mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-19 17:17:04 -05:00
update samples from Release-55 as a part of SDK release
This commit is contained in:
@@ -1,8 +1,11 @@
|
||||
import numpy as np
|
||||
import argparse
|
||||
from azureml.core import Run
|
||||
|
||||
import numpy as np
|
||||
|
||||
from sklearn.externals import joblib
|
||||
from azureml.automl.core.shared import constants, metrics
|
||||
|
||||
from azureml.automl.runtime.shared.score import scoring, constants
|
||||
from azureml.core import Run
|
||||
from azureml.core.model import Model
|
||||
|
||||
|
||||
@@ -29,22 +32,26 @@ model = joblib.load(model_path)
|
||||
run = Run.get_context()
|
||||
# get input dataset by name
|
||||
test_dataset = run.input_datasets['test_data']
|
||||
train_dataset = run.input_datasets['train_data']
|
||||
|
||||
X_test_df = test_dataset.drop_columns(columns=[target_column_name]) \
|
||||
.to_pandas_dataframe()
|
||||
y_test_df = test_dataset.with_timestamp_columns(None) \
|
||||
.keep_columns(columns=[target_column_name]) \
|
||||
.to_pandas_dataframe()
|
||||
y_train_df = test_dataset.with_timestamp_columns(None) \
|
||||
.keep_columns(columns=[target_column_name]) \
|
||||
.to_pandas_dataframe()
|
||||
|
||||
predicted = model.predict_proba(X_test_df)
|
||||
|
||||
# use automl metrics module
|
||||
scores = metrics.compute_metrics_classification(
|
||||
np.array(predicted),
|
||||
np.array(y_test_df),
|
||||
class_labels=model.classes_,
|
||||
metrics=list(constants.Metric.SCALAR_CLASSIFICATION_SET)
|
||||
)
|
||||
# Use the AutoML scoring module
|
||||
class_labels = np.unique(np.concatenate((y_train_df.values, y_test_df.values)))
|
||||
train_labels = model.classes_
|
||||
classification_metrics = list(constants.CLASSIFICATION_SCALAR_SET)
|
||||
scores = scoring.score_classification(y_test_df.values, predicted,
|
||||
classification_metrics,
|
||||
class_labels, train_labels)
|
||||
|
||||
print("scores:")
|
||||
print(scores)
|
||||
|
||||
Reference in New Issue
Block a user