trainer.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # Improve performance of neural network
  2. import numpy as np
  3. from lab import neural, io_mnist
  4. from copy import copy
  5. def train(inputNetwork, learnRate = 0.05, epochs = 2, batchSize = 10):
  6. """
  7. Create an improved network
  8. inputNetwork : the network to be trained
  9. learnRate : speed of training, lower is slower but more precise
  10. epochs : the number of iterations
  11. return : a trained copy with improved performance
  12. """
  13. net = copy(inputNetwork)
  14. # Load data
  15. np_images, np_expected = io_mnist.load_training_samples()
  16. nbSamples = np_images.shape[1]
  17. # Prepare variables
  18. a0 = np.empty((net.inputLength, batchSize))
  19. z1 = np.empty((net.hiddenLength, batchSize))
  20. a1 = np.empty((net.hiddenLength, batchSize))
  21. z2 = np.empty((net.outputLength, batchSize))
  22. a2 = np.empty((net.outputLength, batchSize))
  23. y = np.empty((net.outputLength, batchSize))
  24. g = net.activationFunction
  25. g_ = net.activationDerivative
  26. d2 = np.empty(a2.shape)
  27. d1 = np.empty(a1.shape)
  28. permut = np.arange(nbSamples)
  29. for epoch in range(epochs):
  30. # Create mini batches
  31. np.random.shuffle(permut)
  32. # Iterate over batches
  33. for batchIndex in range(0, nbSamples, batchSize):
  34. # Update modified weights
  35. w1 = net.layer1
  36. b1 = np.stack([net.bias1] * batchSize).transpose() # stack
  37. w2 = net.layer2
  38. b2 = np.stack([net.bias2] * batchSize).transpose() # stack
  39. # Capture batch
  40. batchEndIndex = batchIndex + batchSize
  41. batchSelection = permut[batchIndex : batchEndIndex]
  42. a0 = np_images[:, batchSelection]
  43. y = np_expected[:, batchSelection]
  44. # Forward computation
  45. z1 = w1 @ a0 + b1
  46. a1 = g(z1)
  47. z2 = w2 @ a1 + b2
  48. a2 = g(z2)
  49. # Backward propagation
  50. d2 = a2 - y
  51. d1 = w2.transpose() @ d2 * g_(z1)
  52. # Weight correction
  53. net.layer2 -= learnRate * d2 @ a1.transpose() / batchSize
  54. net.layer1 -= learnRate * d1 @ a0.transpose() / batchSize
  55. net.bias2 -= learnRate * d2 @ np.ones(batchSize) / batchSize
  56. net.bias1 -= learnRate * d1 @ np.ones(batchSize) / batchSize
  57. return net