infer_rec.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import numpy as np
  18. import os
  19. import sys
  20. import json
  21. __dir__ = os.path.dirname(os.path.abspath(__file__))
  22. sys.path.append(__dir__)
  23. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  24. os.environ["FLAGS_allocator_strategy"] = 'auto_growth'
  25. import paddle
  26. from ppocr.data import create_operators, transform
  27. from ppocr.modeling.architectures import build_model
  28. from ppocr.postprocess import build_post_process
  29. from ppocr.utils.save_load import load_model
  30. from ppocr.utils.utility import get_image_file_list
  31. import tools.program as program
  32. def main():
  33. global_config = config['Global']
  34. # build post process
  35. post_process_class = build_post_process(config['PostProcess'],
  36. global_config)
  37. # build model
  38. if hasattr(post_process_class, 'character'):
  39. char_num = len(getattr(post_process_class, 'character'))
  40. if config['Architecture']["algorithm"] in ["Distillation",
  41. ]: # distillation model
  42. for key in config['Architecture']["Models"]:
  43. if config['Architecture']['Models'][key]['Head'][
  44. 'name'] == 'MultiHead': # for multi head
  45. out_channels_list = {}
  46. if config['PostProcess'][
  47. 'name'] == 'DistillationSARLabelDecode':
  48. char_num = char_num - 2
  49. out_channels_list['CTCLabelDecode'] = char_num
  50. out_channels_list['SARLabelDecode'] = char_num + 2
  51. config['Architecture']['Models'][key]['Head'][
  52. 'out_channels_list'] = out_channels_list
  53. else:
  54. config['Architecture']["Models"][key]["Head"][
  55. 'out_channels'] = char_num
  56. elif config['Architecture']['Head'][
  57. 'name'] == 'MultiHead': # for multi head loss
  58. out_channels_list = {}
  59. if config['PostProcess']['name'] == 'SARLabelDecode':
  60. char_num = char_num - 2
  61. out_channels_list['CTCLabelDecode'] = char_num
  62. out_channels_list['SARLabelDecode'] = char_num + 2
  63. config['Architecture']['Head'][
  64. 'out_channels_list'] = out_channels_list
  65. else: # base rec model
  66. config['Architecture']["Head"]['out_channels'] = char_num
  67. model = build_model(config['Architecture'])
  68. load_model(config, model)
  69. # create data ops
  70. transforms = []
  71. for op in config['Eval']['dataset']['transforms']:
  72. op_name = list(op)[0]
  73. if 'Label' in op_name:
  74. continue
  75. elif op_name in ['RecResizeImg']:
  76. op[op_name]['infer_mode'] = True
  77. elif op_name == 'KeepKeys':
  78. if config['Architecture']['algorithm'] == "SRN":
  79. op[op_name]['keep_keys'] = [
  80. 'image', 'encoder_word_pos', 'gsrm_word_pos',
  81. 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
  82. ]
  83. elif config['Architecture']['algorithm'] == "SAR":
  84. op[op_name]['keep_keys'] = ['image', 'valid_ratio']
  85. elif config['Architecture']['algorithm'] == "RobustScanner":
  86. op[op_name][
  87. 'keep_keys'] = ['image', 'valid_ratio', 'word_positons']
  88. else:
  89. op[op_name]['keep_keys'] = ['image']
  90. transforms.append(op)
  91. global_config['infer_mode'] = True
  92. ops = create_operators(transforms, global_config)
  93. save_res_path = config['Global'].get('save_res_path',
  94. "./output/rec/predicts_rec.txt")
  95. if not os.path.exists(os.path.dirname(save_res_path)):
  96. os.makedirs(os.path.dirname(save_res_path))
  97. model.eval()
  98. with open(save_res_path, "w") as fout:
  99. for file in get_image_file_list(config['Global']['infer_img']):
  100. logger.info("infer_img: {}".format(file))
  101. with open(file, 'rb') as f:
  102. img = f.read()
  103. data = {'image': img}
  104. batch = transform(data, ops)
  105. if config['Architecture']['algorithm'] == "SRN":
  106. encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
  107. gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
  108. gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
  109. gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
  110. others = [
  111. paddle.to_tensor(encoder_word_pos_list),
  112. paddle.to_tensor(gsrm_word_pos_list),
  113. paddle.to_tensor(gsrm_slf_attn_bias1_list),
  114. paddle.to_tensor(gsrm_slf_attn_bias2_list)
  115. ]
  116. if config['Architecture']['algorithm'] == "SAR":
  117. valid_ratio = np.expand_dims(batch[-1], axis=0)
  118. img_metas = [paddle.to_tensor(valid_ratio)]
  119. if config['Architecture']['algorithm'] == "RobustScanner":
  120. valid_ratio = np.expand_dims(batch[1], axis=0)
  121. word_positons = np.expand_dims(batch[2], axis=0)
  122. img_metas = [
  123. paddle.to_tensor(valid_ratio),
  124. paddle.to_tensor(word_positons),
  125. ]
  126. if config['Architecture']['algorithm'] == "CAN":
  127. image_mask = paddle.ones(
  128. (np.expand_dims(
  129. batch[0], axis=0).shape), dtype='float32')
  130. label = paddle.ones((1, 36), dtype='int64')
  131. images = np.expand_dims(batch[0], axis=0)
  132. images = paddle.to_tensor(images)
  133. if config['Architecture']['algorithm'] == "SRN":
  134. preds = model(images, others)
  135. elif config['Architecture']['algorithm'] == "SAR":
  136. preds = model(images, img_metas)
  137. elif config['Architecture']['algorithm'] == "RobustScanner":
  138. preds = model(images, img_metas)
  139. elif config['Architecture']['algorithm'] == "CAN":
  140. preds = model([images, image_mask, label])
  141. else:
  142. preds = model(images)
  143. post_result = post_process_class(preds)
  144. info = None
  145. if isinstance(post_result, dict):
  146. rec_info = dict()
  147. for key in post_result:
  148. if len(post_result[key][0]) >= 2:
  149. rec_info[key] = {
  150. "label": post_result[key][0][0],
  151. "score": float(post_result[key][0][1]),
  152. }
  153. info = json.dumps(rec_info, ensure_ascii=False)
  154. elif isinstance(post_result, list) and isinstance(post_result[0],
  155. int):
  156. # for RFLearning CNT branch
  157. info = str(post_result[0])
  158. else:
  159. if len(post_result[0]) >= 2:
  160. info = post_result[0][0] + "\t" + str(post_result[0][1])
  161. if info is not None:
  162. logger.info("\t result: {}".format(info))
  163. fout.write(file + "\t" + info + "\n")
  164. logger.info("success!")
  165. if __name__ == '__main__':
  166. config, device, logger, vdl_writer = program.preprocess()
  167. main()