# Load digit database

from mnist import MNIST
import numpy as np

def load_samples(training):
	"""
	training : load training data if true, load testing data otherwise

	Return np_images, np_expected
	where
	np_impages is a np.array of 784 x 60 000
	np_expected is a np.array of 10 x 60 000
	"""
	mndata = MNIST('../resources/download')

	images = [[]] # Contains vectors of 784 pixels image
	labels = [] # Contains expected response for each image

	if (training):
		images, labels = mndata.load_training()
	else:
		images, labels = mndata.load_testing()

	np_images = np.array(images, dtype=np.float64)

	# Normalize data between 0.0 and 1.0
	np_images /= 255

	# Contstruct expected outputs
	np_expected = np.zeros((len(labels), 10))
	for k, label in enumerate(labels):
		np_expected[k][label] = 1.0

	return np.transpose(np_images), np.transpose(np_expected)

def load_training_samples():
	return load_samples(training = True)

def load_testing_samples():
	return load_samples(training = False)