Skip to content

Commit

Permalink
Add support for n-dimensional arrays in _tfr_from_mt (mne-tools#13104)
Browse files Browse the repository at this point in the history
Co-authored-by: Eric Larson <larson.eric.d@gmail.com>
  • Loading branch information
tsbinns and larsoner authored Feb 10, 2025
1 parent e4cc4e2 commit 8eaa521
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions mne/time_frequency/tfr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4291,19 +4291,20 @@ def _tfr_from_mt(x_mt, weights):
Parameters
----------
x_mt : array, shape (n_channels, n_tapers, n_freqs, n_times)
x_mt : array, shape (..., n_tapers, n_freqs, n_times)
The complex-valued multitaper coefficients.
weights : array, shape (n_tapers, n_freqs)
The weights to use to combine the tapered estimates.
Returns
-------
tfr : array, shape (n_channels, n_freqs, n_times)
tfr : array, shape (..., n_freqs, n_times)
The time-frequency power estimates.
"""
weights = weights[np.newaxis, :, :, np.newaxis] # add singleton channel & time dims
# add singleton dim for time and any dims preceding the tapers
weights = weights[..., np.newaxis]
tfr = weights * x_mt
tfr *= tfr.conj()
tfr = tfr.real.sum(axis=1)
tfr *= 2 / (weights * weights.conj()).real.sum(axis=1)
tfr = tfr.real.sum(axis=-3)
tfr *= 2 / (weights * weights.conj()).real.sum(axis=-3)
return tfr

0 comments on commit 8eaa521

Please sign in to comment.