Skip to content

Commit

Permalink
webui support controlnet extension (#948)
Browse files Browse the repository at this point in the history
only supported UNet inference with control, but controlnet model not
supported now

---------

Co-authored-by: Li Junliang <117806079+lijunliangTG@users.noreply.github.com>
  • Loading branch information
marigoold and lijunliangTG authored Jun 22, 2024
1 parent 30d1168 commit 102ee05
Show file tree
Hide file tree
Showing 8 changed files with 1,283 additions and 44 deletions.
8 changes: 5 additions & 3 deletions onediff_sd_webui_extensions/compile/compile_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,9 @@ def get_calibrate_info(filename: str) -> Union[None, Dict]:


def get_compiled_graph(sd_model, quantization) -> OneDiffCompiledGraph:
compiled_unet = compile_unet(
sd_model.model.diffusion_model, quantization=quantization
)
diffusion_model = sd_model.model.diffusion_model
# for controlnet
if "forward" in diffusion_model.__dict__:
diffusion_model.__dict__.pop("forward")
compiled_unet = compile_unet(diffusion_model, quantization=quantization)
return OneDiffCompiledGraph(sd_model, compiled_unet, quantization)
27 changes: 16 additions & 11 deletions onediff_sd_webui_extensions/compile/sd_webui_onediff_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,22 @@ def forward(self, x):


# https://github.com/Stability-AI/generative-models/blob/059d8e9cd9c55aea1ef2ece39abf605efb8b7cc9/sgm/modules/diffusionmodules/util.py#L207
def timestep_embedding(timesteps, dim, max_period=10000):
half = dim // 2
freqs = flow.exp(
-math.log(max_period)
* flow.arange(start=0, end=half, dtype=flow.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = flow.cat([flow.cos(args), flow.sin(args)], dim=-1)
if dim % 2:
embedding = flow.cat([embedding, flow.zeros_like(embedding[:, :1])], dim=-1)
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = flow.exp(
-math.log(max_period)
* flow.arange(start=0, end=half, dtype=flow.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = flow.cat([flow.cos(args), flow.sin(args)], dim=-1)
if dim % 2:
embedding = flow.cat([embedding, flow.zeros_like(embedding[:, :1])], dim=-1)
else:
raise NotImplementedError(
"repeat_only=True is not implemented in timestep_embedding"
)
return embedding


Expand Down
Loading

0 comments on commit 102ee05

Please sign in to comment.