Skip to content

Commit

Permalink
add option for smooth l1 (huber / delta)
Browse files Browse the repository at this point in the history
  • Loading branch information
kabachuha committed Apr 1, 2024
1 parent c6495de commit dd22958
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3091,7 +3091,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
"--loss_type",
type=str,
default="l2",
choices=["l2", "huber", "huber_scheduled"],
choices=["l2", "huber", "huber_scheduled", "smooth_l1", "smooth_l1_scheduled"],
help="The type of loss to use and whether it's scheduled based on the timestep"
)
parser.add_argument(
Expand Down Expand Up @@ -4608,7 +4608,7 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest

#TODO: if a huber loss is selected, it will use constant timesteps for each batch
# as. In the future there may be a smarter way
if args.loss_type == 'huber_scheduled':
if args.loss_type == 'huber_scheduled' or args.loss_type == 'smooth_l1_scheduled': #NOTE: Will unify scheduled and vanilla soon
timesteps = torch.randint(
min_timestep, max_timestep, (1,), device='cpu'
)
Expand All @@ -4617,7 +4617,7 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest
alpha = - math.log(args.huber_c) / num_train_timesteps
huber_c = math.exp(-alpha * timestep)
timesteps = timesteps.repeat(b_size).to(device)
elif args.loss_type == 'huber':
elif args.loss_type == 'huber' or args.loss_type == 'smooth_l1':
# for fairness in comparison
timesteps = torch.randint(
min_timestep, max_timestep, (1,), device='cpu'
Expand Down Expand Up @@ -4670,6 +4670,12 @@ def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
elif loss_type == 'smooth_l1' or loss_type == 'smooth_l1_scheduled': # NOTE: Will unify in the next commits
loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c)
if reduction == "mean":
loss = torch.mean(loss)
elif reduction == "sum":
loss = torch.sum(loss)
else:
raise NotImplementedError(f'Unsupported Loss Type {loss_type}')
return loss
Expand Down

0 comments on commit dd22958

Please sign in to comment.