You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
`model.forward()` method are automatically removed.
If I try to instantiate a Trainer with an actual IterableDataset for the train_dataset parameter, PyRight complains (rightly) that
Argument of type "IterableDataset" cannot be assigned to parameter
"train_dataset" of type "Dataset[Unknown] | None" in function "init"
Type "IterableDataset" cannot be assigned to type "Dataset[Unknown] | None"
"IterableDataset" is incompatible with "Dataset[Unknown]"
"IterableDataset" is incompatible with "None"
Please change the type hints for these parameters to allow for IterableDataset values as well.
when pyright complains about your source code, follow the params with a commend like this: train_dataset=my_iterable_ds, # type: ignore (IterableDataset is apparently not envisioned here)
Who can help?
No response
Information
The official example scripts
My own modified scripts
Tasks
An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
My own task or dataset (give details below)
Reproduction
# the core issue is thisiterable_ds=any_dataset.to_iterable_dataset()
trainer=Trainer(model=any_model,
train_dataset=iterable_ds, # causes PyRight warningsargs=training_args)
# pedantically complete examplefromtransformersimportBertForSequenceClassification, Trainer, TrainingArgumentsfromdatasetsimportload_dataset# Load a portion of the dataset for quick demonstrationdataset=load_dataset('imdb', split='train[:1%]')
# Convert the loaded dataset to an iterable datasetiterable_ds=dataset.to_iterable_dataset()
# Load a pre-trained modelmodel=BertForSequenceClassification.from_pretrained('bert-base-uncased')
# Define training argumentstraining_args=TrainingArguments(
output_dir='./results', # Output directory for model checkpointsnum_train_epochs=1, # Total number of training epochsper_device_train_batch_size=8, # Batch size per device during trainingwarmup_steps=500, # Number of warmup steps for learning rate schedulerweight_decay=0.01, # Strength of weight decaylogging_dir='./logs', # Directory for storing logslogging_steps=10,
)
# Initialize the Trainertrainer=Trainer(
model=model,
args=training_args,
train_dataset=iterable_ds, # Causes PyRight warnings
)
Expected behavior
No pyright warnings when passing an IterableDataset object as the train_dataset param of the Trainer constructor.
The text was updated successfully, but these errors were encountered:
We want to be consistent with types as a form of documentation to help users, but we don't strictly enforce them and we don't guarantee compatibility or passing checks with tools such as mypy or pyright. We prioritise readability and practicality over full typing coverage.
If there's a place in the code you think a type should be updated, please feel free to open a PR - we'd be happy to review! 🤗
System Info
The constructor for Trainer declares the following parameters:
evidence:
permalink to source
But the doc says
evidence:
permalink to source
If I try to instantiate a Trainer with an actual IterableDataset for the train_dataset parameter, PyRight complains (rightly) that
Please change the type hints for these parameters to allow for IterableDataset values as well.
@muellerz , @pacman100
workaround:
when pyright complains about your source code, follow the params with a commend like this:
train_dataset=my_iterable_ds, # type: ignore (IterableDataset is apparently not envisioned here)
Who can help?
No response
Information
Tasks
examples
folder (such as GLUE/SQuAD, ...)Reproduction
Expected behavior
No pyright warnings when passing an IterableDataset object as the train_dataset param of the Trainer constructor.
The text was updated successfully, but these errors were encountered: