Skip to content

Commit

Permalink
type-stable
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 13, 2021
1 parent 96dce9e commit 17f851b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 5 additions & 4 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2788,19 +2788,20 @@ function mapslices(f, A::AbstractArray; dims)
Aslice = A[idx1...]
r1 = f(Aslice)

if r1 isa AbstractArray && ndims(r1) > 0
res1 = if r1 isa AbstractArray && ndims(r1) > 0
n = sum(dim_mask)
if ndims(r1) > n && any(ntuple(d -> size(r1,d+n)>1, ndims(r1)-n))
s = size(r1)[1:n]
throw(DimensionMismatch("cannot assign slice f(x) of size $(size(r1)) into output of size $s"))
end
res1 = r1
r1
else
# If the result of f on a single slice is a scalar then we add singleton
# dimensions. When adding the dimensions, we have to respect the
# index type of the input array (e.g. in the case of OffsetArrays)
res1 = similar(Aslice, typeof(r1), reduced_indices(Aslice, 1:ndims(Aslice)))
res1[begin] = r1
_res1 = similar(Aslice, typeof(r1), reduced_indices(Aslice, 1:ndims(Aslice)))
_res1[begin] = r1
_res1
end

# Determine result size and allocate. We always pad ndims(res1) out to length(dims):
Expand Down
6 changes: 3 additions & 3 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ end
@test mapslices(nnz, sparse(1.0I, 3, 3), dims=1) == [1 1 1]

r = rand(Int8, 4,5,2)
@test mapslices(transpose, r, dims=(1,3)) == permutedims(r, (3,2,1))
@test @inferred(mapslices(transpose, r, dims=(1,3))) == permutedims(r, (3,2,1))
@test vec(mapslices(repr, r, dims=(2,1))) == map(repr, eachslice(r, dims=3))
@test mapslices(cumsum, sparse(r[:,:,1]), dims=1) == cumsum(r[:,:,1], dims=1)
@test mapslices(prod, sparse(r[:,:,1]), dims=1) == prod(r[:,:,1], dims=1)
Expand All @@ -1210,8 +1210,8 @@ end
@test_throws ArgumentError mapslices(identity, rand(2,3), dims=0) # previously BoundsError
@test_throws ArgumentError mapslices(identity, rand(2,3), dims=(1,3)) # previously BoundsError
@test_throws DimensionMismatch mapslices(x -> x * x', rand(2,3), dims=1) # explicitly caught
@test mapslices(hcat, [1 2; 3 4], dims=1) == [1 2; 3 4] # previously an error, now allowed
@test mapslices(identity, [1 2; 3 4], dims=(2,2)) == [1 2; 3 4] # previously an error
@test @inferred(mapslices(hcat, [1 2; 3 4], dims=1)) == [1 2; 3 4] # previously an error, now allowed
@test @inferred(mapslices(identity, [1 2; 3 4], dims=(2,2))) == [1 2; 3 4] # previously an error
end

@testset "single multidimensional index" begin
Expand Down

0 comments on commit 17f851b

Please sign in to comment.