Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f1cca68

Browse files
committedFeb 2, 2024·
add test
1 parent 9e7482f commit f1cca68

File tree

2 files changed

+42
-9
lines changed

2 files changed

+42
-9
lines changed
 

‎python/python/lance/sampler.py

+21-7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import gc
2020
import logging
21+
import warnings
2122
from abc import ABC, abstractmethod
2223
from dataclasses import dataclass, field
2324
from heapq import heappush, heappushpop
@@ -266,19 +267,32 @@ def __init__(self, rank: int, world_size: int):
266267
self._world_size = world_size
267268

268269
def __call__(
269-
self, dataset: lance.LanceDataset, *args, **kwargs
270+
self,
271+
dataset: lance.LanceDataset,
272+
*args,
273+
batch_size: int = 128,
274+
columns: Optional[List[str]] = None,
275+
batch_readahead: int = 16,
276+
with_row_id: Optional[bool] = None,
277+
**kwargs,
270278
) -> Generator[lance.RecordBatch, None, None]:
271-
total = self._ds.count_rows()
279+
total = dataset.count_rows()
280+
281+
if with_row_id is not None:
282+
warnings.warn(
283+
"with_row_id is not supported for ShardedBatchSampler",
284+
)
272285

273286
def _gen_ranges():
274287
for start in range(
275-
self._rank * self._batch_size,
288+
self._rank * batch_size,
276289
total,
277-
self._world_size * self._batch_size,
290+
self._world_size * batch_size,
278291
):
279-
yield start, min(start + self._batch_size, total)
292+
yield start, min(start + batch_size, total)
280293

281-
return dataset.take_scan(
294+
return dataset._ds.take_scan(
282295
_gen_ranges(),
283-
columns=self._columns,
296+
columns=columns,
297+
batch_readahead=batch_readahead,
284298
)

‎python/python/tests/torch_tests/test_data.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@
1313
# limitations under the License.
1414

1515
import shutil
16-
from pathlib import Path
1716
from itertools import chain
17+
from pathlib import Path
1818

1919
import lance
2020
import numpy as np
2121
import pyarrow as pa
2222
import pytest
23+
from lance.sampler import ShardedBatchSampler, ShardedFragmentSampler
2324

2425
torch = pytest.importorskip("torch")
2526
from lance.torch.data import LanceDataset # noqa: E402
26-
from lance.sampler import ShardedFragmentSampler, ShardedBatchSampler, FullScanSampler
2727

2828

2929
def test_iter_over_dataset(tmp_path):
@@ -131,3 +131,22 @@ def test_sample_fragments(tmp_path: Path):
131131

132132
all_ids = list(chain.from_iterable([batch["ids"].cpu().numpy() for batch in ds]))
133133
assert all_ids == [i for i in range(2000) if i // 100 % 2 == 1]
134+
135+
136+
def test_sample_batches(tmp_path: Path):
137+
arr = pa.array(range(2000))
138+
tbl = pa.Table.from_arrays([arr], ["ids"])
139+
140+
# Write 20 files
141+
lance.write_dataset(tbl, tmp_path, max_rows_per_file=100)
142+
143+
ds = LanceDataset(
144+
tmp_path,
145+
batch_size=25,
146+
columns=["ids"],
147+
with_row_id=True,
148+
sampler=ShardedBatchSampler(rank=1, world_size=2),
149+
)
150+
151+
all_ids = list(chain.from_iterable([batch.cpu().numpy() for batch in ds]))
152+
assert all_ids == [i for i in range(2000) if i // 25 % 2 == 1]

0 commit comments

Comments
 (0)
Please sign in to comment.