Skip to content

Commit

Permalink
ScanRight: ICollection<> implementation (#387)
Browse files Browse the repository at this point in the history
  • Loading branch information
viceroypenguin authored May 7, 2023
1 parent 244f777 commit d6a3261
Show file tree
Hide file tree
Showing 10 changed files with 228 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,29 @@ namespace SuperLinq;
public partial class SuperEnumerable
{
[ExcludeFromCodeCoverage]
private abstract class IteratorCollection<TSource, TResult> : ICollection<TResult>, IReadOnlyCollection<TResult>
private abstract class CollectionIterator<T> : ICollection<T>, IReadOnlyCollection<T>
{
public bool IsReadOnly => true;
public void Add(TResult item) =>
public void Add(T item) =>
throw new NotSupportedException();
public bool Remove(TResult item) =>
public bool Remove(T item) =>
throw new NotSupportedException();
public void Clear() =>
throw new NotSupportedException();

IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();

protected abstract IEnumerable<TResult> GetEnumerable();
protected abstract IEnumerable<T> GetEnumerable();

public abstract int Count { get; }

public virtual IEnumerator<TResult> GetEnumerator() =>
public virtual IEnumerator<T> GetEnumerator() =>
GetEnumerable().GetEnumerator();

public virtual bool Contains(TResult item) =>
public virtual bool Contains(T item) =>
GetEnumerable().Contains(item);

public virtual void CopyTo(TResult[] array, int arrayIndex) =>
public virtual void CopyTo(T[] array, int arrayIndex) =>
GetEnumerable().CopyTo(array, arrayIndex);
}
}
2 changes: 1 addition & 1 deletion Source/SuperLinq/Do.cs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ private static IEnumerable<TSource> DoCore<TSource>(IEnumerable<TSource> source,
}


private class DoIterator<T> : IteratorCollection<T, T>
private class DoIterator<T> : CollectionIterator<T>
{
private readonly ICollection<T> _source;
private readonly Action<T> _onNext;
Expand Down
2 changes: 1 addition & 1 deletion Source/SuperLinq/FillBackward.cs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ private static IEnumerable<T> FillBackwardCore<T>(IEnumerable<T> source, Func<T,
}
}

private sealed class FillBackwardCollection<T> : IteratorCollection<T, T>
private sealed class FillBackwardCollection<T> : CollectionIterator<T>
{
private readonly ICollection<T> _source;
private readonly Func<T, bool> _predicate;
Expand Down
2 changes: 1 addition & 1 deletion Source/SuperLinq/FillForward.cs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private static IEnumerable<T> FillForwardCore<T>(IEnumerable<T> source, Func<T,
}
}

private sealed class FillForwardCollection<T> : IteratorCollection<T, T>
private sealed class FillForwardCollection<T> : CollectionIterator<T>
{
private readonly ICollection<T> _source;
private readonly Func<T, bool> _predicate;
Expand Down
2 changes: 1 addition & 1 deletion Source/SuperLinq/Interleave.cs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ private static IEnumerable<T> InterleaveCore<T>(IEnumerable<IEnumerable<T>> sour
}
}

private sealed class InterleaveIterator<T> : IteratorCollection<T, T>
private sealed class InterleaveIterator<T> : CollectionIterator<T>
{
private readonly IEnumerable<ICollection<T>> _sources;

Expand Down
2 changes: 1 addition & 1 deletion Source/SuperLinq/Rank.cs
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ public static partial class SuperEnumerable
}
}

private sealed class RankIterator<TSource, TKey> : IteratorCollection<TSource, (TSource, int)>
private sealed class RankIterator<TSource, TKey> : CollectionIterator<(TSource, int)>
{
private readonly ICollection<TSource> _source;
private readonly Func<TSource, TKey> _keySelector;
Expand Down
4 changes: 2 additions & 2 deletions Source/SuperLinq/Scan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ private static IEnumerable<TSource> ScanCore<TSource>(IEnumerable<TSource> sourc
}
}

private class ScanIterator<T> : IteratorCollection<T, T>
private class ScanIterator<T> : CollectionIterator<T>
{
private readonly ICollection<T> _source;
private readonly Func<T, T, T> _transformation;
Expand Down Expand Up @@ -132,7 +132,7 @@ private static IEnumerable<TState> ScanCore<TSource, TState>(
}
}

private class ScanStateIterator<TSource, TState> : IteratorCollection<TSource, TState>
private class ScanStateIterator<TSource, TState> : CollectionIterator<TState>
{
private readonly ICollection<TSource> _source;
private readonly TState _state;
Expand Down
125 changes: 101 additions & 24 deletions Source/SuperLinq/ScanRight.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
namespace SuperLinq;
using System.Diagnostics.CodeAnalysis;

namespace SuperLinq;

public static partial class SuperEnumerable
{
Expand Down Expand Up @@ -31,27 +33,65 @@ public static IEnumerable<TSource> ScanRight<TSource>(this IEnumerable<TSource>
Guard.IsNotNull(source);
Guard.IsNotNull(func);

return Core(source, func);
if (source is ICollection<TSource> coll)
return new ScanRightIterator<TSource>(coll, func);

static IEnumerable<TSource> Core(IEnumerable<TSource> source, Func<TSource, TSource, TSource> func)
{
var list = source is IList<TSource> l ? l : source.ToList();
return ScanRightCore(source, func);
}

private static IEnumerable<TSource> ScanRightCore<TSource>(IEnumerable<TSource> source, Func<TSource, TSource, TSource> func)
{
var list = source.ToList();

if (list.Count == 0)
yield break;
if (list.Count == 0)
yield break;

var seed = list[^1];
var stack = new Stack<TSource>(list.Count);
var seed = list[^1];
var stack = new Stack<TSource>(list.Count);
stack.Push(seed);

for (var i = list.Count - 2; i >= 0; i--)
{
seed = func(list[i], seed);
stack.Push(seed);
}

foreach (var item in stack)
yield return item;
}

private class ScanRightIterator<T> : CollectionIterator<T>
{
private readonly ICollection<T> _source;
private readonly Func<T, T, T> _func;

public ScanRightIterator(ICollection<T> source, Func<T, T, T> func)
{
_source = source;
_func = func;
}

public override int Count => _source.Count;

[ExcludeFromCodeCoverage]
protected override IEnumerable<T> GetEnumerable() =>
ScanRightCore(_source, _func);

public override void CopyTo(T[] array, int arrayIndex)
{
var (sList, b, cnt) = _source is IList<T> s
? (s, 0, s.Count)
: (array, arrayIndex, SuperEnumerable.CopyTo(_source, array, arrayIndex));

var i = cnt - 1;
var state = sList[b + i];
array[arrayIndex + i] = state;

for (var i = list.Count - 2; i >= 0; i--)
for (i--; i >= 0; i--)
{
seed = func(list[i], seed);
stack.Push(seed);
state = _func(sList[b + i], state);
array[arrayIndex + i] = state;
}

foreach (var item in stack)
yield return item;
}
}

Expand Down Expand Up @@ -85,22 +125,59 @@ public static IEnumerable<TAccumulate> ScanRight<TSource, TAccumulate>(this IEnu
Guard.IsNotNull(source);
Guard.IsNotNull(func);

return Core(source, seed, func);
if (source is ICollection<TSource> coll)
return new ScanRightStateIterator<TSource, TAccumulate>(coll, seed, func);

return ScanRightCore(source, seed, func);
}

private static IEnumerable<TAccumulate> ScanRightCore<TSource, TAccumulate>(IEnumerable<TSource> source, TAccumulate seed, Func<TSource, TAccumulate, TAccumulate> func)
{
var list = source.ToList();
var stack = new Stack<TAccumulate>(list.Count + 1);
stack.Push(seed);

static IEnumerable<TAccumulate> Core(IEnumerable<TSource> source, TAccumulate seed, Func<TSource, TAccumulate, TAccumulate> func)
for (var i = list.Count - 1; i >= 0; i--)
{
var list = source is IList<TSource> l ? l : source.ToList();
var stack = new Stack<TAccumulate>(list.Count + 1);
seed = func(list[i], seed);
stack.Push(seed);
}

foreach (var item in stack)
yield return item;
}

private class ScanRightStateIterator<TSource, TAccumulate> : CollectionIterator<TAccumulate>
{
private readonly ICollection<TSource> _source;
private readonly TAccumulate _seed;
private readonly Func<TSource, TAccumulate, TAccumulate> _func;

public ScanRightStateIterator(ICollection<TSource> source, TAccumulate seed, Func<TSource, TAccumulate, TAccumulate> func)
{
_source = source;
_seed = seed;
_func = func;
}

public override int Count => _source.Count + 1;

[ExcludeFromCodeCoverage]
protected override IEnumerable<TAccumulate> GetEnumerable() =>
ScanRightCore(_source, _seed, _func);

public override void CopyTo(TAccumulate[] array, int arrayIndex)
{
var list = _source is IList<TSource> l ? l : _source.ToList();

var seed = _seed;
array[arrayIndex + list.Count] = seed;

for (var i = list.Count - 1; i >= 0; i--)
{
seed = func(list[i], seed);
stack.Push(seed);
seed = _func(list[i], seed);
array[arrayIndex + i] = seed;
}

foreach (var item in stack)
yield return item;
}
}
}
41 changes: 41 additions & 0 deletions Tests/SuperLinq.Test/BreakingList.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
namespace Test;

internal static class BreakingList
{
public static BreakingList<T> AsBreakingList<T>(this IEnumerable<T> source) => new(source);
}

internal class BreakingList<T> : BreakingSequence<T>, IList<T>, IDisposableEnumerable<T>
{
protected readonly IList<T> List;

public BreakingList(params T[] values) : this((IList<T>)values) { }
public BreakingList(IEnumerable<T> source) => List = source.ToList();
public BreakingList(IList<T> list) => List = list;

public int Count => List.Count;

public T this[int index]
{
get => List[index];
set => Assert.Fail("LINQ Operators should not be calling this method.");
}

public int IndexOf(T item)
{
Assert.Fail("LINQ Operators should not be calling this method.");
return -1;
}

public void Add(T item) => Assert.Fail("LINQ Operators should not be calling this method.");
public void Insert(int index, T item) => Assert.Fail("LINQ Operators should not be calling this method.");
public void Clear() => Assert.Fail("LINQ Operators should not be calling this method.");
public bool Contains(T item) => List.Contains(item);
public bool Remove(T item) { Assert.Fail("LINQ Operators should not be calling this method."); return false; }
public void RemoveAt(int index) => Assert.Fail("LINQ Operators should not be calling this method.");
public bool IsReadOnly => true;

public virtual void CopyTo(T[] array, int arrayIndex) => Assert.Fail("LINQ Operators should not be calling this method.");

public void Dispose() { }
}
72 changes: 72 additions & 0 deletions Tests/SuperLinq.Test/ScanRightTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,42 @@ public void ScanRightIsLazy()
_ = new BreakingSequence<int>().ScanRight(BreakingFunc.Of<int, int, int>());
}

[Fact]
public void ScanRightCollection()
{
using var seq = Enumerable.Range(1, 10).AsBreakingCollection();

var result = seq.ScanRight((a, b) => a + b);
Assert.Equal(10, result.Count());

result.ToArray()
.AssertSequenceEqual(55, 54, 52, 49, 45, 40, 34, 27, 19, 10);
Assert.Equal(1, seq.CopyCount);

var arr = new int[20];
_ = result.CopyTo(arr, 5);
arr
.AssertSequenceEqual(0, 0, 0, 0, 0, 55, 54, 52, 49, 45, 40, 34, 27, 19, 10, 0, 0, 0, 0, 0);
Assert.Equal(2, seq.CopyCount);
}

[Fact]
public void ScanRightList()
{
using var seq = Enumerable.Range(1, 10).AsBreakingList();

var result = seq.ScanRight((a, b) => a + b);
Assert.Equal(10, result.Count());

result.ToArray()
.AssertSequenceEqual(55, 54, 52, 49, 45, 40, 34, 27, 19, 10);

var arr = new int[20];
_ = result.CopyTo(arr, 5);
arr
.AssertSequenceEqual(0, 0, 0, 0, 0, 55, 54, 52, 49, 45, 40, 34, 27, 19, 10, 0, 0, 0, 0, 0);
}

// ScanRight(source, seed, func)

[Theory]
Expand Down Expand Up @@ -108,4 +144,40 @@ public void ScanRightSeedIsLazy()
{
_ = new BreakingSequence<int>().ScanRight(string.Empty, BreakingFunc.Of<int, string, string>());
}

[Fact]
public void ScanRightSeedCollection()
{
using var seq = Enumerable.Range(1, 10).AsBreakingCollection();

var result = seq.ScanRight(5, (a, b) => a + b);
Assert.Equal(11, result.Count());

result.ToArray()
.AssertSequenceEqual(60, 59, 57, 54, 50, 45, 39, 32, 24, 15, 5);
Assert.Equal(1, seq.CopyCount);

var arr = new int[20];
_ = result.CopyTo(arr, 5);
arr
.AssertSequenceEqual(0, 0, 0, 0, 0, 60, 59, 57, 54, 50, 45, 39, 32, 24, 15, 5, 0, 0, 0, 0);
Assert.Equal(2, seq.CopyCount);
}

[Fact]
public void ScanRightSeedList()
{
using var seq = Enumerable.Range(1, 10).AsBreakingList();

var result = seq.ScanRight(5, (a, b) => a + b);
Assert.Equal(11, result.Count());

result.ToArray()
.AssertSequenceEqual(60, 59, 57, 54, 50, 45, 39, 32, 24, 15, 5);

var arr = new int[20];
_ = result.CopyTo(arr, 5);
arr
.AssertSequenceEqual(0, 0, 0, 0, 0, 60, 59, 57, 54, 50, 45, 39, 32, 24, 15, 5, 0, 0, 0, 0);
}
}

0 comments on commit d6a3261

Please sign in to comment.