Skip to content

Commit

Permalink
add commandline param for per image loss logging
Browse files Browse the repository at this point in the history
  • Loading branch information
noskill committed Mar 12, 2024
1 parent 4cb59f4 commit a189b04
Showing 1 changed file with 18 additions and 8 deletions.
26 changes: 18 additions & 8 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import logging
import math
import os
import os.path
import random
import shutil
from pathlib import Path
Expand Down Expand Up @@ -415,6 +414,11 @@ def parse_args(input_args=None):
default=4,
help=("The dimension of the LoRA update matrices."),
)
parser.add_argument(
"--debug-loss",
action="store_true",
default=False,
help="debug loss for each image, if filenames are awailable in the dataset")

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -831,7 +835,6 @@ def tokenize_captions(examples, is_train=True):
)

def preprocess_train(examples):
fnames = [os.path.basename(image.filename) for image in examples[image_column]]
images = [image.convert("RGB") for image in examples[image_column]]
# image aug
original_sizes = []
Expand Down Expand Up @@ -861,7 +864,10 @@ def preprocess_train(examples):
tokens_one, tokens_two = tokenize_captions(examples)
examples["input_ids_one"] = tokens_one
examples["input_ids_two"] = tokens_two
examples["filenames"] = fnames
if args.debug_loss:
fnames = [os.path.basename(image.filename) for image in examples[image_column] if image.filename]
if fnames:
examples["filenames"] = fnames
return examples

with accelerator.main_process_first():
Expand All @@ -877,15 +883,19 @@ def collate_fn(examples):
crop_top_lefts = [example["crop_top_lefts"] for example in examples]
input_ids_one = torch.stack([example["input_ids_one"] for example in examples])
input_ids_two = torch.stack([example["input_ids_two"] for example in examples])
return {
result = {
"pixel_values": pixel_values,
"input_ids_one": input_ids_one,
"input_ids_two": input_ids_two,
"original_sizes": original_sizes,
"crop_top_lefts": crop_top_lefts,
"filenames": [example['filenames'] for example in examples]
}

filenames = [example['filenames'] for example in examples if 'filenames' in example]
if filenames:
result['filenames'] = filenames
return result

# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
Expand Down Expand Up @@ -1077,9 +1087,9 @@ def compute_time_ids(original_size, crops_coords_top_left):
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()

for fname in batch['filenames']:
accelerator.log({'loss_for_' + fname: loss}, step=global_step)
if args.debug_loss and 'filenames' in batch:
for fname in batch['filenames']:
accelerator.log({'loss_for_' + fname: loss}, step=global_step)
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
train_loss += avg_loss.item() / args.gradient_accumulation_steps
Expand Down

0 comments on commit a189b04

Please sign in to comment.