diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 53a55e38642b..92379a7f838d 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -66,6 +66,9 @@ from diffusers.utils.torch_utils import is_compiled_module +if is_wandb_available(): + import wandb + # Will error if the minimal version of diffusers is not installed. Remove at your own risks. check_min_version("0.27.0.dev0") @@ -113,6 +116,71 @@ def save_model_card( model_card.save(os.path.join(repo_folder, "README.md")) +def log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=False, +): + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it + scheduler_args = {} + + if "variance_type" in pipeline.scheduler.config: + variance_type = pipeline.scheduler.config.variance_type + + if variance_type in ["learned", "learned_range"]: + variance_type = "fixed_small" + + scheduler_args["variance_type"] = variance_type + + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) + + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + + if args.validation_images is None: + images = [] + for _ in range(args.num_validation_images): + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, generator=generator).images[0] + images.append(image) + else: + images = [] + for image in args.validation_images: + image = Image.open(image) + with torch.cuda.amp.autocast(): + image = pipeline(**pipeline_args, image=image, generator=generator).images[0] + images.append(image) + + for tracker in accelerator.trackers: + phase_name = "test" if is_final_validation else "validation" + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images(phase_name, np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + phase_name: [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() + + return images + + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( pretrained_model_name_or_path, @@ -684,7 +752,6 @@ def main(args): if args.report_to == "wandb": if not is_wandb_available(): raise ImportError("Make sure to install wandb if you want to use it for logging during training.") - import wandb # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models. @@ -1265,10 +1332,6 @@ def compute_text_embeddings(prompt): if accelerator.is_main_process: if args.validation_prompt is not None and epoch % args.validation_epochs == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) # create pipeline pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -1279,26 +1342,6 @@ def compute_text_embeddings(prompt): torch_dtype=weight_dtype, ) - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config( - pipeline.scheduler.config, **scheduler_args - ) - - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None if args.pre_compute_text_embeddings: pipeline_args = { "prompt_embeds": validation_prompt_encoder_hidden_states, @@ -1307,36 +1350,13 @@ def compute_text_embeddings(prompt): else: pipeline_args = {"prompt": args.validation_prompt} - if args.validation_images is None: - images = [] - for _ in range(args.num_validation_images): - with torch.cuda.amp.autocast(): - image = pipeline(**pipeline_args, generator=generator).images[0] - images.append(image) - else: - images = [] - for image in args.validation_images: - image = Image.open(image) - with torch.cuda.amp.autocast(): - image = pipeline(**pipeline_args, image=image, generator=generator).images[0] - images.append(image) - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) - - del pipeline - torch.cuda.empty_cache() + images = log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + ) # Save the lora layers accelerator.wait_for_everyone() @@ -1364,46 +1384,21 @@ def compute_text_embeddings(prompt): args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype ) - # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it - scheduler_args = {} - - if "variance_type" in pipeline.scheduler.config: - variance_type = pipeline.scheduler.config.variance_type - - if variance_type in ["learned", "learned_range"]: - variance_type = "fixed_small" - - scheduler_args["variance_type"] = variance_type - - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - - pipeline = pipeline.to(accelerator.device) - # load attention processors pipeline.load_lora_weights(args.output_dir, weight_name="pytorch_lora_weights.safetensors") # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + pipeline_args = {"prompt": args.validation_prompt, "num_inference_steps": 25} + images = log_validation( + pipeline, + args, + accelerator, + pipeline_args, + epoch, + is_final_validation=True, + ) if args.push_to_hub: save_model_card(