Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More fall backs to parent working and AbstractArray2 #120

Merged
merged 11 commits into from
Feb 17, 2021
60 changes: 51 additions & 9 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@ else
end
end

if VERSION ≥ v"1.6.0-DEV.1581"
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,true}}) where {T,N,S,A} = True()
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,false}}) where {T,N,S,A} = False()
else
_is_reshaped(::Type{ReinterpretArray{T,N,S,A}}) where {T,N,S,A} = False()
end

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)
Expand Down Expand Up @@ -56,12 +63,9 @@ function known_length(::Type{T}) where {T}
if parent_type(T) <: T
return nothing
else
return known_length(parent_type(T))
return _known_length(known_size(T))
end
end
@inline function known_length(::Type{<:SubArray{T,N,P,I}}) where {T,N,P,I}
return _known_length(ntuple(i -> known_length(I.parameters[i]), Val(N)))
end
_known_length(x::Tuple{Vararg{Union{Nothing,Int}}}) = nothing
_known_length(x::Tuple{Vararg{Int}}) = prod(x)

Expand Down Expand Up @@ -594,6 +598,7 @@ struct CPUPointer <: AbstractCPU end
struct CheckParent end
struct CPUIndex <: AbstractCPU end
struct GPU <: AbstractDevice end

"""
device(::Type{T})

Expand Down Expand Up @@ -623,13 +628,15 @@ defines_strides(::Type{T}) -> Bool

Is strides(::T) defined?
"""
defines_strides(::Type) = false
function defines_strides(::Type{T}) where {T}
if parent_type(T) <: T
return false
else
return defines_strides(parent_type(T))
end
end
defines_strides(x) = defines_strides(typeof(x))
defines_strides(::Type{<:StridedArray}) = true
defines_strides(
::Type{A},
) where {A<:Union{<:Transpose,<:Adjoint,<:SubArray,<:PermutedDimsArray}} =
defines_strides(parent_type(A))

"""
can_avx(f)
Expand Down Expand Up @@ -755,8 +762,41 @@ include("static.jl")
include("ranges.jl")
include("indexing.jl")
include("dimensions.jl")
include("axes.jl")
include("size.jl")
include("stridelayout.jl")


abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end

Base.size(A::AbstractArray2) = ArrayInterface.size(A)
Base.size(A::AbstractArray2, dim) = ArrayInterface.size(A, dim)

Base.axes(A::AbstractArray2) = ArrayInterface.axes(A)
Base.axes(A::AbstractArray2, dim) = ArrayInterface.axes(A, dim)

Base.strides(A::AbstractArray2) = ArrayInterface.strides(A)
Base.strides(A::AbstractArray2, dim) = ArrayInterface.strides(A, dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the size and stride methods convert to Int to avoid potential type instabilities/problems in generic code unaware of ArrayInterface?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These only exist to pass things forward to the ArrayInterface implementation. But if you look at the actually implementation...

axes(a, dim) = axes(a, to_dims(a, dim))
function axes(a::A, dim::Integer) where {A}
    if parent_type(A) <: A
        return Base.axes(a, Int(dim))
    else
        return axes(parent(a), to_parent_dims(A, dim))
    end
end

...StaticInt is never passed to a method in Base but persists internally, preventing premature conversion of StaticInt to an Int and relying on constant propagation for inference.

Copy link
Collaborator

@chriselrod chriselrod Feb 10, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I define a statically sized AbstractArray2 using the ArtayInterface, then Base.size will suddenly return StaticInts as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought it made sense because StaticInt is part of the public interface and we want to propagate static information, avoiding a dependence on constant propagation when dealing with deeply nested types. I'm sure we will run into some situations where methods explicitly require Int, but I also assumed that the end goal was for StaticInt to become standard.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was that users might assume that size returns an NTuple{N,Int}, and end up inadvertently writing code like myprod that doesn't handle heterogenous tuples well:

julia> using ArrayInterface: StaticInt

julia> sz = (StaticInt(4), 5);

julia> @code_warntype prod(sz)
MethodInstance for prod(::Tuple{StaticInt{4}, Int64})
  from prod(x::Tuple{Any, Vararg{Any}}) in Base at tuple.jl:480
Arguments
  #self#::Core.Const(prod)
  x::Tuple{StaticInt{4}, Int64}
Body::Int64
1%1 = Core._apply_iterate(Base.iterate, Base.:*, x)::Int64
└──      return %1


julia> function myprod(x)
           p = 1
           for a  x
               p *= a
           end
           p
       end
myprod (generic function with 1 method)

julia> @code_warntype myprod(sz)
MethodInstance for myprod(::Tuple{StaticInt{4}, Int64})
  from myprod(x) in Main at REPL[10]:1
Arguments
  #self#::Core.Const(myprod)
  x::Tuple{StaticInt{4}, Int64}
Locals
  @_3::Union{Nothing, Tuple{Union{StaticInt{4}, Int64}, Int64}}
  p::Int64
  a::Union{StaticInt{4}, Int64}
Body::Int64
1 ─       (p = 1)
│   %2  = x::Tuple{StaticInt{4}, Int64}
│         (@_3 = Base.iterate(%2))
│   %4  = (@_3::Core.Const((static(4), 2)) === nothing)::Core.Const(false)
│   %5  = Base.not_int(%4)::Core.Const(true)
└──       goto #4 if not %5
2%7  = @_3::Union{Tuple{StaticInt{4}, Int64}, Tuple{Int64, Int64}}
│         (a = Core.getfield(%7, 1))
│   %9  = Core.getfield(%7, 2)::Int64
│         (p = p * a)
│         (@_3 = Base.iterate(%2, %9))
│   %12 = (@_3 === nothing)::Bool%13 = Base.not_int(%12)::Bool
└──       goto #4 if not %13
3 ─       goto #2
4return p


julia> @btime myprod($(Ref(sz))[])
  14.331 ns (1 allocation: 16 bytes)
20

julia> @btime prod($(Ref(sz))[])
  0.989 ns (0 allocations: 0 bytes)
20

I don't want to accidentally cause bad performance in other libraries. I have no idea how prevalent code like this is though, and would hope authors would fix it.

I also figure that if code is set up to take advantage of static size information, it could probably also take the step to use ArrayInterface.size & co instead. E.g., many of my libraries have using ArrayInterface: size, strides.

But if code is written in ignorance of these, is it more likely that they'd benefit from the forced constant prop, or that they'd be hurt by the need for constant prop to ensure type stability?

As is, I have definitions like here:

@inline Base.size(A::AbstractStrideArray) = map(Int, size(A))
@inline Base.strides(A::AbstractStrideArray) = map(Int, strides(A))

@inline create_axis(s, ::Zero) = CloseOpen(s)
@inline create_axis(s, ::One) = One():s
@inline create_axis(s, o) = CloseOpen(s, s+o)

@inline ArrayInterface.axes(A::AbstractStrideArray) = map(create_axis, size(A), offsets(A))
@inline Base.axes(A::AbstractStrideArray) = axes(A)

@inline ArrayInterface.offsets(A::PtrArray) = getfield(getfield(A, :ptr), :offsets)
@inline ArrayInterface.static_length(A::AbstractStrideArray) = prod(size(A))

# type stable, because index known at compile time
@inline type_stable_select(t::NTuple, ::StaticInt{N}) where {N} = t[N]
@inline type_stable_select(t::Tuple, ::StaticInt{N}) where {N} = t[N]
# type stable, because tuple is homogenous
@inline type_stable_select(t::NTuple, i::Integer) = t[i]
# make the tuple homogenous before indexing
@inline type_stable_select(t::Tuple, i::Integer) = map(Int, t)[i]

@inline function ArrayInterface.axes(A::AbstractStrideVector, i::Integer)
    if i == 1
        o = type_stable_select(offsets(A), i)
        s = type_stable_select(size(A), i)
        return create_axis(s, o)
    else
        return One():1
    end
end
@inline function ArrayInterface.axes(A::AbstractStrideArray, i::Integer)
    o = type_stable_select(offsets(A), i)
    s = type_stable_select(size(A), i)
    create_axis(s, o)
end
@inline Base.axes(A::AbstractStrideArray, i::Integer) = axes(A, i)

Meant to try and avoid placing too many performance foot-guns about.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the examples Chris. I think it nicely illustrates the problems with returning StaticInt in the wild.

It will undoubtedly cause problems introducing StaticInt into the public space, although StaticArrays already does this and returns SOneTo for its axes. Perhaps another way of approaching this is asking ourselves if we really do plan for something like StaticInt to be passed around in the future. If we do then at what point do we start pushing people to make that adjustment? Having to move back and forth between static and dynamic space so that users don't have to learn about using things like NTuple{N} vs Tuple{Vararg{Any,N}} seems like a mistake to me, but so does country music and I'm not going to take that away from people either.

@timholy, you had some valuable input when we made StaticInt. Any opinions?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it so that strides and size will return Int types. Returning statically sized axes seems to work for StaticArrays so I left that in. I think this also makes it so that new subtypes of AbstractArray2 can just define axes and axes_types and ArrayInterface.size and ArrayInterface.strides can catch any static info.


function Base.length(A::AbstractArray2)
len = known_length(A)
if len === nothing
return prod(size(A))
else
return static(len)
end
end

@propagate_inbounds Base.getindex(A::AbstractArray2, args...) = getindex(A, args...)
@propagate_inbounds Base.getindex(A::AbstractArray2; kwargs...) = getindex(A; kwargs...)

@propagate_inbounds function Base.setindex!(A::AbstractArray2, val, args...)
return setindex!(A, val, args...)
end
@propagate_inbounds function Base.setindex!(A::AbstractArray2, val; kwargs...)
return setindex!(A, val; kwargs...)
end

function __init__()

@require SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin
Expand Down Expand Up @@ -1008,6 +1048,8 @@ function __init__()
end
stride_rank(::Type{A}) where {A<:OffsetArrays.OffsetArray} =
stride_rank(parent_type(A))
ArrayInterface.axes(A::OffsetArrays.OffsetArray) = Base.axes(A)
ArrayInterface.axes(A::OffsetArrays.OffsetArray, dim::Integer) = Base.axes(A, dim)
end
end

Expand Down
153 changes: 153 additions & 0 deletions src/axes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@

"""
axes_types(::Type{T}, dim)

Returns the axis type along dimension `dim`.
"""
axes_types(x, dim) = axes_types(typeof(x), dim)
@inline axes_types(::Type{T}, dim) where {T} = axes_types(T, to_dims(T, dim))
@inline function axes_types(::Type{T}, dim::StaticInt) where {T}
if dim > ndims(T)
return OptionallyStaticUnitRange{One,One}
else
return _get_tuple(axes_types(T), dim)
end
end
@inline function axes_types(::Type{T}, dim::Integer) where {T}
if dim > ndims(T)
return OptionallyStaticUnitRange{One,One}
else
return axes_types(T).parameters[Int(dim)]
end
end

"""
axes_types(::Type{T}) -> Type

Returns the type of the axes for `T`
"""
axes_types(x) = axes_types(typeof(x))
function axes_types(::Type{T}) where {T}
if parent_type(T) <: T
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}}
else
return eachop_tuple(axes_types, parent_type(T), to_parent_dims(T))
end
end
function axes_types(::Type{T}) where {T<:MatAdjTrans}
return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T))
end
function axes_types(::Type{T}) where {T<:PermutedDimsArray}
return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T))
end
function axes_types(::Type{T}) where {T<:AbstractRange}
if known_length(T) === nothing
return Tuple{OptionallyStaticUnitRange{One,Int}}
else
return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}}
end
end

@inline function axes_types(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}}
return _sub_axes_types(Val(ArrayStyle(T)), I, axes_types(P))
end
@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray}
return _reinterpret_axes_types(
axes_types(parent_type(T)),
eltype(T),
eltype(parent_type(T)),
)
end
function axes_types(::Type{T}) where {N,T<:Base.ReshapedArray{<:Any,N}}
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},N}}
end

# These methods help handle identifying axes that don't directly propagate from the
# parent array axes. They may be worth making a formal part of the API, as they provide
# a low traffic spot to change what axes_types produces.
@inline function sub_axis_type(::Type{A}, ::Type{I}) where {A,I}
if known_length(I) === nothing
return OptionallyStaticUnitRange{One,Int}
else
return OptionallyStaticUnitRange{One,StaticInt{known_length(I)}}
end
end
@generated function _sub_axes_types(
::Val{S},
::Type{I},
::Type{PI},
) where {S,I<:Tuple,PI<:Tuple}
out = Expr(:curly, :Tuple)
d = 1
for i in I.parameters
ad = argdims(S, i)
if ad > 0
push!(out.args, :(sub_axis_type($(PI.parameters[d]), $i)))
d += ad
else
d += 1
end
end
Expr(:block, Expr(:meta, :inline), out)
end
@inline function reinterpret_axis_type(::Type{A}, ::Type{T}, ::Type{S}) where {A,T,S}
if known_length(A) === nothing
return OptionallyStaticUnitRange{One,Int}
else
return OptionallyStaticUnitRange{
One,
StaticInt{Int(known_length(A) / (sizeof(T) / sizeof(S)))},
}
end
end
@generated function _reinterpret_axes_types(
::Type{I},
::Type{T},
::Type{S},
) where {I<:Tuple,T,S}
out = Expr(:curly, :Tuple)
for i = 1:length(I.parameters)
if i === 1
push!(out.args, reinterpret_axis_type(I.parameters[1], T, S))
else
push!(out.args, I.parameters[i])
end
end
Expr(:block, Expr(:meta, :inline), out)
end

"""
axes(A, d)

Return a valid range that maps to each index along dimension `d` of `A`.
"""
axes(a, dim) = axes(a, to_dims(a, dim))
function axes(a::A, dim::Integer) where {A}
if parent_type(A) <: A
return Base.axes(a, Int(dim))
else
return axes(parent(a), to_parent_dims(A, dim))
end
end
axes(A::SubArray, dim::Integer) = Base.axes(A, dim) # TODO implement ArrayInterface version
axes(A::ReinterpretArray, dim::Integer) = Base.axes(A, dim) # TODO implement ArrayInterface version
axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, dim) # TODO implement ArrayInterface version

"""
axes(A)

Return a tuple of ranges where each range maps to each element along a dimension of `A`.
"""
@inline function axes(a::A) where {A}
if parent_type(A) <: A
return Base.axes(a)
else
return axes(parent(a))
end
end
axes(A::PermutedDimsArray) = permute(axes(parent(A)), to_parent_dims(A))
axes(A::Union{Transpose,Adjoint}) = Base.axes(A) # TODO implement ArrayInterface version
axes(A::SubArray) = Base.axes(A) # TODO implement ArrayInterface version
axes(A::ReinterpretArray) = Base.axes(A) # TODO implement ArrayInterface version
axes(A::Base.ReshapedArray) = Base.axes(A) # TODO implement ArrayInterface version

Loading