det_fce_loss.py 8.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/losses/fce_loss.py
  17. """
  18. import numpy as np
  19. from paddle import nn
  20. import paddle
  21. import paddle.nn.functional as F
  22. from functools import partial
  23. def multi_apply(func, *args, **kwargs):
  24. pfunc = partial(func, **kwargs) if kwargs else func
  25. map_results = map(pfunc, *args)
  26. return tuple(map(list, zip(*map_results)))
  27. class FCELoss(nn.Layer):
  28. """The class for implementing FCENet loss
  29. FCENet(CVPR2021): Fourier Contour Embedding for Arbitrary-shaped
  30. Text Detection
  31. [https://arxiv.org/abs/2104.10442]
  32. Args:
  33. fourier_degree (int) : The maximum Fourier transform degree k.
  34. num_sample (int) : The sampling points number of regression
  35. loss. If it is too small, fcenet tends to be overfitting.
  36. ohem_ratio (float): the negative/positive ratio in OHEM.
  37. """
  38. def __init__(self, fourier_degree, num_sample, ohem_ratio=3.):
  39. super().__init__()
  40. self.fourier_degree = fourier_degree
  41. self.num_sample = num_sample
  42. self.ohem_ratio = ohem_ratio
  43. def forward(self, preds, labels):
  44. assert isinstance(preds, dict)
  45. preds = preds['levels']
  46. p3_maps, p4_maps, p5_maps = labels[1:]
  47. assert p3_maps[0].shape[0] == 4 * self.fourier_degree + 5,\
  48. 'fourier degree not equal in FCEhead and FCEtarget'
  49. # to tensor
  50. gts = [p3_maps, p4_maps, p5_maps]
  51. for idx, maps in enumerate(gts):
  52. gts[idx] = paddle.to_tensor(np.stack(maps))
  53. losses = multi_apply(self.forward_single, preds, gts)
  54. loss_tr = paddle.to_tensor(0.).astype('float32')
  55. loss_tcl = paddle.to_tensor(0.).astype('float32')
  56. loss_reg_x = paddle.to_tensor(0.).astype('float32')
  57. loss_reg_y = paddle.to_tensor(0.).astype('float32')
  58. loss_all = paddle.to_tensor(0.).astype('float32')
  59. for idx, loss in enumerate(losses):
  60. loss_all += sum(loss)
  61. if idx == 0:
  62. loss_tr += sum(loss)
  63. elif idx == 1:
  64. loss_tcl += sum(loss)
  65. elif idx == 2:
  66. loss_reg_x += sum(loss)
  67. else:
  68. loss_reg_y += sum(loss)
  69. results = dict(
  70. loss=loss_all,
  71. loss_text=loss_tr,
  72. loss_center=loss_tcl,
  73. loss_reg_x=loss_reg_x,
  74. loss_reg_y=loss_reg_y, )
  75. return results
  76. def forward_single(self, pred, gt):
  77. cls_pred = paddle.transpose(pred[0], (0, 2, 3, 1))
  78. reg_pred = paddle.transpose(pred[1], (0, 2, 3, 1))
  79. gt = paddle.transpose(gt, (0, 2, 3, 1))
  80. k = 2 * self.fourier_degree + 1
  81. tr_pred = paddle.reshape(cls_pred[:, :, :, :2], (-1, 2))
  82. tcl_pred = paddle.reshape(cls_pred[:, :, :, 2:], (-1, 2))
  83. x_pred = paddle.reshape(reg_pred[:, :, :, 0:k], (-1, k))
  84. y_pred = paddle.reshape(reg_pred[:, :, :, k:2 * k], (-1, k))
  85. tr_mask = gt[:, :, :, :1].reshape([-1])
  86. tcl_mask = gt[:, :, :, 1:2].reshape([-1])
  87. train_mask = gt[:, :, :, 2:3].reshape([-1])
  88. x_map = paddle.reshape(gt[:, :, :, 3:3 + k], (-1, k))
  89. y_map = paddle.reshape(gt[:, :, :, 3 + k:], (-1, k))
  90. tr_train_mask = (train_mask * tr_mask).astype('bool')
  91. tr_train_mask2 = paddle.concat(
  92. [tr_train_mask.unsqueeze(1), tr_train_mask.unsqueeze(1)], axis=1)
  93. # tr loss
  94. loss_tr = self.ohem(tr_pred, tr_mask, train_mask)
  95. # tcl loss
  96. loss_tcl = paddle.to_tensor(0.).astype('float32')
  97. tr_neg_mask = tr_train_mask.logical_not()
  98. tr_neg_mask2 = paddle.concat(
  99. [tr_neg_mask.unsqueeze(1), tr_neg_mask.unsqueeze(1)], axis=1)
  100. if tr_train_mask.sum().item() > 0:
  101. loss_tcl_pos = F.cross_entropy(
  102. tcl_pred.masked_select(tr_train_mask2).reshape([-1, 2]),
  103. tcl_mask.masked_select(tr_train_mask).astype('int64'))
  104. loss_tcl_neg = F.cross_entropy(
  105. tcl_pred.masked_select(tr_neg_mask2).reshape([-1, 2]),
  106. tcl_mask.masked_select(tr_neg_mask).astype('int64'))
  107. loss_tcl = loss_tcl_pos + 0.5 * loss_tcl_neg
  108. # regression loss
  109. loss_reg_x = paddle.to_tensor(0.).astype('float32')
  110. loss_reg_y = paddle.to_tensor(0.).astype('float32')
  111. if tr_train_mask.sum().item() > 0:
  112. weight = (tr_mask.masked_select(tr_train_mask.astype('bool'))
  113. .astype('float32') + tcl_mask.masked_select(
  114. tr_train_mask.astype('bool')).astype('float32')) / 2
  115. weight = weight.reshape([-1, 1])
  116. ft_x, ft_y = self.fourier2poly(x_map, y_map)
  117. ft_x_pre, ft_y_pre = self.fourier2poly(x_pred, y_pred)
  118. dim = ft_x.shape[1]
  119. tr_train_mask3 = paddle.concat(
  120. [tr_train_mask.unsqueeze(1) for i in range(dim)], axis=1)
  121. loss_reg_x = paddle.mean(weight * F.smooth_l1_loss(
  122. ft_x_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
  123. ft_x.masked_select(tr_train_mask3).reshape([-1, dim]),
  124. reduction='none'))
  125. loss_reg_y = paddle.mean(weight * F.smooth_l1_loss(
  126. ft_y_pre.masked_select(tr_train_mask3).reshape([-1, dim]),
  127. ft_y.masked_select(tr_train_mask3).reshape([-1, dim]),
  128. reduction='none'))
  129. return loss_tr, loss_tcl, loss_reg_x, loss_reg_y
  130. def ohem(self, predict, target, train_mask):
  131. pos = (target * train_mask).astype('bool')
  132. neg = ((1 - target) * train_mask).astype('bool')
  133. pos2 = paddle.concat([pos.unsqueeze(1), pos.unsqueeze(1)], axis=1)
  134. neg2 = paddle.concat([neg.unsqueeze(1), neg.unsqueeze(1)], axis=1)
  135. n_pos = pos.astype('float32').sum()
  136. if n_pos.item() > 0:
  137. loss_pos = F.cross_entropy(
  138. predict.masked_select(pos2).reshape([-1, 2]),
  139. target.masked_select(pos).astype('int64'),
  140. reduction='sum')
  141. loss_neg = F.cross_entropy(
  142. predict.masked_select(neg2).reshape([-1, 2]),
  143. target.masked_select(neg).astype('int64'),
  144. reduction='none')
  145. n_neg = min(
  146. int(neg.astype('float32').sum().item()),
  147. int(self.ohem_ratio * n_pos.astype('float32')))
  148. else:
  149. loss_pos = paddle.to_tensor(0.)
  150. loss_neg = F.cross_entropy(
  151. predict.masked_select(neg2).reshape([-1, 2]),
  152. target.masked_select(neg).astype('int64'),
  153. reduction='none')
  154. n_neg = 100
  155. if len(loss_neg) > n_neg:
  156. loss_neg, _ = paddle.topk(loss_neg, n_neg)
  157. return (loss_pos + loss_neg.sum()) / (n_pos + n_neg).astype('float32')
  158. def fourier2poly(self, real_maps, imag_maps):
  159. """Transform Fourier coefficient maps to polygon maps.
  160. Args:
  161. real_maps (tensor): A map composed of the real parts of the
  162. Fourier coefficients, whose shape is (-1, 2k+1)
  163. imag_maps (tensor):A map composed of the imag parts of the
  164. Fourier coefficients, whose shape is (-1, 2k+1)
  165. Returns
  166. x_maps (tensor): A map composed of the x value of the polygon
  167. represented by n sample points (xn, yn), whose shape is (-1, n)
  168. y_maps (tensor): A map composed of the y value of the polygon
  169. represented by n sample points (xn, yn), whose shape is (-1, n)
  170. """
  171. k_vect = paddle.arange(
  172. -self.fourier_degree, self.fourier_degree + 1,
  173. dtype='float32').reshape([-1, 1])
  174. i_vect = paddle.arange(
  175. 0, self.num_sample, dtype='float32').reshape([1, -1])
  176. transform_matrix = 2 * np.pi / self.num_sample * paddle.matmul(k_vect,
  177. i_vect)
  178. x1 = paddle.einsum('ak, kn-> an', real_maps,
  179. paddle.cos(transform_matrix))
  180. x2 = paddle.einsum('ak, kn-> an', imag_maps,
  181. paddle.sin(transform_matrix))
  182. y1 = paddle.einsum('ak, kn-> an', real_maps,
  183. paddle.sin(transform_matrix))
  184. y2 = paddle.einsum('ak, kn-> an', imag_maps,
  185. paddle.cos(transform_matrix))
  186. x_maps = x1 - x2
  187. y_maps = y1 + y2
  188. return x_maps, y_maps