From 29a805502c7ae86e3a1b87864b20e6a39db4728e Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Wed, 18 Aug 2021 17:32:37 +0200 Subject: [PATCH 1/2] Fix collect on stateful generator --- base/array.jl | 18 +++++++++++++++--- test/iterators.jl | 11 ++++++++++- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/base/array.jl b/base/array.jl index 15c354dce6085..32c42af0315da 100644 --- a/base/array.jl +++ b/base/array.jl @@ -766,13 +766,25 @@ function collect(itr::Generator) et = @default_eltype(itr) if isa(isz, SizeUnknown) return grow_to!(Vector{et}(), itr) - else + elseif isa(isz, HasLength) + len = length(itr) + y = iterate(itr) + if y === nothing + return et[] + end + v1, st = y + return collect_to_with_first!(Vector{typeof(v1)}(undef, len), v1, itr, st) + elseif isa(isz, HasShape) + axs = axes(itr) y = iterate(itr) if y === nothing - return _array_for(et, itr.iter, isz) + return similar(Array{et,length(axs)}, axs) end v1, st = y - collect_to_with_first!(_array_for(typeof(v1), itr.iter, isz), v1, itr, st) + arr = similar(Array{typeof(v1),length(axs)}, axs) + return collect_to_with_first!(arr, v1, itr, st) + else + error("unreachable") end end diff --git a/test/iterators.jl b/test/iterators.jl index fb8edcab92209..c7d00c4e7e2e8 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -292,6 +292,15 @@ let (a, b) = (1:3, [4 6; end end +# collect stateful iterator +let + itr = (i+1 for i in Base.Stateful([1,2,3])) + @test collect(itr) == [2, 3, 4] + A = zeros(Int, 0, 0) + itr = (i-1 for i in Base.Stateful(A)) + @test collect(itr) == Int[] # Stateful do not preserve shape +end + # with 1D inputs let a = 1:2, b = 1.0:10.0, @@ -860,4 +869,4 @@ end @test Iterators.peel(1:10)[2] |> collect == 2:10 @test Iterators.peel(x^2 for x in 2:4)[1] == 4 @test Iterators.peel(x^2 for x in 2:4)[2] |> collect == [9, 16] -end \ No newline at end of file +end From d7f347826b88af90ad9584d3288a6ecec0d3c4a1 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Thu, 19 Aug 2021 10:09:25 +0200 Subject: [PATCH 2/2] Fix type instability in generator collect --- base/array.jl | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/base/array.jl b/base/array.jl index 32c42af0315da..7febf2a63828f 100644 --- a/base/array.jl +++ b/base/array.jl @@ -758,33 +758,25 @@ else end end -_array_for(::Type{T}, itr, ::HasLength) where {T} = Vector{T}(undef, Int(length(itr)::Integer)) -_array_for(::Type{T}, itr, ::HasShape{N}) where {T,N} = similar(Array{T,N}, axes(itr)) +_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr)) +_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr)) +_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len) +_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs) function collect(itr::Generator) isz = IteratorSize(itr.iter) et = @default_eltype(itr) if isa(isz, SizeUnknown) return grow_to!(Vector{et}(), itr) - elseif isa(isz, HasLength) - len = length(itr) - y = iterate(itr) - if y === nothing - return et[] - end - v1, st = y - return collect_to_with_first!(Vector{typeof(v1)}(undef, len), v1, itr, st) - elseif isa(isz, HasShape) - axs = axes(itr) + else + shape = isz isa HasLength ? length(itr) : axes(itr) y = iterate(itr) if y === nothing - return similar(Array{et,length(axs)}, axs) + return _array_for(et, itr.iter, isz) end v1, st = y - arr = similar(Array{typeof(v1),length(axs)}, axs) + arr = _array_for(typeof(v1), itr.iter, isz, shape) return collect_to_with_first!(arr, v1, itr, st) - else - error("unreachable") end end