Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto mask uniform background base on PR #589 mask loss #1114

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
get_latent_masks
)


Expand Down Expand Up @@ -346,6 +347,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
target = noise

if (args.masked_loss or args.mask_simple_background) and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
Expand Down
3 changes: 3 additions & 0 deletions library/config_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class BaseSubsetParams:
flip_aug: bool = False
face_crop_aug_range: Optional[Tuple[float, float]] = None
random_crop: bool = False
mask_simple_background: bool = False
caption_prefix: Optional[str] = None
caption_suffix: Optional[str] = None
caption_dropout_rate: float = 0.0
Expand Down Expand Up @@ -175,6 +176,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence]
"flip_aug": bool,
"num_repeats": int,
"random_crop": bool,
"mask_simple_background": bool,
"shuffle_caption": bool,
"keep_tokens": int,
"keep_tokens_separator": str,
Expand Down Expand Up @@ -510,6 +512,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu
flip_aug: {subset.flip_aug}
face_crop_aug_range: {subset.face_crop_aug_range}
random_crop: {subset.random_crop}
mask_simple_background: {subset.mask_simple_background}
token_warmup_min: {subset.token_warmup_min},
token_warmup_step: {subset.token_warmup_step},
"""
Expand Down
21 changes: 21 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,27 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise


def get_latent_masks(image_masks, latent_shape, device):
# given that masks lower the average loss this will counteract the effect
factor = torch.sqrt(image_masks.mean([1, 2]))
factor = torch.where(factor != 0.0, factor, 1.0)
factor = factor.reshape(factor.shape + (1,) * 2)
image_masks = image_masks / factor

masks = (
image_masks
.to(device)
.reshape(latent_shape[0], 1, latent_shape[2] * 8, latent_shape[3] * 8)
)
# resize to match latent
masks = torch.nn.functional.interpolate(
masks.float(),
size=latent_shape[-2:],
mode="nearest"
)
return masks


"""
##########################################
# Perlin Noise
Expand Down
Loading