Skip to content

Commit

Permalink
dataset returning tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
JVGD committed Nov 1, 2021
1 parent ec87042 commit 2eb9816
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions experiment1/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ def __init__(self, total_samples: int) -> None:
self.samples = T.randint(low=-1000, high=1000, size=(self.N, 2))
self.targets = T.sum(self.samples, dim=-1)

def __getitem__(self, index: int) -> dict:
def __getitem__(self, index: int) -> tuple:
"""Get a pair of sample / target"""
sample = self.samples[index]
target = self.targets[index]
return {"sample": sample, "target": target}
return sample, target

def __len__(self) -> int:
"""Size of the dataset"""
Expand All @@ -24,6 +24,6 @@ def __len__(self) -> int:
if __name__ == "__main__":
ds = NumberAdd(total_samples=100)
for i in range(len(ds)):
data = ds[i]
print(data["sample"], data["target"])
assert T.sum(data["sample"]) == data["target"]
sample, target = ds[i]
print(sample, target)
assert T.sum(sample) == target

0 comments on commit 2eb9816

Please sign in to comment.