Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING] Transition from Tapir to Mooncake #500

Merged
merged 4 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 29 additions & 29 deletions .github/workflows/Test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,27 @@ jobs:
- "pre"
group:
- Formalities
- Internals
- Back/ChainRulesCore
- Back/Diffractor
- Back/Enzyme
- Back/FastDifferentiation
- Back/FiniteDiff
- Back/FiniteDifferences
- Back/ForwardDiff
- Back/PolyesterForwardDiff
- Back/ReverseDiff
- Back/SecondOrder
- Back/Symbolics
- Back/Tapir
- Back/Tracker
- Back/Zygote
- Misc/DifferentiateWith
- Misc/FromPrimitive
- Misc/SparsityDetector
- Misc/ZeroBackends
- Down/Flux
- Down/Lux
# - Internals
# - Back/ChainRulesCore
# - Back/Diffractor
# - Back/Enzyme
# - Back/FastDifferentiation
# - Back/FiniteDiff
# - Back/FiniteDifferences
# - Back/ForwardDiff
- Back/Mooncake
# - Back/PolyesterForwardDiff
# - Back/ReverseDiff
# - Back/SecondOrder
# - Back/Symbolics
# - Back/Tracker
# - Back/Zygote
# - Misc/DifferentiateWith
# - Misc/FromPrimitive
# - Misc/SparsityDetector
# - Misc/ZeroBackends
# - Down/Flux
# - Down/Lux
exclude:
# lts
- version: "lts"
Expand All @@ -67,14 +67,14 @@ jobs:
group: Back/FiniteDiff
- version: "lts"
group: Back/FastDifferentiation
- version: "lts"
group: Back/Mooncake
- version: "lts"
group: Back/PolyesterForwardDiff
- version: "lts"
group: Back/SecondOrder
- version: "lts"
group: Back/Symbolics
- version: "lts"
group: Back/Tapir
- version: "lts"
group: Misc/SparsityDetector
- version: "lts"
Expand All @@ -89,7 +89,7 @@ jobs:
- version: "pre"
group: Back/Enzyme
- version: "pre"
group: Back/Tapir
group: Back/Mooncake
- version: "pre"
group: Back/SecondOrder
- version: "pre"
Expand Down Expand Up @@ -140,13 +140,13 @@ jobs:
matrix:
version:
- "1"
- "lts"
- "pre"
# - "lts"
# - "pre"
group:
- Formalities
- Zero
- Standard
- Weird
# - Zero
# - Standard
# - Weird
exclude:
- version: "lts"
group: Formalities
Expand Down
10 changes: 5 additions & 5 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@ FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -34,17 +34,17 @@ DifferentiationInterfaceFastDifferentiationExt = "FastDifferentiation"
DifferentiationInterfaceFiniteDiffExt = "FiniteDiff"
DifferentiationInterfaceFiniteDifferencesExt = "FiniteDifferences"
DifferentiationInterfaceForwardDiffExt = "ForwardDiff"
DifferentiationInterfaceMooncakeExt = "Mooncake"
DifferentiationInterfacePolyesterForwardDiffExt = "PolyesterForwardDiff"
DifferentiationInterfaceReverseDiffExt = "ReverseDiff"
DifferentiationInterfaceSparseArraysExt = "SparseArrays"
DifferentiationInterfaceSparseMatrixColoringsExt = "SparseMatrixColorings"
DifferentiationInterfaceSymbolicsExt = "Symbolics"
DifferentiationInterfaceTapirExt = "Tapir"
DifferentiationInterfaceTrackerExt = "Tracker"
DifferentiationInterfaceZygoteExt = ["Zygote", "ForwardDiff"]

[compat]
ADTypes = "1.7.0"
ADTypes = "1.9.0"
ChainRulesCore = "1.23.0"
Compat = "3.46,4.2"
Diffractor = "=0.2.6"
Expand All @@ -54,14 +54,14 @@ FiniteDiff = "2.23.1"
FiniteDifferences = "0.12.31"
ForwardDiff = "0.10.36"
LinearAlgebra = "<0.0.1,1"
Mooncake = "0.4.0"
PackageExtensionCompat = "1.0.2"
PolyesterForwardDiff = "0.1.1"
ReverseDiff = "1.15.1"
SparseArrays = "<0.0.1,1"
SparseConnectivityTracer = "0.5.0,0.6"
SparseMatrixColorings = "0.4.0"
Symbolics = "5.27.1, 6"
Tapir = "0.2.48"
Tracker = "0.2.33"
Zygote = "0.6.69"
julia = "1.6"
Expand All @@ -81,6 +81,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -91,7 +92,6 @@ SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ We support the following backends defined by [ADTypes.jl](https://github.com/Sci
- [FiniteDiff.jl](https://github.com/JuliaDiff/FiniteDiff.jl)
- [FiniteDifferences.jl](https://github.com/JuliaDiff/FiniteDifferences.jl)
- [ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl)
- [Mooncake.jl](https://github.com/withbayes/Mooncake.jl)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
- [PolyesterForwardDiff.jl](https://github.com/JuliaDiff/PolyesterForwardDiff.jl)
- [ReverseDiff.jl](https://github.com/JuliaDiff/ReverseDiff.jl)
- [Symbolics.jl](https://github.com/JuliaSymbolics/Symbolics.jl)
- [Tapir.jl](https://github.com/withbayes/Tapir.jl)
- [Tracker.jl](https://github.com/FluxML/Tracker.jl)
- [Zygote.jl](https://github.com/FluxML/Zygote.jl)

Expand Down
8 changes: 4 additions & 4 deletions DifferentiationInterface/docs/src/explanation/backends.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ We support the following dense backend choices from [ADTypes.jl](https://github.
- [`AutoFiniteDiff`](@extref ADTypes.AutoFiniteDiff)
- [`AutoFiniteDifferences`](@extref ADTypes.AutoFiniteDifferences)
- [`AutoForwardDiff`](@extref ADTypes.AutoForwardDiff)
- [`AutoMooncake`](@extref ADTypes.AutoMooncake)
- [`AutoPolyesterForwardDiff`](@extref ADTypes.AutoPolyesterForwardDiff)
- [`AutoReverseDiff`](@extref ADTypes.AutoReverseDiff)
- [`AutoSymbolics`](@extref ADTypes.AutoSymbolics)
- [`AutoTapir`](@extref ADTypes.AutoTapir)
- [`AutoTracker`](@extref ADTypes.AutoTracker)
- [`AutoZygote`](@extref ADTypes.AutoZygote)

Expand Down Expand Up @@ -55,10 +55,10 @@ In practice, many AD backends have custom implementations for high-level operato
| `AutoFiniteDiff` | 🔀 | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| `AutoFiniteDifferences` | 🔀 | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| `AutoForwardDiff` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| `AutoMooncake` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| `AutoPolyesterForwardDiff` | 🔀 | ❌ | 🔀 | ✅ | ✅ | 🔀 | 🔀 | 🔀 |
| `AutoReverseDiff` | ❌ | 🔀 | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| `AutoSymbolics` | ✅ | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| `AutoTapir` | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| `AutoTracker` | ❌ | ✅ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ |
| `AutoZygote` | ❌ | ✅ | ❌ | ✅ | ✅ | ✅ | 🔀 | ❌ |

Expand Down Expand Up @@ -144,9 +144,9 @@ For all operators, preparation generates an [executable function](https://docs.s
!!! warning
Preparation can be very slow for symbolic AD.

### Tapir
### Mooncake

For `pullback`, preparation [builds the reverse rule](https://github.com/withbayes/Tapir.jl?tab=readme-ov-file#how-it-works) of the function.
For `pullback`, preparation [builds the reverse rule](https://github.com/withbayes/Mooncake.jl?tab=readme-ov-file#how-it-works) of the function.
gdalle marked this conversation as resolved.
Show resolved Hide resolved

### Tracker

Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
module DifferentiationInterfaceTapirExt
module DifferentiationInterfaceMooncakeExt

using ADTypes: ADTypes, AutoTapir
using ADTypes: ADTypes, AutoMooncake
import DifferentiationInterface as DI
using DifferentiationInterface: PullbackPrep, Tangents
using Tapir:
using Mooncake:
CoDual,
NoTangent,
build_rrule,
Expand All @@ -19,9 +18,9 @@ using Tapir:
fdata,
rdata,
__value_and_pullback!!,
get_tapir_interpreter
get_interpreter

DI.check_available(::AutoTapir) = true
DI.check_available(::AutoMooncake) = true

include("onearg.jl")
include("twoarg.jl")
Expand Down
Original file line number Diff line number Diff line change
@@ -1,42 +1,47 @@
struct TapirOneArgPullbackPrep{Y,R} <: PullbackPrep
struct MooncakeOneArgPullbackPrep{Y,R} <: DI.PullbackPrep
y_prototype::Y
rrule::R
end

function DI.prepare_pullback(f, backend::AutoTapir, x, ty::Tangents)
function DI.prepare_pullback(f, backend::AutoMooncake, x, ty::DI.Tangents)
y = f(x)
rrule = build_rrule(
get_tapir_interpreter(),
get_interpreter(),
Tuple{typeof(f),typeof(x)};
safety_on=backend.safe_mode,
silence_safety_messages=false,
debug_mode=false, # TODO: modify
silence_debug_messages=false,
)
gdalle marked this conversation as resolved.
Show resolved Hide resolved
prep = TapirOneArgPullbackPrep(y, rrule)
prep = MooncakeOneArgPullbackPrep(y, rrule)
DI.value_and_pullback(f, prep, backend, x, ty) # warm up
return prep
end

function DI.value_and_pullback(
f, prep::TapirOneArgPullbackPrep, backend::AutoTapir, x, ty::Tangents
f, prep::MooncakeOneArgPullbackPrep, backend::AutoMooncake, x, ty::DI.Tangents
)
y = f(x)
tx = map(ty) do dy
only(DI.pullback(f, prep, backend, x, Tangents(dy)))
only(DI.pullback(f, prep, backend, x, DI.Tangents(dy)))
end
return y, tx
end

function DI.value_and_pullback(
f, prep::TapirOneArgPullbackPrep{Y}, ::AutoTapir, x, ty::Tangents{1}
f, prep::MooncakeOneArgPullbackPrep{Y}, ::AutoMooncake, x, ty::DI.Tangents{1}
) where {Y}
dy = only(ty)
dy_righttype = convert(tangent_type(Y), dy)
new_y, (_, new_dx) = value_and_pullback!!(prep.rrule, dy_righttype, f, x)
return new_y, Tangents(new_dx)
return new_y, DI.Tangents(new_dx)
end

function DI.value_and_pullback!(
f, prep::TapirOneArgPullbackPrep{Y}, tx::Tangents, ::AutoTapir, x, ty::Tangents{1}
f,
prep::MooncakeOneArgPullbackPrep{Y},
tx::DI.Tangents,
::AutoMooncake,
x,
ty::DI.Tangents{1},
) where {Y}
dx, dy = only(tx), only(ty)
dy_righttype = convert(tangent_type(Y), dy)
Expand All @@ -48,12 +53,19 @@ function DI.value_and_pullback!(
return y, tx
end

function DI.pullback(f, prep::TapirOneArgPullbackPrep, backend::AutoTapir, x, ty::Tangents)
function DI.pullback(
f, prep::MooncakeOneArgPullbackPrep, backend::AutoMooncake, x, ty::DI.Tangents
)
return DI.value_and_pullback(f, prep, backend, x, ty)[2]
end

function DI.pullback!(
f, tx::Tangents, prep::TapirOneArgPullbackPrep, backend::AutoTapir, x, ty::Tangents
f,
tx::DI.Tangents,
prep::MooncakeOneArgPullbackPrep,
backend::AutoMooncake,
x,
ty::DI.Tangents,
)
return DI.value_and_pullback!(f, tx, prep, backend, x, ty)[2]
end
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
struct TapirTwoArgPullbackPrep{R} <: PullbackPrep
struct MooncakeTwoArgPullbackPrep{R} <: DI.PullbackPrep
rrule::R
end

function DI.prepare_pullback(f!, y, backend::AutoTapir, x, ty::Tangents)
function DI.prepare_pullback(f!, y, backend::AutoMooncake, x, ty::DI.Tangents)
rrule = build_rrule(
get_tapir_interpreter(),
get_interpreter(),
Tuple{typeof(f!),typeof(y),typeof(x)};
safety_on=backend.safe_mode,
silence_safety_messages=false,
debug_mode=false, # TODO: modify
silence_debug_messages=false,
)
prep = TapirTwoArgPullbackPrep(rrule)
prep = MooncakeTwoArgPullbackPrep(rrule)
DI.value_and_pullback(f!, y, prep, backend, x, ty) # warm up
return prep
end

# see https://github.com/withbayes/Tapir.jl/issues/113#issuecomment-2036718992
# see https://github.com/withbayes/Mooncake.jl/issues/113#issuecomment-2036718992

function DI.value_and_pullback(
f!, y, prep::TapirTwoArgPullbackPrep, backend::AutoTapir, x, ty::Tangents
f!, y, prep::MooncakeTwoArgPullbackPrep, backend::AutoMooncake, x, ty::DI.Tangents
)
tx = map(ty) do dy
only(DI.pullback(f!, y, prep, backend, x, Tangents(dy)))
only(DI.pullback(f!, y, prep, backend, x, DI.Tangents(dy)))
end
f!(y, x)
return y, tx
end

function DI.value_and_pullback(
f!, y, prep::TapirTwoArgPullbackPrep, ::AutoTapir, x, ty::Tangents{1}
f!, y, prep::MooncakeTwoArgPullbackPrep, ::AutoMooncake, x, ty::DI.Tangents{1}
)
dy = only(ty)
dy_righttype = convert(tangent_type(typeof(y)), copy(dy))
Expand Down Expand Up @@ -67,5 +67,5 @@ function DI.value_and_pullback(
# Run the reverse-pass.
_, _, new_dx = pb!!(NoRData())

return y, Tangents(tangent(fdata(dx_righttype), new_dx))
return y, DI.Tangents(tangent(fdata(dx_righttype), new_dx))
end
4 changes: 2 additions & 2 deletions DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ using ADTypes:
AutoFiniteDiff,
AutoFiniteDifferences,
AutoForwardDiff,
AutoMooncake,
AutoPolyesterForwardDiff,
AutoReverseDiff,
AutoSymbolics,
AutoTapir,
AutoTracker,
AutoZygote
using Compat
Expand Down Expand Up @@ -110,10 +110,10 @@ export AutoFastDifferentiation
export AutoFiniteDiff
export AutoFiniteDifferences
export AutoForwardDiff
export AutoMooncake
export AutoPolyesterForwardDiff
export AutoReverseDiff
export AutoSymbolics
export AutoTapir
export AutoTracker
export AutoZygote

Expand Down
Loading
Loading