load_cifar.py 1.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. import pickle as p
  2. import numpy as np
  3. from PIL import Image
  4. def load_CIFAR_batch(filename):
  5. """ load single batch of cifar """
  6. with open(filename, 'rb') as f:
  7. datadict = p.load(f, encoding='bytes')
  8. # 以字典的形式取出数据
  9. X = datadict[b'data']
  10. Y = datadict[b'fine_labels']
  11. try:
  12. X = X.reshape(10000, 3, 32, 32)
  13. except:
  14. X = X.reshape(50000, 3, 32, 32)
  15. Y = np.array(Y)
  16. print(Y.shape)
  17. return X, Y
  18. if __name__ == "__main__":
  19. mode = "train"
  20. imgX, imgY = load_CIFAR_batch(f"./cifar-100-python/{mode}")
  21. with open(f'./cifar-100-python/{mode}_imgs/img_label.txt', 'a+') as f:
  22. for i in range(imgY.shape[0]):
  23. f.write('img' + str(i) + ' ' + str(imgY[i]) + '\n')
  24. for i in range(imgX.shape[0]):
  25. imgs = imgX[i]
  26. img0 = imgs[0]
  27. img1 = imgs[1]
  28. img2 = imgs[2]
  29. i0 = Image.fromarray(img0)
  30. i1 = Image.fromarray(img1)
  31. i2 = Image.fromarray(img2)
  32. img = Image.merge("RGB", (i0, i1, i2))
  33. name = "img" + str(i) + ".png"
  34. img.save(f"./cifar-100-python/{mode}_imgs/" + name, "png")
  35. print("save successfully!")