table_master_match.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/match.py
  17. """
  18. import os
  19. import re
  20. import cv2
  21. import glob
  22. import copy
  23. import math
  24. import pickle
  25. import numpy as np
  26. from shapely.geometry import Polygon, MultiPoint
  27. """
  28. Useful function in matching.
  29. """
  30. def remove_empty_bboxes(bboxes):
  31. """
  32. remove [0., 0., 0., 0.] in structure master bboxes.
  33. len(bboxes.shape) must be 2.
  34. :param bboxes:
  35. :return:
  36. """
  37. new_bboxes = []
  38. for bbox in bboxes:
  39. if sum(bbox) == 0.:
  40. continue
  41. new_bboxes.append(bbox)
  42. return np.array(new_bboxes)
  43. def xywh2xyxy(bboxes):
  44. if len(bboxes.shape) == 1:
  45. new_bboxes = np.empty_like(bboxes)
  46. new_bboxes[0] = bboxes[0] - bboxes[2] / 2
  47. new_bboxes[1] = bboxes[1] - bboxes[3] / 2
  48. new_bboxes[2] = bboxes[0] + bboxes[2] / 2
  49. new_bboxes[3] = bboxes[1] + bboxes[3] / 2
  50. return new_bboxes
  51. elif len(bboxes.shape) == 2:
  52. new_bboxes = np.empty_like(bboxes)
  53. new_bboxes[:, 0] = bboxes[:, 0] - bboxes[:, 2] / 2
  54. new_bboxes[:, 1] = bboxes[:, 1] - bboxes[:, 3] / 2
  55. new_bboxes[:, 2] = bboxes[:, 0] + bboxes[:, 2] / 2
  56. new_bboxes[:, 3] = bboxes[:, 1] + bboxes[:, 3] / 2
  57. return new_bboxes
  58. else:
  59. raise ValueError
  60. def xyxy2xywh(bboxes):
  61. if len(bboxes.shape) == 1:
  62. new_bboxes = np.empty_like(bboxes)
  63. new_bboxes[0] = bboxes[0] + (bboxes[2] - bboxes[0]) / 2
  64. new_bboxes[1] = bboxes[1] + (bboxes[3] - bboxes[1]) / 2
  65. new_bboxes[2] = bboxes[2] - bboxes[0]
  66. new_bboxes[3] = bboxes[3] - bboxes[1]
  67. return new_bboxes
  68. elif len(bboxes.shape) == 2:
  69. new_bboxes = np.empty_like(bboxes)
  70. new_bboxes[:, 0] = bboxes[:, 0] + (bboxes[:, 2] - bboxes[:, 0]) / 2
  71. new_bboxes[:, 1] = bboxes[:, 1] + (bboxes[:, 3] - bboxes[:, 1]) / 2
  72. new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0]
  73. new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1]
  74. return new_bboxes
  75. else:
  76. raise ValueError
  77. def pickle_load(path, prefix='end2end'):
  78. if os.path.isfile(path):
  79. data = pickle.load(open(path, 'rb'))
  80. elif os.path.isdir(path):
  81. data = dict()
  82. search_path = os.path.join(path, '{}_*.pkl'.format(prefix))
  83. pkls = glob.glob(search_path)
  84. for pkl in pkls:
  85. this_data = pickle.load(open(pkl, 'rb'))
  86. data.update(this_data)
  87. else:
  88. raise ValueError
  89. return data
  90. def convert_coord(xyxy):
  91. """
  92. Convert two points format to four points format.
  93. :param xyxy:
  94. :return:
  95. """
  96. new_bbox = np.zeros([4, 2], dtype=np.float32)
  97. new_bbox[0, 0], new_bbox[0, 1] = xyxy[0], xyxy[1]
  98. new_bbox[1, 0], new_bbox[1, 1] = xyxy[2], xyxy[1]
  99. new_bbox[2, 0], new_bbox[2, 1] = xyxy[2], xyxy[3]
  100. new_bbox[3, 0], new_bbox[3, 1] = xyxy[0], xyxy[3]
  101. return new_bbox
  102. def cal_iou(bbox1, bbox2):
  103. bbox1_poly = Polygon(bbox1).convex_hull
  104. bbox2_poly = Polygon(bbox2).convex_hull
  105. union_poly = np.concatenate((bbox1, bbox2))
  106. if not bbox1_poly.intersects(bbox2_poly):
  107. iou = 0
  108. else:
  109. inter_area = bbox1_poly.intersection(bbox2_poly).area
  110. union_area = MultiPoint(union_poly).convex_hull.area
  111. if union_area == 0:
  112. iou = 0
  113. else:
  114. iou = float(inter_area) / union_area
  115. return iou
  116. def cal_distance(p1, p2):
  117. delta_x = p1[0] - p2[0]
  118. delta_y = p1[1] - p2[1]
  119. d = math.sqrt((delta_x**2) + (delta_y**2))
  120. return d
  121. def is_inside(center_point, corner_point):
  122. """
  123. Find if center_point inside the bbox(corner_point) or not.
  124. :param center_point: center point (x, y)
  125. :param corner_point: corner point ((x1,y1),(x2,y2))
  126. :return:
  127. """
  128. x_flag = False
  129. y_flag = False
  130. if (center_point[0] >= corner_point[0][0]) and (
  131. center_point[0] <= corner_point[1][0]):
  132. x_flag = True
  133. if (center_point[1] >= corner_point[0][1]) and (
  134. center_point[1] <= corner_point[1][1]):
  135. y_flag = True
  136. if x_flag and y_flag:
  137. return True
  138. else:
  139. return False
  140. def find_no_match(match_list, all_end2end_nums, type='end2end'):
  141. """
  142. Find out no match end2end bbox in previous match list.
  143. :param match_list: matching pairs.
  144. :param all_end2end_nums: numbers of end2end_xywh
  145. :param type: 'end2end' corresponding to idx 0, 'master' corresponding to idx 1.
  146. :return: no match pse bbox index list
  147. """
  148. if type == 'end2end':
  149. idx = 0
  150. elif type == 'master':
  151. idx = 1
  152. else:
  153. raise ValueError
  154. no_match_indexs = []
  155. # m[0] is end2end index m[1] is master index
  156. matched_bbox_indexs = [m[idx] for m in match_list]
  157. for n in range(all_end2end_nums):
  158. if n not in matched_bbox_indexs:
  159. no_match_indexs.append(n)
  160. return no_match_indexs
  161. def is_abs_lower_than_threshold(this_bbox, target_bbox, threshold=3):
  162. # only consider y axis, for grouping in row.
  163. delta = abs(this_bbox[1] - target_bbox[1])
  164. if delta < threshold:
  165. return True
  166. else:
  167. return False
  168. def sort_line_bbox(g, bg):
  169. """
  170. Sorted the bbox in the same line(group)
  171. compare coord 'x' value, where 'y' value is closed in the same group.
  172. :param g: index in the same group
  173. :param bg: bbox in the same group
  174. :return:
  175. """
  176. xs = [bg_item[0] for bg_item in bg]
  177. xs_sorted = sorted(xs)
  178. g_sorted = [None] * len(xs_sorted)
  179. bg_sorted = [None] * len(xs_sorted)
  180. for g_item, bg_item in zip(g, bg):
  181. idx = xs_sorted.index(bg_item[0])
  182. bg_sorted[idx] = bg_item
  183. g_sorted[idx] = g_item
  184. return g_sorted, bg_sorted
  185. def flatten(sorted_groups, sorted_bbox_groups):
  186. idxs = []
  187. bboxes = []
  188. for group, bbox_group in zip(sorted_groups, sorted_bbox_groups):
  189. for g, bg in zip(group, bbox_group):
  190. idxs.append(g)
  191. bboxes.append(bg)
  192. return idxs, bboxes
  193. def sort_bbox(end2end_xywh_bboxes, no_match_end2end_indexes):
  194. """
  195. This function will group the render end2end bboxes in row.
  196. :param end2end_xywh_bboxes:
  197. :param no_match_end2end_indexes:
  198. :return:
  199. """
  200. groups = []
  201. bbox_groups = []
  202. for index, end2end_xywh_bbox in zip(no_match_end2end_indexes,
  203. end2end_xywh_bboxes):
  204. this_bbox = end2end_xywh_bbox
  205. if len(groups) == 0:
  206. groups.append([index])
  207. bbox_groups.append([this_bbox])
  208. else:
  209. flag = False
  210. for g, bg in zip(groups, bbox_groups):
  211. # this_bbox is belong to bg's row or not
  212. if is_abs_lower_than_threshold(this_bbox, bg[0]):
  213. g.append(index)
  214. bg.append(this_bbox)
  215. flag = True
  216. break
  217. if not flag:
  218. # this_bbox is not belong to bg's row, create a row.
  219. groups.append([index])
  220. bbox_groups.append([this_bbox])
  221. # sorted bboxes in a group
  222. tmp_groups, tmp_bbox_groups = [], []
  223. for g, bg in zip(groups, bbox_groups):
  224. g_sorted, bg_sorted = sort_line_bbox(g, bg)
  225. tmp_groups.append(g_sorted)
  226. tmp_bbox_groups.append(bg_sorted)
  227. # sorted groups, sort by coord y's value.
  228. sorted_groups = [None] * len(tmp_groups)
  229. sorted_bbox_groups = [None] * len(tmp_bbox_groups)
  230. ys = [bg[0][1] for bg in tmp_bbox_groups]
  231. sorted_ys = sorted(ys)
  232. for g, bg in zip(tmp_groups, tmp_bbox_groups):
  233. idx = sorted_ys.index(bg[0][1])
  234. sorted_groups[idx] = g
  235. sorted_bbox_groups[idx] = bg
  236. # flatten, get final result
  237. end2end_sorted_idx_list, end2end_sorted_bbox_list \
  238. = flatten(sorted_groups, sorted_bbox_groups)
  239. return end2end_sorted_idx_list, end2end_sorted_bbox_list, sorted_groups, sorted_bbox_groups
  240. def get_bboxes_list(end2end_result, structure_master_result):
  241. """
  242. This function is use to convert end2end results and structure master results to
  243. List of xyxy bbox format and List of xywh bbox format
  244. :param end2end_result: bbox's format is xyxy
  245. :param structure_master_result: bbox's format is xywh
  246. :return: 4 kind list of bbox ()
  247. """
  248. # end2end
  249. end2end_xyxy_list = []
  250. end2end_xywh_list = []
  251. for end2end_item in end2end_result:
  252. src_bbox = end2end_item['bbox']
  253. end2end_xyxy_list.append(src_bbox)
  254. xywh_bbox = xyxy2xywh(src_bbox)
  255. end2end_xywh_list.append(xywh_bbox)
  256. end2end_xyxy_bboxes = np.array(end2end_xyxy_list)
  257. end2end_xywh_bboxes = np.array(end2end_xywh_list)
  258. # structure master
  259. src_bboxes = structure_master_result['bbox']
  260. src_bboxes = remove_empty_bboxes(src_bboxes)
  261. structure_master_xyxy_bboxes = src_bboxes
  262. xywh_bbox = xyxy2xywh(src_bboxes)
  263. structure_master_xywh_bboxes = xywh_bbox
  264. return end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes
  265. def center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes):
  266. """
  267. Judge end2end Bbox's center point is inside structure master Bbox or not,
  268. if end2end Bbox's center is in structure master Bbox, get matching pair.
  269. :param end2end_xywh_bboxes:
  270. :param structure_master_xyxy_bboxes:
  271. :return: match pairs list, e.g. [[0,1], [1,2], ...]
  272. """
  273. match_pairs_list = []
  274. for i, end2end_xywh in enumerate(end2end_xywh_bboxes):
  275. for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
  276. x_end2end, y_end2end = end2end_xywh[0], end2end_xywh[1]
  277. x_master1, y_master1, x_master2, y_master2 \
  278. = master_xyxy[0], master_xyxy[1], master_xyxy[2], master_xyxy[3]
  279. center_point_end2end = (x_end2end, y_end2end)
  280. corner_point_master = ((x_master1, y_master1),
  281. (x_master2, y_master2))
  282. if is_inside(center_point_end2end, corner_point_master):
  283. match_pairs_list.append([i, j])
  284. return match_pairs_list
  285. def iou_rule_match(end2end_xyxy_bboxes, end2end_xyxy_indexes,
  286. structure_master_xyxy_bboxes):
  287. """
  288. Use iou to find matching list.
  289. choose max iou value bbox as match pair.
  290. :param end2end_xyxy_bboxes:
  291. :param end2end_xyxy_indexes: original end2end indexes.
  292. :param structure_master_xyxy_bboxes:
  293. :return: match pairs list, e.g. [[0,1], [1,2], ...]
  294. """
  295. match_pair_list = []
  296. for end2end_xyxy_index, end2end_xyxy in zip(end2end_xyxy_indexes,
  297. end2end_xyxy_bboxes):
  298. max_iou = 0
  299. max_match = [None, None]
  300. for j, master_xyxy in enumerate(structure_master_xyxy_bboxes):
  301. end2end_4xy = convert_coord(end2end_xyxy)
  302. master_4xy = convert_coord(master_xyxy)
  303. iou = cal_iou(end2end_4xy, master_4xy)
  304. if iou > max_iou:
  305. max_match[0], max_match[1] = end2end_xyxy_index, j
  306. max_iou = iou
  307. if max_match[0] is None:
  308. # no match
  309. continue
  310. match_pair_list.append(max_match)
  311. return match_pair_list
  312. def distance_rule_match(end2end_indexes, end2end_bboxes, master_indexes,
  313. master_bboxes):
  314. """
  315. Get matching between no-match end2end bboxes and no-match master bboxes.
  316. Use min distance to match.
  317. This rule will only run (no-match end2end nums > 0) and (no-match master nums > 0)
  318. It will Return master_bboxes_nums match-pairs.
  319. :param end2end_indexes:
  320. :param end2end_bboxes:
  321. :param master_indexes:
  322. :param master_bboxes:
  323. :return: match_pairs list, e.g. [[0,1], [1,2], ...]
  324. """
  325. min_match_list = []
  326. for j, master_bbox in zip(master_indexes, master_bboxes):
  327. min_distance = np.inf
  328. min_match = [0, 0] # i, j
  329. for i, end2end_bbox in zip(end2end_indexes, end2end_bboxes):
  330. x_end2end, y_end2end = end2end_bbox[0], end2end_bbox[1]
  331. x_master, y_master = master_bbox[0], master_bbox[1]
  332. end2end_point = (x_end2end, y_end2end)
  333. master_point = (x_master, y_master)
  334. dist = cal_distance(master_point, end2end_point)
  335. if dist < min_distance:
  336. min_match[0], min_match[1] = i, j
  337. min_distance = dist
  338. min_match_list.append(min_match)
  339. return min_match_list
  340. def extra_match(no_match_end2end_indexes, master_bbox_nums):
  341. """
  342. This function will create some virtual master bboxes,
  343. and get match with the no match end2end indexes.
  344. :param no_match_end2end_indexes:
  345. :param master_bbox_nums:
  346. :return:
  347. """
  348. end_nums = len(no_match_end2end_indexes) + master_bbox_nums
  349. extra_match_list = []
  350. for i in range(master_bbox_nums, end_nums):
  351. end2end_index = no_match_end2end_indexes[i - master_bbox_nums]
  352. extra_match_list.append([end2end_index, i])
  353. return extra_match_list
  354. def get_match_dict(match_list):
  355. """
  356. Convert match_list to a dict, where key is master bbox's index, value is end2end bbox index.
  357. :param match_list:
  358. :return:
  359. """
  360. match_dict = dict()
  361. for match_pair in match_list:
  362. end2end_index, master_index = match_pair[0], match_pair[1]
  363. if master_index not in match_dict.keys():
  364. match_dict[master_index] = [end2end_index]
  365. else:
  366. match_dict[master_index].append(end2end_index)
  367. return match_dict
  368. def deal_successive_space(text):
  369. """
  370. deal successive space character for text
  371. 1. Replace ' '*3 with '<space>' which is real space is text
  372. 2. Remove ' ', which is split token, not true space
  373. 3. Replace '<space>' with ' ', to get real text
  374. :param text:
  375. :return:
  376. """
  377. text = text.replace(' ' * 3, '<space>')
  378. text = text.replace(' ', '')
  379. text = text.replace('<space>', ' ')
  380. return text
  381. def reduce_repeat_bb(text_list, break_token):
  382. """
  383. convert ['<b>Local</b>', '<b>government</b>', '<b>unit</b>'] to ['<b>Local government unit</b>']
  384. PS: maybe style <i>Local</i> is also exist, too. it can be processed like this.
  385. :param text_list:
  386. :param break_token:
  387. :return:
  388. """
  389. count = 0
  390. for text in text_list:
  391. if text.startswith('<b>'):
  392. count += 1
  393. if count == len(text_list):
  394. new_text_list = []
  395. for text in text_list:
  396. text = text.replace('<b>', '').replace('</b>', '')
  397. new_text_list.append(text)
  398. return ['<b>' + break_token.join(new_text_list) + '</b>']
  399. else:
  400. return text_list
  401. def get_match_text_dict(match_dict, end2end_info, break_token=' '):
  402. match_text_dict = dict()
  403. for master_index, end2end_index_list in match_dict.items():
  404. text_list = [
  405. end2end_info[end2end_index]['text']
  406. for end2end_index in end2end_index_list
  407. ]
  408. text_list = reduce_repeat_bb(text_list, break_token)
  409. text = break_token.join(text_list)
  410. match_text_dict[master_index] = text
  411. return match_text_dict
  412. def merge_span_token(master_token_list):
  413. """
  414. Merge the span style token (row span or col span).
  415. :param master_token_list:
  416. :return:
  417. """
  418. new_master_token_list = []
  419. pointer = 0
  420. if master_token_list[-1] != '</tbody>':
  421. master_token_list.append('</tbody>')
  422. while master_token_list[pointer] != '</tbody>':
  423. try:
  424. if master_token_list[pointer] == '<td':
  425. if master_token_list[pointer + 1].startswith(
  426. ' colspan=') or master_token_list[
  427. pointer + 1].startswith(' rowspan='):
  428. """
  429. example:
  430. pattern <td colspan="3">
  431. '<td' + 'colspan=" "' + '>' + '</td>'
  432. """
  433. tmp = ''.join(master_token_list[pointer:pointer + 3 + 1])
  434. pointer += 4
  435. new_master_token_list.append(tmp)
  436. elif master_token_list[pointer + 2].startswith(
  437. ' colspan=') or master_token_list[
  438. pointer + 2].startswith(' rowspan='):
  439. """
  440. example:
  441. pattern <td rowspan="2" colspan="3">
  442. '<td' + 'rowspan=" "' + 'colspan=" "' + '>' + '</td>'
  443. """
  444. tmp = ''.join(master_token_list[pointer:pointer + 4 + 1])
  445. pointer += 5
  446. new_master_token_list.append(tmp)
  447. else:
  448. new_master_token_list.append(master_token_list[pointer])
  449. pointer += 1
  450. else:
  451. new_master_token_list.append(master_token_list[pointer])
  452. pointer += 1
  453. except:
  454. print("Break in merge...")
  455. break
  456. new_master_token_list.append('</tbody>')
  457. return new_master_token_list
  458. def deal_eb_token(master_token):
  459. """
  460. post process with <eb></eb>, <eb1></eb1>, ...
  461. emptyBboxTokenDict = {
  462. "[]": '<eb></eb>',
  463. "[' ']": '<eb1></eb1>',
  464. "['<b>', ' ', '</b>']": '<eb2></eb2>',
  465. "['\\u2028', '\\u2028']": '<eb3></eb3>',
  466. "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
  467. "['<b>', '</b>']": '<eb5></eb5>',
  468. "['<i>', ' ', '</i>']": '<eb6></eb6>',
  469. "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
  470. "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
  471. "['<i>', '</i>']": '<eb9></eb9>',
  472. "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']": '<eb10></eb10>',
  473. }
  474. :param master_token:
  475. :return:
  476. """
  477. master_token = master_token.replace('<eb></eb>', '<td></td>')
  478. master_token = master_token.replace('<eb1></eb1>', '<td> </td>')
  479. master_token = master_token.replace('<eb2></eb2>', '<td><b> </b></td>')
  480. master_token = master_token.replace('<eb3></eb3>', '<td>\u2028\u2028</td>')
  481. master_token = master_token.replace('<eb4></eb4>', '<td><sup> </sup></td>')
  482. master_token = master_token.replace('<eb5></eb5>', '<td><b></b></td>')
  483. master_token = master_token.replace('<eb6></eb6>', '<td><i> </i></td>')
  484. master_token = master_token.replace('<eb7></eb7>',
  485. '<td><b><i></i></b></td>')
  486. master_token = master_token.replace('<eb8></eb8>',
  487. '<td><b><i> </i></b></td>')
  488. master_token = master_token.replace('<eb9></eb9>', '<td><i></i></td>')
  489. master_token = master_token.replace('<eb10></eb10>',
  490. '<td><b> \u2028 \u2028 </b></td>')
  491. return master_token
  492. def insert_text_to_token(master_token_list, match_text_dict):
  493. """
  494. Insert OCR text result to structure token.
  495. :param master_token_list:
  496. :param match_text_dict:
  497. :return:
  498. """
  499. master_token_list = merge_span_token(master_token_list)
  500. merged_result_list = []
  501. text_count = 0
  502. for master_token in master_token_list:
  503. if master_token.startswith('<td'):
  504. if text_count > len(match_text_dict) - 1:
  505. text_count += 1
  506. continue
  507. elif text_count not in match_text_dict.keys():
  508. text_count += 1
  509. continue
  510. else:
  511. master_token = master_token.replace(
  512. '><', '>{}<'.format(match_text_dict[text_count]))
  513. text_count += 1
  514. master_token = deal_eb_token(master_token)
  515. merged_result_list.append(master_token)
  516. return ''.join(merged_result_list)
  517. def deal_isolate_span(thead_part):
  518. """
  519. Deal with isolate span cases in this function.
  520. It causes by wrong prediction in structure recognition model.
  521. eg. predict <td rowspan="2"></td> to <td></td> rowspan="2"></b></td>.
  522. :param thead_part:
  523. :return:
  524. """
  525. # 1. find out isolate span tokens.
  526. isolate_pattern = "<td></td> rowspan=\"(\d)+\" colspan=\"(\d)+\"></b></td>|" \
  527. "<td></td> colspan=\"(\d)+\" rowspan=\"(\d)+\"></b></td>|" \
  528. "<td></td> rowspan=\"(\d)+\"></b></td>|" \
  529. "<td></td> colspan=\"(\d)+\"></b></td>"
  530. isolate_iter = re.finditer(isolate_pattern, thead_part)
  531. isolate_list = [i.group() for i in isolate_iter]
  532. # 2. find out span number, by step 1 results.
  533. span_pattern = " rowspan=\"(\d)+\" colspan=\"(\d)+\"|" \
  534. " colspan=\"(\d)+\" rowspan=\"(\d)+\"|" \
  535. " rowspan=\"(\d)+\"|" \
  536. " colspan=\"(\d)+\""
  537. corrected_list = []
  538. for isolate_item in isolate_list:
  539. span_part = re.search(span_pattern, isolate_item)
  540. spanStr_in_isolateItem = span_part.group()
  541. # 3. merge the span number into the span token format string.
  542. if spanStr_in_isolateItem is not None:
  543. corrected_item = '<td{}></td>'.format(spanStr_in_isolateItem)
  544. corrected_list.append(corrected_item)
  545. else:
  546. corrected_list.append(None)
  547. # 4. replace original isolated token.
  548. for corrected_item, isolate_item in zip(corrected_list, isolate_list):
  549. if corrected_item is not None:
  550. thead_part = thead_part.replace(isolate_item, corrected_item)
  551. else:
  552. pass
  553. return thead_part
  554. def deal_duplicate_bb(thead_part):
  555. """
  556. Deal duplicate <b> or </b> after replace.
  557. Keep one <b></b> in a <td></td> token.
  558. :param thead_part:
  559. :return:
  560. """
  561. # 1. find out <td></td> in <thead></thead>.
  562. td_pattern = "<td rowspan=\"(\d)+\" colspan=\"(\d)+\">(.+?)</td>|" \
  563. "<td colspan=\"(\d)+\" rowspan=\"(\d)+\">(.+?)</td>|" \
  564. "<td rowspan=\"(\d)+\">(.+?)</td>|" \
  565. "<td colspan=\"(\d)+\">(.+?)</td>|" \
  566. "<td>(.*?)</td>"
  567. td_iter = re.finditer(td_pattern, thead_part)
  568. td_list = [t.group() for t in td_iter]
  569. # 2. is multiply <b></b> in <td></td> or not?
  570. new_td_list = []
  571. for td_item in td_list:
  572. if td_item.count('<b>') > 1 or td_item.count('</b>') > 1:
  573. # multiply <b></b> in <td></td> case.
  574. # 1. remove all <b></b>
  575. td_item = td_item.replace('<b>', '').replace('</b>', '')
  576. # 2. replace <tb> -> <tb><b>, </tb> -> </b></tb>.
  577. td_item = td_item.replace('<td>', '<td><b>').replace('</td>',
  578. '</b></td>')
  579. new_td_list.append(td_item)
  580. else:
  581. new_td_list.append(td_item)
  582. # 3. replace original thead part.
  583. for td_item, new_td_item in zip(td_list, new_td_list):
  584. thead_part = thead_part.replace(td_item, new_td_item)
  585. return thead_part
  586. def deal_bb(result_token):
  587. """
  588. In our opinion, <b></b> always occurs in <thead></thead> text's context.
  589. This function will find out all tokens in <thead></thead> and insert <b></b> by manual.
  590. :param result_token:
  591. :return:
  592. """
  593. # find out <thead></thead> parts.
  594. thead_pattern = '<thead>(.*?)</thead>'
  595. if re.search(thead_pattern, result_token) is None:
  596. return result_token
  597. thead_part = re.search(thead_pattern, result_token).group()
  598. origin_thead_part = copy.deepcopy(thead_part)
  599. # check "rowspan" or "colspan" occur in <thead></thead> parts or not .
  600. span_pattern = "<td rowspan=\"(\d)+\" colspan=\"(\d)+\">|<td colspan=\"(\d)+\" rowspan=\"(\d)+\">|<td rowspan=\"(\d)+\">|<td colspan=\"(\d)+\">"
  601. span_iter = re.finditer(span_pattern, thead_part)
  602. span_list = [s.group() for s in span_iter]
  603. has_span_in_head = True if len(span_list) > 0 else False
  604. if not has_span_in_head:
  605. # <thead></thead> not include "rowspan" or "colspan" branch 1.
  606. # 1. replace <td> to <td><b>, and </td> to </b></td>
  607. # 2. it is possible to predict text include <b> or </b> by Text-line recognition,
  608. # so we replace <b><b> to <b>, and </b></b> to </b>
  609. thead_part = thead_part.replace('<td>', '<td><b>')\
  610. .replace('</td>', '</b></td>')\
  611. .replace('<b><b>', '<b>')\
  612. .replace('</b></b>', '</b>')
  613. else:
  614. # <thead></thead> include "rowspan" or "colspan" branch 2.
  615. # Firstly, we deal rowspan or colspan cases.
  616. # 1. replace > to ><b>
  617. # 2. replace </td> to </b></td>
  618. # 3. it is possible to predict text include <b> or </b> by Text-line recognition,
  619. # so we replace <b><b> to <b>, and </b><b> to </b>
  620. # Secondly, deal ordinary cases like branch 1
  621. # replace ">" to "<b>"
  622. replaced_span_list = []
  623. for sp in span_list:
  624. replaced_span_list.append(sp.replace('>', '><b>'))
  625. for sp, rsp in zip(span_list, replaced_span_list):
  626. thead_part = thead_part.replace(sp, rsp)
  627. # replace "</td>" to "</b></td>"
  628. thead_part = thead_part.replace('</td>', '</b></td>')
  629. # remove duplicated <b> by re.sub
  630. mb_pattern = "(<b>)+"
  631. single_b_string = "<b>"
  632. thead_part = re.sub(mb_pattern, single_b_string, thead_part)
  633. mgb_pattern = "(</b>)+"
  634. single_gb_string = "</b>"
  635. thead_part = re.sub(mgb_pattern, single_gb_string, thead_part)
  636. # ordinary cases like branch 1
  637. thead_part = thead_part.replace('<td>', '<td><b>').replace('<b><b>',
  638. '<b>')
  639. # convert <tb><b></b></tb> back to <tb></tb>, empty cell has no <b></b>.
  640. # but space cell(<tb> </tb>) is suitable for <td><b> </b></td>
  641. thead_part = thead_part.replace('<td><b></b></td>', '<td></td>')
  642. # deal with duplicated <b></b>
  643. thead_part = deal_duplicate_bb(thead_part)
  644. # deal with isolate span tokens, which causes by wrong predict by structure prediction.
  645. # eg.PMC5994107_011_00.png
  646. thead_part = deal_isolate_span(thead_part)
  647. # replace original result with new thead part.
  648. result_token = result_token.replace(origin_thead_part, thead_part)
  649. return result_token
  650. class Matcher:
  651. def __init__(self, end2end_file, structure_master_file):
  652. """
  653. This class process the end2end results and structure recognition results.
  654. :param end2end_file: end2end results predict by end2end inference.
  655. :param structure_master_file: structure recognition results predict by structure master inference.
  656. """
  657. self.end2end_file = end2end_file
  658. self.structure_master_file = structure_master_file
  659. self.end2end_results = pickle_load(end2end_file, prefix='end2end')
  660. self.structure_master_results = pickle_load(
  661. structure_master_file, prefix='structure')
  662. def match(self):
  663. """
  664. Match process:
  665. pre-process : convert end2end and structure master results to xyxy, xywh ndnarray format.
  666. 1. Use pseBbox is inside masterBbox judge rule
  667. 2. Use iou between pseBbox and masterBbox rule
  668. 3. Use min distance of center point rule
  669. :return:
  670. """
  671. match_results = dict()
  672. for idx, (file_name,
  673. end2end_result) in enumerate(self.end2end_results.items()):
  674. match_list = []
  675. if file_name not in self.structure_master_results:
  676. continue
  677. structure_master_result = self.structure_master_results[file_name]
  678. end2end_xyxy_bboxes, end2end_xywh_bboxes, structure_master_xywh_bboxes, structure_master_xyxy_bboxes = \
  679. get_bboxes_list(end2end_result, structure_master_result)
  680. # rule 1: center rule
  681. center_rule_match_list = \
  682. center_rule_match(end2end_xywh_bboxes, structure_master_xyxy_bboxes)
  683. match_list.extend(center_rule_match_list)
  684. # rule 2: iou rule
  685. # firstly, find not match index in previous step.
  686. center_no_match_end2end_indexs = \
  687. find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
  688. if len(center_no_match_end2end_indexs) > 0:
  689. center_no_match_end2end_xyxy = end2end_xyxy_bboxes[
  690. center_no_match_end2end_indexs]
  691. # secondly, iou rule match
  692. iou_rule_match_list = \
  693. iou_rule_match(center_no_match_end2end_xyxy, center_no_match_end2end_indexs, structure_master_xyxy_bboxes)
  694. match_list.extend(iou_rule_match_list)
  695. # rule 3: distance rule
  696. # match between no-match end2end bboxes and no-match master bboxes.
  697. # it will return master_bboxes_nums match-pairs.
  698. # firstly, find not match index in previous step.
  699. centerIou_no_match_end2end_indexs = \
  700. find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
  701. centerIou_no_match_master_indexs = \
  702. find_no_match(match_list, len(structure_master_xywh_bboxes), type='master')
  703. if len(centerIou_no_match_master_indexs) > 0 and len(
  704. centerIou_no_match_end2end_indexs) > 0:
  705. centerIou_no_match_end2end_xywh = end2end_xywh_bboxes[
  706. centerIou_no_match_end2end_indexs]
  707. centerIou_no_match_master_xywh = structure_master_xywh_bboxes[
  708. centerIou_no_match_master_indexs]
  709. distance_match_list = distance_rule_match(
  710. centerIou_no_match_end2end_indexs,
  711. centerIou_no_match_end2end_xywh,
  712. centerIou_no_match_master_indexs,
  713. centerIou_no_match_master_xywh)
  714. match_list.extend(distance_match_list)
  715. # TODO:
  716. # The render no-match pseBbox, insert the last
  717. # After step3 distance rule, a master bbox at least match one end2end bbox.
  718. # But end2end bbox maybe overmuch, because numbers of master bbox will cut by max length.
  719. # For these render end2end bboxes, we will make some virtual master bboxes, and get matching.
  720. # The above extra insert bboxes will be further processed in "formatOutput" function.
  721. # After this operation, it will increase TEDS score.
  722. no_match_end2end_indexes = \
  723. find_no_match(match_list, len(end2end_xywh_bboxes), type='end2end')
  724. if len(no_match_end2end_indexes) > 0:
  725. no_match_end2end_xywh = end2end_xywh_bboxes[
  726. no_match_end2end_indexes]
  727. # sort the render no-match end2end bbox in row
  728. end2end_sorted_indexes_list, end2end_sorted_bboxes_list, sorted_groups, sorted_bboxes_groups = \
  729. sort_bbox(no_match_end2end_xywh, no_match_end2end_indexes)
  730. # make virtual master bboxes, and get matching with the no-match end2end bboxes.
  731. extra_match_list = extra_match(
  732. end2end_sorted_indexes_list,
  733. len(structure_master_xywh_bboxes))
  734. match_list_add_extra_match = copy.deepcopy(match_list)
  735. match_list_add_extra_match.extend(extra_match_list)
  736. else:
  737. # no no-match end2end bboxes
  738. match_list_add_extra_match = copy.deepcopy(match_list)
  739. sorted_groups = []
  740. sorted_bboxes_groups = []
  741. match_result_dict = {
  742. 'match_list': match_list,
  743. 'match_list_add_extra_match': match_list_add_extra_match,
  744. 'sorted_groups': sorted_groups,
  745. 'sorted_bboxes_groups': sorted_bboxes_groups
  746. }
  747. # format output
  748. match_result_dict = self._format(match_result_dict, file_name)
  749. match_results[file_name] = match_result_dict
  750. return match_results
  751. def _format(self, match_result, file_name):
  752. """
  753. Extend the master token(insert virtual master token), and format matching result.
  754. :param match_result:
  755. :param file_name:
  756. :return:
  757. """
  758. end2end_info = self.end2end_results[file_name]
  759. master_info = self.structure_master_results[file_name]
  760. master_token = master_info['text']
  761. sorted_groups = match_result['sorted_groups']
  762. # creat virtual master token
  763. virtual_master_token_list = []
  764. for line_group in sorted_groups:
  765. tmp_list = ['<tr>']
  766. item_nums = len(line_group)
  767. for _ in range(item_nums):
  768. tmp_list.append('<td></td>')
  769. tmp_list.append('</tr>')
  770. virtual_master_token_list.extend(tmp_list)
  771. # insert virtual master token
  772. master_token_list = master_token.split(',')
  773. if master_token_list[-1] == '</tbody>':
  774. # complete predict(no cut by max length)
  775. # This situation insert virtual master token will drop TEDs score in val set.
  776. # So we will not extend virtual token in this situation.
  777. # fake extend virtual
  778. master_token_list[:-1].extend(virtual_master_token_list)
  779. # real extend virtual
  780. # master_token_list = master_token_list[:-1]
  781. # master_token_list.extend(virtual_master_token_list)
  782. # master_token_list.append('</tbody>')
  783. elif master_token_list[-1] == '<td></td>':
  784. master_token_list.append('</tr>')
  785. master_token_list.extend(virtual_master_token_list)
  786. master_token_list.append('</tbody>')
  787. else:
  788. master_token_list.extend(virtual_master_token_list)
  789. master_token_list.append('</tbody>')
  790. # format output
  791. match_result.setdefault('matched_master_token_list', master_token_list)
  792. return match_result
  793. def get_merge_result(self, match_results):
  794. """
  795. Merge the OCR result into structure token to get final results.
  796. :param match_results:
  797. :return:
  798. """
  799. merged_results = dict()
  800. # break_token is linefeed token, when one master bbox has multiply end2end bboxes.
  801. break_token = ' '
  802. for idx, (file_name, match_info) in enumerate(match_results.items()):
  803. end2end_info = self.end2end_results[file_name]
  804. master_token_list = match_info['matched_master_token_list']
  805. match_list = match_info['match_list_add_extra_match']
  806. match_dict = get_match_dict(match_list)
  807. match_text_dict = get_match_text_dict(match_dict, end2end_info,
  808. break_token)
  809. merged_result = insert_text_to_token(master_token_list,
  810. match_text_dict)
  811. merged_result = deal_bb(merged_result)
  812. merged_results[file_name] = merged_result
  813. return merged_results
  814. class TableMasterMatcher(Matcher):
  815. def __init__(self):
  816. pass
  817. def __call__(self, structure_res, dt_boxes, rec_res, img_name=1):
  818. end2end_results = {img_name: []}
  819. for dt_box, res in zip(dt_boxes, rec_res):
  820. d = dict(
  821. bbox=np.array(dt_box),
  822. text=res[0], )
  823. end2end_results[img_name].append(d)
  824. self.end2end_results = end2end_results
  825. structure_master_result_dict = {img_name: {}}
  826. pred_structures, pred_bboxes = structure_res
  827. pred_structures = ','.join(pred_structures[3:-3])
  828. structure_master_result_dict[img_name]['text'] = pred_structures
  829. structure_master_result_dict[img_name]['bbox'] = pred_bboxes
  830. self.structure_master_results = structure_master_result_dict
  831. # match
  832. match_results = self.match()
  833. merged_results = self.get_merge_result(match_results)
  834. pred_html = merged_results[img_name]
  835. pred_html = '<html><body><table>' + pred_html + '</table></body></html>'
  836. return pred_html