123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474 |
- import paddle
- import numpy as np
- import os
- import paddle.nn as nn
- import paddle.distributed as dist
- dist.get_world_size()
- dist.init_parallel_env()
- from loss import build_loss, LossDistill, DMLLoss, KLJSLoss
- from optimizer import create_optimizer
- from data_loader import build_dataloader
- from metric import create_metric
- from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model
- from config import preprocess
- import time
- from paddleslim.dygraph.quant import QAT
- from slim.slim_quant import PACT, quant_config
- from slim.slim_fpgm import prune_model
- from utils import load_model
- def _mkdir_if_not_exist(path, logger):
- """
- mkdir if not exists, ignore the exception when multiprocess mkdir together
- """
- if not os.path.exists(path):
- try:
- os.makedirs(path)
- except OSError as e:
- if e.errno == errno.EEXIST and os.path.isdir(path):
- logger.warning(
- 'be happy if some process has already created {}'.format(
- path))
- else:
- raise OSError('Failed to mkdir {}'.format(path))
- def save_model(model,
- optimizer,
- model_path,
- logger,
- is_best=False,
- prefix='ppocr',
- **kwargs):
- """
- save model to the target path
- """
- _mkdir_if_not_exist(model_path, logger)
- model_prefix = os.path.join(model_path, prefix)
- paddle.save(model.state_dict(), model_prefix + '.pdparams')
- if type(optimizer) is list:
- paddle.save(optimizer[0].state_dict(), model_prefix + '.pdopt')
- paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + '.pdopt')
- else:
- paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
- # # save metric and config
- # with open(model_prefix + '.states', 'wb') as f:
- # pickle.dump(kwargs, f, protocol=2)
- if is_best:
- logger.info('save best model is to {}'.format(model_prefix))
- else:
- logger.info("save model in {}".format(model_prefix))
- def amp_scaler(config):
- if 'AMP' in config and config['AMP']['use_amp'] is True:
- AMP_RELATED_FLAGS_SETTING = {
- 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
- 'FLAGS_max_inplace_grad_add': 8,
- }
- paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
- scale_loss = config["AMP"].get("scale_loss", 1.0)
- use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling",
- False)
- scaler = paddle.amp.GradScaler(
- init_loss_scaling=scale_loss,
- use_dynamic_loss_scaling=use_dynamic_loss_scaling)
- return scaler
- else:
- return None
- def set_seed(seed):
- paddle.seed(seed)
- np.random.seed(seed)
- def train(config, scaler=None):
- EPOCH = config['epoch']
- topk = config['topk']
- batch_size = config['TRAIN']['batch_size']
- num_workers = config['TRAIN']['num_workers']
- train_loader = build_dataloader(
- 'train', batch_size=batch_size, num_workers=num_workers)
- # build metric
- metric_func = create_metric
- # build model
- # model = MobileNetV3_large_x0_5(class_dim=100)
- model = build_model(config)
- # build_optimizer
- optimizer, lr_scheduler = create_optimizer(
- config, parameter_list=model.parameters())
- # load model
- pre_best_model_dict = load_model(config, model, optimizer)
- if len(pre_best_model_dict) > 0:
- pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
- logger.info(pre_str)
- # about slim prune and quant
- if "quant_train" in config and config['quant_train'] is True:
- quanter = QAT(config=quant_config, act_preprocess=PACT)
- quanter.quantize(model)
- elif "prune_train" in config and config['prune_train'] is True:
- model = prune_model(model, [1, 3, 32, 32], 0.1)
- else:
- pass
- # distribution
- model.train()
- model = paddle.DataParallel(model)
- # build loss function
- loss_func = build_loss(config)
- data_num = len(train_loader)
- best_acc = {}
- for epoch in range(EPOCH):
- st = time.time()
- for idx, data in enumerate(train_loader):
- img_batch, label = data
- img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
- label = paddle.unsqueeze(label, -1)
- if scaler is not None:
- with paddle.amp.auto_cast():
- outs = model(img_batch)
- else:
- outs = model(img_batch)
- # cal metric
- acc = metric_func(outs, label)
- # cal loss
- avg_loss = loss_func(outs, label)
- if scaler is None:
- # backward
- avg_loss.backward()
- optimizer.step()
- optimizer.clear_grad()
- else:
- scaled_avg_loss = scaler.scale(avg_loss)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer, scaled_avg_loss)
- if not isinstance(lr_scheduler, float):
- lr_scheduler.step()
- if idx % 10 == 0:
- et = time.time()
- strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
- strs += f"loss: {avg_loss.numpy()[0]}"
- strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
- strs += f", batch_time: {round(et-st, 4)} s"
- logger.info(strs)
- st = time.time()
- if epoch % 10 == 0:
- acc = eval(config, model)
- if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
- best_acc = acc
- best_acc['epoch'] = epoch
- is_best = True
- else:
- is_best = False
- logger.info(
- f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
- )
- save_model(
- model,
- optimizer,
- config['save_model_dir'],
- logger,
- is_best,
- prefix="cls")
- def train_distill(config, scaler=None):
- EPOCH = config['epoch']
- topk = config['topk']
- batch_size = config['TRAIN']['batch_size']
- num_workers = config['TRAIN']['num_workers']
- train_loader = build_dataloader(
- 'train', batch_size=batch_size, num_workers=num_workers)
- # build metric
- metric_func = create_metric
- # model = distillmv3_large_x0_5(class_dim=100)
- model = build_model(config)
- # pact quant train
- if "quant_train" in config and config['quant_train'] is True:
- quanter = QAT(config=quant_config, act_preprocess=PACT)
- quanter.quantize(model)
- elif "prune_train" in config and config['prune_train'] is True:
- model = prune_model(model, [1, 3, 32, 32], 0.1)
- else:
- pass
- # build_optimizer
- optimizer, lr_scheduler = create_optimizer(
- config, parameter_list=model.parameters())
- # load model
- pre_best_model_dict = load_model(config, model, optimizer)
- if len(pre_best_model_dict) > 0:
- pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
- logger.info(pre_str)
- model.train()
- model = paddle.DataParallel(model)
- # build loss function
- loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
- loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
- loss_func_js = KLJSLoss(mode='js')
- data_num = len(train_loader)
- best_acc = {}
- for epoch in range(EPOCH):
- st = time.time()
- for idx, data in enumerate(train_loader):
- img_batch, label = data
- img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
- label = paddle.unsqueeze(label, -1)
- if scaler is not None:
- with paddle.amp.auto_cast():
- outs = model(img_batch)
- else:
- outs = model(img_batch)
- # cal metric
- acc = metric_func(outs['student'], label)
- # cal loss
- avg_loss = loss_func_distill(outs, label)['student'] + \
- loss_func_distill(outs, label)['student1'] + \
- loss_func_dml(outs, label)['student_student1']
- # backward
- if scaler is None:
- avg_loss.backward()
- optimizer.step()
- optimizer.clear_grad()
- else:
- scaled_avg_loss = scaler.scale(avg_loss)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer, scaled_avg_loss)
- if not isinstance(lr_scheduler, float):
- lr_scheduler.step()
- if idx % 10 == 0:
- et = time.time()
- strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
- strs += f"loss: {avg_loss.numpy()[0]}"
- strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
- strs += f", batch_time: {round(et-st, 4)} s"
- logger.info(strs)
- st = time.time()
- if epoch % 10 == 0:
- acc = eval(config, model._layers.student)
- if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
- best_acc = acc
- best_acc['epoch'] = epoch
- is_best = True
- else:
- is_best = False
- logger.info(
- f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
- )
- save_model(
- model,
- optimizer,
- config['save_model_dir'],
- logger,
- is_best,
- prefix="cls_distill")
- def train_distill_multiopt(config, scaler=None):
- EPOCH = config['epoch']
- topk = config['topk']
- batch_size = config['TRAIN']['batch_size']
- num_workers = config['TRAIN']['num_workers']
- train_loader = build_dataloader(
- 'train', batch_size=batch_size, num_workers=num_workers)
- # build metric
- metric_func = create_metric
- # model = distillmv3_large_x0_5(class_dim=100)
- model = build_model(config)
- # build_optimizer
- optimizer, lr_scheduler = create_optimizer(
- config, parameter_list=model.student.parameters())
- optimizer1, lr_scheduler1 = create_optimizer(
- config, parameter_list=model.student1.parameters())
- # load model
- pre_best_model_dict = load_model(config, model, optimizer)
- if len(pre_best_model_dict) > 0:
- pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
- logger.info(pre_str)
- # quant train
- if "quant_train" in config and config['quant_train'] is True:
- quanter = QAT(config=quant_config, act_preprocess=PACT)
- quanter.quantize(model)
- elif "prune_train" in config and config['prune_train'] is True:
- model = prune_model(model, [1, 3, 32, 32], 0.1)
- else:
- pass
- model.train()
- model = paddle.DataParallel(model)
- # build loss function
- loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
- loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
- loss_func_js = KLJSLoss(mode='js')
- data_num = len(train_loader)
- best_acc = {}
- for epoch in range(EPOCH):
- st = time.time()
- for idx, data in enumerate(train_loader):
- img_batch, label = data
- img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
- label = paddle.unsqueeze(label, -1)
- if scaler is not None:
- with paddle.amp.auto_cast():
- outs = model(img_batch)
- else:
- outs = model(img_batch)
- # cal metric
- acc = metric_func(outs['student'], label)
- # cal loss
- avg_loss = loss_func_distill(outs,
- label)['student'] + loss_func_dml(
- outs, label)['student_student1']
- avg_loss1 = loss_func_distill(outs,
- label)['student1'] + loss_func_dml(
- outs, label)['student_student1']
- if scaler is None:
- # backward
- avg_loss.backward(retain_graph=True)
- optimizer.step()
- optimizer.clear_grad()
- avg_loss1.backward()
- optimizer1.step()
- optimizer1.clear_grad()
- else:
- scaled_avg_loss = scaler.scale(avg_loss)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer, scaled_avg_loss)
- scaled_avg_loss = scaler.scale(avg_loss1)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer1, scaled_avg_loss)
- if not isinstance(lr_scheduler, float):
- lr_scheduler.step()
- if not isinstance(lr_scheduler1, float):
- lr_scheduler1.step()
- if idx % 10 == 0:
- et = time.time()
- strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
- strs += f"loss: {avg_loss.numpy()[0]}, loss1: {avg_loss1.numpy()[0]}"
- strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
- strs += f", batch_time: {round(et-st, 4)} s"
- logger.info(strs)
- st = time.time()
- if epoch % 10 == 0:
- acc = eval(config, model._layers.student)
- if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
- best_acc = acc
- best_acc['epoch'] = epoch
- is_best = True
- else:
- is_best = False
- logger.info(
- f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
- )
- save_model(
- model, [optimizer, optimizer1],
- config['save_model_dir'],
- logger,
- is_best,
- prefix="cls_distill_multiopt")
- def eval(config, model):
- batch_size = config['VALID']['batch_size']
- num_workers = config['VALID']['num_workers']
- valid_loader = build_dataloader(
- 'test', batch_size=batch_size, num_workers=num_workers)
- # build metric
- metric_func = create_metric
- outs = []
- labels = []
- for idx, data in enumerate(valid_loader):
- img_batch, label = data
- img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
- label = paddle.unsqueeze(label, -1)
- out = model(img_batch)
- outs.append(out)
- labels.append(label)
- outs = paddle.concat(outs, axis=0)
- labels = paddle.concat(labels, axis=0)
- acc = metric_func(outs, labels)
- strs = f"The metric are as follows: acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
- logger.info(strs)
- return acc
- if __name__ == "__main__":
- config, logger = preprocess(is_train=False)
- # AMP scaler
- scaler = amp_scaler(config)
- model_type = config['model_type']
- if model_type == "cls":
- train(config)
- elif model_type == "cls_distill":
- train_distill(config)
- elif model_type == "cls_distill_multiopt":
- train_distill_multiopt(config)
- else:
- raise ValueError("model_type should be one of ['']")
|