rec_resnet_32.py 9.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/hikopensource/DAVAR-Lab-OCR/davarocr/davar_rcg/models/backbones/ResNet32.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import paddle.nn as nn
  22. __all__ = ["ResNet32"]
  23. conv_weight_attr = nn.initializer.KaimingNormal()
  24. class ResNet32(nn.Layer):
  25. """
  26. Feature Extractor is proposed in FAN Ref [1]
  27. Ref [1]: Focusing Attention: Towards Accurate Text Recognition in Neural Images ICCV-2017
  28. """
  29. def __init__(self, in_channels, out_channels=512):
  30. """
  31. Args:
  32. in_channels (int): input channel
  33. output_channel (int): output channel
  34. """
  35. super(ResNet32, self).__init__()
  36. self.out_channels = out_channels
  37. self.ConvNet = ResNet(in_channels, out_channels, BasicBlock, [1, 2, 5, 3])
  38. def forward(self, inputs):
  39. """
  40. Args:
  41. inputs: input feature
  42. Returns:
  43. output feature
  44. """
  45. return self.ConvNet(inputs)
  46. class BasicBlock(nn.Layer):
  47. """Res-net Basic Block"""
  48. expansion = 1
  49. def __init__(self, inplanes, planes,
  50. stride=1, downsample=None,
  51. norm_type='BN', **kwargs):
  52. """
  53. Args:
  54. inplanes (int): input channel
  55. planes (int): channels of the middle feature
  56. stride (int): stride of the convolution
  57. downsample (int): type of the down_sample
  58. norm_type (str): type of the normalization
  59. **kwargs (None): backup parameter
  60. """
  61. super(BasicBlock, self).__init__()
  62. self.conv1 = self._conv3x3(inplanes, planes)
  63. self.bn1 = nn.BatchNorm2D(planes)
  64. self.conv2 = self._conv3x3(planes, planes)
  65. self.bn2 = nn.BatchNorm2D(planes)
  66. self.relu = nn.ReLU()
  67. self.downsample = downsample
  68. self.stride = stride
  69. def _conv3x3(self, in_planes, out_planes, stride=1):
  70. """
  71. Args:
  72. in_planes (int): input channel
  73. out_planes (int): channels of the middle feature
  74. stride (int): stride of the convolution
  75. Returns:
  76. nn.Layer: Conv2D with kernel = 3
  77. """
  78. return nn.Conv2D(in_planes, out_planes,
  79. kernel_size=3, stride=stride,
  80. padding=1, weight_attr=conv_weight_attr,
  81. bias_attr=False)
  82. def forward(self, x):
  83. residual = x
  84. out = self.conv1(x)
  85. out = self.bn1(out)
  86. out = self.relu(out)
  87. out = self.conv2(out)
  88. out = self.bn2(out)
  89. if self.downsample is not None:
  90. residual = self.downsample(x)
  91. out += residual
  92. out = self.relu(out)
  93. return out
  94. class ResNet(nn.Layer):
  95. """Res-Net network structure"""
  96. def __init__(self, input_channel,
  97. output_channel, block, layers):
  98. """
  99. Args:
  100. input_channel (int): input channel
  101. output_channel (int): output channel
  102. block (BasicBlock): convolution block
  103. layers (list): layers of the block
  104. """
  105. super(ResNet, self).__init__()
  106. self.output_channel_block = [int(output_channel / 4),
  107. int(output_channel / 2),
  108. output_channel,
  109. output_channel]
  110. self.inplanes = int(output_channel / 8)
  111. self.conv0_1 = nn.Conv2D(input_channel, int(output_channel / 16),
  112. kernel_size=3, stride=1,
  113. padding=1,
  114. weight_attr=conv_weight_attr,
  115. bias_attr=False)
  116. self.bn0_1 = nn.BatchNorm2D(int(output_channel / 16))
  117. self.conv0_2 = nn.Conv2D(int(output_channel / 16), self.inplanes,
  118. kernel_size=3, stride=1,
  119. padding=1,
  120. weight_attr=conv_weight_attr,
  121. bias_attr=False)
  122. self.bn0_2 = nn.BatchNorm2D(self.inplanes)
  123. self.relu = nn.ReLU()
  124. self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
  125. self.layer1 = self._make_layer(block,
  126. self.output_channel_block[0],
  127. layers[0])
  128. self.conv1 = nn.Conv2D(self.output_channel_block[0],
  129. self.output_channel_block[0],
  130. kernel_size=3, stride=1,
  131. padding=1,
  132. weight_attr=conv_weight_attr,
  133. bias_attr=False)
  134. self.bn1 = nn.BatchNorm2D(self.output_channel_block[0])
  135. self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
  136. self.layer2 = self._make_layer(block,
  137. self.output_channel_block[1],
  138. layers[1], stride=1)
  139. self.conv2 = nn.Conv2D(self.output_channel_block[1],
  140. self.output_channel_block[1],
  141. kernel_size=3, stride=1,
  142. padding=1,
  143. weight_attr=conv_weight_attr,
  144. bias_attr=False,)
  145. self.bn2 = nn.BatchNorm2D(self.output_channel_block[1])
  146. self.maxpool3 = nn.MaxPool2D(kernel_size=2,
  147. stride=(2, 1),
  148. padding=(0, 1))
  149. self.layer3 = self._make_layer(block, self.output_channel_block[2],
  150. layers[2], stride=1)
  151. self.conv3 = nn.Conv2D(self.output_channel_block[2],
  152. self.output_channel_block[2],
  153. kernel_size=3, stride=1,
  154. padding=1,
  155. weight_attr=conv_weight_attr,
  156. bias_attr=False)
  157. self.bn3 = nn.BatchNorm2D(self.output_channel_block[2])
  158. self.layer4 = self._make_layer(block, self.output_channel_block[3],
  159. layers[3], stride=1)
  160. self.conv4_1 = nn.Conv2D(self.output_channel_block[3],
  161. self.output_channel_block[3],
  162. kernel_size=2, stride=(2, 1),
  163. padding=(0, 1),
  164. weight_attr=conv_weight_attr,
  165. bias_attr=False)
  166. self.bn4_1 = nn.BatchNorm2D(self.output_channel_block[3])
  167. self.conv4_2 = nn.Conv2D(self.output_channel_block[3],
  168. self.output_channel_block[3],
  169. kernel_size=2, stride=1,
  170. padding=0,
  171. weight_attr=conv_weight_attr,
  172. bias_attr=False)
  173. self.bn4_2 = nn.BatchNorm2D(self.output_channel_block[3])
  174. def _make_layer(self, block, planes, blocks, stride=1):
  175. """
  176. Args:
  177. block (block): convolution block
  178. planes (int): input channels
  179. blocks (list): layers of the block
  180. stride (int): stride of the convolution
  181. Returns:
  182. nn.Sequential: the combination of the convolution block
  183. """
  184. downsample = None
  185. if stride != 1 or self.inplanes != planes * block.expansion:
  186. downsample = nn.Sequential(
  187. nn.Conv2D(self.inplanes, planes * block.expansion,
  188. kernel_size=1, stride=stride,
  189. weight_attr=conv_weight_attr,
  190. bias_attr=False),
  191. nn.BatchNorm2D(planes * block.expansion),
  192. )
  193. layers = list()
  194. layers.append(block(self.inplanes, planes, stride, downsample))
  195. self.inplanes = planes * block.expansion
  196. for _ in range(1, blocks):
  197. layers.append(block(self.inplanes, planes))
  198. return nn.Sequential(*layers)
  199. def forward(self, x):
  200. x = self.conv0_1(x)
  201. x = self.bn0_1(x)
  202. x = self.relu(x)
  203. x = self.conv0_2(x)
  204. x = self.bn0_2(x)
  205. x = self.relu(x)
  206. x = self.maxpool1(x)
  207. x = self.layer1(x)
  208. x = self.conv1(x)
  209. x = self.bn1(x)
  210. x = self.relu(x)
  211. x = self.maxpool2(x)
  212. x = self.layer2(x)
  213. x = self.conv2(x)
  214. x = self.bn2(x)
  215. x = self.relu(x)
  216. x = self.maxpool3(x)
  217. x = self.layer3(x)
  218. x = self.conv3(x)
  219. x = self.bn3(x)
  220. x = self.relu(x)
  221. x = self.layer4(x)
  222. x = self.conv4_1(x)
  223. x = self.bn4_1(x)
  224. x = self.relu(x)
  225. x = self.conv4_2(x)
  226. x = self.bn4_2(x)
  227. x = self.relu(x)
  228. return x