From b548ab2927848acc96ac6e9b0e508cc2e45c210a Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Fri, 28 Oct 2022 02:16:30 +0900 Subject: [PATCH] Fix padding in dreambooth --- examples/dreambooth/train_dreambooth.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 8b0b20e07bb5..4f905702e960 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -496,7 +496,12 @@ def collate_fn(examples): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + input_ids = tokenizer.pad( + {"input_ids": input_ids}, + padding="max_length", + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids batch = { "input_ids": input_ids,