# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import sys
import shapely
from shapely.geometry import Polygon
import numpy as np
from collections import defaultdict
import operator
import editdistance


def strQ2B(ustring):
    rstring = ""
    for uchar in ustring:
        inside_code = ord(uchar)
        if inside_code == 12288:
            inside_code = 32
        elif (inside_code >= 65281 and inside_code <= 65374):
            inside_code -= 65248
        rstring += chr(inside_code)
    return rstring


def polygon_from_str(polygon_points):
    """
    Create a shapely polygon object from gt or dt line.
    """
    polygon_points = np.array(polygon_points).reshape(4, 2)
    polygon = Polygon(polygon_points).convex_hull
    return polygon


def polygon_iou(poly1, poly2):
    """
    Intersection over union between two shapely polygons.
    """
    if not poly1.intersects(
            poly2):  # this test is fast and can accelerate calculation
        iou = 0
    else:
        try:
            inter_area = poly1.intersection(poly2).area
            union_area = poly1.area + poly2.area - inter_area
            iou = float(inter_area) / union_area
        except shapely.geos.TopologicalError:
            # except Exception as e:
            #     print(e)
            print('shapely.geos.TopologicalError occurred, iou set to 0')
            iou = 0
    return iou


def ed(str1, str2):
    return editdistance.eval(str1, str2)


def e2e_eval(gt_dir, res_dir, ignore_blank=False):
    print('start testing...')
    iou_thresh = 0.5
    val_names = os.listdir(gt_dir)
    num_gt_chars = 0
    gt_count = 0
    dt_count = 0
    hit = 0
    ed_sum = 0

    for i, val_name in enumerate(val_names):
        with open(os.path.join(gt_dir, val_name), encoding='utf-8') as f:
            gt_lines = [o.strip() for o in f.readlines()]
        gts = []
        ignore_masks = []
        for line in gt_lines:
            parts = line.strip().split('\t')
            # ignore illegal data
            if len(parts) < 9:
                continue
            assert (len(parts) < 11)
            if len(parts) == 9:
                gts.append(parts[:8] + [''])
            else:
                gts.append(parts[:8] + [parts[-1]])

            ignore_masks.append(parts[8])

        val_path = os.path.join(res_dir, val_name)
        if not os.path.exists(val_path):
            dt_lines = []
        else:
            with open(val_path, encoding='utf-8') as f:
                dt_lines = [o.strip() for o in f.readlines()]
        dts = []
        for line in dt_lines:
            # print(line)
            parts = line.strip().split("\t")
            assert (len(parts) < 10), "line error: {}".format(line)
            if len(parts) == 8:
                dts.append(parts + [''])
            else:
                dts.append(parts)

        dt_match = [False] * len(dts)
        gt_match = [False] * len(gts)
        all_ious = defaultdict(tuple)
        for index_gt, gt in enumerate(gts):
            gt_coors = [float(gt_coor) for gt_coor in gt[0:8]]
            gt_poly = polygon_from_str(gt_coors)
            for index_dt, dt in enumerate(dts):
                dt_coors = [float(dt_coor) for dt_coor in dt[0:8]]
                dt_poly = polygon_from_str(dt_coors)
                iou = polygon_iou(dt_poly, gt_poly)
                if iou >= iou_thresh:
                    all_ious[(index_gt, index_dt)] = iou
        sorted_ious = sorted(
            all_ious.items(), key=operator.itemgetter(1), reverse=True)
        sorted_gt_dt_pairs = [item[0] for item in sorted_ious]

        # matched gt and dt
        for gt_dt_pair in sorted_gt_dt_pairs:
            index_gt, index_dt = gt_dt_pair
            if gt_match[index_gt] == False and dt_match[index_dt] == False:
                gt_match[index_gt] = True
                dt_match[index_dt] = True
                if ignore_blank:
                    gt_str = strQ2B(gts[index_gt][8]).replace(" ", "")
                    dt_str = strQ2B(dts[index_dt][8]).replace(" ", "")
                else:
                    gt_str = strQ2B(gts[index_gt][8])
                    dt_str = strQ2B(dts[index_dt][8])
                if ignore_masks[index_gt] == '0':
                    ed_sum += ed(gt_str, dt_str)
                    num_gt_chars += len(gt_str)
                    if gt_str == dt_str:
                        hit += 1
                    gt_count += 1
                    dt_count += 1

        # unmatched dt
        for tindex, dt_match_flag in enumerate(dt_match):
            if dt_match_flag == False:
                dt_str = dts[tindex][8]
                gt_str = ''
                ed_sum += ed(dt_str, gt_str)
                dt_count += 1

        # unmatched gt
        for tindex, gt_match_flag in enumerate(gt_match):
            if gt_match_flag == False and ignore_masks[tindex] == '0':
                dt_str = ''
                gt_str = gts[tindex][8]
                ed_sum += ed(gt_str, dt_str)
                num_gt_chars += len(gt_str)
                gt_count += 1

    eps = 1e-9
    print('hit, dt_count, gt_count', hit, dt_count, gt_count)
    precision = hit / (dt_count + eps)
    recall = hit / (gt_count + eps)
    fmeasure = 2.0 * precision * recall / (precision + recall + eps)
    avg_edit_dist_img = ed_sum / len(val_names)
    avg_edit_dist_field = ed_sum / (gt_count + eps)
    character_acc = 1 - ed_sum / (num_gt_chars + eps)

    print('character_acc: %.2f' % (character_acc * 100) + "%")
    print('avg_edit_dist_field: %.2f' % (avg_edit_dist_field))
    print('avg_edit_dist_img: %.2f' % (avg_edit_dist_img))
    print('precision: %.2f' % (precision * 100) + "%")
    print('recall: %.2f' % (recall * 100) + "%")
    print('fmeasure: %.2f' % (fmeasure * 100) + "%")


if __name__ == '__main__':
    # if len(sys.argv) != 3:
    #     print("python3 ocr_e2e_eval.py gt_dir res_dir")
    #     exit(-1)
    # gt_folder = sys.argv[1]
    # pred_folder = sys.argv[2]
    gt_folder = sys.argv[1]
    pred_folder = sys.argv[2]
    e2e_eval(gt_folder, pred_folder)