encoder.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. from arch.base_module import SNConv, SNConvTranspose, ResBlock
  17. class Encoder(nn.Layer):
  18. def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
  19. act, act_attr, conv_block_dropout, conv_block_num,
  20. conv_block_dilation):
  21. super(Encoder, self).__init__()
  22. self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
  23. self._in_conv = SNConv(
  24. name=name + "_in_conv",
  25. in_channels=in_channels,
  26. out_channels=encode_dim,
  27. kernel_size=7,
  28. use_bias=use_bias,
  29. norm_layer=norm_layer,
  30. act=act,
  31. act_attr=act_attr)
  32. self._down1 = SNConv(
  33. name=name + "_down1",
  34. in_channels=encode_dim,
  35. out_channels=encode_dim * 2,
  36. kernel_size=3,
  37. stride=2,
  38. padding=1,
  39. use_bias=use_bias,
  40. norm_layer=norm_layer,
  41. act=act,
  42. act_attr=act_attr)
  43. self._down2 = SNConv(
  44. name=name + "_down2",
  45. in_channels=encode_dim * 2,
  46. out_channels=encode_dim * 4,
  47. kernel_size=3,
  48. stride=2,
  49. padding=1,
  50. use_bias=use_bias,
  51. norm_layer=norm_layer,
  52. act=act,
  53. act_attr=act_attr)
  54. self._down3 = SNConv(
  55. name=name + "_down3",
  56. in_channels=encode_dim * 4,
  57. out_channels=encode_dim * 4,
  58. kernel_size=3,
  59. stride=2,
  60. padding=1,
  61. use_bias=use_bias,
  62. norm_layer=norm_layer,
  63. act=act,
  64. act_attr=act_attr)
  65. conv_blocks = []
  66. for i in range(conv_block_num):
  67. conv_blocks.append(
  68. ResBlock(
  69. name="{}_conv_block_{}".format(name, i),
  70. channels=encode_dim * 4,
  71. norm_layer=norm_layer,
  72. use_dropout=conv_block_dropout,
  73. use_dilation=conv_block_dilation,
  74. use_bias=use_bias))
  75. self._conv_blocks = nn.Sequential(*conv_blocks)
  76. def forward(self, x):
  77. out_dict = dict()
  78. x = self._pad2d(x)
  79. out_dict["in_conv"] = self._in_conv.forward(x)
  80. out_dict["down1"] = self._down1.forward(out_dict["in_conv"])
  81. out_dict["down2"] = self._down2.forward(out_dict["down1"])
  82. out_dict["down3"] = self._down3.forward(out_dict["down2"])
  83. out_dict["res_blocks"] = self._conv_blocks.forward(out_dict["down3"])
  84. return out_dict
  85. class EncoderUnet(nn.Layer):
  86. def __init__(self, name, in_channels, encode_dim, use_bias, norm_layer,
  87. act, act_attr):
  88. super(EncoderUnet, self).__init__()
  89. self._pad2d = paddle.nn.Pad2D([3, 3, 3, 3], mode="replicate")
  90. self._in_conv = SNConv(
  91. name=name + "_in_conv",
  92. in_channels=in_channels,
  93. out_channels=encode_dim,
  94. kernel_size=7,
  95. use_bias=use_bias,
  96. norm_layer=norm_layer,
  97. act=act,
  98. act_attr=act_attr)
  99. self._down1 = SNConv(
  100. name=name + "_down1",
  101. in_channels=encode_dim,
  102. out_channels=encode_dim * 2,
  103. kernel_size=3,
  104. stride=2,
  105. padding=1,
  106. use_bias=use_bias,
  107. norm_layer=norm_layer,
  108. act=act,
  109. act_attr=act_attr)
  110. self._down2 = SNConv(
  111. name=name + "_down2",
  112. in_channels=encode_dim * 2,
  113. out_channels=encode_dim * 2,
  114. kernel_size=3,
  115. stride=2,
  116. padding=1,
  117. use_bias=use_bias,
  118. norm_layer=norm_layer,
  119. act=act,
  120. act_attr=act_attr)
  121. self._down3 = SNConv(
  122. name=name + "_down3",
  123. in_channels=encode_dim * 2,
  124. out_channels=encode_dim * 2,
  125. kernel_size=3,
  126. stride=2,
  127. padding=1,
  128. use_bias=use_bias,
  129. norm_layer=norm_layer,
  130. act=act,
  131. act_attr=act_attr)
  132. self._down4 = SNConv(
  133. name=name + "_down4",
  134. in_channels=encode_dim * 2,
  135. out_channels=encode_dim * 2,
  136. kernel_size=3,
  137. stride=2,
  138. padding=1,
  139. use_bias=use_bias,
  140. norm_layer=norm_layer,
  141. act=act,
  142. act_attr=act_attr)
  143. self._up1 = SNConvTranspose(
  144. name=name + "_up1",
  145. in_channels=encode_dim * 2,
  146. out_channels=encode_dim * 2,
  147. kernel_size=3,
  148. stride=2,
  149. padding=1,
  150. use_bias=use_bias,
  151. norm_layer=norm_layer,
  152. act=act,
  153. act_attr=act_attr)
  154. self._up2 = SNConvTranspose(
  155. name=name + "_up2",
  156. in_channels=encode_dim * 4,
  157. out_channels=encode_dim * 4,
  158. kernel_size=3,
  159. stride=2,
  160. padding=1,
  161. use_bias=use_bias,
  162. norm_layer=norm_layer,
  163. act=act,
  164. act_attr=act_attr)
  165. def forward(self, x):
  166. output_dict = dict()
  167. x = self._pad2d(x)
  168. output_dict['in_conv'] = self._in_conv.forward(x)
  169. output_dict['down1'] = self._down1.forward(output_dict['in_conv'])
  170. output_dict['down2'] = self._down2.forward(output_dict['down1'])
  171. output_dict['down3'] = self._down3.forward(output_dict['down2'])
  172. output_dict['down4'] = self._down4.forward(output_dict['down3'])
  173. output_dict['up1'] = self._up1.forward(output_dict['down4'])
  174. output_dict['up2'] = self._up2.forward(
  175. paddle.concat(
  176. (output_dict['down3'], output_dict['up1']), axis=1))
  177. output_dict['concat'] = paddle.concat(
  178. (output_dict['down2'], output_dict['up2']), axis=1)
  179. return output_dict