1234567891011121314151617181920212223242526272829303132333435363738394041 |
- # Load digit database
- from mnist import MNIST
- import numpy as np
- def load_samples(training):
- """
- training : load training data if true, load testing data otherwise
- Return np_images, np_expected
- where
- np_impages is a np.array of 784 x 60 000
- np_expected is a np.array of 10 x 60 000
- """
- mndata = MNIST('../resources/download')
- images = [[]] # Contains vectors of 784 pixels image
- labels = [] # Contains expected response for each image
- if (training):
- images, labels = mndata.load_training()
- else:
- images, labels = mndata.load_testing()
- np_images = np.array(images, dtype=np.float64)
- # Normalize data between 0.0 and 1.0
- np_images /= 255
- # Contstruct expected outputs
- np_expected = np.zeros((len(labels), 10))
- for k, label in enumerate(labels):
- np_expected[k][label] = 1.0
- return np.transpose(np_images), np.transpose(np_expected)
- def load_training_samples():
- return load_samples(training = True)
- def load_testing_samples():
- return load_samples(training = False)
|