| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 | # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##    http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.import numpy as npfrom ppocr.metrics.det_metric import DetMetricclass TableStructureMetric(object):    def __init__(self,                 main_indicator='acc',                 eps=1e-6,                 del_thead_tbody=False,                 **kwargs):        self.main_indicator = main_indicator        self.eps = eps        self.del_thead_tbody = del_thead_tbody        self.reset()    def __call__(self, pred_label, batch=None, *args, **kwargs):        preds, labels = pred_label        pred_structure_batch_list = preds['structure_batch_list']        gt_structure_batch_list = labels['structure_batch_list']        correct_num = 0        all_num = 0        for (pred, pred_conf), target in zip(pred_structure_batch_list,                                             gt_structure_batch_list):            pred_str = ''.join(pred)            target_str = ''.join(target)            if self.del_thead_tbody:                pred_str = pred_str.replace('<thead>', '').replace(                    '</thead>', '').replace('<tbody>', '').replace('</tbody>',                                                                   '')                target_str = target_str.replace('<thead>', '').replace(                    '</thead>', '').replace('<tbody>', '').replace('</tbody>',                                                                   '')            if pred_str == target_str:                correct_num += 1            all_num += 1        self.correct_num += correct_num        self.all_num += all_num    def get_metric(self):        """        return metrics {                 'acc': 0,            }        """        acc = 1.0 * self.correct_num / (self.all_num + self.eps)        self.reset()        return {'acc': acc}    def reset(self):        self.correct_num = 0        self.all_num = 0        self.len_acc_num = 0        self.token_nums = 0        self.anys_dict = dict()class TableMetric(object):    def __init__(self,                 main_indicator='acc',                 compute_bbox_metric=False,                 box_format='xyxy',                 del_thead_tbody=False,                 **kwargs):        """        @param sub_metrics: configs of sub_metric        @param main_matric: main_matric for save best_model        @param kwargs:        """        self.structure_metric = TableStructureMetric(            del_thead_tbody=del_thead_tbody)        self.bbox_metric = DetMetric() if compute_bbox_metric else None        self.main_indicator = main_indicator        self.box_format = box_format        self.reset()    def __call__(self, pred_label, batch=None, *args, **kwargs):        self.structure_metric(pred_label)        if self.bbox_metric is not None:            self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))    def prepare_bbox_metric_input(self, pred_label):        pred_bbox_batch_list = []        gt_ignore_tags_batch_list = []        gt_bbox_batch_list = []        preds, labels = pred_label        batch_num = len(preds['bbox_batch_list'])        for batch_idx in range(batch_num):            # pred            pred_bbox_list = [                self.format_box(pred_box)                for pred_box in preds['bbox_batch_list'][batch_idx]            ]            pred_bbox_batch_list.append({'points': pred_bbox_list})            # gt            gt_bbox_list = []            gt_ignore_tags_list = []            for gt_box in labels['bbox_batch_list'][batch_idx]:                gt_bbox_list.append(self.format_box(gt_box))                gt_ignore_tags_list.append(0)            gt_bbox_batch_list.append(gt_bbox_list)            gt_ignore_tags_batch_list.append(gt_ignore_tags_list)        return [            pred_bbox_batch_list,            [0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list]        ]    def get_metric(self):        structure_metric = self.structure_metric.get_metric()        if self.bbox_metric is None:            return structure_metric        bbox_metric = self.bbox_metric.get_metric()        if self.main_indicator == self.bbox_metric.main_indicator:            output = bbox_metric            for sub_key in structure_metric:                output["structure_metric_{}".format(                    sub_key)] = structure_metric[sub_key]        else:            output = structure_metric            for sub_key in bbox_metric:                output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]        return output    def reset(self):        self.structure_metric.reset()        if self.bbox_metric is not None:            self.bbox_metric.reset()    def format_box(self, box):        if self.box_format == 'xyxy':            x1, y1, x2, y2 = box            box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]        elif self.box_format == 'xywh':            x, y, w, h = box            x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2            box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]        elif self.box_format == 'xyxyxyxy':            x1, y1, x2, y2, x3, y3, x4, y4 = box            box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]        return box
 |