#!/usr/bin/python3

# File designed to be launched by user
print("Start script")

# Import python libraries
from scipy.special import expit
import argparse

# Import local code to call
from lab import generator, trainer, benchmark

# Parse script arguments
parser = argparse.ArgumentParser(description = "Start a training session for a neuronal network and display results.")
parser.add_argument('--alpha', help = "set the learning rate", dest="learnRate", type = float, default = 0.05)
parser.add_argument('--epochs', help = "set the number of iterations", dest="epochs", type = int, default = 2)

args = parser.parse_args()

# Parameters
print("Parameters")
print("Learning rate : ", args.learnRate)
print("Number of epochs : ", args.epochs)

activation = expit
activationDerivative = lambda x : expit(x) * (1 - expit(x))

# Session
print("Start session")
print("... generating neural network ...")
network = generator.generate(activation, activationDerivative, generator.gaussAdaptedDev)

print("... compute precision before training ... ")
precisionBefore = benchmark.computePrecision(network)

print("... train the network (can take a while) ... ")
network = trainer.train(network, args.learnRate, args.epochs)

print("... compute precision after training ... ")
precisionAfter = benchmark.computePrecision(network)

print("End of session")

# Display results
print("Precision before training : ", precisionBefore)
print("Precision after training : ", precisionAfter)