Skip to content

Commit

Permalink
feat: Ollama via async-openai with chatcompletion support (#545)
Browse files Browse the repository at this point in the history
Adds support for chatcompletions (agents) for ollama. SimplePrompt and embeddings now use async-openai underneath.

Copy pasted as I expect some differences in the future.
  • Loading branch information
timonv authored Jan 11, 2025
1 parent edb94fe commit c919484
Show file tree
Hide file tree
Showing 11 changed files with 307 additions and 51 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,4 @@ jobs:
with:
repo-token: ${{ secrets.GITHUB_TOKEN }}
- name: "Test"
run: cargo test -j 2 --all-features --tests
run: cargo test -j 2 --all-features --no-fail-fast
22 changes: 3 additions & 19 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ convert_case = "0.6.0"

# Integrations
spider = { version = "2.21" }
async-openai = { version = "0.26" }
async-openai = { version = "0.26.0" }
qdrant-client = { version = "1.10", default-features = false, features = [
"serde",
] }
Expand All @@ -67,7 +67,6 @@ fastembed = "4.4"
flv-util = "0.5.2"
htmd = "0.1"
ignore = "0.4"
ollama-rs = "0.2.2"
proc-macro2 = "1.0"
quote = "1.0"
redis = "0.27"
Expand Down
2 changes: 1 addition & 1 deletion swiftide-agents/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
//!
//! # Example
//!
//! ```no_run
//! ```ignore
//! # use swiftide_agents::Agent;
//! # use swiftide_integrations as integrations;
//! # async fn run() -> Result<(), Box<dyn std::error::Error>> {
Expand Down
2 changes: 1 addition & 1 deletion swiftide-core/src/search_strategies/custom_strategy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type QueryGenerator<Q> = Arc<dyn Fn(&Query<states::Pending>) -> Result<Q> + Send
/// * `Q` - The retriever-specific query type (e.g., `sqlx::QueryBuilder` for `PostgreSQL`)
///
/// # Examples
/// ```rust
/// ```ignore
/// // Define search configuration
/// const MAX_SEARCH_RESULTS: i64 = 5;
///
Expand Down
5 changes: 2 additions & 3 deletions swiftide-integrations/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ aws-sdk-bedrockruntime = { workspace = true, features = [
], optional = true }
secrecy = { workspace = true, optional = true }
reqwest = { workspace = true, optional = true }
ollama-rs = { workspace = true, optional = true }
deadpool = { workspace = true, features = [
"managed",
"rt_tokio_1",
Expand Down Expand Up @@ -127,8 +126,8 @@ tree-sitter = [
openai = ["dep:async-openai"]
# Groq prompting
groq = ["dep:async-openai", "dep:secrecy", "dep:reqwest"]
# Ollama prompting
ollama = ["dep:ollama-rs"]
# Ollama prompting, embedding, chatcompletion
ollama = ["dep:async-openai", "dep:secrecy", "dep:reqwest"]
# FastEmbed (by qdrant) for fast, local embeddings
fastembed = ["dep:fastembed"]
# Scraping via spider as loader and a html to markdown transformer
Expand Down
196 changes: 196 additions & 0 deletions swiftide-integrations/src/ollama/chat_completion.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
use anyhow::{Context as _, Result};
use async_openai::types::{
ChatCompletionMessageToolCall, ChatCompletionRequestAssistantMessageArgs,
ChatCompletionRequestSystemMessageArgs, ChatCompletionRequestToolMessageArgs,
ChatCompletionRequestUserMessageArgs, ChatCompletionTool, ChatCompletionToolArgs,
ChatCompletionToolType, CreateChatCompletionRequestArgs, FunctionCall, FunctionObjectArgs,
};
use async_trait::async_trait;
use itertools::Itertools;
use serde_json::json;
use swiftide_core::chat_completion::{
errors::ChatCompletionError, ChatCompletion, ChatCompletionRequest, ChatCompletionResponse,
ChatMessage, ToolCall, ToolSpec,
};

use super::Ollama;

#[async_trait]
impl ChatCompletion for Ollama {
#[tracing::instrument(skip_all)]
async fn complete(
&self,
request: &ChatCompletionRequest,
) -> Result<ChatCompletionResponse, ChatCompletionError> {
let model = self
.default_options
.prompt_model
.as_ref()
.context("Model not set")?;

let messages = request
.messages()
.iter()
.map(message_to_openai)
.collect::<Result<Vec<_>>>()?;

// Build the request to be sent to the OpenAI API.
let mut openai_request = CreateChatCompletionRequestArgs::default()
.model(model)
.messages(messages)
.to_owned();

if !request.tools_spec.is_empty() {
openai_request
.tools(
request
.tools_spec()
.iter()
.map(tools_to_openai)
.collect::<Result<Vec<_>>>()?,
)
.tool_choice("auto")
.parallel_tool_calls(true);
}

let request = openai_request
.build()
.map_err(|e| ChatCompletionError::LLM(Box::new(e)))?;

tracing::debug!(
model = &model,
request = serde_json::to_string_pretty(&request).expect("infallible"),
"Sending request to Ollama"
);

let response = self
.client
.chat()
.create(request)
.await
.map_err(|e| ChatCompletionError::LLM(Box::new(e)))?;

tracing::debug!(
response = serde_json::to_string_pretty(&response).expect("infallible"),
"Received response from Ollama"
);

ChatCompletionResponse::builder()
.maybe_message(
response
.choices
.first()
.and_then(|choice| choice.message.content.clone()),
)
.maybe_tool_calls(
response
.choices
.first()
.and_then(|choice| choice.message.tool_calls.clone())
.map(|tool_calls| {
tool_calls
.iter()
.map(|tool_call| {
ToolCall::builder()
.id(tool_call.id.clone())
.args(tool_call.function.arguments.clone())
.name(tool_call.function.name.clone())
.build()
.expect("infallible")
})
.collect_vec()
}),
)
.build()
.map_err(ChatCompletionError::from)
}
}

// TODO: Maybe just into the whole thing? Types are not in this crate

fn tools_to_openai(spec: &ToolSpec) -> Result<ChatCompletionTool> {
let mut properties = serde_json::Map::new();

for param in &spec.parameters {
properties.insert(
param.name.to_string(),
json!({
"type": "string",
"description": param.description,
}),
);
}

ChatCompletionToolArgs::default()
.r#type(ChatCompletionToolType::Function)
.function(FunctionObjectArgs::default()
.name(spec.name)
.description(spec.description)
.parameters(json!({
"type": "object",
"properties": properties,
"required": spec.parameters.iter().filter(|param| param.required).map(|param| param.name).collect_vec(),
"additionalProperties": false,
})).build()?).build()
.map_err(anyhow::Error::from)
}

fn message_to_openai(
message: &ChatMessage,
) -> Result<async_openai::types::ChatCompletionRequestMessage> {
let openai_message = match message {
ChatMessage::User(msg) => ChatCompletionRequestUserMessageArgs::default()
.content(msg.as_str())
.build()?
.into(),
ChatMessage::System(msg) => ChatCompletionRequestSystemMessageArgs::default()
.content(msg.as_str())
.build()?
.into(),
ChatMessage::Summary(msg) => ChatCompletionRequestAssistantMessageArgs::default()
.content(msg.as_str())
.build()?
.into(),
ChatMessage::ToolOutput(tool_call, tool_output) => {
let Some(content) = tool_output.content() else {
return Ok(ChatCompletionRequestToolMessageArgs::default()
.tool_call_id(tool_call.id())
.build()?
.into());
};

ChatCompletionRequestToolMessageArgs::default()
.content(content)
.tool_call_id(tool_call.id())
.build()?
.into()
}
ChatMessage::Assistant(msg, tool_calls) => {
let mut builder = ChatCompletionRequestAssistantMessageArgs::default();

if let Some(msg) = msg {
builder.content(msg.as_str());
}

if let Some(tool_calls) = tool_calls {
builder.tool_calls(
tool_calls
.iter()
.map(|tool_call| ChatCompletionMessageToolCall {
id: tool_call.id().to_string(),
r#type: ChatCompletionToolType::Function,
function: FunctionCall {
name: tool_call.name().to_string(),
arguments: tool_call.args().unwrap_or_default().to_string(),
},
})
.collect::<Vec<_>>(),
);
}

builder.build()?.into()
}
};

Ok(openai_message)
}
51 changes: 51 additions & 0 deletions swiftide-integrations/src/ollama/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use reqwest::header::HeaderMap;
use secrecy::Secret;
use serde::Deserialize;

const OLLAMA_API_BASE: &str = "http://localhost:11434/v1";

#[derive(Clone, Debug, Deserialize)]
#[serde(default)]
pub struct OllamaConfig {
api_base: String,
api_key: Secret<String>,
}

impl OllamaConfig {
pub fn with_api_base(&mut self, api_base: &str) -> &mut Self {
self.api_base = api_base.to_string();

self
}
}

impl Default for OllamaConfig {
fn default() -> Self {
Self {
api_base: OLLAMA_API_BASE.to_string(),
api_key: String::new().into(),
}
}
}

impl async_openai::config::Config for OllamaConfig {
fn headers(&self) -> HeaderMap {
HeaderMap::new()
}

fn url(&self, path: &str) -> String {
format!("{}{}", self.api_base, path)
}

fn api_base(&self) -> &str {
&self.api_base
}

fn api_key(&self) -> &Secret<String> {
&self.api_key
}

fn query(&self) -> Vec<(&str, &str)> {
vec![]
}
}
Loading

0 comments on commit c919484

Please sign in to comment.