lmdb_dataset.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import os
  16. from paddle.io import Dataset
  17. import lmdb
  18. import cv2
  19. import string
  20. import six
  21. from PIL import Image
  22. from .imaug import transform, create_operators
  23. class LMDBDataSet(Dataset):
  24. def __init__(self, config, mode, logger, seed=None):
  25. super(LMDBDataSet, self).__init__()
  26. global_config = config['Global']
  27. dataset_config = config[mode]['dataset']
  28. loader_config = config[mode]['loader']
  29. batch_size = loader_config['batch_size_per_card']
  30. data_dir = dataset_config['data_dir']
  31. self.do_shuffle = loader_config['shuffle']
  32. self.lmdb_sets = self.load_hierarchical_lmdb_dataset(data_dir)
  33. logger.info("Initialize indexs of datasets:%s" % data_dir)
  34. self.data_idx_order_list = self.dataset_traversal()
  35. if self.do_shuffle:
  36. np.random.shuffle(self.data_idx_order_list)
  37. self.ops = create_operators(dataset_config['transforms'], global_config)
  38. self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx",
  39. 1)
  40. ratio_list = dataset_config.get("ratio_list", [1.0])
  41. self.need_reset = True in [x < 1 for x in ratio_list]
  42. def load_hierarchical_lmdb_dataset(self, data_dir):
  43. lmdb_sets = {}
  44. dataset_idx = 0
  45. for dirpath, dirnames, filenames in os.walk(data_dir + '/'):
  46. if not dirnames:
  47. env = lmdb.open(
  48. dirpath,
  49. max_readers=32,
  50. readonly=True,
  51. lock=False,
  52. readahead=False,
  53. meminit=False)
  54. txn = env.begin(write=False)
  55. num_samples = int(txn.get('num-samples'.encode()))
  56. lmdb_sets[dataset_idx] = {"dirpath":dirpath, "env":env, \
  57. "txn":txn, "num_samples":num_samples}
  58. dataset_idx += 1
  59. return lmdb_sets
  60. def dataset_traversal(self):
  61. lmdb_num = len(self.lmdb_sets)
  62. total_sample_num = 0
  63. for lno in range(lmdb_num):
  64. total_sample_num += self.lmdb_sets[lno]['num_samples']
  65. data_idx_order_list = np.zeros((total_sample_num, 2))
  66. beg_idx = 0
  67. for lno in range(lmdb_num):
  68. tmp_sample_num = self.lmdb_sets[lno]['num_samples']
  69. end_idx = beg_idx + tmp_sample_num
  70. data_idx_order_list[beg_idx:end_idx, 0] = lno
  71. data_idx_order_list[beg_idx:end_idx, 1] \
  72. = list(range(tmp_sample_num))
  73. data_idx_order_list[beg_idx:end_idx, 1] += 1
  74. beg_idx = beg_idx + tmp_sample_num
  75. return data_idx_order_list
  76. def get_img_data(self, value):
  77. """get_img_data"""
  78. if not value:
  79. return None
  80. imgdata = np.frombuffer(value, dtype='uint8')
  81. if imgdata is None:
  82. return None
  83. imgori = cv2.imdecode(imgdata, 1)
  84. if imgori is None:
  85. return None
  86. return imgori
  87. def get_ext_data(self):
  88. ext_data_num = 0
  89. for op in self.ops:
  90. if hasattr(op, 'ext_data_num'):
  91. ext_data_num = getattr(op, 'ext_data_num')
  92. break
  93. load_data_ops = self.ops[:self.ext_op_transform_idx]
  94. ext_data = []
  95. while len(ext_data) < ext_data_num:
  96. lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(
  97. len(self))]
  98. lmdb_idx = int(lmdb_idx)
  99. file_idx = int(file_idx)
  100. sample_info = self.get_lmdb_sample_info(
  101. self.lmdb_sets[lmdb_idx]['txn'], file_idx)
  102. if sample_info is None:
  103. continue
  104. img, label = sample_info
  105. data = {'image': img, 'label': label}
  106. data = transform(data, load_data_ops)
  107. if data is None:
  108. continue
  109. ext_data.append(data)
  110. return ext_data
  111. def get_lmdb_sample_info(self, txn, index):
  112. label_key = 'label-%09d'.encode() % index
  113. label = txn.get(label_key)
  114. if label is None:
  115. return None
  116. label = label.decode('utf-8')
  117. img_key = 'image-%09d'.encode() % index
  118. imgbuf = txn.get(img_key)
  119. return imgbuf, label
  120. def __getitem__(self, idx):
  121. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  122. lmdb_idx = int(lmdb_idx)
  123. file_idx = int(file_idx)
  124. sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
  125. file_idx)
  126. if sample_info is None:
  127. return self.__getitem__(np.random.randint(self.__len__()))
  128. img, label = sample_info
  129. data = {'image': img, 'label': label}
  130. data['ext_data'] = self.get_ext_data()
  131. outs = transform(data, self.ops)
  132. if outs is None:
  133. return self.__getitem__(np.random.randint(self.__len__()))
  134. return outs
  135. def __len__(self):
  136. return self.data_idx_order_list.shape[0]
  137. class LMDBDataSetSR(LMDBDataSet):
  138. def buf2PIL(self, txn, key, type='RGB'):
  139. imgbuf = txn.get(key)
  140. buf = six.BytesIO()
  141. buf.write(imgbuf)
  142. buf.seek(0)
  143. im = Image.open(buf).convert(type)
  144. return im
  145. def str_filt(self, str_, voc_type):
  146. alpha_dict = {
  147. 'digit': string.digits,
  148. 'lower': string.digits + string.ascii_lowercase,
  149. 'upper': string.digits + string.ascii_letters,
  150. 'all': string.digits + string.ascii_letters + string.punctuation
  151. }
  152. if voc_type == 'lower':
  153. str_ = str_.lower()
  154. for char in str_:
  155. if char not in alpha_dict[voc_type]:
  156. str_ = str_.replace(char, '')
  157. return str_
  158. def get_lmdb_sample_info(self, txn, index):
  159. self.voc_type = 'upper'
  160. self.max_len = 100
  161. self.test = False
  162. label_key = b'label-%09d' % index
  163. word = str(txn.get(label_key).decode())
  164. img_HR_key = b'image_hr-%09d' % index # 128*32
  165. img_lr_key = b'image_lr-%09d' % index # 64*16
  166. try:
  167. img_HR = self.buf2PIL(txn, img_HR_key, 'RGB')
  168. img_lr = self.buf2PIL(txn, img_lr_key, 'RGB')
  169. except IOError or len(word) > self.max_len:
  170. return self[index + 1]
  171. label_str = self.str_filt(word, self.voc_type)
  172. return img_HR, img_lr, label_str
  173. def __getitem__(self, idx):
  174. lmdb_idx, file_idx = self.data_idx_order_list[idx]
  175. lmdb_idx = int(lmdb_idx)
  176. file_idx = int(file_idx)
  177. sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'],
  178. file_idx)
  179. if sample_info is None:
  180. return self.__getitem__(np.random.randint(self.__len__()))
  181. img_HR, img_lr, label_str = sample_info
  182. data = {'image_hr': img_HR, 'image_lr': img_lr, 'label': label_str}
  183. outs = transform(data, self.ops)
  184. if outs is None:
  185. return self.__getitem__(np.random.randint(self.__len__()))
  186. return outs