Skip to content

Commit

Permalink
Initial commit.
Browse files Browse the repository at this point in the history
  • Loading branch information
jackz314 committed Apr 1, 2021
0 parents commit 7ee07da
Show file tree
Hide file tree
Showing 14 changed files with 422 additions and 0 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# PyEEGLAB

Python support for EEGLAB files
5 changes: 5 additions & 0 deletions pyeeglab/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from ._version import __version__
from . import epochs
from . import raw

__all__ = [epochs, raw]
3 changes: 3 additions & 0 deletions pyeeglab/_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""The version number."""

__version__ = '0.0.1'
88 changes: 88 additions & 0 deletions pyeeglab/epochs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import numpy as np
from numpy.core.records import fromarrays
from scipy.io import savemat
from utils import _get_eeglab_full_cords


def export_set(inst, fname):
"""Export Epochs to EEGLAB's .set format.
Parameters
----------
inst : mne.BaseEpochs
Epochs instance to save
fname : str
Name of the export file.
Notes
-----
Channel locations are expanded to the full EEGLAB format
For more details see .io.utils.cart_to_eeglab_full_coords
"""
# load data first
inst.load_data()

# remove extra epoc and STI channels
chs_drop = [ch for ch in ['epoc', 'STI 014'] if ch in inst.ch_names]
inst.drop_channels(chs_drop)

data = inst.get_data() * 1e6 # convert to microvolts
data = np.moveaxis(data, 0, 2) # convert to EEGLAB 3D format
fs = inst.info["sfreq"]
times = inst.times
trials = len(inst.events) # epoch count in EEGLAB

# get full EEGLAB coordinates to export
full_coords = _get_eeglab_full_cords(inst)

ch_names = inst.ch_names

# convert to record arrays for MATLAB format
chanlocs = fromarrays(
[ch_names, *full_coords.T, np.repeat('', len(ch_names))],
names=["labels", "X", "Y", "Z", "sph_theta", "sph_phi",
"sph_radius", "theta", "radius",
"sph_theta_besa", "sph_phi_besa", "type"])

# reverse order of event type dict to look up events faster
event_type_d = dict((v, k) for k, v in inst.event_id.items())
ev_types = [event_type_d[ev[2]] for ev in inst.events]

# EEGLAB latency, in units of data sample points
# ev_lat = [int(n) for n in self.events[:, 0]]
ev_lat = inst.events[:, 0]

# event durations should all be 0 except boundaries which we don't have
ev_dur = np.zeros((trials,), dtype=np.int64)

# indices of epochs each event belongs to
ev_epoch = np.arange(1, trials + 1)

# EEGLAB events format, also used for distinguishing epochs/trials
events = fromarrays([ev_types, ev_lat, ev_dur, ev_epoch],
names=["type", "latency", "duration", "epoch"])

# same as the indices for event epoch, except need to use array
ep_event = [np.array(n) for n in ev_epoch]
ep_lat = [np.array(n) for n in ev_lat]
ep_types = [np.array(n) for n in ev_types]

epochs = fromarrays([ep_event, ep_lat, ep_types],
names=["event", "eventlatency", "eventtype"])

eeg_d = dict(EEG=dict(data=data,
setname=fname,
nbchan=data.shape[0],
pnts=float(data.shape[1]),
trials=trials,
srate=fs,
xmin=times[0],
xmax=times[-1],
chanlocs=chanlocs,
event=events,
epoch=epochs,
icawinv=[],
icasphere=[],
icaweights=[]))
savemat(fname, eeg_d,
appendmat=False)
67 changes: 67 additions & 0 deletions pyeeglab/raw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np
from numpy.core.records import fromarrays
from scipy.io import savemat
from utils import _get_eeglab_full_cords


def export_set(inst, fname):
"""Export Raw to EEGLAB's .set format.
Parameters
----------
inst : mne.io.BaseRaw
Raw instance to save
fname : str
Name of the export file.
Notes
-----
Channel locations are expanded to the full EEGLAB format
For more details see .utils.cart_to_eeglab_full_coords
"""
# load data first
inst.load_data()

# remove extra epoc and STI channels
chs_drop = [ch for ch in ['epoc'] if ch in inst.ch_names]
if 'STI 014' in inst.ch_names and \
not (inst.filenames[0].endswith('.fif')):
chs_drop.append('STI 014')
inst.drop_channels(chs_drop)

data = inst.get_data() * 1e6 # convert to microvolts
fs = inst.info["sfreq"]
times = inst.times

# convert xyz to full eeglab coordinates
full_coords = _get_eeglab_full_cords(inst)

ch_names = inst.ch_names

# convert to record arrays for MATLAB format
chanlocs = fromarrays(
[ch_names, *full_coords.T, np.repeat('', len(ch_names))],
names=["labels", "X", "Y", "Z", "sph_theta", "sph_phi",
"sph_radius", "theta", "radius",
"sph_theta_besa", "sph_phi_besa", "type"])

events = fromarrays([inst.annotations.description,
inst.annotations.onset * fs + 1,
inst.annotations.duration * fs],
names=["type", "latency", "duration"])
eeg_d = dict(EEG=dict(data=data,
setname=fname,
nbchan=data.shape[0],
pnts=data.shape[1],
trials=1,
srate=fs,
xmin=times[0],
xmax=times[-1],
chanlocs=chanlocs,
event=events,
icawinv=[],
icasphere=[],
icaweights=[]))

savemat(fname, eeg_d,
appendmat=False)
3 changes: 3 additions & 0 deletions pyeeglab/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import os.path as op

data_dir = op.join(op.dirname(__file__), 'data')
Binary file added pyeeglab/tests/data/test-eve.fif
Binary file not shown.
Binary file added pyeeglab/tests/data/test_raw.fif
Binary file not shown.
43 changes: 43 additions & 0 deletions pyeeglab/tests/test_epochs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from pathlib import Path
import pytest
from epochs import export_set
from mne import read_events, pick_types, Epochs, read_epochs_eeglab
from mne.io import read_raw_fif
from os import path as op
import numpy as np
from numpy.testing import assert_allclose, assert_array_equal

raw_fname = Path(__file__).parent / "data" / "test_raw.fif"
event_name = Path(__file__).parent / "data" / 'test-eve.fif'


def _get_data(preload=False):
"""Get data."""
raw = read_raw_fif(raw_fname, preload=preload, verbose='warning')
events = read_events(event_name)
picks = pick_types(raw.info, meg=True, eeg=True, stim=True,
ecg=True, eog=True, include=['STI 014'],
exclude='bads')
return raw, events, picks


@pytest.mark.parametrize('preload', (True, False))
def test_export_set(tmpdir, preload):
"""Test saving an Epochs instance to EEGLAB's set format"""
raw, events = _get_data()[:2]
raw.load_data()
epochs = Epochs(raw, events, preload=preload)
temp_fname = op.join(str(tmpdir), 'test.set')
export_set(epochs, temp_fname)
epochs_read = read_epochs_eeglab(temp_fname)
assert epochs.ch_names == epochs_read.ch_names
cart_coords = np.array([d['loc'][:3]
for d in epochs.info['chs']]) # just xyz
cart_coords_read = np.array([d['loc'][:3]
for d in epochs_read.info['chs']])
assert_allclose(cart_coords, cart_coords_read)
assert_array_equal(epochs.events[:, 0],
epochs_read.events[:, 0]) # latency
assert epochs.event_id.keys() == epochs_read.event_id.keys() # just keys
assert_allclose(epochs.times, epochs_read.times)
assert_allclose(epochs.get_data(), epochs_read.get_data())
24 changes: 24 additions & 0 deletions pyeeglab/tests/test_raw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from pathlib import Path
from os import path as op
import numpy as np
from mne.io import read_raw_fif, read_raw_eeglab
from numpy.testing import assert_allclose
from raw import export_set
from utils_tests import _TempDir


def test_export_set():
"""Test saving a Raw instance to EEGLAB's set format"""
fname = Path(__file__).parent / "data" / "test_raw.fif"
raw = read_raw_fif(fname)
raw.load_data()
tmpdir = _TempDir()
temp_fname = op.join(str(tmpdir), 'test.set')
export_set(raw, temp_fname)
raw_read = read_raw_eeglab(temp_fname, preload=True)
assert raw.ch_names == raw_read.ch_names
cart_coords = np.array([d['loc'][:3] for d in raw.info['chs']]) # just xyz
cart_coords_read = np.array([d['loc'][:3] for d in raw_read.info['chs']])
assert_allclose(cart_coords, cart_coords_read)
assert_allclose(raw.times, raw_read.times)
assert_allclose(raw.get_data(), raw_read.get_data())
116 changes: 116 additions & 0 deletions pyeeglab/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import numpy as np


def _cart_to_eeglab_full_coords_xyz(x, y, z):
"""Convert Cartesian coordinates to EEGLAB full coordinates.
Also see https://github.com/sccn/eeglab/blob/develop/functions/sigprocfunc/convertlocs.m
Parameters
----------
x : ndarray, shape (n_points, )
Array of x coordinates
y : ndarray, shape (n_points, )
Array of y coordinates
z : ndarray, shape (n_points, )
Array of z coordinates
Returns
-------
sph_pts : ndarray, shape (n_points, 7)
Array containing points in spherical coordinates
(sph_theta, sph_phi, sph_radius, theta, radius,
sph_theta_besa, sph_phi_besa)
""" # noqa: E501

assert len(x) == len(y) == len(z)
out = np.empty((len(x), 7))

# https://github.com/sccn/eeglab/blob/develop/functions/sigprocfunc/topo2sph.m
def topo2sph(theta, radius):
c = np.empty((len(theta),))
h = np.empty((len(theta),))
for i, (t, r) in enumerate(zip(theta, radius)):
if t >= 0:
h[i] = 90 - t
else:
h[i] = -(90 + t)
if t != 0:
c[i] = np.sign(t) * 180 * r
else:
c[i] = 180 * r
return c, h

# cart to sph, see https://www.mathworks.com/help/matlab/ref/cart2sph.html
th = np.arctan2(y, x)
phi = np.arctan2(z, np.sqrt(np.square(x) + np.square(y)))
sph_r = np.sqrt(np.square(x) + np.square(y) + np.square(z))

# other stuff needed by EEGLAB
sph_theta = th / np.pi * 180
sph_phi = phi / np.pi * 180
sph_radius = sph_r
theta = -sph_theta
radius = 0.5 - sph_phi / 180
sph_theta_besa, sph_phi_besa = topo2sph(theta, radius)

# ordered based on EEGLAB order
out[:, 0] = sph_theta
out[:, 1] = sph_phi
out[:, 2] = sph_radius
out[:, 3] = theta
out[:, 4] = radius
out[:, 5] = sph_theta_besa
out[:, 6] = sph_phi_besa

out = np.nan_to_num(out)
return out


def _cart_to_eeglab_full_coords(cart):
"""Convert Cartesian coordinates to EEGLAB full coordinates.
Also see https://github.com/sccn/eeglab/blob/develop/functions/sigprocfunc/convertlocs.m
Parameters
----------
cart : ndarray, shape (n_points, 3)
Array containing points in Cartesian coordinates (x, y, z)
Returns
-------
sph_pts : ndarray, shape (n_points, 7)
Array containing points in spherical coordinates
(sph_theta, sph_phi, sph_radius, theta, radius,
sph_theta_besa, sph_phi_besa)
""" # noqa: E501

# based on transforms.py's _cart_to_sph()
assert cart.ndim == 2 and cart.shape[1] == 3
cart = np.atleast_2d(cart)
x, y, z = cart.T
return _cart_to_eeglab_full_coords_xyz(x, y, z)


def _get_eeglab_full_cords(inst):
"""Get full EEGLAB coords from MNE instance (Raw or Epochs)
Parameters
----------
inst: Epochs or Raw
Instance of epochs or raw to extract x,y,z coordinates from
Returns
-------
full_coords : ndarray, shape (n_channels, 10)
xyz + spherical and polar coords
see cart_to_eeglab_full_coords for more detail
"""
chs = inst.info["chs"]
cart_coords = np.array([d['loc'][:3] for d in chs])
# (-y x z) to (x y z)
cart_coords[:, 0] = -cart_coords[:, 0] # -y to y
cart_coords[:, [0, 1]] = cart_coords[:, [1, 0]] # swap x (1) and y (0)
other_coords = _cart_to_eeglab_full_coords(cart_coords)
full_coords = np.append(cart_coords, other_coords, 1) # hstack
return full_coords
24 changes: 24 additions & 0 deletions pyeeglab/utils_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import tempfile
from shutil import rmtree

class _TempDir(str):
"""Create and auto-destroy temp dir.
This is designed to be used with testing modules. Instances should be
defined inside test functions. Instances defined at module level can not
guarantee proper destruction of the temporary directory.
When used at module level, the current use of the __del__() method for
cleanup can fail because the rmtree function may be cleaned up before this
object (an alternative could be using the atexit module instead).
"""

def __new__(self): # noqa: D105
new = str.__new__(self, tempfile.mkdtemp(prefix='tmp_mne_tempdir_'))
return new

def __init__(self): # noqa: D102
self._path = self.__str__()

def __del__(self): # noqa: D105
rmtree(self._path, ignore_errors=True)
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mne
numpy
scipy
Loading

0 comments on commit 7ee07da

Please sign in to comment.