drrg_targets.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is refer from:
  16. https://github.com/open-mmlab/mmocr/blob/main/mmocr/datasets/pipelines/textdet_targets/drrg_targets.py
  17. """
  18. import cv2
  19. import numpy as np
  20. from lanms import merge_quadrangle_n9 as la_nms
  21. from numpy.linalg import norm
  22. class DRRGTargets(object):
  23. def __init__(self,
  24. orientation_thr=2.0,
  25. resample_step=8.0,
  26. num_min_comps=9,
  27. num_max_comps=600,
  28. min_width=8.0,
  29. max_width=24.0,
  30. center_region_shrink_ratio=0.3,
  31. comp_shrink_ratio=1.0,
  32. comp_w_h_ratio=0.3,
  33. text_comp_nms_thr=0.25,
  34. min_rand_half_height=8.0,
  35. max_rand_half_height=24.0,
  36. jitter_level=0.2,
  37. **kwargs):
  38. super().__init__()
  39. self.orientation_thr = orientation_thr
  40. self.resample_step = resample_step
  41. self.num_max_comps = num_max_comps
  42. self.num_min_comps = num_min_comps
  43. self.min_width = min_width
  44. self.max_width = max_width
  45. self.center_region_shrink_ratio = center_region_shrink_ratio
  46. self.comp_shrink_ratio = comp_shrink_ratio
  47. self.comp_w_h_ratio = comp_w_h_ratio
  48. self.text_comp_nms_thr = text_comp_nms_thr
  49. self.min_rand_half_height = min_rand_half_height
  50. self.max_rand_half_height = max_rand_half_height
  51. self.jitter_level = jitter_level
  52. self.eps = 1e-8
  53. def vector_angle(self, vec1, vec2):
  54. if vec1.ndim > 1:
  55. unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps).reshape((-1, 1))
  56. else:
  57. unit_vec1 = vec1 / (norm(vec1, axis=-1) + self.eps)
  58. if vec2.ndim > 1:
  59. unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps).reshape((-1, 1))
  60. else:
  61. unit_vec2 = vec2 / (norm(vec2, axis=-1) + self.eps)
  62. return np.arccos(
  63. np.clip(
  64. np.sum(unit_vec1 * unit_vec2, axis=-1), -1.0, 1.0))
  65. def vector_slope(self, vec):
  66. assert len(vec) == 2
  67. return abs(vec[1] / (vec[0] + self.eps))
  68. def vector_sin(self, vec):
  69. assert len(vec) == 2
  70. return vec[1] / (norm(vec) + self.eps)
  71. def vector_cos(self, vec):
  72. assert len(vec) == 2
  73. return vec[0] / (norm(vec) + self.eps)
  74. def find_head_tail(self, points, orientation_thr):
  75. assert points.ndim == 2
  76. assert points.shape[0] >= 4
  77. assert points.shape[1] == 2
  78. assert isinstance(orientation_thr, float)
  79. if len(points) > 4:
  80. pad_points = np.vstack([points, points[0]])
  81. edge_vec = pad_points[1:] - pad_points[:-1]
  82. theta_sum = []
  83. adjacent_vec_theta = []
  84. for i, edge_vec1 in enumerate(edge_vec):
  85. adjacent_ind = [x % len(edge_vec) for x in [i - 1, i + 1]]
  86. adjacent_edge_vec = edge_vec[adjacent_ind]
  87. temp_theta_sum = np.sum(
  88. self.vector_angle(edge_vec1, adjacent_edge_vec))
  89. temp_adjacent_theta = self.vector_angle(adjacent_edge_vec[0],
  90. adjacent_edge_vec[1])
  91. theta_sum.append(temp_theta_sum)
  92. adjacent_vec_theta.append(temp_adjacent_theta)
  93. theta_sum_score = np.array(theta_sum) / np.pi
  94. adjacent_theta_score = np.array(adjacent_vec_theta) / np.pi
  95. poly_center = np.mean(points, axis=0)
  96. edge_dist = np.maximum(
  97. norm(
  98. pad_points[1:] - poly_center, axis=-1),
  99. norm(
  100. pad_points[:-1] - poly_center, axis=-1))
  101. dist_score = edge_dist / (np.max(edge_dist) + self.eps)
  102. position_score = np.zeros(len(edge_vec))
  103. score = 0.5 * theta_sum_score + 0.15 * adjacent_theta_score
  104. score += 0.35 * dist_score
  105. if len(points) % 2 == 0:
  106. position_score[(len(score) // 2 - 1)] += 1
  107. position_score[-1] += 1
  108. score += 0.1 * position_score
  109. pad_score = np.concatenate([score, score])
  110. score_matrix = np.zeros((len(score), len(score) - 3))
  111. x = np.arange(len(score) - 3) / float(len(score) - 4)
  112. gaussian = 1. / (np.sqrt(2. * np.pi) * 0.5) * np.exp(-np.power(
  113. (x - 0.5) / 0.5, 2.) / 2)
  114. gaussian = gaussian / np.max(gaussian)
  115. for i in range(len(score)):
  116. score_matrix[i, :] = score[i] + pad_score[(i + 2):(i + len(
  117. score) - 1)] * gaussian * 0.3
  118. head_start, tail_increment = np.unravel_index(score_matrix.argmax(),
  119. score_matrix.shape)
  120. tail_start = (head_start + tail_increment + 2) % len(points)
  121. head_end = (head_start + 1) % len(points)
  122. tail_end = (tail_start + 1) % len(points)
  123. if head_end > tail_end:
  124. head_start, tail_start = tail_start, head_start
  125. head_end, tail_end = tail_end, head_end
  126. head_inds = [head_start, head_end]
  127. tail_inds = [tail_start, tail_end]
  128. else:
  129. if self.vector_slope(points[1] - points[0]) + self.vector_slope(
  130. points[3] - points[2]) < self.vector_slope(points[
  131. 2] - points[1]) + self.vector_slope(points[0] - points[
  132. 3]):
  133. horizontal_edge_inds = [[0, 1], [2, 3]]
  134. vertical_edge_inds = [[3, 0], [1, 2]]
  135. else:
  136. horizontal_edge_inds = [[3, 0], [1, 2]]
  137. vertical_edge_inds = [[0, 1], [2, 3]]
  138. vertical_len_sum = norm(points[vertical_edge_inds[0][0]] - points[
  139. vertical_edge_inds[0][1]]) + norm(points[vertical_edge_inds[1][
  140. 0]] - points[vertical_edge_inds[1][1]])
  141. horizontal_len_sum = norm(points[horizontal_edge_inds[0][
  142. 0]] - points[horizontal_edge_inds[0][1]]) + norm(points[
  143. horizontal_edge_inds[1][0]] - points[horizontal_edge_inds[1]
  144. [1]])
  145. if vertical_len_sum > horizontal_len_sum * orientation_thr:
  146. head_inds = horizontal_edge_inds[0]
  147. tail_inds = horizontal_edge_inds[1]
  148. else:
  149. head_inds = vertical_edge_inds[0]
  150. tail_inds = vertical_edge_inds[1]
  151. return head_inds, tail_inds
  152. def reorder_poly_edge(self, points):
  153. assert points.ndim == 2
  154. assert points.shape[0] >= 4
  155. assert points.shape[1] == 2
  156. head_inds, tail_inds = self.find_head_tail(points, self.orientation_thr)
  157. head_edge, tail_edge = points[head_inds], points[tail_inds]
  158. pad_points = np.vstack([points, points])
  159. if tail_inds[1] < 1:
  160. tail_inds[1] = len(points)
  161. sideline1 = pad_points[head_inds[1]:tail_inds[1]]
  162. sideline2 = pad_points[tail_inds[1]:(head_inds[1] + len(points))]
  163. sideline_mean_shift = np.mean(
  164. sideline1, axis=0) - np.mean(
  165. sideline2, axis=0)
  166. if sideline_mean_shift[1] > 0:
  167. top_sideline, bot_sideline = sideline2, sideline1
  168. else:
  169. top_sideline, bot_sideline = sideline1, sideline2
  170. return head_edge, tail_edge, top_sideline, bot_sideline
  171. def cal_curve_length(self, line):
  172. assert line.ndim == 2
  173. assert len(line) >= 2
  174. edges_length = np.sqrt((line[1:, 0] - line[:-1, 0])**2 + (line[
  175. 1:, 1] - line[:-1, 1])**2)
  176. total_length = np.sum(edges_length)
  177. return edges_length, total_length
  178. def resample_line(self, line, n):
  179. assert line.ndim == 2
  180. assert line.shape[0] >= 2
  181. assert line.shape[1] == 2
  182. assert isinstance(n, int)
  183. assert n > 2
  184. edges_length, total_length = self.cal_curve_length(line)
  185. t_org = np.insert(np.cumsum(edges_length), 0, 0)
  186. unit_t = total_length / (n - 1)
  187. t_equidistant = np.arange(1, n - 1, dtype=np.float32) * unit_t
  188. edge_ind = 0
  189. points = [line[0]]
  190. for t in t_equidistant:
  191. while edge_ind < len(edges_length) - 1 and t > t_org[edge_ind + 1]:
  192. edge_ind += 1
  193. t_l, t_r = t_org[edge_ind], t_org[edge_ind + 1]
  194. weight = np.array(
  195. [t_r - t, t - t_l], dtype=np.float32) / (t_r - t_l + self.eps)
  196. p_coords = np.dot(weight, line[[edge_ind, edge_ind + 1]])
  197. points.append(p_coords)
  198. points.append(line[-1])
  199. resampled_line = np.vstack(points)
  200. return resampled_line
  201. def resample_sidelines(self, sideline1, sideline2, resample_step):
  202. assert sideline1.ndim == sideline2.ndim == 2
  203. assert sideline1.shape[1] == sideline2.shape[1] == 2
  204. assert sideline1.shape[0] >= 2
  205. assert sideline2.shape[0] >= 2
  206. assert isinstance(resample_step, float)
  207. _, length1 = self.cal_curve_length(sideline1)
  208. _, length2 = self.cal_curve_length(sideline2)
  209. avg_length = (length1 + length2) / 2
  210. resample_point_num = max(int(float(avg_length) / resample_step) + 1, 3)
  211. resampled_line1 = self.resample_line(sideline1, resample_point_num)
  212. resampled_line2 = self.resample_line(sideline2, resample_point_num)
  213. return resampled_line1, resampled_line2
  214. def dist_point2line(self, point, line):
  215. assert isinstance(line, tuple)
  216. point1, point2 = line
  217. d = abs(np.cross(point2 - point1, point - point1)) / (
  218. norm(point2 - point1) + 1e-8)
  219. return d
  220. def draw_center_region_maps(self, top_line, bot_line, center_line,
  221. center_region_mask, top_height_map,
  222. bot_height_map, sin_map, cos_map,
  223. region_shrink_ratio):
  224. assert top_line.shape == bot_line.shape == center_line.shape
  225. assert (center_region_mask.shape == top_height_map.shape ==
  226. bot_height_map.shape == sin_map.shape == cos_map.shape)
  227. assert isinstance(region_shrink_ratio, float)
  228. h, w = center_region_mask.shape
  229. for i in range(0, len(center_line) - 1):
  230. top_mid_point = (top_line[i] + top_line[i + 1]) / 2
  231. bot_mid_point = (bot_line[i] + bot_line[i + 1]) / 2
  232. sin_theta = self.vector_sin(top_mid_point - bot_mid_point)
  233. cos_theta = self.vector_cos(top_mid_point - bot_mid_point)
  234. tl = center_line[i] + (top_line[i] - center_line[i]
  235. ) * region_shrink_ratio
  236. tr = center_line[i + 1] + (top_line[i + 1] - center_line[i + 1]
  237. ) * region_shrink_ratio
  238. br = center_line[i + 1] + (bot_line[i + 1] - center_line[i + 1]
  239. ) * region_shrink_ratio
  240. bl = center_line[i] + (bot_line[i] - center_line[i]
  241. ) * region_shrink_ratio
  242. current_center_box = np.vstack([tl, tr, br, bl]).astype(np.int32)
  243. cv2.fillPoly(center_region_mask, [current_center_box], color=1)
  244. cv2.fillPoly(sin_map, [current_center_box], color=sin_theta)
  245. cv2.fillPoly(cos_map, [current_center_box], color=cos_theta)
  246. current_center_box[:, 0] = np.clip(current_center_box[:, 0], 0,
  247. w - 1)
  248. current_center_box[:, 1] = np.clip(current_center_box[:, 1], 0,
  249. h - 1)
  250. min_coord = np.min(current_center_box, axis=0).astype(np.int32)
  251. max_coord = np.max(current_center_box, axis=0).astype(np.int32)
  252. current_center_box = current_center_box - min_coord
  253. box_sz = (max_coord - min_coord + 1)
  254. center_box_mask = np.zeros((box_sz[1], box_sz[0]), dtype=np.uint8)
  255. cv2.fillPoly(center_box_mask, [current_center_box], color=1)
  256. inds = np.argwhere(center_box_mask > 0)
  257. inds = inds + (min_coord[1], min_coord[0])
  258. inds_xy = np.fliplr(inds)
  259. top_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
  260. inds_xy, (top_line[i], top_line[i + 1]))
  261. bot_height_map[(inds[:, 0], inds[:, 1])] = self.dist_point2line(
  262. inds_xy, (bot_line[i], bot_line[i + 1]))
  263. def generate_center_mask_attrib_maps(self, img_size, text_polys):
  264. assert isinstance(img_size, tuple)
  265. h, w = img_size
  266. center_lines = []
  267. center_region_mask = np.zeros((h, w), np.uint8)
  268. top_height_map = np.zeros((h, w), dtype=np.float32)
  269. bot_height_map = np.zeros((h, w), dtype=np.float32)
  270. sin_map = np.zeros((h, w), dtype=np.float32)
  271. cos_map = np.zeros((h, w), dtype=np.float32)
  272. for poly in text_polys:
  273. polygon_points = poly
  274. _, _, top_line, bot_line = self.reorder_poly_edge(polygon_points)
  275. resampled_top_line, resampled_bot_line = self.resample_sidelines(
  276. top_line, bot_line, self.resample_step)
  277. resampled_bot_line = resampled_bot_line[::-1]
  278. center_line = (resampled_top_line + resampled_bot_line) / 2
  279. if self.vector_slope(center_line[-1] - center_line[0]) > 2:
  280. if (center_line[-1] - center_line[0])[1] < 0:
  281. center_line = center_line[::-1]
  282. resampled_top_line = resampled_top_line[::-1]
  283. resampled_bot_line = resampled_bot_line[::-1]
  284. else:
  285. if (center_line[-1] - center_line[0])[0] < 0:
  286. center_line = center_line[::-1]
  287. resampled_top_line = resampled_top_line[::-1]
  288. resampled_bot_line = resampled_bot_line[::-1]
  289. line_head_shrink_len = np.clip(
  290. (norm(top_line[0] - bot_line[0]) * self.comp_w_h_ratio),
  291. self.min_width, self.max_width) / 2
  292. line_tail_shrink_len = np.clip(
  293. (norm(top_line[-1] - bot_line[-1]) * self.comp_w_h_ratio),
  294. self.min_width, self.max_width) / 2
  295. num_head_shrink = int(line_head_shrink_len // self.resample_step)
  296. num_tail_shrink = int(line_tail_shrink_len // self.resample_step)
  297. if len(center_line) > num_head_shrink + num_tail_shrink + 2:
  298. center_line = center_line[num_head_shrink:len(center_line) -
  299. num_tail_shrink]
  300. resampled_top_line = resampled_top_line[num_head_shrink:len(
  301. resampled_top_line) - num_tail_shrink]
  302. resampled_bot_line = resampled_bot_line[num_head_shrink:len(
  303. resampled_bot_line) - num_tail_shrink]
  304. center_lines.append(center_line.astype(np.int32))
  305. self.draw_center_region_maps(
  306. resampled_top_line, resampled_bot_line, center_line,
  307. center_region_mask, top_height_map, bot_height_map, sin_map,
  308. cos_map, self.center_region_shrink_ratio)
  309. return (center_lines, center_region_mask, top_height_map,
  310. bot_height_map, sin_map, cos_map)
  311. def generate_rand_comp_attribs(self, num_rand_comps, center_sample_mask):
  312. assert isinstance(num_rand_comps, int)
  313. assert num_rand_comps > 0
  314. assert center_sample_mask.ndim == 2
  315. h, w = center_sample_mask.shape
  316. max_rand_half_height = self.max_rand_half_height
  317. min_rand_half_height = self.min_rand_half_height
  318. max_rand_height = max_rand_half_height * 2
  319. max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
  320. self.min_width, self.max_width)
  321. margin = int(
  322. np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
  323. if 2 * margin + 1 > min(h, w):
  324. assert min(h, w) > (np.sqrt(2) * (self.min_width + 1))
  325. max_rand_half_height = max(min(h, w) / 4, self.min_width / 2 + 1)
  326. min_rand_half_height = max(max_rand_half_height / 4,
  327. self.min_width / 2)
  328. max_rand_height = max_rand_half_height * 2
  329. max_rand_width = np.clip(max_rand_height * self.comp_w_h_ratio,
  330. self.min_width, self.max_width)
  331. margin = int(
  332. np.sqrt((max_rand_height / 2)**2 + (max_rand_width / 2)**2)) + 1
  333. inner_center_sample_mask = np.zeros_like(center_sample_mask)
  334. inner_center_sample_mask[margin:h - margin, margin:w - margin] = \
  335. center_sample_mask[margin:h - margin, margin:w - margin]
  336. kernel_size = int(np.clip(max_rand_half_height, 7, 21))
  337. inner_center_sample_mask = cv2.erode(
  338. inner_center_sample_mask,
  339. np.ones((kernel_size, kernel_size), np.uint8))
  340. center_candidates = np.argwhere(inner_center_sample_mask > 0)
  341. num_center_candidates = len(center_candidates)
  342. sample_inds = np.random.choice(num_center_candidates, num_rand_comps)
  343. rand_centers = center_candidates[sample_inds]
  344. rand_top_height = np.random.randint(
  345. min_rand_half_height,
  346. max_rand_half_height,
  347. size=(len(rand_centers), 1))
  348. rand_bot_height = np.random.randint(
  349. min_rand_half_height,
  350. max_rand_half_height,
  351. size=(len(rand_centers), 1))
  352. rand_cos = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
  353. rand_sin = 2 * np.random.random(size=(len(rand_centers), 1)) - 1
  354. scale = np.sqrt(1.0 / (rand_cos**2 + rand_sin**2 + 1e-8))
  355. rand_cos = rand_cos * scale
  356. rand_sin = rand_sin * scale
  357. height = (rand_top_height + rand_bot_height)
  358. width = np.clip(height * self.comp_w_h_ratio, self.min_width,
  359. self.max_width)
  360. rand_comp_attribs = np.hstack([
  361. rand_centers[:, ::-1], height, width, rand_cos, rand_sin,
  362. np.zeros_like(rand_sin)
  363. ]).astype(np.float32)
  364. return rand_comp_attribs
  365. def jitter_comp_attribs(self, comp_attribs, jitter_level):
  366. """Jitter text components attributes.
  367. Args:
  368. comp_attribs (ndarray): The text component attributes.
  369. jitter_level (float): The jitter level of text components
  370. attributes.
  371. Returns:
  372. jittered_comp_attribs (ndarray): The jittered text component
  373. attributes (x, y, h, w, cos, sin, comp_label).
  374. """
  375. assert comp_attribs.shape[1] == 7
  376. assert comp_attribs.shape[0] > 0
  377. assert isinstance(jitter_level, float)
  378. x = comp_attribs[:, 0].reshape((-1, 1))
  379. y = comp_attribs[:, 1].reshape((-1, 1))
  380. h = comp_attribs[:, 2].reshape((-1, 1))
  381. w = comp_attribs[:, 3].reshape((-1, 1))
  382. cos = comp_attribs[:, 4].reshape((-1, 1))
  383. sin = comp_attribs[:, 5].reshape((-1, 1))
  384. comp_labels = comp_attribs[:, 6].reshape((-1, 1))
  385. x += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
  386. h * np.abs(cos) + w * np.abs(sin)) * jitter_level
  387. y += (np.random.random(size=(len(comp_attribs), 1)) - 0.5) * (
  388. h * np.abs(sin) + w * np.abs(cos)) * jitter_level
  389. h += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
  390. ) * h * jitter_level
  391. w += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
  392. ) * w * jitter_level
  393. cos += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
  394. ) * 2 * jitter_level
  395. sin += (np.random.random(size=(len(comp_attribs), 1)) - 0.5
  396. ) * 2 * jitter_level
  397. scale = np.sqrt(1.0 / (cos**2 + sin**2 + 1e-8))
  398. cos = cos * scale
  399. sin = sin * scale
  400. jittered_comp_attribs = np.hstack([x, y, h, w, cos, sin, comp_labels])
  401. return jittered_comp_attribs
  402. def generate_comp_attribs(self, center_lines, text_mask, center_region_mask,
  403. top_height_map, bot_height_map, sin_map, cos_map):
  404. """Generate text component attributes.
  405. Args:
  406. center_lines (list[ndarray]): The list of text center lines .
  407. text_mask (ndarray): The text region mask.
  408. center_region_mask (ndarray): The text center region mask.
  409. top_height_map (ndarray): The map on which the distance from points
  410. to top side lines will be drawn for each pixel in text center
  411. regions.
  412. bot_height_map (ndarray): The map on which the distance from points
  413. to bottom side lines will be drawn for each pixel in text
  414. center regions.
  415. sin_map (ndarray): The sin(theta) map where theta is the angle
  416. between vector (top point - bottom point) and vector (1, 0).
  417. cos_map (ndarray): The cos(theta) map where theta is the angle
  418. between vector (top point - bottom point) and vector (1, 0).
  419. Returns:
  420. pad_comp_attribs (ndarray): The padded text component attributes
  421. of a fixed size.
  422. """
  423. assert isinstance(center_lines, list)
  424. assert (
  425. text_mask.shape == center_region_mask.shape == top_height_map.shape
  426. == bot_height_map.shape == sin_map.shape == cos_map.shape)
  427. center_lines_mask = np.zeros_like(center_region_mask)
  428. cv2.polylines(center_lines_mask, center_lines, 0, 1, 1)
  429. center_lines_mask = center_lines_mask * center_region_mask
  430. comp_centers = np.argwhere(center_lines_mask > 0)
  431. y = comp_centers[:, 0]
  432. x = comp_centers[:, 1]
  433. top_height = top_height_map[y, x].reshape(
  434. (-1, 1)) * self.comp_shrink_ratio
  435. bot_height = bot_height_map[y, x].reshape(
  436. (-1, 1)) * self.comp_shrink_ratio
  437. sin = sin_map[y, x].reshape((-1, 1))
  438. cos = cos_map[y, x].reshape((-1, 1))
  439. top_mid_points = comp_centers + np.hstack(
  440. [top_height * sin, top_height * cos])
  441. bot_mid_points = comp_centers - np.hstack(
  442. [bot_height * sin, bot_height * cos])
  443. width = (top_height + bot_height) * self.comp_w_h_ratio
  444. width = np.clip(width, self.min_width, self.max_width)
  445. r = width / 2
  446. tl = top_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
  447. tr = top_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
  448. br = bot_mid_points[:, ::-1] + np.hstack([-r * sin, r * cos])
  449. bl = bot_mid_points[:, ::-1] - np.hstack([-r * sin, r * cos])
  450. text_comps = np.hstack([tl, tr, br, bl]).astype(np.float32)
  451. score = np.ones((text_comps.shape[0], 1), dtype=np.float32)
  452. text_comps = np.hstack([text_comps, score])
  453. text_comps = la_nms(text_comps, self.text_comp_nms_thr)
  454. if text_comps.shape[0] >= 1:
  455. img_h, img_w = center_region_mask.shape
  456. text_comps[:, 0:8:2] = np.clip(text_comps[:, 0:8:2], 0, img_w - 1)
  457. text_comps[:, 1:8:2] = np.clip(text_comps[:, 1:8:2], 0, img_h - 1)
  458. comp_centers = np.mean(
  459. text_comps[:, 0:8].reshape((-1, 4, 2)), axis=1).astype(np.int32)
  460. x = comp_centers[:, 0]
  461. y = comp_centers[:, 1]
  462. height = (top_height_map[y, x] + bot_height_map[y, x]).reshape(
  463. (-1, 1))
  464. width = np.clip(height * self.comp_w_h_ratio, self.min_width,
  465. self.max_width)
  466. cos = cos_map[y, x].reshape((-1, 1))
  467. sin = sin_map[y, x].reshape((-1, 1))
  468. _, comp_label_mask = cv2.connectedComponents(
  469. center_region_mask, connectivity=8)
  470. comp_labels = comp_label_mask[y, x].reshape(
  471. (-1, 1)).astype(np.float32)
  472. x = x.reshape((-1, 1)).astype(np.float32)
  473. y = y.reshape((-1, 1)).astype(np.float32)
  474. comp_attribs = np.hstack(
  475. [x, y, height, width, cos, sin, comp_labels])
  476. comp_attribs = self.jitter_comp_attribs(comp_attribs,
  477. self.jitter_level)
  478. if comp_attribs.shape[0] < self.num_min_comps:
  479. num_rand_comps = self.num_min_comps - comp_attribs.shape[0]
  480. rand_comp_attribs = self.generate_rand_comp_attribs(
  481. num_rand_comps, 1 - text_mask)
  482. comp_attribs = np.vstack([comp_attribs, rand_comp_attribs])
  483. else:
  484. comp_attribs = self.generate_rand_comp_attribs(self.num_min_comps,
  485. 1 - text_mask)
  486. num_comps = (np.ones(
  487. (comp_attribs.shape[0], 1),
  488. dtype=np.float32) * comp_attribs.shape[0])
  489. comp_attribs = np.hstack([num_comps, comp_attribs])
  490. if comp_attribs.shape[0] > self.num_max_comps:
  491. comp_attribs = comp_attribs[:self.num_max_comps, :]
  492. comp_attribs[:, 0] = self.num_max_comps
  493. pad_comp_attribs = np.zeros(
  494. (self.num_max_comps, comp_attribs.shape[1]), dtype=np.float32)
  495. pad_comp_attribs[:comp_attribs.shape[0], :] = comp_attribs
  496. return pad_comp_attribs
  497. def generate_text_region_mask(self, img_size, text_polys):
  498. """Generate text center region mask and geometry attribute maps.
  499. Args:
  500. img_size (tuple): The image size (height, width).
  501. text_polys (list[list[ndarray]]): The list of text polygons.
  502. Returns:
  503. text_region_mask (ndarray): The text region mask.
  504. """
  505. assert isinstance(img_size, tuple)
  506. h, w = img_size
  507. text_region_mask = np.zeros((h, w), dtype=np.uint8)
  508. for poly in text_polys:
  509. polygon = np.array(poly, dtype=np.int32).reshape((1, -1, 2))
  510. cv2.fillPoly(text_region_mask, polygon, 1)
  511. return text_region_mask
  512. def generate_effective_mask(self, mask_size: tuple, polygons_ignore):
  513. """Generate effective mask by setting the ineffective regions to 0 and
  514. effective regions to 1.
  515. Args:
  516. mask_size (tuple): The mask size.
  517. polygons_ignore (list[[ndarray]]: The list of ignored text
  518. polygons.
  519. Returns:
  520. mask (ndarray): The effective mask of (height, width).
  521. """
  522. mask = np.ones(mask_size, dtype=np.uint8)
  523. for poly in polygons_ignore:
  524. instance = poly.astype(np.int32).reshape(1, -1, 2)
  525. cv2.fillPoly(mask, instance, 0)
  526. return mask
  527. def generate_targets(self, data):
  528. """Generate the gt targets for DRRG.
  529. Args:
  530. data (dict): The input result dictionary.
  531. Returns:
  532. data (dict): The output result dictionary.
  533. """
  534. assert isinstance(data, dict)
  535. image = data['image']
  536. polygons = data['polys']
  537. ignore_tags = data['ignore_tags']
  538. h, w, _ = image.shape
  539. polygon_masks = []
  540. polygon_masks_ignore = []
  541. for tag, polygon in zip(ignore_tags, polygons):
  542. if tag is True:
  543. polygon_masks_ignore.append(polygon)
  544. else:
  545. polygon_masks.append(polygon)
  546. gt_text_mask = self.generate_text_region_mask((h, w), polygon_masks)
  547. gt_mask = self.generate_effective_mask((h, w), polygon_masks_ignore)
  548. (center_lines, gt_center_region_mask, gt_top_height_map,
  549. gt_bot_height_map, gt_sin_map,
  550. gt_cos_map) = self.generate_center_mask_attrib_maps((h, w),
  551. polygon_masks)
  552. gt_comp_attribs = self.generate_comp_attribs(
  553. center_lines, gt_text_mask, gt_center_region_mask,
  554. gt_top_height_map, gt_bot_height_map, gt_sin_map, gt_cos_map)
  555. mapping = {
  556. 'gt_text_mask': gt_text_mask,
  557. 'gt_center_region_mask': gt_center_region_mask,
  558. 'gt_mask': gt_mask,
  559. 'gt_top_height_map': gt_top_height_map,
  560. 'gt_bot_height_map': gt_bot_height_map,
  561. 'gt_sin_map': gt_sin_map,
  562. 'gt_cos_map': gt_cos_map
  563. }
  564. data.update(mapping)
  565. data['gt_comp_attribs'] = gt_comp_attribs
  566. return data
  567. def __call__(self, data):
  568. data = self.generate_targets(data)
  569. return data