-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expression compiler: automatic differentiation #253
Comments
253: expression compiler: automatic differentiation Task-Url: #253
for x which is part of DerivativeNode[operand=(a*x+(b*x^2))+(c*x^3), variable=x, derivative=((0*x+a*1)+((0*x^2)+(b*x^2)))+((0*x^3)+(c*x^3))] at arb4j/arb.expressions.nodes.VariableNode.type(VariableNode.java:592) at arb4j/arb.expressions.nodes.binary.BinaryOperationNode.type(BinaryOperationNode.java:363) at arb4j/arb.expressions.nodes.binary.BinaryOperationNode.type(BinaryOperationNode.java:362) at arb4j/arb.expressions.nodes.binary.BinaryOperationNode.type(BinaryOperationNode.java:362) at arb4j/arb.expressions.nodes.binary.BinaryOperationNode.generate(BinaryOperationNode.java:190) at arb4j/arb.expressions.nodes.DerivativeNode.generate(DerivativeNode.java:141) at arb4j/arb.expressions.Expression.generateEvaluationMethod(Expression.java:940) at arb4j/arb.expressions.Expression.generate(Expression.java:787) at arb4j/arb.expressions.Expression.defineClass(Expression.java:570) at arb4j/arb.expressions.Expression.getInstance(Expression.java:1218) at arb4j/arb.expressions.Expression.instantiate(Expression.java:1367) at arb4j/arb.functions.Function.instantiate(Function.java:126) at arb4j/arb.functions.rational.RationalNullaryFunction.express(RationalNullaryFunction.java:29) at arb4j/arb.functions.rational.RationalNullaryFunction.express(RationalNullaryFunction.java:39) at arb4j/arb.RationalFunction.express(RationalFunction.java:245) at arb4j/arb.expressions.ExpressionTest.testRationalFunctionDerivative(ExpressionTest.java:83) at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103) at java.base/java.lang.reflect.Method.invoke(Method.java:580) at junit@4.13.2/junit.framework.TestCase.runTest(TestCase.java:177) at junit@4.13.2/junit.framework.TestCase.runBare(TestCase.java:142) at junit@4.13.2/junit.framework.TestResult$1.protect(TestResult.java:122) at junit@4.13.2/junit.framework.TestResult.runProtected(TestResult.java:142) at junit@4.13.2/junit.framework.TestResult.run(TestResult.java:125) at junit@4.13.2/junit.framework.TestCase.run(TestCase.java:130) at junit@4.13.2/junit.framework.TestSuite.runTest(TestSuite.java:241) at junit@4.13.2/junit.framework.TestSuite.run(TestSuite.java:236) at junit@4.13.2/org.junit.internal.runners.JUnit38ClassRunner.run(JUnit38ClassRunner.java:90) at org.eclipse.jdt.internal.junit4.runner.JUnit4TestReference.run(JUnit4TestReference.java:93) at org.eclipse.jdt.internal.junit.runner.TestExecution.run(TestExecution.java:40) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:530) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:758) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.run(RemoteTestRunner.java:453) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.main(RemoteTestRunner.java:211) #253
interface arb.functions.real.RealFunction at arb4j/arb.functions.Function.newCoDomainInstance(Function.java:320) at arb4j/arb.functions.Function.evaluate(Function.java:239) at arb4j/arb.functions.Function.evaluate(Function.java:220) at arb4j/arb.functions.integer.Sequence.evaluate(Sequence.java:47) at arb4j/arb.expressions.nodes.unary.SphericalBesselFunctionNodeOfTheFirstKindTest.testj0ViaRealFunctionalExpression(SphericalBesselFunctionNodeOfTheFirstKindTest.java:24) at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103) at java.base/java.lang.reflect.Method.invoke(Method.java:580) at junit@4.13.2/junit.framework.TestCase.runTest(TestCase.java:177) at junit@4.13.2/junit.framework.TestCase.runBare(TestCase.java:142) at junit@4.13.2/junit.framework.TestResult$1.protect(TestResult.java:122) at junit@4.13.2/junit.framework.TestResult.runProtected(TestResult.java:142) at junit@4.13.2/junit.framework.TestResult.run(TestResult.java:125) at junit@4.13.2/junit.framework.TestCase.run(TestCase.java:130) at junit@4.13.2/junit.framework.TestSuite.runTest(TestSuite.java:241) at junit@4.13.2/junit.framework.TestSuite.run(TestSuite.java:236) at junit@4.13.2/org.junit.internal.runners.JUnit38ClassRunner.run(JUnit38ClassRunner.java:90) at org.eclipse.jdt.internal.junit4.runner.JUnit4TestReference.run(JUnit4TestReference.java:93) at org.eclipse.jdt.internal.junit.runner.TestExecution.run(TestExecution.java:40) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:530) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:758) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.run(RemoteTestRunner.java:453) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.main(RemoteTestRunner.java:211) #253
package arb.functions.real;
import arb.Initializable;
import arb.Integer;
import arb.Real;
import arb.Typesettable;
import arb.documentation.BusinessSourceLicenseVersionOnePointOne;
import arb.documentation.TheArb4jLibrary;
import arb.expressions.nodes.DerivativeNode;
import junit.framework.TestCase;
/**
* Decompiled {@link DerivativeNode} test function
*
* @see BusinessSourceLicenseVersionOnePointOne © terms of the
* {@link TheArb4jLibrary}
*/
public class TestCompiledDerivative implements
RealFunctional<Object, RealFunction>,
Typesettable,
AutoCloseable,
Initializable
{
public boolean isInitialized;
public final Integer cℤ2 = new Integer("3");
public final Integer cℤ1 = new Integer("2");
public final Integer cℤ4 = new Integer("1");
public final Integer cℤ3 = new Integer("0");
public Real a;
public Real b;
public Real c;
public Real ifuncℝ4 = new Real();
public Real ifuncℝ5 = new Real();
public Integer iℤ2 = new Integer();
public Real ifuncℝ6 = new Real();
public Integer iℤ1 = new Integer();
public Real ifuncℝ7 = new Real();
public Real ifuncℝ1 = new Real();
public Real ifuncℝ2 = new Real();
public Real ifuncℝ3 = new Real();
public Real ifuncℝ8 = new Real();
public static void main(String args[])
{
try ( TestCompiledDerivative derivative = new TestCompiledDerivative())
{
derivative.a = Real.named("a").set(2);
derivative.b = Real.named("b").set(4);
derivative.c = Real.named("c").set(6);
RealFunction d = derivative.evaluate(null, 128);
double val = d.eval(2.3);
TestCase.assertEquals(115.61999999999998, val);
System.out.format("%s(2.3)=%s\n", d, val);
}
}
@Override
public Class<RealFunction> coDomainType()
{
return RealFunction.class;
}
@Override
public RealFunction evaluate(Object in, int order, int bits, RealFunction result)
{
if (!isInitialized)
{
initialize();
}
RealFunction realFunction = new RealFunction()
{
@Override
public Real evaluate(Real input, int order, int bits, Real res)
{
return a.add(b.mul(cℤ1.mul(input.pow(cℤ1.sub(cℤ4, bits, iℤ1), bits, ifuncℝ1), bits, ifuncℝ2), bits, ifuncℝ3),
bits,
ifuncℝ4)
.add(c.mul(cℤ2.mul(input.pow(cℤ2.sub(cℤ4, bits, iℤ2), bits, ifuncℝ5), bits, ifuncℝ6), bits, ifuncℝ7),
bits,
ifuncℝ8);
}
@Override
public String toString()
{
return TestCompiledDerivative.this.toString();
}
};
return realFunction;
}
@Override
public void initialize()
{
if (isInitialized)
{
throw new AssertionError("Already initialized");
}
else if (a == null)
{
throw new AssertionError("x-∂a*x+b*x²+c*x³⁄∂x.a is null");
}
else if (b == null)
{
throw new AssertionError("x-∂a*x+b*x²+c*x³⁄∂x.b is null");
}
else if (c == null)
{
throw new AssertionError("x-∂a*x+b*x²+c*x³⁄∂x.c is null");
}
else
{
isInitialized = true;
}
}
@Override
public void close()
{
cℤ2.close();
cℤ1.close();
cℤ4.close();
cℤ3.close();
ifuncℝ4.close();
ifuncℝ5.close();
iℤ2.close();
ifuncℝ6.close();
iℤ1.close();
ifuncℝ7.close();
ifuncℝ1.close();
ifuncℝ2.close();
ifuncℝ3.close();
ifuncℝ8.close();
}
@Override
public String toString()
{
return "x➔∂a*x+b*x²+c*x³/∂x";
}
@Override
public String typeset()
{
return "a + b \\cdot 2 \\cdot {x}^{(\\left(2-1\\right))} + c \\cdot 3 \\cdot {x}^{(\\left(3-1\\right))}";
}
} |
at k-2*j(k,x)func.evaluate(Unknown Source) at arb4j/arb.functions.real.RealFunction.eval(RealFunction.java:222) at arb4j/arb.expressions.nodes.unary.SphericalBesselFunctionNodeOfTheFirstKindTest.testj0ViaRealFunctionalExpression(SphericalBesselFunctionNodeOfTheFirstKindTest.java:25) at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103) at java.base/java.lang.reflect.Method.invoke(Method.java:580) at junit@4.13.2/junit.framework.TestCase.runTest(TestCase.java:177) at junit@4.13.2/junit.framework.TestCase.runBare(TestCase.java:142) at junit@4.13.2/junit.framework.TestResult$1.protect(TestResult.java:122) at junit@4.13.2/junit.framework.TestResult.runProtected(TestResult.java:142) at junit@4.13.2/junit.framework.TestResult.run(TestResult.java:125) at junit@4.13.2/junit.framework.TestCase.run(TestCase.java:130) at junit@4.13.2/junit.framework.TestSuite.runTest(TestSuite.java:241) at junit@4.13.2/junit.framework.TestSuite.run(TestSuite.java:236) at org.eclipse.jdt.internal.junit.runner.junit3.JUnit3TestReference.run(JUnit3TestReference.java:128) at org.eclipse.jdt.internal.junit.runner.TestExecution.run(TestExecution.java:40) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:530) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.runTests(RemoteTestRunner.java:758) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.run(RemoteTestRunner.java:453) at org.eclipse.jdt.internal.junit.runner.RemoteTestRunner.main(RemoteTestRunner.java:211) #253 gr8
To implement the Key Points for Implementation
Implementing the Differentiate MethodHere's an example implementation to guide you: @Override
public Node<D, R, F> differentiate(VariableNode<D, R, F> variable) {
// Step 1: Differentiate the argument (g'(x)).
Node<D, R, F> argDerivative = arg.differentiate(variable);
// Step 2: Differentiate the function (f'(g(x))).
Node<D, R, F> functionDerivative = differentiateFunction();
// Step 3: Apply the chain rule: f'(g(x)) * g'(x).
return new MultiplicationNode<>(expression, functionDerivative, argDerivative);
}
/**
* Returns the node representing the derivative of the function.
* This will vary based on whether the function is built-in or contextual.
*/
private Node<D, R, F> differentiateFunction() {
// Check if the function is built-in or contextual.
if (isBuiltin()) {
return differentiateBuiltinFunction();
} else if (contextual) {
return differentiateContextualFunction();
} else {
throw new UnsupportedOperationException("Cannot differentiate function: " + functionName);
}
}
/**
* Handles differentiation for built-in functions.
*/
private Node<D, R, F> differentiateBuiltinFunction() {
switch (functionName) {
case "sin":
return new FunctionNode<>("cos", arg, expression); // derivative of sin is cos
case "cos":
return new NegationNode<>(expression, new FunctionNode<>("sin", arg, expression)); // derivative of cos is -sin
case "exp":
return this; // derivative of exp is exp
// Add other built-in function derivatives
default:
throw new UnsupportedOperationException("Derivative not implemented for function: " + functionName);
}
}
/**
* Handles differentiation for contextual functions.
*/
private Node<D, R, F> differentiateContextualFunction() {
// Add logic for differentiating contextual functions, potentially involving more advanced logic
// depending on how contextual functions are defined and used in your framework.
// As an example, you might retrieve a derivative function mapping if available:
if (mapping != null && mapping.derivativeMapping != null) {
return new FunctionNode<>(mapping.derivativeFunctionName, arg, expression);
}
// Alternatively, handle differentiation based on known properties of the function
throw new UnsupportedOperationException("Contextual function differentiation not implemented: " + functionName);
} Considerations
This implementation assumes that you have proper classes and methods like Citations: |
For an n-ary product (product of n functions), the derivative follows this formula: Example for Three FunctionsFor three functions u, v, and w, the derivative is: PatternEach term in the sum is formed by:
The rule extends to any number of functions following this same pattern. Citations: The derivative of a sum follows the linearity property of derivatives - you can differentiate each term separately and then sum the results: PatternThe derivative operator can move inside the summation because:
For example, if you have: This is much simpler than the product rule because addition is a linear operation. Citations: |
For n-ary products, there is indeed a closed form for integration. Using the multinomial theorem and integration by parts, we can express it as: Where:
This is significantly more complex than differentiation, but it does provide a systematic way to compute the integral of a product of n functions. Citations: |
{ var x = RealFunction.parse("∂ln(x)/∂x"); assertEquals("1/x", x.rootNode.toString()); } #253
public void testArcSinDerivative()
{
var f = RealFunction.parse("∂arcsin(x)/∂x");
var df = RealFunction.parse("1/√(1-x^2)");
assertEquals(df.rootNode.toString(), f.rootNode.toString());
} |
Add support for f'(t) and f''(t) syntax for 1st and 2nd derivative, ... Let's write it purely in terms of J_n(t). The Bessel differential equation in standard form: To transform this into a Riccati equation, we use the substitution: The transformation steps:
This shows that w(t), the logarithmic derivative of the Bessel function J_n(t), satisfies this first-order nonlinear differential equation. Ill be damned, it makes sense now Yes, the logarithmic derivative w(t) = J_n'(t)/J_n(t) of the Bessel function J_n(t) satisfies the Riccati equation: This comes directly from the fact that J_n(t) satisfies the Bessel differential equation: The transformation from one to the other is a standard technique for converting certain second-order linear differential equations into first-order nonlinear (Riccati) equations through the logarithmic derivative substitution[1]. Citations: |
\documentclass{article} \title{The Logarithmic Derivative of Bessel Functions} \begin{document} The Bessel function Its logarithmic derivative is: The logarithmic derivative satisfies a Riccati equation, as follows: \begin{enumerate} \item The derivative of this equation is: \item These expressions in the Bessel equation yield: \item Division by \item The equation rearranges to: This Riccati equation has the standard form: where: \end{document} |
Implementing Automatic Differentiator
This issue tracks the implementation of an automatic differentiator for the expression compiler. The compiler parses expressions into binary trees, and this differentiator will apply differentiation rules recursively based on the tree structure.
Algorithm Outline
Base Case:
Differentiation Rules:
Recursive Application:
Construct New Tree:
Simplification (Optional):
Stuff To Be Done And Whatnot
The text was updated successfully, but these errors were encountered: