export_prune_model.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(__file__)
  20. sys.path.append(__dir__)
  21. sys.path.append(os.path.join(__dir__, '..', '..', '..'))
  22. sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
  23. import paddle
  24. from ppocr.data import build_dataloader
  25. from ppocr.modeling.architectures import build_model
  26. from ppocr.postprocess import build_post_process
  27. from ppocr.metrics import build_metric
  28. from ppocr.utils.save_load import load_model
  29. import tools.program as program
  30. def main(config, device, logger, vdl_writer):
  31. global_config = config['Global']
  32. # build dataloader
  33. valid_dataloader = build_dataloader(config, 'Eval', device, logger)
  34. # build post process
  35. post_process_class = build_post_process(config['PostProcess'],
  36. global_config)
  37. # build model
  38. # for rec algorithm
  39. if hasattr(post_process_class, 'character'):
  40. char_num = len(getattr(post_process_class, 'character'))
  41. config['Architecture']["Head"]['out_channels'] = char_num
  42. model = build_model(config['Architecture'])
  43. if config['Architecture']['model_type'] == 'det':
  44. input_shape = [1, 3, 640, 640]
  45. elif config['Architecture']['model_type'] == 'rec':
  46. input_shape = [1, 3, 32, 320]
  47. flops = paddle.flops(model, input_shape)
  48. logger.info("FLOPs before pruning: {}".format(flops))
  49. from paddleslim.dygraph import FPGMFilterPruner
  50. model.train()
  51. pruner = FPGMFilterPruner(model, input_shape)
  52. # build metric
  53. eval_class = build_metric(config['Metric'])
  54. def eval_fn():
  55. metric = program.eval(model, valid_dataloader, post_process_class,
  56. eval_class)
  57. if config['Architecture']['model_type'] == 'det':
  58. main_indicator = 'hmean'
  59. else:
  60. main_indicator = 'acc'
  61. logger.info("metric[{}]: {}".format(main_indicator, metric[
  62. main_indicator]))
  63. return metric[main_indicator]
  64. params_sensitive = pruner.sensitive(
  65. eval_func=eval_fn,
  66. sen_file="./sen.pickle",
  67. skip_vars=[
  68. "conv2d_57.w_0", "conv2d_transpose_2.w_0", "conv2d_transpose_3.w_0"
  69. ])
  70. logger.info(
  71. "The sensitivity analysis results of model parameters saved in sen.pickle"
  72. )
  73. # calculate pruned params's ratio
  74. params_sensitive = pruner._get_ratios_by_loss(params_sensitive, loss=0.02)
  75. for key in params_sensitive.keys():
  76. logger.info("{}, {}".format(key, params_sensitive[key]))
  77. plan = pruner.prune_vars(params_sensitive, [0])
  78. flops = paddle.flops(model, input_shape)
  79. logger.info("FLOPs after pruning: {}".format(flops))
  80. # load pretrain model
  81. load_model(config, model)
  82. metric = program.eval(model, valid_dataloader, post_process_class,
  83. eval_class)
  84. if config['Architecture']['model_type'] == 'det':
  85. main_indicator = 'hmean'
  86. else:
  87. main_indicator = 'acc'
  88. logger.info("metric['']: {}".format(main_indicator, metric[main_indicator]))
  89. # start export model
  90. from paddle.jit import to_static
  91. infer_shape = [3, -1, -1]
  92. if config['Architecture']['model_type'] == "rec":
  93. infer_shape = [3, 32, -1] # for rec model, H must be 32
  94. if 'Transform' in config['Architecture'] and config['Architecture'][
  95. 'Transform'] is not None and config['Architecture'][
  96. 'Transform']['name'] == 'TPS':
  97. logger.info(
  98. 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
  99. )
  100. infer_shape[-1] = 100
  101. model = to_static(
  102. model,
  103. input_spec=[
  104. paddle.static.InputSpec(
  105. shape=[None] + infer_shape, dtype='float32')
  106. ])
  107. save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
  108. paddle.jit.save(model, save_path)
  109. logger.info('inference model is saved to {}'.format(save_path))
  110. if __name__ == '__main__':
  111. config, device, logger, vdl_writer = program.preprocess(is_train=True)
  112. main(config, device, logger, vdl_writer)