123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import paddle
- import paddle.nn as nn
- from arch.spectral_norm import spectral_norm
- class CBN(nn.Layer):
- def __init__(self,
- name,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- use_bias=False,
- norm_layer=None,
- act=None,
- act_attr=None):
- super(CBN, self).__init__()
- if use_bias:
- bias_attr = paddle.ParamAttr(name=name + "_bias")
- else:
- bias_attr = None
- self._conv = paddle.nn.Conv2D(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- weight_attr=paddle.ParamAttr(name=name + "_weights"),
- bias_attr=bias_attr)
- if norm_layer:
- self._norm_layer = getattr(paddle.nn, norm_layer)(
- num_features=out_channels, name=name + "_bn")
- else:
- self._norm_layer = None
- if act:
- if act_attr:
- self._act = getattr(paddle.nn, act)(**act_attr,
- name=name + "_" + act)
- else:
- self._act = getattr(paddle.nn, act)(name=name + "_" + act)
- else:
- self._act = None
- def forward(self, x):
- out = self._conv(x)
- if self._norm_layer:
- out = self._norm_layer(out)
- if self._act:
- out = self._act(out)
- return out
- class SNConv(nn.Layer):
- def __init__(self,
- name,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- use_bias=False,
- norm_layer=None,
- act=None,
- act_attr=None):
- super(SNConv, self).__init__()
- if use_bias:
- bias_attr = paddle.ParamAttr(name=name + "_bias")
- else:
- bias_attr = None
- self._sn_conv = spectral_norm(
- paddle.nn.Conv2D(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- dilation=dilation,
- groups=groups,
- weight_attr=paddle.ParamAttr(name=name + "_weights"),
- bias_attr=bias_attr))
- if norm_layer:
- self._norm_layer = getattr(paddle.nn, norm_layer)(
- num_features=out_channels, name=name + "_bn")
- else:
- self._norm_layer = None
- if act:
- if act_attr:
- self._act = getattr(paddle.nn, act)(**act_attr,
- name=name + "_" + act)
- else:
- self._act = getattr(paddle.nn, act)(name=name + "_" + act)
- else:
- self._act = None
- def forward(self, x):
- out = self._sn_conv(x)
- if self._norm_layer:
- out = self._norm_layer(out)
- if self._act:
- out = self._act(out)
- return out
- class SNConvTranspose(nn.Layer):
- def __init__(self,
- name,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- output_padding=0,
- dilation=1,
- groups=1,
- use_bias=False,
- norm_layer=None,
- act=None,
- act_attr=None):
- super(SNConvTranspose, self).__init__()
- if use_bias:
- bias_attr = paddle.ParamAttr(name=name + "_bias")
- else:
- bias_attr = None
- self._sn_conv_transpose = spectral_norm(
- paddle.nn.Conv2DTranspose(
- in_channels=in_channels,
- out_channels=out_channels,
- kernel_size=kernel_size,
- stride=stride,
- padding=padding,
- output_padding=output_padding,
- dilation=dilation,
- groups=groups,
- weight_attr=paddle.ParamAttr(name=name + "_weights"),
- bias_attr=bias_attr))
- if norm_layer:
- self._norm_layer = getattr(paddle.nn, norm_layer)(
- num_features=out_channels, name=name + "_bn")
- else:
- self._norm_layer = None
- if act:
- if act_attr:
- self._act = getattr(paddle.nn, act)(**act_attr,
- name=name + "_" + act)
- else:
- self._act = getattr(paddle.nn, act)(name=name + "_" + act)
- else:
- self._act = None
- def forward(self, x):
- out = self._sn_conv_transpose(x)
- if self._norm_layer:
- out = self._norm_layer(out)
- if self._act:
- out = self._act(out)
- return out
- class MiddleNet(nn.Layer):
- def __init__(self, name, in_channels, mid_channels, out_channels,
- use_bias):
- super(MiddleNet, self).__init__()
- self._sn_conv1 = SNConv(
- name=name + "_sn_conv1",
- in_channels=in_channels,
- out_channels=mid_channels,
- kernel_size=1,
- use_bias=use_bias,
- norm_layer=None,
- act=None)
- self._pad2d = nn.Pad2D(padding=[1, 1, 1, 1], mode="replicate")
- self._sn_conv2 = SNConv(
- name=name + "_sn_conv2",
- in_channels=mid_channels,
- out_channels=mid_channels,
- kernel_size=3,
- use_bias=use_bias)
- self._sn_conv3 = SNConv(
- name=name + "_sn_conv3",
- in_channels=mid_channels,
- out_channels=out_channels,
- kernel_size=1,
- use_bias=use_bias)
- def forward(self, x):
- sn_conv1 = self._sn_conv1.forward(x)
- pad_2d = self._pad2d.forward(sn_conv1)
- sn_conv2 = self._sn_conv2.forward(pad_2d)
- sn_conv3 = self._sn_conv3.forward(sn_conv2)
- return sn_conv3
- class ResBlock(nn.Layer):
- def __init__(self, name, channels, norm_layer, use_dropout, use_dilation,
- use_bias):
- super(ResBlock, self).__init__()
- if use_dilation:
- padding_mat = [1, 1, 1, 1]
- else:
- padding_mat = [0, 0, 0, 0]
- self._pad1 = nn.Pad2D(padding_mat, mode="replicate")
- self._sn_conv1 = SNConv(
- name=name + "_sn_conv1",
- in_channels=channels,
- out_channels=channels,
- kernel_size=3,
- padding=0,
- norm_layer=norm_layer,
- use_bias=use_bias,
- act="ReLU",
- act_attr=None)
- if use_dropout:
- self._dropout = nn.Dropout(0.5)
- else:
- self._dropout = None
- self._pad2 = nn.Pad2D([1, 1, 1, 1], mode="replicate")
- self._sn_conv2 = SNConv(
- name=name + "_sn_conv2",
- in_channels=channels,
- out_channels=channels,
- kernel_size=3,
- norm_layer=norm_layer,
- use_bias=use_bias,
- act="ReLU",
- act_attr=None)
- def forward(self, x):
- pad1 = self._pad1.forward(x)
- sn_conv1 = self._sn_conv1.forward(pad1)
- pad2 = self._pad2.forward(sn_conv1)
- sn_conv2 = self._sn_conv2.forward(pad2)
- return sn_conv2 + x
|