diff --git a/fine_tune.py b/fine_tune.py index 0efae56af..898e53cb3 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -347,7 +347,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: target = noise - if args.masked_loss and batch['masks'] is not None: + if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None: mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) noise_pred = noise_pred * mask target = target * mask diff --git a/library/train_util.py b/library/train_util.py index b2570c707..8d5a2f8ee 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -1131,8 +1131,8 @@ def __getitem__(self, index): image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) - mask = load_mask(image_info.absolute_path, image_info.resized_size) / 255 + latents, original_size, crop_ltrb, flipped_latents, mask = load_latents_from_disk(image_info.latents_npz) + mask = mask / 255 if flipped: latents = flipped_latents mask = np.flip(mask, axis=1) @@ -2001,7 +2001,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -2010,14 +2010,19 @@ def load_latents_from_disk( original_size = npz["original_size"].tolist() crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - return latents, original_size, crop_ltrb, flipped_latents + mask = npz["mask"] if "mask" in npz else None + if mask is not None: + mask = mask.astype(np.float32) + return latents, original_size, crop_ltrb, flipped_latents, mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, mask=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() - np.savez( + if mask is not None: + kwargs["mask"] = np.array(mask, dtype=np.uint8) + np.savez_compressed( npz_path, latents=latents_tensor.float().cpu().numpy(), original_size=np.array(original_size), @@ -2322,7 +2327,7 @@ def cache_batch_latents( raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) + save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, mask) else: info.latents = latent info.mask = mask diff --git a/sdxl_train.py b/sdxl_train.py index df05a2d9d..283c59372 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -562,7 +562,7 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): target = noise - if args.masked_loss and batch['masks'] is not None: + if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None: mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) noise_pred = noise_pred * mask target = target * mask diff --git a/train_db.py b/train_db.py index 088b795ac..973ca1652 100644 --- a/train_db.py +++ b/train_db.py @@ -334,7 +334,7 @@ def train(args): else: target = noise - if args.masked_loss and batch['masks'] is not None: + if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None: mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) noise_pred = noise_pred * mask target = target * mask diff --git a/train_network.py b/train_network.py index cddd3554e..d333c220f 100644 --- a/train_network.py +++ b/train_network.py @@ -825,7 +825,7 @@ def remove_model(old_ckpt_name): else: target = noise - if args.masked_loss and batch['masks'] is not None: + if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None: mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) noise_pred = noise_pred * mask target = target * mask diff --git a/train_textual_inversion.py b/train_textual_inversion.py index f1eb3299c..d1e9460c8 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -583,7 +583,7 @@ def remove_model(old_ckpt_name): else: target = noise - if args.masked_loss and batch['masks'] is not None: + if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None: mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) noise_pred = noise_pred * mask target = target * mask diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 1b999c7ba..76dbed9a8 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -460,7 +460,7 @@ def remove_model(old_ckpt_name): else: target = noise - if args.masked_loss and batch['masks'] is not None: + if (args.masked_loss or args.auto_masked_loss) and batch['masks'] is not None: mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) noise_pred = noise_pred * mask target = target * mask