Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure is not type stable but could be made stable? #177

Closed
Red-Portal opened this issue Jun 24, 2024 · 4 comments
Closed

Restructure is not type stable but could be made stable? #177

Red-Portal opened this issue Jun 24, 2024 · 4 comments

Comments

@Red-Portal
Copy link

Red-Portal commented Jun 24, 2024

Hi, it seems like calling restructure is not stable by default. This is currently causing issues with Enzyme.jl (see this issue). Here is a MWE to illustrate the point:

using Cthulhu, LinearAlgebra, Optimisers, Functors

struct Model{A,B}
    a::A
    b::B
end

Functors.@functor Model

m = Model(randn(10), LowerTriangular(Matrix(I, 10, 10)))

params, re = Optimisers.destructure(m)

@code_warntype re(params)

This returns:

MethodInstance for (::Optimisers.Restructure{Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}, @NamedTuple{a::Int64, b::Tuple{}}})(::Vector{Float64})
  from (re::Optimisers.Restructure)(flat::AbstractVector) @ Optimisers ~/.julia/packages/Optimisers/yDIWk/src/destructure.jl:59
Arguments
  re::Optimisers.Restructure{Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}, @NamedTuple{a::Int64, b::Tuple{}}}
  flat::Vector{Float64}
Body::Model
1%1 = Base.getproperty(re, :model)::Model{Vector{Float64}, LowerTriangular{Bool, Matrix{Bool}}}%2 = Base.getproperty(re, :offsets)::@NamedTuple{a::Int64, b::Tuple{}}
│   %3 = Base.getproperty(re, :length)::Int64%4 = Optimisers._rebuild(%1, %2, flat, %3)::Model
└──      return %4

where the return type Model is not stable in terms of its type parameters.

This can be solved in a brute-force manner by defining re::Restructure (defined here) as:

   (re::Restructure)(flat::AbstractVector)::typeof(re.model) = _rebuild(re.model, re.offsets, flat, re.length)

where we are informing the compiler that the return type will be the same as re.model. I think this is safe to assume, and this immediately resolves the instability. Any thoughts on this?

@ToucheSir
Copy link
Member

Can you elaborate on why you want re(params) to be type stable? If it's to ensure that subsequent code is type stable, then re(params)::typeof(m) in user code might work better. If it's to ensure that Restructure is internally type stable, a type assertion won't be enough.

@Red-Portal
Copy link
Author

Red-Portal commented Jun 28, 2024

Hi @ToucheSir ! The first reason is because Enzyme fails due to the instability. And the second reason is that I find it unusual to be type instable. Restructure is fully aware of the type we expect, so I am actually surprised that the current implementation is not type stable out of the box.

@ToucheSir
Copy link
Member

Restructure is fully aware of the type we expect

Not quite. You are allowed to pass an array of a different eltype to Restructure, which might give you a different model return type.

I am actually surprised that the current implementation is not type stable out of the box.

...but even if you weren't, the type instability would remain. The reason is because Functors uses an untyped IdDict internally to keep track of shared/aliased parameters: https://github.com/FluxML/Functors.jl/blob/2eddcb74f9589e61b847362c2a91d91bd90ef628/src/walks.jl#L170-L201

The fix here would be to tell fmap not to cache by passing cache = nothing. This isn't done by default because we can't assume people are using models with zero parameter sharing/tying, but perhaps the models you're working with can. Disabling the cache would be necessary but may not be sufficient to remove type instabilities, as the Julia compiler really hates nested recursive calls like Functors.jl uses.

P.S. since you're working with Enzyme and not Zygote, you may be interested in adapting the mutating (de|Re)structure implementation in #165 for your use case.

@Red-Portal
Copy link
Author

Thanks for the tips. Let me try to explicitly set the return type of the reconstruct expression as you suggested.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants