mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-19 17:17:04 -05:00
388 lines
14 KiB
Plaintext
388 lines
14 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
|
"\n",
|
|
"Licensed under the MIT License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Explain multiclass classification model's predictions\n",
|
|
"_**This notebook showcases how to use the Azure Machine Learning Interpretability SDK to explain and visualize a multiclass classification model predictions.**_\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"## Table of Contents\n",
|
|
"\n",
|
|
"1. [Introduction](#Introduction)\n",
|
|
"1. [Setup](#Setup)\n",
|
|
"1. [Run model explainer locally at training time](#Explain)\n",
|
|
" 1. Train a multiclass classification model\n",
|
|
" 1. Explain the model\n",
|
|
" 1. Generate global explanations\n",
|
|
" 1. Generate local explanations\n",
|
|
"1. [Visualize results](#Visualize)\n",
|
|
"1. [Next steps](#Next)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Introduction\n",
|
|
"\n",
|
|
"This notebook illustrates how to explain a multiclass classification model predictions locally at training time without contacting any Azure services.\n",
|
|
"It demonstrates the API calls that you need to make to get the global and local explanations and a visualization dashboard that provides an interactive way of discovering patterns in data and explanations.\n",
|
|
"\n",
|
|
"We will showcase three tabular data explainers: TabularExplainer (SHAP), MimicExplainer (global surrogate), and PFIExplainer.\n",
|
|
"\n",
|
|
"|  |\n",
|
|
"|:--:|\n",
|
|
"| *Interpretability Toolkit Architecture* |\n",
|
|
"\n",
|
|
"Problem: Iris flower classification with scikit-learn (run model explainer locally)\n",
|
|
"\n",
|
|
"1. Train a SVM classification model using Scikit-learn\n",
|
|
"2. Run 'explain_model' globally and locally with full dataset in local mode, which doesn't contact any Azure services.\n",
|
|
"3. Visualize the global and local explanations with the visualization dashboard.\n",
|
|
"---\n",
|
|
"\n",
|
|
"Setup: If you are using Jupyter notebooks, the extensions should be installed automatically with the package.\n",
|
|
"If you are using Jupyter Labs run the following command:\n",
|
|
"```\n",
|
|
"(myenv) $ jupyter labextension install @jupyter-widgets/jupyterlab-manager\n",
|
|
"```\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Explain\n",
|
|
"\n",
|
|
"### Run model explainer locally at training time"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.datasets import load_iris\n",
|
|
"from sklearn import svm\n",
|
|
"\n",
|
|
"# Explainers:\n",
|
|
"# 1. SHAP Tabular Explainer\n",
|
|
"from interpret.ext.blackbox import TabularExplainer\n",
|
|
"\n",
|
|
"# OR\n",
|
|
"\n",
|
|
"# 2. Mimic Explainer\n",
|
|
"from interpret.ext.blackbox import MimicExplainer\n",
|
|
"# You can use one of the following four interpretable models as a global surrogate to the black box model\n",
|
|
"from interpret.ext.glassbox import LGBMExplainableModel\n",
|
|
"from interpret.ext.glassbox import LinearExplainableModel\n",
|
|
"from interpret.ext.glassbox import SGDExplainableModel\n",
|
|
"from interpret.ext.glassbox import DecisionTreeExplainableModel\n",
|
|
"\n",
|
|
"# OR\n",
|
|
"\n",
|
|
"# 3. PFI Explainer\n",
|
|
"from interpret.ext.blackbox import PFIExplainer "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Load the Iris flower dataset"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"iris = load_iris()\n",
|
|
"X = iris['data']\n",
|
|
"y = iris['target']\n",
|
|
"classes = iris['target_names']\n",
|
|
"feature_names = iris['feature_names']"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Split data into train and test\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Train a SVM classification model, which you want to explain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clf = svm.SVC(gamma=0.001, C=100., probability=True)\n",
|
|
"model = clf.fit(x_train, y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Explain predictions on your local machine"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# 1. Using SHAP TabularExplainer\n",
|
|
"explainer = TabularExplainer(model, \n",
|
|
" x_train, \n",
|
|
" features=feature_names, \n",
|
|
" classes=classes)\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"# 2. Using MimicExplainer\n",
|
|
"# augment_data is optional and if true, oversamples the initialization examples to improve surrogate model accuracy to fit original model. Useful for high-dimensional data where the number of rows is less than the number of columns. \n",
|
|
"# max_num_of_augmentations is optional and defines max number of times we can increase the input data size.\n",
|
|
"# LGBMExplainableModel can be replaced with LinearExplainableModel, SGDExplainableModel, or DecisionTreeExplainableModel\n",
|
|
"# explainer = MimicExplainer(model, \n",
|
|
"# x_train, \n",
|
|
"# LGBMExplainableModel, \n",
|
|
"# augment_data=True, \n",
|
|
"# max_num_of_augmentations=10, \n",
|
|
"# features=feature_names, \n",
|
|
"# classes=classes)\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"# 3. Using PFIExplainer\n",
|
|
"\n",
|
|
"# Use the parameter \"metric\" to pass a metric name or function to evaluate the permutation. \n",
|
|
"# Note that if a metric function is provided a higher value must be better.\n",
|
|
"# Otherwise, take the negative of the function or set the parameter \"is_error_metric\" to True.\n",
|
|
"# Default metrics: \n",
|
|
"# F1 Score for binary classification, F1 Score with micro average for multiclass classification and\n",
|
|
"# Mean absolute error for regression\n",
|
|
"\n",
|
|
"# explainer = PFIExplainer(model, \n",
|
|
"# features=feature_names, \n",
|
|
"# classes=classes)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Generate global explanations\n",
|
|
"Explain overall model predictions (global explanation)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data\n",
|
|
"# x_train can be passed as well, but with more examples explanations will take longer although they may be more accurate\n",
|
|
"global_explanation = explainer.explain_global(x_test)\n",
|
|
"\n",
|
|
"# Note: if you used the PFIExplainer in the previous step, use the next line of code instead\n",
|
|
"# global_explanation = explainer.explain_global(x_test, true_labels=y_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Sorted SHAP values\n",
|
|
"print('ranked global importance values: {}'.format(global_explanation.get_ranked_global_values()))\n",
|
|
"# Corresponding feature names\n",
|
|
"print('ranked global importance names: {}'.format(global_explanation.get_ranked_global_names()))\n",
|
|
"# Feature ranks (based on original order of features)\n",
|
|
"print('global importance rank: {}'.format(global_explanation.global_importance_rank))\n",
|
|
"\n",
|
|
"# Note: PFIExplainer does not support per class explanations\n",
|
|
"# Per class feature names\n",
|
|
"print('ranked per class feature names: {}'.format(global_explanation.get_ranked_per_class_names()))\n",
|
|
"# Per class feature importance values\n",
|
|
"print('ranked per class feature values: {}'.format(global_explanation.get_ranked_per_class_values()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Print out a dictionary that holds the sorted feature importance names and values\n",
|
|
"print('global importance rank: {}'.format(global_explanation.get_feature_importance_dict()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Explain overall model predictions as a collection of local (instance-level) explanations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# feature shap values for all features and all data points in the training data\n",
|
|
"print('local importance values: {}'.format(global_explanation.local_importance_values))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### Generate local explanations\n",
|
|
"Explain local data points (individual instances)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Note: PFIExplainer does not support local explanations\n",
|
|
"# You can pass a specific data point or a group of data points to the explain_local function\n",
|
|
"\n",
|
|
"# E.g., Explain the first data point in the test set\n",
|
|
"instance_num = 0\n",
|
|
"local_explanation = explainer.explain_local(x_test[instance_num,:])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get the prediction for the first member of the test set and explain why model made that prediction\n",
|
|
"prediction_value = clf.predict(x_test)[instance_num]\n",
|
|
"\n",
|
|
"sorted_local_importance_values = local_explanation.get_ranked_local_values()[prediction_value]\n",
|
|
"sorted_local_importance_names = local_explanation.get_ranked_local_names()[prediction_value]\n",
|
|
"\n",
|
|
"print('local importance values: {}'.format(sorted_local_importance_values))\n",
|
|
"print('local importance names: {}'.format(sorted_local_importance_names))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Visualize\n",
|
|
"Load the visualization dashboard"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from azureml.contrib.interpret.visualize import ExplanationDashboard"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ExplanationDashboard(global_explanation, model, x_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Next\n",
|
|
"Learn about other use cases of the explain package on a:\n",
|
|
"\n",
|
|
"1. [Training time: regression problem](./explain-regression-local.ipynb) \n",
|
|
"1. [Training time: binary classification problem](./explain-binary-classification-local.ipynb)\n",
|
|
"1. Explain models with engineered features:\n",
|
|
" 1. [Simple feature transformations](./simple-feature-transformations-explain-local.ipynb)\n",
|
|
" 1. [Advanced feature transformations](./advanced-feature-transformations-explain-local.ipynb)\n",
|
|
"1. [Save model explanations via Azure Machine Learning Run History](../azure-integration/run-history/save-retrieve-explanations-run-history.ipynb)\n",
|
|
"1. [Run explainers remotely on Azure Machine Learning Compute (AMLCompute)](../azure-integration/remote-explanation/explain-model-on-amlcompute.ipynb)\n",
|
|
"1. Inferencing time: deploy a classification model and explainer:\n",
|
|
" 1. [Deploy a locally-trained model and explainer](../azure-integration/scoring-time/train-explain-model-locally-and-deploy.ipynb)\n",
|
|
" 1. [Deploy a remotely-trained model and explainer](../azure-integration/scoring-time/train-explain-model-on-amlcompute-and-deploy.ipynb)\n",
|
|
"\u00e2\u20ac\u2039\n"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"authors": [
|
|
{
|
|
"name": "mesameki"
|
|
}
|
|
],
|
|
"kernelspec": {
|
|
"display_name": "Python 3.6",
|
|
"language": "python",
|
|
"name": "python36"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
} |