Files
MachineLearningNotebooks/how-to-use-azureml/explain-model/explain-tabular-data/explain-tabular-data.ipynb
Roope Astala 556a41e223 version 1.0.21
2019-03-25 15:06:08 -04:00

267 lines
7.8 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
"\n",
"Licensed under the MIT License."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Uncomment these if explanation packages are not already installed in your environment\n",
"#!pip install --upgrade azureml-sdk[explain]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Explain a model with the AML explain-model package\n",
"\n",
"1. Train a SVM model using Scikit-learn\n",
"2. Run 'explain_model' in local mode, which doesn't contact any Azure services\n",
"3. Run 'explain_model' with AML Run History, which leverages Run History Service to store and manage the explanation data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Disclaimer: this notebook is a preview of model explainability, and the APIs shown below are subject to breaking changes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a SVM model, which we will try to explain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Import Iris dataset\n",
"from sklearn import datasets\n",
"iris = datasets.load_iris()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Split data into train and test\n",
"from sklearn.model_selection import train_test_split\n",
"x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=0)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Import scikit learn, fit a SVM model\n",
"def create_scikit_learn_model(X, y):\n",
" from sklearn import svm\n",
" clf = svm.SVC(gamma=0.001, C=100., probability=True)\n",
" model = clf.fit(X, y)\n",
" return model\n",
"model = create_scikit_learn_model(x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run model explainer locally"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from azureml.explain.model.tabular_explainer import TabularExplainer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"start = time.time()\n",
"\n",
"explainer = TabularExplainer(model, x_train, features=iris.feature_names)\n",
"global_explanation = explainer.explain_global(x_test)\n",
"\n",
"# importance values for each class, test example, and feature (local importance)\n",
"local_imp_values = global_explanation.local_importance_values\n",
"# base prediction with feature importances ignored\n",
"expected_values = global_explanation.expected_values\n",
"# global feature importance information\n",
"global_imp_values = global_explanation.global_importance_values\n",
"ranked_global_imp_names = global_explanation.get_ranked_global_names()\n",
"# global per-class feature importance information\n",
"per_class_imp_values = global_explanation.per_class_values\n",
"ranked_per_class_imp_names = global_explanation.get_ranked_per_class_names()\n",
"\n",
"end = time.time()\n",
"print(end - start)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Run model explainer with AML Run History"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import azureml.core\n",
"from azureml.core import Workspace, Experiment, Run\n",
"from azureml.explain.model.tabular_explainer import TabularExplainer\n",
"from azureml.contrib.explain.model.explanation.explanation_client import ExplanationClient\n",
"# Check core SDK version number\n",
"print(\"SDK version:\", azureml.core.VERSION)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ws = Workspace.from_config()\n",
"print('Workspace name: ' + ws.name, \n",
" 'Azure region: ' + ws.location, \n",
" 'Subscription id: ' + ws.subscription_id, \n",
" 'Resource group: ' + ws.resource_group, sep = '\\n')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"experiment_name = 'explain_model'\n",
"experiment = Experiment(ws, experiment_name)\n",
"run = experiment.start_logging()\n",
"client = ExplanationClient.from_run(run)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import time\n",
"start = time.time()\n",
"explainer = TabularExplainer(model, x_train, features=iris.feature_names, classes=iris.target_names)\n",
"explanation = explainer.explain_global(x_test)\n",
"client.upload_model_explanation(explanation)\n",
"end = time.time()\n",
"print(end - start)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"explanation_from_run = client.download_model_explanation()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# global feature importance information\n",
"global_imp_values = explanation_from_run.global_importance_values\n",
"global_imp_names = explanation_from_run.get_ranked_global_names()\n",
"# global per-class feature importance information\n",
"per_class_imp_values = explanation_from_run.per_class_values\n",
"per_class_imp_names = explanation_from_run.get_ranked_per_class_names()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## This visualization is unsupported, and is not guaranteed to work in the future"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Get the shap values and explore locally\n",
"import shap\n",
"import numpy as np\n",
"shap.initjs()\n",
"display(shap.force_plot(explanation_from_run.expected_values[1], np.asarray(explanation_from_run.local_importance_values[1]), x_test))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"run.complete()"
]
}
],
"metadata": {
"authors": [
{
"name": "wamartin"
}
],
"kernelspec": {
"display_name": "Python 3.6",
"language": "python",
"name": "python36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}