From 780b9eb654936974ff1df3259bd3568bf237a021 Mon Sep 17 00:00:00 2001 From: Wei Zhang Date: Sat, 5 Oct 2024 12:41:50 +0800 Subject: [PATCH] feat(core): support configuring stop words in model config (#3209) * feat: add stop words to model config * chore: use the name stop_words_from_config * [autofix.ci] apply automated fixes * chore: fix snap test due to source file change * chore: use empty text in addtional stop words * chore: remove dup with_stop_words * chore: fix unit test for code splitter --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- crates/tabby-common/src/config.rs | 8 +++++ crates/tabby-index/src/code/index.rs | 2 +- ...__code__index__tests__code_splitter-2.snap | 16 +++++----- ...ex__code__index__tests__code_splitter.snap | 15 ++++++---- crates/tabby-inference/src/code.rs | 13 +++++++-- crates/tabby-inference/src/decoding.rs | 29 ++++++++++++++++++- crates/tabby/src/services/completion.rs | 2 +- crates/tabby/src/services/model/mod.rs | 5 ++-- 8 files changed, 69 insertions(+), 21 deletions(-) diff --git a/crates/tabby-common/src/config.rs b/crates/tabby-common/src/config.rs index 32d5315ae551..a6bf9e993abb 100644 --- a/crates/tabby-common/src/config.rs +++ b/crates/tabby-common/src/config.rs @@ -185,6 +185,7 @@ fn default_embedding_config() -> ModelConfig { num_gpu_layers: 9999, enable_fast_attention: None, context_size: default_context_size(), + additional_stop_words: None, }) } @@ -221,6 +222,7 @@ impl ModelConfig { num_gpu_layers, enable_fast_attention: None, context_size: default_context_size(), + additional_stop_words: None, }) } } @@ -256,6 +258,9 @@ pub struct HttpModelConfig { /// Used by Chat/Completion API allowing users to get supported models info. #[builder(default)] pub supported_models: Option>, + + #[builder(default)] + pub additional_stop_words: Option>, } #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] @@ -273,6 +278,9 @@ pub struct LocalModelConfig { #[serde(default = "default_context_size")] pub context_size: usize, + + #[serde(default)] + pub additional_stop_words: Option>, } fn default_parallelism() -> u8 { diff --git a/crates/tabby-index/src/code/index.rs b/crates/tabby-index/src/code/index.rs index c47a97f55159..14f8b82b1f62 100644 --- a/crates/tabby-index/src/code/index.rs +++ b/crates/tabby-index/src/code/index.rs @@ -150,7 +150,7 @@ mod tests { #[tokio::test] async fn test_code_splitter() { - // First file, chat/openai_chat.rs + // First file, tabby-inference/src/decoding.rs let file_contents = include_str!("../../../tabby-inference/src/decoding.rs"); let rust_chunks = CodeIntelligence::chunks(file_contents, "rust") diff --git a/crates/tabby-index/src/code/snapshots/tabby_index__code__index__tests__code_splitter-2.snap b/crates/tabby-index/src/code/snapshots/tabby_index__code__index__tests__code_splitter-2.snap index 8e81ee46c51f..1260d07f2867 100644 --- a/crates/tabby-index/src/code/snapshots/tabby_index__code__index__tests__code_splitter-2.snap +++ b/crates/tabby-index/src/code/snapshots/tabby_index__code__index__tests__code_splitter-2.snap @@ -3,14 +3,16 @@ source: crates/tabby-index/src/code/index.rs expression: "format!(\"{:#?}\", text_chunks)" --- [ - "use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap>,\n}\n\nfn reverse(s: T) -> String\nwhere\n T: Into,\n{\n s.into().chars().rev().collect()\n}\n\nimpl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie>;", - "impl StopConditionFactory {\n pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }", - "fn get_trie<'a>(&'a self, language: &'static Language) -> Option> {\n let stop_words = language.get_stop_words();\n if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));", - "trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}\n\nfn create_stop_trie(stop_words: Vec) -> Trie {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}\n\npub struct StopCondition<'a> {\n stop_trie: Option>,\n reversed_text: String,\n num_decoded: usize,\n}", + "use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap>,\n stop_words_from_model_config: Vec,\n}\n\nfn reverse(s: T) -> String\nwhere\n T: Into,\n{\n s.into().chars().rev().collect()\n}", + "impl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: vec![],\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie>;\n\nimpl StopConditionFactory {\n pub fn with_stop_words(stop_words: Vec) -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: stop_words,\n }\n }", + "pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }", + "fn get_trie<'a>(&'a self, language: &'static Language) -> Option> {\n let mut stop_words = language.get_stop_words();\n // append model stop words\n stop_words.extend(self.stop_words_from_model_config.iter().cloned());", + "if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));\n trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}", + "fn create_stop_trie(stop_words: Vec) -> Trie {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}\n\npub struct StopCondition<'a> {\n stop_trie: Option>,\n reversed_text: String,\n num_decoded: usize,\n}", "impl<'a> StopCondition<'a> {\n pub fn new(stop_trie: Option>, text: &str) -> Self {\n Self {\n stop_trie,\n reversed_text: reverse(text),\n num_decoded: 0,\n }\n }\n\n pub fn should_stop(&mut self, new_text: &str) -> (bool, usize) {\n self.num_decoded += 1;\n if !new_text.is_empty() {\n self.reversed_text = reverse(new_text) + &self.reversed_text;", - "if let Some(re) = &self.stop_trie {\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();\n if let Some(matched_length) = matched_length {\n return (true, matched_length);\n }\n }\n }\n (false, 0)\n }\n}\n\n#[cfg(test)]\nmod tests {\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;", + "if let Some(re) = &self.stop_trie {\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();\n if let Some(matched_length) = matched_length {\n return (true, matched_length);\n }\n }\n }\n (false, 0)\n }\n}\n\n#[cfg(test)]\nmod tests {\n\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;", "#[test]\n fn test_trie_works() {\n let text = reverse(\"void write_u32(std::uint32_t val) const {\\n write_raw(&val, sizeof(val));\\n }\\n\\n ~llama_file() {\\n if (fp) {\\n std::fclose(fp);\\n }\\n }\\n};\\n\\nvoid\");\n\n let trie = create_stop_trie(vec![\"\\n\\n\".to_owned(), \"\\n\\n \".to_owned()]);\n assert!(trie.common_prefix_search(&text).is_empty());", - "let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n }", + "let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n \"<|file_sep|>\".to_owned(), // qwen 2.5 coder style\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n\n let qwen25coder = reverse(\"qwen25 style stop words;<|file_sep|>\");\n assert!(!trie.common_prefix_search(&qwen25coder).is_empty());\n }", "#[test]\n fn test_stop_condition_max_length() {\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");\n assert!(!should_stop)", - "}\n}", + "}\n\n #[test]\n fn test_stop_condition_additional_stop_words() {\n let factory = StopConditionFactory::with_stop_words(vec![\"<|endoftext|>\".to_owned()]);\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"<|endoftext|>\");\n assert!(should_stop);\n }\n}", ] diff --git a/crates/tabby-index/src/code/snapshots/tabby_index__code__index__tests__code_splitter.snap b/crates/tabby-index/src/code/snapshots/tabby_index__code__index__tests__code_splitter.snap index 53449d28470c..7ddf6ec20d24 100644 --- a/crates/tabby-index/src/code/snapshots/tabby_index__code__index__tests__code_splitter.snap +++ b/crates/tabby-index/src/code/snapshots/tabby_index__code__index__tests__code_splitter.snap @@ -3,11 +3,13 @@ source: crates/tabby-index/src/code/index.rs expression: "format!(\"{:#?}\", rust_chunks)" --- [ - "use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap>,\n}\n\nfn reverse(s: T) -> String\nwhere\n T: Into,\n{\n s.into().chars().rev().collect()\n}\n\nimpl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie>;", + "use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap>,\n stop_words_from_model_config: Vec,\n}\n\nfn reverse(s: T) -> String\nwhere\n T: Into,\n{\n s.into().chars().rev().collect()\n}", + "impl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: vec![],\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie>;", "impl StopConditionFactory", - "{\n pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }", + "{\n pub fn with_stop_words(stop_words: Vec) -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: stop_words,\n }\n }\n\n pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }", "fn get_trie<'a>(&'a self, language: &'static Language) -> Option>", - "{\n let stop_words = language.get_stop_words();\n if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));\n trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}", + "{\n let mut stop_words = language.get_stop_words();\n // append model stop words\n stop_words.extend(self.stop_words_from_model_config.iter().cloned());", + "if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));\n trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}", "fn create_stop_trie(stop_words: Vec) -> Trie {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}\n\npub struct StopCondition<'a> {\n stop_trie: Option>,\n reversed_text: String,\n num_decoded: usize,\n}", "impl<'a> StopCondition<'a>", "{\n pub fn new(stop_trie: Option>, text: &str) -> Self {\n Self {\n stop_trie,\n reversed_text: reverse(text),\n num_decoded: 0,\n }\n }", @@ -15,9 +17,10 @@ expression: "format!(\"{:#?}\", rust_chunks)" "{\n self.num_decoded += 1;\n if !new_text.is_empty() {\n self.reversed_text = reverse(new_text) + &self.reversed_text;\n\n if let Some(re) = &self.stop_trie {\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();\n if let Some(matched_length) = matched_length {\n return (true, matched_length);\n }\n }\n }", "(false, 0)\n }\n}\n\n#[cfg(test)]", "mod tests", - "{\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;\n\n #[test]", + "{\n\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;\n\n #[test]", "fn test_trie_works()", "{\n let text = reverse(\"void write_u32(std::uint32_t val) const {\\n write_raw(&val, sizeof(val));\\n }\\n\\n ~llama_file() {\\n if (fp) {\\n std::fclose(fp);\\n }\\n }\\n};\\n\\nvoid\");\n\n let trie = create_stop_trie(vec![\"\\n\\n\".to_owned(), \"\\n\\n \".to_owned()]);\n assert!(trie.common_prefix_search(&text).is_empty());", - "let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n }\n\n #[test]", - "fn test_stop_condition_max_length() {\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");\n assert!(!should_stop)\n }\n}", + "let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n \"<|file_sep|>\".to_owned(), // qwen 2.5 coder style\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n\n let qwen25coder = reverse(\"qwen25 style stop words;<|file_sep|>\");\n assert!(!trie.common_prefix_search(&qwen25coder).is_empty());\n }\n\n #[test]", + "fn test_stop_condition_max_length() {\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");\n assert!(!should_stop)\n }", + "#[test]\n fn test_stop_condition_additional_stop_words() {\n let factory = StopConditionFactory::with_stop_words(vec![\"<|endoftext|>\".to_owned()]);\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"<|endoftext|>\");\n assert!(should_stop);\n }\n}", ] diff --git a/crates/tabby-inference/src/code.rs b/crates/tabby-inference/src/code.rs index 76e14c4e1c14..2d420053bf42 100644 --- a/crates/tabby-inference/src/code.rs +++ b/crates/tabby-inference/src/code.rs @@ -3,7 +3,7 @@ use std::sync::Arc; use async_stream::stream; use derive_builder::Builder; use futures::StreamExt; -use tabby_common::languages::Language; +use tabby_common::{config::ModelConfig, languages::Language}; use crate::{decoding::StopConditionFactory, CompletionOptionsBuilder, CompletionStream}; @@ -31,10 +31,17 @@ pub struct CodeGeneration { } impl CodeGeneration { - pub fn new(imp: Arc) -> Self { + pub fn new(imp: Arc, config: Option) -> Self { + let additional_stop_words = match config { + Some(ModelConfig::Local(config)) => config.additional_stop_words.unwrap_or_default(), + Some(ModelConfig::Http(config)) => config.additional_stop_words.unwrap_or_default(), + _ => vec![], + }; + let stop_condition_factory = StopConditionFactory::with_stop_words(additional_stop_words); + Self { imp, - stop_condition_factory: StopConditionFactory::default(), + stop_condition_factory, } } } diff --git a/crates/tabby-inference/src/decoding.rs b/crates/tabby-inference/src/decoding.rs index 91087f4d0f05..de8961393230 100644 --- a/crates/tabby-inference/src/decoding.rs +++ b/crates/tabby-inference/src/decoding.rs @@ -4,6 +4,7 @@ use trie_rs::{Trie, TrieBuilder}; pub struct StopConditionFactory { stop_trie_cache: DashMap>, + stop_words_from_model_config: Vec, } fn reverse(s: T) -> String @@ -17,6 +18,7 @@ impl Default for StopConditionFactory { fn default() -> Self { Self { stop_trie_cache: DashMap::new(), + stop_words_from_model_config: vec![], } } } @@ -24,6 +26,13 @@ impl Default for StopConditionFactory { type CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie>; impl StopConditionFactory { + pub fn with_stop_words(stop_words: Vec) -> Self { + Self { + stop_trie_cache: DashMap::new(), + stop_words_from_model_config: stop_words, + } + } + pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition { if let Some(language) = language { StopCondition::new(self.get_trie(language), text) @@ -33,7 +42,10 @@ impl StopConditionFactory { } fn get_trie<'a>(&'a self, language: &'static Language) -> Option> { - let stop_words = language.get_stop_words(); + let mut stop_words = language.get_stop_words(); + // append model stop words + stop_words.extend(self.stop_words_from_model_config.iter().cloned()); + if stop_words.is_empty() { None } else { @@ -92,6 +104,7 @@ impl<'a> StopCondition<'a> { #[cfg(test)] mod tests { + use tabby_common::languages::UNKNOWN_LANGUAGE; use super::*; @@ -107,8 +120,12 @@ mod tests { "\n\n".to_owned(), "\n\n ".to_owned(), "\nvoid".to_owned(), + "<|file_sep|>".to_owned(), // qwen 2.5 coder style ]); assert!(!trie.common_prefix_search(&text).is_empty()); + + let qwen25coder = reverse("qwen25 style stop words;<|file_sep|>"); + assert!(!trie.common_prefix_search(&qwen25coder).is_empty()); } #[test] @@ -124,4 +141,14 @@ mod tests { let (should_stop, _) = cond.should_stop("4"); assert!(!should_stop) } + + #[test] + fn test_stop_condition_additional_stop_words() { + let factory = StopConditionFactory::with_stop_words(vec!["<|endoftext|>".to_owned()]); + let mut cond = factory.create("", Some(&UNKNOWN_LANGUAGE)); + let (should_stop, _) = cond.should_stop("1"); + assert!(!should_stop); + let (should_stop, _) = cond.should_stop("<|endoftext|>"); + assert!(should_stop); + } } diff --git a/crates/tabby/src/services/completion.rs b/crates/tabby/src/services/completion.rs index b489dd159783..3344b0e29f3c 100644 --- a/crates/tabby/src/services/completion.rs +++ b/crates/tabby/src/services/completion.rs @@ -436,7 +436,7 @@ mod tests { } fn mock_completion_service() -> CompletionService { - let generation = CodeGeneration::new(Arc::new(MockCompletionStream)); + let generation = CodeGeneration::new(Arc::new(MockCompletionStream), None); CompletionService::new( CompletionConfig::default(), Arc::new(generation), diff --git a/crates/tabby/src/services/model/mod.rs b/crates/tabby/src/services/model/mod.rs index d72aba0d7ecb..0948ae084766 100644 --- a/crates/tabby/src/services/model/mod.rs +++ b/crates/tabby/src/services/model/mod.rs @@ -18,8 +18,9 @@ pub async fn load_code_generation_and_chat( Option, Option>, ) { - let (engine, prompt_info, chat) = load_completion_and_chat(completion_model, chat_model).await; - let code = engine.map(|engine| Arc::new(CodeGeneration::new(engine))); + let (engine, prompt_info, chat) = + load_completion_and_chat(completion_model.clone(), chat_model).await; + let code = engine.map(|engine| Arc::new(CodeGeneration::new(engine, completion_model))); (code, prompt_info, chat) }