mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 01:27:06 -05:00
Compare commits
2 Commits
lostmygith
...
release_up
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7871e37ec0 | ||
|
|
58e584e7eb |
@@ -10,6 +10,8 @@ from sklearn.metrics import mean_absolute_error, mean_squared_error
|
|||||||
from azureml.automl.runtime.shared.score import scoring, constants
|
from azureml.automl.runtime.shared.score import scoring, constants
|
||||||
from azureml.core import Run
|
from azureml.core import Run
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def align_outputs(y_predicted, X_trans, X_test, y_test,
|
def align_outputs(y_predicted, X_trans, X_test, y_test,
|
||||||
predicted_column_name='predicted',
|
predicted_column_name='predicted',
|
||||||
@@ -221,6 +223,10 @@ def MAPE(actual, pred):
|
|||||||
return np.mean(APE(actual_safe, pred_safe))
|
return np.mean(APE(actual_safe, pred_safe))
|
||||||
|
|
||||||
|
|
||||||
|
def map_location_cuda(storage, loc):
|
||||||
|
return storage.cuda()
|
||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--max_horizon', type=int, dest='max_horizon',
|
'--max_horizon', type=int, dest='max_horizon',
|
||||||
@@ -274,8 +280,13 @@ X_lookback_df = lookback_dataset.drop_columns(columns=[target_column_name])
|
|||||||
y_lookback_df = lookback_dataset.with_timestamp_columns(
|
y_lookback_df = lookback_dataset.with_timestamp_columns(
|
||||||
None).keep_columns(columns=[target_column_name])
|
None).keep_columns(columns=[target_column_name])
|
||||||
|
|
||||||
fitted_model = joblib.load(model_path)
|
# Load the trained model with torch.
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
map_location = map_location_cuda
|
||||||
|
else:
|
||||||
|
map_location = 'cpu'
|
||||||
|
with open(model_path, 'rb') as fh:
|
||||||
|
fitted_model = torch.load(fh, map_location=map_location)
|
||||||
|
|
||||||
if hasattr(fitted_model, 'get_lookback'):
|
if hasattr(fitted_model, 'get_lookback'):
|
||||||
lookback = fitted_model.get_lookback()
|
lookback = fitted_model.get_lookback()
|
||||||
|
|||||||
Reference in New Issue
Block a user