# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys __dir__ = os.path.dirname(os.path.abspath(__file__)) sys.path.append(__dir__) sys.path.append(os.path.abspath(os.path.join(__dir__, '..', '..', '..'))) sys.path.append( os.path.abspath(os.path.join(__dir__, '..', '..', '..', 'tools'))) import yaml import paddle import paddle.distributed as dist paddle.seed(2) from ppocr.data import build_dataloader from ppocr.modeling.architectures import build_model from ppocr.losses import build_loss from ppocr.optimizer import build_optimizer from ppocr.postprocess import build_post_process from ppocr.metrics import build_metric from ppocr.utils.save_load import load_model import tools.program as program from paddleslim.dygraph.quant import QAT dist.get_world_size() class PACT(paddle.nn.Layer): def __init__(self): super(PACT, self).__init__() alpha_attr = paddle.ParamAttr( name=self.full_name() + ".pact", initializer=paddle.nn.initializer.Constant(value=20), learning_rate=1.0, regularizer=paddle.regularizer.L2Decay(2e-5)) self.alpha = self.create_parameter( shape=[1], attr=alpha_attr, dtype='float32') def forward(self, x): out_left = paddle.nn.functional.relu(x - self.alpha) out_right = paddle.nn.functional.relu(-self.alpha - x) x = x - out_left + out_right return x quant_config = { # weight preprocess type, default is None and no preprocessing is performed. 'weight_preprocess_type': None, # activation preprocess type, default is None and no preprocessing is performed. 'activation_preprocess_type': None, # weight quantize type, default is 'channel_wise_abs_max' 'weight_quantize_type': 'channel_wise_abs_max', # activation quantize type, default is 'moving_average_abs_max' 'activation_quantize_type': 'moving_average_abs_max', # weight quantize bit num, default is 8 'weight_bits': 8, # activation quantize bit num, default is 8 'activation_bits': 8, # data type after quantization, such as 'uint8', 'int8', etc. default is 'int8' 'dtype': 'int8', # window size for 'range_abs_max' quantization. default is 10000 'window_size': 10000, # The decay coefficient of moving average, default is 0.9 'moving_rate': 0.9, # for dygraph quantization, layers of type in quantizable_layer_type will be quantized 'quantizable_layer_type': ['Conv2D', 'Linear'], } def main(config, device, logger, vdl_writer): # init dist environment if config['Global']['distributed']: dist.init_parallel_env() global_config = config['Global'] # build dataloader train_dataloader = build_dataloader(config, 'Train', device, logger) if config['Eval']: valid_dataloader = build_dataloader(config, 'Eval', device, logger) else: valid_dataloader = None # build post process post_process_class = build_post_process(config['PostProcess'], global_config) # build model # for rec algorithm if hasattr(post_process_class, 'character'): char_num = len(getattr(post_process_class, 'character')) if config['Architecture']["algorithm"] in ["Distillation", ]: # distillation model for key in config['Architecture']["Models"]: if config['Architecture']['Models'][key]['Head'][ 'name'] == 'MultiHead': # for multi head if config['PostProcess'][ 'name'] == 'DistillationSARLabelDecode': char_num = char_num - 2 # update SARLoss params assert list(config['Loss']['loss_config_list'][-1].keys())[ 0] == 'DistillationSARLoss' config['Loss']['loss_config_list'][-1][ 'DistillationSARLoss']['ignore_index'] = char_num + 1 out_channels_list = {} out_channels_list['CTCLabelDecode'] = char_num out_channels_list['SARLabelDecode'] = char_num + 2 config['Architecture']['Models'][key]['Head'][ 'out_channels_list'] = out_channels_list else: config['Architecture']["Models"][key]["Head"][ 'out_channels'] = char_num elif config['Architecture']['Head'][ 'name'] == 'MultiHead': # for multi head if config['PostProcess']['name'] == 'SARLabelDecode': char_num = char_num - 2 # update SARLoss params assert list(config['Loss']['loss_config_list'][1].keys())[ 0] == 'SARLoss' if config['Loss']['loss_config_list'][1]['SARLoss'] is None: config['Loss']['loss_config_list'][1]['SARLoss'] = { 'ignore_index': char_num + 1 } else: config['Loss']['loss_config_list'][1]['SARLoss'][ 'ignore_index'] = char_num + 1 out_channels_list = {} out_channels_list['CTCLabelDecode'] = char_num out_channels_list['SARLabelDecode'] = char_num + 2 config['Architecture']['Head'][ 'out_channels_list'] = out_channels_list else: # base rec model config['Architecture']["Head"]['out_channels'] = char_num if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model config['Loss']['ignore_index'] = char_num - 1 model = build_model(config['Architecture']) pre_best_model_dict = dict() # load fp32 model to begin quantization pre_best_model_dict = load_model(config, model, None, config['Architecture']["model_type"]) freeze_params = False if config['Architecture']["algorithm"] in ["Distillation"]: for key in config['Architecture']["Models"]: freeze_params = freeze_params or config['Architecture']['Models'][ key].get('freeze_params', False) act = None if freeze_params else PACT quanter = QAT(config=quant_config, act_preprocess=act) quanter.quantize(model) if config['Global']['distributed']: model = paddle.DataParallel(model) # build loss loss_class = build_loss(config['Loss']) # build optim optimizer, lr_scheduler = build_optimizer( config['Optimizer'], epochs=config['Global']['epoch_num'], step_each_epoch=len(train_dataloader), model=model) # resume PACT training process pre_best_model_dict = load_model(config, model, optimizer, config['Architecture']["model_type"]) # build metric eval_class = build_metric(config['Metric']) logger.info('train dataloader has {} iters, valid dataloader has {} iters'. format(len(train_dataloader), len(valid_dataloader))) # start train program.train(config, train_dataloader, valid_dataloader, device, model, loss_class, optimizer, lr_scheduler, post_process_class, eval_class, pre_best_model_dict, logger, vdl_writer) if __name__ == '__main__': config, device, logger, vdl_writer = program.preprocess(is_train=True) main(config, device, logger, vdl_writer)