mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
259 lines
6.7 KiB
Plaintext
259 lines
6.7 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Breast cancer diagnosis classification with scikit-learn (run model explainer locally)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
|
"\n",
|
|
"Licensed under the MIT License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Explain a model with the AML explain-model package\n",
|
|
"\n",
|
|
"1. Train a SVM classification model using Scikit-learn\n",
|
|
"2. Run 'explain_model' with full data in local mode, which doesn't contact any Azure services\n",
|
|
"3. Run 'explain_model' with summarized data in local mode, which doesn't contact any Azure services\n",
|
|
"4. Visualize the global and local explanations with the visualization dashboard."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.datasets import load_breast_cancer\n",
|
|
"from sklearn import svm\n",
|
|
"from azureml.explain.model.tabular_explainer import TabularExplainer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 1. Run model explainer locally with full data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load the breast cancer diagnosis data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"breast_cancer_data = load_breast_cancer()\n",
|
|
"classes = breast_cancer_data.target_names.tolist()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Split data into train and test\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"x_train, x_test, y_train, y_test = train_test_split(breast_cancer_data.data, breast_cancer_data.target, test_size=0.2, random_state=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Train a SVM classification model, which you want to explain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clf = svm.SVC(gamma=0.001, C=100., probability=True)\n",
|
|
"model = clf.fit(x_train, y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Explain predictions on your local machine"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tabular_explainer = TabularExplainer(model, x_train, features=breast_cancer_data.feature_names, classes=classes)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Explain overall model predictions (global explanation)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data\n",
|
|
"# x_train can be passed as well, but with more examples explanations will take longer although they may be more accurate\n",
|
|
"global_explanation = tabular_explainer.explain_global(x_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Sorted SHAP values\n",
|
|
"print('ranked global importance values: {}'.format(global_explanation.get_ranked_global_values()))\n",
|
|
"# Corresponding feature names\n",
|
|
"print('ranked global importance names: {}'.format(global_explanation.get_ranked_global_names()))\n",
|
|
"# feature ranks (based on original order of features)\n",
|
|
"print('global importance rank: {}'.format(global_explanation.global_importance_rank))\n",
|
|
"# per class feature names\n",
|
|
"print('ranked per class feature names: {}'.format(global_explanation.get_ranked_per_class_names()))\n",
|
|
"# per class feature importance values\n",
|
|
"print('ranked per class feature values: {}'.format(global_explanation.get_ranked_per_class_values()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dict(zip(global_explanation.get_ranked_global_names(), global_explanation.get_ranked_global_values()))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Explain overall model predictions as a collection of local (instance-level) explanations"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# feature shap values for all features and all data points in the training data\n",
|
|
"print('local importance values: {}'.format(global_explanation.local_importance_values))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Explain local data points (individual instances)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# explain the first member of the test set\n",
|
|
"instance_num = 0\n",
|
|
"local_explanation = tabular_explainer.explain_local(x_test[instance_num,:])"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# get the prediction for the first member of the test set and explain why model made that prediction\n",
|
|
"prediction_value = clf.predict(x_test)[instance_num]\n",
|
|
"\n",
|
|
"sorted_local_importance_values = local_explanation.get_ranked_local_values()[prediction_value]\n",
|
|
"sorted_local_importance_names = local_explanation.get_ranked_local_names()[prediction_value]\n",
|
|
"\n",
|
|
"\n",
|
|
"dict(zip(sorted_local_importance_names, sorted_local_importance_values))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 2. Load visualization dashboard"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from azureml.contrib.explain.model.visualize import ExplanationDashboard"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ExplanationDashboard(global_explanation, model, x_test)"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"authors": [
|
|
{
|
|
"name": "mesameki"
|
|
}
|
|
],
|
|
"kernelspec": {
|
|
"display_name": "Python 3.6",
|
|
"language": "python",
|
|
"name": "python36"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|