kie_sdmgr_head.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # reference from : https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/kie/heads/sdmgr_head.py
  15. from __future__ import absolute_import
  16. from __future__ import division
  17. from __future__ import print_function
  18. import math
  19. import paddle
  20. from paddle import nn
  21. import paddle.nn.functional as F
  22. from paddle import ParamAttr
  23. class SDMGRHead(nn.Layer):
  24. def __init__(self,
  25. in_channels,
  26. num_chars=92,
  27. visual_dim=16,
  28. fusion_dim=1024,
  29. node_input=32,
  30. node_embed=256,
  31. edge_input=5,
  32. edge_embed=256,
  33. num_gnn=2,
  34. num_classes=26,
  35. bidirectional=False):
  36. super().__init__()
  37. self.fusion = Block([visual_dim, node_embed], node_embed, fusion_dim)
  38. self.node_embed = nn.Embedding(num_chars, node_input, 0)
  39. hidden = node_embed // 2 if bidirectional else node_embed
  40. self.rnn = nn.LSTM(
  41. input_size=node_input, hidden_size=hidden, num_layers=1)
  42. self.edge_embed = nn.Linear(edge_input, edge_embed)
  43. self.gnn_layers = nn.LayerList(
  44. [GNNLayer(node_embed, edge_embed) for _ in range(num_gnn)])
  45. self.node_cls = nn.Linear(node_embed, num_classes)
  46. self.edge_cls = nn.Linear(edge_embed, 2)
  47. def forward(self, input, targets):
  48. relations, texts, x = input
  49. node_nums, char_nums = [], []
  50. for text in texts:
  51. node_nums.append(text.shape[0])
  52. char_nums.append(paddle.sum((text > -1).astype(int), axis=-1))
  53. max_num = max([char_num.max() for char_num in char_nums])
  54. all_nodes = paddle.concat([
  55. paddle.concat(
  56. [text, paddle.zeros(
  57. (text.shape[0], max_num - text.shape[1]))], -1)
  58. for text in texts
  59. ])
  60. temp = paddle.clip(all_nodes, min=0).astype(int)
  61. embed_nodes = self.node_embed(temp)
  62. rnn_nodes, _ = self.rnn(embed_nodes)
  63. b, h, w = rnn_nodes.shape
  64. nodes = paddle.zeros([b, w])
  65. all_nums = paddle.concat(char_nums)
  66. valid = paddle.nonzero((all_nums > 0).astype(int))
  67. temp_all_nums = (
  68. paddle.gather(all_nums, valid) - 1).unsqueeze(-1).unsqueeze(-1)
  69. temp_all_nums = paddle.expand(temp_all_nums, [
  70. temp_all_nums.shape[0], temp_all_nums.shape[1], rnn_nodes.shape[-1]
  71. ])
  72. temp_all_nodes = paddle.gather(rnn_nodes, valid)
  73. N, C, A = temp_all_nodes.shape
  74. one_hot = F.one_hot(
  75. temp_all_nums[:, 0, :], num_classes=C).transpose([0, 2, 1])
  76. one_hot = paddle.multiply(
  77. temp_all_nodes, one_hot.astype("float32")).sum(axis=1, keepdim=True)
  78. t = one_hot.expand([N, 1, A]).squeeze(1)
  79. nodes = paddle.scatter(nodes, valid.squeeze(1), t)
  80. if x is not None:
  81. nodes = self.fusion([x, nodes])
  82. all_edges = paddle.concat(
  83. [rel.reshape([-1, rel.shape[-1]]) for rel in relations])
  84. embed_edges = self.edge_embed(all_edges.astype('float32'))
  85. embed_edges = F.normalize(embed_edges)
  86. for gnn_layer in self.gnn_layers:
  87. nodes, cat_nodes = gnn_layer(nodes, embed_edges, node_nums)
  88. node_cls, edge_cls = self.node_cls(nodes), self.edge_cls(cat_nodes)
  89. return node_cls, edge_cls
  90. class GNNLayer(nn.Layer):
  91. def __init__(self, node_dim=256, edge_dim=256):
  92. super().__init__()
  93. self.in_fc = nn.Linear(node_dim * 2 + edge_dim, node_dim)
  94. self.coef_fc = nn.Linear(node_dim, 1)
  95. self.out_fc = nn.Linear(node_dim, node_dim)
  96. self.relu = nn.ReLU()
  97. def forward(self, nodes, edges, nums):
  98. start, cat_nodes = 0, []
  99. for num in nums:
  100. sample_nodes = nodes[start:start + num]
  101. cat_nodes.append(
  102. paddle.concat([
  103. paddle.expand(sample_nodes.unsqueeze(1), [-1, num, -1]),
  104. paddle.expand(sample_nodes.unsqueeze(0), [num, -1, -1])
  105. ], -1).reshape([num**2, -1]))
  106. start += num
  107. cat_nodes = paddle.concat([paddle.concat(cat_nodes), edges], -1)
  108. cat_nodes = self.relu(self.in_fc(cat_nodes))
  109. coefs = self.coef_fc(cat_nodes)
  110. start, residuals = 0, []
  111. for num in nums:
  112. residual = F.softmax(
  113. -paddle.eye(num).unsqueeze(-1) * 1e9 +
  114. coefs[start:start + num**2].reshape([num, num, -1]), 1)
  115. residuals.append((residual * cat_nodes[start:start + num**2]
  116. .reshape([num, num, -1])).sum(1))
  117. start += num**2
  118. nodes += self.relu(self.out_fc(paddle.concat(residuals)))
  119. return [nodes, cat_nodes]
  120. class Block(nn.Layer):
  121. def __init__(self,
  122. input_dims,
  123. output_dim,
  124. mm_dim=1600,
  125. chunks=20,
  126. rank=15,
  127. shared=False,
  128. dropout_input=0.,
  129. dropout_pre_lin=0.,
  130. dropout_output=0.,
  131. pos_norm='before_cat'):
  132. super().__init__()
  133. self.rank = rank
  134. self.dropout_input = dropout_input
  135. self.dropout_pre_lin = dropout_pre_lin
  136. self.dropout_output = dropout_output
  137. assert (pos_norm in ['before_cat', 'after_cat'])
  138. self.pos_norm = pos_norm
  139. # Modules
  140. self.linear0 = nn.Linear(input_dims[0], mm_dim)
  141. self.linear1 = (self.linear0
  142. if shared else nn.Linear(input_dims[1], mm_dim))
  143. self.merge_linears0 = nn.LayerList()
  144. self.merge_linears1 = nn.LayerList()
  145. self.chunks = self.chunk_sizes(mm_dim, chunks)
  146. for size in self.chunks:
  147. ml0 = nn.Linear(size, size * rank)
  148. self.merge_linears0.append(ml0)
  149. ml1 = ml0 if shared else nn.Linear(size, size * rank)
  150. self.merge_linears1.append(ml1)
  151. self.linear_out = nn.Linear(mm_dim, output_dim)
  152. def forward(self, x):
  153. x0 = self.linear0(x[0])
  154. x1 = self.linear1(x[1])
  155. bs = x1.shape[0]
  156. if self.dropout_input > 0:
  157. x0 = F.dropout(x0, p=self.dropout_input, training=self.training)
  158. x1 = F.dropout(x1, p=self.dropout_input, training=self.training)
  159. x0_chunks = paddle.split(x0, self.chunks, -1)
  160. x1_chunks = paddle.split(x1, self.chunks, -1)
  161. zs = []
  162. for x0_c, x1_c, m0, m1 in zip(x0_chunks, x1_chunks, self.merge_linears0,
  163. self.merge_linears1):
  164. m = m0(x0_c) * m1(x1_c) # bs x split_size*rank
  165. m = m.reshape([bs, self.rank, -1])
  166. z = paddle.sum(m, 1)
  167. if self.pos_norm == 'before_cat':
  168. z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
  169. z = F.normalize(z)
  170. zs.append(z)
  171. z = paddle.concat(zs, 1)
  172. if self.pos_norm == 'after_cat':
  173. z = paddle.sqrt(F.relu(z)) - paddle.sqrt(F.relu(-z))
  174. z = F.normalize(z)
  175. if self.dropout_pre_lin > 0:
  176. z = F.dropout(z, p=self.dropout_pre_lin, training=self.training)
  177. z = self.linear_out(z)
  178. if self.dropout_output > 0:
  179. z = F.dropout(z, p=self.dropout_output, training=self.training)
  180. return z
  181. def chunk_sizes(self, dim, chunks):
  182. split_size = (dim + chunks - 1) // chunks
  183. sizes_list = [split_size] * chunks
  184. sizes_list[-1] = sizes_list[-1] - (sum(sizes_list) - dim)
  185. return sizes_list