Skip to content

Commit

Permalink
feat: add ConstructionBaseExt to allow Setfield and Functors support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 13, 2024
1 parent c5c2b8c commit b05286d
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 5 deletions.
9 changes: 7 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
name = "ADTypes"
uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
authors = ["Vaibhav Dixit <vaibhavyashdixit@gmail.com>, Guillaume Dalle and contributors"]
version = "1.9.1"
version = "1.10.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[weakdeps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

[extensions]
ADTypesChainRulesCoreExt = "ChainRulesCore"
ADTypesConstructionBaseExt = "ConstructionBase"
ADTypesEnzymeCoreExt = "EnzymeCore"

[compat]
ChainRulesCore = "1.0.2"
ConstructionBase = "1.5"
EnzymeCore = "0.5.3,0.6,0.7,0.8"
julia = "1.6"

Expand All @@ -25,7 +29,8 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "ChainRulesCore", "EnzymeCore", "JET", "Test"]
test = ["Aqua", "ChainRulesCore", "EnzymeCore", "JET", "Setfield", "Test"]
18 changes: 18 additions & 0 deletions ext/ADTypesConstructionBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module ADTypesConstructionBaseExt

using ADTypes: AutoEnzyme, AutoForwardDiff, AutoPolyesterForwardDiff
using ConstructionBase: ConstructionBase

function ConstructionBase.constructorof(::Type{<:AutoEnzyme{M, A}}) where {M, A}
return AutoEnzyme{A}
end

function ConstructionBase.constructorof(::Type{<:AutoForwardDiff{chunksize}}) where {chunksize}
return AutoForwardDiff{chunksize}
end

function ConstructionBase.constructorof(::Type{<:AutoPolyesterForwardDiff{chunksize}}) where {chunksize}
return AutoPolyesterForwardDiff{chunksize}
end

end
14 changes: 11 additions & 3 deletions src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,11 @@ struct AutoEnzyme{M, A} <: AbstractADType
mode::M
end

AutoEnzyme{A}(mode::M) where {M, A} = AutoEnzyme{M, A}(mode)

function AutoEnzyme(;
mode::M = nothing, function_annotation::Type{A} = Nothing) where {M, A}
return AutoEnzyme{M, A}(mode)
return AutoEnzyme{A}(mode)
end

mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension
Expand Down Expand Up @@ -181,8 +183,10 @@ struct AutoForwardDiff{chunksize, T} <: AbstractADType
tag::T
end

AutoForwardDiff{chunksize}(tag::T) where {chunksize, T} = AutoForwardDiff{chunksize, T}(tag)

function AutoForwardDiff(; chunksize = nothing, tag = nothing)
AutoForwardDiff{chunksize, typeof(tag)}(tag)
return AutoForwardDiff{chunksize}(tag)
end

mode(::AutoForwardDiff) = ForwardMode()
Expand Down Expand Up @@ -271,8 +275,12 @@ struct AutoPolyesterForwardDiff{chunksize, T} <: AbstractADType
tag::T
end

function AutoPolyesterForwardDiff{chunksize}(tag::T) where {chunksize, T}
return AutoPolyesterForwardDiff{chunksize, T}(tag)
end

function AutoPolyesterForwardDiff(; chunksize = nothing, tag = nothing)
AutoPolyesterForwardDiff{chunksize, typeof(tag)}(tag)
return AutoPolyesterForwardDiff{chunksize}(tag)
end

mode(::AutoPolyesterForwardDiff) = ForwardMode()
Expand Down
33 changes: 33 additions & 0 deletions test/misc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,36 @@ for backend in [
]
println(backend)
end

using Setfield

@testset "Setfield compatibility" begin
ad = AutoEnzyme()
@test ad.mode === nothing
@set! ad.mode = EnzymeCore.Reverse
@test ad.mode isa EnzymeCore.ReverseMode

struct CustomTestTag end

ad = AutoForwardDiff()
@test ad.tag === nothing
@set! ad.tag = CustomTestTag()
@test ad.tag isa CustomTestTag

ad = AutoForwardDiff(; chunksize = 10)
@test ad.tag === nothing
@set! ad.tag = CustomTestTag()
@test ad.tag isa CustomTestTag
@test ad isa AutoForwardDiff{10}

ad = AutoPolyesterForwardDiff()
@test ad.tag === nothing
@set! ad.tag = CustomTestTag()
@test ad.tag isa CustomTestTag

ad = AutoPolyesterForwardDiff(; chunksize = 10)
@test ad.tag === nothing
@set! ad.tag = CustomTestTag()
@test ad.tag isa CustomTestTag
@test ad isa AutoPolyesterForwardDiff{10}
end

0 comments on commit b05286d

Please sign in to comment.