mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
Compare commits
1 Commits
update-spa
...
release_up
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
efc61d2219 |
@@ -103,7 +103,7 @@
|
||||
"source": [
|
||||
"import azureml.core\n",
|
||||
"\n",
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -21,8 +21,8 @@ dependencies:
|
||||
|
||||
- pip:
|
||||
# Required packages for AzureML execution, history, and data preparation.
|
||||
- azureml-widgets~=1.23.0
|
||||
- azureml-widgets~=1.24.0
|
||||
- pytorch-transformers==1.0.0
|
||||
- spacy==2.1.8
|
||||
- https://aka.ms/automl-resources/packages/en_core_web_sm-2.1.0.tar.gz
|
||||
- -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.23.0/validated_win32_requirements.txt [--no-deps]
|
||||
- -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.24.0/validated_win32_requirements.txt [--no-deps]
|
||||
|
||||
@@ -21,8 +21,8 @@ dependencies:
|
||||
|
||||
- pip:
|
||||
# Required packages for AzureML execution, history, and data preparation.
|
||||
- azureml-widgets~=1.23.0
|
||||
- azureml-widgets~=1.24.0
|
||||
- pytorch-transformers==1.0.0
|
||||
- spacy==2.1.8
|
||||
- https://aka.ms/automl-resources/packages/en_core_web_sm-2.1.0.tar.gz
|
||||
- -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.23.0/validated_linux_requirements.txt [--no-deps]
|
||||
- -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.24.0/validated_linux_requirements.txt [--no-deps]
|
||||
|
||||
@@ -22,8 +22,8 @@ dependencies:
|
||||
|
||||
- pip:
|
||||
# Required packages for AzureML execution, history, and data preparation.
|
||||
- azureml-widgets~=1.23.0
|
||||
- azureml-widgets~=1.24.0
|
||||
- pytorch-transformers==1.0.0
|
||||
- spacy==2.1.8
|
||||
- https://aka.ms/automl-resources/packages/en_core_web_sm-2.1.0.tar.gz
|
||||
- -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.23.0/validated_darwin_requirements.txt [--no-deps]
|
||||
- -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.24.0/validated_darwin_requirements.txt [--no-deps]
|
||||
|
||||
@@ -105,7 +105,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-classification-bank-marketing-all-features
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -93,7 +93,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-classification-credit-card-fraud
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -96,7 +96,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-classification-text-dnn
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -81,7 +81,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-continuous-retraining
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -5,7 +5,7 @@ set options=%3
|
||||
set PIP_NO_WARN_SCRIPT_LOCATION=0
|
||||
|
||||
IF "%conda_env_name%"=="" SET conda_env_name="azure_automl_experimental"
|
||||
IF "%automl_env_file%"=="" SET automl_env_file="automl_env.yml"
|
||||
IF "%automl_env_file%"=="" SET automl_env_file="automl_thin_client_env.yml"
|
||||
|
||||
IF NOT EXIST %automl_env_file% GOTO YmlMissing
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ fi
|
||||
|
||||
if [ "$AUTOML_ENV_FILE" == "" ]
|
||||
then
|
||||
AUTOML_ENV_FILE="automl_env.yml"
|
||||
AUTOML_ENV_FILE="automl_thin_client_env.yml"
|
||||
fi
|
||||
|
||||
if [ ! -f $AUTOML_ENV_FILE ]; then
|
||||
|
||||
@@ -12,7 +12,7 @@ fi
|
||||
|
||||
if [ "$AUTOML_ENV_FILE" == "" ]
|
||||
then
|
||||
AUTOML_ENV_FILE="automl_env.yml"
|
||||
AUTOML_ENV_FILE="automl_thin_client_env_mac.yml"
|
||||
fi
|
||||
|
||||
if [ ! -f $AUTOML_ENV_FILE ]; then
|
||||
|
||||
@@ -7,6 +7,7 @@ dependencies:
|
||||
- nb_conda
|
||||
- cython
|
||||
- urllib3<1.24
|
||||
- PyJWT < 2.0.0
|
||||
- numpy==1.18.5
|
||||
|
||||
- pip:
|
||||
@@ -15,4 +16,3 @@ dependencies:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- pandas
|
||||
- PyJWT < 2.0.0
|
||||
|
||||
@@ -8,6 +8,7 @@ dependencies:
|
||||
- nb_conda
|
||||
- cython
|
||||
- urllib3<1.24
|
||||
- PyJWT < 2.0.0
|
||||
- numpy==1.18.5
|
||||
|
||||
- pip:
|
||||
@@ -16,4 +17,3 @@ dependencies:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- pandas
|
||||
- PyJWT < 2.0.0
|
||||
|
||||
@@ -90,7 +90,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
@@ -194,7 +194,6 @@
|
||||
"|**n_cross_validations**|Number of cross validation splits.|\n",
|
||||
"|**training_data**|(sparse) array-like, shape = [n_samples, n_features]|\n",
|
||||
"|**label_column_name**|(sparse) array-like, shape = [n_samples, ], targets values.|\n",
|
||||
"|**scenario**|We need to set this parameter to 'Latest' to enable some experimental features. This parameter should not be set outside of this experimental notebook.|\n",
|
||||
"\n",
|
||||
"**_You can find more information about primary metrics_** [here](https://docs.microsoft.com/en-us/azure/machine-learning/service/how-to-configure-auto-train#primary-metric)"
|
||||
]
|
||||
@@ -223,7 +222,6 @@
|
||||
" compute_target = compute_target,\n",
|
||||
" training_data = train_data,\n",
|
||||
" label_column_name = label,\n",
|
||||
" scenario='Latest',\n",
|
||||
" **automl_settings\n",
|
||||
" )"
|
||||
]
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-regression-model-proxy
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -113,7 +113,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-forecasting-beer-remote
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -87,7 +87,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-forecasting-bike-share
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -97,7 +97,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-forecasting-energy-demand
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -94,7 +94,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-forecasting-function
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -82,7 +82,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-forecasting-orange-juice-sales
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -96,7 +96,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-classification-credit-card-fraud-local
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -96,7 +96,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-regression-explanation-featurization
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -92,7 +92,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n",
|
||||
"print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n",
|
||||
"print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
name: auto-ml-regression
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
84
how-to-use-azureml/azure-synapse/README.md
Normal file
84
how-to-use-azureml/azure-synapse/README.md
Normal file
@@ -0,0 +1,84 @@
|
||||
Azure Synapse Analytics is a limitless analytics service that brings together data integration, enterprise data warehousing, and big data analytics. It gives you the freedom to query data on your terms, using either serverless or dedicated resources—at scale. Azure Synapse brings these worlds together with a unified experience to ingest, explore, prepare, manage, and serve data for immediate BI and machine learning needs. A core offering within Azure Synapse Analytics are serverless Apache Spark pools enhanced for big data workloads.
|
||||
|
||||
Synapse in Aml integration is for customers who want to use Apache Spark in Azure Synapse Analytics to prepare data at scale in Azure ML before training their ML model. This will allow customers to work on their end-to-end ML lifecycle including large-scale data preparation, model training and deployment within Azure ML workspace without having to use suboptimal tools for machine learning or switch between multiple tools for data preparation and model training. The ability to perform all ML tasks within Azure ML will reduce time required for customers to iterate on a machine learning project which typically includes multiple rounds of data preparation and training.
|
||||
|
||||
In the public preview, the capabilities are provided:
|
||||
|
||||
- Link Azure Synapse Analytics workspace to Azure Machine Learning workspace (via ARM, UI or SDK)
|
||||
- Attach Apache Spark pools powered by Azure Synapse Analytics as Azure Machine Learning compute targets (via ARM, UI or SDK)
|
||||
- Launch Apache Spark sessions in notebooks and perform interactive data exploration and preparation. This interactive experience leverages Apache Spark magic and customers will have session-level Conda support to install packages.
|
||||
- Productionize ML pipelines by leveraging Apache Spark pools to pre-process big data
|
||||
|
||||
# Using Synapse in Azure machine learning
|
||||
|
||||
## Create synapse resources
|
||||
|
||||
Follow up the documents to create Synapse workspace and resource-setup.sh is available for you to create the resources.
|
||||
|
||||
- Create from [Portal](https://docs.microsoft.com/en-us/azure/synapse-analytics/quickstart-create-workspace)
|
||||
- Create from [Cli](https://docs.microsoft.com/en-us/azure/synapse-analytics/quickstart-create-workspace-cli)
|
||||
|
||||
Follow up the documents to create Synapse spark pool
|
||||
|
||||
- Create from [Portal](https://docs.microsoft.com/en-us/azure/synapse-analytics/quickstart-create-apache-spark-pool-portal)
|
||||
- Create from [Cli](https://docs.microsoft.com/en-us/cli/azure/ext/synapse/synapse/spark/pool?view=azure-cli-latest)
|
||||
|
||||
## Link Synapse Workspace
|
||||
|
||||
Make sure you are the owner of synapse workspace so that you can link synapse workspace into AML.
|
||||
You can run resource-setup.py to link the synapse workspace and attach compute
|
||||
|
||||
```python
|
||||
from azureml.core import Workspace
|
||||
ws = Workspace.from_config()
|
||||
|
||||
from azureml.core import LinkedService, SynapseWorkspaceLinkedServiceConfiguration
|
||||
synapse_link_config = SynapseWorkspaceLinkedServiceConfiguration(
|
||||
subscription_id="<subscription id>",
|
||||
resource_group="<resource group",
|
||||
name="<synapse workspace name>"
|
||||
)
|
||||
|
||||
linked_service = LinkedService.register(
|
||||
workspace=ws,
|
||||
name='<link name>',
|
||||
linked_service_config=synapse_link_config)
|
||||
|
||||
```
|
||||
|
||||
## Attach synapse spark pool as AzureML compute
|
||||
|
||||
```python
|
||||
|
||||
from azureml.core.compute import SynapseCompute, ComputeTarget
|
||||
spark_pool_name = "<spark pool name>"
|
||||
attached_synapse_name = "<attached compute name>"
|
||||
|
||||
attach_config = SynapseCompute.attach_configuration(
|
||||
linked_service,
|
||||
type="SynapseSpark",
|
||||
pool_name=spark_pool_name)
|
||||
|
||||
synapse_compute=ComputeTarget.attach(
|
||||
workspace=ws,
|
||||
name=attached_synapse_name,
|
||||
attach_configuration=attach_config)
|
||||
|
||||
synapse_compute.wait_for_completion()
|
||||
```
|
||||
|
||||
## Set up permission
|
||||
|
||||
Grant Spark admin role to system assigned identity of the linked service so that the user can submit experiment run or pipeline run from AML workspace to synapse spark pool.
|
||||
|
||||
Grant Spark admin role to the specific user so that the user can start spark session to synapse spark pool.
|
||||
|
||||
You can get the system assigned identity information by running
|
||||
|
||||
```python
|
||||
print(linked_service.system_assigned_identity_principal_id)
|
||||
```
|
||||
|
||||
- Launch synapse studio of the synapse workspace and grant linked service MSI "Synapse Apache Spark administrator" role.
|
||||
|
||||
- In azure portal grant linked service MSI "Storage Blob Data Contributor" role of the primary adlsgen2 account of synapse workspace to use the library management feature.
|
||||
@@ -0,0 +1,6 @@
|
||||
name: multi-model-register-and-deploy
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- numpy
|
||||
- scikit-learn
|
||||
@@ -0,0 +1,6 @@
|
||||
name: model-register-and-deploy
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- numpy
|
||||
- scikit-learn
|
||||
@@ -0,0 +1,4 @@
|
||||
name: deploy-aks-with-controlled-rollout
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -0,0 +1,4 @@
|
||||
name: enable-app-insights-in-production-service
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -0,0 +1,8 @@
|
||||
name: onnx-convert-aml-deploy-tinyyolo
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- numpy
|
||||
- git+https://github.com/apple/coremltools@v2.1
|
||||
- onnx<1.7.0
|
||||
- onnxmltools
|
||||
@@ -0,0 +1,9 @@
|
||||
name: onnx-inference-facial-expression-recognition-deploy
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- matplotlib
|
||||
- numpy
|
||||
- onnx<1.7.0
|
||||
- opencv-python-headless
|
||||
@@ -0,0 +1,9 @@
|
||||
name: onnx-inference-mnist-deploy
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- matplotlib
|
||||
- numpy
|
||||
- onnx<1.7.0
|
||||
- opencv-python-headless
|
||||
@@ -0,0 +1,4 @@
|
||||
name: onnx-model-register-and-deploy
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -0,0 +1,4 @@
|
||||
name: onnx-modelzoo-aml-deploy-resnet50
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -0,0 +1,5 @@
|
||||
name: onnx-train-pytorch-aml-deploy-mnist
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,5 @@
|
||||
name: production-deploy-to-aks-gpu
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- tensorflow
|
||||
@@ -0,0 +1,8 @@
|
||||
name: production-deploy-to-aks-ssl
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- matplotlib
|
||||
- tqdm
|
||||
- scipy
|
||||
- sklearn
|
||||
@@ -0,0 +1,8 @@
|
||||
name: production-deploy-to-aks
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- matplotlib
|
||||
- tqdm
|
||||
- scipy
|
||||
- sklearn
|
||||
@@ -0,0 +1,4 @@
|
||||
name: model-register-and-deploy-spark
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -0,0 +1,9 @@
|
||||
name: explain-model-on-amlcompute
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-interpret
|
||||
- interpret-community[visualization]
|
||||
- matplotlib
|
||||
- azureml-dataset-runtime
|
||||
- ipywidgets
|
||||
@@ -226,36 +226,6 @@
|
||||
" ('classifier', SVC(C=1.0, probability=True))])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"'''\n",
|
||||
"# Uncomment below if sklearn-pandas is not installed\n",
|
||||
"#!pip install sklearn-pandas\n",
|
||||
"from sklearn_pandas import DataFrameMapper\n",
|
||||
"\n",
|
||||
"# Impute, standardize the numeric features and one-hot encode the categorical features. \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"numeric_transformations = [([f], Pipeline(steps=[('imputer', SimpleImputer(strategy='median')), ('scaler', StandardScaler())])) for f in numerical]\n",
|
||||
"\n",
|
||||
"categorical_transformations = [([f], OneHotEncoder(handle_unknown='ignore', sparse=False)) for f in categorical]\n",
|
||||
"\n",
|
||||
"transformations = numeric_transformations + categorical_transformations\n",
|
||||
"\n",
|
||||
"# Append classifier to preprocessing pipeline.\n",
|
||||
"# Now we have a full prediction pipeline.\n",
|
||||
"clf = Pipeline(steps=[('preprocessor', transformations),\n",
|
||||
" ('classifier', SVC(C=1.0, probability=True))]) \n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"'''"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
name: save-retrieve-explanations-run-history
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-interpret
|
||||
- interpret-community[visualization]
|
||||
- matplotlib
|
||||
- ipywidgets
|
||||
@@ -166,12 +166,12 @@
|
||||
"source": [
|
||||
"from sklearn.model_selection import train_test_split\n",
|
||||
"import joblib\n",
|
||||
"from sklearn.compose import ColumnTransformer\n",
|
||||
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
||||
"from sklearn.impute import SimpleImputer\n",
|
||||
"from sklearn.pipeline import Pipeline\n",
|
||||
"from sklearn.linear_model import LogisticRegression\n",
|
||||
"from sklearn.ensemble import RandomForestClassifier\n",
|
||||
"from sklearn_pandas import DataFrameMapper\n",
|
||||
"\n",
|
||||
"from interpret.ext.blackbox import TabularExplainer\n",
|
||||
"\n",
|
||||
@@ -201,17 +201,23 @@
|
||||
"# Store the numerical columns in a list numerical\n",
|
||||
"numerical = attritionXData.columns.difference(categorical)\n",
|
||||
"\n",
|
||||
"numeric_transformations = [([f], Pipeline(steps=[\n",
|
||||
"# We create the preprocessing pipelines for both numeric and categorical data.\n",
|
||||
"numeric_transformer = Pipeline(steps=[\n",
|
||||
" ('imputer', SimpleImputer(strategy='median')),\n",
|
||||
" ('scaler', StandardScaler())])) for f in numerical]\n",
|
||||
" ('scaler', StandardScaler())])\n",
|
||||
"\n",
|
||||
"categorical_transformations = [([f], OneHotEncoder(handle_unknown='ignore', sparse=False)) for f in categorical]\n",
|
||||
"categorical_transformer = Pipeline(steps=[\n",
|
||||
" ('imputer', SimpleImputer(strategy='constant', fill_value='missing')),\n",
|
||||
" ('onehot', OneHotEncoder(handle_unknown='ignore'))])\n",
|
||||
"\n",
|
||||
"transformations = numeric_transformations + categorical_transformations\n",
|
||||
"transformations = ColumnTransformer(\n",
|
||||
" transformers=[\n",
|
||||
" ('num', numeric_transformer, numerical),\n",
|
||||
" ('cat', categorical_transformer, categorical)])\n",
|
||||
"\n",
|
||||
"# Append classifier to preprocessing pipeline.\n",
|
||||
"# Now we have a full prediction pipeline.\n",
|
||||
"clf = Pipeline(steps=[('preprocessor', DataFrameMapper(transformations)),\n",
|
||||
"clf = Pipeline(steps=[('preprocessor', transformations),\n",
|
||||
" ('classifier', RandomForestClassifier())])\n",
|
||||
"\n",
|
||||
"# Split data into train and test\n",
|
||||
@@ -350,7 +356,7 @@
|
||||
"# the submitted job is run in. Note the remote environment(s) needs to be similar to the local\n",
|
||||
"# environment, otherwise if a model is trained or deployed in a different environment this can\n",
|
||||
"# cause errors. Please take extra care when specifying your dependencies in a production environment.\n",
|
||||
"myenv = CondaDependencies.create(pip_packages=['sklearn-pandas', 'pyyaml', sklearn_dep, pandas_dep] + azureml_pip_packages,\n",
|
||||
"myenv = CondaDependencies.create(pip_packages=['pyyaml', sklearn_dep, pandas_dep] + azureml_pip_packages,\n",
|
||||
" pin_sdk_version=False)\n",
|
||||
"\n",
|
||||
"with open(\"myenv.yml\",\"w\") as f:\n",
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
name: train-explain-model-locally-and-deploy
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-interpret
|
||||
- interpret-community[visualization]
|
||||
- matplotlib
|
||||
- ipywidgets
|
||||
@@ -294,7 +294,7 @@
|
||||
"# the submitted job is run in. Note the remote environment(s) needs to be similar to the local\n",
|
||||
"# environment, otherwise if a model is trained or deployed in a different environment this can\n",
|
||||
"# cause errors. Please take extra care when specifying your dependencies in a production environment.\n",
|
||||
"azureml_pip_packages.extend(['sklearn-pandas', 'pyyaml', sklearn_dep, pandas_dep])\n",
|
||||
"azureml_pip_packages.extend(['pyyaml', sklearn_dep, pandas_dep])\n",
|
||||
"run_config.environment.python.conda_dependencies = CondaDependencies.create(pip_packages=azureml_pip_packages)\n",
|
||||
"# Now submit a run on AmlCompute\n",
|
||||
"from azureml.core.script_run_config import ScriptRunConfig\n",
|
||||
@@ -458,7 +458,7 @@
|
||||
"# the submitted job is run in. Note the remote environment(s) needs to be similar to the local\n",
|
||||
"# environment, otherwise if a model is trained or deployed in a different environment this can\n",
|
||||
"# cause errors. Please take extra care when specifying your dependencies in a production environment.\n",
|
||||
"azureml_pip_packages.extend(['sklearn-pandas', 'pyyaml', sklearn_dep, pandas_dep])\n",
|
||||
"azureml_pip_packages.extend(['pyyaml', sklearn_dep, pandas_dep])\n",
|
||||
"myenv = CondaDependencies.create(pip_packages=azureml_pip_packages)\n",
|
||||
"\n",
|
||||
"with open(\"myenv.yml\",\"w\") as f:\n",
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
name: train-explain-model-on-amlcompute-and-deploy
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-interpret
|
||||
- interpret-community[visualization]
|
||||
- matplotlib
|
||||
- azureml-dataset-runtime
|
||||
- azureml-core
|
||||
- ipywidgets
|
||||
@@ -5,13 +5,13 @@
|
||||
import os
|
||||
import pandas as pd
|
||||
import zipfile
|
||||
from sklearn.model_selection import train_test_split
|
||||
import joblib
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn_pandas import DataFrameMapper
|
||||
|
||||
from azureml.core.run import Run
|
||||
from interpret.ext.blackbox import TabularExplainer
|
||||
@@ -57,16 +57,22 @@ for col, value in attritionXData.iteritems():
|
||||
# store the numerical columns
|
||||
numerical = attritionXData.columns.difference(categorical)
|
||||
|
||||
numeric_transformations = [([f], Pipeline(steps=[
|
||||
# We create the preprocessing pipelines for both numeric and categorical data.
|
||||
numeric_transformer = Pipeline(steps=[
|
||||
('imputer', SimpleImputer(strategy='median')),
|
||||
('scaler', StandardScaler())])) for f in numerical]
|
||||
('scaler', StandardScaler())])
|
||||
|
||||
categorical_transformations = [([f], OneHotEncoder(handle_unknown='ignore', sparse=False)) for f in categorical]
|
||||
categorical_transformer = Pipeline(steps=[
|
||||
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
|
||||
('onehot', OneHotEncoder(handle_unknown='ignore'))])
|
||||
|
||||
transformations = numeric_transformations + categorical_transformations
|
||||
transformations = ColumnTransformer(
|
||||
transformers=[
|
||||
('num', numeric_transformer, numerical),
|
||||
('cat', categorical_transformer, categorical)])
|
||||
|
||||
# append classifier to preprocessing pipeline
|
||||
clf = Pipeline(steps=[('preprocessor', DataFrameMapper(transformations)),
|
||||
clf = Pipeline(steps=[('preprocessor', transformations),
|
||||
('classifier', LogisticRegression(solver='lbfgs'))])
|
||||
|
||||
# get the run this was submitted from to interact with run history
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
name: aml-pipelines-data-transfer
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,5 @@
|
||||
name: aml-pipelines-getting-started
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,5 @@
|
||||
name: aml-pipelines-how-to-use-modulestep
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,5 @@
|
||||
name: aml-pipelines-how-to-use-pipeline-drafts
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,9 @@
|
||||
name: aml-pipelines-parameter-tuning-with-hyperdrive
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- matplotlib
|
||||
- numpy
|
||||
- pandas_ml
|
||||
- azureml-dataset-runtime[pandas,fuse]
|
||||
@@ -0,0 +1,6 @@
|
||||
name: aml-pipelines-publish-and-run-using-rest-endpoint
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- requests
|
||||
@@ -0,0 +1,5 @@
|
||||
name: aml-pipelines-setup-schedule-for-a-published-pipeline
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,6 @@
|
||||
name: aml-pipelines-setup-versioned-pipeline-endpoints
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- requests
|
||||
@@ -0,0 +1,5 @@
|
||||
name: aml-pipelines-showcasing-datapath-and-pipelineparameter
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,5 @@
|
||||
name: aml-pipelines-showcasing-dataset-and-pipelineparameter
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,4 @@
|
||||
name: aml-pipelines-with-automated-machine-learning-step
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
@@ -0,0 +1,5 @@
|
||||
name: aml-pipelines-with-commandstep-r
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,5 @@
|
||||
name: aml-pipelines-with-commandstep
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,5 @@
|
||||
name: aml-pipelines-with-data-dependency-steps
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,6 @@
|
||||
name: aml-pipelines-with-notebook-runner-step
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- azureml-contrib-notebook
|
||||
@@ -0,0 +1,10 @@
|
||||
name: nyc-taxi-data-regression-model-building
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- azureml-opendatasets
|
||||
- azureml-train-automl
|
||||
- matplotlib
|
||||
- pandas
|
||||
- pyarrow
|
||||
@@ -0,0 +1,7 @@
|
||||
name: file-dataset-image-inference-mnist
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-pipeline-steps
|
||||
- azureml-widgets
|
||||
- pandas
|
||||
@@ -0,0 +1,7 @@
|
||||
name: tabular-dataset-inference-iris
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-pipeline-steps
|
||||
- azureml-widgets
|
||||
- pandas
|
||||
@@ -0,0 +1,7 @@
|
||||
name: pipeline-style-transfer-parallel-run
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-pipeline-steps
|
||||
- azureml-widgets
|
||||
- requests
|
||||
@@ -0,0 +1,5 @@
|
||||
name: distributed-chainer
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,12 @@
|
||||
name: train-hyperparameter-tune-deploy-with-chainer
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- numpy
|
||||
- matplotlib
|
||||
- json
|
||||
- urllib
|
||||
- gzip
|
||||
- struct
|
||||
- requests
|
||||
@@ -0,0 +1,5 @@
|
||||
name: fastai-with-custom-docker
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- fastai==1.0.61
|
||||
@@ -0,0 +1,8 @@
|
||||
name: train-hyperparameter-tune-deploy-with-keras
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- tensorflow
|
||||
- keras<=2.3.1
|
||||
- matplotlib
|
||||
@@ -21,7 +21,8 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Distributed PyTorch with DistributedDataParallel\n",
|
||||
"In this tutorial, you will train a PyTorch model on the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset using distributed training with PyTorch's `DistributedDataParallel` module across a GPU cluster. "
|
||||
"\n",
|
||||
"In this tutorial, you will train a PyTorch model on the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset using distributed training with PyTorch's `DistributedDataParallel` module across a GPU cluster."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -113,7 +114,7 @@
|
||||
"from azureml.core.compute_target import ComputeTargetException\n",
|
||||
"\n",
|
||||
"# choose a name for your cluster\n",
|
||||
"cluster_name = \"gpu-cluster\"\n",
|
||||
"cluster_name = 'gpu-cluster'\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" compute_target = ComputeTarget(workspace=ws, name=cluster_name)\n",
|
||||
@@ -139,6 +140,68 @@
|
||||
"The above code creates GPU compute. If you instead want to create CPU compute, provide a different VM size to the `vm_size` parameter, such as `STANDARD_D2_V2`."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Prepare dataset\n",
|
||||
"\n",
|
||||
"Prepare the dataset used for training. We will first download and extract the publicly available CIFAR-10 dataset from the cs.toronto.edu website and then create an Azure ML FileDataset to use the data for training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Download and extract CIFAR-10 data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import urllib\n",
|
||||
"import tarfile\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'\n",
|
||||
"filename = 'cifar-10-python.tar.gz'\n",
|
||||
"data_root = 'cifar-10'\n",
|
||||
"filepath = os.path.join(data_root, filename)\n",
|
||||
"\n",
|
||||
"if not os.path.isdir(data_root):\n",
|
||||
" os.makedirs(data_root, exist_ok=True)\n",
|
||||
" urllib.request.urlretrieve(url, filepath)\n",
|
||||
" with tarfile.open(filepath, \"r:gz\") as tar:\n",
|
||||
" tar.extractall(path=data_root)\n",
|
||||
" os.remove(filepath) # delete tar.gz file after extraction"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create Azure ML dataset\n",
|
||||
"\n",
|
||||
"The `upload_directory` method will upload the data to a datastore and create a FileDataset from it. In this tutorial we will use the workspace's default datastore."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core import Dataset\n",
|
||||
"\n",
|
||||
"datastore = ws.get_default_datastore()\n",
|
||||
"dataset = Dataset.File.upload_directory(\n",
|
||||
" src_dir=data_root, target=(datastore, data_root)\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -161,8 +224,6 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"project_folder = './pytorch-distr'\n",
|
||||
"os.makedirs(project_folder, exist_ok=True)"
|
||||
]
|
||||
@@ -172,26 +233,14 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Prepare training script\n",
|
||||
"Now you will need to create your training script. In this tutorial, the script for distributed training of MNIST is already provided for you at `pytorch_mnist.py`. In practice, you should be able to take any custom PyTorch training script as is and run it with Azure ML without having to modify your code.\n",
|
||||
"\n",
|
||||
"However, if you would like to use Azure ML's [metric logging](https://docs.microsoft.com/azure/machine-learning/service/concept-azure-machine-learning-architecture#logging) capabilities, you will have to add a small amount of Azure ML logic inside your training script. In this example, at each logging interval, we will log the loss for that minibatch to our Azure ML run.\n",
|
||||
"\n",
|
||||
"To do so, in `pytorch_mnist.py`, we will first access the Azure ML `Run` object within the script:\n",
|
||||
"```Python\n",
|
||||
"from azureml.core.run import Run\n",
|
||||
"run = Run.get_context()\n",
|
||||
"```\n",
|
||||
"Later within the script, we log the loss metric to our run:\n",
|
||||
"```Python\n",
|
||||
"run.log('loss', losses.avg)\n",
|
||||
"```"
|
||||
"Now you will need to create your training script. In this tutorial, the script for distributed training on CIFAR-10 is already provided for you at `train.py`. In practice, you should be able to take any custom PyTorch training script as is and run it with Azure ML without having to modify your code."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once your script is ready, copy the training script `pytorch_mnist.py` into the project directory."
|
||||
"Once your script is ready, copy the training script `train.py` into the project directory."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -202,7 +251,7 @@
|
||||
"source": [
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"shutil.copy('pytorch_mnist.py', project_folder)"
|
||||
"shutil.copy('train.py', project_folder)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -231,26 +280,7 @@
|
||||
"source": [
|
||||
"### Create an environment\n",
|
||||
"\n",
|
||||
"Define a conda environment YAML file with your training script dependencies and create an Azure ML environment."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%writefile conda_dependencies.yml\n",
|
||||
"\n",
|
||||
"channels:\n",
|
||||
"- conda-forge\n",
|
||||
"dependencies:\n",
|
||||
"- python=3.6.2\n",
|
||||
"- pip:\n",
|
||||
" - azureml-defaults\n",
|
||||
" - torch==1.6.0\n",
|
||||
" - torchvision==0.7.0\n",
|
||||
" - future==0.17.1"
|
||||
"In this tutorial, we will use one of Azure ML's curated PyTorch environments for training. [Curated environments](https://docs.microsoft.com/azure/machine-learning/how-to-use-environments#use-a-curated-environment) are available in your workspace by default. Specifically, we will use the PyTorch 1.6 GPU curated environment."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -261,24 +291,39 @@
|
||||
"source": [
|
||||
"from azureml.core import Environment\n",
|
||||
"\n",
|
||||
"pytorch_env = Environment.from_conda_specification(name = 'pytorch-1.6-gpu', file_path = './conda_dependencies.yml')\n",
|
||||
"\n",
|
||||
"# Specify a GPU base image\n",
|
||||
"pytorch_env.docker.enabled = True\n",
|
||||
"pytorch_env.docker.base_image = 'mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.1-cudnn7-ubuntu18.04'"
|
||||
"pytorch_env = Environment.get(ws, name='AzureML-PyTorch-1.6-GPU')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Configure the training job: torch.distributed with NCCL backend\n",
|
||||
"### Configure the training job\n",
|
||||
"\n",
|
||||
"Create a ScriptRunConfig object to specify the configuration details of your training job, including your training script, environment to use, and the compute target to run on.\n",
|
||||
"To launch a distributed PyTorch job on Azure ML, you have two options:\n",
|
||||
"\n",
|
||||
"In order to run a distributed PyTorch job with **torch.distributed** using the NCCL backend, create a `PyTorchConfiguration` and pass it to the `distributed_job_config` parameter of the ScriptRunConfig constructor. Specify `communication_backend='Nccl'` in the PyTorchConfiguration. The below code will configure a 2-node distributed job. The NCCL backend is the recommended backend for PyTorch distributed GPU training.\n",
|
||||
"1. Per-process launch - specify the total # of worker processes (typically one per GPU) you want to run, and\n",
|
||||
"Azure ML will handle launching each process.\n",
|
||||
"2. Per-node launch with [torch.distributed.launch](https://pytorch.org/docs/stable/distributed.html#launch-utility) - provide the `torch.distributed.launch` command you want to\n",
|
||||
"run on each node.\n",
|
||||
"\n",
|
||||
"The script arguments refers to the Azure ML-set environment variables `AZ_BATCHAI_PYTORCH_INIT_METHOD` for shared file-system initialization and `AZ_BATCHAI_TASK_INDEX` for the global rank of each worker process."
|
||||
"For more information, see the [documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-pytorch#distributeddataparallel).\n",
|
||||
"\n",
|
||||
"Both options are shown below."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Per-process launch\n",
|
||||
"\n",
|
||||
"To use the per-process launch option in which Azure ML will handle launching each of the processes to run your training script,\n",
|
||||
"\n",
|
||||
"1. Specify the training script and arguments\n",
|
||||
"2. Create a `PyTorchConfiguration` and specify `node_count` and `process_count`. The `process_count` is the total number of processes you want to run for the job; this should typically equal the # of GPUs available on each node multiplied by the # of nodes. Since this tutorial uses the `STANDARD_NC6` SKU, which has one GPU, the total process count for a 2-node job is `2`. If you are using a SKU with >1 GPUs, adjust the `process_count` accordingly.\n",
|
||||
"\n",
|
||||
"Azure ML will set the `MASTER_ADDR`, `MASTER_PORT`, `NODE_RANK`, `WORLD_SIZE` environment variables on each node, in addition to the process-level `RANK` and `LOCAL_RANK` environment variables, that are needed for distributed PyTorch training."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -290,17 +335,61 @@
|
||||
"from azureml.core import ScriptRunConfig\n",
|
||||
"from azureml.core.runconfig import PyTorchConfiguration\n",
|
||||
"\n",
|
||||
"args = ['--dist-backend', 'nccl',\n",
|
||||
" '--dist-url', '$AZ_BATCHAI_PYTORCH_INIT_METHOD',\n",
|
||||
" '--rank', '$AZ_BATCHAI_TASK_INDEX',\n",
|
||||
" '--world-size', 2]\n",
|
||||
"# create distributed config\n",
|
||||
"distr_config = PyTorchConfiguration(process_count=2, node_count=2)\n",
|
||||
"\n",
|
||||
"# create args\n",
|
||||
"args = [\"--data-dir\", dataset.as_download(), \"--epochs\", 25]\n",
|
||||
"\n",
|
||||
"# create job config\n",
|
||||
"src = ScriptRunConfig(source_directory=project_folder,\n",
|
||||
" script='pytorch_mnist.py',\n",
|
||||
" script='train.py',\n",
|
||||
" arguments=args,\n",
|
||||
" compute_target=compute_target,\n",
|
||||
" environment=pytorch_env,\n",
|
||||
" distributed_job_config=PyTorchConfiguration(communication_backend='Nccl', node_count=2))"
|
||||
" distributed_job_config=distr_config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"#### Per-node launch with `torch.distributed.launch`\n",
|
||||
"\n",
|
||||
"If you would instead like to use the PyTorch-provided launch utility `torch.distributed.launch` to handle launching the worker processes on each node, you can do so as well. \n",
|
||||
"\n",
|
||||
"1. Provide the launch command to the `command` parameter of ScriptRunConfig. For PyTorch jobs Azure ML will set the `MASTER_ADDR`, `MASTER_PORT`, and `NODE_RANK` environment variables on each node, so you can simply just reference those environment variables in your command. If you are using a SKU with >1 GPUs, adjust the `--nproc_per_node` argument accordingly.\n",
|
||||
"\n",
|
||||
"2. Create a `PyTorchConfiguration` and specify the `node_count`. You do not need to specify the `process_count`; by default Azure ML will launch one process per node to run the `command` you provided.\n",
|
||||
"\n",
|
||||
"Uncomment the code below to configure a job with this method."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"'''\n",
|
||||
"from azureml.core import ScriptRunConfig\n",
|
||||
"from azureml.core.runconfig import PyTorchConfiguration\n",
|
||||
"\n",
|
||||
"# create distributed config\n",
|
||||
"distr_config = PyTorchConfiguration(node_count=2)\n",
|
||||
"\n",
|
||||
"# define command\n",
|
||||
"launch_cmd = [\"python -m torch.distributed.launch --nproc_per_node 1 --nnodes 2 \" \\\n",
|
||||
" \"--node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT --use_env \" \\\n",
|
||||
" \"train.py --data-dir\", dataset.as_download(), \"--epochs 25\"]\n",
|
||||
"\n",
|
||||
"# create job config\n",
|
||||
"src = ScriptRunConfig(source_directory=project_folder,\n",
|
||||
" command=launch_cmd,\n",
|
||||
" compute_target=compute_target,\n",
|
||||
" environment=pytorch_env,\n",
|
||||
" distributed_job_config=distr_config)\n",
|
||||
"'''"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -308,7 +397,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Submit job\n",
|
||||
"Run your experiment by submitting your ScriptRunConfig object. Note that this call is asynchronous."
|
||||
"Run your experiment by submitting your `ScriptRunConfig` object. Note that this call is asynchronous."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -355,50 +444,12 @@
|
||||
"source": [
|
||||
"run.wait_for_completion(show_output=True) # this provides a verbose log"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Configure training job: torch.distributed with Gloo backend\n",
|
||||
"\n",
|
||||
"If you would instead like to use the Gloo backend for distributed training, you can do so via the following code. The Gloo backend is recommended for distributed CPU training."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core import ScriptRunConfig\n",
|
||||
"from azureml.core.runconfig import PyTorchConfiguration\n",
|
||||
"\n",
|
||||
"args = ['--dist-backend', 'gloo',\n",
|
||||
" '--dist-url', '$AZ_BATCHAI_PYTORCH_INIT_METHOD',\n",
|
||||
" '--rank', '$AZ_BATCHAI_TASK_INDEX',\n",
|
||||
" '--world-size', 2]\n",
|
||||
"\n",
|
||||
"src = ScriptRunConfig(source_directory=project_folder,\n",
|
||||
" script='pytorch_mnist.py',\n",
|
||||
" arguments=args,\n",
|
||||
" compute_target=compute_target,\n",
|
||||
" environment=pytorch_env,\n",
|
||||
" distributed_job_config=PyTorchConfiguration(communication_backend='Gloo', node_count=2))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Once you create the ScriptRunConfig, you can follow the submit steps as shown in the previous steps to submit a PyTorch distributed run using the Gloo backend."
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"authors": [
|
||||
{
|
||||
"name": "ninhu"
|
||||
"name": "minxia"
|
||||
}
|
||||
],
|
||||
"category": "training",
|
||||
@@ -406,7 +457,7 @@
|
||||
"AML Compute"
|
||||
],
|
||||
"datasets": [
|
||||
"MNIST"
|
||||
"CIFAR-10"
|
||||
],
|
||||
"deployment": [
|
||||
"None"
|
||||
@@ -432,12 +483,12 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.6.9"
|
||||
"version": "3.7.7"
|
||||
},
|
||||
"tags": [
|
||||
"None"
|
||||
],
|
||||
"task": "Train a model using distributed training via Nccl/Gloo"
|
||||
"task": "Train a model using distributed training via PyTorch DistributedDataParallel"
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
@@ -0,0 +1,5 @@
|
||||
name: distributed-pytorch-with-distributeddataparallel
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,238 @@
|
||||
# Copyright (c) 2017 Facebook, Inc. All rights reserved.
|
||||
# BSD 3-Clause License
|
||||
#
|
||||
# Script adapted from:
|
||||
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
|
||||
# ==============================================================================
|
||||
|
||||
# imports
|
||||
import torch
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
import os
|
||||
import argparse
|
||||
|
||||
|
||||
# define network architecture
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 32, 3)
|
||||
self.pool = nn.MaxPool2d(2, 2)
|
||||
self.conv2 = nn.Conv2d(32, 64, 3)
|
||||
self.conv3 = nn.Conv2d(64, 128, 3)
|
||||
self.fc1 = nn.Linear(128 * 6 * 6, 120)
|
||||
self.dropout = nn.Dropout(p=0.2)
|
||||
self.fc2 = nn.Linear(120, 84)
|
||||
self.fc3 = nn.Linear(84, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.conv1(x))
|
||||
x = self.pool(F.relu(self.conv2(x)))
|
||||
x = self.pool(F.relu(self.conv3(x)))
|
||||
x = x.view(-1, 128 * 6 * 6)
|
||||
x = self.dropout(F.relu(self.fc1(x)))
|
||||
x = F.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
def train(train_loader, model, criterion, optimizer, epoch, device, print_freq, rank):
|
||||
running_loss = 0.0
|
||||
for i, data in enumerate(train_loader, 0):
|
||||
# get the inputs; data is a list of [inputs, labels]
|
||||
inputs, labels = data[0].to(device), data[1].to(device)
|
||||
|
||||
# zero the parameter gradients
|
||||
optimizer.zero_grad()
|
||||
|
||||
# forward + backward + optimize
|
||||
outputs = model(inputs)
|
||||
loss = criterion(outputs, labels)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# print statistics
|
||||
running_loss += loss.item()
|
||||
if i % print_freq == 0: # print every print_freq mini-batches
|
||||
print(
|
||||
"Rank %d: [%d, %5d] loss: %.3f"
|
||||
% (rank, epoch + 1, i + 1, running_loss / print_freq)
|
||||
)
|
||||
running_loss = 0.0
|
||||
|
||||
|
||||
def evaluate(test_loader, model, device):
|
||||
classes = (
|
||||
"plane",
|
||||
"car",
|
||||
"bird",
|
||||
"cat",
|
||||
"deer",
|
||||
"dog",
|
||||
"frog",
|
||||
"horse",
|
||||
"ship",
|
||||
"truck",
|
||||
)
|
||||
|
||||
model.eval()
|
||||
|
||||
correct = 0
|
||||
total = 0
|
||||
class_correct = list(0.0 for i in range(10))
|
||||
class_total = list(0.0 for i in range(10))
|
||||
with torch.no_grad():
|
||||
for data in test_loader:
|
||||
images, labels = data[0].to(device), data[1].to(device)
|
||||
outputs = model(images)
|
||||
_, predicted = torch.max(outputs.data, 1)
|
||||
total += labels.size(0)
|
||||
correct += (predicted == labels).sum().item()
|
||||
c = (predicted == labels).squeeze()
|
||||
for i in range(10):
|
||||
label = labels[i]
|
||||
class_correct[label] += c[i].item()
|
||||
class_total[label] += 1
|
||||
|
||||
# print total test set accuracy
|
||||
print(
|
||||
"Accuracy of the network on the 10000 test images: %d %%"
|
||||
% (100 * correct / total)
|
||||
)
|
||||
|
||||
# print test accuracy for each of the classes
|
||||
for i in range(10):
|
||||
print(
|
||||
"Accuracy of %5s : %2d %%"
|
||||
% (classes[i], 100 * class_correct[i] / class_total[i])
|
||||
)
|
||||
|
||||
|
||||
def main(args):
|
||||
# get PyTorch environment variables
|
||||
world_size = int(os.environ["WORLD_SIZE"])
|
||||
rank = int(os.environ["RANK"])
|
||||
local_rank = int(os.environ["LOCAL_RANK"])
|
||||
|
||||
distributed = world_size > 1
|
||||
|
||||
# set device
|
||||
if distributed:
|
||||
device = torch.device("cuda", local_rank)
|
||||
else:
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# initialize distributed process group using default env:// method
|
||||
if distributed:
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
|
||||
# define train and test dataset DataLoaders
|
||||
transform = transforms.Compose(
|
||||
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
|
||||
)
|
||||
|
||||
train_set = torchvision.datasets.CIFAR10(
|
||||
root=args.data_dir, train=True, download=False, transform=transform
|
||||
)
|
||||
|
||||
if distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_set)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_set,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=(train_sampler is None),
|
||||
num_workers=args.workers,
|
||||
sampler=train_sampler,
|
||||
)
|
||||
|
||||
test_set = torchvision.datasets.CIFAR10(
|
||||
root=args.data_dir, train=False, download=False, transform=transform
|
||||
)
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers
|
||||
)
|
||||
|
||||
model = Net().to(device)
|
||||
|
||||
# wrap model with DDP
|
||||
if distributed:
|
||||
model = nn.parallel.DistributedDataParallel(
|
||||
model, device_ids=[local_rank], output_device=local_rank
|
||||
)
|
||||
|
||||
# define loss function and optimizer
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
optimizer = optim.SGD(
|
||||
model.parameters(), lr=args.learning_rate, momentum=args.momentum
|
||||
)
|
||||
|
||||
# train the model
|
||||
for epoch in range(args.epochs):
|
||||
print("Rank %d: Starting epoch %d" % (rank, epoch))
|
||||
if distributed:
|
||||
train_sampler.set_epoch(epoch)
|
||||
model.train()
|
||||
train(
|
||||
train_loader,
|
||||
model,
|
||||
criterion,
|
||||
optimizer,
|
||||
epoch,
|
||||
device,
|
||||
args.print_freq,
|
||||
rank,
|
||||
)
|
||||
|
||||
print("Rank %d: Finished Training" % (rank))
|
||||
|
||||
if not distributed or rank == 0:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
model_path = os.path.join(args.output_dir, "cifar_net.pt")
|
||||
torch.save(model.state_dict(), model_path)
|
||||
|
||||
# evaluate on full test dataset
|
||||
evaluate(test_loader, model, device)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# setup argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data-dir", type=str, help="directory containing CIFAR-10 dataset"
|
||||
)
|
||||
parser.add_argument("--epochs", default=10, type=int, help="number of epochs")
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
default=16,
|
||||
type=int,
|
||||
help="mini batch size for each gpu/process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
default=2,
|
||||
type=int,
|
||||
help="number of data loading workers for each gpu/process",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", default=0.001, type=float, help="learning rate"
|
||||
)
|
||||
parser.add_argument("--momentum", default=0.9, type=float, help="momentum")
|
||||
parser.add_argument(
|
||||
"--output-dir", default="outputs", type=str, help="directory to save model to"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--print-freq",
|
||||
default=200,
|
||||
type=int,
|
||||
help="frequency of printing training statistics",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(args)
|
||||
@@ -0,0 +1,5 @@
|
||||
name: distributed-pytorch-with-horovod
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -1,209 +0,0 @@
|
||||
# Copyright (c) 2017, PyTorch contributors
|
||||
# Modifications copyright (C) Microsoft Corporation
|
||||
# Licensed under the BSD license
|
||||
# Adapted from https://github.com/Azure/BatchAI/tree/master/recipes/PyTorch/PyTorch-GPU-Distributed-Gloo
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import os
|
||||
import shutil
|
||||
import time
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torchvision import datasets, transforms
|
||||
import torch.nn.parallel
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.distributed as dist
|
||||
import torch.utils.data
|
||||
import torch.utils.data.distributed
|
||||
import torchvision.models as models
|
||||
|
||||
from azureml.core.run import Run
|
||||
# get the Azure ML run object
|
||||
run = Run.get_context()
|
||||
|
||||
# Training settings
|
||||
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
|
||||
parser.add_argument('--batch-size', type=int, default=64, metavar='N',
|
||||
help='input batch size for training (default: 64)')
|
||||
parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N',
|
||||
help='input batch size for testing (default: 1000)')
|
||||
parser.add_argument('--epochs', type=int, default=10, metavar='N',
|
||||
help='number of epochs to train (default: 10)')
|
||||
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
|
||||
help='learning rate (default: 0.01)')
|
||||
parser.add_argument('--momentum', type=float, default=0.5, metavar='M',
|
||||
help='SGD momentum (default: 0.5)')
|
||||
parser.add_argument('--seed', type=int, default=1, metavar='S',
|
||||
help='random seed (default: 1)')
|
||||
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
|
||||
help='number of data loading workers (default: 4)')
|
||||
parser.add_argument('--log-interval', type=int, default=10, metavar='N',
|
||||
help='how many batches to wait before logging training status')
|
||||
parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
|
||||
metavar='W', help='weight decay (default: 1e-4)')
|
||||
parser.add_argument('--world-size', default=1, type=int,
|
||||
help='number of distributed processes')
|
||||
parser.add_argument('--dist-url', type=str,
|
||||
help='url used to set up distributed training')
|
||||
parser.add_argument('--dist-backend', default='nccl', type=str,
|
||||
help='distributed backend')
|
||||
parser.add_argument('--rank', default=-1, type=int,
|
||||
help='rank of the worker')
|
||||
|
||||
best_prec1 = 0
|
||||
args = parser.parse_args()
|
||||
|
||||
args.distributed = args.world_size >= 2
|
||||
|
||||
if args.distributed:
|
||||
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
||||
world_size=args.world_size, rank=args.rank)
|
||||
|
||||
train_dataset = datasets.MNIST('data-%d' % args.rank, train=True, download=True,
|
||||
transform=transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.1307,), (0.3081,))
|
||||
]))
|
||||
|
||||
if args.distributed:
|
||||
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
train_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size, shuffle=(train_sampler is None),
|
||||
num_workers=args.workers, pin_memory=True, sampler=train_sampler)
|
||||
|
||||
|
||||
test_loader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size, shuffle=False,
|
||||
num_workers=args.workers, pin_memory=True)
|
||||
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(self):
|
||||
super(Net, self).__init__()
|
||||
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
||||
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
||||
self.conv2_drop = nn.Dropout2d()
|
||||
self.fc1 = nn.Linear(320, 50)
|
||||
self.fc2 = nn.Linear(50, 10)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
||||
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
||||
x = x.view(-1, 320)
|
||||
x = F.relu(self.fc1(x))
|
||||
x = F.dropout(x, training=self.training)
|
||||
x = self.fc2(x)
|
||||
return F.log_softmax(x)
|
||||
|
||||
|
||||
model = Net()
|
||||
|
||||
if not args.distributed:
|
||||
model = torch.nn.DataParallel(model).cuda()
|
||||
else:
|
||||
model.cuda()
|
||||
model = torch.nn.parallel.DistributedDataParallel(model)
|
||||
|
||||
# define loss function (criterion) and optimizer
|
||||
criterion = nn.CrossEntropyLoss().cuda()
|
||||
|
||||
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
|
||||
|
||||
|
||||
def train(epoch):
|
||||
batch_time = AverageMeter()
|
||||
data_time = AverageMeter()
|
||||
losses = AverageMeter()
|
||||
top1 = AverageMeter()
|
||||
top5 = AverageMeter()
|
||||
|
||||
# switch to train mode
|
||||
model.train()
|
||||
end = time.time()
|
||||
for i, (input, target) in enumerate(train_loader):
|
||||
# measure data loading time
|
||||
data_time.update(time.time() - end)
|
||||
|
||||
input, target = input.cuda(), target.cuda()
|
||||
|
||||
# compute output
|
||||
try:
|
||||
output = model(input)
|
||||
loss = criterion(output, target)
|
||||
|
||||
# measure accuracy and record loss
|
||||
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
|
||||
losses.update(loss.item(), input.size(0))
|
||||
top1.update(prec1[0], input.size(0))
|
||||
top5.update(prec5[0], input.size(0))
|
||||
|
||||
# compute gradient and do SGD step
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# measure elapsed time
|
||||
batch_time.update(time.time() - end)
|
||||
end = time.time()
|
||||
|
||||
if i % 10 == 0:
|
||||
run.log("loss", losses.avg)
|
||||
run.log("prec@1", "{0:.3f}".format(top1.avg))
|
||||
run.log("prec@5", "{0:.3f}".format(top5.avg))
|
||||
print('Epoch: [{0}][{1}/{2}]\t'
|
||||
'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
|
||||
'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
|
||||
'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
|
||||
'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
|
||||
'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader),
|
||||
batch_time=batch_time, data_time=data_time,
|
||||
loss=losses, top1=top1, top5=top5))
|
||||
except:
|
||||
import sys
|
||||
print("Unexpected error:", sys.exc_info()[0])
|
||||
|
||||
|
||||
class AverageMeter(object):
|
||||
"""Computes and stores the average and current value"""
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
def accuracy(output, target, topk=(1,)):
|
||||
"""Computes the precision@k for the specified values of k"""
|
||||
maxk = max(topk)
|
||||
batch_size = target.size(0)
|
||||
|
||||
_, pred = output.topk(maxk, 1, True, True)
|
||||
pred = pred.t()
|
||||
correct = pred.eq(target.view(1, -1).expand_as(pred))
|
||||
|
||||
res = []
|
||||
for k in topk:
|
||||
correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
|
||||
res.append(correct_k.mul_(100.0 / batch_size))
|
||||
return res
|
||||
|
||||
|
||||
for epoch in range(1, args.epochs + 1):
|
||||
train(epoch)
|
||||
@@ -0,0 +1,10 @@
|
||||
name: train-hyperparameter-tune-deploy-with-pytorch
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- pillow==5.4.1
|
||||
- matplotlib
|
||||
- numpy==1.19.3
|
||||
- https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp36-cp36m-win_amd64.whl
|
||||
- https://download.pytorch.org/whl/cpu/torchvision-0.7.0%2Bcpu-cp36-cp36m-win_amd64.whl
|
||||
@@ -0,0 +1,6 @@
|
||||
name: train-hyperparameter-tune-deploy-with-sklearn
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- numpy
|
||||
@@ -0,0 +1,11 @@
|
||||
name: distributed-tensorflow-with-horovod
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- keras
|
||||
- tensorflow-gpu==1.13.2
|
||||
- horovod==0.19.1
|
||||
- matplotlib
|
||||
- pandas
|
||||
- fuse
|
||||
@@ -0,0 +1,5 @@
|
||||
name: distributed-tensorflow-with-parameter-server
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,12 @@
|
||||
name: train-hyperparameter-tune-deploy-with-tensorflow
|
||||
dependencies:
|
||||
- numpy
|
||||
- matplotlib
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-widgets
|
||||
- pandas
|
||||
- keras
|
||||
- tensorflow==2.0.0
|
||||
- matplotlib
|
||||
- fuse
|
||||
@@ -0,0 +1,8 @@
|
||||
name: pong_rllib
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-contrib-reinforcementlearning
|
||||
- azureml-widgets
|
||||
- matplotlib
|
||||
- azure-mgmt-network==12.0.0
|
||||
@@ -0,0 +1,6 @@
|
||||
name: cartpole_ci
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-contrib-reinforcementlearning
|
||||
- azureml-widgets
|
||||
@@ -0,0 +1,6 @@
|
||||
name: cartpole_sc
|
||||
dependencies:
|
||||
- pip:
|
||||
- azureml-sdk
|
||||
- azureml-contrib-reinforcementlearning
|
||||
- azureml-widgets
|
||||
@@ -1,70 +0,0 @@
|
||||
FROM mcr.microsoft.com/azureml/base:openmpi3.1.2-ubuntu18.04
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
cpio \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
tmux \
|
||||
htop \
|
||||
gcc \
|
||||
xvfb \
|
||||
python-opengl \
|
||||
x11-xserver-utils \
|
||||
ffmpeg \
|
||||
mesa-utils \
|
||||
nano \
|
||||
vim \
|
||||
rsync \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create a working directory
|
||||
RUN mkdir /app
|
||||
WORKDIR /app
|
||||
|
||||
# Install Minecraft needed libraries
|
||||
RUN mkdir -p /usr/share/man/man1 && \
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y \
|
||||
openjdk-8-jre-headless=8u162-b12-1 \
|
||||
openjdk-8-jdk-headless=8u162-b12-1 \
|
||||
openjdk-8-jre=8u162-b12-1 \
|
||||
openjdk-8-jdk=8u162-b12-1
|
||||
|
||||
# Create a Python 3.7 environment
|
||||
RUN conda install conda-build \
|
||||
&& conda create -y --name py37 python=3.7.3 \
|
||||
&& conda clean -ya
|
||||
ENV CONDA_DEFAULT_ENV=py37
|
||||
|
||||
# Install minerl
|
||||
RUN pip install --upgrade --user minerl
|
||||
|
||||
RUN pip install \
|
||||
pandas \
|
||||
matplotlib \
|
||||
numpy \
|
||||
scipy \
|
||||
azureml-defaults \
|
||||
tensorboardX \
|
||||
tensorflow==1.15rc2 \
|
||||
tabulate \
|
||||
dm_tree \
|
||||
lz4 \
|
||||
ray==0.8.3 \
|
||||
ray[rllib]==0.8.3 \
|
||||
ray[tune]==0.8.3
|
||||
|
||||
COPY patch_files/* /root/.local/lib/python3.7/site-packages/minerl/env/Malmo/Minecraft/src/main/java/com/microsoft/Malmo/Client/
|
||||
|
||||
# Start minerl to pre-fetch minerl files (saves time when starting minerl during training)
|
||||
RUN xvfb-run -a -s "-screen 0 1400x900x24" python -c "import gym; import minerl; env = gym.make('MineRLTreechop-v0'); env.close();"
|
||||
|
||||
RUN pip install --index-url https://test.pypi.org/simple/ malmo && \
|
||||
python -c "import malmo.minecraftbootstrap; malmo.minecraftbootstrap.download();"
|
||||
|
||||
ENV MALMO_XSD_PATH="/app/MalmoPlatform/Schemas"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,939 +0,0 @@
|
||||
// --------------------------------------------------------------------------------------------------
|
||||
// Copyright (c) 2016 Microsoft Corporation
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
// associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||
// sublicense, and/or l copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all copies or
|
||||
// substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
|
||||
// NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
// --------------------------------------------------------------------------------------------------
|
||||
|
||||
package com.microsoft.Malmo.Client;
|
||||
|
||||
import com.microsoft.Malmo.MalmoMod;
|
||||
import com.microsoft.Malmo.MissionHandlerInterfaces.IWantToQuit;
|
||||
import com.microsoft.Malmo.Schemas.MissionInit;
|
||||
import com.microsoft.Malmo.Utils.TCPUtils;
|
||||
|
||||
import net.minecraft.profiler.Profiler;
|
||||
import com.microsoft.Malmo.Utils.TimeHelper;
|
||||
|
||||
import net.minecraftforge.common.config.Configuration;
|
||||
import java.io.*;
|
||||
import java.net.ServerSocket;
|
||||
import java.net.Socket;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.locks.Condition;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.Hashtable;
|
||||
import com.microsoft.Malmo.Utils.TCPInputPoller;
|
||||
import java.util.logging.Level;
|
||||
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
/**
|
||||
* MalmoEnvServer - service supporting OpenAI gym "environment" for multi-agent Malmo missions.
|
||||
*/
|
||||
public class MalmoEnvServer implements IWantToQuit {
|
||||
private static Profiler profiler = new Profiler();
|
||||
private static int nsteps = 0;
|
||||
private static boolean debug = false;
|
||||
|
||||
private static String hello = "<MalmoEnv" ;
|
||||
|
||||
private class EnvState {
|
||||
|
||||
// Mission parameters:
|
||||
String missionInit = null;
|
||||
String token = null;
|
||||
String experimentId = null;
|
||||
int agentCount = 0;
|
||||
int reset = 0;
|
||||
boolean quit = false;
|
||||
boolean synchronous = false;
|
||||
Long seed = null;
|
||||
|
||||
// OpenAI gym state:
|
||||
boolean done = false;
|
||||
double reward = 0.0;
|
||||
byte[] obs = null;
|
||||
String info = "";
|
||||
LinkedList<String> commands = new LinkedList<String>();
|
||||
}
|
||||
|
||||
private static boolean envPolicy = false; // Are we configured by config policy?
|
||||
|
||||
// Synchronize on EnvStateasd
|
||||
|
||||
|
||||
private Lock lock = new ReentrantLock();
|
||||
private Condition cond = lock.newCondition();
|
||||
|
||||
private EnvState envState = new EnvState();
|
||||
|
||||
private Hashtable<String, Integer> initTokens = new Hashtable<String, Integer>();
|
||||
|
||||
static final long COND_WAIT_SECONDS = 3; // Max wait in seconds before timing out (and replying to RPC).
|
||||
static final int BYTES_INT = 4;
|
||||
static final int BYTES_DOUBLE = 8;
|
||||
private static final Charset utf8 = Charset.forName("UTF-8");
|
||||
|
||||
// Service uses a single per-environment client connection - initiated by the remote environment.
|
||||
|
||||
private int port;
|
||||
private TCPInputPoller missionPoller; // Used for command parsing and not actual communication.
|
||||
private String version;
|
||||
|
||||
// AOG: From running experiments, I've found that MineRL can get stuck resetting the
|
||||
// environment which causes huge delays while we wait for the Python side to time
|
||||
// out and restart the Minecraft instace. Minecraft itself is normally in a recoverable
|
||||
// state, but the MalmoEnvServer instance will be blocked in a tight spin loop trying
|
||||
// handling a Peek request from the Python client. To unstick things, I've added this
|
||||
// flag that can be set when we know things are in a bad state to abort the peek request.
|
||||
// WARNING: THIS IS ONLY TREATING THE SYMPTOM AND NOT THE ROOT CAUSE
|
||||
// The reason things are getting stuck is because the player is either dying or we're
|
||||
// receiving a quit request while an episode reset is in progress.
|
||||
private boolean abortRequest;
|
||||
public void abort() {
|
||||
System.out.println("AOG: MalmoEnvServer.abort");
|
||||
abortRequest = true;
|
||||
}
|
||||
|
||||
/***
|
||||
* Malmo "Env" service.
|
||||
* @param port the port the service listens on.
|
||||
* @param missionPoller for plugging into existing comms handling.
|
||||
*/
|
||||
public MalmoEnvServer(String version, int port, TCPInputPoller missionPoller) {
|
||||
this.version = version;
|
||||
this.missionPoller = missionPoller;
|
||||
this.port = port;
|
||||
// AOG - Assume we don't wan't to be aborting in the first place
|
||||
this.abortRequest = false;
|
||||
}
|
||||
|
||||
/** Initialize malmo env configuration. For now either on or "legacy" AgentHost protocol.*/
|
||||
static public void update(Configuration configs) {
|
||||
envPolicy = configs.get(MalmoMod.ENV_CONFIGS, "env", "false").getBoolean();
|
||||
}
|
||||
|
||||
public static boolean isEnv() {
|
||||
return envPolicy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Start servicing the MalmoEnv protocol.
|
||||
* @throws IOException
|
||||
*/
|
||||
public void serve() throws IOException {
|
||||
|
||||
ServerSocket serverSocket = new ServerSocket(port);
|
||||
serverSocket.setPerformancePreferences(0,2,1);
|
||||
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
final Socket socket = serverSocket.accept();
|
||||
socket.setTcpNoDelay(true);
|
||||
|
||||
Thread thread = new Thread("EnvServerSocketHandler") {
|
||||
public void run() {
|
||||
boolean running = false;
|
||||
try {
|
||||
checkHello(socket);
|
||||
|
||||
while (true) {
|
||||
DataInputStream din = new DataInputStream(socket.getInputStream());
|
||||
int hdr = din.readInt();
|
||||
byte[] data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
|
||||
String command = new String(data, utf8);
|
||||
|
||||
if (command.startsWith("<Step")) {
|
||||
|
||||
profiler.startSection("root");
|
||||
long start = System.nanoTime();
|
||||
step(command, socket, din);
|
||||
profiler.endSection();
|
||||
if (nsteps % 100 == 0 && debug){
|
||||
List<Profiler.Result> dat = profiler.getProfilingData("root");
|
||||
for(int qq = 0; qq < dat.size(); qq++){
|
||||
Profiler.Result res = dat.get(qq);
|
||||
System.out.println(res.profilerName + " " + res.totalUsePercentage + " "+ res.usePercentage);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} else if (command.startsWith("<Peek")) {
|
||||
|
||||
peek(command, socket, din);
|
||||
|
||||
} else if (command.startsWith("<Init")) {
|
||||
|
||||
init(command, socket);
|
||||
|
||||
} else if (command.startsWith("<Find")) {
|
||||
|
||||
find(command, socket);
|
||||
|
||||
} else if (command.startsWith("<MissionInit")) {
|
||||
|
||||
if (missionInit(din, command, socket))
|
||||
{
|
||||
running = true;
|
||||
}
|
||||
|
||||
} else if (command.startsWith("<Quit")) {
|
||||
|
||||
quit(command, socket);
|
||||
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Exit")) {
|
||||
|
||||
exit(command, socket);
|
||||
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Close")) {
|
||||
|
||||
close(command, socket);
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Status")) {
|
||||
|
||||
status(command, socket);
|
||||
|
||||
} else if (command.startsWith("<Echo")) {
|
||||
command = "<Echo>" + command + "</Echo>";
|
||||
data = command.getBytes(utf8);
|
||||
hdr = data.length;
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(hdr);
|
||||
dout.write(data, 0, hdr);
|
||||
dout.flush();
|
||||
} else {
|
||||
throw new IOException("Unknown env service command");
|
||||
}
|
||||
}
|
||||
} catch (IOException ioe) {
|
||||
// ioe.printStackTrace();
|
||||
TCPUtils.Log(Level.SEVERE, "MalmoEnv socket error: " + ioe + " (can be on disconnect)");
|
||||
// System.out.println("[ERROR] " + "MalmoEnv socket error: " + ioe + " (can be on disconnect)");
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] MalmoEnv socket error");
|
||||
try {
|
||||
if (running) {
|
||||
TCPUtils.Log(Level.INFO,"Want to quit on disconnect.");
|
||||
|
||||
System.out.println("[LOGTOPY] " + "Want to quit on disconnect.");
|
||||
setWantToQuit();
|
||||
}
|
||||
socket.close();
|
||||
} catch (IOException ioe2) {
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
thread.start();
|
||||
} catch (IOException ioe) {
|
||||
TCPUtils.Log(Level.SEVERE, "MalmoEnv service exits on " + ioe);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void checkHello(Socket socket) throws IOException {
|
||||
DataInputStream din = new DataInputStream(socket.getInputStream());
|
||||
int hdr = din.readInt();
|
||||
if (hdr <= 0 || hdr > hello.length() + 8) // Version number may be somewhat longer in future.
|
||||
throw new IOException("Invalid MalmoEnv hello header length");
|
||||
byte[] data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
if (!new String(data).startsWith(hello + version))
|
||||
throw new IOException("MalmoEnv invalid protocol or version - expected " + hello + version);
|
||||
}
|
||||
|
||||
// Handler for <MissionInit> messages.
|
||||
private boolean missionInit(DataInputStream din, String command, Socket socket) throws IOException {
|
||||
|
||||
String ipOriginator = socket.getInetAddress().getHostName();
|
||||
|
||||
int hdr;
|
||||
byte[] data;
|
||||
hdr = din.readInt();
|
||||
data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
String id = new String(data, utf8);
|
||||
|
||||
TCPUtils.Log(Level.INFO,"Mission Init" + id);
|
||||
|
||||
String[] token = id.split(":");
|
||||
String experimentId = token[0];
|
||||
int role = Integer.parseInt(token[1]);
|
||||
int reset = Integer.parseInt(token[2]);
|
||||
int agentCount = Integer.parseInt(token[3]);
|
||||
Boolean isSynchronous = Boolean.parseBoolean(token[4]);
|
||||
Long seed = null;
|
||||
if(token.length > 5)
|
||||
seed = Long.parseLong(token[5]);
|
||||
|
||||
if(isSynchronous && agentCount > 1){
|
||||
throw new IOException("Synchronous mode currently does not support multiple agents.");
|
||||
}
|
||||
port = -1;
|
||||
boolean allTokensConsumed = true;
|
||||
boolean started = false;
|
||||
|
||||
lock.lock();
|
||||
try {
|
||||
if (role == 0) {
|
||||
|
||||
String previousToken = experimentId + ":0:" + (reset - 1);
|
||||
initTokens.remove(previousToken);
|
||||
|
||||
String myToken = experimentId + ":0:" + reset;
|
||||
if (!initTokens.containsKey(myToken)) {
|
||||
TCPUtils.Log(Level.INFO,"(Pre)Start " + role + " reset " + reset);
|
||||
started = startUp(command, ipOriginator, experimentId, reset, agentCount, myToken, seed, isSynchronous);
|
||||
if (started)
|
||||
initTokens.put(myToken, 0);
|
||||
} else {
|
||||
started = true; // Pre-started previously.
|
||||
}
|
||||
|
||||
// Check that all previous tokens have been consumed. If not don't proceed to mission.
|
||||
|
||||
allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount);
|
||||
if (!allTokensConsumed) {
|
||||
try {
|
||||
cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException ie) {
|
||||
}
|
||||
allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount);
|
||||
}
|
||||
} else {
|
||||
TCPUtils.Log(Level.INFO, "Start " + role + " reset " + reset);
|
||||
|
||||
started = startUp(command, ipOriginator, experimentId, reset, agentCount, experimentId + ":" + role + ":" + reset, seed, isSynchronous);
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(allTokensConsumed && started ? 1 : 0);
|
||||
dout.flush();
|
||||
|
||||
dout.flush();
|
||||
|
||||
return allTokensConsumed && started;
|
||||
}
|
||||
|
||||
private boolean areAllTokensConsumed(String experimentId, int reset, int agentCount) {
|
||||
boolean allTokensConsumed = true;
|
||||
for (int i = 1; i < agentCount; i++) {
|
||||
String tokenForAgent = experimentId + ":" + i + ":" + (reset - 1);
|
||||
if (initTokens.containsKey(tokenForAgent)) {
|
||||
TCPUtils.Log(Level.FINE,"Mission init - unconsumed " + tokenForAgent);
|
||||
allTokensConsumed = false;
|
||||
}
|
||||
}
|
||||
return allTokensConsumed;
|
||||
}
|
||||
|
||||
private boolean startUp(String command, String ipOriginator, String experimentId, int reset, int agentCount, String myToken, Long seed, Boolean isSynchronous) throws IOException {
|
||||
|
||||
// Clear out mission state
|
||||
envState.reward = 0.0;
|
||||
envState.commands.clear();
|
||||
envState.obs = null;
|
||||
envState.info = "";
|
||||
|
||||
|
||||
envState.missionInit = command;
|
||||
envState.done = false;
|
||||
envState.quit = false;
|
||||
envState.token = myToken;
|
||||
envState.experimentId = experimentId;
|
||||
envState.agentCount = agentCount;
|
||||
envState.reset = reset;
|
||||
envState.synchronous = isSynchronous;
|
||||
envState.seed = seed;
|
||||
|
||||
return startUpMission(command, ipOriginator);
|
||||
}
|
||||
|
||||
private boolean startUpMission(String command, String ipOriginator) throws IOException {
|
||||
|
||||
if (missionPoller == null)
|
||||
return false;
|
||||
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(baos);
|
||||
|
||||
missionPoller.commandReceived(command, ipOriginator, dos);
|
||||
|
||||
dos.flush();
|
||||
byte[] reply = baos.toByteArray();
|
||||
ByteArrayInputStream bais = new ByteArrayInputStream(reply);
|
||||
DataInputStream dis = new DataInputStream(bais);
|
||||
int hdr = dis.readInt();
|
||||
byte[] replyBytes = new byte[hdr];
|
||||
dis.readFully(replyBytes);
|
||||
|
||||
String replyStr = new String(replyBytes);
|
||||
if (replyStr.equals("MALMOOK")) {
|
||||
TCPUtils.Log(Level.INFO, "MalmoEnvServer Mission starting ...");
|
||||
return true;
|
||||
} else if (replyStr.equals("MALMOBUSY")) {
|
||||
TCPUtils.Log(Level.INFO, "MalmoEnvServer Busy - I want to quit");
|
||||
this.envState.quit = true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static final int stepTagLength = "<Step_>".length(); // Step with option code.
|
||||
private synchronized void stepSync(String command, Socket socket, DataInputStream din) throws IOException
|
||||
{
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Entering synchronous step.");
|
||||
nsteps += 1;
|
||||
profiler.startSection("commandProcessing");
|
||||
String actions = command.substring(stepTagLength, command.length() - (stepTagLength + 2));
|
||||
int options = Character.getNumericValue(command.charAt(stepTagLength - 2));
|
||||
boolean withInfo = options == 0 || options == 2;
|
||||
|
||||
|
||||
|
||||
|
||||
// Prepare to write data to the client.
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
double reward = 0.0;
|
||||
boolean done;
|
||||
byte[] obs;
|
||||
String info = "";
|
||||
boolean sent = false;
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Acquiring lock for synchronous step.");
|
||||
|
||||
lock.lock();
|
||||
try {
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Lock is acquired.");
|
||||
|
||||
done = envState.done;
|
||||
|
||||
// TODO Handle when the environment is done.
|
||||
|
||||
// Process the actions.
|
||||
if (actions.contains("\n")) {
|
||||
String[] cmds = actions.split("\\n");
|
||||
for(String cmd : cmds) {
|
||||
envState.commands.add(cmd);
|
||||
}
|
||||
} else {
|
||||
if (!actions.isEmpty())
|
||||
envState.commands.add(actions);
|
||||
}
|
||||
sent = true;
|
||||
|
||||
|
||||
|
||||
profiler.endSection(); //cmd
|
||||
profiler.startSection("requestTick");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Received: " + actions);
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Requesting tick.");
|
||||
// Now wait to run a tick
|
||||
// If synchronous mode is off then we should see if want to quit is true.
|
||||
while(!TimeHelper.SyncManager.requestTick() && !done ){Thread.yield();}
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Tick request granted.");
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("waitForTick");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Waiting for tick.");
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted() && !done ){ Thread.yield();}
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> TICK DONE. Getting observation.");
|
||||
|
||||
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("getObservation");
|
||||
// After which, get the observations.
|
||||
obs = getObservation(done);
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Observation received. Getting info.");
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("getInfo");
|
||||
|
||||
|
||||
// Pick up rewards.
|
||||
reward = envState.reward;
|
||||
if (withInfo) {
|
||||
info = envState.info;
|
||||
// if(info == null)
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> FILLING INFO: NULL");
|
||||
// else
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> FILLING " + info.toString());
|
||||
|
||||
}
|
||||
done = envState.done;
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> STATUS " + Boolean.toString(done));
|
||||
envState.info = null;
|
||||
envState.obs = null;
|
||||
envState.reward = 0.0;
|
||||
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Info received..");
|
||||
profiler.endSection();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Lock released. Writing observation, info, done.");
|
||||
|
||||
profiler.startSection("writeObs");
|
||||
dout.writeInt(obs.length);
|
||||
dout.write(obs);
|
||||
|
||||
dout.writeInt(BYTES_DOUBLE + 2);
|
||||
dout.writeDouble(reward);
|
||||
dout.writeByte(done ? 1 : 0);
|
||||
dout.writeByte(sent ? 1 : 0);
|
||||
|
||||
if (withInfo) {
|
||||
byte[] infoBytes = info.getBytes(utf8);
|
||||
dout.writeInt(infoBytes.length);
|
||||
dout.write(infoBytes);
|
||||
}
|
||||
|
||||
profiler.endSection(); //write obs
|
||||
profiler.startSection("flush");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Packets written. Flushing.");
|
||||
dout.flush();
|
||||
profiler.endSection(); // flush
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Done with step.");
|
||||
}
|
||||
// Handler for <Step_> messages. Single digit option code after _ specifies if turnkey and info are included in message.
|
||||
private void step(String command, Socket socket, DataInputStream din) throws IOException {
|
||||
if(envState.synchronous){
|
||||
stepSync(command, socket, din);
|
||||
}
|
||||
else{
|
||||
System.out.println("[ERROR] Asynchronous stepping is not supported in MineRL.");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Handler for <Peek> messages.
|
||||
private void peek(String command, Socket socket, DataInputStream din) throws IOException {
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
byte[] obs;
|
||||
boolean done;
|
||||
String info = "";
|
||||
// AOG - As we've only seen issues with the peek reqest, I've focused my changes to just
|
||||
// this function. Initially we want to be optimistic and assume we're not going to abort
|
||||
// the request and my observations of event timings indicate that there is plenty of time
|
||||
// between the peek request being received and the reset failing, so a race condition is
|
||||
// unlikely.
|
||||
abortRequest = false;
|
||||
|
||||
lock.lock();
|
||||
|
||||
try {
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Waiting for pistol to fire.");
|
||||
while(!TimeHelper.SyncManager.hasServerFiredPistol() && !abortRequest){
|
||||
|
||||
// Now wait to run a tick
|
||||
while(!TimeHelper.SyncManager.requestTick() && !abortRequest){Thread.yield();}
|
||||
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted() && !abortRequest){ Thread.yield();}
|
||||
|
||||
|
||||
Thread.yield();
|
||||
}
|
||||
|
||||
if (abortRequest) {
|
||||
System.out.println("AOG: Aborting peek request");
|
||||
// AOG - We detect the lack of observation within our Python wrapper and throw a slightly
|
||||
// diferent exception that by-passes MineRLs automatic clean up code. If we were to report
|
||||
// 'done', the MineRL detects this as a runtime error and kills the Minecraft process
|
||||
// triggering a lengthy restart. So far from testing, Minecraft itself is fine can we can
|
||||
// retry the reset, it's only the tight loops above that were causing things to stall and
|
||||
// timeout.
|
||||
// No observation
|
||||
dout.writeInt(0);
|
||||
// No info
|
||||
dout.writeInt(0);
|
||||
// Done
|
||||
dout.writeInt(1);
|
||||
dout.writeByte(0);
|
||||
dout.flush();
|
||||
return;
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Pistol fired!.");
|
||||
// Wait two ticks for the first observation from server to be propagated.
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Getting observation.");
|
||||
|
||||
obs = getObservation(false);
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Observation acquired.");
|
||||
done = envState.done;
|
||||
info = envState.info;
|
||||
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
dout.writeInt(obs.length);
|
||||
dout.write(obs);
|
||||
|
||||
byte[] infoBytes = info.getBytes(utf8);
|
||||
dout.writeInt(infoBytes.length);
|
||||
dout.write(infoBytes);
|
||||
|
||||
dout.writeInt(1);
|
||||
dout.writeByte(done ? 1 : 0);
|
||||
|
||||
dout.flush();
|
||||
}
|
||||
|
||||
// Get the current observation. If none and not done wait for a short time.
|
||||
public byte[] getObservation(boolean done) {
|
||||
byte[] obs = envState.obs;
|
||||
if (obs == null){
|
||||
System.out.println("[ERROR] Video observation is null; please notify the developer.");
|
||||
}
|
||||
return obs;
|
||||
}
|
||||
|
||||
// Handler for <Find> messages - used by non-zero roles to discover integrated server port from primary (role 0) service.
|
||||
|
||||
private final static int findTagLength = "<Find>".length();
|
||||
|
||||
private void find(String command, Socket socket) throws IOException {
|
||||
|
||||
Integer port;
|
||||
lock.lock();
|
||||
try {
|
||||
String token = command.substring(findTagLength, command.length() - (findTagLength + 1));
|
||||
TCPUtils.Log(Level.INFO, "Find token? " + token);
|
||||
|
||||
// Purge previous token.
|
||||
String[] tokenSplits = token.split(":");
|
||||
String experimentId = tokenSplits[0];
|
||||
int role = Integer.parseInt(tokenSplits[1]);
|
||||
int reset = Integer.parseInt(tokenSplits[2]);
|
||||
|
||||
String previousToken = experimentId + ":" + role + ":" + (reset - 1);
|
||||
initTokens.remove(previousToken);
|
||||
cond.signalAll();
|
||||
|
||||
// Check for next token. Wait for a short time if not already produced.
|
||||
port = initTokens.get(token);
|
||||
if (port == null) {
|
||||
try {
|
||||
cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException ie) {
|
||||
}
|
||||
port = initTokens.get(token);
|
||||
if (port == null) {
|
||||
port = 0;
|
||||
TCPUtils.Log(Level.INFO,"Role " + role + " reset " + reset + " waiting for token.");
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(port);
|
||||
dout.flush();
|
||||
}
|
||||
|
||||
public boolean isSynchronous(){
|
||||
return envState.synchronous;
|
||||
}
|
||||
|
||||
// Handler for <Init> messages. These reset the service so use with care!
|
||||
private void init(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
initTokens = new Hashtable<String, Integer>();
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Quit> (quit mission) messages.
|
||||
private void quit(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
if (!envState.done){
|
||||
|
||||
envState.quit = true;
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Pistol fired!.");
|
||||
// Wait two ticks for the first observation from server to be propagated.
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(envState.done ? 1 : 0);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private final static int closeTagLength = "<Close>".length();
|
||||
|
||||
// Handler for <Close> messages.
|
||||
private void close(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
String token = command.substring(closeTagLength, command.length() - (closeTagLength + 1));
|
||||
|
||||
initTokens.remove(token);
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Status> messages.
|
||||
private void status(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
String status = "{}"; // TODO Possibly have something more interesting to report.
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
|
||||
byte[] statusBytes = status.getBytes(utf8);
|
||||
dout.writeInt(statusBytes.length);
|
||||
dout.write(statusBytes);
|
||||
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Exit> messages. These "kill the service" temporarily so use with care!f
|
||||
private void exit(String command, Socket socket) throws IOException {
|
||||
// lock.lock();
|
||||
try {
|
||||
// We may exit before we get a chance to reply.
|
||||
TimeHelper.SyncManager.setSynchronous(false);
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
|
||||
ClientStateMachine.exitJava();
|
||||
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Malmo client state machine interface methods:
|
||||
|
||||
public String getCommand() {
|
||||
try {
|
||||
String command = envState.commands.poll();
|
||||
if (command == null)
|
||||
return "";
|
||||
else
|
||||
return command;
|
||||
} finally {
|
||||
}
|
||||
}
|
||||
|
||||
public void endMission() {
|
||||
// lock.lock();
|
||||
try {
|
||||
// AOG - If the mission is ending, we always want to abort requests and they won't
|
||||
// be able to progress to completion and will stall.
|
||||
System.out.println("AOG: MalmoEnvServer.endMission");
|
||||
abort();
|
||||
envState.done = true;
|
||||
envState.quit = false;
|
||||
envState.missionInit = null;
|
||||
|
||||
if (envState.token != null) {
|
||||
initTokens.remove(envState.token);
|
||||
envState.token = null;
|
||||
envState.experimentId = null;
|
||||
envState.agentCount = 0;
|
||||
envState.reset = 0;
|
||||
|
||||
// cond.signalAll();
|
||||
}
|
||||
// lock.unlock();
|
||||
} finally {
|
||||
}
|
||||
}
|
||||
// Record a Malmo "observation" json - as the env info since an environment "obs" is a video frame.
|
||||
public void observation(String info) {
|
||||
// Parsing obs as JSON would be slower but less fragile than extracting the turn_key using string search.
|
||||
// lock.lock();
|
||||
try {
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <OBSERVATION> Inserting: " + info);
|
||||
envState.info = info;
|
||||
// cond.signalAll();
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void addRewards(double rewards) {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.reward += rewards;
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void addFrame(byte[] frame) {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.obs = frame; // Replaces current.
|
||||
// cond.signalAll();
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void notifyIntegrationServerStarted(int integrationServerPort) {
|
||||
lock.lock();
|
||||
try {
|
||||
if (envState.token != null) {
|
||||
TCPUtils.Log(Level.INFO,"Integration server start up - token: " + envState.token);
|
||||
addTokens(integrationServerPort, envState.token, envState.experimentId, envState.agentCount, envState.reset);
|
||||
cond.signalAll();
|
||||
} else {
|
||||
TCPUtils.Log(Level.WARNING,"No mission token on integration server start up!");
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private void addTokens(int integratedServerPort, String myToken, String experimentId, int agentCount, int reset) {
|
||||
initTokens.put(myToken, integratedServerPort);
|
||||
// Place tokens for other agents to find.
|
||||
for (int i = 1; i < agentCount; i++) {
|
||||
String tokenForAgent = experimentId + ":" + i + ":" + reset;
|
||||
initTokens.put(tokenForAgent, integratedServerPort);
|
||||
}
|
||||
}
|
||||
|
||||
// IWantToQuit implementation.
|
||||
|
||||
@Override
|
||||
public boolean doIWantToQuit(MissionInit missionInit) {
|
||||
// lock.lock();
|
||||
try {
|
||||
return envState.quit;
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public Long getSeed(){
|
||||
return envState.seed;
|
||||
}
|
||||
|
||||
private void setWantToQuit() {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.quit = true;
|
||||
|
||||
} finally {
|
||||
|
||||
if(TimeHelper.SyncManager.isSynchronous()){
|
||||
// We want to dsynchronize everything.
|
||||
TimeHelper.SyncManager.setSynchronous(false);
|
||||
}
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void prepare(MissionInit missionInit) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cleanup() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutcome() {
|
||||
return "Env quit";
|
||||
}
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
FROM mcr.microsoft.com/azureml/base-gpu:openmpi3.1.2-cuda10.0-cudnn7-ubuntu18.04
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
cpio \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
tmux \
|
||||
htop \
|
||||
gcc \
|
||||
xvfb \
|
||||
python-opengl \
|
||||
x11-xserver-utils \
|
||||
ffmpeg \
|
||||
mesa-utils \
|
||||
nano \
|
||||
vim \
|
||||
rsync \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Create a working directory
|
||||
RUN mkdir /app
|
||||
WORKDIR /app
|
||||
|
||||
# Create a Python 3.7 environment
|
||||
RUN conda install conda-build \
|
||||
&& conda create -y --name py37 python=3.7.3 \
|
||||
&& conda clean -ya
|
||||
ENV CONDA_DEFAULT_ENV=py37
|
||||
|
||||
# Install Minecraft needed libraries
|
||||
RUN mkdir -p /usr/share/man/man1 && \
|
||||
sudo apt-get update && \
|
||||
sudo apt-get install -y \
|
||||
openjdk-8-jre-headless=8u162-b12-1 \
|
||||
openjdk-8-jdk-headless=8u162-b12-1 \
|
||||
openjdk-8-jre=8u162-b12-1 \
|
||||
openjdk-8-jdk=8u162-b12-1
|
||||
|
||||
RUN pip install --upgrade --user minerl
|
||||
|
||||
# PyTorch with CUDA 10 installation
|
||||
RUN conda install -y -c pytorch \
|
||||
cuda100=1.0 \
|
||||
magma-cuda100=2.4.0 \
|
||||
"pytorch=1.1.0=py3.7_cuda10.0.130_cudnn7.5.1_0" \
|
||||
torchvision=0.3.0 \
|
||||
&& conda clean -ya
|
||||
|
||||
RUN pip install \
|
||||
pandas \
|
||||
matplotlib \
|
||||
numpy \
|
||||
scipy \
|
||||
azureml-defaults \
|
||||
tensorboardX \
|
||||
tensorflow-gpu==1.15rc2 \
|
||||
GPUtil \
|
||||
tabulate \
|
||||
dm_tree \
|
||||
lz4 \
|
||||
ray==0.8.3 \
|
||||
ray[rllib]==0.8.3 \
|
||||
ray[tune]==0.8.3
|
||||
|
||||
COPY patch_files/* /root/.local/lib/python3.7/site-packages/minerl/env/Malmo/Minecraft/src/main/java/com/microsoft/Malmo/Client/
|
||||
|
||||
# Start minerl to pre-fetch minerl files (saves time when starting minerl during training)
|
||||
RUN xvfb-run -a -s "-screen 0 1400x900x24" python -c "import gym; import minerl; env = gym.make('MineRLTreechop-v0'); env.close();"
|
||||
|
||||
RUN pip install --index-url https://test.pypi.org/simple/ malmo && \
|
||||
python -c "import malmo.minecraftbootstrap; malmo.minecraftbootstrap.download();"
|
||||
|
||||
ENV MALMO_XSD_PATH="/app/MalmoPlatform/Schemas"
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,939 +0,0 @@
|
||||
// --------------------------------------------------------------------------------------------------
|
||||
// Copyright (c) 2016 Microsoft Corporation
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
|
||||
// associated documentation files (the "Software"), to deal in the Software without restriction,
|
||||
// including without limitation the rights to use, copy, modify, merge, publish, distribute,
|
||||
// sublicense, and/or l copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all copies or
|
||||
// substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
|
||||
// NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
|
||||
// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
// --------------------------------------------------------------------------------------------------
|
||||
|
||||
package com.microsoft.Malmo.Client;
|
||||
|
||||
import com.microsoft.Malmo.MalmoMod;
|
||||
import com.microsoft.Malmo.MissionHandlerInterfaces.IWantToQuit;
|
||||
import com.microsoft.Malmo.Schemas.MissionInit;
|
||||
import com.microsoft.Malmo.Utils.TCPUtils;
|
||||
|
||||
import net.minecraft.profiler.Profiler;
|
||||
import com.microsoft.Malmo.Utils.TimeHelper;
|
||||
|
||||
import net.minecraftforge.common.config.Configuration;
|
||||
import java.io.*;
|
||||
import java.net.ServerSocket;
|
||||
import java.net.Socket;
|
||||
import java.nio.charset.Charset;
|
||||
import java.util.Arrays;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.locks.Condition;
|
||||
import java.util.concurrent.locks.Lock;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.Hashtable;
|
||||
import com.microsoft.Malmo.Utils.TCPInputPoller;
|
||||
import java.util.logging.Level;
|
||||
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
|
||||
/**
|
||||
* MalmoEnvServer - service supporting OpenAI gym "environment" for multi-agent Malmo missions.
|
||||
*/
|
||||
public class MalmoEnvServer implements IWantToQuit {
|
||||
private static Profiler profiler = new Profiler();
|
||||
private static int nsteps = 0;
|
||||
private static boolean debug = false;
|
||||
|
||||
private static String hello = "<MalmoEnv" ;
|
||||
|
||||
private class EnvState {
|
||||
|
||||
// Mission parameters:
|
||||
String missionInit = null;
|
||||
String token = null;
|
||||
String experimentId = null;
|
||||
int agentCount = 0;
|
||||
int reset = 0;
|
||||
boolean quit = false;
|
||||
boolean synchronous = false;
|
||||
Long seed = null;
|
||||
|
||||
// OpenAI gym state:
|
||||
boolean done = false;
|
||||
double reward = 0.0;
|
||||
byte[] obs = null;
|
||||
String info = "";
|
||||
LinkedList<String> commands = new LinkedList<String>();
|
||||
}
|
||||
|
||||
private static boolean envPolicy = false; // Are we configured by config policy?
|
||||
|
||||
// Synchronize on EnvStateasd
|
||||
|
||||
|
||||
private Lock lock = new ReentrantLock();
|
||||
private Condition cond = lock.newCondition();
|
||||
|
||||
private EnvState envState = new EnvState();
|
||||
|
||||
private Hashtable<String, Integer> initTokens = new Hashtable<String, Integer>();
|
||||
|
||||
static final long COND_WAIT_SECONDS = 3; // Max wait in seconds before timing out (and replying to RPC).
|
||||
static final int BYTES_INT = 4;
|
||||
static final int BYTES_DOUBLE = 8;
|
||||
private static final Charset utf8 = Charset.forName("UTF-8");
|
||||
|
||||
// Service uses a single per-environment client connection - initiated by the remote environment.
|
||||
|
||||
private int port;
|
||||
private TCPInputPoller missionPoller; // Used for command parsing and not actual communication.
|
||||
private String version;
|
||||
|
||||
// AOG: From running experiments, I've found that MineRL can get stuck resetting the
|
||||
// environment which causes huge delays while we wait for the Python side to time
|
||||
// out and restart the Minecraft instace. Minecraft itself is normally in a recoverable
|
||||
// state, but the MalmoEnvServer instance will be blocked in a tight spin loop trying
|
||||
// handling a Peek request from the Python client. To unstick things, I've added this
|
||||
// flag that can be set when we know things are in a bad state to abort the peek request.
|
||||
// WARNING: THIS IS ONLY TREATING THE SYMPTOM AND NOT THE ROOT CAUSE
|
||||
// The reason things are getting stuck is because the player is either dying or we're
|
||||
// receiving a quit request while an episode reset is in progress.
|
||||
private boolean abortRequest;
|
||||
public void abort() {
|
||||
System.out.println("AOG: MalmoEnvServer.abort");
|
||||
abortRequest = true;
|
||||
}
|
||||
|
||||
/***
|
||||
* Malmo "Env" service.
|
||||
* @param port the port the service listens on.
|
||||
* @param missionPoller for plugging into existing comms handling.
|
||||
*/
|
||||
public MalmoEnvServer(String version, int port, TCPInputPoller missionPoller) {
|
||||
this.version = version;
|
||||
this.missionPoller = missionPoller;
|
||||
this.port = port;
|
||||
// AOG - Assume we don't wan't to be aborting in the first place
|
||||
this.abortRequest = false;
|
||||
}
|
||||
|
||||
/** Initialize malmo env configuration. For now either on or "legacy" AgentHost protocol.*/
|
||||
static public void update(Configuration configs) {
|
||||
envPolicy = configs.get(MalmoMod.ENV_CONFIGS, "env", "false").getBoolean();
|
||||
}
|
||||
|
||||
public static boolean isEnv() {
|
||||
return envPolicy;
|
||||
}
|
||||
|
||||
/**
|
||||
* Start servicing the MalmoEnv protocol.
|
||||
* @throws IOException
|
||||
*/
|
||||
public void serve() throws IOException {
|
||||
|
||||
ServerSocket serverSocket = new ServerSocket(port);
|
||||
serverSocket.setPerformancePreferences(0,2,1);
|
||||
|
||||
|
||||
while (true) {
|
||||
try {
|
||||
final Socket socket = serverSocket.accept();
|
||||
socket.setTcpNoDelay(true);
|
||||
|
||||
Thread thread = new Thread("EnvServerSocketHandler") {
|
||||
public void run() {
|
||||
boolean running = false;
|
||||
try {
|
||||
checkHello(socket);
|
||||
|
||||
while (true) {
|
||||
DataInputStream din = new DataInputStream(socket.getInputStream());
|
||||
int hdr = din.readInt();
|
||||
byte[] data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
|
||||
String command = new String(data, utf8);
|
||||
|
||||
if (command.startsWith("<Step")) {
|
||||
|
||||
profiler.startSection("root");
|
||||
long start = System.nanoTime();
|
||||
step(command, socket, din);
|
||||
profiler.endSection();
|
||||
if (nsteps % 100 == 0 && debug){
|
||||
List<Profiler.Result> dat = profiler.getProfilingData("root");
|
||||
for(int qq = 0; qq < dat.size(); qq++){
|
||||
Profiler.Result res = dat.get(qq);
|
||||
System.out.println(res.profilerName + " " + res.totalUsePercentage + " "+ res.usePercentage);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} else if (command.startsWith("<Peek")) {
|
||||
|
||||
peek(command, socket, din);
|
||||
|
||||
} else if (command.startsWith("<Init")) {
|
||||
|
||||
init(command, socket);
|
||||
|
||||
} else if (command.startsWith("<Find")) {
|
||||
|
||||
find(command, socket);
|
||||
|
||||
} else if (command.startsWith("<MissionInit")) {
|
||||
|
||||
if (missionInit(din, command, socket))
|
||||
{
|
||||
running = true;
|
||||
}
|
||||
|
||||
} else if (command.startsWith("<Quit")) {
|
||||
|
||||
quit(command, socket);
|
||||
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Exit")) {
|
||||
|
||||
exit(command, socket);
|
||||
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Close")) {
|
||||
|
||||
close(command, socket);
|
||||
profiler.profilingEnabled = false;
|
||||
|
||||
} else if (command.startsWith("<Status")) {
|
||||
|
||||
status(command, socket);
|
||||
|
||||
} else if (command.startsWith("<Echo")) {
|
||||
command = "<Echo>" + command + "</Echo>";
|
||||
data = command.getBytes(utf8);
|
||||
hdr = data.length;
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(hdr);
|
||||
dout.write(data, 0, hdr);
|
||||
dout.flush();
|
||||
} else {
|
||||
throw new IOException("Unknown env service command");
|
||||
}
|
||||
}
|
||||
} catch (IOException ioe) {
|
||||
// ioe.printStackTrace();
|
||||
TCPUtils.Log(Level.SEVERE, "MalmoEnv socket error: " + ioe + " (can be on disconnect)");
|
||||
// System.out.println("[ERROR] " + "MalmoEnv socket error: " + ioe + " (can be on disconnect)");
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] MalmoEnv socket error");
|
||||
try {
|
||||
if (running) {
|
||||
TCPUtils.Log(Level.INFO,"Want to quit on disconnect.");
|
||||
|
||||
System.out.println("[LOGTOPY] " + "Want to quit on disconnect.");
|
||||
setWantToQuit();
|
||||
}
|
||||
socket.close();
|
||||
} catch (IOException ioe2) {
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
thread.start();
|
||||
} catch (IOException ioe) {
|
||||
TCPUtils.Log(Level.SEVERE, "MalmoEnv service exits on " + ioe);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void checkHello(Socket socket) throws IOException {
|
||||
DataInputStream din = new DataInputStream(socket.getInputStream());
|
||||
int hdr = din.readInt();
|
||||
if (hdr <= 0 || hdr > hello.length() + 8) // Version number may be somewhat longer in future.
|
||||
throw new IOException("Invalid MalmoEnv hello header length");
|
||||
byte[] data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
if (!new String(data).startsWith(hello + version))
|
||||
throw new IOException("MalmoEnv invalid protocol or version - expected " + hello + version);
|
||||
}
|
||||
|
||||
// Handler for <MissionInit> messages.
|
||||
private boolean missionInit(DataInputStream din, String command, Socket socket) throws IOException {
|
||||
|
||||
String ipOriginator = socket.getInetAddress().getHostName();
|
||||
|
||||
int hdr;
|
||||
byte[] data;
|
||||
hdr = din.readInt();
|
||||
data = new byte[hdr];
|
||||
din.readFully(data);
|
||||
String id = new String(data, utf8);
|
||||
|
||||
TCPUtils.Log(Level.INFO,"Mission Init" + id);
|
||||
|
||||
String[] token = id.split(":");
|
||||
String experimentId = token[0];
|
||||
int role = Integer.parseInt(token[1]);
|
||||
int reset = Integer.parseInt(token[2]);
|
||||
int agentCount = Integer.parseInt(token[3]);
|
||||
Boolean isSynchronous = Boolean.parseBoolean(token[4]);
|
||||
Long seed = null;
|
||||
if(token.length > 5)
|
||||
seed = Long.parseLong(token[5]);
|
||||
|
||||
if(isSynchronous && agentCount > 1){
|
||||
throw new IOException("Synchronous mode currently does not support multiple agents.");
|
||||
}
|
||||
port = -1;
|
||||
boolean allTokensConsumed = true;
|
||||
boolean started = false;
|
||||
|
||||
lock.lock();
|
||||
try {
|
||||
if (role == 0) {
|
||||
|
||||
String previousToken = experimentId + ":0:" + (reset - 1);
|
||||
initTokens.remove(previousToken);
|
||||
|
||||
String myToken = experimentId + ":0:" + reset;
|
||||
if (!initTokens.containsKey(myToken)) {
|
||||
TCPUtils.Log(Level.INFO,"(Pre)Start " + role + " reset " + reset);
|
||||
started = startUp(command, ipOriginator, experimentId, reset, agentCount, myToken, seed, isSynchronous);
|
||||
if (started)
|
||||
initTokens.put(myToken, 0);
|
||||
} else {
|
||||
started = true; // Pre-started previously.
|
||||
}
|
||||
|
||||
// Check that all previous tokens have been consumed. If not don't proceed to mission.
|
||||
|
||||
allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount);
|
||||
if (!allTokensConsumed) {
|
||||
try {
|
||||
cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException ie) {
|
||||
}
|
||||
allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount);
|
||||
}
|
||||
} else {
|
||||
TCPUtils.Log(Level.INFO, "Start " + role + " reset " + reset);
|
||||
|
||||
started = startUp(command, ipOriginator, experimentId, reset, agentCount, experimentId + ":" + role + ":" + reset, seed, isSynchronous);
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(allTokensConsumed && started ? 1 : 0);
|
||||
dout.flush();
|
||||
|
||||
dout.flush();
|
||||
|
||||
return allTokensConsumed && started;
|
||||
}
|
||||
|
||||
private boolean areAllTokensConsumed(String experimentId, int reset, int agentCount) {
|
||||
boolean allTokensConsumed = true;
|
||||
for (int i = 1; i < agentCount; i++) {
|
||||
String tokenForAgent = experimentId + ":" + i + ":" + (reset - 1);
|
||||
if (initTokens.containsKey(tokenForAgent)) {
|
||||
TCPUtils.Log(Level.FINE,"Mission init - unconsumed " + tokenForAgent);
|
||||
allTokensConsumed = false;
|
||||
}
|
||||
}
|
||||
return allTokensConsumed;
|
||||
}
|
||||
|
||||
private boolean startUp(String command, String ipOriginator, String experimentId, int reset, int agentCount, String myToken, Long seed, Boolean isSynchronous) throws IOException {
|
||||
|
||||
// Clear out mission state
|
||||
envState.reward = 0.0;
|
||||
envState.commands.clear();
|
||||
envState.obs = null;
|
||||
envState.info = "";
|
||||
|
||||
|
||||
envState.missionInit = command;
|
||||
envState.done = false;
|
||||
envState.quit = false;
|
||||
envState.token = myToken;
|
||||
envState.experimentId = experimentId;
|
||||
envState.agentCount = agentCount;
|
||||
envState.reset = reset;
|
||||
envState.synchronous = isSynchronous;
|
||||
envState.seed = seed;
|
||||
|
||||
return startUpMission(command, ipOriginator);
|
||||
}
|
||||
|
||||
private boolean startUpMission(String command, String ipOriginator) throws IOException {
|
||||
|
||||
if (missionPoller == null)
|
||||
return false;
|
||||
|
||||
ByteArrayOutputStream baos = new ByteArrayOutputStream();
|
||||
DataOutputStream dos = new DataOutputStream(baos);
|
||||
|
||||
missionPoller.commandReceived(command, ipOriginator, dos);
|
||||
|
||||
dos.flush();
|
||||
byte[] reply = baos.toByteArray();
|
||||
ByteArrayInputStream bais = new ByteArrayInputStream(reply);
|
||||
DataInputStream dis = new DataInputStream(bais);
|
||||
int hdr = dis.readInt();
|
||||
byte[] replyBytes = new byte[hdr];
|
||||
dis.readFully(replyBytes);
|
||||
|
||||
String replyStr = new String(replyBytes);
|
||||
if (replyStr.equals("MALMOOK")) {
|
||||
TCPUtils.Log(Level.INFO, "MalmoEnvServer Mission starting ...");
|
||||
return true;
|
||||
} else if (replyStr.equals("MALMOBUSY")) {
|
||||
TCPUtils.Log(Level.INFO, "MalmoEnvServer Busy - I want to quit");
|
||||
this.envState.quit = true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static final int stepTagLength = "<Step_>".length(); // Step with option code.
|
||||
private synchronized void stepSync(String command, Socket socket, DataInputStream din) throws IOException
|
||||
{
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Entering synchronous step.");
|
||||
nsteps += 1;
|
||||
profiler.startSection("commandProcessing");
|
||||
String actions = command.substring(stepTagLength, command.length() - (stepTagLength + 2));
|
||||
int options = Character.getNumericValue(command.charAt(stepTagLength - 2));
|
||||
boolean withInfo = options == 0 || options == 2;
|
||||
|
||||
|
||||
|
||||
|
||||
// Prepare to write data to the client.
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
double reward = 0.0;
|
||||
boolean done;
|
||||
byte[] obs;
|
||||
String info = "";
|
||||
boolean sent = false;
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Acquiring lock for synchronous step.");
|
||||
|
||||
lock.lock();
|
||||
try {
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Lock is acquired.");
|
||||
|
||||
done = envState.done;
|
||||
|
||||
// TODO Handle when the environment is done.
|
||||
|
||||
// Process the actions.
|
||||
if (actions.contains("\n")) {
|
||||
String[] cmds = actions.split("\\n");
|
||||
for(String cmd : cmds) {
|
||||
envState.commands.add(cmd);
|
||||
}
|
||||
} else {
|
||||
if (!actions.isEmpty())
|
||||
envState.commands.add(actions);
|
||||
}
|
||||
sent = true;
|
||||
|
||||
|
||||
|
||||
profiler.endSection(); //cmd
|
||||
profiler.startSection("requestTick");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Received: " + actions);
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Requesting tick.");
|
||||
// Now wait to run a tick
|
||||
// If synchronous mode is off then we should see if want to quit is true.
|
||||
while(!TimeHelper.SyncManager.requestTick() && !done ){Thread.yield();}
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Tick request granted.");
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("waitForTick");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Waiting for tick.");
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted() && !done ){ Thread.yield();}
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> TICK DONE. Getting observation.");
|
||||
|
||||
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("getObservation");
|
||||
// After which, get the observations.
|
||||
obs = getObservation(done);
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Observation received. Getting info.");
|
||||
|
||||
profiler.endSection();
|
||||
profiler.startSection("getInfo");
|
||||
|
||||
|
||||
// Pick up rewards.
|
||||
reward = envState.reward;
|
||||
if (withInfo) {
|
||||
info = envState.info;
|
||||
// if(info == null)
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> FILLING INFO: NULL");
|
||||
// else
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> FILLING " + info.toString());
|
||||
|
||||
}
|
||||
done = envState.done;
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> STATUS " + Boolean.toString(done));
|
||||
envState.info = null;
|
||||
envState.obs = null;
|
||||
envState.reward = 0.0;
|
||||
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Info received..");
|
||||
profiler.endSection();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Lock released. Writing observation, info, done.");
|
||||
|
||||
profiler.startSection("writeObs");
|
||||
dout.writeInt(obs.length);
|
||||
dout.write(obs);
|
||||
|
||||
dout.writeInt(BYTES_DOUBLE + 2);
|
||||
dout.writeDouble(reward);
|
||||
dout.writeByte(done ? 1 : 0);
|
||||
dout.writeByte(sent ? 1 : 0);
|
||||
|
||||
if (withInfo) {
|
||||
byte[] infoBytes = info.getBytes(utf8);
|
||||
dout.writeInt(infoBytes.length);
|
||||
dout.write(infoBytes);
|
||||
}
|
||||
|
||||
profiler.endSection(); //write obs
|
||||
profiler.startSection("flush");
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Packets written. Flushing.");
|
||||
dout.flush();
|
||||
profiler.endSection(); // flush
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <STEP> Done with step.");
|
||||
}
|
||||
// Handler for <Step_> messages. Single digit option code after _ specifies if turnkey and info are included in message.
|
||||
private void step(String command, Socket socket, DataInputStream din) throws IOException {
|
||||
if(envState.synchronous){
|
||||
stepSync(command, socket, din);
|
||||
}
|
||||
else{
|
||||
System.out.println("[ERROR] Asynchronous stepping is not supported in MineRL.");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Handler for <Peek> messages.
|
||||
private void peek(String command, Socket socket, DataInputStream din) throws IOException {
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
byte[] obs;
|
||||
boolean done;
|
||||
String info = "";
|
||||
// AOG - As we've only seen issues with the peek reqest, I've focused my changes to just
|
||||
// this function. Initially we want to be optimistic and assume we're not going to abort
|
||||
// the request and my observations of event timings indicate that there is plenty of time
|
||||
// between the peek request being received and the reset failing, so a race condition is
|
||||
// unlikely.
|
||||
abortRequest = false;
|
||||
|
||||
lock.lock();
|
||||
|
||||
try {
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Waiting for pistol to fire.");
|
||||
while(!TimeHelper.SyncManager.hasServerFiredPistol() && !abortRequest){
|
||||
|
||||
// Now wait to run a tick
|
||||
while(!TimeHelper.SyncManager.requestTick() && !abortRequest){Thread.yield();}
|
||||
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted() && !abortRequest){ Thread.yield();}
|
||||
|
||||
|
||||
Thread.yield();
|
||||
}
|
||||
|
||||
if (abortRequest) {
|
||||
System.out.println("AOG: Aborting peek request");
|
||||
// AOG - We detect the lack of observation within our Python wrapper and throw a slightly
|
||||
// diferent exception that by-passes MineRLs automatic clean up code. If we were to report
|
||||
// 'done', the MineRL detects this as a runtime error and kills the Minecraft process
|
||||
// triggering a lengthy restart. So far from testing, Minecraft itself is fine can we can
|
||||
// retry the reset, it's only the tight loops above that were causing things to stall and
|
||||
// timeout.
|
||||
// No observation
|
||||
dout.writeInt(0);
|
||||
// No info
|
||||
dout.writeInt(0);
|
||||
// Done
|
||||
dout.writeInt(1);
|
||||
dout.writeByte(0);
|
||||
dout.flush();
|
||||
return;
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Pistol fired!.");
|
||||
// Wait two ticks for the first observation from server to be propagated.
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Getting observation.");
|
||||
|
||||
obs = getObservation(false);
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Observation acquired.");
|
||||
done = envState.done;
|
||||
info = envState.info;
|
||||
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
dout.writeInt(obs.length);
|
||||
dout.write(obs);
|
||||
|
||||
byte[] infoBytes = info.getBytes(utf8);
|
||||
dout.writeInt(infoBytes.length);
|
||||
dout.write(infoBytes);
|
||||
|
||||
dout.writeInt(1);
|
||||
dout.writeByte(done ? 1 : 0);
|
||||
|
||||
dout.flush();
|
||||
}
|
||||
|
||||
// Get the current observation. If none and not done wait for a short time.
|
||||
public byte[] getObservation(boolean done) {
|
||||
byte[] obs = envState.obs;
|
||||
if (obs == null){
|
||||
System.out.println("[ERROR] Video observation is null; please notify the developer.");
|
||||
}
|
||||
return obs;
|
||||
}
|
||||
|
||||
// Handler for <Find> messages - used by non-zero roles to discover integrated server port from primary (role 0) service.
|
||||
|
||||
private final static int findTagLength = "<Find>".length();
|
||||
|
||||
private void find(String command, Socket socket) throws IOException {
|
||||
|
||||
Integer port;
|
||||
lock.lock();
|
||||
try {
|
||||
String token = command.substring(findTagLength, command.length() - (findTagLength + 1));
|
||||
TCPUtils.Log(Level.INFO, "Find token? " + token);
|
||||
|
||||
// Purge previous token.
|
||||
String[] tokenSplits = token.split(":");
|
||||
String experimentId = tokenSplits[0];
|
||||
int role = Integer.parseInt(tokenSplits[1]);
|
||||
int reset = Integer.parseInt(tokenSplits[2]);
|
||||
|
||||
String previousToken = experimentId + ":" + role + ":" + (reset - 1);
|
||||
initTokens.remove(previousToken);
|
||||
cond.signalAll();
|
||||
|
||||
// Check for next token. Wait for a short time if not already produced.
|
||||
port = initTokens.get(token);
|
||||
if (port == null) {
|
||||
try {
|
||||
cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS);
|
||||
} catch (InterruptedException ie) {
|
||||
}
|
||||
port = initTokens.get(token);
|
||||
if (port == null) {
|
||||
port = 0;
|
||||
TCPUtils.Log(Level.INFO,"Role " + role + " reset " + reset + " waiting for token.");
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(port);
|
||||
dout.flush();
|
||||
}
|
||||
|
||||
public boolean isSynchronous(){
|
||||
return envState.synchronous;
|
||||
}
|
||||
|
||||
// Handler for <Init> messages. These reset the service so use with care!
|
||||
private void init(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
initTokens = new Hashtable<String, Integer>();
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Quit> (quit mission) messages.
|
||||
private void quit(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
if (!envState.done){
|
||||
|
||||
envState.quit = true;
|
||||
}
|
||||
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <PEEK> Pistol fired!.");
|
||||
// Wait two ticks for the first observation from server to be propagated.
|
||||
while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();}
|
||||
|
||||
// Then wait until the tick is finished
|
||||
while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();}
|
||||
|
||||
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(envState.done ? 1 : 0);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private final static int closeTagLength = "<Close>".length();
|
||||
|
||||
// Handler for <Close> messages.
|
||||
private void close(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
String token = command.substring(closeTagLength, command.length() - (closeTagLength + 1));
|
||||
|
||||
initTokens.remove(token);
|
||||
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Status> messages.
|
||||
private void status(String command, Socket socket) throws IOException {
|
||||
lock.lock();
|
||||
try {
|
||||
String status = "{}"; // TODO Possibly have something more interesting to report.
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
|
||||
byte[] statusBytes = status.getBytes(utf8);
|
||||
dout.writeInt(statusBytes.length);
|
||||
dout.write(statusBytes);
|
||||
|
||||
dout.flush();
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Handler for <Exit> messages. These "kill the service" temporarily so use with care!f
|
||||
private void exit(String command, Socket socket) throws IOException {
|
||||
// lock.lock();
|
||||
try {
|
||||
// We may exit before we get a chance to reply.
|
||||
TimeHelper.SyncManager.setSynchronous(false);
|
||||
DataOutputStream dout = new DataOutputStream(socket.getOutputStream());
|
||||
dout.writeInt(BYTES_INT);
|
||||
dout.writeInt(1);
|
||||
dout.flush();
|
||||
|
||||
ClientStateMachine.exitJava();
|
||||
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// Malmo client state machine interface methods:
|
||||
|
||||
public String getCommand() {
|
||||
try {
|
||||
String command = envState.commands.poll();
|
||||
if (command == null)
|
||||
return "";
|
||||
else
|
||||
return command;
|
||||
} finally {
|
||||
}
|
||||
}
|
||||
|
||||
public void endMission() {
|
||||
// lock.lock();
|
||||
try {
|
||||
// AOG - If the mission is ending, we always want to abort requests and they won't
|
||||
// be able to progress to completion and will stall.
|
||||
System.out.println("AOG: MalmoEnvServer.endMission");
|
||||
abort();
|
||||
envState.done = true;
|
||||
envState.quit = false;
|
||||
envState.missionInit = null;
|
||||
|
||||
if (envState.token != null) {
|
||||
initTokens.remove(envState.token);
|
||||
envState.token = null;
|
||||
envState.experimentId = null;
|
||||
envState.agentCount = 0;
|
||||
envState.reset = 0;
|
||||
|
||||
// cond.signalAll();
|
||||
}
|
||||
// lock.unlock();
|
||||
} finally {
|
||||
}
|
||||
}
|
||||
// Record a Malmo "observation" json - as the env info since an environment "obs" is a video frame.
|
||||
public void observation(String info) {
|
||||
// Parsing obs as JSON would be slower but less fragile than extracting the turn_key using string search.
|
||||
// lock.lock();
|
||||
try {
|
||||
// TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] <OBSERVATION> Inserting: " + info);
|
||||
envState.info = info;
|
||||
// cond.signalAll();
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void addRewards(double rewards) {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.reward += rewards;
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void addFrame(byte[] frame) {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.obs = frame; // Replaces current.
|
||||
// cond.signalAll();
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public void notifyIntegrationServerStarted(int integrationServerPort) {
|
||||
lock.lock();
|
||||
try {
|
||||
if (envState.token != null) {
|
||||
TCPUtils.Log(Level.INFO,"Integration server start up - token: " + envState.token);
|
||||
addTokens(integrationServerPort, envState.token, envState.experimentId, envState.agentCount, envState.reset);
|
||||
cond.signalAll();
|
||||
} else {
|
||||
TCPUtils.Log(Level.WARNING,"No mission token on integration server start up!");
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
private void addTokens(int integratedServerPort, String myToken, String experimentId, int agentCount, int reset) {
|
||||
initTokens.put(myToken, integratedServerPort);
|
||||
// Place tokens for other agents to find.
|
||||
for (int i = 1; i < agentCount; i++) {
|
||||
String tokenForAgent = experimentId + ":" + i + ":" + reset;
|
||||
initTokens.put(tokenForAgent, integratedServerPort);
|
||||
}
|
||||
}
|
||||
|
||||
// IWantToQuit implementation.
|
||||
|
||||
@Override
|
||||
public boolean doIWantToQuit(MissionInit missionInit) {
|
||||
// lock.lock();
|
||||
try {
|
||||
return envState.quit;
|
||||
} finally {
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
public Long getSeed(){
|
||||
return envState.seed;
|
||||
}
|
||||
|
||||
private void setWantToQuit() {
|
||||
// lock.lock();
|
||||
try {
|
||||
envState.quit = true;
|
||||
|
||||
} finally {
|
||||
|
||||
if(TimeHelper.SyncManager.isSynchronous()){
|
||||
// We want to dsynchronize everything.
|
||||
TimeHelper.SyncManager.setSynchronous(false);
|
||||
}
|
||||
// lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void prepare(MissionInit missionInit) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void cleanup() {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String getOutcome() {
|
||||
return "Env quit";
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user