Skip to content

Commit

Permalink
Updating enumerable resolution for generics
Browse files Browse the repository at this point in the history
We now attempt to match each descriptor against both the service type
and implementation type of cached call sites to account for slot mismatches

Added unit test to validate this functionality

Fix dotnet#87017
  • Loading branch information
mesakomarevich committed Jun 1, 2023
1 parent 15f5f79 commit a6fe477
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -273,18 +273,53 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
}
else
{
int slot = 0;
/*
* We cannot assume that services cached in _callSiteCache have a slot matching the order of registered descriptors.
* Consider the following:
*
* serviceCollection.AddTransient<IOpenGenericFoo<IFoo>, ClosedGenericFoo1>();
* serviceCollection.AddTransient(typeof(IOpenGenericFoo<>), typeof(OpenGenericFoo1<>));
*
* var serviceProvider = serviceCollection.BuildServiceProvider();
* var service = serviceProvider.GetService<IOpenGenericFoo<IFoo>>();
* var services = serviceProvider.GetServices<IOpenGenericFoo<IFoo>>();
*
* When service is resolved, we will first call TryCreateExact, which will find and cache ClosedGenericFoo1 with Slot = 0.
*
* When services is resolved, if we loop through descriptors calling TryCreateExact and then TryCreateOpenGeneric,
* we will first find OpenGenericFoo1<> (as it was registered last), call TryCreateOpenGeneric and check the cache
* for a service matching: { Type = IOpenGenericFoo<IFoo>, Slot = 0 }, which will return ClosedGenericFoo1.
*
* Next, we will find the descriptor for ClosedGenericFoo1, call TryCreateExact and check the cache for a service
* matching: { Type = IOpenGenericFoo<IFoo>, Slot = 1 }, which will fail, so we will create, cache, and return a
* new call site for ClosedGenericFoo1.
*
* Finally, we will return an enumerable containing two instances of ClosedGenericFoo1, and no intance of
* OpenGenericFoo1<IFoo>.
*
* To prevent this, we need to check each descriptor against both the service and implentation types of the cached
* services, and adjust our slot value accordingly.
*/
List<ServiceCallSite> cachedCallSites = GetCachedCallSitesForService(itemType);
int slot = cachedCallSites.Count;
// We are going in reverse so the last service in descriptor list gets slot 0
for (int i = _descriptors.Length - 1; i >= 0; i--)
{
ServiceDescriptor descriptor = _descriptors[i];
ServiceCallSite? callSite = TryCreateExact(descriptor, itemType, callSiteChain, slot) ??
TryCreateOpenGeneric(descriptor, itemType, callSiteChain, slot, false);

if (callSite != null)
ServiceCallSite? callSite = cachedCallSites.Find(scs => CachedCallSiteMatchesDescriptor(scs, descriptor));

if (callSite == null)
{
slot++;
callSite = TryCreateExact(descriptor, itemType, callSiteChain, slot) ??
TryCreateOpenGeneric(descriptor, itemType, callSiteChain, slot, false);

// We only increment slot when we create a new call site
slot = callSite == null ? slot : slot + 1;
}

if (callSite != null)
{
cacheLocation = GetCommonCacheLocation(cacheLocation, callSite.Cache.Location);
callSites.Add(callSite);
}
Expand All @@ -311,6 +346,42 @@ private static bool AreCompatible(DynamicallyAccessedMemberTypes serviceDynamica
}
}

private List<ServiceCallSite> GetCachedCallSitesForService(Type serviceType, int slot = DefaultSlot)
{
var cachedCallSites = new List<ServiceCallSite>();
var callSiteKey = new ServiceCacheKey(serviceType, slot);

while (_callSiteCache.TryGetValue(callSiteKey, out ServiceCallSite? serviceCallSite))
{
cachedCallSites.Add(serviceCallSite);
slot++;
callSiteKey = new ServiceCacheKey(serviceType, slot);
}

return cachedCallSites;
}

private static bool CachedCallSiteMatchesDescriptor(ServiceCallSite serviceCallSite, ServiceDescriptor descriptor)
{
// Check for exact match
if (serviceCallSite.ServiceType == descriptor.ServiceType
&& serviceCallSite.ImplementationType == descriptor.ImplementationType)
{
return true;
}

// Check for open generic match
if (serviceCallSite.ServiceType.IsConstructedGenericType
&& serviceCallSite.ServiceType.GetGenericTypeDefinition() == descriptor.ServiceType
&& (serviceCallSite.ImplementationType?.IsGenericType ?? false)
&& serviceCallSite.ImplementationType?.GetGenericTypeDefinition() == descriptor.ImplementationType)
{
return true;
}

return false;
}

private static CallSiteResultCacheLocation GetCommonCacheLocation(CallSiteResultCacheLocation locationA, CallSiteResultCacheLocation locationB)
{
return (CallSiteResultCacheLocation)Math.Max((int)locationA, (int)locationB);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Microsoft.Extensions.DependencyInjection.Extensions;
using Xunit;

namespace Microsoft.Extensions.DependencyInjection
Expand Down Expand Up @@ -249,6 +250,25 @@ public async Task CreateAsyncScope_Returns_AsyncServiceScope_Wrapping_ServiceSco
Assert.IsType<Foo1>(service);
}

[Fact]
public void GetServices_Returns_AllServices_After_Closed_Generic_Is_Cached()
{
// Arrange
var serviceProvider = CreateTestServiceProvider(8);

// Act
var service = serviceProvider.GetService<IOpenGenericFoo<IFoo>>();
var services = serviceProvider.GetServices<IOpenGenericFoo<IFoo>>();

// Assert
Assert.True(service is ClosedGenericFoo2);
Assert.Contains(services, item => item is ClosedGenericFoo1);
Assert.Contains(services, item => item is ClosedGenericFoo2);
Assert.Contains(services, item => item is OpenGenericFoo1<IFoo>);
Assert.Contains(services, item => item is OpenGenericFoo2<IFoo>);
Assert.Equal(4, services.Count());
}

private static IServiceProvider CreateTestServiceProvider(int count)
{
var serviceCollection = new ServiceCollection();
Expand All @@ -273,6 +293,28 @@ private static IServiceProvider CreateTestServiceProvider(int count)
serviceCollection.AddTransient<IBar, Bar2>();
}

// Note that ClosedGenericFoos are registered before OpenGenericFoos to test the inverse order lookup of
// descriptors for resolving enumerables
if (count > 4)
{
serviceCollection.AddTransient<IOpenGenericFoo<IFoo>, ClosedGenericFoo1>();
}

if (count > 5)
{
serviceCollection.AddTransient<IOpenGenericFoo<IFoo>, ClosedGenericFoo2>();
}

if (count > 6)
{
serviceCollection.AddTransient(typeof(IOpenGenericFoo<>), typeof(OpenGenericFoo1<>));
}

if (count > 7)
{
serviceCollection.AddTransient(typeof(IOpenGenericFoo<>), typeof(OpenGenericFoo2<>));
}

return serviceCollection.BuildServiceProvider();
}

Expand All @@ -288,6 +330,16 @@ public class Bar1 : IBar { }

public class Bar2 : IBar { }

public interface IOpenGenericFoo<T> where T : IFoo { }

public class OpenGenericFoo1<T> : IOpenGenericFoo<T> where T : IFoo { }

public class OpenGenericFoo2<T> : IOpenGenericFoo<T> where T : IFoo { }

public class ClosedGenericFoo1 : IOpenGenericFoo<IFoo> { }

public class ClosedGenericFoo2 : IOpenGenericFoo<IFoo> { }

private class RequiredServiceSupportingProvider : IServiceProvider, ISupportRequiredService
{
object ISupportRequiredService.GetRequiredService(Type serviceType)
Expand Down

0 comments on commit a6fe477

Please sign in to comment.