diff --git a/python/python/lance/__init__.py b/python/python/lance/__init__.py index 83b22cf521..54834a5240 100644 --- a/python/python/lance/__init__.py +++ b/python/python/lance/__init__.py @@ -11,6 +11,7 @@ from .dataset import ( DataStatistics, FieldStatistics, + Index, LanceDataset, LanceOperation, LanceScanner, @@ -47,6 +48,7 @@ "LanceScanner", "MergeInsertBuilder", "Transaction", + "Index", "__version__", "bytes_read_counter", "iops_counter", diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 57dacdb580..a6c8c2b390 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -29,6 +29,7 @@ Tuple, TypedDict, Union, + overload, ) import pyarrow as pa @@ -79,6 +80,10 @@ Iterable[float], ] +IndexType = Literal["BTREE", "BITMAP", "LABEL_LIST", "INVERTED", "FTS", "NGRAM"] + +SearchType = Literal["DfsQueryThenFetch", "QueryThenFetch"] + class MergeInsertBuilder(_MergeInsertBuilder): def execute(self, data_obj: ReaderLike, *, schema: Optional[pa.Schema] = None): @@ -412,6 +417,11 @@ def scanner( currently only supports a single column in the columns list. - query: str The query string to search for. + - search_type: SearchType + The search type to use, default is QueryThenFetch. + QueryThenFetch: filter the results directly. + DfsQueryThenFetch: collect frequency of the indexed column first, + then filter the results. Cost more time but more accurate. fast_search: bool, default False If True, then the search will only be performed on the indexed data, which yields faster search time. @@ -623,6 +633,11 @@ def to_table( currently only supports a single column in the columns list. - query: str The query string to search for. + - search_type: SearchType + The search type to use, default is QueryThenFetch. + QueryThenFetch: filter the results directly. + DfsQueryThenFetch: collect frequency of the indexed column first, + then filter the results. Cost more time but more accurate. include_deleted_rows: bool, optional, default False If True, then rows that have been deleted, but are still present in the fragment, will be returned. These rows will have the _rowid column set @@ -1502,22 +1517,42 @@ def cleanup_old_versions( td_to_micros(older_than), delete_unverified, error_if_tagged_old_versions ) + if TYPE_CHECKING: + + @overload + def create_scalar_index( + self, + column: str, + index_type: IndexType, + name: Optional[str] = None, + *, + fragment_ids: List[int], + replace: bool = True, + **kwargs, + ) -> Index: ... + + @overload + def create_scalar_index( + self, + column: str, + index_type: IndexType, + name: Optional[str] = None, + *, + fragment_ids: None = None, + replace: bool = True, + **kwargs, + ) -> LanceDataset: ... + def create_scalar_index( self, column: str, - index_type: Union[ - Literal["BTREE"], - Literal["BITMAP"], - Literal["LABEL_LIST"], - Literal["INVERTED"], - Literal["FTS"], - Literal["NGRAM"], - ], + index_type: IndexType, name: Optional[str] = None, *, + fragment_ids: Optional[List[int]] = None, replace: bool = True, **kwargs, - ): + ) -> Index | LanceDataset: """Create a scalar index on a column. Scalar indices, like vector indices, can be used to speed up scans. A scalar @@ -1590,6 +1625,9 @@ def create_scalar_index( name : str, optional The index name. If not provided, it will be generated from the column name. + fragment_ids: list of int, optional + The fragment ids to create the index on. If not provided, the index will + be created on all fragments. replace : bool, default True Replace the existing index if it exists. @@ -1625,6 +1663,13 @@ def create_scalar_index( non-ascii characters to ascii characters if possible. This would remove accents like "é" -> "e". + Returns + ------- + index : Index | LanceDataset + Returns Index object if the fragment_ids is provided. Commit the index + to the dataset later with commit() method. + Returns LanceDataset if the fragment_ids is not provided. + Examples -------- @@ -1711,8 +1756,76 @@ def create_scalar_index( raise TypeError( f"Scalar index column {column} cannot currently be a duration" ) - + if fragment_ids is not None: + return self._ds.create_fragment_index( + [column], index_type, name, replace, None, fragment_ids, kwargs + ) self._ds.create_index([column], index_type, name, replace, None, kwargs) + return self + + if TYPE_CHECKING: + + @overload + def create_index( + self, + column: Union[str, List[str]], + index_type: str, + name: Optional[str] = None, + metric: str = "L2", + replace: bool = False, + num_partitions: Optional[int] = None, + ivf_centroids: Optional[ + Union[np.ndarray, pa.FixedSizeListArray, pa.FixedShapeTensorArray] + ] = None, + pq_codebook: Optional[ + Union[np.ndarray, pa.FixedSizeListArray, pa.FixedShapeTensorArray] + ] = None, + num_sub_vectors: Optional[int] = None, + accelerator: Optional[Union[str, "torch.Device"]] = None, + index_cache_size: Optional[int] = None, + shuffle_partition_batches: Optional[int] = None, + shuffle_partition_concurrency: Optional[int] = None, + # experimental parameters + ivf_centroids_file: Optional[str] = None, + precomputed_partition_dataset: Optional[str] = None, + storage_options: Optional[Dict[str, str]] = None, + filter_nan: bool = True, + one_pass_ivfpq: bool = False, + *, + fragment_ids: List[int], + **kwargs, + ) -> Index: ... + + @overload + def create_index( + self, + column: Union[str, List[str]], + index_type: str, + name: Optional[str] = None, + metric: str = "L2", + replace: bool = False, + num_partitions: Optional[int] = None, + ivf_centroids: Optional[ + Union[np.ndarray, pa.FixedSizeListArray, pa.FixedShapeTensorArray] + ] = None, + pq_codebook: Optional[ + Union[np.ndarray, pa.FixedSizeListArray, pa.FixedShapeTensorArray] + ] = None, + num_sub_vectors: Optional[int] = None, + accelerator: Optional[Union[str, "torch.Device"]] = None, + index_cache_size: Optional[int] = None, + shuffle_partition_batches: Optional[int] = None, + shuffle_partition_concurrency: Optional[int] = None, + # experimental parameters + ivf_centroids_file: Optional[str] = None, + precomputed_partition_dataset: Optional[str] = None, + storage_options: Optional[Dict[str, str]] = None, + filter_nan: bool = True, + one_pass_ivfpq: bool = False, + *, + fragment_ids: None = None, + **kwargs, + ) -> LanceDataset: ... def create_index( self, @@ -1739,8 +1852,10 @@ def create_index( storage_options: Optional[Dict[str, str]] = None, filter_nan: bool = True, one_pass_ivfpq: bool = False, + *, + fragment_ids: Optional[List[int]] = None, **kwargs, - ) -> LanceDataset: + ) -> LanceDataset | Index: """Create index on column. **Experimental API** @@ -1805,10 +1920,18 @@ def create_index( for nullable columns. Obtains a small speed boost. one_pass_ivfpq: bool Defaults to False. If enabled, index type must be "IVF_PQ". Reduces disk IO. + fragment_ids: list of int, optional + The fragment ids to create the index on. If not provided, the index will + be created on all fragments. kwargs : Parameters passed to the index building process. - + Returns + ------- + index : Index | LanceDataset + Returns Index object if the fragment_ids is provided. Commit the index + to the dataset later with commit() method. + Returns LanceDataset if the fragment_ids is not provided. The SQ (Scalar Quantization) is available for only ``IVF_HNSW_SQ`` index type, this quantization method is used to reduce the memory usage of the index, @@ -2259,6 +2382,18 @@ def drop_index(self, name: str): """ return self._ds.drop_index(name) + def unindexed_fragments(self, name: str) -> List[FragmentMetadata]: + """ + Return the fragments that are not covered by any of the deltas of the index. + """ + return self._ds.unindexed_fragments(name) + + def indexed_fragments(self, name: str) -> List[List[FragmentMetadata]]: + """ + Return the fragments that are covered by each of the deltas of the index. + """ + return self._ds.indexed_fragments(name) + def session(self) -> Session: """ Return the dataset session, which holds the dataset's state. @@ -2276,7 +2411,9 @@ def _commit( "LanceDataset._commit() is deprecated, use LanceDataset.commit() instead", DeprecationWarning, ) - return LanceDataset.commit(base_uri, operation, read_version, commit_lock) + return LanceDataset.commit( + base_uri, operation, read_version=read_version, commit_lock=commit_lock + ) @staticmethod def commit( @@ -2610,9 +2747,8 @@ class ExecuteResult(TypedDict): class Index(TypedDict): name: str - type: str uuid: str - fields: List[str] + fields: List[int] version: int fragment_ids: Set[int] @@ -2954,11 +3090,8 @@ class CreateIndex(BaseOperation): Operation that creates an index on the dataset. """ - uuid: str - name: str - fields: List[int] - dataset_version: int - fragment_ids: Set[int] + new_indices: List[Index] + removed_indices: List[Index] @dataclass class DataReplacementGroup: @@ -3293,6 +3426,7 @@ def full_text_search( self, query: str, columns: Optional[List[str]] = None, + search_type: SearchType = "QueryThenFetch", ) -> ScannerBuilder: """ Filter rows by full text searching. *Experimental API*, @@ -3300,7 +3434,11 @@ def full_text_search( Must create inverted index on the given column before searching, """ - self._full_text_query = {"query": query, "columns": columns} + self._full_text_query = { + "query": query, + "columns": columns, + "search_type": search_type, + } return self def to_scanner(self) -> LanceScanner: diff --git a/python/python/lance/lance/__init__.pyi b/python/python/lance/lance/__init__.pyi index 82fe3eed7c..c5f1887f74 100644 --- a/python/python/lance/lance/__init__.pyi +++ b/python/python/lance/lance/__init__.pyi @@ -266,7 +266,19 @@ class _Dataset: storage_options: Optional[Dict[str, str]] = None, kwargs: Optional[Dict[str, Any]] = None, ): ... + def create_fragment_index( + self, + columns: List[str], + index_type: str, + name: Optional[str] = None, + replace: Optional[bool] = None, + storage_options: Optional[Dict[str, str]] = None, + fragment_ids: Optional[List[int]] = None, + kwargs: Optional[Dict[str, Any]] = None, + ) -> Index: ... def drop_index(self, name: str): ... + def unindexed_fragments(self, name: str) -> List[FragmentMetadata]: ... + def indexed_fragments(self, name: str) -> List[List[FragmentMetadata]]: ... def count_fragments(self) -> int: ... def num_small_files(self, max_rows_per_group: int) -> int: ... def get_fragments(self) -> List[_Fragment]: ... diff --git a/python/python/tests/test_commit_index.py b/python/python/tests/test_commit_index.py index fa2eeacabc..85be3f607d 100644 --- a/python/python/tests/test_commit_index.py +++ b/python/python/tests/test_commit_index.py @@ -5,6 +5,7 @@ import shutil import string from pathlib import Path +from typing import List import lance import numpy as np @@ -57,11 +58,18 @@ def test_commit_index(dataset_with_index, test_table, tmp_path): # Commit the index to dataset_without_index field_idx = dataset_without_index.schema.get_field_index("meta") create_index_op = lance.LanceOperation.CreateIndex( - index_id, - "meta_idx", - [field_idx], - dataset_without_index.version, - set([f.fragment_id for f in dataset_without_index.get_fragments()]), + new_indices=[ + lance.Index( + uuid=index_id, + name="meta_idx", + fields=[field_idx], + version=dataset_without_index.version, + fragment_ids=set( + [f.fragment_id for f in dataset_without_index.get_fragments()] + ), + ) + ], + removed_indices=[], ) dataset_without_index = lance.LanceDataset.commit( dataset_without_index.uri, @@ -84,3 +92,189 @@ def test_commit_index(dataset_with_index, test_table, tmp_path): ) plan = scanner.explain_plan() assert "MaterializeIndex" in plan + + +@pytest.fixture() +def tmp_tables() -> List[pa.Table]: + tables = [ + { + "text": [ + "Frodo was a puppy", + "There were several kittens playing", + ], + "sentiment": ["neutral", "neutral"], + }, + { + "text": [ + "Frodo was a happy puppy", + "Frodo was a very happy puppy", + ], + "sentiment": ["positive", "positive"], + }, + { + "text": [ + "Frodo was a sad puppy", + "Frodo was a very sad puppy", + ], + "sentiment": ["negative", "negative"], + }, + ] + for tb in tables: + tb["text2"] = tb["text"] + tb["text3"] = tb["text"] + return [pa.table(tb) for tb in tables] + + +def test_indexed_unindexed_fragments(tmp_tables, tmp_path): + ds = lance.write_dataset(tmp_tables[0], tmp_path, mode="overwrite") + ds = lance.write_dataset(tmp_tables[1], tmp_path, mode="append") + ds = lance.write_dataset(tmp_tables[2], tmp_path, mode="append") + frags = [f for f in ds.get_fragments()] + index = ds.create_scalar_index( + "text", "INVERTED", fragment_ids=[frags[0].fragment_id] + ) + assert isinstance(index, dict) + + indices = [index] + create_index_op = lance.LanceOperation.CreateIndex( + new_indices=indices, + removed_indices=[], + ) + ds = lance.LanceDataset.commit( + ds.uri, + create_index_op, + read_version=ds.version, + ) + + unindexed_fragments = ds.unindexed_fragments("text_idx") + assert len(unindexed_fragments) == 2 + assert unindexed_fragments[0].id == frags[1].fragment_id + assert unindexed_fragments[1].id == frags[2].fragment_id + + indexed_fragments = [f for fs in ds.indexed_fragments("text_idx") for f in fs] + assert len(indexed_fragments) == 1 + assert indexed_fragments[0].id == frags[0].fragment_id + + +def test_dfs_query_then_fetch(tmp_tables, tmp_path): + ds = lance.write_dataset(tmp_tables[0], tmp_path, mode="overwrite") + ds = lance.write_dataset(tmp_tables[1], tmp_path, mode="append") + ds = lance.write_dataset(tmp_tables[2], tmp_path, mode="append") + indices = [] + frags = list(ds.get_fragments()) + for f in frags[:2]: + # we can create an inverted index distributely + index = ds.create_scalar_index("text", "INVERTED", fragment_ids=[f.fragment_id]) + assert isinstance(index, dict) + indices.append(index) + + index = ds.create_scalar_index( + "text2", "INVERTED", fragment_ids=[frags[0].fragment_id, frags[1].fragment_id] + ) + indices.append(index) + index = ds.create_scalar_index( + "text3", "INVERTED", fragment_ids=[frags[0].fragment_id] + ) + indices.append(index) + create_index_op = lance.LanceOperation.CreateIndex( + new_indices=indices, + removed_indices=[], + ) + + ds = lance.LanceDataset.commit( + ds.uri, + create_index_op, + read_version=ds.version, + ) + + # test query then fetch + text_query_fetch = ds.to_table( + full_text_query={"columns": ["text"], "query": "puppy"}, + prefilter=True, + with_row_id=True, + ) + assert sorted(text_query_fetch["_rowid"].to_pylist()) == [ + 0, + 1 << 32, + (1 << 32) + 1, + 2 << 32, + (2 << 32) + 1, + ] + + # test dfs query then fetch + text_dfs_query_fetch = ds.to_table( + full_text_query={ + "columns": ["text"], + "query": "puppy", + "search_type": "DfsQueryThenFetch", + }, + prefilter=True, + with_row_id=True, + ) + assert sorted(text_dfs_query_fetch["_rowid"].to_pylist()) == [ + 0, + 1 << 32, + (1 << 32) + 1, + 2 << 32, + (2 << 32) + 1, + ] + + def table_to_tuple(tb): + return list(zip(tb["_rowid"].to_pylist(), tb["_score"].to_pylist())) + + # it should be the same as dfs query then fetch for column text + text2_query_fetch = ds.to_table( + full_text_query={"columns": ["text2"], "query": "puppy"}, + prefilter=True, + with_row_id=True, + ) + assert sorted(table_to_tuple(text2_query_fetch)) == sorted( + table_to_tuple(text_dfs_query_fetch) + ) + + # for column text2, it should be the same as query then fetch + text2_dfs_query_fetch = ds.to_table( + full_text_query={ + "columns": ["text2"], + "query": "puppy", + "search_type": "DfsQueryThenFetch", + }, + prefilter=True, + with_row_id=True, + ) + assert sorted(table_to_tuple(text2_query_fetch)) == sorted( + table_to_tuple(text2_dfs_query_fetch) + ) + + text3_dfs_neutral = ds.to_table( + full_text_query={ + "columns": ["text3"], + "query": "puppy", + "search_type": "DfsQueryThenFetch", + }, + filter="sentiment='neutral'", + prefilter=True, + with_row_id=True, + ) + assert ( + sorted(table_to_tuple(text3_dfs_neutral)) + == sorted(table_to_tuple(text_query_fetch))[:1] + ) + + text3_neutral = ds.to_table( + full_text_query={"columns": ["text3"], "query": "puppy"}, + filter="sentiment='neutral'", + prefilter=True, + with_row_id=True, + ) + assert sorted(table_to_tuple(text3_neutral)) == sorted( + table_to_tuple(text3_dfs_neutral) + ) + + text_neutral = ds.to_table( + full_text_query={"columns": ["text"], "query": "puppy"}, + filter="sentiment='neutral'", + prefilter=True, + with_row_id=True, + ) + assert sorted(table_to_tuple(text_neutral)) == sorted(table_to_tuple(text3_neutral)) diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 39a00d7fb1..c5e850eef3 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -50,7 +50,7 @@ use lance::dataset::{ColumnAlteration, ProjectionRequest}; use lance::index::vector::utils::get_vector_type; use lance::index::{vector::VectorIndexParams, DatasetIndexInternalExt}; use lance_arrow::as_fixed_size_list_array; -use lance_index::scalar::InvertedIndexParams; +use lance_index::scalar::{InvertedIndexParams, SearchType}; use lance_index::{ optimize::OptimizeOptions, scalar::{FullTextSearchQuery, ScalarIndexParams, ScalarIndexType}, @@ -62,7 +62,7 @@ use lance_index::{ }; use lance_io::object_store::ObjectStoreParams; use lance_linalg::distance::MetricType; -use lance_table::format::Fragment; +use lance_table::format::{Fragment, Index}; use lance_table::io::commit::CommitHandler; use object_store::path::Path; use pyo3::exceptions::{PyStopIteration, PyTypeError}; @@ -82,7 +82,7 @@ use crate::file::object_store_from_uri_or_path; use crate::fragment::FileFragment; use crate::schema::LanceSchema; use crate::session::Session; -use crate::utils::PyLance; +use crate::utils::{export_vec, PyLance}; use crate::RT; use crate::{LanceReader, Scanner}; @@ -561,7 +561,21 @@ impl Dataset { } else { None }; - let full_text_query = FullTextSearchQuery::new(query).columns(columns); + let search_type = full_text_query.get_item("search_type")?; + let search_type = match search_type { + Some(search_type) => { + let search_type = search_type.to_string().to_lowercase(); + if search_type == SearchType::DfsQueryThenFetch.to_string().to_lowercase() { + SearchType::DfsQueryThenFetch + } else { + SearchType::QueryThenFetch + } + } + None => SearchType::QueryThenFetch, + }; + let full_text_query = FullTextSearchQuery::new(query) + .columns(columns) + .search_type(search_type); scanner .full_text_search(full_text_query) .map_err(|err| PyValueError::new_err(err.to_string()))?; @@ -1164,6 +1178,45 @@ impl Dataset { Ok(()) } + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = (columns, index_type, name = None, replace = None, storage_options = None, fragment_ids = None, kwargs = None))] + fn create_fragment_index( + &mut self, + columns: Vec, + index_type: &str, + name: Option, + replace: Option, + storage_options: Option>, + fragment_ids: Option>, + kwargs: Option<&Bound>, + ) -> PyResult> { + let columns: Vec<&str> = columns.iter().map(|s| &**s).collect(); + let index_type = index_type.to_uppercase(); + let idx_type = self.parse_index_type(&index_type)?; + log::info!("Creating index: type={}", index_type); + let params: Box = + self.parse_index_params(&columns, &index_type, storage_options, kwargs)?; + + let replace = replace.unwrap_or(true); + + let mut new_self = self.ds.as_ref().clone(); + let res = RT + .block_on( + None, + new_self.create_fragment_index( + &columns, + idx_type, + name, + params.as_ref(), + replace, + fragment_ids, + ), + )? + .map_err(|err| PyIOError::new_err(err.to_string()))?; + self.ds = Arc::new(new_self); + Ok(PyLance(res)) + } + #[pyo3(signature = (columns, index_type, name = None, replace = None, storage_options = None, kwargs = None))] fn create_index( &mut self, @@ -1176,90 +1229,10 @@ impl Dataset { ) -> PyResult<()> { let columns: Vec<&str> = columns.iter().map(|s| &**s).collect(); let index_type = index_type.to_uppercase(); - let idx_type = match index_type.as_str() { - "BTREE" => IndexType::Scalar, - "BITMAP" => IndexType::Bitmap, - "NGRAM" => IndexType::NGram, - "LABEL_LIST" => IndexType::LabelList, - "INVERTED" | "FTS" => IndexType::Inverted, - "IVF_FLAT" | "IVF_PQ" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => IndexType::Vector, - _ => { - return Err(PyValueError::new_err(format!( - "Index type '{index_type}' is not supported." - ))) - } - }; - + let idx_type = self.parse_index_type(&index_type)?; log::info!("Creating index: type={}", index_type); - let params: Box = match index_type.as_str() { - "BTREE" => Box::::default(), - "BITMAP" => Box::new(ScalarIndexParams { - // Temporary workaround until we add support for auto-detection of scalar index type - force_index_type: Some(ScalarIndexType::Bitmap), - }), - "NGRAM" => Box::new(ScalarIndexParams { - force_index_type: Some(ScalarIndexType::NGram), - }), - "LABEL_LIST" => Box::new(ScalarIndexParams { - force_index_type: Some(ScalarIndexType::LabelList), - }), - "INVERTED" | "FTS" => { - let mut params = InvertedIndexParams::default(); - if let Some(kwargs) = kwargs { - if let Some(with_position) = kwargs.get_item("with_position")? { - params.with_position = with_position.extract()?; - } - if let Some(base_tokenizer) = kwargs.get_item("base_tokenizer")? { - params.tokenizer_config = params - .tokenizer_config - .base_tokenizer(base_tokenizer.extract()?); - } - if let Some(language) = kwargs.get_item("language")? { - let language: PyBackedStr = - language.downcast::()?.clone().try_into()?; - params.tokenizer_config = - params.tokenizer_config.language(&language).map_err(|e| { - PyValueError::new_err(format!( - "can't set tokenizer language to {}: {:?}", - language, e - )) - })?; - } - if let Some(max_token_length) = kwargs.get_item("max_token_length")? { - params.tokenizer_config = params - .tokenizer_config - .max_token_length(max_token_length.extract()?); - } - if let Some(lower_case) = kwargs.get_item("lower_case")? { - params.tokenizer_config = - params.tokenizer_config.lower_case(lower_case.extract()?); - } - if let Some(stem) = kwargs.get_item("stem")? { - params.tokenizer_config = params.tokenizer_config.stem(stem.extract()?); - } - if let Some(remove_stop_words) = kwargs.get_item("remove_stop_words")? { - params.tokenizer_config = params - .tokenizer_config - .remove_stop_words(remove_stop_words.extract()?); - } - if let Some(ascii_folding) = kwargs.get_item("ascii_folding")? { - params.tokenizer_config = params - .tokenizer_config - .ascii_folding(ascii_folding.extract()?); - } - } - Box::new(params) - } - _ => { - let column_type = match self.ds.schema().field(columns[0]) { - Some(f) => f.data_type().clone(), - None => { - return Err(PyValueError::new_err("Column not found in dataset schema.")) - } - }; - prepare_vector_index_params(&index_type, &column_type, storage_options, kwargs)? - } - }; + let params: Box = + self.parse_index_params(&columns, &index_type, storage_options, kwargs)?; let replace = replace.unwrap_or(true); @@ -1283,6 +1256,33 @@ impl Dataset { Ok(()) } + fn unindexed_fragments(&self, name: &str) -> PyResult { + let result = RT + .block_on(None, self.ds.unindexed_fragments(name))? + .map_err(|err| PyIOError::new_err(err.to_string()))?; + + Python::with_gil(|py| { + let py_vec = export_vec(py, &result)?; + PyList::new(py, py_vec).map(|list| list.into()) + }) + } + + fn indexed_fragments(&self, name: &str) -> PyResult { + let result = RT + .block_on(None, self.ds.indexed_fragments(name))? + .map_err(|err| PyIOError::new_err(err.to_string()))?; + Python::with_gil(|py| { + let result = result + .iter() + .map(|vec| { + let py_vec = export_vec(py, vec)?; + PyList::new(py, py_vec).map(|list| list.into()) + }) + .collect::, _>>()?; + PyList::new(py, result).map(|list| list.into()) + }) + } + fn count_fragments(&self) -> usize { self.ds.count_fragments() } @@ -1619,6 +1619,98 @@ impl Dataset { fn list_tags(&self) -> ::lance::error::Result> { RT.runtime.block_on(self.ds.tags.list()) } + + fn parse_index_params( + &mut self, + columns: &[&str], + index_type: &str, + storage_options: Option>, + kwargs: Option<&Bound>, + ) -> PyResult> { + match index_type { + "BTREE" => Ok(Box::::default()), + "BITMAP" => Ok(Box::new(ScalarIndexParams { + // Temporary workaround until we add support for auto-detection of scalar index type + force_index_type: Some(ScalarIndexType::Bitmap), + })), + "NGRAM" => Ok(Box::new(ScalarIndexParams { + force_index_type: Some(ScalarIndexType::NGram), + })), + "LABEL_LIST" => Ok(Box::new(ScalarIndexParams { + force_index_type: Some(ScalarIndexType::LabelList), + })), + "INVERTED" | "FTS" => { + let mut params = InvertedIndexParams::default(); + if let Some(kwargs) = kwargs { + if let Some(with_position) = kwargs.get_item("with_position")? { + params.with_position = with_position.extract()?; + } + if let Some(base_tokenizer) = kwargs.get_item("base_tokenizer")? { + params.tokenizer_config = params + .tokenizer_config + .base_tokenizer(base_tokenizer.extract()?); + } + if let Some(language) = kwargs.get_item("language")? { + let language: PyBackedStr = + language.downcast::()?.clone().try_into()?; + params.tokenizer_config = + params.tokenizer_config.language(&language).map_err(|e| { + PyValueError::new_err(format!( + "can't set tokenizer language to {}: {:?}", + language, e + )) + })?; + } + if let Some(max_token_length) = kwargs.get_item("max_token_length")? { + params.tokenizer_config = params + .tokenizer_config + .max_token_length(max_token_length.extract()?); + } + if let Some(lower_case) = kwargs.get_item("lower_case")? { + params.tokenizer_config = + params.tokenizer_config.lower_case(lower_case.extract()?); + } + if let Some(stem) = kwargs.get_item("stem")? { + params.tokenizer_config = params.tokenizer_config.stem(stem.extract()?); + } + if let Some(remove_stop_words) = kwargs.get_item("remove_stop_words")? { + params.tokenizer_config = params + .tokenizer_config + .remove_stop_words(remove_stop_words.extract()?); + } + if let Some(ascii_folding) = kwargs.get_item("ascii_folding")? { + params.tokenizer_config = params + .tokenizer_config + .ascii_folding(ascii_folding.extract()?); + } + } + Ok(Box::new(params)) + } + _ => { + let column_type = match self.ds.schema().field(columns[0]) { + Some(f) => f.data_type().clone(), + None => { + return Err(PyValueError::new_err("Column not found in dataset schema.")) + } + }; + prepare_vector_index_params(index_type, &column_type, storage_options, kwargs) + } + } + } + + fn parse_index_type(&self, index_type: &str) -> PyResult { + match index_type { + "BTREE" => Ok(IndexType::Scalar), + "BITMAP" => Ok(IndexType::Bitmap), + "NGRAM" => Ok(IndexType::NGram), + "LABEL_LIST" => Ok(IndexType::LabelList), + "INVERTED" | "FTS" => Ok(IndexType::Inverted), + "IVF_FLAT" | "IVF_PQ" | "IVF_HNSW_PQ" | "IVF_HNSW_SQ" => Ok(IndexType::Vector), + _ => Err(PyValueError::new_err(format!( + "Index type '{index_type}' is not supported." + ))), + } + } } #[pyfunction(name = "_write_dataset")] diff --git a/python/src/transaction.rs b/python/src/transaction.rs index 33ed60eaa5..c6effb86eb 100644 --- a/python/src/transaction.rs +++ b/python/src/transaction.rs @@ -9,8 +9,8 @@ use lance::dataset::transaction::{ use lance::datatypes::Schema; use lance_table::format::{DataFile, Fragment, Index}; use pyo3::exceptions::PyValueError; -use pyo3::types::PySet; -use pyo3::{intern, prelude::*}; +use pyo3::types::{PyDict, PyList, PyNone, PySet}; +use pyo3::{intern, prelude::*, PyTypeCheck}; use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python}; use uuid::Uuid; @@ -47,6 +47,78 @@ impl<'py> IntoPyObject<'py> for PyLance<&DataReplacementGroup> { } } +impl FromPyObject<'_> for PyLance { + fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { + let uuid = ob.get_item("uuid")?.to_string(); + let name = ob.get_item("name")?.extract()?; + let fields = ob.get_item("fields")?.extract()?; + let dataset_version = ob.get_item("version")?.extract()?; + + let fragment_ids = ob.get_item("fragment_ids")?; + let fragment_ids = if PySet::type_check(&fragment_ids) { + let fragment_ids_ref: &Bound<'_, PySet> = fragment_ids.downcast()?; + fragment_ids_ref + .into_iter() + .map(|id| id.extract()) + .collect::>>()? + } else if PyList::type_check(&fragment_ids) { + let fragment_ids_ref: &Bound<'_, PyList> = fragment_ids.downcast()?; + fragment_ids_ref + .into_iter() + .map(|id| id.extract()) + .collect::>>()? + } else { + return Err(PyValueError::new_err("Invalid fragment_ids")); + }; + let fragment_bitmap = Some(fragment_ids.into_iter().collect()); + Ok(Self(Index { + uuid: Uuid::parse_str(&uuid).map_err(|e| PyValueError::new_err(e.to_string()))?, + name, + fields, + dataset_version, + fragment_bitmap, + // TODO: we should use lance::dataset::Dataset::commit_existing_index once + // we have a way to determine index details from an existing index. + index_details: None, + })) + } +} + +impl<'py> IntoPyObject<'py> for PyLance<&Index> { + type Target = PyDict; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + let uuid = self.0.uuid.to_string().into_pyobject(py)?; + let name = self.0.name.clone().into_pyobject(py)?; + let fields = export_vec(py, &self.0.fields)?; + let dataset_version = self.0.dataset_version.into_pyobject(py)?; + let fragment_ids = match &self.0.fragment_bitmap { + Some(bitmap) => bitmap.into_iter().collect::>().into_pyobject(py)?, + None => PyNone::get(py).to_owned().into_any(), + }; + + let kwargs = PyDict::new(py); + kwargs.set_item("uuid", uuid).unwrap(); + kwargs.set_item("name", name).unwrap(); + kwargs.set_item("fields", fields).unwrap(); + kwargs.set_item("version", dataset_version).unwrap(); + kwargs.set_item("fragment_ids", fragment_ids).unwrap(); + Ok(kwargs) + } +} + +impl<'py> IntoPyObject<'py> for PyLance { + type Target = PyDict; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> Result { + PyLance(&self.0).into_pyobject(py) + } +} + impl FromPyObject<'_> for PyLance { fn extract_bound(ob: &Bound<'_, PyAny>) -> PyResult { match class_name(ob)?.as_str() { @@ -119,33 +191,10 @@ impl FromPyObject<'_> for PyLance { Ok(Self(op)) } "CreateIndex" => { - let uuid = ob.getattr("uuid")?.to_string(); - let name = ob.getattr("name")?.extract()?; - let fields = ob.getattr("fields")?.extract()?; - let dataset_version = ob.getattr("dataset_version")?.extract()?; - - let fragment_ids = ob.getattr("fragment_ids")?; - let fragment_ids_ref: &Bound<'_, PySet> = fragment_ids.downcast()?; - let fragment_ids = fragment_ids_ref - .into_iter() - .map(|id| id.extract()) - .collect::>>()?; - let fragment_bitmap = Some(fragment_ids.into_iter().collect()); - - let new_indices = vec![Index { - uuid: Uuid::parse_str(&uuid) - .map_err(|e| PyValueError::new_err(e.to_string()))?, - name, - fields, - dataset_version, - fragment_bitmap, - // TODO: we should use lance::dataset::Dataset::commit_existing_index once - // we have a way to determine index details from an existing index. - index_details: None, - }]; - + let removed_indices = extract_vec(&ob.getattr("removed_indices")?)?; + let new_indices = extract_vec(&ob.getattr("new_indices")?)?; let op = Operation::CreateIndex { - removed_indices: Vec::new(), + removed_indices, new_indices, }; Ok(Self(op)) @@ -211,6 +260,17 @@ impl<'py> IntoPyObject<'py> for PyLance<&Operation> { .expect("Failed to get Update class"); cls.call1((removed_fragment_ids, updated_fragments, new_fragments)) } + Operation::CreateIndex { + removed_indices, + new_indices, + } => { + let removed_indices = export_vec(py, removed_indices.as_slice())?; + let new_indices = export_vec(py, new_indices.as_slice())?; + let cls = namespace + .getattr("CreateIndex") + .expect("Failed to get CreateIndex class"); + cls.call1((removed_indices, new_indices)) + } Operation::DataReplacement { replacements } => { let replacements = export_vec(py, replacements.as_slice())?; let cls = namespace diff --git a/python/src/utils.rs b/python/src/utils.rs index 1b1a78cd02..389618dc82 100644 --- a/python/src/utils.rs +++ b/python/src/utils.rs @@ -281,3 +281,13 @@ pub fn class_name(ob: &Bound<'_, PyAny>) -> PyResult { None => Ok(full_name), } } + +impl<'py> IntoPyObject<'py> for PyLance<&i32> { + type Target = PyAny; + type Output = Bound<'py, Self::Target>; + type Error = PyErr; + + fn into_pyobject(self, py: Python<'py>) -> PyResult { + self.0.into_bound_py_any(py) + } +} diff --git a/rust/lance-index/src/scalar.rs b/rust/lance-index/src/scalar.rs index 7cba5815e6..58a65f8132 100644 --- a/rust/lance-index/src/scalar.rs +++ b/rust/lance-index/src/scalar.rs @@ -4,7 +4,7 @@ //! Scalar indices for metadata search & filtering use std::collections::HashMap; -use std::fmt::Debug; +use std::fmt::{Debug, Display}; use std::{any::Any, ops::Bound, sync::Arc}; use arrow::buffer::{OffsetBuffer, ScalarBuffer}; @@ -250,6 +250,21 @@ impl PartialEq for dyn AnyQuery { } } +#[derive(Debug, Clone, PartialEq)] +pub enum SearchType { + QueryThenFetch, + DfsQueryThenFetch, +} + +impl Display for SearchType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::QueryThenFetch => write!(f, "QueryThenFetch"), + Self::DfsQueryThenFetch => write!(f, "DfsQueryThenFetch"), + } + } +} + /// A full text search query #[derive(Debug, Clone, PartialEq)] pub struct FullTextSearchQuery { @@ -265,6 +280,8 @@ pub struct FullTextSearchQuery { /// Increasing this value will reduce the recall and improve the performance /// 1.0 is the value that would give the best performance without recall loss pub wand_factor: Option, + + pub search_type: SearchType, } impl FullTextSearchQuery { @@ -274,6 +291,7 @@ impl FullTextSearchQuery { limit: None, columns: vec![], wand_factor: None, + search_type: SearchType::QueryThenFetch, } } @@ -284,6 +302,11 @@ impl FullTextSearchQuery { self } + pub fn search_type(mut self, search_type: SearchType) -> Self { + self.search_type = search_type; + self + } + pub fn limit(mut self, limit: Option) -> Self { self.limit = limit; self diff --git a/rust/lance-index/src/scalar/inverted/index.rs b/rust/lance-index/src/scalar/inverted/index.rs index c2091e2c75..5dbdc2832d 100644 --- a/rust/lance-index/src/scalar/inverted/index.rs +++ b/rust/lance-index/src/scalar/inverted/index.rs @@ -38,7 +38,7 @@ use super::{wand::*, InvertedIndexBuilder, TokenizerConfig}; use crate::prefilter::{NoFilter, PreFilter}; use crate::scalar::{ AnyQuery, FullTextSearchQuery, IndexReader, IndexStore, InvertedIndexParams, SargableQuery, - ScalarIndex, SearchResult, + ScalarIndex, SearchResult, SearchType, }; use crate::Index; @@ -69,6 +69,63 @@ lazy_static! { .unwrap_or(512 * 1024 * 1024); } +#[macro_export] +macro_rules! as_inverted_index { + ($index:expr, $uuid:expr) => { + $index + .as_any() + .downcast_ref::() + .ok_or_else(|| { + DataFusionError::Execution(format!("Index {} is not an inverted index", $uuid,)) + }) + }; +} + +#[derive(Clone, Debug)] +pub struct DistributedFrequency { + posting_lens: HashMap, + num_tokens: u64, + num_docs: usize, +} + +impl DistributedFrequency { + fn avgdl(&self) -> f32 { + self.num_tokens as f32 / self.num_docs as f32 + } + fn posting_frequency(&self, token: &String) -> Option { + self.posting_lens + .get(token) + .map(|posting_len| PostingDistributedFrequency::new(self.num_docs, *posting_len)) + } +} + +#[derive(Clone, Debug)] +pub struct ParsedQuery { + query: FullTextSearchQuery, + tokens: Vec, + dfs: DistributedFrequency, +} + +impl ParsedQuery { + pub fn count(&mut self, index: &InvertedIndex) { + for token in &self.tokens { + let len = index.posting_len(token); + if len > 0 { + match self.dfs.posting_lens.get_mut(token) { + Some(v) => { + (*v) += len; + } + None => { + self.dfs.posting_lens.insert(token.clone(), len); + } + } + } + } + self.dfs.num_tokens += index.num_tokens(); + self.dfs.num_docs += index.num_docs(); + } +} + #[derive(Clone)] pub struct InvertedIndex { params: InvertedIndexParams, @@ -97,6 +154,44 @@ impl DeepSizeOf for InvertedIndex { } impl InvertedIndex { + pub fn parse(&self, query: &FullTextSearchQuery) -> Result { + let mut tokenizer = self.tokenizer.clone(); + let tokens = collect_tokens(&query.query, &mut tokenizer, None); + let posting_lens = tokens + .iter() + .map(|token| { + let posting_len = self.posting_len(token); + (token.clone(), posting_len) + }) + .collect(); + let num_docs = self.num_docs(); + let num_tokens = self.num_tokens(); + let dfs = DistributedFrequency { + posting_lens, + num_docs, + num_tokens, + }; + Ok(ParsedQuery { + query: query.clone(), + tokens, + dfs, + }) + } + + pub fn num_tokens(&self) -> u64 { + self.docs.total_tokens + } + + pub fn num_docs(&self) -> usize { + self.docs.token_count.len() + } + + pub fn posting_len(&self, token: &str) -> usize { + match self.tokens.get(token) { + Some(token_id) => self.inverted_list.posting_len(token_id), + None => 0, + } + } // map tokens to token ids // ignore tokens that are not in the index cause they won't contribute to the search #[instrument(level = "debug", skip_all)] @@ -107,17 +202,17 @@ impl InvertedIndex { .collect() } - #[instrument(level = "debug", skip_all)] - pub async fn full_text_search( - &self, - query: &FullTextSearchQuery, - prefilter: Arc, - ) -> Result> { - let mut tokenizer = self.tokenizer.clone(); - let tokens = collect_tokens(&query.query, &mut tokenizer, None); - let token_ids = self.map(&tokens).into_iter(); - let token_ids = if !is_phrase_query(&query.query) { - token_ids.sorted_unstable().dedup().collect() + fn zip_with_id(&self, texts: &[String]) -> Vec<(String, u32)> { + texts + .iter() + .filter_map(|text| self.tokens.get(text).map(|id| (text.clone(), id))) + .collect() + } + + fn token_ids(&self, tokens: &[String], is_phrase_query: bool) -> Result> { + let token_ids = self.zip_with_id(tokens).into_iter(); + if !is_phrase_query { + Ok(token_ids.dedup_by(|x, y| x.1 == y.1).collect()) } else { if !self.inverted_list.has_positions() { return Err(Error::Index { message: "position is not found but required for phrase queries, try recreating the index with position".to_owned(), location: location!() }); @@ -127,9 +222,27 @@ impl InvertedIndex { if token_ids.len() != tokens.len() { return Ok(Vec::new()); } - token_ids - }; - self.bm25_search(token_ids, query, prefilter).await + Ok(token_ids) + } + } + + pub async fn parsed_search( + &self, + parsed: &ParsedQuery, + prefilter: Arc, + ) -> Result> { + let token_ids = self.token_ids(&parsed.tokens, is_phrase_query(&parsed.query.query))?; + self.bm25_search(token_ids, parsed, prefilter).await + } + + #[instrument(level = "debug", skip_all)] + pub async fn full_text_search( + &self, + query: &FullTextSearchQuery, + prefilter: Arc, + ) -> Result> { + let parsed = self.parse(query)?; + self.parsed_search(&parsed, prefilter).await } // search the documents that contain the query @@ -138,42 +251,56 @@ impl InvertedIndex { #[instrument(level = "debug", skip_all)] async fn bm25_search( &self, - token_ids: Vec, - query: &FullTextSearchQuery, + token_ids: Vec<(String, u32)>, + parsed: &ParsedQuery, prefilter: Arc, ) -> Result> { - let limit = query + let limit = parsed + .query .limit .map(|limit| limit as usize) .unwrap_or(usize::MAX); - let wand_factor = query.wand_factor.unwrap_or(1.0); + let wand_factor = parsed.query.wand_factor.unwrap_or(1.0); let mask = prefilter.mask(); - let is_phrase_query = is_phrase_query(&query.query); - let postings = stream::iter(token_ids) + let is_phrase_query = is_phrase_query(&parsed.query.query); + let token_id_dfs: Vec<(u32, Option)> = token_ids + .into_iter() + .map(|(token, id)| { + let dfs = if parsed.query.search_type == SearchType::DfsQueryThenFetch { + parsed.dfs.posting_frequency(&token) + } else { + None + }; + (id, dfs) + }) + .collect(); + let postings = stream::iter(token_id_dfs) .enumerate() .zip(repeat_with(|| (self.inverted_list.clone(), mask.clone()))) - .map(|((position, token_id), (inverted_list, mask))| async move { - let posting = inverted_list - .posting_list(token_id, is_phrase_query) - .await?; - Result::Ok(PostingIterator::new( - token_id, - position as i32, - posting, - self.docs.len(), - mask, - )) - }) + .map( + |((position, (token_id, posting_frequency)), (inverted_list, mask))| async move { + let posting = inverted_list + .posting_list(token_id, is_phrase_query) + .await?; + Result::Ok(PostingIterator::new( + token_id, + position as i32, + posting, + parsed.dfs.num_docs, + mask, + posting_frequency.map(|dfs| dfs.len), + )) + }, + ) // Use compute count since data hopefully cached .buffered(get_num_compute_intensive_cpus()) .try_collect::>() .await?; - - let mut wand = Wand::new(self.docs.len(), postings.into_iter()); + let avgdl = parsed.dfs.avgdl(); + let mut wand = Wand::new(parsed.dfs.num_docs, postings.into_iter()); wand.search(is_phrase_query, limit, wand_factor, |doc, freq| { - let doc_norm = - K1 * (1.0 - B + B * self.docs.num_tokens(doc) as f32 / self.docs.average_length()); + let doc_norm = K1 * (1.0 - B + B * self.docs.num_tokens(doc) as f32 / avgdl); freq / (freq + doc_norm) }) .await @@ -1031,22 +1158,22 @@ fn do_flat_full_text_search( pub fn flat_bm25_search( batch: RecordBatch, doc_col: &str, - inverted_list: &InvertedListReader, - query_tokens: &HashSet, - query_token_ids: &HashMap>, + parsed: &ParsedQuery, tokenizer: &mut tantivy::tokenizer::TextAnalyzer, - avgdl: f32, - num_docs: usize, ) -> std::result::Result { + let query_tokens: HashSet = parsed.tokens.clone().into_iter().dedup().collect(); + let doc_iter = iter_str_array(&batch[doc_col]); let mut scores = Vec::with_capacity(batch.num_rows()); + let avgdl = parsed.dfs.avgdl(); + for doc in doc_iter { let Some(doc) = doc else { scores.push(0.0); continue; }; - let doc_tokens = collect_tokens(doc, tokenizer, Some(query_tokens)); + let doc_tokens = collect_tokens(doc, tokenizer, Some(&query_tokens)); let doc_norm = K1 * (1.0 - B + B * doc_tokens.len() as f32 / avgdl); let mut doc_token_count = HashMap::new(); for token in doc_tokens { @@ -1056,17 +1183,14 @@ pub fn flat_bm25_search( .or_insert(1); } let mut score = 0.0; - for (token, token_id) in query_token_ids.iter() { + for token in query_tokens.iter() { let freq = doc_token_count.get(token).copied().unwrap_or_default() as f32; - - let idf = if let Some(token_id) = token_id { - // for known token, we just use the index's metadata to calculate the score - // it's not accurate but it's good enough for ranking - idf(inverted_list.posting_len(*token_id), num_docs) + let idf = if let Some(freq) = parsed.dfs.posting_frequency(token) { + idf(freq.len, freq.num_docs) } else { // for unknown token, we set the idf to a very high value // so that the new token will significantly effect the score - idf(1, num_docs) + idf(1, parsed.dfs.num_docs) }; score += idf * (freq * (K1 + 1.0) / (freq + doc_norm)); } @@ -1083,35 +1207,13 @@ pub fn flat_bm25_search( pub fn flat_bm25_search_stream( input: SendableRecordBatchStream, doc_col: String, - query: FullTextSearchQuery, + parsed: ParsedQuery, index: &InvertedIndex, ) -> SendableRecordBatchStream { let mut tokenizer = index.tokenizer.clone(); - let query_token_ids = collect_tokens(&query.query, &mut tokenizer, None) - .into_iter() - .dedup() - .map(|token| { - let token_id = index.tokens.get(&token); - (token, token_id) - }) - .collect::>(); - let query_tokens = query_token_ids.keys().cloned().collect::>(); - let inverted_list = index.inverted_list.clone(); - let num_docs = index.docs.len(); - let avgdl = index.docs.average_length(); - let stream = input.map(move |batch| { let batch = batch?; - let scored_batch = flat_bm25_search( - batch, - &doc_col, - inverted_list.as_ref(), - &query_tokens, - &query_token_ids, - &mut tokenizer, - avgdl, - num_docs, - )?; + let scored_batch = flat_bm25_search(batch, &doc_col, &parsed, &mut tokenizer)?; // filter out rows with score 0 let score_col = scored_batch[SCORE_COL].as_primitive::(); diff --git a/rust/lance-index/src/scalar/inverted/wand.rs b/rust/lance-index/src/scalar/inverted/wand.rs index 9bb176e0d7..e7ef7d2448 100644 --- a/rust/lance-index/src/scalar/inverted/wand.rs +++ b/rust/lance-index/src/scalar/inverted/wand.rs @@ -16,6 +16,18 @@ use super::builder::OrderedDoc; use super::index::{idf, K1}; use super::{DocInfo, PostingList}; +#[derive(Clone, Debug)] +pub struct PostingDistributedFrequency { + pub num_docs: usize, + pub len: usize, +} + +impl PostingDistributedFrequency { + pub fn new(num_docs: usize, len: usize) -> Self { + Self { num_docs, len } + } +} + #[derive(Clone)] pub struct PostingIterator { token_id: u32, @@ -24,6 +36,7 @@ pub struct PostingIterator { index: usize, mask: Arc, approximate_upper_bound: f32, + len: usize, } impl PartialEq for PostingIterator { @@ -58,12 +71,17 @@ impl PostingIterator { list: PostingList, num_doc: usize, mask: Arc, + dfs_len: Option, ) -> Self { - let approximate_upper_bound = match list.max_score() { - Some(max_score) => max_score, - None => idf(list.len(), num_doc) * (K1 + 1.0), + let len: usize = match &dfs_len { + Some(dfs_len) => *dfs_len, + None => list.len(), + }; + let approximate_upper_bound = match (&dfs_len, list.max_score()) { + (None, Some(max_score)) => max_score, + (None, None) => idf(list.len(), num_doc) * (K1 + 1.0), + (Some(dfs_len), _) => idf(*dfs_len, num_doc) * (K1 + 1.0), }; - // move the iterator to the first selected document. This is important // because caller might directly call `doc()` without calling `next()`. let mut index = 0; @@ -78,6 +96,7 @@ impl PostingIterator { index, mask, approximate_upper_bound, + len, } } @@ -192,7 +211,7 @@ impl Wand { break; } debug_assert!(cur_doc.row_id == doc_id); - let idf = idf(posting.list.len(), self.num_docs); + let idf = idf(posting.len, self.num_docs); score += idf * (K1 + 1.0) * scorer(doc_id, cur_doc.frequency); } score diff --git a/rust/lance-index/src/traits.rs b/rust/lance-index/src/traits.rs index 82db7718f4..272efae1d2 100644 --- a/rust/lance-index/src/traits.rs +++ b/rust/lance-index/src/traits.rs @@ -35,6 +35,16 @@ pub trait DatasetIndexExt { replace: bool, ) -> Result<()>; + async fn create_fragment_index( + &mut self, + columns: &[&str], + index_type: IndexType, + name: Option, + params: &dyn IndexParams, + replace: bool, + fragment_ids: Option>, + ) -> Result; + /// Drop indices by name. /// /// Upon finish, a new dataset version is generated. diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index 11839d8b21..7fda5fb0aa 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -1080,7 +1080,7 @@ impl Dataset { // Gets a filtered list of fragments from ids in O(N) time instead of using // `get_fragment` which would require O(N^2) time. - fn get_frags_from_ordered_ids(&self, ordered_ids: &[u32]) -> Vec> { + pub fn get_frags_from_ordered_ids(&self, ordered_ids: &[u32]) -> Vec> { let mut fragments = Vec::with_capacity(ordered_ids.len()); let mut id_iter = ordered_ids.iter(); let mut id = id_iter.next(); diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index 9e9df5ef4c..e4b5dc1587 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -191,14 +191,15 @@ fn vector_index_details() -> prost_types::Any { #[async_trait] impl DatasetIndexExt for Dataset { #[instrument(skip_all)] - async fn create_index( + async fn create_fragment_index( &mut self, columns: &[&str], index_type: IndexType, name: Option, params: &dyn IndexParams, replace: bool, - ) -> Result<()> { + fragment_ids: Option>, + ) -> Result { if columns.len() != 1 { return Err(Error::Index { message: "Only support building index on 1 column at the moment".to_string(), @@ -237,6 +238,11 @@ impl DatasetIndexExt for Dataset { } } + let fragment_bitmap = match &fragment_ids { + Some(fragment_ids) => Some(fragment_ids.iter().collect()), + None => Some(self.get_fragments().iter().map(|f| f.id() as u32).collect()), + }; + let index_id = Uuid::new_v4(); let index_details: prost_types::Any = match (index_type, params.index_name()) { ( @@ -248,7 +254,8 @@ impl DatasetIndexExt for Dataset { LANCE_SCALAR_INDEX, ) => { let params = ScalarIndexParams::new(index_type.try_into()?); - build_scalar_index(self, column, &index_id.to_string(), ¶ms).await? + build_scalar_index(self, column, &index_id.to_string(), fragment_ids, ¶ms) + .await? } (IndexType::Scalar, LANCE_SCALAR_INDEX) => { // Guess the index type @@ -259,7 +266,8 @@ impl DatasetIndexExt for Dataset { message: "Scalar index type must take a ScalarIndexParams".to_string(), location: location!(), })?; - build_scalar_index(self, column, &index_id.to_string(), params).await? + build_scalar_index(self, column, &index_id.to_string(), fragment_ids, params) + .await? } (IndexType::Inverted, _) => { // Inverted index params. @@ -271,7 +279,14 @@ impl DatasetIndexExt for Dataset { location: location!(), })?; - build_inverted_index(self, column, &index_id.to_string(), inverted_params).await?; + build_inverted_index( + self, + column, + &index_id.to_string(), + fragment_ids, + inverted_params, + ) + .await?; inverted_index_details() } (IndexType::Vector, LANCE_VECTOR_INDEX) => { @@ -330,15 +345,28 @@ impl DatasetIndexExt for Dataset { }); } }; - - let new_idx = IndexMetadata { + Ok(IndexMetadata { uuid: index_id, name: index_name, fields: vec![field.id], dataset_version: self.manifest.version, - fragment_bitmap: Some(self.get_fragments().iter().map(|f| f.id() as u32).collect()), + fragment_bitmap, index_details: Some(index_details), - }; + }) + } + + #[instrument(skip_all)] + async fn create_index( + &mut self, + columns: &[&str], + index_type: IndexType, + name: Option, + params: &dyn IndexParams, + replace: bool, + ) -> Result<()> { + let new_idx = self + .create_fragment_index(columns, index_type, name, params, replace, None) + .await?; let transaction = Transaction::new( self.manifest.version, Operation::CreateIndex { diff --git a/rust/lance/src/index/scalar.rs b/rust/lance/src/index/scalar.rs index 6b03ae486c..5375924fff 100644 --- a/rust/lance/src/index/scalar.rs +++ b/rust/lance/src/index/scalar.rs @@ -11,6 +11,7 @@ use async_trait::async_trait; use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use datafusion::physical_plan::SendableRecordBatchStream; use futures::TryStreamExt; +use itertools::Itertools; use lance_core::{Error, Result}; use lance_datafusion::{chunker::chunk_concat_stream, exec::LanceExecutionOptions}; use lance_index::scalar::btree::DEFAULT_BTREE_BATCH_SIZE; @@ -42,6 +43,7 @@ const TRAINING_UPDATE_FREQ: usize = 1000000; struct TrainingRequest { dataset: Arc, column: String, + fragment_ids: Option>, } #[async_trait] @@ -70,6 +72,24 @@ impl TrainingRequest { let num_rows = self.dataset.count_all_rows().await?; let mut scan = self.dataset.scan(); + if let Some(ref fragment_ids) = self.fragment_ids { + let fragment_ids = fragment_ids.clone().into_iter().dedup().collect_vec(); + let frags = self.dataset.get_frags_from_ordered_ids(&fragment_ids); + let frags: Result> = fragment_ids + .iter() + .zip(frags) + .map(|(id, frag)| { + let Some(frag) = frag else { + return Err(Error::InvalidInput { + source: format!("No fragment with id {}", id).into(), + location: location!(), + }); + }; + Ok(frag.metadata().clone()) + }) + .collect(); + scan.with_fragments(frags?); + } let column_field = self.dataset @@ -231,11 +251,13 @@ pub(super) async fn build_scalar_index( dataset: &Dataset, column: &str, uuid: &str, + fragment_ids: Option>, params: &ScalarIndexParams, ) -> Result { let training_request = Box::new(TrainingRequest { dataset: Arc::new(dataset.clone()), column: column.to_string(), + fragment_ids, }); let field = dataset.schema().field(column).ok_or(Error::InvalidInput { source: format!("No column with name {}", column).into(), @@ -319,11 +341,13 @@ pub(super) async fn build_inverted_index( dataset: &Dataset, column: &str, uuid: &str, + fragment_ids: Option>, params: &InvertedIndexParams, ) -> Result<()> { let training_request = Box::new(TrainingRequest { dataset: Arc::new(dataset.clone()), column: column.to_string(), + fragment_ids, }); let index_store = LanceIndexStore::from_dataset(dataset, uuid); train_inverted_index(training_request, &index_store, params.clone()).await diff --git a/rust/lance/src/io/exec/fts.rs b/rust/lance/src/io/exec/fts.rs index e8409e9830..4962396989 100644 --- a/rust/lance/src/io/exec/fts.rs +++ b/rust/lance/src/io/exec/fts.rs @@ -15,9 +15,12 @@ use datafusion::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan, Pla use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; use futures::stream::{self}; use futures::{StreamExt, TryStreamExt}; +use lance_index::as_inverted_index; use lance_index::prefilter::{FilterLoader, PreFilter}; -use lance_index::scalar::inverted::{flat_bm25_search_stream, InvertedIndex, FTS_SCHEMA}; -use lance_index::scalar::FullTextSearchQuery; +use lance_index::scalar::inverted::{ + flat_bm25_search_stream, InvertedIndex, ParsedQuery, FTS_SCHEMA, +}; +use lance_index::scalar::{FullTextSearchQuery, SearchType}; use lance_table::format::Index; use tracing::instrument; @@ -29,6 +32,30 @@ use super::utils::{ }; use super::PreFilterSource; +async fn parse_query( + query: &FullTextSearchQuery, + column: &str, + indices: &[Index], + ds: &Arc, +) -> DataFusionResult { + let mut parsed = None; + for index in indices { + let uuid = index.uuid.to_string(); + let index = ds.open_generic_index(column, &uuid).await?; + let index = as_inverted_index!(index, uuid)?; + match &mut parsed { + None => parsed = Some(index.parse(query)?), + Some(parsed) => parsed.count(index), + } + } + match parsed { + Some(parsed) => Ok(parsed), + None => Err(DataFusionError::Execution( + "Unable to parse query".to_string(), + )), + } +} + /// An execution node that performs full text search /// /// This node would perform full text search with inverted index on the dataset. @@ -159,60 +186,86 @@ impl ExecutionPlan for FtsExec { let prefilter_source = self.prefilter_source.clone(); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let indices = self.indices.clone(); - let stream = stream::iter(indices) - .map(move |(column, indices)| { - let index_meta = indices[0].clone(); - let uuid = index_meta.uuid.to_string(); - let query = query.clone(); - let ds = ds.clone(); - let context = context.clone(); - let prefilter_source = prefilter_source.clone(); + let stream = stream::iter(self.indices.clone()); - async move { - let prefilter_loader = match &prefilter_source { - PreFilterSource::FilteredRowIds(src_node) => { - let stream = src_node.execute(partition, context.clone())?; - Some(Box::new(FilteredRowIdsToPrefilter(stream)) - as Box) - } - PreFilterSource::ScalarIndexQuery(src_node) => { - let stream = src_node.execute(partition, context.clone())?; - Some(Box::new(SelectionVectorToPrefilter(stream)) - as Box) + let stream = match query.search_type { + SearchType::QueryThenFetch => stream + .map(|(column, indices)| (column, indices, Option::::None)) + .boxed(), + SearchType::DfsQueryThenFetch => { + let ds = ds.clone(); + let query = query.clone(); + stream + .then(move |(column, indices)| { + let ds = ds.clone(); + let query = query.clone(); + async move { + match parse_query(&query, &column, &indices, &ds).await { + Ok(parsed) => (column, indices, Some(parsed)), + // use query then fetch if distributed frequency collection failed + Err(_) => (column, indices, Option::::None), + } } - PreFilterSource::None => None, - }; - let pre_filter = Arc::new(DatasetPreFilter::new( - ds.clone(), - &[index_meta], - prefilter_loader, - )); + }) + .boxed() + } + }; - let index = ds.open_generic_index(&column, &uuid).await?; - let index = - index - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Execution(format!( - "Index {} is not an inverted index", - uuid, - )) - })?; - pre_filter.wait_for_ready().await?; - let results = index.full_text_search(&query, pre_filter).await?; - - let (row_ids, scores): (Vec, Vec) = results.into_iter().unzip(); - let batch = RecordBatch::try_new( - FTS_SCHEMA.clone(), - vec![ - Arc::new(UInt64Array::from(row_ids)), - Arc::new(Float32Array::from(scores)), - ], - )?; - Ok::<_, DataFusionError>(batch) + let stream = stream + .flat_map(move |(column, indices, parsed)| { + let mut all_batches = Vec::with_capacity(indices.len()); + + for index_meta in indices { + let parsed = parsed.clone(); + let query = query.clone(); + let ds = ds.clone(); + let context = context.clone(); + let prefilter_source = prefilter_source.clone(); + let column = column.clone(); + all_batches.push(async move { + let uuid = index_meta.uuid.to_string(); + let prefilter_loader = match &prefilter_source { + PreFilterSource::FilteredRowIds(src_node) => { + let stream = src_node.execute(partition, context.clone())?; + Some(Box::new(FilteredRowIdsToPrefilter(stream)) + as Box) + } + PreFilterSource::ScalarIndexQuery(src_node) => { + let stream = src_node.execute(partition, context.clone())?; + Some(Box::new(SelectionVectorToPrefilter(stream)) + as Box) + } + PreFilterSource::None => None, + }; + let pre_filter = Arc::new(DatasetPreFilter::new( + ds.clone(), + &[index_meta], + prefilter_loader, + )); + + let index = ds.open_generic_index(&column, &uuid).await?; + let index = as_inverted_index!(index, uuid)?; + pre_filter.wait_for_ready().await?; + + let parsed = match parsed { + Some(parsed) => parsed, + None => index.parse(&query)?, + }; + + let results = index.parsed_search(&parsed, pre_filter).await?; + + let (row_ids, scores): (Vec, Vec) = results.into_iter().unzip(); + let batch = RecordBatch::try_new( + FTS_SCHEMA.clone(), + vec![ + Arc::new(UInt64Array::from(row_ids)), + Arc::new(Float32Array::from(scores)), + ], + )?; + Ok::<_, DataFusionError>(batch) + }); } + stream::iter(all_batches) }) .buffered(self.indices.len()); let schema = self.schema(); @@ -346,24 +399,18 @@ impl ExecutionPlan for FlatFtsExec { let query = query.clone(); let ds = ds.clone(); let context = context.clone(); - async move { - let index = ds.open_generic_index(&column, &uuid).await?; - let index = - index - .as_any() - .downcast_ref::() - .ok_or_else(|| { - DataFusionError::Execution(format!( - "Index {} is not an inverted index", - uuid, - )) - })?; - let unindexed_stream = input.execute(partition, context)?; + let index = ds.open_generic_index(&column, &uuid).await?; + let index = as_inverted_index!(index, uuid)?; + let parsed = match query.search_type { + SearchType::DfsQueryThenFetch => { + parse_query(&query, &column, &indices, &ds).await? + } + SearchType::QueryThenFetch => index.parse(&query)?, + }; let unindexed_result_stream = - flat_bm25_search_stream(unindexed_stream, column, query, index); - + flat_bm25_search_stream(unindexed_stream, column, parsed, index); Ok::<_, DataFusionError>(unindexed_result_stream) } })