Files
MachineLearningNotebooks/databricks/03a.Build_model.ipynb
rastala 90b3bf799f Update notebooks
Update notebooks
2018-09-20 10:02:39 -04:00

2 lines
5.1 KiB
Plaintext

{"cells":[{"cell_type":"markdown","source":["Azure ML & Azure Databricks notebooks by Parashar Shah.\n\nCopyright (c) Microsoft Corporation. All rights reserved.\n\nLicensed under the MIT License."],"metadata":{}},{"cell_type":"markdown","source":["Please ensure you have run all previous notebooks in sequence before running this."],"metadata":{}},{"cell_type":"markdown","source":["#Model Building"],"metadata":{}},{"cell_type":"code","source":["import os\nimport pprint\nimport numpy as np\n\nfrom pyspark.ml import Pipeline, PipelineModel\nfrom pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler\nfrom pyspark.ml.classification import LogisticRegression\nfrom pyspark.ml.evaluation import BinaryClassificationEvaluator\nfrom pyspark.ml.tuning import CrossValidator, ParamGridBuilder"],"metadata":{},"outputs":[],"execution_count":4},{"cell_type":"code","source":["#get the train and test datasets\ntrain_data_path = \"AdultCensusIncomeTrain\"\ntest_data_path = \"AdultCensusIncomeTest\"\n\ntrain = spark.read.parquet(train_data_path)\ntest = spark.read.parquet(test_data_path)\n\nprint(\"train: ({}, {})\".format(train.count(), len(train.columns)))\nprint(\"test: ({}, {})\".format(test.count(), len(test.columns)))\n\ntrain.printSchema()"],"metadata":{},"outputs":[],"execution_count":5},{"cell_type":"markdown","source":["#Define ML Pipeline"],"metadata":{}},{"cell_type":"code","source":["label = \"income\"\n\nreg = 0.1\nprint(\"Regularization Rate is {}.\".format(reg))\n\n# create a new Logistic Regression model.\nlr = LogisticRegression(regParam=reg)\n\ndtypes = dict(train.dtypes)\ndtypes.pop(label)\n\nsi_xvars = []\nohe_xvars = []\nfeatureCols = []\nfor idx,key in enumerate(dtypes):\n if dtypes[key] == \"string\":\n featureCol = \"-\".join([key, \"encoded\"])\n featureCols.append(featureCol)\n \n tmpCol = \"-\".join([key, \"tmp\"])\n # string-index and one-hot encode the string column\n #https://spark.apache.org/docs/2.3.0/api/java/org/apache/spark/ml/feature/StringIndexer.html\n #handleInvalid: Param for how to handle invalid data (unseen labels or NULL values). \n #Options are 'skip' (filter out rows with invalid data), 'error' (throw an error), \n #or 'keep' (put invalid data in a special additional bucket, at index numLabels). Default: \"error\"\n si_xvars.append(StringIndexer(inputCol=key, outputCol=tmpCol, handleInvalid=\"skip\")) #, handleInvalid=\"keep\"\n ohe_xvars.append(OneHotEncoder(inputCol=tmpCol, outputCol=featureCol))\n else:\n featureCols.append(key)\n\n# string-index the label column into a column named \"label\"\nsi_label = StringIndexer(inputCol=label, outputCol='label')\n\n# assemble the encoded feature columns in to a column named \"features\"\nassembler = VectorAssembler(inputCols=featureCols, outputCol=\"features\")\n\n# put together the pipeline\npipe = Pipeline(stages=[*si_xvars, *ohe_xvars, si_label, assembler, lr])\n\n# train the model\nmodel = pipe.fit(train)\nprint(model)"],"metadata":{},"outputs":[],"execution_count":7},{"cell_type":"markdown","source":["#Tune ML Pipeline"],"metadata":{}},{"cell_type":"code","source":["regs = np.arange(0.0, 1.0, 0.2)\n\nparamGrid = ParamGridBuilder().addGrid(lr.regParam, regs).build()\ncv = CrossValidator(estimator=pipe, evaluator=BinaryClassificationEvaluator(), estimatorParamMaps=paramGrid)"],"metadata":{},"outputs":[],"execution_count":9},{"cell_type":"code","source":["cvModel = cv.fit(train)"],"metadata":{},"outputs":[],"execution_count":10},{"cell_type":"code","source":["model = cvModel.bestModel"],"metadata":{},"outputs":[],"execution_count":11},{"cell_type":"markdown","source":["#Model Evaluation"],"metadata":{}},{"cell_type":"code","source":["# make prediction\npred = model.transform(test)\noutput = pred[['hours_per_week','age','workclass','marital_status','income','prediction']]\ndisplay(output.limit(5))"],"metadata":{},"outputs":[],"execution_count":13},{"cell_type":"code","source":["# evaluate. note only 2 metrics are supported out of the box by Spark ML.\nbce = BinaryClassificationEvaluator(rawPredictionCol='rawPrediction')\nau_roc = bce.setMetricName('areaUnderROC').evaluate(pred)\nau_prc = bce.setMetricName('areaUnderPR').evaluate(pred)\n\nprint(\"Area under ROC: {}\".format(au_roc))\nprint(\"Area Under PR: {}\".format(au_prc))"],"metadata":{},"outputs":[],"execution_count":14},{"cell_type":"markdown","source":["#Model Persistence"],"metadata":{}},{"cell_type":"code","source":["##NOTE: by default the model is saved to and loaded from /dbfs/ instead of cwd!\nmodel_name = \"AdultCensus.mml\"\nmodel_dbfs = os.path.join(\"/dbfs\", model_name)\n\nmodel.write().overwrite().save(model_name)\nprint(\"saved model to {}\".format(model_dbfs))"],"metadata":{},"outputs":[],"execution_count":16},{"cell_type":"code","source":["%sh\n\nls -la /dbfs/AdultCensus.mml/*"],"metadata":{},"outputs":[],"execution_count":17},{"cell_type":"code","source":["dbutils.notebook.exit(\"success\")"],"metadata":{},"outputs":[],"execution_count":18}],"metadata":{"name":"03a.Build_model","notebookId":3874566296719409},"nbformat":4,"nbformat_minor":0}