rec_spin_att_head.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/sequence_heads/att_head.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import paddle
  22. import paddle.nn as nn
  23. import paddle.nn.functional as F
  24. class SPINAttentionHead(nn.Layer):
  25. def __init__(self, in_channels, out_channels, hidden_size, **kwargs):
  26. super(SPINAttentionHead, self).__init__()
  27. self.input_size = in_channels
  28. self.hidden_size = hidden_size
  29. self.num_classes = out_channels
  30. self.attention_cell = AttentionLSTMCell(
  31. in_channels, hidden_size, out_channels, use_gru=False)
  32. self.generator = nn.Linear(hidden_size, out_channels)
  33. def _char_to_onehot(self, input_char, onehot_dim):
  34. input_ont_hot = F.one_hot(input_char, onehot_dim)
  35. return input_ont_hot
  36. def forward(self, inputs, targets=None, batch_max_length=25):
  37. batch_size = paddle.shape(inputs)[0]
  38. num_steps = batch_max_length + 1 # +1 for [sos] at end of sentence
  39. hidden = (paddle.zeros((batch_size, self.hidden_size)),
  40. paddle.zeros((batch_size, self.hidden_size)))
  41. output_hiddens = []
  42. if self.training: # for train
  43. targets = targets[0]
  44. for i in range(num_steps):
  45. char_onehots = self._char_to_onehot(
  46. targets[:, i], onehot_dim=self.num_classes)
  47. (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
  48. char_onehots)
  49. output_hiddens.append(paddle.unsqueeze(outputs, axis=1))
  50. output = paddle.concat(output_hiddens, axis=1)
  51. probs = self.generator(output)
  52. else:
  53. targets = paddle.zeros(shape=[batch_size], dtype="int32")
  54. probs = None
  55. char_onehots = None
  56. outputs = None
  57. alpha = None
  58. for i in range(num_steps):
  59. char_onehots = self._char_to_onehot(
  60. targets, onehot_dim=self.num_classes)
  61. (outputs, hidden), alpha = self.attention_cell(hidden, inputs,
  62. char_onehots)
  63. probs_step = self.generator(outputs)
  64. if probs is None:
  65. probs = paddle.unsqueeze(probs_step, axis=1)
  66. else:
  67. probs = paddle.concat(
  68. [probs, paddle.unsqueeze(
  69. probs_step, axis=1)], axis=1)
  70. next_input = probs_step.argmax(axis=1)
  71. targets = next_input
  72. if not self.training:
  73. probs = paddle.nn.functional.softmax(probs, axis=2)
  74. return probs
  75. class AttentionLSTMCell(nn.Layer):
  76. def __init__(self, input_size, hidden_size, num_embeddings, use_gru=False):
  77. super(AttentionLSTMCell, self).__init__()
  78. self.i2h = nn.Linear(input_size, hidden_size, bias_attr=False)
  79. self.h2h = nn.Linear(hidden_size, hidden_size)
  80. self.score = nn.Linear(hidden_size, 1, bias_attr=False)
  81. if not use_gru:
  82. self.rnn = nn.LSTMCell(
  83. input_size=input_size + num_embeddings, hidden_size=hidden_size)
  84. else:
  85. self.rnn = nn.GRUCell(
  86. input_size=input_size + num_embeddings, hidden_size=hidden_size)
  87. self.hidden_size = hidden_size
  88. def forward(self, prev_hidden, batch_H, char_onehots):
  89. batch_H_proj = self.i2h(batch_H)
  90. prev_hidden_proj = paddle.unsqueeze(self.h2h(prev_hidden[0]), axis=1)
  91. res = paddle.add(batch_H_proj, prev_hidden_proj)
  92. res = paddle.tanh(res)
  93. e = self.score(res)
  94. alpha = F.softmax(e, axis=1)
  95. alpha = paddle.transpose(alpha, [0, 2, 1])
  96. context = paddle.squeeze(paddle.mm(alpha, batch_H), axis=1)
  97. concat_context = paddle.concat([context, char_onehots], 1)
  98. cur_hidden = self.rnn(concat_context, prev_hidden)
  99. return cur_hidden, alpha