mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 09:37:04 -05:00
121 lines
4.2 KiB
Python
121 lines
4.2 KiB
Python
# Copyright 2019 Uber Technologies, Inc. All Rights Reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
# Script adapted from: https://github.com/horovod/horovod/blob/master/examples/tensorflow2_keras_mnist.py
|
|
# ==============================================================================
|
|
|
|
import tensorflow as tf
|
|
import horovod.tensorflow.keras as hvd
|
|
|
|
import os
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--learning-rate", "-lr", type=float, default=0.001)
|
|
parser.add_argument("--epochs", type=int, default=24)
|
|
parser.add_argument("--steps-per-epoch", type=int, default=500)
|
|
args = parser.parse_args()
|
|
|
|
# Horovod: initialize Horovod.
|
|
hvd.init()
|
|
|
|
# Horovod: pin GPU to be used to process local rank (one GPU per process)
|
|
gpus = tf.config.experimental.list_physical_devices("GPU")
|
|
for gpu in gpus:
|
|
tf.config.experimental.set_memory_growth(gpu, True)
|
|
if gpus:
|
|
tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], "GPU")
|
|
|
|
(mnist_images, mnist_labels), _ = tf.keras.datasets.mnist.load_data(
|
|
path="mnist-%d.npz" % hvd.rank()
|
|
)
|
|
|
|
dataset = tf.data.Dataset.from_tensor_slices(
|
|
(
|
|
tf.cast(mnist_images[..., tf.newaxis] / 255.0, tf.float32),
|
|
tf.cast(mnist_labels, tf.int64),
|
|
)
|
|
)
|
|
dataset = dataset.repeat().shuffle(10000).batch(128)
|
|
|
|
mnist_model = tf.keras.Sequential(
|
|
[
|
|
tf.keras.layers.Conv2D(32, [3, 3], activation="relu"),
|
|
tf.keras.layers.Conv2D(64, [3, 3], activation="relu"),
|
|
tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
|
|
tf.keras.layers.Dropout(0.25),
|
|
tf.keras.layers.Flatten(),
|
|
tf.keras.layers.Dense(128, activation="relu"),
|
|
tf.keras.layers.Dropout(0.5),
|
|
tf.keras.layers.Dense(10, activation="softmax"),
|
|
]
|
|
)
|
|
|
|
# Horovod: adjust learning rate based on number of GPUs.
|
|
scaled_lr = args.learning_rate * hvd.size()
|
|
opt = tf.optimizers.Adam(scaled_lr)
|
|
|
|
# Horovod: add Horovod DistributedOptimizer.
|
|
opt = hvd.DistributedOptimizer(opt)
|
|
|
|
# Horovod: Specify `experimental_run_tf_function=False` to ensure TensorFlow
|
|
# uses hvd.DistributedOptimizer() to compute gradients.
|
|
mnist_model.compile(
|
|
loss=tf.losses.SparseCategoricalCrossentropy(),
|
|
optimizer=opt,
|
|
metrics=["accuracy"],
|
|
experimental_run_tf_function=False,
|
|
)
|
|
|
|
callbacks = [
|
|
# Horovod: broadcast initial variable states from rank 0 to all other processes.
|
|
# This is necessary to ensure consistent initialization of all workers when
|
|
# training is started with random weights or restored from a checkpoint.
|
|
hvd.callbacks.BroadcastGlobalVariablesCallback(0),
|
|
# Horovod: average metrics among workers at the end of every epoch.
|
|
#
|
|
# Note: This callback must be in the list before the ReduceLROnPlateau,
|
|
# TensorBoard or other metrics-based callbacks.
|
|
hvd.callbacks.MetricAverageCallback(),
|
|
# Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final
|
|
# accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during
|
|
# the first three epochs. See https://arxiv.org/abs/1706.02677 for details.
|
|
hvd.callbacks.LearningRateWarmupCallback(
|
|
warmup_epochs=3, initial_lr=scaled_lr, verbose=1
|
|
),
|
|
]
|
|
|
|
# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
|
|
if hvd.rank() == 0:
|
|
output_dir = "./outputs"
|
|
os.makedirs(output_dir, exist_ok=True)
|
|
callbacks.append(
|
|
tf.keras.callbacks.ModelCheckpoint(
|
|
os.path.join(output_dir, "checkpoint-{epoch}.h5")
|
|
)
|
|
)
|
|
|
|
# Horovod: write logs on worker 0.
|
|
verbose = 1 if hvd.rank() == 0 else 0
|
|
|
|
# Train the model.
|
|
# Horovod: adjust number of steps based on number of GPUs.
|
|
mnist_model.fit(
|
|
dataset,
|
|
steps_per_epoch=args.steps_per_epoch // hvd.size(),
|
|
callbacks=callbacks,
|
|
epochs=args.epochs,
|
|
verbose=verbose,
|
|
)
|