Update train.py
This commit is contained in:
@@ -2,7 +2,8 @@
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import os
|
||||
from sklearn.datasets import load_diabetes
|
||||
import argparse
|
||||
|
||||
from sklearn.linear_model import Ridge
|
||||
from sklearn.metrics import mean_squared_error
|
||||
from sklearn.model_selection import train_test_split
|
||||
@@ -12,8 +13,16 @@ from sklearn.externals import joblib
|
||||
import numpy as np
|
||||
|
||||
os.makedirs('./outputs', exist_ok=True)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data-folder', type=str,
|
||||
dest='data_folder', help='data folder')
|
||||
args = parser.parse_args()
|
||||
|
||||
X, y = load_diabetes(return_X_y=True)
|
||||
print('Data folder is at:', args.data_folder)
|
||||
print('List all files: ', os.listdir(args.data_folder))
|
||||
|
||||
X = np.load(os.path.join(args.data_folder, 'features.npy'))
|
||||
y = np.load(os.path.join(args.data_folder, 'labels.npy'))
|
||||
|
||||
run = Run.get_submitted_run()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user