123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from rapidfuzz.distance import Levenshtein
- from difflib import SequenceMatcher
- import numpy as np
- import string
- class RecMetric(object):
- def __init__(self,
- main_indicator='acc',
- is_filter=False,
- ignore_space=True,
- **kwargs):
- self.main_indicator = main_indicator
- self.is_filter = is_filter
- self.ignore_space = ignore_space
- self.eps = 1e-5
- self.reset()
- def _normalize_text(self, text):
- text = ''.join(
- filter(lambda x: x in (string.digits + string.ascii_letters), text))
- return text.lower()
- def __call__(self, pred_label, *args, **kwargs):
- preds, labels = pred_label
- correct_num = 0
- all_num = 0
- norm_edit_dis = 0.0
- for (pred, pred_conf), (target, _) in zip(preds, labels):
- if self.ignore_space:
- pred = pred.replace(" ", "")
- target = target.replace(" ", "")
- if self.is_filter:
- pred = self._normalize_text(pred)
- target = self._normalize_text(target)
- norm_edit_dis += Levenshtein.normalized_distance(pred, target)
- if pred == target:
- correct_num += 1
- all_num += 1
- self.correct_num += correct_num
- self.all_num += all_num
- self.norm_edit_dis += norm_edit_dis
- return {
- 'acc': correct_num / (all_num + self.eps),
- 'norm_edit_dis': 1 - norm_edit_dis / (all_num + self.eps)
- }
- def get_metric(self):
- """
- return metrics {
- 'acc': 0,
- 'norm_edit_dis': 0,
- }
- """
- acc = 1.0 * self.correct_num / (self.all_num + self.eps)
- norm_edit_dis = 1 - self.norm_edit_dis / (self.all_num + self.eps)
- self.reset()
- return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
- def reset(self):
- self.correct_num = 0
- self.all_num = 0
- self.norm_edit_dis = 0
- class CNTMetric(object):
- def __init__(self, main_indicator='acc', **kwargs):
- self.main_indicator = main_indicator
- self.eps = 1e-5
- self.reset()
- def __call__(self, pred_label, *args, **kwargs):
- preds, labels = pred_label
- correct_num = 0
- all_num = 0
- for pred, target in zip(preds, labels):
- if pred == target:
- correct_num += 1
- all_num += 1
- self.correct_num += correct_num
- self.all_num += all_num
- return {'acc': correct_num / (all_num + self.eps), }
- def get_metric(self):
- """
- return metrics {
- 'acc': 0,
- }
- """
- acc = 1.0 * self.correct_num / (self.all_num + self.eps)
- self.reset()
- return {'acc': acc}
- def reset(self):
- self.correct_num = 0
- self.all_num = 0
- class CANMetric(object):
- def __init__(self, main_indicator='exp_rate', **kwargs):
- self.main_indicator = main_indicator
- self.word_right = []
- self.exp_right = []
- self.word_total_length = 0
- self.exp_total_num = 0
- self.word_rate = 0
- self.exp_rate = 0
- self.reset()
- self.epoch_reset()
- def __call__(self, preds, batch, **kwargs):
- for k, v in kwargs.items():
- epoch_reset = v
- if epoch_reset:
- self.epoch_reset()
- word_probs = preds
- word_label, word_label_mask = batch
- line_right = 0
- if word_probs is not None:
- word_pred = word_probs.argmax(2)
- word_pred = word_pred.cpu().detach().numpy()
- word_scores = [
- SequenceMatcher(
- None,
- s1[:int(np.sum(s3))],
- s2[:int(np.sum(s3))],
- autojunk=False).ratio() * (
- len(s1[:int(np.sum(s3))]) + len(s2[:int(np.sum(s3))])) /
- len(s1[:int(np.sum(s3))]) / 2
- for s1, s2, s3 in zip(word_label, word_pred, word_label_mask)
- ]
- batch_size = len(word_scores)
- for i in range(batch_size):
- if word_scores[i] == 1:
- line_right += 1
- self.word_rate = np.mean(word_scores) #float
- self.exp_rate = line_right / batch_size #float
- exp_length, word_length = word_label.shape[:2]
- self.word_right.append(self.word_rate * word_length)
- self.exp_right.append(self.exp_rate * exp_length)
- self.word_total_length = self.word_total_length + word_length
- self.exp_total_num = self.exp_total_num + exp_length
- def get_metric(self):
- """
- return {
- 'word_rate': 0,
- "exp_rate": 0,
- }
- """
- cur_word_rate = sum(self.word_right) / self.word_total_length
- cur_exp_rate = sum(self.exp_right) / self.exp_total_num
- self.reset()
- return {'word_rate': cur_word_rate, "exp_rate": cur_exp_rate}
- def reset(self):
- self.word_rate = 0
- self.exp_rate = 0
- def epoch_reset(self):
- self.word_right = []
- self.exp_right = []
- self.word_total_length = 0
- self.exp_total_num = 0
|