mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
262 lines
7.5 KiB
Plaintext
262 lines
7.5 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Breast cancer diagnosis classification with scikit-learn (save model explanations via AML Run History)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
""
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Copyright (c) Microsoft Corporation. All rights reserved.\n",
|
|
"\n",
|
|
"Licensed under the MIT License."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"Explain a model with the AML explain-model package\n",
|
|
"\n",
|
|
"1. Train a SVM classification model using Scikit-learn\n",
|
|
"2. Run 'explain_model' with AML Run History, which leverages run history service to store and manage the explanation data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from sklearn.datasets import load_breast_cancer\n",
|
|
"from sklearn import svm\n",
|
|
"from azureml.explain.model.tabular_explainer import TabularExplainer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 1. Run model explainer locally with full data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Load the breast cancer diagnosis data"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"breast_cancer_data = load_breast_cancer()\n",
|
|
"classes = breast_cancer_data.target_names.tolist()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Split data into train and test\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"x_train, x_test, y_train, y_test = train_test_split(breast_cancer_data.data, breast_cancer_data.target, test_size=0.2, random_state=0)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Train a SVM classification model, which you want to explain"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"clf = svm.SVC(gamma=0.001, C=100., probability=True)\n",
|
|
"model = clf.fit(x_train, y_train)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Explain predictions on your local machine"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"tabular_explainer = TabularExplainer(model, x_train, features=breast_cancer_data.feature_names, classes=classes)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Explain overall model predictions (global explanation)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Passing in test dataset for evaluation examples - note it must be a representative sample of the original data\n",
|
|
"# x_train can be passed as well, but with more examples explanations will take longer although they may be more accurate\n",
|
|
"global_explanation = tabular_explainer.explain_global(x_test)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# 2. Save Model Explanation With AML Run History"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import azureml.core\n",
|
|
"from azureml.core import Workspace, Experiment, Run\n",
|
|
"from azureml.explain.model.tabular_explainer import TabularExplainer\n",
|
|
"from azureml.contrib.explain.model.explanation.explanation_client import ExplanationClient\n",
|
|
"# Check core SDK version number\n",
|
|
"print(\"SDK version:\", azureml.core.VERSION)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"ws = Workspace.from_config()\n",
|
|
"print('Workspace name: ' + ws.name, \n",
|
|
" 'Azure region: ' + ws.location, \n",
|
|
" 'Subscription id: ' + ws.subscription_id, \n",
|
|
" 'Resource group: ' + ws.resource_group, sep = '\\n')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"experiment_name = 'explain_model'\n",
|
|
"experiment = Experiment(ws, experiment_name)\n",
|
|
"run = experiment.start_logging()\n",
|
|
"client = ExplanationClient.from_run(run)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Uploading model explanation data for storage or visualization in webUX\n",
|
|
"# The explanation can then be downloaded on any compute\n",
|
|
"client.upload_model_explanation(global_explanation)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get model explanation data\n",
|
|
"explanation = client.download_model_explanation()\n",
|
|
"local_importance_values = explanation.local_importance_values\n",
|
|
"expected_values = explanation.expected_values"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Get the top k (e.g., 4) most important features with their importance values\n",
|
|
"explanation = client.download_model_explanation(top_k=4)\n",
|
|
"global_importance_values = explanation.get_ranked_global_values()\n",
|
|
"global_importance_names = explanation.get_ranked_global_names()\n",
|
|
"per_class_names = explanation.get_ranked_per_class_names()[0]\n",
|
|
"per_class_values = explanation.get_ranked_per_class_values()[0]"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"print('per class feature importance values: {}'.format(per_class_values))\n",
|
|
"print('per class feature importance names: {}'.format(per_class_names))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"dict(zip(per_class_names, per_class_values))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"authors": [
|
|
{
|
|
"name": "mesameki"
|
|
}
|
|
],
|
|
"kernelspec": {
|
|
"display_name": "Python 3.6",
|
|
"language": "python",
|
|
"name": "python36"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.6.8"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
} |