Files
MachineLearningNotebooks/how-to-use-azureml/reinforcement-learning/cartpole-on-compute-instance/files/cartpole_rollout.py

109 lines
3.7 KiB
Python

import os
import sys
import argparse
from ray.rllib.evaluate import RolloutSaver, rollout
from ray_on_aml.core import Ray_On_AML
import ray.cloudpickle as cloudpickle
from ray.tune.utils import merge_dicts
from ray.tune.registry import get_trainable_cls, _global_registry, ENV_CREATOR
from azureml.core import Run
from utils import callbacks
import collections
import copy
import gymnasium as gym
import json
from pathlib import Path
def run_rollout(checkpoint, algo, render, steps, episodes):
config_dir = os.path.dirname(checkpoint)
config_path = os.path.join(config_dir, "params.pkl")
config = None
# Try parent directory.
if not os.path.exists(config_path):
config_path = os.path.join(config_dir, "../params.pkl")
# Load the config from pickled.
if os.path.exists(config_path):
with open(config_path, "rb") as f:
config = cloudpickle.load(f)
# If no pkl file found, require command line `--config`.
else:
raise ValueError("Could not find params.pkl in either the checkpoint dir or its parent directory")
# Make sure worker 0 has an Env.
config["create_env_on_driver"] = True
# Merge with `evaluation_config` (first try from command line, then from
# pkl file).
evaluation_config = copy.deepcopy(config.get("evaluation_config", {}))
config = merge_dicts(config, evaluation_config)
env = config.get("env")
# Make sure we have evaluation workers.
if not config.get("evaluation_num_workers"):
config["evaluation_num_workers"] = config.get("num_workers", 0)
if not config.get("evaluation_duration"):
config["evaluation_duration"] = 1
# Hard-override this as it raises a warning by Algorithm otherwise.
# Makes no sense anyways, to have it set to None as we don't call
# `Algorithm.train()` here.
config["evaluation_interval"] = 1
# Rendering settings.
config["render_env"] = render
# Create the Algorithm from config.
cls = get_trainable_cls(algo)
algorithm = cls(env=env, config=config)
# Load state from checkpoint, if provided.
if checkpoint:
algorithm.restore(checkpoint)
# Do the actual rollout.
with RolloutSaver(
outfile=None,
use_shelve=False,
write_update_file=False,
target_steps=steps,
target_episodes=episodes,
save_info=False,
) as saver:
rollout(algorithm, env, steps, episodes, saver, not render)
algorithm.stop()
if __name__ == "__main__":
# Start ray head (single node)
ray_on_aml = Ray_On_AML()
ray = ray_on_aml.getRay()
if ray:
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_path', required=True, help='Path to artifacts dataset')
parser.add_argument('--checkpoint', required=True, help='Name of checkpoint file directory')
parser.add_argument('--algo', required=True, help='Name of RL algorithm')
parser.add_argument('--render', default=False, required=False, help='True to render')
parser.add_argument('--steps', required=False, type=int, help='Number of steps to run')
parser.add_argument('--episodes', required=False, type=int, help='Number of episodes to run')
args = parser.parse_args()
# Get a handle to run
run = Run.get_context()
# Get handles to the tarining artifacts dataset and mount path
dataset_path = run.input_datasets['dataset_path']
# Find checkpoint file to be evaluated
checkpoint = os.path.join(dataset_path, args.checkpoint)
print('Checkpoint:', checkpoint)
# Start rollout
ray.init(address='auto')
run_rollout(checkpoint, args.algo, args.render, args.steps, args.episodes)