Skip to content

Commit

Permalink
Retrieve pullback only once for ReverseRuleConfigBackend (#51)
Browse files Browse the repository at this point in the history
* Retrieve pullback only once for ReverseRuleConfigBackend

* Update Project.toml

* `HasReverseMode` is not exported...

* Fix CRC imports

* Missing bracket
  • Loading branch information
devmotion authored Feb 9, 2022
1 parent 8f0d6db commit 031f7f4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "AbstractDifferentiation"
uuid = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d"
authors = ["Mohamed Tarek <mohamed82008@gmail.com> and contributors"]
version = "0.4.1"
version = "0.4.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
2 changes: 1 addition & 1 deletion src/AbstractDifferentiation.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module AbstractDifferentiation

using LinearAlgebra, ExprTools, Requires, Compat
using ChainRulesCore: RuleConfig, rrule_via_ad
using ChainRulesCore: ChainRulesCore

export AD

Expand Down
14 changes: 5 additions & 9 deletions src/ruleconfig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,13 @@
AD backend that uses reverse mode with any ChainRules-compatible reverse-mode AD package.
"""
struct ReverseRuleConfigBackend{RC <: RuleConfig} <: AbstractReverseMode
struct ReverseRuleConfigBackend{RC<:ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode}} <: AbstractReverseMode
ruleconfig::RC
end

AD.@primitive function pullback_function(ab::ReverseRuleConfigBackend, f, xs...)
return (vs) -> begin
_, back = rrule_via_ad(ab.ruleconfig, f, xs...)
if vs isa Tuple && length(vs) === 1
return Base.tail(back(vs[1]))
else
return Base.tail(back(vs))
end
end
_, back = ChainRulesCore.rrule_via_ad(ab.ruleconfig, f, xs...)
pullback(vs) = Base.tail(back(vs))
pullback(vs::Tuple{Any}) = Base.tail(back(first(vs)))
return pullback
end

2 comments on commit 031f7f4

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/54307

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.4.2 -m "<description of version>" 031f7f480054d3103d768a06e7cd9ad001a5e3e1
git push origin v0.4.2

Please sign in to comment.