Skip to content

Commit

Permalink
Unify reshape methods accepting Colon and OneTo
Browse files Browse the repository at this point in the history
  • Loading branch information
jishnub committed Dec 17, 2024
1 parent da9f934 commit edda303
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 18 deletions.
2 changes: 2 additions & 0 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -835,11 +835,13 @@ similar(a::AbstractArray, ::Type{T}, dims::Dims{N}) where {T,N} = Array{T,N}(
to_shape(::Tuple{}) = ()
to_shape(dims::Dims) = dims
to_shape(dims::DimsOrInds) = map(to_shape, dims)::DimsOrInds
to_shape(dims::Tuple{Vararg{Union{Integer, AbstractUnitRange, Colon}}}) = map(to_shape, dims)
# each dimension
to_shape(i::Int) = i
to_shape(i::Integer) = Int(i)
to_shape(r::OneTo) = Int(last(r))
to_shape(r::AbstractUnitRange) = r
to_shape(r::Colon) = r

"""
similar(storagetype, axes)
Expand Down
36 changes: 19 additions & 17 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,31 +119,31 @@ julia> reshape(1:6, 2, 3)
"""
reshape

reshape(parent::AbstractArray, dims::IntOrInd...) = reshape(parent, dims)
reshape(parent::AbstractArray, shp::Tuple{Union{Integer,OneTo}, Vararg{Union{Integer,OneTo}}}) = reshape(parent, to_shape(shp))
reshape(parent::AbstractArray, dims::Tuple{Integer, Vararg{Integer}}) = reshape(parent, map(Int, dims))
# we collect the vararg indices and only define methods for tuples of indices
reshape(parent::AbstractArray, dims::Union{Integer,Colon,AbstractUnitRange}...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Tuple{Vararg{Integer}}) = reshape(parent, map(Int, dims))
reshape(parent::AbstractArray, dims::Dims) = _reshape(parent, dims)

# Allow missing dimensions with Colon():
reshape(parent::AbstractVector, ::Colon) = parent
reshape(parent::AbstractVector, ::Tuple{Colon}) = parent
reshape(parent::AbstractArray, dims::Int...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Integer...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Union{Integer,Colon}...) = reshape(parent, dims)
reshape(parent::AbstractArray, dims::Tuple{Vararg{Union{Integer,Colon}}}) = reshape(parent, _reshape_uncolon(parent, dims))

@noinline throw1(dims) = throw(DimensionMismatch(LazyString("new dimensions ", dims,
# convert axes to sizes using to_shape, and convert colons to sizes using _reshape_uncolon
# We add a level of indirection to avoid method ambiguities in reshape
reshape(parent::AbstractArray, dims::Tuple{Vararg{Union{Integer,Colon,OneTo}}}) = _reshape_maybecolon(parent, dims)
_reshape_maybecolon(parent::AbstractVector, ::Tuple{Colon}) = parent
_reshape_maybecolon(parent::AbstractArray, dims::Tuple{Vararg{Union{Integer,Colon,OneTo}}}) = reshape(parent, _reshape_uncolon(length(parent), to_shape(dims)))

@noinline _reshape_throwcolon(dims) = throw(DimensionMismatch(LazyString("new dimensions ", dims,
" may have at most one omitted dimension specified by `Colon()`")))
@noinline throw2(lenA, dims) = throw(DimensionMismatch(string("array size ", lenA,
@noinline _reshape_throwsize(lenA, dims) = throw(DimensionMismatch(LazyString("array size ", lenA,
" must be divisible by the product of the new dimensions ", dims)))

@inline function _reshape_uncolon(A, _dims::Tuple{Vararg{Union{Integer, Colon}}})
_reshape_uncolon(len, ::Tuple{Colon}) = len
@inline function _reshape_uncolon(len, _dims::Tuple{Vararg{Union{Integer, Colon}}})
# promote the dims to `Int` at least
dims = map(x -> x isa Colon ? x : promote_type(typeof(x), Int)(x), _dims)
dims isa Tuple{Vararg{Integer}} && return dims
pre = _before_colon(dims...)
post = _after_colon(dims...)
_any_colon(post...) && throw1(dims)
len = length(A)
_any_colon(post...) && _reshape_throwcolon(dims)
_reshape_uncolon_computesize(len, dims, pre, post)
end
@inline function _reshape_uncolon_computesize(len::Int, dims, pre::Tuple{Vararg{Int}}, post::Tuple{Vararg{Int}})
Expand All @@ -167,18 +167,20 @@ end
(pre..., sz, post...)
end
@inline function _reshape_uncolon_computesize_nonempty(len, dims, pr)
iszero(pr) && throw2(len, dims)
iszero(pr) && _reshape_throwsize(len, dims)
(quo, rem) = divrem(len, pr)
iszero(rem) || throw2(len, dims)
iszero(rem) || _reshape_throwsize(len, dims)
quo
end
@inline _any_colon() = false
@inline _any_colon(dim::Colon, tail...) = true
@inline _any_colon(dim::Any, tail...) = _any_colon(tail...)
@inline _before_colon(dim::Any, tail...) = (dim, _before_colon(tail...)...)
@inline _before_colon(dim::Colon, tail...) = ()
@inline _before_colon() = ()
@inline _after_colon(dim::Any, tail...) = _after_colon(tail...)
@inline _after_colon(dim::Colon, tail...) = tail
@inline _after_colon() = ()

reshape(parent::AbstractArray{T,N}, ndims::Val{N}) where {T,N} = parent
function reshape(parent::AbstractArray, ndims::Val{N}) where N
Expand Down
3 changes: 3 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2205,6 +2205,9 @@ end
@test b isa Matrix{Int}
@test b.ref === a.ref
end
C = reshape(CartesianIndices((2,2)), big(4))
@test axes(C, 1) == 1:4
@test C == CartesianIndex.([(1,1), (2,1), (1,2), (2,2)])
end
@testset "AbstractArrayMath" begin
@testset "IsReal" begin
Expand Down
10 changes: 10 additions & 0 deletions test/offsetarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,16 @@ end
a = OffsetArray(4:5, 5:6)
@test reshape(a, :) === a
@test reshape(a, (:,)) === a
R = reshape(zeros(6), 2, :)
@test R isa Matrix
@test axes(R) == (1:2, 1:3)
R = reshape(zeros(6,1), 2, :)
@test R isa Matrix
@test axes(R) == (1:2, 1:3)
R = reshape(zeros(6), 2:3, :)
@test axes(R) == (2:3, 1:3)
R = reshape(zeros(6,1), 2:3, :)
@test axes(R) == (2:3, 1:3)
end

@testset "stack" begin
Expand Down
1 change: 0 additions & 1 deletion test/testhelpers/OffsetArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,6 @@ _similar_axes_or_length(A, T, ax::I, ::I) where {I} = similar(A, T, map(_indexle
_similar_axes_or_length(AT, ax::I, ::I) where {I} = similar(AT, map(_indexlength, ax))

# reshape accepts a single colon
Base.reshape(A::AbstractArray, inds::OffsetAxis...) = reshape(A, inds)
function Base.reshape(A::AbstractArray, inds::Tuple{Vararg{OffsetAxis}})
AR = reshape(no_offset_view(A), map(_indexlength, inds))
O = OffsetArray(AR, map(_offset, axes(AR), inds))
Expand Down

0 comments on commit edda303

Please sign in to comment.