123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- import importlib
- import copy
- from .rec_metric import RecMetric
- from .det_metric import DetMetric
- from .e2e_metric import E2EMetric
- from .cls_metric import ClsMetric
- from .vqa_token_ser_metric import VQASerTokenMetric
- from .vqa_token_re_metric import VQAReTokenMetric
- class DistillationMetric(object):
- def __init__(self,
- key=None,
- base_metric_name=None,
- main_indicator=None,
- **kwargs):
- self.main_indicator = main_indicator
- self.key = key
- self.main_indicator = main_indicator
- self.base_metric_name = base_metric_name
- self.kwargs = kwargs
- self.metrics = None
- def _init_metrcis(self, preds):
- self.metrics = dict()
- mod = importlib.import_module(__name__)
- for key in preds:
- self.metrics[key] = getattr(mod, self.base_metric_name)(
- main_indicator=self.main_indicator, **self.kwargs)
- self.metrics[key].reset()
- def __call__(self, preds, batch, **kwargs):
- assert isinstance(preds, dict)
- if self.metrics is None:
- self._init_metrcis(preds)
- output = dict()
- for key in preds:
- self.metrics[key].__call__(preds[key], batch, **kwargs)
- def get_metric(self):
- """
- return metrics {
- 'acc': 0,
- 'norm_edit_dis': 0,
- }
- """
- output = dict()
- for key in self.metrics:
- metric = self.metrics[key].get_metric()
-
- if key == self.key:
- output.update(metric)
- else:
- for sub_key in metric:
- output["{}_{}".format(key, sub_key)] = metric[sub_key]
- return output
- def reset(self):
- for key in self.metrics:
- self.metrics[key].reset()
|