Patch UNet Forward to support resolutions that are not multiples of 64

Also modifed the UI to no longer step in 64
This commit is contained in:
Billy Cao
2022-11-23 18:11:24 +08:00
parent 828438b4a1
commit adb6cb7619
3 changed files with 45 additions and 12 deletions

View File

@@ -5,6 +5,7 @@ import importlib
import torch
from torch import einsum
import torch.nn.functional as F
from ldm.util import default
from einops import rearrange
@@ -12,6 +13,8 @@ from einops import rearrange
from modules import shared
from modules.hypernetworks import hypernetwork
from ldm.modules.diffusionmodules.util import timestep_embedding
if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
try:
@@ -310,3 +313,31 @@ def xformers_attnblock_forward(self, x):
return x + out
except NotImplementedError:
return cross_attention_attnblock_forward(self, x)
def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
if h.shape[-2:] != hs[-1].shape[-2:]:
h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)