123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import numpy as np
- import paddle
- __all__ = ['KIEMetric']
- class VQAReTokenMetric(object):
- def __init__(self, main_indicator='hmean', **kwargs):
- self.main_indicator = main_indicator
- self.reset()
- def __call__(self, preds, batch, **kwargs):
- pred_relations, relations, entities = preds
- self.pred_relations_list.extend(pred_relations)
- self.relations_list.extend(relations)
- self.entities_list.extend(entities)
- def get_metric(self):
- gt_relations = []
- for b in range(len(self.relations_list)):
- rel_sent = []
- relation_list = self.relations_list[b]
- entitie_list = self.entities_list[b]
- head_len = relation_list[0, 0]
- if head_len > 0:
- entitie_start_list = entitie_list[1:entitie_list[0, 0] + 1, 0]
- entitie_end_list = entitie_list[1:entitie_list[0, 1] + 1, 1]
- entitie_label_list = entitie_list[1:entitie_list[0, 2] + 1, 2]
- for head, tail in zip(relation_list[1:head_len + 1, 0],
- relation_list[1:head_len + 1, 1]):
- rel = {}
- rel["head_id"] = head
- rel["head"] = (entitie_start_list[head],
- entitie_end_list[head])
- rel["head_type"] = entitie_label_list[head]
- rel["tail_id"] = tail
- rel["tail"] = (entitie_start_list[tail],
- entitie_end_list[tail])
- rel["tail_type"] = entitie_label_list[tail]
- rel["type"] = 1
- rel_sent.append(rel)
- gt_relations.append(rel_sent)
- re_metrics = self.re_score(
- self.pred_relations_list, gt_relations, mode="boundaries")
- metrics = {
- "precision": re_metrics["ALL"]["p"],
- "recall": re_metrics["ALL"]["r"],
- "hmean": re_metrics["ALL"]["f1"],
- }
- self.reset()
- return metrics
- def reset(self):
- self.pred_relations_list = []
- self.relations_list = []
- self.entities_list = []
- def re_score(self, pred_relations, gt_relations, mode="strict"):
- """Evaluate RE predictions
- Args:
- pred_relations (list) : list of list of predicted relations (several relations in each sentence)
- gt_relations (list) : list of list of ground truth relations
- rel = { "head": (start_idx (inclusive), end_idx (exclusive)),
- "tail": (start_idx (inclusive), end_idx (exclusive)),
- "head_type": ent_type,
- "tail_type": ent_type,
- "type": rel_type}
- vocab (Vocab) : dataset vocabulary
- mode (str) : in 'strict' or 'boundaries'"""
- assert mode in ["strict", "boundaries"]
- relation_types = [v for v in [0, 1] if not v == 0]
- scores = {
- rel: {
- "tp": 0,
- "fp": 0,
- "fn": 0
- }
- for rel in relation_types + ["ALL"]
- }
- # Count GT relations and Predicted relations
- n_sents = len(gt_relations)
- n_rels = sum([len([rel for rel in sent]) for sent in gt_relations])
- n_found = sum([len([rel for rel in sent]) for sent in pred_relations])
- # Count TP, FP and FN per type
- for pred_sent, gt_sent in zip(pred_relations, gt_relations):
- for rel_type in relation_types:
- # strict mode takes argument types into account
- if mode == "strict":
- pred_rels = {(rel["head"], rel["head_type"], rel["tail"],
- rel["tail_type"])
- for rel in pred_sent
- if rel["type"] == rel_type}
- gt_rels = {(rel["head"], rel["head_type"], rel["tail"],
- rel["tail_type"])
- for rel in gt_sent if rel["type"] == rel_type}
- # boundaries mode only takes argument spans into account
- elif mode == "boundaries":
- pred_rels = {(rel["head"], rel["tail"])
- for rel in pred_sent
- if rel["type"] == rel_type}
- gt_rels = {(rel["head"], rel["tail"])
- for rel in gt_sent if rel["type"] == rel_type}
- scores[rel_type]["tp"] += len(pred_rels & gt_rels)
- scores[rel_type]["fp"] += len(pred_rels - gt_rels)
- scores[rel_type]["fn"] += len(gt_rels - pred_rels)
- # Compute per entity Precision / Recall / F1
- for rel_type in scores.keys():
- if scores[rel_type]["tp"]:
- scores[rel_type]["p"] = scores[rel_type]["tp"] / (
- scores[rel_type]["fp"] + scores[rel_type]["tp"])
- scores[rel_type]["r"] = scores[rel_type]["tp"] / (
- scores[rel_type]["fn"] + scores[rel_type]["tp"])
- else:
- scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0
- if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
- scores[rel_type]["f1"] = (
- 2 * scores[rel_type]["p"] * scores[rel_type]["r"] /
- (scores[rel_type]["p"] + scores[rel_type]["r"]))
- else:
- scores[rel_type]["f1"] = 0
- # Compute micro F1 Scores
- tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
- fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
- fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
- if tp:
- precision = tp / (tp + fp)
- recall = tp / (tp + fn)
- f1 = 2 * precision * recall / (precision + recall)
- else:
- precision, recall, f1 = 0, 0, 0
- scores["ALL"]["p"] = precision
- scores["ALL"]["r"] = recall
- scores["ALL"]["f1"] = f1
- scores["ALL"]["tp"] = tp
- scores["ALL"]["fp"] = fp
- scores["ALL"]["fn"] = fn
- # Compute Macro F1 Scores
- scores["ALL"]["Macro_f1"] = np.mean(
- [scores[ent_type]["f1"] for ent_type in relation_types])
- scores["ALL"]["Macro_p"] = np.mean(
- [scores[ent_type]["p"] for ent_type in relation_types])
- scores["ALL"]["Macro_r"] = np.mean(
- [scores[ent_type]["r"] for ent_type in relation_types])
- return scores
|