Skip to content

Commit

Permalink
Fix GlobalFrechet normalization
Browse files Browse the repository at this point in the history
  • Loading branch information
matthieubulte committed Jun 20, 2024
1 parent 92fc4af commit a848e3f
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions pyfrechet/regression/frechet_regression/global_frechet.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@ def fit(self, X, y: MetricData):
return self

def weights_for(self, x):
S_inv_dx = cho_solve(self.Sigma_chol_, (x - self.mu_).T).T
return self._normalize_weights(1 + np.sum(S_inv_dx * self.centered_x_train_, axis=1), sum_to_one=True, clip=True)
S_inv_dx = cho_solve(self.Sigma_chol_, x - self.mu_)
return (1 + self.centered_x_train_ @ S_inv_dx)/self.centered_x_train_.shape[0]

0 comments on commit a848e3f

Please sign in to comment.