Updated notebook folders

This commit is contained in:
Roope Astala
2018-11-19 13:28:49 -05:00
parent f5bcc55fe3
commit 06728004b6
114 changed files with 65052 additions and 28873 deletions

View File

@@ -1,57 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license.
import torch
import torch.nn as nn
from torchvision import transforms
import json
import base64
from io import BytesIO
from PIL import Image
from azureml.core.model import Model
def preprocess_image(image_file):
"""Preprocess the input image."""
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_file)
image = data_transforms(image).float()
image = torch.tensor(image)
image = image.unsqueeze(0)
return image
def base64ToImg(base64ImgString):
base64Img = base64ImgString.encode('utf-8')
decoded_img = base64.b64decode(base64Img)
return BytesIO(decoded_img)
def init():
global model
model_path = Model.get_model_path('pytorch-hymenoptera')
model = torch.load(model_path, map_location=lambda storage, loc: storage)
model.eval()
def run(input_data):
img = base64ToImg(json.loads(input_data)['data'])
img = preprocess_image(img)
# get prediction
output = model(img)
classes = ['ants', 'bees']
softmax = nn.Softmax(dim=1)
pred_probs = softmax(model(img)).detach().numpy()[0]
index = torch.argmax(output, 1)
result = {"label": classes[index], "probability": str(pred_probs[index])}
return result