12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879 |
- # Improve performance of neural network
- import numpy as np
- from lab import neural, io_mnist
- def train(net, learnRate = 0.05, epochs = 2, batchSize = 10):
- """
- Create an improved network
- net : the network to be trained
- learnRate : speed of training, lower is slower but more precise
- epochs : the number of iterations
- return : the trained network
- """
- # Load data
- np_images, np_expected = io_mnist.load_training_samples()
- nbSamples = np_images.shape[1]
- # Prepare variables
- a0 = np.empty((net.inputLength, batchSize))
- z1 = np.empty((net.hiddenLength, batchSize))
- a1 = np.empty((net.hiddenLength, batchSize))
- z2 = np.empty((net.outputLength, batchSize))
- a2 = np.empty((net.outputLength, batchSize))
-
- y = np.empty((net.outputLength, batchSize))
- g = net.activationFunction
- g_ = net.activationDerivative
- d2 = np.empty(a2.shape)
- d1 = np.empty(a1.shape)
- permut = np.arange(nbSamples)
- for epoch in range(epochs):
- # Create mini batches
- np.random.shuffle(permut)
- # Iterate over batches
- for batchIndex in range(0, nbSamples, batchSize):
- # Update modified weights
- w1 = net.layer1
- b1 = np.stack([net.bias1] * batchSize).transpose() # stack
- w2 = net.layer2
- b2 = np.stack([net.bias2] * batchSize).transpose() # stack
- # Capture batch
- batchEndIndex = batchIndex + batchSize
- batchSelection = permut[batchIndex : batchEndIndex]
- a0 = np_images[:, batchSelection]
- y = np_expected[:, batchSelection]
- # Forward computation
- z1 = w1 @ a0 + b1
- a1 = g(z1)
- z2 = w2 @ a1 + b2
- a2 = g(z2)
- # Backward propagation
- d2 = a2 - y
- d1 = w2.transpose() @ d2 * g_(z1)
- # Weight correction
- net.layer2 -= learnRate * d2 @ a1.transpose() / batchSize
- net.layer1 -= learnRate * d1 @ a0.transpose() / batchSize
- net.bias2 -= learnRate * d2 @ np.ones(batchSize) / batchSize
- net.bias1 -= learnRate * d1 @ np.ones(batchSize) / batchSize
- return net
|