make_border_map.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/data_loader/modules/make_border_map.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. from __future__ import unicode_literals
  22. import numpy as np
  23. import cv2
  24. np.seterr(divide='ignore', invalid='ignore')
  25. import pyclipper
  26. from shapely.geometry import Polygon
  27. import sys
  28. import warnings
  29. warnings.simplefilter("ignore")
  30. __all__ = ['MakeBorderMap']
  31. class MakeBorderMap(object):
  32. def __init__(self,
  33. shrink_ratio=0.4,
  34. thresh_min=0.3,
  35. thresh_max=0.7,
  36. **kwargs):
  37. self.shrink_ratio = shrink_ratio
  38. self.thresh_min = thresh_min
  39. self.thresh_max = thresh_max
  40. def __call__(self, data):
  41. img = data['image']
  42. text_polys = data['polys']
  43. ignore_tags = data['ignore_tags']
  44. canvas = np.zeros(img.shape[:2], dtype=np.float32)
  45. mask = np.zeros(img.shape[:2], dtype=np.float32)
  46. for i in range(len(text_polys)):
  47. if ignore_tags[i]:
  48. continue
  49. self.draw_border_map(text_polys[i], canvas, mask=mask)
  50. canvas = canvas * (self.thresh_max - self.thresh_min) + self.thresh_min
  51. data['threshold_map'] = canvas
  52. data['threshold_mask'] = mask
  53. return data
  54. def draw_border_map(self, polygon, canvas, mask):
  55. polygon = np.array(polygon)
  56. assert polygon.ndim == 2
  57. assert polygon.shape[1] == 2
  58. polygon_shape = Polygon(polygon)
  59. if polygon_shape.area <= 0:
  60. return
  61. distance = polygon_shape.area * (
  62. 1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length
  63. subject = [tuple(l) for l in polygon]
  64. padding = pyclipper.PyclipperOffset()
  65. padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  66. padded_polygon = np.array(padding.Execute(distance)[0])
  67. cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0)
  68. xmin = padded_polygon[:, 0].min()
  69. xmax = padded_polygon[:, 0].max()
  70. ymin = padded_polygon[:, 1].min()
  71. ymax = padded_polygon[:, 1].max()
  72. width = xmax - xmin + 1
  73. height = ymax - ymin + 1
  74. polygon[:, 0] = polygon[:, 0] - xmin
  75. polygon[:, 1] = polygon[:, 1] - ymin
  76. xs = np.broadcast_to(
  77. np.linspace(
  78. 0, width - 1, num=width).reshape(1, width), (height, width))
  79. ys = np.broadcast_to(
  80. np.linspace(
  81. 0, height - 1, num=height).reshape(height, 1), (height, width))
  82. distance_map = np.zeros(
  83. (polygon.shape[0], height, width), dtype=np.float32)
  84. for i in range(polygon.shape[0]):
  85. j = (i + 1) % polygon.shape[0]
  86. absolute_distance = self._distance(xs, ys, polygon[i], polygon[j])
  87. distance_map[i] = np.clip(absolute_distance / distance, 0, 1)
  88. distance_map = distance_map.min(axis=0)
  89. xmin_valid = min(max(0, xmin), canvas.shape[1] - 1)
  90. xmax_valid = min(max(0, xmax), canvas.shape[1] - 1)
  91. ymin_valid = min(max(0, ymin), canvas.shape[0] - 1)
  92. ymax_valid = min(max(0, ymax), canvas.shape[0] - 1)
  93. canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax(
  94. 1 - distance_map[ymin_valid - ymin:ymax_valid - ymax + height,
  95. xmin_valid - xmin:xmax_valid - xmax + width],
  96. canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1])
  97. def _distance(self, xs, ys, point_1, point_2):
  98. '''
  99. compute the distance from point to a line
  100. ys: coordinates in the first axis
  101. xs: coordinates in the second axis
  102. point_1, point_2: (x, y), the end of the line
  103. '''
  104. height, width = xs.shape[:2]
  105. square_distance_1 = np.square(xs - point_1[0]) + np.square(ys - point_1[
  106. 1])
  107. square_distance_2 = np.square(xs - point_2[0]) + np.square(ys - point_2[
  108. 1])
  109. square_distance = np.square(point_1[0] - point_2[0]) + np.square(
  110. point_1[1] - point_2[1])
  111. cosin = (square_distance - square_distance_1 - square_distance_2) / (
  112. 2 * np.sqrt(square_distance_1 * square_distance_2))
  113. square_sin = 1 - np.square(cosin)
  114. square_sin = np.nan_to_num(square_sin)
  115. result = np.sqrt(square_distance_1 * square_distance_2 * square_sin /
  116. square_distance)
  117. result[cosin <
  118. 0] = np.sqrt(np.fmin(square_distance_1, square_distance_2))[cosin
  119. < 0]
  120. # self.extend_line(point_1, point_2, result)
  121. return result
  122. def extend_line(self, point_1, point_2, result, shrink_ratio):
  123. ex_point_1 = (int(
  124. round(point_1[0] + (point_1[0] - point_2[0]) * (1 + shrink_ratio))),
  125. int(
  126. round(point_1[1] + (point_1[1] - point_2[1]) * (
  127. 1 + shrink_ratio))))
  128. cv2.line(
  129. result,
  130. tuple(ex_point_1),
  131. tuple(point_1),
  132. 4096.0,
  133. 1,
  134. lineType=cv2.LINE_AA,
  135. shift=0)
  136. ex_point_2 = (int(
  137. round(point_2[0] + (point_2[0] - point_1[0]) * (1 + shrink_ratio))),
  138. int(
  139. round(point_2[1] + (point_2[1] - point_1[1]) * (
  140. 1 + shrink_ratio))))
  141. cv2.line(
  142. result,
  143. tuple(ex_point_2),
  144. tuple(point_2),
  145. 4096.0,
  146. 1,
  147. lineType=cv2.LINE_AA,
  148. shift=0)
  149. return ex_point_1, ex_point_2