Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lag: IList<> implementation #404

Merged
merged 3 commits into from
May 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions Source/SuperLinq/Lag.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,7 @@ public static partial class SuperEnumerable
/// </remarks>
public static IEnumerable<(TSource current, TSource? lag)> Lag<TSource>(this IEnumerable<TSource> source, int offset)
{
Guard.IsNotNull(source);

return source.Select(Some)
.Lag(offset, default, (curr, lag) => (curr.Value, lag is (true, var some) ? some : default));
return source.Lag(offset, ValueTuple.Create);
}

/// <summary>
Expand All @@ -41,11 +38,7 @@ public static partial class SuperEnumerable
/// </remarks>
public static IEnumerable<TResult> Lag<TSource, TResult>(this IEnumerable<TSource> source, int offset, Func<TSource, TSource?, TResult> resultSelector)
{
Guard.IsNotNull(source);
Guard.IsNotNull(resultSelector);

return source.Select(Some)
.Lag(offset, default, (curr, lag) => resultSelector(curr.Value, lag is (true, var some) ? some : default));
return source.Lag(offset, default!, resultSelector);
}

/// <summary>
Expand All @@ -70,6 +63,9 @@ public static IEnumerable<TResult> Lag<TSource, TResult>(this IEnumerable<TSourc
Guard.IsNotNull(resultSelector);
Guard.IsGreaterThanOrEqualTo(offset, 1);

if (source is IList<TSource> list)
return new LagIterator<TSource, TResult>(list, offset, defaultLagValue, resultSelector);

return Core(source, offset, defaultLagValue, resultSelector);

static IEnumerable<TResult> Core(IEnumerable<TSource> source, int offset, TSource defaultLagValue, Func<TSource, TSource, TResult> resultSelector)
Expand All @@ -84,4 +80,36 @@ static IEnumerable<TResult> Core(IEnumerable<TSource> source, int offset, TSourc
}
}
}

private sealed class LagIterator<TSource, TResult> : ListIterator<TResult>
{
private readonly IList<TSource> _source;
private readonly int _offset;
private readonly TSource _defaultLagValue;
private readonly Func<TSource, TSource, TResult> _resultSelector;

public LagIterator(IList<TSource> source, int offset, TSource defaultLagValue, Func<TSource, TSource, TResult> resultSelector)
{
_source = source;
_offset = offset;
_defaultLagValue = defaultLagValue;
_resultSelector = resultSelector;
}

public override int Count => _source.Count;

protected override IEnumerable<TResult> GetEnumerable()
{
var cnt = (uint)_source.Count;
for (var i = 0; i < cnt; i++)
yield return _resultSelector(
_source[i],
i < _offset ? _defaultLagValue : _source[i - _offset]);
}

protected override TResult ElementAt(int index) =>
_resultSelector(
_source[index],
index < _offset ? _defaultLagValue : _source[index - _offset]);
}
}
180 changes: 101 additions & 79 deletions Tests/SuperLinq.Test/LagTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,108 +35,130 @@ public void TestLagZeroOffset()
new BreakingSequence<int>().Lag(0, (val, lagVal) => val + lagVal));
}

/// <summary>
/// Verify that lag can accept an propagate a default value passed to it.
/// </summary>
[Fact]
public void TestLagExplicitDefaultValue()
public static IEnumerable<object[]> GetIntSequences() =>
Enumerable.Range(1, 100)
.GetListSequences()
.Select(x => new object[] { x });

[Theory]
[MemberData(nameof(GetIntSequences))]
public void TestLagExplicitDefaultValue(IDisposableEnumerable<int> seq)
{
using var sequence = Enumerable.Range(1, 100).AsTestingSequence();

var result = sequence.Lag(10, -1, (val, lagVal) => lagVal).ToList();
Assert.Equal(100, result.Count);
Assert.Equal(Enumerable.Repeat(-1, 10), result.Take(10));
using (seq)
{
var result = seq.Lag(10, -1, (val, lagVal) => lagVal);
result.AssertSequenceEqual(
Enumerable.Repeat(-1, 10).Concat(Enumerable.Range(1, 90)));
}
}

[Fact]
public void TestLagTuple()
[Theory]
[MemberData(nameof(GetIntSequences))]
public void TestLagTuple(IDisposableEnumerable<int> seq)
{
using var sequence = Enumerable.Range(1, 100).AsTestingSequence();

var result = sequence.Lag(10).ToList();
Assert.Equal(100, result.Count);
result.AssertSequenceEqual(
Enumerable.Range(1, 100).Select(x => (x, x <= 10 ? default : x - 10)));
using (seq)
{
var result = seq.Lag(10);
result.AssertSequenceEqual(
Enumerable.Range(1, 100).Select(x => (x, x <= 10 ? default : x - 10)));
}
}

/// <summary>
/// Verify that lag will use default(T) if a specific default value is not supplied for the lag value.
/// </summary>
[Fact]
public void TestLagImplicitDefaultValue()
[Theory]
[MemberData(nameof(GetIntSequences))]
public void TestLagImplicitDefaultValue(IDisposableEnumerable<int> seq)
{
using var sequence = Enumerable.Range(1, 100).AsTestingSequence();

var result = sequence.Lag(10, (val, lagVal) => lagVal).ToList();
Assert.Equal(100, result.Count);
Assert.Equal(Enumerable.Repeat(default(int), 10), result.Take(10));
using (seq)
{
var result = seq.Lag(10, (val, lagVal) => lagVal);
result.AssertSequenceEqual(
Enumerable.Repeat(default(int), 10)
.Concat(Enumerable.Range(1, 90)));
}
}

/// <summary>
/// Verify that if the lag offset is greater than the sequence length lag
/// still yields all of the elements of the source sequence.
/// </summary>
[Fact]
public void TestLagOffsetGreaterThanSequenceLength()
[Theory]
[MemberData(nameof(GetIntSequences))]
public void TestLagOffsetGreaterThanSequenceLength(IDisposableEnumerable<int> seq)
{
using var sequence = Enumerable.Range(1, 100).AsTestingSequence();

var result = sequence.Lag(100 + 1, (a, b) => a).ToList();
Assert.Equal(100, result.Count);
Assert.Equal(Enumerable.Range(1, 100), result);
using (seq)
{
var result = seq.Lag(100 + 1, (a, b) => a);
result.AssertSequenceEqual(
Enumerable.Range(1, 100));
}
}

/// <summary>
/// Verify that lag actually yields the correct pair of values from the sequence
/// when offsetting by a single item.
/// </summary>
[Fact]
public void TestLagPassesCorrectLagValueOffsetBy1()
[Theory]
[MemberData(nameof(GetIntSequences))]
public void TestLagPassesCorrectLagValueOffsetBy1(IDisposableEnumerable<int> seq)
{
using var sequence = Enumerable.Range(1, 100).AsTestingSequence();

var result = sequence.Lag(1, (a, b) => new { A = a, B = b }).ToList();
Assert.Equal(100, result.Count);
Assert.True(result.All(x => x.B == (x.A - 1)));
using (seq)
{
var result = seq.Lag(1);
result.AssertSequenceEqual(
Enumerable.Range(1, 100)
.Select(x => (x, x - 1)));
}
}

/// <summary>
/// Verify that lag yields the correct pair of values from the sequence when
/// offsetting by more than a single item.
/// </summary>
[Fact]
public void TestLagPassesCorrectLagValuesOffsetBy2()
[Theory]
[MemberData(nameof(GetIntSequences))]
public void TestLagPassesCorrectLagValuesOffsetBy2(IDisposableEnumerable<int> seq)
{
using var sequence = Enumerable.Range(1, 100).AsTestingSequence();
using (seq)
{
var result = seq.Lag(2);
result.AssertSequenceEqual(
Enumerable.Range(1, 100)
.Select(x => (x, x <= 2 ? 0 : x - 2)));
}
}

public static IEnumerable<object[]> GetStringSequences() =>
Seq("foo", "bar", "baz", "qux")
.GetListSequences()
.Select(x => new object[] { x, });

var result = sequence.Lag(2, (a, b) => new { A = a, B = b }).ToList();
Assert.Equal(100, result.Count);
Assert.True(result.Skip(2).All(x => x.B == (x.A - 2)));
Assert.True(result.Take(2).All(x => (x.A - x.B) == x.A));
[Theory]
[MemberData(nameof(GetStringSequences))]
public void TestLagWithNullableReferences(IDisposableEnumerable<string> seq)
{
using (seq)
{
var result = seq.Lag(2, (a, b) => new { A = a, B = b });
result.AssertSequenceEqual(
new { A = "foo", B = (string?)null },
new { A = "bar", B = (string?)null },
new { A = "baz", B = (string?)"foo" },
new { A = "qux", B = (string?)"bar" });
}
}

[Fact]
public void TestLagWithNullableReferences()
[Theory]
[MemberData(nameof(GetStringSequences))]
public void TestLagWithNonNullableReferences(IDisposableEnumerable<string> seq)
{
using var words = TestingSequence.Of("foo", "bar", "baz", "qux");
var result = words.Lag(2, (a, b) => new { A = a, B = b });
result.AssertSequenceEqual(
new { A = "foo", B = (string?)null },
new { A = "bar", B = (string?)null },
new { A = "baz", B = (string?)"foo" },
new { A = "qux", B = (string?)"bar" });
using (seq)
{
var result = seq.Lag(2, string.Empty, (a, b) => new { A = a, B = b });
result.AssertSequenceEqual(
new { A = "foo", B = string.Empty, },
new { A = "bar", B = string.Empty, },
new { A = "baz", B = "foo" },
new { A = "qux", B = "bar" });
}
}

[Fact]
public void TestLagWithNonNullableReferences()
public void ZipMapListBehavior()
{
using var words = TestingSequence.Of("foo", "bar", "baz", "qux");
var empty = string.Empty;
var result = words.Lag(2, empty, (a, b) => new { A = a, B = b });
result.AssertSequenceEqual(
new { A = "foo", B = empty },
new { A = "bar", B = empty },
new { A = "baz", B = "foo" },
new { A = "qux", B = "bar" });
using var seq = Enumerable.Range(0, 10_000).AsBreakingList();

var result = seq.Lag(20);
Assert.Equal(10_000, result.Count());
Assert.Equal((10, 0), result.ElementAt(10));
Assert.Equal((50, 30), result.ElementAt(50));
Assert.Equal((9_950, 9_930), result.ElementAt(^50));
}
}