table_postprocess.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from .rec_postprocess import AttnLabelDecode
  17. class TableLabelDecode(AttnLabelDecode):
  18. """ """
  19. def __init__(self,
  20. character_dict_path,
  21. merge_no_span_structure=False,
  22. **kwargs):
  23. dict_character = []
  24. with open(character_dict_path, "rb") as fin:
  25. lines = fin.readlines()
  26. for line in lines:
  27. line = line.decode('utf-8').strip("\n").strip("\r\n")
  28. dict_character.append(line)
  29. if merge_no_span_structure:
  30. if "<td></td>" not in dict_character:
  31. dict_character.append("<td></td>")
  32. if "<td>" in dict_character:
  33. dict_character.remove("<td>")
  34. dict_character = self.add_special_char(dict_character)
  35. self.dict = {}
  36. for i, char in enumerate(dict_character):
  37. self.dict[char] = i
  38. self.character = dict_character
  39. self.td_token = ['<td>', '<td', '<td></td>']
  40. def __call__(self, preds, batch=None):
  41. structure_probs = preds['structure_probs']
  42. bbox_preds = preds['loc_preds']
  43. if isinstance(structure_probs, paddle.Tensor):
  44. structure_probs = structure_probs.numpy()
  45. if isinstance(bbox_preds, paddle.Tensor):
  46. bbox_preds = bbox_preds.numpy()
  47. shape_list = batch[-1]
  48. result = self.decode(structure_probs, bbox_preds, shape_list)
  49. if len(batch) == 1: # only contains shape
  50. return result
  51. label_decode_result = self.decode_label(batch)
  52. return result, label_decode_result
  53. def decode(self, structure_probs, bbox_preds, shape_list):
  54. """convert text-label into text-index.
  55. """
  56. ignored_tokens = self.get_ignored_tokens()
  57. end_idx = self.dict[self.end_str]
  58. structure_idx = structure_probs.argmax(axis=2)
  59. structure_probs = structure_probs.max(axis=2)
  60. structure_batch_list = []
  61. bbox_batch_list = []
  62. batch_size = len(structure_idx)
  63. for batch_idx in range(batch_size):
  64. structure_list = []
  65. bbox_list = []
  66. score_list = []
  67. for idx in range(len(structure_idx[batch_idx])):
  68. char_idx = int(structure_idx[batch_idx][idx])
  69. if idx > 0 and char_idx == end_idx:
  70. break
  71. if char_idx in ignored_tokens:
  72. continue
  73. text = self.character[char_idx]
  74. if text in self.td_token:
  75. bbox = bbox_preds[batch_idx, idx]
  76. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  77. bbox_list.append(bbox)
  78. structure_list.append(text)
  79. score_list.append(structure_probs[batch_idx, idx])
  80. structure_batch_list.append([structure_list, np.mean(score_list)])
  81. bbox_batch_list.append(np.array(bbox_list))
  82. result = {
  83. 'bbox_batch_list': bbox_batch_list,
  84. 'structure_batch_list': structure_batch_list,
  85. }
  86. return result
  87. def decode_label(self, batch):
  88. """convert text-label into text-index.
  89. """
  90. structure_idx = batch[1]
  91. gt_bbox_list = batch[2]
  92. shape_list = batch[-1]
  93. ignored_tokens = self.get_ignored_tokens()
  94. end_idx = self.dict[self.end_str]
  95. structure_batch_list = []
  96. bbox_batch_list = []
  97. batch_size = len(structure_idx)
  98. for batch_idx in range(batch_size):
  99. structure_list = []
  100. bbox_list = []
  101. for idx in range(len(structure_idx[batch_idx])):
  102. char_idx = int(structure_idx[batch_idx][idx])
  103. if idx > 0 and char_idx == end_idx:
  104. break
  105. if char_idx in ignored_tokens:
  106. continue
  107. structure_list.append(self.character[char_idx])
  108. bbox = gt_bbox_list[batch_idx][idx]
  109. if bbox.sum() != 0:
  110. bbox = self._bbox_decode(bbox, shape_list[batch_idx])
  111. bbox_list.append(bbox)
  112. structure_batch_list.append(structure_list)
  113. bbox_batch_list.append(bbox_list)
  114. result = {
  115. 'bbox_batch_list': bbox_batch_list,
  116. 'structure_batch_list': structure_batch_list,
  117. }
  118. return result
  119. def _bbox_decode(self, bbox, shape):
  120. h, w, ratio_h, ratio_w, pad_h, pad_w = shape
  121. bbox[0::2] *= w
  122. bbox[1::2] *= h
  123. return bbox
  124. class TableMasterLabelDecode(TableLabelDecode):
  125. """ """
  126. def __init__(self,
  127. character_dict_path,
  128. box_shape='ori',
  129. merge_no_span_structure=True,
  130. **kwargs):
  131. super(TableMasterLabelDecode, self).__init__(character_dict_path,
  132. merge_no_span_structure)
  133. self.box_shape = box_shape
  134. assert box_shape in [
  135. 'ori', 'pad'
  136. ], 'The shape used for box normalization must be ori or pad'
  137. def add_special_char(self, dict_character):
  138. self.beg_str = '<SOS>'
  139. self.end_str = '<EOS>'
  140. self.unknown_str = '<UKN>'
  141. self.pad_str = '<PAD>'
  142. dict_character = dict_character
  143. dict_character = dict_character + [
  144. self.unknown_str, self.beg_str, self.end_str, self.pad_str
  145. ]
  146. return dict_character
  147. def get_ignored_tokens(self):
  148. pad_idx = self.dict[self.pad_str]
  149. start_idx = self.dict[self.beg_str]
  150. end_idx = self.dict[self.end_str]
  151. unknown_idx = self.dict[self.unknown_str]
  152. return [start_idx, end_idx, pad_idx, unknown_idx]
  153. def _bbox_decode(self, bbox, shape):
  154. h, w, ratio_h, ratio_w, pad_h, pad_w = shape
  155. if self.box_shape == 'pad':
  156. h, w = pad_h, pad_w
  157. bbox[0::2] *= w
  158. bbox[1::2] *= h
  159. bbox[0::2] /= ratio_w
  160. bbox[1::2] /= ratio_h
  161. x, y, w, h = bbox
  162. x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
  163. bbox = np.array([x1, y1, x2, y2])
  164. return bbox