eval.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(os.path.abspath(__file__))
  20. sys.path.insert(0, __dir__)
  21. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  22. import paddle
  23. from ppocr.data import build_dataloader
  24. from ppocr.modeling.architectures import build_model
  25. from ppocr.postprocess import build_post_process
  26. from ppocr.metrics import build_metric
  27. from ppocr.utils.save_load import load_model
  28. import tools.program as program
  29. def main():
  30. global_config = config['Global']
  31. # build dataloader
  32. valid_dataloader = build_dataloader(config, 'Eval', device, logger)
  33. # build post process
  34. post_process_class = build_post_process(config['PostProcess'],
  35. global_config)
  36. # build model
  37. # for rec algorithm
  38. if hasattr(post_process_class, 'character'):
  39. char_num = len(getattr(post_process_class, 'character'))
  40. if config['Architecture']["algorithm"] in ["Distillation",
  41. ]: # distillation model
  42. for key in config['Architecture']["Models"]:
  43. if config['Architecture']['Models'][key]['Head'][
  44. 'name'] == 'MultiHead': # for multi head
  45. out_channels_list = {}
  46. if config['PostProcess'][
  47. 'name'] == 'DistillationSARLabelDecode':
  48. char_num = char_num - 2
  49. out_channels_list['CTCLabelDecode'] = char_num
  50. out_channels_list['SARLabelDecode'] = char_num + 2
  51. config['Architecture']['Models'][key]['Head'][
  52. 'out_channels_list'] = out_channels_list
  53. else:
  54. config['Architecture']["Models"][key]["Head"][
  55. 'out_channels'] = char_num
  56. elif config['Architecture']['Head'][
  57. 'name'] == 'MultiHead': # for multi head
  58. out_channels_list = {}
  59. if config['PostProcess']['name'] == 'SARLabelDecode':
  60. char_num = char_num - 2
  61. out_channels_list['CTCLabelDecode'] = char_num
  62. out_channels_list['SARLabelDecode'] = char_num + 2
  63. config['Architecture']['Head'][
  64. 'out_channels_list'] = out_channels_list
  65. else: # base rec model
  66. config['Architecture']["Head"]['out_channels'] = char_num
  67. model = build_model(config['Architecture'])
  68. extra_input_models = [
  69. "SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"
  70. ]
  71. extra_input = False
  72. if config['Architecture']['algorithm'] == 'Distillation':
  73. for key in config['Architecture']["Models"]:
  74. extra_input = extra_input or config['Architecture']['Models'][key][
  75. 'algorithm'] in extra_input_models
  76. else:
  77. extra_input = config['Architecture']['algorithm'] in extra_input_models
  78. if "model_type" in config['Architecture'].keys():
  79. if config['Architecture']['algorithm'] == 'CAN':
  80. model_type = 'can'
  81. else:
  82. model_type = config['Architecture']['model_type']
  83. else:
  84. model_type = None
  85. # build metric
  86. eval_class = build_metric(config['Metric'])
  87. # amp
  88. use_amp = config["Global"].get("use_amp", False)
  89. amp_level = config["Global"].get("amp_level", 'O2')
  90. amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
  91. if use_amp:
  92. AMP_RELATED_FLAGS_SETTING = {
  93. 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
  94. 'FLAGS_max_inplace_grad_add': 8,
  95. }
  96. paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
  97. scale_loss = config["Global"].get("scale_loss", 1.0)
  98. use_dynamic_loss_scaling = config["Global"].get(
  99. "use_dynamic_loss_scaling", False)
  100. scaler = paddle.amp.GradScaler(
  101. init_loss_scaling=scale_loss,
  102. use_dynamic_loss_scaling=use_dynamic_loss_scaling)
  103. if amp_level == "O2":
  104. model = paddle.amp.decorate(
  105. models=model, level=amp_level, master_weight=True)
  106. else:
  107. scaler = None
  108. best_model_dict = load_model(
  109. config, model, model_type=config['Architecture']["model_type"])
  110. if len(best_model_dict):
  111. logger.info('metric in ckpt ***************')
  112. for k, v in best_model_dict.items():
  113. logger.info('{}:{}'.format(k, v))
  114. # start eval
  115. metric = program.eval(model, valid_dataloader, post_process_class,
  116. eval_class, model_type, extra_input, scaler,
  117. amp_level, amp_custom_black_list)
  118. logger.info('metric eval ***************')
  119. for k, v in metric.items():
  120. logger.info('{}:{}'.format(k, v))
  121. if __name__ == '__main__':
  122. config, device, logger, vdl_writer = program.preprocess()
  123. main()