123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import sys
- __dir__ = os.path.dirname(__file__)
- sys.path.append(__dir__)
- sys.path.append(os.path.join(__dir__, '..'))
- from ppocr.utils.e2e_utils.pgnet_pp_utils import PGNet_PostProcess
- class PGPostProcess(object):
- """
- The post process for PGNet.
- """
- def __init__(self,
- character_dict_path,
- valid_set,
- score_thresh,
- mode,
- point_gather_mode=None,
- **kwargs):
- self.character_dict_path = character_dict_path
- self.valid_set = valid_set
- self.score_thresh = score_thresh
- self.mode = mode
- self.point_gather_mode = point_gather_mode
- # c++ la-nms is faster, but only support python 3.5
- self.is_python35 = False
- if sys.version_info.major == 3 and sys.version_info.minor == 5:
- self.is_python35 = True
- def __call__(self, outs_dict, shape_list):
- post = PGNet_PostProcess(
- self.character_dict_path,
- self.valid_set,
- self.score_thresh,
- outs_dict,
- shape_list,
- point_gather_mode=self.point_gather_mode)
- if self.mode == 'fast':
- data = post.pg_postprocess_fast()
- else:
- data = post.pg_postprocess_slow()
- return data
|