drrg_postprocess.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/postprocess/drrg_postprocessor.py
  17. """
  18. import functools
  19. import operator
  20. import numpy as np
  21. import paddle
  22. from numpy.linalg import norm
  23. import cv2
  24. class Node:
  25. def __init__(self, ind):
  26. self.__ind = ind
  27. self.__links = set()
  28. @property
  29. def ind(self):
  30. return self.__ind
  31. @property
  32. def links(self):
  33. return set(self.__links)
  34. def add_link(self, link_node):
  35. self.__links.add(link_node)
  36. link_node.__links.add(self)
  37. def graph_propagation(edges, scores, text_comps, edge_len_thr=50.):
  38. assert edges.ndim == 2
  39. assert edges.shape[1] == 2
  40. assert edges.shape[0] == scores.shape[0]
  41. assert text_comps.ndim == 2
  42. assert isinstance(edge_len_thr, float)
  43. edges = np.sort(edges, axis=1)
  44. score_dict = {}
  45. for i, edge in enumerate(edges):
  46. if text_comps is not None:
  47. box1 = text_comps[edge[0], :8].reshape(4, 2)
  48. box2 = text_comps[edge[1], :8].reshape(4, 2)
  49. center1 = np.mean(box1, axis=0)
  50. center2 = np.mean(box2, axis=0)
  51. distance = norm(center1 - center2)
  52. if distance > edge_len_thr:
  53. scores[i] = 0
  54. if (edge[0], edge[1]) in score_dict:
  55. score_dict[edge[0], edge[1]] = 0.5 * (
  56. score_dict[edge[0], edge[1]] + scores[i])
  57. else:
  58. score_dict[edge[0], edge[1]] = scores[i]
  59. nodes = np.sort(np.unique(edges.flatten()))
  60. mapping = -1 * np.ones((np.max(nodes) + 1), dtype=np.int)
  61. mapping[nodes] = np.arange(nodes.shape[0])
  62. order_inds = mapping[edges]
  63. vertices = [Node(node) for node in nodes]
  64. for ind in order_inds:
  65. vertices[ind[0]].add_link(vertices[ind[1]])
  66. return vertices, score_dict
  67. def connected_components(nodes, score_dict, link_thr):
  68. assert isinstance(nodes, list)
  69. assert all([isinstance(node, Node) for node in nodes])
  70. assert isinstance(score_dict, dict)
  71. assert isinstance(link_thr, float)
  72. clusters = []
  73. nodes = set(nodes)
  74. while nodes:
  75. node = nodes.pop()
  76. cluster = {node}
  77. node_queue = [node]
  78. while node_queue:
  79. node = node_queue.pop(0)
  80. neighbors = set([
  81. neighbor for neighbor in node.links
  82. if score_dict[tuple(sorted([node.ind, neighbor.ind]))] >=
  83. link_thr
  84. ])
  85. neighbors.difference_update(cluster)
  86. nodes.difference_update(neighbors)
  87. cluster.update(neighbors)
  88. node_queue.extend(neighbors)
  89. clusters.append(list(cluster))
  90. return clusters
  91. def clusters2labels(clusters, num_nodes):
  92. assert isinstance(clusters, list)
  93. assert all([isinstance(cluster, list) for cluster in clusters])
  94. assert all(
  95. [isinstance(node, Node) for cluster in clusters for node in cluster])
  96. assert isinstance(num_nodes, int)
  97. node_labels = np.zeros(num_nodes)
  98. for cluster_ind, cluster in enumerate(clusters):
  99. for node in cluster:
  100. node_labels[node.ind] = cluster_ind
  101. return node_labels
  102. def remove_single(text_comps, comp_pred_labels):
  103. assert text_comps.ndim == 2
  104. assert text_comps.shape[0] == comp_pred_labels.shape[0]
  105. single_flags = np.zeros_like(comp_pred_labels)
  106. pred_labels = np.unique(comp_pred_labels)
  107. for label in pred_labels:
  108. current_label_flag = (comp_pred_labels == label)
  109. if np.sum(current_label_flag) == 1:
  110. single_flags[np.where(current_label_flag)[0][0]] = 1
  111. keep_ind = [i for i in range(len(comp_pred_labels)) if not single_flags[i]]
  112. filtered_text_comps = text_comps[keep_ind, :]
  113. filtered_labels = comp_pred_labels[keep_ind]
  114. return filtered_text_comps, filtered_labels
  115. def norm2(point1, point2):
  116. return ((point1[0] - point2[0])**2 + (point1[1] - point2[1])**2)**0.5
  117. def min_connect_path(points):
  118. assert isinstance(points, list)
  119. assert all([isinstance(point, list) for point in points])
  120. assert all([isinstance(coord, int) for point in points for coord in point])
  121. points_queue = points.copy()
  122. shortest_path = []
  123. current_edge = [[], []]
  124. edge_dict0 = {}
  125. edge_dict1 = {}
  126. current_edge[0] = points_queue[0]
  127. current_edge[1] = points_queue[0]
  128. points_queue.remove(points_queue[0])
  129. while points_queue:
  130. for point in points_queue:
  131. length0 = norm2(point, current_edge[0])
  132. edge_dict0[length0] = [point, current_edge[0]]
  133. length1 = norm2(current_edge[1], point)
  134. edge_dict1[length1] = [current_edge[1], point]
  135. key0 = min(edge_dict0.keys())
  136. key1 = min(edge_dict1.keys())
  137. if key0 <= key1:
  138. start = edge_dict0[key0][0]
  139. end = edge_dict0[key0][1]
  140. shortest_path.insert(0, [points.index(start), points.index(end)])
  141. points_queue.remove(start)
  142. current_edge[0] = start
  143. else:
  144. start = edge_dict1[key1][0]
  145. end = edge_dict1[key1][1]
  146. shortest_path.append([points.index(start), points.index(end)])
  147. points_queue.remove(end)
  148. current_edge[1] = end
  149. edge_dict0 = {}
  150. edge_dict1 = {}
  151. shortest_path = functools.reduce(operator.concat, shortest_path)
  152. shortest_path = sorted(set(shortest_path), key=shortest_path.index)
  153. return shortest_path
  154. def in_contour(cont, point):
  155. x, y = point
  156. is_inner = cv2.pointPolygonTest(cont, (int(x), int(y)), False) > 0.5
  157. return is_inner
  158. def fix_corner(top_line, bot_line, start_box, end_box):
  159. assert isinstance(top_line, list)
  160. assert all(isinstance(point, list) for point in top_line)
  161. assert isinstance(bot_line, list)
  162. assert all(isinstance(point, list) for point in bot_line)
  163. assert start_box.shape == end_box.shape == (4, 2)
  164. contour = np.array(top_line + bot_line[::-1])
  165. start_left_mid = (start_box[0] + start_box[3]) / 2
  166. start_right_mid = (start_box[1] + start_box[2]) / 2
  167. end_left_mid = (end_box[0] + end_box[3]) / 2
  168. end_right_mid = (end_box[1] + end_box[2]) / 2
  169. if not in_contour(contour, start_left_mid):
  170. top_line.insert(0, start_box[0].tolist())
  171. bot_line.insert(0, start_box[3].tolist())
  172. elif not in_contour(contour, start_right_mid):
  173. top_line.insert(0, start_box[1].tolist())
  174. bot_line.insert(0, start_box[2].tolist())
  175. if not in_contour(contour, end_left_mid):
  176. top_line.append(end_box[0].tolist())
  177. bot_line.append(end_box[3].tolist())
  178. elif not in_contour(contour, end_right_mid):
  179. top_line.append(end_box[1].tolist())
  180. bot_line.append(end_box[2].tolist())
  181. return top_line, bot_line
  182. def comps2boundaries(text_comps, comp_pred_labels):
  183. assert text_comps.ndim == 2
  184. assert len(text_comps) == len(comp_pred_labels)
  185. boundaries = []
  186. if len(text_comps) < 1:
  187. return boundaries
  188. for cluster_ind in range(0, int(np.max(comp_pred_labels)) + 1):
  189. cluster_comp_inds = np.where(comp_pred_labels == cluster_ind)
  190. text_comp_boxes = text_comps[cluster_comp_inds, :8].reshape(
  191. (-1, 4, 2)).astype(np.int32)
  192. score = np.mean(text_comps[cluster_comp_inds, -1])
  193. if text_comp_boxes.shape[0] < 1:
  194. continue
  195. elif text_comp_boxes.shape[0] > 1:
  196. centers = np.mean(text_comp_boxes, axis=1).astype(np.int32).tolist()
  197. shortest_path = min_connect_path(centers)
  198. text_comp_boxes = text_comp_boxes[shortest_path]
  199. top_line = np.mean(
  200. text_comp_boxes[:, 0:2, :], axis=1).astype(np.int32).tolist()
  201. bot_line = np.mean(
  202. text_comp_boxes[:, 2:4, :], axis=1).astype(np.int32).tolist()
  203. top_line, bot_line = fix_corner(
  204. top_line, bot_line, text_comp_boxes[0], text_comp_boxes[-1])
  205. boundary_points = top_line + bot_line[::-1]
  206. else:
  207. top_line = text_comp_boxes[0, 0:2, :].astype(np.int32).tolist()
  208. bot_line = text_comp_boxes[0, 2:4:-1, :].astype(np.int32).tolist()
  209. boundary_points = top_line + bot_line
  210. boundary = [p for coord in boundary_points for p in coord] + [score]
  211. boundaries.append(boundary)
  212. return boundaries
  213. class DRRGPostprocess(object):
  214. """Merge text components and construct boundaries of text instances.
  215. Args:
  216. link_thr (float): The edge score threshold.
  217. """
  218. def __init__(self, link_thr, **kwargs):
  219. assert isinstance(link_thr, float)
  220. self.link_thr = link_thr
  221. def __call__(self, preds, shape_list):
  222. """
  223. Args:
  224. edges (ndarray): The edge array of shape N * 2, each row is a node
  225. index pair that makes up an edge in graph.
  226. scores (ndarray): The edge score array of shape (N,).
  227. text_comps (ndarray): The text components.
  228. Returns:
  229. List[list[float]]: The predicted boundaries of text instances.
  230. """
  231. edges, scores, text_comps = preds
  232. if edges is not None:
  233. if isinstance(edges, paddle.Tensor):
  234. edges = edges.numpy()
  235. if isinstance(scores, paddle.Tensor):
  236. scores = scores.numpy()
  237. if isinstance(text_comps, paddle.Tensor):
  238. text_comps = text_comps.numpy()
  239. assert len(edges) == len(scores)
  240. assert text_comps.ndim == 2
  241. assert text_comps.shape[1] == 9
  242. vertices, score_dict = graph_propagation(edges, scores, text_comps)
  243. clusters = connected_components(vertices, score_dict, self.link_thr)
  244. pred_labels = clusters2labels(clusters, text_comps.shape[0])
  245. text_comps, pred_labels = remove_single(text_comps, pred_labels)
  246. boundaries = comps2boundaries(text_comps, pred_labels)
  247. else:
  248. boundaries = []
  249. boundaries, scores = self.resize_boundary(
  250. boundaries, (1 / shape_list[0, 2:]).tolist()[::-1])
  251. boxes_batch = [dict(points=boundaries, scores=scores)]
  252. return boxes_batch
  253. def resize_boundary(self, boundaries, scale_factor):
  254. """Rescale boundaries via scale_factor.
  255. Args:
  256. boundaries (list[list[float]]): The boundary list. Each boundary
  257. with size 2k+1 with k>=4.
  258. scale_factor(ndarray): The scale factor of size (4,).
  259. Returns:
  260. boundaries (list[list[float]]): The scaled boundaries.
  261. """
  262. boxes = []
  263. scores = []
  264. for b in boundaries:
  265. sz = len(b)
  266. scores.append(b[-1])
  267. b = (np.array(b[:sz - 1]) *
  268. (np.tile(scale_factor[:2], int(
  269. (sz - 1) / 2)).reshape(1, sz - 1))).flatten().tolist()
  270. boxes.append(np.array(b).reshape([-1, 2]))
  271. return boxes, scores