Skip to content

Commit

Permalink
CollectionIterator improvements (#409)
Browse files Browse the repository at this point in the history
* Mark iterator classes `sealed`
* Use `TryGetCollectionCount()` instead of `ICollection<>`
  • Loading branch information
viceroypenguin authored May 11, 2023
1 parent ea91311 commit d147954
Show file tree
Hide file tree
Showing 7 changed files with 49 additions and 25 deletions.
17 changes: 10 additions & 7 deletions Source/SuperLinq/Do.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace SuperLinq;
using System.Diagnostics.CodeAnalysis;

namespace SuperLinq;

public partial class SuperEnumerable
{
Expand Down Expand Up @@ -94,8 +96,8 @@ public static IEnumerable<TSource> Do<TSource>(this IEnumerable<TSource> source,
Guard.IsNotNull(onError);
Guard.IsNotNull(onCompleted);

if (source is ICollection<TSource> coll)
return new DoIterator<TSource>(coll, onNext, onError, onCompleted);
if (source.TryGetCollectionCount() != null)
return new DoIterator<TSource>(source, onNext, onError, onCompleted);

return DoCore(source, onNext, onError, onCompleted);
}
Expand Down Expand Up @@ -124,23 +126,24 @@ private static IEnumerable<TSource> DoCore<TSource>(IEnumerable<TSource> source,
}


private class DoIterator<T> : CollectionIterator<T>
private sealed class DoIterator<T> : CollectionIterator<T>
{
private readonly ICollection<T> _source;
private readonly IEnumerable<T> _source;
private readonly Action<T> _onNext;
private readonly Action<Exception>? _onError;
private readonly Action _onCompleted;

public DoIterator(ICollection<T> source, Action<T> onNext, Action<Exception>? onError, Action onCompleted)
public DoIterator(IEnumerable<T> source, Action<T> onNext, Action<Exception>? onError, Action onCompleted)
{
_source = source;
_onNext = onNext;
_onError = onError;
_onCompleted = onCompleted;
}

public override int Count => _source.Count;
public override int Count => _source.GetCollectionCount();

[ExcludeFromCodeCoverage]
protected override IEnumerable<T> GetEnumerable() =>
_onError != null
? DoCore(_source, _onNext, _onError, _onCompleted)
Expand Down
2 changes: 1 addition & 1 deletion Source/SuperLinq/PreScan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ private static IEnumerable<TSource> PreScanCore<TSource>(IEnumerable<TSource> so
}
}

private class PreScanIterator<T> : CollectionIterator<T>
private sealed class PreScanIterator<T> : CollectionIterator<T>
{
private readonly ICollection<T> _source;
private readonly Func<T, T, T> _transformation;
Expand Down
10 changes: 5 additions & 5 deletions Source/SuperLinq/Rank.cs
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ public static partial class SuperEnumerable
Guard.IsNotNull(source);
Guard.IsNotNull(keySelector);

if (source is ICollection<TSource> coll)
return new RankIterator<TSource, TKey>(coll, keySelector, comparer, isDense: false);
if (source.TryGetCollectionCount() != null)
return new RankIterator<TSource, TKey>(source, keySelector, comparer, isDense: false);

return RankByCore(source, keySelector, comparer, isDense: false);
}
Expand Down Expand Up @@ -205,20 +205,20 @@ public static partial class SuperEnumerable

private sealed class RankIterator<TSource, TKey> : CollectionIterator<(TSource, int)>
{
private readonly ICollection<TSource> _source;
private readonly IEnumerable<TSource> _source;
private readonly Func<TSource, TKey> _keySelector;
private readonly IComparer<TKey>? _comparer;
private readonly bool _isDense;

public RankIterator(ICollection<TSource> source, Func<TSource, TKey> keySelector, IComparer<TKey>? comparer, bool isDense)
public RankIterator(IEnumerable<TSource> source, Func<TSource, TKey> keySelector, IComparer<TKey>? comparer, bool isDense)
{
_source = source;
_keySelector = keySelector;
_comparer = comparer;
_isDense = isDense;
}

public override int Count => _source.Count;
public override int Count => _source.GetCollectionCount();

[ExcludeFromCodeCoverage]
protected override IEnumerable<(TSource, int)> GetEnumerable() =>
Expand Down
4 changes: 2 additions & 2 deletions Source/SuperLinq/Scan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private static IEnumerable<TSource> ScanCore<TSource>(IEnumerable<TSource> sourc
}
}

private class ScanIterator<T> : CollectionIterator<T>
private sealed class ScanIterator<T> : CollectionIterator<T>
{
private readonly ICollection<T> _source;
private readonly Func<T, T, T> _transformation;
Expand Down Expand Up @@ -149,7 +149,7 @@ private static IEnumerable<TState> ScanCore<TSource, TState>(
}
}

private class ScanStateIterator<TSource, TState> : CollectionIterator<TState>
private sealed class ScanStateIterator<TSource, TState> : CollectionIterator<TState>
{
private readonly ICollection<TSource> _source;
private readonly TState _state;
Expand Down
14 changes: 8 additions & 6 deletions Source/SuperLinq/ScanBy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,11 @@ public static partial class SuperEnumerable

comparer ??= EqualityComparer<TKey>.Default;

if (source is ICollection<TSource> coll)
if (source.TryGetCollectionCount() != null)
{
return new ScanByIterator<TSource, TKey, TState>(
coll, keySelector, seedSelector, accumulator, comparer);
source, keySelector, seedSelector, accumulator, comparer);
}

return ScanByCore(source, keySelector, seedSelector, accumulator, comparer);
}
Expand Down Expand Up @@ -107,16 +109,16 @@ public static partial class SuperEnumerable
}
}

private class ScanByIterator<TSource, TKey, TState> : CollectionIterator<(TKey key, TState state)>
private sealed class ScanByIterator<TSource, TKey, TState> : CollectionIterator<(TKey key, TState state)>
{
private readonly ICollection<TSource> _source;
private readonly IEnumerable<TSource> _source;
private readonly Func<TSource, TKey> _keySelector;
private readonly Func<TKey, TState> _seedSelector;
private readonly Func<TState, TKey, TSource, TState> _accumulator;
private readonly IEqualityComparer<TKey> _comparer;

public ScanByIterator(
ICollection<TSource> source,
IEnumerable<TSource> source,
Func<TSource, TKey> keySelector,
Func<TKey, TState> seedSelector,
Func<TState, TKey, TSource, TState> accumulator,
Expand All @@ -129,7 +131,7 @@ public ScanByIterator(
_comparer = comparer;
}

public override int Count => _source.Count;
public override int Count => _source.GetCollectionCount();

[ExcludeFromCodeCoverage]
protected override IEnumerable<(TKey key, TState state)> GetEnumerable() =>
Expand Down
4 changes: 2 additions & 2 deletions Source/SuperLinq/ScanRight.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private static IEnumerable<TSource> ScanRightCore<TSource>(IEnumerable<TSource>
yield return item;
}

private class ScanRightIterator<T> : CollectionIterator<T>
private sealed class ScanRightIterator<T> : CollectionIterator<T>
{
private readonly ICollection<T> _source;
private readonly Func<T, T, T> _func;
Expand Down Expand Up @@ -147,7 +147,7 @@ private static IEnumerable<TAccumulate> ScanRightCore<TSource, TAccumulate>(IEnu
yield return item;
}

private class ScanRightStateIterator<TSource, TAccumulate> : CollectionIterator<TAccumulate>
private sealed class ScanRightStateIterator<TSource, TAccumulate> : CollectionIterator<TAccumulate>
{
private readonly ICollection<TSource> _source;
private readonly TAccumulate _seed;
Expand Down
23 changes: 21 additions & 2 deletions Source/SuperLinq/SuperEnumerable.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
namespace SuperLinq;
using System.Diagnostics.CodeAnalysis;

namespace SuperLinq;

/// <summary>
/// Provides a set of static methods for querying objects that
/// implement <see cref="IEnumerable{T}" />.
/// </summary>
public static partial class SuperEnumerable
{
[ExcludeFromCodeCoverage]
internal static int? TryGetCollectionCount<T>(this IEnumerable<T> source) =>
#if NET6_0_OR_GREATER
source.TryGetNonEnumeratedCount(out var count) ? count : default(int?);
Expand All @@ -15,7 +18,23 @@ public static partial class SuperEnumerable
null => ThrowHelper.ThrowArgumentNullException<int?>(nameof(source)),
ICollection<T> collection => collection.Count,
System.Collections.ICollection collection => collection.Count,
_ => null
_ => null,
};
#endif

[ExcludeFromCodeCoverage]
internal static int GetCollectionCount<T>(this IEnumerable<T> source) =>
#if NET6_0_OR_GREATER
source.TryGetNonEnumeratedCount(out var count)
? count
: ThrowHelper.ThrowInvalidOperationException<int>("Expected valid non-enumerated count.");
#else
source switch
{
null => ThrowHelper.ThrowArgumentNullException<int>(nameof(source)),
ICollection<T> collection => collection.Count,
System.Collections.ICollection collection => collection.Count,
_ => ThrowHelper.ThrowInvalidOperationException<int>("Expected valid non-enumerated count."),
};
#endif

Expand Down

0 comments on commit d147954

Please sign in to comment.