From ceab9ea74106fa65e46548e693322b82d640a7c1 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Mon, 17 Aug 2020 13:27:37 -0500 Subject: [PATCH] Improve shape inference in `cat` `cat` is frequently called with poor inference, since one only has to concatenate a couple of different container types before inference punts on the result type. While this does not make the return type in mixed container types fully inferrable, it does improve the analysis of the shape. --- base/abstractarray.jl | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 0ba8fbebacf91..e26eb5be263b0 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -1437,15 +1437,15 @@ vcat(V::AbstractVector{T}...) where {T} = typed_vcat(T, V...) AbstractVecOrTuple{T} = Union{AbstractVector{<:T}, Tuple{Vararg{T}}} function _typed_vcat(::Type{T}, V::AbstractVecOrTuple{AbstractVector}) where T - n::Int = 0 + n = 0 for Vk in V - n += length(Vk) + n += Int(length(Vk))::Int end a = similar(V[1], T, n) pos = 1 - for k=1:length(V) + for k=1:Int(length(V))::Int Vk = V[k] - p1 = pos+length(Vk)-1 + p1 = pos + Int(length(Vk))::Int - 1 a[pos:p1] = Vk pos = p1+1 end @@ -1507,7 +1507,7 @@ function _typed_vcat(::Type{T}, A::AbstractVecOrTuple{AbstractVecOrMat}) where T pos = 1 for k=1:nargs Ak = A[k] - p1 = pos+size(Ak,1)-1 + p1 = pos+size(Ak,1)::Int-1 B[pos:p1, :] = Ak pos = p1+1 end @@ -1585,17 +1585,18 @@ end _cat(dims, X...) = cat_t(promote_eltypeof(X...), X...; dims=dims) @inline cat_t(::Type{T}, X...; dims) where {T} = _cat_t(dims, T, X...) -@inline function _cat_t(dims, T::Type, X...) +@inline function _cat_t(dims, ::Type{T}, X...) where {T} catdims = dims2cat(dims) - shape = cat_shape(catdims, map(cat_size, X)) + shape = cat_shape(catdims, map(cat_size, X)::Tuple{Vararg{Union{Int,Dims}}})::Dims A = cat_similar(X[1], T, shape) - if count(!iszero, catdims) > 1 + if count(!iszero, catdims)::Int > 1 fill!(A, zero(T)) end return __cat(A, shape, catdims, X...) end -function __cat(A, shape::NTuple{N}, catdims, X...) where N +function __cat(A, shape::NTuple{M,Int}, catdims, X...) where M + N = M::Int offsets = zeros(Int, N) inds = Vector{UnitRange{Int}}(undef, N) concat = copyto!(zeros(Bool, N), catdims) @@ -1702,8 +1703,8 @@ julia> hcat(x, [1; 2; 3]) """ hcat(X...) = cat(X...; dims=Val(2)) -typed_vcat(T::Type, X...) = cat_t(T, X...; dims=Val(1)) -typed_hcat(T::Type, X...) = cat_t(T, X...; dims=Val(2)) +typed_vcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(1)) +typed_hcat(::Type{T}, X...) where T = cat_t(T, X...; dims=Val(2)) """ cat(A...; dims=dims)