Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: always return correct batch size #3066

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/python/lance/sampler.py
Original file line number Diff line number Diff line change
@@ -139,6 +139,8 @@ def _filtered_efficient_sample(
elif filter is not None:
raise NotImplementedError(f"Can't yet run filter <{filter}> in-memory")
if table.num_rows > 0:
if table.num_rows > remaining_rows:
table = table.slice(0, remaining_rows)
tables.append(table)
remaining_rows -= table.num_rows
remaining_in_batch = remaining_in_batch - table.num_rows
19 changes: 14 additions & 5 deletions python/python/lance/torch/data.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,19 @@
__all__ = ["LanceDataset"]


# Convert an Arrow FSL array into a 2D torch tensor
def _fsl_to_tensor(arr: pa.FixedSizeListArray, dimension: int) -> torch.Tensor:
# Note: FixedSizeListArray.values does not take offset/len into account and
# so may we need to slice here
values = arr.values
start = arr.offset * dimension
num_vals = len(arr) * dimension
values = values.slice(start, num_vals)
# Convert to numpy
nparr = values.to_numpy(zero_copy_only=True).reshape(-1, dimension)
return torch.from_numpy(nparr)


def _to_tensor(
batch: pa.RecordBatch,
*,
@@ -54,11 +67,7 @@ def _to_tensor(
pa.types.is_floating(arr.type.value_type)
or pa.types.is_integer(arr.type.value_type)
):
np_tensor = arr.values.to_numpy(zero_copy_only=True).reshape(
-1, arr.type.list_size
)
tensor = torch.from_numpy(np_tensor)
del np_tensor
tensor = _fsl_to_tensor(arr, arr.type.list_size)
elif (
pa.types.is_integer(arr.type)
or pa.types.is_floating(arr.type)
26 changes: 26 additions & 0 deletions python/python/tests/torch_tests/test_data.py
Original file line number Diff line number Diff line change
@@ -185,6 +185,32 @@ def test_sample_batches(tmp_path: Path):
assert all_ids == [i for i in range(2000) if i // 25 % 2 == 1]


def test_filtered_sampling_odd_batch_size(tmp_path: Path):
tbl = pa.Table.from_pydict(
{
"vector": pa.array(
[[1.0, 2.0, 3.0] for _ in range(10000)], pa.list_(pa.float32(), 3)
),
"filterme": [i % 2 for i in range(10000)],
}
)

lance.write_dataset(tbl, tmp_path, max_rows_per_file=200)

ds = LanceDataset(
tmp_path,
batch_size=38,
columns=["vector"],
samples=38 * 256,
filter="vector is not null",
)

x = next(iter(ds))

assert x.shape[0] == 38
assert x.shape[1] == 3


def test_sample_batches_with_filter(tmp_path: Path):
NUM_ROWS = 10000
tbl = pa.Table.from_pydict(