mirror of
https://github.com/Azure/MachineLearningNotebooks.git
synced 2025-12-20 09:37:04 -05:00
32 lines
814 B
Python
32 lines
814 B
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
# Licensed under the MIT license.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torchvision import transforms
|
|
import json
|
|
|
|
from azureml.core.model import Model
|
|
|
|
|
|
def init():
|
|
global model
|
|
model_path = Model.get_model_path('pytorch-hymenoptera')
|
|
model = torch.load(model_path, map_location=lambda storage, loc: storage)
|
|
model.eval()
|
|
|
|
|
|
def run(input_data):
|
|
input_data = torch.tensor(json.loads(input_data)['data'])
|
|
|
|
# get prediction
|
|
with torch.no_grad():
|
|
output = model(input_data)
|
|
classes = ['ants', 'bees']
|
|
softmax = nn.Softmax(dim=1)
|
|
pred_probs = softmax(output).numpy()[0]
|
|
index = torch.argmax(output, 1)
|
|
|
|
result = {"label": classes[index], "probability": str(pred_probs[index])}
|
|
return result
|