Files
MachineLearningNotebooks/how-to-use-azureml/explain-model/explain-tabular-data-local/explain-local-sklearn-binary-classification.ipynb
2019-05-04 19:35:57 -07:00

266 lines
7.0 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Breast cancer diagnosis classification with scikit-learn (run model explainer locally)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Impressions](https://PixelServer20190423114238.azurewebsites.net/api/impressions/MachineLearningNotebooks/how-to-use-azureml/explain-model/explain-tabular-data-local/explain-local-sklearn-binary-classification.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
"\n",
"Licensed under the MIT License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Explain a model with the AML explain-model package\n",
"\n",
"1. Train a SVM classification model using Scikit-learn\n",
"2. Run 'explain_model' with full data in local mode, which doesn't contact any Azure services\n",
"3. Run 'explain_model' with summarized data in local mode, which doesn't contact any Azure services\n",
"4. Visualize the global and local explanations with the visualization dashboard."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.datasets import load_breast_cancer\n",
"from sklearn import svm\n",
"from azureml.explain.model.tabular_explainer import TabularExplainer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 1. Run model explainer locally with full data"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the breast cancer diagnosis data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"breast_cancer_data = load_breast_cancer()\n",
"classes = breast_cancer_data.target_names.tolist()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Split data into train and test\n",
"from sklearn.model_selection import train_test_split\n",
"x_train, x_test, y_train, y_test = train_test_split(breast_cancer_data.data, breast_cancer_data.target, test_size=0.2, random_state=0)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train a SVM classification model, which you want to explain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"clf = svm.SVC(gamma=0.001, C=100., probability=True)\n",
"model = clf.fit(x_train, y_train)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explain predictions on your local machine"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tabular_explainer = TabularExplainer(model, x_train, features=breast_cancer_data.feature_names, classes=classes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explain overall model predictions (global explanation)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data\n",
"# x_train can be passed as well, but with more examples explanations will take longer although they may be more accurate\n",
"global_explanation = tabular_explainer.explain_global(x_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Sorted SHAP values\n",
"print('ranked global importance values: {}'.format(global_explanation.get_ranked_global_values()))\n",
"# Corresponding feature names\n",
"print('ranked global importance names: {}'.format(global_explanation.get_ranked_global_names()))\n",
"# feature ranks (based on original order of features)\n",
"print('global importance rank: {}'.format(global_explanation.global_importance_rank))\n",
"# per class feature names\n",
"print('ranked per class feature names: {}'.format(global_explanation.get_ranked_per_class_names()))\n",
"# per class feature importance values\n",
"print('ranked per class feature values: {}'.format(global_explanation.get_ranked_per_class_values()))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dict(zip(global_explanation.get_ranked_global_names(), global_explanation.get_ranked_global_values()))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explain overall model predictions as a collection of local (instance-level) explanations"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# feature shap values for all features and all data points in the training data\n",
"print('local importance values: {}'.format(global_explanation.local_importance_values))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Explain local data points (individual instances)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# explain the first member of the test set\n",
"instance_num = 0\n",
"local_explanation = tabular_explainer.explain_local(x_test[instance_num,:])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get the prediction for the first member of the test set and explain why model made that prediction\n",
"prediction_value = clf.predict(x_test)[instance_num]\n",
"\n",
"sorted_local_importance_values = local_explanation.get_ranked_local_values()[prediction_value]\n",
"sorted_local_importance_names = local_explanation.get_ranked_local_names()[prediction_value]\n",
"\n",
"\n",
"dict(zip(sorted_local_importance_names, sorted_local_importance_values))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 2. Load visualization dashboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from azureml.contrib.explain.model.visualize import ExplanationDashboard"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ExplanationDashboard(global_explanation, model, x_test)"
]
}
],
"metadata": {
"authors": [
{
"name": "mesameki"
}
],
"kernelspec": {
"display_name": "Python 3.6",
"language": "python",
"name": "python36"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}