Files
MachineLearningNotebooks/how-to-use-azureml/training-with-deep-learning/distributed-chainer/chainer_mnist_distributed.py
Roope Astala b7b5576b15 version 1.0.17
2019-02-25 16:12:02 -05:00

154 lines
5.1 KiB
Python

import argparse
import chainer
import chainer.cuda
import chainer.functions as F
import chainer.links as L
from chainer import training
from chainer.training import extensions
import chainermn
import chainermn.datasets
import chainermn.functions
chainer.disable_experimental_feature_warning = True
class MLP0SubA(chainer.Chain):
def __init__(self, comm, n_out):
super(MLP0SubA, self).__init__(
l1=L.Linear(784, n_out))
def __call__(self, x):
return F.relu(self.l1(x))
class MLP0SubB(chainer.Chain):
def __init__(self, comm):
super(MLP0SubB, self).__init__()
def __call__(self, y):
return y
class MLP0(chainermn.MultiNodeChainList):
# Model on worker 0.
def __init__(self, comm, n_out):
super(MLP0, self).__init__(comm=comm)
self.add_link(MLP0SubA(comm, n_out), rank_in=None, rank_out=1)
self.add_link(MLP0SubB(comm), rank_in=1, rank_out=None)
class MLP1Sub(chainer.Chain):
def __init__(self, n_units, n_out):
super(MLP1Sub, self).__init__(
l2=L.Linear(None, n_units),
l3=L.Linear(None, n_out))
def __call__(self, h0):
h1 = F.relu(self.l2(h0))
return self.l3(h1)
class MLP1(chainermn.MultiNodeChainList):
# Model on worker 1.
def __init__(self, comm, n_units, n_out):
super(MLP1, self).__init__(comm=comm)
self.add_link(MLP1Sub(n_units, n_out), rank_in=0, rank_out=0)
def main():
parser = argparse.ArgumentParser(
description='ChainerMN example: pipelined neural network')
parser.add_argument('--batchsize', '-b', type=int, default=100,
help='Number of images in each mini-batch')
parser.add_argument('--epoch', '-e', type=int, default=20,
help='Number of sweeps over the dataset to train')
parser.add_argument('--gpu', '-g', action='store_true',
help='Use GPU')
parser.add_argument('--out', '-o', default='result',
help='Directory to output the result')
parser.add_argument('--unit', '-u', type=int, default=1000,
help='Number of units')
args = parser.parse_args()
# Prepare ChainerMN communicator.
if args.gpu:
comm = chainermn.create_communicator('hierarchical')
data_axis, model_axis = comm.rank % 2, comm.rank // 2
data_comm = comm.split(data_axis, comm.rank)
model_comm = comm.split(model_axis, comm.rank)
device = comm.intra_rank
else:
comm = chainermn.create_communicator('naive')
data_axis, model_axis = comm.rank % 2, comm.rank // 2
data_comm = comm.split(data_axis, comm.rank)
model_comm = comm.split(model_axis, comm.rank)
device = -1
if model_comm.size != 2:
raise ValueError(
'This example can only be executed on the even number'
'of processes.')
if comm.rank == 0:
print('==========================================')
if args.gpu:
print('Using GPUs')
print('Num unit: {}'.format(args.unit))
print('Num Minibatch-size: {}'.format(args.batchsize))
print('Num epoch: {}'.format(args.epoch))
print('==========================================')
if data_axis == 0:
model = L.Classifier(MLP0(model_comm, args.unit))
elif data_axis == 1:
model = MLP1(model_comm, args.unit, 10)
if device >= 0:
chainer.cuda.get_device_from_id(device).use()
model.to_gpu()
optimizer = chainermn.create_multi_node_optimizer(
chainer.optimizers.Adam(), data_comm)
optimizer.setup(model)
# Original dataset on worker 0 and 1.
# Datasets of worker 0 and 1 are split and distributed to all workers.
if model_axis == 0:
train, test = chainer.datasets.get_mnist()
if data_axis == 1:
train = chainermn.datasets.create_empty_dataset(train)
test = chainermn.datasets.create_empty_dataset(test)
else:
train, test = None, None
train = chainermn.scatter_dataset(train, data_comm, shuffle=True)
test = chainermn.scatter_dataset(test, data_comm, shuffle=True)
train_iter = chainer.iterators.SerialIterator(
train, args.batchsize, shuffle=False)
test_iter = chainer.iterators.SerialIterator(
test, args.batchsize, repeat=False, shuffle=False)
updater = training.StandardUpdater(train_iter, optimizer, device=device)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)
evaluator = extensions.Evaluator(test_iter, model, device=device)
evaluator = chainermn.create_multi_node_evaluator(evaluator, data_comm)
trainer.extend(evaluator)
# Some display and output extentions are necessary only for worker 0.
if comm.rank == 0:
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
trainer.extend(extensions.ProgressBar())
trainer.run()
if __name__ == '__main__':
main()