123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import numpy as np
- import os
- import sys
- __dir__ = os.path.dirname(os.path.abspath(__file__))
- sys.path.append(__dir__)
- sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
- os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
- import cv2
- import json
- import paddle
- import paddle.distributed as dist
- from ppocr.data import create_operators, transform
- from ppocr.modeling.architectures import build_model
- from ppocr.postprocess import build_post_process
- from ppocr.utils.save_load import load_model
- from ppocr.utils.visual import draw_re_results
- from ppocr.utils.logging import get_logger
- from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
- from tools.program import ArgsParser, load_config, merge_config
- from tools.infer_kie_token_ser import SerPredictor
- class ReArgsParser(ArgsParser):
- def __init__(self):
- super(ReArgsParser, self).__init__()
- self.add_argument(
- "-c_ser", "--config_ser", help="ser configuration file to use")
- self.add_argument(
- "-o_ser",
- "--opt_ser",
- nargs='+',
- help="set ser configuration options ")
- def parse_args(self, argv=None):
- args = super(ReArgsParser, self).parse_args(argv)
- assert args.config_ser is not None, \
- "Please specify --config_ser=ser_configure_file_path."
- args.opt_ser = self._parse_opt(args.opt_ser)
- return args
- def make_input(ser_inputs, ser_results):
- entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
- batch_size, max_seq_len = ser_inputs[0].shape[:2]
- entities = ser_inputs[8][0]
- ser_results = ser_results[0]
- assert len(entities) == len(ser_results)
- # entities
- start = []
- end = []
- label = []
- entity_idx_dict = {}
- for i, (res, entity) in enumerate(zip(ser_results, entities)):
- if res['pred'] == 'O':
- continue
- entity_idx_dict[len(start)] = i
- start.append(entity['start'])
- end.append(entity['end'])
- label.append(entities_labels[res['pred']])
- entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64)
- entities[0, 0] = len(start)
- entities[1:len(start) + 1, 0] = start
- entities[0, 1] = len(end)
- entities[1:len(end) + 1, 1] = end
- entities[0, 2] = len(label)
- entities[1:len(label) + 1, 2] = label
- # relations
- head = []
- tail = []
- for i in range(len(label)):
- for j in range(len(label)):
- if label[i] == 1 and label[j] == 2:
- head.append(i)
- tail.append(j)
- relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64)
- relations[0, 0] = len(head)
- relations[1:len(head) + 1, 0] = head
- relations[0, 1] = len(tail)
- relations[1:len(tail) + 1, 1] = tail
- entities = np.expand_dims(entities, axis=0)
- entities = np.repeat(entities, batch_size, axis=0)
- relations = np.expand_dims(relations, axis=0)
- relations = np.repeat(relations, batch_size, axis=0)
- # remove ocr_info segment_offset_id and label in ser input
- if isinstance(ser_inputs[0], paddle.Tensor):
- entities = paddle.to_tensor(entities)
- relations = paddle.to_tensor(relations)
- ser_inputs = ser_inputs[:5] + [entities, relations]
- entity_idx_dict_batch = []
- for b in range(batch_size):
- entity_idx_dict_batch.append(entity_idx_dict)
- return ser_inputs, entity_idx_dict_batch
- class SerRePredictor(object):
- def __init__(self, config, ser_config):
- global_config = config['Global']
- if "infer_mode" in global_config:
- ser_config["Global"]["infer_mode"] = global_config["infer_mode"]
- self.ser_engine = SerPredictor(ser_config)
- # init re model
- # build post process
- self.post_process_class = build_post_process(config['PostProcess'],
- global_config)
- # build model
- self.model = build_model(config['Architecture'])
- load_model(
- config, self.model, model_type=config['Architecture']["model_type"])
- self.model.eval()
- def __call__(self, data):
- ser_results, ser_inputs = self.ser_engine(data)
- re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
- if self.model.backbone.use_visual_backbone is False:
- re_input.pop(4)
- preds = self.model(re_input)
- post_result = self.post_process_class(
- preds,
- ser_results=ser_results,
- entity_idx_dict_batch=entity_idx_dict_batch)
- return post_result
- def preprocess():
- FLAGS = ReArgsParser().parse_args()
- config = load_config(FLAGS.config)
- config = merge_config(config, FLAGS.opt)
- ser_config = load_config(FLAGS.config_ser)
- ser_config = merge_config(ser_config, FLAGS.opt_ser)
- logger = get_logger()
- # check if set use_gpu=True in paddlepaddle cpu version
- use_gpu = config['Global']['use_gpu']
- device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
- device = paddle.set_device(device)
- logger.info('{} re config {}'.format('*' * 10, '*' * 10))
- print_dict(config, logger)
- logger.info('\n')
- logger.info('{} ser config {}'.format('*' * 10, '*' * 10))
- print_dict(ser_config, logger)
- logger.info('train with paddle {} and device {}'.format(paddle.__version__,
- device))
- return config, ser_config, device, logger
- if __name__ == '__main__':
- config, ser_config, device, logger = preprocess()
- os.makedirs(config['Global']['save_res_path'], exist_ok=True)
- ser_re_engine = SerRePredictor(config, ser_config)
- if config["Global"].get("infer_mode", None) is False:
- data_dir = config['Eval']['dataset']['data_dir']
- with open(config['Global']['infer_img'], "rb") as f:
- infer_imgs = f.readlines()
- else:
- infer_imgs = get_image_file_list(config['Global']['infer_img'])
- with open(
- os.path.join(config['Global']['save_res_path'],
- "infer_results.txt"),
- "w",
- encoding='utf-8') as fout:
- for idx, info in enumerate(infer_imgs):
- if config["Global"].get("infer_mode", None) is False:
- data_line = info.decode('utf-8')
- substr = data_line.strip("\n").split("\t")
- img_path = os.path.join(data_dir, substr[0])
- data = {'img_path': img_path, 'label': substr[1]}
- else:
- img_path = info
- data = {'img_path': img_path}
- save_img_path = os.path.join(
- config['Global']['save_res_path'],
- os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
- result = ser_re_engine(data)
- result = result[0]
- fout.write(img_path + "\t" + json.dumps(
- result, ensure_ascii=False) + "\n")
- img_res = draw_re_results(img_path, result)
- cv2.imwrite(save_img_path, img_res)
- logger.info("process: [{}/{}], save result to {}".format(
- idx, len(infer_imgs), save_img_path))
|