From 19f8caea401d51c67c7a4b11ee2ecfa21e13a1e9 Mon Sep 17 00:00:00 2001 From: Javier Vargas Date: Sat, 23 Oct 2021 16:02:27 +0200 Subject: [PATCH] Added dataset --- experiment1/__init__.py | 0 experiment1/dataset.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 experiment1/__init__.py create mode 100644 experiment1/dataset.py diff --git a/experiment1/__init__.py b/experiment1/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/experiment1/dataset.py b/experiment1/dataset.py new file mode 100644 index 0000000..a48d7df --- /dev/null +++ b/experiment1/dataset.py @@ -0,0 +1,29 @@ +import torch as T +from torch.utils.data import Dataset + + +class NumberAdd(Dataset): + """Dataset for adding numbers""" + def __init__(self, total_samples: int) -> None: + """Generate dataset with pairs of numbers and their sum""" + self.N = total_samples + self.samples = T.randint(low=-1000, high=1000, size=(self.N, 2)) + self.targets = T.sum(self.samples, dim=-1) + + def __getitem__(self, index: int) -> dict: + """Get a pair of sample / target""" + sample = self.samples[index] + target = self.targets[index] + return {"sample": sample, "target": target} + + def __len__(self) -> int: + """Size of the dataset""" + return self.N + + +if __name__ == "__main__": + ds = NumberAdd(total_samples=100) + for i in range(len(ds)): + data = ds[i] + print(data["sample"], data["target"]) + assert T.sum(data["sample"]) == data["target"] \ No newline at end of file