123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192 |
- # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import numpy as np
- from ppstructure.table.table_master_match import deal_eb_token, deal_bb
- def distance(box_1, box_2):
- x1, y1, x2, y2 = box_1
- x3, y3, x4, y4 = box_2
- dis = abs(x3 - x1) + abs(y3 - y1) + abs(x4 - x2) + abs(y4 - y2)
- dis_2 = abs(x3 - x1) + abs(y3 - y1)
- dis_3 = abs(x4 - x2) + abs(y4 - y2)
- return dis + min(dis_2, dis_3)
- def compute_iou(rec1, rec2):
- """
- computing IoU
- :param rec1: (y0, x0, y1, x1), which reflects
- (top, left, bottom, right)
- :param rec2: (y0, x0, y1, x1)
- :return: scala value of IoU
- """
- # computing area of each rectangles
- S_rec1 = (rec1[2] - rec1[0]) * (rec1[3] - rec1[1])
- S_rec2 = (rec2[2] - rec2[0]) * (rec2[3] - rec2[1])
- # computing the sum_area
- sum_area = S_rec1 + S_rec2
- # find the each edge of intersect rectangle
- left_line = max(rec1[1], rec2[1])
- right_line = min(rec1[3], rec2[3])
- top_line = max(rec1[0], rec2[0])
- bottom_line = min(rec1[2], rec2[2])
- # judge if there is an intersect
- if left_line >= right_line or top_line >= bottom_line:
- return 0.0
- else:
- intersect = (right_line - left_line) * (bottom_line - top_line)
- return (intersect / (sum_area - intersect)) * 1.0
- class TableMatch:
- def __init__(self, filter_ocr_result=False, use_master=False):
- self.filter_ocr_result = filter_ocr_result
- self.use_master = use_master
- def __call__(self, structure_res, dt_boxes, rec_res):
- pred_structures, pred_bboxes = structure_res
- if self.filter_ocr_result:
- dt_boxes, rec_res = self._filter_ocr_result(pred_bboxes, dt_boxes,
- rec_res)
- matched_index = self.match_result(dt_boxes, pred_bboxes)
- if self.use_master:
- pred_html, pred = self.get_pred_html_master(pred_structures,
- matched_index, rec_res)
- else:
- pred_html, pred = self.get_pred_html(pred_structures, matched_index,
- rec_res)
- return pred_html
- def match_result(self, dt_boxes, pred_bboxes):
- matched = {}
- for i, gt_box in enumerate(dt_boxes):
- distances = []
- for j, pred_box in enumerate(pred_bboxes):
- if len(pred_box) == 8:
- pred_box = [
- np.min(pred_box[0::2]), np.min(pred_box[1::2]),
- np.max(pred_box[0::2]), np.max(pred_box[1::2])
- ]
- distances.append((distance(gt_box, pred_box),
- 1. - compute_iou(gt_box, pred_box)
- )) # compute iou and l1 distance
- sorted_distances = distances.copy()
- # select det box by iou and l1 distance
- sorted_distances = sorted(
- sorted_distances, key=lambda item: (item[1], item[0]))
- if distances.index(sorted_distances[0]) not in matched.keys():
- matched[distances.index(sorted_distances[0])] = [i]
- else:
- matched[distances.index(sorted_distances[0])].append(i)
- return matched
- def get_pred_html(self, pred_structures, matched_index, ocr_contents):
- end_html = []
- td_index = 0
- for tag in pred_structures:
- if '</td>' in tag:
- if '<td></td>' == tag:
- end_html.extend('<td>')
- if td_index in matched_index.keys():
- b_with = False
- if '<b>' in ocr_contents[matched_index[td_index][
- 0]] and len(matched_index[td_index]) > 1:
- b_with = True
- end_html.extend('<b>')
- for i, td_index_index in enumerate(matched_index[td_index]):
- content = ocr_contents[td_index_index][0]
- if len(matched_index[td_index]) > 1:
- if len(content) == 0:
- continue
- if content[0] == ' ':
- content = content[1:]
- if '<b>' in content:
- content = content[3:]
- if '</b>' in content:
- content = content[:-4]
- if len(content) == 0:
- continue
- if i != len(matched_index[
- td_index]) - 1 and ' ' != content[-1]:
- content += ' '
- end_html.extend(content)
- if b_with:
- end_html.extend('</b>')
- if '<td></td>' == tag:
- end_html.append('</td>')
- else:
- end_html.append(tag)
- td_index += 1
- else:
- end_html.append(tag)
- return ''.join(end_html), end_html
- def get_pred_html_master(self, pred_structures, matched_index,
- ocr_contents):
- end_html = []
- td_index = 0
- for token in pred_structures:
- if '</td>' in token:
- txt = ''
- b_with = False
- if td_index in matched_index.keys():
- if '<b>' in ocr_contents[matched_index[td_index][
- 0]] and len(matched_index[td_index]) > 1:
- b_with = True
- for i, td_index_index in enumerate(matched_index[td_index]):
- content = ocr_contents[td_index_index][0]
- if len(matched_index[td_index]) > 1:
- if len(content) == 0:
- continue
- if content[0] == ' ':
- content = content[1:]
- if '<b>' in content:
- content = content[3:]
- if '</b>' in content:
- content = content[:-4]
- if len(content) == 0:
- continue
- if i != len(matched_index[
- td_index]) - 1 and ' ' != content[-1]:
- content += ' '
- txt += content
- if b_with:
- txt = '<b>{}</b>'.format(txt)
- if '<td></td>' == token:
- token = '<td>{}</td>'.format(txt)
- else:
- token = '{}</td>'.format(txt)
- td_index += 1
- token = deal_eb_token(token)
- end_html.append(token)
- html = ''.join(end_html)
- html = deal_bb(html)
- return html, end_html
- def _filter_ocr_result(self, pred_bboxes, dt_boxes, rec_res):
- y1 = pred_bboxes[:, 1::2].min()
- new_dt_boxes = []
- new_rec_res = []
- for box, rec in zip(dt_boxes, rec_res):
- if np.max(box[1::2]) < y1:
- continue
- new_dt_boxes.append(box)
- new_rec_res.append(rec)
- return new_dt_boxes, new_rec_res
|