123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """Contains various CTC decoders."""
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import cv2
- import math
- import numpy as np
- from itertools import groupby
- from skimage.morphology._skeletonize import thin
- def get_dict(character_dict_path):
- character_str = ""
- with open(character_dict_path, "rb") as fin:
- lines = fin.readlines()
- for line in lines:
- line = line.decode('utf-8').strip("\n").strip("\r\n")
- character_str += line
- dict_character = list(character_str)
- return dict_character
- def softmax(logits):
- """
- logits: N x d
- """
- max_value = np.max(logits, axis=1, keepdims=True)
- exp = np.exp(logits - max_value)
- exp_sum = np.sum(exp, axis=1, keepdims=True)
- dist = exp / exp_sum
- return dist
- def get_keep_pos_idxs(labels, remove_blank=None):
- """
- Remove duplicate and get pos idxs of keep items.
- The value of keep_blank should be [None, 95].
- """
- duplicate_len_list = []
- keep_pos_idx_list = []
- keep_char_idx_list = []
- for k, v_ in groupby(labels):
- current_len = len(list(v_))
- if k != remove_blank:
- current_idx = int(sum(duplicate_len_list) + current_len // 2)
- keep_pos_idx_list.append(current_idx)
- keep_char_idx_list.append(k)
- duplicate_len_list.append(current_len)
- return keep_char_idx_list, keep_pos_idx_list
- def remove_blank(labels, blank=0):
- new_labels = [x for x in labels if x != blank]
- return new_labels
- def insert_blank(labels, blank=0):
- new_labels = [blank]
- for l in labels:
- new_labels += [l, blank]
- return new_labels
- def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
- """
- CTC greedy (best path) decoder.
- """
- raw_str = np.argmax(np.array(probs_seq), axis=1)
- remove_blank_in_pos = None if keep_blank_in_idxs else blank
- dedup_str, keep_idx_list = get_keep_pos_idxs(
- raw_str, remove_blank=remove_blank_in_pos)
- dst_str = remove_blank(dedup_str, blank=blank)
- return dst_str, keep_idx_list
- def instance_ctc_greedy_decoder(gather_info,
- logits_map,
- pts_num=4,
- point_gather_mode=None):
- _, _, C = logits_map.shape
- if point_gather_mode == 'align':
- insert_num = 0
- gather_info = np.array(gather_info)
- length = len(gather_info) - 1
- for index in range(length):
- stride_y = np.abs(gather_info[index + insert_num][0] - gather_info[
- index + 1 + insert_num][0])
- stride_x = np.abs(gather_info[index + insert_num][1] - gather_info[
- index + 1 + insert_num][1])
- max_points = int(max(stride_x, stride_y))
- stride = (gather_info[index + insert_num] -
- gather_info[index + 1 + insert_num]) / (max_points)
- insert_num_temp = max_points - 1
- for i in range(int(insert_num_temp)):
- insert_value = gather_info[index + insert_num] - (i + 1
- ) * stride
- insert_index = index + i + 1 + insert_num
- gather_info = np.insert(
- gather_info, insert_index, insert_value, axis=0)
- insert_num += insert_num_temp
- gather_info = gather_info.tolist()
- else:
- pass
- ys, xs = zip(*gather_info)
- logits_seq = logits_map[list(ys), list(xs)]
- probs_seq = logits_seq
- labels = np.argmax(probs_seq, axis=1)
- dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
- detal = len(gather_info) // (pts_num - 1)
- keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
- keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
- return dst_str, keep_gather_list
- def ctc_decoder_for_image(gather_info_list,
- logits_map,
- Lexicon_Table,
- pts_num=6,
- point_gather_mode=None):
- """
- CTC decoder using multiple processes.
- """
- decoder_str = []
- decoder_xys = []
- for gather_info in gather_info_list:
- if len(gather_info) < pts_num:
- continue
- dst_str, xys_list = instance_ctc_greedy_decoder(
- gather_info,
- logits_map,
- pts_num=pts_num,
- point_gather_mode=point_gather_mode)
- dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
- if len(dst_str_readable) < 2:
- continue
- decoder_str.append(dst_str_readable)
- decoder_xys.append(xys_list)
- return decoder_str, decoder_xys
- def sort_with_direction(pos_list, f_direction):
- """
- f_direction: h x w x 2
- pos_list: [[y, x], [y, x], [y, x] ...]
- """
- def sort_part_with_direction(pos_list, point_direction):
- pos_list = np.array(pos_list).reshape(-1, 2)
- point_direction = np.array(point_direction).reshape(-1, 2)
- average_direction = np.mean(point_direction, axis=0, keepdims=True)
- pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
- sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
- sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
- return sorted_list, sorted_direction
- pos_list = np.array(pos_list).reshape(-1, 2)
- point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
- point_direction = point_direction[:, ::-1] # x, y -> y, x
- sorted_point, sorted_direction = sort_part_with_direction(pos_list,
- point_direction)
- point_num = len(sorted_point)
- if point_num >= 16:
- middle_num = point_num // 2
- first_part_point = sorted_point[:middle_num]
- first_point_direction = sorted_direction[:middle_num]
- sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
- first_part_point, first_point_direction)
- last_part_point = sorted_point[middle_num:]
- last_point_direction = sorted_direction[middle_num:]
- sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
- last_part_point, last_point_direction)
- sorted_point = sorted_fist_part_point + sorted_last_part_point
- sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
- return sorted_point, np.array(sorted_direction)
- def add_id(pos_list, image_id=0):
- """
- Add id for gather feature, for inference.
- """
- new_list = []
- for item in pos_list:
- new_list.append((image_id, item[0], item[1]))
- return new_list
- def sort_and_expand_with_direction(pos_list, f_direction):
- """
- f_direction: h x w x 2
- pos_list: [[y, x], [y, x], [y, x] ...]
- """
- h, w, _ = f_direction.shape
- sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
- point_num = len(sorted_list)
- sub_direction_len = max(point_num // 3, 2)
- left_direction = point_direction[:sub_direction_len, :]
- right_dirction = point_direction[point_num - sub_direction_len:, :]
- left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
- left_average_len = np.linalg.norm(left_average_direction)
- left_start = np.array(sorted_list[0])
- left_step = left_average_direction / (left_average_len + 1e-6)
- right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
- right_average_len = np.linalg.norm(right_average_direction)
- right_step = right_average_direction / (right_average_len + 1e-6)
- right_start = np.array(sorted_list[-1])
- append_num = max(
- int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
- left_list = []
- right_list = []
- for i in range(append_num):
- ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
- 'int32').tolist()
- if ly < h and lx < w and (ly, lx) not in left_list:
- left_list.append((ly, lx))
- ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
- 'int32').tolist()
- if ry < h and rx < w and (ry, rx) not in right_list:
- right_list.append((ry, rx))
- all_list = left_list[::-1] + sorted_list + right_list
- return all_list
- def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
- """
- f_direction: h x w x 2
- pos_list: [[y, x], [y, x], [y, x] ...]
- binary_tcl_map: h x w
- """
- h, w, _ = f_direction.shape
- sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
- point_num = len(sorted_list)
- sub_direction_len = max(point_num // 3, 2)
- left_direction = point_direction[:sub_direction_len, :]
- right_dirction = point_direction[point_num - sub_direction_len:, :]
- left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
- left_average_len = np.linalg.norm(left_average_direction)
- left_start = np.array(sorted_list[0])
- left_step = left_average_direction / (left_average_len + 1e-6)
- right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
- right_average_len = np.linalg.norm(right_average_direction)
- right_step = right_average_direction / (right_average_len + 1e-6)
- right_start = np.array(sorted_list[-1])
- append_num = max(
- int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
- max_append_num = 2 * append_num
- left_list = []
- right_list = []
- for i in range(max_append_num):
- ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
- 'int32').tolist()
- if ly < h and lx < w and (ly, lx) not in left_list:
- if binary_tcl_map[ly, lx] > 0.5:
- left_list.append((ly, lx))
- else:
- break
- for i in range(max_append_num):
- ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
- 'int32').tolist()
- if ry < h and rx < w and (ry, rx) not in right_list:
- if binary_tcl_map[ry, rx] > 0.5:
- right_list.append((ry, rx))
- else:
- break
- all_list = left_list[::-1] + sorted_list + right_list
- return all_list
- def point_pair2poly(point_pair_list):
- """
- Transfer vertical point_pairs into poly point in clockwise.
- """
- point_num = len(point_pair_list) * 2
- point_list = [0] * point_num
- for idx, point_pair in enumerate(point_pair_list):
- point_list[idx] = point_pair[0]
- point_list[point_num - 1 - idx] = point_pair[1]
- return np.array(point_list).reshape(-1, 2)
- def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
- ratio_pair = np.array(
- [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
- p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
- p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
- return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
- def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
- """
- expand poly along width.
- """
- point_num = poly.shape[0]
- left_quad = np.array(
- [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
- left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
- (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
- left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
- right_quad = np.array(
- [
- poly[point_num // 2 - 2], poly[point_num // 2 - 1],
- poly[point_num // 2], poly[point_num // 2 + 1]
- ],
- dtype=np.float32)
- right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
- (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
- right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
- poly[0] = left_quad_expand[0]
- poly[-1] = left_quad_expand[-1]
- poly[point_num // 2 - 1] = right_quad_expand[1]
- poly[point_num // 2] = right_quad_expand[2]
- return poly
- def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
- src_h, valid_set):
- poly_list = []
- keep_str_list = []
- for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
- if len(keep_str) < 2:
- print('--> too short, {}'.format(keep_str))
- continue
- offset_expand = 1.0
- if valid_set == 'totaltext':
- offset_expand = 1.2
- point_pair_list = []
- for y, x in yx_center_line:
- offset = p_border[:, y, x].reshape(2, 2) * offset_expand
- ori_yx = np.array([y, x], dtype=np.float32)
- point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
- [ratio_w, ratio_h]).reshape(-1, 2)
- point_pair_list.append(point_pair)
- detected_poly = point_pair2poly(point_pair_list)
- detected_poly = expand_poly_along_width(
- detected_poly, shrink_ratio_of_width=0.2)
- detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
- detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
- keep_str_list.append(keep_str)
- if valid_set == 'partvgg':
- middle_point = len(detected_poly) // 2
- detected_poly = detected_poly[
- [0, middle_point - 1, middle_point, -1], :]
- poly_list.append(detected_poly)
- elif valid_set == 'totaltext':
- poly_list.append(detected_poly)
- else:
- print('--> Not supported format.')
- exit(-1)
- return poly_list, keep_str_list
- def generate_pivot_list_fast(p_score,
- p_char_maps,
- f_direction,
- Lexicon_Table,
- score_thresh=0.5,
- point_gather_mode=None):
- """
- return center point and end point of TCL instance; filter with the char maps;
- """
- p_score = p_score[0]
- f_direction = f_direction.transpose(1, 2, 0)
- p_tcl_map = (p_score > score_thresh) * 1.0
- skeleton_map = thin(p_tcl_map.astype(np.uint8))
- instance_count, instance_label_map = cv2.connectedComponents(
- skeleton_map.astype(np.uint8), connectivity=8)
- # get TCL Instance
- all_pos_yxs = []
- if instance_count > 0:
- for instance_id in range(1, instance_count):
- pos_list = []
- ys, xs = np.where(instance_label_map == instance_id)
- pos_list = list(zip(ys, xs))
- if len(pos_list) < 3:
- continue
- pos_list_sorted = sort_and_expand_with_direction_v2(
- pos_list, f_direction, p_tcl_map)
- all_pos_yxs.append(pos_list_sorted)
- p_char_maps = p_char_maps.transpose([1, 2, 0])
- decoded_str, keep_yxs_list = ctc_decoder_for_image(
- all_pos_yxs,
- logits_map=p_char_maps,
- Lexicon_Table=Lexicon_Table,
- point_gather_mode=point_gather_mode)
- return keep_yxs_list, decoded_str
- def extract_main_direction(pos_list, f_direction):
- """
- f_direction: h x w x 2
- pos_list: [[y, x], [y, x], [y, x] ...]
- """
- pos_list = np.array(pos_list)
- point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
- point_direction = point_direction[:, ::-1] # x, y -> y, x
- average_direction = np.mean(point_direction, axis=0, keepdims=True)
- average_direction = average_direction / (
- np.linalg.norm(average_direction) + 1e-6)
- return average_direction
- def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
- """
- f_direction: h x w x 2
- pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
- """
- pos_list_full = np.array(pos_list).reshape(-1, 3)
- pos_list = pos_list_full[:, 1:]
- point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
- point_direction = point_direction[:, ::-1] # x, y -> y, x
- average_direction = np.mean(point_direction, axis=0, keepdims=True)
- pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
- sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
- return sorted_list
- def sort_by_direction_with_image_id(pos_list, f_direction):
- """
- f_direction: h x w x 2
- pos_list: [[y, x], [y, x], [y, x] ...]
- """
- def sort_part_with_direction(pos_list_full, point_direction):
- pos_list_full = np.array(pos_list_full).reshape(-1, 3)
- pos_list = pos_list_full[:, 1:]
- point_direction = np.array(point_direction).reshape(-1, 2)
- average_direction = np.mean(point_direction, axis=0, keepdims=True)
- pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
- sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
- sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
- return sorted_list, sorted_direction
- pos_list = np.array(pos_list).reshape(-1, 3)
- point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
- point_direction = point_direction[:, ::-1] # x, y -> y, x
- sorted_point, sorted_direction = sort_part_with_direction(pos_list,
- point_direction)
- point_num = len(sorted_point)
- if point_num >= 16:
- middle_num = point_num // 2
- first_part_point = sorted_point[:middle_num]
- first_point_direction = sorted_direction[:middle_num]
- sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
- first_part_point, first_point_direction)
- last_part_point = sorted_point[middle_num:]
- last_point_direction = sorted_direction[middle_num:]
- sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
- last_part_point, last_point_direction)
- sorted_point = sorted_fist_part_point + sorted_last_part_point
- sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
- return sorted_point
|