ct_fpn.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. import os
  22. import sys
  23. import math
  24. from paddle.nn.initializer import TruncatedNormal, Constant, Normal
  25. ones_ = Constant(value=1.)
  26. zeros_ = Constant(value=0.)
  27. __dir__ = os.path.dirname(os.path.abspath(__file__))
  28. sys.path.append(__dir__)
  29. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../../..')))
  30. class Conv_BN_ReLU(nn.Layer):
  31. def __init__(self,
  32. in_planes,
  33. out_planes,
  34. kernel_size=1,
  35. stride=1,
  36. padding=0):
  37. super(Conv_BN_ReLU, self).__init__()
  38. self.conv = nn.Conv2D(
  39. in_planes,
  40. out_planes,
  41. kernel_size=kernel_size,
  42. stride=stride,
  43. padding=padding,
  44. bias_attr=False)
  45. self.bn = nn.BatchNorm2D(out_planes)
  46. self.relu = nn.ReLU()
  47. for m in self.sublayers():
  48. if isinstance(m, nn.Conv2D):
  49. n = m._kernel_size[0] * m._kernel_size[1] * m._out_channels
  50. normal_ = Normal(mean=0.0, std=math.sqrt(2. / n))
  51. normal_(m.weight)
  52. elif isinstance(m, nn.BatchNorm2D):
  53. zeros_(m.bias)
  54. ones_(m.weight)
  55. def forward(self, x):
  56. return self.relu(self.bn(self.conv(x)))
  57. class FPEM(nn.Layer):
  58. def __init__(self, in_channels, out_channels):
  59. super(FPEM, self).__init__()
  60. planes = out_channels
  61. self.dwconv3_1 = nn.Conv2D(
  62. planes,
  63. planes,
  64. kernel_size=3,
  65. stride=1,
  66. padding=1,
  67. groups=planes,
  68. bias_attr=False)
  69. self.smooth_layer3_1 = Conv_BN_ReLU(planes, planes)
  70. self.dwconv2_1 = nn.Conv2D(
  71. planes,
  72. planes,
  73. kernel_size=3,
  74. stride=1,
  75. padding=1,
  76. groups=planes,
  77. bias_attr=False)
  78. self.smooth_layer2_1 = Conv_BN_ReLU(planes, planes)
  79. self.dwconv1_1 = nn.Conv2D(
  80. planes,
  81. planes,
  82. kernel_size=3,
  83. stride=1,
  84. padding=1,
  85. groups=planes,
  86. bias_attr=False)
  87. self.smooth_layer1_1 = Conv_BN_ReLU(planes, planes)
  88. self.dwconv2_2 = nn.Conv2D(
  89. planes,
  90. planes,
  91. kernel_size=3,
  92. stride=2,
  93. padding=1,
  94. groups=planes,
  95. bias_attr=False)
  96. self.smooth_layer2_2 = Conv_BN_ReLU(planes, planes)
  97. self.dwconv3_2 = nn.Conv2D(
  98. planes,
  99. planes,
  100. kernel_size=3,
  101. stride=2,
  102. padding=1,
  103. groups=planes,
  104. bias_attr=False)
  105. self.smooth_layer3_2 = Conv_BN_ReLU(planes, planes)
  106. self.dwconv4_2 = nn.Conv2D(
  107. planes,
  108. planes,
  109. kernel_size=3,
  110. stride=2,
  111. padding=1,
  112. groups=planes,
  113. bias_attr=False)
  114. self.smooth_layer4_2 = Conv_BN_ReLU(planes, planes)
  115. def _upsample_add(self, x, y):
  116. return F.upsample(x, scale_factor=2, mode='bilinear') + y
  117. def forward(self, f1, f2, f3, f4):
  118. # up-down
  119. f3 = self.smooth_layer3_1(self.dwconv3_1(self._upsample_add(f4, f3)))
  120. f2 = self.smooth_layer2_1(self.dwconv2_1(self._upsample_add(f3, f2)))
  121. f1 = self.smooth_layer1_1(self.dwconv1_1(self._upsample_add(f2, f1)))
  122. # down-up
  123. f2 = self.smooth_layer2_2(self.dwconv2_2(self._upsample_add(f2, f1)))
  124. f3 = self.smooth_layer3_2(self.dwconv3_2(self._upsample_add(f3, f2)))
  125. f4 = self.smooth_layer4_2(self.dwconv4_2(self._upsample_add(f4, f3)))
  126. return f1, f2, f3, f4
  127. class CTFPN(nn.Layer):
  128. def __init__(self, in_channels, out_channel=128):
  129. super(CTFPN, self).__init__()
  130. self.out_channels = out_channel * 4
  131. self.reduce_layer1 = Conv_BN_ReLU(in_channels[0], 128)
  132. self.reduce_layer2 = Conv_BN_ReLU(in_channels[1], 128)
  133. self.reduce_layer3 = Conv_BN_ReLU(in_channels[2], 128)
  134. self.reduce_layer4 = Conv_BN_ReLU(in_channels[3], 128)
  135. self.fpem1 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128)
  136. self.fpem2 = FPEM(in_channels=(64, 128, 256, 512), out_channels=128)
  137. def _upsample(self, x, scale=1):
  138. return F.upsample(x, scale_factor=scale, mode='bilinear')
  139. def forward(self, f):
  140. # # reduce channel
  141. f1 = self.reduce_layer1(f[0]) # N,64,160,160 --> N, 128, 160, 160
  142. f2 = self.reduce_layer2(f[1]) # N, 128, 80, 80 --> N, 128, 80, 80
  143. f3 = self.reduce_layer3(f[2]) # N, 256, 40, 40 --> N, 128, 40, 40
  144. f4 = self.reduce_layer4(f[3]) # N, 512, 20, 20 --> N, 128, 20, 20
  145. # FPEM
  146. f1_1, f2_1, f3_1, f4_1 = self.fpem1(f1, f2, f3, f4)
  147. f1_2, f2_2, f3_2, f4_2 = self.fpem2(f1_1, f2_1, f3_1, f4_1)
  148. # FFM
  149. f1 = f1_1 + f1_2
  150. f2 = f2_1 + f2_2
  151. f3 = f3_1 + f3_2
  152. f4 = f4_1 + f4_2
  153. f2 = self._upsample(f2, scale=2)
  154. f3 = self._upsample(f3, scale=4)
  155. f4 = self._upsample(f4, scale=8)
  156. ff = paddle.concat((f1, f2, f3, f4), 1) # N,512, 160,160
  157. return ff