diff --git a/Directory.Packages.props b/Directory.Packages.props index fbac589..c72f582 100644 --- a/Directory.Packages.props +++ b/Directory.Packages.props @@ -3,21 +3,21 @@ true - + - - + + - - - - - + + + + + diff --git a/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs b/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs index b959f16..6ca2a46 100644 --- a/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs +++ b/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs @@ -15,7 +15,6 @@ using Microsoft.KernelMemory.Prompts; using Microsoft.SemanticKernel; using Spectre.Console; -using static KernelMemory.Extensions.QueryPipeline.SemanticKernelQueryRewriter; namespace SemanticMemory.Samples; @@ -82,11 +81,16 @@ public async Task RunSample2() .Title("Select the query executor to use") .AddChoices(["KernelMemory Default", "Cohere CommandR+"])); + var queryRewriterTool = AnsiConsole.Prompt(new SelectionPrompt() + .Title("Select query rewriter") + .AddChoices(["Semantic Kernel Base", "Semantic Kernel Handlebar"])); + var kernelBuider = CreateBasicKernelBuilder(); var builder = CreateBasicKernelMemoryBuilder( services, storageToUse == "elasticsearch", - queryExecutorToUse == "Cohere CommandR+"); + queryExecutorToUse == "Cohere CommandR+", + queryRewriterTool == "Semantic Kernel Handlebar"); var kernelMemory = builder.Build(); var kernel = kernelBuider.Build(); @@ -221,7 +225,8 @@ private static async Task IndexDocument(MemoryServerless kernelMemory, string do private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder( ServiceCollection services, bool useElasticSearch, - bool useCohereCommandRPlusForQueryExecutor) + bool useCohereCommandRPlusForQueryExecutor, + bool useHandlebarQueryRewriter) { // we need a series of services to use Kernel Memory, the first one is // an embedding service that will be used to create dense vector for @@ -283,6 +288,7 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder( services.AddSingleton(kernelMemoryBuilder); services.AddSingleton(); + services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); services.AddSingleton(); @@ -315,7 +321,15 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder( } config.SetReRanker(); - config.SetQueryRewriter(); + + if (useHandlebarQueryRewriter) + { + config.SetQueryRewriter(); + } + else + { + config.SetQueryRewriter(); + } }); return kernelMemoryBuilder; } diff --git a/src/KernelMemory.Extensions/KernelMemory.Extensions.csproj b/src/KernelMemory.Extensions/KernelMemory.Extensions.csproj index 669e3bf..a831858 100644 --- a/src/KernelMemory.Extensions/KernelMemory.Extensions.csproj +++ b/src/KernelMemory.Extensions/KernelMemory.Extensions.csproj @@ -67,4 +67,8 @@ + + + + diff --git a/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs b/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs index a507cf3..3a365e0 100644 --- a/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs +++ b/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs @@ -1,5 +1,7 @@ using Microsoft.SemanticKernel; using Microsoft.SemanticKernel.ChatCompletion; +using Microsoft.SemanticKernel.Connectors.OpenAI; +using Microsoft.SemanticKernel.PromptTemplates.Handlebars; using System.Threading.Tasks; namespace KernelMemory.Extensions.QueryPipeline; @@ -66,14 +68,80 @@ be a standalone question that contains also the previous context. If there is no return result?.ToString() ?? question; } +} + +/// +/// Allows some parametrization of the rewriter. +/// +public class SemanticKernelQueryRewriterOptions +{ + public string? ModelId { get; set; } + + public float Temperature { get; set; } = 0.0f; +} - /// - /// Allows some parametrization of the rewriter. - /// - public class SemanticKernelQueryRewriterOptions +public class HandlebarSemanticKernelQueryRewriter : IConversationQueryRewriter +{ + private readonly SemanticKernelQueryRewriterOptions _semanticKernelQueryRewriterOptions; + private readonly Kernel _kernel; + private readonly KernelFunction _chatFunction; + + public HandlebarSemanticKernelQueryRewriter( + SemanticKernelQueryRewriterOptions semanticKernelQueryRewriterOptions, + Kernel kernel) { - public string? ModelId { get; set; } + _semanticKernelQueryRewriterOptions = semanticKernelQueryRewriterOptions; + _kernel = kernel; + + // Create a template for chat with settings + _chatFunction = kernel.CreateFunctionFromPrompt(new PromptTemplateConfig() + { + Name = "TestRewrite", + Description = "Rewrite a query for kernel memory.", + Template = @"system: +* Given the following conversation history and the users next question,rephrase the question to be a stand alone question. +If the conversation is irrelevant or empty, just restate the original question. +Do not add more details than necessary to the question. + +chat history: +{{#each history}} +question: +{{question}} +answer: +{{answer}} +{{/each}} + +Follow up Input: {{ chat_input }} +Standalone Question:", + TemplateFormat = "handlebars", + InputVariables = + [ + new() { Name = "chat_input", Description = "New question of the user", IsRequired = false, Default = "" }, + new() { Name = "history", Description = "The history of the RAG CHAT.", IsRequired = true } + ], + ExecutionSettings = + { + { "default", new OpenAIPromptExecutionSettings() + { + MaxTokens = 1000, + Temperature = 0, + ModelId = "gpt35", + } + }, + } + }, + promptTemplateFactory: new HandlebarsPromptTemplateFactory()); + } - public float Temperature { get; set; } = 0.0f; + public async Task RewriteAsync(Conversation conversation, string question) + { + KernelArguments ka = new(); + ka["chat_input"] = question; + + ka["history"] = conversation.GetQuestions(); + + var result = await _kernel.InvokeAsync(_chatFunction, ka); + + return result?.ToString() ?? question; } } diff --git a/src/KernelMemory.Extensions/QueryPipeline/OriginalKernelMemorySearchClient.cs b/src/KernelMemory.Extensions/QueryPipeline/OriginalKernelMemorySearchClient.cs deleted file mode 100644 index cdb4f40..0000000 --- a/src/KernelMemory.Extensions/QueryPipeline/OriginalKernelMemorySearchClient.cs +++ /dev/null @@ -1,370 +0,0 @@ -using Microsoft.Extensions.Logging; -using Microsoft.KernelMemory; -using Microsoft.KernelMemory.AI; -using Microsoft.KernelMemory.Diagnostics; -using Microsoft.KernelMemory.MemoryStorage; -using Microsoft.KernelMemory.Prompts; -using Microsoft.KernelMemory.Search; -using System; -using System.Collections.Generic; -using System.Diagnostics; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; - -namespace KernelMemory.Extensions -{ - public class OriginalKernelMemorySearchClient : ISearchClient - { - private readonly IMemoryDb _memoryDb; - private readonly ITextGenerator _textGenerator; - private readonly SearchClientConfig _config; - private readonly ILogger _log; - private readonly string _answerPrompt; - - public OriginalKernelMemorySearchClient( - IMemoryDb memoryDb, - ITextGenerator textGenerator, - SearchClientConfig? config = null, - IPromptProvider? promptProvider = null, - ILogger? log = null) - { - this._memoryDb = memoryDb; - this._textGenerator = textGenerator; - this._config = config ?? new SearchClientConfig(); - this._config.Validate(); - -#pragma warning disable KMEXP00 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - promptProvider ??= new EmbeddedPromptProvider(); -#pragma warning restore KMEXP00 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed. - this._answerPrompt = promptProvider.ReadPrompt(Constants.PromptNamesAnswerWithFacts); - - this._log = log ?? DefaultLogger.Instance; - - if (this._memoryDb == null) - { - throw new KernelMemoryException("Search memory DB not configured"); - } - - if (this._textGenerator == null) - { - throw new KernelMemoryException("Text generator not configured"); - } - } - - /// - public Task> ListIndexesAsync(CancellationToken cancellationToken = default) - { - return this._memoryDb.GetIndexesAsync(cancellationToken); - } - - /// - public async Task SearchAsync( - string index, - string query, - ICollection? filters = null, - double minRelevance = 0, - int limit = -1, - CancellationToken cancellationToken = default) - { - if (limit <= 0) { limit = this._config.MaxMatchesCount; } - - var result = new SearchResult - { - Query = query, - Results = new List() - }; - - if (string.IsNullOrWhiteSpace(query) && (filters == null || filters.Count == 0)) - { - this._log.LogWarning("No query or filters provided"); - return result; - } - - var list = new List<(MemoryRecord memory, double relevance)>(); - if (!string.IsNullOrEmpty(query)) - { - this._log.LogTrace("Fetching relevant memories by similarity, min relevance {0}", minRelevance); - IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync( - index: index, - text: query, - filters: filters, - minRelevance: minRelevance, - limit: limit, - withEmbeddings: false, - cancellationToken: cancellationToken); - - // Memories are sorted by relevance, starting from the most relevant - await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false)) - { - list.Add((memory, relevance)); - } - } - else - { - this._log.LogTrace("Fetching relevant memories by filtering"); - IAsyncEnumerable matches = this._memoryDb.GetListAsync( - index: index, - filters: filters, - limit: limit, - withEmbeddings: false, - cancellationToken: cancellationToken); - - await foreach (MemoryRecord memory in matches.ConfigureAwait(false)) - { - list.Add((memory, float.MinValue)); - } - } - - // Memories are sorted by relevance, starting from the most relevant - foreach ((MemoryRecord memory, double relevance) in list) - { - // Note: a document can be composed by multiple files - string documentId = memory.GetDocumentId(this._log); - - // Identify the file in case there are multiple files - string fileId = memory.GetFileId(this._log); - - // TODO: URL to access the file in content storage - string linkToFile = $"{index}/{documentId}/{fileId}"; - - var partitionText = memory.GetPartitionText(this._log).Trim(); - if (string.IsNullOrEmpty(partitionText)) - { - this._log.LogError("The document partition is empty, doc: {0}", memory.Id); - continue; - } - - if (relevance > float.MinValue) { this._log.LogTrace("Adding result with relevance {0}", relevance); } - - // If the file is already in the list of citations, only add the partition - var citation = result.Results.FirstOrDefault(x => x.Link == linkToFile); - if (citation == null) - { - citation = new Citation(); - result.Results.Add(citation); - } - - // Add the partition to the list of citations - citation.Index = index; - citation.DocumentId = documentId; - citation.FileId = fileId; - citation.Link = linkToFile; - citation.SourceContentType = memory.GetFileContentType(this._log); - citation.SourceName = memory.GetFileName(this._log); - citation.SourceUrl = memory.GetWebPageUrl(index); - - citation.Partitions.Add(new Citation.Partition - { - Text = partitionText, - Relevance = (float)relevance, - PartitionNumber = memory.GetPartitionNumber(this._log), - SectionNumber = memory.GetSectionNumber(), - LastUpdate = memory.GetLastUpdate(), - Tags = memory.Tags, - }); - } - - if (result.Results.Count == 0) - { - this._log.LogDebug("No memories found"); - } - - return result; - } - - /// - public async Task AskAsync( - string index, - string question, - ICollection? filters = null, - double minRelevance = 0, - CancellationToken cancellationToken = default) - { - var noAnswerFound = new MemoryAnswer - { - Question = question, - NoResult = true, - Result = this._config.EmptyAnswer, - }; - - if (string.IsNullOrEmpty(question)) - { - this._log.LogWarning("No question provided"); - noAnswerFound.NoResultReason = "No question provided"; - return noAnswerFound; - } - - var facts = new StringBuilder(); - var maxTokens = this._config.MaxAskPromptSize > 0 - ? this._config.MaxAskPromptSize - : this._textGenerator.MaxTokenTotal; - var tokensAvailable = maxTokens - - this._textGenerator.CountTokens(this._answerPrompt) - - this._textGenerator.CountTokens(question) - - this._config.AnswerTokens; - - var factsUsedCount = 0; - var factsAvailableCount = 0; - var answer = noAnswerFound; - - this._log.LogTrace("Fetching relevant memories"); - IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync( - index: index, - text: question, - filters: filters, - minRelevance: minRelevance, - limit: this._config.MaxMatchesCount, - withEmbeddings: false, - cancellationToken: cancellationToken); - - // Memories are sorted by relevance, starting from the most relevant - await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false)) - { - // Note: a document can be composed by multiple files - string documentId = memory.GetDocumentId(this._log); - - // Identify the file in case there are multiple files - string fileId = memory.GetFileId(this._log); - - // TODO: URL to access the file in content storage - string linkToFile = $"{index}/{documentId}/{fileId}"; - - string fileName = memory.GetFileName(this._log); - - var partitionText = memory.GetPartitionText(this._log).Trim(); - if (string.IsNullOrEmpty(partitionText)) - { - this._log.LogError("The document partition is empty, doc: {0}", memory.Id); - continue; - } - - factsAvailableCount++; - - // TODO: add file age in days, to push relevance of newer documents - var fact = $"==== [File:{fileName};Relevance:{relevance:P1}]:\n{partitionText}\n"; - - // Use the partition/chunk only if there's room for it - var size = this._textGenerator.CountTokens(fact); - if (size >= tokensAvailable) - { - // Stop after reaching the max number of tokens - break; - } - - factsUsedCount++; - if (relevance > float.MinValue) { this._log.LogTrace("Adding text {0} with relevance {1}", factsUsedCount, relevance); } - - facts.Append(fact); - tokensAvailable -= size; - - // If the file is already in the list of citations, only add the partition - var citation = answer.RelevantSources.FirstOrDefault(x => x.Link == linkToFile); - if (citation == null) - { - citation = new Citation(); - answer.RelevantSources.Add(citation); - } - - // Add the partition to the list of citations - citation.Index = index; - citation.DocumentId = documentId; - citation.FileId = fileId; - citation.Link = linkToFile; - citation.SourceContentType = memory.GetFileContentType(this._log); - citation.SourceName = fileName; - citation.SourceUrl = memory.GetWebPageUrl(index); - - citation.Partitions.Add(new Citation.Partition - { - Text = partitionText, - Relevance = (float)relevance, - LastUpdate = memory.GetLastUpdate(), - Tags = memory.Tags, - }); - } - - if (factsAvailableCount > 0 && factsUsedCount == 0) - { - this._log.LogError("Unable to inject memories in the prompt, not enough tokens available"); - noAnswerFound.NoResultReason = "Unable to use memories"; - return noAnswerFound; - } - - if (factsUsedCount == 0) - { - this._log.LogWarning("No memories available"); - noAnswerFound.NoResultReason = "No memories available"; - return noAnswerFound; - } - - var text = new StringBuilder(); - var charsGenerated = 0; - var watch = new Stopwatch(); - watch.Restart(); - await foreach (var x in this.GenerateAnswerAsync(question, facts.ToString()) - .WithCancellation(cancellationToken).ConfigureAwait(false)) - { - text.Append(x); - - if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30) - { - charsGenerated = text.Length; - this._log.LogTrace("{0} chars generated", charsGenerated); - } - } - - watch.Stop(); - this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds); - - answer.Result = text.ToString(); - answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer); - if (answer.NoResult) - { - answer.NoResultReason = "No relevant memories found"; - } - - return answer; - } - - private IAsyncEnumerable GenerateAnswerAsync(string question, string facts) - { - var prompt = this._answerPrompt; - prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase); - - question = question.Trim(); - question = question.EndsWith('?') ? question : $"{question}?"; - prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase); - - prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase); - - var options = new TextGenerationOptions - { - Temperature = this._config.Temperature, - TopP = this._config.TopP, - PresencePenalty = this._config.PresencePenalty, - FrequencyPenalty = this._config.FrequencyPenalty, - MaxTokens = this._config.AnswerTokens, - StopSequences = this._config.StopSequences, - TokenSelectionBiases = this._config.TokenSelectionBiases, - }; - - if (this._log.IsEnabled(LogLevel.Debug)) - { - this._log.LogDebug("Running RAG prompt, size: {0} tokens, requesting max {1} tokens", - this._textGenerator.CountTokens(prompt), - this._config.AnswerTokens); - } - - return this._textGenerator.GenerateTextAsync(prompt, options); - } - - private static bool ValueIsEquivalentTo(string value, string target) - { - value = value.Trim().Trim('.', '"', '\'', '`', '~', '!', '?', '@', '#', '$', '%', '^', '+', '*', '_', '-', '=', '|', '\\', '/', '(', ')', '[', ']', '{', '}', '<', '>'); - target = target.Trim().Trim('.', '"', '\'', '`', '~', '!', '?', '@', '#', '$', '%', '^', '+', '*', '_', '-', '=', '|', '\\', '/', '(', ')', '[', ']', '{', '}', '<', '>'); - return string.Equals(value, target, StringComparison.OrdinalIgnoreCase); - } - } -} diff --git a/src/KernelMemory.Extensions/QueryPipeline/StandardRagQueryExecutor.cs b/src/KernelMemory.Extensions/QueryPipeline/StandardRagQueryExecutor.cs index e614bf8..b691bc1 100644 --- a/src/KernelMemory.Extensions/QueryPipeline/StandardRagQueryExecutor.cs +++ b/src/KernelMemory.Extensions/QueryPipeline/StandardRagQueryExecutor.cs @@ -150,7 +150,6 @@ private IAsyncEnumerable GenerateAnswerAsync(string question, string fac var options = new TextGenerationOptions { Temperature = this._config.Temperature, - TopP = this._config.TopP, PresencePenalty = this._config.PresencePenalty, FrequencyPenalty = this._config.FrequencyPenalty, MaxTokens = this._config.AnswerTokens,