infer_kie.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import paddle.nn.functional as F
  19. import os
  20. import sys
  21. __dir__ = os.path.dirname(os.path.abspath(__file__))
  22. sys.path.append(__dir__)
  23. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  24. os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
  25. import cv2
  26. import paddle
  27. from ppocr.data import create_operators, transform
  28. from ppocr.modeling.architectures import build_model
  29. from ppocr.utils.save_load import load_model
  30. import tools.program as program
  31. import time
  32. def read_class_list(filepath):
  33. ret = {}
  34. with open(filepath, "r") as f:
  35. lines = f.readlines()
  36. for idx, line in enumerate(lines):
  37. ret[idx] = line.strip("\n")
  38. return ret
  39. def draw_kie_result(batch, node, idx_to_cls, count):
  40. img = batch[6].copy()
  41. boxes = batch[7]
  42. h, w = img.shape[:2]
  43. pred_img = np.ones((h, w * 2, 3), dtype=np.uint8) * 255
  44. max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
  45. node_pred_label = max_idx.numpy().tolist()
  46. node_pred_score = max_value.numpy().tolist()
  47. for i, box in enumerate(boxes):
  48. if i >= len(node_pred_label):
  49. break
  50. new_box = [[box[0], box[1]], [box[2], box[1]], [box[2], box[3]],
  51. [box[0], box[3]]]
  52. Pts = np.array([new_box], np.int32)
  53. cv2.polylines(
  54. img, [Pts.reshape((-1, 1, 2))],
  55. True,
  56. color=(255, 255, 0),
  57. thickness=1)
  58. x_min = int(min([point[0] for point in new_box]))
  59. y_min = int(min([point[1] for point in new_box]))
  60. pred_label = node_pred_label[i]
  61. if pred_label in idx_to_cls:
  62. pred_label = idx_to_cls[pred_label]
  63. pred_score = '{:.2f}'.format(node_pred_score[i])
  64. text = pred_label + '(' + pred_score + ')'
  65. cv2.putText(pred_img, text, (x_min * 2, y_min),
  66. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 0, 0), 1)
  67. vis_img = np.ones((h, w * 3, 3), dtype=np.uint8) * 255
  68. vis_img[:, :w] = img
  69. vis_img[:, w:] = pred_img
  70. save_kie_path = os.path.dirname(config['Global'][
  71. 'save_res_path']) + "/kie_results/"
  72. if not os.path.exists(save_kie_path):
  73. os.makedirs(save_kie_path)
  74. save_path = os.path.join(save_kie_path, str(count) + ".png")
  75. cv2.imwrite(save_path, vis_img)
  76. logger.info("The Kie Image saved in {}".format(save_path))
  77. def write_kie_result(fout, node, data):
  78. """
  79. Write infer result to output file, sorted by the predict label of each line.
  80. The format keeps the same as the input with additional score attribute.
  81. """
  82. import json
  83. label = data['label']
  84. annotations = json.loads(label)
  85. max_value, max_idx = paddle.max(node, -1), paddle.argmax(node, -1)
  86. node_pred_label = max_idx.numpy().tolist()
  87. node_pred_score = max_value.numpy().tolist()
  88. res = []
  89. for i, label in enumerate(node_pred_label):
  90. pred_score = '{:.2f}'.format(node_pred_score[i])
  91. pred_res = {
  92. 'label': label,
  93. 'transcription': annotations[i]['transcription'],
  94. 'score': pred_score,
  95. 'points': annotations[i]['points'],
  96. }
  97. res.append(pred_res)
  98. res.sort(key=lambda x: x['label'])
  99. fout.writelines([json.dumps(res, ensure_ascii=False) + '\n'])
  100. def main():
  101. global_config = config['Global']
  102. # build model
  103. model = build_model(config['Architecture'])
  104. load_model(config, model)
  105. # create data ops
  106. transforms = []
  107. for op in config['Eval']['dataset']['transforms']:
  108. transforms.append(op)
  109. data_dir = config['Eval']['dataset']['data_dir']
  110. ops = create_operators(transforms, global_config)
  111. save_res_path = config['Global']['save_res_path']
  112. class_path = config['Global']['class_path']
  113. idx_to_cls = read_class_list(class_path)
  114. os.makedirs(os.path.dirname(save_res_path), exist_ok=True)
  115. model.eval()
  116. warmup_times = 0
  117. count_t = []
  118. with open(save_res_path, "w") as fout:
  119. with open(config['Global']['infer_img'], "rb") as f:
  120. lines = f.readlines()
  121. for index, data_line in enumerate(lines):
  122. if index == 10:
  123. warmup_t = time.time()
  124. data_line = data_line.decode('utf-8')
  125. substr = data_line.strip("\n").split("\t")
  126. img_path, label = data_dir + "/" + substr[0], substr[1]
  127. data = {'img_path': img_path, 'label': label}
  128. with open(data['img_path'], 'rb') as f:
  129. img = f.read()
  130. data['image'] = img
  131. st = time.time()
  132. batch = transform(data, ops)
  133. batch_pred = [0] * len(batch)
  134. for i in range(len(batch)):
  135. batch_pred[i] = paddle.to_tensor(
  136. np.expand_dims(
  137. batch[i], axis=0))
  138. st = time.time()
  139. node, edge = model(batch_pred)
  140. node = F.softmax(node, -1)
  141. count_t.append(time.time() - st)
  142. draw_kie_result(batch, node, idx_to_cls, index)
  143. write_kie_result(fout, node, data)
  144. fout.close()
  145. logger.info("success!")
  146. logger.info("It took {} s for predict {} images.".format(
  147. np.sum(count_t), len(count_t)))
  148. ips = len(count_t[warmup_times:]) / np.sum(count_t[warmup_times:])
  149. logger.info("The ips is {} images/s".format(ips))
  150. if __name__ == '__main__':
  151. config, device, logger, vdl_writer = program.preprocess()
  152. main()