pse.pyx 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import numpy as np
  2. import cv2
  3. cimport numpy as np
  4. cimport cython
  5. cimport libcpp
  6. cimport libcpp.pair
  7. cimport libcpp.queue
  8. from libcpp.pair cimport *
  9. from libcpp.queue cimport *
  10. @cython.boundscheck(False)
  11. @cython.wraparound(False)
  12. cdef np.ndarray[np.int32_t, ndim=2] _pse(np.ndarray[np.uint8_t, ndim=3] kernels,
  13. np.ndarray[np.int32_t, ndim=2] label,
  14. int kernel_num,
  15. int label_num,
  16. float min_area=0):
  17. cdef np.ndarray[np.int32_t, ndim=2] pred
  18. pred = np.zeros((label.shape[0], label.shape[1]), dtype=np.int32)
  19. for label_idx in range(1, label_num):
  20. if np.sum(label == label_idx) < min_area:
  21. label[label == label_idx] = 0
  22. cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] que = \
  23. queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
  24. cdef libcpp.queue.queue[libcpp.pair.pair[np.int16_t,np.int16_t]] nxt_que = \
  25. queue[libcpp.pair.pair[np.int16_t,np.int16_t]]()
  26. cdef np.int16_t* dx = [-1, 1, 0, 0]
  27. cdef np.int16_t* dy = [0, 0, -1, 1]
  28. cdef np.int16_t tmpx, tmpy
  29. points = np.array(np.where(label > 0)).transpose((1, 0))
  30. for point_idx in range(points.shape[0]):
  31. tmpx, tmpy = points[point_idx, 0], points[point_idx, 1]
  32. que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
  33. pred[tmpx, tmpy] = label[tmpx, tmpy]
  34. cdef libcpp.pair.pair[np.int16_t,np.int16_t] cur
  35. cdef int cur_label
  36. for kernel_idx in range(kernel_num - 1, -1, -1):
  37. while not que.empty():
  38. cur = que.front()
  39. que.pop()
  40. cur_label = pred[cur.first, cur.second]
  41. is_edge = True
  42. for j in range(4):
  43. tmpx = cur.first + dx[j]
  44. tmpy = cur.second + dy[j]
  45. if tmpx < 0 or tmpx >= label.shape[0] or tmpy < 0 or tmpy >= label.shape[1]:
  46. continue
  47. if kernels[kernel_idx, tmpx, tmpy] == 0 or pred[tmpx, tmpy] > 0:
  48. continue
  49. que.push(pair[np.int16_t,np.int16_t](tmpx, tmpy))
  50. pred[tmpx, tmpy] = cur_label
  51. is_edge = False
  52. if is_edge:
  53. nxt_que.push(cur)
  54. que, nxt_que = nxt_que, que
  55. return pred
  56. def pse(kernels, min_area):
  57. kernel_num = kernels.shape[0]
  58. label_num, label = cv2.connectedComponents(kernels[-1], connectivity=4)
  59. return _pse(kernels[:-1], label, kernel_num, label_num, min_area)