distillation_metric.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import importlib
  15. import copy
  16. from .rec_metric import RecMetric
  17. from .det_metric import DetMetric
  18. from .e2e_metric import E2EMetric
  19. from .cls_metric import ClsMetric
  20. from .vqa_token_ser_metric import VQASerTokenMetric
  21. from .vqa_token_re_metric import VQAReTokenMetric
  22. class DistillationMetric(object):
  23. def __init__(self,
  24. key=None,
  25. base_metric_name=None,
  26. main_indicator=None,
  27. **kwargs):
  28. self.main_indicator = main_indicator
  29. self.key = key
  30. self.main_indicator = main_indicator
  31. self.base_metric_name = base_metric_name
  32. self.kwargs = kwargs
  33. self.metrics = None
  34. def _init_metrcis(self, preds):
  35. self.metrics = dict()
  36. mod = importlib.import_module(__name__)
  37. for key in preds:
  38. self.metrics[key] = getattr(mod, self.base_metric_name)(
  39. main_indicator=self.main_indicator, **self.kwargs)
  40. self.metrics[key].reset()
  41. def __call__(self, preds, batch, **kwargs):
  42. assert isinstance(preds, dict)
  43. if self.metrics is None:
  44. self._init_metrcis(preds)
  45. output = dict()
  46. for key in preds:
  47. self.metrics[key].__call__(preds[key], batch, **kwargs)
  48. def get_metric(self):
  49. """
  50. return metrics {
  51. 'acc': 0,
  52. 'norm_edit_dis': 0,
  53. }
  54. """
  55. output = dict()
  56. for key in self.metrics:
  57. metric = self.metrics[key].get_metric()
  58. # main indicator
  59. if key == self.key:
  60. output.update(metric)
  61. else:
  62. for sub_key in metric:
  63. output["{}_{}".format(key, sub_key)] = metric[sub_key]
  64. return output
  65. def reset(self):
  66. for key in self.metrics:
  67. self.metrics[key].reset()