table_metric.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ppocr.metrics.det_metric import DetMetric
  16. class TableStructureMetric(object):
  17. def __init__(self,
  18. main_indicator='acc',
  19. eps=1e-6,
  20. del_thead_tbody=False,
  21. **kwargs):
  22. self.main_indicator = main_indicator
  23. self.eps = eps
  24. self.del_thead_tbody = del_thead_tbody
  25. self.reset()
  26. def __call__(self, pred_label, batch=None, *args, **kwargs):
  27. preds, labels = pred_label
  28. pred_structure_batch_list = preds['structure_batch_list']
  29. gt_structure_batch_list = labels['structure_batch_list']
  30. correct_num = 0
  31. all_num = 0
  32. for (pred, pred_conf), target in zip(pred_structure_batch_list,
  33. gt_structure_batch_list):
  34. pred_str = ''.join(pred)
  35. target_str = ''.join(target)
  36. if self.del_thead_tbody:
  37. pred_str = pred_str.replace('<thead>', '').replace(
  38. '</thead>', '').replace('<tbody>', '').replace('</tbody>',
  39. '')
  40. target_str = target_str.replace('<thead>', '').replace(
  41. '</thead>', '').replace('<tbody>', '').replace('</tbody>',
  42. '')
  43. if pred_str == target_str:
  44. correct_num += 1
  45. all_num += 1
  46. self.correct_num += correct_num
  47. self.all_num += all_num
  48. def get_metric(self):
  49. """
  50. return metrics {
  51. 'acc': 0,
  52. }
  53. """
  54. acc = 1.0 * self.correct_num / (self.all_num + self.eps)
  55. self.reset()
  56. return {'acc': acc}
  57. def reset(self):
  58. self.correct_num = 0
  59. self.all_num = 0
  60. self.len_acc_num = 0
  61. self.token_nums = 0
  62. self.anys_dict = dict()
  63. class TableMetric(object):
  64. def __init__(self,
  65. main_indicator='acc',
  66. compute_bbox_metric=False,
  67. box_format='xyxy',
  68. del_thead_tbody=False,
  69. **kwargs):
  70. """
  71. @param sub_metrics: configs of sub_metric
  72. @param main_matric: main_matric for save best_model
  73. @param kwargs:
  74. """
  75. self.structure_metric = TableStructureMetric(
  76. del_thead_tbody=del_thead_tbody)
  77. self.bbox_metric = DetMetric() if compute_bbox_metric else None
  78. self.main_indicator = main_indicator
  79. self.box_format = box_format
  80. self.reset()
  81. def __call__(self, pred_label, batch=None, *args, **kwargs):
  82. self.structure_metric(pred_label)
  83. if self.bbox_metric is not None:
  84. self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))
  85. def prepare_bbox_metric_input(self, pred_label):
  86. pred_bbox_batch_list = []
  87. gt_ignore_tags_batch_list = []
  88. gt_bbox_batch_list = []
  89. preds, labels = pred_label
  90. batch_num = len(preds['bbox_batch_list'])
  91. for batch_idx in range(batch_num):
  92. # pred
  93. pred_bbox_list = [
  94. self.format_box(pred_box)
  95. for pred_box in preds['bbox_batch_list'][batch_idx]
  96. ]
  97. pred_bbox_batch_list.append({'points': pred_bbox_list})
  98. # gt
  99. gt_bbox_list = []
  100. gt_ignore_tags_list = []
  101. for gt_box in labels['bbox_batch_list'][batch_idx]:
  102. gt_bbox_list.append(self.format_box(gt_box))
  103. gt_ignore_tags_list.append(0)
  104. gt_bbox_batch_list.append(gt_bbox_list)
  105. gt_ignore_tags_batch_list.append(gt_ignore_tags_list)
  106. return [
  107. pred_bbox_batch_list,
  108. [0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list]
  109. ]
  110. def get_metric(self):
  111. structure_metric = self.structure_metric.get_metric()
  112. if self.bbox_metric is None:
  113. return structure_metric
  114. bbox_metric = self.bbox_metric.get_metric()
  115. if self.main_indicator == self.bbox_metric.main_indicator:
  116. output = bbox_metric
  117. for sub_key in structure_metric:
  118. output["structure_metric_{}".format(
  119. sub_key)] = structure_metric[sub_key]
  120. else:
  121. output = structure_metric
  122. for sub_key in bbox_metric:
  123. output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]
  124. return output
  125. def reset(self):
  126. self.structure_metric.reset()
  127. if self.bbox_metric is not None:
  128. self.bbox_metric.reset()
  129. def format_box(self, box):
  130. if self.box_format == 'xyxy':
  131. x1, y1, x2, y2 = box
  132. box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
  133. elif self.box_format == 'xywh':
  134. x, y, w, h = box
  135. x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
  136. box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
  137. elif self.box_format == 'xyxyxyxy':
  138. x1, y1, x2, y2, x3, y3, x4, y4 = box
  139. box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
  140. return box