From dd22958caa56e4db885324f76188c13bdf504569 Mon Sep 17 00:00:00 2001 From: kabachuha Date: Mon, 1 Apr 2024 21:36:31 +0300 Subject: [PATCH] add option for smooth l1 (huber / delta) --- library/train_util.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/library/train_util.py b/library/train_util.py index c380e3311..95972e8fc 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -3091,7 +3091,7 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth: "--loss_type", type=str, default="l2", - choices=["l2", "huber", "huber_scheduled"], + choices=["l2", "huber", "huber_scheduled", "smooth_l1", "smooth_l1_scheduled"], help="The type of loss to use and whether it's scheduled based on the timestep" ) parser.add_argument( @@ -4608,7 +4608,7 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest #TODO: if a huber loss is selected, it will use constant timesteps for each batch # as. In the future there may be a smarter way - if args.loss_type == 'huber_scheduled': + if args.loss_type == 'huber_scheduled' or args.loss_type == 'smooth_l1_scheduled': #NOTE: Will unify scheduled and vanilla soon timesteps = torch.randint( min_timestep, max_timestep, (1,), device='cpu' ) @@ -4617,7 +4617,7 @@ def get_timesteps_and_huber_c(args, min_timestep, max_timestep, num_train_timest alpha = - math.log(args.huber_c) / num_train_timesteps huber_c = math.exp(-alpha * timestep) timesteps = timesteps.repeat(b_size).to(device) - elif args.loss_type == 'huber': + elif args.loss_type == 'huber' or args.loss_type == 'smooth_l1': # for fairness in comparison timesteps = torch.randint( min_timestep, max_timestep, (1,), device='cpu' @@ -4670,6 +4670,12 @@ def conditional_loss(model_pred:torch.Tensor, target:torch.Tensor, reduction:str loss = torch.mean(loss) elif reduction == "sum": loss = torch.sum(loss) + elif loss_type == 'smooth_l1' or loss_type == 'smooth_l1_scheduled': # NOTE: Will unify in the next commits + loss = 2 * (torch.sqrt((model_pred - target) ** 2 + huber_c**2) - huber_c) + if reduction == "mean": + loss = torch.mean(loss) + elif reduction == "sum": + loss = torch.sum(loss) else: raise NotImplementedError(f'Unsupported Loss Type {loss_type}') return loss