Skip to content

Commit

Permalink
Remove FillArrays dependency by implementing custom OneElement (#453)
Browse files Browse the repository at this point in the history
* Remove FillArrays dep by implementing custom `OneElement`

* Clean up

* Fixes
  • Loading branch information
gdalle authored Sep 6, 2024
1 parent e89793f commit b2dcdef
Show file tree
Hide file tree
Showing 17 changed files with 110 additions and 47 deletions.
2 changes: 0 additions & 2 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ version = "0.6.0"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"

Expand Down Expand Up @@ -51,7 +50,6 @@ Compat = "3.46,4.2"
Diffractor = "=0.2.6"
Enzyme = "0.12.35"
FastDifferentiation = "0.3.9, 0.4"
FillArrays = "1.7.0"
FiniteDiff = "2.23.1"
FiniteDifferences = "0.12.31"
ForwardDiff = "0.10.36"
Expand Down
1 change: 0 additions & 1 deletion DifferentiationInterface/docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ DocumenterInterLinks = "d12716ef-a0f6-4df4-a9f1-a5a34e75c656"
DocumenterMermaid = "a078cd44-4d9c-4618-b545-3ab9d77f9177"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@ using FastDifferentiation:
make_variables,
sparse_hessian,
sparse_jacobian
using FillArrays: Fill
using LinearAlgebra: dot
using FastDifferentiation.RuntimeGeneratedFunctions: RuntimeGeneratedFunction

DI.check_available(::AutoFastDifferentiation) = true

monovec(x::Number) = Fill(x, 1)
monovec(x::Number) = [x]

myvec(x::Number) = monovec(x)
myvec(x::AbstractArray) = vec(x)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ using ADTypes: AutoFiniteDifferences
import DifferentiationInterface as DI
using DifferentiationInterface:
NoGradientExtras, NoJacobianExtras, NoPullbackExtras, NoPushforwardExtras, Tangents
using FillArrays: OneElement
using FiniteDifferences: FiniteDifferences, grad, jacobian, jvp, j′vp
using LinearAlgebra: dot

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ using DifferentiationInterface:
JacobianExtras,
NoPullbackExtras,
Tangents
using FillArrays: OneElement
using ReverseDiff.DiffResults: DiffResults, DiffResult, GradientResult, MutableDiffResult
using LinearAlgebra: dot, mul!
using ReverseDiff:
Expand All @@ -32,10 +31,8 @@ using ReverseDiff:

DI.check_available(::AutoReverseDiff) = true

function DI.basis(
::AutoReverseDiff, a::AbstractArray{T,N}, i::CartesianIndex{N}
) where {T,N}
return OneElement(one(T), Tuple(i), axes(a))
function DI.basis(::AutoReverseDiff, a::AbstractArray{T}, i) where {T}
return DI.OneElement(i, one(T), a)
end

include("onearg.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:iterat
if pushforward_performance(backend) isa PushforwardFast
p = similar(y)
extras = prepare_pushforward_same_point(
f, backend, x, basis(backend, x, first(CartesianIndices(x)))
f, backend, x, basis(backend, x, first(eachindex(x)))
)
for (kj, j) in enumerate(CartesianIndices(x))
for (kj, j) in enumerate(eachindex(x))
pushforward!(f, p, extras, backend, x, basis(backend, x, j))
for ki in LinearIndices(p)
if abs(p[ki]) > atol
Expand All @@ -42,9 +42,9 @@ function ADTypes.jacobian_sparsity(f, x, detector::DenseSparsityDetector{:iterat
else
p = similar(x)
extras = prepare_pullback_same_point(
f, backend, x, basis(backend, y, first(CartesianIndices(y)))
f, backend, x, basis(backend, y, first(eachindex(y)))
)
for (ki, i) in enumerate(CartesianIndices(y))
for (ki, i) in enumerate(eachindex(y))
pullback!(f, p, extras, backend, x, basis(backend, y, i))
for kj in LinearIndices(p)
if abs(p[kj]) > atol
Expand All @@ -64,9 +64,9 @@ function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:it
if pushforward_performance(backend) isa PushforwardFast
p = similar(y)
extras = prepare_pushforward_same_point(
f!, y, backend, x, basis(backend, x, first(CartesianIndices(x)))
f!, y, backend, x, basis(backend, x, first(eachindex(x)))
)
for (kj, j) in enumerate(CartesianIndices(x))
for (kj, j) in enumerate(eachindex(x))
pushforward!(f!, y, p, extras, backend, x, basis(backend, x, j))
for ki in LinearIndices(p)
if abs(p[ki]) > atol
Expand All @@ -78,9 +78,9 @@ function ADTypes.jacobian_sparsity(f!, y, x, detector::DenseSparsityDetector{:it
else
p = similar(x)
extras = prepare_pullback_same_point(
f!, y, backend, x, basis(backend, y, first(CartesianIndices(y)))
f!, y, backend, x, basis(backend, y, first(eachindex(y)))
)
for (ki, i) in enumerate(CartesianIndices(y))
for (ki, i) in enumerate(eachindex(y))
pullback!(f!, y, p, extras, backend, x, basis(backend, y, i))
for kj in LinearIndices(p)
if abs(p[kj]) > atol
Expand All @@ -98,10 +98,8 @@ function ADTypes.hessian_sparsity(f, x, detector::DenseSparsityDetector{:iterati
n = length(x)
I, J = Int[], Int[]
p = similar(x)
extras = prepare_hvp_same_point(
f, backend, x, basis(backend, x, first(CartesianIndices(x)))
)
for (kj, j) in enumerate(CartesianIndices(x))
extras = prepare_hvp_same_point(f, backend, x, basis(backend, x, first(eachindex(x))))
for (kj, j) in enumerate(eachindex(x))
hvp!(f, p, extras, backend, x, basis(backend, x, j))
for ki in LinearIndices(p)
if abs(p[ki]) > atol
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function DI.prepare_hessian(f::F, backend::AutoSparse, x) where {F}
groups = column_groups(coloring_result)
Ng = length(groups)
B = pick_batchsize(maybe_outer(dense_backend), Ng)
seeds = [multibasis(backend, x, CartesianIndices(x)[group]) for group in groups]
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=2)
batched_seeds = [
Tangents(ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B))) for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function _prepare_sparse_jacobian_aux(
groups = column_groups(coloring_result)
Ng = length(groups)
B = pick_batchsize(dense_backend, Ng)
seeds = [multibasis(backend, x, CartesianIndices(x)[group]) for group in groups]
seeds = [multibasis(backend, x, eachindex(x)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(y)), groups; dims=2)
batched_seeds = [
Tangents(ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B))) for
Expand Down Expand Up @@ -122,7 +122,7 @@ function _prepare_sparse_jacobian_aux(
groups = row_groups(coloring_result)
Ng = length(groups)
B = pick_batchsize(dense_backend, Ng)
seeds = [multibasis(backend, y, CartesianIndices(y)[group]) for group in groups]
seeds = [multibasis(backend, y, eachindex(y)[group]) for group in groups]
compressed_matrix = stack(_ -> vec(similar(x)), groups; dims=1)
batched_seeds = [
Tangents(ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % Ng], Val(B))) for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ using DifferentiationInterface:
SecondDerivativeExtras,
Tangents,
maybe_dense_ad
using FillArrays: Fill
using LinearAlgebra: dot
using Symbolics:
build_function,
Expand All @@ -33,7 +32,7 @@ using Symbolics.RuntimeGeneratedFunctions: RuntimeGeneratedFunction
DI.check_available(::AutoSymbolics) = true
DI.pullback_performance(::AutoSymbolics) = DI.PullbackSlow()

monovec(x::Number) = Fill(x, 1)
monovec(x::Number) = [x]

myvec(x::Number) = monovec(x)
myvec(x::AbstractArray) = vec(x)
Expand Down
1 change: 0 additions & 1 deletion DifferentiationInterface/src/DifferentiationInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ using ADTypes:
AutoTracker,
AutoZygote
using Compat
using FillArrays: OneElement
using LinearAlgebra: Symmetric, Transpose, dot, parent, transpose
using PackageExtensionCompat: @require_extensions

Expand Down
4 changes: 2 additions & 2 deletions DifferentiationInterface/src/first_order/jacobian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ function _prepare_jacobian_aux(
) where {FY}
N = length(x)
B = pick_batchsize(backend, N)
seeds = [basis(backend, x, ind) for ind in CartesianIndices(x)]
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
batched_seeds = [
Tangents(ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B))) for
a in 1:div(N, B, RoundUp)
Expand All @@ -102,7 +102,7 @@ function _prepare_jacobian_aux(
) where {FY}
M = length(y)
B = pick_batchsize(backend, M)
seeds = [basis(backend, y, ind) for ind in CartesianIndices(y)]
seeds = [basis(backend, y, ind) for ind in eachindex(y)]
batched_seeds = [
Tangents(ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % M], Val(B))) for
a in 1:div(M, B, RoundUp)
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/first_order/pullback.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ function _pullback_via_pushforward(
x::AbstractArray,
dy,
) where {F}
dx = map(CartesianIndices(x)) do j
dx = map(CartesianIndices(x)) do j # preserve shape
t1 = pushforward(
f!, y, pushforward_extras, backend, x, SingleTangent(basis(backend, x, j))
)
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/first_order/pushforward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ end
function _pushforward_via_pullback(
f!::F, y::AbstractArray, pullback_extras::PullbackExtras, backend::AbstractADType, x, dx
) where {F}
dy = map(CartesianIndices(y)) do i
dy = map(CartesianIndices(y)) do i # preserve shape
t1 = pullback(f!, y, pullback_extras, backend, x, SingleTangent(basis(backend, y, i)))
dot(dx, only(t1))
end
Expand Down
2 changes: 1 addition & 1 deletion DifferentiationInterface/src/second_order/hessian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
function prepare_hessian(f::F, backend::AbstractADType, x) where {F}
N = length(x)
B = pick_batchsize(maybe_outer(backend), N)
seeds = [basis(backend, x, ind) for ind in CartesianIndices(x)]
seeds = [basis(backend, x, ind) for ind in eachindex(x)]
batched_seeds = [
Tangents(ntuple(b -> seeds[1 + ((a - 1) * B + (b - 1)) % N], Val(B))) for
a in 1:div(N, B, RoundUp)
Expand Down
52 changes: 44 additions & 8 deletions DifferentiationInterface/src/utils/basis.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,43 @@
struct OneElement{I,N,T,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
ind::I
val::T
a::A

function OneElement(ind::Integer, val::T, a::A) where {N,T,A<:AbstractArray{T,N}}
right_ind = eachindex(a)[ind]
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
end

function OneElement(
ind::CartesianIndex{N}, val::T, a::A
) where {N,T,A<:AbstractArray{T,N}}
linear_ind = LinearIndices(a)[ind]
right_ind = eachindex(a)[linear_ind]
return new{typeof(right_ind),N,T,A}(right_ind, val, a)
end
end

Base.size(oe::OneElement) = size(oe.a)
Base.IndexStyle(oe::OneElement) = Base.IndexStyle(oe.a)

function Base.getindex(oe::OneElement{<:Integer}, ind::Integer)
if ind == oe.ind
return oe.val
else
return zero(eltype(oe.a))
end
end

function Base.getindex(oe::OneElement{<:CartesianIndex{N}}, ind::Vararg{Int,N}) where {N}
if ind == Tuple(oe.ind)
return oe.val
else
return zero(eltype(oe.a))
end
end

"""
basis(backend, a::AbstractArray, i::CartesianIndex)
basis(backend, a::AbstractArray, i)
Construct the `i`-th standard basis array in the vector space of `a` with element type `eltype(a)`.
Expand All @@ -11,7 +49,7 @@ this function can be extended on the backend type.
basis(::AbstractADType, a::AbstractArray, i) = basis(a, i)

"""
multibasis(backend, a::AbstractArray, inds::AbstractVector{<:CartesianIndex})
multibasis(backend, a::AbstractArray, inds::AbstractVector)
Construct the sum of the `i`-th standard basis arrays in the vector space of `a` with element type `eltype(a)`, for all `i ∈ inds`.
Expand All @@ -22,16 +60,14 @@ this function can be extended on the backend type.
"""
multibasis(::AbstractADType, a::AbstractArray, inds) = multibasis(a, inds)

function basis(a::AbstractArray{T,N}, i::CartesianIndex{N}) where {T,N}
return zero(a) + OneElement(one(T), Tuple(i), axes(a))
function basis(a::AbstractArray{T,N}, i) where {T,N}
return zero(a) + OneElement(i, one(T), a)
end

function multibasis(
a::AbstractArray{T,N}, inds::AbstractVector{<:CartesianIndex{N}}
) where {T,N}
function multibasis(a::AbstractArray{T,N}, inds::AbstractVector) where {T,N}
seed = zero(a)
for i in inds
seed += OneElement(one(T), Tuple(i), axes(a))
seed += OneElement(i, one(T), a)
end
return seed
end
41 changes: 41 additions & 0 deletions DifferentiationInterface/test/Internals/basis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
using DifferentiationInterface: OneElement, basis
using LinearAlgebra
using StaticArrays, JLArrays
using Test

@testset "OneElement" begin
a = rand(Float32, 3)
i = eachindex(a)[2]
oe = OneElement(i, 2.0f0, a)
@test eltype(oe) == Float32
@test oe == [0, 2, 0]

a = rand(Float64, 3, 3)
i = eachindex(a)[4]
oe = OneElement(i, 2.0, a)
@test eltype(oe) == Float64
@test oe == [0 2 0; 0 0 0; 0 0 0]

a = Diagonal(ones(3))
i = eachindex(a)[4]
oe = OneElement(i, 2.0, a)
@test oe == [0 2 0; 0 0 0; 0 0 0]
end

@testset "Basis" begin
b_ref = [0, 1, 0]
@test basis(rand(3), 2) isa Vector
@test basis(rand(3), 2) == b_ref
@test basis(jl(rand(3)), 2) isa JLArray
@test all(basis(jl(rand(3)), 2) .== b_ref)
@test basis(@SVector(rand(3)), 2) isa SVector
@test basis(@SVector(rand(3)), 2) == b_ref

b_ref = [0 1 0; 0 0 0; 0 0 0]
@test basis(rand(3, 3), 4) isa Matrix
@test basis(rand(3, 3), 4) == b_ref
@test basis(jl(rand(3, 3)), 4) isa JLArray
@test all(basis(jl(rand(3, 3)), 4) .== b_ref)
@test basis(@SMatrix(rand(3, 3)), 4) isa SMatrix
@test basis(@SMatrix(rand(3, 3)), 4) == b_ref
end
8 changes: 3 additions & 5 deletions DifferentiationInterfaceTest/src/tests/correctness_eval.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
function test_scen_intact(new_scen, scen; isequal)
@testset "Scenario intact" begin
for n in fieldnames(typeof(scen))
n == :f && continue
@test isequal(getfield(new_scen, n), getfield(scen, n))
end
for n in fieldnames(typeof(scen))
n == :f && continue
@test isequal(getfield(new_scen, n), getfield(scen, n))
end
end

Expand Down

0 comments on commit b2dcdef

Please sign in to comment.