From 19a834c3ab448614e8887b07f2bb4e0aaabf0805 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Mon, 1 Apr 2024 21:46:28 +0300 Subject: [PATCH] unify huber scheduling --- library/train_util.py | 39 +++++++++++++++++++++++++-------------- 1 file changed, 25 insertions(+), 14 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index 95972e8fc..44a45f9c2 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3091,14 +3091,21 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--loss_type", type=str, default="l2", - choices=["l2", "huber", "huber_scheduled", "smooth_l1", "smooth_l1_scheduled"], + choices=["l2", "huber", "smooth_l1"], + help="The type of loss to use and whether it's scheduled based on the timestep" + ) + parser.add_argument( + "--huber_schedule", + type=str, + default="exponential", + choices=["constant", "exponential", "snr"], #TODO: add snr help="The type of loss to use and whether it's scheduled based on the timestep" ) parser.add_argument( "--huber_c", type=float, default=0.1, - help="The huber loss parameter. Only used if one of the huber loss modes is selected with loss_type.", + help="The huber loss parameter. Only used if one of the huber loss modes (huber or smooth l1) is selected with loss_type.", ) parser.add_argument( @@ -4608,22 +4615,26 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest #TODO: if a huber loss is selected, it will use constant timesteps for each batch # as. In the future there may be a smarter way - if args.loss_type == 'huber_scheduled' or args.loss_type == 'smooth_l1_scheduled': #NOTE: Will unify scheduled and vanilla soon + + if args.loss_type == 'huber' or args.loss_type == 'smooth_l1': timesteps = torch.randint( min_timestep, max_timestep, (1,), device='cpu' ) timestep = timesteps.item() - alpha = - math.log(args.huber_c) / num_train_timesteps - huber_c = math.exp(-alpha * timestep) - timesteps = timesteps.repeat(b_size).to(device) - elif args.loss_type == 'huber' or args.loss_type == 'smooth_l1': - # for fairness in comparison - timesteps = torch.randint( - min_timestep, max_timestep, (1,), device='cpu' - ) + if args.huber_schedule == "exponential": + alpha = - math.log(args.huber_c) / num_train_timesteps + huber_c = math.exp(-alpha * timestep) + elif args.huber_schedule == "snr": + # TODO + huber_c = args.huber_c # Placeholder + pass + elif args.huber_schedule == "constant": + huber_c = args.huber_c + else: + raise NotImplementedError(f'Unknown Huber loss schedule {args.huber_schedule}!') + timesteps = timesteps.repeat(b_size).to(device) - huber_c = args.huber_c elif args.loss_type == 'l2': timesteps = torch.randint(min_timestep, max_timestep, (b_size,), device=device) huber_c = 1 # may be anything, as it's not used @@ -4664,13 +4675,13 @@ def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str if loss_type == 'l2': loss = torch.nn.functional.mse_loss(model_pred, target, reduction=reduction) - elif loss_type == 'huber' or loss_type == 'huber_scheduled': + elif loss_type == 'huber': loss = 2 * huber_c * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) - elif loss_type == 'smooth_l1' or loss_type == 'smooth_l1_scheduled': # NOTE: Will unify in the next commits + elif loss_type == 'smooth_l1': loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) if reduction == "mean": loss = torch.mean(loss)