mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-19 17:17:04 -05:00
122 lines
3.4 KiB
Python
122 lines
3.4 KiB
Python
import os
|
|
import sys
|
|
|
|
import ray
|
|
from ray.rllib import rollout
|
|
from ray.tune.registry import get_trainable_cls
|
|
|
|
from azureml.core import Run
|
|
|
|
from utils import callbacks
|
|
|
|
|
|
def run_rollout(args, parser):
|
|
|
|
config = args.config
|
|
if not args.env:
|
|
if not config.get("env"):
|
|
parser.error("the following arguments are required: --env")
|
|
args.env = config.get("env")
|
|
|
|
# Create the Trainer from config.
|
|
cls = get_trainable_cls(args.run)
|
|
agent = cls(env=args.env, config=config)
|
|
|
|
# Load state from checkpoint.
|
|
agent.restore(args.checkpoint)
|
|
num_steps = int(args.steps)
|
|
num_episodes = int(args.episodes)
|
|
|
|
# Determine the video output directory.
|
|
use_arg_monitor = False
|
|
try:
|
|
args.video_dir
|
|
except AttributeError:
|
|
print("There is no such attribute: args.video_dir")
|
|
use_arg_monitor = True
|
|
|
|
video_dir = None
|
|
if not use_arg_monitor:
|
|
if args.monitor:
|
|
video_dir = os.path.join("./logs", "video")
|
|
elif args.video_dir:
|
|
video_dir = os.path.expanduser(args.video_dir)
|
|
|
|
# Do the actual rollout.
|
|
with rollout.RolloutSaver(
|
|
args.out,
|
|
args.use_shelve,
|
|
write_update_file=args.track_progress,
|
|
target_steps=num_steps,
|
|
target_episodes=num_episodes,
|
|
save_info=args.save_info) as saver:
|
|
if use_arg_monitor:
|
|
rollout.rollout(
|
|
agent,
|
|
args.env,
|
|
num_steps,
|
|
num_episodes,
|
|
saver,
|
|
args.no_render,
|
|
args.monitor)
|
|
else:
|
|
rollout.rollout(
|
|
agent, args.env,
|
|
num_steps,
|
|
num_episodes,
|
|
saver,
|
|
args.no_render, video_dir)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
# Start ray head (single node)
|
|
os.system('ray start --head')
|
|
ray.init(address='auto')
|
|
|
|
# Add positional argument - serves as placeholder for checkpoint
|
|
argvc = sys.argv[1:]
|
|
argvc.insert(0, 'checkpoint-placeholder')
|
|
|
|
# Parse arguments
|
|
rollout_parser = rollout.create_parser()
|
|
|
|
rollout_parser.add_argument(
|
|
'--checkpoint-number', required=False, type=int, default=1,
|
|
help='Checkpoint number of the checkpoint from which to roll out')
|
|
|
|
rollout_parser.add_argument(
|
|
'--artifacts-dataset', required=True,
|
|
help='The checkpoints artifacts dataset')
|
|
|
|
rollout_parser.add_argument(
|
|
'--artifacts-path', required=True,
|
|
help='The checkpoints artifacts path')
|
|
|
|
args = rollout_parser.parse_args(argvc)
|
|
|
|
# Get a handle to run
|
|
run = Run.get_context()
|
|
|
|
# Get handles to the tarining artifacts dataset and mount path
|
|
artifacts_dataset = run.input_datasets['artifacts_dataset']
|
|
artifacts_path = run.input_datasets['artifacts_path']
|
|
|
|
# Find checkpoint file to be evaluated
|
|
checkpoint_id = '-' + str(args.checkpoint_number)
|
|
checkpoint_files = list(filter(
|
|
lambda filename: filename.endswith(checkpoint_id),
|
|
artifacts_dataset.to_path()))
|
|
|
|
checkpoint_file = checkpoint_files[0]
|
|
if checkpoint_file[0] == '/':
|
|
checkpoint_file = checkpoint_file[1:]
|
|
checkpoint = os.path.join(artifacts_path, checkpoint_file)
|
|
print('Checkpoint:', checkpoint)
|
|
|
|
# Set rollout checkpoint
|
|
args.checkpoint = checkpoint
|
|
|
|
# Start rollout
|
|
run_rollout(args, rollout_parser)
|