locality_aware_nms.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. """
  2. Locality aware nms.
  3. This code is refered from: https://github.com/songdejia/EAST/blob/master/locality_aware_nms.py
  4. """
  5. import numpy as np
  6. from shapely.geometry import Polygon
  7. def intersection(g, p):
  8. """
  9. Intersection.
  10. """
  11. g = Polygon(g[:8].reshape((4, 2)))
  12. p = Polygon(p[:8].reshape((4, 2)))
  13. g = g.buffer(0)
  14. p = p.buffer(0)
  15. if not g.is_valid or not p.is_valid:
  16. return 0
  17. inter = Polygon(g).intersection(Polygon(p)).area
  18. union = g.area + p.area - inter
  19. if union == 0:
  20. return 0
  21. else:
  22. return inter / union
  23. def intersection_iog(g, p):
  24. """
  25. Intersection_iog.
  26. """
  27. g = Polygon(g[:8].reshape((4, 2)))
  28. p = Polygon(p[:8].reshape((4, 2)))
  29. if not g.is_valid or not p.is_valid:
  30. return 0
  31. inter = Polygon(g).intersection(Polygon(p)).area
  32. #union = g.area + p.area - inter
  33. union = p.area
  34. if union == 0:
  35. print("p_area is very small")
  36. return 0
  37. else:
  38. return inter / union
  39. def weighted_merge(g, p):
  40. """
  41. Weighted merge.
  42. """
  43. g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
  44. g[8] = (g[8] + p[8])
  45. return g
  46. def standard_nms(S, thres):
  47. """
  48. Standard nms.
  49. """
  50. order = np.argsort(S[:, 8])[::-1]
  51. keep = []
  52. while order.size > 0:
  53. i = order[0]
  54. keep.append(i)
  55. ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
  56. inds = np.where(ovr <= thres)[0]
  57. order = order[inds + 1]
  58. return S[keep]
  59. def standard_nms_inds(S, thres):
  60. """
  61. Standard nms, retun inds.
  62. """
  63. order = np.argsort(S[:, 8])[::-1]
  64. keep = []
  65. while order.size > 0:
  66. i = order[0]
  67. keep.append(i)
  68. ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
  69. inds = np.where(ovr <= thres)[0]
  70. order = order[inds + 1]
  71. return keep
  72. def nms(S, thres):
  73. """
  74. nms.
  75. """
  76. order = np.argsort(S[:, 8])[::-1]
  77. keep = []
  78. while order.size > 0:
  79. i = order[0]
  80. keep.append(i)
  81. ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
  82. inds = np.where(ovr <= thres)[0]
  83. order = order[inds + 1]
  84. return keep
  85. def soft_nms(boxes_in, Nt_thres=0.3, threshold=0.8, sigma=0.5, method=2):
  86. """
  87. soft_nms
  88. :para boxes_in, N x 9 (coords + score)
  89. :para threshould, eliminate cases min score(0.001)
  90. :para Nt_thres, iou_threshi
  91. :para sigma, gaussian weght
  92. :method, linear or gaussian
  93. """
  94. boxes = boxes_in.copy()
  95. N = boxes.shape[0]
  96. if N is None or N < 1:
  97. return np.array([])
  98. pos, maxpos = 0, 0
  99. weight = 0.0
  100. inds = np.arange(N)
  101. tbox, sbox = boxes[0].copy(), boxes[0].copy()
  102. for i in range(N):
  103. maxscore = boxes[i, 8]
  104. maxpos = i
  105. tbox = boxes[i].copy()
  106. ti = inds[i]
  107. pos = i + 1
  108. #get max box
  109. while pos < N:
  110. if maxscore < boxes[pos, 8]:
  111. maxscore = boxes[pos, 8]
  112. maxpos = pos
  113. pos = pos + 1
  114. #add max box as a detection
  115. boxes[i, :] = boxes[maxpos, :]
  116. inds[i] = inds[maxpos]
  117. #swap
  118. boxes[maxpos, :] = tbox
  119. inds[maxpos] = ti
  120. tbox = boxes[i].copy()
  121. pos = i + 1
  122. #NMS iteration
  123. while pos < N:
  124. sbox = boxes[pos].copy()
  125. ts_iou_val = intersection(tbox, sbox)
  126. if ts_iou_val > 0:
  127. if method == 1:
  128. if ts_iou_val > Nt_thres:
  129. weight = 1 - ts_iou_val
  130. else:
  131. weight = 1
  132. elif method == 2:
  133. weight = np.exp(-1.0 * ts_iou_val**2 / sigma)
  134. else:
  135. if ts_iou_val > Nt_thres:
  136. weight = 0
  137. else:
  138. weight = 1
  139. boxes[pos, 8] = weight * boxes[pos, 8]
  140. #if box score falls below thresold, discard the box by
  141. #swaping last box update N
  142. if boxes[pos, 8] < threshold:
  143. boxes[pos, :] = boxes[N - 1, :]
  144. inds[pos] = inds[N - 1]
  145. N = N - 1
  146. pos = pos - 1
  147. pos = pos + 1
  148. return boxes[:N]
  149. def nms_locality(polys, thres=0.3):
  150. """
  151. locality aware nms of EAST
  152. :param polys: a N*9 numpy array. first 8 coordinates, then prob
  153. :return: boxes after nms
  154. """
  155. S = []
  156. p = None
  157. for g in polys:
  158. if p is not None and intersection(g, p) > thres:
  159. p = weighted_merge(g, p)
  160. else:
  161. if p is not None:
  162. S.append(p)
  163. p = g
  164. if p is not None:
  165. S.append(p)
  166. if len(S) == 0:
  167. return np.array([])
  168. return standard_nms(np.array(S), thres)
  169. if __name__ == '__main__':
  170. # 343,350,448,135,474,143,369,359
  171. print(
  172. Polygon(np.array([[343, 350], [448, 135], [474, 143], [369, 359]]))
  173. .area)