From 69ea0a7e0674b0bd2e360f2601543779c0b1f93e Mon Sep 17 00:00:00 2001
From: Weston Pace <weston.pace@gmail.com>
Date: Wed, 18 Sep 2024 15:03:58 -0700
Subject: [PATCH 1/3] Handle nulls when creating indices with cuda

---
 python/python/lance/sampler.py           | 84 ++++++++++++++++++++----
 python/python/lance/torch/data.py        |  4 +-
 python/python/lance/vector.py            | 22 +++++--
 python/python/tests/test_indices.py      | 63 ++++++++++++++++++
 python/python/tests/test_vector_index.py | 14 +++-
 5 files changed, 166 insertions(+), 21 deletions(-)

diff --git a/python/python/lance/sampler.py b/python/python/lance/sampler.py
index fb3f670c86..f0896228df 100644
--- a/python/python/lance/sampler.py
+++ b/python/python/lance/sampler.py
@@ -6,6 +6,7 @@
 
 import gc
 import logging
+import math
 import random
 import warnings
 from abc import ABC, abstractmethod
@@ -15,6 +16,7 @@
 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, TypeVar, Union
 
 import pyarrow as pa
+import pyarrow.compute as pc
 
 import lance
 from lance.dependencies import numpy as np
@@ -105,12 +107,63 @@ def _efficient_sample(
         del tbl
 
 
+def _filtered_efficient_sample(
+    dataset: lance.LanceDataset,
+    n: int,
+    columns: Optional[Union[List[str], Dict[str, str]]],
+    batch_size: int,
+    target_takes: int,
+    filter: str,
+) -> Generator[pa.RecordBatch, None, None]:
+    total_records = len(dataset)
+    shard_size = math.ceil(n / target_takes)
+    num_shards = math.ceil(total_records / shard_size)
+
+    shards = list(range(num_shards))
+    random.shuffle(shards)
+
+    tables = []
+    remaining_rows = n
+    remaining_in_batch = min(batch_size, n)
+    for shard in shards:
+        start = shard * shard_size
+        end = min(start + shard_size, total_records)
+        table = dataset.to_table(
+            columns=columns,
+            offset=start,
+            limit=(end - start),
+            batch_size=shard_size,
+        )
+        if len(columns) == 1 and filter.lower() == f"{columns[0]} is not null":
+            table = pc.drop_null(table)
+        elif filter is not None:
+            raise Exception(f"Can't yet run filter <{filter}> in-memory")
+        if table.num_rows > 0:
+            tables.append(table)
+            remaining_rows -= table.num_rows
+            remaining_in_batch = remaining_in_batch - table.num_rows
+            if remaining_in_batch <= 0:
+                combined = pa.concat_tables(tables).combine_chunks()
+                batch = combined.slice(0, batch_size).to_batches()[0]
+                yield batch
+                remaining_in_batch = min(batch_size, remaining_rows)
+                if len(combined) > batch_size:
+                    leftover = combined.slice(batch_size)
+                    tables = [leftover]
+                    remaining_in_batch -= len(leftover)
+                else:
+                    tables = []
+            if remaining_rows <= 0:
+                break
+
+
 def maybe_sample(
     dataset: Union[str, Path, lance.LanceDataset],
     n: int,
     columns: Union[list[str], dict[str, str], str],
     batch_size: int = 10240,
     max_takes: int = 2048,
+    filt: Optional[str] = None,
 ) -> Generator[pa.RecordBatch, None, None]:
     """Sample n records from the dataset.
 
@@ -129,6 +182,10 @@ def maybe_sample(
         This is employed to minimize the number of random reads necessary for sampling.
         A sufficiently large value can provide an effective random sample without
         the need for excessive random reads.
+    filter : str, optional
+        The filter to apply to the dataset, by default None.  If a filter is provided,
+        then we will first load all row ids in memory and then batch through the ids
+        in random order until enough matches have been found.
 
     Returns
     -------
@@ -143,18 +200,23 @@ def maybe_sample(
 
     if n >= len(dataset):
         # Dont have enough data in the dataset. Just do a full scan
-        yield from dataset.to_batches(columns=columns, batch_size=batch_size)
+        yield from dataset.to_batches(
+            columns=columns, batch_size=batch_size, filter=filt
+        )
+    elif filt is not None:
+        yield from _filtered_efficient_sample(
+            dataset, n, columns, batch_size, max_takes, filt
+        )
+    elif n > max_takes:
+        yield from _efficient_sample(dataset, n, columns, batch_size, max_takes)
     else:
-        if n > max_takes:
-            yield from _efficient_sample(dataset, n, columns, batch_size, max_takes)
-        else:
-            choices = np.random.choice(len(dataset), n, replace=False)
-            idx = 0
-            while idx < len(choices):
-                end = min(idx + batch_size, len(choices))
-                tbl = dataset.take(choices[idx:end], columns=columns).combine_chunks()
-                yield tbl.to_batches()[0]
-                idx += batch_size
+        choices = np.random.choice(len(dataset), n, replace=False)
+        idx = 0
+        while idx < len(choices):
+            end = min(idx + batch_size, len(choices))
+            tbl = dataset.take(choices[idx:end], columns=columns).combine_chunks()
+            yield tbl.to_batches()[0]
+            idx += batch_size
 
 
 T = TypeVar("T")
diff --git a/python/python/lance/torch/data.py b/python/python/lance/torch/data.py
index 30cb58c86f..3f066543dd 100644
--- a/python/python/lance/torch/data.py
+++ b/python/python/lance/torch/data.py
@@ -233,9 +233,6 @@ def __init__(
             warnings.warn("rank and world_size are deprecated", DeprecationWarning)
         self.sampler: Optional[Sampler] = sampler
 
-        if filter is not None and self.samples > 0 or self.samples is None:
-            raise ValueError("`filter` is not supported with `samples`")
-
         # Dataset with huggingface metadata
         if (
             dataset.schema.metadata is not None
@@ -284,6 +281,7 @@ def __iter__(self):
                     n=self.samples,
                     columns=self.columns,
                     batch_size=self.batch_size,
+                    filt=self.filter,
                 )
             else:
                 raw_stream = sampler(
diff --git a/python/python/lance/vector.py b/python/python/lance/vector.py
index 48daa44148..e30911e1b9 100644
--- a/python/python/lance/vector.py
+++ b/python/python/lance/vector.py
@@ -154,10 +154,22 @@ def train_ivf_centroids_on_accelerator(
 
     k = int(k)
 
-    logging.info("Randomly select %s centroids from %s", k, dataset)
-    samples = dataset.sample(k, [column], sorted=True).combine_chunks()
-    fsl = samples.to_batches()[0][column]
-    init_centroids = torch.from_numpy(np.stack(fsl.to_numpy(zero_copy_only=False)))
+    if dataset.schema.field(column).nullable:
+        filt = f"{column} is not null"
+    else:
+        filt = None
+
+    logging.info("Randomly select %s centroids from %s (filt=%s)", k, dataset, filt)
+
+    ds = TorchDataset(
+        dataset,
+        batch_size=k,
+        columns=[column],
+        samples=sample_size,
+        filter=filt,
+    )
+
+    init_centroids = next(iter(ds))
     logging.info("Done sampling: centroids shape: %s", init_centroids.shape)
 
     ds = TorchDataset(
@@ -165,6 +177,7 @@ def train_ivf_centroids_on_accelerator(
         batch_size=20480,
         columns=[column],
         samples=sample_size,
+        filter=filt,
         cache=True,
     )
 
@@ -233,6 +246,7 @@ def compute_partitions(
         batch_size=batch_size,
         with_row_id=True,
         columns=[column],
+        filter=f"{column} is not null",
     )
     loader = torch.utils.data.DataLoader(
         torch_ds,
diff --git a/python/python/tests/test_indices.py b/python/python/tests/test_indices.py
index dc909a8925..8f9b91676b 100644
--- a/python/python/tests/test_indices.py
+++ b/python/python/tests/test_indices.py
@@ -32,6 +32,21 @@ def rand_dataset(tmpdir, request):
     return ds
 
 
+@pytest.fixture
+def mostly_null_dataset(tmpdir, request):
+    vectors = np.random.randn(NUM_ROWS, DIMENSION).astype(np.float32)
+    vectors.shape = -1
+    vectors = pa.FixedSizeListArray.from_arrays(vectors, DIMENSION)
+    vectors = vectors.to_pylist()
+    vectors = [vec if i % 10 == 0 else None for i, vec in enumerate(vectors)]
+    vectors = pa.array(vectors, pa.list_(pa.float32(), DIMENSION))
+    table = pa.Table.from_arrays([vectors], names=["vectors"])
+
+    uri = str(tmpdir / "nulls_dataset")
+    ds = lance.write_dataset(table, uri, max_rows_per_file=NUM_ROWS_PER_FRAGMENT)
+    return ds
+
+
 def test_ivf_centroids(tmpdir, rand_dataset):
     ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(sample_rate=16)
 
@@ -44,6 +59,13 @@ def test_ivf_centroids(tmpdir, rand_dataset):
     assert ivf.centroids == reloaded.centroids
 
 
+def test_ivf_centroids_mostly_null(mostly_null_dataset):
+    ivf = IndicesBuilder(mostly_null_dataset, "vectors").train_ivf(sample_rate=16)
+
+    assert ivf.distance_type == "l2"
+    assert len(ivf.centroids) == NUM_PARTITIONS
+
+
 @pytest.mark.cuda
 def test_ivf_centroids_cuda(rand_dataset):
     ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(
@@ -54,6 +76,15 @@ def test_ivf_centroids_cuda(rand_dataset):
     assert len(ivf.centroids) == NUM_PARTITIONS
 
 
+def test_ivf_centroids_mostly_null_cuda(mostly_null_dataset):
+    ivf = IndicesBuilder(mostly_null_dataset, "vectors").train_ivf(
+        sample_rate=16, accelerator="cuda"
+    )
+
+    assert ivf.distance_type == "l2"
+    assert len(ivf.centroids) == NUM_PARTITIONS
+
+
 def test_ivf_centroids_distance_type(tmpdir, rand_dataset):
     def check(distance_type):
         ivf = IndicesBuilder(rand_dataset, "vectors").train_ivf(
@@ -95,6 +126,16 @@ def test_gen_pq(tmpdir, rand_dataset, rand_ivf):
     assert pq.codebook == reloaded.codebook
 
 
+def test_gen_pq_mostly_null(mostly_null_dataset):
+    centroids = np.random.rand(DIMENSION * 100).astype(np.float32)
+    centroids = pa.FixedSizeListArray.from_arrays(centroids, DIMENSION)
+    ivf = IvfModel(centroids, "l2")
+
+    pq = IndicesBuilder(mostly_null_dataset, "vectors").train_pq(ivf, sample_rate=2)
+    assert pq.dimension == DIMENSION
+    assert pq.num_subvectors == NUM_SUBVECTORS
+
+
 @pytest.mark.cuda
 def test_assign_partitions(rand_dataset, rand_ivf):
     builder = IndicesBuilder(rand_dataset, "vectors")
@@ -113,6 +154,28 @@ def test_assign_partitions(rand_dataset, rand_ivf):
     assert len(found_row_ids) == rand_dataset.count_rows()
 
 
+@pytest.mark.cuda
+def test_assign_partitions_mostly_null(mostly_null_dataset):
+    centroids = np.random.rand(DIMENSION * 100).astype(np.float32)
+    centroids = pa.FixedSizeListArray.from_arrays(centroids, DIMENSION)
+    ivf = IvfModel(centroids, "l2")
+
+    builder = IndicesBuilder(mostly_null_dataset, "vectors")
+
+    partitions_uri = builder.assign_ivf_partitions(ivf, accelerator="cuda")
+
+    partitions = lance.dataset(partitions_uri)
+    found_row_ids = set()
+    for batch in partitions.to_batches():
+        row_ids = batch["row_id"]
+        for row_id in row_ids:
+            found_row_ids.add(row_id)
+        part_ids = batch["partition"]
+        for part_id in part_ids:
+            assert part_id.as_py() < 100
+    assert len(found_row_ids) == (mostly_null_dataset.count_rows() / 10)
+
+
 @pytest.fixture
 def rand_pq(rand_dataset, rand_ivf):
     dtype = rand_dataset.schema.field("vectors").type.value_type.to_pandas_dtype()
diff --git a/python/python/tests/test_vector_index.py b/python/python/tests/test_vector_index.py
index 6fc54ee690..7673a64ee6 100644
--- a/python/python/tests/test_vector_index.py
+++ b/python/python/tests/test_vector_index.py
@@ -19,7 +19,7 @@
 from lance.vector import vec_to_table  # noqa: E402
 
 
-def create_table(nvec=1000, ndim=128, nans=0):
+def create_table(nvec=1000, ndim=128, nans=0, nullify=False):
     mat = np.random.randn(nvec, ndim)
     if nans > 0:
         nans_mat = np.empty((nans, ndim))
@@ -37,6 +37,13 @@ def gen_str(n):
         .append_column("meta", pa.array(meta))
         .append_column("id", pa.array(range(nvec + nans)))
     )
+    if nullify:
+        idx = tbl.schema.get_field_index("vector")
+        vecs = tbl[idx].to_pylist()
+        nullified = [vec if i % 2 == 0 else None for i, vec in enumerate(vecs)]
+        field = tbl.schema.field(idx)
+        vecs = pa.array(nullified, field.type)
+        tbl = tbl.set_column(idx, field, vecs)
     return tbl
 
 
@@ -191,8 +198,9 @@ def test_index_with_pq_codebook(tmp_path):
 
 
 @pytest.mark.cuda
-def test_create_index_using_cuda(tmp_path):
-    tbl = create_table()
+@pytest.mark.parametrize("nullify", [False, True])
+def test_create_index_using_cuda(tmp_path, nullify):
+    tbl = create_table(nullify=nullify)
     dataset = lance.write_dataset(tbl, tmp_path)
     dataset = dataset.create_index(
         "vector",

From 868f5b605bd5ab0d6577b36c191b92460ab53d80 Mon Sep 17 00:00:00 2001
From: Weston Pace <weston.pace@gmail.com>
Date: Wed, 18 Sep 2024 15:30:02 -0700
Subject: [PATCH 2/3] Forgot a cuda mark on one of the tests

---
 python/python/tests/test_indices.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/python/python/tests/test_indices.py b/python/python/tests/test_indices.py
index 8f9b91676b..015994c483 100644
--- a/python/python/tests/test_indices.py
+++ b/python/python/tests/test_indices.py
@@ -76,6 +76,7 @@ def test_ivf_centroids_cuda(rand_dataset):
     assert len(ivf.centroids) == NUM_PARTITIONS
 
 
+@pytest.mark.cuda
 def test_ivf_centroids_mostly_null_cuda(mostly_null_dataset):
     ivf = IndicesBuilder(mostly_null_dataset, "vectors").train_ivf(
         sample_rate=16, accelerator="cuda"

From 0df1d9705a67bbd4fb3a5b0c63727aae15949f19 Mon Sep 17 00:00:00 2001
From: Weston Pace <weston.pace@gmail.com>
Date: Wed, 18 Sep 2024 15:58:26 -0700
Subject: [PATCH 3/3] Change error raised for not-implemented feature.  Fix
 test

---
 python/python/lance/sampler.py               |  2 +-
 python/python/tests/torch_tests/test_data.py | 18 ++++++++++--------
 2 files changed, 11 insertions(+), 9 deletions(-)

diff --git a/python/python/lance/sampler.py b/python/python/lance/sampler.py
index f0896228df..2028609d1e 100644
--- a/python/python/lance/sampler.py
+++ b/python/python/lance/sampler.py
@@ -137,7 +137,7 @@ def _filtered_efficient_sample(
         if len(columns) == 1 and filter.lower() == f"{columns[0]} is not null":
             table = pc.drop_null(table)
         elif filter is not None:
-            raise Exception(f"Can't yet run filter <{filter}> in-memory")
+            raise NotImplementedError(f"Can't yet run filter <{filter}> in-memory")
         if table.num_rows > 0:
             tables.append(table)
             remaining_rows -= table.num_rows
diff --git a/python/python/tests/torch_tests/test_data.py b/python/python/tests/torch_tests/test_data.py
index 66a214001c..bb76b30771 100644
--- a/python/python/tests/torch_tests/test_data.py
+++ b/python/python/tests/torch_tests/test_data.py
@@ -134,14 +134,16 @@ def check(dataset):
         )
     )
 
-    # sampling fails
-    with pytest.raises(ValueError):
-        LanceDataset(
-            ds,
-            batch_size=10,
-            filter="ids >= 300",
-            samples=100,
-            columns=["ids"],
+    # sampling with filter
+    with pytest.raises(NotImplementedError):
+        check(
+            LanceDataset(
+                ds,
+                batch_size=10,
+                filter="ids >= 300",
+                samples=100,
+                columns=["ids"],
+            )
         )