Files

59 lines
1.7 KiB
Python

# Modified from https://www.geeksforgeeks.org/multiclass-classification-using-scikit-learn/
import argparse
# importing necessary libraries
import numpy as np
from sklearn import datasets
from sklearn.metrics import confusion_matrix
from sklearn.model_selection import train_test_split
import joblib
from azureml.core.run import Run
run = Run.get_context()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--kernel', type=str, default='linear',
help='Kernel type to be used in the algorithm')
parser.add_argument('--penalty', type=float, default=1.0,
help='Penalty parameter of the error term')
args = parser.parse_args()
run.log('Kernel type', np.str(args.kernel))
run.log('Penalty', np.float(args.penalty))
# loading the iris dataset
iris = datasets.load_iris()
# X -> features, y -> label
X = iris.data
y = iris.target
# dividing X, y into train and test data
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
# training a linear SVM classifier
from sklearn.svm import SVC
svm_model_linear = SVC(kernel=args.kernel, C=args.penalty).fit(X_train, y_train)
svm_predictions = svm_model_linear.predict(X_test)
# model accuracy for X_test
accuracy = svm_model_linear.score(X_test, y_test)
print('Accuracy of SVM classifier on test set: {:.2f}'.format(accuracy))
run.log('Accuracy', np.float(accuracy))
# creating a confusion matrix
cm = confusion_matrix(y_test, svm_predictions)
print(cm)
# save model
joblib.dump(svm_model_linear, 'model.joblib')
if __name__ == '__main__':
main()