Skip to content

Commit

Permalink
Simplify multiple vector DBs/embedding generators DI
Browse files Browse the repository at this point in the history
  • Loading branch information
dluc committed Aug 8, 2023
1 parent 93f2fc0 commit 07b6e2d
Show file tree
Hide file tree
Showing 21 changed files with 142 additions and 155 deletions.
27 changes: 0 additions & 27 deletions dotnet/CoreLib/AI/AzureOpenAI/DependencyInjection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
using Microsoft.SemanticMemory.Core.AppBuilders;
using Microsoft.SemanticMemory.Core.ContentStorage.AzureBlobs;
using Microsoft.SemanticMemory.Core.Diagnostics;

Expand Down Expand Up @@ -84,30 +83,4 @@ public static IServiceCollection AddSemanticKernelWithAzureOpenAI(this IServiceC
throw new NotImplementedException($"Azure OpenAI auth type '{config.Auth}' not available");
}
}

public static void AddAzureOpenAIEmbeddingGenerationToList(this ConfiguredServices<ITextEmbeddingGeneration> services, AzureOpenAIConfig config)
{
switch (config.Auth)
{
case "":
case string x when x.Equals("AzureIdentity", StringComparison.OrdinalIgnoreCase):
services.Add(serviceProvider => new AzureTextEmbeddingGeneration(
modelId: config.Deployment,
endpoint: config.Endpoint,
credential: new DefaultAzureCredential(),
logger: serviceProvider.GetService<ILogger<AzureBlob>>()));
break;

case string y when y.Equals("APIKey", StringComparison.OrdinalIgnoreCase):
services.Add(serviceProvider => new AzureTextEmbeddingGeneration(
modelId: config.Deployment,
endpoint: config.Endpoint,
apiKey: config.APIKey,
logger: serviceProvider.GetService<ILogger<AzureBlob>>()));
break;

default:
throw new NotImplementedException($"Azure OpenAI auth type '{config.Auth}' not available");
}
}
}
10 changes: 0 additions & 10 deletions dotnet/CoreLib/AI/OpenAI/DependencyInjection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
using Microsoft.SemanticMemory.Core.AppBuilders;
using Microsoft.SemanticMemory.Core.Diagnostics;

namespace Microsoft.SemanticMemory.Core.AI.OpenAI;
Expand Down Expand Up @@ -64,13 +63,4 @@ public static IServiceCollection AddSemanticKernelWithOpenAI(this IServiceCollec
setAsDefault: true)
.Build());
}

public static void AddOpenAITextEmbeddingGenerationToList(this ConfiguredServices<ITextEmbeddingGeneration> services, OpenAIConfig config)
{
services.Add(serviceProvider => new OpenAITextEmbeddingGeneration(
modelId: config.EmbeddingModel,
apiKey: config.APIKey,
organization: config.OrgId,
logger: serviceProvider.GetService<ILogger<OpenAITextEmbeddingGeneration>>()));
}
}
22 changes: 0 additions & 22 deletions dotnet/CoreLib/AppBuilders/ConfiguredServices.cs

This file was deleted.

32 changes: 32 additions & 0 deletions dotnet/CoreLib/Configuration/TypeCollection.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;

namespace Microsoft.SemanticMemory.Core.Configuration;

public class TypeCollection<T> where T : class
{
private readonly List<Type> _types;

public void Add<TImplementation>() where TImplementation : T
{
this._types.Add(typeof(TImplementation));
}

public List<Type> GetList()
{
return this._types.Select(x => x).ToList();
}

public TypeCollection()
{
this._types = new();
}

public TypeCollection(Type firstValue)
{
this._types = new() { firstValue };
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticMemory.Client;
using Microsoft.SemanticMemory.Core.AppBuilders;

namespace Microsoft.SemanticMemory.Core.ContentStorage.AzureBlobs;

Expand All @@ -15,10 +13,4 @@ public static IServiceCollection AddAzureBlobAsContentStorage(this IServiceColle
.AddSingleton<IContentStorage, AzureBlob>()
.AddSingleton<AzureBlob, AzureBlob>();
}

public static void AddAzureBlobAsContentStorageToList(this ConfiguredServices<IContentStorage> services, AzureBlobConfig config)
{
services.Add(serviceProvider => serviceProvider.GetService<AzureBlob>()
?? throw new SemanticMemoryException("Unable to instantiate " + typeof(AzureBlob)));
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticMemory.Client;
using Microsoft.SemanticMemory.Core.AppBuilders;

namespace Microsoft.SemanticMemory.Core.ContentStorage.FileSystemStorage;

Expand All @@ -15,10 +13,4 @@ public static IServiceCollection AddFileSystemAsContentStorage(this IServiceColl
.AddSingleton<IContentStorage, FileSystem>()
.AddSingleton<FileSystem, FileSystem>();
}

public static void AddFileSystemAsContentStorageToList(this ConfiguredServices<IContentStorage> services, FileSystemConfig config)
{
services.Add(serviceProvider => serviceProvider.GetService<FileSystem>()
?? throw new SemanticMemoryException("Unable to instantiate " + typeof(FileSystem)));
}
}
1 change: 1 addition & 0 deletions dotnet/CoreLib/CoreLib.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
<PackageReference Include="Azure.Storage.Queues" Version="12.15.0" />
<PackageReference Include="DocumentFormat.OpenXml" Version="2.20.0" />
<PackageReference Include="Microsoft.AspNetCore.Http" Version="2.2.2" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="7.0.0" />
<PackageReference Include="Microsoft.Extensions.Hosting" Version="7.0.1" />
<PackageReference Include="Microsoft.Extensions.Logging.Abstractions" Version="7.0.1" />
<PackageReference Include="Microsoft.SemanticKernel" Version="0.17.230718.1-preview" />
Expand Down
17 changes: 2 additions & 15 deletions dotnet/CoreLib/Handlers/GenerateEmbeddingsHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,8 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticMemory.Client;
using Microsoft.SemanticMemory.Core.AppBuilders;
using Microsoft.SemanticMemory.Core.ContentStorage;
using Microsoft.SemanticMemory.Core.Diagnostics;
using Microsoft.SemanticMemory.Core.Pipeline;
Expand All @@ -32,26 +29,16 @@ public class GenerateEmbeddingsHandler : IPipelineStepHandler
/// </summary>
/// <param name="stepName">Pipeline step for which the handler will be invoked</param>
/// <param name="orchestrator">Current orchestrator used by the pipeline, giving access to content and other helps.</param>
/// <param name="serviceProvider">.NET service provider</param>
/// <param name="log">Application logger</param>
public GenerateEmbeddingsHandler(
string stepName,
IPipelineOrchestrator orchestrator,
IServiceProvider serviceProvider,
ILogger<GenerateEmbeddingsHandler>? log = null)
{
this.StepName = stepName;
this._orchestrator = orchestrator;
this._log = log
?? serviceProvider.GetService<ILogger<GenerateEmbeddingsHandler>>()
?? DefaultLogger<GenerateEmbeddingsHandler>.Instance;

var embeddingGeneratorBuilders = serviceProvider.GetService<ConfiguredServices<ITextEmbeddingGeneration>>()
?? throw new SemanticMemoryException("List of embedding generators not configured");
foreach (Func<IServiceProvider, ITextEmbeddingGeneration> x in embeddingGeneratorBuilders.GetList())
{
this._embeddingGenerators.Add(x.Invoke(serviceProvider));
}
this._log = log ?? DefaultLogger<GenerateEmbeddingsHandler>.Instance;
this._embeddingGenerators = orchestrator.GetEmbeddingGenerators();

this._log.LogInformation("Handler '{0}' ready, {1} embedding generators", stepName, this._embeddingGenerators.Count);
if (this._embeddingGenerators.Count < 1)
Expand Down
16 changes: 2 additions & 14 deletions dotnet/CoreLib/Handlers/SaveEmbeddingsHandler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,8 @@
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticMemory.Client;
using Microsoft.SemanticMemory.Core.AppBuilders;
using Microsoft.SemanticMemory.Core.ContentStorage;
using Microsoft.SemanticMemory.Core.Diagnostics;
using Microsoft.SemanticMemory.Core.MemoryStorage;
Expand All @@ -29,26 +27,16 @@ public class SaveEmbeddingsHandler : IPipelineStepHandler
/// </summary>
/// <param name="stepName">Pipeline step for which the handler will be invoked</param>
/// <param name="orchestrator">Current orchestrator used by the pipeline, giving access to content and other helps.</param>
/// <param name="serviceProvider">.NET service provider</param>
/// <param name="log">Application logger</param>
public SaveEmbeddingsHandler(
string stepName,
IPipelineOrchestrator orchestrator,
IServiceProvider serviceProvider,
ILogger<SaveEmbeddingsHandler>? log = null)
{
this.StepName = stepName;
this._orchestrator = orchestrator;
this._log = log
?? serviceProvider.GetService<ILogger<SaveEmbeddingsHandler>>()
?? DefaultLogger<SaveEmbeddingsHandler>.Instance;

var vectorDbBuilders = serviceProvider.GetService<ConfiguredServices<ISemanticMemoryVectorDb>>()
?? throw new SemanticMemoryException("List of embedding generators not configured");
foreach (Func<IServiceProvider, ISemanticMemoryVectorDb> x in vectorDbBuilders.GetList())
{
this._vectorDbs.Add(x.Invoke(serviceProvider));
}
this._log = log ?? DefaultLogger<SaveEmbeddingsHandler>.Instance;
this._vectorDbs = orchestrator.GetVectorDbs();

this._log.LogInformation("Handler {0} ready, {1} vector storages", stepName, this._vectorDbs.Count);
if (this._vectorDbs.Count < 1)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
// Copyright (c) Microsoft. All rights reserved.

using Microsoft.Extensions.DependencyInjection;
using Microsoft.SemanticMemory.Client;
using Microsoft.SemanticMemory.Core.AppBuilders;

namespace Microsoft.SemanticMemory.Core.MemoryStorage.AzureCognitiveSearch;

Expand All @@ -15,10 +13,4 @@ public static IServiceCollection AddAzureCognitiveSearchAsVectorDb(this IService
.AddSingleton<ISemanticMemoryVectorDb, AzureCognitiveSearchMemory>()
.AddSingleton<AzureCognitiveSearchMemory, AzureCognitiveSearchMemory>();
}

public static void AddAzureCognitiveSearchAsVectorDbToList(this ConfiguredServices<ISemanticMemoryVectorDb> services, AzureCognitiveSearchConfig config)
{
services.Add(serviceProvider => serviceProvider.GetService<AzureCognitiveSearchMemory>()
?? throw new SemanticMemoryException("Unable to instantiate " + typeof(AzureCognitiveSearchMemory)));
}
}
41 changes: 41 additions & 0 deletions dotnet/CoreLib/Pipeline/BaseOrchestrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,60 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticMemory.Client;
using Microsoft.SemanticMemory.Client.Models;
using Microsoft.SemanticMemory.Core.Configuration;
using Microsoft.SemanticMemory.Core.ContentStorage;
using Microsoft.SemanticMemory.Core.Diagnostics;
using Microsoft.SemanticMemory.Core.MemoryStorage;
using Microsoft.SemanticMemory.Core.WebService;

namespace Microsoft.SemanticMemory.Core.Pipeline;

public abstract class BaseOrchestrator : IPipelineOrchestrator, IDisposable
{
private readonly List<ISemanticMemoryVectorDb> _vectorDbs;
private readonly List<ITextEmbeddingGeneration> _embeddingGenerators;

protected IContentStorage ContentStorage { get; private set; }
protected ILogger<BaseOrchestrator> Log { get; private set; }
protected CancellationTokenSource CancellationTokenSource { get; private set; }
protected IMimeTypeDetection MimeTypeDetection { get; private set; }

protected BaseOrchestrator(
IContentStorage contentStorage,
IServiceProvider serviceProvider,
IMimeTypeDetection? mimeTypeDetection = null,
ILogger<BaseOrchestrator>? log = null)
{
this.MimeTypeDetection = mimeTypeDetection ?? new MimeTypesDetection();
this.ContentStorage = contentStorage;
this.Log = log ?? DefaultLogger<BaseOrchestrator>.Instance;
this.CancellationTokenSource = new CancellationTokenSource();

this._embeddingGenerators = new List<ITextEmbeddingGeneration>();

var embeddingGenerators = serviceProvider.GetService<TypeCollection<ITextEmbeddingGeneration>>()
?? throw new SemanticMemoryException("Service provider is missing " + typeof(TypeCollection<ITextEmbeddingGeneration>));
foreach (Type t in embeddingGenerators.GetList())
{
var service = serviceProvider.GetService(t)
?? throw new SemanticMemoryException("Unable to instantiate " + t.FullName);
this._embeddingGenerators.Add((ITextEmbeddingGeneration)service);
}

this._vectorDbs = new List<ISemanticMemoryVectorDb>();
var vectorDbs = serviceProvider.GetService<TypeCollection<ISemanticMemoryVectorDb>>()
?? throw new SemanticMemoryException("Service provider is missing " + typeof(TypeCollection<ISemanticMemoryVectorDb>));
foreach (Type t in vectorDbs.GetList())
{
var service = serviceProvider.GetService(t)
?? throw new SemanticMemoryException("Unable to instantiate " + t.FullName);
this._vectorDbs.Add((ISemanticMemoryVectorDb)service);
}
}

///<inheritdoc />
Expand Down Expand Up @@ -154,6 +183,18 @@ public Task WriteTextFileAsync(DataPipeline pipeline, string fileName, string fi
return this.WriteFileAsync(pipeline, fileName, BinaryData.FromString(fileContent), cancellationToken);
}

///<inheritdoc />
public List<ITextEmbeddingGeneration> GetEmbeddingGenerators()
{
return this._embeddingGenerators;
}

///<inheritdoc />
public List<ISemanticMemoryVectorDb> GetVectorDbs()
{
return this._vectorDbs;
}

///<inheritdoc />
public Task WriteFileAsync(DataPipeline pipeline, string fileName, BinaryData fileContent, CancellationToken cancellationToken = default)
{
Expand Down
3 changes: 2 additions & 1 deletion dotnet/CoreLib/Pipeline/DistributedPipelineOrchestrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ public DistributedPipelineOrchestrator(
IContentStorage contentStorage,
IMimeTypeDetection mimeTypeDetection,
QueueClientFactory queueClientFactory,
IServiceProvider serviceProvider,
ILogger<DistributedPipelineOrchestrator> log)
: base(contentStorage, mimeTypeDetection, log)
: base(contentStorage, serviceProvider, mimeTypeDetection, log)
{
this._queueClientFactory = queueClientFactory;
}
Expand Down
14 changes: 14 additions & 0 deletions dotnet/CoreLib/Pipeline/IPipelineOrchestrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Http;
using Microsoft.SemanticKernel.AI.Embeddings;
using Microsoft.SemanticMemory.Client.Models;
using Microsoft.SemanticMemory.Core.MemoryStorage;
using Microsoft.SemanticMemory.Core.WebService;

namespace Microsoft.SemanticMemory.Core.Pipeline;
Expand Down Expand Up @@ -118,4 +120,16 @@ public interface IPipelineOrchestrator
/// <param name="fileContent">File content</param>
/// <param name="cancellationToken">Async task cancellation token</param>
Task WriteTextFileAsync(DataPipeline pipeline, string fileName, string fileContent, CancellationToken cancellationToken = default);

/// <summary>
/// Get list of embedding generators to use during the ingestion, e.g. to create
/// multiple vectors.
/// </summary>
List<ITextEmbeddingGeneration> GetEmbeddingGenerators();

/// <summary>
/// Get list of Vector DBs where to store embeddings.
/// </summary>
/// <returns></returns>
List<ISemanticMemoryVectorDb> GetVectorDbs();
}
3 changes: 2 additions & 1 deletion dotnet/CoreLib/Pipeline/InProcessPipelineOrchestrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ public class InProcessPipelineOrchestrator : BaseOrchestrator

public InProcessPipelineOrchestrator(
IContentStorage contentStorage,
IServiceProvider serviceProvider,
IMimeTypeDetection? mimeTypeDetection = null,
ILogger<InProcessPipelineOrchestrator>? log = null)
: base(contentStorage, mimeTypeDetection, log)
: base(contentStorage, serviceProvider, mimeTypeDetection, log)
{
}

Expand Down
Loading

0 comments on commit 07b6e2d

Please sign in to comment.