Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make return type of map inferrable with heterogeneous arrays #42046

Merged
merged 2 commits into from
Sep 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -739,18 +739,19 @@ if isdefined(Core, :Compiler)
I = esc(itr)
return quote
if $I isa Generator && ($I).f isa Type
($I).f
T = ($I).f
else
Core.Compiler.return_type(_iterator_upper_bound, Tuple{typeof($I)})
T = Core.Compiler.return_type(_iterator_upper_bound, Tuple{typeof($I)})
end
promote_typejoin_union(T)
end
end
else
macro default_eltype(itr)
I = esc(itr)
return quote
if $I isa Generator && ($I).f isa Type
($I).f
promote_typejoin_union($I.f)
else
Any
end
Expand All @@ -775,8 +776,12 @@ function collect(itr::Generator)
return _array_for(et, itr.iter, isz)
end
v1, st = y
arr = _array_for(typeof(v1), itr.iter, isz, shape)
return collect_to_with_first!(arr, v1, itr, st)
dest = _array_for(typeof(v1), itr.iter, isz, shape)
# The typeassert gives inference a helping hand on the element type and dimensionality
# (work-around for #28382)
et′ = et <: Type ? Type : et
RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any
collect_to_with_first!(dest, v1, itr, st)::RT
end
end

Expand Down
46 changes: 1 addition & 45 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
module Broadcast

using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, @pure,
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, promote_typejoin_union, @pure,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__, BroadcastFunction
Expand Down Expand Up @@ -713,50 +713,6 @@ eltypes(t::Tuple{Any}) = Tuple{_broadcast_getindex_eltype(t[1])}
eltypes(t::Tuple{Any,Any}) = Tuple{_broadcast_getindex_eltype(t[1]), _broadcast_getindex_eltype(t[2])}
eltypes(t::Tuple) = Tuple{_broadcast_getindex_eltype(t[1]), eltypes(tail(t)).types...}

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
elseif T isa UnionAll
return Any # TODO: compute more precise bounds
elseif T isa Union
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

@pure function typejoin_union_tuple(T::Type)
u = Base.unwrap_unionall(T)
u isa Union && return typejoin(
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
p = (u::DataType).parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
c = Vector{Any}(undef, lr)
for i = 1:lr
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(U.a, U.b)
else
ci = U
end
if i == lr && Core.Compiler.isvarargtype(pi)
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
else
c[i] = ci
end
end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Inferred eltype of result of broadcast(f, args...)
combine_eltypes(f, args::Tuple) =
promote_typejoin_union(Base._return_type(f, eltypes(args)))
Expand Down
44 changes: 44 additions & 0 deletions base/promotion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,50 @@ function promote_typejoin(@nospecialize(a), @nospecialize(b))
end
_promote_typesubtract(@nospecialize(a)) = typesplit(a, Union{Nothing, Missing})

function promote_typejoin_union(::Type{T}) where T
if T === Union{}
return Union{}
elseif T isa UnionAll
return Any # TODO: compute more precise bounds
elseif T isa Union
return promote_typejoin(promote_typejoin_union(T.a), promote_typejoin_union(T.b))
elseif T <: Tuple
return typejoin_union_tuple(T)
else
return T
end
end

function typejoin_union_tuple(T::Type)
@_pure_meta
u = Base.unwrap_unionall(T)
u isa Union && return typejoin(
typejoin_union_tuple(Base.rewrap_unionall(u.a, T)),
typejoin_union_tuple(Base.rewrap_unionall(u.b, T)))
p = (u::DataType).parameters
lr = length(p)::Int
if lr == 0
return Tuple{}
end
c = Vector{Any}(undef, lr)
for i = 1:lr
pi = p[i]
U = Core.Compiler.unwrapva(pi)
if U === Union{}
ci = Union{}
elseif U isa Union
ci = typejoin(U.a, U.b)
else
ci = U
end
if i == lr && Core.Compiler.isvarargtype(pi)
c[i] = isdefined(pi, :N) ? Vararg{ci, pi.N} : Vararg{ci}
else
c[i] = ci
end
end
return Base.rewrap_unionall(Tuple{c...}, T)
end

# Returns length, isfixed
function full_va_len(p)
Expand Down
4 changes: 0 additions & 4 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -991,10 +991,6 @@ end
@test Core.Compiler.return_type(broadcast, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test isequal([1, 2] + [3.0, missing], [4.0, missing])
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test Core.Compiler.return_type(+, Tuple{Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
Expand Down
22 changes: 22 additions & 0 deletions test/generic_map_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,28 @@ function generic_map_tests(mapf, inplace_mapf=nothing)
@test A == map(x->x*x*x, Float64[1:10...])
@test A === B
end

# Issue #28382: inferrability of map with Union eltype
@test isequal(map(+, [1, 2], [3.0, missing]), [4.0, missing])
@test Core.Compiler.return_type(map, Tuple{typeof(+), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Union{Vector{Missing}, Vector{Union{Missing, Float64}}, Vector{Float64}}
@test isequal(map(tuple, [1, 2], [3.0, missing]), [(1, 3.0), (2, missing)])
@test Core.Compiler.return_type(map, Tuple{typeof(tuple), Vector{Int},
Vector{Union{Float64, Missing}}}) ==
Vector{<:Tuple{Int, Any}}
# Check that corner cases do not throw an error
@test isequal(map(x -> x === 1 ? nothing : x, [1, 2, missing]),
[nothing, 2, missing])
@test isequal(map(x -> x === 1 ? nothing : x, Any[1, 2, 3.0, missing]),
[nothing, 2, 3, missing])
@test map((x,y)->(x==1 ? 1.0 : x, y), [1, 2, 3], ["a", "b", "c"]) ==
[(1.0, "a"), (2, "b"), (3, "c")]
@test map(typeof, [iszero, isdigit]) == [typeof(iszero), typeof(isdigit)]
@test map(typeof, [iszero, iszero]) == [typeof(iszero), typeof(iszero)]
@test isequal(map(identity, Vector{<:Union{Int, Missing}}[[1, 2],[missing, 1]]),
[[1, 2],[missing, 1]])
@test map(x -> x < 0 ? false : x, Int[]) isa Vector{Integer}
end

function testmap_equivalence(mapf, f, c...)
Expand Down
1 change: 1 addition & 0 deletions test/sets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ using Dates
@test isa(Set(sin(x) for x = 1:3), Set{Float64})
@test isa(Set(f17741(x) for x = 1:3), Set{Int})
@test isa(Set(f17741(x) for x = -1:1), Set{Integer})
@test isa(Set(f17741(x) for x = 1:0), Set{Integer})
end
let s1 = Set(["foo", "bar"]), s2 = Set(s1)
@test s1 == s2
Expand Down