diff --git a/Source/SuperLinq/IteratorCollection.cs b/Source/SuperLinq/CollectionIterator.cs similarity index 57% rename from Source/SuperLinq/IteratorCollection.cs rename to Source/SuperLinq/CollectionIterator.cs index 89539b57..1df5bcbe 100644 --- a/Source/SuperLinq/IteratorCollection.cs +++ b/Source/SuperLinq/CollectionIterator.cs @@ -6,29 +6,29 @@ namespace SuperLinq; public partial class SuperEnumerable { [ExcludeFromCodeCoverage] - private abstract class IteratorCollection : ICollection, IReadOnlyCollection + private abstract class CollectionIterator : ICollection, IReadOnlyCollection { public bool IsReadOnly => true; - public void Add(TResult item) => + public void Add(T item) => throw new NotSupportedException(); - public bool Remove(TResult item) => + public bool Remove(T item) => throw new NotSupportedException(); public void Clear() => throw new NotSupportedException(); IEnumerator IEnumerable.GetEnumerator() => GetEnumerator(); - protected abstract IEnumerable GetEnumerable(); + protected abstract IEnumerable GetEnumerable(); public abstract int Count { get; } - public virtual IEnumerator GetEnumerator() => + public virtual IEnumerator GetEnumerator() => GetEnumerable().GetEnumerator(); - public virtual bool Contains(TResult item) => + public virtual bool Contains(T item) => GetEnumerable().Contains(item); - public virtual void CopyTo(TResult[] array, int arrayIndex) => + public virtual void CopyTo(T[] array, int arrayIndex) => GetEnumerable().CopyTo(array, arrayIndex); } } diff --git a/Source/SuperLinq/Do.cs b/Source/SuperLinq/Do.cs index 0ca301ab..d187a8c1 100644 --- a/Source/SuperLinq/Do.cs +++ b/Source/SuperLinq/Do.cs @@ -124,7 +124,7 @@ private static IEnumerable DoCore(IEnumerable source, } - private class DoIterator : IteratorCollection + private class DoIterator : CollectionIterator { private readonly ICollection _source; private readonly Action _onNext; diff --git a/Source/SuperLinq/FillBackward.cs b/Source/SuperLinq/FillBackward.cs index 9b0a06f6..8e309594 100644 --- a/Source/SuperLinq/FillBackward.cs +++ b/Source/SuperLinq/FillBackward.cs @@ -122,7 +122,7 @@ private static IEnumerable FillBackwardCore(IEnumerable source, Func : IteratorCollection + private sealed class FillBackwardCollection : CollectionIterator { private readonly ICollection _source; private readonly Func _predicate; diff --git a/Source/SuperLinq/FillForward.cs b/Source/SuperLinq/FillForward.cs index 727732c8..6d42548c 100644 --- a/Source/SuperLinq/FillForward.cs +++ b/Source/SuperLinq/FillForward.cs @@ -110,7 +110,7 @@ private static IEnumerable FillForwardCore(IEnumerable source, Func : IteratorCollection + private sealed class FillForwardCollection : CollectionIterator { private readonly ICollection _source; private readonly Func _predicate; diff --git a/Source/SuperLinq/Interleave.cs b/Source/SuperLinq/Interleave.cs index b607d306..586a0a46 100644 --- a/Source/SuperLinq/Interleave.cs +++ b/Source/SuperLinq/Interleave.cs @@ -82,7 +82,7 @@ private static IEnumerable InterleaveCore(IEnumerable> sour } } - private sealed class InterleaveIterator : IteratorCollection + private sealed class InterleaveIterator : CollectionIterator { private readonly IEnumerable> _sources; diff --git a/Source/SuperLinq/Rank.cs b/Source/SuperLinq/Rank.cs index 3a696a25..4b5e744f 100644 --- a/Source/SuperLinq/Rank.cs +++ b/Source/SuperLinq/Rank.cs @@ -203,7 +203,7 @@ public static partial class SuperEnumerable } } - private sealed class RankIterator : IteratorCollection + private sealed class RankIterator : CollectionIterator<(TSource, int)> { private readonly ICollection _source; private readonly Func _keySelector; diff --git a/Source/SuperLinq/Scan.cs b/Source/SuperLinq/Scan.cs index 11c4bf2b..030de6a7 100644 --- a/Source/SuperLinq/Scan.cs +++ b/Source/SuperLinq/Scan.cs @@ -48,7 +48,7 @@ private static IEnumerable ScanCore(IEnumerable sourc } } - private class ScanIterator : IteratorCollection + private class ScanIterator : CollectionIterator { private readonly ICollection _source; private readonly Func _transformation; @@ -132,7 +132,7 @@ private static IEnumerable ScanCore( } } - private class ScanStateIterator : IteratorCollection + private class ScanStateIterator : CollectionIterator { private readonly ICollection _source; private readonly TState _state; diff --git a/Source/SuperLinq/ScanRight.cs b/Source/SuperLinq/ScanRight.cs index 064b611a..778f6592 100644 --- a/Source/SuperLinq/ScanRight.cs +++ b/Source/SuperLinq/ScanRight.cs @@ -1,4 +1,6 @@ -namespace SuperLinq; +using System.Diagnostics.CodeAnalysis; + +namespace SuperLinq; public static partial class SuperEnumerable { @@ -31,27 +33,65 @@ public static IEnumerable ScanRight(this IEnumerable Guard.IsNotNull(source); Guard.IsNotNull(func); - return Core(source, func); + if (source is ICollection coll) + return new ScanRightIterator(coll, func); - static IEnumerable Core(IEnumerable source, Func func) - { - var list = source is IList l ? l : source.ToList(); + return ScanRightCore(source, func); + } + + private static IEnumerable ScanRightCore(IEnumerable source, Func func) + { + var list = source.ToList(); - if (list.Count == 0) - yield break; + if (list.Count == 0) + yield break; - var seed = list[^1]; - var stack = new Stack(list.Count); + var seed = list[^1]; + var stack = new Stack(list.Count); + stack.Push(seed); + + for (var i = list.Count - 2; i >= 0; i--) + { + seed = func(list[i], seed); stack.Push(seed); + } + + foreach (var item in stack) + yield return item; + } + + private class ScanRightIterator : CollectionIterator + { + private readonly ICollection _source; + private readonly Func _func; + + public ScanRightIterator(ICollection source, Func func) + { + _source = source; + _func = func; + } + + public override int Count => _source.Count; + + [ExcludeFromCodeCoverage] + protected override IEnumerable GetEnumerable() => + ScanRightCore(_source, _func); + + public override void CopyTo(T[] array, int arrayIndex) + { + var (sList, b, cnt) = _source is IList s + ? (s, 0, s.Count) + : (array, arrayIndex, SuperEnumerable.CopyTo(_source, array, arrayIndex)); + + var i = cnt - 1; + var state = sList[b + i]; + array[arrayIndex + i] = state; - for (var i = list.Count - 2; i >= 0; i--) + for (i--; i >= 0; i--) { - seed = func(list[i], seed); - stack.Push(seed); + state = _func(sList[b + i], state); + array[arrayIndex + i] = state; } - - foreach (var item in stack) - yield return item; } } @@ -85,22 +125,59 @@ public static IEnumerable ScanRight(this IEnu Guard.IsNotNull(source); Guard.IsNotNull(func); - return Core(source, seed, func); + if (source is ICollection coll) + return new ScanRightStateIterator(coll, seed, func); + + return ScanRightCore(source, seed, func); + } + + private static IEnumerable ScanRightCore(IEnumerable source, TAccumulate seed, Func func) + { + var list = source.ToList(); + var stack = new Stack(list.Count + 1); + stack.Push(seed); - static IEnumerable Core(IEnumerable source, TAccumulate seed, Func func) + for (var i = list.Count - 1; i >= 0; i--) { - var list = source is IList l ? l : source.ToList(); - var stack = new Stack(list.Count + 1); + seed = func(list[i], seed); stack.Push(seed); + } + + foreach (var item in stack) + yield return item; + } + + private class ScanRightStateIterator : CollectionIterator + { + private readonly ICollection _source; + private readonly TAccumulate _seed; + private readonly Func _func; + + public ScanRightStateIterator(ICollection source, TAccumulate seed, Func func) + { + _source = source; + _seed = seed; + _func = func; + } + + public override int Count => _source.Count + 1; + + [ExcludeFromCodeCoverage] + protected override IEnumerable GetEnumerable() => + ScanRightCore(_source, _seed, _func); + + public override void CopyTo(TAccumulate[] array, int arrayIndex) + { + var list = _source is IList l ? l : _source.ToList(); + + var seed = _seed; + array[arrayIndex + list.Count] = seed; for (var i = list.Count - 1; i >= 0; i--) { - seed = func(list[i], seed); - stack.Push(seed); + seed = _func(list[i], seed); + array[arrayIndex + i] = seed; } - - foreach (var item in stack) - yield return item; } } } diff --git a/Tests/SuperLinq.Test/BreakingList.cs b/Tests/SuperLinq.Test/BreakingList.cs new file mode 100644 index 00000000..9088f366 --- /dev/null +++ b/Tests/SuperLinq.Test/BreakingList.cs @@ -0,0 +1,41 @@ +namespace Test; + +internal static class BreakingList +{ + public static BreakingList AsBreakingList(this IEnumerable source) => new(source); +} + +internal class BreakingList : BreakingSequence, IList, IDisposableEnumerable +{ + protected readonly IList List; + + public BreakingList(params T[] values) : this((IList)values) { } + public BreakingList(IEnumerable source) => List = source.ToList(); + public BreakingList(IList list) => List = list; + + public int Count => List.Count; + + public T this[int index] + { + get => List[index]; + set => Assert.Fail("LINQ Operators should not be calling this method."); + } + + public int IndexOf(T item) + { + Assert.Fail("LINQ Operators should not be calling this method."); + return -1; + } + + public void Add(T item) => Assert.Fail("LINQ Operators should not be calling this method."); + public void Insert(int index, T item) => Assert.Fail("LINQ Operators should not be calling this method."); + public void Clear() => Assert.Fail("LINQ Operators should not be calling this method."); + public bool Contains(T item) => List.Contains(item); + public bool Remove(T item) { Assert.Fail("LINQ Operators should not be calling this method."); return false; } + public void RemoveAt(int index) => Assert.Fail("LINQ Operators should not be calling this method."); + public bool IsReadOnly => true; + + public virtual void CopyTo(T[] array, int arrayIndex) => Assert.Fail("LINQ Operators should not be calling this method."); + + public void Dispose() { } +} diff --git a/Tests/SuperLinq.Test/ScanRightTest.cs b/Tests/SuperLinq.Test/ScanRightTest.cs index 015dbb7d..5eef9f73 100644 --- a/Tests/SuperLinq.Test/ScanRightTest.cs +++ b/Tests/SuperLinq.Test/ScanRightTest.cs @@ -56,6 +56,42 @@ public void ScanRightIsLazy() _ = new BreakingSequence().ScanRight(BreakingFunc.Of()); } + [Fact] + public void ScanRightCollection() + { + using var seq = Enumerable.Range(1, 10).AsBreakingCollection(); + + var result = seq.ScanRight((a, b) => a + b); + Assert.Equal(10, result.Count()); + + result.ToArray() + .AssertSequenceEqual(55, 54, 52, 49, 45, 40, 34, 27, 19, 10); + Assert.Equal(1, seq.CopyCount); + + var arr = new int[20]; + _ = result.CopyTo(arr, 5); + arr + .AssertSequenceEqual(0, 0, 0, 0, 0, 55, 54, 52, 49, 45, 40, 34, 27, 19, 10, 0, 0, 0, 0, 0); + Assert.Equal(2, seq.CopyCount); + } + + [Fact] + public void ScanRightList() + { + using var seq = Enumerable.Range(1, 10).AsBreakingList(); + + var result = seq.ScanRight((a, b) => a + b); + Assert.Equal(10, result.Count()); + + result.ToArray() + .AssertSequenceEqual(55, 54, 52, 49, 45, 40, 34, 27, 19, 10); + + var arr = new int[20]; + _ = result.CopyTo(arr, 5); + arr + .AssertSequenceEqual(0, 0, 0, 0, 0, 55, 54, 52, 49, 45, 40, 34, 27, 19, 10, 0, 0, 0, 0, 0); + } + // ScanRight(source, seed, func) [Theory] @@ -108,4 +144,40 @@ public void ScanRightSeedIsLazy() { _ = new BreakingSequence().ScanRight(string.Empty, BreakingFunc.Of()); } + + [Fact] + public void ScanRightSeedCollection() + { + using var seq = Enumerable.Range(1, 10).AsBreakingCollection(); + + var result = seq.ScanRight(5, (a, b) => a + b); + Assert.Equal(11, result.Count()); + + result.ToArray() + .AssertSequenceEqual(60, 59, 57, 54, 50, 45, 39, 32, 24, 15, 5); + Assert.Equal(1, seq.CopyCount); + + var arr = new int[20]; + _ = result.CopyTo(arr, 5); + arr + .AssertSequenceEqual(0, 0, 0, 0, 0, 60, 59, 57, 54, 50, 45, 39, 32, 24, 15, 5, 0, 0, 0, 0); + Assert.Equal(2, seq.CopyCount); + } + + [Fact] + public void ScanRightSeedList() + { + using var seq = Enumerable.Range(1, 10).AsBreakingList(); + + var result = seq.ScanRight(5, (a, b) => a + b); + Assert.Equal(11, result.Count()); + + result.ToArray() + .AssertSequenceEqual(60, 59, 57, 54, 50, 45, 39, 32, 24, 15, 5); + + var arr = new int[20]; + _ = result.CopyTo(arr, 5); + arr + .AssertSequenceEqual(0, 0, 0, 0, 0, 60, 59, 57, 54, 50, 45, 39, 32, 24, 15, 5, 0, 0, 0, 0); + } }