trainer.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. # Improve performance of neural network
  2. import numpy as np
  3. from lab import neural, io_mnist
  4. from copy import copy
  5. def train(inputNetwork, learnRate, epochs, batchSize = 10):
  6. """
  7. Create an improved network
  8. inputNetwork : the network to be trained
  9. epochs : the number of iterations
  10. return : a trained copy with improved performance
  11. """
  12. net = copy(inputNetwork)
  13. np_images, np_expected = io_mnist.load_training_samples()
  14. nbSamples = np_images.shape[1]
  15. # Prepare variables
  16. a0 = np.empty((net.inputLength, batchSize))
  17. w1 = net.layer1 # reference
  18. b1 = np.stack([net.bias1] * batchSize).transpose() # stack
  19. z1 = np.empty((net.hiddenLength, batchSize))
  20. a1 = np.empty((net.hiddenLength, batchSize))
  21. w2 = net.layer2 # reference
  22. b2 = np.stack([net.bias2] * batchSize).transpose() # stack
  23. z2 = np.empty((net.outputLength, batchSize))
  24. a2 = np.empty((net.outputLength, batchSize))
  25. y = np.empty((net.outputLength, batchSize))
  26. g = net.activationFunction
  27. g_ = net.activationDerivative
  28. d2 = np.empty(a2.shape)
  29. d1 = np.empty(a1.shape)
  30. for epoch in range(epochs):
  31. # Create mini batches
  32. # TODO Shuffle samples
  33. # Iterate over batches
  34. for batchIndex in range(0, nbSamples, batchSize):
  35. # Capture batch
  36. batchEndIndex = batchIndex + batchSize
  37. a0 = np_images[:, batchIndex:batchEndIndex]
  38. y = np_expected[:, batchIndex:batchEndIndex]
  39. # Forward computation
  40. z1 = w1 @ a0 + b1
  41. a1 = g(z1)
  42. z2 = w2 @ a1 + b2
  43. a2 = g(z2)
  44. # Backward propagation
  45. d2 = a2 - y
  46. d1 = w2.transpose() @ d2 * g_(z1)
  47. # Weight correction
  48. net.layer2 -= learnRate * d2 @ a1.transpose() / batchSize
  49. net.layer1 -= learnRate * d1 @ a0.transpose() / batchSize
  50. net.bias2 -= learnRate * d2 @ np.ones(batchSize) / batchSize
  51. net.bias1 -= learnRate * d1 @ np.ones(batchSize) / batchSize
  52. return net