sr_metric.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/utils/ssim_psnr.py
  16. """
  17. from math import exp
  18. import paddle
  19. import paddle.nn.functional as F
  20. import paddle.nn as nn
  21. import string
  22. class SSIM(nn.Layer):
  23. def __init__(self, window_size=11, size_average=True):
  24. super(SSIM, self).__init__()
  25. self.window_size = window_size
  26. self.size_average = size_average
  27. self.channel = 1
  28. self.window = self.create_window(window_size, self.channel)
  29. def gaussian(self, window_size, sigma):
  30. gauss = paddle.to_tensor([
  31. exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
  32. for x in range(window_size)
  33. ])
  34. return gauss / gauss.sum()
  35. def create_window(self, window_size, channel):
  36. _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
  37. _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
  38. window = _2D_window.expand([channel, 1, window_size, window_size])
  39. return window
  40. def _ssim(self, img1, img2, window, window_size, channel,
  41. size_average=True):
  42. mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
  43. mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
  44. mu1_sq = mu1.pow(2)
  45. mu2_sq = mu2.pow(2)
  46. mu1_mu2 = mu1 * mu2
  47. sigma1_sq = F.conv2d(
  48. img1 * img1, window, padding=window_size // 2,
  49. groups=channel) - mu1_sq
  50. sigma2_sq = F.conv2d(
  51. img2 * img2, window, padding=window_size // 2,
  52. groups=channel) - mu2_sq
  53. sigma12 = F.conv2d(
  54. img1 * img2, window, padding=window_size // 2,
  55. groups=channel) - mu1_mu2
  56. C1 = 0.01**2
  57. C2 = 0.03**2
  58. ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
  59. (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
  60. if size_average:
  61. return ssim_map.mean()
  62. else:
  63. return ssim_map.mean([1, 2, 3])
  64. def ssim(self, img1, img2, window_size=11, size_average=True):
  65. (_, channel, _, _) = img1.shape
  66. window = self.create_window(window_size, channel)
  67. return self._ssim(img1, img2, window, window_size, channel,
  68. size_average)
  69. def forward(self, img1, img2):
  70. (_, channel, _, _) = img1.shape
  71. if channel == self.channel and self.window.dtype == img1.dtype:
  72. window = self.window
  73. else:
  74. window = self.create_window(self.window_size, channel)
  75. self.window = window
  76. self.channel = channel
  77. return self._ssim(img1, img2, window, self.window_size, channel,
  78. self.size_average)
  79. class SRMetric(object):
  80. def __init__(self, main_indicator='all', **kwargs):
  81. self.main_indicator = main_indicator
  82. self.eps = 1e-5
  83. self.psnr_result = []
  84. self.ssim_result = []
  85. self.calculate_ssim = SSIM()
  86. self.reset()
  87. def reset(self):
  88. self.correct_num = 0
  89. self.all_num = 0
  90. self.norm_edit_dis = 0
  91. self.psnr_result = []
  92. self.ssim_result = []
  93. def calculate_psnr(self, img1, img2):
  94. # img1 and img2 have range [0, 1]
  95. mse = ((img1 * 255 - img2 * 255)**2).mean()
  96. if mse == 0:
  97. return float('inf')
  98. return 20 * paddle.log10(255.0 / paddle.sqrt(mse))
  99. def _normalize_text(self, text):
  100. text = ''.join(
  101. filter(lambda x: x in (string.digits + string.ascii_letters), text))
  102. return text.lower()
  103. def __call__(self, pred_label, *args, **kwargs):
  104. metric = {}
  105. images_sr = pred_label["sr_img"]
  106. images_hr = pred_label["hr_img"]
  107. psnr = self.calculate_psnr(images_sr, images_hr)
  108. ssim = self.calculate_ssim(images_sr, images_hr)
  109. self.psnr_result.append(psnr)
  110. self.ssim_result.append(ssim)
  111. def get_metric(self):
  112. """
  113. return metrics {
  114. 'acc': 0,
  115. 'norm_edit_dis': 0,
  116. }
  117. """
  118. self.psnr_avg = sum(self.psnr_result) / len(self.psnr_result)
  119. self.psnr_avg = round(self.psnr_avg.item(), 6)
  120. self.ssim_avg = sum(self.ssim_result) / len(self.ssim_result)
  121. self.ssim_avg = round(self.ssim_avg.item(), 6)
  122. self.all_avg = self.psnr_avg + self.ssim_avg
  123. self.reset()
  124. return {
  125. 'psnr_avg': self.psnr_avg,
  126. "ssim_avg": self.ssim_avg,
  127. "all": self.all_avg
  128. }