eval_end2end.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import re
  16. import sys
  17. import shapely
  18. from shapely.geometry import Polygon
  19. import numpy as np
  20. from collections import defaultdict
  21. import operator
  22. import editdistance
  23. def strQ2B(ustring):
  24. rstring = ""
  25. for uchar in ustring:
  26. inside_code = ord(uchar)
  27. if inside_code == 12288:
  28. inside_code = 32
  29. elif (inside_code >= 65281 and inside_code <= 65374):
  30. inside_code -= 65248
  31. rstring += chr(inside_code)
  32. return rstring
  33. def polygon_from_str(polygon_points):
  34. """
  35. Create a shapely polygon object from gt or dt line.
  36. """
  37. polygon_points = np.array(polygon_points).reshape(4, 2)
  38. polygon = Polygon(polygon_points).convex_hull
  39. return polygon
  40. def polygon_iou(poly1, poly2):
  41. """
  42. Intersection over union between two shapely polygons.
  43. """
  44. if not poly1.intersects(
  45. poly2): # this test is fast and can accelerate calculation
  46. iou = 0
  47. else:
  48. try:
  49. inter_area = poly1.intersection(poly2).area
  50. union_area = poly1.area + poly2.area - inter_area
  51. iou = float(inter_area) / union_area
  52. except shapely.geos.TopologicalError:
  53. # except Exception as e:
  54. # print(e)
  55. print('shapely.geos.TopologicalError occurred, iou set to 0')
  56. iou = 0
  57. return iou
  58. def ed(str1, str2):
  59. return editdistance.eval(str1, str2)
  60. def e2e_eval(gt_dir, res_dir, ignore_blank=False):
  61. print('start testing...')
  62. iou_thresh = 0.5
  63. val_names = os.listdir(gt_dir)
  64. num_gt_chars = 0
  65. gt_count = 0
  66. dt_count = 0
  67. hit = 0
  68. ed_sum = 0
  69. for i, val_name in enumerate(val_names):
  70. with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
  71. gt_lines = [o.strip() for o in f.readlines()]
  72. gts = []
  73. ignore_masks = []
  74. for line in gt_lines:
  75. parts = line.strip().split('\t')
  76. # ignore illegal data
  77. if len(parts) < 9:
  78. continue
  79. assert (len(parts) < 11)
  80. if len(parts) == 9:
  81. gts.append(parts[:8] + [''])
  82. else:
  83. gts.append(parts[:8] + [parts[-1]])
  84. ignore_masks.append(parts[8])
  85. val_path = os.path.join(res_dir, val_name)
  86. if not os.path.exists(val_path):
  87. dt_lines = []
  88. else:
  89. with open(val_path, encoding='utf-8') as f:
  90. dt_lines = [o.strip() for o in f.readlines()]
  91. dts = []
  92. for line in dt_lines:
  93. # print(line)
  94. parts = line.strip().split("\t")
  95. assert (len(parts) < 10), "line error: {}".format(line)
  96. if len(parts) == 8:
  97. dts.append(parts + [''])
  98. else:
  99. dts.append(parts)
  100. dt_match = [False] * len(dts)
  101. gt_match = [False] * len(gts)
  102. all_ious = defaultdict(tuple)
  103. for index_gt, gt in enumerate(gts):
  104. gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
  105. gt_poly = polygon_from_str(gt_coors)
  106. for index_dt, dt in enumerate(dts):
  107. dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
  108. dt_poly = polygon_from_str(dt_coors)
  109. iou = polygon_iou(dt_poly, gt_poly)
  110. if iou >= iou_thresh:
  111. all_ious[(index_gt, index_dt)] = iou
  112. sorted_ious = sorted(
  113. all_ious.items(), key=operator.itemgetter(1), reverse=True)
  114. sorted_gt_dt_pairs = [item[0] for item in sorted_ious]
  115. # matched gt and dt
  116. for gt_dt_pair in sorted_gt_dt_pairs:
  117. index_gt, index_dt = gt_dt_pair
  118. if gt_match[index_gt] == False and dt_match[index_dt] == False:
  119. gt_match[index_gt] = True
  120. dt_match[index_dt] = True
  121. if ignore_blank:
  122. gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
  123. dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
  124. else:
  125. gt_str = strQ2B(gts[index_gt][8])
  126. dt_str = strQ2B(dts[index_dt][8])
  127. if ignore_masks[index_gt] == '0':
  128. ed_sum += ed(gt_str, dt_str)
  129. num_gt_chars += len(gt_str)
  130. if gt_str == dt_str:
  131. hit += 1
  132. gt_count += 1
  133. dt_count += 1
  134. # unmatched dt
  135. for tindex, dt_match_flag in enumerate(dt_match):
  136. if dt_match_flag == False:
  137. dt_str = dts[tindex][8]
  138. gt_str = ''
  139. ed_sum += ed(dt_str, gt_str)
  140. dt_count += 1
  141. # unmatched gt
  142. for tindex, gt_match_flag in enumerate(gt_match):
  143. if gt_match_flag == False and ignore_masks[tindex] == '0':
  144. dt_str = ''
  145. gt_str = gts[tindex][8]
  146. ed_sum += ed(gt_str, dt_str)
  147. num_gt_chars += len(gt_str)
  148. gt_count += 1
  149. eps = 1e-9
  150. print('hit, dt_count, gt_count', hit, dt_count, gt_count)
  151. precision = hit / (dt_count + eps)
  152. recall = hit / (gt_count + eps)
  153. fmeasure = 2.0 * precision * recall / (precision + recall + eps)
  154. avg_edit_dist_img = ed_sum / len(val_names)
  155. avg_edit_dist_field = ed_sum / (gt_count + eps)
  156. character_acc = 1 - ed_sum / (num_gt_chars + eps)
  157. print('character_acc: %.2f' % (character_acc * 100) + "%")
  158. print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
  159. print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
  160. print('precision: %.2f' % (precision * 100) + "%")
  161. print('recall: %.2f' % (recall * 100) + "%")
  162. print('fmeasure: %.2f' % (fmeasure * 100) + "%")
  163. if __name__ == '__main__':
  164. # if len(sys.argv) != 3:
  165. # print("python3 ocr_e2e_eval.py gt_dir res_dir")
  166. # exit(-1)
  167. # gt_folder = sys.argv[1]
  168. # pred_folder = sys.argv[2]
  169. gt_folder = sys.argv[1]
  170. pred_folder = sys.argv[2]
  171. e2e_eval(gt_folder, pred_folder)