Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Map speedup #6745

Merged
merged 11 commits into from
Mar 1, 2024
18 changes: 13 additions & 5 deletions examples/text_to_image/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from datasets import load_dataset, concatenate_datasets
lhoestq marked this conversation as resolved.
Show resolved Hide resolved
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
Expand Down Expand Up @@ -896,13 +896,21 @@ def preprocess_train(examples):
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
new_fingerprint_for_vae = Hasher.hash("vae")
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
train_dataset = train_dataset.map(
train_dataset_with_embeddings = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
train_dataset_with_vae = train_dataset.map(
compute_vae_encodings_fn,
batched=True,
batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
new_fingerprint=new_fingerprint_for_vae,
)
precomputed_dataset = concatenate_datasets(
[
train_dataset_with_embeddings,
train_dataset_with_vae.remove_columns(['image', 'text'])
],
axis=1
)
precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)

del text_encoders, tokenizers, vae
gc.collect()
Expand All @@ -925,7 +933,7 @@ def collate_fn(examples):

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
precomputed_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
Expand Down Expand Up @@ -976,7 +984,7 @@ def unwrap_model(model):
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num examples = {len(precomputed_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
Expand Down
Loading