pubtab_dataset.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # copyright (c) 2021 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import os
  16. import random
  17. from paddle.io import Dataset
  18. import json
  19. from copy import deepcopy
  20. from .imaug import transform, create_operators
  21. class PubTabDataSet(Dataset):
  22. def __init__(self, config, mode, logger, seed=None):
  23. super(PubTabDataSet, self).__init__()
  24. self.logger = logger
  25. global_config = config['Global']
  26. dataset_config = config[mode]['dataset']
  27. loader_config = config[mode]['loader']
  28. label_file_list = dataset_config.pop('label_file_list')
  29. data_source_num = len(label_file_list)
  30. ratio_list = dataset_config.get("ratio_list", [1.0])
  31. if isinstance(ratio_list, (float, int)):
  32. ratio_list = [float(ratio_list)] * int(data_source_num)
  33. assert len(
  34. ratio_list
  35. ) == data_source_num, "The length of ratio_list should be the same as the file_list."
  36. self.data_dir = dataset_config['data_dir']
  37. self.do_shuffle = loader_config['shuffle']
  38. self.seed = seed
  39. self.mode = mode.lower()
  40. logger.info("Initialize indexs of datasets:%s" % label_file_list)
  41. self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
  42. # self.check(config['Global']['max_text_length'])
  43. if mode.lower() == "train" and self.do_shuffle:
  44. self.shuffle_data_random()
  45. self.ops = create_operators(dataset_config['transforms'], global_config)
  46. self.need_reset = True in [x < 1 for x in ratio_list]
  47. def get_image_info_list(self, file_list, ratio_list):
  48. if isinstance(file_list, str):
  49. file_list = [file_list]
  50. data_lines = []
  51. for idx, file in enumerate(file_list):
  52. with open(file, "rb") as f:
  53. lines = f.readlines()
  54. if self.mode == "train" or ratio_list[idx] < 1.0:
  55. random.seed(self.seed)
  56. lines = random.sample(lines,
  57. round(len(lines) * ratio_list[idx]))
  58. data_lines.extend(lines)
  59. return data_lines
  60. def check(self, max_text_length):
  61. data_lines = []
  62. for line in self.data_lines:
  63. data_line = line.decode('utf-8').strip("\n")
  64. info = json.loads(data_line)
  65. file_name = info['filename']
  66. cells = info['html']['cells'].copy()
  67. structure = info['html']['structure']['tokens'].copy()
  68. img_path = os.path.join(self.data_dir, file_name)
  69. if not os.path.exists(img_path):
  70. self.logger.warning("{} does not exist!".format(img_path))
  71. continue
  72. if len(structure) == 0 or len(structure) > max_text_length:
  73. continue
  74. # data = {'img_path': img_path, 'cells': cells, 'structure':structure,'file_name':file_name}
  75. data_lines.append(line)
  76. self.data_lines = data_lines
  77. def shuffle_data_random(self):
  78. if self.do_shuffle:
  79. random.seed(self.seed)
  80. random.shuffle(self.data_lines)
  81. return
  82. def __getitem__(self, idx):
  83. try:
  84. data_line = self.data_lines[idx]
  85. data_line = data_line.decode('utf-8').strip("\n")
  86. info = json.loads(data_line)
  87. file_name = info['filename']
  88. cells = info['html']['cells'].copy()
  89. structure = info['html']['structure']['tokens'].copy()
  90. img_path = os.path.join(self.data_dir, file_name)
  91. if not os.path.exists(img_path):
  92. raise Exception("{} does not exist!".format(img_path))
  93. data = {
  94. 'img_path': img_path,
  95. 'cells': cells,
  96. 'structure': structure,
  97. 'file_name': file_name
  98. }
  99. with open(data['img_path'], 'rb') as f:
  100. img = f.read()
  101. data['image'] = img
  102. outs = transform(data, self.ops)
  103. except:
  104. import traceback
  105. err = traceback.format_exc()
  106. self.logger.error(
  107. "When parsing line {}, error happened with msg: {}".format(
  108. data_line, err))
  109. outs = None
  110. if outs is None:
  111. rnd_idx = np.random.randint(self.__len__(
  112. )) if self.mode == "train" else (idx + 1) % self.__len__()
  113. return self.__getitem__(rnd_idx)
  114. return outs
  115. def __len__(self):
  116. return len(self.data_lines)