Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Minimal Implementation of Masked Weight Loss #236

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions library/train_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def __init__(self, image_key: str, num_repeats: int, caption: str, is_reg: bool,
self.latents_flipped: torch.Tensor = None
self.latents_npz: str = None
self.latents_npz_flipped: str = None
self.mask: np.ndarray = None
self.mask_flipped: np.ndarray = None


class BucketManager():
Expand Down Expand Up @@ -233,7 +235,6 @@ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_to
self.dropout_rate: float = 0
self.dropout_every_n_epochs: int = None
self.tag_dropout_rate: float = 0

# augmentation
flip_p = 0.5 if flip_aug else 0.0
if color_aug:
Expand All @@ -257,6 +258,7 @@ def __init__(self, tokenizer, max_token_length, shuffle_caption, shuffle_keep_to
self.image_data: Dict[str, ImageInfo] = {}

self.replacements = {}
self.mask_max_attention = 1

def set_current_epoch(self, epoch):
self.current_epoch = epoch
Expand Down Expand Up @@ -465,11 +467,32 @@ def shuffle_buckets(self):
random.shuffle(self.buckets_indices)
self.bucket_manager.shuffle()

def load_mask(self, path):
try:
mask_path = path[:path.rindex('.')]+".mask"
mask = np.array(Image.open(mask_path))
if len(mask.shape) > 2 and mask.max() <= 255:
print(mask.shape)
return np.array(Image.open(mask_path).convert("L"))
elif len(mask.shape) == 2 and mask.max() > 255:
print(mask.max())
return mask//(((2**16)-1)//255)
elif len(mask.shape) == 2 and mask.max() <= 255:
return mask
else:
print(f"{mask_path} has invalid mask format: Defaulting to no mask")
return np.ones_like(np.array(Image.open(path).convert("L")))*255
except:
print(f"{mask_path} not found: Defaulting to no mask")
return np.ones_like(np.array(Image.open(path).convert("L")))*255

def load_image(self, image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
if not image.mode == "RGBA":
image = image.convert("RGBA")
img = np.array(image, np.uint8)
#if img[:,-1].mean() == 255:
img[...,-1] = self.load_mask(image_path)
return img

def trim_and_resize_if_required(self, image, reso, resized_size):
Expand Down Expand Up @@ -508,16 +531,19 @@ def cache_latents(self, vae):

image = self.load_image(info.absolute_path)
image = self.trim_and_resize_if_required(image, info.bucket_reso, info.resized_size)

mask = image[:,:,-1] #grab alpha channel
image = image[:,:,:3] #drop alpha channel
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
info.mask = mask/(255/self.mask_max_attention)

if self.flip_aug:
image = image[:, ::-1].copy() # cannot convert to Tensor without copy
img_tensor = self.image_transforms(image)
img_tensor = img_tensor.unsqueeze(0).to(device=vae.device, dtype=vae.dtype)
info.latents_flipped = vae.encode(img_tensor).latent_dist.sample().squeeze(0).to("cpu")
info.mask_flipped = mask[::-1]/(255/self.mask_max_attention)

def get_image_size(self, image_path):
image = Image.open(image_path)
Expand Down Expand Up @@ -606,14 +632,17 @@ def __getitem__(self, index):
input_ids_list = []
latents_list = []
images = []
masks = []

for image_key in bucket[image_index:image_index + bucket_batch_size]:
image_info = self.image_data[image_key]
loss_weights.append(self.prior_loss_weight if image_info.is_reg else 1.0)

# image/latentsを処理する
if image_info.latents is not None:
latents = image_info.latents if not self.flip_aug or random.random() < .5 else image_info.latents_flipped
rand_flip = random.random()
latents = image_info.latents if not self.flip_aug or rand_flip < .5 else image_info.latents_flipped
mask = image_info.mask if not self.flip_aug or rand_flip < .5 else image_info.mask_flipped
image = None
elif image_info.latents_npz is not None:
latents = self.load_latents_from_npz(image_info, self.flip_aug and random.random() >= .5)
Expand All @@ -622,6 +651,8 @@ def __getitem__(self, index):
else:
# 画像を読み込み、必要ならcropする
img, face_cx, face_cy, face_w, face_h = self.load_image_with_face_info(image_info.absolute_path)
mask = img[:,:,-1]/(255/self.mask_max_attention) #grab alpha channel
img = img[:,:,:3] #drop alpha channel
im_h, im_w = img.shape[0:2]

if self.enable_bucket:
Expand All @@ -647,7 +678,8 @@ def __getitem__(self, index):

latents = None
image = self.image_transforms(img) # -1.0~1.0のtorch.Tensorになる

mask = torch.from_numpy(mask)
masks.append(torch.tensor(mask))
images.append(image)
latents_list.append(latents)

Expand All @@ -672,7 +704,7 @@ def __getitem__(self, index):
else:
images = None
example['images'] = images

example['masks'] = torch.stack(masks) if masks[0] is not None else None
example['latents'] = torch.stack(latents_list) if latents_list[0] is not None else None

if self.debug_dataset:
Expand Down Expand Up @@ -1494,6 +1526,8 @@ def add_dataset_arguments(parser: argparse.ArgumentParser, support_dreambooth: b
help="steps of resolution for buckets, divisible by 8 is recommended / bucketの解像度の単位、8で割り切れる値を推奨します")
parser.add_argument("--bucket_no_upscale", action="store_true",
help="make bucket for each image without upscaling / 画像を拡大せずbucketを作成します")
parser.add_argument("--masked_loss", action="store_true",
help="Enable Masked Loss from Alpha Channel")

if support_caption_dropout:
# Textual Inversion はcaptionのdropoutをsupportしない
Expand Down Expand Up @@ -2059,4 +2093,4 @@ def __getitem__(self, idx):
return (tensor_pil, img_path)


# endregion
# endregion
26 changes: 25 additions & 1 deletion train_network.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from torch.cuda.amp import autocast
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional
import importlib
import argparse
import gc
Expand Down Expand Up @@ -377,6 +378,25 @@ def train(args):
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
target = noise

if args.masked_loss and batch['masks'] is not None:

mask = (
batch['masks']
.to(noise_pred.device)
.reshape(
noise_pred.shape[0], 1, noise_pred.shape[2] * 8, noise_pred.shape[3] * 8
)
)
# resize to match noise_pred
mask = torch.nn.functional.interpolate(
mask.float(),
size=noise_pred.shape[-2:],
mode="nearest",
)

noise_pred = noise_pred * mask
target = target * mask

loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="none")
loss = loss.mean([1, 2, 3])
Expand Down Expand Up @@ -408,7 +428,11 @@ def train(args):
loss_list[step] = current_loss
loss_total += current_loss
avr_loss = loss_total / len(loss_list)
logs = {"loss": avr_loss} # , "lr": lr_scheduler.get_last_lr()[0]}

if args.masked_loss and batch['masks'] is not None:
logs = {"loss": avr_loss, "Batch Mask Average Weight": batch['masks'].mean().item()} # , "lr": lr_scheduler.get_last_lr()[0]}
else:
logs = {"loss": avr_loss}
progress_bar.set_postfix(**logs)

if args.logging_dir is not None:
Expand Down