123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155 |
- # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/utils/ssim_psnr.py
- """
- from math import exp
- import paddle
- import paddle.nn.functional as F
- import paddle.nn as nn
- import string
- class SSIM(nn.Layer):
- def __init__(self, window_size=11, size_average=True):
- super(SSIM, self).__init__()
- self.window_size = window_size
- self.size_average = size_average
- self.channel = 1
- self.window = self.create_window(window_size, self.channel)
- def gaussian(self, window_size, sigma):
- gauss = paddle.to_tensor([
- exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
- for x in range(window_size)
- ])
- return gauss / gauss.sum()
- def create_window(self, window_size, channel):
- _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
- _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
- window = _2D_window.expand([channel, 1, window_size, window_size])
- return window
- def _ssim(self, img1, img2, window, window_size, channel,
- size_average=True):
- mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
- mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
- mu1_sq = mu1.pow(2)
- mu2_sq = mu2.pow(2)
- mu1_mu2 = mu1 * mu2
- sigma1_sq = F.conv2d(
- img1 * img1, window, padding=window_size // 2,
- groups=channel) - mu1_sq
- sigma2_sq = F.conv2d(
- img2 * img2, window, padding=window_size // 2,
- groups=channel) - mu2_sq
- sigma12 = F.conv2d(
- img1 * img2, window, padding=window_size // 2,
- groups=channel) - mu1_mu2
- C1 = 0.01**2
- C2 = 0.03**2
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
- (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
- if size_average:
- return ssim_map.mean()
- else:
- return ssim_map.mean([1, 2, 3])
- def ssim(self, img1, img2, window_size=11, size_average=True):
- (_, channel, _, _) = img1.shape
- window = self.create_window(window_size, channel)
- return self._ssim(img1, img2, window, window_size, channel,
- size_average)
- def forward(self, img1, img2):
- (_, channel, _, _) = img1.shape
- if channel == self.channel and self.window.dtype == img1.dtype:
- window = self.window
- else:
- window = self.create_window(self.window_size, channel)
- self.window = window
- self.channel = channel
- return self._ssim(img1, img2, window, self.window_size, channel,
- self.size_average)
- class SRMetric(object):
- def __init__(self, main_indicator='all', **kwargs):
- self.main_indicator = main_indicator
- self.eps = 1e-5
- self.psnr_result = []
- self.ssim_result = []
- self.calculate_ssim = SSIM()
- self.reset()
- def reset(self):
- self.correct_num = 0
- self.all_num = 0
- self.norm_edit_dis = 0
- self.psnr_result = []
- self.ssim_result = []
- def calculate_psnr(self, img1, img2):
- # img1 and img2 have range [0, 1]
- mse = ((img1 * 255 - img2 * 255)**2).mean()
- if mse == 0:
- return float('inf')
- return 20 * paddle.log10(255.0 / paddle.sqrt(mse))
- def _normalize_text(self, text):
- text = ''.join(
- filter(lambda x: x in (string.digits + string.ascii_letters), text))
- return text.lower()
- def __call__(self, pred_label, *args, **kwargs):
- metric = {}
- images_sr = pred_label["sr_img"]
- images_hr = pred_label["hr_img"]
- psnr = self.calculate_psnr(images_sr, images_hr)
- ssim = self.calculate_ssim(images_sr, images_hr)
- self.psnr_result.append(psnr)
- self.ssim_result.append(ssim)
- def get_metric(self):
- """
- return metrics {
- 'acc': 0,
- 'norm_edit_dis': 0,
- }
- """
- self.psnr_avg = sum(self.psnr_result) / len(self.psnr_result)
- self.psnr_avg = round(self.psnr_avg.item(), 6)
- self.ssim_avg = sum(self.ssim_result) / len(self.ssim_result)
- self.ssim_avg = round(self.ssim_avg.item(), 6)
- self.all_avg = self.psnr_avg + self.ssim_avg
- self.reset()
- return {
- 'psnr_avg': self.psnr_avg,
- "ssim_avg": self.ssim_avg,
- "all": self.all_avg
- }
|