export_model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import sys
  16. __dir__ = os.path.dirname(os.path.abspath(__file__))
  17. sys.path.append(__dir__)
  18. sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
  19. import argparse
  20. import paddle
  21. from paddle.jit import to_static
  22. from ppocr.modeling.architectures import build_model
  23. from ppocr.postprocess import build_post_process
  24. from ppocr.utils.save_load import load_model
  25. from ppocr.utils.logging import get_logger
  26. from tools.program import load_config, merge_config, ArgsParser
  27. def export_single_model(model,
  28. arch_config,
  29. save_path,
  30. logger,
  31. input_shape=None,
  32. quanter=None):
  33. if arch_config["algorithm"] == "SRN":
  34. max_text_length = arch_config["Head"]["max_text_length"]
  35. other_shape = [
  36. paddle.static.InputSpec(
  37. shape=[None, 1, 64, 256], dtype="float32"), [
  38. paddle.static.InputSpec(
  39. shape=[None, 256, 1],
  40. dtype="int64"), paddle.static.InputSpec(
  41. shape=[None, max_text_length, 1], dtype="int64"),
  42. paddle.static.InputSpec(
  43. shape=[None, 8, max_text_length, max_text_length],
  44. dtype="int64"), paddle.static.InputSpec(
  45. shape=[None, 8, max_text_length, max_text_length],
  46. dtype="int64")
  47. ]
  48. ]
  49. model = to_static(model, input_spec=other_shape)
  50. elif arch_config["algorithm"] == "SAR":
  51. other_shape = [
  52. paddle.static.InputSpec(
  53. shape=[None, 3, 48, 160], dtype="float32"),
  54. [paddle.static.InputSpec(
  55. shape=[None], dtype="float32")]
  56. ]
  57. model = to_static(model, input_spec=other_shape)
  58. elif arch_config["algorithm"] == "SVTR":
  59. if arch_config["Head"]["name"] == 'MultiHead':
  60. other_shape = [
  61. paddle.static.InputSpec(
  62. shape=[None, 3, 48, -1], dtype="float32"),
  63. ]
  64. else:
  65. other_shape = [
  66. paddle.static.InputSpec(
  67. shape=[None] + input_shape, dtype="float32"),
  68. ]
  69. model = to_static(model, input_spec=other_shape)
  70. elif arch_config["algorithm"] == "PREN":
  71. other_shape = [
  72. paddle.static.InputSpec(
  73. shape=[None, 3, 64, 256], dtype="float32"),
  74. ]
  75. model = to_static(model, input_spec=other_shape)
  76. elif arch_config["model_type"] == "sr":
  77. other_shape = [
  78. paddle.static.InputSpec(
  79. shape=[None, 3, 16, 64], dtype="float32")
  80. ]
  81. model = to_static(model, input_spec=other_shape)
  82. elif arch_config["algorithm"] == "ViTSTR":
  83. other_shape = [
  84. paddle.static.InputSpec(
  85. shape=[None, 1, 224, 224], dtype="float32"),
  86. ]
  87. model = to_static(model, input_spec=other_shape)
  88. elif arch_config["algorithm"] == "ABINet":
  89. other_shape = [
  90. paddle.static.InputSpec(
  91. shape=[None, 3, 32, 128], dtype="float32"),
  92. ]
  93. # print([None, 3, 32, 128])
  94. model = to_static(model, input_spec=other_shape)
  95. elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:
  96. other_shape = [
  97. paddle.static.InputSpec(
  98. shape=[None, 1, 32, 100], dtype="float32"),
  99. ]
  100. model = to_static(model, input_spec=other_shape)
  101. elif arch_config["algorithm"] == "VisionLAN":
  102. other_shape = [
  103. paddle.static.InputSpec(
  104. shape=[None, 3, 64, 256], dtype="float32"),
  105. ]
  106. model = to_static(model, input_spec=other_shape)
  107. elif arch_config["algorithm"] == "RobustScanner":
  108. max_text_length = arch_config["Head"]["max_text_length"]
  109. other_shape = [
  110. paddle.static.InputSpec(
  111. shape=[None, 3, 48, 160], dtype="float32"), [
  112. paddle.static.InputSpec(
  113. shape=[None, ], dtype="float32"),
  114. paddle.static.InputSpec(
  115. shape=[None, max_text_length], dtype="int64")
  116. ]
  117. ]
  118. model = to_static(model, input_spec=other_shape)
  119. elif arch_config["algorithm"] == "CAN":
  120. other_shape = [[
  121. paddle.static.InputSpec(
  122. shape=[None, 1, None, None],
  123. dtype="float32"), paddle.static.InputSpec(
  124. shape=[None, 1, None, None], dtype="float32"),
  125. paddle.static.InputSpec(
  126. shape=[None, arch_config['Head']['max_text_length']],
  127. dtype="int64")
  128. ]]
  129. model = to_static(model, input_spec=other_shape)
  130. elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
  131. input_spec = [
  132. paddle.static.InputSpec(
  133. shape=[None, 512], dtype="int64"), # input_ids
  134. paddle.static.InputSpec(
  135. shape=[None, 512, 4], dtype="int64"), # bbox
  136. paddle.static.InputSpec(
  137. shape=[None, 512], dtype="int64"), # attention_mask
  138. paddle.static.InputSpec(
  139. shape=[None, 512], dtype="int64"), # token_type_ids
  140. paddle.static.InputSpec(
  141. shape=[None, 3, 224, 224], dtype="int64"), # image
  142. ]
  143. if 'Re' in arch_config['Backbone']['name']:
  144. input_spec.extend([
  145. paddle.static.InputSpec(
  146. shape=[None, 512, 3], dtype="int64"), # entities
  147. paddle.static.InputSpec(
  148. shape=[None, None, 2], dtype="int64"), # relations
  149. ])
  150. if model.backbone.use_visual_backbone is False:
  151. input_spec.pop(4)
  152. model = to_static(model, input_spec=[input_spec])
  153. else:
  154. infer_shape = [3, -1, -1]
  155. if arch_config["model_type"] == "rec":
  156. infer_shape = [3, 32, -1] # for rec model, H must be 32
  157. if "Transform" in arch_config and arch_config[
  158. "Transform"] is not None and arch_config["Transform"][
  159. "name"] == "TPS":
  160. logger.info(
  161. "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
  162. )
  163. infer_shape[-1] = 100
  164. elif arch_config["model_type"] == "table":
  165. infer_shape = [3, 488, 488]
  166. if arch_config["algorithm"] == "TableMaster":
  167. infer_shape = [3, 480, 480]
  168. if arch_config["algorithm"] == "SLANet":
  169. infer_shape = [3, -1, -1]
  170. model = to_static(
  171. model,
  172. input_spec=[
  173. paddle.static.InputSpec(
  174. shape=[None] + infer_shape, dtype="float32")
  175. ])
  176. if quanter is None:
  177. paddle.jit.save(model, save_path)
  178. else:
  179. quanter.save_quantized_model(model, save_path)
  180. logger.info("inference model is saved to {}".format(save_path))
  181. return
  182. def main():
  183. FLAGS = ArgsParser().parse_args()
  184. config = load_config(FLAGS.config)
  185. config = merge_config(config, FLAGS.opt)
  186. logger = get_logger()
  187. # build post process
  188. post_process_class = build_post_process(config["PostProcess"],
  189. config["Global"])
  190. # build model
  191. # for rec algorithm
  192. if hasattr(post_process_class, "character"):
  193. char_num = len(getattr(post_process_class, "character"))
  194. if config["Architecture"]["algorithm"] in ["Distillation",
  195. ]: # distillation model
  196. for key in config["Architecture"]["Models"]:
  197. if config["Architecture"]["Models"][key]["Head"][
  198. "name"] == 'MultiHead': # multi head
  199. out_channels_list = {}
  200. if config['PostProcess'][
  201. 'name'] == 'DistillationSARLabelDecode':
  202. char_num = char_num - 2
  203. out_channels_list['CTCLabelDecode'] = char_num
  204. out_channels_list['SARLabelDecode'] = char_num + 2
  205. config['Architecture']['Models'][key]['Head'][
  206. 'out_channels_list'] = out_channels_list
  207. else:
  208. config["Architecture"]["Models"][key]["Head"][
  209. "out_channels"] = char_num
  210. # just one final tensor needs to exported for inference
  211. config["Architecture"]["Models"][key][
  212. "return_all_feats"] = False
  213. elif config['Architecture']['Head'][
  214. 'name'] == 'MultiHead': # multi head
  215. out_channels_list = {}
  216. char_num = len(getattr(post_process_class, 'character'))
  217. if config['PostProcess']['name'] == 'SARLabelDecode':
  218. char_num = char_num - 2
  219. out_channels_list['CTCLabelDecode'] = char_num
  220. out_channels_list['SARLabelDecode'] = char_num + 2
  221. config['Architecture']['Head'][
  222. 'out_channels_list'] = out_channels_list
  223. else: # base rec model
  224. config["Architecture"]["Head"]["out_channels"] = char_num
  225. # for sr algorithm
  226. if config["Architecture"]["model_type"] == "sr":
  227. config['Architecture']["Transform"]['infer_mode'] = True
  228. model = build_model(config["Architecture"])
  229. load_model(config, model, model_type=config['Architecture']["model_type"])
  230. model.eval()
  231. save_path = config["Global"]["save_inference_dir"]
  232. arch_config = config["Architecture"]
  233. if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
  234. "name"] != 'MultiHead':
  235. input_shape = config["Eval"]["dataset"]["transforms"][-2][
  236. 'SVTRRecResizeImg']['image_shape']
  237. else:
  238. input_shape = None
  239. if arch_config["algorithm"] in ["Distillation", ]: # distillation model
  240. archs = list(arch_config["Models"].values())
  241. for idx, name in enumerate(model.model_name_list):
  242. sub_model_save_path = os.path.join(save_path, name, "inference")
  243. export_single_model(model.model_list[idx], archs[idx],
  244. sub_model_save_path, logger)
  245. else:
  246. save_path = os.path.join(save_path, "inference")
  247. export_single_model(
  248. model, arch_config, save_path, logger, input_shape=input_shape)
  249. if __name__ == "__main__":
  250. main()