Quellcode durchsuchen

Make weight generator more user-friendly

DricomDragon vor 5 Jahren
Ursprung
Commit
eedb7117ff
2 geänderte Dateien mit 16 neuen und 4 gelöschten Zeilen
  1. 15 2
      python/lab/generator.py
  2. 1 2
      python/start_session.py

+ 15 - 2
python/lab/generator.py

@@ -1,6 +1,19 @@
 # Generate neural network
 
 from lab import neural
+import numpy as np
+
+# Random generators
+def uniform(layer):
+	# TODO
+	return layer
+
+def gaussUnitDev(layer):
+	return np.random.normal(size = layer.shape)
+
+def gaussAdaptedDev():
+	# TODO
+	return layer
 
 # Network weight initialization
 def generate(activation, derivative, weightGenerator = None):
@@ -11,7 +24,7 @@ def generate(activation, derivative, weightGenerator = None):
 	net = neural.Network(activation, derivative)
 
 	if (weightGenerator is not None):
-		net.layer1 = weightGenerator(size = net.layer1.shape)
-		net.layer2 = weightGenerator(size = net.layer2.shape)
+		net.layer1 = weightGenerator(net.layer1)
+		net.layer2 = weightGenerator(net.layer2)
 
 	return net

+ 1 - 2
python/start_session.py

@@ -5,7 +5,6 @@ print("Start session")
 
 # Import python libraries
 from scipy.special import expit
-import numpy as np
 
 # Import local code to call
 from lab import generator, trainer, benchmark
@@ -17,7 +16,7 @@ activationDerivative = lambda x : expit(x) * (1 - expit(x))
 epochs = 1
 
 # Session
-network = generator.generate(activation, activationDerivative, np.random.normal)
+network = generator.generate(activation, activationDerivative, generator.gaussUnitDev)
 
 precisionBefore = benchmark.computePrecision(network)