Skip to content

Commit

Permalink
Make recode! type stable (#407)
Browse files Browse the repository at this point in the history
Varargs appear to be type-stable according to `@code_warntype`
but in practice that's not the case.
  • Loading branch information
tiemvanderdeure authored Jan 3, 2025
1 parent 341de70 commit 3e0d056
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 44 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "324d7699-5711-5eae-9e2f-1d82baa6b597"
version = "0.10.8"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
Future = "9fa8497b-333b-5362-9e8d-4d0656e87820"
Missings = "e1d29d7a-bbdc-5cf2-9ac0-f12de2c33e28"
Expand All @@ -24,6 +25,7 @@ CategoricalArraysSentinelArraysExt = "SentinelArrays"
CategoricalArraysStructTypesExt = "StructTypes"

[compat]
Compat = "3.37, 4"
DataAPI = "1.6"
JSON = "0.15, 0.16, 0.17, 0.18, 0.19, 0.20, 0.21"
JSON3 = "1.1.2"
Expand Down
1 change: 1 addition & 0 deletions src/CategoricalArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ module CategoricalArrays
using DataAPI
using Missings
using Printf
import Compat

# JuliaLang/julia#36810
if VERSION < v"1.5.2"
Expand Down
84 changes: 40 additions & 44 deletions src/recode.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,27 +52,34 @@ A user defined type could override this method to define an appropriate test fun
optimize_pair(pair::Pair) = pair
optimize_pair(pair::Pair{<:AbstractArray}) = Set(pair.first) => pair.second

function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T}
function recode!(dest::AbstractArray, src::AbstractArray, default::Any, pairs::Pair...)
if length(dest) != length(src)
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
end

opt_pairs = map(optimize_pair, pairs)
opt_pairs = optimize_pair.(pairs)

_recode!(dest, src, default, opt_pairs)
end

function _recode!(dest::AbstractArray{T}, src::AbstractArray, default,
pairs::NTuple{<:Any, Pair}) where {T}
recode_to = last.(pairs)
recode_from = first.(pairs)

@inbounds for i in eachindex(dest, src)
x = src[i]

for j in 1:length(opt_pairs)
p = opt_pairs[j]
# we use isequal and recode_in because we cannot really distinguish scalars from collections
if x p.first || recode_in(x, p.first)
dest[i] = p.second
@goto nextitem
end
end

# @inline is needed for type stability and Compat for compatibility before julia v1.8
# we use isequal and recode_in because we cannot really
# distinguish scalars from collections
j = Compat.@inline findfirst(y -> isequal(x, y) || recode_in(x,y), recode_from)

# Value in one of the pairs
if j !== nothing
dest[i] = recode_to[j]
# Value not in any of the pairs
if ismissing(x)
elseif ismissing(x)
eltype(dest) >: Missing ||
throw(MissingException("missing value found, but dest does not support them: " *
"recode them to a supported value"))
Expand All @@ -89,21 +96,16 @@ function recode!(dest::AbstractArray{T}, src::AbstractArray, default::Any, pairs
else
dest[i] = default
end

@label nextitem
end

dest
end

function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pairs::Pair...) where {T}
if length(dest) != length(src)
throw(DimensionMismatch("dest and src must be of the same length (got $(length(dest)) and $(length(src)))"))
end

opt_pairs = map(optimize_pair, pairs)
function _recode!(dest::CategoricalArray{T, <:Any, R}, src::AbstractArray, default::Any,
pairs::NTuple{<:Any, Pair}) where {T, R}
recode_from = first.(pairs)
vals = T[p.second for p in pairs]

vals = T[p.second for p in opt_pairs]
default !== nothing && push!(vals, default)

levels!(dest.pool, filter!(!ismissing, unique(vals)))
Expand All @@ -112,22 +114,22 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
dupvals = length(vals) != length(levels(dest.pool))

drefs = dest.refs
pairmap = [ismissing(v) ? 0 : get(dest.pool, v) for v in vals]
defaultref = default === nothing || ismissing(default) ? 0 : get(dest.pool, default)
pairmap = [ismissing(v) ? zero(R) : get(dest.pool, v) for v in vals]
defaultref = default === nothing || ismissing(default) ? zero(R) : get(dest.pool, default)

@inbounds for i in eachindex(drefs, src)
x = src[i]

for j in 1:length(opt_pairs)
p = opt_pairs[j]
# we use isequal and recode_in because we cannot really distinguish scalars from collections
if x p.first || recode_in(x, p.first)
drefs[i] = dupvals ? pairmap[j] : j
@goto nextitem
end
end
# @inline is needed for type stability and Compat for compatibility before julia v1.8
# we use isequal and recode_in because we cannot really
# distinguish scalars from collections
j = Compat.@inline findfirst(y -> isequal(x, y) || recode_in(x, y), recode_from)

# Value in one of the pairs
if j !== nothing
drefs[i] = dupvals ? pairmap[j] : j
# Value not in any of the pairs
if ismissing(x)
elseif ismissing(x)
eltype(dest) >: Missing ||
throw(MissingException("missing value found, but dest does not support them: " *
"recode them to a supported value"))
Expand All @@ -144,8 +146,6 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
else
drefs[i] = defaultref
end

@label nextitem
end

# Put existing levels first, and sort them if possible
Expand All @@ -168,25 +168,21 @@ function recode!(dest::CategoricalArray{T}, src::AbstractArray, default::Any, pa
dest
end

function recode!(dest::CategoricalArray{T, N, R}, src::CategoricalArray,
default::Any, pairs::Pair...) where {T, N, R<:Integer}
if length(dest) != length(src)
throw(DimensionMismatch("dest and src must be of the same length " *
"(got $(length(dest)) and $(length(src)))"))
end

function _recode!(dest::CategoricalArray{T, N, R}, src::CategoricalArray,
default::Any, pairs::NTuple{<:Any, Pair}) where {T, N, R<:Integer}
recode_from = first.(pairs)
vals = T[p.second for p in pairs]

if default === nothing
srclevels = levels(src)

# Remove recoded levels as they won't appear in result
firsts = (p.first for p in pairs)
keptlevels = Vector{T}(undef, 0)
sizehint!(keptlevels, length(srclevels))

for l in srclevels
if !(any(x -> x l, firsts) ||
any(f -> recode_in(l, f), firsts))
if !(any(x -> x l, recode_from) ||
any(f -> recode_in(l, f), recode_from))
try
push!(keptlevels, l)
catch err
Expand Down

0 comments on commit 3e0d056

Please sign in to comment.