Skip to content

Commit

Permalink
Update dependencies and add netstandard2.0 target (for HL) and change…
Browse files Browse the repository at this point in the history
… code accordingly.
  • Loading branch information
gkronber committed Jun 17, 2024
1 parent da3c05a commit 79b2991
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 71 deletions.
4 changes: 2 additions & 2 deletions HEAL.Expressions.Benchmark/HEAL.Expressions.Benchmark.csproj
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net7.0</TargetFramework>
<TargetFrameworks>net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="BenchmarkDotNet" Version="0.13.7" />
<PackageReference Include="BenchmarkDotNet" Version="0.13.12" />
</ItemGroup>
<ItemGroup>
<ProjectReference Include="..\HEAL.Expressions\HEAL.Expressions.csproj" />
Expand Down
10 changes: 5 additions & 5 deletions HEAL.Expressions.Tests/HEAL.Expressions.Tests.csproj
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net7.0</TargetFramework>
<TargetFrameworks>net8.0</TargetFrameworks>
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" />
<PackageReference Include="MSTest.TestAdapter" Version="3.1.1" />
<PackageReference Include="MSTest.TestFramework" Version="3.1.1" />
<PackageReference Include="coverlet.collector" Version="6.0.0">
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.10.0" />
<PackageReference Include="MSTest.TestAdapter" Version="3.4.3" />
<PackageReference Include="MSTest.TestFramework" Version="3.4.3" />
<PackageReference Include="coverlet.collector" Version="6.0.2">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
Expand Down
60 changes: 33 additions & 27 deletions HEAL.Expressions/ExprFormatter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@ public static string ToString(Expression<Expr.ParametricFunction> expr, string[]
protected override Expression VisitBinary(BinaryExpression node) {
if (node.NodeType == ExpressionType.ArrayIndex) {
if (node.Left == pParam) {
sb.AppendFormat(CultureInfo.InvariantCulture, "{0:g8}", p[GetIndex(node)]);
sb.AppendFormat(CultureInfo.InvariantCulture, "{0:g8}", p[ExprFormatter.GetIndex(node)]);
} else if (node.Left == varParam) {
sb.Append(EscapeVarName(varNames[GetIndex(node)]));
sb.Append(ExprFormatter.EscapeVarName(varNames[ExprFormatter.GetIndex(node)]));
} else throw new ArgumentException();
} else {
FormatLeftChildExpr(node.NodeType, node.Left);
sb.Append($" {OpSymbol(node.NodeType)} ");
sb.Append($" {ExprFormatter.OpSymbol(node.NodeType)} ");
if (!IsBinaryOp(node.Right)) {
Visit(node.Right);
} else if (node.NodeType == ExpressionType.Subtract || node.NodeType == ExpressionType.Divide) {
Expand Down Expand Up @@ -81,7 +81,7 @@ private Expression VisitPow(MethodCallExpression node) {
protected override Expression VisitUnary(UnaryExpression node) {
if (node.NodeType == ExpressionType.UnaryPlus) Visit(node.Operand);
else {
if (!IsBinaryOp(node.Operand) && Priority(node.NodeType) >= Priority(node.Operand.NodeType)) {
if (!IsBinaryOp(node.Operand) && ExprFormatter.Priority(node.NodeType) >= ExprFormatter.Priority(node.Operand.NodeType)) {
sb.Append("-(");
Visit(node.Operand);
sb.Append(')');
Expand Down Expand Up @@ -115,34 +115,40 @@ private void FormatRightChildExpr(ExpressionType parentType, Expression childExp
Visit(childExpr);
}
}
private int Priority(ExpressionType nodeType) =>
nodeType switch {
ExpressionType.Constant => 13,
ExpressionType.Call or ExpressionType.ArrayIndex => 12,
ExpressionType.UnaryPlus or ExpressionType.Negate => 10,
ExpressionType.Multiply or ExpressionType.Divide => 9,
ExpressionType.Add or ExpressionType.Subtract => 7,
_ => throw new ArgumentException()
};

private string OpSymbol(ExpressionType nodeType) =>
nodeType switch {
ExpressionType.Add => "+",
ExpressionType.Subtract => "-",
ExpressionType.Multiply => "*",
ExpressionType.Divide => "/",
_ => throw new ArgumentException()
};



private int GetIndex(BinaryExpression node) {

private static int Priority(ExpressionType nodeType) {
switch(nodeType) {
case ExpressionType.Constant: return 13;
case ExpressionType.Call:
case ExpressionType.ArrayIndex: return 12;
case ExpressionType.UnaryPlus:
case ExpressionType.Negate: return 10;
case ExpressionType.Multiply:
case ExpressionType.Divide: return 9;
case ExpressionType.Add:
case ExpressionType.Subtract: return 7;
default: throw new ArgumentException();
}
}

private static string OpSymbol(ExpressionType nodeType) {
switch(nodeType) {
case ExpressionType.Add: return "+";
case ExpressionType.Subtract: return "-";
case ExpressionType.Multiply: return "*";
case ExpressionType.Divide: return "/";
default: throw new ArgumentException();
}
}


private static int GetIndex(BinaryExpression node) {
return (int)((ConstantExpression)node.Right).Value;
}



private string EscapeVarName(string v) {
private static string EscapeVarName(string v) {
if (v.Contains(' ')) return $"'{v}'";
else return v;
}
Expand Down
22 changes: 13 additions & 9 deletions HEAL.Expressions/ExpressionInterpreter.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
Expand Down Expand Up @@ -166,7 +165,9 @@ private void Evaluate(double[] theta, double[] f, int startRow, int batchSize) {

if (result.values.Length == 1) {
// parameters or constants
Array.Fill(f, result.GetValue(0), startRow, batchSize);
// Array.Fill(f, result.GetValue(0), startRow, batchSize);
var val = result.GetValue(0);
for (int i = startRow; i < startRow + batchSize; i++) f[i] = val;
} else {
Array.Copy(result.values, 0, f, startRow, batchSize);
}
Expand All @@ -193,12 +194,15 @@ private void EvaluateWithJac(double[] theta, double[] f, int startRow, int batch
// clear arrays
if (jacX != null) Array.Clear(jacX, startRow * jacX.GetLength(1), batchSize * jacX.GetLength(1));
if (jacTheta != null) Array.Clear(jacTheta, startRow * jacTheta.GetLength(1), batchSize * jacTheta.GetLength(1));
for (int i = 0; i < instructions.Count; i++) if (instructions[i].diffValues != null) Array.Clear(instructions[i].diffValues);
for (int i = 0; i < instructions.Count; i++) if (instructions[i].diffValues != null) Array.Clear(instructions[i].diffValues, 0, instructions[i].diffValues.Length);


// backpropagate
var lastInstr = instructions.Last();
if (lastInstr.diffValues != null) Array.Fill(lastInstr.diffValues, 1.0);
if (lastInstr.diffValues != null) {
// Array.Fill(lastInstr.diffValues, 1.0);
for (int i = 0; i < lastInstr.diffValues.Length; i++) lastInstr.diffValues[i] = 1.0;
}

for (int instrIdx = instructions.Count - 1; instrIdx >= 0; instrIdx--) {
var curInstr = instructions[instrIdx];
Expand Down Expand Up @@ -369,11 +373,11 @@ private Instruction.OpcEnum OpCode(Expression expression) {
private struct Instruction {
public enum OpcEnum { None, Const, Param, Var, Neg, Add, Sub, Mul, Div, Log, Abs, Exp, Sin, Cos, Cosh, Tanh, Pow, PowAbs, Sqrt, Cbrt, Sign, Logistic, InvLogistic, LogisticPrime, InvLogisticPrime };

public int idx1 { get; init; } // child idx1 for internal nodes, index into p or x for parameters or variables
public int idx2 { get; init; } // child idx2 for internal nodes (only for binary operations)
public OpcEnum opc { get; init; }
public double[] values { get; init; }// for internal nodes and variables
public double[] diffValues { get; init; } // for reverse autodiff
public int idx1 { get; set; } // child idx1 for internal nodes, index into p or x for parameters or variables
public int idx2 { get; set; } // child idx2 for internal nodes (only for binary operations)
public OpcEnum opc { get; set; }
public double[] values { get; set; }// for internal nodes and variables
public double[] diffValues { get; set; } // for reverse autodiff

public double GetValue(int idx) => values.Length == 1 ? values[0] : values[idx];
}
Expand Down
2 changes: 1 addition & 1 deletion HEAL.Expressions/HEAL.Expressions.csproj
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net7.0</TargetFramework>
<TargetFrameworks>net8.0;netstandard2.0</TargetFrameworks>
</PropertyGroup>
</Project>
24 changes: 12 additions & 12 deletions HEAL.Expressions/RuleBasedSimplificationVisitor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ public class RuleBasedSimplificationVisitor : ExpressionVisitor {
private ParameterExpression p;
private List<double> pValues;

private readonly List<(string rule, Expression expr)> matchedRules = new(); // for debugging which rules are actually used
private readonly List<(string rule, Expression expr)> matchedRules = new List<(string rule, Expression expr)>(); // for debugging which rules are actually used

private readonly Dictionary<string, Expression> visitCache = new(); // for memoization of Visit() methods
private readonly Dictionary<string, Expression> visitCache = new Dictionary<string, Expression>(); // for memoization of Visit() methods

private RuleBasedSimplificationVisitor(ParameterExpression p, double[] pValues, bool debugRules = false) {
this.p = p;
Expand Down Expand Up @@ -1021,7 +1021,7 @@ private void AddReparameterizationRules() {
// only parameters and constants (rule for folding constants should be applied first)
new MethodCallExpressionRule(
"fold parameters and constants in method calls",
e => e.Arguments.All(e => IsParameter(e) || IsConstant(e)),
e => e.Arguments.All(arg => IsParameter(arg) || IsConstant(arg)),
e => NewParameter((double) e.Method.Invoke(e.Object, e.Arguments.Select(GetParameterOrConstantValue).OfType<object>().ToArray()))
),
//
Expand Down Expand Up @@ -1340,12 +1340,12 @@ protected override Expression VisitBinary(BinaryExpression node) {
var right = Visit(node.Right);
result = node.Update(left, null, right);

var r = binaryRules.FirstOrDefault(r => r.Match((BinaryExpression)result));
var r = binaryRules.FirstOrDefault(ru => ru.Match((BinaryExpression)result));
while (r != default(BinaryExpressionRule)) {
MarkUsage(r.Description, result);
result = r.Apply((BinaryExpression)result);
if (result is BinaryExpression binExpr) {
r = binaryRules.FirstOrDefault(r => r.Match(binExpr));
r = binaryRules.FirstOrDefault(ru => ru.Match(binExpr));
} else break;
}

Expand All @@ -1359,12 +1359,12 @@ protected override Expression VisitUnary(UnaryExpression node) {

var opd = Visit(node.Operand);
result = node.Update(opd);
var r = unaryRules.FirstOrDefault(r => r.Match((UnaryExpression)result));
var r = unaryRules.FirstOrDefault(ru => ru.Match((UnaryExpression)result));
while (r != default(UnaryExpressionRule)) {
MarkUsage(r.Description, result);
result = r.Apply((UnaryExpression)result);
if (result is UnaryExpression unaryExpr) {
r = unaryRules.FirstOrDefault(r => r.Match(unaryExpr));
r = unaryRules.FirstOrDefault(ru => ru.Match(unaryExpr));
} else break;
}

Expand All @@ -1376,12 +1376,12 @@ protected override Expression VisitMethodCall(MethodCallExpression node) {
if (visitCache.TryGetValue(node.ToString(), out var result)) return result;

result = node.Update(node.Object, node.Arguments.Select(Visit));
var r = callRules.FirstOrDefault(r => r.Match((MethodCallExpression)result));
var r = callRules.FirstOrDefault(ru => ru.Match((MethodCallExpression)result));
while (r != default(MethodCallExpressionRule)) {
MarkUsage(r.Description, result);
result = r.Apply((MethodCallExpression)result);
if (result is MethodCallExpression callExpr) {
r = callRules.FirstOrDefault(r => r.Match(callExpr));
r = callRules.FirstOrDefault(ru => ru.Match(callExpr));
} else break;
}

Expand All @@ -1392,7 +1392,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node) {

// for debugging rules
private void MarkUsage(string description, Expression expr) {
matchedRules.Add(new(description, expr));
matchedRules.Add(ValueTuple.Create(description, expr));
}

private int Compare(Expression left, Expression right) {
Expand Down Expand Up @@ -1488,8 +1488,8 @@ private bool HasScalingParameter(Expression arg) {
private bool IsPower(MethodInfo method) => method == pow || method == powabs;

private IEnumerable<Expression> FoldTerms(IEnumerable<Expression> terms) {
Dictionary<string, Expression> exprStr2scale = new();
Dictionary<string, Expression> exprStr2expr = new();
var exprStr2scale = new Dictionary<string, Expression>();
var exprStr2expr = new Dictionary<string, Expression>();
foreach (var t in terms) {
(var scaledTerm, var scale) = ExtractScaleExprFromTerm(t);
var scaledTermStr = scaledTerm.ToString();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net7.0</TargetFramework>
<TargetFrameworks>net8.0</TargetFrameworks>
<ImplicitUsings>enable</ImplicitUsings>
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
</PropertyGroup>
<ItemGroup>
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.7.1" />
<PackageReference Include="MSTest.TestAdapter" Version="3.1.1" />
<PackageReference Include="MSTest.TestFramework" Version="3.1.1" />
<PackageReference Include="coverlet.collector" Version="6.0.0">
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.10.0" />
<PackageReference Include="MSTest.TestAdapter" Version="3.4.3" />
<PackageReference Include="MSTest.TestFramework" Version="3.4.3" />
<PackageReference Include="coverlet.collector" Version="6.0.2">
<PrivateAssets>all</PrivateAssets>
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>
</PackageReference>
<PackageReference Include="NUnit" Version="3.13.3" />
<PackageReference Include="NUnit" Version="4.1.0" />
<PackageReference Include="NUnit3TestAdapter" Version="4.5.0" />
</ItemGroup>
<ItemGroup>
Expand Down
3 changes: 2 additions & 1 deletion HEAL.NonlinearRegression.Console.Tests/NLR_EndToEnd.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using NUnit.Framework;
using NUnit.Framework.Legacy;
using System.Globalization;

namespace HEAL.NonlinearRegression.Console.Tests {
Expand Down Expand Up @@ -484,7 +485,7 @@ internal void RunConsoleTest(Action action, string expected) {
}
var actual = File.ReadAllText(randFilename);
System.Console.WriteLine(actual);
Assert.AreEqual(expected.ReplaceLineEndings(), actual.ReplaceLineEndings());
ClassicAssert.AreEqual(expected.ReplaceLineEndings(), actual.ReplaceLineEndings());
} finally {
if (File.Exists(randFilename))
File.Delete(randFilename);
Expand Down
13 changes: 7 additions & 6 deletions HEAL.NonlinearRegression.Console.Tests/NonlinearRegression.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using NUnit.Framework;
using System.Globalization;
using System.Globalization;
using NUnit.Framework;
using NUnit.Framework.Legacy;

namespace HEAL.NonlinearRegression.Console.Tests {
public class NLR {
Expand Down Expand Up @@ -53,13 +54,13 @@ public void FitPuromycin() {
nlr.WriteStatistics();

System.Console.WriteLine($"Deviance: {nlr.Deviance:e4}, BIC: {nlr.BIC:f2}");
Assert.AreEqual(96.91354730673082, nlr.BIC, 1e-5);
ClassicAssert.AreEqual(96.91354730673082, nlr.BIC, 1e-5);

var prediction = nlr.PredictWithIntervals(x, IntervalEnum.LaplaceApproximation);
System.Console.WriteLine($"pred: {prediction[0, 0]}, low: {prediction[0, 2]}, high: {prediction[0, 3]}");
Assert.AreEqual(50.565977770482867, prediction[0, 0], 1e-4);
Assert.AreEqual(41.543747215791058, prediction[0, 2], 1e-4);
Assert.AreEqual(59.588208325174676, prediction[0, 3], 1e-6);
ClassicAssert.AreEqual(50.565977770482867, prediction[0, 0], 1e-4);
ClassicAssert.AreEqual(41.543747215791058, prediction[0, 2], 1e-4);
ClassicAssert.AreEqual(59.588208325174676, prediction[0, 3], 1e-6);
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<OutputType>Exe</OutputType>
<TargetFramework>net7.0</TargetFramework>
<TargetFrameworks>net8.0</TargetFrameworks>
<ImplicitUsings>disable</ImplicitUsings>
<Nullable>enable</Nullable>
</PropertyGroup>
Expand Down
2 changes: 1 addition & 1 deletion HEAL.NonlinearRegression/HEAL.NonlinearRegression.csproj
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
<Project Sdk="Microsoft.NET.Sdk">
<PropertyGroup>
<TargetFramework>net7.0</TargetFramework>
<TargetFrameworks>net8.0;netstandard2.0</TargetFrameworks>
<Description>Fit and evaluate nonlinear regression models.</Description>
<PackageProjectUrl>https://github.com/heal-research/HEAL.NonlinearRegression</PackageProjectUrl>
<PackageReadmeFile>README.md</PackageReadmeFile>
Expand Down

0 comments on commit 79b2991

Please sign in to comment.