From b261537485c11cd4a8f57651648ed702e6498087 Mon Sep 17 00:00:00 2001 From: Foxtrek_64 Date: Sun, 28 Feb 2021 14:43:30 -0600 Subject: [PATCH] Fix #20064 --- .../src/System/Linq/CachedReflection.cs | 42 ++++++++++ .../src/System/Linq/Queryable.cs | 82 +++++++++++++++++++ .../System.Linq/src/System/Linq/First.cs | 24 ++++-- .../System.Linq/src/System/Linq/Last.cs | 20 +++-- .../System.Linq/src/System/Linq/Single.cs | 12 ++- 5 files changed, 165 insertions(+), 15 deletions(-) diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs index f0eb0f169bf3c0..2d37fe2f53016a 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs @@ -284,6 +284,20 @@ public static MethodInfo FirstOrDefault_TSource_2(Type TSource) => (s_FirstOrDefault_TSource_2 ??= new Func, Expression>, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_FirstOrDefault_TSource_3; + + public static MethodInfo FirstOrDefault_TSource_3(Type TSource) => + (s_FirstOrDefault_TSource_3 ?? + (s_FirstOrDefault_TSource_3 = new Func, object?, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(); + + private static MethodInfo? s_FirstOrDefault_TSource_4; + + public static MethodInfo FirstOrDefault_TSource_4(Type TSource) => + (s_FirstOrDefault_TSource_4 ?? + (s_FirstOrDefault_TSource_4 = new Func, Expression>, object?, object?>(Queryable.FirstOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(TSource); + private static MethodInfo? s_GroupBy_TSource_TKey_2; public static MethodInfo GroupBy_TSource_TKey_2(Type TSource, Type TKey) => @@ -392,6 +406,20 @@ public static MethodInfo LastOrDefault_TSource_2(Type TSource) => (s_LastOrDefault_TSource_2 ??= new Func, Expression>, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_LastOrDefault_TSource_3; + + public static MethodInfo LastOrDefault_TSource_3(Type TSource) => + (s_LastOrDefault_TSource_3 ?? + (s_LastOrDefault_TSource_3 = new Func, object?, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(TSource); + + private static MethodInfo? s_LastOrDefault_TSource_4; + + public static MethodInfo LastOrDefault_TSource_4(Type TSource) => + (s_LastOrDefault_TSource_4 ?? + (s_LastOrDefault_TSource_4 = new Func, Expression>, object?, object?>(Queryable.LastOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(TSource); + private static MethodInfo? s_LongCount_TSource_1; public static MethodInfo LongCount_TSource_1(Type TSource) => @@ -536,6 +564,20 @@ public static MethodInfo SingleOrDefault_TSource_2(Type TSource) => (s_SingleOrDefault_TSource_2 ??= new Func, Expression>, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_SingleOrDefault_TSource_3; + + public static MethodInfo SingleOrDefault_TSource_3(Type TSource) => + (s_SingleOrDefault_TSource_3 ?? + (s_SingleOrDefault_TSource_3 = new Func, object?, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(TSource); + + private static MethodInfo? s_SingleOrDefault_TSource_4; + + public static MethodInfo SingleOrDefault_TSource_4(Type TSource) => + (s_SingleOrDefault_TSource_4 ?? + (s_SingleOrDefault_TSource_4 = new Func, Expression>, object?, object?>(Queryable.SingleOrDefault).GetMethodInfo().GetGenericMethodDefinition())) + .MakeGenericMethod(); + private static MethodInfo? s_Skip_TSource_2; public static MethodInfo Skip_TSource_2(Type TSource) => diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs index c08bcab79af79c..065218f1ba12ce 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs @@ -864,6 +864,18 @@ public static TSource First(this IQueryable source, Expression CachedReflectionInfo.FirstOrDefault_TSource_1(typeof(TSource)), source.Expression)); } + [DynamicDependency("FirstOrDefault`1", typeof(Enumerable))] + public static TSource? FirstOrDefault(this IQueryable source, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.FirstOrDefault_TSource_3(typeof(TSource)), + source.Expression, Expression.Constant(defaultValue, typeof(TSource)))); + } + [DynamicDependency("FirstOrDefault`1", typeof(Enumerable))] public static TSource? FirstOrDefault(this IQueryable source, Expression> predicate) { @@ -879,6 +891,21 @@ public static TSource First(this IQueryable source, Expression )); } + [DynamicDependency("FirstOrDefault`1", typeof(Enumerable))] + public static TSource? FirstOrDefault(this IQueryable source, Expression> predicate, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (predicate == null) + throw Error.ArgumentNull(nameof(predicate)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.FirstOrDefault_TSource_4(typeof(TSource)), + source.Expression, Expression.Quote(predicate), Expression.Constant(defaultValue, typeof(TSource)) + )); + } + [DynamicDependency("Last`1", typeof(Enumerable))] public static TSource Last(this IQueryable source) { @@ -916,6 +943,18 @@ public static TSource Last(this IQueryable source, Expression< CachedReflectionInfo.LastOrDefault_TSource_1(typeof(TSource)), source.Expression)); } + [DynamicDependency("LastOrDefault`1", typeof(Enumerable))] + public static TSource? LastOrDefault(this IQueryable source, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.LastOrDefault_TSource_3(typeof(TSource)), + source.Expression, Expression.Constant(defaultValue, typeof(TSource)))); + } + [DynamicDependency("LastOrDefault`1", typeof(Enumerable))] public static TSource? LastOrDefault(this IQueryable source, Expression> predicate) { @@ -931,6 +970,21 @@ public static TSource Last(this IQueryable source, Expression< )); } + [DynamicDependency("LastOrDefault`1", typeof(Enumerable))] + public static TSource? LastOrDefault(this IQueryable source, Expression> predicate, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (predicate == null) + throw Error.ArgumentNull(nameof(predicate)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.LastOrDefault_TSource_4(typeof(TSource)), + source.Expression, Expression.Quote(predicate), Expression.Constant(defaultValue, typeof(TSource)) + )); + } + [DynamicDependency("Single`1", typeof(Enumerable))] public static TSource Single(this IQueryable source) { @@ -968,6 +1022,19 @@ public static TSource Single(this IQueryable source, Expressio CachedReflectionInfo.SingleOrDefault_TSource_1(typeof(TSource)), source.Expression)); } + [DynamicDependency("SingleOrDefault`1", typeof(Enumerable))] + public static TSource? SingleOrDefault(this IQueryable source, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.SingleOrDefault_TSource_3(typeof(TSource)), + source.Expression, Expression.Constant(defaultValue, typeof(TSource)))); + + } + [DynamicDependency("SingleOrDefault`1", typeof(Enumerable))] public static TSource? SingleOrDefault(this IQueryable source, Expression> predicate) { @@ -983,6 +1050,21 @@ public static TSource Single(this IQueryable source, Expressio )); } + [DynamicDependency("SingleOrDefault`1", typeof(Enumerable))] + public static TSource? SingleOrDefault(this IQueryable source, Expression> predicate, TSource? defaultValue) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (predicate == null) + throw Error.ArgumentNull(nameof(predicate)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.SingleOrDefault_TSource_4(typeof(TSource)), + source.Expression, Expression.Quote(predicate), Expression.Constant(defaultValue, typeof(TSource)) + )); + } + [DynamicDependency("ElementAt`1", typeof(Enumerable))] public static TSource ElementAt(this IQueryable source, int index) { diff --git a/src/libraries/System.Linq/src/System/Linq/First.cs b/src/libraries/System.Linq/src/System/Linq/First.cs index 64b89915c259f8..6751a14e9f1550 100644 --- a/src/libraries/System.Linq/src/System/Linq/First.cs +++ b/src/libraries/System.Linq/src/System/Linq/First.cs @@ -31,12 +31,21 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source) => - source.TryGetFirst(out bool _); + source.TryGetFirst(out _); + + public static TSource? FirstOrDefault(this IEnumerable source, TSource? defaultValue) => + source.TryGetFirst(defaultValue, out _); public static TSource? FirstOrDefault(this IEnumerable source, Func predicate) => - source.TryGetFirst(predicate, out bool _); + source.TryGetFirst(predicate, out _); + + public static TSource? FirstOrDefault(this IEnumerable source, Func predicate, TSource? defaultValue) => + source.TryGetFirst(predicate, defaultValue, out _); - private static TSource? TryGetFirst(this IEnumerable source, out bool found) + private static TSource? TryGetFirst(this IEnumerable source, out bool found) => + source.TryGetFirst(default(TSource), out found); + + private static TSource? TryGetFirst(this IEnumerable source, TSource? defaultValue, out bool found) { if (source == null) { @@ -69,10 +78,13 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source, Func predicate, out bool found) + private static TSource? TryGetFirst(this IEnumerable source, Func predicate, out bool found) => + source.TryGetFirst(predicate, default(TSource), out found); + + private static TSource? TryGetFirst(this IEnumerable source, Func predicate, TSource? defaultValue, out bool found) { if (source == null) { @@ -94,7 +106,7 @@ public static TSource First(this IEnumerable source, Func(this IEnumerable source, Func(this IEnumerable source) => - source.TryGetLast(out bool _); + public static TSource? LastOrDefault(this IEnumerable source) + => source.TryGetLast(out _); + public static TSource? LastOrDefault(this IEnumerable source, TSource? defaultValue) + => source.TryGetLast(defaultValue, out _); - public static TSource? LastOrDefault(this IEnumerable source, Func predicate) => - source.TryGetLast(predicate, out bool _); + public static TSource? LastOrDefault(this IEnumerable source, Func predicate) + => source.TryGetLast(predicate, out bool _); + public static TSource? LastOrDefault(this IEnumerable source, Func predicate, TSource? defaultValue) + => source.TryGetLast(predicate, defaultValue, out bool _); private static TSource? TryGetLast(this IEnumerable source, out bool found) + => source.TryGetLast(default(TSource?), out found); + private static TSource? TryGetLast(this IEnumerable source, TSource? defaultValue, out bool found) { if (source == null) { @@ -77,10 +83,12 @@ public static TSource Last(this IEnumerable source, Func(this IEnumerable source, Func predicate, out bool found) + => source.TryGetLast(predicate, default(TSource?), out found); + private static TSource? TryGetLast(this IEnumerable source, Func predicate, TSource? defaultValue, out bool found) { if (source == null) { @@ -135,7 +143,7 @@ public static TSource Last(this IEnumerable source, Func(this IEnumerable source, Func(this IEnumerable source) + => source.SingleOrDefault(default(TSource)); + + public static TSource? SingleOrDefault(this IEnumerable source, TSource? defaultValue) { if (source == null) { @@ -95,7 +98,7 @@ public static TSource Single(this IEnumerable source, Func(this IEnumerable source, Func(this IEnumerable source, Func(this IEnumerable source, Func predicate) + => source.SingleOrDefault(predicate, default); + + public static TSource? SingleOrDefault(this IEnumerable source, Func predicate, TSource? defaultValue) { if (source == null) { @@ -153,7 +159,7 @@ public static TSource Single(this IEnumerable source, Func