| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269 | # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at##     http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.import osimport sys__dir__ = os.path.dirname(os.path.abspath(__file__))sys.path.append(__dir__)sys.path.insert(0, os.path.abspath(os.path.join(__dir__, "..")))import argparseimport paddlefrom paddle.jit import to_staticfrom ppocr.modeling.architectures import build_modelfrom ppocr.postprocess import build_post_processfrom ppocr.utils.save_load import load_modelfrom ppocr.utils.logging import get_loggerfrom tools.program import load_config, merge_config, ArgsParserdef export_single_model(model,                        arch_config,                        save_path,                        logger,                        input_shape=None,                        quanter=None):    if arch_config["algorithm"] == "SRN":        max_text_length = arch_config["Head"]["max_text_length"]        other_shape = [            paddle.static.InputSpec(                shape=[None, 1, 64, 256], dtype="float32"), [                    paddle.static.InputSpec(                        shape=[None, 256, 1],                        dtype="int64"), paddle.static.InputSpec(                            shape=[None, max_text_length, 1], dtype="int64"),                    paddle.static.InputSpec(                        shape=[None, 8, max_text_length, max_text_length],                        dtype="int64"), paddle.static.InputSpec(                            shape=[None, 8, max_text_length, max_text_length],                            dtype="int64")                ]        ]        model = to_static(model, input_spec=other_shape)    elif arch_config["algorithm"] == "SAR":        other_shape = [            paddle.static.InputSpec(                shape=[None, 3, 48, 160], dtype="float32"),            [paddle.static.InputSpec(                shape=[None], dtype="float32")]        ]        model = to_static(model, input_spec=other_shape)    elif arch_config["algorithm"] == "SVTR":        if arch_config["Head"]["name"] == 'MultiHead':            other_shape = [                paddle.static.InputSpec(                    shape=[None, 3, 48, -1], dtype="float32"),            ]        else:            other_shape = [                paddle.static.InputSpec(                    shape=[None] + input_shape, dtype="float32"),            ]        model = to_static(model, input_spec=other_shape)    elif arch_config["algorithm"] == "PREN":        other_shape = [            paddle.static.InputSpec(                shape=[None, 3, 64, 256], dtype="float32"),        ]        model = to_static(model, input_spec=other_shape)    elif arch_config["model_type"] == "sr":        other_shape = [            paddle.static.InputSpec(                shape=[None, 3, 16, 64], dtype="float32")        ]        model = to_static(model, input_spec=other_shape)    elif arch_config["algorithm"] == "ViTSTR":        other_shape = [            paddle.static.InputSpec(                shape=[None, 1, 224, 224], dtype="float32"),        ]        model = to_static(model, input_spec=other_shape)    elif arch_config["algorithm"] == "ABINet":        other_shape = [            paddle.static.InputSpec(                shape=[None, 3, 32, 128], dtype="float32"),        ]        # print([None, 3, 32, 128])        model = to_static(model, input_spec=other_shape)    elif arch_config["algorithm"] in ["NRTR", "SPIN", 'RFL']:        other_shape = [            paddle.static.InputSpec(                shape=[None, 1, 32, 100], dtype="float32"),        ]        model = to_static(model, input_spec=other_shape)    elif arch_config["algorithm"] == "VisionLAN":        other_shape = [            paddle.static.InputSpec(                shape=[None, 3, 64, 256], dtype="float32"),        ]        model = to_static(model, input_spec=other_shape)    elif arch_config["algorithm"] == "RobustScanner":        max_text_length = arch_config["Head"]["max_text_length"]        other_shape = [            paddle.static.InputSpec(                shape=[None, 3, 48, 160], dtype="float32"), [                    paddle.static.InputSpec(                        shape=[None, ], dtype="float32"),                    paddle.static.InputSpec(                        shape=[None, max_text_length], dtype="int64")                ]        ]        model = to_static(model, input_spec=other_shape)    elif arch_config["algorithm"] == "CAN":        other_shape = [[            paddle.static.InputSpec(                shape=[None, 1, None, None],                dtype="float32"), paddle.static.InputSpec(                    shape=[None, 1, None, None], dtype="float32"),            paddle.static.InputSpec(                shape=[None, arch_config['Head']['max_text_length']],                dtype="int64")        ]]        model = to_static(model, input_spec=other_shape)    elif arch_config["algorithm"] in ["LayoutLM", "LayoutLMv2", "LayoutXLM"]:        input_spec = [            paddle.static.InputSpec(                shape=[None, 512], dtype="int64"),  # input_ids            paddle.static.InputSpec(                shape=[None, 512, 4], dtype="int64"),  # bbox            paddle.static.InputSpec(                shape=[None, 512], dtype="int64"),  # attention_mask            paddle.static.InputSpec(                shape=[None, 512], dtype="int64"),  # token_type_ids            paddle.static.InputSpec(                shape=[None, 3, 224, 224], dtype="int64"),  # image        ]        if 'Re' in arch_config['Backbone']['name']:            input_spec.extend([                paddle.static.InputSpec(                    shape=[None, 512, 3], dtype="int64"),  # entities                paddle.static.InputSpec(                    shape=[None, None, 2], dtype="int64"),  # relations            ])        if model.backbone.use_visual_backbone is False:            input_spec.pop(4)        model = to_static(model, input_spec=[input_spec])    else:        infer_shape = [3, -1, -1]        if arch_config["model_type"] == "rec":            infer_shape = [3, 32, -1]  # for rec model, H must be 32            if "Transform" in arch_config and arch_config[                    "Transform"] is not None and arch_config["Transform"][                        "name"] == "TPS":                logger.info(                    "When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training"                )                infer_shape[-1] = 100        elif arch_config["model_type"] == "table":            infer_shape = [3, 488, 488]            if arch_config["algorithm"] == "TableMaster":                infer_shape = [3, 480, 480]            if arch_config["algorithm"] == "SLANet":                infer_shape = [3, -1, -1]        model = to_static(            model,            input_spec=[                paddle.static.InputSpec(                    shape=[None] + infer_shape, dtype="float32")            ])    if quanter is None:        paddle.jit.save(model, save_path)    else:        quanter.save_quantized_model(model, save_path)    logger.info("inference model is saved to {}".format(save_path))    returndef main():    FLAGS = ArgsParser().parse_args()    config = load_config(FLAGS.config)    config = merge_config(config, FLAGS.opt)    logger = get_logger()    # build post process    post_process_class = build_post_process(config["PostProcess"],                                            config["Global"])    # build model    # for rec algorithm    if hasattr(post_process_class, "character"):        char_num = len(getattr(post_process_class, "character"))        if config["Architecture"]["algorithm"] in ["Distillation",                                                   ]:  # distillation model            for key in config["Architecture"]["Models"]:                if config["Architecture"]["Models"][key]["Head"][                        "name"] == 'MultiHead':  # multi head                    out_channels_list = {}                    if config['PostProcess'][                            'name'] == 'DistillationSARLabelDecode':                        char_num = char_num - 2                    out_channels_list['CTCLabelDecode'] = char_num                    out_channels_list['SARLabelDecode'] = char_num + 2                    config['Architecture']['Models'][key]['Head'][                        'out_channels_list'] = out_channels_list                else:                    config["Architecture"]["Models"][key]["Head"][                        "out_channels"] = char_num                # just one final tensor needs to exported for inference                config["Architecture"]["Models"][key][                    "return_all_feats"] = False        elif config['Architecture']['Head'][                'name'] == 'MultiHead':  # multi head            out_channels_list = {}            char_num = len(getattr(post_process_class, 'character'))            if config['PostProcess']['name'] == 'SARLabelDecode':                char_num = char_num - 2            out_channels_list['CTCLabelDecode'] = char_num            out_channels_list['SARLabelDecode'] = char_num + 2            config['Architecture']['Head'][                'out_channels_list'] = out_channels_list        else:  # base rec model            config["Architecture"]["Head"]["out_channels"] = char_num    # for sr algorithm    if config["Architecture"]["model_type"] == "sr":        config['Architecture']["Transform"]['infer_mode'] = True    model = build_model(config["Architecture"])    load_model(config, model, model_type=config['Architecture']["model_type"])    model.eval()    save_path = config["Global"]["save_inference_dir"]    arch_config = config["Architecture"]    if arch_config["algorithm"] == "SVTR" and arch_config["Head"][            "name"] != 'MultiHead':        input_shape = config["Eval"]["dataset"]["transforms"][-2][            'SVTRRecResizeImg']['image_shape']    else:        input_shape = None    if arch_config["algorithm"] in ["Distillation", ]:  # distillation model        archs = list(arch_config["Models"].values())        for idx, name in enumerate(model.model_name_list):            sub_model_save_path = os.path.join(save_path, name, "inference")            export_single_model(model.model_list[idx], archs[idx],                                sub_model_save_path, logger)    else:        save_path = os.path.join(save_path, "inference")        export_single_model(            model, arch_config, save_path, logger, input_shape=input_shape)if __name__ == "__main__":    main()
 |