Skip to content

Commit

Permalink
Add weighted random
Browse files Browse the repository at this point in the history
  • Loading branch information
Vivelin committed Dec 7, 2023
1 parent 94309c8 commit 599f0c6
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 4 deletions.
34 changes: 34 additions & 0 deletions src/Vivelin.Core/EnumerableExtensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
using System.Diagnostics;

namespace Vivelin;

public static class EnumerableExtensions
{
public static T Sample<T>(this IReadOnlyList<T> source)
=> source.Sample(Random.Shared);

public static T Sample<T>(this IReadOnlyList<T> source, Random rng)
{
var index = rng.Next(source.Count);
return source[index];
}

public static T WeightedSample<T>(this IEnumerable<T> source) where T : IWeighted
=> source.WeightedSample(Random.Shared);

public static T WeightedSample<T>(this IEnumerable<T> source, Random rng) where T : IWeighted
{
var totalWeight = source.Sum(x => x.Weight);
var targetWeight = rng.NextDouble() * totalWeight;

foreach (var item in source)
{
if (targetWeight < item.Weight)
return item;

targetWeight -= item.Weight;
}

throw new UnreachableException();
}
}
6 changes: 6 additions & 0 deletions src/Vivelin.Core/IWeighted.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
namespace Vivelin;

public interface IWeighted
{
double Weight { get; }
}
10 changes: 8 additions & 2 deletions src/Vivelin.Core/Schrödinger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ namespace Vivelin;

public class Schrödinger<T>
{
private static readonly bool s_isWeighted = typeof(T).IsAssignableTo(typeof(IWeighted));

public Schrödinger(ReadOnlySpan<T> values)
{
Values = ImmutableList.Create(values);
Expand All @@ -20,7 +22,11 @@ public static implicit operator T(Schrödinger<T> value)

public T Resolve(Random rng)
{
var index = rng.Next(Values.Count);
return Values[index];
if (s_isWeighted)
{
return (T)Values.Cast<IWeighted>().WeightedSample(rng);
}

return Values.Sample(rng);
}
}
10 changes: 10 additions & 0 deletions src/Vivelin.Core/WeightedValue.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace Vivelin;

public readonly record struct Weighted<T>(T Value, double Weight) : IWeighted
{
public Weighted(T value)
: this(value, 1.0) { }

public static implicit operator T(Weighted<T> weighted)
=> weighted.Value;
}
29 changes: 29 additions & 0 deletions tests/Vivelin.Core.Tests/EnumerableExtensionsTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
namespace Vivelin.Core.Tests;

public class EnumerableExtensionsTests
{
[Fact]
public void Sample_ReturnsRandomValue()
{
var values = new List<string>(["Bulbasaur", "Charmander", "Squirtle"]);

values.Sample(Random.Shared).Should().BeOneOf(values);
}

[Fact]
public void Sample_ReturnsWeightedRandomValue()
{
var values = new List<Weighted<string>>([
new("Bulbasaur", 100),
new("Charmander", 1),
new("Squirtle", 1),
]);

var results = Enumerable.Range(0, 100).Select(_ => values.WeightedSample(Random.Shared));
var bulbasaurs = results.Count(x => x.Value == "Bulbasaur");
var charmanders = results.Count(x => x.Value == "Charmander");
var squirtles = results.Count(x => x.Value == "Squirtle");
bulbasaurs.Should().NotBeCloseTo(charmanders, delta: 70);
bulbasaurs.Should().NotBeCloseTo(squirtles, delta: 70);
}
}
21 changes: 19 additions & 2 deletions tests/Vivelin.Core.Tests/SchrödingerTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,25 @@ public void SchrodingersString_ImplicitlyResolvesToRandomString()
{
var values = new[] { "Bulbasaur", "Charmander", "Squirtle" };
var instance = new Schrödinger<string>(values);

// Avoiding FluentAssertions here to demonstrate implicit usage
Assert.Contains(instance, values);
}
}

[Fact]
public void SchrodingersWeightedString_SelectsHigherWeightsMoreOften()
{
var instance = new Schrödinger<Weighted<string>>([
new("Bulbasaur", 100),
new("Charmander", 1),
new("Squirtle", 1),
]);

var values = Enumerable.Range(0, 100).Select(_ => instance.Resolve());
var bulbasaurs = values.Count(x => x.Value == "Bulbasaur");
var charmanders = values.Count(x => x.Value == "Charmander");
var squirtles = values.Count(x => x.Value == "Squirtle");
bulbasaurs.Should().NotBeCloseTo(charmanders, delta: 70);
bulbasaurs.Should().NotBeCloseTo(squirtles, delta: 70);
}
}

0 comments on commit 599f0c6

Please sign in to comment.