diff --git a/src/ArrayInterface.jl b/src/ArrayInterface.jl index a9e9007c4..5fe7f1429 100644 --- a/src/ArrayInterface.jl +++ b/src/ArrayInterface.jl @@ -789,7 +789,7 @@ function Base.length(A::AbstractArray2) if len === nothing return prod(size(A)) else - return static(len) + return Int(len) end end diff --git a/src/dimensions.jl b/src/dimensions.jl index 73617a4b1..84b73c399 100644 --- a/src/dimensions.jl +++ b/src/dimensions.jl @@ -29,15 +29,17 @@ Returns the mapping from parent dimensions to child dimensions. """ from_parent_dims(x) = from_parent_dims(typeof(x)) from_parent_dims(::Type{T}) where {T} = nstatic(Val(ndims(T))) -from_parent_dims(::Type{T}) where {T<:Union{Transpose,Adjoint}} = (StaticInt(2), One()) +from_parent_dims(::Type{T}) where {T<:VecAdjTrans} = (StaticInt(2),) +from_parent_dims(::Type{T}) where {T<:MatAdjTrans} = (StaticInt(2), One()) from_parent_dims(::Type{<:SubArray{T,N,A,I}}) where {T,N,A,I} = _from_sub_dims(A, I) -@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,N,I<:Tuple{Vararg{Any,N}}} +@generated function _from_sub_dims(::Type{A}, ::Type{I}) where {A,I<:Tuple} out = Expr(:tuple) - n = 1 - for p in I.parameters + dim_i = 1 + for i in 1:ndims(A) + p = I.parameters[i] if argdims(A, p) > 0 - push!(out.args, :(StaticInt($n))) - n += 1 + push!(out.args, :(StaticInt($dim_i))) + dim_i += 1 else push!(out.args, :(StaticInt(0))) end diff --git a/test/dimensions.jl b/test/dimensions.jl index 1b3f53af0..03a5f48f6 100644 --- a/test/dimensions.jl +++ b/test/dimensions.jl @@ -20,7 +20,7 @@ function ArrayInterface.dimnames(::Type{T}, dim) where {L,T<:NamedDimsWrapper{L} return static(L[dim]) end end -ArrayInterface.has_dimnames(::Type{T}) where {T} = true +ArrayInterface.has_dimnames(::Type{T}) where {T<:NamedDimsWrapper} = true Base.parent(x::NamedDimsWrapper) = x.parent @testset "dimension permutations" begin @@ -47,9 +47,7 @@ Base.parent(x::NamedDimsWrapper) = x.parent @test @inferred(ArrayInterface.from_parent_dims(typeof(mview))) == (1, 0, 2) @test @inferred(ArrayInterface.from_parent_dims(typeof(madj))) == (2, 1) @test @inferred(ArrayInterface.from_parent_dims(typeof(vview))) == (0, 1) - @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj))) == (2, 1) - @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj), static(1))) == 2 - @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj), 1)) == 2 + @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj))) == (2,) @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj), static(1))) == 2 @test @inferred(ArrayInterface.from_parent_dims(typeof(vadj), 1)) == 2 @@ -61,16 +59,20 @@ Base.parent(x::NamedDimsWrapper) = x.parent if VERSION ≥ v"1.6.0-DEV.1581" colormat = reinterpret(reshape, Float64, [(R = rand(), G = rand(), B = rand()) for i ∈ 1:100]) - @test @inferred(ArrayInterface.from_parent_dims(colormat)) === (static(2),) + @test @inferred(ArrayInterface.from_parent_dims(typeof(colormat))) === (static(2),) + @test @inferred(ArrayInterface.to_parent_dims(typeof(colormat))) === (static(0), static(1),) Rr = reinterpret(reshape, Int32, ones(4)) - @test @inferred(ArrayInterface.from_parent_dims(Rr)) === (static(2),) + @test @inferred(ArrayInterface.from_parent_dims(typeof(Rr))) === (static(2),) + @test @inferred(ArrayInterface.to_parent_dims(typeof(Rr))) === (static(0), static(1),) Rr = reinterpret(reshape, Int64, ones(4)) - @test @inferred(ArrayInterface.from_parent_dims(Rr)) === (static(1),) + @test @inferred(ArrayInterface.from_parent_dims(typeof(Rr))) === (static(1),) + @test @inferred(ArrayInterface.to_parent_dims(typeof(Rr))) === (static(1),) Sr = reinterpret(reshape, Complex{Int64}, zeros(2, 3, 4)) - @test @inferred(ArrayInterface.from_parent_dims(Sr)) === (static(0), static(1), static(2)) + @test @inferred(ArrayInterface.from_parent_dims(typeof(Sr))) === (static(0), static(1), static(2)) + @test @inferred(ArrayInterface.to_parent_dims(typeof(Sr))) === (static(2), static(3)) end end @@ -97,9 +99,12 @@ val_has_dimnames(x) = Val(ArrayInterface.has_dimnames(x)) y = NamedDimsWrapper{(:x,)}(ones(2)); dnums = ntuple(+, length(d)) @test @inferred(val_has_dimnames(x)) === Val(true) + @test @inferred(ArrayInterface.has_dimnames(ones(2,2))) === false + @test @inferred(ArrayInterface.has_dimnames(Array{Int,2})) === false @test @inferred(val_has_dimnames(typeof(x))) === Val(true) @test @inferred(val_has_dimnames(typeof(view(x, :, 1, :)))) === Val(true) @test @inferred(dimnames(x)) === d + @test @inferred(dimnames(parent(x))) === (static(:_), static(:_)) @test @inferred(dimnames(x')) === reverse(d) @test @inferred(dimnames(y')) === (static(:_), static(:x)) @test @inferred(dimnames(PermutedDimsArray(x, (2, 1)))) === reverse(d) @@ -143,6 +148,7 @@ end @test @inferred(axes(x, first(d))) == axes(parent(x), 1) @test strides(x, :x) == ArrayInterface.strides(parent(x))[1] @test @inferred(ArrayInterface.axes_types(x, static(:x))) <: Base.OneTo{Int} + @test ArrayInterface.axes_types(x, :x) <: Base.OneTo{Int} x[x = 1] = [2, 3] @test @inferred(getindex(x, x = 1)) == [2, 3] diff --git a/test/runtests.jl b/test/runtests.jl index d18e9dbeb..44d2d1a4d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -529,6 +529,7 @@ end @test @inferred(ArrayInterface.strides(M)) == strides(M) @test @inferred(ArrayInterface.strides(Mp)) == strides(Mp) @test @inferred(ArrayInterface.strides(Mp2)) == strides(Mp2) + @test_throws MethodError ArrayInterface.strides(DummyZeros(3,4)) @test @inferred(ArrayInterface.known_strides(A)) === (1, nothing, nothing) @test @inferred(ArrayInterface.known_strides(Ap)) === (1, nothing)