diff --git a/skchange/costs/multivariate_t_cost.py b/skchange/costs/multivariate_t_cost.py index 26d62e7..2194395 100644 --- a/skchange/costs/multivariate_t_cost.py +++ b/skchange/costs/multivariate_t_cost.py @@ -16,6 +16,7 @@ from skchange.utils.numba.stats import log_det_covariance +@jit(nopython=False) def estimate_mle_cov_scale(centered_samples: np.ndarray, dof: float): """Estimate the scale parameter of the MLE covariance matrix.""" p = centered_samples.shape[1] @@ -25,6 +26,7 @@ def estimate_mle_cov_scale(centered_samples: np.ndarray, dof: float): return np.exp(log_alpha) +@jit(nopython=False) def initial_scale_matrix_estimate( centered_samples: np.ndarray, t_dof: float, num_zeroed_samples: int = 0 ): @@ -71,6 +73,33 @@ def scale_matrix_fixed_point_iteration( return reconstructed_scale_matrix +@njit +def scale_matrix_fixed_point_iteration_njit( + scale_matrix: np.ndarray, + t_dof: float, + centered_samples: np.ndarray, + num_zeroed_samples: int = 0, +): + """Compute the MLE covariance residual for a mv_t distribution.""" + n, p = centered_samples.shape + + # Subtract the number of 'zeroed' samples: + effective_num_samples = n - num_zeroed_samples + + inv_cov_2d = np.linalg.solve(scale_matrix, np.eye(p)) + z_scores = np.einsum("ij,jk,ik->i", centered_samples, inv_cov_2d, centered_samples) + + sample_weight = (p + t_dof) / (t_dof + z_scores) + weighted_samples = centered_samples * sample_weight[:, np.newaxis] + + reconstructed_scale_matrix = ( + weighted_samples.T @ centered_samples + ) / effective_num_samples + + return reconstructed_scale_matrix + + +@jit(nopython=False) def solve_mle_scale_matrix( initial_scale_matrix: np.ndarray, centered_samples: np.ndarray, @@ -91,7 +120,41 @@ def solve_mle_scale_matrix( centered_samples=centered_samples, num_zeroed_samples=num_zeroed_samples, ) - residual = sla.norm(temp_cov_matrix - scale_matrix, ord="fro") + + # Note: 'ord = None' computes the Frobenius norm. + residual = np.linalg.norm(temp_cov_matrix - scale_matrix, ord=None) + + scale_matrix = temp_cov_matrix.copy() + if residual < reverse_tol: + break + + return scale_matrix, iteration + + +@njit +def solve_mle_scale_matrix_njit( + initial_scale_matrix: np.ndarray, + centered_samples: np.ndarray, + t_dof: float, + num_zeroed_samples: int = 0, + max_iter: int = 50, + reverse_tol: float = 1.0e-3, +) -> np.ndarray: + """Perform fixed point iterations for the MLE scale matrix.""" + scale_matrix = initial_scale_matrix.copy() + temp_cov_matrix = initial_scale_matrix.copy() + + # Compute the MLE covariance matrix using fixed point iteration: + for iteration in range(max_iter): + temp_cov_matrix = scale_matrix_fixed_point_iteration_njit( + scale_matrix=scale_matrix, + t_dof=t_dof, + centered_samples=centered_samples, + num_zeroed_samples=num_zeroed_samples, + ) + + # Note: 'ord = None' computes the Frobenius norm. + residual = np.linalg.norm(temp_cov_matrix - scale_matrix, ord=None) scale_matrix = temp_cov_matrix.copy() if residual < reverse_tol: