Skip to content

Commit

Permalink
PHL-schedule should depend on noise scheduler's num timesteps
Browse files Browse the repository at this point in the history
  • Loading branch information
kabachuha committed Mar 31, 2024
1 parent 8abc112 commit a58f290
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
6 changes: 3 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4604,7 +4604,7 @@ def save_sd_model_on_train_end_common(
if args.huggingface_repo_id is not None:
huggingface_util.upload(args, out_dir, "/" + model_name, force_sync_upload=True)

def get_timesteps_and_huber_c(args, min_timestep, max_timestep, b_size, device):
def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timesteps, b_size, device):

#TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way
Expand All @@ -4614,7 +4614,7 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, b_size, device):
)
timestep = timesteps.item()

alpha = - math.log(args.huber_c) / max_timestep
alpha = - math.log(args.huber_c) / num_train_timesteps
huber_c = math.exp(-alpha * timestep)
timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == 'huber':
Expand Down Expand Up @@ -4648,7 +4648,7 @@ def get_noise_noisy_latents_and_timesteps(args, noise_scheduler, latents):
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = noise_scheduler.config.num_train_timesteps if args.max_timestep is None else args.max_timestep

timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, b_size, latents.device)
timesteps, huber_c = get_timesteps_and_huber_c(args, min_timestep, max_timestep, noise_scheduler.config.num_train_timesteps, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand Down
2 changes: 1 addition & 1 deletion train_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,7 @@ def remove_model(old_ckpt_name):
)

# Sample a random timestep for each image
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, b_size, latents.device)
timesteps, huber_c = train_util.get_timesteps_and_huber_c(args, 0, noise_scheduler.config.num_train_timesteps, noise_scheduler.config.num_train_timesteps, b_size, latents.device)

# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
Expand Down

0 comments on commit a58f290

Please sign in to comment.