Skip to content

Commit

Permalink
WIP: deprecate implicit scalar broadcasting in setindex!
Browse files Browse the repository at this point in the history
  • Loading branch information
mbauman committed Mar 7, 2018
1 parent c27ec72 commit 149fb89
Show file tree
Hide file tree
Showing 28 changed files with 249 additions and 191 deletions.
23 changes: 15 additions & 8 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1178,8 +1178,8 @@ function get!(X::AbstractArray{T}, A::AbstractArray, I::Union{AbstractRange,Abst
# Linear indexing
ind = findall(occursin(1:length(A)), I)
X[ind] = A[I[ind]]
X[1:first(ind)-1] = default
X[last(ind)+1:length(X)] = default
X[1:first(ind)-1] .= (default,)
X[last(ind)+1:length(X)] .= (default,)
X
end

Expand Down Expand Up @@ -1377,7 +1377,11 @@ function _cat(A, shape::NTuple{N}, catdims, X...) where N
end
end
I::NTuple{N, UnitRange{Int}} = (inds...,)
A[I...] = x
if x isa AbstractArray
A[I...] = x
else
A[I...] .= (x,)
end
end
return A
end
Expand Down Expand Up @@ -1927,27 +1931,27 @@ function mapslices(f, A::AbstractArray, dims::AbstractVector)
ridx[d] = axes(R,d)
end

R[ridx...] = r1
concatenate_setindex!(R, r1, ridx...)

nidx = length(otherdims)
indices = Iterators.drop(CartesianIndices(itershape), 1)
indices = Iterators.drop(CartesianIndices(itershape), 1) # skip the first element, we already handled it
inner_mapslices!(safe_for_reuse, indices, nidx, idx, otherdims, ridx, Aslice, A, f, R)
end

@noinline function inner_mapslices!(safe_for_reuse, indices, nidx, idx, otherdims, ridx, Aslice, A, f, R)
if safe_for_reuse
# when f returns an array, R[ridx...] = f(Aslice) line copies elements,
# so we can reuse Aslice
for I in indices # skip the first element, we already handled it
for I in indices
replace_tuples!(nidx, idx, ridx, otherdims, I)
_unsafe_getindex!(Aslice, A, idx...)
R[ridx...] = f(Aslice)
concatenate_setindex!(R, f(Aslice), ridx...)
end
else
# we can't guarantee safety (#18524), so allocate new storage for each slice
for I in indices
replace_tuples!(nidx, idx, ridx, otherdims, I)
R[ridx...] = f(A[idx...])
concatenate_setindex!(R, f(A[idx...]), ridx...)
end
end

Expand All @@ -1960,6 +1964,9 @@ function replace_tuples!(nidx, idx, ridx, otherdims, I)
end
end

concatenate_setindex!(R, v, I...) = (R[I...] .= (v,); R)
concatenate_setindex!(R, X::AbstractArray, I...) = (R[I...] = X)


## 1 argument

Expand Down
6 changes: 1 addition & 5 deletions base/abstractarraymath.jl
Original file line number Diff line number Diff line change
Expand Up @@ -364,10 +364,6 @@ _rshps(shp, shp_i, sz, i, ::Tuple{}) =
_reperr(s, n, N) = throw(ArgumentError("number of " * s * " repetitions " *
"($n) cannot be less than number of dimensions of input ($N)"))

# We need special handling when repeating arrays of arrays
cat_fill!(R, X, inds) = (R[inds...] = X)
cat_fill!(R, X::AbstractArray, inds) = fill!(view(R, inds...), X)

@noinline function _repeat(A::AbstractArray, inner, outer)
shape, inner_shape = rep_shapes(A, inner, outer)

Expand All @@ -386,7 +382,7 @@ cat_fill!(R, X::AbstractArray, inds) = fill!(view(R, inds...), X)
n = inner[i]
inner_indices[i] = (1:n) .+ ((c[i] - 1) * n)
end
cat_fill!(R, A[c], inner_indices)
R[inner_indices...] .= (A[c],)
end
end

Expand Down
3 changes: 0 additions & 3 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -720,9 +720,6 @@ function setindex!(A::Array{T}, X::Array{T}, c::Colon) where T
return A
end

setindex!(A::Array, x::Number, ::Colon) = fill!(A, x)
setindex!(A::Array{T, N}, x::Number, ::Vararg{Colon, N}) where {T, N} = fill!(A, x)

# efficiently grow an array

_growbeg!(a::Vector, delta::Integer) =
Expand Down
66 changes: 34 additions & 32 deletions base/bitarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ function gen_bitarray_from_itr(itr, st)
end
end
if ind > 1
@inbounds C[ind:bitcache_size] = false
@inbounds (C[ind:bitcache_size] .= false; C)
resize!(B, length(B) + ind - 1)
dumpbitcache(Bc, cind, C)
end
Expand All @@ -602,7 +602,7 @@ function fill_bitarray_from_itr!(B::BitArray, itr, st)
end
end
if ind > 1
@inbounds C[ind:bitcache_size] = false
@inbounds (C[ind:bitcache_size] .= false; nothing)
dumpbitcache(Bc, cind, C)
end
return B
Expand Down Expand Up @@ -647,43 +647,45 @@ end
indexoffset(i) = first(i)-1
indexoffset(::Colon) = 0

@inline function setindex!(B::BitArray, x, J0::Union{Colon,UnitRange{Int}})
I0 = to_indices(B, (J0,))[1]
@boundscheck checkbounds(B, I0)
y = Bool(x)
l0 = length(I0)
l0 == 0 && return B
f0 = indexoffset(I0)+1
fill_chunks!(B.chunks, y, f0, l0)
return B
end
# TODO: re-implement this guy
# @inline function setindex!(B::BitArray, x, J0::Union{UnitRange{Int}})
# I0 = to_indices(B, (J0,))[1]
# @boundscheck checkbounds(B, I0)
# y = Bool(x)
# l0 = length(I0)
# l0 == 0 && return B
# f0 = indexoffset(I0)+1
# fill_chunks!(B.chunks, y, f0, l0)
# return B
# end
@propagate_inbounds function setindex!(B::BitArray, X::AbstractArray, J0::Union{Colon,UnitRange{Int}})
_setindex!(IndexStyle(B), B, X, to_indices(B, (J0,))[1])
end

# logical indexing

# TODO: reimplement these guys
# When indexing with a BitArray, we can operate whole chunks at a time for a ~100x gain
@inline function setindex!(B::BitArray, x, I::BitArray)
@boundscheck checkbounds(B, I)
_unsafe_setindex!(B, x, I)
end
function _unsafe_setindex!(B::BitArray, x, I::BitArray)
y = convert(Bool, x)
Bc = B.chunks
Ic = I.chunks
length(Bc) == length(Ic) || throw_boundserror(B, I)
@inbounds if y
for i = 1:length(Bc)
Bc[i] |= Ic[i]
end
else
for i = 1:length(Bc)
Bc[i] &= ~Ic[i]
end
end
return B
end
# @inline function setindex!(B::BitArray, x, I::BitArray)
# @boundscheck checkbounds(B, I)
# _unsafe_setindex!(B, x, I)
# end
# function _unsafe_setindex!(B::BitArray, x, I::BitArray)
# y = convert(Bool, x)
# Bc = B.chunks
# Ic = I.chunks
# length(Bc) == length(Ic) || throw_boundserror(B, I)
# @inbounds if y
# for i = 1:length(Bc)
# Bc[i] |= Ic[i]
# end
# else
# for i = 1:length(Bc)
# Bc[i] &= ~Ic[i]
# end
# end
# return B
# end

# Assigning an array of bools is more complicated, but we can still do some
# work on chunks by combining X and I 64 bits at a time to improve perf by ~40%
Expand Down
2 changes: 1 addition & 1 deletion base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ end
end
end
if ind > 1
@inbounds C[ind:bitcache_size] = false
fill!(@inbounds(view(C, ind:bitcache_size)), false)
dumpbitcache(Bc, cind, C)
end
return B
Expand Down
19 changes: 19 additions & 0 deletions base/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,25 @@ function lastindex(a, n)
last(axes(a, n))
end

# PR
function deprecate_scalar_setindex_broadcast_message(v, I...)
value = (Broadcast.BroadcastStyle(typeof(v)) === Broadcast.Scalar() ? "x" : "(x,)")
"using `A[I...] = x` to implicitly broadcast `x` across many locations is deprecated. Use `A[I...] .= $value` instead."
end
deprecate_scalar_setindex_broadcast_message(v, ::Colon, ::Vararg{Colon}) =
"using `A[I...] = x` to implicitly broadcast `x` across many locations is deprecated. Use `fill!(A, x)` instead."

function _iterable(v, I...)
depwarn(deprecate_scalar_setindex_broadcast_message(v, I...), :setindex!)
Iterators.repeated(v)
end
function setindex!(B::BitArray, x, I0::Union{Colon,UnitRange{Int}}, I::Union{Int,UnitRange{Int},Colon}...)
depwarn(deprecate_scalar_setindex_broadcast_message(x, I0, I...), :setindex!)
B[I0, I...] .= (x,)
B
end


@deprecate_binding repmat repeat

@deprecate Timer(timeout, repeat) Timer(timeout, interval = repeat)
Expand Down
4 changes: 2 additions & 2 deletions base/iobuffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ function IOBuffer(;
append=flags.append,
truncate=flags.truncate,
maxsize=maxsize)
buf.data[:] = 0
fill!(buf.data, 0)
return buf
end

Expand Down Expand Up @@ -246,7 +246,7 @@ function truncate(io::GenericIOBuffer, n::Integer)
if n > length(io.data)
resize!(io.data, n)
end
io.data[io.size+1:n] = 0
io.data[io.size+1:n] .= 0
io.size = n
io.ptr = min(io.ptr, n+1)
ismarked(io) && io.mark > n && unmark(io)
Expand Down
20 changes: 10 additions & 10 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -666,12 +666,11 @@ function _setindex!(l::IndexStyle, A::AbstractArray, x, I::Union{Real, AbstractA
A
end

_iterable(v::AbstractArray) = v
_iterable(v) = Iterators.repeated(v)
_iterable(v::AbstractArray, I...) = v
@generated function _unsafe_setindex!(::IndexStyle, A::AbstractArray, x, I::Union{Real,AbstractArray}...)
N = length(I)
quote
x′ = _iterable(unalias(A, x))
x′ = _iterable(unalias(A, x), I...)
@nexprs $N d->(I_d = unalias(A, I[d]))
idxlens = @ncall $N index_lengths I
@ncall $N setindex_shape_check x′ (d->idxlens[d])
Expand Down Expand Up @@ -1603,12 +1602,13 @@ end
end
end

@inline function setindex!(B::BitArray, x,
I0::Union{Colon,UnitRange{Int}}, I::Union{Int,UnitRange{Int},Colon}...)
J = to_indices(B, (I0, I...))
@boundscheck checkbounds(B, J...)
_unsafe_setindex!(B, x, J...)
end
# TODO: reimplement this guy
# @inline function setindex!(B::BitArray, x,
# I0::Union{Colon,UnitRange{Int}}, I::Union{Int,UnitRange{Int},Colon}...)
# J = to_indices(B, (I0, I...))
# @boundscheck checkbounds(B, J...)
# _unsafe_setindex!(B, x, J...)
# end
@propagate_inbounds function setindex!(B::BitArray, X::AbstractArray,
I0::Union{Colon,UnitRange{Int}}, I::Union{Int,UnitRange{Int},Colon}...)
_setindex!(IndexStyle(B), B, X, to_indices(B, (I0, I...))...)
Expand Down Expand Up @@ -1865,7 +1865,7 @@ julia> extrema(A, (1,2))
"""
function extrema(A::AbstractArray, dims)
sz = [size(A)...]
sz[[dims...]] = 1
sz[[dims...]] .= 1
B = Array{Tuple{eltype(A),eltype(A)}}(uninitialized, sz...)
return extrema!(B, A)
end
Expand Down
2 changes: 1 addition & 1 deletion base/sysimg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ if false
# simple print definitions for debugging. enable these if something
# goes wrong during bootstrap before printing code is available.
# otherwise, they just just eventually get (noisily) overwritten later
global show, print, println
global show, print, println, string
show(io::IO, x) = Core.show(io, x)
print(io::IO, a...) = Core.print(io, a...)
println(io::IO, x...) = Core.println(io, x...)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/FileWatching/src/FileWatching.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ mutable struct _FDWatcher
if fdnum > length(FDWatchers)
old_len = length(FDWatchers)
resize!(FDWatchers, fdnum)
FDWatchers[(old_len + 1):fdnum] = nothing
fill!(view(FDWatchers, (old_len + 1):fdnum), nothing)
elseif FDWatchers[fdnum] !== nothing
this = FDWatchers[fdnum]::_FDWatcher
this.refcount = (this.refcount[1] + Int(readable), this.refcount[2] + Int(writable))
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1295,7 +1295,7 @@ function pinv(A::StridedMatrix{T}, tol::Real) where T
Sinv = zeros(Stype, length(SVD.S))
index = SVD.S .> tol*maximum(SVD.S)
Sinv[index] = one(Stype) ./ SVD.S[index]
Sinv[findall(.!isfinite.(Sinv))] = zero(Stype)
Sinv[findall(.!isfinite.(Sinv))] .= zero(Stype)
return SVD.Vt' * (Diagonal(Sinv) * SVD.U')
end
function pinv(A::StridedMatrix{T}) where T
Expand Down
6 changes: 3 additions & 3 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -774,7 +774,7 @@ function ldiv!(A::QRPivoted{T}, B::StridedMatrix{T}, rcond::Real) where T<:BlasF
end
ar = abs(A.factors[1])
if ar == 0
B[1:nA, :] = 0
B[1:nA, :] .= 0
return B, 0
end
rnk = 1
Expand All @@ -795,7 +795,7 @@ function ldiv!(A::QRPivoted{T}, B::StridedMatrix{T}, rcond::Real) where T<:BlasF
end
C, τ = LAPACK.tzrzf!(A.factors[1:rnk,:])
ldiv!(UpperTriangular(C[1:rnk,1:rnk]),view(lmul!(adjoint(A.Q), view(B, 1:mA, 1:nrhs)), 1:rnk, 1:nrhs))
B[rnk+1:end,:] = zero(T)
B[rnk+1:end,:] .= zero(T)
LAPACK.ormrz!('L', eltype(B)<:Complex ? 'C' : 'T', C, τ, view(B,1:nA,1:nrhs))
B[1:nA,:] = view(B, 1:nA, :)[invperm(A.p),:]
return B, rnk
Expand Down Expand Up @@ -832,7 +832,7 @@ function ldiv!(A::QR{T}, B::StridedMatrix{T}) where T
end
LinearAlgebra.ldiv!(UpperTriangular(view(R, :, 1:minmn)), view(B, 1:minmn, :))
if n > m # Apply elementary transformation to solution
B[m + 1:mB,1:nB] = zero(T)
B[m + 1:mB,1:nB] .= zero(T)
for j = 1:nB
for k = 1:m
vBj = B[k,j]
Expand Down
9 changes: 8 additions & 1 deletion stdlib/LinearAlgebra/src/uniformscaling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,14 @@ chol(J::UniformScaling, args...) = ((C, info) = _chol!(J, nothing); @assertposde


## Matrix construction from UniformScaling
Matrix{T}(s::UniformScaling, dims::Dims{2}) where {T} = setindex!(Base.zeros(T, dims), T(s.λ), diagind(dims...))
function Matrix{T}(s::UniformScaling, dims::Dims{2}) where {T}
A = zeros(T, dims)
v = T(s.λ)
for i in diagind(dims...)
@inbounds A[i] = v
end
return A
end
Matrix{T}(s::UniformScaling, m::Integer, n::Integer) where {T} = Matrix{T}(s, Dims((m, n)))
Matrix(s::UniformScaling, m::Integer, n::Integer) = Matrix(s, Dims((m, n)))
Matrix(s::UniformScaling, dims::Dims{2}) = Matrix{eltype(s)}(s, dims)
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/test/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ end
dim=2
S=zeros(Complex,dim,dim)
T=zeros(Complex,dim,dim)
T[:] = 1
fill!(T, 1)
z = 2.5 + 1.5im
S[1] = z
@test S*T == [z z; 0 0]
Expand Down
2 changes: 1 addition & 1 deletion stdlib/Pkg3/src/GraphType.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1267,7 +1267,7 @@ function build_eq_classes_soft1!(graph::Graph, p0::Int)

# disable the other versions by introducing additional constraints
fill!(gconstr0, false)
gconstr0[repr_vers] = true
gconstr0[repr_vers] .= true

return
end
Expand Down
2 changes: 1 addition & 1 deletion stdlib/SHA/src/sha3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function digest!(context::T) where {T<:SHA3_CTX}
# Begin padding with a 0x06
context.buffer[usedspace+1] = 0x06
# Fill with zeros up until the last byte
context.buffer[usedspace+2:end-1] = 0x00
context.buffer[usedspace+2:end-1] .= 0x00
# Finish it off with a 0x80
context.buffer[end] = 0x80
else
Expand Down
Loading

0 comments on commit 149fb89

Please sign in to comment.