123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953 |
- # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- This code is refer from:
- https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/match.py
- """
- import os
- import re
- import cv2
- import glob
- import copy
- import math
- import pickle
- import numpy as np
- from shapely.geometry import Polygon, MultiPoint
- """
- Useful function in matching.
- """
- def remove_empty_bboxes(bboxes):
- """
- remove [0., 0., 0., 0.] in structure master bboxes.
- len(bboxes.shape) must be 2.
- :param bboxes:
- :return:
- """
- new_bboxes = []
- for bbox in bboxes:
- if sum(bbox) == 0.:
- continue
- new_bboxes.append(bbox)
- return np.array(new_bboxes)
- def xywh2xyxy(bboxes):
- if len(bboxes.shape) == 1:
- new_bboxes = np.empty_like(bboxes)
- new_bboxes[0] = bboxes[0] - bboxes[2] / 2
- new_bboxes[1] = bboxes[1] - bboxes[3] / 2
- new_bboxes[2] = bboxes[0] + bboxes[2] / 2
- new_bboxes[3] = bboxes[1] + bboxes[3] / 2
- return new_bboxes
- elif len(bboxes.shape) == 2:
- new_bboxes = np.empty_like(bboxes)
- new_bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2
- new_bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2
- new_bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2
- new_bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2
- return new_bboxes
- else:
- raise ValueError
- def xyxy2xywh(bboxes):
- if len(bboxes.shape) == 1:
- new_bboxes = np.empty_like(bboxes)
- new_bboxes[0] = bboxes[0] + (bboxes[2] - bboxes[0]) / 2
- new_bboxes[1] = bboxes[1] + (bboxes[3] - bboxes[1]) / 2
- new_bboxes[2] = bboxes[2] - bboxes[0]
- new_bboxes[3] = bboxes[3] - bboxes[1]
- return new_bboxes
- elif len(bboxes.shape) == 2:
- new_bboxes = np.empty_like(bboxes)
- new_bboxes[:, 0] = bboxes[:, 0] + (bboxes[:, 2] - bboxes[:, 0]) / 2
- new_bboxes[:, 1] = bboxes[:, 1] + (bboxes[:, 3] - bboxes[:, 1]) / 2
- new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
- new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
- return new_bboxes
- else:
- raise ValueError
- def pickle_load(path, prefix='end2end'):
- if os.path.isfile(path):
- data = pickle.load(open(path, 'rb'))
- elif os.path.isdir(path):
- data = dict()
- search_path = os.path.join(path, '{}_*.pkl'.format(prefix))
- pkls = glob.glob(search_path)
- for pkl in pkls:
- this_data = pickle.load(open(pkl, 'rb'))
- data.update(this_data)
- else:
- raise ValueError
- return data
- def convert_coord(xyxy):
- """
- Convert two points format to four points format.
- :param xyxy:
- :return:
- """
- new_bbox = np.zeros([4, 2], dtype=np.float32)
- new_bbox[0, 0], new_bbox[0, 1] = xyxy[0], xyxy[1]
- new_bbox[1, 0], new_bbox[1, 1] = xyxy[2], xyxy[1]
- new_bbox[2, 0], new_bbox[2, 1] = xyxy[2], xyxy[3]
- new_bbox[3, 0], new_bbox[3, 1] = xyxy[0], xyxy[3]
- return new_bbox
- def cal_iou(bbox1, bbox2):
- bbox1_poly = Polygon(bbox1).convex_hull
- bbox2_poly = Polygon(bbox2).convex_hull
- union_poly = np.concatenate((bbox1, bbox2))
- if not bbox1_poly.intersects(bbox2_poly):
- iou = 0
- else:
- inter_area = bbox1_poly.intersection(bbox2_poly).area
- union_area = MultiPoint(union_poly).convex_hull.area
- if union_area == 0:
- iou = 0
- else:
- iou = float(inter_area) / union_area
- return iou
- def cal_distance(p1, p2):
- delta_x = p1[0] - p2[0]
- delta_y = p1[1] - p2[1]
- d = math.sqrt((delta_x**2) + (delta_y**2))
- return d
- def is_inside(center_point, corner_point):
- """
- Find if center_point inside the bbox(corner_point) or not.
- :param center_point: center point (x, y)
- :param corner_point: corner point ((x1,y1),(x2,y2))
- :return:
- """
- x_flag = False
- y_flag = False
- if (center_point[0] >= corner_point[0][0]) and (
- center_point[0] <= corner_point[1][0]):
- x_flag = True
- if (center_point[1] >= corner_point[0][1]) and (
- center_point[1] <= corner_point[1][1]):
- y_flag = True
- if x_flag and y_flag:
- return True
- else:
- return False
- def find_no_match(match_list, all_end2end_nums, type='end2end'):
- """
- Find out no match end2end bbox in previous match list.
- :param match_list: matching pairs.
- :param all_end2end_nums: numbers of end2end_xywh
- :param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1.
- :return: no match pse bbox index list
- """
- if type == 'end2end':
- idx = 0
- elif type == 'master':
- idx = 1
- else:
- raise ValueError
- no_match_indexs = []
- # m[0] is end2end index m[1] is master index
- matched_bbox_indexs = [m[idx] for m in match_list]
- for n in range(all_end2end_nums):
- if n not in matched_bbox_indexs:
- no_match_indexs.append(n)
- return no_match_indexs
- def is_abs_lower_than_threshold(this_bbox, target_bbox, threshold=3):
- # only consider y axis, for grouping in row.
- delta = abs(this_bbox[1] - target_bbox[1])
- if delta < threshold:
- return True
- else:
- return False
- def sort_line_bbox(g, bg):
- """
- Sorted the bbox in the same line(group)
- compare coord 'x' value, where 'y' value is closed in the same group.
- :param g: index in the same group
- :param bg: bbox in the same group
- :return:
- """
- xs = [bg_item[0] for bg_item in bg]
- xs_sorted = sorted(xs)
- g_sorted = [None] * len(xs_sorted)
- bg_sorted = [None] * len(xs_sorted)
- for g_item, bg_item in zip(g, bg):
- idx = xs_sorted.index(bg_item[0])
- bg_sorted[idx] = bg_item
- g_sorted[idx] = g_item
- return g_sorted, bg_sorted
- def flatten(sorted_groups, sorted_bbox_groups):
- idxs = []
- bboxes = []
- for group, bbox_group in zip(sorted_groups, sorted_bbox_groups):
- for g, bg in zip(group, bbox_group):
- idxs.append(g)
- bboxes.append(bg)
- return idxs, bboxes
- def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes):
- """
- This function will group the render end2end bboxes in row.
- :param end2end_xywh_bboxes:
- :param no_match_end2end_indexes:
- :return:
- """
- groups = []
- bbox_groups = []
- for index, end2end_xywh_bbox in zip(no_match_end2end_indexes,
- end2end_xywh_bboxes):
- this_bbox = end2end_xywh_bbox
- if len(groups) == 0:
- groups.append([index])
- bbox_groups.append([this_bbox])
- else:
- flag = False
- for g, bg in zip(groups, bbox_groups):
- # this_bbox is belong to bg's row or not
- if is_abs_lower_than_threshold(this_bbox, bg[0]):
- g.append(index)
- bg.append(this_bbox)
- flag = True
- break
- if not flag:
- # this_bbox is not belong to bg's row, create a row.
- groups.append([index])
- bbox_groups.append([this_bbox])
- # sorted bboxes in a group
- tmp_groups, tmp_bbox_groups = [], []
- for g, bg in zip(groups, bbox_groups):
- g_sorted, bg_sorted = sort_line_bbox(g, bg)
- tmp_groups.append(g_sorted)
- tmp_bbox_groups.append(bg_sorted)
- # sorted groups, sort by coord y's value.
- sorted_groups = [None] * len(tmp_groups)
- sorted_bbox_groups = [None] * len(tmp_bbox_groups)
- ys = [bg[0][1] for bg in tmp_bbox_groups]
- sorted_ys = sorted(ys)
- for g, bg in zip(tmp_groups, tmp_bbox_groups):
- idx = sorted_ys.index(bg[0][1])
- sorted_groups[idx] = g
- sorted_bbox_groups[idx] = bg
- # flatten, get final result
- end2end_sorted_idx_list, end2end_sorted_bbox_list \
- = flatten(sorted_groups, sorted_bbox_groups)
- return end2end_sorted_idx_list, end2end_sorted_bbox_list, sorted_groups, sorted_bbox_groups
- def get_bboxes_list(end2end_result, structure_master_result):
- """
- This function is use to convert end2end results and structure master results to
- List of xyxy bbox format and List of xywh bbox format
- :param end2end_result: bbox's format is xyxy
- :param structure_master_result: bbox's format is xywh
- :return: 4 kind list of bbox ()
- """
- # end2end
- end2end_xyxy_list = []
- end2end_xywh_list = []
- for end2end_item in end2end_result:
- src_bbox = end2end_item['bbox']
- end2end_xyxy_list.append(src_bbox)
- xywh_bbox = xyxy2xywh(src_bbox)
- end2end_xywh_list.append(xywh_bbox)
- end2end_xyxy_bboxes = np.array(end2end_xyxy_list)
- end2end_xywh_bboxes = np.array(end2end_xywh_list)
- # structure master
- src_bboxes = structure_master_result['bbox']
- src_bboxes = remove_empty_bboxes(src_bboxes)
- structure_master_xyxy_bboxes = src_bboxes
- xywh_bbox = xyxy2xywh(src_bboxes)
- structure_master_xywh_bboxes = xywh_bbox
- return end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes
- def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes):
- """
- Judge end2end Bbox's center point is inside structure master Bbox or not,
- if end2end Bbox's center is in structure master Bbox, get matching pair.
- :param end2end_xywh_bboxes:
- :param structure_master_xyxy_bboxes:
- :return: match pairs list, e.g. [[0,1], [1,2], ...]
- """
- match_pairs_list = []
- for i, end2end_xywh in enumerate(end2end_xywh_bboxes):
- for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
- x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1]
- x_master1, y_master1, x_master2, y_master2 \
- = master_xyxy[0], master_xyxy[1], master_xyxy[2], master_xyxy[3]
- center_point_end2end = (x_end2end, y_end2end)
- corner_point_master = ((x_master1, y_master1),
- (x_master2, y_master2))
- if is_inside(center_point_end2end, corner_point_master):
- match_pairs_list.append([i, j])
- return match_pairs_list
- def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes,
- structure_master_xyxy_bboxes):
- """
- Use iou to find matching list.
- choose max iou value bbox as match pair.
- :param end2end_xyxy_bboxes:
- :param end2end_xyxy_indexes: original end2end indexes.
- :param structure_master_xyxy_bboxes:
- :return: match pairs list, e.g. [[0,1], [1,2], ...]
- """
- match_pair_list = []
- for end2end_xyxy_index, end2end_xyxy in zip(end2end_xyxy_indexes,
- end2end_xyxy_bboxes):
- max_iou = 0
- max_match = [None, None]
- for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
- end2end_4xy = convert_coord(end2end_xyxy)
- master_4xy = convert_coord(master_xyxy)
- iou = cal_iou(end2end_4xy, master_4xy)
- if iou > max_iou:
- max_match[0], max_match[1] = end2end_xyxy_index, j
- max_iou = iou
- if max_match[0] is None:
- # no match
- continue
- match_pair_list.append(max_match)
- return match_pair_list
- def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes,
- master_bboxes):
- """
- Get matching between no-match end2end bboxes and no-match master bboxes.
- Use min distance to match.
- This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0)
- It will Return master_bboxes_nums match-pairs.
- :param end2end_indexes:
- :param end2end_bboxes:
- :param master_indexes:
- :param master_bboxes:
- :return: match_pairs list, e.g. [[0,1], [1,2], ...]
- """
- min_match_list = []
- for j, master_bbox in zip(master_indexes, master_bboxes):
- min_distance = np.inf
- min_match = [0, 0] # i, j
- for i, end2end_bbox in zip(end2end_indexes, end2end_bboxes):
- x_end2end, y_end2end = end2end_bbox[0], end2end_bbox[1]
- x_master, y_master = master_bbox[0], master_bbox[1]
- end2end_point = (x_end2end, y_end2end)
- master_point = (x_master, y_master)
- dist = cal_distance(master_point, end2end_point)
- if dist < min_distance:
- min_match[0], min_match[1] = i, j
- min_distance = dist
- min_match_list.append(min_match)
- return min_match_list
- def extra_match(no_match_end2end_indexes, master_bbox_nums):
- """
- This function will create some virtual master bboxes,
- and get match with the no match end2end indexes.
- :param no_match_end2end_indexes:
- :param master_bbox_nums:
- :return:
- """
- end_nums = len(no_match_end2end_indexes) + master_bbox_nums
- extra_match_list = []
- for i in range(master_bbox_nums, end_nums):
- end2end_index = no_match_end2end_indexes[i - master_bbox_nums]
- extra_match_list.append([end2end_index, i])
- return extra_match_list
- def get_match_dict(match_list):
- """
- Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index.
- :param match_list:
- :return:
- """
- match_dict = dict()
- for match_pair in match_list:
- end2end_index, master_index = match_pair[0], match_pair[1]
- if master_index not in match_dict.keys():
- match_dict[master_index] = [end2end_index]
- else:
- match_dict[master_index].append(end2end_index)
- return match_dict
- def deal_successive_space(text):
- """
- deal successive space character for text
- 1. Replace ' '*3 with '<space>' which is real space is text
- 2. Remove ' ', which is split token, not true space
- 3. Replace '<space>' with ' ', to get real text
- :param text:
- :return:
- """
- text = text.replace(' ' * 3, '<space>')
- text = text.replace(' ', '')
- text = text.replace('<space>', ' ')
- return text
- def reduce_repeat_bb(text_list, break_token):
- """
- convert ['<b>Local</b>', '<b>government</b>', '<b>unit</b>'] to ['<b>Local government unit</b>']
- PS: maybe style <i>Local</i> is also exist, too. it can be processed like this.
- :param text_list:
- :param break_token:
- :return:
- """
- count = 0
- for text in text_list:
- if text.startswith('<b>'):
- count += 1
- if count == len(text_list):
- new_text_list = []
- for text in text_list:
- text = text.replace('<b>', '').replace('</b>', '')
- new_text_list.append(text)
- return ['<b>' + break_token.join(new_text_list) + '</b>']
- else:
- return text_list
- def get_match_text_dict(match_dict, end2end_info, break_token=' '):
- match_text_dict = dict()
- for master_index, end2end_index_list in match_dict.items():
- text_list = [
- end2end_info[end2end_index]['text']
- for end2end_index in end2end_index_list
- ]
- text_list = reduce_repeat_bb(text_list, break_token)
- text = break_token.join(text_list)
- match_text_dict[master_index] = text
- return match_text_dict
- def merge_span_token(master_token_list):
- """
- Merge the span style token (row span or col span).
- :param master_token_list:
- :return:
- """
- new_master_token_list = []
- pointer = 0
- if master_token_list[-1] != '</tbody>':
- master_token_list.append('</tbody>')
- while master_token_list[pointer] != '</tbody>':
- try:
- if master_token_list[pointer] == '<td':
- if master_token_list[pointer + 1].startswith(
- ' colspan=') or master_token_list[
- pointer + 1].startswith(' rowspan='):
- """
- example:
- pattern <td colspan="3">
- '<td' + 'colspan=" "' + '>' + '</td>'
- """
- tmp = ''.join(master_token_list[pointer:pointer + 3 + 1])
- pointer += 4
- new_master_token_list.append(tmp)
- elif master_token_list[pointer + 2].startswith(
- ' colspan=') or master_token_list[
- pointer + 2].startswith(' rowspan='):
- """
- example:
- pattern <td rowspan="2" colspan="3">
- '<td' + 'rowspan=" "' + 'colspan=" "' + '>' + '</td>'
- """
- tmp = ''.join(master_token_list[pointer:pointer + 4 + 1])
- pointer += 5
- new_master_token_list.append(tmp)
- else:
- new_master_token_list.append(master_token_list[pointer])
- pointer += 1
- else:
- new_master_token_list.append(master_token_list[pointer])
- pointer += 1
- except:
- print("Break in merge...")
- break
- new_master_token_list.append('</tbody>')
- return new_master_token_list
- def deal_eb_token(master_token):
- """
- post process with <eb></eb>, <eb1></eb1>, ...
- emptyBboxTokenDict = {
- "[]": '<eb></eb>',
- "[' ']": '<eb1></eb1>',
- "['<b>', ' ', '</b>']": '<eb2></eb2>',
- "['\\u2028', '\\u2028']": '<eb3></eb3>',
- "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
- "['<b>', '</b>']": '<eb5></eb5>',
- "['<i>', ' ', '</i>']": '<eb6></eb6>',
- "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
- "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
- "['<i>', '</i>']": '<eb9></eb9>',
- "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
- }
- :param master_token:
- :return:
- """
- master_token = master_token.replace('<eb></eb>', '<td></td>')
- master_token = master_token.replace('<eb1></eb1>', '<td> </td>')
- master_token = master_token.replace('<eb2></eb2>', '<td><b> </b></td>')
- master_token = master_token.replace('<eb3></eb3>', '<td>\u2028\u2028</td>')
- master_token = master_token.replace('<eb4></eb4>', '<td><sup> </sup></td>')
- master_token = master_token.replace('<eb5></eb5>', '<td><b></b></td>')
- master_token = master_token.replace('<eb6></eb6>', '<td><i> </i></td>')
- master_token = master_token.replace('<eb7></eb7>',
- '<td><b><i></i></b></td>')
- master_token = master_token.replace('<eb8></eb8>',
- '<td><b><i> </i></b></td>')
- master_token = master_token.replace('<eb9></eb9>', '<td><i></i></td>')
- master_token = master_token.replace('<eb10></eb10>',
- '<td><b> \u2028 \u2028 </b></td>')
- return master_token
- def insert_text_to_token(master_token_list, match_text_dict):
- """
- Insert OCR text result to structure token.
- :param master_token_list:
- :param match_text_dict:
- :return:
- """
- master_token_list = merge_span_token(master_token_list)
- merged_result_list = []
- text_count = 0
- for master_token in master_token_list:
- if master_token.startswith('<td'):
- if text_count > len(match_text_dict) - 1:
- text_count += 1
- continue
- elif text_count not in match_text_dict.keys():
- text_count += 1
- continue
- else:
- master_token = master_token.replace(
- '><', '>{}<'.format(match_text_dict[text_count]))
- text_count += 1
- master_token = deal_eb_token(master_token)
- merged_result_list.append(master_token)
- return ''.join(merged_result_list)
- def deal_isolate_span(thead_part):
- """
- Deal with isolate span cases in this function.
- It causes by wrong prediction in structure recognition model.
- eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
- :param thead_part:
- :return:
- """
- # 1. find out isolate span tokens.
- isolate_pattern = "<td></td> rowspan=\"(\d)+\" colspan=\"(\d)+\"></b></td>|" \
- "<td></td> colspan=\"(\d)+\" rowspan=\"(\d)+\"></b></td>|" \
- "<td></td> rowspan=\"(\d)+\"></b></td>|" \
- "<td></td> colspan=\"(\d)+\"></b></td>"
- isolate_iter = re.finditer(isolate_pattern, thead_part)
- isolate_list = [i.group() for i in isolate_iter]
- # 2. find out span number, by step 1 results.
- span_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\"|" \
- " colspan=\"(\d)+\" rowspan=\"(\d)+\"|" \
- " rowspan=\"(\d)+\"|" \
- " colspan=\"(\d)+\""
- corrected_list = []
- for isolate_item in isolate_list:
- span_part = re.search(span_pattern, isolate_item)
- spanStr_in_isolateItem = span_part.group()
- # 3. merge the span number into the span token format string.
- if spanStr_in_isolateItem is not None:
- corrected_item = '<td{}></td>'.format(spanStr_in_isolateItem)
- corrected_list.append(corrected_item)
- else:
- corrected_list.append(None)
- # 4. replace original isolated token.
- for corrected_item, isolate_item in zip(corrected_list, isolate_list):
- if corrected_item is not None:
- thead_part = thead_part.replace(isolate_item, corrected_item)
- else:
- pass
- return thead_part
- def deal_duplicate_bb(thead_part):
- """
- Deal duplicate <b> or </b> after replace.
- Keep one <b></b> in a <td></td> token.
- :param thead_part:
- :return:
- """
- # 1. find out <td></td> in <thead></thead>.
- td_pattern = "<td rowspan=\"(\d)+\" colspan=\"(\d)+\">(.+?)</td>|" \
- "<td colspan=\"(\d)+\" rowspan=\"(\d)+\">(.+?)</td>|" \
- "<td rowspan=\"(\d)+\">(.+?)</td>|" \
- "<td colspan=\"(\d)+\">(.+?)</td>|" \
- "<td>(.*?)</td>"
- td_iter = re.finditer(td_pattern, thead_part)
- td_list = [t.group() for t in td_iter]
- # 2. is multiply <b></b> in <td></td> or not?
- new_td_list = []
- for td_item in td_list:
- if td_item.count('<b>') > 1 or td_item.count('</b>') > 1:
- # multiply <b></b> in <td></td> case.
- # 1. remove all <b></b>
- td_item = td_item.replace('<b>', '').replace('</b>', '')
- # 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
- td_item = td_item.replace('<td>', '<td><b>').replace('</td>',
- '</b></td>')
- new_td_list.append(td_item)
- else:
- new_td_list.append(td_item)
- # 3. replace original thead part.
- for td_item, new_td_item in zip(td_list, new_td_list):
- thead_part = thead_part.replace(td_item, new_td_item)
- return thead_part
- def deal_bb(result_token):
- """
- In our opinion, <b></b> always occurs in <thead></thead> text's context.
- This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
- :param result_token:
- :return:
- """
- # find out <thead></thead> parts.
- thead_pattern = '<thead>(.*?)</thead>'
- if re.search(thead_pattern, result_token) is None:
- return result_token
- thead_part = re.search(thead_pattern, result_token).group()
- origin_thead_part = copy.deepcopy(thead_part)
- # check "rowspan" or "colspan" occur in <thead></thead> parts or not .
- span_pattern = "<td rowspan=\"(\d)+\" colspan=\"(\d)+\">|<td colspan=\"(\d)+\" rowspan=\"(\d)+\">|<td rowspan=\"(\d)+\">|<td colspan=\"(\d)+\">"
- span_iter = re.finditer(span_pattern, thead_part)
- span_list = [s.group() for s in span_iter]
- has_span_in_head = True if len(span_list) > 0 else False
- if not has_span_in_head:
- # <thead></thead> not include "rowspan" or "colspan" branch 1.
- # 1. replace <td> to <td><b>, and </td> to </b></td>
- # 2. it is possible to predict text include <b> or </b> by Text-line recognition,
- # so we replace <b><b> to <b>, and </b></b> to </b>
- thead_part = thead_part.replace('<td>', '<td><b>')\
- .replace('</td>', '</b></td>')\
- .replace('<b><b>', '<b>')\
- .replace('</b></b>', '</b>')
- else:
- # <thead></thead> include "rowspan" or "colspan" branch 2.
- # Firstly, we deal rowspan or colspan cases.
- # 1. replace > to ><b>
- # 2. replace </td> to </b></td>
- # 3. it is possible to predict text include <b> or </b> by Text-line recognition,
- # so we replace <b><b> to <b>, and </b><b> to </b>
- # Secondly, deal ordinary cases like branch 1
- # replace ">" to "<b>"
- replaced_span_list = []
- for sp in span_list:
- replaced_span_list.append(sp.replace('>', '><b>'))
- for sp, rsp in zip(span_list, replaced_span_list):
- thead_part = thead_part.replace(sp, rsp)
- # replace "</td>" to "</b></td>"
- thead_part = thead_part.replace('</td>', '</b></td>')
- # remove duplicated <b> by re.sub
- mb_pattern = "(<b>)+"
- single_b_string = "<b>"
- thead_part = re.sub(mb_pattern, single_b_string, thead_part)
- mgb_pattern = "(</b>)+"
- single_gb_string = "</b>"
- thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
- # ordinary cases like branch 1
- thead_part = thead_part.replace('<td>', '<td><b>').replace('<b><b>',
- '<b>')
- # convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
- # but space cell(<tb> </tb>) is suitable for <td><b> </b></td>
- thead_part = thead_part.replace('<td><b></b></td>', '<td></td>')
- # deal with duplicated <b></b>
- thead_part = deal_duplicate_bb(thead_part)
- # deal with isolate span tokens, which causes by wrong predict by structure prediction.
- # eg.PMC5994107_011_00.png
- thead_part = deal_isolate_span(thead_part)
- # replace original result with new thead part.
- result_token = result_token.replace(origin_thead_part, thead_part)
- return result_token
- class Matcher:
- def __init__(self, end2end_file, structure_master_file):
- """
- This class process the end2end results and structure recognition results.
- :param end2end_file: end2end results predict by end2end inference.
- :param structure_master_file: structure recognition results predict by structure master inference.
- """
- self.end2end_file = end2end_file
- self.structure_master_file = structure_master_file
- self.end2end_results = pickle_load(end2end_file, prefix='end2end')
- self.structure_master_results = pickle_load(
- structure_master_file, prefix='structure')
- def match(self):
- """
- Match process:
- pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format.
- 1. Use pseBbox is inside masterBbox judge rule
- 2. Use iou between pseBbox and masterBbox rule
- 3. Use min distance of center point rule
- :return:
- """
- match_results = dict()
- for idx, (file_name,
- end2end_result) in enumerate(self.end2end_results.items()):
- match_list = []
- if file_name not in self.structure_master_results:
- continue
- structure_master_result = self.structure_master_results[file_name]
- end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes = \
- get_bboxes_list(end2end_result, structure_master_result)
- # rule 1: center rule
- center_rule_match_list = \
- center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes)
- match_list.extend(center_rule_match_list)
- # rule 2: iou rule
- # firstly, find not match index in previous step.
- center_no_match_end2end_indexs = \
- find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
- if len(center_no_match_end2end_indexs) > 0:
- center_no_match_end2end_xyxy = end2end_xyxy_bboxes[
- center_no_match_end2end_indexs]
- # secondly, iou rule match
- iou_rule_match_list = \
- iou_rule_match(center_no_match_end2end_xyxy, center_no_match_end2end_indexs, structure_master_xyxy_bboxes)
- match_list.extend(iou_rule_match_list)
- # rule 3: distance rule
- # match between no-match end2end bboxes and no-match master bboxes.
- # it will return master_bboxes_nums match-pairs.
- # firstly, find not match index in previous step.
- centerIou_no_match_end2end_indexs = \
- find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
- centerIou_no_match_master_indexs = \
- find_no_match(match_list, len(structure_master_xywh_bboxes), type='master')
- if len(centerIou_no_match_master_indexs) > 0 and len(
- centerIou_no_match_end2end_indexs) > 0:
- centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[
- centerIou_no_match_end2end_indexs]
- centerIou_no_match_master_xywh = structure_master_xywh_bboxes[
- centerIou_no_match_master_indexs]
- distance_match_list = distance_rule_match(
- centerIou_no_match_end2end_indexs,
- centerIou_no_match_end2end_xywh,
- centerIou_no_match_master_indexs,
- centerIou_no_match_master_xywh)
- match_list.extend(distance_match_list)
- # TODO:
- # The render no-match pseBbox, insert the last
- # After step3 distance rule, a master bbox at least match one end2end bbox.
- # But end2end bbox maybe overmuch, because numbers of master bbox will cut by max length.
- # For these render end2end bboxes, we will make some virtual master bboxes, and get matching.
- # The above extra insert bboxes will be further processed in "formatOutput" function.
- # After this operation, it will increase TEDS score.
- no_match_end2end_indexes = \
- find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
- if len(no_match_end2end_indexes) > 0:
- no_match_end2end_xywh = end2end_xywh_bboxes[
- no_match_end2end_indexes]
- # sort the render no-match end2end bbox in row
- end2end_sorted_indexes_list, end2end_sorted_bboxes_list, sorted_groups, sorted_bboxes_groups = \
- sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes)
- # make virtual master bboxes, and get matching with the no-match end2end bboxes.
- extra_match_list = extra_match(
- end2end_sorted_indexes_list,
- len(structure_master_xywh_bboxes))
- match_list_add_extra_match = copy.deepcopy(match_list)
- match_list_add_extra_match.extend(extra_match_list)
- else:
- # no no-match end2end bboxes
- match_list_add_extra_match = copy.deepcopy(match_list)
- sorted_groups = []
- sorted_bboxes_groups = []
- match_result_dict = {
- 'match_list': match_list,
- 'match_list_add_extra_match': match_list_add_extra_match,
- 'sorted_groups': sorted_groups,
- 'sorted_bboxes_groups': sorted_bboxes_groups
- }
- # format output
- match_result_dict = self._format(match_result_dict, file_name)
- match_results[file_name] = match_result_dict
- return match_results
- def _format(self, match_result, file_name):
- """
- Extend the master token(insert virtual master token), and format matching result.
- :param match_result:
- :param file_name:
- :return:
- """
- end2end_info = self.end2end_results[file_name]
- master_info = self.structure_master_results[file_name]
- master_token = master_info['text']
- sorted_groups = match_result['sorted_groups']
- # creat virtual master token
- virtual_master_token_list = []
- for line_group in sorted_groups:
- tmp_list = ['<tr>']
- item_nums = len(line_group)
- for _ in range(item_nums):
- tmp_list.append('<td></td>')
- tmp_list.append('</tr>')
- virtual_master_token_list.extend(tmp_list)
- # insert virtual master token
- master_token_list = master_token.split(',')
- if master_token_list[-1] == '</tbody>':
- # complete predict(no cut by max length)
- # This situation insert virtual master token will drop TEDs score in val set.
- # So we will not extend virtual token in this situation.
- # fake extend virtual
- master_token_list[:-1].extend(virtual_master_token_list)
- # real extend virtual
- # master_token_list = master_token_list[:-1]
- # master_token_list.extend(virtual_master_token_list)
- # master_token_list.append('</tbody>')
- elif master_token_list[-1] == '<td></td>':
- master_token_list.append('</tr>')
- master_token_list.extend(virtual_master_token_list)
- master_token_list.append('</tbody>')
- else:
- master_token_list.extend(virtual_master_token_list)
- master_token_list.append('</tbody>')
- # format output
- match_result.setdefault('matched_master_token_list', master_token_list)
- return match_result
- def get_merge_result(self, match_results):
- """
- Merge the OCR result into structure token to get final results.
- :param match_results:
- :return:
- """
- merged_results = dict()
- # break_token is linefeed token, when one master bbox has multiply end2end bboxes.
- break_token = ' '
- for idx, (file_name, match_info) in enumerate(match_results.items()):
- end2end_info = self.end2end_results[file_name]
- master_token_list = match_info['matched_master_token_list']
- match_list = match_info['match_list_add_extra_match']
- match_dict = get_match_dict(match_list)
- match_text_dict = get_match_text_dict(match_dict, end2end_info,
- break_token)
- merged_result = insert_text_to_token(master_token_list,
- match_text_dict)
- merged_result = deal_bb(merged_result)
- merged_results[file_name] = merged_result
- return merged_results
- class TableMasterMatcher(Matcher):
- def __init__(self):
- pass
- def __call__(self, structure_res, dt_boxes, rec_res, img_name=1):
- end2end_results = {img_name: []}
- for dt_box, res in zip(dt_boxes, rec_res):
- d = dict(
- bbox=np.array(dt_box),
- text=res[0], )
- end2end_results[img_name].append(d)
- self.end2end_results = end2end_results
- structure_master_result_dict = {img_name: {}}
- pred_structures, pred_bboxes = structure_res
- pred_structures = ','.join(pred_structures[3:-3])
- structure_master_result_dict[img_name]['text'] = pred_structures
- structure_master_result_dict[img_name]['bbox'] = pred_bboxes
- self.structure_master_results = structure_master_result_dict
- # match
- match_results = self.match()
- merged_results = self.get_merge_result(match_results)
- pred_html = merged_results[img_name]
- pred_html = '<html><body><table>' + pred_html + '</table></body></html>'
- return pred_html
|