table_fpn.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. from paddle import nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. class TableFPN(nn.Layer):
  22. def __init__(self, in_channels, out_channels, **kwargs):
  23. super(TableFPN, self).__init__()
  24. self.out_channels = 512
  25. weight_attr = paddle.nn.initializer.KaimingUniform()
  26. self.in2_conv = nn.Conv2D(
  27. in_channels=in_channels[0],
  28. out_channels=self.out_channels,
  29. kernel_size=1,
  30. weight_attr=ParamAttr(initializer=weight_attr),
  31. bias_attr=False)
  32. self.in3_conv = nn.Conv2D(
  33. in_channels=in_channels[1],
  34. out_channels=self.out_channels,
  35. kernel_size=1,
  36. stride = 1,
  37. weight_attr=ParamAttr(initializer=weight_attr),
  38. bias_attr=False)
  39. self.in4_conv = nn.Conv2D(
  40. in_channels=in_channels[2],
  41. out_channels=self.out_channels,
  42. kernel_size=1,
  43. weight_attr=ParamAttr(initializer=weight_attr),
  44. bias_attr=False)
  45. self.in5_conv = nn.Conv2D(
  46. in_channels=in_channels[3],
  47. out_channels=self.out_channels,
  48. kernel_size=1,
  49. weight_attr=ParamAttr(initializer=weight_attr),
  50. bias_attr=False)
  51. self.p5_conv = nn.Conv2D(
  52. in_channels=self.out_channels,
  53. out_channels=self.out_channels // 4,
  54. kernel_size=3,
  55. padding=1,
  56. weight_attr=ParamAttr(initializer=weight_attr),
  57. bias_attr=False)
  58. self.p4_conv = nn.Conv2D(
  59. in_channels=self.out_channels,
  60. out_channels=self.out_channels // 4,
  61. kernel_size=3,
  62. padding=1,
  63. weight_attr=ParamAttr(initializer=weight_attr),
  64. bias_attr=False)
  65. self.p3_conv = nn.Conv2D(
  66. in_channels=self.out_channels,
  67. out_channels=self.out_channels // 4,
  68. kernel_size=3,
  69. padding=1,
  70. weight_attr=ParamAttr(initializer=weight_attr),
  71. bias_attr=False)
  72. self.p2_conv = nn.Conv2D(
  73. in_channels=self.out_channels,
  74. out_channels=self.out_channels // 4,
  75. kernel_size=3,
  76. padding=1,
  77. weight_attr=ParamAttr(initializer=weight_attr),
  78. bias_attr=False)
  79. self.fuse_conv = nn.Conv2D(
  80. in_channels=self.out_channels * 4,
  81. out_channels=512,
  82. kernel_size=3,
  83. padding=1,
  84. weight_attr=ParamAttr(initializer=weight_attr), bias_attr=False)
  85. def forward(self, x):
  86. c2, c3, c4, c5 = x
  87. in5 = self.in5_conv(c5)
  88. in4 = self.in4_conv(c4)
  89. in3 = self.in3_conv(c3)
  90. in2 = self.in2_conv(c2)
  91. out4 = in4 + F.upsample(
  92. in5, size=in4.shape[2:4], mode="nearest", align_mode=1) # 1/16
  93. out3 = in3 + F.upsample(
  94. out4, size=in3.shape[2:4], mode="nearest", align_mode=1) # 1/8
  95. out2 = in2 + F.upsample(
  96. out3, size=in2.shape[2:4], mode="nearest", align_mode=1) # 1/4
  97. p4 = F.upsample(out4, size=in5.shape[2:4], mode="nearest", align_mode=1)
  98. p3 = F.upsample(out3, size=in5.shape[2:4], mode="nearest", align_mode=1)
  99. p2 = F.upsample(out2, size=in5.shape[2:4], mode="nearest", align_mode=1)
  100. fuse = paddle.concat([in5, p4, p3, p2], axis=1)
  101. fuse_conv = self.fuse_conv(fuse) * 0.005
  102. return [c5 + fuse_conv]