Skip to content

Commit

Permalink
Rewrote parts of scale matrix fixed point iteration code to use numba…
Browse files Browse the repository at this point in the history
… njit (hopefully faster).
  • Loading branch information
johannvk committed Jan 4, 2025
1 parent 05f0fbe commit 0216759
Showing 1 changed file with 64 additions and 1 deletion.
65 changes: 64 additions & 1 deletion skchange/costs/multivariate_t_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from skchange.utils.numba.stats import log_det_covariance


@jit(nopython=False)
def estimate_mle_cov_scale(centered_samples: np.ndarray, dof: float):
"""Estimate the scale parameter of the MLE covariance matrix."""
p = centered_samples.shape[1]
Expand All @@ -25,6 +26,7 @@ def estimate_mle_cov_scale(centered_samples: np.ndarray, dof: float):
return np.exp(log_alpha)


@jit(nopython=False)
def initial_scale_matrix_estimate(
centered_samples: np.ndarray, t_dof: float, num_zeroed_samples: int = 0
):
Expand Down Expand Up @@ -71,6 +73,33 @@ def scale_matrix_fixed_point_iteration(
return reconstructed_scale_matrix


@njit
def scale_matrix_fixed_point_iteration_njit(
scale_matrix: np.ndarray,
t_dof: float,
centered_samples: np.ndarray,
num_zeroed_samples: int = 0,
):
"""Compute the MLE covariance residual for a mv_t distribution."""
n, p = centered_samples.shape

# Subtract the number of 'zeroed' samples:
effective_num_samples = n - num_zeroed_samples

inv_cov_2d = np.linalg.solve(scale_matrix, np.eye(p))
z_scores = np.einsum("ij,jk,ik->i", centered_samples, inv_cov_2d, centered_samples)

sample_weight = (p + t_dof) / (t_dof + z_scores)
weighted_samples = centered_samples * sample_weight[:, np.newaxis]

reconstructed_scale_matrix = (
weighted_samples.T @ centered_samples
) / effective_num_samples

return reconstructed_scale_matrix


@jit(nopython=False)
def solve_mle_scale_matrix(
initial_scale_matrix: np.ndarray,
centered_samples: np.ndarray,
Expand All @@ -91,7 +120,41 @@ def solve_mle_scale_matrix(
centered_samples=centered_samples,
num_zeroed_samples=num_zeroed_samples,
)
residual = sla.norm(temp_cov_matrix - scale_matrix, ord="fro")

# Note: 'ord = None' computes the Frobenius norm.
residual = np.linalg.norm(temp_cov_matrix - scale_matrix, ord=None)

scale_matrix = temp_cov_matrix.copy()
if residual < reverse_tol:
break

return scale_matrix, iteration


@njit
def solve_mle_scale_matrix_njit(
initial_scale_matrix: np.ndarray,
centered_samples: np.ndarray,
t_dof: float,
num_zeroed_samples: int = 0,
max_iter: int = 50,
reverse_tol: float = 1.0e-3,
) -> np.ndarray:
"""Perform fixed point iterations for the MLE scale matrix."""
scale_matrix = initial_scale_matrix.copy()
temp_cov_matrix = initial_scale_matrix.copy()

# Compute the MLE covariance matrix using fixed point iteration:
for iteration in range(max_iter):
temp_cov_matrix = scale_matrix_fixed_point_iteration_njit(
scale_matrix=scale_matrix,
t_dof=t_dof,
centered_samples=centered_samples,
num_zeroed_samples=num_zeroed_samples,
)

# Note: 'ord = None' computes the Frobenius norm.
residual = np.linalg.norm(temp_cov_matrix - scale_matrix, ord=None)

scale_matrix = temp_cov_matrix.copy()
if residual < reverse_tol:
Expand Down

0 comments on commit 0216759

Please sign in to comment.