Skip to content

Commit

Permalink
unify huber scheduling
Browse files Browse the repository at this point in the history
  • Loading branch information
kabachuha committed Apr 1, 2024
1 parent dd22958 commit 19a834c
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3091,14 +3091,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "huber_scheduled", "smooth_l1", "smooth_l1_scheduled"],
choices=["l2", "huber", "smooth_l1"],
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
"--huber_schedule",
type=str,
default="exponential",
choices=["constant", "exponential", "snr"], #TODO: add snr
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
"--huber_c",
type=float,
default=0.1,
help="The huber loss parameter. Only used if one of the huber loss modes is selected with loss_type.",
help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.",
)

parser.add_argument(
Expand Down Expand Up @@ -4608,22 +4615,26 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest

#TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way
if args.loss_type == 'huber_scheduled' or args.loss_type == 'smooth_l1_scheduled': #NOTE: Will unify scheduled and vanilla soon

if args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
timesteps = torch.randint(
min_timestep, max_timestep, (1,), device='cpu'
)
timestep = timesteps.item()

alpha = - math.log(args.huber_c) / num_train_timesteps
huber_c = math.exp(-alpha * timestep)
timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
# for fairness in comparison
timesteps = torch.randint(
min_timestep, max_timestep, (1,), device='cpu'
)
if args.huber_schedule == "exponential":
alpha = - math.log(args.huber_c) / num_train_timesteps
huber_c = math.exp(-alpha * timestep)
elif args.huber_schedule == "snr":
# TODO
huber_c = args.huber_c # Placeholder
pass
elif args.huber_schedule == "constant":
huber_c = args.huber_c
else:
raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!')

timesteps = timesteps.repeat(b_size).to(device)
huber_c = args.huber_c
elif args.loss_type == 'l2':
timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device)
huber_c = 1 # may be anything, as it's not used
Expand Down Expand Up @@ -4664,13 +4675,13 @@ def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str

if loss_type == 'l2':
loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction)
elif loss_type == 'huber' or loss_type == 'huber_scheduled':
elif loss_type == 'huber':
loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == 'smooth_l1' or loss_type == 'smooth_l1_scheduled': # NOTE: Will unify in the next commits
elif loss_type == 'smooth_l1':
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
Expand Down

0 comments on commit 19a834c

Please sign in to comment.