Skip to content

Commit

Permalink
fix: use engine linalg in WLS
Browse files Browse the repository at this point in the history
  • Loading branch information
gavincyi committed Sep 22, 2023
1 parent cc8b9f5 commit 98ea2ed
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/fpm_risk_model/regressor/wls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
from typing import Optional

from numpy import ndarray
from numpy.linalg import pinv

from ..engine import LinAlgEngine

linalg = LinAlgEngine()


@dataclass
Expand Down Expand Up @@ -85,14 +88,14 @@ def _close_fit(X: ndarray, y: ndarray, weights: Optional[ndarray] = None):
if len(weights.shape) == 1 and weights.shape[0] == y.shape[0]:
weights = weights**0.5
X_t_w = X.T * weights * weights.T
beta = pinv(X_t_w @ X) @ X_t_w @ y
beta = linalg.pinv(X_t_w @ X) @ X_t_w @ y
else:
raise ValueError(
f"Dimension of y {y.shape} does not align with weights "
f"{weights.shape}"
)
else:
beta = pinv(X.T @ X) @ X.T @ y
beta = linalg.pinv(X.T @ X) @ X.T @ y

alpha = y - X @ beta
return RegressionResult(alpha=alpha, beta=beta)

0 comments on commit 98ea2ed

Please sign in to comment.