infer_kie_token_ser_re.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import os
  19. import sys
  20. __dir__ = os.path.dirname(os.path.abspath(__file__))
  21. sys.path.append(__dir__)
  22. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  23. os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
  24. import cv2
  25. import json
  26. import paddle
  27. import paddle.distributed as dist
  28. from ppocr.data import create_operators, transform
  29. from ppocr.modeling.architectures import build_model
  30. from ppocr.postprocess import build_post_process
  31. from ppocr.utils.save_load import load_model
  32. from ppocr.utils.visual import draw_re_results
  33. from ppocr.utils.logging import get_logger
  34. from ppocr.utils.utility import get_image_file_list, load_vqa_bio_label_maps, print_dict
  35. from tools.program import ArgsParser, load_config, merge_config
  36. from tools.infer_kie_token_ser import SerPredictor
  37. class ReArgsParser(ArgsParser):
  38. def __init__(self):
  39. super(ReArgsParser, self).__init__()
  40. self.add_argument(
  41. "-c_ser", "--config_ser", help="ser configuration file to use")
  42. self.add_argument(
  43. "-o_ser",
  44. "--opt_ser",
  45. nargs='+',
  46. help="set ser configuration options ")
  47. def parse_args(self, argv=None):
  48. args = super(ReArgsParser, self).parse_args(argv)
  49. assert args.config_ser is not None, \
  50. "Please specify --config_ser=ser_configure_file_path."
  51. args.opt_ser = self._parse_opt(args.opt_ser)
  52. return args
  53. def make_input(ser_inputs, ser_results):
  54. entities_labels = {'HEADER': 0, 'QUESTION': 1, 'ANSWER': 2}
  55. batch_size, max_seq_len = ser_inputs[0].shape[:2]
  56. entities = ser_inputs[8][0]
  57. ser_results = ser_results[0]
  58. assert len(entities) == len(ser_results)
  59. # entities
  60. start = []
  61. end = []
  62. label = []
  63. entity_idx_dict = {}
  64. for i, (res, entity) in enumerate(zip(ser_results, entities)):
  65. if res['pred'] == 'O':
  66. continue
  67. entity_idx_dict[len(start)] = i
  68. start.append(entity['start'])
  69. end.append(entity['end'])
  70. label.append(entities_labels[res['pred']])
  71. entities = np.full([max_seq_len + 1, 3], fill_value=-1, dtype=np.int64)
  72. entities[0, 0] = len(start)
  73. entities[1:len(start) + 1, 0] = start
  74. entities[0, 1] = len(end)
  75. entities[1:len(end) + 1, 1] = end
  76. entities[0, 2] = len(label)
  77. entities[1:len(label) + 1, 2] = label
  78. # relations
  79. head = []
  80. tail = []
  81. for i in range(len(label)):
  82. for j in range(len(label)):
  83. if label[i] == 1 and label[j] == 2:
  84. head.append(i)
  85. tail.append(j)
  86. relations = np.full([len(head) + 1, 2], fill_value=-1, dtype=np.int64)
  87. relations[0, 0] = len(head)
  88. relations[1:len(head) + 1, 0] = head
  89. relations[0, 1] = len(tail)
  90. relations[1:len(tail) + 1, 1] = tail
  91. entities = np.expand_dims(entities, axis=0)
  92. entities = np.repeat(entities, batch_size, axis=0)
  93. relations = np.expand_dims(relations, axis=0)
  94. relations = np.repeat(relations, batch_size, axis=0)
  95. # remove ocr_info segment_offset_id and label in ser input
  96. if isinstance(ser_inputs[0], paddle.Tensor):
  97. entities = paddle.to_tensor(entities)
  98. relations = paddle.to_tensor(relations)
  99. ser_inputs = ser_inputs[:5] + [entities, relations]
  100. entity_idx_dict_batch = []
  101. for b in range(batch_size):
  102. entity_idx_dict_batch.append(entity_idx_dict)
  103. return ser_inputs, entity_idx_dict_batch
  104. class SerRePredictor(object):
  105. def __init__(self, config, ser_config):
  106. global_config = config['Global']
  107. if "infer_mode" in global_config:
  108. ser_config["Global"]["infer_mode"] = global_config["infer_mode"]
  109. self.ser_engine = SerPredictor(ser_config)
  110. # init re model
  111. # build post process
  112. self.post_process_class = build_post_process(config['PostProcess'],
  113. global_config)
  114. # build model
  115. self.model = build_model(config['Architecture'])
  116. load_model(
  117. config, self.model, model_type=config['Architecture']["model_type"])
  118. self.model.eval()
  119. def __call__(self, data):
  120. ser_results, ser_inputs = self.ser_engine(data)
  121. re_input, entity_idx_dict_batch = make_input(ser_inputs, ser_results)
  122. if self.model.backbone.use_visual_backbone is False:
  123. re_input.pop(4)
  124. preds = self.model(re_input)
  125. post_result = self.post_process_class(
  126. preds,
  127. ser_results=ser_results,
  128. entity_idx_dict_batch=entity_idx_dict_batch)
  129. return post_result
  130. def preprocess():
  131. FLAGS = ReArgsParser().parse_args()
  132. config = load_config(FLAGS.config)
  133. config = merge_config(config, FLAGS.opt)
  134. ser_config = load_config(FLAGS.config_ser)
  135. ser_config = merge_config(ser_config, FLAGS.opt_ser)
  136. logger = get_logger()
  137. # check if set use_gpu=True in paddlepaddle cpu version
  138. use_gpu = config['Global']['use_gpu']
  139. device = 'gpu:{}'.format(dist.ParallelEnv().dev_id) if use_gpu else 'cpu'
  140. device = paddle.set_device(device)
  141. logger.info('{} re config {}'.format('*' * 10, '*' * 10))
  142. print_dict(config, logger)
  143. logger.info('\n')
  144. logger.info('{} ser config {}'.format('*' * 10, '*' * 10))
  145. print_dict(ser_config, logger)
  146. logger.info('train with paddle {} and device {}'.format(paddle.__version__,
  147. device))
  148. return config, ser_config, device, logger
  149. if __name__ == '__main__':
  150. config, ser_config, device, logger = preprocess()
  151. os.makedirs(config['Global']['save_res_path'], exist_ok=True)
  152. ser_re_engine = SerRePredictor(config, ser_config)
  153. if config["Global"].get("infer_mode", None) is False:
  154. data_dir = config['Eval']['dataset']['data_dir']
  155. with open(config['Global']['infer_img'], "rb") as f:
  156. infer_imgs = f.readlines()
  157. else:
  158. infer_imgs = get_image_file_list(config['Global']['infer_img'])
  159. with open(
  160. os.path.join(config['Global']['save_res_path'],
  161. "infer_results.txt"),
  162. "w",
  163. encoding='utf-8') as fout:
  164. for idx, info in enumerate(infer_imgs):
  165. if config["Global"].get("infer_mode", None) is False:
  166. data_line = info.decode('utf-8')
  167. substr = data_line.strip("\n").split("\t")
  168. img_path = os.path.join(data_dir, substr[0])
  169. data = {'img_path': img_path, 'label': substr[1]}
  170. else:
  171. img_path = info
  172. data = {'img_path': img_path}
  173. save_img_path = os.path.join(
  174. config['Global']['save_res_path'],
  175. os.path.splitext(os.path.basename(img_path))[0] + "_ser_re.jpg")
  176. result = ser_re_engine(data)
  177. result = result[0]
  178. fout.write(img_path + "\t" + json.dumps(
  179. result, ensure_ascii=False) + "\n")
  180. img_res = draw_re_results(img_path, result)
  181. cv2.imwrite(save_img_path, img_res)
  182. logger.info("process: [{}/{}], save result to {}".format(
  183. idx, len(infer_imgs), save_img_path))