import chainer import chainer.functions as F import chainer.links as L from chainer import training from chainer.training import extensions # Network definition class CNN(chainer.Chain): def __init__(self, train=True): super(CNN, self).__init__( l1 = L.Linear(None, 100), # n_in -> n_units l2 = L.Linear(None, 10),) # n_units -> n_units def __call__(self, x): h = F.sigmoid(self.l1(x)) h = self.l2(h) return h # Load the MNIST dataset train, test = chainer.datasets.get_mnist(ndim=3) # Set up a neural network model and Set up a trainer model = L.Classifier(CNN()) optimizer = chainer.optimizers.Adam() optimizer.setup(model) train_iter = chainer.iterators.SerialIterator(train, batch_size=100) test_iter = chainer.iterators.SerialIterator(test, batch_size=100, repeat=False, shuffle=False) updater = training.StandardUpdater(train_iter, optimizer) trainer = training.Trainer(updater, (5, 'epoch'), out='result') trainer.extend(extensions.Evaluator(test_iter, model)) trainer.extend(extensions.LogReport()) trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy'])) # Run the training trainer.run()