simple_dataset.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import os
  16. import json
  17. import random
  18. import traceback
  19. from paddle.io import Dataset
  20. from .imaug import transform, create_operators
  21. class SimpleDataSet(Dataset):
  22. def __init__(self, config, mode, logger, seed=None):
  23. super(SimpleDataSet, self).__init__()
  24. self.logger = logger
  25. self.mode = mode.lower()
  26. global_config = config['Global']
  27. dataset_config = config[mode]['dataset']
  28. loader_config = config[mode]['loader']
  29. self.delimiter = dataset_config.get('delimiter', '\t')
  30. label_file_list = dataset_config.pop('label_file_list')
  31. data_source_num = len(label_file_list)
  32. ratio_list = dataset_config.get("ratio_list", 1.0)
  33. if isinstance(ratio_list, (float, int)):
  34. ratio_list = [float(ratio_list)] * int(data_source_num)
  35. assert len(
  36. ratio_list
  37. ) == data_source_num, "The length of ratio_list should be the same as the file_list."
  38. self.data_dir = dataset_config['data_dir']
  39. self.do_shuffle = loader_config['shuffle']
  40. self.seed = seed
  41. logger.info("Initialize indexs of datasets:%s" % label_file_list)
  42. self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
  43. self.data_idx_order_list = list(range(len(self.data_lines)))
  44. if self.mode == "train" and self.do_shuffle:
  45. self.shuffle_data_random()
  46. self.ops = create_operators(dataset_config['transforms'], global_config)
  47. self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
  48. 2)
  49. self.need_reset = True in [x < 1 for x in ratio_list]
  50. def get_image_info_list(self, file_list, ratio_list):
  51. if isinstance(file_list, str):
  52. file_list = [file_list]
  53. data_lines = []
  54. for idx, file in enumerate(file_list):
  55. with open(file, "rb") as f:
  56. lines = f.readlines()
  57. if self.mode == "train" or ratio_list[idx] < 1.0:
  58. random.seed(self.seed)
  59. lines = random.sample(lines,
  60. round(len(lines) * ratio_list[idx]))
  61. data_lines.extend(lines)
  62. return data_lines
  63. def shuffle_data_random(self):
  64. random.seed(self.seed)
  65. random.shuffle(self.data_lines)
  66. return
  67. def _try_parse_filename_list(self, file_name):
  68. # multiple images -> one gt label
  69. if len(file_name) > 0 and file_name[0] == "[":
  70. try:
  71. info = json.loads(file_name)
  72. file_name = random.choice(info)
  73. except:
  74. pass
  75. return file_name
  76. def get_ext_data(self):
  77. ext_data_num = 0
  78. for op in self.ops:
  79. if hasattr(op, 'ext_data_num'):
  80. ext_data_num = getattr(op, 'ext_data_num')
  81. break
  82. load_data_ops = self.ops[:self.ext_op_transform_idx]
  83. ext_data = []
  84. while len(ext_data) < ext_data_num:
  85. file_idx = self.data_idx_order_list[np.random.randint(self.__len__(
  86. ))]
  87. data_line = self.data_lines[file_idx]
  88. data_line = data_line.decode('utf-8')
  89. substr = data_line.strip("\n").split(self.delimiter)
  90. file_name = substr[0]
  91. file_name = self._try_parse_filename_list(file_name)
  92. label = substr[1]
  93. img_path = os.path.join(self.data_dir, file_name)
  94. data = {'img_path': img_path, 'label': label}
  95. if not os.path.exists(img_path):
  96. continue
  97. with open(data['img_path'], 'rb') as f:
  98. img = f.read()
  99. data['image'] = img
  100. data = transform(data, load_data_ops)
  101. if data is None:
  102. continue
  103. if 'polys' in data.keys():
  104. if data['polys'].shape[1] != 4:
  105. continue
  106. ext_data.append(data)
  107. return ext_data
  108. def __getitem__(self, idx):
  109. file_idx = self.data_idx_order_list[idx]
  110. data_line = self.data_lines[file_idx]
  111. try:
  112. data_line = data_line.decode('utf-8')
  113. substr = data_line.strip("\n").split(self.delimiter)
  114. file_name = substr[0]
  115. file_name = self._try_parse_filename_list(file_name)
  116. label = substr[1]
  117. img_path = os.path.join(self.data_dir, file_name)
  118. data = {'img_path': img_path, 'label': label}
  119. if not os.path.exists(img_path):
  120. raise Exception("{} does not exist!".format(img_path))
  121. with open(data['img_path'], 'rb') as f:
  122. img = f.read()
  123. data['image'] = img
  124. data['ext_data'] = self.get_ext_data()
  125. outs = transform(data, self.ops)
  126. except:
  127. self.logger.error(
  128. "When parsing line {}, error happened with msg: {}".format(
  129. data_line, traceback.format_exc()))
  130. outs = None
  131. if outs is None:
  132. # during evaluation, we should fix the idx to get same results for many times of evaluation.
  133. rnd_idx = np.random.randint(self.__len__(
  134. )) if self.mode == "train" else (idx + 1) % self.__len__()
  135. return self.__getitem__(rnd_idx)
  136. return outs
  137. def __len__(self):
  138. return len(self.data_idx_order_list)