label_ops.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import copy
  19. import numpy as np
  20. import string
  21. from shapely.geometry import LineString, Point, Polygon
  22. import json
  23. import copy
  24. from random import sample
  25. from ppocr.utils.logging import get_logger
  26. from ppocr.data.imaug.vqa.augment import order_by_tbyx
  27. class ClsLabelEncode(object):
  28. def __init__(self, label_list, **kwargs):
  29. self.label_list = label_list
  30. def __call__(self, data):
  31. label = data['label']
  32. if label not in self.label_list:
  33. return None
  34. label = self.label_list.index(label)
  35. data['label'] = label
  36. return data
  37. class DetLabelEncode(object):
  38. def __init__(self, **kwargs):
  39. pass
  40. def __call__(self, data):
  41. label = data['label']
  42. label = json.loads(label)
  43. nBox = len(label)
  44. boxes, txts, txt_tags = [], [], []
  45. for bno in range(0, nBox):
  46. box = label[bno]['points']
  47. txt = label[bno]['transcription']
  48. boxes.append(box)
  49. txts.append(txt)
  50. if txt in ['*', '###']:
  51. txt_tags.append(True)
  52. else:
  53. txt_tags.append(False)
  54. if len(boxes) == 0:
  55. return None
  56. boxes = self.expand_points_num(boxes)
  57. boxes = np.array(boxes, dtype=np.float32)
  58. txt_tags = np.array(txt_tags, dtype=bool)
  59. data['polys'] = boxes
  60. data['texts'] = txts
  61. data['ignore_tags'] = txt_tags
  62. return data
  63. def order_points_clockwise(self, pts):
  64. rect = np.zeros((4, 2), dtype="float32")
  65. s = pts.sum(axis=1)
  66. rect[0] = pts[np.argmin(s)]
  67. rect[2] = pts[np.argmax(s)]
  68. tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
  69. diff = np.diff(np.array(tmp), axis=1)
  70. rect[1] = tmp[np.argmin(diff)]
  71. rect[3] = tmp[np.argmax(diff)]
  72. return rect
  73. def expand_points_num(self, boxes):
  74. max_points_num = 0
  75. for box in boxes:
  76. if len(box) > max_points_num:
  77. max_points_num = len(box)
  78. ex_boxes = []
  79. for box in boxes:
  80. ex_box = box + [box[-1]] * (max_points_num - len(box))
  81. ex_boxes.append(ex_box)
  82. return ex_boxes
  83. class BaseRecLabelEncode(object):
  84. """ Convert between text-label and text-index """
  85. def __init__(self,
  86. max_text_length,
  87. character_dict_path=None,
  88. use_space_char=False,
  89. lower=False):
  90. self.max_text_len = max_text_length
  91. self.beg_str = "sos"
  92. self.end_str = "eos"
  93. self.lower = lower
  94. if character_dict_path is None:
  95. logger = get_logger()
  96. logger.warning(
  97. "The character_dict_path is None, model can only recognize number and lower letters"
  98. )
  99. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  100. dict_character = list(self.character_str)
  101. self.lower = True
  102. else:
  103. self.character_str = []
  104. with open(character_dict_path, "rb") as fin:
  105. lines = fin.readlines()
  106. for line in lines:
  107. line = line.decode('utf-8').strip("\n").strip("\r\n")
  108. self.character_str.append(line)
  109. if use_space_char:
  110. self.character_str.append(" ")
  111. dict_character = list(self.character_str)
  112. dict_character = self.add_special_char(dict_character)
  113. self.dict = {}
  114. for i, char in enumerate(dict_character):
  115. self.dict[char] = i
  116. self.character = dict_character
  117. def add_special_char(self, dict_character):
  118. return dict_character
  119. def encode(self, text):
  120. """convert text-label into text-index.
  121. input:
  122. text: text labels of each image. [batch_size]
  123. output:
  124. text: concatenated text index for CTCLoss.
  125. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
  126. length: length of each text. [batch_size]
  127. """
  128. if len(text) == 0 or len(text) > self.max_text_len:
  129. return None
  130. if self.lower:
  131. text = text.lower()
  132. text_list = []
  133. for char in text:
  134. if char not in self.dict:
  135. # logger = get_logger()
  136. # logger.warning('{} is not in dict'.format(char))
  137. continue
  138. text_list.append(self.dict[char])
  139. if len(text_list) == 0:
  140. return None
  141. return text_list
  142. class CTCLabelEncode(BaseRecLabelEncode):
  143. """ Convert between text-label and text-index """
  144. def __init__(self,
  145. max_text_length,
  146. character_dict_path=None,
  147. use_space_char=False,
  148. **kwargs):
  149. super(CTCLabelEncode, self).__init__(
  150. max_text_length, character_dict_path, use_space_char)
  151. def __call__(self, data):
  152. text = data['label']
  153. text = self.encode(text)
  154. if text is None:
  155. return None
  156. data['length'] = np.array(len(text))
  157. text = text + [0] * (self.max_text_len - len(text))
  158. data['label'] = np.array(text)
  159. label = [0] * len(self.character)
  160. for x in text:
  161. label[x] += 1
  162. data['label_ace'] = np.array(label)
  163. return data
  164. def add_special_char(self, dict_character):
  165. dict_character = ['blank'] + dict_character
  166. return dict_character
  167. class E2ELabelEncodeTest(BaseRecLabelEncode):
  168. def __init__(self,
  169. max_text_length,
  170. character_dict_path=None,
  171. use_space_char=False,
  172. **kwargs):
  173. super(E2ELabelEncodeTest, self).__init__(
  174. max_text_length, character_dict_path, use_space_char)
  175. def __call__(self, data):
  176. import json
  177. padnum = len(self.dict)
  178. label = data['label']
  179. label = json.loads(label)
  180. nBox = len(label)
  181. boxes, txts, txt_tags = [], [], []
  182. for bno in range(0, nBox):
  183. box = label[bno]['points']
  184. txt = label[bno]['transcription']
  185. boxes.append(box)
  186. txts.append(txt)
  187. if txt in ['*', '###']:
  188. txt_tags.append(True)
  189. else:
  190. txt_tags.append(False)
  191. boxes = np.array(boxes, dtype=np.float32)
  192. txt_tags = np.array(txt_tags, dtype=bool)
  193. data['polys'] = boxes
  194. data['ignore_tags'] = txt_tags
  195. temp_texts = []
  196. for text in txts:
  197. text = text.lower()
  198. text = self.encode(text)
  199. if text is None:
  200. return None
  201. text = text + [padnum] * (self.max_text_len - len(text)
  202. ) # use 36 to pad
  203. temp_texts.append(text)
  204. data['texts'] = np.array(temp_texts)
  205. return data
  206. class E2ELabelEncodeTrain(object):
  207. def __init__(self, **kwargs):
  208. pass
  209. def __call__(self, data):
  210. import json
  211. label = data['label']
  212. label = json.loads(label)
  213. nBox = len(label)
  214. boxes, txts, txt_tags = [], [], []
  215. for bno in range(0, nBox):
  216. box = label[bno]['points']
  217. txt = label[bno]['transcription']
  218. boxes.append(box)
  219. txts.append(txt)
  220. if txt in ['*', '###']:
  221. txt_tags.append(True)
  222. else:
  223. txt_tags.append(False)
  224. boxes = np.array(boxes, dtype=np.float32)
  225. txt_tags = np.array(txt_tags, dtype=bool)
  226. data['polys'] = boxes
  227. data['texts'] = txts
  228. data['ignore_tags'] = txt_tags
  229. return data
  230. class KieLabelEncode(object):
  231. def __init__(self,
  232. character_dict_path,
  233. class_path,
  234. norm=10,
  235. directed=False,
  236. **kwargs):
  237. super(KieLabelEncode, self).__init__()
  238. self.dict = dict({'': 0})
  239. self.label2classid_map = dict()
  240. with open(character_dict_path, 'r', encoding='utf-8') as fr:
  241. idx = 1
  242. for line in fr:
  243. char = line.strip()
  244. self.dict[char] = idx
  245. idx += 1
  246. with open(class_path, "r") as fin:
  247. lines = fin.readlines()
  248. for idx, line in enumerate(lines):
  249. line = line.strip("\n")
  250. self.label2classid_map[line] = idx
  251. self.norm = norm
  252. self.directed = directed
  253. def compute_relation(self, boxes):
  254. """Compute relation between every two boxes."""
  255. x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
  256. x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
  257. ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
  258. dxs = (x1s[:, 0][None] - x1s) / self.norm
  259. dys = (y1s[:, 0][None] - y1s) / self.norm
  260. xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
  261. whs = ws / hs + np.zeros_like(xhhs)
  262. relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
  263. bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
  264. return relations, bboxes
  265. def pad_text_indices(self, text_inds):
  266. """Pad text index to same length."""
  267. max_len = 300
  268. recoder_len = max([len(text_ind) for text_ind in text_inds])
  269. padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
  270. for idx, text_ind in enumerate(text_inds):
  271. padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
  272. return padded_text_inds, recoder_len
  273. def list_to_numpy(self, ann_infos):
  274. """Convert bboxes, relations, texts and labels to ndarray."""
  275. boxes, text_inds = ann_infos['points'], ann_infos['text_inds']
  276. boxes = np.array(boxes, np.int32)
  277. relations, bboxes = self.compute_relation(boxes)
  278. labels = ann_infos.get('labels', None)
  279. if labels is not None:
  280. labels = np.array(labels, np.int32)
  281. edges = ann_infos.get('edges', None)
  282. if edges is not None:
  283. labels = labels[:, None]
  284. edges = np.array(edges)
  285. edges = (edges[:, None] == edges[None, :]).astype(np.int32)
  286. if self.directed:
  287. edges = (edges & labels == 1).astype(np.int32)
  288. np.fill_diagonal(edges, -1)
  289. labels = np.concatenate([labels, edges], -1)
  290. padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
  291. max_num = 300
  292. temp_bboxes = np.zeros([max_num, 4])
  293. h, _ = bboxes.shape
  294. temp_bboxes[:h, :] = bboxes
  295. temp_relations = np.zeros([max_num, max_num, 5])
  296. temp_relations[:h, :h, :] = relations
  297. temp_padded_text_inds = np.zeros([max_num, max_num])
  298. temp_padded_text_inds[:h, :] = padded_text_inds
  299. temp_labels = np.zeros([max_num, max_num])
  300. temp_labels[:h, :h + 1] = labels
  301. tag = np.array([h, recoder_len])
  302. return dict(
  303. image=ann_infos['image'],
  304. points=temp_bboxes,
  305. relations=temp_relations,
  306. texts=temp_padded_text_inds,
  307. labels=temp_labels,
  308. tag=tag)
  309. def convert_canonical(self, points_x, points_y):
  310. assert len(points_x) == 4
  311. assert len(points_y) == 4
  312. points = [Point(points_x[i], points_y[i]) for i in range(4)]
  313. polygon = Polygon([(p.x, p.y) for p in points])
  314. min_x, min_y, _, _ = polygon.bounds
  315. points_to_lefttop = [
  316. LineString([points[i], Point(min_x, min_y)]) for i in range(4)
  317. ]
  318. distances = np.array([line.length for line in points_to_lefttop])
  319. sort_dist_idx = np.argsort(distances)
  320. lefttop_idx = sort_dist_idx[0]
  321. if lefttop_idx == 0:
  322. point_orders = [0, 1, 2, 3]
  323. elif lefttop_idx == 1:
  324. point_orders = [1, 2, 3, 0]
  325. elif lefttop_idx == 2:
  326. point_orders = [2, 3, 0, 1]
  327. else:
  328. point_orders = [3, 0, 1, 2]
  329. sorted_points_x = [points_x[i] for i in point_orders]
  330. sorted_points_y = [points_y[j] for j in point_orders]
  331. return sorted_points_x, sorted_points_y
  332. def sort_vertex(self, points_x, points_y):
  333. assert len(points_x) == 4
  334. assert len(points_y) == 4
  335. x = np.array(points_x)
  336. y = np.array(points_y)
  337. center_x = np.sum(x) * 0.25
  338. center_y = np.sum(y) * 0.25
  339. x_arr = np.array(x - center_x)
  340. y_arr = np.array(y - center_y)
  341. angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
  342. sort_idx = np.argsort(angle)
  343. sorted_points_x, sorted_points_y = [], []
  344. for i in range(4):
  345. sorted_points_x.append(points_x[sort_idx[i]])
  346. sorted_points_y.append(points_y[sort_idx[i]])
  347. return self.convert_canonical(sorted_points_x, sorted_points_y)
  348. def __call__(self, data):
  349. import json
  350. label = data['label']
  351. annotations = json.loads(label)
  352. boxes, texts, text_inds, labels, edges = [], [], [], [], []
  353. for ann in annotations:
  354. box = ann['points']
  355. x_list = [box[i][0] for i in range(4)]
  356. y_list = [box[i][1] for i in range(4)]
  357. sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
  358. sorted_box = []
  359. for x, y in zip(sorted_x_list, sorted_y_list):
  360. sorted_box.append(x)
  361. sorted_box.append(y)
  362. boxes.append(sorted_box)
  363. text = ann['transcription']
  364. texts.append(ann['transcription'])
  365. text_ind = [self.dict[c] for c in text if c in self.dict]
  366. text_inds.append(text_ind)
  367. if 'label' in ann.keys():
  368. labels.append(self.label2classid_map[ann['label']])
  369. elif 'key_cls' in ann.keys():
  370. labels.append(ann['key_cls'])
  371. else:
  372. raise ValueError(
  373. "Cannot found 'key_cls' in ann.keys(), please check your training annotation."
  374. )
  375. edges.append(ann.get('edge', 0))
  376. ann_infos = dict(
  377. image=data['image'],
  378. points=boxes,
  379. texts=texts,
  380. text_inds=text_inds,
  381. edges=edges,
  382. labels=labels)
  383. return self.list_to_numpy(ann_infos)
  384. class AttnLabelEncode(BaseRecLabelEncode):
  385. """ Convert between text-label and text-index """
  386. def __init__(self,
  387. max_text_length,
  388. character_dict_path=None,
  389. use_space_char=False,
  390. **kwargs):
  391. super(AttnLabelEncode, self).__init__(
  392. max_text_length, character_dict_path, use_space_char)
  393. def add_special_char(self, dict_character):
  394. self.beg_str = "sos"
  395. self.end_str = "eos"
  396. dict_character = [self.beg_str] + dict_character + [self.end_str]
  397. return dict_character
  398. def __call__(self, data):
  399. text = data['label']
  400. text = self.encode(text)
  401. if text is None:
  402. return None
  403. if len(text) >= self.max_text_len:
  404. return None
  405. data['length'] = np.array(len(text))
  406. text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
  407. - len(text) - 2)
  408. data['label'] = np.array(text)
  409. return data
  410. def get_ignored_tokens(self):
  411. beg_idx = self.get_beg_end_flag_idx("beg")
  412. end_idx = self.get_beg_end_flag_idx("end")
  413. return [beg_idx, end_idx]
  414. def get_beg_end_flag_idx(self, beg_or_end):
  415. if beg_or_end == "beg":
  416. idx = np.array(self.dict[self.beg_str])
  417. elif beg_or_end == "end":
  418. idx = np.array(self.dict[self.end_str])
  419. else:
  420. assert False, "Unsupport type %s in get_beg_end_flag_idx" \
  421. % beg_or_end
  422. return idx
  423. class RFLLabelEncode(BaseRecLabelEncode):
  424. """ Convert between text-label and text-index """
  425. def __init__(self,
  426. max_text_length,
  427. character_dict_path=None,
  428. use_space_char=False,
  429. **kwargs):
  430. super(RFLLabelEncode, self).__init__(
  431. max_text_length, character_dict_path, use_space_char)
  432. def add_special_char(self, dict_character):
  433. self.beg_str = "sos"
  434. self.end_str = "eos"
  435. dict_character = [self.beg_str] + dict_character + [self.end_str]
  436. return dict_character
  437. def encode_cnt(self, text):
  438. cnt_label = [0.0] * len(self.character)
  439. for char_ in text:
  440. cnt_label[char_] += 1
  441. return np.array(cnt_label)
  442. def __call__(self, data):
  443. text = data['label']
  444. text = self.encode(text)
  445. if text is None:
  446. return None
  447. if len(text) >= self.max_text_len:
  448. return None
  449. cnt_label = self.encode_cnt(text)
  450. data['length'] = np.array(len(text))
  451. text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
  452. - len(text) - 2)
  453. if len(text) != self.max_text_len:
  454. return None
  455. data['label'] = np.array(text)
  456. data['cnt_label'] = cnt_label
  457. return data
  458. def get_ignored_tokens(self):
  459. beg_idx = self.get_beg_end_flag_idx("beg")
  460. end_idx = self.get_beg_end_flag_idx("end")
  461. return [beg_idx, end_idx]
  462. def get_beg_end_flag_idx(self, beg_or_end):
  463. if beg_or_end == "beg":
  464. idx = np.array(self.dict[self.beg_str])
  465. elif beg_or_end == "end":
  466. idx = np.array(self.dict[self.end_str])
  467. else:
  468. assert False, "Unsupport type %s in get_beg_end_flag_idx" \
  469. % beg_or_end
  470. return idx
  471. class SEEDLabelEncode(BaseRecLabelEncode):
  472. """ Convert between text-label and text-index """
  473. def __init__(self,
  474. max_text_length,
  475. character_dict_path=None,
  476. use_space_char=False,
  477. **kwargs):
  478. super(SEEDLabelEncode, self).__init__(
  479. max_text_length, character_dict_path, use_space_char)
  480. def add_special_char(self, dict_character):
  481. self.padding = "padding"
  482. self.end_str = "eos"
  483. self.unknown = "unknown"
  484. dict_character = dict_character + [
  485. self.end_str, self.padding, self.unknown
  486. ]
  487. return dict_character
  488. def __call__(self, data):
  489. text = data['label']
  490. text = self.encode(text)
  491. if text is None:
  492. return None
  493. if len(text) >= self.max_text_len:
  494. return None
  495. data['length'] = np.array(len(text)) + 1 # conclude eos
  496. text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
  497. self.max_text_len - len(text) - 1)
  498. data['label'] = np.array(text)
  499. return data
  500. class SRNLabelEncode(BaseRecLabelEncode):
  501. """ Convert between text-label and text-index """
  502. def __init__(self,
  503. max_text_length=25,
  504. character_dict_path=None,
  505. use_space_char=False,
  506. **kwargs):
  507. super(SRNLabelEncode, self).__init__(
  508. max_text_length, character_dict_path, use_space_char)
  509. def add_special_char(self, dict_character):
  510. dict_character = dict_character + [self.beg_str, self.end_str]
  511. return dict_character
  512. def __call__(self, data):
  513. text = data['label']
  514. text = self.encode(text)
  515. char_num = len(self.character)
  516. if text is None:
  517. return None
  518. if len(text) > self.max_text_len:
  519. return None
  520. data['length'] = np.array(len(text))
  521. text = text + [char_num - 1] * (self.max_text_len - len(text))
  522. data['label'] = np.array(text)
  523. return data
  524. def get_ignored_tokens(self):
  525. beg_idx = self.get_beg_end_flag_idx("beg")
  526. end_idx = self.get_beg_end_flag_idx("end")
  527. return [beg_idx, end_idx]
  528. def get_beg_end_flag_idx(self, beg_or_end):
  529. if beg_or_end == "beg":
  530. idx = np.array(self.dict[self.beg_str])
  531. elif beg_or_end == "end":
  532. idx = np.array(self.dict[self.end_str])
  533. else:
  534. assert False, "Unsupport type %s in get_beg_end_flag_idx" \
  535. % beg_or_end
  536. return idx
  537. class TableLabelEncode(AttnLabelEncode):
  538. """ Convert between text-label and text-index """
  539. def __init__(self,
  540. max_text_length,
  541. character_dict_path,
  542. replace_empty_cell_token=False,
  543. merge_no_span_structure=False,
  544. learn_empty_box=False,
  545. loc_reg_num=4,
  546. **kwargs):
  547. self.max_text_len = max_text_length
  548. self.lower = False
  549. self.learn_empty_box = learn_empty_box
  550. self.merge_no_span_structure = merge_no_span_structure
  551. self.replace_empty_cell_token = replace_empty_cell_token
  552. dict_character = []
  553. with open(character_dict_path, "rb") as fin:
  554. lines = fin.readlines()
  555. for line in lines:
  556. line = line.decode('utf-8').strip("\n").strip("\r\n")
  557. dict_character.append(line)
  558. if self.merge_no_span_structure:
  559. if "<td></td>" not in dict_character:
  560. dict_character.append("<td></td>")
  561. if "<td>" in dict_character:
  562. dict_character.remove("<td>")
  563. dict_character = self.add_special_char(dict_character)
  564. self.dict = {}
  565. for i, char in enumerate(dict_character):
  566. self.dict[char] = i
  567. self.idx2char = {v: k for k, v in self.dict.items()}
  568. self.character = dict_character
  569. self.loc_reg_num = loc_reg_num
  570. self.pad_idx = self.dict[self.beg_str]
  571. self.start_idx = self.dict[self.beg_str]
  572. self.end_idx = self.dict[self.end_str]
  573. self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
  574. self.empty_bbox_token_dict = {
  575. "[]": '<eb></eb>',
  576. "[' ']": '<eb1></eb1>',
  577. "['<b>', ' ', '</b>']": '<eb2></eb2>',
  578. "['\\u2028', '\\u2028']": '<eb3></eb3>',
  579. "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
  580. "['<b>', '</b>']": '<eb5></eb5>',
  581. "['<i>', ' ', '</i>']": '<eb6></eb6>',
  582. "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
  583. "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
  584. "['<i>', '</i>']": '<eb9></eb9>',
  585. "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']":
  586. '<eb10></eb10>',
  587. }
  588. @property
  589. def _max_text_len(self):
  590. return self.max_text_len + 2
  591. def __call__(self, data):
  592. cells = data['cells']
  593. structure = data['structure']
  594. if self.merge_no_span_structure:
  595. structure = self._merge_no_span_structure(structure)
  596. if self.replace_empty_cell_token:
  597. structure = self._replace_empty_cell_token(structure, cells)
  598. # remove empty token and add " " to span token
  599. new_structure = []
  600. for token in structure:
  601. if token != '':
  602. if 'span' in token and token[0] != ' ':
  603. token = ' ' + token
  604. new_structure.append(token)
  605. # encode structure
  606. structure = self.encode(new_structure)
  607. if structure is None:
  608. return None
  609. structure = [self.start_idx] + structure + [self.end_idx
  610. ] # add sos abd eos
  611. structure = structure + [self.pad_idx] * (self._max_text_len -
  612. len(structure)) # pad
  613. structure = np.array(structure)
  614. data['structure'] = structure
  615. if len(structure) > self._max_text_len:
  616. return None
  617. # encode box
  618. bboxes = np.zeros(
  619. (self._max_text_len, self.loc_reg_num), dtype=np.float32)
  620. bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
  621. bbox_idx = 0
  622. for i, token in enumerate(structure):
  623. if self.idx2char[token] in self.td_token:
  624. if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
  625. 'tokens']) > 0:
  626. bbox = cells[bbox_idx]['bbox'].copy()
  627. bbox = np.array(bbox, dtype=np.float32).reshape(-1)
  628. bboxes[i] = bbox
  629. bbox_masks[i] = 1.0
  630. if self.learn_empty_box:
  631. bbox_masks[i] = 1.0
  632. bbox_idx += 1
  633. data['bboxes'] = bboxes
  634. data['bbox_masks'] = bbox_masks
  635. return data
  636. def _merge_no_span_structure(self, structure):
  637. """
  638. This code is refer from:
  639. https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
  640. """
  641. new_structure = []
  642. i = 0
  643. while i < len(structure):
  644. token = structure[i]
  645. if token == '<td>':
  646. token = '<td></td>'
  647. i += 1
  648. new_structure.append(token)
  649. i += 1
  650. return new_structure
  651. def _replace_empty_cell_token(self, token_list, cells):
  652. """
  653. This fun code is refer from:
  654. https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
  655. """
  656. bbox_idx = 0
  657. add_empty_bbox_token_list = []
  658. for token in token_list:
  659. if token in ['<td></td>', '<td', '<td>']:
  660. if 'bbox' not in cells[bbox_idx].keys():
  661. content = str(cells[bbox_idx]['tokens'])
  662. token = self.empty_bbox_token_dict[content]
  663. add_empty_bbox_token_list.append(token)
  664. bbox_idx += 1
  665. else:
  666. add_empty_bbox_token_list.append(token)
  667. return add_empty_bbox_token_list
  668. class TableMasterLabelEncode(TableLabelEncode):
  669. """ Convert between text-label and text-index """
  670. def __init__(self,
  671. max_text_length,
  672. character_dict_path,
  673. replace_empty_cell_token=False,
  674. merge_no_span_structure=False,
  675. learn_empty_box=False,
  676. loc_reg_num=4,
  677. **kwargs):
  678. super(TableMasterLabelEncode, self).__init__(
  679. max_text_length, character_dict_path, replace_empty_cell_token,
  680. merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs)
  681. self.pad_idx = self.dict[self.pad_str]
  682. self.unknown_idx = self.dict[self.unknown_str]
  683. @property
  684. def _max_text_len(self):
  685. return self.max_text_len
  686. def add_special_char(self, dict_character):
  687. self.beg_str = '<SOS>'
  688. self.end_str = '<EOS>'
  689. self.unknown_str = '<UKN>'
  690. self.pad_str = '<PAD>'
  691. dict_character = dict_character
  692. dict_character = dict_character + [
  693. self.unknown_str, self.beg_str, self.end_str, self.pad_str
  694. ]
  695. return dict_character
  696. class TableBoxEncode(object):
  697. def __init__(self, in_box_format='xyxy', out_box_format='xyxy', **kwargs):
  698. assert out_box_format in ['xywh', 'xyxy', 'xyxyxyxy']
  699. self.in_box_format = in_box_format
  700. self.out_box_format = out_box_format
  701. def __call__(self, data):
  702. img_height, img_width = data['image'].shape[:2]
  703. bboxes = data['bboxes']
  704. if self.in_box_format != self.out_box_format:
  705. if self.out_box_format == 'xywh':
  706. if self.in_box_format == 'xyxyxyxy':
  707. bboxes = self.xyxyxyxy2xywh(bboxes)
  708. elif self.in_box_format == 'xyxy':
  709. bboxes = self.xyxy2xywh(bboxes)
  710. bboxes[:, 0::2] /= img_width
  711. bboxes[:, 1::2] /= img_height
  712. data['bboxes'] = bboxes
  713. return data
  714. def xyxyxyxy2xywh(self, boxes):
  715. new_bboxes = np.zeros([len(bboxes), 4])
  716. new_bboxes[:, 0] = bboxes[:, 0::2].min() # x1
  717. new_bboxes[:, 1] = bboxes[:, 1::2].min() # y1
  718. new_bboxes[:, 2] = bboxes[:, 0::2].max() - new_bboxes[:, 0] # w
  719. new_bboxes[:, 3] = bboxes[:, 1::2].max() - new_bboxes[:, 1] # h
  720. return new_bboxes
  721. def xyxy2xywh(self, bboxes):
  722. new_bboxes = np.empty_like(bboxes)
  723. new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
  724. new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
  725. new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
  726. new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
  727. return new_bboxes
  728. class SARLabelEncode(BaseRecLabelEncode):
  729. """ Convert between text-label and text-index """
  730. def __init__(self,
  731. max_text_length,
  732. character_dict_path=None,
  733. use_space_char=False,
  734. **kwargs):
  735. super(SARLabelEncode, self).__init__(
  736. max_text_length, character_dict_path, use_space_char)
  737. def add_special_char(self, dict_character):
  738. beg_end_str = "<BOS/EOS>"
  739. unknown_str = "<UKN>"
  740. padding_str = "<PAD>"
  741. dict_character = dict_character + [unknown_str]
  742. self.unknown_idx = len(dict_character) - 1
  743. dict_character = dict_character + [beg_end_str]
  744. self.start_idx = len(dict_character) - 1
  745. self.end_idx = len(dict_character) - 1
  746. dict_character = dict_character + [padding_str]
  747. self.padding_idx = len(dict_character) - 1
  748. return dict_character
  749. def __call__(self, data):
  750. text = data['label']
  751. text = self.encode(text)
  752. if text is None:
  753. return None
  754. if len(text) >= self.max_text_len - 1:
  755. return None
  756. data['length'] = np.array(len(text))
  757. target = [self.start_idx] + text + [self.end_idx]
  758. padded_text = [self.padding_idx for _ in range(self.max_text_len)]
  759. padded_text[:len(target)] = target
  760. data['label'] = np.array(padded_text)
  761. return data
  762. def get_ignored_tokens(self):
  763. return [self.padding_idx]
  764. class PRENLabelEncode(BaseRecLabelEncode):
  765. def __init__(self,
  766. max_text_length,
  767. character_dict_path,
  768. use_space_char=False,
  769. **kwargs):
  770. super(PRENLabelEncode, self).__init__(
  771. max_text_length, character_dict_path, use_space_char)
  772. def add_special_char(self, dict_character):
  773. padding_str = '<PAD>' # 0
  774. end_str = '<EOS>' # 1
  775. unknown_str = '<UNK>' # 2
  776. dict_character = [padding_str, end_str, unknown_str] + dict_character
  777. self.padding_idx = 0
  778. self.end_idx = 1
  779. self.unknown_idx = 2
  780. return dict_character
  781. def encode(self, text):
  782. if len(text) == 0 or len(text) >= self.max_text_len:
  783. return None
  784. if self.lower:
  785. text = text.lower()
  786. text_list = []
  787. for char in text:
  788. if char not in self.dict:
  789. text_list.append(self.unknown_idx)
  790. else:
  791. text_list.append(self.dict[char])
  792. text_list.append(self.end_idx)
  793. if len(text_list) < self.max_text_len:
  794. text_list += [self.padding_idx] * (
  795. self.max_text_len - len(text_list))
  796. return text_list
  797. def __call__(self, data):
  798. text = data['label']
  799. encoded_text = self.encode(text)
  800. if encoded_text is None:
  801. return None
  802. data['label'] = np.array(encoded_text)
  803. return data
  804. class VQATokenLabelEncode(object):
  805. """
  806. Label encode for NLP VQA methods
  807. """
  808. def __init__(self,
  809. class_path,
  810. contains_re=False,
  811. add_special_ids=False,
  812. algorithm='LayoutXLM',
  813. use_textline_bbox_info=True,
  814. order_method=None,
  815. infer_mode=False,
  816. ocr_engine=None,
  817. **kwargs):
  818. super(VQATokenLabelEncode, self).__init__()
  819. from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer
  820. from ppocr.utils.utility import load_vqa_bio_label_maps
  821. tokenizer_dict = {
  822. 'LayoutXLM': {
  823. 'class': LayoutXLMTokenizer,
  824. 'pretrained_model': 'layoutxlm-base-uncased'
  825. },
  826. 'LayoutLM': {
  827. 'class': LayoutLMTokenizer,
  828. 'pretrained_model': 'layoutlm-base-uncased'
  829. },
  830. 'LayoutLMv2': {
  831. 'class': LayoutLMv2Tokenizer,
  832. 'pretrained_model': 'layoutlmv2-base-uncased'
  833. }
  834. }
  835. self.contains_re = contains_re
  836. tokenizer_config = tokenizer_dict[algorithm]
  837. self.tokenizer = tokenizer_config['class'].from_pretrained(
  838. tokenizer_config['pretrained_model'])
  839. self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
  840. self.add_special_ids = add_special_ids
  841. self.infer_mode = infer_mode
  842. self.ocr_engine = ocr_engine
  843. self.use_textline_bbox_info = use_textline_bbox_info
  844. self.order_method = order_method
  845. assert self.order_method in [None, "tb-yx"]
  846. def split_bbox(self, bbox, text, tokenizer):
  847. words = text.split()
  848. token_bboxes = []
  849. curr_word_idx = 0
  850. x1, y1, x2, y2 = bbox
  851. unit_w = (x2 - x1) / len(text)
  852. for idx, word in enumerate(words):
  853. curr_w = len(word) * unit_w
  854. word_bbox = [x1, y1, x1 + curr_w, y2]
  855. token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word)))
  856. x1 += (len(word) + 1) * unit_w
  857. return token_bboxes
  858. def filter_empty_contents(self, ocr_info):
  859. """
  860. find out the empty texts and remove the links
  861. """
  862. new_ocr_info = []
  863. empty_index = []
  864. for idx, info in enumerate(ocr_info):
  865. if len(info["transcription"]) > 0:
  866. new_ocr_info.append(copy.deepcopy(info))
  867. else:
  868. empty_index.append(info["id"])
  869. for idx, info in enumerate(new_ocr_info):
  870. new_link = []
  871. for link in info["linking"]:
  872. if link[0] in empty_index or link[1] in empty_index:
  873. continue
  874. new_link.append(link)
  875. new_ocr_info[idx]["linking"] = new_link
  876. return new_ocr_info
  877. def __call__(self, data):
  878. # load bbox and label info
  879. ocr_info = self._load_ocr_info(data)
  880. for idx in range(len(ocr_info)):
  881. if "bbox" not in ocr_info[idx]:
  882. ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][
  883. "points"])
  884. if self.order_method == "tb-yx":
  885. ocr_info = order_by_tbyx(ocr_info)
  886. # for re
  887. train_re = self.contains_re and not self.infer_mode
  888. if train_re:
  889. ocr_info = self.filter_empty_contents(ocr_info)
  890. height, width, _ = data['image'].shape
  891. words_list = []
  892. bbox_list = []
  893. input_ids_list = []
  894. token_type_ids_list = []
  895. segment_offset_id = []
  896. gt_label_list = []
  897. entities = []
  898. if train_re:
  899. relations = []
  900. id2label = {}
  901. entity_id_to_index_map = {}
  902. empty_entity = set()
  903. data['ocr_info'] = copy.deepcopy(ocr_info)
  904. for info in ocr_info:
  905. text = info["transcription"]
  906. if len(text) <= 0:
  907. continue
  908. if train_re:
  909. # for re
  910. if len(text) == 0:
  911. empty_entity.add(info["id"])
  912. continue
  913. id2label[info["id"]] = info["label"]
  914. relations.extend([tuple(sorted(l)) for l in info["linking"]])
  915. # smooth_box
  916. info["bbox"] = self.trans_poly_to_bbox(info["points"])
  917. encode_res = self.tokenizer.encode(
  918. text,
  919. pad_to_max_seq_len=False,
  920. return_attention_mask=True,
  921. return_token_type_ids=True)
  922. if not self.add_special_ids:
  923. # TODO: use tok.all_special_ids to remove
  924. encode_res["input_ids"] = encode_res["input_ids"][1:-1]
  925. encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
  926. -1]
  927. encode_res["attention_mask"] = encode_res["attention_mask"][1:
  928. -1]
  929. if self.use_textline_bbox_info:
  930. bbox = [info["bbox"]] * len(encode_res["input_ids"])
  931. else:
  932. bbox = self.split_bbox(info["bbox"], info["transcription"],
  933. self.tokenizer)
  934. if len(bbox) <= 0:
  935. continue
  936. bbox = self._smooth_box(bbox, height, width)
  937. if self.add_special_ids:
  938. bbox.insert(0, [0, 0, 0, 0])
  939. bbox.append([0, 0, 0, 0])
  940. # parse label
  941. if not self.infer_mode:
  942. label = info['label']
  943. gt_label = self._parse_label(label, encode_res)
  944. # construct entities for re
  945. if train_re:
  946. if gt_label[0] != self.label2id_map["O"]:
  947. entity_id_to_index_map[info["id"]] = len(entities)
  948. label = label.upper()
  949. entities.append({
  950. "start": len(input_ids_list),
  951. "end":
  952. len(input_ids_list) + len(encode_res["input_ids"]),
  953. "label": label.upper(),
  954. })
  955. else:
  956. entities.append({
  957. "start": len(input_ids_list),
  958. "end": len(input_ids_list) + len(encode_res["input_ids"]),
  959. "label": 'O',
  960. })
  961. input_ids_list.extend(encode_res["input_ids"])
  962. token_type_ids_list.extend(encode_res["token_type_ids"])
  963. bbox_list.extend(bbox)
  964. words_list.append(text)
  965. segment_offset_id.append(len(input_ids_list))
  966. if not self.infer_mode:
  967. gt_label_list.extend(gt_label)
  968. data['input_ids'] = input_ids_list
  969. data['token_type_ids'] = token_type_ids_list
  970. data['bbox'] = bbox_list
  971. data['attention_mask'] = [1] * len(input_ids_list)
  972. data['labels'] = gt_label_list
  973. data['segment_offset_id'] = segment_offset_id
  974. data['tokenizer_params'] = dict(
  975. padding_side=self.tokenizer.padding_side,
  976. pad_token_type_id=self.tokenizer.pad_token_type_id,
  977. pad_token_id=self.tokenizer.pad_token_id)
  978. data['entities'] = entities
  979. if train_re:
  980. data['relations'] = relations
  981. data['id2label'] = id2label
  982. data['empty_entity'] = empty_entity
  983. data['entity_id_to_index_map'] = entity_id_to_index_map
  984. return data
  985. def trans_poly_to_bbox(self, poly):
  986. x1 = int(np.min([p[0] for p in poly]))
  987. x2 = int(np.max([p[0] for p in poly]))
  988. y1 = int(np.min([p[1] for p in poly]))
  989. y2 = int(np.max([p[1] for p in poly]))
  990. return [x1, y1, x2, y2]
  991. def _load_ocr_info(self, data):
  992. if self.infer_mode:
  993. ocr_result = self.ocr_engine.ocr(data['image'], cls=False)[0]
  994. ocr_info = []
  995. for res in ocr_result:
  996. ocr_info.append({
  997. "transcription": res[1][0],
  998. "bbox": self.trans_poly_to_bbox(res[0]),
  999. "points": res[0],
  1000. })
  1001. return ocr_info
  1002. else:
  1003. info = data['label']
  1004. # read text info
  1005. info_dict = json.loads(info)
  1006. return info_dict
  1007. def _smooth_box(self, bboxes, height, width):
  1008. bboxes = np.array(bboxes)
  1009. bboxes[:, 0] = bboxes[:, 0] * 1000 / width
  1010. bboxes[:, 2] = bboxes[:, 2] * 1000 / width
  1011. bboxes[:, 1] = bboxes[:, 1] * 1000 / height
  1012. bboxes[:, 3] = bboxes[:, 3] * 1000 / height
  1013. bboxes = bboxes.astype("int64").tolist()
  1014. return bboxes
  1015. def _parse_label(self, label, encode_res):
  1016. gt_label = []
  1017. if label.lower() in ["other", "others", "ignore"]:
  1018. gt_label.extend([0] * len(encode_res["input_ids"]))
  1019. else:
  1020. gt_label.append(self.label2id_map[("b-" + label).upper()])
  1021. gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
  1022. (len(encode_res["input_ids"]) - 1))
  1023. return gt_label
  1024. class MultiLabelEncode(BaseRecLabelEncode):
  1025. def __init__(self,
  1026. max_text_length,
  1027. character_dict_path=None,
  1028. use_space_char=False,
  1029. **kwargs):
  1030. super(MultiLabelEncode, self).__init__(
  1031. max_text_length, character_dict_path, use_space_char)
  1032. self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path,
  1033. use_space_char, **kwargs)
  1034. self.sar_encode = SARLabelEncode(max_text_length, character_dict_path,
  1035. use_space_char, **kwargs)
  1036. def __call__(self, data):
  1037. data_ctc = copy.deepcopy(data)
  1038. data_sar = copy.deepcopy(data)
  1039. data_out = dict()
  1040. data_out['img_path'] = data.get('img_path', None)
  1041. data_out['image'] = data['image']
  1042. ctc = self.ctc_encode.__call__(data_ctc)
  1043. sar = self.sar_encode.__call__(data_sar)
  1044. if ctc is None or sar is None:
  1045. return None
  1046. data_out['label_ctc'] = ctc['label']
  1047. data_out['label_sar'] = sar['label']
  1048. data_out['length'] = ctc['length']
  1049. return data_out
  1050. class NRTRLabelEncode(BaseRecLabelEncode):
  1051. """ Convert between text-label and text-index """
  1052. def __init__(self,
  1053. max_text_length,
  1054. character_dict_path=None,
  1055. use_space_char=False,
  1056. **kwargs):
  1057. super(NRTRLabelEncode, self).__init__(
  1058. max_text_length, character_dict_path, use_space_char)
  1059. def __call__(self, data):
  1060. text = data['label']
  1061. text = self.encode(text)
  1062. if text is None:
  1063. return None
  1064. if len(text) >= self.max_text_len - 1:
  1065. return None
  1066. data['length'] = np.array(len(text))
  1067. text.insert(0, 2)
  1068. text.append(3)
  1069. text = text + [0] * (self.max_text_len - len(text))
  1070. data['label'] = np.array(text)
  1071. return data
  1072. def add_special_char(self, dict_character):
  1073. dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
  1074. return dict_character
  1075. class ViTSTRLabelEncode(BaseRecLabelEncode):
  1076. """ Convert between text-label and text-index """
  1077. def __init__(self,
  1078. max_text_length,
  1079. character_dict_path=None,
  1080. use_space_char=False,
  1081. ignore_index=0,
  1082. **kwargs):
  1083. super(ViTSTRLabelEncode, self).__init__(
  1084. max_text_length, character_dict_path, use_space_char)
  1085. self.ignore_index = ignore_index
  1086. def __call__(self, data):
  1087. text = data['label']
  1088. text = self.encode(text)
  1089. if text is None:
  1090. return None
  1091. if len(text) >= self.max_text_len:
  1092. return None
  1093. data['length'] = np.array(len(text))
  1094. text.insert(0, self.ignore_index)
  1095. text.append(1)
  1096. text = text + [self.ignore_index] * (self.max_text_len + 2 - len(text))
  1097. data['label'] = np.array(text)
  1098. return data
  1099. def add_special_char(self, dict_character):
  1100. dict_character = ['<s>', '</s>'] + dict_character
  1101. return dict_character
  1102. class ABINetLabelEncode(BaseRecLabelEncode):
  1103. """ Convert between text-label and text-index """
  1104. def __init__(self,
  1105. max_text_length,
  1106. character_dict_path=None,
  1107. use_space_char=False,
  1108. ignore_index=100,
  1109. **kwargs):
  1110. super(ABINetLabelEncode, self).__init__(
  1111. max_text_length, character_dict_path, use_space_char)
  1112. self.ignore_index = ignore_index
  1113. def __call__(self, data):
  1114. text = data['label']
  1115. text = self.encode(text)
  1116. if text is None:
  1117. return None
  1118. if len(text) >= self.max_text_len:
  1119. return None
  1120. data['length'] = np.array(len(text))
  1121. text.append(0)
  1122. text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
  1123. data['label'] = np.array(text)
  1124. return data
  1125. def add_special_char(self, dict_character):
  1126. dict_character = ['</s>'] + dict_character
  1127. return dict_character
  1128. class SRLabelEncode(BaseRecLabelEncode):
  1129. def __init__(self,
  1130. max_text_length,
  1131. character_dict_path=None,
  1132. use_space_char=False,
  1133. **kwargs):
  1134. super(SRLabelEncode, self).__init__(max_text_length,
  1135. character_dict_path, use_space_char)
  1136. self.dic = {}
  1137. with open(character_dict_path, 'r') as fin:
  1138. for line in fin.readlines():
  1139. line = line.strip()
  1140. character, sequence = line.split()
  1141. self.dic[character] = sequence
  1142. english_stroke_alphabet = '0123456789'
  1143. self.english_stroke_dict = {}
  1144. for index in range(len(english_stroke_alphabet)):
  1145. self.english_stroke_dict[english_stroke_alphabet[index]] = index
  1146. def encode(self, label):
  1147. stroke_sequence = ''
  1148. for character in label:
  1149. if character not in self.dic:
  1150. continue
  1151. else:
  1152. stroke_sequence += self.dic[character]
  1153. stroke_sequence += '0'
  1154. label = stroke_sequence
  1155. length = len(label)
  1156. input_tensor = np.zeros(self.max_text_len).astype("int64")
  1157. for j in range(length - 1):
  1158. input_tensor[j + 1] = self.english_stroke_dict[label[j]]
  1159. return length, input_tensor
  1160. def __call__(self, data):
  1161. text = data['label']
  1162. length, input_tensor = self.encode(text)
  1163. data["length"] = length
  1164. data["input_tensor"] = input_tensor
  1165. if text is None:
  1166. return None
  1167. return data
  1168. class SPINLabelEncode(AttnLabelEncode):
  1169. """ Convert between text-label and text-index """
  1170. def __init__(self,
  1171. max_text_length,
  1172. character_dict_path=None,
  1173. use_space_char=False,
  1174. lower=True,
  1175. **kwargs):
  1176. super(SPINLabelEncode, self).__init__(
  1177. max_text_length, character_dict_path, use_space_char)
  1178. self.lower = lower
  1179. def add_special_char(self, dict_character):
  1180. self.beg_str = "sos"
  1181. self.end_str = "eos"
  1182. dict_character = [self.beg_str] + [self.end_str] + dict_character
  1183. return dict_character
  1184. def __call__(self, data):
  1185. text = data['label']
  1186. text = self.encode(text)
  1187. if text is None:
  1188. return None
  1189. if len(text) > self.max_text_len:
  1190. return None
  1191. data['length'] = np.array(len(text))
  1192. target = [0] + text + [1]
  1193. padded_text = [0 for _ in range(self.max_text_len + 2)]
  1194. padded_text[:len(target)] = target
  1195. data['label'] = np.array(padded_text)
  1196. return data
  1197. class VLLabelEncode(BaseRecLabelEncode):
  1198. """ Convert between text-label and text-index """
  1199. def __init__(self,
  1200. max_text_length,
  1201. character_dict_path=None,
  1202. use_space_char=False,
  1203. **kwargs):
  1204. super(VLLabelEncode, self).__init__(max_text_length,
  1205. character_dict_path, use_space_char)
  1206. self.dict = {}
  1207. for i, char in enumerate(self.character):
  1208. self.dict[char] = i
  1209. def __call__(self, data):
  1210. text = data['label'] # original string
  1211. # generate occluded text
  1212. len_str = len(text)
  1213. if len_str <= 0:
  1214. return None
  1215. change_num = 1
  1216. order = list(range(len_str))
  1217. change_id = sample(order, change_num)[0]
  1218. label_sub = text[change_id]
  1219. if change_id == (len_str - 1):
  1220. label_res = text[:change_id]
  1221. elif change_id == 0:
  1222. label_res = text[1:]
  1223. else:
  1224. label_res = text[:change_id] + text[change_id + 1:]
  1225. data['label_res'] = label_res # remaining string
  1226. data['label_sub'] = label_sub # occluded character
  1227. data['label_id'] = change_id # character index
  1228. # encode label
  1229. text = self.encode(text)
  1230. if text is None:
  1231. return None
  1232. text = [i + 1 for i in text]
  1233. data['length'] = np.array(len(text))
  1234. text = text + [0] * (self.max_text_len - len(text))
  1235. data['label'] = np.array(text)
  1236. label_res = self.encode(label_res)
  1237. label_sub = self.encode(label_sub)
  1238. if label_res is None:
  1239. label_res = []
  1240. else:
  1241. label_res = [i + 1 for i in label_res]
  1242. if label_sub is None:
  1243. label_sub = []
  1244. else:
  1245. label_sub = [i + 1 for i in label_sub]
  1246. data['length_res'] = np.array(len(label_res))
  1247. data['length_sub'] = np.array(len(label_sub))
  1248. label_res = label_res + [0] * (self.max_text_len - len(label_res))
  1249. label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
  1250. data['label_res'] = np.array(label_res)
  1251. data['label_sub'] = np.array(label_sub)
  1252. return data
  1253. class CTLabelEncode(object):
  1254. def __init__(self, **kwargs):
  1255. pass
  1256. def __call__(self, data):
  1257. label = data['label']
  1258. label = json.loads(label)
  1259. nBox = len(label)
  1260. boxes, txts = [], []
  1261. for bno in range(0, nBox):
  1262. box = label[bno]['points']
  1263. box = np.array(box)
  1264. boxes.append(box)
  1265. txt = label[bno]['transcription']
  1266. txts.append(txt)
  1267. if len(boxes) == 0:
  1268. return None
  1269. data['polys'] = boxes
  1270. data['texts'] = txts
  1271. return data
  1272. class CANLabelEncode(BaseRecLabelEncode):
  1273. def __init__(self,
  1274. character_dict_path,
  1275. max_text_length=100,
  1276. use_space_char=False,
  1277. lower=True,
  1278. **kwargs):
  1279. super(CANLabelEncode, self).__init__(
  1280. max_text_length, character_dict_path, use_space_char, lower)
  1281. def encode(self, text_seq):
  1282. text_seq_encoded = []
  1283. for text in text_seq:
  1284. if text not in self.character:
  1285. continue
  1286. text_seq_encoded.append(self.dict.get(text))
  1287. if len(text_seq_encoded) == 0:
  1288. return None
  1289. return text_seq_encoded
  1290. def __call__(self, data):
  1291. label = data['label']
  1292. if isinstance(label, str):
  1293. label = label.strip().split()
  1294. label.append(self.end_str)
  1295. data['label'] = self.encode(label)
  1296. return data