train.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. __dir__ = os.path.dirname(os.path.abspath(__file__))
  20. sys.path.append(__dir__)
  21. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, '..')))
  22. import yaml
  23. import paddle
  24. import paddle.distributed as dist
  25. from ppocr.data import build_dataloader
  26. from ppocr.modeling.architectures import build_model
  27. from ppocr.losses import build_loss
  28. from ppocr.optimizer import build_optimizer
  29. from ppocr.postprocess import build_post_process
  30. from ppocr.metrics import build_metric
  31. from ppocr.utils.save_load import load_model
  32. from ppocr.utils.utility import set_seed
  33. from ppocr.modeling.architectures import apply_to_static
  34. import tools.program as program
  35. dist.get_world_size()
  36. def main(config, device, logger, vdl_writer):
  37. # init dist environment
  38. if config['Global']['distributed']:
  39. dist.init_parallel_env()
  40. global_config = config['Global']
  41. # build dataloader
  42. train_dataloader = build_dataloader(config, 'Train', device, logger)
  43. if len(train_dataloader) == 0:
  44. logger.error(
  45. "No Images in train dataset, please ensure\n" +
  46. "\t1. The images num in the train label_file_list should be larger than or equal with batch size.\n"
  47. +
  48. "\t2. The annotation file and path in the configuration file are provided normally."
  49. )
  50. return
  51. if config['Eval']:
  52. valid_dataloader = build_dataloader(config, 'Eval', device, logger)
  53. else:
  54. valid_dataloader = None
  55. # build post process
  56. post_process_class = build_post_process(config['PostProcess'],
  57. global_config)
  58. # build model
  59. # for rec algorithm
  60. if hasattr(post_process_class, 'character'):
  61. char_num = len(getattr(post_process_class, 'character'))
  62. if config['Architecture']["algorithm"] in ["Distillation",
  63. ]: # distillation model
  64. for key in config['Architecture']["Models"]:
  65. if config['Architecture']['Models'][key]['Head'][
  66. 'name'] == 'MultiHead': # for multi head
  67. if config['PostProcess'][
  68. 'name'] == 'DistillationSARLabelDecode':
  69. char_num = char_num - 2
  70. # update SARLoss params
  71. assert list(config['Loss']['loss_config_list'][-1].keys())[
  72. 0] == 'DistillationSARLoss'
  73. config['Loss']['loss_config_list'][-1][
  74. 'DistillationSARLoss']['ignore_index'] = char_num + 1
  75. out_channels_list = {}
  76. out_channels_list['CTCLabelDecode'] = char_num
  77. out_channels_list['SARLabelDecode'] = char_num + 2
  78. config['Architecture']['Models'][key]['Head'][
  79. 'out_channels_list'] = out_channels_list
  80. else:
  81. config['Architecture']["Models"][key]["Head"][
  82. 'out_channels'] = char_num
  83. elif config['Architecture']['Head'][
  84. 'name'] == 'MultiHead': # for multi head
  85. if config['PostProcess']['name'] == 'SARLabelDecode':
  86. char_num = char_num - 2
  87. # update SARLoss params
  88. assert list(config['Loss']['loss_config_list'][1].keys())[
  89. 0] == 'SARLoss'
  90. if config['Loss']['loss_config_list'][1]['SARLoss'] is None:
  91. config['Loss']['loss_config_list'][1]['SARLoss'] = {
  92. 'ignore_index': char_num + 1
  93. }
  94. else:
  95. config['Loss']['loss_config_list'][1]['SARLoss'][
  96. 'ignore_index'] = char_num + 1
  97. out_channels_list = {}
  98. out_channels_list['CTCLabelDecode'] = char_num
  99. out_channels_list['SARLabelDecode'] = char_num + 2
  100. config['Architecture']['Head'][
  101. 'out_channels_list'] = out_channels_list
  102. else: # base rec model
  103. config['Architecture']["Head"]['out_channels'] = char_num
  104. if config['PostProcess']['name'] == 'SARLabelDecode': # for SAR model
  105. config['Loss']['ignore_index'] = char_num - 1
  106. model = build_model(config['Architecture'])
  107. use_sync_bn = config["Global"].get("use_sync_bn", False)
  108. if use_sync_bn:
  109. model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  110. logger.info('convert_sync_batchnorm')
  111. model = apply_to_static(model, config, logger)
  112. # build loss
  113. loss_class = build_loss(config['Loss'])
  114. # build optim
  115. optimizer, lr_scheduler = build_optimizer(
  116. config['Optimizer'],
  117. epochs=config['Global']['epoch_num'],
  118. step_each_epoch=len(train_dataloader),
  119. model=model)
  120. # build metric
  121. eval_class = build_metric(config['Metric'])
  122. logger.info('train dataloader has {} iters'.format(len(train_dataloader)))
  123. if valid_dataloader is not None:
  124. logger.info('valid dataloader has {} iters'.format(
  125. len(valid_dataloader)))
  126. use_amp = config["Global"].get("use_amp", False)
  127. amp_level = config["Global"].get("amp_level", 'O2')
  128. amp_custom_black_list = config['Global'].get('amp_custom_black_list', [])
  129. if use_amp:
  130. AMP_RELATED_FLAGS_SETTING = {'FLAGS_max_inplace_grad_add': 8, }
  131. if paddle.is_compiled_with_cuda():
  132. AMP_RELATED_FLAGS_SETTING.update({
  133. 'FLAGS_cudnn_batchnorm_spatial_persistent': 1
  134. })
  135. paddle.fluid.set_flags(AMP_RELATED_FLAGS_SETTING)
  136. scale_loss = config["Global"].get("scale_loss", 1.0)
  137. use_dynamic_loss_scaling = config["Global"].get(
  138. "use_dynamic_loss_scaling", False)
  139. scaler = paddle.amp.GradScaler(
  140. init_loss_scaling=scale_loss,
  141. use_dynamic_loss_scaling=use_dynamic_loss_scaling)
  142. if amp_level == "O2":
  143. model, optimizer = paddle.amp.decorate(
  144. models=model,
  145. optimizers=optimizer,
  146. level=amp_level,
  147. master_weight=True)
  148. else:
  149. scaler = None
  150. # load pretrain model
  151. pre_best_model_dict = load_model(config, model, optimizer,
  152. config['Architecture']["model_type"])
  153. if config['Global']['distributed']:
  154. model = paddle.DataParallel(model)
  155. # start train
  156. program.train(config, train_dataloader, valid_dataloader, device, model,
  157. loss_class, optimizer, lr_scheduler, post_process_class,
  158. eval_class, pre_best_model_dict, logger, vdl_writer, scaler,
  159. amp_level, amp_custom_black_list)
  160. def test_reader(config, device, logger):
  161. loader = build_dataloader(config, 'Train', device, logger)
  162. import time
  163. starttime = time.time()
  164. count = 0
  165. try:
  166. for data in loader():
  167. count += 1
  168. if count % 1 == 0:
  169. batch_time = time.time() - starttime
  170. starttime = time.time()
  171. logger.info("reader: {}, {}, {}".format(
  172. count, len(data[0]), batch_time))
  173. except Exception as e:
  174. logger.info(e)
  175. logger.info("finish reader: {}, Success!".format(count))
  176. if __name__ == '__main__':
  177. config, device, logger, vdl_writer = program.preprocess(is_train=True)
  178. seed = config['Global']['seed'] if 'seed' in config['Global'] else 1024
  179. set_seed(seed)
  180. main(config, device, logger, vdl_writer)
  181. # test_reader(config, device, logger)