Skip to content

Commit

Permalink
initial work on sd_unet for SDXL
Browse files Browse the repository at this point in the history
  • Loading branch information
AUTOMATIC1111 committed Sep 11, 2023
1 parent 9246423 commit 5954432
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 7 deletions.
17 changes: 12 additions & 5 deletions modules/sd_hijack.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from torch.nn.functional import silu
from types import MethodType

from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet, patches
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr

import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.openaimodel
import ldm.models.diffusion.ddpm
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
import ldm.modules.encoders.modules
Expand Down Expand Up @@ -37,6 +38,8 @@
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None

ldm_original_forward = patches.patch(__file__, ldm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)
sgm_original_forward = patches.patch(__file__, sgm.modules.diffusionmodules.openaimodel.UNetModel, "forward", sd_unet.UNetModel_forward)

def list_optimizers():
new_optimizers = script_callbacks.list_optimizers_callback()
Expand Down Expand Up @@ -239,10 +242,13 @@ def flatten(el):

self.layers = flatten(m)

if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward
if isinstance(m, ldm.models.diffusion.ddpm.LatentDiffusion):
sd_unet.original_forward = ldm_original_forward
elif isinstance(m, sgm.models.diffusion.DiffusionEngine):
sd_unet.original_forward = sgm_original_forward
else:
sd_unet.original_forward = None

ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward

def undo_hijack(self, m):
conditioner = getattr(m, 'conditioner', None)
Expand Down Expand Up @@ -279,7 +285,8 @@ def undo_hijack(self, m):
self.layers = None
self.clip = None

ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui
sd_unet.original_forward = None


def apply_circular(self, enable):
if self.circular_enabled == enable:
Expand Down
4 changes: 2 additions & 2 deletions modules/sd_unet.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch.nn
import ldm.modules.diffusionmodules.openaimodel

from modules import script_callbacks, shared, devices

unet_options = []
current_unet_option = None
current_unet = None
original_forward = None


def list_unets():
Expand Down Expand Up @@ -88,5 +88,5 @@ def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
if current_unet is not None:
return current_unet.forward(x, timesteps, context, *args, **kwargs)

return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
return original_forward(self, x, timesteps, context, *args, **kwargs)

0 comments on commit 5954432

Please sign in to comment.