# Improve performance of neural network

import numpy as np
from lab import neural, io_mnist

def train(net, learnRate = 0.05, epochs = 2, batchSize = 10):
	"""
	Create an improved network

	net : the network to be trained
	learnRate : speed of training, lower is slower but more precise
	epochs : the number of iterations

	return : the trained network
	"""
	# Load data
	np_images, np_expected = io_mnist.load_training_samples()

	nbSamples = np_images.shape[1]

	# Prepare variables
	a0 = np.empty((net.inputLength, batchSize))

	z1 = np.empty((net.hiddenLength, batchSize))
	a1 = np.empty((net.hiddenLength, batchSize))

	z2 = np.empty((net.outputLength, batchSize))
	a2 = np.empty((net.outputLength, batchSize))
	
	y = np.empty((net.outputLength, batchSize))

	g = net.activationFunction
	g_ = net.activationDerivative

	d2 = np.empty(a2.shape)
	d1 = np.empty(a1.shape)

	permut = np.arange(nbSamples)

	for epoch in range(epochs):
		# Create mini batches
		np.random.shuffle(permut)

		# Iterate over batches
		for batchIndex in range(0, nbSamples, batchSize):
			# Update modified weights
			w1 = net.layer1
			b1 = np.stack([net.bias1] * batchSize).transpose() # stack

			w2 = net.layer2
			b2 = np.stack([net.bias2] * batchSize).transpose() # stack

			# Capture batch
			batchEndIndex = batchIndex + batchSize
			batchSelection = permut[batchIndex : batchEndIndex]

			a0 = np_images[:, batchSelection]
			y = np_expected[:, batchSelection]

			# Forward computation
			z1 = w1 @ a0 + b1
			a1 = g(z1)

			z2 = w2 @ a1 + b2
			a2 = g(z2)

			# Backward propagation
			d2 = a2 - y
			d1 = w2.transpose() @ d2 * g_(z1)

			# Weight correction
			net.layer2 -= learnRate * d2 @ a1.transpose() / batchSize
			net.layer1 -= learnRate * d1 @ a0.transpose() / batchSize

			net.bias2 -= learnRate * d2 @ np.ones(batchSize) / batchSize
			net.bias1 -= learnRate * d1 @ np.ones(batchSize) / batchSize

	return net