Revert "Updated notebook folders"

This reverts commit 06728004b6.
This commit is contained in:
rastala
2018-11-20 10:39:48 -05:00
parent d7127de03c
commit d10b1fa796
114 changed files with 28854 additions and 65033 deletions

View File

@@ -0,0 +1,57 @@
# Copyright (c) Microsoft. All rights reserved.
# Licensed under the MIT license.
import torch
import torch.nn as nn
from torchvision import transforms
import json
import base64
from io import BytesIO
from PIL import Image
from azureml.core.model import Model
def preprocess_image(image_file):
"""Preprocess the input image."""
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_file)
image = data_transforms(image).float()
image = torch.tensor(image)
image = image.unsqueeze(0)
return image
def base64ToImg(base64ImgString):
base64Img = base64ImgString.encode('utf-8')
decoded_img = base64.b64decode(base64Img)
return BytesIO(decoded_img)
def init():
global model
model_path = Model.get_model_path('pytorch-hymenoptera')
model = torch.load(model_path, map_location=lambda storage, loc: storage)
model.eval()
def run(input_data):
img = base64ToImg(json.loads(input_data)['data'])
img = preprocess_image(img)
# get prediction
output = model(img)
classes = ['ants', 'bees']
softmax = nn.Softmax(dim=1)
pred_probs = softmax(model(img)).detach().numpy()[0]
index = torch.argmax(output, 1)
result = {"label": classes[index], "probability": str(pred_probs[index])}
return result