From 031f7f480054d3103d768a06e7cd9ad001a5e3e1 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Thu, 10 Feb 2022 00:37:24 +0100 Subject: [PATCH] Retrieve pullback only once for ReverseRuleConfigBackend (#51) * Retrieve pullback only once for ReverseRuleConfigBackend * Update Project.toml * `HasReverseMode` is not exported... * Fix CRC imports * Missing bracket --- Project.toml | 2 +- src/AbstractDifferentiation.jl | 2 +- src/ruleconfig.jl | 14 +++++--------- 3 files changed, 7 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index c544dc8..9cb36e3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "AbstractDifferentiation" uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" authors = ["Mohamed Tarek and contributors"] -version = "0.4.1" +version = "0.4.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 12aadb9..4143be0 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -1,7 +1,7 @@ module AbstractDifferentiation using LinearAlgebra, ExprTools, Requires, Compat -using ChainRulesCore: RuleConfig, rrule_via_ad +using ChainRulesCore: ChainRulesCore export AD diff --git a/src/ruleconfig.jl b/src/ruleconfig.jl index 8174d7a..1dcb2c1 100644 --- a/src/ruleconfig.jl +++ b/src/ruleconfig.jl @@ -3,17 +3,13 @@ AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package. """ -struct ReverseRuleConfigBackend{RC <: RuleConfig} <: AbstractReverseMode +struct ReverseRuleConfigBackend{RC<:ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}} <: AbstractReverseMode ruleconfig::RC end AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...) - return (vs) -> begin - _, back = rrule_via_ad(ab.ruleconfig, f, xs...) - if vs isa Tuple && length(vs) === 1 - return Base.tail(back(vs[1])) - else - return Base.tail(back(vs)) - end - end + _, back = ChainRulesCore.rrule_via_ad(ab.ruleconfig, f, xs...) + pullback(vs) = Base.tail(back(vs)) + pullback(vs::Tuple{Any}) = Base.tail(back(first(vs))) + return pullback end