local_graph.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/modules/local_graph.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import numpy as np
  22. import paddle
  23. import paddle.nn as nn
  24. from ppocr.ext_op import RoIAlignRotated
  25. def normalize_adjacent_matrix(A):
  26. assert A.ndim == 2
  27. assert A.shape[0] == A.shape[1]
  28. A = A + np.eye(A.shape[0])
  29. d = np.sum(A, axis=0)
  30. d = np.clip(d, 0, None)
  31. d_inv = np.power(d, -0.5).flatten()
  32. d_inv[np.isinf(d_inv)] = 0.0
  33. d_inv = np.diag(d_inv)
  34. G = A.dot(d_inv).transpose().dot(d_inv)
  35. return G
  36. def euclidean_distance_matrix(A, B):
  37. """Calculate the Euclidean distance matrix.
  38. Args:
  39. A (ndarray): The point sequence.
  40. B (ndarray): The point sequence with the same dimensions as A.
  41. returns:
  42. D (ndarray): The Euclidean distance matrix.
  43. """
  44. assert A.ndim == 2
  45. assert B.ndim == 2
  46. assert A.shape[1] == B.shape[1]
  47. m = A.shape[0]
  48. n = B.shape[0]
  49. A_dots = (A * A).sum(axis=1).reshape((m, 1)) * np.ones(shape=(1, n))
  50. B_dots = (B * B).sum(axis=1) * np.ones(shape=(m, 1))
  51. D_squared = A_dots + B_dots - 2 * A.dot(B.T)
  52. zero_mask = np.less(D_squared, 0.0)
  53. D_squared[zero_mask] = 0.0
  54. D = np.sqrt(D_squared)
  55. return D
  56. def feature_embedding(input_feats, out_feat_len):
  57. """Embed features. This code was partially adapted from
  58. https://github.com/GXYM/DRRG licensed under the MIT license.
  59. Args:
  60. input_feats (ndarray): The input features of shape (N, d), where N is
  61. the number of nodes in graph, d is the input feature vector length.
  62. out_feat_len (int): The length of output feature vector.
  63. Returns:
  64. embedded_feats (ndarray): The embedded features.
  65. """
  66. assert input_feats.ndim == 2
  67. assert isinstance(out_feat_len, int)
  68. assert out_feat_len >= input_feats.shape[1]
  69. num_nodes = input_feats.shape[0]
  70. feat_dim = input_feats.shape[1]
  71. feat_repeat_times = out_feat_len // feat_dim
  72. residue_dim = out_feat_len % feat_dim
  73. if residue_dim > 0:
  74. embed_wave = np.array([
  75. np.power(1000, 2.0 * (j // 2) / feat_repeat_times + 1)
  76. for j in range(feat_repeat_times + 1)
  77. ]).reshape((feat_repeat_times + 1, 1, 1))
  78. repeat_feats = np.repeat(
  79. np.expand_dims(
  80. input_feats, axis=0), feat_repeat_times, axis=0)
  81. residue_feats = np.hstack([
  82. input_feats[:, 0:residue_dim], np.zeros(
  83. (num_nodes, feat_dim - residue_dim))
  84. ])
  85. residue_feats = np.expand_dims(residue_feats, axis=0)
  86. repeat_feats = np.concatenate([repeat_feats, residue_feats], axis=0)
  87. embedded_feats = repeat_feats / embed_wave
  88. embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
  89. embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
  90. embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
  91. (num_nodes, -1))[:, 0:out_feat_len]
  92. else:
  93. embed_wave = np.array([
  94. np.power(1000, 2.0 * (j // 2) / feat_repeat_times)
  95. for j in range(feat_repeat_times)
  96. ]).reshape((feat_repeat_times, 1, 1))
  97. repeat_feats = np.repeat(
  98. np.expand_dims(
  99. input_feats, axis=0), feat_repeat_times, axis=0)
  100. embedded_feats = repeat_feats / embed_wave
  101. embedded_feats[:, 0::2] = np.sin(embedded_feats[:, 0::2])
  102. embedded_feats[:, 1::2] = np.cos(embedded_feats[:, 1::2])
  103. embedded_feats = np.transpose(embedded_feats, (1, 0, 2)).reshape(
  104. (num_nodes, -1)).astype(np.float32)
  105. return embedded_feats
  106. class LocalGraphs:
  107. def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len,
  108. pooling_scale, pooling_output_size, local_graph_thr):
  109. assert len(k_at_hops) == 2
  110. assert all(isinstance(n, int) for n in k_at_hops)
  111. assert isinstance(num_adjacent_linkages, int)
  112. assert isinstance(node_geo_feat_len, int)
  113. assert isinstance(pooling_scale, float)
  114. assert all(isinstance(n, int) for n in pooling_output_size)
  115. assert isinstance(local_graph_thr, float)
  116. self.k_at_hops = k_at_hops
  117. self.num_adjacent_linkages = num_adjacent_linkages
  118. self.node_geo_feat_dim = node_geo_feat_len
  119. self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale)
  120. self.local_graph_thr = local_graph_thr
  121. def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels):
  122. """Generate local graphs for GCN to predict which instance a text
  123. component belongs to.
  124. Args:
  125. sorted_dist_inds (ndarray): The complete graph node indices, which
  126. is sorted according to the Euclidean distance.
  127. gt_comp_labels(ndarray): The ground truth labels define the
  128. instance to which the text components (nodes in graphs) belong.
  129. Returns:
  130. pivot_local_graphs(list[list[int]]): The list of local graph
  131. neighbor indices of pivots.
  132. pivot_knns(list[list[int]]): The list of k-nearest neighbor indices
  133. of pivots.
  134. """
  135. assert sorted_dist_inds.ndim == 2
  136. assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] ==
  137. gt_comp_labels.shape[0])
  138. knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1]
  139. pivot_local_graphs = []
  140. pivot_knns = []
  141. for pivot_ind, knn in enumerate(knn_graph):
  142. local_graph_neighbors = set(knn)
  143. for neighbor_ind in knn:
  144. local_graph_neighbors.update(
  145. set(sorted_dist_inds[neighbor_ind, 1:self.k_at_hops[1] +
  146. 1]))
  147. local_graph_neighbors.discard(pivot_ind)
  148. pivot_local_graph = list(local_graph_neighbors)
  149. pivot_local_graph.insert(0, pivot_ind)
  150. pivot_knn = [pivot_ind] + list(knn)
  151. if pivot_ind < 1:
  152. pivot_local_graphs.append(pivot_local_graph)
  153. pivot_knns.append(pivot_knn)
  154. else:
  155. add_flag = True
  156. for graph_ind, added_knn in enumerate(pivot_knns):
  157. added_pivot_ind = added_knn[0]
  158. added_local_graph = pivot_local_graphs[graph_ind]
  159. union = len(
  160. set(pivot_local_graph[1:]).union(
  161. set(added_local_graph[1:])))
  162. intersect = len(
  163. set(pivot_local_graph[1:]).intersection(
  164. set(added_local_graph[1:])))
  165. local_graph_iou = intersect / (union + 1e-8)
  166. if (local_graph_iou > self.local_graph_thr and
  167. pivot_ind in added_knn and
  168. gt_comp_labels[added_pivot_ind] ==
  169. gt_comp_labels[pivot_ind] and
  170. gt_comp_labels[pivot_ind] != 0):
  171. add_flag = False
  172. break
  173. if add_flag:
  174. pivot_local_graphs.append(pivot_local_graph)
  175. pivot_knns.append(pivot_knn)
  176. return pivot_local_graphs, pivot_knns
  177. def generate_gcn_input(self, node_feat_batch, node_label_batch,
  178. local_graph_batch, knn_batch, sorted_dist_ind_batch):
  179. """Generate graph convolution network input data.
  180. Args:
  181. node_feat_batch (List[Tensor]): The batched graph node features.
  182. node_label_batch (List[ndarray]): The batched text component
  183. labels.
  184. local_graph_batch (List[List[list[int]]]): The local graph node
  185. indices of image batch.
  186. knn_batch (List[List[list[int]]]): The knn graph node indices of
  187. image batch.
  188. sorted_dist_ind_batch (list[ndarray]): The node indices sorted
  189. according to the Euclidean distance.
  190. Returns:
  191. local_graphs_node_feat (Tensor): The node features of graph.
  192. adjacent_matrices (Tensor): The adjacent matrices of local graphs.
  193. pivots_knn_inds (Tensor): The k-nearest neighbor indices in
  194. local graph.
  195. gt_linkage (Tensor): The surpervision signal of GCN for linkage
  196. prediction.
  197. """
  198. assert isinstance(node_feat_batch, list)
  199. assert isinstance(node_label_batch, list)
  200. assert isinstance(local_graph_batch, list)
  201. assert isinstance(knn_batch, list)
  202. assert isinstance(sorted_dist_ind_batch, list)
  203. num_max_nodes = max([
  204. len(pivot_local_graph)
  205. for pivot_local_graphs in local_graph_batch
  206. for pivot_local_graph in pivot_local_graphs
  207. ])
  208. local_graphs_node_feat = []
  209. adjacent_matrices = []
  210. pivots_knn_inds = []
  211. pivots_gt_linkage = []
  212. for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch):
  213. node_feats = node_feat_batch[batch_ind]
  214. pivot_local_graphs = local_graph_batch[batch_ind]
  215. pivot_knns = knn_batch[batch_ind]
  216. node_labels = node_label_batch[batch_ind]
  217. for graph_ind, pivot_knn in enumerate(pivot_knns):
  218. pivot_local_graph = pivot_local_graphs[graph_ind]
  219. num_nodes = len(pivot_local_graph)
  220. pivot_ind = pivot_local_graph[0]
  221. node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)}
  222. knn_inds = paddle.to_tensor(
  223. [node2ind_map[i] for i in pivot_knn[1:]])
  224. pivot_feats = node_feats[pivot_ind]
  225. normalized_feats = node_feats[paddle.to_tensor(
  226. pivot_local_graph)] - pivot_feats
  227. adjacent_matrix = np.zeros(
  228. (num_nodes, num_nodes), dtype=np.float32)
  229. for node in pivot_local_graph:
  230. neighbors = sorted_dist_inds[node, 1:
  231. self.num_adjacent_linkages + 1]
  232. for neighbor in neighbors:
  233. if neighbor in pivot_local_graph:
  234. adjacent_matrix[node2ind_map[node], node2ind_map[
  235. neighbor]] = 1
  236. adjacent_matrix[node2ind_map[neighbor],
  237. node2ind_map[node]] = 1
  238. adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix)
  239. pad_adjacent_matrix = paddle.zeros(
  240. (num_max_nodes, num_max_nodes))
  241. pad_adjacent_matrix[:num_nodes, :num_nodes] = paddle.cast(
  242. paddle.to_tensor(adjacent_matrix), 'float32')
  243. pad_normalized_feats = paddle.concat(
  244. [
  245. normalized_feats, paddle.zeros(
  246. (num_max_nodes - num_nodes,
  247. normalized_feats.shape[1]))
  248. ],
  249. axis=0)
  250. local_graph_labels = node_labels[pivot_local_graph]
  251. knn_labels = local_graph_labels[knn_inds.numpy()]
  252. link_labels = ((node_labels[pivot_ind] == knn_labels) &
  253. (node_labels[pivot_ind] > 0)).astype(np.int64)
  254. link_labels = paddle.to_tensor(link_labels)
  255. local_graphs_node_feat.append(pad_normalized_feats)
  256. adjacent_matrices.append(pad_adjacent_matrix)
  257. pivots_knn_inds.append(knn_inds)
  258. pivots_gt_linkage.append(link_labels)
  259. local_graphs_node_feat = paddle.stack(local_graphs_node_feat, 0)
  260. adjacent_matrices = paddle.stack(adjacent_matrices, 0)
  261. pivots_knn_inds = paddle.stack(pivots_knn_inds, 0)
  262. pivots_gt_linkage = paddle.stack(pivots_gt_linkage, 0)
  263. return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
  264. pivots_gt_linkage)
  265. def __call__(self, feat_maps, comp_attribs):
  266. """Generate local graphs as GCN input.
  267. Args:
  268. feat_maps (Tensor): The feature maps to extract the content
  269. features of text components.
  270. comp_attribs (ndarray): The text component attributes.
  271. Returns:
  272. local_graphs_node_feat (Tensor): The node features of graph.
  273. adjacent_matrices (Tensor): The adjacent matrices of local graphs.
  274. pivots_knn_inds (Tensor): The k-nearest neighbor indices in local
  275. graph.
  276. gt_linkage (Tensor): The surpervision signal of GCN for linkage
  277. prediction.
  278. """
  279. assert isinstance(feat_maps, paddle.Tensor)
  280. assert comp_attribs.ndim == 3
  281. assert comp_attribs.shape[2] == 8
  282. sorted_dist_inds_batch = []
  283. local_graph_batch = []
  284. knn_batch = []
  285. node_feat_batch = []
  286. node_label_batch = []
  287. for batch_ind in range(comp_attribs.shape[0]):
  288. num_comps = int(comp_attribs[batch_ind, 0, 0])
  289. comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7]
  290. node_labels = comp_attribs[batch_ind, :num_comps, 7].astype(
  291. np.int32)
  292. comp_centers = comp_geo_attribs[:, 0:2]
  293. distance_matrix = euclidean_distance_matrix(comp_centers,
  294. comp_centers)
  295. batch_id = np.zeros(
  296. (comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind
  297. comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1)
  298. angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign(
  299. comp_geo_attribs[:, -1])
  300. angle = angle.reshape((-1, 1))
  301. rotated_rois = np.hstack(
  302. [batch_id, comp_geo_attribs[:, :-2], angle])
  303. rois = paddle.to_tensor(rotated_rois)
  304. content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0),
  305. rois)
  306. content_feats = content_feats.reshape([content_feats.shape[0], -1])
  307. geo_feats = feature_embedding(comp_geo_attribs,
  308. self.node_geo_feat_dim)
  309. geo_feats = paddle.to_tensor(geo_feats)
  310. node_feats = paddle.concat([content_feats, geo_feats], axis=-1)
  311. sorted_dist_inds = np.argsort(distance_matrix, axis=1)
  312. pivot_local_graphs, pivot_knns = self.generate_local_graphs(
  313. sorted_dist_inds, node_labels)
  314. node_feat_batch.append(node_feats)
  315. node_label_batch.append(node_labels)
  316. local_graph_batch.append(pivot_local_graphs)
  317. knn_batch.append(pivot_knns)
  318. sorted_dist_inds_batch.append(sorted_dist_inds)
  319. (node_feats, adjacent_matrices, knn_inds, gt_linkage) = \
  320. self.generate_gcn_input(node_feat_batch,
  321. node_label_batch,
  322. local_graph_batch,
  323. knn_batch,
  324. sorted_dist_inds_batch)
  325. return node_feats, adjacent_matrices, knn_inds, gt_linkage