From e3a105599a6258f1de6c4ca92612cbc921238bff Mon Sep 17 00:00:00 2001 From: mrvillage Date: Thu, 14 Nov 2024 18:04:53 -0500 Subject: [PATCH] Missed functions for `LogisticModel` --- src/calc.rs | 42 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/src/calc.rs b/src/calc.rs index 6c500d0..f480343 100644 --- a/src/calc.rs +++ b/src/calc.rs @@ -579,6 +579,48 @@ pub struct LogisticModel { adj_r2: f64, } +impl LogisticModel { + #[inline] + #[cfg_attr(coverage_nightly, coverage(off))] + pub fn slopes(&self) -> &[f64] { + &self.slopes + } + + #[inline] + #[cfg_attr(coverage_nightly, coverage(off))] + pub fn intercept(&self) -> f64 { + self.intercept + } + + #[inline] + #[cfg_attr(coverage_nightly, coverage(off))] + pub fn predicted(&self) -> &[f64] { + &self.predicted + } + + #[inline] + #[cfg_attr(coverage_nightly, coverage(off))] + pub fn r2(&self) -> f64 { + self.r2 + } + + #[inline] + #[cfg_attr(coverage_nightly, coverage(off))] + pub fn adj_r2(&self) -> f64 { + self.adj_r2 + } + + pub fn predict(&self, x: &[f64]) -> f64 { + self.intercept + + self + .slopes + .iter() + .zip(x.iter()) + .map(|(a, b)| a * b) + .sum::() + } +} + #[inline(always)] fn logit(x: f64) -> f64 { (x / (1.0 - x)).ln()