123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import numpy as np
- from ppocr.metrics.det_metric import DetMetric
- class TableStructureMetric(object):
- def __init__(self,
- main_indicator='acc',
- eps=1e-6,
- del_thead_tbody=False,
- **kwargs):
- self.main_indicator = main_indicator
- self.eps = eps
- self.del_thead_tbody = del_thead_tbody
- self.reset()
- def __call__(self, pred_label, batch=None, *args, **kwargs):
- preds, labels = pred_label
- pred_structure_batch_list = preds['structure_batch_list']
- gt_structure_batch_list = labels['structure_batch_list']
- correct_num = 0
- all_num = 0
- for (pred, pred_conf), target in zip(pred_structure_batch_list,
- gt_structure_batch_list):
- pred_str = ''.join(pred)
- target_str = ''.join(target)
- if self.del_thead_tbody:
- pred_str = pred_str.replace('<thead>', '').replace(
- '</thead>', '').replace('<tbody>', '').replace('</tbody>',
- '')
- target_str = target_str.replace('<thead>', '').replace(
- '</thead>', '').replace('<tbody>', '').replace('</tbody>',
- '')
- if pred_str == target_str:
- correct_num += 1
- all_num += 1
- self.correct_num += correct_num
- self.all_num += all_num
- def get_metric(self):
- """
- return metrics {
- 'acc': 0,
- }
- """
- acc = 1.0 * self.correct_num / (self.all_num + self.eps)
- self.reset()
- return {'acc': acc}
- def reset(self):
- self.correct_num = 0
- self.all_num = 0
- self.len_acc_num = 0
- self.token_nums = 0
- self.anys_dict = dict()
- class TableMetric(object):
- def __init__(self,
- main_indicator='acc',
- compute_bbox_metric=False,
- box_format='xyxy',
- del_thead_tbody=False,
- **kwargs):
- """
- @param sub_metrics: configs of sub_metric
- @param main_matric: main_matric for save best_model
- @param kwargs:
- """
- self.structure_metric = TableStructureMetric(
- del_thead_tbody=del_thead_tbody)
- self.bbox_metric = DetMetric() if compute_bbox_metric else None
- self.main_indicator = main_indicator
- self.box_format = box_format
- self.reset()
- def __call__(self, pred_label, batch=None, *args, **kwargs):
- self.structure_metric(pred_label)
- if self.bbox_metric is not None:
- self.bbox_metric(*self.prepare_bbox_metric_input(pred_label))
- def prepare_bbox_metric_input(self, pred_label):
- pred_bbox_batch_list = []
- gt_ignore_tags_batch_list = []
- gt_bbox_batch_list = []
- preds, labels = pred_label
- batch_num = len(preds['bbox_batch_list'])
- for batch_idx in range(batch_num):
- # pred
- pred_bbox_list = [
- self.format_box(pred_box)
- for pred_box in preds['bbox_batch_list'][batch_idx]
- ]
- pred_bbox_batch_list.append({'points': pred_bbox_list})
- # gt
- gt_bbox_list = []
- gt_ignore_tags_list = []
- for gt_box in labels['bbox_batch_list'][batch_idx]:
- gt_bbox_list.append(self.format_box(gt_box))
- gt_ignore_tags_list.append(0)
- gt_bbox_batch_list.append(gt_bbox_list)
- gt_ignore_tags_batch_list.append(gt_ignore_tags_list)
- return [
- pred_bbox_batch_list,
- [0, 0, gt_bbox_batch_list, gt_ignore_tags_batch_list]
- ]
- def get_metric(self):
- structure_metric = self.structure_metric.get_metric()
- if self.bbox_metric is None:
- return structure_metric
- bbox_metric = self.bbox_metric.get_metric()
- if self.main_indicator == self.bbox_metric.main_indicator:
- output = bbox_metric
- for sub_key in structure_metric:
- output["structure_metric_{}".format(
- sub_key)] = structure_metric[sub_key]
- else:
- output = structure_metric
- for sub_key in bbox_metric:
- output["bbox_metric_{}".format(sub_key)] = bbox_metric[sub_key]
- return output
- def reset(self):
- self.structure_metric.reset()
- if self.bbox_metric is not None:
- self.bbox_metric.reset()
- def format_box(self, box):
- if self.box_format == 'xyxy':
- x1, y1, x2, y2 = box
- box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
- elif self.box_format == 'xywh':
- x, y, w, h = box
- x1, y1, x2, y2 = x - w // 2, y - h // 2, x + w // 2, y + h // 2
- box = [[x1, y1], [x2, y1], [x2, y2], [x1, y2]]
- elif self.box_format == 'xyxyxyxy':
- x1, y1, x2, y2, x3, y3, x4, y4 = box
- box = [[x1, y1], [x2, y2], [x3, y3], [x4, y4]]
- return box
|