123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import paddle
- import os
- import sys
- __dir__ = os.path.dirname(__file__)
- sys.path.append(__dir__)
- sys.path.append(os.path.join(__dir__, '..'))
- from extract_textpoint_slow import *
- from extract_textpoint_fast import generate_pivot_list_fast, restore_poly
- class PGNet_PostProcess(object):
- # two different post-process
- def __init__(self,
- character_dict_path,
- valid_set,
- score_thresh,
- outs_dict,
- shape_list,
- point_gather_mode=None):
- self.Lexicon_Table = get_dict(character_dict_path)
- self.valid_set = valid_set
- self.score_thresh = score_thresh
- self.outs_dict = outs_dict
- self.shape_list = shape_list
- self.point_gather_mode = point_gather_mode
- def pg_postprocess_fast(self):
- p_score = self.outs_dict['f_score']
- p_border = self.outs_dict['f_border']
- p_char = self.outs_dict['f_char']
- p_direction = self.outs_dict['f_direction']
- if isinstance(p_score, paddle.Tensor):
- p_score = p_score[0].numpy()
- p_border = p_border[0].numpy()
- p_direction = p_direction[0].numpy()
- p_char = p_char[0].numpy()
- else:
- p_score = p_score[0]
- p_border = p_border[0]
- p_direction = p_direction[0]
- p_char = p_char[0]
- src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
- instance_yxs_list, seq_strs = generate_pivot_list_fast(
- p_score,
- p_char,
- p_direction,
- self.Lexicon_Table,
- score_thresh=self.score_thresh,
- point_gather_mode=self.point_gather_mode)
- poly_list, keep_str_list = restore_poly(instance_yxs_list, seq_strs,
- p_border, ratio_w, ratio_h,
- src_w, src_h, self.valid_set)
- data = {
- 'points': poly_list,
- 'texts': keep_str_list,
- }
- return data
- def pg_postprocess_slow(self):
- p_score = self.outs_dict['f_score']
- p_border = self.outs_dict['f_border']
- p_char = self.outs_dict['f_char']
- p_direction = self.outs_dict['f_direction']
- if isinstance(p_score, paddle.Tensor):
- p_score = p_score[0].numpy()
- p_border = p_border[0].numpy()
- p_direction = p_direction[0].numpy()
- p_char = p_char[0].numpy()
- else:
- p_score = p_score[0]
- p_border = p_border[0]
- p_direction = p_direction[0]
- p_char = p_char[0]
- src_h, src_w, ratio_h, ratio_w = self.shape_list[0]
- is_curved = self.valid_set == "totaltext"
- char_seq_idx_set, instance_yxs_list = generate_pivot_list_slow(
- p_score,
- p_char,
- p_direction,
- score_thresh=self.score_thresh,
- is_backbone=True,
- is_curved=is_curved)
- seq_strs = []
- for char_idx_set in char_seq_idx_set:
- pr_str = ''.join([self.Lexicon_Table[pos] for pos in char_idx_set])
- seq_strs.append(pr_str)
- poly_list = []
- keep_str_list = []
- all_point_list = []
- all_point_pair_list = []
- for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
- if len(yx_center_line) == 1:
- yx_center_line.append(yx_center_line[-1])
- offset_expand = 1.0
- if self.valid_set == 'totaltext':
- offset_expand = 1.2
- point_pair_list = []
- for batch_id, y, x in yx_center_line:
- offset = p_border[:, y, x].reshape(2, 2)
- if offset_expand != 1.0:
- offset_length = np.linalg.norm(
- offset, axis=1, keepdims=True)
- expand_length = np.clip(
- offset_length * (offset_expand - 1),
- a_min=0.5,
- a_max=3.0)
- offset_detal = offset / offset_length * expand_length
- offset = offset + offset_detal
- ori_yx = np.array([y, x], dtype=np.float32)
- point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
- [ratio_w, ratio_h]).reshape(-1, 2)
- point_pair_list.append(point_pair)
- all_point_list.append([
- int(round(x * 4.0 / ratio_w)),
- int(round(y * 4.0 / ratio_h))
- ])
- all_point_pair_list.append(point_pair.round().astype(np.int32)
- .tolist())
- detected_poly, pair_length_info = point_pair2poly(point_pair_list)
- detected_poly = expand_poly_along_width(
- detected_poly, shrink_ratio_of_width=0.2)
- detected_poly[:, 0] = np.clip(
- detected_poly[:, 0], a_min=0, a_max=src_w)
- detected_poly[:, 1] = np.clip(
- detected_poly[:, 1], a_min=0, a_max=src_h)
- if len(keep_str) < 2:
- continue
- keep_str_list.append(keep_str)
- detected_poly = np.round(detected_poly).astype('int32')
- if self.valid_set == 'partvgg':
- middle_point = len(detected_poly) // 2
- detected_poly = detected_poly[
- [0, middle_point - 1, middle_point, -1], :]
- poly_list.append(detected_poly)
- elif self.valid_set == 'totaltext':
- poly_list.append(detected_poly)
- else:
- print('--> Not supported format.')
- exit(-1)
- data = {
- 'points': poly_list,
- 'texts': keep_str_list,
- }
- return data
|