Skip to content

Commit

Permalink
Int sizes + axes tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tokazama committed Feb 15, 2021
1 parent b360be7 commit 93dd634
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -775,14 +775,14 @@ include("stridelayout.jl")

abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end

Base.size(A::AbstractArray2) = ArrayInterface.size(A)
Base.size(A::AbstractArray2, dim) = ArrayInterface.size(A, dim)
Base.size(A::AbstractArray2) = map(Int, ArrayInterface.size(A))
Base.size(A::AbstractArray2, dim) = Int(ArrayInterface.size(A, dim))

Base.axes(A::AbstractArray2) = ArrayInterface.axes(A)
Base.axes(A::AbstractArray2, dim) = ArrayInterface.axes(A, dim)

Base.strides(A::AbstractArray2) = ArrayInterface.strides(A)
Base.strides(A::AbstractArray2, dim) = ArrayInterface.strides(A, dim)
Base.strides(A::AbstractArray2) = map(Int, ArrayInterface.strides(A))
Base.strides(A::AbstractArray2, dim) = Int(ArrayInterface.strides(A, dim))

function Base.length(A::AbstractArray2)
len = known_length(A)
Expand Down
1 change: 1 addition & 0 deletions src/axes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ end
Returns the type of the axes for `T`
"""
axes_types(x) = axes_types(typeof(x))
axes_types(::Type{T}) where {T<:Array} = Tuple{Vararg{OneTo{Int},ndims(T)}}
function axes_types(::Type{T}) where {T}
if parent_type(T) <: T
return Tuple{Vararg{OptionallyStaticUnitRange{One,Int},ndims(T)}}
Expand Down
3 changes: 3 additions & 0 deletions test/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ val_has_dimnames(x) = Val(ArrayInterface.has_dimnames(x))
@test @inferred(dimnames(view(x, :, :, :))) === (static(:x),static(:y), static(:_))
@test @inferred(dimnames(view(x, :, 1, :))) === (static(:x), static(:_))
@test @inferred(dimnames(x, ArrayInterface.One())) === static(:x)
@test @inferred(dimnames(parent(x), ArrayInterface.One())) === static(:_)
end

@testset "to_dims" begin
Expand All @@ -121,6 +122,7 @@ end
@test @inferred(ArrayInterface.to_dims(x, (:y, :x))) == (2, 1)
@test @inferred(ArrayInterface.to_dims(x, :x)) == 1
@test @inferred(ArrayInterface.to_dims(x, :y)) == 2
@test_throws DimensionMismatch ArrayInterface.to_dims(x, static(:z)) # not found
@test_throws DimensionMismatch ArrayInterface.to_dims(x, :z) # not found
end

Expand All @@ -140,6 +142,7 @@ end
@test @inferred(ArrayInterface.size(y')) == (1, size(parent(x), 1))
@test @inferred(axes(x, first(d))) == axes(parent(x), 1)
@test strides(x, :x) == ArrayInterface.strides(parent(x))[1]
@test @inferred(ArrayInterface.axes_types(x, static(:x))) <: Base.OneTo{Int}

x[x = 1] = [2, 3]
@test @inferred(getindex(x, x = 1)) == [2, 3]
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ end
@test @inferred(ArrayInterface.known_size(Rnr, static(1))) === 4
@test @inferred(ArrayInterface.known_size(Ar)) === (nothing,nothing, nothing,)
@test @inferred(ArrayInterface.known_size(Ar, static(1))) === nothing
@test @inferred(ArrayInterface.known_size(Ar, static(4))) === 1

@test @inferred(ArrayInterface.known_size(S)) === (2, 3, 4)
@test @inferred(ArrayInterface.known_size(Wrapper(S))) === (2, 3, 4)
Expand Down

0 comments on commit 93dd634

Please sign in to comment.