Skip to content

Commit

Permalink
Support filtering in serverless client
Browse files Browse the repository at this point in the history
  • Loading branch information
dluc committed Aug 9, 2023
1 parent 102563e commit 9c6e4ce
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 68 deletions.
8 changes: 4 additions & 4 deletions dotnet/ClientLib/ISemanticMemoryClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,22 @@ public interface ISemanticMemoryClient

/// <summary>
/// Search the default user memory for an answer to the given query.
/// TODO: add support for tags.
/// </summary>
/// <param name="query">Query/question to answer</param>
/// <param name="filter">Filter to match</param>
/// <param name="cancellationToken">Async task cancellation token</param>
/// <returns>Answer to the query, if possible</returns>
public Task<MemoryAnswer> AskAsync(string query, CancellationToken cancellationToken = default);
public Task<MemoryAnswer> AskAsync(string query, MemoryFilter? filter = null, CancellationToken cancellationToken = default);

/// <summary>
/// Search a user memory for an answer to the given query.
/// TODO: add support for tags.
/// </summary>
/// <param name="userId">ID of the user's memory to search</param>
/// <param name="query">Query/question to answer</param>
/// <param name="filter">Filter to match</param>
/// <param name="cancellationToken">Async task cancellation token</param>
/// <returns>Answer to the query, if possible</returns>
public Task<MemoryAnswer> AskAsync(string userId, string query, CancellationToken cancellationToken = default);
public Task<MemoryAnswer> AskAsync(string userId, string query, MemoryFilter? filter = null, CancellationToken cancellationToken = default);

/// <summary>
/// Check if a document ID exists in a user memory and is ready for usage.
Expand Down
6 changes: 3 additions & 3 deletions dotnet/ClientLib/MemoryWebClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ public Task<string> ImportFileAsync(string fileName, DocumentDetails details, Ca
}

/// <inheritdoc />
public Task<MemoryAnswer> AskAsync(string query, CancellationToken cancellationToken = default)
public Task<MemoryAnswer> AskAsync(string query, MemoryFilter? filter = null, CancellationToken cancellationToken = default)
{
return this.AskAsync(new DocumentDetails().UserId, query, cancellationToken);
return this.AskAsync(new DocumentDetails().UserId, query, filter, cancellationToken);
}

/// <inheritdoc />
public async Task<MemoryAnswer> AskAsync(string userId, string query, CancellationToken cancellationToken = default)
public async Task<MemoryAnswer> AskAsync(string userId, string query, MemoryFilter? filter = null, CancellationToken cancellationToken = default)
{
var request = new { UserId = userId, Query = query, Tags = new TagCollection() };
using var content = new StringContent(JsonSerializer.Serialize(request), Encoding.UTF8, "application/json");
Expand Down
8 changes: 0 additions & 8 deletions dotnet/ClientLib/Models/MemoryAnswer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,6 @@ public class Partition
[JsonPropertyOrder(2)]
public float Relevance { get; set; } = 0;

/// <summary>
/// Size in tokens of the partition. The size depends
/// on the AI model used to generate the answer.
/// </summary>
[JsonPropertyName("SizeInTokens")]
[JsonPropertyOrder(3)]
public int SizeInTokens { get; set; } = 0;

/// <summary>
/// Timestamp about the file/text partition.
/// </summary>
Expand Down
43 changes: 43 additions & 0 deletions dotnet/ClientLib/Models/MemoryFilter.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
// Copyright (c) Microsoft. All rights reserved.

using System.Collections.Generic;

namespace Microsoft.SemanticMemory.Client.Models;

public class MemoryFilter
{
private readonly TagCollection _tags;

public MemoryFilter()
{
this._tags = new TagCollection();
}

public bool IsEmpty()
{
return this._tags.Count == 0;
}

public MemoryFilter ByTag(string name, string value)
{
this._tags.Add(name, value);
return this;
}

public MemoryFilter ByUser(string userId)
{
this._tags.Add(Constants.ReservedUserIdTag, userId);
return this;
}

public MemoryFilter ByDocument(string docId)
{
this._tags.Add(Constants.ReservedPipelineIdTag, docId);
return this;
}

public IEnumerable<KeyValuePair<string, string?>> GetFilters()
{
return this._tags.ToKeyValueList();
}
}
5 changes: 5 additions & 0 deletions dotnet/ClientLib/Models/TagCollection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ public void CopyTo(TagCollection tagCollection)
}
}

public IEnumerable<KeyValuePair<string, string?>> ToKeyValueList()
{
return (from tag in this._data from tagValue in tag.Value select new KeyValuePair<string, string?>(tag.Key, tagValue));
}

public bool Remove(KeyValuePair<string, List<string?>> item)
{
return this._data.Remove(item);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
using Azure.Search.Documents.Models;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticMemory.Client.Models;
using Microsoft.SemanticMemory.Core.Configuration;
using Microsoft.SemanticMemory.Core.Diagnostics;

Expand Down Expand Up @@ -93,11 +94,12 @@ await client.IndexDocumentsAsync(
}

/// <inheritdoc />
public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync(
public async IAsyncEnumerable<(MemoryRecord, double)> GetSimilarListAsync(
string indexName,
Embedding<float> embedding,
int limit,
double minRelevanceScore = 0,
MemoryFilter? filter = null,
bool withEmbeddings = false,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
Expand All @@ -110,7 +112,25 @@ await client.IndexDocumentsAsync(
Value = embedding.Vector.ToList()
};

SearchOptions options = new() { Vector = vectorQuery };
SearchOptions options = new()
{
Vector = vectorQuery
};

if (filter != null && !filter.IsEmpty())
{
// We need to fetch more vectors because filters are applied after the vector search
vectorQuery.KNearestNeighborsCount = limit * 100;

IEnumerable<string> conditions = (from keyValue in filter.GetFilters()
let fieldValue = keyValue.Value?.Replace("'", "''", StringComparison.Ordinal)
select $"tags/any(s: s eq '{keyValue.Key}={fieldValue}')");
options.Filter = string.Join(" and ", conditions);
options.Size = limit;

this._log.LogDebug("Filtering vectors, limit {0}, condition: {1}", options.Size, options.Filter);
}

Response<SearchResults<AzureCognitiveSearchMemoryRecord>>? searchResult = null;
try
{
Expand All @@ -137,26 +157,36 @@ await client.IndexDocumentsAsync(
}

/// <inheritdoc />
public async IAsyncEnumerable<MemoryRecord> SearchByFieldValueAsync(
public async IAsyncEnumerable<MemoryRecord> GetListAsync(
string indexName,
string fieldName,
bool fieldIsCollection,
string fieldValue,
int limit,
MemoryFilter? filter = null,
int limit = 1,
bool withEmbeddings = false,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var client = this.GetSearchClient(indexName);

// See: https://learn.microsoft.com/azure/search/search-query-understand-collection-filters
fieldValue = fieldValue.Replace("'", "''", StringComparison.Ordinal);
var options = new SearchOptions
var options = new SearchOptions();
if (filter != null && !filter.IsEmpty())
{
Filter = fieldIsCollection
? $"{fieldName}/any(s: s eq '{fieldValue}')"
: $"{fieldName} eq '{fieldValue}')",
Size = limit
};
IEnumerable<string> conditions = (from keyValue in filter.GetFilters()
let fieldValue = keyValue.Value?.Replace("'", "''", StringComparison.Ordinal)
select $"tags/any(s: s eq '{keyValue.Key}={fieldValue}')");
options.Filter = string.Join(" and ", conditions);
options.Size = limit;

this._log.LogDebug("Filtering vectors, limit {0}, condition: {1}", options.Size, options.Filter);
}

// See: https://learn.microsoft.com/azure/search/search-query-understand-collection-filters
// fieldValue = fieldValue.Replace("'", "''", StringComparison.Ordinal);
// var options = new SearchOptions
// {
// Filter = fieldIsCollection
// ? $"{fieldName}/any(s: s eq '{fieldValue}')"
// : $"{fieldName} eq '{fieldValue}')",
// Size = limit
// };

Response<SearchResults<AzureCognitiveSearchMemoryRecord>>? searchResult = null;
try
Expand Down
17 changes: 8 additions & 9 deletions dotnet/CoreLib/MemoryStorage/ISemanticMemoryVectorDb.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticMemory.Client.Models;

namespace Microsoft.SemanticMemory.Core.MemoryStorage;

Expand Down Expand Up @@ -60,14 +61,16 @@ Task<string> UpsertAsync(
/// <param name="embedding">Target vector to compare to</param>
/// <param name="limit">Max number of results</param>
/// <param name="minRelevanceScore">Minimum similarity required</param>
/// <param name="filter">Values to match in the field used for tagging records (the field must be a list of strings)</param>
/// <param name="withEmbeddings">Whether to include vector in the result</param>
/// <param name="cancellationToken">Task cancellation token</param>
/// <returns>List of similar vectors, starting from the most similar</returns>
IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync(
IAsyncEnumerable<(MemoryRecord, double)> GetSimilarListAsync(
string indexName,
Embedding<float> embedding,
int limit,
double minRelevanceScore = 0,
MemoryFilter? filter = null,
bool withEmbeddings = false,
CancellationToken cancellationToken = default);

Expand All @@ -76,19 +79,15 @@ Task<string> UpsertAsync(
/// E.g. searching vectors by tag, for deletions.
/// </summary>
/// <param name="indexName">Index/Collection name</param>
/// <param name="fieldName">Field to search</param>
/// <param name="fieldIsCollection">Whether the field is a string or a collection of strings</param>
/// <param name="fieldValue">Value to match (if the field is a collection, the collection must contain the value)</param>
/// <param name="filter">Values to match in the field used for tagging records (the field must be a list of strings)</param>
/// <param name="limit">Max number of records to return</param>
/// <param name="withEmbeddings">Whether to include vector in the result</param>
/// <param name="cancellationToken">Task cancellation token</param>
/// <returns>List of records</returns>
IAsyncEnumerable<MemoryRecord> SearchByFieldValueAsync(
IAsyncEnumerable<MemoryRecord> GetListAsync(
string indexName,
string fieldName,
bool fieldIsCollection,
string fieldValue,
int limit,
MemoryFilter? filter = null,
int limit = 1,
bool withEmbeddings = false,
CancellationToken cancellationToken = default);

Expand Down
12 changes: 6 additions & 6 deletions dotnet/CoreLib/Search/SearchClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@ public SearchClient(
"Answer: ";
}

public Task<MemoryAnswer> SearchAsync(SearchRequest request, CancellationToken cancellationToken = default)
public Task<MemoryAnswer> AskAsync(MemoryQuery query, CancellationToken cancellationToken = default)
{
return this.SearchAsync(request.UserId, request.Query, cancellationToken);
return this.AskAsync(query.UserId, query.Query, query.Filter, cancellationToken);
}

public async Task<MemoryAnswer> SearchAsync(string userId, string query, CancellationToken cancellationToken = default)
public async Task<MemoryAnswer> AskAsync(string userId, string query, MemoryFilter? filter = null, CancellationToken cancellationToken = default)
{
var facts = new StringBuilder();
var tokensAvailable = 8000
Expand All @@ -82,9 +82,10 @@ public async Task<MemoryAnswer> SearchAsync(string userId, string query, Cancell
var embedding = await this.GenerateEmbeddingAsync(query).ConfigureAwait(false);

this._log.LogTrace("Fetching relevant memories");
IAsyncEnumerable<(MemoryRecord, double)> matches = this._vectorDb.GetNearestMatchesAsync(
indexName: userId, embedding, MatchesCount, MinSimilarity, false, cancellationToken: cancellationToken);
IAsyncEnumerable<(MemoryRecord, double)> matches = this._vectorDb.GetSimilarListAsync(
indexName: userId, embedding, MatchesCount, MinSimilarity, filter, false, cancellationToken: cancellationToken);

// Memories are sorted by relevance, starting from the most relevant
await foreach ((MemoryRecord memory, double relevance) in matches.WithCancellation(cancellationToken))
{
if (!memory.Tags.ContainsKey(Constants.ReservedPipelineIdTag))
Expand Down Expand Up @@ -156,7 +157,6 @@ public async Task<MemoryAnswer> SearchAsync(string userId, string query, Cancell
{
Text = partitionText,
Relevance = (float)relevance,
SizeInTokens = size,
LastUpdate = lastUpdate,
});

Expand Down
8 changes: 4 additions & 4 deletions dotnet/CoreLib/SemanticMemoryServerless.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,15 @@ public async Task<string> ImportFileAsync(string fileName, DocumentDetails detai
}

/// <inheritdoc />
public Task<MemoryAnswer> AskAsync(string query, CancellationToken cancellationToken = default)
public Task<MemoryAnswer> AskAsync(string query, MemoryFilter? filter = null, CancellationToken cancellationToken = default)
{
return this.AskAsync(new DocumentDetails().UserId, query, cancellationToken);
return this.AskAsync(new DocumentDetails().UserId, query, filter, cancellationToken);
}

/// <inheritdoc />
public Task<MemoryAnswer> AskAsync(string userId, string query, CancellationToken cancellationToken = default)
public Task<MemoryAnswer> AskAsync(string userId, string query, MemoryFilter? filter = null, CancellationToken cancellationToken = default)
{
return this._searchClient.SearchAsync(userId: userId, query: query, cancellationToken: cancellationToken);
return this._searchClient.AskAsync(userId: userId, query: query, filter: filter, cancellationToken: cancellationToken);
}

/// <inheritdoc />
Expand Down
6 changes: 3 additions & 3 deletions dotnet/CoreLib/SemanticMemoryService.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public interface ISemanticMemoryService
/// <returns>Pipeline status if available</returns>
Task<DataPipeline?> ReadPipelineStatusAsync(string userId, string documentId, CancellationToken cancellationToken = default);

Task<MemoryAnswer> AskAsync(SearchRequest request, CancellationToken cancellationToken = default);
Task<MemoryAnswer> AskAsync(MemoryQuery query, CancellationToken cancellationToken = default);
}

public class SemanticMemoryService : ISemanticMemoryService
Expand Down Expand Up @@ -59,8 +59,8 @@ public Task<string> UploadFileAsync(
}

///<inheritdoc />
public Task<MemoryAnswer> AskAsync(SearchRequest request, CancellationToken cancellationToken = default)
public Task<MemoryAnswer> AskAsync(MemoryQuery query, CancellationToken cancellationToken = default)
{
return this._searchClient.SearchAsync(request, cancellationToken);
return this._searchClient.AskAsync(query, cancellationToken);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

namespace Microsoft.SemanticMemory.Core.WebService;

public class SearchRequest
public class MemoryQuery
{
public string Query { get; set; } = string.Empty;
public string UserId { get; set; } = string.Empty;
public TagCollection Tags { get; set; } = new();
public string Query { get; set; } = string.Empty;
public MemoryFilter Filter { get; set; } = new();
}
4 changes: 2 additions & 2 deletions dotnet/Service/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@
// Ask endpoint
app.MapPost("/ask",
async Task<IResult> (
SearchRequest request,
MemoryQuery query,
ISemanticMemoryService service,
ILogger<Program> log) =>
{
log.LogTrace("New search request");
MemoryAnswer answer = await service.AskAsync(request);
MemoryAnswer answer = await service.AskAsync(query);
return Results.Ok(answer);
})
.Produces<MemoryAnswer>(StatusCodes.Status200OK);
Expand Down
19 changes: 16 additions & 3 deletions samples/001-dotnet-Serverless/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ await memory.ImportFilesAsync(new[]
{
await memory.ImportFileAsync("file5-NASA-news.pdf",
new DocumentDetails("user2", "f05")
.AddTag("collection", "samples")
.AddTag("collection", "webClient")
.AddTag("collection", ".NET")
.AddTag("collection", "meetings")
.AddTag("collection", "NASA")
.AddTag("collection", "space")
.AddTag("type", "news"));
}

Expand Down Expand Up @@ -85,3 +85,16 @@ await memory.ImportFileAsync("file5-NASA-news.pdf",
{
Console.WriteLine($" - {x.SourceName} - {x.Link} [{x.Partitions.First().LastUpdate:D}]");
}

// Test with tags
question = "What is Orion?";
Console.WriteLine($"\n\nQuestion: {question}");

var filter1 = new MemoryFilter().ByTag("type", "article");
var filter2 = new MemoryFilter().ByTag("type", "news");

answer = await memory.AskAsync("user2", question, filter1);
Console.WriteLine($"\nArticles: {answer.Result}\n\n");

answer = await memory.AskAsync("user2", question, filter2);
Console.WriteLine($"\nNews: {answer.Result}\n\n");
Loading

0 comments on commit 9c6e4ce

Please sign in to comment.