Skip to content

Commit

Permalink
Add base moderation to Providers (#513)
Browse files Browse the repository at this point in the history
  • Loading branch information
zakiali authored Jan 10, 2025
1 parent 2b56d48 commit 7b827d0
Show file tree
Hide file tree
Showing 10 changed files with 720 additions and 68 deletions.
14 changes: 12 additions & 2 deletions crates/goose-server/src/routes/reply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ mod tests {
use goose::{
agents::DefaultAgent as Agent,
providers::{
base::{Provider, ProviderUsage, Usage},
base::{Moderation, ModerationResult, Provider, ProviderUsage, Usage},
configs::{ModelConfig, OpenAiProviderConfig},
},
};
Expand All @@ -405,7 +405,7 @@ mod tests {

#[async_trait::async_trait]
impl Provider for MockProvider {
async fn complete(
async fn complete_internal(
&self,
_system_prompt: &str,
_messages: &[Message],
Expand All @@ -426,6 +426,16 @@ mod tests {
}
}

#[async_trait::async_trait]
impl Moderation for MockProvider {
async fn moderate_content(
&self,
_content: &str,
) -> Result<ModerationResult, anyhow::Error> {
Ok(ModerationResult::new(false, None, None))
}
}

#[test]
fn test_convert_messages_user_only() {
let incoming = vec![IncomingMessage {
Expand Down
16 changes: 11 additions & 5 deletions crates/goose/src/providers/anthropic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@ use serde_json::{json, Value};
use std::collections::HashSet;
use std::time::Duration;

use super::base::ProviderUsage;
use super::base::{Provider, Usage};
use super::base::{Moderation, ModerationResult, Provider, ProviderUsage, Usage};
use super::configs::{AnthropicProviderConfig, ModelConfig, ProviderModelConfig};
use super::model_pricing::cost;
use super::model_pricing::model_pricing_for;
Expand Down Expand Up @@ -205,7 +204,7 @@ impl Provider for AnthropicProvider {
cost
)
)]
async fn complete(
async fn complete_internal(
&self,
system: &str,
messages: &[Message],
Expand Down Expand Up @@ -285,6 +284,13 @@ impl Provider for AnthropicProvider {
}
}

#[async_trait]
impl Moderation for AnthropicProvider {
async fn moderate_content(&self, _content: &str) -> Result<ModerationResult> {
Ok(ModerationResult::new(false, None, None))
}
}

#[cfg(test)]
mod tests {
use crate::providers::configs::ModelConfig;
Expand Down Expand Up @@ -340,7 +346,7 @@ mod tests {
let messages = vec![Message::user().with_text("Hello?")];

let (message, usage) = provider
.complete("You are a helpful assistant.", &messages, &[])
.complete_internal("You are a helpful assistant.", &messages, &[])
.await?;

if let MessageContent::Text(text) = &message.content[0] {
Expand Down Expand Up @@ -399,7 +405,7 @@ mod tests {
);

let (message, usage) = provider
.complete("You are a helpful assistant.", &messages, &[tool])
.complete_internal("You are a helpful assistant.", &messages, &[tool])
.await?;

if let MessageContent::ToolRequest(tool_request) = &message.content[0] {
Expand Down
Loading

0 comments on commit 7b827d0

Please sign in to comment.