Skip to content

Commit 59b414b

Browse files
authored
feat: support to read IVF partitions (#3462)
1 parent cca98fc commit 59b414b

File tree

13 files changed

+360
-15
lines changed

13 files changed

+360
-15
lines changed

python/python/lance/dataset.py

+109
Original file line numberDiff line numberDiff line change
@@ -3841,3 +3841,112 @@ def _validate_metadata(metadata: dict):
38413841
)
38423842
elif isinstance(v, dict):
38433843
_validate_metadata(v)
3844+
3845+
3846+
class VectorIndexReader:
3847+
"""
3848+
This class allows you to initialize a reader for a specific vector index,
3849+
retrieve the number of partitions,
3850+
access the centroids of the index,
3851+
and read specific partitions of the index.
3852+
3853+
Parameters
3854+
----------
3855+
dataset: LanceDataset
3856+
The dataset containing the index.
3857+
index_name: str
3858+
The name of the vector index to read.
3859+
3860+
Examples
3861+
--------
3862+
.. code-block:: python
3863+
3864+
import lance
3865+
from lance.dataset import VectorIndexReader
3866+
import numpy as np
3867+
import pyarrow as pa
3868+
vectors = np.random.rand(256, 2)
3869+
data = pa.table({"vector": pa.array(vectors.tolist(),
3870+
type=pa.list_(pa.float32(), 2))})
3871+
dataset = lance.write_dataset(data, "/tmp/index_reader_demo")
3872+
dataset.create_index("vector", index_type="IVF_PQ",
3873+
num_partitions=4, num_sub_vectors=2)
3874+
reader = VectorIndexReader(dataset, "vector_idx")
3875+
assert reader.num_partitions() == 4
3876+
partition = reader.read_partition(0)
3877+
assert "_rowid" in partition.column_names
3878+
3879+
Exceptions
3880+
----------
3881+
ValueError
3882+
If the specified index is not a vector index.
3883+
"""
3884+
3885+
def __init__(self, dataset: LanceDataset, index_name: str):
3886+
stats = dataset.stats.index_stats(index_name)
3887+
self.dataset = dataset
3888+
self.index_name = index_name
3889+
self.stats = stats
3890+
try:
3891+
self.num_partitions()
3892+
except KeyError:
3893+
raise ValueError(f"Index {index_name} is not vector index")
3894+
3895+
def num_partitions(self) -> int:
3896+
"""
3897+
Returns the number of partitions in the dataset.
3898+
3899+
Returns
3900+
-------
3901+
int
3902+
The number of partitions.
3903+
"""
3904+
3905+
return self.stats["indices"][0]["num_partitions"]
3906+
3907+
def centroids(self) -> np.ndarray:
3908+
"""
3909+
Returns the centroids of the index
3910+
3911+
Returns
3912+
-------
3913+
np.ndarray
3914+
The centroids of IVF
3915+
with shape (num_partitions, dim)
3916+
"""
3917+
# when we have more delta indices,
3918+
# they are with the same centroids
3919+
return np.array(
3920+
self.dataset._ds.get_index_centroids(self.stats["indices"][0]["centroids"])
3921+
)
3922+
3923+
def read_partition(
3924+
self, partition_id: int, *, with_vector: bool = False
3925+
) -> pa.Table:
3926+
"""
3927+
Returns a pyarrow table for the given IVF partition
3928+
3929+
Parameters
3930+
----------
3931+
partition_id: int
3932+
The id of the partition to read
3933+
with_vector: bool, default False
3934+
Whether to include the vector column in the reader,
3935+
for IVF_PQ, the vector column is PQ codes
3936+
3937+
Returns
3938+
-------
3939+
pa.Table
3940+
A pyarrow table for the given partition,
3941+
containing the row IDs, and quantized vectors (if with_vector is True).
3942+
"""
3943+
3944+
if partition_id < 0 or partition_id >= self.num_partitions():
3945+
raise IndexError(
3946+
f"Partition id {partition_id} is out of range, "
3947+
f"expected 0 <= partition_id < {self.num_partitions()}"
3948+
)
3949+
3950+
return self.dataset._ds.read_index_partition(
3951+
self.index_name, partition_id, with_vector
3952+
).read_all()

python/python/tests/test_vector_index.py

+31
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import pyarrow.compute as pc
1414
import pytest
1515
from lance import LanceFragment
16+
from lance.dataset import VectorIndexReader
1617

1718
torch = pytest.importorskip("torch")
1819
from lance.util import validate_vector_index # noqa: E402
@@ -1129,3 +1130,33 @@ def test_drop_indices(indexed_dataset):
11291130
)
11301131

11311132
assert len(results) == 15
1133+
1134+
1135+
def test_read_partition(indexed_dataset):
1136+
idx_name = indexed_dataset.list_indices()[0]["name"]
1137+
reader = VectorIndexReader(indexed_dataset, idx_name)
1138+
1139+
num_rows = indexed_dataset.count_rows()
1140+
row_sum = 0
1141+
for part_id in range(reader.num_partitions()):
1142+
res = reader.read_partition(part_id)
1143+
row_sum += res.num_rows
1144+
assert "_rowid" in res.column_names
1145+
assert row_sum == num_rows
1146+
1147+
row_sum = 0
1148+
for part_id in range(reader.num_partitions()):
1149+
res = reader.read_partition(part_id, with_vector=True)
1150+
row_sum += res.num_rows
1151+
pq_column = res["__pq_code"]
1152+
assert "_rowid" in res.column_names
1153+
assert pq_column.type == pa.list_(pa.uint8(), 16)
1154+
assert row_sum == num_rows
1155+
1156+
# error tests
1157+
with pytest.raises(IndexError, match="out of range"):
1158+
reader.read_partition(reader.num_partitions() + 1)
1159+
1160+
with pytest.raises(ValueError, match="not vector index"):
1161+
indexed_dataset.create_scalar_index("id", index_type="BTREE")
1162+
VectorIndexReader(indexed_dataset, "id_idx")

python/src/dataset.rs

+22-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use arrow_array::Array;
3131
use futures::{StreamExt, TryFutureExt};
3232
use lance::dataset::builder::DatasetBuilder;
3333
use lance::dataset::refs::{Ref, TagContents};
34-
use lance::dataset::scanner::MaterializationStyle;
34+
use lance::dataset::scanner::{DatasetRecordBatchStream, MaterializationStyle};
3535
use lance::dataset::statistics::{DataStatistics, DatasetStatisticsExt};
3636
use lance::dataset::{
3737
fragment::FileFragment as LanceFileFragment,
@@ -1558,6 +1558,27 @@ impl Dataset {
15581558

15591559
Ok(())
15601560
}
1561+
1562+
#[pyo3(signature = (index_name,partition_id, with_vector=false))]
1563+
fn read_index_partition(
1564+
&self,
1565+
index_name: String,
1566+
partition_id: usize,
1567+
with_vector: bool,
1568+
) -> PyResult<PyArrowType<Box<dyn RecordBatchReader + Send>>> {
1569+
let stream = RT
1570+
.block_on(
1571+
None,
1572+
self.ds
1573+
.read_index_partition(&index_name, partition_id, with_vector),
1574+
)?
1575+
.map_err(|err| PyValueError::new_err(err.to_string()))?;
1576+
1577+
let reader = Box::new(LanceReader::from_stream(DatasetRecordBatchStream::new(
1578+
stream,
1579+
)));
1580+
Ok(PyArrowType(reader))
1581+
}
15611582
}
15621583

15631584
impl Dataset {

rust/lance-index/src/traits.rs

+8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
use std::sync::Arc;
55

66
use async_trait::async_trait;
7+
use datafusion::execution::SendableRecordBatchStream;
78
use lance_core::Result;
89

910
use crate::{optimize::OptimizeOptions, IndexParams, IndexType};
@@ -97,4 +98,11 @@ pub trait DatasetIndexExt {
9798
column: &str,
9899
index_id: Uuid,
99100
) -> Result<()>;
101+
102+
async fn read_index_partition(
103+
&self,
104+
index_name: &str,
105+
partition_id: usize,
106+
with_vector: bool,
107+
) -> Result<SendableRecordBatchStream>;
100108
}

rust/lance-index/src/vector.rs

+13
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use std::{collections::HashMap, sync::Arc};
99
use arrow_array::{ArrayRef, RecordBatch, UInt32Array};
1010
use arrow_schema::Field;
1111
use async_trait::async_trait;
12+
use datafusion::execution::SendableRecordBatchStream;
1213
use ivf::storage::IvfModel;
1314
use lance_core::{Result, ROW_ID_FIELD};
1415
use lance_io::object_store::ObjectStore;
@@ -179,6 +180,18 @@ pub trait VectorIndex: Send + Sync + std::fmt::Debug + Index {
179180
self.load(reader, offset, length).await
180181
}
181182

183+
// for IVF only
184+
async fn partition_reader(
185+
&self,
186+
_partition_id: usize,
187+
_with_vector: bool,
188+
) -> Result<SendableRecordBatchStream> {
189+
unimplemented!("only for IVF")
190+
}
191+
192+
// for SubIndex only
193+
async fn to_batch_stream(&self, with_vector: bool) -> Result<SendableRecordBatchStream>;
194+
182195
/// Return the IDs of rows in the index.
183196
fn row_ids(&self) -> Box<dyn Iterator<Item = &'_ u64> + '_>;
184197

rust/lance-index/src/vector/hnsw/index.rs

+30
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@ use std::{
1010

1111
use arrow_array::{RecordBatch, UInt32Array};
1212
use async_trait::async_trait;
13+
use datafusion::execution::SendableRecordBatchStream;
14+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
1315
use deepsize::DeepSizeOf;
16+
use lance_arrow::RecordBatchExt;
17+
use lance_core::ROW_ID;
1418
use lance_core::{datatypes::Schema, Error, Result};
1519
use lance_file::reader::FileReader;
1620
use lance_io::traits::Reader;
@@ -263,6 +267,32 @@ impl<Q: Quantization + Send + Sync + 'static> VectorIndex for HNSWIndex<Q> {
263267
}))
264268
}
265269

270+
async fn to_batch_stream(&self, with_vector: bool) -> Result<SendableRecordBatchStream> {
271+
let store = self.storage.as_ref().ok_or(Error::Index {
272+
message: "vector storage not loaded".to_string(),
273+
location: location!(),
274+
})?;
275+
276+
let schema = if with_vector {
277+
store.schema().clone()
278+
} else {
279+
let schema = store.schema();
280+
let row_id_idx = schema.index_of(ROW_ID)?;
281+
Arc::new(schema.project(&[row_id_idx])?)
282+
};
283+
284+
let batches = store
285+
.to_batches()?
286+
.map(|b| {
287+
let batch = b.project_by_schema(&schema)?;
288+
Ok(batch)
289+
})
290+
.collect::<Vec<_>>();
291+
let stream = futures::stream::iter(batches);
292+
let stream = RecordBatchStreamAdapter::new(schema, stream);
293+
Ok(Box::pin(stream))
294+
}
295+
266296
fn row_ids(&self) -> Box<dyn Iterator<Item = &'_ u64> + '_> {
267297
Box::new(self.storage.as_ref().unwrap().row_ids())
268298
}

rust/lance/src/index.rs

+47-9
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,10 @@
77
use std::collections::{HashMap, HashSet};
88
use std::sync::{Arc, OnceLock};
99

10-
use arrow_schema::DataType;
10+
use arrow_schema::{DataType, Schema};
1111
use async_trait::async_trait;
12+
use datafusion::execution::SendableRecordBatchStream;
13+
use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
1214
use futures::{stream, StreamExt, TryStreamExt};
1315
use itertools::Itertools;
1416
use lance_core::utils::parse::str_is_truthy;
@@ -686,6 +688,48 @@ impl DatasetIndexExt for Dataset {
686688
location: location!(),
687689
})
688690
}
691+
692+
async fn read_index_partition(
693+
&self,
694+
index_name: &str,
695+
partition_id: usize,
696+
with_vector: bool,
697+
) -> Result<SendableRecordBatchStream> {
698+
let indices = self.load_indices_by_name(index_name).await?;
699+
if indices.is_empty() {
700+
return Err(Error::IndexNotFound {
701+
identity: format!("name={}", index_name),
702+
location: location!(),
703+
});
704+
}
705+
let column = self.schema().field_by_id(indices[0].fields[0]).unwrap();
706+
707+
let mut schema: Option<Arc<Schema>> = None;
708+
let mut partition_streams = Vec::with_capacity(indices.len());
709+
for index in indices {
710+
let index = self
711+
.open_vector_index(&column.name, &index.uuid.to_string())
712+
.await?;
713+
714+
let stream = index.partition_reader(partition_id, with_vector).await?;
715+
if schema.is_none() {
716+
schema = Some(stream.schema());
717+
}
718+
partition_streams.push(stream);
719+
}
720+
721+
match schema {
722+
Some(schema) => {
723+
let merged = stream::select_all(partition_streams);
724+
let stream = RecordBatchStreamAdapter::new(schema, merged);
725+
Ok(Box::pin(stream))
726+
}
727+
None => Ok(Box::pin(RecordBatchStreamAdapter::new(
728+
Arc::new(Schema::empty()),
729+
stream::empty(),
730+
))),
731+
}
732+
}
689733
}
690734

691735
/// A trait for internal dataset utilities
@@ -775,14 +819,8 @@ impl DatasetIndexInternalExt for Dataset {
775819
match &proto.implementation {
776820
Some(Implementation::VectorIndex(vector_index)) => {
777821
let dataset = Arc::new(self.clone());
778-
crate::index::vector::open_vector_index(
779-
dataset,
780-
column,
781-
uuid,
782-
vector_index,
783-
reader,
784-
)
785-
.await
822+
crate::index::vector::open_vector_index(dataset, uuid, vector_index, reader)
823+
.await
786824
}
787825
None => Err(Error::Internal {
788826
message: "Index proto was missing implementation field".into(),

rust/lance/src/index/vector.rs

-1
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,6 @@ pub(crate) async fn remap_vector_index(
445445
#[instrument(level = "debug", skip(dataset, vec_idx, reader))]
446446
pub(crate) async fn open_vector_index(
447447
dataset: Arc<Dataset>,
448-
column: &str,
449448
uuid: &str,
450449
vec_idx: &lance_index::pb::VectorIndex,
451450
reader: Arc<dyn Reader>,

rust/lance/src/index/vector/fixture_test.rs

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ mod test {
1717
use arrow_array::{FixedSizeListArray, Float32Array, RecordBatch, UInt32Array};
1818
use arrow_schema::{DataType, Field, Schema};
1919
use async_trait::async_trait;
20+
use datafusion::execution::SendableRecordBatchStream;
2021
use deepsize::{Context, DeepSizeOf};
2122
use lance_arrow::FixedSizeListArrayExt;
2223
use lance_index::vector::ivf::storage::IvfModel;
@@ -142,6 +143,10 @@ mod test {
142143
Ok(())
143144
}
144145

146+
async fn to_batch_stream(&self, _with_vector: bool) -> Result<SendableRecordBatchStream> {
147+
unimplemented!("only for SubIndex")
148+
}
149+
145150
fn ivf_model(&self) -> IvfModel {
146151
unimplemented!("only for IVF")
147152
}

0 commit comments

Comments
 (0)