Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@non_differentiable should use identical pullbacks when possible #678

Closed
nsajko opened this issue May 30, 2024 · 4 comments · Fixed by #679
Closed

@non_differentiable should use identical pullbacks when possible #678

nsajko opened this issue May 30, 2024 · 4 comments · Fixed by #679

Comments

@nsajko
Copy link
Contributor

nsajko commented May 30, 2024

The pullbacks returned by @non_differentiable-generated rrule would ideally be identical for the same type signature. Presumably this could help compilation latency and type stability in user code.

Test:

f(x) = rand()*x*0.1
g(x) = rand()*x*0.2
using ChainRulesCore
@non_differentiable f(::Any)
@non_differentiable g(::Any)
using Test
@test last(rrule(f, 0.3)) === last(rrule(g, 0.4))

Failure:

julia> @test last(rrule(f, 0.3)) === last(rrule(g, 0.4))
Test Failed at REPL[7]:1
  Expression: last(rrule(f, 0.3)) === last(rrule(g, 0.4))
   Evaluated: var"#f_pullback#2"() === var"#g_pullback#4"()

ERROR: There was an error during testing

I'm not good with macros so I probably won't be tackling this.

@nsajko
Copy link
Contributor Author

nsajko commented May 30, 2024

I'm not good with macros so I probably won't be tackling this.

Actually, I think I got this. I think the issue can be fixed without relying on macros too much using make_pullback_for_non_differentiable:

using ChainRulesCore

function make_pullback_for_non_differentiable(::Val{N}) where {N}
    Vararg{Any,N}  # throw early for invalid `N`, must be nonnegative `Int`
    function pullback_for_non_differentiable(::Vararg{Any,N})
        f = _ -> NoTangent()
        ntuple(f, Val(N))
    end
end

using Test

@testset "`make_pullback_for_non_differentiable`" begin
    f = make_pullback_for_non_differentiable
    @testset "throws on invalid input" begin
        @test_throws Exception f(Val(0.0))
        @test_throws Exception f(Val(-1))
    end
    @testset "identical objects" begin
        for i  0:5
            v = Val(i)
            @test f(v) === f(v)
        end
    end
    @testset "dispatch" begin
        pullback = f(Val(2))
        @test_throws MethodError pullback()
        @test_throws MethodError pullback(1)
        @test (NoTangent(), NoTangent()) === pullback(1, 2)
        @test_throws MethodError pullback(1, 2, 3)
    end
end

@devmotion
Copy link
Member

AFAICT the pullback should return a tuple of length N + 1. Alternatively with callable structs:

using ChainRulesCore

struct NonDiffPullback{N} end
(pb::NonDiffPullback{N})(::Vararg{Any,N}) where {N} = ntuple(Returns(NoTangent()), Val(N + 1))

@nsajko
Copy link
Contributor Author

nsajko commented May 30, 2024

Thanks, you're right, there's an off-by-one error. But note we can't use Returns until support for Julia v1.6 is dropped.

@devmotion
Copy link
Member

We can use Returns since the package depends on Compat >= 3.40 and it seems in version 3.35 support for Returns was added.

devmotion pushed a commit that referenced this issue May 31, 2024
* make `@non_differentiable` use identical pullbacks when possible

Fixes #678

* simpler

* bump version
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants