Skip to content

Commit

Permalink
Optimize recode for large number of categories (#345)
Browse files Browse the repository at this point in the history
  • Loading branch information
pgagarinov authored May 4, 2021
1 parent c511b4f commit 79ff7a7
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
10 changes: 9 additions & 1 deletion benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,12 @@ SUITE["repeated assignment"]["empty dest"] =
SUITE["repeated assignment"]["same levels dest"] =
@benchmarkable mycopy!(c2, a) setup = c2=copy(c)
SUITE["repeated assignment"]["many levels dest"] =
@benchmarkable mycopy!(d2, a) setup = d2=copy(d)
@benchmarkable mycopy!(d2, a) setup = d2=copy(d)

orig_vec = (x -> repeat(x, 32)).(string.([x % 1000 for x in 1:1000000]))
cat2merge_vec = (x -> repeat(x, 32)).(string.([x % 1000 for x in 1:100000]))
SUITE["recode"] = BenchmarkGroup()
SUITE["recode"]["vectors"] = @benchmarkable recode(orig_vec, cat2merge_vec => "None")
SUITE["recode"]["categorical_vectors"] = @benchmarkable recode(categorical(orig_vec), cat2merge_vec => "None")
SUITE["recode"]["matrices"] = @benchmarkable recode(reshape(orig_vec, :, 1), reshape(cat2merge_vec, :, 1) => "None")
SUITE["recode"]["categorical_matrices"] = @benchmarkable recode(categorical(reshape(orig_vec, :, 1)), reshape(cat2merge_vec, :, 1) => "None")
17 changes: 12 additions & 5 deletions src/recode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,21 @@ A user defined type could override this method to define an appropriate test fun
@inline recode_in(x, collection::Set) = x in collection
@inline recode_in(x, collection) = any(x y for y in collection)

optimize_pair(pair::Pair) = pair
optimize_pair(pair::Pair{<:AbstractArray}) = Set(pair.first) => pair.second

function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T}
if length(dest) != length(src)
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
end

opt_pairs = map(optimize_pair, pairs)

@inbounds for i in eachindex(dest, src)
x = src[i]

for j in 1:length(pairs)
p = pairs[j]
for j in 1:length(opt_pairs)
p = opt_pairs[j]
# we use isequal and recode_in because we cannot really distinguish scalars from collections
if x p.first || recode_in(x, p.first)
dest[i] = p.second
Expand Down Expand Up @@ -96,7 +101,9 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
end

vals = T[p.second for p in pairs]
opt_pairs = map(optimize_pair, pairs)

vals = T[p.second for p in opt_pairs]
default !== nothing && push!(vals, default)

levels!(dest.pool, filter!(!ismissing, unique(vals)))
Expand All @@ -110,8 +117,8 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
@inbounds for i in eachindex(drefs, src)
x = src[i]

for j in 1:length(pairs)
p = pairs[j]
for j in 1:length(opt_pairs)
p = opt_pairs[j]
# we use isequal and recode_in because we cannot really distinguish scalars from collections
if x p.first || recode_in(x, p.first)
drefs[i] = dupvals ? pairmap[j] : j
Expand Down

0 comments on commit 79ff7a7

Please sign in to comment.