convert_ppocr_label.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import json
  16. import os
  17. def poly_to_string(poly):
  18. if len(poly.shape) > 1:
  19. poly = np.array(poly).flatten()
  20. string = "\t".join(str(i) for i in poly)
  21. return string
  22. def convert_label(label_dir, mode="gt", save_dir="./save_results/"):
  23. if not os.path.exists(label_dir):
  24. raise ValueError(f"The file {label_dir} does not exist!")
  25. assert label_dir != save_dir, "hahahhaha"
  26. label_file = open(label_dir, 'r')
  27. data = label_file.readlines()
  28. gt_dict = {}
  29. for line in data:
  30. try:
  31. tmp = line.split('\t')
  32. assert len(tmp) == 2, ""
  33. except:
  34. tmp = line.strip().split(' ')
  35. gt_lists = []
  36. if tmp[0].split('/')[0] is not None:
  37. img_path = tmp[0]
  38. anno = json.loads(tmp[1])
  39. gt_collect = []
  40. for dic in anno:
  41. #txt = dic['transcription'].replace(' ', '') # ignore blank
  42. txt = dic['transcription']
  43. if 'score' in dic and float(dic['score']) < 0.5:
  44. continue
  45. if u'\u3000' in txt: txt = txt.replace(u'\u3000', u' ')
  46. #while ' ' in txt:
  47. # txt = txt.replace(' ', '')
  48. poly = np.array(dic['points']).flatten()
  49. if txt == "###":
  50. txt_tag = 1 ## ignore 1
  51. else:
  52. txt_tag = 0
  53. if mode == "gt":
  54. gt_label = poly_to_string(poly) + "\t" + str(
  55. txt_tag) + "\t" + txt + "\n"
  56. else:
  57. gt_label = poly_to_string(poly) + "\t" + txt + "\n"
  58. gt_lists.append(gt_label)
  59. gt_dict[img_path] = gt_lists
  60. else:
  61. continue
  62. if not os.path.exists(save_dir):
  63. os.makedirs(save_dir)
  64. for img_name in gt_dict.keys():
  65. save_name = img_name.split("/")[-1]
  66. save_file = os.path.join(save_dir, save_name + ".txt")
  67. with open(save_file, "w") as f:
  68. f.writelines(gt_dict[img_name])
  69. print("The convert label saved in {}".format(save_dir))
  70. def parse_args():
  71. import argparse
  72. parser = argparse.ArgumentParser(description="args")
  73. parser.add_argument("--label_path", type=str, required=True)
  74. parser.add_argument("--save_folder", type=str, required=True)
  75. parser.add_argument("--mode", type=str, default=False)
  76. args = parser.parse_args()
  77. return args
  78. if __name__ == "__main__":
  79. args = parse_args()
  80. convert_label(args.label_path, args.mode, args.save_folder)