data_loader.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. import numpy as np
  2. from paddle.vision.datasets import Cifar100
  3. from paddle.vision.transforms import Normalize
  4. from paddle.fluid.dataloader.collate import default_collate_fn
  5. import signal
  6. import os
  7. from paddle.io import Dataset, DataLoader, DistributedBatchSampler
  8. def term_mp(sig_num, frame):
  9. """ kill all child processes
  10. """
  11. pid = os.getpid()
  12. pgid = os.getpgid(os.getpid())
  13. print("main proc {} exit, kill process group " "{}".format(pid, pgid))
  14. os.killpg(pgid, signal.SIGKILL)
  15. return
  16. def build_dataloader(mode,
  17. batch_size=4,
  18. seed=None,
  19. num_workers=0,
  20. device='gpu:0'):
  21. normalize = Normalize(
  22. mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='HWC')
  23. if mode.lower() == "train":
  24. dataset = Cifar100(mode=mode, transform=normalize)
  25. elif mode.lower() in ["test", 'valid', 'eval']:
  26. dataset = Cifar100(mode="test", transform=normalize)
  27. else:
  28. raise ValueError(f"{mode} should be one of ['train', 'test']")
  29. # define batch sampler
  30. batch_sampler = DistributedBatchSampler(
  31. dataset=dataset, batch_size=batch_size, shuffle=False, drop_last=True)
  32. data_loader = DataLoader(
  33. dataset=dataset,
  34. batch_sampler=batch_sampler,
  35. places=device,
  36. num_workers=num_workers,
  37. return_list=True,
  38. use_shared_memory=False)
  39. # support exit using ctrl+c
  40. signal.signal(signal.SIGINT, term_mp)
  41. signal.signal(signal.SIGTERM, term_mp)
  42. return data_loader
  43. # cifar100 = Cifar100(mode='train', transform=normalize)
  44. # data = cifar100[0]
  45. # image, label = data
  46. # reader = build_dataloader('train')
  47. # for idx, data in enumerate(reader):
  48. # print(idx, data[0].shape, data[1].shape)
  49. # if idx >= 10:
  50. # break