rec_sar_head.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/sar_encoder.py
  17. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/sar_decoder.py
  18. """
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. import math
  23. import paddle
  24. from paddle import ParamAttr
  25. import paddle.nn as nn
  26. import paddle.nn.functional as F
  27. class SAREncoder(nn.Layer):
  28. """
  29. Args:
  30. enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
  31. enc_drop_rnn (float): Dropout probability of RNN layer in encoder.
  32. enc_gru (bool): If True, use GRU, else LSTM in encoder.
  33. d_model (int): Dim of channels from backbone.
  34. d_enc (int): Dim of encoder RNN layer.
  35. mask (bool): If True, mask padding in RNN sequence.
  36. """
  37. def __init__(self,
  38. enc_bi_rnn=False,
  39. enc_drop_rnn=0.1,
  40. enc_gru=False,
  41. d_model=512,
  42. d_enc=512,
  43. mask=True,
  44. **kwargs):
  45. super().__init__()
  46. assert isinstance(enc_bi_rnn, bool)
  47. assert isinstance(enc_drop_rnn, (int, float))
  48. assert 0 <= enc_drop_rnn < 1.0
  49. assert isinstance(enc_gru, bool)
  50. assert isinstance(d_model, int)
  51. assert isinstance(d_enc, int)
  52. assert isinstance(mask, bool)
  53. self.enc_bi_rnn = enc_bi_rnn
  54. self.enc_drop_rnn = enc_drop_rnn
  55. self.mask = mask
  56. # LSTM Encoder
  57. if enc_bi_rnn:
  58. direction = 'bidirectional'
  59. else:
  60. direction = 'forward'
  61. kwargs = dict(
  62. input_size=d_model,
  63. hidden_size=d_enc,
  64. num_layers=2,
  65. time_major=False,
  66. dropout=enc_drop_rnn,
  67. direction=direction)
  68. if enc_gru:
  69. self.rnn_encoder = nn.GRU(**kwargs)
  70. else:
  71. self.rnn_encoder = nn.LSTM(**kwargs)
  72. # global feature transformation
  73. encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
  74. self.linear = nn.Linear(encoder_rnn_out_size, encoder_rnn_out_size)
  75. def forward(self, feat, img_metas=None):
  76. if img_metas is not None:
  77. assert len(img_metas[0]) == paddle.shape(feat)[0]
  78. valid_ratios = None
  79. if img_metas is not None and self.mask:
  80. valid_ratios = img_metas[-1]
  81. h_feat = feat.shape[2] # bsz c h w
  82. feat_v = F.max_pool2d(
  83. feat, kernel_size=(h_feat, 1), stride=1, padding=0)
  84. feat_v = feat_v.squeeze(2) # bsz * C * W
  85. feat_v = paddle.transpose(feat_v, perm=[0, 2, 1]) # bsz * W * C
  86. holistic_feat = self.rnn_encoder(feat_v)[0] # bsz * T * C
  87. if valid_ratios is not None:
  88. valid_hf = []
  89. T = paddle.shape(holistic_feat)[1]
  90. for i in range(paddle.shape(valid_ratios)[0]):
  91. valid_step = paddle.minimum(
  92. T, paddle.ceil(valid_ratios[i] * T).astype('int32')) - 1
  93. valid_hf.append(holistic_feat[i, valid_step, :])
  94. valid_hf = paddle.stack(valid_hf, axis=0)
  95. else:
  96. valid_hf = holistic_feat[:, -1, :] # bsz * C
  97. holistic_feat = self.linear(valid_hf) # bsz * C
  98. return holistic_feat
  99. class BaseDecoder(nn.Layer):
  100. def __init__(self, **kwargs):
  101. super().__init__()
  102. def forward_train(self, feat, out_enc, targets, img_metas):
  103. raise NotImplementedError
  104. def forward_test(self, feat, out_enc, img_metas):
  105. raise NotImplementedError
  106. def forward(self,
  107. feat,
  108. out_enc,
  109. label=None,
  110. img_metas=None,
  111. train_mode=True):
  112. self.train_mode = train_mode
  113. if train_mode:
  114. return self.forward_train(feat, out_enc, label, img_metas)
  115. return self.forward_test(feat, out_enc, img_metas)
  116. class ParallelSARDecoder(BaseDecoder):
  117. """
  118. Args:
  119. out_channels (int): Output class number.
  120. enc_bi_rnn (bool): If True, use bidirectional RNN in encoder.
  121. dec_bi_rnn (bool): If True, use bidirectional RNN in decoder.
  122. dec_drop_rnn (float): Dropout of RNN layer in decoder.
  123. dec_gru (bool): If True, use GRU, else LSTM in decoder.
  124. d_model (int): Dim of channels from backbone.
  125. d_enc (int): Dim of encoder RNN layer.
  126. d_k (int): Dim of channels of attention module.
  127. pred_dropout (float): Dropout probability of prediction layer.
  128. max_seq_len (int): Maximum sequence length for decoding.
  129. mask (bool): If True, mask padding in feature map.
  130. start_idx (int): Index of start token.
  131. padding_idx (int): Index of padding token.
  132. pred_concat (bool): If True, concat glimpse feature from
  133. attention with holistic feature and hidden state.
  134. """
  135. def __init__(
  136. self,
  137. out_channels, # 90 + unknown + start + padding
  138. enc_bi_rnn=False,
  139. dec_bi_rnn=False,
  140. dec_drop_rnn=0.0,
  141. dec_gru=False,
  142. d_model=512,
  143. d_enc=512,
  144. d_k=64,
  145. pred_dropout=0.1,
  146. max_text_length=30,
  147. mask=True,
  148. pred_concat=True,
  149. **kwargs):
  150. super().__init__()
  151. self.num_classes = out_channels
  152. self.enc_bi_rnn = enc_bi_rnn
  153. self.d_k = d_k
  154. self.start_idx = out_channels - 2
  155. self.padding_idx = out_channels - 1
  156. self.max_seq_len = max_text_length
  157. self.mask = mask
  158. self.pred_concat = pred_concat
  159. encoder_rnn_out_size = d_enc * (int(enc_bi_rnn) + 1)
  160. decoder_rnn_out_size = encoder_rnn_out_size * (int(dec_bi_rnn) + 1)
  161. # 2D attention layer
  162. self.conv1x1_1 = nn.Linear(decoder_rnn_out_size, d_k)
  163. self.conv3x3_1 = nn.Conv2D(
  164. d_model, d_k, kernel_size=3, stride=1, padding=1)
  165. self.conv1x1_2 = nn.Linear(d_k, 1)
  166. # Decoder RNN layer
  167. if dec_bi_rnn:
  168. direction = 'bidirectional'
  169. else:
  170. direction = 'forward'
  171. kwargs = dict(
  172. input_size=encoder_rnn_out_size,
  173. hidden_size=encoder_rnn_out_size,
  174. num_layers=2,
  175. time_major=False,
  176. dropout=dec_drop_rnn,
  177. direction=direction)
  178. if dec_gru:
  179. self.rnn_decoder = nn.GRU(**kwargs)
  180. else:
  181. self.rnn_decoder = nn.LSTM(**kwargs)
  182. # Decoder input embedding
  183. self.embedding = nn.Embedding(
  184. self.num_classes,
  185. encoder_rnn_out_size,
  186. padding_idx=self.padding_idx)
  187. # Prediction layer
  188. self.pred_dropout = nn.Dropout(pred_dropout)
  189. pred_num_classes = self.num_classes - 1
  190. if pred_concat:
  191. fc_in_channel = decoder_rnn_out_size + d_model + encoder_rnn_out_size
  192. else:
  193. fc_in_channel = d_model
  194. self.prediction = nn.Linear(fc_in_channel, pred_num_classes)
  195. def _2d_attention(self,
  196. decoder_input,
  197. feat,
  198. holistic_feat,
  199. valid_ratios=None):
  200. y = self.rnn_decoder(decoder_input)[0]
  201. # y: bsz * (seq_len + 1) * hidden_size
  202. attn_query = self.conv1x1_1(y) # bsz * (seq_len + 1) * attn_size
  203. bsz, seq_len, attn_size = attn_query.shape
  204. attn_query = paddle.unsqueeze(attn_query, axis=[3, 4])
  205. # (bsz, seq_len + 1, attn_size, 1, 1)
  206. attn_key = self.conv3x3_1(feat)
  207. # bsz * attn_size * h * w
  208. attn_key = attn_key.unsqueeze(1)
  209. # bsz * 1 * attn_size * h * w
  210. attn_weight = paddle.tanh(paddle.add(attn_key, attn_query))
  211. # bsz * (seq_len + 1) * attn_size * h * w
  212. attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 3, 4, 2])
  213. # bsz * (seq_len + 1) * h * w * attn_size
  214. attn_weight = self.conv1x1_2(attn_weight)
  215. # bsz * (seq_len + 1) * h * w * 1
  216. bsz, T, h, w, c = paddle.shape(attn_weight)
  217. assert c == 1
  218. if valid_ratios is not None:
  219. # cal mask of attention weight
  220. for i in range(paddle.shape(valid_ratios)[0]):
  221. valid_width = paddle.minimum(
  222. w, paddle.ceil(valid_ratios[i] * w).astype("int32"))
  223. if valid_width < w:
  224. attn_weight[i, :, :, valid_width:, :] = float('-inf')
  225. attn_weight = paddle.reshape(attn_weight, [bsz, T, -1])
  226. attn_weight = F.softmax(attn_weight, axis=-1)
  227. attn_weight = paddle.reshape(attn_weight, [bsz, T, h, w, c])
  228. attn_weight = paddle.transpose(attn_weight, perm=[0, 1, 4, 2, 3])
  229. # attn_weight: bsz * T * c * h * w
  230. # feat: bsz * c * h * w
  231. attn_feat = paddle.sum(paddle.multiply(feat.unsqueeze(1), attn_weight),
  232. (3, 4),
  233. keepdim=False)
  234. # bsz * (seq_len + 1) * C
  235. # Linear transformation
  236. if self.pred_concat:
  237. hf_c = holistic_feat.shape[-1]
  238. holistic_feat = paddle.expand(
  239. holistic_feat, shape=[bsz, seq_len, hf_c])
  240. y = self.prediction(paddle.concat((y, attn_feat, holistic_feat), 2))
  241. else:
  242. y = self.prediction(attn_feat)
  243. # bsz * (seq_len + 1) * num_classes
  244. if self.train_mode:
  245. y = self.pred_dropout(y)
  246. return y
  247. def forward_train(self, feat, out_enc, label, img_metas):
  248. '''
  249. img_metas: [label, valid_ratio]
  250. '''
  251. if img_metas is not None:
  252. assert paddle.shape(img_metas[0])[0] == paddle.shape(feat)[0]
  253. valid_ratios = None
  254. if img_metas is not None and self.mask:
  255. valid_ratios = img_metas[-1]
  256. lab_embedding = self.embedding(label)
  257. # bsz * seq_len * emb_dim
  258. out_enc = out_enc.unsqueeze(1)
  259. # bsz * 1 * emb_dim
  260. in_dec = paddle.concat((out_enc, lab_embedding), axis=1)
  261. # bsz * (seq_len + 1) * C
  262. out_dec = self._2d_attention(
  263. in_dec, feat, out_enc, valid_ratios=valid_ratios)
  264. return out_dec[:, 1:, :] # bsz * seq_len * num_classes
  265. def forward_test(self, feat, out_enc, img_metas):
  266. if img_metas is not None:
  267. assert len(img_metas[0]) == feat.shape[0]
  268. valid_ratios = None
  269. if img_metas is not None and self.mask:
  270. valid_ratios = img_metas[-1]
  271. seq_len = self.max_seq_len
  272. bsz = feat.shape[0]
  273. start_token = paddle.full(
  274. (bsz, ), fill_value=self.start_idx, dtype='int64')
  275. # bsz
  276. start_token = self.embedding(start_token)
  277. # bsz * emb_dim
  278. emb_dim = start_token.shape[1]
  279. start_token = start_token.unsqueeze(1)
  280. start_token = paddle.expand(start_token, shape=[bsz, seq_len, emb_dim])
  281. # bsz * seq_len * emb_dim
  282. out_enc = out_enc.unsqueeze(1)
  283. # bsz * 1 * emb_dim
  284. decoder_input = paddle.concat((out_enc, start_token), axis=1)
  285. # bsz * (seq_len + 1) * emb_dim
  286. outputs = []
  287. for i in range(1, seq_len + 1):
  288. decoder_output = self._2d_attention(
  289. decoder_input, feat, out_enc, valid_ratios=valid_ratios)
  290. char_output = decoder_output[:, i, :] # bsz * num_classes
  291. char_output = F.softmax(char_output, -1)
  292. outputs.append(char_output)
  293. max_idx = paddle.argmax(char_output, axis=1, keepdim=False)
  294. char_embedding = self.embedding(max_idx) # bsz * emb_dim
  295. if i < seq_len:
  296. decoder_input[:, i + 1, :] = char_embedding
  297. outputs = paddle.stack(outputs, 1) # bsz * seq_len * num_classes
  298. return outputs
  299. class SARHead(nn.Layer):
  300. def __init__(self,
  301. in_channels,
  302. out_channels,
  303. enc_dim=512,
  304. max_text_length=30,
  305. enc_bi_rnn=False,
  306. enc_drop_rnn=0.1,
  307. enc_gru=False,
  308. dec_bi_rnn=False,
  309. dec_drop_rnn=0.0,
  310. dec_gru=False,
  311. d_k=512,
  312. pred_dropout=0.1,
  313. pred_concat=True,
  314. **kwargs):
  315. super(SARHead, self).__init__()
  316. # encoder module
  317. self.encoder = SAREncoder(
  318. enc_bi_rnn=enc_bi_rnn,
  319. enc_drop_rnn=enc_drop_rnn,
  320. enc_gru=enc_gru,
  321. d_model=in_channels,
  322. d_enc=enc_dim)
  323. # decoder module
  324. self.decoder = ParallelSARDecoder(
  325. out_channels=out_channels,
  326. enc_bi_rnn=enc_bi_rnn,
  327. dec_bi_rnn=dec_bi_rnn,
  328. dec_drop_rnn=dec_drop_rnn,
  329. dec_gru=dec_gru,
  330. d_model=in_channels,
  331. d_enc=enc_dim,
  332. d_k=d_k,
  333. pred_dropout=pred_dropout,
  334. max_text_length=max_text_length,
  335. pred_concat=pred_concat)
  336. def forward(self, feat, targets=None):
  337. '''
  338. img_metas: [label, valid_ratio]
  339. '''
  340. holistic_feat = self.encoder(feat, targets) # bsz c
  341. if self.training:
  342. label = targets[0] # label
  343. final_out = self.decoder(
  344. feat, holistic_feat, label, img_metas=targets)
  345. else:
  346. final_out = self.decoder(
  347. feat,
  348. holistic_feat,
  349. label=None,
  350. img_metas=targets,
  351. train_mode=False)
  352. # (bsz, seq_len, num_classes)
  353. return final_out