Browse Source

Add the argument of number hidden neurons

DricomDragon 5 years ago
parent
commit
ce244e1f15
2 changed files with 6 additions and 3 deletions
  1. 3 2
      python/lab/generator.py
  2. 3 1
      python/start_session.py

+ 3 - 2
python/lab/generator.py

@@ -29,17 +29,18 @@ def gaussAdaptedDev(layer):
 	return np.random.normal(scale = stdDev, size = layer.shape)
 
 # Network weight initialization
-def generate(activation, derivative, weightGenerator = None):
+def generate(activation, derivative, hiddenLength = 30, weightGenerator = None):
 	"""
 	activation : function used on network outputs
 	derivative : the derivative of the activation function
+	hiddenLength : number of neurons in the hidden layer
 	weightGenerator : is one of
 	None
 	generator.uniform
 	generator.gaussUnitDev
 	generator.gaussAdaptedDev
 	"""
-	net = neural.Network(activation, derivative)
+	net = neural.Network(activation, derivative, hiddenLength)
 
 	if (weightGenerator is not None):
 		net.layer1 = weightGenerator(net.layer1)

+ 3 - 1
python/start_session.py

@@ -14,6 +14,7 @@ from lab import generator, trainer, benchmark
 parser = argparse.ArgumentParser(description = "Start a training session for a neuronal network and display results.")
 parser.add_argument('-a, --alpha', help = "set the learning rate", dest="learnRate", type = float, default = 0.05)
 parser.add_argument('-e, --epochs', help = "set the number of iterations", dest="epochs", type = int, default = 2)
+parser.add_argument('-n, --neurons', help = "set the number of hidden neurons", dest="hiddenNeurons", type = int, default = 30)
 
 args = parser.parse_args()
 
@@ -21,6 +22,7 @@ args = parser.parse_args()
 print("Parameters")
 print("Learning rate : ", args.learnRate)
 print("Number of epochs : ", args.epochs)
+print("Number of hidden layers : ", args.hiddenNeurons)
 
 activation = expit
 activationDerivative = lambda x : expit(x) * (1 - expit(x))
@@ -28,7 +30,7 @@ activationDerivative = lambda x : expit(x) * (1 - expit(x))
 # Session
 print("Start session")
 print("... generating neural network ...")
-network = generator.generate(activation, activationDerivative, generator.gaussAdaptedDev)
+network = generator.generate(activation, activationDerivative, hiddenLength = args.hiddenNeurons, weightGenerator = generator.gaussAdaptedDev)
 
 print("... compute precision before training ... ")
 precisionBefore = benchmark.computePrecision(network)