infer_table.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import os
  19. import sys
  20. import json
  21. __dir__ = os.path.dirname(os.path.abspath(__file__))
  22. sys.path.append(__dir__)
  23. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  24. os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
  25. import paddle
  26. from paddle.jit import to_static
  27. from ppocr.data import create_operators, transform
  28. from ppocr.modeling.architectures import build_model
  29. from ppocr.postprocess import build_post_process
  30. from ppocr.utils.save_load import load_model
  31. from ppocr.utils.utility import get_image_file_list
  32. from ppocr.utils.visual import draw_rectangle
  33. from tools.infer.utility import draw_boxes
  34. import tools.program as program
  35. import cv2
  36. @paddle.no_grad()
  37. def main(config, device, logger, vdl_writer):
  38. global_config = config['Global']
  39. # build post process
  40. post_process_class = build_post_process(config['PostProcess'],
  41. global_config)
  42. # build model
  43. if hasattr(post_process_class, 'character'):
  44. config['Architecture']["Head"]['out_channels'] = len(
  45. getattr(post_process_class, 'character'))
  46. model = build_model(config['Architecture'])
  47. algorithm = config['Architecture']['algorithm']
  48. load_model(config, model)
  49. # create data ops
  50. transforms = []
  51. for op in config['Eval']['dataset']['transforms']:
  52. op_name = list(op)[0]
  53. if 'Encode' in op_name:
  54. continue
  55. if op_name == 'KeepKeys':
  56. op[op_name]['keep_keys'] = ['image', 'shape']
  57. transforms.append(op)
  58. global_config['infer_mode'] = True
  59. ops = create_operators(transforms, global_config)
  60. save_res_path = config['Global']['save_res_path']
  61. os.makedirs(save_res_path, exist_ok=True)
  62. model.eval()
  63. with open(
  64. os.path.join(save_res_path, 'infer.txt'), mode='w',
  65. encoding='utf-8') as f_w:
  66. for file in get_image_file_list(config['Global']['infer_img']):
  67. logger.info("infer_img: {}".format(file))
  68. with open(file, 'rb') as f:
  69. img = f.read()
  70. data = {'image': img}
  71. batch = transform(data, ops)
  72. images = np.expand_dims(batch[0], axis=0)
  73. shape_list = np.expand_dims(batch[1], axis=0)
  74. images = paddle.to_tensor(images)
  75. preds = model(images)
  76. post_result = post_process_class(preds, [shape_list])
  77. structure_str_list = post_result['structure_batch_list'][0]
  78. bbox_list = post_result['bbox_batch_list'][0]
  79. structure_str_list = structure_str_list[0]
  80. structure_str_list = [
  81. '<html>', '<body>', '<table>'
  82. ] + structure_str_list + ['</table>', '</body>', '</html>']
  83. bbox_list_str = json.dumps(bbox_list.tolist())
  84. logger.info("result: {}, {}".format(structure_str_list,
  85. bbox_list_str))
  86. f_w.write("result: {}, {}\n".format(structure_str_list,
  87. bbox_list_str))
  88. if len(bbox_list) > 0 and len(bbox_list[0]) == 4:
  89. img = draw_rectangle(file, bbox_list)
  90. else:
  91. img = draw_boxes(cv2.imread(file), bbox_list)
  92. cv2.imwrite(
  93. os.path.join(save_res_path, os.path.basename(file)), img)
  94. logger.info('save result to {}'.format(save_res_path))
  95. logger.info("success!")
  96. if __name__ == '__main__':
  97. config, device, logger, vdl_writer = program.preprocess()
  98. main(config, device, logger, vdl_writer)