predict_system.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. import subprocess
  17. __dir__ = os.path.dirname(os.path.abspath(__file__))
  18. sys.path.append(__dir__)
  19. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '../')))
  20. os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
  21. import cv2
  22. import json
  23. import numpy as np
  24. import time
  25. import logging
  26. from copy import deepcopy
  27. from ppocr.utils.utility import get_image_file_list, check_and_read
  28. from ppocr.utils.logging import get_logger
  29. from ppocr.utils.visual import draw_ser_results, draw_re_results
  30. from tools.infer.predict_system import TextSystem
  31. from ppstructure.layout.predict_layout import LayoutPredictor
  32. from ppstructure.table.predict_table import TableSystem, to_excel
  33. from ppstructure.utility import parse_args, draw_structure_result
  34. logger = get_logger()
  35. class StructureSystem(object):
  36. def __init__(self, args):
  37. self.mode = args.mode
  38. self.recovery = args.recovery
  39. self.image_orientation_predictor = None
  40. if args.image_orientation:
  41. import paddleclas
  42. self.image_orientation_predictor = paddleclas.PaddleClas(
  43. model_name="text_image_orientation")
  44. if self.mode == 'structure':
  45. if not args.show_log:
  46. logger.setLevel(logging.INFO)
  47. if args.layout == False and args.ocr == True:
  48. args.ocr = False
  49. logger.warning(
  50. "When args.layout is false, args.ocr is automatically set to false"
  51. )
  52. args.drop_score = 0
  53. # init model
  54. self.layout_predictor = None
  55. self.text_system = None
  56. self.table_system = None
  57. if args.layout:
  58. self.layout_predictor = LayoutPredictor(args)
  59. if args.ocr:
  60. self.text_system = TextSystem(args)
  61. if args.table:
  62. if self.text_system is not None:
  63. self.table_system = TableSystem(
  64. args, self.text_system.text_detector,
  65. self.text_system.text_recognizer)
  66. else:
  67. self.table_system = TableSystem(args)
  68. elif self.mode == 'kie':
  69. from ppstructure.kie.predict_kie_token_ser_re import SerRePredictor
  70. self.kie_predictor = SerRePredictor(args)
  71. def __call__(self, img, return_ocr_result_in_table=False, img_idx=0):
  72. time_dict = {
  73. 'image_orientation': 0,
  74. 'layout': 0,
  75. 'table': 0,
  76. 'table_match': 0,
  77. 'det': 0,
  78. 'rec': 0,
  79. 'kie': 0,
  80. 'all': 0
  81. }
  82. start = time.time()
  83. if self.image_orientation_predictor is not None:
  84. tic = time.time()
  85. cls_result = self.image_orientation_predictor.predict(
  86. input_data=img)
  87. cls_res = next(cls_result)
  88. angle = cls_res[0]['label_names'][0]
  89. cv_rotate_code = {
  90. '90': cv2.ROTATE_90_COUNTERCLOCKWISE,
  91. '180': cv2.ROTATE_180,
  92. '270': cv2.ROTATE_90_CLOCKWISE
  93. }
  94. if angle in cv_rotate_code:
  95. img = cv2.rotate(img, cv_rotate_code[angle])
  96. toc = time.time()
  97. time_dict['image_orientation'] = toc - tic
  98. if self.mode == 'structure':
  99. ori_im = img.copy()
  100. if self.layout_predictor is not None:
  101. layout_res, elapse = self.layout_predictor(img)
  102. time_dict['layout'] += elapse
  103. else:
  104. h, w = ori_im.shape[:2]
  105. layout_res = [dict(bbox=None, label='table')]
  106. res_list = []
  107. for region in layout_res:
  108. res = ''
  109. if region['bbox'] is not None:
  110. x1, y1, x2, y2 = region['bbox']
  111. x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
  112. roi_img = ori_im[y1:y2, x1:x2, :]
  113. else:
  114. x1, y1, x2, y2 = 0, 0, w, h
  115. roi_img = ori_im
  116. if region['label'] == 'table':
  117. if self.table_system is not None:
  118. res, table_time_dict = self.table_system(
  119. roi_img, return_ocr_result_in_table)
  120. time_dict['table'] += table_time_dict['table']
  121. time_dict['table_match'] += table_time_dict['match']
  122. time_dict['det'] += table_time_dict['det']
  123. time_dict['rec'] += table_time_dict['rec']
  124. else:
  125. if self.text_system is not None:
  126. if self.recovery:
  127. wht_im = np.ones(ori_im.shape, dtype=ori_im.dtype)
  128. wht_im[y1:y2, x1:x2, :] = roi_img
  129. filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
  130. wht_im)
  131. else:
  132. filter_boxes, filter_rec_res, ocr_time_dict = self.text_system(
  133. roi_img)
  134. time_dict['det'] += ocr_time_dict['det']
  135. time_dict['rec'] += ocr_time_dict['rec']
  136. # remove style char,
  137. # when using the recognition model trained on the PubtabNet dataset,
  138. # it will recognize the text format in the table, such as <b>
  139. style_token = [
  140. '<strike>', '<strike>', '<sup>', '</sub>', '<b>',
  141. '</b>', '<sub>', '</sup>', '<overline>',
  142. '</overline>', '<underline>', '</underline>', '<i>',
  143. '</i>'
  144. ]
  145. res = []
  146. for box, rec_res in zip(filter_boxes, filter_rec_res):
  147. rec_str, rec_conf = rec_res
  148. for token in style_token:
  149. if token in rec_str:
  150. rec_str = rec_str.replace(token, '')
  151. if not self.recovery:
  152. box += [x1, y1]
  153. res.append({
  154. 'text': rec_str,
  155. 'confidence': float(rec_conf),
  156. 'text_region': box.tolist()
  157. })
  158. res_list.append({
  159. 'type': region['label'].lower(),
  160. 'bbox': [x1, y1, x2, y2],
  161. 'img': roi_img,
  162. 'res': res,
  163. 'img_idx': img_idx
  164. })
  165. end = time.time()
  166. time_dict['all'] = end - start
  167. return res_list, time_dict
  168. elif self.mode == 'kie':
  169. re_res, elapse = self.kie_predictor(img)
  170. time_dict['kie'] = elapse
  171. time_dict['all'] = elapse
  172. return re_res[0], time_dict
  173. return None, None
  174. def save_structure_res(res, save_folder, img_name, img_idx=0):
  175. excel_save_folder = os.path.join(save_folder, img_name)
  176. os.makedirs(excel_save_folder, exist_ok=True)
  177. res_cp = deepcopy(res)
  178. # save res
  179. with open(
  180. os.path.join(excel_save_folder, 'res_{}.txt'.format(img_idx)),
  181. 'w',
  182. encoding='utf8') as f:
  183. for region in res_cp:
  184. roi_img = region.pop('img')
  185. f.write('{}\n'.format(json.dumps(region)))
  186. if region['type'].lower() == 'table' and len(region[
  187. 'res']) > 0 and 'html' in region['res']:
  188. excel_path = os.path.join(
  189. excel_save_folder,
  190. '{}_{}.xlsx'.format(region['bbox'], img_idx))
  191. to_excel(region['res']['html'], excel_path)
  192. elif region['type'].lower() == 'figure':
  193. img_path = os.path.join(
  194. excel_save_folder,
  195. '{}_{}.jpg'.format(region['bbox'], img_idx))
  196. cv2.imwrite(img_path, roi_img)
  197. def main(args):
  198. image_file_list = get_image_file_list(args.image_dir)
  199. image_file_list = image_file_list
  200. image_file_list = image_file_list[args.process_id::args.total_process_num]
  201. if not args.use_pdf2docx_api:
  202. structure_sys = StructureSystem(args)
  203. save_folder = os.path.join(args.output, structure_sys.mode)
  204. os.makedirs(save_folder, exist_ok=True)
  205. img_num = len(image_file_list)
  206. for i, image_file in enumerate(image_file_list):
  207. logger.info("[{}/{}] {}".format(i, img_num, image_file))
  208. img, flag_gif, flag_pdf = check_and_read(image_file)
  209. img_name = os.path.basename(image_file).split('.')[0]
  210. if args.recovery and args.use_pdf2docx_api and flag_pdf:
  211. from pdf2docx.converter import Converter
  212. os.makedirs(args.output, exist_ok=True)
  213. docx_file = os.path.join(args.output,
  214. '{}_api.docx'.format(img_name))
  215. cv = Converter(image_file)
  216. cv.convert(docx_file)
  217. cv.close()
  218. logger.info('docx save to {}'.format(docx_file))
  219. continue
  220. if not flag_gif and not flag_pdf:
  221. img = cv2.imread(image_file)
  222. if not flag_pdf:
  223. if img is None:
  224. logger.error("error in loading image:{}".format(image_file))
  225. continue
  226. imgs = [img]
  227. else:
  228. imgs = img
  229. all_res = []
  230. for index, img in enumerate(imgs):
  231. res, time_dict = structure_sys(img, img_idx=index)
  232. img_save_path = os.path.join(save_folder, img_name,
  233. 'show_{}.jpg'.format(index))
  234. os.makedirs(os.path.join(save_folder, img_name), exist_ok=True)
  235. if structure_sys.mode == 'structure' and res != []:
  236. draw_img = draw_structure_result(img, res, args.vis_font_path)
  237. save_structure_res(res, save_folder, img_name, index)
  238. elif structure_sys.mode == 'kie':
  239. if structure_sys.kie_predictor.predictor is not None:
  240. draw_img = draw_re_results(
  241. img, res, font_path=args.vis_font_path)
  242. else:
  243. draw_img = draw_ser_results(
  244. img, res, font_path=args.vis_font_path)
  245. with open(
  246. os.path.join(save_folder, img_name,
  247. 'res_{}_kie.txt'.format(index)),
  248. 'w',
  249. encoding='utf8') as f:
  250. res_str = '{}\t{}\n'.format(
  251. image_file,
  252. json.dumps(
  253. {
  254. "ocr_info": res
  255. }, ensure_ascii=False))
  256. f.write(res_str)
  257. if res != []:
  258. cv2.imwrite(img_save_path, draw_img)
  259. logger.info('result save to {}'.format(img_save_path))
  260. if args.recovery and res != []:
  261. from ppstructure.recovery.recovery_to_doc import sorted_layout_boxes, convert_info_docx
  262. h, w, _ = img.shape
  263. res = sorted_layout_boxes(res, w)
  264. all_res += res
  265. if args.recovery and all_res != []:
  266. try:
  267. convert_info_docx(img, all_res, save_folder, img_name)
  268. except Exception as ex:
  269. logger.error("error in layout recovery image:{}, err msg: {}".
  270. format(image_file, ex))
  271. continue
  272. logger.info("Predict time : {:.3f}s".format(time_dict['all']))
  273. if __name__ == "__main__":
  274. args = parse_args()
  275. if args.use_mp:
  276. p_list = []
  277. total_process_num = args.total_process_num
  278. for process_id in range(total_process_num):
  279. cmd = [sys.executable, "-u"] + sys.argv + [
  280. "--process_id={}".format(process_id),
  281. "--use_mp={}".format(False)
  282. ]
  283. p = subprocess.Popen(cmd, stdout=sys.stdout, stderr=sys.stdout)
  284. p_list.append(p)
  285. for p in p_list:
  286. p.wait()
  287. else:
  288. main(args)