Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TrainingArguments regression with torch <2.0.0 for dataloader_prefetch_factor #29447

Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ACCELERATE_MIN_VERSION,
ExplicitEnum,
cached_property,
get_torch_version,
is_accelerate_available,
is_safetensors_available,
is_sagemaker_dp_enabled,
Expand Down Expand Up @@ -1023,13 +1024,13 @@ class TrainingArguments:
)
},
)
dataloader_prefetch_factor: int = field(
default=None,
dataloader_prefetch_factor: Optional[int] = field(
default=None if is_torch_available() and version.parse(get_torch_version()) >= version.parse("2.0.0") else 2,
ringohoffman marked this conversation as resolved.
Show resolved Hide resolved
metadata={
"help": (
"Number of batches loaded in advance by each worker. "
"2 means there will be a total of 2 * num_workers batches prefetched across all workers. "
"Default is unset"
"Default is None for PyTorch >= 2.0.0 and otherwise 2."
)
},
)
Expand Down Expand Up @@ -1807,12 +1808,6 @@ def __post_init__(self):
if self.use_cpu:
self.dataloader_pin_memory = False

if self.dataloader_num_workers == 0 and self.dataloader_prefetch_factor is not None:
raise ValueError(
"--dataloader_prefetch_factor can only be set when data is loaded in a different process, i.e."
" when --dataloader_num_workers > 1."
)
ringohoffman marked this conversation as resolved.
Show resolved Hide resolved

if self.push_to_hub_token is not None:
warnings.warn(
"`--push_to_hub_token` is deprecated and will be removed in version 5 of 🤗 Transformers. Use "
Expand Down
Loading