Files

173 lines
5.9 KiB
Python

import argparse
import os
import sys
import re
import json
import traceback
from PIL import Image
import torch
from torchvision import transforms
from azureml.core.model import Model
style_model = None
class TransformerNet(torch.nn.Module):
def __init__(self):
super(TransformerNet, self).__init__()
# Initial convolution layers
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
# Residual layers
self.res1 = ResidualBlock(128)
self.res2 = ResidualBlock(128)
self.res3 = ResidualBlock(128)
self.res4 = ResidualBlock(128)
self.res5 = ResidualBlock(128)
# Upsampling Layers
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
# Non-linearities
self.relu = torch.nn.ReLU()
def forward(self, X):
y = self.relu(self.in1(self.conv1(X)))
y = self.relu(self.in2(self.conv2(y)))
y = self.relu(self.in3(self.conv3(y)))
y = self.res1(y)
y = self.res2(y)
y = self.res3(y)
y = self.res4(y)
y = self.res5(y)
y = self.relu(self.in4(self.deconv1(y)))
y = self.relu(self.in5(self.deconv2(y)))
y = self.deconv3(y)
return y
class ConvLayer(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvLayer, self).__init__()
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
class ResidualBlock(torch.nn.Module):
"""ResidualBlock
introduced in: https://arxiv.org/abs/1512.03385
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
"""
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
self.relu = torch.nn.ReLU()
def forward(self, x):
residual = x
out = self.relu(self.in1(self.conv1(x)))
out = self.in2(self.conv2(out))
out = out + residual
return out
class UpsampleConvLayer(torch.nn.Module):
"""UpsampleConvLayer
Upsamples the input and then does a convolution. This method gives better results
compared to ConvTranspose2d.
ref: http://distill.pub/2016/deconv-checkerboard/
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
super(UpsampleConvLayer, self).__init__()
self.upsample = upsample
if upsample:
self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
reflection_padding = kernel_size // 2
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
def forward(self, x):
x_in = x
if self.upsample:
x_in = self.upsample_layer(x_in)
out = self.reflection_pad(x_in)
out = self.conv2d(out)
return out
def load_image(filename):
img = Image.open(filename)
return img
def save_image(filename, data):
img = data.clone().clamp(0, 255).numpy()
img = img.transpose(1, 2, 0).astype("uint8")
img = Image.fromarray(img)
img.save(filename)
def init():
global output_path, args
global style_model, device
output_path = os.environ['AZUREML_BI_OUTPUT_PATH']
print(f'output path: {output_path}')
print(f'Cuda available? {torch.cuda.is_available()}')
arg_parser = argparse.ArgumentParser(description="parser for fast-neural-style")
arg_parser.add_argument("--style", type=str, help="style name")
args, unknown_args = arg_parser.parse_known_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
with torch.no_grad():
style_model = TransformerNet()
model_path = Model.get_model_path(args.style)
state_dict = torch.load(os.path.join(model_path))
# remove saved deprecated running_* keys in InstanceNorm from the checkpoint
for k in list(state_dict.keys()):
if re.search(r'in\d+\.running_(mean|var)$', k):
del state_dict[k]
style_model.load_state_dict(state_dict)
style_model.to(device)
print(f'Model loaded successfully. Path: {model_path}')
def run(mini_batch):
result = []
for image_file_path in mini_batch:
img = load_image(image_file_path)
with torch.no_grad():
content_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Lambda(lambda x: x.mul(255))
])
content_image = content_transform(img)
content_image = content_image.unsqueeze(0).to(device)
output = style_model(content_image).cpu()
output_file_path = os.path.join(output_path, os.path.basename(image_file_path))
save_image(output_file_path, output[0])
result.append(output_file_path)
return result