123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import numpy as np
- import paddle
- from ppocr.utils.utility import load_vqa_bio_label_maps
- class VQASerTokenLayoutLMPostProcess(object):
- """ Convert between text-label and text-index """
- def __init__(self, class_path, **kwargs):
- super(VQASerTokenLayoutLMPostProcess, self).__init__()
- label2id_map, self.id2label_map = load_vqa_bio_label_maps(class_path)
- self.label2id_map_for_draw = dict()
- for key in label2id_map:
- if key.startswith("I-"):
- self.label2id_map_for_draw[key] = label2id_map["B" + key[1:]]
- else:
- self.label2id_map_for_draw[key] = label2id_map[key]
- self.id2label_map_for_show = dict()
- for key in self.label2id_map_for_draw:
- val = self.label2id_map_for_draw[key]
- if key == "O":
- self.id2label_map_for_show[val] = key
- if key.startswith("B-") or key.startswith("I-"):
- self.id2label_map_for_show[val] = key[2:]
- else:
- self.id2label_map_for_show[val] = key
- def __call__(self, preds, batch=None, *args, **kwargs):
- if isinstance(preds, tuple):
- preds = preds[0]
- if isinstance(preds, paddle.Tensor):
- preds = preds.numpy()
- if batch is not None:
- return self._metric(preds, batch[5])
- else:
- return self._infer(preds, **kwargs)
- def _metric(self, preds, label):
- pred_idxs = preds.argmax(axis=2)
- decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
- label_decode_out_list = [[] for _ in range(pred_idxs.shape[0])]
- for i in range(pred_idxs.shape[0]):
- for j in range(pred_idxs.shape[1]):
- if label[i, j] != -100:
- label_decode_out_list[i].append(self.id2label_map[label[i,
- j]])
- decode_out_list[i].append(self.id2label_map[pred_idxs[i,
- j]])
- return decode_out_list, label_decode_out_list
- def _infer(self, preds, segment_offset_ids, ocr_infos):
- results = []
- for pred, segment_offset_id, ocr_info in zip(preds, segment_offset_ids,
- ocr_infos):
- pred = np.argmax(pred, axis=1)
- pred = [self.id2label_map[idx] for idx in pred]
- for idx in range(len(segment_offset_id)):
- if idx == 0:
- start_id = 0
- else:
- start_id = segment_offset_id[idx - 1]
- end_id = segment_offset_id[idx]
- curr_pred = pred[start_id:end_id]
- curr_pred = [self.label2id_map_for_draw[p] for p in curr_pred]
- if len(curr_pred) <= 0:
- pred_id = 0
- else:
- counts = np.bincount(curr_pred)
- pred_id = np.argmax(counts)
- ocr_info[idx]["pred_id"] = int(pred_id)
- ocr_info[idx]["pred"] = self.id2label_map_for_show[int(pred_id)]
- results.append(ocr_info)
- return results
- class DistillationSerPostProcess(VQASerTokenLayoutLMPostProcess):
- """
- DistillationSerPostProcess
- """
- def __init__(self, class_path, model_name=["Student"], key=None, **kwargs):
- super().__init__(class_path, **kwargs)
- if not isinstance(model_name, list):
- model_name = [model_name]
- self.model_name = model_name
- self.key = key
- def __call__(self, preds, batch=None, *args, **kwargs):
- output = dict()
- for name in self.model_name:
- pred = preds[name]
- if self.key is not None:
- pred = pred[self.key]
- output[name] = super().__call__(pred, batch=batch, *args, **kwargs)
- return output
|