rec_postprocess.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from paddle.nn import functional as F
  17. import re
  18. class BaseRecLabelDecode(object):
  19. """ Convert between text-label and text-index """
  20. def __init__(self, character_dict_path=None, use_space_char=False):
  21. self.beg_str = "sos"
  22. self.end_str = "eos"
  23. self.reverse = False
  24. self.character_str = []
  25. if character_dict_path is None:
  26. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  27. dict_character = list(self.character_str)
  28. else:
  29. with open(character_dict_path, "rb") as fin:
  30. lines = fin.readlines()
  31. for line in lines:
  32. line = line.decode('utf-8').strip("\n").strip("\r\n")
  33. self.character_str.append(line)
  34. if use_space_char:
  35. self.character_str.append(" ")
  36. dict_character = list(self.character_str)
  37. if 'arabic' in character_dict_path:
  38. self.reverse = True
  39. dict_character = self.add_special_char(dict_character)
  40. self.dict = {}
  41. for i, char in enumerate(dict_character):
  42. self.dict[char] = i
  43. self.character = dict_character
  44. def pred_reverse(self, pred):
  45. pred_re = []
  46. c_current = ''
  47. for c in pred:
  48. if not bool(re.search('[a-zA-Z0-9 :*./%+-]', c)):
  49. if c_current != '':
  50. pred_re.append(c_current)
  51. pred_re.append(c)
  52. c_current = ''
  53. else:
  54. c_current += c
  55. if c_current != '':
  56. pred_re.append(c_current)
  57. return ''.join(pred_re[::-1])
  58. def add_special_char(self, dict_character):
  59. return dict_character
  60. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  61. """ convert text-index into text-label. """
  62. result_list = []
  63. ignored_tokens = self.get_ignored_tokens()
  64. batch_size = len(text_index)
  65. for batch_idx in range(batch_size):
  66. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  67. if is_remove_duplicate:
  68. selection[1:] = text_index[batch_idx][1:] != text_index[
  69. batch_idx][:-1]
  70. for ignored_token in ignored_tokens:
  71. selection &= text_index[batch_idx] != ignored_token
  72. char_list = [
  73. self.character[text_id]
  74. for text_id in text_index[batch_idx][selection]
  75. ]
  76. if text_prob is not None:
  77. conf_list = text_prob[batch_idx][selection]
  78. else:
  79. conf_list = [1] * len(selection)
  80. if len(conf_list) == 0:
  81. conf_list = [0]
  82. text = ''.join(char_list)
  83. if self.reverse: # for arabic rec
  84. text = self.pred_reverse(text)
  85. result_list.append((text, np.mean(conf_list).tolist()))
  86. return result_list
  87. def get_ignored_tokens(self):
  88. return [0] # for ctc blank
  89. class CTCLabelDecode(BaseRecLabelDecode):
  90. """ Convert between text-label and text-index """
  91. def __init__(self, character_dict_path=None, use_space_char=False,
  92. **kwargs):
  93. super(CTCLabelDecode, self).__init__(character_dict_path,
  94. use_space_char)
  95. def __call__(self, preds, label=None, *args, **kwargs):
  96. if isinstance(preds, tuple) or isinstance(preds, list):
  97. preds = preds[-1]
  98. if isinstance(preds, paddle.Tensor):
  99. preds = preds.numpy()
  100. preds_idx = preds.argmax(axis=2)
  101. preds_prob = preds.max(axis=2)
  102. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
  103. if label is None:
  104. return text
  105. label = self.decode(label)
  106. return text, label
  107. def add_special_char(self, dict_character):
  108. dict_character = ['blank'] + dict_character
  109. return dict_character
  110. class DistillationCTCLabelDecode(CTCLabelDecode):
  111. """
  112. Convert
  113. Convert between text-label and text-index
  114. """
  115. def __init__(self,
  116. character_dict_path=None,
  117. use_space_char=False,
  118. model_name=["student"],
  119. key=None,
  120. multi_head=False,
  121. **kwargs):
  122. super(DistillationCTCLabelDecode, self).__init__(character_dict_path,
  123. use_space_char)
  124. if not isinstance(model_name, list):
  125. model_name = [model_name]
  126. self.model_name = model_name
  127. self.key = key
  128. self.multi_head = multi_head
  129. def __call__(self, preds, label=None, *args, **kwargs):
  130. output = dict()
  131. for name in self.model_name:
  132. pred = preds[name]
  133. if self.key is not None:
  134. pred = pred[self.key]
  135. if self.multi_head and isinstance(pred, dict):
  136. pred = pred['ctc']
  137. output[name] = super().__call__(pred, label=label, *args, **kwargs)
  138. return output
  139. class AttnLabelDecode(BaseRecLabelDecode):
  140. """ Convert between text-label and text-index """
  141. def __init__(self, character_dict_path=None, use_space_char=False,
  142. **kwargs):
  143. super(AttnLabelDecode, self).__init__(character_dict_path,
  144. use_space_char)
  145. def add_special_char(self, dict_character):
  146. self.beg_str = "sos"
  147. self.end_str = "eos"
  148. dict_character = dict_character
  149. dict_character = [self.beg_str] + dict_character + [self.end_str]
  150. return dict_character
  151. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  152. """ convert text-index into text-label. """
  153. result_list = []
  154. ignored_tokens = self.get_ignored_tokens()
  155. [beg_idx, end_idx] = self.get_ignored_tokens()
  156. batch_size = len(text_index)
  157. for batch_idx in range(batch_size):
  158. char_list = []
  159. conf_list = []
  160. for idx in range(len(text_index[batch_idx])):
  161. if text_index[batch_idx][idx] in ignored_tokens:
  162. continue
  163. if int(text_index[batch_idx][idx]) == int(end_idx):
  164. break
  165. if is_remove_duplicate:
  166. # only for predict
  167. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  168. batch_idx][idx]:
  169. continue
  170. char_list.append(self.character[int(text_index[batch_idx][
  171. idx])])
  172. if text_prob is not None:
  173. conf_list.append(text_prob[batch_idx][idx])
  174. else:
  175. conf_list.append(1)
  176. text = ''.join(char_list)
  177. result_list.append((text, np.mean(conf_list).tolist()))
  178. return result_list
  179. def __call__(self, preds, label=None, *args, **kwargs):
  180. """
  181. text = self.decode(text)
  182. if label is None:
  183. return text
  184. else:
  185. label = self.decode(label, is_remove_duplicate=False)
  186. return text, label
  187. """
  188. if isinstance(preds, paddle.Tensor):
  189. preds = preds.numpy()
  190. preds_idx = preds.argmax(axis=2)
  191. preds_prob = preds.max(axis=2)
  192. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  193. if label is None:
  194. return text
  195. label = self.decode(label, is_remove_duplicate=False)
  196. return text, label
  197. def get_ignored_tokens(self):
  198. beg_idx = self.get_beg_end_flag_idx("beg")
  199. end_idx = self.get_beg_end_flag_idx("end")
  200. return [beg_idx, end_idx]
  201. def get_beg_end_flag_idx(self, beg_or_end):
  202. if beg_or_end == "beg":
  203. idx = np.array(self.dict[self.beg_str])
  204. elif beg_or_end == "end":
  205. idx = np.array(self.dict[self.end_str])
  206. else:
  207. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  208. % beg_or_end
  209. return idx
  210. class RFLLabelDecode(BaseRecLabelDecode):
  211. """ Convert between text-label and text-index """
  212. def __init__(self, character_dict_path=None, use_space_char=False,
  213. **kwargs):
  214. super(RFLLabelDecode, self).__init__(character_dict_path,
  215. use_space_char)
  216. def add_special_char(self, dict_character):
  217. self.beg_str = "sos"
  218. self.end_str = "eos"
  219. dict_character = dict_character
  220. dict_character = [self.beg_str] + dict_character + [self.end_str]
  221. return dict_character
  222. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  223. """ convert text-index into text-label. """
  224. result_list = []
  225. ignored_tokens = self.get_ignored_tokens()
  226. [beg_idx, end_idx] = self.get_ignored_tokens()
  227. batch_size = len(text_index)
  228. for batch_idx in range(batch_size):
  229. char_list = []
  230. conf_list = []
  231. for idx in range(len(text_index[batch_idx])):
  232. if text_index[batch_idx][idx] in ignored_tokens:
  233. continue
  234. if int(text_index[batch_idx][idx]) == int(end_idx):
  235. break
  236. if is_remove_duplicate:
  237. # only for predict
  238. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  239. batch_idx][idx]:
  240. continue
  241. char_list.append(self.character[int(text_index[batch_idx][
  242. idx])])
  243. if text_prob is not None:
  244. conf_list.append(text_prob[batch_idx][idx])
  245. else:
  246. conf_list.append(1)
  247. text = ''.join(char_list)
  248. result_list.append((text, np.mean(conf_list).tolist()))
  249. return result_list
  250. def __call__(self, preds, label=None, *args, **kwargs):
  251. # if seq_outputs is not None:
  252. if isinstance(preds, tuple) or isinstance(preds, list):
  253. cnt_outputs, seq_outputs = preds
  254. if isinstance(seq_outputs, paddle.Tensor):
  255. seq_outputs = seq_outputs.numpy()
  256. preds_idx = seq_outputs.argmax(axis=2)
  257. preds_prob = seq_outputs.max(axis=2)
  258. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  259. if label is None:
  260. return text
  261. label = self.decode(label, is_remove_duplicate=False)
  262. return text, label
  263. else:
  264. cnt_outputs = preds
  265. if isinstance(cnt_outputs, paddle.Tensor):
  266. cnt_outputs = cnt_outputs.numpy()
  267. cnt_length = []
  268. for lens in cnt_outputs:
  269. length = round(np.sum(lens))
  270. cnt_length.append(length)
  271. if label is None:
  272. return cnt_length
  273. label = self.decode(label, is_remove_duplicate=False)
  274. length = [len(res[0]) for res in label]
  275. return cnt_length, length
  276. def get_ignored_tokens(self):
  277. beg_idx = self.get_beg_end_flag_idx("beg")
  278. end_idx = self.get_beg_end_flag_idx("end")
  279. return [beg_idx, end_idx]
  280. def get_beg_end_flag_idx(self, beg_or_end):
  281. if beg_or_end == "beg":
  282. idx = np.array(self.dict[self.beg_str])
  283. elif beg_or_end == "end":
  284. idx = np.array(self.dict[self.end_str])
  285. else:
  286. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  287. % beg_or_end
  288. return idx
  289. class SEEDLabelDecode(BaseRecLabelDecode):
  290. """ Convert between text-label and text-index """
  291. def __init__(self, character_dict_path=None, use_space_char=False,
  292. **kwargs):
  293. super(SEEDLabelDecode, self).__init__(character_dict_path,
  294. use_space_char)
  295. def add_special_char(self, dict_character):
  296. self.padding_str = "padding"
  297. self.end_str = "eos"
  298. self.unknown = "unknown"
  299. dict_character = dict_character + [
  300. self.end_str, self.padding_str, self.unknown
  301. ]
  302. return dict_character
  303. def get_ignored_tokens(self):
  304. end_idx = self.get_beg_end_flag_idx("eos")
  305. return [end_idx]
  306. def get_beg_end_flag_idx(self, beg_or_end):
  307. if beg_or_end == "sos":
  308. idx = np.array(self.dict[self.beg_str])
  309. elif beg_or_end == "eos":
  310. idx = np.array(self.dict[self.end_str])
  311. else:
  312. assert False, "unsupport type %s in get_beg_end_flag_idx" % beg_or_end
  313. return idx
  314. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  315. """ convert text-index into text-label. """
  316. result_list = []
  317. [end_idx] = self.get_ignored_tokens()
  318. batch_size = len(text_index)
  319. for batch_idx in range(batch_size):
  320. char_list = []
  321. conf_list = []
  322. for idx in range(len(text_index[batch_idx])):
  323. if int(text_index[batch_idx][idx]) == int(end_idx):
  324. break
  325. if is_remove_duplicate:
  326. # only for predict
  327. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  328. batch_idx][idx]:
  329. continue
  330. char_list.append(self.character[int(text_index[batch_idx][
  331. idx])])
  332. if text_prob is not None:
  333. conf_list.append(text_prob[batch_idx][idx])
  334. else:
  335. conf_list.append(1)
  336. text = ''.join(char_list)
  337. result_list.append((text, np.mean(conf_list).tolist()))
  338. return result_list
  339. def __call__(self, preds, label=None, *args, **kwargs):
  340. """
  341. text = self.decode(text)
  342. if label is None:
  343. return text
  344. else:
  345. label = self.decode(label, is_remove_duplicate=False)
  346. return text, label
  347. """
  348. preds_idx = preds["rec_pred"]
  349. if isinstance(preds_idx, paddle.Tensor):
  350. preds_idx = preds_idx.numpy()
  351. if "rec_pred_scores" in preds:
  352. preds_idx = preds["rec_pred"]
  353. preds_prob = preds["rec_pred_scores"]
  354. else:
  355. preds_idx = preds["rec_pred"].argmax(axis=2)
  356. preds_prob = preds["rec_pred"].max(axis=2)
  357. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  358. if label is None:
  359. return text
  360. label = self.decode(label, is_remove_duplicate=False)
  361. return text, label
  362. class SRNLabelDecode(BaseRecLabelDecode):
  363. """ Convert between text-label and text-index """
  364. def __init__(self, character_dict_path=None, use_space_char=False,
  365. **kwargs):
  366. super(SRNLabelDecode, self).__init__(character_dict_path,
  367. use_space_char)
  368. self.max_text_length = kwargs.get('max_text_length', 25)
  369. def __call__(self, preds, label=None, *args, **kwargs):
  370. pred = preds['predict']
  371. char_num = len(self.character_str) + 2
  372. if isinstance(pred, paddle.Tensor):
  373. pred = pred.numpy()
  374. pred = np.reshape(pred, [-1, char_num])
  375. preds_idx = np.argmax(pred, axis=1)
  376. preds_prob = np.max(pred, axis=1)
  377. preds_idx = np.reshape(preds_idx, [-1, self.max_text_length])
  378. preds_prob = np.reshape(preds_prob, [-1, self.max_text_length])
  379. text = self.decode(preds_idx, preds_prob)
  380. if label is None:
  381. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  382. return text
  383. label = self.decode(label)
  384. return text, label
  385. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  386. """ convert text-index into text-label. """
  387. result_list = []
  388. ignored_tokens = self.get_ignored_tokens()
  389. batch_size = len(text_index)
  390. for batch_idx in range(batch_size):
  391. char_list = []
  392. conf_list = []
  393. for idx in range(len(text_index[batch_idx])):
  394. if text_index[batch_idx][idx] in ignored_tokens:
  395. continue
  396. if is_remove_duplicate:
  397. # only for predict
  398. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  399. batch_idx][idx]:
  400. continue
  401. char_list.append(self.character[int(text_index[batch_idx][
  402. idx])])
  403. if text_prob is not None:
  404. conf_list.append(text_prob[batch_idx][idx])
  405. else:
  406. conf_list.append(1)
  407. text = ''.join(char_list)
  408. result_list.append((text, np.mean(conf_list).tolist()))
  409. return result_list
  410. def add_special_char(self, dict_character):
  411. dict_character = dict_character + [self.beg_str, self.end_str]
  412. return dict_character
  413. def get_ignored_tokens(self):
  414. beg_idx = self.get_beg_end_flag_idx("beg")
  415. end_idx = self.get_beg_end_flag_idx("end")
  416. return [beg_idx, end_idx]
  417. def get_beg_end_flag_idx(self, beg_or_end):
  418. if beg_or_end == "beg":
  419. idx = np.array(self.dict[self.beg_str])
  420. elif beg_or_end == "end":
  421. idx = np.array(self.dict[self.end_str])
  422. else:
  423. assert False, "unsupport type %s in get_beg_end_flag_idx" \
  424. % beg_or_end
  425. return idx
  426. class SARLabelDecode(BaseRecLabelDecode):
  427. """ Convert between text-label and text-index """
  428. def __init__(self, character_dict_path=None, use_space_char=False,
  429. **kwargs):
  430. super(SARLabelDecode, self).__init__(character_dict_path,
  431. use_space_char)
  432. self.rm_symbol = kwargs.get('rm_symbol', False)
  433. def add_special_char(self, dict_character):
  434. beg_end_str = "<BOS/EOS>"
  435. unknown_str = "<UKN>"
  436. padding_str = "<PAD>"
  437. dict_character = dict_character + [unknown_str]
  438. self.unknown_idx = len(dict_character) - 1
  439. dict_character = dict_character + [beg_end_str]
  440. self.start_idx = len(dict_character) - 1
  441. self.end_idx = len(dict_character) - 1
  442. dict_character = dict_character + [padding_str]
  443. self.padding_idx = len(dict_character) - 1
  444. return dict_character
  445. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  446. """ convert text-index into text-label. """
  447. result_list = []
  448. ignored_tokens = self.get_ignored_tokens()
  449. batch_size = len(text_index)
  450. for batch_idx in range(batch_size):
  451. char_list = []
  452. conf_list = []
  453. for idx in range(len(text_index[batch_idx])):
  454. if text_index[batch_idx][idx] in ignored_tokens:
  455. continue
  456. if int(text_index[batch_idx][idx]) == int(self.end_idx):
  457. if text_prob is None and idx == 0:
  458. continue
  459. else:
  460. break
  461. if is_remove_duplicate:
  462. # only for predict
  463. if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
  464. batch_idx][idx]:
  465. continue
  466. char_list.append(self.character[int(text_index[batch_idx][
  467. idx])])
  468. if text_prob is not None:
  469. conf_list.append(text_prob[batch_idx][idx])
  470. else:
  471. conf_list.append(1)
  472. text = ''.join(char_list)
  473. if self.rm_symbol:
  474. comp = re.compile('[^A-Z^a-z^0-9^\u4e00-\u9fa5]')
  475. text = text.lower()
  476. text = comp.sub('', text)
  477. result_list.append((text, np.mean(conf_list).tolist()))
  478. return result_list
  479. def __call__(self, preds, label=None, *args, **kwargs):
  480. if isinstance(preds, paddle.Tensor):
  481. preds = preds.numpy()
  482. preds_idx = preds.argmax(axis=2)
  483. preds_prob = preds.max(axis=2)
  484. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  485. if label is None:
  486. return text
  487. label = self.decode(label, is_remove_duplicate=False)
  488. return text, label
  489. def get_ignored_tokens(self):
  490. return [self.padding_idx]
  491. class DistillationSARLabelDecode(SARLabelDecode):
  492. """
  493. Convert
  494. Convert between text-label and text-index
  495. """
  496. def __init__(self,
  497. character_dict_path=None,
  498. use_space_char=False,
  499. model_name=["student"],
  500. key=None,
  501. multi_head=False,
  502. **kwargs):
  503. super(DistillationSARLabelDecode, self).__init__(character_dict_path,
  504. use_space_char)
  505. if not isinstance(model_name, list):
  506. model_name = [model_name]
  507. self.model_name = model_name
  508. self.key = key
  509. self.multi_head = multi_head
  510. def __call__(self, preds, label=None, *args, **kwargs):
  511. output = dict()
  512. for name in self.model_name:
  513. pred = preds[name]
  514. if self.key is not None:
  515. pred = pred[self.key]
  516. if self.multi_head and isinstance(pred, dict):
  517. pred = pred['sar']
  518. output[name] = super().__call__(pred, label=label, *args, **kwargs)
  519. return output
  520. class PRENLabelDecode(BaseRecLabelDecode):
  521. """ Convert between text-label and text-index """
  522. def __init__(self, character_dict_path=None, use_space_char=False,
  523. **kwargs):
  524. super(PRENLabelDecode, self).__init__(character_dict_path,
  525. use_space_char)
  526. def add_special_char(self, dict_character):
  527. padding_str = '<PAD>' # 0
  528. end_str = '<EOS>' # 1
  529. unknown_str = '<UNK>' # 2
  530. dict_character = [padding_str, end_str, unknown_str] + dict_character
  531. self.padding_idx = 0
  532. self.end_idx = 1
  533. self.unknown_idx = 2
  534. return dict_character
  535. def decode(self, text_index, text_prob=None):
  536. """ convert text-index into text-label. """
  537. result_list = []
  538. batch_size = len(text_index)
  539. for batch_idx in range(batch_size):
  540. char_list = []
  541. conf_list = []
  542. for idx in range(len(text_index[batch_idx])):
  543. if text_index[batch_idx][idx] == self.end_idx:
  544. break
  545. if text_index[batch_idx][idx] in \
  546. [self.padding_idx, self.unknown_idx]:
  547. continue
  548. char_list.append(self.character[int(text_index[batch_idx][
  549. idx])])
  550. if text_prob is not None:
  551. conf_list.append(text_prob[batch_idx][idx])
  552. else:
  553. conf_list.append(1)
  554. text = ''.join(char_list)
  555. if len(text) > 0:
  556. result_list.append((text, np.mean(conf_list).tolist()))
  557. else:
  558. # here confidence of empty recog result is 1
  559. result_list.append(('', 1))
  560. return result_list
  561. def __call__(self, preds, label=None, *args, **kwargs):
  562. if isinstance(preds, paddle.Tensor):
  563. preds = preds.numpy()
  564. preds_idx = preds.argmax(axis=2)
  565. preds_prob = preds.max(axis=2)
  566. text = self.decode(preds_idx, preds_prob)
  567. if label is None:
  568. return text
  569. label = self.decode(label)
  570. return text, label
  571. class NRTRLabelDecode(BaseRecLabelDecode):
  572. """ Convert between text-label and text-index """
  573. def __init__(self, character_dict_path=None, use_space_char=True, **kwargs):
  574. super(NRTRLabelDecode, self).__init__(character_dict_path,
  575. use_space_char)
  576. def __call__(self, preds, label=None, *args, **kwargs):
  577. if len(preds) == 2:
  578. preds_id = preds[0]
  579. preds_prob = preds[1]
  580. if isinstance(preds_id, paddle.Tensor):
  581. preds_id = preds_id.numpy()
  582. if isinstance(preds_prob, paddle.Tensor):
  583. preds_prob = preds_prob.numpy()
  584. if preds_id[0][0] == 2:
  585. preds_idx = preds_id[:, 1:]
  586. preds_prob = preds_prob[:, 1:]
  587. else:
  588. preds_idx = preds_id
  589. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  590. if label is None:
  591. return text
  592. label = self.decode(label[:, 1:])
  593. else:
  594. if isinstance(preds, paddle.Tensor):
  595. preds = preds.numpy()
  596. preds_idx = preds.argmax(axis=2)
  597. preds_prob = preds.max(axis=2)
  598. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  599. if label is None:
  600. return text
  601. label = self.decode(label[:, 1:])
  602. return text, label
  603. def add_special_char(self, dict_character):
  604. dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
  605. return dict_character
  606. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  607. """ convert text-index into text-label. """
  608. result_list = []
  609. batch_size = len(text_index)
  610. for batch_idx in range(batch_size):
  611. char_list = []
  612. conf_list = []
  613. for idx in range(len(text_index[batch_idx])):
  614. try:
  615. char_idx = self.character[int(text_index[batch_idx][idx])]
  616. except:
  617. continue
  618. if char_idx == '</s>': # end
  619. break
  620. char_list.append(char_idx)
  621. if text_prob is not None:
  622. conf_list.append(text_prob[batch_idx][idx])
  623. else:
  624. conf_list.append(1)
  625. text = ''.join(char_list)
  626. result_list.append((text.lower(), np.mean(conf_list).tolist()))
  627. return result_list
  628. class ViTSTRLabelDecode(NRTRLabelDecode):
  629. """ Convert between text-label and text-index """
  630. def __init__(self, character_dict_path=None, use_space_char=False,
  631. **kwargs):
  632. super(ViTSTRLabelDecode, self).__init__(character_dict_path,
  633. use_space_char)
  634. def __call__(self, preds, label=None, *args, **kwargs):
  635. if isinstance(preds, paddle.Tensor):
  636. preds = preds[:, 1:].numpy()
  637. else:
  638. preds = preds[:, 1:]
  639. preds_idx = preds.argmax(axis=2)
  640. preds_prob = preds.max(axis=2)
  641. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  642. if label is None:
  643. return text
  644. label = self.decode(label[:, 1:])
  645. return text, label
  646. def add_special_char(self, dict_character):
  647. dict_character = ['<s>', '</s>'] + dict_character
  648. return dict_character
  649. class ABINetLabelDecode(NRTRLabelDecode):
  650. """ Convert between text-label and text-index """
  651. def __init__(self, character_dict_path=None, use_space_char=False,
  652. **kwargs):
  653. super(ABINetLabelDecode, self).__init__(character_dict_path,
  654. use_space_char)
  655. def __call__(self, preds, label=None, *args, **kwargs):
  656. if isinstance(preds, dict):
  657. preds = preds['align'][-1].numpy()
  658. elif isinstance(preds, paddle.Tensor):
  659. preds = preds.numpy()
  660. else:
  661. preds = preds
  662. preds_idx = preds.argmax(axis=2)
  663. preds_prob = preds.max(axis=2)
  664. text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
  665. if label is None:
  666. return text
  667. label = self.decode(label)
  668. return text, label
  669. def add_special_char(self, dict_character):
  670. dict_character = ['</s>'] + dict_character
  671. return dict_character
  672. class SPINLabelDecode(AttnLabelDecode):
  673. """ Convert between text-label and text-index """
  674. def __init__(self, character_dict_path=None, use_space_char=False,
  675. **kwargs):
  676. super(SPINLabelDecode, self).__init__(character_dict_path,
  677. use_space_char)
  678. def add_special_char(self, dict_character):
  679. self.beg_str = "sos"
  680. self.end_str = "eos"
  681. dict_character = dict_character
  682. dict_character = [self.beg_str] + [self.end_str] + dict_character
  683. return dict_character
  684. class VLLabelDecode(BaseRecLabelDecode):
  685. """ Convert between text-label and text-index """
  686. def __init__(self, character_dict_path=None, use_space_char=False,
  687. **kwargs):
  688. super(VLLabelDecode, self).__init__(character_dict_path, use_space_char)
  689. self.max_text_length = kwargs.get('max_text_length', 25)
  690. self.nclass = len(self.character) + 1
  691. def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
  692. """ convert text-index into text-label. """
  693. result_list = []
  694. ignored_tokens = self.get_ignored_tokens()
  695. batch_size = len(text_index)
  696. for batch_idx in range(batch_size):
  697. selection = np.ones(len(text_index[batch_idx]), dtype=bool)
  698. if is_remove_duplicate:
  699. selection[1:] = text_index[batch_idx][1:] != text_index[
  700. batch_idx][:-1]
  701. for ignored_token in ignored_tokens:
  702. selection &= text_index[batch_idx] != ignored_token
  703. char_list = [
  704. self.character[text_id - 1]
  705. for text_id in text_index[batch_idx][selection]
  706. ]
  707. if text_prob is not None:
  708. conf_list = text_prob[batch_idx][selection]
  709. else:
  710. conf_list = [1] * len(selection)
  711. if len(conf_list) == 0:
  712. conf_list = [0]
  713. text = ''.join(char_list)
  714. result_list.append((text, np.mean(conf_list).tolist()))
  715. return result_list
  716. def __call__(self, preds, label=None, length=None, *args, **kwargs):
  717. if len(preds) == 2: # eval mode
  718. text_pre, x = preds
  719. b = text_pre.shape[1]
  720. lenText = self.max_text_length
  721. nsteps = self.max_text_length
  722. if not isinstance(text_pre, paddle.Tensor):
  723. text_pre = paddle.to_tensor(text_pre, dtype='float32')
  724. out_res = paddle.zeros(
  725. shape=[lenText, b, self.nclass], dtype=x.dtype)
  726. out_length = paddle.zeros(shape=[b], dtype=x.dtype)
  727. now_step = 0
  728. for _ in range(nsteps):
  729. if 0 in out_length and now_step < nsteps:
  730. tmp_result = text_pre[now_step, :, :]
  731. out_res[now_step] = tmp_result
  732. tmp_result = tmp_result.topk(1)[1].squeeze(axis=1)
  733. for j in range(b):
  734. if out_length[j] == 0 and tmp_result[j] == 0:
  735. out_length[j] = now_step + 1
  736. now_step += 1
  737. for j in range(0, b):
  738. if int(out_length[j]) == 0:
  739. out_length[j] = nsteps
  740. start = 0
  741. output = paddle.zeros(
  742. shape=[int(out_length.sum()), self.nclass], dtype=x.dtype)
  743. for i in range(0, b):
  744. cur_length = int(out_length[i])
  745. output[start:start + cur_length] = out_res[0:cur_length, i, :]
  746. start += cur_length
  747. net_out = output
  748. length = out_length
  749. else: # train mode
  750. net_out = preds[0]
  751. length = length
  752. net_out = paddle.concat([t[:l] for t, l in zip(net_out, length)])
  753. text = []
  754. if not isinstance(net_out, paddle.Tensor):
  755. net_out = paddle.to_tensor(net_out, dtype='float32')
  756. net_out = F.softmax(net_out, axis=1)
  757. for i in range(0, length.shape[0]):
  758. preds_idx = net_out[int(length[:i].sum()):int(length[:i].sum(
  759. ) + length[i])].topk(1)[1][:, 0].tolist()
  760. preds_text = ''.join([
  761. self.character[idx - 1]
  762. if idx > 0 and idx <= len(self.character) else ''
  763. for idx in preds_idx
  764. ])
  765. preds_prob = net_out[int(length[:i].sum()):int(length[:i].sum(
  766. ) + length[i])].topk(1)[0][:, 0]
  767. preds_prob = paddle.exp(
  768. paddle.log(preds_prob).sum() / (preds_prob.shape[0] + 1e-6))
  769. text.append((preds_text, preds_prob.numpy()[0]))
  770. if label is None:
  771. return text
  772. label = self.decode(label)
  773. return text, label
  774. class CANLabelDecode(BaseRecLabelDecode):
  775. """ Convert between latex-symbol and symbol-index """
  776. def __init__(self, character_dict_path=None, use_space_char=False,
  777. **kwargs):
  778. super(CANLabelDecode, self).__init__(character_dict_path,
  779. use_space_char)
  780. def decode(self, text_index, preds_prob=None):
  781. result_list = []
  782. batch_size = len(text_index)
  783. for batch_idx in range(batch_size):
  784. seq_end = text_index[batch_idx].argmin(0)
  785. idx_list = text_index[batch_idx][:seq_end].tolist()
  786. symbol_list = [self.character[idx] for idx in idx_list]
  787. probs = []
  788. if preds_prob is not None:
  789. probs = preds_prob[batch_idx][:len(symbol_list)].tolist()
  790. result_list.append([' '.join(symbol_list), probs])
  791. return result_list
  792. def __call__(self, preds, label=None, *args, **kwargs):
  793. pred_prob, _, _, _ = preds
  794. preds_idx = pred_prob.argmax(axis=2)
  795. text = self.decode(preds_idx)
  796. if label is None:
  797. return text
  798. label = self.decode(label)
  799. return text, label