123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import sys
- __dir__ = os.path.dirname(os.path.abspath(__file__))
- sys.path.append(__dir__)
- sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))
- import argparse
- import paddle
- from paddle.jit import to_static
- from ppocr.modeling.architectures import build_model
- from ppocr.postprocess import build_post_process
- from ppocr.utils.save_load import load_model
- from ppocr.utils.logging import get_logger
- from tools.program import load_config, merge_config, ArgsParser
- def export_single_model(model,
- arch_config,
- save_path,
- logger,
- input_shape=None,
- quanter=None):
- if arch_config["algorithm"] == "SRN":
- max_text_length = arch_config["Head"]["max_text_length"]
- other_shape = [
- paddle.static.InputSpec(
- shape=[None, 1, 64, 256], dtype="float32"), [
- paddle.static.InputSpec(
- shape=[None, 256, 1],
- dtype="int64"), paddle.static.InputSpec(
- shape=[None, max_text_length, 1], dtype="int64"),
- paddle.static.InputSpec(
- shape=[None, 8, max_text_length, max_text_length],
- dtype="int64"), paddle.static.InputSpec(
- shape=[None, 8, max_text_length, max_text_length],
- dtype="int64")
- ]
- ]
- model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] == "SAR":
- other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 48, 160], dtype="float32"),
- [paddle.static.InputSpec(
- shape=[None], dtype="float32")]
- ]
- model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] == "SVTR":
- if arch_config["Head"]["name"] == 'MultiHead':
- other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 48, -1], dtype="float32"),
- ]
- else:
- other_shape = [
- paddle.static.InputSpec(
- shape=[None] + input_shape, dtype="float32"),
- ]
- model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] == "PREN":
- other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 64, 256], dtype="float32"),
- ]
- model = to_static(model, input_spec=other_shape)
- elif arch_config["model_type"] == "sr":
- other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 16, 64], dtype="float32")
- ]
- model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] == "ViTSTR":
- other_shape = [
- paddle.static.InputSpec(
- shape=[None, 1, 224, 224], dtype="float32"),
- ]
- model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] == "ABINet":
- other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 32, 128], dtype="float32"),
- ]
- # print([None, 3, 32, 128])
- model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:
- other_shape = [
- paddle.static.InputSpec(
- shape=[None, 1, 32, 100], dtype="float32"),
- ]
- model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] == "VisionLAN":
- other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 64, 256], dtype="float32"),
- ]
- model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] == "RobustScanner":
- max_text_length = arch_config["Head"]["max_text_length"]
- other_shape = [
- paddle.static.InputSpec(
- shape=[None, 3, 48, 160], dtype="float32"), [
- paddle.static.InputSpec(
- shape=[None, ], dtype="float32"),
- paddle.static.InputSpec(
- shape=[None, max_text_length], dtype="int64")
- ]
- ]
- model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] == "CAN":
- other_shape = [[
- paddle.static.InputSpec(
- shape=[None, 1, None, None],
- dtype="float32"), paddle.static.InputSpec(
- shape=[None, 1, None, None], dtype="float32"),
- paddle.static.InputSpec(
- shape=[None, arch_config['Head']['max_text_length']],
- dtype="int64")
- ]]
- model = to_static(model, input_spec=other_shape)
- elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:
- input_spec = [
- paddle.static.InputSpec(
- shape=[None, 512], dtype="int64"), # input_ids
- paddle.static.InputSpec(
- shape=[None, 512, 4], dtype="int64"), # bbox
- paddle.static.InputSpec(
- shape=[None, 512], dtype="int64"), # attention_mask
- paddle.static.InputSpec(
- shape=[None, 512], dtype="int64"), # token_type_ids
- paddle.static.InputSpec(
- shape=[None, 3, 224, 224], dtype="int64"), # image
- ]
- if 'Re' in arch_config['Backbone']['name']:
- input_spec.extend([
- paddle.static.InputSpec(
- shape=[None, 512, 3], dtype="int64"), # entities
- paddle.static.InputSpec(
- shape=[None, None, 2], dtype="int64"), # relations
- ])
- if model.backbone.use_visual_backbone is False:
- input_spec.pop(4)
- model = to_static(model, input_spec=[input_spec])
- else:
- infer_shape = [3, -1, -1]
- if arch_config["model_type"] == "rec":
- infer_shape = [3, 32, -1] # for rec model, H must be 32
- if "Transform" in arch_config and arch_config[
- "Transform"] is not None and arch_config["Transform"][
- "name"] == "TPS":
- logger.info(
- "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"
- )
- infer_shape[-1] = 100
- elif arch_config["model_type"] == "table":
- infer_shape = [3, 488, 488]
- if arch_config["algorithm"] == "TableMaster":
- infer_shape = [3, 480, 480]
- if arch_config["algorithm"] == "SLANet":
- infer_shape = [3, -1, -1]
- model = to_static(
- model,
- input_spec=[
- paddle.static.InputSpec(
- shape=[None] + infer_shape, dtype="float32")
- ])
- if quanter is None:
- paddle.jit.save(model, save_path)
- else:
- quanter.save_quantized_model(model, save_path)
- logger.info("inference model is saved to {}".format(save_path))
- return
- def main():
- FLAGS = ArgsParser().parse_args()
- config = load_config(FLAGS.config)
- config = merge_config(config, FLAGS.opt)
- logger = get_logger()
- # build post process
- post_process_class = build_post_process(config["PostProcess"],
- config["Global"])
- # build model
- # for rec algorithm
- if hasattr(post_process_class, "character"):
- char_num = len(getattr(post_process_class, "character"))
- if config["Architecture"]["algorithm"] in ["Distillation",
- ]: # distillation model
- for key in config["Architecture"]["Models"]:
- if config["Architecture"]["Models"][key]["Head"][
- "name"] == 'MultiHead': # multi head
- out_channels_list = {}
- if config['PostProcess'][
- 'name'] == 'DistillationSARLabelDecode':
- char_num = char_num - 2
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- config['Architecture']['Models'][key]['Head'][
- 'out_channels_list'] = out_channels_list
- else:
- config["Architecture"]["Models"][key]["Head"][
- "out_channels"] = char_num
- # just one final tensor needs to exported for inference
- config["Architecture"]["Models"][key][
- "return_all_feats"] = False
- elif config['Architecture']['Head'][
- 'name'] == 'MultiHead': # multi head
- out_channels_list = {}
- char_num = len(getattr(post_process_class, 'character'))
- if config['PostProcess']['name'] == 'SARLabelDecode':
- char_num = char_num - 2
- out_channels_list['CTCLabelDecode'] = char_num
- out_channels_list['SARLabelDecode'] = char_num + 2
- config['Architecture']['Head'][
- 'out_channels_list'] = out_channels_list
- else: # base rec model
- config["Architecture"]["Head"]["out_channels"] = char_num
- # for sr algorithm
- if config["Architecture"]["model_type"] == "sr":
- config['Architecture']["Transform"]['infer_mode'] = True
- model = build_model(config["Architecture"])
- load_model(config, model, model_type=config['Architecture']["model_type"])
- model.eval()
- save_path = config["Global"]["save_inference_dir"]
- arch_config = config["Architecture"]
- if arch_config["algorithm"] == "SVTR" and arch_config["Head"][
- "name"] != 'MultiHead':
- input_shape = config["Eval"]["dataset"]["transforms"][-2][
- 'SVTRRecResizeImg']['image_shape']
- else:
- input_shape = None
- if arch_config["algorithm"] in ["Distillation", ]: # distillation model
- archs = list(arch_config["Models"].values())
- for idx, name in enumerate(model.model_name_list):
- sub_model_save_path = os.path.join(save_path, name, "inference")
- export_single_model(model.model_list[idx], archs[idx],
- sub_model_save_path, logger)
- else:
- save_path = os.path.join(save_path, "inference")
- export_single_model(
- model, arch_config, save_path, logger, input_shape=input_shape)
- if __name__ == "__main__":
- main()
|