From 8eaa521270c17096586e45211f4df8a9f6bb57be Mon Sep 17 00:00:00 2001 From: "Thomas S. Binns" Date: Mon, 10 Feb 2025 15:12:39 +0000 Subject: [PATCH] Add support for n-dimensional arrays in `_tfr_from_mt` (#13104) Co-authored-by: Eric Larson --- mne/time_frequency/tfr.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index f4a01e87895..0c8bb0f4fb0 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -4291,19 +4291,20 @@ def _tfr_from_mt(x_mt, weights): Parameters ---------- - x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times) + x_mt : array, shape (..., n_tapers, n_freqs, n_times) The complex-valued multitaper coefficients. weights : array, shape (n_tapers, n_freqs) The weights to use to combine the tapered estimates. Returns ------- - tfr : array, shape (n_channels, n_freqs, n_times) + tfr : array, shape (..., n_freqs, n_times) The time-frequency power estimates. """ - weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims + # add singleton dim for time and any dims preceding the tapers + weights = weights[..., np.newaxis] tfr = weights * x_mt tfr *= tfr.conj() - tfr = tfr.real.sum(axis=1) - tfr *= 2 / (weights * weights.conj()).real.sum(axis=1) + tfr = tfr.real.sum(axis=-3) + tfr *= 2 / (weights * weights.conj()).real.sum(axis=-3) return tfr