Skip to content

Commit 085b4d9

Browse files
authored
fix: shard batch iterator can reads partial batches (#1889)
1 parent 628f7a3 commit 085b4d9

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

python/python/lance/_dataset/sharded_batch_iterator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def _gen_ranges():
141141
total,
142142
self._world_size * self._batch_size,
143143
):
144-
yield start, start + self._batch_size
144+
yield start, min(start + self._batch_size, total)
145145

146146
return self._ds._ds.take_scan(
147147
_gen_ranges(),

python/python/tests/test_dataset.py

+19
Original file line numberDiff line numberDiff line change
@@ -1416,3 +1416,22 @@ def test_sharded_iterator_batches(tmp_path: Path):
14161416
for j in range(i, i + BATCH_SIZE)
14171417
]
14181418
)
1419+
1420+
1421+
def test_sharded_iterator_non_full_batch(tmp_path: Path):
1422+
arr = pa.array(range(1186))
1423+
tbl = pa.table({"a": arr})
1424+
1425+
ds = lance.write_dataset(tbl, tmp_path)
1426+
shard_datast = ShardedBatchIterator(
1427+
ds,
1428+
1,
1429+
2,
1430+
columns=["a"],
1431+
batch_size=100,
1432+
granularity="batch",
1433+
)
1434+
batches = pa.concat_arrays([b["a"] for b in shard_datast])
1435+
1436+
# Can read partial batches
1437+
assert len(set(range(1100, 1186)) - set(batches.to_pylist())) == 0

0 commit comments

Comments
 (0)