rec_robustscanner_head.py 25 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/channel_reduction_encoder.py
  17. https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/decoders/robust_scanner_decoder.py
  18. """
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. import math
  23. import paddle
  24. from paddle import ParamAttr
  25. import paddle.nn as nn
  26. import paddle.nn.functional as F
  27. class BaseDecoder(nn.Layer):
  28. def __init__(self, **kwargs):
  29. super().__init__()
  30. def forward_train(self, feat, out_enc, targets, img_metas):
  31. raise NotImplementedError
  32. def forward_test(self, feat, out_enc, img_metas):
  33. raise NotImplementedError
  34. def forward(self,
  35. feat,
  36. out_enc,
  37. label=None,
  38. valid_ratios=None,
  39. word_positions=None,
  40. train_mode=True):
  41. self.train_mode = train_mode
  42. if train_mode:
  43. return self.forward_train(feat, out_enc, label, valid_ratios, word_positions)
  44. return self.forward_test(feat, out_enc, valid_ratios, word_positions)
  45. class ChannelReductionEncoder(nn.Layer):
  46. """Change the channel number with a one by one convoluational layer.
  47. Args:
  48. in_channels (int): Number of input channels.
  49. out_channels (int): Number of output channels.
  50. """
  51. def __init__(self,
  52. in_channels,
  53. out_channels,
  54. **kwargs):
  55. super(ChannelReductionEncoder, self).__init__()
  56. self.layer = nn.Conv2D(
  57. in_channels, out_channels, kernel_size=1, stride=1, padding=0, weight_attr=nn.initializer.XavierNormal())
  58. def forward(self, feat):
  59. """
  60. Args:
  61. feat (Tensor): Image features with the shape of
  62. :math:`(N, C_{in}, H, W)`.
  63. Returns:
  64. Tensor: A tensor of shape :math:`(N, C_{out}, H, W)`.
  65. """
  66. return self.layer(feat)
  67. def masked_fill(x, mask, value):
  68. y = paddle.full(x.shape, value, x.dtype)
  69. return paddle.where(mask, y, x)
  70. class DotProductAttentionLayer(nn.Layer):
  71. def __init__(self, dim_model=None):
  72. super().__init__()
  73. self.scale = dim_model**-0.5 if dim_model is not None else 1.
  74. def forward(self, query, key, value, h, w, valid_ratios=None):
  75. query = paddle.transpose(query, (0, 2, 1))
  76. logits = paddle.matmul(query, key) * self.scale
  77. n, c, t = logits.shape
  78. # reshape to (n, c, h, w)
  79. logits = paddle.reshape(logits, [n, c, h, w])
  80. if valid_ratios is not None:
  81. # cal mask of attention weight
  82. for i, valid_ratio in enumerate(valid_ratios):
  83. valid_width = min(w, int(w * valid_ratio + 0.5))
  84. if valid_width < w:
  85. logits[i, :, :, valid_width:] = float('-inf')
  86. # reshape to (n, c, h, w)
  87. logits = paddle.reshape(logits, [n, c, t])
  88. weights = F.softmax(logits, axis=2)
  89. value = paddle.transpose(value, (0, 2, 1))
  90. glimpse = paddle.matmul(weights, value)
  91. glimpse = paddle.transpose(glimpse, (0, 2, 1))
  92. return glimpse
  93. class SequenceAttentionDecoder(BaseDecoder):
  94. """Sequence attention decoder for RobustScanner.
  95. RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
  96. Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
  97. Args:
  98. num_classes (int): Number of output classes :math:`C`.
  99. rnn_layers (int): Number of RNN layers.
  100. dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
  101. dim_model (int): Dimension :math:`D_m` of the model. Should also be the
  102. same as encoder output vector ``out_enc``.
  103. max_seq_len (int): Maximum output sequence length :math:`T`.
  104. start_idx (int): The index of `<SOS>`.
  105. mask (bool): Whether to mask input features according to
  106. ``img_meta['valid_ratio']``.
  107. padding_idx (int): The index of `<PAD>`.
  108. dropout (float): Dropout rate.
  109. return_feature (bool): Return feature or logits as the result.
  110. encode_value (bool): Whether to use the output of encoder ``out_enc``
  111. as `value` of attention layer. If False, the original feature
  112. ``feat`` will be used.
  113. Warning:
  114. This decoder will not predict the final class which is assumed to be
  115. `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
  116. is also ignored by loss as specified in
  117. :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
  118. """
  119. def __init__(self,
  120. num_classes=None,
  121. rnn_layers=2,
  122. dim_input=512,
  123. dim_model=128,
  124. max_seq_len=40,
  125. start_idx=0,
  126. mask=True,
  127. padding_idx=None,
  128. dropout=0,
  129. return_feature=False,
  130. encode_value=False):
  131. super().__init__()
  132. self.num_classes = num_classes
  133. self.dim_input = dim_input
  134. self.dim_model = dim_model
  135. self.return_feature = return_feature
  136. self.encode_value = encode_value
  137. self.max_seq_len = max_seq_len
  138. self.start_idx = start_idx
  139. self.mask = mask
  140. self.embedding = nn.Embedding(
  141. self.num_classes, self.dim_model, padding_idx=padding_idx)
  142. self.sequence_layer = nn.LSTM(
  143. input_size=dim_model,
  144. hidden_size=dim_model,
  145. num_layers=rnn_layers,
  146. time_major=False,
  147. dropout=dropout)
  148. self.attention_layer = DotProductAttentionLayer()
  149. self.prediction = None
  150. if not self.return_feature:
  151. pred_num_classes = num_classes - 1
  152. self.prediction = nn.Linear(
  153. dim_model if encode_value else dim_input, pred_num_classes)
  154. def forward_train(self, feat, out_enc, targets, valid_ratios):
  155. """
  156. Args:
  157. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  158. out_enc (Tensor): Encoder output of shape
  159. :math:`(N, D_m, H, W)`.
  160. targets (Tensor): a tensor of shape :math:`(N, T)`. Each element is the index of a
  161. character.
  162. valid_ratios (Tensor): valid length ratio of img.
  163. Returns:
  164. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
  165. ``return_feature=False``. Otherwise it would be the hidden feature
  166. before the prediction projection layer, whose shape is
  167. :math:`(N, T, D_m)`.
  168. """
  169. tgt_embedding = self.embedding(targets)
  170. n, c_enc, h, w = out_enc.shape
  171. assert c_enc == self.dim_model
  172. _, c_feat, _, _ = feat.shape
  173. assert c_feat == self.dim_input
  174. _, len_q, c_q = tgt_embedding.shape
  175. assert c_q == self.dim_model
  176. assert len_q <= self.max_seq_len
  177. query, _ = self.sequence_layer(tgt_embedding)
  178. query = paddle.transpose(query, (0, 2, 1))
  179. key = paddle.reshape(out_enc, [n, c_enc, h * w])
  180. if self.encode_value:
  181. value = key
  182. else:
  183. value = paddle.reshape(feat, [n, c_feat, h * w])
  184. attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
  185. attn_out = paddle.transpose(attn_out, (0, 2, 1))
  186. if self.return_feature:
  187. return attn_out
  188. out = self.prediction(attn_out)
  189. return out
  190. def forward_test(self, feat, out_enc, valid_ratios):
  191. """
  192. Args:
  193. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  194. out_enc (Tensor): Encoder output of shape
  195. :math:`(N, D_m, H, W)`.
  196. valid_ratios (Tensor): valid length ratio of img.
  197. Returns:
  198. Tensor: The output logit sequence tensor of shape
  199. :math:`(N, T, C-1)`.
  200. """
  201. seq_len = self.max_seq_len
  202. batch_size = feat.shape[0]
  203. decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
  204. outputs = []
  205. for i in range(seq_len):
  206. step_out = self.forward_test_step(feat, out_enc, decode_sequence,
  207. i, valid_ratios)
  208. outputs.append(step_out)
  209. max_idx = paddle.argmax(step_out, axis=1, keepdim=False)
  210. if i < seq_len - 1:
  211. decode_sequence[:, i + 1] = max_idx
  212. outputs = paddle.stack(outputs, 1)
  213. return outputs
  214. def forward_test_step(self, feat, out_enc, decode_sequence, current_step,
  215. valid_ratios):
  216. """
  217. Args:
  218. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  219. out_enc (Tensor): Encoder output of shape
  220. :math:`(N, D_m, H, W)`.
  221. decode_sequence (Tensor): Shape :math:`(N, T)`. The tensor that
  222. stores history decoding result.
  223. current_step (int): Current decoding step.
  224. valid_ratios (Tensor): valid length ratio of img
  225. Returns:
  226. Tensor: Shape :math:`(N, C-1)`. The logit tensor of predicted
  227. tokens at current time step.
  228. """
  229. embed = self.embedding(decode_sequence)
  230. n, c_enc, h, w = out_enc.shape
  231. assert c_enc == self.dim_model
  232. _, c_feat, _, _ = feat.shape
  233. assert c_feat == self.dim_input
  234. _, _, c_q = embed.shape
  235. assert c_q == self.dim_model
  236. query, _ = self.sequence_layer(embed)
  237. query = paddle.transpose(query, (0, 2, 1))
  238. key = paddle.reshape(out_enc, [n, c_enc, h * w])
  239. if self.encode_value:
  240. value = key
  241. else:
  242. value = paddle.reshape(feat, [n, c_feat, h * w])
  243. # [n, c, l]
  244. attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
  245. out = attn_out[:, :, current_step]
  246. if self.return_feature:
  247. return out
  248. out = self.prediction(out)
  249. out = F.softmax(out, dim=-1)
  250. return out
  251. class PositionAwareLayer(nn.Layer):
  252. def __init__(self, dim_model, rnn_layers=2):
  253. super().__init__()
  254. self.dim_model = dim_model
  255. self.rnn = nn.LSTM(
  256. input_size=dim_model,
  257. hidden_size=dim_model,
  258. num_layers=rnn_layers,
  259. time_major=False)
  260. self.mixer = nn.Sequential(
  261. nn.Conv2D(
  262. dim_model, dim_model, kernel_size=3, stride=1, padding=1),
  263. nn.ReLU(),
  264. nn.Conv2D(
  265. dim_model, dim_model, kernel_size=3, stride=1, padding=1))
  266. def forward(self, img_feature):
  267. n, c, h, w = img_feature.shape
  268. rnn_input = paddle.transpose(img_feature, (0, 2, 3, 1))
  269. rnn_input = paddle.reshape(rnn_input, (n * h, w, c))
  270. rnn_output, _ = self.rnn(rnn_input)
  271. rnn_output = paddle.reshape(rnn_output, (n, h, w, c))
  272. rnn_output = paddle.transpose(rnn_output, (0, 3, 1, 2))
  273. out = self.mixer(rnn_output)
  274. return out
  275. class PositionAttentionDecoder(BaseDecoder):
  276. """Position attention decoder for RobustScanner.
  277. RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
  278. Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
  279. Args:
  280. num_classes (int): Number of output classes :math:`C`.
  281. rnn_layers (int): Number of RNN layers.
  282. dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
  283. dim_model (int): Dimension :math:`D_m` of the model. Should also be the
  284. same as encoder output vector ``out_enc``.
  285. max_seq_len (int): Maximum output sequence length :math:`T`.
  286. mask (bool): Whether to mask input features according to
  287. ``img_meta['valid_ratio']``.
  288. return_feature (bool): Return feature or logits as the result.
  289. encode_value (bool): Whether to use the output of encoder ``out_enc``
  290. as `value` of attention layer. If False, the original feature
  291. ``feat`` will be used.
  292. Warning:
  293. This decoder will not predict the final class which is assumed to be
  294. `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
  295. is also ignored by loss
  296. """
  297. def __init__(self,
  298. num_classes=None,
  299. rnn_layers=2,
  300. dim_input=512,
  301. dim_model=128,
  302. max_seq_len=40,
  303. mask=True,
  304. return_feature=False,
  305. encode_value=False):
  306. super().__init__()
  307. self.num_classes = num_classes
  308. self.dim_input = dim_input
  309. self.dim_model = dim_model
  310. self.max_seq_len = max_seq_len
  311. self.return_feature = return_feature
  312. self.encode_value = encode_value
  313. self.mask = mask
  314. self.embedding = nn.Embedding(self.max_seq_len + 1, self.dim_model)
  315. self.position_aware_module = PositionAwareLayer(
  316. self.dim_model, rnn_layers)
  317. self.attention_layer = DotProductAttentionLayer()
  318. self.prediction = None
  319. if not self.return_feature:
  320. pred_num_classes = num_classes - 1
  321. self.prediction = nn.Linear(
  322. dim_model if encode_value else dim_input, pred_num_classes)
  323. def _get_position_index(self, length, batch_size):
  324. position_index_list = []
  325. for i in range(batch_size):
  326. position_index = paddle.arange(0, end=length, step=1, dtype='int64')
  327. position_index_list.append(position_index)
  328. batch_position_index = paddle.stack(position_index_list, axis=0)
  329. return batch_position_index
  330. def forward_train(self, feat, out_enc, targets, valid_ratios, position_index):
  331. """
  332. Args:
  333. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  334. out_enc (Tensor): Encoder output of shape
  335. :math:`(N, D_m, H, W)`.
  336. targets (dict): A dict with the key ``padded_targets``, a
  337. tensor of shape :math:`(N, T)`. Each element is the index of a
  338. character.
  339. valid_ratios (Tensor): valid length ratio of img.
  340. position_index (Tensor): The position of each word.
  341. Returns:
  342. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
  343. ``return_feature=False``. Otherwise it will be the hidden feature
  344. before the prediction projection layer, whose shape is
  345. :math:`(N, T, D_m)`.
  346. """
  347. n, c_enc, h, w = out_enc.shape
  348. assert c_enc == self.dim_model
  349. _, c_feat, _, _ = feat.shape
  350. assert c_feat == self.dim_input
  351. _, len_q = targets.shape
  352. assert len_q <= self.max_seq_len
  353. position_out_enc = self.position_aware_module(out_enc)
  354. query = self.embedding(position_index)
  355. query = paddle.transpose(query, (0, 2, 1))
  356. key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
  357. if self.encode_value:
  358. value = paddle.reshape(out_enc,(n, c_enc, h * w))
  359. else:
  360. value = paddle.reshape(feat,(n, c_feat, h * w))
  361. attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
  362. attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
  363. if self.return_feature:
  364. return attn_out
  365. return self.prediction(attn_out)
  366. def forward_test(self, feat, out_enc, valid_ratios, position_index):
  367. """
  368. Args:
  369. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  370. out_enc (Tensor): Encoder output of shape
  371. :math:`(N, D_m, H, W)`.
  372. valid_ratios (Tensor): valid length ratio of img
  373. position_index (Tensor): The position of each word.
  374. Returns:
  375. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)` if
  376. ``return_feature=False``. Otherwise it would be the hidden feature
  377. before the prediction projection layer, whose shape is
  378. :math:`(N, T, D_m)`.
  379. """
  380. n, c_enc, h, w = out_enc.shape
  381. assert c_enc == self.dim_model
  382. _, c_feat, _, _ = feat.shape
  383. assert c_feat == self.dim_input
  384. position_out_enc = self.position_aware_module(out_enc)
  385. query = self.embedding(position_index)
  386. query = paddle.transpose(query, (0, 2, 1))
  387. key = paddle.reshape(position_out_enc, (n, c_enc, h * w))
  388. if self.encode_value:
  389. value = paddle.reshape(out_enc,(n, c_enc, h * w))
  390. else:
  391. value = paddle.reshape(feat,(n, c_feat, h * w))
  392. attn_out = self.attention_layer(query, key, value, h, w, valid_ratios)
  393. attn_out = paddle.transpose(attn_out, (0, 2, 1)) # [n, len_q, dim_v]
  394. if self.return_feature:
  395. return attn_out
  396. return self.prediction(attn_out)
  397. class RobustScannerFusionLayer(nn.Layer):
  398. def __init__(self, dim_model, dim=-1):
  399. super(RobustScannerFusionLayer, self).__init__()
  400. self.dim_model = dim_model
  401. self.dim = dim
  402. self.linear_layer = nn.Linear(dim_model * 2, dim_model * 2)
  403. def forward(self, x0, x1):
  404. assert x0.shape == x1.shape
  405. fusion_input = paddle.concat([x0, x1], self.dim)
  406. output = self.linear_layer(fusion_input)
  407. output = F.glu(output, self.dim)
  408. return output
  409. class RobustScannerDecoder(BaseDecoder):
  410. """Decoder for RobustScanner.
  411. RobustScanner: `RobustScanner: Dynamically Enhancing Positional Clues for
  412. Robust Text Recognition <https://arxiv.org/abs/2007.07542>`_
  413. Args:
  414. num_classes (int): Number of output classes :math:`C`.
  415. dim_input (int): Dimension :math:`D_i` of input vector ``feat``.
  416. dim_model (int): Dimension :math:`D_m` of the model. Should also be the
  417. same as encoder output vector ``out_enc``.
  418. max_seq_len (int): Maximum output sequence length :math:`T`.
  419. start_idx (int): The index of `<SOS>`.
  420. mask (bool): Whether to mask input features according to
  421. ``img_meta['valid_ratio']``.
  422. padding_idx (int): The index of `<PAD>`.
  423. encode_value (bool): Whether to use the output of encoder ``out_enc``
  424. as `value` of attention layer. If False, the original feature
  425. ``feat`` will be used.
  426. Warning:
  427. This decoder will not predict the final class which is assumed to be
  428. `<PAD>`. Therefore, its output size is always :math:`C - 1`. `<PAD>`
  429. is also ignored by loss as specified in
  430. :obj:`mmocr.models.textrecog.recognizer.EncodeDecodeRecognizer`.
  431. """
  432. def __init__(self,
  433. num_classes=None,
  434. dim_input=512,
  435. dim_model=128,
  436. hybrid_decoder_rnn_layers=2,
  437. hybrid_decoder_dropout=0,
  438. position_decoder_rnn_layers=2,
  439. max_seq_len=40,
  440. start_idx=0,
  441. mask=True,
  442. padding_idx=None,
  443. encode_value=False):
  444. super().__init__()
  445. self.num_classes = num_classes
  446. self.dim_input = dim_input
  447. self.dim_model = dim_model
  448. self.max_seq_len = max_seq_len
  449. self.encode_value = encode_value
  450. self.start_idx = start_idx
  451. self.padding_idx = padding_idx
  452. self.mask = mask
  453. # init hybrid decoder
  454. self.hybrid_decoder = SequenceAttentionDecoder(
  455. num_classes=num_classes,
  456. rnn_layers=hybrid_decoder_rnn_layers,
  457. dim_input=dim_input,
  458. dim_model=dim_model,
  459. max_seq_len=max_seq_len,
  460. start_idx=start_idx,
  461. mask=mask,
  462. padding_idx=padding_idx,
  463. dropout=hybrid_decoder_dropout,
  464. encode_value=encode_value,
  465. return_feature=True
  466. )
  467. # init position decoder
  468. self.position_decoder = PositionAttentionDecoder(
  469. num_classes=num_classes,
  470. rnn_layers=position_decoder_rnn_layers,
  471. dim_input=dim_input,
  472. dim_model=dim_model,
  473. max_seq_len=max_seq_len,
  474. mask=mask,
  475. encode_value=encode_value,
  476. return_feature=True
  477. )
  478. self.fusion_module = RobustScannerFusionLayer(
  479. self.dim_model if encode_value else dim_input)
  480. pred_num_classes = num_classes - 1
  481. self.prediction = nn.Linear(dim_model if encode_value else dim_input,
  482. pred_num_classes)
  483. def forward_train(self, feat, out_enc, target, valid_ratios, word_positions):
  484. """
  485. Args:
  486. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  487. out_enc (Tensor): Encoder output of shape
  488. :math:`(N, D_m, H, W)`.
  489. target (dict): A dict with the key ``padded_targets``, a
  490. tensor of shape :math:`(N, T)`. Each element is the index of a
  491. character.
  492. valid_ratios (Tensor):
  493. word_positions (Tensor): The position of each word.
  494. Returns:
  495. Tensor: A raw logit tensor of shape :math:`(N, T, C-1)`.
  496. """
  497. hybrid_glimpse = self.hybrid_decoder.forward_train(
  498. feat, out_enc, target, valid_ratios)
  499. position_glimpse = self.position_decoder.forward_train(
  500. feat, out_enc, target, valid_ratios, word_positions)
  501. fusion_out = self.fusion_module(hybrid_glimpse, position_glimpse)
  502. out = self.prediction(fusion_out)
  503. return out
  504. def forward_test(self, feat, out_enc, valid_ratios, word_positions):
  505. """
  506. Args:
  507. feat (Tensor): Tensor of shape :math:`(N, D_i, H, W)`.
  508. out_enc (Tensor): Encoder output of shape
  509. :math:`(N, D_m, H, W)`.
  510. valid_ratios (Tensor):
  511. word_positions (Tensor): The position of each word.
  512. Returns:
  513. Tensor: The output logit sequence tensor of shape
  514. :math:`(N, T, C-1)`.
  515. """
  516. seq_len = self.max_seq_len
  517. batch_size = feat.shape[0]
  518. decode_sequence = (paddle.ones((batch_size, seq_len), dtype='int64') * self.start_idx)
  519. position_glimpse = self.position_decoder.forward_test(
  520. feat, out_enc, valid_ratios, word_positions)
  521. outputs = []
  522. for i in range(seq_len):
  523. hybrid_glimpse_step = self.hybrid_decoder.forward_test_step(
  524. feat, out_enc, decode_sequence, i, valid_ratios)
  525. fusion_out = self.fusion_module(hybrid_glimpse_step,
  526. position_glimpse[:, i, :])
  527. char_out = self.prediction(fusion_out)
  528. char_out = F.softmax(char_out, -1)
  529. outputs.append(char_out)
  530. max_idx = paddle.argmax(char_out, axis=1, keepdim=False)
  531. if i < seq_len - 1:
  532. decode_sequence[:, i + 1] = max_idx
  533. outputs = paddle.stack(outputs, 1)
  534. return outputs
  535. class RobustScannerHead(nn.Layer):
  536. def __init__(self,
  537. out_channels, # 90 + unknown + start + padding
  538. in_channels,
  539. enc_outchannles=128,
  540. hybrid_dec_rnn_layers=2,
  541. hybrid_dec_dropout=0,
  542. position_dec_rnn_layers=2,
  543. start_idx=0,
  544. max_text_length=40,
  545. mask=True,
  546. padding_idx=None,
  547. encode_value=False,
  548. **kwargs):
  549. super(RobustScannerHead, self).__init__()
  550. # encoder module
  551. self.encoder = ChannelReductionEncoder(
  552. in_channels=in_channels, out_channels=enc_outchannles)
  553. # decoder module
  554. self.decoder =RobustScannerDecoder(
  555. num_classes=out_channels,
  556. dim_input=in_channels,
  557. dim_model=enc_outchannles,
  558. hybrid_decoder_rnn_layers=hybrid_dec_rnn_layers,
  559. hybrid_decoder_dropout=hybrid_dec_dropout,
  560. position_decoder_rnn_layers=position_dec_rnn_layers,
  561. max_seq_len=max_text_length,
  562. start_idx=start_idx,
  563. mask=mask,
  564. padding_idx=padding_idx,
  565. encode_value=encode_value)
  566. def forward(self, inputs, targets=None):
  567. '''
  568. targets: [label, valid_ratio, word_positions]
  569. '''
  570. out_enc = self.encoder(inputs)
  571. valid_ratios = None
  572. word_positions = targets[-1]
  573. if len(targets) > 1:
  574. valid_ratios = targets[-2]
  575. if self.training:
  576. label = targets[0] # label
  577. label = paddle.to_tensor(label, dtype='int64')
  578. final_out = self.decoder(
  579. inputs, out_enc, label, valid_ratios, word_positions)
  580. if not self.training:
  581. final_out = self.decoder(
  582. inputs,
  583. out_enc,
  584. label=None,
  585. valid_ratios=valid_ratios,
  586. word_positions=word_positions,
  587. train_mode=False)
  588. return final_out