From c9194845faa12b8a0fcecdd65f8ec9d3d221ba08 Mon Sep 17 00:00:00 2001 From: Timon Vonk Date: Sat, 11 Jan 2025 22:37:55 +0100 Subject: [PATCH] feat: Ollama via async-openai with chatcompletion support (#545) Adds support for chatcompletions (agents) for ollama. SimplePrompt and embeddings now use async-openai underneath. Copy pasted as I expect some differences in the future. --- .github/workflows/test.yml | 2 +- Cargo.lock | 22 +- Cargo.toml | 3 +- swiftide-agents/src/lib.rs | 2 +- .../src/search_strategies/custom_strategy.rs | 2 +- swiftide-integrations/Cargo.toml | 5 +- .../src/ollama/chat_completion.rs | 196 ++++++++++++++++++ swiftide-integrations/src/ollama/config.rs | 51 +++++ swiftide-integrations/src/ollama/embed.rs | 23 +- swiftide-integrations/src/ollama/mod.rs | 15 +- .../src/ollama/simple_prompt.rs | 37 +++- 11 files changed, 307 insertions(+), 51 deletions(-) create mode 100644 swiftide-integrations/src/ollama/chat_completion.rs create mode 100644 swiftide-integrations/src/ollama/config.rs diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index d50b1e91..b674c649 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -28,4 +28,4 @@ jobs: with: repo-token: ${{ secrets.GITHUB_TOKEN }} - name: "Test" - run: cargo test -j 2 --all-features --tests + run: cargo test -j 2 --all-features --no-fail-fast diff --git a/Cargo.lock b/Cargo.lock index 9d6d0c42..3d6571b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4059,7 +4059,7 @@ dependencies = [ "httpdate", "itoa 1.0.14", "pin-project-lite", - "socket2 0.5.8", + "socket2 0.4.10", "tokio", "tower-service", "tracing", @@ -5940,21 +5940,6 @@ dependencies = [ "walkdir", ] -[[package]] -name = "ollama-rs" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763afb01db2dced00e656cc2cdcd875659fc3fac4c449e6337a4f04f9e3d9efc" -dependencies = [ - "async-stream", - "async-trait", - "log", - "reqwest", - "serde", - "serde_json", - "url", -] - [[package]] name = "once_cell" version = "1.20.2" @@ -6720,7 +6705,7 @@ version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0f3e5beed80eb580c68e2c600937ac2c4eedabdfd5ef1e5b7ea4f3fba84497b" dependencies = [ - "heck 0.5.0", + "heck 0.4.1", "itertools 0.13.0", "log", "multimap", @@ -8059,7 +8044,7 @@ version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "03c3c6b7927ffe7ecaa769ee0e3994da3b8cafc8f444578982c83ecb161af917" dependencies = [ - "heck 0.5.0", + "heck 0.4.1", "proc-macro2", "quote", "syn 2.0.91", @@ -8690,7 +8675,6 @@ dependencies = [ "itertools 0.13.0", "lancedb", "mockall", - "ollama-rs", "parquet", "pgvector", "qdrant-client", diff --git a/Cargo.toml b/Cargo.toml index 252611cd..ade3c0d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -44,7 +44,7 @@ convert_case = "0.6.0" # Integrations spider = { version = "2.21" } -async-openai = { version = "0.26" } +async-openai = { version = "0.26.0" } qdrant-client = { version = "1.10", default-features = false, features = [ "serde", ] } @@ -67,7 +67,6 @@ fastembed = "4.4" flv-util = "0.5.2" htmd = "0.1" ignore = "0.4" -ollama-rs = "0.2.2" proc-macro2 = "1.0" quote = "1.0" redis = "0.27" diff --git a/swiftide-agents/src/lib.rs b/swiftide-agents/src/lib.rs index 2fa5b452..41abe4c6 100644 --- a/swiftide-agents/src/lib.rs +++ b/swiftide-agents/src/lib.rs @@ -11,7 +11,7 @@ //! //! # Example //! -//! ```no_run +//! ```ignore //! # use swiftide_agents::Agent; //! # use swiftide_integrations as integrations; //! # async fn run() -> Result<(), Box> { diff --git a/swiftide-core/src/search_strategies/custom_strategy.rs b/swiftide-core/src/search_strategies/custom_strategy.rs index ac0e221e..d2e9ea6c 100644 --- a/swiftide-core/src/search_strategies/custom_strategy.rs +++ b/swiftide-core/src/search_strategies/custom_strategy.rs @@ -33,7 +33,7 @@ type QueryGenerator = Arc) -> Result + Send /// * `Q` - The retriever-specific query type (e.g., `sqlx::QueryBuilder` for `PostgreSQL`) /// /// # Examples -/// ```rust +/// ```ignore /// // Define search configuration /// const MAX_SEARCH_RESULTS: i64 = 5; /// diff --git a/swiftide-integrations/Cargo.toml b/swiftide-integrations/Cargo.toml index d7b21f72..6eb2a6d1 100644 --- a/swiftide-integrations/Cargo.toml +++ b/swiftide-integrations/Cargo.toml @@ -67,7 +67,6 @@ aws-sdk-bedrockruntime = { workspace = true, features = [ ], optional = true } secrecy = { workspace = true, optional = true } reqwest = { workspace = true, optional = true } -ollama-rs = { workspace = true, optional = true } deadpool = { workspace = true, features = [ "managed", "rt_tokio_1", @@ -127,8 +126,8 @@ tree-sitter = [ openai = ["dep:async-openai"] # Groq prompting groq = ["dep:async-openai", "dep:secrecy", "dep:reqwest"] -# Ollama prompting -ollama = ["dep:ollama-rs"] +# Ollama prompting, embedding, chatcompletion +ollama = ["dep:async-openai", "dep:secrecy", "dep:reqwest"] # FastEmbed (by qdrant) for fast, local embeddings fastembed = ["dep:fastembed"] # Scraping via spider as loader and a html to markdown transformer diff --git a/swiftide-integrations/src/ollama/chat_completion.rs b/swiftide-integrations/src/ollama/chat_completion.rs new file mode 100644 index 00000000..04c6ef40 --- /dev/null +++ b/swiftide-integrations/src/ollama/chat_completion.rs @@ -0,0 +1,196 @@ +use anyhow::{Context as _, Result}; +use async_openai::types::{ + ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs, + ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs, + ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs, + ChatCompletionToolType, CreateChatCompletionRequestArgs, FunctionCall, FunctionObjectArgs, +}; +use async_trait::async_trait; +use itertools::Itertools; +use serde_json::json; +use swiftide_core::chat_completion::{ + errors::ChatCompletionError, ChatCompletion, ChatCompletionRequest, ChatCompletionResponse, + ChatMessage, ToolCall, ToolSpec, +}; + +use super::Ollama; + +#[async_trait] +impl ChatCompletion for Ollama { + #[tracing::instrument(skip_all)] + async fn complete( + &self, + request: &ChatCompletionRequest, + ) -> Result { + let model = self + .default_options + .prompt_model + .as_ref() + .context("Model not set")?; + + let messages = request + .messages() + .iter() + .map(message_to_openai) + .collect::>>()?; + + // Build the request to be sent to the OpenAI API. + let mut openai_request = CreateChatCompletionRequestArgs::default() + .model(model) + .messages(messages) + .to_owned(); + + if !request.tools_spec.is_empty() { + openai_request + .tools( + request + .tools_spec() + .iter() + .map(tools_to_openai) + .collect::>>()?, + ) + .tool_choice("auto") + .parallel_tool_calls(true); + } + + let request = openai_request + .build() + .map_err(|e| ChatCompletionError::LLM(Box::new(e)))?; + + tracing::debug!( + model = &model, + request = serde_json::to_string_pretty(&request).expect("infallible"), + "Sending request to Ollama" + ); + + let response = self + .client + .chat() + .create(request) + .await + .map_err(|e| ChatCompletionError::LLM(Box::new(e)))?; + + tracing::debug!( + response = serde_json::to_string_pretty(&response).expect("infallible"), + "Received response from Ollama" + ); + + ChatCompletionResponse::builder() + .maybe_message( + response + .choices + .first() + .and_then(|choice| choice.message.content.clone()), + ) + .maybe_tool_calls( + response + .choices + .first() + .and_then(|choice| choice.message.tool_calls.clone()) + .map(|tool_calls| { + tool_calls + .iter() + .map(|tool_call| { + ToolCall::builder() + .id(tool_call.id.clone()) + .args(tool_call.function.arguments.clone()) + .name(tool_call.function.name.clone()) + .build() + .expect("infallible") + }) + .collect_vec() + }), + ) + .build() + .map_err(ChatCompletionError::from) + } +} + +// TODO: Maybe just into the whole thing? Types are not in this crate + +fn tools_to_openai(spec: &ToolSpec) -> Result { + let mut properties = serde_json::Map::new(); + + for param in &spec.parameters { + properties.insert( + param.name.to_string(), + json!({ + "type": "string", + "description": param.description, + }), + ); + } + + ChatCompletionToolArgs::default() + .r#type(ChatCompletionToolType::Function) + .function(FunctionObjectArgs::default() + .name(spec.name) + .description(spec.description) + .parameters(json!({ + "type": "object", + "properties": properties, + "required": spec.parameters.iter().filter(|param| param.required).map(|param| param.name).collect_vec(), + "additionalProperties": false, + })).build()?).build() + .map_err(anyhow::Error::from) +} + +fn message_to_openai( + message: &ChatMessage, +) -> Result { + let openai_message = match message { + ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default() + .content(msg.as_str()) + .build()? + .into(), + ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default() + .content(msg.as_str()) + .build()? + .into(), + ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default() + .content(msg.as_str()) + .build()? + .into(), + ChatMessage::ToolOutput(tool_call, tool_output) => { + let Some(content) = tool_output.content() else { + return Ok(ChatCompletionRequestToolMessageArgs::default() + .tool_call_id(tool_call.id()) + .build()? + .into()); + }; + + ChatCompletionRequestToolMessageArgs::default() + .content(content) + .tool_call_id(tool_call.id()) + .build()? + .into() + } + ChatMessage::Assistant(msg, tool_calls) => { + let mut builder = ChatCompletionRequestAssistantMessageArgs::default(); + + if let Some(msg) = msg { + builder.content(msg.as_str()); + } + + if let Some(tool_calls) = tool_calls { + builder.tool_calls( + tool_calls + .iter() + .map(|tool_call| ChatCompletionMessageToolCall { + id: tool_call.id().to_string(), + r#type: ChatCompletionToolType::Function, + function: FunctionCall { + name: tool_call.name().to_string(), + arguments: tool_call.args().unwrap_or_default().to_string(), + }, + }) + .collect::>(), + ); + } + + builder.build()?.into() + } + }; + + Ok(openai_message) +} diff --git a/swiftide-integrations/src/ollama/config.rs b/swiftide-integrations/src/ollama/config.rs new file mode 100644 index 00000000..3a1f350d --- /dev/null +++ b/swiftide-integrations/src/ollama/config.rs @@ -0,0 +1,51 @@ +use reqwest::header::HeaderMap; +use secrecy::Secret; +use serde::Deserialize; + +const OLLAMA_API_BASE: &str = "http://localhost:11434/v1"; + +#[derive(Clone, Debug, Deserialize)] +#[serde(default)] +pub struct OllamaConfig { + api_base: String, + api_key: Secret, +} + +impl OllamaConfig { + pub fn with_api_base(&mut self, api_base: &str) -> &mut Self { + self.api_base = api_base.to_string(); + + self + } +} + +impl Default for OllamaConfig { + fn default() -> Self { + Self { + api_base: OLLAMA_API_BASE.to_string(), + api_key: String::new().into(), + } + } +} + +impl async_openai::config::Config for OllamaConfig { + fn headers(&self) -> HeaderMap { + HeaderMap::new() + } + + fn url(&self, path: &str) -> String { + format!("{}{}", self.api_base, path) + } + + fn api_base(&self) -> &str { + &self.api_base + } + + fn api_key(&self) -> &Secret { + &self.api_key + } + + fn query(&self) -> Vec<(&str, &str)> { + vec![] + } +} diff --git a/swiftide-integrations/src/ollama/embed.rs b/swiftide-integrations/src/ollama/embed.rs index c9d71bd9..66780e0e 100644 --- a/swiftide-integrations/src/ollama/embed.rs +++ b/swiftide-integrations/src/ollama/embed.rs @@ -1,7 +1,7 @@ use anyhow::{Context as _, Result}; +use async_openai::types::CreateEmbeddingRequestArgs; use async_trait::async_trait; -use ollama_rs::generation::embeddings::request::GenerateEmbeddingsRequest; use swiftide_core::{EmbeddingModel, Embeddings}; use super::Ollama; @@ -15,19 +15,26 @@ impl EmbeddingModel for Ollama { .as_ref() .context("Model not set")?; - let request = GenerateEmbeddingsRequest::new(model.to_string(), input.into()); + let request = CreateEmbeddingRequestArgs::default() + .model(model) + .input(&input) + .build()?; tracing::debug!( - messages = serde_json::to_string_pretty(&request)?, - "[Embed] Request to ollama" + num_chunks = input.len(), + model = &model, + "[Embed] Request to openai" ); let response = self .client - .generate_embeddings(request) + .embeddings() + .create(request) .await - .context("Request to Ollama Failed")?; + .context("Request to OpenAI Failed")?; - tracing::debug!("[Embed] Response ollama"); + let num_embeddings = response.data.len(); + tracing::debug!(num_embeddings = num_embeddings, "[Embed] Response openai"); - Ok(response.embeddings) + // WARN: Naively assumes that the order is preserved. Might not always be the case. + Ok(response.data.into_iter().map(|d| d.embedding).collect()) } } diff --git a/swiftide-integrations/src/ollama/mod.rs b/swiftide-integrations/src/ollama/mod.rs index 5dd48786..9193a5ba 100644 --- a/swiftide-integrations/src/ollama/mod.rs +++ b/swiftide-integrations/src/ollama/mod.rs @@ -2,11 +2,14 @@ //! It includes the `Ollama` struct for managing API clients and default options for embedding and prompt models. //! The module is conditionally compiled based on the "ollama" feature flag. +use config::OllamaConfig; use derive_builder::Builder; use std::sync::Arc; -mod embed; -mod simple_prompt; +pub mod chat_completion; +pub mod config; +pub mod embed; +pub mod simple_prompt; /// The `Ollama` struct encapsulates an `Ollama` client and default options for embedding and prompt models. /// It uses the `Builder` pattern for flexible and customizable instantiation. @@ -22,7 +25,7 @@ mod simple_prompt; pub struct Ollama { /// The `Ollama` client, wrapped in an `Arc` for thread-safe reference counting. #[builder(default = "default_client()", setter(custom))] - client: Arc, + client: Arc>, /// Default options for the embedding and prompt models. #[builder(default)] default_options: Options, @@ -91,7 +94,7 @@ impl OllamaBuilder { /// /// # Returns /// A mutable reference to the `OllamaBuilder`. - pub fn client(&mut self, client: ollama_rs::Ollama) -> &mut Self { + pub fn client(&mut self, client: async_openai::Client) -> &mut Self { self.client = Some(Arc::new(client)); self } @@ -135,8 +138,8 @@ impl OllamaBuilder { } } -fn default_client() -> Arc { - ollama_rs::Ollama::default().into() +fn default_client() -> Arc> { + Arc::new(async_openai::Client::with_config(OllamaConfig::default())) } #[cfg(test)] diff --git a/swiftide-integrations/src/ollama/simple_prompt.rs b/swiftide-integrations/src/ollama/simple_prompt.rs index 1eeb2523..4db3763f 100644 --- a/swiftide-integrations/src/ollama/simple_prompt.rs +++ b/swiftide-integrations/src/ollama/simple_prompt.rs @@ -1,8 +1,9 @@ //! This module provides an implementation of the `SimplePrompt` trait for the `Ollama` struct. //! It defines an asynchronous function to interact with the `Ollama` API, allowing prompt processing //! and generating responses as part of the Swiftide system. +use async_openai::types::{ChatCompletionRequestUserMessageArgs, CreateChatCompletionRequestArgs}; use async_trait::async_trait; -use swiftide_core::{prompt::Prompt, SimplePrompt}; +use swiftide_core::{prompt::Prompt, util::debug_long_utf8, SimplePrompt}; use super::Ollama; use anyhow::{Context as _, Result}; @@ -33,28 +34,44 @@ impl SimplePrompt for Ollama { .context("Model not set")?; // Build the request to be sent to the Ollama API. - let request = ollama_rs::generation::completion::request::GenerationRequest::new( - model.to_string(), - prompt.render().await?, - ); + let request = CreateChatCompletionRequestArgs::default() + .model(model) + .messages(vec![ChatCompletionRequestUserMessageArgs::default() + .content(prompt.render().await?) + .build()? + .into()]) + .build()?; // Log the request for debugging purposes. tracing::debug!( - messages = serde_json::to_string_pretty(&request)?, + model = &model, + messages = debug_long_utf8( + serde_json::to_string_pretty(&request.messages.first())?, + 100 + ), "[SimplePrompt] Request to ollama" ); // Send the request to the Ollama API and await the response. - // let mut response = self.client.chat().create(request).await?; - let response = self.client.generate(request).await?; + let response = self + .client + .chat() + .create(request) + .await? + .choices + .remove(0) + .message + .content + .take() + .context("Expected content in response")?; // Log the response for debugging purposes. tracing::debug!( - response = serde_json::to_string_pretty(&response.response)?, + response = debug_long_utf8(&response, 100), "[SimplePrompt] Response from ollama" ); // Extract and return the content of the response, returning an error if not found. - Ok(response.response) + Ok(response) } }