east_postprocess.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. from .locality_aware_nms import nms_locality
  19. import cv2
  20. import paddle
  21. import os
  22. import sys
  23. class EASTPostProcess(object):
  24. """
  25. The post process for EAST.
  26. """
  27. def __init__(self,
  28. score_thresh=0.8,
  29. cover_thresh=0.1,
  30. nms_thresh=0.2,
  31. **kwargs):
  32. self.score_thresh = score_thresh
  33. self.cover_thresh = cover_thresh
  34. self.nms_thresh = nms_thresh
  35. def restore_rectangle_quad(self, origin, geometry):
  36. """
  37. Restore rectangle from quadrangle.
  38. """
  39. # quad
  40. origin_concat = np.concatenate(
  41. (origin, origin, origin, origin), axis=1) # (n, 8)
  42. pred_quads = origin_concat - geometry
  43. pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
  44. return pred_quads
  45. def detect(self,
  46. score_map,
  47. geo_map,
  48. score_thresh=0.8,
  49. cover_thresh=0.1,
  50. nms_thresh=0.2):
  51. """
  52. restore text boxes from score map and geo map
  53. """
  54. score_map = score_map[0]
  55. geo_map = np.swapaxes(geo_map, 1, 0)
  56. geo_map = np.swapaxes(geo_map, 1, 2)
  57. # filter the score map
  58. xy_text = np.argwhere(score_map > score_thresh)
  59. if len(xy_text) == 0:
  60. return []
  61. # sort the text boxes via the y axis
  62. xy_text = xy_text[np.argsort(xy_text[:, 0])]
  63. #restore quad proposals
  64. text_box_restored = self.restore_rectangle_quad(
  65. xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :])
  66. boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
  67. boxes[:, :8] = text_box_restored.reshape((-1, 8))
  68. boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
  69. try:
  70. import lanms
  71. boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
  72. except:
  73. print(
  74. 'you should install lanms by pip3 install lanms-nova to speed up nms_locality'
  75. )
  76. boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
  77. if boxes.shape[0] == 0:
  78. return []
  79. # Here we filter some low score boxes by the average score map,
  80. # this is different from the orginal paper.
  81. for i, box in enumerate(boxes):
  82. mask = np.zeros_like(score_map, dtype=np.uint8)
  83. cv2.fillPoly(mask, box[:8].reshape(
  84. (-1, 4, 2)).astype(np.int32) // 4, 1)
  85. boxes[i, 8] = cv2.mean(score_map, mask)[0]
  86. boxes = boxes[boxes[:, 8] > cover_thresh]
  87. return boxes
  88. def sort_poly(self, p):
  89. """
  90. Sort polygons.
  91. """
  92. min_axis = np.argmin(np.sum(p, axis=1))
  93. p = p[[min_axis, (min_axis + 1) % 4,\
  94. (min_axis + 2) % 4, (min_axis + 3) % 4]]
  95. if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
  96. return p
  97. else:
  98. return p[[0, 3, 2, 1]]
  99. def __call__(self, outs_dict, shape_list):
  100. score_list = outs_dict['f_score']
  101. geo_list = outs_dict['f_geo']
  102. if isinstance(score_list, paddle.Tensor):
  103. score_list = score_list.numpy()
  104. geo_list = geo_list.numpy()
  105. img_num = len(shape_list)
  106. dt_boxes_list = []
  107. for ino in range(img_num):
  108. score = score_list[ino]
  109. geo = geo_list[ino]
  110. boxes = self.detect(
  111. score_map=score,
  112. geo_map=geo,
  113. score_thresh=self.score_thresh,
  114. cover_thresh=self.cover_thresh,
  115. nms_thresh=self.nms_thresh)
  116. boxes_norm = []
  117. if len(boxes) > 0:
  118. h, w = score.shape[1:]
  119. src_h, src_w, ratio_h, ratio_w = shape_list[ino]
  120. boxes = boxes[:, :8].reshape((-1, 4, 2))
  121. boxes[:, :, 0] /= ratio_w
  122. boxes[:, :, 1] /= ratio_h
  123. for i_box, box in enumerate(boxes):
  124. box = self.sort_poly(box.astype(np.int32))
  125. if np.linalg.norm(box[0] - box[1]) < 5 \
  126. or np.linalg.norm(box[3] - box[0]) < 5:
  127. continue
  128. boxes_norm.append(box)
  129. dt_boxes_list.append({'points': np.array(boxes_norm)})
  130. return dt_boxes_list