start_session.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. #!/usr/bin/python3
  2. # File designed to be launched by user
  3. print("Start script")
  4. # Import python libraries
  5. from scipy.special import expit
  6. import argparse
  7. # Import local code to call
  8. from lab import generator, trainer, benchmark
  9. # Parse script arguments
  10. parser = argparse.ArgumentParser(description = "Start a training session for a neuronal network and display results.")
  11. parser.add_argument('-a, --alpha', help = "set the learning rate", dest="learnRate", type = float, default = 0.05)
  12. parser.add_argument('-e, --epochs', help = "set the number of iterations", dest="epochs", type = int, default = 2)
  13. parser.add_argument('-n, --neurons', help = "set the number of hidden neurons", dest="hiddenNeurons", type = int, default = 30)
  14. args = parser.parse_args()
  15. # Parameters
  16. print("Parameters")
  17. print("Learning rate : ", args.learnRate)
  18. print("Number of epochs : ", args.epochs)
  19. print("Number of hidden layers : ", args.hiddenNeurons)
  20. activation = expit
  21. activationDerivative = lambda x : expit(x) * (1 - expit(x))
  22. # Session
  23. print("Start session")
  24. print("... generating neural network ...")
  25. network = generator.generate(activation, activationDerivative, hiddenLength = args.hiddenNeurons, weightGenerator = generator.gaussAdaptedDev)
  26. print("... compute precision before training ... ")
  27. precisionBefore = benchmark.computePrecision(network)
  28. print("... train the network (can take a while) ... ")
  29. network = trainer.train(network, args.learnRate, args.epochs)
  30. print("... compute precision after training ... ")
  31. precisionAfter = benchmark.computePrecision(network)
  32. print("End of session")
  33. # Display results
  34. print("Precision before training : ", precisionBefore)
  35. print("Precision after training : ", precisionAfter)