diff --git a/src/extras.jl b/src/extras.jl index f536f06f..2afcef38 100644 --- a/src/extras.jl +++ b/src/extras.jl @@ -9,11 +9,14 @@ function fill_refs!(refs::AbstractArray, X::AbstractArray, @inbounds for i in eachindex(X) x = X[i] - if ismissing(x) + if x isa Number && isnan(x) + throw(ArgumentError("NaN values are not allowed in input vector")) + elseif ismissing(x) refs[i] = 0 - elseif x == upper + elseif isequal(x, upper) refs[i] = n-1 - elseif extend !== true && !(lower <= x <= upper) + elseif extend !== true && + !((isless(lower, x) || isequal(x, lower)) && isless(x, upper)) extend === missing || throw(ArgumentError("value $x (at index $i) does not fall inside the breaks: " * "adapt them manually, or pass extend=true or extend=missing")) @@ -55,10 +58,10 @@ also accept them. the intervals; or a function `f(from, to, i; leftclosed, rightclosed)` that generates the labels from the left and right interval boundaries and the group index. Defaults to `"[from, to)"` (or `"[from, to]"` for the rightmost interval if `extend == true`). -* `allowempty::Bool=false`: when `false`, an error is raised if some breaks appear - multiple times, generating empty intervals; when `true`, duplicate breaks are allowed - and the intervals they generate are kept as unused levels - (but duplicate labels are not allowed). +* `allowempty::Bool=false`: when `false`, an error is raised if some breaks other than + the last one appear multiple times, generating empty intervals; when `true`, + duplicate breaks are allowed and the intervals they generate are kept as + unused levels (but duplicate labels are not allowed). # Examples ```jldoctest @@ -132,14 +135,19 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector, extend::Union{Bool, Missing}, labels::Union{AbstractVector{<:SupportedTypes},Function}, allowempty::Bool=false) where {T, N} - if !allowempty && !allunique(breaks) - throw(ArgumentError("all breaks must be unique unless `allowempty=true`")) - end - if !issorted(breaks) breaks = sort(breaks) end + if any(x -> x isa Number && isnan(x), breaks) + throw(ArgumentError("NaN values are not allowed in breaks")) + end + + if !allowempty && !allunique(@view breaks[1:end-1]) + throw(ArgumentError("all breaks other than the last one must be unique " * + "unless `allowempty=true`")) + end + if extend === true xnm = T >: Missing ? skipmissing(x) : x length(breaks) >= 1 || throw(ArgumentError("at least one break must be provided")) @@ -158,11 +166,11 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector, rethrow(err) end end - if !ismissing(min_x) && breaks[1] > min_x + if !ismissing(min_x) && isless(min_x, breaks[1]) # this type annotation is needed on Julia<1.7 for stable inference breaks = [min_x::nonmissingtype(eltype(x)); breaks] end - if !ismissing(max_x) && breaks[end] < max_x + if !ismissing(max_x) && isless(breaks[end], max_x) breaks = [breaks; max_x::nonmissingtype(eltype(x))] end length(breaks) > 1 || @@ -189,16 +197,15 @@ function _cut(x::AbstractArray{T, N}, breaks::AbstractVector, from = breaks[1:n-1] to = breaks[2:n] firstlevel = labels(from[1], to[1], 1, - leftclosed=breaks[1] != breaks[2], rightclosed=false) + leftclosed=!isequal(breaks[1], breaks[2]), rightclosed=false) levs = Vector{typeof(firstlevel)}(undef, n-1) levs[1] = firstlevel for i in 2:n-2 levs[i] = labels(from[i], to[i], i, - leftclosed=breaks[i] != breaks[i+1], rightclosed=false) + leftclosed=!isequal(breaks[i], breaks[i+1]), rightclosed=false) end levs[end] = labels(from[end], to[end], n-1, - leftclosed=breaks[end-1] != breaks[end], - rightclosed=true) + leftclosed=true, rightclosed=true) else length(labels) == n-1 || throw(ArgumentError("labels must be of length $(n-1), but got length $(length(labels))")) @@ -243,21 +250,28 @@ quantiles. the labels from the left and right interval boundaries and the group index. Defaults to `"Qi: [from, to)"` (or `"Qi: [from, to]"` for the rightmost interval). * `allowempty::Bool=false`: when `false`, an error is raised if some quantiles breakpoints - are equal, generating empty intervals; when `true`, duplicate breaks are allowed - and the intervals they generate are kept as unused levels - (but duplicate labels are not allowed). + other than the last one are equal, generating empty intervals; + when `true`, duplicate breaks are allowed and the intervals they generate are kept as + unused levels (but duplicate labels are not allowed). """ function cut(x::AbstractArray, ngroups::Integer; labels::Union{AbstractVector{<:SupportedTypes},Function}=quantile_formatter, allowempty::Bool=false) + ngroups >= 1 || throw(ArgumentError("ngroups must be strictly positive (got $ngroups)")) xnm = eltype(x) >: Missing ? skipmissing(x) : x - breaks = Statistics.quantile(xnm, (1:ngroups-1)/ngroups) - if !allowempty && !allunique(breaks) - n = length(unique(breaks)) - 1 - throw(ArgumentError("cannot compute $ngroups quantiles: `quantile` " * - "returned only $n groups due to duplicated values in `x`." * + # Computing extrema is faster than taking 0 and 1 quantiles + min_x, max_x = extrema(xnm) + if (min_x isa Number && isnan(min_x)) || + (max_x isa Number && isnan(max_x)) + throw(ArgumentError("NaN values are not allowed in input vector")) + end + breaks = quantile(xnm, (1:ngroups-1)/ngroups) + breaks = [min_x; breaks; max_x] + if !allowempty && !allunique(@view breaks[1:end-1]) + throw(ArgumentError("cannot compute $ngroups quantiles due to " * + "too many duplicated values in `x`. " * "Pass `allowempty=true` to allow empty quantiles or " * "choose a lower value for `ngroups`.")) end - cut(x, breaks; extend=true, labels=labels, allowempty=allowempty) + cut(x, breaks; labels=labels, allowempty=allowempty) end diff --git a/test/15_extras.jl b/test/15_extras.jl index 14fb4352..1aaf8dc7 100644 --- a/test/15_extras.jl +++ b/test/15_extras.jl @@ -111,9 +111,6 @@ const ≅ = isequal @test isa(x, CategoricalVector{Union{Int, String, T}}) @test isordered(x) @test levels(x) == [0, "2", 4, "6", 8] - - @test_throws ArgumentError cut([-0.0, 0.0], 2) - @test_throws ArgumentError cut([-0.0, 0.0], 2, labels=[-0.0, 0.0]) end @testset "cut with missing values in input" begin @@ -144,6 +141,11 @@ end @test levels(x) == ["Q1: [2.0, 3.5)", "Q2: [3.5, 5.0]"] end +@testset "cut(x, n) with invalid n" begin + @test_throws ArgumentError cut(1:10, 0) + @test_throws ArgumentError cut(1:10, -1) +end + @testset "cut with formatter function" begin my_formatter(from, to, i; leftclosed, rightclosed) = "$i: $from -- $to" @@ -185,11 +187,20 @@ end x = [zeros(10); ones(10)] @test_throws ArgumentError cut(x, [0, 0.1, 0.1, 10]) @test_throws ArgumentError cut(x, 10) + y = cut(x, [0, 0.1, 10, 10]) + @test y == [fill("[0.0, 0.1)", 10); fill("[0.1, 10.0)", 10)] + @test levels(y) == ["[0.0, 0.1)", "[0.1, 10.0)", "[10.0, 10.0]"] @test_throws ArgumentError cut(1:10, [1, 5, 5, 11]) y = cut(1:10, [1, 5, 5, 11], allowempty=true) @test y == cut(1:10, [1, 5, 11]) @test levels(y) == ["[1, 5)", "(5, 5)", "[5, 11]"] + y = cut(1:10, [1, 5, 11, 11]) + @test y == [fill("[1, 5)", 4); fill("[5, 11)", 6)] + @test levels(y) == ["[1, 5)", "[5, 11)", "[11, 11]"] + y = cut(1:10, [1, 5, 10, 10]) + @test y == [fill("[1, 5)", 4); fill("[5, 10)", 5); "[10, 10]"] + @test levels(y) == ["[1, 5)", "[5, 10)", "[10, 10]"] @test_throws ArgumentError cut(1:10, [1, 5, 5, 5, 11]) @test_throws ArgumentError cut(1:10, [1, 5, 5, 11], @@ -242,6 +253,49 @@ end fmt = (from, to, i; leftclosed, rightclosed) -> (i % 2 == 0 ? to : 0.0) @test_throws ArgumentError cut(1:8, 0:2:10, labels=fmt) + + @test_throws ArgumentError cut([fill(1, 10); 4], 2) + @test_throws ArgumentError cut([fill(1, 10); 4], 3) + x = cut([fill(1, 10); 4], 2, allowempty=true) + @test unique(x) == ["Q2: [1.0, 4.0]"] + x = cut([fill(1, 10); 4], 3, allowempty=true) + @test unique(x) == ["Q3: [1.0, 4.0]"] + @test levels(x) == ["Q1: (1.0, 1.0)", "Q2: (1.0, 1.0)", "Q3: [1.0, 4.0]"] + + x = cut([fill(1, 5); fill(4, 5)], 2) + @test x == [fill("Q1: [1.0, 2.5)", 5); fill("Q2: [2.5, 4.0]", 5)] + @test levels(x) == ["Q1: [1.0, 2.5)", "Q2: [2.5, 4.0]"] + @test_throws ArgumentError cut([fill(1, 5); fill(4, 5)], 3) + x = cut([fill(1, 5); fill(4, 5)], 3, allowempty=true) + @test x == [fill("Q2: [1.0, 4.0)", 5); fill("Q3: [4.0, 4.0]", 5)] + @test levels(x) == ["Q1: (1.0, 1.0)", "Q2: [1.0, 4.0)", "Q3: [4.0, 4.0]"] +end + +@testset "cut with -0.0" begin + x = cut([-0.0, 0.0, 0.0, -0.0], 2) + @test x == ["Q1: [-0.0, 0.0)", "Q2: [0.0, 0.0]", "Q2: [0.0, 0.0]", "Q1: [-0.0, 0.0)"] + @test levels(x) == ["Q1: [-0.0, 0.0)", "Q2: [0.0, 0.0]"] + + x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0, 0.0, 0.0]) + @test x == ["[-0.0, 0.0)", "[0.0, 0.0]", "[0.0, 0.0]", "[-0.0, 0.0)"] + @test levels(x) == ["[-0.0, 0.0)", "[0.0, 0.0]"] + + x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0, 0.0]) + @test x == fill("[-0.0, 0.0]", 4) + @test levels(x) == ["[-0.0, 0.0]"] + + x = cut([-0.0, 0.0, 0.0, -0.0], [0.0], extend=true) + @test x == fill("[-0.0, 0.0]", 4) + @test levels(x) == ["[-0.0, 0.0]"] + + x = cut([-0.0, 0.0, 0.0, -0.0], [-0.0], extend=true) + @test x == fill("[-0.0, 0.0]", 4) + @test levels(x) == ["[-0.0, 0.0]"] + + x = cut([-0.0, 0.0, 0.0, -0.0], 2, labels=[-0.0, 0.0]) + @test x == [-0.0, 0.0, 0.0, -0.0] + + @test_throws ArgumentError cut([-0.0, 0.0, 0.0, -0.0], [-0.0, -0.0, 0.0]) end @testset "cut with extend=true" begin @@ -276,4 +330,35 @@ end @test x == ["[-1.0, 0.0)", "[-1.0, 0.0)", "[0.0, 1.0]", "[0.0, 1.0]", "[0.0, 1.0]"] end +@testset "cut with NaN and Inf" begin + @test_throws ArgumentError("NaN values are not allowed in input vector") cut([1, NaN, 2, 3], [1, 10]) + @test_throws ArgumentError("NaN values are not allowed in input vector") cut([1, NaN, 2, 3], [1], extend=true) + @test_throws ArgumentError("NaN values are not allowed in input vector") cut([1, NaN, 2, 3], 2) + @test_throws ArgumentError("NaN values are not allowed in breaks") cut([1, 2], [1, NaN]) + + x = cut([1, Inf], [1], extend=true) + @test x ≅ ["[1.0, Inf]", "[1.0, Inf]"] + @test levels(x) == ["[1.0, Inf]"] + + x = cut([1, -Inf], [1], extend=true) + @test x ≅ ["[-Inf, 1.0]", "[-Inf, 1.0]"] + @test levels(x) == ["[-Inf, 1.0]"] + + x = cut([1:5; Inf], [1, 2, Inf]) + @test x ≅ ["[1.0, 2.0)"; fill("[2.0, Inf]", 5)] + @test levels(x) == ["[1.0, 2.0)", "[2.0, Inf]"] + + x = cut([1:5; -Inf], [-Inf, 2, 5]) + @test x ≅ ["[-Inf, 2.0)"; fill("[2.0, 5.0]", 4); "[-Inf, 2.0)"] + @test levels(x) == ["[-Inf, 2.0)", "[2.0, 5.0]"] + + x = cut([1:5; Inf], 2) + @test x ≅ [fill("Q1: [1.0, 3.5)", 3); fill("Q2: [3.5, Inf]", 3)] + @test levels(x) == ["Q1: [1.0, 3.5)", "Q2: [3.5, Inf]"] + + x = cut([1:5; -Inf], 2) + @test x ≅ [fill("Q1: [-Inf, 2.5)", 2); fill("Q2: [2.5, 5.0]", 3); "Q1: [-Inf, 2.5)"] + @test levels(x) == ["Q1: [-Inf, 2.5)", "Q2: [2.5, 5.0]"] end + +end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 142bd15f..e59180e7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,8 @@ module TestCategoricalArrays using Test using CategoricalArrays + const ≊ = isequal + tests = [ "01_value.jl", "04_constructors.jl",