From c8b6f6cce31de9df4e25235af42007d3bebcfb02 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Fri, 12 Mar 2021 22:56:24 +0000 Subject: [PATCH 1/3] Implement Enumerable.*By operators --- .../ref/System.Linq.Queryable.cs | 24 +- .../src/System/Linq/CachedReflection.cs | 84 ++++++ .../src/System/Linq/Queryable.cs | 239 ++++++++++++++++++ .../tests/DistinctTests.cs | 36 +++ .../tests/ExceptTests.cs | 46 ++++ .../tests/IntersectTests.cs | 46 ++++ .../System.Linq.Queryable/tests/MaxTests.cs | 49 ++++ .../System.Linq.Queryable/tests/MinTests.cs | 49 ++++ .../tests/TrimCompatibilityTests.cs | 2 +- .../System.Linq.Queryable/tests/UnionTests.cs | 46 ++++ src/libraries/System.Linq/ref/System.Linq.cs | 18 +- .../System.Linq/src/System/Linq/Distinct.cs | 28 ++ .../System.Linq/src/System/Linq/Except.cs | 33 +++ .../System.Linq/src/System/Linq/Intersect.cs | 34 ++- .../System.Linq/src/System/Linq/Max.cs | 117 ++++++++- .../System.Linq/src/System/Linq/Min.cs | 117 ++++++++- .../System.Linq/src/System/Linq/Union.cs | 41 +++ .../System.Linq/tests/DistinctTests.cs | 99 ++++++++ .../System.Linq/tests/ExceptTests.cs | 110 +++++++- .../System.Linq/tests/IntersectTests.cs | 114 +++++++++ src/libraries/System.Linq/tests/MaxTests.cs | 188 ++++++++++++++ src/libraries/System.Linq/tests/MinTests.cs | 188 ++++++++++++++ src/libraries/System.Linq/tests/UnionTests.cs | 121 +++++++++ 23 files changed, 1795 insertions(+), 34 deletions(-) diff --git a/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs b/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs index bd3497d53c6580..a6ae3b48620a9c 100644 --- a/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs +++ b/src/libraries/System.Linq.Queryable/ref/System.Linq.Queryable.cs @@ -77,18 +77,22 @@ public static partial class Queryable public static int Count(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static System.Linq.IQueryable DefaultIfEmpty(this System.Linq.IQueryable source) { throw null; } public static System.Linq.IQueryable DefaultIfEmpty(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } + public static System.Linq.IQueryable DistinctBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector) { throw null; } + public static System.Linq.IQueryable DistinctBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Linq.IQueryable Distinct(this System.Linq.IQueryable source) { throw null; } public static System.Linq.IQueryable Distinct(this System.Linq.IQueryable source, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } - public static TSource? ElementAtOrDefault(this System.Linq.IQueryable source, int index) { throw null; } public static TSource? ElementAtOrDefault(this System.Linq.IQueryable source, System.Index index) { throw null; } - public static TSource ElementAt(this System.Linq.IQueryable source, int index) { throw null; } + public static TSource? ElementAtOrDefault(this System.Linq.IQueryable source, int index) { throw null; } public static TSource ElementAt(this System.Linq.IQueryable source, System.Index index) { throw null; } + public static TSource ElementAt(this System.Linq.IQueryable source, int index) { throw null; } + public static System.Linq.IQueryable ExceptBy(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Linq.Expressions.Expression> keySelector) { throw null; } + public static System.Linq.IQueryable ExceptBy(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Linq.Expressions.Expression> keySelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Linq.IQueryable Except(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2) { throw null; } public static System.Linq.IQueryable Except(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? FirstOrDefault(this System.Linq.IQueryable source) { throw null; } - public static TSource FirstOrDefault(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } public static TSource? FirstOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static TSource FirstOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource defaultValue) { throw null; } + public static TSource FirstOrDefault(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } public static TSource First(this System.Linq.IQueryable source) { throw null; } public static TSource First(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static System.Linq.IQueryable> GroupBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector) { throw null; } @@ -101,21 +105,29 @@ public static partial class Queryable public static System.Linq.IQueryable GroupBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector, System.Linq.Expressions.Expression> elementSelector, System.Linq.Expressions.Expression, TResult>> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Linq.IQueryable GroupJoin(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Linq.Expressions.Expression, TResult>> resultSelector) { throw null; } public static System.Linq.IQueryable GroupJoin(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Linq.Expressions.Expression, TResult>> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } + public static System.Linq.IQueryable IntersectBy(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Linq.Expressions.Expression> keySelector) { throw null; } + public static System.Linq.IQueryable IntersectBy(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Linq.Expressions.Expression> keySelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Linq.IQueryable Intersect(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2) { throw null; } public static System.Linq.IQueryable Intersect(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Linq.IQueryable Join(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Linq.Expressions.Expression> resultSelector) { throw null; } public static System.Linq.IQueryable Join(this System.Linq.IQueryable outer, System.Collections.Generic.IEnumerable inner, System.Linq.Expressions.Expression> outerKeySelector, System.Linq.Expressions.Expression> innerKeySelector, System.Linq.Expressions.Expression> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? LastOrDefault(this System.Linq.IQueryable source) { throw null; } - public static TSource LastOrDefault(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } public static TSource? LastOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static TSource LastOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource defaultValue) { throw null; } + public static TSource LastOrDefault(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } public static TSource Last(this System.Linq.IQueryable source) { throw null; } public static TSource Last(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static long LongCount(this System.Linq.IQueryable source) { throw null; } public static long LongCount(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } + public static TSource? MaxBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector) { throw null; } + public static TSource? MaxBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector, System.Collections.Generic.IComparer? comparer) { throw null; } public static TSource? Max(this System.Linq.IQueryable source) { throw null; } + public static TSource? Max(this System.Linq.IQueryable source, System.Collections.Generic.IComparer? comparer) { throw null; } public static TResult? Max(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> selector) { throw null; } + public static TSource? MinBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector) { throw null; } + public static TSource? MinBy(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector, System.Collections.Generic.IComparer? comparer) { throw null; } public static TSource? Min(this System.Linq.IQueryable source) { throw null; } + public static TSource? Min(this System.Linq.IQueryable source, System.Collections.Generic.IComparer? comparer) { throw null; } public static TResult? Min(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> selector) { throw null; } public static System.Linq.IQueryable OfType(this System.Linq.IQueryable source) { throw null; } public static System.Linq.IOrderedQueryable OrderByDescending(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> keySelector) { throw null; } @@ -133,9 +145,9 @@ public static partial class Queryable public static bool SequenceEqual(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2) { throw null; } public static bool SequenceEqual(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? SingleOrDefault(this System.Linq.IQueryable source) { throw null; } - public static TSource SingleOrDefault(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } public static TSource? SingleOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static TSource SingleOrDefault(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate, TSource defaultValue) { throw null; } + public static TSource SingleOrDefault(this System.Linq.IQueryable source, TSource defaultValue) { throw null; } public static TSource Single(this System.Linq.IQueryable source) { throw null; } public static TSource Single(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } public static System.Linq.IQueryable SkipLast(this System.Linq.IQueryable source, int count) { throw null; } @@ -171,6 +183,8 @@ public static partial class Queryable public static System.Linq.IOrderedQueryable ThenByDescending(this System.Linq.IOrderedQueryable source, System.Linq.Expressions.Expression> keySelector, System.Collections.Generic.IComparer? comparer) { throw null; } public static System.Linq.IOrderedQueryable ThenBy(this System.Linq.IOrderedQueryable source, System.Linq.Expressions.Expression> keySelector) { throw null; } public static System.Linq.IOrderedQueryable ThenBy(this System.Linq.IOrderedQueryable source, System.Linq.Expressions.Expression> keySelector, System.Collections.Generic.IComparer? comparer) { throw null; } + public static System.Linq.IQueryable UnionBy(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Linq.Expressions.Expression> keySelector) { throw null; } + public static System.Linq.IQueryable UnionBy(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Linq.Expressions.Expression> keySelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Linq.IQueryable Union(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2) { throw null; } public static System.Linq.IQueryable Union(this System.Linq.IQueryable source1, System.Collections.Generic.IEnumerable source2, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Linq.IQueryable Where(this System.Linq.IQueryable source, System.Linq.Expressions.Expression> predicate) { throw null; } diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs index be11c326812cb5..3211d2c374b764 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/CachedReflection.cs @@ -224,6 +224,18 @@ public static MethodInfo Distinct_TSource_2(Type TSource) => (s_Distinct_TSource_2 ??= new Func, IEqualityComparer, IQueryable>(Queryable.Distinct).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_DistinctBy_TSource_TKey_2; + + public static MethodInfo DistinctBy_TSource_TKey_2(Type TSource, Type TKey) => + (s_DistinctBy_TSource_TKey_2 ??= new Func, Expression>, IQueryable>(Queryable.DistinctBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + + private static MethodInfo? s_DistinctBy_TSource_TKey_3; + + public static MethodInfo DistinctBy_TSource_TKey_3(Type TSource, Type TKey) => + (s_DistinctBy_TSource_TKey_3 ??= new Func, Expression>, IEqualityComparer, IQueryable>(Queryable.DistinctBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + private static MethodInfo? s_ElementAt_Int32_TSource_2; public static MethodInfo ElementAt_Int32_TSource_2(Type TSource) => @@ -260,6 +272,18 @@ public static MethodInfo Except_TSource_3(Type TSource) => (s_Except_TSource_3 ??= new Func, IEnumerable, IEqualityComparer, IQueryable>(Queryable.Except).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_ExceptBy_TSource_TKey_3; + + public static MethodInfo ExceptBy_TSource_TKey_3(Type TSource, Type TKey) => + (s_ExceptBy_TSource_TKey_3 ??= new Func, IEnumerable, Expression>, IQueryable>(Queryable.ExceptBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + + private static MethodInfo? s_ExceptBy_TSource_TKey_4; + + public static MethodInfo ExceptBy_TSource_TKey_4(Type TSource, Type TKey) => + (s_ExceptBy_TSource_TKey_4 ??= new Func, IEnumerable, Expression>, IEqualityComparer, IQueryable>(Queryable.ExceptBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + private static MethodInfo? s_First_TSource_1; public static MethodInfo First_TSource_1(Type TSource) => @@ -370,6 +394,18 @@ public static MethodInfo Intersect_TSource_3(Type TSource) => (s_Intersect_TSource_3 ??= new Func, IEnumerable, IEqualityComparer, IQueryable>(Queryable.Intersect).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_IntersectBy_TSource_TKey_3; + + public static MethodInfo IntersectBy_TSource_TKey_3(Type TSource, Type TKey) => + (s_IntersectBy_TSource_TKey_3 ??= new Func, IEnumerable, Expression>, IQueryable>(Queryable.IntersectBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + + private static MethodInfo? s_IntersectBy_TSource_TKey_4; + + public static MethodInfo IntersectBy_TSource_TKey_4(Type TSource, Type TKey) => + (s_IntersectBy_TSource_TKey_4 ??= new Func, IEnumerable, Expression>, IEqualityComparer, IQueryable>(Queryable.IntersectBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + private static MethodInfo? s_Join_TOuter_TInner_TKey_TResult_5; public static MethodInfo Join_TOuter_TInner_TKey_TResult_5(Type TOuter, Type TInner, Type TKey, Type TResult) => @@ -438,24 +474,60 @@ public static MethodInfo Max_TSource_1(Type TSource) => (s_Max_TSource_1 ??= new Func, object?>(Queryable.Max).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_Max_TSource_2; + + public static MethodInfo Max_TSource_2(Type TSource) => + (s_Max_TSource_2 ??= new Func, IComparer, object?>(Queryable.Max).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource); + private static MethodInfo? s_Max_TSource_TResult_2; public static MethodInfo Max_TSource_TResult_2(Type TSource, Type TResult) => (s_Max_TSource_TResult_2 ??= new Func, Expression>, object?>(Queryable.Max).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource, TResult); + private static MethodInfo? s_MaxBy_TSource_TKey_2; + + public static MethodInfo MaxBy_TSource_TKey_2(Type TSource, Type TKey) => + (s_MaxBy_TSource_TKey_2 ??= new Func, Expression>, object?>(Queryable.MaxBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + + private static MethodInfo? s_MaxBy_TSource_TKey_3; + + public static MethodInfo MaxBy_TSource_TKey_3(Type TSource, Type TKey) => + (s_MaxBy_TSource_TKey_3 ??= new Func, Expression>, IComparer, object?>(Queryable.MaxBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + private static MethodInfo? s_Min_TSource_1; public static MethodInfo Min_TSource_1(Type TSource) => (s_Min_TSource_1 ??= new Func, object?>(Queryable.Min).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_Min_TSource_2; + + public static MethodInfo Min_TSource_2(Type TSource) => + (s_Min_TSource_2 ??= new Func, IComparer, object?>(Queryable.Min).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource); + private static MethodInfo? s_Min_TSource_TResult_2; public static MethodInfo Min_TSource_TResult_2(Type TSource, Type TResult) => (s_Min_TSource_TResult_2 ??= new Func, Expression>, object?>(Queryable.Min).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource, TResult); + private static MethodInfo? s_MinBy_TSource_TKey_2; + + public static MethodInfo MinBy_TSource_TKey_2(Type TSource, Type TKey) => + (s_MinBy_TSource_TKey_2 ??= new Func, Expression>, object?>(Queryable.MinBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + + private static MethodInfo? s_MinBy_TSource_TKey_3; + + public static MethodInfo MinBy_TSource_TKey_3(Type TSource, Type TKey) => + (s_MinBy_TSource_TKey_3 ??= new Func, Expression>, IComparer, object?>(Queryable.MinBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + private static MethodInfo? s_OfType_TResult_1; public static MethodInfo OfType_TResult_1(Type TResult) => @@ -766,6 +838,18 @@ public static MethodInfo Union_TSource_3(Type TSource) => (s_Union_TSource_3 ??= new Func, IEnumerable, IEqualityComparer, IQueryable>(Queryable.Union).GetMethodInfo().GetGenericMethodDefinition()) .MakeGenericMethod(TSource); + private static MethodInfo? s_UnionBy_TSource_TKey_3; + + public static MethodInfo UnionBy_TSource_TKey_3(Type TSource, Type TKey) => + (s_UnionBy_TSource_TKey_3 ??= new Func, IEnumerable, Expression>, IQueryable>(Queryable.UnionBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + + private static MethodInfo? s_UnionBy_TSource_TKey_4; + + public static MethodInfo UnionBy_TSource_TKey_4(Type TSource, Type TKey) => + (s_UnionBy_TSource_TKey_4 ??= new Func, IEnumerable, Expression>, IEqualityComparer, IQueryable>(Queryable.UnionBy).GetMethodInfo().GetGenericMethodDefinition()) + .MakeGenericMethod(TSource, TKey); + private static MethodInfo? s_Where_TSource_2; public static MethodInfo Where_TSource_2(Type TSource) => diff --git a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs index 36a7789d464f84..2917955e9d2695 100644 --- a/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs +++ b/src/libraries/System.Linq.Queryable/src/System/Linq/Queryable.cs @@ -639,6 +639,36 @@ public static IQueryable Distinct(this IQueryable sou )); } + [DynamicDependency("DistinctBy`2", typeof(Enumerable))] + public static IQueryable DistinctBy(this IQueryable source, Expression> keySelector) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source.Provider.CreateQuery( + Expression.Call( + null, + CachedReflectionInfo.DistinctBy_TSource_TKey_2(typeof(TSource), typeof(TKey)), + source.Expression, Expression.Quote(keySelector) + )); + } + + [DynamicDependency("DistinctBy`2", typeof(Enumerable))] + public static IQueryable DistinctBy(this IQueryable source, Expression> keySelector, IEqualityComparer? comparer) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source.Provider.CreateQuery( + Expression.Call( + null, + CachedReflectionInfo.DistinctBy_TSource_TKey_3(typeof(TSource), typeof(TKey)), + source.Expression, Expression.Quote(keySelector), Expression.Constant(comparer, typeof(IEqualityComparer)) + )); + } + [DynamicDependency("Chunk`1", typeof(Enumerable))] public static IQueryable Chunk(this IQueryable source, int size) { @@ -763,6 +793,43 @@ public static IQueryable Union(this IQueryable source )); } + [DynamicDependency("UnionBy`2", typeof(Enumerable))] + public static IQueryable UnionBy(this IQueryable source1, IEnumerable source2, Expression> keySelector) + { + if (source1 == null) + throw Error.ArgumentNull(nameof(source1)); + if (source2 == null) + throw Error.ArgumentNull(nameof(source2)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source1.Provider.CreateQuery( + Expression.Call( + null, + CachedReflectionInfo.UnionBy_TSource_TKey_3(typeof(TSource), typeof(TKey)), + source1.Expression, GetSourceExpression(source2), Expression.Quote(keySelector) + )); + } + + [DynamicDependency("UnionBy`2", typeof(Enumerable))] + public static IQueryable UnionBy(this IQueryable source1, IEnumerable source2, Expression> keySelector, IEqualityComparer? comparer) + { + if (source1 == null) + throw Error.ArgumentNull(nameof(source1)); + if (source2 == null) + throw Error.ArgumentNull(nameof(source2)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source1.Provider.CreateQuery( + Expression.Call( + null, + CachedReflectionInfo.UnionBy_TSource_TKey_4(typeof(TSource), typeof(TKey)), + source1.Expression, + GetSourceExpression(source2), + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IEqualityComparer)) + )); + } + [DynamicDependency("Intersect`1", typeof(Enumerable))] public static IQueryable Intersect(this IQueryable source1, IEnumerable source2) { @@ -795,6 +862,45 @@ public static IQueryable Intersect(this IQueryable so )); } + [DynamicDependency("IntersectBy`2", typeof(Enumerable))] + public static IQueryable IntersectBy(this IQueryable source1, IEnumerable source2, Expression> keySelector) + { + if (source1 == null) + throw Error.ArgumentNull(nameof(source1)); + if (source2 == null) + throw Error.ArgumentNull(nameof(source2)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source1.Provider.CreateQuery( + Expression.Call( + null, + CachedReflectionInfo.IntersectBy_TSource_TKey_3(typeof(TSource), typeof(TKey)), + source1.Expression, + GetSourceExpression(source2), + Expression.Quote(keySelector) + )); + } + + [DynamicDependency("IntersectBy`2", typeof(Enumerable))] + public static IQueryable IntersectBy(this IQueryable source1, IEnumerable source2, Expression> keySelector, IEqualityComparer? comparer) + { + if (source1 == null) + throw Error.ArgumentNull(nameof(source1)); + if (source2 == null) + throw Error.ArgumentNull(nameof(source2)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source1.Provider.CreateQuery( + Expression.Call( + null, + CachedReflectionInfo.IntersectBy_TSource_TKey_4(typeof(TSource), typeof(TKey)), + source1.Expression, + GetSourceExpression(source2), + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IEqualityComparer)) + )); + } + [DynamicDependency("Except`1", typeof(Enumerable))] public static IQueryable Except(this IQueryable source1, IEnumerable source2) { @@ -827,6 +933,45 @@ public static IQueryable Except(this IQueryable sourc )); } + [DynamicDependency("ExceptBy`2", typeof(Enumerable))] + public static IQueryable ExceptBy(this IQueryable source1, IEnumerable source2, Expression> keySelector) + { + if (source1 == null) + throw Error.ArgumentNull(nameof(source1)); + if (source2 == null) + throw Error.ArgumentNull(nameof(source2)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source1.Provider.CreateQuery( + Expression.Call( + null, + CachedReflectionInfo.ExceptBy_TSource_TKey_3(typeof(TSource), typeof(TKey)), + source1.Expression, + GetSourceExpression(source2), + Expression.Quote(keySelector) + )); + } + + [DynamicDependency("ExceptBy`2", typeof(Enumerable))] + public static IQueryable ExceptBy(this IQueryable source1, IEnumerable source2, Expression> keySelector, IEqualityComparer? comparer) + { + if (source1 == null) + throw Error.ArgumentNull(nameof(source1)); + if (source2 == null) + throw Error.ArgumentNull(nameof(source2)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source1.Provider.CreateQuery( + Expression.Call( + null, + CachedReflectionInfo.ExceptBy_TSource_TKey_4(typeof(TSource), typeof(TKey)), + source1.Expression, + GetSourceExpression(source2), + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IEqualityComparer)) + )); + } + [DynamicDependency("First`1", typeof(Enumerable))] public static TSource First(this IQueryable source) { @@ -1339,6 +1484,20 @@ public static long LongCount(this IQueryable source, Expressio CachedReflectionInfo.Min_TSource_1(typeof(TSource)), source.Expression)); } + [DynamicDependency("Min`1", typeof(Enumerable))] + public static TSource? Min(this IQueryable source, IComparer? comparer) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.Min_TSource_2(typeof(TSource)), + source.Expression, + Expression.Constant(comparer, typeof(IComparer)) + )); + } + [DynamicDependency("Min`2", typeof(Enumerable))] public static TResult? Min(this IQueryable source, Expression> selector) { @@ -1354,6 +1513,39 @@ public static long LongCount(this IQueryable source, Expressio )); } + [DynamicDependency("MinBy`2", typeof(Enumerable))] + public static TSource? MinBy(this IQueryable source, Expression> keySelector) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.MinBy_TSource_TKey_2(typeof(TSource), typeof(TKey)), + source.Expression, + Expression.Quote(keySelector) + )); + } + + [DynamicDependency("MinBy`2", typeof(Enumerable))] + public static TSource? MinBy(this IQueryable source, Expression> keySelector, IComparer? comparer) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.MinBy_TSource_TKey_3(typeof(TSource), typeof(TKey)), + source.Expression, + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IComparer)) + )); + } + [DynamicDependency("Max`1", typeof(Enumerable))] public static TSource? Max(this IQueryable source) { @@ -1365,6 +1557,20 @@ public static long LongCount(this IQueryable source, Expressio CachedReflectionInfo.Max_TSource_1(typeof(TSource)), source.Expression)); } + [DynamicDependency("Max`1", typeof(Enumerable))] + public static TSource? Max(this IQueryable source, IComparer? comparer) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.Max_TSource_2(typeof(TSource)), + source.Expression, + Expression.Constant(comparer, typeof(IComparer)) + )); + } + [DynamicDependency("Max`2", typeof(Enumerable))] public static TResult? Max(this IQueryable source, Expression> selector) { @@ -1380,6 +1586,39 @@ public static long LongCount(this IQueryable source, Expressio )); } + [DynamicDependency("MaxBy`2", typeof(Enumerable))] + public static TSource? MaxBy(this IQueryable source, Expression> keySelector) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.MaxBy_TSource_TKey_2(typeof(TSource), typeof(TKey)), + source.Expression, + Expression.Quote(keySelector) + )); + } + + [DynamicDependency("MaxBy`2", typeof(Enumerable))] + public static TSource? MaxBy(this IQueryable source, Expression> keySelector, IComparer? comparer) + { + if (source == null) + throw Error.ArgumentNull(nameof(source)); + if (keySelector == null) + throw Error.ArgumentNull(nameof(keySelector)); + return source.Provider.Execute( + Expression.Call( + null, + CachedReflectionInfo.MaxBy_TSource_TKey_3(typeof(TSource), typeof(TKey)), + source.Expression, + Expression.Quote(keySelector), + Expression.Constant(comparer, typeof(IComparer)) + )); + } + [DynamicDependency("Sum", typeof(Enumerable))] public static int Sum(this IQueryable source) { diff --git a/src/libraries/System.Linq.Queryable/tests/DistinctTests.cs b/src/libraries/System.Linq.Queryable/tests/DistinctTests.cs index b252b287bfde60..44456901990320 100644 --- a/src/libraries/System.Linq.Queryable/tests/DistinctTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/DistinctTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Linq.Expressions; using Xunit; namespace System.Linq.Tests @@ -77,5 +78,40 @@ public void Distinct2() var count = (new int[] { 0, 1, 2, 2, 0 }).AsQueryable().Distinct(EqualityComparer.Default).Count(); Assert.Equal(3, count); } + + [Fact] + public void DistinctBy_NullSource_ThrowsArgumentNullException() + { + IQueryable source = null; + + AssertExtensions.Throws("source", () => source.DistinctBy(x => x)); + AssertExtensions.Throws("source", () => source.DistinctBy(x => x, EqualityComparer.Default)); + } + + [Fact] + public void DistinctBy_NullKeySelector_ThrowsArgumentNullException() + { + IQueryable source = Enumerable.Empty().AsQueryable(); + Expression> keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.DistinctBy(keySelector)); + AssertExtensions.Throws("keySelector", () => source.DistinctBy(keySelector, EqualityComparer.Default)); + } + + [Fact] + public void DistinctBy() + { + var expected = Enumerable.Range(0, 3); + var actual = Enumerable.Range(0, 20).AsQueryable().DistinctBy(x => x % 3).ToArray(); + Assert.Equal(expected, actual); + } + + [Fact] + public void DistinctBy_CustomComparison() + { + var expected = Enumerable.Range(0, 3); + var actual = Enumerable.Range(0, 20).AsQueryable().DistinctBy(x => x % 3, EqualityComparer.Default).ToArray(); + Assert.Equal(expected, actual); + } } } diff --git a/src/libraries/System.Linq.Queryable/tests/ExceptTests.cs b/src/libraries/System.Linq.Queryable/tests/ExceptTests.cs index b7aa76cdd463b4..7ccd478d1b29f0 100644 --- a/src/libraries/System.Linq.Queryable/tests/ExceptTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/ExceptTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Linq.Expressions; using Xunit; namespace System.Linq.Tests @@ -85,5 +86,50 @@ public void Except2() var count = (new int[] { 0, 1, 2 }).AsQueryable().Except((new int[] { 1, 2, 3 }).AsQueryable(), EqualityComparer.Default).Count(); Assert.Equal(1, count); } + + [Fact] + public void ExceptBy_NullSource1_ThrowsArgumentNullException() + { + IQueryable source1 = null; + + AssertExtensions.Throws("source1", () => source1.ExceptBy(Enumerable.Empty(), x => x)); + AssertExtensions.Throws("source1", () => source1.ExceptBy(Enumerable.Empty(), x => x, EqualityComparer.Default)); + } + + [Fact] + public void ExceptBy_NullSource2_ThrowsArgumentNullException() + { + IQueryable source1 = Enumerable.Empty().AsQueryable(); + IQueryable source2 = null; + + AssertExtensions.Throws("source2", () => source1.ExceptBy(source2, x => x)); + AssertExtensions.Throws("source2", () => source1.ExceptBy(source2, x => x, EqualityComparer.Default)); + } + + [Fact] + public void ExceptBy_NullKeySelector_ThrowsArgumentNullException() + { + IQueryable source = Enumerable.Empty().AsQueryable(); + Expression> keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.ExceptBy(source, keySelector)); + AssertExtensions.Throws("keySelector", () => source.ExceptBy(source, keySelector, EqualityComparer.Default)); + } + + [Fact] + public void ExceptBy() + { + var expected = Enumerable.Range(5, 5); + var actual = Enumerable.Range(0, 10).AsQueryable().ExceptBy(Enumerable.Range(0, 5), x => x).ToArray(); + Assert.Equal(expected, actual); + } + + [Fact] + public void ExceptBy_CustomComparison() + { + var expected = Enumerable.Range(5, 5); + var actual = Enumerable.Range(0, 10).AsQueryable().ExceptBy(Enumerable.Range(0, 5), x => x, EqualityComparer.Default).ToArray(); + Assert.Equal(expected, actual); + } } } diff --git a/src/libraries/System.Linq.Queryable/tests/IntersectTests.cs b/src/libraries/System.Linq.Queryable/tests/IntersectTests.cs index 0731d99aa8b425..184a618f64ef3b 100644 --- a/src/libraries/System.Linq.Queryable/tests/IntersectTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/IntersectTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Linq.Expressions; using Xunit; namespace System.Linq.Tests @@ -83,5 +84,50 @@ public void Intersect2() var count = (new int[] { 0, 1, 2 }).AsQueryable().Intersect((new int[] { 1, 2, 3 }).AsQueryable(), EqualityComparer.Default).Count(); Assert.Equal(2, count); } + + [Fact] + public void IntersectBy_NullSource1_ThrowsArgumentNullException() + { + IQueryable source1 = null; + + AssertExtensions.Throws("source1", () => source1.IntersectBy(Enumerable.Empty(), x => x)); + AssertExtensions.Throws("source1", () => source1.IntersectBy(Enumerable.Empty(), x => x, EqualityComparer.Default)); + } + + [Fact] + public void IntersectBy_NullSource2_ThrowsArgumentNullException() + { + IQueryable source1 = Enumerable.Empty().AsQueryable(); + IQueryable source2 = null; + + AssertExtensions.Throws("source2", () => source1.IntersectBy(source2, x => x)); + AssertExtensions.Throws("source2", () => source1.IntersectBy(source2, x => x, EqualityComparer.Default)); + } + + [Fact] + public void IntersectBy_NullKeySelector_ThrowsArgumentNullException() + { + IQueryable source = Enumerable.Empty().AsQueryable(); + Expression> keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.IntersectBy(source, keySelector)); + AssertExtensions.Throws("keySelector", () => source.IntersectBy(source, keySelector, EqualityComparer.Default)); + } + + [Fact] + public void IntersectBy() + { + var expected = Enumerable.Range(5, 5); + var actual = Enumerable.Range(0, 10).AsQueryable().IntersectBy(Enumerable.Range(5, 20), x => x).ToArray(); + Assert.Equal(expected, actual); + } + + [Fact] + public void IntersectBy_CustomComparison() + { + var expected = Enumerable.Range(5, 5); + var actual = Enumerable.Range(0, 10).AsQueryable().IntersectBy(Enumerable.Range(5, 20), x => x, EqualityComparer.Default).ToArray(); + Assert.Equal(expected, actual); + } } } diff --git a/src/libraries/System.Linq.Queryable/tests/MaxTests.cs b/src/libraries/System.Linq.Queryable/tests/MaxTests.cs index 4354d895d3ee33..84a56d5b856d5a 100644 --- a/src/libraries/System.Linq.Queryable/tests/MaxTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/MaxTests.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; using System.Linq.Expressions; using Xunit; @@ -589,5 +590,53 @@ public void Max2() var val = (new int[] { 0, 2, 1 }).AsQueryable().Max(n => n); Assert.Equal(2, val); } + + [Fact] + public void Max_CustomComparer_NullSource_ThrowsArgumentNullException() + { + IQueryable source = null; + AssertExtensions.Throws("source", () => source.Max(Comparer.Default)); + } + + [Fact] + public void Max_CustomComparer() + { + IComparer comparer = Comparer.Create((x, y) => -x.CompareTo(y)); + IQueryable source = Enumerable.Range(1, 10).AsQueryable(); + Assert.Equal(1, source.Max(comparer)); + } + + [Fact] + public void MaxBy_NullSource_ThrowsArgumentNullException() + { + IQueryable source = null; + + AssertExtensions.Throws("source", () => source.MaxBy(x => x)); + AssertExtensions.Throws("source", () => source.MaxBy(x => x, Comparer.Default)); + } + + [Fact] + public void MaxBy_NullKeySelector_ThrowsArgumentNullException() + { + IQueryable source = Enumerable.Empty().AsQueryable(); + Expression> keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.MaxBy(keySelector)); + AssertExtensions.Throws("keySelector", () => source.MaxBy(keySelector, Comparer.Default)); + } + + [Fact] + public void MaxBy() + { + IQueryable source = Enumerable.Range(1, 20).AsQueryable(); + Assert.Equal(1, source.MaxBy(x => -x)); + } + + [Fact] + public void MaxBy_CustomComparer() + { + IQueryable source = Enumerable.Range(1, 20).AsQueryable(); + Assert.Equal(20, source.MaxBy(x => -x, Comparer.Create((x, y) => -x.CompareTo(y)))); + } } } diff --git a/src/libraries/System.Linq.Queryable/tests/MinTests.cs b/src/libraries/System.Linq.Queryable/tests/MinTests.cs index 10bf12b4e17c97..8e9a8a6680b869 100644 --- a/src/libraries/System.Linq.Queryable/tests/MinTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/MinTests.cs @@ -1,6 +1,7 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Collections.Generic; using System.Linq.Expressions; using Xunit; @@ -557,5 +558,53 @@ public void Min2() var val = (new int[] { 0, 2, 1 }).AsQueryable().Min(n => n); Assert.Equal(0, val); } + + [Fact] + public void Min_CustomComparer_NullSource_ThrowsArgumentNullException() + { + IQueryable source = null; + AssertExtensions.Throws("source", () => source.Min(Comparer.Default)); + } + + [Fact] + public void Min_CustomComparer() + { + IComparer comparer = Comparer.Create((x, y) => -x.CompareTo(y)); + IQueryable source = Enumerable.Range(1, 10).AsQueryable(); + Assert.Equal(10, source.Min(comparer)); + } + + [Fact] + public void MinBy_NullSource_ThrowsArgumentNullException() + { + IQueryable source = null; + + AssertExtensions.Throws("source", () => source.MinBy(x => x)); + AssertExtensions.Throws("source", () => source.MinBy(x => x, Comparer.Default)); + } + + [Fact] + public void MinBy_NullKeySelector_ThrowsArgumentNullException() + { + IQueryable source = Enumerable.Empty().AsQueryable(); + Expression> keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.MinBy(keySelector)); + AssertExtensions.Throws("keySelector", () => source.MinBy(keySelector, Comparer.Default)); + } + + [Fact] + public void MinBy() + { + IQueryable source = Enumerable.Range(1, 20).AsQueryable(); + Assert.Equal(20, source.MinBy(x => -x)); + } + + [Fact] + public void MinBy_CustomComparer() + { + IQueryable source = Enumerable.Range(1, 20).AsQueryable(); + Assert.Equal(1, source.MinBy(x => -x, Comparer.Create((x, y) => -x.CompareTo(y)))); + } } } diff --git a/src/libraries/System.Linq.Queryable/tests/TrimCompatibilityTests.cs b/src/libraries/System.Linq.Queryable/tests/TrimCompatibilityTests.cs index b1463804c69086..6a4583253edbec 100644 --- a/src/libraries/System.Linq.Queryable/tests/TrimCompatibilityTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/TrimCompatibilityTests.cs @@ -61,7 +61,7 @@ public static void CachedReflectionInfoMethodsNoAnnotations() .Where(m => m.GetParameters().Length > 0); // If you are adding a new method to this class, ensure the method meets these requirements - Assert.Equal(117, methods.Count()); + Assert.Equal(131, methods.Count()); foreach (MethodInfo method in methods) { ParameterInfo[] parameters = method.GetParameters(); diff --git a/src/libraries/System.Linq.Queryable/tests/UnionTests.cs b/src/libraries/System.Linq.Queryable/tests/UnionTests.cs index fe63038371739e..1d6d9ae6afe0a2 100644 --- a/src/libraries/System.Linq.Queryable/tests/UnionTests.cs +++ b/src/libraries/System.Linq.Queryable/tests/UnionTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Collections.Generic; +using System.Linq.Expressions; using Xunit; namespace System.Linq.Tests @@ -79,5 +80,50 @@ public void Union2() var count = (new int[] { 0, 1, 2 }).AsQueryable().Union((new int[] { 1, 2, 3 }).AsQueryable(), EqualityComparer.Default).Count(); Assert.Equal(4, count); } + + [Fact] + public void UnionBy_NullSource1_ThrowsArgumentNullException() + { + IQueryable source1 = null; + + AssertExtensions.Throws("source1", () => source1.UnionBy(Enumerable.Empty(), x => x)); + AssertExtensions.Throws("source1", () => source1.UnionBy(Enumerable.Empty(), x => x, EqualityComparer.Default)); + } + + [Fact] + public void UnionBy_NullSource2_ThrowsArgumentNullException() + { + IQueryable source1 = Enumerable.Empty().AsQueryable(); + IQueryable source2 = null; + + AssertExtensions.Throws("source2", () => source1.UnionBy(source2, x => x)); + AssertExtensions.Throws("source2", () => source1.UnionBy(source2, x => x, EqualityComparer.Default)); + } + + [Fact] + public void UnionBy_NullKeySelector_ThrowsArgumentNullException() + { + IQueryable source = Enumerable.Empty().AsQueryable(); + Expression> keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.UnionBy(source, keySelector)); + AssertExtensions.Throws("keySelector", () => source.UnionBy(source, keySelector, EqualityComparer.Default)); + } + + [Fact] + public void UnionBy() + { + var expected = Enumerable.Range(0, 10); + var actual = Enumerable.Range(0, 5).AsQueryable().UnionBy(Enumerable.Range(5, 5), x => x).ToArray(); + Assert.Equal(expected, actual); + } + + [Fact] + public void UnionBy_CustomComparison() + { + var expected = Enumerable.Range(0, 10); + var actual = Enumerable.Range(0, 5).AsQueryable().UnionBy(Enumerable.Range(5, 5), x => x, EqualityComparer.Default).ToArray(); + Assert.Equal(expected, actual); + } } } diff --git a/src/libraries/System.Linq/ref/System.Linq.cs b/src/libraries/System.Linq/ref/System.Linq.cs index 0f09240a0249de..5b2d9996512aea 100644 --- a/src/libraries/System.Linq/ref/System.Linq.cs +++ b/src/libraries/System.Linq/ref/System.Linq.cs @@ -49,13 +49,17 @@ public static System.Collections.Generic.IEnumerable< public static int Count(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } public static System.Collections.Generic.IEnumerable DefaultIfEmpty(this System.Collections.Generic.IEnumerable source) { throw null; } public static System.Collections.Generic.IEnumerable DefaultIfEmpty(this System.Collections.Generic.IEnumerable source, TSource defaultValue) { throw null; } + public static System.Collections.Generic.IEnumerable DistinctBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector) { throw null; } + public static System.Collections.Generic.IEnumerable DistinctBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Collections.Generic.IEnumerable Distinct(this System.Collections.Generic.IEnumerable source) { throw null; } public static System.Collections.Generic.IEnumerable Distinct(this System.Collections.Generic.IEnumerable source, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } - public static TSource? ElementAtOrDefault(this System.Collections.Generic.IEnumerable source, int index) { throw null; } public static TSource? ElementAtOrDefault(this System.Collections.Generic.IEnumerable source, System.Index index) { throw null; } - public static TSource ElementAt(this System.Collections.Generic.IEnumerable source, int index) { throw null; } + public static TSource? ElementAtOrDefault(this System.Collections.Generic.IEnumerable source, int index) { throw null; } public static TSource ElementAt(this System.Collections.Generic.IEnumerable source, System.Index index) { throw null; } + public static TSource ElementAt(this System.Collections.Generic.IEnumerable source, int index) { throw null; } public static System.Collections.Generic.IEnumerable Empty() { throw null; } + public static System.Collections.Generic.IEnumerable ExceptBy(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Func keySelector) { throw null; } + public static System.Collections.Generic.IEnumerable ExceptBy(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Collections.Generic.IEnumerable Except(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second) { throw null; } public static System.Collections.Generic.IEnumerable Except(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static TSource? FirstOrDefault(this System.Collections.Generic.IEnumerable source) { throw null; } @@ -74,6 +78,8 @@ public static System.Collections.Generic.IEnumerable< public static System.Collections.Generic.IEnumerable GroupBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Func elementSelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Collections.Generic.IEnumerable GroupJoin(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func, TResult> resultSelector) { throw null; } public static System.Collections.Generic.IEnumerable GroupJoin(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func, TResult> resultSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } + public static System.Collections.Generic.IEnumerable IntersectBy(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Func keySelector) { throw null; } + public static System.Collections.Generic.IEnumerable IntersectBy(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Collections.Generic.IEnumerable Intersect(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second) { throw null; } public static System.Collections.Generic.IEnumerable Intersect(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Collections.Generic.IEnumerable Join(this System.Collections.Generic.IEnumerable outer, System.Collections.Generic.IEnumerable inner, System.Func outerKeySelector, System.Func innerKeySelector, System.Func resultSelector) { throw null; } @@ -96,7 +102,10 @@ public static System.Collections.Generic.IEnumerable< public static long? Max(this System.Collections.Generic.IEnumerable source) { throw null; } public static float? Max(this System.Collections.Generic.IEnumerable source) { throw null; } public static float Max(this System.Collections.Generic.IEnumerable source) { throw null; } + public static TSource? MaxBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector) { throw null; } + public static TSource? MaxBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Collections.Generic.IComparer? comparer) { throw null; } public static TSource? Max(this System.Collections.Generic.IEnumerable source) { throw null; } + public static TSource? Max(this System.Collections.Generic.IEnumerable source, System.Collections.Generic.IComparer? comparer) { throw null; } public static decimal Max(this System.Collections.Generic.IEnumerable source, System.Func selector) { throw null; } public static double Max(this System.Collections.Generic.IEnumerable source, System.Func selector) { throw null; } public static int Max(this System.Collections.Generic.IEnumerable source, System.Func selector) { throw null; } @@ -118,7 +127,10 @@ public static System.Collections.Generic.IEnumerable< public static long? Min(this System.Collections.Generic.IEnumerable source) { throw null; } public static float? Min(this System.Collections.Generic.IEnumerable source) { throw null; } public static float Min(this System.Collections.Generic.IEnumerable source) { throw null; } + public static TSource? MinBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector) { throw null; } + public static TSource? MinBy(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Collections.Generic.IComparer? comparer) { throw null; } public static TSource? Min(this System.Collections.Generic.IEnumerable source) { throw null; } + public static TSource? Min(this System.Collections.Generic.IEnumerable source, System.Collections.Generic.IComparer? comparer) { throw null; } public static decimal Min(this System.Collections.Generic.IEnumerable source, System.Func selector) { throw null; } public static double Min(this System.Collections.Generic.IEnumerable source, System.Func selector) { throw null; } public static int Min(this System.Collections.Generic.IEnumerable source, System.Func selector) { throw null; } @@ -199,6 +211,8 @@ public static System.Collections.Generic.IEnumerable< public static System.Linq.ILookup ToLookup(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Func elementSelector) { throw null; } public static System.Linq.ILookup ToLookup(this System.Collections.Generic.IEnumerable source, System.Func keySelector, System.Func elementSelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static bool TryGetNonEnumeratedCount(this System.Collections.Generic.IEnumerable source, out int count) { throw null; } + public static System.Collections.Generic.IEnumerable UnionBy(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Func keySelector) { throw null; } + public static System.Collections.Generic.IEnumerable UnionBy(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Func keySelector, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Collections.Generic.IEnumerable Union(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second) { throw null; } public static System.Collections.Generic.IEnumerable Union(this System.Collections.Generic.IEnumerable first, System.Collections.Generic.IEnumerable second, System.Collections.Generic.IEqualityComparer? comparer) { throw null; } public static System.Collections.Generic.IEnumerable Where(this System.Collections.Generic.IEnumerable source, System.Func predicate) { throw null; } diff --git a/src/libraries/System.Linq/src/System/Linq/Distinct.cs b/src/libraries/System.Linq/src/System/Linq/Distinct.cs index 5c59fc8164d32c..5d5af39b306c30 100644 --- a/src/libraries/System.Linq/src/System/Linq/Distinct.cs +++ b/src/libraries/System.Linq/src/System/Linq/Distinct.cs @@ -20,6 +20,34 @@ public static IEnumerable Distinct(this IEnumerable s return new DistinctIterator(source, comparer); } + public static IEnumerable DistinctBy(this IEnumerable source, Func keySelector) => DistinctBy(source, keySelector, null); + + public static IEnumerable DistinctBy(this IEnumerable source, Func keySelector, IEqualityComparer? comparer) + { + if (source is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); + } + if (keySelector is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); + } + + return DistinctByIterator(source, keySelector, comparer); + } + + private static IEnumerable DistinctByIterator(IEnumerable source, Func keySelector, IEqualityComparer? comparer) + { + var set = new HashSet(DefaultInternalSetCapacity, comparer); + foreach (TSource element in source) + { + if (set.Add(keySelector(element))) + { + yield return element; + } + } + } + /// /// An iterator that yields the distinct values in an . /// diff --git a/src/libraries/System.Linq/src/System/Linq/Except.cs b/src/libraries/System.Linq/src/System/Linq/Except.cs index b3e0f45075d2d1..f590cfd5488d0b 100644 --- a/src/libraries/System.Linq/src/System/Linq/Except.cs +++ b/src/libraries/System.Linq/src/System/Linq/Except.cs @@ -37,6 +37,26 @@ public static IEnumerable Except(this IEnumerable fir return ExceptIterator(first, second, comparer); } + public static IEnumerable ExceptBy(this IEnumerable first, IEnumerable second, Func keySelector) => ExceptBy(first, second, keySelector, null); + + public static IEnumerable ExceptBy(this IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer) + { + if (first is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.first); + } + if (second is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.second); + } + if (keySelector is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); + } + + return ExceptByIterator(first, second, keySelector, comparer); + } + private static IEnumerable ExceptIterator(IEnumerable first, IEnumerable second, IEqualityComparer? comparer) { var set = new HashSet(second, comparer); @@ -49,5 +69,18 @@ private static IEnumerable ExceptIterator(IEnumerable } } } + + private static IEnumerable ExceptByIterator(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer) + { + var set = new HashSet(second, comparer); + + foreach (TSource element in first) + { + if (set.Add(keySelector(element))) + { + yield return element; + } + } + } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Intersect.cs b/src/libraries/System.Linq/src/System/Linq/Intersect.cs index 88c679106aba2d..a9519a6e73a7b0 100644 --- a/src/libraries/System.Linq/src/System/Linq/Intersect.cs +++ b/src/libraries/System.Linq/src/System/Linq/Intersect.cs @@ -7,7 +7,9 @@ namespace System.Linq { public static partial class Enumerable { - public static IEnumerable Intersect(this IEnumerable first, IEnumerable second) + public static IEnumerable Intersect(this IEnumerable first, IEnumerable second) => Intersect(first, second, null); + + public static IEnumerable Intersect(this IEnumerable first, IEnumerable second, IEqualityComparer? comparer) { if (first == null) { @@ -19,22 +21,27 @@ public static IEnumerable Intersect(this IEnumerable ThrowHelper.ThrowArgumentNullException(ExceptionArgument.second); } - return IntersectIterator(first, second, null); + return IntersectIterator(first, second, comparer); } - public static IEnumerable Intersect(this IEnumerable first, IEnumerable second, IEqualityComparer? comparer) + public static IEnumerable IntersectBy(this IEnumerable first, IEnumerable second, Func keySelector) => IntersectBy(first, second, keySelector, null); + + public static IEnumerable IntersectBy(this IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer) { - if (first == null) + if (first is null) { ThrowHelper.ThrowArgumentNullException(ExceptionArgument.first); } - - if (second == null) + if (second is null) { ThrowHelper.ThrowArgumentNullException(ExceptionArgument.second); } + if (keySelector is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); + } - return IntersectIterator(first, second, comparer); + return IntersectByIterator(first, second, keySelector, comparer); } private static IEnumerable IntersectIterator(IEnumerable first, IEnumerable second, IEqualityComparer? comparer) @@ -49,5 +56,18 @@ private static IEnumerable IntersectIterator(IEnumerable IntersectByIterator(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer) + { + var set = new HashSet(second, comparer); + + foreach (TSource element in first) + { + if (set.Remove(keySelector(element))) + { + yield return element; + } + } + } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Max.cs b/src/libraries/System.Linq/src/System/Linq/Max.cs index c430306156400c..8e9f918d0724cb 100644 --- a/src/libraries/System.Linq/src/System/Linq/Max.cs +++ b/src/libraries/System.Linq/src/System/Linq/Max.cs @@ -441,13 +441,16 @@ public static decimal Max(this IEnumerable source) return value; } - public static TSource? Max(this IEnumerable source) + public static TSource? Max(this IEnumerable source) => Max(source, comparer: null); + public static TSource? Max(this IEnumerable source, IComparer? comparer) { if (source == null) { ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); } + comparer ??= Comparer.Default; + TSource? value = default; using (IEnumerator e = source.GetEnumerator()) { @@ -464,13 +467,12 @@ public static decimal Max(this IEnumerable source) } while (value == null); - Comparer comparer = Comparer.Default; while (e.MoveNext()) { - TSource x = e.Current; - if (x != null && comparer.Compare(x, value) > 0) + TSource next = e.Current; + if (next != null && comparer.Compare(next, value) > 0) { - value = x; + value = next; } } } @@ -482,12 +484,111 @@ public static decimal Max(this IEnumerable source) } value = e.Current; + if (comparer == Comparer.Default) + { + while (e.MoveNext()) + { + TSource next = e.Current; + if (Comparer.Default.Compare(next, value) > 0) + { + value = next; + } + } + } + else + { + while (e.MoveNext()) + { + TSource next = e.Current; + if (comparer.Compare(next, value) > 0) + { + value = next; + } + } + } + } + } + + return value; + } + + public static TSource? MaxBy(this IEnumerable source, Func keySelector) => MaxBy(source, keySelector, null); + public static TSource? MaxBy(this IEnumerable source, Func keySelector, IComparer? comparer) + { + if (source == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); + } + + if (keySelector == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); + } + + comparer ??= Comparer.Default; + + TKey? key = default; + TSource? value = default; + using (IEnumerator e = source.GetEnumerator()) + { + if (key == null) + { + do + { + if (!e.MoveNext()) + { + return value; + } + + value = e.Current; + key = keySelector(value); + } + while (key == null); + while (e.MoveNext()) { - TSource x = e.Current; - if (Comparer.Default.Compare(x, value) > 0) + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (nextKey != null && comparer.Compare(nextKey, key) > 0) { - value = x; + key = nextKey; + value = nextValue; + } + } + } + else + { + if (!e.MoveNext()) + { + ThrowHelper.ThrowNoElementsException(); + } + + value = e.Current; + key = keySelector(value); + if (comparer == Comparer.Default) + { + while (e.MoveNext()) + { + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (Comparer.Default.Compare(nextKey, key) > 0) + { + key = nextKey; + value = nextValue; + } + } + } + else + { + while (e.MoveNext()) + { + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (comparer.Compare(nextKey, key) > 0) + { + key = nextKey; + value = nextValue; + } } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Min.cs b/src/libraries/System.Linq/src/System/Linq/Min.cs index f9531778910a42..fb90cdf348143a 100644 --- a/src/libraries/System.Linq/src/System/Linq/Min.cs +++ b/src/libraries/System.Linq/src/System/Linq/Min.cs @@ -399,13 +399,16 @@ public static decimal Min(this IEnumerable source) return value; } - public static TSource? Min(this IEnumerable source) + public static TSource? Min(this IEnumerable source) => Min(source, comparer: null); + public static TSource? Min(this IEnumerable source, IComparer? comparer) { if (source == null) { ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); } + comparer ??= Comparer.Default; + TSource? value = default; using (IEnumerator e = source.GetEnumerator()) { @@ -422,13 +425,12 @@ public static decimal Min(this IEnumerable source) } while (value == null); - Comparer comparer = Comparer.Default; while (e.MoveNext()) { - TSource x = e.Current; - if (x != null && comparer.Compare(x, value) < 0) + TSource next = e.Current; + if (next != null && comparer.Compare(next, value) < 0) { - value = x; + value = next; } } } @@ -440,12 +442,111 @@ public static decimal Min(this IEnumerable source) } value = e.Current; + if (comparer == Comparer.Default) + { + while (e.MoveNext()) + { + TSource next = e.Current; + if (Comparer.Default.Compare(next, value) < 0) + { + value = next; + } + } + } + else + { + while (e.MoveNext()) + { + TSource next = e.Current; + if (comparer.Compare(next, value) < 0) + { + value = next; + } + } + } + } + } + + return value; + } + + public static TSource? MinBy(this IEnumerable source, Func keySelector) => MinBy(source, keySelector, comparer: null); + public static TSource? MinBy(this IEnumerable source, Func keySelector, IComparer? comparer) + { + if (source == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.source); + } + + if (keySelector == null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); + } + + comparer ??= Comparer.Default; + + TKey? key = default; + TSource? value = default; + using (IEnumerator e = source.GetEnumerator()) + { + if (key == null) + { + do + { + if (!e.MoveNext()) + { + return value; + } + + value = e.Current; + key = keySelector(value); + } + while (key == null); + while (e.MoveNext()) { - TSource x = e.Current; - if (Comparer.Default.Compare(x, value) < 0) + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (nextKey != null && comparer.Compare(nextKey, key) < 0) { - value = x; + key = nextKey; + value = nextValue; + } + } + } + else + { + if (!e.MoveNext()) + { + ThrowHelper.ThrowNoElementsException(); + } + + value = e.Current; + key = keySelector(value); + if (comparer == Comparer.Default) + { + while (e.MoveNext()) + { + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (Comparer.Default.Compare(nextKey, key) < 0) + { + key = nextKey; + value = nextValue; + } + } + } + else + { + while (e.MoveNext()) + { + TSource nextValue = e.Current; + TKey nextKey = keySelector(nextValue); + if (comparer.Compare(nextKey, key) < 0) + { + key = nextKey; + value = nextValue; + } } } } diff --git a/src/libraries/System.Linq/src/System/Linq/Union.cs b/src/libraries/System.Linq/src/System/Linq/Union.cs index d9b3c4bdb065b7..bc7eb36843d721 100644 --- a/src/libraries/System.Linq/src/System/Linq/Union.cs +++ b/src/libraries/System.Linq/src/System/Linq/Union.cs @@ -26,6 +26,47 @@ public static IEnumerable Union(this IEnumerable firs return first is UnionIterator union && AreEqualityComparersEqual(comparer, union._comparer) ? union.Union(second) : new UnionIterator2(first, second, comparer); } + public static IEnumerable UnionBy(this IEnumerable first, IEnumerable second, Func keySelector) => UnionBy(first, second, keySelector, null); + + public static IEnumerable UnionBy(this IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer) + { + if (first is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.first); + } + if (second is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.second); + } + if (keySelector is null) + { + ThrowHelper.ThrowArgumentNullException(ExceptionArgument.keySelector); + } + + return UnionByIterator(first, second, keySelector, comparer); + } + + private static IEnumerable UnionByIterator(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer) + { + var set = new HashSet(DefaultInternalSetCapacity, comparer); + + foreach (TSource element in first) + { + if (set.Add(keySelector(element))) + { + yield return element; + } + } + + foreach (TSource element in second) + { + if (set.Add(keySelector(element))) + { + yield return element; + } + } + } + /// /// An iterator that yields distinct values from two or more . /// diff --git a/src/libraries/System.Linq/tests/DistinctTests.cs b/src/libraries/System.Linq/tests/DistinctTests.cs index a5a438d5dcb83c..7408e96ddb38ce 100644 --- a/src/libraries/System.Linq/tests/DistinctTests.cs +++ b/src/libraries/System.Linq/tests/DistinctTests.cs @@ -267,5 +267,104 @@ public void RepeatEnumerating() Assert.Equal(result, result); } + + [Fact] + public void DistinctBy_SourceNull_ThrowsArgumentNullException() + { + string[] first = null; + + AssertExtensions.Throws("source", () => first.DistinctBy(x => x)); + AssertExtensions.Throws("source", () => first.DistinctBy(x => x, new AnagramEqualityComparer())); + } + + [Fact] + public void DistinctBy_KeySelectorNull_ThrowsArgumentNullException() + { + string[] source = { "Bob", "Tim", "Robert", "Chris" }; + Func keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.DistinctBy(keySelector)); + AssertExtensions.Throws("keySelector", () => source.DistinctBy(keySelector, new AnagramEqualityComparer())); + } + + [Theory] + [MemberData(nameof(DistinctBy_TestData))] + public static void DistinctBy_HasExpectedOutput(IEnumerable source, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + { + Assert.Equal(expected, source.DistinctBy(keySelector, comparer)); + } + + [Theory] + [MemberData(nameof(DistinctBy_TestData))] + public static void DistinctBy_RunOnce_HasExpectedOutput(IEnumerable source, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + { + Assert.Equal(expected, source.RunOnce().DistinctBy(keySelector, comparer)); + } + + public static IEnumerable DistinctBy_TestData() + { + yield return WrapArgs( + source: Enumerable.Range(0, 10), + keySelector: x => x, + comparer: null, + expected: Enumerable.Range(0, 10)); + + yield return WrapArgs( + source: Enumerable.Range(5, 10), + keySelector: x => true, + comparer: null, + expected: new int[] { 5 }); + + yield return WrapArgs( + source: Enumerable.Range(0, 20), + keySelector: x => x % 5, + comparer: null, + expected: Enumerable.Range(0, 5)); + + yield return WrapArgs( + source: Enumerable.Repeat(5, 20), + keySelector: x => x, + comparer: null, + expected: Enumerable.Repeat(5, 1)); + + yield return WrapArgs( + source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" }, + keySelector: x => x, + null, + expected: new string[] { "Bob", "bob", "tim", "Tim" }); + + yield return WrapArgs( + source: new string[] { "Bob", "bob", "tim", "Bob", "Tim" }, + keySelector: x => x, + StringComparer.OrdinalIgnoreCase, + expected: new string[] { "Bob", "tim" }); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }, + keySelector: x => x.Age, + comparer: null, + expected: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 20), ("Harry", 40) }, + keySelector: x => x.Age, + comparer: null, + expected: new (string Name, int Age)[] { ("Tom", 20), ("Harry", 40) }); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) }, + keySelector: x => x.Name, + comparer: null, + expected: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) }); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Bob", 20), ("bob", 30), ("Harry", 40) }, + keySelector: x => x.Name, + comparer: StringComparer.OrdinalIgnoreCase, + expected: new (string Name, int Age)[] { ("Bob", 20), ("Harry", 40) }); + + object[] WrapArgs(IEnumerable source, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + => new object[] { source, keySelector, comparer, expected }; + } } } diff --git a/src/libraries/System.Linq/tests/ExceptTests.cs b/src/libraries/System.Linq/tests/ExceptTests.cs index 5cb24615ca356f..36976b7cc57a54 100644 --- a/src/libraries/System.Linq/tests/ExceptTests.cs +++ b/src/libraries/System.Linq/tests/ExceptTests.cs @@ -27,9 +27,6 @@ public void SameResultsRepeatCallsStringQuery() var q2 = from x2 in new[] { "!@#$%^", "C", "AAA", "", "Calling Twice", "SoS" } select x2; - var rst1 = q1.Except(q2); - var rst2 = q1.Except(q2); - Assert.Equal(q1.Except(q2), q1.Except(q2)); } @@ -141,5 +138,112 @@ public void HashSetWithBuiltInComparer_HashSetContainsNotUsed() Assert.Equal(new[] { "A" }, input2.Except(input1, EqualityComparer.Default)); Assert.Equal(Enumerable.Empty(), input2.Except(input1, StringComparer.OrdinalIgnoreCase)); } + + [Fact] + public void ExceptBy_FirstNull_ThrowsArgumentNullException() + { + string[] first = null; + string[] second = { "bBo", "shriC" }; + + AssertExtensions.Throws("first", () => first.ExceptBy(second, x => x)); + AssertExtensions.Throws("first", () => first.ExceptBy(second, x => x, new AnagramEqualityComparer())); + } + + [Fact] + public void ExceptBy_SecondNull_ThrowsArgumentNullException() + { + string[] first = { "Bob", "Tim", "Robert", "Chris" }; + string[] second = null; + + AssertExtensions.Throws("second", () => first.ExceptBy(second, x => x)); + AssertExtensions.Throws("second", () => first.ExceptBy(second, x => x, new AnagramEqualityComparer())); + } + + [Fact] + public void ExceptBy_KeySelectorNull_ThrowsArgumentNullException() + { + string[] first = { "Bob", "Tim", "Robert", "Chris" }; + string[] second = { "bBo", "shriC" }; + Func keySelector = null; + + AssertExtensions.Throws("keySelector", () => first.ExceptBy(second, keySelector)); + AssertExtensions.Throws("keySelector", () => first.ExceptBy(second, keySelector, new AnagramEqualityComparer())); + } + + [Theory] + [MemberData(nameof(ExceptBy_TestData))] + public static void ExceptBy_HasExpectedOutput(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + { + Assert.Equal(expected, first.ExceptBy(second, keySelector, comparer)); + } + + [Theory] + [MemberData(nameof(ExceptBy_TestData))] + public static void ExceptBy_RunOnce_HasExpectedOutput(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + { + Assert.Equal(expected, first.RunOnce().ExceptBy(second.RunOnce(), keySelector, comparer)); + } + + public static IEnumerable ExceptBy_TestData() + { + yield return WrapArgs( + first: Enumerable.Range(0, 10), + second: Enumerable.Range(0, 5), + keySelector: x => x, + comparer: null, + expected: Enumerable.Range(5, 5)); + + yield return WrapArgs( + first: Enumerable.Repeat(5, 20), + second: Enumerable.Empty(), + keySelector: x => x, + comparer: null, + expected: Enumerable.Repeat(5, 1)); + + yield return WrapArgs( + first: Enumerable.Repeat(5, 20), + second: Enumerable.Repeat(5, 3), + keySelector: x => x, + comparer: null, + expected: Enumerable.Empty()); + + yield return WrapArgs( + first: new string[] { "Bob", "Tim", "Robert", "Chris" }, + second: new string[] { "bBo", "shriC" }, + keySelector: x => x, + null, + expected: new string[] { "Bob", "Tim", "Robert", "Chris" }); + + yield return WrapArgs( + first: new string[] { "Bob", "Tim", "Robert", "Chris" }, + second: new string[] { "bBo", "shriC" }, + keySelector: x => x, + new AnagramEqualityComparer(), + expected: new string[] { "Tim", "Robert" }); + + yield return WrapArgs( + first: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }, + second: new int[] { 15, 20, 40 }, + keySelector: x => x.Age, + comparer: null, + expected: new (string Name, int Age)[] { ("Dick", 30) }); + + yield return WrapArgs( + first: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }, + second: new string[] { "moT" }, + keySelector: x => x.Name, + comparer: null, + expected: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }); + + yield return WrapArgs( + first: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }, + second: new string[] { "moT" }, + keySelector: x => x.Name, + comparer: new AnagramEqualityComparer(), + expected: new (string Name, int Age)[] { ("Dick", 30), ("Harry", 40) }); + + object[] WrapArgs(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + => new object[] { first, second, keySelector, comparer, expected }; + } } } diff --git a/src/libraries/System.Linq/tests/IntersectTests.cs b/src/libraries/System.Linq/tests/IntersectTests.cs index 361854b69c170c..7b1d5fdfbd873d 100644 --- a/src/libraries/System.Linq/tests/IntersectTests.cs +++ b/src/libraries/System.Linq/tests/IntersectTests.cs @@ -135,5 +135,119 @@ public void HashSetWithBuiltInComparer_HashSetContainsNotUsed() Assert.Equal(Enumerable.Empty(), input2.Intersect(input1, EqualityComparer.Default)); Assert.Equal(new[] { "A" }, input2.Intersect(input1, StringComparer.OrdinalIgnoreCase)); } + + [Fact] + public void IntersectBy_FirstNull_ThrowsArgumentNullException() + { + string[] first = null; + string[] second = { "bBo", "shriC" }; + + AssertExtensions.Throws("first", () => first.IntersectBy(second, x => x)); + AssertExtensions.Throws("first", () => first.IntersectBy(second, x => x, new AnagramEqualityComparer())); + } + + [Fact] + public void IntersectBy_SecondNull_ThrowsArgumentNullException() + { + string[] first = { "Bob", "Tim", "Robert", "Chris" }; + string[] second = null; + + AssertExtensions.Throws("second", () => first.IntersectBy(second, x => x)); + AssertExtensions.Throws("second", () => first.IntersectBy(second, x => x, new AnagramEqualityComparer())); + } + + [Fact] + public void IntersectBy_KeySelectorNull_ThrowsArgumentNullException() + { + string[] first = { "Bob", "Tim", "Robert", "Chris" }; + string[] second = { "bBo", "shriC" }; + Func keySelector = null; + + AssertExtensions.Throws("keySelector", () => first.IntersectBy(second, keySelector)); + AssertExtensions.Throws("keySelector", () => first.IntersectBy(second, keySelector, new AnagramEqualityComparer())); + } + + [Theory] + [MemberData(nameof(IntersectBy_TestData))] + public static void IntersectBy_HasExpectedOutput(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + { + Assert.Equal(expected, first.IntersectBy(second, keySelector, comparer)); + } + + [Theory] + [MemberData(nameof(IntersectBy_TestData))] + public static void IntersectBy_RunOnce_HasExpectedOutput(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + { + Assert.Equal(expected, first.RunOnce().IntersectBy(second.RunOnce(), keySelector, comparer)); + } + + public static IEnumerable IntersectBy_TestData() + { + yield return WrapArgs( + first: Enumerable.Range(0, 10), + second: Enumerable.Range(0, 5), + keySelector: x => x, + comparer: null, + expected: Enumerable.Range(0, 5)); + + yield return WrapArgs( + first: Enumerable.Range(0, 10), + second: Enumerable.Range(10, 10), + keySelector: x => x, + comparer: null, + expected: Enumerable.Empty()); + + yield return WrapArgs( + first: Enumerable.Repeat(5, 20), + second: Enumerable.Empty(), + keySelector: x => x, + comparer: null, + expected: Enumerable.Empty()); + + yield return WrapArgs( + first: Enumerable.Repeat(5, 20), + second: Enumerable.Repeat(5, 3), + keySelector: x => x, + comparer: null, + expected: Enumerable.Repeat(5, 1)); + + yield return WrapArgs( + first: new string[] { "Bob", "Tim", "Robert", "Chris" }, + second: new string[] { "bBo", "shriC" }, + keySelector: x => x, + null, + expected: Array.Empty()); + + yield return WrapArgs( + first: new string[] { "Bob", "Tim", "Robert", "Chris" }, + second: new string[] { "bBo", "shriC" }, + keySelector: x => x, + new AnagramEqualityComparer(), + expected: new string[] { "Bob", "Chris" }); + + yield return WrapArgs( + first: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }, + second: new int[] { 15, 20, 40 }, + keySelector: x => x.Age, + comparer: null, + expected: new (string Name, int Age)[] { ("Tom", 20), ("Harry", 40) }); + + yield return WrapArgs( + first: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }, + second: new string[] { "moT" }, + keySelector: x => x.Name, + comparer: null, + expected: Array.Empty<(string Name, int Age)>()); + + yield return WrapArgs( + first: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40) }, + second: new string[] { "moT" }, + keySelector: x => x.Name, + comparer: new AnagramEqualityComparer(), + expected: new (string Name, int Age)[] { ("Tom", 20) }); + + object[] WrapArgs(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + => new object[] { first, second, keySelector, comparer, expected }; + } } } diff --git a/src/libraries/System.Linq/tests/MaxTests.cs b/src/libraries/System.Linq/tests/MaxTests.cs index 6c1e56a62df1e5..ed66912a8c2cc2 100644 --- a/src/libraries/System.Linq/tests/MaxTests.cs +++ b/src/libraries/System.Linq/tests/MaxTests.cs @@ -770,5 +770,193 @@ public void Max_Boolean_EmptySource_ThrowsInvalidOperationException() { Assert.Throws(() => Enumerable.Empty().Max()); } + + [Fact] + public static void Max_Generic_NullSource_ThrowsArgumentNullException() + { + IEnumerable source = null; + + AssertExtensions.Throws("source", () => source.Max()); + AssertExtensions.Throws("source", () => source.Max(comparer: null)); + AssertExtensions.Throws("source", () => source.Max(Comparer.Create((_, _) => 0))); + } + + [Fact] + public static void Max_Generic_EmptyStructSource_ThrowsInvalidOperationException() + { + Assert.Throws(() => Enumerable.Empty().Max()); + Assert.Throws(() => Enumerable.Empty().Max(comparer: null)); + Assert.Throws(() => Enumerable.Empty().Max(Comparer.Create((_,_) => 0))); + } + + [Theory] + [MemberData(nameof(Max_Generic_TestData))] + public static void Max_Generic_HasExpectedOutput(IEnumerable source, IComparer? comparer, TSource? expected) + { + Assert.Equal(expected, source.Max(comparer)); + } + + [Theory] + [MemberData(nameof(Max_Generic_TestData))] + public static void Max_Generic_RunOnce_HasExpectedOutput(IEnumerable source, IComparer? comparer, TSource? expected) + { + Assert.Equal(expected, source.RunOnce().Max(comparer)); + } + + public static IEnumerable Max_Generic_TestData() + { + yield return WrapArgs( + source: Enumerable.Empty(), + comparer: null, + expected: null); + + yield return WrapArgs( + source: Enumerable.Empty(), + comparer: Comparer.Create((_,_) => 0), + expected: null); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + comparer: null, + expected: 9); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: 0); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + comparer: Comparer.Create((x,y) => 0), + expected: 0); + + yield return WrapArgs( + source: new string[] { "Aardvark", "Zyzzyva", "Zebra", "Antelope" }, + comparer: null, + expected: "Zyzzyva"); + + yield return WrapArgs( + source: new string[] { "Aardvark", "Zyzzyva", "Zebra", "Antelope" }, + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: "Aardvark"); + + object[] WrapArgs(IEnumerable source, IComparer? comparer, TSource? expected) + => new object[] { source, comparer, expected }; + } + + [Fact] + public static void MaxBy_Generic_NullSource_ThrowsArgumentNullException() + { + IEnumerable source = null; + + AssertExtensions.Throws("source", () => source.MaxBy(x => x)); + AssertExtensions.Throws("source", () => source.MaxBy(x => x, comparer: null)); + AssertExtensions.Throws("source", () => source.MaxBy(x => x, Comparer.Create((_, _) => 0))); + } + + [Fact] + public static void MaxBy_Generic_NullKeySelector_ThrowsArgumentNullException() + { + IEnumerable source = Enumerable.Empty(); + Func keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.MaxBy(keySelector)); + AssertExtensions.Throws("keySelector", () => source.MaxBy(keySelector, comparer: null)); + AssertExtensions.Throws("keySelector", () => source.MaxBy(keySelector, Comparer.Create((_, _) => 0))); + } + + [Fact] + public static void MaxBy_Generic_EmptyStructSource_ThrowsInvalidOperationException() + { + Assert.Throws(() => Enumerable.Empty().MaxBy(x => x)); + Assert.Throws(() => Enumerable.Empty().MaxBy(x => x, comparer: null)); + Assert.Throws(() => Enumerable.Empty().MaxBy(x => x, Comparer.Create((_, _) => 0))); + } + + [Theory] + [MemberData(nameof(MaxBy_Generic_TestData))] + public static void MaxBy_Generic_HasExpectedOutput(IEnumerable source, Func keySelector, IComparer? comparer, TSource? expected) + { + Assert.Equal(expected, source.MaxBy(keySelector, comparer)); + } + + [Theory] + [MemberData(nameof(MaxBy_Generic_TestData))] + public static void MaxBy_Generic_RunOnce_HasExpectedOutput(IEnumerable source, Func keySelector, IComparer? comparer, TSource? expected) + { + Assert.Equal(expected, source.RunOnce().MaxBy(keySelector, comparer)); + } + + public static IEnumerable MaxBy_Generic_TestData() + { + yield return WrapArgs( + source: Enumerable.Empty(), + keySelector: x => x, + comparer: null, + expected: null); + + yield return WrapArgs( + source: Enumerable.Empty(), + keySelector: x => x, + comparer: Comparer.Create((_, _) => 0), + expected: null); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + keySelector: x => x, + comparer: null, + expected: 9); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + keySelector: x => x, + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: 0); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + keySelector: x => x, + comparer: Comparer.Create((x, y) => 0), + expected: 0); + + yield return WrapArgs( + source: new string[] { "Aardvark", "Zyzzyva", "Zebra", "Antelope" }, + keySelector: x => x, + comparer: null, + expected: "Zyzzyva"); + + yield return WrapArgs( + source: new string[] { "Aardvark", "Zyzzyva", "Zebra", "Antelope" }, + keySelector: x => x, + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: "Aardvark"); + + yield return WrapArgs( + source: new (string Name, int Age) [] { ("Tom", 43), ("Dick", 55), ("Harry", 20) }, + keySelector: x => x.Age, + comparer: null, + expected: (Name: "Dick", Age: 55)); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 43), ("Dick", 55), ("Harry", 20) }, + keySelector: x => x.Age, + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: (Name: "Harry", Age: 20)); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 43), ("Dick", 55), ("Harry", 20) }, + keySelector: x => x.Name, + comparer: null, + expected: (Name: "Tom", Age: 43)); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 43), ("Dick", 55), ("Harry", 20) }, + keySelector: x => x.Name, + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: (Name: "Dick", Age: 55)); + + object[] WrapArgs(IEnumerable source, Func keySelector, IComparer? comparer, TSource? expected) + => new object[] { source, keySelector, comparer, expected }; + } } } diff --git a/src/libraries/System.Linq/tests/MinTests.cs b/src/libraries/System.Linq/tests/MinTests.cs index 44a9b1a1e4d97d..3d6eef37d09000 100644 --- a/src/libraries/System.Linq/tests/MinTests.cs +++ b/src/libraries/System.Linq/tests/MinTests.cs @@ -748,5 +748,193 @@ public void Min_Bool_EmptySource_ThrowsInvalodOperationException() { Assert.Throws(() => Enumerable.Empty().Min()); } + + [Fact] + public static void Min_Generic_NullSource_ThrowsArgumentNullException() + { + IEnumerable source = null; + + AssertExtensions.Throws("source", () => source.Min()); + AssertExtensions.Throws("source", () => source.Min(comparer: null)); + AssertExtensions.Throws("source", () => source.Min(Comparer.Create((_, _) => 0))); + } + + [Fact] + public static void Min_Generic_EmptyStructSource_ThrowsInvalidOperationException() + { + Assert.Throws(() => Enumerable.Empty().Min()); + Assert.Throws(() => Enumerable.Empty().Min(comparer: null)); + Assert.Throws(() => Enumerable.Empty().Min(Comparer.Create((_, _) => 0))); + } + + [Theory] + [MemberData(nameof(Min_Generic_TestData))] + public static void Min_Generic_HasExpectedOutput(IEnumerable source, IComparer? comparer, TSource? expected) + { + Assert.Equal(expected, source.Min(comparer)); + } + + [Theory] + [MemberData(nameof(Min_Generic_TestData))] + public static void Min_Generic_RunOnce_HasExpectedOutput(IEnumerable source, IComparer? comparer, TSource? expected) + { + Assert.Equal(expected, source.RunOnce().Min(comparer)); + } + + public static IEnumerable Min_Generic_TestData() + { + yield return WrapArgs( + source: Enumerable.Empty(), + comparer: null, + expected: null); + + yield return WrapArgs( + source: Enumerable.Empty(), + comparer: Comparer.Create((_, _) => 0), + expected: null); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + comparer: null, + expected: 0); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: 9); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + comparer: Comparer.Create((x, y) => 0), + expected: 0); + + yield return WrapArgs( + source: new string[] { "Aardvark", "Zyzzyva", "Zebra", "Antelope" }, + comparer: null, + expected: "Aardvark"); + + yield return WrapArgs( + source: new string[] { "Aardvark", "Zyzzyva", "Zebra", "Antelope" }, + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: "Zyzzyva"); + + object[] WrapArgs(IEnumerable source, IComparer? comparer, TSource? expected) + => new object[] { source, comparer, expected }; + } + + [Fact] + public static void MinBy_Generic_NullSource_ThrowsArgumentNullException() + { + IEnumerable source = null; + + AssertExtensions.Throws("source", () => source.MinBy(x => x)); + AssertExtensions.Throws("source", () => source.MinBy(x => x, comparer: null)); + AssertExtensions.Throws("source", () => source.MinBy(x => x, Comparer.Create((_, _) => 0))); + } + + [Fact] + public static void MinBy_Generic_NullKeySelector_ThrowsArgumentNullException() + { + IEnumerable source = Enumerable.Empty(); + Func keySelector = null; + + AssertExtensions.Throws("keySelector", () => source.MinBy(keySelector)); + AssertExtensions.Throws("keySelector", () => source.MinBy(keySelector, comparer: null)); + AssertExtensions.Throws("keySelector", () => source.MinBy(keySelector, Comparer.Create((_, _) => 0))); + } + + [Fact] + public static void MinBy_Generic_EmptyStructSource_ThrowsInvalidOperationException() + { + Assert.Throws(() => Enumerable.Empty().MinBy(x => x)); + Assert.Throws(() => Enumerable.Empty().MinBy(x => x, comparer: null)); + Assert.Throws(() => Enumerable.Empty().MinBy(x => x, Comparer.Create((_, _) => 0))); + } + + [Theory] + [MemberData(nameof(MinBy_Generic_TestData))] + public static void MinBy_Generic_HasExpectedOutput(IEnumerable source, Func keySelector, IComparer? comparer, TSource? expected) + { + Assert.Equal(expected, source.MinBy(keySelector, comparer)); + } + + [Theory] + [MemberData(nameof(MinBy_Generic_TestData))] + public static void MinBy_Generic_RunOnce_HasExpectedOutput(IEnumerable source, Func keySelector, IComparer? comparer, TSource? expected) + { + Assert.Equal(expected, source.RunOnce().MinBy(keySelector, comparer)); + } + + public static IEnumerable MinBy_Generic_TestData() + { + yield return WrapArgs( + source: Enumerable.Empty(), + keySelector: x => x, + comparer: null, + expected: null); + + yield return WrapArgs( + source: Enumerable.Empty(), + keySelector: x => x, + comparer: Comparer.Create((_, _) => 0), + expected: null); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + keySelector: x => x, + comparer: null, + expected: 0); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + keySelector: x => x, + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: 9); + + yield return WrapArgs( + source: Enumerable.Range(0, 10), + keySelector: x => x, + comparer: Comparer.Create((x, y) => 0), + expected: 0); + + yield return WrapArgs( + source: new string[] { "Aardvark", "Zyzzyva", "Zebra", "Antelope" }, + keySelector: x => x, + comparer: null, + expected: "Aardvark"); + + yield return WrapArgs( + source: new string[] { "Aardvark", "Zyzzyva", "Zebra", "Antelope" }, + keySelector: x => x, + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: "Zyzzyva"); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 43), ("Dick", 55), ("Harry", 20) }, + keySelector: x => x.Age, + comparer: null, + expected: (Name: "Harry", Age: 20)); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 43), ("Dick", 55), ("Harry", 20) }, + keySelector: x => x.Age, + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: (Name: "Dick", Age: 55)); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 43), ("Dick", 55), ("Harry", 20) }, + keySelector: x => x.Name, + comparer: null, + expected: (Name: "Dick", Age: 55)); + + yield return WrapArgs( + source: new (string Name, int Age)[] { ("Tom", 43), ("Dick", 55), ("Harry", 20) }, + keySelector: x => x.Name, + comparer: Comparer.Create((x, y) => -x.CompareTo(y)), + expected: (Name: "Tom", Age: 43)); + + object[] WrapArgs(IEnumerable source, Func keySelector, IComparer? comparer, TSource? expected) + => new object[] { source, keySelector, comparer, expected }; + } } } diff --git a/src/libraries/System.Linq/tests/UnionTests.cs b/src/libraries/System.Linq/tests/UnionTests.cs index 9c95220e37800d..9dbfd822b16395 100644 --- a/src/libraries/System.Linq/tests/UnionTests.cs +++ b/src/libraries/System.Linq/tests/UnionTests.cs @@ -413,5 +413,126 @@ public void HashSetWithBuiltInComparer_HashSetContainsNotUsed() Assert.Equal(new[] { "A", "a" }, input2.Union(input1, EqualityComparer.Default)); Assert.Equal(new[] { "A" }, input2.Union(input1, StringComparer.OrdinalIgnoreCase)); } + + [Fact] + public void UnionBy_FirstNull_ThrowsArgumentNullException() + { + string[] first = null; + string[] second = { "bBo", "shriC" }; + + AssertExtensions.Throws("first", () => first.UnionBy(second, x => x)); + AssertExtensions.Throws("first", () => first.UnionBy(second, x => x, new AnagramEqualityComparer())); + } + + [Fact] + public void UnionBy_SecondNull_ThrowsArgumentNullException() + { + string[] first = { "Bob", "Tim", "Robert", "Chris" }; + string[] second = null; + + AssertExtensions.Throws("second", () => first.UnionBy(second, x => x)); + AssertExtensions.Throws("second", () => first.UnionBy(second, x => x, new AnagramEqualityComparer())); + } + + [Fact] + public void UnionBy_KeySelectorNull_ThrowsArgumentNullException() + { + string[] first = { "Bob", "Tim", "Robert", "Chris" }; + string[] second = { "bBo", "shriC" }; + Func keySelector = null; + + AssertExtensions.Throws("keySelector", () => first.UnionBy(second, keySelector)); + AssertExtensions.Throws("keySelector", () => first.UnionBy(second, keySelector, new AnagramEqualityComparer())); + } + + [Theory] + [MemberData(nameof(UnionBy_TestData))] + public static void UnionBy_HasExpectedOutput(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + { + Assert.Equal(expected, first.UnionBy(second, keySelector, comparer)); + } + + [Theory] + [MemberData(nameof(UnionBy_TestData))] + public static void UnionBy_RunOnce_HasExpectedOutput(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + { + Assert.Equal(expected, first.RunOnce().UnionBy(second.RunOnce(), keySelector, comparer)); + } + + public static IEnumerable UnionBy_TestData() + { + yield return WrapArgs( + first: Enumerable.Range(0, 7), + second: Enumerable.Range(3, 7), + keySelector: x => x, + comparer: null, + expected: Enumerable.Range(0, 10)); + + yield return WrapArgs( + first: Enumerable.Range(0, 10), + second: Enumerable.Range(10, 10), + keySelector: x => x, + comparer: null, + expected: Enumerable.Range(0, 20)); + + yield return WrapArgs( + first: Enumerable.Empty(), + second: Enumerable.Range(0, 5), + keySelector: x => x, + comparer: null, + expected: Enumerable.Range(0, 5)); + + yield return WrapArgs( + first: Enumerable.Repeat(5, 20), + second: Enumerable.Empty(), + keySelector: x => x, + comparer: null, + expected: Enumerable.Repeat(5, 1)); + + yield return WrapArgs( + first: Enumerable.Repeat(5, 20), + second: Enumerable.Repeat(5, 3), + keySelector: x => x, + comparer: null, + expected: Enumerable.Repeat(5, 1)); + + yield return WrapArgs( + first: new string[] { "Bob", "Tim", "Robert", "Chris" }, + second: new string[] { "bBo", "shriC" }, + keySelector: x => x, + null, + expected: new string[] { "Bob", "Tim", "Robert", "Chris", "bBo", "shriC" }); + + yield return WrapArgs( + first: new string[] { "Bob", "Tim", "Robert", "Chris" }, + second: new string[] { "bBo", "shriC" }, + keySelector: x => x, + new AnagramEqualityComparer(), + expected: new string[] { "Bob", "Tim", "Robert", "Chris" }); + + yield return WrapArgs( + first: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40), ("Martin", 20) }, + second: new (string Name, int Age)[] { ("Peter", 21), ("John", 30), ("Toby", 33) }, + keySelector: x => x.Age, + comparer: null, + expected: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40), ("Peter", 21), ("Toby", 33) }); + + yield return WrapArgs( + first: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40), ("Martin", 20) }, + second: new (string Name, int Age)[] { ("Toby", 33), ("Harry", 35), ("tom", 67) }, + keySelector: x => x.Name, + comparer: null, + expected: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40), ("Martin", 20), ("Toby", 33), ("tom", 67) }); + + yield return WrapArgs( + first: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40), ("Martin", 20) }, + second: new (string Name, int Age)[] { ("Toby", 33), ("Harry", 35), ("tom", 67) }, + keySelector: x => x.Name, + comparer: StringComparer.OrdinalIgnoreCase, + expected: new (string Name, int Age)[] { ("Tom", 20), ("Dick", 30), ("Harry", 40), ("Martin", 20), ("Toby", 33) }); + + object[] WrapArgs(IEnumerable first, IEnumerable second, Func keySelector, IEqualityComparer? comparer, IEnumerable expected) + => new object[] { first, second, keySelector, comparer, expected }; + } } } From 9203a82b788b93f1fdac642e6ba890a374d7fb26 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Mon, 29 Mar 2021 18:30:35 +0100 Subject: [PATCH 2/3] apply feedback --- .../System.Linq/src/System/Linq/Distinct.cs | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/Distinct.cs b/src/libraries/System.Linq/src/System/Linq/Distinct.cs index 5d5af39b306c30..ec053396db47ef 100644 --- a/src/libraries/System.Linq/src/System/Linq/Distinct.cs +++ b/src/libraries/System.Linq/src/System/Linq/Distinct.cs @@ -38,13 +38,21 @@ public static IEnumerable DistinctBy(this IEnumerable DistinctByIterator(IEnumerable source, Func keySelector, IEqualityComparer? comparer) { - var set = new HashSet(DefaultInternalSetCapacity, comparer); - foreach (TSource element in source) + using IEnumerator enumerator = source.GetEnumerator(); + + if (enumerator.MoveNext()) { - if (set.Add(keySelector(element))) + var set = new HashSet(DefaultInternalSetCapacity, comparer); + do { - yield return element; + TSource element = enumerator.Current; + if (set.Add(keySelector(element))) + { + yield return element; + } + } + while (enumerator.MoveNext()); } } From c9d275b1893a16368752c93d5ed766960d04a101 Mon Sep 17 00:00:00 2001 From: Eirik Tsarpalis Date: Fri, 9 Apr 2021 16:33:03 +0100 Subject: [PATCH 3/3] Update src/libraries/System.Linq/src/System/Linq/Distinct.cs Co-authored-by: Stephen Toub --- src/libraries/System.Linq/src/System/Linq/Distinct.cs | 1 - 1 file changed, 1 deletion(-) diff --git a/src/libraries/System.Linq/src/System/Linq/Distinct.cs b/src/libraries/System.Linq/src/System/Linq/Distinct.cs index ec053396db47ef..008128024302bf 100644 --- a/src/libraries/System.Linq/src/System/Linq/Distinct.cs +++ b/src/libraries/System.Linq/src/System/Linq/Distinct.cs @@ -50,7 +50,6 @@ private static IEnumerable DistinctByIterator(IEnumerabl { yield return element; } - } while (enumerator.MoveNext()); }