Skip to content

Commit

Permalink
Fixed dimnames and from_parent_dims tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tokazama committed Feb 16, 2021
1 parent 93dd634 commit 5387002
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/ArrayInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,7 @@ function Base.length(A::AbstractArray2)
if len === nothing
return prod(size(A))
else
return static(len)
return Int(len)
end
end

Expand Down
14 changes: 8 additions & 6 deletions src/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@ Returns the mapping from parent dimensions to child dimensions.
"""
from_parent_dims(x) = from_parent_dims(typeof(x))
from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T)))
from_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One())
from_parent_dims(::Type{T}) where {T<:VecAdjTrans} = (StaticInt(2),)
from_parent_dims(::Type{T}) where {T<:MatAdjTrans} = (StaticInt(2), One())
from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I)
@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}}
@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,I<:Tuple}
out = Expr(:tuple)
n = 1
for p in I.parameters
dim_i = 1
for i in 1:ndims(A)
p = I.parameters[i]
if argdims(A, p) > 0
push!(out.args, :(StaticInt($n)))
n += 1
push!(out.args, :(StaticInt($dim_i)))
dim_i += 1
else
push!(out.args, :(StaticInt(0)))
end
Expand Down
22 changes: 14 additions & 8 deletions test/dimensions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ function ArrayInterface.dimnames(::Type{T}, dim) where {L,T<:NamedDimsWrapper{L}
return static(L[dim])
end
end
ArrayInterface.has_dimnames(::Type{T}) where {T} = true
ArrayInterface.has_dimnames(::Type{T}) where {T<:NamedDimsWrapper} = true
Base.parent(x::NamedDimsWrapper) = x.parent

@testset "dimension permutations" begin
Expand All @@ -47,9 +47,7 @@ Base.parent(x::NamedDimsWrapper) = x.parent
@test @inferred(ArrayInterface.from_parent_dims(typeof(mview))) == (1, 0, 2)
@test @inferred(ArrayInterface.from_parent_dims(typeof(madj))) == (2, 1)
@test @inferred(ArrayInterface.from_parent_dims(typeof(vview))) == (0, 1)
@test @inferred(ArrayInterface.from_parent_dims(typeof(vadj))) == (2, 1)
@test @inferred(ArrayInterface.from_parent_dims(typeof(vadj), static(1))) == 2
@test @inferred(ArrayInterface.from_parent_dims(typeof(vadj), 1)) == 2
@test @inferred(ArrayInterface.from_parent_dims(typeof(vadj))) == (2,)
@test @inferred(ArrayInterface.from_parent_dims(typeof(vadj), static(1))) == 2
@test @inferred(ArrayInterface.from_parent_dims(typeof(vadj), 1)) == 2

Expand All @@ -61,16 +59,20 @@ Base.parent(x::NamedDimsWrapper) = x.parent

if VERSION v"1.6.0-DEV.1581"
colormat = reinterpret(reshape, Float64, [(R = rand(), G = rand(), B = rand()) for i 1:100])
@test @inferred(ArrayInterface.from_parent_dims(colormat)) === (static(2),)
@test @inferred(ArrayInterface.from_parent_dims(typeof(colormat))) === (static(2),)
@test @inferred(ArrayInterface.to_parent_dims(typeof(colormat))) === (static(0), static(1),)

Rr = reinterpret(reshape, Int32, ones(4))
@test @inferred(ArrayInterface.from_parent_dims(Rr)) === (static(2),)
@test @inferred(ArrayInterface.from_parent_dims(typeof(Rr))) === (static(2),)
@test @inferred(ArrayInterface.to_parent_dims(typeof(Rr))) === (static(0), static(1),)

Rr = reinterpret(reshape, Int64, ones(4))
@test @inferred(ArrayInterface.from_parent_dims(Rr)) === (static(1),)
@test @inferred(ArrayInterface.from_parent_dims(typeof(Rr))) === (static(1),)
@test @inferred(ArrayInterface.to_parent_dims(typeof(Rr))) === (static(1),)

Sr = reinterpret(reshape, Complex{Int64}, zeros(2, 3, 4))
@test @inferred(ArrayInterface.from_parent_dims(Sr)) === (static(0), static(1), static(2))
@test @inferred(ArrayInterface.from_parent_dims(typeof(Sr))) === (static(0), static(1), static(2))
@test @inferred(ArrayInterface.to_parent_dims(typeof(Sr))) === (static(2), static(3))
end
end

Expand All @@ -97,9 +99,12 @@ val_has_dimnames(x) = Val(ArrayInterface.has_dimnames(x))
y = NamedDimsWrapper{(:x,)}(ones(2));
dnums = ntuple(+, length(d))
@test @inferred(val_has_dimnames(x)) === Val(true)
@test @inferred(ArrayInterface.has_dimnames(ones(2,2))) === false
@test @inferred(ArrayInterface.has_dimnames(Array{Int,2})) === false
@test @inferred(val_has_dimnames(typeof(x))) === Val(true)
@test @inferred(val_has_dimnames(typeof(view(x, :, 1, :)))) === Val(true)
@test @inferred(dimnames(x)) === d
@test @inferred(dimnames(parent(x))) === (static(:_), static(:_))
@test @inferred(dimnames(x')) === reverse(d)
@test @inferred(dimnames(y')) === (static(:_), static(:x))
@test @inferred(dimnames(PermutedDimsArray(x, (2, 1)))) === reverse(d)
Expand Down Expand Up @@ -143,6 +148,7 @@ end
@test @inferred(axes(x, first(d))) == axes(parent(x), 1)
@test strides(x, :x) == ArrayInterface.strides(parent(x))[1]
@test @inferred(ArrayInterface.axes_types(x, static(:x))) <: Base.OneTo{Int}
@test ArrayInterface.axes_types(x, :x) <: Base.OneTo{Int}

x[x = 1] = [2, 3]
@test @inferred(getindex(x, x = 1)) == [2, 3]
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ end
@test @inferred(ArrayInterface.strides(M)) == strides(M)
@test @inferred(ArrayInterface.strides(Mp)) == strides(Mp)
@test @inferred(ArrayInterface.strides(Mp2)) == strides(Mp2)
@test_throws MethodError ArrayInterface.strides(DummyZeros(3,4))

@test @inferred(ArrayInterface.known_strides(A)) === (1, nothing, nothing)
@test @inferred(ArrayInterface.known_strides(Ap)) === (1, nothing)
Expand Down

0 comments on commit 5387002

Please sign in to comment.