-
Notifications
You must be signed in to change notification settings - Fork 11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
signature doesn't work for callable structs #16
Comments
PR would be appreciated. |
What is the reason for the assertion that is currently in the code? It seems like it doesn't really matter what the first item in |
I basically just copied and overwrote |
Yeah, this is true. Off-topic: what are you using this for? |
I am using So I use Here's the code snippet, it's relatively straightforward, I think. """
generate_derivative(f, dvar::Symbol)
Automatically generates an analytical partial derivative of `f` w.r.t `dvar` using ModelingToolkit/Symbolics.jl.
To avoid symbolic tracing issues, the function should 1) be pure (no side effects or non-mathematical behavior) and 2) avoid
indeterminate control flow such as if-else or while blocks (technically should work but sometimes doesn't...). Additional
argument names are extracted automatically from the method signature of `f`. Keyword arg `choosefn` should be a function
which selects from available methods of `f` (returned by `methods`); defaults to `first`.
"""
function generate_derivative(f, dvar::Symbol; choosefn=first, contextmodule=CryoGrid)
# Parse function parameter names using ExprTools
fms = ExprTools.methods(f)
symbol(arg::Symbol) = arg
symbol(expr::Expr) = expr.args[1]
argnames = map(symbol, ExprTools.signature(choosefn(fms))[:args])
@assert dvar in argnames "function must have $dvar as an argument"
dind = findfirst(s -> s == dvar, argnames)
# Convert to MTK symbols
argsyms = map(s -> Num(Sym{Real}(s)), argnames)
# Generate analytical derivative of f
x = argsyms[dind]
∂x = Differential(x)
∇f_expr = build_function(∂x(f(argsyms...)) |> expand_derivatives,argsyms...)
∇f = @RuntimeGeneratedFunction(∇f_expr)
return ∇f
end |
Neat. And way less evil than everything else I can think of doing with this. |
signature
fails due to an assertion error for callable structs. Simple example:ERROR: AssertionError: slot_syms[1] === Symbol("#self#")
This appears to be due to an erroneous assertion in
argument_names
:@assert slot_syms[1] === Symbol("#self#")
For callable structs, the first slot is the name of the struct argument.
The text was updated successfully, but these errors were encountered: