Deteval.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import numpy as np
  16. import scipy.io as io
  17. import Polygon as plg
  18. from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
  19. def get_socre_A(gt_dir, pred_dict):
  20. allInputs = 1
  21. def input_reading_mod(pred_dict):
  22. """This helper reads input from txt files"""
  23. det = []
  24. n = len(pred_dict)
  25. for i in range(n):
  26. points = pred_dict[i]['points']
  27. text = pred_dict[i]['texts']
  28. point = ",".join(map(str, points.reshape(-1, )))
  29. det.append([point, text])
  30. return det
  31. def gt_reading_mod(gt_dict):
  32. """This helper reads groundtruths from mat files"""
  33. gt = []
  34. n = len(gt_dict)
  35. for i in range(n):
  36. points = gt_dict[i]['points'].tolist()
  37. h = len(points)
  38. text = gt_dict[i]['text']
  39. xx = [
  40. np.array(
  41. ['x:'], dtype='<U2'), 0, np.array(
  42. ['y:'], dtype='<U2'), 0, np.array(
  43. ['#'], dtype='<U1'), np.array(
  44. ['#'], dtype='<U1')
  45. ]
  46. t_x, t_y = [], []
  47. for j in range(h):
  48. t_x.append(points[j][0])
  49. t_y.append(points[j][1])
  50. xx[1] = np.array([t_x], dtype='int16')
  51. xx[3] = np.array([t_y], dtype='int16')
  52. if text != "":
  53. xx[4] = np.array([text], dtype='U{}'.format(len(text)))
  54. xx[5] = np.array(['c'], dtype='<U1')
  55. gt.append(xx)
  56. return gt
  57. def detection_filtering(detections, groundtruths, threshold=0.5):
  58. for gt_id, gt in enumerate(groundtruths):
  59. if (gt[5] == '#') and (gt[1].shape[1] > 1):
  60. gt_x = list(map(int, np.squeeze(gt[1])))
  61. gt_y = list(map(int, np.squeeze(gt[3])))
  62. for det_id, detection in enumerate(detections):
  63. detection_orig = detection
  64. detection = [float(x) for x in detection[0].split(',')]
  65. detection = list(map(int, detection))
  66. det_x = detection[0::2]
  67. det_y = detection[1::2]
  68. det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
  69. if det_gt_iou > threshold:
  70. detections[det_id] = []
  71. detections[:] = [item for item in detections if item != []]
  72. return detections
  73. def sigma_calculation(det_x, det_y, gt_x, gt_y):
  74. """
  75. sigma = inter_area / gt_area
  76. """
  77. return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
  78. area(gt_x, gt_y)), 2)
  79. def tau_calculation(det_x, det_y, gt_x, gt_y):
  80. if area(det_x, det_y) == 0.0:
  81. return 0
  82. return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
  83. area(det_x, det_y)), 2)
  84. ##############################Initialization###################################
  85. # global_sigma = []
  86. # global_tau = []
  87. # global_pred_str = []
  88. # global_gt_str = []
  89. ###############################################################################
  90. for input_id in range(allInputs):
  91. if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
  92. input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
  93. input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
  94. and (input_id != 'Deteval_result_non_curved.txt'):
  95. detections = input_reading_mod(pred_dict)
  96. groundtruths = gt_reading_mod(gt_dir)
  97. detections = detection_filtering(
  98. detections,
  99. groundtruths) # filters detections overlapping with DC area
  100. dc_id = []
  101. for i in range(len(groundtruths)):
  102. if groundtruths[i][5] == '#':
  103. dc_id.append(i)
  104. cnt = 0
  105. for a in dc_id:
  106. num = a - cnt
  107. del groundtruths[num]
  108. cnt += 1
  109. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  110. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  111. local_pred_str = {}
  112. local_gt_str = {}
  113. for gt_id, gt in enumerate(groundtruths):
  114. if len(detections) > 0:
  115. for det_id, detection in enumerate(detections):
  116. detection_orig = detection
  117. detection = [float(x) for x in detection[0].split(',')]
  118. detection = list(map(int, detection))
  119. pred_seq_str = detection_orig[1].strip()
  120. det_x = detection[0::2]
  121. det_y = detection[1::2]
  122. gt_x = list(map(int, np.squeeze(gt[1])))
  123. gt_y = list(map(int, np.squeeze(gt[3])))
  124. gt_seq_str = str(gt[4].tolist()[0])
  125. local_sigma_table[gt_id, det_id] = sigma_calculation(
  126. det_x, det_y, gt_x, gt_y)
  127. local_tau_table[gt_id, det_id] = tau_calculation(
  128. det_x, det_y, gt_x, gt_y)
  129. local_pred_str[det_id] = pred_seq_str
  130. local_gt_str[gt_id] = gt_seq_str
  131. global_sigma = local_sigma_table
  132. global_tau = local_tau_table
  133. global_pred_str = local_pred_str
  134. global_gt_str = local_gt_str
  135. single_data = {}
  136. single_data['sigma'] = global_sigma
  137. single_data['global_tau'] = global_tau
  138. single_data['global_pred_str'] = global_pred_str
  139. single_data['global_gt_str'] = global_gt_str
  140. return single_data
  141. def get_socre_B(gt_dir, img_id, pred_dict):
  142. allInputs = 1
  143. def input_reading_mod(pred_dict):
  144. """This helper reads input from txt files"""
  145. det = []
  146. n = len(pred_dict)
  147. for i in range(n):
  148. points = pred_dict[i]['points']
  149. text = pred_dict[i]['texts']
  150. point = ",".join(map(str, points.reshape(-1, )))
  151. det.append([point, text])
  152. return det
  153. def gt_reading_mod(gt_dir, gt_id):
  154. gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
  155. gt = gt['polygt']
  156. return gt
  157. def detection_filtering(detections, groundtruths, threshold=0.5):
  158. for gt_id, gt in enumerate(groundtruths):
  159. if (gt[5] == '#') and (gt[1].shape[1] > 1):
  160. gt_x = list(map(int, np.squeeze(gt[1])))
  161. gt_y = list(map(int, np.squeeze(gt[3])))
  162. for det_id, detection in enumerate(detections):
  163. detection_orig = detection
  164. detection = [float(x) for x in detection[0].split(',')]
  165. detection = list(map(int, detection))
  166. det_x = detection[0::2]
  167. det_y = detection[1::2]
  168. det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
  169. if det_gt_iou > threshold:
  170. detections[det_id] = []
  171. detections[:] = [item for item in detections if item != []]
  172. return detections
  173. def sigma_calculation(det_x, det_y, gt_x, gt_y):
  174. """
  175. sigma = inter_area / gt_area
  176. """
  177. return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
  178. area(gt_x, gt_y)), 2)
  179. def tau_calculation(det_x, det_y, gt_x, gt_y):
  180. if area(det_x, det_y) == 0.0:
  181. return 0
  182. return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
  183. area(det_x, det_y)), 2)
  184. ##############################Initialization###################################
  185. # global_sigma = []
  186. # global_tau = []
  187. # global_pred_str = []
  188. # global_gt_str = []
  189. ###############################################################################
  190. for input_id in range(allInputs):
  191. if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
  192. input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
  193. input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
  194. and (input_id != 'Deteval_result_non_curved.txt'):
  195. detections = input_reading_mod(pred_dict)
  196. groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
  197. detections = detection_filtering(
  198. detections,
  199. groundtruths) # filters detections overlapping with DC area
  200. dc_id = []
  201. for i in range(len(groundtruths)):
  202. if groundtruths[i][5] == '#':
  203. dc_id.append(i)
  204. cnt = 0
  205. for a in dc_id:
  206. num = a - cnt
  207. del groundtruths[num]
  208. cnt += 1
  209. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  210. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  211. local_pred_str = {}
  212. local_gt_str = {}
  213. for gt_id, gt in enumerate(groundtruths):
  214. if len(detections) > 0:
  215. for det_id, detection in enumerate(detections):
  216. detection_orig = detection
  217. detection = [float(x) for x in detection[0].split(',')]
  218. detection = list(map(int, detection))
  219. pred_seq_str = detection_orig[1].strip()
  220. det_x = detection[0::2]
  221. det_y = detection[1::2]
  222. gt_x = list(map(int, np.squeeze(gt[1])))
  223. gt_y = list(map(int, np.squeeze(gt[3])))
  224. gt_seq_str = str(gt[4].tolist()[0])
  225. local_sigma_table[gt_id, det_id] = sigma_calculation(
  226. det_x, det_y, gt_x, gt_y)
  227. local_tau_table[gt_id, det_id] = tau_calculation(
  228. det_x, det_y, gt_x, gt_y)
  229. local_pred_str[det_id] = pred_seq_str
  230. local_gt_str[gt_id] = gt_seq_str
  231. global_sigma = local_sigma_table
  232. global_tau = local_tau_table
  233. global_pred_str = local_pred_str
  234. global_gt_str = local_gt_str
  235. single_data = {}
  236. single_data['sigma'] = global_sigma
  237. single_data['global_tau'] = global_tau
  238. single_data['global_pred_str'] = global_pred_str
  239. single_data['global_gt_str'] = global_gt_str
  240. return single_data
  241. def get_score_C(gt_label, text, pred_bboxes):
  242. """
  243. get score for CentripetalText (CT) prediction.
  244. """
  245. def gt_reading_mod(gt_label, text):
  246. """This helper reads groundtruths from mat files"""
  247. groundtruths = []
  248. nbox = len(gt_label)
  249. for i in range(nbox):
  250. label = {"transcription": text[i][0], "points": gt_label[i].numpy()}
  251. groundtruths.append(label)
  252. return groundtruths
  253. def get_union(pD, pG):
  254. areaA = pD.area()
  255. areaB = pG.area()
  256. return areaA + areaB - get_intersection(pD, pG)
  257. def get_intersection(pD, pG):
  258. pInt = pD & pG
  259. if len(pInt) == 0:
  260. return 0
  261. return pInt.area()
  262. def detection_filtering(detections, groundtruths, threshold=0.5):
  263. for gt in groundtruths:
  264. point_num = gt['points'].shape[1] // 2
  265. if gt['transcription'] == '###' and (point_num > 1):
  266. gt_p = np.array(gt['points']).reshape(point_num,
  267. 2).astype('int32')
  268. gt_p = plg.Polygon(gt_p)
  269. for det_id, detection in enumerate(detections):
  270. det_y = detection[0::2]
  271. det_x = detection[1::2]
  272. det_p = np.concatenate((np.array(det_x), np.array(det_y)))
  273. det_p = det_p.reshape(2, -1).transpose()
  274. det_p = plg.Polygon(det_p)
  275. try:
  276. det_gt_iou = get_intersection(det_p,
  277. gt_p) / det_p.area()
  278. except:
  279. print(det_x, det_y, gt_p)
  280. if det_gt_iou > threshold:
  281. detections[det_id] = []
  282. detections[:] = [item for item in detections if item != []]
  283. return detections
  284. def sigma_calculation(det_p, gt_p):
  285. """
  286. sigma = inter_area / gt_area
  287. """
  288. if gt_p.area() == 0.:
  289. return 0
  290. return get_intersection(det_p, gt_p) / gt_p.area()
  291. def tau_calculation(det_p, gt_p):
  292. """
  293. tau = inter_area / det_area
  294. """
  295. if det_p.area() == 0.:
  296. return 0
  297. return get_intersection(det_p, gt_p) / det_p.area()
  298. detections = []
  299. for item in pred_bboxes:
  300. detections.append(item[:, ::-1].reshape(-1))
  301. groundtruths = gt_reading_mod(gt_label, text)
  302. detections = detection_filtering(
  303. detections, groundtruths) # filters detections overlapping with DC area
  304. for idx in range(len(groundtruths) - 1, -1, -1):
  305. #NOTE: source code use 'orin' to indicate '#', here we use 'anno',
  306. # which may cause slight drop in fscore, about 0.12
  307. if groundtruths[idx]['transcription'] == '###':
  308. groundtruths.pop(idx)
  309. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  310. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  311. for gt_id, gt in enumerate(groundtruths):
  312. if len(detections) > 0:
  313. for det_id, detection in enumerate(detections):
  314. point_num = gt['points'].shape[1] // 2
  315. gt_p = np.array(gt['points']).reshape(point_num,
  316. 2).astype('int32')
  317. gt_p = plg.Polygon(gt_p)
  318. det_y = detection[0::2]
  319. det_x = detection[1::2]
  320. det_p = np.concatenate((np.array(det_x), np.array(det_y)))
  321. det_p = det_p.reshape(2, -1).transpose()
  322. det_p = plg.Polygon(det_p)
  323. local_sigma_table[gt_id, det_id] = sigma_calculation(det_p,
  324. gt_p)
  325. local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
  326. data = {}
  327. data['sigma'] = local_sigma_table
  328. data['global_tau'] = local_tau_table
  329. data['global_pred_str'] = ''
  330. data['global_gt_str'] = ''
  331. return data
  332. def combine_results(all_data, rec_flag=True):
  333. tr = 0.7
  334. tp = 0.6
  335. fsc_k = 0.8
  336. k = 2
  337. global_sigma = []
  338. global_tau = []
  339. global_pred_str = []
  340. global_gt_str = []
  341. for data in all_data:
  342. global_sigma.append(data['sigma'])
  343. global_tau.append(data['global_tau'])
  344. global_pred_str.append(data['global_pred_str'])
  345. global_gt_str.append(data['global_gt_str'])
  346. global_accumulative_recall = 0
  347. global_accumulative_precision = 0
  348. total_num_gt = 0
  349. total_num_det = 0
  350. hit_str_count = 0
  351. hit_count = 0
  352. def one_to_one(local_sigma_table, local_tau_table,
  353. local_accumulative_recall, local_accumulative_precision,
  354. global_accumulative_recall, global_accumulative_precision,
  355. gt_flag, det_flag, idy, rec_flag):
  356. hit_str_num = 0
  357. for gt_id in range(num_gt):
  358. gt_matching_qualified_sigma_candidates = np.where(
  359. local_sigma_table[gt_id, :] > tr)
  360. gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
  361. 0].shape[0]
  362. gt_matching_qualified_tau_candidates = np.where(
  363. local_tau_table[gt_id, :] > tp)
  364. gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
  365. 0].shape[0]
  366. det_matching_qualified_sigma_candidates = np.where(
  367. local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
  368. > tr)
  369. det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
  370. 0].shape[0]
  371. det_matching_qualified_tau_candidates = np.where(
  372. local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
  373. tp)
  374. det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
  375. 0].shape[0]
  376. if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
  377. (det_matching_num_qualified_sigma_candidates == 1) and (
  378. det_matching_num_qualified_tau_candidates == 1):
  379. global_accumulative_recall = global_accumulative_recall + 1.0
  380. global_accumulative_precision = global_accumulative_precision + 1.0
  381. local_accumulative_recall = local_accumulative_recall + 1.0
  382. local_accumulative_precision = local_accumulative_precision + 1.0
  383. gt_flag[0, gt_id] = 1
  384. matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
  385. # recg start
  386. if rec_flag:
  387. gt_str_cur = global_gt_str[idy][gt_id]
  388. pred_str_cur = global_pred_str[idy][matched_det_id[0]
  389. .tolist()[0]]
  390. if pred_str_cur == gt_str_cur:
  391. hit_str_num += 1
  392. else:
  393. if pred_str_cur.lower() == gt_str_cur.lower():
  394. hit_str_num += 1
  395. # recg end
  396. det_flag[0, matched_det_id] = 1
  397. return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
  398. def one_to_many(local_sigma_table, local_tau_table,
  399. local_accumulative_recall, local_accumulative_precision,
  400. global_accumulative_recall, global_accumulative_precision,
  401. gt_flag, det_flag, idy, rec_flag):
  402. hit_str_num = 0
  403. for gt_id in range(num_gt):
  404. # skip the following if the groundtruth was matched
  405. if gt_flag[0, gt_id] > 0:
  406. continue
  407. non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
  408. num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
  409. if num_non_zero_in_sigma >= k:
  410. ####search for all detections that overlaps with this groundtruth
  411. qualified_tau_candidates = np.where((local_tau_table[
  412. gt_id, :] >= tp) & (det_flag[0, :] == 0))
  413. num_qualified_tau_candidates = qualified_tau_candidates[
  414. 0].shape[0]
  415. if num_qualified_tau_candidates == 1:
  416. if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
  417. and
  418. (local_sigma_table[gt_id, qualified_tau_candidates] >=
  419. tr)):
  420. # became an one-to-one case
  421. global_accumulative_recall = global_accumulative_recall + 1.0
  422. global_accumulative_precision = global_accumulative_precision + 1.0
  423. local_accumulative_recall = local_accumulative_recall + 1.0
  424. local_accumulative_precision = local_accumulative_precision + 1.0
  425. gt_flag[0, gt_id] = 1
  426. det_flag[0, qualified_tau_candidates] = 1
  427. # recg start
  428. if rec_flag:
  429. gt_str_cur = global_gt_str[idy][gt_id]
  430. pred_str_cur = global_pred_str[idy][
  431. qualified_tau_candidates[0].tolist()[0]]
  432. if pred_str_cur == gt_str_cur:
  433. hit_str_num += 1
  434. else:
  435. if pred_str_cur.lower() == gt_str_cur.lower():
  436. hit_str_num += 1
  437. # recg end
  438. elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
  439. >= tr):
  440. gt_flag[0, gt_id] = 1
  441. det_flag[0, qualified_tau_candidates] = 1
  442. # recg start
  443. if rec_flag:
  444. gt_str_cur = global_gt_str[idy][gt_id]
  445. pred_str_cur = global_pred_str[idy][
  446. qualified_tau_candidates[0].tolist()[0]]
  447. if pred_str_cur == gt_str_cur:
  448. hit_str_num += 1
  449. else:
  450. if pred_str_cur.lower() == gt_str_cur.lower():
  451. hit_str_num += 1
  452. # recg end
  453. global_accumulative_recall = global_accumulative_recall + fsc_k
  454. global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
  455. local_accumulative_recall = local_accumulative_recall + fsc_k
  456. local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
  457. return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
  458. def many_to_one(local_sigma_table, local_tau_table,
  459. local_accumulative_recall, local_accumulative_precision,
  460. global_accumulative_recall, global_accumulative_precision,
  461. gt_flag, det_flag, idy, rec_flag):
  462. hit_str_num = 0
  463. for det_id in range(num_det):
  464. # skip the following if the detection was matched
  465. if det_flag[0, det_id] > 0:
  466. continue
  467. non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
  468. num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
  469. if num_non_zero_in_tau >= k:
  470. ####search for all detections that overlaps with this groundtruth
  471. qualified_sigma_candidates = np.where((
  472. local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
  473. num_qualified_sigma_candidates = qualified_sigma_candidates[
  474. 0].shape[0]
  475. if num_qualified_sigma_candidates == 1:
  476. if ((local_tau_table[qualified_sigma_candidates, det_id] >=
  477. tp) and
  478. (local_sigma_table[qualified_sigma_candidates, det_id]
  479. >= tr)):
  480. # became an one-to-one case
  481. global_accumulative_recall = global_accumulative_recall + 1.0
  482. global_accumulative_precision = global_accumulative_precision + 1.0
  483. local_accumulative_recall = local_accumulative_recall + 1.0
  484. local_accumulative_precision = local_accumulative_precision + 1.0
  485. gt_flag[0, qualified_sigma_candidates] = 1
  486. det_flag[0, det_id] = 1
  487. # recg start
  488. if rec_flag:
  489. pred_str_cur = global_pred_str[idy][det_id]
  490. gt_len = len(qualified_sigma_candidates[0])
  491. for idx in range(gt_len):
  492. ele_gt_id = qualified_sigma_candidates[
  493. 0].tolist()[idx]
  494. if ele_gt_id not in global_gt_str[idy]:
  495. continue
  496. gt_str_cur = global_gt_str[idy][ele_gt_id]
  497. if pred_str_cur == gt_str_cur:
  498. hit_str_num += 1
  499. break
  500. else:
  501. if pred_str_cur.lower() == gt_str_cur.lower(
  502. ):
  503. hit_str_num += 1
  504. break
  505. # recg end
  506. elif (np.sum(local_tau_table[qualified_sigma_candidates,
  507. det_id]) >= tp):
  508. det_flag[0, det_id] = 1
  509. gt_flag[0, qualified_sigma_candidates] = 1
  510. # recg start
  511. if rec_flag:
  512. pred_str_cur = global_pred_str[idy][det_id]
  513. gt_len = len(qualified_sigma_candidates[0])
  514. for idx in range(gt_len):
  515. ele_gt_id = qualified_sigma_candidates[0].tolist()[
  516. idx]
  517. if ele_gt_id not in global_gt_str[idy]:
  518. continue
  519. gt_str_cur = global_gt_str[idy][ele_gt_id]
  520. if pred_str_cur == gt_str_cur:
  521. hit_str_num += 1
  522. break
  523. else:
  524. if pred_str_cur.lower() == gt_str_cur.lower():
  525. hit_str_num += 1
  526. break
  527. # recg end
  528. global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
  529. global_accumulative_precision = global_accumulative_precision + fsc_k
  530. local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
  531. local_accumulative_precision = local_accumulative_precision + fsc_k
  532. return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
  533. for idx in range(len(global_sigma)):
  534. local_sigma_table = np.array(global_sigma[idx])
  535. local_tau_table = global_tau[idx]
  536. num_gt = local_sigma_table.shape[0]
  537. num_det = local_sigma_table.shape[1]
  538. total_num_gt = total_num_gt + num_gt
  539. total_num_det = total_num_det + num_det
  540. local_accumulative_recall = 0
  541. local_accumulative_precision = 0
  542. gt_flag = np.zeros((1, num_gt))
  543. det_flag = np.zeros((1, num_det))
  544. #######first check for one-to-one case##########
  545. local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
  546. gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
  547. local_accumulative_recall, local_accumulative_precision,
  548. global_accumulative_recall, global_accumulative_precision,
  549. gt_flag, det_flag, idx, rec_flag)
  550. hit_str_count += hit_str_num
  551. #######then check for one-to-many case##########
  552. local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
  553. gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
  554. local_accumulative_recall, local_accumulative_precision,
  555. global_accumulative_recall, global_accumulative_precision,
  556. gt_flag, det_flag, idx, rec_flag)
  557. hit_str_count += hit_str_num
  558. #######then check for many-to-one case##########
  559. local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
  560. gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
  561. local_accumulative_recall, local_accumulative_precision,
  562. global_accumulative_recall, global_accumulative_precision,
  563. gt_flag, det_flag, idx, rec_flag)
  564. hit_str_count += hit_str_num
  565. try:
  566. recall = global_accumulative_recall / total_num_gt
  567. except ZeroDivisionError:
  568. recall = 0
  569. try:
  570. precision = global_accumulative_precision / total_num_det
  571. except ZeroDivisionError:
  572. precision = 0
  573. try:
  574. f_score = 2 * precision * recall / (precision + recall)
  575. except ZeroDivisionError:
  576. f_score = 0
  577. try:
  578. seqerr = 1 - float(hit_str_count) / global_accumulative_recall
  579. except ZeroDivisionError:
  580. seqerr = 1
  581. try:
  582. recall_e2e = float(hit_str_count) / total_num_gt
  583. except ZeroDivisionError:
  584. recall_e2e = 0
  585. try:
  586. precision_e2e = float(hit_str_count) / total_num_det
  587. except ZeroDivisionError:
  588. precision_e2e = 0
  589. try:
  590. f_score_e2e = 2 * precision_e2e * recall_e2e / (
  591. precision_e2e + recall_e2e)
  592. except ZeroDivisionError:
  593. f_score_e2e = 0
  594. final = {
  595. 'total_num_gt': total_num_gt,
  596. 'total_num_det': total_num_det,
  597. 'global_accumulative_recall': global_accumulative_recall,
  598. 'hit_str_count': hit_str_count,
  599. 'recall': recall,
  600. 'precision': precision,
  601. 'f_score': f_score,
  602. 'seqerr': seqerr,
  603. 'recall_e2e': recall_e2e,
  604. 'precision_e2e': precision_e2e,
  605. 'f_score_e2e': f_score_e2e
  606. }
  607. return final