Skip to content

Commit

Permalink
Fix inference
Browse files Browse the repository at this point in the history
  • Loading branch information
odow committed Apr 4, 2021
1 parent 3e3ce21 commit 2670b38
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 23 deletions.
44 changes: 22 additions & 22 deletions src/Containers/DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,41 +328,41 @@ end
################

# This implementation follows the instructions at
# hhttps://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting
# https://docs.julialang.org/en/v1/manual/interfaces/#man-interfaces-broadcasting
# for implementing broadcast.

struct _DenseAxisArrayBroadcastStyle <: Broadcast.BroadcastStyle end

Base.BroadcastStyle(::Type{<:DenseAxisArray}) = _DenseAxisArrayBroadcastStyle()

function Base.BroadcastStyle(
::_DenseAxisArrayBroadcastStyle,
::Base.Broadcast.DefaultArrayStyle{0},
)
return _DenseAxisArrayBroadcastStyle()
function Base.BroadcastStyle(::Type{T}) where {T<:DenseAxisArray}
return Broadcast.ArrayStyle{T}()
end

function _broadcast_axes_inner(arg::DenseAxisArray, axes_lookup)
if axes_lookup !== nothing && axes(arg) != axes_lookup[1]
error("Unable to broadcast over DenseAxisArrays with different axes.")
function _broadcast_axes_check(x::NTuple{N}) where {N}
axes = first(x)
for i = 2:N
if x[i][1] != axes[1]
error("Unable to broadcast over DenseAxisArrays with different axes.")
end
end
return (arg.axes, arg.lookup)
return axes
end
_broadcast_axes_inner(::Any, axes_lookup) = axes_lookup

function _broadcast_axes(args::Tuple, axes = nothing)
new_axes = _broadcast_axes_inner(first(args), axes)
return _broadcast_axes(Base.tail(args), new_axes)
_broadcast_axes(x::Tuple) = _broadcast_axes(first(x), Base.tail(x))
_broadcast_axes(::Tuple{}) = ()
_broadcast_axes(::Any, tail) = _broadcast_axes(tail)
function _broadcast_axes(x::DenseAxisArray, tail)
return ((x.axes, x.lookup), _broadcast_axes(tail)...)
end
_broadcast_axes(::NTuple{0}, axes) = axes

_broadcast_args(x::Tuple) = _broadcast_args(first(x), Base.tail(x))
_broadcast_args(::Tuple{}) = ()
_broadcast_args(x::Any, tail) = (x, _broadcast_args(tail)...)
_broadcast_args(x::DenseAxisArray, tail) = (x.data, _broadcast_args(tail)...)
_broadcast_args(::Tuple{}) = ()

function Base.Broadcast.broadcasted(::_DenseAxisArrayBroadcastStyle, f, args...)
axes_lookup = _broadcast_axes(args)
function Base.Broadcast.broadcasted(
::Broadcast.ArrayStyle{<:DenseAxisArray},
f,
args...,
)
axes_lookup = _broadcast_axes_check(_broadcast_axes(args))
new_args = _broadcast_args(args)
return DenseAxisArray(broadcast(f, new_args...), axes_lookup...)
end
Expand Down
2 changes: 1 addition & 1 deletion test/Containers/DenseAxisArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ And data, a 0-dimensional $(Array{Int,0}):
@testset "Broadcast" begin
foo(x, y) = x + y
foo_b(x, y) = foo.(x, y)
bar(x, y) = foo.(x, y) .+ x
bar(x, y) = (foo.(x, y) .+ x).^2
a = [5.0 6.0; 7.0 8.0]
A = DenseAxisArray(a, [:a, :b], [:a, :b])
b = a .+ 1
Expand Down

0 comments on commit 2670b38

Please sign in to comment.