Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

log loss per image #7278

Merged
merged 5 commits into from
Mar 14, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import logging
import math
import os
import os.path
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have os module imported. I think we can reuse it.

import random
import shutil
from pathlib import Path
Expand Down Expand Up @@ -586,6 +587,7 @@ def main(args):
# Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
unet.to(accelerator.device, dtype=weight_dtype)

if args.pretrained_vae_model_name_or_path is None:
vae.to(accelerator.device, dtype=torch.float32)
else:
Expand Down Expand Up @@ -829,6 +831,7 @@ def tokenize_captions(examples, is_train=True):
)

def preprocess_train(examples):
fnames = [os.path.basename(image.filename) for image in examples[image_column]]
images = [image.convert("RGB") for image in examples[image_column]]
# image aug
original_sizes = []
Expand Down Expand Up @@ -858,13 +861,14 @@ def preprocess_train(examples):
tokens_one, tokens_two = tokenize_captions(examples)
examples["input_ids_one"] = tokens_one
examples["input_ids_two"] = tokens_two
examples["filenames"] = fnames
return examples

with accelerator.main_process_first():
if args.max_train_samples is not None:
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
# Set the training transforms
train_dataset = dataset["train"].with_transform(preprocess_train)
train_dataset = dataset["train"].with_transform(preprocess_train, output_all_columns=True)

def collate_fn(examples):
pixel_values = torch.stack([example["pixel_values"] for example in examples])
Expand All @@ -879,6 +883,7 @@ def collate_fn(examples):
"input_ids_two": input_ids_two,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
"filenames": [example['filenames'] for example in examples]
}

# DataLoaders creation:
Expand Down Expand Up @@ -1073,10 +1078,13 @@ def compute_time_ids(original_size, crops_coords_top_left):
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()

for fname in batch['filenames']:
accelerator.log({'loss_for_' + fname: loss}, step=global_step)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we accept a CLI argument for this? That will be nice no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, it's nice to have.

# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps


# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
Expand Down
Loading