diff --git a/src/Vivelin.Core/EnumerableExtensions.cs b/src/Vivelin.Core/EnumerableExtensions.cs new file mode 100644 index 0000000..73a785e --- /dev/null +++ b/src/Vivelin.Core/EnumerableExtensions.cs @@ -0,0 +1,34 @@ +using System.Diagnostics; + +namespace Vivelin; + +public static class EnumerableExtensions +{ + public static T Sample(this IReadOnlyList source) + => source.Sample(Random.Shared); + + public static T Sample(this IReadOnlyList source, Random rng) + { + var index = rng.Next(source.Count); + return source[index]; + } + + public static T WeightedSample(this IEnumerable source) where T : IWeighted + => source.WeightedSample(Random.Shared); + + public static T WeightedSample(this IEnumerable source, Random rng) where T : IWeighted + { + var totalWeight = source.Sum(x => x.Weight); + var targetWeight = rng.NextDouble() * totalWeight; + + foreach (var item in source) + { + if (targetWeight < item.Weight) + return item; + + targetWeight -= item.Weight; + } + + throw new UnreachableException(); + } +} diff --git a/src/Vivelin.Core/IWeighted.cs b/src/Vivelin.Core/IWeighted.cs new file mode 100644 index 0000000..e59785c --- /dev/null +++ b/src/Vivelin.Core/IWeighted.cs @@ -0,0 +1,6 @@ +namespace Vivelin; + +public interface IWeighted +{ + double Weight { get; } +} \ No newline at end of file diff --git "a/src/Vivelin.Core/Schr\303\266dinger.cs" "b/src/Vivelin.Core/Schr\303\266dinger.cs" index 80668e6..f621b8f 100644 --- "a/src/Vivelin.Core/Schr\303\266dinger.cs" +++ "b/src/Vivelin.Core/Schr\303\266dinger.cs" @@ -4,6 +4,8 @@ namespace Vivelin; public class Schrödinger { + private static readonly bool s_isWeighted = typeof(T).IsAssignableTo(typeof(IWeighted)); + public Schrödinger(ReadOnlySpan values) { Values = ImmutableList.Create(values); @@ -20,7 +22,11 @@ public static implicit operator T(Schrödinger value) public T Resolve(Random rng) { - var index = rng.Next(Values.Count); - return Values[index]; + if (s_isWeighted) + { + return (T)Values.Cast().WeightedSample(rng); + } + + return Values.Sample(rng); } } diff --git a/src/Vivelin.Core/WeightedValue.cs b/src/Vivelin.Core/WeightedValue.cs new file mode 100644 index 0000000..a3fd8bb --- /dev/null +++ b/src/Vivelin.Core/WeightedValue.cs @@ -0,0 +1,10 @@ +namespace Vivelin; + +public readonly record struct Weighted(T Value, double Weight) : IWeighted +{ + public Weighted(T value) + : this(value, 1.0) { } + + public static implicit operator T(Weighted weighted) + => weighted.Value; +} diff --git a/tests/Vivelin.Core.Tests/EnumerableExtensionsTests.cs b/tests/Vivelin.Core.Tests/EnumerableExtensionsTests.cs new file mode 100644 index 0000000..50a4a21 --- /dev/null +++ b/tests/Vivelin.Core.Tests/EnumerableExtensionsTests.cs @@ -0,0 +1,29 @@ +namespace Vivelin.Core.Tests; + +public class EnumerableExtensionsTests +{ + [Fact] + public void Sample_ReturnsRandomValue() + { + var values = new List(["Bulbasaur", "Charmander", "Squirtle"]); + + values.Sample(Random.Shared).Should().BeOneOf(values); + } + + [Fact] + public void Sample_ReturnsWeightedRandomValue() + { + var values = new List>([ + new("Bulbasaur", 100), + new("Charmander", 1), + new("Squirtle", 1), + ]); + + var results = Enumerable.Range(0, 100).Select(_ => values.WeightedSample(Random.Shared)); + var bulbasaurs = results.Count(x => x.Value == "Bulbasaur"); + var charmanders = results.Count(x => x.Value == "Charmander"); + var squirtles = results.Count(x => x.Value == "Squirtle"); + bulbasaurs.Should().NotBeCloseTo(charmanders, delta: 70); + bulbasaurs.Should().NotBeCloseTo(squirtles, delta: 70); + } +} diff --git "a/tests/Vivelin.Core.Tests/Schr\303\266dingerTests.cs" "b/tests/Vivelin.Core.Tests/Schr\303\266dingerTests.cs" index 015c217..d9f44e4 100644 --- "a/tests/Vivelin.Core.Tests/Schr\303\266dingerTests.cs" +++ "b/tests/Vivelin.Core.Tests/Schr\303\266dingerTests.cs" @@ -17,8 +17,25 @@ public void SchrodingersString_ImplicitlyResolvesToRandomString() { var values = new[] { "Bulbasaur", "Charmander", "Squirtle" }; var instance = new Schrödinger(values); - + // Avoiding FluentAssertions here to demonstrate implicit usage Assert.Contains(instance, values); } -} \ No newline at end of file + + [Fact] + public void SchrodingersWeightedString_SelectsHigherWeightsMoreOften() + { + var instance = new Schrödinger>([ + new("Bulbasaur", 100), + new("Charmander", 1), + new("Squirtle", 1), + ]); + + var values = Enumerable.Range(0, 100).Select(_ => instance.Resolve()); + var bulbasaurs = values.Count(x => x.Value == "Bulbasaur"); + var charmanders = values.Count(x => x.Value == "Charmander"); + var squirtles = values.Count(x => x.Value == "Squirtle"); + bulbasaurs.Should().NotBeCloseTo(charmanders, delta: 70); + bulbasaurs.Should().NotBeCloseTo(squirtles, delta: 70); + } +}