From d0d7d3e89200f11c3682543c04aeaef270fcd0de Mon Sep 17 00:00:00 2001 From: Gian Maria Ricci Date: Wed, 8 Jan 2025 18:01:04 +0100 Subject: [PATCH] Added context to cohere. --- .../Samples/CustomSearchPipelineBase.cs | 10 +++ .../Cohere/CohereCommandRQueryExecutor.cs | 4 ++ .../Cohere/CohereConfiguration.cs | 2 + .../Cohere/RawCohereChatClient.cs | 54 +++++++++++++--- .../Cohere/RawCohereClientDtos.cs | 64 ++++++++++++++++++- .../Helper/LLMCallLog.cs | 12 ++++ .../Helper/SemanticKernelWrapper.cs | 1 + 7 files changed, 138 insertions(+), 9 deletions(-) diff --git a/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs b/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs index f08cdbf..5d7d69f 100644 --- a/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs +++ b/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs @@ -221,6 +221,15 @@ public async Task RunSample2() call.TokenCount.CachedTokenRead, call.TokenCount.CachedTokenWrite); } + + if (call.Warnings.Count > 0) + { + Console.WriteLine("Warnings:"); + foreach (var warning in call.Warnings) + { + Console.WriteLine(warning); + } + } } } else @@ -351,6 +360,7 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder( services.AddSingleton(kernelMemoryBuilder); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); diff --git a/src/KernelMemory.Extensions/Cohere/CohereCommandRQueryExecutor.cs b/src/KernelMemory.Extensions/Cohere/CohereCommandRQueryExecutor.cs index 8bc8cfc..3b55b86 100644 --- a/src/KernelMemory.Extensions/Cohere/CohereCommandRQueryExecutor.cs +++ b/src/KernelMemory.Extensions/Cohere/CohereCommandRQueryExecutor.cs @@ -2,6 +2,7 @@ using Microsoft.Extensions.Logging; using Microsoft.KernelMemory.Diagnostics; using Microsoft.KernelMemory.MemoryStorage; +using Polly.Fallback; using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; @@ -47,15 +48,18 @@ public class CohereCommandRQueryExecutor : BasicAsyncQueryHandlerWithProgress private readonly RawCohereClient _rawCohereClient; private readonly CohereCommandRQueryExecutorConfiguration _config; + private readonly CohereTokenizer _cohereTokenizer; private readonly ILogger _log; public CohereCommandRQueryExecutor( RawCohereClient rawCohereClient, CohereCommandRQueryExecutorConfiguration config, + CohereTokenizer cohereTokenizer, ILogger? log = null) { _rawCohereClient = rawCohereClient; _config = config; + _cohereTokenizer = cohereTokenizer; _log = log ?? DefaultLogger.Instance; } diff --git a/src/KernelMemory.Extensions/Cohere/CohereConfiguration.cs b/src/KernelMemory.Extensions/Cohere/CohereConfiguration.cs index 995f520..45b7cc9 100644 --- a/src/KernelMemory.Extensions/Cohere/CohereConfiguration.cs +++ b/src/KernelMemory.Extensions/Cohere/CohereConfiguration.cs @@ -111,6 +111,8 @@ public static IServiceCollection ConfigureCohereChat( BaseUrl = baseUrl, }); + services.AddSingleton(); + return services; } diff --git a/src/KernelMemory.Extensions/Cohere/RawCohereChatClient.cs b/src/KernelMemory.Extensions/Cohere/RawCohereChatClient.cs index a644345..c048902 100644 --- a/src/KernelMemory.Extensions/Cohere/RawCohereChatClient.cs +++ b/src/KernelMemory.Extensions/Cohere/RawCohereChatClient.cs @@ -10,13 +10,15 @@ using System.Threading.Tasks; using KernelMemory.Extensions.Helper; using Microsoft.Extensions.Logging; +using Microsoft.KernelMemory.Context; using Microsoft.KernelMemory.Diagnostics; namespace KernelMemory.Extensions.Cohere; public class RawCohereChatClient { - private readonly HttpClient _httpClient; + private readonly HttpClient _httpClient; + private readonly IContextProvider _contextProvider; private readonly ILogger _log; private readonly string _apiKey; private readonly string _baseUrl; @@ -24,6 +26,7 @@ public class RawCohereChatClient public RawCohereChatClient( CohereChatConfiguration config, HttpClient httpClient, + IContextProvider contextProvider, ILogger? log = null) { if (String.IsNullOrEmpty(config.ApiKey)) @@ -31,7 +34,8 @@ public RawCohereChatClient( throw new ArgumentException("ApiKey is required", nameof(config.ApiKey)); } - this._httpClient = httpClient; + _httpClient = httpClient; + _contextProvider = contextProvider; _log = log ?? DefaultLogger.Instance; _apiKey = config.ApiKey; _baseUrl = config.BaseUrl; @@ -91,10 +95,9 @@ public async IAsyncEnumerable RagQueryStreamingAsync CohereRagRequest cohereRagRequest, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - if (cohereRagRequest is null) - { - throw new ArgumentNullException(nameof(cohereRagRequest)); - } + ArgumentNullException.ThrowIfNull(cohereRagRequest); + + var context = _contextProvider.GetContext(); var client = _httpClient; //force streaming @@ -130,7 +133,7 @@ public async IAsyncEnumerable RagQueryStreamingAsync string line = (await reader.ReadLineAsync(cancellationToken))!; var data = JsonSerializer.Deserialize(line)!; - if (data.EventType == "stream-start" || data.EventType == "stream-end" || data.EventType == "search-results") + if (data.EventType == "stream-start" || data.EventType == "search-results") { //not interested in this events continue; @@ -152,6 +155,11 @@ public async IAsyncEnumerable RagQueryStreamingAsync ResponseType = CohereRagResponseType.Citations }; } + else if (data.EventType == "stream-end") + { + //create log + AddLog(context, "CommandR+RAG", cohereRagRequest.Describe(), data); + } else { //not supported. @@ -159,5 +167,35 @@ public async IAsyncEnumerable RagQueryStreamingAsync } } } - } + } + + private void AddLog( + IContext context, + string name, + string input, + ChatStreamEvent data) + { + LLMCallLog callLog = new() + { + CallName = name, + ReturnObject = data, + InputPrompt = input, + Output = data.Response.Text, + TokenCount = new TokenCount() + { + InputTokens = data.Response?.Meta.Tokens.InputTokens ?? 0, + OutputTokens = data.Response?.Meta.Tokens.OutputTokens ?? 0, + } + }; + + if (data.Response?.Meta.Warnings?.Length > 0) + { + foreach (var warning in data.Response.Meta.Warnings) + { + callLog.AddWarning(warning); + } + } + + context.AddCallLog(callLog); + } } diff --git a/src/KernelMemory.Extensions/Cohere/RawCohereClientDtos.cs b/src/KernelMemory.Extensions/Cohere/RawCohereClientDtos.cs index 986b72a..9fcc40a 100644 --- a/src/KernelMemory.Extensions/Cohere/RawCohereClientDtos.cs +++ b/src/KernelMemory.Extensions/Cohere/RawCohereClientDtos.cs @@ -1,5 +1,8 @@ using Microsoft.KernelMemory.MemoryStorage; +using System; using System.Collections.Generic; +using System.Linq; +using System.Text; using System.Text.Json.Serialization; namespace KernelMemory.Extensions.Cohere; @@ -43,16 +46,51 @@ public static CohereRagRequest CreateFromMemoryRecord(string question, IEnumerab foreach (var memory in memoryRecords) { + //if the text is more than 300 words we need to split it + var text = memory.GetPartitionText(); + int start = 0; + int spaceCount = 0; + for (int i = 0; i < text.Length; i++) + { + if (text[i] == ' ') + { + spaceCount++; + } + if (spaceCount > 250) + { + ragRequest.Documents.Add(new RagDocument() + { + DocId = memory.Id, + Text = text[start..i] + }); + start = i; + spaceCount = 0; + } + } + ragRequest.Documents.Add(new RagDocument() { DocId = memory.Id, - Text = memory.GetPartitionText() + Text = text[start..text.Length] }); } return ragRequest; } + internal string Describe() + { + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.AppendLine($"Message: {Message}"); + stringBuilder.AppendLine($"Model: {Model}"); + stringBuilder.AppendLine($"Document count: {Documents.Count}"); + stringBuilder.AppendLine($"Temperature: {Temperature}"); + stringBuilder.AppendLine($"Stream: {Stream}"); + stringBuilder.AppendLine($"\n\nFullDocuments\n{string.Join("\n", Documents.Select(d => d.Text))}"); + + return stringBuilder.ToString(); + } + [JsonPropertyName("message")] public string Message { get; set; } @@ -321,6 +359,30 @@ public class ChatStreamEvent [JsonPropertyName("citations")] public List Citations { get; set; } + + [JsonPropertyName("response")] + public ChatStreamingResponse Response { get; set; } +} + +public class ChatStreamingResponse +{ + [JsonPropertyName("response_id")] + public string ResponseId { get; set; } + + [JsonPropertyName("text")] + public string Text { get; set; } + + [JsonPropertyName("generation_id")] + public string GenerationId { get; set; } + + [JsonPropertyName("chat_history")] + public List ChatHistory { get; set; } + + [JsonPropertyName("finish_reason")] + public string FinishReason { get; set; } + + [JsonPropertyName("meta")] + public Meta Meta { get; set; } } public class CohereRagCitation diff --git a/src/KernelMemory.Extensions/Helper/LLMCallLog.cs b/src/KernelMemory.Extensions/Helper/LLMCallLog.cs index 0356f66..ec76a9f 100644 --- a/src/KernelMemory.Extensions/Helper/LLMCallLog.cs +++ b/src/KernelMemory.Extensions/Helper/LLMCallLog.cs @@ -2,6 +2,7 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.Connectors.OpenAI; using OpenAI.Chat; +using System; using System.Collections.Generic; using System.Linq; @@ -20,6 +21,10 @@ public class LLMCallLog public object? ReturnObject { get; set; } + public IReadOnlyList Warnings => _warnings; + + private readonly List _warnings = new(); + public TokenCount TokenCount { get; set; } = null!; public void AddOpenaiChatMessageContent(OpenAIChatMessageContent mc) @@ -50,6 +55,11 @@ public void AddOpenaiChatMessageContent(OpenAIChatMessageContent mc) }; } } + + public void AddWarning(string warning) + { + _warnings.Add(warning); + } } public class TokenCount @@ -66,6 +76,8 @@ public class TokenCount /// public class LLMCallLogContext { + public Guid Id { get; private set; } = Guid.NewGuid(); + public IReadOnlyList CallLogs => _callLogs; private readonly List _callLogs = new(); diff --git a/src/KernelMemory.Extensions/Helper/SemanticKernelWrapper.cs b/src/KernelMemory.Extensions/Helper/SemanticKernelWrapper.cs index 7bfd3a0..27f4d50 100644 --- a/src/KernelMemory.Extensions/Helper/SemanticKernelWrapper.cs +++ b/src/KernelMemory.Extensions/Helper/SemanticKernelWrapper.cs @@ -18,6 +18,7 @@ namespace KernelMemory.Extensions.Helper; public interface ISemanticKernelWrapper { KernelFunction CreateFunctionFromMethod(Delegate method, string functionName); + KernelPlugin CreateFromFunctions(string pluginName, IEnumerable functions); KernelFunction CreateFunctionFromPrompt(PromptTemplateConfig config, IPromptTemplateFactory? promptTemplateFactory = null);