Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

signature doesn't work for callable structs #16

Closed
bgroenks96 opened this issue Apr 8, 2021 · 6 comments · Fixed by #21
Closed

signature doesn't work for callable structs #16

bgroenks96 opened this issue Apr 8, 2021 · 6 comments · Fixed by #21

Comments

@bgroenks96
Copy link
Contributor

bgroenks96 commented Apr 8, 2021

signature fails due to an assertion error for callable structs. Simple example:

struct MyCallableStruct
    val::Int
end
(s::MyCallableStruct)(a::Int,b::Int) = s.val*(a+b)
# get method and invoke signature
methods(MyCallableStruct()) |> first |> signature

ERROR: AssertionError: slot_syms[1] === Symbol("#self#")

This appears to be due to an erroneous assertion in argument_names: @assert slot_syms[1] === Symbol("#self#")

For callable structs, the first slot is the name of the struct argument.

@oxinabox
Copy link
Member

oxinabox commented Jul 2, 2021

PR would be appreciated.
I think the new def of signature that accepts a type-tuple should work for callable structs.
It doesn't try to workout arguement names, so that means it doesn't need to reason about slot_syms.
But we should be able to reason about slotsyms in a way that works for callable slots.

@bgroenks96
Copy link
Contributor Author

What is the reason for the assertion that is currently in the code? It seems like it doesn't really matter what the first item in slot_syms is because the rest of the code ignores it.

@bgroenks96
Copy link
Contributor Author

I basically just copied and overwrote ExprTools.argument_names and deleted the assertion line. It's been working for me ever since.

@oxinabox
Copy link
Member

oxinabox commented Jul 2, 2021

What is the reason for the assertion that is currently in the code? It seems like it doesn't really matter what the first item in slot_syms is because the rest of the code ignores it.

Yeah, this is true.
It is there because this is not part of the public API of julia, and so I wanted to make sure it was what I thought it was, so I could keep orientated.
It can be removed and replaced with a comment saying that that the first name is the name of the function object itself.


Off-topic: what are you using this for?
I made it so i could do invenia/Nabla.jl#189
I am currently preparing a JuliaCon talk about it, so I would be interested to have a second real-world example.

@bgroenks96
Copy link
Contributor Author

I am using Symbolics.jl to generate analytical derivatives of functions, but I wanted to provide a clean API that would simply take a function and an argument name and then generative the derivative w.r.t that argument. Doing this in Symbolics.jl is of course not super hard but does require a bit of work (as well as the use of RuntimeGeneratedFunctions).

So I use ExprTools to extract the user's method signature and select the appropriate argument.

Here's the code snippet, it's relatively straightforward, I think.

"""
    generate_derivative(f, dvar::Symbol)

Automatically generates an analytical partial derivative of `f` w.r.t `dvar` using ModelingToolkit/Symbolics.jl.
To avoid symbolic tracing issues, the function should 1) be pure (no side effects or non-mathematical behavior) and 2) avoid
indeterminate control flow such as if-else or while blocks (technically should work but sometimes doesn't...). Additional
argument names are extracted automatically from the method signature of `f`. Keyword arg `choosefn` should be a function
which selects from available methods of `f` (returned by `methods`); defaults to `first`.
"""
function generate_derivative(f, dvar::Symbol; choosefn=first, contextmodule=CryoGrid)
    # Parse function parameter names using ExprTools
    fms = ExprTools.methods(f)
    symbol(arg::Symbol) = arg
    symbol(expr::Expr) = expr.args[1]
    argnames = map(symbol, ExprTools.signature(choosefn(fms))[:args])
    @assert dvar in argnames "function must have $dvar as an argument"
    dind = findfirst(s -> s == dvar, argnames)
    # Convert to MTK symbols
    argsyms = map(s -> Num(Sym{Real}(s)), argnames)
    # Generate analytical derivative of f
    x = argsyms[dind]
    ∂x = Differential(x)
    ∇f_expr = build_function(∂x(f(argsyms...)) |> expand_derivatives,argsyms...)
    ∇f = @RuntimeGeneratedFunction(∇f_expr)
    return ∇f
end

@oxinabox
Copy link
Member

oxinabox commented Jul 2, 2021

Neat. And way less evil than everything else I can think of doing with this.

bgroenks96 added a commit to bgroenks96/ExprTools.jl that referenced this issue Jul 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants