Skip to content

Commit

Permalink
implement count using mapreduce
Browse files Browse the repository at this point in the history
This creates the same calling interface for `count` as for e.g. `sum`,
namely allowing the `dims` keyword.
The implementation is also shorter than before without sacrificing
performance.
`mapreduce` with `add_sum` may even yield performance benefits through
chunking, though this was not observed in simple tests.
  • Loading branch information
stev47 committed Dec 12, 2019
1 parent beebfd3 commit 90f7561
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 14 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ Standard library changes

* Sets are now displayed less compactly in the REPL, as a column of elements, like vectors
and dictionaries ([#33300]).
* `count` now accepts the `dims` keyword.
* new in-place `count!` function similar to `sum!`.

#### Libdl

Expand Down
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ export
any,
firstindex,
collect,
count!,
count,
delete!,
deleteat!,
Expand Down
15 changes: 1 addition & 14 deletions base/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -846,18 +846,5 @@ julia> count([true, false, true, true])
3
```
"""
function count(pred, itr)
n = 0
for x in itr
n += pred(x)::Bool
end
return n
end
function count(pred, a::AbstractArray)
n = 0
for i in eachindex(a)
@inbounds n += pred(a[i])::Bool
end
return n
end
count(itr) = count(identity, itr)
count(f, itr) = mapreduce(x->f(x)::Bool, add_sum, itr, init=0)
63 changes: 63 additions & 0 deletions base/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,69 @@ julia> reduce(max, a, dims=1)
reduce(op, A::AbstractArray; kw...) = mapreduce(identity, op, A; kw...)

##### Specific reduction functions #####

_bool(f::Function) = x->f(x)::Bool

"""
count([f=identity,] A::AbstractArray; dims=:)
Count the number of elements in `A` for which `f` returns `true` over the given
dimensions.
!!! compat "Julia 1.4"
`dims` keyword was added in Julia 1.4.
# Examples
```jldoctest
julia> A = [1 2; 3 4]
2×2 Array{Int64,2}:
1 2
3 4
julia> count(<=(2), A, dims=1)
1×2 Array{Int64,2}:
1 1
julia> count(<=(2), A, dims=2)
2×1 Array{Int64,2}:
2
0
```
"""
count(A::AbstractArray; dims=:) = count(identity, A, dims=dims)
count(f, A::AbstractArray; dims=:) = mapreduce(_bool(f), add_sum, A, dims=dims, init=0)

"""
count!([f=identity,] r, A; init=true)
Count the number of elements in `A` for which `f` returns `true` over the
singleton dimensions of `r`, writing the result into `r` in-place.
If `init` is `true`, values in `r` are initialized to zero.
!!! compat "Julia 1.4"
inplace `count!` was added in Julia 1.4.
# Examples
```jldoctest
julia> A = [1 2; 3 4]
2×2 Array{Int64,2}:
1 2
3 4
julia> count!(<=(2), [1 1], A)
1×2 Array{Int64,2}:
1 1
julia> count!(<=(2), [1; 1], A)
2-element Array{Int64,1}:
2
0
```
"""
count!(r::AbstractArray, A::AbstractArray; init::Bool=true) = count!(identity, r, A; init=init)
count!(f, r::AbstractArray, A::AbstractArray; init::Bool=true) =
mapreducedim!(_bool(f), add_sum, initarray!(r, add_sum, init, A), A)

"""
sum(A::AbstractArray; dims)
Expand Down
13 changes: 13 additions & 0 deletions test/reducedim.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ safe_sum(A::Array{T}, region) where {T} = safe_mapslices(sum, A, region)
safe_prod(A::Array{T}, region) where {T} = safe_mapslices(prod, A, region)
safe_maximum(A::Array{T}, region) where {T} = safe_mapslices(maximum, A, region)
safe_minimum(A::Array{T}, region) where {T} = safe_mapslices(minimum, A, region)
safe_count(A::Array{T}, region) where {T} = safe_mapslices(count, A, region)
safe_sumabs(A::Array{T}, region) where {T} = safe_mapslices(sum, abs.(A), region)
safe_sumabs2(A::Array{T}, region) where {T} = safe_mapslices(sum, abs2.(A), region)
safe_maxabs(A::Array{T}, region) where {T} = safe_mapslices(maximum, abs.(A), region)
Expand All @@ -21,6 +22,8 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
1, 2, 3, 4, 5, (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4),
(1, 2, 3), (1, 3, 4), (2, 3, 4), (1, 2, 3, 4)]
Areduc = rand(3, 4, 5, 6)
Breduc = rand(Bool, 3, 4, 5, 6)
@assert axes(Areduc) == axes(Breduc)
r = fill(NaN, map(length, Base.reduced_indices(axes(Areduc), region)))
@test sum!(r, Areduc) safe_sum(Areduc, region)
@test prod!(r, Areduc) safe_prod(Areduc, region)
Expand All @@ -30,6 +33,7 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
@test sum!(abs2, r, Areduc) safe_sumabs2(Areduc, region)
@test maximum!(abs, r, Areduc) safe_maxabs(Areduc, region)
@test minimum!(abs, r, Areduc) safe_minabs(Areduc, region)
@test count!(abs, r, Breduc) safe_count(Breduc, region)

# With init=false
r2 = similar(r)
Expand All @@ -49,6 +53,8 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
@test maximum!(abs, r, Areduc, init=false) fill!(r2, 1.5)
fill!(r, -1.5)
@test minimum!(abs, r, Areduc, init=false) fill!(r2, -1.5)
fill!(r, 1)
@test count!(r, Breduc, init=false) safe_count(Breduc, region) .+ 1

@test @inferred(sum(Areduc, dims=region)) safe_sum(Areduc, region)
@test @inferred(prod(Areduc, dims=region)) safe_prod(Areduc, region)
Expand All @@ -58,6 +64,7 @@ safe_minabs(A::Array{T}, region) where {T} = safe_mapslices(minimum, abs.(A), re
@test @inferred(sum(abs2, Areduc, dims=region)) safe_sumabs2(Areduc, region)
@test @inferred(maximum(abs, Areduc, dims=region)) safe_maxabs(Areduc, region)
@test @inferred(minimum(abs, Areduc, dims=region)) safe_minabs(Areduc, region)
@test @inferred(count(Breduc, dims=region)) safe_count(Breduc, region)
end

# Test reduction along first dimension; this is special-cased for
Expand Down Expand Up @@ -416,3 +423,9 @@ end

@test sum([Variable(:x), Variable(:y)], dims=1) == [AffExpr([Variable(:x), Variable(:y)])]
end

# count
@testset "count: throw on non-bool types" begin
@test_throws TypeError count([1], dims=1)
@test_throws TypeError count!([1], [1])
end

0 comments on commit 90f7561

Please sign in to comment.