import paddle import numpy as np import os import paddle.nn as nn import paddle.distributed as dist dist.get_world_size() dist.init_parallel_env() from loss import build_loss, LossDistill, DMLLoss, KLJSLoss from optimizer import create_optimizer from data_loader import build_dataloader from metric import create_metric from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model from config import preprocess import time from paddleslim.dygraph.quant import QAT from slim.slim_quant import PACT, quant_config from slim.slim_fpgm import prune_model from utils import load_model def _mkdir_if_not_exist(path, logger): """ mkdir if not exists, ignore the exception when multiprocess mkdir together """ if not os.path.exists(path): try: os.makedirs(path) except OSError as e: if e.errno == errno.EEXIST and os.path.isdir(path): logger.warning( 'be happy if some process has already created {}'.format( path)) else: raise OSError('Failed to mkdir {}'.format(path)) def save_model(model, optimizer, model_path, logger, is_best=False, prefix='ppocr', **kwargs): """ save model to the target path """ _mkdir_if_not_exist(model_path, logger) model_prefix = os.path.join(model_path, prefix) paddle.save(model.state_dict(), model_prefix + '.pdparams') if type(optimizer) is list: paddle.save(optimizer[0].state_dict(), model_prefix + '.pdopt') paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + '.pdopt') else: paddle.save(optimizer.state_dict(), model_prefix + '.pdopt') # # save metric and config # with open(model_prefix + '.states', 'wb') as f: # pickle.dump(kwargs, f, protocol=2) if is_best: logger.info('save best model is to {}'.format(model_prefix)) else: logger.info("save model in {}".format(model_prefix)) def amp_scaler(config): if 'AMP' in config and config['AMP']['use_amp'] is True: AMP_RELATED_FLAGS_SETTING = { 'FLAGS_cudnn_batchnorm_spatial_persistent': 1, 'FLAGS_max_inplace_grad_add': 8, } paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING) scale_loss = config["AMP"].get("scale_loss", 1.0) use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling", False) scaler = paddle.amp.GradScaler( init_loss_scaling=scale_loss, use_dynamic_loss_scaling=use_dynamic_loss_scaling) return scaler else: return None def set_seed(seed): paddle.seed(seed) np.random.seed(seed) def train(config, scaler=None): EPOCH = config['epoch'] topk = config['topk'] batch_size = config['TRAIN']['batch_size'] num_workers = config['TRAIN']['num_workers'] train_loader = build_dataloader( 'train', batch_size=batch_size, num_workers=num_workers) # build metric metric_func = create_metric # build model # model = MobileNetV3_large_x0_5(class_dim=100) model = build_model(config) # build_optimizer optimizer, lr_scheduler = create_optimizer( config, parameter_list=model.parameters()) # load model pre_best_model_dict = load_model(config, model, optimizer) if len(pre_best_model_dict) > 0: pre_str = 'The metric of loaded metric as follows {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()])) logger.info(pre_str) # about slim prune and quant if "quant_train" in config and config['quant_train'] is True: quanter = QAT(config=quant_config, act_preprocess=PACT) quanter.quantize(model) elif "prune_train" in config and config['prune_train'] is True: model = prune_model(model, [1, 3, 32, 32], 0.1) else: pass # distribution model.train() model = paddle.DataParallel(model) # build loss function loss_func = build_loss(config) data_num = len(train_loader) best_acc = {} for epoch in range(EPOCH): st = time.time() for idx, data in enumerate(train_loader): img_batch, label = data img_batch = paddle.transpose(img_batch, [0, 3, 1, 2]) label = paddle.unsqueeze(label, -1) if scaler is not None: with paddle.amp.auto_cast(): outs = model(img_batch) else: outs = model(img_batch) # cal metric acc = metric_func(outs, label) # cal loss avg_loss = loss_func(outs, label) if scaler is None: # backward avg_loss.backward() optimizer.step() optimizer.clear_grad() else: scaled_avg_loss = scaler.scale(avg_loss) scaled_avg_loss.backward() scaler.minimize(optimizer, scaled_avg_loss) if not isinstance(lr_scheduler, float): lr_scheduler.step() if idx % 10 == 0: et = time.time() strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], " strs += f"loss: {avg_loss.numpy()[0]}" strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}" strs += f", batch_time: {round(et-st, 4)} s" logger.info(strs) st = time.time() if epoch % 10 == 0: acc = eval(config, model) if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']: best_acc = acc best_acc['epoch'] = epoch is_best = True else: is_best = False logger.info( f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}" ) save_model( model, optimizer, config['save_model_dir'], logger, is_best, prefix="cls") def train_distill(config, scaler=None): EPOCH = config['epoch'] topk = config['topk'] batch_size = config['TRAIN']['batch_size'] num_workers = config['TRAIN']['num_workers'] train_loader = build_dataloader( 'train', batch_size=batch_size, num_workers=num_workers) # build metric metric_func = create_metric # model = distillmv3_large_x0_5(class_dim=100) model = build_model(config) # pact quant train if "quant_train" in config and config['quant_train'] is True: quanter = QAT(config=quant_config, act_preprocess=PACT) quanter.quantize(model) elif "prune_train" in config and config['prune_train'] is True: model = prune_model(model, [1, 3, 32, 32], 0.1) else: pass # build_optimizer optimizer, lr_scheduler = create_optimizer( config, parameter_list=model.parameters()) # load model pre_best_model_dict = load_model(config, model, optimizer) if len(pre_best_model_dict) > 0: pre_str = 'The metric of loaded metric as follows {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()])) logger.info(pre_str) model.train() model = paddle.DataParallel(model) # build loss function loss_func_distill = LossDistill(model_name_list=['student', 'student1']) loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1']) loss_func_js = KLJSLoss(mode='js') data_num = len(train_loader) best_acc = {} for epoch in range(EPOCH): st = time.time() for idx, data in enumerate(train_loader): img_batch, label = data img_batch = paddle.transpose(img_batch, [0, 3, 1, 2]) label = paddle.unsqueeze(label, -1) if scaler is not None: with paddle.amp.auto_cast(): outs = model(img_batch) else: outs = model(img_batch) # cal metric acc = metric_func(outs['student'], label) # cal loss avg_loss = loss_func_distill(outs, label)['student'] + \ loss_func_distill(outs, label)['student1'] + \ loss_func_dml(outs, label)['student_student1'] # backward if scaler is None: avg_loss.backward() optimizer.step() optimizer.clear_grad() else: scaled_avg_loss = scaler.scale(avg_loss) scaled_avg_loss.backward() scaler.minimize(optimizer, scaled_avg_loss) if not isinstance(lr_scheduler, float): lr_scheduler.step() if idx % 10 == 0: et = time.time() strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], " strs += f"loss: {avg_loss.numpy()[0]}" strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}" strs += f", batch_time: {round(et-st, 4)} s" logger.info(strs) st = time.time() if epoch % 10 == 0: acc = eval(config, model._layers.student) if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']: best_acc = acc best_acc['epoch'] = epoch is_best = True else: is_best = False logger.info( f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}" ) save_model( model, optimizer, config['save_model_dir'], logger, is_best, prefix="cls_distill") def train_distill_multiopt(config, scaler=None): EPOCH = config['epoch'] topk = config['topk'] batch_size = config['TRAIN']['batch_size'] num_workers = config['TRAIN']['num_workers'] train_loader = build_dataloader( 'train', batch_size=batch_size, num_workers=num_workers) # build metric metric_func = create_metric # model = distillmv3_large_x0_5(class_dim=100) model = build_model(config) # build_optimizer optimizer, lr_scheduler = create_optimizer( config, parameter_list=model.student.parameters()) optimizer1, lr_scheduler1 = create_optimizer( config, parameter_list=model.student1.parameters()) # load model pre_best_model_dict = load_model(config, model, optimizer) if len(pre_best_model_dict) > 0: pre_str = 'The metric of loaded metric as follows {}'.format(', '.join( ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()])) logger.info(pre_str) # quant train if "quant_train" in config and config['quant_train'] is True: quanter = QAT(config=quant_config, act_preprocess=PACT) quanter.quantize(model) elif "prune_train" in config and config['prune_train'] is True: model = prune_model(model, [1, 3, 32, 32], 0.1) else: pass model.train() model = paddle.DataParallel(model) # build loss function loss_func_distill = LossDistill(model_name_list=['student', 'student1']) loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1']) loss_func_js = KLJSLoss(mode='js') data_num = len(train_loader) best_acc = {} for epoch in range(EPOCH): st = time.time() for idx, data in enumerate(train_loader): img_batch, label = data img_batch = paddle.transpose(img_batch, [0, 3, 1, 2]) label = paddle.unsqueeze(label, -1) if scaler is not None: with paddle.amp.auto_cast(): outs = model(img_batch) else: outs = model(img_batch) # cal metric acc = metric_func(outs['student'], label) # cal loss avg_loss = loss_func_distill(outs, label)['student'] + loss_func_dml( outs, label)['student_student1'] avg_loss1 = loss_func_distill(outs, label)['student1'] + loss_func_dml( outs, label)['student_student1'] if scaler is None: # backward avg_loss.backward(retain_graph=True) optimizer.step() optimizer.clear_grad() avg_loss1.backward() optimizer1.step() optimizer1.clear_grad() else: scaled_avg_loss = scaler.scale(avg_loss) scaled_avg_loss.backward() scaler.minimize(optimizer, scaled_avg_loss) scaled_avg_loss = scaler.scale(avg_loss1) scaled_avg_loss.backward() scaler.minimize(optimizer1, scaled_avg_loss) if not isinstance(lr_scheduler, float): lr_scheduler.step() if not isinstance(lr_scheduler1, float): lr_scheduler1.step() if idx % 10 == 0: et = time.time() strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], " strs += f"loss: {avg_loss.numpy()[0]}, loss1: {avg_loss1.numpy()[0]}" strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}" strs += f", batch_time: {round(et-st, 4)} s" logger.info(strs) st = time.time() if epoch % 10 == 0: acc = eval(config, model._layers.student) if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']: best_acc = acc best_acc['epoch'] = epoch is_best = True else: is_best = False logger.info( f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}" ) save_model( model, [optimizer, optimizer1], config['save_model_dir'], logger, is_best, prefix="cls_distill_multiopt") def eval(config, model): batch_size = config['VALID']['batch_size'] num_workers = config['VALID']['num_workers'] valid_loader = build_dataloader( 'test', batch_size=batch_size, num_workers=num_workers) # build metric metric_func = create_metric outs = [] labels = [] for idx, data in enumerate(valid_loader): img_batch, label = data img_batch = paddle.transpose(img_batch, [0, 3, 1, 2]) label = paddle.unsqueeze(label, -1) out = model(img_batch) outs.append(out) labels.append(label) outs = paddle.concat(outs, axis=0) labels = paddle.concat(labels, axis=0) acc = metric_func(outs, labels) strs = f"The metric are as follows: acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}" logger.info(strs) return acc if __name__ == "__main__": config, logger = preprocess(is_train=False) # AMP scaler scaler = amp_scaler(config) model_type = config['model_type'] if model_type == "cls": train(config) elif model_type == "cls_distill": train_distill(config) elif model_type == "cls_distill_multiopt": train_distill_multiopt(config) else: raise ValueError("model_type should be one of ['']")