table_ops.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229
  1. """
  2. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. from __future__ import unicode_literals
  20. import sys
  21. import six
  22. import cv2
  23. import numpy as np
  24. class GenTableMask(object):
  25. """ gen table mask """
  26. def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
  27. self.shrink_h_max = 5
  28. self.shrink_w_max = 5
  29. self.mask_type = mask_type
  30. def projection(self, erosion, h, w, spilt_threshold=0):
  31. # 水平投影
  32. projection_map = np.ones_like(erosion)
  33. project_val_array = [0 for _ in range(0, h)]
  34. for j in range(0, h):
  35. for i in range(0, w):
  36. if erosion[j, i] == 255:
  37. project_val_array[j] += 1
  38. # 根据数组,获取切割点
  39. start_idx = 0 # 记录进入字符区的索引
  40. end_idx = 0 # 记录进入空白区域的索引
  41. in_text = False # 是否遍历到了字符区内
  42. box_list = []
  43. for i in range(len(project_val_array)):
  44. if in_text == False and project_val_array[
  45. i] > spilt_threshold: # 进入字符区了
  46. in_text = True
  47. start_idx = i
  48. elif project_val_array[
  49. i] <= spilt_threshold and in_text == True: # 进入空白区了
  50. end_idx = i
  51. in_text = False
  52. if end_idx - start_idx <= 2:
  53. continue
  54. box_list.append((start_idx, end_idx + 1))
  55. if in_text:
  56. box_list.append((start_idx, h - 1))
  57. # 绘制投影直方图
  58. for j in range(0, h):
  59. for i in range(0, project_val_array[j]):
  60. projection_map[j, i] = 0
  61. return box_list, projection_map
  62. def projection_cx(self, box_img):
  63. box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
  64. h, w = box_gray_img.shape
  65. # 灰度图片进行二值化处理
  66. ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
  67. cv2.THRESH_BINARY_INV)
  68. # 纵向腐蚀
  69. if h < w:
  70. kernel = np.ones((2, 1), np.uint8)
  71. erode = cv2.erode(thresh1, kernel, iterations=1)
  72. else:
  73. erode = thresh1
  74. # 水平膨胀
  75. kernel = np.ones((1, 5), np.uint8)
  76. erosion = cv2.dilate(erode, kernel, iterations=1)
  77. # 水平投影
  78. projection_map = np.ones_like(erosion)
  79. project_val_array = [0 for _ in range(0, h)]
  80. for j in range(0, h):
  81. for i in range(0, w):
  82. if erosion[j, i] == 255:
  83. project_val_array[j] += 1
  84. # 根据数组,获取切割点
  85. start_idx = 0 # 记录进入字符区的索引
  86. end_idx = 0 # 记录进入空白区域的索引
  87. in_text = False # 是否遍历到了字符区内
  88. box_list = []
  89. spilt_threshold = 0
  90. for i in range(len(project_val_array)):
  91. if in_text == False and project_val_array[
  92. i] > spilt_threshold: # 进入字符区了
  93. in_text = True
  94. start_idx = i
  95. elif project_val_array[
  96. i] <= spilt_threshold and in_text == True: # 进入空白区了
  97. end_idx = i
  98. in_text = False
  99. if end_idx - start_idx <= 2:
  100. continue
  101. box_list.append((start_idx, end_idx + 1))
  102. if in_text:
  103. box_list.append((start_idx, h - 1))
  104. # 绘制投影直方图
  105. for j in range(0, h):
  106. for i in range(0, project_val_array[j]):
  107. projection_map[j, i] = 0
  108. split_bbox_list = []
  109. if len(box_list) > 1:
  110. for i, (h_start, h_end) in enumerate(box_list):
  111. if i == 0:
  112. h_start = 0
  113. if i == len(box_list):
  114. h_end = h
  115. word_img = erosion[h_start:h_end + 1, :]
  116. word_h, word_w = word_img.shape
  117. w_split_list, w_projection_map = self.projection(word_img.T,
  118. word_w, word_h)
  119. w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
  120. if h_start > 0:
  121. h_start -= 1
  122. h_end += 1
  123. word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
  124. split_bbox_list.append([w_start, h_start, w_end, h_end])
  125. else:
  126. split_bbox_list.append([0, 0, w, h])
  127. return split_bbox_list
  128. def shrink_bbox(self, bbox):
  129. left, top, right, bottom = bbox
  130. sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
  131. sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
  132. left_new = left + sh_w
  133. right_new = right - sh_w
  134. top_new = top + sh_h
  135. bottom_new = bottom - sh_h
  136. if left_new >= right_new:
  137. left_new = left
  138. right_new = right
  139. if top_new >= bottom_new:
  140. top_new = top
  141. bottom_new = bottom
  142. return [left_new, top_new, right_new, bottom_new]
  143. def __call__(self, data):
  144. img = data['image']
  145. cells = data['cells']
  146. height, width = img.shape[0:2]
  147. if self.mask_type == 1:
  148. mask_img = np.zeros((height, width), dtype=np.float32)
  149. else:
  150. mask_img = np.zeros((height, width, 3), dtype=np.float32)
  151. cell_num = len(cells)
  152. for cno in range(cell_num):
  153. if "bbox" in cells[cno]:
  154. bbox = cells[cno]['bbox']
  155. left, top, right, bottom = bbox
  156. box_img = img[top:bottom, left:right, :].copy()
  157. split_bbox_list = self.projection_cx(box_img)
  158. for sno in range(len(split_bbox_list)):
  159. split_bbox_list[sno][0] += left
  160. split_bbox_list[sno][1] += top
  161. split_bbox_list[sno][2] += left
  162. split_bbox_list[sno][3] += top
  163. for sno in range(len(split_bbox_list)):
  164. left, top, right, bottom = split_bbox_list[sno]
  165. left, top, right, bottom = self.shrink_bbox(
  166. [left, top, right, bottom])
  167. if self.mask_type == 1:
  168. mask_img[top:bottom, left:right] = 1.0
  169. data['mask_img'] = mask_img
  170. else:
  171. mask_img[top:bottom, left:right, :] = (255, 255, 255)
  172. data['image'] = mask_img
  173. return data
  174. class ResizeTableImage(object):
  175. def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
  176. **kwargs):
  177. super(ResizeTableImage, self).__init__()
  178. self.max_len = max_len
  179. self.resize_bboxes = resize_bboxes
  180. self.infer_mode = infer_mode
  181. def __call__(self, data):
  182. img = data['image']
  183. height, width = img.shape[0:2]
  184. ratio = self.max_len / (max(height, width) * 1.0)
  185. resize_h = int(height * ratio)
  186. resize_w = int(width * ratio)
  187. resize_img = cv2.resize(img, (resize_w, resize_h))
  188. if self.resize_bboxes and not self.infer_mode:
  189. data['bboxes'] = data['bboxes'] * ratio
  190. data['image'] = resize_img
  191. data['src_img'] = img
  192. data['shape'] = np.array([height, width, ratio, ratio])
  193. data['max_len'] = self.max_len
  194. return data
  195. class PaddingTableImage(object):
  196. def __init__(self, size, **kwargs):
  197. super(PaddingTableImage, self).__init__()
  198. self.size = size
  199. def __call__(self, data):
  200. img = data['image']
  201. pad_h, pad_w = self.size
  202. padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
  203. height, width = img.shape[0:2]
  204. padding_img[0:height, 0:width, :] = img.copy()
  205. data['image'] = padding_img
  206. shape = data['shape'].tolist()
  207. shape.extend([pad_h, pad_w])
  208. data['shape'] = np.array(shape)
  209. return data