Skip to content

Commit

Permalink
fix tuple handling code w.r.t. kind types/vararg types
Browse files Browse the repository at this point in the history
  • Loading branch information
jrevels committed Jun 25, 2018
1 parent 6e0120d commit c16a109
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 87 deletions.
40 changes: 9 additions & 31 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,33 +141,6 @@ end
argtail(x, rest...) = rest
tail(x::Tuple) = argtail(x...)

# TODO: a better / more infer-able definition would pehaps be
# tuple_type_head(T::Type) = fieldtype(T::Type{<:Tuple}, 1)
tuple_type_head(T::UnionAll) = (@_pure_meta; UnionAll(T.var, tuple_type_head(T.body)))
function tuple_type_head(T::Union)
@_pure_meta
return Union{tuple_type_head(T.a), tuple_type_head(T.b)}
end
function tuple_type_head(T::DataType)
@_pure_meta
T.name === Tuple.name || throw(MethodError(tuple_type_head, (T,)))
return unwrapva(T.parameters[1])
end

tuple_type_tail(T::UnionAll) = (@_pure_meta; UnionAll(T.var, tuple_type_tail(T.body)))
function tuple_type_tail(T::Union)
@_pure_meta
return Union{tuple_type_tail(T.a), tuple_type_tail(T.b)}
end
function tuple_type_tail(T::DataType)
@_pure_meta
T.name === Tuple.name || throw(MethodError(tuple_type_tail, (T,)))
if isvatuple(T) && length(T.parameters) == 1
return T
end
return Tuple{argtail(T.parameters...)...}
end

tuple_type_cons(::Type, ::Type{Union{}}) = Union{}
function tuple_type_cons(::Type{S}, ::Type{T}) where T<:Tuple where S
@_pure_meta
Expand Down Expand Up @@ -243,10 +216,15 @@ function typename(a::Union)
end
typename(union::UnionAll) = typename(union.body)

convert(::Type{T}, x::T) where {T<:Tuple{Any, Vararg{Any}}} = x
convert(::Type{Tuple{}}, x::Tuple{Any, Vararg{Any}}) = throw(MethodError(convert, (Tuple{}, x)))
convert(::Type{T}, x::Tuple{Any, Vararg{Any}}) where {T<:Tuple} =
(convert(tuple_type_head(T), x[1]), convert(tuple_type_tail(T), tail(x))...)
convert(::Type{T}, x::T) where {T<:Tuple} = x

tuple_convert_check(T, x) = isvatuple(T) || nfields(x) === fieldcount(T) || throw(MethodError(convert, (T, x)))

function convert(::Type{T}, x::NTuple{N,Any}) where {T<:Tuple,N}
tuple_convert_check(T, x)
# TODO: this is inferring Unions for concrete converts
return ntuple(i -> convert(fieldtype(T, i), x[i]), Val(N))
end

# TODO: the following definitions are equivalent (behaviorally) to the above method
# I think they may be faster / more efficient for inference,
Expand Down
33 changes: 21 additions & 12 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ import ..@__MODULE__, ..parentmodule
const Base = parentmodule(@__MODULE__)
using .Base:
@inline, Pair, AbstractDict, IndexLinear, IndexCartesian, IndexStyle, AbstractVector, Vector,
tail, tuple_type_head, tuple_type_tail, tuple_type_cons, SizeUnknown, HasLength, HasShape,
IsInfinite, EltypeUnknown, HasEltype, OneTo, @propagate_inbounds, Generator, AbstractRange,
LinearIndices, (:), |, +, -, !==, !, <=, <, missing
tail, tuple_type_cons, SizeUnknown, HasLength, HasShape, IsInfinite, EltypeUnknown, HasEltype,
OneTo, @propagate_inbounds, Generator, AbstractRange, LinearIndices, (:), |, +, -, !==, !, <=,
<, missing

import .Base:
first, last,
Expand Down Expand Up @@ -749,9 +749,12 @@ julia> collect(Iterators.product(1:2, 3:5))
"""
product(iters...) = ProductIterator(iters)

IteratorSize(::Type{ProductIterator{Tuple{}}}) = HasShape{0}()
IteratorSize(::Type{ProductIterator{T}}) where {T<:Tuple} =
prod_iteratorsize( IteratorSize(tuple_type_head(T)), IteratorSize(ProductIterator{tuple_type_tail(T)}) )
function IteratorSize(::Type{ProductIterator{T}}) where {T<:Tuple}
if isvatuple(T)
throw(ArgumentError("Cannot compute IteratorSize for ProductIterator{$T}"))
end
return reduce(prod_iteratorsize, HasShape{0}(), ntuple(i -> IteratorSize(fieldtype(T, i))), fieldcount(T))
end

prod_iteratorsize(::HasLength, ::HasLength) = HasShape{2}()
prod_iteratorsize(::HasLength, ::HasShape{N}) where {N} = HasShape{N+1}()
Expand Down Expand Up @@ -787,15 +790,21 @@ _length(p::ProductIterator) = prod(map(unsafe_length, axes(p)))
IteratorEltype(::Type{ProductIterator{Tuple{}}}) = HasEltype()
IteratorEltype(::Type{ProductIterator{Tuple{I}}}) where {I} = IteratorEltype(I)
function IteratorEltype(::Type{ProductIterator{T}}) where {T<:Tuple}
I = tuple_type_head(T)
P = ProductIterator{tuple_type_tail(T)}
IteratorEltype(I) == EltypeUnknown() ? EltypeUnknown() : IteratorEltype(P)
if isvatuple(T)
throw(ArgumentError("Cannot compute IteratorEltype for ProductIterator{$T}"))
elseif any(ntuple(i -> IteratorEltype(fieldtype(T, i)) == EltypeUnknown(), fieldcount(T)))
return EltypeUnknown()
end
return HasEltype()
end

eltype(::Type{<:ProductIterator{I}}) where {I} = _prod_eltype(I)
_prod_eltype(::Type{Tuple{}}) = Tuple{}
_prod_eltype(::Type{I}) where {I<:Tuple} =
Base.tuple_type_cons(eltype(tuple_type_head(I)),_prod_eltype(tuple_type_tail(I)))
function _prod_eltype(::Type{I}) where {I<:Tuple}
if isvatuple(I)
throw(ArgumentError("Cannot compute _prod_eltype for tuple type $I"))
end
return Tuple{ntuple(i -> eltype(fieldtype(I, i)), fieldcount(I))...}
end

iterate(::ProductIterator{Tuple{}}) = (), true
iterate(::ProductIterator{Tuple{}}, state) = nothing
Expand Down
2 changes: 0 additions & 2 deletions base/precompile.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,6 @@ precompile(Tuple{typeof(Base.take!), Base.Channel{Any}})
precompile(Tuple{typeof(Base.task_done_hook), Task})
precompile(Tuple{typeof(Base.time_print), UInt64, Int64, Int64, Int64})
precompile(Tuple{typeof(Base.truncate), Base.GenericIOBuffer{Array{UInt8, 1}}, Int64})
precompile(Tuple{typeof(Base.tuple_type_head), Type{Tuple{Vararg{Int64, N}} where N}})
precompile(Tuple{typeof(Base.tuple_type_tail), Type{Tuple{Vararg{Int64, N}} where N}})
precompile(Tuple{typeof(Base.typed_vcat), Type{Any}, Array{Symbol, 1}, Array{Symbol, 1}, Array{Symbol, 1}, Array{Symbol, 1}, Array{Symbol, 1}, Array{String, 1}})
precompile(Tuple{typeof(Base.typeinfo_eltype), Type{Any}})
precompile(Tuple{typeof(Base.unique), Array{Any, 1}})
Expand Down
2 changes: 2 additions & 0 deletions base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ include("iterators.jl")
using .Iterators: zip, enumerate
using .Iterators: Flatten, product # for generators

include("tupleconstructor.jl") # constructing a Tuple from an iterator

include("namedtuple.jl")

# numeric operations
Expand Down
42 changes: 0 additions & 42 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,48 +219,6 @@ fill_to_length(t::Tuple{}, val, ::Val{2}) = (val, val)
# return (t..., ntuple(i -> val, N - length(t))...)
#end

# constructing from an iterator

# only define these in Base, to avoid overwriting the constructors
# NOTE: this means this constructor must be avoided in Core.Compiler!
if nameof(@__MODULE__) === :Base

(::Type{T})(x::Tuple) where {T<:Tuple} = convert(T, x) # still use `convert` for tuples

# resolve ambiguity between preceding and following methods
All16{E,N}(x::Tuple) where {E,N} = convert(All16{E,N}, x)

function (T::All16{E,N})(itr) where {E,N}
len = N+16
elts = collect(E, Iterators.take(itr,len))
if length(elts) != len
_totuple_err(T)
end
(elts...,)
end

(::Type{T})(itr) where {T<:Tuple} = _totuple(T, itr)

_totuple(::Type{Tuple{}}, itr, s...) = ()

function _totuple_err(@nospecialize T)
@_noinline_meta
throw(ArgumentError("too few elements for tuple type $T"))
end

function _totuple(T, itr, s...)
@_inline_meta
y = iterate(itr, s...)
y === nothing && _totuple_err(T)
(convert(tuple_type_head(T), y[1]), _totuple(tuple_type_tail(T), itr, y[2])...)
end

_totuple(::Type{Tuple{Vararg{E}}}, itr, s...) where {E} = (collect(E, Iterators.rest(itr,s...))...,)

_totuple(::Type{Tuple}, itr, s...) = (collect(Iterators.rest(itr,s...))...,)

end

## comparison ##

isequal(t1::Tuple, t2::Tuple) = (length(t1) == length(t2)) && _isequal(t1, t2)
Expand Down
60 changes: 60 additions & 0 deletions base/tupleconstructor.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

# NOTE: This is in a separate file from tuple.jl so that it can be conditionally included
# in sysimg.jl after the necessary machinery for `@generated` is defined. Furthermore,
# this code must not be loaded into Core.Compiler.

if nameof(@__MODULE__) === :Base

(::Type{T})(x::Tuple) where {T<:Tuple} = convert(T, x) # still use `convert` for tuples

# resolve ambiguity between preceding and following methods
All16{E,N}(x::Tuple) where {E,N} = convert(All16{E,N}, x)

function (T::All16{E,N})(itr) where {E,N}
len = N+16
elts = collect(E, Iterators.take(itr,len))
if length(elts) != len
_totuple_err(T)
end
(elts...,)
end

@generated function (::Type{T})(itr)::T where {T<:Tuple}
tuple_expr = Expr(:tuple)
if isvatuple(T)
t = unwrap_unionall(T)
n = length(t.parameters) - 1
else
n = fieldcount(T)
end
for i in 1:n
push!(tuple_expr.args, quote
if done(itr, state)
_totuple_err(T)
else
item, state = next(itr, state)
convert(fieldtype(T, $i), item)
end
end)
end
if isvatuple(T)
V = rewrap_unionall(unwrap_unionall(t.parameters[n + 1]), T)
U = unwrapva(V)
if n == 0 # then avoid creating a redundant iterator
return :((collect($U, itr)...,))
end
push!(tuple_expr.args, Expr(:..., :(collect($U, Iterators.rest(itr, state)))))
end
return quote
state = start(itr)
$tuple_expr
end
end

function _totuple_err(@nospecialize T)
@_noinline_meta
throw(ArgumentError("too few elements for tuple type $T"))
end

end # nameof(@__MODULE__) === :Base
2 changes: 2 additions & 0 deletions test/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ end
@test Tuple{Vararg{Float32}}(Float64[1,2,3]) === (1.0f0, 2.0f0, 3.0f0)
@test Tuple{Int,Vararg{Float32}}(Float64[1,2,3]) === (1, 2.0f0, 3.0f0)
@test Tuple{Int,Vararg{Any}}(Float64[1,2,3]) === (1, 2.0, 3.0)
@test (Tuple{Vararg{T}} where T<:AbstractFloat)([1,2,3]) === (1.0, 2.0, 3.0)
@test (Tuple{Tuple{T,T},Vararg{T}} where T<:AbstractFloat)([(1,2.0),3,4]) === ((1.0, 2.0), 3.0, 4.0)
@test Tuple(fill(1.,5)) === (1.0,1.0,1.0,1.0,1.0)
@test_throws MethodError convert(Tuple, fill(1.,5))

Expand Down

0 comments on commit c16a109

Please sign in to comment.