eval_table.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. __dir__ = os.path.dirname(os.path.abspath(__file__))
  17. sys.path.append(__dir__)
  18. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../..')))
  19. import cv2
  20. import pickle
  21. import paddle
  22. from tqdm import tqdm
  23. from ppstructure.table.table_metric import TEDS
  24. from ppstructure.table.predict_table import TableSystem
  25. from ppstructure.utility import init_args
  26. from ppocr.utils.logging import get_logger
  27. logger = get_logger()
  28. def parse_args():
  29. parser = init_args()
  30. parser.add_argument("--gt_path", type=str)
  31. return parser.parse_args()
  32. def load_txt(txt_path):
  33. pred_html_dict = {}
  34. if not os.path.exists(txt_path):
  35. return pred_html_dict
  36. with open(txt_path, encoding='utf-8') as f:
  37. lines = f.readlines()
  38. for line in lines:
  39. line = line.strip().split('\t')
  40. img_name, pred_html = line
  41. pred_html_dict[img_name] = pred_html
  42. return pred_html_dict
  43. def load_result(path):
  44. data = {}
  45. if os.path.exists(path):
  46. data = pickle.load(open(path, 'rb'))
  47. return data
  48. def save_result(path, data):
  49. old_data = load_result(path)
  50. old_data.update(data)
  51. with open(path, 'wb') as f:
  52. pickle.dump(old_data, f)
  53. def main(gt_path, img_root, args):
  54. os.makedirs(args.output, exist_ok=True)
  55. # init TableSystem
  56. text_sys = TableSystem(args)
  57. # load gt and preds html result
  58. gt_html_dict = load_txt(gt_path)
  59. ocr_result = load_result(os.path.join(args.output, 'ocr.pickle'))
  60. structure_result = load_result(
  61. os.path.join(args.output, 'structure.pickle'))
  62. pred_htmls = []
  63. gt_htmls = []
  64. for img_name, gt_html in tqdm(gt_html_dict.items()):
  65. img = cv2.imread(os.path.join(img_root, img_name))
  66. # run ocr and save result
  67. if img_name not in ocr_result:
  68. dt_boxes, rec_res, _, _ = text_sys._ocr(img)
  69. ocr_result[img_name] = [dt_boxes, rec_res]
  70. save_result(os.path.join(args.output, 'ocr.pickle'), ocr_result)
  71. # run structure and save result
  72. if img_name not in structure_result:
  73. structure_res, _ = text_sys._structure(img)
  74. structure_result[img_name] = structure_res
  75. save_result(
  76. os.path.join(args.output, 'structure.pickle'), structure_result)
  77. dt_boxes, rec_res = ocr_result[img_name]
  78. structure_res = structure_result[img_name]
  79. # match ocr and structure
  80. pred_html = text_sys.match(structure_res, dt_boxes, rec_res)
  81. pred_htmls.append(pred_html)
  82. gt_htmls.append(gt_html)
  83. # compute teds
  84. teds = TEDS(n_jobs=16)
  85. scores = teds.batch_evaluate_html(gt_htmls, pred_htmls)
  86. logger.info('teds: {}'.format(sum(scores) / len(scores)))
  87. if __name__ == '__main__':
  88. args = parse_args()
  89. main(args.gt_path, args.image_dir, args)