Skip to content

Commit

Permalink
Support broadcasting over multiple DenseAxisArrays (#2533)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Apr 5, 2021
1 parent ff182a7 commit f1d41eb
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 41 deletions.
77 changes: 36 additions & 41 deletions src/Containers/DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
struct _AxisLookup{D}
data::D
end
Base.:(==)(x::_AxisLookup{D}, y::_AxisLookup{D}) where {D} = x.data == y.data

# Default fallbacks.
Base.getindex(::_AxisLookup, key) = throw(KeyError(key))
Expand Down Expand Up @@ -327,52 +328,46 @@ end
################

# This implementation follows the instructions at
# https://docs.julialang.org/en/latest/manual/interfaces/#man-interfaces-broadcasting-1
# for implementing broadcast. We eagerly evaluate expressions involving
# DenseAxisArrays, overriding operation fusion. For now, nested (fused)
# broadcasts like f.(A .+ 1) don't work, and we don't support broadcasts
# where multiple DenseAxisArrays appear. This is a stopgap solution to get tests
# passing on Julia 0.7 and leaves lots of room for improvement.
struct DenseAxisArrayBroadcastStyle <: Broadcast.BroadcastStyle end
# Scalars can be used with DenseAxisArray in broadcast
function Base.BroadcastStyle(
::DenseAxisArrayBroadcastStyle,
::Base.Broadcast.DefaultArrayStyle{0},
)
return DenseAxisArrayBroadcastStyle()
end
Base.BroadcastStyle(::Type{<:DenseAxisArray}) = DenseAxisArrayBroadcastStyle()
function Base.Broadcast.broadcasted(::DenseAxisArrayBroadcastStyle, f, args...)
array = find_jump_array(args)
if sum(arg isa DenseAxisArray for arg in args) > 1
error(
"Broadcast operations with multiple DenseAxisArrays are not yet " *
"supported.",
)
# https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting
# for implementing broadcast.

function Base.BroadcastStyle(::Type{T}) where {T<:DenseAxisArray}
return Broadcast.ArrayStyle{T}()
end

function _broadcast_axes_check(x::NTuple{N}) where {N}
axes = first(x)
for i in 2:N
if x[i][1] != axes[1]
error(
"Unable to broadcast over DenseAxisArrays with different axes.",
)
end
end
result_data = broadcast(f, unpack_jump_array(args)...)
return DenseAxisArray(result_data, array.axes, array.lookup)
end
function find_jump_array(args::Tuple)
return find_jump_array(args[1], Base.tail(args))
end
find_jump_array(array::DenseAxisArray, rest) = array
find_jump_array(::Any, rest) = find_jump_array(rest)
function find_jump_array(broadcasted::Broadcast.Broadcasted)
return error(
"Unsupported nested broadcast operation. DenseAxisArray supports " *
"only simple broadcast operations like f.(A) but not f.(A .+ 1).",
)
return axes
end

function unpack_jump_array(args::Tuple)
return unpack_jump_array(args[1], Base.tail(args))
_broadcast_axes(x::Tuple) = _broadcast_axes(first(x), Base.tail(x))
_broadcast_axes(::Tuple{}) = ()
_broadcast_axes(::Any, tail) = _broadcast_axes(tail)
function _broadcast_axes(x::DenseAxisArray, tail)
return ((x.axes, x.lookup), _broadcast_axes(tail)...)
end
unpack_jump_array(args::Tuple{}) = ()
function unpack_jump_array(array::DenseAxisArray, rest)
return (array.data, unpack_jump_array(rest)...)

_broadcast_args(x::Tuple) = _broadcast_args(first(x), Base.tail(x))
_broadcast_args(::Tuple{}) = ()
_broadcast_args(x::Any, tail) = (x, _broadcast_args(tail)...)
_broadcast_args(x::DenseAxisArray, tail) = (x.data, _broadcast_args(tail)...)

function Base.Broadcast.broadcasted(
::Broadcast.ArrayStyle{<:DenseAxisArray},
f,
args...,
)
axes_lookup = _broadcast_axes_check(_broadcast_axes(args))
new_args = _broadcast_args(args)
return DenseAxisArray(broadcast(f, new_args...), axes_lookup...)
end
unpack_jump_array(other::Any, rest) = (other, unpack_jump_array(rest)...)

########
# Show #
Expand Down
22 changes: 22 additions & 0 deletions test/Containers/DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,26 @@ And data, a 0-dimensional $(Array{Int,0}):
@test y isa DenseAxisArray
@test x == y
end
@testset "Broadcast" begin
foo(x, y) = x + y
foo_b(x, y) = foo.(x, y)
bar(x, y) = (foo.(x, y) .+ x) .^ 2
a = [5.0 6.0; 7.0 8.0]
A = DenseAxisArray(a, [:a, :b], [:a, :b])
b = a .+ 1
B = A .+ 1
@test B == DenseAxisArray(b, [:a, :b], [:a, :b])
C = @inferred foo_b(A, B)
@test C == DenseAxisArray(foo_b(a, b), [:a, :b], [:a, :b])
D = @inferred bar(A, B)
@test D == DenseAxisArray(bar(a, b), [:a, :b], [:a, :b])
end
@testset "Broadcast_errors" begin
a = [5.0 6.0; 7.0 8.0]
A = DenseAxisArray(a, [:a, :b], [:a, :b])
B = DenseAxisArray(a, [:b, :a], [:a, :b])
@test_throws ErrorException A .+ B
b = [5.0 6.0; 7.0 8.0; 9.0 10.0]
@test_throws DimensionMismatch A .+ b
end
end

0 comments on commit f1d41eb

Please sign in to comment.