Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 56c70a7

Browse files
committedOct 14, 2024·
feat: support to customize tokenizer
Signed-off-by: BubbleCal <bubble-cal@outlook.com>
1 parent fe5fcf3 commit 56c70a7

File tree

9 files changed

+372
-63
lines changed

9 files changed

+372
-63
lines changed
 

‎Cargo.toml

+18-18
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ exclude = ["python"]
2121
resolver = "2"
2222

2323
[workspace.package]
24-
version = "0.18.3"
24+
version = "0.18.4"
2525
edition = "2021"
2626
authors = ["Lance Devs <dev@lancedb.com>"]
2727
license = "Apache-2.0"
@@ -44,21 +44,21 @@ categories = [
4444
rust-version = "1.78"
4545

4646
[workspace.dependencies]
47-
lance = { version = "=0.18.3", path = "./rust/lance" }
48-
lance-arrow = { version = "=0.18.3", path = "./rust/lance-arrow" }
49-
lance-core = { version = "=0.18.3", path = "./rust/lance-core" }
50-
lance-datafusion = { version = "=0.18.3", path = "./rust/lance-datafusion" }
51-
lance-datagen = { version = "=0.18.3", path = "./rust/lance-datagen" }
52-
lance-encoding = { version = "=0.18.3", path = "./rust/lance-encoding" }
53-
lance-encoding-datafusion = { version = "=0.18.3", path = "./rust/lance-encoding-datafusion" }
54-
lance-file = { version = "=0.18.3", path = "./rust/lance-file" }
55-
lance-index = { version = "=0.18.3", path = "./rust/lance-index" }
56-
lance-io = { version = "=0.18.3", path = "./rust/lance-io" }
57-
lance-jni = { version = "=0.18.3", path = "./java/core/lance-jni" }
58-
lance-linalg = { version = "=0.18.3", path = "./rust/lance-linalg" }
59-
lance-table = { version = "=0.18.3", path = "./rust/lance-table" }
60-
lance-test-macros = { version = "=0.18.3", path = "./rust/lance-test-macros" }
61-
lance-testing = { version = "=0.18.3", path = "./rust/lance-testing" }
47+
lance = { version = "=0.18.4", path = "./rust/lance" }
48+
lance-arrow = { version = "=0.18.4", path = "./rust/lance-arrow" }
49+
lance-core = { version = "=0.18.4", path = "./rust/lance-core" }
50+
lance-datafusion = { version = "=0.18.4", path = "./rust/lance-datafusion" }
51+
lance-datagen = { version = "=0.18.4", path = "./rust/lance-datagen" }
52+
lance-encoding = { version = "=0.18.4", path = "./rust/lance-encoding" }
53+
lance-encoding-datafusion = { version = "=0.18.4", path = "./rust/lance-encoding-datafusion" }
54+
lance-file = { version = "=0.18.4", path = "./rust/lance-file" }
55+
lance-index = { version = "=0.18.4", path = "./rust/lance-index" }
56+
lance-io = { version = "=0.18.4", path = "./rust/lance-io" }
57+
lance-jni = { version = "=0.18.4", path = "./java/core/lance-jni" }
58+
lance-linalg = { version = "=0.18.4", path = "./rust/lance-linalg" }
59+
lance-table = { version = "=0.18.4", path = "./rust/lance-table" }
60+
lance-test-macros = { version = "=0.18.4", path = "./rust/lance-test-macros" }
61+
lance-testing = { version = "=0.18.4", path = "./rust/lance-testing" }
6262
approx = "0.5.1"
6363
# Note that this one does not include pyarrow
6464
arrow = { version = "52.2", optional = false, features = ["prettyprint"] }
@@ -111,7 +111,7 @@ datafusion-physical-expr = { version = "41.0", features = [
111111
] }
112112
deepsize = "0.2.0"
113113
either = "1.0"
114-
fsst = { version = "=0.18.3", path = "./rust/lance-encoding/compression-algo/fsst" }
114+
fsst = { version = "=0.18.4", path = "./rust/lance-encoding/compression-algo/fsst" }
115115
futures = "0.3"
116116
http = "0.2.9"
117117
hyperloglogplus = { version = "0.4.1", features = ["const-loop"] }
@@ -141,7 +141,7 @@ serde = { version = "^1" }
141141
serde_json = { version = "1" }
142142
shellexpand = "3.0"
143143
snafu = "0.7.5"
144-
tantivy = "0.22.0"
144+
tantivy = { version = "0.22.0", features = ["stopwords"] }
145145
tempfile = "3"
146146
test-log = { version = "0.2.15" }
147147
tokio = { version = "1.23", features = [

‎python/Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pylance"
3-
version = "0.18.3"
3+
version = "0.18.4"
44
edition = "2021"
55
authors = ["Lance Devs <dev@lancedb.com>"]
66
rust-version = "1.65"

‎python/python/lance/dataset.py

+21
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,27 @@ def create_scalar_index(
13371337
query. This will significantly increase the index size.
13381338
It won't impact the performance of non-phrase queries even if it is set to
13391339
True.
1340+
base_tokenizer: str, default "simple"
1341+
This is for the ``INVERTED`` index. The base tokenizer to use. The value
1342+
can be:
1343+
* "simple": splits tokens on whitespace and punctuation.
1344+
* "whitespace": splits tokens on whitespace.
1345+
* "raw": no tokenization.
1346+
language: str, default "English"
1347+
This is for the ``INVERTED`` index. The language for stemming
1348+
and stop words. This is only used when `stem` or `remove_stop_words` is true
1349+
max_token_length: Optional[int], default 40
1350+
This is for the ``INVERTED`` index. The maximum token length.
1351+
Any token longer than this will be removed.
1352+
lower_case: bool, default True
1353+
This is for the ``INVERTED`` index. If True, the index will convert all
1354+
text to lowercase.
1355+
stem: bool, default False
1356+
This is for the ``INVERTED`` index. If True, the index will stem the
1357+
tokens.
1358+
remove_stop_words: bool, default False
1359+
This is for the ``INVERTED`` index. If True, the index will remove
1360+
stop words.
13401361
13411362
Examples
13421363
--------

‎python/src/dataset.rs

+37
Original file line numberDiff line numberDiff line change
@@ -1205,6 +1205,43 @@ impl Dataset {
12051205
if let Some(with_position) = kwargs.get_item("with_position")? {
12061206
params.with_position = with_position.extract()?;
12071207
}
1208+
if let Some(base_tokenizer) = kwargs.get_item("base_tokenizer")? {
1209+
params.tokenizer_config = params
1210+
.tokenizer_config
1211+
.base_tokenizer(base_tokenizer.extract()?);
1212+
}
1213+
if let Some(language) = kwargs.get_item("language")? {
1214+
let language = language.extract()?;
1215+
params.tokenizer_config =
1216+
params.tokenizer_config.language(language).map_err(|e| {
1217+
PyValueError::new_err(format!(
1218+
"can't set tokenizer language to {}: {:?}",
1219+
language, e
1220+
))
1221+
})?;
1222+
}
1223+
if let Some(max_token_length) = kwargs.get_item("max_token_length")? {
1224+
params.tokenizer_config = params
1225+
.tokenizer_config
1226+
.max_token_length(max_token_length.extract()?);
1227+
}
1228+
if let Some(lower_case) = kwargs.get_item("lower_case")? {
1229+
params.tokenizer_config =
1230+
params.tokenizer_config.lower_case(lower_case.extract()?);
1231+
}
1232+
if let Some(stem) = kwargs.get_item("stem")? {
1233+
params.tokenizer_config = params.tokenizer_config.stem(stem.extract()?);
1234+
}
1235+
if let Some(remove_stop_words) = kwargs.get_item("remove_stop_words")? {
1236+
params.tokenizer_config = params
1237+
.tokenizer_config
1238+
.remove_stop_words(remove_stop_words.extract()?);
1239+
}
1240+
if let Some(ascii_folding) = kwargs.get_item("ascii_folding")? {
1241+
params.tokenizer_config = params
1242+
.tokenizer_config
1243+
.ascii_folding(ascii_folding.extract()?);
1244+
}
12081245
}
12091246
Box::new(params)
12101247
}

‎rust/lance-index/src/scalar.rs

+20-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
//! Scalar indices for metadata search & filtering
55
66
use std::collections::HashMap;
7+
use std::fmt::Debug;
78
use std::{any::Any, ops::Bound, sync::Arc};
89

910
use arrow::buffer::{OffsetBuffer, ScalarBuffer};
@@ -17,6 +18,7 @@ use datafusion_common::{scalar::ScalarValue, Column};
1718
use datafusion_expr::expr::ScalarFunction;
1819
use datafusion_expr::Expr;
1920
use deepsize::DeepSizeOf;
21+
use inverted::TokenizerConfig;
2022
use lance_core::utils::mask::RowIdTreeMap;
2123
use lance_core::{Error, Result};
2224
use snafu::{location, Location};
@@ -91,19 +93,36 @@ impl IndexParams for ScalarIndexParams {
9193
}
9294
}
9395

94-
#[derive(Debug, Clone, DeepSizeOf)]
96+
#[derive(Clone)]
9597
pub struct InvertedIndexParams {
9698
/// If true, store the position of the term in the document
9799
/// This can significantly increase the size of the index
98100
/// If false, only store the frequency of the term in the document
99101
/// Default is true
100102
pub with_position: bool,
103+
104+
pub tokenizer_config: TokenizerConfig,
105+
}
106+
107+
impl Debug for InvertedIndexParams {
108+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
109+
f.debug_struct("InvertedIndexParams")
110+
.field("with_position", &self.with_position)
111+
.finish()
112+
}
113+
}
114+
115+
impl DeepSizeOf for InvertedIndexParams {
116+
fn deep_size_of_children(&self, _: &mut deepsize::Context) -> usize {
117+
0
118+
}
101119
}
102120

103121
impl Default for InvertedIndexParams {
104122
fn default() -> Self {
105123
Self {
106124
with_position: true,
125+
tokenizer_config: TokenizerConfig::default(),
107126
}
108127
}
109128
}

‎rust/lance-index/src/scalar/inverted.rs

+2
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
mod builder;
55
mod index;
6+
mod tokenizer;
67
mod wand;
78

89
pub use builder::InvertedIndexBuilder;
910
pub use index::*;
1011
use lance_core::Result;
12+
pub use tokenizer::*;
1113

1214
use super::btree::TrainingSource;
1315
use super::{IndexStore, InvertedIndexParams};

‎rust/lance-index/src/scalar/inverted/builder.rs

+92-26
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ use std::sync::Arc;
1010
use crate::scalar::lance_format::LanceIndexStore;
1111
use crate::scalar::{IndexReader, IndexStore, IndexWriter, InvertedIndexParams};
1212
use crate::vector::graph::OrderedFloat;
13-
use arrow::array::AsArray;
13+
use arrow::array::{ArrayBuilder, AsArray, Int32Builder, StringBuilder};
1414
use arrow::datatypes;
15-
use arrow_array::RecordBatch;
15+
use arrow_array::{Int32Array, RecordBatch, StringArray};
1616
use arrow_schema::SchemaRef;
1717
use crossbeam_queue::ArrayQueue;
1818
use datafusion::execution::SendableRecordBatchStream;
@@ -131,8 +131,8 @@ impl InvertedIndexBuilder {
131131
senders.push(sender);
132132
result_futs.push(tokio::spawn({
133133
async move {
134-
while let Some((row_id, tokens)) = receiver.recv().await {
135-
worker.add(row_id, tokens).await?;
134+
while let Some((row_id, tokens, positions)) = receiver.recv().await {
135+
worker.add(row_id, tokens, positions).await?;
136136
}
137137
let reader = worker.into_reader(inverted_list).await?;
138138
Result::Ok(reader)
@@ -143,18 +143,15 @@ impl InvertedIndexBuilder {
143143
let start = std::time::Instant::now();
144144
let senders = Arc::new(senders);
145145
let tokenizer_pool = Arc::new(ArrayQueue::new(num_shards));
146-
let token_buffers_pool = Arc::new(ArrayQueue::new(num_shards));
146+
let tokenizer = self.params.tokenizer_config.build()?;
147147
for _ in 0..num_shards {
148-
let _ = tokenizer_pool.push(TOKENIZER.clone());
149-
token_buffers_pool
150-
.push(vec![Vec::new(); num_shards])
151-
.unwrap();
148+
let _ = tokenizer_pool.push(tokenizer.clone());
152149
}
153150
let mut stream = stream
154151
.map(move |batch| {
155152
let senders = senders.clone();
156153
let tokenizer_pool = tokenizer_pool.clone();
157-
let token_buffers_pool = token_buffers_pool.clone();
154+
// let token_buffers_pool = token_buffers_pool.clone();
158155
CPU_RUNTIME.spawn_blocking(move || {
159156
let batch = batch?;
160157
let doc_iter = iter_str_array(batch.column(0));
@@ -164,37 +161,55 @@ impl InvertedIndexBuilder {
164161
.filter_map(|(doc, row_id)| doc.map(|doc| (doc, *row_id)));
165162

166163
let mut tokenizer = tokenizer_pool.pop().unwrap();
167-
let mut token_buffers = token_buffers_pool.pop().unwrap();
168164

169165
let num_tokens = docs
170166
.map(|(doc, row_id)| {
171167
// tokenize the document
168+
let predicted_num_tokens = doc.len() / 5 / num_shards;
169+
let mut token_buffers = std::iter::repeat_with(|| {
170+
(
171+
StringBuilder::with_capacity(
172+
predicted_num_tokens,
173+
doc.len() / num_shards,
174+
),
175+
Int32Builder::with_capacity(predicted_num_tokens),
176+
)
177+
})
178+
.take(num_shards)
179+
.collect_vec();
172180
let mut num_tokens = 0;
173181
let mut token_stream = tokenizer.token_stream(doc);
174182
while token_stream.advance() {
175183
let token = token_stream.token_mut();
176184
let mut hasher = DefaultHasher::new();
177185
hasher.write(token.text.as_bytes());
178186
let shard = hasher.finish() as usize % num_shards;
179-
token_buffers[shard]
180-
.push((std::mem::take(&mut token.text), token.position as i32));
187+
let (ref mut token_builder, ref mut position_builder) =
188+
&mut token_buffers[shard];
189+
token_builder.append_value(&token.text);
190+
position_builder.append_value(token.position as i32);
181191
num_tokens += 1;
182192
}
183193

184-
for (shard, buffer) in token_buffers.iter_mut().enumerate() {
185-
if buffer.is_empty() {
194+
for (shard, (token_builder, position_builder)) in
195+
token_buffers.iter_mut().enumerate()
196+
{
197+
if token_builder.is_empty() {
186198
continue;
187199
}
188-
let buffer = std::mem::take(buffer);
189-
senders[shard].blocking_send((row_id, buffer)).unwrap();
200+
201+
let tokens = token_builder.finish();
202+
let positions = position_builder.finish();
203+
senders[shard]
204+
.blocking_send((row_id, tokens, positions))
205+
.unwrap();
190206
}
191207

192208
(row_id, num_tokens)
193209
})
194210
.collect_vec();
195211

196212
let _ = tokenizer_pool.push(tokenizer);
197-
token_buffers_pool.push(token_buffers).unwrap();
198213
Result::Ok(num_tokens)
199214
})
200215
})
@@ -355,7 +370,10 @@ impl InvertedIndexBuilder {
355370
let batch = tokens.to_batch()?;
356371
let mut writer = store.new_index_file(TOKENS_FILE, batch.schema()).await?;
357372
writer.write_record_batch(batch).await?;
358-
writer.finish().await?;
373+
374+
let tokenizer = serde_json::to_string(&self.params.tokenizer_config)?;
375+
let metadata = HashMap::from_iter(vec![("tokenizer".to_owned(), tokenizer)]);
376+
writer.finish_with_metadata(metadata).await?;
359377

360378
log::info!("finished writing tokens");
361379
Ok(())
@@ -421,21 +439,26 @@ impl IndexWorker {
421439
self.schema.column_with_name(POSITION_COL).is_some()
422440
}
423441

424-
async fn add(&mut self, row_id: u64, tokens: Vec<(String, i32)>) -> Result<()> {
442+
async fn add(&mut self, row_id: u64, tokens: StringArray, positions: Int32Array) -> Result<()> {
425443
let mut token_occurrences = HashMap::new();
426-
for (token, position) in tokens {
444+
for (token, position) in tokens.iter().zip(positions.values().into_iter()) {
445+
let token = if let Some(token) = token {
446+
token
447+
} else {
448+
continue;
449+
};
427450
token_occurrences
428451
.entry(token)
429452
.or_insert_with(Vec::new)
430-
.push(position);
453+
.push(*position);
431454
}
432455
let with_position = self.has_position();
433456
token_occurrences
434457
.into_iter()
435458
.for_each(|(token, term_positions)| {
436459
let posting_list = self
437460
.posting_lists
438-
.entry(token.clone())
461+
.entry(token.to_owned())
439462
.or_insert_with(|| PostingListBuilder::empty(with_position));
440463

441464
let old_size = if posting_list.is_empty() {
@@ -702,20 +725,23 @@ mod tests {
702725
use lance_io::object_store::ObjectStore;
703726
use object_store::path::Path;
704727

728+
use crate::scalar::inverted::TokenizerConfig;
705729
use crate::scalar::lance_format::LanceIndexStore;
706730
use crate::scalar::{FullTextSearchQuery, SargableQuery, ScalarIndex};
707731

708732
use super::InvertedIndex;
709733

710734
async fn create_index<Offset: arrow::array::OffsetSizeTrait>(
711735
with_position: bool,
736+
tokenizer: TokenizerConfig,
712737
) -> Arc<InvertedIndex> {
713738
let tempdir = tempfile::tempdir().unwrap();
714739
let index_dir = Path::from_filesystem_path(tempdir.path()).unwrap();
715740
let cache = FileMetadataCache::with_capacity(128 * 1024 * 1024, CapacityMode::Bytes);
716741
let store = LanceIndexStore::new(ObjectStore::local(), index_dir, cache);
717742

718-
let params = super::InvertedIndexParams::default().with_position(with_position);
743+
let mut params = super::InvertedIndexParams::default().with_position(with_position);
744+
params.tokenizer_config = tokenizer;
719745
let mut invert_index = super::InvertedIndexBuilder::new(params);
720746
let doc_col = GenericStringArray::<Offset>::from(vec![
721747
"lance database the search",
@@ -724,6 +750,7 @@ mod tests {
724750
"database search",
725751
"unrelated doc",
726752
"unrelated",
753+
"mots accentués",
727754
]);
728755
let row_id_col = UInt64Array::from(Vec::from_iter(0..doc_col.len() as u64));
729756
let batch = RecordBatch::try_new(
@@ -750,7 +777,7 @@ mod tests {
750777
}
751778

752779
async fn test_inverted_index<Offset: arrow::array::OffsetSizeTrait>() {
753-
let invert_index = create_index::<Offset>(false).await;
780+
let invert_index = create_index::<Offset>(false, TokenizerConfig::default()).await;
754781
let row_ids = invert_index
755782
.search(&SargableQuery::FullTextSearch(
756783
FullTextSearchQuery::new("lance".to_owned()).limit(Some(3)),
@@ -800,7 +827,7 @@ mod tests {
800827
assert!(results.unwrap_err().to_string().contains("position is not found but required for phrase queries, try recreating the index with position"));
801828

802829
// recreate the index with position
803-
let invert_index = create_index::<Offset>(true).await;
830+
let invert_index = create_index::<Offset>(true, TokenizerConfig::default()).await;
804831
let row_ids = invert_index
805832
.search(&SargableQuery::FullTextSearch(
806833
FullTextSearchQuery::new("lance database".to_owned()).limit(Some(10)),
@@ -857,4 +884,43 @@ mod tests {
857884
async fn test_inverted_index_with_large_string() {
858885
test_inverted_index::<i64>().await;
859886
}
887+
888+
#[tokio::test]
889+
async fn test_accented_chars() {
890+
let invert_index = create_index::<i32>(false, TokenizerConfig::default()).await;
891+
let row_ids = invert_index
892+
.search(&SargableQuery::FullTextSearch(
893+
FullTextSearchQuery::new("accentués".to_owned()).limit(Some(3)),
894+
))
895+
.await
896+
.unwrap();
897+
assert_eq!(row_ids.len(), Some(1));
898+
899+
let row_ids = invert_index
900+
.search(&SargableQuery::FullTextSearch(
901+
FullTextSearchQuery::new("accentues".to_owned()).limit(Some(3)),
902+
))
903+
.await
904+
.unwrap();
905+
assert_eq!(row_ids.len(), Some(0));
906+
907+
// with ascii folding enabled, the search should be accent-insensitive
908+
let invert_index =
909+
create_index::<i32>(true, TokenizerConfig::default().ascii_folding(true)).await;
910+
let row_ids = invert_index
911+
.search(&SargableQuery::FullTextSearch(
912+
FullTextSearchQuery::new("accentués".to_owned()).limit(Some(3)),
913+
))
914+
.await
915+
.unwrap();
916+
assert_eq!(row_ids.len(), Some(1));
917+
918+
let row_ids = invert_index
919+
.search(&SargableQuery::FullTextSearch(
920+
FullTextSearchQuery::new("accentues".to_owned()).limit(Some(3)),
921+
))
922+
.await
923+
.unwrap();
924+
assert_eq!(row_ids.len(), Some(1));
925+
}
860926
}

‎rust/lance-index/src/scalar/inverted/index.rs

+33-17
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-FileCopyrightText: Copyright The Lance Authors
33

44
use std::collections::{HashMap, HashSet};
5+
use std::fmt::Debug;
56
use std::sync::Arc;
67

78
use arrow::array::{
@@ -27,11 +28,10 @@ use lazy_static::lazy_static;
2728
use moka::future::Cache;
2829
use roaring::RoaringBitmap;
2930
use snafu::{location, Location};
30-
use tantivy::tokenizer::Language;
3131
use tracing::instrument;
3232

3333
use super::builder::inverted_list_schema;
34-
use super::{wand::*, InvertedIndexBuilder};
34+
use super::{wand::*, InvertedIndexBuilder, TokenizerConfig};
3535
use crate::prefilter::{NoFilter, PreFilter};
3636
use crate::scalar::{
3737
AnyQuery, FullTextSearchQuery, IndexReader, IndexStore, SargableQuery, ScalarIndex,
@@ -57,26 +57,30 @@ pub const K1: f32 = 1.2;
5757
pub const B: f32 = 0.75;
5858

5959
lazy_static! {
60-
pub static ref TOKENIZER: tantivy::tokenizer::TextAnalyzer = {
61-
tantivy::tokenizer::TextAnalyzer::builder(tantivy::tokenizer::SimpleTokenizer::default())
62-
.filter(tantivy::tokenizer::RemoveLongFilter::limit(40))
63-
.filter(tantivy::tokenizer::LowerCaser)
64-
.filter(tantivy::tokenizer::Stemmer::new(Language::English))
65-
.build()
66-
};
6760
static ref CACHE_SIZE: usize = std::env::var("LANCE_INVERTED_CACHE_SIZE")
6861
.ok()
6962
.and_then(|s| s.parse().ok())
7063
.unwrap_or(512 * 1024 * 1024);
7164
}
7265

73-
#[derive(Debug, Clone)]
66+
#[derive(Clone)]
7467
pub struct InvertedIndex {
68+
tokenizer: tantivy::tokenizer::TextAnalyzer,
7569
tokens: TokenSet,
7670
inverted_list: Arc<InvertedListReader>,
7771
docs: DocSet,
7872
}
7973

74+
impl Debug for InvertedIndex {
75+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76+
f.debug_struct("InvertedIndex")
77+
.field("tokens", &self.tokens)
78+
.field("inverted_list", &self.inverted_list)
79+
.field("docs", &self.docs)
80+
.finish()
81+
}
82+
}
83+
8084
impl DeepSizeOf for InvertedIndex {
8185
fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
8286
self.tokens.deep_size_of_children(context)
@@ -102,7 +106,8 @@ impl InvertedIndex {
102106
query: &FullTextSearchQuery,
103107
prefilter: Arc<dyn PreFilter>,
104108
) -> Result<Vec<(u64, f32)>> {
105-
let tokens = collect_tokens(&query.query);
109+
let mut tokenizer = self.tokenizer.clone();
110+
let tokens = collect_tokens(&query.query, &mut tokenizer);
106111
let token_ids = self.map(&tokens).into_iter();
107112
let token_ids = if !is_phrase_query(&query.query) {
108113
token_ids.sorted_unstable().dedup().collect()
@@ -239,8 +244,16 @@ impl ScalarIndex for InvertedIndex {
239244
let store = store.clone();
240245
async move {
241246
let token_reader = store.open_index_file(TOKENS_FILE).await?;
247+
let tokenizer = token_reader
248+
.schema()
249+
.metadata
250+
.get("tokenizer")
251+
.map(|s| serde_json::from_str::<TokenizerConfig>(s))
252+
.transpose()?
253+
.unwrap_or_default()
254+
.build()?;
242255
let tokens = TokenSet::load(token_reader).await?;
243-
Result::Ok(tokens)
256+
Result::Ok((tokenizer, tokens))
244257
}
245258
});
246259
let invert_list_fut = tokio::spawn({
@@ -260,11 +273,12 @@ impl ScalarIndex for InvertedIndex {
260273
}
261274
});
262275

263-
let tokens = tokens_fut.await??;
276+
let (tokenizer, tokens) = tokens_fut.await??;
264277
let inverted_list = invert_list_fut.await??;
265278
let docs = docs_fut.await??;
266279

267280
Ok(Arc::new(Self {
281+
tokenizer,
268282
tokens,
269283
inverted_list,
270284
docs,
@@ -959,13 +973,16 @@ fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
959973
query: &str,
960974
) -> Result<Vec<u64>> {
961975
let mut results = Vec::new();
962-
let query_tokens = collect_tokens(query).into_iter().collect::<HashSet<_>>();
976+
let mut tokenizer = TokenizerConfig::default().build()?;
977+
let query_tokens = collect_tokens(query, &mut tokenizer)
978+
.into_iter()
979+
.collect::<HashSet<_>>();
963980
for batch in batches {
964981
let row_id_array = batch[ROW_ID].as_primitive::<UInt64Type>();
965982
let doc_array = batch[doc_col].as_string::<Offset>();
966983
for i in 0..row_id_array.len() {
967984
let doc = doc_array.value(i);
968-
let doc_tokens = collect_tokens(doc);
985+
let doc_tokens = collect_tokens(doc, &mut tokenizer);
969986
if doc_tokens.iter().any(|token| query_tokens.contains(token)) {
970987
results.push(row_id_array.value(i));
971988
assert!(doc.contains(query));
@@ -976,8 +993,7 @@ fn do_flat_full_text_search<Offset: OffsetSizeTrait>(
976993
Ok(results)
977994
}
978995

979-
pub fn collect_tokens(text: &str) -> Vec<String> {
980-
let mut tokenizer = TOKENIZER.clone();
996+
pub fn collect_tokens(text: &str, tokenizer: &mut tantivy::tokenizer::TextAnalyzer) -> Vec<String> {
981997
let mut stream = tokenizer.token_stream(text);
982998
let mut tokens = Vec::new();
983999
while let Some(token) = stream.next() {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// SPDX-License-Identifier: Apache-2.0
2+
// SPDX-FileCopyrightText: Copyright The Lance Authors
3+
4+
use lance_core::{Error, Result};
5+
use serde::{Deserialize, Serialize};
6+
use snafu::{location, Location};
7+
8+
/// Tokenizer configs
9+
#[derive(Debug, Clone, Serialize, Deserialize)]
10+
pub struct TokenizerConfig {
11+
/// base tokenizer:
12+
/// - `simple`: splits tokens on whitespace and punctuation
13+
/// - `whitespace`: splits tokens on whitespace
14+
/// - `raw`: no tokenization
15+
/// `simple` is recommended for most cases and the default value
16+
base_tokenizer: String,
17+
18+
/// language for stemming and stop words
19+
/// this is only used when `stem` or `remove_stop_words` is true
20+
language: tantivy::tokenizer::Language,
21+
22+
/// maximum token length
23+
/// - `None`: no limit
24+
/// - `Some(n)`: remove tokens longer than `n`
25+
max_token_length: Option<usize>,
26+
27+
/// whether lower case tokens
28+
lower_case: bool,
29+
30+
/// whether apply stemming
31+
stem: bool,
32+
33+
/// whether remove stop words
34+
remove_stop_words: bool,
35+
36+
/// ascii folding
37+
ascii_folding: bool,
38+
}
39+
40+
impl Default for TokenizerConfig {
41+
fn default() -> Self {
42+
Self::new("simple".to_owned(), tantivy::tokenizer::Language::English)
43+
}
44+
}
45+
46+
impl TokenizerConfig {
47+
pub fn new(base_tokenizer: String, language: tantivy::tokenizer::Language) -> Self {
48+
TokenizerConfig {
49+
base_tokenizer,
50+
language,
51+
max_token_length: Some(40),
52+
lower_case: true,
53+
stem: false,
54+
remove_stop_words: false,
55+
ascii_folding: false,
56+
}
57+
}
58+
59+
pub fn base_tokenizer(mut self, base_tokenizer: String) -> Self {
60+
self.base_tokenizer = base_tokenizer;
61+
self
62+
}
63+
64+
pub fn language(mut self, language: &str) -> Result<Self> {
65+
// need to convert to valid JSON string
66+
let language = serde_json::from_str(format!("\"{}\"", language).as_str())?;
67+
self.language = language;
68+
Ok(self)
69+
}
70+
71+
pub fn max_token_length(mut self, max_token_length: Option<usize>) -> Self {
72+
self.max_token_length = max_token_length;
73+
self
74+
}
75+
76+
pub fn lower_case(mut self, lower_case: bool) -> Self {
77+
self.lower_case = lower_case;
78+
self
79+
}
80+
81+
pub fn stem(mut self, stem: bool) -> Self {
82+
self.stem = stem;
83+
self
84+
}
85+
86+
pub fn remove_stop_words(mut self, remove_stop_words: bool) -> Self {
87+
self.remove_stop_words = remove_stop_words;
88+
self
89+
}
90+
91+
pub fn ascii_folding(mut self, ascii_folding: bool) -> Self {
92+
self.ascii_folding = ascii_folding;
93+
self
94+
}
95+
96+
pub fn build(&self) -> Result<tantivy::tokenizer::TextAnalyzer> {
97+
let mut builder = build_base_tokenizer_builder(&self.base_tokenizer)?;
98+
if let Some(max_token_length) = self.max_token_length {
99+
builder = builder.filter_dynamic(tantivy::tokenizer::RemoveLongFilter::limit(
100+
max_token_length,
101+
));
102+
}
103+
if self.lower_case {
104+
builder = builder.filter_dynamic(tantivy::tokenizer::LowerCaser);
105+
}
106+
if self.stem {
107+
builder = builder.filter_dynamic(tantivy::tokenizer::Stemmer::new(self.language));
108+
}
109+
if self.remove_stop_words {
110+
let stop_word_filter = tantivy::tokenizer::StopWordFilter::new(self.language)
111+
.ok_or_else(|| {
112+
Error::invalid_input(
113+
format!(
114+
"removing stop words for language {:?} is not supported yet",
115+
self.language
116+
),
117+
location!(),
118+
)
119+
})?;
120+
builder = builder.filter_dynamic(stop_word_filter);
121+
}
122+
if self.ascii_folding {
123+
builder = builder.filter_dynamic(tantivy::tokenizer::AsciiFoldingFilter);
124+
}
125+
Ok(builder.build())
126+
}
127+
}
128+
129+
fn build_base_tokenizer_builder(name: &str) -> Result<tantivy::tokenizer::TextAnalyzerBuilder> {
130+
match name {
131+
"simple" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
132+
tantivy::tokenizer::SimpleTokenizer::default(),
133+
)
134+
.dynamic()),
135+
"whitespace" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
136+
tantivy::tokenizer::WhitespaceTokenizer::default(),
137+
)
138+
.dynamic()),
139+
"raw" => Ok(tantivy::tokenizer::TextAnalyzer::builder(
140+
tantivy::tokenizer::RawTokenizer::default(),
141+
)
142+
.dynamic()),
143+
_ => Err(Error::invalid_input(
144+
format!("unknown base tokenizer {}", name),
145+
location!(),
146+
)),
147+
}
148+
}

0 commit comments

Comments
 (0)
Please sign in to comment.