rnn.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. from ppocr.modeling.heads.rec_ctc_head import get_para_bias_attr
  20. from ppocr.modeling.backbones.rec_svtrnet import Block, ConvBNLayer, trunc_normal_, zeros_, ones_
  21. class Im2Seq(nn.Layer):
  22. def __init__(self, in_channels, **kwargs):
  23. super().__init__()
  24. self.out_channels = in_channels
  25. def forward(self, x):
  26. B, C, H, W = x.shape
  27. assert H == 1
  28. x = x.squeeze(axis=2)
  29. x = x.transpose([0, 2, 1]) # (NTC)(batch, width, channels)
  30. return x
  31. class EncoderWithRNN(nn.Layer):
  32. def __init__(self, in_channels, hidden_size):
  33. super(EncoderWithRNN, self).__init__()
  34. self.out_channels = hidden_size * 2
  35. self.lstm = nn.LSTM(
  36. in_channels, hidden_size, direction='bidirectional', num_layers=2)
  37. def forward(self, x):
  38. x, _ = self.lstm(x)
  39. return x
  40. class BidirectionalLSTM(nn.Layer):
  41. def __init__(self, input_size,
  42. hidden_size,
  43. output_size=None,
  44. num_layers=1,
  45. dropout=0,
  46. direction=False,
  47. time_major=False,
  48. with_linear=False):
  49. super(BidirectionalLSTM, self).__init__()
  50. self.with_linear = with_linear
  51. self.rnn = nn.LSTM(input_size,
  52. hidden_size,
  53. num_layers=num_layers,
  54. dropout=dropout,
  55. direction=direction,
  56. time_major=time_major)
  57. # text recognition the specified structure LSTM with linear
  58. if self.with_linear:
  59. self.linear = nn.Linear(hidden_size * 2, output_size)
  60. def forward(self, input_feature):
  61. recurrent, _ = self.rnn(input_feature) # batch_size x T x input_size -> batch_size x T x (2*hidden_size)
  62. if self.with_linear:
  63. output = self.linear(recurrent) # batch_size x T x output_size
  64. return output
  65. return recurrent
  66. class EncoderWithCascadeRNN(nn.Layer):
  67. def __init__(self, in_channels, hidden_size, out_channels, num_layers=2, with_linear=False):
  68. super(EncoderWithCascadeRNN, self).__init__()
  69. self.out_channels = out_channels[-1]
  70. self.encoder = nn.LayerList(
  71. [BidirectionalLSTM(
  72. in_channels if i == 0 else out_channels[i - 1],
  73. hidden_size,
  74. output_size=out_channels[i],
  75. num_layers=1,
  76. direction='bidirectional',
  77. with_linear=with_linear)
  78. for i in range(num_layers)]
  79. )
  80. def forward(self, x):
  81. for i, l in enumerate(self.encoder):
  82. x = l(x)
  83. return x
  84. class EncoderWithFC(nn.Layer):
  85. def __init__(self, in_channels, hidden_size):
  86. super(EncoderWithFC, self).__init__()
  87. self.out_channels = hidden_size
  88. weight_attr, bias_attr = get_para_bias_attr(
  89. l2_decay=0.00001, k=in_channels)
  90. self.fc = nn.Linear(
  91. in_channels,
  92. hidden_size,
  93. weight_attr=weight_attr,
  94. bias_attr=bias_attr,
  95. name='reduce_encoder_fea')
  96. def forward(self, x):
  97. x = self.fc(x)
  98. return x
  99. class EncoderWithSVTR(nn.Layer):
  100. def __init__(
  101. self,
  102. in_channels,
  103. dims=64, # XS
  104. depth=2,
  105. hidden_dims=120,
  106. use_guide=False,
  107. num_heads=8,
  108. qkv_bias=True,
  109. mlp_ratio=2.0,
  110. drop_rate=0.1,
  111. attn_drop_rate=0.1,
  112. drop_path=0.,
  113. qk_scale=None):
  114. super(EncoderWithSVTR, self).__init__()
  115. self.depth = depth
  116. self.use_guide = use_guide
  117. self.conv1 = ConvBNLayer(
  118. in_channels, in_channels // 8, padding=1, act=nn.Swish)
  119. self.conv2 = ConvBNLayer(
  120. in_channels // 8, hidden_dims, kernel_size=1, act=nn.Swish)
  121. self.svtr_block = nn.LayerList([
  122. Block(
  123. dim=hidden_dims,
  124. num_heads=num_heads,
  125. mixer='Global',
  126. HW=None,
  127. mlp_ratio=mlp_ratio,
  128. qkv_bias=qkv_bias,
  129. qk_scale=qk_scale,
  130. drop=drop_rate,
  131. act_layer=nn.Swish,
  132. attn_drop=attn_drop_rate,
  133. drop_path=drop_path,
  134. norm_layer='nn.LayerNorm',
  135. epsilon=1e-05,
  136. prenorm=False) for i in range(depth)
  137. ])
  138. self.norm = nn.LayerNorm(hidden_dims, epsilon=1e-6)
  139. self.conv3 = ConvBNLayer(
  140. hidden_dims, in_channels, kernel_size=1, act=nn.Swish)
  141. # last conv-nxn, the input is concat of input tensor and conv3 output tensor
  142. self.conv4 = ConvBNLayer(
  143. 2 * in_channels, in_channels // 8, padding=1, act=nn.Swish)
  144. self.conv1x1 = ConvBNLayer(
  145. in_channels // 8, dims, kernel_size=1, act=nn.Swish)
  146. self.out_channels = dims
  147. self.apply(self._init_weights)
  148. def _init_weights(self, m):
  149. if isinstance(m, nn.Linear):
  150. trunc_normal_(m.weight)
  151. if isinstance(m, nn.Linear) and m.bias is not None:
  152. zeros_(m.bias)
  153. elif isinstance(m, nn.LayerNorm):
  154. zeros_(m.bias)
  155. ones_(m.weight)
  156. def forward(self, x):
  157. # for use guide
  158. if self.use_guide:
  159. z = x.clone()
  160. z.stop_gradient = True
  161. else:
  162. z = x
  163. # for short cut
  164. h = z
  165. # reduce dim
  166. z = self.conv1(z)
  167. z = self.conv2(z)
  168. # SVTR global block
  169. B, C, H, W = z.shape
  170. z = z.flatten(2).transpose([0, 2, 1])
  171. for blk in self.svtr_block:
  172. z = blk(z)
  173. z = self.norm(z)
  174. # last stage
  175. z = z.reshape([0, H, W, C]).transpose([0, 3, 1, 2])
  176. z = self.conv3(z)
  177. z = paddle.concat((h, z), axis=1)
  178. z = self.conv1x1(self.conv4(z))
  179. return z
  180. class SequenceEncoder(nn.Layer):
  181. def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
  182. super(SequenceEncoder, self).__init__()
  183. self.encoder_reshape = Im2Seq(in_channels)
  184. self.out_channels = self.encoder_reshape.out_channels
  185. self.encoder_type = encoder_type
  186. if encoder_type == 'reshape':
  187. self.only_reshape = True
  188. else:
  189. support_encoder_dict = {
  190. 'reshape': Im2Seq,
  191. 'fc': EncoderWithFC,
  192. 'rnn': EncoderWithRNN,
  193. 'svtr': EncoderWithSVTR,
  194. 'cascadernn': EncoderWithCascadeRNN
  195. }
  196. assert encoder_type in support_encoder_dict, '{} must in {}'.format(
  197. encoder_type, support_encoder_dict.keys())
  198. if encoder_type == "svtr":
  199. self.encoder = support_encoder_dict[encoder_type](
  200. self.encoder_reshape.out_channels, **kwargs)
  201. elif encoder_type == 'cascadernn':
  202. self.encoder = support_encoder_dict[encoder_type](
  203. self.encoder_reshape.out_channels, hidden_size, **kwargs)
  204. else:
  205. self.encoder = support_encoder_dict[encoder_type](
  206. self.encoder_reshape.out_channels, hidden_size)
  207. self.out_channels = self.encoder.out_channels
  208. self.only_reshape = False
  209. def forward(self, x):
  210. if self.encoder_type != 'svtr':
  211. x = self.encoder_reshape(x)
  212. if not self.only_reshape:
  213. x = self.encoder(x)
  214. return x
  215. else:
  216. x = self.encoder(x)
  217. x = self.encoder_reshape(x)
  218. return x