diff --git a/src/libraries/System.Linq/src/System/Linq/AggregateBy.cs b/src/libraries/System.Linq/src/System/Linq/AggregateBy.cs index 054b1272db115a..53d81ae4faf536 100644 --- a/src/libraries/System.Linq/src/System/Linq/AggregateBy.cs +++ b/src/libraries/System.Linq/src/System/Linq/AggregateBy.cs @@ -49,7 +49,26 @@ public static IEnumerable> AggregateBy enumerator = source.GetEnumerator(); + + if (!enumerator.MoveNext()) + { + return []; + } + + Dictionary dict = new(keyComparer); + + do + { + TSource value = enumerator.Current; + TKey key = keySelector(value); + + ref TAccumulate? acc = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, key, out bool exists); + acc = func(exists ? acc! : seed, value); + } + while (enumerator.MoveNext()); + + return dict; } /// @@ -97,71 +116,26 @@ public static IEnumerable> AggregateBy> AggregateByIterator(IEnumerable source, Func keySelector, TAccumulate seed, Func func, IEqualityComparer? keyComparer) where TKey : notnull - { using IEnumerator enumerator = source.GetEnumerator(); if (!enumerator.MoveNext()) { - yield break; - } - - foreach (KeyValuePair countBy in PopulateDictionary(enumerator, keySelector, seed, func, keyComparer)) - { - yield return countBy; - } - - static Dictionary PopulateDictionary(IEnumerator enumerator, Func keySelector, TAccumulate seed, Func func, IEqualityComparer? keyComparer) - { - Dictionary dict = new(keyComparer); - - do - { - TSource value = enumerator.Current; - TKey key = keySelector(value); - - ref TAccumulate? acc = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, key, out bool exists); - acc = func(exists ? acc! : seed, value); - } - while (enumerator.MoveNext()); - - return dict; + return []; } - } - private static IEnumerable> AggregateByIterator(IEnumerable source, Func keySelector, Func seedSelector, Func func, IEqualityComparer? keyComparer) where TKey : notnull - { - using IEnumerator enumerator = source.GetEnumerator(); + Dictionary dict = new(keyComparer); - if (!enumerator.MoveNext()) + do { - yield break; - } + TSource value = enumerator.Current; + TKey key = keySelector(value); - foreach (KeyValuePair countBy in PopulateDictionary(enumerator, keySelector, seedSelector, func, keyComparer)) - { - yield return countBy; + ref TAccumulate? acc = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, key, out bool exists); + acc = func(exists ? acc! : seedSelector(key), value); } + while (enumerator.MoveNext()); - static Dictionary PopulateDictionary(IEnumerator enumerator, Func keySelector, Func seedSelector, Func func, IEqualityComparer? keyComparer) - { - Dictionary dict = new(keyComparer); - - do - { - TSource value = enumerator.Current; - TKey key = keySelector(value); - - ref TAccumulate? acc = ref CollectionsMarshal.GetValueRefOrAddDefault(dict, key, out bool exists); - acc = func(exists ? acc! : seedSelector(key), value); - } - while (enumerator.MoveNext()); - - return dict; - } + return dict; } } } diff --git a/src/libraries/System.Linq/src/System/Linq/CountBy.cs b/src/libraries/System.Linq/src/System/Linq/CountBy.cs index 43b450f309cda6..412d67caf3f5fe 100644 --- a/src/libraries/System.Linq/src/System/Linq/CountBy.cs +++ b/src/libraries/System.Linq/src/System/Linq/CountBy.cs @@ -33,26 +33,13 @@ public static IEnumerable> CountBy(this I return []; } - return CountByIterator(source, keySelector, keyComparer); - } - - private static IEnumerable> CountByIterator(IEnumerable source, Func keySelector, IEqualityComparer? keyComparer) where TKey : notnull - { using IEnumerator enumerator = source.GetEnumerator(); if (!enumerator.MoveNext()) { - yield break; - } - - foreach (KeyValuePair countBy in BuildCountDictionary(enumerator, keySelector, keyComparer)) - { - yield return countBy; + return []; } - } - private static Dictionary BuildCountDictionary(IEnumerator enumerator, Func keySelector, IEqualityComparer? keyComparer) where TKey : notnull - { Dictionary countsBy = new(keyComparer); do