Skip to content

Commit

Permalink
notify #25, update ctrlnet core
Browse files Browse the repository at this point in the history
remove some typing imports not available during loading
  • Loading branch information
Kahsolt committed Apr 30, 2023
1 parent 343cc38 commit b2fa229
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 74 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ Not only prompts! We also support non-prompt conditions, read => [README_ext.md]

⚪ Fixups

- 2023/04/30: update controlnet core to `v1.1.116`
- 2023/03/29: `v2.4` bug fixes on script hook, now working correctly with extra networks & [sd-webui-controlnet](https://github.com/Mikubill/sd-webui-controlnet)
- 2023/01/31: keep up with webui's updates, (issue #14: `ImportError: cannot import name 'single_sample_to_image'`)
- 2023/01/28: keep up with webui's updates, extra-networks rework
Expand Down
149 changes: 75 additions & 74 deletions scripts/controlnet_travel.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,33 +47,17 @@ class InterpMethod(Enum):
except:
controlnet_found = False

def cfg_based_adder(self:UnetHook, base:Tensor, x:Tensor, require_autocast:bool, is_adapter=False):
def cfg_based_adder(self, base:Tensor, x:Tensor, require_autocast:bool):
self: UnetHook

if isinstance(x, float):
return base + x

if require_autocast:
zeros = torch.zeros_like(base)
zeros[:, :x.shape[1], ...] = x
x = zeros

# assume the input format is [cond, uncond] and they have same shape
# see https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/0cc0ee1bcb4c24a8c9715f66cede06601bfc00c8/modules/sd_samplers_kdiffusion.py#L114
if base.shape[0] % 2 == 0 and (self.guess_mode or shared.opts.data.get("control_net_cfg_based_guidance", False)):
if self.is_vanilla_samplers:
uncond, cond = base.chunk(2)
if x.shape[0] % 2 == 0:
_, x_cond = x.chunk(2)
return torch.cat([uncond, cond + x_cond], dim=0)
if is_adapter:
return torch.cat([uncond, cond + x], dim=0)
else:
cond, uncond = base.chunk(2)
if x.shape[0] % 2 == 0:
x_cond, _ = x.chunk(2)
return torch.cat([cond + x_cond, uncond], dim=0)
if is_adapter:
return torch.cat([cond + x, uncond], dim=0)


# resize to sample resolution
base_h, base_w = base.shape[-2:]
xh, xw = x.shape[-2:]
Expand All @@ -82,12 +66,13 @@ def cfg_based_adder(self:UnetHook, base:Tensor, x:Tensor, require_autocast:bool,

return base + x

def forward(outer:UnetHook, x:Tensor, timesteps:Tensor=None, context:Tensor=None, **kwargs):
def forward(outer, x:Tensor, timesteps:Tensor=None, context:Tensor=None, **kwargs):
''' NOTE: This function is called `sampling_steps*2` times (once for cond & uncond respectively) '''
outer: UnetHook

total_control = [0.0] * 13
total_adapter = [0.0] * 4
total_extra_cond = torch.zeros([0, context.shape[-1]]).to(devices.get_device_for("controlnet"))
total_extra_cond = None
only_mid_control = outer.only_mid_control
require_inpaint_hijack = False

Expand All @@ -97,50 +82,52 @@ def forward(outer:UnetHook, x:Tensor, timesteps:Tensor=None, context:Tensor=None
x: Tensor # [1, 4, 64, 64]
context: Tensor # [1, 77, 768]

# handle external cond first
for param in outer.control_params: # do nothing due to no extra_cond
# High-res fix
is_in_high_res_fix = False
for param in outer.control_params:
# select which hint_cond to use
param.used_hint_cond = param.hint_cond
# has high-res fix
if param.hr_hint_cond is not None and x.ndim == 4 and param.hint_cond.ndim == 3 and param.hr_hint_cond.ndim == 3:
_, h_lr, w_lr = param.hint_cond.shape
_, h_hr, w_hr = param.hr_hint_cond.shape
_, _, h, w = x.shape
h, w = h * 8, w * 8
if abs(h - h_lr) < abs(h - h_hr):
# we are in low-res path
param.used_hint_cond = param.hint_cond
else:
# we are in high-res path
param.used_hint_cond = param.hr_hint_cond
is_in_high_res_fix = True
if shared.opts.data.get("control_net_high_res_only_mid", False):
only_mid_control = True

# handle external cond
for param in outer.control_params:
if param.guidance_stopped or not param.is_extra_cond:
continue
if outer.lowvram:
param.control_model.to(devices.get_device_for("controlnet"))

control = param.control_model(x=x, hint=param.hint_cond, timesteps=timesteps, context=context)
total_extra_cond = torch.cat([total_extra_cond, control.clone().squeeze(0) * param.weight])

# check if it's non-batch-cond mode (lowvram, edit model etc)
if context.shape[0] % 2 != 0 and outer.batch_cond_available: # True
outer.batch_cond_available = False
if len(total_extra_cond) > 0 or outer.guess_mode or shared.opts.data.get("control_net_cfg_based_guidance", False):
print("Warning: StyleAdapter and cfg/guess mode may not works due to non-batch-cond inference")

# concat styleadapter to cond, pad uncond to same length
if len(total_extra_cond) > 0 and outer.batch_cond_available: # False
total_extra_cond = torch.repeat_interleave(total_extra_cond.unsqueeze(0), context.shape[0] // 2, dim=0)
if outer.is_vanilla_samplers:
uncond, cond = context.chunk(2)
cond = torch.cat([cond, total_extra_cond], dim=1)
uncond = torch.cat([uncond, uncond[:, -total_extra_cond.shape[1]:, :]], dim=1)
context = torch.cat([uncond, cond], dim=0)
param.control_model.to(devices.get_device_for("controlnet"))
query_size = int(x.shape[0])
control = param.control_model(x=x, hint=param.used_hint_cond, timesteps=timesteps, context=context)
uc_mask = param.generate_uc_mask(query_size, dtype=x.dtype, device=x.device)[:, None, None]
control = torch.concatenate([control.clone() for _ in range(query_size)], dim=0)
control *= param.weight
control *= uc_mask
if total_extra_cond is None:
total_extra_cond = control.clone()
else:
cond, uncond = context.chunk(2)
cond = torch.cat([cond, total_extra_cond], dim=1)
uncond = torch.cat([uncond, uncond[:, -total_extra_cond.shape[1]:, :]], dim=1)
context = torch.cat([cond, uncond], dim=0)
total_extra_cond = torch.cat([total_extra_cond, control.clone()], dim=1)

if total_extra_cond is not None:
context = torch.cat([context, total_extra_cond], dim=1)

# handle unet injection stuff
for i, param in enumerate(outer.control_params):
for param in outer.control_params:
if param.guidance_stopped or param.is_extra_cond:
continue
if outer.lowvram:
param.control_model.to(devices.get_device_for("controlnet"))

# hires stuffs
# note that this method may not works if hr_scale < 1.1
if abs(x.shape[-1] - param.hint_cond.shape[-1] // 8) > 8:
only_mid_control = shared.opts.data.get("control_net_only_midctrl_hires", True)
# If you want to completely disable control net, uncomment this.
# return self._original_forward(x, timesteps=timesteps, context=context, **kwargs)


param.control_model.to(devices.get_device_for("controlnet"))
# inpaint model workaround
x_in = x
control_model = param.control_model.control_model
Expand All @@ -152,29 +139,41 @@ def forward(outer:UnetHook, x:Tensor, timesteps:Tensor=None, context:Tensor=None
# NOTE: perform hint shallow fusion here
if interp_alpha == 0.0: # collect hind_cond on key frames
if len(to_hint_cond) < len(outer.control_params):
to_hint_cond.append(param.hint_cond.cpu().clone())
to_hint_cond.append(param.used_hint_cond.cpu().clone())
else: # interp with cached hind_cond
param.hint_cond = mid_hint_cond[i].to(x_in.device)
param.used_hint_cond = mid_hint_cond[i].to(x_in.device)

assert param.hint_cond is not None, f"Controlnet is enabled but no input image is given"
control = param.control_model(x=x_in, hint=param.hint_cond, timesteps=timesteps, context=context)
assert param.used_hint_cond is not None, f"Controlnet is enabled but no input image is given"
control = param.control_model(x=x_in, hint=param.used_hint_cond, timesteps=timesteps, context=context)
control_scales = ([param.weight] * 13)

if outer.lowvram:
param.control_model.to("cpu")
if param.guess_mode:

if param.cfg_injection or param.global_average_pooling:
query_size = int(x.shape[0])
if param.is_adapter:
# see https://github.com/Mikubill/sd-webui-controlnet/issues/269
control_scales = param.weight * [0.25, 0.62, 0.825, 1.0]
control = [torch.concatenate([c.clone() for _ in range(query_size)], dim=0) for c in control]
uc_mask = param.generate_uc_mask(query_size, dtype=x.dtype, device=x.device)[:, None, None, None]
control = [c * uc_mask for c in control]

if param.soft_injection or is_in_high_res_fix:
# important! use the soft weights with high-res fix can significantly reduce artifacts.
if param.is_adapter:
control_scales = [param.weight * x for x in (0.25, 0.62, 0.825, 1.0)]
else:
control_scales = [param.weight * (0.825 ** float(12 - i)) for i in range(13)]

if param.advanced_weighting is not None:
control_scales = param.advanced_weighting

control = [c * scale for c, scale in zip(control, control_scales)]
if param.global_average_pooling:
control = [torch.mean(c, dim=(2, 3), keepdim=True) for c in control]

for idx, item in enumerate(control):
target = total_adapter if param.is_adapter else total_control
target[idx] += item
target[idx] = item + target[idx]

# NOTE: perform latent fusion here
if interp_alpha == 0.0: # collect control tensors on key frames
Expand Down Expand Up @@ -205,9 +204,9 @@ def forward(outer:UnetHook, x:Tensor, timesteps:Tensor=None, context:Tensor=None
h = module(h, emb, context)

# t2i-adatper, same as openaimodel.py:744
if ((i+1)%3 == 0) and len(total_adapter):
h = cfg_based_adder(outer, h, total_adapter.pop(0), require_inpaint_hijack, is_adapter=True)

if ((i+1) % 3 == 0) and len(total_adapter):
h = cfg_based_adder(outer, h, total_adapter.pop(0), require_inpaint_hijack)
hs.append(h)
h = self.middle_block(h, emb, context)

Expand All @@ -226,8 +225,9 @@ def forward(outer:UnetHook, x:Tensor, timesteps:Tensor=None, context:Tensor=None
h = h.type(x.dtype)
return self.out(h)

def forward2(self: UnetHook, *args, **kwargs):
def forward2(self, *args, **kwargs):
# webui will handle other compoments
self: UnetHook
try:
if shared.cmd_opts.lowvram:
lowvram.send_everything_to_cpu()
Expand Down Expand Up @@ -351,13 +351,14 @@ def run(self, p:StableDiffusionProcessing,
self.controlnet_script = None
self.hooked = None
try:
from scripts.controlnet import Script as ControlNetScript
#from scripts.controlnet import Script as ControlNetScript
from scripts.external_code import ControlNetUnit
for script in p.scripts.alwayson_scripts:
if hasattr(script, "latest_network") and script.title().lower() == "controlnet":
script_args: Tuple[ControlNetUnit] = p.script_args[script.args_from:script.args_to]
if not any([u.enabled for u in script_args]): return Processed(p, [], p.seed, 'sd-webui-controlnet not enabled')
self.controlnet_script: ControlNetScript = script
#self.controlnet_script: ControlNetScript = script
self.controlnet_script = script
break
except ImportError:
return Processed(p, [], p.seed, 'sd-webui-controlnet not installed')
Expand Down

0 comments on commit b2fa229

Please sign in to comment.