pg_fpn.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. class ConvBNLayer(nn.Layer):
  22. def __init__(self,
  23. in_channels,
  24. out_channels,
  25. kernel_size,
  26. stride=1,
  27. groups=1,
  28. is_vd_mode=False,
  29. act=None,
  30. name=None):
  31. super(ConvBNLayer, self).__init__()
  32. self.is_vd_mode = is_vd_mode
  33. self._pool2d_avg = nn.AvgPool2D(
  34. kernel_size=2, stride=2, padding=0, ceil_mode=True)
  35. self._conv = nn.Conv2D(
  36. in_channels=in_channels,
  37. out_channels=out_channels,
  38. kernel_size=kernel_size,
  39. stride=stride,
  40. padding=(kernel_size - 1) // 2,
  41. groups=groups,
  42. weight_attr=ParamAttr(name=name + "_weights"),
  43. bias_attr=False)
  44. if name == "conv1":
  45. bn_name = "bn_" + name
  46. else:
  47. bn_name = "bn" + name[3:]
  48. self._batch_norm = nn.BatchNorm(
  49. out_channels,
  50. act=act,
  51. param_attr=ParamAttr(name=bn_name + '_scale'),
  52. bias_attr=ParamAttr(bn_name + '_offset'),
  53. moving_mean_name=bn_name + '_mean',
  54. moving_variance_name=bn_name + '_variance',
  55. use_global_stats=False)
  56. def forward(self, inputs):
  57. y = self._conv(inputs)
  58. y = self._batch_norm(y)
  59. return y
  60. class DeConvBNLayer(nn.Layer):
  61. def __init__(self,
  62. in_channels,
  63. out_channels,
  64. kernel_size=4,
  65. stride=2,
  66. padding=1,
  67. groups=1,
  68. if_act=True,
  69. act=None,
  70. name=None):
  71. super(DeConvBNLayer, self).__init__()
  72. self.if_act = if_act
  73. self.act = act
  74. self.deconv = nn.Conv2DTranspose(
  75. in_channels=in_channels,
  76. out_channels=out_channels,
  77. kernel_size=kernel_size,
  78. stride=stride,
  79. padding=padding,
  80. groups=groups,
  81. weight_attr=ParamAttr(name=name + '_weights'),
  82. bias_attr=False)
  83. self.bn = nn.BatchNorm(
  84. num_channels=out_channels,
  85. act=act,
  86. param_attr=ParamAttr(name="bn_" + name + "_scale"),
  87. bias_attr=ParamAttr(name="bn_" + name + "_offset"),
  88. moving_mean_name="bn_" + name + "_mean",
  89. moving_variance_name="bn_" + name + "_variance",
  90. use_global_stats=False)
  91. def forward(self, x):
  92. x = self.deconv(x)
  93. x = self.bn(x)
  94. return x
  95. class PGFPN(nn.Layer):
  96. def __init__(self, in_channels, **kwargs):
  97. super(PGFPN, self).__init__()
  98. num_inputs = [2048, 2048, 1024, 512, 256]
  99. num_outputs = [256, 256, 192, 192, 128]
  100. self.out_channels = 128
  101. self.conv_bn_layer_1 = ConvBNLayer(
  102. in_channels=3,
  103. out_channels=32,
  104. kernel_size=3,
  105. stride=1,
  106. act=None,
  107. name='FPN_d1')
  108. self.conv_bn_layer_2 = ConvBNLayer(
  109. in_channels=64,
  110. out_channels=64,
  111. kernel_size=3,
  112. stride=1,
  113. act=None,
  114. name='FPN_d2')
  115. self.conv_bn_layer_3 = ConvBNLayer(
  116. in_channels=256,
  117. out_channels=128,
  118. kernel_size=3,
  119. stride=1,
  120. act=None,
  121. name='FPN_d3')
  122. self.conv_bn_layer_4 = ConvBNLayer(
  123. in_channels=32,
  124. out_channels=64,
  125. kernel_size=3,
  126. stride=2,
  127. act=None,
  128. name='FPN_d4')
  129. self.conv_bn_layer_5 = ConvBNLayer(
  130. in_channels=64,
  131. out_channels=64,
  132. kernel_size=3,
  133. stride=1,
  134. act='relu',
  135. name='FPN_d5')
  136. self.conv_bn_layer_6 = ConvBNLayer(
  137. in_channels=64,
  138. out_channels=128,
  139. kernel_size=3,
  140. stride=2,
  141. act=None,
  142. name='FPN_d6')
  143. self.conv_bn_layer_7 = ConvBNLayer(
  144. in_channels=128,
  145. out_channels=128,
  146. kernel_size=3,
  147. stride=1,
  148. act='relu',
  149. name='FPN_d7')
  150. self.conv_bn_layer_8 = ConvBNLayer(
  151. in_channels=128,
  152. out_channels=128,
  153. kernel_size=1,
  154. stride=1,
  155. act=None,
  156. name='FPN_d8')
  157. self.conv_h0 = ConvBNLayer(
  158. in_channels=num_inputs[0],
  159. out_channels=num_outputs[0],
  160. kernel_size=1,
  161. stride=1,
  162. act=None,
  163. name="conv_h{}".format(0))
  164. self.conv_h1 = ConvBNLayer(
  165. in_channels=num_inputs[1],
  166. out_channels=num_outputs[1],
  167. kernel_size=1,
  168. stride=1,
  169. act=None,
  170. name="conv_h{}".format(1))
  171. self.conv_h2 = ConvBNLayer(
  172. in_channels=num_inputs[2],
  173. out_channels=num_outputs[2],
  174. kernel_size=1,
  175. stride=1,
  176. act=None,
  177. name="conv_h{}".format(2))
  178. self.conv_h3 = ConvBNLayer(
  179. in_channels=num_inputs[3],
  180. out_channels=num_outputs[3],
  181. kernel_size=1,
  182. stride=1,
  183. act=None,
  184. name="conv_h{}".format(3))
  185. self.conv_h4 = ConvBNLayer(
  186. in_channels=num_inputs[4],
  187. out_channels=num_outputs[4],
  188. kernel_size=1,
  189. stride=1,
  190. act=None,
  191. name="conv_h{}".format(4))
  192. self.dconv0 = DeConvBNLayer(
  193. in_channels=num_outputs[0],
  194. out_channels=num_outputs[0 + 1],
  195. name="dconv_{}".format(0))
  196. self.dconv1 = DeConvBNLayer(
  197. in_channels=num_outputs[1],
  198. out_channels=num_outputs[1 + 1],
  199. act=None,
  200. name="dconv_{}".format(1))
  201. self.dconv2 = DeConvBNLayer(
  202. in_channels=num_outputs[2],
  203. out_channels=num_outputs[2 + 1],
  204. act=None,
  205. name="dconv_{}".format(2))
  206. self.dconv3 = DeConvBNLayer(
  207. in_channels=num_outputs[3],
  208. out_channels=num_outputs[3 + 1],
  209. act=None,
  210. name="dconv_{}".format(3))
  211. self.conv_g1 = ConvBNLayer(
  212. in_channels=num_outputs[1],
  213. out_channels=num_outputs[1],
  214. kernel_size=3,
  215. stride=1,
  216. act='relu',
  217. name="conv_g{}".format(1))
  218. self.conv_g2 = ConvBNLayer(
  219. in_channels=num_outputs[2],
  220. out_channels=num_outputs[2],
  221. kernel_size=3,
  222. stride=1,
  223. act='relu',
  224. name="conv_g{}".format(2))
  225. self.conv_g3 = ConvBNLayer(
  226. in_channels=num_outputs[3],
  227. out_channels=num_outputs[3],
  228. kernel_size=3,
  229. stride=1,
  230. act='relu',
  231. name="conv_g{}".format(3))
  232. self.conv_g4 = ConvBNLayer(
  233. in_channels=num_outputs[4],
  234. out_channels=num_outputs[4],
  235. kernel_size=3,
  236. stride=1,
  237. act='relu',
  238. name="conv_g{}".format(4))
  239. self.convf = ConvBNLayer(
  240. in_channels=num_outputs[4],
  241. out_channels=num_outputs[4],
  242. kernel_size=1,
  243. stride=1,
  244. act=None,
  245. name="conv_f{}".format(4))
  246. def forward(self, x):
  247. c0, c1, c2, c3, c4, c5, c6 = x
  248. # FPN_Down_Fusion
  249. f = [c0, c1, c2]
  250. g = [None, None, None]
  251. h = [None, None, None]
  252. h[0] = self.conv_bn_layer_1(f[0])
  253. h[1] = self.conv_bn_layer_2(f[1])
  254. h[2] = self.conv_bn_layer_3(f[2])
  255. g[0] = self.conv_bn_layer_4(h[0])
  256. g[1] = paddle.add(g[0], h[1])
  257. g[1] = F.relu(g[1])
  258. g[1] = self.conv_bn_layer_5(g[1])
  259. g[1] = self.conv_bn_layer_6(g[1])
  260. g[2] = paddle.add(g[1], h[2])
  261. g[2] = F.relu(g[2])
  262. g[2] = self.conv_bn_layer_7(g[2])
  263. f_down = self.conv_bn_layer_8(g[2])
  264. # FPN UP Fusion
  265. f1 = [c6, c5, c4, c3, c2]
  266. g = [None, None, None, None, None]
  267. h = [None, None, None, None, None]
  268. h[0] = self.conv_h0(f1[0])
  269. h[1] = self.conv_h1(f1[1])
  270. h[2] = self.conv_h2(f1[2])
  271. h[3] = self.conv_h3(f1[3])
  272. h[4] = self.conv_h4(f1[4])
  273. g[0] = self.dconv0(h[0])
  274. g[1] = paddle.add(g[0], h[1])
  275. g[1] = F.relu(g[1])
  276. g[1] = self.conv_g1(g[1])
  277. g[1] = self.dconv1(g[1])
  278. g[2] = paddle.add(g[1], h[2])
  279. g[2] = F.relu(g[2])
  280. g[2] = self.conv_g2(g[2])
  281. g[2] = self.dconv2(g[2])
  282. g[3] = paddle.add(g[2], h[3])
  283. g[3] = F.relu(g[3])
  284. g[3] = self.conv_g3(g[3])
  285. g[3] = self.dconv3(g[3])
  286. g[4] = paddle.add(x=g[3], y=h[4])
  287. g[4] = F.relu(g[4])
  288. g[4] = self.conv_g4(g[4])
  289. f_up = self.convf(g[4])
  290. f_common = paddle.add(f_down, f_up)
  291. f_common = F.relu(f_common)
  292. return f_common