From 9513c52b3b659e8fb74f150c35d0fb27663eb60f Mon Sep 17 00:00:00 2001 From: Luke Manley Date: Thu, 28 Nov 2024 09:44:38 -0500 Subject: [PATCH] fix: Improve `hist` binning around breakpoints (#20054) --- crates/polars-ops/src/chunked_array/hist.rs | 19 ++++++++----------- py-polars/tests/unit/operations/test_hist.py | 15 +++++++++++++++ 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/hist.rs b/crates/polars-ops/src/chunked_array/hist.rs index ca0b125cb0f3..efca210089f9 100644 --- a/crates/polars-ops/src/chunked_array/hist.rs +++ b/crates/polars-ops/src/chunked_array/hist.rs @@ -98,25 +98,22 @@ where let mut count: Vec = vec![0; num_bins]; let min_break: f64 = breaks[0]; let max_break: f64 = breaks[num_bins]; - let width = breaks[1] - min_break; // guaranteed at least one bin + let scale = num_bins as f64 / (max_break - min_break); for chunk in ca.downcast_iter() { for item in chunk.non_null_values_iter() { let item = item.to_f64().unwrap(); - if include_lower && item == min_break { - count[0] += 1; - } else if item == max_break { - count[num_bins - 1] += 1; - } else if item > min_break && item < max_break { - let width_multiple = (item - min_break) / width; - let idx = width_multiple.floor(); - // handle the case where item lands on the boundary - let idx = if idx == width_multiple { + if item > min_break && item <= max_break { + let idx = scale * (item - min_break); + let idx_floor = idx.floor(); + let idx = if idx == idx_floor { idx - 1.0 } else { - idx + idx_floor }; count[idx as usize] += 1; + } else if include_lower && item == min_break { + count[0] += 1; } } } diff --git a/py-polars/tests/unit/operations/test_hist.py b/py-polars/tests/unit/operations/test_hist.py index bb94f7b94d9b..ee7127e67df5 100644 --- a/py-polars/tests/unit/operations/test_hist.py +++ b/py-polars/tests/unit/operations/test_hist.py @@ -438,3 +438,18 @@ def test_hist_same_values_20030() -> None: } ) assert_frame_equal(out, expected) + + +def test_hist_breakpoint_accuracy() -> None: + s = pl.Series([1, 2, 3, 4]) + out = s.hist(bin_count=3) + expected = pl.DataFrame( + { + "breakpoint": pl.Series([2.0, 3.0, 4.0], dtype=pl.Float64), + "category": pl.Series( + ["(0.997, 2.0]", "(2.0, 3.0]", "(3.0, 4.0]"], dtype=pl.Categorical + ), + "count": pl.Series([2, 1, 1], dtype=pl.get_index_type()), + } + ) + assert_frame_equal(out, expected)