Skip to content

Commit

Permalink
Added dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
JVGD committed Oct 23, 2021
0 parents commit 19f8cae
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
Empty file added experiment1/__init__.py
Empty file.
29 changes: 29 additions & 0 deletions experiment1/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import torch as T
from torch.utils.data import Dataset


class NumberAdd(Dataset):
"""Dataset for adding numbers"""
def __init__(self, total_samples: int) -> None:
"""Generate dataset with pairs of numbers and their sum"""
self.N = total_samples
self.samples = T.randint(low=-1000, high=1000, size=(self.N, 2))
self.targets = T.sum(self.samples, dim=-1)

def __getitem__(self, index: int) -> dict:
"""Get a pair of sample / target"""
sample = self.samples[index]
target = self.targets[index]
return {"sample": sample, "target": target}

def __len__(self) -> int:
"""Size of the dataset"""
return self.N


if __name__ == "__main__":
ds = NumberAdd(total_samples=100)
for i in range(len(ds)):
data = ds[i]
print(data["sample"], data["target"])
assert T.sum(data["sample"]) == data["target"]

0 comments on commit 19f8cae

Please sign in to comment.