rec_srn_head.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. from paddle import nn, ParamAttr
  20. from paddle.nn import functional as F
  21. import numpy as np
  22. from .self_attention import WrapEncoderForFeature
  23. from .self_attention import WrapEncoder
  24. from paddle.static import Program
  25. from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
  26. from collections import OrderedDict
  27. gradient_clip = 10
  28. class PVAM(nn.Layer):
  29. def __init__(self, in_channels, char_num, max_text_length, num_heads,
  30. num_encoder_tus, hidden_dims):
  31. super(PVAM, self).__init__()
  32. self.char_num = char_num
  33. self.max_length = max_text_length
  34. self.num_heads = num_heads
  35. self.num_encoder_TUs = num_encoder_tus
  36. self.hidden_dims = hidden_dims
  37. # Transformer encoder
  38. t = 256
  39. c = 512
  40. self.wrap_encoder_for_feature = WrapEncoderForFeature(
  41. src_vocab_size=1,
  42. max_length=t,
  43. n_layer=self.num_encoder_TUs,
  44. n_head=self.num_heads,
  45. d_key=int(self.hidden_dims / self.num_heads),
  46. d_value=int(self.hidden_dims / self.num_heads),
  47. d_model=self.hidden_dims,
  48. d_inner_hid=self.hidden_dims,
  49. prepostprocess_dropout=0.1,
  50. attention_dropout=0.1,
  51. relu_dropout=0.1,
  52. preprocess_cmd="n",
  53. postprocess_cmd="da",
  54. weight_sharing=True)
  55. # PVAM
  56. self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1)
  57. self.fc0 = paddle.nn.Linear(
  58. in_features=in_channels,
  59. out_features=in_channels, )
  60. self.emb = paddle.nn.Embedding(
  61. num_embeddings=self.max_length, embedding_dim=in_channels)
  62. self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2)
  63. self.fc1 = paddle.nn.Linear(
  64. in_features=in_channels, out_features=1, bias_attr=False)
  65. def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
  66. b, c, h, w = inputs.shape
  67. conv_features = paddle.reshape(inputs, shape=[-1, c, h * w])
  68. conv_features = paddle.transpose(conv_features, perm=[0, 2, 1])
  69. # transformer encoder
  70. b, t, c = conv_features.shape
  71. enc_inputs = [conv_features, encoder_word_pos, None]
  72. word_features = self.wrap_encoder_for_feature(enc_inputs)
  73. # pvam
  74. b, t, c = word_features.shape
  75. word_features = self.fc0(word_features)
  76. word_features_ = paddle.reshape(word_features, [-1, 1, t, c])
  77. word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1])
  78. word_pos_feature = self.emb(gsrm_word_pos)
  79. word_pos_feature_ = paddle.reshape(word_pos_feature,
  80. [-1, self.max_length, 1, c])
  81. word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1])
  82. y = word_pos_feature_ + word_features_
  83. y = F.tanh(y)
  84. attention_weight = self.fc1(y)
  85. attention_weight = paddle.reshape(
  86. attention_weight, shape=[-1, self.max_length, t])
  87. attention_weight = F.softmax(attention_weight, axis=-1)
  88. pvam_features = paddle.matmul(attention_weight,
  89. word_features) #[b, max_length, c]
  90. return pvam_features
  91. class GSRM(nn.Layer):
  92. def __init__(self, in_channels, char_num, max_text_length, num_heads,
  93. num_encoder_tus, num_decoder_tus, hidden_dims):
  94. super(GSRM, self).__init__()
  95. self.char_num = char_num
  96. self.max_length = max_text_length
  97. self.num_heads = num_heads
  98. self.num_encoder_TUs = num_encoder_tus
  99. self.num_decoder_TUs = num_decoder_tus
  100. self.hidden_dims = hidden_dims
  101. self.fc0 = paddle.nn.Linear(
  102. in_features=in_channels, out_features=self.char_num)
  103. self.wrap_encoder0 = WrapEncoder(
  104. src_vocab_size=self.char_num + 1,
  105. max_length=self.max_length,
  106. n_layer=self.num_decoder_TUs,
  107. n_head=self.num_heads,
  108. d_key=int(self.hidden_dims / self.num_heads),
  109. d_value=int(self.hidden_dims / self.num_heads),
  110. d_model=self.hidden_dims,
  111. d_inner_hid=self.hidden_dims,
  112. prepostprocess_dropout=0.1,
  113. attention_dropout=0.1,
  114. relu_dropout=0.1,
  115. preprocess_cmd="n",
  116. postprocess_cmd="da",
  117. weight_sharing=True)
  118. self.wrap_encoder1 = WrapEncoder(
  119. src_vocab_size=self.char_num + 1,
  120. max_length=self.max_length,
  121. n_layer=self.num_decoder_TUs,
  122. n_head=self.num_heads,
  123. d_key=int(self.hidden_dims / self.num_heads),
  124. d_value=int(self.hidden_dims / self.num_heads),
  125. d_model=self.hidden_dims,
  126. d_inner_hid=self.hidden_dims,
  127. prepostprocess_dropout=0.1,
  128. attention_dropout=0.1,
  129. relu_dropout=0.1,
  130. preprocess_cmd="n",
  131. postprocess_cmd="da",
  132. weight_sharing=True)
  133. self.mul = lambda x: paddle.matmul(x=x,
  134. y=self.wrap_encoder0.prepare_decoder.emb0.weight,
  135. transpose_y=True)
  136. def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
  137. gsrm_slf_attn_bias2):
  138. # ===== GSRM Visual-to-semantic embedding block =====
  139. b, t, c = inputs.shape
  140. pvam_features = paddle.reshape(inputs, [-1, c])
  141. word_out = self.fc0(pvam_features)
  142. word_ids = paddle.argmax(F.softmax(word_out), axis=1)
  143. word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1])
  144. #===== GSRM Semantic reasoning block =====
  145. """
  146. This module is achieved through bi-transformers,
  147. ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
  148. """
  149. pad_idx = self.char_num
  150. word1 = paddle.cast(word_ids, "float32")
  151. word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC")
  152. word1 = paddle.cast(word1, "int64")
  153. word1 = word1[:, :-1, :]
  154. word2 = word_ids
  155. enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
  156. enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
  157. gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
  158. gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
  159. gsrm_feature2 = F.pad(gsrm_feature2, [0, 1],
  160. value=0.,
  161. data_format="NLC")
  162. gsrm_feature2 = gsrm_feature2[:, 1:, ]
  163. gsrm_features = gsrm_feature1 + gsrm_feature2
  164. gsrm_out = self.mul(gsrm_features)
  165. b, t, c = gsrm_out.shape
  166. gsrm_out = paddle.reshape(gsrm_out, [-1, c])
  167. return gsrm_features, word_out, gsrm_out
  168. class VSFD(nn.Layer):
  169. def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
  170. super(VSFD, self).__init__()
  171. self.char_num = char_num
  172. self.fc0 = paddle.nn.Linear(
  173. in_features=in_channels * 2, out_features=pvam_ch)
  174. self.fc1 = paddle.nn.Linear(
  175. in_features=pvam_ch, out_features=self.char_num)
  176. def forward(self, pvam_feature, gsrm_feature):
  177. b, t, c1 = pvam_feature.shape
  178. b, t, c2 = gsrm_feature.shape
  179. combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2)
  180. img_comb_feature_ = paddle.reshape(
  181. combine_feature_, shape=[-1, c1 + c2])
  182. img_comb_feature_map = self.fc0(img_comb_feature_)
  183. img_comb_feature_map = F.sigmoid(img_comb_feature_map)
  184. img_comb_feature_map = paddle.reshape(
  185. img_comb_feature_map, shape=[-1, t, c1])
  186. combine_feature = img_comb_feature_map * pvam_feature + (
  187. 1.0 - img_comb_feature_map) * gsrm_feature
  188. img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1])
  189. out = self.fc1(img_comb_feature)
  190. return out
  191. class SRNHead(nn.Layer):
  192. def __init__(self, in_channels, out_channels, max_text_length, num_heads,
  193. num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
  194. super(SRNHead, self).__init__()
  195. self.char_num = out_channels
  196. self.max_length = max_text_length
  197. self.num_heads = num_heads
  198. self.num_encoder_TUs = num_encoder_TUs
  199. self.num_decoder_TUs = num_decoder_TUs
  200. self.hidden_dims = hidden_dims
  201. self.pvam = PVAM(
  202. in_channels=in_channels,
  203. char_num=self.char_num,
  204. max_text_length=self.max_length,
  205. num_heads=self.num_heads,
  206. num_encoder_tus=self.num_encoder_TUs,
  207. hidden_dims=self.hidden_dims)
  208. self.gsrm = GSRM(
  209. in_channels=in_channels,
  210. char_num=self.char_num,
  211. max_text_length=self.max_length,
  212. num_heads=self.num_heads,
  213. num_encoder_tus=self.num_encoder_TUs,
  214. num_decoder_tus=self.num_decoder_TUs,
  215. hidden_dims=self.hidden_dims)
  216. self.vsfd = VSFD(in_channels=in_channels, char_num=self.char_num)
  217. self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
  218. def forward(self, inputs, targets=None):
  219. others = targets[-4:]
  220. encoder_word_pos = others[0]
  221. gsrm_word_pos = others[1]
  222. gsrm_slf_attn_bias1 = others[2]
  223. gsrm_slf_attn_bias2 = others[3]
  224. pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
  225. gsrm_feature, word_out, gsrm_out = self.gsrm(
  226. pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
  227. gsrm_slf_attn_bias2)
  228. final_out = self.vsfd(pvam_feature, gsrm_feature)
  229. if not self.training:
  230. final_out = F.softmax(final_out, axis=1)
  231. _, decoded_out = paddle.topk(final_out, k=1)
  232. predicts = OrderedDict([
  233. ('predict', final_out),
  234. ('pvam_feature', pvam_feature),
  235. ('decoded_out', decoded_out),
  236. ('word_out', word_out),
  237. ('gsrm_out', gsrm_out),
  238. ])
  239. return predicts