diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index d8a99200783..18543b854d0 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -21,4 +21,4 @@ jobs: - run: pip install --upgrade towncrier pygithub gitpython numpy - run: python ./.github/actions/rename_towncrier/rename_towncrier.py - run: python ./tools/dev/ensure_headers.py - - uses: autofix-ci/action@ff86a557419858bb967097bfc916833f5647fa8c + - uses: autofix-ci/action@551dded8c6cc8a1054039c8bc0b8b48c51dfc6ef diff --git a/.github/workflows/check_changelog.yml b/.github/workflows/check_changelog.yml index cc85b591977..6995c399b34 100644 --- a/.github/workflows/check_changelog.yml +++ b/.github/workflows/check_changelog.yml @@ -5,6 +5,9 @@ on: # yamllint disable-line rule:truthy types: [opened, synchronize, labeled, unlabeled] branches: ["main"] +permissions: + contents: read + jobs: changelog_checker: name: Check towncrier entry in doc/changes/devel/ diff --git a/.github/workflows/circle_artifacts.yml b/.github/workflows/circle_artifacts.yml index fa32e1ce80c..301c6234eb5 100644 --- a/.github/workflows/circle_artifacts.yml +++ b/.github/workflows/circle_artifacts.yml @@ -1,4 +1,7 @@ on: [status] # yamllint disable-line rule:truthy +permissions: + contents: read + statuses: write jobs: circleci_artifacts_redirector_job: if: "${{ startsWith(github.event.context, 'ci/circleci: build_docs') }}" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cb769988655..fb5a6bd4247 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,7 @@ repos: # Ruff mne - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.9.1 + rev: v0.9.3 hooks: - id: ruff name: ruff lint mne @@ -23,7 +23,7 @@ repos: # Codespell - repo: https://github.com/codespell-project/codespell - rev: v2.3.0 + rev: v2.4.0 hooks: - id: codespell additional_dependencies: @@ -82,7 +82,7 @@ repos: # zizmor - repo: https://github.com/woodruffw/zizmor-pre-commit - rev: v1.1.1 + rev: v1.2.2 hooks: - id: zizmor diff --git a/doc/changes/devel/12071.newfeature.rst b/doc/changes/devel/12071.newfeature.rst new file mode 100644 index 00000000000..4e7995e3beb --- /dev/null +++ b/doc/changes/devel/12071.newfeature.rst @@ -0,0 +1 @@ +Add new ``select`` parameter to :func:`mne.viz.plot_evoked_topo` and :meth:`mne.Evoked.plot_topo` to toggle lasso selection of sensors, by `Marijn van Vliet`_. diff --git a/doc/changes/devel/12656.bugfix.rst b/doc/changes/devel/12656.bugfix.rst new file mode 100644 index 00000000000..3f32dbd23e5 --- /dev/null +++ b/doc/changes/devel/12656.bugfix.rst @@ -0,0 +1 @@ +Fix bug where :func:`mne.export.export_raw` does not correct for recording start time (:attr:`raw.first_time `) when exporting Raw instances to EDF or EEGLAB formats, by `Qian Chu`_. \ No newline at end of file diff --git a/doc/changes/devel/13065.bugfix.rst b/doc/changes/devel/13065.bugfix.rst new file mode 100644 index 00000000000..bbaa07ae127 --- /dev/null +++ b/doc/changes/devel/13065.bugfix.rst @@ -0,0 +1,7 @@ +Improved sklearn class compatibility and compliance, which resulted in some parameters of classes having an underscore appended to their name during ``fit``, such as: + +- :class:`mne.decoding.FilterEstimator` parameter ``picks`` passed to the initializer is set as ``est.picks_`` +- :class:`mne.decoding.UnsupervisedSpatialFilter` parameter ``estimator`` passed to the initializer is set as ``est.estimator_`` + +Unused ``verbose`` class parameters (that had no effect) were removed from :class:`~mne.decoding.PSDEstimator`, :class:`~mne.decoding.TemporalFilter`, and :class:`~mne.decoding.FilterEstimator` as well. +Changes by `Eric Larson`_. diff --git a/doc/changes/devel/13070.bugfix.rst b/doc/changes/devel/13070.bugfix.rst new file mode 100644 index 00000000000..3c6a3c25082 --- /dev/null +++ b/doc/changes/devel/13070.bugfix.rst @@ -0,0 +1 @@ +Return events when requested even when current matches the desired sfreq in :meth:`mne.io.Raw.resample` by :newcontrib:`Roy Eric Wieske`. \ No newline at end of file diff --git a/doc/changes/devel/13082.bugfix.rst b/doc/changes/devel/13082.bugfix.rst new file mode 100644 index 00000000000..0f5cad3d0af --- /dev/null +++ b/doc/changes/devel/13082.bugfix.rst @@ -0,0 +1 @@ +Fix bug with automated Mesa 3D detection for proper 3D option setting on systems with software rendering, by `Eric Larson`_. \ No newline at end of file diff --git a/doc/changes/names.inc b/doc/changes/names.inc index eb444c5e594..5a58ac0fa34 100644 --- a/doc/changes/names.inc +++ b/doc/changes/names.inc @@ -258,6 +258,7 @@ .. _Roman Goj: https://romanmne.blogspot.co.uk .. _Ross Maddox: https://www.urmc.rochester.edu/labs/maddox-lab.aspx .. _Rotem Falach: https://github.com/Falach +.. _Roy Eric Wieske: https://github.com/Randomidous .. _Sammi Chekroud: https://github.com/schekroud .. _Samu Taulu: https://phys.washington.edu/people/samu-taulu .. _Samuel Deslauriers-Gauthier: https://github.com/sdeslauriers diff --git a/doc/sphinxext/mne_doc_utils.py b/doc/sphinxext/mne_doc_utils.py index 7df361e4af1..e626838f251 100644 --- a/doc/sphinxext/mne_doc_utils.py +++ b/doc/sphinxext/mne_doc_utils.py @@ -97,6 +97,8 @@ def reset_warnings(gallery_conf, fname): r"numpy\.core is deprecated and has been renamed to numpy\._core", # matplotlib "__array_wrap__ must accept context and return_scalar.*", + # nibabel + "__array__ implementation doesn't accept.*", ): warnings.filterwarnings( # deal with other modules having bad imports "ignore", message=f".*{key}.*", category=DeprecationWarning diff --git a/doc/sphinxext/related_software.py b/doc/sphinxext/related_software.py index ab159b0fcb4..2548725390a 100644 --- a/doc/sphinxext/related_software.py +++ b/doc/sphinxext/related_software.py @@ -173,6 +173,7 @@ def _get_packages() -> dict[str, str]: packages = sorted(packages, key=lambda x: x.lower()) packages = [RENAMES.get(package, package) for package in packages] out = dict() + reasons = [] for package in status_iterator( packages, f"Adding {len(packages)} related software packages: " ): @@ -183,12 +184,17 @@ def _get_packages() -> dict[str, str]: else: md = importlib.metadata.metadata(package) except importlib.metadata.PackageNotFoundError: - pass # raise a complete error later + reasons.append(f"{package}: not found, needs to be installed") + continue # raise a complete error later else: # Every project should really have this + do_continue = False for key in ("Summary",): if key not in md: - raise ExtensionError(f"Missing {repr(key)} for {package}") + reasons.extend(f"{package}: missing {repr(key)}") + do_continue = True + if do_continue: + continue # It is annoying to find the home page url = None if "Home-page" in md: @@ -204,15 +210,17 @@ def _get_packages() -> dict[str, str]: if url is not None: break else: - raise RuntimeError( - f"Could not find Home-page for {package} in:\n" - f"{sorted(set(md))}\nwith Summary:\n{md['Summary']}" + reasons.append( + f"{package}: could not find Home-page in {sorted(md)}" ) + continue out[package]["url"] = url out[package]["description"] = md["Summary"].replace("\n", "") - bad = [package for package in packages if not out[package]] - if bad and REQUIRE_METADATA: - raise ExtensionError(f"Could not find metadata for:\n{' '.join(bad)}") + reason_str = "\n".join(reasons) + if reason_str and REQUIRE_METADATA: + raise ExtensionError( + f"Could not find suitable metadata for related software:\n{reason_str}" + ) return out diff --git a/examples/decoding/linear_model_patterns.py b/examples/decoding/linear_model_patterns.py index c1390cbb0d3..7373c0a18b3 100644 --- a/examples/decoding/linear_model_patterns.py +++ b/examples/decoding/linear_model_patterns.py @@ -79,7 +79,7 @@ # Extract and plot spatial filters and spatial patterns for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)): - # We fitted the linear model onto Z-scored data. To make the filters + # We fit the linear model on Z-scored data. To make the filters # interpretable, we must reverse this normalization step coef = scaler.inverse_transform([coef])[0] diff --git a/examples/preprocessing/movement_detection.py b/examples/preprocessing/movement_detection.py index 9bcac562588..dd468feb464 100644 --- a/examples/preprocessing/movement_detection.py +++ b/examples/preprocessing/movement_detection.py @@ -81,6 +81,7 @@ ############################################################################## # After checking the annotated movement artifacts, calculate the new transform # and plot it: + new_dev_head_t = compute_average_dev_head_t(raw, head_pos) raw.info["dev_head_t"] = new_dev_head_t fig = mne.viz.plot_alignment( diff --git a/mne/_fiff/proj.py b/mne/_fiff/proj.py index d6ec108e34d..aa010085904 100644 --- a/mne/_fiff/proj.py +++ b/mne/_fiff/proj.py @@ -1100,7 +1100,7 @@ def _has_eeg_average_ref_proj( def _needs_eeg_average_ref_proj(info): - """Determine if the EEG needs an averge EEG reference. + """Determine if the EEG needs an average EEG reference. This returns True if no custom reference has been applied and no average reference projection is present in the list of projections. diff --git a/mne/conftest.py b/mne/conftest.py index 8a4586067b3..2a73c7a1b8e 100644 --- a/mne/conftest.py +++ b/mne/conftest.py @@ -174,6 +174,7 @@ def pytest_configure(config: pytest.Config): # pandas ignore:\n*Pyarrow will become a required dependency of pandas.*:DeprecationWarning ignore:np\.find_common_type is deprecated.*:DeprecationWarning + ignore:Python binding for RankQuantileOptions.*: # pyvista <-> NumPy 2.0 ignore:__array_wrap__ must accept context and return_scalar arguments.*:DeprecationWarning # nibabel <-> NumPy 2.0 diff --git a/mne/cov.py b/mne/cov.py index 94239472fa2..694c836d0cd 100644 --- a/mne/cov.py +++ b/mne/cov.py @@ -1226,7 +1226,7 @@ def _compute_rank_raw_array( from .io import RawArray return _compute_rank( - RawArray(data, info, copy=None, verbose=_verbose_safe_false()), + RawArray(data, info, copy="auto", verbose=_verbose_safe_false()), rank, scalings, info, @@ -1405,7 +1405,7 @@ def _compute_covariance_auto( # project back cov = np.dot(eigvec.T, np.dot(cov, eigvec)) # undo bias - cov *= data.shape[0] / (data.shape[0] - 1) + cov *= data.shape[0] / max(data.shape[0] - 1, 1) # undo scaling _undo_scaling_cov(cov, picks_list, scalings) method_ = method[ei] diff --git a/mne/decoding/base.py b/mne/decoding/base.py index a291416bb17..f73cd976fe3 100644 --- a/mne/decoding/base.py +++ b/mne/decoding/base.py @@ -19,7 +19,7 @@ from sklearn.linear_model import LogisticRegression from sklearn.metrics import check_scoring from sklearn.model_selection import KFold, StratifiedKFold, check_cv -from sklearn.utils import check_array, indexable +from sklearn.utils import check_array, check_X_y, indexable from ..parallel import parallel_func from ..utils import _pl, logger, verbose, warn @@ -76,9 +76,9 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator): ) def __init__(self, model=None): + # TODO: We need to set this to get our tag checking to work properly if model is None: model = LogisticRegression(solver="liblinear") - self.model = model def __sklearn_tags__(self): @@ -122,7 +122,11 @@ def fit(self, X, y, **fit_params): self : instance of LinearModel Returns the modified instance. """ - X = check_array(X, input_name="X") + if y is not None: + X = check_array(X) + else: + X, y = check_X_y(X, y) + self.n_features_in_ = X.shape[1] if y is not None: y = check_array(y, dtype=None, ensure_2d=False, input_name="y") if y.ndim > 2: @@ -133,6 +137,7 @@ def fit(self, X, y, **fit_params): # fit the Model self.model.fit(X, y, **fit_params) + self.model_ = self.model # for better sklearn compat # Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y diff --git a/mne/decoding/csp.py b/mne/decoding/csp.py index 9e12335cdbe..ea38fd58ca3 100644 --- a/mne/decoding/csp.py +++ b/mne/decoding/csp.py @@ -6,7 +6,8 @@ import numpy as np from scipy.linalg import eigh -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted from .._fiff.meas_info import create_info from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh @@ -19,10 +20,11 @@ fill_doc, pinv, ) +from .transformer import MNETransformerMixin @fill_doc -class CSP(TransformerMixin, BaseEstimator): +class CSP(MNETransformerMixin, BaseEstimator): """M/EEG signal decomposition using the Common Spatial Patterns (CSP). This class can be used as a supervised decomposition to estimate spatial @@ -112,49 +114,44 @@ def __init__( component_order="mutual_info", ): # Init default CSP - if not isinstance(n_components, int): - raise ValueError("n_components must be an integer.") self.n_components = n_components self.rank = rank self.reg = reg - - # Init default cov_est - if not (cov_est == "concat" or cov_est == "epoch"): - raise ValueError("unknown covariance estimation method") self.cov_est = cov_est - - # Init default transform_into - self.transform_into = _check_option( - "transform_into", transform_into, ["average_power", "csp_space"] - ) - - # Init default log - if transform_into == "average_power": - if log is not None and not isinstance(log, bool): - raise ValueError( - 'log must be a boolean if transform_into == "average_power".' - ) - else: - if log is not None: - raise ValueError('log must be a None if transform_into == "csp_space".') + self.transform_into = transform_into self.log = log - - _validate_type(norm_trace, bool, "norm_trace") self.norm_trace = norm_trace self.cov_method_params = cov_method_params - self.component_order = _check_option( - "component_order", component_order, ("mutual_info", "alternate") + self.component_order = component_order + + def _validate_params(self, *, y): + _validate_type(self.n_components, int, "n_components") + if hasattr(self, "cov_est"): + _validate_type(self.cov_est, str, "cov_est") + _check_option("cov_est", self.cov_est, ("concat", "epoch")) + if hasattr(self, "norm_trace"): + _validate_type(self.norm_trace, bool, "norm_trace") + _check_option( + "transform_into", self.transform_into, ["average_power", "csp_space"] ) - - def _check_Xy(self, X, y=None): - """Check input data.""" - if not isinstance(X, np.ndarray): - raise ValueError(f"X should be of type ndarray (got {type(X)}).") - if y is not None: - if len(X) != len(y) or len(y) < 1: - raise ValueError("X and y must have the same length.") - if X.ndim < 3: - raise ValueError("X must have at least 3 dimensions.") + if self.transform_into == "average_power": + _validate_type( + self.log, + (bool, None), + "log", + extra="when transform_into is 'average_power'", + ) + else: + _validate_type( + self.log, None, "log", extra="when transform_into is 'csp_space'" + ) + _check_option( + "component_order", self.component_order, ("mutual_info", "alternate") + ) + self.classes_ = np.unique(y) + n_classes = len(self.classes_) + if n_classes < 2: + raise ValueError(f"n_classes must be >= 2, but got {n_classes} class") def fit(self, X, y): """Estimate the CSP decomposition on epochs. @@ -171,12 +168,9 @@ def fit(self, X, y): self : instance of CSP Returns the modified instance. """ - self._check_Xy(X, y) - - self._classes = np.unique(y) - n_classes = len(self._classes) - if n_classes < 2: - raise ValueError("n_classes must be >= 2.") + X, y = self._check_data(X, y=y, fit=True, return_y=True) + self._validate_params(y=y) + n_classes = len(self.classes_) if n_classes > 2 and self.component_order == "alternate": raise ValueError( "component_order='alternate' requires two classes, but data contains " @@ -225,13 +219,8 @@ def transform(self, X): If self.transform_into == 'csp_space' then returns the data in CSP space and shape is (n_epochs, n_components, n_times). """ - if not isinstance(X, np.ndarray): - raise ValueError(f"X should be of type ndarray (got {type(X)}).") - if self.filters_ is None: - raise RuntimeError( - "No filters available. Please first fit CSP decomposition." - ) - + check_is_fitted(self, "filters_") + X = self._check_data(X) pick_filters = self.filters_[: self.n_components] X = np.asarray([np.dot(pick_filters, epoch) for epoch in X]) @@ -577,7 +566,7 @@ def _compute_covariance_matrices(self, X, y): covs = [] sample_weights = [] - for ci, this_class in enumerate(self._classes): + for ci, this_class in enumerate(self.classes_): cov, weight = cov_estimator( X[y == this_class], cov_kind=f"class={this_class}", @@ -689,7 +678,7 @@ def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights): def _order_components( self, covs, sample_weights, eigen_vectors, eigen_values, component_order ): - n_classes = len(self._classes) + n_classes = len(self.classes_) if component_order == "mutual_info" and n_classes > 2: mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors) ix = np.argsort(mutual_info)[::-1] @@ -889,10 +878,8 @@ def fit(self, X, y): self : instance of SPoC Returns the modified instance. """ - self._check_Xy(X, y) - - if len(np.unique(y)) < 2: - raise ValueError("y must have at least two distinct values.") + X, y = self._check_data(X, y=y, fit=True, return_y=True) + self._validate_params(y=y) # The following code is directly copied from pyRiemann diff --git a/mne/decoding/ems.py b/mne/decoding/ems.py index b3e72a30e21..5c7557798ef 100644 --- a/mne/decoding/ems.py +++ b/mne/decoding/ems.py @@ -5,15 +5,16 @@ from collections import Counter import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator from .._fiff.pick import _picks_to_idx, pick_info, pick_types from ..parallel import parallel_func from ..utils import logger, verbose from .base import _set_cv +from .transformer import MNETransformerMixin -class EMS(TransformerMixin, BaseEstimator): +class EMS(MNETransformerMixin, BaseEstimator): """Transformer to compute event-matched spatial filters. This version of EMS :footcite:`SchurgerEtAl2013` operates on the entire @@ -37,6 +38,16 @@ class EMS(TransformerMixin, BaseEstimator): .. footbibliography:: """ + def __sklearn_tags__(self): + """Return sklearn tags.""" + from sklearn.utils import ClassifierTags + + tags = super().__sklearn_tags__() + if tags.classifier_tags is None: + tags.classifier_tags = ClassifierTags() + tags.classifier_tags.multi_class = False + return tags + def __repr__(self): # noqa: D105 if hasattr(self, "filters_"): return ( @@ -64,11 +75,12 @@ def fit(self, X, y): self : instance of EMS Returns self. """ - classes = np.unique(y) - if len(classes) != 2: + X, y = self._check_data(X, y=y, fit=True, return_y=True) + classes, y = np.unique(y, return_inverse=True) + if len(classes) > 2: raise ValueError("EMS only works for binary classification.") self.classes_ = classes - filters = X[y == classes[0]].mean(0) - X[y == classes[1]].mean(0) + filters = X[y == 0].mean(0) - X[y == 1].mean(0) filters /= np.linalg.norm(filters, axis=0)[None, :] self.filters_ = filters return self @@ -86,13 +98,14 @@ def transform(self, X): X : array, shape (n_epochs, n_times) The input data transformed by the spatial filters. """ + X = self._check_data(X) Xt = np.sum(X * self.filters_, axis=1) return Xt @verbose def compute_ems( - epochs, conditions=None, picks=None, n_jobs=None, cv=None, verbose=None + epochs, conditions=None, picks=None, n_jobs=None, cv=None, *, verbose=None ): """Compute event-matched spatial filter on epochs. diff --git a/mne/decoding/search_light.py b/mne/decoding/search_light.py index e3059a3e959..8bd96781185 100644 --- a/mne/decoding/search_light.py +++ b/mne/decoding/search_light.py @@ -5,18 +5,25 @@ import logging import numpy as np -from sklearn.base import BaseEstimator, MetaEstimatorMixin, TransformerMixin, clone +from sklearn.base import BaseEstimator, MetaEstimatorMixin, clone from sklearn.metrics import check_scoring from sklearn.preprocessing import LabelEncoder -from sklearn.utils import check_array +from sklearn.utils.validation import check_is_fitted from ..parallel import parallel_func -from ..utils import ProgressBar, _parse_verbose, array_split_idx, fill_doc, verbose +from ..utils import ( + ProgressBar, + _parse_verbose, + _verbose_safe_false, + array_split_idx, + fill_doc, +) from .base import _check_estimator +from .transformer import MNETransformerMixin @fill_doc -class SlidingEstimator(MetaEstimatorMixin, TransformerMixin, BaseEstimator): +class SlidingEstimator(MetaEstimatorMixin, MNETransformerMixin, BaseEstimator): """Search Light. Fit, predict and score a series of models to each subset of the dataset @@ -38,7 +45,6 @@ class SlidingEstimator(MetaEstimatorMixin, TransformerMixin, BaseEstimator): List of fitted scikit-learn estimators (one per task). """ - @verbose def __init__( self, base_estimator, @@ -49,7 +55,6 @@ def __init__( allow_2d=False, verbose=None, ): - _check_estimator(base_estimator) self.base_estimator = base_estimator self.n_jobs = n_jobs self.scoring = scoring @@ -102,9 +107,13 @@ def fit(self, X, y, **fit_params): self : object Return self. """ - X = self._check_Xy(X, y) + _check_estimator(self.base_estimator) + X, _ = self._check_Xy(X, y, fit=True) parallel, p_func, n_jobs = parallel_func( - _sl_fit, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_fit, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) self.estimators_ = list() self.fit_params_ = fit_params @@ -153,14 +162,19 @@ def fit_transform(self, X, y, **fit_params): def _transform(self, X, method): """Aux. function to make parallel predictions/transformation.""" - X = self._check_Xy(X) + X, is_nd = self._check_Xy(X) + orig_method = method + check_is_fitted(self) method = _check_method(self.base_estimator, method) if X.shape[-1] != len(self.estimators_): raise ValueError("The number of estimators does not match X.shape[-1]") # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_transform, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) X_splits = np.array_split(X, n_jobs, axis=-1) @@ -174,6 +188,10 @@ def _transform(self, X, method): ) y_pred = np.concatenate(y_pred, axis=1) + if orig_method == "transform": + y_pred = y_pred.astype(X.dtype) + if orig_method == "predict_proba" and not is_nd: + y_pred = y_pred[:, 0, :] return y_pred def transform(self, X): @@ -196,7 +214,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators) The transformed values generated by each estimator. """ # noqa: E501 - return self._transform(X, "transform").astype(X.dtype) + return self._transform(X, "transform") def predict(self, X): """Predict each data slice/task with a series of independent estimators. @@ -265,15 +283,12 @@ def decision_function(self, X): """ # noqa: E501 return self._transform(X, "decision_function") - def _check_Xy(self, X, y=None): + def _check_Xy(self, X, y=None, fit=False): """Aux. function to check input data.""" # Once we require sklearn 1.1+ we should do something like: - X = check_array(X, ensure_2d=False, allow_nd=True, input_name="X") - if y is not None: - y = check_array(y, dtype=None, ensure_2d=False, input_name="y") - if len(X) != len(y) or len(y) < 1: - raise ValueError("X and y must have the same length.") - if X.ndim < 3: + X = self._check_data(X, y=y, atleast_3d=False, fit=fit) + is_nd = X.ndim >= 3 + if not is_nd: err = None if not self.allow_2d: err = 3 @@ -282,7 +297,7 @@ def _check_Xy(self, X, y=None): if err: raise ValueError(f"X must have at least {err} dimensions.") X = X[..., np.newaxis] - return X + return X, is_nd def score(self, X, y): """Score each estimator on each task. @@ -307,7 +322,7 @@ def score(self, X, y): score : array, shape (n_samples, n_estimators) Score for each estimator/task. """ # noqa: E501 - X = self._check_Xy(X, y) + X, _ = self._check_Xy(X, y) if X.shape[-1] != len(self.estimators_): raise ValueError("The number of estimators does not match X.shape[-1]") @@ -317,7 +332,10 @@ def score(self, X, y): # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _sl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _sl_score, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) X_splits = np.array_split(X, n_jobs, axis=-1) est_splits = np.array_split(self.estimators_, n_jobs) @@ -483,11 +501,16 @@ def __repr__(self): # noqa: D105 def _transform(self, X, method): """Aux. function to make parallel predictions/transformation.""" - X = self._check_Xy(X) + X, is_nd = self._check_Xy(X) + check_is_fitted(self) + orig_method = method method = _check_method(self.base_estimator, method) parallel, p_func, n_jobs = parallel_func( - _gl_transform, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _gl_transform, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) context = _create_progressbar_context(self, X, "Transforming") @@ -500,6 +523,10 @@ def _transform(self, X, method): ) y_pred = np.concatenate(y_pred, axis=2) + if orig_method == "transform": + y_pred = y_pred.astype(X.dtype) + if orig_method == "predict_proba" and not is_nd: + y_pred = y_pred[:, 0, 0, :] return y_pred def transform(self, X): @@ -518,6 +545,7 @@ def transform(self, X): Xt : array, shape (n_samples, n_estimators, n_slices) The transformed values generated by each estimator. """ + check_is_fitted(self) return self._transform(X, "transform") def predict(self, X): @@ -603,11 +631,14 @@ def score(self, X, y): score : array, shape (n_samples, n_estimators, n_slices) Score for each estimator / data slice couple. """ # noqa: E501 - X = self._check_Xy(X, y) + X, _ = self._check_Xy(X, y) # For predictions/transforms the parallelization is across the data and # not across the estimators to avoid memory load. parallel, p_func, n_jobs = parallel_func( - _gl_score, self.n_jobs, max_jobs=X.shape[-1], verbose=False + _gl_score, + self.n_jobs, + max_jobs=X.shape[-1], + verbose=_verbose_safe_false(), ) scoring = check_scoring(self.base_estimator, self.scoring) y = _fix_auc(scoring, y) diff --git a/mne/decoding/ssd.py b/mne/decoding/ssd.py index 8bc0036d315..111ded9f274 100644 --- a/mne/decoding/ssd.py +++ b/mne/decoding/ssd.py @@ -4,8 +4,10 @@ import numpy as np from scipy.linalg import eigh -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted +from .._fiff.meas_info import Info, create_info from .._fiff.pick import _picks_to_idx from ..cov import Covariance, _regularized_covariance from ..defaults import _handle_default @@ -13,17 +15,17 @@ from ..rank import compute_rank from ..time_frequency import psd_array_welch from ..utils import ( - _check_option, _time_mask, _validate_type, _verbose_safe_false, fill_doc, logger, ) +from .transformer import MNETransformerMixin @fill_doc -class SSD(TransformerMixin, BaseEstimator): +class SSD(MNETransformerMixin, BaseEstimator): """ Signal decomposition using the Spatio-Spectral Decomposition (SSD). @@ -64,7 +66,7 @@ class SSD(TransformerMixin, BaseEstimator): If sort_by_spectral_ratio is set to True, then the SSD sources will be sorted according to their spectral ratio which is calculated based on :func:`mne.time_frequency.psd_array_welch`. The n_fft parameter sets the - length of FFT used. + length of FFT used. The default (None) will use 1 second of data. See :func:`mne.time_frequency.psd_array_welch` for more information. cov_method_params : dict | None (default None) As in :class:`mne.decoding.SPoC` @@ -104,7 +106,25 @@ def __init__( rank=None, ): """Initialize instance.""" - dicts = {"signal": filt_params_signal, "noise": filt_params_noise} + self.info = info + self.filt_params_signal = filt_params_signal + self.filt_params_noise = filt_params_noise + self.reg = reg + self.n_components = n_components + self.picks = picks + self.sort_by_spectral_ratio = sort_by_spectral_ratio + self.return_filtered = return_filtered + self.n_fft = n_fft + self.cov_method_params = cov_method_params + self.rank = rank + + def _validate_params(self, X): + if isinstance(self.info, float): # special case, mostly for testing + self.sfreq_ = self.info + else: + _validate_type(self.info, Info, "info") + self.sfreq_ = self.info["sfreq"] + dicts = {"signal": self.filt_params_signal, "noise": self.filt_params_noise} for param, dd in [("l", 0), ("h", 0), ("l", 1), ("h", 1)]: key = ("signal", "noise")[dd] if param + "_freq" not in dicts[key]: @@ -116,48 +136,47 @@ def __init__( _validate_type(val, ("numeric",), f"{key} {param}_freq") # check freq bands if ( - filt_params_noise["l_freq"] > filt_params_signal["l_freq"] - or filt_params_signal["h_freq"] > filt_params_noise["h_freq"] + self.filt_params_noise["l_freq"] > self.filt_params_signal["l_freq"] + or self.filt_params_signal["h_freq"] > self.filt_params_noise["h_freq"] ): raise ValueError( "Wrongly specified frequency bands!\n" "The signal band-pass must be within the noise " "band-pass!" ) - self.picks = picks - del picks - self.info = info - self.freqs_signal = (filt_params_signal["l_freq"], filt_params_signal["h_freq"]) - self.freqs_noise = (filt_params_noise["l_freq"], filt_params_noise["h_freq"]) - self.filt_params_signal = filt_params_signal - self.filt_params_noise = filt_params_noise - # check if boolean - if not isinstance(sort_by_spectral_ratio, (bool)): - raise ValueError("sort_by_spectral_ratio must be boolean") - self.sort_by_spectral_ratio = sort_by_spectral_ratio - if n_fft is None: - self.n_fft = int(self.info["sfreq"]) - else: - self.n_fft = int(n_fft) - # check if boolean - if not isinstance(return_filtered, (bool)): - raise ValueError("return_filtered must be boolean") - self.return_filtered = return_filtered - self.reg = reg - self.n_components = n_components - self.rank = rank - self.cov_method_params = cov_method_params + self.freqs_signal_ = ( + self.filt_params_signal["l_freq"], + self.filt_params_signal["h_freq"], + ) + self.freqs_noise_ = ( + self.filt_params_noise["l_freq"], + self.filt_params_noise["h_freq"], + ) + _validate_type(self.sort_by_spectral_ratio, (bool,), "sort_by_spectral_ratio") + _validate_type(self.n_fft, ("numeric", None), "n_fft") + self.n_fft_ = min( + int(self.n_fft if self.n_fft is not None else self.sfreq_), + X.shape[-1], + ) + _validate_type(self.return_filtered, (bool,), "return_filtered") + if isinstance(self.info, Info): + ch_types = self.info.get_channel_types(picks=self.picks, unique=True) + if len(ch_types) > 1: + raise ValueError( + "At this point SSD only supports fitting " + f"single channel types. Your info has {len(ch_types)} types." + ) - def _check_X(self, X): + def _check_X(self, X, *, y=None, fit=False): """Check input data.""" - _validate_type(X, np.ndarray, "X") - _check_option("X.ndim", X.ndim, (2, 3)) + X = self._check_data(X, y=y, fit=fit, atleast_3d=False) n_chan = X.shape[-2] - if n_chan != self.info["nchan"]: + if isinstance(self.info, Info) and n_chan != self.info["nchan"]: raise ValueError( "Info must match the input data." f"Found {n_chan} channels but expected {self.info['nchan']}." ) + return X def fit(self, X, y=None): """Estimate the SSD decomposition on raw or epoched data. @@ -176,18 +195,17 @@ def fit(self, X, y=None): self : instance of SSD Returns the modified instance. """ - ch_types = self.info.get_channel_types(picks=self.picks, unique=True) - if len(ch_types) > 1: - raise ValueError( - "At this point SSD only supports fitting " - f"single channel types. Your info has {len(ch_types)} types." - ) - self.picks_ = _picks_to_idx(self.info, self.picks, none="data", exclude="bads") - self._check_X(X) + X = self._check_X(X, y=y, fit=True) + self._validate_params(X) + if isinstance(self.info, Info): + info = self.info + else: + info = create_info(X.shape[-2], self.sfreq_, ch_types="eeg") + self.picks_ = _picks_to_idx(info, self.picks, none="data", exclude="bads") X_aux = X[..., self.picks_, :] - X_signal = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) - X_noise = filter_data(X_aux, self.info["sfreq"], **self.filt_params_noise) + X_signal = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) + X_noise = filter_data(X_aux, self.sfreq_, **self.filt_params_noise) X_noise -= X_signal if X.ndim == 3: X_signal = np.hstack(X_signal) @@ -199,19 +217,19 @@ def fit(self, X, y=None): reg=self.reg, method_params=self.cov_method_params, rank="full", - info=self.info, + info=info, ) cov_noise = _regularized_covariance( X_noise, reg=self.reg, method_params=self.cov_method_params, rank="full", - info=self.info, + info=info, ) # project cov to rank subspace cov_signal, cov_noise, rank_proj = _dimensionality_reduction( - cov_signal, cov_noise, self.info, self.rank + cov_signal, cov_noise, info, self.rank ) eigvals_, eigvects_ = eigh(cov_signal, cov_noise) @@ -226,10 +244,10 @@ def fit(self, X, y=None): # than the initial ordering. This ordering should be also learned when # fitting. X_ssd = self.filters_.T @ X[..., self.picks_, :] - sorter_spec = Ellipsis + sorter_spec = slice(None) if self.sort_by_spectral_ratio: _, sorter_spec = self.get_spectral_ratio(ssd_sources=X_ssd) - self.sorter_spec = sorter_spec + self.sorter_spec_ = sorter_spec logger.info("Done.") return self @@ -248,17 +266,13 @@ def transform(self, X): X_ssd : array, shape ([n_epochs, ]n_components, n_times) The processed data. """ - self._check_X(X) - if self.filters_ is None: - raise RuntimeError("No filters available. Please first call fit") + check_is_fitted(self, "filters_") + X = self._check_X(X) if self.return_filtered: X_aux = X[..., self.picks_, :] - X = filter_data(X_aux, self.info["sfreq"], **self.filt_params_signal) + X = filter_data(X_aux, self.sfreq_, **self.filt_params_signal) X_ssd = self.filters_.T @ X[..., self.picks_, :] - if X.ndim == 2: - X_ssd = X_ssd[self.sorter_spec][: self.n_components] - else: - X_ssd = X_ssd[:, self.sorter_spec, :][:, : self.n_components, :] + X_ssd = X_ssd[..., self.sorter_spec_, :][..., : self.n_components, :] return X_ssd def fit_transform(self, X, y=None, **fit_params): @@ -308,11 +322,9 @@ def get_spectral_ratio(self, ssd_sources): ---------- .. footbibliography:: """ - psd, freqs = psd_array_welch( - ssd_sources, sfreq=self.info["sfreq"], n_fft=self.n_fft - ) - sig_idx = _time_mask(freqs, *self.freqs_signal) - noise_idx = _time_mask(freqs, *self.freqs_noise) + psd, freqs = psd_array_welch(ssd_sources, sfreq=self.sfreq_, n_fft=self.n_fft_) + sig_idx = _time_mask(freqs, *self.freqs_signal_) + noise_idx = _time_mask(freqs, *self.freqs_noise_) if psd.ndim == 3: mean_sig = psd[:, :, sig_idx].mean(axis=2).mean(axis=0) mean_noise = psd[:, :, noise_idx].mean(axis=2).mean(axis=0) @@ -352,7 +364,7 @@ def apply(self, X): The processed data. """ X_ssd = self.transform(X) - pick_patterns = self.patterns_[self.sorter_spec][: self.n_components].T + pick_patterns = self.patterns_[self.sorter_spec_][: self.n_components].T X = pick_patterns @ X_ssd return X diff --git a/mne/decoding/tests/test_base.py b/mne/decoding/tests/test_base.py index 6d915dd24f9..504e309d53c 100644 --- a/mne/decoding/tests/test_base.py +++ b/mne/decoding/tests/test_base.py @@ -86,6 +86,8 @@ def _make_data(n_samples=1000, n_features=5, n_targets=3): X = Y.dot(A.T) X += np.random.randn(n_samples, n_features) # add noise X += np.random.rand(n_features) # Put an offset + if n_targets == 1: + Y = Y[:, 0] return X, Y, A @@ -95,7 +97,7 @@ def test_get_coef(): """Test getting linear coefficients (filters/patterns) from estimators.""" lm_classification = LinearModel() assert hasattr(lm_classification, "__sklearn_tags__") - print(lm_classification.__sklearn_tags__) + print(lm_classification.__sklearn_tags__()) assert is_classifier(lm_classification.model) assert is_classifier(lm_classification) assert not is_regressor(lm_classification.model) @@ -273,7 +275,12 @@ def test_get_coef_multiclass(n_features, n_targets): """Test get_coef on multiclass problems.""" # Check patterns with more than 1 regressor X, Y, A = _make_data(n_samples=30000, n_features=n_features, n_targets=n_targets) - lm = LinearModel(LinearRegression()).fit(X, Y) + lm = LinearModel(LinearRegression()) + assert not hasattr(lm, "model_") + lm.fit(X, Y) + # TODO: modifying non-underscored `model` is a sklearn no-no, maybe should be a + # metaestimator? + assert lm.model is lm.model_ assert_array_equal(lm.filters_.shape, lm.patterns_.shape) if n_targets == 1: want_shape = (n_features,) @@ -473,9 +480,8 @@ def test_cross_val_multiscore(): def test_sklearn_compliance(estimator, check): """Test LinearModel compliance with sklearn.""" ignores = ( - "check_n_features_in", # maybe we should add this someday? - "check_estimator_sparse_data", # we densify "check_estimators_overwrite_params", # self.model changes! + "check_dont_overwrite_parameters", "check_parameters_default_constructible", ) if any(ignore in str(check) for ignore in ignores): diff --git a/mne/decoding/tests/test_csp.py b/mne/decoding/tests/test_csp.py index 7a1a83feeaf..e754b6952f9 100644 --- a/mne/decoding/tests/test_csp.py +++ b/mne/decoding/tests/test_csp.py @@ -19,6 +19,7 @@ from sklearn.model_selection import StratifiedKFold, cross_val_score from sklearn.pipeline import Pipeline, make_pipeline from sklearn.svm import SVC +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import Epochs, compute_proj_raw, io, pick_types, read_events from mne.decoding import CSP, LinearModel, Scaler, SPoC, get_coef @@ -139,18 +140,22 @@ def test_csp(): y = epochs.events[:, -1] # Init - pytest.raises(ValueError, CSP, n_components="foo", norm_trace=False) + csp = CSP(n_components="foo") + with pytest.raises(TypeError, match="must be an instance"): + csp.fit(epochs_data, y) for reg in ["foo", -0.1, 1.1]: csp = CSP(reg=reg, norm_trace=False) pytest.raises(ValueError, csp.fit, epochs_data, epochs.events[:, -1]) for reg in ["oas", "ledoit_wolf", 0, 0.5, 1.0]: CSP(reg=reg, norm_trace=False) - for cov_est in ["foo", None]: - pytest.raises(ValueError, CSP, cov_est=cov_est, norm_trace=False) + csp = CSP(cov_est="foo", norm_trace=False) + with pytest.raises(ValueError, match="Invalid value"): + csp.fit(epochs_data, y) + csp = CSP(norm_trace="foo") with pytest.raises(TypeError, match="instance of bool"): - CSP(norm_trace="foo") + csp.fit(epochs_data, y) for cov_est in ["concat", "epoch"]: - CSP(cov_est=cov_est, norm_trace=False) + CSP(cov_est=cov_est, norm_trace=False).fit(epochs_data, y) n_components = 3 # Fit @@ -171,8 +176,8 @@ def test_csp(): # Test data exception pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) - pytest.raises(ValueError, csp.fit, epochs, y) - pytest.raises(ValueError, csp.transform, epochs) + pytest.raises(ValueError, csp.fit, "foo", y) + pytest.raises(ValueError, csp.transform, "foo") # Test plots epochs.pick(picks="mag") @@ -200,7 +205,7 @@ def test_csp(): for cov_est in ["concat", "epoch"]: csp = CSP(n_components=n_components, cov_est=cov_est, norm_trace=False) csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) - assert_equal(len(csp._classes), 4) + assert_equal(len(csp.classes_), 4) assert_array_equal(csp.filters_.shape, [n_channels, n_channels]) assert_array_equal(csp.patterns_.shape, [n_channels, n_channels]) @@ -220,15 +225,17 @@ def test_csp(): # Different normalization return different transform assert np.sum((X_trans["True"] - X_trans["False"]) ** 2) > 1.0 # Check wrong inputs - pytest.raises(ValueError, CSP, transform_into="average_power", log="foo") + csp = CSP(transform_into="average_power", log="foo") + with pytest.raises(TypeError, match="must be an instance of bool"): + csp.fit(epochs_data, epochs.events[:, 2]) # Test csp space transform csp = CSP(transform_into="csp_space", norm_trace=False) assert csp.transform_into == "csp_space" for log in ("foo", True, False): - pytest.raises( - ValueError, CSP, transform_into="csp_space", log=log, norm_trace=False - ) + csp = CSP(transform_into="csp_space", log=log, norm_trace=False) + with pytest.raises(TypeError, match="must be an instance"): + csp.fit(epochs_data, epochs.events[:, 2]) n_components = 2 csp = CSP(n_components=n_components, transform_into="csp_space", norm_trace=False) Xt = csp.fit(epochs_data, epochs.events[:, 2]).transform(epochs_data) @@ -343,8 +350,8 @@ def test_regularized_csp(ch_type, rank, reg): # test init exception pytest.raises(ValueError, csp.fit, epochs_data, np.zeros_like(epochs.events)) - pytest.raises(ValueError, csp.fit, epochs, y) - pytest.raises(ValueError, csp.transform, epochs) + pytest.raises(ValueError, csp.fit, "foo", y) + pytest.raises(ValueError, csp.transform, "foo") csp.n_components = n_components sources = csp.transform(epochs_data) @@ -465,7 +472,9 @@ def test_csp_component_ordering(): """Test that CSP component ordering works as expected.""" x, y = deterministic_toy_data(["class_a", "class_b"]) - pytest.raises(ValueError, CSP, component_order="invalid") + csp = CSP(component_order="invalid") + with pytest.raises(ValueError, match="Invalid value"): + csp.fit(x, y) # component_order='alternate' only works with two classes csp = CSP(component_order="alternate") @@ -480,3 +489,10 @@ def test_csp_component_ordering(): # p_alt arranges them to [0.8, 0.06, 0.5, 0.1] # p_mut arranges them to [0.06, 0.1, 0.8, 0.5] assert_array_almost_equal(p_alt, p_mut[[2, 0, 3, 1]]) + + +@pytest.mark.filterwarnings("ignore:.*Only one sample available.*") +@parametrize_with_checks([CSP(), SPoC()]) +def test_sklearn_compliance(estimator, check): + """Test compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_ems.py b/mne/decoding/tests/test_ems.py index 10774c0681a..dc54303a541 100644 --- a/mne/decoding/tests/test_ems.py +++ b/mne/decoding/tests/test_ems.py @@ -11,6 +11,7 @@ pytest.importorskip("sklearn") from sklearn.model_selection import StratifiedKFold +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import Epochs, io, pick_types, read_events from mne.decoding import EMS, compute_ems @@ -91,3 +92,9 @@ def test_ems(): assert_equal(ems.__repr__(), "") assert_array_almost_equal(filters, np.mean(coefs, axis=0)) assert_array_almost_equal(surrogates, np.vstack(Xt)) + + +@parametrize_with_checks([EMS()]) +def test_sklearn_compliance(estimator, check): + """Test compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_search_light.py b/mne/decoding/tests/test_search_light.py index 7cb3a66dd81..e7abfd9209e 100644 --- a/mne/decoding/tests/test_search_light.py +++ b/mne/decoding/tests/test_search_light.py @@ -41,7 +41,7 @@ def make_data(): return X, y -def test_search_light(): +def test_search_light_basic(): """Test SlidingEstimator.""" # https://github.com/scikit-learn/scikit-learn/issues/27711 if platform.system() == "Windows" and check_version("numpy", "2.0.0.dev0"): @@ -52,7 +52,9 @@ def test_search_light(): X, y = make_data() n_epochs, _, n_time = X.shape # init - pytest.raises(ValueError, SlidingEstimator, "foo") + sl = SlidingEstimator("foo") + with pytest.raises(ValueError, match="must be"): + sl.fit(X, y) sl = SlidingEstimator(Ridge()) assert not is_classifier(sl) sl = SlidingEstimator(LogisticRegression(solver="liblinear")) @@ -69,7 +71,8 @@ def test_search_light(): # transforms pytest.raises(ValueError, sl.predict, X[:, :, :2]) y_trans = sl.transform(X) - assert X.dtype == y_trans.dtype == np.dtype(float) + assert X.dtype == float + assert y_trans.dtype == float y_pred = sl.predict(X) assert y_pred.dtype == np.dtype(int) assert_array_equal(y_pred.shape, [n_epochs, n_time]) @@ -344,22 +347,19 @@ def predict_proba(self, X): @pytest.mark.slowtest -@parametrize_with_checks([SlidingEstimator(LogisticRegression(), allow_2d=True)]) +@parametrize_with_checks( + [ + SlidingEstimator(LogisticRegression(), allow_2d=True), + GeneralizingEstimator(LogisticRegression(), allow_2d=True), + ] +) def test_sklearn_compliance(estimator, check): """Test LinearModel compliance with sklearn.""" ignores = ( - "check_estimator_sparse_data", # we densify - "check_classifiers_one_label_sample_weights", # don't handle singleton - "check_classifiers_classes", # dim mismatch + # TODO: we don't handle singleton right (probably) + "check_classifiers_one_label_sample_weights", + "check_classifiers_classes", "check_classifiers_train", - "check_decision_proba_consistency", - "check_parameters_default_constructible", - # Should probably fix these? - "check_estimators_unfitted", - "check_transformer_data_not_an_array", - "check_n_features_in", - "check_fit2d_predict1d", - "check_do_not_raise_errors_in_init_or_set_params", ) if any(ignore in str(check) for ignore in ignores): return diff --git a/mne/decoding/tests/test_ssd.py b/mne/decoding/tests/test_ssd.py index 8f4d2472803..b6cdfc472c3 100644 --- a/mne/decoding/tests/test_ssd.py +++ b/mne/decoding/tests/test_ssd.py @@ -11,6 +11,7 @@ pytest.importorskip("sklearn") from sklearn.pipeline import Pipeline +from sklearn.utils.estimator_checks import parametrize_with_checks from mne import create_info, io from mne.decoding import CSP @@ -101,8 +102,9 @@ def test_ssd(): l_trans_bandwidth=1, h_trans_bandwidth=1, ) + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(TypeError, match="must be an instance "): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # Wrongly specified noise band freq = 2 @@ -115,14 +117,16 @@ def test_ssd(): l_trans_bandwidth=1, h_trans_bandwidth=1, ) + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(ValueError, match="Wrongly specified "): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # filt param no dict filt_params_signal = freqs_sig filt_params_noise = freqs_noise + ssd = SSD(info, filt_params_signal, filt_params_noise) with pytest.raises(ValueError, match="must be defined"): - ssd = SSD(info, filt_params_signal, filt_params_noise) + ssd.fit(X) # Data type filt_params_signal = dict( @@ -140,15 +144,18 @@ def test_ssd(): ssd = SSD(info, filt_params_signal, filt_params_noise) raw = io.RawArray(X, info) - pytest.raises(TypeError, ssd.fit, raw) + with pytest.raises(ValueError): + ssd.fit(raw) # check non-boolean return_filtered - with pytest.raises(ValueError, match="return_filtered"): - ssd = SSD(info, filt_params_signal, filt_params_noise, return_filtered=0) + ssd = SSD(info, filt_params_signal, filt_params_noise, return_filtered=0) + with pytest.raises(TypeError, match="return_filtered"): + ssd.fit(X) # check non-boolean sort_by_spectral_ratio - with pytest.raises(ValueError, match="sort_by_spectral_ratio"): - ssd = SSD(info, filt_params_signal, filt_params_noise, sort_by_spectral_ratio=0) + ssd = SSD(info, filt_params_signal, filt_params_noise, sort_by_spectral_ratio=0) + with pytest.raises(TypeError, match="sort_by_spectral_ratio"): + ssd.fit(X) # More than 1 channel type ch_types = np.reshape([["mag"] * 10, ["eeg"] * 10], n_channels) @@ -161,7 +168,8 @@ def test_ssd(): # Number of channels info_3 = create_info(ch_names=n_channels + 1, sfreq=sf, ch_types="eeg") ssd = SSD(info_3, filt_params_signal, filt_params_noise) - pytest.raises(ValueError, ssd.fit, X) + with pytest.raises(ValueError, match="channels but expected"): + ssd.fit(X) # Fit n_components = 10 @@ -381,7 +389,7 @@ def test_sorting(): ssd.fit(Xtr) # check sorters - sorter_in = ssd.sorter_spec + sorter_in = ssd.sorter_spec_ ssd = SSD( info, filt_params_signal, @@ -476,3 +484,29 @@ def test_non_full_rank_data(): if sys.platform == "darwin": pytest.xfail("Unknown linalg bug (Accelerate?)") ssd.fit(X) + + +@pytest.mark.filterwarnings("ignore:.*invalid value encountered in divide.*") +@pytest.mark.filterwarnings("ignore:.*is longer than.*") +@parametrize_with_checks( + [ + SSD( + 100.0, + dict(l_freq=0.0, h_freq=30.0), + dict(l_freq=0.0, h_freq=40.0), + ) + ] +) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + ignores = ( + "check_methods_sample_order_invariance", + # Shape stuff + "check_fit_idempotent", + "check_methods_subset_invariance", + "check_transformer_general", + "check_transformer_data_not_an_array", + ) + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/tests/test_time_frequency.py b/mne/decoding/tests/test_time_frequency.py index 37e7d7d8dc2..638cebda21e 100644 --- a/mne/decoding/tests/test_time_frequency.py +++ b/mne/decoding/tests/test_time_frequency.py @@ -10,18 +10,23 @@ pytest.importorskip("sklearn") from sklearn.base import clone +from sklearn.utils.estimator_checks import parametrize_with_checks from mne.decoding.time_frequency import TimeFrequency -def test_timefrequency(): +def test_timefrequency_basic(): """Test TimeFrequency.""" # Init n_freqs = 3 freqs = [20, 21, 22] tf = TimeFrequency(freqs, sfreq=100) + n_epochs, n_chans, n_times = 10, 2, 100 + X = np.random.rand(n_epochs, n_chans, n_times) for output in ["avg_power", "foo", None]: - pytest.raises(ValueError, TimeFrequency, freqs, output=output) + tf = TimeFrequency(freqs, output=output) + with pytest.raises(ValueError, match="Invalid value"): + tf.fit(X) tf = clone(tf) # Clone estimator @@ -30,9 +35,9 @@ def test_timefrequency(): clone(tf) # Fit - n_epochs, n_chans, n_times = 10, 2, 100 - X = np.random.rand(n_epochs, n_chans, n_times) + assert not hasattr(tf, "fitted_") tf.fit(X, None) + assert tf.fitted_ # Transform tf = TimeFrequency(freqs, sfreq=100) @@ -41,9 +46,15 @@ def test_timefrequency(): Xt = tf.transform(X) assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times]) # 2-D X - Xt = tf.transform(X[:, 0, :]) + Xt = tf.fit_transform(X[:, 0, :]) assert_array_equal(Xt.shape, [n_epochs, n_freqs, n_times]) # 3-D with decim tf = TimeFrequency(freqs, sfreq=100, decim=2) - Xt = tf.transform(X) + Xt = tf.fit_transform(X) assert_array_equal(Xt.shape, [n_epochs, n_chans, n_freqs, n_times // 2]) + + +@parametrize_with_checks([TimeFrequency([300, 400], 1000.0, n_cycles=0.25)]) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + check(estimator) diff --git a/mne/decoding/tests/test_transformer.py b/mne/decoding/tests/test_transformer.py index 8dcc3ad74c7..a8afe209d96 100644 --- a/mne/decoding/tests/test_transformer.py +++ b/mne/decoding/tests/test_transformer.py @@ -17,10 +17,14 @@ from sklearn.decomposition import PCA from sklearn.kernel_ridge import KernelRidge +from sklearn.pipeline import make_pipeline +from sklearn.preprocessing import StandardScaler +from sklearn.utils.estimator_checks import parametrize_with_checks -from mne import Epochs, io, pick_types, read_events +from mne import Epochs, EpochsArray, create_info, io, pick_types, read_events from mne.decoding import ( FilterEstimator, + LinearModel, PSDEstimator, Scaler, TemporalFilter, @@ -36,6 +40,7 @@ data_dir = Path(__file__).parents[2] / "io" / "tests" / "data" raw_fname = data_dir / "test_raw.fif" event_name = data_dir / "test-eve.fif" +info = create_info(2, 1000.0, "eeg") @pytest.mark.parametrize( @@ -101,9 +106,11 @@ def test_scaler(info, method): assert_array_almost_equal(epochs_data, Xi) # Test init exception - pytest.raises(ValueError, Scaler, None, None) - pytest.raises(TypeError, scaler.fit, epochs, y) - pytest.raises(TypeError, scaler.transform, epochs) + x = Scaler(None, None) + with pytest.raises(ValueError): + x.fit(epochs_data, y) + pytest.raises(ValueError, scaler.fit, "foo", y) + pytest.raises(ValueError, scaler.transform, "foo") epochs_bad = Epochs( raw, events, @@ -164,8 +171,8 @@ def test_filterestimator(): X = filt.fit_transform(epochs_data, y) # Test init exception - pytest.raises(ValueError, filt.fit, epochs, y) - pytest.raises(ValueError, filt.transform, epochs) + pytest.raises(ValueError, filt.fit, "foo", y) + pytest.raises(ValueError, filt.transform, "foo") def test_psdestimator(): @@ -182,14 +189,18 @@ def test_psdestimator(): epochs_data = epochs.get_data(copy=False) psd = PSDEstimator(2 * np.pi, 0, np.inf) y = epochs.events[:, -1] + assert not hasattr(psd, "fitted_") X = psd.fit_transform(epochs_data, y) + assert psd.fitted_ assert X.shape[0] == epochs_data.shape[0] assert_array_equal(psd.fit(epochs_data, y).transform(epochs_data), X) # Test init exception - pytest.raises(ValueError, psd.fit, epochs, y) - pytest.raises(ValueError, psd.transform, epochs) + with pytest.raises(ValueError): + psd.fit("foo", y) + with pytest.raises(ValueError): + psd.transform("foo") def test_vectorizer(): @@ -210,9 +221,16 @@ def test_vectorizer(): assert_equal(vect.fit_transform(data[1:]).shape, (149, 108)) # check if raised errors are working correctly - vect.fit(np.random.rand(105, 12, 3)) - pytest.raises(ValueError, vect.transform, np.random.rand(105, 12, 3, 1)) - pytest.raises(ValueError, vect.inverse_transform, np.random.rand(102, 12, 12)) + X = np.random.default_rng(0).standard_normal((105, 12, 3)) + y = np.arange(X.shape[0]) % 2 + pytest.raises(ValueError, vect.transform, X[..., np.newaxis]) + pytest.raises(ValueError, vect.inverse_transform, X[:, :-1]) + + # And that pipelines work properly + X_arr = EpochsArray(X, create_info(12, 1000.0, "eeg")) + vect.fit(X_arr) + clf = make_pipeline(Vectorizer(), StandardScaler(), LinearModel()) + clf.fit(X_arr, y) def test_unsupervised_spatial_filter(): @@ -235,11 +253,13 @@ def test_unsupervised_spatial_filter(): verbose=False, ) - # Test estimator - pytest.raises(ValueError, UnsupervisedSpatialFilter, KernelRidge(2)) + # Test estimator (must be a transformer) + X = epochs.get_data(copy=False) + usf = UnsupervisedSpatialFilter(KernelRidge(2)) + with pytest.raises(ValueError, match="transform"): + usf.fit(X) # Test fit - X = epochs.get_data(copy=False) n_components = 4 usf = UnsupervisedSpatialFilter(PCA(n_components)) usf.fit(X) @@ -255,7 +275,9 @@ def test_unsupervised_spatial_filter(): # Test with average param usf = UnsupervisedSpatialFilter(PCA(4), average=True) usf.fit_transform(X) - pytest.raises(ValueError, UnsupervisedSpatialFilter, PCA(4), 2) + usf = UnsupervisedSpatialFilter(PCA(4), 2) + with pytest.raises(TypeError, match="average must be"): + usf.fit(X) def test_temporal_filter(): @@ -281,8 +303,8 @@ def test_temporal_filter(): assert X.shape == Xt.shape # Test fit and transform numpy type check - with pytest.raises(ValueError, match="Data to be filtered must be"): - filt.transform([1, 2]) + with pytest.raises(ValueError): + filt.transform("foo") # Test with 2 dimensional data array X = np.random.rand(101, 500) @@ -298,4 +320,36 @@ def test_bad_triage(): filt = TemporalFilter(l_freq=8, h_freq=60, sfreq=160.0) # Used to fail with "ValueError: Effective band-stop frequency (135.0) is # too high (maximum based on Nyquist is 80.0)" + assert not hasattr(filt, "fitted_") filt.fit_transform(np.zeros((1, 1, 481))) + assert filt.fitted_ + + +@pytest.mark.filterwarnings("ignore:.*filter_length.*") +@parametrize_with_checks( + [ + FilterEstimator(info, l_freq=1, h_freq=10), + PSDEstimator(), + Scaler(scalings="mean"), + # Not easy to test Scaler(info) b/c number of channels must match + TemporalFilter(), + UnsupervisedSpatialFilter(PCA()), + Vectorizer(), + ] +) +def test_sklearn_compliance(estimator, check): + """Test LinearModel compliance with sklearn.""" + ignores = [] + if estimator.__class__.__name__ == "FilterEstimator": + ignores += [ + "check_estimators_overwrite_params", # we modify self.info + "check_methods_sample_order_invariance", + ] + if estimator.__class__.__name__.startswith(("PSD", "Temporal")): + ignores += [ + "check_transformers_unfitted", # allow unfitted transform + "check_methods_sample_order_invariance", + ] + if any(ignore in str(check) for ignore in ignores): + return + check(estimator) diff --git a/mne/decoding/time_frequency.py b/mne/decoding/time_frequency.py index de6ec52155b..29232aaeb9f 100644 --- a/mne/decoding/time_frequency.py +++ b/mne/decoding/time_frequency.py @@ -3,14 +3,16 @@ # Copyright the MNE-Python contributors. import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator +from sklearn.utils.validation import check_is_fitted from ..time_frequency.tfr import _compute_tfr -from ..utils import _check_option, fill_doc, verbose +from ..utils import _check_option, fill_doc +from .transformer import MNETransformerMixin @fill_doc -class TimeFrequency(TransformerMixin, BaseEstimator): +class TimeFrequency(MNETransformerMixin, BaseEstimator): """Time frequency transformer. Time-frequency transform of times series along the last axis. @@ -59,7 +61,6 @@ class TimeFrequency(TransformerMixin, BaseEstimator): mne.time_frequency.tfr_multitaper """ - @verbose def __init__( self, freqs, @@ -74,9 +75,6 @@ def __init__( verbose=None, ): """Init TimeFrequency transformer.""" - # Check non-average output - output = _check_option("output", output, ["complex", "power", "phase"]) - self.freqs = freqs self.sfreq = sfreq self.method = method @@ -89,6 +87,16 @@ def __init__( self.n_jobs = n_jobs self.verbose = verbose + def __sklearn_tags__(self): + """Return sklearn tags.""" + out = super().__sklearn_tags__() + from sklearn.utils import TransformerTags + + if out.transformer_tags is None: + out.transformer_tags = TransformerTags() + out.transformer_tags.preserves_dtype = [] # real->complex + return out + def fit_transform(self, X, y=None): """Time-frequency transform of times series along the last axis. @@ -123,6 +131,10 @@ def fit(self, X, y=None): # noqa: D401 self : object Return self. """ + # Check non-average output + _check_option("output", self.output, ["complex", "power", "phase"]) + self._check_data(X, y=y, fit=True) + self.fitted_ = True return self def transform(self, X): @@ -130,16 +142,18 @@ def transform(self, X): Parameters ---------- - X : array, shape (n_samples, n_channels, n_times) + X : array, shape (n_samples, [n_channels, ]n_times) The training data samples. The channel dimension can be zero- or 1-dimensional. Returns ------- - Xt : array, shape (n_samples, n_channels, n_freqs, n_times) + Xt : array, shape (n_samples, [n_channels, ]n_freqs, n_times) The time-frequency transform of the data, where n_channels can be zero- or 1-dimensional. """ + X = self._check_data(X, atleast_3d=False) + check_is_fitted(self, "fitted_") # Ensure 3-dimensional X shape = X.shape[1:-1] if not shape: diff --git a/mne/decoding/transformer.py b/mne/decoding/transformer.py index e475cd22161..6d0c83f42ab 100644 --- a/mne/decoding/transformer.py +++ b/mne/decoding/transformer.py @@ -3,19 +3,72 @@ # Copyright the MNE-Python contributors. import numpy as np -from sklearn.base import BaseEstimator, TransformerMixin +from sklearn.base import BaseEstimator, TransformerMixin, check_array, clone +from sklearn.preprocessing import RobustScaler, StandardScaler +from sklearn.utils import check_X_y +from sklearn.utils.validation import check_is_fitted, validate_data from .._fiff.pick import ( _pick_data_channels, _picks_by_type, _picks_to_idx, pick_info, - pick_types, ) from ..cov import _check_scalings_user +from ..epochs import BaseEpochs from ..filter import filter_data from ..time_frequency import psd_array_multitaper -from ..utils import _check_option, _validate_type, fill_doc, verbose +from ..utils import _check_option, _validate_type, fill_doc + + +class MNETransformerMixin(TransformerMixin): + """TransformerMixin plus some helpers.""" + + def _check_data( + self, + epochs_data, + *, + y=None, + atleast_3d=True, + fit=False, + return_y=False, + multi_output=False, + check_n_features=True, + ): + # Sklearn calls asarray under the hood which works, but elsewhere they check for + # __len__ then look at the size of obj[0]... which is an epoch of shape (1, ...) + # rather than what they expect (shape (...)). So we explicitly get the NumPy + # array to make everyone happy. + if isinstance(epochs_data, BaseEpochs): + epochs_data = epochs_data.get_data(copy=False) + kwargs = dict(dtype=np.float64, allow_nd=True, order="C", force_writeable=True) + if hasattr(self, "n_features_in_") and check_n_features: + if y is None: + epochs_data = validate_data( + self, + epochs_data, + **kwargs, + reset=fit, + ) + else: + epochs_data, y = validate_data( + self, + epochs_data, + y, + **kwargs, + reset=fit, + ) + elif y is None: + epochs_data = check_array(epochs_data, **kwargs) + else: + epochs_data, y = check_X_y( + X=epochs_data, y=y, multi_output=multi_output, **kwargs + ) + if fit: + self.n_features_in_ = epochs_data.shape[1] + if atleast_3d: + epochs_data = np.atleast_3d(epochs_data) + return (epochs_data, y) if return_y else epochs_data class _ConstantScaler: @@ -55,8 +108,9 @@ def fit_transform(self, X, y=None): def _sklearn_reshape_apply(func, return_result, X, *args, **kwargs): """Reshape epochs and apply function.""" - if not isinstance(X, np.ndarray): - raise ValueError(f"data should be an np.ndarray, got {type(X)}.") + _validate_type(X, np.ndarray, "X") + if X.size == 0: + return X.copy() if return_result else None orig_shape = X.shape X = np.reshape(X.transpose(0, 2, 1), (-1, orig_shape[1])) X = func(X, *args, **kwargs) @@ -67,7 +121,7 @@ def _sklearn_reshape_apply(func, return_result, X, *args, **kwargs): @fill_doc -class Scaler(TransformerMixin, BaseEstimator): +class Scaler(MNETransformerMixin, BaseEstimator): """Standardize channel data. This class scales data for each channel. It differs from scikit-learn @@ -109,31 +163,6 @@ def __init__(self, info=None, scalings=None, with_mean=True, with_std=True): self.with_std = with_std self.scalings = scalings - if not (scalings is None or isinstance(scalings, dict | str)): - raise ValueError( - f"scalings type should be dict, str, or None, got {type(scalings)}" - ) - if isinstance(scalings, str): - _check_option("scalings", scalings, ["mean", "median"]) - if scalings is None or isinstance(scalings, dict): - if info is None: - raise ValueError( - f'Need to specify "info" if scalings is {type(scalings)}' - ) - self._scaler = _ConstantScaler(info, scalings, self.with_std) - elif scalings == "mean": - from sklearn.preprocessing import StandardScaler - - self._scaler = StandardScaler( - with_mean=self.with_mean, with_std=self.with_std - ) - else: # scalings == 'median': - from sklearn.preprocessing import RobustScaler - - self._scaler = RobustScaler( - with_centering=self.with_mean, with_scaling=self.with_std - ) - def fit(self, epochs_data, y=None): """Standardize data across channels. @@ -149,11 +178,30 @@ def fit(self, epochs_data, y=None): self : instance of Scaler The modified instance. """ - _validate_type(epochs_data, np.ndarray, "epochs_data") - if epochs_data.ndim == 2: - epochs_data = epochs_data[..., np.newaxis] + epochs_data = self._check_data(epochs_data, y=y, fit=True, multi_output=True) assert epochs_data.ndim == 3, epochs_data.shape - _sklearn_reshape_apply(self._scaler.fit, False, epochs_data, y=y) + + _validate_type(self.scalings, (dict, str, type(None)), "scalings") + if isinstance(self.scalings, str): + _check_option( + "scalings", self.scalings, ["mean", "median"], extra="when str" + ) + if self.scalings is None or isinstance(self.scalings, dict): + if self.info is None: + raise ValueError( + f'Need to specify "info" if scalings is {type(self.scalings)}' + ) + self.scaler_ = _ConstantScaler(self.info, self.scalings, self.with_std) + elif self.scalings == "mean": + self.scaler_ = StandardScaler( + with_mean=self.with_mean, with_std=self.with_std + ) + else: # scalings == 'median': + self.scaler_ = RobustScaler( + with_centering=self.with_mean, with_scaling=self.with_std + ) + + _sklearn_reshape_apply(self.scaler_.fit, False, epochs_data, y=y) return self def transform(self, epochs_data): @@ -174,13 +222,14 @@ def transform(self, epochs_data): This function makes a copy of the data before the operations and the memory usage may be large with big data. """ - _validate_type(epochs_data, np.ndarray, "epochs_data") + check_is_fitted(self, "scaler_") + epochs_data = self._check_data(epochs_data, atleast_3d=False) if epochs_data.ndim == 2: # can happen with SlidingEstimator if self.info is not None: assert len(self.info["ch_names"]) == epochs_data.shape[1] epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape - return _sklearn_reshape_apply(self._scaler.transform, True, epochs_data) + return _sklearn_reshape_apply(self.scaler_.transform, True, epochs_data) def fit_transform(self, epochs_data, y=None): """Fit to data, then transform it. @@ -226,19 +275,20 @@ def inverse_transform(self, epochs_data): This function makes a copy of the data before the operations and the memory usage may be large with big data. """ + epochs_data = self._check_data(epochs_data, atleast_3d=False) squeeze = False # Can happen with CSP if epochs_data.ndim == 2: squeeze = True epochs_data = epochs_data[..., np.newaxis] assert epochs_data.ndim == 3, epochs_data.shape - out = _sklearn_reshape_apply(self._scaler.inverse_transform, True, epochs_data) + out = _sklearn_reshape_apply(self.scaler_.inverse_transform, True, epochs_data) if squeeze: out = out[..., 0] return out -class Vectorizer(TransformerMixin, BaseEstimator): +class Vectorizer(MNETransformerMixin, BaseEstimator): """Transform n-dimensional array into 2D array of n_samples by n_features. This class reshapes an n-dimensional array into an n_samples * n_features @@ -275,7 +325,7 @@ def fit(self, X, y=None): self : instance of Vectorizer Return the modified instance. """ - X = np.asarray(X) + X = self._check_data(X, y=y, atleast_3d=False, fit=True, check_n_features=False) self.features_shape_ = X.shape[1:] return self @@ -295,7 +345,7 @@ def transform(self, X): X : array, shape (n_samples, n_features) The transformed data. """ - X = np.asarray(X) + X = self._check_data(X, atleast_3d=False) if X.shape[1:] != self.features_shape_: raise ValueError("Shape of X used in fit and transform must be same") return X.reshape(len(X), -1) @@ -334,7 +384,7 @@ def inverse_transform(self, X): The data transformed into shape as used in fit. The first dimension is of length n_samples. """ - X = np.asarray(X) + X = self._check_data(X, atleast_3d=False, check_n_features=False) if X.ndim not in (2, 3): raise ValueError( f"X should be of 2 or 3 dimensions but has shape {X.shape}" @@ -343,7 +393,7 @@ def inverse_transform(self, X): @fill_doc -class PSDEstimator(TransformerMixin, BaseEstimator): +class PSDEstimator(MNETransformerMixin, BaseEstimator): """Compute power spectral density (PSD) using a multi-taper method. Parameters @@ -365,7 +415,6 @@ class PSDEstimator(TransformerMixin, BaseEstimator): n_jobs : int Number of parallel jobs to use (only used if adaptive=True). %(normalization)s - %(verbose)s See Also -------- @@ -375,7 +424,6 @@ class PSDEstimator(TransformerMixin, BaseEstimator): mne.Evoked.compute_psd """ - @verbose def __init__( self, sfreq=2 * np.pi, @@ -386,8 +434,6 @@ def __init__( low_bias=True, n_jobs=None, normalization="length", - *, - verbose=None, ): self.sfreq = sfreq self.fmin = fmin @@ -398,7 +444,7 @@ def __init__( self.n_jobs = n_jobs self.normalization = normalization - def fit(self, epochs_data, y): + def fit(self, epochs_data, y=None): """Compute power spectral density (PSD) using a multi-taper method. Parameters @@ -413,11 +459,8 @@ def fit(self, epochs_data, y): self : instance of PSDEstimator The modified instance. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - + self._check_data(epochs_data, y=y, fit=True) + self.fitted_ = True # sklearn compliance return self def transform(self, epochs_data): @@ -433,10 +476,7 @@ def transform(self, epochs_data): psd : array, shape (n_signals, n_freqs) or (n_freqs,) The computed PSD. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) + epochs_data = self._check_data(epochs_data) psd, _ = psd_array_multitaper( epochs_data, sfreq=self.sfreq, @@ -452,7 +492,7 @@ def transform(self, epochs_data): @fill_doc -class FilterEstimator(TransformerMixin, BaseEstimator): +class FilterEstimator(MNETransformerMixin, BaseEstimator): """Estimator to filter RtEpochs. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop @@ -488,7 +528,6 @@ class FilterEstimator(TransformerMixin, BaseEstimator): See mne.filter.construct_iir_filter for details. If iir_params is None and method="iir", 4th order Butterworth will be used. %(fir_design)s - %(verbose)s See Also -------- @@ -514,13 +553,11 @@ def __init__( method="fir", iir_params=None, fir_design="firwin", - *, - verbose=None, ): self.info = info self.l_freq = l_freq self.h_freq = h_freq - self.picks = _picks_to_idx(info, picks) + self.picks = picks self.filter_length = filter_length self.l_trans_bandwidth = l_trans_bandwidth self.h_trans_bandwidth = h_trans_bandwidth @@ -544,24 +581,11 @@ def fit(self, epochs_data, y): self : instance of FilterEstimator The modified instance. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - - if self.picks is None: - self.picks = pick_types( - self.info, meg=True, eeg=True, ref_meg=False, exclude=[] - ) + self.picks_ = _picks_to_idx(self.info, self.picks) + self._check_data(epochs_data, y=y, fit=True) if self.l_freq == 0: self.l_freq = None - if self.h_freq is not None and self.h_freq > (self.info["sfreq"] / 2.0): - self.h_freq = None - if self.l_freq is not None and not isinstance(self.l_freq, float): - self.l_freq = float(self.l_freq) - if self.h_freq is not None and not isinstance(self.h_freq, float): - self.h_freq = float(self.h_freq) if self.info["lowpass"] is None or ( self.h_freq is not None @@ -594,17 +618,12 @@ def transform(self, epochs_data): X : array, shape (n_epochs, n_channels, n_times) The data after filtering. """ - if not isinstance(epochs_data, np.ndarray): - raise ValueError( - f"epochs_data should be of type ndarray (got {type(epochs_data)})." - ) - epochs_data = np.atleast_3d(epochs_data) return filter_data( - epochs_data, + self._check_data(epochs_data), self.info["sfreq"], self.l_freq, self.h_freq, - self.picks, + self.picks_, self.filter_length, self.l_trans_bandwidth, self.h_trans_bandwidth, @@ -617,7 +636,7 @@ def transform(self, epochs_data): ) -class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): +class UnsupervisedSpatialFilter(MNETransformerMixin, BaseEstimator): """Use unsupervised spatial filtering across time and samples. Parameters @@ -630,19 +649,6 @@ class UnsupervisedSpatialFilter(TransformerMixin, BaseEstimator): """ def __init__(self, estimator, average=False): - # XXX: Use _check_estimator #3381 - for attr in ("fit", "transform", "fit_transform"): - if not hasattr(estimator, attr): - raise ValueError( - "estimator must be a scikit-learn " - f"transformer, missing {attr} method" - ) - - if not isinstance(average, bool): - raise ValueError( - f"average parameter must be of bool type, got {type(bool)} instead" - ) - self.estimator = estimator self.average = average @@ -661,13 +667,25 @@ def fit(self, X, y=None): self : instance of UnsupervisedSpatialFilter Return the modified instance. """ + # sklearn.utils.estimator_checks.check_estimator(self.estimator) is probably + # too strict for us, given that we don't fully adhere yet, so just check attrs + for attr in ("fit", "transform", "fit_transform"): + if not hasattr(self.estimator, attr): + raise ValueError( + "estimator must be a scikit-learn " + f"transformer, missing {attr} method" + ) + _validate_type(self.average, bool, "average") + X = self._check_data(X, y=y, fit=True) if self.average: X = np.mean(X, axis=0).T else: n_epochs, n_channels, n_times = X.shape # trial as time samples X = np.transpose(X, (1, 0, 2)).reshape((n_channels, n_epochs * n_times)).T - self.estimator.fit(X) + + self.estimator_ = clone(self.estimator) + self.estimator_.fit(X) return self def fit_transform(self, X, y=None): @@ -700,6 +718,8 @@ def transform(self, X): X : array, shape (n_epochs, n_channels, n_times) The transformed data. """ + check_is_fitted(self.estimator_) + X = self._check_data(X) return self._apply_method(X, "transform") def inverse_transform(self, X): @@ -735,7 +755,7 @@ def _apply_method(self, X, method): X = np.transpose(X, [1, 0, 2]) X = np.reshape(X, [n_channels, n_epochs * n_times]).T # apply method - method = getattr(self.estimator, method) + method = getattr(self.estimator_, method) X = method(X) # put it back to n_epochs, n_dimensions X = np.reshape(X.T, [-1, n_epochs, n_times]).transpose([1, 0, 2]) @@ -743,7 +763,7 @@ def _apply_method(self, X, method): @fill_doc -class TemporalFilter(TransformerMixin, BaseEstimator): +class TemporalFilter(MNETransformerMixin, BaseEstimator): """Estimator to filter data array along the last dimension. Applies a zero-phase low-pass, high-pass, band-pass, or band-stop @@ -817,7 +837,6 @@ class TemporalFilter(TransformerMixin, BaseEstimator): attenuation using fewer samples than "firwin2". .. versionadded:: 0.15 - %(verbose)s See Also -------- @@ -826,7 +845,6 @@ class TemporalFilter(TransformerMixin, BaseEstimator): mne.filter.filter_data """ - @verbose def __init__( self, l_freq=None, @@ -840,8 +858,6 @@ def __init__( iir_params=None, fir_window="hamming", fir_design="firwin", - *, - verbose=None, ): self.l_freq = l_freq self.h_freq = h_freq @@ -855,17 +871,12 @@ def __init__( self.fir_window = fir_window self.fir_design = fir_design - if not isinstance(self.n_jobs, int) and self.n_jobs == "cuda": - raise ValueError( - f'n_jobs must be int or "cuda", got {type(self.n_jobs)} instead.' - ) - def fit(self, X, y=None): """Do nothing (for scikit-learn compatibility purposes). Parameters ---------- - X : array, shape (n_epochs, n_channels, n_times) or or shape (n_channels, n_times) + X : array, shape ([n_epochs, ]n_channels, n_times) The data to be filtered over the last dimension. The channels dimension can be zero when passing a 2D array. y : None @@ -875,7 +886,9 @@ def fit(self, X, y=None): ------- self : instance of TemporalFilter The modified instance. - """ # noqa: E501 + """ + self.fitted_ = True # sklearn compliance + self._check_data(X, y=y, atleast_3d=False, fit=True) return self def transform(self, X): @@ -883,7 +896,7 @@ def transform(self, X): Parameters ---------- - X : array, shape (n_epochs, n_channels, n_times) or shape (n_channels, n_times) + X : array, shape ([n_epochs, ]n_channels, n_times) The data to be filtered over the last dimension. The channels dimension can be zero when passing a 2D array. @@ -892,6 +905,7 @@ def transform(self, X): X : array The data after filtering. """ # noqa: E501 + X = self._check_data(X, atleast_3d=False) X = np.atleast_2d(X) if X.ndim > 3: diff --git a/mne/epochs.py b/mne/epochs.py index 679643ab969..ee8921d3990 100644 --- a/mne/epochs.py +++ b/mne/epochs.py @@ -1353,6 +1353,7 @@ def plot_topo_image( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): return plot_topo_image_epochs( @@ -1371,6 +1372,7 @@ def plot_topo_image( fig_facecolor=fig_facecolor, fig_background=fig_background, font_color=font_color, + select=select, show=show, ) diff --git a/mne/evoked.py b/mne/evoked.py index c04f83531e3..7bd2355e4ee 100644 --- a/mne/evoked.py +++ b/mne/evoked.py @@ -613,6 +613,7 @@ def plot_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """. @@ -639,6 +640,7 @@ def plot_topo( background_color=background_color, noise_cov=noise_cov, exclude=exclude, + select=select, show=show, ) diff --git a/mne/export/_brainvision.py b/mne/export/_brainvision.py index ba64ba010ce..6503c540f41 100644 --- a/mne/export/_brainvision.py +++ b/mne/export/_brainvision.py @@ -107,6 +107,13 @@ def _export_mne_raw(*, raw, fname, events=None, overwrite=False): def _mne_annots2pybv_events(raw): """Convert mne Annotations to pybv events.""" + # check that raw.annotations.orig_time is the same as raw.info["meas_date"] + # so that onsets are relative to the first sample + # (after further correction for first_time) + if raw.annotations and raw.info["meas_date"] != raw.annotations.orig_time: + raise ValueError( + "Annotations must have the same orig_time as raw.info['meas_date']" + ) events = [] for annot in raw.annotations: # handle onset and duration: seconds to sample, relative to diff --git a/mne/export/_edf.py b/mne/export/_edf.py index ef870692014..e50b05f7056 100644 --- a/mne/export/_edf.py +++ b/mne/export/_edf.py @@ -7,6 +7,7 @@ import numpy as np +from ..annotations import _sync_onset from ..utils import _check_edfio_installed, warn _check_edfio_installed() @@ -204,7 +205,9 @@ def _export_raw(fname, raw, physical_range, add_ch_type): for desc, onset, duration, ch_names in zip( raw.annotations.description, - raw.annotations.onset, + # subtract raw.first_time because EDF marks events starting from the first + # available data point and ignores raw.first_time + _sync_onset(raw, raw.annotations.onset, inverse=False), raw.annotations.duration, raw.annotations.ch_names, ): diff --git a/mne/export/_eeglab.py b/mne/export/_eeglab.py index 3c8f896164a..459207f0616 100644 --- a/mne/export/_eeglab.py +++ b/mne/export/_eeglab.py @@ -4,6 +4,7 @@ import numpy as np +from ..annotations import _sync_onset from ..utils import _check_eeglabio_installed _check_eeglabio_installed() @@ -24,11 +25,16 @@ def _export_raw(fname, raw): ch_names = [ch for ch in raw.ch_names if ch not in drop_chs] cart_coords = _get_als_coords_from_chs(raw.info["chs"], drop_chs) - annotations = [ - raw.annotations.description, - raw.annotations.onset, - raw.annotations.duration, - ] + if raw.annotations: + annotations = [ + raw.annotations.description, + # subtract raw.first_time because EEGLAB marks events starting from + # the first available data point and ignores raw.first_time + _sync_onset(raw, raw.annotations.onset, inverse=False), + raw.annotations.duration, + ] + else: + annotations = None eeglabio.raw.export_set( fname, data=raw.get_data(picks=ch_names), diff --git a/mne/export/_export.py b/mne/export/_export.py index 6e63064bf7c..4b93fda917e 100644 --- a/mne/export/_export.py +++ b/mne/export/_export.py @@ -25,6 +25,14 @@ def export_raw( %(export_warning)s + .. warning:: + When exporting ``Raw`` with annotations, ``raw.info["meas_date"]`` must be the + same as ``raw.annotations.orig_time``. This guarantees that the annotations are + in the same reference frame as the samples. When + :attr:`Raw.first_time ` is not zero (e.g., after + cropping), the onsets are automatically corrected so that onsets are always + relative to the first sample. + Parameters ---------- %(fname_export_params)s diff --git a/mne/export/tests/test_export.py b/mne/export/tests/test_export.py index 191e91b1eed..6f712923c7d 100644 --- a/mne/export/tests/test_export.py +++ b/mne/export/tests/test_export.py @@ -122,6 +122,49 @@ def test_export_raw_eeglab(tmp_path): raw.export(temp_fname, overwrite=True) +@pytest.mark.parametrize("tmin", (0, 1, 5, 10)) +def test_export_raw_eeglab_annotations(tmp_path, tmin): + """Test annotations in the exported EEGLAB file. + + All annotations should be preserved and onset corrected. + """ + pytest.importorskip("eeglabio") + raw = read_raw_fif(fname_raw, preload=True) + raw.apply_proj() + annotations = Annotations( + onset=[0.01, 0.05, 0.90, 1.05], + duration=[0, 1, 0, 0], + description=["test1", "test2", "test3", "test4"], + ch_names=[["MEG 0113"], ["MEG 0113", "MEG 0132"], [], ["MEG 0143"]], + ) + raw.set_annotations(annotations) + raw.crop(tmin) + + # export + temp_fname = tmp_path / "test.set" + raw.export(temp_fname) + + # read in the file + with pytest.warns(RuntimeWarning, match="is above the 99th percentile"): + raw_read = read_raw_eeglab(temp_fname, preload=True, montage_units="m") + assert raw_read.first_time == 0 # exportation resets first_time + valid_annot = ( + raw.annotations.onset >= tmin + ) # only annotations in the cropped range gets exported + + # compare annotations before and after export + assert_array_almost_equal( + raw.annotations.onset[valid_annot] - raw.first_time, + raw_read.annotations.onset, + ) + assert_array_equal( + raw.annotations.duration[valid_annot], raw_read.annotations.duration + ) + assert_array_equal( + raw.annotations.description[valid_annot], raw_read.annotations.description + ) + + def _create_raw_for_edf_tests(stim_channel_index=None): rng = np.random.RandomState(12345) ch_types = [ @@ -154,6 +197,7 @@ def test_double_export_edf(tmp_path): """Test exporting an EDF file multiple times.""" raw = _create_raw_for_edf_tests(stim_channel_index=2) raw.info.set_meas_date("2023-09-04 14:53:09.000") + raw.set_annotations(Annotations(onset=[1], duration=[0], description=["test"])) # include subject info and measurement date raw.info["subject_info"] = dict( @@ -258,8 +302,12 @@ def test_edf_padding(tmp_path, pad_width): @edfio_mark() -def test_export_edf_annotations(tmp_path): - """Test that exporting EDF preserves annotations.""" +@pytest.mark.parametrize("tmin", (0, 0.005, 0.03, 1)) +def test_export_edf_annotations(tmp_path, tmin): + """Test annotations in the exported EDF file. + + All annotations should be preserved and onset corrected. + """ raw = _create_raw_for_edf_tests() annotations = Annotations( onset=[0.01, 0.05, 0.90, 1.05], @@ -268,17 +316,44 @@ def test_export_edf_annotations(tmp_path): ch_names=[["0"], ["0", "1"], [], ["1"]], ) raw.set_annotations(annotations) + raw.crop(tmin) + assert raw.first_time == tmin + + if raw.n_times % raw.info["sfreq"] == 0: + expectation = nullcontext() + else: + expectation = pytest.warns( + RuntimeWarning, match="EDF format requires equal-length data blocks" + ) # export temp_fname = tmp_path / "test.edf" - raw.export(temp_fname) + with expectation: + raw.export(temp_fname) # read in the file raw_read = read_raw_edf(temp_fname, preload=True) - assert_array_equal(raw.annotations.onset, raw_read.annotations.onset) - assert_array_equal(raw.annotations.duration, raw_read.annotations.duration) - assert_array_equal(raw.annotations.description, raw_read.annotations.description) - assert_array_equal(raw.annotations.ch_names, raw_read.annotations.ch_names) + assert raw_read.first_time == 0 # exportation resets first_time + bad_annot = raw_read.annotations.description == "BAD_ACQ_SKIP" + if bad_annot.any(): + raw_read.annotations.delete(bad_annot) + valid_annot = ( + raw.annotations.onset >= tmin + ) # only annotations in the cropped range gets exported + + # compare annotations before and after export + assert_array_almost_equal( + raw.annotations.onset[valid_annot] - raw.first_time, raw_read.annotations.onset + ) + assert_array_equal( + raw.annotations.duration[valid_annot], raw_read.annotations.duration + ) + assert_array_equal( + raw.annotations.description[valid_annot], raw_read.annotations.description + ) + assert_array_equal( + raw.annotations.ch_names[valid_annot], raw_read.annotations.ch_names + ) @edfio_mark() diff --git a/mne/forward/tests/test_make_forward.py b/mne/forward/tests/test_make_forward.py index 37ec6e041b5..a357c5779c9 100644 --- a/mne/forward/tests/test_make_forward.py +++ b/mne/forward/tests/test_make_forward.py @@ -482,7 +482,7 @@ def test_make_forward_solution_openmeeg(n_layers): eeg_atol=100, meg_corr_tol=0.98, eeg_corr_tol=0.98, - meg_rdm_tol=0.1, + meg_rdm_tol=0.11, eeg_rdm_tol=0.2, ) diff --git a/mne/io/base.py b/mne/io/base.py index 280330367f7..b3052b80aff 100644 --- a/mne/io/base.py +++ b/mne/io/base.py @@ -1386,7 +1386,10 @@ def resample( sfreq = float(sfreq) o_sfreq = float(self.info["sfreq"]) if _check_resamp_noop(sfreq, o_sfreq): - return self + if events is not None: + return self, events.copy() + else: + return self # When no event object is supplied, some basic detection of dropped # events is performed to generate a warning. Finding events can fail diff --git a/mne/io/fiff/tests/test_raw_fiff.py b/mne/io/fiff/tests/test_raw_fiff.py index 1ae0cc52901..3ae49189161 100644 --- a/mne/io/fiff/tests/test_raw_fiff.py +++ b/mne/io/fiff/tests/test_raw_fiff.py @@ -23,6 +23,7 @@ concatenate_events, create_info, equalize_channels, + events_from_annotations, find_events, make_fixed_length_epochs, pick_channels, @@ -1318,6 +1319,15 @@ def test_crop(): assert raw.n_times - 1 == raw3.n_times +@testing.requires_testing_data +def test_resample_with_events(): + """Test resampling raws with events.""" + raw = read_raw_fif(fif_fname) + raw.resample(250) # pretend raw is recorded at 250 Hz + events, _ = events_from_annotations(raw) + raw, events = raw.resample(250, events=events) + + @testing.requires_testing_data def test_resample_equiv(): """Test resample (with I/O and multiple files).""" diff --git a/mne/preprocessing/tests/test_fine_cal.py b/mne/preprocessing/tests/test_fine_cal.py index 45971620db5..8b45208e848 100644 --- a/mne/preprocessing/tests/test_fine_cal.py +++ b/mne/preprocessing/tests/test_fine_cal.py @@ -231,7 +231,7 @@ def test_fine_cal_systems(system, tmp_path): err_limit = 6000 n_ref = 28 corrs = (0.19, 0.41, 0.49) - sfs = [0.5, 0.7, 0.9, 1.5] + sfs = [0.5, 0.7, 0.9, 1.55] corr_tol = 0.55 elif system == "fil": raw = read_raw_fil(fil_fname, verbose="error") diff --git a/mne/time_frequency/multitaper.py b/mne/time_frequency/multitaper.py index 98705e838c2..1c1a3baf238 100644 --- a/mne/time_frequency/multitaper.py +++ b/mne/time_frequency/multitaper.py @@ -63,7 +63,14 @@ def dpss_windows(N, half_nbw, Kmax, *, sym=True, norm=None, low_bias=True): ---------- .. footbibliography:: """ - dpss, eigvals = sp_dpss(N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True) + # TODO VERSION can be removed with SciPy 1.16 is min, + # workaround for https://github.com/scipy/scipy/pull/22344 + if N <= 1: + dpss, eigvals = np.ones((1, 1)), np.ones(1) + else: + dpss, eigvals = sp_dpss( + N, half_nbw, Kmax, sym=sym, norm=norm, return_ratios=True + ) if low_bias: idx = eigvals > 0.9 if not idx.any(): diff --git a/mne/time_frequency/tfr.py b/mne/time_frequency/tfr.py index fc60802f61b..f4a01e87895 100644 --- a/mne/time_frequency/tfr.py +++ b/mne/time_frequency/tfr.py @@ -562,7 +562,8 @@ def _compute_tfr( if len(Ws[0][0]) > epoch_data.shape[2]: raise ValueError( "At least one of the wavelets is longer than the " - "signal. Use a longer signal or shorter wavelets." + f"signal ({len(Ws[0][0])} > {epoch_data.shape[2]} samples). " + "Use a longer signal or shorter wavelets." ) # Initialize output diff --git a/mne/utils/docs.py b/mne/utils/docs.py index 683704c4bc6..54cc6845e58 100644 --- a/mne/utils/docs.py +++ b/mne/utils/docs.py @@ -1494,19 +1494,22 @@ def _reflow_param_docstring(docstring, has_first_line=True, width=75): docdict["export_fmt_support_epochs"] = """\ Supported formats: - - EEGLAB (``.set``, uses :mod:`eeglabio`) + +- EEGLAB (``.set``, uses :mod:`eeglabio`) """ docdict["export_fmt_support_evoked"] = """\ Supported formats: - - MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) + +- MFF (``.mff``, uses :func:`mne.export.export_evokeds_mff`) """ docdict["export_fmt_support_raw"] = """\ Supported formats: - - BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) - - EEGLAB (``.set``, uses :mod:`eeglabio`) - - EDF (``.edf``, uses `edfio `_) + +- BrainVision (``.vhdr``, ``.vmrk``, ``.eeg``, uses `pybv `_) +- EEGLAB (``.set``, uses :mod:`eeglabio`) +- EDF (``.edf``, uses `edfio `_) """ # noqa: E501 docdict["export_warning"] = """\ diff --git a/mne/utils/numerics.py b/mne/utils/numerics.py index 5029e8fbeca..11ba0ecb487 100644 --- a/mne/utils/numerics.py +++ b/mne/utils/numerics.py @@ -35,7 +35,7 @@ check_random_state, ) from .docs import fill_doc -from .misc import _empty_hash +from .misc import _empty_hash, _pl def split_list(v, n, idx=False): @@ -479,7 +479,8 @@ def _time_mask( extra = "" if include_tmax else "when include_tmax=False " raise ValueError( f"No samples remain when using tmin={orig_tmin} and tmax={orig_tmax} " - f"{extra}(original time bounds are [{times[0]}, {times[-1]}])" + f"{extra}(original time bounds are [{times[0]}, {times[-1]}] containing " + f"{len(times)} sample{_pl(times)})" ) return mask diff --git a/mne/viz/_brain/tests/test_brain.py b/mne/viz/_brain/tests/test_brain.py index 5d092c21713..46406542b5c 100644 --- a/mne/viz/_brain/tests/test_brain.py +++ b/mne/viz/_brain/tests/test_brain.py @@ -773,9 +773,6 @@ def test_single_hemi(hemi, renderer_interactive_pyvistaqt, brain_gc): def test_brain_save_movie(tmp_path, renderer, brain_gc, interactive_state): """Test saving a movie of a Brain instance.""" imageio_ffmpeg = pytest.importorskip("imageio_ffmpeg") - # TODO: Figure out why this fails -- some imageio_ffmpeg error - if os.getenv("MNE_CI_KIND", "") == "conda" and platform.system() == "Linux": - pytest.skip("Test broken for unknown reason on conda linux") brain = _create_testing_brain( hemi="lh", time_viewer=False, cortex=["r", "b"] diff --git a/mne/viz/_figure.py b/mne/viz/_figure.py index b63d2a395e2..090c661f633 100644 --- a/mne/viz/_figure.py +++ b/mne/viz/_figure.py @@ -424,7 +424,7 @@ def _redraw(self, update_data=True, annotations=False): if annotations and not self.mne.is_epochs: self._draw_annotations() - def _close(self, event): + def _close(self, event=None): """Handle close events (via keypress or window [x]).""" from matplotlib.pyplot import close @@ -500,11 +500,11 @@ def _create_ch_location_fig(self, pick): show=False, ) # highlight desired channel & disable interactivity - inds = np.isin(fig.lasso.ch_names, [ch_name]) + fig.lasso.selection_inds = np.isin(fig.lasso.names, [ch_name]) fig.lasso.disconnect() - fig.lasso.alpha_other = 0.3 + fig.lasso.alpha_nonselected = 0.3 fig.lasso.linewidth_selected = 3 - fig.lasso.style_sensors(inds) + fig.lasso.style_objects() return fig diff --git a/mne/viz/_mpl_figure.py b/mne/viz/_mpl_figure.py index 2e552bd4012..f3563b454f0 100644 --- a/mne/viz/_mpl_figure.py +++ b/mne/viz/_mpl_figure.py @@ -186,7 +186,7 @@ def _inch_to_rel(self, dim_inches, horiz=True): class MNEAnnotationFigure(MNEFigure): """Interactive dialog figure for annotations.""" - def _close(self, event): + def _close(self, event=None): """Handle close events (via keypress or window [x]).""" parent = self.mne.parent_fig # disable span selector @@ -275,7 +275,7 @@ def _set_active_button(self, idx, *, draw=True): class MNESelectionFigure(MNEFigure): """Interactive dialog figure for channel selections.""" - def _close(self, event): + def _close(self, event=None): """Handle close events.""" self.mne.parent_fig.mne.child_figs.remove(self) self.mne.fig_selection = None @@ -1536,7 +1536,7 @@ def _update_selection(self): def _update_highlighted_sensors(self): """Update the sensor plot to show what is selected.""" inds = np.isin( - self.mne.fig_selection.lasso.ch_names, self.mne.ch_names[self.mne.picks] + self.mne.fig_selection.lasso.names, self.mne.ch_names[self.mne.picks] ).nonzero()[0] self.mne.fig_selection.lasso.select_many(inds) diff --git a/mne/viz/backends/_pyvista.py b/mne/viz/backends/_pyvista.py index 8fe4b4bf1d8..0bd1ae1d3ca 100644 --- a/mne/viz/backends/_pyvista.py +++ b/mne/viz/backends/_pyvista.py @@ -1331,7 +1331,6 @@ def _is_osmesa(plotter): ) gpu_info = " ".join(gpu_info).lower() is_osmesa = "mesa" in gpu_info.split() - print(is_osmesa) if is_osmesa: # Try to warn if it's ancient version = re.findall("mesa ([0-9.]+)[ -].*", gpu_info) or re.findall( @@ -1345,7 +1344,7 @@ def _is_osmesa(plotter): "surface rendering, consider upgrading to 18.3.6 or " "later." ) - is_osmesa = "via llvmpipe" in gpu_info + is_osmesa = "llvmpipe" in gpu_info return is_osmesa diff --git a/mne/viz/evoked.py b/mne/viz/evoked.py index b047de4ea32..96ee0684e6e 100644 --- a/mne/viz/evoked.py +++ b/mne/viz/evoked.py @@ -1153,6 +1153,7 @@ def plot_evoked_topo( background_color="w", noise_cov=None, exclude="bads", + select=False, show=True, ): """Plot 2D topography of evoked responses. @@ -1218,6 +1219,15 @@ def plot_evoked_topo( exclude : list of str | ``'bads'`` Channels names to exclude from the plot. If ``'bads'``, the bad channels are excluded. By default, exclude is set to ``'bads'``. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 + exclude : list of str | ``'bads'`` + Channels names to exclude from the plot. If ``'bads'``, the + bad channels are excluded. By default, exclude is set to ``'bads'``. show : bool Show figure if True. @@ -1274,10 +1284,11 @@ def plot_evoked_topo( font_color=font_color, merge_channels=merge_grads, legend=legend, + noise_cov=noise_cov, axes=axes, exclude=exclude, + select=select, show=show, - noise_cov=noise_cov, ) diff --git a/mne/viz/tests/test_raw.py b/mne/viz/tests/test_raw.py index 89e0a7c543d..caa09ae4d07 100644 --- a/mne/viz/tests/test_raw.py +++ b/mne/viz/tests/test_raw.py @@ -1088,36 +1088,25 @@ def test_plot_sensors(raw): pytest.raises(TypeError, plot_sensors, raw) # needs to be info pytest.raises(ValueError, plot_sensors, raw.info, kind="sasaasd") plt.close("all") + + # Test lasso selection. fig, sels = raw.plot_sensors("select", show_names=True) ax = fig.axes[0] - - # Click with no sensors - _fake_click(fig, ax, (0.0, 0.0), xform="data") - _fake_click(fig, ax, (0, 0.0), xform="data", kind="release") - assert fig.lasso.selection == [] - - # Lasso with 1 sensor (upper left) - _fake_click(fig, ax, (0, 1), xform="ax") - fig.canvas.draw() - assert fig.lasso.selection == [] - _fake_click(fig, ax, (0.65, 1), xform="ax", kind="motion") - _fake_click(fig, ax, (0.65, 0.7), xform="ax", kind="motion") - _fake_keypress(fig, "control") - _fake_click(fig, ax, (0, 0.7), xform="ax", kind="release", key="control") + # Lasso a single sensor. + _fake_click(fig, ax, (-0.13, 0.13), xform="data") + _fake_click(fig, ax, (-0.11, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.11, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.06), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="motion") + _fake_click(fig, ax, (-0.13, 0.13), xform="data", kind="release") assert fig.lasso.selection == ["MEG 0121"] - # check that point appearance changes - fc = fig.lasso.collection.get_facecolors() - ec = fig.lasso.collection.get_edgecolors() - assert (fc[:, -1] == [0.5, 1.0, 0.5]).all() - assert (ec[:, -1] == [0.25, 1.0, 0.25]).all() - - _fake_click(fig, ax, (0.7, 1), xform="ax", kind="motion", key="control") - xy = ax.collections[0].get_offsets() - _fake_click(fig, ax, xy[2], xform="data", key="control") # single sel + # Add another sensor with a single click. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (-0.1278, 0.0318), xform="data") + _fake_click(fig, ax, (-0.1278, 0.0318), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") assert fig.lasso.selection == ["MEG 0121", "MEG 0131"] - _fake_click(fig, ax, xy[2], xform="data", key="control") # deselect - assert fig.lasso.selection == ["MEG 0121"] plt.close("all") raw.info["dev_head_t"] = None # like empty room diff --git a/mne/viz/tests/test_topo.py b/mne/viz/tests/test_topo.py index 85b4b43dcf8..48d031739b9 100644 --- a/mne/viz/tests/test_topo.py +++ b/mne/viz/tests/test_topo.py @@ -23,7 +23,7 @@ ) from mne.viz.evoked import _line_plot_onselect from mne.viz.topo import _imshow_tfr, _plot_update_evoked_topo_proj, iter_topography -from mne.viz.utils import _fake_click +from mne.viz.utils import _fake_click, _fake_keypress base_dir = Path(__file__).parents[2] / "io" / "tests" / "data" evoked_fname = base_dir / "test-ave.fif" @@ -231,6 +231,16 @@ def test_plot_topo(): break plt.close("all") + # Test plot_topo with selection of channels enabled. + fig = evoked.plot_topo(select=True) + ax = fig.axes[0] + _fake_click(fig, ax, (0.05, 0.62), xform="data") + _fake_click(fig, ax, (0.2, 0.62), xform="data", kind="motion") + _fake_click(fig, ax, (0.2, 0.7), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.7), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.7), xform="data", kind="release") + assert fig.lasso.selection == ["MEG 0113", "MEG 0112", "MEG 0111"] + def test_plot_topo_nirs(fnirs_evoked): """Test plotting of ERP topography for nirs data.""" @@ -296,6 +306,30 @@ def test_plot_topo_image_epochs(): assert qm_cmap[0] is cmap +def test_plot_topo_select(): + """Test selecting sensors in an ERP topography plot.""" + # Show topography + evoked = _get_epochs().average() + fig = plot_evoked_topo(evoked, select=True) + ax = fig.axes[0] + + # Lasso select 3 out of the 6 sensors. + _fake_click(fig, ax, (0.05, 0.5), xform="data") + _fake_click(fig, ax, (0.2, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.2, 0.6), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.6), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.05, 0.5), xform="data", kind="release") + assert fig.lasso.selection == ["MEG 0132", "MEG 0133", "MEG 0131"] + + # Add another sensor with a single click. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (0.11, 0.65), xform="data") + _fake_click(fig, ax, (0.21, 0.65), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") + assert fig.lasso.selection == ["MEG 0111", "MEG 0132", "MEG 0133", "MEG 0131"] + + def test_plot_tfr_topo(): """Test plotting of TFR data.""" epochs = _get_epochs() diff --git a/mne/viz/tests/test_utils.py b/mne/viz/tests/test_utils.py index 59e2976e464..55dc0f1e65c 100644 --- a/mne/viz/tests/test_utils.py +++ b/mne/viz/tests/test_utils.py @@ -16,6 +16,7 @@ from mne.viz import ClickableImage, add_background_image, mne_analyze_colormap from mne.viz.ui_events import ColormapRange, link, subscribe from mne.viz.utils import ( + SelectFromCollection, _compute_scalings, _fake_click, _fake_keypress, @@ -274,3 +275,71 @@ def callback(event): cmap_new1 = fig.axes[0].CB.mappable.get_cmap().name cmap_new2 = fig2.axes[0].CB.mappable.get_cmap().name assert cmap_new1 == cmap_new2 == cmap_want != cmap_old + + +def test_select_from_collection(): + """Test the lasso selector for matplotlib figures.""" + fig, ax = plt.subplots() + collection = ax.scatter([1, 2, 2, 1], [1, 1, 0, 0], color="black", edgecolor="red") + ax.set_xlim(-1, 4) + ax.set_ylim(-1, 2) + lasso = SelectFromCollection(ax, collection, names=["A", "B", "C", "D"]) + assert lasso.selection == [] + + # Make a selection with no patches inside of it. + _fake_click(fig, ax, (0, 0), xform="data") + _fake_click(fig, ax, (0.5, 0), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1), xform="data", kind="release") + assert lasso.selection == [] + + # Doing a single click on a patch should not select it. + _fake_click(fig, ax, (1, 1), xform="data") + assert lasso.selection == [] + + # Make a selection with two patches in it. + _fake_click(fig, ax, (0, 0.5), xform="data") + _fake_click(fig, ax, (3, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (3, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0, 0.5), xform="data", kind="release") + assert lasso.selection == ["A", "B"] + + # Use Control key to lasso an additional patch. + _fake_keypress(fig, "control") + _fake_click(fig, ax, (0.5, -0.5), xform="data") + _fake_click(fig, ax, (1.5, -0.5), xform="data", kind="motion") + _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 0.5), xform="data", kind="release") + _fake_keypress(fig, "control", kind="release") + assert lasso.selection == ["A", "B", "D"] + + # Use CTRL+SHIFT to remove a patch. + _fake_keypress(fig, "ctrl+shift") + _fake_click(fig, ax, (0.5, 0.5), xform="data") + _fake_click(fig, ax, (1.5, 0.5), xform="data", kind="motion") + _fake_click(fig, ax, (1.5, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="motion") + _fake_click(fig, ax, (0.5, 1.5), xform="data", kind="release") + _fake_keypress(fig, "ctrl+shift", kind="release") + assert lasso.selection == ["B", "D"] + + # Check that the two selected patches have a different appearance. + fc = lasso.collection.get_facecolors() + ec = lasso.collection.get_edgecolors() + assert (fc[:, -1] == [0.5, 1.0, 0.5, 1.0]).all() + assert (ec[:, -1] == [0.25, 1.0, 0.25, 1.0]).all() + + # Test adding and removing single channels. + lasso.select_one(2) # should not do anything without modifier keys + assert lasso.selection == ["B", "D"] + _fake_keypress(fig, "control") + lasso.select_one(2) # add to selection + _fake_keypress(fig, "control", kind="release") + assert lasso.selection == ["B", "C", "D"] + _fake_keypress(fig, "ctrl+shift") + lasso.select_one(1) # remove from selection + assert lasso.selection == ["C", "D"] + _fake_keypress(fig, "ctrl+shift", kind="release") diff --git a/mne/viz/topo.py b/mne/viz/topo.py index 3364a455aed..5c43d4de48e 100644 --- a/mne/viz/topo.py +++ b/mne/viz/topo.py @@ -13,8 +13,10 @@ from .._fiff.pick import _picks_to_idx, channel_type, pick_types from ..defaults import _handle_default from ..utils import Bunch, _check_option, _clean_names, _is_numeric, _to_rgb, fill_doc +from .ui_events import ChannelsSelect, publish, subscribe from .utils import ( DraggableColorbar, + SelectFromCollection, _check_cov, _check_delayed_ssp, _draw_proj_checkbox, @@ -37,6 +39,7 @@ def iter_topography( axis_spinecolor="k", layout_scale=None, legend=False, + select=False, ): """Create iterator over channel positions. @@ -72,6 +75,12 @@ def iter_topography( If True, an additional axis is created in the bottom right corner that can be used to, e.g., construct a legend. The index of this axis will be -1. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 Returns ------- @@ -93,6 +102,7 @@ def iter_topography( axis_spinecolor, layout_scale, legend=legend, + select=select, ) @@ -128,6 +138,7 @@ def _iter_topography( img=False, axes=None, legend=False, + select=False, ): """Iterate over topography. @@ -193,8 +204,11 @@ def format_coord_multiaxis(x, y, ch_name=None): under_ax.set(xlim=[0, 1], ylim=[0, 1]) axs = list() + + shown_ch_names = [] for idx, name in iter_ch: ch_idx = ch_names.index(name) + shown_ch_names.append(name) if not unified: # old, slow way ax = plt.axes(pos[idx]) ax.patch.set_facecolor(axis_facecolor) @@ -226,24 +240,48 @@ def format_coord_multiaxis(x, y, ch_name=None): if unified: under_ax._mne_axs = axs # Create a PolyCollection for the axis backgrounds + sel_pos = pos[[i[0] for i in iter_ch]] verts = np.transpose( [ - pos[:, :2], - pos[:, :2] + pos[:, 2:] * [1, 0], - pos[:, :2] + pos[:, 2:], - pos[:, :2] + pos[:, 2:] * [0, 1], + sel_pos[:, :2], + sel_pos[:, :2] + sel_pos[:, 2:] * [1, 0], + sel_pos[:, :2] + sel_pos[:, 2:], + sel_pos[:, :2] + sel_pos[:, 2:] * [0, 1], ], [1, 0, 2], ) - if not img: - under_ax.add_collection( - collections.PolyCollection( - verts, - facecolor=axis_facecolor, - edgecolor=axis_spinecolor, - linewidth=1.0, + if not img: # Not needed for image plots. + collection = collections.PolyCollection( + verts, + facecolor=axis_facecolor, + edgecolor=axis_spinecolor, + linewidth=1.0, + ) + under_ax.add_collection(collection) + + if select: + # Configure the lasso-selection tool + fig.lasso = SelectFromCollection( + ax=under_ax, + collection=collection, + names=shown_ch_names, + alpha_nonselected=0, + alpha_selected=1, + linewidth_nonselected=0, + linewidth_selected=0.7, ) - ) # Not needed for image plots. + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + selection_inds = np.flatnonzero( + np.isin(shown_ch_names, event.ch_names) + ) + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) for ax in axs: yield ax, ax._mne_ch_idx @@ -270,6 +308,7 @@ def _plot_topo( unified=False, img=False, axes=None, + select=False, ): """Plot on sensor layout.""" import matplotlib.pyplot as plt @@ -322,6 +361,7 @@ def _plot_topo( unified=unified, img=img, axes=axes, + select=select, ) for ax, ch_idx in my_topo_plot: @@ -340,8 +380,17 @@ def _plot_topo( def _plot_topo_onpick(event, show_func): """Onpick callback that shows a single channel in a new figure.""" - # make sure that the swipe gesture in OS-X doesn't open many figures orig_ax = event.inaxes + fig = orig_ax.figure + + # If we are doing lasso select, allow it to handle the click instead. + if hasattr(fig, "lasso") and event.key in ["control", "ctrl+shift"]: + return + + # make sure that the swipe gesture in OS-X doesn't open many figures + if fig.canvas._key in ["shift", "alt"]: + return + import matplotlib.pyplot as plt try: @@ -838,9 +887,10 @@ def _plot_evoked_topo( merge_channels=False, legend=True, axes=None, + noise_cov=None, exclude="bads", + select=False, show=True, - noise_cov=None, ): """Plot 2D topography of evoked responses. @@ -912,6 +962,10 @@ def _plot_evoked_topo( exclude : list of str | 'bads' Channels names to exclude from being shown. If 'bads', the bad channels are excluded. By default, exclude is set to 'bads'. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. show : bool Show figure if True. @@ -1091,6 +1145,7 @@ def _plot_evoked_topo( y_label=y_label, unified=True, axes=axes, + select=select, ) add_background_image(fig, fig_background) @@ -1098,7 +1153,10 @@ def _plot_evoked_topo( if legend is not False: legend_loc = 0 if legend is True else legend labels = [e.comment if e.comment else "Unknown" for e in evoked] - handles = fig.axes[0].lines[: len(evoked)] + if select: + handles = fig.axes[0].lines[1 : len(evoked) + 1] + else: + handles = fig.axes[0].lines[: len(evoked)] legend = plt.legend( labels=labels, handles=handles, loc=legend_loc, prop={"size": 10} ) @@ -1157,6 +1215,7 @@ def plot_topo_image_epochs( fig_facecolor="k", fig_background=None, font_color="w", + select=False, show=True, ): """Plot Event Related Potential / Fields image on topographies. @@ -1204,6 +1263,12 @@ def plot_topo_image_epochs( :func:`matplotlib.pyplot.imshow`. Defaults to ``None``. font_color : color The color of tick labels in the colorbar. Defaults to white. + select : bool + Whether to enable the lasso-selection tool to enable the user to select + channels. The selected channels will be available in + ``fig.lasso.selection``. + + .. versionadded:: 1.10.0 show : bool Whether to show the figure. Defaults to ``True``. @@ -1293,6 +1358,7 @@ def plot_topo_image_epochs( y_label="Epoch", unified=True, img=True, + select=select, ) add_background_image(fig, fig_background) plt_show(show) diff --git a/mne/viz/ui_events.py b/mne/viz/ui_events.py index 256d5741ad3..b8b3fe29a4d 100644 --- a/mne/viz/ui_events.py +++ b/mne/viz/ui_events.py @@ -212,6 +212,26 @@ class Contours(UIEvent): contours: list[str] +@dataclass +@fill_doc +class ChannelsSelect(UIEvent): + """Indicates that the user has selected one or more channels. + + Parameters + ---------- + ch_names : list of str + The names of the channels that were selected. + + Attributes + ---------- + %(ui_event_name_source)s + ch_names : list of str + The names of the channels that were selected. + """ + + ch_names: list[str] + + def _get_event_channel(fig): """Get the event channel associated with a figure. diff --git a/mne/viz/utils.py b/mne/viz/utils.py index a09da17de7d..b9b844b321a 100644 --- a/mne/viz/utils.py +++ b/mne/viz/utils.py @@ -58,7 +58,7 @@ warn, ) from ..utils.misc import _identity_function -from .ui_events import ColormapRange, publish, subscribe +from .ui_events import ChannelsSelect, ColormapRange, publish, subscribe _channel_type_prettyprint = { "eeg": "EEG channel", @@ -807,12 +807,12 @@ def _fake_click(fig, ax, point, xform="ax", button=1, kind="press", key=None): ) -def _fake_keypress(fig, key): +def _fake_keypress(fig, key, kind="press"): from matplotlib import backend_bases fig.canvas.callbacks.process( - "key_press_event", - backend_bases.KeyEvent(name="key_press_event", canvas=fig.canvas, key=key), + f"key_{kind}_event", + backend_bases.KeyEvent(name=f"key_{kind}_event", canvas=fig.canvas, key=key), ) @@ -952,7 +952,7 @@ def plot_sensors( Whether to plot the sensors as 3d, topomap or as an interactive sensor selection dialog. Available options ``'topomap'``, ``'3d'``, ``'select'``. If ``'select'``, a set of channels can be selected - interactively by using lasso selector or clicking while holding control + interactively by using lasso selector or clicking while holding the control key. The selected channels are returned along with the figure instance. Defaults to ``'topomap'``. ch_type : None | str @@ -1163,10 +1163,10 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names): if event.mouseevent.inaxes != ax: return - if event.mouseevent.key == "control" and fig.lasso is not None: + if fig.lasso is not None and event.mouseevent.key in ["control", "ctrl+shift"]: + # Add the sensor to the selection instead of showing its name. for ind in event.ind: fig.lasso.select_one(ind) - return if show_names: return # channel names already visible @@ -1185,7 +1185,7 @@ def _onpick_sensor(event, fig, ax, pos, ch_names, show_names): fig.canvas.draw() -def _close_event(event, fig): +def _close_event(event=None, fig=None): """Listen for sensor plotter close event.""" if getattr(fig, "lasso", None) is not None: fig.lasso.disconnect() @@ -1272,7 +1272,17 @@ def _plot_sensors_2d( lw=linewidth, ) if kind == "select": - fig.lasso = SelectFromCollection(ax, pts, ch_names) + fig.lasso = SelectFromCollection(ax, pts, names=ch_names) + + def on_select(): + publish(fig, ChannelsSelect(ch_names=fig.lasso.selection)) + + def on_channels_select(event): + selection_inds = np.flatnonzero(np.isin(ch_names, event.ch_names)) + fig.lasso.select_many(selection_inds) + + fig.lasso.callbacks.append(on_select) + subscribe(fig, "channels_select", on_channels_select) else: fig.lasso = None @@ -1595,11 +1605,14 @@ def _update(self): class SelectFromCollection: - """Select channels from a matplotlib collection using ``LassoSelector``. + """Select objects from a matplotlib collection using ``LassoSelector``. - Selected channels are saved in the ``selection`` attribute. This tool - highlights selected points by fading other points out (i.e., reducing their - alpha values). + The names of the selected objects are saved in the ``selection`` attribute. + This tool highlights selected objects by fading other objects out (i.e., + reducing their alpha values). + + Holding down the Control key will add to the current selection, and holding down + Control+Shift will remove from the current selection. Parameters ---------- @@ -1607,112 +1620,144 @@ class SelectFromCollection: Axes to interact with. collection : instance of matplotlib collection Collection you want to select from. - alpha_other : 0 <= float <= 1 - To highlight a selection, this tool sets all selected points to an - alpha value of 1 and non-selected points to ``alpha_other``. - Defaults to 0.3. - linewidth_other : float - Linewidth to use for non-selected sensors. Default is 1. + names : list of str + The names of the object. The selection is returned as a subset of these names. + alpha_selected : float + Alpha for selected objects (0=tranparant, 1=opaque). + alpha_nonselected : float + Alpha for non-selected objects (0=tranparant, 1=opaque). + linewidth_selected : float + Linewidth for the borders of selected objects. + linewidth_nonselected : float + Linewidth for the borders of non-selected objects. Notes ----- - This tool selects collection objects based on their *origins* - (i.e., ``offsets``). Calls all callbacks in self.callbacks when selection - is ready. + This tool selects collection objects which bounding boxes intersect with a lasso + path. Calls all callbacks in self.callbacks when selection is ready. """ def __init__( self, ax, collection, - ch_names, - alpha_other=0.5, - linewidth_other=0.5, + *, + names, alpha_selected=1, + alpha_nonselected=0.5, linewidth_selected=1, + linewidth_nonselected=0.5, + verbose=None, ): from matplotlib.widgets import LassoSelector + self.fig = ax.figure self.canvas = ax.figure.canvas self.collection = collection - self.ch_names = ch_names - self.alpha_other = alpha_other - self.linewidth_other = linewidth_other + self.names = names self.alpha_selected = alpha_selected + self.alpha_nonselected = alpha_nonselected self.linewidth_selected = linewidth_selected + self.linewidth_nonselected = linewidth_nonselected - self.xys = collection.get_offsets() - self.Npts = len(self.xys) + from matplotlib.collections import PolyCollection + from matplotlib.path import Path - # Ensure that we have separate colors for each object + if isinstance(collection, PolyCollection): + self.paths = collection.get_paths() + else: + self.paths = [Path([point]) for point in collection.get_offsets()] + self.Npts = len(self.paths) + if self.Npts != len(names): + raise ValueError( + f"Number of names ({len(names)}) does not match the number of objects " + f"in the collection ({self.Npts})." + ) + + # Ensure that we have colors for each object. self.fc = collection.get_facecolors() self.ec = collection.get_edgecolors() - self.lw = collection.get_linewidths() if len(self.fc) == 0: raise ValueError("Collection must have a facecolor") elif len(self.fc) == 1: self.fc = np.tile(self.fc, self.Npts).reshape(self.Npts, -1) + if len(self.ec) == 0: + self.ec = np.zeros((self.Npts, 4)) # all black + elif len(self.ec) == 1: self.ec = np.tile(self.ec, self.Npts).reshape(self.Npts, -1) - self.fc[:, -1] = self.alpha_other # deselect in the beginning - self.ec[:, -1] = self.alpha_other - self.lw = np.full(self.Npts, self.linewidth_other) + self.lw = np.full(self.Npts, float(self.linewidth_nonselected)) + # Initialize the lasso selector self.lasso = LassoSelector( ax, onselect=self.on_select, props=dict(color="red", linewidth=0.5) ) self.selection = list() + self.selection_inds = np.array([], dtype="int") self.callbacks = list() + # Deselect everything in the beginning. + self.style_objects() + + # For backwards compatibility + @property + def ch_names(self): + return self.names + + def notify(self): + """Notify listeners that a selection has been made.""" + logger.info(f"Selected channels: {self.selection}") + for callback in self.callbacks: + callback() + def on_select(self, verts): """Select a subset from the collection.""" from matplotlib.path import Path - if len(verts) <= 3: # Seems to be a good way to exclude single clicks. + # Don't respond to single clicks without extra keys being hold down. + # Figures like plot_evoked_topo want to do something else with them. + if len(verts) <= 3 and self.canvas._key not in ["control", "ctrl+shift"]: return path = Path(verts) - inds = np.nonzero([path.contains_point(xy) for xy in self.xys])[0] + inds = np.nonzero([path.intersects_path(p) for p in self.paths])[0] if self.canvas._key == "control": # Appending selection. - sels = [np.where(self.ch_names == c)[0][0] for c in self.selection] - inters = set(inds) - set(sels) - inds = list(inters.union(set(sels) - set(inds))) - - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selection_inds = np.union1d(self.selection_inds, inds).astype("int") + elif self.canvas._key == "ctrl+shift": + self.selection_inds = np.setdiff1d(self.selection_inds, inds).astype("int") + else: + self.selection_inds = inds + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() self.notify() def select_one(self, ind): """Select or deselect one sensor.""" - ch_name = self.ch_names[ind] - if ch_name in self.selection: - sel_ind = self.selection.index(ch_name) - self.selection.pop(sel_ind) + if self.canvas._key == "control": + self.selection_inds = np.union1d(self.selection_inds, [ind]) + elif self.canvas._key == "ctrl+shift": + self.selection_inds = np.setdiff1d(self.selection_inds, [ind]) else: - self.selection.append(ch_name) - inds = np.isin(self.ch_names, self.selection).nonzero()[0] - self.style_sensors(inds) + return # don't notify() + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() self.notify() - def notify(self): - """Notify listeners that a selection has been made.""" - for callback in self.callbacks: - callback() - def select_many(self, inds): """Select many sensors using indices (for predefined selections).""" - self.selection[:] = np.array(self.ch_names)[inds].tolist() - self.style_sensors(inds) + self.selection_inds = inds + self.selection = [self.names[i] for i in self.selection_inds] + self.style_objects() - def style_sensors(self, inds): + def style_objects(self): """Style selected sensors as "active".""" # reset - self.fc[:, -1] = self.alpha_other - self.ec[:, -1] = self.alpha_other / 2 - self.lw[:] = self.linewidth_other + self.fc[:, -1] = self.alpha_nonselected + self.ec[:, -1] = self.alpha_nonselected / 2 + self.lw[:] = self.linewidth_nonselected # style sensors at `inds` - self.fc[inds, -1] = self.alpha_selected - self.ec[inds, -1] = self.alpha_selected - self.lw[inds] = self.linewidth_selected + self.fc[self.selection_inds, -1] = self.alpha_selected + self.ec[self.selection_inds, -1] = self.alpha_selected + self.lw[self.selection_inds] = self.linewidth_selected self.collection.set_facecolors(self.fc) self.collection.set_edgecolors(self.ec) self.collection.set_linewidths(self.lw) diff --git a/tools/circleci_dependencies.sh b/tools/circleci_dependencies.sh index dd3216ebf06..b306bb528f4 100755 --- a/tools/circleci_dependencies.sh +++ b/tools/circleci_dependencies.sh @@ -13,4 +13,4 @@ python -m pip install --upgrade --progress-bar off \ mne-icalabel mne-lsl mne-microstates mne-nirs mne-rsa \ neurodsp neurokit2 niseq nitime pactools \ plotly pycrostates pyprep pyriemann python-picard sesameeg \ - sleepecg tensorpac yasa meegkit eeg_positions + sleepecg tensorpac yasa meegkit eeg_positions wfdb diff --git a/tools/github_actions_dependencies.sh b/tools/github_actions_dependencies.sh index cebd2caefa7..d47d9070f8b 100755 --- a/tools/github_actions_dependencies.sh +++ b/tools/github_actions_dependencies.sh @@ -23,7 +23,7 @@ if [ ! -z "$CONDA_ENV" ]; then elif [[ "${MNE_CI_KIND}" == "pip" ]]; then # Only used for 3.13 at the moment, just get test deps plus a few extras # that we know are available - INSTALL_ARGS="nibabel scikit-learn numpydoc PySide6 mne-qt-browser pandas h5io mffpy defusedxml" + INSTALL_ARGS="nibabel scikit-learn numpydoc PySide6 mne-qt-browser pandas h5io mffpy defusedxml numba" INSTALL_KIND="test" else test "${MNE_CI_KIND}" == "pip-pre" diff --git a/tools/github_actions_env_vars.sh b/tools/github_actions_env_vars.sh index 8accf72a11a..9f424ae5f48 100755 --- a/tools/github_actions_env_vars.sh +++ b/tools/github_actions_env_vars.sh @@ -28,7 +28,7 @@ else # conda-like echo "MNE_LOGGING_LEVEL=warning" | tee -a $GITHUB_ENV echo "MNE_QT_BACKEND=PySide6" | tee -a $GITHUB_ENV # TODO: Also need "|unreliable on GitHub Actions conda" on macOS, but omit for now to make sure the failure actually shows up - echo "MNE_TEST_ALLOW_SKIP=.*(Requires (spm|brainstorm) dataset|CUDA not|PySide6 causes segfaults).*" | tee -a $GITHUB_ENV + echo "MNE_TEST_ALLOW_SKIP=.*(Requires (spm|brainstorm) dataset|CUDA not|PySide6 causes segfaults|Accelerate|Flakey verbose behavior).*" | tee -a $GITHUB_ENV fi fi set +x diff --git a/tools/vulture_allowlist.py b/tools/vulture_allowlist.py index 24bcd9af64a..9d0e215ee80 100644 --- a/tools/vulture_allowlist.py +++ b/tools/vulture_allowlist.py @@ -41,6 +41,8 @@ # Decoding _._more_tags +_.multi_class +_.preserves_dtype deep # Backward compat or rarely used diff --git a/tutorials/epochs/60_make_fixed_length_epochs.py b/tutorials/epochs/60_make_fixed_length_epochs.py index 04a4ec87c7d..10b8c12ea19 100644 --- a/tutorials/epochs/60_make_fixed_length_epochs.py +++ b/tutorials/epochs/60_make_fixed_length_epochs.py @@ -5,13 +5,12 @@ ================================================= This tutorial shows how to segment continuous data into a set of epochs spaced -equidistantly in time. The epochs will not be created based on experimental -events; instead, the continuous data will be "chunked" into consecutive epochs -(which may be temporally overlapping, adjacent, or separated). -We will also briefly demonstrate how to use these epochs in connectivity -analysis. +equidistantly in time. The epochs will not be created based on experimental events; +instead, the continuous data will be "chunked" into consecutive epochs (which may be +temporally overlapping, adjacent, or separated). We will also briefly demonstrate how +to use these epochs in connectivity analysis. -First, we import necessary modules and read in a sample raw data set. +First, we import the necessary modules and read in a sample raw data set. This data set contains brain activity that is event-related, i.e., synchronized to the onset of auditory stimuli. However, rather than creating epochs by segmenting the data around the onset of each stimulus, we will diff --git a/tutorials/evoked/10_evoked_overview.py b/tutorials/evoked/10_evoked_overview.py index 75e63692bd2..b251a1f8239 100644 --- a/tutorials/evoked/10_evoked_overview.py +++ b/tutorials/evoked/10_evoked_overview.py @@ -5,12 +5,11 @@ The Evoked data structure: evoked/averaged data =============================================== -This tutorial covers the basics of creating and working with :term:`evoked` -data. It introduces the :class:`~mne.Evoked` data structure in detail, -including how to load, query, subset, export, and plot data from an -:class:`~mne.Evoked` object. For details on creating an :class:`~mne.Evoked` -object from (possibly simulated) data in a :class:`NumPy array -`, see :ref:`tut-creating-data-structures`. +This tutorial covers the basics of creating and working with :term:`evoked` data. It +introduces the :class:`~mne.Evoked` data structure in detail, including how to load, +query, subset, export, and plot data from an :class:`~mne.Evoked` object. For details +on creating an :class:`~mne.Evoked` object from (possibly simulated) data in a +:class:`NumPy array `, see :ref:`tut-creating-data-structures`. As usual, we start by importing the modules we need: """ diff --git a/tutorials/intro/70_report.py b/tutorials/intro/70_report.py index cc32d02679b..fe87c0f3a44 100644 --- a/tutorials/intro/70_report.py +++ b/tutorials/intro/70_report.py @@ -12,11 +12,11 @@ and after each preprocessing step, epoch rejection statistics, MRI slices with overlaid BEM shells, all the way up to plots of estimated cortical activity. -Compared to a Jupyter notebook, :class:`mne.Report` is easier to deploy, as the -HTML pages it generates are self-contained and do not require a running Python -environment. However, it is less flexible as you can't change code and re-run -something directly within the browser. This tutorial covers the basics of -building a report. As usual, we will start by importing the modules and data we need: +Compared to a Jupyter notebook, :class:`mne.Report` is easier to deploy, as the HTML +pages it generates are self-contained and do not require a running Python environment. +However, it is less flexible as you can't change code and re-run something directly +within the browser. This tutorial covers the basics of building a report. As usual, +we will start by importing the modules and data we need: """ # Authors: The MNE-Python contributors. diff --git a/tutorials/inverse/20_dipole_fit.py b/tutorials/inverse/20_dipole_fit.py index 2b640aa8fc2..e72e76dd0fd 100644 --- a/tutorials/inverse/20_dipole_fit.py +++ b/tutorials/inverse/20_dipole_fit.py @@ -87,6 +87,7 @@ # %% # Calculate and visualise magnetic field predicted by dipole with maximum GOF # and compare to the measured data, highlighting the ipsilateral (right) source + fwd, stc = make_forward_dipole(dip, fname_bem, evoked.info, fname_trans) pred_evoked = simulate_evoked(fwd, stc, evoked.info, cov=None, nave=np.inf)