123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263 |
- #!/usr/bin/python3
- # File designed to be launched by user
- print("Start script")
- # Import python libraries
- from scipy.special import expit
- import argparse
- # Import local code to call
- from lab import generator, trainer, benchmark
- from lab.generator import uniform, gaussAdaptedDev, gaussUnitDev
- # Parse script arguments
- parser = argparse.ArgumentParser(description = "Start a training session for a neuronal network and display results.")
- parser.add_argument('-a, --alpha', help = "set the learning rate", dest="learnRate", type = float, default = 0.05)
- parser.add_argument('-e, --epochs', help = "set the number of iterations", dest="epochs", type = int, default = 2)
- parser.add_argument('-n, --neurons', help = "set the number of hidden neurons", dest="hiddenNeurons", type = int, default = 30)
- parser.add_argument('-g, --generator', help = "choose distribution function for network creation", dest="generatorName", default = "gaussAdaptedDev", choices=["none", "uniform", "gaussAdaptedDev", "gaussUnitDev"])
- args = parser.parse_args()
- # Parameters
- print("Parameters")
- print("Learning rate :", args.learnRate)
- print("Number of epochs :", args.epochs)
- print("Number of hidden layers :", args.hiddenNeurons)
- print("Generator :", args.generatorName)
- activation = expit
- activationDerivative = lambda x : expit(x) * (1 - expit(x))
- generatorAssociation = {
- "none" : None,
- "uniform" : uniform,
- "gaussAdaptedDev" : gaussAdaptedDev,
- "gaussUnitDev" : gaussUnitDev
- }
- generatorFunction = generatorAssociation[args.generatorName]
- # Session
- print("Start session")
- print("... generating neural network ...")
- network = generator.generate(activation, activationDerivative, hiddenLength = args.hiddenNeurons, weightGenerator = generatorFunction)
- print("... compute precision before training ... ")
- precisionBefore = benchmark.computePrecision(network)
- print("... train the network (can take a while) ... ")
- network = trainer.train(network, args.learnRate, args.epochs)
- print("... compute precision after training ... ")
- precisionAfter = benchmark.computePrecision(network)
- print("End of session")
- # Display results
- print("Precision before training : ", precisionBefore)
- print("Precision after training : ", precisionAfter)
|