Skip to content

Commit

Permalink
Merge pull request #120 from SciML/abstract-type
Browse files Browse the repository at this point in the history
More fall backs to parent working and `AbstractArray2`
  • Loading branch information
chriselrod authored Feb 17, 2021
2 parents ee80672 + 5387002 commit baccd56
Show file tree
Hide file tree
Showing 10 changed files with 748 additions and 481 deletions.
71 changes: 60 additions & 11 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ using LinearAlgebra
using SparseArrays
using Base.Cartesian

using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray
using Base: @propagate_inbounds, tail, OneTo, LogicalIndex, Slice, ReinterpretArray,
ReshapedArray


@static if VERSION >= v"1.7.0-DEV.421"
using Base: @aggressive_constprop
Expand All @@ -16,13 +18,22 @@ else
end
end

if VERSION v"1.6.0-DEV.1581"
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,true}}) where {T,N,S,A} = true
_is_reshaped(::Type{ReinterpretArray{T,N,S,A,false}}) where {T,N,S,A} = false
else
_is_reshaped(::Type{ReinterpretArray{T,N,S,A}}) where {T,N,S,A} = false
end

Base.@pure __parameterless_type(T) = Base.typename(T).wrapper
parameterless_type(x) = parameterless_type(typeof(x))
parameterless_type(x::Type) = __parameterless_type(x)

const VecAdjTrans{T,V<:AbstractVector{T}} = Union{Transpose{T,V},Adjoint{T,V}}
const MatAdjTrans{T,M<:AbstractMatrix{T}} = Union{Transpose{T,M},Adjoint{T,M}}

include("static.jl")

"""
parent_type(::Type{T})
Expand Down Expand Up @@ -56,12 +67,9 @@ function known_length(::Type{T}) where {T}
if parent_type(T) <: T
return nothing
else
return known_length(parent_type(T))
return _known_length(known_size(T))
end
end
@inline function known_length(::Type{<:SubArray{T,N,P,I}}) where {T,N,P,I}
return _known_length(ntuple(i -> known_length(I.parameters[i]), Val(N)))
end
_known_length(x::Tuple{Vararg{Union{Nothing,Int}}}) = nothing
_known_length(x::Tuple{Vararg{Int}}) = prod(x)

Expand Down Expand Up @@ -594,6 +602,7 @@ struct CPUPointer <: AbstractCPU end
struct CheckParent end
struct CPUIndex <: AbstractCPU end
struct GPU <: AbstractDevice end

"""
device(::Type{T})
Expand Down Expand Up @@ -623,13 +632,18 @@ defines_strides(::Type{T}) -> Bool
Is strides(::T) defined?
"""
defines_strides(::Type) = false
defines_strides(x) = defines_strides(typeof(x))
function defines_strides(::Type{T}) where {T}
if parent_type(T) <: T
return false
else
return defines_strides(parent_type(T))
end
end
defines_strides(::Type{<:StridedArray}) = true
defines_strides(
::Type{A},
) where {A<:Union{<:Transpose,<:Adjoint,<:SubArray,<:PermutedDimsArray}} =
defines_strides(parent_type(A))
function defines_strides(::Type{<:SubArray{T,N,P,I}}) where {T,N,P,I}
return stride_preserving_index(I) === True()
end

"""
can_avx(f)
Expand Down Expand Up @@ -751,12 +765,44 @@ end
end
end

include("static.jl")
include("ranges.jl")
include("indexing.jl")
include("dimensions.jl")
include("axes.jl")
include("size.jl")
include("stridelayout.jl")


abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end

Base.size(A::AbstractArray2) = map(Int, ArrayInterface.size(A))
Base.size(A::AbstractArray2, dim) = Int(ArrayInterface.size(A, dim))

Base.axes(A::AbstractArray2) = ArrayInterface.axes(A)
Base.axes(A::AbstractArray2, dim) = ArrayInterface.axes(A, dim)

Base.strides(A::AbstractArray2) = map(Int, ArrayInterface.strides(A))
Base.strides(A::AbstractArray2, dim) = Int(ArrayInterface.strides(A, dim))

function Base.length(A::AbstractArray2)
len = known_length(A)
if len === nothing
return prod(size(A))
else
return Int(len)
end
end

@propagate_inbounds Base.getindex(A::AbstractArray2, args...) = getindex(A, args...)
@propagate_inbounds Base.getindex(A::AbstractArray2; kwargs...) = getindex(A; kwargs...)

@propagate_inbounds function Base.setindex!(A::AbstractArray2, val, args...)
return setindex!(A, val, args...)
end
@propagate_inbounds function Base.setindex!(A::AbstractArray2, val; kwargs...)
return setindex!(A, val; kwargs...)
end

function __init__()

@require SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" begin
Expand Down Expand Up @@ -811,6 +857,7 @@ function __init__()
function dense_dims(::Type{<:StaticArrays.StaticArray{S,T,N}}) where {S,T,N}
return ArrayInterface._all_dense(Val(N))
end
defines_strides(::Type{<:StaticArrays.SArray}) = true
defines_strides(::Type{<:StaticArrays.MArray}) = true

@generated function axes_types(::Type{<:StaticArrays.StaticArray{S}}) where {S}
Expand Down Expand Up @@ -1008,6 +1055,8 @@ function __init__()
end
stride_rank(::Type{A}) where {A<:OffsetArrays.OffsetArray} =
stride_rank(parent_type(A))
ArrayInterface.axes(A::OffsetArrays.OffsetArray) = Base.axes(A)
ArrayInterface.axes(A::OffsetArrays.OffsetArray, dim::Integer) = Base.axes(A, dim)
end
end

Expand Down
158 changes: 158 additions & 0 deletions src/axes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@

"""
axes_types(::Type{T}, dim)
Returns the axis type along dimension `dim`.
"""
axes_types(x, dim) = axes_types(typeof(x), dim)
@inline axes_types(::Type{T}, dim) where {T} = axes_types(T, to_dims(T, dim))
@inline function axes_types(::Type{T}, dim::StaticInt{D}) where {T,D}
if D > ndims(T)
return OptionallyStaticUnitRange{One,One}
else
return _get_tuple(axes_types(T), dim)
end
end
@inline function axes_types(::Type{T}, dim::Int) where {T}
if dim > ndims(T)
return OptionallyStaticUnitRange{One,One}
else
return axes_types(T).parameters[dim]
end
end

"""
axes_types(::Type{T}) -> Type
Returns the type of the axes for `T`
"""
axes_types(x) = axes_types(typeof(x))
axes_types(::Type{T}) where {T<:Array} = Tuple{Vararg{OneTo{Int},ndims(T)}}
function axes_types(::Type{T}) where {T}
if parent_type(T) <: T
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}}
else
return axes_types(parent_type(T))
end
end
function axes_types(::Type{T}) where {T<:VecAdjTrans}
return Tuple{OptionallyStaticUnitRange{One,One},axes_types(parent_type(T), One())}
end
function axes_types(::Type{T}) where {T<:MatAdjTrans}
return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T))
end
function axes_types(::Type{T}) where {T<:PermutedDimsArray}
return eachop_tuple(_get_tuple, axes_types(parent_type(T)), to_parent_dims(T))
end
function axes_types(::Type{T}) where {T<:AbstractRange}
if known_length(T) === nothing
return Tuple{OptionallyStaticUnitRange{One,Int}}
else
return Tuple{OptionallyStaticUnitRange{One,StaticInt{known_length(T)}}}
end
end
function axes_types(::Type{T}) where {N,T<:Base.ReshapedArray{<:Any,N}}
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},N}}
end

_int_or_static_int(::Nothing) = Int
_int_or_static_int(x::Int) = StaticInt{x}

@inline function axes_types(::Type{T}) where {N,P,I,T<:SubArray{<:Any,N,P,I}}
return eachop_tuple(_sub_axis_type, T, to_parent_dims(T))
end
@inline function _sub_axis_type(::Type{A}, dim::StaticInt) where {T,N,P,I,A<:SubArray{T,N,P,I}}
return OptionallyStaticUnitRange{
_int_or_static_int(known_first(axes_types(P, dim))),
_int_or_static_int(known_length(_get_tuple(I, dim)))
}
end

function axes_types(::Type{R}) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
if _is_reshaped(R)
if sizeof(S) === sizeof(T)
return axes_types(A)
elseif sizeof(S) > sizeof(T)
return eachop_tuple(_reshaped_axis_type, R, to_parent_dims(R))
else
return eachop_tuple(axes_types, A, to_parent_dims(R))
end
else
return eachop_tuple(_non_reshaped_axis_type, R, to_parent_dims(R))
end
end

function _reshaped_axis_type(::Type{R}, dim::StaticInt) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
return axes_types(parent_type(R), dim)
end
function _reshaped_axis_type(::Type{R}, dim::Zero) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
return OptionallyStaticUnitRange{One,StaticInt{div(sizeof(S), sizeof(T))}}
end

function _non_reshaped_axis_type(::Type{R}, dim::StaticInt) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
return axes_types(parent_type(R), dim)
end
function _non_reshaped_axis_type(::Type{R}, dim::One) where {T,N,S,A,R<:ReinterpretArray{T,N,S,A}}
paxis = axes_types(A, dim)
len = known_length(paxis)
if len === nothing
raxis = OptionallyStaticUnitRange{One,Int}
else
raxis = OptionallyStaticUnitRange{One,StaticInt{div(len * sizeof(S), sizeof(T))}}
end
return similar_type(paxis, Int, raxis)
end

#=
similar_type(orignal_type, new_data_type)
=#
similar_type(::Type{OneTo{Int}}, ::Type{Int}, ::Type{OneTo{Int}}) = OneTo{Int}
similar_type(::Type{OneTo{Int}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,Int}}) = OneTo{Int}
similar_type(::Type{OneTo{Int}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,StaticInt{N}}}) where {N} = OptionallyStaticUnitRange{One,StaticInt{N}}

similar_type(::Type{OptionallyStaticUnitRange{One,Int}}, ::Type{Int}, ::Type{OneTo{Int}}) = OptionallyStaticUnitRange{One,Int}
similar_type(::Type{OptionallyStaticUnitRange{One,Int}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,Int}}) = OptionallyStaticUnitRange{One,Int}
similar_type(::Type{OptionallyStaticUnitRange{One,Int}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,StaticInt{N}}}) where {N} = OptionallyStaticUnitRange{One,StaticInt{N}}

similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N}}}, ::Type{Int}, ::Type{OneTo{Int}}) where {N} = OptionallyStaticUnitRange{One,Int}
similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N}}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,Int}}) where {N} = OptionallyStaticUnitRange{One,Int}
similar_type(::Type{OptionallyStaticUnitRange{One,StaticInt{N1}}}, ::Type{Int}, ::Type{OptionallyStaticUnitRange{One,StaticInt{N2}}}) where {N1,N2} = OptionallyStaticUnitRange{One,StaticInt{N2}}

"""
axes(A, d)
Return a valid range that maps to each index along dimension `d` of `A`.
"""
axes(a, dim) = axes(a, to_dims(a, dim))
function axes(a::A, dim::Integer) where {A}
if parent_type(A) <: A
return Base.axes(a, Int(dim))
else
return axes(parent(a), to_parent_dims(A, dim))
end
end
axes(A::SubArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
axes(A::ReinterpretArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version
axes(A::Base.ReshapedArray, dim::Integer) = Base.axes(A, Int(dim)) # TODO implement ArrayInterface version

"""
axes(A)
Return a tuple of ranges where each range maps to each element along a dimension of `A`.
"""
@inline function axes(a::A) where {A}
if parent_type(A) <: A
return Base.axes(a)
else
return axes(parent(a))
end
end
axes(A::PermutedDimsArray) = permute(axes(parent(A)), to_parent_dims(A))
function axes(A::Union{Transpose,Adjoint})
p = parent(A)
return (axes(p, StaticInt(2)), axes(p, One()))
end
axes(A::SubArray) = Base.axes(A) # TODO implement ArrayInterface version
axes(A::ReinterpretArray) = Base.axes(A) # TODO implement ArrayInterface version
axes(A::Base.ReshapedArray) = Base.axes(A) # TODO implement ArrayInterface version

Loading

0 comments on commit baccd56

Please sign in to comment.