mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
131 lines
3.8 KiB
Python
131 lines
3.8 KiB
Python
import argparse
|
|
import os
|
|
import re
|
|
|
|
from azureml.core import Run
|
|
from azureml.core.model import Model
|
|
|
|
from minecraft_environment import create_env_for_rollout
|
|
from malmo_video_recorder import MalmoVideoRecorder
|
|
from gym import wrappers
|
|
|
|
import ray
|
|
import ray.tune as tune
|
|
from ray.rllib import rollout
|
|
from ray.tune.registry import get_trainable_cls
|
|
|
|
|
|
def write_mission_file_for_seed(mission_file, seed):
|
|
with open(mission_file, 'r') as base_file:
|
|
mission_file_path = mission_file.replace('v0', seed)
|
|
content = base_file.read().format(SEED_PLACEHOLDER=seed)
|
|
|
|
mission_file = open(mission_file_path, 'w')
|
|
mission_file.writelines(content)
|
|
mission_file.close()
|
|
|
|
return mission_file_path
|
|
|
|
|
|
def run_rollout(trainable_type, mission_file, seed):
|
|
# Writes the mission file for minerl
|
|
mission_file_path = write_mission_file_for_seed(mission_file, seed)
|
|
|
|
# Instantiate the agent. Note: the IMPALA trainer implementation in
|
|
# Ray uses an AsyncSamplesOptimizer. Under the hood, this starts a
|
|
# LearnerThread which will wait for training samples. This will fail
|
|
# after a timeout, but has no influence on the rollout. See
|
|
# https://github.com/ray-project/ray/blob/708dff6d8f7dd6f7919e06c1845f1fea0cca5b89/rllib/optimizers/aso_learner.py#L66
|
|
config = {
|
|
"env_config": {
|
|
"mission": mission_file_path,
|
|
"is_rollout": True,
|
|
"seed": seed
|
|
},
|
|
"num_workers": 0
|
|
}
|
|
cls = get_trainable_cls(args.run)
|
|
agent = cls(env="Minecraft", config=config)
|
|
|
|
# The optimizer is not needed during a rollout
|
|
agent.optimizer.stop()
|
|
|
|
# Load state from checkpoint
|
|
agent.restore(f'{checkpoint_path}/{checkpoint_file}')
|
|
|
|
# Get a reference to the environment
|
|
env = agent.workers.local_worker().env
|
|
|
|
# Let the agent choose actions until the game is over
|
|
obs = env.reset()
|
|
done = False
|
|
total_reward = 0
|
|
|
|
while not done:
|
|
action = agent.compute_action(obs)
|
|
obs, reward, done, info = env.step(action)
|
|
|
|
total_reward += reward
|
|
|
|
print(f'Total reward using seed {seed}: {total_reward}')
|
|
|
|
# This avoids a sigterm trace in the logs, see minerl.env.malmo.Instance
|
|
env.instance.watcher_process.kill()
|
|
|
|
env.close()
|
|
agent.stop()
|
|
|
|
return env.get_trajectory()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--model_name', required=True)
|
|
parser.add_argument('--run', required=False, default="IMPALA")
|
|
args = parser.parse_args()
|
|
|
|
# Register custom Minecraft environment
|
|
tune.register_env("Minecraft", create_env_for_rollout)
|
|
|
|
ray.init(address='auto')
|
|
|
|
# Download the model files (contains a checkpoint)
|
|
ws = Run.get_context().experiment.workspace
|
|
model = Model(ws, args.model_name)
|
|
checkpoint_path = model.download(exist_ok=True)
|
|
|
|
files_ = os.listdir(checkpoint_path)
|
|
cp_pattern = re.compile('^checkpoint-\\d+$')
|
|
|
|
checkpoint_file = None
|
|
for f_ in files_:
|
|
if cp_pattern.match(f_):
|
|
checkpoint_file = f_
|
|
|
|
if checkpoint_file is None:
|
|
raise Exception("No checkpoint file found.")
|
|
|
|
# These are the Minecraft mission seeds for the rollouts
|
|
rollout_seeds = ['1234', '43289', '65224', '983341']
|
|
|
|
# Initialize the Malmo video recorder
|
|
video_recorder = MalmoVideoRecorder()
|
|
video_recorder.init_malmo()
|
|
|
|
# Path references to the mission files
|
|
base_training_mission_file = \
|
|
'minecraft_missions/lava_maze_rollout-v0.xml'
|
|
base_video_recording_mission_file = \
|
|
'minecraft_missions/lava_maze_rollout_video.xml'
|
|
|
|
for seed in rollout_seeds:
|
|
trajectory = run_rollout(
|
|
args.run,
|
|
base_training_mission_file,
|
|
seed)
|
|
|
|
video_recorder.record_malmo_video(
|
|
trajectory,
|
|
base_video_recording_mission_file,
|
|
seed)
|