Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: handle nulls when creating indices with cuda #2910

Merged
merged 3 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 73 additions & 11 deletions python/python/lance/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import gc
import logging
import math
import random
import warnings
from abc import ABC, abstractmethod
Expand All @@ -15,6 +16,7 @@
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar, Union

import pyarrow as pa
import pyarrow.compute as pc

import lance
from lance.dependencies import numpy as np
Expand Down Expand Up @@ -105,12 +107,63 @@ def _efficient_sample(
del tbl


def _filtered_efficient_sample(
dataset: lance.LanceDataset,
n: int,
columns: Optional[Union[List[str], Dict[str, str]]],
batch_size: int,
target_takes: int,
filter: str,
) -> Generator[pa.RecordBatch, None, None]:
total_records = len(dataset)
shard_size = math.ceil(n / target_takes)
num_shards = math.ceil(total_records / shard_size)

shards = list(range(num_shards))
random.shuffle(shards)

tables = []
remaining_rows = n
remaining_in_batch = min(batch_size, n)
for shard in shards:
start = shard * shard_size
end = min(start + shard_size, total_records)
table = dataset.to_table(
columns=columns,
offset=start,
limit=(end - start),
batch_size=shard_size,
)
if len(columns) == 1 and filter.lower() == f"{columns[0]} is not null":
table = pc.drop_null(table)
elif filter is not None:
raise NotImplementedError(f"Can't yet run filter <{filter}> in-memory")
if table.num_rows > 0:
tables.append(table)
remaining_rows -= table.num_rows
remaining_in_batch = remaining_in_batch - table.num_rows
if remaining_in_batch <= 0:
combined = pa.concat_tables(tables).combine_chunks()
batch = combined.slice(0, batch_size).to_batches()[0]
yield batch
remaining_in_batch = min(batch_size, remaining_rows)
if len(combined) > batch_size:
leftover = combined.slice(batch_size)
tables = [leftover]
remaining_in_batch -= len(leftover)
else:
tables = []
if remaining_rows <= 0:
break


def maybe_sample(
dataset: Union[str, Path, lance.LanceDataset],
n: int,
columns: Union[list[str], dict[str, str], str],
batch_size: int = 10240,
max_takes: int = 2048,
filt: Optional[str] = None,
) -> Generator[pa.RecordBatch, None, None]:
"""Sample n records from the dataset.

Expand All @@ -129,6 +182,10 @@ def maybe_sample(
This is employed to minimize the number of random reads necessary for sampling.
A sufficiently large value can provide an effective random sample without
the need for excessive random reads.
filter : str, optional
The filter to apply to the dataset, by default None. If a filter is provided,
then we will first load all row ids in memory and then batch through the ids
in random order until enough matches have been found.

Returns
-------
Expand All @@ -143,18 +200,23 @@ def maybe_sample(

if n >= len(dataset):
# Dont have enough data in the dataset. Just do a full scan
yield from dataset.to_batches(columns=columns, batch_size=batch_size)
yield from dataset.to_batches(
columns=columns, batch_size=batch_size, filter=filt
)
elif filt is not None:
yield from _filtered_efficient_sample(
dataset, n, columns, batch_size, max_takes, filt
)
elif n > max_takes:
yield from _efficient_sample(dataset, n, columns, batch_size, max_takes)
else:
if n > max_takes:
yield from _efficient_sample(dataset, n, columns, batch_size, max_takes)
else:
choices = np.random.choice(len(dataset), n, replace=False)
idx = 0
while idx < len(choices):
end = min(idx + batch_size, len(choices))
tbl = dataset.take(choices[idx:end], columns=columns).combine_chunks()
yield tbl.to_batches()[0]
idx += batch_size
choices = np.random.choice(len(dataset), n, replace=False)
idx = 0
while idx < len(choices):
end = min(idx + batch_size, len(choices))
tbl = dataset.take(choices[idx:end], columns=columns).combine_chunks()
yield tbl.to_batches()[0]
idx += batch_size


T = TypeVar("T")
Expand Down
4 changes: 1 addition & 3 deletions python/python/lance/torch/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,6 @@ def __init__(
warnings.warn("rank and world_size are deprecated", DeprecationWarning)
self.sampler: Optional[Sampler] = sampler

if filter is not None and self.samples > 0 or self.samples is None:
raise ValueError("`filter` is not supported with `samples`")

# Dataset with huggingface metadata
if (
dataset.schema.metadata is not None
Expand Down Expand Up @@ -284,6 +281,7 @@ def __iter__(self):
n=self.samples,
columns=self.columns,
batch_size=self.batch_size,
filt=self.filter,
)
else:
raw_stream = sampler(
Expand Down
22 changes: 18 additions & 4 deletions python/python/lance/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,17 +154,30 @@ def train_ivf_centroids_on_accelerator(

k = int(k)

logging.info("Randomly select %s centroids from %s", k, dataset)
samples = dataset.sample(k, [column], sorted=True).combine_chunks()
fsl = samples.to_batches()[0][column]
init_centroids = torch.from_numpy(np.stack(fsl.to_numpy(zero_copy_only=False)))
if dataset.schema.field(column).nullable:
filt = f"{column} is not null"
else:
filt = None

logging.info("Randomly select %s centroids from %s (filt=%s)", k, dataset, filt)

ds = TorchDataset(
dataset,
batch_size=k,
columns=[column],
samples=sample_size,
filter=filt,
)

init_centroids = next(iter(ds))
logging.info("Done sampling: centroids shape: %s", init_centroids.shape)

ds = TorchDataset(
dataset,
batch_size=20480,
columns=[column],
samples=sample_size,
filter=filt,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using filters here seems to cause a significant performance hit now (on the order of several minutes vs several seconds for kmeans clustering).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, maybe worth a fast path if we know there are no nulls then.

cache=True,
)

Expand Down Expand Up @@ -233,6 +246,7 @@ def compute_partitions(
batch_size=batch_size,
with_row_id=True,
columns=[column],
filter=f"{column} is not null",
)
loader = torch.utils.data.DataLoader(
torch_ds,
Expand Down
64 changes: 64 additions & 0 deletions python/python/tests/test_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@ def rand_dataset(tmpdir, request):
return ds


@pytest.fixture
def mostly_null_dataset(tmpdir, request):
vectors = np.random.randn(NUM_ROWS, DIMENSION).astype(np.float32)
vectors.shape = -1
vectors = pa.FixedSizeListArray.from_arrays(vectors, DIMENSION)
vectors = vectors.to_pylist()
vectors = [vec if i % 10 == 0 else None for i, vec in enumerate(vectors)]
vectors = pa.array(vectors, pa.list_(pa.float32(), DIMENSION))
table = pa.Table.from_arrays([vectors], names=["vectors"])

uri = str(tmpdir / "nulls_dataset")
ds = lance.write_dataset(table, uri, max_rows_per_file=NUM_ROWS_PER_FRAGMENT)
return ds


def test_ivf_centroids(tmpdir, rand_dataset):
ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(sample_rate=16)

Expand All @@ -44,6 +59,13 @@ def test_ivf_centroids(tmpdir, rand_dataset):
assert ivf.centroids == reloaded.centroids


def test_ivf_centroids_mostly_null(mostly_null_dataset):
ivf = IndicesBuilder(mostly_null_dataset, "vectors").train_ivf(sample_rate=16)

assert ivf.distance_type == "l2"
assert len(ivf.centroids) == NUM_PARTITIONS


@pytest.mark.cuda
def test_ivf_centroids_cuda(rand_dataset):
ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(
Expand All @@ -54,6 +76,16 @@ def test_ivf_centroids_cuda(rand_dataset):
assert len(ivf.centroids) == NUM_PARTITIONS


@pytest.mark.cuda
def test_ivf_centroids_mostly_null_cuda(mostly_null_dataset):
ivf = IndicesBuilder(mostly_null_dataset, "vectors").train_ivf(
sample_rate=16, accelerator="cuda"
)

assert ivf.distance_type == "l2"
assert len(ivf.centroids) == NUM_PARTITIONS


def test_ivf_centroids_distance_type(tmpdir, rand_dataset):
def check(distance_type):
ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(
Expand Down Expand Up @@ -95,6 +127,16 @@ def test_gen_pq(tmpdir, rand_dataset, rand_ivf):
assert pq.codebook == reloaded.codebook


def test_gen_pq_mostly_null(mostly_null_dataset):
centroids = np.random.rand(DIMENSION * 100).astype(np.float32)
centroids = pa.FixedSizeListArray.from_arrays(centroids, DIMENSION)
ivf = IvfModel(centroids, "l2")

pq = IndicesBuilder(mostly_null_dataset, "vectors").train_pq(ivf, sample_rate=2)
assert pq.dimension == DIMENSION
assert pq.num_subvectors == NUM_SUBVECTORS


@pytest.mark.cuda
def test_assign_partitions(rand_dataset, rand_ivf):
builder = IndicesBuilder(rand_dataset, "vectors")
Expand All @@ -113,6 +155,28 @@ def test_assign_partitions(rand_dataset, rand_ivf):
assert len(found_row_ids) == rand_dataset.count_rows()


@pytest.mark.cuda
def test_assign_partitions_mostly_null(mostly_null_dataset):
centroids = np.random.rand(DIMENSION * 100).astype(np.float32)
centroids = pa.FixedSizeListArray.from_arrays(centroids, DIMENSION)
ivf = IvfModel(centroids, "l2")

builder = IndicesBuilder(mostly_null_dataset, "vectors")

partitions_uri = builder.assign_ivf_partitions(ivf, accelerator="cuda")

partitions = lance.dataset(partitions_uri)
found_row_ids = set()
for batch in partitions.to_batches():
row_ids = batch["row_id"]
for row_id in row_ids:
found_row_ids.add(row_id)
part_ids = batch["partition"]
for part_id in part_ids:
assert part_id.as_py() < 100
assert len(found_row_ids) == (mostly_null_dataset.count_rows() / 10)


@pytest.fixture
def rand_pq(rand_dataset, rand_ivf):
dtype = rand_dataset.schema.field("vectors").type.value_type.to_pandas_dtype()
Expand Down
14 changes: 11 additions & 3 deletions python/python/tests/test_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from lance.vector import vec_to_table # noqa: E402


def create_table(nvec=1000, ndim=128, nans=0):
def create_table(nvec=1000, ndim=128, nans=0, nullify=False):
mat = np.random.randn(nvec, ndim)
if nans > 0:
nans_mat = np.empty((nans, ndim))
Expand All @@ -37,6 +37,13 @@ def gen_str(n):
.append_column("meta", pa.array(meta))
.append_column("id", pa.array(range(nvec + nans)))
)
if nullify:
idx = tbl.schema.get_field_index("vector")
vecs = tbl[idx].to_pylist()
nullified = [vec if i % 2 == 0 else None for i, vec in enumerate(vecs)]
field = tbl.schema.field(idx)
vecs = pa.array(nullified, field.type)
tbl = tbl.set_column(idx, field, vecs)
return tbl


Expand Down Expand Up @@ -191,8 +198,9 @@ def test_index_with_pq_codebook(tmp_path):


@pytest.mark.cuda
def test_create_index_using_cuda(tmp_path):
tbl = create_table()
@pytest.mark.parametrize("nullify", [False, True])
def test_create_index_using_cuda(tmp_path, nullify):
tbl = create_table(nullify=nullify)
dataset = lance.write_dataset(tbl, tmp_path)
dataset = dataset.create_index(
"vector",
Expand Down
18 changes: 10 additions & 8 deletions python/python/tests/torch_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,14 +134,16 @@ def check(dataset):
)
)

# sampling fails
with pytest.raises(ValueError):
LanceDataset(
ds,
batch_size=10,
filter="ids >= 300",
samples=100,
columns=["ids"],
# sampling with filter
with pytest.raises(NotImplementedError):
check(
LanceDataset(
ds,
batch_size=10,
filter="ids >= 300",
samples=100,
columns=["ids"],
)
)


Expand Down