rec_resnet_rfl.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/hikopensource/DAVAR-Lab-OCR/blob/main/davarocr/davar_rcg/models/backbones/ResNetRFL.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import paddle
  22. import paddle.nn as nn
  23. from paddle.nn.initializer import TruncatedNormal, Constant, Normal, KaimingNormal
  24. kaiming_init_ = KaimingNormal()
  25. zeros_ = Constant(value=0.)
  26. ones_ = Constant(value=1.)
  27. class BasicBlock(nn.Layer):
  28. """Res-net Basic Block"""
  29. expansion = 1
  30. def __init__(self,
  31. inplanes,
  32. planes,
  33. stride=1,
  34. downsample=None,
  35. norm_type='BN',
  36. **kwargs):
  37. """
  38. Args:
  39. inplanes (int): input channel
  40. planes (int): channels of the middle feature
  41. stride (int): stride of the convolution
  42. downsample (int): type of the down_sample
  43. norm_type (str): type of the normalization
  44. **kwargs (None): backup parameter
  45. """
  46. super(BasicBlock, self).__init__()
  47. self.conv1 = self._conv3x3(inplanes, planes)
  48. self.bn1 = nn.BatchNorm(planes)
  49. self.conv2 = self._conv3x3(planes, planes)
  50. self.bn2 = nn.BatchNorm(planes)
  51. self.relu = nn.ReLU()
  52. self.downsample = downsample
  53. self.stride = stride
  54. def _conv3x3(self, in_planes, out_planes, stride=1):
  55. return nn.Conv2D(
  56. in_planes,
  57. out_planes,
  58. kernel_size=3,
  59. stride=stride,
  60. padding=1,
  61. bias_attr=False)
  62. def forward(self, x):
  63. residual = x
  64. out = self.conv1(x)
  65. out = self.bn1(out)
  66. out = self.relu(out)
  67. out = self.conv2(out)
  68. out = self.bn2(out)
  69. if self.downsample is not None:
  70. residual = self.downsample(x)
  71. out += residual
  72. out = self.relu(out)
  73. return out
  74. class ResNetRFL(nn.Layer):
  75. def __init__(self,
  76. in_channels,
  77. out_channels=512,
  78. use_cnt=True,
  79. use_seq=True):
  80. """
  81. Args:
  82. in_channels (int): input channel
  83. out_channels (int): output channel
  84. """
  85. super(ResNetRFL, self).__init__()
  86. assert use_cnt or use_seq
  87. self.use_cnt, self.use_seq = use_cnt, use_seq
  88. self.backbone = RFLBase(in_channels)
  89. self.out_channels = out_channels
  90. self.out_channels_block = [
  91. int(self.out_channels / 4), int(self.out_channels / 2),
  92. self.out_channels, self.out_channels
  93. ]
  94. block = BasicBlock
  95. layers = [1, 2, 5, 3]
  96. self.inplanes = int(self.out_channels // 2)
  97. self.relu = nn.ReLU()
  98. if self.use_seq:
  99. self.maxpool3 = nn.MaxPool2D(
  100. kernel_size=2, stride=(2, 1), padding=(0, 1))
  101. self.layer3 = self._make_layer(
  102. block, self.out_channels_block[2], layers[2], stride=1)
  103. self.conv3 = nn.Conv2D(
  104. self.out_channels_block[2],
  105. self.out_channels_block[2],
  106. kernel_size=3,
  107. stride=1,
  108. padding=1,
  109. bias_attr=False)
  110. self.bn3 = nn.BatchNorm(self.out_channels_block[2])
  111. self.layer4 = self._make_layer(
  112. block, self.out_channels_block[3], layers[3], stride=1)
  113. self.conv4_1 = nn.Conv2D(
  114. self.out_channels_block[3],
  115. self.out_channels_block[3],
  116. kernel_size=2,
  117. stride=(2, 1),
  118. padding=(0, 1),
  119. bias_attr=False)
  120. self.bn4_1 = nn.BatchNorm(self.out_channels_block[3])
  121. self.conv4_2 = nn.Conv2D(
  122. self.out_channels_block[3],
  123. self.out_channels_block[3],
  124. kernel_size=2,
  125. stride=1,
  126. padding=0,
  127. bias_attr=False)
  128. self.bn4_2 = nn.BatchNorm(self.out_channels_block[3])
  129. if self.use_cnt:
  130. self.inplanes = int(self.out_channels // 2)
  131. self.v_maxpool3 = nn.MaxPool2D(
  132. kernel_size=2, stride=(2, 1), padding=(0, 1))
  133. self.v_layer3 = self._make_layer(
  134. block, self.out_channels_block[2], layers[2], stride=1)
  135. self.v_conv3 = nn.Conv2D(
  136. self.out_channels_block[2],
  137. self.out_channels_block[2],
  138. kernel_size=3,
  139. stride=1,
  140. padding=1,
  141. bias_attr=False)
  142. self.v_bn3 = nn.BatchNorm(self.out_channels_block[2])
  143. self.v_layer4 = self._make_layer(
  144. block, self.out_channels_block[3], layers[3], stride=1)
  145. self.v_conv4_1 = nn.Conv2D(
  146. self.out_channels_block[3],
  147. self.out_channels_block[3],
  148. kernel_size=2,
  149. stride=(2, 1),
  150. padding=(0, 1),
  151. bias_attr=False)
  152. self.v_bn4_1 = nn.BatchNorm(self.out_channels_block[3])
  153. self.v_conv4_2 = nn.Conv2D(
  154. self.out_channels_block[3],
  155. self.out_channels_block[3],
  156. kernel_size=2,
  157. stride=1,
  158. padding=0,
  159. bias_attr=False)
  160. self.v_bn4_2 = nn.BatchNorm(self.out_channels_block[3])
  161. def _make_layer(self, block, planes, blocks, stride=1):
  162. downsample = None
  163. if stride != 1 or self.inplanes != planes * block.expansion:
  164. downsample = nn.Sequential(
  165. nn.Conv2D(
  166. self.inplanes,
  167. planes * block.expansion,
  168. kernel_size=1,
  169. stride=stride,
  170. bias_attr=False),
  171. nn.BatchNorm(planes * block.expansion), )
  172. layers = list()
  173. layers.append(block(self.inplanes, planes, stride, downsample))
  174. self.inplanes = planes * block.expansion
  175. for _ in range(1, blocks):
  176. layers.append(block(self.inplanes, planes))
  177. return nn.Sequential(*layers)
  178. def forward(self, inputs):
  179. x_1 = self.backbone(inputs)
  180. if self.use_cnt:
  181. v_x = self.v_maxpool3(x_1)
  182. v_x = self.v_layer3(v_x)
  183. v_x = self.v_conv3(v_x)
  184. v_x = self.v_bn3(v_x)
  185. visual_feature_2 = self.relu(v_x)
  186. v_x = self.v_layer4(visual_feature_2)
  187. v_x = self.v_conv4_1(v_x)
  188. v_x = self.v_bn4_1(v_x)
  189. v_x = self.relu(v_x)
  190. v_x = self.v_conv4_2(v_x)
  191. v_x = self.v_bn4_2(v_x)
  192. visual_feature_3 = self.relu(v_x)
  193. else:
  194. visual_feature_3 = None
  195. if self.use_seq:
  196. x = self.maxpool3(x_1)
  197. x = self.layer3(x)
  198. x = self.conv3(x)
  199. x = self.bn3(x)
  200. x_2 = self.relu(x)
  201. x = self.layer4(x_2)
  202. x = self.conv4_1(x)
  203. x = self.bn4_1(x)
  204. x = self.relu(x)
  205. x = self.conv4_2(x)
  206. x = self.bn4_2(x)
  207. x_3 = self.relu(x)
  208. else:
  209. x_3 = None
  210. return [visual_feature_3, x_3]
  211. class ResNetBase(nn.Layer):
  212. def __init__(self, in_channels, out_channels, block, layers):
  213. super(ResNetBase, self).__init__()
  214. self.out_channels_block = [
  215. int(out_channels / 4), int(out_channels / 2), out_channels,
  216. out_channels
  217. ]
  218. self.inplanes = int(out_channels / 8)
  219. self.conv0_1 = nn.Conv2D(
  220. in_channels,
  221. int(out_channels / 16),
  222. kernel_size=3,
  223. stride=1,
  224. padding=1,
  225. bias_attr=False)
  226. self.bn0_1 = nn.BatchNorm(int(out_channels / 16))
  227. self.conv0_2 = nn.Conv2D(
  228. int(out_channels / 16),
  229. self.inplanes,
  230. kernel_size=3,
  231. stride=1,
  232. padding=1,
  233. bias_attr=False)
  234. self.bn0_2 = nn.BatchNorm(self.inplanes)
  235. self.relu = nn.ReLU()
  236. self.maxpool1 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
  237. self.layer1 = self._make_layer(block, self.out_channels_block[0],
  238. layers[0])
  239. self.conv1 = nn.Conv2D(
  240. self.out_channels_block[0],
  241. self.out_channels_block[0],
  242. kernel_size=3,
  243. stride=1,
  244. padding=1,
  245. bias_attr=False)
  246. self.bn1 = nn.BatchNorm(self.out_channels_block[0])
  247. self.maxpool2 = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
  248. self.layer2 = self._make_layer(
  249. block, self.out_channels_block[1], layers[1], stride=1)
  250. self.conv2 = nn.Conv2D(
  251. self.out_channels_block[1],
  252. self.out_channels_block[1],
  253. kernel_size=3,
  254. stride=1,
  255. padding=1,
  256. bias_attr=False)
  257. self.bn2 = nn.BatchNorm(self.out_channels_block[1])
  258. def _make_layer(self, block, planes, blocks, stride=1):
  259. downsample = None
  260. if stride != 1 or self.inplanes != planes * block.expansion:
  261. downsample = nn.Sequential(
  262. nn.Conv2D(
  263. self.inplanes,
  264. planes * block.expansion,
  265. kernel_size=1,
  266. stride=stride,
  267. bias_attr=False),
  268. nn.BatchNorm(planes * block.expansion), )
  269. layers = list()
  270. layers.append(block(self.inplanes, planes, stride, downsample))
  271. self.inplanes = planes * block.expansion
  272. for _ in range(1, blocks):
  273. layers.append(block(self.inplanes, planes))
  274. return nn.Sequential(*layers)
  275. def forward(self, x):
  276. x = self.conv0_1(x)
  277. x = self.bn0_1(x)
  278. x = self.relu(x)
  279. x = self.conv0_2(x)
  280. x = self.bn0_2(x)
  281. x = self.relu(x)
  282. x = self.maxpool1(x)
  283. x = self.layer1(x)
  284. x = self.conv1(x)
  285. x = self.bn1(x)
  286. x = self.relu(x)
  287. x = self.maxpool2(x)
  288. x = self.layer2(x)
  289. x = self.conv2(x)
  290. x = self.bn2(x)
  291. x = self.relu(x)
  292. return x
  293. class RFLBase(nn.Layer):
  294. """ Reciprocal feature learning share backbone network"""
  295. def __init__(self, in_channels, out_channels=512):
  296. super(RFLBase, self).__init__()
  297. self.ConvNet = ResNetBase(in_channels, out_channels, BasicBlock,
  298. [1, 2, 5, 3])
  299. def forward(self, inputs):
  300. return self.ConvNet(inputs)