Skip to content

Commit

Permalink
Move StaticArrays support to extension (#265)
Browse files Browse the repository at this point in the history
* Use weakdeps on Julia v1.9

Update Project.toml

* move StructStaticArray broadcast to ext

* fix doctest

* move `Adapt` to ext

And use curried adapter to avoid possible instability

* Apply suggestions from code review

* Adopt code style suggestion.

* restrict Adapt compat

* Add empty bc test.

* define `__broadcast` ourselves

---------

Co-authored-by: Oliver Schulz <oschulz@mpp.mpg.de>
  • Loading branch information
N5N3 and oschulz authored Dec 6, 2023
1 parent 99f0556 commit 3b38c22
Show file tree
Hide file tree
Showing 11 changed files with 209 additions and 206 deletions.
19 changes: 15 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,32 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
StructArraysAdaptExt = "Adapt"
StructArraysGPUArraysCoreExt = "GPUArraysCore"
StructArraysStaticArraysExt = "StaticArrays"

[compat]
Adapt = "1, 2, 3"
Adapt = "3.4"
ConstructionBase = "1"
DataAPI = "1"
GPUArraysCore = "0.1.2"
StaticArrays = "1.5.6"
StaticArraysCore = "1.3"
Tables = "1"
julia = "1.6"

[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PooledArrays = "2dfb63ee-cc39-5dd5-95bd-886bf059d720"
Expand All @@ -32,4 +43,4 @@ TypedTables = "9d95f2ec-7b3d-5a63-8d20-e2491e220bb9"
WeakRefStrings = "ea10d353-3f73-51f8-a26c-33c1cb351aa5"

[targets]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays"]
test = ["Test", "JLArrays", "StaticArrays", "OffsetArrays", "PooledArrays", "TypedTables", "WeakRefStrings", "Documenter", "SparseArrays", "GPUArraysCore", "Adapt"]
97 changes: 32 additions & 65 deletions docs/src/advanced.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,89 +6,56 @@ StructArrays support structures with custom data layout. The user is required to

Here is an example of a type `MyType` that has as custom fields either its field `data` or fields of its field `rest` (which is a named tuple):

```jldoctest advanced1
julia> using StructArrays
```@repl advanced1
using StructArrays
julia> struct MyType{T, NT<:NamedTuple}
data::T
rest::NT
end
struct MyType{T, NT<:NamedTuple}
data::T
rest::NT
end
julia> MyType(x; kwargs...) = MyType(x, values(kwargs))
MyType
MyType(x; kwargs...) = MyType(x, values(kwargs))
```

Let's create a small array of these objects:

```jldoctest advanced1
julia> s = [MyType(i/5, a=6-i, b=2) for i in 1:5]
5-element Vector{MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}}:
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.2, (a = 5, b = 2))
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.4, (a = 4, b = 2))
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.6, (a = 3, b = 2))
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(0.8, (a = 2, b = 2))
MyType{Float64, NamedTuple{(:a, :b), Tuple{Int64, Int64}}}(1.0, (a = 1, b = 2))
```@repl advanced1
s = [MyType(i/5, a=6-i, b=2) for i in 1:5]
```

The default `StructArray` does not unpack the `NamedTuple`:

```jldoctest advanced1
julia> sa = StructArray(s);
julia> sa.rest
5-element Vector{NamedTuple{(:a, :b), Tuple{Int64, Int64}}}:
(a = 5, b = 2)
(a = 4, b = 2)
(a = 3, b = 2)
(a = 2, b = 2)
(a = 1, b = 2)
julia> sa.a
ERROR: type NamedTuple has no field a
Stacktrace:
[1] component
[...]
```@repl advanced1
sa = StructArray(s);
sa.rest
sa.a
```

Suppose we wish to give the keywords their own fields. We can define custom `staticschema`, `component`, and `createinstance` methods for `MyType`:

```jldoctest advanced1
julia> function StructArrays.staticschema(::Type{MyType{T, NamedTuple{names, types}}}) where {T, names, types}
# Define the desired names and eltypes of the "fields"
return NamedTuple{(:data, names...), Base.tuple_type_cons(T, types)}
end;
julia> function StructArrays.component(m::MyType, key::Symbol)
# Define a component-extractor
return key === :data ? getfield(m, 1) : getfield(getfield(m, 2), key)
end;
julia> function StructArrays.createinstance(::Type{MyType{T, NT}}, x, args...) where {T, NT}
# Generate an instance of MyType from components
return MyType(x, NT(args))
end;
```@repl advanced1
function StructArrays.staticschema(::Type{MyType{T, NamedTuple{names, types}}}) where {T, names, types}
# Define the desired names and eltypes of the "fields"
return NamedTuple{(:data, names...), Base.tuple_type_cons(T, types)}
end;
function StructArrays.component(m::MyType, key::Symbol)
# Define a component-extractor
return key === :data ? getfield(m, 1) : getfield(getfield(m, 2), key)
end;
function StructArrays.createinstance(::Type{MyType{T, NT}}, x, args...) where {T, NT}
# Generate an instance of MyType from components
return MyType(x, NT(args))
end;
```

and now:

```jldoctest advanced1
julia> sa = StructArray(s);
julia> sa.a
5-element Vector{Int64}:
5
4
3
2
1
julia> sa.b
5-element Vector{Int64}:
2
2
2
2
2
```@repl advanced1
sa = StructArray(s);
sa.a
sa.b
```

The above strategy has been tested and implemented in [GeometryBasics.jl](https://github.com/JuliaGeometry/GeometryBasics.jl).
Expand Down
70 changes: 18 additions & 52 deletions docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,73 +9,39 @@ The package was largely inspired by the `Columns` type in [IndexedTables](https:
## Collection and initialization

One can create a `StructArray` by providing the struct type and a tuple or NamedTuple of field arrays:
```jldoctest intro
julia> using StructArrays
julia> struct Foo{T}
a::T
b::T
end
julia> adata = [1 2; 3 4]; bdata = [10 20; 30 40];
julia> x = StructArray{Foo}((adata, bdata))
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype Foo:
Foo{Int64}(1, 10) Foo{Int64}(2, 20)
Foo{Int64}(3, 30) Foo{Int64}(4, 40)
```@repl intro
using StructArrays
struct Foo{T}
a::T
b::T
end
adata = [1 2; 3 4]; bdata = [10 20; 30 40];
x = StructArray{Foo}((adata, bdata))
```

You can also initialze a StructArray by passing in a NamedTuple, in which case the name (rather than the order) specifies how the input arrays are assigned to fields:

```jldoctest intro
julia> x = StructArray{Foo}((b = adata, a = bdata)) # initialize a with bdata and vice versa
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype Foo:
Foo{Int64}(10, 1) Foo{Int64}(20, 2)
Foo{Int64}(30, 3) Foo{Int64}(40, 4)
```@repl intro
x = StructArray{Foo}((b = adata, a = bdata)) # initialize a with bdata and vice versa
```

If a struct is not specified, a StructArray with Tuple or NamedTuple elements will be created:
```jldoctest intro
julia> x = StructArray((adata, bdata))
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype Tuple{Int64, Int64}:
(1, 10) (2, 20)
(3, 30) (4, 40)
julia> x = StructArray((a = adata, b = bdata))
2×2 StructArray(::Matrix{Int64}, ::Matrix{Int64}) with eltype NamedTuple{(:a, :b), Tuple{Int64, Int64}}:
(a = 1, b = 10) (a = 2, b = 20)
(a = 3, b = 30) (a = 4, b = 40)
```@repl intro
x = StructArray((adata, bdata))
x = StructArray((a = adata, b = bdata))
```

It's also possible to create a `StructArray` by choosing a particular dimension to interpret as the components of a struct:

```jldoctest intro
julia> x = StructArray{Complex{Int}}(adata; dims=1) # along dimension 1, the first item `re` and the second is `im`
2-element StructArray(view(::Matrix{Int64}, 1, :), view(::Matrix{Int64}, 2, :)) with eltype Complex{Int64}:
1 + 3im
2 + 4im
julia> x = StructArray{Complex{Int}}(adata; dims=2) # along dimension 2, the first item `re` and the second is `im`
2-element StructArray(view(::Matrix{Int64}, :, 1), view(::Matrix{Int64}, :, 2)) with eltype Complex{Int64}:
1 + 2im
3 + 4im
```@repl intro
x = StructArray{Complex{Int}}(adata; dims=1) # along dimension 1, the first item `re` and the second is `im`
x = StructArray{Complex{Int}}(adata; dims=2) # along dimension 2, the first item `re` and the second is `im`
```

One can also create a `StructArray` from an iterable of structs without creating an intermediate `Array`:

```jldoctest intro
julia> StructArray(log(j+2.0*im) for j in 1:10)
10-element StructArray(::Vector{Float64}, ::Vector{Float64}) with eltype ComplexF64:
0.8047189562170501 + 1.1071487177940904im
1.0397207708399179 + 0.7853981633974483im
1.2824746787307684 + 0.5880026035475675im
1.4978661367769954 + 0.4636476090008061im
1.683647914993237 + 0.3805063771123649im
1.8444397270569681 + 0.3217505543966422im
1.985145956776061 + 0.27829965900511133im
2.1097538525880535 + 0.24497866312686414im
2.2213256282451583 + 0.21866894587394195im
2.3221954495706862 + 0.19739555984988078im
```@repl intro
StructArray(log(j+2.0*im) for j in 1:10)
```

Another option is to create an uninitialized `StructArray` and then fill it with data. Just like in normal arrays, this is done with the `undef` syntax:
Expand Down
5 changes: 5 additions & 0 deletions ext/StructArraysAdaptExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module StructArraysAdaptExt
# Use Adapt allows for automatic conversion of CPU to GPU StructArrays
using Adapt, StructArrays
Adapt.adapt_structure(to, s::StructArray) = replace_storage(adapt(to), s)
end
21 changes: 21 additions & 0 deletions ext/StructArraysGPUArraysCoreExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module StructArraysGPUArraysCoreExt

using StructArrays
using StructArrays: map_params, array_types

using Base: tail

import GPUArraysCore

# for GPU broadcast
import GPUArraysCore
function GPUArraysCore.backend(::Type{T}) where {T<:StructArray}
backends = map_params(GPUArraysCore.backend, array_types(T))
backend, others = backends[1], tail(backends)
isconsistent = mapfoldl(isequal(backend), &, others; init=true)
isconsistent || throw(ArgumentError("all component arrays must have the same GPU backend"))
return backend
end
StructArrays.always_struct_broadcast(::GPUArraysCore.AbstractGPUArrayStyle) = true

end # module
107 changes: 107 additions & 0 deletions ext/StructArraysStaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
module StructArraysStaticArraysExt

using StructArrays
using StaticArrays: StaticArray, FieldArray, tuple_prod

"""
StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
The `staticschema` of a `StaticArray` element type is the `staticschema` of the underlying `Tuple`.
```julia
julia> StructArrays.staticschema(SVector{2, Float64})
Tuple{Float64, Float64}
```
The one exception to this rule is `<:StaticArrays.FieldArray`, since `FieldArray` is based on a
struct. In this case, `staticschema(<:FieldArray)` returns the `staticschema` for the struct
which subtypes `FieldArray`.
"""
@generated function StructArrays.staticschema(::Type{<:StaticArray{S, T}}) where {S, T}
return quote
Base.@_inline_meta
return NTuple{$(tuple_prod(S)), T}
end
end
StructArrays.createinstance(::Type{T}, args...) where {T<:StaticArray} = T(args)
StructArrays.component(s::StaticArray, i) = getindex(s, i)

# invoke general fallbacks for a `FieldArray` type.
@inline function StructArrays.staticschema(T::Type{<:FieldArray})
invoke(StructArrays.staticschema, Tuple{Type{<:Any}}, T)
end
StructArrays.component(s::FieldArray, i) = invoke(StructArrays.component, Tuple{Any, Any}, s, i)
StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArrays.createinstance, Tuple{Type{<:Any}, Vararg}, T, args...)

# Broadcast overload
using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo
using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype
using StructArrays: isnonemptystructtype
using Base.Broadcast: Broadcasted, _broadcast_getindex

# StaticArrayStyle has no similar defined.
# Overload `try_struct_copy` instead.
@inline function StructArrays.try_struct_copy(bc::Broadcasted{StaticArrayStyle{M}}) where {M}
flat = broadcast_flatten(bc); as = flat.args; f = flat.f
argsizes = broadcast_sizes(as...)
ax = axes(bc)
ax isa Tuple{Vararg{SOneTo}} || error("Dimension is not static. Please file a bug at `StaticArrays.jl`.")
return _broadcast(f, Size(map(length, ax)), argsizes, as...)
end

# A functor generates the ith component of StructStaticBroadcast.
struct Similar_ith{SA, E<:Tuple}
elements::E
Similar_ith{SA}(elements::Tuple) where {SA} = new{SA, typeof(elements)}(elements)
end
function (s::Similar_ith{SA})(i::Int) where {SA}
ith_elements = ntuple(Val(length(s.elements))) do j
getfield(s.elements[j], i)
end
ith_SA = similar_type(SA, fieldtype(eltype(SA), i))
return @inbounds ith_SA(ith_elements)
end

@inline function _broadcast(f, sz::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where {newsize}
first_staticarray = first_statictype(a...)
elements, ET = if prod(newsize) == 0
# Use inference to get eltype in empty case (following StaticBroadcast defined in StaticArrays.jl)
eltys = Tuple{map(eltype, a)...}
(), Core.Compiler.return_type(f, eltys)
else
temp = __broadcast(f, sz, s, a...)
temp, eltype(temp)
end
if isnonemptystructtype(ET)
SA = similar_type(first_staticarray, ET, sz)
arrs = ntuple(Similar_ith{SA}(elements), Val(fieldcount(ET)))
return StructArray{ET}(arrs)
else
@inbounds return similar_type(first_staticarray, ET, sz)(elements)
end
end

# The `__broadcast` kernal is copied from `StaticArrays.jl`.
# see https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/broadcast.jl
@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
sizes = [sz.parameters[1] for sz s.parameters]

indices = CartesianIndices(newsize)
exprs = similar(indices, Expr)
for (j, current_ind) enumerate(indices)
exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
exprs[j] = :(f($(exprs_vals...)))
end

return quote
Base.@_inline_meta
return tuple($(exprs...))
end
end

broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I))
function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex)
li = LinearIndices(oldsize)
ind = _broadcast_getindex(li, newindex)
return :(a[$i][$ind])
end

end
Loading

0 comments on commit 3b38c22

Please sign in to comment.