diff --git a/fine_tune.py b/fine_tune.py index 982dc8aec..b0fe23d5d 100644 --- a/fine_tune.py +++ b/fine_tune.py @@ -31,6 +31,7 @@ prepare_scheduler_for_custom_training, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + get_latent_masks ) @@ -346,6 +347,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): else: target = noise + if (args.masked_loss or args.mask_simple_background) and batch['masks'] is not None: + mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) + noise_pred = noise_pred * mask + target = target * mask + if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss: # do not mean over batch dimension for snr weight or scale v-pred loss loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") diff --git a/library/config_util.py b/library/config_util.py index a98c2b90d..cabfa2668 100644 --- a/library/config_util.py +++ b/library/config_util.py @@ -61,6 +61,7 @@ class BaseSubsetParams: flip_aug: bool = False face_crop_aug_range: Optional[Tuple[float, float]] = None random_crop: bool = False + mask_simple_background: bool = False caption_prefix: Optional[str] = None caption_suffix: Optional[str] = None caption_dropout_rate: float = 0.0 @@ -175,6 +176,7 @@ def __validate_and_convert_scalar_or_twodim(klass, value: Union[float, Sequence] "flip_aug": bool, "num_repeats": int, "random_crop": bool, + "mask_simple_background": bool, "shuffle_caption": bool, "keep_tokens": int, "keep_tokens_separator": str, @@ -510,6 +512,7 @@ def generate_dataset_group_by_blueprint(dataset_group_blueprint: DatasetGroupBlu flip_aug: {subset.flip_aug} face_crop_aug_range: {subset.face_crop_aug_range} random_crop: {subset.random_crop} + mask_simple_background: {subset.mask_simple_background} token_warmup_min: {subset.token_warmup_min}, token_warmup_step: {subset.token_warmup_step}, """ diff --git a/library/custom_train_functions.py b/library/custom_train_functions.py index e0a026dae..7df22dd40 100644 --- a/library/custom_train_functions.py +++ b/library/custom_train_functions.py @@ -471,6 +471,27 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale): return noise +def get_latent_masks(image_masks, latent_shape, device): + # given that masks lower the average loss this will counteract the effect + factor = torch.sqrt(image_masks.mean([1, 2])) + factor = torch.where(factor != 0.0, factor, 1.0) + factor = factor.reshape(factor.shape + (1,) * 2) + image_masks = image_masks / factor + + masks = ( + image_masks + .to(device) + .reshape(latent_shape[0], 1, latent_shape[2] * 8, latent_shape[3] * 8) + ) + # resize to match latent + masks = torch.nn.functional.interpolate( + masks.float(), + size=latent_shape[-2:], + mode="nearest" + ) + return masks + + """ ########################################## # Perlin Noise diff --git a/library/train_util.py b/library/train_util.py index ba428e508..99af0863c 100644 --- a/library/train_util.py +++ b/library/train_util.py @@ -145,6 +145,9 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool, self.text_encoder_outputs1: Optional[torch.Tensor] = None self.text_encoder_outputs2: Optional[torch.Tensor] = None self.text_encoder_pool2: Optional[torch.Tensor] = None + # Masked Loss + self.mask: np.ndarray = None + self.mask_flipped: np.ndarray = None class BucketManager: @@ -322,21 +325,24 @@ def color_aug(self, image: np.ndarray): # ) hue_shift_limit = 8 + rgb_channels = image[:, :, :3] + alpha_channel = image[:, :, -1] # remove dependency to albumentations if random.random() <= 0.33: if random.random() > 0.5: # hue shift - hsv_img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + hsv_img = cv2.cvtColor(rgb_channels, cv2.COLOR_BGR2HSV) hue_shift = random.uniform(-hue_shift_limit, hue_shift_limit) if hue_shift < 0: hue_shift = 180 + hue_shift hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180 - image = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) + rgb_channels = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2BGR) else: # random gamma gamma = random.uniform(0.95, 1.05) - image = np.clip(image**gamma, 0, 255).astype(np.uint8) + rgb_channels = np.clip(rgb_channels**gamma, 0, 255).astype(np.uint8) + image = np.dstack((rgb_channels, alpha_channel)) return {"image": image} def get_augmentor(self, use_color_aug: bool): # -> Optional[Callable[[np.ndarray], Dict[str, np.ndarray]]]: @@ -356,6 +362,7 @@ def __init__( flip_aug: bool, face_crop_aug_range: Optional[Tuple[float, float]], random_crop: bool, + mask_simple_background: bool, caption_dropout_rate: float, caption_dropout_every_n_epochs: int, caption_tag_dropout_rate: float, @@ -374,6 +381,7 @@ def __init__( self.flip_aug = flip_aug self.face_crop_aug_range = face_crop_aug_range self.random_crop = random_crop + self.mask_simple_background = mask_simple_background self.caption_dropout_rate = caption_dropout_rate self.caption_dropout_every_n_epochs = caption_dropout_every_n_epochs self.caption_tag_dropout_rate = caption_tag_dropout_rate @@ -402,6 +410,7 @@ def __init__( flip_aug, face_crop_aug_range, random_crop, + mask_simple_background: bool, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, @@ -423,6 +432,7 @@ def __init__( flip_aug, face_crop_aug_range, random_crop, + mask_simple_background, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, @@ -458,6 +468,7 @@ def __init__( flip_aug, face_crop_aug_range, random_crop, + mask_simple_background: bool, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, @@ -479,6 +490,7 @@ def __init__( flip_aug, face_crop_aug_range, random_crop, + mask_simple_background, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, @@ -511,6 +523,7 @@ def __init__( flip_aug, face_crop_aug_range, random_crop, + mask_simple_background: bool, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, @@ -532,6 +545,7 @@ def __init__( flip_aug, face_crop_aug_range, random_crop, + mask_simple_background, caption_dropout_rate, caption_dropout_every_n_epochs, caption_tag_dropout_rate, @@ -949,7 +963,7 @@ def cache_latents(self, vae, vae_batch_size=1, cache_to_disk=False, is_main_proc # iterate batches: batch doesn't have image, image will be loaded in cache_batch_latents and discarded print("caching latents...") for batch in tqdm(batches, smoothing=1, total=len(batches)): - cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop) + cache_batch_latents(vae, cache_to_disk, batch, subset.flip_aug, subset.random_crop, subset.mask_simple_background) # weight_dtypeを指定するとText Encoderそのもの、およひ出力がweight_dtypeになる # SDXLでのみ有効だが、datasetのメソッドとする必要があるので、sdxl_train_util.pyではなくこちらに実装する @@ -1097,6 +1111,7 @@ def __getitem__(self, index): input_ids2_list = [] latents_list = [] images = [] + masks = [] original_sizes_hw = [] crop_top_lefts = [] target_sizes_hw = [] @@ -1120,14 +1135,17 @@ def __getitem__(self, index): crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped if not flipped: latents = image_info.latents + mask = image_info.mask else: latents = image_info.latents_flipped + mask = image_info.mask_flipped image = None elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合 - latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz) + latents, original_size, crop_ltrb, flipped_latents, mask = load_latents_from_disk(image_info.latents_npz) if flipped: latents = flipped_latents + mask = np.flip(mask, axis=1) del flipped_latents latents = torch.FloatTensor(latents) @@ -1162,7 +1180,23 @@ def __getitem__(self, index): original_size = [im_w, im_h] crop_ltrb = (0, 0, 0, 0) - + + if subset.mask_simple_background: + edge_width = max(1, min(image.shape[0], image.shape[1]) // 20) + top_edge = image[:edge_width, :, :] + bottom_edge = image[-edge_width:, :, :] + left_edge = image[:, :edge_width, :].reshape(-1, image.shape[2]) + right_edge = image[:, -edge_width:, :].reshape(-1, image.shape[2]) + edges = np.concatenate([top_edge.reshape(-1, image.shape[2]), + bottom_edge.reshape(-1, image.shape[2]), + left_edge, right_edge]) + colors, counts = np.unique(edges, axis=0, return_counts=True) + simple_color = colors[counts.argmax()] + simple_color_ratio = counts.max() / counts.sum() + if simple_color_ratio > 0.3: + simple_color_mask = np.all(image[:, :, :-1] == simple_color[:3], axis=2) + image[simple_color_mask, -1] = 0 + # augmentation aug = self.aug_helper.get_augmentor(subset.color_aug) if aug is not None: @@ -1171,11 +1205,16 @@ def __getitem__(self, index): if flipped: img = img[:, ::-1, :].copy() # copy to avoid negative stride problem + # loss mask is alpha channel, separate it + mask = img[:, :, -1] / 255 + img = img[:, :, :3] + latents = None image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる images.append(image) latents_list.append(latents) + masks.append(torch.tensor(mask)) target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8) @@ -1267,7 +1306,7 @@ def __getitem__(self, index): else: images = None example["images"] = images - + example["masks"] = torch.stack(masks) if masks[0] is not None else None example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None example["captions"] = captions @@ -1988,7 +2027,7 @@ def is_disk_cached_latents_is_expected(reso, npz_path: str, flip_aug: bool): # 戻り値は、latents_tensor, (original_size width, original_size height), (crop left, crop top) def load_latents_from_disk( npz_path, -) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor]]: +) -> Tuple[Optional[torch.Tensor], Optional[List[int]], Optional[List[int]], Optional[torch.Tensor], Optional[np.ndarray]]: npz = np.load(npz_path) if "latents" not in npz: raise ValueError(f"error: npz is old format. please re-generate {npz_path}") @@ -1997,14 +2036,19 @@ def load_latents_from_disk( original_size = npz["original_size"].tolist() crop_ltrb = npz["crop_ltrb"].tolist() flipped_latents = npz["latents_flipped"] if "latents_flipped" in npz else None - return latents, original_size, crop_ltrb, flipped_latents + mask = npz["mask"] if "mask" in npz else None + if mask is not None: + mask = mask.astype(np.float32) + return latents, original_size, crop_ltrb, flipped_latents, mask -def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None): +def save_latents_to_disk(npz_path, latents_tensor, original_size, crop_ltrb, flipped_latents_tensor=None, mask=None): kwargs = {} if flipped_latents_tensor is not None: kwargs["latents_flipped"] = flipped_latents_tensor.float().cpu().numpy() - np.savez( + if mask is not None: + kwargs["mask"] = np.array(mask, dtype=np.uint8) + np.savez_compressed( npz_path, latents=latents_tensor.float().cpu().numpy(), original_size=np.array(original_size), @@ -2191,12 +2235,43 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset: def load_image(image_path): image = Image.open(image_path) - if not image.mode == "RGB": - image = image.convert("RGB") + if not image.mode == "RGBA": + image = image.convert("RGBA") img = np.array(image, np.uint8) + img[..., -1] = load_mask(image_path, img.shape[:2]) return img +def load_mask(image_path, target_shape): + p = pathlib.Path(image_path) + mask_path = os.path.join(p.parent, 'mask', p.stem + '.png') + result = None + + if os.path.exists(mask_path): + try: + mask_img = Image.open(mask_path) + mask = np.array(mask_img) + if len(mask.shape) > 2 and mask.max() <= 255: + result = np.array(mask_img.convert("L")) + elif len(mask.shape) == 2 and mask.max() > 255: + result = mask // (((2 ** 16) - 1) // 255) + elif len(mask.shape) == 2 and mask.max() <= 255: + result = mask + else: + print(f"{mask_path} has invalid mask format: using default mask") + except: + print(f"failed to load mask: {mask_path}") + + # use default when mask file is unavailable + if result is None: + result = np.full(target_shape, 255, np.uint8) + + # stretch mask to image shape + if result.shape != target_shape: + result = cv2.resize(result, dsize=target_shape, interpolation=cv2.INTER_LINEAR) + + return result + # 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom) def trim_and_resize_if_required( random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int] @@ -2231,7 +2306,7 @@ def trim_and_resize_if_required( def cache_batch_latents( - vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool + vae: AutoencoderKL, cache_to_disk: bool, image_infos: List[ImageInfo], flip_aug: bool, random_crop: bool, mask_simple_background: bool ) -> None: r""" requires image_infos to have: absolute_path, bucket_reso, resized_size, latents_npz @@ -2243,12 +2318,32 @@ def cache_batch_latents( latents_original_size and latents_crop_ltrb are also set """ images = [] + masks = [] for info in image_infos: image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8) # TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要 image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size) + # alpha channel contains loss mask, separate it + if mask_simple_background: + edge_width = max(1, min(image.shape[0], image.shape[1]) // 20) + top_edge = image[:edge_width, :, :] + bottom_edge = image[-edge_width:, :, :] + left_edge = image[:, :edge_width, :].reshape(-1, image.shape[2]) + right_edge = image[:, -edge_width:, :].reshape(-1, image.shape[2]) + edges = np.concatenate([top_edge.reshape(-1, image.shape[2]), + bottom_edge.reshape(-1, image.shape[2]), + left_edge, right_edge]) + colors, counts = np.unique(edges, axis=0, return_counts=True) + simple_color = colors[counts.argmax()] + simple_color_ratio = counts.max() / counts.sum() + if simple_color_ratio > 0.3: + simple_color_mask = np.all(image[:, :, :-1] == simple_color[:3], axis=2) + image[simple_color_mask, -1] = 0 + mask = image[:, :, -1] / 255 + image = image[:, :, :3] image = IMAGE_TRANSFORMS(image) images.append(image) + masks.append(mask) info.latents_original_size = original_size info.latents_crop_ltrb = crop_ltrb @@ -2266,17 +2361,19 @@ def cache_batch_latents( else: flipped_latents = [None] * len(latents) - for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents): + for info, latent, flipped_latent, mask in zip(image_infos, latents, flipped_latents, masks): # check NaN if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()): raise RuntimeError(f"NaN detected in latents: {info.absolute_path}") if cache_to_disk: - save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent) + save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent, mask) else: info.latents = latent + info.mask = mask if flip_aug: info.latents_flipped = flipped_latent + info.mask_flipped = mask.flip(mask, dims=[3]) # FIXME this slows down caching a lot, specify this as an option if torch.cuda.is_available(): @@ -3259,6 +3356,14 @@ def add_dataset_arguments( "--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します" ) + parser.add_argument( + "--masked_loss", action="store_true", help="Enable masking of latent loss using grayscale mask images" + ) + + parser.add_argument( + "--mask_simple_background", action="store_true", help="Enable auto-masking of latent loss based on the dominant edge color if it occupies more than 30% of the image edges. This helps in focusing the model on the main content by ignoring simple or uniform background colors such as solid white or black. / 画像の端に占める主要な色が30%以上の場合に基づいて潜在的な損失の自動マスキングを有効にします。これにより、純白または純黒などの単純または均一な背景色を無視して、モデルがメインコンテンツに焦点を合わせるのに役立ちます。" + ) + parser.add_argument( "--token_warmup_min", type=int, diff --git a/sdxl_train.py b/sdxl_train.py index a3f6f3a17..0f73a0e8a 100644 --- a/sdxl_train.py +++ b/sdxl_train.py @@ -33,6 +33,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + get_latent_masks ) from library.sdxl_original_unet import SdxlUNet2DConditionModel @@ -561,6 +562,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module): target = noise + if (args.masked_loss or args.mask_simple_background) and batch['masks'] is not None: + mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) + noise_pred = noise_pred * mask + target = target * mask + if ( args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred diff --git a/train_db.py b/train_db.py index 888cad25e..2762c9cde 100644 --- a/train_db.py +++ b/train_db.py @@ -34,6 +34,7 @@ apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + get_latent_masks ) # perlin_noise, @@ -333,6 +334,11 @@ def train(args): else: target = noise + if (args.masked_loss or args.mask_simple_background) and batch['masks'] is not None: + mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) + noise_pred = noise_pred * mask + target = target * mask + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) diff --git a/train_network.py b/train_network.py index 8d102ae8f..e49253807 100644 --- a/train_network.py +++ b/train_network.py @@ -40,6 +40,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + get_latent_masks ) @@ -824,6 +825,11 @@ def remove_model(old_ckpt_name): else: target = noise + if (args.masked_loss or args.mask_simple_background) and batch['masks'] is not None: + mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) + noise_pred = noise_pred * mask + target = target * mask + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion.py b/train_textual_inversion.py index 441c1e00b..f827720a2 100644 --- a/train_textual_inversion.py +++ b/train_textual_inversion.py @@ -31,6 +31,7 @@ scale_v_prediction_loss_like_noise_prediction, add_v_prediction_like_loss, apply_debiased_estimation, + get_latent_masks ) imagenet_templates_small = [ @@ -582,6 +583,11 @@ def remove_model(old_ckpt_name): else: target = noise + if (args.masked_loss or args.mask_simple_background) and batch['masks'] is not None: + mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) + noise_pred = noise_pred * mask + target = target * mask + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3]) diff --git a/train_textual_inversion_XTI.py b/train_textual_inversion_XTI.py index 7046a4808..8a20b978b 100644 --- a/train_textual_inversion_XTI.py +++ b/train_textual_inversion_XTI.py @@ -33,6 +33,7 @@ apply_noise_offset, scale_v_prediction_loss_like_noise_prediction, apply_debiased_estimation, + get_latent_masks ) import library.original_unet as original_unet from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI @@ -459,6 +460,11 @@ def remove_model(old_ckpt_name): else: target = noise + if (args.masked_loss or args.mask_simple_background) and batch['masks'] is not None: + mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device) + noise_pred = noise_pred * mask + target = target * mask + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none") loss = loss.mean([1, 2, 3])