mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-21 01:55:07 -05:00
Updated notebook folders
This commit is contained in:
@@ -1,57 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torchvision import transforms
|
||||
import json
|
||||
import base64
|
||||
from io import BytesIO
|
||||
from PIL import Image
|
||||
|
||||
from azureml.core.model import Model
|
||||
|
||||
|
||||
def preprocess_image(image_file):
|
||||
"""Preprocess the input image."""
|
||||
data_transforms = transforms.Compose([
|
||||
transforms.Resize(256),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
|
||||
image = Image.open(image_file)
|
||||
image = data_transforms(image).float()
|
||||
image = torch.tensor(image)
|
||||
image = image.unsqueeze(0)
|
||||
return image
|
||||
|
||||
|
||||
def base64ToImg(base64ImgString):
|
||||
base64Img = base64ImgString.encode('utf-8')
|
||||
decoded_img = base64.b64decode(base64Img)
|
||||
return BytesIO(decoded_img)
|
||||
|
||||
|
||||
def init():
|
||||
global model
|
||||
model_path = Model.get_model_path('pytorch-hymenoptera')
|
||||
model = torch.load(model_path, map_location=lambda storage, loc: storage)
|
||||
model.eval()
|
||||
|
||||
|
||||
def run(input_data):
|
||||
img = base64ToImg(json.loads(input_data)['data'])
|
||||
img = preprocess_image(img)
|
||||
|
||||
# get prediction
|
||||
output = model(img)
|
||||
|
||||
classes = ['ants', 'bees']
|
||||
softmax = nn.Softmax(dim=1)
|
||||
pred_probs = softmax(model(img)).detach().numpy()[0]
|
||||
index = torch.argmax(output, 1)
|
||||
|
||||
result = {"label": classes[index], "probability": str(pred_probs[index])}
|
||||
return result
|
||||
Reference in New Issue
Block a user