Skip to content

Commit

Permalink
Merge branch 'main' into dev-headsurface
Browse files Browse the repository at this point in the history
  • Loading branch information
vferat authored Feb 3, 2025
2 parents 927bc7a + 715540a commit 5628b0b
Show file tree
Hide file tree
Showing 65 changed files with 1,114 additions and 521 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/autofix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ jobs:
- run: pip install --upgrade towncrier pygithub gitpython numpy
- run: python ./.github/actions/rename_towncrier/rename_towncrier.py
- run: python ./tools/dev/ensure_headers.py
- uses: autofix-ci/action@ff86a557419858bb967097bfc916833f5647fa8c
- uses: autofix-ci/action@551dded8c6cc8a1054039c8bc0b8b48c51dfc6ef
3 changes: 3 additions & 0 deletions .github/workflows/check_changelog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ on: # yamllint disable-line rule:truthy
types: [opened, synchronize, labeled, unlabeled]
branches: ["main"]

permissions:
contents: read

jobs:
changelog_checker:
name: Check towncrier entry in doc/changes/devel/
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/circle_artifacts.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
on: [status] # yamllint disable-line rule:truthy
permissions:
contents: read
statuses: write
jobs:
circleci_artifacts_redirector_job:
if: "${{ startsWith(github.event.context, 'ci/circleci: build_docs') }}"
Expand Down
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
repos:
# Ruff mne
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.1
rev: v0.9.3
hooks:
- id: ruff
name: ruff lint mne
Expand All @@ -23,7 +23,7 @@ repos:

# Codespell
- repo: https://github.com/codespell-project/codespell
rev: v2.3.0
rev: v2.4.0
hooks:
- id: codespell
additional_dependencies:
Expand Down Expand Up @@ -82,7 +82,7 @@ repos:

# zizmor
- repo: https://github.com/woodruffw/zizmor-pre-commit
rev: v1.1.1
rev: v1.2.2
hooks:
- id: zizmor

Expand Down
1 change: 1 addition & 0 deletions doc/changes/devel/12071.newfeature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add new ``select`` parameter to :func:`mne.viz.plot_evoked_topo` and :meth:`mne.Evoked.plot_topo` to toggle lasso selection of sensors, by `Marijn van Vliet`_.
1 change: 1 addition & 0 deletions doc/changes/devel/12656.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug where :func:`mne.export.export_raw` does not correct for recording start time (:attr:`raw.first_time <mne.io.Raw.first_time>`) when exporting Raw instances to EDF or EEGLAB formats, by `Qian Chu`_.
7 changes: 7 additions & 0 deletions doc/changes/devel/13065.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Improved sklearn class compatibility and compliance, which resulted in some parameters of classes having an underscore appended to their name during ``fit``, such as:

- :class:`mne.decoding.FilterEstimator` parameter ``picks`` passed to the initializer is set as ``est.picks_``
- :class:`mne.decoding.UnsupervisedSpatialFilter` parameter ``estimator`` passed to the initializer is set as ``est.estimator_``

Unused ``verbose`` class parameters (that had no effect) were removed from :class:`~mne.decoding.PSDEstimator`, :class:`~mne.decoding.TemporalFilter`, and :class:`~mne.decoding.FilterEstimator` as well.
Changes by `Eric Larson`_.
1 change: 1 addition & 0 deletions doc/changes/devel/13070.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Return events when requested even when current matches the desired sfreq in :meth:`mne.io.Raw.resample` by :newcontrib:`Roy Eric Wieske`.
1 change: 1 addition & 0 deletions doc/changes/devel/13082.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix bug with automated Mesa 3D detection for proper 3D option setting on systems with software rendering, by `Eric Larson`_.
1 change: 1 addition & 0 deletions doc/changes/names.inc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@
.. _Roman Goj: https://romanmne.blogspot.co.uk
.. _Ross Maddox: https://www.urmc.rochester.edu/labs/maddox-lab.aspx
.. _Rotem Falach: https://github.com/Falach
.. _Roy Eric Wieske: https://github.com/Randomidous
.. _Sammi Chekroud: https://github.com/schekroud
.. _Samu Taulu: https://phys.washington.edu/people/samu-taulu
.. _Samuel Deslauriers-Gauthier: https://github.com/sdeslauriers
Expand Down
2 changes: 2 additions & 0 deletions doc/sphinxext/mne_doc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,8 @@ def reset_warnings(gallery_conf, fname):
r"numpy\.core is deprecated and has been renamed to numpy\._core",
# matplotlib
"__array_wrap__ must accept context and return_scalar.*",
# nibabel
"__array__ implementation doesn't accept.*",
):
warnings.filterwarnings( # deal with other modules having bad imports
"ignore", message=f".*{key}.*", category=DeprecationWarning
Expand Down
24 changes: 16 additions & 8 deletions doc/sphinxext/related_software.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def _get_packages() -> dict[str, str]:
packages = sorted(packages, key=lambda x: x.lower())
packages = [RENAMES.get(package, package) for package in packages]
out = dict()
reasons = []
for package in status_iterator(
packages, f"Adding {len(packages)} related software packages: "
):
Expand All @@ -183,12 +184,17 @@ def _get_packages() -> dict[str, str]:
else:
md = importlib.metadata.metadata(package)
except importlib.metadata.PackageNotFoundError:
pass # raise a complete error later
reasons.append(f"{package}: not found, needs to be installed")
continue # raise a complete error later
else:
# Every project should really have this
do_continue = False
for key in ("Summary",):
if key not in md:
raise ExtensionError(f"Missing {repr(key)} for {package}")
reasons.extend(f"{package}: missing {repr(key)}")
do_continue = True
if do_continue:
continue
# It is annoying to find the home page
url = None
if "Home-page" in md:
Expand All @@ -204,15 +210,17 @@ def _get_packages() -> dict[str, str]:
if url is not None:
break
else:
raise RuntimeError(
f"Could not find Home-page for {package} in:\n"
f"{sorted(set(md))}\nwith Summary:\n{md['Summary']}"
reasons.append(
f"{package}: could not find Home-page in {sorted(md)}"
)
continue
out[package]["url"] = url
out[package]["description"] = md["Summary"].replace("\n", "")
bad = [package for package in packages if not out[package]]
if bad and REQUIRE_METADATA:
raise ExtensionError(f"Could not find metadata for:\n{' '.join(bad)}")
reason_str = "\n".join(reasons)
if reason_str and REQUIRE_METADATA:
raise ExtensionError(
f"Could not find suitable metadata for related software:\n{reason_str}"
)

return out

Expand Down
2 changes: 1 addition & 1 deletion examples/decoding/linear_model_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@

# Extract and plot spatial filters and spatial patterns
for name, coef in (("patterns", model.patterns_), ("filters", model.filters_)):
# We fitted the linear model onto Z-scored data. To make the filters
# We fit the linear model on Z-scored data. To make the filters
# interpretable, we must reverse this normalization step
coef = scaler.inverse_transform([coef])[0]

Expand Down
1 change: 1 addition & 0 deletions examples/preprocessing/movement_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
##############################################################################
# After checking the annotated movement artifacts, calculate the new transform
# and plot it:

new_dev_head_t = compute_average_dev_head_t(raw, head_pos)
raw.info["dev_head_t"] = new_dev_head_t
fig = mne.viz.plot_alignment(
Expand Down
2 changes: 1 addition & 1 deletion mne/_fiff/proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,7 @@ def _has_eeg_average_ref_proj(


def _needs_eeg_average_ref_proj(info):
"""Determine if the EEG needs an averge EEG reference.
"""Determine if the EEG needs an average EEG reference.
This returns True if no custom reference has been applied and no average
reference projection is present in the list of projections.
Expand Down
1 change: 1 addition & 0 deletions mne/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ def pytest_configure(config: pytest.Config):
# pandas
ignore:\n*Pyarrow will become a required dependency of pandas.*:DeprecationWarning
ignore:np\.find_common_type is deprecated.*:DeprecationWarning
ignore:Python binding for RankQuantileOptions.*:
# pyvista <-> NumPy 2.0
ignore:__array_wrap__ must accept context and return_scalar arguments.*:DeprecationWarning
# nibabel <-> NumPy 2.0
Expand Down
4 changes: 2 additions & 2 deletions mne/cov.py
Original file line number Diff line number Diff line change
Expand Up @@ -1226,7 +1226,7 @@ def _compute_rank_raw_array(
from .io import RawArray

return _compute_rank(
RawArray(data, info, copy=None, verbose=_verbose_safe_false()),
RawArray(data, info, copy="auto", verbose=_verbose_safe_false()),
rank,
scalings,
info,
Expand Down Expand Up @@ -1405,7 +1405,7 @@ def _compute_covariance_auto(
# project back
cov = np.dot(eigvec.T, np.dot(cov, eigvec))
# undo bias
cov *= data.shape[0] / (data.shape[0] - 1)
cov *= data.shape[0] / max(data.shape[0] - 1, 1)
# undo scaling
_undo_scaling_cov(cov, picks_list, scalings)
method_ = method[ei]
Expand Down
11 changes: 8 additions & 3 deletions mne/decoding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import check_scoring
from sklearn.model_selection import KFold, StratifiedKFold, check_cv
from sklearn.utils import check_array, indexable
from sklearn.utils import check_array, check_X_y, indexable

from ..parallel import parallel_func
from ..utils import _pl, logger, verbose, warn
Expand Down Expand Up @@ -76,9 +76,9 @@ class LinearModel(MetaEstimatorMixin, BaseEstimator):
)

def __init__(self, model=None):
# TODO: We need to set this to get our tag checking to work properly
if model is None:
model = LogisticRegression(solver="liblinear")

self.model = model

def __sklearn_tags__(self):
Expand Down Expand Up @@ -122,7 +122,11 @@ def fit(self, X, y, **fit_params):
self : instance of LinearModel
Returns the modified instance.
"""
X = check_array(X, input_name="X")
if y is not None:
X = check_array(X)
else:
X, y = check_X_y(X, y)
self.n_features_in_ = X.shape[1]
if y is not None:
y = check_array(y, dtype=None, ensure_2d=False, input_name="y")
if y.ndim > 2:
Expand All @@ -133,6 +137,7 @@ def fit(self, X, y, **fit_params):

# fit the Model
self.model.fit(X, y, **fit_params)
self.model_ = self.model # for better sklearn compat

# Computes patterns using Haufe's trick: A = Cov_X . W . Precision_Y

Expand Down
99 changes: 43 additions & 56 deletions mne/decoding/csp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import numpy as np
from scipy.linalg import eigh
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_is_fitted

from .._fiff.meas_info import create_info
from ..cov import _compute_rank_raw_array, _regularized_covariance, _smart_eigh
Expand All @@ -19,10 +20,11 @@
fill_doc,
pinv,
)
from .transformer import MNETransformerMixin


@fill_doc
class CSP(TransformerMixin, BaseEstimator):
class CSP(MNETransformerMixin, BaseEstimator):
"""M/EEG signal decomposition using the Common Spatial Patterns (CSP).
This class can be used as a supervised decomposition to estimate spatial
Expand Down Expand Up @@ -112,49 +114,44 @@ def __init__(
component_order="mutual_info",
):
# Init default CSP
if not isinstance(n_components, int):
raise ValueError("n_components must be an integer.")
self.n_components = n_components
self.rank = rank
self.reg = reg

# Init default cov_est
if not (cov_est == "concat" or cov_est == "epoch"):
raise ValueError("unknown covariance estimation method")
self.cov_est = cov_est

# Init default transform_into
self.transform_into = _check_option(
"transform_into", transform_into, ["average_power", "csp_space"]
)

# Init default log
if transform_into == "average_power":
if log is not None and not isinstance(log, bool):
raise ValueError(
'log must be a boolean if transform_into == "average_power".'
)
else:
if log is not None:
raise ValueError('log must be a None if transform_into == "csp_space".')
self.transform_into = transform_into
self.log = log

_validate_type(norm_trace, bool, "norm_trace")
self.norm_trace = norm_trace
self.cov_method_params = cov_method_params
self.component_order = _check_option(
"component_order", component_order, ("mutual_info", "alternate")
self.component_order = component_order

def _validate_params(self, *, y):
_validate_type(self.n_components, int, "n_components")
if hasattr(self, "cov_est"):
_validate_type(self.cov_est, str, "cov_est")
_check_option("cov_est", self.cov_est, ("concat", "epoch"))
if hasattr(self, "norm_trace"):
_validate_type(self.norm_trace, bool, "norm_trace")
_check_option(
"transform_into", self.transform_into, ["average_power", "csp_space"]
)

def _check_Xy(self, X, y=None):
"""Check input data."""
if not isinstance(X, np.ndarray):
raise ValueError(f"X should be of type ndarray (got {type(X)}).")
if y is not None:
if len(X) != len(y) or len(y) < 1:
raise ValueError("X and y must have the same length.")
if X.ndim < 3:
raise ValueError("X must have at least 3 dimensions.")
if self.transform_into == "average_power":
_validate_type(
self.log,
(bool, None),
"log",
extra="when transform_into is 'average_power'",
)
else:
_validate_type(
self.log, None, "log", extra="when transform_into is 'csp_space'"
)
_check_option(
"component_order", self.component_order, ("mutual_info", "alternate")
)
self.classes_ = np.unique(y)
n_classes = len(self.classes_)
if n_classes < 2:
raise ValueError(f"n_classes must be >= 2, but got {n_classes} class")

def fit(self, X, y):
"""Estimate the CSP decomposition on epochs.
Expand All @@ -171,12 +168,9 @@ def fit(self, X, y):
self : instance of CSP
Returns the modified instance.
"""
self._check_Xy(X, y)

self._classes = np.unique(y)
n_classes = len(self._classes)
if n_classes < 2:
raise ValueError("n_classes must be >= 2.")
X, y = self._check_data(X, y=y, fit=True, return_y=True)
self._validate_params(y=y)
n_classes = len(self.classes_)
if n_classes > 2 and self.component_order == "alternate":
raise ValueError(
"component_order='alternate' requires two classes, but data contains "
Expand Down Expand Up @@ -225,13 +219,8 @@ def transform(self, X):
If self.transform_into == 'csp_space' then returns the data in CSP
space and shape is (n_epochs, n_components, n_times).
"""
if not isinstance(X, np.ndarray):
raise ValueError(f"X should be of type ndarray (got {type(X)}).")
if self.filters_ is None:
raise RuntimeError(
"No filters available. Please first fit CSP decomposition."
)

check_is_fitted(self, "filters_")
X = self._check_data(X)
pick_filters = self.filters_[: self.n_components]
X = np.asarray([np.dot(pick_filters, epoch) for epoch in X])

Expand Down Expand Up @@ -577,7 +566,7 @@ def _compute_covariance_matrices(self, X, y):

covs = []
sample_weights = []
for ci, this_class in enumerate(self._classes):
for ci, this_class in enumerate(self.classes_):
cov, weight = cov_estimator(
X[y == this_class],
cov_kind=f"class={this_class}",
Expand Down Expand Up @@ -689,7 +678,7 @@ def _normalize_eigenvectors(self, eigen_vectors, covs, sample_weights):
def _order_components(
self, covs, sample_weights, eigen_vectors, eigen_values, component_order
):
n_classes = len(self._classes)
n_classes = len(self.classes_)
if component_order == "mutual_info" and n_classes > 2:
mutual_info = self._compute_mutual_info(covs, sample_weights, eigen_vectors)
ix = np.argsort(mutual_info)[::-1]
Expand Down Expand Up @@ -889,10 +878,8 @@ def fit(self, X, y):
self : instance of SPoC
Returns the modified instance.
"""
self._check_Xy(X, y)

if len(np.unique(y)) < 2:
raise ValueError("y must have at least two distinct values.")
X, y = self._check_data(X, y=y, fit=True, return_y=True)
self._validate_params(y=y)

# The following code is directly copied from pyRiemann

Expand Down
Loading

0 comments on commit 5628b0b

Please sign in to comment.