det_basic_loss.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/models/losses/basic_loss.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import numpy as np
  22. import paddle
  23. from paddle import nn
  24. import paddle.nn.functional as F
  25. class BalanceLoss(nn.Layer):
  26. def __init__(self,
  27. balance_loss=True,
  28. main_loss_type='DiceLoss',
  29. negative_ratio=3,
  30. return_origin=False,
  31. eps=1e-6,
  32. **kwargs):
  33. """
  34. The BalanceLoss for Differentiable Binarization text detection
  35. args:
  36. balance_loss (bool): whether balance loss or not, default is True
  37. main_loss_type (str): can only be one of ['CrossEntropy','DiceLoss',
  38. 'Euclidean','BCELoss', 'MaskL1Loss'], default is 'DiceLoss'.
  39. negative_ratio (int|float): float, default is 3.
  40. return_origin (bool): whether return unbalanced loss or not, default is False.
  41. eps (float): default is 1e-6.
  42. """
  43. super(BalanceLoss, self).__init__()
  44. self.balance_loss = balance_loss
  45. self.main_loss_type = main_loss_type
  46. self.negative_ratio = negative_ratio
  47. self.return_origin = return_origin
  48. self.eps = eps
  49. if self.main_loss_type == "CrossEntropy":
  50. self.loss = nn.CrossEntropyLoss()
  51. elif self.main_loss_type == "Euclidean":
  52. self.loss = nn.MSELoss()
  53. elif self.main_loss_type == "DiceLoss":
  54. self.loss = DiceLoss(self.eps)
  55. elif self.main_loss_type == "BCELoss":
  56. self.loss = BCELoss(reduction='none')
  57. elif self.main_loss_type == "MaskL1Loss":
  58. self.loss = MaskL1Loss(self.eps)
  59. else:
  60. loss_type = [
  61. 'CrossEntropy', 'DiceLoss', 'Euclidean', 'BCELoss', 'MaskL1Loss'
  62. ]
  63. raise Exception(
  64. "main_loss_type in BalanceLoss() can only be one of {}".format(
  65. loss_type))
  66. def forward(self, pred, gt, mask=None):
  67. """
  68. The BalanceLoss for Differentiable Binarization text detection
  69. args:
  70. pred (variable): predicted feature maps.
  71. gt (variable): ground truth feature maps.
  72. mask (variable): masked maps.
  73. return: (variable) balanced loss
  74. """
  75. positive = gt * mask
  76. negative = (1 - gt) * mask
  77. positive_count = int(positive.sum())
  78. negative_count = int(
  79. min(negative.sum(), positive_count * self.negative_ratio))
  80. loss = self.loss(pred, gt, mask=mask)
  81. if not self.balance_loss:
  82. return loss
  83. positive_loss = positive * loss
  84. negative_loss = negative * loss
  85. negative_loss = paddle.reshape(negative_loss, shape=[-1])
  86. if negative_count > 0:
  87. sort_loss = negative_loss.sort(descending=True)
  88. negative_loss = sort_loss[:negative_count]
  89. # negative_loss, _ = paddle.topk(negative_loss, k=negative_count_int)
  90. balance_loss = (positive_loss.sum() + negative_loss.sum()) / (
  91. positive_count + negative_count + self.eps)
  92. else:
  93. balance_loss = positive_loss.sum() / (positive_count + self.eps)
  94. if self.return_origin:
  95. return balance_loss, loss
  96. return balance_loss
  97. class DiceLoss(nn.Layer):
  98. def __init__(self, eps=1e-6):
  99. super(DiceLoss, self).__init__()
  100. self.eps = eps
  101. def forward(self, pred, gt, mask, weights=None):
  102. """
  103. DiceLoss function.
  104. """
  105. assert pred.shape == gt.shape
  106. assert pred.shape == mask.shape
  107. if weights is not None:
  108. assert weights.shape == mask.shape
  109. mask = weights * mask
  110. intersection = paddle.sum(pred * gt * mask)
  111. union = paddle.sum(pred * mask) + paddle.sum(gt * mask) + self.eps
  112. loss = 1 - 2.0 * intersection / union
  113. assert loss <= 1
  114. return loss
  115. class MaskL1Loss(nn.Layer):
  116. def __init__(self, eps=1e-6):
  117. super(MaskL1Loss, self).__init__()
  118. self.eps = eps
  119. def forward(self, pred, gt, mask):
  120. """
  121. Mask L1 Loss
  122. """
  123. loss = (paddle.abs(pred - gt) * mask).sum() / (mask.sum() + self.eps)
  124. loss = paddle.mean(loss)
  125. return loss
  126. class BCELoss(nn.Layer):
  127. def __init__(self, reduction='mean'):
  128. super(BCELoss, self).__init__()
  129. self.reduction = reduction
  130. def forward(self, input, label, mask=None, weight=None, name=None):
  131. loss = F.binary_cross_entropy(input, label, reduction=self.reduction)
  132. return loss