table_att_head.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. import paddle.nn as nn
  20. from paddle import ParamAttr
  21. import paddle.nn.functional as F
  22. import numpy as np
  23. from .rec_att_head import AttentionGRUCell
  24. def get_para_bias_attr(l2_decay, k):
  25. if l2_decay > 0:
  26. regularizer = paddle.regularizer.L2Decay(l2_decay)
  27. stdv = 1.0 / math.sqrt(k * 1.0)
  28. initializer = nn.initializer.Uniform(-stdv, stdv)
  29. else:
  30. regularizer = None
  31. initializer = None
  32. weight_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
  33. bias_attr = ParamAttr(regularizer=regularizer, initializer=initializer)
  34. return [weight_attr, bias_attr]
  35. class TableAttentionHead(nn.Layer):
  36. def __init__(self,
  37. in_channels,
  38. hidden_size,
  39. in_max_len=488,
  40. max_text_length=800,
  41. out_channels=30,
  42. loc_reg_num=4,
  43. **kwargs):
  44. super(TableAttentionHead, self).__init__()
  45. self.input_size = in_channels[-1]
  46. self.hidden_size = hidden_size
  47. self.out_channels = out_channels
  48. self.max_text_length = max_text_length
  49. self.structure_attention_cell = AttentionGRUCell(
  50. self.input_size, hidden_size, self.out_channels, use_gru=False)
  51. self.structure_generator = nn.Linear(hidden_size, self.out_channels)
  52. self.in_max_len = in_max_len
  53. if self.in_max_len == 640:
  54. self.loc_fea_trans = nn.Linear(400, self.max_text_length + 1)
  55. elif self.in_max_len == 800:
  56. self.loc_fea_trans = nn.Linear(625, self.max_text_length + 1)
  57. else:
  58. self.loc_fea_trans = nn.Linear(256, self.max_text_length + 1)
  59. self.loc_generator = nn.Linear(self.input_size + hidden_size,
  60. loc_reg_num)
  61. def _char_to_onehot(self, input_char, onehot_dim):
  62. input_ont_hot = F.one_hot(input_char, onehot_dim)
  63. return input_ont_hot
  64. def forward(self, inputs, targets=None):
  65. # if and else branch are both needed when you want to assign a variable
  66. # if you modify the var in just one branch, then the modification will not work.
  67. fea = inputs[-1]
  68. last_shape = int(np.prod(fea.shape[2:])) # gry added
  69. fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], last_shape])
  70. fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
  71. batch_size = fea.shape[0]
  72. hidden = paddle.zeros((batch_size, self.hidden_size))
  73. output_hiddens = paddle.zeros(
  74. (batch_size, self.max_text_length + 1, self.hidden_size))
  75. if self.training and targets is not None:
  76. structure = targets[0]
  77. for i in range(self.max_text_length + 1):
  78. elem_onehots = self._char_to_onehot(
  79. structure[:, i], onehot_dim=self.out_channels)
  80. (outputs, hidden), alpha = self.structure_attention_cell(
  81. hidden, fea, elem_onehots)
  82. output_hiddens[:, i, :] = outputs
  83. structure_probs = self.structure_generator(output_hiddens)
  84. loc_fea = fea.transpose([0, 2, 1])
  85. loc_fea = self.loc_fea_trans(loc_fea)
  86. loc_fea = loc_fea.transpose([0, 2, 1])
  87. loc_concat = paddle.concat([output_hiddens, loc_fea], axis=2)
  88. loc_preds = self.loc_generator(loc_concat)
  89. loc_preds = F.sigmoid(loc_preds)
  90. else:
  91. temp_elem = paddle.zeros(shape=[batch_size], dtype="int32")
  92. structure_probs = None
  93. loc_preds = None
  94. elem_onehots = None
  95. outputs = None
  96. alpha = None
  97. max_text_length = paddle.to_tensor(self.max_text_length)
  98. for i in range(max_text_length + 1):
  99. elem_onehots = self._char_to_onehot(
  100. temp_elem, onehot_dim=self.out_channels)
  101. (outputs, hidden), alpha = self.structure_attention_cell(
  102. hidden, fea, elem_onehots)
  103. output_hiddens[:, i, :] = outputs
  104. structure_probs_step = self.structure_generator(outputs)
  105. temp_elem = structure_probs_step.argmax(axis=1, dtype="int32")
  106. structure_probs = self.structure_generator(output_hiddens)
  107. structure_probs = F.softmax(structure_probs)
  108. loc_fea = fea.transpose([0, 2, 1])
  109. loc_fea = self.loc_fea_trans(loc_fea)
  110. loc_fea = loc_fea.transpose([0, 2, 1])
  111. loc_concat = paddle.concat([output_hiddens, loc_fea], axis=2)
  112. loc_preds = self.loc_generator(loc_concat)
  113. loc_preds = F.sigmoid(loc_preds)
  114. return {'structure_probs': structure_probs, 'loc_preds': loc_preds}
  115. class SLAHead(nn.Layer):
  116. def __init__(self,
  117. in_channels,
  118. hidden_size,
  119. out_channels=30,
  120. max_text_length=500,
  121. loc_reg_num=4,
  122. fc_decay=0.0,
  123. **kwargs):
  124. """
  125. @param in_channels: input shape
  126. @param hidden_size: hidden_size for RNN and Embedding
  127. @param out_channels: num_classes to rec
  128. @param max_text_length: max text pred
  129. """
  130. super().__init__()
  131. in_channels = in_channels[-1]
  132. self.hidden_size = hidden_size
  133. self.max_text_length = max_text_length
  134. self.emb = self._char_to_onehot
  135. self.num_embeddings = out_channels
  136. self.loc_reg_num = loc_reg_num
  137. # structure
  138. self.structure_attention_cell = AttentionGRUCell(
  139. in_channels, hidden_size, self.num_embeddings)
  140. weight_attr, bias_attr = get_para_bias_attr(
  141. l2_decay=fc_decay, k=hidden_size)
  142. weight_attr1_1, bias_attr1_1 = get_para_bias_attr(
  143. l2_decay=fc_decay, k=hidden_size)
  144. weight_attr1_2, bias_attr1_2 = get_para_bias_attr(
  145. l2_decay=fc_decay, k=hidden_size)
  146. self.structure_generator = nn.Sequential(
  147. nn.Linear(
  148. self.hidden_size,
  149. self.hidden_size,
  150. weight_attr=weight_attr1_2,
  151. bias_attr=bias_attr1_2),
  152. nn.Linear(
  153. hidden_size,
  154. out_channels,
  155. weight_attr=weight_attr,
  156. bias_attr=bias_attr))
  157. # loc
  158. weight_attr1, bias_attr1 = get_para_bias_attr(
  159. l2_decay=fc_decay, k=self.hidden_size)
  160. weight_attr2, bias_attr2 = get_para_bias_attr(
  161. l2_decay=fc_decay, k=self.hidden_size)
  162. self.loc_generator = nn.Sequential(
  163. nn.Linear(
  164. self.hidden_size,
  165. self.hidden_size,
  166. weight_attr=weight_attr1,
  167. bias_attr=bias_attr1),
  168. nn.Linear(
  169. self.hidden_size,
  170. loc_reg_num,
  171. weight_attr=weight_attr2,
  172. bias_attr=bias_attr2),
  173. nn.Sigmoid())
  174. def forward(self, inputs, targets=None):
  175. fea = inputs[-1]
  176. batch_size = fea.shape[0]
  177. # reshape
  178. fea = paddle.reshape(fea, [fea.shape[0], fea.shape[1], -1])
  179. fea = fea.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
  180. hidden = paddle.zeros((batch_size, self.hidden_size))
  181. structure_preds = paddle.zeros(
  182. (batch_size, self.max_text_length + 1, self.num_embeddings))
  183. loc_preds = paddle.zeros(
  184. (batch_size, self.max_text_length + 1, self.loc_reg_num))
  185. structure_preds.stop_gradient = True
  186. loc_preds.stop_gradient = True
  187. if self.training and targets is not None:
  188. structure = targets[0]
  189. for i in range(self.max_text_length + 1):
  190. hidden, structure_step, loc_step = self._decode(structure[:, i],
  191. fea, hidden)
  192. structure_preds[:, i, :] = structure_step
  193. loc_preds[:, i, :] = loc_step
  194. else:
  195. pre_chars = paddle.zeros(shape=[batch_size], dtype="int32")
  196. max_text_length = paddle.to_tensor(self.max_text_length)
  197. # for export
  198. loc_step, structure_step = None, None
  199. for i in range(max_text_length + 1):
  200. hidden, structure_step, loc_step = self._decode(pre_chars, fea,
  201. hidden)
  202. pre_chars = structure_step.argmax(axis=1, dtype="int32")
  203. structure_preds[:, i, :] = structure_step
  204. loc_preds[:, i, :] = loc_step
  205. if not self.training:
  206. structure_preds = F.softmax(structure_preds)
  207. return {'structure_probs': structure_preds, 'loc_preds': loc_preds}
  208. def _decode(self, pre_chars, features, hidden):
  209. """
  210. Predict table label and coordinates for each step
  211. @param pre_chars: Table label in previous step
  212. @param features:
  213. @param hidden: hidden status in previous step
  214. @return:
  215. """
  216. emb_feature = self.emb(pre_chars)
  217. # output shape is b * self.hidden_size
  218. (output, hidden), alpha = self.structure_attention_cell(
  219. hidden, features, emb_feature)
  220. # structure
  221. structure_step = self.structure_generator(output)
  222. # loc
  223. loc_step = self.loc_generator(output)
  224. return hidden, structure_step, loc_step
  225. def _char_to_onehot(self, input_char):
  226. input_ont_hot = F.one_hot(input_char, self.num_embeddings)
  227. return input_ont_hot