start_session.py 748 B

1234567891011121314151617181920212223242526272829
  1. #!/usr/bin/python3
  2. # File designed to be launched by user
  3. print("Start session")
  4. # Import python libraries
  5. from scipy.special import expit
  6. # Import local code to call
  7. from lab import generator, trainer, benchmark
  8. # Parameters
  9. learnRate = 0.05
  10. activation = expit
  11. activationDerivative = lambda x : expit(x) * (1 - expit(x))
  12. epochs = 10
  13. # Session
  14. newNetwork = generator.generate(activation, activationDerivative, generator.gaussUnitDev)
  15. trainedNetwork = trainer.train(newNetwork, learnRate, epochs)
  16. precisionBefore = benchmark.computePrecision(newNetwork)
  17. precisionAfter = benchmark.computePrecision(trainedNetwork)
  18. # Display results
  19. print("Precision before training : ", precisionBefore)
  20. print("Precision after training : ", precisionAfter)