export_model.py 7.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. __dir__ = os.path.dirname(os.path.abspath(__file__))
  17. sys.path.append(__dir__)
  18. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
  19. sys.path.insert(
  20. 0, os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
  21. import argparse
  22. import paddle
  23. from paddle.jit import to_static
  24. from ppocr.modeling.architectures import build_model
  25. from ppocr.postprocess import build_post_process
  26. from ppocr.utils.save_load import load_model
  27. from ppocr.utils.logging import get_logger
  28. from tools.program import load_config, merge_config, ArgsParser
  29. from ppocr.metrics import build_metric
  30. import tools.program as program
  31. from paddleslim.dygraph.quant import QAT
  32. from ppocr.data import build_dataloader
  33. from tools.export_model import export_single_model
  34. def main():
  35. ############################################################################################################
  36. # 1. quantization configs
  37. ############################################################################################################
  38. quant_config = {
  39. # weight preprocess type, default is None and no preprocessing is performed.
  40. 'weight_preprocess_type': None,
  41. # activation preprocess type, default is None and no preprocessing is performed.
  42. 'activation_preprocess_type': None,
  43. # weight quantize type, default is 'channel_wise_abs_max'
  44. 'weight_quantize_type': 'channel_wise_abs_max',
  45. # activation quantize type, default is 'moving_average_abs_max'
  46. 'activation_quantize_type': 'moving_average_abs_max',
  47. # weight quantize bit num, default is 8
  48. 'weight_bits': 8,
  49. # activation quantize bit num, default is 8
  50. 'activation_bits': 8,
  51. # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
  52. 'dtype': 'int8',
  53. # window size for 'range_abs_max' quantization. default is 10000
  54. 'window_size': 10000,
  55. # The decay coefficient of moving average, default is 0.9
  56. 'moving_rate': 0.9,
  57. # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
  58. 'quantizable_layer_type': ['Conv2D', 'Linear'],
  59. }
  60. FLAGS = ArgsParser().parse_args()
  61. config = load_config(FLAGS.config)
  62. config = merge_config(config, FLAGS.opt)
  63. logger = get_logger()
  64. # build post process
  65. post_process_class = build_post_process(config['PostProcess'],
  66. config['Global'])
  67. # build model
  68. if hasattr(post_process_class, 'character'):
  69. char_num = len(getattr(post_process_class, 'character'))
  70. if config['Architecture']["algorithm"] in ["Distillation",
  71. ]: # distillation model
  72. for key in config['Architecture']["Models"]:
  73. if config['Architecture']['Models'][key]['Head'][
  74. 'name'] == 'MultiHead': # for multi head
  75. if config['PostProcess'][
  76. 'name'] == 'DistillationSARLabelDecode':
  77. char_num = char_num - 2
  78. # update SARLoss params
  79. assert list(config['Loss']['loss_config_list'][-1].keys())[
  80. 0] == 'DistillationSARLoss'
  81. config['Loss']['loss_config_list'][-1][
  82. 'DistillationSARLoss']['ignore_index'] = char_num + 1
  83. out_channels_list = {}
  84. out_channels_list['CTCLabelDecode'] = char_num
  85. out_channels_list['SARLabelDecode'] = char_num + 2
  86. config['Architecture']['Models'][key]['Head'][
  87. 'out_channels_list'] = out_channels_list
  88. else:
  89. config['Architecture']["Models"][key]["Head"][
  90. 'out_channels'] = char_num
  91. elif config['Architecture']['Head'][
  92. 'name'] == 'MultiHead': # for multi head
  93. if config['PostProcess']['name'] == 'SARLabelDecode':
  94. char_num = char_num - 2
  95. # update SARLoss params
  96. assert list(config['Loss']['loss_config_list'][1].keys())[
  97. 0] == 'SARLoss'
  98. if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
  99. config['Loss']['loss_config_list'][1]['SARLoss'] = {
  100. 'ignore_index': char_num + 1
  101. }
  102. else:
  103. config['Loss']['loss_config_list'][1]['SARLoss'][
  104. 'ignore_index'] = char_num + 1
  105. out_channels_list = {}
  106. out_channels_list['CTCLabelDecode'] = char_num
  107. out_channels_list['SARLabelDecode'] = char_num + 2
  108. config['Architecture']['Head'][
  109. 'out_channels_list'] = out_channels_list
  110. else: # base rec model
  111. config['Architecture']["Head"]['out_channels'] = char_num
  112. if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
  113. config['Loss']['ignore_index'] = char_num - 1
  114. model = build_model(config['Architecture'])
  115. # get QAT model
  116. quanter = QAT(config=quant_config)
  117. quanter.quantize(model)
  118. load_model(config, model)
  119. # build metric
  120. eval_class = build_metric(config['Metric'])
  121. # build dataloader
  122. valid_dataloader = build_dataloader(config, 'Eval', device, logger)
  123. use_srn = config['Architecture']['algorithm'] == "SRN"
  124. model_type = config['Architecture'].get('model_type', None)
  125. # start eval
  126. metric = program.eval(model, valid_dataloader, post_process_class,
  127. eval_class, model_type, use_srn)
  128. model.eval()
  129. logger.info('metric eval ***************')
  130. for k, v in metric.items():
  131. logger.info('{}:{}'.format(k, v))
  132. save_path = config["Global"]["save_inference_dir"]
  133. arch_config = config["Architecture"]
  134. if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
  135. "name"] != 'MultiHead':
  136. input_shape = config["Eval"]["dataset"]["transforms"][-2][
  137. 'SVTRRecResizeImg']['image_shape']
  138. else:
  139. input_shape = None
  140. if arch_config["algorithm"] in ["Distillation", ]: # distillation model
  141. archs = list(arch_config["Models"].values())
  142. for idx, name in enumerate(model.model_name_list):
  143. sub_model_save_path = os.path.join(save_path, name, "inference")
  144. export_single_model(model.model_list[idx], archs[idx],
  145. sub_model_save_path, logger, input_shape,
  146. quanter)
  147. else:
  148. save_path = os.path.join(save_path, "inference")
  149. export_single_model(model, arch_config, save_path, logger, input_shape,
  150. quanter)
  151. if __name__ == "__main__":
  152. config, device, logger, vdl_writer = program.preprocess()
  153. main()