diff --git a/01.getting-started/04.train-on-remote-vm/train.py b/01.getting-started/04.train-on-remote-vm/train.py index 9f039de5..f0eac9a0 100644 --- a/01.getting-started/04.train-on-remote-vm/train.py +++ b/01.getting-started/04.train-on-remote-vm/train.py @@ -2,7 +2,8 @@ # Licensed under the MIT license. import os -from sklearn.datasets import load_diabetes +import argparse + from sklearn.linear_model import Ridge from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split @@ -12,8 +13,16 @@ from sklearn.externals import joblib import numpy as np os.makedirs('./outputs', exist_ok=True) +parser = argparse.ArgumentParser() +parser.add_argument('--data-folder', type=str, + dest='data_folder', help='data folder') +args = parser.parse_args() -X, y = load_diabetes(return_X_y=True) +print('Data folder is at:', args.data_folder) +print('List all files: ', os.listdir(args.data_folder)) + +X = np.load(os.path.join(args.data_folder, 'features.npy')) +y = np.load(os.path.join(args.data_folder, 'labels.npy')) run = Run.get_submitted_run()