save_load.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import errno
  18. import os
  19. import pickle
  20. import six
  21. import paddle
  22. from ppocr.utils.logging import get_logger
  23. __all__ = ['load_model']
  24. def _mkdir_if_not_exist(path, logger):
  25. """
  26. mkdir if not exists, ignore the exception when multiprocess mkdir together
  27. """
  28. if not os.path.exists(path):
  29. try:
  30. os.makedirs(path)
  31. except OSError as e:
  32. if e.errno == errno.EEXIST and os.path.isdir(path):
  33. logger.warning(
  34. 'be happy if some process has already created {}'.format(
  35. path))
  36. else:
  37. raise OSError('Failed to mkdir {}'.format(path))
  38. def load_model(config, model, optimizer=None, model_type='det'):
  39. """
  40. load model from checkpoint or pretrained_model
  41. """
  42. logger = get_logger()
  43. global_config = config['Global']
  44. checkpoints = global_config.get('checkpoints')
  45. pretrained_model = global_config.get('pretrained_model')
  46. best_model_dict = {}
  47. is_float16 = False
  48. is_nlp_model = model_type == 'kie' and config["Architecture"][
  49. "algorithm"] not in ["SDMGR"]
  50. if is_nlp_model is True:
  51. # NOTE: for kie model dsitillation, resume training is not supported now
  52. if config["Architecture"]["algorithm"] in ["Distillation"]:
  53. return best_model_dict
  54. checkpoints = config['Architecture']['Backbone']['checkpoints']
  55. # load kie method metric
  56. if checkpoints:
  57. if os.path.exists(os.path.join(checkpoints, 'metric.states')):
  58. with open(os.path.join(checkpoints, 'metric.states'),
  59. 'rb') as f:
  60. states_dict = pickle.load(f) if six.PY2 else pickle.load(
  61. f, encoding='latin1')
  62. best_model_dict = states_dict.get('best_model_dict', {})
  63. if 'epoch' in states_dict:
  64. best_model_dict['start_epoch'] = states_dict['epoch'] + 1
  65. logger.info("resume from {}".format(checkpoints))
  66. if optimizer is not None:
  67. if checkpoints[-1] in ['/', '\\']:
  68. checkpoints = checkpoints[:-1]
  69. if os.path.exists(checkpoints + '.pdopt'):
  70. optim_dict = paddle.load(checkpoints + '.pdopt')
  71. optimizer.set_state_dict(optim_dict)
  72. else:
  73. logger.warning(
  74. "{}.pdopt is not exists, params of optimizer is not loaded".
  75. format(checkpoints))
  76. return best_model_dict
  77. if checkpoints:
  78. if checkpoints.endswith('.pdparams'):
  79. checkpoints = checkpoints.replace('.pdparams', '')
  80. assert os.path.exists(checkpoints + ".pdparams"), \
  81. "The {}.pdparams does not exists!".format(checkpoints)
  82. # load params from trained model
  83. params = paddle.load(checkpoints + '.pdparams')
  84. state_dict = model.state_dict()
  85. new_state_dict = {}
  86. for key, value in state_dict.items():
  87. if key not in params:
  88. logger.warning("{} not in loaded params {} !".format(
  89. key, params.keys()))
  90. continue
  91. pre_value = params[key]
  92. if pre_value.dtype == paddle.float16:
  93. is_float16 = True
  94. if pre_value.dtype != value.dtype:
  95. pre_value = pre_value.astype(value.dtype)
  96. if list(value.shape) == list(pre_value.shape):
  97. new_state_dict[key] = pre_value
  98. else:
  99. logger.warning(
  100. "The shape of model params {} {} not matched with loaded params shape {} !".
  101. format(key, value.shape, pre_value.shape))
  102. model.set_state_dict(new_state_dict)
  103. if is_float16:
  104. logger.info(
  105. "The parameter type is float16, which is converted to float32 when loading"
  106. )
  107. if optimizer is not None:
  108. if os.path.exists(checkpoints + '.pdopt'):
  109. optim_dict = paddle.load(checkpoints + '.pdopt')
  110. optimizer.set_state_dict(optim_dict)
  111. else:
  112. logger.warning(
  113. "{}.pdopt is not exists, params of optimizer is not loaded".
  114. format(checkpoints))
  115. if os.path.exists(checkpoints + '.states'):
  116. with open(checkpoints + '.states', 'rb') as f:
  117. states_dict = pickle.load(f) if six.PY2 else pickle.load(
  118. f, encoding='latin1')
  119. best_model_dict = states_dict.get('best_model_dict', {})
  120. if 'epoch' in states_dict:
  121. best_model_dict['start_epoch'] = states_dict['epoch'] + 1
  122. logger.info("resume from {}".format(checkpoints))
  123. elif pretrained_model:
  124. is_float16 = load_pretrained_params(model, pretrained_model)
  125. else:
  126. logger.info('train from scratch')
  127. best_model_dict['is_float16'] = is_float16
  128. return best_model_dict
  129. def load_pretrained_params(model, path):
  130. logger = get_logger()
  131. if path.endswith('.pdparams'):
  132. path = path.replace('.pdparams', '')
  133. assert os.path.exists(path + ".pdparams"), \
  134. "The {}.pdparams does not exists!".format(path)
  135. params = paddle.load(path + '.pdparams')
  136. state_dict = model.state_dict()
  137. new_state_dict = {}
  138. is_float16 = False
  139. for k1 in params.keys():
  140. if k1 not in state_dict.keys():
  141. logger.warning("The pretrained params {} not in model".format(k1))
  142. else:
  143. if params[k1].dtype == paddle.float16:
  144. is_float16 = True
  145. if params[k1].dtype != state_dict[k1].dtype:
  146. params[k1] = params[k1].astype(state_dict[k1].dtype)
  147. if list(state_dict[k1].shape) == list(params[k1].shape):
  148. new_state_dict[k1] = params[k1]
  149. else:
  150. logger.warning(
  151. "The shape of model params {} {} not matched with loaded params {} {} !".
  152. format(k1, state_dict[k1].shape, k1, params[k1].shape))
  153. model.set_state_dict(new_state_dict)
  154. if is_float16:
  155. logger.info(
  156. "The parameter type is float16, which is converted to float32 when loading"
  157. )
  158. logger.info("load pretrain successful from {}".format(path))
  159. return is_float16
  160. def save_model(model,
  161. optimizer,
  162. model_path,
  163. logger,
  164. config,
  165. is_best=False,
  166. prefix='ppocr',
  167. **kwargs):
  168. """
  169. save model to the target path
  170. """
  171. _mkdir_if_not_exist(model_path, logger)
  172. model_prefix = os.path.join(model_path, prefix)
  173. paddle.save(optimizer.state_dict(), model_prefix + '.pdopt')
  174. is_nlp_model = config['Architecture']["model_type"] == 'kie' and config[
  175. "Architecture"]["algorithm"] not in ["SDMGR"]
  176. if is_nlp_model is not True:
  177. paddle.save(model.state_dict(), model_prefix + '.pdparams')
  178. metric_prefix = model_prefix
  179. else: # for kie system, we follow the save/load rules in NLP
  180. if config['Global']['distributed']:
  181. arch = model._layers
  182. else:
  183. arch = model
  184. if config["Architecture"]["algorithm"] in ["Distillation"]:
  185. arch = arch.Student
  186. arch.backbone.model.save_pretrained(model_prefix)
  187. metric_prefix = os.path.join(model_prefix, 'metric')
  188. # save metric and config
  189. with open(metric_prefix + '.states', 'wb') as f:
  190. pickle.dump(kwargs, f, protocol=2)
  191. if is_best:
  192. logger.info('save best model is to {}'.format(model_prefix))
  193. else:
  194. logger.info("save model in {}".format(model_prefix))