Version 0.1.80

This commit is contained in:
rastala
2018-11-20 11:00:48 -05:00
parent d10b1fa796
commit 5726fe3ddb
61 changed files with 23040 additions and 24681 deletions

View File

@@ -5,35 +5,10 @@ import torch
import torch.nn as nn
from torchvision import transforms
import json
import base64
from io import BytesIO
from PIL import Image
from azureml.core.model import Model
def preprocess_image(image_file):
"""Preprocess the input image."""
data_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_file)
image = data_transforms(image).float()
image = torch.tensor(image)
image = image.unsqueeze(0)
return image
def base64ToImg(base64ImgString):
base64Img = base64ImgString.encode('utf-8')
decoded_img = base64.b64decode(base64Img)
return BytesIO(decoded_img)
def init():
global model
model_path = Model.get_model_path('pytorch-hymenoptera')
@@ -42,16 +17,15 @@ def init():
def run(input_data):
img = base64ToImg(json.loads(input_data)['data'])
img = preprocess_image(img)
input_data = torch.tensor(json.loads(input_data)['data'])
# get prediction
output = model(img)
classes = ['ants', 'bees']
softmax = nn.Softmax(dim=1)
pred_probs = softmax(model(img)).detach().numpy()[0]
index = torch.argmax(output, 1)
with torch.no_grad():
output = model(input_data)
classes = ['ants', 'bees']
softmax = nn.Softmax(dim=1)
pred_probs = softmax(output).numpy()[0]
index = torch.argmax(output, 1)
result = {"label": classes[index], "probability": str(pred_probs[index])}
return result