quant.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(os.path.abspath(__file__))
  20. sys.path.append(__dir__)
  21. sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..')))
  22. sys.path.append(
  23. os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools')))
  24. import yaml
  25. import paddle
  26. import paddle.distributed as dist
  27. paddle.seed(2)
  28. from ppocr.data import build_dataloader
  29. from ppocr.modeling.architectures import build_model
  30. from ppocr.losses import build_loss
  31. from ppocr.optimizer import build_optimizer
  32. from ppocr.postprocess import build_post_process
  33. from ppocr.metrics import build_metric
  34. from ppocr.utils.save_load import load_model
  35. import tools.program as program
  36. from paddleslim.dygraph.quant import QAT
  37. dist.get_world_size()
  38. class PACT(paddle.nn.Layer):
  39. def __init__(self):
  40. super(PACT, self).__init__()
  41. alpha_attr = paddle.ParamAttr(
  42. name=self.full_name() + ".pact",
  43. initializer=paddle.nn.initializer.Constant(value=20),
  44. learning_rate=1.0,
  45. regularizer=paddle.regularizer.L2Decay(2e-5))
  46. self.alpha = self.create_parameter(
  47. shape=[1], attr=alpha_attr, dtype='float32')
  48. def forward(self, x):
  49. out_left = paddle.nn.functional.relu(x - self.alpha)
  50. out_right = paddle.nn.functional.relu(-self.alpha - x)
  51. x = x - out_left + out_right
  52. return x
  53. quant_config = {
  54. # weight preprocess type, default is None and no preprocessing is performed.
  55. 'weight_preprocess_type': None,
  56. # activation preprocess type, default is None and no preprocessing is performed.
  57. 'activation_preprocess_type': None,
  58. # weight quantize type, default is 'channel_wise_abs_max'
  59. 'weight_quantize_type': 'channel_wise_abs_max',
  60. # activation quantize type, default is 'moving_average_abs_max'
  61. 'activation_quantize_type': 'moving_average_abs_max',
  62. # weight quantize bit num, default is 8
  63. 'weight_bits': 8,
  64. # activation quantize bit num, default is 8
  65. 'activation_bits': 8,
  66. # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8'
  67. 'dtype': 'int8',
  68. # window size for 'range_abs_max' quantization. default is 10000
  69. 'window_size': 10000,
  70. # The decay coefficient of moving average, default is 0.9
  71. 'moving_rate': 0.9,
  72. # for dygraph quantization, layers of type in quantizable_layer_type will be quantized
  73. 'quantizable_layer_type': ['Conv2D', 'Linear'],
  74. }
  75. def main(config, device, logger, vdl_writer):
  76. # init dist environment
  77. if config['Global']['distributed']:
  78. dist.init_parallel_env()
  79. global_config = config['Global']
  80. # build dataloader
  81. train_dataloader = build_dataloader(config, 'Train', device, logger)
  82. if config['Eval']:
  83. valid_dataloader = build_dataloader(config, 'Eval', device, logger)
  84. else:
  85. valid_dataloader = None
  86. # build post process
  87. post_process_class = build_post_process(config['PostProcess'],
  88. global_config)
  89. # build model
  90. # for rec algorithm
  91. if hasattr(post_process_class, 'character'):
  92. char_num = len(getattr(post_process_class, 'character'))
  93. if config['Architecture']["algorithm"] in ["Distillation",
  94. ]: # distillation model
  95. for key in config['Architecture']["Models"]:
  96. if config['Architecture']['Models'][key]['Head'][
  97. 'name'] == 'MultiHead': # for multi head
  98. if config['PostProcess'][
  99. 'name'] == 'DistillationSARLabelDecode':
  100. char_num = char_num - 2
  101. # update SARLoss params
  102. assert list(config['Loss']['loss_config_list'][-1].keys())[
  103. 0] == 'DistillationSARLoss'
  104. config['Loss']['loss_config_list'][-1][
  105. 'DistillationSARLoss']['ignore_index'] = char_num + 1
  106. out_channels_list = {}
  107. out_channels_list['CTCLabelDecode'] = char_num
  108. out_channels_list['SARLabelDecode'] = char_num + 2
  109. config['Architecture']['Models'][key]['Head'][
  110. 'out_channels_list'] = out_channels_list
  111. else:
  112. config['Architecture']["Models"][key]["Head"][
  113. 'out_channels'] = char_num
  114. elif config['Architecture']['Head'][
  115. 'name'] == 'MultiHead': # for multi head
  116. if config['PostProcess']['name'] == 'SARLabelDecode':
  117. char_num = char_num - 2
  118. # update SARLoss params
  119. assert list(config['Loss']['loss_config_list'][1].keys())[
  120. 0] == 'SARLoss'
  121. if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
  122. config['Loss']['loss_config_list'][1]['SARLoss'] = {
  123. 'ignore_index': char_num + 1
  124. }
  125. else:
  126. config['Loss']['loss_config_list'][1]['SARLoss'][
  127. 'ignore_index'] = char_num + 1
  128. out_channels_list = {}
  129. out_channels_list['CTCLabelDecode'] = char_num
  130. out_channels_list['SARLabelDecode'] = char_num + 2
  131. config['Architecture']['Head'][
  132. 'out_channels_list'] = out_channels_list
  133. else: # base rec model
  134. config['Architecture']["Head"]['out_channels'] = char_num
  135. if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
  136. config['Loss']['ignore_index'] = char_num - 1
  137. model = build_model(config['Architecture'])
  138. pre_best_model_dict = dict()
  139. # load fp32 model to begin quantization
  140. pre_best_model_dict = load_model(config, model, None, config['Architecture']["model_type"])
  141. freeze_params = False
  142. if config['Architecture']["algorithm"] in ["Distillation"]:
  143. for key in config['Architecture']["Models"]:
  144. freeze_params = freeze_params or config['Architecture']['Models'][
  145. key].get('freeze_params', False)
  146. act = None if freeze_params else PACT
  147. quanter = QAT(config=quant_config, act_preprocess=act)
  148. quanter.quantize(model)
  149. if config['Global']['distributed']:
  150. model = paddle.DataParallel(model)
  151. # build loss
  152. loss_class = build_loss(config['Loss'])
  153. # build optim
  154. optimizer, lr_scheduler = build_optimizer(
  155. config['Optimizer'],
  156. epochs=config['Global']['epoch_num'],
  157. step_each_epoch=len(train_dataloader),
  158. model=model)
  159. # resume PACT training process
  160. pre_best_model_dict = load_model(config, model, optimizer, config['Architecture']["model_type"])
  161. # build metric
  162. eval_class = build_metric(config['Metric'])
  163. logger.info('train dataloader has {} iters, valid dataloader has {} iters'.
  164. format(len(train_dataloader), len(valid_dataloader)))
  165. # start train
  166. program.train(config, train_dataloader, valid_dataloader, device, model,
  167. loss_class, optimizer, lr_scheduler, post_process_class,
  168. eval_class, pre_best_model_dict, logger, vdl_writer)
  169. if __name__ == '__main__':
  170. config, device, logger, vdl_writer = program.preprocess(is_train=True)
  171. main(config, device, logger, vdl_writer)