mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
267 lines
7.8 KiB
Plaintext
267 lines
7.8 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
|
"\n",
|
|
"Licensed under the MIT License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Uncomment these if explanation packages are not already installed in your environment\n",
|
|
"#!pip install --upgrade azureml-sdk[explain]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Explain a model with the AML explain-model package\n",
|
|
"\n",
|
|
"1. Train a SVM model using Scikit-learn\n",
|
|
"2. Run 'explain_model' in local mode, which doesn't contact any Azure services\n",
|
|
"3. Run 'explain_model' with AML Run History, which leverages Run History Service to store and manage the explanation data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Disclaimer: this notebook is a preview of model explainability, and the APIs shown below are subject to breaking changes"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Train a SVM model, which we will try to explain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Import Iris dataset\n",
|
|
"from sklearn import datasets\n",
|
|
"iris = datasets.load_iris()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Split data into train and test\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"x_train, x_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Import scikit learn, fit a SVM model\n",
|
|
"def create_scikit_learn_model(X, y):\n",
|
|
" from sklearn import svm\n",
|
|
" clf = svm.SVC(gamma=0.001, C=100., probability=True)\n",
|
|
" model = clf.fit(X, y)\n",
|
|
" return model\n",
|
|
"model = create_scikit_learn_model(x_train, y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Run model explainer locally"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from azureml.explain.model.tabular_explainer import TabularExplainer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import time\n",
|
|
"start = time.time()\n",
|
|
"\n",
|
|
"explainer = TabularExplainer(model, x_train, features=iris.feature_names)\n",
|
|
"global_explanation = explainer.explain_global(x_test)\n",
|
|
"\n",
|
|
"# importance values for each class, test example, and feature (local importance)\n",
|
|
"local_imp_values = global_explanation.local_importance_values\n",
|
|
"# base prediction with feature importances ignored\n",
|
|
"expected_values = global_explanation.expected_values\n",
|
|
"# global feature importance information\n",
|
|
"global_imp_values = global_explanation.global_importance_values\n",
|
|
"ranked_global_imp_names = global_explanation.get_ranked_global_names()\n",
|
|
"# global per-class feature importance information\n",
|
|
"per_class_imp_values = global_explanation.per_class_values\n",
|
|
"ranked_per_class_imp_names = global_explanation.get_ranked_per_class_names()\n",
|
|
"\n",
|
|
"end = time.time()\n",
|
|
"print(end - start)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Run model explainer with AML Run History"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import azureml.core\n",
|
|
"from azureml.core import Workspace, Experiment, Run\n",
|
|
"from azureml.explain.model.tabular_explainer import TabularExplainer\n",
|
|
"from azureml.contrib.explain.model.explanation.explanation_client import ExplanationClient\n",
|
|
"# Check core SDK version number\n",
|
|
"print(\"SDK version:\", azureml.core.VERSION)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ws = Workspace.from_config()\n",
|
|
"print('Workspace name: ' + ws.name, \n",
|
|
" 'Azure region: ' + ws.location, \n",
|
|
" 'Subscription id: ' + ws.subscription_id, \n",
|
|
" 'Resource group: ' + ws.resource_group, sep = '\\n')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"experiment_name = 'explain_model'\n",
|
|
"experiment = Experiment(ws, experiment_name)\n",
|
|
"run = experiment.start_logging()\n",
|
|
"client = ExplanationClient.from_run(run)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import time\n",
|
|
"start = time.time()\n",
|
|
"explainer = TabularExplainer(model, x_train, features=iris.feature_names, classes=iris.target_names)\n",
|
|
"explanation = explainer.explain_global(x_test)\n",
|
|
"client.upload_model_explanation(explanation)\n",
|
|
"end = time.time()\n",
|
|
"print(end - start)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"explanation_from_run = client.download_model_explanation()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# global feature importance information\n",
|
|
"global_imp_values = explanation_from_run.global_importance_values\n",
|
|
"global_imp_names = explanation_from_run.get_ranked_global_names()\n",
|
|
"# global per-class feature importance information\n",
|
|
"per_class_imp_values = explanation_from_run.per_class_values\n",
|
|
"per_class_imp_names = explanation_from_run.get_ranked_per_class_names()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## This visualization is unsupported, and is not guaranteed to work in the future"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get the shap values and explore locally\n",
|
|
"import shap\n",
|
|
"import numpy as np\n",
|
|
"shap.initjs()\n",
|
|
"display(shap.force_plot(explanation_from_run.expected_values[1], np.asarray(explanation_from_run.local_importance_values[1]), x_test))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"run.complete()"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"authors": [
|
|
{
|
|
"name": "wamartin"
|
|
}
|
|
],
|
|
"kernelspec": {
|
|
"display_name": "Python 3.6",
|
|
"language": "python",
|
|
"name": "python36"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
} |