Skip to content

Commit

Permalink
Fix issue with TE LR
Browse files Browse the repository at this point in the history
  • Loading branch information
bmaltais committed Mar 3, 2024
1 parent 642cca3 commit 74a7b03
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 161 deletions.
167 changes: 87 additions & 80 deletions kohya_gui/dreambooth_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,87 +565,94 @@ def train_model(
else:
run_cmd += fr' "{scriptdir}/sd-scripts/train_db.py"'

# Initialize a dictionary with always-included keyword arguments
kwargs_for_training = {
"adaptive_noise_scale": adaptive_noise_scale,
"additional_parameters": additional_parameters,
"bucket_no_upscale": bucket_no_upscale,
"bucket_reso_steps": bucket_reso_steps,
"cache_latents": cache_latents,
"cache_latents_to_disk": cache_latents_to_disk,
"caption_dropout_every_n_epochs": caption_dropout_every_n_epochs,
"caption_dropout_rate": caption_dropout_rate,
"caption_extension": caption_extension,
"clip_skip": clip_skip,
"color_aug": color_aug,
"enable_bucket": enable_bucket,
"epoch": epoch,
"flip_aug": flip_aug,
"full_bf16": full_bf16,
"full_fp16": full_fp16,
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_checkpointing": gradient_checkpointing,
"keep_tokens": keep_tokens,
"learning_rate": learning_rate,
"logging_dir": logging_dir,
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": lr_scheduler_args,
"lr_scheduler_num_cycles": lr_scheduler_num_cycles,
"lr_scheduler_power": lr_scheduler_power,
"lr_warmup_steps": lr_warmup_steps,
"max_bucket_reso": max_bucket_reso,
"max_data_loader_n_workers": max_data_loader_n_workers,
"max_resolution": max_resolution,
"max_timestep": max_timestep,
"max_token_length": max_token_length,
"max_train_epochs": max_train_epochs,
"max_train_steps": max_train_steps,
"mem_eff_attn": mem_eff_attn,
"min_bucket_reso": min_bucket_reso,
"min_snr_gamma": min_snr_gamma,
"min_timestep": min_timestep,
"mixed_precision": mixed_precision,
"multires_noise_discount": multires_noise_discount,
"multires_noise_iterations": multires_noise_iterations,
"no_token_padding": no_token_padding,
"noise_offset": noise_offset,
"noise_offset_type": noise_offset_type,
"optimizer": optimizer,
"optimizer_args": optimizer_args,
"output_dir": output_dir,
"output_name": output_name,
"persistent_data_loader_workers": persistent_data_loader_workers,
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"prior_loss_weight": prior_loss_weight,
"random_crop": random_crop,
"reg_data_dir": reg_data_dir,
"resume": resume,
"save_every_n_epochs": save_every_n_epochs,
"save_every_n_steps": save_every_n_steps,
"save_last_n_steps": save_last_n_steps,
"save_last_n_steps_state": save_last_n_steps_state,
"save_model_as": save_model_as,
"save_precision": save_precision,
"save_state": save_state,
"scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred,
"seed": seed,
"shuffle_caption": shuffle_caption,
"stop_text_encoder_training": stop_text_encoder_training,
"train_batch_size": train_batch_size,
"train_data_dir": train_data_dir,
"use_wandb": use_wandb,
"v2": v2,
"v_parameterization": v_parameterization,
"v_pred_like_loss": v_pred_like_loss,
"vae": vae,
"vae_batch_size": vae_batch_size,
"wandb_api_key": wandb_api_key,
"weighted_captions": weighted_captions,
"xformers": xformers,
}

# Conditionally include specific keyword arguments based on sdxl
if sdxl:
kwargs_for_training["learning_rate_te1"] = learning_rate_te1
kwargs_for_training["learning_rate_te2"] = learning_rate_te2
else:
kwargs_for_training["learning_rate_te"] = learning_rate_te

run_cmd += run_cmd_advanced_training(
adaptive_noise_scale=adaptive_noise_scale,
additional_parameters=additional_parameters,
bucket_no_upscale=bucket_no_upscale,
bucket_reso_steps=bucket_reso_steps,
cache_latents=cache_latents,
cache_latents_to_disk=cache_latents_to_disk,
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
caption_dropout_rate=caption_dropout_rate,
caption_extension=caption_extension,
clip_skip=clip_skip,
color_aug=color_aug,
enable_bucket=enable_bucket,
epoch=epoch,
flip_aug=flip_aug,
full_bf16=full_bf16,
full_fp16=full_fp16,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=gradient_checkpointing,
keep_tokens=keep_tokens,
learning_rate=learning_rate,
learning_rate_te1=learning_rate_te1 if sdxl else None,
learning_rate_te2=learning_rate_te2 if sdxl else None,
learning_rate_te=learning_rate_te if not sdxl else None,
logging_dir=logging_dir,
lr_scheduler=lr_scheduler,
lr_scheduler_args=lr_scheduler_args,
lr_scheduler_num_cycles=lr_scheduler_num_cycles,
lr_scheduler_power=lr_scheduler_power,
lr_warmup_steps=lr_warmup_steps,
max_bucket_reso=max_bucket_reso,
max_data_loader_n_workers=max_data_loader_n_workers,
max_resolution=max_resolution,
max_timestep=max_timestep,
max_token_length=max_token_length,
max_train_epochs=max_train_epochs,
max_train_steps=max_train_steps,
mem_eff_attn=mem_eff_attn,
min_bucket_reso=min_bucket_reso,
min_snr_gamma=min_snr_gamma,
min_timestep=min_timestep,
mixed_precision=mixed_precision,
multires_noise_discount=multires_noise_discount,
multires_noise_iterations=multires_noise_iterations,
no_token_padding=no_token_padding,
noise_offset=noise_offset,
noise_offset_type=noise_offset_type,
optimizer=optimizer,
optimizer_args=optimizer_args,
output_dir=output_dir,
output_name=output_name,
persistent_data_loader_workers=persistent_data_loader_workers,
pretrained_model_name_or_path=pretrained_model_name_or_path,
prior_loss_weight=prior_loss_weight,
random_crop=random_crop,
reg_data_dir=reg_data_dir,
resume=resume,
save_every_n_epochs=save_every_n_epochs,
save_every_n_steps=save_every_n_steps,
save_last_n_steps=save_last_n_steps,
save_last_n_steps_state=save_last_n_steps_state,
save_model_as=save_model_as,
save_precision=save_precision,
save_state=save_state,
scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred,
seed=seed,
shuffle_caption=shuffle_caption,
stop_text_encoder_training=stop_text_encoder_training,
train_batch_size=train_batch_size,
train_data_dir=train_data_dir,
use_wandb=use_wandb,
v2=v2,
v_parameterization=v_parameterization,
v_pred_like_loss=v_pred_like_loss,
vae=vae,
vae_batch_size=vae_batch_size,
wandb_api_key=wandb_api_key,
weighted_captions=weighted_captions,
xformers=xformers,
)
# Pass the dynamically constructed keyword arguments to the function
run_cmd += run_cmd_advanced_training(**kwargs_for_training)

run_cmd += run_cmd_sample(
sample_every_n_steps,
Expand Down
166 changes: 86 additions & 80 deletions kohya_gui/finetune_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,86 +566,92 @@ def train_model(
cache_text_encoder_outputs = sdxl_checkbox and sdxl_cache_text_encoder_outputs
no_half_vae = sdxl_checkbox and sdxl_no_half_vae

run_cmd += run_cmd_advanced_training(
adaptive_noise_scale=adaptive_noise_scale,
additional_parameters=additional_parameters,
block_lr=block_lr,
bucket_no_upscale=bucket_no_upscale,
bucket_reso_steps=bucket_reso_steps,
cache_latents=cache_latents,
cache_latents_to_disk=cache_latents_to_disk,
cache_text_encoder_outputs=cache_text_encoder_outputs
if sdxl_checkbox
else None,
caption_dropout_every_n_epochs=caption_dropout_every_n_epochs,
caption_dropout_rate=caption_dropout_rate,
caption_extension=caption_extension,
clip_skip=clip_skip,
color_aug=color_aug,
dataset_repeats=dataset_repeats,
enable_bucket=True,
flip_aug=flip_aug,
full_bf16=full_bf16,
full_fp16=full_fp16,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_checkpointing=gradient_checkpointing,
in_json=in_json,
keep_tokens=keep_tokens,
learning_rate=learning_rate,
learning_rate_te1=learning_rate_te1 if sdxl_checkbox else None,
learning_rate_te2=learning_rate_te2 if sdxl_checkbox else None,
learning_rate_te=learning_rate_te if not sdxl_checkbox else None,
logging_dir=logging_dir,
lr_scheduler=lr_scheduler,
lr_scheduler_args=lr_scheduler_args,
lr_warmup_steps=lr_warmup_steps,
max_bucket_reso=max_bucket_reso,
max_data_loader_n_workers=max_data_loader_n_workers,
max_resolution=max_resolution,
max_timestep=max_timestep,
max_token_length=max_token_length,
max_train_epochs=max_train_epochs,
max_train_steps=max_train_steps,
mem_eff_attn=mem_eff_attn,
min_bucket_reso=min_bucket_reso,
min_snr_gamma=min_snr_gamma,
min_timestep=min_timestep,
mixed_precision=mixed_precision,
multires_noise_discount=multires_noise_discount,
multires_noise_iterations=multires_noise_iterations,
no_half_vae=no_half_vae if sdxl_checkbox else None,
noise_offset=noise_offset,
noise_offset_type=noise_offset_type,
optimizer=optimizer,
optimizer_args=optimizer_args,
output_dir=output_dir,
output_name=output_name,
persistent_data_loader_workers=persistent_data_loader_workers,
pretrained_model_name_or_path=pretrained_model_name_or_path,
random_crop=random_crop,
resume=resume,
save_every_n_epochs=save_every_n_epochs,
save_every_n_steps=save_every_n_steps,
save_last_n_steps=save_last_n_steps,
save_last_n_steps_state=save_last_n_steps_state,
save_model_as=save_model_as,
save_precision=save_precision,
save_state=save_state,
scale_v_pred_loss_like_noise_pred=scale_v_pred_loss_like_noise_pred,
seed=seed,
shuffle_caption=shuffle_caption,
train_batch_size=train_batch_size,
train_data_dir=image_folder,
train_text_encoder=train_text_encoder,
use_wandb=use_wandb,
v2=v2,
v_parameterization=v_parameterization,
v_pred_like_loss=v_pred_like_loss,
vae_batch_size=vae_batch_size,
wandb_api_key=wandb_api_key,
weighted_captions=weighted_captions,
xformers=xformers,
)
# Initialize a dictionary with always-included keyword arguments
kwargs_for_training = {
"adaptive_noise_scale": adaptive_noise_scale,
"additional_parameters": additional_parameters,
"block_lr": block_lr,
"bucket_no_upscale": bucket_no_upscale,
"bucket_reso_steps": bucket_reso_steps,
"cache_latents": cache_latents,
"cache_latents_to_disk": cache_latents_to_disk,
"caption_dropout_every_n_epochs": caption_dropout_every_n_epochs,
"caption_dropout_rate": caption_dropout_rate,
"caption_extension": caption_extension,
"clip_skip": clip_skip,
"color_aug": color_aug,
"dataset_repeats": dataset_repeats,
"enable_bucket": True,
"flip_aug": flip_aug,
"full_bf16": full_bf16,
"full_fp16": full_fp16,
"gradient_accumulation_steps": gradient_accumulation_steps,
"gradient_checkpointing": gradient_checkpointing,
"in_json": in_json,
"keep_tokens": keep_tokens,
"learning_rate": learning_rate,
"logging_dir": logging_dir,
"lr_scheduler": lr_scheduler,
"lr_scheduler_args": lr_scheduler_args,
"lr_warmup_steps": lr_warmup_steps,
"max_bucket_reso": max_bucket_reso,
"max_data_loader_n_workers": max_data_loader_n_workers,
"max_resolution": max_resolution,
"max_timestep": max_timestep,
"max_token_length": max_token_length,
"max_train_epochs": max_train_epochs,
"max_train_steps": max_train_steps,
"mem_eff_attn": mem_eff_attn,
"min_bucket_reso": min_bucket_reso,
"min_snr_gamma": min_snr_gamma,
"min_timestep": min_timestep,
"mixed_precision": mixed_precision,
"multires_noise_discount": multires_noise_discount,
"multires_noise_iterations": multires_noise_iterations,
"noise_offset": noise_offset,
"noise_offset_type": noise_offset_type,
"optimizer": optimizer,
"optimizer_args": optimizer_args,
"output_dir": output_dir,
"output_name": output_name,
"persistent_data_loader_workers": persistent_data_loader_workers,
"pretrained_model_name_or_path": pretrained_model_name_or_path,
"random_crop": random_crop,
"resume": resume,
"save_every_n_epochs": save_every_n_epochs,
"save_every_n_steps": save_every_n_steps,
"save_last_n_steps": save_last_n_steps,
"save_last_n_steps_state": save_last_n_steps_state,
"save_model_as": save_model_as,
"save_precision": save_precision,
"save_state": save_state,
"scale_v_pred_loss_like_noise_pred": scale_v_pred_loss_like_noise_pred,
"seed": seed,
"shuffle_caption": shuffle_caption,
"train_batch_size": train_batch_size,
"train_data_dir": image_folder,
"train_text_encoder": train_text_encoder,
"use_wandb": use_wandb,
"v2": v2,
"v_parameterization": v_parameterization,
"v_pred_like_loss": v_pred_like_loss,
"vae_batch_size": vae_batch_size,
"wandb_api_key": wandb_api_key,
"weighted_captions": weighted_captions,
"xformers": xformers,
}

# Conditionally include specific keyword arguments based on sdxl_checkbox
if sdxl_checkbox:
kwargs_for_training["cache_text_encoder_outputs"] = cache_text_encoder_outputs
kwargs_for_training["learning_rate_te1"] = learning_rate_te1
kwargs_for_training["learning_rate_te2"] = learning_rate_te2
kwargs_for_training["no_half_vae"] = no_half_vae
else:
kwargs_for_training["learning_rate_te"] = learning_rate_te

# Pass the dynamically constructed keyword arguments to the function
run_cmd += run_cmd_advanced_training(**kwargs_for_training)

run_cmd += run_cmd_sample(
sample_every_n_steps,
Expand Down
2 changes: 1 addition & 1 deletion test/config/dreambooth-Prodigy-SDXL.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"adaptive_noise_scale": 0,
"additional_parameters": "",
"bucket_no_upscale": true,
"bucket_reso_steps": 1,
"bucket_reso_steps": 32,
"cache_latents": true,
"cache_latents_to_disk": false,
"caption_dropout_every_n_epochs": 0.0,
Expand Down

0 comments on commit 74a7b03

Please sign in to comment.