Files
MachineLearningNotebooks/how-to-use-azureml/explain-model/explain-tabular-data-run-history/explain-run-history-sklearn-classification.ipynb
Roope Astala 2d41c00488 version 1.0.39
2019-05-14 16:01:14 -04:00

262 lines
7.5 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Breast cancer diagnosis classification with scikit-learn (save model explanations via AML Run History)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Impressions](https://PixelServer20190423114238.azurewebsites.net/api/impressions/MachineLearningNotebooks/how-to-use-azureml/explain-model/explain-tabular-data-run-history/explain-run-history-sklearn-classification.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
"\n",
"Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Explain a model with the AML explain-model package\n",
"\n",
"1. Train a SVM classification model using Scikit-learn\n",
"2. Run 'explain_model' with AML Run History, which leverages run history service to store and manage the explanation data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_breast_cancer\n",
"from sklearn import svm\n",
"from azureml.explain.model.tabular_explainer import TabularExplainer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Run model explainer locally with full data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the breast cancer diagnosis data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"breast_cancer_data = load_breast_cancer()\n",
"classes = breast_cancer_data.target_names.tolist()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Split data into train and test\n",
"from sklearn.model_selection import train_test_split\n",
"x_train, x_test, y_train, y_test = train_test_split(breast_cancer_data.data, breast_cancer_data.target, test_size=0.2, random_state=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a SVM classification model, which you want to explain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"clf = svm.SVC(gamma=0.001, C=100., probability=True)\n",
"model = clf.fit(x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explain predictions on your local machine"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tabular_explainer = TabularExplainer(model, x_train, features=breast_cancer_data.feature_names, classes=classes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explain overall model predictions (global explanation)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data\n",
"# x_train can be passed as well, but with more examples explanations will take longer although they may be more accurate\n",
"global_explanation = tabular_explainer.explain_global(x_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Save Model Explanation With AML Run History"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import azureml.core\n",
"from azureml.core import Workspace, Experiment, Run\n",
"from azureml.explain.model.tabular_explainer import TabularExplainer\n",
"from azureml.contrib.explain.model.explanation.explanation_client import ExplanationClient\n",
"# Check core SDK version number\n",
"print(\"SDK version:\", azureml.core.VERSION)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ws = Workspace.from_config()\n",
"print('Workspace name: ' + ws.name, \n",
" 'Azure region: ' + ws.location, \n",
" 'Subscription id: ' + ws.subscription_id, \n",
" 'Resource group: ' + ws.resource_group, sep = '\\n')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"experiment_name = 'explain_model'\n",
"experiment = Experiment(ws, experiment_name)\n",
"run = experiment.start_logging()\n",
"client = ExplanationClient.from_run(run)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Uploading model explanation data for storage or visualization in webUX\n",
"# The explanation can then be downloaded on any compute\n",
"client.upload_model_explanation(global_explanation)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get model explanation data\n",
"explanation = client.download_model_explanation()\n",
"local_importance_values = explanation.local_importance_values\n",
"expected_values = explanation.expected_values"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get the top k (e.g., 4) most important features with their importance values\n",
"explanation = client.download_model_explanation(top_k=4)\n",
"global_importance_values = explanation.get_ranked_global_values()\n",
"global_importance_names = explanation.get_ranked_global_names()\n",
"per_class_names = explanation.get_ranked_per_class_names()[0]\n",
"per_class_values = explanation.get_ranked_per_class_values()[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print('per class feature importance values: {}'.format(per_class_values))\n",
"print('per class feature importance names: {}'.format(per_class_names))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dict(zip(per_class_names, per_class_values))"
]
}
],
"metadata": {
"authors": [
{
"name": "mesameki"
}
],
"kernelspec": {
"display_name": "Python 3.6",
"language": "python",
"name": "python36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}