Files
MachineLearningNotebooks/training/01.train-hyperparameter-tune-deploy-with-pytorch/pytorch_score.py
rastala d10b1fa796 Revert "Updated notebook folders"
This reverts commit 06728004b6.
2018-11-20 10:39:48 -05:00

58 lines
1.4 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license.
import torch
import torch.nn as nn
from torchvision import transforms
import json
import base64
from io import BytesIO
from PIL import Image
from azureml.core.model import Model
def preprocess_image(image_file):
"""Preprocess the input image."""
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_file)
image = data_transforms(image).float()
image = torch.tensor(image)
image = image.unsqueeze(0)
return image
def base64ToImg(base64ImgString):
base64Img = base64ImgString.encode('utf-8')
decoded_img = base64.b64decode(base64Img)
return BytesIO(decoded_img)
def init():
global model
model_path = Model.get_model_path('pytorch-hymenoptera')
model = torch.load(model_path, map_location=lambda storage, loc: storage)
model.eval()
def run(input_data):
img = base64ToImg(json.loads(input_data)['data'])
img = preprocess_image(img)
# get prediction
output = model(img)
classes = ['ants', 'bees']
softmax = nn.Softmax(dim=1)
pred_probs = softmax(model(img)).detach().numpy()[0]
index = torch.argmax(output, 1)
result = {"label": classes[index], "probability": str(pred_probs[index])}
return result