Skip to content

Commit

Permalink
Remove the skew and kurtosis loss objectives; they hurt more than the…
Browse files Browse the repository at this point in the history
…y help
  • Loading branch information
cheald committed Apr 24, 2024
1 parent 6bf5dae commit 357aa44
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 27 deletions.
19 changes: 0 additions & 19 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3369,25 +3369,6 @@ def add_training_arguments(parser: argparse.ArgumentParser, support_dreambooth:
default=None,
help="Weight for standard deviation loss. Encourages the model to learn noise with a stddev like the true noise. May prevent 'deep fry'. 1.0 is a good starting place.",
)
parser.add_argument(
"--kurtosis_loss_weight",
type=float,
default=None,
help="Weight for kurtosis loss. Encourages the model to learn noise with a kurtosis like the true noise. Recommended if using std_loss_weight.",
)
parser.add_argument(
"--skew_loss_weight",
type=float,
default=None,
help="Weight for skew loss. Encourages the model to learn noise with a skew like the true noise. Recommended if using std_loss_weight.",
)
parser.add_argument(
"--latent_corruption",
type=float,
default=None,
help="latent corruption for training (default is None) / 学習時のlatent corruption(デフォルトはNone)",
)


if support_dreambooth:
# DreamBooth training
Expand Down
8 changes: 0 additions & 8 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -913,14 +913,6 @@ def remove_model(old_ckpt_name):
std_loss = F.mse_loss(pred_std, true_std, reduction="none")
loss = loss + std_loss * args.std_loss_weight

if args.skew_loss_weight is not None:
skew_loss = F.mse_loss(pred_skews, true_skews, reduction="none")
loss = loss + skew_loss * args.skew_loss_weight

if args.kurtosis_loss_weight is not None:
kurtosis_loss = F.mse_loss(pred_kurtoses, true_kurtoses, reduction="none")
loss = loss + kurtosis_loss * args.kurtosis_loss_weight

# print(kl_loss.dtype, pred_std.dtype, noise_pred.dtype, true_std.dtype, pred_skews.dtype, true_skews.dtype, pred_kurtoses.dtype, true_kurtoses.dtype)
# step_logs["loss/kl_loss"] = kl_loss.mean().item()
step_logs["metrics/noise_pred_std"] = pred_std.mean().item()
Expand Down

0 comments on commit 357aa44

Please sign in to comment.