Skip to content

Commit

Permalink
chore: get rid of implicit imports and clarify extension imports (#649)
Browse files Browse the repository at this point in the history
* Remove imports from DI in extensions

* Add DI prefix everywhere

* Unwrap

* Typos

* Typos

* Context

* Inner outer

* Typos

* Basis

* Explicit imports from DI in DIT

* Typos

* Gradient and hvp

* Typos

* Typo

* Typos

* Typos

* Typos

* Remove implicit imports in DIT, add tests

* Relu

* Typo

* DIT

* Retoggle test on 1.11

* Not broken

* Public tests on 1.11

* Bump
  • Loading branch information
gdalle authored Dec 2, 2024
1 parent 8fe1dd1 commit d53f8ac
Show file tree
Hide file tree
Showing 68 changed files with 1,116 additions and 1,011 deletions.
5 changes: 3 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
Expand All @@ -32,10 +33,10 @@ DifferentiationInterfaceEnzymeExt = "Enzyme"
DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfaceForwardDiffExt = ["ForwardDiff", "DiffResults"]
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceReverseDiffExt = ["ReverseDiff", "DiffResults"]
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
DifferentiationInterfaceStaticArraysExt = "StaticArrays"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@ using ChainRulesCore:
frule_via_ad,
rrule_via_ad
import DifferentiationInterface as DI
using DifferentiationInterface:
Constant, DifferentiateWith, NoPullbackPrep, NoPushforwardPrep, PullbackPrep, unwrap

ruleconfig(backend::AutoChainRules) = backend.ruleconfig

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function ChainRulesCore.rrule(dw::DifferentiateWith, x)
function ChainRulesCore.rrule(dw::DI.DifferentiateWith, x)
(; f, backend) = dw
y = f(x)
prep_same = DI.prepare_pullback_same_point(f, backend, x, (y,))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,39 +1,39 @@
## Pullback

struct ChainRulesPullbackPrepSamePoint{Y,PB} <: PullbackPrep
struct ChainRulesPullbackPrepSamePoint{Y,PB} <: DI.PullbackPrep
y::Y
pb::PB
end

function DI.prepare_pullback(
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{Constant,C}
f, ::AutoReverseChainRules, x, ty::NTuple, contexts::Vararg{DI.Constant,C}
) where {C}
return NoPullbackPrep()
return DI.NoPullbackPrep()
end

function DI.prepare_pullback_same_point(
f,
::NoPullbackPrep,
::DI.NoPullbackPrep,
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{Constant,C},
contexts::Vararg{DI.Constant,C},
) where {C}
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(unwrap, contexts)...)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
return ChainRulesPullbackPrepSamePoint(y, pb)
end

function DI.value_and_pullback(
f,
::NoPullbackPrep,
::DI.NoPullbackPrep,
backend::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{Constant,C},
contexts::Vararg{DI.Constant,C},
) where {C}
rc = ruleconfig(backend)
y, pb = rrule_via_ad(rc, f, x, map(unwrap, contexts)...)
y, pb = rrule_via_ad(rc, f, x, map(DI.unwrap, contexts)...)
tx = map(ty) do dy
pb(dy)[2]
end
Expand All @@ -46,7 +46,7 @@ function DI.value_and_pullback(
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{Constant,C},
contexts::Vararg{DI.Constant,C},
) where {C}
(; y, pb) = prep
tx = map(ty) do dy
Expand All @@ -61,7 +61,7 @@ function DI.pullback(
::AutoReverseChainRules,
x,
ty::NTuple,
contexts::Vararg{Constant,C},
contexts::Vararg{DI.Constant,C},
) where {C}
(; pb) = prep
tx = map(ty) do dy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ module DifferentiationInterfaceDiffractorExt

using ADTypes: ADTypes, AutoDiffractor
import DifferentiationInterface as DI
using DifferentiationInterface: NoPushforwardPrep
using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆

DI.check_available(::AutoDiffractor) = true
Expand All @@ -11,9 +10,9 @@ DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()

## Pushforward

DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = NoPushforwardPrep()
DI.prepare_pushforward(f, ::AutoDiffractor, x, tx::NTuple) = DI.NoPushforwardPrep()

function DI.pushforward(f, ::NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
function DI.pushforward(f, ::DI.NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
ty = map(tx) do dx
# code copied from Diffractor.jl
z = ∂☆{1}()(ZeroBundle{1}(f), bundle(x, dx))
Expand All @@ -24,7 +23,7 @@ function DI.pushforward(f, ::NoPushforwardPrep, ::AutoDiffractor, x, tx::NTuple)
end

function DI.value_and_pushforward(
f, prep::NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
f, prep::DI.NoPushforwardPrep, backend::AutoDiffractor, x, tx::NTuple
)
return f(x), DI.pushforward(f, prep, backend, x, tx)
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,6 @@ module DifferentiationInterfaceEnzymeExt
using ADTypes: ADTypes, AutoEnzyme
using Base: Fix1
import DifferentiationInterface as DI
using DifferentiationInterface:
Context,
DerivativePrep,
GradientPrep,
JacobianPrep,
HVPPrep,
PullbackPrep,
PushforwardPrep,
NoDerivativePrep,
NoGradientPrep,
NoHVPPrep,
NoJacobianPrep,
NoPullbackPrep,
NoPushforwardPrep
using Enzyme:
Active,
Annotation,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ function DI.prepare_pushforward(
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
return NoPushforwardPrep()
return DI.NoPushforwardPrep()
end

function DI.value_and_pushforward(
f::F,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
dx_sametype = convert(typeof(x), only(tx))
Expand All @@ -29,11 +29,11 @@ end

function DI.value_and_pushforward(
f::F,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f_and_df = get_f_and_df(f, backend, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
Expand All @@ -46,11 +46,11 @@ end

function DI.pushforward(
f::F,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
dx_sametype = convert(typeof(x), only(tx))
Expand All @@ -63,11 +63,11 @@ end

function DI.pushforward(
f::F,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f_and_df = get_f_and_df(f, backend, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
Expand All @@ -81,11 +81,11 @@ end
function DI.value_and_pushforward!(
f::F,
ty::NTuple,
prep::NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
# dy cannot be passed anyway
y, new_ty = DI.value_and_pushforward(f, prep, backend, x, tx, contexts...)
Expand All @@ -96,11 +96,11 @@ end
function DI.pushforward!(
f::F,
ty::NTuple,
prep::NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
# dy cannot be passed anyway
new_ty = DI.pushforward(f, prep, backend, x, tx, contexts...)
Expand All @@ -110,7 +110,7 @@ end

## Gradient

struct EnzymeForwardGradientPrep{B,O} <: GradientPrep
struct EnzymeForwardGradientPrep{B,O} <: DI.GradientPrep
shadows::O
end

Expand Down Expand Up @@ -175,7 +175,7 @@ end

## Jacobian

struct EnzymeForwardOneArgJacobianPrep{B,O} <: JacobianPrep
struct EnzymeForwardOneArgJacobianPrep{B,O} <: DI.JacobianPrep
shadows::O
output_length::Int
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@ function DI.prepare_pushforward(
::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
return NoPushforwardPrep()
return DI.NoPushforwardPrep()
end

function DI.value_and_pushforward(
f!::F,
y,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{1},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
f!_and_df! = get_f_and_df(f!, backend)
dx_sametype = convert(typeof(x), only(tx))
Expand All @@ -39,11 +39,11 @@ end
function DI.value_and_pushforward(
f!::F,
y,
::NoPushforwardPrep,
::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple{B},
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,B,C}
f!_and_df! = get_f_and_df(f!, backend, Val(B))
tx_sametype = map(Fix1(convert, typeof(x)), tx)
Expand All @@ -64,11 +64,11 @@ end
function DI.pushforward(
f!::F,
y,
prep::NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
_, ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
return ty
Expand All @@ -78,11 +78,11 @@ function DI.value_and_pushforward!(
f!::F,
y,
ty::NTuple,
prep::NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
y, new_ty = DI.value_and_pushforward(f!, y, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
Expand All @@ -93,11 +93,11 @@ function DI.pushforward!(
f!::F,
y,
ty::NTuple,
prep::NoPushforwardPrep,
prep::DI.NoPushforwardPrep,
backend::AutoEnzyme{<:Union{ForwardMode,Nothing}},
x,
tx::NTuple,
contexts::Vararg{Context,C},
contexts::Vararg{DI.Context,C},
) where {F,C}
new_ty = DI.pushforward(f!, y, prep, backend, x, tx, contexts...)
foreach(copyto!, ty, new_ty)
Expand Down
Loading

2 comments on commit d53f8ac

@gdalle
Copy link
Member Author

@gdalle gdalle commented on d53f8ac Dec 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=DifferentiationInterfaceTest

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120560

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a DifferentiationInterfaceTest-v0.9.1 -m "<description of version>" d53f8ac8421dfb89fa469d5bb70d4cc3304f14f5
git push origin DifferentiationInterfaceTest-v0.9.1

Please sign in to comment.