train.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. import paddle
  2. import numpy as np
  3. import os
  4. import paddle.nn as nn
  5. import paddle.distributed as dist
  6. dist.get_world_size()
  7. dist.init_parallel_env()
  8. from loss import build_loss, LossDistill, DMLLoss, KLJSLoss
  9. from optimizer import create_optimizer
  10. from data_loader import build_dataloader
  11. from metric import create_metric
  12. from mv3 import MobileNetV3_large_x0_5, distillmv3_large_x0_5, build_model
  13. from config import preprocess
  14. import time
  15. from paddleslim.dygraph.quant import QAT
  16. from slim.slim_quant import PACT, quant_config
  17. from slim.slim_fpgm import prune_model
  18. from utils import load_model
  19. def _mkdir_if_not_exist(path, logger):
  20. """
  21. mkdir if not exists, ignore the exception when multiprocess mkdir together
  22. """
  23. if not os.path.exists(path):
  24. try:
  25. os.makedirs(path)
  26. except OSError as e:
  27. if e.errno == errno.EEXIST and os.path.isdir(path):
  28. logger.warning(
  29. 'be happy if some process has already created {}'.format(
  30. path))
  31. else:
  32. raise OSError('Failed to mkdir {}'.format(path))
  33. def save_model(model,
  34. optimizer,
  35. model_path,
  36. logger,
  37. is_best=False,
  38. prefix='ppocr',
  39. **kwargs):
  40. """
  41. save model to the target path
  42. """
  43. _mkdir_if_not_exist(model_path, logger)
  44. model_prefix = os.path.join(model_path, prefix)
  45. paddle.save(model.state_dict(), model_prefix + '.pdparams')
  46. if type(optimizer) is list:
  47. paddle.save(optimizer[0].state_dict(), model_prefix + '.pdopt')
  48. paddle.save(optimizer[1].state_dict(), model_prefix + "_1" + '.pdopt')
  49. else:
  50. paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
  51. # # save metric and config
  52. # with open(model_prefix + '.states', 'wb') as f:
  53. # pickle.dump(kwargs, f, protocol=2)
  54. if is_best:
  55. logger.info('save best model is to {}'.format(model_prefix))
  56. else:
  57. logger.info("save model in {}".format(model_prefix))
  58. def amp_scaler(config):
  59. if 'AMP' in config and config['AMP']['use_amp'] is True:
  60. AMP_RELATED_FLAGS_SETTING = {
  61. 'FLAGS_cudnn_batchnorm_spatial_persistent': 1,
  62. 'FLAGS_max_inplace_grad_add': 8,
  63. }
  64. paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
  65. scale_loss = config["AMP"].get("scale_loss", 1.0)
  66. use_dynamic_loss_scaling = config["AMP"].get("use_dynamic_loss_scaling",
  67. False)
  68. scaler = paddle.amp.GradScaler(
  69. init_loss_scaling=scale_loss,
  70. use_dynamic_loss_scaling=use_dynamic_loss_scaling)
  71. return scaler
  72. else:
  73. return None
  74. def set_seed(seed):
  75. paddle.seed(seed)
  76. np.random.seed(seed)
  77. def train(config, scaler=None):
  78. EPOCH = config['epoch']
  79. topk = config['topk']
  80. batch_size = config['TRAIN']['batch_size']
  81. num_workers = config['TRAIN']['num_workers']
  82. train_loader = build_dataloader(
  83. 'train', batch_size=batch_size, num_workers=num_workers)
  84. # build metric
  85. metric_func = create_metric
  86. # build model
  87. # model = MobileNetV3_large_x0_5(class_dim=100)
  88. model = build_model(config)
  89. # build_optimizer
  90. optimizer, lr_scheduler = create_optimizer(
  91. config, parameter_list=model.parameters())
  92. # load model
  93. pre_best_model_dict = load_model(config, model, optimizer)
  94. if len(pre_best_model_dict) > 0:
  95. pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
  96. ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
  97. logger.info(pre_str)
  98. # about slim prune and quant
  99. if "quant_train" in config and config['quant_train'] is True:
  100. quanter = QAT(config=quant_config, act_preprocess=PACT)
  101. quanter.quantize(model)
  102. elif "prune_train" in config and config['prune_train'] is True:
  103. model = prune_model(model, [1, 3, 32, 32], 0.1)
  104. else:
  105. pass
  106. # distribution
  107. model.train()
  108. model = paddle.DataParallel(model)
  109. # build loss function
  110. loss_func = build_loss(config)
  111. data_num = len(train_loader)
  112. best_acc = {}
  113. for epoch in range(EPOCH):
  114. st = time.time()
  115. for idx, data in enumerate(train_loader):
  116. img_batch, label = data
  117. img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
  118. label = paddle.unsqueeze(label, -1)
  119. if scaler is not None:
  120. with paddle.amp.auto_cast():
  121. outs = model(img_batch)
  122. else:
  123. outs = model(img_batch)
  124. # cal metric
  125. acc = metric_func(outs, label)
  126. # cal loss
  127. avg_loss = loss_func(outs, label)
  128. if scaler is None:
  129. # backward
  130. avg_loss.backward()
  131. optimizer.step()
  132. optimizer.clear_grad()
  133. else:
  134. scaled_avg_loss = scaler.scale(avg_loss)
  135. scaled_avg_loss.backward()
  136. scaler.minimize(optimizer, scaled_avg_loss)
  137. if not isinstance(lr_scheduler, float):
  138. lr_scheduler.step()
  139. if idx % 10 == 0:
  140. et = time.time()
  141. strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
  142. strs += f"loss: {avg_loss.numpy()[0]}"
  143. strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
  144. strs += f", batch_time: {round(et-st, 4)} s"
  145. logger.info(strs)
  146. st = time.time()
  147. if epoch % 10 == 0:
  148. acc = eval(config, model)
  149. if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
  150. best_acc = acc
  151. best_acc['epoch'] = epoch
  152. is_best = True
  153. else:
  154. is_best = False
  155. logger.info(
  156. f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
  157. )
  158. save_model(
  159. model,
  160. optimizer,
  161. config['save_model_dir'],
  162. logger,
  163. is_best,
  164. prefix="cls")
  165. def train_distill(config, scaler=None):
  166. EPOCH = config['epoch']
  167. topk = config['topk']
  168. batch_size = config['TRAIN']['batch_size']
  169. num_workers = config['TRAIN']['num_workers']
  170. train_loader = build_dataloader(
  171. 'train', batch_size=batch_size, num_workers=num_workers)
  172. # build metric
  173. metric_func = create_metric
  174. # model = distillmv3_large_x0_5(class_dim=100)
  175. model = build_model(config)
  176. # pact quant train
  177. if "quant_train" in config and config['quant_train'] is True:
  178. quanter = QAT(config=quant_config, act_preprocess=PACT)
  179. quanter.quantize(model)
  180. elif "prune_train" in config and config['prune_train'] is True:
  181. model = prune_model(model, [1, 3, 32, 32], 0.1)
  182. else:
  183. pass
  184. # build_optimizer
  185. optimizer, lr_scheduler = create_optimizer(
  186. config, parameter_list=model.parameters())
  187. # load model
  188. pre_best_model_dict = load_model(config, model, optimizer)
  189. if len(pre_best_model_dict) > 0:
  190. pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
  191. ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
  192. logger.info(pre_str)
  193. model.train()
  194. model = paddle.DataParallel(model)
  195. # build loss function
  196. loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
  197. loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
  198. loss_func_js = KLJSLoss(mode='js')
  199. data_num = len(train_loader)
  200. best_acc = {}
  201. for epoch in range(EPOCH):
  202. st = time.time()
  203. for idx, data in enumerate(train_loader):
  204. img_batch, label = data
  205. img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
  206. label = paddle.unsqueeze(label, -1)
  207. if scaler is not None:
  208. with paddle.amp.auto_cast():
  209. outs = model(img_batch)
  210. else:
  211. outs = model(img_batch)
  212. # cal metric
  213. acc = metric_func(outs['student'], label)
  214. # cal loss
  215. avg_loss = loss_func_distill(outs, label)['student'] + \
  216. loss_func_distill(outs, label)['student1'] + \
  217. loss_func_dml(outs, label)['student_student1']
  218. # backward
  219. if scaler is None:
  220. avg_loss.backward()
  221. optimizer.step()
  222. optimizer.clear_grad()
  223. else:
  224. scaled_avg_loss = scaler.scale(avg_loss)
  225. scaled_avg_loss.backward()
  226. scaler.minimize(optimizer, scaled_avg_loss)
  227. if not isinstance(lr_scheduler, float):
  228. lr_scheduler.step()
  229. if idx % 10 == 0:
  230. et = time.time()
  231. strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
  232. strs += f"loss: {avg_loss.numpy()[0]}"
  233. strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
  234. strs += f", batch_time: {round(et-st, 4)} s"
  235. logger.info(strs)
  236. st = time.time()
  237. if epoch % 10 == 0:
  238. acc = eval(config, model._layers.student)
  239. if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
  240. best_acc = acc
  241. best_acc['epoch'] = epoch
  242. is_best = True
  243. else:
  244. is_best = False
  245. logger.info(
  246. f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
  247. )
  248. save_model(
  249. model,
  250. optimizer,
  251. config['save_model_dir'],
  252. logger,
  253. is_best,
  254. prefix="cls_distill")
  255. def train_distill_multiopt(config, scaler=None):
  256. EPOCH = config['epoch']
  257. topk = config['topk']
  258. batch_size = config['TRAIN']['batch_size']
  259. num_workers = config['TRAIN']['num_workers']
  260. train_loader = build_dataloader(
  261. 'train', batch_size=batch_size, num_workers=num_workers)
  262. # build metric
  263. metric_func = create_metric
  264. # model = distillmv3_large_x0_5(class_dim=100)
  265. model = build_model(config)
  266. # build_optimizer
  267. optimizer, lr_scheduler = create_optimizer(
  268. config, parameter_list=model.student.parameters())
  269. optimizer1, lr_scheduler1 = create_optimizer(
  270. config, parameter_list=model.student1.parameters())
  271. # load model
  272. pre_best_model_dict = load_model(config, model, optimizer)
  273. if len(pre_best_model_dict) > 0:
  274. pre_str = 'The metric of loaded metric as follows {}'.format(', '.join(
  275. ['{}: {}'.format(k, v) for k, v in pre_best_model_dict.items()]))
  276. logger.info(pre_str)
  277. # quant train
  278. if "quant_train" in config and config['quant_train'] is True:
  279. quanter = QAT(config=quant_config, act_preprocess=PACT)
  280. quanter.quantize(model)
  281. elif "prune_train" in config and config['prune_train'] is True:
  282. model = prune_model(model, [1, 3, 32, 32], 0.1)
  283. else:
  284. pass
  285. model.train()
  286. model = paddle.DataParallel(model)
  287. # build loss function
  288. loss_func_distill = LossDistill(model_name_list=['student', 'student1'])
  289. loss_func_dml = DMLLoss(model_name_pairs=['student', 'student1'])
  290. loss_func_js = KLJSLoss(mode='js')
  291. data_num = len(train_loader)
  292. best_acc = {}
  293. for epoch in range(EPOCH):
  294. st = time.time()
  295. for idx, data in enumerate(train_loader):
  296. img_batch, label = data
  297. img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
  298. label = paddle.unsqueeze(label, -1)
  299. if scaler is not None:
  300. with paddle.amp.auto_cast():
  301. outs = model(img_batch)
  302. else:
  303. outs = model(img_batch)
  304. # cal metric
  305. acc = metric_func(outs['student'], label)
  306. # cal loss
  307. avg_loss = loss_func_distill(outs,
  308. label)['student'] + loss_func_dml(
  309. outs, label)['student_student1']
  310. avg_loss1 = loss_func_distill(outs,
  311. label)['student1'] + loss_func_dml(
  312. outs, label)['student_student1']
  313. if scaler is None:
  314. # backward
  315. avg_loss.backward(retain_graph=True)
  316. optimizer.step()
  317. optimizer.clear_grad()
  318. avg_loss1.backward()
  319. optimizer1.step()
  320. optimizer1.clear_grad()
  321. else:
  322. scaled_avg_loss = scaler.scale(avg_loss)
  323. scaled_avg_loss.backward()
  324. scaler.minimize(optimizer, scaled_avg_loss)
  325. scaled_avg_loss = scaler.scale(avg_loss1)
  326. scaled_avg_loss.backward()
  327. scaler.minimize(optimizer1, scaled_avg_loss)
  328. if not isinstance(lr_scheduler, float):
  329. lr_scheduler.step()
  330. if not isinstance(lr_scheduler1, float):
  331. lr_scheduler1.step()
  332. if idx % 10 == 0:
  333. et = time.time()
  334. strs = f"epoch: [{epoch}/{EPOCH}], iter: [{idx}/{data_num}], "
  335. strs += f"loss: {avg_loss.numpy()[0]}, loss1: {avg_loss1.numpy()[0]}"
  336. strs += f", acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
  337. strs += f", batch_time: {round(et-st, 4)} s"
  338. logger.info(strs)
  339. st = time.time()
  340. if epoch % 10 == 0:
  341. acc = eval(config, model._layers.student)
  342. if len(best_acc) < 1 or acc['top5'].numpy()[0] > best_acc['top5']:
  343. best_acc = acc
  344. best_acc['epoch'] = epoch
  345. is_best = True
  346. else:
  347. is_best = False
  348. logger.info(
  349. f"The best acc: acc_topk1: {best_acc['top1'].numpy()[0]}, acc_top5: {best_acc['top5'].numpy()[0]}, best_epoch: {best_acc['epoch']}"
  350. )
  351. save_model(
  352. model, [optimizer, optimizer1],
  353. config['save_model_dir'],
  354. logger,
  355. is_best,
  356. prefix="cls_distill_multiopt")
  357. def eval(config, model):
  358. batch_size = config['VALID']['batch_size']
  359. num_workers = config['VALID']['num_workers']
  360. valid_loader = build_dataloader(
  361. 'test', batch_size=batch_size, num_workers=num_workers)
  362. # build metric
  363. metric_func = create_metric
  364. outs = []
  365. labels = []
  366. for idx, data in enumerate(valid_loader):
  367. img_batch, label = data
  368. img_batch = paddle.transpose(img_batch, [0, 3, 1, 2])
  369. label = paddle.unsqueeze(label, -1)
  370. out = model(img_batch)
  371. outs.append(out)
  372. labels.append(label)
  373. outs = paddle.concat(outs, axis=0)
  374. labels = paddle.concat(labels, axis=0)
  375. acc = metric_func(outs, labels)
  376. strs = f"The metric are as follows: acc_topk1: {acc['top1'].numpy()[0]}, acc_top5: {acc['top5'].numpy()[0]}"
  377. logger.info(strs)
  378. return acc
  379. if __name__ == "__main__":
  380. config, logger = preprocess(is_train=False)
  381. # AMP scaler
  382. scaler = amp_scaler(config)
  383. model_type = config['model_type']
  384. if model_type == "cls":
  385. train(config)
  386. elif model_type == "cls_distill":
  387. train_distill(config)
  388. elif model_type == "cls_distill_multiopt":
  389. train_distill_multiopt(config)
  390. else:
  391. raise ValueError("model_type should be one of ['']")