mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
181 lines
5.1 KiB
Python
181 lines
5.1 KiB
Python
import json
|
|
import logging
|
|
|
|
import gym
|
|
import minerl.env.core
|
|
import minerl.env.comms
|
|
import numpy as np
|
|
|
|
from ray.rllib.env.atari_wrappers import FrameStack
|
|
from minerl.env.malmo import InstanceManager
|
|
|
|
# Modify the MineRL timeouts to detect common errors
|
|
# quicker and speed up recovery
|
|
minerl.env.core.SOCKTIME = 60.0
|
|
minerl.env.comms.retry_timeout = 1
|
|
|
|
|
|
class EnvWrapper(minerl.env.core.MineRLEnv):
|
|
def __init__(self, xml, port):
|
|
InstanceManager.configure_malmo_base_port(port)
|
|
self.action_to_command_array = [
|
|
'move 1',
|
|
'camera 0 270',
|
|
'camera 0 90']
|
|
|
|
super().__init__(
|
|
xml,
|
|
gym.spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8),
|
|
gym.spaces.Discrete(3)
|
|
)
|
|
|
|
self.metadata['video.frames_per_second'] = 2
|
|
|
|
def _setup_spaces(self, observation_space, action_space):
|
|
self.observation_space = observation_space
|
|
self.action_space = action_space
|
|
|
|
def _process_action(self, action_in) -> str:
|
|
assert self.action_space.contains(action_in)
|
|
assert action_in <= len(
|
|
self.action_to_command_array) - 1, 'action index out of bounds.'
|
|
return self.action_to_command_array[action_in]
|
|
|
|
def _process_observation(self, pov, info):
|
|
'''
|
|
Overwritten to simplify: returns only `pov` and
|
|
not as the MineRLEnv an obs_dict (observation directory)
|
|
'''
|
|
pov = np.frombuffer(pov, dtype=np.uint8)
|
|
|
|
if pov is None or len(pov) == 0:
|
|
raise Exception('Invalid observation, probably an aborted peek')
|
|
else:
|
|
pov = pov.reshape(
|
|
(self.height, self.width, self.depth)
|
|
)[::-1, :, :]
|
|
|
|
assert self.observation_space.contains(pov)
|
|
|
|
self._last_pov = pov
|
|
|
|
return pov
|
|
|
|
|
|
class TrackingEnv(gym.Wrapper):
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
self._actions = [
|
|
self._forward,
|
|
self._turn_left,
|
|
self._turn_right
|
|
]
|
|
|
|
def _reset_state(self):
|
|
self._facing = (1, 0)
|
|
self._position = (0, 0)
|
|
self._visited = {}
|
|
self._update_visited()
|
|
|
|
def _forward(self):
|
|
self._position = (
|
|
self._position[0] + self._facing[0],
|
|
self._position[1] + self._facing[1]
|
|
)
|
|
|
|
def _turn_left(self):
|
|
self._facing = (self._facing[1], -self._facing[0])
|
|
|
|
def _turn_right(self):
|
|
self._facing = (-self._facing[1], self._facing[0])
|
|
|
|
def _encode_state(self):
|
|
return self._position
|
|
|
|
def _update_visited(self):
|
|
state = self._encode_state()
|
|
value = self._visited.get(state, 0)
|
|
self._visited[state] = value + 1
|
|
return value
|
|
|
|
def reset(self):
|
|
self._reset_state()
|
|
return super().reset()
|
|
|
|
def step(self, action):
|
|
o, r, d, i = super().step(action)
|
|
self._actions[action]()
|
|
revisit_count = self._update_visited()
|
|
if revisit_count == 0:
|
|
r += 0.1
|
|
|
|
return o, r, d, i
|
|
|
|
|
|
class TrajectoryWrapper(gym.Wrapper):
|
|
def __init__(self, env):
|
|
super().__init__(env)
|
|
self._trajectory = []
|
|
self._action_to_malmo_command_array = ['move 1', 'turn -1', 'turn 1']
|
|
|
|
def get_trajectory(self):
|
|
return self._trajectory
|
|
|
|
def _to_malmo_action(self, action_index):
|
|
return self._action_to_malmo_command_array[action_index]
|
|
|
|
def step(self, action):
|
|
self._trajectory.append(self._to_malmo_action(action))
|
|
o, r, d, i = super().step(action)
|
|
|
|
return o, r, d, i
|
|
|
|
|
|
class DummyEnv(gym.Env):
|
|
def __init__(self):
|
|
self.observation_space = gym.spaces.Box(
|
|
low=0,
|
|
high=255,
|
|
shape=(84, 84, 6),
|
|
dtype=np.uint8)
|
|
self.action_space = gym.spaces.Discrete(3)
|
|
|
|
|
|
# Define a function to create a MineRL environment
|
|
def create_env(config):
|
|
mission = config["mission"]
|
|
port = 1000 * config.worker_index + config.vector_index
|
|
print('*********************************************')
|
|
print(f'* Worker {config.worker_index} creating from \
|
|
mission: {mission}, port {port}')
|
|
print('*********************************************')
|
|
|
|
if config.worker_index == 0:
|
|
# The first environment is only used for checking the action
|
|
# and observation space. By using a dummy environment, there's
|
|
# no need to spin up a Minecraft instance behind it saving some
|
|
# CPU resources on the head node.
|
|
return DummyEnv()
|
|
|
|
env = EnvWrapper(mission, port)
|
|
env = TrackingEnv(env)
|
|
env = FrameStack(env, 2)
|
|
|
|
return env
|
|
|
|
|
|
def create_env_for_rollout(config):
|
|
mission = config['mission']
|
|
port = 1000 * config.worker_index + config.vector_index
|
|
print('*********************************************')
|
|
print(f'* Worker {config.worker_index} creating from \
|
|
mission: {mission}, port {port}')
|
|
print('*********************************************')
|
|
|
|
env = EnvWrapper(mission, port)
|
|
env = TrackingEnv(env)
|
|
env = FrameStack(env, 2)
|
|
env = TrajectoryWrapper(env)
|
|
|
|
return env
|