Skip to content

Commit

Permalink
Update predicted
Browse files Browse the repository at this point in the history
  • Loading branch information
mrvillage committed Nov 13, 2024
1 parent 6f063da commit d7a670b
Showing 1 changed file with 15 additions and 6 deletions.
21 changes: 15 additions & 6 deletions src/calc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,10 @@ pub fn logistic_regression_irls(xs: MatRef<'_, f64>, ys: &[f64]) -> LogisticMode
let r2 = R2Simd::new(ys, &mu).calculate();
let adj_r2 = calculate_adj_r2(r2, ys.len(), xs.ncols());

if should_disable_predicted() {
mu = Vec::new();
}

LogisticModel {
slopes,
intercept,
Expand Down Expand Up @@ -769,15 +773,20 @@ pub fn logistic_regression_newton_raphson(xs: MatRef<'_, f64>, ys: &[f64]) -> Lo
}
beta.copy_from_slice(beta_new.try_as_slice().unwrap());
}
let predicted = (&x * faer::col::from_slice(beta.as_slice()))
.try_as_slice()
.unwrap()
.iter()
.map(|x| logistic(*x))
.collect();
let r2 = R2Simd::new(ys, &mu).calculate();
let adj_r2 = calculate_adj_r2(r2, ys.len(), xs.ncols());

let predicted = if should_disable_predicted() {
Vec::new()
} else {
(&x * faer::col::from_slice(beta.as_slice()))
.try_as_slice()
.unwrap()
.iter()
.map(|x| logistic(*x))
.collect()
};

LogisticModel {
predicted,
intercept: beta[x.ncols() - 1],
Expand Down

0 comments on commit d7a670b

Please sign in to comment.