Skip to content

Commit

Permalink
improve conversion to sparse arrays (#289)
Browse files Browse the repository at this point in the history
* sparse conversion

* map sparse over components

* move tests up

* Move to package extension and extend issparse

* formatting

* use createinstance

---------

Co-authored-by: Jishnu Bhattacharya <jishnub.github@gmail.com>
  • Loading branch information
piever and jishnub authored Dec 14, 2023
1 parent 15b044b commit 8b665db
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 4 deletions.
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,20 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
StructArraysAdaptExt = "Adapt"
StructArraysGPUArraysCoreExt = "GPUArraysCore"
StructArraysSparseArraysExt = "SparseArrays"
StructArraysStaticArraysExt = "StaticArrays"

[compat]
Expand All @@ -34,6 +37,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Expand All @@ -43,4 +47,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"

[targets]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "Adapt"]
test = ["Adapt", "Documenter", "GPUArraysCore", "JLArrays", "LinearAlgebra", "OffsetArrays", "PooledArrays", "SparseArrays", "StaticArrays", "Test", "TypedTables", "WeakRefStrings"]
13 changes: 13 additions & 0 deletions ext/StructArraysSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
module StructArraysSparseArraysExt

using StructArrays: StructArray, components, createinstance
import SparseArrays: sparse, issparse

function sparse(S::StructArray{T}) where {T}
sparse_components = map(sparse, components(S))
return createinstance.(T, sparse_components...)
end

issparse(S::StructArray) = all(issparse, components(S))

end
1 change: 1 addition & 0 deletions src/StructArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ end
@static if !isdefined(Base, :get_extension)
include("../ext/StructArraysAdaptExt.jl")
include("../ext/StructArraysGPUArraysCoreExt.jl")
include("../ext/StructArraysSparseArraysExt.jl")
include("../ext/StructArraysStaticArraysExt.jl")
end

Expand Down
25 changes: 22 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using DataAPI: refarray, refvalue
using Adapt: adapt, Adapt
using JLArrays
using GPUArraysCore: backend
using LinearAlgebra
using Test
using SparseArrays

Expand Down Expand Up @@ -613,9 +614,10 @@ end
A = spzeros(3)
B = spzeros(3)
S = StructArray{Complex{eltype(A)}}((A,B))
fill!(S, 0)
@test all(iszero, A)
@test all(iszero, B)
fill!(S, 2+3im)
@test all(==(2), A)
@test all(==(3), B)
@test issparse(S)
end
end

Expand Down Expand Up @@ -1144,6 +1146,23 @@ end
end
end

@testset "sparse" begin
@testset "Vector" begin
v = [1,0,2]
sv = StructArray{Complex{Int}}((v, v))
spv = @inferred sparse(sv)
@test spv isa SparseVector{eltype(sv)}
@test spv == sv
end
@testset "Matrix" begin
d = Diagonal(Float64[1:4;])
sa = StructArray{ComplexF64}((d, d))
sp = @inferred sparse(sa)
@test sp isa SparseMatrixCSC{eltype(sa)}
@test sp == sa
end
end

struct ArrayConverter end

Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)
Expand Down

0 comments on commit 8b665db

Please sign in to comment.