# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import os import sys import cv2 import numpy as np from copy import deepcopy def trans_poly_to_bbox(poly): x1 = np.min([p[0] for p in poly]) x2 = np.max([p[0] for p in poly]) y1 = np.min([p[1] for p in poly]) y2 = np.max([p[1] for p in poly]) return [x1, y1, x2, y2] def get_outer_poly(bbox_list): x1 = min([bbox[0] for bbox in bbox_list]) y1 = min([bbox[1] for bbox in bbox_list]) x2 = max([bbox[2] for bbox in bbox_list]) y2 = max([bbox[3] for bbox in bbox_list]) return [[x1, y1], [x2, y1], [x2, y2], [x1, y2]] def load_funsd_label(image_dir, anno_dir): imgs = os.listdir(image_dir) annos = os.listdir(anno_dir) imgs = [img.replace(".png", "") for img in imgs] annos = [anno.replace(".json", "") for anno in annos] fn_info_map = dict() for anno_fn in annos: res = [] with open(os.path.join(anno_dir, anno_fn + ".json"), "r") as fin: infos = json.load(fin) infos = infos["form"] old_id2new_id_map = dict() global_new_id = 0 for info in infos: if info["text"] is None: continue words = info["words"] if len(words) <= 0: continue word_idx = 1 curr_bboxes = [words[0]["box"]] curr_texts = [words[0]["text"]] while word_idx < len(words): # switch to a new link if words[word_idx]["box"][0] + 10 <= words[word_idx - 1][ "box"][2]: if len("".join(curr_texts[0])) > 0: res.append({ "transcription": " ".join(curr_texts), "label": info["label"], "points": get_outer_poly(curr_bboxes), "linking": info["linking"], "id": global_new_id, }) if info["id"] not in old_id2new_id_map: old_id2new_id_map[info["id"]] = [] old_id2new_id_map[info["id"]].append(global_new_id) global_new_id += 1 curr_bboxes = [words[word_idx]["box"]] curr_texts = [words[word_idx]["text"]] else: curr_bboxes.append(words[word_idx]["box"]) curr_texts.append(words[word_idx]["text"]) word_idx += 1 if len("".join(curr_texts[0])) > 0: res.append({ "transcription": " ".join(curr_texts), "label": info["label"], "points": get_outer_poly(curr_bboxes), "linking": info["linking"], "id": global_new_id, }) if info["id"] not in old_id2new_id_map: old_id2new_id_map[info["id"]] = [] old_id2new_id_map[info["id"]].append(global_new_id) global_new_id += 1 res = sorted( res, key=lambda r: (r["points"][0][1], r["points"][0][0])) for i in range(len(res) - 1): for j in range(i, 0, -1): if abs(res[j + 1]["points"][0][1] - res[j]["points"][0][1]) < 20 and \ (res[j + 1]["points"][0][0] < res[j]["points"][0][0]): tmp = deepcopy(res[j]) res[j] = deepcopy(res[j + 1]) res[j + 1] = deepcopy(tmp) else: break # re-generate unique ids for idx, r in enumerate(res): new_links = [] for link in r["linking"]: # illegal links will be removed if link[0] not in old_id2new_id_map or link[ 1] not in old_id2new_id_map: continue for src in old_id2new_id_map[link[0]]: for dst in old_id2new_id_map[link[1]]: new_links.append([src, dst]) res[idx]["linking"] = deepcopy(new_links) fn_info_map[anno_fn] = res return fn_info_map def main(): test_image_dir = "train_data/FUNSD/testing_data/images/" test_anno_dir = "train_data/FUNSD/testing_data/annotations/" test_output_dir = "train_data/FUNSD/test.json" fn_info_map = load_funsd_label(test_image_dir, test_anno_dir) with open(test_output_dir, "w") as fout: for fn in fn_info_map: fout.write(fn + ".png" + "\t" + json.dumps( fn_info_map[fn], ensure_ascii=False) + "\n") train_image_dir = "train_data/FUNSD/training_data/images/" train_anno_dir = "train_data/FUNSD/training_data/annotations/" train_output_dir = "train_data/FUNSD/train.json" fn_info_map = load_funsd_label(train_image_dir, train_anno_dir) with open(train_output_dir, "w") as fout: for fn in fn_info_map: fout.write(fn + ".png" + "\t" + json.dumps( fn_info_map[fn], ensure_ascii=False) + "\n") print("====ok====") return if __name__ == "__main__": main()