diff --git a/configuration.ipynb b/configuration.ipynb index afba9b3b..fcba452e 100644 --- a/configuration.ipynb +++ b/configuration.ipynb @@ -103,7 +103,7 @@ "source": [ "import azureml.core\n", "\n", - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/automl_env.yml b/how-to-use-azureml/automated-machine-learning/automl_env.yml index 642d4e86..f0d34df8 100644 --- a/how-to-use-azureml/automated-machine-learning/automl_env.yml +++ b/how-to-use-azureml/automated-machine-learning/automl_env.yml @@ -21,8 +21,8 @@ dependencies: - pip: # Required packages for AzureML execution, history, and data preparation. - - azureml-widgets~=1.23.0 + - azureml-widgets~=1.24.0 - pytorch-transformers==1.0.0 - spacy==2.1.8 - https://aka.ms/automl-resources/packages/en_core_web_sm-2.1.0.tar.gz - - -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.23.0/validated_win32_requirements.txt [--no-deps] + - -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.24.0/validated_win32_requirements.txt [--no-deps] diff --git a/how-to-use-azureml/automated-machine-learning/automl_env_linux.yml b/how-to-use-azureml/automated-machine-learning/automl_env_linux.yml index de9f66cf..ec2df76c 100644 --- a/how-to-use-azureml/automated-machine-learning/automl_env_linux.yml +++ b/how-to-use-azureml/automated-machine-learning/automl_env_linux.yml @@ -21,8 +21,8 @@ dependencies: - pip: # Required packages for AzureML execution, history, and data preparation. - - azureml-widgets~=1.23.0 + - azureml-widgets~=1.24.0 - pytorch-transformers==1.0.0 - spacy==2.1.8 - https://aka.ms/automl-resources/packages/en_core_web_sm-2.1.0.tar.gz - - -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.23.0/validated_linux_requirements.txt [--no-deps] + - -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.24.0/validated_linux_requirements.txt [--no-deps] diff --git a/how-to-use-azureml/automated-machine-learning/automl_env_mac.yml b/how-to-use-azureml/automated-machine-learning/automl_env_mac.yml index 2be16810..7986e34e 100644 --- a/how-to-use-azureml/automated-machine-learning/automl_env_mac.yml +++ b/how-to-use-azureml/automated-machine-learning/automl_env_mac.yml @@ -22,8 +22,8 @@ dependencies: - pip: # Required packages for AzureML execution, history, and data preparation. - - azureml-widgets~=1.23.0 + - azureml-widgets~=1.24.0 - pytorch-transformers==1.0.0 - spacy==2.1.8 - https://aka.ms/automl-resources/packages/en_core_web_sm-2.1.0.tar.gz - - -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.23.0/validated_darwin_requirements.txt [--no-deps] + - -r https://automlcesdkdataresources.blob.core.windows.net/validated-requirements/1.24.0/validated_darwin_requirements.txt [--no-deps] diff --git a/how-to-use-azureml/automated-machine-learning/classification-bank-marketing-all-features/auto-ml-classification-bank-marketing-all-features.ipynb b/how-to-use-azureml/automated-machine-learning/classification-bank-marketing-all-features/auto-ml-classification-bank-marketing-all-features.ipynb index 41b399d7..60098ace 100644 --- a/how-to-use-azureml/automated-machine-learning/classification-bank-marketing-all-features/auto-ml-classification-bank-marketing-all-features.ipynb +++ b/how-to-use-azureml/automated-machine-learning/classification-bank-marketing-all-features/auto-ml-classification-bank-marketing-all-features.ipynb @@ -105,7 +105,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/classification-bank-marketing-all-features/auto-ml-classification-bank-marketing-all-features.yml b/how-to-use-azureml/automated-machine-learning/classification-bank-marketing-all-features/auto-ml-classification-bank-marketing-all-features.yml new file mode 100644 index 00000000..0f30214b --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/classification-bank-marketing-all-features/auto-ml-classification-bank-marketing-all-features.yml @@ -0,0 +1,4 @@ +name: auto-ml-classification-bank-marketing-all-features +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/classification-credit-card-fraud/auto-ml-classification-credit-card-fraud.ipynb b/how-to-use-azureml/automated-machine-learning/classification-credit-card-fraud/auto-ml-classification-credit-card-fraud.ipynb index 6e9e39db..ffe01c11 100644 --- a/how-to-use-azureml/automated-machine-learning/classification-credit-card-fraud/auto-ml-classification-credit-card-fraud.ipynb +++ b/how-to-use-azureml/automated-machine-learning/classification-credit-card-fraud/auto-ml-classification-credit-card-fraud.ipynb @@ -93,7 +93,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/classification-credit-card-fraud/auto-ml-classification-credit-card-fraud.yml b/how-to-use-azureml/automated-machine-learning/classification-credit-card-fraud/auto-ml-classification-credit-card-fraud.yml new file mode 100644 index 00000000..148f33d5 --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/classification-credit-card-fraud/auto-ml-classification-credit-card-fraud.yml @@ -0,0 +1,4 @@ +name: auto-ml-classification-credit-card-fraud +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/classification-text-dnn/auto-ml-classification-text-dnn.ipynb b/how-to-use-azureml/automated-machine-learning/classification-text-dnn/auto-ml-classification-text-dnn.ipynb index ab1182fc..767e7c80 100644 --- a/how-to-use-azureml/automated-machine-learning/classification-text-dnn/auto-ml-classification-text-dnn.ipynb +++ b/how-to-use-azureml/automated-machine-learning/classification-text-dnn/auto-ml-classification-text-dnn.ipynb @@ -96,7 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/classification-text-dnn/auto-ml-classification-text-dnn.yml b/how-to-use-azureml/automated-machine-learning/classification-text-dnn/auto-ml-classification-text-dnn.yml new file mode 100644 index 00000000..4c952264 --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/classification-text-dnn/auto-ml-classification-text-dnn.yml @@ -0,0 +1,4 @@ +name: auto-ml-classification-text-dnn +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/continuous-retraining/auto-ml-continuous-retraining.ipynb b/how-to-use-azureml/automated-machine-learning/continuous-retraining/auto-ml-continuous-retraining.ipynb index 56f9c66c..5d7e3336 100644 --- a/how-to-use-azureml/automated-machine-learning/continuous-retraining/auto-ml-continuous-retraining.ipynb +++ b/how-to-use-azureml/automated-machine-learning/continuous-retraining/auto-ml-continuous-retraining.ipynb @@ -81,7 +81,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/continuous-retraining/auto-ml-continuous-retraining.yml b/how-to-use-azureml/automated-machine-learning/continuous-retraining/auto-ml-continuous-retraining.yml new file mode 100644 index 00000000..9b05ea1f --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/continuous-retraining/auto-ml-continuous-retraining.yml @@ -0,0 +1,4 @@ +name: auto-ml-continuous-retraining +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client.cmd b/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client.cmd index 1bc35112..ac5123e6 100644 --- a/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client.cmd +++ b/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client.cmd @@ -5,7 +5,7 @@ set options=%3 set PIP_NO_WARN_SCRIPT_LOCATION=0 IF "%conda_env_name%"=="" SET conda_env_name="azure_automl_experimental" -IF "%automl_env_file%"=="" SET automl_env_file="automl_env.yml" +IF "%automl_env_file%"=="" SET automl_env_file="automl_thin_client_env.yml" IF NOT EXIST %automl_env_file% GOTO YmlMissing diff --git a/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client_linux.sh b/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client_linux.sh index c2f7c86b..e73b5961 100644 --- a/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client_linux.sh +++ b/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client_linux.sh @@ -12,7 +12,7 @@ fi if [ "$AUTOML_ENV_FILE" == "" ] then - AUTOML_ENV_FILE="automl_env.yml" + AUTOML_ENV_FILE="automl_thin_client_env.yml" fi if [ ! -f $AUTOML_ENV_FILE ]; then diff --git a/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client_mac.sh b/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client_mac.sh index fa966ad7..506b6ecd 100644 --- a/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client_mac.sh +++ b/how-to-use-azureml/automated-machine-learning/experimental/automl_setup_thin_client_mac.sh @@ -12,7 +12,7 @@ fi if [ "$AUTOML_ENV_FILE" == "" ] then - AUTOML_ENV_FILE="automl_env.yml" + AUTOML_ENV_FILE="automl_thin_client_env_mac.yml" fi if [ ! -f $AUTOML_ENV_FILE ]; then diff --git a/how-to-use-azureml/automated-machine-learning/experimental/automl_thin_client_env.yml b/how-to-use-azureml/automated-machine-learning/experimental/automl_thin_client_env.yml index 408755fb..ce785fbf 100644 --- a/how-to-use-azureml/automated-machine-learning/experimental/automl_thin_client_env.yml +++ b/how-to-use-azureml/automated-machine-learning/experimental/automl_thin_client_env.yml @@ -7,6 +7,7 @@ dependencies: - nb_conda - cython - urllib3<1.24 +- PyJWT < 2.0.0 - numpy==1.18.5 - pip: @@ -15,4 +16,3 @@ dependencies: - azureml-sdk - azureml-widgets - pandas - - PyJWT < 2.0.0 diff --git a/how-to-use-azureml/automated-machine-learning/experimental/automl_thin_client_env_mac.yml b/how-to-use-azureml/automated-machine-learning/experimental/automl_thin_client_env_mac.yml index f20a4e27..531ac69f 100644 --- a/how-to-use-azureml/automated-machine-learning/experimental/automl_thin_client_env_mac.yml +++ b/how-to-use-azureml/automated-machine-learning/experimental/automl_thin_client_env_mac.yml @@ -8,6 +8,7 @@ dependencies: - nb_conda - cython - urllib3<1.24 +- PyJWT < 2.0.0 - numpy==1.18.5 - pip: @@ -16,4 +17,3 @@ dependencies: - azureml-sdk - azureml-widgets - pandas - - PyJWT < 2.0.0 diff --git a/how-to-use-azureml/automated-machine-learning/experimental/regression-model-proxy/auto-ml-regression-model-proxy.ipynb b/how-to-use-azureml/automated-machine-learning/experimental/regression-model-proxy/auto-ml-regression-model-proxy.ipynb index f2366c38..fdda92a2 100644 --- a/how-to-use-azureml/automated-machine-learning/experimental/regression-model-proxy/auto-ml-regression-model-proxy.ipynb +++ b/how-to-use-azureml/automated-machine-learning/experimental/regression-model-proxy/auto-ml-regression-model-proxy.ipynb @@ -90,7 +90,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, @@ -194,7 +194,6 @@ "|**n_cross_validations**|Number of cross validation splits.|\n", "|**training_data**|(sparse) array-like, shape = [n_samples, n_features]|\n", "|**label_column_name**|(sparse) array-like, shape = [n_samples, ], targets values.|\n", - "|**scenario**|We need to set this parameter to 'Latest' to enable some experimental features. This parameter should not be set outside of this experimental notebook.|\n", "\n", "**_You can find more information about primary metrics_** [here](https://docs.microsoft.com/en-us/azure/machine-learning/service/how-to-configure-auto-train#primary-metric)" ] @@ -223,7 +222,6 @@ " compute_target = compute_target,\n", " training_data = train_data,\n", " label_column_name = label,\n", - " scenario='Latest',\n", " **automl_settings\n", " )" ] diff --git a/how-to-use-azureml/automated-machine-learning/experimental/regression-model-proxy/auto-ml-regression-model-proxy.yml b/how-to-use-azureml/automated-machine-learning/experimental/regression-model-proxy/auto-ml-regression-model-proxy.yml new file mode 100644 index 00000000..e5d127ea --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/experimental/regression-model-proxy/auto-ml-regression-model-proxy.yml @@ -0,0 +1,4 @@ +name: auto-ml-regression-model-proxy +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/forecasting-beer-remote/auto-ml-forecasting-beer-remote.ipynb b/how-to-use-azureml/automated-machine-learning/forecasting-beer-remote/auto-ml-forecasting-beer-remote.ipynb index f2572c95..97755d3e 100644 --- a/how-to-use-azureml/automated-machine-learning/forecasting-beer-remote/auto-ml-forecasting-beer-remote.ipynb +++ b/how-to-use-azureml/automated-machine-learning/forecasting-beer-remote/auto-ml-forecasting-beer-remote.ipynb @@ -113,7 +113,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/forecasting-beer-remote/auto-ml-forecasting-beer-remote.yml b/how-to-use-azureml/automated-machine-learning/forecasting-beer-remote/auto-ml-forecasting-beer-remote.yml new file mode 100644 index 00000000..103560d8 --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/forecasting-beer-remote/auto-ml-forecasting-beer-remote.yml @@ -0,0 +1,4 @@ +name: auto-ml-forecasting-beer-remote +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/forecasting-bike-share/auto-ml-forecasting-bike-share.ipynb b/how-to-use-azureml/automated-machine-learning/forecasting-bike-share/auto-ml-forecasting-bike-share.ipynb index 500c9627..21b19f4d 100644 --- a/how-to-use-azureml/automated-machine-learning/forecasting-bike-share/auto-ml-forecasting-bike-share.ipynb +++ b/how-to-use-azureml/automated-machine-learning/forecasting-bike-share/auto-ml-forecasting-bike-share.ipynb @@ -87,7 +87,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/forecasting-bike-share/auto-ml-forecasting-bike-share.yml b/how-to-use-azureml/automated-machine-learning/forecasting-bike-share/auto-ml-forecasting-bike-share.yml new file mode 100644 index 00000000..70a3271c --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/forecasting-bike-share/auto-ml-forecasting-bike-share.yml @@ -0,0 +1,4 @@ +name: auto-ml-forecasting-bike-share +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/forecasting-energy-demand/auto-ml-forecasting-energy-demand.ipynb b/how-to-use-azureml/automated-machine-learning/forecasting-energy-demand/auto-ml-forecasting-energy-demand.ipynb index a5f910da..6ba8039c 100644 --- a/how-to-use-azureml/automated-machine-learning/forecasting-energy-demand/auto-ml-forecasting-energy-demand.ipynb +++ b/how-to-use-azureml/automated-machine-learning/forecasting-energy-demand/auto-ml-forecasting-energy-demand.ipynb @@ -97,7 +97,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/forecasting-energy-demand/auto-ml-forecasting-energy-demand.yml b/how-to-use-azureml/automated-machine-learning/forecasting-energy-demand/auto-ml-forecasting-energy-demand.yml new file mode 100644 index 00000000..13bd78f8 --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/forecasting-energy-demand/auto-ml-forecasting-energy-demand.yml @@ -0,0 +1,4 @@ +name: auto-ml-forecasting-energy-demand +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/forecasting-forecast-function/auto-ml-forecasting-function.ipynb b/how-to-use-azureml/automated-machine-learning/forecasting-forecast-function/auto-ml-forecasting-function.ipynb index 8bb87b18..4517881d 100644 --- a/how-to-use-azureml/automated-machine-learning/forecasting-forecast-function/auto-ml-forecasting-function.ipynb +++ b/how-to-use-azureml/automated-machine-learning/forecasting-forecast-function/auto-ml-forecasting-function.ipynb @@ -94,7 +94,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/forecasting-forecast-function/auto-ml-forecasting-function.yml b/how-to-use-azureml/automated-machine-learning/forecasting-forecast-function/auto-ml-forecasting-function.yml new file mode 100644 index 00000000..144797d6 --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/forecasting-forecast-function/auto-ml-forecasting-function.yml @@ -0,0 +1,4 @@ +name: auto-ml-forecasting-function +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/forecasting-orange-juice-sales/auto-ml-forecasting-orange-juice-sales.ipynb b/how-to-use-azureml/automated-machine-learning/forecasting-orange-juice-sales/auto-ml-forecasting-orange-juice-sales.ipynb index 8b5d2dd6..2fbe860e 100644 --- a/how-to-use-azureml/automated-machine-learning/forecasting-orange-juice-sales/auto-ml-forecasting-orange-juice-sales.ipynb +++ b/how-to-use-azureml/automated-machine-learning/forecasting-orange-juice-sales/auto-ml-forecasting-orange-juice-sales.ipynb @@ -82,7 +82,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/forecasting-orange-juice-sales/auto-ml-forecasting-orange-juice-sales.yml b/how-to-use-azureml/automated-machine-learning/forecasting-orange-juice-sales/auto-ml-forecasting-orange-juice-sales.yml new file mode 100644 index 00000000..a6cc3e71 --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/forecasting-orange-juice-sales/auto-ml-forecasting-orange-juice-sales.yml @@ -0,0 +1,4 @@ +name: auto-ml-forecasting-orange-juice-sales +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/local-run-classification-credit-card-fraud/auto-ml-classification-credit-card-fraud-local.ipynb b/how-to-use-azureml/automated-machine-learning/local-run-classification-credit-card-fraud/auto-ml-classification-credit-card-fraud-local.ipynb index bb7f1f95..fa7bec19 100644 --- a/how-to-use-azureml/automated-machine-learning/local-run-classification-credit-card-fraud/auto-ml-classification-credit-card-fraud-local.ipynb +++ b/how-to-use-azureml/automated-machine-learning/local-run-classification-credit-card-fraud/auto-ml-classification-credit-card-fraud-local.ipynb @@ -96,7 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/local-run-classification-credit-card-fraud/auto-ml-classification-credit-card-fraud-local.yml b/how-to-use-azureml/automated-machine-learning/local-run-classification-credit-card-fraud/auto-ml-classification-credit-card-fraud-local.yml new file mode 100644 index 00000000..6c817042 --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/local-run-classification-credit-card-fraud/auto-ml-classification-credit-card-fraud-local.yml @@ -0,0 +1,4 @@ +name: auto-ml-classification-credit-card-fraud-local +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/regression-explanation-featurization/auto-ml-regression-explanation-featurization.ipynb b/how-to-use-azureml/automated-machine-learning/regression-explanation-featurization/auto-ml-regression-explanation-featurization.ipynb index 3e9e0f5c..d7193587 100644 --- a/how-to-use-azureml/automated-machine-learning/regression-explanation-featurization/auto-ml-regression-explanation-featurization.ipynb +++ b/how-to-use-azureml/automated-machine-learning/regression-explanation-featurization/auto-ml-regression-explanation-featurization.ipynb @@ -96,7 +96,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/regression-explanation-featurization/auto-ml-regression-explanation-featurization.yml b/how-to-use-azureml/automated-machine-learning/regression-explanation-featurization/auto-ml-regression-explanation-featurization.yml new file mode 100644 index 00000000..9db24f2b --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/regression-explanation-featurization/auto-ml-regression-explanation-featurization.yml @@ -0,0 +1,4 @@ +name: auto-ml-regression-explanation-featurization +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/automated-machine-learning/regression/auto-ml-regression.ipynb b/how-to-use-azureml/automated-machine-learning/regression/auto-ml-regression.ipynb index d56227b3..fea6c750 100644 --- a/how-to-use-azureml/automated-machine-learning/regression/auto-ml-regression.ipynb +++ b/how-to-use-azureml/automated-machine-learning/regression/auto-ml-regression.ipynb @@ -92,7 +92,7 @@ "metadata": {}, "outputs": [], "source": [ - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] }, diff --git a/how-to-use-azureml/automated-machine-learning/regression/auto-ml-regression.yml b/how-to-use-azureml/automated-machine-learning/regression/auto-ml-regression.yml new file mode 100644 index 00000000..4e84e13a --- /dev/null +++ b/how-to-use-azureml/automated-machine-learning/regression/auto-ml-regression.yml @@ -0,0 +1,4 @@ +name: auto-ml-regression +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/azure-synapse/README.md b/how-to-use-azureml/azure-synapse/README.md new file mode 100644 index 00000000..1398865e --- /dev/null +++ b/how-to-use-azureml/azure-synapse/README.md @@ -0,0 +1,84 @@ +Azure Synapse Analytics is a limitless analytics service that brings together data integration, enterprise data warehousing, and big data analytics. It gives you the freedom to query data on your terms, using either serverless or dedicated resources—at scale. Azure Synapse brings these worlds together with a unified experience to ingest, explore, prepare, manage, and serve data for immediate BI and machine learning needs. A core offering within Azure Synapse Analytics are serverless Apache Spark pools enhanced for big data workloads.  + +Synapse in Aml integration is for customers who want to use Apache Spark in Azure Synapse Analytics to prepare data at scale in Azure ML before training their ML model. This will allow customers to work on their end-to-end ML lifecycle including large-scale data preparation, model training and deployment within Azure ML workspace without having to use suboptimal tools for machine learning or switch between multiple tools for data preparation and model training. The ability to perform all ML tasks within Azure ML will reduce time required for customers to iterate on a machine learning project which typically includes multiple rounds of data preparation and training. + +In the public preview, the capabilities are provided: + +- Link Azure Synapse Analytics workspace to Azure Machine Learning workspace (via ARM, UI or SDK) +- Attach Apache Spark pools powered by Azure Synapse Analytics as Azure Machine Learning compute targets (via ARM, UI or SDK) +- Launch Apache Spark sessions in notebooks and perform interactive data exploration and preparation. This interactive experience leverages Apache Spark magic and customers will have session-level Conda support to install packages. +- Productionize ML pipelines by leveraging Apache Spark pools to pre-process big data + +# Using Synapse in Azure machine learning + +## Create synapse resources + +Follow up the documents to create Synapse workspace and resource-setup.sh is available for you to create the resources. + +- Create from [Portal](https://docs.microsoft.com/en-us/azure/synapse-analytics/quickstart-create-workspace) +- Create from [Cli](https://docs.microsoft.com/en-us/azure/synapse-analytics/quickstart-create-workspace-cli) + +Follow up the documents to create Synapse spark pool + +- Create from [Portal](https://docs.microsoft.com/en-us/azure/synapse-analytics/quickstart-create-apache-spark-pool-portal) +- Create from [Cli](https://docs.microsoft.com/en-us/cli/azure/ext/synapse/synapse/spark/pool?view=azure-cli-latest) + +## Link Synapse Workspace + +Make sure you are the owner of synapse workspace so that you can link synapse workspace into AML. +You can run resource-setup.py to link the synapse workspace and attach compute + +```python +from azureml.core import Workspace +ws = Workspace.from_config() + +from azureml.core import LinkedService, SynapseWorkspaceLinkedServiceConfiguration +synapse_link_config = SynapseWorkspaceLinkedServiceConfiguration( + subscription_id="", + resource_group="I', gz.read(4)) + if not label: + n_rows = struct.unpack('>I', gz.read(4))[0] + n_cols = struct.unpack('>I', gz.read(4))[0] + res = np.frombuffer(gz.read(n_items[0] * n_rows * n_cols), dtype=np.uint8) + res = res.reshape(n_items[0], n_rows * n_cols) + else: + res = np.frombuffer(gz.read(n_items[0]), dtype=np.uint8) + res = res.reshape(n_items[0], 1) + return res + + +def download_mnist(): + data_folder = os.path.join(os.getcwd(), 'data/mnist') + os.makedirs(data_folder, exist_ok=True) + + mnist_file_dataset = MNIST.get_file_dataset() + mnist_file_dataset.download(data_folder, overwrite=True) + + X_train = load_data(glob.glob(os.path.join(data_folder, "**/train-images-idx3-ubyte.gz"), + recursive=True)[0], False) / 255.0 + X_test = load_data(glob.glob(os.path.join(data_folder, "**/t10k-images-idx3-ubyte.gz"), + recursive=True)[0], False) / 255.0 + y_train = load_data(glob.glob(os.path.join(data_folder, "**/train-labels-idx1-ubyte.gz"), + recursive=True)[0], True).reshape(-1) + y_test = load_data(glob.glob(os.path.join(data_folder, "**/t10k-labels-idx1-ubyte.gz"), + recursive=True)[0], True).reshape(-1) + + train = tuple_dataset.TupleDataset(X_train.astype(np.float32), y_train.astype(np.int32)) + test = tuple_dataset.TupleDataset(X_test.astype(np.float32), y_test.astype(np.int32)) + + return train, test diff --git a/how-to-use-azureml/ml-frameworks/fastai/fastai-with-custom-docker/fastai-with-custom-docker.yml b/how-to-use-azureml/ml-frameworks/fastai/fastai-with-custom-docker/fastai-with-custom-docker.yml new file mode 100644 index 00000000..3e5f80ae --- /dev/null +++ b/how-to-use-azureml/ml-frameworks/fastai/fastai-with-custom-docker/fastai-with-custom-docker.yml @@ -0,0 +1,5 @@ +name: fastai-with-custom-docker +dependencies: +- pip: + - azureml-sdk + - fastai==1.0.61 diff --git a/how-to-use-azureml/ml-frameworks/keras/train-hyperparameter-tune-deploy-with-keras/train-hyperparameter-tune-deploy-with-keras.yml b/how-to-use-azureml/ml-frameworks/keras/train-hyperparameter-tune-deploy-with-keras/train-hyperparameter-tune-deploy-with-keras.yml new file mode 100644 index 00000000..8fa4d352 --- /dev/null +++ b/how-to-use-azureml/ml-frameworks/keras/train-hyperparameter-tune-deploy-with-keras/train-hyperparameter-tune-deploy-with-keras.yml @@ -0,0 +1,8 @@ +name: train-hyperparameter-tune-deploy-with-keras +dependencies: +- pip: + - azureml-sdk + - azureml-widgets + - tensorflow + - keras<=2.3.1 + - matplotlib diff --git a/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-nccl-gloo/distributed-pytorch-with-nccl-gloo.ipynb b/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel/distributed-pytorch-with-distributeddataparallel.ipynb similarity index 63% rename from how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-nccl-gloo/distributed-pytorch-with-nccl-gloo.ipynb rename to how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel/distributed-pytorch-with-distributeddataparallel.ipynb index 55c15863..df2b5ce9 100644 --- a/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-nccl-gloo/distributed-pytorch-with-nccl-gloo.ipynb +++ b/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel/distributed-pytorch-with-distributeddataparallel.ipynb @@ -21,7 +21,8 @@ "metadata": {}, "source": [ "# Distributed PyTorch with DistributedDataParallel\n", - "In this tutorial, you will train a PyTorch model on the [MNIST](http://yann.lecun.com/exdb/mnist/) dataset using distributed training with PyTorch's `DistributedDataParallel` module across a GPU cluster. " + "\n", + "In this tutorial, you will train a PyTorch model on the [CIFAR-10](https://www.cs.toronto.edu/~kriz/cifar.html) dataset using distributed training with PyTorch's `DistributedDataParallel` module across a GPU cluster." ] }, { @@ -113,7 +114,7 @@ "from azureml.core.compute_target import ComputeTargetException\n", "\n", "# choose a name for your cluster\n", - "cluster_name = \"gpu-cluster\"\n", + "cluster_name = 'gpu-cluster'\n", "\n", "try:\n", " compute_target = ComputeTarget(workspace=ws, name=cluster_name)\n", @@ -139,6 +140,68 @@ "The above code creates GPU compute. If you instead want to create CPU compute, provide a different VM size to the `vm_size` parameter, such as `STANDARD_D2_V2`." ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prepare dataset\n", + "\n", + "Prepare the dataset used for training. We will first download and extract the publicly available CIFAR-10 dataset from the cs.toronto.edu website and then create an Azure ML FileDataset to use the data for training." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download and extract CIFAR-10 data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import urllib\n", + "import tarfile\n", + "import os\n", + "\n", + "url = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'\n", + "filename = 'cifar-10-python.tar.gz'\n", + "data_root = 'cifar-10'\n", + "filepath = os.path.join(data_root, filename)\n", + "\n", + "if not os.path.isdir(data_root):\n", + " os.makedirs(data_root, exist_ok=True)\n", + " urllib.request.urlretrieve(url, filepath)\n", + " with tarfile.open(filepath, \"r:gz\") as tar:\n", + " tar.extractall(path=data_root)\n", + " os.remove(filepath) # delete tar.gz file after extraction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create Azure ML dataset\n", + "\n", + "The `upload_directory` method will upload the data to a datastore and create a FileDataset from it. In this tutorial we will use the workspace's default datastore." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from azureml.core import Dataset\n", + "\n", + "datastore = ws.get_default_datastore()\n", + "dataset = Dataset.File.upload_directory(\n", + " src_dir=data_root, target=(datastore, data_root)\n", + ")" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -161,8 +224,6 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", - "\n", "project_folder = './pytorch-distr'\n", "os.makedirs(project_folder, exist_ok=True)" ] @@ -172,26 +233,14 @@ "metadata": {}, "source": [ "### Prepare training script\n", - "Now you will need to create your training script. In this tutorial, the script for distributed training of MNIST is already provided for you at `pytorch_mnist.py`. In practice, you should be able to take any custom PyTorch training script as is and run it with Azure ML without having to modify your code.\n", - "\n", - "However, if you would like to use Azure ML's [metric logging](https://docs.microsoft.com/azure/machine-learning/service/concept-azure-machine-learning-architecture#logging) capabilities, you will have to add a small amount of Azure ML logic inside your training script. In this example, at each logging interval, we will log the loss for that minibatch to our Azure ML run.\n", - "\n", - "To do so, in `pytorch_mnist.py`, we will first access the Azure ML `Run` object within the script:\n", - "```Python\n", - "from azureml.core.run import Run\n", - "run = Run.get_context()\n", - "```\n", - "Later within the script, we log the loss metric to our run:\n", - "```Python\n", - "run.log('loss', losses.avg)\n", - "```" + "Now you will need to create your training script. In this tutorial, the script for distributed training on CIFAR-10 is already provided for you at `train.py`. In practice, you should be able to take any custom PyTorch training script as is and run it with Azure ML without having to modify your code." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Once your script is ready, copy the training script `pytorch_mnist.py` into the project directory." + "Once your script is ready, copy the training script `train.py` into the project directory." ] }, { @@ -202,7 +251,7 @@ "source": [ "import shutil\n", "\n", - "shutil.copy('pytorch_mnist.py', project_folder)" + "shutil.copy('train.py', project_folder)" ] }, { @@ -231,26 +280,7 @@ "source": [ "### Create an environment\n", "\n", - "Define a conda environment YAML file with your training script dependencies and create an Azure ML environment." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "%%writefile conda_dependencies.yml\n", - "\n", - "channels:\n", - "- conda-forge\n", - "dependencies:\n", - "- python=3.6.2\n", - "- pip:\n", - " - azureml-defaults\n", - " - torch==1.6.0\n", - " - torchvision==0.7.0\n", - " - future==0.17.1" + "In this tutorial, we will use one of Azure ML's curated PyTorch environments for training. [Curated environments](https://docs.microsoft.com/azure/machine-learning/how-to-use-environments#use-a-curated-environment) are available in your workspace by default. Specifically, we will use the PyTorch 1.6 GPU curated environment." ] }, { @@ -261,24 +291,39 @@ "source": [ "from azureml.core import Environment\n", "\n", - "pytorch_env = Environment.from_conda_specification(name = 'pytorch-1.6-gpu', file_path = './conda_dependencies.yml')\n", - "\n", - "# Specify a GPU base image\n", - "pytorch_env.docker.enabled = True\n", - "pytorch_env.docker.base_image = 'mcr.microsoft.com/azureml/openmpi3.1.2-cuda10.1-cudnn7-ubuntu18.04'" + "pytorch_env = Environment.get(ws, name='AzureML-PyTorch-1.6-GPU')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Configure the training job: torch.distributed with NCCL backend\n", + "### Configure the training job\n", "\n", - "Create a ScriptRunConfig object to specify the configuration details of your training job, including your training script, environment to use, and the compute target to run on.\n", + "To launch a distributed PyTorch job on Azure ML, you have two options:\n", "\n", - "In order to run a distributed PyTorch job with **torch.distributed** using the NCCL backend, create a `PyTorchConfiguration` and pass it to the `distributed_job_config` parameter of the ScriptRunConfig constructor. Specify `communication_backend='Nccl'` in the PyTorchConfiguration. The below code will configure a 2-node distributed job. The NCCL backend is the recommended backend for PyTorch distributed GPU training.\n", + "1. Per-process launch - specify the total # of worker processes (typically one per GPU) you want to run, and\n", + "Azure ML will handle launching each process.\n", + "2. Per-node launch with [torch.distributed.launch](https://pytorch.org/docs/stable/distributed.html#launch-utility) - provide the `torch.distributed.launch` command you want to\n", + "run on each node.\n", "\n", - "The script arguments refers to the Azure ML-set environment variables `AZ_BATCHAI_PYTORCH_INIT_METHOD` for shared file-system initialization and `AZ_BATCHAI_TASK_INDEX` for the global rank of each worker process." + "For more information, see the [documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-train-pytorch#distributeddataparallel).\n", + "\n", + "Both options are shown below." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Per-process launch\n", + "\n", + "To use the per-process launch option in which Azure ML will handle launching each of the processes to run your training script,\n", + "\n", + "1. Specify the training script and arguments\n", + "2. Create a `PyTorchConfiguration` and specify `node_count` and `process_count`. The `process_count` is the total number of processes you want to run for the job; this should typically equal the # of GPUs available on each node multiplied by the # of nodes. Since this tutorial uses the `STANDARD_NC6` SKU, which has one GPU, the total process count for a 2-node job is `2`. If you are using a SKU with >1 GPUs, adjust the `process_count` accordingly.\n", + "\n", + "Azure ML will set the `MASTER_ADDR`, `MASTER_PORT`, `NODE_RANK`, `WORLD_SIZE` environment variables on each node, in addition to the process-level `RANK` and `LOCAL_RANK` environment variables, that are needed for distributed PyTorch training." ] }, { @@ -290,17 +335,61 @@ "from azureml.core import ScriptRunConfig\n", "from azureml.core.runconfig import PyTorchConfiguration\n", "\n", - "args = ['--dist-backend', 'nccl',\n", - " '--dist-url', '$AZ_BATCHAI_PYTORCH_INIT_METHOD',\n", - " '--rank', '$AZ_BATCHAI_TASK_INDEX',\n", - " '--world-size', 2]\n", + "# create distributed config\n", + "distr_config = PyTorchConfiguration(process_count=2, node_count=2)\n", "\n", + "# create args\n", + "args = [\"--data-dir\", dataset.as_download(), \"--epochs\", 25]\n", + "\n", + "# create job config\n", "src = ScriptRunConfig(source_directory=project_folder,\n", - " script='pytorch_mnist.py',\n", + " script='train.py',\n", " arguments=args,\n", " compute_target=compute_target,\n", " environment=pytorch_env,\n", - " distributed_job_config=PyTorchConfiguration(communication_backend='Nccl', node_count=2))" + " distributed_job_config=distr_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Per-node launch with `torch.distributed.launch`\n", + "\n", + "If you would instead like to use the PyTorch-provided launch utility `torch.distributed.launch` to handle launching the worker processes on each node, you can do so as well. \n", + "\n", + "1. Provide the launch command to the `command` parameter of ScriptRunConfig. For PyTorch jobs Azure ML will set the `MASTER_ADDR`, `MASTER_PORT`, and `NODE_RANK` environment variables on each node, so you can simply just reference those environment variables in your command. If you are using a SKU with >1 GPUs, adjust the `--nproc_per_node` argument accordingly.\n", + "\n", + "2. Create a `PyTorchConfiguration` and specify the `node_count`. You do not need to specify the `process_count`; by default Azure ML will launch one process per node to run the `command` you provided.\n", + "\n", + "Uncomment the code below to configure a job with this method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "'''\n", + "from azureml.core import ScriptRunConfig\n", + "from azureml.core.runconfig import PyTorchConfiguration\n", + "\n", + "# create distributed config\n", + "distr_config = PyTorchConfiguration(node_count=2)\n", + "\n", + "# define command\n", + "launch_cmd = [\"python -m torch.distributed.launch --nproc_per_node 1 --nnodes 2 \" \\\n", + " \"--node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT --use_env \" \\\n", + " \"train.py --data-dir\", dataset.as_download(), \"--epochs 25\"]\n", + "\n", + "# create job config\n", + "src = ScriptRunConfig(source_directory=project_folder,\n", + " command=launch_cmd,\n", + " compute_target=compute_target,\n", + " environment=pytorch_env,\n", + " distributed_job_config=distr_config)\n", + "'''" ] }, { @@ -308,7 +397,7 @@ "metadata": {}, "source": [ "### Submit job\n", - "Run your experiment by submitting your ScriptRunConfig object. Note that this call is asynchronous." + "Run your experiment by submitting your `ScriptRunConfig` object. Note that this call is asynchronous." ] }, { @@ -355,50 +444,12 @@ "source": [ "run.wait_for_completion(show_output=True) # this provides a verbose log" ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Configure training job: torch.distributed with Gloo backend\n", - "\n", - "If you would instead like to use the Gloo backend for distributed training, you can do so via the following code. The Gloo backend is recommended for distributed CPU training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from azureml.core import ScriptRunConfig\n", - "from azureml.core.runconfig import PyTorchConfiguration\n", - "\n", - "args = ['--dist-backend', 'gloo',\n", - " '--dist-url', '$AZ_BATCHAI_PYTORCH_INIT_METHOD',\n", - " '--rank', '$AZ_BATCHAI_TASK_INDEX',\n", - " '--world-size', 2]\n", - "\n", - "src = ScriptRunConfig(source_directory=project_folder,\n", - " script='pytorch_mnist.py',\n", - " arguments=args,\n", - " compute_target=compute_target,\n", - " environment=pytorch_env,\n", - " distributed_job_config=PyTorchConfiguration(communication_backend='Gloo', node_count=2))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Once you create the ScriptRunConfig, you can follow the submit steps as shown in the previous steps to submit a PyTorch distributed run using the Gloo backend." - ] } ], "metadata": { "authors": [ { - "name": "ninhu" + "name": "minxia" } ], "category": "training", @@ -406,7 +457,7 @@ "AML Compute" ], "datasets": [ - "MNIST" + "CIFAR-10" ], "deployment": [ "None" @@ -432,12 +483,12 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.7.7" }, "tags": [ "None" ], - "task": "Train a model using distributed training via Nccl/Gloo" + "task": "Train a model using distributed training via PyTorch DistributedDataParallel" }, "nbformat": 4, "nbformat_minor": 2 diff --git a/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel/distributed-pytorch-with-distributeddataparallel.yml b/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel/distributed-pytorch-with-distributeddataparallel.yml new file mode 100644 index 00000000..8fa7e81d --- /dev/null +++ b/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel/distributed-pytorch-with-distributeddataparallel.yml @@ -0,0 +1,5 @@ +name: distributed-pytorch-with-distributeddataparallel +dependencies: +- pip: + - azureml-sdk + - azureml-widgets diff --git a/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel/train.py b/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel/train.py new file mode 100644 index 00000000..c6c302dd --- /dev/null +++ b/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel/train.py @@ -0,0 +1,238 @@ +# Copyright (c) 2017 Facebook, Inc. All rights reserved. +# BSD 3-Clause License +# +# Script adapted from: +# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html +# ============================================================================== + +# imports +import torch +import torchvision +import torchvision.transforms as transforms +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import os +import argparse + + +# define network architecture +class Net(nn.Module): + def __init__(self): + super(Net, self).__init__() + self.conv1 = nn.Conv2d(3, 32, 3) + self.pool = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(32, 64, 3) + self.conv3 = nn.Conv2d(64, 128, 3) + self.fc1 = nn.Linear(128 * 6 * 6, 120) + self.dropout = nn.Dropout(p=0.2) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = self.pool(F.relu(self.conv2(x))) + x = self.pool(F.relu(self.conv3(x))) + x = x.view(-1, 128 * 6 * 6) + x = self.dropout(F.relu(self.fc1(x))) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +def train(train_loader, model, criterion, optimizer, epoch, device, print_freq, rank): + running_loss = 0.0 + for i, data in enumerate(train_loader, 0): + # get the inputs; data is a list of [inputs, labels] + inputs, labels = data[0].to(device), data[1].to(device) + + # zero the parameter gradients + optimizer.zero_grad() + + # forward + backward + optimize + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + + # print statistics + running_loss += loss.item() + if i % print_freq == 0: # print every print_freq mini-batches + print( + "Rank %d: [%d, %5d] loss: %.3f" + % (rank, epoch + 1, i + 1, running_loss / print_freq) + ) + running_loss = 0.0 + + +def evaluate(test_loader, model, device): + classes = ( + "plane", + "car", + "bird", + "cat", + "deer", + "dog", + "frog", + "horse", + "ship", + "truck", + ) + + model.eval() + + correct = 0 + total = 0 + class_correct = list(0.0 for i in range(10)) + class_total = list(0.0 for i in range(10)) + with torch.no_grad(): + for data in test_loader: + images, labels = data[0].to(device), data[1].to(device) + outputs = model(images) + _, predicted = torch.max(outputs.data, 1) + total += labels.size(0) + correct += (predicted == labels).sum().item() + c = (predicted == labels).squeeze() + for i in range(10): + label = labels[i] + class_correct[label] += c[i].item() + class_total[label] += 1 + + # print total test set accuracy + print( + "Accuracy of the network on the 10000 test images: %d %%" + % (100 * correct / total) + ) + + # print test accuracy for each of the classes + for i in range(10): + print( + "Accuracy of %5s : %2d %%" + % (classes[i], 100 * class_correct[i] / class_total[i]) + ) + + +def main(args): + # get PyTorch environment variables + world_size = int(os.environ["WORLD_SIZE"]) + rank = int(os.environ["RANK"]) + local_rank = int(os.environ["LOCAL_RANK"]) + + distributed = world_size > 1 + + # set device + if distributed: + device = torch.device("cuda", local_rank) + else: + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + # initialize distributed process group using default env:// method + if distributed: + torch.distributed.init_process_group(backend="nccl") + + # define train and test dataset DataLoaders + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + + train_set = torchvision.datasets.CIFAR10( + root=args.data_dir, train=True, download=False, transform=transform + ) + + if distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler(train_set) + else: + train_sampler = None + + train_loader = torch.utils.data.DataLoader( + train_set, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.workers, + sampler=train_sampler, + ) + + test_set = torchvision.datasets.CIFAR10( + root=args.data_dir, train=False, download=False, transform=transform + ) + test_loader = torch.utils.data.DataLoader( + test_set, batch_size=args.batch_size, shuffle=False, num_workers=args.workers + ) + + model = Net().to(device) + + # wrap model with DDP + if distributed: + model = nn.parallel.DistributedDataParallel( + model, device_ids=[local_rank], output_device=local_rank + ) + + # define loss function and optimizer + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD( + model.parameters(), lr=args.learning_rate, momentum=args.momentum + ) + + # train the model + for epoch in range(args.epochs): + print("Rank %d: Starting epoch %d" % (rank, epoch)) + if distributed: + train_sampler.set_epoch(epoch) + model.train() + train( + train_loader, + model, + criterion, + optimizer, + epoch, + device, + args.print_freq, + rank, + ) + + print("Rank %d: Finished Training" % (rank)) + + if not distributed or rank == 0: + os.makedirs(args.output_dir, exist_ok=True) + model_path = os.path.join(args.output_dir, "cifar_net.pt") + torch.save(model.state_dict(), model_path) + + # evaluate on full test dataset + evaluate(test_loader, model, device) + + +if __name__ == "__main__": + # setup argparse + parser = argparse.ArgumentParser() + parser.add_argument( + "--data-dir", type=str, help="directory containing CIFAR-10 dataset" + ) + parser.add_argument("--epochs", default=10, type=int, help="number of epochs") + parser.add_argument( + "--batch-size", + default=16, + type=int, + help="mini batch size for each gpu/process", + ) + parser.add_argument( + "--workers", + default=2, + type=int, + help="number of data loading workers for each gpu/process", + ) + parser.add_argument( + "--learning-rate", default=0.001, type=float, help="learning rate" + ) + parser.add_argument("--momentum", default=0.9, type=float, help="momentum") + parser.add_argument( + "--output-dir", default="outputs", type=str, help="directory to save model to" + ) + parser.add_argument( + "--print-freq", + default=200, + type=int, + help="frequency of printing training statistics", + ) + args = parser.parse_args() + + main(args) diff --git a/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-horovod/distributed-pytorch-with-horovod.yml b/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-horovod/distributed-pytorch-with-horovod.yml new file mode 100644 index 00000000..58bb77d8 --- /dev/null +++ b/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-horovod/distributed-pytorch-with-horovod.yml @@ -0,0 +1,5 @@ +name: distributed-pytorch-with-horovod +dependencies: +- pip: + - azureml-sdk + - azureml-widgets diff --git a/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-nccl-gloo/pytorch_mnist.py b/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-nccl-gloo/pytorch_mnist.py deleted file mode 100644 index 7a9aeb60..00000000 --- a/how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-nccl-gloo/pytorch_mnist.py +++ /dev/null @@ -1,209 +0,0 @@ -# Copyright (c) 2017, PyTorch contributors -# Modifications copyright (C) Microsoft Corporation -# Licensed under the BSD license -# Adapted from https://github.com/Azure/BatchAI/tree/master/recipes/PyTorch/PyTorch-GPU-Distributed-Gloo - -from __future__ import print_function -import argparse -import os -import shutil -import time -import torch -import torch.nn as nn -import torch.nn.functional as F -import torch.optim as optim -from torchvision import datasets, transforms -import torch.nn.parallel -import torch.backends.cudnn as cudnn -import torch.distributed as dist -import torch.utils.data -import torch.utils.data.distributed -import torchvision.models as models - -from azureml.core.run import Run -# get the Azure ML run object -run = Run.get_context() - -# Training settings -parser = argparse.ArgumentParser(description='PyTorch MNIST Example') -parser.add_argument('--batch-size', type=int, default=64, metavar='N', - help='input batch size for training (default: 64)') -parser.add_argument('--test-batch-size', type=int, default=1000, metavar='N', - help='input batch size for testing (default: 1000)') -parser.add_argument('--epochs', type=int, default=10, metavar='N', - help='number of epochs to train (default: 10)') -parser.add_argument('--lr', type=float, default=0.01, metavar='LR', - help='learning rate (default: 0.01)') -parser.add_argument('--momentum', type=float, default=0.5, metavar='M', - help='SGD momentum (default: 0.5)') -parser.add_argument('--seed', type=int, default=1, metavar='S', - help='random seed (default: 1)') -parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', - help='number of data loading workers (default: 4)') -parser.add_argument('--log-interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') -parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float, - metavar='W', help='weight decay (default: 1e-4)') -parser.add_argument('--world-size', default=1, type=int, - help='number of distributed processes') -parser.add_argument('--dist-url', type=str, - help='url used to set up distributed training') -parser.add_argument('--dist-backend', default='nccl', type=str, - help='distributed backend') -parser.add_argument('--rank', default=-1, type=int, - help='rank of the worker') - -best_prec1 = 0 -args = parser.parse_args() - -args.distributed = args.world_size >= 2 - -if args.distributed: - dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, - world_size=args.world_size, rank=args.rank) - -train_dataset = datasets.MNIST('data-%d' % args.rank, train=True, download=True, - transform=transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.1307,), (0.3081,)) - ])) - -if args.distributed: - train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) -else: - train_sampler = None - -train_loader = torch.utils.data.DataLoader( - train_dataset, - batch_size=args.batch_size, shuffle=(train_sampler is None), - num_workers=args.workers, pin_memory=True, sampler=train_sampler) - - -test_loader = torch.utils.data.DataLoader( - train_dataset, - batch_size=args.batch_size, shuffle=False, - num_workers=args.workers, pin_memory=True) - - -class Net(nn.Module): - def __init__(self): - super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv2_drop = nn.Dropout2d() - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, 10) - - def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) - x = x.view(-1, 320) - x = F.relu(self.fc1(x)) - x = F.dropout(x, training=self.training) - x = self.fc2(x) - return F.log_softmax(x) - - -model = Net() - -if not args.distributed: - model = torch.nn.DataParallel(model).cuda() -else: - model.cuda() - model = torch.nn.parallel.DistributedDataParallel(model) - -# define loss function (criterion) and optimizer -criterion = nn.CrossEntropyLoss().cuda() - -optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.weight_decay) - - -def train(epoch): - batch_time = AverageMeter() - data_time = AverageMeter() - losses = AverageMeter() - top1 = AverageMeter() - top5 = AverageMeter() - - # switch to train mode - model.train() - end = time.time() - for i, (input, target) in enumerate(train_loader): - # measure data loading time - data_time.update(time.time() - end) - - input, target = input.cuda(), target.cuda() - - # compute output - try: - output = model(input) - loss = criterion(output, target) - - # measure accuracy and record loss - prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) - losses.update(loss.item(), input.size(0)) - top1.update(prec1[0], input.size(0)) - top5.update(prec5[0], input.size(0)) - - # compute gradient and do SGD step - optimizer.zero_grad() - loss.backward() - optimizer.step() - - # measure elapsed time - batch_time.update(time.time() - end) - end = time.time() - - if i % 10 == 0: - run.log("loss", losses.avg) - run.log("prec@1", "{0:.3f}".format(top1.avg)) - run.log("prec@5", "{0:.3f}".format(top5.avg)) - print('Epoch: [{0}][{1}/{2}]\t' - 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' - 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' - 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' - 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' - 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(epoch, i, len(train_loader), - batch_time=batch_time, data_time=data_time, - loss=losses, top1=top1, top5=top5)) - except: - import sys - print("Unexpected error:", sys.exc_info()[0]) - - -class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self): - self.reset() - - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 - - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count - - -def accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - maxk = max(topk) - batch_size = target.size(0) - - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) - - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) - res.append(correct_k.mul_(100.0 / batch_size)) - return res - - -for epoch in range(1, args.epochs + 1): - train(epoch) diff --git a/how-to-use-azureml/ml-frameworks/pytorch/train-hyperparameter-tune-deploy-with-pytorch/train-hyperparameter-tune-deploy-with-pytorch.yml b/how-to-use-azureml/ml-frameworks/pytorch/train-hyperparameter-tune-deploy-with-pytorch/train-hyperparameter-tune-deploy-with-pytorch.yml new file mode 100644 index 00000000..c04135a1 --- /dev/null +++ b/how-to-use-azureml/ml-frameworks/pytorch/train-hyperparameter-tune-deploy-with-pytorch/train-hyperparameter-tune-deploy-with-pytorch.yml @@ -0,0 +1,10 @@ +name: train-hyperparameter-tune-deploy-with-pytorch +dependencies: +- pip: + - azureml-sdk + - azureml-widgets + - pillow==5.4.1 + - matplotlib + - numpy==1.19.3 + - https://download.pytorch.org/whl/cpu/torch-1.6.0%2Bcpu-cp36-cp36m-win_amd64.whl + - https://download.pytorch.org/whl/cpu/torchvision-0.7.0%2Bcpu-cp36-cp36m-win_amd64.whl diff --git a/how-to-use-azureml/ml-frameworks/scikit-learn/train-hyperparameter-tune-deploy-with-sklearn/train-hyperparameter-tune-deploy-with-sklearn.yml b/how-to-use-azureml/ml-frameworks/scikit-learn/train-hyperparameter-tune-deploy-with-sklearn/train-hyperparameter-tune-deploy-with-sklearn.yml new file mode 100644 index 00000000..2691a849 --- /dev/null +++ b/how-to-use-azureml/ml-frameworks/scikit-learn/train-hyperparameter-tune-deploy-with-sklearn/train-hyperparameter-tune-deploy-with-sklearn.yml @@ -0,0 +1,6 @@ +name: train-hyperparameter-tune-deploy-with-sklearn +dependencies: +- pip: + - azureml-sdk + - azureml-widgets + - numpy diff --git a/how-to-use-azureml/ml-frameworks/tensorflow/distributed-tensorflow-with-horovod/distributed-tensorflow-with-horovod.yml b/how-to-use-azureml/ml-frameworks/tensorflow/distributed-tensorflow-with-horovod/distributed-tensorflow-with-horovod.yml new file mode 100644 index 00000000..3fbd7704 --- /dev/null +++ b/how-to-use-azureml/ml-frameworks/tensorflow/distributed-tensorflow-with-horovod/distributed-tensorflow-with-horovod.yml @@ -0,0 +1,11 @@ +name: distributed-tensorflow-with-horovod +dependencies: +- pip: + - azureml-sdk + - azureml-widgets + - keras + - tensorflow-gpu==1.13.2 + - horovod==0.19.1 + - matplotlib + - pandas + - fuse diff --git a/how-to-use-azureml/ml-frameworks/tensorflow/distributed-tensorflow-with-parameter-server/distributed-tensorflow-with-parameter-server.yml b/how-to-use-azureml/ml-frameworks/tensorflow/distributed-tensorflow-with-parameter-server/distributed-tensorflow-with-parameter-server.yml new file mode 100644 index 00000000..bc5a30eb --- /dev/null +++ b/how-to-use-azureml/ml-frameworks/tensorflow/distributed-tensorflow-with-parameter-server/distributed-tensorflow-with-parameter-server.yml @@ -0,0 +1,5 @@ +name: distributed-tensorflow-with-parameter-server +dependencies: +- pip: + - azureml-sdk + - azureml-widgets diff --git a/how-to-use-azureml/ml-frameworks/tensorflow/train-hyperparameter-tune-deploy-with-tensorflow/train-hyperparameter-tune-deploy-with-tensorflow.yml b/how-to-use-azureml/ml-frameworks/tensorflow/train-hyperparameter-tune-deploy-with-tensorflow/train-hyperparameter-tune-deploy-with-tensorflow.yml new file mode 100644 index 00000000..76b7eabc --- /dev/null +++ b/how-to-use-azureml/ml-frameworks/tensorflow/train-hyperparameter-tune-deploy-with-tensorflow/train-hyperparameter-tune-deploy-with-tensorflow.yml @@ -0,0 +1,12 @@ +name: train-hyperparameter-tune-deploy-with-tensorflow +dependencies: +- numpy +- matplotlib +- pip: + - azureml-sdk + - azureml-widgets + - pandas + - keras + - tensorflow==2.0.0 + - matplotlib + - fuse diff --git a/how-to-use-azureml/reinforcement-learning/atari-on-distributed-compute/pong_rllib.yml b/how-to-use-azureml/reinforcement-learning/atari-on-distributed-compute/pong_rllib.yml new file mode 100644 index 00000000..d9d808d9 --- /dev/null +++ b/how-to-use-azureml/reinforcement-learning/atari-on-distributed-compute/pong_rllib.yml @@ -0,0 +1,8 @@ +name: pong_rllib +dependencies: +- pip: + - azureml-sdk + - azureml-contrib-reinforcementlearning + - azureml-widgets + - matplotlib + - azure-mgmt-network==12.0.0 diff --git a/how-to-use-azureml/reinforcement-learning/cartpole-on-compute-instance/cartpole_ci.yml b/how-to-use-azureml/reinforcement-learning/cartpole-on-compute-instance/cartpole_ci.yml new file mode 100644 index 00000000..c5a2ed39 --- /dev/null +++ b/how-to-use-azureml/reinforcement-learning/cartpole-on-compute-instance/cartpole_ci.yml @@ -0,0 +1,6 @@ +name: cartpole_ci +dependencies: +- pip: + - azureml-sdk + - azureml-contrib-reinforcementlearning + - azureml-widgets diff --git a/how-to-use-azureml/reinforcement-learning/cartpole-on-single-compute/cartpole_sc.yml b/how-to-use-azureml/reinforcement-learning/cartpole-on-single-compute/cartpole_sc.yml new file mode 100644 index 00000000..48d5edfa --- /dev/null +++ b/how-to-use-azureml/reinforcement-learning/cartpole-on-single-compute/cartpole_sc.yml @@ -0,0 +1,6 @@ +name: cartpole_sc +dependencies: +- pip: + - azureml-sdk + - azureml-contrib-reinforcementlearning + - azureml-widgets diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/cpu/Dockerfile b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/cpu/Dockerfile deleted file mode 100644 index e6c05310..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/cpu/Dockerfile +++ /dev/null @@ -1,70 +0,0 @@ -FROM mcr.microsoft.com/azureml/base:openmpi3.1.2-ubuntu18.04 - -# Install some basic utilities -RUN apt-get update && apt-get install -y \ - curl \ - ca-certificates \ - sudo \ - cpio \ - git \ - bzip2 \ - libx11-6 \ - tmux \ - htop \ - gcc \ - xvfb \ - python-opengl \ - x11-xserver-utils \ - ffmpeg \ - mesa-utils \ - nano \ - vim \ - rsync \ - && rm -rf /var/lib/apt/lists/* - -# Create a working directory -RUN mkdir /app -WORKDIR /app - -# Install Minecraft needed libraries -RUN mkdir -p /usr/share/man/man1 && \ - sudo apt-get update && \ - sudo apt-get install -y \ - openjdk-8-jre-headless=8u162-b12-1 \ - openjdk-8-jdk-headless=8u162-b12-1 \ - openjdk-8-jre=8u162-b12-1 \ - openjdk-8-jdk=8u162-b12-1 - -# Create a Python 3.7 environment -RUN conda install conda-build \ - && conda create -y --name py37 python=3.7.3 \ - && conda clean -ya -ENV CONDA_DEFAULT_ENV=py37 - -# Install minerl -RUN pip install --upgrade --user minerl - -RUN pip install \ - pandas \ - matplotlib \ - numpy \ - scipy \ - azureml-defaults \ - tensorboardX \ - tensorflow==1.15rc2 \ - tabulate \ - dm_tree \ - lz4 \ - ray==0.8.3 \ - ray[rllib]==0.8.3 \ - ray[tune]==0.8.3 - -COPY patch_files/* /root/.local/lib/python3.7/site-packages/minerl/env/Malmo/Minecraft/src/main/java/com/microsoft/Malmo/Client/ - -# Start minerl to pre-fetch minerl files (saves time when starting minerl during training) -RUN xvfb-run -a -s "-screen 0 1400x900x24" python -c "import gym; import minerl; env = gym.make('MineRLTreechop-v0'); env.close();" - -RUN pip install --index-url https://test.pypi.org/simple/ malmo && \ - python -c "import malmo.minecraftbootstrap; malmo.minecraftbootstrap.download();" - -ENV MALMO_XSD_PATH="/app/MalmoPlatform/Schemas" diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/cpu/patch_files/ClientStateMachine.java b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/cpu/patch_files/ClientStateMachine.java deleted file mode 100644 index 71668152..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/cpu/patch_files/ClientStateMachine.java +++ /dev/null @@ -1,2481 +0,0 @@ -// -------------------------------------------------------------------------------------------------- -// Copyright (c) 2016 Microsoft Corporation -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and -// associated documentation files (the "Software"), to deal in the Software without restriction, -// including without limitation the rights to use, copy, modify, merge, publish, distribute, -// sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT -// NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -------------------------------------------------------------------------------------------------- - -package com.microsoft.Malmo.Client; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.lang.reflect.Field; -import java.math.BigDecimal; -import java.net.UnknownHostException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.logging.Level; - -import javax.xml.bind.JAXBException; -import javax.xml.stream.XMLStreamException; - -import net.minecraft.client.Minecraft; -import net.minecraft.client.entity.EntityPlayerSP; -import net.minecraft.client.gui.GuiDisconnected; -import net.minecraft.client.gui.GuiIngameMenu; -import net.minecraft.client.gui.GuiMainMenu; -import net.minecraft.client.gui.GuiScreen; -import net.minecraft.client.multiplayer.WorldClient; -import net.minecraft.client.network.NetHandlerPlayClient; -import net.minecraft.client.settings.GameSettings; -import net.minecraft.entity.Entity; -import net.minecraft.entity.EntityLivingBase; -import net.minecraft.launchwrapper.Launch; -import net.minecraft.network.NetworkManager; -import net.minecraft.server.integrated.IntegratedServer; -import net.minecraft.util.math.MathHelper; -import net.minecraft.util.text.TextComponentString; -import net.minecraft.world.GameType; -import net.minecraft.world.World; -import net.minecraft.world.chunk.Chunk; -import net.minecraft.world.chunk.IChunkProvider; -import net.minecraftforge.common.MinecraftForge; -import net.minecraftforge.common.config.Configuration; -import net.minecraftforge.fml.client.event.ConfigChangedEvent.OnConfigChangedEvent; -import net.minecraftforge.fml.common.FMLCommonHandler; -import net.minecraftforge.fml.common.Loader; -import net.minecraftforge.fml.common.eventhandler.SubscribeEvent; -import net.minecraftforge.fml.common.gameevent.TickEvent; -import net.minecraftforge.fml.common.gameevent.TickEvent.ClientTickEvent; -import net.minecraftforge.fml.common.gameevent.TickEvent.Phase; -import net.minecraftforge.fml.common.gameevent.TickEvent.ServerTickEvent; - -import org.xml.sax.SAXException; - -import com.google.gson.JsonObject; -import com.microsoft.Malmo.IState; -import com.microsoft.Malmo.MalmoMod; -import com.microsoft.Malmo.MalmoMod.IMalmoMessageListener; -import com.microsoft.Malmo.MalmoMod.MalmoMessageType; -import com.microsoft.Malmo.StateEpisode; -import com.microsoft.Malmo.StateMachine; -import com.microsoft.Malmo.Client.MalmoModClient.InputType; -import com.microsoft.Malmo.MissionHandlerInterfaces.IVideoProducer; -import com.microsoft.Malmo.MissionHandlerInterfaces.IWantToQuit; -import com.microsoft.Malmo.MissionHandlers.MissionBehaviour; -import com.microsoft.Malmo.MissionHandlers.MultidimensionalReward; -import com.microsoft.Malmo.Schemas.AgentSection; -import com.microsoft.Malmo.Schemas.AgentStart; -import com.microsoft.Malmo.Schemas.ClientAgentConnection; -import com.microsoft.Malmo.Schemas.MinecraftServerConnection; -import com.microsoft.Malmo.Schemas.Mission; -import com.microsoft.Malmo.Schemas.MissionDiagnostics; -import com.microsoft.Malmo.Schemas.MissionEnded; -import com.microsoft.Malmo.Schemas.MissionInit; -import com.microsoft.Malmo.Schemas.MissionResult; -import com.microsoft.Malmo.Schemas.Reward; -import com.microsoft.Malmo.Schemas.ModSettings; -import com.microsoft.Malmo.Schemas.PosAndDirection; -import com.microsoft.Malmo.Utils.AddressHelper; -import com.microsoft.Malmo.Utils.AuthenticationHelper; -import com.microsoft.Malmo.Utils.SchemaHelper; -import com.microsoft.Malmo.Utils.ScreenHelper; -import com.microsoft.Malmo.Utils.SeedHelper; -import com.microsoft.Malmo.Utils.ScoreHelper; -import com.microsoft.Malmo.Utils.TextureHelper; -import com.microsoft.Malmo.Utils.ScreenHelper.TextCategory; -import com.microsoft.Malmo.Utils.TCPInputPoller; -import com.microsoft.Malmo.Utils.TCPInputPoller.CommandAndIPAddress; -import com.microsoft.Malmo.Utils.TimeHelper.SyncTickEvent; -import com.microsoft.Malmo.Utils.TCPSocketChannel; -import com.microsoft.Malmo.Utils.TCPUtils; -import com.microsoft.Malmo.Utils.TimeHelper; -import com.mojang.authlib.properties.Property; - -/** - * Class designed to track and control the state of the mod, especially regarding mission launching/running.
- * States are defined by the MissionState enum, and control is handled by - * MissionStateEpisode subclasses. The ability to set the state directly is - * restricted, but hooks such as onPlayerReadyForMission etc are exposed to - * allow subclasses to react to certain state changes.
- * The ProjectMalmo mod app class inherits from this and uses these hooks to run missions. - */ -public class ClientStateMachine extends StateMachine implements IMalmoMessageListener -{ - // AOG - Dropped from 2000 to 1000 to speed up detection of failed server restarts - private static final int WAIT_MAX_TICKS = 1000; // Over 1 minute and a half in client ticks. - private static final int VIDEO_MAX_WAIT = 90 * 1000; // Max wait for video in ms. - private static final String MISSING_MCP_PORT_ERROR = "no_mcp"; - private static final String INFO_MCP_PORT = "info_mcp"; - private static final String INFO_RESERVE_STATUS = "info_reservation"; - - private MissionInit currentMissionInit = null; // The MissionInit object for the mission currently being loaded/run. - private MissionBehaviour missionBehaviour = new MissionBehaviour(); - private String missionQuitCode = ""; // The reason why this mission ended. - private MultidimensionalReward finalReward = new MultidimensionalReward(true); // The reward at the end of the mission, sent separately to ensure timely delivery. - private MissionDiagnostics missionEndedData = new MissionDiagnostics(); - private ScreenHelper screenHelper = new ScreenHelper(); - protected MalmoModClient inputController; - - // Env service: - protected MalmoEnvServer envServer; - - // Socket stuff: - protected TCPInputPoller missionPoller; - protected TCPInputPoller controlInputPoller; - protected int integratedServerPort; - String reservationID = ""; // empty if we are not reserved, otherwise "RESERVED" + the experiment ID we are reserved for. - long reservationExpirationTime = 0; - private TCPSocketChannel missionControlSocket; - - private void reserveClient(String id) - { - synchronized(this.reservationID) - { - ClientStateMachine.this.getScreenHelper().clearFragment(INFO_RESERVE_STATUS); - - // id is in the form :, where long is the length of time to keep the reservation for, - // and expID is the experimentationID used to ensure the client is reserved for the correct experiment. - int separator = id.indexOf(":"); - if (separator == -1) - { - System.out.println("Error - malformed reservation request - client will not be reserved."); - this.reservationID = ""; - } - else - { - long duration = Long.valueOf(id.substring(0, separator)); - String expID = id.substring(separator + 1); - this.reservationExpirationTime = System.currentTimeMillis() + duration; - // We don't just use the id, in case users have supplied a blank string as their experiment ID. - this.reservationID = "RESERVED" + expID; - ClientStateMachine.this.getScreenHelper().addFragment("Reserved: " + expID, TextCategory.TXT_INFO, (int)duration);//INFO_RESERVE_STATUS); - } - } - } - - private boolean isReserved() - { - synchronized(this.reservationID) - { - System.out.println("==== RES: " + this.reservationID + " - " + (this.reservationExpirationTime - System.currentTimeMillis())); - return !this.reservationID.isEmpty() && this.reservationExpirationTime > System.currentTimeMillis(); - } - } - - private boolean isAvailable(String id) - { - synchronized(this.reservationID) - { - return (this.reservationID.isEmpty() || this.reservationID.equals("RESERVED" + id) || System.currentTimeMillis() >= this.reservationExpirationTime); - } - } - - private void cancelReservation() - { - synchronized(this.reservationID) - { - this.reservationID = ""; - ClientStateMachine.this.getScreenHelper().clearFragment(INFO_RESERVE_STATUS); - } - } - - protected TCPSocketChannel getMissionControlSocket() { return this.missionControlSocket; } - - protected void createMissionControlSocket() - { - TCPUtils.LogSection ls = new TCPUtils.LogSection("Creating MissionControlSocket"); - // Set up a TCP connection to the agent: - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - if (this.missionControlSocket == null || - this.missionControlSocket.getPort() != cac.getAgentMissionControlPort() || - this.missionControlSocket.getAddress() == null || - !this.missionControlSocket.isValid() || - !this.missionControlSocket.isOpen() || - !this.missionControlSocket.getAddress().equals(cac.getAgentIPAddress())) - { - if (this.missionControlSocket != null) - this.missionControlSocket.close(); - this.missionControlSocket = new TCPSocketChannel(cac.getAgentIPAddress(), cac.getAgentMissionControlPort(), "mcp"); - } - ls.close(); - } - - public ClientStateMachine(ClientState initialState, MalmoModClient inputController) - { - super(initialState); - this.inputController = inputController; - - // Register ourself on the event busses, so we can harness the client tick: - MinecraftForge.EVENT_BUS.register(this); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_TEXT); - } - - @Override - public void clearErrorDetails() - { - super.clearErrorDetails(); - this.missionQuitCode = ""; - } - - @SubscribeEvent - public void onClientTick(TickEvent.ClientTickEvent ev) - { - // Use the client tick to ensure we regularly update our state (from the client thread) - updateState(); - } - - public ScreenHelper getScreenHelper() - { - return screenHelper; - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - if (messageType == MalmoMessageType.SERVER_TEXT) - { - String chat = data.get("chat"); - if (chat != null) - Minecraft.getMinecraft().ingameGUI.getChatGUI().printChatMessageWithOptionalDeletion(new TextComponentString(chat), 1); - else - { - String text = data.get("text"); - ScreenHelper.TextCategory category = ScreenHelper.TextCategory.valueOf(data.get("category")); - String strtime = data.get("displayTime"); - Integer time = (strtime != null) ? Integer.valueOf(strtime) : null; - this.getScreenHelper().addFragment(text, category, time); - } - } - } - - @Override - protected String getName() - { - return "CLIENT"; - } - - @Override - protected void onPreStateChange(IState toState) - { - this.getScreenHelper().addFragment("CLIENT: " + toState, ScreenHelper.TextCategory.TXT_CLIENT_STATE, ""); - } - - /** - * Create the episode object for the requested state. - * - * @param state the state the mod is entering - * @return a MissionStateEpisode that localises all the logic required to run this state - */ - @Override - protected StateEpisode getStateEpisodeForState(IState state) - { - if (!(state instanceof ClientState)) - return null; - - ClientState cs = (ClientState) state; - switch (cs) { - case WAITING_FOR_MOD_READY: - return new InitialiseClientModEpisode(this); - case DORMANT: - return new DormantEpisode(this); - case CREATING_HANDLERS: - return new CreateHandlersEpisode(this); - case EVALUATING_WORLD_REQUIREMENTS: - return new EvaluateWorldRequirementsEpisode(this); - case PAUSING_OLD_SERVER: - return new PauseOldServerEpisode(this); - case CLOSING_OLD_SERVER: - return new CloseOldServerEpisode(this); - case CREATING_NEW_WORLD: - return new CreateWorldEpisode(this); - case WAITING_FOR_SERVER_READY: - return new WaitingForServerEpisode(this); - case RUNNING: - return new MissionRunningEpisode(this); - case IDLING: - return new MissionIdlingEpisode(this); - case MISSION_ENDED: - return new MissionEndedEpisode(this, MissionResult.ENDED, false, false, true); - case ERROR_DUFF_HANDLERS: - return new MissionEndedEpisode(this, MissionResult.MOD_FAILED_TO_INSTANTIATE_HANDLERS, true, true, true); - case ERROR_INTEGRATED_SERVER_UNREACHABLE: - return new MissionEndedEpisode(this, MissionResult.MOD_SERVER_UNREACHABLE, true, true, true); - case ERROR_NO_WORLD: - return new MissionEndedEpisode(this, MissionResult.MOD_HAS_NO_WORLD_LOADED, true, true, true); - case ERROR_CANNOT_CREATE_WORLD: - return new MissionEndedEpisode(this, MissionResult.MOD_FAILED_TO_CREATE_WORLD, true, true, true); - case ERROR_CANNOT_START_AGENT: // run-ons deliberate - case ERROR_LOST_AGENT: - case ERROR_LOST_VIDEO: - return new MissionEndedEpisode(this, MissionResult.MOD_HAS_NO_AGENT_AVAILABLE, true, true, false); - case ERROR_LOST_NETWORK_CONNECTION: // run-on deliberate - case ERROR_CANNOT_CONNECT_TO_SERVER: - return new MissionEndedEpisode(this, MissionResult.MOD_CONNECTION_FAILED, true, false, true); // No point trying to inform the server - we can't reach it anyway! - case ERROR_TIMED_OUT_WAITING_FOR_EPISODE_START: // run-ons deliberate - case ERROR_TIMED_OUT_WAITING_FOR_EPISODE_PAUSE: - case ERROR_TIMED_OUT_WAITING_FOR_EPISODE_CLOSE: - case ERROR_TIMED_OUT_WAITING_FOR_MISSION_END: - case ERROR_TIMED_OUT_WAITING_FOR_WORLD_CREATE: - return new MissionEndedEpisode(this, MissionResult.MOD_CONNECTION_FAILED, true, true, true); - case MISSION_ABORTED: - return new MissionEndedEpisode(this, MissionResult.MOD_SERVER_ABORTED_MISSION, true, false, true); // Don't inform the server - it already knows (we're acting on its notification) - case WAITING_FOR_SERVER_MISSION_END: - return new WaitingForServerMissionEndEpisode(this); - default: - break; - } - return null; - } - - protected MissionInit currentMissionInit() - { - return this.currentMissionInit; - } - - protected MissionBehaviour currentMissionBehaviour() - { - return this.missionBehaviour; - } - - protected class MissionInitResult - { - public MissionInit missionInit = null; - public boolean wasMissionInit = false; - public String error = null; - } - - protected MissionInitResult decodeMissionInit(String command) - { - MissionInitResult result = new MissionInitResult(); - if (command == null) - { - result.error = "Null command passed."; - return result; - } - - String rootNodeName = SchemaHelper.getRootNodeName(command); - if (rootNodeName != null && rootNodeName.equals("MissionInit")) - { - result.wasMissionInit = true; - // Attempt to decode the MissionInit XML string. - try - { - result.missionInit = (MissionInit) SchemaHelper.deserialiseObject(command, "MissionInit.xsd", MissionInit.class); - } - catch (JAXBException e) - { - System.out.println("JAXB exception: " + e); - if (e.getMessage() != null) - result.error = e.getMessage(); - else if (e.getLinkedException() != null && e.getLinkedException().getMessage() != null) - result.error = e.getLinkedException().getMessage(); - else - result.error = "Unspecified problem parsing MissionInit - check your Mission xml."; - } - catch (SAXException e) - { - System.out.println("SAX exception: " + e); - result.error = e.getMessage(); - } - catch (XMLStreamException e) - { - System.out.println("XMLStreamException: " + e); - result.error = e.getMessage(); - } - } - return result; - } - - protected boolean areMissionsEqual(Mission m1, Mission m2) - { - return true; - // FIX NEEDED - the following code fails because m1 may have been - // modified since loading - eg the MazeDecorator writes directly to the XML, - // and the use of some of the getters in the XSD-generated code can cause extra - // (empty) nodes to be added to the resulting XML. - // We need a more robust way of comparing two mission objects. - // For now, simply return true, since a false positive is less dangerous - // than a false negative. - /* - try { - String s1 = SchemaHelper.serialiseObject(m1, Mission.class); - String s2 = SchemaHelper.serialiseObject(m2, Mission.class); - return s1.compareTo(s2) == 0; - } catch( JAXBException e ) { - System.out.println("JAXB exception: " + e); - return false; - }*/ - } - - /** - * Set up the mission poller.
- * This is called during the initialisation episode, but also needs to be - * available for other episodes in case the configuration changes, resulting - * in changes to the ports. - * - * @throws UnknownHostException - */ - protected void initialiseComms() throws UnknownHostException - { - // Start polling for missions: - if (this.missionPoller != null) - { - this.missionPoller.stopServer(); - } - - this.missionPoller = new TCPInputPoller(AddressHelper.getMissionControlPortOverride(), AddressHelper.MIN_MISSION_CONTROL_PORT, AddressHelper.MAX_FREE_PORT, true, "mcp") - { - @Override - public void onError(String error, DataOutputStream dos) - { - System.out.println("SENDING ERROR: " + error); - try - { - dos.writeInt(error.length()); - dos.writeBytes(error); - dos.flush(); - } - catch (IOException e) - { - } - } - - private void reply(String reply, DataOutputStream dos) - { - System.out.println("REPLYING WITH: " + reply); - try - { - dos.writeInt(reply.length()); - dos.writeBytes(reply); - dos.flush(); - } - catch (IOException e) - { - System.out.println("Failed to reply to message!"); - } - } - - @Override - public boolean onCommand(String command, String ipFrom, DataOutputStream dos) - { - System.out.println("Received from " + ipFrom + ":" + - command.substring(0, Math.min(command.length(), 1024))); - boolean keepProcessing = false; - - // Possible commands: - // 1: MALMO_REQUEST_CLIENT:: - // 2: MALMO_CANCEL_REQUEST - // 3: MALMO_FIND_SERVER - // 4: MALMO_KILL_CLIENT - // 5: MissionInit - - String reservePrefixGeneral = "MALMO_REQUEST_CLIENT:"; - String reservePrefix = reservePrefixGeneral + Loader.instance().activeModContainer().getVersion() + ":"; - String findServerPrefix = "MALMO_FIND_SERVER"; - String cancelRequestCommand = "MALMO_CANCEL_REQUEST"; - String killClientCommand = "MALMO_KILL_CLIENT"; - - if (command.startsWith(reservePrefix)) - { - // Reservation request. - // We either reply with MALMOOK, if we are free, or MALMOBUSY if not. - IState currentState = getStableState(); - if (currentState != null && currentState.equals(ClientState.DORMANT) && !isReserved()) - { - reserveClient(command.substring(reservePrefix.length())); - reply("MALMOOK", dos); - } - else - { - // We're busy - we can't be reserved. - reply("MALMOBUSY", dos); - } - } - else if (command.startsWith(reservePrefixGeneral)) - { - // Reservation request, but it didn't match the request we expect, above. - // This happens if the agent sending the request is running a different version of Malmo - - // a version mismatch error. - reply("MALMOERRORVERSIONMISMATCH in reservation string (Got " + command + ", expected " + reservePrefix + " - check your path for old versions of MalmoPython/MalmoJava/Malmo.lib etc)", dos); - } - else if (command.equals(cancelRequestCommand)) - { - // If we've been reserved, cancel the reservation. - if (isReserved()) - { - cancelReservation(); - reply("MALMOOK", dos); - } - else - { - // We weren't reserved in the first place - something is odd. - reply("MALMOERRORAttempt to cancel a reservation that was never made.", dos); - } - } - else if (command.startsWith(findServerPrefix)) - { - // Request to find the server for the given experiment ID. - String expID = command.substring(findServerPrefix.length()); - if (currentMissionInit() != null && currentMissionInit().getExperimentUID().equals(expID)) - { - // Our Experiment IDs match, so we are running the same experiment. - // Return the port and server IP address to the caller: - MinecraftServerConnection msc = currentMissionInit().getMinecraftServerConnection(); - if (msc == null) - reply("MALMONOSERVERYET", dos); // Mission might be starting up. - else - reply("MALMOS" + msc.getAddress().trim() + ":" + msc.getPort(), dos); - } - else - { - // We don't have a MissionInit ourselves, or we're running a different experiment, - // so we can't help. - reply("MALMONOSERVER", dos); - } - } - else if (command.equals(killClientCommand)) - { - // Kill switch provided in case AI takes over the world... - // Or, more likely, in case this Minecraft instance has become unreliable (eg if it's been running for several days) - // and needs to be replaced with a fresh instance. - // If we are currently running a mission, we gracefully decline, to prevent users from wiping out - // other users' experiments. - // We also decline unless we were launched in "replaceable" mode - a command-line switch that indicates we were - // launched by a script which is still running, and can therefore replace us when we terminate. - IState currentState = getStableState(); - if (currentState != null && currentState.equals(ClientState.DORMANT) && !isReserved()) - { - Configuration config = MalmoMod.instance.getModSessionConfigFile(); - if (config.getBoolean("replaceable", "runtype", false, "Will be replaced if killed")) - { - reply("MALMOOK", dos); - - missionPoller.stopServer(); - exitJava(); - } - else - { - reply("MALMOERRORNOTKILLABLE", dos); - } - } - else - { - // We're too busy and important to be killed. - reply("MALMOBUSY", dos); - } - } - else - { - // See if we've been sent a MissionInit message: - - MissionInitResult missionInitResult = decodeMissionInit(command); - - if (missionInitResult.wasMissionInit && missionInitResult.missionInit == null) - { - // Got sent a duff MissionInit xml - pass back the JAXB/SAXB errors. - reply("MALMOERROR" + missionInitResult.error, dos); - } - else if (missionInitResult.wasMissionInit && missionInitResult.missionInit != null) - { - MissionInit missionInit = missionInitResult.missionInit; - // We've been sent a MissionInit message. - // First, check the version number: - String platformVersion = missionInit.getPlatformVersion(); - String ourVersion = Loader.instance().activeModContainer().getVersion(); - if (platformVersion == null || !platformVersion.equals(ourVersion)) - { - reply("MALMOERRORVERSIONMISMATCH (Got " + platformVersion + ", expected " + ourVersion + " - check your path for old versions of MalmoPython/MalmoJava/Malmo.lib etc)", dos); - } - else - { - // MissionInit passed to us - this is a request to launch this mission. Can we? - IState currentState = getStableState(); - if (currentState != null && currentState.equals(ClientState.DORMANT) && isAvailable(missionInit.getExperimentUID())) - { - reply("MALMOOK", dos); - keepProcessing = true; // State machine will now process this MissionInit and start the mission. - } - else - { - // We're busy - we can't run this mission. - reply("MALMOBUSY", dos); - } - } - } - } - - return keepProcessing; - } - }; - - int mcPort = 0; - if (MalmoEnvServer.isEnv()) { - // Start up new "Env" service instead of Malmo AgentHost api. - System.out.println("***** Start MalmoEnvServer on port " + AddressHelper.getMissionControlPortOverride()); - this.envServer = new MalmoEnvServer(Loader.instance().activeModContainer().getVersion(), AddressHelper.getMissionControlPortOverride(), this.missionPoller); - Thread thread = new Thread("MalmoEnvServer") { - public void run() { - try { - envServer.serve(); - } catch (IOException ioe) { - System.out.println("MalmoEnvServer exist on " + ioe); - } - } - }; - thread.start(); - } else { - // "Legacy" AgentHost api. - this.missionPoller.start(); - mcPort = ClientStateMachine.this.missionPoller.getPortBlocking(); - } - - // Tell the address helper what the actual port is: - AddressHelper.setMissionControlPort(mcPort); - if (AddressHelper.getMissionControlPort() == -1) - { - // Failed to create a mission control port - nothing will work! - System.out.println("**** NO MISSION CONTROL SOCKET CREATED - WAS THE PORT IN USE? (Check Mod GUI options) ****"); - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Could not open a Mission Control Port - check the Mod GUI options.", TextCategory.TXT_CLIENT_WARNING, MISSING_MCP_PORT_ERROR); - } - else - { - // Clear the error string, if there was one: - ClientStateMachine.this.getScreenHelper().clearFragment(MISSING_MCP_PORT_ERROR); - } - // Display the port number: - ClientStateMachine.this.getScreenHelper().clearFragment(INFO_MCP_PORT); - if (AddressHelper.getMissionControlPort() != -1) - ClientStateMachine.this.getScreenHelper().addFragment("MCP: " + AddressHelper.getMissionControlPort(), TextCategory.TXT_INFO, INFO_MCP_PORT); - } - - public static void exitJava() { - // Give non-hard exit 10 seconds to complete and force a hard exit. - Thread deadMansHandle = new Thread(new Runnable() { - @Override - public void run() { - for (int i = 10; i > 0; i--) { - try { - Thread.sleep(1000); - System.out.println("Waiting to exit " + i + "..."); - } catch (InterruptedException e) { - System.out.println("Interrupted " + i + "..."); - } - } - - // Kill it with fire!!! - System.out.println("Attempting hard exit"); - FMLCommonHandler.instance().exitJava(0, true); - } - }); - - deadMansHandle.setDaemon(true); - deadMansHandle.start(); - - // Have to use FMLCommonHandler; direct calls to System.exit() are trapped and denied by the FML code. - FMLCommonHandler.instance().exitJava(0, false); - } - - // --------------------------------------------------------------------------------------------------------- - // Episode helpers - each extends a MissionStateEpisode to encapsulate a certain state - // --------------------------------------------------------------------------------------------------------- - - public abstract class ErrorAwareEpisode extends StateEpisode implements IMalmoMessageListener - { - protected Boolean errorFlag = false; - protected Map errorData = null; - - public ErrorAwareEpisode(ClientStateMachine machine) - { - super(machine); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_ABORT); - } - - protected boolean pingAgent(boolean abortIfFailed) - { - if (AddressHelper.getMissionControlPort() == 0) { - // MalmoEnvServer has no server to client ping. - return true; - } - - boolean sentOkay = ClientStateMachine.this.getMissionControlSocket().sendTCPString("", 1); - if (!sentOkay) - { - // It's not available - bail. - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Lost contact with agent - aborting mission", TextCategory.TXT_CLIENT_WARNING, 10000); - if (abortIfFailed) - episodeHasCompletedWithErrors(ClientState.ERROR_LOST_AGENT, "Lost contact with the agent"); - } - return sentOkay; - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - if (messageType == MalmoMod.MalmoMessageType.SERVER_ABORT) - { - synchronized (this.errorFlag) - { - this.errorFlag = true; - this.errorData = data; - // Save the error message, if there is one: - if (data != null) - { - String message = data.get("message"); - String user = data.get("username"); - String error = data.get("error"); - String report = ""; - if (user != null) - report += "From " + user + ": "; - if (error != null) - report += error; - if (message != null) - report += " (" + message + ")"; - ClientStateMachine.this.saveErrorDetails(report); - } - onAbort(data); - } - } - } - - @Override - public void cleanup() - { - super.cleanup(); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_ABORT); - } - - protected boolean inAbortState() - { - synchronized (this.errorFlag) - { - return this.errorFlag; - } - } - - protected Map getErrorData() - { - synchronized (this.errorFlag) - { - return this.errorData; - } - } - - protected void onAbort(Map errorData) - { - // Default does nothing, but can be overridden. - } - } - - /** - * Helper base class that responds to the config change and updates our AddressHelper.
- * This will also reset the mission poller. Depending on the state, more - * work may be needed (eg to recreate the command handler, etc) - it's up to - * the individual state episodes to do whatever else needs doing. - */ - abstract public class ConfigAwareStateEpisode extends ErrorAwareEpisode - { - ConfigAwareStateEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - public void onConfigChanged(OnConfigChangedEvent ev) - { - if (ev.getConfigID().equals(MalmoMod.SOCKET_CONFIGS)) - { - AddressHelper.update(MalmoMod.instance.getModSessionConfigFile()); - try - { - ClientStateMachine.this.initialiseComms(); - } - catch (UnknownHostException e) - { - // TODO What to do here? - e.printStackTrace(); - } - ScreenHelper.update(MalmoMod.instance.getModPermanentConfigFile()); - TCPUtils.update(MalmoMod.instance.getModPermanentConfigFile()); - } - } - } - - /** Initial episode - perform client setup */ - public class InitialiseClientModEpisode extends ConfigAwareStateEpisode - { - InitialiseClientModEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() throws Exception - { - ClientStateMachine.this.initialiseComms(); - - // This is necessary in order to allow user to exit the Minecraft window without halting the experiment: - GameSettings settings = Minecraft.getMinecraft().gameSettings; - settings.pauseOnLostFocus = false; - // And hook the screen helper into the ingame gui (which is responsible for overlaying chat, titles etc) - - // this has to be done after Minecraft.init(), so we do it here. - ScreenHelper.hookIntoInGameGui(); - } - - @Override - public void onRenderTick(TickEvent.RenderTickEvent ev) - { - // We wait until we start to get render ticks, at which point we assume Minecraft has finished starting up. - episodeHasCompleted(ClientState.DORMANT); - } - } - - // --------------------------------------------------------------------------------------------------------- - /** Dormant state - receptive to new missions */ - public class DormantEpisode extends ConfigAwareStateEpisode - { - private ClientStateMachine csMachine; - - protected DormantEpisode(ClientStateMachine machine) - { - super(machine); - this.csMachine = machine; - } - - @Override - protected void execute() - { - TextureHelper.init(); - - // Clear our current MissionInit state: - csMachine.currentMissionInit = null; - // Clear our current error state: - clearErrorDetails(); - // And clear out any stale commands left over from recent missions: - if (ClientStateMachine.this.controlInputPoller != null) - ClientStateMachine.this.controlInputPoller.clearCommands(); - // Finally, do some Java housekeeping: - System.gc(); - } - - @Override - public void onClientTick(TickEvent.ClientTickEvent ev) throws Exception - { - - Minecraft.getMinecraft().mcProfiler.startSection("malmoHandleMissionCommands"); - checkForMissionCommand(); - Minecraft.getMinecraft().mcProfiler.endSection(); - } - - private void checkForMissionCommand() throws Exception - { - if (ClientStateMachine.this.missionPoller == null) - return; - - CommandAndIPAddress comip = missionPoller.getCommandAndIPAddress(); - if (comip == null) - return; - String missionMessage = comip.command; - if (missionMessage == null || missionMessage.length() == 0) - return; - Minecraft.getMinecraft().mcProfiler.endSection(); - Minecraft.getMinecraft().mcProfiler.startSection("malmoDecodeMissionInit"); - - MissionInitResult missionInitResult = decodeMissionInit(missionMessage); - Minecraft.getMinecraft().mcProfiler.endSection(); - - MissionInit missionInit = missionInitResult.missionInit; - if (missionInit != null) - { - missionInit.getClientAgentConnection().setAgentIPAddress(comip.ipAddress); - System.out.println("Mission received: " + missionInit.getMission().getAbout().getSummary()); - csMachine.currentMissionInit = missionInit; - TimeHelper.SyncManager.numTicks = 0; - ScoreHelper.logMissionInit(missionInit); - - ClientStateMachine.this.createMissionControlSocket(); - // Move on to next state: - episodeHasCompleted(ClientState.CREATING_HANDLERS); - } - else - { - throw new Exception("Failed to get valid MissionInit object from SchemaHelper."); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Now the MissionInit XML has been decoded, the client needs to create the - * Mission Handlers. - */ - public class CreateHandlersEpisode extends ConfigAwareStateEpisode - { - protected CreateHandlersEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() throws Exception - { - // First, clear our reservation state, if we were reserved: - ClientStateMachine.this.cancelReservation(); - - // Now try creating the handlers: - try - { - if(envServer != null){ - SeedHelper.advanceNextSeed(envServer.getSeed()); - } - ClientStateMachine.this.missionBehaviour = MissionBehaviour.createAgentHandlersFromMissionInit(currentMissionInit()); - if (envServer != null) { - ClientStateMachine.this.missionBehaviour.addQuitProducer(envServer); - } - } - catch (Exception e) - { - // TODO - System.err.println("ERROR: Exception caught making agent handlers" + e.toString()); - e.printStackTrace(); - } - // Set up our command input poller. This is only checked during the MissionRunning episode, but - // it needs to be started now, so we can report the port it's using back to the agent. - TCPUtils.LogSection ls = new TCPUtils.LogSection("Initialise Command Input Poller"); - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - int requestedPort = cac.getClientCommandsPort(); - // If the requested port is 0, we dynamically allocate our own port, and feed that back to the agent. - // If the requested port is non-zero, we have to use it. - if (requestedPort != 0 && ClientStateMachine.this.controlInputPoller != null && ClientStateMachine.this.controlInputPoller.getPort() != requestedPort) - { - // A specific port has been requested, and it's not the one we are currently using, - // so we need to recreate our poller. - System.out.println("Requested command port is not the same as the input poller port; the port was not free. Stopping server."); - ClientStateMachine.this.controlInputPoller.stopServer(); - ClientStateMachine.this.controlInputPoller = null; - } - if (ClientStateMachine.this.controlInputPoller == null) - { - if (requestedPort == 0) - ClientStateMachine.this.controlInputPoller = new TCPInputPoller(AddressHelper.MIN_FREE_PORT, AddressHelper.MAX_FREE_PORT, true, "com"); - else - ClientStateMachine.this.controlInputPoller = new TCPInputPoller(requestedPort, "com"); - System.out.println("Starting command server."); - ClientStateMachine.this.controlInputPoller.start(); - } - // Make sure the cac is up-to-date: - cac.setClientCommandsPort(ClientStateMachine.this.controlInputPoller.getPortBlocking()); - ls.close(); - - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - // Set the agent's name as the current username: - List agents = currentMissionInit().getMission().getAgentSection(); - String agentName = agents.get(currentMissionInit().getClientRole()).getName(); - AuthenticationHelper.setPlayerName(Minecraft.getMinecraft().getSession(), agentName); - // If the player's profile properties are empty, MC will keep pinging the Minecraft session service - // to fill them, resulting in multiple http requests and grumpy responses from the server - // (see https://github.com/Microsoft/malmo/issues/568). - // To prevent this, we add a dummy property. - Minecraft.getMinecraft().getProfileProperties().put("dummy", new Property("dummy", "property")); - // Handlers and poller created successfully; proceed to next stage of loading. - // We will either need to connect to an existing server, or to start - // a new integrated server ourselves, depending on our role. - // For now, assume that the mod with role 0 is responsible for the server. - if (currentMissionInit().getClientRole() == 0) - { - // We are responsible for the server - investigate what needs to happen next: - episodeHasCompleted(ClientState.EVALUATING_WORLD_REQUIREMENTS); - } - else - { - // We may need to connect to a server. - episodeHasCompleted(ClientState.WAITING_FOR_SERVER_READY); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Attempt to connect to a server. Wait until connection is established. - */ - public class WaitingForServerEpisode extends ConfigAwareStateEpisode - { - String agentName; - int ticksUntilNextPing = 0; - int totalTicks = 0; - boolean waitingForChunk = false; - boolean waitingForPlayer = true; - - protected WaitingForServerEpisode(ClientStateMachine machine) - { - super(machine); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_ALLPLAYERSJOINED); - } - - private boolean isChunkReady() - { - // First, find the starting position we ought to have: - List agents = currentMissionInit().getMission().getAgentSection(); - if (agents == null || agents.size() <= currentMissionInit().getClientRole()) - return true; // This should never happen. - AgentSection as = agents.get(currentMissionInit().getClientRole()); - if (as.getAgentStart() != null && as.getAgentStart().getPlacement() != null) - { - PosAndDirection pos = as.getAgentStart().getPlacement(); - int x = MathHelper.floor(pos.getX().doubleValue()) >> 4; - int z = MathHelper.floor(pos.getZ().doubleValue()) >> 4; - // Now get the chunk we should be starting in: - IChunkProvider chunkprov = Minecraft.getMinecraft().world.getChunkProvider(); - EntityPlayerSP player = Minecraft.getMinecraft().player; - if (player.addedToChunk) - { - // Our player is already added to a chunk - is it the right one? - Chunk actualChunk = chunkprov.provideChunk(player.chunkCoordX, player.chunkCoordZ); - Chunk requestedChunk = chunkprov.provideChunk(x, z); - if (actualChunk == requestedChunk && actualChunk != null && !actualChunk.isEmpty()) - { - // We're in the right chunk, and it's not an empty chunk. - // We're ready to proceed, but first set our client positions to where we ought to be. - // The server should be doing this too, but there's no harm (probably) in doing it ourselves. - player.posX = pos.getX().doubleValue(); - player.posY = pos.getY().doubleValue(); - player.posZ = pos.getZ().doubleValue(); - return true; - } - } - return false; // Our starting position has been specified, but it's not yet ready. - } - return true; // No starting position specified, so doesn't matter where we start. - } - - @Override - protected void onClientTick(ClientTickEvent ev) - { - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - if (this.waitingForPlayer) - { - if (Minecraft.getMinecraft().player != null) - { - this.waitingForPlayer = false; - handleLan(); - } - else - return; - } - - totalTicks++; - - if (ticksUntilNextPing == 0) - { - // Tell the server what our agent name is. - // We do this repeatedly, because the server might not yet be listening. - if (Minecraft.getMinecraft().player != null && !this.waitingForChunk) - { - HashMap map = new HashMap(); - map.put("agentname", agentName); - map.put("username", Minecraft.getMinecraft().player.getName()); - currentMissionBehaviour().appendExtraServerInformation(map); - System.out.println("***Telling server we are ready - " + agentName); - MalmoMod.network.sendToServer(new MalmoMod.MalmoMessage(MalmoMessageType.CLIENT_AGENTREADY, 0, map)); - } - - // We also ping our agent, just to check it is still available: - pingAgent(true); // Will abort to an error state if client unavailable. - - ticksUntilNextPing = 10; // Try again in ten ticks. - } - else - { - ticksUntilNextPing--; - } - - if (this.waitingForChunk) - { - // The server is ready, we're just waiting for our chunk to appear. - if (isChunkReady()) - proceed(); - } - - List agents = currentMissionInit().getMission().getAgentSection(); - boolean completedWithErrors = false; - - if (agents.size() > 1 && currentMissionInit().getClientRole() != 0) - { - // We are waiting to join an out-of-process server. Need to pay attention to what happens - - // if we can't join, for any reason, we should abort the mission. - GuiScreen screen = Minecraft.getMinecraft().currentScreen; - if (screen != null && screen instanceof GuiDisconnected) { - // Disconnected screen appears when something has gone wrong. - // Would be nice to grab the reason from the screen, but it's a private member. - // (Can always use reflection, but it's so inelegant.) - String msg = "Unable to connect to Minecraft server in multi-agent mission."; - TCPUtils.Log(Level.SEVERE, msg); - episodeHasCompletedWithErrors(ClientState.ERROR_CANNOT_CONNECT_TO_SERVER, msg); - completedWithErrors = true; - } - } - - if (!completedWithErrors && totalTicks > WAIT_MAX_TICKS) - { - String msg = "Too long waiting for server episode to start."; - TCPUtils.Log(Level.SEVERE, msg); - // AOG - If we have timed out waiting for the server to be ready, then the - // MalmoEnvServer is also likely stuck trying to handle a peek request from - // Python client. We need to signal the env server should abort the request - // so that the client detects the error and can retry. - if (envServer != null) { - envServer.abort(); - } - episodeHasCompletedWithErrors(ClientState.ERROR_TIMED_OUT_WAITING_FOR_EPISODE_START, msg); - } - } - - @Override - protected void execute() throws Exception - { - totalTicks = 0; - - Minecraft.getMinecraft().displayGuiScreen(null); // Clear any menu screen that might confuse things. - // Get our name from the Mission: - List agents = currentMissionInit().getMission().getAgentSection(); - //if (agents == null || agents.size() <= currentMissionInit().getClientRole()) - // throw new Exception("No agent section for us!"); // TODO - this.agentName = agents.get(currentMissionInit().getClientRole()).getName(); - - if (agents.size() > 1 && currentMissionInit().getClientRole() != 0) - { - // Multi-agent mission, we should be joining a server. - // (Unless we are already on the correct server.) - String address = currentMissionInit().getMinecraftServerConnection().getAddress().trim(); - int port = currentMissionInit().getMinecraftServerConnection().getPort(); - String targetIP = address + ":" + port; - System.out.println("We should be joining " + targetIP); - EntityPlayerSP player = Minecraft.getMinecraft().player; - boolean namesMatch = (player == null) || Minecraft.getMinecraft().player.getName().equals(this.agentName); - if (!namesMatch) - { - // The name of our agent no longer matches the agent in our game profile - - // safest way to update is to log out and back in again. - // This hangs so just warn instead about the miss-match and proceed. - TCPUtils.Log(Level.WARNING,"Agent name does not match agent in game."); - // Minecraft.getMinecraft().world.sendQuittingDisconnectingPacket(); - // Minecraft.getMinecraft().loadWorld((WorldClient)null); - } - if (Minecraft.getMinecraft().getCurrentServerData() == null || !Minecraft.getMinecraft().getCurrentServerData().serverIP.equals(targetIP)) - { - net.minecraftforge.fml.client.FMLClientHandler.instance().connectToServerAtStartup(address, port); - } - this.waitingForPlayer = false; - } - } - - protected void handleLan() - { - // Get our name from the Mission: - List agents = currentMissionInit().getMission().getAgentSection(); - //if (agents == null || agents.size() <= currentMissionInit().getClientRole()) - // throw new Exception("No agent section for us!"); // TODO - this.agentName = agents.get(currentMissionInit().getClientRole()).getName(); - - if (agents.size() > 1 && currentMissionInit().getClientRole() == 0) // Multi-agent mission - make sure the server is open to the LAN: - { - MinecraftServerConnection msc = new MinecraftServerConnection(); - String address = currentMissionInit().getClientAgentConnection().getClientIPAddress(); - // Do we need to open to LAN? - if (Minecraft.getMinecraft().isSingleplayer() && !Minecraft.getMinecraft().getIntegratedServer().getPublic()) - { - String portstr = Minecraft.getMinecraft().getIntegratedServer().shareToLAN(GameType.SURVIVAL, true); // Set to true to stop spam kicks. - ClientStateMachine.this.integratedServerPort = Integer.valueOf(portstr); - } - - TCPUtils.Log(Level.INFO,"Integrated server port: " + ClientStateMachine.this.integratedServerPort); - msc.setPort(ClientStateMachine.this.integratedServerPort); - msc.setAddress(address); - - if (envServer != null) { - envServer.notifyIntegrationServerStarted(ClientStateMachine.this.integratedServerPort); - } - currentMissionInit().setMinecraftServerConnection(msc); - } - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - super.onMessage(messageType, data); - - if (messageType != MalmoMessageType.SERVER_ALLPLAYERSJOINED) - return; - - List handlers = new ArrayList(); - for (Entry entry : data.entrySet()) - { - if (entry.getKey().equals("startPosition")) - { - try - { - String[] parts = entry.getValue().split(":"); - Float x = Float.valueOf(parts[0]); - Float y = Float.valueOf(parts[1]); - Float z = Float.valueOf(parts[2]); - // Find the starting position we ought to have: - List agents = currentMissionInit().getMission().getAgentSection(); - if (agents != null && agents.size() > currentMissionInit().getClientRole()) - { - // And write this new position into it: - AgentSection as = agents.get(currentMissionInit().getClientRole()); - AgentStart startSection = as.getAgentStart(); - if (startSection != null) - { - PosAndDirection pos = startSection.getPlacement(); - if (pos == null) - pos = new PosAndDirection(); - pos.setX(new BigDecimal(x)); - pos.setY(new BigDecimal(y)); - pos.setZ(new BigDecimal(z)); - startSection.setPlacement(pos); - as.setAgentStart(startSection); - } - } - } - catch (Exception e) - { - System.out.println("Couldn't interpret position data"); - } - } - else - { - String extraHandler = entry.getValue(); - if (extraHandler != null && extraHandler.length() > 0) - { - try - { - Class handlerClass = Class.forName(entry.getKey()); - Object handler = SchemaHelper.deserialiseObject(extraHandler, "MissionInit.xsd", handlerClass); - handlers.add(handler); - } - catch (Exception e) - { - System.out.println("Error trying to create extra handlers: " + e); - // Do something... like episodeHasCompletedWithErrors(nextState, error)? - } - } - } - } - if (!handlers.isEmpty()) - currentMissionBehaviour().addExtraHandlers(handlers); - this.waitingForChunk = true; - } - - private void proceed() - { - // The server is ready, so send our MissionInit back to the agent and go! - // We launch the agent by sending it the MissionInit message we were sent - // (but with the Launcher's IP address included) - String xml = null; - boolean sentOkay = false; - String errorReport = ""; - try - { - xml = SchemaHelper.serialiseObject(currentMissionInit(), MissionInit.class); - if (AddressHelper.getMissionControlPort() == 0) { - if (envServer != null) { - // TODO MalmoEnvServer <- Running - } - sentOkay = true; - } else { - sentOkay = ClientStateMachine.this.getMissionControlSocket().sendTCPString(xml, 1); - } - } - catch (JAXBException e) - { - errorReport = e.getMessage(); - } - if (sentOkay) - episodeHasCompleted(ClientState.RUNNING); - else - { - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Could not contact agent to start mission - mission will abort.", TextCategory.TXT_CLIENT_WARNING, 10000); - if (!errorReport.isEmpty()) - { - ClientStateMachine.this.getScreenHelper().addFragment("ERROR DETAILS: " + errorReport, TextCategory.TXT_CLIENT_WARNING, 10000); - errorReport = ": " + errorReport; - } - episodeHasCompletedWithErrors(ClientState.ERROR_CANNOT_START_AGENT, "Failed to send MissionInit back to agent" + errorReport); - } - } - - @Override - public void cleanup() - { - super.cleanup(); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_ALLPLAYERSJOINED); - } - } - - /** - * Wait for the server to decide the mission has ended.
- * We're not allowed to return to dormant until the server decides everyone can. - */ - public class WaitingForServerMissionEndEpisode extends ConfigAwareStateEpisode - { - protected WaitingForServerMissionEndEpisode(ClientStateMachine machine) - { - super(machine); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_MISSIONOVER); - } - - @Override - protected void execute() throws Exception - { - // Get our name from the Mission: - List agents = currentMissionInit().getMission().getAgentSection(); - if (agents == null || agents.size() <= currentMissionInit().getClientRole()) - throw new Exception("No agent section for us!"); // TODO - String agentName = agents.get(currentMissionInit().getClientRole()).getName(); - - // Now send a message to the server saying that we are ready: - HashMap map = new HashMap(); - map.put("agentname", agentName); - MalmoMod.network.sendToServer(new MalmoMod.MalmoMessage(MalmoMessageType.CLIENT_AGENTSTOPPED, 0, map)); - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - super.onMessage(messageType, data); - if (messageType == MalmoMessageType.SERVER_MISSIONOVER) - episodeHasCompleted(ClientState.DORMANT); - } - - @Override - public void cleanup() - { - super.cleanup(); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_MISSIONOVER); - } - - @Override - protected void onAbort(Map errorData) - { - episodeHasCompleted(ClientState.MISSION_ABORTED); - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Depending on the basemap provided, either begin to perform a full world - * load, or reset the current world - */ - public class EvaluateWorldRequirementsEpisode extends ConfigAwareStateEpisode - { - EvaluateWorldRequirementsEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() - { - // We are responsible for creating the server, if required. - // This means we need access to the server's MissionHandlers: - MissionBehaviour serverHandlers = null; - try - { - serverHandlers = MissionBehaviour.createServerHandlersFromMissionInit(currentMissionInit()); - } - catch (Exception e) - { - episodeHasCompletedWithErrors(ClientState.ERROR_DUFF_HANDLERS, "Could not create server mission handlers: " + e.getMessage()); - } - - World world = null; - if (Minecraft.getMinecraft().getIntegratedServer() != null) - world = Minecraft.getMinecraft().getIntegratedServer().getEntityWorld(); - - boolean needsNewWorld = serverHandlers != null && serverHandlers.worldGenerator != null && serverHandlers.worldGenerator.shouldCreateWorld(currentMissionInit(), world); - boolean worldCurrentlyExists = world != null; - if (worldCurrentlyExists) - { - // If a world already exists, we need to check that our requested agent name matches the name - // of the player. If not, the safest thing to do is start a new server. - // Get our name from the Mission: - List agents = currentMissionInit().getMission().getAgentSection(); - String agentName = agents.get(currentMissionInit().getClientRole()).getName(); - if (Minecraft.getMinecraft().player != null) - { - if (!Minecraft.getMinecraft().player.getName().equals(agentName)) - needsNewWorld = true; - } - } - if (needsNewWorld && worldCurrentlyExists) - { - // We want a new world, and there is currently a world running, - // so we need to kill the current world. - episodeHasCompleted(ClientState.PAUSING_OLD_SERVER); - } - else if (needsNewWorld && !worldCurrentlyExists) - { - // We want a new world, and there is currently nothing running, - // so jump to world creation: - episodeHasCompleted(ClientState.CREATING_NEW_WORLD); - } - else if (!needsNewWorld && worldCurrentlyExists) - { - // We don't want a new world, and we can use the current one - - // but we own the server, so we need to pass it the new mission init: - Minecraft.getMinecraft().getIntegratedServer().addScheduledTask(new Runnable() - { - @Override - public void run() - { - try - { - MalmoMod.instance.sendMissionInitDirectToServer(currentMissionInit); - } - catch (Exception e) - { - episodeHasCompletedWithErrors(ClientState.ERROR_INTEGRATED_SERVER_UNREACHABLE, "Could not send MissionInit to our integrated server: " + e.getMessage()); - } - } - }); - // Skip all the map loading stuff and go straight to waiting for the server: - episodeHasCompleted(ClientState.WAITING_FOR_SERVER_READY); - } - else if (!needsNewWorld && !worldCurrentlyExists) - { - // Mission has requested no new world, but there is no current world to play in - this is an error: - episodeHasCompletedWithErrors(ClientState.ERROR_NO_WORLD, "We have no world to play in - check that your ServerHandlers section contains a world generator"); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Pause the old server. It's vital that we do this, otherwise it will - * respond to the quit disconnect package straight away and kill the server - * thread, which means there will be no server to respond to the loadWorld - * code. - */ - public class PauseOldServerEpisode extends ConfigAwareStateEpisode - { - int serverTickCount = 0; - int clientTickCount = 0; - int totalTicks = 0; - - PauseOldServerEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() - { - serverTickCount = 0; - clientTickCount = 0; - totalTicks = 0; - - if (Minecraft.getMinecraft().getIntegratedServer() != null && Minecraft.getMinecraft().world != null) - { - // If the integrated server has been opened to the LAN, we won't be able to pause it. - // To get around this, we need to make it think it's not open, by modifying its isPublic flag. - if (Minecraft.getMinecraft().getIntegratedServer().getPublic()) - { - if (!killPublicFlag(Minecraft.getMinecraft().getIntegratedServer())) - { - episodeHasCompletedWithErrors(ClientState.ERROR_CANNOT_CREATE_WORLD, "Can not pause the old server since it's open to LAN; no way to safely create new world."); - } - } - - Minecraft.getMinecraft().displayGuiScreen(new GuiIngameMenu()); - } - } - - private boolean killPublicFlag(IntegratedServer server) - { - // Are we in a dev environment? - boolean devEnv = (Boolean) Launch.blackboard.get("fml.deobfuscatedEnvironment"); - // We need to know, because the member name will either be obfuscated or not. - String isPublicMemberName = devEnv ? "isPublic" : "field_71346_p"; - // NOTE: obfuscated name may need updating if Forge changes. - Field isPublic; - try - { - isPublic = IntegratedServer.class.getDeclaredField(isPublicMemberName); - isPublic.setAccessible(true); - isPublic.set(server, false); - return true; - } - catch (SecurityException e) - { - e.printStackTrace(); - } - catch (IllegalAccessException e) - { - e.printStackTrace(); - } - catch (IllegalArgumentException e) - { - e.printStackTrace(); - } - catch (NoSuchFieldException e) - { - e.printStackTrace(); - } - return false; - } - - @Override - public void onClientTick(TickEvent.ClientTickEvent ev) - { - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - // We need to make sure that both the client and server have paused. - - // Since the server sets its pause state in response to the client's pause state, - // and it only performs this check once, at the top of its tick method, - // to be sure that the server has had time to set the flag correctly we need to make sure - // that at least one server tick method has *started* since the flag was set. - // We can't do this by catching the onServerTick events, since we don't receive them when the game is paused. - - // The following code makes use of the fact that the server both locks and empties the server's futureQueue, - // every time through the server tick method. - // This locking means that if the client - which needs to wait on the lock - - // tries to add an event to the queue in response to an event on the queue being executed, - // the newly added event will have to happen in a subsequent tick. - if ((Minecraft.getMinecraft().isGamePaused() || Minecraft.getMinecraft().player == null) && ev != null && ev.phase == Phase.END && this.clientTickCount == this.serverTickCount && this.clientTickCount <= 2) - { - this.clientTickCount++; // Increment our count, and wait for the server to catch up. - Minecraft.getMinecraft().getIntegratedServer().addScheduledTask(new Runnable() - { - public void run() - { - // Increment the server count. - PauseOldServerEpisode.this.serverTickCount++; - } - }); - } - - if (this.serverTickCount > 2) { - episodeHasCompleted(ClientState.CLOSING_OLD_SERVER); - } else if (++totalTicks > WAIT_MAX_TICKS) { - String msg = "Too long waiting for server episode to pause."; - TCPUtils.Log(Level.SEVERE, msg); - episodeHasCompletedWithErrors(ClientState.ERROR_TIMED_OUT_WAITING_FOR_EPISODE_PAUSE, msg); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Send a disconnecting message to the current server - sent before attempting to load a new world. - */ - public class CloseOldServerEpisode extends ConfigAwareStateEpisode - { - int totalTicks; - - CloseOldServerEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() - { - totalTicks = 0; - - if (Minecraft.getMinecraft().world != null) - { - // If the Minecraft server isn't paused at this point, - // then the following line will cause the server thread to exit... - Minecraft.getMinecraft().world.sendQuittingDisconnectingPacket(); - // ...in which case the next line will block. - Minecraft.getMinecraft().loadWorld((WorldClient) null); - // Must display the GUI or Minecraft will attempt to access a non-existent player in the client tick. - Minecraft.getMinecraft().displayGuiScreen(new GuiMainMenu()); - - // Allow shutdown messages to flow through. - try { - Thread.sleep(10000); - } catch (InterruptedException ie) { - } - } - } - - @Override - public void onClientTick(TickEvent.ClientTickEvent ev) - { - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - if (ev.phase == Phase.END) - episodeHasCompleted(ClientState.CREATING_NEW_WORLD); - - if (++totalTicks > WAIT_MAX_TICKS) - { - String msg = "Too long waiting for server episode to close."; - TCPUtils.Log(Level.SEVERE, msg); - episodeHasCompletedWithErrors(ClientState.ERROR_TIMED_OUT_WAITING_FOR_EPISODE_CLOSE, msg); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Attempt to create a world. - */ - public class CreateWorldEpisode extends ConfigAwareStateEpisode - { - boolean serverStarted = false; - boolean worldCreated = false; - int totalTicks = 0; - - CreateWorldEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() - { - try - { - totalTicks = 0; - - // We need to use the server's MissionHandlers here: - MissionBehaviour serverHandlers = MissionBehaviour.createServerHandlersFromMissionInit(currentMissionInit()); - if (serverHandlers != null && serverHandlers.worldGenerator != null) - { - if (serverHandlers.worldGenerator.createWorld(currentMissionInit())) - { - this.worldCreated = true; - if (Minecraft.getMinecraft().getIntegratedServer() != null) - Minecraft.getMinecraft().getIntegratedServer().setOnlineMode(false); - } - else - { - // World has not been created. - episodeHasCompletedWithErrors(ClientState.ERROR_CANNOT_CREATE_WORLD, "Server world-creation handler failed to create a world: " + serverHandlers.worldGenerator.getErrorDetails()); - } - } - } - catch (Exception e) - { - episodeHasCompletedWithErrors(ClientState.ERROR_CANNOT_CREATE_WORLD, "Server world-creation handler failed to create a world: " + e.getMessage()); - } - } - - @Override - protected void onServerTick(ServerTickEvent ev) - { - if (this.worldCreated && !this.serverStarted) - { - // The server has started ticking - we can set up its state machine, - // and move on to the next state in our own machine. - this.serverStarted = true; - MalmoMod.instance.initIntegratedServer(currentMissionInit()); // Needs to be done from the server thread. - episodeHasCompleted(ClientState.WAITING_FOR_SERVER_READY); - } - } - - @Override - public void onClientTick(TickEvent.ClientTickEvent ev) - { - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - if (++totalTicks > WAIT_MAX_TICKS) - { - String msg = "Too long waiting for world to be created."; - TCPUtils.Log(Level.SEVERE, msg); - episodeHasCompletedWithErrors(ClientState.ERROR_TIMED_OUT_WAITING_FOR_WORLD_CREATE, msg); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * State in which an agent has finished the mission, but is waiting for the server to draw stumps. - */ - public class MissionIdlingEpisode extends ConfigAwareStateEpisode - { - int totalTicks = 0; - - protected MissionIdlingEpisode(ClientStateMachine machine) - { - super(machine); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_STOPAGENTS); - } - - @Override - protected void execute() - { - totalTicks = 0; - TimeHelper.SyncManager.numTicks = 0; - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - super.onMessage(messageType, data); - // This message will be sent to us once the server has decided the mission is over. - if (messageType == MalmoMessageType.SERVER_STOPAGENTS) - episodeHasCompleted(ClientState.MISSION_ENDED); - } - - @Override - public void cleanup() - { - super.cleanup(); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_STOPAGENTS); - } - - @Override - public void onClientTick(TickEvent.ClientTickEvent ev) - { - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - ++totalTicks; - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * State in which a mission is running.
- * This state is ended by the death of the player or by the IWantToQuit - * handler, or by the server declaring the mission is over. - */ - public class MissionRunningEpisode extends ConfigAwareStateEpisode implements VideoProducedObserver - { - public static final int FailedTCPSendCountTolerance = 3; // Number of TCP timeouts before we cancel the mission - - protected MissionRunningEpisode(ClientStateMachine machine) - { - super(machine); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_STOPAGENTS); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_GO); - } - - boolean serverHasFiredStartingPistol = false; - boolean playerDied = false; - private int failedTCPRewardSendCount = 0; - private int failedTCPObservationSendCount = 0; - private boolean wantsToQuit = false; // We have decided our mission is at an end - private List videoHooks = new ArrayList(); - private String quitCode = ""; - private TCPSocketChannel observationSocket = null; - private TCPSocketChannel rewardSocket = null; - private long lastPingSent = 0; - private long pingFrequencyMs = 1000; - private boolean shouldMissionEnd = false; - private long frameTimestamp = 0; - - public void frameProduced() { - this.frameTimestamp = System.currentTimeMillis(); - } - - protected void onMissionStarted() - { - frameTimestamp = 0; - - // Open our communication channels: - openSockets(); - - this.shouldMissionEnd = false; - // Tell the server we have started: - HashMap map = new HashMap(); - map.put("username", Minecraft.getMinecraft().player.getName()); - MalmoMod.network.sendToServer(new MalmoMod.MalmoMessage(MalmoMessageType.CLIENT_AGENTRUNNING, 0, map)); - - // Set up our mission handlers: - if (currentMissionBehaviour().commandHandler != null) - { - currentMissionBehaviour().commandHandler.install(currentMissionInit()); - currentMissionBehaviour().commandHandler.setOverriding(true); - } - - if (currentMissionBehaviour().observationProducer != null) - currentMissionBehaviour().observationProducer.prepare(currentMissionInit()); - - if (currentMissionBehaviour().quitProducer != null) - currentMissionBehaviour().quitProducer.prepare(currentMissionInit()); - - if (currentMissionBehaviour().rewardProducer != null) - currentMissionBehaviour().rewardProducer.prepare(currentMissionInit()); - - if (currentMissionBehaviour().performanceProducer != null) - currentMissionBehaviour().performanceProducer.prepare(currentMissionInit()); - - // Disable the gui for the episode! - Minecraft.getMinecraft().gameSettings.hideGUI = true; - - for (IVideoProducer videoProducer : currentMissionBehaviour().videoProducers) - { - VideoHook hook = new VideoHook(); - this.videoHooks.add(hook); - frameProduced(); - hook.start(currentMissionInit(), videoProducer, this, envServer); - } - - // Make sure we have mouse control: - ClientStateMachine.this.inputController.setInputType(InputType.AI); - Minecraft.getMinecraft().inGameHasFocus = true; // Otherwise auto-repeat won't work for mouse clicks. - - // Overclocking: - ModSettings modsettings = currentMissionInit().getMission().getModSettings(); - if (modsettings != null && modsettings.getMsPerTick() != null) - TimeHelper.setMinecraftClientClockSpeed(1000 / modsettings.getMsPerTick()); - if (modsettings != null && modsettings.isPrioritiseOffscreenRendering() == Boolean.TRUE) - TimeHelper.displayGranularityMs = 1000; - TimeHelper.unpause(); - - // Synchronization - if (envServer != null){ - if(!envServer.doIWantToQuit(currentMissionInit())){ - TimeHelper.SyncManager.setSynchronous(envServer.isSynchronous()); - } else { - TimeHelper.SyncManager.setSynchronous(false); - } - } - } - - protected void onMissionEnded(IState nextState, String errorReport) - { - //Send the final data associated with the misson here. - this.shouldMissionEnd = false; - sendData(true); - - // Tidy up our mission handlers: - if (currentMissionBehaviour().rewardProducer != null) - currentMissionBehaviour().rewardProducer.cleanup(); - - if (currentMissionBehaviour().quitProducer != null) - currentMissionBehaviour().quitProducer.cleanup(); - - if (currentMissionBehaviour().observationProducer != null) - currentMissionBehaviour().observationProducer.cleanup(); - - if (currentMissionBehaviour().commandHandler != null) - { - currentMissionBehaviour().commandHandler.setOverriding(false); - currentMissionBehaviour().commandHandler.deinstall(currentMissionInit()); - } - - if (AddressHelper.getMissionControlPort() == 0) { - if (envServer != null) { - byte[] obs = envServer.getObservation(false); - envServer.endMission(); - } - } - - // Close our communication channels: - closeSockets(); - - for (VideoHook hook : this.videoHooks) - hook.stop(ClientStateMachine.this.missionEndedData); - - - // Disable the gui for the episode! - Minecraft.getMinecraft().gameSettings.hideGUI = false; - - // Return Minecraft speed to "normal": - TimeHelper.SyncManager.setPistolFired(false); - TimeHelper.setMinecraftClientClockSpeed(20); - TimeHelper.displayGranularityMs = 0; - TimeHelper.unpause(); - TimeHelper.SyncManager.setSynchronous(false); - - ClientStateMachine.this.missionQuitCode = this.quitCode; - if (errorReport != null) - episodeHasCompletedWithErrors(nextState, errorReport); - else - episodeHasCompleted(nextState); - } - - @Override - protected void execute() - { - onMissionStarted(); - } - - @Override - public void onClientTick(ClientTickEvent event) - { - // If we aren't performing synchronous ticking use the client Tick to handle updates - if(!TimeHelper.SyncManager.isSynchronous()){ - onTick(false, event.phase); - } - } - - @Override - public void onSyncTick(SyncTickEvent ev){ - // If we are performing synchronous ticking - onTick(true, ev.pos); - } - - private synchronized void onTick(Boolean synchronous, TickEvent.Phase phase){ - // TimeHelper.SyncManager.debugLog("[CLIENT_STATE_MACHINE] " + phase.toString()); - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - onMissionEnded(ClientState.MISSION_ABORTED, "Mission was aborted by server: " + ClientStateMachine.this.getErrorDetails()); - // Check to see whether we've been kicked from the server. - NetHandlerPlayClient npc = Minecraft.getMinecraft().getConnection(); - if(npc == null){ - if(this.serverHasFiredStartingPistol){ - onMissionEnded(ClientState.ERROR_LOST_NETWORK_CONNECTION, "Server was closed"); - return; - } - } - else{ - NetworkManager netman = npc.getNetworkManager(); - if (netman != null && !netman.hasNoChannel() && !netman.isChannelOpen()) - { - // Connection has been lost. - onMissionEnded(ClientState.ERROR_LOST_NETWORK_CONNECTION, "Client was kicked from server - " + netman.getExitMessage().getUnformattedText()); - } - - } - - // Check we are still in touch with the agent: - if (System.currentTimeMillis() > this.lastPingSent + this.pingFrequencyMs) - { - this.lastPingSent = System.currentTimeMillis(); - // Ping the agent - if serverHasFiredStartingPistol is true, we don't need to abort - - // we can simply set the wantsToQuit flag and end the mission cleanly. - // If serverHasFiredStartingPistol is false, then the mission isn't yet running, and - // setting the quit flag will do nothing - so we need to abort. - if (!pingAgent(false)) - { - if (!this.serverHasFiredStartingPistol){ - onMissionEnded(ClientState.ERROR_LOST_AGENT, "Lost contact with the agent"); - return; - } - else - { - System.out.println("Error - agent is not responding to pings."); - this.wantsToQuit = true; - this.quitCode = MalmoMod.AGENT_UNRESPONSIVE_CODE; - } - } - } - - - if (this.frameTimestamp != 0 && (System.currentTimeMillis() - this.frameTimestamp > VIDEO_MAX_WAIT) && !synchronous) { - System.out.println("No video produced recently. Aborting mission."); - if (!this.serverHasFiredStartingPistol) - onMissionEnded(ClientState.ERROR_LOST_VIDEO, "No video produced recently."); - else - { - System.out.println("Error - not receiving video."); - this.wantsToQuit = true; - this.quitCode = MalmoMod.VIDEO_UNRESPONSIVE_CODE; - } - } - - if(Minecraft.getMinecraft().world == null){ - if(this.serverHasFiredStartingPistol){ - onMissionEnded(ClientState.ERROR_NO_WORLD, "No world for client. Must be in main menu"); - } - - return; - - } - // Check here to see whether the player has died or not: - if (!this.playerDied && Minecraft.getMinecraft().player.isDead) - { - this.playerDied = true; - this.quitCode = MalmoMod.AGENT_DEAD_QUIT_CODE; - } - - // Although we only arrive in this episode once the server has determined that all clients are ready to go, - // the server itself waits for all clients to begin running before it enters the running state itself. - // This creates a small vulnerability, since a running client could theoretically *finish* its mission - // before the server manages to *start*. - // (This has potentially disastrous effects for the state machine, and is easy to reproduce by, - // for example, setting the start point and goal of the mission to the same coordinates.) - - // To guard against this happening, although we are running, we don't act on anything - - // we don't check for commands, or send observations or rewards - until we get the SERVER_GO signal, - // which is sent once the server's running episode has started. - - - TimeHelper.SyncManager.setPistolFired(this.serverHasFiredStartingPistol); - if (!this.serverHasFiredStartingPistol){ - return; - } - - // Perhaps the race condition could be that synchronous is then set to false when the quit command is recieved! - if(synchronous && phase == Phase.START){ - checkForControlCommand(); - } - if (phase == Phase.END) - { - - - // Check whether or not we want to quit: - IWantToQuit quitHandler = (currentMissionBehaviour() != null) ? currentMissionBehaviour().quitProducer : null; - boolean quitHandlerFired = (quitHandler != null && quitHandler.doIWantToQuit(currentMissionInit())); - if (quitHandlerFired || this.wantsToQuit || this.playerDied || this.shouldMissionEnd) - { - if (quitHandlerFired) - { - this.quitCode = quitHandler.getOutcome(); - } - try - { - // Save the quit code for anything that needs it: - MalmoMod.getPropertiesForCurrentThread().put("QuitCode", this.quitCode); - } - catch (Exception e) - { - System.out.println("Failed to get properties - final reward may go missing."); - } - - // Get the final reward data: - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - // if (currentMissionBehaviour() != null && currentMissionBehaviour().rewardProducer != null && cac != null) - // currentMissionBehaviour().rewardProducer.getReward(currentMissionInit(), ClientStateMachine.this.finalReward); - - // Now send a message to the server saying that we have finished our mission: - List agents = currentMissionInit().getMission().getAgentSection(); - String agentName = agents.get(currentMissionInit().getClientRole()).getName(); - HashMap map = new HashMap(); - map.put("agentname", agentName); - map.put("username", Minecraft.getMinecraft().player.getName()); - map.put("quitcode", this.quitCode); - MalmoMod.network.sendToServer(new MalmoMod.MalmoMessage(MalmoMessageType.CLIENT_AGENTFINISHEDMISSION, 0, map)); - - onMissionEnded(ClientState.MISSION_ABORTED, null); - } - else - { - // If in the case that we are asynchronous, do this - // wack stuff of checking input at the end of a tick... - if(!synchronous){ - - checkForControlCommand(); - } - - // Send off observation and reward data: - // And see if we have any incoming commands to act upon: - sendData(false); - } - } - } - - private void openSockets() - { - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - this.observationSocket = new TCPSocketChannel(cac.getAgentIPAddress(), cac.getAgentObservationsPort(), "obs"); - this.rewardSocket = new TCPSocketChannel(cac.getAgentIPAddress(), cac.getAgentRewardsPort(), "rew"); - } - - private void closeSockets() - { - this.observationSocket.close(); - this.rewardSocket.close(); - } - - private void sendData(boolean done) - { - TCPUtils.LogSection ls = new TCPUtils.LogSection("Sending data"); - - Minecraft.getMinecraft().mcProfiler.startSection("malmoSendData"); - // Create the observation data: - String data = ""; - Minecraft.getMinecraft().mcProfiler.startSection("malmoGatherObservationJSON"); - - if (currentMissionBehaviour() != null && currentMissionBehaviour().observationProducer != null) - { - JsonObject json = new JsonObject(); - currentMissionBehaviour().observationProducer.writeObservationsToJSON(json, currentMissionInit()); - data = json.toString(); - } - Minecraft.getMinecraft().mcProfiler.endSection(); //malmogatherjson - Minecraft.getMinecraft().mcProfiler.startSection("malmoSendTCPObservations"); - - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - - if (data != null && data.length() > 2 && cac != null) // An empty json string will be "{}" (length 2) - don't send these. - { - // TimeHelper.SyncManager.debugLog("[CLIENT_STATE_MACHINE INFO] " + Integer.toString(AddressHelper.getMissionControlPort())); - if (AddressHelper.getMissionControlPort() == 0) { - if (envServer != null) { - // TODO wierd, aren't we doing this? - envServer.observation(data); - } - } else { - if (this.observationSocket.sendTCPString(data)) { - this.failedTCPObservationSendCount = 0; - } else { - // Failed to send observation message. - this.failedTCPObservationSendCount++; - TCPUtils.Log(Level.WARNING, "Observation signal delivery failure count at " + this.failedTCPObservationSendCount); - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Agent missed observation signal", TextCategory.TXT_CLIENT_WARNING, 5000); - } - } - } - Minecraft.getMinecraft().mcProfiler.endSection(); //malmotcp - Minecraft.getMinecraft().mcProfiler.startSection("malmoGatherRewardSignal"); - - // Now create the reward signal: - if (currentMissionBehaviour() != null && currentMissionBehaviour().rewardProducer != null && cac != null) - { - MultidimensionalReward reward = new MultidimensionalReward(); - currentMissionBehaviour().rewardProducer.getReward(currentMissionInit(), reward); - - if (!reward.isEmpty()) - { - - String strReward = reward.getAsSimpleString(); - Minecraft.getMinecraft().mcProfiler.startSection("malmoSendTCPReward"); - - ScoreHelper.logReward(strReward); - - if (AddressHelper.getMissionControlPort() == 0) { - // MalmoEnvServer - reward - if (envServer != null) { - envServer.addRewards(reward.getRewardTotal()); - } - } else { - if (this.rewardSocket.sendTCPString(strReward)) { - this.failedTCPRewardSendCount = 0; // Reset the count of consecutive TCP failures. - } else { - // Failed to send TCP message - probably because the agent has quit under our feet. - // (This happens a lot when developing a Python agent - the developer has no easy way to quit - // the agent cleanly, so tends to kill the process.) - this.failedTCPRewardSendCount++; - TCPUtils.Log(Level.WARNING, "Reward signal delivery failure count at " + this.failedTCPRewardSendCount); - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Agent missed reward signal", TextCategory.TXT_CLIENT_WARNING, 5000); - } - } - - Minecraft.getMinecraft().mcProfiler.endSection(); //sendTCP reward. - } - if (currentMissionBehaviour().performanceProducer != null) - currentMissionBehaviour().performanceProducer.step(reward.getRewardTotal(), done); - } - else if(currentMissionBehaviour() != null){ - if (currentMissionBehaviour().performanceProducer != null) - currentMissionBehaviour().performanceProducer.step(0, done); - } - Minecraft.getMinecraft().mcProfiler.endSection(); //Gather reward. - Minecraft.getMinecraft().mcProfiler.endSection(); //sendData - - int maxFailedTCPSendCount = 0; - for (VideoHook hook : this.videoHooks) - { - if (hook.failedTCPSendCount > maxFailedTCPSendCount) - maxFailedTCPSendCount = hook.failedTCPSendCount; - } - if (maxFailedTCPSendCount > 0) - TCPUtils.Log(Level.WARNING, "Video signal failure count at " + maxFailedTCPSendCount); - // Check that our messages are getting through: - int maxFailed = Math.max(this.failedTCPRewardSendCount, maxFailedTCPSendCount); - maxFailed = Math.max(maxFailed, this.failedTCPObservationSendCount); - if (maxFailed > FailedTCPSendCountTolerance) - { - // They're not - and we've exceeded the count of allowed TCP failures. - System.out.println("ERROR: TCP messages are not getting through - quitting mission."); - this.wantsToQuit = true; - this.quitCode = MalmoMod.AGENT_UNRESPONSIVE_CODE; - } - ls.close(); - } - - /** - * Check to see if any control instructions have been received and act on them if so. - */ - public void checkForControlCommand() - { - Minecraft.getMinecraft().mcProfiler.endStartSection("malmoCommandHandling"); - String command; - boolean quitHandlerFired = false; - IWantToQuit quitHandler = (currentMissionBehaviour() != null) ? currentMissionBehaviour().quitProducer : null; - - if (envServer != null) { - command = envServer.getCommand(); - } else { - command = ClientStateMachine.this.controlInputPoller.getCommand(); - } - while (command != null && command.length() > 0 && !quitHandlerFired) - { - // TCPUtils.Log(Level.INFO, "Act on " + command); - // Pass the command to our various control overrides: - Minecraft.getMinecraft().mcProfiler.startSection("malmoCommandAct"); - - boolean handled = handleCommand(command); - // Get the next command: - if (envServer != null) { - command = envServer.getCommand(); - } else { - command = ClientStateMachine.this.controlInputPoller.getCommand(); - } - // If there *is* another command (commands came in faster than one per client tick), - // then we should check our quit producer before deciding whether to execute it. - Minecraft.getMinecraft().mcProfiler.endStartSection("malmoCommandRecheckQuitHandlers"); - if (command != null && command.length() > 0 && handled) - quitHandlerFired = (quitHandler != null && quitHandler.doIWantToQuit(currentMissionInit())); - Minecraft.getMinecraft().mcProfiler.endSection(); - } - } - - /** - * Attempt to handle a command string by passing it to our various external controllers in turn. - * - * @param command the command string to be handled. - * @return true if the command was handled. - */ - private boolean handleCommand(String command) - { - if (currentMissionBehaviour() != null && currentMissionBehaviour().commandHandler != null) - { - return currentMissionBehaviour().commandHandler.execute(command, currentMissionInit()); - } - return false; - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - super.onMessage(messageType, data); - // This message will be sent to us once the server has decided the mission is over. - if (messageType == MalmoMessageType.SERVER_STOPAGENTS) - { - this.quitCode = data.containsKey("QuitCode") ? data.get("QuitCode") : ""; - try - { - // Save the quit code for anything that needs it: - MalmoMod.getPropertiesForCurrentThread().put("QuitCode", this.quitCode); - } - catch (Exception e) - { - System.out.println("Failed to get properties - final reward may go missing."); - } - // Get the final reward data: - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - if (currentMissionBehaviour() != null && currentMissionBehaviour().rewardProducer != null && cac != null) - currentMissionBehaviour().rewardProducer.getReward(currentMissionInit(), ClientStateMachine.this.finalReward); - - this.shouldMissionEnd = true; - - } - else if (messageType == MalmoMessageType.SERVER_GO) - { - // First, force all entities to get re-added to their chunks, clearing out any old entities in the process. - // We need to do this because the process of teleporting all agents to their start positions, combined - // with setting them to/from spectator mode, leaves the client chunk entity lists etc in a parlous state. - List lel = Minecraft.getMinecraft().world.loadedEntityList; - for (int i = 0; i < lel.size(); i++) - { - Entity entity = (Entity)lel.get(i); - Chunk chunk = Minecraft.getMinecraft().world.getChunkFromChunkCoords(entity.chunkCoordX, entity.chunkCoordZ); - List entitiesToRemove = new ArrayList(); - for (int k = 0; k < chunk.getEntityLists().length; k++) - { - Iterator iterator = chunk.getEntityLists()[k].iterator(); - while (iterator.hasNext()) - { - Entity chunkent = (Entity)iterator.next(); - if (chunkent.getEntityId() == entity.getEntityId()) - { - entitiesToRemove.add(chunkent); - } - } - } - for (Entity removeEnt : entitiesToRemove) - { - chunk.removeEntity(removeEnt); - } - entity.addedToChunk = false; // Will force it to get re-added to the chunk list. - if (entity instanceof EntityLivingBase) - { - // If we want the entities to be rendered with the correct yaw from the outset, - // we need to set their render offset manually. - // (Set the offset from the outset to avoid the onset of upset.) - ((EntityLivingBase)entity).renderYawOffset = entity.rotationYaw; - ((EntityLivingBase)entity).prevRenderYawOffset = entity.rotationYaw; - } - if (entity instanceof EntityPlayerSP) - { - // Although the following call takes place on the server, and should have taken effect already, - // there is some discontinuity which is causing the effects to get lost, so we call it here too: - entity.setInvisible(false); - } - } - this.serverHasFiredStartingPistol = true; // GO GO GO! - } - } - - @Override - public void cleanup() - { - super.cleanup(); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_STOPAGENTS); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_GO); - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * State that occurs at the end of the mission, whether due to death, - * failure, success, error, or whatever. - */ - public class MissionEndedEpisode extends ConfigAwareStateEpisode - { - private MissionResult result; - private boolean aborting; - private boolean informServer; - private boolean informAgent; - private int totalTicks = 0; - - public MissionEndedEpisode(ClientStateMachine machine, MissionResult mr, boolean aborting, boolean informServer, boolean informAgent) - { - super(machine); - this.result = mr; - this.aborting = aborting; - this.informServer = informServer; - this.informAgent = informAgent; - } - - @Override - protected void execute() - { - totalTicks = 0; - - // Get a text report: - String errorFeedback = ClientStateMachine.this.getErrorDetails(); - String quitFeedback = ClientStateMachine.this.missionQuitCode; - String concatenation = (errorFeedback != null && !errorFeedback.isEmpty() && quitFeedback != null && !quitFeedback.isEmpty()) ? ";\n" : ""; - String report = quitFeedback + concatenation + errorFeedback; - - if (this.informServer) - { - // Inform the server of what has happened. - HashMap map = new HashMap(); - if (Minecraft.getMinecraft().player != null) // Might not be a player yet. - map.put("username", Minecraft.getMinecraft().player.getName()); - map.put("error", ClientStateMachine.this.getErrorDetails()); - MalmoMod.network.sendToServer(new MalmoMod.MalmoMessage(MalmoMessageType.CLIENT_BAILED, 0, map)); - } - - if (this.informAgent) - { - // Create a MissionEnded instance for this result: - MissionEnded missionEnded = new MissionEnded(); - missionEnded.setStatus(this.result); - if (ClientStateMachine.this.missionQuitCode != null && ClientStateMachine.this.missionQuitCode.equals(MalmoMod.AGENT_DEAD_QUIT_CODE)) - missionEnded.setStatus(MissionResult.PLAYER_DIED); // Need to do this manually. - missionEnded.setHumanReadableStatus(report); - - // TODO: WE HAVE TO MOVE THIS TO THE onMISSIONENDED of Client Mission - // BECAUSE IT WOULD TAKE AN EXTRA TICK TO HAVE THIS APPEAR PROPERLY. - // THIS MOVE IS INCOMPATIBLE WITH MULTIPLE AGENTS AND REWARD DISTRIBUTION - // A PROPER REHAUL OF THE WHOLE SIMULATOR TO SUPPROT SYNCHRONOUS TICKING - // ACCROSS MULTIPLE AGENTS AND A STATE MACHINE WHOSE STATE CHANGES INDEPENDENT - // OF CLIENT TICKS IS REQUIRED. - // if (!ClientStateMachine.this.finalReward.isEmpty()) - // { - // if (envServer != null) { - // envServer.addRewards(ClientStateMachine.this.finalReward.getRewardTotal()); - // } - // missionEnded.setReward(ClientStateMachine.this.finalReward.getAsReward()); - // ClientStateMachine.this.finalReward.clear(); - // } - missionEnded.setMissionDiagnostics(ClientStateMachine.this.missionEndedData); // send our diagnostics - ClientStateMachine.this.missionEndedData = new MissionDiagnostics(); // and clear them for the next mission - // And send MissionEnded message to the agent to inform it that the mission has ended: - System.out.println("inform the agent"); - sendMissionEnded(missionEnded); - } - - if (this.aborting) // Take the shortest path back to dormant. - episodeHasCompleted(ClientState.DORMANT); - } - - private void sendMissionEnded(MissionEnded missionEnded) - { - // Send a MissionEnded message to the agent to inform it that the mission has ended. - // Create a string XML representation: - String missionEndedString = null; - try - { - missionEndedString = SchemaHelper.serialiseObject(missionEnded, MissionEnded.class); - if (ScoreHelper.isScoring()) { - Reward reward = missionEnded.getReward(); - if (reward == null) { - reward = new Reward(); - } - ScoreHelper.logMissionEndRewards(reward); - } - } - catch (JAXBException e) - { - TCPUtils.Log(Level.SEVERE, "Failed mission end XML serialization: " + e); - } - - boolean sentOkay = false; - if (missionEndedString != null) - { - if (AddressHelper.getMissionControlPort() == 0) { - sentOkay = true; - } else { - TCPSocketChannel sender = ClientStateMachine.this.getMissionControlSocket(); - System.out.println(String.format("Sending mission ended message to %s:%d.", sender.getAddress(), sender.getPort())); - sentOkay = sender.sendTCPString(missionEndedString); - sender.close(); - } - } - - if (!sentOkay) - { - // Couldn't formulate a reply to the agent - bit of a problem. - // Can't do much to alert the agent itself, - // will have to settle for alerting anyone who is watching the mod: - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Could not send mission ended message - agent may need manually resetting.", TextCategory.TXT_CLIENT_WARNING, 10000); - } - } - - @Override - public void onClientTick(ClientTickEvent event) - { - if (!this.aborting) - episodeHasCompleted(ClientState.WAITING_FOR_SERVER_MISSION_END); - - if (++totalTicks > WAIT_MAX_TICKS) - { - String msg = "Too long waiting for server to end mission."; - TCPUtils.Log(Level.SEVERE, msg); - episodeHasCompletedWithErrors(ClientState.ERROR_TIMED_OUT_WAITING_FOR_MISSION_END, msg); - } - } - } -} diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/cpu/patch_files/MalmoEnvServer.java b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/cpu/patch_files/MalmoEnvServer.java deleted file mode 100644 index 6b74acac..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/cpu/patch_files/MalmoEnvServer.java +++ /dev/null @@ -1,939 +0,0 @@ -// -------------------------------------------------------------------------------------------------- -// Copyright (c) 2016 Microsoft Corporation -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and -// associated documentation files (the "Software"), to deal in the Software without restriction, -// including without limitation the rights to use, copy, modify, merge, publish, distribute, -// sublicense, and/or l copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT -// NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -------------------------------------------------------------------------------------------------- - -package com.microsoft.Malmo.Client; - -import com.microsoft.Malmo.MalmoMod; -import com.microsoft.Malmo.MissionHandlerInterfaces.IWantToQuit; -import com.microsoft.Malmo.Schemas.MissionInit; -import com.microsoft.Malmo.Utils.TCPUtils; - -import net.minecraft.profiler.Profiler; -import com.microsoft.Malmo.Utils.TimeHelper; - -import net.minecraftforge.common.config.Configuration; -import java.io.*; -import java.net.ServerSocket; -import java.net.Socket; -import java.nio.charset.Charset; -import java.util.Arrays; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; -import java.util.Hashtable; -import com.microsoft.Malmo.Utils.TCPInputPoller; -import java.util.logging.Level; - -import java.util.LinkedList; -import java.util.List; - - -/** - * MalmoEnvServer - service supporting OpenAI gym "environment" for multi-agent Malmo missions. - */ -public class MalmoEnvServer implements IWantToQuit { - private static Profiler profiler = new Profiler(); - private static int nsteps = 0; - private static boolean debug = false; - - private static String hello = " commands = new LinkedList(); - } - - private static boolean envPolicy = false; // Are we configured by config policy? - - // Synchronize on EnvStateasd - - - private Lock lock = new ReentrantLock(); - private Condition cond = lock.newCondition(); - - private EnvState envState = new EnvState(); - - private Hashtable initTokens = new Hashtable(); - - static final long COND_WAIT_SECONDS = 3; // Max wait in seconds before timing out (and replying to RPC). - static final int BYTES_INT = 4; - static final int BYTES_DOUBLE = 8; - private static final Charset utf8 = Charset.forName("UTF-8"); - - // Service uses a single per-environment client connection - initiated by the remote environment. - - private int port; - private TCPInputPoller missionPoller; // Used for command parsing and not actual communication. - private String version; - - // AOG: From running experiments, I've found that MineRL can get stuck resetting the - // environment which causes huge delays while we wait for the Python side to time - // out and restart the Minecraft instace. Minecraft itself is normally in a recoverable - // state, but the MalmoEnvServer instance will be blocked in a tight spin loop trying - // handling a Peek request from the Python client. To unstick things, I've added this - // flag that can be set when we know things are in a bad state to abort the peek request. - // WARNING: THIS IS ONLY TREATING THE SYMPTOM AND NOT THE ROOT CAUSE - // The reason things are getting stuck is because the player is either dying or we're - // receiving a quit request while an episode reset is in progress. - private boolean abortRequest; - public void abort() { - System.out.println("AOG: MalmoEnvServer.abort"); - abortRequest = true; - } - - /*** - * Malmo "Env" service. - * @param port the port the service listens on. - * @param missionPoller for plugging into existing comms handling. - */ - public MalmoEnvServer(String version, int port, TCPInputPoller missionPoller) { - this.version = version; - this.missionPoller = missionPoller; - this.port = port; - // AOG - Assume we don't wan't to be aborting in the first place - this.abortRequest = false; - } - - /** Initialize malmo env configuration. For now either on or "legacy" AgentHost protocol.*/ - static public void update(Configuration configs) { - envPolicy = configs.get(MalmoMod.ENV_CONFIGS, "env", "false").getBoolean(); - } - - public static boolean isEnv() { - return envPolicy; - } - - /** - * Start servicing the MalmoEnv protocol. - * @throws IOException - */ - public void serve() throws IOException { - - ServerSocket serverSocket = new ServerSocket(port); - serverSocket.setPerformancePreferences(0,2,1); - - - while (true) { - try { - final Socket socket = serverSocket.accept(); - socket.setTcpNoDelay(true); - - Thread thread = new Thread("EnvServerSocketHandler") { - public void run() { - boolean running = false; - try { - checkHello(socket); - - while (true) { - DataInputStream din = new DataInputStream(socket.getInputStream()); - int hdr = din.readInt(); - byte[] data = new byte[hdr]; - din.readFully(data); - - String command = new String(data, utf8); - - if (command.startsWith(" dat = profiler.getProfilingData("root"); - for(int qq = 0; qq < dat.size(); qq++){ - Profiler.Result res = dat.get(qq); - System.out.println(res.profilerName + " " + res.totalUsePercentage + " "+ res.usePercentage); - } - } - - - } else if (command.startsWith(""; - data = command.getBytes(utf8); - hdr = data.length; - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(hdr); - dout.write(data, 0, hdr); - dout.flush(); - } else { - throw new IOException("Unknown env service command"); - } - } - } catch (IOException ioe) { - // ioe.printStackTrace(); - TCPUtils.Log(Level.SEVERE, "MalmoEnv socket error: " + ioe + " (can be on disconnect)"); - // System.out.println("[ERROR] " + "MalmoEnv socket error: " + ioe + " (can be on disconnect)"); - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] MalmoEnv socket error"); - try { - if (running) { - TCPUtils.Log(Level.INFO,"Want to quit on disconnect."); - - System.out.println("[LOGTOPY] " + "Want to quit on disconnect."); - setWantToQuit(); - } - socket.close(); - } catch (IOException ioe2) { - } - } - } - }; - thread.start(); - } catch (IOException ioe) { - TCPUtils.Log(Level.SEVERE, "MalmoEnv service exits on " + ioe); - } - } - } - - private void checkHello(Socket socket) throws IOException { - DataInputStream din = new DataInputStream(socket.getInputStream()); - int hdr = din.readInt(); - if (hdr <= 0 || hdr > hello.length() + 8) // Version number may be somewhat longer in future. - throw new IOException("Invalid MalmoEnv hello header length"); - byte[] data = new byte[hdr]; - din.readFully(data); - if (!new String(data).startsWith(hello + version)) - throw new IOException("MalmoEnv invalid protocol or version - expected " + hello + version); - } - - // Handler for messages. - private boolean missionInit(DataInputStream din, String command, Socket socket) throws IOException { - - String ipOriginator = socket.getInetAddress().getHostName(); - - int hdr; - byte[] data; - hdr = din.readInt(); - data = new byte[hdr]; - din.readFully(data); - String id = new String(data, utf8); - - TCPUtils.Log(Level.INFO,"Mission Init" + id); - - String[] token = id.split(":"); - String experimentId = token[0]; - int role = Integer.parseInt(token[1]); - int reset = Integer.parseInt(token[2]); - int agentCount = Integer.parseInt(token[3]); - Boolean isSynchronous = Boolean.parseBoolean(token[4]); - Long seed = null; - if(token.length > 5) - seed = Long.parseLong(token[5]); - - if(isSynchronous && agentCount > 1){ - throw new IOException("Synchronous mode currently does not support multiple agents."); - } - port = -1; - boolean allTokensConsumed = true; - boolean started = false; - - lock.lock(); - try { - if (role == 0) { - - String previousToken = experimentId + ":0:" + (reset - 1); - initTokens.remove(previousToken); - - String myToken = experimentId + ":0:" + reset; - if (!initTokens.containsKey(myToken)) { - TCPUtils.Log(Level.INFO,"(Pre)Start " + role + " reset " + reset); - started = startUp(command, ipOriginator, experimentId, reset, agentCount, myToken, seed, isSynchronous); - if (started) - initTokens.put(myToken, 0); - } else { - started = true; // Pre-started previously. - } - - // Check that all previous tokens have been consumed. If not don't proceed to mission. - - allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount); - if (!allTokensConsumed) { - try { - cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS); - } catch (InterruptedException ie) { - } - allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount); - } - } else { - TCPUtils.Log(Level.INFO, "Start " + role + " reset " + reset); - - started = startUp(command, ipOriginator, experimentId, reset, agentCount, experimentId + ":" + role + ":" + reset, seed, isSynchronous); - } - } finally { - lock.unlock(); - } - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(allTokensConsumed && started ? 1 : 0); - dout.flush(); - - dout.flush(); - - return allTokensConsumed && started; - } - - private boolean areAllTokensConsumed(String experimentId, int reset, int agentCount) { - boolean allTokensConsumed = true; - for (int i = 1; i < agentCount; i++) { - String tokenForAgent = experimentId + ":" + i + ":" + (reset - 1); - if (initTokens.containsKey(tokenForAgent)) { - TCPUtils.Log(Level.FINE,"Mission init - unconsumed " + tokenForAgent); - allTokensConsumed = false; - } - } - return allTokensConsumed; - } - - private boolean startUp(String command, String ipOriginator, String experimentId, int reset, int agentCount, String myToken, Long seed, Boolean isSynchronous) throws IOException { - - // Clear out mission state - envState.reward = 0.0; - envState.commands.clear(); - envState.obs = null; - envState.info = ""; - - - envState.missionInit = command; - envState.done = false; - envState.quit = false; - envState.token = myToken; - envState.experimentId = experimentId; - envState.agentCount = agentCount; - envState.reset = reset; - envState.synchronous = isSynchronous; - envState.seed = seed; - - return startUpMission(command, ipOriginator); - } - - private boolean startUpMission(String command, String ipOriginator) throws IOException { - - if (missionPoller == null) - return false; - - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(baos); - - missionPoller.commandReceived(command, ipOriginator, dos); - - dos.flush(); - byte[] reply = baos.toByteArray(); - ByteArrayInputStream bais = new ByteArrayInputStream(reply); - DataInputStream dis = new DataInputStream(bais); - int hdr = dis.readInt(); - byte[] replyBytes = new byte[hdr]; - dis.readFully(replyBytes); - - String replyStr = new String(replyBytes); - if (replyStr.equals("MALMOOK")) { - TCPUtils.Log(Level.INFO, "MalmoEnvServer Mission starting ..."); - return true; - } else if (replyStr.equals("MALMOBUSY")) { - TCPUtils.Log(Level.INFO, "MalmoEnvServer Busy - I want to quit"); - this.envState.quit = true; - } - return false; - } - - private static final int stepTagLength = "".length(); // Step with option code. - private synchronized void stepSync(String command, Socket socket, DataInputStream din) throws IOException - { - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Entering synchronous step."); - nsteps += 1; - profiler.startSection("commandProcessing"); - String actions = command.substring(stepTagLength, command.length() - (stepTagLength + 2)); - int options = Character.getNumericValue(command.charAt(stepTagLength - 2)); - boolean withInfo = options == 0 || options == 2; - - - - - // Prepare to write data to the client. - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - double reward = 0.0; - boolean done; - byte[] obs; - String info = ""; - boolean sent = false; - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Acquiring lock for synchronous step."); - - lock.lock(); - try { - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Lock is acquired."); - - done = envState.done; - - // TODO Handle when the environment is done. - - // Process the actions. - if (actions.contains("\n")) { - String[] cmds = actions.split("\\n"); - for(String cmd : cmds) { - envState.commands.add(cmd); - } - } else { - if (!actions.isEmpty()) - envState.commands.add(actions); - } - sent = true; - - - - profiler.endSection(); //cmd - profiler.startSection("requestTick"); - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Received: " + actions); - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Requesting tick."); - // Now wait to run a tick - // If synchronous mode is off then we should see if want to quit is true. - while(!TimeHelper.SyncManager.requestTick() && !done ){Thread.yield();} - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Tick request granted."); - - profiler.endSection(); - profiler.startSection("waitForTick"); - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Waiting for tick."); - - // Then wait until the tick is finished - while(!TimeHelper.SyncManager.isTickCompleted() && !done ){ Thread.yield();} - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] TICK DONE. Getting observation."); - - - - profiler.endSection(); - profiler.startSection("getObservation"); - // After which, get the observations. - obs = getObservation(done); - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Observation received. Getting info."); - - profiler.endSection(); - profiler.startSection("getInfo"); - - - // Pick up rewards. - reward = envState.reward; - if (withInfo) { - info = envState.info; - // if(info == null) - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] FILLING INFO: NULL"); - // else - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] FILLING " + info.toString()); - - } - done = envState.done; - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] STATUS " + Boolean.toString(done)); - envState.info = null; - envState.obs = null; - envState.reward = 0.0; - - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Info received.."); - profiler.endSection(); - } finally { - lock.unlock(); - } - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Lock released. Writing observation, info, done."); - - profiler.startSection("writeObs"); - dout.writeInt(obs.length); - dout.write(obs); - - dout.writeInt(BYTES_DOUBLE + 2); - dout.writeDouble(reward); - dout.writeByte(done ? 1 : 0); - dout.writeByte(sent ? 1 : 0); - - if (withInfo) { - byte[] infoBytes = info.getBytes(utf8); - dout.writeInt(infoBytes.length); - dout.write(infoBytes); - } - - profiler.endSection(); //write obs - profiler.startSection("flush"); - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Packets written. Flushing."); - dout.flush(); - profiler.endSection(); // flush - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Done with step."); - } - // Handler for messages. Single digit option code after _ specifies if turnkey and info are included in message. - private void step(String command, Socket socket, DataInputStream din) throws IOException { - if(envState.synchronous){ - stepSync(command, socket, din); - } - else{ - System.out.println("[ERROR] Asynchronous stepping is not supported in MineRL."); - } - - } - - // Handler for messages. - private void peek(String command, Socket socket, DataInputStream din) throws IOException { - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - byte[] obs; - boolean done; - String info = ""; - // AOG - As we've only seen issues with the peek reqest, I've focused my changes to just - // this function. Initially we want to be optimistic and assume we're not going to abort - // the request and my observations of event timings indicate that there is plenty of time - // between the peek request being received and the reset failing, so a race condition is - // unlikely. - abortRequest = false; - - lock.lock(); - - try { - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Waiting for pistol to fire."); - while(!TimeHelper.SyncManager.hasServerFiredPistol() && !abortRequest){ - - // Now wait to run a tick - while(!TimeHelper.SyncManager.requestTick() && !abortRequest){Thread.yield();} - - - // Then wait until the tick is finished - while(!TimeHelper.SyncManager.isTickCompleted() && !abortRequest){ Thread.yield();} - - - Thread.yield(); - } - - if (abortRequest) { - System.out.println("AOG: Aborting peek request"); - // AOG - We detect the lack of observation within our Python wrapper and throw a slightly - // diferent exception that by-passes MineRLs automatic clean up code. If we were to report - // 'done', the MineRL detects this as a runtime error and kills the Minecraft process - // triggering a lengthy restart. So far from testing, Minecraft itself is fine can we can - // retry the reset, it's only the tight loops above that were causing things to stall and - // timeout. - // No observation - dout.writeInt(0); - // No info - dout.writeInt(0); - // Done - dout.writeInt(1); - dout.writeByte(0); - dout.flush(); - return; - } - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Pistol fired!."); - // Wait two ticks for the first observation from server to be propagated. - while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();} - - // Then wait until the tick is finished - while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();} - - - - while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();} - - // Then wait until the tick is finished - while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();} - - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Getting observation."); - - obs = getObservation(false); - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Observation acquired."); - done = envState.done; - info = envState.info; - - } finally { - lock.unlock(); - } - - dout.writeInt(obs.length); - dout.write(obs); - - byte[] infoBytes = info.getBytes(utf8); - dout.writeInt(infoBytes.length); - dout.write(infoBytes); - - dout.writeInt(1); - dout.writeByte(done ? 1 : 0); - - dout.flush(); - } - - // Get the current observation. If none and not done wait for a short time. - public byte[] getObservation(boolean done) { - byte[] obs = envState.obs; - if (obs == null){ - System.out.println("[ERROR] Video observation is null; please notify the developer."); - } - return obs; - } - - // Handler for messages - used by non-zero roles to discover integrated server port from primary (role 0) service. - - private final static int findTagLength = "".length(); - - private void find(String command, Socket socket) throws IOException { - - Integer port; - lock.lock(); - try { - String token = command.substring(findTagLength, command.length() - (findTagLength + 1)); - TCPUtils.Log(Level.INFO, "Find token? " + token); - - // Purge previous token. - String[] tokenSplits = token.split(":"); - String experimentId = tokenSplits[0]; - int role = Integer.parseInt(tokenSplits[1]); - int reset = Integer.parseInt(tokenSplits[2]); - - String previousToken = experimentId + ":" + role + ":" + (reset - 1); - initTokens.remove(previousToken); - cond.signalAll(); - - // Check for next token. Wait for a short time if not already produced. - port = initTokens.get(token); - if (port == null) { - try { - cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS); - } catch (InterruptedException ie) { - } - port = initTokens.get(token); - if (port == null) { - port = 0; - TCPUtils.Log(Level.INFO,"Role " + role + " reset " + reset + " waiting for token."); - } - } - } finally { - lock.unlock(); - } - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(port); - dout.flush(); - } - - public boolean isSynchronous(){ - return envState.synchronous; - } - - // Handler for messages. These reset the service so use with care! - private void init(String command, Socket socket) throws IOException { - lock.lock(); - try { - initTokens = new Hashtable(); - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(1); - dout.flush(); - } finally { - lock.unlock(); - } - } - - // Handler for (quit mission) messages. - private void quit(String command, Socket socket) throws IOException { - lock.lock(); - try { - if (!envState.done){ - - envState.quit = true; - } - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Pistol fired!."); - // Wait two ticks for the first observation from server to be propagated. - while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();} - - // Then wait until the tick is finished - while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();} - - - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(envState.done ? 1 : 0); - dout.flush(); - } finally { - lock.unlock(); - } - } - - private final static int closeTagLength = "".length(); - - // Handler for messages. - private void close(String command, Socket socket) throws IOException { - lock.lock(); - try { - String token = command.substring(closeTagLength, command.length() - (closeTagLength + 1)); - - initTokens.remove(token); - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(1); - dout.flush(); - } finally { - lock.unlock(); - } - } - - // Handler for messages. - private void status(String command, Socket socket) throws IOException { - lock.lock(); - try { - String status = "{}"; // TODO Possibly have something more interesting to report. - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - - byte[] statusBytes = status.getBytes(utf8); - dout.writeInt(statusBytes.length); - dout.write(statusBytes); - - dout.flush(); - } finally { - lock.unlock(); - } - } - - // Handler for messages. These "kill the service" temporarily so use with care!f - private void exit(String command, Socket socket) throws IOException { - // lock.lock(); - try { - // We may exit before we get a chance to reply. - TimeHelper.SyncManager.setSynchronous(false); - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(1); - dout.flush(); - - ClientStateMachine.exitJava(); - - } finally { - // lock.unlock(); - } - } - - // Malmo client state machine interface methods: - - public String getCommand() { - try { - String command = envState.commands.poll(); - if (command == null) - return ""; - else - return command; - } finally { - } - } - - public void endMission() { - // lock.lock(); - try { - // AOG - If the mission is ending, we always want to abort requests and they won't - // be able to progress to completion and will stall. - System.out.println("AOG: MalmoEnvServer.endMission"); - abort(); - envState.done = true; - envState.quit = false; - envState.missionInit = null; - - if (envState.token != null) { - initTokens.remove(envState.token); - envState.token = null; - envState.experimentId = null; - envState.agentCount = 0; - envState.reset = 0; - - // cond.signalAll(); - } - // lock.unlock(); - } finally { - } - } - // Record a Malmo "observation" json - as the env info since an environment "obs" is a video frame. - public void observation(String info) { - // Parsing obs as JSON would be slower but less fragile than extracting the turn_key using string search. - // lock.lock(); - try { - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Inserting: " + info); - envState.info = info; - // cond.signalAll(); - } finally { - // lock.unlock(); - } - } - - public void addRewards(double rewards) { - // lock.lock(); - try { - envState.reward += rewards; - } finally { - // lock.unlock(); - } - } - - public void addFrame(byte[] frame) { - // lock.lock(); - try { - envState.obs = frame; // Replaces current. - // cond.signalAll(); - } finally { - // lock.unlock(); - } - } - - public void notifyIntegrationServerStarted(int integrationServerPort) { - lock.lock(); - try { - if (envState.token != null) { - TCPUtils.Log(Level.INFO,"Integration server start up - token: " + envState.token); - addTokens(integrationServerPort, envState.token, envState.experimentId, envState.agentCount, envState.reset); - cond.signalAll(); - } else { - TCPUtils.Log(Level.WARNING,"No mission token on integration server start up!"); - } - } finally { - lock.unlock(); - } - } - - private void addTokens(int integratedServerPort, String myToken, String experimentId, int agentCount, int reset) { - initTokens.put(myToken, integratedServerPort); - // Place tokens for other agents to find. - for (int i = 1; i < agentCount; i++) { - String tokenForAgent = experimentId + ":" + i + ":" + reset; - initTokens.put(tokenForAgent, integratedServerPort); - } - } - - // IWantToQuit implementation. - - @Override - public boolean doIWantToQuit(MissionInit missionInit) { - // lock.lock(); - try { - return envState.quit; - } finally { - // lock.unlock(); - } - } - - public Long getSeed(){ - return envState.seed; - } - - private void setWantToQuit() { - // lock.lock(); - try { - envState.quit = true; - - } finally { - - if(TimeHelper.SyncManager.isSynchronous()){ - // We want to dsynchronize everything. - TimeHelper.SyncManager.setSynchronous(false); - } - // lock.unlock(); - } - } - - @Override - public void prepare(MissionInit missionInit) { - } - - @Override - public void cleanup() { - } - - @Override - public String getOutcome() { - return "Env quit"; - } -} diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/gpu/Dockerfile b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/gpu/Dockerfile deleted file mode 100644 index 57f6a4be..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/gpu/Dockerfile +++ /dev/null @@ -1,78 +0,0 @@ -FROM mcr.microsoft.com/azureml/base-gpu:openmpi3.1.2-cuda10.0-cudnn7-ubuntu18.04 - -# Install some basic utilities -RUN apt-get update && apt-get install -y \ - curl \ - ca-certificates \ - sudo \ - cpio \ - git \ - bzip2 \ - libx11-6 \ - tmux \ - htop \ - gcc \ - xvfb \ - python-opengl \ - x11-xserver-utils \ - ffmpeg \ - mesa-utils \ - nano \ - vim \ - rsync \ - && rm -rf /var/lib/apt/lists/* - -# Create a working directory -RUN mkdir /app -WORKDIR /app - -# Create a Python 3.7 environment -RUN conda install conda-build \ - && conda create -y --name py37 python=3.7.3 \ - && conda clean -ya -ENV CONDA_DEFAULT_ENV=py37 - -# Install Minecraft needed libraries -RUN mkdir -p /usr/share/man/man1 && \ - sudo apt-get update && \ - sudo apt-get install -y \ - openjdk-8-jre-headless=8u162-b12-1 \ - openjdk-8-jdk-headless=8u162-b12-1 \ - openjdk-8-jre=8u162-b12-1 \ - openjdk-8-jdk=8u162-b12-1 - -RUN pip install --upgrade --user minerl - -# PyTorch with CUDA 10 installation -RUN conda install -y -c pytorch \ - cuda100=1.0 \ - magma-cuda100=2.4.0 \ - "pytorch=1.1.0=py3.7_cuda10.0.130_cudnn7.5.1_0" \ - torchvision=0.3.0 \ - && conda clean -ya - -RUN pip install \ - pandas \ - matplotlib \ - numpy \ - scipy \ - azureml-defaults \ - tensorboardX \ - tensorflow-gpu==1.15rc2 \ - GPUtil \ - tabulate \ - dm_tree \ - lz4 \ - ray==0.8.3 \ - ray[rllib]==0.8.3 \ - ray[tune]==0.8.3 - -COPY patch_files/* /root/.local/lib/python3.7/site-packages/minerl/env/Malmo/Minecraft/src/main/java/com/microsoft/Malmo/Client/ - -# Start minerl to pre-fetch minerl files (saves time when starting minerl during training) -RUN xvfb-run -a -s "-screen 0 1400x900x24" python -c "import gym; import minerl; env = gym.make('MineRLTreechop-v0'); env.close();" - -RUN pip install --index-url https://test.pypi.org/simple/ malmo && \ - python -c "import malmo.minecraftbootstrap; malmo.minecraftbootstrap.download();" - -ENV MALMO_XSD_PATH="/app/MalmoPlatform/Schemas" diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/gpu/patch_files/ClientStateMachine.java b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/gpu/patch_files/ClientStateMachine.java deleted file mode 100644 index 71668152..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/gpu/patch_files/ClientStateMachine.java +++ /dev/null @@ -1,2481 +0,0 @@ -// -------------------------------------------------------------------------------------------------- -// Copyright (c) 2016 Microsoft Corporation -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and -// associated documentation files (the "Software"), to deal in the Software without restriction, -// including without limitation the rights to use, copy, modify, merge, publish, distribute, -// sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT -// NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -------------------------------------------------------------------------------------------------- - -package com.microsoft.Malmo.Client; - -import java.io.DataOutputStream; -import java.io.IOException; -import java.lang.reflect.Field; -import java.math.BigDecimal; -import java.net.UnknownHostException; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Map.Entry; -import java.util.logging.Level; - -import javax.xml.bind.JAXBException; -import javax.xml.stream.XMLStreamException; - -import net.minecraft.client.Minecraft; -import net.minecraft.client.entity.EntityPlayerSP; -import net.minecraft.client.gui.GuiDisconnected; -import net.minecraft.client.gui.GuiIngameMenu; -import net.minecraft.client.gui.GuiMainMenu; -import net.minecraft.client.gui.GuiScreen; -import net.minecraft.client.multiplayer.WorldClient; -import net.minecraft.client.network.NetHandlerPlayClient; -import net.minecraft.client.settings.GameSettings; -import net.minecraft.entity.Entity; -import net.minecraft.entity.EntityLivingBase; -import net.minecraft.launchwrapper.Launch; -import net.minecraft.network.NetworkManager; -import net.minecraft.server.integrated.IntegratedServer; -import net.minecraft.util.math.MathHelper; -import net.minecraft.util.text.TextComponentString; -import net.minecraft.world.GameType; -import net.minecraft.world.World; -import net.minecraft.world.chunk.Chunk; -import net.minecraft.world.chunk.IChunkProvider; -import net.minecraftforge.common.MinecraftForge; -import net.minecraftforge.common.config.Configuration; -import net.minecraftforge.fml.client.event.ConfigChangedEvent.OnConfigChangedEvent; -import net.minecraftforge.fml.common.FMLCommonHandler; -import net.minecraftforge.fml.common.Loader; -import net.minecraftforge.fml.common.eventhandler.SubscribeEvent; -import net.minecraftforge.fml.common.gameevent.TickEvent; -import net.minecraftforge.fml.common.gameevent.TickEvent.ClientTickEvent; -import net.minecraftforge.fml.common.gameevent.TickEvent.Phase; -import net.minecraftforge.fml.common.gameevent.TickEvent.ServerTickEvent; - -import org.xml.sax.SAXException; - -import com.google.gson.JsonObject; -import com.microsoft.Malmo.IState; -import com.microsoft.Malmo.MalmoMod; -import com.microsoft.Malmo.MalmoMod.IMalmoMessageListener; -import com.microsoft.Malmo.MalmoMod.MalmoMessageType; -import com.microsoft.Malmo.StateEpisode; -import com.microsoft.Malmo.StateMachine; -import com.microsoft.Malmo.Client.MalmoModClient.InputType; -import com.microsoft.Malmo.MissionHandlerInterfaces.IVideoProducer; -import com.microsoft.Malmo.MissionHandlerInterfaces.IWantToQuit; -import com.microsoft.Malmo.MissionHandlers.MissionBehaviour; -import com.microsoft.Malmo.MissionHandlers.MultidimensionalReward; -import com.microsoft.Malmo.Schemas.AgentSection; -import com.microsoft.Malmo.Schemas.AgentStart; -import com.microsoft.Malmo.Schemas.ClientAgentConnection; -import com.microsoft.Malmo.Schemas.MinecraftServerConnection; -import com.microsoft.Malmo.Schemas.Mission; -import com.microsoft.Malmo.Schemas.MissionDiagnostics; -import com.microsoft.Malmo.Schemas.MissionEnded; -import com.microsoft.Malmo.Schemas.MissionInit; -import com.microsoft.Malmo.Schemas.MissionResult; -import com.microsoft.Malmo.Schemas.Reward; -import com.microsoft.Malmo.Schemas.ModSettings; -import com.microsoft.Malmo.Schemas.PosAndDirection; -import com.microsoft.Malmo.Utils.AddressHelper; -import com.microsoft.Malmo.Utils.AuthenticationHelper; -import com.microsoft.Malmo.Utils.SchemaHelper; -import com.microsoft.Malmo.Utils.ScreenHelper; -import com.microsoft.Malmo.Utils.SeedHelper; -import com.microsoft.Malmo.Utils.ScoreHelper; -import com.microsoft.Malmo.Utils.TextureHelper; -import com.microsoft.Malmo.Utils.ScreenHelper.TextCategory; -import com.microsoft.Malmo.Utils.TCPInputPoller; -import com.microsoft.Malmo.Utils.TCPInputPoller.CommandAndIPAddress; -import com.microsoft.Malmo.Utils.TimeHelper.SyncTickEvent; -import com.microsoft.Malmo.Utils.TCPSocketChannel; -import com.microsoft.Malmo.Utils.TCPUtils; -import com.microsoft.Malmo.Utils.TimeHelper; -import com.mojang.authlib.properties.Property; - -/** - * Class designed to track and control the state of the mod, especially regarding mission launching/running.
- * States are defined by the MissionState enum, and control is handled by - * MissionStateEpisode subclasses. The ability to set the state directly is - * restricted, but hooks such as onPlayerReadyForMission etc are exposed to - * allow subclasses to react to certain state changes.
- * The ProjectMalmo mod app class inherits from this and uses these hooks to run missions. - */ -public class ClientStateMachine extends StateMachine implements IMalmoMessageListener -{ - // AOG - Dropped from 2000 to 1000 to speed up detection of failed server restarts - private static final int WAIT_MAX_TICKS = 1000; // Over 1 minute and a half in client ticks. - private static final int VIDEO_MAX_WAIT = 90 * 1000; // Max wait for video in ms. - private static final String MISSING_MCP_PORT_ERROR = "no_mcp"; - private static final String INFO_MCP_PORT = "info_mcp"; - private static final String INFO_RESERVE_STATUS = "info_reservation"; - - private MissionInit currentMissionInit = null; // The MissionInit object for the mission currently being loaded/run. - private MissionBehaviour missionBehaviour = new MissionBehaviour(); - private String missionQuitCode = ""; // The reason why this mission ended. - private MultidimensionalReward finalReward = new MultidimensionalReward(true); // The reward at the end of the mission, sent separately to ensure timely delivery. - private MissionDiagnostics missionEndedData = new MissionDiagnostics(); - private ScreenHelper screenHelper = new ScreenHelper(); - protected MalmoModClient inputController; - - // Env service: - protected MalmoEnvServer envServer; - - // Socket stuff: - protected TCPInputPoller missionPoller; - protected TCPInputPoller controlInputPoller; - protected int integratedServerPort; - String reservationID = ""; // empty if we are not reserved, otherwise "RESERVED" + the experiment ID we are reserved for. - long reservationExpirationTime = 0; - private TCPSocketChannel missionControlSocket; - - private void reserveClient(String id) - { - synchronized(this.reservationID) - { - ClientStateMachine.this.getScreenHelper().clearFragment(INFO_RESERVE_STATUS); - - // id is in the form :, where long is the length of time to keep the reservation for, - // and expID is the experimentationID used to ensure the client is reserved for the correct experiment. - int separator = id.indexOf(":"); - if (separator == -1) - { - System.out.println("Error - malformed reservation request - client will not be reserved."); - this.reservationID = ""; - } - else - { - long duration = Long.valueOf(id.substring(0, separator)); - String expID = id.substring(separator + 1); - this.reservationExpirationTime = System.currentTimeMillis() + duration; - // We don't just use the id, in case users have supplied a blank string as their experiment ID. - this.reservationID = "RESERVED" + expID; - ClientStateMachine.this.getScreenHelper().addFragment("Reserved: " + expID, TextCategory.TXT_INFO, (int)duration);//INFO_RESERVE_STATUS); - } - } - } - - private boolean isReserved() - { - synchronized(this.reservationID) - { - System.out.println("==== RES: " + this.reservationID + " - " + (this.reservationExpirationTime - System.currentTimeMillis())); - return !this.reservationID.isEmpty() && this.reservationExpirationTime > System.currentTimeMillis(); - } - } - - private boolean isAvailable(String id) - { - synchronized(this.reservationID) - { - return (this.reservationID.isEmpty() || this.reservationID.equals("RESERVED" + id) || System.currentTimeMillis() >= this.reservationExpirationTime); - } - } - - private void cancelReservation() - { - synchronized(this.reservationID) - { - this.reservationID = ""; - ClientStateMachine.this.getScreenHelper().clearFragment(INFO_RESERVE_STATUS); - } - } - - protected TCPSocketChannel getMissionControlSocket() { return this.missionControlSocket; } - - protected void createMissionControlSocket() - { - TCPUtils.LogSection ls = new TCPUtils.LogSection("Creating MissionControlSocket"); - // Set up a TCP connection to the agent: - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - if (this.missionControlSocket == null || - this.missionControlSocket.getPort() != cac.getAgentMissionControlPort() || - this.missionControlSocket.getAddress() == null || - !this.missionControlSocket.isValid() || - !this.missionControlSocket.isOpen() || - !this.missionControlSocket.getAddress().equals(cac.getAgentIPAddress())) - { - if (this.missionControlSocket != null) - this.missionControlSocket.close(); - this.missionControlSocket = new TCPSocketChannel(cac.getAgentIPAddress(), cac.getAgentMissionControlPort(), "mcp"); - } - ls.close(); - } - - public ClientStateMachine(ClientState initialState, MalmoModClient inputController) - { - super(initialState); - this.inputController = inputController; - - // Register ourself on the event busses, so we can harness the client tick: - MinecraftForge.EVENT_BUS.register(this); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_TEXT); - } - - @Override - public void clearErrorDetails() - { - super.clearErrorDetails(); - this.missionQuitCode = ""; - } - - @SubscribeEvent - public void onClientTick(TickEvent.ClientTickEvent ev) - { - // Use the client tick to ensure we regularly update our state (from the client thread) - updateState(); - } - - public ScreenHelper getScreenHelper() - { - return screenHelper; - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - if (messageType == MalmoMessageType.SERVER_TEXT) - { - String chat = data.get("chat"); - if (chat != null) - Minecraft.getMinecraft().ingameGUI.getChatGUI().printChatMessageWithOptionalDeletion(new TextComponentString(chat), 1); - else - { - String text = data.get("text"); - ScreenHelper.TextCategory category = ScreenHelper.TextCategory.valueOf(data.get("category")); - String strtime = data.get("displayTime"); - Integer time = (strtime != null) ? Integer.valueOf(strtime) : null; - this.getScreenHelper().addFragment(text, category, time); - } - } - } - - @Override - protected String getName() - { - return "CLIENT"; - } - - @Override - protected void onPreStateChange(IState toState) - { - this.getScreenHelper().addFragment("CLIENT: " + toState, ScreenHelper.TextCategory.TXT_CLIENT_STATE, ""); - } - - /** - * Create the episode object for the requested state. - * - * @param state the state the mod is entering - * @return a MissionStateEpisode that localises all the logic required to run this state - */ - @Override - protected StateEpisode getStateEpisodeForState(IState state) - { - if (!(state instanceof ClientState)) - return null; - - ClientState cs = (ClientState) state; - switch (cs) { - case WAITING_FOR_MOD_READY: - return new InitialiseClientModEpisode(this); - case DORMANT: - return new DormantEpisode(this); - case CREATING_HANDLERS: - return new CreateHandlersEpisode(this); - case EVALUATING_WORLD_REQUIREMENTS: - return new EvaluateWorldRequirementsEpisode(this); - case PAUSING_OLD_SERVER: - return new PauseOldServerEpisode(this); - case CLOSING_OLD_SERVER: - return new CloseOldServerEpisode(this); - case CREATING_NEW_WORLD: - return new CreateWorldEpisode(this); - case WAITING_FOR_SERVER_READY: - return new WaitingForServerEpisode(this); - case RUNNING: - return new MissionRunningEpisode(this); - case IDLING: - return new MissionIdlingEpisode(this); - case MISSION_ENDED: - return new MissionEndedEpisode(this, MissionResult.ENDED, false, false, true); - case ERROR_DUFF_HANDLERS: - return new MissionEndedEpisode(this, MissionResult.MOD_FAILED_TO_INSTANTIATE_HANDLERS, true, true, true); - case ERROR_INTEGRATED_SERVER_UNREACHABLE: - return new MissionEndedEpisode(this, MissionResult.MOD_SERVER_UNREACHABLE, true, true, true); - case ERROR_NO_WORLD: - return new MissionEndedEpisode(this, MissionResult.MOD_HAS_NO_WORLD_LOADED, true, true, true); - case ERROR_CANNOT_CREATE_WORLD: - return new MissionEndedEpisode(this, MissionResult.MOD_FAILED_TO_CREATE_WORLD, true, true, true); - case ERROR_CANNOT_START_AGENT: // run-ons deliberate - case ERROR_LOST_AGENT: - case ERROR_LOST_VIDEO: - return new MissionEndedEpisode(this, MissionResult.MOD_HAS_NO_AGENT_AVAILABLE, true, true, false); - case ERROR_LOST_NETWORK_CONNECTION: // run-on deliberate - case ERROR_CANNOT_CONNECT_TO_SERVER: - return new MissionEndedEpisode(this, MissionResult.MOD_CONNECTION_FAILED, true, false, true); // No point trying to inform the server - we can't reach it anyway! - case ERROR_TIMED_OUT_WAITING_FOR_EPISODE_START: // run-ons deliberate - case ERROR_TIMED_OUT_WAITING_FOR_EPISODE_PAUSE: - case ERROR_TIMED_OUT_WAITING_FOR_EPISODE_CLOSE: - case ERROR_TIMED_OUT_WAITING_FOR_MISSION_END: - case ERROR_TIMED_OUT_WAITING_FOR_WORLD_CREATE: - return new MissionEndedEpisode(this, MissionResult.MOD_CONNECTION_FAILED, true, true, true); - case MISSION_ABORTED: - return new MissionEndedEpisode(this, MissionResult.MOD_SERVER_ABORTED_MISSION, true, false, true); // Don't inform the server - it already knows (we're acting on its notification) - case WAITING_FOR_SERVER_MISSION_END: - return new WaitingForServerMissionEndEpisode(this); - default: - break; - } - return null; - } - - protected MissionInit currentMissionInit() - { - return this.currentMissionInit; - } - - protected MissionBehaviour currentMissionBehaviour() - { - return this.missionBehaviour; - } - - protected class MissionInitResult - { - public MissionInit missionInit = null; - public boolean wasMissionInit = false; - public String error = null; - } - - protected MissionInitResult decodeMissionInit(String command) - { - MissionInitResult result = new MissionInitResult(); - if (command == null) - { - result.error = "Null command passed."; - return result; - } - - String rootNodeName = SchemaHelper.getRootNodeName(command); - if (rootNodeName != null && rootNodeName.equals("MissionInit")) - { - result.wasMissionInit = true; - // Attempt to decode the MissionInit XML string. - try - { - result.missionInit = (MissionInit) SchemaHelper.deserialiseObject(command, "MissionInit.xsd", MissionInit.class); - } - catch (JAXBException e) - { - System.out.println("JAXB exception: " + e); - if (e.getMessage() != null) - result.error = e.getMessage(); - else if (e.getLinkedException() != null && e.getLinkedException().getMessage() != null) - result.error = e.getLinkedException().getMessage(); - else - result.error = "Unspecified problem parsing MissionInit - check your Mission xml."; - } - catch (SAXException e) - { - System.out.println("SAX exception: " + e); - result.error = e.getMessage(); - } - catch (XMLStreamException e) - { - System.out.println("XMLStreamException: " + e); - result.error = e.getMessage(); - } - } - return result; - } - - protected boolean areMissionsEqual(Mission m1, Mission m2) - { - return true; - // FIX NEEDED - the following code fails because m1 may have been - // modified since loading - eg the MazeDecorator writes directly to the XML, - // and the use of some of the getters in the XSD-generated code can cause extra - // (empty) nodes to be added to the resulting XML. - // We need a more robust way of comparing two mission objects. - // For now, simply return true, since a false positive is less dangerous - // than a false negative. - /* - try { - String s1 = SchemaHelper.serialiseObject(m1, Mission.class); - String s2 = SchemaHelper.serialiseObject(m2, Mission.class); - return s1.compareTo(s2) == 0; - } catch( JAXBException e ) { - System.out.println("JAXB exception: " + e); - return false; - }*/ - } - - /** - * Set up the mission poller.
- * This is called during the initialisation episode, but also needs to be - * available for other episodes in case the configuration changes, resulting - * in changes to the ports. - * - * @throws UnknownHostException - */ - protected void initialiseComms() throws UnknownHostException - { - // Start polling for missions: - if (this.missionPoller != null) - { - this.missionPoller.stopServer(); - } - - this.missionPoller = new TCPInputPoller(AddressHelper.getMissionControlPortOverride(), AddressHelper.MIN_MISSION_CONTROL_PORT, AddressHelper.MAX_FREE_PORT, true, "mcp") - { - @Override - public void onError(String error, DataOutputStream dos) - { - System.out.println("SENDING ERROR: " + error); - try - { - dos.writeInt(error.length()); - dos.writeBytes(error); - dos.flush(); - } - catch (IOException e) - { - } - } - - private void reply(String reply, DataOutputStream dos) - { - System.out.println("REPLYING WITH: " + reply); - try - { - dos.writeInt(reply.length()); - dos.writeBytes(reply); - dos.flush(); - } - catch (IOException e) - { - System.out.println("Failed to reply to message!"); - } - } - - @Override - public boolean onCommand(String command, String ipFrom, DataOutputStream dos) - { - System.out.println("Received from " + ipFrom + ":" + - command.substring(0, Math.min(command.length(), 1024))); - boolean keepProcessing = false; - - // Possible commands: - // 1: MALMO_REQUEST_CLIENT:: - // 2: MALMO_CANCEL_REQUEST - // 3: MALMO_FIND_SERVER - // 4: MALMO_KILL_CLIENT - // 5: MissionInit - - String reservePrefixGeneral = "MALMO_REQUEST_CLIENT:"; - String reservePrefix = reservePrefixGeneral + Loader.instance().activeModContainer().getVersion() + ":"; - String findServerPrefix = "MALMO_FIND_SERVER"; - String cancelRequestCommand = "MALMO_CANCEL_REQUEST"; - String killClientCommand = "MALMO_KILL_CLIENT"; - - if (command.startsWith(reservePrefix)) - { - // Reservation request. - // We either reply with MALMOOK, if we are free, or MALMOBUSY if not. - IState currentState = getStableState(); - if (currentState != null && currentState.equals(ClientState.DORMANT) && !isReserved()) - { - reserveClient(command.substring(reservePrefix.length())); - reply("MALMOOK", dos); - } - else - { - // We're busy - we can't be reserved. - reply("MALMOBUSY", dos); - } - } - else if (command.startsWith(reservePrefixGeneral)) - { - // Reservation request, but it didn't match the request we expect, above. - // This happens if the agent sending the request is running a different version of Malmo - - // a version mismatch error. - reply("MALMOERRORVERSIONMISMATCH in reservation string (Got " + command + ", expected " + reservePrefix + " - check your path for old versions of MalmoPython/MalmoJava/Malmo.lib etc)", dos); - } - else if (command.equals(cancelRequestCommand)) - { - // If we've been reserved, cancel the reservation. - if (isReserved()) - { - cancelReservation(); - reply("MALMOOK", dos); - } - else - { - // We weren't reserved in the first place - something is odd. - reply("MALMOERRORAttempt to cancel a reservation that was never made.", dos); - } - } - else if (command.startsWith(findServerPrefix)) - { - // Request to find the server for the given experiment ID. - String expID = command.substring(findServerPrefix.length()); - if (currentMissionInit() != null && currentMissionInit().getExperimentUID().equals(expID)) - { - // Our Experiment IDs match, so we are running the same experiment. - // Return the port and server IP address to the caller: - MinecraftServerConnection msc = currentMissionInit().getMinecraftServerConnection(); - if (msc == null) - reply("MALMONOSERVERYET", dos); // Mission might be starting up. - else - reply("MALMOS" + msc.getAddress().trim() + ":" + msc.getPort(), dos); - } - else - { - // We don't have a MissionInit ourselves, or we're running a different experiment, - // so we can't help. - reply("MALMONOSERVER", dos); - } - } - else if (command.equals(killClientCommand)) - { - // Kill switch provided in case AI takes over the world... - // Or, more likely, in case this Minecraft instance has become unreliable (eg if it's been running for several days) - // and needs to be replaced with a fresh instance. - // If we are currently running a mission, we gracefully decline, to prevent users from wiping out - // other users' experiments. - // We also decline unless we were launched in "replaceable" mode - a command-line switch that indicates we were - // launched by a script which is still running, and can therefore replace us when we terminate. - IState currentState = getStableState(); - if (currentState != null && currentState.equals(ClientState.DORMANT) && !isReserved()) - { - Configuration config = MalmoMod.instance.getModSessionConfigFile(); - if (config.getBoolean("replaceable", "runtype", false, "Will be replaced if killed")) - { - reply("MALMOOK", dos); - - missionPoller.stopServer(); - exitJava(); - } - else - { - reply("MALMOERRORNOTKILLABLE", dos); - } - } - else - { - // We're too busy and important to be killed. - reply("MALMOBUSY", dos); - } - } - else - { - // See if we've been sent a MissionInit message: - - MissionInitResult missionInitResult = decodeMissionInit(command); - - if (missionInitResult.wasMissionInit && missionInitResult.missionInit == null) - { - // Got sent a duff MissionInit xml - pass back the JAXB/SAXB errors. - reply("MALMOERROR" + missionInitResult.error, dos); - } - else if (missionInitResult.wasMissionInit && missionInitResult.missionInit != null) - { - MissionInit missionInit = missionInitResult.missionInit; - // We've been sent a MissionInit message. - // First, check the version number: - String platformVersion = missionInit.getPlatformVersion(); - String ourVersion = Loader.instance().activeModContainer().getVersion(); - if (platformVersion == null || !platformVersion.equals(ourVersion)) - { - reply("MALMOERRORVERSIONMISMATCH (Got " + platformVersion + ", expected " + ourVersion + " - check your path for old versions of MalmoPython/MalmoJava/Malmo.lib etc)", dos); - } - else - { - // MissionInit passed to us - this is a request to launch this mission. Can we? - IState currentState = getStableState(); - if (currentState != null && currentState.equals(ClientState.DORMANT) && isAvailable(missionInit.getExperimentUID())) - { - reply("MALMOOK", dos); - keepProcessing = true; // State machine will now process this MissionInit and start the mission. - } - else - { - // We're busy - we can't run this mission. - reply("MALMOBUSY", dos); - } - } - } - } - - return keepProcessing; - } - }; - - int mcPort = 0; - if (MalmoEnvServer.isEnv()) { - // Start up new "Env" service instead of Malmo AgentHost api. - System.out.println("***** Start MalmoEnvServer on port " + AddressHelper.getMissionControlPortOverride()); - this.envServer = new MalmoEnvServer(Loader.instance().activeModContainer().getVersion(), AddressHelper.getMissionControlPortOverride(), this.missionPoller); - Thread thread = new Thread("MalmoEnvServer") { - public void run() { - try { - envServer.serve(); - } catch (IOException ioe) { - System.out.println("MalmoEnvServer exist on " + ioe); - } - } - }; - thread.start(); - } else { - // "Legacy" AgentHost api. - this.missionPoller.start(); - mcPort = ClientStateMachine.this.missionPoller.getPortBlocking(); - } - - // Tell the address helper what the actual port is: - AddressHelper.setMissionControlPort(mcPort); - if (AddressHelper.getMissionControlPort() == -1) - { - // Failed to create a mission control port - nothing will work! - System.out.println("**** NO MISSION CONTROL SOCKET CREATED - WAS THE PORT IN USE? (Check Mod GUI options) ****"); - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Could not open a Mission Control Port - check the Mod GUI options.", TextCategory.TXT_CLIENT_WARNING, MISSING_MCP_PORT_ERROR); - } - else - { - // Clear the error string, if there was one: - ClientStateMachine.this.getScreenHelper().clearFragment(MISSING_MCP_PORT_ERROR); - } - // Display the port number: - ClientStateMachine.this.getScreenHelper().clearFragment(INFO_MCP_PORT); - if (AddressHelper.getMissionControlPort() != -1) - ClientStateMachine.this.getScreenHelper().addFragment("MCP: " + AddressHelper.getMissionControlPort(), TextCategory.TXT_INFO, INFO_MCP_PORT); - } - - public static void exitJava() { - // Give non-hard exit 10 seconds to complete and force a hard exit. - Thread deadMansHandle = new Thread(new Runnable() { - @Override - public void run() { - for (int i = 10; i > 0; i--) { - try { - Thread.sleep(1000); - System.out.println("Waiting to exit " + i + "..."); - } catch (InterruptedException e) { - System.out.println("Interrupted " + i + "..."); - } - } - - // Kill it with fire!!! - System.out.println("Attempting hard exit"); - FMLCommonHandler.instance().exitJava(0, true); - } - }); - - deadMansHandle.setDaemon(true); - deadMansHandle.start(); - - // Have to use FMLCommonHandler; direct calls to System.exit() are trapped and denied by the FML code. - FMLCommonHandler.instance().exitJava(0, false); - } - - // --------------------------------------------------------------------------------------------------------- - // Episode helpers - each extends a MissionStateEpisode to encapsulate a certain state - // --------------------------------------------------------------------------------------------------------- - - public abstract class ErrorAwareEpisode extends StateEpisode implements IMalmoMessageListener - { - protected Boolean errorFlag = false; - protected Map errorData = null; - - public ErrorAwareEpisode(ClientStateMachine machine) - { - super(machine); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_ABORT); - } - - protected boolean pingAgent(boolean abortIfFailed) - { - if (AddressHelper.getMissionControlPort() == 0) { - // MalmoEnvServer has no server to client ping. - return true; - } - - boolean sentOkay = ClientStateMachine.this.getMissionControlSocket().sendTCPString("", 1); - if (!sentOkay) - { - // It's not available - bail. - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Lost contact with agent - aborting mission", TextCategory.TXT_CLIENT_WARNING, 10000); - if (abortIfFailed) - episodeHasCompletedWithErrors(ClientState.ERROR_LOST_AGENT, "Lost contact with the agent"); - } - return sentOkay; - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - if (messageType == MalmoMod.MalmoMessageType.SERVER_ABORT) - { - synchronized (this.errorFlag) - { - this.errorFlag = true; - this.errorData = data; - // Save the error message, if there is one: - if (data != null) - { - String message = data.get("message"); - String user = data.get("username"); - String error = data.get("error"); - String report = ""; - if (user != null) - report += "From " + user + ": "; - if (error != null) - report += error; - if (message != null) - report += " (" + message + ")"; - ClientStateMachine.this.saveErrorDetails(report); - } - onAbort(data); - } - } - } - - @Override - public void cleanup() - { - super.cleanup(); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_ABORT); - } - - protected boolean inAbortState() - { - synchronized (this.errorFlag) - { - return this.errorFlag; - } - } - - protected Map getErrorData() - { - synchronized (this.errorFlag) - { - return this.errorData; - } - } - - protected void onAbort(Map errorData) - { - // Default does nothing, but can be overridden. - } - } - - /** - * Helper base class that responds to the config change and updates our AddressHelper.
- * This will also reset the mission poller. Depending on the state, more - * work may be needed (eg to recreate the command handler, etc) - it's up to - * the individual state episodes to do whatever else needs doing. - */ - abstract public class ConfigAwareStateEpisode extends ErrorAwareEpisode - { - ConfigAwareStateEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - public void onConfigChanged(OnConfigChangedEvent ev) - { - if (ev.getConfigID().equals(MalmoMod.SOCKET_CONFIGS)) - { - AddressHelper.update(MalmoMod.instance.getModSessionConfigFile()); - try - { - ClientStateMachine.this.initialiseComms(); - } - catch (UnknownHostException e) - { - // TODO What to do here? - e.printStackTrace(); - } - ScreenHelper.update(MalmoMod.instance.getModPermanentConfigFile()); - TCPUtils.update(MalmoMod.instance.getModPermanentConfigFile()); - } - } - } - - /** Initial episode - perform client setup */ - public class InitialiseClientModEpisode extends ConfigAwareStateEpisode - { - InitialiseClientModEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() throws Exception - { - ClientStateMachine.this.initialiseComms(); - - // This is necessary in order to allow user to exit the Minecraft window without halting the experiment: - GameSettings settings = Minecraft.getMinecraft().gameSettings; - settings.pauseOnLostFocus = false; - // And hook the screen helper into the ingame gui (which is responsible for overlaying chat, titles etc) - - // this has to be done after Minecraft.init(), so we do it here. - ScreenHelper.hookIntoInGameGui(); - } - - @Override - public void onRenderTick(TickEvent.RenderTickEvent ev) - { - // We wait until we start to get render ticks, at which point we assume Minecraft has finished starting up. - episodeHasCompleted(ClientState.DORMANT); - } - } - - // --------------------------------------------------------------------------------------------------------- - /** Dormant state - receptive to new missions */ - public class DormantEpisode extends ConfigAwareStateEpisode - { - private ClientStateMachine csMachine; - - protected DormantEpisode(ClientStateMachine machine) - { - super(machine); - this.csMachine = machine; - } - - @Override - protected void execute() - { - TextureHelper.init(); - - // Clear our current MissionInit state: - csMachine.currentMissionInit = null; - // Clear our current error state: - clearErrorDetails(); - // And clear out any stale commands left over from recent missions: - if (ClientStateMachine.this.controlInputPoller != null) - ClientStateMachine.this.controlInputPoller.clearCommands(); - // Finally, do some Java housekeeping: - System.gc(); - } - - @Override - public void onClientTick(TickEvent.ClientTickEvent ev) throws Exception - { - - Minecraft.getMinecraft().mcProfiler.startSection("malmoHandleMissionCommands"); - checkForMissionCommand(); - Minecraft.getMinecraft().mcProfiler.endSection(); - } - - private void checkForMissionCommand() throws Exception - { - if (ClientStateMachine.this.missionPoller == null) - return; - - CommandAndIPAddress comip = missionPoller.getCommandAndIPAddress(); - if (comip == null) - return; - String missionMessage = comip.command; - if (missionMessage == null || missionMessage.length() == 0) - return; - Minecraft.getMinecraft().mcProfiler.endSection(); - Minecraft.getMinecraft().mcProfiler.startSection("malmoDecodeMissionInit"); - - MissionInitResult missionInitResult = decodeMissionInit(missionMessage); - Minecraft.getMinecraft().mcProfiler.endSection(); - - MissionInit missionInit = missionInitResult.missionInit; - if (missionInit != null) - { - missionInit.getClientAgentConnection().setAgentIPAddress(comip.ipAddress); - System.out.println("Mission received: " + missionInit.getMission().getAbout().getSummary()); - csMachine.currentMissionInit = missionInit; - TimeHelper.SyncManager.numTicks = 0; - ScoreHelper.logMissionInit(missionInit); - - ClientStateMachine.this.createMissionControlSocket(); - // Move on to next state: - episodeHasCompleted(ClientState.CREATING_HANDLERS); - } - else - { - throw new Exception("Failed to get valid MissionInit object from SchemaHelper."); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Now the MissionInit XML has been decoded, the client needs to create the - * Mission Handlers. - */ - public class CreateHandlersEpisode extends ConfigAwareStateEpisode - { - protected CreateHandlersEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() throws Exception - { - // First, clear our reservation state, if we were reserved: - ClientStateMachine.this.cancelReservation(); - - // Now try creating the handlers: - try - { - if(envServer != null){ - SeedHelper.advanceNextSeed(envServer.getSeed()); - } - ClientStateMachine.this.missionBehaviour = MissionBehaviour.createAgentHandlersFromMissionInit(currentMissionInit()); - if (envServer != null) { - ClientStateMachine.this.missionBehaviour.addQuitProducer(envServer); - } - } - catch (Exception e) - { - // TODO - System.err.println("ERROR: Exception caught making agent handlers" + e.toString()); - e.printStackTrace(); - } - // Set up our command input poller. This is only checked during the MissionRunning episode, but - // it needs to be started now, so we can report the port it's using back to the agent. - TCPUtils.LogSection ls = new TCPUtils.LogSection("Initialise Command Input Poller"); - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - int requestedPort = cac.getClientCommandsPort(); - // If the requested port is 0, we dynamically allocate our own port, and feed that back to the agent. - // If the requested port is non-zero, we have to use it. - if (requestedPort != 0 && ClientStateMachine.this.controlInputPoller != null && ClientStateMachine.this.controlInputPoller.getPort() != requestedPort) - { - // A specific port has been requested, and it's not the one we are currently using, - // so we need to recreate our poller. - System.out.println("Requested command port is not the same as the input poller port; the port was not free. Stopping server."); - ClientStateMachine.this.controlInputPoller.stopServer(); - ClientStateMachine.this.controlInputPoller = null; - } - if (ClientStateMachine.this.controlInputPoller == null) - { - if (requestedPort == 0) - ClientStateMachine.this.controlInputPoller = new TCPInputPoller(AddressHelper.MIN_FREE_PORT, AddressHelper.MAX_FREE_PORT, true, "com"); - else - ClientStateMachine.this.controlInputPoller = new TCPInputPoller(requestedPort, "com"); - System.out.println("Starting command server."); - ClientStateMachine.this.controlInputPoller.start(); - } - // Make sure the cac is up-to-date: - cac.setClientCommandsPort(ClientStateMachine.this.controlInputPoller.getPortBlocking()); - ls.close(); - - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - // Set the agent's name as the current username: - List agents = currentMissionInit().getMission().getAgentSection(); - String agentName = agents.get(currentMissionInit().getClientRole()).getName(); - AuthenticationHelper.setPlayerName(Minecraft.getMinecraft().getSession(), agentName); - // If the player's profile properties are empty, MC will keep pinging the Minecraft session service - // to fill them, resulting in multiple http requests and grumpy responses from the server - // (see https://github.com/Microsoft/malmo/issues/568). - // To prevent this, we add a dummy property. - Minecraft.getMinecraft().getProfileProperties().put("dummy", new Property("dummy", "property")); - // Handlers and poller created successfully; proceed to next stage of loading. - // We will either need to connect to an existing server, or to start - // a new integrated server ourselves, depending on our role. - // For now, assume that the mod with role 0 is responsible for the server. - if (currentMissionInit().getClientRole() == 0) - { - // We are responsible for the server - investigate what needs to happen next: - episodeHasCompleted(ClientState.EVALUATING_WORLD_REQUIREMENTS); - } - else - { - // We may need to connect to a server. - episodeHasCompleted(ClientState.WAITING_FOR_SERVER_READY); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Attempt to connect to a server. Wait until connection is established. - */ - public class WaitingForServerEpisode extends ConfigAwareStateEpisode - { - String agentName; - int ticksUntilNextPing = 0; - int totalTicks = 0; - boolean waitingForChunk = false; - boolean waitingForPlayer = true; - - protected WaitingForServerEpisode(ClientStateMachine machine) - { - super(machine); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_ALLPLAYERSJOINED); - } - - private boolean isChunkReady() - { - // First, find the starting position we ought to have: - List agents = currentMissionInit().getMission().getAgentSection(); - if (agents == null || agents.size() <= currentMissionInit().getClientRole()) - return true; // This should never happen. - AgentSection as = agents.get(currentMissionInit().getClientRole()); - if (as.getAgentStart() != null && as.getAgentStart().getPlacement() != null) - { - PosAndDirection pos = as.getAgentStart().getPlacement(); - int x = MathHelper.floor(pos.getX().doubleValue()) >> 4; - int z = MathHelper.floor(pos.getZ().doubleValue()) >> 4; - // Now get the chunk we should be starting in: - IChunkProvider chunkprov = Minecraft.getMinecraft().world.getChunkProvider(); - EntityPlayerSP player = Minecraft.getMinecraft().player; - if (player.addedToChunk) - { - // Our player is already added to a chunk - is it the right one? - Chunk actualChunk = chunkprov.provideChunk(player.chunkCoordX, player.chunkCoordZ); - Chunk requestedChunk = chunkprov.provideChunk(x, z); - if (actualChunk == requestedChunk && actualChunk != null && !actualChunk.isEmpty()) - { - // We're in the right chunk, and it's not an empty chunk. - // We're ready to proceed, but first set our client positions to where we ought to be. - // The server should be doing this too, but there's no harm (probably) in doing it ourselves. - player.posX = pos.getX().doubleValue(); - player.posY = pos.getY().doubleValue(); - player.posZ = pos.getZ().doubleValue(); - return true; - } - } - return false; // Our starting position has been specified, but it's not yet ready. - } - return true; // No starting position specified, so doesn't matter where we start. - } - - @Override - protected void onClientTick(ClientTickEvent ev) - { - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - if (this.waitingForPlayer) - { - if (Minecraft.getMinecraft().player != null) - { - this.waitingForPlayer = false; - handleLan(); - } - else - return; - } - - totalTicks++; - - if (ticksUntilNextPing == 0) - { - // Tell the server what our agent name is. - // We do this repeatedly, because the server might not yet be listening. - if (Minecraft.getMinecraft().player != null && !this.waitingForChunk) - { - HashMap map = new HashMap(); - map.put("agentname", agentName); - map.put("username", Minecraft.getMinecraft().player.getName()); - currentMissionBehaviour().appendExtraServerInformation(map); - System.out.println("***Telling server we are ready - " + agentName); - MalmoMod.network.sendToServer(new MalmoMod.MalmoMessage(MalmoMessageType.CLIENT_AGENTREADY, 0, map)); - } - - // We also ping our agent, just to check it is still available: - pingAgent(true); // Will abort to an error state if client unavailable. - - ticksUntilNextPing = 10; // Try again in ten ticks. - } - else - { - ticksUntilNextPing--; - } - - if (this.waitingForChunk) - { - // The server is ready, we're just waiting for our chunk to appear. - if (isChunkReady()) - proceed(); - } - - List agents = currentMissionInit().getMission().getAgentSection(); - boolean completedWithErrors = false; - - if (agents.size() > 1 && currentMissionInit().getClientRole() != 0) - { - // We are waiting to join an out-of-process server. Need to pay attention to what happens - - // if we can't join, for any reason, we should abort the mission. - GuiScreen screen = Minecraft.getMinecraft().currentScreen; - if (screen != null && screen instanceof GuiDisconnected) { - // Disconnected screen appears when something has gone wrong. - // Would be nice to grab the reason from the screen, but it's a private member. - // (Can always use reflection, but it's so inelegant.) - String msg = "Unable to connect to Minecraft server in multi-agent mission."; - TCPUtils.Log(Level.SEVERE, msg); - episodeHasCompletedWithErrors(ClientState.ERROR_CANNOT_CONNECT_TO_SERVER, msg); - completedWithErrors = true; - } - } - - if (!completedWithErrors && totalTicks > WAIT_MAX_TICKS) - { - String msg = "Too long waiting for server episode to start."; - TCPUtils.Log(Level.SEVERE, msg); - // AOG - If we have timed out waiting for the server to be ready, then the - // MalmoEnvServer is also likely stuck trying to handle a peek request from - // Python client. We need to signal the env server should abort the request - // so that the client detects the error and can retry. - if (envServer != null) { - envServer.abort(); - } - episodeHasCompletedWithErrors(ClientState.ERROR_TIMED_OUT_WAITING_FOR_EPISODE_START, msg); - } - } - - @Override - protected void execute() throws Exception - { - totalTicks = 0; - - Minecraft.getMinecraft().displayGuiScreen(null); // Clear any menu screen that might confuse things. - // Get our name from the Mission: - List agents = currentMissionInit().getMission().getAgentSection(); - //if (agents == null || agents.size() <= currentMissionInit().getClientRole()) - // throw new Exception("No agent section for us!"); // TODO - this.agentName = agents.get(currentMissionInit().getClientRole()).getName(); - - if (agents.size() > 1 && currentMissionInit().getClientRole() != 0) - { - // Multi-agent mission, we should be joining a server. - // (Unless we are already on the correct server.) - String address = currentMissionInit().getMinecraftServerConnection().getAddress().trim(); - int port = currentMissionInit().getMinecraftServerConnection().getPort(); - String targetIP = address + ":" + port; - System.out.println("We should be joining " + targetIP); - EntityPlayerSP player = Minecraft.getMinecraft().player; - boolean namesMatch = (player == null) || Minecraft.getMinecraft().player.getName().equals(this.agentName); - if (!namesMatch) - { - // The name of our agent no longer matches the agent in our game profile - - // safest way to update is to log out and back in again. - // This hangs so just warn instead about the miss-match and proceed. - TCPUtils.Log(Level.WARNING,"Agent name does not match agent in game."); - // Minecraft.getMinecraft().world.sendQuittingDisconnectingPacket(); - // Minecraft.getMinecraft().loadWorld((WorldClient)null); - } - if (Minecraft.getMinecraft().getCurrentServerData() == null || !Minecraft.getMinecraft().getCurrentServerData().serverIP.equals(targetIP)) - { - net.minecraftforge.fml.client.FMLClientHandler.instance().connectToServerAtStartup(address, port); - } - this.waitingForPlayer = false; - } - } - - protected void handleLan() - { - // Get our name from the Mission: - List agents = currentMissionInit().getMission().getAgentSection(); - //if (agents == null || agents.size() <= currentMissionInit().getClientRole()) - // throw new Exception("No agent section for us!"); // TODO - this.agentName = agents.get(currentMissionInit().getClientRole()).getName(); - - if (agents.size() > 1 && currentMissionInit().getClientRole() == 0) // Multi-agent mission - make sure the server is open to the LAN: - { - MinecraftServerConnection msc = new MinecraftServerConnection(); - String address = currentMissionInit().getClientAgentConnection().getClientIPAddress(); - // Do we need to open to LAN? - if (Minecraft.getMinecraft().isSingleplayer() && !Minecraft.getMinecraft().getIntegratedServer().getPublic()) - { - String portstr = Minecraft.getMinecraft().getIntegratedServer().shareToLAN(GameType.SURVIVAL, true); // Set to true to stop spam kicks. - ClientStateMachine.this.integratedServerPort = Integer.valueOf(portstr); - } - - TCPUtils.Log(Level.INFO,"Integrated server port: " + ClientStateMachine.this.integratedServerPort); - msc.setPort(ClientStateMachine.this.integratedServerPort); - msc.setAddress(address); - - if (envServer != null) { - envServer.notifyIntegrationServerStarted(ClientStateMachine.this.integratedServerPort); - } - currentMissionInit().setMinecraftServerConnection(msc); - } - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - super.onMessage(messageType, data); - - if (messageType != MalmoMessageType.SERVER_ALLPLAYERSJOINED) - return; - - List handlers = new ArrayList(); - for (Entry entry : data.entrySet()) - { - if (entry.getKey().equals("startPosition")) - { - try - { - String[] parts = entry.getValue().split(":"); - Float x = Float.valueOf(parts[0]); - Float y = Float.valueOf(parts[1]); - Float z = Float.valueOf(parts[2]); - // Find the starting position we ought to have: - List agents = currentMissionInit().getMission().getAgentSection(); - if (agents != null && agents.size() > currentMissionInit().getClientRole()) - { - // And write this new position into it: - AgentSection as = agents.get(currentMissionInit().getClientRole()); - AgentStart startSection = as.getAgentStart(); - if (startSection != null) - { - PosAndDirection pos = startSection.getPlacement(); - if (pos == null) - pos = new PosAndDirection(); - pos.setX(new BigDecimal(x)); - pos.setY(new BigDecimal(y)); - pos.setZ(new BigDecimal(z)); - startSection.setPlacement(pos); - as.setAgentStart(startSection); - } - } - } - catch (Exception e) - { - System.out.println("Couldn't interpret position data"); - } - } - else - { - String extraHandler = entry.getValue(); - if (extraHandler != null && extraHandler.length() > 0) - { - try - { - Class handlerClass = Class.forName(entry.getKey()); - Object handler = SchemaHelper.deserialiseObject(extraHandler, "MissionInit.xsd", handlerClass); - handlers.add(handler); - } - catch (Exception e) - { - System.out.println("Error trying to create extra handlers: " + e); - // Do something... like episodeHasCompletedWithErrors(nextState, error)? - } - } - } - } - if (!handlers.isEmpty()) - currentMissionBehaviour().addExtraHandlers(handlers); - this.waitingForChunk = true; - } - - private void proceed() - { - // The server is ready, so send our MissionInit back to the agent and go! - // We launch the agent by sending it the MissionInit message we were sent - // (but with the Launcher's IP address included) - String xml = null; - boolean sentOkay = false; - String errorReport = ""; - try - { - xml = SchemaHelper.serialiseObject(currentMissionInit(), MissionInit.class); - if (AddressHelper.getMissionControlPort() == 0) { - if (envServer != null) { - // TODO MalmoEnvServer <- Running - } - sentOkay = true; - } else { - sentOkay = ClientStateMachine.this.getMissionControlSocket().sendTCPString(xml, 1); - } - } - catch (JAXBException e) - { - errorReport = e.getMessage(); - } - if (sentOkay) - episodeHasCompleted(ClientState.RUNNING); - else - { - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Could not contact agent to start mission - mission will abort.", TextCategory.TXT_CLIENT_WARNING, 10000); - if (!errorReport.isEmpty()) - { - ClientStateMachine.this.getScreenHelper().addFragment("ERROR DETAILS: " + errorReport, TextCategory.TXT_CLIENT_WARNING, 10000); - errorReport = ": " + errorReport; - } - episodeHasCompletedWithErrors(ClientState.ERROR_CANNOT_START_AGENT, "Failed to send MissionInit back to agent" + errorReport); - } - } - - @Override - public void cleanup() - { - super.cleanup(); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_ALLPLAYERSJOINED); - } - } - - /** - * Wait for the server to decide the mission has ended.
- * We're not allowed to return to dormant until the server decides everyone can. - */ - public class WaitingForServerMissionEndEpisode extends ConfigAwareStateEpisode - { - protected WaitingForServerMissionEndEpisode(ClientStateMachine machine) - { - super(machine); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_MISSIONOVER); - } - - @Override - protected void execute() throws Exception - { - // Get our name from the Mission: - List agents = currentMissionInit().getMission().getAgentSection(); - if (agents == null || agents.size() <= currentMissionInit().getClientRole()) - throw new Exception("No agent section for us!"); // TODO - String agentName = agents.get(currentMissionInit().getClientRole()).getName(); - - // Now send a message to the server saying that we are ready: - HashMap map = new HashMap(); - map.put("agentname", agentName); - MalmoMod.network.sendToServer(new MalmoMod.MalmoMessage(MalmoMessageType.CLIENT_AGENTSTOPPED, 0, map)); - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - super.onMessage(messageType, data); - if (messageType == MalmoMessageType.SERVER_MISSIONOVER) - episodeHasCompleted(ClientState.DORMANT); - } - - @Override - public void cleanup() - { - super.cleanup(); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_MISSIONOVER); - } - - @Override - protected void onAbort(Map errorData) - { - episodeHasCompleted(ClientState.MISSION_ABORTED); - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Depending on the basemap provided, either begin to perform a full world - * load, or reset the current world - */ - public class EvaluateWorldRequirementsEpisode extends ConfigAwareStateEpisode - { - EvaluateWorldRequirementsEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() - { - // We are responsible for creating the server, if required. - // This means we need access to the server's MissionHandlers: - MissionBehaviour serverHandlers = null; - try - { - serverHandlers = MissionBehaviour.createServerHandlersFromMissionInit(currentMissionInit()); - } - catch (Exception e) - { - episodeHasCompletedWithErrors(ClientState.ERROR_DUFF_HANDLERS, "Could not create server mission handlers: " + e.getMessage()); - } - - World world = null; - if (Minecraft.getMinecraft().getIntegratedServer() != null) - world = Minecraft.getMinecraft().getIntegratedServer().getEntityWorld(); - - boolean needsNewWorld = serverHandlers != null && serverHandlers.worldGenerator != null && serverHandlers.worldGenerator.shouldCreateWorld(currentMissionInit(), world); - boolean worldCurrentlyExists = world != null; - if (worldCurrentlyExists) - { - // If a world already exists, we need to check that our requested agent name matches the name - // of the player. If not, the safest thing to do is start a new server. - // Get our name from the Mission: - List agents = currentMissionInit().getMission().getAgentSection(); - String agentName = agents.get(currentMissionInit().getClientRole()).getName(); - if (Minecraft.getMinecraft().player != null) - { - if (!Minecraft.getMinecraft().player.getName().equals(agentName)) - needsNewWorld = true; - } - } - if (needsNewWorld && worldCurrentlyExists) - { - // We want a new world, and there is currently a world running, - // so we need to kill the current world. - episodeHasCompleted(ClientState.PAUSING_OLD_SERVER); - } - else if (needsNewWorld && !worldCurrentlyExists) - { - // We want a new world, and there is currently nothing running, - // so jump to world creation: - episodeHasCompleted(ClientState.CREATING_NEW_WORLD); - } - else if (!needsNewWorld && worldCurrentlyExists) - { - // We don't want a new world, and we can use the current one - - // but we own the server, so we need to pass it the new mission init: - Minecraft.getMinecraft().getIntegratedServer().addScheduledTask(new Runnable() - { - @Override - public void run() - { - try - { - MalmoMod.instance.sendMissionInitDirectToServer(currentMissionInit); - } - catch (Exception e) - { - episodeHasCompletedWithErrors(ClientState.ERROR_INTEGRATED_SERVER_UNREACHABLE, "Could not send MissionInit to our integrated server: " + e.getMessage()); - } - } - }); - // Skip all the map loading stuff and go straight to waiting for the server: - episodeHasCompleted(ClientState.WAITING_FOR_SERVER_READY); - } - else if (!needsNewWorld && !worldCurrentlyExists) - { - // Mission has requested no new world, but there is no current world to play in - this is an error: - episodeHasCompletedWithErrors(ClientState.ERROR_NO_WORLD, "We have no world to play in - check that your ServerHandlers section contains a world generator"); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Pause the old server. It's vital that we do this, otherwise it will - * respond to the quit disconnect package straight away and kill the server - * thread, which means there will be no server to respond to the loadWorld - * code. - */ - public class PauseOldServerEpisode extends ConfigAwareStateEpisode - { - int serverTickCount = 0; - int clientTickCount = 0; - int totalTicks = 0; - - PauseOldServerEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() - { - serverTickCount = 0; - clientTickCount = 0; - totalTicks = 0; - - if (Minecraft.getMinecraft().getIntegratedServer() != null && Minecraft.getMinecraft().world != null) - { - // If the integrated server has been opened to the LAN, we won't be able to pause it. - // To get around this, we need to make it think it's not open, by modifying its isPublic flag. - if (Minecraft.getMinecraft().getIntegratedServer().getPublic()) - { - if (!killPublicFlag(Minecraft.getMinecraft().getIntegratedServer())) - { - episodeHasCompletedWithErrors(ClientState.ERROR_CANNOT_CREATE_WORLD, "Can not pause the old server since it's open to LAN; no way to safely create new world."); - } - } - - Minecraft.getMinecraft().displayGuiScreen(new GuiIngameMenu()); - } - } - - private boolean killPublicFlag(IntegratedServer server) - { - // Are we in a dev environment? - boolean devEnv = (Boolean) Launch.blackboard.get("fml.deobfuscatedEnvironment"); - // We need to know, because the member name will either be obfuscated or not. - String isPublicMemberName = devEnv ? "isPublic" : "field_71346_p"; - // NOTE: obfuscated name may need updating if Forge changes. - Field isPublic; - try - { - isPublic = IntegratedServer.class.getDeclaredField(isPublicMemberName); - isPublic.setAccessible(true); - isPublic.set(server, false); - return true; - } - catch (SecurityException e) - { - e.printStackTrace(); - } - catch (IllegalAccessException e) - { - e.printStackTrace(); - } - catch (IllegalArgumentException e) - { - e.printStackTrace(); - } - catch (NoSuchFieldException e) - { - e.printStackTrace(); - } - return false; - } - - @Override - public void onClientTick(TickEvent.ClientTickEvent ev) - { - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - // We need to make sure that both the client and server have paused. - - // Since the server sets its pause state in response to the client's pause state, - // and it only performs this check once, at the top of its tick method, - // to be sure that the server has had time to set the flag correctly we need to make sure - // that at least one server tick method has *started* since the flag was set. - // We can't do this by catching the onServerTick events, since we don't receive them when the game is paused. - - // The following code makes use of the fact that the server both locks and empties the server's futureQueue, - // every time through the server tick method. - // This locking means that if the client - which needs to wait on the lock - - // tries to add an event to the queue in response to an event on the queue being executed, - // the newly added event will have to happen in a subsequent tick. - if ((Minecraft.getMinecraft().isGamePaused() || Minecraft.getMinecraft().player == null) && ev != null && ev.phase == Phase.END && this.clientTickCount == this.serverTickCount && this.clientTickCount <= 2) - { - this.clientTickCount++; // Increment our count, and wait for the server to catch up. - Minecraft.getMinecraft().getIntegratedServer().addScheduledTask(new Runnable() - { - public void run() - { - // Increment the server count. - PauseOldServerEpisode.this.serverTickCount++; - } - }); - } - - if (this.serverTickCount > 2) { - episodeHasCompleted(ClientState.CLOSING_OLD_SERVER); - } else if (++totalTicks > WAIT_MAX_TICKS) { - String msg = "Too long waiting for server episode to pause."; - TCPUtils.Log(Level.SEVERE, msg); - episodeHasCompletedWithErrors(ClientState.ERROR_TIMED_OUT_WAITING_FOR_EPISODE_PAUSE, msg); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Send a disconnecting message to the current server - sent before attempting to load a new world. - */ - public class CloseOldServerEpisode extends ConfigAwareStateEpisode - { - int totalTicks; - - CloseOldServerEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() - { - totalTicks = 0; - - if (Minecraft.getMinecraft().world != null) - { - // If the Minecraft server isn't paused at this point, - // then the following line will cause the server thread to exit... - Minecraft.getMinecraft().world.sendQuittingDisconnectingPacket(); - // ...in which case the next line will block. - Minecraft.getMinecraft().loadWorld((WorldClient) null); - // Must display the GUI or Minecraft will attempt to access a non-existent player in the client tick. - Minecraft.getMinecraft().displayGuiScreen(new GuiMainMenu()); - - // Allow shutdown messages to flow through. - try { - Thread.sleep(10000); - } catch (InterruptedException ie) { - } - } - } - - @Override - public void onClientTick(TickEvent.ClientTickEvent ev) - { - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - if (ev.phase == Phase.END) - episodeHasCompleted(ClientState.CREATING_NEW_WORLD); - - if (++totalTicks > WAIT_MAX_TICKS) - { - String msg = "Too long waiting for server episode to close."; - TCPUtils.Log(Level.SEVERE, msg); - episodeHasCompletedWithErrors(ClientState.ERROR_TIMED_OUT_WAITING_FOR_EPISODE_CLOSE, msg); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * Attempt to create a world. - */ - public class CreateWorldEpisode extends ConfigAwareStateEpisode - { - boolean serverStarted = false; - boolean worldCreated = false; - int totalTicks = 0; - - CreateWorldEpisode(ClientStateMachine machine) - { - super(machine); - } - - @Override - protected void execute() - { - try - { - totalTicks = 0; - - // We need to use the server's MissionHandlers here: - MissionBehaviour serverHandlers = MissionBehaviour.createServerHandlersFromMissionInit(currentMissionInit()); - if (serverHandlers != null && serverHandlers.worldGenerator != null) - { - if (serverHandlers.worldGenerator.createWorld(currentMissionInit())) - { - this.worldCreated = true; - if (Minecraft.getMinecraft().getIntegratedServer() != null) - Minecraft.getMinecraft().getIntegratedServer().setOnlineMode(false); - } - else - { - // World has not been created. - episodeHasCompletedWithErrors(ClientState.ERROR_CANNOT_CREATE_WORLD, "Server world-creation handler failed to create a world: " + serverHandlers.worldGenerator.getErrorDetails()); - } - } - } - catch (Exception e) - { - episodeHasCompletedWithErrors(ClientState.ERROR_CANNOT_CREATE_WORLD, "Server world-creation handler failed to create a world: " + e.getMessage()); - } - } - - @Override - protected void onServerTick(ServerTickEvent ev) - { - if (this.worldCreated && !this.serverStarted) - { - // The server has started ticking - we can set up its state machine, - // and move on to the next state in our own machine. - this.serverStarted = true; - MalmoMod.instance.initIntegratedServer(currentMissionInit()); // Needs to be done from the server thread. - episodeHasCompleted(ClientState.WAITING_FOR_SERVER_READY); - } - } - - @Override - public void onClientTick(TickEvent.ClientTickEvent ev) - { - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - if (++totalTicks > WAIT_MAX_TICKS) - { - String msg = "Too long waiting for world to be created."; - TCPUtils.Log(Level.SEVERE, msg); - episodeHasCompletedWithErrors(ClientState.ERROR_TIMED_OUT_WAITING_FOR_WORLD_CREATE, msg); - } - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * State in which an agent has finished the mission, but is waiting for the server to draw stumps. - */ - public class MissionIdlingEpisode extends ConfigAwareStateEpisode - { - int totalTicks = 0; - - protected MissionIdlingEpisode(ClientStateMachine machine) - { - super(machine); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_STOPAGENTS); - } - - @Override - protected void execute() - { - totalTicks = 0; - TimeHelper.SyncManager.numTicks = 0; - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - super.onMessage(messageType, data); - // This message will be sent to us once the server has decided the mission is over. - if (messageType == MalmoMessageType.SERVER_STOPAGENTS) - episodeHasCompleted(ClientState.MISSION_ENDED); - } - - @Override - public void cleanup() - { - super.cleanup(); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_STOPAGENTS); - } - - @Override - public void onClientTick(TickEvent.ClientTickEvent ev) - { - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - episodeHasCompleted(ClientState.MISSION_ABORTED); - - ++totalTicks; - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * State in which a mission is running.
- * This state is ended by the death of the player or by the IWantToQuit - * handler, or by the server declaring the mission is over. - */ - public class MissionRunningEpisode extends ConfigAwareStateEpisode implements VideoProducedObserver - { - public static final int FailedTCPSendCountTolerance = 3; // Number of TCP timeouts before we cancel the mission - - protected MissionRunningEpisode(ClientStateMachine machine) - { - super(machine); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_STOPAGENTS); - MalmoMod.MalmoMessageHandler.registerForMessage(this, MalmoMessageType.SERVER_GO); - } - - boolean serverHasFiredStartingPistol = false; - boolean playerDied = false; - private int failedTCPRewardSendCount = 0; - private int failedTCPObservationSendCount = 0; - private boolean wantsToQuit = false; // We have decided our mission is at an end - private List videoHooks = new ArrayList(); - private String quitCode = ""; - private TCPSocketChannel observationSocket = null; - private TCPSocketChannel rewardSocket = null; - private long lastPingSent = 0; - private long pingFrequencyMs = 1000; - private boolean shouldMissionEnd = false; - private long frameTimestamp = 0; - - public void frameProduced() { - this.frameTimestamp = System.currentTimeMillis(); - } - - protected void onMissionStarted() - { - frameTimestamp = 0; - - // Open our communication channels: - openSockets(); - - this.shouldMissionEnd = false; - // Tell the server we have started: - HashMap map = new HashMap(); - map.put("username", Minecraft.getMinecraft().player.getName()); - MalmoMod.network.sendToServer(new MalmoMod.MalmoMessage(MalmoMessageType.CLIENT_AGENTRUNNING, 0, map)); - - // Set up our mission handlers: - if (currentMissionBehaviour().commandHandler != null) - { - currentMissionBehaviour().commandHandler.install(currentMissionInit()); - currentMissionBehaviour().commandHandler.setOverriding(true); - } - - if (currentMissionBehaviour().observationProducer != null) - currentMissionBehaviour().observationProducer.prepare(currentMissionInit()); - - if (currentMissionBehaviour().quitProducer != null) - currentMissionBehaviour().quitProducer.prepare(currentMissionInit()); - - if (currentMissionBehaviour().rewardProducer != null) - currentMissionBehaviour().rewardProducer.prepare(currentMissionInit()); - - if (currentMissionBehaviour().performanceProducer != null) - currentMissionBehaviour().performanceProducer.prepare(currentMissionInit()); - - // Disable the gui for the episode! - Minecraft.getMinecraft().gameSettings.hideGUI = true; - - for (IVideoProducer videoProducer : currentMissionBehaviour().videoProducers) - { - VideoHook hook = new VideoHook(); - this.videoHooks.add(hook); - frameProduced(); - hook.start(currentMissionInit(), videoProducer, this, envServer); - } - - // Make sure we have mouse control: - ClientStateMachine.this.inputController.setInputType(InputType.AI); - Minecraft.getMinecraft().inGameHasFocus = true; // Otherwise auto-repeat won't work for mouse clicks. - - // Overclocking: - ModSettings modsettings = currentMissionInit().getMission().getModSettings(); - if (modsettings != null && modsettings.getMsPerTick() != null) - TimeHelper.setMinecraftClientClockSpeed(1000 / modsettings.getMsPerTick()); - if (modsettings != null && modsettings.isPrioritiseOffscreenRendering() == Boolean.TRUE) - TimeHelper.displayGranularityMs = 1000; - TimeHelper.unpause(); - - // Synchronization - if (envServer != null){ - if(!envServer.doIWantToQuit(currentMissionInit())){ - TimeHelper.SyncManager.setSynchronous(envServer.isSynchronous()); - } else { - TimeHelper.SyncManager.setSynchronous(false); - } - } - } - - protected void onMissionEnded(IState nextState, String errorReport) - { - //Send the final data associated with the misson here. - this.shouldMissionEnd = false; - sendData(true); - - // Tidy up our mission handlers: - if (currentMissionBehaviour().rewardProducer != null) - currentMissionBehaviour().rewardProducer.cleanup(); - - if (currentMissionBehaviour().quitProducer != null) - currentMissionBehaviour().quitProducer.cleanup(); - - if (currentMissionBehaviour().observationProducer != null) - currentMissionBehaviour().observationProducer.cleanup(); - - if (currentMissionBehaviour().commandHandler != null) - { - currentMissionBehaviour().commandHandler.setOverriding(false); - currentMissionBehaviour().commandHandler.deinstall(currentMissionInit()); - } - - if (AddressHelper.getMissionControlPort() == 0) { - if (envServer != null) { - byte[] obs = envServer.getObservation(false); - envServer.endMission(); - } - } - - // Close our communication channels: - closeSockets(); - - for (VideoHook hook : this.videoHooks) - hook.stop(ClientStateMachine.this.missionEndedData); - - - // Disable the gui for the episode! - Minecraft.getMinecraft().gameSettings.hideGUI = false; - - // Return Minecraft speed to "normal": - TimeHelper.SyncManager.setPistolFired(false); - TimeHelper.setMinecraftClientClockSpeed(20); - TimeHelper.displayGranularityMs = 0; - TimeHelper.unpause(); - TimeHelper.SyncManager.setSynchronous(false); - - ClientStateMachine.this.missionQuitCode = this.quitCode; - if (errorReport != null) - episodeHasCompletedWithErrors(nextState, errorReport); - else - episodeHasCompleted(nextState); - } - - @Override - protected void execute() - { - onMissionStarted(); - } - - @Override - public void onClientTick(ClientTickEvent event) - { - // If we aren't performing synchronous ticking use the client Tick to handle updates - if(!TimeHelper.SyncManager.isSynchronous()){ - onTick(false, event.phase); - } - } - - @Override - public void onSyncTick(SyncTickEvent ev){ - // If we are performing synchronous ticking - onTick(true, ev.pos); - } - - private synchronized void onTick(Boolean synchronous, TickEvent.Phase phase){ - // TimeHelper.SyncManager.debugLog("[CLIENT_STATE_MACHINE] " + phase.toString()); - // Check to see whether anything has caused us to abort - if so, go to the abort state. - if (inAbortState()) - onMissionEnded(ClientState.MISSION_ABORTED, "Mission was aborted by server: " + ClientStateMachine.this.getErrorDetails()); - // Check to see whether we've been kicked from the server. - NetHandlerPlayClient npc = Minecraft.getMinecraft().getConnection(); - if(npc == null){ - if(this.serverHasFiredStartingPistol){ - onMissionEnded(ClientState.ERROR_LOST_NETWORK_CONNECTION, "Server was closed"); - return; - } - } - else{ - NetworkManager netman = npc.getNetworkManager(); - if (netman != null && !netman.hasNoChannel() && !netman.isChannelOpen()) - { - // Connection has been lost. - onMissionEnded(ClientState.ERROR_LOST_NETWORK_CONNECTION, "Client was kicked from server - " + netman.getExitMessage().getUnformattedText()); - } - - } - - // Check we are still in touch with the agent: - if (System.currentTimeMillis() > this.lastPingSent + this.pingFrequencyMs) - { - this.lastPingSent = System.currentTimeMillis(); - // Ping the agent - if serverHasFiredStartingPistol is true, we don't need to abort - - // we can simply set the wantsToQuit flag and end the mission cleanly. - // If serverHasFiredStartingPistol is false, then the mission isn't yet running, and - // setting the quit flag will do nothing - so we need to abort. - if (!pingAgent(false)) - { - if (!this.serverHasFiredStartingPistol){ - onMissionEnded(ClientState.ERROR_LOST_AGENT, "Lost contact with the agent"); - return; - } - else - { - System.out.println("Error - agent is not responding to pings."); - this.wantsToQuit = true; - this.quitCode = MalmoMod.AGENT_UNRESPONSIVE_CODE; - } - } - } - - - if (this.frameTimestamp != 0 && (System.currentTimeMillis() - this.frameTimestamp > VIDEO_MAX_WAIT) && !synchronous) { - System.out.println("No video produced recently. Aborting mission."); - if (!this.serverHasFiredStartingPistol) - onMissionEnded(ClientState.ERROR_LOST_VIDEO, "No video produced recently."); - else - { - System.out.println("Error - not receiving video."); - this.wantsToQuit = true; - this.quitCode = MalmoMod.VIDEO_UNRESPONSIVE_CODE; - } - } - - if(Minecraft.getMinecraft().world == null){ - if(this.serverHasFiredStartingPistol){ - onMissionEnded(ClientState.ERROR_NO_WORLD, "No world for client. Must be in main menu"); - } - - return; - - } - // Check here to see whether the player has died or not: - if (!this.playerDied && Minecraft.getMinecraft().player.isDead) - { - this.playerDied = true; - this.quitCode = MalmoMod.AGENT_DEAD_QUIT_CODE; - } - - // Although we only arrive in this episode once the server has determined that all clients are ready to go, - // the server itself waits for all clients to begin running before it enters the running state itself. - // This creates a small vulnerability, since a running client could theoretically *finish* its mission - // before the server manages to *start*. - // (This has potentially disastrous effects for the state machine, and is easy to reproduce by, - // for example, setting the start point and goal of the mission to the same coordinates.) - - // To guard against this happening, although we are running, we don't act on anything - - // we don't check for commands, or send observations or rewards - until we get the SERVER_GO signal, - // which is sent once the server's running episode has started. - - - TimeHelper.SyncManager.setPistolFired(this.serverHasFiredStartingPistol); - if (!this.serverHasFiredStartingPistol){ - return; - } - - // Perhaps the race condition could be that synchronous is then set to false when the quit command is recieved! - if(synchronous && phase == Phase.START){ - checkForControlCommand(); - } - if (phase == Phase.END) - { - - - // Check whether or not we want to quit: - IWantToQuit quitHandler = (currentMissionBehaviour() != null) ? currentMissionBehaviour().quitProducer : null; - boolean quitHandlerFired = (quitHandler != null && quitHandler.doIWantToQuit(currentMissionInit())); - if (quitHandlerFired || this.wantsToQuit || this.playerDied || this.shouldMissionEnd) - { - if (quitHandlerFired) - { - this.quitCode = quitHandler.getOutcome(); - } - try - { - // Save the quit code for anything that needs it: - MalmoMod.getPropertiesForCurrentThread().put("QuitCode", this.quitCode); - } - catch (Exception e) - { - System.out.println("Failed to get properties - final reward may go missing."); - } - - // Get the final reward data: - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - // if (currentMissionBehaviour() != null && currentMissionBehaviour().rewardProducer != null && cac != null) - // currentMissionBehaviour().rewardProducer.getReward(currentMissionInit(), ClientStateMachine.this.finalReward); - - // Now send a message to the server saying that we have finished our mission: - List agents = currentMissionInit().getMission().getAgentSection(); - String agentName = agents.get(currentMissionInit().getClientRole()).getName(); - HashMap map = new HashMap(); - map.put("agentname", agentName); - map.put("username", Minecraft.getMinecraft().player.getName()); - map.put("quitcode", this.quitCode); - MalmoMod.network.sendToServer(new MalmoMod.MalmoMessage(MalmoMessageType.CLIENT_AGENTFINISHEDMISSION, 0, map)); - - onMissionEnded(ClientState.MISSION_ABORTED, null); - } - else - { - // If in the case that we are asynchronous, do this - // wack stuff of checking input at the end of a tick... - if(!synchronous){ - - checkForControlCommand(); - } - - // Send off observation and reward data: - // And see if we have any incoming commands to act upon: - sendData(false); - } - } - } - - private void openSockets() - { - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - this.observationSocket = new TCPSocketChannel(cac.getAgentIPAddress(), cac.getAgentObservationsPort(), "obs"); - this.rewardSocket = new TCPSocketChannel(cac.getAgentIPAddress(), cac.getAgentRewardsPort(), "rew"); - } - - private void closeSockets() - { - this.observationSocket.close(); - this.rewardSocket.close(); - } - - private void sendData(boolean done) - { - TCPUtils.LogSection ls = new TCPUtils.LogSection("Sending data"); - - Minecraft.getMinecraft().mcProfiler.startSection("malmoSendData"); - // Create the observation data: - String data = ""; - Minecraft.getMinecraft().mcProfiler.startSection("malmoGatherObservationJSON"); - - if (currentMissionBehaviour() != null && currentMissionBehaviour().observationProducer != null) - { - JsonObject json = new JsonObject(); - currentMissionBehaviour().observationProducer.writeObservationsToJSON(json, currentMissionInit()); - data = json.toString(); - } - Minecraft.getMinecraft().mcProfiler.endSection(); //malmogatherjson - Minecraft.getMinecraft().mcProfiler.startSection("malmoSendTCPObservations"); - - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - - if (data != null && data.length() > 2 && cac != null) // An empty json string will be "{}" (length 2) - don't send these. - { - // TimeHelper.SyncManager.debugLog("[CLIENT_STATE_MACHINE INFO] " + Integer.toString(AddressHelper.getMissionControlPort())); - if (AddressHelper.getMissionControlPort() == 0) { - if (envServer != null) { - // TODO wierd, aren't we doing this? - envServer.observation(data); - } - } else { - if (this.observationSocket.sendTCPString(data)) { - this.failedTCPObservationSendCount = 0; - } else { - // Failed to send observation message. - this.failedTCPObservationSendCount++; - TCPUtils.Log(Level.WARNING, "Observation signal delivery failure count at " + this.failedTCPObservationSendCount); - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Agent missed observation signal", TextCategory.TXT_CLIENT_WARNING, 5000); - } - } - } - Minecraft.getMinecraft().mcProfiler.endSection(); //malmotcp - Minecraft.getMinecraft().mcProfiler.startSection("malmoGatherRewardSignal"); - - // Now create the reward signal: - if (currentMissionBehaviour() != null && currentMissionBehaviour().rewardProducer != null && cac != null) - { - MultidimensionalReward reward = new MultidimensionalReward(); - currentMissionBehaviour().rewardProducer.getReward(currentMissionInit(), reward); - - if (!reward.isEmpty()) - { - - String strReward = reward.getAsSimpleString(); - Minecraft.getMinecraft().mcProfiler.startSection("malmoSendTCPReward"); - - ScoreHelper.logReward(strReward); - - if (AddressHelper.getMissionControlPort() == 0) { - // MalmoEnvServer - reward - if (envServer != null) { - envServer.addRewards(reward.getRewardTotal()); - } - } else { - if (this.rewardSocket.sendTCPString(strReward)) { - this.failedTCPRewardSendCount = 0; // Reset the count of consecutive TCP failures. - } else { - // Failed to send TCP message - probably because the agent has quit under our feet. - // (This happens a lot when developing a Python agent - the developer has no easy way to quit - // the agent cleanly, so tends to kill the process.) - this.failedTCPRewardSendCount++; - TCPUtils.Log(Level.WARNING, "Reward signal delivery failure count at " + this.failedTCPRewardSendCount); - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Agent missed reward signal", TextCategory.TXT_CLIENT_WARNING, 5000); - } - } - - Minecraft.getMinecraft().mcProfiler.endSection(); //sendTCP reward. - } - if (currentMissionBehaviour().performanceProducer != null) - currentMissionBehaviour().performanceProducer.step(reward.getRewardTotal(), done); - } - else if(currentMissionBehaviour() != null){ - if (currentMissionBehaviour().performanceProducer != null) - currentMissionBehaviour().performanceProducer.step(0, done); - } - Minecraft.getMinecraft().mcProfiler.endSection(); //Gather reward. - Minecraft.getMinecraft().mcProfiler.endSection(); //sendData - - int maxFailedTCPSendCount = 0; - for (VideoHook hook : this.videoHooks) - { - if (hook.failedTCPSendCount > maxFailedTCPSendCount) - maxFailedTCPSendCount = hook.failedTCPSendCount; - } - if (maxFailedTCPSendCount > 0) - TCPUtils.Log(Level.WARNING, "Video signal failure count at " + maxFailedTCPSendCount); - // Check that our messages are getting through: - int maxFailed = Math.max(this.failedTCPRewardSendCount, maxFailedTCPSendCount); - maxFailed = Math.max(maxFailed, this.failedTCPObservationSendCount); - if (maxFailed > FailedTCPSendCountTolerance) - { - // They're not - and we've exceeded the count of allowed TCP failures. - System.out.println("ERROR: TCP messages are not getting through - quitting mission."); - this.wantsToQuit = true; - this.quitCode = MalmoMod.AGENT_UNRESPONSIVE_CODE; - } - ls.close(); - } - - /** - * Check to see if any control instructions have been received and act on them if so. - */ - public void checkForControlCommand() - { - Minecraft.getMinecraft().mcProfiler.endStartSection("malmoCommandHandling"); - String command; - boolean quitHandlerFired = false; - IWantToQuit quitHandler = (currentMissionBehaviour() != null) ? currentMissionBehaviour().quitProducer : null; - - if (envServer != null) { - command = envServer.getCommand(); - } else { - command = ClientStateMachine.this.controlInputPoller.getCommand(); - } - while (command != null && command.length() > 0 && !quitHandlerFired) - { - // TCPUtils.Log(Level.INFO, "Act on " + command); - // Pass the command to our various control overrides: - Minecraft.getMinecraft().mcProfiler.startSection("malmoCommandAct"); - - boolean handled = handleCommand(command); - // Get the next command: - if (envServer != null) { - command = envServer.getCommand(); - } else { - command = ClientStateMachine.this.controlInputPoller.getCommand(); - } - // If there *is* another command (commands came in faster than one per client tick), - // then we should check our quit producer before deciding whether to execute it. - Minecraft.getMinecraft().mcProfiler.endStartSection("malmoCommandRecheckQuitHandlers"); - if (command != null && command.length() > 0 && handled) - quitHandlerFired = (quitHandler != null && quitHandler.doIWantToQuit(currentMissionInit())); - Minecraft.getMinecraft().mcProfiler.endSection(); - } - } - - /** - * Attempt to handle a command string by passing it to our various external controllers in turn. - * - * @param command the command string to be handled. - * @return true if the command was handled. - */ - private boolean handleCommand(String command) - { - if (currentMissionBehaviour() != null && currentMissionBehaviour().commandHandler != null) - { - return currentMissionBehaviour().commandHandler.execute(command, currentMissionInit()); - } - return false; - } - - @Override - public void onMessage(MalmoMessageType messageType, Map data) - { - super.onMessage(messageType, data); - // This message will be sent to us once the server has decided the mission is over. - if (messageType == MalmoMessageType.SERVER_STOPAGENTS) - { - this.quitCode = data.containsKey("QuitCode") ? data.get("QuitCode") : ""; - try - { - // Save the quit code for anything that needs it: - MalmoMod.getPropertiesForCurrentThread().put("QuitCode", this.quitCode); - } - catch (Exception e) - { - System.out.println("Failed to get properties - final reward may go missing."); - } - // Get the final reward data: - ClientAgentConnection cac = currentMissionInit().getClientAgentConnection(); - if (currentMissionBehaviour() != null && currentMissionBehaviour().rewardProducer != null && cac != null) - currentMissionBehaviour().rewardProducer.getReward(currentMissionInit(), ClientStateMachine.this.finalReward); - - this.shouldMissionEnd = true; - - } - else if (messageType == MalmoMessageType.SERVER_GO) - { - // First, force all entities to get re-added to their chunks, clearing out any old entities in the process. - // We need to do this because the process of teleporting all agents to their start positions, combined - // with setting them to/from spectator mode, leaves the client chunk entity lists etc in a parlous state. - List lel = Minecraft.getMinecraft().world.loadedEntityList; - for (int i = 0; i < lel.size(); i++) - { - Entity entity = (Entity)lel.get(i); - Chunk chunk = Minecraft.getMinecraft().world.getChunkFromChunkCoords(entity.chunkCoordX, entity.chunkCoordZ); - List entitiesToRemove = new ArrayList(); - for (int k = 0; k < chunk.getEntityLists().length; k++) - { - Iterator iterator = chunk.getEntityLists()[k].iterator(); - while (iterator.hasNext()) - { - Entity chunkent = (Entity)iterator.next(); - if (chunkent.getEntityId() == entity.getEntityId()) - { - entitiesToRemove.add(chunkent); - } - } - } - for (Entity removeEnt : entitiesToRemove) - { - chunk.removeEntity(removeEnt); - } - entity.addedToChunk = false; // Will force it to get re-added to the chunk list. - if (entity instanceof EntityLivingBase) - { - // If we want the entities to be rendered with the correct yaw from the outset, - // we need to set their render offset manually. - // (Set the offset from the outset to avoid the onset of upset.) - ((EntityLivingBase)entity).renderYawOffset = entity.rotationYaw; - ((EntityLivingBase)entity).prevRenderYawOffset = entity.rotationYaw; - } - if (entity instanceof EntityPlayerSP) - { - // Although the following call takes place on the server, and should have taken effect already, - // there is some discontinuity which is causing the effects to get lost, so we call it here too: - entity.setInvisible(false); - } - } - this.serverHasFiredStartingPistol = true; // GO GO GO! - } - } - - @Override - public void cleanup() - { - super.cleanup(); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_STOPAGENTS); - MalmoMod.MalmoMessageHandler.deregisterForMessage(this, MalmoMessageType.SERVER_GO); - } - } - - // --------------------------------------------------------------------------------------------------------- - /** - * State that occurs at the end of the mission, whether due to death, - * failure, success, error, or whatever. - */ - public class MissionEndedEpisode extends ConfigAwareStateEpisode - { - private MissionResult result; - private boolean aborting; - private boolean informServer; - private boolean informAgent; - private int totalTicks = 0; - - public MissionEndedEpisode(ClientStateMachine machine, MissionResult mr, boolean aborting, boolean informServer, boolean informAgent) - { - super(machine); - this.result = mr; - this.aborting = aborting; - this.informServer = informServer; - this.informAgent = informAgent; - } - - @Override - protected void execute() - { - totalTicks = 0; - - // Get a text report: - String errorFeedback = ClientStateMachine.this.getErrorDetails(); - String quitFeedback = ClientStateMachine.this.missionQuitCode; - String concatenation = (errorFeedback != null && !errorFeedback.isEmpty() && quitFeedback != null && !quitFeedback.isEmpty()) ? ";\n" : ""; - String report = quitFeedback + concatenation + errorFeedback; - - if (this.informServer) - { - // Inform the server of what has happened. - HashMap map = new HashMap(); - if (Minecraft.getMinecraft().player != null) // Might not be a player yet. - map.put("username", Minecraft.getMinecraft().player.getName()); - map.put("error", ClientStateMachine.this.getErrorDetails()); - MalmoMod.network.sendToServer(new MalmoMod.MalmoMessage(MalmoMessageType.CLIENT_BAILED, 0, map)); - } - - if (this.informAgent) - { - // Create a MissionEnded instance for this result: - MissionEnded missionEnded = new MissionEnded(); - missionEnded.setStatus(this.result); - if (ClientStateMachine.this.missionQuitCode != null && ClientStateMachine.this.missionQuitCode.equals(MalmoMod.AGENT_DEAD_QUIT_CODE)) - missionEnded.setStatus(MissionResult.PLAYER_DIED); // Need to do this manually. - missionEnded.setHumanReadableStatus(report); - - // TODO: WE HAVE TO MOVE THIS TO THE onMISSIONENDED of Client Mission - // BECAUSE IT WOULD TAKE AN EXTRA TICK TO HAVE THIS APPEAR PROPERLY. - // THIS MOVE IS INCOMPATIBLE WITH MULTIPLE AGENTS AND REWARD DISTRIBUTION - // A PROPER REHAUL OF THE WHOLE SIMULATOR TO SUPPROT SYNCHRONOUS TICKING - // ACCROSS MULTIPLE AGENTS AND A STATE MACHINE WHOSE STATE CHANGES INDEPENDENT - // OF CLIENT TICKS IS REQUIRED. - // if (!ClientStateMachine.this.finalReward.isEmpty()) - // { - // if (envServer != null) { - // envServer.addRewards(ClientStateMachine.this.finalReward.getRewardTotal()); - // } - // missionEnded.setReward(ClientStateMachine.this.finalReward.getAsReward()); - // ClientStateMachine.this.finalReward.clear(); - // } - missionEnded.setMissionDiagnostics(ClientStateMachine.this.missionEndedData); // send our diagnostics - ClientStateMachine.this.missionEndedData = new MissionDiagnostics(); // and clear them for the next mission - // And send MissionEnded message to the agent to inform it that the mission has ended: - System.out.println("inform the agent"); - sendMissionEnded(missionEnded); - } - - if (this.aborting) // Take the shortest path back to dormant. - episodeHasCompleted(ClientState.DORMANT); - } - - private void sendMissionEnded(MissionEnded missionEnded) - { - // Send a MissionEnded message to the agent to inform it that the mission has ended. - // Create a string XML representation: - String missionEndedString = null; - try - { - missionEndedString = SchemaHelper.serialiseObject(missionEnded, MissionEnded.class); - if (ScoreHelper.isScoring()) { - Reward reward = missionEnded.getReward(); - if (reward == null) { - reward = new Reward(); - } - ScoreHelper.logMissionEndRewards(reward); - } - } - catch (JAXBException e) - { - TCPUtils.Log(Level.SEVERE, "Failed mission end XML serialization: " + e); - } - - boolean sentOkay = false; - if (missionEndedString != null) - { - if (AddressHelper.getMissionControlPort() == 0) { - sentOkay = true; - } else { - TCPSocketChannel sender = ClientStateMachine.this.getMissionControlSocket(); - System.out.println(String.format("Sending mission ended message to %s:%d.", sender.getAddress(), sender.getPort())); - sentOkay = sender.sendTCPString(missionEndedString); - sender.close(); - } - } - - if (!sentOkay) - { - // Couldn't formulate a reply to the agent - bit of a problem. - // Can't do much to alert the agent itself, - // will have to settle for alerting anyone who is watching the mod: - ClientStateMachine.this.getScreenHelper().addFragment("ERROR: Could not send mission ended message - agent may need manually resetting.", TextCategory.TXT_CLIENT_WARNING, 10000); - } - } - - @Override - public void onClientTick(ClientTickEvent event) - { - if (!this.aborting) - episodeHasCompleted(ClientState.WAITING_FOR_SERVER_MISSION_END); - - if (++totalTicks > WAIT_MAX_TICKS) - { - String msg = "Too long waiting for server to end mission."; - TCPUtils.Log(Level.SEVERE, msg); - episodeHasCompletedWithErrors(ClientState.ERROR_TIMED_OUT_WAITING_FOR_MISSION_END, msg); - } - } - } -} diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/gpu/patch_files/MalmoEnvServer.java b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/gpu/patch_files/MalmoEnvServer.java deleted file mode 100644 index 6b74acac..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/docker/gpu/patch_files/MalmoEnvServer.java +++ /dev/null @@ -1,939 +0,0 @@ -// -------------------------------------------------------------------------------------------------- -// Copyright (c) 2016 Microsoft Corporation -// -// Permission is hereby granted, free of charge, to any person obtaining a copy of this software and -// associated documentation files (the "Software"), to deal in the Software without restriction, -// including without limitation the rights to use, copy, modify, merge, publish, distribute, -// sublicense, and/or l copies of the Software, and to permit persons to whom the Software is -// furnished to do so, subject to the following conditions: -// -// The above copyright notice and this permission notice shall be included in all copies or -// substantial portions of the Software. -// -// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT -// NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND -// NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, -// DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. -// -------------------------------------------------------------------------------------------------- - -package com.microsoft.Malmo.Client; - -import com.microsoft.Malmo.MalmoMod; -import com.microsoft.Malmo.MissionHandlerInterfaces.IWantToQuit; -import com.microsoft.Malmo.Schemas.MissionInit; -import com.microsoft.Malmo.Utils.TCPUtils; - -import net.minecraft.profiler.Profiler; -import com.microsoft.Malmo.Utils.TimeHelper; - -import net.minecraftforge.common.config.Configuration; -import java.io.*; -import java.net.ServerSocket; -import java.net.Socket; -import java.nio.charset.Charset; -import java.util.Arrays; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.locks.Condition; -import java.util.concurrent.locks.Lock; -import java.util.concurrent.locks.ReentrantLock; -import java.util.Hashtable; -import com.microsoft.Malmo.Utils.TCPInputPoller; -import java.util.logging.Level; - -import java.util.LinkedList; -import java.util.List; - - -/** - * MalmoEnvServer - service supporting OpenAI gym "environment" for multi-agent Malmo missions. - */ -public class MalmoEnvServer implements IWantToQuit { - private static Profiler profiler = new Profiler(); - private static int nsteps = 0; - private static boolean debug = false; - - private static String hello = " commands = new LinkedList(); - } - - private static boolean envPolicy = false; // Are we configured by config policy? - - // Synchronize on EnvStateasd - - - private Lock lock = new ReentrantLock(); - private Condition cond = lock.newCondition(); - - private EnvState envState = new EnvState(); - - private Hashtable initTokens = new Hashtable(); - - static final long COND_WAIT_SECONDS = 3; // Max wait in seconds before timing out (and replying to RPC). - static final int BYTES_INT = 4; - static final int BYTES_DOUBLE = 8; - private static final Charset utf8 = Charset.forName("UTF-8"); - - // Service uses a single per-environment client connection - initiated by the remote environment. - - private int port; - private TCPInputPoller missionPoller; // Used for command parsing and not actual communication. - private String version; - - // AOG: From running experiments, I've found that MineRL can get stuck resetting the - // environment which causes huge delays while we wait for the Python side to time - // out and restart the Minecraft instace. Minecraft itself is normally in a recoverable - // state, but the MalmoEnvServer instance will be blocked in a tight spin loop trying - // handling a Peek request from the Python client. To unstick things, I've added this - // flag that can be set when we know things are in a bad state to abort the peek request. - // WARNING: THIS IS ONLY TREATING THE SYMPTOM AND NOT THE ROOT CAUSE - // The reason things are getting stuck is because the player is either dying or we're - // receiving a quit request while an episode reset is in progress. - private boolean abortRequest; - public void abort() { - System.out.println("AOG: MalmoEnvServer.abort"); - abortRequest = true; - } - - /*** - * Malmo "Env" service. - * @param port the port the service listens on. - * @param missionPoller for plugging into existing comms handling. - */ - public MalmoEnvServer(String version, int port, TCPInputPoller missionPoller) { - this.version = version; - this.missionPoller = missionPoller; - this.port = port; - // AOG - Assume we don't wan't to be aborting in the first place - this.abortRequest = false; - } - - /** Initialize malmo env configuration. For now either on or "legacy" AgentHost protocol.*/ - static public void update(Configuration configs) { - envPolicy = configs.get(MalmoMod.ENV_CONFIGS, "env", "false").getBoolean(); - } - - public static boolean isEnv() { - return envPolicy; - } - - /** - * Start servicing the MalmoEnv protocol. - * @throws IOException - */ - public void serve() throws IOException { - - ServerSocket serverSocket = new ServerSocket(port); - serverSocket.setPerformancePreferences(0,2,1); - - - while (true) { - try { - final Socket socket = serverSocket.accept(); - socket.setTcpNoDelay(true); - - Thread thread = new Thread("EnvServerSocketHandler") { - public void run() { - boolean running = false; - try { - checkHello(socket); - - while (true) { - DataInputStream din = new DataInputStream(socket.getInputStream()); - int hdr = din.readInt(); - byte[] data = new byte[hdr]; - din.readFully(data); - - String command = new String(data, utf8); - - if (command.startsWith(" dat = profiler.getProfilingData("root"); - for(int qq = 0; qq < dat.size(); qq++){ - Profiler.Result res = dat.get(qq); - System.out.println(res.profilerName + " " + res.totalUsePercentage + " "+ res.usePercentage); - } - } - - - } else if (command.startsWith(""; - data = command.getBytes(utf8); - hdr = data.length; - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(hdr); - dout.write(data, 0, hdr); - dout.flush(); - } else { - throw new IOException("Unknown env service command"); - } - } - } catch (IOException ioe) { - // ioe.printStackTrace(); - TCPUtils.Log(Level.SEVERE, "MalmoEnv socket error: " + ioe + " (can be on disconnect)"); - // System.out.println("[ERROR] " + "MalmoEnv socket error: " + ioe + " (can be on disconnect)"); - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] MalmoEnv socket error"); - try { - if (running) { - TCPUtils.Log(Level.INFO,"Want to quit on disconnect."); - - System.out.println("[LOGTOPY] " + "Want to quit on disconnect."); - setWantToQuit(); - } - socket.close(); - } catch (IOException ioe2) { - } - } - } - }; - thread.start(); - } catch (IOException ioe) { - TCPUtils.Log(Level.SEVERE, "MalmoEnv service exits on " + ioe); - } - } - } - - private void checkHello(Socket socket) throws IOException { - DataInputStream din = new DataInputStream(socket.getInputStream()); - int hdr = din.readInt(); - if (hdr <= 0 || hdr > hello.length() + 8) // Version number may be somewhat longer in future. - throw new IOException("Invalid MalmoEnv hello header length"); - byte[] data = new byte[hdr]; - din.readFully(data); - if (!new String(data).startsWith(hello + version)) - throw new IOException("MalmoEnv invalid protocol or version - expected " + hello + version); - } - - // Handler for messages. - private boolean missionInit(DataInputStream din, String command, Socket socket) throws IOException { - - String ipOriginator = socket.getInetAddress().getHostName(); - - int hdr; - byte[] data; - hdr = din.readInt(); - data = new byte[hdr]; - din.readFully(data); - String id = new String(data, utf8); - - TCPUtils.Log(Level.INFO,"Mission Init" + id); - - String[] token = id.split(":"); - String experimentId = token[0]; - int role = Integer.parseInt(token[1]); - int reset = Integer.parseInt(token[2]); - int agentCount = Integer.parseInt(token[3]); - Boolean isSynchronous = Boolean.parseBoolean(token[4]); - Long seed = null; - if(token.length > 5) - seed = Long.parseLong(token[5]); - - if(isSynchronous && agentCount > 1){ - throw new IOException("Synchronous mode currently does not support multiple agents."); - } - port = -1; - boolean allTokensConsumed = true; - boolean started = false; - - lock.lock(); - try { - if (role == 0) { - - String previousToken = experimentId + ":0:" + (reset - 1); - initTokens.remove(previousToken); - - String myToken = experimentId + ":0:" + reset; - if (!initTokens.containsKey(myToken)) { - TCPUtils.Log(Level.INFO,"(Pre)Start " + role + " reset " + reset); - started = startUp(command, ipOriginator, experimentId, reset, agentCount, myToken, seed, isSynchronous); - if (started) - initTokens.put(myToken, 0); - } else { - started = true; // Pre-started previously. - } - - // Check that all previous tokens have been consumed. If not don't proceed to mission. - - allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount); - if (!allTokensConsumed) { - try { - cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS); - } catch (InterruptedException ie) { - } - allTokensConsumed = areAllTokensConsumed(experimentId, reset, agentCount); - } - } else { - TCPUtils.Log(Level.INFO, "Start " + role + " reset " + reset); - - started = startUp(command, ipOriginator, experimentId, reset, agentCount, experimentId + ":" + role + ":" + reset, seed, isSynchronous); - } - } finally { - lock.unlock(); - } - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(allTokensConsumed && started ? 1 : 0); - dout.flush(); - - dout.flush(); - - return allTokensConsumed && started; - } - - private boolean areAllTokensConsumed(String experimentId, int reset, int agentCount) { - boolean allTokensConsumed = true; - for (int i = 1; i < agentCount; i++) { - String tokenForAgent = experimentId + ":" + i + ":" + (reset - 1); - if (initTokens.containsKey(tokenForAgent)) { - TCPUtils.Log(Level.FINE,"Mission init - unconsumed " + tokenForAgent); - allTokensConsumed = false; - } - } - return allTokensConsumed; - } - - private boolean startUp(String command, String ipOriginator, String experimentId, int reset, int agentCount, String myToken, Long seed, Boolean isSynchronous) throws IOException { - - // Clear out mission state - envState.reward = 0.0; - envState.commands.clear(); - envState.obs = null; - envState.info = ""; - - - envState.missionInit = command; - envState.done = false; - envState.quit = false; - envState.token = myToken; - envState.experimentId = experimentId; - envState.agentCount = agentCount; - envState.reset = reset; - envState.synchronous = isSynchronous; - envState.seed = seed; - - return startUpMission(command, ipOriginator); - } - - private boolean startUpMission(String command, String ipOriginator) throws IOException { - - if (missionPoller == null) - return false; - - ByteArrayOutputStream baos = new ByteArrayOutputStream(); - DataOutputStream dos = new DataOutputStream(baos); - - missionPoller.commandReceived(command, ipOriginator, dos); - - dos.flush(); - byte[] reply = baos.toByteArray(); - ByteArrayInputStream bais = new ByteArrayInputStream(reply); - DataInputStream dis = new DataInputStream(bais); - int hdr = dis.readInt(); - byte[] replyBytes = new byte[hdr]; - dis.readFully(replyBytes); - - String replyStr = new String(replyBytes); - if (replyStr.equals("MALMOOK")) { - TCPUtils.Log(Level.INFO, "MalmoEnvServer Mission starting ..."); - return true; - } else if (replyStr.equals("MALMOBUSY")) { - TCPUtils.Log(Level.INFO, "MalmoEnvServer Busy - I want to quit"); - this.envState.quit = true; - } - return false; - } - - private static final int stepTagLength = "".length(); // Step with option code. - private synchronized void stepSync(String command, Socket socket, DataInputStream din) throws IOException - { - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Entering synchronous step."); - nsteps += 1; - profiler.startSection("commandProcessing"); - String actions = command.substring(stepTagLength, command.length() - (stepTagLength + 2)); - int options = Character.getNumericValue(command.charAt(stepTagLength - 2)); - boolean withInfo = options == 0 || options == 2; - - - - - // Prepare to write data to the client. - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - double reward = 0.0; - boolean done; - byte[] obs; - String info = ""; - boolean sent = false; - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Acquiring lock for synchronous step."); - - lock.lock(); - try { - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Lock is acquired."); - - done = envState.done; - - // TODO Handle when the environment is done. - - // Process the actions. - if (actions.contains("\n")) { - String[] cmds = actions.split("\\n"); - for(String cmd : cmds) { - envState.commands.add(cmd); - } - } else { - if (!actions.isEmpty()) - envState.commands.add(actions); - } - sent = true; - - - - profiler.endSection(); //cmd - profiler.startSection("requestTick"); - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Received: " + actions); - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Requesting tick."); - // Now wait to run a tick - // If synchronous mode is off then we should see if want to quit is true. - while(!TimeHelper.SyncManager.requestTick() && !done ){Thread.yield();} - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Tick request granted."); - - profiler.endSection(); - profiler.startSection("waitForTick"); - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Waiting for tick."); - - // Then wait until the tick is finished - while(!TimeHelper.SyncManager.isTickCompleted() && !done ){ Thread.yield();} - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] TICK DONE. Getting observation."); - - - - profiler.endSection(); - profiler.startSection("getObservation"); - // After which, get the observations. - obs = getObservation(done); - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Observation received. Getting info."); - - profiler.endSection(); - profiler.startSection("getInfo"); - - - // Pick up rewards. - reward = envState.reward; - if (withInfo) { - info = envState.info; - // if(info == null) - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] FILLING INFO: NULL"); - // else - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] FILLING " + info.toString()); - - } - done = envState.done; - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] STATUS " + Boolean.toString(done)); - envState.info = null; - envState.obs = null; - envState.reward = 0.0; - - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Info received.."); - profiler.endSection(); - } finally { - lock.unlock(); - } - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Lock released. Writing observation, info, done."); - - profiler.startSection("writeObs"); - dout.writeInt(obs.length); - dout.write(obs); - - dout.writeInt(BYTES_DOUBLE + 2); - dout.writeDouble(reward); - dout.writeByte(done ? 1 : 0); - dout.writeByte(sent ? 1 : 0); - - if (withInfo) { - byte[] infoBytes = info.getBytes(utf8); - dout.writeInt(infoBytes.length); - dout.write(infoBytes); - } - - profiler.endSection(); //write obs - profiler.startSection("flush"); - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Packets written. Flushing."); - dout.flush(); - profiler.endSection(); // flush - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Done with step."); - } - // Handler for messages. Single digit option code after _ specifies if turnkey and info are included in message. - private void step(String command, Socket socket, DataInputStream din) throws IOException { - if(envState.synchronous){ - stepSync(command, socket, din); - } - else{ - System.out.println("[ERROR] Asynchronous stepping is not supported in MineRL."); - } - - } - - // Handler for messages. - private void peek(String command, Socket socket, DataInputStream din) throws IOException { - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - byte[] obs; - boolean done; - String info = ""; - // AOG - As we've only seen issues with the peek reqest, I've focused my changes to just - // this function. Initially we want to be optimistic and assume we're not going to abort - // the request and my observations of event timings indicate that there is plenty of time - // between the peek request being received and the reset failing, so a race condition is - // unlikely. - abortRequest = false; - - lock.lock(); - - try { - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Waiting for pistol to fire."); - while(!TimeHelper.SyncManager.hasServerFiredPistol() && !abortRequest){ - - // Now wait to run a tick - while(!TimeHelper.SyncManager.requestTick() && !abortRequest){Thread.yield();} - - - // Then wait until the tick is finished - while(!TimeHelper.SyncManager.isTickCompleted() && !abortRequest){ Thread.yield();} - - - Thread.yield(); - } - - if (abortRequest) { - System.out.println("AOG: Aborting peek request"); - // AOG - We detect the lack of observation within our Python wrapper and throw a slightly - // diferent exception that by-passes MineRLs automatic clean up code. If we were to report - // 'done', the MineRL detects this as a runtime error and kills the Minecraft process - // triggering a lengthy restart. So far from testing, Minecraft itself is fine can we can - // retry the reset, it's only the tight loops above that were causing things to stall and - // timeout. - // No observation - dout.writeInt(0); - // No info - dout.writeInt(0); - // Done - dout.writeInt(1); - dout.writeByte(0); - dout.flush(); - return; - } - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Pistol fired!."); - // Wait two ticks for the first observation from server to be propagated. - while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();} - - // Then wait until the tick is finished - while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();} - - - - while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();} - - // Then wait until the tick is finished - while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();} - - - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Getting observation."); - - obs = getObservation(false); - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Observation acquired."); - done = envState.done; - info = envState.info; - - } finally { - lock.unlock(); - } - - dout.writeInt(obs.length); - dout.write(obs); - - byte[] infoBytes = info.getBytes(utf8); - dout.writeInt(infoBytes.length); - dout.write(infoBytes); - - dout.writeInt(1); - dout.writeByte(done ? 1 : 0); - - dout.flush(); - } - - // Get the current observation. If none and not done wait for a short time. - public byte[] getObservation(boolean done) { - byte[] obs = envState.obs; - if (obs == null){ - System.out.println("[ERROR] Video observation is null; please notify the developer."); - } - return obs; - } - - // Handler for messages - used by non-zero roles to discover integrated server port from primary (role 0) service. - - private final static int findTagLength = "".length(); - - private void find(String command, Socket socket) throws IOException { - - Integer port; - lock.lock(); - try { - String token = command.substring(findTagLength, command.length() - (findTagLength + 1)); - TCPUtils.Log(Level.INFO, "Find token? " + token); - - // Purge previous token. - String[] tokenSplits = token.split(":"); - String experimentId = tokenSplits[0]; - int role = Integer.parseInt(tokenSplits[1]); - int reset = Integer.parseInt(tokenSplits[2]); - - String previousToken = experimentId + ":" + role + ":" + (reset - 1); - initTokens.remove(previousToken); - cond.signalAll(); - - // Check for next token. Wait for a short time if not already produced. - port = initTokens.get(token); - if (port == null) { - try { - cond.await(COND_WAIT_SECONDS, TimeUnit.SECONDS); - } catch (InterruptedException ie) { - } - port = initTokens.get(token); - if (port == null) { - port = 0; - TCPUtils.Log(Level.INFO,"Role " + role + " reset " + reset + " waiting for token."); - } - } - } finally { - lock.unlock(); - } - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(port); - dout.flush(); - } - - public boolean isSynchronous(){ - return envState.synchronous; - } - - // Handler for messages. These reset the service so use with care! - private void init(String command, Socket socket) throws IOException { - lock.lock(); - try { - initTokens = new Hashtable(); - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(1); - dout.flush(); - } finally { - lock.unlock(); - } - } - - // Handler for (quit mission) messages. - private void quit(String command, Socket socket) throws IOException { - lock.lock(); - try { - if (!envState.done){ - - envState.quit = true; - } - - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Pistol fired!."); - // Wait two ticks for the first observation from server to be propagated. - while(!TimeHelper.SyncManager.requestTick() ){Thread.yield();} - - // Then wait until the tick is finished - while(!TimeHelper.SyncManager.isTickCompleted()){ Thread.yield();} - - - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(envState.done ? 1 : 0); - dout.flush(); - } finally { - lock.unlock(); - } - } - - private final static int closeTagLength = "".length(); - - // Handler for messages. - private void close(String command, Socket socket) throws IOException { - lock.lock(); - try { - String token = command.substring(closeTagLength, command.length() - (closeTagLength + 1)); - - initTokens.remove(token); - - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(1); - dout.flush(); - } finally { - lock.unlock(); - } - } - - // Handler for messages. - private void status(String command, Socket socket) throws IOException { - lock.lock(); - try { - String status = "{}"; // TODO Possibly have something more interesting to report. - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - - byte[] statusBytes = status.getBytes(utf8); - dout.writeInt(statusBytes.length); - dout.write(statusBytes); - - dout.flush(); - } finally { - lock.unlock(); - } - } - - // Handler for messages. These "kill the service" temporarily so use with care!f - private void exit(String command, Socket socket) throws IOException { - // lock.lock(); - try { - // We may exit before we get a chance to reply. - TimeHelper.SyncManager.setSynchronous(false); - DataOutputStream dout = new DataOutputStream(socket.getOutputStream()); - dout.writeInt(BYTES_INT); - dout.writeInt(1); - dout.flush(); - - ClientStateMachine.exitJava(); - - } finally { - // lock.unlock(); - } - } - - // Malmo client state machine interface methods: - - public String getCommand() { - try { - String command = envState.commands.poll(); - if (command == null) - return ""; - else - return command; - } finally { - } - } - - public void endMission() { - // lock.lock(); - try { - // AOG - If the mission is ending, we always want to abort requests and they won't - // be able to progress to completion and will stall. - System.out.println("AOG: MalmoEnvServer.endMission"); - abort(); - envState.done = true; - envState.quit = false; - envState.missionInit = null; - - if (envState.token != null) { - initTokens.remove(envState.token); - envState.token = null; - envState.experimentId = null; - envState.agentCount = 0; - envState.reset = 0; - - // cond.signalAll(); - } - // lock.unlock(); - } finally { - } - } - // Record a Malmo "observation" json - as the env info since an environment "obs" is a video frame. - public void observation(String info) { - // Parsing obs as JSON would be slower but less fragile than extracting the turn_key using string search. - // lock.lock(); - try { - // TimeHelper.SyncManager.debugLog("[MALMO_ENV_SERVER] Inserting: " + info); - envState.info = info; - // cond.signalAll(); - } finally { - // lock.unlock(); - } - } - - public void addRewards(double rewards) { - // lock.lock(); - try { - envState.reward += rewards; - } finally { - // lock.unlock(); - } - } - - public void addFrame(byte[] frame) { - // lock.lock(); - try { - envState.obs = frame; // Replaces current. - // cond.signalAll(); - } finally { - // lock.unlock(); - } - } - - public void notifyIntegrationServerStarted(int integrationServerPort) { - lock.lock(); - try { - if (envState.token != null) { - TCPUtils.Log(Level.INFO,"Integration server start up - token: " + envState.token); - addTokens(integrationServerPort, envState.token, envState.experimentId, envState.agentCount, envState.reset); - cond.signalAll(); - } else { - TCPUtils.Log(Level.WARNING,"No mission token on integration server start up!"); - } - } finally { - lock.unlock(); - } - } - - private void addTokens(int integratedServerPort, String myToken, String experimentId, int agentCount, int reset) { - initTokens.put(myToken, integratedServerPort); - // Place tokens for other agents to find. - for (int i = 1; i < agentCount; i++) { - String tokenForAgent = experimentId + ":" + i + ":" + reset; - initTokens.put(tokenForAgent, integratedServerPort); - } - } - - // IWantToQuit implementation. - - @Override - public boolean doIWantToQuit(MissionInit missionInit) { - // lock.lock(); - try { - return envState.quit; - } finally { - // lock.unlock(); - } - } - - public Long getSeed(){ - return envState.seed; - } - - private void setWantToQuit() { - // lock.lock(); - try { - envState.quit = true; - - } finally { - - if(TimeHelper.SyncManager.isSynchronous()){ - // We want to dsynchronize everything. - TimeHelper.SyncManager.setSynchronous(false); - } - // lock.unlock(); - } - } - - @Override - public void prepare(MissionInit missionInit) { - } - - @Override - public void cleanup() { - } - - @Override - public String getOutcome() { - return "Env quit"; - } -} diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/malmo_video_recorder.py b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/malmo_video_recorder.py deleted file mode 100644 index 2801d6bc..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/malmo_video_recorder.py +++ /dev/null @@ -1,173 +0,0 @@ -import time -import glob -import pathlib - -from malmo import MalmoPython, malmoutils -from malmo.launch_minecraft_in_background import launch_minecraft_in_background - - -class MalmoVideoRecorder: - DEFAULT_RECORDINGS_DIR = './logs/videos' - - def __init__(self): - self.agent_host_bot = None - self.agent_host_camera = None - self.client_pool = None - self.is_malmo_initialized = False - - def init_malmo(self, recordings_directory=DEFAULT_RECORDINGS_DIR): - if self.is_malmo_initialized: - return - - launch_minecraft_in_background( - '/app/MalmoPlatform/Minecraft', - ports=[10000, 10001]) - - # Set up two agent hosts - self.agent_host_bot = MalmoPython.AgentHost() - self.agent_host_camera = MalmoPython.AgentHost() - - # Create list of Minecraft clients to attach to. The agents must - # have been launched before calling record_malmo_video using - # init_malmo() - self.client_pool = MalmoPython.ClientPool() - self.client_pool.add(MalmoPython.ClientInfo('127.0.0.1', 10000)) - self.client_pool.add(MalmoPython.ClientInfo('127.0.0.1', 10001)) - - # Use bot's agenthost to hold the command-line options - malmoutils.parse_command_line( - self.agent_host_bot, - ['--record_video', '--recording_dir', recordings_directory]) - - self.is_malmo_initialized = True - - def _start_mission(self, agent_host, mission, recording_spec, role): - used_attempts = 0 - max_attempts = 5 - - while True: - try: - agent_host.startMission( - mission, - self.client_pool, - recording_spec, - role, - '') - break - except MalmoPython.MissionException as e: - errorCode = e.details.errorCode - if errorCode == (MalmoPython.MissionErrorCode - .MISSION_SERVER_WARMING_UP): - time.sleep(2) - elif errorCode == (MalmoPython.MissionErrorCode - .MISSION_INSUFFICIENT_CLIENTS_AVAILABLE): - print('Not enough Minecraft instances running.') - used_attempts += 1 - if used_attempts < max_attempts: - print('Will wait in case they are starting up.') - time.sleep(300) - elif errorCode == (MalmoPython.MissionErrorCode - .MISSION_SERVER_NOT_FOUND): - print('Server not found.') - used_attempts += 1 - if used_attempts < max_attempts: - print('Will wait and retry.') - time.sleep(2) - else: - used_attempts = max_attempts - if used_attempts >= max_attempts: - raise e - - def _wait_for_start(self, agent_hosts): - start_flags = [False for a in agent_hosts] - start_time = time.time() - time_out = 120 - - while not all(start_flags) and time.time() - start_time < time_out: - states = [a.peekWorldState() for a in agent_hosts] - start_flags = [w.has_mission_begun for w in states] - errors = [e for w in states for e in w.errors] - - if len(errors) > 0: - print("Errors waiting for mission start:") - for e in errors: - print(e.text) - raise Exception("Encountered errors while starting mission.") - if time.time() - start_time >= time_out: - raise Exception("Timed out while waiting for mission to start.") - - def _get_xml(self, xml_file, seed): - with open(xml_file, 'r') as mission_file: - return mission_file.read().format(SEED_PLACEHOLDER=seed) - - def _is_mission_running(self): - return self.agent_host_bot.peekWorldState().is_mission_running or \ - self.agent_host_camera.peekWorldState().is_mission_running - - def record_malmo_video(self, instructions, xml_file, seed): - ''' - Replays a set of instructions through Malmo using two players. The - first player will navigate the specified mission based on the given - instructions. The second player observes the first player's moves, - which is captured in a video. - ''' - - if not self.is_malmo_initialized: - raise Exception('Malmo not initialized. Call init_malmo() first.') - - # Set up the mission - my_mission = MalmoPython.MissionSpec( - self._get_xml(xml_file, seed), - True) - - bot_recording_spec = MalmoPython.MissionRecordSpec() - camera_recording_spec = MalmoPython.MissionRecordSpec() - - recordingsDirectory = \ - malmoutils.get_recordings_directory(self.agent_host_bot) - if recordingsDirectory: - camera_recording_spec.setDestination( - recordingsDirectory + "//rollout_" + str(seed) + ".tgz") - camera_recording_spec.recordMP4( - MalmoPython.FrameType.VIDEO, - 36, - 2000000, - False) - - # Start the agents - self._start_mission( - self.agent_host_bot, - my_mission, - bot_recording_spec, - 0) - self._start_mission( - self.agent_host_camera, - my_mission, - camera_recording_spec, - 1) - self._wait_for_start([self.agent_host_camera, self.agent_host_bot]) - - # Teleport the camera agent to the required position - self.agent_host_camera.sendCommand('tp -29 72 -6.7') - instruction_index = 0 - - while self._is_mission_running(): - - command = instructions[instruction_index] - instruction_index += 1 - - self.agent_host_bot.sendCommand(command) - - # Pause for half a second - change this for faster/slower videos - time.sleep(0.5) - - if instruction_index == len(instructions): - self.agent_host_bot.sendCommand("jump 1") - time.sleep(2) - - self.agent_host_bot.sendCommand("quit") - - # Wait a little for Malmo to reset before the - # next mission is started - time.sleep(2) - print("Video recorded.") diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_environment.py b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_environment.py deleted file mode 100644 index f4ac416c..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_environment.py +++ /dev/null @@ -1,180 +0,0 @@ -import json -import logging - -import gym -import minerl.env.core -import minerl.env.comms -import numpy as np - -from ray.rllib.env.atari_wrappers import FrameStack -from minerl.env.malmo import InstanceManager - -# Modify the MineRL timeouts to detect common errors -# quicker and speed up recovery -minerl.env.core.SOCKTIME = 60.0 -minerl.env.comms.retry_timeout = 1 - - -class EnvWrapper(minerl.env.core.MineRLEnv): - def __init__(self, xml, port): - InstanceManager.configure_malmo_base_port(port) - self.action_to_command_array = [ - 'move 1', - 'camera 0 270', - 'camera 0 90'] - - super().__init__( - xml, - gym.spaces.Box(low=0, high=255, shape=(84, 84, 3), dtype=np.uint8), - gym.spaces.Discrete(3) - ) - - self.metadata['video.frames_per_second'] = 2 - - def _setup_spaces(self, observation_space, action_space): - self.observation_space = observation_space - self.action_space = action_space - - def _process_action(self, action_in) -> str: - assert self.action_space.contains(action_in) - assert action_in <= len( - self.action_to_command_array) - 1, 'action index out of bounds.' - return self.action_to_command_array[action_in] - - def _process_observation(self, pov, info): - ''' - Overwritten to simplify: returns only `pov` and - not as the MineRLEnv an obs_dict (observation directory) - ''' - pov = np.frombuffer(pov, dtype=np.uint8) - - if pov is None or len(pov) == 0: - raise Exception('Invalid observation, probably an aborted peek') - else: - pov = pov.reshape( - (self.height, self.width, self.depth) - )[::-1, :, :] - - assert self.observation_space.contains(pov) - - self._last_pov = pov - - return pov - - -class TrackingEnv(gym.Wrapper): - def __init__(self, env): - super().__init__(env) - self._actions = [ - self._forward, - self._turn_left, - self._turn_right - ] - - def _reset_state(self): - self._facing = (1, 0) - self._position = (0, 0) - self._visited = {} - self._update_visited() - - def _forward(self): - self._position = ( - self._position[0] + self._facing[0], - self._position[1] + self._facing[1] - ) - - def _turn_left(self): - self._facing = (self._facing[1], -self._facing[0]) - - def _turn_right(self): - self._facing = (-self._facing[1], self._facing[0]) - - def _encode_state(self): - return self._position - - def _update_visited(self): - state = self._encode_state() - value = self._visited.get(state, 0) - self._visited[state] = value + 1 - return value - - def reset(self): - self._reset_state() - return super().reset() - - def step(self, action): - o, r, d, i = super().step(action) - self._actions[action]() - revisit_count = self._update_visited() - if revisit_count == 0: - r += 0.1 - - return o, r, d, i - - -class TrajectoryWrapper(gym.Wrapper): - def __init__(self, env): - super().__init__(env) - self._trajectory = [] - self._action_to_malmo_command_array = ['move 1', 'turn -1', 'turn 1'] - - def get_trajectory(self): - return self._trajectory - - def _to_malmo_action(self, action_index): - return self._action_to_malmo_command_array[action_index] - - def step(self, action): - self._trajectory.append(self._to_malmo_action(action)) - o, r, d, i = super().step(action) - - return o, r, d, i - - -class DummyEnv(gym.Env): - def __init__(self): - self.observation_space = gym.spaces.Box( - low=0, - high=255, - shape=(84, 84, 6), - dtype=np.uint8) - self.action_space = gym.spaces.Discrete(3) - - -# Define a function to create a MineRL environment -def create_env(config): - mission = config["mission"] - port = 1000 * config.worker_index + config.vector_index - print('*********************************************') - print(f'* Worker {config.worker_index} creating from \ - mission: {mission}, port {port}') - print('*********************************************') - - if config.worker_index == 0: - # The first environment is only used for checking the action - # and observation space. By using a dummy environment, there's - # no need to spin up a Minecraft instance behind it saving some - # CPU resources on the head node. - return DummyEnv() - - env = EnvWrapper(mission, port) - env = TrackingEnv(env) - env = FrameStack(env, 2) - - return env - - -def create_env_for_rollout(config): - mission = config['mission'] - port = 1000 * config.worker_index + config.vector_index - print('*********************************************') - print(f'* Worker {config.worker_index} creating from \ - mission: {mission}, port {port}') - print('*********************************************') - - env = EnvWrapper(mission, port) - env = TrackingEnv(env) - env = FrameStack(env, 2) - env = TrajectoryWrapper(env) - - return env diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_missions/lava_maze-v0.xml b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_missions/lava_maze-v0.xml deleted file mode 100644 index 0e6d5dc7..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_missions/lava_maze-v0.xml +++ /dev/null @@ -1,95 +0,0 @@ - - - - - $(ENV_NAME) - - - - 50 - - - - - - clear - false - - - - - - - - - - - random - - - - - - - 0.6 - false - - - - - - - - - AML_Bot - - - - - - - - - 84 - 84 - - - - - - - - - moveMouse - inventory - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_missions/lava_maze_rollout-v0.xml b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_missions/lava_maze_rollout-v0.xml deleted file mode 100644 index 07051c02..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_missions/lava_maze_rollout-v0.xml +++ /dev/null @@ -1,95 +0,0 @@ - - - - - $(ENV_NAME) - - - - 50 - - - - - - clear - false - - - - - - - - - - - {SEED_PLACEHOLDER} - - - - - - - 0.6 - false - - - - - - - - - AML_Bot - - - - - - - - - 84 - 84 - - - - - - - - - moveMouse - inventory - - - - - - - - - - - - - - - - - - - - - - - - - - - - \ No newline at end of file diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_missions/lava_maze_rollout_video.xml b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_missions/lava_maze_rollout_video.xml deleted file mode 100644 index a0de6355..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_missions/lava_maze_rollout_video.xml +++ /dev/null @@ -1,74 +0,0 @@ - - - - - AML-Video-Gatherer - - - - 50 - - - - - - clear - false - - - - - - {SEED_PLACEHOLDER} - - - - - - - 0.6 - false - - - - - - - - - Agent - - - - - - - - - moveMouse - inventory - - - - - - - - - - Camera_Bot - - - - - - 860 - 480 - - - - - \ No newline at end of file diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_rollout.py b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_rollout.py deleted file mode 100644 index b4f4cafa..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_rollout.py +++ /dev/null @@ -1,130 +0,0 @@ -import argparse -import os -import re - -from azureml.core import Run -from azureml.core.model import Model - -from minecraft_environment import create_env_for_rollout -from malmo_video_recorder import MalmoVideoRecorder -from gym import wrappers - -import ray -import ray.tune as tune -from ray.rllib import rollout -from ray.tune.registry import get_trainable_cls - - -def write_mission_file_for_seed(mission_file, seed): - with open(mission_file, 'r') as base_file: - mission_file_path = mission_file.replace('v0', seed) - content = base_file.read().format(SEED_PLACEHOLDER=seed) - - mission_file = open(mission_file_path, 'w') - mission_file.writelines(content) - mission_file.close() - - return mission_file_path - - -def run_rollout(trainable_type, mission_file, seed): - # Writes the mission file for minerl - mission_file_path = write_mission_file_for_seed(mission_file, seed) - - # Instantiate the agent. Note: the IMPALA trainer implementation in - # Ray uses an AsyncSamplesOptimizer. Under the hood, this starts a - # LearnerThread which will wait for training samples. This will fail - # after a timeout, but has no influence on the rollout. See - # https://github.com/ray-project/ray/blob/708dff6d8f7dd6f7919e06c1845f1fea0cca5b89/rllib/optimizers/aso_learner.py#L66 - config = { - "env_config": { - "mission": mission_file_path, - "is_rollout": True, - "seed": seed - }, - "num_workers": 0 - } - cls = get_trainable_cls(args.run) - agent = cls(env="Minecraft", config=config) - - # The optimizer is not needed during a rollout - agent.optimizer.stop() - - # Load state from checkpoint - agent.restore(f'{checkpoint_path}/{checkpoint_file}') - - # Get a reference to the environment - env = agent.workers.local_worker().env - - # Let the agent choose actions until the game is over - obs = env.reset() - done = False - total_reward = 0 - - while not done: - action = agent.compute_action(obs) - obs, reward, done, info = env.step(action) - - total_reward += reward - - print(f'Total reward using seed {seed}: {total_reward}') - - # This avoids a sigterm trace in the logs, see minerl.env.malmo.Instance - env.instance.watcher_process.kill() - - env.close() - agent.stop() - - return env.get_trajectory() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument('--model_name', required=True) - parser.add_argument('--run', required=False, default="IMPALA") - args = parser.parse_args() - - # Register custom Minecraft environment - tune.register_env("Minecraft", create_env_for_rollout) - - ray.init(address='auto') - - # Download the model files (contains a checkpoint) - ws = Run.get_context().experiment.workspace - model = Model(ws, args.model_name) - checkpoint_path = model.download(exist_ok=True) - - files_ = os.listdir(checkpoint_path) - cp_pattern = re.compile('^checkpoint-\\d+$') - - checkpoint_file = None - for f_ in files_: - if cp_pattern.match(f_): - checkpoint_file = f_ - - if checkpoint_file is None: - raise Exception("No checkpoint file found.") - - # These are the Minecraft mission seeds for the rollouts - rollout_seeds = ['1234', '43289', '65224', '983341'] - - # Initialize the Malmo video recorder - video_recorder = MalmoVideoRecorder() - video_recorder.init_malmo() - - # Path references to the mission files - base_training_mission_file = \ - 'minecraft_missions/lava_maze_rollout-v0.xml' - base_video_recording_mission_file = \ - 'minecraft_missions/lava_maze_rollout_video.xml' - - for seed in rollout_seeds: - trajectory = run_rollout( - args.run, - base_training_mission_file, - seed) - - video_recorder.record_malmo_video( - trajectory, - base_video_recording_mission_file, - seed) diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_train.py b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_train.py deleted file mode 100644 index d97411d6..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/minecraft_train.py +++ /dev/null @@ -1,49 +0,0 @@ -import os - -import ray -import ray.tune as tune - -from utils import callbacks -from minecraft_environment import create_env - - -def stop(trial_id, result): - max_train_time = int(os.environ.get("AML_MAX_TRAIN_TIME_SECONDS", 5 * 60 * 60)) - - return result["episode_reward_mean"] >= 1 \ - or result["time_total_s"] >= max_train_time - - -if __name__ == '__main__': - tune.register_env("Minecraft", create_env) - - ray.init(address='auto') - - tune.run( - run_or_experiment="IMPALA", - config={ - "env": "Minecraft", - "env_config": { - "mission": "minecraft_missions/lava_maze-v0.xml" - }, - "num_workers": 10, - "num_cpus_per_worker": 2, - "rollout_fragment_length": 50, - "train_batch_size": 1024, - "replay_buffer_num_slots": 4000, - "replay_proportion": 10, - "learner_queue_timeout": 900, - "num_sgd_iter": 2, - "num_data_loader_buffers": 2, - "exploration_config": { - "type": "EpsilonGreedy", - "initial_epsilon": 1.0, - "final_epsilon": 0.02, - "epsilon_timesteps": 500000 - }, - "callbacks": {"on_train_result": callbacks.on_train_result}, - }, - stop=stop, - checkpoint_at_end=True, - local_dir='./logs' - ) diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/networkutils.py b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/networkutils.py deleted file mode 100644 index 64af7d8b..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/networkutils.py +++ /dev/null @@ -1,237 +0,0 @@ -import sys -import csv -from azure.mgmt.network import NetworkManagementClient - - -def check_port_in_port_range(expected_port: str, - dest_port_range: str): - """ - Check if a port is within a port range - Port range maybe like *, 8080 or 8888-8889 - """ - - if dest_port_range == '*': - return True - - dest_ports = dest_port_range.split('-') - - if len(dest_ports) == 1 and \ - int(dest_ports[0]) == int(expected_port): - return True - - if len(dest_ports) == 2 and \ - int(dest_ports[0]) <= int(expected_port) and \ - int(dest_ports[1]) >= int(expected_port): - return True - - return False - - -def check_port_in_destination_port_ranges(expected_port: str, - dest_port_ranges: list): - """ - Check if a port is within a given list of port ranges - i.e. check if port 8080 is in port ranges of 22,80,8080-8090,443 - """ - - for dest_port_range in dest_port_ranges: - if check_port_in_port_range(expected_port, dest_port_range) is True: - return True - - return False - - -def check_ports_in_destination_port_ranges(expected_ports: list, - dest_port_ranges: list): - """ - Check if all ports in a given port list are within a given list - of port ranges - i.e. check if port 8080,8081 are in port ranges of 22,80,8080-8090,443 - """ - - for expected_port in expected_ports: - if check_port_in_destination_port_ranges( - expected_port, dest_port_ranges) is False: - return False - - return True - - -def check_source_address_prefix(source_address_prefix: str): - """Check if source address prefix is BatchNodeManagement or default""" - - required_prefix = 'BatchNodeManagement' - default_prefix = 'default' - - if source_address_prefix.lower() == required_prefix.lower() or \ - source_address_prefix.lower() == default_prefix.lower(): - return True - - return False - - -def check_protocol(protocol: str): - """Check if protocol is supported - Tcp/Any""" - - required_protocol = 'Tcp' - any_protocol = 'Any' - - if required_protocol.lower() == protocol.lower() or \ - any_protocol.lower() == protocol.lower(): - return True - - return False - - -def check_direction(direction: str): - """Check if port direction is inbound""" - - required_direction = 'Inbound' - - if required_direction.lower() == direction.lower(): - return True - - return False - - -def check_provisioning_state(provisioning_state: str): - """Check if the provisioning state is succeeded""" - - required_provisioning_state = 'Succeeded' - - if required_provisioning_state.lower() == provisioning_state.lower(): - return True - - return False - - -def check_rule_for_Azure_ML(rule): - """Check if the ports required for Azure Machine Learning are open""" - - required_ports = ['29876', '29877'] - - if check_source_address_prefix(rule.source_address_prefix) is False: - return False - - if check_protocol(rule.protocol) is False: - return False - - if check_direction(rule.direction) is False: - return False - - if check_provisioning_state(rule.provisioning_state) is False: - return False - - if rule.destination_port_range is not None: - if check_ports_in_destination_port_ranges( - required_ports, - [rule.destination_port_range]) is False: - return False - else: - if check_ports_in_destination_port_ranges( - required_ports, - rule.destination_port_ranges) is False: - return False - - return True - - -def check_vnet_security_rules(auth_object, - vnet_subscription_id, - vnet_resource_group, - vnet_name, - save_to_file=False): - """ - Check all the rules of virtual network if required ports for Azure Machine - Learning are open - """ - - network_client = NetworkManagementClient( - auth_object, - vnet_subscription_id) - - # get the vnet - vnet = network_client.virtual_networks.get( - resource_group_name=vnet_resource_group, - virtual_network_name=vnet_name) - - vnet_location = vnet.location - vnet_info = [] - - if vnet.subnets is None or len(vnet.subnets) == 0: - print('WARNING: No subnet found for VNet:', vnet_name) - - # for each subnet of the vnet - for subnet in vnet.subnets: - if subnet.network_security_group is None: - print('WARNING: No network security group found for subnet.', - 'Subnet', - subnet.id.split("/")[-1]) - else: - # get all the rules - network_security_group_name = \ - subnet.network_security_group.id.split("/")[-1] - network_security_group_resource_group_name = \ - subnet.network_security_group.id.split("/")[4] - network_security_group_subscription_id = \ - subnet.network_security_group.id.split("/")[2] - - security_rules = list(network_client.security_rules.list( - network_security_group_resource_group_name, - network_security_group_name)) - - rule_matched = None - for rule in security_rules: - rule_info = [] - # add vnet details - rule_info.append(vnet_name) - rule_info.append(vnet_subscription_id) - rule_info.append(vnet_resource_group) - rule_info.append(vnet_location) - # add subnet details - rule_info.append(subnet.id.split("/")[-1]) - rule_info.append(network_security_group_name) - rule_info.append(network_security_group_subscription_id) - rule_info.append(network_security_group_resource_group_name) - # add rule details - rule_info.append(rule.priority) - rule_info.append(rule.name) - rule_info.append(rule.source_address_prefix) - if rule.destination_port_range is not None: - rule_info.append(rule.destination_port_range) - else: - rule_info.append(rule.destination_port_ranges) - rule_info.append(rule.direction) - rule_info.append(rule.provisioning_state) - vnet_info.append(rule_info) - - if check_rule_for_Azure_ML(rule) is True: - rule_matched = rule - - if rule_matched is not None: - print("INFORMATION: Rule matched with required ports. Subnet:", - subnet.id.split("/")[-1], "Rule:", rule.name) - else: - print("WARNING: No rule matched with required ports. Subnet:", - subnet.id.split("/")[-1]) - - if save_to_file is True: - file_name = vnet_name + ".csv" - with open(file_name, mode='w') as vnet_rule_file: - vnet_rule_file_writer = csv.writer( - vnet_rule_file, - delimiter=',', - quotechar='"', - quoting=csv.QUOTE_MINIMAL) - header = ['VNet_Name', 'VNet_Subscription_ID', - 'VNet_Resource_Group', 'VNet_Location', - 'Subnet_Name', 'NSG_Name', - 'NSG_Subscription_ID', 'NSG_Resource_Group', - 'Rule_Priority', 'Rule_Name', 'Rule_Source', - 'Rule_Destination_Ports', 'Rule_Direction', - 'Rule_Provisioning_State'] - vnet_rule_file_writer.writerow(header) - vnet_rule_file_writer.writerows(vnet_info) - - print("INFORMATION: Network security group rules for your virtual \ -network are saved in file", file_name) diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/utils/callbacks.py b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/utils/callbacks.py deleted file mode 100644 index 782c694d..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/files/utils/callbacks.py +++ /dev/null @@ -1,18 +0,0 @@ -'''RLlib callbacks module: - Common callback methods to be passed to RLlib trainer. -''' - -from azureml.core import Run - - -def on_train_result(info): - '''Callback on train result to record metrics returned by trainer. - ''' - run = Run.get_context() - run.log( - name='episode_reward_mean', - value=info["result"]["episode_reward_mean"]) - - run.log( - name='episodes_total', - value=info["result"]["episodes_total"]) diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/images/lava_maze_minecraft.gif b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/images/lava_maze_minecraft.gif deleted file mode 100644 index dc508fd4..00000000 Binary files a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/images/lava_maze_minecraft.gif and /dev/null differ diff --git a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/minecraft.ipynb b/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/minecraft.ipynb deleted file mode 100644 index b5cff26f..00000000 --- a/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/minecraft.ipynb +++ /dev/null @@ -1,1076 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Copyright (c) Microsoft Corporation. All rights reserved.\n", - "\n", - "Licensed under the MIT License." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "![Impressions](https://PixelServer20190423114238.azurewebsites.net/api/impressions/MachineLearningNotebooks/how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/minecraft.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Reinforcement Learning in Azure Machine Learning - Training a Minecraft agent using custom environments\n", - "\n", - "This tutorial will show how to set up a more complex reinforcement\n", - "learning (RL) training scenario. It demonstrates how to train an agent to\n", - "navigate through a lava maze in the Minecraft game using Azure Machine\n", - "Learning.\n", - "\n", - "**Please note:** This notebook trains an agent on a randomly generated\n", - "Minecraft level. As a result, on rare occasions, a training run may fail\n", - "to produce a model that can solve the maze. If this happens, you can\n", - "re-run the training step as indicated below.\n", - "\n", - "**Please note:** This notebook uses 1 NC6 type node and 8 D2 type nodes\n", - "for up to 5 hours of training, which corresponds to approximately $9.06 (USD)\n", - "as of May 2020.\n", - "\n", - "Minecraft is currently one of the most popular video\n", - "games and as such has been a study object for RL. [Project \n", - "Malmo](https://www.microsoft.com/en-us/research/project/project-malmo/) is\n", - "a platform for artificial intelligence experimentation and research built on\n", - "top of Minecraft. We will use Minecraft [gym](https://gym.openai.com) environments from Project\n", - "Malmo's 2019 MineRL competition, which are part of the \n", - "[MineRL](http://minerl.io/docs/index.html) Python package.\n", - "\n", - "Minecraft environments require a display to run, so we will demonstrate\n", - "how to set up a virtual display within the docker container used for training.\n", - "Learning will be based on the agent's visual observations. To\n", - "generate the necessary amount of sample data, we will run several\n", - "instances of the Minecraft game in parallel. Below, you can see a video of\n", - "a trained agent navigating a lava maze. Starting from the green position,\n", - "it moves to the blue position by moving forward, turning left or turning right:\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
\n", - " \"Minecraft\n", - "
Fig 1. Video of a trained Minecraft agent navigating a lava maze.
\n", - "\n", - "The tutorial will cover the following steps:\n", - "- Initializing Azure Machine Learning resources for training\n", - "- Training the RL agent with Azure Machine Learning service\n", - "- Monitoring training progress\n", - "- Reviewing training results\n", - "\n", - "\n", - "## Prerequisites\n", - "\n", - "The user should have completed the Azure Machine Learning introductory tutorial.\n", - "You will need to make sure that you have a valid subscription id, a resource group and a\n", - "workspace. For detailed instructions see [Tutorial: Get started creating\n", - "your first ML experiment.](https://docs.microsoft.com/en-us/azure/machine-learning/tutorial-1st-experiment-sdk-setup)\n", - "\n", - "While this is a standalone notebook, we highly recommend going over the\n", - "introductory notebooks for RL first.\n", - "- Getting started:\n", - " - [RL using a compute instance with Azure Machine Learning service](../cartpole-on-compute-instance/cartpole_ci.ipynb)\n", - " - [Using Azure Machine Learning compute](../cartpole-on-single-compute/cartpole_sc.ipynb)\n", - "- [Scaling RL training runs with Azure Machine Learning service](../atari-on-distributed-compute/pong_rllib.ipynb)\n", - "\n", - "\n", - "## Initialize resources\n", - "\n", - "All required Azure Machine Learning service resources for this tutorial can be set up from Jupyter.\n", - "This includes:\n", - "- Connecting to your existing Azure Machine Learning workspace.\n", - "- Creating an experiment to track runs.\n", - "- Setting up a virtual network\n", - "- Creating remote compute targets for [Ray](https://docs.ray.io/en/latest/index.html).\n", - "\n", - "### Azure Machine Learning SDK\n", - "\n", - "Display the Azure Machine Learning SDK version." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import azureml.core\n", - "print(\"Azure Machine Learning SDK Version: \", azureml.core.VERSION) " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Connect to workspace\n", - "\n", - "Get a reference to an existing Azure Machine Learning workspace." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from azureml.core import Workspace\n", - "\n", - "ws = Workspace.from_config()\n", - "print(ws.name, ws.location, ws.resource_group, sep=' | ')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Create an experiment\n", - "\n", - "Create an experiment to track the runs in your workspace. A\n", - "workspace can have multiple experiments and each experiment\n", - "can be used to track multiple runs (see [documentation](https://docs.microsoft.com/en-us/python/api/azureml-core/azureml.core.experiment.experiment?view=azure-ml-py)\n", - "for details)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "nbpresent": { - "id": "bc70f780-c240-4779-96f3-bc5ef9a37d59" - } - }, - "outputs": [], - "source": [ - "from azureml.core import Experiment\n", - "\n", - "exp = Experiment(workspace=ws, name='minecraft-maze')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Create Virtual Network\n", - "\n", - "If you are using separate compute targets for the Ray head and worker, a virtual network must be created in the resource group. If you have alraeady created a virtual network in the resource group, you can skip this step.\n", - "\n", - "To do this, you first must install the Azure Networking API.\n", - "\n", - "`pip install --upgrade azure-mgmt-network==12.0.0`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# If you need to install the Azure Networking SDK, uncomment the following line.\n", - "#!pip install --upgrade azure-mgmt-network==12.0.0" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from azure.mgmt.network import NetworkManagementClient\n", - "\n", - "# Virtual network name\n", - "vnet_name =\"rl_minecraft_vnet\"\n", - "\n", - "# Default subnet\n", - "subnet_name =\"default\"\n", - "\n", - "# The Azure subscription you are using\n", - "subscription_id=ws.subscription_id\n", - "\n", - "# The resource group for the reinforcement learning cluster\n", - "resource_group=ws.resource_group\n", - "\n", - "# Azure region of the resource group\n", - "location=ws.location\n", - "\n", - "network_client = NetworkManagementClient(ws._auth_object, subscription_id)\n", - "\n", - "async_vnet_creation = network_client.virtual_networks.create_or_update(\n", - " resource_group,\n", - " vnet_name,\n", - " {\n", - " 'location': location,\n", - " 'address_space': {\n", - " 'address_prefixes': ['10.0.0.0/16']\n", - " }\n", - " }\n", - ")\n", - "\n", - "async_vnet_creation.wait()\n", - "print(\"Virtual network created successfully: \", async_vnet_creation.result())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Set up Network Security Group on Virtual Network\n", - "\n", - "Depending on your Azure setup, you may need to open certain ports to make it possible for Azure to manage the compute targets that you create. The ports that need to be opened are described [here](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-enable-virtual-network).\n", - "\n", - "A common situation is that ports `29876-29877` are closed. The following code will add a security rule to open these ports. Or you can do this manually in the [Azure portal](https://portal.azure.com).\n", - "\n", - "You may need to modify the code below to match your scenario." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import azure.mgmt.network.models\n", - "\n", - "security_group_name = vnet_name + '-' + \"nsg\"\n", - "security_rule_name = \"AllowAML\"\n", - "\n", - "# Create a network security group\n", - "nsg_params = azure.mgmt.network.models.NetworkSecurityGroup(\n", - " location=location,\n", - " security_rules=[\n", - " azure.mgmt.network.models.SecurityRule(\n", - " name=security_rule_name,\n", - " access=azure.mgmt.network.models.SecurityRuleAccess.allow,\n", - " description='Reinforcement Learning in Azure Machine Learning rule',\n", - " destination_address_prefix='*',\n", - " destination_port_range='29876-29877',\n", - " direction=azure.mgmt.network.models.SecurityRuleDirection.inbound,\n", - " priority=400,\n", - " protocol=azure.mgmt.network.models.SecurityRuleProtocol.tcp,\n", - " source_address_prefix='BatchNodeManagement',\n", - " source_port_range='*'\n", - " ),\n", - " ],\n", - ")\n", - "\n", - "async_nsg_creation = network_client.network_security_groups.create_or_update(\n", - " resource_group,\n", - " security_group_name,\n", - " nsg_params,\n", - ")\n", - "\n", - "async_nsg_creation.wait() \n", - "print(\"Network security group created successfully:\", async_nsg_creation.result())\n", - "\n", - "network_security_group = network_client.network_security_groups.get(\n", - " resource_group,\n", - " security_group_name,\n", - ")\n", - "\n", - "# Define a subnet to be created with network security group\n", - "subnet = azure.mgmt.network.models.Subnet(\n", - " id='default',\n", - " address_prefix='10.0.0.0/24',\n", - " network_security_group=network_security_group\n", - " )\n", - " \n", - "# Create subnet on virtual network\n", - "async_subnet_creation = network_client.subnets.create_or_update(\n", - " resource_group_name=resource_group,\n", - " virtual_network_name=vnet_name,\n", - " subnet_name=subnet_name,\n", - " subnet_parameters=subnet\n", - ")\n", - "\n", - "async_subnet_creation.wait()\n", - "print(\"Subnet created successfully:\", async_subnet_creation.result())" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Review the virtual network security rules\n", - "Ensure that the virtual network is configured correctly with required ports open. It is possible that you have configured rules with broader range of ports that allows ports 29876-29877 to be opened. Kindly review your network security group rules. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from files.networkutils import *\n", - "\n", - "check_vnet_security_rules(ws._auth_object, ws.subscription_id, ws.resource_group, vnet_name, True)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Create or attach an existing compute resource\n", - "\n", - "A compute target is a designated compute resource where you\n", - "run your training script. For more information, see [What\n", - "are compute targets in Azure Machine Learning service?](https://docs.microsoft.com/en-us/azure/machine-learning/concept-compute-target).\n", - "\n", - "#### GPU target for Ray head\n", - "\n", - "In the experiment setup for this tutorial, the Ray head node\n", - "will run on a GPU-enabled node. A maximum cluster size\n", - "of 1 node is therefore sufficient. If you wish to run\n", - "multiple experiments in parallel using the same GPU\n", - "cluster, you may elect to increase this number. The cluster\n", - "will automatically scale down to 0 nodes when no training jobs\n", - "are scheduled (see `min_nodes`).\n", - "\n", - "The code below creates a compute cluster of GPU-enabled NC6\n", - "nodes. If the cluster with the specified name is already in\n", - "your workspace the code will skip the creation process.\n", - "\n", - "Note that we must specify a Virtual Network during compute\n", - "creation to allow communication between the cluster running\n", - "the Ray head node and the additional Ray compute nodes. For\n", - "details on how to setup the Virtual Network, please follow the\n", - "instructions in the \"Prerequisites\" section above.\n", - "\n", - "**Note: Creation of a compute resource can take several minutes**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from azureml.core.compute import ComputeTarget, AmlCompute\n", - "from azureml.core.compute_target import ComputeTargetException\n", - "\n", - "gpu_cluster_name = 'gpu-cl-nc6-vnet'\n", - "\n", - "try:\n", - " gpu_cluster = ComputeTarget(workspace=ws, name=gpu_cluster_name)\n", - " print('Found existing compute target')\n", - "except ComputeTargetException:\n", - " print('Creating a new compute target...')\n", - " compute_config = AmlCompute.provisioning_configuration(\n", - " vm_size='Standard_NC6',\n", - " min_nodes=0,\n", - " max_nodes=1,\n", - " vnet_resourcegroup_name=ws.resource_group,\n", - " vnet_name=vnet_name,\n", - " subnet_name=subnet_name)\n", - "\n", - " gpu_cluster = ComputeTarget.create(ws, gpu_cluster_name, compute_config)\n", - " gpu_cluster.wait_for_completion(show_output=True, min_node_count=None, timeout_in_minutes=20)\n", - "\n", - " print('Cluster created.')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### CPU target for additional Ray nodes\n", - "\n", - "The code below creates a compute cluster of D2 nodes. If the cluster with the specified name is already in your workspace the code will skip the creation process.\n", - "\n", - "This cluster will be used to start additional Ray nodes\n", - "increasing the clusters CPU resources.\n", - "\n", - "**Note: Creation of a compute resource can take several minutes**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "cpu_cluster_name = 'cpu-cl-d2-vnet'\n", - "\n", - "try:\n", - " cpu_cluster = ComputeTarget(workspace=ws, name=cpu_cluster_name)\n", - " print('Found existing compute target')\n", - "except ComputeTargetException:\n", - " print('Creating a new compute target...')\n", - " compute_config = AmlCompute.provisioning_configuration(\n", - " vm_size='STANDARD_D2',\n", - " min_nodes=0,\n", - " max_nodes=10,\n", - " vnet_resourcegroup_name=ws.resource_group,\n", - " vnet_name=vnet_name,\n", - " subnet_name=subnet_name)\n", - "\n", - " cpu_cluster = ComputeTarget.create(ws, cpu_cluster_name, compute_config)\n", - " cpu_cluster.wait_for_completion(show_output=True, min_node_count=None, timeout_in_minutes=20)\n", - "\n", - " print('Cluster created.')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Training the agent\n", - "\n", - "### Training environments\n", - "\n", - "This tutorial uses custom docker images (CPU and GPU respectively)\n", - "with the necessary software installed. The\n", - "[Environment](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-use-environments)\n", - "class stores the configuration for the training environment. The docker\n", - "image is set via `env.docker.base_image` which can point to any\n", - "publicly available docker image. `user_managed_dependencies`\n", - "is set so that the preinstalled Python packages in the image are preserved.\n", - "\n", - "Note that since Minecraft requires a display to start, we set the `interpreter_path`\n", - "such that the Python process is started via **xvfb-run**." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import os\n", - "from azureml.core import Environment\n", - "\n", - "max_train_time = os.environ.get(\"AML_MAX_TRAIN_TIME_SECONDS\", 5 * 60 * 60)\n", - "\n", - "def create_env(env_type):\n", - " env = Environment(name='minecraft-{env_type}'.format(env_type=env_type))\n", - "\n", - " env.docker.enabled = True\n", - " env.docker.base_image = 'akdmsft/minecraft-{env_type}'.format(env_type=env_type)\n", - "\n", - " env.python.interpreter_path = \"xvfb-run -s '-screen 0 640x480x16 -ac +extension GLX +render' python\"\n", - " env.environment_variables[\"AML_MAX_TRAIN_TIME_SECONDS\"] = str(max_train_time)\n", - " env.python.user_managed_dependencies = True\n", - " \n", - " return env\n", - " \n", - "cpu_minecraft_env = create_env('cpu')\n", - "gpu_minecraft_env = create_env('gpu')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Training script\n", - "\n", - "As described above, we use the MineRL Python package to launch\n", - "Minecraft game instances. MineRL provides several OpenAI gym\n", - "environments for different scenarios, such as chopping wood.\n", - "Besides predefined environments, MineRL lets its users create\n", - "custom Minecraft environments through\n", - "[minerl.env](http://minerl.io/docs/api/env.html). In the helper\n", - "file **minecraft_environment.py** provided with this tutorial, we use the\n", - "latter option to customize a Minecraft level with a lava maze\n", - "that the agent has to navigate. The agent receives a negative\n", - "reward of -1 for falling into the lava, a negative reward of\n", - "-0.02 for sending a command (i.e. navigating through the maze\n", - "with fewer actions yields a higher total reward) and a positive reward\n", - "of 1 for reaching the goal. To encourage the agent to explore\n", - "the maze, it also receives a positive reward of 0.1 for visiting\n", - "a tile for the first time.\n", - "\n", - "The agent learns purely from visual observations and the image\n", - "is scaled to an 84x84 format, stacking four frames. For the\n", - "purposes of this example, we use a small action space of size\n", - "three: move forward, turn 90 degrees to the left, and turn 90\n", - "degrees to the right.\n", - "\n", - "The training script itself registers the function to create training\n", - "environments with the `tune.register_env` function and connects to\n", - "the Ray cluster Azure Machine Learning service started on the GPU \n", - "and CPU nodes. Lastly, it starts a RL training run with `tune.run()`.\n", - "\n", - "We recommend setting the `local_dir` parameter to `./logs` as this\n", - "directory will automatically become available as part of the training\n", - "run's files in the Azure Portal. The Tensorboard integration\n", - "(see \"View the Tensorboard\" section below) also depends on the files'\n", - "availability. For a list of common parameter options, please refer\n", - "to the [Ray documentation](https://docs.ray.io/en/latest/rllib-training.html#common-parameters).\n", - "\n", - "\n", - "```python\n", - "# Taken from minecraft_environment.py and minecraft_train.py\n", - "\n", - "# Define a function to create a MineRL environment\n", - "def create_env(config):\n", - " mission = config['mission']\n", - " port = 1000 * config.worker_index + config.vector_index\n", - " print('*********************************************')\n", - " print(f'* Worker {config.worker_index} creating from mission: {mission}, port {port}')\n", - " print('*********************************************')\n", - "\n", - " if config.worker_index == 0:\n", - " # The first environment is only used for checking the action and observation space.\n", - " # By using a dummy environment, there's no need to spin up a Minecraft instance behind it\n", - " # saving some CPU resources on the head node.\n", - " return DummyEnv()\n", - "\n", - " env = EnvWrapper(mission, port)\n", - " env = TrackingEnv(env)\n", - " env = FrameStack(env, 2)\n", - " \n", - " return env\n", - "\n", - "\n", - "def stop(trial_id, result):\n", - " return result[\"episode_reward_mean\"] >= 1 \\\n", - " or result[\"time_total_s\"] > 5 * 60 * 60\n", - "\n", - "\n", - "if __name__ == '__main__':\n", - " tune.register_env(\"Minecraft\", create_env)\n", - "\n", - " ray.init(address='auto')\n", - "\n", - " tune.run(\n", - " run_or_experiment=\"IMPALA\",\n", - " config={\n", - " \"env\": \"Minecraft\",\n", - " \"env_config\": {\n", - " \"mission\": \"minecraft_missions/lava_maze-v0.xml\"\n", - " },\n", - " \"num_workers\": 10,\n", - " \"num_cpus_per_worker\": 2,\n", - " \"rollout_fragment_length\": 50,\n", - " \"train_batch_size\": 1024,\n", - " \"replay_buffer_num_slots\": 4000,\n", - " \"replay_proportion\": 10,\n", - " \"learner_queue_timeout\": 900,\n", - " \"num_sgd_iter\": 2,\n", - " \"num_data_loader_buffers\": 2,\n", - " \"exploration_config\": {\n", - " \"type\": \"EpsilonGreedy\",\n", - " \"initial_epsilon\": 1.0,\n", - " \"final_epsilon\": 0.02,\n", - " \"epsilon_timesteps\": 500000\n", - " },\n", - " \"callbacks\": {\"on_train_result\": callbacks.on_train_result},\n", - " },\n", - " stop=stop,\n", - " checkpoint_at_end=True,\n", - " local_dir='./logs'\n", - " )\n", - "```\n", - "\n", - "### Submitting a training run\n", - "\n", - "Below, you create the training run using a `ReinforcementLearningEstimator`\n", - "object, which contains all the configuration parameters for this experiment:\n", - "- `source_directory`: Contains the training script and helper files to be\n", - "copied onto the node running the Ray head.\n", - "- `entry_script`: The training script, described in more detail above..\n", - "- `compute_target`: The compute target for the Ray head and training\n", - "script execution.\n", - "- `environment`: The Azure machine learning environment definition for\n", - "the node running the Ray head.\n", - "- `worker_configuration`: The configuration object for the additional\n", - "Ray nodes to be attached to the Ray cluster:\n", - " - `compute_target`: The compute target for the additional Ray nodes.\n", - " - `node_count`: The number of nodes to attach to the Ray cluster.\n", - " - `environment`: The environment definition for the additional Ray nodes.\n", - "- `max_run_duration_seconds`: The time after which to abort the run if it\n", - "is still running.\n", - "- `shm_size`: The size of docker container's shared memory block. \n", - "\n", - "For more details, please take a look at the [online documentation](https://docs.microsoft.com/en-us/python/api/azureml-contrib-reinforcementlearning/?view=azure-ml-py)\n", - "for Azure Machine Learning service's reinforcement learning offering.\n", - "\n", - "We configure 8 extra D2 (worker) nodes for the Ray cluster, giving us a total of\n", - "22 CPUs and 1 GPU. The GPU and one CPU are used by the IMPALA learner,\n", - "and each MineRL environment receives 2 CPUs allowing us to spawn a total\n", - "of 10 rollout workers (see `num_workers` parameter in the training script).\n", - "\n", - "\n", - "Lastly, the `RunDetails` widget displays information about the submitted\n", - "RL experiment, including a link to the Azure portal with more details." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from azureml.contrib.train.rl import ReinforcementLearningEstimator, WorkerConfiguration\n", - "from azureml.widgets import RunDetails\n", - "\n", - "worker_config = WorkerConfiguration(\n", - " compute_target=cpu_cluster, \n", - " node_count=8,\n", - " environment=cpu_minecraft_env)\n", - "\n", - "rl_est = ReinforcementLearningEstimator(\n", - " source_directory='files',\n", - " entry_script='minecraft_train.py',\n", - " compute_target=gpu_cluster,\n", - " environment=gpu_minecraft_env,\n", - " worker_configuration=worker_config,\n", - " max_run_duration_seconds=6 * 60 * 60,\n", - " shm_size=1024 * 1024 * 1024 * 30)\n", - "\n", - "train_run = exp.submit(rl_est)\n", - "\n", - "RunDetails(train_run).show()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# If you wish to cancel the run before it completes, uncomment and execute:\n", - "#train_run.cancel()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Monitoring training progress\n", - "\n", - "### View the Tensorboard\n", - "\n", - "The Tensorboard can be displayed via the Azure Machine Learning service's\n", - "[Tensorboard API](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-monitor-tensorboard).\n", - "When running locally, please make sure to follow the instructions in the\n", - "link and install required packages. Running this cell will output a URL\n", - "for the Tensorboard.\n", - "\n", - "Note that the training script sets the log directory when starting RLlib\n", - "via the `local_dir` parameter. `./logs` will automatically appear in\n", - "the downloadable files for a run. Since this script is executed on the\n", - "Ray head node run, we need to get a reference to it as shown below.\n", - "\n", - "The Tensorboard API will continuously stream logs from the run.\n", - "\n", - "**Note: It may take a couple of minutes after the run is in \"Running\" state\n", - "before Tensorboard files are available and the board will refresh automatically**" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import time\n", - "from azureml.tensorboard import Tensorboard\n", - "\n", - "head_run = None\n", - "\n", - "timeout = 60\n", - "while timeout > 0 and head_run is None:\n", - " timeout -= 1\n", - " \n", - " try:\n", - " head_run = next(r for r in train_run.get_children() if r.id.endswith('head'))\n", - " except StopIteration:\n", - " time.sleep(1)\n", - "\n", - "tb = Tensorboard([head_run], port=6007)\n", - "tb.start()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Review results\n", - "\n", - "Please ensure that the training run has completed before continuing with this section." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "train_run.wait_for_completion()\n", - "\n", - "print('Training run completed.')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "**Please note:** If the final \"episode_reward_mean\" metric from the training run is negative,\n", - "the produced model does not solve the problem of navigating the maze well. You can view\n", - "the metric on the Tensorboard or in \"Metrics\" section of the head run in the Azure Machine Learning\n", - "portal. We recommend training a new model by rerunning the notebook starting from \"Submitting a training run\".\n", - "\n", - "\n", - "### Export final model\n", - "\n", - "The key result from the training run is the final checkpoint\n", - "containing the state of the IMPALA trainer (model) upon meeting the\n", - "stopping criteria specified in `minecraft_train.py`.\n", - "\n", - "Azure Machine Learning service offers the [Model.register()](https://docs.microsoft.com/en-us/python/api/azureml-core/azureml.core.model.model?view=azure-ml-py)\n", - "API which allows you to persist the model files from the\n", - "training run. We identify the directory containing the\n", - "final model written during the training run and register\n", - "it with Azure Machine Learning service. We use a Dataset\n", - "object to filter out the correct files." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import re\n", - "import tempfile\n", - "\n", - "from azureml.core import Dataset\n", - "\n", - "path_prefix = os.path.join(tempfile.gettempdir(), 'tmp_training_artifacts')\n", - "\n", - "run_artifacts_path = os.path.join('azureml', head_run.id)\n", - "datastore = ws.get_default_datastore()\n", - "\n", - "run_artifacts_ds = Dataset.File.from_files(datastore.path(os.path.join(run_artifacts_path, '**')))\n", - "\n", - "cp_pattern = re.compile('.*checkpoint-\\\\d+$')\n", - "\n", - "checkpoint_files = [file for file in run_artifacts_ds.to_path() if cp_pattern.match(file)]\n", - "\n", - "# There should only be one checkpoint with our training settings...\n", - "final_checkpoint = os.path.dirname(os.path.join(run_artifacts_path, os.path.normpath(checkpoint_files[-1][1:])))\n", - "datastore.download(target_path=path_prefix, prefix=final_checkpoint.replace('\\\\', '/'), show_progress=True)\n", - "\n", - "print('Download complete.')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from azureml.core.model import Model\n", - "\n", - "model_name = 'final_model_minecraft_maze'\n", - "\n", - "model = Model.register(\n", - " workspace=ws,\n", - " model_path=os.path.join(path_prefix, final_checkpoint),\n", - " model_name=model_name,\n", - " description='Model of an agent trained to navigate a lava maze in Minecraft.')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Models can be used through a varity of APIs. Please see the\n", - "[documentation](https://docs.microsoft.com/en-us/azure/machine-learning/how-to-deploy-and-where)\n", - "for more details.\n", - "\n", - "### Test agent performance in a rollout\n", - "\n", - "To observe the trained agent's behavior, it is a common practice to\n", - "view its behavior in a rollout. The previous reinforcement learning\n", - "tutorials explain rollouts in more detail.\n", - "\n", - "The provided `minecraft_rollout.py` script loads the final checkpoint\n", - "of the trained agent from the model registered with Azure Machine Learning\n", - "service. It then starts a rollout on 4 different lava maze layouts, that\n", - "are all larger and thus more difficult than the maze the agent was trained\n", - "on. The script further records videos by replaying the agent's decisions\n", - "in [Malmo](https://github.com/microsoft/malmo). Malmo supports multiple\n", - "agents in the same environment, thus allowing us to capture videos that\n", - "depict the agent from another agent's perspective. The provided\n", - "`malmo_video_recorder.py` file and the Malmo Github repository have more\n", - "details on the video recording setup.\n", - "\n", - "You can view the rewards for each rollout episode in the logs for the 'head'\n", - "run submitted below. In some episodes, the agent may fail to reach the goal\n", - "due to the higher level of difficulty - in practice, we could continue\n", - "training the agent on harder tasks starting with the final checkpoint." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "script_params = {\n", - " '--model_name': model_name\n", - "}\n", - "\n", - "rollout_est = ReinforcementLearningEstimator(\n", - " source_directory='files',\n", - " entry_script='minecraft_rollout.py',\n", - " script_params=script_params,\n", - " compute_target=gpu_cluster,\n", - " environment=gpu_minecraft_env,\n", - " shm_size=1024 * 1024 * 1024 * 30)\n", - "\n", - "rollout_run = exp.submit(rollout_est)\n", - "\n", - "RunDetails(rollout_run).show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### View videos captured during rollout\n", - "\n", - "To inspect the agent's training progress you can view the videos captured\n", - "during the rollout episodes. First, ensure that the training run has\n", - "completed." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rollout_run.wait_for_completion()\n", - "\n", - "head_run_rollout = next(r for r in rollout_run.get_children() if r.id.endswith('head'))\n", - "\n", - "print('Rollout completed.')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, you need to download the video files from the training run. We use a\n", - "Dataset to filter out the video files which are in tgz archives." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "rollout_run_artifacts_path = os.path.join('azureml', head_run_rollout.id)\n", - "datastore = ws.get_default_datastore()\n", - "\n", - "rollout_run_artifacts_ds = Dataset.File.from_files(datastore.path(os.path.join(rollout_run_artifacts_path, '**')))\n", - "\n", - "video_archives = [file for file in rollout_run_artifacts_ds.to_path() if file.endswith('.tgz')]\n", - "video_archives = [os.path.join(rollout_run_artifacts_path, os.path.normpath(file[1:])) for file in video_archives]\n", - "\n", - "datastore.download(\n", - " target_path=path_prefix,\n", - " prefix=os.path.dirname(video_archives[0]).replace('\\\\', '/'),\n", - " show_progress=True)\n", - "\n", - "print('Download complete.')" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Next, unzip the video files and rename them by the Minecraft mission seed used\n", - "(see `minecraft_rollout.py` for more details on how the seed is used)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "import tarfile\n", - "import shutil\n", - "\n", - "training_artifacts_dir = './training_artifacts'\n", - "video_dir = os.path.join(training_artifacts_dir, 'videos')\n", - "video_files = []\n", - "\n", - "for tar_file_path in video_archives:\n", - " seed = tar_file_path[tar_file_path.index('rollout_') + len('rollout_'): tar_file_path.index('.tgz')]\n", - " \n", - " tar = tarfile.open(os.path.join(path_prefix, tar_file_path).replace('\\\\', '/'), 'r')\n", - " tar_info = next(t_info for t_info in tar.getmembers() if t_info.name.endswith('mp4'))\n", - " tar.extract(tar_info, video_dir)\n", - " tar.close()\n", - " \n", - " unzipped_folder = os.path.join(video_dir, next(f_ for f_ in os.listdir(video_dir) if not f_.endswith('mp4'))) \n", - " video_file = os.path.join(unzipped_folder,'video.mp4')\n", - " final_video_path = os.path.join(video_dir, '{seed}.mp4'.format(seed=seed))\n", - " \n", - " shutil.move(video_file, final_video_path) \n", - " video_files.append(final_video_path)\n", - " \n", - " shutil.rmtree(unzipped_folder)\n", - "\n", - "# Clean up any downloaded 'tmp' files\n", - "shutil.rmtree(path_prefix)\n", - "\n", - "print('Local video files:\\n', video_files)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Finally, run the cell below to display the videos in-line. In some cases,\n", - "the agent may struggle to find the goal since the maze size was increased\n", - "compared to training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "from IPython.core.display import display, HTML\n", - "\n", - "index = 0\n", - "while index < len(video_files) - 1:\n", - " display(\n", - " HTML('\\\n", - " \\\n", - " '.format(f1=video_files[index], f2=video_files[index + 1]))\n", - " )\n", - " \n", - " index += 2\n", - "\n", - "if index < len(video_files):\n", - " display(\n", - " HTML('\\\n", - " '.format(f1=video_files[index]))\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Cleaning up\n", - "\n", - "Below, you can find code snippets for your convenience to clean up any resources created as part of this tutorial you don't wish to retain." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# to stop the Tensorboard, uncomment and run\n", - "#tb.stop()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# to delete the gpu compute target, uncomment and run\n", - "#gpu_cluster.delete()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# to delete the cpu compute target, uncomment and run\n", - "#cpu_cluster.delete()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# to delete the registered model, uncomment and run\n", - "#model.delete()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# to delete the local video files, uncomment and run\n", - "#shutil.rmtree(training_artifacts_dir)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Next steps\n", - "\n", - "This is currently the last introductory tutorial for Azure Machine Learning\n", - "service's Reinforcement\n", - "Learning offering. We would love to hear your feedback to build the features\n", - "you need!\n", - "\n" - ] - } - ], - "metadata": { - "authors": [ - { - "name": "andress" - } - ], - "kernelspec": { - "display_name": "Python 3.6", - "language": "python", - "name": "python36" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - }, - "notice": "Copyright (c) Microsoft Corporation. All rights reserved.\u00e2\u20ac\u00afLicensed under the MIT License.\u00e2\u20ac\u00af " - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/how-to-use-azureml/reinforcement-learning/multiagent-particle-envs/particle.yml b/how-to-use-azureml/reinforcement-learning/multiagent-particle-envs/particle.yml new file mode 100644 index 00000000..b1c52d07 --- /dev/null +++ b/how-to-use-azureml/reinforcement-learning/multiagent-particle-envs/particle.yml @@ -0,0 +1,9 @@ +name: particle +dependencies: +- pip: + - azureml-sdk + - azureml-contrib-reinforcementlearning + - azureml-widgets + - tensorboard + - azureml-tensorboard + - ipython diff --git a/how-to-use-azureml/responsible-ai/visualize-upload-loan-decision/rai-loan-decision.ipynb b/how-to-use-azureml/responsible-ai/visualize-upload-loan-decision/rai-loan-decision.ipynb index a154615f..b008df60 100644 --- a/how-to-use-azureml/responsible-ai/visualize-upload-loan-decision/rai-loan-decision.ipynb +++ b/how-to-use-azureml/responsible-ai/visualize-upload-loan-decision/rai-loan-decision.ipynb @@ -33,7 +33,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Install required packages" + "## Install required packages\n", + "\n", + "This notebook works with Fairlearn v0.4.6, and not later versions. If needed, please uncomment and run the following cell:" ] }, { @@ -42,9 +44,7 @@ "metadata": {}, "outputs": [], "source": [ - "%pip install --upgrade fairlearn\n", - "%pip install --upgrade interpret-community\n", - "%pip install --upgrade raiwidgets" + "# %pip install --upgrade fairlearn==0.4.6" ] }, { @@ -71,8 +71,6 @@ "source": [ "from fairlearn.reductions import GridSearch\n", "from fairlearn.reductions import DemographicParity, ErrorRate\n", - "from fairlearn.datasets import fetch_adult\n", - "from fairlearn.metrics import MetricFrame, selection_rate\n", "\n", "from sklearn import svm, neighbors, tree\n", "from sklearn.compose import ColumnTransformer, make_column_selector\n", @@ -83,6 +81,7 @@ "from sklearn.preprocessing import StandardScaler, OneHotEncoder\n", "from sklearn.svm import SVC\n", "from sklearn.metrics import accuracy_score\n", + "from sklearn.datasets import fetch_openml\n", "\n", "import pandas as pd\n", "import numpy as np\n", @@ -106,7 +105,7 @@ "metadata": {}, "outputs": [], "source": [ - "dataset = fetch_adult(as_frame=True)\n", + "dataset = fetch_openml(data_id=1590, as_frame=True)\n", "X_raw, y = dataset['data'], dataset['target']\n", "X_raw[\"race\"].value_counts().to_dict()" ] @@ -342,13 +341,13 @@ "metadata": {}, "outputs": [], "source": [ - "from raiwidgets import FairnessDashboard\n", + "from fairlearn.widget import FairlearnDashboard\n", "\n", "y_pred = model.predict(X_test)\n", "\n", - "FairnessDashboard(sensitive_features=sensitive_features_test,\n", - " y_true=y_test,\n", - " y_pred=y_pred)" + "FairlearnDashboard(sensitive_features=sensitive_features_test,\n", + " y_true=y_test,\n", + " y_pred=y_pred)" ] }, { @@ -404,7 +403,7 @@ "sweep.fit(X_train_prep, y_train,\n", " sensitive_features=sensitive_features_train.sex)\n", "\n", - "predictors = sweep.predictors_" + "predictors = sweep._predictors" ] }, { @@ -420,13 +419,18 @@ "metadata": {}, "outputs": [], "source": [ + "from fairlearn.metrics import demographic_parity_difference\n", + "\n", "accuracies, disparities = [], []\n", "\n", "for predictor in predictors:\n", - " accuracy_metric_frame = MetricFrame(accuracy_score, y_train, predictor.predict(X_train_prep), sensitive_features=sensitive_features_train.sex)\n", - " selection_rate_metric_frame = MetricFrame(selection_rate, y_train, predictor.predict(X_train_prep), sensitive_features=sensitive_features_train.sex)\n", - " accuracies.append(accuracy_metric_frame.overall)\n", - " disparities.append(selection_rate_metric_frame.difference())\n", + " y_pred = predictor.predict(X_train_prep)\n", + " # accuracy_metric_frame = MetricFrame(accuracy_score, y_train, predictor.predict(X_train_prep), sensitive_features=sensitive_features_train.sex)\n", + " # selection_rate_metric_frame = MetricFrame(selection_rate, y_train, predictor.predict(X_train_prep), sensitive_features=sensitive_features_train.sex)\n", + " accuracies.append(accuracy_score(y_train, y_pred))\n", + " disparities.append(demographic_parity_difference(y_train,\n", + " y_pred,\n", + " sensitive_features=sensitive_features_train.sex))\n", " \n", "all_results = pd.DataFrame({\"predictor\": predictors, \"accuracy\": accuracies, \"disparity\": disparities})\n", "\n", @@ -456,8 +460,6 @@ "metadata": {}, "outputs": [], "source": [ - "from raiwidgets import FairnessDashboard\n", - "\n", "dashboard_all = {}\n", "for name, predictor in all_models_dict.items():\n", " value = predictor.predict(X_test_prep)\n", @@ -467,7 +469,7 @@ "for name, predictor in dominant_models_dict.items():\n", " dominant_all[name] = predictor.predict(X_test_prep)\n", "\n", - "FairnessDashboard(sensitive_features=sensitive_features_test, \n", + "FairlearnDashboard(sensitive_features=sensitive_features_test, \n", " y_true=y_test,\n", " y_pred=dominant_all)" ] @@ -551,7 +553,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Next, we register each of the models in the `dashboard_predicted` dictionary into the workspace. For this, we have to save each model to a file, and then register that file:" + "Next, we register each of the models in the `dominant_all` dictionary into the workspace. For this, we have to save each model to a file, and then register that file:" ] }, { @@ -576,7 +578,7 @@ " return registered_model.id\n", "\n", "model_name_id_mapping = dict()\n", - "for name, model in dashboard_all.items():\n", + "for name, model in dominant_all.items():\n", " m_id = register_model(name, model)\n", " model_name_id_mapping[name] = m_id" ] @@ -594,9 +596,9 @@ "metadata": {}, "outputs": [], "source": [ - "dashboard_all_ids = dict()\n", - "for name, y_pred in dashboard_all.items():\n", - " dashboard_all_ids[model_name_id_mapping[name]] = y_pred" + "dominant_all_ids = dict()\n", + "for name, y_pred in dominant_all.items():\n", + " dominant_all_ids[model_name_id_mapping[name]] = y_pred" ] }, { @@ -619,7 +621,7 @@ "from fairlearn.metrics._group_metric_set import _create_group_metric_set\n", "\n", "dash_dict_all = _create_group_metric_set(y_true=y_test,\n", - " predictions=dashboard_all_ids,\n", + " predictions=dominant_all_ids,\n", " sensitive_features=sf,\n", " prediction_type='binary_classification')" ] diff --git a/how-to-use-azureml/responsible-ai/visualize-upload-loan-decision/rai-loan-decision.yml b/how-to-use-azureml/responsible-ai/visualize-upload-loan-decision/rai-loan-decision.yml new file mode 100644 index 00000000..3fffcf2f --- /dev/null +++ b/how-to-use-azureml/responsible-ai/visualize-upload-loan-decision/rai-loan-decision.yml @@ -0,0 +1,12 @@ +name: rai-loan-decision +dependencies: +- pip: + - azureml-sdk + - azureml-interpret + - azureml-contrib-fairness + - interpret-community[visualization] + - fairlearn==0.4.6 + - matplotlib + - azureml-dataset-runtime + - ipywidgets + - raiwidgets diff --git a/how-to-use-azureml/track-and-monitor-experiments/logging-api/logging-api.ipynb b/how-to-use-azureml/track-and-monitor-experiments/logging-api/logging-api.ipynb index 3d153063..5ccabd65 100644 --- a/how-to-use-azureml/track-and-monitor-experiments/logging-api/logging-api.ipynb +++ b/how-to-use-azureml/track-and-monitor-experiments/logging-api/logging-api.ipynb @@ -100,7 +100,7 @@ "\n", "# Check core SDK version number\n", "\n", - "print(\"This notebook was created using SDK version 1.23.0, you are currently running version\", azureml.core.VERSION)" + "print(\"This notebook was created using SDK version 1.24.0, you are currently running version\", azureml.core.VERSION)" ] }, { diff --git a/how-to-use-azureml/track-and-monitor-experiments/logging-api/logging-api.yml b/how-to-use-azureml/track-and-monitor-experiments/logging-api/logging-api.yml new file mode 100644 index 00000000..42144437 --- /dev/null +++ b/how-to-use-azureml/track-and-monitor-experiments/logging-api/logging-api.yml @@ -0,0 +1,8 @@ +name: logging-api +dependencies: +- numpy +- matplotlib +- tqdm +- pip: + - azureml-sdk + - azureml-widgets diff --git a/how-to-use-azureml/track-and-monitor-experiments/manage-runs/manage-runs.yml b/how-to-use-azureml/track-and-monitor-experiments/manage-runs/manage-runs.yml new file mode 100644 index 00000000..34a95ec8 --- /dev/null +++ b/how-to-use-azureml/track-and-monitor-experiments/manage-runs/manage-runs.yml @@ -0,0 +1,4 @@ +name: manage-runs +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/track-and-monitor-experiments/tensorboard/export-run-history-to-tensorboard/export-run-history-to-tensorboard.yml b/how-to-use-azureml/track-and-monitor-experiments/tensorboard/export-run-history-to-tensorboard/export-run-history-to-tensorboard.yml new file mode 100644 index 00000000..17aa9f1d --- /dev/null +++ b/how-to-use-azureml/track-and-monitor-experiments/tensorboard/export-run-history-to-tensorboard/export-run-history-to-tensorboard.yml @@ -0,0 +1,10 @@ +name: export-run-history-to-tensorboard +dependencies: +- pip: + - azureml-sdk + - azureml-tensorboard + - tensorflow + - tqdm + - scipy + - sklearn + - setuptools>=41.0.0 diff --git a/how-to-use-azureml/track-and-monitor-experiments/tensorboard/tensorboard/tensorboard.yml b/how-to-use-azureml/track-and-monitor-experiments/tensorboard/tensorboard/tensorboard.yml new file mode 100644 index 00000000..024d3600 --- /dev/null +++ b/how-to-use-azureml/track-and-monitor-experiments/tensorboard/tensorboard/tensorboard.yml @@ -0,0 +1,7 @@ +name: tensorboard +dependencies: +- pip: + - azureml-sdk + - azureml-tensorboard + - tensorflow + - setuptools>=41.0.0 diff --git a/how-to-use-azureml/track-and-monitor-experiments/using-mlflow/train-local/train-local.yml b/how-to-use-azureml/track-and-monitor-experiments/using-mlflow/train-local/train-local.yml new file mode 100644 index 00000000..5095b89f --- /dev/null +++ b/how-to-use-azureml/track-and-monitor-experiments/using-mlflow/train-local/train-local.yml @@ -0,0 +1,7 @@ +name: train-local +dependencies: +- scikit-learn +- matplotlib +- pip: + - azureml-sdk + - azureml-mlflow diff --git a/how-to-use-azureml/track-and-monitor-experiments/using-mlflow/train-remote/train-remote.yml b/how-to-use-azureml/track-and-monitor-experiments/using-mlflow/train-remote/train-remote.yml new file mode 100644 index 00000000..e96f6ab6 --- /dev/null +++ b/how-to-use-azureml/track-and-monitor-experiments/using-mlflow/train-remote/train-remote.yml @@ -0,0 +1,4 @@ +name: train-remote +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/training/train-on-amlcompute/train-on-amlcompute.yml b/how-to-use-azureml/training/train-on-amlcompute/train-on-amlcompute.yml new file mode 100644 index 00000000..57cc15b6 --- /dev/null +++ b/how-to-use-azureml/training/train-on-amlcompute/train-on-amlcompute.yml @@ -0,0 +1,6 @@ +name: train-on-amlcompute +dependencies: +- scikit-learn +- pip: + - azureml-sdk + - azureml-widgets diff --git a/how-to-use-azureml/training/train-on-local/train-on-local.yml b/how-to-use-azureml/training/train-on-local/train-on-local.yml new file mode 100644 index 00000000..76f64467 --- /dev/null +++ b/how-to-use-azureml/training/train-on-local/train-on-local.yml @@ -0,0 +1,7 @@ +name: train-on-local +dependencies: +- matplotlib +- scikit-learn +- pip: + - azureml-sdk + - azureml-widgets diff --git a/how-to-use-azureml/training/using-environments/using-environments.yml b/how-to-use-azureml/training/using-environments/using-environments.yml new file mode 100644 index 00000000..88422a40 --- /dev/null +++ b/how-to-use-azureml/training/using-environments/using-environments.yml @@ -0,0 +1,4 @@ +name: using-environments +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/work-with-data/datadrift-tutorial/datadrift-tutorial.yml b/how-to-use-azureml/work-with-data/datadrift-tutorial/datadrift-tutorial.yml new file mode 100644 index 00000000..6633d9e5 --- /dev/null +++ b/how-to-use-azureml/work-with-data/datadrift-tutorial/datadrift-tutorial.yml @@ -0,0 +1,5 @@ +name: datadrift-tutorial +dependencies: +- pip: + - azureml-sdk + - azureml-datadrift diff --git a/how-to-use-azureml/work-with-data/datasets-tutorial/pipeline-with-datasets/pipeline-for-image-classification.yml b/how-to-use-azureml/work-with-data/datasets-tutorial/pipeline-with-datasets/pipeline-for-image-classification.yml new file mode 100644 index 00000000..f33e9474 --- /dev/null +++ b/how-to-use-azureml/work-with-data/datasets-tutorial/pipeline-with-datasets/pipeline-for-image-classification.yml @@ -0,0 +1,6 @@ +name: pipeline-for-image-classification +dependencies: +- pip: + - azureml-sdk + - pandas<=0.23.4 + - fuse diff --git a/how-to-use-azureml/work-with-data/datasets-tutorial/scriptrun-with-data-input-output/how-to-use-scriptrun.yml b/how-to-use-azureml/work-with-data/datasets-tutorial/scriptrun-with-data-input-output/how-to-use-scriptrun.yml new file mode 100644 index 00000000..87dc3b4c --- /dev/null +++ b/how-to-use-azureml/work-with-data/datasets-tutorial/scriptrun-with-data-input-output/how-to-use-scriptrun.yml @@ -0,0 +1,4 @@ +name: how-to-use-scriptrun +dependencies: +- pip: + - azureml-sdk diff --git a/how-to-use-azureml/work-with-data/datasets-tutorial/timeseries-datasets/tabular-timeseries-dataset-filtering.yml b/how-to-use-azureml/work-with-data/datasets-tutorial/timeseries-datasets/tabular-timeseries-dataset-filtering.yml new file mode 100644 index 00000000..af9acab3 --- /dev/null +++ b/how-to-use-azureml/work-with-data/datasets-tutorial/timeseries-datasets/tabular-timeseries-dataset-filtering.yml @@ -0,0 +1,5 @@ +name: tabular-timeseries-dataset-filtering +dependencies: +- pip: + - azureml-sdk + - pandas<=0.23.4 diff --git a/how-to-use-azureml/work-with-data/datasets-tutorial/train-with-datasets/train-with-datasets.yml b/how-to-use-azureml/work-with-data/datasets-tutorial/train-with-datasets/train-with-datasets.yml new file mode 100644 index 00000000..d13f92dc --- /dev/null +++ b/how-to-use-azureml/work-with-data/datasets-tutorial/train-with-datasets/train-with-datasets.yml @@ -0,0 +1,8 @@ +name: train-with-datasets +dependencies: +- pip: + - azureml-sdk + - azureml-widgets + - pandas<=0.23.4 + - fuse + - scikit-learn diff --git a/index.md b/index.md index dedd23e0..3739461c 100644 --- a/index.md +++ b/index.md @@ -60,8 +60,8 @@ Machine Learning notebook samples and encourage efficient retrieval of topics an | [Train a model with hyperparameter tuning](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/ml-frameworks/chainer/train-hyperparameter-tune-deploy-with-chainer/train-hyperparameter-tune-deploy-with-chainer.ipynb) | Train a Convolutional Neural Network (CNN) | MNIST | AML Compute | Azure Container Instance | Chainer | None | | [Train a model with a custom Docker image](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/ml-frameworks/fastai/fastai-with-custom-docker/fastai-with-custom-docker.ipynb) | Train with custom Docker image | Oxford IIIT Pet | AML Compute | None | Pytorch | None | | [Train a DNN using hyperparameter tuning and deploying with Keras](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/ml-frameworks/keras/train-hyperparameter-tune-deploy-with-keras/train-hyperparameter-tune-deploy-with-keras.ipynb) | Create a multi-class classifier | MNIST | AML Compute | Azure Container Instance | TensorFlow | None | +| [Distributed training with PyTorch](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-distributeddataparallel/distributed-pytorch-with-distributeddataparallel.ipynb) | Train a model using distributed training via PyTorch DistributedDataParallel | CIFAR-10 | AML Compute | None | PyTorch | None | | [Distributed PyTorch](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-horovod/distributed-pytorch-with-horovod.ipynb) | Train a model using the distributed training via Horovod | MNIST | AML Compute | None | PyTorch | None | -| [Distributed training with PyTorch](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/ml-frameworks/pytorch/distributed-pytorch-with-nccl-gloo/distributed-pytorch-with-nccl-gloo.ipynb) | Train a model using distributed training via Nccl/Gloo | MNIST | AML Compute | None | PyTorch | None | | [Training with hyperparameter tuning using PyTorch](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/ml-frameworks/pytorch/train-hyperparameter-tune-deploy-with-pytorch/train-hyperparameter-tune-deploy-with-pytorch.ipynb) | Train an image classification model using transfer learning with the PyTorch estimator | ImageNet | AML Compute | Azure Container Instance | PyTorch | None | | [Training and hyperparameter tuning with Scikit-learn](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/ml-frameworks/scikit-learn/train-hyperparameter-tune-deploy-with-sklearn/train-hyperparameter-tune-deploy-with-sklearn.ipynb) | Train a support vector machine (SVM) to perform classification | Iris | AML Compute | None | Scikit-learn | None | | [Distributed training using TensorFlow with Horovod](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/ml-frameworks/tensorflow/distributed-tensorflow-with-horovod/distributed-tensorflow-with-horovod.ipynb) | Use the TensorFlow estimator to train a word2vec model | None | AML Compute | None | TensorFlow | None | @@ -127,7 +127,6 @@ Machine Learning notebook samples and encourage efficient retrieval of topics an | [pong_rllib](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/reinforcement-learning/atari-on-distributed-compute/pong_rllib.ipynb) | | | | | | | | [cartpole_ci](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/reinforcement-learning/cartpole-on-compute-instance/cartpole_ci.ipynb) | | | | | | | | [cartpole_sc](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/reinforcement-learning/cartpole-on-single-compute/cartpole_sc.ipynb) | | | | | | | -| [minecraft](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/reinforcement-learning/minecraft-on-distributed-compute/minecraft.ipynb) | | | | | | | | [particle](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/reinforcement-learning/multiagent-particle-envs/particle.ipynb) | | | | | | | | [rai-loan-decision](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/responsible-ai/visualize-upload-loan-decision/rai-loan-decision.ipynb) | | | | | | | | [Logging APIs](https://github.com/Azure/MachineLearningNotebooks/blob/master//how-to-use-azureml/track-and-monitor-experiments/logging-api/logging-api.ipynb) | Logging APIs and analyzing results | None | None | None | None | None | diff --git a/setup-environment/configuration.ipynb b/setup-environment/configuration.ipynb index 9dffd464..895a64bb 100644 --- a/setup-environment/configuration.ipynb +++ b/setup-environment/configuration.ipynb @@ -102,7 +102,7 @@ "source": [ "import azureml.core\n", "\n", - "print(\"This notebook was created using version 1.23.0 of the Azure ML SDK\")\n", + "print(\"This notebook was created using version 1.24.0 of the Azure ML SDK\")\n", "print(\"You are currently using version\", azureml.core.VERSION, \"of the Azure ML SDK\")" ] },