visual.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import cv2
  15. import os
  16. import numpy as np
  17. from PIL import Image, ImageDraw, ImageFont
  18. def draw_ser_results(image,
  19. ocr_results,
  20. font_path="doc/fonts/simfang.ttf",
  21. font_size=14):
  22. np.random.seed(2021)
  23. color = (np.random.permutation(range(255)),
  24. np.random.permutation(range(255)),
  25. np.random.permutation(range(255)))
  26. color_map = {
  27. idx: (color[0][idx], color[1][idx], color[2][idx])
  28. for idx in range(1, 255)
  29. }
  30. if isinstance(image, np.ndarray):
  31. image = Image.fromarray(image)
  32. elif isinstance(image, str) and os.path.isfile(image):
  33. image = Image.open(image).convert('RGB')
  34. img_new = image.copy()
  35. draw = ImageDraw.Draw(img_new)
  36. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  37. for ocr_info in ocr_results:
  38. if ocr_info["pred_id"] not in color_map:
  39. continue
  40. color = color_map[ocr_info["pred_id"]]
  41. text = "{}: {}".format(ocr_info["pred"], ocr_info["transcription"])
  42. if "bbox" in ocr_info:
  43. # draw with ocr engine
  44. bbox = ocr_info["bbox"]
  45. else:
  46. # draw with ocr groundtruth
  47. bbox = trans_poly_to_bbox(ocr_info["points"])
  48. draw_box_txt(bbox, text, draw, font, font_size, color)
  49. img_new = Image.blend(image, img_new, 0.7)
  50. return np.array(img_new)
  51. def draw_box_txt(bbox, text, draw, font, font_size, color):
  52. # draw ocr results outline
  53. bbox = ((bbox[0], bbox[1]), (bbox[2], bbox[3]))
  54. draw.rectangle(bbox, fill=color)
  55. # draw ocr results
  56. tw = font.getsize(text)[0]
  57. th = font.getsize(text)[1]
  58. start_y = max(0, bbox[0][1] - th)
  59. draw.rectangle(
  60. [(bbox[0][0] + 1, start_y), (bbox[0][0] + tw + 1, start_y + th)],
  61. fill=(0, 0, 255))
  62. draw.text((bbox[0][0] + 1, start_y), text, fill=(255, 255, 255), font=font)
  63. def trans_poly_to_bbox(poly):
  64. x1 = np.min([p[0] for p in poly])
  65. x2 = np.max([p[0] for p in poly])
  66. y1 = np.min([p[1] for p in poly])
  67. y2 = np.max([p[1] for p in poly])
  68. return [x1, y1, x2, y2]
  69. def draw_re_results(image,
  70. result,
  71. font_path="doc/fonts/simfang.ttf",
  72. font_size=18):
  73. np.random.seed(0)
  74. if isinstance(image, np.ndarray):
  75. image = Image.fromarray(image)
  76. elif isinstance(image, str) and os.path.isfile(image):
  77. image = Image.open(image).convert('RGB')
  78. img_new = image.copy()
  79. draw = ImageDraw.Draw(img_new)
  80. font = ImageFont.truetype(font_path, font_size, encoding="utf-8")
  81. color_head = (0, 0, 255)
  82. color_tail = (255, 0, 0)
  83. color_line = (0, 255, 0)
  84. for ocr_info_head, ocr_info_tail in result:
  85. draw_box_txt(ocr_info_head["bbox"], ocr_info_head["transcription"],
  86. draw, font, font_size, color_head)
  87. draw_box_txt(ocr_info_tail["bbox"], ocr_info_tail["transcription"],
  88. draw, font, font_size, color_tail)
  89. center_head = (
  90. (ocr_info_head['bbox'][0] + ocr_info_head['bbox'][2]) // 2,
  91. (ocr_info_head['bbox'][1] + ocr_info_head['bbox'][3]) // 2)
  92. center_tail = (
  93. (ocr_info_tail['bbox'][0] + ocr_info_tail['bbox'][2]) // 2,
  94. (ocr_info_tail['bbox'][1] + ocr_info_tail['bbox'][3]) // 2)
  95. draw.line([center_head, center_tail], fill=color_line, width=5)
  96. img_new = Image.blend(image, img_new, 0.5)
  97. return np.array(img_new)
  98. def draw_rectangle(img_path, boxes):
  99. boxes = np.array(boxes)
  100. img = cv2.imread(img_path)
  101. img_show = img.copy()
  102. for box in boxes.astype(int):
  103. x1, y1, x2, y2 = box
  104. cv2.rectangle(img_show, (x1, y1), (x2, y2), (255, 0, 0), 2)
  105. return img_show