mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-19 17:17:04 -05:00
112 lines
3.3 KiB
Python
112 lines
3.3 KiB
Python
# ---------------------------------------------------------
|
|
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
# ---------------------------------------------------------
|
|
|
|
"""Utilities for azureml-contrib-fairness notebooks."""
|
|
|
|
import arff
|
|
from collections import OrderedDict
|
|
from contextlib import closing
|
|
import gzip
|
|
import pandas as pd
|
|
from sklearn.datasets import fetch_openml
|
|
from sklearn.utils import Bunch
|
|
import time
|
|
|
|
|
|
def fetch_openml_with_retries(data_id, max_retries=4, retry_delay=60):
|
|
"""Fetch a given dataset from OpenML with retries as specified."""
|
|
for i in range(max_retries):
|
|
try:
|
|
print("Download attempt {0} of {1}".format(i + 1, max_retries))
|
|
data = fetch_openml(data_id=data_id, as_frame=True)
|
|
break
|
|
except Exception as e: # noqa: B902
|
|
print("Download attempt failed with exception:")
|
|
print(e)
|
|
if i + 1 != max_retries:
|
|
print("Will retry after {0} seconds".format(retry_delay))
|
|
time.sleep(retry_delay)
|
|
retry_delay = retry_delay * 2
|
|
else:
|
|
raise RuntimeError("Unable to download dataset from OpenML")
|
|
|
|
return data
|
|
|
|
|
|
_categorical_columns = [
|
|
'workclass',
|
|
'education',
|
|
'marital-status',
|
|
'occupation',
|
|
'relationship',
|
|
'race',
|
|
'sex',
|
|
'native-country'
|
|
]
|
|
|
|
|
|
def fetch_census_dataset():
|
|
"""Fetch the Adult Census Dataset.
|
|
|
|
This uses a particular URL for the Adult Census dataset. The code
|
|
is a simplified version of fetch_openml() in sklearn.
|
|
|
|
The data are copied from:
|
|
https://openml.org/data/v1/download/1595261.gz
|
|
(as of 2021-03-31)
|
|
"""
|
|
try:
|
|
from urllib import urlretrieve
|
|
except ImportError:
|
|
from urllib.request import urlretrieve
|
|
|
|
filename = "1595261.gz"
|
|
data_url = "https://rainotebookscdn.blob.core.windows.net/datasets/"
|
|
|
|
remaining_attempts = 5
|
|
sleep_duration = 10
|
|
while remaining_attempts > 0:
|
|
try:
|
|
urlretrieve(data_url + filename, filename)
|
|
|
|
http_stream = gzip.GzipFile(filename=filename, mode='rb')
|
|
|
|
with closing(http_stream):
|
|
def _stream_generator(response):
|
|
for line in response:
|
|
yield line.decode('utf-8')
|
|
|
|
stream = _stream_generator(http_stream)
|
|
data = arff.load(stream)
|
|
except Exception as exc: # noqa: B902
|
|
remaining_attempts -= 1
|
|
print("Error downloading dataset from {} ({} attempt(s) remaining)"
|
|
.format(data_url, remaining_attempts))
|
|
print(exc)
|
|
time.sleep(sleep_duration)
|
|
sleep_duration *= 2
|
|
continue
|
|
else:
|
|
# dataset successfully downloaded
|
|
break
|
|
else:
|
|
raise Exception("Could not retrieve dataset from {}.".format(data_url))
|
|
|
|
attributes = OrderedDict(data['attributes'])
|
|
arff_columns = list(attributes)
|
|
|
|
raw_df = pd.DataFrame(data=data['data'], columns=arff_columns)
|
|
|
|
target_column_name = 'class'
|
|
target = raw_df.pop(target_column_name)
|
|
for col_name in _categorical_columns:
|
|
dtype = pd.api.types.CategoricalDtype(attributes[col_name])
|
|
raw_df[col_name] = raw_df[col_name].astype(dtype, copy=False)
|
|
|
|
result = Bunch()
|
|
result.data = raw_df
|
|
result.target = target
|
|
|
|
return result
|