From 6a4b25635da84e27696c9a3022c7bae3c784de03 Mon Sep 17 00:00:00 2001 From: Steve Madere Date: Tue, 19 Mar 2024 15:27:06 -0500 Subject: [PATCH 1/2] Fixed typehint for train_dataset param in Trainer.__init__(). Added IterableDataset option. --- src/transformers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index bef4b24c517c..3e0eabd465dd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -52,7 +52,7 @@ from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn -from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler +from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, IterableDataset from . import __version__ from .configuration_utils import PretrainedConfig @@ -350,7 +350,7 @@ def __init__( model: Union[PreTrainedModel, nn.Module] = None, args: TrainingArguments = None, data_collator: Optional[DataCollator] = None, - train_dataset: Optional[Dataset] = None, + train_dataset: Optional[Union[Dataset, IterableDataset]] = None, eval_dataset: Optional[Union[Dataset, Dict[str, Dataset]]] = None, tokenizer: Optional[PreTrainedTokenizerBase] = None, model_init: Optional[Callable[[], PreTrainedModel]] = None, From afc2ed0996a37baff6705ef2630e8f709e1b5490 Mon Sep 17 00:00:00 2001 From: Steve Madere Date: Tue, 19 Mar 2024 17:19:12 -0500 Subject: [PATCH 2/2] make fixup --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 3e0eabd465dd..91a80fb5dd94 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -52,7 +52,7 @@ from huggingface_hub import ModelCard, create_repo, upload_folder from packaging import version from torch import nn -from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, IterableDataset +from torch.utils.data import DataLoader, Dataset, IterableDataset, RandomSampler, SequentialSampler from . import __version__ from .configuration_utils import PretrainedConfig