Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename mutation_support into twoarg_support #230

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions DifferentiationInterface/docs/src/overloads.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Each cell can have three values:
```@setup overloads
using ADTypes: AbstractADType
using DifferentiationInterface
using DifferentiationInterface: backend_str, mutation_support, MutationSupported
using DifferentiationInterface: backend_str, twoarg_support, TwoArgSupported
using Markdown: Markdown
using Diffractor: Diffractor
using Enzyme: Enzyme
Expand Down Expand Up @@ -155,7 +155,7 @@ function print_overloads(backend, ext::Symbol)

println(io, "#### Two-argument functions `f!(y, x)`")
println(io)
if mutation_support(backend) == MutationSupported()
if twoarg_support(backend) == TwoArgSupported()
print_overload_table(io, operators_and_types_f!(backend), ext)
else
println(io, "Backend doesn't support mutating functions.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const AutoForwardChainRules = AutoChainRules{<:RuleConfig{>:HasForwardsMode}}
const AutoReverseChainRules = AutoChainRules{<:RuleConfig{>:HasReverseMode}}

DI.check_available(::AutoChainRules) = true
DI.mutation_support(::AutoChainRules) = DI.MutationNotSupported()
DI.twoarg_support(::AutoChainRules) = DI.TwoArgNotSupported()

include("reverse_onearg.jl")
include("differentiate_with.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using DifferentiationInterface: NoPushforwardExtras
using Diffractor: DiffractorRuleConfig, TaylorTangentIndex, ZeroBundle, bundle, ∂☆

DI.check_available(::AutoDiffractor) = true
DI.mutation_support(::AutoDiffractor) = DI.MutationNotSupported()
DI.twoarg_support(::AutoDiffractor) = DI.TwoArgNotSupported()
DI.pullback_performance(::AutoDiffractor) = DI.PullbackSlow()

## Pushforward
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp
using LinearAlgebra: dot

DI.check_available(::AutoFiniteDifferences) = true
DI.mutation_support(::AutoFiniteDifferences) = DI.MutationNotSupported()
DI.twoarg_support(::AutoFiniteDifferences) = DI.TwoArgNotSupported()

function FiniteDifferences.to_vec(a::OneElement) # TODO: remove type piracy (https://github.com/JuliaDiff/FiniteDifferences.jl/issues/141)
return FiniteDifferences.to_vec(collect(a))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using DifferentiationInterface: NoGradientExtras, NoPullbackExtras
using Tracker: Tracker, back, data, forward, gradient, jacobian, param, withgradient

DI.check_available(::AutoTracker) = true
DI.mutation_support(::AutoTracker) = DI.MutationNotSupported()
DI.twoarg_support(::AutoTracker) = DI.TwoArgNotSupported()

## Pullback

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ using Zygote:
ZygoteRuleConfig, gradient, hessian, jacobian, pullback, withgradient, withjacobian

DI.check_available(::AutoZygote) = true
DI.mutation_support(::AutoZygote) = DI.MutationNotSupported()
DI.twoarg_support(::AutoZygote) = DI.TwoArgNotSupported()

## Pullback

Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/sparse/fallbacks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

for trait in (
:check_available,
:mutation_support,
:twoarg_support,
:pushforward_performance,
:pullback_performance,
:hvp_mode,
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/utils/check.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ end

Check whether `backend` supports differentiation of two-argument functions.
"""
check_twoarg(backend::AbstractADType) = Bool(mutation_support(backend))
check_twoarg(backend::AbstractADType) = Bool(twoarg_support(backend))

sqnorm(x::AbstractArray) = sum(abs2, x)

Expand Down
18 changes: 9 additions & 9 deletions DifferentiationInterface/src/utils/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,25 +3,25 @@
abstract type MutationBehavior end

"""
MutationSupported
TwoArgSupported

Trait identifying backends that support two-argument functions `f!(y, x)`.
"""
struct MutationSupported <: MutationBehavior end
struct TwoArgSupported <: MutationBehavior end

"""
MutationNotSupported
TwoArgNotSupported

Trait identifying backends that do not support two-argument functions `f!(y, x)`.
"""
struct MutationNotSupported <: MutationBehavior end
struct TwoArgNotSupported <: MutationBehavior end

"""
mutation_support(backend)
twoarg_support(backend)

Return [`MutationSupported`](@ref) or [`MutationNotSupported`](@ref) in a statically predictable way.
Return [`TwoArgSupported`](@ref) or [`TwoArgNotSupported`](@ref) in a statically predictable way.
"""
mutation_support(::AbstractADType) = MutationSupported()
twoarg_support(::AbstractADType) = TwoArgSupported()

## Pushforward

Expand Down Expand Up @@ -132,8 +132,8 @@ end

## Conversions

Base.Bool(::MutationSupported) = true
Base.Bool(::MutationNotSupported) = false
Base.Bool(::TwoArgSupported) = true
Base.Bool(::TwoArgNotSupported) = false

Base.Bool(::PushforwardFast) = true
Base.Bool(::PushforwardSlow) = false
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ using DifferentiationInterface:
inner,
mode,
outer,
mutation_support,
twoarg_support,
pushforward_performance,
pullback_performance
using DifferentiationInterface:
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterfaceTest/src/scenarios/scenario.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ operator_place(::AbstractScenario{args,op}) where {args,op} = op

function compatible(backend::AbstractADType, scen::AbstractScenario)
if nb_args(scen) == 2
return Bool(mutation_support(backend))
return Bool(twoarg_support(backend))
end
return true
end
Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterfaceTest/src/utils/zero_backends.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct AutoZeroForward <: AbstractADType end

ADTypes.mode(::AutoZeroForward) = ForwardMode()
DI.check_available(::AutoZeroForward) = true
DI.mutation_support(::AutoZeroForward) = DI.MutationSupported()
DI.twoarg_support(::AutoZeroForward) = DI.TwoArgSupported()

DI.prepare_pushforward(f, ::AutoZeroForward, x, dx) = NoPushforwardExtras()
DI.prepare_pushforward(f!, y, ::AutoZeroForward, x, dx) = NoPushforwardExtras()
Expand Down Expand Up @@ -55,7 +55,7 @@ struct AutoZeroReverse <: AbstractADType end

ADTypes.mode(::AutoZeroReverse) = ReverseMode()
DI.check_available(::AutoZeroReverse) = true
DI.mutation_support(::AutoZeroReverse) = DI.MutationSupported()
DI.twoarg_support(::AutoZeroReverse) = DI.TwoArgSupported()

DI.prepare_pullback(f, ::AutoZeroReverse, x, dy) = NoPullbackExtras()
DI.prepare_pullback(f!, y, ::AutoZeroReverse, x, dy) = NoPullbackExtras()
Expand Down
Loading