-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix for pyTorch 1.2 #549
fix for pyTorch 1.2 #549
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current code and this change test the training DataLoader against IterableDataset, rather than the dataset attribute of that DataLoader. I made a PR in #547 but I think it's cleaner to do it here where you are making other changes.
@@ -24,7 +32,7 @@ def init_train_dataloader(self, model): | |||
self.get_train_dataloader = model.train_dataloader | |||
|
|||
# determine number of training batches | |||
if isinstance(self.get_train_dataloader(), IterableDataset): | |||
if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader(), IterableDataset): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader(), IterableDataset): | |
if EXIST_ITER_DATASET and isinstance(self.get_train_dataloader().dataset, IterableDataset): |
I think that we want to check the dataset attribute rather than the DataLoader itself (since it contains the IterableDataset). If we fix it here I can remove #547 .
@@ -167,7 +175,8 @@ def get_dataloaders(self, model): | |||
self.get_val_dataloaders() | |||
|
|||
# support IterableDataset for train data | |||
self.is_iterable_train_dataloader = isinstance(self.get_train_dataloader(), IterableDataset) | |||
self.is_iterable_train_dataloader = (EXIST_ITER_DATASET and | |||
isinstance(self.get_train_dataloader(), IterableDataset)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
isinstance(self.get_train_dataloader(), IterableDataset)) | |
isinstance(self.get_train_dataloader().dataset, IterableDataset)) |
Same note as above
from torch.utils.data import IterableDataset | ||
except ImportError: | ||
# loading for pyTorch 1.2 | ||
print('Your version of pyTorch does not support `IterableDataset`, please upgrade to 1.3+') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome warning :)
@MikeScarp thx cor you comment, I will have look at it... This is more about solving issue with missing dataset type in older PyTorch version then doing proper testing, so I believe that you can keep the other PR :) |
@williamFalcon rebased :) |
great job @Borda. Good feedback @MikeScarp |
What does this PR do?
Follow-up to #546 Fixes #491
The
IterableDataset
is just missing in lower versions so it cannot be "tested". First, it was introduced in 1.2 releasePR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.