db_postprocess.py 9.7 KB

  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refered from:
  16. https://github.com/WenmuZhou/DBNet.pytorch/blob/master/post_processing/seg_detector_representer.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import numpy as np
  22. import cv2
  23. import paddle
  24. from shapely.geometry import Polygon
  25. import pyclipper
  26. class DBPostProcess(object):
  27. """
  28. The post process for Differentiable Binarization (DB).
  29. """
  30. def __init__(self,
  31. thresh=0.3,
  32. box_thresh=0.7,
  33. max_candidates=1000,
  34. unclip_ratio=2.0,
  35. use_dilation=False,
  36. score_mode="fast",
  37. box_type='quad',
  38. **kwargs):
  39. self.thresh = thresh
  40. self.box_thresh = box_thresh
  41. self.max_candidates = max_candidates
  42. self.unclip_ratio = unclip_ratio
  43. self.min_size = 3
  44. self.score_mode = score_mode
  45. self.box_type = box_type
  46. assert score_mode in [
  47. "slow", "fast"
  48. ], "Score mode must be in [slow, fast] but got: {}".format(score_mode)
  49. self.dilation_kernel = None if not use_dilation else np.array(
  50. [[1, 1], [1, 1]])
  51. def polygons_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  52. '''
  53. _bitmap: single map with shape (1, H, W),
  54. whose values are binarized as {0, 1}
  55. '''
  56. bitmap = _bitmap
  57. height, width = bitmap.shape
  58. boxes = []
  59. scores = []
  60. contours, _ = cv2.findContours((bitmap * 255).astype(np.uint8),
  62. for contour in contours[:self.max_candidates]:
  63. epsilon = 0.002 * cv2.arcLength(contour, True)
  64. approx = cv2.approxPolyDP(contour, epsilon, True)
  65. points = approx.reshape((-1, 2))
  66. if points.shape[0] < 4:
  67. continue
  68. score = self.box_score_fast(pred, points.reshape(-1, 2))
  69. if self.box_thresh > score:
  70. continue
  71. if points.shape[0] > 2:
  72. box = self.unclip(points, self.unclip_ratio)
  73. if len(box) > 1:
  74. continue
  75. else:
  76. continue
  77. box = box.reshape(-1, 2)
  78. _, sside = self.get_mini_boxes(box.reshape((-1, 1, 2)))
  79. if sside < self.min_size + 2:
  80. continue
  81. box = np.array(box)
  82. box[:, 0] = np.clip(
  83. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  84. box[:, 1] = np.clip(
  85. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  86. boxes.append(box.tolist())
  87. scores.append(score)
  88. return boxes, scores
  89. def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
  90. '''
  91. _bitmap: single map with shape (1, H, W),
  92. whose values are binarized as {0, 1}
  93. '''
  94. bitmap = _bitmap
  95. height, width = bitmap.shape
  96. outs = cv2.findContours((bitmap * 255).astype(np.uint8), cv2.RETR_LIST,
  98. if len(outs) == 3:
  99. img, contours, _ = outs[0], outs[1], outs[2]
  100. elif len(outs) == 2:
  101. contours, _ = outs[0], outs[1]
  102. num_contours = min(len(contours), self.max_candidates)
  103. boxes = []
  104. scores = []
  105. for index in range(num_contours):
  106. contour = contours[index]
  107. points, sside = self.get_mini_boxes(contour)
  108. if sside < self.min_size:
  109. continue
  110. points = np.array(points)
  111. if self.score_mode == "fast":
  112. score = self.box_score_fast(pred, points.reshape(-1, 2))
  113. else:
  114. score = self.box_score_slow(pred, contour)
  115. if self.box_thresh > score:
  116. continue
  117. box = self.unclip(points, self.unclip_ratio).reshape(-1, 1, 2)
  118. box, sside = self.get_mini_boxes(box)
  119. if sside < self.min_size + 2:
  120. continue
  121. box = np.array(box)
  122. box[:, 0] = np.clip(
  123. np.round(box[:, 0] / width * dest_width), 0, dest_width)
  124. box[:, 1] = np.clip(
  125. np.round(box[:, 1] / height * dest_height), 0, dest_height)
  126. boxes.append(box.astype("int32"))
  127. scores.append(score)
  128. return np.array(boxes, dtype="int32"), scores
  129. def unclip(self, box, unclip_ratio):
  130. poly = Polygon(box)
  131. distance = poly.area * unclip_ratio / poly.length
  132. offset = pyclipper.PyclipperOffset()
  133. offset.AddPath(box, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON)
  134. expanded = np.array(offset.Execute(distance))
  135. return expanded
  136. def get_mini_boxes(self, contour):
  137. bounding_box = cv2.minAreaRect(contour)
  138. points = sorted(list(cv2.boxPoints(bounding_box)), key=lambda x: x[0])
  139. index_1, index_2, index_3, index_4 = 0, 1, 2, 3
  140. if points[1][1] > points[0][1]:
  141. index_1 = 0
  142. index_4 = 1
  143. else:
  144. index_1 = 1
  145. index_4 = 0
  146. if points[3][1] > points[2][1]:
  147. index_2 = 2
  148. index_3 = 3
  149. else:
  150. index_2 = 3
  151. index_3 = 2
  152. box = [
  153. points[index_1], points[index_2], points[index_3], points[index_4]
  154. ]
  155. return box, min(bounding_box[1])
  156. def box_score_fast(self, bitmap, _box):
  157. '''
  158. box_score_fast: use bbox mean score as the mean score
  159. '''
  160. h, w = bitmap.shape[:2]
  161. box = _box.copy()
  162. xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
  163. xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
  164. ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
  165. ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
  166. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  167. box[:, 0] = box[:, 0] - xmin
  168. box[:, 1] = box[:, 1] - ymin
  169. cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
  170. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  171. def box_score_slow(self, bitmap, contour):
  172. '''
  173. box_score_slow: use polyon mean score as the mean score
  174. '''
  175. h, w = bitmap.shape[:2]
  176. contour = contour.copy()
  177. contour = np.reshape(contour, (-1, 2))
  178. xmin = np.clip(np.min(contour[:, 0]), 0, w - 1)
  179. xmax = np.clip(np.max(contour[:, 0]), 0, w - 1)
  180. ymin = np.clip(np.min(contour[:, 1]), 0, h - 1)
  181. ymax = np.clip(np.max(contour[:, 1]), 0, h - 1)
  182. mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
  183. contour[:, 0] = contour[:, 0] - xmin
  184. contour[:, 1] = contour[:, 1] - ymin
  185. cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
  186. return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
  187. def __call__(self, outs_dict, shape_list):
  188. pred = outs_dict['maps']
  189. if isinstance(pred, paddle.Tensor):
  190. pred = pred.numpy()
  191. pred = pred[:, 0, :, :]
  192. segmentation = pred > self.thresh
  193. boxes_batch = []
  194. for batch_index in range(pred.shape[0]):
  195. src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
  196. if self.dilation_kernel is not None:
  197. mask = cv2.dilate(
  198. np.array(segmentation[batch_index]).astype(np.uint8),
  199. self.dilation_kernel)
  200. else:
  201. mask = segmentation[batch_index]
  202. if self.box_type == 'poly':
  203. boxes, scores = self.polygons_from_bitmap(pred[batch_index],
  204. mask, src_w, src_h)
  205. elif self.box_type == 'quad':
  206. boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
  207. src_w, src_h)
  208. else:
  209. raise ValueError("box_type can only be one of ['quad', 'poly']")
  210. boxes_batch.append({'points': boxes})
  211. return boxes_batch
  212. class DistillationDBPostProcess(object):
  213. def __init__(self,
  214. model_name=["student"],
  215. key=None,
  216. thresh=0.3,
  217. box_thresh=0.6,
  218. max_candidates=1000,
  219. unclip_ratio=1.5,
  220. use_dilation=False,
  221. score_mode="fast",
  222. box_type='quad',
  223. **kwargs):
  224. self.model_name = model_name
  225. self.key = key
  226. self.post_process = DBPostProcess(
  227. thresh=thresh,
  228. box_thresh=box_thresh,
  229. max_candidates=max_candidates,
  230. unclip_ratio=unclip_ratio,
  231. use_dilation=use_dilation,
  232. score_mode=score_mode,
  233. box_type=box_type)
  234. def __call__(self, predicts, shape_list):
  235. results = {}
  236. for k in self.model_name:
  237. results[k] = self.post_process(predicts[k], shape_list=shape_list)
  238. return results