Skip to content

Commit

Permalink
Weird array test scenarios in DIT extensions (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Jul 16, 2024
1 parent 1299a5e commit 0db3be8
Show file tree
Hide file tree
Showing 16 changed files with 228 additions and 141 deletions.
6 changes: 6 additions & 0 deletions DifferentiationInterface/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ julia = "1.6"
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
# DifferentiationInterfaceTest = "a82114a7-5aa3-49a8-9643-716bb13727a3"
Diffractor = "9f5e2b26-1114-432f-b630-d3fe2085c51c"
Expand All @@ -81,6 +82,7 @@ FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PolyesterForwardDiff = "98d1487c-24ca-40b6-b7ab-df2af84e126b"
Expand All @@ -89,6 +91,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand All @@ -99,14 +102,17 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
test = [
"ADTypes",
"Aqua",
"ComponentArrays",
"DataFrames",
# "DifferentiationInterfaceTest",
"JET",
"JLArrays",
"JuliaFormatter",
"Pkg",
"SparseArrays",
"SparseConnectivityTracer",
"SparseMatrixColorings",
"StableRNGs",
"StaticArrays",
"Test",
]
2 changes: 2 additions & 0 deletions DifferentiationInterface/test/Back/ForwardDiff/test.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using ComponentArrays: ComponentArrays
using DifferentiationInterface, DifferentiationInterfaceTest
using DifferentiationInterfaceTest: add_batchified!
using ForwardDiff: ForwardDiff
using SparseConnectivityTracer, SparseMatrixColorings
using StaticArrays: StaticArrays
using Test

dense_backends = [AutoForwardDiff(), AutoForwardDiff(; chunksize=5, tag=:hello)]
Expand Down
3 changes: 3 additions & 0 deletions DifferentiationInterface/test/Back/Zygote/test.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
using ComponentArrays: ComponentArrays
using DifferentiationInterface, DifferentiationInterfaceTest
using JLArrays: JLArrays
using SparseConnectivityTracer, SparseMatrixColorings
using StaticArrays: StaticArrays
using Test
using Zygote: Zygote

Expand Down
38 changes: 33 additions & 5 deletions DifferentiationInterfaceTest/Project.toml
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
name = "DifferentiationInterfaceTest"
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
authors = ["Guillaume Dalle", "Adrian Hill"]
version = "0.5.0"
version = "0.6.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Chairmarks = "0ca39b1e-fe0b-4e98-acfc-b1656634c4de"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[weakdeps]
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
DifferentiationInterfaceTestComponentArraysExt = "ComponentArrays"
DifferentiationInterfaceTestJLArraysExt = "JLArrays"
DifferentiationInterfaceTestStaticArraysExt = "StaticArrays"

[compat]
ADTypes = "1.0.0"
Chairmarks = "1.2.1"
Expand All @@ -42,17 +50,37 @@ julia = "1.6"
[extras]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
SparseMatrixColorings = "0a514795-09f3-496d-8182-132a7b665d35"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ADTypes", "Aqua", "DataFrames", "DifferentiationInterface", "ForwardDiff", "JET", "JuliaFormatter", "Pkg", "SparseArrays", "SparseConnectivityTracer", "SparseMatrixColorings", "Test", "Zygote"]
test = [
"ADTypes",
"Aqua",
"ComponentArrays",
"DataFrames",
"DifferentiationInterface",
"ForwardDiff",
"JET",
"JLArrays",
"JuliaFormatter",
"Pkg",
"SparseArrays",
"SparseConnectivityTracer",
"SparseMatrixColorings",
"StaticArrays",
"Test",
"Zygote",
]
2 changes: 1 addition & 1 deletion DifferentiationInterfaceTest/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Make it easy to know, for a given function:
- Type stability tests
- Count calls to the function
- Benchmark runtime and allocations
- Weird array types (GPU, static, components)
- Scenarios with weird array types (GPU, static, components) in package extensions

## Installation

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
module DifferentiationInterfaceTestComponentArraysExt

using ComponentArrays: ComponentVector
using DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using LinearAlgebra: dot
using Random: AbstractRNG, default_rng

## Vector to scalar

function comp_to_num(x::ComponentVector)::Number
Expand Down Expand Up @@ -42,12 +50,7 @@ end

## Gather

"""
component_scenarios(rng=Random.default_rng())
Create a vector of [`Scenario`](@ref)s with component array types from [ComponentArrays.jl](https://github.com/jonniedie/ComponentArrays.jl).
"""
function component_scenarios(rng::AbstractRNG=default_rng())
function DIT.component_scenarios(rng::AbstractRNG=default_rng())
dy_ = rand(rng)

x_comp = ComponentVector(; a=randn(rng, 4), b=randn(rng, 2))
Expand All @@ -60,3 +63,5 @@ function component_scenarios(rng::AbstractRNG=default_rng())
)
return scens
end

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
module DifferentiationInterfaceTestJLArraysExt

using DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using JLArrays: JLArray, jl
using Random: AbstractRNG, default_rng

num_to_arr_jlvector(x) = DIT.num_to_arr(x, JLArray{Float64,1})
num_to_arr_jlmatrix(x) = DIT.num_to_arr(x, JLArray{Float64,2})

DIT.pick_num_to_arr(::Type{<:JLArray{<:Real,1}}) = num_to_arr_jlvector
DIT.pick_num_to_arr(::Type{<:JLArray{<:Real,2}}) = num_to_arr_jlmatrix

function DIT.gpu_scenarios(rng::AbstractRNG=default_rng(); linalg=true)
x_ = rand(rng)
dx_ = rand(rng)
dy_ = rand(rng)

x_6 = jl(rand(rng, 6))
dx_6 = jl(rand(rng, 6))

x_2_3 = jl(rand(rng, 2, 3))
dx_2_3 = jl(rand(rng, 2, 3))

dy_12 = jl(rand(rng, 12))
dy_6_2 = jl(rand(rng, 6, 2))
dy_6 = jl(rand(rng, 6))
dy_2_3 = jl(rand(rng, 2, 3))

V = typeof(dy_6)
M = typeof(dy_2_3)

scens = vcat(
# one argument
DIT.num_to_num_scenarios_onearg(x_; dx=dx_, dy=dy_),
DIT.num_to_arr_scenarios_onearg(x_, V; dx=dx_, dy=dy_6),
DIT.num_to_arr_scenarios_onearg(x_, M; dx=dx_, dy=dy_2_3),
DIT.arr_to_num_scenarios_onearg(x_6; dx=dx_6, dy=dy_, linalg),
DIT.arr_to_num_scenarios_onearg(x_2_3; dx=dx_2_3, dy=dy_, linalg),
DIT.vec_to_vec_scenarios_onearg(x_6; dx=dx_6, dy=dy_12),
DIT.vec_to_mat_scenarios_onearg(x_6; dx=dx_6, dy=dy_6_2),
DIT.mat_to_vec_scenarios_onearg(x_2_3; dx=dx_2_3, dy=dy_12),
DIT.mat_to_mat_scenarios_onearg(x_2_3; dx=dx_2_3, dy=dy_6_2),
# two arguments
DIT.num_to_arr_scenarios_twoarg(x_, V; dx=dx_, dy=dy_6),
DIT.num_to_arr_scenarios_twoarg(x_, M; dx=dx_, dy=dy_2_3),
DIT.vec_to_vec_scenarios_twoarg(x_6; dx=dx_6, dy=dy_12),
DIT.vec_to_mat_scenarios_twoarg(x_6; dx=dx_6, dy=dy_6_2),
DIT.mat_to_vec_scenarios_twoarg(x_2_3; dx=dx_2_3, dy=dy_12),
DIT.mat_to_mat_scenarios_twoarg(x_2_3; dx=dx_2_3, dy=dy_6_2),
)
return scens
end

end
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
module DifferentiationInterfaceTestStaticArraysExt

using DifferentiationInterfaceTest
import DifferentiationInterfaceTest as DIT
using Random: AbstractRNG, default_rng
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector

num_to_arr_svector(x) = DIT.num_to_arr(x, SVector{6,Float64})
num_to_arr_smatrix(x) = DIT.num_to_arr(x, SMatrix{2,3,Float64,6})

DIT.pick_num_to_arr(::Type{<:SVector}) = num_to_arr_svector
DIT.pick_num_to_arr(::Type{<:SMatrix}) = num_to_arr_smatrix

function DIT.static_scenarios(rng::AbstractRNG=default_rng(); linalg=true)
x_ = rand(rng)
dx_ = rand(rng)
dy_ = rand(rng)

x_6 = rand(rng, 6)
dx_6 = rand(rng, 6)

x_2_3 = rand(rng, 2, 3)
dx_2_3 = rand(rng, 2, 3)

dy_6 = rand(rng, 6)
dy_12 = rand(rng, 12)
dy_2_3 = rand(rng, 2, 3)
dy_6_2 = rand(rng, 6, 2)

SV_6 = SVector{6}
MV_6 = MVector{6}
SV_12 = SVector{12}
MV_12 = MVector{12}

SM_2_3 = SMatrix{2,3}
MM_2_3 = MMatrix{2,3}
SM_6_2 = SMatrix{6,2}
MM_6_2 = MMatrix{6,2}

scens = vcat(
# one argument
DIT.num_to_arr_scenarios_onearg(x_, SV_6; dx=dx_, dy=SV_6(dy_6)),
DIT.num_to_arr_scenarios_onearg(x_, SM_2_3; dx=dx_, dy=SM_2_3(dy_2_3)),
DIT.arr_to_num_scenarios_onearg(SV_6(x_6); dx=SV_6(dx_6), dy=dy_, linalg),
DIT.arr_to_num_scenarios_onearg(SM_2_3(x_2_3); dx=SM_2_3(dx_2_3), dy=dy_, linalg),
DIT.vec_to_vec_scenarios_onearg(SV_6(x_6); dx=SV_6(dx_6), dy=SV_12(dy_12)),
DIT.vec_to_mat_scenarios_onearg(SV_6(x_6); dx=SV_6(dx_6), dy=SM_6_2(dy_6_2)),
DIT.mat_to_vec_scenarios_onearg(SM_2_3(x_2_3); dx=SM_2_3(dx_2_3), dy=SV_12(dy_12)),
DIT.mat_to_mat_scenarios_onearg(
SM_2_3(x_2_3); dx=SM_2_3(dx_2_3), dy=SM_6_2(dy_6_2)
),
# two arguments
DIT.num_to_arr_scenarios_twoarg(x_, MV_6; dx=dx_, dy=MV_6(dy_6)),
DIT.num_to_arr_scenarios_twoarg(x_, MM_2_3; dx=dx_, dy=MM_2_3(dy_2_3)),
DIT.vec_to_vec_scenarios_twoarg(MV_6(x_6); dx=MV_6(dx_6), dy=MV_12(dy_12)),
DIT.vec_to_mat_scenarios_twoarg(MV_6(x_6); dx=MV_6(dx_6), dy=MM_6_2(dy_6_2)),
DIT.mat_to_vec_scenarios_twoarg(MM_2_3(x_2_3); dx=MM_2_3(dx_2_3), dy=MV_12(dy_12)),
DIT.mat_to_mat_scenarios_twoarg(
MM_2_3(x_2_3); dx=MM_2_3(dx_2_3), dy=MM_6_2(dy_6_2)
),
)
scens = filter(scens) do s
DIT.place(s) == :outofplace || s.x isa Union{Number,MArray}
end
return scens
end

end
17 changes: 10 additions & 7 deletions DifferentiationInterfaceTest/src/DifferentiationInterfaceTest.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ using ADTypes:
SymbolicMode
using Chairmarks: @be, Benchmark, Sample
using Compat
using ComponentArrays: ComponentVector
using DataFrames: DataFrame
using DifferentiationInterface
using DifferentiationInterface:
Expand Down Expand Up @@ -63,22 +62,19 @@ using DifferentiationInterface:
using DocStringExtensions
import DifferentiationInterface as DI
using JET: JET
using JLArrays: JLArray, jl
using LinearAlgebra: Adjoint, Diagonal, Transpose, dot, parent
using PackageExtensionCompat: @require_extensions
using ProgressMeter: ProgressUnknown, next!
using Random: AbstractRNG, default_rng, rand!
using SparseArrays: SparseArrays, SparseMatrixCSC, nnz, spdiagm
using StaticArrays: MArray, MMatrix, MVector, SArray, SMatrix, SVector
using Test: @testset, @test

include("scenarios/scenario.jl")
include("scenarios/batchify.jl")
include("scenarios/default.jl")
include("scenarios/sparse.jl")
include("scenarios/static.jl")
include("scenarios/component.jl")
include("scenarios/gpu.jl")
include("scenarios/allocfree.jl")
include("scenarios/extensions.jl")

include("utils/zero_backends.jl")
include("utils/misc.jl")
Expand All @@ -92,6 +88,10 @@ include("tests/sparsity.jl")
include("tests/benchmark.jl")
include("test_differentiation.jl")

function __init__()
@require_extensions
end

export Scenario
export PushforwardScenario,
PullbackScenario,
Expand All @@ -102,8 +102,11 @@ export PushforwardScenario,
HVPScenario,
HessianScenario
export default_scenarios, sparse_scenarios
export static_scenarios, component_scenarios, gpu_scenarios
export test_differentiation, benchmark_differentiation
export DifferentiationBenchmarkDataRow
# extensions
export static_scenarios
export component_scenarios
export gpu_scenarios

end
26 changes: 3 additions & 23 deletions DifferentiationInterfaceTest/src/scenarios/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,30 +47,10 @@ function num_to_arr(x::Number, ::Type{A}) where {A<:AbstractArray}
end

num_to_arr_vector(x) = num_to_arr(x, Vector{Float64})
num_to_arr_svector(x) = num_to_arr(x, SVector{6,Float64})
num_to_arr_jlvector(x) = num_to_arr(x, JLArray{Float64,1})

num_to_arr_matrix(x) = num_to_arr(x, Matrix{Float64})
num_to_arr_smatrix(x) = num_to_arr(x, SMatrix{2,3,Float64,6})
num_to_arr_jlmatrix(x) = num_to_arr(x, JLArray{Float64,2})

function pick_num_to_arr(::Type{A}) where {A<:AbstractArray}
if A <: Vector
return num_to_arr_vector
elseif A <: SVector
return num_to_arr_svector
elseif A <: JLArray{<:Any,1}
return num_to_arr_jlvector
elseif A <: Matrix
return num_to_arr_matrix
elseif A <: SMatrix
return num_to_arr_smatrix
elseif A <: JLArray{<:Any,2}
return num_to_arr_jlmatrix
else
throw(ArgumentError("Array type $A not supported"))
end
end

pick_num_to_arr(::Type{<:Vector}) = num_to_arr_vector
pick_num_to_arr(::Type{<:Matrix}) = num_to_arr_matrix

function num_to_arr!(y::AbstractArray, x::Number)::Nothing
a = multiplicator(typeof(y))
Expand Down
Loading

0 comments on commit 0db3be8

Please sign in to comment.