start_session.py 738 B

12345678910111213141516171819202122232425262728293031
  1. #!/usr/bin/python3
  2. # File designed to be launched by user
  3. print("Start session")
  4. # Import python libraries
  5. from scipy.special import expit
  6. import numpy as np
  7. # Import local code to call
  8. from lab import generator, trainer, benchmark
  9. # Parameters
  10. learnRate = 0.05
  11. activation = expit
  12. activationDerivative = lambda x : expit(x) * (1 - expit(x))
  13. epochs = 1
  14. # Session
  15. network = generator.generate(activation, activationDerivative, np.random.normal)
  16. precisionBefore = benchmark.computePrecision(network)
  17. network = trainer.train(network, learnRate, epochs)
  18. precisionAfter = benchmark.computePrecision(network)
  19. # Display results
  20. print("Precision before training : ", precisionBefore)
  21. print("Precision after training : ", precisionAfter)