base_module.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. from arch.spectral_norm import spectral_norm
  17. class CBN(nn.Layer):
  18. def __init__(self,
  19. name,
  20. in_channels,
  21. out_channels,
  22. kernel_size,
  23. stride=1,
  24. padding=0,
  25. dilation=1,
  26. groups=1,
  27. use_bias=False,
  28. norm_layer=None,
  29. act=None,
  30. act_attr=None):
  31. super(CBN, self).__init__()
  32. if use_bias:
  33. bias_attr = paddle.ParamAttr(name=name + "_bias")
  34. else:
  35. bias_attr = None
  36. self._conv = paddle.nn.Conv2D(
  37. in_channels=in_channels,
  38. out_channels=out_channels,
  39. kernel_size=kernel_size,
  40. stride=stride,
  41. padding=padding,
  42. dilation=dilation,
  43. groups=groups,
  44. weight_attr=paddle.ParamAttr(name=name + "_weights"),
  45. bias_attr=bias_attr)
  46. if norm_layer:
  47. self._norm_layer = getattr(paddle.nn, norm_layer)(
  48. num_features=out_channels, name=name + "_bn")
  49. else:
  50. self._norm_layer = None
  51. if act:
  52. if act_attr:
  53. self._act = getattr(paddle.nn, act)(**act_attr,
  54. name=name + "_" + act)
  55. else:
  56. self._act = getattr(paddle.nn, act)(name=name + "_" + act)
  57. else:
  58. self._act = None
  59. def forward(self, x):
  60. out = self._conv(x)
  61. if self._norm_layer:
  62. out = self._norm_layer(out)
  63. if self._act:
  64. out = self._act(out)
  65. return out
  66. class SNConv(nn.Layer):
  67. def __init__(self,
  68. name,
  69. in_channels,
  70. out_channels,
  71. kernel_size,
  72. stride=1,
  73. padding=0,
  74. dilation=1,
  75. groups=1,
  76. use_bias=False,
  77. norm_layer=None,
  78. act=None,
  79. act_attr=None):
  80. super(SNConv, self).__init__()
  81. if use_bias:
  82. bias_attr = paddle.ParamAttr(name=name + "_bias")
  83. else:
  84. bias_attr = None
  85. self._sn_conv = spectral_norm(
  86. paddle.nn.Conv2D(
  87. in_channels=in_channels,
  88. out_channels=out_channels,
  89. kernel_size=kernel_size,
  90. stride=stride,
  91. padding=padding,
  92. dilation=dilation,
  93. groups=groups,
  94. weight_attr=paddle.ParamAttr(name=name + "_weights"),
  95. bias_attr=bias_attr))
  96. if norm_layer:
  97. self._norm_layer = getattr(paddle.nn, norm_layer)(
  98. num_features=out_channels, name=name + "_bn")
  99. else:
  100. self._norm_layer = None
  101. if act:
  102. if act_attr:
  103. self._act = getattr(paddle.nn, act)(**act_attr,
  104. name=name + "_" + act)
  105. else:
  106. self._act = getattr(paddle.nn, act)(name=name + "_" + act)
  107. else:
  108. self._act = None
  109. def forward(self, x):
  110. out = self._sn_conv(x)
  111. if self._norm_layer:
  112. out = self._norm_layer(out)
  113. if self._act:
  114. out = self._act(out)
  115. return out
  116. class SNConvTranspose(nn.Layer):
  117. def __init__(self,
  118. name,
  119. in_channels,
  120. out_channels,
  121. kernel_size,
  122. stride=1,
  123. padding=0,
  124. output_padding=0,
  125. dilation=1,
  126. groups=1,
  127. use_bias=False,
  128. norm_layer=None,
  129. act=None,
  130. act_attr=None):
  131. super(SNConvTranspose, self).__init__()
  132. if use_bias:
  133. bias_attr = paddle.ParamAttr(name=name + "_bias")
  134. else:
  135. bias_attr = None
  136. self._sn_conv_transpose = spectral_norm(
  137. paddle.nn.Conv2DTranspose(
  138. in_channels=in_channels,
  139. out_channels=out_channels,
  140. kernel_size=kernel_size,
  141. stride=stride,
  142. padding=padding,
  143. output_padding=output_padding,
  144. dilation=dilation,
  145. groups=groups,
  146. weight_attr=paddle.ParamAttr(name=name + "_weights"),
  147. bias_attr=bias_attr))
  148. if norm_layer:
  149. self._norm_layer = getattr(paddle.nn, norm_layer)(
  150. num_features=out_channels, name=name + "_bn")
  151. else:
  152. self._norm_layer = None
  153. if act:
  154. if act_attr:
  155. self._act = getattr(paddle.nn, act)(**act_attr,
  156. name=name + "_" + act)
  157. else:
  158. self._act = getattr(paddle.nn, act)(name=name + "_" + act)
  159. else:
  160. self._act = None
  161. def forward(self, x):
  162. out = self._sn_conv_transpose(x)
  163. if self._norm_layer:
  164. out = self._norm_layer(out)
  165. if self._act:
  166. out = self._act(out)
  167. return out
  168. class MiddleNet(nn.Layer):
  169. def __init__(self, name, in_channels, mid_channels, out_channels,
  170. use_bias):
  171. super(MiddleNet, self).__init__()
  172. self._sn_conv1 = SNConv(
  173. name=name + "_sn_conv1",
  174. in_channels=in_channels,
  175. out_channels=mid_channels,
  176. kernel_size=1,
  177. use_bias=use_bias,
  178. norm_layer=None,
  179. act=None)
  180. self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
  181. self._sn_conv2 = SNConv(
  182. name=name + "_sn_conv2",
  183. in_channels=mid_channels,
  184. out_channels=mid_channels,
  185. kernel_size=3,
  186. use_bias=use_bias)
  187. self._sn_conv3 = SNConv(
  188. name=name + "_sn_conv3",
  189. in_channels=mid_channels,
  190. out_channels=out_channels,
  191. kernel_size=1,
  192. use_bias=use_bias)
  193. def forward(self, x):
  194. sn_conv1 = self._sn_conv1.forward(x)
  195. pad_2d = self._pad2d.forward(sn_conv1)
  196. sn_conv2 = self._sn_conv2.forward(pad_2d)
  197. sn_conv3 = self._sn_conv3.forward(sn_conv2)
  198. return sn_conv3
  199. class ResBlock(nn.Layer):
  200. def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
  201. use_bias):
  202. super(ResBlock, self).__init__()
  203. if use_dilation:
  204. padding_mat = [1, 1, 1, 1]
  205. else:
  206. padding_mat = [0, 0, 0, 0]
  207. self._pad1 = nn.Pad2D(padding_mat, mode="replicate")
  208. self._sn_conv1 = SNConv(
  209. name=name + "_sn_conv1",
  210. in_channels=channels,
  211. out_channels=channels,
  212. kernel_size=3,
  213. padding=0,
  214. norm_layer=norm_layer,
  215. use_bias=use_bias,
  216. act="ReLU",
  217. act_attr=None)
  218. if use_dropout:
  219. self._dropout = nn.Dropout(0.5)
  220. else:
  221. self._dropout = None
  222. self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate")
  223. self._sn_conv2 = SNConv(
  224. name=name + "_sn_conv2",
  225. in_channels=channels,
  226. out_channels=channels,
  227. kernel_size=3,
  228. norm_layer=norm_layer,
  229. use_bias=use_bias,
  230. act="ReLU",
  231. act_attr=None)
  232. def forward(self, x):
  233. pad1 = self._pad1.forward(x)
  234. sn_conv1 = self._sn_conv1.forward(pad1)
  235. pad2 = self._pad2.forward(sn_conv1)
  236. sn_conv2 = self._sn_conv2.forward(pad2)
  237. return sn_conv2 + x