mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-23 20:00:06 -05:00
Compare commits
2 Commits
release_up
...
azureml-sd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
879a272a8d | ||
|
|
bc65bde097 |
@@ -20,11 +20,11 @@ Using these samples, you will be able to do the following.
|
||||
|
||||
| File/folder | Description |
|
||||
|-------------------|--------------------------------------------|
|
||||
| [README.md](README.md) | This README file. |
|
||||
| [devenv_setup.ipynb](setup/devenv_setup.ipynb) | Notebook to setup development environment for Azure ML RL |
|
||||
| [cartpole_ci.ipynb](cartpole-on-compute-instance/cartpole_ci.ipynb) | Notebook to train a Cartpole playing agent on an Azure ML Compute Instance |
|
||||
| [cartpole_cc.ipynb](cartpole-on-single-compute/cartpole_cc.ipynb) | Notebook to train a Cartpole playing agent on an Azure ML Compute Cluster (single node) |
|
||||
| [pong_rllib.ipynb](atari-on-distributed-compute/pong_rllib.ipynb) | Notebook to train Pong agent using RLlib on multiple compute targets |
|
||||
| [minecraft.ipynb](minecraft-on-distributed-compute/minecraft.ipynb) | Notebook to train an agent to navigate through a lava maze in the Minecraft game |
|
||||
|
||||
## Prerequisites
|
||||
|
||||
@@ -111,7 +111,7 @@ contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additio
|
||||
|
||||
For more on SDK concepts, please refer to [notebooks](https://github.com/Azure/MachineLearningNotebooks).
|
||||
|
||||
**Please let us know your feedback.**
|
||||
**Please let us know your [feedback](https://github.com/Azure/MachineLearningNotebooks/labels/Reinforcement%20Learning).**
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
FROM mcr.microsoft.com/azureml/base:openmpi3.1.2-ubuntu18.04
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
cpio \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
tmux \
|
||||
htop \
|
||||
gcc \
|
||||
xvfb \
|
||||
python-opengl \
|
||||
x11-xserver-utils \
|
||||
ffmpeg \
|
||||
mesa-utils \
|
||||
nano \
|
||||
vim \
|
||||
rsync \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create a working directory
|
||||
RUN mkdir /app
|
||||
WORKDIR /app
|
||||
|
||||
# Install Minecraft needed libraries
|
||||
RUN mkdir -p /usr/share/man/man1 && \
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y \
|
||||
openjdk-8-jre-headless=8u162-b12-1 \
|
||||
openjdk-8-jdk-headless=8u162-b12-1 \
|
||||
openjdk-8-jre=8u162-b12-1 \
|
||||
openjdk-8-jdk=8u162-b12-1
|
||||
|
||||
# Create a Python 3.7 environment
|
||||
RUN conda install conda-build \
|
||||
&& conda create -y --name py37 python=3.7.3 \
|
||||
&& conda clean -ya
|
||||
ENV CONDA_DEFAULT_ENV=py37
|
||||
|
||||
# Install minerl
|
||||
RUN pip install --upgrade --user minerl
|
||||
|
||||
RUN pip install \
|
||||
pandas \
|
||||
matplotlib \
|
||||
numpy \
|
||||
scipy \
|
||||
azureml-defaults \
|
||||
tensorboardX \
|
||||
tensorflow==1.15rc2 \
|
||||
tabulate \
|
||||
dm_tree \
|
||||
lz4 \
|
||||
ray==0.8.3 \
|
||||
ray[rllib]==0.8.3 \
|
||||
ray[tune]==0.8.3
|
||||
|
||||
COPY patch_files/* /root/.local/lib/python3.7/site-packages/minerl/env/Malmo/Minecraft/src/main/java/com/microsoft/Malmo/Client/
|
||||
|
||||
# Start minerl to pre-fetch minerl files (saves time when starting minerl during training)
|
||||
RUN xvfb-run -a -s "-screen 0 1400x900x24" python -c "import gym; import minerl; env = gym.make('MineRLTreechop-v0'); env.close();"
|
||||
|
||||
RUN pip install --index-url https://test.pypi.org/simple/ malmo && \
|
||||
python -c "import malmo.minecraftbootstrap; malmo.minecraftbootstrap.download();"
|
||||
|
||||
ENV MALMO_XSD_PATH="/app/MalmoPlatform/Schemas"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,939 @@
|
||||
// --------------------------------------------------------------------------------------------------
|
||||
// Copyright (c) 2016 Microsoft Corporation
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
// associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||
// sublicense, and/or l copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all copies or
|
||||
// substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
|
||||
// NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
// --------------------------------------------------------------------------------------------------
|
||||
|
||||
package com.microsoft.Malmo.Client;
|
||||
|
||||
import com.microsoft.Malmo.MalmoMod;
|
||||
import com.microsoft.Malmo.MissionHandlerInterfaces.IWantToQuit;
|
||||
import com.microsoft.Malmo.Schemas.MissionInit;
|
||||
import com.microsoft.Malmo.Utils.TCPUtils;
|
||||
|
||||
import net.minecraft.profiler.Profiler;
|
||||
import com.microsoft.Malmo.Utils.TimeHelper;
|
||||
|
||||
import net.minecraftforge.common.config.Configuration;
|
||||
import java.io.*;
|
||||
import java.net.ServerSocket;
|
||||
import java.net.Socket;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.locks.Condition;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.Hashtable;
|
||||
import com.microsoft.Malmo.Utils.TCPInputPoller;
|
||||
import java.util.logging.Level;
|
||||
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
/**
|
||||
* MalmoEnvServer - service supporting OpenAI gym "environment" for multi-agent Malmo missions.
|
||||
*/
|
||||
public class MalmoEnvServer implements IWantToQuit {
|
||||
private static Profiler profiler = new Profiler();
|
||||
private static int nsteps = 0;
|
||||
private static boolean debug = false;
|
||||
|
||||
private static String hello = "<MalmoEnv" ;
|
||||
|
||||
private class EnvState {
|
||||
|
||||
// Mission parameters:
|
||||
String missionInit = null;
|
||||
String token = null;
|
||||
String experimentId = null;
|
||||
int agentCount = 0;
|
||||
int reset = 0;
|
||||
boolean quit = false;
|
||||
boolean synchronous = false;
|
||||
Long seed = null;
|
||||
|
||||
// OpenAI gym state:
|
||||
boolean done = false;
|
||||
double reward = 0.0;
|
||||
byte[] obs = null;
|
||||
String info = "";
|
||||
LinkedList<String> commands = new LinkedList<String>();
|
||||
}
|
||||
|
||||
private static boolean envPolicy = false; // Are we configured by config policy?
|
||||
|
||||
// Synchronize on EnvStateasd
|
||||
|
||||
|
||||
private Lock lock = new ReentrantLock();
|
||||
private Condition cond = lock.newCondition();
|
||||
|
||||
private EnvState envState = new EnvState();
|
||||
|
||||
private Hashtable<String, Integer> initTokens = new Hashtable<String, Integer>();
|
||||
|
||||
static final long COND_WAIT_SECONDS = 3; // Max wait in seconds before timing out (and replying to RPC).
|
||||
static final int BYTES_INT = 4;
|
||||
static final int BYTES_DOUBLE = 8;
|
||||
private static final Charset utf8 = Charset.forName("UTF-8");
|
||||
|
||||
// Service uses a single per-environment client connection - initiated by the remote environment.
|
||||
|
||||
private int port;
|
||||
private TCPInputPoller missionPoller; // Used for command parsing and not actual communication.
|
||||
private String version;
|
||||
|
||||
// AOG: From running experiments, I've found that MineRL can get stuck resetting the
|
||||
// environment which causes huge delays while we wait for the Python side to time
|
||||
// out and restart the Minecraft instace. Minecraft itself is normally in a recoverable
|
||||
// state, but the MalmoEnvServer instance will be blocked in a tight spin loop trying
|
||||
// handling a Peek request from the Python client. To unstick things, I've added this
|
||||
// flag that can be set when we know things are in a bad state to abort the peek request.
|
||||
// WARNING: THIS IS ONLY TREATING THE SYMPTOM AND NOT THE ROOT CAUSE
|
||||
// The reason things are getting stuck is because the player is either dying or we're
|
||||
// receiving a quit request while an episode reset is in progress.
|
||||
private boolean abortRequest;
|
||||
public void abort() {
|
||||
System.out.println("AOG: MalmoEnvServer.abort");
|
||||
abortRequest = true;
|
||||
}
|
||||
|
||||
/***
|
||||
* Malmo "Env" service.
|
||||
* @param port the port the service listens on.
|
||||
* @param missionPoller for plugging into existing comms handling.
|
||||
*/
|
||||
public MalmoEnvServer(String version, int port, TCPInputPoller missionPoller) {
|
||||
this.version = version;
|
||||
this.missionPoller = missionPoller;
|
||||
this.port = port;
|
||||
// AOG - Assume we don't wan't to be aborting in the first place
|
||||
this.abortRequest = false;
|
||||
}
|
||||
|
||||
/** Initialize malmo env configuration. For now either on or "legacy" AgentHost protocol.*/
|
||||
static public void update(Configuration configs) {
|
||||
envPolicy = configs.get(MalmoMod.ENV_CONFIGS, "env", "false").getBoolean();
|
||||
}
|
||||
|
||||
public static boolean isEnv() {
|
||||
return envPolicy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Start servicing the MalmoEnv protocol.
|
||||
* @throws IOException
|
||||
*/
|
||||
public void serve() throws IOException {
|
||||
|
||||
ServerSocket serverSocket = new ServerSocket(port);
|
||||
serverSocket.setPerformancePreferences(0,2,1);
|
||||
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
final Socket socket = serverSocket.accept();
|
||||
socket.setTcpNoDelay(true);
|
||||
|
||||
Thread thread = new Thread("EnvServerSocketHandler") {
|
||||
public void run() {
|
||||
boolean running = false;
|
||||
try {
|
||||
checkHello(socket);
|
||||
|
||||
while (true) {
|
||||
DataInputStream din = new DataInputStream(socket.getInputStream());
|
||||
int hdr = din.readInt();
|
||||
byte[] data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
|
||||
String command = new String(data, utf8);
|
||||
|
||||
if (command.startsWith("<Step")) {
|
||||
|
||||
profiler.startSection("root");
|
||||
long start = System.nanoTime();
|
||||
step(command, socket, din);
|
||||
profiler.endSection();
|
||||
if (nsteps % 100 == 0 && debug){
|
||||
List<Profiler.Result> dat = profiler.getProfilingData("root");
|
||||
for(int qq = 0; qq < dat.size(); qq++){
|
||||
Profiler.Result res = dat.get(qq);
|
||||
System.out.println(res.profilerName + " " + res.totalUsePercentage + " "+ res.usePercentage);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} else if (command.startsWith("<Peek")) {
|
||||
|
||||
peek(command, socket, din);
|
||||
|
||||
} else if (command.startsWith("<Init")) {
|
||||
|
||||
init(command, socket);
|
||||
|
||||
} else if (command.startsWith("<Find")) {
|
||||
|
||||
find(command, socket);
|
||||
|
||||
} else if (command.startsWith("<MissionInit")) {
|
||||
|
||||
if (missionInit(din, command, socket))
|
||||
{
|
||||
running = true;
|
||||
}
|
||||
|
||||
} else if (command.startsWith("<Quit")) {
|
||||
|
||||
quit(command, socket);
|
||||
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Exit")) {
|
||||
|
||||
exit(command, socket);
|
||||
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Close")) {
|
||||
|
||||
close(command, socket);
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Status")) {
|
||||
|
||||
status(command, socket);
|
||||
|
||||
} else if (command.startsWith("<Echo")) {
|
||||
command = "<Echo>" + command + "</Echo>";
|
||||
data = command.getBytes(utf8);
|
||||
hdr = data.length;
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(hdr);
|
||||
dout.write(data, 0, hdr);
|
||||
dout.flush();
|
||||
} else {
|
||||
throw new IOException("Unknown env service command");
|
||||
}
|
||||
}
|
||||
} catch (IOException ioe) {
|
||||
// ioe.printStackTrace();
|
||||
TCPUtils.Log(Level.SEVERE, "MalmoEnv socket error: " + ioe + " (can be on disconnect)");
|
||||
// System.out.println("[ERROR] " + "MalmoEnv socket error: " + ioe + " (can be on disconnect)");
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] MalmoEnv socket error");
|
||||
try {
|
||||
if (running) {
|
||||
TCPUtils.Log(Level.INFO,"Want to quit on disconnect.");
|
||||
|
||||
System.out.println("[LOGTOPY] " + "Want to quit on disconnect.");
|
||||
setWantToQuit();
|
||||
}
|
||||
socket.close();
|
||||
} catch (IOException ioe2) {
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
thread.start();
|
||||
} catch (IOException ioe) {
|
||||
TCPUtils.Log(Level.SEVERE, "MalmoEnv service exits on " + ioe);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void checkHello(Socket socket) throws IOException {
|
||||
DataInputStream din = new DataInputStream(socket.getInputStream());
|
||||
int hdr = din.readInt();
|
||||
if (hdr <= 0 || hdr > hello.length() + 8) // Version number may be somewhat longer in future.
|
||||
throw new IOException("Invalid MalmoEnv hello header length");
|
||||
byte[] data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
if (!new String(data).startsWith(hello + version))
|
||||
throw new IOException("MalmoEnv invalid protocol or version - expected " + hello + version);
|
||||
}
|
||||
|
||||
// Handler for <MissionInit> messages.
|
||||
private boolean missionInit(DataInputStream din, String command, Socket socket) throws IOException {
|
||||
|
||||
String ipOriginator = socket.getInetAddress().getHostName();
|
||||
|
||||
int hdr;
|
||||
byte[] data;
|
||||
hdr = din.readInt();
|
||||
data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
String id = new String(data, utf8);
|
||||
|
||||
TCPUtils.Log(Level.INFO,"Mission Init" + id);
|
||||
|
||||
String[] token = id.split(":");
|
||||
String experimentId = token[0];
|
||||
int role = Integer.parseInt(token[1]);
|
||||
int reset = Integer.parseInt(token[2]);
|
||||
int agentCount = Integer.parseInt(token[3]);
|
||||
Boolean isSynchronous = Boolean.parseBoolean(token[4]);
|
||||
Long seed = null;
|
||||
if(token.length > 5)
|
||||
seed = Long.parseLong(token[5]);
|
||||
|
||||
if(isSynchronous && agentCount > 1){
|
||||
throw new IOException("Synchronous mode currently does not support multiple agents.");
|
||||
}
|
||||
port = -1;
|
||||
boolean allTokensConsumed = true;
|
||||
boolean started = false;
|
||||
|
||||
lock.lock();
|
||||
try {
|
||||
if (role == 0) {
|
||||
|
||||
String previousToken = experimentId + ":0:" + (reset - 1);
|
||||
initTokens.remove(previousToken);
|
||||
|
||||
String myToken = experimentId + ":0:" + reset;
|
||||
if (!initTokens.containsKey(myToken)) {
|
||||
TCPUtils.Log(Level.INFO,"(Pre)Start " + role + " reset " + reset);
|
||||
started = startUp(command, ipOriginator, experimentId, reset, agentCount, myToken, seed, isSynchronous);
|
||||
if (started)
|
||||
initTokens.put(myToken, 0);
|
||||
} else {
|
||||
started = true; // Pre-started previously.
|
||||
}
|
||||
|
||||
// Check that all previous tokens have been consumed. If not don't proceed to mission.
|
||||
|
||||
allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount);
|
||||
if (!allTokensConsumed) {
|
||||
try {
|
||||
cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException ie) {
|
||||
}
|
||||
allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount);
|
||||
}
|
||||
} else {
|
||||
TCPUtils.Log(Level.INFO, "Start " + role + " reset " + reset);
|
||||
|
||||
started = startUp(command, ipOriginator, experimentId, reset, agentCount, experimentId + ":" + role + ":" + reset, seed, isSynchronous);
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(allTokensConsumed && started ? 1 : 0);
|
||||
dout.flush();
|
||||
|
||||
dout.flush();
|
||||
|
||||
return allTokensConsumed && started;
|
||||
}
|
||||
|
||||
private boolean areAllTokensConsumed(String experimentId, int reset, int agentCount) {
|
||||
boolean allTokensConsumed = true;
|
||||
for (int i = 1; i < agentCount; i++) {
|
||||
String tokenForAgent = experimentId + ":" + i + ":" + (reset - 1);
|
||||
if (initTokens.containsKey(tokenForAgent)) {
|
||||
TCPUtils.Log(Level.FINE,"Mission init - unconsumed " + tokenForAgent);
|
||||
allTokensConsumed = false;
|
||||
}
|
||||
}
|
||||
return allTokensConsumed;
|
||||
}
|
||||
|
||||
private boolean startUp(String command, String ipOriginator, String experimentId, int reset, int agentCount, String myToken, Long seed, Boolean isSynchronous) throws IOException {
|
||||
|
||||
// Clear out mission state
|
||||
envState.reward = 0.0;
|
||||
envState.commands.clear();
|
||||
envState.obs = null;
|
||||
envState.info = "";
|
||||
|
||||
|
||||
envState.missionInit = command;
|
||||
envState.done = false;
|
||||
envState.quit = false;
|
||||
envState.token = myToken;
|
||||
envState.experimentId = experimentId;
|
||||
envState.agentCount = agentCount;
|
||||
envState.reset = reset;
|
||||
envState.synchronous = isSynchronous;
|
||||
envState.seed = seed;
|
||||
|
||||
return startUpMission(command, ipOriginator);
|
||||
}
|
||||
|
||||
private boolean startUpMission(String command, String ipOriginator) throws IOException {
|
||||
|
||||
if (missionPoller == null)
|
||||
return false;
|
||||
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(baos);
|
||||
|
||||
missionPoller.commandReceived(command, ipOriginator, dos);
|
||||
|
||||
dos.flush();
|
||||
byte[] reply = baos.toByteArray();
|
||||
ByteArrayInputStream bais = new ByteArrayInputStream(reply);
|
||||
DataInputStream dis = new DataInputStream(bais);
|
||||
int hdr = dis.readInt();
|
||||
byte[] replyBytes = new byte[hdr];
|
||||
dis.readFully(replyBytes);
|
||||
|
||||
String replyStr = new String(replyBytes);
|
||||
if (replyStr.equals("MALMOOK")) {
|
||||
TCPUtils.Log(Level.INFO, "MalmoEnvServer Mission starting ...");
|
||||
return true;
|
||||
} else if (replyStr.equals("MALMOBUSY")) {
|
||||
TCPUtils.Log(Level.INFO, "MalmoEnvServer Busy - I want to quit");
|
||||
this.envState.quit = true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static final int stepTagLength = "<Step_>".length(); // Step with option code.
|
||||
private synchronized void stepSync(String command, Socket socket, DataInputStream din) throws IOException
|
||||
{
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Entering synchronous step.");
|
||||
nsteps += 1;
|
||||
profiler.startSection("commandProcessing");
|
||||
String actions = command.substring(stepTagLength, command.length() - (stepTagLength + 2));
|
||||
int options = Character.getNumericValue(command.charAt(stepTagLength - 2));
|
||||
boolean withInfo = options == 0 || options == 2;
|
||||
|
||||
|
||||
|
||||
|
||||
// Prepare to write data to the client.
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
double reward = 0.0;
|
||||
boolean done;
|
||||
byte[] obs;
|
||||
String info = "";
|
||||
boolean sent = false;
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Acquiring lock for synchronous step.");
|
||||
|
||||
lock.lock();
|
||||
try {
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Lock is acquired.");
|
||||
|
||||
done = envState.done;
|
||||
|
||||
// TODO Handle when the environment is done.
|
||||
|
||||
// Process the actions.
|
||||
if (actions.contains("\n")) {
|
||||
String[] cmds = actions.split("\\n");
|
||||
for(String cmd : cmds) {
|
||||
envState.commands.add(cmd);
|
||||
}
|
||||
} else {
|
||||
if (!actions.isEmpty())
|
||||
envState.commands.add(actions);
|
||||
}
|
||||
sent = true;
|
||||
|
||||
|
||||
|
||||
profiler.endSection(); //cmd
|
||||
profiler.startSection("requestTick");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Received: " + actions);
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Requesting tick.");
|
||||
// Now wait to run a tick
|
||||
// If synchronous mode is off then we should see if want to quit is true.
|
||||
while(!TimeHelper.SyncManager.requestTick() && !done ){Thread.yield();}
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Tick request granted.");
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("waitForTick");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Waiting for tick.");
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted() && !done ){ Thread.yield();}
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> TICK DONE. Getting observation.");
|
||||
|
||||
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("getObservation");
|
||||
// After which, get the observations.
|
||||
obs = getObservation(done);
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Observation received. Getting info.");
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("getInfo");
|
||||
|
||||
|
||||
// Pick up rewards.
|
||||
reward = envState.reward;
|
||||
if (withInfo) {
|
||||
info = envState.info;
|
||||
// if(info == null)
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> FILLING INFO: NULL");
|
||||
// else
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> FILLING " + info.toString());
|
||||
|
||||
}
|
||||
done = envState.done;
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> STATUS " + Boolean.toString(done));
|
||||
envState.info = null;
|
||||
envState.obs = null;
|
||||
envState.reward = 0.0;
|
||||
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Info received..");
|
||||
profiler.endSection();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Lock released. Writing observation, info, done.");
|
||||
|
||||
profiler.startSection("writeObs");
|
||||
dout.writeInt(obs.length);
|
||||
dout.write(obs);
|
||||
|
||||
dout.writeInt(BYTES_DOUBLE + 2);
|
||||
dout.writeDouble(reward);
|
||||
dout.writeByte(done ? 1 : 0);
|
||||
dout.writeByte(sent ? 1 : 0);
|
||||
|
||||
if (withInfo) {
|
||||
byte[] infoBytes = info.getBytes(utf8);
|
||||
dout.writeInt(infoBytes.length);
|
||||
dout.write(infoBytes);
|
||||
}
|
||||
|
||||
profiler.endSection(); //write obs
|
||||
profiler.startSection("flush");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Packets written. Flushing.");
|
||||
dout.flush();
|
||||
profiler.endSection(); // flush
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Done with step.");
|
||||
}
|
||||
// Handler for <Step_> messages. Single digit option code after _ specifies if turnkey and info are included in message.
|
||||
private void step(String command, Socket socket, DataInputStream din) throws IOException {
|
||||
if(envState.synchronous){
|
||||
stepSync(command, socket, din);
|
||||
}
|
||||
else{
|
||||
System.out.println("[ERROR] Asynchronous stepping is not supported in MineRL.");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Handler for <Peek> messages.
|
||||
private void peek(String command, Socket socket, DataInputStream din) throws IOException {
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
byte[] obs;
|
||||
boolean done;
|
||||
String info = "";
|
||||
// AOG - As we've only seen issues with the peek reqest, I've focused my changes to just
|
||||
// this function. Initially we want to be optimistic and assume we're not going to abort
|
||||
// the request and my observations of event timings indicate that there is plenty of time
|
||||
// between the peek request being received and the reset failing, so a race condition is
|
||||
// unlikely.
|
||||
abortRequest = false;
|
||||
|
||||
lock.lock();
|
||||
|
||||
try {
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Waiting for pistol to fire.");
|
||||
while(!TimeHelper.SyncManager.hasServerFiredPistol() && !abortRequest){
|
||||
|
||||
// Now wait to run a tick
|
||||
while(!TimeHelper.SyncManager.requestTick() && !abortRequest){Thread.yield();}
|
||||
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted() && !abortRequest){ Thread.yield();}
|
||||
|
||||
|
||||
Thread.yield();
|
||||
}
|
||||
|
||||
if (abortRequest) {
|
||||
System.out.println("AOG: Aborting peek request");
|
||||
// AOG - We detect the lack of observation within our Python wrapper and throw a slightly
|
||||
// diferent exception that by-passes MineRLs automatic clean up code. If we were to report
|
||||
// 'done', the MineRL detects this as a runtime error and kills the Minecraft process
|
||||
// triggering a lengthy restart. So far from testing, Minecraft itself is fine can we can
|
||||
// retry the reset, it's only the tight loops above that were causing things to stall and
|
||||
// timeout.
|
||||
// No observation
|
||||
dout.writeInt(0);
|
||||
// No info
|
||||
dout.writeInt(0);
|
||||
// Done
|
||||
dout.writeInt(1);
|
||||
dout.writeByte(0);
|
||||
dout.flush();
|
||||
return;
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Pistol fired!.");
|
||||
// Wait two ticks for the first observation from server to be propagated.
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Getting observation.");
|
||||
|
||||
obs = getObservation(false);
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Observation acquired.");
|
||||
done = envState.done;
|
||||
info = envState.info;
|
||||
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
dout.writeInt(obs.length);
|
||||
dout.write(obs);
|
||||
|
||||
byte[] infoBytes = info.getBytes(utf8);
|
||||
dout.writeInt(infoBytes.length);
|
||||
dout.write(infoBytes);
|
||||
|
||||
dout.writeInt(1);
|
||||
dout.writeByte(done ? 1 : 0);
|
||||
|
||||
dout.flush();
|
||||
}
|
||||
|
||||
// Get the current observation. If none and not done wait for a short time.
|
||||
public byte[] getObservation(boolean done) {
|
||||
byte[] obs = envState.obs;
|
||||
if (obs == null){
|
||||
System.out.println("[ERROR] Video observation is null; please notify the developer.");
|
||||
}
|
||||
return obs;
|
||||
}
|
||||
|
||||
// Handler for <Find> messages - used by non-zero roles to discover integrated server port from primary (role 0) service.
|
||||
|
||||
private final static int findTagLength = "<Find>".length();
|
||||
|
||||
private void find(String command, Socket socket) throws IOException {
|
||||
|
||||
Integer port;
|
||||
lock.lock();
|
||||
try {
|
||||
String token = command.substring(findTagLength, command.length() - (findTagLength + 1));
|
||||
TCPUtils.Log(Level.INFO, "Find token? " + token);
|
||||
|
||||
// Purge previous token.
|
||||
String[] tokenSplits = token.split(":");
|
||||
String experimentId = tokenSplits[0];
|
||||
int role = Integer.parseInt(tokenSplits[1]);
|
||||
int reset = Integer.parseInt(tokenSplits[2]);
|
||||
|
||||
String previousToken = experimentId + ":" + role + ":" + (reset - 1);
|
||||
initTokens.remove(previousToken);
|
||||
cond.signalAll();
|
||||
|
||||
// Check for next token. Wait for a short time if not already produced.
|
||||
port = initTokens.get(token);
|
||||
if (port == null) {
|
||||
try {
|
||||
cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException ie) {
|
||||
}
|
||||
port = initTokens.get(token);
|
||||
if (port == null) {
|
||||
port = 0;
|
||||
TCPUtils.Log(Level.INFO,"Role " + role + " reset " + reset + " waiting for token.");
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(port);
|
||||
dout.flush();
|
||||
}
|
||||
|
||||
public boolean isSynchronous(){
|
||||
return envState.synchronous;
|
||||
}
|
||||
|
||||
// Handler for <Init> messages. These reset the service so use with care!
|
||||
private void init(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
initTokens = new Hashtable<String, Integer>();
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Quit> (quit mission) messages.
|
||||
private void quit(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
if (!envState.done){
|
||||
|
||||
envState.quit = true;
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Pistol fired!.");
|
||||
// Wait two ticks for the first observation from server to be propagated.
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(envState.done ? 1 : 0);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private final static int closeTagLength = "<Close>".length();
|
||||
|
||||
// Handler for <Close> messages.
|
||||
private void close(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
String token = command.substring(closeTagLength, command.length() - (closeTagLength + 1));
|
||||
|
||||
initTokens.remove(token);
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Status> messages.
|
||||
private void status(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
String status = "{}"; // TODO Possibly have something more interesting to report.
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
|
||||
byte[] statusBytes = status.getBytes(utf8);
|
||||
dout.writeInt(statusBytes.length);
|
||||
dout.write(statusBytes);
|
||||
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Exit> messages. These "kill the service" temporarily so use with care!f
|
||||
private void exit(String command, Socket socket) throws IOException {
|
||||
// lock.lock();
|
||||
try {
|
||||
// We may exit before we get a chance to reply.
|
||||
TimeHelper.SyncManager.setSynchronous(false);
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
|
||||
ClientStateMachine.exitJava();
|
||||
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Malmo client state machine interface methods:
|
||||
|
||||
public String getCommand() {
|
||||
try {
|
||||
String command = envState.commands.poll();
|
||||
if (command == null)
|
||||
return "";
|
||||
else
|
||||
return command;
|
||||
} finally {
|
||||
}
|
||||
}
|
||||
|
||||
public void endMission() {
|
||||
// lock.lock();
|
||||
try {
|
||||
// AOG - If the mission is ending, we always want to abort requests and they won't
|
||||
// be able to progress to completion and will stall.
|
||||
System.out.println("AOG: MalmoEnvServer.endMission");
|
||||
abort();
|
||||
envState.done = true;
|
||||
envState.quit = false;
|
||||
envState.missionInit = null;
|
||||
|
||||
if (envState.token != null) {
|
||||
initTokens.remove(envState.token);
|
||||
envState.token = null;
|
||||
envState.experimentId = null;
|
||||
envState.agentCount = 0;
|
||||
envState.reset = 0;
|
||||
|
||||
// cond.signalAll();
|
||||
}
|
||||
// lock.unlock();
|
||||
} finally {
|
||||
}
|
||||
}
|
||||
// Record a Malmo "observation" json - as the env info since an environment "obs" is a video frame.
|
||||
public void observation(String info) {
|
||||
// Parsing obs as JSON would be slower but less fragile than extracting the turn_key using string search.
|
||||
// lock.lock();
|
||||
try {
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <OBSERVATION> Inserting: " + info);
|
||||
envState.info = info;
|
||||
// cond.signalAll();
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void addRewards(double rewards) {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.reward += rewards;
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void addFrame(byte[] frame) {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.obs = frame; // Replaces current.
|
||||
// cond.signalAll();
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void notifyIntegrationServerStarted(int integrationServerPort) {
|
||||
lock.lock();
|
||||
try {
|
||||
if (envState.token != null) {
|
||||
TCPUtils.Log(Level.INFO,"Integration server start up - token: " + envState.token);
|
||||
addTokens(integrationServerPort, envState.token, envState.experimentId, envState.agentCount, envState.reset);
|
||||
cond.signalAll();
|
||||
} else {
|
||||
TCPUtils.Log(Level.WARNING,"No mission token on integration server start up!");
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private void addTokens(int integratedServerPort, String myToken, String experimentId, int agentCount, int reset) {
|
||||
initTokens.put(myToken, integratedServerPort);
|
||||
// Place tokens for other agents to find.
|
||||
for (int i = 1; i < agentCount; i++) {
|
||||
String tokenForAgent = experimentId + ":" + i + ":" + reset;
|
||||
initTokens.put(tokenForAgent, integratedServerPort);
|
||||
}
|
||||
}
|
||||
|
||||
// IWantToQuit implementation.
|
||||
|
||||
@Override
|
||||
public boolean doIWantToQuit(MissionInit missionInit) {
|
||||
// lock.lock();
|
||||
try {
|
||||
return envState.quit;
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public Long getSeed(){
|
||||
return envState.seed;
|
||||
}
|
||||
|
||||
private void setWantToQuit() {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.quit = true;
|
||||
|
||||
} finally {
|
||||
|
||||
if(TimeHelper.SyncManager.isSynchronous()){
|
||||
// We want to dsynchronize everything.
|
||||
TimeHelper.SyncManager.setSynchronous(false);
|
||||
}
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void prepare(MissionInit missionInit) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cleanup() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutcome() {
|
||||
return "Env quit";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
FROM mcr.microsoft.com/azureml/base-gpu:openmpi3.1.2-cuda10.0-cudnn7-ubuntu18.04
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
cpio \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
tmux \
|
||||
htop \
|
||||
gcc \
|
||||
xvfb \
|
||||
python-opengl \
|
||||
x11-xserver-utils \
|
||||
ffmpeg \
|
||||
mesa-utils \
|
||||
nano \
|
||||
vim \
|
||||
rsync \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create a working directory
|
||||
RUN mkdir /app
|
||||
WORKDIR /app
|
||||
|
||||
# Create a Python 3.7 environment
|
||||
RUN conda install conda-build \
|
||||
&& conda create -y --name py37 python=3.7.3 \
|
||||
&& conda clean -ya
|
||||
ENV CONDA_DEFAULT_ENV=py37
|
||||
|
||||
# Install Minecraft needed libraries
|
||||
RUN mkdir -p /usr/share/man/man1 && \
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y \
|
||||
openjdk-8-jre-headless=8u162-b12-1 \
|
||||
openjdk-8-jdk-headless=8u162-b12-1 \
|
||||
openjdk-8-jre=8u162-b12-1 \
|
||||
openjdk-8-jdk=8u162-b12-1
|
||||
|
||||
RUN pip install --upgrade --user minerl
|
||||
|
||||
# PyTorch with CUDA 10 installation
|
||||
RUN conda install -y -c pytorch \
|
||||
cuda100=1.0 \
|
||||
magma-cuda100=2.4.0 \
|
||||
"pytorch=1.1.0=py3.7_cuda10.0.130_cudnn7.5.1_0" \
|
||||
torchvision=0.3.0 \
|
||||
&& conda clean -ya
|
||||
|
||||
RUN pip install \
|
||||
pandas \
|
||||
matplotlib \
|
||||
numpy \
|
||||
scipy \
|
||||
azureml-defaults \
|
||||
tensorboardX \
|
||||
tensorflow-gpu==1.15rc2 \
|
||||
GPUtil \
|
||||
tabulate \
|
||||
dm_tree \
|
||||
lz4 \
|
||||
ray==0.8.3 \
|
||||
ray[rllib]==0.8.3 \
|
||||
ray[tune]==0.8.3
|
||||
|
||||
COPY patch_files/* /root/.local/lib/python3.7/site-packages/minerl/env/Malmo/Minecraft/src/main/java/com/microsoft/Malmo/Client/
|
||||
|
||||
# Start minerl to pre-fetch minerl files (saves time when starting minerl during training)
|
||||
RUN xvfb-run -a -s "-screen 0 1400x900x24" python -c "import gym; import minerl; env = gym.make('MineRLTreechop-v0'); env.close();"
|
||||
|
||||
RUN pip install --index-url https://test.pypi.org/simple/ malmo && \
|
||||
python -c "import malmo.minecraftbootstrap; malmo.minecraftbootstrap.download();"
|
||||
|
||||
ENV MALMO_XSD_PATH="/app/MalmoPlatform/Schemas"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,939 @@
|
||||
// --------------------------------------------------------------------------------------------------
|
||||
// Copyright (c) 2016 Microsoft Corporation
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
// associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||
// sublicense, and/or l copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all copies or
|
||||
// substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
|
||||
// NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
// --------------------------------------------------------------------------------------------------
|
||||
|
||||
package com.microsoft.Malmo.Client;
|
||||
|
||||
import com.microsoft.Malmo.MalmoMod;
|
||||
import com.microsoft.Malmo.MissionHandlerInterfaces.IWantToQuit;
|
||||
import com.microsoft.Malmo.Schemas.MissionInit;
|
||||
import com.microsoft.Malmo.Utils.TCPUtils;
|
||||
|
||||
import net.minecraft.profiler.Profiler;
|
||||
import com.microsoft.Malmo.Utils.TimeHelper;
|
||||
|
||||
import net.minecraftforge.common.config.Configuration;
|
||||
import java.io.*;
|
||||
import java.net.ServerSocket;
|
||||
import java.net.Socket;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.locks.Condition;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.Hashtable;
|
||||
import com.microsoft.Malmo.Utils.TCPInputPoller;
|
||||
import java.util.logging.Level;
|
||||
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
/**
|
||||
* MalmoEnvServer - service supporting OpenAI gym "environment" for multi-agent Malmo missions.
|
||||
*/
|
||||
public class MalmoEnvServer implements IWantToQuit {
|
||||
private static Profiler profiler = new Profiler();
|
||||
private static int nsteps = 0;
|
||||
private static boolean debug = false;
|
||||
|
||||
private static String hello = "<MalmoEnv" ;
|
||||
|
||||
private class EnvState {
|
||||
|
||||
// Mission parameters:
|
||||
String missionInit = null;
|
||||
String token = null;
|
||||
String experimentId = null;
|
||||
int agentCount = 0;
|
||||
int reset = 0;
|
||||
boolean quit = false;
|
||||
boolean synchronous = false;
|
||||
Long seed = null;
|
||||
|
||||
// OpenAI gym state:
|
||||
boolean done = false;
|
||||
double reward = 0.0;
|
||||
byte[] obs = null;
|
||||
String info = "";
|
||||
LinkedList<String> commands = new LinkedList<String>();
|
||||
}
|
||||
|
||||
private static boolean envPolicy = false; // Are we configured by config policy?
|
||||
|
||||
// Synchronize on EnvStateasd
|
||||
|
||||
|
||||
private Lock lock = new ReentrantLock();
|
||||
private Condition cond = lock.newCondition();
|
||||
|
||||
private EnvState envState = new EnvState();
|
||||
|
||||
private Hashtable<String, Integer> initTokens = new Hashtable<String, Integer>();
|
||||
|
||||
static final long COND_WAIT_SECONDS = 3; // Max wait in seconds before timing out (and replying to RPC).
|
||||
static final int BYTES_INT = 4;
|
||||
static final int BYTES_DOUBLE = 8;
|
||||
private static final Charset utf8 = Charset.forName("UTF-8");
|
||||
|
||||
// Service uses a single per-environment client connection - initiated by the remote environment.
|
||||
|
||||
private int port;
|
||||
private TCPInputPoller missionPoller; // Used for command parsing and not actual communication.
|
||||
private String version;
|
||||
|
||||
// AOG: From running experiments, I've found that MineRL can get stuck resetting the
|
||||
// environment which causes huge delays while we wait for the Python side to time
|
||||
// out and restart the Minecraft instace. Minecraft itself is normally in a recoverable
|
||||
// state, but the MalmoEnvServer instance will be blocked in a tight spin loop trying
|
||||
// handling a Peek request from the Python client. To unstick things, I've added this
|
||||
// flag that can be set when we know things are in a bad state to abort the peek request.
|
||||
// WARNING: THIS IS ONLY TREATING THE SYMPTOM AND NOT THE ROOT CAUSE
|
||||
// The reason things are getting stuck is because the player is either dying or we're
|
||||
// receiving a quit request while an episode reset is in progress.
|
||||
private boolean abortRequest;
|
||||
public void abort() {
|
||||
System.out.println("AOG: MalmoEnvServer.abort");
|
||||
abortRequest = true;
|
||||
}
|
||||
|
||||
/***
|
||||
* Malmo "Env" service.
|
||||
* @param port the port the service listens on.
|
||||
* @param missionPoller for plugging into existing comms handling.
|
||||
*/
|
||||
public MalmoEnvServer(String version, int port, TCPInputPoller missionPoller) {
|
||||
this.version = version;
|
||||
this.missionPoller = missionPoller;
|
||||
this.port = port;
|
||||
// AOG - Assume we don't wan't to be aborting in the first place
|
||||
this.abortRequest = false;
|
||||
}
|
||||
|
||||
/** Initialize malmo env configuration. For now either on or "legacy" AgentHost protocol.*/
|
||||
static public void update(Configuration configs) {
|
||||
envPolicy = configs.get(MalmoMod.ENV_CONFIGS, "env", "false").getBoolean();
|
||||
}
|
||||
|
||||
public static boolean isEnv() {
|
||||
return envPolicy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Start servicing the MalmoEnv protocol.
|
||||
* @throws IOException
|
||||
*/
|
||||
public void serve() throws IOException {
|
||||
|
||||
ServerSocket serverSocket = new ServerSocket(port);
|
||||
serverSocket.setPerformancePreferences(0,2,1);
|
||||
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
final Socket socket = serverSocket.accept();
|
||||
socket.setTcpNoDelay(true);
|
||||
|
||||
Thread thread = new Thread("EnvServerSocketHandler") {
|
||||
public void run() {
|
||||
boolean running = false;
|
||||
try {
|
||||
checkHello(socket);
|
||||
|
||||
while (true) {
|
||||
DataInputStream din = new DataInputStream(socket.getInputStream());
|
||||
int hdr = din.readInt();
|
||||
byte[] data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
|
||||
String command = new String(data, utf8);
|
||||
|
||||
if (command.startsWith("<Step")) {
|
||||
|
||||
profiler.startSection("root");
|
||||
long start = System.nanoTime();
|
||||
step(command, socket, din);
|
||||
profiler.endSection();
|
||||
if (nsteps % 100 == 0 && debug){
|
||||
List<Profiler.Result> dat = profiler.getProfilingData("root");
|
||||
for(int qq = 0; qq < dat.size(); qq++){
|
||||
Profiler.Result res = dat.get(qq);
|
||||
System.out.println(res.profilerName + " " + res.totalUsePercentage + " "+ res.usePercentage);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} else if (command.startsWith("<Peek")) {
|
||||
|
||||
peek(command, socket, din);
|
||||
|
||||
} else if (command.startsWith("<Init")) {
|
||||
|
||||
init(command, socket);
|
||||
|
||||
} else if (command.startsWith("<Find")) {
|
||||
|
||||
find(command, socket);
|
||||
|
||||
} else if (command.startsWith("<MissionInit")) {
|
||||
|
||||
if (missionInit(din, command, socket))
|
||||
{
|
||||
running = true;
|
||||
}
|
||||
|
||||
} else if (command.startsWith("<Quit")) {
|
||||
|
||||
quit(command, socket);
|
||||
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Exit")) {
|
||||
|
||||
exit(command, socket);
|
||||
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Close")) {
|
||||
|
||||
close(command, socket);
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Status")) {
|
||||
|
||||
status(command, socket);
|
||||
|
||||
} else if (command.startsWith("<Echo")) {
|
||||
command = "<Echo>" + command + "</Echo>";
|
||||
data = command.getBytes(utf8);
|
||||
hdr = data.length;
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(hdr);
|
||||
dout.write(data, 0, hdr);
|
||||
dout.flush();
|
||||
} else {
|
||||
throw new IOException("Unknown env service command");
|
||||
}
|
||||
}
|
||||
} catch (IOException ioe) {
|
||||
// ioe.printStackTrace();
|
||||
TCPUtils.Log(Level.SEVERE, "MalmoEnv socket error: " + ioe + " (can be on disconnect)");
|
||||
// System.out.println("[ERROR] " + "MalmoEnv socket error: " + ioe + " (can be on disconnect)");
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] MalmoEnv socket error");
|
||||
try {
|
||||
if (running) {
|
||||
TCPUtils.Log(Level.INFO,"Want to quit on disconnect.");
|
||||
|
||||
System.out.println("[LOGTOPY] " + "Want to quit on disconnect.");
|
||||
setWantToQuit();
|
||||
}
|
||||
socket.close();
|
||||
} catch (IOException ioe2) {
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
thread.start();
|
||||
} catch (IOException ioe) {
|
||||
TCPUtils.Log(Level.SEVERE, "MalmoEnv service exits on " + ioe);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void checkHello(Socket socket) throws IOException {
|
||||
DataInputStream din = new DataInputStream(socket.getInputStream());
|
||||
int hdr = din.readInt();
|
||||
if (hdr <= 0 || hdr > hello.length() + 8) // Version number may be somewhat longer in future.
|
||||
throw new IOException("Invalid MalmoEnv hello header length");
|
||||
byte[] data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
if (!new String(data).startsWith(hello + version))
|
||||
throw new IOException("MalmoEnv invalid protocol or version - expected " + hello + version);
|
||||
}
|
||||
|
||||
// Handler for <MissionInit> messages.
|
||||
private boolean missionInit(DataInputStream din, String command, Socket socket) throws IOException {
|
||||
|
||||
String ipOriginator = socket.getInetAddress().getHostName();
|
||||
|
||||
int hdr;
|
||||
byte[] data;
|
||||
hdr = din.readInt();
|
||||
data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
String id = new String(data, utf8);
|
||||
|
||||
TCPUtils.Log(Level.INFO,"Mission Init" + id);
|
||||
|
||||
String[] token = id.split(":");
|
||||
String experimentId = token[0];
|
||||
int role = Integer.parseInt(token[1]);
|
||||
int reset = Integer.parseInt(token[2]);
|
||||
int agentCount = Integer.parseInt(token[3]);
|
||||
Boolean isSynchronous = Boolean.parseBoolean(token[4]);
|
||||
Long seed = null;
|
||||
if(token.length > 5)
|
||||
seed = Long.parseLong(token[5]);
|
||||
|
||||
if(isSynchronous && agentCount > 1){
|
||||
throw new IOException("Synchronous mode currently does not support multiple agents.");
|
||||
}
|
||||
port = -1;
|
||||
boolean allTokensConsumed = true;
|
||||
boolean started = false;
|
||||
|
||||
lock.lock();
|
||||
try {
|
||||
if (role == 0) {
|
||||
|
||||
String previousToken = experimentId + ":0:" + (reset - 1);
|
||||
initTokens.remove(previousToken);
|
||||
|
||||
String myToken = experimentId + ":0:" + reset;
|
||||
if (!initTokens.containsKey(myToken)) {
|
||||
TCPUtils.Log(Level.INFO,"(Pre)Start " + role + " reset " + reset);
|
||||
started = startUp(command, ipOriginator, experimentId, reset, agentCount, myToken, seed, isSynchronous);
|
||||
if (started)
|
||||
initTokens.put(myToken, 0);
|
||||
} else {
|
||||
started = true; // Pre-started previously.
|
||||
}
|
||||
|
||||
// Check that all previous tokens have been consumed. If not don't proceed to mission.
|
||||
|
||||
allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount);
|
||||
if (!allTokensConsumed) {
|
||||
try {
|
||||
cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException ie) {
|
||||
}
|
||||
allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount);
|
||||
}
|
||||
} else {
|
||||
TCPUtils.Log(Level.INFO, "Start " + role + " reset " + reset);
|
||||
|
||||
started = startUp(command, ipOriginator, experimentId, reset, agentCount, experimentId + ":" + role + ":" + reset, seed, isSynchronous);
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(allTokensConsumed && started ? 1 : 0);
|
||||
dout.flush();
|
||||
|
||||
dout.flush();
|
||||
|
||||
return allTokensConsumed && started;
|
||||
}
|
||||
|
||||
private boolean areAllTokensConsumed(String experimentId, int reset, int agentCount) {
|
||||
boolean allTokensConsumed = true;
|
||||
for (int i = 1; i < agentCount; i++) {
|
||||
String tokenForAgent = experimentId + ":" + i + ":" + (reset - 1);
|
||||
if (initTokens.containsKey(tokenForAgent)) {
|
||||
TCPUtils.Log(Level.FINE,"Mission init - unconsumed " + tokenForAgent);
|
||||
allTokensConsumed = false;
|
||||
}
|
||||
}
|
||||
return allTokensConsumed;
|
||||
}
|
||||
|
||||
private boolean startUp(String command, String ipOriginator, String experimentId, int reset, int agentCount, String myToken, Long seed, Boolean isSynchronous) throws IOException {
|
||||
|
||||
// Clear out mission state
|
||||
envState.reward = 0.0;
|
||||
envState.commands.clear();
|
||||
envState.obs = null;
|
||||
envState.info = "";
|
||||
|
||||
|
||||
envState.missionInit = command;
|
||||
envState.done = false;
|
||||
envState.quit = false;
|
||||
envState.token = myToken;
|
||||
envState.experimentId = experimentId;
|
||||
envState.agentCount = agentCount;
|
||||
envState.reset = reset;
|
||||
envState.synchronous = isSynchronous;
|
||||
envState.seed = seed;
|
||||
|
||||
return startUpMission(command, ipOriginator);
|
||||
}
|
||||
|
||||
private boolean startUpMission(String command, String ipOriginator) throws IOException {
|
||||
|
||||
if (missionPoller == null)
|
||||
return false;
|
||||
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(baos);
|
||||
|
||||
missionPoller.commandReceived(command, ipOriginator, dos);
|
||||
|
||||
dos.flush();
|
||||
byte[] reply = baos.toByteArray();
|
||||
ByteArrayInputStream bais = new ByteArrayInputStream(reply);
|
||||
DataInputStream dis = new DataInputStream(bais);
|
||||
int hdr = dis.readInt();
|
||||
byte[] replyBytes = new byte[hdr];
|
||||
dis.readFully(replyBytes);
|
||||
|
||||
String replyStr = new String(replyBytes);
|
||||
if (replyStr.equals("MALMOOK")) {
|
||||
TCPUtils.Log(Level.INFO, "MalmoEnvServer Mission starting ...");
|
||||
return true;
|
||||
} else if (replyStr.equals("MALMOBUSY")) {
|
||||
TCPUtils.Log(Level.INFO, "MalmoEnvServer Busy - I want to quit");
|
||||
this.envState.quit = true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static final int stepTagLength = "<Step_>".length(); // Step with option code.
|
||||
private synchronized void stepSync(String command, Socket socket, DataInputStream din) throws IOException
|
||||
{
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Entering synchronous step.");
|
||||
nsteps += 1;
|
||||
profiler.startSection("commandProcessing");
|
||||
String actions = command.substring(stepTagLength, command.length() - (stepTagLength + 2));
|
||||
int options = Character.getNumericValue(command.charAt(stepTagLength - 2));
|
||||
boolean withInfo = options == 0 || options == 2;
|
||||
|
||||
|
||||
|
||||
|
||||
// Prepare to write data to the client.
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
double reward = 0.0;
|
||||
boolean done;
|
||||
byte[] obs;
|
||||
String info = "";
|
||||
boolean sent = false;
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Acquiring lock for synchronous step.");
|
||||
|
||||
lock.lock();
|
||||
try {
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Lock is acquired.");
|
||||
|
||||
done = envState.done;
|
||||
|
||||
// TODO Handle when the environment is done.
|
||||
|
||||
// Process the actions.
|
||||
if (actions.contains("\n")) {
|
||||
String[] cmds = actions.split("\\n");
|
||||
for(String cmd : cmds) {
|
||||
envState.commands.add(cmd);
|
||||
}
|
||||
} else {
|
||||
if (!actions.isEmpty())
|
||||
envState.commands.add(actions);
|
||||
}
|
||||
sent = true;
|
||||
|
||||
|
||||
|
||||
profiler.endSection(); //cmd
|
||||
profiler.startSection("requestTick");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Received: " + actions);
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Requesting tick.");
|
||||
// Now wait to run a tick
|
||||
// If synchronous mode is off then we should see if want to quit is true.
|
||||
while(!TimeHelper.SyncManager.requestTick() && !done ){Thread.yield();}
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Tick request granted.");
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("waitForTick");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Waiting for tick.");
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted() && !done ){ Thread.yield();}
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> TICK DONE. Getting observation.");
|
||||
|
||||
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("getObservation");
|
||||
// After which, get the observations.
|
||||
obs = getObservation(done);
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Observation received. Getting info.");
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("getInfo");
|
||||
|
||||
|
||||
// Pick up rewards.
|
||||
reward = envState.reward;
|
||||
if (withInfo) {
|
||||
info = envState.info;
|
||||
// if(info == null)
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> FILLING INFO: NULL");
|
||||
// else
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> FILLING " + info.toString());
|
||||
|
||||
}
|
||||
done = envState.done;
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> STATUS " + Boolean.toString(done));
|
||||
envState.info = null;
|
||||
envState.obs = null;
|
||||
envState.reward = 0.0;
|
||||
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Info received..");
|
||||
profiler.endSection();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Lock released. Writing observation, info, done.");
|
||||
|
||||
profiler.startSection("writeObs");
|
||||
dout.writeInt(obs.length);
|
||||
dout.write(obs);
|
||||
|
||||
dout.writeInt(BYTES_DOUBLE + 2);
|
||||
dout.writeDouble(reward);
|
||||
dout.writeByte(done ? 1 : 0);
|
||||
dout.writeByte(sent ? 1 : 0);
|
||||
|
||||
if (withInfo) {
|
||||
byte[] infoBytes = info.getBytes(utf8);
|
||||
dout.writeInt(infoBytes.length);
|
||||
dout.write(infoBytes);
|
||||
}
|
||||
|
||||
profiler.endSection(); //write obs
|
||||
profiler.startSection("flush");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Packets written. Flushing.");
|
||||
dout.flush();
|
||||
profiler.endSection(); // flush
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Done with step.");
|
||||
}
|
||||
// Handler for <Step_> messages. Single digit option code after _ specifies if turnkey and info are included in message.
|
||||
private void step(String command, Socket socket, DataInputStream din) throws IOException {
|
||||
if(envState.synchronous){
|
||||
stepSync(command, socket, din);
|
||||
}
|
||||
else{
|
||||
System.out.println("[ERROR] Asynchronous stepping is not supported in MineRL.");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Handler for <Peek> messages.
|
||||
private void peek(String command, Socket socket, DataInputStream din) throws IOException {
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
byte[] obs;
|
||||
boolean done;
|
||||
String info = "";
|
||||
// AOG - As we've only seen issues with the peek reqest, I've focused my changes to just
|
||||
// this function. Initially we want to be optimistic and assume we're not going to abort
|
||||
// the request and my observations of event timings indicate that there is plenty of time
|
||||
// between the peek request being received and the reset failing, so a race condition is
|
||||
// unlikely.
|
||||
abortRequest = false;
|
||||
|
||||
lock.lock();
|
||||
|
||||
try {
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Waiting for pistol to fire.");
|
||||
while(!TimeHelper.SyncManager.hasServerFiredPistol() && !abortRequest){
|
||||
|
||||
// Now wait to run a tick
|
||||
while(!TimeHelper.SyncManager.requestTick() && !abortRequest){Thread.yield();}
|
||||
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted() && !abortRequest){ Thread.yield();}
|
||||
|
||||
|
||||
Thread.yield();
|
||||
}
|
||||
|
||||
if (abortRequest) {
|
||||
System.out.println("AOG: Aborting peek request");
|
||||
// AOG - We detect the lack of observation within our Python wrapper and throw a slightly
|
||||
// diferent exception that by-passes MineRLs automatic clean up code. If we were to report
|
||||
// 'done', the MineRL detects this as a runtime error and kills the Minecraft process
|
||||
// triggering a lengthy restart. So far from testing, Minecraft itself is fine can we can
|
||||
// retry the reset, it's only the tight loops above that were causing things to stall and
|
||||
// timeout.
|
||||
// No observation
|
||||
dout.writeInt(0);
|
||||
// No info
|
||||
dout.writeInt(0);
|
||||
// Done
|
||||
dout.writeInt(1);
|
||||
dout.writeByte(0);
|
||||
dout.flush();
|
||||
return;
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Pistol fired!.");
|
||||
// Wait two ticks for the first observation from server to be propagated.
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Getting observation.");
|
||||
|
||||
obs = getObservation(false);
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Observation acquired.");
|
||||
done = envState.done;
|
||||
info = envState.info;
|
||||
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
dout.writeInt(obs.length);
|
||||
dout.write(obs);
|
||||
|
||||
byte[] infoBytes = info.getBytes(utf8);
|
||||
dout.writeInt(infoBytes.length);
|
||||
dout.write(infoBytes);
|
||||
|
||||
dout.writeInt(1);
|
||||
dout.writeByte(done ? 1 : 0);
|
||||
|
||||
dout.flush();
|
||||
}
|
||||
|
||||
// Get the current observation. If none and not done wait for a short time.
|
||||
public byte[] getObservation(boolean done) {
|
||||
byte[] obs = envState.obs;
|
||||
if (obs == null){
|
||||
System.out.println("[ERROR] Video observation is null; please notify the developer.");
|
||||
}
|
||||
return obs;
|
||||
}
|
||||
|
||||
// Handler for <Find> messages - used by non-zero roles to discover integrated server port from primary (role 0) service.
|
||||
|
||||
private final static int findTagLength = "<Find>".length();
|
||||
|
||||
private void find(String command, Socket socket) throws IOException {
|
||||
|
||||
Integer port;
|
||||
lock.lock();
|
||||
try {
|
||||
String token = command.substring(findTagLength, command.length() - (findTagLength + 1));
|
||||
TCPUtils.Log(Level.INFO, "Find token? " + token);
|
||||
|
||||
// Purge previous token.
|
||||
String[] tokenSplits = token.split(":");
|
||||
String experimentId = tokenSplits[0];
|
||||
int role = Integer.parseInt(tokenSplits[1]);
|
||||
int reset = Integer.parseInt(tokenSplits[2]);
|
||||
|
||||
String previousToken = experimentId + ":" + role + ":" + (reset - 1);
|
||||
initTokens.remove(previousToken);
|
||||
cond.signalAll();
|
||||
|
||||
// Check for next token. Wait for a short time if not already produced.
|
||||
port = initTokens.get(token);
|
||||
if (port == null) {
|
||||
try {
|
||||
cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException ie) {
|
||||
}
|
||||
port = initTokens.get(token);
|
||||
if (port == null) {
|
||||
port = 0;
|
||||
TCPUtils.Log(Level.INFO,"Role " + role + " reset " + reset + " waiting for token.");
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(port);
|
||||
dout.flush();
|
||||
}
|
||||
|
||||
public boolean isSynchronous(){
|
||||
return envState.synchronous;
|
||||
}
|
||||
|
||||
// Handler for <Init> messages. These reset the service so use with care!
|
||||
private void init(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
initTokens = new Hashtable<String, Integer>();
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Quit> (quit mission) messages.
|
||||
private void quit(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
if (!envState.done){
|
||||
|
||||
envState.quit = true;
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Pistol fired!.");
|
||||
// Wait two ticks for the first observation from server to be propagated.
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(envState.done ? 1 : 0);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private final static int closeTagLength = "<Close>".length();
|
||||
|
||||
// Handler for <Close> messages.
|
||||
private void close(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
String token = command.substring(closeTagLength, command.length() - (closeTagLength + 1));
|
||||
|
||||
initTokens.remove(token);
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Status> messages.
|
||||
private void status(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
String status = "{}"; // TODO Possibly have something more interesting to report.
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
|
||||
byte[] statusBytes = status.getBytes(utf8);
|
||||
dout.writeInt(statusBytes.length);
|
||||
dout.write(statusBytes);
|
||||
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Exit> messages. These "kill the service" temporarily so use with care!f
|
||||
private void exit(String command, Socket socket) throws IOException {
|
||||
// lock.lock();
|
||||
try {
|
||||
// We may exit before we get a chance to reply.
|
||||
TimeHelper.SyncManager.setSynchronous(false);
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
|
||||
ClientStateMachine.exitJava();
|
||||
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Malmo client state machine interface methods:
|
||||
|
||||
public String getCommand() {
|
||||
try {
|
||||
String command = envState.commands.poll();
|
||||
if (command == null)
|
||||
return "";
|
||||
else
|
||||
return command;
|
||||
} finally {
|
||||
}
|
||||
}
|
||||
|
||||
public void endMission() {
|
||||
// lock.lock();
|
||||
try {
|
||||
// AOG - If the mission is ending, we always want to abort requests and they won't
|
||||
// be able to progress to completion and will stall.
|
||||
System.out.println("AOG: MalmoEnvServer.endMission");
|
||||
abort();
|
||||
envState.done = true;
|
||||
envState.quit = false;
|
||||
envState.missionInit = null;
|
||||
|
||||
if (envState.token != null) {
|
||||
initTokens.remove(envState.token);
|
||||
envState.token = null;
|
||||
envState.experimentId = null;
|
||||
envState.agentCount = 0;
|
||||
envState.reset = 0;
|
||||
|
||||
// cond.signalAll();
|
||||
}
|
||||
// lock.unlock();
|
||||
} finally {
|
||||
}
|
||||
}
|
||||
// Record a Malmo "observation" json - as the env info since an environment "obs" is a video frame.
|
||||
public void observation(String info) {
|
||||
// Parsing obs as JSON would be slower but less fragile than extracting the turn_key using string search.
|
||||
// lock.lock();
|
||||
try {
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <OBSERVATION> Inserting: " + info);
|
||||
envState.info = info;
|
||||
// cond.signalAll();
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void addRewards(double rewards) {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.reward += rewards;
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void addFrame(byte[] frame) {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.obs = frame; // Replaces current.
|
||||
// cond.signalAll();
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void notifyIntegrationServerStarted(int integrationServerPort) {
|
||||
lock.lock();
|
||||
try {
|
||||
if (envState.token != null) {
|
||||
TCPUtils.Log(Level.INFO,"Integration server start up - token: " + envState.token);
|
||||
addTokens(integrationServerPort, envState.token, envState.experimentId, envState.agentCount, envState.reset);
|
||||
cond.signalAll();
|
||||
} else {
|
||||
TCPUtils.Log(Level.WARNING,"No mission token on integration server start up!");
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private void addTokens(int integratedServerPort, String myToken, String experimentId, int agentCount, int reset) {
|
||||
initTokens.put(myToken, integratedServerPort);
|
||||
// Place tokens for other agents to find.
|
||||
for (int i = 1; i < agentCount; i++) {
|
||||
String tokenForAgent = experimentId + ":" + i + ":" + reset;
|
||||
initTokens.put(tokenForAgent, integratedServerPort);
|
||||
}
|
||||
}
|
||||
|
||||
// IWantToQuit implementation.
|
||||
|
||||
@Override
|
||||
public boolean doIWantToQuit(MissionInit missionInit) {
|
||||
// lock.lock();
|
||||
try {
|
||||
return envState.quit;
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public Long getSeed(){
|
||||
return envState.seed;
|
||||
}
|
||||
|
||||
private void setWantToQuit() {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.quit = true;
|
||||
|
||||
} finally {
|
||||
|
||||
if(TimeHelper.SyncManager.isSynchronous()){
|
||||
// We want to dsynchronize everything.
|
||||
TimeHelper.SyncManager.setSynchronous(false);
|
||||
}
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void prepare(MissionInit missionInit) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cleanup() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutcome() {
|
||||
return "Env quit";
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,173 @@
|
||||
import time
|
||||
import glob
|
||||
import pathlib
|
||||
|
||||
from malmo import MalmoPython, malmoutils
|
||||
from malmo.launch_minecraft_in_background import launch_minecraft_in_background
|
||||
|
||||
|
||||
class MalmoVideoRecorder:
|
||||
DEFAULT_RECORDINGS_DIR = './logs/videos'
|
||||
|
||||
def __init__(self):
|
||||
self.agent_host_bot = None
|
||||
self.agent_host_camera = None
|
||||
self.client_pool = None
|
||||
self.is_malmo_initialized = False
|
||||
|
||||
def init_malmo(self, recordings_directory=DEFAULT_RECORDINGS_DIR):
|
||||
if self.is_malmo_initialized:
|
||||
return
|
||||
|
||||
launch_minecraft_in_background(
|
||||
'/app/MalmoPlatform/Minecraft',
|
||||
ports=[10000, 10001])
|
||||
|
||||
# Set up two agent hosts
|
||||
self.agent_host_bot = MalmoPython.AgentHost()
|
||||
self.agent_host_camera = MalmoPython.AgentHost()
|
||||
|
||||
# Create list of Minecraft clients to attach to. The agents must
|
||||
# have been launched before calling record_malmo_video using
|
||||
# init_malmo()
|
||||
self.client_pool = MalmoPython.ClientPool()
|
||||
self.client_pool.add(MalmoPython.ClientInfo('127.0.0.1', 10000))
|
||||
self.client_pool.add(MalmoPython.ClientInfo('127.0.0.1', 10001))
|
||||
|
||||
# Use bot's agenthost to hold the command-line options
|
||||
malmoutils.parse_command_line(
|
||||
self.agent_host_bot,
|
||||
['--record_video', '--recording_dir', recordings_directory])
|
||||
|
||||
self.is_malmo_initialized = True
|
||||
|
||||
def _start_mission(self, agent_host, mission, recording_spec, role):
|
||||
used_attempts = 0
|
||||
max_attempts = 5
|
||||
|
||||
while True:
|
||||
try:
|
||||
agent_host.startMission(
|
||||
mission,
|
||||
self.client_pool,
|
||||
recording_spec,
|
||||
role,
|
||||
'')
|
||||
break
|
||||
except MalmoPython.MissionException as e:
|
||||
errorCode = e.details.errorCode
|
||||
if errorCode == (MalmoPython.MissionErrorCode
|
||||
.MISSION_SERVER_WARMING_UP):
|
||||
time.sleep(2)
|
||||
elif errorCode == (MalmoPython.MissionErrorCode
|
||||
.MISSION_INSUFFICIENT_CLIENTS_AVAILABLE):
|
||||
print('Not enough Minecraft instances running.')
|
||||
used_attempts += 1
|
||||
if used_attempts < max_attempts:
|
||||
print('Will wait in case they are starting up.')
|
||||
time.sleep(300)
|
||||
elif errorCode == (MalmoPython.MissionErrorCode
|
||||
.MISSION_SERVER_NOT_FOUND):
|
||||
print('Server not found.')
|
||||
used_attempts += 1
|
||||
if used_attempts < max_attempts:
|
||||
print('Will wait and retry.')
|
||||
time.sleep(2)
|
||||
else:
|
||||
used_attempts = max_attempts
|
||||
if used_attempts >= max_attempts:
|
||||
raise e
|
||||
|
||||
def _wait_for_start(self, agent_hosts):
|
||||
start_flags = [False for a in agent_hosts]
|
||||
start_time = time.time()
|
||||
time_out = 120
|
||||
|
||||
while not all(start_flags) and time.time() - start_time < time_out:
|
||||
states = [a.peekWorldState() for a in agent_hosts]
|
||||
start_flags = [w.has_mission_begun for w in states]
|
||||
errors = [e for w in states for e in w.errors]
|
||||
|
||||
if len(errors) > 0:
|
||||
print("Errors waiting for mission start:")
|
||||
for e in errors:
|
||||
print(e.text)
|
||||
raise Exception("Encountered errors while starting mission.")
|
||||
if time.time() - start_time >= time_out:
|
||||
raise Exception("Timed out while waiting for mission to start.")
|
||||
|
||||
def _get_xml(self, xml_file, seed):
|
||||
with open(xml_file, 'r') as mission_file:
|
||||
return mission_file.read().format(SEED_PLACEHOLDER=seed)
|
||||
|
||||
def _is_mission_running(self):
|
||||
return self.agent_host_bot.peekWorldState().is_mission_running or \
|
||||
self.agent_host_camera.peekWorldState().is_mission_running
|
||||
|
||||
def record_malmo_video(self, instructions, xml_file, seed):
|
||||
'''
|
||||
Replays a set of instructions through Malmo using two players. The
|
||||
first player will navigate the specified mission based on the given
|
||||
instructions. The second player observes the first player's moves,
|
||||
which is captured in a video.
|
||||
'''
|
||||
|
||||
if not self.is_malmo_initialized:
|
||||
raise Exception('Malmo not initialized. Call init_malmo() first.')
|
||||
|
||||
# Set up the mission
|
||||
my_mission = MalmoPython.MissionSpec(
|
||||
self._get_xml(xml_file, seed),
|
||||
True)
|
||||
|
||||
bot_recording_spec = MalmoPython.MissionRecordSpec()
|
||||
camera_recording_spec = MalmoPython.MissionRecordSpec()
|
||||
|
||||
recordingsDirectory = \
|
||||
malmoutils.get_recordings_directory(self.agent_host_bot)
|
||||
if recordingsDirectory:
|
||||
camera_recording_spec.setDestination(
|
||||
recordingsDirectory + "//rollout_" + str(seed) + ".tgz")
|
||||
camera_recording_spec.recordMP4(
|
||||
MalmoPython.FrameType.VIDEO,
|
||||
36,
|
||||
2000000,
|
||||
False)
|
||||
|
||||
# Start the agents
|
||||
self._start_mission(
|
||||
self.agent_host_bot,
|
||||
my_mission,
|
||||
bot_recording_spec,
|
||||
0)
|
||||
self._start_mission(
|
||||
self.agent_host_camera,
|
||||
my_mission,
|
||||
camera_recording_spec,
|
||||
1)
|
||||
self._wait_for_start([self.agent_host_camera, self.agent_host_bot])
|
||||
|
||||
# Teleport the camera agent to the required position
|
||||
self.agent_host_camera.sendCommand('tp -29 72 -6.7')
|
||||
instruction_index = 0
|
||||
|
||||
while self._is_mission_running():
|
||||
|
||||
command = instructions[instruction_index]
|
||||
instruction_index += 1
|
||||
|
||||
self.agent_host_bot.sendCommand(command)
|
||||
|
||||
# Pause for half a second - change this for faster/slower videos
|
||||
time.sleep(0.5)
|
||||
|
||||
if instruction_index == len(instructions):
|
||||
self.agent_host_bot.sendCommand("jump 1")
|
||||
time.sleep(2)
|
||||
|
||||
self.agent_host_bot.sendCommand("quit")
|
||||
|
||||
# Wait a little for Malmo to reset before the
|
||||
# next mission is started
|
||||
time.sleep(2)
|
||||
print("Video recorded.")
|
||||
@@ -0,0 +1,180 @@
|
||||
import json
|
||||
import logging
|
||||
|
||||
import gym
|
||||
import minerl.env.core
|
||||
import minerl.env.comms
|
||||
import numpy as np
|
||||
|
||||
from ray.rllib.env.atari_wrappers import FrameStack
|
||||
from minerl.env.malmo import InstanceManager
|
||||
|
||||
# Modify the MineRL timeouts to detect common errors
|
||||
# quicker and speed up recovery
|
||||
minerl.env.core.SOCKTIME = 60.0
|
||||
minerl.env.comms.retry_timeout = 1
|
||||
|
||||
|
||||
class EnvWrapper(minerl.env.core.MineRLEnv):
|
||||
def __init__(self, xml, port):
|
||||
InstanceManager.configure_malmo_base_port(port)
|
||||
self.action_to_command_array = [
|
||||
'move 1',
|
||||
'camera 0 270',
|
||||
'camera 0 90']
|
||||
|
||||
super().__init__(
|
||||
xml,
|
||||
gym.spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8),
|
||||
gym.spaces.Discrete(3)
|
||||
)
|
||||
|
||||
self.metadata['video.frames_per_second'] = 2
|
||||
|
||||
def _setup_spaces(self, observation_space, action_space):
|
||||
self.observation_space = observation_space
|
||||
self.action_space = action_space
|
||||
|
||||
def _process_action(self, action_in) -> str:
|
||||
assert self.action_space.contains(action_in)
|
||||
assert action_in <= len(
|
||||
self.action_to_command_array) - 1, 'action index out of bounds.'
|
||||
return self.action_to_command_array[action_in]
|
||||
|
||||
def _process_observation(self, pov, info):
|
||||
'''
|
||||
Overwritten to simplify: returns only `pov` and
|
||||
not as the MineRLEnv an obs_dict (observation directory)
|
||||
'''
|
||||
pov = np.frombuffer(pov, dtype=np.uint8)
|
||||
|
||||
if pov is None or len(pov) == 0:
|
||||
raise Exception('Invalid observation, probably an aborted peek')
|
||||
else:
|
||||
pov = pov.reshape(
|
||||
(self.height, self.width, self.depth)
|
||||
)[::-1, :, :]
|
||||
|
||||
assert self.observation_space.contains(pov)
|
||||
|
||||
self._last_pov = pov
|
||||
|
||||
return pov
|
||||
|
||||
|
||||
class TrackingEnv(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self._actions = [
|
||||
self._forward,
|
||||
self._turn_left,
|
||||
self._turn_right
|
||||
]
|
||||
|
||||
def _reset_state(self):
|
||||
self._facing = (1, 0)
|
||||
self._position = (0, 0)
|
||||
self._visited = {}
|
||||
self._update_visited()
|
||||
|
||||
def _forward(self):
|
||||
self._position = (
|
||||
self._position[0] + self._facing[0],
|
||||
self._position[1] + self._facing[1]
|
||||
)
|
||||
|
||||
def _turn_left(self):
|
||||
self._facing = (self._facing[1], -self._facing[0])
|
||||
|
||||
def _turn_right(self):
|
||||
self._facing = (-self._facing[1], self._facing[0])
|
||||
|
||||
def _encode_state(self):
|
||||
return self._position
|
||||
|
||||
def _update_visited(self):
|
||||
state = self._encode_state()
|
||||
value = self._visited.get(state, 0)
|
||||
self._visited[state] = value + 1
|
||||
return value
|
||||
|
||||
def reset(self):
|
||||
self._reset_state()
|
||||
return super().reset()
|
||||
|
||||
def step(self, action):
|
||||
o, r, d, i = super().step(action)
|
||||
self._actions[action]()
|
||||
revisit_count = self._update_visited()
|
||||
if revisit_count == 0:
|
||||
r += 0.1
|
||||
|
||||
return o, r, d, i
|
||||
|
||||
|
||||
class TrajectoryWrapper(gym.Wrapper):
|
||||
def __init__(self, env):
|
||||
super().__init__(env)
|
||||
self._trajectory = []
|
||||
self._action_to_malmo_command_array = ['move 1', 'turn -1', 'turn 1']
|
||||
|
||||
def get_trajectory(self):
|
||||
return self._trajectory
|
||||
|
||||
def _to_malmo_action(self, action_index):
|
||||
return self._action_to_malmo_command_array[action_index]
|
||||
|
||||
def step(self, action):
|
||||
self._trajectory.append(self._to_malmo_action(action))
|
||||
o, r, d, i = super().step(action)
|
||||
|
||||
return o, r, d, i
|
||||
|
||||
|
||||
class DummyEnv(gym.Env):
|
||||
def __init__(self):
|
||||
self.observation_space = gym.spaces.Box(
|
||||
low=0,
|
||||
high=255,
|
||||
shape=(84, 84, 6),
|
||||
dtype=np.uint8)
|
||||
self.action_space = gym.spaces.Discrete(3)
|
||||
|
||||
|
||||
# Define a function to create a MineRL environment
|
||||
def create_env(config):
|
||||
mission = config["mission"]
|
||||
port = 1000 * config.worker_index + config.vector_index
|
||||
print('*********************************************')
|
||||
print(f'* Worker {config.worker_index} creating from \
|
||||
mission: {mission}, port {port}')
|
||||
print('*********************************************')
|
||||
|
||||
if config.worker_index == 0:
|
||||
# The first environment is only used for checking the action
|
||||
# and observation space. By using a dummy environment, there's
|
||||
# no need to spin up a Minecraft instance behind it saving some
|
||||
# CPU resources on the head node.
|
||||
return DummyEnv()
|
||||
|
||||
env = EnvWrapper(mission, port)
|
||||
env = TrackingEnv(env)
|
||||
env = FrameStack(env, 2)
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def create_env_for_rollout(config):
|
||||
mission = config['mission']
|
||||
port = 1000 * config.worker_index + config.vector_index
|
||||
print('*********************************************')
|
||||
print(f'* Worker {config.worker_index} creating from \
|
||||
mission: {mission}, port {port}')
|
||||
print('*********************************************')
|
||||
|
||||
env = EnvWrapper(mission, port)
|
||||
env = TrackingEnv(env)
|
||||
env = FrameStack(env, 2)
|
||||
env = TrajectoryWrapper(env)
|
||||
|
||||
return env
|
||||
@@ -0,0 +1,95 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no" ?>
|
||||
<Mission xmlns="http://ProjectMalmo.microsoft.com" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
|
||||
|
||||
<About>
|
||||
<Summary>$(ENV_NAME)</Summary>
|
||||
</About>
|
||||
|
||||
<ModSettings>
|
||||
<MsPerTick>50</MsPerTick>
|
||||
</ModSettings>
|
||||
|
||||
<ServerSection>
|
||||
<ServerInitialConditions>
|
||||
<Time>
|
||||
<StartTime>6000</StartTime>
|
||||
<AllowPassageOfTime>false</AllowPassageOfTime>
|
||||
</Time>
|
||||
<Weather>clear</Weather>
|
||||
<AllowSpawning>false</AllowSpawning>
|
||||
</ServerInitialConditions>
|
||||
<ServerHandlers>
|
||||
<FlatWorldGenerator generatorString="3;7,220*1,5*3,2;3;,biome_1"/>
|
||||
|
||||
<DrawingDecorator>
|
||||
<DrawSphere x="-29" y="70" z="-2" radius="100" type="air"/>
|
||||
<DrawCuboid x1="-34" y1="70" z1="-7" x2="-24" y2="70" z2="3" type="lava" />
|
||||
</DrawingDecorator>
|
||||
|
||||
<MazeDecorator>
|
||||
<Seed>random</Seed>
|
||||
<SizeAndPosition width="5" length="5" height="10" xOrigin="-32" yOrigin="69" zOrigin="-5"/>
|
||||
<StartBlock type="emerald_block" fixedToEdge="false"/>
|
||||
<EndBlock type="lapis_block" fixedToEdge="false"/>
|
||||
<PathBlock type="grass"/>
|
||||
<FloorBlock type="air"/>
|
||||
<GapBlock type="lava"/>
|
||||
<GapProbability>0.6</GapProbability>
|
||||
<AllowDiagonalMovement>false</AllowDiagonalMovement>
|
||||
</MazeDecorator>
|
||||
|
||||
<ServerQuitFromTimeUp timeLimitMs="300000" description="out_of_time"/>
|
||||
<ServerQuitWhenAnyAgentFinishes/>
|
||||
</ServerHandlers>
|
||||
</ServerSection>
|
||||
|
||||
<AgentSection mode="Survival">
|
||||
<Name>AML_Bot</Name>
|
||||
|
||||
<AgentStart>
|
||||
<Placement x="-28.5" y="71.0" z="-1.5" pitch="70" yaw="0"/>
|
||||
</AgentStart>
|
||||
|
||||
<AgentHandlers>
|
||||
|
||||
<VideoProducer want_depth="false">
|
||||
<Width>84</Width>
|
||||
<Height>84</Height>
|
||||
</VideoProducer>
|
||||
|
||||
<FileBasedPerformanceProducer/>
|
||||
|
||||
<ObservationFromFullInventory flat="false"/>
|
||||
<ObservationFromFullStats/>
|
||||
<HumanLevelCommands>
|
||||
<ModifierList type="deny-list">
|
||||
<command>moveMouse</command>
|
||||
<command>inventory</command>
|
||||
</ModifierList>
|
||||
</HumanLevelCommands>
|
||||
<CameraCommands/>
|
||||
<ObservationFromCompass/>
|
||||
<DiscreteMovementCommands/>
|
||||
|
||||
<RewardForMissionEnd>
|
||||
<Reward description="out_of_time" reward="-1" />
|
||||
</RewardForMissionEnd>
|
||||
|
||||
<RewardForTouchingBlockType>
|
||||
<Block reward="-1.0" type="lava" behaviour="onceOnly"/>
|
||||
<Block reward="1.0" type="lapis_block" behaviour="onceOnly"/>
|
||||
</RewardForTouchingBlockType>
|
||||
|
||||
<RewardForSendingCommand reward="-0.02"/>
|
||||
|
||||
<AgentQuitFromTouchingBlockType>
|
||||
<Block type="lava" />
|
||||
<Block type="lapis_block" />
|
||||
</AgentQuitFromTouchingBlockType>
|
||||
<PauseCommand/>
|
||||
<AgentQuitFromReachingCommandQuota total="50"/>
|
||||
</AgentHandlers>
|
||||
</AgentSection>
|
||||
|
||||
|
||||
</Mission>
|
||||
@@ -0,0 +1,95 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no" ?>
|
||||
<Mission xmlns="http://ProjectMalmo.microsoft.com" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
|
||||
|
||||
<About>
|
||||
<Summary>$(ENV_NAME)</Summary>
|
||||
</About>
|
||||
|
||||
<ModSettings>
|
||||
<MsPerTick>50</MsPerTick>
|
||||
</ModSettings>
|
||||
|
||||
<ServerSection>
|
||||
<ServerInitialConditions>
|
||||
<Time>
|
||||
<StartTime>6000</StartTime>
|
||||
<AllowPassageOfTime>false</AllowPassageOfTime>
|
||||
</Time>
|
||||
<Weather>clear</Weather>
|
||||
<AllowSpawning>false</AllowSpawning>
|
||||
</ServerInitialConditions>
|
||||
<ServerHandlers>
|
||||
<FlatWorldGenerator generatorString="3;7,220*1,5*3,2;3;,biome_1"/>
|
||||
|
||||
<DrawingDecorator>
|
||||
<DrawSphere x="-29" y="70" z="-2" radius="100" type="air"/>
|
||||
<DrawCuboid x1="-34" y1="70" z1="-7" x2="-24" y2="70" z2="3" type="lava" />
|
||||
</DrawingDecorator>
|
||||
|
||||
<MazeDecorator>
|
||||
<Seed>{SEED_PLACEHOLDER}</Seed>
|
||||
<SizeAndPosition width="6" length="6" height="10" xOrigin="-32" yOrigin="69" zOrigin="-5"/>
|
||||
<StartBlock type="emerald_block" fixedToEdge="false"/>
|
||||
<EndBlock type="lapis_block" fixedToEdge="false"/>
|
||||
<PathBlock type="grass"/>
|
||||
<FloorBlock type="air"/>
|
||||
<GapBlock type="lava"/>
|
||||
<GapProbability>0.6</GapProbability>
|
||||
<AllowDiagonalMovement>false</AllowDiagonalMovement>
|
||||
</MazeDecorator>
|
||||
|
||||
<ServerQuitFromTimeUp timeLimitMs="300000" description="out_of_time"/>
|
||||
<ServerQuitWhenAnyAgentFinishes/>
|
||||
</ServerHandlers>
|
||||
</ServerSection>
|
||||
|
||||
<AgentSection mode="Survival">
|
||||
<Name>AML_Bot</Name>
|
||||
|
||||
<AgentStart>
|
||||
<Placement x="-28.5" y="71.0" z="-1.5" pitch="70" yaw="0"/>
|
||||
</AgentStart>
|
||||
|
||||
<AgentHandlers>
|
||||
|
||||
<VideoProducer want_depth="false">
|
||||
<Width>84</Width>
|
||||
<Height>84</Height>
|
||||
</VideoProducer>
|
||||
|
||||
<FileBasedPerformanceProducer/>
|
||||
|
||||
<ObservationFromFullInventory flat="false"/>
|
||||
<ObservationFromFullStats/>
|
||||
<HumanLevelCommands>
|
||||
<ModifierList type="deny-list">
|
||||
<command>moveMouse</command>
|
||||
<command>inventory</command>
|
||||
</ModifierList>
|
||||
</HumanLevelCommands>
|
||||
<CameraCommands/>
|
||||
<ObservationFromCompass/>
|
||||
<DiscreteMovementCommands/>
|
||||
|
||||
<RewardForMissionEnd>
|
||||
<Reward description="out_of_time" reward="-1" />
|
||||
</RewardForMissionEnd>
|
||||
|
||||
<RewardForTouchingBlockType>
|
||||
<Block reward="-1.0" type="lava" behaviour="onceOnly"/>
|
||||
<Block reward="1.0" type="lapis_block" behaviour="onceOnly"/>
|
||||
</RewardForTouchingBlockType>
|
||||
|
||||
<RewardForSendingCommand reward="-0.02"/>
|
||||
|
||||
<AgentQuitFromTouchingBlockType>
|
||||
<Block type="lava" />
|
||||
<Block type="lapis_block" />
|
||||
</AgentQuitFromTouchingBlockType>
|
||||
<PauseCommand/>
|
||||
<AgentQuitFromReachingCommandQuota total="50"/>
|
||||
</AgentHandlers>
|
||||
</AgentSection>
|
||||
|
||||
|
||||
</Mission>
|
||||
@@ -0,0 +1,74 @@
|
||||
<?xml version="1.0" encoding="UTF-8" standalone="no" ?>
|
||||
<Mission xmlns="http://ProjectMalmo.microsoft.com" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
|
||||
|
||||
<About>
|
||||
<Summary>AML-Video-Gatherer</Summary>
|
||||
</About>
|
||||
|
||||
<ModSettings>
|
||||
<MsPerTick>50</MsPerTick>
|
||||
</ModSettings>
|
||||
|
||||
<ServerSection>
|
||||
<ServerInitialConditions>
|
||||
<Time>
|
||||
<StartTime>6000</StartTime>
|
||||
<AllowPassageOfTime>false</AllowPassageOfTime>
|
||||
</Time>
|
||||
<Weather>clear</Weather>
|
||||
<AllowSpawning>false</AllowSpawning>
|
||||
</ServerInitialConditions>
|
||||
<ServerHandlers>
|
||||
<FlatWorldGenerator generatorString="3;7,220*1,5*3,2;3;,biome_1"/>
|
||||
|
||||
<MazeDecorator>
|
||||
<Seed>{SEED_PLACEHOLDER}</Seed>
|
||||
<SizeAndPosition width="6" length="6" height="10" xOrigin="-32" yOrigin="69" zOrigin="-5"/>
|
||||
<StartBlock type="emerald_block" fixedToEdge="false"/>
|
||||
<EndBlock type="lapis_block" fixedToEdge="false"/>
|
||||
<PathBlock type="grass"/>
|
||||
<FloorBlock type="air"/>
|
||||
<GapBlock type="lava"/>
|
||||
<GapProbability>0.6</GapProbability>
|
||||
<AllowDiagonalMovement>false</AllowDiagonalMovement>
|
||||
</MazeDecorator>
|
||||
|
||||
<ServerQuitFromTimeUp timeLimitMs="300000" description="out_of_time"/>
|
||||
<ServerQuitWhenAnyAgentFinishes/>
|
||||
</ServerHandlers>
|
||||
</ServerSection>
|
||||
|
||||
<AgentSection mode="Survival">
|
||||
<Name>Agent</Name>
|
||||
|
||||
<AgentStart>
|
||||
<Placement x="-28.5" y="71.0" z="-1.5" yaw="0"/>
|
||||
</AgentStart>
|
||||
|
||||
<AgentHandlers>
|
||||
<HumanLevelCommands>
|
||||
<ModifierList type="deny-list">
|
||||
<command>moveMouse</command>
|
||||
<command>inventory</command>
|
||||
</ModifierList>
|
||||
</HumanLevelCommands>
|
||||
<DiscreteMovementCommands/>
|
||||
<MissionQuitCommands/>
|
||||
<AgentQuitFromReachingCommandQuota total="50"/>
|
||||
</AgentHandlers>
|
||||
</AgentSection>
|
||||
|
||||
<AgentSection mode="Spectator">
|
||||
<Name>Camera_Bot</Name>
|
||||
<AgentStart>
|
||||
<Placement x="-29" y="72" z="-6.7" pitch="16" yaw="0"/>
|
||||
</AgentStart>
|
||||
<AgentHandlers>
|
||||
<VideoProducer want_depth="false">
|
||||
<Width>860</Width>
|
||||
<Height>480</Height>
|
||||
</VideoProducer>
|
||||
<AbsoluteMovementCommands/>
|
||||
</AgentHandlers>
|
||||
</AgentSection>
|
||||
</Mission>
|
||||
@@ -0,0 +1,130 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
from azureml.core import Run
|
||||
from azureml.core.model import Model
|
||||
|
||||
from minecraft_environment import create_env_for_rollout
|
||||
from malmo_video_recorder import MalmoVideoRecorder
|
||||
from gym import wrappers
|
||||
|
||||
import ray
|
||||
import ray.tune as tune
|
||||
from ray.rllib import rollout
|
||||
from ray.tune.registry import get_trainable_cls
|
||||
|
||||
|
||||
def write_mission_file_for_seed(mission_file, seed):
|
||||
with open(mission_file, 'r') as base_file:
|
||||
mission_file_path = mission_file.replace('v0', seed)
|
||||
content = base_file.read().format(SEED_PLACEHOLDER=seed)
|
||||
|
||||
mission_file = open(mission_file_path, 'w')
|
||||
mission_file.writelines(content)
|
||||
mission_file.close()
|
||||
|
||||
return mission_file_path
|
||||
|
||||
|
||||
def run_rollout(trainable_type, mission_file, seed):
|
||||
# Writes the mission file for minerl
|
||||
mission_file_path = write_mission_file_for_seed(mission_file, seed)
|
||||
|
||||
# Instantiate the agent. Note: the IMPALA trainer implementation in
|
||||
# Ray uses an AsyncSamplesOptimizer. Under the hood, this starts a
|
||||
# LearnerThread which will wait for training samples. This will fail
|
||||
# after a timeout, but has no influence on the rollout. See
|
||||
# https://github.com/ray-project/ray/blob/708dff6d8f7dd6f7919e06c1845f1fea0cca5b89/rllib/optimizers/aso_learner.py#L66
|
||||
config = {
|
||||
"env_config": {
|
||||
"mission": mission_file_path,
|
||||
"is_rollout": True,
|
||||
"seed": seed
|
||||
},
|
||||
"num_workers": 0
|
||||
}
|
||||
cls = get_trainable_cls(args.run)
|
||||
agent = cls(env="Minecraft", config=config)
|
||||
|
||||
# The optimizer is not needed during a rollout
|
||||
agent.optimizer.stop()
|
||||
|
||||
# Load state from checkpoint
|
||||
agent.restore(f'{checkpoint_path}/{checkpoint_file}')
|
||||
|
||||
# Get a reference to the environment
|
||||
env = agent.workers.local_worker().env
|
||||
|
||||
# Let the agent choose actions until the game is over
|
||||
obs = env.reset()
|
||||
done = False
|
||||
total_reward = 0
|
||||
|
||||
while not done:
|
||||
action = agent.compute_action(obs)
|
||||
obs, reward, done, info = env.step(action)
|
||||
|
||||
total_reward += reward
|
||||
|
||||
print(f'Total reward using seed {seed}: {total_reward}')
|
||||
|
||||
# This avoids a sigterm trace in the logs, see minerl.env.malmo.Instance
|
||||
env.instance.watcher_process.kill()
|
||||
|
||||
env.close()
|
||||
agent.stop()
|
||||
|
||||
return env.get_trajectory()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model_name', required=True)
|
||||
parser.add_argument('--run', required=False, default="IMPALA")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Register custom Minecraft environment
|
||||
tune.register_env("Minecraft", create_env_for_rollout)
|
||||
|
||||
ray.init(address='auto')
|
||||
|
||||
# Download the model files (contains a checkpoint)
|
||||
ws = Run.get_context().experiment.workspace
|
||||
model = Model(ws, args.model_name)
|
||||
checkpoint_path = model.download(exist_ok=True)
|
||||
|
||||
files_ = os.listdir(checkpoint_path)
|
||||
cp_pattern = re.compile('^checkpoint-\\d+$')
|
||||
|
||||
checkpoint_file = None
|
||||
for f_ in files_:
|
||||
if cp_pattern.match(f_):
|
||||
checkpoint_file = f_
|
||||
|
||||
if checkpoint_file is None:
|
||||
raise Exception("No checkpoint file found.")
|
||||
|
||||
# These are the Minecraft mission seeds for the rollouts
|
||||
rollout_seeds = ['1234', '43289', '65224', '983341']
|
||||
|
||||
# Initialize the Malmo video recorder
|
||||
video_recorder = MalmoVideoRecorder()
|
||||
video_recorder.init_malmo()
|
||||
|
||||
# Path references to the mission files
|
||||
base_training_mission_file = \
|
||||
'minecraft_missions/lava_maze_rollout-v0.xml'
|
||||
base_video_recording_mission_file = \
|
||||
'minecraft_missions/lava_maze_rollout_video.xml'
|
||||
|
||||
for seed in rollout_seeds:
|
||||
trajectory = run_rollout(
|
||||
args.run,
|
||||
base_training_mission_file,
|
||||
seed)
|
||||
|
||||
video_recorder.record_malmo_video(
|
||||
trajectory,
|
||||
base_video_recording_mission_file,
|
||||
seed)
|
||||
@@ -0,0 +1,45 @@
|
||||
import ray
|
||||
import ray.tune as tune
|
||||
|
||||
from utils import callbacks
|
||||
from minecraft_environment import create_env
|
||||
|
||||
|
||||
def stop(trial_id, result):
|
||||
return result["episode_reward_mean"] >= 1 \
|
||||
or result["time_total_s"] > 5 * 60 * 60
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
tune.register_env("Minecraft", create_env)
|
||||
|
||||
ray.init(address='auto')
|
||||
|
||||
tune.run(
|
||||
run_or_experiment="IMPALA",
|
||||
config={
|
||||
"env": "Minecraft",
|
||||
"env_config": {
|
||||
"mission": "minecraft_missions/lava_maze-v0.xml"
|
||||
},
|
||||
"num_workers": 10,
|
||||
"num_cpus_per_worker": 2,
|
||||
"rollout_fragment_length": 50,
|
||||
"train_batch_size": 1024,
|
||||
"replay_buffer_num_slots": 4000,
|
||||
"replay_proportion": 10,
|
||||
"learner_queue_timeout": 900,
|
||||
"num_sgd_iter": 2,
|
||||
"num_data_loader_buffers": 2,
|
||||
"exploration_config": {
|
||||
"type": "EpsilonGreedy",
|
||||
"initial_epsilon": 1.0,
|
||||
"final_epsilon": 0.02,
|
||||
"epsilon_timesteps": 500000
|
||||
},
|
||||
"callbacks": {"on_train_result": callbacks.on_train_result},
|
||||
},
|
||||
stop=stop,
|
||||
checkpoint_at_end=True,
|
||||
local_dir='./logs'
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
'''RLlib callbacks module:
|
||||
Common callback methods to be passed to RLlib trainer.
|
||||
'''
|
||||
|
||||
from azureml.core import Run
|
||||
|
||||
|
||||
def on_train_result(info):
|
||||
'''Callback on train result to record metrics returned by trainer.
|
||||
'''
|
||||
run = Run.get_context()
|
||||
run.log(
|
||||
name='episode_reward_mean',
|
||||
value=info["result"]["episode_reward_mean"])
|
||||
|
||||
run.log(
|
||||
name='episodes_total',
|
||||
value=info["result"]["episodes_total"])
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 4.2 MiB |
@@ -0,0 +1,925 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
||||
"\n",
|
||||
"Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Reinforcement Learning in Azure Machine Learning - Training a Minecraft agent using custom environments\n",
|
||||
"\n",
|
||||
"This tutorial will show how to set up a more complex reinforcement\n",
|
||||
"learning (RL) training scenario. It demonstrates how to train an agent to\n",
|
||||
"navigate through a lava maze in the Minecraft game using Azure Machine\n",
|
||||
"Learning.\n",
|
||||
"\n",
|
||||
"**Please note:** This notebook trains an agent on a randomly generated\n",
|
||||
"Minecraft level. As a result, on rare occasions, a training run may fail\n",
|
||||
"to produce a model that can solve the maze. If this happens, you can\n",
|
||||
"re-run the training step as indicated below.\n",
|
||||
"\n",
|
||||
"**Please note:** This notebook uses 1 NC6 type node and 8 D2 type nodes\n",
|
||||
"for up to 5 hours of training, which corresponds to approximately $9.06 (USD)\n",
|
||||
"as of May 2020.\n",
|
||||
"\n",
|
||||
"Minecraft is currently one of the most popular video\n",
|
||||
"games and as such has been a study object for RL. [Project \n",
|
||||
"Malmo](https://www.microsoft.com/en-us/research/project/project-malmo/) is\n",
|
||||
"a platform for artificial intelligence experimentation and research built on\n",
|
||||
"top of Minecraft. We will use Minecraft [gym](https://gym.openai.com) environments from Project\n",
|
||||
"Malmo's 2019 MineRL competition, which are part of the \n",
|
||||
"[MineRL](http://minerl.io/docs/index.html) Python package.\n",
|
||||
"\n",
|
||||
"Minecraft environments require a display to run, so we will demonstrate\n",
|
||||
"how to set up a virtual display within the docker container used for training.\n",
|
||||
"Learning will be based on the agent's visual observations. To\n",
|
||||
"generate the necessary amount of sample data, we will run several\n",
|
||||
"instances of the Minecraft game in parallel. Below, you can see a video of\n",
|
||||
"a trained agent navigating a lava maze. Starting from the green position,\n",
|
||||
"it moves to the blue position by moving forward, turning left or turning right:\n",
|
||||
"\n",
|
||||
"<table style=\"width:50%\">\n",
|
||||
" <tr>\n",
|
||||
" <th style=\"text-align: center;\">\n",
|
||||
" <img src=\"./images/lava_maze_minecraft.gif\" alt=\"Minecraft lava maze\" align=\"middle\" margin-left=\"auto\" margin-right=\"auto\"/>\n",
|
||||
" </th>\n",
|
||||
" </tr>\n",
|
||||
" <tr style=\"text-align: center;\">\n",
|
||||
" <th>Fig 1. Video of a trained Minecraft agent navigating a lava maze.</th>\n",
|
||||
" </tr>\n",
|
||||
"</table>\n",
|
||||
"\n",
|
||||
"The tutorial will cover the following steps:\n",
|
||||
"- Initializing Azure Machine Learning resources for training\n",
|
||||
"- Training the RL agent with Azure Machine Learning service\n",
|
||||
"- Monitoring training progress\n",
|
||||
"- Reviewing training results\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Prerequisites\n",
|
||||
"\n",
|
||||
"The user should have completed the Azure Machine Learning introductory tutorial.\n",
|
||||
"You will need to make sure that you have a valid subscription id, a resource group and a\n",
|
||||
"workspace. For detailed instructions see [Tutorial: Get started creating\n",
|
||||
"your first ML experiment.](https://docs.microsoft.com/en-us/azure/machine-learning/tutorial-1st-experiment-sdk-setup)\n",
|
||||
"\n",
|
||||
"In addition, please follow the instructions in the [Reinforcement Learning in\n",
|
||||
"Azure Machine Learning - Setting Up Development Environment](../setup/devenv_setup.ipynb)\n",
|
||||
"notebook to correctly set up a Virtual Network which is required for completing \n",
|
||||
"this tutorial.\n",
|
||||
"\n",
|
||||
"While this is a standalone notebook, we highly recommend going over the\n",
|
||||
"introductory notebooks for RL first.\n",
|
||||
"- Getting started:\n",
|
||||
" - [RL using a compute instance with Azure Machine Learning service](../cartpole-on-compute-instance/cartpole_ci.ipynb)\n",
|
||||
" - [Using Azure Machine Learning compute](../cartpole-on-single-compute/cartpole_sc.ipynb)\n",
|
||||
"- [Scaling RL training runs with Azure Machine Learning service](../atari-on-distributed-compute/pong_rllib.ipynb)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Initialize resources\n",
|
||||
"\n",
|
||||
"All required Azure Machine Learning service resources for this tutorial can be set up from Jupyter.\n",
|
||||
"This includes:\n",
|
||||
"- Connecting to your existing Azure Machine Learning workspace.\n",
|
||||
"- Creating an experiment to track runs.\n",
|
||||
"- Creating remote compute targets for [Ray](https://docs.ray.io/en/latest/index.html).\n",
|
||||
"\n",
|
||||
"### Azure Machine Learning SDK\n",
|
||||
"\n",
|
||||
"Display the Azure Machine Learning SDK version."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import azureml.core\n",
|
||||
"print(\"Azure Machine Learning SDK Version: \", azureml.core.VERSION)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Connect to workspace\n",
|
||||
"\n",
|
||||
"Get a reference to an existing Azure Machine Learning workspace."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core import Workspace\n",
|
||||
"\n",
|
||||
"ws = Workspace.from_config()\n",
|
||||
"print(ws.name, ws.location, ws.resource_group, sep=' | ')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create an experiment\n",
|
||||
"\n",
|
||||
"Create an experiment to track the runs in your workspace. A\n",
|
||||
"workspace can have multiple experiments and each experiment\n",
|
||||
"can be used to track multiple runs (see [documentation](https://docs.microsoft.com/en-us/python/api/azureml-core/azureml.core.experiment.experiment?view=azure-ml-py)\n",
|
||||
"for details)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"nbpresent": {
|
||||
"id": "bc70f780-c240-4779-96f3-bc5ef9a37d59"
|
||||
}
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core import Experiment\n",
|
||||
"\n",
|
||||
"exp = Experiment(workspace=ws, name='minecraft-maze')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create or attach an existing compute resource\n",
|
||||
"\n",
|
||||
"A compute target is a designated compute resource where you\n",
|
||||
"run your training script. For more information, see [What\n",
|
||||
"are compute targets in Azure Machine Learning service?](https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target).\n",
|
||||
"\n",
|
||||
"#### GPU target for Ray head\n",
|
||||
"\n",
|
||||
"In the experiment setup for this tutorial, the Ray head node\n",
|
||||
"will run on a GPU-enabled node. A maximum cluster size\n",
|
||||
"of 1 node is therefore sufficient. If you wish to run\n",
|
||||
"multiple experiments in parallel using the same GPU\n",
|
||||
"cluster, you may elect to increase this number. The cluster\n",
|
||||
"will automatically scale down to 0 nodes when no training jobs\n",
|
||||
"are scheduled (see `min_nodes`).\n",
|
||||
"\n",
|
||||
"The code below creates a compute cluster of GPU-enabled NC6\n",
|
||||
"nodes. If the cluster with the specified name is already in\n",
|
||||
"your workspace the code will skip the creation process.\n",
|
||||
"\n",
|
||||
"Note that we must specify a Virtual Network during compute\n",
|
||||
"creation to allow communication between the cluster running\n",
|
||||
"the Ray head node and the additional Ray compute nodes. For\n",
|
||||
"details on how to setup the Virtual Network, please follow the\n",
|
||||
"instructions in the \"Prerequisites\" section above.\n",
|
||||
"\n",
|
||||
"**Note: Creation of a compute resource can take several minutes**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core.compute import ComputeTarget, AmlCompute\n",
|
||||
"from azureml.core.compute_target import ComputeTargetException\n",
|
||||
"\n",
|
||||
"# please enter the name of your Virtual Network (see Prerequisites -> Workspace setup)\n",
|
||||
"vnet_name = 'your_vnet'\n",
|
||||
"\n",
|
||||
"# name of the Virtual Network subnet ('default' the default name)\n",
|
||||
"subnet_name = 'default'\n",
|
||||
"\n",
|
||||
"gpu_cluster_name = 'gpu-cluster-nc6'\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" gpu_cluster = ComputeTarget(workspace=ws, name=gpu_cluster_name)\n",
|
||||
" print('Found existing compute target')\n",
|
||||
"except ComputeTargetException:\n",
|
||||
" print('Creating a new compute target...')\n",
|
||||
" compute_config = AmlCompute.provisioning_configuration(\n",
|
||||
" vm_size='Standard_NC6',\n",
|
||||
" min_nodes=0,\n",
|
||||
" max_nodes=1,\n",
|
||||
" vnet_resourcegroup_name=ws.resource_group,\n",
|
||||
" vnet_name=vnet_name,\n",
|
||||
" subnet_name=subnet_name)\n",
|
||||
"\n",
|
||||
" gpu_cluster = ComputeTarget.create(ws, gpu_cluster_name, compute_config)\n",
|
||||
" gpu_cluster.wait_for_completion(show_output=True, min_node_count=None, timeout_in_minutes=20)\n",
|
||||
"\n",
|
||||
" print('Cluster created.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### CPU target for additional Ray nodes\n",
|
||||
"\n",
|
||||
"The code below creates a compute cluster of D2 nodes. If the cluster with the specified name is already in your workspace the code will skip the creation process.\n",
|
||||
"\n",
|
||||
"This cluster will be used to start additional Ray nodes\n",
|
||||
"increasing the clusters CPU resources.\n",
|
||||
"\n",
|
||||
"**Note: Creation of a compute resource can take several minutes**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"cpu_cluster_name = 'cpu-cluster-d2'\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" cpu_cluster = ComputeTarget(workspace=ws, name=cpu_cluster_name)\n",
|
||||
" print('Found existing compute target')\n",
|
||||
"except ComputeTargetException:\n",
|
||||
" print('Creating a new compute target...')\n",
|
||||
" compute_config = AmlCompute.provisioning_configuration(\n",
|
||||
" vm_size='STANDARD_D2',\n",
|
||||
" min_nodes=0,\n",
|
||||
" max_nodes=10,\n",
|
||||
" vnet_resourcegroup_name=ws.resource_group,\n",
|
||||
" vnet_name=vnet_name,\n",
|
||||
" subnet_name=subnet_name)\n",
|
||||
"\n",
|
||||
" cpu_cluster = ComputeTarget.create(ws, cpu_cluster_name, compute_config)\n",
|
||||
" cpu_cluster.wait_for_completion(show_output=True, min_node_count=None, timeout_in_minutes=20)\n",
|
||||
"\n",
|
||||
" print('Cluster created.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Training the agent\n",
|
||||
"\n",
|
||||
"### Training environments\n",
|
||||
"\n",
|
||||
"This tutorial uses custom docker images (CPU and GPU respectively)\n",
|
||||
"with the necessary software installed. The\n",
|
||||
"[Environment](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-use-environments)\n",
|
||||
"class stores the configuration for the training environment. The docker\n",
|
||||
"image is set via `env.docker.base_image` which can point to any\n",
|
||||
"publicly available docker image. `user_managed_dependencies`\n",
|
||||
"is set so that the preinstalled Python packages in the image are preserved.\n",
|
||||
"\n",
|
||||
"Note that since Minecraft requires a display to start, we set the `interpreter_path`\n",
|
||||
"such that the Python process is started via **xvfb-run**."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core import Environment\n",
|
||||
"\n",
|
||||
"def create_env(env_type):\n",
|
||||
" env = Environment(name='minecraft-{env_type}'.format(env_type=env_type))\n",
|
||||
"\n",
|
||||
" env.docker.enabled = True\n",
|
||||
" env.docker.base_image = 'akdmsft/minecraft-{env_type}'.format(env_type=env_type)\n",
|
||||
"\n",
|
||||
" env.python.interpreter_path = \"xvfb-run -s '-screen 0 640x480x16 -ac +extension GLX +render' python\"\n",
|
||||
" env.python.user_managed_dependencies = True\n",
|
||||
" \n",
|
||||
" return env\n",
|
||||
" \n",
|
||||
"cpu_minecraft_env = create_env('cpu')\n",
|
||||
"gpu_minecraft_env = create_env('gpu')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Training script\n",
|
||||
"\n",
|
||||
"As described above, we use the MineRL Python package to launch\n",
|
||||
"Minecraft game instances. MineRL provides several OpenAI gym\n",
|
||||
"environments for different scenarios, such as chopping wood.\n",
|
||||
"Besides predefined environments, MineRL lets its users create\n",
|
||||
"custom Minecraft environments through\n",
|
||||
"[minerl.env](http://minerl.io/docs/api/env.html). In the helper\n",
|
||||
"file **minecraft_environment.py** provided with this tutorial, we use the\n",
|
||||
"latter option to customize a Minecraft level with a lava maze\n",
|
||||
"that the agent has to navigate. The agent receives a negative\n",
|
||||
"reward of -1 for falling into the lava, a negative reward of\n",
|
||||
"-0.02 for sending a command (i.e. navigating through the maze\n",
|
||||
"with fewer actions yields a higher total reward) and a positive reward\n",
|
||||
"of 1 for reaching the goal. To encourage the agent to explore\n",
|
||||
"the maze, it also receives a positive reward of 0.1 for visiting\n",
|
||||
"a tile for the first time.\n",
|
||||
"\n",
|
||||
"The agent learns purely from visual observations and the image\n",
|
||||
"is scaled to an 84x84 format, stacking four frames. For the\n",
|
||||
"purposes of this example, we use a small action space of size\n",
|
||||
"three: move forward, turn 90 degrees to the left, and turn 90\n",
|
||||
"degrees to the right.\n",
|
||||
"\n",
|
||||
"The training script itself registers the function to create training\n",
|
||||
"environments with the `tune.register_env` function and connects to\n",
|
||||
"the Ray cluster Azure Machine Learning service started on the GPU \n",
|
||||
"and CPU nodes. Lastly, it starts a RL training run with `tune.run()`.\n",
|
||||
"\n",
|
||||
"We recommend setting the `local_dir` parameter to `./logs` as this\n",
|
||||
"directory will automatically become available as part of the training\n",
|
||||
"run's files in the Azure Portal. The Tensorboard integration\n",
|
||||
"(see \"View the Tensorboard\" section below) also depends on the files'\n",
|
||||
"availability. For a list of common parameter options, please refer\n",
|
||||
"to the [Ray documentation](https://docs.ray.io/en/latest/rllib-training.html#common-parameters).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"# Taken from minecraft_environment.py and minecraft_train.py\n",
|
||||
"\n",
|
||||
"# Define a function to create a MineRL environment\n",
|
||||
"def create_env(config):\n",
|
||||
" mission = config['mission']\n",
|
||||
" port = 1000 * config.worker_index + config.vector_index\n",
|
||||
" print('*********************************************')\n",
|
||||
" print(f'* Worker {config.worker_index} creating from mission: {mission}, port {port}')\n",
|
||||
" print('*********************************************')\n",
|
||||
"\n",
|
||||
" if config.worker_index == 0:\n",
|
||||
" # The first environment is only used for checking the action and observation space.\n",
|
||||
" # By using a dummy environment, there's no need to spin up a Minecraft instance behind it\n",
|
||||
" # saving some CPU resources on the head node.\n",
|
||||
" return DummyEnv()\n",
|
||||
"\n",
|
||||
" env = EnvWrapper(mission, port)\n",
|
||||
" env = TrackingEnv(env)\n",
|
||||
" env = FrameStack(env, 2)\n",
|
||||
" \n",
|
||||
" return env\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"def stop(trial_id, result):\n",
|
||||
" return result[\"episode_reward_mean\"] >= 1 \\\n",
|
||||
" or result[\"time_total_s\"] > 5 * 60 * 60\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"if __name__ == '__main__':\n",
|
||||
" tune.register_env(\"Minecraft\", create_env)\n",
|
||||
"\n",
|
||||
" ray.init(address='auto')\n",
|
||||
"\n",
|
||||
" tune.run(\n",
|
||||
" run_or_experiment=\"IMPALA\",\n",
|
||||
" config={\n",
|
||||
" \"env\": \"Minecraft\",\n",
|
||||
" \"env_config\": {\n",
|
||||
" \"mission\": \"minecraft_missions/lava_maze-v0.xml\"\n",
|
||||
" },\n",
|
||||
" \"num_workers\": 10,\n",
|
||||
" \"num_cpus_per_worker\": 2,\n",
|
||||
" \"rollout_fragment_length\": 50,\n",
|
||||
" \"train_batch_size\": 1024,\n",
|
||||
" \"replay_buffer_num_slots\": 4000,\n",
|
||||
" \"replay_proportion\": 10,\n",
|
||||
" \"learner_queue_timeout\": 900,\n",
|
||||
" \"num_sgd_iter\": 2,\n",
|
||||
" \"num_data_loader_buffers\": 2,\n",
|
||||
" \"exploration_config\": {\n",
|
||||
" \"type\": \"EpsilonGreedy\",\n",
|
||||
" \"initial_epsilon\": 1.0,\n",
|
||||
" \"final_epsilon\": 0.02,\n",
|
||||
" \"epsilon_timesteps\": 500000\n",
|
||||
" },\n",
|
||||
" \"callbacks\": {\"on_train_result\": callbacks.on_train_result},\n",
|
||||
" },\n",
|
||||
" stop=stop,\n",
|
||||
" checkpoint_at_end=True,\n",
|
||||
" local_dir='./logs'\n",
|
||||
" )\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Submitting a training run\n",
|
||||
"\n",
|
||||
"Below, you create the training run using a `ReinforcementLearningEstimator`\n",
|
||||
"object, which contains all the configuration parameters for this experiment:\n",
|
||||
"- `source_directory`: Contains the training script and helper files to be\n",
|
||||
"copied onto the node running the Ray head.\n",
|
||||
"- `entry_script`: The training script, described in more detail above..\n",
|
||||
"- `compute_target`: The compute target for the Ray head and training\n",
|
||||
"script execution.\n",
|
||||
"- `environment`: The Azure machine learning environment definition for\n",
|
||||
"the node running the Ray head.\n",
|
||||
"- `worker_configuration`: The configuration object for the additional\n",
|
||||
"Ray nodes to be attached to the Ray cluster:\n",
|
||||
" - `compute_target`: The compute target for the additional Ray nodes.\n",
|
||||
" - `node_count`: The number of nodes to attach to the Ray cluster.\n",
|
||||
" - `environment`: The environment definition for the additional Ray nodes.\n",
|
||||
"- `max_run_duration_seconds`: The time after which to abort the run if it\n",
|
||||
"is still running.\n",
|
||||
"- `shm_size`: The size of docker container's shared memory block. \n",
|
||||
"\n",
|
||||
"For more details, please take a look at the [online documentation](https://docs.microsoft.com/en-us/python/api/azureml-contrib-reinforcementlearning/?view=azure-ml-py)\n",
|
||||
"for Azure Machine Learning service's reinforcement learning offering.\n",
|
||||
"\n",
|
||||
"We configure 8 extra D2 (worker) nodes for the Ray cluster, giving us a total of\n",
|
||||
"22 CPUs and 1 GPU. The GPU and one CPU are used by the IMPALA learner,\n",
|
||||
"and each MineRL environment receives 2 CPUs allowing us to spawn a total\n",
|
||||
"of 10 rollout workers (see `num_workers` parameter in the training script).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"Lastly, the `RunDetails` widget displays information about the submitted\n",
|
||||
"RL experiment, including a link to the Azure portal with more details."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.contrib.train.rl import ReinforcementLearningEstimator, WorkerConfiguration\n",
|
||||
"from azureml.widgets import RunDetails\n",
|
||||
"\n",
|
||||
"worker_config = WorkerConfiguration(\n",
|
||||
" compute_target=cpu_cluster, \n",
|
||||
" node_count=8,\n",
|
||||
" environment=cpu_minecraft_env)\n",
|
||||
"\n",
|
||||
"rl_est = ReinforcementLearningEstimator(\n",
|
||||
" source_directory='files',\n",
|
||||
" entry_script='minecraft_train.py',\n",
|
||||
" compute_target=gpu_cluster,\n",
|
||||
" environment=gpu_minecraft_env,\n",
|
||||
" worker_configuration=worker_config,\n",
|
||||
" max_run_duration_seconds=6 * 60 * 60,\n",
|
||||
" shm_size=1024 * 1024 * 1024 * 30)\n",
|
||||
"\n",
|
||||
"train_run = exp.submit(rl_est)\n",
|
||||
"\n",
|
||||
"RunDetails(train_run).show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# If you wish to cancel the run before it completes, uncomment and execute:\n",
|
||||
"#train_run.cancel()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Monitoring training progress\n",
|
||||
"\n",
|
||||
"### View the Tensorboard\n",
|
||||
"\n",
|
||||
"The Tensorboard can be displayed via the Azure Machine Learning service's\n",
|
||||
"[Tensorboard API](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-monitor-tensorboard).\n",
|
||||
"When running locally, please make sure to follow the instructions in the\n",
|
||||
"link and install required packages. Running this cell will output a URL\n",
|
||||
"for the Tensorboard.\n",
|
||||
"\n",
|
||||
"Note that the training script sets the log directory when starting RLlib\n",
|
||||
"via the `local_dir` parameter. `./logs` will automatically appear in\n",
|
||||
"the downloadable files for a run. Since this script is executed on the\n",
|
||||
"Ray head node run, we need to get a reference to it as shown below.\n",
|
||||
"\n",
|
||||
"The Tensorboard API will continuously stream logs from the run.\n",
|
||||
"\n",
|
||||
"**Note: It may take a couple of minutes after the run is in \"Running\" state\n",
|
||||
"before Tensorboard files are available and the board will refresh automatically**"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import time\n",
|
||||
"from azureml.tensorboard import Tensorboard\n",
|
||||
"\n",
|
||||
"head_run = None\n",
|
||||
"\n",
|
||||
"timeout = 60\n",
|
||||
"while timeout > 0 and head_run is None:\n",
|
||||
" timeout -= 1\n",
|
||||
" \n",
|
||||
" try:\n",
|
||||
" head_run = next(r for r in train_run.get_children() if r.id.endswith('head'))\n",
|
||||
" except StopIteration:\n",
|
||||
" time.sleep(1)\n",
|
||||
"\n",
|
||||
"tb = Tensorboard([head_run])\n",
|
||||
"tb.start()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Review results\n",
|
||||
"\n",
|
||||
"Please ensure that the training run has completed before continuing with this section."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"train_run.wait_for_completion()\n",
|
||||
"\n",
|
||||
"print('Training run completed.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Please note:** If the final \"episode_reward_mean\" metric from the training run is negative,\n",
|
||||
"the produced model does not solve the problem of navigating the maze well. You can view\n",
|
||||
"the metric on the Tensorboard or in \"Metrics\" section of the head run in the Azure Machine Learning\n",
|
||||
"portal. We recommend training a new model by rerunning the notebook starting from \"Submitting a training run\".\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"### Export final model\n",
|
||||
"\n",
|
||||
"The key result from the training run is the final checkpoint\n",
|
||||
"containing the state of the IMPALA trainer (model) upon meeting the\n",
|
||||
"stopping criteria specified in `minecraft_train.py`.\n",
|
||||
"\n",
|
||||
"Azure Machine Learning service offers the [Model.register()](https://docs.microsoft.com/en-us/python/api/azureml-core/azureml.core.model.model?view=azure-ml-py)\n",
|
||||
"API which allows you to persist the model files from the\n",
|
||||
"training run. We identify the directory containing the\n",
|
||||
"final model written during the training run and register\n",
|
||||
"it with Azure Machine Learning service. We use a Dataset\n",
|
||||
"object to filter out the correct files."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import re\n",
|
||||
"import os\n",
|
||||
"import tempfile\n",
|
||||
"\n",
|
||||
"from azureml.core import Dataset\n",
|
||||
"\n",
|
||||
"path_prefix = os.path.join(tempfile.gettempdir(), 'tmp_training_artifacts')\n",
|
||||
"\n",
|
||||
"run_artifacts_path = os.path.join('azureml', head_run.id)\n",
|
||||
"datastore = ws.get_default_datastore()\n",
|
||||
"\n",
|
||||
"run_artifacts_ds = Dataset.File.from_files(datastore.path(os.path.join(run_artifacts_path, '**')))\n",
|
||||
"\n",
|
||||
"cp_pattern = re.compile('.*checkpoint-\\\\d+$')\n",
|
||||
"\n",
|
||||
"checkpoint_files = [file for file in run_artifacts_ds.to_path() if cp_pattern.match(file)]\n",
|
||||
"\n",
|
||||
"# There should only be one checkpoint with our training settings...\n",
|
||||
"final_checkpoint = os.path.dirname(os.path.join(run_artifacts_path, os.path.normpath(checkpoint_files[-1][1:])))\n",
|
||||
"datastore.download(target_path=path_prefix, prefix=final_checkpoint.replace('\\\\', '/'), show_progress=True)\n",
|
||||
"\n",
|
||||
"print('Download complete.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core.model import Model\n",
|
||||
"\n",
|
||||
"model_name = 'final_model_minecraft_maze'\n",
|
||||
"\n",
|
||||
"model = Model.register(\n",
|
||||
" workspace=ws,\n",
|
||||
" model_path=os.path.join(path_prefix, final_checkpoint),\n",
|
||||
" model_name=model_name,\n",
|
||||
" description='Model of an agent trained to navigate a lava maze in Minecraft.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Models can be used through a varity of APIs. Please see the\n",
|
||||
"[documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-deploy-and-where)\n",
|
||||
"for more details.\n",
|
||||
"\n",
|
||||
"### Test agent performance in a rollout\n",
|
||||
"\n",
|
||||
"To observe the trained agent's behavior, it is a common practice to\n",
|
||||
"view its behavior in a rollout. The previous reinforcement learning\n",
|
||||
"tutorials explain rollouts in more detail.\n",
|
||||
"\n",
|
||||
"The provided `minecraft_rollout.py` script loads the final checkpoint\n",
|
||||
"of the trained agent from the model registered with Azure Machine Learning\n",
|
||||
"service. It then starts a rollout on 4 different lava maze layouts, that\n",
|
||||
"are all larger and thus more difficult than the maze the agent was trained\n",
|
||||
"on. The script further records videos by replaying the agent's decisions\n",
|
||||
"in [Malmo](https://github.com/microsoft/malmo). Malmo supports multiple\n",
|
||||
"agents in the same environment, thus allowing us to capture videos that\n",
|
||||
"depict the agent from another agent's perspective. The provided\n",
|
||||
"`malmo_video_recorder.py` file and the Malmo Github repository have more\n",
|
||||
"details on the video recording setup.\n",
|
||||
"\n",
|
||||
"You can view the rewards for each rollout episode in the logs for the 'head'\n",
|
||||
"run submitted below. In some episodes, the agent may fail to reach the goal\n",
|
||||
"due to the higher level of difficulty - in practice, we could continue\n",
|
||||
"training the agent on harder tasks starting with the final checkpoint."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"script_params = {\n",
|
||||
" '--model_name': model_name\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"rollout_est = ReinforcementLearningEstimator(\n",
|
||||
" source_directory='files',\n",
|
||||
" entry_script='minecraft_rollout.py',\n",
|
||||
" script_params=script_params,\n",
|
||||
" compute_target=gpu_cluster,\n",
|
||||
" environment=gpu_minecraft_env,\n",
|
||||
" shm_size=1024 * 1024 * 1024 * 30)\n",
|
||||
"\n",
|
||||
"rollout_run = exp.submit(rollout_est)\n",
|
||||
"\n",
|
||||
"RunDetails(rollout_run).show()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### View videos captured during rollout\n",
|
||||
"\n",
|
||||
"To inspect the agent's training progress you can view the videos captured\n",
|
||||
"during the rollout episodes. First, ensure that the training run has\n",
|
||||
"completed."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rollout_run.wait_for_completion()\n",
|
||||
"\n",
|
||||
"head_run_rollout = next(r for r in rollout_run.get_children() if r.id.endswith('head'))\n",
|
||||
"\n",
|
||||
"print('Rollout completed.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, you need to download the video files from the training run. We use a\n",
|
||||
"Dataset to filter out the video files which are in tgz archives."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"rollout_run_artifacts_path = os.path.join('azureml', head_run_rollout.id)\n",
|
||||
"datastore = ws.get_default_datastore()\n",
|
||||
"\n",
|
||||
"rollout_run_artifacts_ds = Dataset.File.from_files(datastore.path(os.path.join(rollout_run_artifacts_path, '**')))\n",
|
||||
"\n",
|
||||
"video_archives = [file for file in rollout_run_artifacts_ds.to_path() if file.endswith('.tgz')]\n",
|
||||
"video_archives = [os.path.join(rollout_run_artifacts_path, os.path.normpath(file[1:])) for file in video_archives]\n",
|
||||
"\n",
|
||||
"datastore.download(\n",
|
||||
" target_path=path_prefix,\n",
|
||||
" prefix=os.path.dirname(video_archives[0]).replace('\\\\', '/'),\n",
|
||||
" show_progress=True)\n",
|
||||
"\n",
|
||||
"print('Download complete.')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Next, unzip the video files and rename them by the Minecraft mission seed used\n",
|
||||
"(see `minecraft_rollout.py` for more details on how the seed is used)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import tarfile\n",
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"training_artifacts_dir = './training_artifacts'\n",
|
||||
"video_dir = os.path.join(training_artifacts_dir, 'videos')\n",
|
||||
"video_files = []\n",
|
||||
"\n",
|
||||
"for tar_file_path in video_archives:\n",
|
||||
" seed = tar_file_path[tar_file_path.index('rollout_') + len('rollout_'): tar_file_path.index('.tgz')]\n",
|
||||
" \n",
|
||||
" tar = tarfile.open(os.path.join(path_prefix, tar_file_path).replace('\\\\', '/'), 'r')\n",
|
||||
" tar_info = next(t_info for t_info in tar.getmembers() if t_info.name.endswith('mp4'))\n",
|
||||
" tar.extract(tar_info, video_dir)\n",
|
||||
" tar.close()\n",
|
||||
" \n",
|
||||
" unzipped_folder = os.path.join(video_dir, next(f_ for f_ in os.listdir(video_dir) if not f_.endswith('mp4'))) \n",
|
||||
" video_file = os.path.join(unzipped_folder,'video.mp4')\n",
|
||||
" final_video_path = os.path.join(video_dir, '{seed}.mp4'.format(seed=seed))\n",
|
||||
" \n",
|
||||
" shutil.move(video_file, final_video_path) \n",
|
||||
" video_files.append(final_video_path)\n",
|
||||
" \n",
|
||||
" shutil.rmtree(unzipped_folder)\n",
|
||||
"\n",
|
||||
"# Clean up any downloaded 'tmp' files\n",
|
||||
"shutil.rmtree(path_prefix)\n",
|
||||
"\n",
|
||||
"print('Local video files:\\n', video_files)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Finally, run the cell below to display the videos in-line. In some cases,\n",
|
||||
"the agent may struggle to find the goal since the maze size was increased\n",
|
||||
"compared to training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from IPython.core.display import display, HTML\n",
|
||||
"\n",
|
||||
"index = 0\n",
|
||||
"while index < len(video_files) - 1:\n",
|
||||
" display(\n",
|
||||
" HTML('\\\n",
|
||||
" <video controls alt=\"cannot display video\" autoplay loop width=49%> \\\n",
|
||||
" <source src=\"{f1}\" type=\"video/mp4\"> \\\n",
|
||||
" </video> \\\n",
|
||||
" <video controls alt=\"cannot display video\" autoplay loop width=49%> \\\n",
|
||||
" <source src=\"{f2}\" type=\"video/mp4\"> \\\n",
|
||||
" </video>'.format(f1=video_files[index], f2=video_files[index + 1]))\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" index += 2\n",
|
||||
"\n",
|
||||
"if index < len(video_files):\n",
|
||||
" display(\n",
|
||||
" HTML('\\\n",
|
||||
" <video controls alt=\"cannot display video\" autoplay loop width=49%> \\\n",
|
||||
" <source src=\"{f1}\" type=\"video/mp4\"> \\\n",
|
||||
" </video>'.format(f1=video_files[index]))\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Cleaning up\n",
|
||||
"\n",
|
||||
"Below, you can find code snippets for your convenience to clean up any resources created as part of this tutorial you don't wish to retain."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# to stop the Tensorboard, uncomment and run\n",
|
||||
"#tb.stop()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# to delete the gpu compute target, uncomment and run\n",
|
||||
"#gpu_cluster.delete()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# to delete the cpu compute target, uncomment and run\n",
|
||||
"#cpu_cluster.delete()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# to delete the registered model, uncomment and run\n",
|
||||
"#model.delete()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# to delete the local video files, uncomment and run\n",
|
||||
"#shutil.rmtree(training_artifacts_dir)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Next steps\n",
|
||||
"\n",
|
||||
"This is currently the last introductory tutorial for Azure Machine Learning\n",
|
||||
"service's Reinforcement\n",
|
||||
"Learning offering. We would love to hear your feedback to build the features\n",
|
||||
"you need!\n",
|
||||
"\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"authors": [
|
||||
{
|
||||
"name": "andress"
|
||||
}
|
||||
],
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.6",
|
||||
"language": "python",
|
||||
"name": "python36"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.7.0"
|
||||
},
|
||||
"notice": "Copyright (c) Microsoft Corporation. All rights reserved.\u00e2\u20ac\u00afLicensed under the MIT License.\u00e2\u20ac\u00af "
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
name: minecraft
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-contrib-reinforcementlearning
|
||||
- azureml-widgets
|
||||
- tensorboard
|
||||
- azureml-tensorboard
|
||||
@@ -72,7 +72,7 @@
|
||||
"from azureml.core import Workspace\n",
|
||||
"\n",
|
||||
"ws = Workspace.from_config()\n",
|
||||
"print(ws.name, ws.location, ws.resource_group, sep = ' | ')"
|
||||
"print(ws.name, ws.location, ws.resource_group, sep = ' | ') "
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
2
index.md
2
index.md
@@ -124,6 +124,7 @@ Machine Learning notebook samples and encourage efficient retrieval of topics an
|
||||
| [pong_rllib](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/reinforcement-learning/atari-on-distributed-compute/pong_rllib.ipynb) | | | | | | |
|
||||
| [cartpole_ci](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/reinforcement-learning/cartpole-on-compute-instance/cartpole_ci.ipynb) | | | | | | |
|
||||
| [cartpole_cc](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/reinforcement-learning/cartpole-on-single-compute/cartpole_cc.ipynb) | | | | | | |
|
||||
| [minecraft](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/minecraft.ipynb) | | | | | | |
|
||||
| [devenv_setup](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/reinforcement-learning/setup/devenv_setup.ipynb) | | | | | | |
|
||||
| [Logging APIs](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/track-and-monitor-experiments/logging-api/logging-api.ipynb) | Logging APIs and analyzing results | None | None | None | None | None |
|
||||
| [distributed-cntk-with-custom-docker](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/training-with-deep-learning/distributed-cntk-with-custom-docker/distributed-cntk-with-custom-docker.ipynb) | | | | | | |
|
||||
@@ -132,5 +133,6 @@ Machine Learning notebook samples and encourage efficient retrieval of topics an
|
||||
| [tutorial-1st-experiment-sdk-train](https://github.com/Azure/MachineLearningNotebooks/blob/master//tutorials/create-first-ml-experiment/tutorial-1st-experiment-sdk-train.ipynb) | | | | | | |
|
||||
| [img-classification-part1-training](https://github.com/Azure/MachineLearningNotebooks/blob/master//tutorials/image-classification-mnist-data/img-classification-part1-training.ipynb) | | | | | | |
|
||||
| [img-classification-part2-deploy](https://github.com/Azure/MachineLearningNotebooks/blob/master//tutorials/image-classification-mnist-data/img-classification-part2-deploy.ipynb) | | | | | | |
|
||||
| [img-classification-part3-deploy-encrypted](https://github.com/Azure/MachineLearningNotebooks/blob/master//tutorials/image-classification-mnist-data/img-classification-part3-deploy-encrypted.ipynb) | | | | | | |
|
||||
| [tutorial-pipeline-batch-scoring-classification](https://github.com/Azure/MachineLearningNotebooks/blob/master//tutorials/machine-learning-pipelines-advanced/tutorial-pipeline-batch-scoring-classification.ipynb) | | | | | | |
|
||||
| [regression-automated-ml](https://github.com/Azure/MachineLearningNotebooks/blob/master//tutorials/regression-automl-nyc-taxi-data/regression-automated-ml.ipynb) | | | | | | |
|
||||
|
||||
@@ -19,6 +19,7 @@ The following tutorials are intended to provide an introductory overview of Azur
|
||||
| [Train your first ML Model](https://docs.microsoft.com/azure/machine-learning/tutorial-1st-experiment-sdk-train) | Learn the foundational design patterns in Azure Machine Learning and train a scikit-learn model based on a diabetes data set. | [tutorial-quickstart-train-model.ipynb](create-first-ml-experiment/tutorial-1st-experiment-sdk-train.ipynb) | Regression | Scikit-Learn
|
||||
| [Train an image classification model](https://docs.microsoft.com/azure/machine-learning/tutorial-train-models-with-aml) | Train a scikit-learn image classification model. | [img-classification-part1-training.ipynb](image-classification-mnist-data/img-classification-part1-training.ipynb) | Image Classification | Scikit-Learn
|
||||
| [Deploy an image classification model](https://docs.microsoft.com/azure/machine-learning/tutorial-deploy-models-with-aml) | Deploy a scikit-learn image classification model to Azure Container Instances. | [img-classification-part2-deploy.ipynb](image-classification-mnist-data/img-classification-part2-deploy.ipynb) | Image Classification | Scikit-Learn
|
||||
| [Deploy an encrypted inferencing service](https://docs.microsoft.com/azure/machine-learning/tutorial-deploy-models-with-aml) |Deploy an image classification model for encrypted inferencing in Azure Container Instances | [img-classification-part3-deploy-encrypted.ipynb](image-classification-mnist-data/img-classification-part3-deploy-encrypted.ipynb) | Image Classification | Scikit-Learn
|
||||
| [Use automated machine learning to predict taxi fares](https://docs.microsoft.com/azure/machine-learning/tutorial-auto-train-models) | Train a regression model to predict taxi fares using Automated Machine Learning. | [regression-part2-automated-ml.ipynb](regression-automl-nyc-taxi-data/regression-automated-ml.ipynb) | Regression | Automated ML
|
||||
|
||||
## Advanced Samples
|
||||
|
||||
@@ -0,0 +1,615 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
||||
"\n",
|
||||
"Licensed under the MIT License."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Tutorial #3: Deploy an image classification model for encrypted inferencing in Azure Container Instance (ACI)\n",
|
||||
"\n",
|
||||
"This tutorial is **a new addition to the two-part series**. In the [previous tutorial](img-classification-part1-training.ipynb), you trained machine learning models and then registered a model in your workspace on the cloud. \n",
|
||||
"\n",
|
||||
"Now, you're ready to deploy the model as a encrypted inferencing web service in [Azure Container Instances](https://docs.microsoft.com/azure/container-instances/) (ACI). A web service is an image, in this case a Docker image, that encapsulates the scoring logic and the model itself. \n",
|
||||
"\n",
|
||||
"In this part of the tutorial, you use Azure Machine Learning service (Preview) to:\n",
|
||||
"\n",
|
||||
"> * Set up your testing environment\n",
|
||||
"> * Retrieve the model from your workspace\n",
|
||||
"> * Test the model locally\n",
|
||||
"> * Deploy the model to ACI\n",
|
||||
"> * Test the deployed model\n",
|
||||
"\n",
|
||||
"ACI is a great solution for testing and understanding the workflow. For scalable production deployments, consider using Azure Kubernetes Service. For more information, see [how to deploy and where](https://docs.microsoft.com/azure/machine-learning/service/how-to-deploy-and-where).\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Prerequisites\n",
|
||||
"\n",
|
||||
"Complete the model training in the [Tutorial #1: Train an image classification model with Azure Machine Learning](train-models.ipynb) notebook. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# If you did NOT complete the tutorial, you can instead run this cell \n",
|
||||
"# This will register a model and download the data needed for this tutorial\n",
|
||||
"# These prerequisites are created in the training tutorial\n",
|
||||
"# Feel free to skip this cell if you completed the training tutorial \n",
|
||||
"\n",
|
||||
"# register a model\n",
|
||||
"from azureml.core import Workspace\n",
|
||||
"ws = Workspace.from_config()\n",
|
||||
"\n",
|
||||
"from azureml.core.model import Model\n",
|
||||
"\n",
|
||||
"model_name = \"sklearn_mnist\"\n",
|
||||
"model = Model.register(model_path=\"sklearn_mnist_model.pkl\",\n",
|
||||
" model_name=model_name,\n",
|
||||
" tags={\"data\": \"mnist\", \"model\": \"classification\"},\n",
|
||||
" description=\"Mnist handwriting recognition\",\n",
|
||||
" workspace=ws)\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Setup the Environment \n",
|
||||
"\n",
|
||||
"Add `encrypted-inference` package as a conda dependency "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core.environment import Environment\n",
|
||||
"from azureml.core.conda_dependencies import CondaDependencies\n",
|
||||
"\n",
|
||||
"# to install required packages\n",
|
||||
"env = Environment('tutorial-env')\n",
|
||||
"cd = CondaDependencies.create(pip_packages=['azureml-dataprep[pandas,fuse]>=1.1.14', 'azureml-defaults', 'azure-storage-blob', 'encrypted-inference==0.9'], conda_packages = ['scikit-learn==0.22.1'])\n",
|
||||
"\n",
|
||||
"env.python.conda_dependencies = cd\n",
|
||||
"\n",
|
||||
"# Register environment to re-use later\n",
|
||||
"env.register(workspace = ws)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Set up the environment\n",
|
||||
"\n",
|
||||
"Start by setting up a testing environment.\n",
|
||||
"\n",
|
||||
"### Import packages\n",
|
||||
"\n",
|
||||
"Import the Python packages needed for this tutorial."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": [
|
||||
"check version"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%matplotlib inline\n",
|
||||
"import numpy as np\n",
|
||||
"import matplotlib.pyplot as plt\n",
|
||||
" \n",
|
||||
"import azureml.core\n",
|
||||
"\n",
|
||||
"# display the core SDK version number\n",
|
||||
"print(\"Azure ML SDK Version: \", azureml.core.VERSION)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Install Homomorphic Encryption based library for Secure Inferencing\n",
|
||||
"\n",
|
||||
"Our library is based on [Microsoft SEAL](https://github.com/Microsoft/SEAL) and pubished to [PyPi.org](https://pypi.org/project/encrypted-inference) as an easy to use package "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install encrypted-inference==0.9"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Deploy as web service\n",
|
||||
"\n",
|
||||
"Deploy the model as a web service hosted in ACI. \n",
|
||||
"\n",
|
||||
"To build the correct environment for ACI, provide the following:\n",
|
||||
"* A scoring script to show how to use the model\n",
|
||||
"* A configuration file to build the ACI\n",
|
||||
"* The model you trained before\n",
|
||||
"\n",
|
||||
"### Create scoring script\n",
|
||||
"\n",
|
||||
"Create the scoring script, called score.py, used by the web service call to show how to use the model.\n",
|
||||
"\n",
|
||||
"You must include two required functions into the scoring script:\n",
|
||||
"* The `init()` function, which typically loads the model into a global object. This function is run only once when the Docker container is started. \n",
|
||||
"\n",
|
||||
"* The `run(input_data)` function uses the model to predict a value based on the input data. Inputs and outputs to the run typically use JSON for serialization and de-serialization, but other formats are supported. The function fetches homomorphic encryption based public keys that are uploaded by the service caller. \n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%writefile score.py\n",
|
||||
"import json\n",
|
||||
"import os\n",
|
||||
"import pickle\n",
|
||||
"import joblib\n",
|
||||
"from azure.storage.blob import BlobServiceClient, BlobClient, ContainerClient, PublicAccess\n",
|
||||
"from encrypted.inference.eiserver import EIServer\n",
|
||||
"\n",
|
||||
"def init():\n",
|
||||
" global model\n",
|
||||
" # AZUREML_MODEL_DIR is an environment variable created during deployment.\n",
|
||||
" # It is the path to the model folder (./azureml-models/$MODEL_NAME/$VERSION)\n",
|
||||
" # For multiple models, it points to the folder containing all deployed models (./azureml-models)\n",
|
||||
" model_path = os.path.join(os.getenv('AZUREML_MODEL_DIR'), 'sklearn_mnist_model.pkl')\n",
|
||||
" model = joblib.load(model_path)\n",
|
||||
"\n",
|
||||
" global server\n",
|
||||
" server = EIServer(model.coef_, model.intercept_, verbose=True)\n",
|
||||
"\n",
|
||||
"def run(raw_data):\n",
|
||||
"\n",
|
||||
" json_properties = json.loads(raw_data)\n",
|
||||
"\n",
|
||||
" key_id = json_properties['key_id']\n",
|
||||
" conn_str = json_properties['conn_str']\n",
|
||||
" container = json_properties['container']\n",
|
||||
" data = json_properties['data']\n",
|
||||
"\n",
|
||||
" # download the Galois keys from blob storage\n",
|
||||
" #TODO optimize by caching the keys locally \n",
|
||||
" blob_service_client = BlobServiceClient.from_connection_string(conn_str=conn_str)\n",
|
||||
" blob_client = blob_service_client.get_blob_client(container=container, blob=key_id)\n",
|
||||
" public_keys = blob_client.download_blob().readall()\n",
|
||||
" \n",
|
||||
" result = {}\n",
|
||||
" # make prediction\n",
|
||||
" result = server.predict(data, public_keys)\n",
|
||||
"\n",
|
||||
" # you can return any data type as long as it is JSON-serializable\n",
|
||||
" return result"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create configuration file\n",
|
||||
"\n",
|
||||
"Create a deployment configuration file and specify the number of CPUs and gigabyte of RAM needed for your ACI container. While it depends on your model, the default of 1 core and 1 gigabyte of RAM is usually sufficient for many models. If you feel you need more later, you would have to recreate the image and redeploy the service."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": [
|
||||
"configure web service",
|
||||
"aci"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core.webservice import AciWebservice\n",
|
||||
"\n",
|
||||
"aciconfig = AciWebservice.deploy_configuration(cpu_cores=1, \n",
|
||||
" memory_gb=1, \n",
|
||||
" tags={\"data\": \"MNIST\", \"method\" : \"sklearn\"}, \n",
|
||||
" description='Encrypted Predict MNIST with sklearn + SEAL')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Deploy in ACI\n",
|
||||
"Estimated time to complete: **about 2-5 minutes**\n",
|
||||
"\n",
|
||||
"Configure the image and deploy. The following code goes through these steps:\n",
|
||||
"\n",
|
||||
"1. Create environment object containing dependencies needed by the model using the environment file (`myenv.yml`)\n",
|
||||
"1. Create inference configuration necessary to deploy the model as a web service using:\n",
|
||||
" * The scoring file (`score.py`)\n",
|
||||
" * envrionment object created in previous step\n",
|
||||
"1. Deploy the model to the ACI container.\n",
|
||||
"1. Get the web service HTTP endpoint."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": [
|
||||
"configure image",
|
||||
"create image",
|
||||
"deploy web service",
|
||||
"aci"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%time\n",
|
||||
"from azureml.core.webservice import Webservice\n",
|
||||
"from azureml.core.model import InferenceConfig\n",
|
||||
"from azureml.core.environment import Environment\n",
|
||||
"from azureml.core import Workspace\n",
|
||||
"from azureml.core.model import Model\n",
|
||||
"\n",
|
||||
"ws = Workspace.from_config()\n",
|
||||
"model = Model(ws, 'sklearn_mnist')\n",
|
||||
"\n",
|
||||
"myenv = Environment.get(workspace=ws, name=\"tutorial-env\")\n",
|
||||
"inference_config = InferenceConfig(entry_script=\"score.py\", environment=myenv)\n",
|
||||
"\n",
|
||||
"service = Model.deploy(workspace=ws, \n",
|
||||
" name='sklearn-mnist-svc', \n",
|
||||
" models=[model], \n",
|
||||
" inference_config=inference_config, \n",
|
||||
" deployment_config=aciconfig)\n",
|
||||
"\n",
|
||||
"service.wait_for_deployment(show_output=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Get the scoring web service's HTTP endpoint, which accepts REST client calls. This endpoint can be shared with anyone who wants to test the web service or integrate it into an application."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": [
|
||||
"get scoring uri"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(service.scoring_uri)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Test the model\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Download test data\n",
|
||||
"Download the test data to the **./data/** directory"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from azureml.core import Dataset\n",
|
||||
"from azureml.opendatasets import MNIST\n",
|
||||
"\n",
|
||||
"data_folder = os.path.join(os.getcwd(), 'data')\n",
|
||||
"os.makedirs(data_folder, exist_ok=True)\n",
|
||||
"\n",
|
||||
"mnist_file_dataset = MNIST.get_file_dataset()\n",
|
||||
"mnist_file_dataset.download(data_folder, overwrite=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Load test data\n",
|
||||
"\n",
|
||||
"Load the test data from the **./data/** directory created during the training tutorial."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from utils import load_data\n",
|
||||
"import os\n",
|
||||
"import glob\n",
|
||||
"\n",
|
||||
"data_folder = os.path.join(os.getcwd(), 'data')\n",
|
||||
"# note we also shrink the intensity values (X) from 0-255 to 0-1. This helps the neural network converge faster\n",
|
||||
"X_test = load_data(glob.glob(os.path.join(data_folder,\"**/t10k-images-idx3-ubyte.gz\"), recursive=True)[0], False) / 255.0\n",
|
||||
"y_test = load_data(glob.glob(os.path.join(data_folder,\"**/t10k-labels-idx1-ubyte.gz\"), recursive=True)[0], True).reshape(-1)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Predict test data\n",
|
||||
"\n",
|
||||
"Feed the test dataset to the model to get predictions.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"The following code goes through these steps:\n",
|
||||
"\n",
|
||||
"1. Create our Homomorphic Encryption based client \n",
|
||||
"\n",
|
||||
"1. Upload HE generated public keys \n",
|
||||
"\n",
|
||||
"1. Encrypt the data\n",
|
||||
"\n",
|
||||
"1. Send the data as JSON to the web service hosted in ACI. \n",
|
||||
"\n",
|
||||
"1. Use the SDK's `run` API to invoke the service. You can also make raw calls using any HTTP tool such as curl."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Create our Homomorphic Encryption based client \n",
|
||||
"\n",
|
||||
"Create a new EILinearRegressionClient and setup the public keys "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from encrypted.inference.eiclient import EILinearRegressionClient\n",
|
||||
"\n",
|
||||
"# Create a new Encrypted inference client and a new secret key.\n",
|
||||
"edp = EILinearRegressionClient(verbose=True)\n",
|
||||
"\n",
|
||||
"public_keys_blob, public_keys_data = edp.get_public_keys()\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Upload HE generated public keys\n",
|
||||
"\n",
|
||||
"Upload the public keys to the workspace default blob store. This will allow us to share the keys with the inference server"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import azureml.core\n",
|
||||
"from azureml.core import Workspace, Datastore\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"ws = Workspace.from_config()\n",
|
||||
"\n",
|
||||
"datastore = ws.get_default_datastore()\n",
|
||||
"container_name=datastore.container_name\n",
|
||||
"\n",
|
||||
"# Create a local file and write the keys to it\n",
|
||||
"public_keys = open(public_keys_blob, \"wb\")\n",
|
||||
"public_keys.write(public_keys_data)\n",
|
||||
"public_keys.close()\n",
|
||||
"\n",
|
||||
"# Upload the file to blob store\n",
|
||||
"datastore.upload_files([public_keys_blob])\n",
|
||||
"\n",
|
||||
"# Delete the local file\n",
|
||||
"os.remove(public_keys_blob)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Encrypt the data "
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"#choose any one sample from the test data \n",
|
||||
"sample_index = 1\n",
|
||||
"\n",
|
||||
"#encrypt the data\n",
|
||||
"raw_data = edp.encrypt(X_test[sample_index])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Send the test data to the webservice hosted in ACI\n",
|
||||
"\n",
|
||||
"Feed the test dataset to the model to get predictions. We will need to send the connection string to the blob storage where the public keys were uploaded \n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import json\n",
|
||||
"from azureml.core import Webservice\n",
|
||||
"\n",
|
||||
"service = Webservice(ws, 'sklearn-mnist-svc')\n",
|
||||
"\n",
|
||||
"#pass the connection string for blob storage to give the server access to the uploaded public keys \n",
|
||||
"conn_str_template = 'DefaultEndpointsProtocol={};AccountName={};AccountKey={};EndpointSuffix=core.windows.net'\n",
|
||||
"conn_str = conn_str_template.format(datastore.protocol, datastore.account_name, datastore.account_key)\n",
|
||||
"\n",
|
||||
"#build the json \n",
|
||||
"data = json.dumps({\"data\": raw_data, \"key_id\" : public_keys_blob, \"conn_str\" : conn_str, \"container\" : container_name })\n",
|
||||
"data = bytes(data, encoding='ASCII')\n",
|
||||
"\n",
|
||||
"print ('Making an encrypted inference web service call ')\n",
|
||||
"eresult = service.run(input_data=data)\n",
|
||||
"\n",
|
||||
"print ('Received encrypted inference results')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Decrypt the data\n",
|
||||
"\n",
|
||||
"Use the client to decrypt the results"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np \n",
|
||||
"\n",
|
||||
"results = edp.decrypt(eresult)\n",
|
||||
"\n",
|
||||
"print ('Decrypted the results ', results)\n",
|
||||
"\n",
|
||||
"#Apply argmax to identify the prediction result\n",
|
||||
"prediction = np.argmax(results)\n",
|
||||
"\n",
|
||||
"print ( ' Prediction : ', prediction)\n",
|
||||
"print ( ' Actual Label : ', y_test[sample_index])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Clean up resources\n",
|
||||
"\n",
|
||||
"To keep the resource group and workspace for other tutorials and exploration, you can delete only the ACI deployment using this API call:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"tags": [
|
||||
"delete web service"
|
||||
]
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"service.delete()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"\n",
|
||||
"If you're not going to use what you've created here, delete the resources you just created with this quickstart so you don't incur any charges. In the Azure portal, select and delete your resource group. You can also keep the resource group, but delete a single workspace by displaying the workspace properties and selecting the Delete button.\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"## Next steps\n",
|
||||
"\n",
|
||||
"In this Azure Machine Learning tutorial, you used Python to:\n",
|
||||
"\n",
|
||||
"> * Set up your testing environment\n",
|
||||
"> * Retrieve the model from your workspace\n",
|
||||
"> * Test the model locally\n",
|
||||
"> * Deploy the model to ACI\n",
|
||||
"> * Test the deployed model\n",
|
||||
" \n",
|
||||
"You can also try out the [regression tutorial](regression-part1-data-prep.ipynb)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
""
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"authors": [
|
||||
{
|
||||
"name": "vkanne"
|
||||
}
|
||||
],
|
||||
"celltoolbar": "Edit Metadata",
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3.6",
|
||||
"language": "python",
|
||||
"name": "python36"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6"
|
||||
},
|
||||
"msauthor": "vkanne"
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
name: img-classification-part3-deploy-encrypted
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- matplotlib
|
||||
- sklearn
|
||||
- pandas
|
||||
- azureml-opendatasets
|
||||
- encrypted-inference==0.9
|
||||
- azure-storage-blob
|
||||
Reference in New Issue
Block a user