Skip to content

Commit

Permalink
Improve ICollection<> error checking (#462)
Browse files Browse the repository at this point in the history
* DRY some `ICollection<>` tests
* Fix `CollectionIterator` bounds check
* `AssertCount`
* `Batch`
* `CountDown`
* `Do`
* `Exclude`
* `FillBackward`
* `FillForward`
* `Insert`
* `Interleave`
* `Lag`
* `Lead`
* `Pad`
* `PadStart`
* `PreScan`
* `Rank`
* `Replace`
* `Sequence`
* `ScanBy`
* `ScanRight`
* `Scan`
* `TagFirstLast`
* `Window`
* `WindowLeft`
* `WindowRight`
* `ZipMap`
* `ZipLongest`
* `ZipShortest`
  • Loading branch information
viceroypenguin authored May 26, 2023
1 parent 933afe3 commit bad12b0
Show file tree
Hide file tree
Showing 49 changed files with 316 additions and 135 deletions.
4 changes: 2 additions & 2 deletions Generators/SuperLinq.Generator/ZipLongest.sbntxt
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ public static partial class SuperEnumerable
{{~ end ~}}) =>
ZipLongest({{~ for $j in 1..$i ~}}{{ $ordinals[$j] }}, {{ end }}global::System.ValueTuple.Create);

private class ZipLongestIterator<{{ for $j in 1..$i }}T{{ $j }}, {{ end }}TResult> : ListIterator<TResult>
private sealed class ZipLongestIterator<{{ for $j in 1..$i }}T{{ $j }}, {{ end }}TResult> : ListIterator<TResult>
{
{{~ for $j in 1..$i ~}}
private readonly global::System.Collections.Generic.IList<T{{ $j }}> _list{{ $j }};
Expand Down Expand Up @@ -160,7 +160,7 @@ public static partial class SuperEnumerable

protected override TResult ElementAt(int index)
{
global::CommunityToolkit.Diagnostics.Guard.IsLessThan(index, Count);
global::CommunityToolkit.Diagnostics.Guard.IsBetweenOrEqualTo(index, 0, Count - 1);

return _resultSelector(
{{~ for $j in 1..$i ~}}
Expand Down
4 changes: 2 additions & 2 deletions Generators/SuperLinq.Generator/ZipShortest.sbntxt
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ public static partial class SuperEnumerable
{{~ end ~}}) =>
ZipShortest({{~ for $j in 1..$i ~}}{{ $ordinals[$j] }}, {{ end }}global::System.ValueTuple.Create);

private class ZipShortestIterator<{{ for $j in 1..$i }}T{{ $j }}, {{ end }}TResult> : ListIterator<TResult>
private sealed class ZipShortestIterator<{{ for $j in 1..$i }}T{{ $j }}, {{ end }}TResult> : ListIterator<TResult>
{
{{~ for $j in 1..$i ~}}
private readonly global::System.Collections.Generic.IList<T{{ $j }}> _list{{ $j }};
Expand Down Expand Up @@ -145,7 +145,7 @@ public static partial class SuperEnumerable

protected override TResult ElementAt(int index)
{
global::CommunityToolkit.Diagnostics.Guard.IsLessThan(index, Count);
global::CommunityToolkit.Diagnostics.Guard.IsBetweenOrEqualTo(index, 0, Count - 1);

return _resultSelector(
{{~ for $j in 1..$i ~}}
Expand Down
8 changes: 4 additions & 4 deletions Source/SuperLinq/AssertCount.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static IEnumerable<TSource> Core(IEnumerable<TSource> source, int count)
}
}

private class AssertCountCollectionIterator<T> : CollectionIterator<T>
private sealed class AssertCountCollectionIterator<T> : CollectionIterator<T>
{
private readonly IEnumerable<T> _source;
private readonly int _count;
Expand Down Expand Up @@ -76,13 +76,13 @@ protected override IEnumerable<T> GetEnumerable()
public override void CopyTo(T[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, Count - 1);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

_ = _source.CopyTo(array, arrayIndex);
}
}

private class AssertCountListIterator<T> : ListIterator<T>
private sealed class AssertCountListIterator<T> : ListIterator<T>
{
private readonly IList<T> _source;
private readonly int _count;
Expand Down Expand Up @@ -114,7 +114,7 @@ protected override IEnumerable<T> GetEnumerable()
public override void CopyTo(T[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, Count - 1);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

_source.CopyTo(array, arrayIndex);
}
Expand Down
2 changes: 1 addition & 1 deletion Source/SuperLinq/CollectionIterator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public virtual bool Contains(T item) =>
public virtual void CopyTo(T[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsGreaterThanOrEqualTo(arrayIndex, 0);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

if (Count + arrayIndex > array.Length)
ThrowHelper.ThrowArgumentException(nameof(array), "Destination is not long enough.");
Expand Down
1 change: 1 addition & 0 deletions Source/SuperLinq/CountDown.cs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ protected override IEnumerable<TResult> GetEnumerable()

protected override TResult ElementAt(int index)
{
Guard.IsBetweenOrEqualTo(index, 0, Count - 1);
return _resultSelector(
_source[index],
_source.Count - index < _count ? _source.Count - index - 1 : null);
Expand Down
2 changes: 1 addition & 1 deletion Source/SuperLinq/FillBackward.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ protected override IEnumerable<T> GetEnumerable() =>
public override void CopyTo(T[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsGreaterThanOrEqualTo(arrayIndex, 0);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

_source.CopyTo(array, arrayIndex);

Expand Down
2 changes: 1 addition & 1 deletion Source/SuperLinq/FillForward.cs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ protected override IEnumerable<T> GetEnumerable() =>
public override void CopyTo(T[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsGreaterThanOrEqualTo(arrayIndex, 0);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

_source.CopyTo(array, arrayIndex);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private static bool DoRead<T>(bool flag, IEnumerator<T> iter, out T? value)
/// <typeparam name = "T2">The type of the elements of <paramref name = "second"/>.</typeparam>
/// <param name = "second">The second sequence of elements.</param>
public static global::System.Collections.Generic.IEnumerable<(T1? , T2? )> ZipLongest<T1, T2>(this global::System.Collections.Generic.IEnumerable<T1> first, global::System.Collections.Generic.IEnumerable<T2> second) => ZipLongest(first, second, global::System.ValueTuple.Create);
private class ZipLongestIterator<T1, T2, TResult> : ListIterator<TResult>
private sealed class ZipLongestIterator<T1, T2, TResult> : ListIterator<TResult>
{
private readonly global::System.Collections.Generic.IList<T1> _list1;
private readonly global::System.Collections.Generic.IList<T2> _list2;
Expand All @@ -102,7 +102,7 @@ protected override IEnumerable<TResult> GetEnumerable()

protected override TResult ElementAt(int index)
{
global::CommunityToolkit.Diagnostics.Guard.IsLessThan(index, Count);
global::CommunityToolkit.Diagnostics.Guard.IsBetweenOrEqualTo(index, 0, Count - 1);
return _resultSelector(index < _list1.Count ? _list1[index] : default, index < _list2.Count ? _list2[index] : default);
}
}
Expand Down Expand Up @@ -177,7 +177,7 @@ protected override TResult ElementAt(int index)
/// <typeparam name = "T3">The type of the elements of <paramref name = "third"/>.</typeparam>
/// <param name = "third">The third sequence of elements.</param>
public static global::System.Collections.Generic.IEnumerable<(T1? , T2? , T3? )> ZipLongest<T1, T2, T3>(this global::System.Collections.Generic.IEnumerable<T1> first, global::System.Collections.Generic.IEnumerable<T2> second, global::System.Collections.Generic.IEnumerable<T3> third) => ZipLongest(first, second, third, global::System.ValueTuple.Create);
private class ZipLongestIterator<T1, T2, T3, TResult> : ListIterator<TResult>
private sealed class ZipLongestIterator<T1, T2, T3, TResult> : ListIterator<TResult>
{
private readonly global::System.Collections.Generic.IList<T1> _list1;
private readonly global::System.Collections.Generic.IList<T2> _list2;
Expand All @@ -204,7 +204,7 @@ protected override IEnumerable<TResult> GetEnumerable()

protected override TResult ElementAt(int index)
{
global::CommunityToolkit.Diagnostics.Guard.IsLessThan(index, Count);
global::CommunityToolkit.Diagnostics.Guard.IsBetweenOrEqualTo(index, 0, Count - 1);
return _resultSelector(index < _list1.Count ? _list1[index] : default, index < _list2.Count ? _list2[index] : default, index < _list3.Count ? _list3[index] : default);
}
}
Expand Down Expand Up @@ -286,7 +286,7 @@ protected override TResult ElementAt(int index)
/// <typeparam name = "T4">The type of the elements of <paramref name = "fourth"/>.</typeparam>
/// <param name = "fourth">The fourth sequence of elements.</param>
public static global::System.Collections.Generic.IEnumerable<(T1? , T2? , T3? , T4? )> ZipLongest<T1, T2, T3, T4>(this global::System.Collections.Generic.IEnumerable<T1> first, global::System.Collections.Generic.IEnumerable<T2> second, global::System.Collections.Generic.IEnumerable<T3> third, global::System.Collections.Generic.IEnumerable<T4> fourth) => ZipLongest(first, second, third, fourth, global::System.ValueTuple.Create);
private class ZipLongestIterator<T1, T2, T3, T4, TResult> : ListIterator<TResult>
private sealed class ZipLongestIterator<T1, T2, T3, T4, TResult> : ListIterator<TResult>
{
private readonly global::System.Collections.Generic.IList<T1> _list1;
private readonly global::System.Collections.Generic.IList<T2> _list2;
Expand Down Expand Up @@ -315,7 +315,7 @@ protected override IEnumerable<TResult> GetEnumerable()

protected override TResult ElementAt(int index)
{
global::CommunityToolkit.Diagnostics.Guard.IsLessThan(index, Count);
global::CommunityToolkit.Diagnostics.Guard.IsBetweenOrEqualTo(index, 0, Count - 1);
return _resultSelector(index < _list1.Count ? _list1[index] : default, index < _list2.Count ? _list2[index] : default, index < _list3.Count ? _list3[index] : default, index < _list4.Count ? _list4[index] : default);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public static partial class SuperEnumerable
/// <typeparam name = "TSecond">The type of the elements of <paramref name = "second"/>.</typeparam>
/// <param name = "second">The second sequence of elements.</param>
public static global::System.Collections.Generic.IEnumerable<(TFirst, TSecond)> ZipShortest<TFirst, TSecond>(this global::System.Collections.Generic.IEnumerable<TFirst> first, global::System.Collections.Generic.IEnumerable<TSecond> second) => ZipShortest(first, second, global::System.ValueTuple.Create);
private class ZipShortestIterator<T1, T2, TResult> : ListIterator<TResult>
private sealed class ZipShortestIterator<T1, T2, TResult> : ListIterator<TResult>
{
private readonly global::System.Collections.Generic.IList<T1> _list1;
private readonly global::System.Collections.Generic.IList<T2> _list2;
Expand All @@ -86,7 +86,7 @@ protected override IEnumerable<TResult> GetEnumerable()

protected override TResult ElementAt(int index)
{
global::CommunityToolkit.Diagnostics.Guard.IsLessThan(index, Count);
global::CommunityToolkit.Diagnostics.Guard.IsBetweenOrEqualTo(index, 0, Count - 1);
return _resultSelector(_list1[index], _list2[index]);
}
}
Expand Down Expand Up @@ -156,7 +156,7 @@ protected override TResult ElementAt(int index)
/// <typeparam name = "TThird">The type of the elements of <paramref name = "third"/>.</typeparam>
/// <param name = "third">The third sequence of elements.</param>
public static global::System.Collections.Generic.IEnumerable<(TFirst, TSecond, TThird)> ZipShortest<TFirst, TSecond, TThird>(this global::System.Collections.Generic.IEnumerable<TFirst> first, global::System.Collections.Generic.IEnumerable<TSecond> second, global::System.Collections.Generic.IEnumerable<TThird> third) => ZipShortest(first, second, third, global::System.ValueTuple.Create);
private class ZipShortestIterator<T1, T2, T3, TResult> : ListIterator<TResult>
private sealed class ZipShortestIterator<T1, T2, T3, TResult> : ListIterator<TResult>
{
private readonly global::System.Collections.Generic.IList<T1> _list1;
private readonly global::System.Collections.Generic.IList<T2> _list2;
Expand All @@ -183,7 +183,7 @@ protected override IEnumerable<TResult> GetEnumerable()

protected override TResult ElementAt(int index)
{
global::CommunityToolkit.Diagnostics.Guard.IsLessThan(index, Count);
global::CommunityToolkit.Diagnostics.Guard.IsBetweenOrEqualTo(index, 0, Count - 1);
return _resultSelector(_list1[index], _list2[index], _list3[index]);
}
}
Expand Down Expand Up @@ -259,7 +259,7 @@ protected override TResult ElementAt(int index)
/// <typeparam name = "TFourth">The type of the elements of <paramref name = "fourth"/>.</typeparam>
/// <param name = "fourth">The fourth sequence of elements.</param>
public static global::System.Collections.Generic.IEnumerable<(TFirst, TSecond, TThird, TFourth)> ZipShortest<TFirst, TSecond, TThird, TFourth>(this global::System.Collections.Generic.IEnumerable<TFirst> first, global::System.Collections.Generic.IEnumerable<TSecond> second, global::System.Collections.Generic.IEnumerable<TThird> third, global::System.Collections.Generic.IEnumerable<TFourth> fourth) => ZipShortest(first, second, third, fourth, global::System.ValueTuple.Create);
private class ZipShortestIterator<T1, T2, T3, T4, TResult> : ListIterator<TResult>
private sealed class ZipShortestIterator<T1, T2, T3, T4, TResult> : ListIterator<TResult>
{
private readonly global::System.Collections.Generic.IList<T1> _list1;
private readonly global::System.Collections.Generic.IList<T2> _list2;
Expand Down Expand Up @@ -288,7 +288,7 @@ protected override IEnumerable<TResult> GetEnumerable()

protected override TResult ElementAt(int index)
{
global::CommunityToolkit.Diagnostics.Guard.IsLessThan(index, Count);
global::CommunityToolkit.Diagnostics.Guard.IsBetweenOrEqualTo(index, 0, Count - 1);
return _resultSelector(_list1[index], _list2[index], _list3[index], _list4[index]);
}
}
Expand Down
4 changes: 2 additions & 2 deletions Source/SuperLinq/Insert.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private static IEnumerable<T> InsertCore<T>(IEnumerable<T> first, IEnumerable<T>
yield return iter.Current;
}

private class InsertCollectionIterator<T> : CollectionIterator<T>
private sealed class InsertCollectionIterator<T> : CollectionIterator<T>
{
private readonly IEnumerable<T> _first;
private readonly IEnumerable<T> _second;
Expand Down Expand Up @@ -179,7 +179,7 @@ public override void CopyTo(T[] array, int arrayIndex)
}
}

private class InsertListIterator<T> : ListIterator<T>
private sealed class InsertListIterator<T> : ListIterator<T>
{
private readonly IList<T> _first;
private readonly IList<T> _second;
Expand Down
7 changes: 5 additions & 2 deletions Source/SuperLinq/Lag.cs
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,12 @@ protected override IEnumerable<TResult> GetEnumerable()
i < _offset ? _defaultLagValue : _source[i - _offset]);
}

protected override TResult ElementAt(int index) =>
_resultSelector(
protected override TResult ElementAt(int index)
{
Guard.IsBetweenOrEqualTo(index, 0, Count - 1);
return _resultSelector(
_source[index],
index < _offset ? _defaultLagValue : _source[index - _offset]);
}
}
}
7 changes: 5 additions & 2 deletions Source/SuperLinq/Lead.cs
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,14 @@ protected override IEnumerable<TResult> GetEnumerable()
i < maxOffset ? _source[i + _offset] : _defaultLeadValue);
}

protected override TResult ElementAt(int index) =>
_resultSelector(
protected override TResult ElementAt(int index)
{
Guard.IsBetweenOrEqualTo(index, 0, Count - 1);
return _resultSelector(
_source[index],
index < Math.Max(_source.Count - _offset, 0)
? _source[index + _offset]
: _defaultLeadValue);
}
}
}
2 changes: 1 addition & 1 deletion Source/SuperLinq/Pad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ public override void CopyTo(T[] array, int arrayIndex)

protected override T ElementAt(int index)
{
Guard.IsLessThan(index, Count);
Guard.IsBetweenOrEqualTo(index, 0, Count - 1);
return index < _source.Count
? _source[index]
: _paddingSelector(index);
Expand Down
2 changes: 1 addition & 1 deletion Source/SuperLinq/PadStart.cs
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ public override void CopyTo(T[] array, int arrayIndex)

protected override T ElementAt(int index)
{
Guard.IsLessThan(index, Count);
Guard.IsBetweenOrEqualTo(index, 0, Count - 1);

var offset = Math.Max(_width - _source.Count, 0);
return index < offset
Expand Down
3 changes: 3 additions & 0 deletions Source/SuperLinq/PreScan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ protected override IEnumerable<T> GetEnumerable() =>

public override void CopyTo(T[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

var (sList, b, cnt) = _source is IList<T> s
? (s, 0, s.Count)
: (array, arrayIndex, SuperEnumerable.CopyTo(_source, array, arrayIndex));
Expand Down
3 changes: 3 additions & 0 deletions Source/SuperLinq/Replace.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,9 @@ protected override IEnumerable<TSource> GetEnumerable()

public override void CopyTo(TSource[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

_source.CopyTo(array, arrayIndex);

var idx = _index.GetOffset(_source.Count);
Expand Down
6 changes: 6 additions & 0 deletions Source/SuperLinq/Scan.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ protected override IEnumerable<T> GetEnumerable() =>

public override void CopyTo(T[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

var (sList, b, cnt) = _source is IList<T> s
? (s, 0, s.Count)
: (array, arrayIndex, SuperEnumerable.CopyTo(_source, array, arrayIndex));
Expand Down Expand Up @@ -173,6 +176,9 @@ protected override IEnumerable<TState> GetEnumerable() =>

public override void CopyTo(TState[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

var list = _source is IList<TSource> l ? l : _source.ToList();

var state = _state;
Expand Down
6 changes: 6 additions & 0 deletions Source/SuperLinq/ScanRight.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ protected override IEnumerable<T> GetEnumerable() =>

public override void CopyTo(T[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

var (sList, b, cnt) = _source is IList<T> s
? (s, 0, s.Count)
: (array, arrayIndex, SuperEnumerable.CopyTo(_source, array, arrayIndex));
Expand Down Expand Up @@ -168,6 +171,9 @@ protected override IEnumerable<TAccumulate> GetEnumerable() =>

public override void CopyTo(TAccumulate[] array, int arrayIndex)
{
Guard.IsNotNull(array);
Guard.IsBetweenOrEqualTo(arrayIndex, 0, array.Length - Count);

var list = _source is IList<TSource> l ? l : _source.ToList();

var seed = _seed;
Expand Down
2 changes: 1 addition & 1 deletion Source/SuperLinq/Sequence.cs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public static IEnumerable<int> Sequence(int start, int stop, int step)
return new SequenceIterator(start, step, (((long)stop - start) / step) + 1);
}

private class SequenceIterator : ListIterator<int>
private sealed class SequenceIterator : ListIterator<int>
{
private readonly int _start;
private readonly int _step;
Expand Down
2 changes: 2 additions & 0 deletions Source/SuperLinq/ZipMap.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ public ZipMapIterator(IList<TSource> source, Func<TSource, TResult> selector)

protected override (TSource, TResult) ElementAt(int index)
{
Guard.IsBetweenOrEqualTo(index, 0, Count - 1);

var el = _source[index];
return (el, _selector(el));
}
Expand Down
10 changes: 4 additions & 6 deletions Tests/SuperLinq.Test/AssertCountTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ public void AssertCountCollectionBehavior()
using var seq = Enumerable.Range(0, 10_000).AsBreakingCollection();

var result = seq.AssertCount(10_000);
Assert.Equal(10_000, result.Count());
result.AssertCollectionErrorChecking(10_000);
result.AssertSequenceEqual(Enumerable.Range(0, 10_000));
}

Expand All @@ -73,14 +73,12 @@ public void AssertCountListBehavior()
using var seq = Enumerable.Range(0, 10_000).AsBreakingList();

var result = seq.AssertCount(10_000);
Assert.Equal(10_000, result.Count());
result.AssertCollectionErrorChecking(10_000);
result.AssertListElementChecking(10_000);

Assert.Equal(200, result.ElementAt(200));
Assert.Equal(1_200, result.ElementAt(1_200));
Assert.Equal(8_800, result.ElementAt(^1_200));

_ = Assert.Throws<ArgumentOutOfRangeException>(
"index",
() => result.ElementAt(40_001));
}

[Fact]
Expand Down
Loading

0 comments on commit bad12b0

Please sign in to comment.