mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 09:37:04 -05:00
272 lines
11 KiB
Python
272 lines
11 KiB
Python
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
|
# Licensed under the Apache License, Version 2.0
|
|
# Script adapted from:
|
|
# https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/dist_test/python/mnist_replica.py
|
|
# ==============================================================================
|
|
"""Distributed MNIST training and validation, with model replicas.
|
|
A simple softmax model with one hidden layer is defined. The parameters
|
|
(weights and biases) are located on one parameter server (ps), while the ops
|
|
are executed on two worker nodes by default. The TF sessions also run on the
|
|
worker node.
|
|
Multiple invocations of this script can be done in parallel, with different
|
|
values for --task_index. There should be exactly one invocation with
|
|
--task_index, which will create a master session that carries out variable
|
|
initialization. The other, non-master, sessions will wait for the master
|
|
session to finish the initialization before proceeding to the training stage.
|
|
The coordination between the multiple worker invocations occurs due to
|
|
the definition of the parameters on the same ps devices. The parameter updates
|
|
from one worker is visible to all other workers. As such, the workers can
|
|
perform forward computation and gradient calculation in parallel, which
|
|
should lead to increased training speed for the simple model.
|
|
"""
|
|
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import os
|
|
import math
|
|
import sys
|
|
import tempfile
|
|
import time
|
|
import json
|
|
|
|
import tensorflow as tf
|
|
from tensorflow.examples.tutorials.mnist import input_data
|
|
from azureml.core.run import Run
|
|
|
|
flags = tf.app.flags
|
|
flags.DEFINE_string("data_dir", "/tmp/mnist-data",
|
|
"Directory for storing mnist data")
|
|
flags.DEFINE_boolean("download_only", False,
|
|
"Only perform downloading of data; Do not proceed to "
|
|
"session preparation, model definition or training")
|
|
flags.DEFINE_integer("num_gpus", 0, "Total number of gpus for each machine."
|
|
"If you don't use GPU, please set it to '0'")
|
|
flags.DEFINE_integer("replicas_to_aggregate", None,
|
|
"Number of replicas to aggregate before parameter update "
|
|
"is applied (For sync_replicas mode only; default: "
|
|
"num_workers)")
|
|
flags.DEFINE_integer("hidden_units", 100,
|
|
"Number of units in the hidden layer of the NN")
|
|
flags.DEFINE_integer("train_steps", 200,
|
|
"Number of (global) training steps to perform")
|
|
flags.DEFINE_integer("batch_size", 100, "Training batch size")
|
|
flags.DEFINE_float("learning_rate", 0.01, "Learning rate")
|
|
flags.DEFINE_boolean(
|
|
"sync_replicas", False,
|
|
"Use the sync_replicas (synchronized replicas) mode, "
|
|
"wherein the parameter updates from workers are aggregated "
|
|
"before applied to avoid stale gradients")
|
|
flags.DEFINE_boolean(
|
|
"existing_servers", False, "Whether servers already exists. If True, "
|
|
"will use the worker hosts via their GRPC URLs (one client process "
|
|
"per worker host). Otherwise, will create an in-process TensorFlow "
|
|
"server.")
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
IMAGE_PIXELS = 28
|
|
|
|
|
|
def main(unused_argv):
|
|
data_root = os.path.join("outputs", "MNIST")
|
|
mnist = None
|
|
tf_config = os.environ.get("TF_CONFIG")
|
|
if not tf_config or tf_config == "":
|
|
raise ValueError("TF_CONFIG not found.")
|
|
tf_config_json = json.loads(tf_config)
|
|
cluster = tf_config_json.get('cluster')
|
|
job_name = tf_config_json.get('task', {}).get('type')
|
|
task_index = tf_config_json.get('task', {}).get('index')
|
|
job_name = "worker" if job_name == "master" else job_name
|
|
sentinel_path = os.path.join(data_root, "complete.txt")
|
|
if job_name == "worker" and task_index == 0:
|
|
mnist = input_data.read_data_sets(data_root, one_hot=True)
|
|
with open(sentinel_path, 'w+') as f:
|
|
f.write("download complete")
|
|
else:
|
|
while not os.path.exists(sentinel_path):
|
|
time.sleep(0.01)
|
|
mnist = input_data.read_data_sets(data_root, one_hot=True)
|
|
|
|
if FLAGS.download_only:
|
|
sys.exit(0)
|
|
|
|
print("job name = %s" % job_name)
|
|
print("task index = %d" % task_index)
|
|
print("number of GPUs = %d" % FLAGS.num_gpus)
|
|
|
|
# Construct the cluster and start the server
|
|
cluster_spec = tf.train.ClusterSpec(cluster)
|
|
|
|
# Get the number of workers.
|
|
num_workers = len(cluster_spec.task_indices("worker"))
|
|
|
|
if not FLAGS.existing_servers:
|
|
# Not using existing servers. Create an in-process server.
|
|
server = tf.train.Server(
|
|
cluster_spec, job_name=job_name, task_index=task_index)
|
|
if job_name == "ps":
|
|
server.join()
|
|
|
|
is_chief = (task_index == 0)
|
|
if FLAGS.num_gpus > 0:
|
|
# Avoid gpu allocation conflict: now allocate task_num -> #gpu
|
|
# for each worker in the corresponding machine
|
|
gpu = (task_index % FLAGS.num_gpus)
|
|
worker_device = "/job:worker/task:%d/gpu:%d" % (task_index, gpu)
|
|
elif FLAGS.num_gpus == 0:
|
|
# Just allocate the CPU to worker server
|
|
cpu = 0
|
|
worker_device = "/job:worker/task:%d/cpu:%d" % (task_index, cpu)
|
|
# The device setter will automatically place Variables ops on separate
|
|
# parameter servers (ps). The non-Variable ops will be placed on the workers.
|
|
# The ps use CPU and workers use corresponding GPU
|
|
with tf.device(
|
|
tf.train.replica_device_setter(
|
|
worker_device=worker_device,
|
|
ps_device="/job:ps/cpu:0",
|
|
cluster=cluster)):
|
|
global_step = tf.Variable(0, name="global_step", trainable=False)
|
|
|
|
# Variables of the hidden layer
|
|
hid_w = tf.Variable(
|
|
tf.truncated_normal(
|
|
[IMAGE_PIXELS * IMAGE_PIXELS, FLAGS.hidden_units],
|
|
stddev=1.0 / IMAGE_PIXELS),
|
|
name="hid_w")
|
|
hid_b = tf.Variable(tf.zeros([FLAGS.hidden_units]), name="hid_b")
|
|
|
|
# Variables of the softmax layer
|
|
sm_w = tf.Variable(
|
|
tf.truncated_normal(
|
|
[FLAGS.hidden_units, 10],
|
|
stddev=1.0 / math.sqrt(FLAGS.hidden_units)),
|
|
name="sm_w")
|
|
sm_b = tf.Variable(tf.zeros([10]), name="sm_b")
|
|
|
|
# Ops: located on the worker specified with task_index
|
|
x = tf.placeholder(tf.float32, [None, IMAGE_PIXELS * IMAGE_PIXELS])
|
|
y_ = tf.placeholder(tf.float32, [None, 10])
|
|
|
|
hid_lin = tf.nn.xw_plus_b(x, hid_w, hid_b)
|
|
hid = tf.nn.relu(hid_lin)
|
|
|
|
y = tf.nn.softmax(tf.nn.xw_plus_b(hid, sm_w, sm_b))
|
|
cross_entropy = -tf.reduce_sum(y_ * tf.log(tf.clip_by_value(y, 1e-10, 1.0)))
|
|
|
|
opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
|
|
|
|
if FLAGS.sync_replicas:
|
|
if FLAGS.replicas_to_aggregate is None:
|
|
replicas_to_aggregate = num_workers
|
|
else:
|
|
replicas_to_aggregate = FLAGS.replicas_to_aggregate
|
|
|
|
opt = tf.train.SyncReplicasOptimizer(
|
|
opt,
|
|
replicas_to_aggregate=replicas_to_aggregate,
|
|
total_num_replicas=num_workers,
|
|
name="mnist_sync_replicas")
|
|
|
|
train_step = opt.minimize(cross_entropy, global_step=global_step)
|
|
|
|
if FLAGS.sync_replicas:
|
|
local_init_op = opt.local_step_init_op
|
|
if is_chief:
|
|
local_init_op = opt.chief_init_op
|
|
|
|
ready_for_local_init_op = opt.ready_for_local_init_op
|
|
|
|
# Initial token and chief queue runners required by the sync_replicas mode
|
|
chief_queue_runner = opt.get_chief_queue_runner()
|
|
sync_init_op = opt.get_init_tokens_op()
|
|
|
|
init_op = tf.global_variables_initializer()
|
|
train_dir = tempfile.mkdtemp()
|
|
|
|
if FLAGS.sync_replicas:
|
|
sv = tf.train.Supervisor(
|
|
is_chief=is_chief,
|
|
logdir=train_dir,
|
|
init_op=init_op,
|
|
local_init_op=local_init_op,
|
|
ready_for_local_init_op=ready_for_local_init_op,
|
|
recovery_wait_secs=1,
|
|
global_step=global_step)
|
|
else:
|
|
sv = tf.train.Supervisor(
|
|
is_chief=is_chief,
|
|
logdir=train_dir,
|
|
init_op=init_op,
|
|
recovery_wait_secs=1,
|
|
global_step=global_step)
|
|
|
|
sess_config = tf.ConfigProto(
|
|
allow_soft_placement=True,
|
|
log_device_placement=False,
|
|
device_filters=["/job:ps",
|
|
"/job:worker/task:%d" % task_index])
|
|
|
|
# The chief worker (task_index==0) session will prepare the session,
|
|
# while the remaining workers will wait for the preparation to complete.
|
|
if is_chief:
|
|
print("Worker %d: Initializing session..." % task_index)
|
|
else:
|
|
print("Worker %d: Waiting for session to be initialized..." %
|
|
task_index)
|
|
|
|
if FLAGS.existing_servers:
|
|
server_grpc_url = "grpc://" + task_index
|
|
print("Using existing server at: %s" % server_grpc_url)
|
|
|
|
sess = sv.prepare_or_wait_for_session(server_grpc_url, config=sess_config)
|
|
else:
|
|
sess = sv.prepare_or_wait_for_session(server.target, config=sess_config)
|
|
|
|
print("Worker %d: Session initialization complete." % task_index)
|
|
|
|
if FLAGS.sync_replicas and is_chief:
|
|
# Chief worker will start the chief queue runner and call the init op.
|
|
sess.run(sync_init_op)
|
|
sv.start_queue_runners(sess, [chief_queue_runner])
|
|
|
|
# Perform training
|
|
time_begin = time.time()
|
|
print("Training begins @ %f" % time_begin)
|
|
|
|
local_step = 0
|
|
while True:
|
|
# Training feed
|
|
batch_xs, batch_ys = mnist.train.next_batch(FLAGS.batch_size)
|
|
train_feed = {x: batch_xs, y_: batch_ys}
|
|
|
|
_, step = sess.run([train_step, global_step], feed_dict=train_feed)
|
|
local_step += 1
|
|
|
|
now = time.time()
|
|
print("%f: Worker %d: training step %d done (global step: %d)" %
|
|
(now, task_index, local_step, step))
|
|
|
|
if step >= FLAGS.train_steps:
|
|
break
|
|
|
|
time_end = time.time()
|
|
print("Training ends @ %f" % time_end)
|
|
training_time = time_end - time_begin
|
|
print("Training elapsed time: %f s" % training_time)
|
|
|
|
# Validation feed
|
|
val_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
|
|
val_xent = sess.run(cross_entropy, feed_dict=val_feed)
|
|
print("After %d training step(s), validation cross entropy = %g" %
|
|
(FLAGS.train_steps, val_xent))
|
|
if job_name == "worker" and task_index == 0:
|
|
run = Run.get_context()
|
|
run.log("CrossEntropy", val_xent)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
tf.app.run()
|