浏览代码

Create neural network structure

DricomDragon 5 年之前
父节点
当前提交
5f7d7872f6
共有 3 个文件被更改,包括 19 次插入6 次删除
  1. 2 2
      python/lab/generator.py
  2. 15 3
      python/lab/neural.py
  3. 2 1
      python/start_session.py

+ 2 - 2
python/lab/generator.py

@@ -15,7 +15,7 @@ def gaussUnitDev():
 	return 0
 
 # Network weight initialization
-def generate(activation, weightGenerator = flat):
+def generate(activation, derivative, weightGenerator = flat):
 	# TODO
-	return neural.Network()
+	return neural.Network(activation, derivative)
 

+ 15 - 3
python/lab/neural.py

@@ -1,7 +1,19 @@
 # Hold the network attributes
 
+import numpy as np
+
 class Network():
-	def __init__(self):
-		# TODO
-		print("Created")
+
+	def __init__(self, activationFunction, activationDerivative):
+		self.inputLength = 784
+		self.hiddenLength = 30
+		self.outputLength = 10
+		self.activationFunction = activationFunction
+		self.activationDerivative = activationDerivative
+
+		self.layer1 = np.zeros((self.hiddenLength, self.inputLength))
+		self.bias1 = np.zeros(self.hiddenLength)
+
+		self.layer2 = np.zeros((self.inputLength, self.hiddenLength))
+		self.bias2 = np.zeros(self.inputLength)
 

+ 2 - 1
python/start_session.py

@@ -12,10 +12,11 @@ from lab import generator, trainer, benchmark
 # Parameters
 learnRate = 0.05
 activation = expit
+activationDerivative = lambda x : expit(x) * (1 - expit(x))
 epochs = 10
 
 # Session
-newNetwork = generator.generate(activation, generator.gaussUnitDev)
+newNetwork = generator.generate(activation, activationDerivative, generator.gaussUnitDev)
 
 trainedNetwork = trainer.train(newNetwork, learnRate, epochs)