Files
MachineLearningNotebooks/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_environment.py

181 lines
5.1 KiB
Python

import json
import logging
import gym
import minerl.env.core
import minerl.env.comms
import numpy as np
from ray.rllib.env.atari_wrappers import FrameStack
from minerl.env.malmo import InstanceManager
# Modify the MineRL timeouts to detect common errors
# quicker and speed up recovery
minerl.env.core.SOCKTIME = 60.0
minerl.env.comms.retry_timeout = 1
class EnvWrapper(minerl.env.core.MineRLEnv):
def __init__(self, xml, port):
InstanceManager.configure_malmo_base_port(port)
self.action_to_command_array = [
'move 1',
'camera 0 270',
'camera 0 90']
super().__init__(
xml,
gym.spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8),
gym.spaces.Discrete(3)
)
self.metadata['video.frames_per_second'] = 2
def _setup_spaces(self, observation_space, action_space):
self.observation_space = observation_space
self.action_space = action_space
def _process_action(self, action_in) -> str:
assert self.action_space.contains(action_in)
assert action_in <= len(
self.action_to_command_array) - 1, 'action index out of bounds.'
return self.action_to_command_array[action_in]
def _process_observation(self, pov, info):
'''
Overwritten to simplify: returns only `pov` and
not as the MineRLEnv an obs_dict (observation directory)
'''
pov = np.frombuffer(pov, dtype=np.uint8)
if pov is None or len(pov) == 0:
raise Exception('Invalid observation, probably an aborted peek')
else:
pov = pov.reshape(
(self.height, self.width, self.depth)
)[::-1, :, :]
assert self.observation_space.contains(pov)
self._last_pov = pov
return pov
class TrackingEnv(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self._actions = [
self._forward,
self._turn_left,
self._turn_right
]
def _reset_state(self):
self._facing = (1, 0)
self._position = (0, 0)
self._visited = {}
self._update_visited()
def _forward(self):
self._position = (
self._position[0] + self._facing[0],
self._position[1] + self._facing[1]
)
def _turn_left(self):
self._facing = (self._facing[1], -self._facing[0])
def _turn_right(self):
self._facing = (-self._facing[1], self._facing[0])
def _encode_state(self):
return self._position
def _update_visited(self):
state = self._encode_state()
value = self._visited.get(state, 0)
self._visited[state] = value + 1
return value
def reset(self):
self._reset_state()
return super().reset()
def step(self, action):
o, r, d, i = super().step(action)
self._actions[action]()
revisit_count = self._update_visited()
if revisit_count == 0:
r += 0.1
return o, r, d, i
class TrajectoryWrapper(gym.Wrapper):
def __init__(self, env):
super().__init__(env)
self._trajectory = []
self._action_to_malmo_command_array = ['move 1', 'turn -1', 'turn 1']
def get_trajectory(self):
return self._trajectory
def _to_malmo_action(self, action_index):
return self._action_to_malmo_command_array[action_index]
def step(self, action):
self._trajectory.append(self._to_malmo_action(action))
o, r, d, i = super().step(action)
return o, r, d, i
class DummyEnv(gym.Env):
def __init__(self):
self.observation_space = gym.spaces.Box(
low=0,
high=255,
shape=(84, 84, 6),
dtype=np.uint8)
self.action_space = gym.spaces.Discrete(3)
# Define a function to create a MineRL environment
def create_env(config):
mission = config["mission"]
port = 1000 * config.worker_index + config.vector_index
print('*********************************************')
print(f'* Worker {config.worker_index} creating from \
mission: {mission}, port {port}')
print('*********************************************')
if config.worker_index == 0:
# The first environment is only used for checking the action
# and observation space. By using a dummy environment, there's
# no need to spin up a Minecraft instance behind it saving some
# CPU resources on the head node.
return DummyEnv()
env = EnvWrapper(mission, port)
env = TrackingEnv(env)
env = FrameStack(env, 2)
return env
def create_env_for_rollout(config):
mission = config['mission']
port = 1000 * config.worker_index + config.vector_index
print('*********************************************')
print(f'* Worker {config.worker_index} creating from \
mission: {mission}, port {port}')
print('*********************************************')
env = EnvWrapper(mission, port)
env = TrackingEnv(env)
env = FrameStack(env, 2)
env = TrajectoryWrapper(env)
return env