gen_ocr_train_val_test.py 6.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # coding:utf8
  2. import os
  3. import shutil
  4. import random
  5. import argparse
  6. # 删除划分的训练集、验证集、测试集文件夹,重新创建一个空的文件夹
  7. def isCreateOrDeleteFolder(path, flag):
  8. flagPath = os.path.join(path, flag)
  9. if os.path.exists(flagPath):
  10. shutil.rmtree(flagPath)
  11. os.makedirs(flagPath)
  12. flagAbsPath = os.path.abspath(flagPath)
  13. return flagAbsPath
  14. def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag):
  15. # 按照指定的比例划分训练集、验证集、测试集
  16. dataAbsPath = os.path.abspath(root)
  17. if flag == "det":
  18. labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName)
  19. elif flag == "rec":
  20. labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName)
  21. labelFileRead = open(labelFilePath, "r", encoding="UTF-8")
  22. labelFileContent = labelFileRead.readlines()
  23. random.shuffle(labelFileContent)
  24. labelRecordLen = len(labelFileContent)
  25. for index, labelRecordInfo in enumerate(labelFileContent):
  26. imageRelativePath = labelRecordInfo.split('\t')[0]
  27. imageLabel = labelRecordInfo.split('\t')[1]
  28. imageName = os.path.basename(imageRelativePath)
  29. if flag == "det":
  30. imagePath = os.path.join(dataAbsPath, imageName)
  31. elif flag == "rec":
  32. imagePath = os.path.join(dataAbsPath, "{}\\{}".format(args.recImageDirName, imageName))
  33. # 按预设的比例划分训练集、验证集、测试集
  34. trainValTestRatio = args.trainValTestRatio.split(":")
  35. trainRatio = eval(trainValTestRatio[0]) / 10
  36. valRatio = trainRatio + eval(trainValTestRatio[1]) / 10
  37. curRatio = index / labelRecordLen
  38. if curRatio < trainRatio:
  39. imageCopyPath = os.path.join(absTrainRootPath, imageName)
  40. shutil.copy(imagePath, imageCopyPath)
  41. trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
  42. elif curRatio >= trainRatio and curRatio < valRatio:
  43. imageCopyPath = os.path.join(absValRootPath, imageName)
  44. shutil.copy(imagePath, imageCopyPath)
  45. valTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
  46. else:
  47. imageCopyPath = os.path.join(absTestRootPath, imageName)
  48. shutil.copy(imagePath, imageCopyPath)
  49. testTxt.write("{}\t{}".format(imageCopyPath, imageLabel))
  50. # 删掉存在的文件
  51. def removeFile(path):
  52. if os.path.exists(path):
  53. os.remove(path)
  54. def genDetRecTrainVal(args):
  55. detAbsTrainRootPath = isCreateOrDeleteFolder(args.detRootPath, "train")
  56. detAbsValRootPath = isCreateOrDeleteFolder(args.detRootPath, "val")
  57. detAbsTestRootPath = isCreateOrDeleteFolder(args.detRootPath, "test")
  58. recAbsTrainRootPath = isCreateOrDeleteFolder(args.recRootPath, "train")
  59. recAbsValRootPath = isCreateOrDeleteFolder(args.recRootPath, "val")
  60. recAbsTestRootPath = isCreateOrDeleteFolder(args.recRootPath, "test")
  61. removeFile(os.path.join(args.detRootPath, "train.txt"))
  62. removeFile(os.path.join(args.detRootPath, "val.txt"))
  63. removeFile(os.path.join(args.detRootPath, "test.txt"))
  64. removeFile(os.path.join(args.recRootPath, "train.txt"))
  65. removeFile(os.path.join(args.recRootPath, "val.txt"))
  66. removeFile(os.path.join(args.recRootPath, "test.txt"))
  67. detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8")
  68. detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8")
  69. detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8")
  70. recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8")
  71. recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8")
  72. recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8")
  73. splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt,
  74. detTestTxt, "det")
  75. for root, dirs, files in os.walk(args.datasetRootPath):
  76. for dir in dirs:
  77. if dir == 'crop_img':
  78. splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt,
  79. recTestTxt, "rec")
  80. else:
  81. continue
  82. break
  83. if __name__ == "__main__":
  84. # 功能描述:分别划分检测和识别的训练集、验证集、测试集
  85. # 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注,
  86. # 如此会有多个标注好的图像文件夹汇总并划分训练集、验证集、测试集的需求
  87. parser = argparse.ArgumentParser()
  88. parser.add_argument(
  89. "--trainValTestRatio",
  90. type=str,
  91. default="6:2:2",
  92. help="ratio of trainset:valset:testset")
  93. parser.add_argument(
  94. "--datasetRootPath",
  95. type=str,
  96. default="../train_data/",
  97. help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..."
  98. )
  99. parser.add_argument(
  100. "--detRootPath",
  101. type=str,
  102. default="../train_data/det",
  103. help="the path where the divided detection dataset is placed")
  104. parser.add_argument(
  105. "--recRootPath",
  106. type=str,
  107. default="../train_data/rec",
  108. help="the path where the divided recognition dataset is placed"
  109. )
  110. parser.add_argument(
  111. "--detLabelFileName",
  112. type=str,
  113. default="Label.txt",
  114. help="the name of the detection annotation file")
  115. parser.add_argument(
  116. "--recLabelFileName",
  117. type=str,
  118. default="rec_gt.txt",
  119. help="the name of the recognition annotation file"
  120. )
  121. parser.add_argument(
  122. "--recImageDirName",
  123. type=str,
  124. default="crop_img",
  125. help="the name of the folder where the cropped recognition dataset is located"
  126. )
  127. args = parser.parse_args()
  128. genDetRecTrainVal(args)