extract_textpoint_fast.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Contains various CTC decoders."""
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import cv2
  19. import math
  20. import numpy as np
  21. from itertools import groupby
  22. from skimage.morphology._skeletonize import thin
  23. def get_dict(character_dict_path):
  24. character_str = ""
  25. with open(character_dict_path, "rb") as fin:
  26. lines = fin.readlines()
  27. for line in lines:
  28. line = line.decode('utf-8').strip("\n").strip("\r\n")
  29. character_str += line
  30. dict_character = list(character_str)
  31. return dict_character
  32. def softmax(logits):
  33. """
  34. logits: N x d
  35. """
  36. max_value = np.max(logits, axis=1, keepdims=True)
  37. exp = np.exp(logits - max_value)
  38. exp_sum = np.sum(exp, axis=1, keepdims=True)
  39. dist = exp / exp_sum
  40. return dist
  41. def get_keep_pos_idxs(labels, remove_blank=None):
  42. """
  43. Remove duplicate and get pos idxs of keep items.
  44. The value of keep_blank should be [None, 95].
  45. """
  46. duplicate_len_list = []
  47. keep_pos_idx_list = []
  48. keep_char_idx_list = []
  49. for k, v_ in groupby(labels):
  50. current_len = len(list(v_))
  51. if k != remove_blank:
  52. current_idx = int(sum(duplicate_len_list) + current_len // 2)
  53. keep_pos_idx_list.append(current_idx)
  54. keep_char_idx_list.append(k)
  55. duplicate_len_list.append(current_len)
  56. return keep_char_idx_list, keep_pos_idx_list
  57. def remove_blank(labels, blank=0):
  58. new_labels = [x for x in labels if x != blank]
  59. return new_labels
  60. def insert_blank(labels, blank=0):
  61. new_labels = [blank]
  62. for l in labels:
  63. new_labels += [l, blank]
  64. return new_labels
  65. def ctc_greedy_decoder(probs_seq, blank=95, keep_blank_in_idxs=True):
  66. """
  67. CTC greedy (best path) decoder.
  68. """
  69. raw_str = np.argmax(np.array(probs_seq), axis=1)
  70. remove_blank_in_pos = None if keep_blank_in_idxs else blank
  71. dedup_str, keep_idx_list = get_keep_pos_idxs(
  72. raw_str, remove_blank=remove_blank_in_pos)
  73. dst_str = remove_blank(dedup_str, blank=blank)
  74. return dst_str, keep_idx_list
  75. def instance_ctc_greedy_decoder(gather_info,
  76. logits_map,
  77. pts_num=4,
  78. point_gather_mode=None):
  79. _, _, C = logits_map.shape
  80. if point_gather_mode == 'align':
  81. insert_num = 0
  82. gather_info = np.array(gather_info)
  83. length = len(gather_info) - 1
  84. for index in range(length):
  85. stride_y = np.abs(gather_info[index + insert_num][0] - gather_info[
  86. index + 1 + insert_num][0])
  87. stride_x = np.abs(gather_info[index + insert_num][1] - gather_info[
  88. index + 1 + insert_num][1])
  89. max_points = int(max(stride_x, stride_y))
  90. stride = (gather_info[index + insert_num] -
  91. gather_info[index + 1 + insert_num]) / (max_points)
  92. insert_num_temp = max_points - 1
  93. for i in range(int(insert_num_temp)):
  94. insert_value = gather_info[index + insert_num] - (i + 1
  95. ) * stride
  96. insert_index = index + i + 1 + insert_num
  97. gather_info = np.insert(
  98. gather_info, insert_index, insert_value, axis=0)
  99. insert_num += insert_num_temp
  100. gather_info = gather_info.tolist()
  101. else:
  102. pass
  103. ys, xs = zip(*gather_info)
  104. logits_seq = logits_map[list(ys), list(xs)]
  105. probs_seq = logits_seq
  106. labels = np.argmax(probs_seq, axis=1)
  107. dst_str = [k for k, v_ in groupby(labels) if k != C - 1]
  108. detal = len(gather_info) // (pts_num - 1)
  109. keep_idx_list = [0] + [detal * (i + 1) for i in range(pts_num - 2)] + [-1]
  110. keep_gather_list = [gather_info[idx] for idx in keep_idx_list]
  111. return dst_str, keep_gather_list
  112. def ctc_decoder_for_image(gather_info_list,
  113. logits_map,
  114. Lexicon_Table,
  115. pts_num=6,
  116. point_gather_mode=None):
  117. """
  118. CTC decoder using multiple processes.
  119. """
  120. decoder_str = []
  121. decoder_xys = []
  122. for gather_info in gather_info_list:
  123. if len(gather_info) < pts_num:
  124. continue
  125. dst_str, xys_list = instance_ctc_greedy_decoder(
  126. gather_info,
  127. logits_map,
  128. pts_num=pts_num,
  129. point_gather_mode=point_gather_mode)
  130. dst_str_readable = ''.join([Lexicon_Table[idx] for idx in dst_str])
  131. if len(dst_str_readable) < 2:
  132. continue
  133. decoder_str.append(dst_str_readable)
  134. decoder_xys.append(xys_list)
  135. return decoder_str, decoder_xys
  136. def sort_with_direction(pos_list, f_direction):
  137. """
  138. f_direction: h x w x 2
  139. pos_list: [[y, x], [y, x], [y, x] ...]
  140. """
  141. def sort_part_with_direction(pos_list, point_direction):
  142. pos_list = np.array(pos_list).reshape(-1, 2)
  143. point_direction = np.array(point_direction).reshape(-1, 2)
  144. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  145. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  146. sorted_list = pos_list[np.argsort(pos_proj_leng)].tolist()
  147. sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
  148. return sorted_list, sorted_direction
  149. pos_list = np.array(pos_list).reshape(-1, 2)
  150. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
  151. point_direction = point_direction[:, ::-1] # x, y -> y, x
  152. sorted_point, sorted_direction = sort_part_with_direction(pos_list,
  153. point_direction)
  154. point_num = len(sorted_point)
  155. if point_num >= 16:
  156. middle_num = point_num // 2
  157. first_part_point = sorted_point[:middle_num]
  158. first_point_direction = sorted_direction[:middle_num]
  159. sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
  160. first_part_point, first_point_direction)
  161. last_part_point = sorted_point[middle_num:]
  162. last_point_direction = sorted_direction[middle_num:]
  163. sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
  164. last_part_point, last_point_direction)
  165. sorted_point = sorted_fist_part_point + sorted_last_part_point
  166. sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
  167. return sorted_point, np.array(sorted_direction)
  168. def add_id(pos_list, image_id=0):
  169. """
  170. Add id for gather feature, for inference.
  171. """
  172. new_list = []
  173. for item in pos_list:
  174. new_list.append((image_id, item[0], item[1]))
  175. return new_list
  176. def sort_and_expand_with_direction(pos_list, f_direction):
  177. """
  178. f_direction: h x w x 2
  179. pos_list: [[y, x], [y, x], [y, x] ...]
  180. """
  181. h, w, _ = f_direction.shape
  182. sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
  183. point_num = len(sorted_list)
  184. sub_direction_len = max(point_num // 3, 2)
  185. left_direction = point_direction[:sub_direction_len, :]
  186. right_dirction = point_direction[point_num - sub_direction_len:, :]
  187. left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
  188. left_average_len = np.linalg.norm(left_average_direction)
  189. left_start = np.array(sorted_list[0])
  190. left_step = left_average_direction / (left_average_len + 1e-6)
  191. right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
  192. right_average_len = np.linalg.norm(right_average_direction)
  193. right_step = right_average_direction / (right_average_len + 1e-6)
  194. right_start = np.array(sorted_list[-1])
  195. append_num = max(
  196. int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
  197. left_list = []
  198. right_list = []
  199. for i in range(append_num):
  200. ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
  201. 'int32').tolist()
  202. if ly < h and lx < w and (ly, lx) not in left_list:
  203. left_list.append((ly, lx))
  204. ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
  205. 'int32').tolist()
  206. if ry < h and rx < w and (ry, rx) not in right_list:
  207. right_list.append((ry, rx))
  208. all_list = left_list[::-1] + sorted_list + right_list
  209. return all_list
  210. def sort_and_expand_with_direction_v2(pos_list, f_direction, binary_tcl_map):
  211. """
  212. f_direction: h x w x 2
  213. pos_list: [[y, x], [y, x], [y, x] ...]
  214. binary_tcl_map: h x w
  215. """
  216. h, w, _ = f_direction.shape
  217. sorted_list, point_direction = sort_with_direction(pos_list, f_direction)
  218. point_num = len(sorted_list)
  219. sub_direction_len = max(point_num // 3, 2)
  220. left_direction = point_direction[:sub_direction_len, :]
  221. right_dirction = point_direction[point_num - sub_direction_len:, :]
  222. left_average_direction = -np.mean(left_direction, axis=0, keepdims=True)
  223. left_average_len = np.linalg.norm(left_average_direction)
  224. left_start = np.array(sorted_list[0])
  225. left_step = left_average_direction / (left_average_len + 1e-6)
  226. right_average_direction = np.mean(right_dirction, axis=0, keepdims=True)
  227. right_average_len = np.linalg.norm(right_average_direction)
  228. right_step = right_average_direction / (right_average_len + 1e-6)
  229. right_start = np.array(sorted_list[-1])
  230. append_num = max(
  231. int((left_average_len + right_average_len) / 2.0 * 0.15), 1)
  232. max_append_num = 2 * append_num
  233. left_list = []
  234. right_list = []
  235. for i in range(max_append_num):
  236. ly, lx = np.round(left_start + left_step * (i + 1)).flatten().astype(
  237. 'int32').tolist()
  238. if ly < h and lx < w and (ly, lx) not in left_list:
  239. if binary_tcl_map[ly, lx] > 0.5:
  240. left_list.append((ly, lx))
  241. else:
  242. break
  243. for i in range(max_append_num):
  244. ry, rx = np.round(right_start + right_step * (i + 1)).flatten().astype(
  245. 'int32').tolist()
  246. if ry < h and rx < w and (ry, rx) not in right_list:
  247. if binary_tcl_map[ry, rx] > 0.5:
  248. right_list.append((ry, rx))
  249. else:
  250. break
  251. all_list = left_list[::-1] + sorted_list + right_list
  252. return all_list
  253. def point_pair2poly(point_pair_list):
  254. """
  255. Transfer vertical point_pairs into poly point in clockwise.
  256. """
  257. point_num = len(point_pair_list) * 2
  258. point_list = [0] * point_num
  259. for idx, point_pair in enumerate(point_pair_list):
  260. point_list[idx] = point_pair[0]
  261. point_list[point_num - 1 - idx] = point_pair[1]
  262. return np.array(point_list).reshape(-1, 2)
  263. def shrink_quad_along_width(quad, begin_width_ratio=0., end_width_ratio=1.):
  264. ratio_pair = np.array(
  265. [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
  266. p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
  267. p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
  268. return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
  269. def expand_poly_along_width(poly, shrink_ratio_of_width=0.3):
  270. """
  271. expand poly along width.
  272. """
  273. point_num = poly.shape[0]
  274. left_quad = np.array(
  275. [poly[0], poly[1], poly[-2], poly[-1]], dtype=np.float32)
  276. left_ratio = -shrink_ratio_of_width * np.linalg.norm(left_quad[0] - left_quad[3]) / \
  277. (np.linalg.norm(left_quad[0] - left_quad[1]) + 1e-6)
  278. left_quad_expand = shrink_quad_along_width(left_quad, left_ratio, 1.0)
  279. right_quad = np.array(
  280. [
  281. poly[point_num // 2 - 2], poly[point_num // 2 - 1],
  282. poly[point_num // 2], poly[point_num // 2 + 1]
  283. ],
  284. dtype=np.float32)
  285. right_ratio = 1.0 + shrink_ratio_of_width * np.linalg.norm(right_quad[0] - right_quad[3]) / \
  286. (np.linalg.norm(right_quad[0] - right_quad[1]) + 1e-6)
  287. right_quad_expand = shrink_quad_along_width(right_quad, 0.0, right_ratio)
  288. poly[0] = left_quad_expand[0]
  289. poly[-1] = left_quad_expand[-1]
  290. poly[point_num // 2 - 1] = right_quad_expand[1]
  291. poly[point_num // 2] = right_quad_expand[2]
  292. return poly
  293. def restore_poly(instance_yxs_list, seq_strs, p_border, ratio_w, ratio_h, src_w,
  294. src_h, valid_set):
  295. poly_list = []
  296. keep_str_list = []
  297. for yx_center_line, keep_str in zip(instance_yxs_list, seq_strs):
  298. if len(keep_str) < 2:
  299. print('--> too short, {}'.format(keep_str))
  300. continue
  301. offset_expand = 1.0
  302. if valid_set == 'totaltext':
  303. offset_expand = 1.2
  304. point_pair_list = []
  305. for y, x in yx_center_line:
  306. offset = p_border[:, y, x].reshape(2, 2) * offset_expand
  307. ori_yx = np.array([y, x], dtype=np.float32)
  308. point_pair = (ori_yx + offset)[:, ::-1] * 4.0 / np.array(
  309. [ratio_w, ratio_h]).reshape(-1, 2)
  310. point_pair_list.append(point_pair)
  311. detected_poly = point_pair2poly(point_pair_list)
  312. detected_poly = expand_poly_along_width(
  313. detected_poly, shrink_ratio_of_width=0.2)
  314. detected_poly[:, 0] = np.clip(detected_poly[:, 0], a_min=0, a_max=src_w)
  315. detected_poly[:, 1] = np.clip(detected_poly[:, 1], a_min=0, a_max=src_h)
  316. keep_str_list.append(keep_str)
  317. if valid_set == 'partvgg':
  318. middle_point = len(detected_poly) // 2
  319. detected_poly = detected_poly[
  320. [0, middle_point - 1, middle_point, -1], :]
  321. poly_list.append(detected_poly)
  322. elif valid_set == 'totaltext':
  323. poly_list.append(detected_poly)
  324. else:
  325. print('--> Not supported format.')
  326. exit(-1)
  327. return poly_list, keep_str_list
  328. def generate_pivot_list_fast(p_score,
  329. p_char_maps,
  330. f_direction,
  331. Lexicon_Table,
  332. score_thresh=0.5,
  333. point_gather_mode=None):
  334. """
  335. return center point and end point of TCL instance; filter with the char maps;
  336. """
  337. p_score = p_score[0]
  338. f_direction = f_direction.transpose(1, 2, 0)
  339. p_tcl_map = (p_score > score_thresh) * 1.0
  340. skeleton_map = thin(p_tcl_map.astype(np.uint8))
  341. instance_count, instance_label_map = cv2.connectedComponents(
  342. skeleton_map.astype(np.uint8), connectivity=8)
  343. # get TCL Instance
  344. all_pos_yxs = []
  345. if instance_count > 0:
  346. for instance_id in range(1, instance_count):
  347. pos_list = []
  348. ys, xs = np.where(instance_label_map == instance_id)
  349. pos_list = list(zip(ys, xs))
  350. if len(pos_list) < 3:
  351. continue
  352. pos_list_sorted = sort_and_expand_with_direction_v2(
  353. pos_list, f_direction, p_tcl_map)
  354. all_pos_yxs.append(pos_list_sorted)
  355. p_char_maps = p_char_maps.transpose([1, 2, 0])
  356. decoded_str, keep_yxs_list = ctc_decoder_for_image(
  357. all_pos_yxs,
  358. logits_map=p_char_maps,
  359. Lexicon_Table=Lexicon_Table,
  360. point_gather_mode=point_gather_mode)
  361. return keep_yxs_list, decoded_str
  362. def extract_main_direction(pos_list, f_direction):
  363. """
  364. f_direction: h x w x 2
  365. pos_list: [[y, x], [y, x], [y, x] ...]
  366. """
  367. pos_list = np.array(pos_list)
  368. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]]
  369. point_direction = point_direction[:, ::-1] # x, y -> y, x
  370. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  371. average_direction = average_direction / (
  372. np.linalg.norm(average_direction) + 1e-6)
  373. return average_direction
  374. def sort_by_direction_with_image_id_deprecated(pos_list, f_direction):
  375. """
  376. f_direction: h x w x 2
  377. pos_list: [[id, y, x], [id, y, x], [id, y, x] ...]
  378. """
  379. pos_list_full = np.array(pos_list).reshape(-1, 3)
  380. pos_list = pos_list_full[:, 1:]
  381. point_direction = f_direction[pos_list[:, 0], pos_list[:, 1]] # x, y
  382. point_direction = point_direction[:, ::-1] # x, y -> y, x
  383. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  384. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  385. sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
  386. return sorted_list
  387. def sort_by_direction_with_image_id(pos_list, f_direction):
  388. """
  389. f_direction: h x w x 2
  390. pos_list: [[y, x], [y, x], [y, x] ...]
  391. """
  392. def sort_part_with_direction(pos_list_full, point_direction):
  393. pos_list_full = np.array(pos_list_full).reshape(-1, 3)
  394. pos_list = pos_list_full[:, 1:]
  395. point_direction = np.array(point_direction).reshape(-1, 2)
  396. average_direction = np.mean(point_direction, axis=0, keepdims=True)
  397. pos_proj_leng = np.sum(pos_list * average_direction, axis=1)
  398. sorted_list = pos_list_full[np.argsort(pos_proj_leng)].tolist()
  399. sorted_direction = point_direction[np.argsort(pos_proj_leng)].tolist()
  400. return sorted_list, sorted_direction
  401. pos_list = np.array(pos_list).reshape(-1, 3)
  402. point_direction = f_direction[pos_list[:, 1], pos_list[:, 2]] # x, y
  403. point_direction = point_direction[:, ::-1] # x, y -> y, x
  404. sorted_point, sorted_direction = sort_part_with_direction(pos_list,
  405. point_direction)
  406. point_num = len(sorted_point)
  407. if point_num >= 16:
  408. middle_num = point_num // 2
  409. first_part_point = sorted_point[:middle_num]
  410. first_point_direction = sorted_direction[:middle_num]
  411. sorted_fist_part_point, sorted_fist_part_direction = sort_part_with_direction(
  412. first_part_point, first_point_direction)
  413. last_part_point = sorted_point[middle_num:]
  414. last_point_direction = sorted_direction[middle_num:]
  415. sorted_last_part_point, sorted_last_part_direction = sort_part_with_direction(
  416. last_part_point, last_point_direction)
  417. sorted_point = sorted_fist_part_point + sorted_last_part_point
  418. sorted_direction = sorted_fist_part_direction + sorted_last_part_direction
  419. return sorted_point