vqa_token_ser_metric.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. __all__ = ['KIEMetric']
  20. class VQASerTokenMetric(object):
  21. def __init__(self, main_indicator='hmean', **kwargs):
  22. self.main_indicator = main_indicator
  23. self.reset()
  24. def __call__(self, preds, batch, **kwargs):
  25. preds, labels = preds
  26. self.pred_list.extend(preds)
  27. self.gt_list.extend(labels)
  28. def get_metric(self):
  29. from seqeval.metrics import f1_score, precision_score, recall_score
  30. metrics = {
  31. "precision": precision_score(self.gt_list, self.pred_list),
  32. "recall": recall_score(self.gt_list, self.pred_list),
  33. "hmean": f1_score(self.gt_list, self.pred_list),
  34. }
  35. self.reset()
  36. return metrics
  37. def reset(self):
  38. self.pred_list = []
  39. self.gt_list = []