matcher.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from ppstructure.table.table_master_match import deal_eb_token, deal_bb
  16. def distance(box_1, box_2):
  17. x1, y1, x2, y2 = box_1
  18. x3, y3, x4, y4 = box_2
  19. dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
  20. dis_2 = abs(x3 - x1) + abs(y3 - y1)
  21. dis_3 = abs(x4 - x2) + abs(y4 - y2)
  22. return dis + min(dis_2, dis_3)
  23. def compute_iou(rec1, rec2):
  24. """
  25. computing IoU
  26. :param rec1: (y0, x0, y1, x1), which reflects
  27. (top, left, bottom, right)
  28. :param rec2: (y0, x0, y1, x1)
  29. :return: scala value of IoU
  30. """
  31. # computing area of each rectangles
  32. S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
  33. S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
  34. # computing the sum_area
  35. sum_area = S_rec1 + S_rec2
  36. # find the each edge of intersect rectangle
  37. left_line = max(rec1[1], rec2[1])
  38. right_line = min(rec1[3], rec2[3])
  39. top_line = max(rec1[0], rec2[0])
  40. bottom_line = min(rec1[2], rec2[2])
  41. # judge if there is an intersect
  42. if left_line >= right_line or top_line >= bottom_line:
  43. return 0.0
  44. else:
  45. intersect = (right_line - left_line) * (bottom_line - top_line)
  46. return (intersect / (sum_area - intersect)) * 1.0
  47. class TableMatch:
  48. def __init__(self, filter_ocr_result=False, use_master=False):
  49. self.filter_ocr_result = filter_ocr_result
  50. self.use_master = use_master
  51. def __call__(self, structure_res, dt_boxes, rec_res):
  52. pred_structures, pred_bboxes = structure_res
  53. if self.filter_ocr_result:
  54. dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes,
  55. rec_res)
  56. matched_index = self.match_result(dt_boxes, pred_bboxes)
  57. if self.use_master:
  58. pred_html, pred = self.get_pred_html_master(pred_structures,
  59. matched_index, rec_res)
  60. else:
  61. pred_html, pred = self.get_pred_html(pred_structures, matched_index,
  62. rec_res)
  63. return pred_html
  64. def match_result(self, dt_boxes, pred_bboxes):
  65. matched = {}
  66. for i, gt_box in enumerate(dt_boxes):
  67. distances = []
  68. for j, pred_box in enumerate(pred_bboxes):
  69. if len(pred_box) == 8:
  70. pred_box = [
  71. np.min(pred_box[0::2]), np.min(pred_box[1::2]),
  72. np.max(pred_box[0::2]), np.max(pred_box[1::2])
  73. ]
  74. distances.append((distance(gt_box, pred_box),
  75. 1. - compute_iou(gt_box, pred_box)
  76. )) # compute iou and l1 distance
  77. sorted_distances = distances.copy()
  78. # select det box by iou and l1 distance
  79. sorted_distances = sorted(
  80. sorted_distances, key=lambda item: (item[1], item[0]))
  81. if distances.index(sorted_distances[0]) not in matched.keys():
  82. matched[distances.index(sorted_distances[0])] = [i]
  83. else:
  84. matched[distances.index(sorted_distances[0])].append(i)
  85. return matched
  86. def get_pred_html(self, pred_structures, matched_index, ocr_contents):
  87. end_html = []
  88. td_index = 0
  89. for tag in pred_structures:
  90. if '</td>' in tag:
  91. if '<td></td>' == tag:
  92. end_html.extend('<td>')
  93. if td_index in matched_index.keys():
  94. b_with = False
  95. if '<b>' in ocr_contents[matched_index[td_index][
  96. 0]] and len(matched_index[td_index]) > 1:
  97. b_with = True
  98. end_html.extend('<b>')
  99. for i, td_index_index in enumerate(matched_index[td_index]):
  100. content = ocr_contents[td_index_index][0]
  101. if len(matched_index[td_index]) > 1:
  102. if len(content) == 0:
  103. continue
  104. if content[0] == ' ':
  105. content = content[1:]
  106. if '<b>' in content:
  107. content = content[3:]
  108. if '</b>' in content:
  109. content = content[:-4]
  110. if len(content) == 0:
  111. continue
  112. if i != len(matched_index[
  113. td_index]) - 1 and ' ' != content[-1]:
  114. content += ' '
  115. end_html.extend(content)
  116. if b_with:
  117. end_html.extend('</b>')
  118. if '<td></td>' == tag:
  119. end_html.append('</td>')
  120. else:
  121. end_html.append(tag)
  122. td_index += 1
  123. else:
  124. end_html.append(tag)
  125. return ''.join(end_html), end_html
  126. def get_pred_html_master(self, pred_structures, matched_index,
  127. ocr_contents):
  128. end_html = []
  129. td_index = 0
  130. for token in pred_structures:
  131. if '</td>' in token:
  132. txt = ''
  133. b_with = False
  134. if td_index in matched_index.keys():
  135. if '<b>' in ocr_contents[matched_index[td_index][
  136. 0]] and len(matched_index[td_index]) > 1:
  137. b_with = True
  138. for i, td_index_index in enumerate(matched_index[td_index]):
  139. content = ocr_contents[td_index_index][0]
  140. if len(matched_index[td_index]) > 1:
  141. if len(content) == 0:
  142. continue
  143. if content[0] == ' ':
  144. content = content[1:]
  145. if '<b>' in content:
  146. content = content[3:]
  147. if '</b>' in content:
  148. content = content[:-4]
  149. if len(content) == 0:
  150. continue
  151. if i != len(matched_index[
  152. td_index]) - 1 and ' ' != content[-1]:
  153. content += ' '
  154. txt += content
  155. if b_with:
  156. txt = '<b>{}</b>'.format(txt)
  157. if '<td></td>' == token:
  158. token = '<td>{}</td>'.format(txt)
  159. else:
  160. token = '{}</td>'.format(txt)
  161. td_index += 1
  162. token = deal_eb_token(token)
  163. end_html.append(token)
  164. html = ''.join(end_html)
  165. html = deal_bb(html)
  166. return html, end_html
  167. def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
  168. y1 = pred_bboxes[:, 1::2].min()
  169. new_dt_boxes = []
  170. new_rec_res = []
  171. for box, rec in zip(dt_boxes, rec_res):
  172. if np.max(box[1::2]) < y1:
  173. continue
  174. new_dt_boxes.append(box)
  175. new_rec_res.append(rec)
  176. return new_dt_boxes, new_rec_res