rec_efficientb3_pren.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Code is refer from:
  16. https://github.com/RuijieJ/pren/blob/main/Nets/EfficientNet.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import math
  22. import re
  23. import collections
  24. import paddle
  25. import paddle.nn as nn
  26. import paddle.nn.functional as F
  27. __all__ = ['EfficientNetb3']
  28. GlobalParams = collections.namedtuple('GlobalParams', [
  29. 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'num_classes',
  30. 'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth',
  31. 'drop_connect_rate', 'image_size'
  32. ])
  33. BlockArgs = collections.namedtuple('BlockArgs', [
  34. 'kernel_size', 'num_repeat', 'input_filters', 'output_filters',
  35. 'expand_ratio', 'id_skip', 'stride', 'se_ratio'
  36. ])
  37. class BlockDecoder:
  38. @staticmethod
  39. def _decode_block_string(block_string):
  40. assert isinstance(block_string, str)
  41. ops = block_string.split('_')
  42. options = {}
  43. for op in ops:
  44. splits = re.split(r'(\d.*)', op)
  45. if len(splits) >= 2:
  46. key, value = splits[:2]
  47. options[key] = value
  48. assert (('s' in options and len(options['s']) == 1) or
  49. (len(options['s']) == 2 and options['s'][0] == options['s'][1]))
  50. return BlockArgs(
  51. kernel_size=int(options['k']),
  52. num_repeat=int(options['r']),
  53. input_filters=int(options['i']),
  54. output_filters=int(options['o']),
  55. expand_ratio=int(options['e']),
  56. id_skip=('noskip' not in block_string),
  57. se_ratio=float(options['se']) if 'se' in options else None,
  58. stride=[int(options['s'][0])])
  59. @staticmethod
  60. def decode(string_list):
  61. assert isinstance(string_list, list)
  62. blocks_args = []
  63. for block_string in string_list:
  64. blocks_args.append(BlockDecoder._decode_block_string(block_string))
  65. return blocks_args
  66. def efficientnet(width_coefficient=None,
  67. depth_coefficient=None,
  68. dropout_rate=0.2,
  69. drop_connect_rate=0.2,
  70. image_size=None,
  71. num_classes=1000):
  72. blocks_args = [
  73. 'r1_k3_s11_e1_i32_o16_se0.25',
  74. 'r2_k3_s22_e6_i16_o24_se0.25',
  75. 'r2_k5_s22_e6_i24_o40_se0.25',
  76. 'r3_k3_s22_e6_i40_o80_se0.25',
  77. 'r3_k5_s11_e6_i80_o112_se0.25',
  78. 'r4_k5_s22_e6_i112_o192_se0.25',
  79. 'r1_k3_s11_e6_i192_o320_se0.25',
  80. ]
  81. blocks_args = BlockDecoder.decode(blocks_args)
  82. global_params = GlobalParams(
  83. batch_norm_momentum=0.99,
  84. batch_norm_epsilon=1e-3,
  85. dropout_rate=dropout_rate,
  86. drop_connect_rate=drop_connect_rate,
  87. num_classes=num_classes,
  88. width_coefficient=width_coefficient,
  89. depth_coefficient=depth_coefficient,
  90. depth_divisor=8,
  91. min_depth=None,
  92. image_size=image_size, )
  93. return blocks_args, global_params
  94. class EffUtils:
  95. @staticmethod
  96. def round_filters(filters, global_params):
  97. """ Calculate and round number of filters based on depth multiplier. """
  98. multiplier = global_params.width_coefficient
  99. if not multiplier:
  100. return filters
  101. divisor = global_params.depth_divisor
  102. min_depth = global_params.min_depth
  103. filters *= multiplier
  104. min_depth = min_depth or divisor
  105. new_filters = max(min_depth,
  106. int(filters + divisor / 2) // divisor * divisor)
  107. if new_filters < 0.9 * filters:
  108. new_filters += divisor
  109. return int(new_filters)
  110. @staticmethod
  111. def round_repeats(repeats, global_params):
  112. """ Round number of filters based on depth multiplier. """
  113. multiplier = global_params.depth_coefficient
  114. if not multiplier:
  115. return repeats
  116. return int(math.ceil(multiplier * repeats))
  117. class MbConvBlock(nn.Layer):
  118. def __init__(self, block_args):
  119. super(MbConvBlock, self).__init__()
  120. self._block_args = block_args
  121. self.has_se = (self._block_args.se_ratio is not None) and \
  122. (0 < self._block_args.se_ratio <= 1)
  123. self.id_skip = block_args.id_skip
  124. # expansion phase
  125. self.inp = self._block_args.input_filters
  126. oup = self._block_args.input_filters * self._block_args.expand_ratio
  127. if self._block_args.expand_ratio != 1:
  128. self._expand_conv = nn.Conv2D(self.inp, oup, 1, bias_attr=False)
  129. self._bn0 = nn.BatchNorm(oup)
  130. # depthwise conv phase
  131. k = self._block_args.kernel_size
  132. s = self._block_args.stride
  133. if isinstance(s, list):
  134. s = s[0]
  135. self._depthwise_conv = nn.Conv2D(
  136. oup,
  137. oup,
  138. groups=oup,
  139. kernel_size=k,
  140. stride=s,
  141. padding='same',
  142. bias_attr=False)
  143. self._bn1 = nn.BatchNorm(oup)
  144. # squeeze and excitation layer, if desired
  145. if self.has_se:
  146. num_squeezed_channels = max(1,
  147. int(self._block_args.input_filters *
  148. self._block_args.se_ratio))
  149. self._se_reduce = nn.Conv2D(oup, num_squeezed_channels, 1)
  150. self._se_expand = nn.Conv2D(num_squeezed_channels, oup, 1)
  151. # output phase and some util class
  152. self.final_oup = self._block_args.output_filters
  153. self._project_conv = nn.Conv2D(oup, self.final_oup, 1, bias_attr=False)
  154. self._bn2 = nn.BatchNorm(self.final_oup)
  155. self._swish = nn.Swish()
  156. def _drop_connect(self, inputs, p, training):
  157. if not training:
  158. return inputs
  159. batch_size = inputs.shape[0]
  160. keep_prob = 1 - p
  161. random_tensor = keep_prob
  162. random_tensor += paddle.rand([batch_size, 1, 1, 1], dtype=inputs.dtype)
  163. random_tensor = paddle.to_tensor(random_tensor, place=inputs.place)
  164. binary_tensor = paddle.floor(random_tensor)
  165. output = inputs / keep_prob * binary_tensor
  166. return output
  167. def forward(self, inputs, drop_connect_rate=None):
  168. # expansion and depthwise conv
  169. x = inputs
  170. if self._block_args.expand_ratio != 1:
  171. x = self._swish(self._bn0(self._expand_conv(inputs)))
  172. x = self._swish(self._bn1(self._depthwise_conv(x)))
  173. # squeeze and excitation
  174. if self.has_se:
  175. x_squeezed = F.adaptive_avg_pool2d(x, 1)
  176. x_squeezed = self._se_expand(
  177. self._swish(self._se_reduce(x_squeezed)))
  178. x = F.sigmoid(x_squeezed) * x
  179. x = self._bn2(self._project_conv(x))
  180. # skip conntection and drop connect
  181. if self.id_skip and self._block_args.stride == 1 and \
  182. self.inp == self.final_oup:
  183. if drop_connect_rate:
  184. x = self._drop_connect(
  185. x, p=drop_connect_rate, training=self.training)
  186. x = x + inputs
  187. return x
  188. class EfficientNetb3_PREN(nn.Layer):
  189. def __init__(self, in_channels):
  190. super(EfficientNetb3_PREN, self).__init__()
  191. """
  192. the fllowing are efficientnetb3's superparams,
  193. they means efficientnetb3 network's width, depth, resolution and
  194. dropout respectively, to fit for text recognition task, the resolution
  195. here is changed from 300 to 64.
  196. """
  197. w, d, s, p = 1.2, 1.4, 64, 0.3
  198. self._blocks_args, self._global_params = efficientnet(
  199. width_coefficient=w,
  200. depth_coefficient=d,
  201. dropout_rate=p,
  202. image_size=s)
  203. self.out_channels = []
  204. # stem
  205. out_channels = EffUtils.round_filters(32, self._global_params)
  206. self._conv_stem = nn.Conv2D(
  207. in_channels, out_channels, 3, 2, padding='same', bias_attr=False)
  208. self._bn0 = nn.BatchNorm(out_channels)
  209. # build blocks
  210. self._blocks = []
  211. # to extract three feature maps for fpn based on efficientnetb3 backbone
  212. self._concerned_block_idxes = [7, 17, 25]
  213. _concerned_idx = 0
  214. for i, block_args in enumerate(self._blocks_args):
  215. block_args = block_args._replace(
  216. input_filters=EffUtils.round_filters(block_args.input_filters,
  217. self._global_params),
  218. output_filters=EffUtils.round_filters(block_args.output_filters,
  219. self._global_params),
  220. num_repeat=EffUtils.round_repeats(block_args.num_repeat,
  221. self._global_params))
  222. self._blocks.append(
  223. self.add_sublayer(f"{i}-0", MbConvBlock(block_args)))
  224. _concerned_idx += 1
  225. if _concerned_idx in self._concerned_block_idxes:
  226. self.out_channels.append(block_args.output_filters)
  227. if block_args.num_repeat > 1:
  228. block_args = block_args._replace(
  229. input_filters=block_args.output_filters, stride=1)
  230. for j in range(block_args.num_repeat - 1):
  231. self._blocks.append(
  232. self.add_sublayer(f'{i}-{j+1}', MbConvBlock(block_args)))
  233. _concerned_idx += 1
  234. if _concerned_idx in self._concerned_block_idxes:
  235. self.out_channels.append(block_args.output_filters)
  236. self._swish = nn.Swish()
  237. def forward(self, inputs):
  238. outs = []
  239. x = self._swish(self._bn0(self._conv_stem(inputs)))
  240. for idx, block in enumerate(self._blocks):
  241. drop_connect_rate = self._global_params.drop_connect_rate
  242. if drop_connect_rate:
  243. drop_connect_rate *= float(idx) / len(self._blocks)
  244. x = block(x, drop_connect_rate=drop_connect_rate)
  245. if idx in self._concerned_block_idxes:
  246. outs.append(x)
  247. return outs