Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type annotation for train_dataset and eval_dataset params of Trainer incompatible with IterableDataset #29678

Closed
4 tasks
stevemadere opened this issue Mar 15, 2024 · 3 comments · Fixed by #29738
Closed
4 tasks

Comments

@stevemadere
Copy link
Contributor

System Info

The constructor for Trainer declares the following parameters:

    def __init__(
        self,
        model: Union[PreTrainedModel, nn.Module] = None,
        args: TrainingArguments = None,
        data_collator: Optional[DataCollator] = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None,
        tokenizer: Optional[PreTrainedTokenizerBase] = None,
        model_init: Optional[Callable[[], PreTrainedModel]] = None,
        compute_metrics: Optional[Callable[[EvalPrediction], Dict]] = None,
        callbacks: Optional[List[TrainerCallback]] = None,
        optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
        preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
    ):


evidence:
permalink to source

But the doc says

   train_dataset (`torch.utils.data.Dataset` or `torch.utils.data.IterableDataset`, *optional*):
        The dataset to use for training. If it is a [`~datasets.Dataset`], columns not accepted by the
        `model.forward()` method are automatically removed.

evidence:
permalink to source

If I try to instantiate a Trainer with an actual IterableDataset for the train_dataset parameter, PyRight complains (rightly) that

Argument of type "IterableDataset" cannot be assigned to parameter
"train_dataset" of type "Dataset[Unknown] | None" in function "init"
Type "IterableDataset" cannot be assigned to type "Dataset[Unknown] | None"
"IterableDataset" is incompatible with "Dataset[Unknown]"
 "IterableDataset" is incompatible with "None"

Please change the type hints for these parameters to allow for IterableDataset values as well.

@muellerz , @pacman100

workaround:

when pyright complains about your source code, follow the params with a commend like this:
train_dataset=my_iterable_ds, # type: ignore (IterableDataset is apparently not envisioned here)

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

# the core issue is this
iterable_ds = any_dataset.to_iterable_dataset()
trainer = Trainer(model=any_model, 
                            train_dataset = iterable_ds, # causes PyRight warnings
                            args=training_args)
# pedantically complete example
from transformers import BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

# Load a portion of the dataset for quick demonstration
dataset = load_dataset('imdb', split='train[:1%]')

# Convert the loaded dataset to an iterable dataset
iterable_ds = dataset.to_iterable_dataset()

# Load a pre-trained model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased')

# Define training arguments
training_args = TrainingArguments(
    output_dir='./results',          # Output directory for model checkpoints
    num_train_epochs=1,              # Total number of training epochs
    per_device_train_batch_size=8,   # Batch size per device during training
    warmup_steps=500,                # Number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # Strength of weight decay
    logging_dir='./logs',            # Directory for storing logs
    logging_steps=10,
)

# Initialize the Trainer
trainer = Trainer(
    model=model,
    args=training_args, 
    train_dataset=iterable_ds, # Causes PyRight warnings
)

Expected behavior

No pyright warnings when passing an IterableDataset object as the train_dataset param of the Trainer constructor.

@amyeroberts
Copy link
Collaborator

cc @muellerzr

@amyeroberts
Copy link
Collaborator

Hi @stevemadere, thanks for opening this issue!

We want to be consistent with types as a form of documentation to help users, but we don't strictly enforce them and we don't guarantee compatibility or passing checks with tools such as mypy or pyright. We prioritise readability and practicality over full typing coverage.

If there's a place in the code you think a type should be updated, please feel free to open a PR - we'd be happy to review! 🤗

@stevemadere
Copy link
Contributor Author

Here you go #29738

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants