From edbf519ca39e1303937632d2c1019616ea588ea7 Mon Sep 17 00:00:00 2001 From: Stuart Turner Date: Tue, 9 May 2023 21:13:44 -0500 Subject: [PATCH 1/3] `Lag`: `IList<>` implementation --- Source/SuperLinq/Lag.cs | 46 +++++++-- Tests/SuperLinq.Test/LagTest.cs | 172 +++++++++++++++++--------------- 2 files changed, 128 insertions(+), 90 deletions(-) diff --git a/Source/SuperLinq/Lag.cs b/Source/SuperLinq/Lag.cs index 9be4f0ec..685b0030 100644 --- a/Source/SuperLinq/Lag.cs +++ b/Source/SuperLinq/Lag.cs @@ -17,10 +17,7 @@ public static partial class SuperEnumerable /// public static IEnumerable<(TSource current, TSource? lag)> Lag(this IEnumerable source, int offset) { - Guard.IsNotNull(source); - - return source.Select(Some) - .Lag(offset, default, (curr, lag) => (curr.Value, lag is (true, var some) ? some : default)); + return source.Lag(offset, ValueTuple.Create); } /// @@ -41,11 +38,7 @@ public static partial class SuperEnumerable /// public static IEnumerable Lag(this IEnumerable source, int offset, Func resultSelector) { - Guard.IsNotNull(source); - Guard.IsNotNull(resultSelector); - - return source.Select(Some) - .Lag(offset, default, (curr, lag) => resultSelector(curr.Value, lag is (true, var some) ? some : default)); + return source.Lag(offset, default!, resultSelector); } /// @@ -70,6 +63,9 @@ public static IEnumerable Lag(this IEnumerable list) + return new LagIterator(list, offset, defaultLagValue, resultSelector); + return Core(source, offset, defaultLagValue, resultSelector); static IEnumerable Core(IEnumerable source, int offset, TSource defaultLagValue, Func resultSelector) @@ -84,4 +80,36 @@ static IEnumerable Core(IEnumerable source, int offset, TSourc } } } + + private sealed class LagIterator : ListIterator + { + private readonly IList _source; + private readonly int _offset; + private readonly TSource _defaultLagValue; + private readonly Func _resultSelector; + + public LagIterator(IList source, int offset, TSource defaultLagValue, Func resultSelector) + { + _source = source; + _offset = offset; + _defaultLagValue = defaultLagValue; + _resultSelector = resultSelector; + } + + public override int Count => _source.Count; + + protected override IEnumerable GetEnumerable() + { + var cnt = (uint)_source.Count; + for (var i = 0; i < cnt; i++) + yield return _resultSelector( + _source[i], + i < _offset ? _defaultLagValue : _source[i - _offset]); + } + + protected override TResult ElementAt(int index) => + _resultSelector( + _source[index], + index < _offset ? _defaultLagValue : _source[index - _offset]); + } } diff --git a/Tests/SuperLinq.Test/LagTest.cs b/Tests/SuperLinq.Test/LagTest.cs index aaaa064d..a33f2d6d 100644 --- a/Tests/SuperLinq.Test/LagTest.cs +++ b/Tests/SuperLinq.Test/LagTest.cs @@ -35,108 +35,118 @@ public void TestLagZeroOffset() new BreakingSequence().Lag(0, (val, lagVal) => val + lagVal)); } - /// - /// Verify that lag can accept an propagate a default value passed to it. - /// - [Fact] - public void TestLagExplicitDefaultValue() + public static IEnumerable GetIntSequences() => + Enumerable.Range(1, 100) + .GetListSequences() + .Select(x => new object[] { x }); + + [Theory] + [MemberData(nameof(GetIntSequences))] + public void TestLagExplicitDefaultValue(IDisposableEnumerable seq) { - using var sequence = Enumerable.Range(1, 100).AsTestingSequence(); - - var result = sequence.Lag(10, -1, (val, lagVal) => lagVal).ToList(); - Assert.Equal(100, result.Count); - Assert.Equal(Enumerable.Repeat(-1, 10), result.Take(10)); + using (seq) + { + var result = seq.Lag(10, -1, (val, lagVal) => lagVal); + result.AssertSequenceEqual( + Enumerable.Repeat(-1, 10).Concat(Enumerable.Range(1, 90))); + } } - [Fact] - public void TestLagTuple() + [Theory] + [MemberData(nameof(GetIntSequences))] + public void TestLagTuple(IDisposableEnumerable seq) { - using var sequence = Enumerable.Range(1, 100).AsTestingSequence(); - - var result = sequence.Lag(10).ToList(); - Assert.Equal(100, result.Count); - result.AssertSequenceEqual( - Enumerable.Range(1, 100).Select(x => (x, x <= 10 ? default : x - 10))); + using (seq) + { + var result = seq.Lag(10); + result.AssertSequenceEqual( + Enumerable.Range(1, 100).Select(x => (x, x <= 10 ? default : x - 10))); + } } - /// - /// Verify that lag will use default(T) if a specific default value is not supplied for the lag value. - /// - [Fact] - public void TestLagImplicitDefaultValue() + [Theory] + [MemberData(nameof(GetIntSequences))] + public void TestLagImplicitDefaultValue(IDisposableEnumerable seq) { - using var sequence = Enumerable.Range(1, 100).AsTestingSequence(); - - var result = sequence.Lag(10, (val, lagVal) => lagVal).ToList(); - Assert.Equal(100, result.Count); - Assert.Equal(Enumerable.Repeat(default(int), 10), result.Take(10)); + using (seq) + { + var result = seq.Lag(10, (val, lagVal) => lagVal); + result.AssertSequenceEqual( + Enumerable.Repeat(default(int), 10) + .Concat(Enumerable.Range(1, 90))); + } } - /// - /// Verify that if the lag offset is greater than the sequence length lag - /// still yields all of the elements of the source sequence. - /// - [Fact] - public void TestLagOffsetGreaterThanSequenceLength() + [Theory] + [MemberData(nameof(GetIntSequences))] + public void TestLagOffsetGreaterThanSequenceLength(IDisposableEnumerable seq) { - using var sequence = Enumerable.Range(1, 100).AsTestingSequence(); - - var result = sequence.Lag(100 + 1, (a, b) => a).ToList(); - Assert.Equal(100, result.Count); - Assert.Equal(Enumerable.Range(1, 100), result); + using (seq) + { + var result = seq.Lag(100 + 1, (a, b) => a); + result.AssertSequenceEqual( + Enumerable.Range(1, 100)); + } } - /// - /// Verify that lag actually yields the correct pair of values from the sequence - /// when offsetting by a single item. - /// - [Fact] - public void TestLagPassesCorrectLagValueOffsetBy1() + [Theory] + [MemberData(nameof(GetIntSequences))] + public void TestLagPassesCorrectLagValueOffsetBy1(IDisposableEnumerable seq) { - using var sequence = Enumerable.Range(1, 100).AsTestingSequence(); - - var result = sequence.Lag(1, (a, b) => new { A = a, B = b }).ToList(); - Assert.Equal(100, result.Count); - Assert.True(result.All(x => x.B == (x.A - 1))); + using (seq) + { + var result = seq.Lag(1); + result.AssertSequenceEqual( + Enumerable.Range(1, 100) + .Select(x => (x, x - 1))); + } } - /// - /// Verify that lag yields the correct pair of values from the sequence when - /// offsetting by more than a single item. - /// - [Fact] - public void TestLagPassesCorrectLagValuesOffsetBy2() + [Theory] + [MemberData(nameof(GetIntSequences))] + public void TestLagPassesCorrectLagValuesOffsetBy2(IDisposableEnumerable seq) { - using var sequence = Enumerable.Range(1, 100).AsTestingSequence(); - - var result = sequence.Lag(2, (a, b) => new { A = a, B = b }).ToList(); - Assert.Equal(100, result.Count); - Assert.True(result.Skip(2).All(x => x.B == (x.A - 2))); - Assert.True(result.Take(2).All(x => (x.A - x.B) == x.A)); + using (seq) + { + var result = seq.Lag(2); + result.AssertSequenceEqual( + Enumerable.Range(1, 100) + .Select(x => (x, x <= 2 ? 0 : x - 2))); + } } - [Fact] - public void TestLagWithNullableReferences() + public static IEnumerable GetStringSequences() => + Seq("foo", "bar", "baz", "qux") + .GetListSequences() + .Select(x => new object[] { x, }); + + [Theory] + [MemberData(nameof(GetStringSequences))] + public void TestLagWithNullableReferences(IDisposableEnumerable seq) { - using var words = TestingSequence.Of("foo", "bar", "baz", "qux"); - var result = words.Lag(2, (a, b) => new { A = a, B = b }); - result.AssertSequenceEqual( - new { A = "foo", B = (string?)null }, - new { A = "bar", B = (string?)null }, - new { A = "baz", B = (string?)"foo" }, - new { A = "qux", B = (string?)"bar" }); + using (seq) + { + var result = seq.Lag(2, (a, b) => new { A = a, B = b }); + result.AssertSequenceEqual( + new { A = "foo", B = (string?)null }, + new { A = "bar", B = (string?)null }, + new { A = "baz", B = (string?)"foo" }, + new { A = "qux", B = (string?)"bar" }); + } } - [Fact] - public void TestLagWithNonNullableReferences() + [Theory] + [MemberData(nameof(GetStringSequences))] + public void TestLagWithNonNullableReferences(IDisposableEnumerable seq) { - using var words = TestingSequence.Of("foo", "bar", "baz", "qux"); - var empty = string.Empty; - var result = words.Lag(2, empty, (a, b) => new { A = a, B = b }); - result.AssertSequenceEqual( - new { A = "foo", B = empty }, - new { A = "bar", B = empty }, - new { A = "baz", B = "foo" }, - new { A = "qux", B = "bar" }); + using (seq) + { + var result = seq.Lag(2, string.Empty, (a, b) => new { A = a, B = b }); + result.AssertSequenceEqual( + new { A = "foo", B = string.Empty, }, + new { A = "bar", B = string.Empty, }, + new { A = "baz", B = "foo" }, + new { A = "qux", B = "bar" }); + } } } From 0172e98fd40e6337a5ad4c326991d83396392280 Mon Sep 17 00:00:00 2001 From: Stuart Turner Date: Tue, 9 May 2023 21:16:05 -0500 Subject: [PATCH 2/3] Test `IList<>` behavior --- Tests/SuperLinq.Test/LagTest.cs | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/Tests/SuperLinq.Test/LagTest.cs b/Tests/SuperLinq.Test/LagTest.cs index a33f2d6d..77da2402 100644 --- a/Tests/SuperLinq.Test/LagTest.cs +++ b/Tests/SuperLinq.Test/LagTest.cs @@ -149,4 +149,15 @@ public void TestLagWithNonNullableReferences(IDisposableEnumerable seq) new { A = "qux", B = "bar" }); } } + + [Fact] + public void ZipMapListBehavior() + { + using var seq = Enumerable.Range(0, 10_000).AsBreakingList(); + + var result = seq.Lag(20); + Assert.Equal(10_000, result.Count()); + Assert.Equal((50, 30), result.ElementAt(50)); + Assert.Equal((9_950, 9_930), result.ElementAt(^50)); + } } From dee4505a0db3c6d387aa0e76f22e817c2e219c97 Mon Sep 17 00:00:00 2001 From: Stuart Turner Date: Tue, 9 May 2023 21:34:13 -0500 Subject: [PATCH 3/3] Improve coverage --- Tests/SuperLinq.Test/LagTest.cs | 1 + 1 file changed, 1 insertion(+) diff --git a/Tests/SuperLinq.Test/LagTest.cs b/Tests/SuperLinq.Test/LagTest.cs index 77da2402..3fcfe821 100644 --- a/Tests/SuperLinq.Test/LagTest.cs +++ b/Tests/SuperLinq.Test/LagTest.cs @@ -157,6 +157,7 @@ public void ZipMapListBehavior() var result = seq.Lag(20); Assert.Equal(10_000, result.Count()); + Assert.Equal((10, 0), result.ElementAt(10)); Assert.Equal((50, 30), result.ElementAt(50)); Assert.Equal((9_950, 9_930), result.ElementAt(^50)); }