123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import numpy as np
- from .locality_aware_nms import nms_locality
- import cv2
- import paddle
- import os
- import sys
- class EASTPostProcess(object):
- """
- The post process for EAST.
- """
- def __init__(self,
- score_thresh=0.8,
- cover_thresh=0.1,
- nms_thresh=0.2,
- **kwargs):
- self.score_thresh = score_thresh
- self.cover_thresh = cover_thresh
- self.nms_thresh = nms_thresh
- def restore_rectangle_quad(self, origin, geometry):
- """
- Restore rectangle from quadrangle.
- """
- # quad
- origin_concat = np.concatenate(
- (origin, origin, origin, origin), axis=1) # (n, 8)
- pred_quads = origin_concat - geometry
- pred_quads = pred_quads.reshape((-1, 4, 2)) # (n, 4, 2)
- return pred_quads
- def detect(self,
- score_map,
- geo_map,
- score_thresh=0.8,
- cover_thresh=0.1,
- nms_thresh=0.2):
- """
- restore text boxes from score map and geo map
- """
- score_map = score_map[0]
- geo_map = np.swapaxes(geo_map, 1, 0)
- geo_map = np.swapaxes(geo_map, 1, 2)
- # filter the score map
- xy_text = np.argwhere(score_map > score_thresh)
- if len(xy_text) == 0:
- return []
- # sort the text boxes via the y axis
- xy_text = xy_text[np.argsort(xy_text[:, 0])]
- #restore quad proposals
- text_box_restored = self.restore_rectangle_quad(
- xy_text[:, ::-1] * 4, geo_map[xy_text[:, 0], xy_text[:, 1], :])
- boxes = np.zeros((text_box_restored.shape[0], 9), dtype=np.float32)
- boxes[:, :8] = text_box_restored.reshape((-1, 8))
- boxes[:, 8] = score_map[xy_text[:, 0], xy_text[:, 1]]
- try:
- import lanms
- boxes = lanms.merge_quadrangle_n9(boxes, nms_thresh)
- except:
- print(
- 'you should install lanms by pip3 install lanms-nova to speed up nms_locality'
- )
- boxes = nms_locality(boxes.astype(np.float64), nms_thresh)
- if boxes.shape[0] == 0:
- return []
- # Here we filter some low score boxes by the average score map,
- # this is different from the orginal paper.
- for i, box in enumerate(boxes):
- mask = np.zeros_like(score_map, dtype=np.uint8)
- cv2.fillPoly(mask, box[:8].reshape(
- (-1, 4, 2)).astype(np.int32) // 4, 1)
- boxes[i, 8] = cv2.mean(score_map, mask)[0]
- boxes = boxes[boxes[:, 8] > cover_thresh]
- return boxes
- def sort_poly(self, p):
- """
- Sort polygons.
- """
- min_axis = np.argmin(np.sum(p, axis=1))
- p = p[[min_axis, (min_axis + 1) % 4,\
- (min_axis + 2) % 4, (min_axis + 3) % 4]]
- if abs(p[0, 0] - p[1, 0]) > abs(p[0, 1] - p[1, 1]):
- return p
- else:
- return p[[0, 3, 2, 1]]
- def __call__(self, outs_dict, shape_list):
- score_list = outs_dict['f_score']
- geo_list = outs_dict['f_geo']
- if isinstance(score_list, paddle.Tensor):
- score_list = score_list.numpy()
- geo_list = geo_list.numpy()
- img_num = len(shape_list)
- dt_boxes_list = []
- for ino in range(img_num):
- score = score_list[ino]
- geo = geo_list[ino]
- boxes = self.detect(
- score_map=score,
- geo_map=geo,
- score_thresh=self.score_thresh,
- cover_thresh=self.cover_thresh,
- nms_thresh=self.nms_thresh)
- boxes_norm = []
- if len(boxes) > 0:
- h, w = score.shape[1:]
- src_h, src_w, ratio_h, ratio_w = shape_list[ino]
- boxes = boxes[:, :8].reshape((-1, 4, 2))
- boxes[:, :, 0] /= ratio_w
- boxes[:, :, 1] /= ratio_h
- for i_box, box in enumerate(boxes):
- box = self.sort_poly(box.astype(np.int32))
- if np.linalg.norm(box[0] - box[1]) < 5 \
- or np.linalg.norm(box[3] - box[0]) < 5:
- continue
- boxes_norm.append(box)
- dt_boxes_list.append({'points': np.array(boxes_norm)})
- return dt_boxes_list
|