Skip to content

Commit

Permalink
Merge pull request #78 from tsugumi-sys/adding-test-batch-size
Browse files Browse the repository at this point in the history
adding validation sets
  • Loading branch information
tsugumi-sys authored Jan 15, 2024
2 parents 337b6af + 40e629c commit 1ddaa91
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
8 changes: 7 additions & 1 deletion data_loaders/moving_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ class MovingMNISTDataLoaders(BaseDataLoaders):
def __init__(
self,
train_batch_size: int,
validation_batch_size: int = 1,
input_frames: int = 10,
label_frames: int | None = None,
split_ratios: List[float] | None = None,
shuffle: bool = True,
):
self.train_batch_size = train_batch_size
self.validation_batch_size = validation_batch_size
self.input_frames = input_frames
self.label_frames = label_frames
self.shuffle = shuffle
Expand Down Expand Up @@ -76,7 +78,11 @@ def train_dataloader(self) -> DataLoader:

@property
def validation_dataloader(self) -> DataLoader:
return DataLoader(self.valid_dataset, batch_size=1, shuffle=self.shuffle)
return DataLoader(
self.valid_dataset,
batch_size=self.validation_batch_size,
shuffle=self.shuffle,
)

@property
def test_dataloader(self) -> DataLoader:
Expand Down
28 changes: 26 additions & 2 deletions tests/data_loaders/test_moving_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_MovingMNISTDataLoaders(mocked_MovingMNIST):


@patch("data_loaders.moving_mnist.MovingMNIST")
def test_MovingMNISTDataLoaders_label_frames_set(mocked_MovingMNIST):
def test_MovingMNISTDataLoaders_set_label_frames(mocked_MovingMNIST):
dataset_length = 10
train_batch_size = 2
input_frames = 10
Expand All @@ -48,7 +48,31 @@ def test_MovingMNISTDataLoaders_label_frames_set(mocked_MovingMNIST):


@patch("data_loaders.moving_mnist.MovingMNIST")
def test_MovingMNISTDataLoaders_split_ratio_set(mocked_MovingMNIST):
def test_MovingMNISTDataLoaders_set_validation_batch_size(mocked_MovingMNIST):
dataset_length = 10
train_batch_size = 2
validation_batch_size = 2
input_frames = 10
label_frames = 1
mocked_MovingMNIST.return_value = MockMovingMNIST(dataset_length=dataset_length)
dataloaders = MovingMNISTDataLoaders(
train_batch_size=train_batch_size,
validation_batch_size=validation_batch_size,
input_frames=input_frames,
label_frames=label_frames,
)
assert len(dataloaders.train_dataloader) == 4
assert len(dataloaders.validation_dataloader) == 1
assert len(dataloaders.test_dataloader) == 1
input, target = next(iter(dataloaders.train_dataloader))
assert input.size(0) == train_batch_size
assert input.size(2) == input_frames
assert target.size(0) == train_batch_size
assert target.size(2) == label_frames


@patch("data_loaders.moving_mnist.MovingMNIST")
def test_MovingMNISTDataLoaders_set_split_ratio(mocked_MovingMNIST):
dataset_length = 10
train_batch_size = 1
input_frames = 10
Expand Down

0 comments on commit 1ddaa91

Please sign in to comment.