Przeglądaj źródła

Change returned label value as numpy array

Can be more easily exploited by the trainer.
DricomDragon 5 lat temu
rodzic
commit
172a3723aa
1 zmienionych plików z 8 dodań i 3 usunięć
  1. 8 3
      python/lab/io_mnist.py

+ 8 - 3
python/lab/io_mnist.py

@@ -3,10 +3,10 @@ import numpy as np
 
 def load_training_samples():
 	"""
-	Return np_images, labels
+	Return np_images, np_expected
 	where
 	np_impages is a np.array of 60 000 x 784
-	labels is a python list of expected answers
+	np_expected is a np.array of 60 000 x 10
 	"""
 	mndata = MNIST('../../resources/download')
 
@@ -20,4 +20,9 @@ def load_training_samples():
 	# Normalize data between 0.0 and 1.0
 	np_images /= 255
 
-	return np_images, labels
+	# Contstruct expected outputs
+	np_expected = np.zeros((len(labels), 10))
+	for k, label in enumerate(labels):
+		np_expected[k][label] = 1.0
+
+	return np_images, np_expected