rec_metric.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from rapidfuzz.distance import Levenshtein
  15. from difflib import SequenceMatcher
  16. import numpy as np
  17. import string
  18. class RecMetric(object):
  19. def __init__(self,
  20. main_indicator='acc',
  21. is_filter=False,
  22. ignore_space=True,
  23. **kwargs):
  24. self.main_indicator = main_indicator
  25. self.is_filter = is_filter
  26. self.ignore_space = ignore_space
  27. self.eps = 1e-5
  28. self.reset()
  29. def _normalize_text(self, text):
  30. text = ''.join(
  31. filter(lambda x: x in (string.digits + string.ascii_letters), text))
  32. return text.lower()
  33. def __call__(self, pred_label, *args, **kwargs):
  34. preds, labels = pred_label
  35. correct_num = 0
  36. all_num = 0
  37. norm_edit_dis = 0.0
  38. for (pred, pred_conf), (target, _) in zip(preds, labels):
  39. if self.ignore_space:
  40. pred = pred.replace(" ", "")
  41. target = target.replace(" ", "")
  42. if self.is_filter:
  43. pred = self._normalize_text(pred)
  44. target = self._normalize_text(target)
  45. norm_edit_dis += Levenshtein.normalized_distance(pred, target)
  46. if pred == target:
  47. correct_num += 1
  48. all_num += 1
  49. self.correct_num += correct_num
  50. self.all_num += all_num
  51. self.norm_edit_dis += norm_edit_dis
  52. return {
  53. 'acc': correct_num / (all_num + self.eps),
  54. 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
  55. }
  56. def get_metric(self):
  57. """
  58. return metrics {
  59. 'acc': 0,
  60. 'norm_edit_dis': 0,
  61. }
  62. """
  63. acc = 1.0 * self.correct_num / (self.all_num + self.eps)
  64. norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
  65. self.reset()
  66. return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
  67. def reset(self):
  68. self.correct_num = 0
  69. self.all_num = 0
  70. self.norm_edit_dis = 0
  71. class CNTMetric(object):
  72. def __init__(self, main_indicator='acc', **kwargs):
  73. self.main_indicator = main_indicator
  74. self.eps = 1e-5
  75. self.reset()
  76. def __call__(self, pred_label, *args, **kwargs):
  77. preds, labels = pred_label
  78. correct_num = 0
  79. all_num = 0
  80. for pred, target in zip(preds, labels):
  81. if pred == target:
  82. correct_num += 1
  83. all_num += 1
  84. self.correct_num += correct_num
  85. self.all_num += all_num
  86. return {'acc': correct_num / (all_num + self.eps), }
  87. def get_metric(self):
  88. """
  89. return metrics {
  90. 'acc': 0,
  91. }
  92. """
  93. acc = 1.0 * self.correct_num / (self.all_num + self.eps)
  94. self.reset()
  95. return {'acc': acc}
  96. def reset(self):
  97. self.correct_num = 0
  98. self.all_num = 0
  99. class CANMetric(object):
  100. def __init__(self, main_indicator='exp_rate', **kwargs):
  101. self.main_indicator = main_indicator
  102. self.word_right = []
  103. self.exp_right = []
  104. self.word_total_length = 0
  105. self.exp_total_num = 0
  106. self.word_rate = 0
  107. self.exp_rate = 0
  108. self.reset()
  109. self.epoch_reset()
  110. def __call__(self, preds, batch, **kwargs):
  111. for k, v in kwargs.items():
  112. epoch_reset = v
  113. if epoch_reset:
  114. self.epoch_reset()
  115. word_probs = preds
  116. word_label, word_label_mask = batch
  117. line_right = 0
  118. if word_probs is not None:
  119. word_pred = word_probs.argmax(2)
  120. word_pred = word_pred.cpu().detach().numpy()
  121. word_scores = [
  122. SequenceMatcher(
  123. None,
  124. s1[:int(np.sum(s3))],
  125. s2[:int(np.sum(s3))],
  126. autojunk=False).ratio() * (
  127. len(s1[:int(np.sum(s3))]) + len(s2[:int(np.sum(s3))])) /
  128. len(s1[:int(np.sum(s3))]) / 2
  129. for s1, s2, s3 in zip(word_label, word_pred, word_label_mask)
  130. ]
  131. batch_size = len(word_scores)
  132. for i in range(batch_size):
  133. if word_scores[i] == 1:
  134. line_right += 1
  135. self.word_rate = np.mean(word_scores) #float
  136. self.exp_rate = line_right / batch_size #float
  137. exp_length, word_length = word_label.shape[:2]
  138. self.word_right.append(self.word_rate * word_length)
  139. self.exp_right.append(self.exp_rate * exp_length)
  140. self.word_total_length = self.word_total_length + word_length
  141. self.exp_total_num = self.exp_total_num + exp_length
  142. def get_metric(self):
  143. """
  144. return {
  145. 'word_rate': 0,
  146. "exp_rate": 0,
  147. }
  148. """
  149. cur_word_rate = sum(self.word_right) / self.word_total_length
  150. cur_exp_rate = sum(self.exp_right) / self.exp_total_num
  151. self.reset()
  152. return {'word_rate': cur_word_rate, "exp_rate": cur_exp_rate}
  153. def reset(self):
  154. self.word_rate = 0
  155. self.exp_rate = 0
  156. def epoch_reset(self):
  157. self.word_right = []
  158. self.exp_right = []
  159. self.word_total_length = 0
  160. self.exp_total_num = 0