vqa_token_re_metric.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle
  19. __all__ = ['KIEMetric']
  20. class VQAReTokenMetric(object):
  21. def __init__(self, main_indicator='hmean', **kwargs):
  22. self.main_indicator = main_indicator
  23. self.reset()
  24. def __call__(self, preds, batch, **kwargs):
  25. pred_relations, relations, entities = preds
  26. self.pred_relations_list.extend(pred_relations)
  27. self.relations_list.extend(relations)
  28. self.entities_list.extend(entities)
  29. def get_metric(self):
  30. gt_relations = []
  31. for b in range(len(self.relations_list)):
  32. rel_sent = []
  33. relation_list = self.relations_list[b]
  34. entitie_list = self.entities_list[b]
  35. head_len = relation_list[0, 0]
  36. if head_len > 0:
  37. entitie_start_list = entitie_list[1:entitie_list[0, 0] + 1, 0]
  38. entitie_end_list = entitie_list[1:entitie_list[0, 1] + 1, 1]
  39. entitie_label_list = entitie_list[1:entitie_list[0, 2] + 1, 2]
  40. for head, tail in zip(relation_list[1:head_len + 1, 0],
  41. relation_list[1:head_len + 1, 1]):
  42. rel = {}
  43. rel["head_id"] = head
  44. rel["head"] = (entitie_start_list[head],
  45. entitie_end_list[head])
  46. rel["head_type"] = entitie_label_list[head]
  47. rel["tail_id"] = tail
  48. rel["tail"] = (entitie_start_list[tail],
  49. entitie_end_list[tail])
  50. rel["tail_type"] = entitie_label_list[tail]
  51. rel["type"] = 1
  52. rel_sent.append(rel)
  53. gt_relations.append(rel_sent)
  54. re_metrics = self.re_score(
  55. self.pred_relations_list, gt_relations, mode="boundaries")
  56. metrics = {
  57. "precision": re_metrics["ALL"]["p"],
  58. "recall": re_metrics["ALL"]["r"],
  59. "hmean": re_metrics["ALL"]["f1"],
  60. }
  61. self.reset()
  62. return metrics
  63. def reset(self):
  64. self.pred_relations_list = []
  65. self.relations_list = []
  66. self.entities_list = []
  67. def re_score(self, pred_relations, gt_relations, mode="strict"):
  68. """Evaluate RE predictions
  69. Args:
  70. pred_relations (list) : list of list of predicted relations (several relations in each sentence)
  71. gt_relations (list) : list of list of ground truth relations
  72. rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
  73. "tail": (start_idx (inclusive), end_idx (exclusive)),
  74. "head_type": ent_type,
  75. "tail_type": ent_type,
  76. "type": rel_type}
  77. vocab (Vocab) : dataset vocabulary
  78. mode (str) : in 'strict' or 'boundaries'"""
  79. assert mode in ["strict", "boundaries"]
  80. relation_types = [v for v in [0, 1] if not v == 0]
  81. scores = {
  82. rel: {
  83. "tp": 0,
  84. "fp": 0,
  85. "fn": 0
  86. }
  87. for rel in relation_types + ["ALL"]
  88. }
  89. # Count GT relations and Predicted relations
  90. n_sents = len(gt_relations)
  91. n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
  92. n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
  93. # Count TP, FP and FN per type
  94. for pred_sent, gt_sent in zip(pred_relations, gt_relations):
  95. for rel_type in relation_types:
  96. # strict mode takes argument types into account
  97. if mode == "strict":
  98. pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
  99. rel["tail_type"])
  100. for rel in pred_sent
  101. if rel["type"] == rel_type}
  102. gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
  103. rel["tail_type"])
  104. for rel in gt_sent if rel["type"] == rel_type}
  105. # boundaries mode only takes argument spans into account
  106. elif mode == "boundaries":
  107. pred_rels = {(rel["head"], rel["tail"])
  108. for rel in pred_sent
  109. if rel["type"] == rel_type}
  110. gt_rels = {(rel["head"], rel["tail"])
  111. for rel in gt_sent if rel["type"] == rel_type}
  112. scores[rel_type]["tp"] += len(pred_rels & gt_rels)
  113. scores[rel_type]["fp"] += len(pred_rels - gt_rels)
  114. scores[rel_type]["fn"] += len(gt_rels - pred_rels)
  115. # Compute per entity Precision / Recall / F1
  116. for rel_type in scores.keys():
  117. if scores[rel_type]["tp"]:
  118. scores[rel_type]["p"] = scores[rel_type]["tp"] / (
  119. scores[rel_type]["fp"] + scores[rel_type]["tp"])
  120. scores[rel_type]["r"] = scores[rel_type]["tp"] / (
  121. scores[rel_type]["fn"] + scores[rel_type]["tp"])
  122. else:
  123. scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
  124. if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
  125. scores[rel_type]["f1"] = (
  126. 2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
  127. (scores[rel_type]["p"] + scores[rel_type]["r"]))
  128. else:
  129. scores[rel_type]["f1"] = 0
  130. # Compute micro F1 Scores
  131. tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
  132. fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
  133. fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
  134. if tp:
  135. precision = tp / (tp + fp)
  136. recall = tp / (tp + fn)
  137. f1 = 2 * precision * recall / (precision + recall)
  138. else:
  139. precision, recall, f1 = 0, 0, 0
  140. scores["ALL"]["p"] = precision
  141. scores["ALL"]["r"] = recall
  142. scores["ALL"]["f1"] = f1
  143. scores["ALL"]["tp"] = tp
  144. scores["ALL"]["fp"] = fp
  145. scores["ALL"]["fn"] = fn
  146. # Compute Macro F1 Scores
  147. scores["ALL"]["Macro_f1"] = np.mean(
  148. [scores[ent_type]["f1"] for ent_type in relation_types])
  149. scores["ALL"]["Macro_p"] = np.mean(
  150. [scores[ent_type]["p"] for ent_type in relation_types])
  151. scores["ALL"]["Macro_r"] = np.mean(
  152. [scores[ent_type]["r"] for ent_type in relation_types])
  153. return scores