ct_postprocess.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refered from:
  16. https://github.com/shengtao96/CentripetalText/blob/main/test.py
  17. """
  18. from __future__ import absolute_import
  19. from __future__ import division
  20. from __future__ import print_function
  21. import os
  22. import os.path as osp
  23. import numpy as np
  24. import cv2
  25. import paddle
  26. import pyclipper
  27. class CTPostProcess(object):
  28. """
  29. The post process for Centripetal Text (CT).
  30. """
  31. def __init__(self, min_score=0.88, min_area=16, box_type='poly', **kwargs):
  32. self.min_score = min_score
  33. self.min_area = min_area
  34. self.box_type = box_type
  35. self.coord = np.zeros((2, 300, 300), dtype=np.int32)
  36. for i in range(300):
  37. for j in range(300):
  38. self.coord[0, i, j] = j
  39. self.coord[1, i, j] = i
  40. def __call__(self, preds, batch):
  41. outs = preds['maps']
  42. out_scores = preds['score']
  43. if isinstance(outs, paddle.Tensor):
  44. outs = outs.numpy()
  45. if isinstance(out_scores, paddle.Tensor):
  46. out_scores = out_scores.numpy()
  47. batch_size = outs.shape[0]
  48. boxes_batch = []
  49. for idx in range(batch_size):
  50. bboxes = []
  51. scores = []
  52. img_shape = batch[idx]
  53. org_img_size = img_shape[:3]
  54. img_shape = img_shape[3:]
  55. img_size = img_shape[:2]
  56. out = np.expand_dims(outs[idx], axis=0)
  57. outputs = dict()
  58. score = np.expand_dims(out_scores[idx], axis=0)
  59. kernel = out[:, 0, :, :] > 0.2
  60. loc = out[:, 1:, :, :].astype("float32")
  61. score = score[0].astype(np.float32)
  62. kernel = kernel[0].astype(np.uint8)
  63. loc = loc[0].astype(np.float32)
  64. label_num, label_kernel = cv2.connectedComponents(
  65. kernel, connectivity=4)
  66. for i in range(1, label_num):
  67. ind = (label_kernel == i)
  68. if ind.sum(
  69. ) < 10: # pixel number less than 10, treated as background
  70. label_kernel[ind] = 0
  71. label = np.zeros_like(label_kernel)
  72. h, w = label_kernel.shape
  73. pixels = self.coord[:, :h, :w].reshape(2, -1)
  74. points = pixels.transpose([1, 0]).astype(np.float32)
  75. off_points = (points + 10. / 4. * loc[:, pixels[1], pixels[0]].T
  76. ).astype(np.int32)
  77. off_points[:, 0] = np.clip(off_points[:, 0], 0, label.shape[1] - 1)
  78. off_points[:, 1] = np.clip(off_points[:, 1], 0, label.shape[0] - 1)
  79. label[pixels[1], pixels[0]] = label_kernel[off_points[:, 1],
  80. off_points[:, 0]]
  81. label[label_kernel > 0] = label_kernel[label_kernel > 0]
  82. score_pocket = [0.0]
  83. for i in range(1, label_num):
  84. ind = (label_kernel == i)
  85. if ind.sum() == 0:
  86. score_pocket.append(0.0)
  87. continue
  88. score_i = np.mean(score[ind])
  89. score_pocket.append(score_i)
  90. label_num = np.max(label) + 1
  91. label = cv2.resize(
  92. label, (img_size[1], img_size[0]),
  93. interpolation=cv2.INTER_NEAREST)
  94. scale = (float(org_img_size[1]) / float(img_size[1]),
  95. float(org_img_size[0]) / float(img_size[0]))
  96. for i in range(1, label_num):
  97. ind = (label == i)
  98. points = np.array(np.where(ind)).transpose((1, 0))
  99. if points.shape[0] < self.min_area:
  100. continue
  101. score_i = score_pocket[i]
  102. if score_i < self.min_score:
  103. continue
  104. if self.box_type == 'rect':
  105. rect = cv2.minAreaRect(points[:, ::-1])
  106. bbox = cv2.boxPoints(rect) * scale
  107. z = bbox.mean(0)
  108. bbox = z + (bbox - z) * 0.85
  109. elif self.box_type == 'poly':
  110. binary = np.zeros(label.shape, dtype='uint8')
  111. binary[ind] = 1
  112. try:
  113. _, contours, _ = cv2.findContours(
  114. binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  115. except BaseException:
  116. contours, _ = cv2.findContours(
  117. binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
  118. bbox = contours[0] * scale
  119. bbox = bbox.astype('int32')
  120. bboxes.append(bbox.reshape(-1, 2))
  121. scores.append(score_i)
  122. boxes_batch.append({'points': bboxes})
  123. return boxes_batch