poly_nms.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from shapely.geometry import Polygon
  16. def points2polygon(points):
  17. """Convert k points to 1 polygon.
  18. Args:
  19. points (ndarray or list): A ndarray or a list of shape (2k)
  20. that indicates k points.
  21. Returns:
  22. polygon (Polygon): A polygon object.
  23. """
  24. if isinstance(points, list):
  25. points = np.array(points)
  26. assert isinstance(points, np.ndarray)
  27. assert (points.size % 2 == 0) and (points.size >= 8)
  28. point_mat = points.reshape([-1, 2])
  29. return Polygon(point_mat)
  30. def poly_intersection(poly_det, poly_gt, buffer=0.0001):
  31. """Calculate the intersection area between two polygon.
  32. Args:
  33. poly_det (Polygon): A polygon predicted by detector.
  34. poly_gt (Polygon): A gt polygon.
  35. Returns:
  36. intersection_area (float): The intersection area between two polygons.
  37. """
  38. assert isinstance(poly_det, Polygon)
  39. assert isinstance(poly_gt, Polygon)
  40. if buffer == 0:
  41. poly_inter = poly_det & poly_gt
  42. else:
  43. poly_inter = poly_det.buffer(buffer) & poly_gt.buffer(buffer)
  44. return poly_inter.area, poly_inter
  45. def poly_union(poly_det, poly_gt):
  46. """Calculate the union area between two polygon.
  47. Args:
  48. poly_det (Polygon): A polygon predicted by detector.
  49. poly_gt (Polygon): A gt polygon.
  50. Returns:
  51. union_area (float): The union area between two polygons.
  52. """
  53. assert isinstance(poly_det, Polygon)
  54. assert isinstance(poly_gt, Polygon)
  55. area_det = poly_det.area
  56. area_gt = poly_gt.area
  57. area_inters, _ = poly_intersection(poly_det, poly_gt)
  58. return area_det + area_gt - area_inters
  59. def valid_boundary(x, with_score=True):
  60. num = len(x)
  61. if num < 8:
  62. return False
  63. if num % 2 == 0 and (not with_score):
  64. return True
  65. if num % 2 == 1 and with_score:
  66. return True
  67. return False
  68. def boundary_iou(src, target):
  69. """Calculate the IOU between two boundaries.
  70. Args:
  71. src (list): Source boundary.
  72. target (list): Target boundary.
  73. Returns:
  74. iou (float): The iou between two boundaries.
  75. """
  76. assert valid_boundary(src, False)
  77. assert valid_boundary(target, False)
  78. src_poly = points2polygon(src)
  79. target_poly = points2polygon(target)
  80. return poly_iou(src_poly, target_poly)
  81. def poly_iou(poly_det, poly_gt):
  82. """Calculate the IOU between two polygons.
  83. Args:
  84. poly_det (Polygon): A polygon predicted by detector.
  85. poly_gt (Polygon): A gt polygon.
  86. Returns:
  87. iou (float): The IOU between two polygons.
  88. """
  89. assert isinstance(poly_det, Polygon)
  90. assert isinstance(poly_gt, Polygon)
  91. area_inters, _ = poly_intersection(poly_det, poly_gt)
  92. area_union = poly_union(poly_det, poly_gt)
  93. if area_union == 0:
  94. return 0.0
  95. return area_inters / area_union
  96. def poly_nms(polygons, threshold):
  97. assert isinstance(polygons, list)
  98. polygons = np.array(sorted(polygons, key=lambda x: x[-1]))
  99. keep_poly = []
  100. index = [i for i in range(polygons.shape[0])]
  101. while len(index) > 0:
  102. keep_poly.append(polygons[index[-1]].tolist())
  103. A = polygons[index[-1]][:-1]
  104. index = np.delete(index, -1)
  105. iou_list = np.zeros((len(index), ))
  106. for i in range(len(index)):
  107. B = polygons[index[i]][:-1]
  108. iou_list[i] = boundary_iou(A, B)
  109. remove_index = np.where(iou_list > threshold)
  110. index = np.delete(index, remove_index)
  111. return keep_poly