From 6b350ab2ac4b1b44ed2c35f94586ba6346c339a6 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Tue, 24 Oct 2023 11:34:27 +0200 Subject: [PATCH 01/23] Additions: - support for different lr for text encoder - support for Prodigy optimizer - support for min snr gamma - support for custom captions and dataset loading from the hub --- .../dreambooth/train_dreambooth_lora_sdxl.py | 431 ++++++++++++++---- 1 file changed, 350 insertions(+), 81 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index d7df6d4ef526..3b5991f3152b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -23,6 +23,7 @@ import shutil import warnings from pathlib import Path +from typing import Dict import numpy as np import torch @@ -40,6 +41,7 @@ from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig +from datasets import load_dataset import diffusers from diffusers import ( @@ -141,13 +143,50 @@ def parse_args(input_args=None): required=False, help="Revision of pretrained model identifier from huggingface.co/models.", ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) parser.add_argument( "--instance_data_dir", type=str, default=None, - required=True, - help="A folder containing the training data of instance images.", + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), ) + + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--caption_column", + type=str, + default=None, + help="The column of the dataset containing the instance prompt for each image", + ) + parser.add_argument( "--class_data_dir", type=str, @@ -299,9 +338,16 @@ def parse_args(input_args=None): parser.add_argument( "--learning_rate", type=float, - default=5e-4, + default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) + + parser.add_argument( + "--text_encoder_lr", + type=float, + default=5e-6, + help="Text encoder learning rate to use.", + ) parser.add_argument( "--scale_lr", action="store_true", @@ -317,6 +363,14 @@ def parse_args(input_args=None): ' "constant", "constant_with_warmup"]' ), ) + + parser.add_argument( + "--snr_gamma", + type=float, + action="store_true", + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." ) @@ -335,13 +389,41 @@ def parse_args(input_args=None): "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) + + parser.add_argument( + "--optimizer", + type=str, + default="prodigy", + help=( + 'The optimizer type to use. Choose between ["adamW", "prodigy"]' + ), + ) + parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if " + "optimizer is not set to AdamW" ) - parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + + parser.add_argument("--adam_beta1", type=float, default=0.9, + help="The beta1 parameter for the Adam and Prodigy optimizers.") + parser.add_argument("--adam_beta2", type=float, default=0.999, + help="The beta2 parameter for the Adam and Prodigy optimizers.") + parser.add_argument("--prodigy_beta3", type=float, default=None, + help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " + "uses the value of square root of beta2") + parser.add_argument("--prodigy_decouple", type=bool, default=True, + help="Use AdamW style decoupled weight decay") + parser.add_argument("--adam_weight_decay", type=float, default=1e-02, help="Weight decay to use. If you're using " + "the Adam optimizer you might want to " + "change value to 1e-4") + + parser.add_argument("--adam_epsilon", type=float, default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.") + + parser.add_argument("--prodigy_use_bias_correction ", type=bool, default=True, + help="Turn on Adam's bias correction. True by default.") + parser.add_argument("--prodigy_safeguard_warmup ", type=bool, default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default.") parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") @@ -414,6 +496,12 @@ def parse_args(input_args=None): else: args = parser.parse_args() + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank @@ -440,21 +528,71 @@ class DreamBoothDataset(Dataset): """ def __init__( - self, - instance_data_root, - class_data_root=None, - class_num=None, - size=1024, - center_crop=False, + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + center_crop=False, ): self.size = size self.center_crop = center_crop self.instance_data_root = Path(instance_data_root) - if not self.instance_data_root.exists(): - raise ValueError("Instance images root doesn't exists.") - self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.instance_prompt = instance_prompt + self.custom_instance_prompts = None + self.class_prompt = class_prompt + + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + dataset = load_dataset( + instance_dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + + dataset = load_dataset(instance_data_root, + cache_dir=args.cache_dir, + ) + + # Preprocessing the datasets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + self.instance_images_path = dataset["train"][image_column] + + if args.caption_column is None: + try: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + self.custom_instance_prompts = dataset["train"][caption_column] + except IndexError: + logger.info(f"no caption column provided, deaulting to instance_prompt for all images") + self.custom_instance_prompts = None + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + self.num_instance_images = len(self.instance_images_path) self._length = self.num_instance_images @@ -484,13 +622,23 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + instance_image = self.instance_images_path[index % self.num_instance_images] instance_image = exif_transpose(instance_image) if not instance_image.mode == "RGB": instance_image = instance_image.convert("RGB") example["instance_images"] = self.image_transforms(instance_image) + if self.custom_instance_prompts: + caption = self.custom_instance_prompts[index % self.num_instance_images] + if caption: + example["instance_prompt"] = self.custom_instance_prompts[index % self.num_instance_images] + else: + example["instance_prompt"] = self.instance_prompt + + else: # costum prompts were provided, but length does not match size of image dataset + example["instance_prompt"] = self.instance_prompt + if self.class_data_root: class_image = Image.open(self.class_images_path[index % self.num_class_images]) class_image = exif_transpose(class_image) @@ -498,25 +646,53 @@ def __getitem__(self, index): if not class_image.mode == "RGB": class_image = class_image.convert("RGB") example["class_images"] = self.image_transforms(class_image) + example["class_prompt"] = self.class_prompt return example def collate_fn(examples, with_prior_preservation=False): pixel_values = [example["instance_images"] for example in examples] + prompts = [example["instance_prompt"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. if with_prior_preservation: pixel_values += [example["class_images"] for example in examples] + prompts += [example["class_prompt"] for example in examples] pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - batch = {"pixel_values": pixel_values} + batch = {"pixel_values": pixel_values, "prompts": prompts} return batch +def compute_snr(timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod ** 0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." @@ -865,31 +1041,95 @@ def load_model_hook(models, input_dir): args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) - # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs - if args.use_8bit_adam: + # Optimization parameters + unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} + if args.text_encoder_lr: # different learning rate for text encoder and unet + text_lora_parameters_one_with_lr = {"params": text_lora_parameters_one, "lr": args.text_encoder_lr} + text_lora_parameters_two_with_lr = {"params": text_lora_parameters_two, "lr": args.text_encoder_lr} + + else: + text_lora_parameters_one_with_lr = {"params": text_lora_parameters_one, "lr": args.learning_rate} + text_lora_parameters_two_with_lr = {"params": text_lora_parameters_two, "lr": args.learning_rate} + + params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr, + text_lora_parameters_two_with_lr] + + # Optimizer creation + if args.use_8bit_adam and not args.optimizer.lower() == "AdamW": + logger.warn(f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}") + + if args.optimizer.lower() == "prodigy": try: - import bitsandbytes as bnb + import prodigyopt except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warn( + f"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.text_encoder_lr: + logger.warn( + f"Learning rates were provided both for the unet and the text encdoer- e.g. text_encoder_lr and learning_rate" + f"when using prodigy only learning_rate is used as the initial learning rate" ) - optimizer_class = bnb.optim.AdamW8bit + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) + + + elif args.optimizer.lower() == "AdamW": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) else: - optimizer_class = torch.optim.AdamW + raise ValueError( + f"Unsupported choice of optimizer: {args.optimizer.lower()}. Supported optimizers include [adamW, prodigy]") - # Optimizer creation - params_to_optimize = ( - itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two) - if args.train_text_encoder - else unet_lora_parameters + # Dataset and DataLoaders creation: + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_prompt=args.class_prompt, + class_data_root=args.class_data_dir if args.with_prior_preservation else None, + class_num=args.num_class_images, + size=args.resolution, + center_crop=args.center_crop, ) - optimizer = optimizer_class( - params_to_optimize, - lr=args.learning_rate, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + batch_size=args.train_batch_size, + shuffle=True, + collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), + num_workers=args.dataloader_num_workers, ) # Computes additional embeddings/ids required by the SDXL UNet. @@ -918,9 +1158,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds - # Handle instance prompt. + # Handle instance prompt. If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. instance_time_ids = compute_time_ids() - if not args.train_text_encoder: + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers ) @@ -933,49 +1174,33 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): args.class_prompt, text_encoders, tokenizers ) - # Clear the memory here. - if not args.train_text_encoder: + # Clear the memory here + if not args.train_text_encoder and not train_dataset.custom_instance_prompts: del tokenizers, text_encoders gc.collect() torch.cuda.empty_cache() - # Pack the statically computed variables appropriately. This is so that we don't + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. add_time_ids = instance_time_ids if args.with_prior_preservation: add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0) - if not args.train_text_encoder: - prompt_embeds = instance_prompt_hidden_states - unet_add_text_embeds = instance_pooled_prompt_embeds - if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) - else: - tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) - tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) - if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) - class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) - tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) - tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) - - # Dataset and DataLoaders creation: - train_dataset = DreamBoothDataset( - instance_data_root=args.instance_data_dir, - class_data_root=args.class_data_dir if args.with_prior_preservation else None, - class_num=args.num_class_images, - size=args.resolution, - center_crop=args.center_crop, - ) - - train_dataloader = torch.utils.data.DataLoader( - train_dataset, - batch_size=args.train_batch_size, - shuffle=True, - collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation), - num_workers=args.dataloader_num_workers, - ) + if not train_dataset.custom_instance_prompts: + if not args.train_text_encoder: + prompt_embeds = instance_prompt_hidden_states + unet_add_text_embeds = instance_pooled_prompt_embeds + if args.with_prior_preservation: + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) + else: + tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) + tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Scheduler and math around the number of training steps. overrode_max_train_steps = False @@ -1035,7 +1260,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): path = os.path.basename(args.resume_from_checkpoint) else: # Get the mos recent checkpoint - dirs = os.listdir(args.output_dir) + dirs = os.listdir(args.model_dir_name) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None @@ -1048,7 +1273,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.output_dir, path)) + accelerator.load_state(os.path.join(args.model_dir_name, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step @@ -1078,6 +1303,25 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + prompts = batch["prompts"] + + # encode batch prompts when custom prompts are provided for each image - + if train_dataset.custom_instance_prompts: + + if not args.train_text_encoder: + prompt_embeds, unet_add_text_embeds = compute_text_embeddings( + prompts, text_encoders, tokenizers) + if args.with_prior_preservation: + prompt_embeds_input = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) + else: + tokens_one = tokenize_prompt(tokenizer_one, prompts) + tokens_two = tokenize_prompt(tokenizer_two, prompts) + if args.with_prior_preservation: + class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) + class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) + tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) + tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() @@ -1107,7 +1351,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): "time_ids": add_time_ids.repeat(elems_to_repeat, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1), } - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) + if not dataset.custom_instance_prompts: # i.e. we only encoded args.instance_prompt + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) + else: + prompt_embeds_input = prompt_embeds model_pred = unet( noisy_model_input, timesteps, @@ -1141,16 +1388,38 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) - # Compute instance loss - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - # Compute prior loss prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean") + if args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(timesteps) + base_weight = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[ + 0] / snr + ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # We calculate the original loss and then we mean over the non-batch dimensions and + # rebalance the sample-wise losses with their respective loss weights. + # Finally, we take the mean of the rebalanced loss. + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + if args.with_prior_preservation: # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss - else: - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: From 96e81f9fb992a67d15adb8468577f0bbabd05738 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Tue, 24 Oct 2023 14:55:08 +0200 Subject: [PATCH 02/23] adjusted --caption_column behaviour (to -not- use the second column of the dataset by default if --caption_column is not provided) --- .../dreambooth/train_dreambooth_lora_sdxl.py | 24 +++++++++---------- 1 file changed, 11 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 3b5991f3152b..0500b48f9b62 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -58,7 +58,6 @@ from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available - # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.22.0.dev0") @@ -66,7 +65,7 @@ def save_model_card( - repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None + repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None ): img_str = "" for i, image in enumerate(images): @@ -102,7 +101,7 @@ def save_model_card( def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision @@ -178,7 +177,9 @@ def parse_args(input_args=None): ) parser.add_argument( - "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + "--image_column", type=str, default="image", help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'." ) parser.add_argument( "--caption_column", @@ -579,13 +580,10 @@ def __init__( self.instance_images_path = dataset["train"][image_column] if args.caption_column is None: - try: - caption_column = column_names[1] - logger.info(f"caption column defaulting to {caption_column}") - self.custom_instance_prompts = dataset["train"][caption_column] - except IndexError: - logger.info(f"no caption column provided, deaulting to instance_prompt for all images") - self.custom_instance_prompts = None + logger.info(f"No caption column provided, defaulting to instance_prompt for all images. If your dataset " + f"contains captions/prompts for the images, make sure to specify the " + f"column as --caption_column") + self.custom_instance_prompts = None else: caption_column = args.caption_column if caption_column not in column_names: @@ -818,7 +816,7 @@ def main(args): pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images @@ -1038,7 +1036,7 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Optimization parameters From 7b2be6d9c51984ae4ca2df6a38f7f0b9d517da00 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Wed, 25 Oct 2023 11:37:42 +0200 Subject: [PATCH 03/23] fixed --output_dir / --model_dir_name confusion --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 0500b48f9b62..d19b9903020b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1156,7 +1156,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds - # Handle instance prompt. If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), we encode the instance prompt once to avoid + # Handle instance prompt. If custom instance prompts are NOT provided + # (i.e. the instance prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. instance_time_ids = compute_time_ids() if not args.train_text_encoder and not train_dataset.custom_instance_prompts: @@ -1178,7 +1179,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): gc.collect() torch.cuda.empty_cache() - # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), pack the statically computed variables appropriately here. This is so that we don't + # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), + # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. add_time_ids = instance_time_ids if args.with_prior_preservation: @@ -1258,7 +1260,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): path = os.path.basename(args.resume_from_checkpoint) else: # Get the mos recent checkpoint - dirs = os.listdir(args.model_dir_name) + dirs = os.listdir(args.output_dir) dirs = [d for d in dirs if d.startswith("checkpoint")] dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) path = dirs[-1] if len(dirs) > 0 else None @@ -1271,7 +1273,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): initial_global_step = 0 else: accelerator.print(f"Resuming from checkpoint {path}") - accelerator.load_state(os.path.join(args.model_dir_name, path)) + accelerator.load_state(os.path.join(args.output_dir, path)) global_step = int(path.split("-")[1]) initial_global_step = global_step From b1fda99fae14b655567cfab75c6f6ac1054c7476 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Mon, 13 Nov 2023 16:25:19 +0200 Subject: [PATCH 04/23] added --repeats, --adam_weight_decay_text_encoder + some fixes --- .../dreambooth/train_dreambooth_lora_sdxl.py | 80 +++++++++++-------- 1 file changed, 47 insertions(+), 33 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index d19b9903020b..5702df97b721 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -188,6 +188,11 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) + parser.add_argument("--repeats", + type=int, + default=100, + help="How many times to repeat the training data.") + parser.add_argument( "--class_data_dir", type=str, @@ -200,7 +205,7 @@ def parse_args(input_args=None): type=str, default=None, required=True, - help="The prompt with identifier specifying the instance", + help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", ) parser.add_argument( "--class_prompt", @@ -368,7 +373,6 @@ def parse_args(input_args=None): parser.add_argument( "--snr_gamma", type=float, - action="store_true", help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " "More details here: https://arxiv.org/abs/2303.09556.", ) @@ -396,7 +400,7 @@ def parse_args(input_args=None): type=str, default="prodigy", help=( - 'The optimizer type to use. Choose between ["adamW", "prodigy"]' + 'The optimizer type to use. Choose between ["AdamW", "prodigy"]' ), ) @@ -414,9 +418,9 @@ def parse_args(input_args=None): "uses the value of square root of beta2") parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") - parser.add_argument("--adam_weight_decay", type=float, default=1e-02, help="Weight decay to use. If you're using " - "the Adam optimizer you might want to " - "change value to 1e-4") + parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") + parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for " + "text_encoder") parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer and Prodigy optimizers.") @@ -536,6 +540,7 @@ def __init__( class_data_root=None, class_num=None, size=1024, + repeats=1, center_crop=False, ): self.size = size @@ -585,14 +590,14 @@ def __init__( f"column as --caption_column") self.custom_instance_prompts = None else: - caption_column = args.caption_column - if caption_column not in column_names: + if args.caption_column not in column_names: raise ValueError( f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) + self.custom_instance_prompts = dataset["train"][args.caption_column] self.num_instance_images = len(self.instance_images_path) - self._length = self.num_instance_images + self._length = self.num_instance_images * repeats if class_data_root is not None: self.class_data_root = Path(class_data_root) @@ -630,7 +635,7 @@ def __getitem__(self, index): if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] if caption: - example["instance_prompt"] = self.custom_instance_prompts[index % self.num_instance_images] + example["instance_prompt"] = caption else: example["instance_prompt"] = self.instance_prompt @@ -905,7 +910,8 @@ def main(args): xformers_version = version.parse(xformers.__version__) if xformers_version == version.parse("0.0.16"): logger.warn( - "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, " + "please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." ) unet.enable_xformers_memory_efficient_attention() else: @@ -1041,19 +1047,21 @@ def load_model_hook(models, input_dir): # Optimization parameters unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} - if args.text_encoder_lr: # different learning rate for text encoder and unet - text_lora_parameters_one_with_lr = {"params": text_lora_parameters_one, "lr": args.text_encoder_lr} - text_lora_parameters_two_with_lr = {"params": text_lora_parameters_two, "lr": args.text_encoder_lr} - + if not args.train_text_encoder: + # different learning rate for text encoder and unet + text_lora_parameters_one_with_lr = {"params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate} + text_lora_parameters_two_with_lr = {"params": text_lora_parameters_two, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate} + params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr, + text_lora_parameters_two_with_lr] else: - text_lora_parameters_one_with_lr = {"params": text_lora_parameters_one, "lr": args.learning_rate} - text_lora_parameters_two_with_lr = {"params": text_lora_parameters_two, "lr": args.learning_rate} - - params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr, - text_lora_parameters_two_with_lr] + params_to_optimize = [unet_lora_parameters_with_lr] # Optimizer creation - if args.use_8bit_adam and not args.optimizer.lower() == "AdamW": + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": logger.warn(f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}") @@ -1088,7 +1096,7 @@ def load_model_hook(models, input_dir): ) - elif args.optimizer.lower() == "AdamW": + elif args.optimizer.lower() == "adamw": if args.use_8bit_adam: try: import bitsandbytes as bnb @@ -1119,6 +1127,7 @@ def load_model_hook(models, input_dir): class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_num=args.num_class_images, size=args.resolution, + repeats=args.repeats, center_crop=args.center_crop, ) @@ -1156,10 +1165,12 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds - # Handle instance prompt. If custom instance prompts are NOT provided - # (i.e. the instance prompt is used for all images), we encode the instance prompt once to avoid - # the redundant encoding. + # Handle instance prompt. instance_time_ids = compute_time_ids() + + # If no type of tuning is done on the text_encoder and custom instance prompts are NOT + # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid + # the redundant encoding. if not args.train_text_encoder and not train_dataset.custom_instance_prompts: instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings( args.instance_prompt, text_encoders, tokenizers @@ -1193,6 +1204,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if args.with_prior_preservation: prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) + # if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the + # batch prompts on all training steps else: tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt) tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt) @@ -1312,7 +1325,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt_embeds, unet_add_text_embeds = compute_text_embeddings( prompts, text_encoders, tokenizers) if args.with_prior_preservation: - prompt_embeds_input = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) + prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) else: tokens_one = tokenize_prompt(tokenizer_one, prompts) @@ -1342,8 +1355,11 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) - # Calculate the elements to repeat depending on the use of prior-preservation. - elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz + # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. + if not train_dataset.custom_instance_prompts: + elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz + else: + elems_to_repeat = 1 # todo - make it smarter # Predict the noise residual if not args.train_text_encoder: @@ -1351,10 +1367,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): "time_ids": add_time_ids.repeat(elems_to_repeat, 1), "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1), } - if not dataset.custom_instance_prompts: # i.e. we only encoded args.instance_prompt - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) - else: - prompt_embeds_input = prompt_embeds + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) model_pred = unet( noisy_model_input, timesteps, @@ -1362,7 +1375,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): added_cond_kwargs=unet_added_conditions, ).sample else: - unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)} + unet_added_conditions = {"time_ids": add_time_ids.repeat(bsz, 1)} # todo - make it smarter + # unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)} prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=None, From ed9edd86101da3dd0ad03baff21463d77a89883d Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 13 Nov 2023 16:50:22 +0200 Subject: [PATCH 05/23] Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Patrick von Platen --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 5702df97b721..5c094f1f47c8 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1064,6 +1064,13 @@ def load_model_hook(models, input_dir): if args.use_8bit_adam and not args.optimizer.lower() == "adamw": logger.warn(f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}") + elif args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) if args.optimizer.lower() == "prodigy": try: From 6e14e9f8624a7cc44ff1eb5430534aae7add8379 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 13 Nov 2023 16:50:53 +0200 Subject: [PATCH 06/23] Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Patrick von Platen --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 5c094f1f47c8..6e55a386df92 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1086,7 +1086,7 @@ def load_model_hook(models, input_dir): ) if args.text_encoder_lr: logger.warn( - f"Learning rates were provided both for the unet and the text encdoer- e.g. text_encoder_lr and learning_rate" + f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr: {args.text_encoder_lr} and learning_rate: {args.learning_rate}" f"when using prodigy only learning_rate is used as the initial learning rate" ) From bde6dfa36d38f3b8eb1f69f89a1ad87b36f0cc99 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com> Date: Mon, 13 Nov 2023 16:58:31 +0200 Subject: [PATCH 07/23] Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Patrick von Platen --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6e55a386df92..6f102c03a0fa 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1123,8 +1123,8 @@ def load_model_hook(models, input_dir): eps=args.adam_epsilon, ) else: - raise ValueError( - f"Unsupported choice of optimizer: {args.optimizer.lower()}. Supported optimizers include [adamW, prodigy]") + logger.warn("f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include [adamW, prodigy]. Defaulting to "adamW") + optimizer_class = torch.optim.AdamW # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( From 16376860132a1255715dcff2a3b89dfc50157a39 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Wed, 15 Nov 2023 13:59:13 +0200 Subject: [PATCH 08/23] - import compute_snr from diffusers/training_utils.py - cluster adamw together - when using 'prodigy', if --train_text_encoder == True and --text_encoder_lr != --learning rate, changes the lr of the text encoders optimization params to be --learning_rate (otherwise errors) --- .../dreambooth/train_dreambooth_lora_sdxl.py | 99 +++++++------------ 1 file changed, 37 insertions(+), 62 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 6f102c03a0fa..9be556a80acf 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -54,7 +54,7 @@ from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict from diffusers.models.lora import LoRALinearLayer from diffusers.optimization import get_scheduler -from diffusers.training_utils import unet_lora_state_dict +from diffusers.training_utils import unet_lora_state_dict, compute_snr from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -671,31 +671,6 @@ def collate_fn(examples, with_prior_preservation=False): return batch -def compute_snr(timesteps): - """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 - """ - alphas_cumprod = noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod ** 0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - - # Compute SNR. - snr = (alpha / sigma) ** 2 - return snr - - class PromptDataset(Dataset): "A simple dataset to prepare the prompts to generate class images on multiple GPUs." @@ -1047,7 +1022,7 @@ def load_model_hook(models, input_dir): # Optimization parameters unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} - if not args.train_text_encoder: + if args.train_text_encoder: # different learning rate for text encoder and unet text_lora_parameters_one_with_lr = {"params": text_lora_parameters_one, "weight_decay": args.adam_weight_decay_text_encoder, @@ -1061,16 +1036,34 @@ def load_model_hook(models, input_dir): params_to_optimize = [unet_lora_parameters_with_lr] # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warn(f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW") + args.optimizer = "adamw" + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": logger.warn(f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}") - elif args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) if args.optimizer.lower() == "prodigy": try: @@ -1084,11 +1077,17 @@ def load_model_hook(models, input_dir): logger.warn( f"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) - if args.text_encoder_lr: + if args.train_text_encoder and args.text_encoder_lr: logger.warn( - f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr: {args.text_encoder_lr} and learning_rate: {args.learning_rate}" - f"when using prodigy only learning_rate is used as the initial learning rate" + f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, @@ -1102,30 +1101,6 @@ def load_model_hook(models, input_dir): safeguard_warmup=args.prodigy_safeguard_warmup, ) - - elif args.optimizer.lower() == "adamw": - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - optimizer = optimizer_class( - params_to_optimize, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) - else: - logger.warn("f"Unsupported choice of optimizer: {args.optimizer}. Supported optimizers include [adamW, prodigy]. Defaulting to "adamW") - optimizer_class = torch.optim.AdamW - # Dataset and DataLoaders creation: train_dataset = DreamBoothDataset( instance_data_root=args.instance_data_dir, @@ -1418,7 +1393,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. - snr = compute_snr(timesteps) + snr = compute_snr(noise_scheduler, timesteps) base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[ 0] / snr From 4ba9fc77af1d29754b128181d896509dc9ba42a1 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Wed, 15 Nov 2023 17:22:22 +0200 Subject: [PATCH 09/23] shape fixes when custom captions are used --- .../dreambooth/train_dreambooth_lora_sdxl.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 9be556a80acf..0d31ebc16fcf 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1122,7 +1122,7 @@ def load_model_hook(models, input_dir): ) # Computes additional embeddings/ids required by the SDXL UNet. - # regular text emebddings (when `train_text_encoder` is not True) + # regular text embeddings (when `train_text_encoder` is not True) # pooled text embeddings # time ids @@ -1339,17 +1339,19 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. if not train_dataset.custom_instance_prompts: - elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz + elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz + elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz else: - elems_to_repeat = 1 # todo - make it smarter + elems_to_repeat_text_embeds = 1 + elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz # Predict the noise residual if not args.train_text_encoder: unet_added_conditions = { - "time_ids": add_time_ids.repeat(elems_to_repeat, 1), - "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1), + "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), + "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), } - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( noisy_model_input, timesteps, @@ -1357,7 +1359,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): added_cond_kwargs=unet_added_conditions, ).sample else: - unet_added_conditions = {"time_ids": add_time_ids.repeat(bsz, 1)} # todo - make it smarter + unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} # unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)} prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], @@ -1365,8 +1367,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt=None, text_input_ids_list=[tokens_one, tokens_two], ) - unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)}) - prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1) + unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}) + prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions ).sample @@ -1406,9 +1408,6 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # Epsilon and sample both use the same loss weights. mse_loss_weights = base_weight - # We calculate the original loss and then we mean over the non-batch dimensions and - # rebalance the sample-wise losses with their respective loss weights. - # Finally, we take the mean of the rebalanced loss. loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean() From 3d507167678b407302ee0ec847191a44a67e73d5 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Thu, 16 Nov 2023 12:38:44 +0200 Subject: [PATCH 10/23] formatting and a little cleanup --- .../dreambooth/train_dreambooth_lora_sdxl.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 13d50812150b..880a642d2514 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -import argparse import gc import hashlib import itertools @@ -23,25 +22,25 @@ import shutil import warnings from pathlib import Path -from typing import Dict +import argparse import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers +from PIL import Image +from PIL.ImageOps import exif_transpose from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed +from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version -from PIL import Image -from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig -from datasets import load_dataset import diffusers from diffusers import ( @@ -65,7 +64,8 @@ def save_model_card( - repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None + repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, + vae_path=None ): img_str = "" for i, image in enumerate(images): @@ -1361,14 +1361,14 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): ).sample else: unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} - # unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)} prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=None, prompt=None, text_input_ids_list=[tokens_one, tokens_two], ) - unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}) + unet_added_conditions.update( + {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions From c07653fe788e99e4e6fc13e6c2ebac3e4132bc7b Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Thu, 16 Nov 2023 11:58:27 +0000 Subject: [PATCH 11/23] code styling --- .../dreambooth/train_dreambooth_lora_sdxl.py | 182 ++++++++++-------- 1 file changed, 106 insertions(+), 76 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 880a642d2514..4dc88a332deb 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -13,6 +13,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and +import argparse import gc import hashlib import itertools @@ -23,20 +24,19 @@ import warnings from pathlib import Path -import argparse import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint import transformers -from PIL import Image -from PIL.ImageOps import exif_transpose from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version +from PIL import Image +from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm @@ -53,10 +53,11 @@ from diffusers.loaders import LoraLoaderMixin from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict from diffusers.optimization import get_scheduler -from diffusers.training_utils import unet_lora_state_dict, compute_snr +from diffusers.training_utils import compute_snr, unet_lora_state_dict from diffusers.utils import check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.24.0.dev0") @@ -64,8 +65,7 @@ def save_model_card( - repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, - vae_path=None + repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None ): img_str = "" for i, image in enumerate(images): @@ -101,7 +101,7 @@ def save_model_card( def import_model_class_from_model_name_or_path( - pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" ): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, subfolder=subfolder, revision=revision @@ -177,9 +177,12 @@ def parse_args(input_args=None): ) parser.add_argument( - "--image_column", type=str, default="image", help="The column of the dataset containing the target image. By " - "default, the standard Image Dataset maps out 'file_name' " - "to 'image'." + "--image_column", + type=str, + default="image", + help="The column of the dataset containing the target image. By " + "default, the standard Image Dataset maps out 'file_name' " + "to 'image'.", ) parser.add_argument( "--caption_column", @@ -188,10 +191,7 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument("--repeats", - type=int, - default=100, - help="How many times to repeat the training data.") + parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") parser.add_argument( "--class_data_dir", @@ -374,7 +374,7 @@ def parse_args(input_args=None): "--snr_gamma", type=float, help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " - "More details here: https://arxiv.org/abs/2303.09556.", + "More details here: https://arxiv.org/abs/2303.09556.", ) parser.add_argument( "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." @@ -399,36 +399,53 @@ def parse_args(input_args=None): "--optimizer", type=str, default="prodigy", - help=( - 'The optimizer type to use. Choose between ["AdamW", "prodigy"]' - ), + help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), ) parser.add_argument( - "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if " - "optimizer is not set to AdamW" - ) - - parser.add_argument("--adam_beta1", type=float, default=0.9, - help="The beta1 parameter for the Adam and Prodigy optimizers.") - parser.add_argument("--adam_beta2", type=float, default=0.999, - help="The beta2 parameter for the Adam and Prodigy optimizers.") - parser.add_argument("--prodigy_beta3", type=float, default=None, - help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " - "uses the value of square root of beta2") - parser.add_argument("--prodigy_decouple", type=bool, default=True, - help="Use AdamW style decoupled weight decay") + "--use_8bit_adam", + action="store_true", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if " "optimizer is not set to AdamW", + ) + + parser.add_argument( + "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " + "uses the value of square root of beta2", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") - parser.add_argument("--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for " - "text_encoder") + parser.add_argument( + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for " "text_encoder" + ) - parser.add_argument("--adam_epsilon", type=float, default=1e-08, - help="Epsilon value for the Adam optimizer and Prodigy optimizers.") + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer and Prodigy optimizers.", + ) - parser.add_argument("--prodigy_use_bias_correction ", type=bool, default=True, - help="Turn on Adam's bias correction. True by default.") - parser.add_argument("--prodigy_safeguard_warmup ", type=bool, default=True, - help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default.") + parser.add_argument( + "--prodigy_use_bias_correction ", + type=bool, + default=True, + help="Turn on Adam's bias correction. True by default.", + ) + parser.add_argument( + "--prodigy_safeguard_warmup ", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default.", + ) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") @@ -533,15 +550,15 @@ class DreamBoothDataset(Dataset): """ def __init__( - self, - instance_data_root, - instance_prompt, - class_prompt, - class_data_root=None, - class_num=None, - size=1024, - repeats=1, - center_crop=False, + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, ): self.size = size self.center_crop = center_crop @@ -557,7 +574,7 @@ def __init__( # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script dataset = load_dataset( - instance_dataset_name, + args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, ) @@ -565,9 +582,10 @@ def __init__( if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - dataset = load_dataset(instance_data_root, - cache_dir=args.cache_dir, - ) + dataset = load_dataset( + instance_data_root, + cache_dir=args.cache_dir, + ) # Preprocessing the datasets. column_names = dataset["train"].column_names @@ -585,9 +603,11 @@ def __init__( self.instance_images_path = dataset["train"][image_column] if args.caption_column is None: - logger.info(f"No caption column provided, defaulting to instance_prompt for all images. If your dataset " - f"contains captions/prompts for the images, make sure to specify the " - f"column as --caption_column") + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" + ) self.custom_instance_prompts = None else: if args.caption_column not in column_names: @@ -797,7 +817,7 @@ def main(args): pipeline.to(accelerator.device) for example in tqdm( - sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process + sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): images = pipeline(example["prompt"]).images @@ -1018,33 +1038,44 @@ def load_model_hook(models, input_dir): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes ) # Optimization parameters unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate} if args.train_text_encoder: # different learning rate for text encoder and unet - text_lora_parameters_one_with_lr = {"params": text_lora_parameters_one, - "weight_decay": args.adam_weight_decay_text_encoder, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate} - text_lora_parameters_two_with_lr = {"params": text_lora_parameters_two, - "weight_decay": args.adam_weight_decay_text_encoder, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate} - params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr, - text_lora_parameters_two_with_lr] + text_lora_parameters_one_with_lr = { + "params": text_lora_parameters_one, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + text_lora_parameters_two_with_lr = { + "params": text_lora_parameters_two, + "weight_decay": args.adam_weight_decay_text_encoder, + "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, + } + params_to_optimize = [ + unet_lora_parameters_with_lr, + text_lora_parameters_one_with_lr, + text_lora_parameters_two_with_lr, + ] else: params_to_optimize = [unet_lora_parameters_with_lr] # Optimizer creation if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): - logger.warn(f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." - "Defaulting to adamW") + logger.warn( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) args.optimizer = "adamw" if args.use_8bit_adam and not args.optimizer.lower() == "adamw": - logger.warn(f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " - f"set to {args.optimizer.lower()}") + logger.warn( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) if args.optimizer.lower() == "adamw": if args.use_8bit_adam: @@ -1076,14 +1107,13 @@ def load_model_hook(models, input_dir): if args.learning_rate <= 0.1: logger.warn( - f"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) if args.train_text_encoder and args.text_encoder_lr: logger.warn( f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " f"When using prodigy only learning_rate is used as the initial learning rate." - ) # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be # --learning_rate @@ -1303,10 +1333,10 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # encode batch prompts when custom prompts are provided for each image - if train_dataset.custom_instance_prompts: - if not args.train_text_encoder: prompt_embeds, unet_add_text_embeds = compute_text_embeddings( - prompts, text_encoders, tokenizers) + prompts, text_encoders, tokenizers + ) if args.with_prior_preservation: prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) @@ -1368,7 +1398,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): text_input_ids_list=[tokens_one, tokens_two], ) unet_added_conditions.update( - {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}) + {"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)} + ) prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) model_pred = unet( noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions @@ -1398,8 +1429,7 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) base_weight = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[ - 0] / snr + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) if noise_scheduler.config.prediction_type == "v_prediction": From 26535538c00a406b4e32f985b96f35a983c17caf Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Fri, 17 Nov 2023 14:13:25 +0200 Subject: [PATCH 12/23] --repeats default value fixed, changed to 1 --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 4dc88a332deb..e9b3d7bff51b 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -191,7 +191,7 @@ def parse_args(input_args=None): help="The column of the dataset containing the instance prompt for each image", ) - parser.add_argument("--repeats", type=int, default=100, help="How many times to repeat the training data.") + parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.") parser.add_argument( "--class_data_dir", From 15ece12eaf95374c7f7aafdb2f665e3cc5eacaf1 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Fri, 17 Nov 2023 16:26:08 +0200 Subject: [PATCH 13/23] bug fix - removed redundant lines of embedding concatenation when using prior_preservation (that duplicated class_prompt embeddings) --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index e9b3d7bff51b..af413228dfb0 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1337,17 +1337,9 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): prompt_embeds, unet_add_text_embeds = compute_text_embeddings( prompts, text_encoders, tokenizers ) - if args.with_prior_preservation: - prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0) - unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0) else: tokens_one = tokenize_prompt(tokenizer_one, prompts) tokens_two = tokenize_prompt(tokenizer_two, prompts) - if args.with_prior_preservation: - class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt) - class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt) - tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0) - tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0) # Convert images to latent space model_input = vae.encode(pixel_values).latent_dist.sample() From b9f9da610a862b519952b309fa495548cb86bbdb Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Sun, 19 Nov 2023 11:22:24 +0200 Subject: [PATCH 14/23] changed dataset loading logic according to the following usecases (to avoid unnecessary dependency on datasets)- 1. user provides --dataset_name 2. user provides local dir --instance_data_dir that contains a metadata .jsonl file 3. user provides local dir --instance_data_dir that contains only images in cases [1,2] we import datasets and use load_dataset method, in case [3] we process the data same as in the original script setting --- .../dreambooth/train_dreambooth_lora_sdxl.py | 107 +++++++++++------- 1 file changed, 64 insertions(+), 43 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index af413228dfb0..b8def51cd543 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -32,7 +32,6 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed -from datasets import load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version from PIL import Image @@ -163,12 +162,20 @@ def parse_args(input_args=None): type=str, default=None, help=( - "A folder containing the training data. Folder contents must follow the structure described in" - " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" - " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + "A folder containing the training data. " ), ) + parser.add_argument( + "--instance_data_metadata_file_name", + type=str, + default="metadata.jsonl", + help="file name(relative path) of a jsonl metadata file located in --instance_data_dir. In particular, " + "if you wish to use custom captions, a `metadata.jsonl` file must exist to provide captions for the " + "images. Ignored if `dataset_name` is specified. see " + "https://huggingface.co/docs/datasets/image_dataset#imagefolder for more information" + ) + parser.add_argument( "--cache_dir", type=str, @@ -550,15 +557,15 @@ class DreamBoothDataset(Dataset): """ def __init__( - self, - instance_data_root, - instance_prompt, - class_prompt, - class_data_root=None, - class_num=None, - size=1024, - repeats=1, - center_crop=False, + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, ): self.size = size self.center_crop = center_crop @@ -569,7 +576,12 @@ def __init__( self.custom_instance_prompts = None self.class_prompt = class_prompt + load_as_dataset = False + + # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, + # we load the training data using load_dataset if args.dataset_name is not None: + from datasets import load_dataset # Downloading and loading a dataset from the hub. # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script @@ -578,45 +590,54 @@ def __init__( args.dataset_config_name, cache_dir=args.cache_dir, ) + load_as_dataset = True else: if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") - dataset = load_dataset( - instance_data_root, - cache_dir=args.cache_dir, - ) + if args.instance_data_metadata_file_name in os.listdir(instance_data_root): + from datasets import load_dataset + dataset = load_dataset( + instance_data_root, + cache_dir=args.cache_dir, + ) + self.load_as_dataset = True - # Preprocessing the datasets. - column_names = dataset["train"].column_names + if load_as_dataset: + # Preprocessing the datasets. + column_names = dataset["train"].column_names - # 6. Get the column names for input/target. - if args.image_column is None: - image_column = column_names[0] - logger.info(f"image column defaulting to {image_column}") - else: - image_column = args.image_column - if image_column not in column_names: - raise ValueError( - f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + self.instance_images = dataset["train"][image_column] + + if args.caption_column is None: + logger.info( + "No caption column provided, defaulting to instance_prompt for all images. If your dataset " + "contains captions/prompts for the images, make sure to specify the " + "column as --caption_column" ) - self.instance_images_path = dataset["train"][image_column] + self.custom_instance_prompts = None + else: + if args.caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + self.custom_instance_prompts = dataset["train"][args.caption_column] - if args.caption_column is None: - logger.info( - "No caption column provided, defaulting to instance_prompt for all images. If your dataset " - "contains captions/prompts for the images, make sure to specify the " - "column as --caption_column" - ) - self.custom_instance_prompts = None else: - if args.caption_column not in column_names: - raise ValueError( - f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" - ) - self.custom_instance_prompts = dataset["train"][args.caption_column] + self.instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + self.custom_instance_prompts = None - self.num_instance_images = len(self.instance_images_path) + self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images * repeats if class_data_root is not None: @@ -645,7 +666,7 @@ def __len__(self): def __getitem__(self, index): example = {} - instance_image = self.instance_images_path[index % self.num_instance_images] + instance_image = self.instance_images[index % self.num_instance_images] instance_image = exif_transpose(instance_image) if not instance_image.mode == "RGB": From 02947723d70628d269cf7ee0c843ea522eebb7de Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Sun, 19 Nov 2023 09:40:15 +0000 Subject: [PATCH 15/23] styling fix --- .../dreambooth/train_dreambooth_lora_sdxl.py | 32 +++++++++---------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index b8def51cd543..7ab373486748 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -161,9 +161,7 @@ def parse_args(input_args=None): "--instance_data_dir", type=str, default=None, - help=( - "A folder containing the training data. " - ), + help=("A folder containing the training data. "), ) parser.add_argument( @@ -171,9 +169,9 @@ def parse_args(input_args=None): type=str, default="metadata.jsonl", help="file name(relative path) of a jsonl metadata file located in --instance_data_dir. In particular, " - "if you wish to use custom captions, a `metadata.jsonl` file must exist to provide captions for the " - "images. Ignored if `dataset_name` is specified. see " - "https://huggingface.co/docs/datasets/image_dataset#imagefolder for more information" + "if you wish to use custom captions, a `metadata.jsonl` file must exist to provide captions for the " + "images. Ignored if `dataset_name` is specified. see " + "https://huggingface.co/docs/datasets/image_dataset#imagefolder for more information", ) parser.add_argument( @@ -557,15 +555,15 @@ class DreamBoothDataset(Dataset): """ def __init__( - self, - instance_data_root, - instance_prompt, - class_prompt, - class_data_root=None, - class_num=None, - size=1024, - repeats=1, - center_crop=False, + self, + instance_data_root, + instance_prompt, + class_prompt, + class_data_root=None, + class_num=None, + size=1024, + repeats=1, + center_crop=False, ): self.size = size self.center_crop = center_crop @@ -582,6 +580,7 @@ def __init__( # we load the training data using load_dataset if args.dataset_name is not None: from datasets import load_dataset + # Downloading and loading a dataset from the hub. # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script @@ -597,11 +596,12 @@ def __init__( if args.instance_data_metadata_file_name in os.listdir(instance_data_root): from datasets import load_dataset + dataset = load_dataset( instance_data_root, cache_dir=args.cache_dir, ) - self.load_as_dataset = True + load_as_dataset = True if load_as_dataset: # Preprocessing the datasets. From 3b86a172b9ec5128272354b1af6dcae866b03717 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Sun, 19 Nov 2023 10:01:54 +0000 Subject: [PATCH 16/23] arg name fix --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 7ab373486748..753fead2f778 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -523,10 +523,10 @@ def parse_args(input_args=None): else: args = parser.parse_args() - if args.dataset_name is None and args.train_data_dir is None: + if args.dataset_name is None and args.instance_data_dir is None: raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`") - if args.dataset_name is not None and args.train_data_dir is not None: + if args.dataset_name is not None and args.instance_data_dir is not None: raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`") env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) From 597263332d0690d12b77bcc7a2c1c7c7ce00ce1a Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Sun, 19 Nov 2023 22:42:47 +0200 Subject: [PATCH 17/23] adjusted the --repeats logic --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 753fead2f778..2f7080f6a9b9 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -617,7 +617,7 @@ def __init__( raise ValueError( f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) - self.instance_images = dataset["train"][image_column] + instance_images = dataset["train"][image_column] if args.caption_column is None: logger.info( @@ -631,14 +631,21 @@ def __init__( raise ValueError( f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" ) - self.custom_instance_prompts = dataset["train"][args.caption_column] + custom_instance_prompts = dataset["train"][args.caption_column] + # create final list of captions according to --repeats + self.custom_instance_prompts = [] + for caption in custom_instance_prompts: + self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) else: - self.instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] self.custom_instance_prompts = None + self.instance_images = [] + for img in instance_images: + self.instance_images.extend(itertools.repeat(img, repeats)) self.num_instance_images = len(self.instance_images) - self._length = self.num_instance_images * repeats + self._length = self.num_instance_images if class_data_root is not None: self.class_data_root = Path(class_data_root) From 67b572470da6adfbd3eb1b208c201acfb291d347 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Tue, 21 Nov 2023 00:26:10 +0200 Subject: [PATCH 18/23] -removed redundant arg and 'if' when loading local folder with prompts -updated readme template -some default val fixes -custom caption tests --- .../dreambooth/train_dreambooth_lora_sdxl.py | 114 ++++++++++-------- examples/test_examples.py | 41 +++++++ 2 files changed, 105 insertions(+), 50 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 2f7080f6a9b9..8d06f07ebb9f 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -64,36 +64,66 @@ def save_model_card( - repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None + repo_id: str, + images=None, + base_model=str, + train_text_encoder=False, + instance_prompt=str, + validation_prompt=str, + repo_folder=None, + vae_path=None, ): - img_str = "" + img_str = "widget:\n" if images else "" for i, image in enumerate(images): image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f"![img_{i}](./image_{i}.png)\n" + img_str += f""" + - text: '{validation_prompt if validation_prompt else ' ' }' + parameters: + negative_prompt: '-' + output: + url: f"image_{i}.png" + """ yaml = f""" --- -license: openrail++ -base_model: {base_model} -instance_prompt: {prompt} tags: - stable-diffusion-xl - stable-diffusion-xl-diffusers - text-to-image - diffusers - lora -inference: true +- template:sd-lora +widget: +{img_str} +--- +base_model: {base_model} +instance_prompt: {instance_prompt} +license: openrail++ --- """ + model_card = f""" -# LoRA DreamBooth - {repo_id} +# SDXL LoRA DreamBooth - {repo_id} -These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n -{img_str} + -LoRA for the text encoder was enabled: {train_text_encoder}. +## Model description +These are {repo_id} LoRA adaption weights for {base_model}. +The weights were trained using [DreamBooth](https://dreambooth.github.io/). +LoRA for the text encoder was enabled: {train_text_encoder}. Special VAE used for training: {vae_path}. + +## Trigger words + +You should use {instance_prompt} to trigger the image generation. + +## Download model + +Weights for this model are available in Safetensors format. + +[Download]({repo_id}/tree/main) them in the Files & versions tab. + """ with open(os.path.join(repo_folder, "README.md"), "w") as f: f.write(yaml + model_card) @@ -164,16 +194,6 @@ def parse_args(input_args=None): help=("A folder containing the training data. "), ) - parser.add_argument( - "--instance_data_metadata_file_name", - type=str, - default="metadata.jsonl", - help="file name(relative path) of a jsonl metadata file located in --instance_data_dir. In particular, " - "if you wish to use custom captions, a `metadata.jsonl` file must exist to provide captions for the " - "images. Ignored if `dataset_name` is specified. see " - "https://huggingface.co/docs/datasets/image_dataset#imagefolder for more information", - ) - parser.add_argument( "--cache_dir", type=str, @@ -378,6 +398,7 @@ def parse_args(input_args=None): parser.add_argument( "--snr_gamma", type=float, + default=None, help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " "More details here: https://arxiv.org/abs/2303.09556.", ) @@ -403,14 +424,14 @@ def parse_args(input_args=None): parser.add_argument( "--optimizer", type=str, - default="prodigy", + default="AdamW", help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'), ) parser.add_argument( "--use_8bit_adam", action="store_true", - help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if " "optimizer is not set to AdamW", + help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW", ) parser.add_argument( @@ -424,12 +445,12 @@ def parse_args(input_args=None): type=float, default=None, help="coefficients for computing the Prodidy stepsize using running averages. If set to None, " - "uses the value of square root of beta2", + "uses the value of square root of beta2. Ignored if optimizer is adamW", ) parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for " "text_encoder" + "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" ) parser.add_argument( @@ -440,16 +461,17 @@ def parse_args(input_args=None): ) parser.add_argument( - "--prodigy_use_bias_correction ", + "--prodigy_use_bias_correction", type=bool, default=True, - help="Turn on Adam's bias correction. True by default.", + help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW", ) parser.add_argument( - "--prodigy_safeguard_warmup ", + "--prodigy_safeguard_warmup", type=bool, default=True, - help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default.", + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. " + "Ignored if optimizer is adamW", ) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") @@ -574,13 +596,17 @@ def __init__( self.custom_instance_prompts = None self.class_prompt = class_prompt - load_as_dataset = False - # if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory, # we load the training data using load_dataset if args.dataset_name is not None: - from datasets import load_dataset - + try: + from datasets import load_dataset + except ImportError: + raise ImportError( + "You are trying to load your data using the datasets library. If you wish to train using custom " + "captions please install the datasets library: `pip install datasets`. If you wish to load a " + "local folder containing images only, specify --instance_data_dir instead." + ) # Downloading and loading a dataset from the hub. # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script @@ -589,21 +615,6 @@ def __init__( args.dataset_config_name, cache_dir=args.cache_dir, ) - load_as_dataset = True - else: - if not self.instance_data_root.exists(): - raise ValueError("Instance images root doesn't exists.") - - if args.instance_data_metadata_file_name in os.listdir(instance_data_root): - from datasets import load_dataset - - dataset = load_dataset( - instance_data_root, - cache_dir=args.cache_dir, - ) - load_as_dataset = True - - if load_as_dataset: # Preprocessing the datasets. column_names = dataset["train"].column_names @@ -636,8 +647,10 @@ def __init__( self.custom_instance_prompts = [] for caption in custom_instance_prompts: self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) - else: + if not self.instance_data_root.exists(): + raise ValueError("Instance images root doesn't exists.") + instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] self.custom_instance_prompts = None @@ -1667,7 +1680,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): images=images, base_model=args.pretrained_model_name_or_path, train_text_encoder=args.train_text_encoder, - prompt=args.instance_prompt, + instance_prompt=args.instance_prompt, + validation_prompt=args.validation_prompt, repo_folder=args.output_dir, vae_path=args.pretrained_vae_model_name_or_path, ) diff --git a/examples/test_examples.py b/examples/test_examples.py index 89e866231e89..414cfeb4d755 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -421,6 +421,47 @@ def test_dreambooth_lora_sdxl_with_text_encoder(self): ) self.assertTrue(starts_with_unet) + def test_dreambooth_lora_sdxl_custom_captions(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + """.split() + + run_command(self._launch_args + test_args) + + def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/dreambooth/train_dreambooth_lora_sdxl.py + --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe + --dataset_name hf-internal-testing/dummy_image_text_data + --instance_prompt photo + --resolution 64 + --train_batch_size 1 + --gradient_accumulation_steps 1 + --max_train_steps 2 + --learning_rate 5.0e-04 + --scale_lr + --lr_scheduler constant + --lr_warmup_steps 0 + --output_dir {tmpdir} + --train_text_encoder + """.split() + + run_command(self._launch_args + test_args) + def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self): pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe" From 3552e88d5a92b07c690f3a5104cd2e1cdbc0e891 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Tue, 21 Nov 2023 00:36:55 +0200 Subject: [PATCH 19/23] image path fix for readme --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 8d06f07ebb9f..c4530e767753 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -81,7 +81,7 @@ def save_model_card( parameters: negative_prompt: '-' output: - url: f"image_{i}.png" + url: "image_{i}.png" """ yaml = f""" @@ -1697,4 +1697,4 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if __name__ == "__main__": args = parse_args() - main(args) + main(args) \ No newline at end of file From 0db43e6362a40adad2ff6ae96fdbf0da244b8f0b Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Mon, 20 Nov 2023 22:39:48 +0000 Subject: [PATCH 20/23] code style --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index c4530e767753..7daf574d30e1 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1697,4 +1697,4 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers): if __name__ == "__main__": args = parse_args() - main(args) \ No newline at end of file + main(args) From b68527ef8a33dfaf0555806988f07168c964da4a Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Tue, 21 Nov 2023 01:08:41 +0200 Subject: [PATCH 21/23] bug fix --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 7daf574d30e1..7b8c78a15774 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -590,8 +590,6 @@ def __init__( self.size = size self.center_crop = center_crop - self.instance_data_root = Path(instance_data_root) - self.instance_prompt = instance_prompt self.custom_instance_prompts = None self.class_prompt = class_prompt @@ -648,6 +646,7 @@ def __init__( for caption in custom_instance_prompts: self.custom_instance_prompts.extend(itertools.repeat(caption, repeats)) else: + self.instance_data_root = Path(instance_data_root) if not self.instance_data_root.exists(): raise ValueError("Instance images root doesn't exists.") From 397b4dc46ddf53f3c8098c36a08c4c70f35528d4 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Tue, 21 Nov 2023 01:51:16 +0200 Subject: [PATCH 22/23] --caption_column arg --- examples/test_examples.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/test_examples.py b/examples/test_examples.py index 414cfeb4d755..292c433a3395 100644 --- a/examples/test_examples.py +++ b/examples/test_examples.py @@ -427,6 +427,7 @@ def test_dreambooth_lora_sdxl_custom_captions(self): examples/dreambooth/train_dreambooth_lora_sdxl.py --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe --dataset_name hf-internal-testing/dummy_image_text_data + --caption_column text --instance_prompt photo --resolution 64 --train_batch_size 1 @@ -447,6 +448,7 @@ def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self): examples/dreambooth/train_dreambooth_lora_sdxl.py --pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe --dataset_name hf-internal-testing/dummy_image_text_data + --caption_column text --instance_prompt photo --resolution 64 --train_batch_size 1 From 0cd1f8ae2f2bfe129e77056b35c39631f2d0ed08 Mon Sep 17 00:00:00 2001 From: Linoy Tsaban Date: Tue, 21 Nov 2023 11:35:38 +0200 Subject: [PATCH 23/23] readme fix --- examples/dreambooth/train_dreambooth_lora_sdxl.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 7b8c78a15774..d62e26765aba 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -78,10 +78,9 @@ def save_model_card( image.save(os.path.join(repo_folder, f"image_{i}.png")) img_str += f""" - text: '{validation_prompt if validation_prompt else ' ' }' - parameters: - negative_prompt: '-' output: - url: "image_{i}.png" + url: >- + "image_{i}.png" """ yaml = f"""