Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate registration to single configuration object and optimize registration #828

Merged
merged 2 commits into from
Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion samples/MediatR.Examples.AspNetCore/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ private static IMediator BuildMediator(WrappingWriter writer)

services.AddSingleton<TextWriter>(writer);

services.AddMediatR(typeof(Ping), typeof(Sing));
services.AddMediatR(cfg =>
{
cfg.RegisterServicesFromAssemblies(typeof(Ping).Assembly, typeof(Sing).Assembly);
});

services.AddScoped(typeof(IStreamRequestHandler<Sing, Song>), typeof(SingHandler));

Expand Down
57 changes: 36 additions & 21 deletions src/MediatR/MicrosoftExtensionsDI/MediatrServiceConfiguration.cs
Original file line number Diff line number Diff line change
@@ -1,48 +1,63 @@
using System;
using System.Collections.Generic;
using System.Reflection;
using MediatR;

namespace Microsoft.Extensions.DependencyInjection;

public class MediatRServiceConfiguration
{
public Func<Type, bool> TypeEvaluator { get; private set; } = t => true;
public Type MediatorImplementationType { get; private set; }
public ServiceLifetime Lifetime { get; private set; }
public Func<Type, bool> TypeEvaluator { get; set; } = t => true;
public Type MediatorImplementationType { get; set; } = typeof(Mediator);
public ServiceLifetime Lifetime { get; set; } = ServiceLifetime.Transient;
public RequestExceptionActionProcessorStrategy RequestExceptionActionProcessorStrategy { get; set; }
= RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions;

public MediatRServiceConfiguration()
{
MediatorImplementationType = typeof(Mediator);
Lifetime = ServiceLifetime.Transient;
}
internal List<Assembly> AssembliesToRegister { get; } = new();

public MediatRServiceConfiguration Using<TMediator>() where TMediator : IMediator
{
MediatorImplementationType = typeof(TMediator);
return this;
}
public List<ServiceDescriptor> BehaviorsToRegister { get; } = new();

public MediatRServiceConfiguration RegisterServicesFromAssemblyContaining<T>()
=> RegisterServicesFromAssemblyContaining(typeof(T));

public MediatRServiceConfiguration RegisterServicesFromAssemblyContaining(Type type)
=> RegisterServicesFromAssembly(type.Assembly);

public MediatRServiceConfiguration AsSingleton()
public MediatRServiceConfiguration RegisterServicesFromAssembly(Assembly assembly)
{
Lifetime = ServiceLifetime.Singleton;
AssembliesToRegister.Add(assembly);

return this;
}

public MediatRServiceConfiguration AsScoped()
public MediatRServiceConfiguration RegisterServicesFromAssemblies(
params Assembly[] assemblies)
{
Lifetime = ServiceLifetime.Scoped;
AssembliesToRegister.AddRange(assemblies);

return this;
}

public MediatRServiceConfiguration AsTransient()
public MediatRServiceConfiguration AddBehavior<TServiceType, TImplementationType>(
ServiceLifetime serviceLifetime = ServiceLifetime.Transient) =>
AddBehavior(typeof(TServiceType), typeof(TImplementationType), serviceLifetime);

public MediatRServiceConfiguration AddBehavior(
Type serviceType,
Type implementationType,
ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
Lifetime = ServiceLifetime.Transient;
BehaviorsToRegister.Add(new ServiceDescriptor(serviceType, implementationType, serviceLifetime));

return this;
}

public MediatRServiceConfiguration WithEvaluator(Func<Type, bool> evaluator)
public MediatRServiceConfiguration AddOpenBehavior(Type openBehaviorType, ServiceLifetime serviceLifetime = ServiceLifetime.Transient)
{
TypeEvaluator = evaluator;
var serviceType = typeof(IPipelineBehavior<,>);

BehaviorsToRegister.Add(new ServiceDescriptor(serviceType, openBehaviorType, serviceLifetime));

return this;
}
}
63 changes: 8 additions & 55 deletions src/MediatR/MicrosoftExtensionsDI/ServiceCollectionExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,71 +23,24 @@ public static class ServiceCollectionExtensions
/// Registers handlers and mediator types from the specified assemblies
/// </summary>
/// <param name="services">Service collection</param>
/// <param name="assemblies">Assemblies to scan</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, params Assembly[] assemblies)
=> services.AddMediatR(assemblies, configuration: null);

/// <summary>
/// Registers handlers and mediator types from the specified assemblies
/// </summary>
/// <param name="services">Service collection</param>
/// <param name="assemblies">Assemblies to scan</param>
/// <param name="configuration">The action used to configure the options</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, Action<MediatRServiceConfiguration>? configuration, params Assembly[] assemblies)
=> services.AddMediatR(assemblies, configuration);

/// <summary>
/// Registers handlers and mediator types from the specified assemblies
/// </summary>
/// <param name="services">Service collection</param>
/// <param name="assemblies">Assemblies to scan</param>
/// <param name="configuration">The action used to configure the options</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, IEnumerable<Assembly> assemblies, Action<MediatRServiceConfiguration>? configuration)
public static IServiceCollection AddMediatR(this IServiceCollection services,
Action<MediatRServiceConfiguration> configuration)
{
if (!assemblies.Any())
var serviceConfig = new MediatRServiceConfiguration();

configuration.Invoke(serviceConfig);

if (!serviceConfig.AssembliesToRegister.Any())
{
throw new ArgumentException("No assemblies found to scan. Supply at least one assembly to scan for handlers.");
}
var serviceConfig = new MediatRServiceConfiguration();

configuration?.Invoke(serviceConfig);
ServiceRegistrar.AddMediatRClasses(services, serviceConfig);

ServiceRegistrar.AddRequiredServices(services, serviceConfig);

ServiceRegistrar.AddMediatRClasses(services, assemblies, serviceConfig);

return services;
}

/// <summary>
/// Registers handlers and mediator types from the assemblies that contain the specified types
/// </summary>
/// <param name="services"></param>
/// <param name="handlerAssemblyMarkerTypes"></param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, params Type[] handlerAssemblyMarkerTypes)
=> services.AddMediatR(handlerAssemblyMarkerTypes, configuration: null);

/// <summary>
/// Registers handlers and mediator types from the assemblies that contain the specified types
/// </summary>
/// <param name="services"></param>
/// <param name="handlerAssemblyMarkerTypes"></param>
/// <param name="configuration">The action used to configure the options</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, Action<MediatRServiceConfiguration>? configuration, params Type[] handlerAssemblyMarkerTypes)
=> services.AddMediatR(handlerAssemblyMarkerTypes, configuration);

/// <summary>
/// Registers handlers and mediator types from the assemblies that contain the specified types
/// </summary>
/// <param name="services"></param>
/// <param name="handlerAssemblyMarkerTypes"></param>
/// <param name="configuration">The action used to configure the options</param>
/// <returns>Service collection</returns>
public static IServiceCollection AddMediatR(this IServiceCollection services, IEnumerable<Type> handlerAssemblyMarkerTypes, Action<MediatRServiceConfiguration>? configuration)
=> services.AddMediatR(handlerAssemblyMarkerTypes.Select(t => t.GetTypeInfo().Assembly), configuration);
}
2 changes: 1 addition & 1 deletion src/MediatR/Pipeline/RequestExceptionHandlerState.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ public class RequestExceptionHandlerState<TResponse>
public bool Handled { get; private set; }

/// <summary>
/// The response that is returned if <see cref="Handled"/> is <code>true</code>.
/// The response that is returned if <see cref="Handled"/> is <code>true</code>.
/// </summary>
public TResponse? Response { get; private set; }

Expand Down
53 changes: 43 additions & 10 deletions src/MediatR/Registration/ServiceRegistrar.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ namespace MediatR.Registration;

public static class ServiceRegistrar
{
public static void AddMediatRClasses(IServiceCollection services, IEnumerable<Assembly> assembliesToScan, MediatRServiceConfiguration configuration)
public static void AddMediatRClasses(IServiceCollection services, MediatRServiceConfiguration configuration)
{
assembliesToScan = assembliesToScan.Distinct().ToArray();
var assembliesToScan = configuration.AssembliesToRegister.Distinct().ToArray();

ConnectImplementationsToTypesClosing(typeof(IRequestHandler<,>), services, assembliesToScan, false, configuration);
ConnectImplementationsToTypesClosing(typeof(INotificationHandler<>), services, assembliesToScan, true, configuration);
Expand Down Expand Up @@ -217,20 +217,53 @@ public static void AddRequiredServices(IServiceCollection services, MediatRServi
services.TryAdd(new ServiceDescriptor(typeof(ISender), sp => sp.GetRequiredService<IMediator>(), serviceConfiguration.Lifetime));
services.TryAdd(new ServiceDescriptor(typeof(IPublisher), sp => sp.GetRequiredService<IMediator>(), serviceConfiguration.Lifetime));

// Use TryAddTransientExact (see below), we dó want to register our Pre/Post processor behavior, even if (a more concrete)
foreach (var serviceDescriptor in serviceConfiguration.BehaviorsToRegister)
{
services.Add(serviceDescriptor);
}

// Use TryAddTransientExact (see below), we do want to register our Pre/Post processor behavior, even if (a more concrete)
// registration for IPipelineBehavior<,> already exists. But only once.
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestPreProcessorBehavior<,>));
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestPostProcessorBehavior<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestPreProcessorBehavior<,>),
typeof(IRequestPreProcessor<>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestPostProcessorBehavior<,>),
typeof(IRequestPostProcessor<,>));

if (serviceConfiguration.RequestExceptionActionProcessorStrategy == RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
if (serviceConfiguration.RequestExceptionActionProcessorStrategy ==
RequestExceptionActionProcessorStrategy.ApplyForUnhandledExceptions)
{
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestExceptionActionProcessorBehavior<,>));
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestExceptionProcessorBehavior<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>),
typeof(IRequestExceptionAction<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>),
typeof(IRequestExceptionHandler<,,>));
}
else
{
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestExceptionProcessorBehavior<,>));
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), typeof(RequestExceptionActionProcessorBehavior<,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionProcessorBehavior<,>),
typeof(IRequestExceptionHandler<,,>));
RegisterBehaviorIfImplementationsExist(services, typeof(RequestExceptionActionProcessorBehavior<,>),
typeof(IRequestExceptionAction<,>));
}
}

private static void RegisterBehaviorIfImplementationsExist(
IServiceCollection services,
Type behaviorType,
Type subBehaviorType
)
{
var hasAnyRegistrationsOfSubBehaviorType = services
.Select(service => service.ImplementationType)
.Where(type => type != null)
.SelectMany(type => type!.GetInterfaces())
.Where(type => type.IsGenericType)
.Select(type => type.GetGenericTypeDefinition())
.Where(type => type != null)
.Any(type => type == subBehaviorType);

if (hasAnyRegistrationsOfSubBehaviorType)
{
services.TryAddTransientExact(typeof(IPipelineBehavior<,>), behaviorType);
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/MediatR.Benchmarks/Benchmarks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public void GlobalSetup()

services.AddSingleton(TextWriter.Null);

services.AddMediatR(typeof(Ping));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblyContaining(typeof(Ping)));

services.AddScoped(typeof(IPipelineBehavior<,>), typeof(GenericPipelineBehavior<,>));
services.AddScoped(typeof(IRequestPreProcessor<>), typeof(GenericRequestPreProcessor<>));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public AssemblyResolutionTests()
{
IServiceCollection services = new ServiceCollection();
services.AddSingleton(new Logger());
services.AddMediatR(typeof(Ping).GetTypeInfo().Assembly);
services.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly));
_provider = services.BuildServiceProvider();
}

Expand Down Expand Up @@ -55,7 +55,7 @@ public void ShouldRequireAtLeastOneAssembly()
{
var services = new ServiceCollection();

Action registration = () => services.AddMediatR(new Type[0]);
Action registration = () => services.AddMediatR(_ => {});

registration.ShouldThrow<ArgumentException>();
}
Expand Down
14 changes: 11 additions & 3 deletions test/MediatR.Tests/MicrosoftExtensionsDI/CustomMediatorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ public CustomMediatorTests()
{
IServiceCollection services = new ServiceCollection();
services.AddSingleton(new Logger());
services.AddMediatR(cfg => cfg.Using<MyCustomMediator>(), typeof(CustomMediatorTests));
services.AddMediatR(cfg =>
{
cfg.MediatorImplementationType = typeof(MyCustomMediator);
cfg.RegisterServicesFromAssemblyContaining(typeof(CustomMediatorTests));
});
_provider = services.BuildServiceProvider();
}

Expand Down Expand Up @@ -43,10 +47,14 @@ public void Can_Call_AddMediatr_multiple_times()
{
IServiceCollection services = new ServiceCollection();
services.AddSingleton(new Logger());
services.AddMediatR(cfg => cfg.Using<MyCustomMediator>(), typeof(CustomMediatorTests));
services.AddMediatR(cfg =>
{
cfg.MediatorImplementationType = typeof(MyCustomMediator);
cfg.RegisterServicesFromAssemblyContaining(typeof(CustomMediatorTests));
});

// Call AddMediatr again, this should NOT override our custom mediatr (With MS DI, last registration wins)
services.AddMediatR(typeof(CustomMediatorTests));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblyContaining(typeof(CustomMediatorTests)));

var provider = services.BuildServiceProvider();
var mediator = provider.GetRequiredService<IMediator>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public DerivingRequestsTests()
{
IServiceCollection services = new ServiceCollection();
services.AddSingleton(new Logger());
services.AddMediatR(typeof(Ping));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblyContaining(typeof(Ping)));
_provider = services.BuildServiceProvider();
_mediator = _provider.GetRequiredService<IMediator>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ public DuplicateAssemblyResolutionTests()
{
IServiceCollection services = new ServiceCollection();
services.AddSingleton(new Logger());
services.AddMediatR(typeof(Ping), typeof(Ping));
services.AddMediatR(cfg => cfg.RegisterServicesFromAssemblies(typeof(Ping).Assembly, typeof(Ping).Assembly));
_provider = services.BuildServiceProvider();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public async Task Should_not_call_constructor_multiple_times_when_using_a_pipeli

services.AddSingleton(output);
services.AddTransient(typeof(IPipelineBehavior<,>), typeof(ConstructorTestBehavior<,>));
services.AddMediatR(typeof(Ping).GetTypeInfo().Assembly);
services.AddMediatR(cfg => cfg.RegisterServicesFromAssembly(typeof(Ping).Assembly));
var provider = services.BuildServiceProvider();

var mediator = provider.GetRequiredService<IMediator>();
Expand Down
Loading