Skip to content

Commit

Permalink
Merge branch 'pr/589' into auto_mask
Browse files Browse the repository at this point in the history
  • Loading branch information
gesen2egee committed Feb 10, 2024
2 parents cd19df4 + 5680057 commit 5f6c5ff
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 4 deletions.
6 changes: 6 additions & 0 deletions fine_tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
prepare_scheduler_for_custom_training,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
get_latent_masks
)


Expand Down Expand Up @@ -346,6 +347,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

if args.min_snr_gamma or args.scale_v_pred_loss_like_noise_pred or args.debiased_estimation_loss:
# do not mean over batch dimension for snr weight or scale v-pred loss
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
Expand Down
21 changes: 21 additions & 0 deletions library/custom_train_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,27 @@ def apply_noise_offset(latents, noise, noise_offset, adaptive_noise_scale):
return noise


def get_latent_masks(image_masks, latent_shape, device):
# given that masks lower the average loss this will counteract the effect
factor = torch.sqrt(image_masks.mean([1, 2]))
factor = torch.where(factor != 0.0, factor, 1.0)
factor = factor.reshape(factor.shape + (1,) * 2)
image_masks = image_masks / factor

masks = (
image_masks
.to(device)
.reshape(latent_shape[0], 1, latent_shape[2] * 8, latent_shape[3] * 8)
)
# resize to match latent
masks = torch.nn.functional.interpolate(
masks.float(),
size=latent_shape[-2:],
mode="nearest"
)
return masks


"""
##########################################
# Perlin Noise
Expand Down
64 changes: 60 additions & 4 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,9 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,
self.text_encoder_outputs1: Optional[torch.Tensor] = None
self.text_encoder_outputs2: Optional[torch.Tensor] = None
self.text_encoder_pool2: Optional[torch.Tensor] = None
# Masked Loss
self.mask: np.ndarray = None
self.mask_flipped: np.ndarray = None


class BucketManager:
Expand Down Expand Up @@ -1097,6 +1100,7 @@ def __getitem__(self, index):
input_ids2_list = []
latents_list = []
images = []
masks = []
original_sizes_hw = []
crop_top_lefts = []
target_sizes_hw = []
Expand All @@ -1120,14 +1124,18 @@ def __getitem__(self, index):
crop_ltrb = image_info.latents_crop_ltrb # calc values later if flipped
if not flipped:
latents = image_info.latents
mask = image_info.mask
else:
latents = image_info.latents_flipped
mask = image_info.mask_flipped

image = None
elif image_info.latents_npz is not None: # FineTuningDatasetまたはcache_latents_to_disk=Trueの場合
latents, original_size, crop_ltrb, flipped_latents = load_latents_from_disk(image_info.latents_npz)
mask = load_mask(image_info.absolute_path, image_info.resized_size) / 255
if flipped:
latents = flipped_latents
mask = np.flip(mask, axis=1)
del flipped_latents
latents = torch.FloatTensor(latents)

Expand Down Expand Up @@ -1171,11 +1179,16 @@ def __getitem__(self, index):
if flipped:
img = img[:, ::-1, :].copy() # copy to avoid negative stride problem

# loss mask is alpha channel, separate it
mask = img[:, :, -1] / 255
img = img[:, :, :3]

latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる

images.append(image)
latents_list.append(latents)
masks.append(torch.tensor(mask))

target_size = (image.shape[2], image.shape[1]) if image is not None else (latents.shape[2] * 8, latents.shape[1] * 8)

Expand Down Expand Up @@ -1267,7 +1280,7 @@ def __getitem__(self, index):
else:
images = None
example["images"] = images

example["masks"] = torch.stack(masks) if masks[0] is not None else None
example["latents"] = torch.stack(latents_list) if latents_list[0] is not None else None
example["captions"] = captions

Expand Down Expand Up @@ -2191,12 +2204,44 @@ def load_arbitrary_dataset(args, tokenizer) -> MinimalDataset:

def load_image(image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == "RGBA":
image = image.convert("RGBA")
img = np.array(image, np.uint8)
img[..., -1] = load_mask(image_path, img.shape[:2])
return img


def load_mask(image_path, target_shape):
p = pathlib.Path(image_path)
mask_path = os.path.join(p.parent, 'mask', p.stem + '.png')
result = None

if os.path.exists(mask_path):
try:
mask_img = Image.open(mask_path)
mask = np.array(mask_img)
if len(mask.shape) > 2 and mask.max() <= 255:
result = np.array(mask_img.convert("L"))
elif len(mask.shape) == 2 and mask.max() > 255:
result = mask // (((2 ** 16) - 1) // 255)
elif len(mask.shape) == 2 and mask.max() <= 255:
result = mask
else:
print(f"{mask_path} has invalid mask format: using default mask")
except:
print(f"failed to load mask: {mask_path}")

# use default when mask file is unavailable
if result is None:
result = np.full(target_shape, 255, np.uint8)

# stretch mask to image shape
if result.shape != target_shape:
result = cv2.resize(result, dsize=target_shape, interpolation=cv2.INTER_LINEAR)

return result


# 画像を読み込む。戻り値はnumpy.ndarray,(original width, original height),(crop left, crop top, crop right, crop bottom)
def trim_and_resize_if_required(
random_crop: bool, image: Image.Image, reso, resized_size: Tuple[int, int]
Expand Down Expand Up @@ -2243,12 +2288,17 @@ def cache_batch_latents(
latents_original_size and latents_crop_ltrb are also set
"""
images = []
masks = []
for info in image_infos:
image = load_image(info.absolute_path) if info.image is None else np.array(info.image, np.uint8)
# TODO 画像のメタデータが壊れていて、メタデータから割り当てたbucketと実際の画像サイズが一致しない場合があるのでチェック追加要
image, original_size, crop_ltrb = trim_and_resize_if_required(random_crop, image, info.bucket_reso, info.resized_size)
# alpha channel contains loss mask, separate it
mask = image[:, :, -1] / 255
image = image[:, :, :3]
image = IMAGE_TRANSFORMS(image)
images.append(image)
masks.append(mask)

info.latents_original_size = original_size
info.latents_crop_ltrb = crop_ltrb
Expand All @@ -2266,7 +2316,7 @@ def cache_batch_latents(
else:
flipped_latents = [None] * len(latents)

for info, latent, flipped_latent in zip(image_infos, latents, flipped_latents):
for info, latent, flipped_latent, mask in zip(image_infos, latents, flipped_latents, masks):
# check NaN
if torch.isnan(latents).any() or (flipped_latent is not None and torch.isnan(flipped_latent).any()):
raise RuntimeError(f"NaN detected in latents: {info.absolute_path}")
Expand All @@ -2275,8 +2325,10 @@ def cache_batch_latents(
save_latents_to_disk(info.latents_npz, latent, info.latents_original_size, info.latents_crop_ltrb, flipped_latent)
else:
info.latents = latent
info.mask = mask
if flip_aug:
info.latents_flipped = flipped_latent
info.mask_flipped = mask.flip(mask, dims=[3])

# FIXME this slows down caching a lot, specify this as an option
if torch.cuda.is_available():
Expand Down Expand Up @@ -3259,6 +3311,10 @@ def add_dataset_arguments(
"--bucket_no_upscale", action="store_true", help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します"
)

parser.add_argument(
"--masked_loss", action="store_true", help="Enable masking of latent loss using grayscale mask images"
)

parser.add_argument(
"--token_warmup_min",
type=int,
Expand Down
6 changes: 6 additions & 0 deletions sdxl_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
get_latent_masks
)
from library.sdxl_original_unet import SdxlUNet2DConditionModel

Expand Down Expand Up @@ -561,6 +562,11 @@ def fn_recursive_set_mem_eff(module: torch.nn.Module):

target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

if (
args.min_snr_gamma
or args.scale_v_pred_loss_like_noise_pred
Expand Down
6 changes: 6 additions & 0 deletions train_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
get_latent_masks
)

# perlin_noise,
Expand Down Expand Up @@ -333,6 +334,11 @@ def train(args):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 6 additions & 0 deletions train_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
get_latent_masks
)


Expand Down Expand Up @@ -824,6 +825,11 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 6 additions & 0 deletions train_textual_inversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
scale_v_prediction_loss_like_noise_prediction,
add_v_prediction_like_loss,
apply_debiased_estimation,
get_latent_masks
)

imagenet_templates_small = [
Expand Down Expand Up @@ -582,6 +583,11 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down
6 changes: 6 additions & 0 deletions train_textual_inversion_XTI.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
apply_noise_offset,
scale_v_prediction_loss_like_noise_prediction,
apply_debiased_estimation,
get_latent_masks
)
import library.original_unet as original_unet
from XTI_hijack import unet_forward_XTI, downblock_forward_XTI, upblock_forward_XTI
Expand Down Expand Up @@ -459,6 +460,11 @@ def remove_model(old_ckpt_name):
else:
target = noise

if args.masked_loss and batch['masks'] is not None:
mask = get_latent_masks(batch['masks'], noise_pred.shape, noise_pred.device)
noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])

Expand Down

0 comments on commit 5f6c5ff

Please sign in to comment.