Skip to content

Commit

Permalink
Dotnet: Add modelServiceId support to SemanticKernelAgent (#5422)
Browse files Browse the repository at this point in the history
The `SemanticKernelAgent` class has been updated to include an optional
`modelServiceId` parameter, allowing the specification of a service ID
for the model.

## Why are these changes needed?

Currently, `SemanticKernelAgent` uses the parameterless method for
resolving `IChatCompletionSerivce`. This will fail, when multiple models
are registered in the Kernel.

To support different models registered in the Kernel, I adopted the
resolving of the `IChatCompletionSerivce` within the
`SemanticKernelAgent` with an optional parameter. When it is not set, I
resolve the default instance, otherwise, I use the optional parameter as
a servide id for resolving the `IChatCompletionSerivce` service.

## Related issue number



## Checks

- [x] I've included any doc changes needed for
https://microsoft.github.io/autogen/. See
https://microsoft.github.io/autogen/docs/Contribute#documentation to
build and test documentation locally.
- [x] I've added tests (if relevant) corresponding to the changes
introduced in this PR.
- [x] I've made sure all auto checks have passed.

---------

Co-authored-by: Ryan Sweet <rysweet@microsoft.com>
Co-authored-by: Xiaoyun Zhang <bigmiao.zhang@gmail.com>
  • Loading branch information
3 people authored Feb 25, 2025
1 parent 1380412 commit b37c192
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public static async Task RunAsync()

kernel.Plugins.AddFromObject(new LightPlugin());
var skAgent = kernel
.ToSemanticKernelAgent(name: "assistant", systemMessage: "You control the light", settings);
.ToSemanticKernelAgent(name: "assistant", systemMessage: "You control the light", settings: settings);

// Send a message to the skAgent, the skAgent supports the following message types:
// - IMessage<ChatMessageContent>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ namespace AutoGen.SemanticKernel.Extension;

public static class KernelExtension
{
public static SemanticKernelAgent ToSemanticKernelAgent(this Kernel kernel, string name, string systemMessage = "You are a helpful AI assistant", PromptExecutionSettings? settings = null)
public static SemanticKernelAgent ToSemanticKernelAgent(this Kernel kernel, string name, string systemMessage = "You are a helpful AI assistant", string? modelServiceId = null, PromptExecutionSettings? settings = null)
{
return new SemanticKernelAgent(kernel, name, systemMessage, settings);
return new SemanticKernelAgent(kernel, name, systemMessage, modelServiceId, settings);
}

/// <summary>
Expand Down
22 changes: 20 additions & 2 deletions dotnet/src/AutoGen.SemanticKernel/SemanticKernelAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,28 @@ public class SemanticKernelAgent : IStreamingAgent
{
private readonly Kernel _kernel;
private readonly string _systemMessage;
private readonly string? _modelServiceId;
private readonly PromptExecutionSettings? _settings;

/// <summary>
/// Create a new instance of <see cref="SemanticKernelAgent"/>
/// </summary>
/// <param name="kernel">The Semantic Kernel - Kernel object</param>
/// <param name="name">The name of the agent.</param>
/// <param name="systemMessage">The system message.</param>
/// <param name="modelServiceId">Optional serviceId for the model.</param>
/// <param name="settings">The prompt execution settings.</param>
public SemanticKernelAgent(
Kernel kernel,
string name,
string systemMessage = "You are a helpful AI assistant",
string? modelServiceId = null,
PromptExecutionSettings? settings = null)
{
_kernel = kernel;
this.Name = name;
_systemMessage = systemMessage;
_modelServiceId = modelServiceId;
_settings = settings;
}

Expand All @@ -52,7 +63,7 @@ public async Task<IMessage> GenerateReplyAsync(IEnumerable<IMessage> messages, G
{
var chatHistory = BuildChatHistory(messages);
var option = BuildOption(options);
var chatService = _kernel.GetRequiredService<IChatCompletionService>();
var chatService = GetChatCompletionService();

var reply = await chatService.GetChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken);

Expand All @@ -71,7 +82,7 @@ public async IAsyncEnumerable<IMessage> GenerateStreamingReplyAsync(
{
var chatHistory = BuildChatHistory(messages);
var option = BuildOption(options);
var chatService = _kernel.GetRequiredService<IChatCompletionService>();
var chatService = GetChatCompletionService();
var response = chatService.GetStreamingChatMessageContentsAsync(chatHistory, option, _kernel, cancellationToken);

await foreach (var content in response)
Expand Down Expand Up @@ -108,6 +119,13 @@ private PromptExecutionSettings BuildOption(GenerateReplyOptions? options)
};
}

private IChatCompletionService GetChatCompletionService()
{
return string.IsNullOrEmpty(_modelServiceId)
? _kernel.GetRequiredService<IChatCompletionService>()
: _kernel.GetRequiredService<IChatCompletionService>(_modelServiceId);
}

private IEnumerable<ChatMessageContent> ProcessMessage(IEnumerable<IMessage> messages)
{
return messages.Select(m => m switch
Expand Down
46 changes: 32 additions & 14 deletions dotnet/test/AutoGen.SemanticKernel.Tests/SemanticKernelAgentTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,25 +37,25 @@ public async Task BasicConversationTestAsync()
.AddAzureOpenAIChatCompletion(deploymentName, endpoint, key);

var kernel = builder.Build();

kernel.GetRequiredService<IChatCompletionService>();

var skAgent = new SemanticKernelAgent(kernel, "assistant");

var chatMessageContent = MessageEnvelope.Create(new ChatMessageContent(AuthorRole.Assistant, "Hello"));
var reply = await skAgent.SendAsync(chatMessageContent);
await TestBasicConversationAsync(skAgent);
}

reply.Should().BeOfType<MessageEnvelope<ChatMessageContent>>();
reply.As<MessageEnvelope<ChatMessageContent>>().From.Should().Be("assistant");
[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
public async Task BasicConversationTestWithKeyedServiceAsync()
{
var endpoint = Environment.GetEnvironmentVariable("AZURE_OPENAI_ENDPOINT") ?? throw new Exception("Please set AZURE_OPENAI_ENDPOINT environment variable.");
var key = Environment.GetEnvironmentVariable("AZURE_OPENAI_API_KEY") ?? throw new Exception("Please set AZURE_OPENAI_API_KEY environment variable.");
var deploymentName = Environment.GetEnvironmentVariable("AZURE_OPENAI_DEPLOY_NAME") ?? throw new Exception("Please set AZURE_OPENAI_DEPLOY_NAME environment variable.");
var modelServiceId = "my-service-id";
var builder = Kernel.CreateBuilder()
.AddAzureOpenAIChatCompletion(deploymentName, endpoint, key, modelServiceId);

// test streaming
var streamingReply = skAgent.GenerateStreamingReplyAsync(new[] { chatMessageContent });
var kernel = builder.Build();
var skAgent = new SemanticKernelAgent(kernel, "assistant", modelServiceId: modelServiceId);

await foreach (var streamingMessage in streamingReply)
{
streamingMessage.Should().BeOfType<MessageEnvelope<StreamingChatMessageContent>>();
streamingMessage.As<MessageEnvelope<StreamingChatMessageContent>>().From.Should().Be("assistant");
}
await TestBasicConversationAsync(skAgent);
}

[ApiKeyFact("AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOY_NAME")]
Expand Down Expand Up @@ -241,4 +241,22 @@ public async Task SkChatCompletionAgentPluginTestAsync()
reply.GetContent()!.ToLower().Should().Contain("seattle");
reply.GetContent()!.ToLower().Should().Contain("sunny");
}

private static async Task TestBasicConversationAsync(SemanticKernelAgent agent)
{
var chatMessageContent = MessageEnvelope.Create(new ChatMessageContent(AuthorRole.Assistant, "Hello"));
var reply = await agent.SendAsync(chatMessageContent);

reply.Should().BeOfType<MessageEnvelope<ChatMessageContent>>();
reply.As<MessageEnvelope<ChatMessageContent>>().From.Should().Be("assistant");

// test streaming
var streamingReply = agent.GenerateStreamingReplyAsync(new[] { chatMessageContent });

await foreach (var streamingMessage in streamingReply)
{
streamingMessage.Should().BeOfType<MessageEnvelope<StreamingChatMessageContent>>();
streamingMessage.As<MessageEnvelope<StreamingChatMessageContent>>().From.Should().Be("assistant");
}
}
}

0 comments on commit b37c192

Please sign in to comment.