|
@@ -0,0 +1,28 @@
|
|
|
+#!/usr/bin/python3
|
|
|
+
|
|
|
+# File designed to be launched by user
|
|
|
+print("Start session")
|
|
|
+
|
|
|
+# Import python libraries
|
|
|
+from scipy.special import expit
|
|
|
+
|
|
|
+# Import local code to call
|
|
|
+from lab import generator, trainer, benchmark
|
|
|
+
|
|
|
+# Parameters
|
|
|
+learnRate = 0.05
|
|
|
+activation = expit
|
|
|
+epochs = 10
|
|
|
+
|
|
|
+# Session
|
|
|
+newNetwork = generator.generate(activation, generator.gaussUnitDev)
|
|
|
+
|
|
|
+trainedNetwork = trainer.train(newNetwork, learnRate, epochs)
|
|
|
+
|
|
|
+precisionBefore = benchmark.computePrecision(newNetwork)
|
|
|
+precisionAfter = benchmark.computePrecision(trainedNetwork)
|
|
|
+
|
|
|
+# Display results
|
|
|
+print("Precision before training : ", precisionBefore)
|
|
|
+print("Precision after training : ", precisionAfter)
|
|
|
+
|