slim_fpgm.py 665 B

12345678910111213141516171819202122
  1. import paddleslim
  2. import paddle
  3. import numpy as np
  4. from paddleslim.dygraph import FPGMFilterPruner
  5. def prune_model(model, input_shape, prune_ratio=0.1):
  6. flops = paddle.flops(model, input_shape)
  7. pruner = FPGMFilterPruner(model, input_shape)
  8. params_sensitive = {}
  9. for param in model.parameters():
  10. if 'transpose' not in param.name and 'linear' not in param.name:
  11. # set prune ratio as 10%. The larger the value, the more convolution weights will be cropped
  12. params_sensitive[param.name] = prune_ratio
  13. plan = pruner.prune_vars(params_sensitive, [0])
  14. flops = paddle.flops(model, input_shape)
  15. return model