Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add dataloader prefetch factor in training args and trainer #28498

Merged
merged 14 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -806,6 +806,7 @@ def get_train_dataloader(self) -> DataLoader:
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for adding this @qmeeus. I was dealing with similar IterableDataset distributed loading issues recently as well :)

one quick question, is it supposed to add prefectch_factor configuration outside of this if branch, as this branch is for map-like dataset instead?


return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))

Expand Down Expand Up @@ -863,6 +864,7 @@ def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoa
if not isinstance(eval_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params))

Expand Down Expand Up @@ -895,6 +897,7 @@ def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
if not isinstance(test_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_eval_sampler(test_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor

# We use the same batch_size as for eval.
return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params))
Expand Down
25 changes: 24 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,9 @@ class TrainingArguments:
If True, the data loader will not shut down the worker processes after a dataset has been consumed once.
This allows to maintain the workers Dataset instances alive. Can potentially speed up training, but will
increase RAM usage. Will default to `False`.
dataloader_prefetch_factor (`int`, *optional*):
Number of batches loaded in advance by each worker.
2 means there will be a total of 2 * num_workers batches prefetched across all workers.
skip_memory_metrics (`bool`, *optional*, defaults to `True`):
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
down the training and evaluation speed.
Expand Down Expand Up @@ -989,7 +992,16 @@ class TrainingArguments:
)
},
)

dataloader_prefetch_factor: int = field(
default=None,
metadata={
"help": (
"Number of batches loaded in advance by each worker. "
"2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
"Default is unset"
)
},
)
past_index: int = field(
default=-1,
metadata={"help": "If >=0, uses the corresponding part of the output as the past state for next step."},
Expand Down Expand Up @@ -1737,6 +1749,12 @@ def __post_init__(self):
if self.use_cpu:
self.dataloader_pin_memory = False

if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
raise ValueError(
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
" when --dataloader_num_workers > 1."
)

if self.push_to_hub_token is not None:
warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
Expand Down Expand Up @@ -2634,6 +2652,7 @@ def set_dataloader(
num_workers: int = 0,
pin_memory: bool = True,
persistent_workers: bool = False,
prefetch_factor: Optional[int] = None,
auto_find_batch_size: bool = False,
ignore_data_skip: bool = False,
sampler_seed: Optional[int] = None,
Expand All @@ -2654,6 +2673,9 @@ def set_dataloader(
If True, the data loader will not shut down the worker processes after a dataset has been consumed
once. This allows to maintain the workers Dataset instances alive. Can potentially speed up training,
but will increase RAM usage. Will default to `False`.
prefetch_factor (`int`, *optional*):
Number of batches loaded in advance by each worker.
2 means there will be a total of 2 * num_workers batches prefetched across all workers.
auto_find_batch_size (`bool`, *optional*, defaults to `False`)
Whether to find a batch size that will fit into memory automatically through exponential decay,
avoiding CUDA Out-of-Memory errors. Requires accelerate to be installed (`pip install accelerate`)
Expand Down Expand Up @@ -2684,6 +2706,7 @@ def set_dataloader(
self.dataloader_num_workers = num_workers
self.dataloader_pin_memory = pin_memory
self.dataloader_persistent_workers = persistent_workers
self.dataloader_prefetch_factor = prefetch_factor
self.auto_find_batch_size = auto_find_batch_size
self.ignore_data_skip = ignore_data_skip
self.data_seed = sampler_seed
Expand Down
Loading