Skip to content

Commit

Permalink
Properly generate epoch structure.
Browse files Browse the repository at this point in the history
  • Loading branch information
jackz314 committed Aug 4, 2022
1 parent 1cbf426 commit f9e6e79
Showing 1 changed file with 25 additions and 7 deletions.
32 changes: 25 additions & 7 deletions eeglabio/epochs.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def export_set(fname, data, sfreq, events, tmin, tmax, ch_names, event_id=None,
ev_dur = np.zeros((trials,), dtype=np.int64)

# indices of epochs each event belongs to
ev_epoch = np.arange(1, trials + 1)
ev_epoch = ev_lat // data.shape[1] + 1

# merge annotations into events array
if annotations is not None:
Expand Down Expand Up @@ -123,13 +123,31 @@ def export_set(fname, data, sfreq, events, tmin, tmax, ch_names, event_id=None,
names=["type", "latency", "duration", "epoch"])

# construct epochs array
# true epochs array, one subarray per events in epoch
# make sure epoch count is increasing (it should be)
# splitting code from https://stackoverflow.com/a/43094244/8170714
epoch_start_idx = np.unique(all_epoch, return_index=True)[1][1:] # skip 0
ep_event = np.split(np.arange(1, len(all_epoch)+1, dtype=np.double),
epoch_start_idx)
# starting latency for each epoch in seconds
ep_lat_offset = (all_epoch - 1) * data.shape[1] / sfreq
all_lat_shifted = all_lat / sfreq - ep_lat_offset # shifted rel to ep onset
# convert lat, pos, type to cell arrays by converting to object arrays
ep_lat = np.split(all_lat_shifted.astype(dtype=object) * 1000,
epoch_start_idx)
ep_pos = np.split(all_epoch.astype(dtype=object), epoch_start_idx)
ep_types = np.split(all_types.astype(dtype=object), epoch_start_idx)

# regular one event per epoch
# same as the indices for event epoch, except use array
ep_event = [np.array(n) for n in ev_epoch]
ep_lat = [np.array(n) for n in ev_lat]
ep_types = [np.array(n) for n in ev_types]

epochs = fromarrays([ep_event, ep_lat, ep_types],
names=["event", "eventlatency", "eventtype"])
# ep_event = [np.array(n) for n in ev_epoch]
# ep_lat = [np.array(n) for n in ev_lat]
# ep_types = [np.array(n) for n in ev_types]

field_names = ["event", "eventlatency", "eventposition", "eventtype"]
epochs = fromarrays([np.array(arr, dtype=object) for arr in
[ep_event, ep_lat, ep_pos, ep_types]],
names=field_names)

if isinstance(ref_channels, list):
ref_channels = " ".join(ref_channels)
Expand Down

0 comments on commit f9e6e79

Please sign in to comment.