proposal_local_graph.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/proposal_local_graph.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import cv2
  22. import numpy as np
  23. import paddle
  24. import paddle.nn as nn
  25. import paddle.nn.functional as F
  26. from lanms import merge_quadrangle_n9 as la_nms
  27. from ppocr.ext_op import RoIAlignRotated
  28. from .local_graph import (euclidean_distance_matrix, feature_embedding,
  29. normalize_adjacent_matrix)
  30. def fill_hole(input_mask):
  31. h, w = input_mask.shape
  32. canvas = np.zeros((h + 2, w + 2), np.uint8)
  33. canvas[1:h + 1, 1:w + 1] = input_mask.copy()
  34. mask = np.zeros((h + 4, w + 4), np.uint8)
  35. cv2.floodFill(canvas, mask, (0, 0), 1)
  36. canvas = canvas[1:h + 1, 1:w + 1].astype(np.bool)
  37. return ~canvas | input_mask
  38. class ProposalLocalGraphs:
  39. def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
  40. pooling_scale, pooling_output_size, nms_thr, min_width,
  41. max_width, comp_shrink_ratio, comp_w_h_ratio, comp_score_thr,
  42. text_region_thr, center_region_thr, center_region_area_thr):
  43. assert len(k_at_hops) == 2
  44. assert isinstance(k_at_hops, tuple)
  45. assert isinstance(num_adjacent_linkages, int)
  46. assert isinstance(node_geo_feat_len, int)
  47. assert isinstance(pooling_scale, float)
  48. assert isinstance(pooling_output_size, tuple)
  49. assert isinstance(nms_thr, float)
  50. assert isinstance(min_width, float)
  51. assert isinstance(max_width, float)
  52. assert isinstance(comp_shrink_ratio, float)
  53. assert isinstance(comp_w_h_ratio, float)
  54. assert isinstance(comp_score_thr, float)
  55. assert isinstance(text_region_thr, float)
  56. assert isinstance(center_region_thr, float)
  57. assert isinstance(center_region_area_thr, int)
  58. self.k_at_hops = k_at_hops
  59. self.active_connection = num_adjacent_linkages
  60. self.local_graph_depth = len(self.k_at_hops)
  61. self.node_geo_feat_dim = node_geo_feat_len
  62. self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
  63. self.nms_thr = nms_thr
  64. self.min_width = min_width
  65. self.max_width = max_width
  66. self.comp_shrink_ratio = comp_shrink_ratio
  67. self.comp_w_h_ratio = comp_w_h_ratio
  68. self.comp_score_thr = comp_score_thr
  69. self.text_region_thr = text_region_thr
  70. self.center_region_thr = center_region_thr
  71. self.center_region_area_thr = center_region_area_thr
  72. def propose_comps(self, score_map, top_height_map, bot_height_map, sin_map,
  73. cos_map, comp_score_thr, min_width, max_width,
  74. comp_shrink_ratio, comp_w_h_ratio):
  75. """Propose text components.
  76. Args:
  77. score_map (ndarray): The score map for NMS.
  78. top_height_map (ndarray): The predicted text height map from each
  79. pixel in text center region to top sideline.
  80. bot_height_map (ndarray): The predicted text height map from each
  81. pixel in text center region to bottom sideline.
  82. sin_map (ndarray): The predicted sin(theta) map.
  83. cos_map (ndarray): The predicted cos(theta) map.
  84. comp_score_thr (float): The score threshold of text component.
  85. min_width (float): The minimum width of text components.
  86. max_width (float): The maximum width of text components.
  87. comp_shrink_ratio (float): The shrink ratio of text components.
  88. comp_w_h_ratio (float): The width to height ratio of text
  89. components.
  90. Returns:
  91. text_comps (ndarray): The text components.
  92. """
  93. comp_centers = np.argwhere(score_map > comp_score_thr)
  94. comp_centers = comp_centers[np.argsort(comp_centers[:, 0])]
  95. y = comp_centers[:, 0]
  96. x = comp_centers[:, 1]
  97. top_height = top_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
  98. bot_height = bot_height_map[y, x].reshape((-1, 1)) * comp_shrink_ratio
  99. sin = sin_map[y, x].reshape((-1, 1))
  100. cos = cos_map[y, x].reshape((-1, 1))
  101. top_mid_pts = comp_centers + np.hstack(
  102. [top_height * sin, top_height * cos])
  103. bot_mid_pts = comp_centers - np.hstack(
  104. [bot_height * sin, bot_height * cos])
  105. width = (top_height + bot_height) * comp_w_h_ratio
  106. width = np.clip(width, min_width, max_width)
  107. r = width / 2
  108. tl = top_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
  109. tr = top_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
  110. br = bot_mid_pts[:, ::-1] + np.hstack([-r * sin, r * cos])
  111. bl = bot_mid_pts[:, ::-1] - np.hstack([-r * sin, r * cos])
  112. text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
  113. score = score_map[y, x].reshape((-1, 1))
  114. text_comps = np.hstack([text_comps, score])
  115. return text_comps
  116. def propose_comps_and_attribs(self, text_region_map, center_region_map,
  117. top_height_map, bot_height_map, sin_map,
  118. cos_map):
  119. """Generate text components and attributes.
  120. Args:
  121. text_region_map (ndarray): The predicted text region probability
  122. map.
  123. center_region_map (ndarray): The predicted text center region
  124. probability map.
  125. top_height_map (ndarray): The predicted text height map from each
  126. pixel in text center region to top sideline.
  127. bot_height_map (ndarray): The predicted text height map from each
  128. pixel in text center region to bottom sideline.
  129. sin_map (ndarray): The predicted sin(theta) map.
  130. cos_map (ndarray): The predicted cos(theta) map.
  131. Returns:
  132. comp_attribs (ndarray): The text component attributes.
  133. text_comps (ndarray): The text components.
  134. """
  135. assert (text_region_map.shape == center_region_map.shape ==
  136. top_height_map.shape == bot_height_map.shape == sin_map.shape ==
  137. cos_map.shape)
  138. text_mask = text_region_map > self.text_region_thr
  139. center_region_mask = (
  140. center_region_map > self.center_region_thr) * text_mask
  141. scale = np.sqrt(1.0 / (sin_map**2 + cos_map**2 + 1e-8))
  142. sin_map, cos_map = sin_map * scale, cos_map * scale
  143. center_region_mask = fill_hole(center_region_mask)
  144. center_region_contours, _ = cv2.findContours(
  145. center_region_mask.astype(np.uint8), cv2.RETR_TREE,
  146. cv2.CHAIN_APPROX_SIMPLE)
  147. mask_sz = center_region_map.shape
  148. comp_list = []
  149. for contour in center_region_contours:
  150. current_center_mask = np.zeros(mask_sz)
  151. cv2.drawContours(current_center_mask, [contour], -1, 1, -1)
  152. if current_center_mask.sum() <= self.center_region_area_thr:
  153. continue
  154. score_map = text_region_map * current_center_mask
  155. text_comps = self.propose_comps(
  156. score_map, top_height_map, bot_height_map, sin_map, cos_map,
  157. self.comp_score_thr, self.min_width, self.max_width,
  158. self.comp_shrink_ratio, self.comp_w_h_ratio)
  159. text_comps = la_nms(text_comps, self.nms_thr)
  160. text_comp_mask = np.zeros(mask_sz)
  161. text_comp_boxes = text_comps[:, :8].reshape(
  162. (-1, 4, 2)).astype(np.int32)
  163. cv2.drawContours(text_comp_mask, text_comp_boxes, -1, 1, -1)
  164. if (text_comp_mask * text_mask).sum() < text_comp_mask.sum() * 0.5:
  165. continue
  166. if text_comps.shape[-1] > 0:
  167. comp_list.append(text_comps)
  168. if len(comp_list) <= 0:
  169. return None, None
  170. text_comps = np.vstack(comp_list)
  171. text_comp_boxes = text_comps[:, :8].reshape((-1, 4, 2))
  172. centers = np.mean(text_comp_boxes, axis=1).astype(np.int32)
  173. x = centers[:, 0]
  174. y = centers[:, 1]
  175. scores = []
  176. for text_comp_box in text_comp_boxes:
  177. text_comp_box[:, 0] = np.clip(text_comp_box[:, 0], 0,
  178. mask_sz[1] - 1)
  179. text_comp_box[:, 1] = np.clip(text_comp_box[:, 1], 0,
  180. mask_sz[0] - 1)
  181. min_coord = np.min(text_comp_box, axis=0).astype(np.int32)
  182. max_coord = np.max(text_comp_box, axis=0).astype(np.int32)
  183. text_comp_box = text_comp_box - min_coord
  184. box_sz = (max_coord - min_coord + 1)
  185. temp_comp_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
  186. cv2.fillPoly(temp_comp_mask, [text_comp_box.astype(np.int32)], 1)
  187. temp_region_patch = text_region_map[min_coord[1]:(max_coord[1] + 1),
  188. min_coord[0]:(max_coord[0] + 1)]
  189. score = cv2.mean(temp_region_patch, temp_comp_mask)[0]
  190. scores.append(score)
  191. scores = np.array(scores).reshape((-1, 1))
  192. text_comps = np.hstack([text_comps[:, :-1], scores])
  193. h = top_height_map[y, x].reshape(
  194. (-1, 1)) + bot_height_map[y, x].reshape((-1, 1))
  195. w = np.clip(h * self.comp_w_h_ratio, self.min_width, self.max_width)
  196. sin = sin_map[y, x].reshape((-1, 1))
  197. cos = cos_map[y, x].reshape((-1, 1))
  198. x = x.reshape((-1, 1))
  199. y = y.reshape((-1, 1))
  200. comp_attribs = np.hstack([x, y, h, w, cos, sin])
  201. return comp_attribs, text_comps
  202. def generate_local_graphs(self, sorted_dist_inds, node_feats):
  203. """Generate local graphs and graph convolution network input data.
  204. Args:
  205. sorted_dist_inds (ndarray): The node indices sorted according to
  206. the Euclidean distance.
  207. node_feats (tensor): The features of nodes in graph.
  208. Returns:
  209. local_graphs_node_feats (tensor): The features of nodes in local
  210. graphs.
  211. adjacent_matrices (tensor): The adjacent matrices.
  212. pivots_knn_inds (tensor): The k-nearest neighbor indices in
  213. local graphs.
  214. pivots_local_graphs (tensor): The indices of nodes in local
  215. graphs.
  216. """
  217. assert sorted_dist_inds.ndim == 2
  218. assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
  219. node_feats.shape[0])
  220. knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
  221. pivot_local_graphs = []
  222. pivot_knns = []
  223. for pivot_ind, knn in enumerate(knn_graph):
  224. local_graph_neighbors = set(knn)
  225. for neighbor_ind in knn:
  226. local_graph_neighbors.update(
  227. set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] +
  228. 1]))
  229. local_graph_neighbors.discard(pivot_ind)
  230. pivot_local_graph = list(local_graph_neighbors)
  231. pivot_local_graph.insert(0, pivot_ind)
  232. pivot_knn = [pivot_ind] + list(knn)
  233. pivot_local_graphs.append(pivot_local_graph)
  234. pivot_knns.append(pivot_knn)
  235. num_max_nodes = max([
  236. len(pivot_local_graph) for pivot_local_graph in pivot_local_graphs
  237. ])
  238. local_graphs_node_feat = []
  239. adjacent_matrices = []
  240. pivots_knn_inds = []
  241. pivots_local_graphs = []
  242. for graph_ind, pivot_knn in enumerate(pivot_knns):
  243. pivot_local_graph = pivot_local_graphs[graph_ind]
  244. num_nodes = len(pivot_local_graph)
  245. pivot_ind = pivot_local_graph[0]
  246. node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
  247. knn_inds = paddle.cast(
  248. paddle.to_tensor([node2ind_map[i]
  249. for i in pivot_knn[1:]]), 'int64')
  250. pivot_feats = node_feats[pivot_ind]
  251. normalized_feats = node_feats[paddle.to_tensor(
  252. pivot_local_graph)] - pivot_feats
  253. adjacent_matrix = np.zeros((num_nodes, num_nodes), dtype=np.float32)
  254. for node in pivot_local_graph:
  255. neighbors = sorted_dist_inds[node, 1:self.active_connection + 1]
  256. for neighbor in neighbors:
  257. if neighbor in pivot_local_graph:
  258. adjacent_matrix[node2ind_map[node], node2ind_map[
  259. neighbor]] = 1
  260. adjacent_matrix[node2ind_map[neighbor], node2ind_map[
  261. node]] = 1
  262. adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
  263. pad_adjacent_matrix = paddle.zeros((num_max_nodes, num_max_nodes), )
  264. pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
  265. paddle.to_tensor(adjacent_matrix), 'float32')
  266. pad_normalized_feats = paddle.concat(
  267. [
  268. normalized_feats, paddle.zeros(
  269. (num_max_nodes - num_nodes, normalized_feats.shape[1]),
  270. )
  271. ],
  272. axis=0)
  273. local_graph_nodes = paddle.to_tensor(pivot_local_graph)
  274. local_graph_nodes = paddle.concat(
  275. [
  276. local_graph_nodes, paddle.zeros(
  277. [num_max_nodes - num_nodes], dtype='int64')
  278. ],
  279. axis=-1)
  280. local_graphs_node_feat.append(pad_normalized_feats)
  281. adjacent_matrices.append(pad_adjacent_matrix)
  282. pivots_knn_inds.append(knn_inds)
  283. pivots_local_graphs.append(local_graph_nodes)
  284. local_graphs_node_feat = paddle.stack(local_graphs_node_feat, 0)
  285. adjacent_matrices = paddle.stack(adjacent_matrices, 0)
  286. pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
  287. pivots_local_graphs = paddle.stack(pivots_local_graphs, 0)
  288. return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
  289. pivots_local_graphs)
  290. def __call__(self, preds, feat_maps):
  291. """Generate local graphs and graph convolutional network input data.
  292. Args:
  293. preds (tensor): The predicted maps.
  294. feat_maps (tensor): The feature maps to extract content feature of
  295. text components.
  296. Returns:
  297. none_flag (bool): The flag showing whether the number of proposed
  298. text components is 0.
  299. local_graphs_node_feats (tensor): The features of nodes in local
  300. graphs.
  301. adjacent_matrices (tensor): The adjacent matrices.
  302. pivots_knn_inds (tensor): The k-nearest neighbor indices in
  303. local graphs.
  304. pivots_local_graphs (tensor): The indices of nodes in local
  305. graphs.
  306. text_comps (ndarray): The predicted text components.
  307. """
  308. if preds.ndim == 4:
  309. assert preds.shape[0] == 1
  310. preds = paddle.squeeze(preds)
  311. pred_text_region = F.sigmoid(preds[0]).numpy()
  312. pred_center_region = F.sigmoid(preds[1]).numpy()
  313. pred_sin_map = preds[2].numpy()
  314. pred_cos_map = preds[3].numpy()
  315. pred_top_height_map = preds[4].numpy()
  316. pred_bot_height_map = preds[5].numpy()
  317. comp_attribs, text_comps = self.propose_comps_and_attribs(
  318. pred_text_region, pred_center_region, pred_top_height_map,
  319. pred_bot_height_map, pred_sin_map, pred_cos_map)
  320. if comp_attribs is None or len(comp_attribs) < 2:
  321. none_flag = True
  322. return none_flag, (0, 0, 0, 0, 0)
  323. comp_centers = comp_attribs[:, 0:2]
  324. distance_matrix = euclidean_distance_matrix(comp_centers, comp_centers)
  325. geo_feats = feature_embedding(comp_attribs, self.node_geo_feat_dim)
  326. geo_feats = paddle.to_tensor(geo_feats)
  327. batch_id = np.zeros((comp_attribs.shape[0], 1), dtype=np.float32)
  328. comp_attribs = comp_attribs.astype(np.float32)
  329. angle = np.arccos(comp_attribs[:, -2]) * np.sign(comp_attribs[:, -1])
  330. angle = angle.reshape((-1, 1))
  331. rotated_rois = np.hstack([batch_id, comp_attribs[:, :-2], angle])
  332. rois = paddle.to_tensor(rotated_rois)
  333. content_feats = self.pooling(feat_maps, rois)
  334. content_feats = content_feats.reshape([content_feats.shape[0], -1])
  335. node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
  336. sorted_dist_inds = np.argsort(distance_matrix, axis=1)
  337. (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
  338. pivots_local_graphs) = self.generate_local_graphs(sorted_dist_inds,
  339. node_feats)
  340. none_flag = False
  341. return none_flag, (local_graphs_node_feat, adjacent_matrices,
  342. pivots_knn_inds, pivots_local_graphs, text_comps)