Files
MachineLearningNotebooks/training/01.train-hyperparameter-tune-deploy-with-pytorch/pytorch_score.py
2018-11-05 15:27:36 -05:00

58 lines
1.4 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license.
import torch
import torch.nn as nn
from torchvision import transforms
import json
import base64
from io import BytesIO
from PIL import Image
from azureml.core.model import Model
def preprocess_image(image_file):
"""Preprocess the input image."""
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_file)
image = data_transforms(image).float()
image = torch.tensor(image)
image = image.unsqueeze(0)
return image
def base64ToImg(base64ImgString):
base64Img = base64ImgString.encode('utf-8')
decoded_img = base64.b64decode(base64Img)
return BytesIO(decoded_img)
def init():
global model
model_path = Model.get_model_path('pytorch-hymenoptera')
model = torch.load(model_path, map_location=lambda storage, loc: storage)
model.eval()
def run(input_data):
img = base64ToImg(json.loads(input_data)['data'])
img = preprocess_image(img)
# get prediction
output = model(img)
classes = ['ants', 'bees']
softmax = nn.Softmax(dim=1)
pred_probs = softmax(model(img)).detach().numpy()[0]
index = torch.argmax(output, 1)
result = {"label": classes[index], "probability": str(pred_probs[index])}
return result