spectral_norm.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. def normal_(x, mean=0., std=1.):
  18. temp_value = paddle.normal(mean, std, shape=x.shape)
  19. x.set_value(temp_value)
  20. return x
  21. class SpectralNorm(object):
  22. def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
  23. self.name = name
  24. self.dim = dim
  25. if n_power_iterations <= 0:
  26. raise ValueError('Expected n_power_iterations to be positive, but '
  27. 'got n_power_iterations={}'.format(
  28. n_power_iterations))
  29. self.n_power_iterations = n_power_iterations
  30. self.eps = eps
  31. def reshape_weight_to_matrix(self, weight):
  32. weight_mat = weight
  33. if self.dim != 0:
  34. # transpose dim to front
  35. weight_mat = weight_mat.transpose([
  36. self.dim,
  37. * [d for d in range(weight_mat.dim()) if d != self.dim]
  38. ])
  39. height = weight_mat.shape[0]
  40. return weight_mat.reshape([height, -1])
  41. def compute_weight(self, module, do_power_iteration):
  42. weight = getattr(module, self.name + '_orig')
  43. u = getattr(module, self.name + '_u')
  44. v = getattr(module, self.name + '_v')
  45. weight_mat = self.reshape_weight_to_matrix(weight)
  46. if do_power_iteration:
  47. with paddle.no_grad():
  48. for _ in range(self.n_power_iterations):
  49. v.set_value(
  50. F.normalize(
  51. paddle.matmul(
  52. weight_mat,
  53. u,
  54. transpose_x=True,
  55. transpose_y=False),
  56. axis=0,
  57. epsilon=self.eps, ))
  58. u.set_value(
  59. F.normalize(
  60. paddle.matmul(weight_mat, v),
  61. axis=0,
  62. epsilon=self.eps, ))
  63. if self.n_power_iterations > 0:
  64. u = u.clone()
  65. v = v.clone()
  66. sigma = paddle.dot(u, paddle.mv(weight_mat, v))
  67. weight = weight / sigma
  68. return weight
  69. def remove(self, module):
  70. with paddle.no_grad():
  71. weight = self.compute_weight(module, do_power_iteration=False)
  72. delattr(module, self.name)
  73. delattr(module, self.name + '_u')
  74. delattr(module, self.name + '_v')
  75. delattr(module, self.name + '_orig')
  76. module.add_parameter(self.name, weight.detach())
  77. def __call__(self, module, inputs):
  78. setattr(
  79. module,
  80. self.name,
  81. self.compute_weight(
  82. module, do_power_iteration=module.training))
  83. @staticmethod
  84. def apply(module, name, n_power_iterations, dim, eps):
  85. for k, hook in module._forward_pre_hooks.items():
  86. if isinstance(hook, SpectralNorm) and hook.name == name:
  87. raise RuntimeError(
  88. "Cannot register two spectral_norm hooks on "
  89. "the same parameter {}".format(name))
  90. fn = SpectralNorm(name, n_power_iterations, dim, eps)
  91. weight = module._parameters[name]
  92. with paddle.no_grad():
  93. weight_mat = fn.reshape_weight_to_matrix(weight)
  94. h, w = weight_mat.shape
  95. # randomly initialize u and v
  96. u = module.create_parameter([h])
  97. u = normal_(u, 0., 1.)
  98. v = module.create_parameter([w])
  99. v = normal_(v, 0., 1.)
  100. u = F.normalize(u, axis=0, epsilon=fn.eps)
  101. v = F.normalize(v, axis=0, epsilon=fn.eps)
  102. # delete fn.name form parameters, otherwise you can not set attribute
  103. del module._parameters[fn.name]
  104. module.add_parameter(fn.name + "_orig", weight)
  105. # still need to assign weight back as fn.name because all sorts of
  106. # things may assume that it exists, e.g., when initializing weights.
  107. # However, we can't directly assign as it could be an Parameter and
  108. # gets added as a parameter. Instead, we register weight * 1.0 as a plain
  109. # attribute.
  110. setattr(module, fn.name, weight * 1.0)
  111. module.register_buffer(fn.name + "_u", u)
  112. module.register_buffer(fn.name + "_v", v)
  113. module.register_forward_pre_hook(fn)
  114. return fn
  115. def spectral_norm(module,
  116. name='weight',
  117. n_power_iterations=1,
  118. eps=1e-12,
  119. dim=None):
  120. if dim is None:
  121. if isinstance(module, (nn.Conv1DTranspose, nn.Conv2DTranspose,
  122. nn.Conv3DTranspose, nn.Linear)):
  123. dim = 1
  124. else:
  125. dim = 0
  126. SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
  127. return module