diff --git a/data_loaders/moving_mnist.py b/data_loaders/moving_mnist.py index 6704312..056b907 100644 --- a/data_loaders/moving_mnist.py +++ b/data_loaders/moving_mnist.py @@ -56,7 +56,9 @@ def __init__( ) self.split_ratios = split_ratios - moving_mnist = MovingMNIST(root="./data", download=True) + moving_mnist = MovingMNIST( + root="./data", download=True, transform=lambda x: x / 255.0 + ) train_dataset, valid_dataset, test_dataset = random_split( moving_mnist, [*self.split_ratios],