Browse Source

Add argument for network generation

DricomDragon 5 years ago
parent
commit
0cc0163c6e
1 changed files with 18 additions and 4 deletions
  1. 18 4
      python/start_session.py

+ 18 - 4
python/start_session.py

@@ -10,27 +10,41 @@ import argparse
 # Import local code to call
 from lab import generator, trainer, benchmark
 
+from lab.generator import uniform, gaussAdaptedDev, gaussUnitDev
+
 # Parse script arguments
 parser = argparse.ArgumentParser(description = "Start a training session for a neuronal network and display results.")
 parser.add_argument('-a, --alpha', help = "set the learning rate", dest="learnRate", type = float, default = 0.05)
 parser.add_argument('-e, --epochs', help = "set the number of iterations", dest="epochs", type = int, default = 2)
 parser.add_argument('-n, --neurons', help = "set the number of hidden neurons", dest="hiddenNeurons", type = int, default = 30)
+parser.add_argument('-g, --generator', help = "choose distribution function for network creation", dest="generatorName", default = "gaussAdaptedDev", choices=["none", "uniform", "gaussAdaptedDev", "gaussUnitDev"])
 
 args = parser.parse_args()
 
 # Parameters
 print("Parameters")
-print("Learning rate : ", args.learnRate)
-print("Number of epochs : ", args.epochs)
-print("Number of hidden layers : ", args.hiddenNeurons)
+print("Learning rate :", args.learnRate)
+print("Number of epochs :", args.epochs)
+print("Number of hidden layers :", args.hiddenNeurons)
+print("Generator :", args.generatorName)
 
 activation = expit
 activationDerivative = lambda x : expit(x) * (1 - expit(x))
 
+generatorAssociation = {
+	"none" : None,
+	"uniform" : uniform,
+	"gaussAdaptedDev" : gaussAdaptedDev,
+	"gaussUnitDev" : gaussUnitDev
+}
+
+generatorFunction = generatorAssociation[args.generatorName]
+
+
 # Session
 print("Start session")
 print("... generating neural network ...")
-network = generator.generate(activation, activationDerivative, hiddenLength = args.hiddenNeurons, weightGenerator = generator.gaussAdaptedDev)
+network = generator.generate(activation, activationDerivative, hiddenLength = args.hiddenNeurons, weightGenerator = generatorFunction)
 
 print("... compute precision before training ... ")
 precisionBefore = benchmark.computePrecision(network)