Skip to content

Commit

Permalink
todo_factor: accept multiple values (per depth). autodetect todo_max_…
Browse files Browse the repository at this point in the history
…factor
  • Loading branch information
feffy380 committed Mar 11, 2024
1 parent 1d35144 commit 2629117
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 9 deletions.
2 changes: 1 addition & 1 deletion library/sdxl_train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def load_target_model(args, accelerator, model_version: str, weight_dtype):
# apply token merging patch
if args.todo_factor:
token_downsampling.apply_patch(unet, args, is_sdxl=True)
logger.info(f"enable token downsampling optimization | {unet._todo_info['args']}")
logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}")

return load_stable_diffusion_format, text_encoder1, text_encoder2, vae, unet, logit_scale, ckpt_info

Expand Down
28 changes: 22 additions & 6 deletions library/token_downsampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def compute_merge(x: torch.Tensor, todo_info: dict):
cur_w = original_w // downsample

args = todo_info["args"]
downsample_factor = args["downsample_factor"]

merge_op = lambda x: x
if downsample <= args["max_depth"]:
downsample_factor = args["downsample_factor"][downsample]
new_h = int(cur_h / downsample_factor)
new_w = int(cur_w / downsample_factor)
merge_op = lambda x: up_or_downsample(x, cur_w, cur_h, new_w, new_h)
Expand Down Expand Up @@ -58,14 +58,30 @@ def hook(module, args, kwargs):


def parse_todo_args(args, is_sdxl: bool = False) -> dict:
# validate max_depth
if args.todo_max_depth is None:
args.todo_max_depth = 2 if is_sdxl else 1
if is_sdxl and args.todo_max_depth not in (2, 3):
raise ValueError(f"--todo_max_depth for SDXL must be 2 or 3, received {args.todo_factor}")
args.todo_max_depth = min(len(args.todo_factor), 4)
if is_sdxl and args.todo_max_depth > 2:
raise ValueError(f"todo_max_depth for SDXL cannot be larger than 2, received {args.todo_max_depth}")

# validate factor
if len(args.todo_factor) > 1:
if len(args.todo_factor) != args.todo_max_depth:
raise ValueError(f"todo_factor number of values must be 1 or same as todo_max_depth, received {len(args.todo_factor)}")

# create dict of factors to support per-depth override
factors = args.todo_factor
if len(factors) == 1:
factors *= args.todo_max_depth
factors = {2**(i + int(is_sdxl)): factor for i, factor in enumerate(factors)}

# convert depth to powers of 2 to match layer dimensions: [1,2,3,4] -> [1,2,4,8]
# offset by 1 for sdxl which starts at 2
max_depth = 2**(args.todo_max_depth + int(is_sdxl) - 1)

todo_kwargs = {
"downsample_factor": args.todo_factor,
"max_depth": 2**(args.todo_max_depth - 1),
"downsample_factor": factors,
"max_depth": max_depth,
}

return todo_kwargs
Expand Down
5 changes: 3 additions & 2 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3139,7 +3139,8 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
parser.add_argument(
"--todo_factor",
type=float,
help="token downsampling (ToDo) factor > 1 (recommend around 2-4)",
nargs="+",
help="token downsampling (ToDo) factor > 1 (recommend around 2-4). Specify multiple to set factor for each depth",
)
parser.add_argument(
"--todo_max_depth",
Expand Down Expand Up @@ -4203,7 +4204,7 @@ def load_target_model(args, weight_dtype, accelerator, unet_use_linear_projectio
# apply token merging patch
if args.todo_factor:
token_downsampling.apply_patch(unet, args)
logger.info(f"enable token downsampling optimization | {unet._todo_info['args']}")
logger.info(f"enable token downsampling optimization: downsample_factor={args.todo_factor}, max_depth={args.todo_max_depth}")

return text_encoder, vae, unet, load_stable_diffusion_format

Expand Down

0 comments on commit 2629117

Please sign in to comment.