pg_process.py 40 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import cv2
  16. import numpy as np
  17. from skimage.morphology._skeletonize import thin
  18. from ppocr.utils.e2e_utils.extract_textpoint_fast import sort_and_expand_with_direction_v2
  19. __all__ = ['PGProcessTrain']
  20. class PGProcessTrain(object):
  21. def __init__(self,
  22. character_dict_path,
  23. max_text_length,
  24. max_text_nums,
  25. tcl_len,
  26. batch_size=14,
  27. use_resize=True,
  28. use_random_crop=False,
  29. min_crop_size=24,
  30. min_text_size=4,
  31. max_text_size=512,
  32. point_gather_mode=None,
  33. **kwargs):
  34. self.tcl_len = tcl_len
  35. self.max_text_length = max_text_length
  36. self.max_text_nums = max_text_nums
  37. self.batch_size = batch_size
  38. if use_random_crop is True:
  39. self.min_crop_size = min_crop_size
  40. self.use_random_crop = use_random_crop
  41. self.min_text_size = min_text_size
  42. self.max_text_size = max_text_size
  43. self.use_resize = use_resize
  44. self.point_gather_mode = point_gather_mode
  45. self.Lexicon_Table = self.get_dict(character_dict_path)
  46. self.pad_num = len(self.Lexicon_Table)
  47. self.img_id = 0
  48. def get_dict(self, character_dict_path):
  49. character_str = ""
  50. with open(character_dict_path, "rb") as fin:
  51. lines = fin.readlines()
  52. for line in lines:
  53. line = line.decode('utf-8').strip("\n").strip("\r\n")
  54. character_str += line
  55. dict_character = list(character_str)
  56. return dict_character
  57. def quad_area(self, poly):
  58. """
  59. compute area of a polygon
  60. :param poly:
  61. :return:
  62. """
  63. edge = [(poly[1][0] - poly[0][0]) * (poly[1][1] + poly[0][1]),
  64. (poly[2][0] - poly[1][0]) * (poly[2][1] + poly[1][1]),
  65. (poly[3][0] - poly[2][0]) * (poly[3][1] + poly[2][1]),
  66. (poly[0][0] - poly[3][0]) * (poly[0][1] + poly[3][1])]
  67. return np.sum(edge) / 2.
  68. def gen_quad_from_poly(self, poly):
  69. """
  70. Generate min area quad from poly.
  71. """
  72. point_num = poly.shape[0]
  73. min_area_quad = np.zeros((4, 2), dtype=np.float32)
  74. rect = cv2.minAreaRect(poly.astype(
  75. np.int32)) # (center (x,y), (width, height), angle of rotation)
  76. box = np.array(cv2.boxPoints(rect))
  77. first_point_idx = 0
  78. min_dist = 1e4
  79. for i in range(4):
  80. dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
  81. np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
  82. np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
  83. np.linalg.norm(box[(i + 3) % 4] - poly[-1])
  84. if dist < min_dist:
  85. min_dist = dist
  86. first_point_idx = i
  87. for i in range(4):
  88. min_area_quad[i] = box[(first_point_idx + i) % 4]
  89. return min_area_quad
  90. def check_and_validate_polys(self, polys, tags, im_size):
  91. """
  92. check so that the text poly is in the same direction,
  93. and also filter some invalid polygons
  94. :param polys:
  95. :param tags:
  96. :return:
  97. """
  98. (h, w) = im_size
  99. if polys.shape[0] == 0:
  100. return polys, np.array([]), np.array([])
  101. polys[:, :, 0] = np.clip(polys[:, :, 0], 0, w - 1)
  102. polys[:, :, 1] = np.clip(polys[:, :, 1], 0, h - 1)
  103. validated_polys = []
  104. validated_tags = []
  105. hv_tags = []
  106. for poly, tag in zip(polys, tags):
  107. quad = self.gen_quad_from_poly(poly)
  108. p_area = self.quad_area(quad)
  109. if abs(p_area) < 1:
  110. print('invalid poly')
  111. continue
  112. if p_area > 0:
  113. if tag == False:
  114. print('poly in wrong direction')
  115. tag = True # reversed cases should be ignore
  116. poly = poly[(0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2,
  117. 1), :]
  118. quad = quad[(0, 3, 2, 1), :]
  119. len_w = np.linalg.norm(quad[0] - quad[1]) + np.linalg.norm(quad[3] -
  120. quad[2])
  121. len_h = np.linalg.norm(quad[0] - quad[3]) + np.linalg.norm(quad[1] -
  122. quad[2])
  123. hv_tag = 1
  124. if len_w * 2.0 < len_h:
  125. hv_tag = 0
  126. validated_polys.append(poly)
  127. validated_tags.append(tag)
  128. hv_tags.append(hv_tag)
  129. return np.array(validated_polys), np.array(validated_tags), np.array(
  130. hv_tags)
  131. def crop_area(self,
  132. im,
  133. polys,
  134. tags,
  135. hv_tags,
  136. txts,
  137. crop_background=False,
  138. max_tries=25):
  139. """
  140. make random crop from the input image
  141. :param im:
  142. :param polys: [b,4,2]
  143. :param tags:
  144. :param crop_background:
  145. :param max_tries: 50 -> 25
  146. :return:
  147. """
  148. h, w, _ = im.shape
  149. pad_h = h // 10
  150. pad_w = w // 10
  151. h_array = np.zeros((h + pad_h * 2), dtype=np.int32)
  152. w_array = np.zeros((w + pad_w * 2), dtype=np.int32)
  153. for poly in polys:
  154. poly = np.round(poly, decimals=0).astype(np.int32)
  155. minx = np.min(poly[:, 0])
  156. maxx = np.max(poly[:, 0])
  157. w_array[minx + pad_w:maxx + pad_w] = 1
  158. miny = np.min(poly[:, 1])
  159. maxy = np.max(poly[:, 1])
  160. h_array[miny + pad_h:maxy + pad_h] = 1
  161. # ensure the cropped area not across a text
  162. h_axis = np.where(h_array == 0)[0]
  163. w_axis = np.where(w_array == 0)[0]
  164. if len(h_axis) == 0 or len(w_axis) == 0:
  165. return im, polys, tags, hv_tags, txts
  166. for i in range(max_tries):
  167. xx = np.random.choice(w_axis, size=2)
  168. xmin = np.min(xx) - pad_w
  169. xmax = np.max(xx) - pad_w
  170. xmin = np.clip(xmin, 0, w - 1)
  171. xmax = np.clip(xmax, 0, w - 1)
  172. yy = np.random.choice(h_axis, size=2)
  173. ymin = np.min(yy) - pad_h
  174. ymax = np.max(yy) - pad_h
  175. ymin = np.clip(ymin, 0, h - 1)
  176. ymax = np.clip(ymax, 0, h - 1)
  177. if xmax - xmin < self.min_crop_size or \
  178. ymax - ymin < self.min_crop_size:
  179. continue
  180. if polys.shape[0] != 0:
  181. poly_axis_in_area = (polys[:, :, 0] >= xmin) & (polys[:, :, 0] <= xmax) \
  182. & (polys[:, :, 1] >= ymin) & (polys[:, :, 1] <= ymax)
  183. selected_polys = np.where(
  184. np.sum(poly_axis_in_area, axis=1) == 4)[0]
  185. else:
  186. selected_polys = []
  187. if len(selected_polys) == 0:
  188. # no text in this area
  189. if crop_background:
  190. txts_tmp = []
  191. for selected_poly in selected_polys:
  192. txts_tmp.append(txts[selected_poly])
  193. txts = txts_tmp
  194. return im[ymin: ymax + 1, xmin: xmax + 1, :], \
  195. polys[selected_polys], tags[selected_polys], hv_tags[selected_polys], txts
  196. else:
  197. continue
  198. im = im[ymin:ymax + 1, xmin:xmax + 1, :]
  199. polys = polys[selected_polys]
  200. tags = tags[selected_polys]
  201. hv_tags = hv_tags[selected_polys]
  202. txts_tmp = []
  203. for selected_poly in selected_polys:
  204. txts_tmp.append(txts[selected_poly])
  205. txts = txts_tmp
  206. polys[:, :, 0] -= xmin
  207. polys[:, :, 1] -= ymin
  208. return im, polys, tags, hv_tags, txts
  209. return im, polys, tags, hv_tags, txts
  210. def fit_and_gather_tcl_points_v2(self,
  211. min_area_quad,
  212. poly,
  213. max_h,
  214. max_w,
  215. fixed_point_num=64,
  216. img_id=0,
  217. reference_height=3):
  218. """
  219. Find the center point of poly as key_points, then fit and gather.
  220. """
  221. key_point_xys = []
  222. point_num = poly.shape[0]
  223. for idx in range(point_num // 2):
  224. center_point = (poly[idx] + poly[point_num - 1 - idx]) / 2.0
  225. key_point_xys.append(center_point)
  226. tmp_image = np.zeros(
  227. shape=(
  228. max_h,
  229. max_w, ), dtype='float32')
  230. cv2.polylines(tmp_image, [np.array(key_point_xys).astype('int32')],
  231. False, 1.0)
  232. ys, xs = np.where(tmp_image > 0)
  233. xy_text = np.array(list(zip(xs, ys)), dtype='float32')
  234. left_center_pt = (
  235. (min_area_quad[0] - min_area_quad[1]) / 2.0).reshape(1, 2)
  236. right_center_pt = (
  237. (min_area_quad[1] - min_area_quad[2]) / 2.0).reshape(1, 2)
  238. proj_unit_vec = (right_center_pt - left_center_pt) / (
  239. np.linalg.norm(right_center_pt - left_center_pt) + 1e-6)
  240. proj_unit_vec_tile = np.tile(proj_unit_vec,
  241. (xy_text.shape[0], 1)) # (n, 2)
  242. left_center_pt_tile = np.tile(left_center_pt,
  243. (xy_text.shape[0], 1)) # (n, 2)
  244. xy_text_to_left_center = xy_text - left_center_pt_tile
  245. proj_value = np.sum(xy_text_to_left_center * proj_unit_vec_tile, axis=1)
  246. xy_text = xy_text[np.argsort(proj_value)]
  247. # convert to np and keep the num of point not greater then fixed_point_num
  248. pos_info = np.array(xy_text).reshape(-1, 2)[:, ::-1] # xy-> yx
  249. point_num = len(pos_info)
  250. if point_num > fixed_point_num:
  251. keep_ids = [
  252. int((point_num * 1.0 / fixed_point_num) * x)
  253. for x in range(fixed_point_num)
  254. ]
  255. pos_info = pos_info[keep_ids, :]
  256. keep = int(min(len(pos_info), fixed_point_num))
  257. if np.random.rand() < 0.2 and reference_height >= 3:
  258. dl = (np.random.rand(keep) - 0.5) * reference_height * 0.3
  259. random_float = np.array([1, 0]).reshape([1, 2]) * dl.reshape(
  260. [keep, 1])
  261. pos_info += random_float
  262. pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
  263. pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
  264. # padding to fixed length
  265. pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
  266. pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
  267. pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
  268. pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
  269. pos_m[:keep] = 1.0
  270. return pos_l, pos_m
  271. def fit_and_gather_tcl_points_v3(self,
  272. min_area_quad,
  273. poly,
  274. max_h,
  275. max_w,
  276. fixed_point_num=64,
  277. img_id=0,
  278. reference_height=3):
  279. """
  280. Find the center point of poly as key_points, then fit and gather.
  281. """
  282. det_mask = np.zeros((int(max_h / self.ds_ratio),
  283. int(max_w / self.ds_ratio))).astype(np.float32)
  284. # score_big_map
  285. cv2.fillPoly(det_mask,
  286. np.round(poly / self.ds_ratio).astype(np.int32), 1.0)
  287. det_mask = cv2.resize(
  288. det_mask, dsize=None, fx=self.ds_ratio, fy=self.ds_ratio)
  289. det_mask = np.array(det_mask > 1e-3, dtype='float32')
  290. f_direction = self.f_direction
  291. skeleton_map = thin(det_mask.astype(np.uint8))
  292. instance_count, instance_label_map = cv2.connectedComponents(
  293. skeleton_map.astype(np.uint8), connectivity=8)
  294. ys, xs = np.where(instance_label_map == 1)
  295. pos_list = list(zip(ys, xs))
  296. if len(pos_list) < 3:
  297. return None
  298. pos_list_sorted = sort_and_expand_with_direction_v2(
  299. pos_list, f_direction, det_mask)
  300. pos_list_sorted = np.array(pos_list_sorted)
  301. length = len(pos_list_sorted) - 1
  302. insert_num = 0
  303. for index in range(length):
  304. stride_y = np.abs(pos_list_sorted[index + insert_num][0] -
  305. pos_list_sorted[index + 1 + insert_num][0])
  306. stride_x = np.abs(pos_list_sorted[index + insert_num][1] -
  307. pos_list_sorted[index + 1 + insert_num][1])
  308. max_points = int(max(stride_x, stride_y))
  309. stride = (pos_list_sorted[index + insert_num] -
  310. pos_list_sorted[index + 1 + insert_num]) / (max_points)
  311. insert_num_temp = max_points - 1
  312. for i in range(int(insert_num_temp)):
  313. insert_value = pos_list_sorted[index + insert_num] - (i + 1
  314. ) * stride
  315. insert_index = index + i + 1 + insert_num
  316. pos_list_sorted = np.insert(
  317. pos_list_sorted, insert_index, insert_value, axis=0)
  318. insert_num += insert_num_temp
  319. pos_info = np.array(pos_list_sorted).reshape(-1, 2).astype(
  320. np.float32) # xy-> yx
  321. point_num = len(pos_info)
  322. if point_num > fixed_point_num:
  323. keep_ids = [
  324. int((point_num * 1.0 / fixed_point_num) * x)
  325. for x in range(fixed_point_num)
  326. ]
  327. pos_info = pos_info[keep_ids, :]
  328. keep = int(min(len(pos_info), fixed_point_num))
  329. reference_width = (np.abs(poly[0, 0, 0] - poly[-1, 1, 0]) +
  330. np.abs(poly[0, 3, 0] - poly[-1, 2, 0])) // 2
  331. if np.random.rand() < 1:
  332. dh = (np.random.rand(keep) - 0.5) * reference_height
  333. offset = np.random.rand() - 0.5
  334. dw = np.array([[0, offset * reference_width * 0.2]])
  335. random_float_h = np.array([1, 0]).reshape([1, 2]) * dh.reshape(
  336. [keep, 1])
  337. random_float_w = dw.repeat(keep, axis=0)
  338. pos_info += random_float_h
  339. pos_info += random_float_w
  340. pos_info[:, 0] = np.clip(pos_info[:, 0], 0, max_h - 1)
  341. pos_info[:, 1] = np.clip(pos_info[:, 1], 0, max_w - 1)
  342. # padding to fixed length
  343. pos_l = np.zeros((self.tcl_len, 3), dtype=np.int32)
  344. pos_l[:, 0] = np.ones((self.tcl_len, )) * img_id
  345. pos_m = np.zeros((self.tcl_len, 1), dtype=np.float32)
  346. pos_l[:keep, 1:] = np.round(pos_info).astype(np.int32)
  347. pos_m[:keep] = 1.0
  348. return pos_l, pos_m
  349. def generate_direction_map(self, poly_quads, n_char, direction_map):
  350. """
  351. """
  352. width_list = []
  353. height_list = []
  354. for quad in poly_quads:
  355. quad_w = (np.linalg.norm(quad[0] - quad[1]) +
  356. np.linalg.norm(quad[2] - quad[3])) / 2.0
  357. quad_h = (np.linalg.norm(quad[0] - quad[3]) +
  358. np.linalg.norm(quad[2] - quad[1])) / 2.0
  359. width_list.append(quad_w)
  360. height_list.append(quad_h)
  361. norm_width = max(sum(width_list) / n_char, 1.0)
  362. average_height = max(sum(height_list) / len(height_list), 1.0)
  363. k = 1
  364. for quad in poly_quads:
  365. direct_vector_full = (
  366. (quad[1] + quad[2]) - (quad[0] + quad[3])) / 2.0
  367. direct_vector = direct_vector_full / (
  368. np.linalg.norm(direct_vector_full) + 1e-6) * norm_width
  369. direction_label = tuple(
  370. map(float,
  371. [direct_vector[0], direct_vector[1], 1.0 / average_height]))
  372. cv2.fillPoly(direction_map,
  373. quad.round().astype(np.int32)[np.newaxis, :, :],
  374. direction_label)
  375. k += 1
  376. return direction_map
  377. def calculate_average_height(self, poly_quads):
  378. """
  379. """
  380. height_list = []
  381. for quad in poly_quads:
  382. quad_h = (np.linalg.norm(quad[0] - quad[3]) +
  383. np.linalg.norm(quad[2] - quad[1])) / 2.0
  384. height_list.append(quad_h)
  385. average_height = max(sum(height_list) / len(height_list), 1.0)
  386. return average_height
  387. def generate_tcl_ctc_label(self,
  388. h,
  389. w,
  390. polys,
  391. tags,
  392. text_strs,
  393. ds_ratio,
  394. tcl_ratio=0.3,
  395. shrink_ratio_of_width=0.15):
  396. """
  397. Generate polygon.
  398. """
  399. self.ds_ratio = ds_ratio
  400. score_map_big = np.zeros(
  401. (
  402. h,
  403. w, ), dtype=np.float32)
  404. h, w = int(h * ds_ratio), int(w * ds_ratio)
  405. polys = polys * ds_ratio
  406. score_map = np.zeros(
  407. (
  408. h,
  409. w, ), dtype=np.float32)
  410. score_label_map = np.zeros(
  411. (
  412. h,
  413. w, ), dtype=np.float32)
  414. tbo_map = np.zeros((h, w, 5), dtype=np.float32)
  415. training_mask = np.ones(
  416. (
  417. h,
  418. w, ), dtype=np.float32)
  419. direction_map = np.ones((h, w, 3)) * np.array([0, 0, 1]).reshape(
  420. [1, 1, 3]).astype(np.float32)
  421. label_idx = 0
  422. score_label_map_text_label_list = []
  423. pos_list, pos_mask, label_list = [], [], []
  424. for poly_idx, poly_tag in enumerate(zip(polys, tags)):
  425. poly = poly_tag[0]
  426. tag = poly_tag[1]
  427. # generate min_area_quad
  428. min_area_quad, center_point = self.gen_min_area_quad_from_poly(poly)
  429. min_area_quad_h = 0.5 * (
  430. np.linalg.norm(min_area_quad[0] - min_area_quad[3]) +
  431. np.linalg.norm(min_area_quad[1] - min_area_quad[2]))
  432. min_area_quad_w = 0.5 * (
  433. np.linalg.norm(min_area_quad[0] - min_area_quad[1]) +
  434. np.linalg.norm(min_area_quad[2] - min_area_quad[3]))
  435. if min(min_area_quad_h, min_area_quad_w) < self.min_text_size * ds_ratio \
  436. or min(min_area_quad_h, min_area_quad_w) > self.max_text_size * ds_ratio:
  437. continue
  438. if tag:
  439. cv2.fillPoly(training_mask,
  440. poly.astype(np.int32)[np.newaxis, :, :], 0.15)
  441. else:
  442. text_label = text_strs[poly_idx]
  443. text_label = self.prepare_text_label(text_label,
  444. self.Lexicon_Table)
  445. text_label_index_list = [[self.Lexicon_Table.index(c_)]
  446. for c_ in text_label
  447. if c_ in self.Lexicon_Table]
  448. if len(text_label_index_list) < 1:
  449. continue
  450. tcl_poly = self.poly2tcl(poly, tcl_ratio)
  451. tcl_quads = self.poly2quads(tcl_poly)
  452. poly_quads = self.poly2quads(poly)
  453. stcl_quads, quad_index = self.shrink_poly_along_width(
  454. tcl_quads,
  455. shrink_ratio_of_width=shrink_ratio_of_width,
  456. expand_height_ratio=1.0 / tcl_ratio)
  457. cv2.fillPoly(score_map,
  458. np.round(stcl_quads).astype(np.int32), 1.0)
  459. cv2.fillPoly(score_map_big,
  460. np.round(stcl_quads / ds_ratio).astype(np.int32),
  461. 1.0)
  462. for idx, quad in enumerate(stcl_quads):
  463. quad_mask = np.zeros((h, w), dtype=np.float32)
  464. quad_mask = cv2.fillPoly(
  465. quad_mask,
  466. np.round(quad[np.newaxis, :, :]).astype(np.int32), 1.0)
  467. tbo_map = self.gen_quad_tbo(poly_quads[quad_index[idx]],
  468. quad_mask, tbo_map)
  469. # score label map and score_label_map_text_label_list for refine
  470. if label_idx == 0:
  471. text_pos_list_ = [[len(self.Lexicon_Table)], ]
  472. score_label_map_text_label_list.append(text_pos_list_)
  473. label_idx += 1
  474. cv2.fillPoly(score_label_map,
  475. np.round(poly_quads).astype(np.int32), label_idx)
  476. score_label_map_text_label_list.append(text_label_index_list)
  477. # direction info, fix-me
  478. n_char = len(text_label_index_list)
  479. direction_map = self.generate_direction_map(poly_quads, n_char,
  480. direction_map)
  481. # pos info
  482. average_shrink_height = self.calculate_average_height(
  483. stcl_quads)
  484. if self.point_gather_mode == 'align':
  485. self.f_direction = direction_map[:, :, :-1].copy()
  486. pos_res = self.fit_and_gather_tcl_points_v3(
  487. min_area_quad,
  488. stcl_quads,
  489. max_h=h,
  490. max_w=w,
  491. fixed_point_num=64,
  492. img_id=self.img_id,
  493. reference_height=average_shrink_height)
  494. if pos_res is None:
  495. continue
  496. pos_l, pos_m = pos_res[0], pos_res[1]
  497. else:
  498. pos_l, pos_m = self.fit_and_gather_tcl_points_v2(
  499. min_area_quad,
  500. poly,
  501. max_h=h,
  502. max_w=w,
  503. fixed_point_num=64,
  504. img_id=self.img_id,
  505. reference_height=average_shrink_height)
  506. label_l = text_label_index_list
  507. if len(text_label_index_list) < 2:
  508. continue
  509. pos_list.append(pos_l)
  510. pos_mask.append(pos_m)
  511. label_list.append(label_l)
  512. # use big score_map for smooth tcl lines
  513. score_map_big_resized = cv2.resize(
  514. score_map_big, dsize=None, fx=ds_ratio, fy=ds_ratio)
  515. score_map = np.array(score_map_big_resized > 1e-3, dtype='float32')
  516. return score_map, score_label_map, tbo_map, direction_map, training_mask, \
  517. pos_list, pos_mask, label_list, score_label_map_text_label_list
  518. def adjust_point(self, poly):
  519. """
  520. adjust point order.
  521. """
  522. point_num = poly.shape[0]
  523. if point_num == 4:
  524. len_1 = np.linalg.norm(poly[0] - poly[1])
  525. len_2 = np.linalg.norm(poly[1] - poly[2])
  526. len_3 = np.linalg.norm(poly[2] - poly[3])
  527. len_4 = np.linalg.norm(poly[3] - poly[0])
  528. if (len_1 + len_3) * 1.5 < (len_2 + len_4):
  529. poly = poly[[1, 2, 3, 0], :]
  530. elif point_num > 4:
  531. vector_1 = poly[0] - poly[1]
  532. vector_2 = poly[1] - poly[2]
  533. cos_theta = np.dot(vector_1, vector_2) / (
  534. np.linalg.norm(vector_1) * np.linalg.norm(vector_2) + 1e-6)
  535. theta = np.arccos(np.round(cos_theta, decimals=4))
  536. if abs(theta) > (70 / 180 * math.pi):
  537. index = list(range(1, point_num)) + [0]
  538. poly = poly[np.array(index), :]
  539. return poly
  540. def gen_min_area_quad_from_poly(self, poly):
  541. """
  542. Generate min area quad from poly.
  543. """
  544. point_num = poly.shape[0]
  545. min_area_quad = np.zeros((4, 2), dtype=np.float32)
  546. if point_num == 4:
  547. min_area_quad = poly
  548. center_point = np.sum(poly, axis=0) / 4
  549. else:
  550. rect = cv2.minAreaRect(poly.astype(
  551. np.int32)) # (center (x,y), (width, height), angle of rotation)
  552. center_point = rect[0]
  553. box = np.array(cv2.boxPoints(rect))
  554. first_point_idx = 0
  555. min_dist = 1e4
  556. for i in range(4):
  557. dist = np.linalg.norm(box[(i + 0) % 4] - poly[0]) + \
  558. np.linalg.norm(box[(i + 1) % 4] - poly[point_num // 2 - 1]) + \
  559. np.linalg.norm(box[(i + 2) % 4] - poly[point_num // 2]) + \
  560. np.linalg.norm(box[(i + 3) % 4] - poly[-1])
  561. if dist < min_dist:
  562. min_dist = dist
  563. first_point_idx = i
  564. for i in range(4):
  565. min_area_quad[i] = box[(first_point_idx + i) % 4]
  566. return min_area_quad, center_point
  567. def shrink_quad_along_width(self,
  568. quad,
  569. begin_width_ratio=0.,
  570. end_width_ratio=1.):
  571. """
  572. Generate shrink_quad_along_width.
  573. """
  574. ratio_pair = np.array(
  575. [[begin_width_ratio], [end_width_ratio]], dtype=np.float32)
  576. p0_1 = quad[0] + (quad[1] - quad[0]) * ratio_pair
  577. p3_2 = quad[3] + (quad[2] - quad[3]) * ratio_pair
  578. return np.array([p0_1[0], p0_1[1], p3_2[1], p3_2[0]])
  579. def shrink_poly_along_width(self,
  580. quads,
  581. shrink_ratio_of_width,
  582. expand_height_ratio=1.0):
  583. """
  584. shrink poly with given length.
  585. """
  586. upper_edge_list = []
  587. def get_cut_info(edge_len_list, cut_len):
  588. for idx, edge_len in enumerate(edge_len_list):
  589. cut_len -= edge_len
  590. if cut_len <= 0.000001:
  591. ratio = (cut_len + edge_len_list[idx]) / edge_len_list[idx]
  592. return idx, ratio
  593. for quad in quads:
  594. upper_edge_len = np.linalg.norm(quad[0] - quad[1])
  595. upper_edge_list.append(upper_edge_len)
  596. # length of left edge and right edge.
  597. left_length = np.linalg.norm(quads[0][0] - quads[0][
  598. 3]) * expand_height_ratio
  599. right_length = np.linalg.norm(quads[-1][1] - quads[-1][
  600. 2]) * expand_height_ratio
  601. shrink_length = min(left_length, right_length,
  602. sum(upper_edge_list)) * shrink_ratio_of_width
  603. # shrinking length
  604. upper_len_left = shrink_length
  605. upper_len_right = sum(upper_edge_list) - shrink_length
  606. left_idx, left_ratio = get_cut_info(upper_edge_list, upper_len_left)
  607. left_quad = self.shrink_quad_along_width(
  608. quads[left_idx], begin_width_ratio=left_ratio, end_width_ratio=1)
  609. right_idx, right_ratio = get_cut_info(upper_edge_list, upper_len_right)
  610. right_quad = self.shrink_quad_along_width(
  611. quads[right_idx], begin_width_ratio=0, end_width_ratio=right_ratio)
  612. out_quad_list = []
  613. if left_idx == right_idx:
  614. out_quad_list.append(
  615. [left_quad[0], right_quad[1], right_quad[2], left_quad[3]])
  616. else:
  617. out_quad_list.append(left_quad)
  618. for idx in range(left_idx + 1, right_idx):
  619. out_quad_list.append(quads[idx])
  620. out_quad_list.append(right_quad)
  621. return np.array(out_quad_list), list(range(left_idx, right_idx + 1))
  622. def prepare_text_label(self, label_str, Lexicon_Table):
  623. """
  624. Prepare text lablel by given Lexicon_Table.
  625. """
  626. if len(Lexicon_Table) == 36:
  627. return label_str.lower()
  628. else:
  629. return label_str
  630. def vector_angle(self, A, B):
  631. """
  632. Calculate the angle between vector AB and x-axis positive direction.
  633. """
  634. AB = np.array([B[1] - A[1], B[0] - A[0]])
  635. return np.arctan2(*AB)
  636. def theta_line_cross_point(self, theta, point):
  637. """
  638. Calculate the line through given point and angle in ax + by + c =0 form.
  639. """
  640. x, y = point
  641. cos = np.cos(theta)
  642. sin = np.sin(theta)
  643. return [sin, -cos, cos * y - sin * x]
  644. def line_cross_two_point(self, A, B):
  645. """
  646. Calculate the line through given point A and B in ax + by + c =0 form.
  647. """
  648. angle = self.vector_angle(A, B)
  649. return self.theta_line_cross_point(angle, A)
  650. def average_angle(self, poly):
  651. """
  652. Calculate the average angle between left and right edge in given poly.
  653. """
  654. p0, p1, p2, p3 = poly
  655. angle30 = self.vector_angle(p3, p0)
  656. angle21 = self.vector_angle(p2, p1)
  657. return (angle30 + angle21) / 2
  658. def line_cross_point(self, line1, line2):
  659. """
  660. line1 and line2 in 0=ax+by+c form, compute the cross point of line1 and line2
  661. """
  662. a1, b1, c1 = line1
  663. a2, b2, c2 = line2
  664. d = a1 * b2 - a2 * b1
  665. if d == 0:
  666. print('Cross point does not exist')
  667. return np.array([0, 0], dtype=np.float32)
  668. else:
  669. x = (b1 * c2 - b2 * c1) / d
  670. y = (a2 * c1 - a1 * c2) / d
  671. return np.array([x, y], dtype=np.float32)
  672. def quad2tcl(self, poly, ratio):
  673. """
  674. Generate center line by poly clock-wise point. (4, 2)
  675. """
  676. ratio_pair = np.array(
  677. [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
  678. p0_3 = poly[0] + (poly[3] - poly[0]) * ratio_pair
  679. p1_2 = poly[1] + (poly[2] - poly[1]) * ratio_pair
  680. return np.array([p0_3[0], p1_2[0], p1_2[1], p0_3[1]])
  681. def poly2tcl(self, poly, ratio):
  682. """
  683. Generate center line by poly clock-wise point.
  684. """
  685. ratio_pair = np.array(
  686. [[0.5 - ratio / 2], [0.5 + ratio / 2]], dtype=np.float32)
  687. tcl_poly = np.zeros_like(poly)
  688. point_num = poly.shape[0]
  689. for idx in range(point_num // 2):
  690. point_pair = poly[idx] + (poly[point_num - 1 - idx] - poly[idx]
  691. ) * ratio_pair
  692. tcl_poly[idx] = point_pair[0]
  693. tcl_poly[point_num - 1 - idx] = point_pair[1]
  694. return tcl_poly
  695. def gen_quad_tbo(self, quad, tcl_mask, tbo_map):
  696. """
  697. Generate tbo_map for give quad.
  698. """
  699. # upper and lower line function: ax + by + c = 0;
  700. up_line = self.line_cross_two_point(quad[0], quad[1])
  701. lower_line = self.line_cross_two_point(quad[3], quad[2])
  702. quad_h = 0.5 * (np.linalg.norm(quad[0] - quad[3]) +
  703. np.linalg.norm(quad[1] - quad[2]))
  704. quad_w = 0.5 * (np.linalg.norm(quad[0] - quad[1]) +
  705. np.linalg.norm(quad[2] - quad[3]))
  706. # average angle of left and right line.
  707. angle = self.average_angle(quad)
  708. xy_in_poly = np.argwhere(tcl_mask == 1)
  709. for y, x in xy_in_poly:
  710. point = (x, y)
  711. line = self.theta_line_cross_point(angle, point)
  712. cross_point_upper = self.line_cross_point(up_line, line)
  713. cross_point_lower = self.line_cross_point(lower_line, line)
  714. ##FIX, offset reverse
  715. upper_offset_x, upper_offset_y = cross_point_upper - point
  716. lower_offset_x, lower_offset_y = cross_point_lower - point
  717. tbo_map[y, x, 0] = upper_offset_y
  718. tbo_map[y, x, 1] = upper_offset_x
  719. tbo_map[y, x, 2] = lower_offset_y
  720. tbo_map[y, x, 3] = lower_offset_x
  721. tbo_map[y, x, 4] = 1.0 / max(min(quad_h, quad_w), 1.0) * 2
  722. return tbo_map
  723. def poly2quads(self, poly):
  724. """
  725. Split poly into quads.
  726. """
  727. quad_list = []
  728. point_num = poly.shape[0]
  729. # point pair
  730. point_pair_list = []
  731. for idx in range(point_num // 2):
  732. point_pair = [poly[idx], poly[point_num - 1 - idx]]
  733. point_pair_list.append(point_pair)
  734. quad_num = point_num // 2 - 1
  735. for idx in range(quad_num):
  736. # reshape and adjust to clock-wise
  737. quad_list.append((np.array(point_pair_list)[[idx, idx + 1]]
  738. ).reshape(4, 2)[[0, 2, 3, 1]])
  739. return np.array(quad_list)
  740. def rotate_im_poly(self, im, text_polys):
  741. """
  742. rotate image with 90 / 180 / 270 degre
  743. """
  744. im_w, im_h = im.shape[1], im.shape[0]
  745. dst_im = im.copy()
  746. dst_polys = []
  747. rand_degree_ratio = np.random.rand()
  748. rand_degree_cnt = 1
  749. if rand_degree_ratio > 0.5:
  750. rand_degree_cnt = 3
  751. for i in range(rand_degree_cnt):
  752. dst_im = np.rot90(dst_im)
  753. rot_degree = -90 * rand_degree_cnt
  754. rot_angle = rot_degree * math.pi / 180.0
  755. n_poly = text_polys.shape[0]
  756. cx, cy = 0.5 * im_w, 0.5 * im_h
  757. ncx, ncy = 0.5 * dst_im.shape[1], 0.5 * dst_im.shape[0]
  758. for i in range(n_poly):
  759. wordBB = text_polys[i]
  760. poly = []
  761. for j in range(4): # 16->4
  762. sx, sy = wordBB[j][0], wordBB[j][1]
  763. dx = math.cos(rot_angle) * (sx - cx) - math.sin(rot_angle) * (
  764. sy - cy) + ncx
  765. dy = math.sin(rot_angle) * (sx - cx) + math.cos(rot_angle) * (
  766. sy - cy) + ncy
  767. poly.append([dx, dy])
  768. dst_polys.append(poly)
  769. return dst_im, np.array(dst_polys, dtype=np.float32)
  770. def __call__(self, data):
  771. input_size = 512
  772. im = data['image']
  773. text_polys = data['polys']
  774. text_tags = data['ignore_tags']
  775. text_strs = data['texts']
  776. h, w, _ = im.shape
  777. text_polys, text_tags, hv_tags = self.check_and_validate_polys(
  778. text_polys, text_tags, (h, w))
  779. if text_polys.shape[0] <= 0:
  780. return None
  781. # set aspect ratio and keep area fix
  782. asp_scales = np.arange(1.0, 1.55, 0.1)
  783. asp_scale = np.random.choice(asp_scales)
  784. if np.random.rand() < 0.5:
  785. asp_scale = 1.0 / asp_scale
  786. asp_scale = math.sqrt(asp_scale)
  787. asp_wx = asp_scale
  788. asp_hy = 1.0 / asp_scale
  789. im = cv2.resize(im, dsize=None, fx=asp_wx, fy=asp_hy)
  790. text_polys[:, :, 0] *= asp_wx
  791. text_polys[:, :, 1] *= asp_hy
  792. if self.use_resize is True:
  793. ori_h, ori_w, _ = im.shape
  794. if max(ori_h, ori_w) < 200:
  795. ratio = 200 / max(ori_h, ori_w)
  796. im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
  797. text_polys[:, :, 0] *= ratio
  798. text_polys[:, :, 1] *= ratio
  799. if max(ori_h, ori_w) > 512:
  800. ratio = 512 / max(ori_h, ori_w)
  801. im = cv2.resize(im, (int(ori_w * ratio), int(ori_h * ratio)))
  802. text_polys[:, :, 0] *= ratio
  803. text_polys[:, :, 1] *= ratio
  804. elif self.use_random_crop is True:
  805. h, w, _ = im.shape
  806. if max(h, w) > 2048:
  807. rd_scale = 2048.0 / max(h, w)
  808. im = cv2.resize(im, dsize=None, fx=rd_scale, fy=rd_scale)
  809. text_polys *= rd_scale
  810. h, w, _ = im.shape
  811. if min(h, w) < 16:
  812. return None
  813. # no background
  814. im, text_polys, text_tags, hv_tags, text_strs = self.crop_area(
  815. im,
  816. text_polys,
  817. text_tags,
  818. hv_tags,
  819. text_strs,
  820. crop_background=False)
  821. if text_polys.shape[0] == 0:
  822. return None
  823. # continue for all ignore case
  824. if np.sum((text_tags * 1.0)) >= text_tags.size:
  825. return None
  826. new_h, new_w, _ = im.shape
  827. if (new_h is None) or (new_w is None):
  828. return None
  829. # resize image
  830. std_ratio = float(input_size) / max(new_w, new_h)
  831. rand_scales = np.array(
  832. [0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0, 1.0, 1.0, 1.0, 1.0])
  833. rz_scale = std_ratio * np.random.choice(rand_scales)
  834. im = cv2.resize(im, dsize=None, fx=rz_scale, fy=rz_scale)
  835. text_polys[:, :, 0] *= rz_scale
  836. text_polys[:, :, 1] *= rz_scale
  837. # add gaussian blur
  838. if np.random.rand() < 0.1 * 0.5:
  839. ks = np.random.permutation(5)[0] + 1
  840. ks = int(ks / 2) * 2 + 1
  841. im = cv2.GaussianBlur(im, ksize=(ks, ks), sigmaX=0, sigmaY=0)
  842. # add brighter
  843. if np.random.rand() < 0.1 * 0.5:
  844. im = im * (1.0 + np.random.rand() * 0.5)
  845. im = np.clip(im, 0.0, 255.0)
  846. # add darker
  847. if np.random.rand() < 0.1 * 0.5:
  848. im = im * (1.0 - np.random.rand() * 0.5)
  849. im = np.clip(im, 0.0, 255.0)
  850. # Padding the im to [input_size, input_size]
  851. new_h, new_w, _ = im.shape
  852. if min(new_w, new_h) < input_size * 0.5:
  853. return None
  854. im_padded = np.ones((input_size, input_size, 3), dtype=np.float32)
  855. im_padded[:, :, 2] = 0.485 * 255
  856. im_padded[:, :, 1] = 0.456 * 255
  857. im_padded[:, :, 0] = 0.406 * 255
  858. # Random the start position
  859. del_h = input_size - new_h
  860. del_w = input_size - new_w
  861. sh, sw = 0, 0
  862. if del_h > 1:
  863. sh = int(np.random.rand() * del_h)
  864. if del_w > 1:
  865. sw = int(np.random.rand() * del_w)
  866. # Padding
  867. im_padded[sh:sh + new_h, sw:sw + new_w, :] = im.copy()
  868. text_polys[:, :, 0] += sw
  869. text_polys[:, :, 1] += sh
  870. score_map, score_label_map, border_map, direction_map, training_mask, \
  871. pos_list, pos_mask, label_list, score_label_map_text_label = self.generate_tcl_ctc_label(input_size,
  872. input_size,
  873. text_polys,
  874. text_tags,
  875. text_strs, 0.25)
  876. if len(label_list) <= 0: # eliminate negative samples
  877. return None
  878. pos_list_temp = np.zeros([64, 3])
  879. pos_mask_temp = np.zeros([64, 1])
  880. label_list_temp = np.zeros([self.max_text_length, 1]) + self.pad_num
  881. for i, label in enumerate(label_list):
  882. n = len(label)
  883. if n > self.max_text_length:
  884. label_list[i] = label[:self.max_text_length]
  885. continue
  886. while n < self.max_text_length:
  887. label.append([self.pad_num])
  888. n += 1
  889. for i in range(len(label_list)):
  890. label_list[i] = np.array(label_list[i])
  891. if len(pos_list) <= 0 or len(pos_list) > self.max_text_nums:
  892. return None
  893. for __ in range(self.max_text_nums - len(pos_list), 0, -1):
  894. pos_list.append(pos_list_temp)
  895. pos_mask.append(pos_mask_temp)
  896. label_list.append(label_list_temp)
  897. if self.img_id == self.batch_size - 1:
  898. self.img_id = 0
  899. else:
  900. self.img_id += 1
  901. im_padded[:, :, 2] -= 0.485 * 255
  902. im_padded[:, :, 1] -= 0.456 * 255
  903. im_padded[:, :, 0] -= 0.406 * 255
  904. im_padded[:, :, 2] /= (255.0 * 0.229)
  905. im_padded[:, :, 1] /= (255.0 * 0.224)
  906. im_padded[:, :, 0] /= (255.0 * 0.225)
  907. im_padded = im_padded.transpose((2, 0, 1))
  908. images = im_padded[::-1, :, :]
  909. tcl_maps = score_map[np.newaxis, :, :]
  910. tcl_label_maps = score_label_map[np.newaxis, :, :]
  911. border_maps = border_map.transpose((2, 0, 1))
  912. direction_maps = direction_map.transpose((2, 0, 1))
  913. training_masks = training_mask[np.newaxis, :, :]
  914. pos_list = np.array(pos_list)
  915. pos_mask = np.array(pos_mask)
  916. label_list = np.array(label_list)
  917. data['images'] = images
  918. data['tcl_maps'] = tcl_maps
  919. data['tcl_label_maps'] = tcl_label_maps
  920. data['border_maps'] = border_maps
  921. data['direction_maps'] = direction_maps
  922. data['training_masks'] = training_masks
  923. data['label_list'] = label_list
  924. data['pos_list'] = pos_list
  925. data['pos_mask'] = pos_mask
  926. return data