From 6829bca75f459f511623d9e854d601644432a5f2 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Thu, 17 Sep 2020 15:58:15 -0700 Subject: [PATCH 1/3] Fix output shape of map and collect --- src/ThreadsX.jl | 3 ++- src/map.jl | 15 +++++++++++---- test/test_with_base.jl | 7 +++++++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/src/ThreadsX.jl b/src/ThreadsX.jl index 1840d681..1ea6a2d0 100644 --- a/src/ThreadsX.jl +++ b/src/ThreadsX.jl @@ -38,7 +38,8 @@ module Implementations import SplittablesBase using ArgCheck: @argcheck, @check using BangBang: SingletonVector, append!!, push!!, union!! -using Base: Ordering, add_sum, mapreduce_empty, mul_prod, reduce_empty +using Base: + HasShape, IteratorSize, Ordering, add_sum, mapreduce_empty, mul_prod, reduce_empty using ConstructionBase: setproperties using InitialValues: asmonoid using Referenceables: referenceable diff --git a/src/map.jl b/src/map.jl index 49a95eae..294cee72 100644 --- a/src/map.jl +++ b/src/map.jl @@ -3,10 +3,15 @@ __map(f, itr; kwargs...) = __map(f, itrs...; kwargs...) = tcollect(MapSplat(f), zip(itrs...); basesize = default_basesize(itrs[1]), kwargs...) +reshape_as(ys, xs) = reshape_as(ys, xs, IteratorSize(xs)) +reshape_as(ys, xs, ::IteratorSize) = ys +reshape_as(ys, xs, ::HasShape) = reshape(ys, size(xs)) +reshape_as(::Empty{T}, xs, ::HasShape) where {T} = T(undef, size(xs)...) + function _map(f, itr, itrs...; kwargs...) ys = __map(f, itr, itrs...; kwargs...) isempty(ys) && return map(f, itr, itrs...) - return ys + return reshape_as(ys, itr) end ThreadsX.map(f, itr, itrs...; kwargs...) = _map(f, itr, itrs...; kwargs...) @@ -36,8 +41,10 @@ end struct ConvertTo{T} end (::ConvertTo{T})(x) where {T} = convert(T, x) -ThreadsX.collect(::Type{T}, itr; kwargs...) where {T} = - tcopy(Map(ConvertTo{T}()), Vector{T}, itr; basesize = default_basesize(itr), kwargs...) +ThreadsX.collect(::Type{T}, itr; kwargs...) where {T} = reshape_as( + tcopy(Map(ConvertTo{T}()), Vector{T}, itr; basesize = default_basesize(itr), kwargs...), + itr, +) ThreadsX.collect(itr; kwargs...) = - tcollect(itr; basesize = default_basesize(itr), kwargs...) + reshape_as(tcollect(itr; basesize = default_basesize(itr), kwargs...), itr) diff --git a/test/test_with_base.jl b/test/test_with_base.jl index 1c4f3bc3..e5408c5f 100644 --- a/test/test_with_base.jl +++ b/test/test_with_base.jl @@ -1,5 +1,6 @@ module TestWithBase +using Base: splat using Test using ThreadsX @@ -10,9 +11,14 @@ inc(x) = x + 1 raw_testdata = """ collect(1:10) collect(Float64, 1:10) +collect(x for x in 1:10 if isodd(x)) +collect(Float64, (x for x in 1:10 if isodd(x))) collect(inc(x) for x in 1:10) collect(Float64, (inc(x) for x in 1:10)) +collect(x * y for x in 1:10, y in 11:20) +collect(Float64, (x * y for x in 1:10, y in 11:20)) map(inc, 1:10) +map(inc, (x for x in 1:10 if isodd(x))) map(inc, Float64[]) map(inc, ones(3, 3)) map(inc, ones(3, 0)) @@ -21,6 +27,7 @@ map(*, 1:10, 11:20) map(*, ones(3, 3), ones(3, 3)) map(*, ones(3, 0), ones(3, 0)) map(*, ones(0, 3), ones(0, 3)) +map(splat(*), Iterators.product(1:10, 11:20)) reduce(+, 1:10) reduce(+, 1:0) reduce(+, Bool[]) From 26bfade38909a8df90226c56d1cae8dd3ac91cf0 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Thu, 17 Sep 2020 20:00:31 -0700 Subject: [PATCH 2/3] Fix map and collect on filter etc. default_basesize(xs) now works with eduction and iterator transforms. --- src/basesizes.jl | 4 ++++ src/utils.jl | 3 --- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/basesizes.jl b/src/basesizes.jl index fa4ec7d4..67d6fd13 100644 --- a/src/basesizes.jl +++ b/src/basesizes.jl @@ -1,3 +1,7 @@ +default_basesize(n::Integer) = max(1, cld(n, (8 * Threads.nthreads()))) +default_basesize(xs) = + default_basesize(SplittablesBase.amount(last(extract_transducer(xs)))) + default_basesize(_, _, xs) = default_basesize(xs::AbstractArray) # TODO: handle `Base.Fix2` etc. diff --git a/src/utils.jl b/src/utils.jl index b0b60cf0..311e721e 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,3 @@ -default_basesize(n::Integer) = max(1, cld(n, (8 * Threads.nthreads()))) -default_basesize(xs) = default_basesize(length(xs)) - function adhoc_partition(xs, n) @check firstindex(xs) == 1 m = cld(length(xs), n) From 18e956f9ff538771dd9e01c3542629735fe0db20 Mon Sep 17 00:00:00 2001 From: Takafumi Arakaki Date: Thu, 17 Sep 2020 20:12:25 -0700 Subject: [PATCH 3/3] Test and fix empty cases --- src/map.jl | 3 ++- test/test_with_base.jl | 3 +++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/map.jl b/src/map.jl index 294cee72..27b9b98a 100644 --- a/src/map.jl +++ b/src/map.jl @@ -6,7 +6,8 @@ __map(f, itrs...; kwargs...) = reshape_as(ys, xs) = reshape_as(ys, xs, IteratorSize(xs)) reshape_as(ys, xs, ::IteratorSize) = ys reshape_as(ys, xs, ::HasShape) = reshape(ys, size(xs)) -reshape_as(::Empty{T}, xs, ::HasShape) where {T} = T(undef, size(xs)...) +reshape_as(::Empty{T}, xs, isize::HasShape) where {T<:AbstractVector} = + reshape_as(T(undef, length(xs)), xs, isize) function _map(f, itr, itrs...; kwargs...) ys = __map(f, itr, itrs...; kwargs...) diff --git a/test/test_with_base.jl b/test/test_with_base.jl index e5408c5f..d135dbb5 100644 --- a/test/test_with_base.jl +++ b/test/test_with_base.jl @@ -17,6 +17,8 @@ collect(inc(x) for x in 1:10) collect(Float64, (inc(x) for x in 1:10)) collect(x * y for x in 1:10, y in 11:20) collect(Float64, (x * y for x in 1:10, y in 11:20)) +collect(x * y for x in 1:0, y in 11:20) +collect(Float64, (x * y for x in 1:0, y in 11:20)) map(inc, 1:10) map(inc, (x for x in 1:10 if isodd(x))) map(inc, Float64[]) @@ -28,6 +30,7 @@ map(*, ones(3, 3), ones(3, 3)) map(*, ones(3, 0), ones(3, 0)) map(*, ones(0, 3), ones(0, 3)) map(splat(*), Iterators.product(1:10, 11:20)) +map(splat(*), Iterators.product(1:0, 11:20)) reduce(+, 1:10) reduce(+, 1:0) reduce(+, Bool[])