rec_can_loss.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/LBH1024/CAN/models/can.py
  17. """
  18. import paddle
  19. import paddle.nn as nn
  20. import numpy as np
  21. class CANLoss(nn.Layer):
  22. '''
  23. CANLoss is consist of two part:
  24. word_average_loss: average accuracy of the symbol
  25. counting_loss: counting loss of every symbol
  26. '''
  27. def __init__(self):
  28. super(CANLoss, self).__init__()
  29. self.use_label_mask = False
  30. self.out_channel = 111
  31. self.cross = nn.CrossEntropyLoss(
  32. reduction='none') if self.use_label_mask else nn.CrossEntropyLoss()
  33. self.counting_loss = nn.SmoothL1Loss(reduction='mean')
  34. self.ratio = 16
  35. def forward(self, preds, batch):
  36. word_probs = preds[0]
  37. counting_preds = preds[1]
  38. counting_preds1 = preds[2]
  39. counting_preds2 = preds[3]
  40. labels = batch[2]
  41. labels_mask = batch[3]
  42. counting_labels = gen_counting_label(labels, self.out_channel, True)
  43. counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, counting_labels) \
  44. + self.counting_loss(counting_preds, counting_labels)
  45. word_loss = self.cross(
  46. paddle.reshape(word_probs, [-1, word_probs.shape[-1]]),
  47. paddle.reshape(labels, [-1]))
  48. word_average_loss = paddle.sum(
  49. paddle.reshape(word_loss * labels_mask, [-1])) / (
  50. paddle.sum(labels_mask) + 1e-10
  51. ) if self.use_label_mask else word_loss
  52. loss = word_average_loss + counting_loss
  53. return {'loss': loss}
  54. def gen_counting_label(labels, channel, tag):
  55. b, t = labels.shape
  56. counting_labels = np.zeros([b, channel])
  57. if tag:
  58. ignore = [0, 1, 107, 108, 109, 110]
  59. else:
  60. ignore = []
  61. for i in range(b):
  62. for j in range(t):
  63. k = labels[i][j]
  64. if k in ignore:
  65. continue
  66. else:
  67. counting_labels[i][k] += 1
  68. counting_labels = paddle.to_tensor(counting_labels, dtype='float32')
  69. return counting_labels