e2e_pg_loss.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from paddle import nn
  18. import paddle
  19. from .det_basic_loss import DiceLoss
  20. from ppocr.utils.e2e_utils.extract_batchsize import pre_process
  21. class PGLoss(nn.Layer):
  22. def __init__(self,
  23. tcl_bs,
  24. max_text_length,
  25. max_text_nums,
  26. pad_num,
  27. eps=1e-6,
  28. **kwargs):
  29. super(PGLoss, self).__init__()
  30. self.tcl_bs = tcl_bs
  31. self.max_text_nums = max_text_nums
  32. self.max_text_length = max_text_length
  33. self.pad_num = pad_num
  34. self.dice_loss = DiceLoss(eps=eps)
  35. def border_loss(self, f_border, l_border, l_score, l_mask):
  36. l_border_split, l_border_norm = paddle.tensor.split(
  37. l_border, num_or_sections=[4, 1], axis=1)
  38. f_border_split = f_border
  39. b, c, h, w = l_border_norm.shape
  40. l_border_norm_split = paddle.expand(
  41. x=l_border_norm, shape=[b, 4 * c, h, w])
  42. b, c, h, w = l_score.shape
  43. l_border_score = paddle.expand(x=l_score, shape=[b, 4 * c, h, w])
  44. b, c, h, w = l_mask.shape
  45. l_border_mask = paddle.expand(x=l_mask, shape=[b, 4 * c, h, w])
  46. border_diff = l_border_split - f_border_split
  47. abs_border_diff = paddle.abs(border_diff)
  48. border_sign = abs_border_diff < 1.0
  49. border_sign = paddle.cast(border_sign, dtype='float32')
  50. border_sign.stop_gradient = True
  51. border_in_loss = 0.5 * abs_border_diff * abs_border_diff * border_sign + \
  52. (abs_border_diff - 0.5) * (1.0 - border_sign)
  53. border_out_loss = l_border_norm_split * border_in_loss
  54. border_loss = paddle.sum(border_out_loss * l_border_score * l_border_mask) / \
  55. (paddle.sum(l_border_score * l_border_mask) + 1e-5)
  56. return border_loss
  57. def direction_loss(self, f_direction, l_direction, l_score, l_mask):
  58. l_direction_split, l_direction_norm = paddle.tensor.split(
  59. l_direction, num_or_sections=[2, 1], axis=1)
  60. f_direction_split = f_direction
  61. b, c, h, w = l_direction_norm.shape
  62. l_direction_norm_split = paddle.expand(
  63. x=l_direction_norm, shape=[b, 2 * c, h, w])
  64. b, c, h, w = l_score.shape
  65. l_direction_score = paddle.expand(x=l_score, shape=[b, 2 * c, h, w])
  66. b, c, h, w = l_mask.shape
  67. l_direction_mask = paddle.expand(x=l_mask, shape=[b, 2 * c, h, w])
  68. direction_diff = l_direction_split - f_direction_split
  69. abs_direction_diff = paddle.abs(direction_diff)
  70. direction_sign = abs_direction_diff < 1.0
  71. direction_sign = paddle.cast(direction_sign, dtype='float32')
  72. direction_sign.stop_gradient = True
  73. direction_in_loss = 0.5 * abs_direction_diff * abs_direction_diff * direction_sign + \
  74. (abs_direction_diff - 0.5) * (1.0 - direction_sign)
  75. direction_out_loss = l_direction_norm_split * direction_in_loss
  76. direction_loss = paddle.sum(direction_out_loss * l_direction_score * l_direction_mask) / \
  77. (paddle.sum(l_direction_score * l_direction_mask) + 1e-5)
  78. return direction_loss
  79. def ctcloss(self, f_char, tcl_pos, tcl_mask, tcl_label, label_t):
  80. f_char = paddle.transpose(f_char, [0, 2, 3, 1])
  81. tcl_pos = paddle.reshape(tcl_pos, [-1, 3])
  82. tcl_pos = paddle.cast(tcl_pos, dtype=int)
  83. f_tcl_char = paddle.gather_nd(f_char, tcl_pos)
  84. f_tcl_char = paddle.reshape(
  85. f_tcl_char, [-1, 64, self.pad_num + 1]) # len(Lexicon_Table)+1
  86. f_tcl_char_fg, f_tcl_char_bg = paddle.split(
  87. f_tcl_char, [self.pad_num, 1], axis=2)
  88. f_tcl_char_bg = f_tcl_char_bg * tcl_mask + (1.0 - tcl_mask) * 20.0
  89. b, c, l = tcl_mask.shape
  90. tcl_mask_fg = paddle.expand(x=tcl_mask, shape=[b, c, self.pad_num * l])
  91. tcl_mask_fg.stop_gradient = True
  92. f_tcl_char_fg = f_tcl_char_fg * tcl_mask_fg + (1.0 - tcl_mask_fg) * (
  93. -20.0)
  94. f_tcl_char_mask = paddle.concat([f_tcl_char_fg, f_tcl_char_bg], axis=2)
  95. f_tcl_char_ld = paddle.transpose(f_tcl_char_mask, (1, 0, 2))
  96. N, B, _ = f_tcl_char_ld.shape
  97. input_lengths = paddle.to_tensor([N] * B, dtype='int64')
  98. cost = paddle.nn.functional.ctc_loss(
  99. log_probs=f_tcl_char_ld,
  100. labels=tcl_label,
  101. input_lengths=input_lengths,
  102. label_lengths=label_t,
  103. blank=self.pad_num,
  104. reduction='none')
  105. cost = cost.mean()
  106. return cost
  107. def forward(self, predicts, labels):
  108. images, tcl_maps, tcl_label_maps, border_maps \
  109. , direction_maps, training_masks, label_list, pos_list, pos_mask = labels
  110. # for all the batch_size
  111. pos_list, pos_mask, label_list, label_t = pre_process(
  112. label_list, pos_list, pos_mask, self.max_text_length,
  113. self.max_text_nums, self.pad_num, self.tcl_bs)
  114. f_score, f_border, f_direction, f_char = predicts['f_score'], predicts['f_border'], predicts['f_direction'], \
  115. predicts['f_char']
  116. score_loss = self.dice_loss(f_score, tcl_maps, training_masks)
  117. border_loss = self.border_loss(f_border, border_maps, tcl_maps,
  118. training_masks)
  119. direction_loss = self.direction_loss(f_direction, direction_maps,
  120. tcl_maps, training_masks)
  121. ctc_loss = self.ctcloss(f_char, pos_list, pos_mask, label_list, label_t)
  122. loss_all = score_loss + border_loss + direction_loss + 5 * ctc_loss
  123. losses = {
  124. 'loss': loss_all,
  125. "score_loss": score_loss,
  126. "border_loss": border_loss,
  127. "direction_loss": direction_loss,
  128. "ctc_loss": ctc_loss
  129. }
  130. return losses