infer_e2e.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import os
  19. import sys
  20. __dir__ = os.path.dirname(os.path.abspath(__file__))
  21. sys.path.append(__dir__)
  22. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  23. os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
  24. import cv2
  25. import json
  26. import paddle
  27. from ppocr.data import create_operators, transform
  28. from ppocr.modeling.architectures import build_model
  29. from ppocr.postprocess import build_post_process
  30. from ppocr.utils.save_load import load_model
  31. from ppocr.utils.utility import get_image_file_list
  32. import tools.program as program
  33. from PIL import Image, ImageDraw, ImageFont
  34. import math
  35. def draw_e2e_res_for_chinese(image,
  36. boxes,
  37. txts,
  38. config,
  39. img_name,
  40. font_path="./doc/simfang.ttf"):
  41. h, w = image.height, image.width
  42. img_left = image.copy()
  43. img_right = Image.new('RGB', (w, h), (255, 255, 255))
  44. import random
  45. random.seed(0)
  46. draw_left = ImageDraw.Draw(img_left)
  47. draw_right = ImageDraw.Draw(img_right)
  48. for idx, (box, txt) in enumerate(zip(boxes, txts)):
  49. box = np.array(box)
  50. box = [tuple(x) for x in box]
  51. color = (random.randint(0, 255), random.randint(0, 255),
  52. random.randint(0, 255))
  53. draw_left.polygon(box, fill=color)
  54. draw_right.polygon(box, outline=color)
  55. font = ImageFont.truetype(font_path, 15, encoding="utf-8")
  56. draw_right.text([box[0][0], box[0][1]], txt, fill=(0, 0, 0), font=font)
  57. img_left = Image.blend(image, img_left, 0.5)
  58. img_show = Image.new('RGB', (w * 2, h), (255, 255, 255))
  59. img_show.paste(img_left, (0, 0, w, h))
  60. img_show.paste(img_right, (w, 0, w * 2, h))
  61. save_e2e_path = os.path.dirname(config['Global'][
  62. 'save_res_path']) + "/e2e_results/"
  63. if not os.path.exists(save_e2e_path):
  64. os.makedirs(save_e2e_path)
  65. save_path = os.path.join(save_e2e_path, os.path.basename(img_name))
  66. cv2.imwrite(save_path, np.array(img_show)[:, :, ::-1])
  67. logger.info("The e2e Image saved in {}".format(save_path))
  68. def draw_e2e_res(dt_boxes, strs, config, img, img_name):
  69. if len(dt_boxes) > 0:
  70. src_im = img
  71. for box, str in zip(dt_boxes, strs):
  72. box = box.astype(np.int32).reshape((-1, 1, 2))
  73. cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
  74. cv2.putText(
  75. src_im,
  76. str,
  77. org=(int(box[0, 0, 0]), int(box[0, 0, 1])),
  78. fontFace=cv2.FONT_HERSHEY_COMPLEX,
  79. fontScale=0.7,
  80. color=(0, 255, 0),
  81. thickness=1)
  82. save_det_path = os.path.dirname(config['Global'][
  83. 'save_res_path']) + "/e2e_results/"
  84. if not os.path.exists(save_det_path):
  85. os.makedirs(save_det_path)
  86. save_path = os.path.join(save_det_path, os.path.basename(img_name))
  87. cv2.imwrite(save_path, src_im)
  88. logger.info("The e2e Image saved in {}".format(save_path))
  89. def main():
  90. global_config = config['Global']
  91. # build model
  92. model = build_model(config['Architecture'])
  93. load_model(config, model)
  94. # build post process
  95. post_process_class = build_post_process(config['PostProcess'],
  96. global_config)
  97. # create data ops
  98. transforms = []
  99. for op in config['Eval']['dataset']['transforms']:
  100. op_name = list(op)[0]
  101. if 'Label' in op_name:
  102. continue
  103. elif op_name == 'KeepKeys':
  104. op[op_name]['keep_keys'] = ['image', 'shape']
  105. transforms.append(op)
  106. ops = create_operators(transforms, global_config)
  107. save_res_path = config['Global']['save_res_path']
  108. if not os.path.exists(os.path.dirname(save_res_path)):
  109. os.makedirs(os.path.dirname(save_res_path))
  110. model.eval()
  111. with open(save_res_path, "wb") as fout:
  112. for file in get_image_file_list(config['Global']['infer_img']):
  113. logger.info("infer_img: {}".format(file))
  114. with open(file, 'rb') as f:
  115. img = f.read()
  116. data = {'image': img}
  117. batch = transform(data, ops)
  118. images = np.expand_dims(batch[0], axis=0)
  119. shape_list = np.expand_dims(batch[1], axis=0)
  120. images = paddle.to_tensor(images)
  121. preds = model(images)
  122. post_result = post_process_class(preds, shape_list)
  123. points, strs = post_result['points'], post_result['texts']
  124. # write result
  125. dt_boxes_json = []
  126. for poly, str in zip(points, strs):
  127. tmp_json = {"transcription": str}
  128. tmp_json['points'] = poly.tolist()
  129. dt_boxes_json.append(tmp_json)
  130. otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
  131. fout.write(otstr.encode())
  132. src_img = cv2.imread(file)
  133. if global_config['infer_visual_type'] == 'EN':
  134. draw_e2e_res(points, strs, config, src_img, file)
  135. elif global_config['infer_visual_type'] == 'CN':
  136. src_img = Image.fromarray(
  137. cv2.cvtColor(src_img, cv2.COLOR_BGR2RGB))
  138. draw_e2e_res_for_chinese(
  139. src_img,
  140. points,
  141. strs,
  142. config,
  143. file,
  144. font_path="./doc/fonts/simfang.ttf")
  145. logger.info("success!")
  146. if __name__ == '__main__':
  147. config, device, logger, vdl_writer = program.preprocess()
  148. main()