Skip to content

Commit

Permalink
feat(core): support configuring stop words in model config (TabbyML#3209
Browse files Browse the repository at this point in the history
)

* feat: add stop words to model config

* chore: use the name stop_words_from_config

* [autofix.ci] apply automated fixes

* chore: fix snap test due to source file change

* chore: use empty text in addtional stop words

* chore: remove dup with_stop_words

* chore: fix unit test for code splitter

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
  • Loading branch information
zwpaper and autofix-ci[bot] authored Oct 5, 2024
1 parent b700fde commit 780b9eb
Show file tree
Hide file tree
Showing 8 changed files with 69 additions and 21 deletions.
8 changes: 8 additions & 0 deletions crates/tabby-common/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ fn default_embedding_config() -> ModelConfig {
num_gpu_layers: 9999,
enable_fast_attention: None,
context_size: default_context_size(),
additional_stop_words: None,
})
}

Expand Down Expand Up @@ -221,6 +222,7 @@ impl ModelConfig {
num_gpu_layers,
enable_fast_attention: None,
context_size: default_context_size(),
additional_stop_words: None,
})
}
}
Expand Down Expand Up @@ -256,6 +258,9 @@ pub struct HttpModelConfig {
/// Used by Chat/Completion API allowing users to get supported models info.
#[builder(default)]
pub supported_models: Option<Vec<String>>,

#[builder(default)]
pub additional_stop_words: Option<Vec<String>>,
}

#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
Expand All @@ -273,6 +278,9 @@ pub struct LocalModelConfig {

#[serde(default = "default_context_size")]
pub context_size: usize,

#[serde(default)]
pub additional_stop_words: Option<Vec<String>>,
}

fn default_parallelism() -> u8 {
Expand Down
2 changes: 1 addition & 1 deletion crates/tabby-index/src/code/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ mod tests {

#[tokio::test]
async fn test_code_splitter() {
// First file, chat/openai_chat.rs
// First file, tabby-inference/src/decoding.rs
let file_contents = include_str!("../../../tabby-inference/src/decoding.rs");

let rust_chunks = CodeIntelligence::chunks(file_contents, "rust")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ source: crates/tabby-index/src/code/index.rs
expression: "format!(\"{:#?}\", text_chunks)"
---
[
"use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap<String, Trie<u8>>,\n}\n\nfn reverse<T>(s: T) -> String\nwhere\n T: Into<String>,\n{\n s.into().chars().rev().collect()\n}\n\nimpl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;",
"impl StopConditionFactory {\n pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }",
"fn get_trie<'a>(&'a self, language: &'static Language) -> Option<CachedTrie<'a>> {\n let stop_words = language.get_stop_words();\n if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));",
"trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}\n\nfn create_stop_trie(stop_words: Vec<String>) -> Trie<u8> {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}\n\npub struct StopCondition<'a> {\n stop_trie: Option<CachedTrie<'a>>,\n reversed_text: String,\n num_decoded: usize,\n}",
"use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap<String, Trie<u8>>,\n stop_words_from_model_config: Vec<String>,\n}\n\nfn reverse<T>(s: T) -> String\nwhere\n T: Into<String>,\n{\n s.into().chars().rev().collect()\n}",
"impl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: vec![],\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;\n\nimpl StopConditionFactory {\n pub fn with_stop_words(stop_words: Vec<String>) -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: stop_words,\n }\n }",
"pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }",
"fn get_trie<'a>(&'a self, language: &'static Language) -> Option<CachedTrie<'a>> {\n let mut stop_words = language.get_stop_words();\n // append model stop words\n stop_words.extend(self.stop_words_from_model_config.iter().cloned());",
"if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));\n trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}",
"fn create_stop_trie(stop_words: Vec<String>) -> Trie<u8> {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}\n\npub struct StopCondition<'a> {\n stop_trie: Option<CachedTrie<'a>>,\n reversed_text: String,\n num_decoded: usize,\n}",
"impl<'a> StopCondition<'a> {\n pub fn new(stop_trie: Option<CachedTrie<'a>>, text: &str) -> Self {\n Self {\n stop_trie,\n reversed_text: reverse(text),\n num_decoded: 0,\n }\n }\n\n pub fn should_stop(&mut self, new_text: &str) -> (bool, usize) {\n self.num_decoded += 1;\n if !new_text.is_empty() {\n self.reversed_text = reverse(new_text) + &self.reversed_text;",
"if let Some(re) = &self.stop_trie {\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();\n if let Some(matched_length) = matched_length {\n return (true, matched_length);\n }\n }\n }\n (false, 0)\n }\n}\n\n#[cfg(test)]\nmod tests {\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;",
"if let Some(re) = &self.stop_trie {\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();\n if let Some(matched_length) = matched_length {\n return (true, matched_length);\n }\n }\n }\n (false, 0)\n }\n}\n\n#[cfg(test)]\nmod tests {\n\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;",
"#[test]\n fn test_trie_works() {\n let text = reverse(\"void write_u32(std::uint32_t val) const {\\n write_raw(&val, sizeof(val));\\n }\\n\\n ~llama_file() {\\n if (fp) {\\n std::fclose(fp);\\n }\\n }\\n};\\n\\nvoid\");\n\n let trie = create_stop_trie(vec![\"\\n\\n\".to_owned(), \"\\n\\n \".to_owned()]);\n assert!(trie.common_prefix_search(&text).is_empty());",
"let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n }",
"let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n \"<|file_sep|>\".to_owned(), // qwen 2.5 coder style\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n\n let qwen25coder = reverse(\"qwen25 style stop words;<|file_sep|>\");\n assert!(!trie.common_prefix_search(&qwen25coder).is_empty());\n }",
"#[test]\n fn test_stop_condition_max_length() {\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");\n assert!(!should_stop)",
"}\n}",
"}\n\n #[test]\n fn test_stop_condition_additional_stop_words() {\n let factory = StopConditionFactory::with_stop_words(vec![\"<|endoftext|>\".to_owned()]);\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"<|endoftext|>\");\n assert!(should_stop);\n }\n}",
]
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@ source: crates/tabby-index/src/code/index.rs
expression: "format!(\"{:#?}\", rust_chunks)"
---
[
"use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap<String, Trie<u8>>,\n}\n\nfn reverse<T>(s: T) -> String\nwhere\n T: Into<String>,\n{\n s.into().chars().rev().collect()\n}\n\nimpl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;",
"use dashmap::DashMap;\nuse tabby_common::languages::Language;\nuse trie_rs::{Trie, TrieBuilder};\n\npub struct StopConditionFactory {\n stop_trie_cache: DashMap<String, Trie<u8>>,\n stop_words_from_model_config: Vec<String>,\n}\n\nfn reverse<T>(s: T) -> String\nwhere\n T: Into<String>,\n{\n s.into().chars().rev().collect()\n}",
"impl Default for StopConditionFactory {\n fn default() -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: vec![],\n }\n }\n}\n\ntype CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;",
"impl StopConditionFactory",
"{\n pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }",
"{\n pub fn with_stop_words(stop_words: Vec<String>) -> Self {\n Self {\n stop_trie_cache: DashMap::new(),\n stop_words_from_model_config: stop_words,\n }\n }\n\n pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {\n if let Some(language) = language {\n StopCondition::new(self.get_trie(language), text)\n } else {\n StopCondition::new(None, text)\n }\n }",
"fn get_trie<'a>(&'a self, language: &'static Language) -> Option<CachedTrie<'a>>",
"{\n let stop_words = language.get_stop_words();\n if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));\n trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}",
"{\n let mut stop_words = language.get_stop_words();\n // append model stop words\n stop_words.extend(self.stop_words_from_model_config.iter().cloned());",
"if stop_words.is_empty() {\n None\n } else {\n let hashkey = language.language().to_owned();\n let mut trie = self.stop_trie_cache.get(&hashkey);\n if trie.is_none() {\n self.stop_trie_cache\n .insert(hashkey.clone(), create_stop_trie(stop_words));\n trie = self.stop_trie_cache.get(&hashkey);\n }\n\n trie\n }\n }\n}",
"fn create_stop_trie(stop_words: Vec<String>) -> Trie<u8> {\n let mut builder = TrieBuilder::new();\n for word in stop_words {\n builder.push(reverse(word))\n }\n builder.build()\n}\n\npub struct StopCondition<'a> {\n stop_trie: Option<CachedTrie<'a>>,\n reversed_text: String,\n num_decoded: usize,\n}",
"impl<'a> StopCondition<'a>",
"{\n pub fn new(stop_trie: Option<CachedTrie<'a>>, text: &str) -> Self {\n Self {\n stop_trie,\n reversed_text: reverse(text),\n num_decoded: 0,\n }\n }",
"pub fn should_stop(&mut self, new_text: &str) -> (bool, usize)",
"{\n self.num_decoded += 1;\n if !new_text.is_empty() {\n self.reversed_text = reverse(new_text) + &self.reversed_text;\n\n if let Some(re) = &self.stop_trie {\n let matches = re.common_prefix_search(&self.reversed_text);\n let matched_length = matches.into_iter().map(|x| x.len()).max();\n if let Some(matched_length) = matched_length {\n return (true, matched_length);\n }\n }\n }",
"(false, 0)\n }\n}\n\n#[cfg(test)]",
"mod tests",
"{\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;\n\n #[test]",
"{\n\n use tabby_common::languages::UNKNOWN_LANGUAGE;\n\n use super::*;\n\n #[test]",
"fn test_trie_works()",
"{\n let text = reverse(\"void write_u32(std::uint32_t val) const {\\n write_raw(&val, sizeof(val));\\n }\\n\\n ~llama_file() {\\n if (fp) {\\n std::fclose(fp);\\n }\\n }\\n};\\n\\nvoid\");\n\n let trie = create_stop_trie(vec![\"\\n\\n\".to_owned(), \"\\n\\n \".to_owned()]);\n assert!(trie.common_prefix_search(&text).is_empty());",
"let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n }\n\n #[test]",
"fn test_stop_condition_max_length() {\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");\n assert!(!should_stop)\n }\n}",
"let trie = create_stop_trie(vec![\n \"\\n\\n\".to_owned(),\n \"\\n\\n \".to_owned(),\n \"\\nvoid\".to_owned(),\n \"<|file_sep|>\".to_owned(), // qwen 2.5 coder style\n ]);\n assert!(!trie.common_prefix_search(&text).is_empty());\n\n let qwen25coder = reverse(\"qwen25 style stop words;<|file_sep|>\");\n assert!(!trie.common_prefix_search(&qwen25coder).is_empty());\n }\n\n #[test]",
"fn test_stop_condition_max_length() {\n let factory = StopConditionFactory::default();\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"2\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"3\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"4\");\n assert!(!should_stop)\n }",
"#[test]\n fn test_stop_condition_additional_stop_words() {\n let factory = StopConditionFactory::with_stop_words(vec![\"<|endoftext|>\".to_owned()]);\n let mut cond = factory.create(\"\", Some(&UNKNOWN_LANGUAGE));\n let (should_stop, _) = cond.should_stop(\"1\");\n assert!(!should_stop);\n let (should_stop, _) = cond.should_stop(\"<|endoftext|>\");\n assert!(should_stop);\n }\n}",
]
13 changes: 10 additions & 3 deletions crates/tabby-inference/src/code.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::Arc;
use async_stream::stream;
use derive_builder::Builder;
use futures::StreamExt;
use tabby_common::languages::Language;
use tabby_common::{config::ModelConfig, languages::Language};

use crate::{decoding::StopConditionFactory, CompletionOptionsBuilder, CompletionStream};

Expand Down Expand Up @@ -31,10 +31,17 @@ pub struct CodeGeneration {
}

impl CodeGeneration {
pub fn new(imp: Arc<dyn CompletionStream>) -> Self {
pub fn new(imp: Arc<dyn CompletionStream>, config: Option<ModelConfig>) -> Self {
let additional_stop_words = match config {
Some(ModelConfig::Local(config)) => config.additional_stop_words.unwrap_or_default(),
Some(ModelConfig::Http(config)) => config.additional_stop_words.unwrap_or_default(),
_ => vec![],
};
let stop_condition_factory = StopConditionFactory::with_stop_words(additional_stop_words);

Self {
imp,
stop_condition_factory: StopConditionFactory::default(),
stop_condition_factory,
}
}
}
Expand Down
29 changes: 28 additions & 1 deletion crates/tabby-inference/src/decoding.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use trie_rs::{Trie, TrieBuilder};

pub struct StopConditionFactory {
stop_trie_cache: DashMap<String, Trie<u8>>,
stop_words_from_model_config: Vec<String>,
}

fn reverse<T>(s: T) -> String
Expand All @@ -17,13 +18,21 @@ impl Default for StopConditionFactory {
fn default() -> Self {
Self {
stop_trie_cache: DashMap::new(),
stop_words_from_model_config: vec![],
}
}
}

type CachedTrie<'a> = dashmap::mapref::one::Ref<'a, String, Trie<u8>>;

impl StopConditionFactory {
pub fn with_stop_words(stop_words: Vec<String>) -> Self {
Self {
stop_trie_cache: DashMap::new(),
stop_words_from_model_config: stop_words,
}
}

pub fn create(&self, text: &str, language: Option<&'static Language>) -> StopCondition {
if let Some(language) = language {
StopCondition::new(self.get_trie(language), text)
Expand All @@ -33,7 +42,10 @@ impl StopConditionFactory {
}

fn get_trie<'a>(&'a self, language: &'static Language) -> Option<CachedTrie<'a>> {
let stop_words = language.get_stop_words();
let mut stop_words = language.get_stop_words();
// append model stop words
stop_words.extend(self.stop_words_from_model_config.iter().cloned());

if stop_words.is_empty() {
None
} else {
Expand Down Expand Up @@ -92,6 +104,7 @@ impl<'a> StopCondition<'a> {

#[cfg(test)]
mod tests {

use tabby_common::languages::UNKNOWN_LANGUAGE;

use super::*;
Expand All @@ -107,8 +120,12 @@ mod tests {
"\n\n".to_owned(),
"\n\n ".to_owned(),
"\nvoid".to_owned(),
"<|file_sep|>".to_owned(), // qwen 2.5 coder style
]);
assert!(!trie.common_prefix_search(&text).is_empty());

let qwen25coder = reverse("qwen25 style stop words;<|file_sep|>");
assert!(!trie.common_prefix_search(&qwen25coder).is_empty());
}

#[test]
Expand All @@ -124,4 +141,14 @@ mod tests {
let (should_stop, _) = cond.should_stop("4");
assert!(!should_stop)
}

#[test]
fn test_stop_condition_additional_stop_words() {
let factory = StopConditionFactory::with_stop_words(vec!["<|endoftext|>".to_owned()]);
let mut cond = factory.create("", Some(&UNKNOWN_LANGUAGE));
let (should_stop, _) = cond.should_stop("1");
assert!(!should_stop);
let (should_stop, _) = cond.should_stop("<|endoftext|>");
assert!(should_stop);
}
}
Loading

0 comments on commit 780b9eb

Please sign in to comment.