Quellcode durchsuchen

Fix missing updates during training

Weights and bias were not updated during training
DricomDragon vor 5 Jahren
Ursprung
Commit
a86ce67ccd
1 geänderte Dateien mit 7 neuen und 6 gelöschten Zeilen
  1. 7 6
      python/lab/trainer.py

+ 7 - 6
python/lab/trainer.py

@@ -22,15 +22,9 @@ def train(inputNetwork, learnRate, epochs, batchSize = 10):
 	# Prepare variables
 	a0 = np.empty((net.inputLength, batchSize))
 
-	w1 = net.layer1 # reference
-	b1 =  np.stack([net.bias1] * batchSize).transpose() # stack
-
 	z1 = np.empty((net.hiddenLength, batchSize))
 	a1 = np.empty((net.hiddenLength, batchSize))
 
-	w2 = net.layer2 # reference
-	b2 = np.stack([net.bias2] * batchSize).transpose() # stack
-
 	z2 = np.empty((net.outputLength, batchSize))
 	a2 = np.empty((net.outputLength, batchSize))
 	
@@ -50,6 +44,13 @@ def train(inputNetwork, learnRate, epochs, batchSize = 10):
 
 		# Iterate over batches
 		for batchIndex in range(0, nbSamples, batchSize):
+			# Update modified weights
+			w1 = net.layer1
+			b1 = np.stack([net.bias1] * batchSize).transpose() # stack
+
+			w2 = net.layer2
+			b2 = np.stack([net.bias2] * batchSize).transpose() # stack
+
 			# Capture batch
 			batchEndIndex = batchIndex + batchSize
 			batchSelection = permut[batchIndex : batchEndIndex]