|
@@ -2,20 +2,17 @@
|
|
|
|
|
|
import numpy as np
|
|
|
from lab import neural, io_mnist
|
|
|
-from copy import copy
|
|
|
|
|
|
-def train(inputNetwork, learnRate = 0.05, epochs = 2, batchSize = 10):
|
|
|
+def train(net, learnRate = 0.05, epochs = 2, batchSize = 10):
|
|
|
"""
|
|
|
Create an improved network
|
|
|
|
|
|
- inputNetwork : the network to be trained
|
|
|
+ net : the network to be trained
|
|
|
learnRate : speed of training, lower is slower but more precise
|
|
|
epochs : the number of iterations
|
|
|
|
|
|
- return : a trained copy with improved performance
|
|
|
+ return : the trained network
|
|
|
"""
|
|
|
- net = copy(inputNetwork)
|
|
|
-
|
|
|
# Load data
|
|
|
np_images, np_expected = io_mnist.load_training_samples()
|
|
|
|