start_session.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. #!/usr/bin/python3
  2. # File designed to be launched by user
  3. print("Start script")
  4. # Import python libraries
  5. from scipy.special import expit
  6. import argparse
  7. # Import local code to call
  8. from lab import generator, trainer, benchmark
  9. from lab.generator import uniform, gaussAdaptedDev, gaussUnitDev
  10. # Parse script arguments
  11. parser = argparse.ArgumentParser(description = "Start a training session for a neuronal network and display results.")
  12. parser.add_argument('-a, --alpha', help = "set the learning rate", dest="learnRate", type = float, default = 0.05)
  13. parser.add_argument('-e, --epochs', help = "set the number of iterations", dest="epochs", type = int, default = 2)
  14. parser.add_argument('-n, --neurons', help = "set the number of hidden neurons", dest="hiddenNeurons", type = int, default = 30)
  15. parser.add_argument('-g, --generator', help = "choose distribution function for network creation", dest="generatorName", default = "gaussAdaptedDev", choices=["none", "uniform", "gaussAdaptedDev", "gaussUnitDev"])
  16. args = parser.parse_args()
  17. # Parameters
  18. print("Parameters")
  19. print("Learning rate :", args.learnRate)
  20. print("Number of epochs :", args.epochs)
  21. print("Number of hidden layers :", args.hiddenNeurons)
  22. print("Generator :", args.generatorName)
  23. activation = expit
  24. activationDerivative = lambda x : expit(x) * (1 - expit(x))
  25. generatorAssociation = {
  26. "none" : None,
  27. "uniform" : uniform,
  28. "gaussAdaptedDev" : gaussAdaptedDev,
  29. "gaussUnitDev" : gaussUnitDev
  30. }
  31. generatorFunction = generatorAssociation[args.generatorName]
  32. # Session
  33. print("Start session")
  34. print("... generating neural network ...")
  35. network = generator.generate(activation, activationDerivative, hiddenLength = args.hiddenNeurons, weightGenerator = generatorFunction)
  36. print("... compute precision before training ... ")
  37. precisionBefore = benchmark.computePrecision(network)
  38. print("... train the network (can take a while) ... ")
  39. network = trainer.train(network, args.learnRate, args.epochs)
  40. print("... compute precision after training ... ")
  41. precisionAfter = benchmark.computePrecision(network)
  42. print("End of session")
  43. # Display results
  44. print("Precision before training : ", precisionBefore)
  45. print("Precision after training : ", precisionAfter)