io_mnist.py 999 B

1234567891011121314151617181920212223242526272829303132333435363738394041
  1. # Load digit database
  2. from mnist import MNIST
  3. import numpy as np
  4. def load_samples(training):
  5. """
  6. training : load training data if true, load testing data otherwise
  7. Return np_images, np_expected
  8. where
  9. np_impages is a np.array of 784 x 60 000
  10. np_expected is a np.array of 10 x 60 000
  11. """
  12. mndata = MNIST('../resources/download')
  13. images = [[]] # Contains vectors of 784 pixels image
  14. labels = [] # Contains expected response for each image
  15. if (training):
  16. images, labels = mndata.load_training()
  17. else:
  18. images, labels = mndata.load_testing()
  19. np_images = np.array(images, dtype=np.float64)
  20. # Normalize data between 0.0 and 1.0
  21. np_images /= 255
  22. # Contstruct expected outputs
  23. np_expected = np.zeros((len(labels), 10))
  24. for k, label in enumerate(labels):
  25. np_expected[k][label] = 1.0
  26. return np.transpose(np_images), np.transpose(np_expected)
  27. def load_training_samples():
  28. return load_samples(training = True)
  29. def load_testing_samples():
  30. return load_samples(training = False)