Quellcode durchsuchen

Add argument parser for learning rate and epochs

DricomDragon vor 5 Jahren
Ursprung
Commit
95f45be5db
1 geänderte Dateien mit 13 neuen und 3 gelöschten Zeilen
  1. 13 3
      python/start_session.py

+ 13 - 3
python/start_session.py

@@ -5,22 +5,32 @@ print("Start session")
 
 # Import python libraries
 from scipy.special import expit
+import argparse
 
 # Import local code to call
 from lab import generator, trainer, benchmark
 
+# Parse script arguments
+parser = argparse.ArgumentParser(description = "Start a training session for a neuronal network and display results.")
+parser.add_argument('--alpha', help = "set the learning rate", dest="learnRate", type = float, default = 0.05)
+parser.add_argument('--epochs', help = "set the number of iterations", dest="epochs", type = int, default = 2)
+
+args = parser.parse_args()
+
 # Parameters
-learnRate = 0.05
+print("Parameters")
+print("Learning rate : ", args.learnRate)
+print("Number of epochs : ", args.epochs)
+
 activation = expit
 activationDerivative = lambda x : expit(x) * (1 - expit(x))
-epochs = 10
 
 # Session
 network = generator.generate(activation, activationDerivative, generator.gaussAdaptedDev)
 
 precisionBefore = benchmark.computePrecision(network)
 
-network = trainer.train(network, learnRate, epochs)
+network = trainer.train(network, args.learnRate, args.epochs)
 
 precisionAfter = benchmark.computePrecision(network)