mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
Merge pull request #1530 from Azure/release_update/Release-105
update samples from Release-105 as a part of SDK release
This commit is contained in:
@@ -451,9 +451,8 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create a dataset of training artifacts\n",
|
||||
"To evaluate a trained policy (a checkpoint) we need to make the checkpoint accessible to the rollout script. All the training artifacts are stored in workspace default datastore under **azureml/<run_id>** directory.\n",
|
||||
"\n",
|
||||
"Here we create a file dataset from the stored artifacts, and then use this dataset to feed these data to rollout estimator."
|
||||
"To evaluate a trained policy (a checkpoint) we need to make the checkpoint accessible to the rollout script.\n",
|
||||
"We can use the Run API to download policy training artifacts (saved model and checkpoints) to local compute."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -462,22 +461,24 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core import Dataset\n",
|
||||
"from os import path\n",
|
||||
"from distutils import dir_util\n",
|
||||
"\n",
|
||||
"run_id = child_run_0.id # Or set to run id of a completed run (e.g. 'rl-cartpole-v0_1587572312_06e04ace_head')\n",
|
||||
"run_artifacts_path = os.path.join('azureml', run_id)\n",
|
||||
"print(\"Run artifacts path:\", run_artifacts_path)\n",
|
||||
"training_artifacts_path = path.join(\"logs\", training_algorithm)\n",
|
||||
"print(\"Training artifacts path:\", training_artifacts_path)\n",
|
||||
"\n",
|
||||
"# Create a file dataset object from the files stored on default datastore\n",
|
||||
"datastore = ws.get_default_datastore()\n",
|
||||
"training_artifacts_ds = Dataset.File.from_files(datastore.path(os.path.join(run_artifacts_path, '**')))"
|
||||
"if path.exists(training_artifacts_path):\n",
|
||||
" dir_util.remove_tree(training_artifacts_path)\n",
|
||||
"\n",
|
||||
"# Download run artifacts to local compute\n",
|
||||
"child_run_0.download_files(training_artifacts_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To verify, we can print out the number (and paths) of all the files in the dataset, as follows."
|
||||
"Now let's find the checkpoints and the last checkpoint number."
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -486,7 +487,73 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"artifacts_paths = training_artifacts_ds.to_path()\n",
|
||||
"# A helper function to find checkpoint files in a directory\n",
|
||||
"def find_checkpoints(file_path):\n",
|
||||
" print(\"Looking in path:\", file_path)\n",
|
||||
" checkpoints = []\n",
|
||||
" for root, _, files in os.walk(file_path):\n",
|
||||
" for name in files:\n",
|
||||
" if os.path.basename(root).startswith('checkpoint_'):\n",
|
||||
" checkpoints.append(path.join(root, name))\n",
|
||||
" return checkpoints"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Find checkpoints and last checkpoint number\n",
|
||||
"checkpoint_files = find_checkpoints(training_artifacts_path)\n",
|
||||
"\n",
|
||||
"checkpoint_numbers = []\n",
|
||||
"for file in checkpoint_files:\n",
|
||||
" file = os.path.basename(file)\n",
|
||||
" if file.startswith('checkpoint-') and not file.endswith('.tune_metadata'):\n",
|
||||
" checkpoint_numbers.append(int(file.split('-')[1]))\n",
|
||||
"\n",
|
||||
"print(\"Checkpoints:\", checkpoint_numbers)\n",
|
||||
"\n",
|
||||
"last_checkpoint_number = max(checkpoint_numbers)\n",
|
||||
"print(\"Last checkpoint number:\", last_checkpoint_number)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now we upload checkpoints to default datastore and create a file dataset. This dataset will be used to pass in the checkpoints to the rollout script."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Upload the checkpoint files and create a DataSet\n",
|
||||
"from azureml.core import Dataset\n",
|
||||
"\n",
|
||||
"datastore = ws.get_default_datastore()\n",
|
||||
"checkpoint_dataref = datastore.upload_files(checkpoint_files, target_path='cartpole_checkpoints_' + run_id, overwrite=True)\n",
|
||||
"checkpoint_ds = Dataset.File.from_files(checkpoint_dataref)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To verify, we can print out the number (and paths) of all the files in the dataset."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"artifacts_paths = checkpoint_ds.to_path()\n",
|
||||
"print(\"Number of files in dataset:\", len(artifacts_paths))\n",
|
||||
"\n",
|
||||
"# Uncomment line below to print all file paths\n",
|
||||
@@ -505,36 +572,6 @@
|
||||
"\n",
|
||||
"The checkpoints dataset will be accessible to the rollout script as a mounted folder. The mounted folder and the checkpoint number, passed in via `checkpoint-number`, will be used to create a path to the checkpoint we are going to evaluate. The created checkpoint path then will be passed into RLlib rollout script for evaluation.\n",
|
||||
"\n",
|
||||
"Let's find the checkpoints and the last checkpoint number first."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Find checkpoints and last checkpoint number\n",
|
||||
"checkpoint_files = [\n",
|
||||
" os.path.basename(file) for file in training_artifacts_ds.to_path() \\\n",
|
||||
" if os.path.basename(file).startswith('checkpoint-') and \\\n",
|
||||
" not os.path.basename(file).endswith('tune_metadata')\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"checkpoint_numbers = []\n",
|
||||
"for file in checkpoint_files:\n",
|
||||
" checkpoint_numbers.append(int(file.split('-')[1]))\n",
|
||||
"\n",
|
||||
"print(\"Checkpoints:\", checkpoint_numbers)\n",
|
||||
"\n",
|
||||
"last_checkpoint_number = max(checkpoint_numbers)\n",
|
||||
"print(\"Last checkpoint number:\", last_checkpoint_number)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now let's configure rollout estimator. Note that we use the last checkpoint for evaluation. The assumption is that the last checkpoint points to our best trained agent. You may change this to any of the checkpoint numbers printed above and observe the effect."
|
||||
]
|
||||
},
|
||||
@@ -576,8 +613,8 @@
|
||||
" \n",
|
||||
" # Data inputs\n",
|
||||
" inputs=[\n",
|
||||
" training_artifacts_ds.as_named_input('artifacts_dataset'),\n",
|
||||
" training_artifacts_ds.as_named_input('artifacts_path').as_mount()],\n",
|
||||
" checkpoint_ds.as_named_input('artifacts_dataset'),\n",
|
||||
" checkpoint_ds.as_named_input('artifacts_path').as_mount()],\n",
|
||||
" \n",
|
||||
" # The Azure Machine Learning compute target\n",
|
||||
" compute_target=compute_target,\n",
|
||||
|
||||
@@ -474,61 +474,14 @@
|
||||
"from os import path\n",
|
||||
"from distutils import dir_util\n",
|
||||
"\n",
|
||||
"path_prefix = path.join(\"logs\", training_algorithm)\n",
|
||||
"print(\"Path prefix:\", path_prefix)\n",
|
||||
"training_artifacts_path = path.join(\"logs\", training_algorithm)\n",
|
||||
"print(\"Training artifacts path:\", training_artifacts_path)\n",
|
||||
"\n",
|
||||
"if path.exists(path_prefix):\n",
|
||||
" dir_util.remove_tree(path_prefix)\n",
|
||||
"if path.exists(training_artifacts_path):\n",
|
||||
" dir_util.remove_tree(training_artifacts_path)\n",
|
||||
"\n",
|
||||
"# Uncomment line below to download run artifacts to local compute\n",
|
||||
"#child_run_0.download_files(path_prefix)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Create a dataset of training artifacts\n",
|
||||
"To evaluate a trained policy (a checkpoint) we need to make the checkpoint accessible to the rollout script. All the training artifacts are stored in workspace default datastore under **azureml/<run_id>** directory.\n",
|
||||
"\n",
|
||||
"Here we create a file dataset from the stored artifacts, and then use this dataset to feed these data to rollout estimator."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from azureml.core import Dataset\n",
|
||||
"\n",
|
||||
"run_id = child_run_0.id # Or set to run id of a completed run (e.g. 'rl-cartpole-v0_1587572312_06e04ace_head')\n",
|
||||
"run_artifacts_path = os.path.join('azureml', run_id)\n",
|
||||
"print(\"Run artifacts path:\", run_artifacts_path)\n",
|
||||
"\n",
|
||||
"# Create a file dataset object from the files stored on default datastore\n",
|
||||
"datastore = ws.get_default_datastore()\n",
|
||||
"training_artifacts_ds = Dataset.File.from_files(datastore.path(os.path.join(run_artifacts_path, '**')))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To verify, we can print out the number (and paths) of all the files in the dataset, as follows."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"artifacts_paths = training_artifacts_ds.to_path()\n",
|
||||
"print(\"Number of files in dataset:\", len(artifacts_paths))\n",
|
||||
"\n",
|
||||
"# Uncomment line below to print all file paths\n",
|
||||
"#print(\"Artifacts dataset file paths: \", artifacts_paths)"
|
||||
"# Download run artifacts to local compute\n",
|
||||
"child_run_0.download_files(training_artifacts_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -550,21 +503,6 @@
|
||||
"source": [
|
||||
"import shutil\n",
|
||||
"\n",
|
||||
"# A helper function to download movies from a dataset to local directory\n",
|
||||
"def download_movies(artifacts_ds, movies, destination):\n",
|
||||
" # Create the local destination directory \n",
|
||||
" if path.exists(destination):\n",
|
||||
" dir_util.remove_tree(destination)\n",
|
||||
" dir_util.mkpath(destination)\n",
|
||||
"\n",
|
||||
" for i, artifact in enumerate(artifacts_ds.to_path()):\n",
|
||||
" if artifact in movies:\n",
|
||||
" print('Downloading {} ...'.format(artifact))\n",
|
||||
" artifacts_ds.skip(i).take(1).download(target_path=destination, overwrite=True)\n",
|
||||
"\n",
|
||||
" print('Downloading movies completed!')\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"# A helper function to find movies in a directory\n",
|
||||
"def find_movies(movie_path):\n",
|
||||
" print(\"Looking in path:\", movie_path)\n",
|
||||
@@ -590,34 +528,6 @@
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now let's find the first and the last recorded videos in training artifacts dataset and download them to a local directory."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Find first and last movie\n",
|
||||
"mp4_files = [file for file in training_artifacts_ds.to_path() if file.endswith('.mp4')]\n",
|
||||
"mp4_files.sort()\n",
|
||||
"\n",
|
||||
"first_movie = mp4_files[0] if len(mp4_files) > 0 else None\n",
|
||||
"last_movie = mp4_files[-1] if len(mp4_files) > 1 else None\n",
|
||||
"\n",
|
||||
"print(\"First movie:\", first_movie)\n",
|
||||
"print(\"Last movie:\", last_movie)\n",
|
||||
"\n",
|
||||
"# Download movies\n",
|
||||
"training_movies_path = path.join(\"training\", \"videos\")\n",
|
||||
"download_movies(training_artifacts_ds, [first_movie, last_movie], training_movies_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -631,7 +541,7 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"mp4_files = find_movies(training_movies_path)\n",
|
||||
"mp4_files = find_movies(training_artifacts_path)\n",
|
||||
"mp4_files.sort()"
|
||||
]
|
||||
},
|
||||
@@ -704,16 +614,31 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Find checkpoints and last checkpoint number\n",
|
||||
"checkpoint_files = [\n",
|
||||
" os.path.basename(file) for file in training_artifacts_ds.to_path() \\\n",
|
||||
" if os.path.basename(file).startswith('checkpoint-') and \\\n",
|
||||
" not os.path.basename(file).endswith('tune_metadata')\n",
|
||||
"]\n",
|
||||
"# A helper function to find checkpoint files in a directory\n",
|
||||
"def find_checkpoints(file_path):\n",
|
||||
" print(\"Looking in path:\", file_path)\n",
|
||||
" checkpoints = []\n",
|
||||
" for root, _, files in os.walk(file_path):\n",
|
||||
" for name in files:\n",
|
||||
" if os.path.basename(root).startswith('checkpoint_'):\n",
|
||||
" checkpoints.append(path.join(root, name))\n",
|
||||
" return checkpoints\n",
|
||||
"\n",
|
||||
"checkpoint_files = find_checkpoints(training_artifacts_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Find checkpoints and last checkpoint number\n",
|
||||
"checkpoint_numbers = []\n",
|
||||
"for file in checkpoint_files:\n",
|
||||
" checkpoint_numbers.append(int(file.split('-')[1]))\n",
|
||||
" file = os.path.basename(file)\n",
|
||||
" if file.startswith('checkpoint-') and not file.endswith('.tune_metadata'):\n",
|
||||
" checkpoint_numbers.append(int(file.split('-')[-1]))\n",
|
||||
"\n",
|
||||
"print(\"Checkpoints:\", checkpoint_numbers)\n",
|
||||
"\n",
|
||||
@@ -721,6 +646,20 @@
|
||||
"print(\"Last checkpoint number:\", last_checkpoint_number)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Upload the checkpoint files and create a DataSet\n",
|
||||
"from azureml.core import Dataset\n",
|
||||
"\n",
|
||||
"datastore = ws.get_default_datastore()\n",
|
||||
"checkpoint_dataref = datastore.upload_files(checkpoint_files, target_path='cartpole_checkpoints_' + run_id, overwrite=True)\n",
|
||||
"checkpoint_ds = Dataset.File.from_files(checkpoint_dataref)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@@ -796,8 +735,8 @@
|
||||
" \n",
|
||||
" # Data inputs\n",
|
||||
" inputs=[\n",
|
||||
" training_artifacts_ds.as_named_input('artifacts_dataset'),\n",
|
||||
" training_artifacts_ds.as_named_input('artifacts_path').as_mount()],\n",
|
||||
" checkpoint_ds.as_named_input('artifacts_dataset'),\n",
|
||||
" checkpoint_ds.as_named_input('artifacts_path').as_mount()],\n",
|
||||
" \n",
|
||||
" # The Azure Machine Learning compute target set up for Ray head nodes\n",
|
||||
" compute_target=compute_target,\n",
|
||||
@@ -879,16 +818,15 @@
|
||||
"print('Number of child runs:', len(child_runs))\n",
|
||||
"child_run_0 = child_runs[0]\n",
|
||||
"\n",
|
||||
"run_id = child_run_0.id # Or set to run id of a completed run (e.g. 'rl-cartpole-v0_1587572312_06e04ace_head')\n",
|
||||
"run_artifacts_path = os.path.join('azureml', run_id)\n",
|
||||
"print(\"Run artifacts path:\", run_artifacts_path)\n",
|
||||
"# Download rollout artifacts\n",
|
||||
"rollout_artifacts_path = path.join(\"logs\", \"rollout\")\n",
|
||||
"print(\"Rollout artifacts path:\", rollout_artifacts_path)\n",
|
||||
"\n",
|
||||
"# Create a file dataset object from the files stored on default datastore\n",
|
||||
"datastore = ws.get_default_datastore()\n",
|
||||
"rollout_artifacts_ds = Dataset.File.from_files(datastore.path(os.path.join(run_artifacts_path, '**')))\n",
|
||||
"if path.exists(rollout_artifacts_path):\n",
|
||||
" dir_util.remove_tree(rollout_artifacts_path)\n",
|
||||
"\n",
|
||||
"artifacts_paths = rollout_artifacts_ds.to_path()\n",
|
||||
"print(\"Number of files in dataset:\", len(artifacts_paths))"
|
||||
"# Download videos to local compute\n",
|
||||
"child_run_0.download_files(\"logs/video\", output_directory = rollout_artifacts_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -904,20 +842,11 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Find last movie\n",
|
||||
"mp4_files = [file for file in rollout_artifacts_ds.to_path() if file.endswith('.mp4')]\n",
|
||||
"mp4_files.sort()\n",
|
||||
"\n",
|
||||
"last_movie = mp4_files[-1] if len(mp4_files) > 1 else None\n",
|
||||
"print(\"Last movie:\", last_movie)\n",
|
||||
"\n",
|
||||
"# Download last movie\n",
|
||||
"rollout_movies_path = path.join(\"rollout\", \"videos\")\n",
|
||||
"download_movies(rollout_artifacts_ds, [last_movie], rollout_movies_path)\n",
|
||||
"\n",
|
||||
"# Look for the downloaded movie in local directory\n",
|
||||
"mp4_files = find_movies(rollout_movies_path)\n",
|
||||
"mp4_files.sort()"
|
||||
"mp4_files = find_movies(rollout_artifacts_path)\n",
|
||||
"mp4_files.sort()\n",
|
||||
"last_movie = mp4_files[-1] if len(mp4_files) > 1 else None\n",
|
||||
"print(\"Last movie:\", last_movie)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -960,16 +889,12 @@
|
||||
"#compute_target.delete()\n",
|
||||
"\n",
|
||||
"# To delete downloaded training artifacts\n",
|
||||
"#if os.path.exists(path_prefix):\n",
|
||||
"# dir_util.remove_tree(path_prefix)\n",
|
||||
"\n",
|
||||
"# To delete downloaded training videos\n",
|
||||
"#if path.exists(training_movies_path):\n",
|
||||
"# dir_util.remove_tree(training_movies_path)\n",
|
||||
"#if os.path.exists(training_artifacts_path):\n",
|
||||
"# dir_util.remove_tree(training_artifacts_path)\n",
|
||||
"\n",
|
||||
"# To delete downloaded rollout videos\n",
|
||||
"#if path.exists(rollout_movies_path):\n",
|
||||
"# dir_util.remove_tree(rollout_movies_path)"
|
||||
"#if path.exists(rollout_artifacts_path):\n",
|
||||
"# dir_util.remove_tree(rollout_artifacts_path)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -986,6 +911,9 @@
|
||||
"authors": [
|
||||
{
|
||||
"name": "hoazari"
|
||||
},
|
||||
{
|
||||
"name": "dasommer"
|
||||
}
|
||||
],
|
||||
"kernelspec": {
|
||||
|
||||
Reference in New Issue
Block a user