|
@@ -12,10 +12,11 @@ from lab import generator, trainer, benchmark
|
|
|
|
|
|
learnRate = 0.05
|
|
|
activation = expit
|
|
|
+activationDerivative = lambda x : expit(x) * (1 - expit(x))
|
|
|
epochs = 10
|
|
|
|
|
|
|
|
|
-newNetwork = generator.generate(activation, generator.gaussUnitDev)
|
|
|
+newNetwork = generator.generate(activation, activationDerivative, generator.gaussUnitDev)
|
|
|
|
|
|
trainedNetwork = trainer.train(newNetwork, learnRate, epochs)
|
|
|
|