diff --git a/pyproject.toml b/pyproject.toml index 8a4673b4..eb7fc978 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "squirrel-core" -version = "0.19.0" +version = "0.19.1" description = "Squirrel is a Python library that enables ML teams to share, load, and transform data in a collaborative, flexible and efficient way." authors = ["Merantix Momentum"] license = "Apache 2.0" diff --git a/squirrel/iterstream/base.py b/squirrel/iterstream/base.py index 341fc70d..a63c71d5 100644 --- a/squirrel/iterstream/base.py +++ b/squirrel/iterstream/base.py @@ -449,7 +449,7 @@ def __iter__(self) -> t.Iterator: except StopIteration: if not _started: return - current_ = iter(deepcopy(self.source)) + current_ = iter(deepcopy(self.source)) else: for _ in range(self.n): yield from iter(deepcopy(self.source)) diff --git a/test/test_iterstream/test_stream.py b/test/test_iterstream/test_stream.py index ceba7cb6..aa496643 100644 --- a/test/test_iterstream/test_stream.py +++ b/test/test_iterstream/test_stream.py @@ -191,6 +191,17 @@ def test_loop(samples: t.List[SampleType], n: int) -> None: assert IterableSource([1, 2, 3]).loop(3).collect() == [1, 2, 3, 1, 2, 3, 1, 2, 3] +def test_loop_infinite() -> None: + """Test infinite loop""" + it = IterableSource([1, 2, 3]).loop() + data = [] + for i, x in enumerate(it): + data.append(x) + if i == 8: + break + assert data == [1, 2, 3, 1, 2, 3, 1, 2, 3] + + def test_take_side_effect() -> None: """Test that take_ fetches correct number of elements from an iterator.""" lst = [1, 2, 3, 4]