123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import importlib
- import copy
- from .rec_metric import RecMetric
- from .det_metric import DetMetric
- from .e2e_metric import E2EMetric
- from .cls_metric import ClsMetric
- from .vqa_token_ser_metric import VQASerTokenMetric
- from .vqa_token_re_metric import VQAReTokenMetric
- class DistillationMetric(object):
- def __init__(self,
- key=None,
- base_metric_name=None,
- main_indicator=None,
- **kwargs):
- self.main_indicator = main_indicator
- self.key = key
- self.main_indicator = main_indicator
- self.base_metric_name = base_metric_name
- self.kwargs = kwargs
- self.metrics = None
- def _init_metrcis(self, preds):
- self.metrics = dict()
- mod = importlib.import_module(__name__)
- for key in preds:
- self.metrics[key] = getattr(mod, self.base_metric_name)(
- main_indicator=self.main_indicator, **self.kwargs)
- self.metrics[key].reset()
- def __call__(self, preds, batch, **kwargs):
- assert isinstance(preds, dict)
- if self.metrics is None:
- self._init_metrcis(preds)
- output = dict()
- for key in preds:
- self.metrics[key].__call__(preds[key], batch, **kwargs)
- def get_metric(self):
- """
- return metrics {
- 'acc': 0,
- 'norm_edit_dis': 0,
- }
- """
- output = dict()
- for key in self.metrics:
- metric = self.metrics[key].get_metric()
- # main indicator
- if key == self.key:
- output.update(metric)
- else:
- for sub_key in metric:
- output["{}_{}".format(key, sub_key)] = metric[sub_key]
- return output
- def reset(self):
- for key in self.metrics:
- self.metrics[key].reset()
|