diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 3ad27b90be25..06885b9207be 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -35,7 +35,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed -from datasets import load_dataset +from datasets import concatenate_datasets, load_dataset from huggingface_hub import create_repo, upload_folder from packaging import version from torchvision import transforms @@ -896,13 +896,19 @@ def preprocess_train(examples): # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 new_fingerprint = Hasher.hash(args) new_fingerprint_for_vae = Hasher.hash(vae_path) - train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) - train_dataset = train_dataset.map( + train_dataset_with_embeddings = train_dataset.map( + compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint + ) + train_dataset_with_vae = train_dataset.map( compute_vae_encodings_fn, batched=True, batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps, new_fingerprint=new_fingerprint_for_vae, ) + precomputed_dataset = concatenate_datasets( + [train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1 + ) + precomputed_dataset = precomputed_dataset.with_transform(preprocess_train) del text_encoders, tokenizers, vae gc.collect() @@ -925,7 +931,7 @@ def collate_fn(examples): # DataLoaders creation: train_dataloader = torch.utils.data.DataLoader( - train_dataset, + precomputed_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size, @@ -976,7 +982,7 @@ def unwrap_model(model): total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps logger.info("***** Running training *****") - logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num examples = {len(precomputed_dataset)}") logger.info(f" Num Epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")