From 60a8c8cbaf395fe4960967f500f467723ff1d5e5 Mon Sep 17 00:00:00 2001 From: N5N3 <2642243996@qq.com> Date: Thu, 9 Nov 2023 01:05:31 +0800 Subject: [PATCH] define `__broadcast` ourselves --- ext/StructArraysStaticArraysExt.jl | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/ext/StructArraysStaticArraysExt.jl b/ext/StructArraysStaticArraysExt.jl index fab643a0..0c84f5b7 100644 --- a/ext/StructArraysStaticArraysExt.jl +++ b/ext/StructArraysStaticArraysExt.jl @@ -33,9 +33,9 @@ StructArrays.createinstance(T::Type{<:FieldArray}, args...) = invoke(StructArray # Broadcast overload using StaticArrays: StaticArrayStyle, similar_type, Size, SOneTo -using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype, __broadcast +using StaticArrays: broadcast_flatten, broadcast_sizes, first_statictype using StructArrays: isnonemptystructtype -using Base.Broadcast: Broadcasted +using Base.Broadcast: Broadcasted, _broadcast_getindex # StaticArrayStyle has no similar defined. # Overload `try_struct_copy` instead. @@ -79,4 +79,29 @@ end end end +# The `__broadcast` kernal is copied from `StaticArrays.jl`. +# see https://github.com/JuliaArrays/StaticArrays.jl/blob/master/src/broadcast.jl +@generated function __broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize + sizes = [sz.parameters[1] for sz ∈ s.parameters] + + indices = CartesianIndices(newsize) + exprs = similar(indices, Expr) + for (j, current_ind) ∈ enumerate(indices) + exprs_vals = (broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes)) + exprs[j] = :(f($(exprs_vals...))) + end + + return quote + Base.@_inline_meta + return tuple($(exprs...)) + end +end + +broadcast_getindex(::Tuple{}, i::Int, I::CartesianIndex) = return :(_broadcast_getindex(a[$i], $I)) +function broadcast_getindex(oldsize::Tuple, i::Int, newindex::CartesianIndex) + li = LinearIndices(oldsize) + ind = _broadcast_getindex(li, newindex) + return :(a[$i][$ind]) +end + end