diff --git a/Directory.Packages.props b/Directory.Packages.props
index fbac589..c72f582 100644
--- a/Directory.Packages.props
+++ b/Directory.Packages.props
@@ -3,21 +3,21 @@
true
-
+
-
-
+
+
-
-
-
-
-
+
+
+
+
+
diff --git a/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs b/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs
index b959f16..6ca2a46 100644
--- a/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs
+++ b/src/KernelMemory.Extensions.ConsoleTest/Samples/CustomSearchPipelineBase.cs
@@ -15,7 +15,6 @@
using Microsoft.KernelMemory.Prompts;
using Microsoft.SemanticKernel;
using Spectre.Console;
-using static KernelMemory.Extensions.QueryPipeline.SemanticKernelQueryRewriter;
namespace SemanticMemory.Samples;
@@ -82,11 +81,16 @@ public async Task RunSample2()
.Title("Select the query executor to use")
.AddChoices(["KernelMemory Default", "Cohere CommandR+"]));
+ var queryRewriterTool = AnsiConsole.Prompt(new SelectionPrompt()
+ .Title("Select query rewriter")
+ .AddChoices(["Semantic Kernel Base", "Semantic Kernel Handlebar"]));
+
var kernelBuider = CreateBasicKernelBuilder();
var builder = CreateBasicKernelMemoryBuilder(
services,
storageToUse == "elasticsearch",
- queryExecutorToUse == "Cohere CommandR+");
+ queryExecutorToUse == "Cohere CommandR+",
+ queryRewriterTool == "Semantic Kernel Handlebar");
var kernelMemory = builder.Build();
var kernel = kernelBuider.Build();
@@ -221,7 +225,8 @@ private static async Task IndexDocument(MemoryServerless kernelMemory, string do
private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
ServiceCollection services,
bool useElasticSearch,
- bool useCohereCommandRPlusForQueryExecutor)
+ bool useCohereCommandRPlusForQueryExecutor,
+ bool useHandlebarQueryRewriter)
{
// we need a series of services to use Kernel Memory, the first one is
// an embedding service that will be used to create dense vector for
@@ -283,6 +288,7 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
services.AddSingleton(kernelMemoryBuilder);
services.AddSingleton();
+ services.AddSingleton();
services.AddSingleton();
services.AddSingleton();
services.AddSingleton();
@@ -315,7 +321,15 @@ private static IKernelMemoryBuilder CreateBasicKernelMemoryBuilder(
}
config.SetReRanker();
- config.SetQueryRewriter();
+
+ if (useHandlebarQueryRewriter)
+ {
+ config.SetQueryRewriter();
+ }
+ else
+ {
+ config.SetQueryRewriter();
+ }
});
return kernelMemoryBuilder;
}
diff --git a/src/KernelMemory.Extensions/KernelMemory.Extensions.csproj b/src/KernelMemory.Extensions/KernelMemory.Extensions.csproj
index 669e3bf..a831858 100644
--- a/src/KernelMemory.Extensions/KernelMemory.Extensions.csproj
+++ b/src/KernelMemory.Extensions/KernelMemory.Extensions.csproj
@@ -67,4 +67,8 @@
+
+
+
+
diff --git a/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs b/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs
index a507cf3..3a365e0 100644
--- a/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs
+++ b/src/KernelMemory.Extensions/QueryPipeline/IConversationQueryRewriter.cs
@@ -1,5 +1,7 @@
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.ChatCompletion;
+using Microsoft.SemanticKernel.Connectors.OpenAI;
+using Microsoft.SemanticKernel.PromptTemplates.Handlebars;
using System.Threading.Tasks;
namespace KernelMemory.Extensions.QueryPipeline;
@@ -66,14 +68,80 @@ be a standalone question that contains also the previous context. If there is no
return result?.ToString() ?? question;
}
+}
+
+///
+/// Allows some parametrization of the rewriter.
+///
+public class SemanticKernelQueryRewriterOptions
+{
+ public string? ModelId { get; set; }
+
+ public float Temperature { get; set; } = 0.0f;
+}
- ///
- /// Allows some parametrization of the rewriter.
- ///
- public class SemanticKernelQueryRewriterOptions
+public class HandlebarSemanticKernelQueryRewriter : IConversationQueryRewriter
+{
+ private readonly SemanticKernelQueryRewriterOptions _semanticKernelQueryRewriterOptions;
+ private readonly Kernel _kernel;
+ private readonly KernelFunction _chatFunction;
+
+ public HandlebarSemanticKernelQueryRewriter(
+ SemanticKernelQueryRewriterOptions semanticKernelQueryRewriterOptions,
+ Kernel kernel)
{
- public string? ModelId { get; set; }
+ _semanticKernelQueryRewriterOptions = semanticKernelQueryRewriterOptions;
+ _kernel = kernel;
+
+ // Create a template for chat with settings
+ _chatFunction = kernel.CreateFunctionFromPrompt(new PromptTemplateConfig()
+ {
+ Name = "TestRewrite",
+ Description = "Rewrite a query for kernel memory.",
+ Template = @"system:
+* Given the following conversation history and the users next question,rephrase the question to be a stand alone question.
+If the conversation is irrelevant or empty, just restate the original question.
+Do not add more details than necessary to the question.
+
+chat history:
+{{#each history}}
+question:
+{{question}}
+answer:
+{{answer}}
+{{/each}}
+
+Follow up Input: {{ chat_input }}
+Standalone Question:",
+ TemplateFormat = "handlebars",
+ InputVariables =
+ [
+ new() { Name = "chat_input", Description = "New question of the user", IsRequired = false, Default = "" },
+ new() { Name = "history", Description = "The history of the RAG CHAT.", IsRequired = true }
+ ],
+ ExecutionSettings =
+ {
+ { "default", new OpenAIPromptExecutionSettings()
+ {
+ MaxTokens = 1000,
+ Temperature = 0,
+ ModelId = "gpt35",
+ }
+ },
+ }
+ },
+ promptTemplateFactory: new HandlebarsPromptTemplateFactory());
+ }
- public float Temperature { get; set; } = 0.0f;
+ public async Task RewriteAsync(Conversation conversation, string question)
+ {
+ KernelArguments ka = new();
+ ka["chat_input"] = question;
+
+ ka["history"] = conversation.GetQuestions();
+
+ var result = await _kernel.InvokeAsync(_chatFunction, ka);
+
+ return result?.ToString() ?? question;
}
}
diff --git a/src/KernelMemory.Extensions/QueryPipeline/OriginalKernelMemorySearchClient.cs b/src/KernelMemory.Extensions/QueryPipeline/OriginalKernelMemorySearchClient.cs
deleted file mode 100644
index cdb4f40..0000000
--- a/src/KernelMemory.Extensions/QueryPipeline/OriginalKernelMemorySearchClient.cs
+++ /dev/null
@@ -1,370 +0,0 @@
-using Microsoft.Extensions.Logging;
-using Microsoft.KernelMemory;
-using Microsoft.KernelMemory.AI;
-using Microsoft.KernelMemory.Diagnostics;
-using Microsoft.KernelMemory.MemoryStorage;
-using Microsoft.KernelMemory.Prompts;
-using Microsoft.KernelMemory.Search;
-using System;
-using System.Collections.Generic;
-using System.Diagnostics;
-using System.Linq;
-using System.Text;
-using System.Threading;
-using System.Threading.Tasks;
-
-namespace KernelMemory.Extensions
-{
- public class OriginalKernelMemorySearchClient : ISearchClient
- {
- private readonly IMemoryDb _memoryDb;
- private readonly ITextGenerator _textGenerator;
- private readonly SearchClientConfig _config;
- private readonly ILogger _log;
- private readonly string _answerPrompt;
-
- public OriginalKernelMemorySearchClient(
- IMemoryDb memoryDb,
- ITextGenerator textGenerator,
- SearchClientConfig? config = null,
- IPromptProvider? promptProvider = null,
- ILogger? log = null)
- {
- this._memoryDb = memoryDb;
- this._textGenerator = textGenerator;
- this._config = config ?? new SearchClientConfig();
- this._config.Validate();
-
-#pragma warning disable KMEXP00 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
- promptProvider ??= new EmbeddedPromptProvider();
-#pragma warning restore KMEXP00 // Type is for evaluation purposes only and is subject to change or removal in future updates. Suppress this diagnostic to proceed.
- this._answerPrompt = promptProvider.ReadPrompt(Constants.PromptNamesAnswerWithFacts);
-
- this._log = log ?? DefaultLogger.Instance;
-
- if (this._memoryDb == null)
- {
- throw new KernelMemoryException("Search memory DB not configured");
- }
-
- if (this._textGenerator == null)
- {
- throw new KernelMemoryException("Text generator not configured");
- }
- }
-
- ///
- public Task> ListIndexesAsync(CancellationToken cancellationToken = default)
- {
- return this._memoryDb.GetIndexesAsync(cancellationToken);
- }
-
- ///
- public async Task SearchAsync(
- string index,
- string query,
- ICollection? filters = null,
- double minRelevance = 0,
- int limit = -1,
- CancellationToken cancellationToken = default)
- {
- if (limit <= 0) { limit = this._config.MaxMatchesCount; }
-
- var result = new SearchResult
- {
- Query = query,
- Results = new List()
- };
-
- if (string.IsNullOrWhiteSpace(query) && (filters == null || filters.Count == 0))
- {
- this._log.LogWarning("No query or filters provided");
- return result;
- }
-
- var list = new List<(MemoryRecord memory, double relevance)>();
- if (!string.IsNullOrEmpty(query))
- {
- this._log.LogTrace("Fetching relevant memories by similarity, min relevance {0}", minRelevance);
- IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync(
- index: index,
- text: query,
- filters: filters,
- minRelevance: minRelevance,
- limit: limit,
- withEmbeddings: false,
- cancellationToken: cancellationToken);
-
- // Memories are sorted by relevance, starting from the most relevant
- await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false))
- {
- list.Add((memory, relevance));
- }
- }
- else
- {
- this._log.LogTrace("Fetching relevant memories by filtering");
- IAsyncEnumerable matches = this._memoryDb.GetListAsync(
- index: index,
- filters: filters,
- limit: limit,
- withEmbeddings: false,
- cancellationToken: cancellationToken);
-
- await foreach (MemoryRecord memory in matches.ConfigureAwait(false))
- {
- list.Add((memory, float.MinValue));
- }
- }
-
- // Memories are sorted by relevance, starting from the most relevant
- foreach ((MemoryRecord memory, double relevance) in list)
- {
- // Note: a document can be composed by multiple files
- string documentId = memory.GetDocumentId(this._log);
-
- // Identify the file in case there are multiple files
- string fileId = memory.GetFileId(this._log);
-
- // TODO: URL to access the file in content storage
- string linkToFile = $"{index}/{documentId}/{fileId}";
-
- var partitionText = memory.GetPartitionText(this._log).Trim();
- if (string.IsNullOrEmpty(partitionText))
- {
- this._log.LogError("The document partition is empty, doc: {0}", memory.Id);
- continue;
- }
-
- if (relevance > float.MinValue) { this._log.LogTrace("Adding result with relevance {0}", relevance); }
-
- // If the file is already in the list of citations, only add the partition
- var citation = result.Results.FirstOrDefault(x => x.Link == linkToFile);
- if (citation == null)
- {
- citation = new Citation();
- result.Results.Add(citation);
- }
-
- // Add the partition to the list of citations
- citation.Index = index;
- citation.DocumentId = documentId;
- citation.FileId = fileId;
- citation.Link = linkToFile;
- citation.SourceContentType = memory.GetFileContentType(this._log);
- citation.SourceName = memory.GetFileName(this._log);
- citation.SourceUrl = memory.GetWebPageUrl(index);
-
- citation.Partitions.Add(new Citation.Partition
- {
- Text = partitionText,
- Relevance = (float)relevance,
- PartitionNumber = memory.GetPartitionNumber(this._log),
- SectionNumber = memory.GetSectionNumber(),
- LastUpdate = memory.GetLastUpdate(),
- Tags = memory.Tags,
- });
- }
-
- if (result.Results.Count == 0)
- {
- this._log.LogDebug("No memories found");
- }
-
- return result;
- }
-
- ///
- public async Task AskAsync(
- string index,
- string question,
- ICollection? filters = null,
- double minRelevance = 0,
- CancellationToken cancellationToken = default)
- {
- var noAnswerFound = new MemoryAnswer
- {
- Question = question,
- NoResult = true,
- Result = this._config.EmptyAnswer,
- };
-
- if (string.IsNullOrEmpty(question))
- {
- this._log.LogWarning("No question provided");
- noAnswerFound.NoResultReason = "No question provided";
- return noAnswerFound;
- }
-
- var facts = new StringBuilder();
- var maxTokens = this._config.MaxAskPromptSize > 0
- ? this._config.MaxAskPromptSize
- : this._textGenerator.MaxTokenTotal;
- var tokensAvailable = maxTokens
- - this._textGenerator.CountTokens(this._answerPrompt)
- - this._textGenerator.CountTokens(question)
- - this._config.AnswerTokens;
-
- var factsUsedCount = 0;
- var factsAvailableCount = 0;
- var answer = noAnswerFound;
-
- this._log.LogTrace("Fetching relevant memories");
- IAsyncEnumerable<(MemoryRecord, double)> matches = this._memoryDb.GetSimilarListAsync(
- index: index,
- text: question,
- filters: filters,
- minRelevance: minRelevance,
- limit: this._config.MaxMatchesCount,
- withEmbeddings: false,
- cancellationToken: cancellationToken);
-
- // Memories are sorted by relevance, starting from the most relevant
- await foreach ((MemoryRecord memory, double relevance) in matches.ConfigureAwait(false))
- {
- // Note: a document can be composed by multiple files
- string documentId = memory.GetDocumentId(this._log);
-
- // Identify the file in case there are multiple files
- string fileId = memory.GetFileId(this._log);
-
- // TODO: URL to access the file in content storage
- string linkToFile = $"{index}/{documentId}/{fileId}";
-
- string fileName = memory.GetFileName(this._log);
-
- var partitionText = memory.GetPartitionText(this._log).Trim();
- if (string.IsNullOrEmpty(partitionText))
- {
- this._log.LogError("The document partition is empty, doc: {0}", memory.Id);
- continue;
- }
-
- factsAvailableCount++;
-
- // TODO: add file age in days, to push relevance of newer documents
- var fact = $"==== [File:{fileName};Relevance:{relevance:P1}]:\n{partitionText}\n";
-
- // Use the partition/chunk only if there's room for it
- var size = this._textGenerator.CountTokens(fact);
- if (size >= tokensAvailable)
- {
- // Stop after reaching the max number of tokens
- break;
- }
-
- factsUsedCount++;
- if (relevance > float.MinValue) { this._log.LogTrace("Adding text {0} with relevance {1}", factsUsedCount, relevance); }
-
- facts.Append(fact);
- tokensAvailable -= size;
-
- // If the file is already in the list of citations, only add the partition
- var citation = answer.RelevantSources.FirstOrDefault(x => x.Link == linkToFile);
- if (citation == null)
- {
- citation = new Citation();
- answer.RelevantSources.Add(citation);
- }
-
- // Add the partition to the list of citations
- citation.Index = index;
- citation.DocumentId = documentId;
- citation.FileId = fileId;
- citation.Link = linkToFile;
- citation.SourceContentType = memory.GetFileContentType(this._log);
- citation.SourceName = fileName;
- citation.SourceUrl = memory.GetWebPageUrl(index);
-
- citation.Partitions.Add(new Citation.Partition
- {
- Text = partitionText,
- Relevance = (float)relevance,
- LastUpdate = memory.GetLastUpdate(),
- Tags = memory.Tags,
- });
- }
-
- if (factsAvailableCount > 0 && factsUsedCount == 0)
- {
- this._log.LogError("Unable to inject memories in the prompt, not enough tokens available");
- noAnswerFound.NoResultReason = "Unable to use memories";
- return noAnswerFound;
- }
-
- if (factsUsedCount == 0)
- {
- this._log.LogWarning("No memories available");
- noAnswerFound.NoResultReason = "No memories available";
- return noAnswerFound;
- }
-
- var text = new StringBuilder();
- var charsGenerated = 0;
- var watch = new Stopwatch();
- watch.Restart();
- await foreach (var x in this.GenerateAnswerAsync(question, facts.ToString())
- .WithCancellation(cancellationToken).ConfigureAwait(false))
- {
- text.Append(x);
-
- if (this._log.IsEnabled(LogLevel.Trace) && text.Length - charsGenerated >= 30)
- {
- charsGenerated = text.Length;
- this._log.LogTrace("{0} chars generated", charsGenerated);
- }
- }
-
- watch.Stop();
- this._log.LogTrace("Answer generated in {0} msecs", watch.ElapsedMilliseconds);
-
- answer.Result = text.ToString();
- answer.NoResult = ValueIsEquivalentTo(answer.Result, this._config.EmptyAnswer);
- if (answer.NoResult)
- {
- answer.NoResultReason = "No relevant memories found";
- }
-
- return answer;
- }
-
- private IAsyncEnumerable GenerateAnswerAsync(string question, string facts)
- {
- var prompt = this._answerPrompt;
- prompt = prompt.Replace("{{$facts}}", facts.Trim(), StringComparison.OrdinalIgnoreCase);
-
- question = question.Trim();
- question = question.EndsWith('?') ? question : $"{question}?";
- prompt = prompt.Replace("{{$input}}", question, StringComparison.OrdinalIgnoreCase);
-
- prompt = prompt.Replace("{{$notFound}}", this._config.EmptyAnswer, StringComparison.OrdinalIgnoreCase);
-
- var options = new TextGenerationOptions
- {
- Temperature = this._config.Temperature,
- TopP = this._config.TopP,
- PresencePenalty = this._config.PresencePenalty,
- FrequencyPenalty = this._config.FrequencyPenalty,
- MaxTokens = this._config.AnswerTokens,
- StopSequences = this._config.StopSequences,
- TokenSelectionBiases = this._config.TokenSelectionBiases,
- };
-
- if (this._log.IsEnabled(LogLevel.Debug))
- {
- this._log.LogDebug("Running RAG prompt, size: {0} tokens, requesting max {1} tokens",
- this._textGenerator.CountTokens(prompt),
- this._config.AnswerTokens);
- }
-
- return this._textGenerator.GenerateTextAsync(prompt, options);
- }
-
- private static bool ValueIsEquivalentTo(string value, string target)
- {
- value = value.Trim().Trim('.', '"', '\'', '`', '~', '!', '?', '@', '#', '$', '%', '^', '+', '*', '_', '-', '=', '|', '\\', '/', '(', ')', '[', ']', '{', '}', '<', '>');
- target = target.Trim().Trim('.', '"', '\'', '`', '~', '!', '?', '@', '#', '$', '%', '^', '+', '*', '_', '-', '=', '|', '\\', '/', '(', ')', '[', ']', '{', '}', '<', '>');
- return string.Equals(value, target, StringComparison.OrdinalIgnoreCase);
- }
- }
-}
diff --git a/src/KernelMemory.Extensions/QueryPipeline/StandardRagQueryExecutor.cs b/src/KernelMemory.Extensions/QueryPipeline/StandardRagQueryExecutor.cs
index e614bf8..b691bc1 100644
--- a/src/KernelMemory.Extensions/QueryPipeline/StandardRagQueryExecutor.cs
+++ b/src/KernelMemory.Extensions/QueryPipeline/StandardRagQueryExecutor.cs
@@ -150,7 +150,6 @@ private IAsyncEnumerable GenerateAnswerAsync(string question, string fac
var options = new TextGenerationOptions
{
Temperature = this._config.Temperature,
- TopP = this._config.TopP,
PresencePenalty = this._config.PresencePenalty,
FrequencyPenalty = this._config.FrequencyPenalty,
MaxTokens = this._config.AnswerTokens,