-
Notifications
You must be signed in to change notification settings - Fork 63
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
@non_differentiable
should use identical pullbacks when possible
#678
Comments
Actually, I think I got this. I think the issue can be fixed without relying on macros too much using using ChainRulesCore
function make_pullback_for_non_differentiable(::Val{N}) where {N}
Vararg{Any,N} # throw early for invalid `N`, must be nonnegative `Int`
function pullback_for_non_differentiable(::Vararg{Any,N})
f = _ -> NoTangent()
ntuple(f, Val(N))
end
end
using Test
@testset "`make_pullback_for_non_differentiable`" begin
f = make_pullback_for_non_differentiable
@testset "throws on invalid input" begin
@test_throws Exception f(Val(0.0))
@test_throws Exception f(Val(-1))
end
@testset "identical objects" begin
for i ∈ 0:5
v = Val(i)
@test f(v) === f(v)
end
end
@testset "dispatch" begin
pullback = f(Val(2))
@test_throws MethodError pullback()
@test_throws MethodError pullback(1)
@test (NoTangent(), NoTangent()) === pullback(1, 2)
@test_throws MethodError pullback(1, 2, 3)
end
end |
AFAICT the pullback should return a tuple of length using ChainRulesCore
struct NonDiffPullback{N} end
(pb::NonDiffPullback{N})(::Vararg{Any,N}) where {N} = ntuple(Returns(NoTangent()), Val(N + 1)) |
Thanks, you're right, there's an off-by-one error. But note we can't use |
We can use |
* make `@non_differentiable` use identical pullbacks when possible Fixes #678 * simpler * bump version
The pullbacks returned by
@non_differentiable
-generatedrrule
would ideally be identical for the same type signature. Presumably this could help compilation latency and type stability in user code.Test:
Failure:
I'm not good with macros so I probably won't be tackling this.
The text was updated successfully, but these errors were encountered: