123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229 |
- """
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- from __future__ import unicode_literals
- import sys
- import six
- import cv2
- import numpy as np
- class GenTableMask(object):
- """ gen table mask """
- def __init__(self, shrink_h_max, shrink_w_max, mask_type=0, **kwargs):
- self.shrink_h_max = 5
- self.shrink_w_max = 5
- self.mask_type = mask_type
- def projection(self, erosion, h, w, spilt_threshold=0):
- # 水平投影
- projection_map = np.ones_like(erosion)
- project_val_array = [0 for _ in range(0, h)]
- for j in range(0, h):
- for i in range(0, w):
- if erosion[j, i] == 255:
- project_val_array[j] += 1
- # 根据数组,获取切割点
- start_idx = 0 # 记录进入字符区的索引
- end_idx = 0 # 记录进入空白区域的索引
- in_text = False # 是否遍历到了字符区内
- box_list = []
- for i in range(len(project_val_array)):
- if in_text == False and project_val_array[
- i] > spilt_threshold: # 进入字符区了
- in_text = True
- start_idx = i
- elif project_val_array[
- i] <= spilt_threshold and in_text == True: # 进入空白区了
- end_idx = i
- in_text = False
- if end_idx - start_idx <= 2:
- continue
- box_list.append((start_idx, end_idx + 1))
- if in_text:
- box_list.append((start_idx, h - 1))
- # 绘制投影直方图
- for j in range(0, h):
- for i in range(0, project_val_array[j]):
- projection_map[j, i] = 0
- return box_list, projection_map
- def projection_cx(self, box_img):
- box_gray_img = cv2.cvtColor(box_img, cv2.COLOR_BGR2GRAY)
- h, w = box_gray_img.shape
- # 灰度图片进行二值化处理
- ret, thresh1 = cv2.threshold(box_gray_img, 200, 255,
- cv2.THRESH_BINARY_INV)
- # 纵向腐蚀
- if h < w:
- kernel = np.ones((2, 1), np.uint8)
- erode = cv2.erode(thresh1, kernel, iterations=1)
- else:
- erode = thresh1
- # 水平膨胀
- kernel = np.ones((1, 5), np.uint8)
- erosion = cv2.dilate(erode, kernel, iterations=1)
- # 水平投影
- projection_map = np.ones_like(erosion)
- project_val_array = [0 for _ in range(0, h)]
- for j in range(0, h):
- for i in range(0, w):
- if erosion[j, i] == 255:
- project_val_array[j] += 1
- # 根据数组,获取切割点
- start_idx = 0 # 记录进入字符区的索引
- end_idx = 0 # 记录进入空白区域的索引
- in_text = False # 是否遍历到了字符区内
- box_list = []
- spilt_threshold = 0
- for i in range(len(project_val_array)):
- if in_text == False and project_val_array[
- i] > spilt_threshold: # 进入字符区了
- in_text = True
- start_idx = i
- elif project_val_array[
- i] <= spilt_threshold and in_text == True: # 进入空白区了
- end_idx = i
- in_text = False
- if end_idx - start_idx <= 2:
- continue
- box_list.append((start_idx, end_idx + 1))
- if in_text:
- box_list.append((start_idx, h - 1))
- # 绘制投影直方图
- for j in range(0, h):
- for i in range(0, project_val_array[j]):
- projection_map[j, i] = 0
- split_bbox_list = []
- if len(box_list) > 1:
- for i, (h_start, h_end) in enumerate(box_list):
- if i == 0:
- h_start = 0
- if i == len(box_list):
- h_end = h
- word_img = erosion[h_start:h_end + 1, :]
- word_h, word_w = word_img.shape
- w_split_list, w_projection_map = self.projection(word_img.T,
- word_w, word_h)
- w_start, w_end = w_split_list[0][0], w_split_list[-1][1]
- if h_start > 0:
- h_start -= 1
- h_end += 1
- word_img = box_img[h_start:h_end + 1:, w_start:w_end + 1, :]
- split_bbox_list.append([w_start, h_start, w_end, h_end])
- else:
- split_bbox_list.append([0, 0, w, h])
- return split_bbox_list
- def shrink_bbox(self, bbox):
- left, top, right, bottom = bbox
- sh_h = min(max(int((bottom - top) * 0.1), 1), self.shrink_h_max)
- sh_w = min(max(int((right - left) * 0.1), 1), self.shrink_w_max)
- left_new = left + sh_w
- right_new = right - sh_w
- top_new = top + sh_h
- bottom_new = bottom - sh_h
- if left_new >= right_new:
- left_new = left
- right_new = right
- if top_new >= bottom_new:
- top_new = top
- bottom_new = bottom
- return [left_new, top_new, right_new, bottom_new]
- def __call__(self, data):
- img = data['image']
- cells = data['cells']
- height, width = img.shape[0:2]
- if self.mask_type == 1:
- mask_img = np.zeros((height, width), dtype=np.float32)
- else:
- mask_img = np.zeros((height, width, 3), dtype=np.float32)
- cell_num = len(cells)
- for cno in range(cell_num):
- if "bbox" in cells[cno]:
- bbox = cells[cno]['bbox']
- left, top, right, bottom = bbox
- box_img = img[top:bottom, left:right, :].copy()
- split_bbox_list = self.projection_cx(box_img)
- for sno in range(len(split_bbox_list)):
- split_bbox_list[sno][0] += left
- split_bbox_list[sno][1] += top
- split_bbox_list[sno][2] += left
- split_bbox_list[sno][3] += top
- for sno in range(len(split_bbox_list)):
- left, top, right, bottom = split_bbox_list[sno]
- left, top, right, bottom = self.shrink_bbox(
- [left, top, right, bottom])
- if self.mask_type == 1:
- mask_img[top:bottom, left:right] = 1.0
- data['mask_img'] = mask_img
- else:
- mask_img[top:bottom, left:right, :] = (255, 255, 255)
- data['image'] = mask_img
- return data
- class ResizeTableImage(object):
- def __init__(self, max_len, resize_bboxes=False, infer_mode=False,
- **kwargs):
- super(ResizeTableImage, self).__init__()
- self.max_len = max_len
- self.resize_bboxes = resize_bboxes
- self.infer_mode = infer_mode
- def __call__(self, data):
- img = data['image']
- height, width = img.shape[0:2]
- ratio = self.max_len / (max(height, width) * 1.0)
- resize_h = int(height * ratio)
- resize_w = int(width * ratio)
- resize_img = cv2.resize(img, (resize_w, resize_h))
- if self.resize_bboxes and not self.infer_mode:
- data['bboxes'] = data['bboxes'] * ratio
- data['image'] = resize_img
- data['src_img'] = img
- data['shape'] = np.array([height, width, ratio, ratio])
- data['max_len'] = self.max_len
- return data
- class PaddingTableImage(object):
- def __init__(self, size, **kwargs):
- super(PaddingTableImage, self).__init__()
- self.size = size
- def __call__(self, data):
- img = data['image']
- pad_h, pad_w = self.size
- padding_img = np.zeros((pad_h, pad_w, 3), dtype=np.float32)
- height, width = img.shape[0:2]
- padding_img[0:height, 0:width, :] = img.copy()
- data['image'] = padding_img
- shape = data['shape'].tolist()
- shape.extend([pad_h, pad_w])
- data['shape'] = np.array(shape)
- return data
|