Skip to content

Commit

Permalink
Update to the latest broadcast implement.
Browse files Browse the repository at this point in the history
On master `Broadcasted` store style by field. Update accordingly.
  • Loading branch information
N5N3 committed Oct 26, 2023
1 parent 99f0556 commit 5c0ae3d
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 38 deletions.
84 changes: 49 additions & 35 deletions src/structarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -497,33 +497,53 @@ end
import Base.Broadcast: BroadcastStyle, AbstractArrayStyle, Broadcasted, DefaultArrayStyle, Unknown, ArrayConflict
using Base.Broadcast: combine_styles

struct StructArrayStyle{S, N} <: AbstractArrayStyle{N} end
@static if fieldcount(Base.Broadcast.Broadcasted) == 4
struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
style::S
StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}(style)

Check warning on line 503 in src/structarray.jl

View check run for this annotation

Codecov / codecov/patch

src/structarray.jl#L503

Added line #L503 was not covered by tests
end
StructArrayStyle{N}(style::StructArrayStyle) where {N} = StructArrayStyle{N}(style.style)
parent_style(s::BroadcastStyle) = s
parent_style(s::StructArrayStyle) = s.style
style(bc::Broadcasted) = bc.style

Check warning on line 508 in src/structarray.jl

View check run for this annotation

Codecov / codecov/patch

src/structarray.jl#L505-L508

Added lines #L505 - L508 were not covered by tests
const broadcasted = Broadcasted
else
struct StructArrayStyle{N, S} <: AbstractArrayStyle{N}
StructArrayStyle{N}(style) where {N} = new{N, typeof(style)}()
end
StructArrayStyle{N}(style::StructArrayStyle{M, S}) where {N, M, S} = StructArrayStyle{N}(S())
parent_style(s::BroadcastStyle) = s
parent_style(::StructArrayStyle{N, S}) where {N, S} = S()
style(::Broadcasted{Style}) where {Style} = Style()
broadcasted(s, f, args, axes) = Broadcasted{typeof(s)}(f, args, axes)
end
StructArrayStyle{N, S}() where {N, S} = StructArrayStyle{N}(S())
parent_style(bc::Broadcasted) = parent_style(style(bc))
ofstyle(s, bc::Broadcasted) = broadcasted(s, bc.f, bc.args, bc.axes)

# Here we define the dimension tracking behavior of StructArrayStyle
function StructArrayStyle{S, M}(::Val{N}) where {S, M, N}
function StructArrayStyle{M, S}(::Val{N}) where {S, M, N}

Check warning on line 525 in src/structarray.jl

View check run for this annotation

Codecov / codecov/patch

src/structarray.jl#L525

Added line #L525 was not covered by tests
T = S <: AbstractArrayStyle{M} ? typeof(S(Val{N}())) : S
return StructArrayStyle{T, N}()
return StructArrayStyle{N, T}()

Check warning on line 527 in src/structarray.jl

View check run for this annotation

Codecov / codecov/patch

src/structarray.jl#L527

Added line #L527 was not covered by tests
end

# StructArrayStyle is a wrapped style.
# Here we try our best to resolve style conflict.
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{S, N}) where {S, N, M}
function BroadcastStyle(b::AbstractArrayStyle{M}, a::StructArrayStyle{N, S}) where {S, N, M}
N′ = M === Any || N === Any ? Any : max(M, N)
S′ = Broadcast.result_style(S(), b)
return S′ isa StructArrayStyle ? typeof(S′)(Val{N′}()) : StructArrayStyle{typeof(S′), N′}()
return StructArrayStyle{N′}(Broadcast.result_style(parent_style(a), b))
end
BroadcastStyle(::StructArrayStyle, ::DefaultArrayStyle) = Unknown()

@inline combine_style_types(::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(BroadcastStyle(A), args...)
@inline combine_style_types(s::BroadcastStyle, ::Type{A}, args...) where {A<:AbstractArray} =
combine_style_types(Broadcast.result_style(s, BroadcastStyle(A)), args...)
combine_style_types(::StructArrayStyle{S}) where {S} = S() # avoid nested StructArrayStyle
combine_style_types(s::BroadcastStyle) = s

Base.@pure cst(::Type{SA}) where {SA} = combine_style_types(array_types(SA).parameters...)

BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{typeof(cst(SA)), ndims(SA)}()
BroadcastStyle(::Type{SA}) where {SA<:StructArray} = StructArrayStyle{ndims(SA)}(cst(SA))

"""
always_struct_broadcast(style::BroadcastStyle)
Expand Down Expand Up @@ -551,8 +571,8 @@ See also [`always_struct_broadcast`](@ref).
"""
try_struct_copy(bc::Broadcasted) = copy(bc)

function Base.copy(bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
if always_struct_broadcast(S())
function Base.copy(bc::Broadcasted{<:StructArrayStyle})
if always_struct_broadcast(parent_style(bc))
return invoke(copy, Tuple{Broadcasted}, bc)
else
return try_struct_copy(replace_structarray(bc))
Expand All @@ -567,55 +587,49 @@ an equivalent one without it. This is not a must if the root `BroadcastStyle`
supports `AbstractArray`. But some `BroadcastStyle` limits the input array types,
e.g. `StaticArrayStyle`, thus we have to omit all `StructArray`.
"""
function replace_structarray(bc::Broadcasted{Style}) where {Style}
function replace_structarray(bc::Broadcasted)
args = replace_structarray_args(bc.args)
Style′ = parent_style(Style())
return Broadcasted{Style′}(bc.f, args, bc.axes)
style = parent_style(bc)
return broadcasted(style, bc.f, args, bc.axes)
end
function replace_structarray(A::StructArray)
f = Instantiator(eltype(A))
args = Tuple(components(A))
Style = typeof(combine_styles(args...))
return Broadcasted{Style}(f, args, axes(A))
style = combine_styles(args...)
return broadcasted(style, f, args, axes(A))
end
replace_structarray(@nospecialize(A)) = A

replace_structarray_args(args::Tuple) = (replace_structarray(args[1]), replace_structarray_args(tail(args))...)
replace_structarray_args(::Tuple{}) = ()

parent_style(@nospecialize(x)) = typeof(x)
parent_style(::StructArrayStyle{S, N}) where {S, N} = S
parent_style(::StructArrayStyle{S, N}) where {N, S<:AbstractArrayStyle{N}} = S
parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle{Any}, N} = S
parent_style(::StructArrayStyle{S, N}) where {S<:AbstractArrayStyle, N} = typeof(S(Val(N)))

# `instantiate` and `_axes` might be overloaded for static axes.
function Broadcast.instantiate(bc::Broadcasted{Style}) where {Style <: StructArrayStyle}
Style′ = parent_style(Style())
bc′ = Broadcast.instantiate(convert(Broadcasted{Style′}, bc))
return convert(Broadcasted{Style}, bc′)
function Broadcast.instantiate(bc::Broadcasted{<:StructArrayStyle})
bc′ = Broadcast.instantiate(ofstyle(parent_style(bc), bc))
return ofstyle(style(bc), bc′)
end

function Broadcast._axes(bc::Broadcasted{Style}, ::Nothing) where {Style <: StructArrayStyle}
Style′ = parent_style(Style())
return Broadcast._axes(convert(Broadcasted{Style′}, bc), nothing)
function Broadcast._axes(bc::Broadcasted{<:StructArrayStyle}, ::Nothing)
return Broadcast._axes(ofstyle(parent_style(bc), bc), nothing)
end

# Here we use `similar` defined for `S` to build the dest Array.
function Base.similar(bc::Broadcasted{StructArrayStyle{S, N}}, ::Type{ElType}) where {S, N, ElType}
bc′ = convert(Broadcasted{S}, bc)
function Base.similar(bc::Broadcasted{<:StructArrayStyle}, ::Type{ElType}) where {ElType}
bc′ = ofstyle(parent_style(bc), bc)
return isnonemptystructtype(ElType) ? buildfromschema(T -> similar(bc′, T), ElType) : similar(bc′, ElType)
end

# Unwrapper to recover the behaviour defined by parent style.
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{StructArrayStyle{S, N}}) where {S, N}
bc′ = always_struct_broadcast(S()) ? convert(Broadcasted{S}, bc) : replace_structarray(bc)
@inline function Base.copyto!(dest::AbstractArray, bc::Broadcasted{<:StructArrayStyle})
ps = parent_style(bc)
bc′ = always_struct_broadcast(ps) ? ofstyle(ps, bc) : replace_structarray(bc)
return copyto!(dest, bc′)
end

@inline function Broadcast.materialize!(::StructArrayStyle{S}, dest, bc::Broadcasted) where {S}
bc′ = always_struct_broadcast(S()) ? bc : replace_structarray(bc)
return Broadcast.materialize!(S(), dest, bc′)
@inline function Broadcast.materialize!(s::StructArrayStyle, dest, bc::Broadcasted)
ps = parent_style(s)
bc′ = always_struct_broadcast(ps) ? bc : replace_structarray(bc)
return Broadcast.materialize!(ps, dest, bc′)
end

# for aliasing analysis during broadcast
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1227,7 +1227,7 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
ares = map(a->a.re, as)
aims = map(a->a.im, as)
style = Broadcast.combine_styles(ares...)
@test Broadcast.combine_styles(as...) === StructArrayStyle{typeof(style),1}()
@test Broadcast.combine_styles(as...) === StructArrayStyle{1,typeof(style)}()
if !(style in tested_style)
push!(tested_style, style)
if style isa Broadcast.ArrayStyle{MyArray3}
Expand All @@ -1249,8 +1249,8 @@ Base.BroadcastStyle(::Broadcast.ArrayStyle{MyArray2}, S::Broadcast.DefaultArrayS
@test Base.broadcasted(+, s, MyArray1(rand(2))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}

#parent_style
@test StructArrays.parent_style(StructArrayStyle{Broadcast.DefaultArrayStyle{0},2}()) == Broadcast.DefaultArrayStyle{2}
@test StructArrays.parent_style(StructArrayStyle{Broadcast.Style{Tuple},2}()) == Broadcast.Style{Tuple}
@test StructArrays.parent_style(StructArrayStyle{2,Broadcast.DefaultArrayStyle{0}}()) == Broadcast.DefaultArrayStyle{0}()
@test StructArrays.parent_style(StructArrayStyle{2,Broadcast.Style{Tuple}}()) == Broadcast.Style{Tuple}()

# allocation test for overloaded `broadcast_unalias`
StructArrays.always_struct_broadcast(::Broadcast.ArrayStyle{MyArray1}) = false
Expand Down

0 comments on commit 5c0ae3d

Please sign in to comment.