diff --git a/test/test_datasets.py b/test/test_datasets.py index 74d03e7ea15..48d08b846de 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -1504,14 +1504,16 @@ class MovingMNISTTestCase(datasets_utils.DatasetTestCase): ADDITIONAL_CONFIGS = combinations_grid(split=(None, "train", "test"), split_ratio=(10, 1, 19)) + _NUM_FRAMES = 20 + def inject_fake_data(self, tmpdir, config): base_folder = os.path.join(tmpdir, self.DATASET_CLASS.__name__) os.makedirs(base_folder, exist_ok=True) - num_samples = 20 + num_samples = 5 data = np.concatenate( [ np.zeros((config["split_ratio"], num_samples, 64, 64)), - np.ones((20 - config["split_ratio"], num_samples, 64, 64)), + np.ones((self._NUM_FRAMES - config["split_ratio"], num_samples, 64, 64)), ] ) np.save(os.path.join(base_folder, "mnist_test_seq.npy"), data) @@ -1519,14 +1521,13 @@ def inject_fake_data(self, tmpdir, config): @datasets_utils.test_all_configs def test_split(self, config): - if config["split"] is None: - return - - with self.create_dataset(config) as (dataset, info): + with self.create_dataset(config) as (dataset, _): if config["split"] == "train": assert (dataset.data == 0).all() - else: + elif config["split"] == "test": assert (dataset.data == 1).all() + else: + assert dataset.data.size()[1] == self._NUM_FRAMES class DatasetFolderTestCase(datasets_utils.ImageDatasetTestCase): diff --git a/torchvision/datasets/moving_mnist.py b/torchvision/datasets/moving_mnist.py index afff0bfa3b9..ac5a2b1503d 100644 --- a/torchvision/datasets/moving_mnist.py +++ b/torchvision/datasets/moving_mnist.py @@ -58,7 +58,7 @@ def __init__( data = torch.from_numpy(np.load(os.path.join(self._base_folder, self._filename))) if self.split == "train": data = data[: self.split_ratio] - else: + elif self.split == "test": data = data[self.split_ratio :] self.data = data.transpose(0, 1).unsqueeze(2).contiguous()