Skip to content

Commit

Permalink
Unpirate Union{}[] (#685)
Browse files Browse the repository at this point in the history
* Duplicate SA[] function signatures to avoid Union{} special case
* Test that Union{}[] and friends don't hit the StaticArrays code paths

Co-Authored-By: Chris Foster <chris42f@gmail.com>
  • Loading branch information
tkf and c42f committed Nov 7, 2019
1 parent cb46c2d commit 7819e3a
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
18 changes: 14 additions & 4 deletions src/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,16 @@ const SA_F64 = SA{Float64}
@inline similar_type(::Type{SA}, ::Size{S}) where {S} = SArray{Tuple{S...}}
@inline similar_type(::Type{SA{T}}, ::Size{S}) where {T,S} = SArray{Tuple{S...}, T}

@inline Base.getindex(sa::Type{<:SA}, xs...) = similar_type(sa, Size(length(xs)))(xs)
@inline Base.typed_vcat(sa::Type{<:SA}, xs::Number...) = similar_type(sa, Size(length(xs)))(xs)
@inline Base.typed_hcat(sa::Type{<:SA}, xs::Number...) = similar_type(sa, Size(1,length(xs)))(xs)
# These definitions are duplicated to avoid matching `sa === Union{}` in the
# neater-looking alternative `sa::Type{<:SA}`.
@inline Base.getindex(sa::Type{SA}, xs...) = similar_type(sa, Size(length(xs)))(xs)
@inline Base.getindex(sa::Type{SA{T}}, xs...) where T = similar_type(sa, Size(length(xs)))(xs)

@inline Base.typed_vcat(sa::Type{SA}, xs::Number...) = similar_type(sa, Size(length(xs)))(xs)
@inline Base.typed_vcat(sa::Type{SA{T}}, xs::Number...) where T = similar_type(sa, Size(length(xs)))(xs)

@inline Base.typed_hcat(sa::Type{SA}, xs::Number...) = similar_type(sa, Size(1,length(xs)))(xs)
@inline Base.typed_hcat(sa::Type{SA{T}}, xs::Number...) where T = similar_type(sa, Size(1,length(xs)))(xs)

Base.@pure function _SA_hvcat_transposed_size(rows)
M = rows[1]
Expand All @@ -40,7 +47,7 @@ Base.@pure function _SA_hvcat_transposed_size(rows)
Size(M, length(rows))
end

@inline function Base.typed_hvcat(sa::Type{<:SA}, rows::Dims, xs::Number...)
@inline function _SA_typed_hvcat(sa, rows, xs)
msize = _SA_hvcat_transposed_size(rows)
if msize === nothing
throw(ArgumentError("SA[...] matrix rows of length $rows are inconsistent"))
Expand All @@ -49,3 +56,6 @@ end
transpose(similar_type(sa, msize)(xs))
end

@inline Base.typed_hvcat(sa::Type{SA}, rows::Dims, xs::Number...) = _SA_typed_hvcat(sa, rows, xs)
@inline Base.typed_hvcat(sa::Type{SA{T}}, rows::Dims, xs::Number...) where T = _SA_typed_hvcat(sa, rows, xs)

10 changes: 10 additions & 0 deletions test/initializers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,14 @@ SA_test_hvcat(x,T) = SA{T}[1 x x;
@test SA_F64[1, 2] === SVector{2,Float64}((1,2))
@test SA_F32[1, 2] === SVector{2,Float32}((1,2))

# https://github.com/JuliaArrays/StaticArrays.jl/pull/685
@test Union{}[] isa Vector{Union{}}
@test Base.typed_vcat(Union{}) isa Vector{Union{}}
@test Base.typed_hcat(Union{}) isa Vector{Union{}}
@test Base.typed_hvcat(Union{}, ()) isa Vector{Union{}}
@test_throws MethodError Union{}[1]
@test_throws MethodError Union{}[1 2]
@test_throws MethodError Union{}[1; 2]
@test_throws MethodError Union{}[1 2; 3 4]

end

0 comments on commit 7819e3a

Please sign in to comment.