#!/usr/bin/python3 # File designed to be launched by user print("Start script") # Import python libraries from scipy.special import expit import argparse # Import local code to call from lab import generator, trainer, benchmark from lab.generator import uniform, gaussAdaptedDev, gaussUnitDev # Parse script arguments parser = argparse.ArgumentParser(description = "Start a training session for a neuronal network and display results.") parser.add_argument('-a, --alpha', help = "set the learning rate", dest="learnRate", type = float, default = 0.05) parser.add_argument('-e, --epochs', help = "set the number of iterations", dest="epochs", type = int, default = 2) parser.add_argument('-n, --neurons', help = "set the number of hidden neurons", dest="hiddenNeurons", type = int, default = 30) parser.add_argument('-g, --generator', help = "choose distribution function for network creation", dest="generatorName", default = "gaussAdaptedDev", choices=["none", "uniform", "gaussAdaptedDev", "gaussUnitDev"]) args = parser.parse_args() # Parameters print("Parameters") print("Learning rate :", args.learnRate) print("Number of epochs :", args.epochs) print("Number of hidden layers :", args.hiddenNeurons) print("Generator :", args.generatorName) activation = expit activationDerivative = lambda x : expit(x) * (1 - expit(x)) generatorAssociation = { "none" : None, "uniform" : uniform, "gaussAdaptedDev" : gaussAdaptedDev, "gaussUnitDev" : gaussUnitDev } generatorFunction = generatorAssociation[args.generatorName] # Session print("Start session") print("... generating neural network ...") network = generator.generate(activation, activationDerivative, hiddenLength = args.hiddenNeurons, weightGenerator = generatorFunction) print("... compute precision before training ... ") precisionBefore = benchmark.computePrecision(network) print("... train the network (can take a while) ... ") network = trainer.train(network, args.learnRate, args.epochs) print("... compute precision after training ... ") precisionAfter = benchmark.computePrecision(network) print("End of session") # Display results print("Precision before training : ", precisionBefore) print("Precision after training : ", precisionAfter)