From 6da8cf35d52972129cdc97cf7adfc5b30753473d Mon Sep 17 00:00:00 2001 From: "Tamas K. Papp" Date: Fri, 18 Aug 2023 14:03:47 +0200 Subject: [PATCH] Make Broadcast.result_style work on styles with fields. (#50938) Fixes #50937. --------- Co-authored-by: Jameson Nash --- base/broadcast.jl | 7 +++++-- test/broadcast.jl | 13 +++++++++++++ 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index ff1f6d06af2562..fd330a7f2cb676 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -441,7 +441,9 @@ Base.Broadcast.DefaultArrayStyle{1}() function result_style end result_style(s::BroadcastStyle) = s -result_style(s1::S, s2::S) where S<:BroadcastStyle = S() +function result_style(s1::S, s2::S) where S<:BroadcastStyle + s1 ≡ s2 ? s1 : error("inconsistent broadcast styles, custom rule needed") +end # Test both orders so users typically only have to declare one order result_style(s1, s2) = result_join(s1, s2, BroadcastStyle(s1, s2), BroadcastStyle(s2, s1)) @@ -457,7 +459,8 @@ result_join(::Any, ::Any, s::BroadcastStyle, ::Unknown) = s result_join(::AbstractArrayStyle, ::AbstractArrayStyle, ::Unknown, ::Unknown) = ArrayConflict() # Fallbacks in case users define `rule` for both argument-orders (not recommended) -result_join(::Any, ::Any, ::S, ::S) where S<:BroadcastStyle = S() +result_join(::Any, ::Any, s1::S, s2::S) where S<:BroadcastStyle = result_style(s1, s2) + @noinline function result_join(::S, ::T, ::U, ::V) where {S,T,U,V} error(""" conflicting broadcast rules defined diff --git a/test/broadcast.jl b/test/broadcast.jl index 6cf05fbea139ca..73c01b1c0ee4d7 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -1142,6 +1142,19 @@ end @test CartesianIndex(1,2) .+ [CartesianIndex(3,4), CartesianIndex(5,6)] == [CartesianIndex(4, 6), CartesianIndex(6, 8)] end +struct MyBroadcastStyleWithField <: Broadcast.BroadcastStyle + i::Int +end +# asymmetry intended +Base.BroadcastStyle(a::MyBroadcastStyleWithField, b::MyBroadcastStyleWithField) = a + +@testset "issue #50937: styles that have fields" begin + @test Broadcast.result_style(MyBroadcastStyleWithField(1), MyBroadcastStyleWithField(1)) == + MyBroadcastStyleWithField(1) + @test_throws ErrorException Broadcast.result_style(MyBroadcastStyleWithField(1), + MyBroadcastStyleWithField(2)) +end + # test that `Broadcast` definition is defined as total and eligible for concrete evaluation import Base.Broadcast: BroadcastStyle, DefaultArrayStyle @test Base.infer_effects(BroadcastStyle, (DefaultArrayStyle{1},DefaultArrayStyle{2},)) |>