sensitivity_anal.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(__file__)
  20. sys.path.append(__dir__)
  21. sys.path.append(os.path.join(__dir__, '..', '..', '..'))
  22. sys.path.append(os.path.join(__dir__, '..', '..', '..', 'tools'))
  23. import paddle
  24. import paddle.distributed as dist
  25. from ppocr.data import build_dataloader
  26. from ppocr.modeling.architectures import build_model
  27. from ppocr.losses import build_loss
  28. from ppocr.optimizer import build_optimizer
  29. from ppocr.postprocess import build_post_process
  30. from ppocr.metrics import build_metric
  31. from ppocr.utils.save_load import load_model
  32. import tools.program as program
  33. dist.get_world_size()
  34. def get_pruned_params(parameters):
  35. params = []
  36. for param in parameters:
  37. if len(
  38. param.shape
  39. ) == 4 and 'depthwise' not in param.name and 'transpose' not in param.name and "conv2d_57" not in param.name and "conv2d_56" not in param.name:
  40. params.append(param.name)
  41. return params
  42. def main(config, device, logger, vdl_writer):
  43. # init dist environment
  44. if config['Global']['distributed']:
  45. dist.init_parallel_env()
  46. global_config = config['Global']
  47. # build dataloader
  48. train_dataloader = build_dataloader(config, 'Train', device, logger)
  49. if config['Eval']:
  50. valid_dataloader = build_dataloader(config, 'Eval', device, logger)
  51. else:
  52. valid_dataloader = None
  53. # build post process
  54. post_process_class = build_post_process(config['PostProcess'],
  55. global_config)
  56. # build model
  57. # for rec algorithm
  58. if hasattr(post_process_class, 'character'):
  59. char_num = len(getattr(post_process_class, 'character'))
  60. config['Architecture']["Head"]['out_channels'] = char_num
  61. model = build_model(config['Architecture'])
  62. if config['Architecture']['model_type'] == 'det':
  63. input_shape = [1, 3, 640, 640]
  64. elif config['Architecture']['model_type'] == 'rec':
  65. input_shape = [1, 3, 32, 320]
  66. flops = paddle.flops(model, input_shape)
  67. logger.info("FLOPs before pruning: {}".format(flops))
  68. from paddleslim.dygraph import FPGMFilterPruner
  69. model.train()
  70. pruner = FPGMFilterPruner(model, input_shape)
  71. # build loss
  72. loss_class = build_loss(config['Loss'])
  73. # build optim
  74. optimizer, lr_scheduler = build_optimizer(
  75. config['Optimizer'],
  76. epochs=config['Global']['epoch_num'],
  77. step_each_epoch=len(train_dataloader),
  78. model=model)
  79. # build metric
  80. eval_class = build_metric(config['Metric'])
  81. # load pretrain model
  82. pre_best_model_dict = load_model(config, model, optimizer)
  83. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
  84. format(len(train_dataloader), len(valid_dataloader)))
  85. # build metric
  86. eval_class = build_metric(config['Metric'])
  87. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
  88. format(len(train_dataloader), len(valid_dataloader)))
  89. def eval_fn():
  90. metric = program.eval(model, valid_dataloader, post_process_class,
  91. eval_class, False)
  92. if config['Architecture']['model_type'] == 'det':
  93. main_indicator = 'hmean'
  94. else:
  95. main_indicator = 'acc'
  96. logger.info("metric[{}]: {}".format(main_indicator, metric[
  97. main_indicator]))
  98. return metric[main_indicator]
  99. run_sensitive_analysis = False
  100. """
  101. run_sensitive_analysis=True:
  102. Automatically compute the sensitivities of convolutions in a model.
  103. The sensitivity of a convolution is the losses of accuracy on test dataset in
  104. differenct pruned ratios. The sensitivities can be used to get a group of best
  105. ratios with some condition.
  106. run_sensitive_analysis=False:
  107. Set prune trim ratio to a fixed value, such as 10%. The larger the value,
  108. the more convolution weights will be cropped.
  109. """
  110. if run_sensitive_analysis:
  111. params_sensitive = pruner.sensitive(
  112. eval_func=eval_fn,
  113. sen_file="./deploy/slim/prune/sen.pickle",
  114. skip_vars=[
  115. "conv2d_57.w_0", "conv2d_transpose_2.w_0",
  116. "conv2d_transpose_3.w_0"
  117. ])
  118. logger.info(
  119. "The sensitivity analysis results of model parameters saved in sen.pickle"
  120. )
  121. # calculate pruned params's ratio
  122. params_sensitive = pruner._get_ratios_by_loss(
  123. params_sensitive, loss=0.02)
  124. for key in params_sensitive.keys():
  125. logger.info("{}, {}".format(key, params_sensitive[key]))
  126. else:
  127. params_sensitive = {}
  128. for param in model.parameters():
  129. if 'transpose' not in param.name and 'linear' not in param.name:
  130. # set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
  131. params_sensitive[param.name] = 0.1
  132. plan = pruner.prune_vars(params_sensitive, [0])
  133. flops = paddle.flops(model, input_shape)
  134. logger.info("FLOPs after pruning: {}".format(flops))
  135. # start train
  136. program.train(config, train_dataloader, valid_dataloader, device, model,
  137. loss_class, optimizer, lr_scheduler, post_process_class,
  138. eval_class, pre_best_model_dict, logger, vdl_writer)
  139. if __name__ == '__main__':
  140. config, device, logger, vdl_writer = program.preprocess(is_train=True)
  141. main(config, device, logger, vdl_writer)