Skip to content

Commit

Permalink
Check that axes start with 1 for AbstractRange operations
Browse files Browse the repository at this point in the history
Now that we have Base.IdentityUnitRange and Base.Slice, we need to be
careful about fallbacks that just use `first`, `step`, `stop`-style
properties.
  • Loading branch information
timholy committed Feb 4, 2019
1 parent 589b96d commit a5c4c15
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 26 deletions.
38 changes: 25 additions & 13 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Module containing the broadcasting implementation.
module Broadcast

using .Base.Cartesian
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin,
using .Base: Indices, OneTo, tail, to_shape, isoperator, promote_typejoin, require_one_based_indexing,
_msk_end, unsafe_bitgetindex, bitcache_chunks, bitcache_size, dumpbitcache, unalias
import .Base: copy, copyto!, axes
export broadcast, broadcast!, BroadcastStyle, broadcast_axes, broadcastable, dotview, @__dot__
Expand Down Expand Up @@ -1002,15 +1002,20 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::OrdinalRange) = r
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen) = r
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange) = r

broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange) =
(require_one_based_indexing(r); range(-first(r), step=-step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen) = StepRangeLen(-r.ref, -r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange) = LinRange(-r.start, -r.stop, length(r))

broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) =
(require_one_based_indexing(r); range(x + first(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) =
(require_one_based_indexing(r); range(first(r) + x, length=length(r)))
# For #18336 we need to prevent promotion of the step type:
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) = range(first(r) + x, step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) = range(x + first(r), step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) =
(require_one_based_indexing(r); range(first(r) + x, step=step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) =
(require_one_based_indexing(r); range(x + first(r), step=step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen{T}, x::Number) where T =
StepRangeLen{typeof(T(r.ref)+x)}(r.ref + x, r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::StepRangeLen{T}) where T =
Expand All @@ -1019,9 +1024,12 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::LinRange, x::Number) = LinRa
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::LinRange) = LinRange(x + r.start, x + r.stop, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r1::AbstractRange, r2::AbstractRange) = r1 + r2

broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) = range(first(r)-x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r)-x, step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x-first(r), step=-step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Number) =
(require_one_based_indexing(r); range(first(r)-x, length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) =
(require_one_based_indexing(r); range(first(r)-x, step=step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) =
(require_one_based_indexing(r); range(x-first(r), step=-step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen{T}, x::Number) where T =
StepRangeLen{typeof(T(r.ref)-x)}(r.ref - x, r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::StepRangeLen{T}) where T =
Expand All @@ -1030,22 +1038,26 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange, x::Number) = LinRa
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::LinRange) = LinRange(x - r.start, x - r.stop, length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r1::AbstractRange, r2::AbstractRange) = r1 - r2

broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) = range(x*first(r), step=x*step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::AbstractRange) =
(require_one_based_indexing(r); range(x*first(r), step=x*step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::StepRangeLen{T}) where {T} =
StepRangeLen{typeof(x*T(r.ref))}(x*r.ref, x*r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), x::Number, r::LinRange) = LinRange(x * r.start, x * r.stop, r.len)
# separate in case of noncommutative multiplication
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) = range(first(r)*x, step=step(r)*x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::AbstractRange, x::Number) =
(require_one_based_indexing(r); range(first(r)*x, step=step(r)*x, length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::StepRangeLen{T}, x::Number) where {T} =
StepRangeLen{typeof(T(r.ref)*x)}(r.ref*x, r.step*x, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(*), r::LinRange, x::Number) = LinRange(r.start * x, r.stop * x, r.len)

broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) = range(first(r)/x, step=step(r)/x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::AbstractRange, x::Number) =
(require_one_based_indexing(r); range(first(r)/x, step=step(r)/x, length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::StepRangeLen{T}, x::Number) where {T} =
StepRangeLen{typeof(T(r.ref)/x)}(r.ref/x, r.step/x, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(/), r::LinRange, x::Number) = LinRange(r.start / x, r.stop / x, r.len)

broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::AbstractRange) = range(x\first(r), step=x\step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::AbstractRange) =
(require_one_based_indexing(r); range(x\first(r), step=x\step(r), length=length(r)))
broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::StepRangeLen) = StepRangeLen(x\r.ref, x\r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(\), x::Number, r::LinRange) = LinRange(x \ r.start, x \ r.stop, r.len)

Expand Down
63 changes: 50 additions & 13 deletions base/range.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,14 @@ RangeStepStyle(::Type{<:AbstractRange{<:Integer}}) = RangeStepRegular()

convert(::Type{T}, r::AbstractRange) where {T<:AbstractRange} = r isa T ? r : T(r)

AxesStartStyle(::Type{<:AbstractRange}) = AxesStartAny()
AxesStartStyle(r::AbstractRange) = AxesStartStyle(typeof(r))

require_one_based_indexing(r::AbstractRange) = _require_one_based_indexing(AxesStartStyle(r), r)
_require_one_based_indexing(::AxesStartStyle, r) =
!has_offset_axes(r) || throw(ArgumentError("offset arrays are not supported but got an array with index other than 1"))
_require_one_based_indexing(::AxesStart1, r) = true

## ordinal ranges

"""
Expand Down Expand Up @@ -250,6 +258,8 @@ steprange_last_empty(start, step, stop) = start - step

StepRange(start::T, step::S, stop::T) where {T,S} = StepRange{T,S}(start, step, stop)

AxesStartStyle(::Type{<:StepRange}) = AxesStart1()

"""
UnitRange{T<:Real}
Expand Down Expand Up @@ -297,6 +307,8 @@ if isdefined(Main, :Base)
end
end

AxesStartStyle(::Type{<:UnitRange}) = AxesStart1()

"""
Base.OneTo(n)
Expand All @@ -318,6 +330,8 @@ end
OneTo(stop::T) where {T<:Integer} = OneTo{T}(stop)
OneTo(r::AbstractRange{T}) where {T<:Integer} = OneTo{T}(r)

AxesStartStyle(::Type{<:OneTo}) = AxesStart1()

## Step ranges parameterized by length

"""
Expand Down Expand Up @@ -350,6 +364,8 @@ StepRangeLen(ref::R, step::S, len::Integer, offset::Integer = 1) where {R,S} =
StepRangeLen{T}(ref::R, step::S, len::Integer, offset::Integer = 1) where {T,R,S} =
StepRangeLen{T,R,S}(ref, step, len, offset)

AxesStartStyle(::Type{<:StepRangeLen}) = AxesStart1()

## range with computed step

"""
Expand Down Expand Up @@ -387,6 +403,8 @@ function LinRange(start, stop, len::Integer)
LinRange{T}(start, stop, len)
end

AxesStartStyle(::Type{<:LinRange}) = AxesStart1()

function _range(start::T, ::Nothing, stop::S, len::Integer) where {T,S}
a, b = promote(start, stop)
_range(a, nothing, b, len)
Expand Down Expand Up @@ -713,10 +731,14 @@ show(io::IO, r::AbstractRange) = print(io, repr(first(r)), ':', repr(step(r)), '
show(io::IO, r::UnitRange) = print(io, repr(first(r)), ':', repr(last(r)))
show(io::IO, r::OneTo) = print(io, "Base.OneTo(", r.stop, ")")

range_axes_first_same(r, s) = _range_axes_first_same(AxesStartStyle(r), AxesStartStyle(s), r, s)
_range_axes_first_same(::AxesStart1, ::AxesStart1, r, s) = true
_range_axes_first_same(::AxesStartStyle, ::AxesStartStyle, r, s) = first(axes1(r)) == first(axes1(s))

==(r::T, s::T) where {T<:AbstractRange} =
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s))
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s)) & range_axes_first_same(r, s)
==(r::OrdinalRange, s::OrdinalRange) =
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s))
(first(r) == first(s)) & (step(r) == step(s)) & (last(r) == last(s)) & range_axes_first_same(r, s)
==(r::T, s::T) where {T<:Union{StepRangeLen,LinRange}} =
(first(r) == first(s)) & (length(r) == length(s)) & (last(r) == last(s))
==(r::Union{StepRange{T},StepRangeLen{T,T}}, s::Union{StepRange{T},StepRangeLen{T,T}}) where {T} =
Expand All @@ -727,6 +749,7 @@ function ==(r::AbstractRange, s::AbstractRange)
if lr != length(s)
return false
end
range_axes_first_same(r, s) || return false
yr, ys = iterate(r), iterate(s)
while yr !== nothing
yr[1] == ys[1] || return false
Expand Down Expand Up @@ -849,7 +872,7 @@ end

## linear operations on ranges ##

-(r::OrdinalRange) = range(-first(r), step=-step(r), length=length(r))
-(r::OrdinalRange) = (require_one_based_indexing(r); range(-first(r), step=-step(r), length=length(r)))
-(r::StepRangeLen{T,R,S}) where {T,R,S} =
StepRangeLen{T,R,S}(-r.ref, -r.step, length(r), r.offset)
-(r::LinRange) = LinRange(-r.start, -r.stop, length(r))
Expand All @@ -873,8 +896,10 @@ OneTo{T}(r::OneTo) where {T<:Integer} = OneTo{T}(r.stop)

promote_rule(a::Type{UnitRange{T1}}, ::Type{UR}) where {T1,UR<:AbstractUnitRange} =
promote_rule(a, UnitRange{eltype(UR)})
UnitRange{T}(r::AbstractUnitRange) where {T<:Real} = UnitRange{T}(first(r), last(r))
UnitRange(r::AbstractUnitRange) = UnitRange(first(r), last(r))
UnitRange{T}(r::AbstractUnitRange) where {T<:Real} =
(require_one_based_indexing(r); UnitRange{T}(first(r), last(r)))
UnitRange(r::AbstractUnitRange) =
(require_one_based_indexing(r); UnitRange(first(r), last(r)))

AbstractUnitRange{T}(r::AbstractUnitRange{T}) where {T} = r
AbstractUnitRange{T}(r::UnitRange) where {T} = UnitRange{T}(r)
Expand All @@ -889,10 +914,14 @@ StepRange{T1,T2}(r::StepRange{T1,T2}) where {T1,T2} = r

promote_rule(a::Type{StepRange{T1a,T1b}}, ::Type{UR}) where {T1a,T1b,UR<:AbstractUnitRange} =
promote_rule(a, StepRange{eltype(UR), eltype(UR)})
StepRange{T1,T2}(r::AbstractRange) where {T1,T2} =
function StepRange{T1,T2}(r::AbstractRange) where {T1,T2}
require_one_based_indexing(r)
StepRange{T1,T2}(convert(T1, first(r)), convert(T2, step(r)), convert(T1, last(r)))
StepRange(r::AbstractUnitRange{T}) where {T} =
end
function StepRange(r::AbstractUnitRange{T}) where {T}
require_one_based_indexing(r)
StepRange{T,T}(first(r), step(r), last(r))
end
(::Type{StepRange{T1,T2} where T1})(r::AbstractRange) where {T2} = StepRange{eltype(r),T2}(r)

promote_rule(::Type{StepRangeLen{T1,R1,S1}},::Type{StepRangeLen{T2,R2,S2}}) where {T1,T2,R1,R2,S1,S2} =
Expand All @@ -908,15 +937,16 @@ StepRangeLen{T}(r::StepRangeLen) where {T} =
promote_rule(a::Type{StepRangeLen{T,R,S}}, ::Type{OR}) where {T,R,S,OR<:AbstractRange} =
promote_rule(a, StepRangeLen{eltype(OR), eltype(OR), eltype(OR)})
StepRangeLen{T,R,S}(r::AbstractRange) where {T,R,S} =
StepRangeLen{T,R,S}(R(first(r)), S(step(r)), length(r))
(require_one_based_indexing(r); StepRangeLen{T,R,S}(R(first(r)), S(step(r)), length(r)))
StepRangeLen{T}(r::AbstractRange) where {T} =
StepRangeLen(T(first(r)), T(step(r)), length(r))
(require_one_based_indexing(r); StepRangeLen(T(first(r)), T(step(r)), length(r)))
StepRangeLen(r::AbstractRange) = StepRangeLen{eltype(r)}(r)

promote_rule(a::Type{LinRange{T1}}, b::Type{LinRange{T2}}) where {T1,T2} =
el_same(promote_type(T1,T2), a, b)
LinRange{T}(r::LinRange{T}) where {T} = r
LinRange{T}(r::AbstractRange) where {T} = LinRange{T}(first(r), last(r), length(r))
LinRange{T}(r::AbstractRange) where {T} =
(require_one_based_indexing(r); LinRange{T}(first(r), last(r), length(r)))
LinRange(r::AbstractRange{T}) where {T} = LinRange{T}(r)

promote_rule(a::Type{LinRange{T}}, ::Type{OR}) where {T,OR<:OrdinalRange} =
Expand Down Expand Up @@ -944,7 +974,10 @@ end
Array{T,1}(r::AbstractRange{T}) where {T} = vcat(r)
collect(r::AbstractRange) = vcat(r)

reverse(r::OrdinalRange) = (:)(last(r), -step(r), first(r))
function reverse(r::OrdinalRange)
require_one_based_indexing(r)
(:)(last(r), -step(r), first(r))
end
function reverse(r::StepRangeLen)
# If `r` is empty, `length(r) - r.offset + 1 will be nonpositive hence
# invalid. As `reverse(r)` is also empty, any offset would work so we keep
Expand All @@ -964,8 +997,11 @@ sort!(r::AbstractUnitRange) = r

sort(r::AbstractRange) = issorted(r) ? r : reverse(r)

sortperm(r::AbstractUnitRange) = 1:length(r)
sortperm(r::AbstractRange) = issorted(r) ? (1:1:length(r)) : (length(r):-1:1)
sortperm(r::AbstractUnitRange) = (require_one_based_indexing(r); 1:length(r))
function sortperm(r::AbstractRange)
require_one_based_indexing(r)
issorted(r) ? (1:1:length(r)) : (length(r):-1:1)
end

function sum(r::AbstractRange{<:Real})
l = length(r)
Expand Down Expand Up @@ -1004,6 +1040,7 @@ function _define_range_op(@nospecialize f)
r1l = length(r1)
(r1l == length(r2) ||
throw(DimensionMismatch("argument dimensions must match")))
require_one_based_indexing(r1, r2)
range($f(first(r1), first(r2)), step=$f(step(r1), step(r2)), length=r1l)
end

Expand Down
15 changes: 15 additions & 0 deletions base/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,18 @@ struct RangeStepRegular <: RangeStepStyle end # range with regular step
struct RangeStepIrregular <: RangeStepStyle end # range with rounding error

RangeStepStyle(instance) = RangeStepStyle(typeof(instance))

# trait that allows skipping of axes-checking on abstract range types (risks overflow on `length`)
"""
AxesStartStyle(instance)
AxesStartStyle(T::Type)
Indicate the value that `axes(instance)` starts with. Containers that return `AxesStart1()`
must have `axes(instance)` start with 1 (e.g., `Base.OneTo` axes). Such containers may
bypass axes checks for certain operations (e.g., range comparisons to avoid risk of overflow).
`AxesStartAny()` indicates that one cannot count on the axes starting with 1, and that
an explicit check is required.
"""
abstract type AxesStartStyle end
struct AxesStart1 <: AxesStartStyle end
struct AxesStartAny <: AxesStartStyle end
41 changes: 41 additions & 0 deletions test/ranges.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1458,3 +1458,44 @@ end
Base.TwicePrecision(-1.0, -0.0), 0)
@test reverse(reverse(1.0:0.0)) === 1.0:0.0
end

@testset "Fallbacks for IdentityUnitRange" begin
r = Base.IdentityUnitRange(-2:2)
argerr = ArgumentError("offset arrays are not supported but got an array with index other than 1")
@test r != -2:2
@test r != -2:1:2
@test r == r
@test r != Base.IdentityUnitRange(-1:2)
@test +r === r
@test_throws argerr UnitRange{Int}(r)
@test_throws argerr UnitRange(r)
@test_throws argerr StepRange{Int,Int}(r)
@test_throws argerr StepRange(r)
@test_throws argerr StepRangeLen(r)
@test_throws argerr StepRangeLen{Int,Int,Int}(r)
@test_throws argerr LinRange(r)
@test_throws argerr -r
@test_throws argerr .-r
@test_throws argerr r .+ 1
@test_throws argerr 1 .+ r
@test_throws argerr r .+ im
@test_throws argerr im .+ r
@test_throws argerr r .- 1
@test_throws argerr 1 .- r
@test_throws argerr 2 * r
@test_throws argerr r * 2
@test_throws argerr 2 .* r
@test_throws argerr r .* 2
@test_throws argerr r / 2
@test_throws argerr r ./ 2
@test_throws argerr 2 \ r
@test_throws argerr 2 .\ r
@test_throws argerr r + r
@test_throws argerr r - r
@test_throws argerr r .+ r
@test_throws argerr r .- r
@test_throws MethodError r .* r
@test_throws DimensionMismatch r .* (-2:2)
@test_throws argerr reverse(r)
@test_throws argerr sortperm(r)
end

0 comments on commit a5c4c15

Please sign in to comment.