mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-19 17:17:04 -05:00
109 lines
3.7 KiB
Python
109 lines
3.7 KiB
Python
import os
|
|
import sys
|
|
import argparse
|
|
|
|
from ray.rllib.evaluate import RolloutSaver, rollout
|
|
from ray_on_aml.core import Ray_On_AML
|
|
import ray.cloudpickle as cloudpickle
|
|
from ray.tune.utils import merge_dicts
|
|
from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR
|
|
|
|
from azureml.core import Run
|
|
from utils import callbacks
|
|
|
|
import collections
|
|
import copy
|
|
import gymnasium as gym
|
|
import json
|
|
from pathlib import Path
|
|
|
|
|
|
def run_rollout(checkpoint, algo, render, steps, episodes):
|
|
config_dir = os.path.dirname(checkpoint)
|
|
config_path = os.path.join(config_dir, "params.pkl")
|
|
config = None
|
|
|
|
# Try parent directory.
|
|
if not os.path.exists(config_path):
|
|
config_path = os.path.join(config_dir, "../params.pkl")
|
|
|
|
# Load the config from pickled.
|
|
if os.path.exists(config_path):
|
|
with open(config_path, "rb") as f:
|
|
config = cloudpickle.load(f)
|
|
# If no pkl file found, require command line `--config`.
|
|
else:
|
|
raise ValueError("Could not find params.pkl in either the checkpoint dir or its parent directory")
|
|
|
|
# Make sure worker 0 has an Env.
|
|
config["create_env_on_driver"] = True
|
|
|
|
# Merge with `evaluation_config` (first try from command line, then from
|
|
# pkl file).
|
|
evaluation_config = copy.deepcopy(config.get("evaluation_config", {}))
|
|
config = merge_dicts(config, evaluation_config)
|
|
env = config.get("env")
|
|
|
|
# Make sure we have evaluation workers.
|
|
if not config.get("evaluation_num_workers"):
|
|
config["evaluation_num_workers"] = config.get("num_workers", 0)
|
|
if not config.get("evaluation_duration"):
|
|
config["evaluation_duration"] = 1
|
|
|
|
# Hard-override this as it raises a warning by Algorithm otherwise.
|
|
# Makes no sense anyways, to have it set to None as we don't call
|
|
# `Algorithm.train()` here.
|
|
config["evaluation_interval"] = 1
|
|
|
|
# Rendering settings.
|
|
config["render_env"] = render
|
|
|
|
# Create the Algorithm from config.
|
|
cls = get_trainable_cls(algo)
|
|
algorithm = cls(env=env, config=config)
|
|
|
|
# Load state from checkpoint, if provided.
|
|
if checkpoint:
|
|
algorithm.restore(checkpoint)
|
|
|
|
# Do the actual rollout.
|
|
with RolloutSaver(
|
|
outfile=None,
|
|
use_shelve=False,
|
|
write_update_file=False,
|
|
target_steps=steps,
|
|
target_episodes=episodes,
|
|
save_info=False,
|
|
) as saver:
|
|
rollout(algorithm, env, steps, episodes, saver, not render)
|
|
algorithm.stop()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# Start ray head (single node)
|
|
ray_on_aml = Ray_On_AML()
|
|
ray = ray_on_aml.getRay()
|
|
if ray:
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--dataset_path', required=True, help='Path to artifacts dataset')
|
|
parser.add_argument('--checkpoint', required=True, help='Name of checkpoint file directory')
|
|
parser.add_argument('--algo', required=True, help='Name of RL algorithm')
|
|
parser.add_argument('--render', default=False, required=False, help='True to render')
|
|
parser.add_argument('--steps', required=False, type=int, help='Number of steps to run')
|
|
parser.add_argument('--episodes', required=False, type=int, help='Number of episodes to run')
|
|
args = parser.parse_args()
|
|
|
|
# Get a handle to run
|
|
run = Run.get_context()
|
|
|
|
# Get handles to the tarining artifacts dataset and mount path
|
|
dataset_path = run.input_datasets['dataset_path']
|
|
|
|
# Find checkpoint file to be evaluated
|
|
checkpoint = os.path.join(dataset_path, args.checkpoint)
|
|
print('Checkpoint:', checkpoint)
|
|
|
|
# Start rollout
|
|
ray.init(address='auto')
|
|
run_rollout(checkpoint, args.algo, args.render, args.steps, args.episodes)
|