Skip to content

Commit

Permalink
Merge branch 'pr/1151'
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Mar 13, 2024
2 parents a53b04b + 4fa42ba commit 9748ed2
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 0 deletions.
6 changes: 6 additions & 0 deletions library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from transformers import CLIPTokenizer
from library import model_util, sdxl_model_util, train_util, sdxl_original_unet
from library.sdxl_lpw_stable_diffusion import SdxlStableDiffusionLongPromptWeightingPipeline
from library import token_merging
from .utils import setup_logging
setup_logging()
import logging
Expand Down Expand Up @@ -57,6 +58,11 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

# apply token merging patch
if args.todo_factor:
token_merging.patch_attention(unet, args, is_sdxl=True)
logger.info(f"enable token downsampling optimization | {unet._tome_info['args']}")

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info


Expand Down
99 changes: 99 additions & 0 deletions library/token_merging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# based on:
# https://github.com/ethansmith2000/ImprovedTokenMerge
# https://github.com/ethansmith2000/comfy-todo (MIT)

import math

import torch
import torch.nn.functional as F


def up_or_downsample(item, cur_w, cur_h, new_w, new_h, method="nearest-exact"):
batch_size = item.shape[0]

item = item.reshape(batch_size, cur_h, cur_w, -1).permute(0, 3, 1, 2)
item = F.interpolate(item, size=(new_h, new_w), mode=method).permute(0, 2, 3, 1)
item = item.reshape(batch_size, new_h * new_w, -1)

return item


def compute_merge(x: torch.Tensor, tome_info: dict):
original_h, original_w = tome_info["size"]
original_tokens = original_h * original_w
downsample = int(math.ceil(math.sqrt(original_tokens // x.shape[1])))
cur_h = original_h // downsample
cur_w = original_w // downsample

args = tome_info["args"]
downsample_factor = args["downsample_factor"]

merge_op = lambda x: x
if downsample <= args["max_downsample"]:
new_h = int(cur_h / downsample_factor)
new_w = int(cur_w / downsample_factor)
merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h)

return merge_op


def hook_tome_model(model: torch.nn.Module):
""" Adds a forward pre hook to get the image size. This hook can be removed with remove_patch. """
def hook(module, args):
module._tome_info["size"] = (args[0].shape[2], args[0].shape[3])
return None

model._tome_info["hooks"].append(model.register_forward_pre_hook(hook))


def hook_attention(attn: torch.nn.Module):
""" Adds a forward pre hook to downsample attention keys and values. This hook can be removed with remove_patch. """
def hook(module, args, kwargs):
hidden_states = args[0]
m = compute_merge(hidden_states, module._tome_info)
kwargs["context"] = m(hidden_states)
return args, kwargs

attn._tome_info["hooks"].append(attn.register_forward_pre_hook(hook, with_kwargs=True))


def parse_todo_args(args, is_sdxl: bool = False) -> dict:
if args.todo_max_downsample is None:
args.todo_max_downsample = 2 if is_sdxl else 1
if is_sdxl and args.todo_max_downsample not in (2, 4):
raise ValueError(f"--todo_max_downsample for SDXL must be 2 or 4, received {args.todo_factor}")

todo_kwargs = {
"downsample_factor": args.todo_factor,
"max_downsample": args.todo_max_downsample,
}

return todo_kwargs


def patch_attention(unet: torch.nn.Module, args, is_sdxl=False):
""" Patches the UNet's transformer blocks to apply token downsampling. """
todo_kwargs = parse_todo_args(args, is_sdxl)

unet._tome_info = {
"size": None,
"hooks": [],
"args": todo_kwargs,
}
hook_tome_model(unet)

for _, module in unet.named_modules():
if module.__class__.__name__ == "BasicTransformerBlock":
module.attn1._tome_info = unet._tome_info
hook_attention(module.attn1)

return unet


def remove_patch(unet: torch.nn.Module):
if hasattr(unet, "_tome_info"):
for hook in unet._tome_info["hooks"]:
hook.remove()
unet._tome_info["hooks"].clear()

return unet
20 changes: 20 additions & 0 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import library.model_util as model_util
import library.huggingface_util as huggingface_util
import library.sai_model_spec as sai_model_spec
from library import token_merging
from library.utils import setup_logging

setup_logging()
Expand Down Expand Up @@ -3168,6 +3169,20 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="set maximum time step for U-Net training (1~1000, default is 1000) / U-Net学習時のtime stepの最大値を設定する(1~1000で指定、省略時はデフォルト値(1000))",
)
parser.add_argument(
"--todo_factor",
type=float,
help="token downsampling (ToDo) factor > 1 (recommend around 2-4)",
)
parser.add_argument(
"--todo_max_downsample",
type=int,
choices=[1, 2, 4, 8],
help=(
"apply ToDo to layers with at most this amount of downsampling."
" SDXL only accepts 2 and 4. Recommend 1 or 2. Default 1 (or 2 for SDXL)"
),
)

parser.add_argument(
"--lowram",
Expand Down Expand Up @@ -4219,6 +4234,11 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
clean_memory_on_device(accelerator.device)
accelerator.wait_for_everyone()

# apply token merging patch
if args.todo_factor:
token_merging.patch_attention(unet, args)
logger.info(f"enable token downsampling optimization | {unet._tome_info['args']}")

return text_encoder, vae, unet, load_stable_diffusion_format


Expand Down
4 changes: 4 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,6 +762,10 @@ def train(self, args):
vae_name = os.path.basename(vae_name)
metadata["ss_vae_name"] = vae_name

if args.todo_factor:
metadata["ss_todo_factor"] = args.todo_factor
metadata["ss_todo_max_downsample"] = args.todo_max_downsample

metadata = {k: str(v) for k, v in metadata.items()}

# make minimum metadata for filtering
Expand Down

0 comments on commit 9748ed2

Please sign in to comment.