Skip to content

Commit

Permalink
Allow any f for sum/minimum/maximum(f, v::AbstractSparseVector) (#29884)
Browse files Browse the repository at this point in the history
* sum/minimum/maximum(f, v::AbstractSparseVector)

generalize sum/minimum/maximum(abs/abs2, v::AbstractSparseVector) to
arbitrary f

* add broken sum(f, [])==0 tests for reference
  • Loading branch information
alyst authored Apr 10, 2021
1 parent bb608e5 commit fc69c9a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 18 deletions.
52 changes: 34 additions & 18 deletions stdlib/SparseArrays/src/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1354,34 +1354,50 @@ end

### Reduction

function _sum(f, x::AbstractSparseVector)
n = length(x)
n > 0 || return sum(f, nonzeros(x)) # return zero() of proper type
m = nnz(x)
(m == 0 ? n * f(zero(eltype(x))) :
m == n ? sum(f, nonzeros(x)) :
Base.add_sum((n - m) * f(zero(eltype(x))), sum(f, nonzeros(x))))
end

sum(f::Union{Function, Type}, x::AbstractSparseVector) = _sum(f, x) # resolve ambiguity
sum(f, x::AbstractSparseVector) = _sum(f, x)
sum(x::AbstractSparseVector) = sum(nonzeros(x))

function maximum(x::AbstractSparseVector{T}) where T<:Real
function _maximum(f, x::AbstractSparseVector)
n = length(x)
n > 0 || throw(ArgumentError("maximum over empty array is not allowed."))
if n == 0
if f === abs || f === abs2
return zero(eltype(x)) # preserving maximum(abs/abs2, x) behaviour in 1.0.x
else
throw(ArgumentError("maximum over an empty array is not allowed."))
end
end
m = nnz(x)
(m == 0 ? zero(T) :
m == n ? maximum(nonzeros(x)) :
max(zero(T), maximum(nonzeros(x))))::T
(m == 0 ? f(zero(eltype(x))) :
m == n ? maximum(f, nonzeros(x)) :
max(f(zero(eltype(x))), maximum(f, nonzeros(x))))
end

function minimum(x::AbstractSparseVector{T}) where T<:Real
maximum(f::Union{Function, Type}, x::AbstractSparseVector) = _maximum(f, x) # resolve ambiguity
maximum(f, x::AbstractSparseVector) = _maximum(f, x)
maximum(x::AbstractSparseVector) = maximum(identity, x)

function _minimum(f, x::AbstractSparseVector)
n = length(x)
n > 0 || throw(ArgumentError("minimum over empty array is not allowed."))
n > 0 || throw(ArgumentError("minimum over an empty array is not allowed."))
m = nnz(x)
(m == 0 ? zero(T) :
m == n ? minimum(nonzeros(x)) :
min(zero(T), minimum(nonzeros(x))))::T
(m == 0 ? f(zero(eltype(x))) :
m == n ? minimum(f, nonzeros(x)) :
min(f(zero(eltype(x))), minimum(f, nonzeros(x))))
end

for f in [:sum, :maximum, :minimum], op in [:abs, :abs2]
SV = :AbstractSparseVector
if f === :minimum
@eval ($f)(::typeof($op), x::$SV{T}) where {T<:Number} = nnz(x) < length(x) ? ($op)(zero(T)) : ($f)($op, nonzeros(x))
else
@eval ($f)(::typeof($op), x::$SV) = ($f)($op, nonzeros(x))
end
end
minimum(f::Union{Function, Type}, x::AbstractSparseVector) = _minimum(f, x) # resolve ambiguity
minimum(f, x::AbstractSparseVector) = _minimum(f, x)
minimum(x::AbstractSparseVector) = minimum(identity, x)

norm(x::SparseVectorUnion, p::Real=2) = norm(nonzeros(x), p)

Expand Down
28 changes: 28 additions & 0 deletions stdlib/SparseArrays/test/sparsevector.jl
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,19 @@ end
@test sum(x) == 4.0
@test sum(abs, x) == 5.5
@test sum(abs2, x) == 14.375
@test @inferred(sum(t -> true, x)) === 8
@test @inferred(sum(t -> abs(t) + one(t), x)) == 13.5

@test @inferred(sum(t -> true, spzeros(Float64, 8))) === 8
@test @inferred(sum(t -> abs(t) + one(t), spzeros(Float64, 8))) === 8.0

# reducing over an empty collection
# FIXME sum(f, []) throws, should be fixed both for generic and sparse vectors
@test_broken sum(t -> true, zeros(Float64, 0)) === 0
@test_broken sum(t -> true, spzeros(Float64, 0)) === 0
@test @inferred(sum(abs2, spzeros(Float64, 0))) === 0.0
@test_broken sum(t -> abs(t) + one(t), zeros(Float64, 0)) === 0.0
@test_broken sum(t -> abs(t) + one(t), spzeros(Float64, 0)) === 0.0

@test norm(x) == sqrt(14.375)
@test norm(x, 1) == 5.5
Expand All @@ -802,6 +815,12 @@ end
@test minimum(x) == -0.75
@test maximum(abs, x) == 3.5
@test minimum(abs, x) == 0.0
@test @inferred(minimum(t -> true, x)) === true
@test @inferred(maximum(t -> true, x)) === true
@test @inferred(minimum(t -> abs(t) + one(t), x)) == 1.0
@test @inferred(maximum(t -> abs(t) + one(t), x)) == 4.5
@test @inferred(minimum(t -> t + one(t), x)) == 0.25
@test @inferred(maximum(t -> -abs(t) + one(t), x)) == 1.0
end

let x = abs.(spv_x1)
Expand All @@ -826,6 +845,15 @@ end
@test minimum(x) == 0.0
@test maximum(abs, x) == 0.0
@test minimum(abs, x) == 0.0
@test @inferred(minimum(t -> true, x)) === true
@test @inferred(maximum(t -> true, x)) === true
@test @inferred(minimum(t -> abs(t) + one(t), x)) === 1.0
@test @inferred(maximum(t -> abs(t) + one(t), x)) === 1.0
end

let x = spzeros(Float64, 0)
@test_throws ArgumentError minimum(t -> true, x)
@test_throws ArgumentError maximum(t -> true, x)
end
end

Expand Down

2 comments on commit fc69c9a

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible new issues were detected. A full report can be found here. cc @maleadt

Please sign in to comment.