diff --git a/experiment1/dataset.py b/experiment1/dataset.py index a48d7df..7d55057 100644 --- a/experiment1/dataset.py +++ b/experiment1/dataset.py @@ -10,11 +10,11 @@ def __init__(self, total_samples: int) -> None: self.samples = T.randint(low=-1000, high=1000, size=(self.N, 2)) self.targets = T.sum(self.samples, dim=-1) - def __getitem__(self, index: int) -> dict: + def __getitem__(self, index: int) -> tuple: """Get a pair of sample / target""" sample = self.samples[index] target = self.targets[index] - return {"sample": sample, "target": target} + return sample, target def __len__(self) -> int: """Size of the dataset""" @@ -24,6 +24,6 @@ def __len__(self) -> int: if __name__ == "__main__": ds = NumberAdd(total_samples=100) for i in range(len(ds)): - data = ds[i] - print(data["sample"], data["target"]) - assert T.sum(data["sample"]) == data["target"] \ No newline at end of file + sample, target = ds[i] + print(sample, target) + assert T.sum(sample) == target \ No newline at end of file