Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: check mutability of array before preallocating dual buffer #619

Merged
merged 2 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "DifferentiationInterface"
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.6.22"
version = "0.6.23"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ using DifferentiationInterface:
outer,
shuffled_gradient,
unwrap,
with_contexts
with_contexts,
ismutable_array
import ForwardDiff.DiffResults as DR
using ForwardDiff.DiffResults:
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,11 @@
f::F, backend::AutoForwardDiff, x, tx::NTuple, contexts::Vararg{Context,C}
) where {F,C}
T = tag_type(f, backend, x)
xdual_tmp = make_dual_similar(T, x, tx)
if ismutable_array(x)
xdual_tmp = make_dual_similar(T, x, tx)
else
xdual_tmp = nothing

Check warning on line 74 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L74

Added line #L74 was not covered by tests
end
return ForwardDiffOneArgPushforwardPrep{T,typeof(xdual_tmp)}(xdual_tmp)
end

Expand All @@ -92,8 +96,12 @@
tx::NTuple{B},
contexts::Vararg{Context,C},
) where {F,T,B,C}
(; xdual_tmp) = prep
make_dual!(T, xdual_tmp, x, tx)
if ismutable_array(x)
make_dual!(T, prep.xdual_tmp, x, tx)
xdual_tmp = prep.xdual_tmp
else
xdual_tmp = make_dual(T, x, tx)

Check warning on line 103 in DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/onearg.jl#L103

Added line #L103 was not covered by tests
end
contexts_dual = translate(T, Val(B), contexts...)
ydual = f(xdual_tmp, contexts_dual...)
return ydual
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
return vcat(transpose.(map(vec, t))...)
end

DI.ismutable_array(::Type{<:SArray}) = false

Check warning on line 16 in DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/ext/DifferentiationInterfaceStaticArraysExt/DifferentiationInterfaceStaticArraysExt.jl#L16

Added line #L16 was not covered by tests

function DI.BatchSizeSettings(::AutoForwardDiff{nothing}, x::StaticArray)
return BatchSizeSettings{length(x),true,true}(length(x))
end
Expand Down
3 changes: 3 additions & 0 deletions DifferentiationInterface/src/utils/linalg.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
stack_vec_col(t::NTuple) = stack(vec, t; dims=2)
stack_vec_row(t::NTuple) = stack(vec, t; dims=1)

ismutable_array(::Type) = true
ismutable_array(x) = ismutable_array(typeof(x))
17 changes: 17 additions & 0 deletions DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

using ComponentArrays: ComponentArrays
using DifferentiationInterface, DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using ForwardDiff: ForwardDiff
using StaticArrays: StaticArrays
using Test
Expand Down Expand Up @@ -65,3 +66,19 @@
## Static

test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING)

@testset verbose = true "No allocations on StaticArrays" begin
filtered_static_scenarios = filter(static_scenarios(; include_batchified=false)) do scen
DIT.function_place(scen) == :out && DIT.operator_place(scen) == :out

Check warning on line 72 in DifferentiationInterface/test/Back/ForwardDiff/test.jl

View check run for this annotation

Codecov / codecov/patch

DifferentiationInterface/test/Back/ForwardDiff/test.jl#L72

Added line #L72 was not covered by tests
end
data = benchmark_differentiation(
AutoForwardDiff(),
filtered_static_scenarios;
benchmark=:prepared,
excluded=[:hessian, :pullback], # TODO: figure this out
logging=LOGGING,
)
@testset "$(row[:scenario])" for row in eachrow(data)
@test row[:allocs] == 0
end
end;
Loading