diff --git a/Project.toml b/Project.toml index 53946677..b2d575d6 100644 --- a/Project.toml +++ b/Project.toml @@ -7,17 +7,20 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [weakdeps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] StructArraysAdaptExt = "Adapt" StructArraysGPUArraysCoreExt = "GPUArraysCore" +StructArraysSparseArraysExt = "SparseArrays" StructArraysStaticArraysExt = "StaticArrays" [compat] @@ -34,6 +37,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -43,4 +47,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9" WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5" [targets] -test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "Adapt"] +test = ["Adapt", "Documenter", "GPUArraysCore", "JLArrays", "LinearAlgebra", "OffsetArrays", "PooledArrays", "SparseArrays", "StaticArrays", "Test", "TypedTables", "WeakRefStrings"] \ No newline at end of file diff --git a/ext/StructArraysSparseArraysExt.jl b/ext/StructArraysSparseArraysExt.jl new file mode 100644 index 00000000..f6dfd0bf --- /dev/null +++ b/ext/StructArraysSparseArraysExt.jl @@ -0,0 +1,13 @@ +module StructArraysSparseArraysExt + +using StructArrays: StructArray, components, createinstance +import SparseArrays: sparse, issparse + +function sparse(S::StructArray{T}) where {T} + sparse_components = map(sparse, components(S)) + return createinstance.(T, sparse_components...) +end + +issparse(S::StructArray) = all(issparse, components(S)) + +end diff --git a/src/StructArrays.jl b/src/StructArrays.jl index 4130d059..7c1b1dd0 100644 --- a/src/StructArrays.jl +++ b/src/StructArrays.jl @@ -28,6 +28,7 @@ end @static if !isdefined(Base, :get_extension) include("../ext/StructArraysAdaptExt.jl") include("../ext/StructArraysGPUArraysCoreExt.jl") + include("../ext/StructArraysSparseArraysExt.jl") include("../ext/StructArraysStaticArraysExt.jl") end diff --git a/test/runtests.jl b/test/runtests.jl index 611dfc65..78318c3b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ using DataAPI: refarray, refvalue using Adapt: adapt, Adapt using JLArrays using GPUArraysCore: backend +using LinearAlgebra using Test using SparseArrays @@ -613,9 +614,10 @@ end A = spzeros(3) B = spzeros(3) S = StructArray{Complex{eltype(A)}}((A,B)) - fill!(S, 0) - @test all(iszero, A) - @test all(iszero, B) + fill!(S, 2+3im) + @test all(==(2), A) + @test all(==(3), B) + @test issparse(S) end end @@ -1144,6 +1146,23 @@ end end end +@testset "sparse" begin + @testset "Vector" begin + v = [1,0,2] + sv = StructArray{Complex{Int}}((v, v)) + spv = @inferred sparse(sv) + @test spv isa SparseVector{eltype(sv)} + @test spv == sv + end + @testset "Matrix" begin + d = Diagonal(Float64[1:4;]) + sa = StructArray{ComplexF64}((d, d)) + sp = @inferred sparse(sa) + @test sp isa SparseMatrixCSC{eltype(sa)} + @test sp == sa + end +end + struct ArrayConverter end Adapt.adapt_storage(::ArrayConverter, xs::AbstractArray) = convert(Array, xs)