det_drrg_loss.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/drrg_loss.py
  17. """
  18. import paddle
  19. import paddle.nn.functional as F
  20. from paddle import nn
  21. class DRRGLoss(nn.Layer):
  22. def __init__(self, ohem_ratio=3.0):
  23. super().__init__()
  24. self.ohem_ratio = ohem_ratio
  25. self.downsample_ratio = 1.0
  26. def balance_bce_loss(self, pred, gt, mask):
  27. """Balanced Binary-CrossEntropy Loss.
  28. Args:
  29. pred (Tensor): Shape of :math:`(1, H, W)`.
  30. gt (Tensor): Shape of :math:`(1, H, W)`.
  31. mask (Tensor): Shape of :math:`(1, H, W)`.
  32. Returns:
  33. Tensor: Balanced bce loss.
  34. """
  35. assert pred.shape == gt.shape == mask.shape
  36. assert paddle.all(pred >= 0) and paddle.all(pred <= 1)
  37. assert paddle.all(gt >= 0) and paddle.all(gt <= 1)
  38. positive = gt * mask
  39. negative = (1 - gt) * mask
  40. positive_count = int(positive.sum())
  41. if positive_count > 0:
  42. loss = F.binary_cross_entropy(pred, gt, reduction='none')
  43. positive_loss = paddle.sum(loss * positive)
  44. negative_loss = loss * negative
  45. negative_count = min(
  46. int(negative.sum()), int(positive_count * self.ohem_ratio))
  47. else:
  48. positive_loss = paddle.to_tensor(0.0)
  49. loss = F.binary_cross_entropy(pred, gt, reduction='none')
  50. negative_loss = loss * negative
  51. negative_count = 100
  52. negative_loss, _ = paddle.topk(
  53. negative_loss.reshape([-1]), negative_count)
  54. balance_loss = (positive_loss + paddle.sum(negative_loss)) / (
  55. float(positive_count + negative_count) + 1e-5)
  56. return balance_loss
  57. def gcn_loss(self, gcn_data):
  58. """CrossEntropy Loss from gcn module.
  59. Args:
  60. gcn_data (tuple(Tensor, Tensor)): The first is the
  61. prediction with shape :math:`(N, 2)` and the
  62. second is the gt label with shape :math:`(m, n)`
  63. where :math:`m * n = N`.
  64. Returns:
  65. Tensor: CrossEntropy loss.
  66. """
  67. gcn_pred, gt_labels = gcn_data
  68. gt_labels = gt_labels.reshape([-1])
  69. loss = F.cross_entropy(gcn_pred, gt_labels)
  70. return loss
  71. def bitmasks2tensor(self, bitmasks, target_sz):
  72. """Convert Bitmasks to tensor.
  73. Args:
  74. bitmasks (list[BitmapMasks]): The BitmapMasks list. Each item is
  75. for one img.
  76. target_sz (tuple(int, int)): The target tensor of size
  77. :math:`(H, W)`.
  78. Returns:
  79. list[Tensor]: The list of kernel tensors. Each element stands for
  80. one kernel level.
  81. """
  82. batch_size = len(bitmasks)
  83. results = []
  84. kernel = []
  85. for batch_inx in range(batch_size):
  86. mask = bitmasks[batch_inx]
  87. # hxw
  88. mask_sz = mask.shape
  89. # left, right, top, bottom
  90. pad = [0, target_sz[1] - mask_sz[1], 0, target_sz[0] - mask_sz[0]]
  91. mask = F.pad(mask, pad, mode='constant', value=0)
  92. kernel.append(mask)
  93. kernel = paddle.stack(kernel)
  94. results.append(kernel)
  95. return results
  96. def forward(self, preds, labels):
  97. """Compute Drrg loss.
  98. """
  99. assert isinstance(preds, tuple)
  100. gt_text_mask, gt_center_region_mask, gt_mask, gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map = labels[
  101. 1:8]
  102. downsample_ratio = self.downsample_ratio
  103. pred_maps, gcn_data = preds
  104. pred_text_region = pred_maps[:, 0, :, :]
  105. pred_center_region = pred_maps[:, 1, :, :]
  106. pred_sin_map = pred_maps[:, 2, :, :]
  107. pred_cos_map = pred_maps[:, 3, :, :]
  108. pred_top_height_map = pred_maps[:, 4, :, :]
  109. pred_bot_height_map = pred_maps[:, 5, :, :]
  110. feature_sz = pred_maps.shape
  111. # bitmask 2 tensor
  112. mapping = {
  113. 'gt_text_mask': paddle.cast(gt_text_mask, 'float32'),
  114. 'gt_center_region_mask':
  115. paddle.cast(gt_center_region_mask, 'float32'),
  116. 'gt_mask': paddle.cast(gt_mask, 'float32'),
  117. 'gt_top_height_map': paddle.cast(gt_top_height_map, 'float32'),
  118. 'gt_bot_height_map': paddle.cast(gt_bot_height_map, 'float32'),
  119. 'gt_sin_map': paddle.cast(gt_sin_map, 'float32'),
  120. 'gt_cos_map': paddle.cast(gt_cos_map, 'float32')
  121. }
  122. gt = {}
  123. for key, value in mapping.items():
  124. gt[key] = value
  125. if abs(downsample_ratio - 1.0) < 1e-2:
  126. gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
  127. else:
  128. gt[key] = [item.rescale(downsample_ratio) for item in gt[key]]
  129. gt[key] = self.bitmasks2tensor(gt[key], feature_sz[2:])
  130. if key in ['gt_top_height_map', 'gt_bot_height_map']:
  131. gt[key] = [item * downsample_ratio for item in gt[key]]
  132. gt[key] = [item for item in gt[key]]
  133. scale = paddle.sqrt(1.0 / (pred_sin_map**2 + pred_cos_map**2 + 1e-8))
  134. pred_sin_map = pred_sin_map * scale
  135. pred_cos_map = pred_cos_map * scale
  136. loss_text = self.balance_bce_loss(
  137. F.sigmoid(pred_text_region), gt['gt_text_mask'][0],
  138. gt['gt_mask'][0])
  139. text_mask = (gt['gt_text_mask'][0] * gt['gt_mask'][0])
  140. negative_text_mask = ((1 - gt['gt_text_mask'][0]) * gt['gt_mask'][0])
  141. loss_center_map = F.binary_cross_entropy(
  142. F.sigmoid(pred_center_region),
  143. gt['gt_center_region_mask'][0],
  144. reduction='none')
  145. if int(text_mask.sum()) > 0:
  146. loss_center_positive = paddle.sum(loss_center_map *
  147. text_mask) / paddle.sum(text_mask)
  148. else:
  149. loss_center_positive = paddle.to_tensor(0.0)
  150. loss_center_negative = paddle.sum(
  151. loss_center_map *
  152. negative_text_mask) / paddle.sum(negative_text_mask)
  153. loss_center = loss_center_positive + 0.5 * loss_center_negative
  154. center_mask = (gt['gt_center_region_mask'][0] * gt['gt_mask'][0])
  155. if int(center_mask.sum()) > 0:
  156. map_sz = pred_top_height_map.shape
  157. ones = paddle.ones(map_sz, dtype='float32')
  158. loss_top = F.smooth_l1_loss(
  159. pred_top_height_map / (gt['gt_top_height_map'][0] + 1e-2),
  160. ones,
  161. reduction='none')
  162. loss_bot = F.smooth_l1_loss(
  163. pred_bot_height_map / (gt['gt_bot_height_map'][0] + 1e-2),
  164. ones,
  165. reduction='none')
  166. gt_height = (
  167. gt['gt_top_height_map'][0] + gt['gt_bot_height_map'][0])
  168. loss_height = paddle.sum(
  169. (paddle.log(gt_height + 1) *
  170. (loss_top + loss_bot)) * center_mask) / paddle.sum(center_mask)
  171. loss_sin = paddle.sum(
  172. F.smooth_l1_loss(
  173. pred_sin_map, gt['gt_sin_map'][0],
  174. reduction='none') * center_mask) / paddle.sum(center_mask)
  175. loss_cos = paddle.sum(
  176. F.smooth_l1_loss(
  177. pred_cos_map, gt['gt_cos_map'][0],
  178. reduction='none') * center_mask) / paddle.sum(center_mask)
  179. else:
  180. loss_height = paddle.to_tensor(0.0)
  181. loss_sin = paddle.to_tensor(0.0)
  182. loss_cos = paddle.to_tensor(0.0)
  183. loss_gcn = self.gcn_loss(gcn_data)
  184. loss = loss_text + loss_center + loss_height + loss_sin + loss_cos + loss_gcn
  185. results = dict(
  186. loss=loss,
  187. loss_text=loss_text,
  188. loss_center=loss_center,
  189. loss_height=loss_height,
  190. loss_sin=loss_sin,
  191. loss_cos=loss_cos,
  192. loss_gcn=loss_gcn)
  193. return results