123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- def normal_(x, mean=0., std=1.):
- temp_value = paddle.normal(mean, std, shape=x.shape)
- x.set_value(temp_value)
- return x
- class SpectralNorm(object):
- def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
- self.name = name
- self.dim = dim
- if n_power_iterations <= 0:
- raise ValueError('Expected n_power_iterations to be positive, but '
- 'got n_power_iterations={}'.format(
- n_power_iterations))
- self.n_power_iterations = n_power_iterations
- self.eps = eps
- def reshape_weight_to_matrix(self, weight):
- weight_mat = weight
- if self.dim != 0:
- # transpose dim to front
- weight_mat = weight_mat.transpose([
- self.dim,
- * [d for d in range(weight_mat.dim()) if d != self.dim]
- ])
- height = weight_mat.shape[0]
- return weight_mat.reshape([height, -1])
- def compute_weight(self, module, do_power_iteration):
- weight = getattr(module, self.name + '_orig')
- u = getattr(module, self.name + '_u')
- v = getattr(module, self.name + '_v')
- weight_mat = self.reshape_weight_to_matrix(weight)
- if do_power_iteration:
- with paddle.no_grad():
- for _ in range(self.n_power_iterations):
- v.set_value(
- F.normalize(
- paddle.matmul(
- weight_mat,
- u,
- transpose_x=True,
- transpose_y=False),
- axis=0,
- epsilon=self.eps, ))
- u.set_value(
- F.normalize(
- paddle.matmul(weight_mat, v),
- axis=0,
- epsilon=self.eps, ))
- if self.n_power_iterations > 0:
- u = u.clone()
- v = v.clone()
- sigma = paddle.dot(u, paddle.mv(weight_mat, v))
- weight = weight / sigma
- return weight
- def remove(self, module):
- with paddle.no_grad():
- weight = self.compute_weight(module, do_power_iteration=False)
- delattr(module, self.name)
- delattr(module, self.name + '_u')
- delattr(module, self.name + '_v')
- delattr(module, self.name + '_orig')
- module.add_parameter(self.name, weight.detach())
- def __call__(self, module, inputs):
- setattr(
- module,
- self.name,
- self.compute_weight(
- module, do_power_iteration=module.training))
- @staticmethod
- def apply(module, name, n_power_iterations, dim, eps):
- for k, hook in module._forward_pre_hooks.items():
- if isinstance(hook, SpectralNorm) and hook.name == name:
- raise RuntimeError(
- "Cannot register two spectral_norm hooks on "
- "the same parameter {}".format(name))
- fn = SpectralNorm(name, n_power_iterations, dim, eps)
- weight = module._parameters[name]
- with paddle.no_grad():
- weight_mat = fn.reshape_weight_to_matrix(weight)
- h, w = weight_mat.shape
- # randomly initialize u and v
- u = module.create_parameter([h])
- u = normal_(u, 0., 1.)
- v = module.create_parameter([w])
- v = normal_(v, 0., 1.)
- u = F.normalize(u, axis=0, epsilon=fn.eps)
- v = F.normalize(v, axis=0, epsilon=fn.eps)
- # delete fn.name form parameters, otherwise you can not set attribute
- del module._parameters[fn.name]
- module.add_parameter(fn.name + "_orig", weight)
- # still need to assign weight back as fn.name because all sorts of
- # things may assume that it exists, e.g., when initializing weights.
- # However, we can't directly assign as it could be an Parameter and
- # gets added as a parameter. Instead, we register weight * 1.0 as a plain
- # attribute.
- setattr(module, fn.name, weight * 1.0)
- module.register_buffer(fn.name + "_u", u)
- module.register_buffer(fn.name + "_v", v)
- module.register_forward_pre_hook(fn)
- return fn
- def spectral_norm(module,
- name='weight',
- n_power_iterations=1,
- eps=1e-12,
- dim=None):
- if dim is None:
- if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
- nn.Conv3DTranspose, nn.Linear)):
- dim = 1
- else:
- dim = 0
- SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
- return module
|