quant_kl.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(os.path.abspath(__file__))
  20. sys.path.append(__dir__)
  21. sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
  22. sys.path.append(
  23. os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
  24. import yaml
  25. import paddle
  26. import paddle.distributed as dist
  27. paddle.seed(2)
  28. from ppocr.data import build_dataloader
  29. from ppocr.modeling.architectures import build_model
  30. from ppocr.losses import build_loss
  31. from ppocr.optimizer import build_optimizer
  32. from ppocr.postprocess import build_post_process
  33. from ppocr.metrics import build_metric
  34. from ppocr.utils.save_load import load_model
  35. import tools.program as program
  36. import paddleslim
  37. from paddleslim.dygraph.quant import QAT
  38. import numpy as np
  39. dist.get_world_size()
  40. class PACT(paddle.nn.Layer):
  41. def __init__(self):
  42. super(PACT, self).__init__()
  43. alpha_attr = paddle.ParamAttr(
  44. name=self.full_name() + ".pact",
  45. initializer=paddle.nn.initializer.Constant(value=20),
  46. learning_rate=1.0,
  47. regularizer=paddle.regularizer.L2Decay(2e-5))
  48. self.alpha = self.create_parameter(
  49. shape=[1], attr=alpha_attr, dtype='float32')
  50. def forward(self, x):
  51. out_left = paddle.nn.functional.relu(x - self.alpha)
  52. out_right = paddle.nn.functional.relu(-self.alpha - x)
  53. x = x - out_left + out_right
  54. return x
  55. quant_config = {
  56. # weight preprocess type, default is None and no preprocessing is performed.
  57. 'weight_preprocess_type': None,
  58. # activation preprocess type, default is None and no preprocessing is performed.
  59. 'activation_preprocess_type': None,
  60. # weight quantize type, default is 'channel_wise_abs_max'
  61. 'weight_quantize_type': 'channel_wise_abs_max',
  62. # activation quantize type, default is 'moving_average_abs_max'
  63. 'activation_quantize_type': 'moving_average_abs_max',
  64. # weight quantize bit num, default is 8
  65. 'weight_bits': 8,
  66. # activation quantize bit num, default is 8
  67. 'activation_bits': 8,
  68. # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
  69. 'dtype': 'int8',
  70. # window size for 'range_abs_max' quantization. default is 10000
  71. 'window_size': 10000,
  72. # The decay coefficient of moving average, default is 0.9
  73. 'moving_rate': 0.9,
  74. # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
  75. 'quantizable_layer_type': ['Conv2D', 'Linear'],
  76. }
  77. def sample_generator(loader):
  78. def __reader__():
  79. for indx, data in enumerate(loader):
  80. images = np.array(data[0])
  81. yield images
  82. return __reader__
  83. def sample_generator_layoutxlm_ser(loader):
  84. def __reader__():
  85. for indx, data in enumerate(loader):
  86. input_ids = np.array(data[0])
  87. bbox = np.array(data[1])
  88. attention_mask = np.array(data[2])
  89. token_type_ids = np.array(data[3])
  90. images = np.array(data[4])
  91. yield [input_ids, bbox, attention_mask, token_type_ids, images]
  92. return __reader__
  93. def main(config, device, logger, vdl_writer):
  94. # init dist environment
  95. if config['Global']['distributed']:
  96. dist.init_parallel_env()
  97. global_config = config['Global']
  98. # build dataloader
  99. config['Train']['loader']['num_workers'] = 0
  100. is_layoutxlm_ser = config['Architecture']['model_type'] =='kie' and config['Architecture']['Backbone']['name'] == 'LayoutXLMForSer'
  101. train_dataloader = build_dataloader(config, 'Train', device, logger)
  102. if config['Eval']:
  103. config['Eval']['loader']['num_workers'] = 0
  104. valid_dataloader = build_dataloader(config, 'Eval', device, logger)
  105. if is_layoutxlm_ser:
  106. train_dataloader = valid_dataloader
  107. else:
  108. valid_dataloader = None
  109. paddle.enable_static()
  110. exe = paddle.static.Executor(device)
  111. if 'inference_model' in global_config.keys(): # , 'inference_model'):
  112. inference_model_dir = global_config['inference_model']
  113. else:
  114. inference_model_dir = os.path.dirname(global_config['pretrained_model'])
  115. if not (os.path.exists(os.path.join(inference_model_dir, "inference.pdmodel")) and \
  116. os.path.exists(os.path.join(inference_model_dir, "inference.pdiparams")) ):
  117. raise ValueError(
  118. "Please set inference model dir in Global.inference_model or Global.pretrained_model for post-quantazition"
  119. )
  120. if is_layoutxlm_ser:
  121. generator = sample_generator_layoutxlm_ser(train_dataloader)
  122. else:
  123. generator = sample_generator(train_dataloader)
  124. paddleslim.quant.quant_post_static(
  125. executor=exe,
  126. model_dir=inference_model_dir,
  127. model_filename='inference.pdmodel',
  128. params_filename='inference.pdiparams',
  129. quantize_model_path=global_config['save_inference_dir'],
  130. sample_generator=generator,
  131. save_model_filename='inference.pdmodel',
  132. save_params_filename='inference.pdiparams',
  133. batch_size=1,
  134. batch_nums=None)
  135. if __name__ == '__main__':
  136. config, device, logger, vdl_writer = program.preprocess(is_train=True)
  137. main(config, device, logger, vdl_writer)