From 79b299162688f09194317c9878c551f494113e9c Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Mon, 17 Jun 2024 18:27:17 +0200 Subject: [PATCH] Update dependencies and add netstandard2.0 target (for HL) and change code accordingly. --- .../HEAL.Expressions.Benchmark.csproj | 4 +- .../HEAL.Expressions.Tests.csproj | 10 ++-- HEAL.Expressions/ExprFormatter.cs | 60 ++++++++++--------- HEAL.Expressions/ExpressionInterpreter.cs | 22 ++++--- HEAL.Expressions/HEAL.Expressions.csproj | 2 +- .../RuleBasedSimplificationVisitor.cs | 24 ++++---- ...L.NonlinearRegression.Console.Tests.csproj | 12 ++-- .../NLR_EndToEnd.cs | 3 +- .../NonlinearRegression.cs | 13 ++-- .../HEAL.NonlinearRegression.Console.csproj | 2 +- .../HEAL.NonlinearRegression.csproj | 2 +- 11 files changed, 83 insertions(+), 71 deletions(-) diff --git a/HEAL.Expressions.Benchmark/HEAL.Expressions.Benchmark.csproj b/HEAL.Expressions.Benchmark/HEAL.Expressions.Benchmark.csproj index 3d5642b..b93f013 100644 --- a/HEAL.Expressions.Benchmark/HEAL.Expressions.Benchmark.csproj +++ b/HEAL.Expressions.Benchmark/HEAL.Expressions.Benchmark.csproj @@ -1,12 +1,12 @@ Exe - net7.0 + net8.0 enable enable - + diff --git a/HEAL.Expressions.Tests/HEAL.Expressions.Tests.csproj b/HEAL.Expressions.Tests/HEAL.Expressions.Tests.csproj index 248a791..1cb9bd1 100644 --- a/HEAL.Expressions.Tests/HEAL.Expressions.Tests.csproj +++ b/HEAL.Expressions.Tests/HEAL.Expressions.Tests.csproj @@ -1,13 +1,13 @@ - net7.0 + net8.0 false - - - - + + + + all runtime; build; native; contentfiles; analyzers; buildtransitive diff --git a/HEAL.Expressions/ExprFormatter.cs b/HEAL.Expressions/ExprFormatter.cs index 48d867c..66f608d 100644 --- a/HEAL.Expressions/ExprFormatter.cs +++ b/HEAL.Expressions/ExprFormatter.cs @@ -30,13 +30,13 @@ public static string ToString(Expression expr, string[] protected override Expression VisitBinary(BinaryExpression node) { if (node.NodeType == ExpressionType.ArrayIndex) { if (node.Left == pParam) { - sb.AppendFormat(CultureInfo.InvariantCulture, "{0:g8}", p[GetIndex(node)]); + sb.AppendFormat(CultureInfo.InvariantCulture, "{0:g8}", p[ExprFormatter.GetIndex(node)]); } else if (node.Left == varParam) { - sb.Append(EscapeVarName(varNames[GetIndex(node)])); + sb.Append(ExprFormatter.EscapeVarName(varNames[ExprFormatter.GetIndex(node)])); } else throw new ArgumentException(); } else { FormatLeftChildExpr(node.NodeType, node.Left); - sb.Append($" {OpSymbol(node.NodeType)} "); + sb.Append($" {ExprFormatter.OpSymbol(node.NodeType)} "); if (!IsBinaryOp(node.Right)) { Visit(node.Right); } else if (node.NodeType == ExpressionType.Subtract || node.NodeType == ExpressionType.Divide) { @@ -81,7 +81,7 @@ private Expression VisitPow(MethodCallExpression node) { protected override Expression VisitUnary(UnaryExpression node) { if (node.NodeType == ExpressionType.UnaryPlus) Visit(node.Operand); else { - if (!IsBinaryOp(node.Operand) && Priority(node.NodeType) >= Priority(node.Operand.NodeType)) { + if (!IsBinaryOp(node.Operand) && ExprFormatter.Priority(node.NodeType) >= ExprFormatter.Priority(node.Operand.NodeType)) { sb.Append("-("); Visit(node.Operand); sb.Append(')'); @@ -115,34 +115,40 @@ private void FormatRightChildExpr(ExpressionType parentType, Expression childExp Visit(childExpr); } } - private int Priority(ExpressionType nodeType) => - nodeType switch { - ExpressionType.Constant => 13, - ExpressionType.Call or ExpressionType.ArrayIndex => 12, - ExpressionType.UnaryPlus or ExpressionType.Negate => 10, - ExpressionType.Multiply or ExpressionType.Divide => 9, - ExpressionType.Add or ExpressionType.Subtract => 7, - _ => throw new ArgumentException() - }; - - private string OpSymbol(ExpressionType nodeType) => - nodeType switch { - ExpressionType.Add => "+", - ExpressionType.Subtract => "-", - ExpressionType.Multiply => "*", - ExpressionType.Divide => "/", - _ => throw new ArgumentException() - }; - - - - private int GetIndex(BinaryExpression node) { + + private static int Priority(ExpressionType nodeType) { + switch(nodeType) { + case ExpressionType.Constant: return 13; + case ExpressionType.Call: + case ExpressionType.ArrayIndex: return 12; + case ExpressionType.UnaryPlus: + case ExpressionType.Negate: return 10; + case ExpressionType.Multiply: + case ExpressionType.Divide: return 9; + case ExpressionType.Add: + case ExpressionType.Subtract: return 7; + default: throw new ArgumentException(); + } + } + + private static string OpSymbol(ExpressionType nodeType) { + switch(nodeType) { + case ExpressionType.Add: return "+"; + case ExpressionType.Subtract: return "-"; + case ExpressionType.Multiply: return "*"; + case ExpressionType.Divide: return "/"; + default: throw new ArgumentException(); + } + } + + + private static int GetIndex(BinaryExpression node) { return (int)((ConstantExpression)node.Right).Value; } - private string EscapeVarName(string v) { + private static string EscapeVarName(string v) { if (v.Contains(' ')) return $"'{v}'"; else return v; } diff --git a/HEAL.Expressions/ExpressionInterpreter.cs b/HEAL.Expressions/ExpressionInterpreter.cs index 1488662..ca1e6d7 100644 --- a/HEAL.Expressions/ExpressionInterpreter.cs +++ b/HEAL.Expressions/ExpressionInterpreter.cs @@ -1,6 +1,5 @@ using System; using System.Collections.Generic; -using System.Collections.Immutable; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -166,7 +165,9 @@ private void Evaluate(double[] theta, double[] f, int startRow, int batchSize) { if (result.values.Length == 1) { // parameters or constants - Array.Fill(f, result.GetValue(0), startRow, batchSize); + // Array.Fill(f, result.GetValue(0), startRow, batchSize); + var val = result.GetValue(0); + for (int i = startRow; i < startRow + batchSize; i++) f[i] = val; } else { Array.Copy(result.values, 0, f, startRow, batchSize); } @@ -193,12 +194,15 @@ private void EvaluateWithJac(double[] theta, double[] f, int startRow, int batch // clear arrays if (jacX != null) Array.Clear(jacX, startRow * jacX.GetLength(1), batchSize * jacX.GetLength(1)); if (jacTheta != null) Array.Clear(jacTheta, startRow * jacTheta.GetLength(1), batchSize * jacTheta.GetLength(1)); - for (int i = 0; i < instructions.Count; i++) if (instructions[i].diffValues != null) Array.Clear(instructions[i].diffValues); + for (int i = 0; i < instructions.Count; i++) if (instructions[i].diffValues != null) Array.Clear(instructions[i].diffValues, 0, instructions[i].diffValues.Length); // backpropagate var lastInstr = instructions.Last(); - if (lastInstr.diffValues != null) Array.Fill(lastInstr.diffValues, 1.0); + if (lastInstr.diffValues != null) { + // Array.Fill(lastInstr.diffValues, 1.0); + for (int i = 0; i < lastInstr.diffValues.Length; i++) lastInstr.diffValues[i] = 1.0; + } for (int instrIdx = instructions.Count - 1; instrIdx >= 0; instrIdx--) { var curInstr = instructions[instrIdx]; @@ -369,11 +373,11 @@ private Instruction.OpcEnum OpCode(Expression expression) { private struct Instruction { public enum OpcEnum { None, Const, Param, Var, Neg, Add, Sub, Mul, Div, Log, Abs, Exp, Sin, Cos, Cosh, Tanh, Pow, PowAbs, Sqrt, Cbrt, Sign, Logistic, InvLogistic, LogisticPrime, InvLogisticPrime }; - public int idx1 { get; init; } // child idx1 for internal nodes, index into p or x for parameters or variables - public int idx2 { get; init; } // child idx2 for internal nodes (only for binary operations) - public OpcEnum opc { get; init; } - public double[] values { get; init; }// for internal nodes and variables - public double[] diffValues { get; init; } // for reverse autodiff + public int idx1 { get; set; } // child idx1 for internal nodes, index into p or x for parameters or variables + public int idx2 { get; set; } // child idx2 for internal nodes (only for binary operations) + public OpcEnum opc { get; set; } + public double[] values { get; set; }// for internal nodes and variables + public double[] diffValues { get; set; } // for reverse autodiff public double GetValue(int idx) => values.Length == 1 ? values[0] : values[idx]; } diff --git a/HEAL.Expressions/HEAL.Expressions.csproj b/HEAL.Expressions/HEAL.Expressions.csproj index 40711dc..67e47bb 100644 --- a/HEAL.Expressions/HEAL.Expressions.csproj +++ b/HEAL.Expressions/HEAL.Expressions.csproj @@ -1,5 +1,5 @@ - net7.0 + net8.0;netstandard2.0 \ No newline at end of file diff --git a/HEAL.Expressions/RuleBasedSimplificationVisitor.cs b/HEAL.Expressions/RuleBasedSimplificationVisitor.cs index 0e4c07c..fde5e85 100644 --- a/HEAL.Expressions/RuleBasedSimplificationVisitor.cs +++ b/HEAL.Expressions/RuleBasedSimplificationVisitor.cs @@ -15,9 +15,9 @@ public class RuleBasedSimplificationVisitor : ExpressionVisitor { private ParameterExpression p; private List pValues; - private readonly List<(string rule, Expression expr)> matchedRules = new(); // for debugging which rules are actually used + private readonly List<(string rule, Expression expr)> matchedRules = new List<(string rule, Expression expr)>(); // for debugging which rules are actually used - private readonly Dictionary visitCache = new(); // for memoization of Visit() methods + private readonly Dictionary visitCache = new Dictionary(); // for memoization of Visit() methods private RuleBasedSimplificationVisitor(ParameterExpression p, double[] pValues, bool debugRules = false) { this.p = p; @@ -1021,7 +1021,7 @@ private void AddReparameterizationRules() { // only parameters and constants (rule for folding constants should be applied first) new MethodCallExpressionRule( "fold parameters and constants in method calls", - e => e.Arguments.All(e => IsParameter(e) || IsConstant(e)), + e => e.Arguments.All(arg => IsParameter(arg) || IsConstant(arg)), e => NewParameter((double) e.Method.Invoke(e.Object, e.Arguments.Select(GetParameterOrConstantValue).OfType().ToArray())) ), // @@ -1340,12 +1340,12 @@ protected override Expression VisitBinary(BinaryExpression node) { var right = Visit(node.Right); result = node.Update(left, null, right); - var r = binaryRules.FirstOrDefault(r => r.Match((BinaryExpression)result)); + var r = binaryRules.FirstOrDefault(ru => ru.Match((BinaryExpression)result)); while (r != default(BinaryExpressionRule)) { MarkUsage(r.Description, result); result = r.Apply((BinaryExpression)result); if (result is BinaryExpression binExpr) { - r = binaryRules.FirstOrDefault(r => r.Match(binExpr)); + r = binaryRules.FirstOrDefault(ru => ru.Match(binExpr)); } else break; } @@ -1359,12 +1359,12 @@ protected override Expression VisitUnary(UnaryExpression node) { var opd = Visit(node.Operand); result = node.Update(opd); - var r = unaryRules.FirstOrDefault(r => r.Match((UnaryExpression)result)); + var r = unaryRules.FirstOrDefault(ru => ru.Match((UnaryExpression)result)); while (r != default(UnaryExpressionRule)) { MarkUsage(r.Description, result); result = r.Apply((UnaryExpression)result); if (result is UnaryExpression unaryExpr) { - r = unaryRules.FirstOrDefault(r => r.Match(unaryExpr)); + r = unaryRules.FirstOrDefault(ru => ru.Match(unaryExpr)); } else break; } @@ -1376,12 +1376,12 @@ protected override Expression VisitMethodCall(MethodCallExpression node) { if (visitCache.TryGetValue(node.ToString(), out var result)) return result; result = node.Update(node.Object, node.Arguments.Select(Visit)); - var r = callRules.FirstOrDefault(r => r.Match((MethodCallExpression)result)); + var r = callRules.FirstOrDefault(ru => ru.Match((MethodCallExpression)result)); while (r != default(MethodCallExpressionRule)) { MarkUsage(r.Description, result); result = r.Apply((MethodCallExpression)result); if (result is MethodCallExpression callExpr) { - r = callRules.FirstOrDefault(r => r.Match(callExpr)); + r = callRules.FirstOrDefault(ru => ru.Match(callExpr)); } else break; } @@ -1392,7 +1392,7 @@ protected override Expression VisitMethodCall(MethodCallExpression node) { // for debugging rules private void MarkUsage(string description, Expression expr) { - matchedRules.Add(new(description, expr)); + matchedRules.Add(ValueTuple.Create(description, expr)); } private int Compare(Expression left, Expression right) { @@ -1488,8 +1488,8 @@ private bool HasScalingParameter(Expression arg) { private bool IsPower(MethodInfo method) => method == pow || method == powabs; private IEnumerable FoldTerms(IEnumerable terms) { - Dictionary exprStr2scale = new(); - Dictionary exprStr2expr = new(); + var exprStr2scale = new Dictionary(); + var exprStr2expr = new Dictionary(); foreach (var t in terms) { (var scaledTerm, var scale) = ExtractScaleExprFromTerm(t); var scaledTermStr = scaledTerm.ToString(); diff --git a/HEAL.NonlinearRegression.Console.Tests/HEAL.NonlinearRegression.Console.Tests.csproj b/HEAL.NonlinearRegression.Console.Tests/HEAL.NonlinearRegression.Console.Tests.csproj index 4a89183..0ff8f01 100644 --- a/HEAL.NonlinearRegression.Console.Tests/HEAL.NonlinearRegression.Console.Tests.csproj +++ b/HEAL.NonlinearRegression.Console.Tests/HEAL.NonlinearRegression.Console.Tests.csproj @@ -1,19 +1,19 @@ - net7.0 + net8.0 enable enable false - - - - + + + + all runtime; build; native; contentfiles; analyzers; buildtransitive - + diff --git a/HEAL.NonlinearRegression.Console.Tests/NLR_EndToEnd.cs b/HEAL.NonlinearRegression.Console.Tests/NLR_EndToEnd.cs index c8067ec..1d458a4 100644 --- a/HEAL.NonlinearRegression.Console.Tests/NLR_EndToEnd.cs +++ b/HEAL.NonlinearRegression.Console.Tests/NLR_EndToEnd.cs @@ -1,4 +1,5 @@ using NUnit.Framework; +using NUnit.Framework.Legacy; using System.Globalization; namespace HEAL.NonlinearRegression.Console.Tests { @@ -484,7 +485,7 @@ internal void RunConsoleTest(Action action, string expected) { } var actual = File.ReadAllText(randFilename); System.Console.WriteLine(actual); - Assert.AreEqual(expected.ReplaceLineEndings(), actual.ReplaceLineEndings()); + ClassicAssert.AreEqual(expected.ReplaceLineEndings(), actual.ReplaceLineEndings()); } finally { if (File.Exists(randFilename)) File.Delete(randFilename); diff --git a/HEAL.NonlinearRegression.Console.Tests/NonlinearRegression.cs b/HEAL.NonlinearRegression.Console.Tests/NonlinearRegression.cs index 6f09a3e..d1dd04d 100644 --- a/HEAL.NonlinearRegression.Console.Tests/NonlinearRegression.cs +++ b/HEAL.NonlinearRegression.Console.Tests/NonlinearRegression.cs @@ -1,5 +1,6 @@ -using NUnit.Framework; -using System.Globalization; +using System.Globalization; +using NUnit.Framework; +using NUnit.Framework.Legacy; namespace HEAL.NonlinearRegression.Console.Tests { public class NLR { @@ -53,13 +54,13 @@ public void FitPuromycin() { nlr.WriteStatistics(); System.Console.WriteLine($"Deviance: {nlr.Deviance:e4}, BIC: {nlr.BIC:f2}"); - Assert.AreEqual(96.91354730673082, nlr.BIC, 1e-5); + ClassicAssert.AreEqual(96.91354730673082, nlr.BIC, 1e-5); var prediction = nlr.PredictWithIntervals(x, IntervalEnum.LaplaceApproximation); System.Console.WriteLine($"pred: {prediction[0, 0]}, low: {prediction[0, 2]}, high: {prediction[0, 3]}"); - Assert.AreEqual(50.565977770482867, prediction[0, 0], 1e-4); - Assert.AreEqual(41.543747215791058, prediction[0, 2], 1e-4); - Assert.AreEqual(59.588208325174676, prediction[0, 3], 1e-6); + ClassicAssert.AreEqual(50.565977770482867, prediction[0, 0], 1e-4); + ClassicAssert.AreEqual(41.543747215791058, prediction[0, 2], 1e-4); + ClassicAssert.AreEqual(59.588208325174676, prediction[0, 3], 1e-6); } } } diff --git a/HEAL.NonlinearRegression.Console/HEAL.NonlinearRegression.Console.csproj b/HEAL.NonlinearRegression.Console/HEAL.NonlinearRegression.Console.csproj index 859ded8..10780e3 100644 --- a/HEAL.NonlinearRegression.Console/HEAL.NonlinearRegression.Console.csproj +++ b/HEAL.NonlinearRegression.Console/HEAL.NonlinearRegression.Console.csproj @@ -1,7 +1,7 @@ Exe - net7.0 + net8.0 disable enable diff --git a/HEAL.NonlinearRegression/HEAL.NonlinearRegression.csproj b/HEAL.NonlinearRegression/HEAL.NonlinearRegression.csproj index 712bb89..7b1e71b 100644 --- a/HEAL.NonlinearRegression/HEAL.NonlinearRegression.csproj +++ b/HEAL.NonlinearRegression/HEAL.NonlinearRegression.csproj @@ -1,6 +1,6 @@  - net7.0 + net8.0;netstandard2.0 Fit and evaluate nonlinear regression models. https://github.com/heal-research/HEAL.NonlinearRegression README.md