diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index 40736b6c1..5fe7f1429 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -6,7 +6,9 @@ using LinearAlgebra using SparseArrays using Base.Cartesian -using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray +using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray, + ReshapedArray + @static if VERSION >= v"1.7.0-DEV.421" using Base: @aggressive_constprop @@ -16,6 +18,13 @@ else end end +if VERSION ≥ v"1.6.0-DEV.1581" + _is_reshaped(::Type{ReinterpretArray{T,N,S,A,true}}) where {T,N,S,A} = true + _is_reshaped(::Type{ReinterpretArray{T,N,S,A,false}}) where {T,N,S,A} = false +else + _is_reshaped(::Type{ReinterpretArray{T,N,S,A}}) where {T,N,S,A} = false +end + Base.@pure __parameterless_type(T) = Base.typename(T).wrapper parameterless_type(x) = parameterless_type(typeof(x)) parameterless_type(x::Type) = __parameterless_type(x) @@ -23,6 +32,8 @@ parameterless_type(x::Type) = __parameterless_type(x) const VecAdjTrans{T,V<:AbstractVector{T}} = Union{Transpose{T,V},Adjoint{T,V}} const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}} +include("static.jl") + """ parent_type(::Type{T}) @@ -56,12 +67,9 @@ function known_length(::Type{T}) where {T} if parent_type(T) <: T return nothing else - return known_length(parent_type(T)) + return _known_length(known_size(T)) end end -@inline function known_length(::Type{<:SubArray{T,N,P,I}}) where {T,N,P,I} - return _known_length(ntuple(i -> known_length(I.parameters[i]), Val(N))) -end _known_length(x::Tuple{Vararg{Union{Nothing,Int}}}) = nothing _known_length(x::Tuple{Vararg{Int}}) = prod(x) @@ -594,6 +602,7 @@ struct CPUPointer <: AbstractCPU end struct CheckParent end struct CPUIndex <: AbstractCPU end struct GPU <: AbstractDevice end + """ device(::Type{T}) @@ -623,13 +632,18 @@ defines_strides(::Type{T}) -> Bool Is strides(::T) defined? """ -defines_strides(::Type) = false defines_strides(x) = defines_strides(typeof(x)) +function defines_strides(::Type{T}) where {T} + if parent_type(T) <: T + return false + else + return defines_strides(parent_type(T)) + end +end defines_strides(::Type{<:StridedArray}) = true -defines_strides( - ::Type{A}, -) where {A<:Union{<:Transpose,<:Adjoint,<:SubArray,<:PermutedDimsArray}} = - defines_strides(parent_type(A)) +function defines_strides(::Type{<:SubArray{T,N,P,I}}) where {T,N,P,I} + return stride_preserving_index(I) === True() +end """ can_avx(f) @@ -751,12 +765,44 @@ end end end -include("static.jl") include("ranges.jl") include("indexing.jl") include("dimensions.jl") +include("axes.jl") +include("size.jl") include("stridelayout.jl") + +abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end + +Base.size(A::AbstractArray2) = map(Int, ArrayInterface.size(A)) +Base.size(A::AbstractArray2, dim) = Int(ArrayInterface.size(A, dim)) + +Base.axes(A::AbstractArray2) = ArrayInterface.axes(A) +Base.axes(A::AbstractArray2, dim) = ArrayInterface.axes(A, dim) + +Base.strides(A::AbstractArray2) = map(Int, ArrayInterface.strides(A)) +Base.strides(A::AbstractArray2, dim) = Int(ArrayInterface.strides(A, dim)) + +function Base.length(A::AbstractArray2) + len = known_length(A) + if len === nothing + return prod(size(A)) + else + return Int(len) + end +end + +@propagate_inbounds Base.getindex(A::AbstractArray2, args...) = getindex(A, args...) +@propagate_inbounds Base.getindex(A::AbstractArray2; kwargs...) = getindex(A; kwargs...) + +@propagate_inbounds function Base.setindex!(A::AbstractArray2, val, args...) + return setindex!(A, val, args...) +end +@propagate_inbounds function Base.setindex!(A::AbstractArray2, val; kwargs...) + return setindex!(A, val; kwargs...) +end + function __init__() @require SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin @@ -811,6 +857,7 @@ function __init__() function dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N} return ArrayInterface._all_dense(Val(N)) end + defines_strides(::Type{<:StaticArrays.SArray}) = true defines_strides(::Type{<:StaticArrays.MArray}) = true @generated function axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S} @@ -1008,6 +1055,8 @@ function __init__() end stride_rank(::Type{A}) where {A<:OffsetArrays.OffsetArray} = stride_rank(parent_type(A)) + ArrayInterface.axes(A::OffsetArrays.OffsetArray) = Base.axes(A) + ArrayInterface.axes(A::OffsetArrays.OffsetArray, dim::Integer) = Base.axes(A, dim) end end diff --git a/src/axes.jl b/src/axes.jl new file mode 100644 index 000000000..6838cf36d --- /dev/null +++ b/src/axes.jl @@ -0,0 +1,158 @@ + +""" + axes_types(::Type{T}, dim) + +Returns the axis type along dimension `dim`. +""" +axes_types(x, dim) = axes_types(typeof(x), dim) +@inline axes_types(::Type{T}, dim) where {T} = axes_types(T, to_dims(T, dim)) +@inline function axes_types(::Type{T}, dim::StaticInt{D}) where {T,D} + if D > ndims(T) + return OptionallyStaticUnitRange{One,One} + else + return _get_tuple(axes_types(T), dim) + end +end +@inline function axes_types(::Type{T}, dim::Int) where {T} + if dim > ndims(T) + return OptionallyStaticUnitRange{One,One} + else + return axes_types(T).parameters[dim] + end +end + +""" + axes_types(::Type{T}) -> Type + +Returns the type of the axes for `T` +""" +axes_types(x) = axes_types(typeof(x)) +axes_types(::Type{T}) where {T<:Array} = Tuple{Vararg{OneTo{Int},ndims(T)}} +function axes_types(::Type{T}) where {T} + if parent_type(T) <: T + return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}} + else + return axes_types(parent_type(T)) + end +end +function axes_types(::Type{T}) where {T<:VecAdjTrans} + return Tuple{OptionallyStaticUnitRange{One,One},axes_types(parent_type(T), One())} +end +function axes_types(::Type{T}) where {T<:MatAdjTrans} + return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T)) +end +function axes_types(::Type{T}) where {T<:PermutedDimsArray} + return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T)) +end +function axes_types(::Type{T}) where {T<:AbstractRange} + if known_length(T) === nothing + return Tuple{OptionallyStaticUnitRange{One,Int}} + else + return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}} + end +end +function axes_types(::Type{T}) where {N,T<:Base.ReshapedArray{<:Any,N}} + return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},N}} +end + +_int_or_static_int(::Nothing) = Int +_int_or_static_int(x::Int) = StaticInt{x} + +@inline function axes_types(::Type{T}) where {N,P,I,T<:SubArray{<:Any,N,P,I}} + return eachop_tuple(_sub_axis_type, T, to_parent_dims(T)) +end +@inline function _sub_axis_type(::Type{A}, dim::StaticInt) where {T,N,P,I,A<:SubArray{T,N,P,I}} + return OptionallyStaticUnitRange{ + _int_or_static_int(known_first(axes_types(P, dim))), + _int_or_static_int(known_length(_get_tuple(I, dim))) + } +end + +function axes_types(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}} + if _is_reshaped(R) + if sizeof(S) === sizeof(T) + return axes_types(A) + elseif sizeof(S) > sizeof(T) + return eachop_tuple(_reshaped_axis_type, R, to_parent_dims(R)) + else + return eachop_tuple(axes_types, A, to_parent_dims(R)) + end + else + return eachop_tuple(_non_reshaped_axis_type, R, to_parent_dims(R)) + end +end + +function _reshaped_axis_type(::Type{R}, dim::StaticInt) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}} + return axes_types(parent_type(R), dim) +end +function _reshaped_axis_type(::Type{R}, dim::Zero) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}} + return OptionallyStaticUnitRange{One,StaticInt{div(sizeof(S), sizeof(T))}} +end + +function _non_reshaped_axis_type(::Type{R}, dim::StaticInt) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}} + return axes_types(parent_type(R), dim) +end +function _non_reshaped_axis_type(::Type{R}, dim::One) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}} + paxis = axes_types(A, dim) + len = known_length(paxis) + if len === nothing + raxis = OptionallyStaticUnitRange{One,Int} + else + raxis = OptionallyStaticUnitRange{One,StaticInt{div(len * sizeof(S), sizeof(T))}} + end + return similar_type(paxis, Int, raxis) +end + +#= + similar_type(orignal_type, new_data_type) +=# +similar_type(::Type{OneTo{Int}}, ::Type{Int}, ::Type{OneTo{Int}}) = OneTo{Int} +similar_type(::Type{OneTo{Int}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,Int}}) = OneTo{Int} +similar_type(::Type{OneTo{Int}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,StaticInt{N}}}) where {N} = OptionallyStaticUnitRange{One,StaticInt{N}} + +similar_type(::Type{OptionallyStaticUnitRange{One,Int}}, ::Type{Int}, ::Type{OneTo{Int}}) = OptionallyStaticUnitRange{One,Int} +similar_type(::Type{OptionallyStaticUnitRange{One,Int}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,Int}}) = OptionallyStaticUnitRange{One,Int} +similar_type(::Type{OptionallyStaticUnitRange{One,Int}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,StaticInt{N}}}) where {N} = OptionallyStaticUnitRange{One,StaticInt{N}} + +similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N}}}, ::Type{Int}, ::Type{OneTo{Int}}) where {N} = OptionallyStaticUnitRange{One,Int} +similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N}}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,Int}}) where {N} = OptionallyStaticUnitRange{One,Int} +similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N1}}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,StaticInt{N2}}}) where {N1,N2} = OptionallyStaticUnitRange{One,StaticInt{N2}} + +""" + axes(A, d) + +Return a valid range that maps to each index along dimension `d` of `A`. +""" +axes(a, dim) = axes(a, to_dims(a, dim)) +function axes(a::A, dim::Integer) where {A} + if parent_type(A) <: A + return Base.axes(a, Int(dim)) + else + return axes(parent(a), to_parent_dims(A, dim)) + end +end +axes(A::SubArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version +axes(A::ReinterpretArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version +axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version + +""" + axes(A) + +Return a tuple of ranges where each range maps to each element along a dimension of `A`. +""" +@inline function axes(a::A) where {A} + if parent_type(A) <: A + return Base.axes(a) + else + return axes(parent(a)) + end +end +axes(A::PermutedDimsArray) = permute(axes(parent(A)), to_parent_dims(A)) +function axes(A::Union{Transpose,Adjoint}) + p = parent(A) + return (axes(p, StaticInt(2)), axes(p, One())) +end +axes(A::SubArray) = Base.axes(A) # TODO implement ArrayInterface version +axes(A::ReinterpretArray) = Base.axes(A) # TODO implement ArrayInterface version +axes(A::Base.ReshapedArray) = Base.axes(A) # TODO implement ArrayInterface version + diff --git a/src/dimensions.jl b/src/dimensions.jl index 10cd52921..84b73c399 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -1,4 +1,8 @@ +function throw_dim_error(@nospecialize(x), @nospecialize(dim)) + throw(DimensionMismatch("$x does not have dimension corresponding to $dim")) +end + #julia> @btime ArrayInterface.is_increasing(ArrayInterface.nstatic(Val(10))) # 0.045 ns (0 allocations: 0 bytes) #ArrayInterface.True() @@ -19,20 +23,23 @@ end is_increasing(::Tuple{StaticInt{X}}) where {X} = True() """ - from_parent_dims(::Type{T}) -> Bool + from_parent_dims(::Type{T}) -> Tuple Returns the mapping from parent dimensions to child dimensions. """ +from_parent_dims(x) = from_parent_dims(typeof(x)) from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T))) -from_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One()) +from_parent_dims(::Type{T}) where {T<:VecAdjTrans} = (StaticInt(2),) +from_parent_dims(::Type{T}) where {T<:MatAdjTrans} = (StaticInt(2), One()) from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I) -@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}} +@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,I<:Tuple} out = Expr(:tuple) - n = 1 - for p in I.parameters + dim_i = 1 + for i in 1:ndims(A) + p = I.parameters[i] if argdims(A, p) > 0 - push!(out.args, :(StaticInt($n))) - n += 1 + push!(out.args, :(StaticInt($dim_i))) + dim_i += 1 else push!(out.args, :(StaticInt(0))) end @@ -41,8 +48,44 @@ from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A end from_parent_dims(::Type{<:PermutedDimsArray{T,N,<:Any,I}}) where {T,N,I} = static(Val(I)) +function from_parent_dims(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}} + if !_is_reshaped(R) || sizeof(S) === sizeof(T) + return nstatic(Val(ndims(A))) + elseif sizeof(S) > sizeof(T) + return tail(nstatic(Val(ndims(A) + 1))) + else # sizeof(S) < sizeof(T) + return (Zero(), nstatic(Val(N))...) + end +end + """ - to_parent_dims(::Type{T}) -> Bool + from_parent_dims(::Type{T}, dim) -> Integer + +Returns the mapping from child dimensions to parent dimensions. +""" +from_parent_dims(x, dim) = from_parent_dims(typeof(x), dim) +@aggressive_constprop function from_parent_dims(::Type{T}, dim::Int)::Int where {T} + if dim > ndims(T) + return static(ndims(parent_type(T)) + dim - ndims(T)) + elseif dim > 0 + return @inbounds(getfield(from_parent_dims(T), dim)) + else + throw_dim_error(T, dim) + end +end + +function from_parent_dims(::Type{T}, ::StaticInt{dim}) where {T,dim} + if dim > ndims(T) + return static(ndims(parent_type(T)) + dim - ndims(T)) + elseif dim > 0 + return @inbounds(getfield(from_parent_dims(T), dim)) + else + throw_dim_error(T, dim) + end +end + +""" + to_parent_dims(::Type{T}) -> Tuple Returns the mapping from child dimensions to parent dimensions. """ @@ -62,6 +105,42 @@ to_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _to_sub_dims(A, I) end out end +function to_parent_dims(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}} + pdims = nstatic(Val(ndims(A))) + if !_is_reshaped(R) || sizeof(S) === sizeof(T) + return pdims + elseif sizeof(S) > sizeof(T) + return (Zero(), pdims...,) + else + return tail(pdims) + end +end + +""" + to_parent_dims(::Type{T}, dim) -> Integer + +Returns the mapping from child dimensions to parent dimensions. +""" +to_parent_dims(x, dim) = to_parent_dims(typeof(x), dim) +@aggressive_constprop function to_parent_dims(::Type{T}, dim::Int)::Int where {T} + if dim > ndims(T) + return static(ndims(parent_type(T)) + dim - ndims(T)) + elseif dim > 0 + return @inbounds(getfield(to_parent_dims(T), dim)) + else + throw_dim_error(T, dim) + end +end + +function to_parent_dims(::Type{T}, ::StaticInt{dim}) where {T,dim} + if dim > ndims(T) + return static(ndims(parent_type(T)) + dim - ndims(T)) + elseif dim > 0 + return @inbounds(getfield(to_parent_dims(T), dim)) + else + throw_dim_error(T, dim) + end +end """ has_dimnames(::Type{T}) -> Bool @@ -87,63 +166,33 @@ const SUnderscore = StaticSymbol(:_) Return the names of the dimensions for `x`. """ @inline dimnames(x) = dimnames(typeof(x)) -@inline dimnames(x, dim::Int) = dimnames(typeof(x), dim) -@inline dimnames(x, dim::StaticInt) = dimnames(typeof(x), dim) -@inline function dimnames(::Type{T}, ::StaticInt{dim}) where {T,dim} - if ndims(T) < dim - return SUnderscore - else - return getfield(dimnames(T), dim) - end -end -@inline function dimnames(::Type{T}, dim::Int) where {T} - if ndims(T) < dim +@inline dimnames(x, dim) = dimnames(typeof(x), dim) +@inline function dimnames(::Type{T}, dim) where {T} + if parent_type(T) <: T return SUnderscore else - return getfield(dimnames(T), dim) + return dimnames(parent_type(T), to_parent_dims(T, dim)) end end @inline function dimnames(::Type{T}) where {T} if parent_type(T) <: T return ntuple(_ -> SUnderscore, Val(ndims(T))) else - return dimnames(parent_type(T)) - end -end -@inline function dimnames(::Type{T}) where {T<:Union{Adjoint,Transpose}} - _transpose_dimnames(dimnames(parent_type(T))) -end -@inline _transpose_dimnames(x::Tuple{Any,Any}) = (last(x), first(x)) -@inline _transpose_dimnames(x::Tuple{Any}) = (SUnderscore, first(x)) - -@inline function dimnames(::Type{T}) where {I,T<:PermutedDimsArray{<:Any,<:Any,I}} - return map(i -> dimnames(parent_type(T), i), I) -end -function dimnames(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}} - return _sub_array_dimnames(Val(dimnames(P)), Val(argdims(P, I))) -end -@generated function _sub_array_dimnames(::Val{L}, ::Val{I}) where {L,I} - e = Expr(:tuple) - nl = length(L) - for i in 1:length(I) - if I[i] > 0 - if nl < i - push!(e.args, :(ArrayInterface.SUnderscore)) - else - push!(e.args, QuoteNode(L[i])) - end + perm = to_parent_dims(T) + if invariant_permutation(perm, perm) isa True + return dimnames(parent_type(T)) + else + return eachop(dimnames, parent_type(T), perm) end end - return e +end +function dimnames(::Type{T}) where {T<:SubArray} + return eachop(dimnames, parent_type(T), to_parent_dims(T)) end _to_int(x::Integer) = Int(x) _to_int(x::StaticInt) = x -function no_dimname_error(@nospecialize(x), @nospecialize(dim)) - throw(ArgumentError("($(repr(dim))) does not correspond to any dimension of ($(x))")) -end - """ to_dims(::Type{T}, dim) -> Integer @@ -154,12 +203,16 @@ to_dims(::Type{T}, dim::Integer) where {T} = _to_int(dim) to_dims(::Type{T}, dim::Colon) where {T} = dim function to_dims(::Type{T}, dim::StaticSymbol) where {T} i = find_first_eq(dim, dimnames(T)) - i === nothing && no_dimname_error(T, dim) + if i === nothing + throw_dim_error(T, dim) + end return i end -@inline function to_dims(::Type{T}, dim::Symbol) where {T} - i = find_first_eq(dim, Symbol.(dimnames(T))) - i === nothing && no_dimname_error(T, dim) +@aggressive_constprop function to_dims(::Type{T}, dim::Symbol) where {T} + i = find_first_eq(dim, map(Symbol, dimnames(T))) + if i === nothing + throw_dim_error(T, dim) + end return i end to_dims(::Type{T}, dims::Tuple) where {T} = map(i -> to_dims(T, i), dims) @@ -212,160 +265,3 @@ function _order_named_inds_check(inds::Tuple{Vararg{Any,N}}, nkwargs::Int) where return nothing end -""" - axes_types(::Type{T}[, d]) -> Type - -Returns the type of the axes for `T` -""" -axes_types(x) = axes_types(typeof(x)) -axes_types(x, d) = axes_types(typeof(x), d) -@inline axes_types(::Type{T}, d) where {T} = axes_types(T).parameters[to_dims(T, d)] -function axes_types(::Type{T}) where {T} - if parent_type(T) <: T - return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}} - else - return axes_types(parent_type(T)) - end -end -function axes_types(::Type{T}) where {T<:MatAdjTrans} - return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T)) -end -function axes_types(::Type{T}) where {T<:PermutedDimsArray} - return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T)) -end -function axes_types(::Type{T}) where {T<:AbstractRange} - if known_length(T) === nothing - return Tuple{OptionallyStaticUnitRange{One,Int}} - else - return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}} - end -end - -@inline function axes_types(::Type{T}) where {P,I,T<:SubArray{<:Any,<:Any,P,I}} - return _sub_axes_types(Val(ArrayStyle(T)), I, axes_types(P)) -end -@inline function axes_types(::Type{T}) where {T<:Base.ReinterpretArray} - return _reinterpret_axes_types( - axes_types(parent_type(T)), - eltype(T), - eltype(parent_type(T)), - ) -end -function axes_types(::Type{T}) where {N,T<:Base.ReshapedArray{<:Any,N}} - return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},N}} -end - -# These methods help handle identifying axes that don't directly propagate from the -# parent array axes. They may be worth making a formal part of the API, as they provide -# a low traffic spot to change what axes_types produces. -@inline function sub_axis_type(::Type{A}, ::Type{I}) where {A,I} - if known_length(I) === nothing - return OptionallyStaticUnitRange{One,Int} - else - return OptionallyStaticUnitRange{One,StaticInt{known_length(I)}} - end -end -@generated function _sub_axes_types( - ::Val{S}, - ::Type{I}, - ::Type{PI}, -) where {S,I<:Tuple,PI<:Tuple} - out = Expr(:curly, :Tuple) - d = 1 - for i in I.parameters - ad = argdims(S, i) - if ad > 0 - push!(out.args, :(sub_axis_type($(PI.parameters[d]), $i))) - d += ad - else - d += 1 - end - end - Expr(:block, Expr(:meta, :inline), out) -end -@inline function reinterpret_axis_type(::Type{A}, ::Type{T}, ::Type{S}) where {A,T,S} - if known_length(A) === nothing - return OptionallyStaticUnitRange{One,Int} - else - return OptionallyStaticUnitRange{ - One, - StaticInt{Int(known_length(A) / (sizeof(T) / sizeof(S)))}, - } - end -end -@generated function _reinterpret_axes_types( - ::Type{I}, - ::Type{T}, - ::Type{S}, -) where {I<:Tuple,T,S} - out = Expr(:curly, :Tuple) - for i = 1:length(I.parameters) - if i === 1 - push!(out.args, reinterpret_axis_type(I.parameters[1], T, S)) - else - push!(out.args, I.parameters[i]) - end - end - Expr(:block, Expr(:meta, :inline), out) -end - -""" - size(A) - -Returns the size of `A`. If the size of any axes are known at compile time, -these should be returned as `Static` numbers. For example: -```julia -julia> using StaticArrays, ArrayInterface - -julia> A = @SMatrix rand(3,4); - -julia> ArrayInterface.size(A) -(StaticInt{3}(), StaticInt{4}()) -``` -""" -@inline size(A) = Base.size(A) -@inline size(A, d::Integer) = size(A)[Int(d)] -@inline size(A, d) = Base.size(A, to_dims(A, d)) -@inline size(x::VecAdjTrans) = (One(), static_length(x)) - -function size(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - return _size(size(parent(B)), B.indices, map(static_length, B.indices)) -end -function strides(B::S) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - return _strides(strides(parent(B)), B.indices) -end -@generated function _size(A::Tuple{Vararg{Any,N}}, inds::I, l::L) where {N,I<:Tuple,L} - t = Expr(:tuple) - for n = 1:N - if (I.parameters[n] <: Base.Slice) - push!(t.args, :(@inbounds(_try_static(A[$n], l[$n])))) - elseif I.parameters[n] <: Number - nothing - else - push!(t.args, Expr(:ref, :l, n)) - end - end - Expr(:block, Expr(:meta, :inline), t) -end -@inline size(v::AbstractVector) = (static_length(v),) -@inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B)) -@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A} - return permute(size(parent(B)), to_parent_dims(B)) -end -@inline size(A::AbstractArray, ::StaticInt{N}) where {N} = size(A)[N] -@inline size(A::AbstractArray, ::Val{N}) where {N} = size(A)[N] -""" - axes(A, d) - -Return a valid range that maps to each index along dimension `d` of `A`. -""" -@inline axes(A, d) = axes(A, to_dims(A, d)) -@inline axes(A, d::Integer) = axes(A)[Int(d)] - -""" - axes(A) - -Return a tuple of ranges where each range maps to each element along a dimension of `A`. -""" -@inline axes(A) = Base.axes(A) - diff --git a/src/indexing.jl b/src/indexing.jl index 4281b3017..2931ac1e7 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -456,6 +456,8 @@ function unsafe_getindex(::UnsafeGetCollection, A, inds; kwargs...) return unsafe_get_collection(A, inds; kwargs...) end +unsafe_get_element_error(A, inds) = throw(MethodError(unsafe_get_element, (A, inds))) + """ unsafe_get_element(A::AbstractArray{T}, inds::Tuple) -> T @@ -463,7 +465,13 @@ Returns an element of `A` at the indices `inds`. This method assumes all `inds` have been checked for being in bounds. Any new array type using `ArrayInterface.getindex` must define `unsafe_get_element(::NewArrayType, inds)`. """ -unsafe_get_element(A, inds; kwargs...) = throw(MethodError(unsafe_getindex, (A, inds))) +function unsafe_get_element(a::A, inds) where {A} + if parent_type(A) <: A + unsafe_get_element_error(a, inds) + else + return @inbounds(parent(a)[inds...]) + end +end function unsafe_get_element(A::Array, inds) if length(inds) === 0 return Base.arrayref(false, A, 1) @@ -497,28 +505,10 @@ function unsafe_get_collection(A, inds; kwargs...) return dest end -can_preserve_indices(::Type{T}) where {T<:AbstractRange} = true -can_preserve_indices(::Type{T}) where {T<:Int} = true -can_preserve_indices(::Type{T}) where {T} = false - -# if linear indexing on multidim or can't reconstruct AbstractUnitRange -# then construct Array of CartesianIndex/LinearIndices -function can_preserve_indices(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} - return all(eachop(_can_preserve_indices, T, nstatic(Val(N)))) -end -function _can_preserve_indices(::Type{T}, i::StaticInt) where {T} - if can_preserve_indices(_get_tuple(T, i)) - return True() - else - return False() - end -end - _ints2range(x::Integer) = x:x _ints2range(x::AbstractRange) = x - @inline function unsafe_get_collection(A::CartesianIndices{N}, inds) where {N} - if (length(inds) === 1 && N > 1) || !can_preserve_indices(typeof(inds)) + if (length(inds) === 1 && N > 1) || stride_preserving_index(typeof(inds)) === False() return Base._getindex(IndexStyle(A), A, inds...) else return CartesianIndices(to_axes(A, _ints2range.(inds))) @@ -527,7 +517,7 @@ end @inline function unsafe_get_collection(A::LinearIndices{N}, inds) where {N} if is_linear_indexing(A, inds) return @inbounds(eachindex(A)[first(inds)]) - elseif can_preserve_indices(typeof(inds)) + elseif stride_preserving_index(typeof(inds)) === True() return LinearIndices(to_axes(A, _ints2range.(inds))) else return Base._getindex(IndexStyle(A), A, inds...) @@ -570,6 +560,10 @@ function unsafe_setindex!(::UnsafeGetCollection, A, val, inds::Tuple; kwargs...) return unsafe_set_collection!(A, val, inds; kwargs...) end +function unsafe_set_element_error(A, val, inds) + throw(MethodError(unsafe_set_element!, (A, val, inds))) +end + """ unsafe_set_element!(A, val, inds::Tuple) @@ -577,8 +571,12 @@ Sets an element of `A` to `val` at indices `inds`. This method assumes all `inds have been checked for being in bounds. Any new array type using `ArrayInterface.setindex!` must define `unsafe_set_element!(::NewArrayType, val, inds)`. """ -function unsafe_set_element!(A, val, inds; kwargs...) - return throw(MethodError(unsafe_set_element!, (A, val, inds))) +function unsafe_set_element!(a::A, val, inds; kwargs...) where {A} + if parent_type(A) <: A + unsafe_set_element_error(a, val, inds) + else + return @inbounds(parent(a)[inds...] = val) + end end function unsafe_set_element!(A::Array{T}, val, inds::Tuple) where {T} if length(inds) === 0 diff --git a/src/ranges.jl b/src/ranges.jl index 376c85002..17bca8ca9 100644 --- a/src/ranges.jl +++ b/src/ranges.jl @@ -423,24 +423,17 @@ end Base.:(-)(r::OptionallyStaticRange) = -static_first(r):-static_step(r):-static_last(r) -""" - indices(x[, d]) - -Given an array `x`, this returns the indices along dimension `d`. If `x` is a tuple -of arrays, then the indices corresponding to dimension `d` of all arrays in `x` are -returned. If any indices are not equal along dimension `d`, an error is thrown. A -tuple may be used to specify a different dimension for each array. If `d` is not -specified, then the indices for visiting each index of `x` are returned. -""" -@inline function indices(x) - inds = eachindex(x) - if inds isa AbstractUnitRange && eltype(inds) <: Integer - return Base.Slice(OptionallyStaticUnitRange(inds)) +function Base.show(io::IO, r::OptionallyStaticRange) + print(io, first(r)) + if known_step(r) === 1 + print(io, ":") else - return inds + print(io, ":") + print(io, step(r)) + print(io, ":") end + print(io, last(r)) end -@inline indices(x::AbstractUnitRange{<:Integer}) = Base.Slice(OptionallyStaticUnitRange(x)) """ reduce_tup(f::F, inds::Tuple{Vararg{Any,N}}) where {F,N} @@ -573,6 +566,31 @@ Base.Slice(Static(1):100) q end +@inline function _pick_range(x, y) + fst = _try_static(static_first(x), static_first(y)) + lst = _try_static(static_last(x), static_last(y)) + return Base.Slice(OptionallyStaticUnitRange(fst, lst)) +end + +""" + indices(x[, d]) + +Given an array `x`, this returns the indices along dimension `d`. If `x` is a tuple +of arrays, then the indices corresponding to dimension `d` of all arrays in `x` are +returned. If any indices are not equal along dimension `d`, an error is thrown. A +tuple may be used to specify a different dimension for each array. If `d` is not +specified, then the indices for visiting each index of `x` are returned. +""" +@inline function indices(x) + inds = eachindex(x) + if inds isa AbstractUnitRange && eltype(inds) <: Integer + return Base.Slice(OptionallyStaticUnitRange(inds)) + else + return inds + end +end +@inline indices(x::AbstractUnitRange{<:Integer}) = Base.Slice(OptionallyStaticUnitRange(x)) + function indices(x::Tuple) inds = map(eachindex, x) return reduce_tup(_pick_range, inds) @@ -590,21 +608,3 @@ end return reduce_tup(_pick_range, inds) end -@inline function _pick_range(x, y) - fst = _try_static(static_first(x), static_first(y)) - lst = _try_static(static_last(x), static_last(y)) - return Base.Slice(OptionallyStaticUnitRange(fst, lst)) -end - -function Base.show(io::IO, r::OptionallyStaticRange) - print(io, first(r)) - if known_step(r) === 1 - print(io, ":") - else - print(io, ":") - print(io, step(r)) - print(io, ":") - end - print(io, last(r)) -end - diff --git a/src/size.jl b/src/size.jl new file mode 100644 index 000000000..82171332a --- /dev/null +++ b/src/size.jl @@ -0,0 +1,101 @@ + +""" + size(A) + +Returns the size of `A`. If the size of any axes are known at compile time, +these should be returned as `Static` numbers. For example: +```julia +julia> using StaticArrays, ArrayInterface + +julia> A = @SMatrix rand(3,4); + +julia> ArrayInterface.size(A) +(StaticInt{3}(), StaticInt{4}()) +``` +""" +function size(a::A) where {A} + if parent_type(A) <: A + return map(static_length, axes(a)) + else + return size(parent(a)) + end +end +#size(a::AbstractVector) = (size(a, One()),) + +size(x::SubArray) = eachop(_sub_size, x.indices, to_parent_dims(x)) +_sub_size(x::Tuple, ::StaticInt{dim}) where {dim} = static_length(getfield(x, dim)) + +@inline size(B::VecAdjTrans) = (One(), length(parent(B))) +@inline size(B::MatAdjTrans) = permute(size(parent(B)), to_parent_dims(B)) +@inline function size(B::PermutedDimsArray{T,N,I1,I2,A}) where {T,N,I1,I2,A} + return permute(size(parent(B)), to_parent_dims(B)) +end +function size(a::ReinterpretArray{T,N,S,A}) where {T,N,S,A} + psize = size(parent(a)) + if _is_reshaped(typeof(a)) + if sizeof(S) === sizeof(T) + return psize + elseif sizeof(S) > sizeof(T) + return (static(div(sizeof(S), sizeof(T))), psize...) + else + return tail(psize) + end + else + return (div(first(psize) * static(sizeof(S)), static(sizeof(T))), tail(psize)...,) + end +end +size(A::ReshapedArray) = A.dims +size(A::AbstractRange) = (static_length(A),) + +""" + size(A, dim) + +Returns the size of `A` along dimension `dim`. +""" +size(a, dim) = size(a, to_dims(a, dim)) +function size(a::A, dim::Integer) where {A} + if parent_type(A) <: A + len = known_size(A, dim) + if len === nothing + return Int(length(axes(a, dim))) + else + return StaticInt(len) + end + else + return size(a)[dim] + end +end +function size(A::SubArray, dim::Integer) + pdim = to_parent_dims(A, dim) + if pdim > ndims(parent_type(A)) + return size(parent(A), pdim) + else + return static_length(A.indices[pdim]) + end +end + +""" + known_size(::Type{T}) -> Tuple + +Returns the size of each dimension for `T` known at compile time. If a dimension does not +have a known size along a dimension then `nothing` is returned in its position. +""" +known_size(x) = known_size(typeof(x)) +known_size(::Type{T}) where {T} = eachop(known_size, T, nstatic(Val(ndims(T)))) + +""" + known_size(::Type{T}, dim) + +Returns the size along dimension `dim` known at compile time. If it is not known then +returns `nothing`. +""" +@inline known_size(x, dim) = known_size(typeof(x), dim) +@inline known_size(::Type{T}, dim) where {T} = known_size(T, to_dims(T, dim)) +@inline function known_size(::Type{T}, dim::Integer) where {T} + if ndims(T) < dim + return 1 + else + return known_length(axes_types(T, dim)) + end +end + diff --git a/src/static.jl b/src/static.jl index b2475c0fa..96249393e 100644 --- a/src/static.jl +++ b/src/static.jl @@ -399,7 +399,11 @@ is_static(::Type{T}) where {T} = False() _tuple_static(::Type{T}, i) where {T} = is_static(_get_tuple(T, i)) function is_static(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} - return all(eachop(_tuple_static, T, nstatic(Val(N)))) + if all(eachop(_tuple_static, T, nstatic(Val(N)))) + return True() + else + return False() + end end """ diff --git a/src/stridelayout.jl b/src/stridelayout.jl index 622569e41..763ae0a7e 100644 --- a/src/stridelayout.jl +++ b/src/stridelayout.jl @@ -1,6 +1,26 @@ +#= + stride_preserving_index(::Type{T}) -> StaticBool + +Returns `True` if strides between each element can still be derived when indexing with an +instance of type `T`. +=# +stride_preserving_index(::Type{T}) where {T<:AbstractRange} = True() +stride_preserving_index(::Type{T}) where {T<:Int} = True() +stride_preserving_index(::Type{T}) where {T} = False() +function stride_preserving_index(::Type{T}) where {N,T<:Tuple{Vararg{Any,N}}} + if all(eachop(_stride_preserving_index, T, nstatic(Val(N)))) + return True() + else + return False() + end +end +function _stride_preserving_index(::Type{T}, i::StaticInt) where {T} + return stride_preserving_index(_get_tuple(T, i)) +end + """ - offsets(A) -> Tuple + offsets(A[, dim]) -> Tuple Returns offsets of indices with respect to 0. If values are known at compile time, it should return them as `Static` numbers. @@ -12,7 +32,7 @@ offsets(x) = eachop(offsets, x, nstatic(Val(ndims(x)))) offsets(::Tuple) = (One(),) """ -contiguous_axis(::Type{T}) -> StaticInt{N} + contiguous_axis(::Type{T}) -> StaticInt{N} Returns the axis of an array of type `T` containing contiguous data. If no axis is contiguous, it returns `StaticInt{-1}`. @@ -48,18 +68,18 @@ function contiguous_axis(::Type{T}) where {T<:MatAdjTrans} return StaticInt(3) - c end end -function contiguous_axis(::Type{T}) where {I1,I2,T<:PermutedDimsArray{<:Any,<:Any,I1,I2}} +function contiguous_axis(::Type{T}) where {T<:PermutedDimsArray} c = contiguous_axis(parent_type(T)) if c === nothing return nothing elseif isone(-c) return c else - return StaticInt(I2[c]) + return from_parent_dims(T, c) end end -function contiguous_axis(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} - return _contiguous_axis(S, contiguous_axis(A)) +function contiguous_axis(::Type{T}) where {T<:SubArray} + return _contiguous_axis(T, contiguous_axis(parent_type(T))) end _contiguous_axis(::Type{A}, ::Nothing) where {T,N,P,I,A<:SubArray{T,N,P,I}} = nothing @@ -110,7 +130,13 @@ function rank_to_sortperm(R::Tuple{Vararg{StaticInt,N}}) where {N} end stride_rank(x) = stride_rank(typeof(x)) -stride_rank(::Type) = nothing +function stride_rank(::Type{T}) where {T} + if parent_type(T) <: T + return nothing + else + return stride_rank(parent_type(T)) + end +end stride_rank(::Type{Array{T,N}}) where {T,N} = nstatic(Val(N)) stride_rank(::Type{<:Tuple}) = (One(),) @@ -142,11 +168,6 @@ end _reshaped_striderank(::True, ::Val{N}, ::Val{0}) where {N} = nstatic(Val(N)) _reshaped_striderank(_, __, ___) = nothing -""" - If the contiguous dimension is not the dimension with `StrideRank{1}`: -""" - - """ contiguous_batch_size(::Type{T}) -> StaticInt{N} @@ -208,7 +229,13 @@ Returns a tuple of indicators for whether each axis is dense. An axis `i` of array `A` is dense if `stride(A, i) * Base.size(A, i) == stride(A, j)` where `stride_rank(A)[i] + 1 == stride_rank(A)[j]`. """ dense_dims(x) = dense_dims(typeof(x)) -dense_dims(::Type) = nothing +function dense_dims(::Type{T}) where {T} + if parent_type(T) <: T + return nothing + else + return dense_dims(parent_type(T)) + end +end _all_dense(::Val{N}) where {N} = ntuple(_ -> True(), Val{N}()) dense_dims(::Type{Array{T,N}}) where {T,N} = _all_dense(Val{N}()) @@ -241,7 +268,7 @@ function dense_dims(::Type{S}) where {N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArra return _dense_dims(S, dense_dims(A), Val(stride_rank(A))) end -_dense_dims(::Any, ::Any) = nothing +_dense_dims(::Type{S}, ::Nothing, ::Val{R}) where {R,N,NP,T,A<:AbstractArray{T,NP},I,S<:SubArray{T,N,A,I}} = nothing @generated function _dense_dims( ::Type{S}, ::D, @@ -293,75 +320,45 @@ function _reshaped_dense_dims(dense::D, ::True, ::Val{N}, ::Val{0}) where {D,N} end """ - strides(A) -> Tuple - -Returns the strides of array `A`. If any strides are known at compile time, -these should be returned as `Static` numbers. For example: -```julia -julia> A = rand(3,4); - -julia> ArrayInterface.strides(A) -(static(1), 3) - -Additionally, the behavior differs from `Base.strides` for adjoint vectors: - -julia> x = rand(5); - -julia> ArrayInterface.strides(x') -(static(1), static(1)) - -This is to support the pattern of using just the first stride for linear indexing, `x[i]`, -while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`. -``` -""" -strides(A) = Base.strides(A) -strides(A, d) = strides(A)[to_dims(A, d)] - -@inline function known_length(::Type{T}) where {T <: Base.ReinterpretArray} - return _known_length(known_length(parent_type(T)), eltype(T), eltype(parent_type(T))) -end -_known_length(::Nothing, _, __) = nothing -@inline _known_length(L::Integer, ::Type{T}, ::Type{P}) where {T,P} = L * sizeof(P) ÷ sizeof(T) - - - -""" - known_offsets(::Type{T}[, d]) -> Tuple + known_offsets(::Type{T}[, dim]) -> Tuple Returns a tuple of offset values known at compile time. If the offset of a given axis is not known at compile time `nothing` is returned its position. """ -@inline known_offsets(x, d) = known_offsets(x)[to_dims(x, d)] -known_offsets(x) = known_offsets(typeof(x)) -@generated function known_offsets(::Type{T}) where {T} - out = Expr(:tuple) - for p in axes_types(T).parameters - push!(out.args, known_first(p)) +known_offsets(x, dim) = known_offsets(typeof(x), dim) +known_offsets(::Type{T}, dim) where {T} = known_offsets(T, to_dims(T, dim)) +function known_offsets(::Type{T}, dim::Integer) where {T} + if ndims(T) < dim + return 1 + else + return known_offsets(T)[dim] end - return out end -""" - known_size(::Type{T}[, d]) -> Tuple - -Returns the size of each dimension for `T` known at compile time. If a dimension does not -have a known size along a dimension then `nothing` is returned in its position. -""" -@inline known_size(x, d) = known_size(x)[to_dims(x, d)] -known_size(x) = known_size(typeof(x)) -function known_size(::Type{T}) where {T} - return eachop(_known_axis_length, axes_types(T), nstatic(Val(ndims(T)))) +known_offsets(x) = known_offsets(typeof(x)) +function known_offsets(::Type{T}) where {T} + return eachop(_known_offsets, axes_types(T), nstatic(Val(ndims(T)))) end -_known_axis_length(::Type{T}, c::StaticInt) where {T} = known_length(_get_tuple(T, c)) +_known_offsets(::Type{T}, dim::StaticInt) where {T} = known_first(_get_tuple(T, dim)) """ - known_strides(::Type{T}[, d]) -> Tuple + known_strides(::Type{T}[, dim]) -> Tuple Returns the strides of array `A` known at compile time. Any strides that are not known at compile time are represented by `nothing`. """ +known_strides(x, dim) = known_strides(typeof(x), dim) +known_strides(::Type{T}, dim) where {T} = known_strides(T, to_dims(T, dim)) +function known_strides(::Type{T}, dim::Integer) where {T} + # see https://github.com/JuliaLang/julia/blob/6468dcb04ea2947f43a11f556da9a5588de512a0/base/reinterpretarray.jl#L148 + if ndims(T) < dim + return known_length(T) + else + return known_strides(T)[dim] + end +end + known_strides(x) = known_strides(typeof(x)) -known_strides(x, d) = known_strides(x)[to_dims(x, d)] known_strides(::Type{T}) where {T<:Vector} = (1,) function known_strides(::Type{T}) where {T<:MatAdjTrans} return permute(known_strides(parent_type(T)), to_parent_dims(T)) @@ -370,108 +367,97 @@ end strd = first(known_strides(parent_type(T))) return (strd, strd) end -@inline function known_strides(::Type{T}) where {I1,T<:PermutedDimsArray{<:Any,<:Any,I1}} - return permute(known_strides(parent_type(T)), Val{I1}()) +@inline function known_strides(::Type{T}) where {T<:PermutedDimsArray} + return permute(known_strides(parent_type(T)), to_parent_dims(T)) end -@inline function known_strides(::Type{T}) where {I1,T<:SubArray{<:Any,<:Any,<:Any,I1}} +@inline function known_strides(::Type{T}) where {T<:SubArray} return permute(known_strides(parent_type(T)), to_parent_dims(T)) end function known_strides(::Type{T}) where {T} if ndims(T) === 1 return (1,) else - return _known_strides(Val(Base.front(known_size(T)))) - end -end -@generated function _known_strides(::Val{S}) where {S} - out = Expr(:tuple) - N = length(S) - push!(out.args, 1) - for s in S - if s === nothing || out.args[end] === nothing - push!(out.args, nothing) - else - push!(out.args, out.args[end] * s) - end + return size_to_strides(known_size(T), 1) end - return Expr(:block, Expr(:meta, :inline), out) end +_stride_error(@nospecialize(x)) = throw(MethodError(ArrayInterface.strides, (x,))) + +""" + strides(A) -> Tuple + +Returns the strides of array `A`. If any strides are known at compile time, +these should be returned as `Static` numbers. For example: +```julia +julia> A = rand(3,4); + +julia> ArrayInterface.strides(A) +(static(1), 3) + +Additionally, the behavior differs from `Base.strides` for adjoint vectors: + +julia> x = rand(5); + +julia> ArrayInterface.strides(x') +(static(1), static(1)) + +This is to support the pattern of using just the first stride for linear indexing, `x[i]`, +while still producing correct behavior when using valid cartesian indices, such as `x[1,i]`. +``` +""" @inline strides(A::Vector{<:Any}) = (StaticInt(1),) @inline strides(A::Array{<:Any,N}) where {N} = (StaticInt(1), Base.tail(Base.strides(A))...) -@inline strides(A::AbstractArray) = _strides(A, Base.strides(A), contiguous_axis(A)) +function strides(x) + defines_strides(x) || _stride_error(x) + return size_to_strides(size(x), One()) +end +#@inline strides(A) = _strides(A, Base.strides(A), contiguous_axis(A)) function strides(x::VecAdjTrans) st = first(strides(parent(x))) return (st, st) end +@inline strides(B::MatAdjTrans) = permute(strides(parent(B)), to_parent_dims(B)) +@inline strides(B::PermutedDimsArray) = permute(strides(parent(B)), to_parent_dims(B)) -@generated function _strides(A::AbstractArray{T,N}, s::NTuple{N}, ::StaticInt{C}) where {T,N,C} - if C ≤ 0 || C > N - return Expr(:block, Expr(:meta, :inline), :s) - else - stup = Expr(:tuple) - for n ∈ 1:N - if n == C - push!(stup.args, :(One())) - else - push!(stup.args, Expr(:ref, :s, n)) - end - end - return quote - $(Expr(:meta, :inline)) - @inbounds $stup - end - end +getmul(x::Tuple, y::Tuple, ::StaticInt{i}) where {i} = getfield(x, i) * getfield(y, i) +function strides(A::SubArray) + return eachop(getmul, map(maybe_static_step, A.indices), strides(parent(A)), to_parent_dims(A)) end -if VERSION ≥ v"1.6.0-DEV.1581" - @generated function _strides( - _::Base.ReinterpretArray{T,N,S,A,true}, - s::NTuple{N}, - ::StaticInt{1}, - ) where {T,N,S,D,A<:Array{S,D}} - stup = Expr(:tuple, :(One())) - if D < N - push!(stup.args, Expr(:call, Expr(:curly, :StaticInt, sizeof(S) ÷ sizeof(T)))) - end - for n ∈ 2+(D 1 && t.args[i - 1] === :nothing) + push!(t.args, :nothing) + else + next = Symbol(:val_, i) + push!(out.args, :($next = $prev * getfield(sz, $i))) + push!(t.args, next) + prev = next end + i += 1 end + push!(out.args, t) + return out end -@inline strides(B::MatAdjTrans) = permute(strides(parent(B)), to_parent_dims(B)) -@inline function strides(B::PermutedDimsArray) - return permute(strides(parent(B)), to_parent_dims(B)) +strides(a, dim) = strides(a, to_dims(a, dim)) +function strides(a::A, dim::Integer) where {A} + if parent_type(A) <: A + return Base.stride(a, Int(dim)) + else + return strides(parent(a), to_parent_dims(A, dim)) + end end + @inline stride(A::AbstractArray, ::StaticInt{N}) where {N} = strides(A)[N] @inline stride(A::AbstractArray, ::Val{N}) where {N} = strides(A)[N] stride(A, i) = Base.stride(A, i) # for type stability -@generated function _strides(A::Tuple{Vararg{Any,N}}, inds::I) where {N,I<:Tuple} - t = Expr(:tuple) - for n = 1:N - if I.parameters[n] <: AbstractUnitRange - push!(t.args, Expr(:ref, :A, n)) - elseif I.parameters[n] <: AbstractRange - push!( - t.args, - Expr( - :call, - :(*), - Expr(:ref, :A, n), - Expr(:call, :static_step, Expr(:ref, :inds, n)), - ), - ) - elseif !(I.parameters[n] <: Integer) - return nothing - end - end - Expr(:block, Expr(:meta, :inline), t) -end - diff --git a/test/dimensions.jl b/test/dimensions.jl index 82932c1ca..03a5f48f6 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -6,31 +6,22 @@ using ArrayInterface: dimnames ### define wrapper with dimnames ### -struct NamedDimsWrapper{L,T,N,P<:AbstractArray{T,N}} <: AbstractArray{T,N} +struct NamedDimsWrapper{L,T,N,P<:AbstractArray{T,N}} <: ArrayInterface.AbstractArray2{T,N} parent::P NamedDimsWrapper{L}(p) where {L} = new{L,eltype(p),ndims(p),typeof(p)}(p) end ArrayInterface.parent_type(::Type{T}) where {P,T<:NamedDimsWrapper{<:Any,<:Any,<:Any,P}} = P ArrayInterface.has_dimnames(::Type{T}) where {T<:NamedDimsWrapper} = true ArrayInterface.dimnames(::Type{T}) where {L,T<:NamedDimsWrapper{L}} = static(Val(L)) -Base.parent(x::NamedDimsWrapper) = x.parent -Base.size(x::NamedDimsWrapper) = size(parent(x)) -Base.size(x::NamedDimsWrapper, d) = ArrayInterface.size(x, d) -Base.axes(x::NamedDimsWrapper) = axes(parent(x)) -Base.axes(x::NamedDimsWrapper, d) = ArrayInterface.axes(x, d) -Base.strides(x::NamedDimsWrapper) = Base.strides(parent(x)) -Base.strides(x::NamedDimsWrapper, d) = ArrayInterface.strides(x, d) - -Base.getindex(x::NamedDimsWrapper; kwargs...) = ArrayInterface.getindex(x; kwargs...) -Base.getindex(x::NamedDimsWrapper, args...) = ArrayInterface.getindex(x, args...) -Base.setindex!(x::NamedDimsWrapper, val; kwargs...) = ArrayInterface.setindex!(x, val; kwargs...) -Base.setindex!(x::NamedDimsWrapper, val, args...) = ArrayInterface.setindex!(x, val, args...) -function ArrayInterface.unsafe_get_element(x::NamedDimsWrapper, inds; kwargs...) - return @inbounds(parent(x)[inds...]) -end -function ArrayInterface.unsafe_set_element!(x::NamedDimsWrapper, val, inds; kwargs...) - return @inbounds(parent(x)[inds...] = val) +function ArrayInterface.dimnames(::Type{T}, dim) where {L,T<:NamedDimsWrapper{L}} + if ndims(T) < dim + return static(:_) + else + return static(L[dim]) + end end +ArrayInterface.has_dimnames(::Type{T}) where {T<:NamedDimsWrapper} = true +Base.parent(x::NamedDimsWrapper) = x.parent @testset "dimension permutations" begin a = ones(2, 2, 2) @@ -46,13 +37,43 @@ end @test @inferred(ArrayInterface.to_parent_dims(typeof(madj))) == (2, 1) @test @inferred(ArrayInterface.to_parent_dims(typeof(vview))) == (2,) @test @inferred(ArrayInterface.to_parent_dims(typeof(vadj))) == (2, 1) + @test @inferred(ArrayInterface.to_parent_dims(typeof(vadj), static(1))) == 2 + @test @inferred(ArrayInterface.to_parent_dims(typeof(vadj), 1)) == 2 + @test @inferred(ArrayInterface.to_parent_dims(typeof(vadj), static(3))) == 2 + @test @inferred(ArrayInterface.to_parent_dims(typeof(vadj), 3)) == 2 @test @inferred(ArrayInterface.from_parent_dims(typeof(a))) == (1, 2, 3) @test @inferred(ArrayInterface.from_parent_dims(typeof(perm))) == (2, 3, 1) @test @inferred(ArrayInterface.from_parent_dims(typeof(mview))) == (1, 0, 2) @test @inferred(ArrayInterface.from_parent_dims(typeof(madj))) == (2, 1) @test @inferred(ArrayInterface.from_parent_dims(typeof(vview))) == (0, 1) - @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj))) == (2, 1) + @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj))) == (2,) + @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj), static(1))) == 2 + @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj), 1)) == 2 + + @test_throws DimensionMismatch ArrayInterface.to_parent_dims(typeof(vadj), 0) + @test_throws DimensionMismatch ArrayInterface.to_parent_dims(typeof(vadj), static(0)) + + @test_throws DimensionMismatch ArrayInterface.from_parent_dims(typeof(vadj), 0) + @test_throws DimensionMismatch ArrayInterface.from_parent_dims(typeof(vadj), static(0)) + + if VERSION ≥ v"1.6.0-DEV.1581" + colormat = reinterpret(reshape, Float64, [(R = rand(), G = rand(), B = rand()) for i ∈ 1:100]) + @test @inferred(ArrayInterface.from_parent_dims(typeof(colormat))) === (static(2),) + @test @inferred(ArrayInterface.to_parent_dims(typeof(colormat))) === (static(0), static(1),) + + Rr = reinterpret(reshape, Int32, ones(4)) + @test @inferred(ArrayInterface.from_parent_dims(typeof(Rr))) === (static(2),) + @test @inferred(ArrayInterface.to_parent_dims(typeof(Rr))) === (static(0), static(1),) + + Rr = reinterpret(reshape, Int64, ones(4)) + @test @inferred(ArrayInterface.from_parent_dims(typeof(Rr))) === (static(1),) + @test @inferred(ArrayInterface.to_parent_dims(typeof(Rr))) === (static(1),) + + Sr = reinterpret(reshape, Complex{Int64}, zeros(2, 3, 4)) + @test @inferred(ArrayInterface.from_parent_dims(typeof(Sr))) === (static(0), static(1), static(2)) + @test @inferred(ArrayInterface.to_parent_dims(typeof(Sr))) === (static(2), static(3)) + end end @testset "order_named_inds" begin @@ -78,16 +99,21 @@ val_has_dimnames(x) = Val(ArrayInterface.has_dimnames(x)) y = NamedDimsWrapper{(:x,)}(ones(2)); dnums = ntuple(+, length(d)) @test @inferred(val_has_dimnames(x)) === Val(true) + @test @inferred(ArrayInterface.has_dimnames(ones(2,2))) === false + @test @inferred(ArrayInterface.has_dimnames(Array{Int,2})) === false @test @inferred(val_has_dimnames(typeof(x))) === Val(true) @test @inferred(val_has_dimnames(typeof(view(x, :, 1, :)))) === Val(true) @test @inferred(dimnames(x)) === d + @test @inferred(dimnames(parent(x))) === (static(:_), static(:_)) @test @inferred(dimnames(x')) === reverse(d) @test @inferred(dimnames(y')) === (static(:_), static(:x)) @test @inferred(dimnames(PermutedDimsArray(x, (2, 1)))) === reverse(d) + @test @inferred(dimnames(PermutedDimsArray(x', (2, 1)))) === d @test @inferred(dimnames(view(x, :, 1))) === (static(:x),) @test @inferred(dimnames(view(x, :, :, :))) === (static(:x),static(:y), static(:_)) @test @inferred(dimnames(view(x, :, 1, :))) === (static(:x), static(:_)) @test @inferred(dimnames(x, ArrayInterface.One())) === static(:x) + @test @inferred(dimnames(parent(x), ArrayInterface.One())) === static(:_) end @testset "to_dims" begin @@ -101,14 +127,15 @@ end @test @inferred(ArrayInterface.to_dims(x, (:y, :x))) == (2, 1) @test @inferred(ArrayInterface.to_dims(x, :x)) == 1 @test @inferred(ArrayInterface.to_dims(x, :y)) == 2 - @test_throws ArgumentError ArrayInterface.to_dims(x, :z) # not found + @test_throws DimensionMismatch ArrayInterface.to_dims(x, static(:z)) # not found + @test_throws DimensionMismatch ArrayInterface.to_dims(x, :z) # not found end @testset "large case" begin @test @inferred(ArrayInterface.to_dims(y, :x)) == 1 @test @inferred(ArrayInterface.to_dims(y, :a)) == 3 @test @inferred(ArrayInterface.to_dims(y, :d)) == 6 - @test_throws ArgumentError ArrayInterface.to_dims(y, :z) # not found + @test_throws DimensionMismatch ArrayInterface.to_dims(y, :z) # not found end end @@ -116,10 +143,12 @@ end d = (static(:x), static(:y)) x = NamedDimsWrapper{d}(ones(2,2)); y = NamedDimsWrapper{(:x,)}(ones(2)); - @test @inferred(size(x, :x)) == size(parent(x), 1) + @test @inferred(size(x, first(d))) == size(parent(x), 1) @test @inferred(ArrayInterface.size(y')) == (1, size(parent(x), 1)) - @test @inferred(axes(x, :x)) == axes(parent(x), 1) + @test @inferred(axes(x, first(d))) == axes(parent(x), 1) @test strides(x, :x) == ArrayInterface.strides(parent(x))[1] + @test @inferred(ArrayInterface.axes_types(x, static(:x))) <: Base.OneTo{Int} + @test ArrayInterface.axes_types(x, :x) <: Base.OneTo{Int} x[x = 1] = [2, 3] @test @inferred(getindex(x, x = 1)) == [2, 3] diff --git a/test/runtests.jl b/test/runtests.jl index 2ec14df75..44d2d1a4d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,6 @@ x = @SVector [1,2,3] x = @MVector [1,2,3] @test ArrayInterface.ismutable(x) == true @test ArrayInterface.ismutable(view(x, 1:2)) == true -@test ArrayInterface.ismutable(1:10) == false @test ArrayInterface.ismutable((0.1,1.0)) == false @test ArrayInterface.ismutable(Base.ImmutableDict{Symbol,Int64}) == false @test ArrayInterface.ismutable((;x=1)) == false @@ -292,14 +291,26 @@ DummyZeros(dims...) = DummyZeros{Float64}(dims...) Base.size(x::DummyZeros) = x.dims Base.getindex(::DummyZeros{T}, inds...) where {T} = zero(T) +struct Wrapper{T,N,P<:AbstractArray{T,N}} <: ArrayInterface.AbstractArray2{T,N} + parent::P +end +ArrayInterface.parent_type(::Type{<:Wrapper{T,N,P}}) where {T,N,P} = P +Base.parent(x::Wrapper) = x.parent + using OffsetArrays @testset "Memory Layout" begin x = zeros(100); # R = reshape(view(x, 1:100), (10,10)); # A = zeros(3,4,5); - A = reshape(view(x, 1:60), (3,4,5)) + A = Wrapper(reshape(view(x, 1:60), (3,4,5))) D1 = view(A, 1:2:3, :, :) # first dimension is discontiguous D2 = view(A, :, 2:2:4, :) # first dimension is contiguous + + @test @inferred(ArrayInterface.defines_strides(x)) + @test @inferred(ArrayInterface.defines_strides(A)) + @test @inferred(ArrayInterface.defines_strides(D1)) + @test !@inferred(ArrayInterface.defines_strides(view(A, :, [1,2],1))) + @test @inferred(device(A)) === ArrayInterface.CPUPointer() @test @inferred(device((1,2,3))) === ArrayInterface.CPUIndex() @test @inferred(device(PermutedDimsArray(A,(3,1,2)))) === ArrayInterface.CPUPointer() @@ -330,9 +341,12 @@ using OffsetArrays @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StaticInt(-1) @test @inferred(contiguous_axis(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(1) @test @inferred(contiguous_axis((3,4))) === StaticInt(1) - @test @inferred(contiguous_axis(DummyZeros(3,4))) === nothing @test @inferred(contiguous_axis(rand(4)')) === StaticInt(2) @test @inferred(contiguous_axis(view(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])', :, 1)')) === StaticInt(-1) + @test @inferred(contiguous_axis(DummyZeros(3,4))) === nothing + @test @inferred(contiguous_axis(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing + @test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :))) === nothing + @test @inferred(contiguous_axis(view(DummyZeros(3,4), 1, :)')) === nothing @test @inferred(ArrayInterface.contiguous_axis_indicator(@SArray(zeros(2,2,2)))) == (true,false,false) @test @inferred(ArrayInterface.contiguous_axis_indicator(A)) == (true,false,false) @@ -356,20 +370,22 @@ using OffsetArrays @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) === ArrayInterface.StaticInt(-1) @test @inferred(contiguous_batch_size(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) === ArrayInterface.StaticInt(0) - @test @inferred(stride_rank(@SArray(zeros(2,2,2)))) == ((1, 2, 3)) - @test @inferred(stride_rank(A)) == ((1,2,3)) - @test @inferred(stride_rank(view(A,:,:,1))) == ((1,2)) + @test @inferred(stride_rank(@SArray(zeros(2,2,2)))) == (1, 2, 3) + @test @inferred(stride_rank(A)) == (1,2,3) + @test @inferred(stride_rank(view(A,:,:,1))) === (static(1), static(2)) @test @inferred(stride_rank(view(A,:,:,1))) === ((ArrayInterface.StaticInt(1),ArrayInterface.StaticInt(2))) - @test @inferred(stride_rank(PermutedDimsArray(A,(3,1,2)))) == ((3, 1, 2)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) == ((1, 2)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) == ((2, 1)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) == ((3, 1, 2)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) == ((3, 2)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) == ((2, 3)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) == ((1, 3)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])')) == ((2, 1)) - @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,[1,3,4]]))) == ((3, 1, 2)) - + @test @inferred(stride_rank(PermutedDimsArray(A,(3,1,2)))) == (3, 1, 2) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) == (1, 2) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) == (2, 1) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) == (3, 1, 2) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) == (3, 2) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) == (2, 3) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) == (1, 3) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,2,1])')) == (2, 1) + @test @inferred(stride_rank(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,[1,3,4]]))) == (3, 1, 2) + @test @inferred(stride_rank(DummyZeros(3,4)')) === nothing + @test @inferred(stride_rank(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing + @test @inferred(stride_rank(view(DummyZeros(3,4), 1, :))) === nothing #= @btime ArrayInterface.is_column_major($(PermutedDimsArray(A,(3,1,2)))) 0.047 ns (0 allocations: 0 bytes) @@ -395,20 +411,28 @@ using OffsetArrays @test @inferred(ArrayInterface.is_column_major(1:10)) === False() @test @inferred(ArrayInterface.is_column_major(2.3)) === False() - @test @inferred(dense_dims(@SArray(zeros(2,2,2)))) == ((true,true,true)) - @test @inferred(dense_dims(A)) == ((true,true,true)) - @test @inferred(dense_dims(PermutedDimsArray(A,(3,1,2)))) == ((true,true,true)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) == ((true,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) == ((false,true)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) == ((false,true,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,1:2]))) == ((false,true,true)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) == ((false,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) == ((false,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) == ((true,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,[1,2]]))) == ((false,true,false)) - @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,[1,2,3],:]))) == ((false,false,false)) - @test @inferred(dense_dims(vec(A))) == ((true,)) - @test @inferred(dense_dims(vec(A)')) == ((true,true)) + @test @inferred(dense_dims(@SArray(zeros(2,2,2)))) == (true,true,true) + @test @inferred(dense_dims(A)) == (true,true,true) + @test @inferred(dense_dims(PermutedDimsArray(A,(3,1,2)))) == (true,true,true) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:]))) == (true,false) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2,1:2,:])')) == (false,true) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,1:2,:]))) == (false,true,false) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,1:2]))) == (false,true,true) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:]))) == (false,false) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,2,:])')) == (false,false) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[:,1:2,1])')) == (true,false) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,:,[1,2]]))) == (false,true,false) + @test @inferred(dense_dims(@view(PermutedDimsArray(A,(3,1,2))[2:3,[1,2,3],:]))) == (false,false,false) + # TODO Currently Wrapper can't function the same as Array because Array can change + # the dimensions on reshape. We should be rewrapping the result in `Wrapper` but we + # first need to develop a standard method for reconstructing arrays + @test @inferred(dense_dims(vec(parent(A)))) == (true,) + @test @inferred(dense_dims(vec(parent(A))')) == (true,true) + @test @inferred(dense_dims(DummyZeros(3,4))) === nothing + @test @inferred(dense_dims(DummyZeros(3,4)')) === nothing + @test @inferred(dense_dims(PermutedDimsArray(DummyZeros(3,4), (2, 1)))) === nothing + @test @inferred(dense_dims(view(DummyZeros(3,4), :, 1))) === nothing + @test @inferred(dense_dims(view(DummyZeros(3,4), :, 1)')) === nothing B = Array{Int8}(undef, 2,2,2,2); doubleperm = PermutedDimsArray(PermutedDimsArray(B,(4,2,3,1)), (4,2,1,3)); @@ -426,7 +450,7 @@ end Mp2 = @view(PermutedDimsArray(M,(3,1,2))[2:3,:,2])'; D = @view(A[:,2:2:4,:]) R = StaticInt(1):StaticInt(2) - Rr = reinterpret(Int32, R) + Rnr = reinterpret(Int32, R) Ar = reinterpret(Float32, A) sv5 = @SVector(zeros(5)); v5 = Vector{Float64}(undef, 5); @@ -437,18 +461,20 @@ end @test @inferred(ArrayInterface.size(A)) === size(A) @test @inferred(ArrayInterface.size(Ap)) === size(Ap) @test @inferred(ArrayInterface.size(R)) === (StaticInt(2),) - @test @inferred(ArrayInterface.size(Rr)) === (StaticInt(4),) - @test @inferred(ArrayInterface.known_length(Rr)) === 4 + @test @inferred(ArrayInterface.size(Rnr)) === (StaticInt(4),) + @test @inferred(ArrayInterface.known_length(Rnr)) === 4 @test @inferred(ArrayInterface.size(S)) === (StaticInt(2), StaticInt(3), StaticInt(4)) @test @inferred(ArrayInterface.size(Sp)) === (2, 2, StaticInt(3)) @test @inferred(ArrayInterface.size(Sp2)) === (2, StaticInt(3), StaticInt(2)) @test @inferred(ArrayInterface.size(S)) == size(S) @test @inferred(ArrayInterface.size(Sp)) == size(Sp) + @test @inferred(ArrayInterface.size(parent(Sp2))) === (static(4), static(3), static(2)) @test @inferred(ArrayInterface.size(Sp2)) == size(Sp2) @test @inferred(ArrayInterface.size(Sp2, StaticInt(1))) === 2 @test @inferred(ArrayInterface.size(Sp2, StaticInt(2))) === StaticInt(3) @test @inferred(ArrayInterface.size(Sp2, StaticInt(3))) === StaticInt(2) + @test @inferred(ArrayInterface.size(Wrapper(Sp2), StaticInt(3))) === StaticInt(2) @test @inferred(ArrayInterface.size(M)) === (StaticInt(2), StaticInt(3), StaticInt(4)) @test @inferred(ArrayInterface.size(Mp)) === (StaticInt(3), StaticInt(4)) @@ -460,12 +486,19 @@ end @test @inferred(ArrayInterface.known_size(A)) === (nothing, nothing, nothing) @test @inferred(ArrayInterface.known_size(Ap)) === (nothing,nothing) + @test @inferred(ArrayInterface.known_size(Wrapper(Ap))) === (nothing,nothing) @test @inferred(ArrayInterface.known_size(R)) === (2,) - @test @inferred(ArrayInterface.known_size(Rr)) === (4,) + @test @inferred(ArrayInterface.known_size(Wrapper(R))) === (2,) + @test @inferred(ArrayInterface.known_size(Rnr)) === (4,) + @test @inferred(ArrayInterface.known_size(Rnr, static(1))) === 4 @test @inferred(ArrayInterface.known_size(Ar)) === (nothing,nothing, nothing,) + @test @inferred(ArrayInterface.known_size(Ar, static(1))) === nothing + @test @inferred(ArrayInterface.known_size(Ar, static(4))) === 1 @test @inferred(ArrayInterface.known_size(S)) === (2, 3, 4) + @test @inferred(ArrayInterface.known_size(Wrapper(S))) === (2, 3, 4) @test @inferred(ArrayInterface.known_size(Sp)) === (nothing, nothing, 3) + @test @inferred(ArrayInterface.known_size(Wrapper(Sp))) === (nothing, nothing, 3) @test @inferred(ArrayInterface.known_size(Sp2)) === (nothing, 3, 2) @test @inferred(ArrayInterface.known_size(Sp2, StaticInt(1))) === nothing @test @inferred(ArrayInterface.known_size(Sp2, StaticInt(2))) === 3 @@ -496,6 +529,7 @@ end @test @inferred(ArrayInterface.strides(M)) == strides(M) @test @inferred(ArrayInterface.strides(Mp)) == strides(Mp) @test @inferred(ArrayInterface.strides(Mp2)) == strides(Mp2) + @test_throws MethodError ArrayInterface.strides(DummyZeros(3,4)) @test @inferred(ArrayInterface.known_strides(A)) === (1, nothing, nothing) @test @inferred(ArrayInterface.known_strides(Ap)) === (1, nothing) @@ -508,6 +542,7 @@ end @test @inferred(ArrayInterface.known_strides(Sp2, StaticInt(1))) === 6 @test @inferred(ArrayInterface.known_strides(Sp2, StaticInt(2))) === 2 @test @inferred(ArrayInterface.known_strides(Sp2, StaticInt(3))) === 1 + @test @inferred(ArrayInterface.known_strides(Sp2, StaticInt(4))) === ArrayInterface.known_length(Sp2) @test @inferred(ArrayInterface.known_strides(view(Sp2, :, 1, 1)')) === (6, 6) @test @inferred(ArrayInterface.known_strides(M)) === (1, 2, 6) @@ -529,6 +564,8 @@ end @test @inferred(ArrayInterface.known_offsets(A)) === (1, 1, 1) @test @inferred(ArrayInterface.known_offsets(Ap)) === (1, 1) @test @inferred(ArrayInterface.known_offsets(Ar)) === (1, 1, 1) + @test @inferred(ArrayInterface.known_offsets(Ar, static(1))) === 1 + @test @inferred(ArrayInterface.known_offsets(Ar, static(4))) === 1 @test @inferred(ArrayInterface.known_offsets(S)) === (1, 1, 1) @test @inferred(ArrayInterface.known_offsets(Sp)) === (1, 1, 1) @@ -539,7 +576,7 @@ end @test @inferred(ArrayInterface.known_offsets(Mp2)) === (1, 1) @test @inferred(ArrayInterface.known_offsets(R)) === (1,) - @test @inferred(ArrayInterface.known_offsets(Rr)) === (1,) + @test @inferred(ArrayInterface.known_offsets(Rnr)) === (1,) @test @inferred(ArrayInterface.known_offsets(1:10)) === (1,) O = OffsetArray(A, 3, 7, 10); @@ -554,6 +591,15 @@ end colormat = reinterpret(reshape, Float64, colors) @test @inferred(ArrayInterface.strides(colormat)) === (StaticInt(1), StaticInt(3)) + + Rr = reinterpret(reshape, Int32, R) + @test @inferred(ArrayInterface.size(Rr)) === (StaticInt(2),StaticInt(2)) + @test @inferred(ArrayInterface.known_size(Rr)) === (2, 2) + + Sr = Wrapper(reinterpret(reshape, Complex{Int64}, S)) + @test @inferred(ArrayInterface.size(Sr)) == (static(3), static(4)) + @test @inferred(ArrayInterface.known_size(Sr)) === (3, 4) + @test @inferred(ArrayInterface.strides(Sr)) === (static(1), static(3)) end end