Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] #613 spike_train_generation module to handle multichannel AnalogSignal inputs #614

Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add unittests and add handling for analogsignals with multiple channe…
…ls to peak_extraction, spike_extraction
  • Loading branch information
Moritz-Alexander-Kern committed Jan 19, 2024
commit cff235301c34d4acb35e2b44446a84874c3fe55c
206 changes: 129 additions & 77 deletions elephant/spike_train_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from __future__ import division, print_function, unicode_literals

import warnings
from typing import List, Union, Optional
from typing import List, Literal, Union, Optional

import neo
import numpy as np
Expand Down Expand Up @@ -83,61 +83,12 @@
]


def spike_extraction(signal: Union[neo.core.AnalogSignal,
List[neo.core.AnalogSignal]],
threshold: pq.Quantity = 0.0 * pq.mV,
sign: str = 'above',
time_stamps: neo.core.SpikeTrain = None,
interval: tuple = (-2 * pq.ms, 4 * pq.ms),
always_as_list: bool = False
) -> Union[neo.core.SpikeTrain,
neo.core.spiketrainlist.SpikeTrainList]:
"""
Return the peak times for all events that cross threshold and the
waveforms. Usually used for extracting spikes from a membrane
potential to calculate waveform properties.

Parameters
---------- # noqa
signal : :class:`neo.core.AnalogSignal` or List[:class:`neo.core.AnalogSignal`]
An analog input signal or a list of analog input signals.
threshold : pq.Quantity, optional
Contains a value that must be reached for an event to be detected.
Default: 0.0 * pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
time_stamps : :class:`neo.core.SpikeTrain` , optional
Provides the time stamps around which the waveform is extracted. If it
is None, the function `peak_detection` is used to calculate the
`time_stamps` from signal.
Default: None
interval : tuple of :class:`pq.Quantity`
Specifies the time interval around the `time_stamps` where the waveform
is extracted.
Default: (-2 * pq.ms, 4 * pq.ms)

Returns
-------
result_st : :class:`neo.core.SpikeTrain` or :class:`neo.core.SpikeTrainList`

Contains the time_stamps of each of the spikes and the waveforms in
`result_st.waveforms`.

See Also
--------
:func:`elephant.spike_train_generation.peak_detection`
"""
# if isinstance(signal, neo.core.AnalogSignal):
# signals = [signal]
# elif isinstance(signal, list) and all(isinstance(s, neo.core.AnalogSignal)
# for s in signal):
# signals = signal
# else:
# raise TypeError("signal must be a neo.core.AnalogSignal or"
# f" a list of neo.core.AnalogSignal, not {type(signal)}")

def _spike_extraction(signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
time_stamps: neo.core.SpikeTrain = None,
interval: tuple = (-2 * pq.ms, 4 * pq.ms)
) -> neo.core.SpikeTrain:
# Get spike time_stamps
if time_stamps is None:
time_stamps = peak_detection(signal, threshold, sign=sign)
Expand Down Expand Up @@ -202,35 +153,84 @@ def spike_extraction(signal: Union[neo.core.AnalogSignal,
left_sweep=extr_left)


def threshold_detection(signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: str = 'above'
) -> neo.core.AnalogSignal:
def spike_extraction(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
time_stamps: neo.core.SpikeTrain = None,
interval: tuple = (-2 * pq.ms, 4 * pq.ms),
always_as_list: bool = False
) -> Union[neo.core.SpikeTrain,
neo.core.spiketrainlist.SpikeTrainList]:
"""
Returns the times when the analog signal crosses a threshold.
Usually used for extracting spike times from a membrane potential.
Return the peak times for all events that cross threshold and the
waveforms. Usually used for extracting spikes from a membrane
potential to calculate waveform properties.

Parameters
----------
signal : :class:`neo.core.AnalogSignal`
An analog input signal.
threshold : :class:`pq.Quantity`, optional
An analog input signal one or more channels.
threshold : pq.Quantity, optional
Contains a value that must be reached for an event to be detected.
Default: 0.0 * pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
time_stamps : :class:`neo.core.SpikeTrain` , optional
Provides the time stamps around which the waveform is extracted. If it
is None, the function `peak_detection` is used to calculate the
`time_stamps` from signal.
Default: None
interval : tuple of :class:`pq.Quantity`
Specifies the time interval around the `time_stamps` where the waveform
is extracted.
Default: (-2 * pq.ms, 4 * pq.ms)
always_as_list: bool, optional
If True, a list of neo.SpikeTrain is returned.
Default: False

Returns
-------
result_st :class:`neo.core.SpikeTrain`
Contains the spike times of each of the events (spikes) extracted from
the signal.
------- # noqa
result_st : :class:`neo.core.SpikeTrain`, list of :class:`neo.core.SpikeTrain`
Contains the time_stamps of each of the spikes and the waveforms in
`result_st.waveforms`.

See Also
--------
:func:`elephant.spike_train_generation.peak_detection`
"""
if isinstance(signal, neo.core.AnalogSignal):
if signal.shape[1] == 1:
if always_as_list:
return [_spike_extraction(signal, threshold=threshold,
time_stamps=time_stamps,
interval=interval,
sign=sign)]
else:
return _spike_extraction(signal, threshold=threshold,
time_stamps=time_stamps,
interval=interval,
sign=sign)
elif signal.shape[1] > 1:
return [_spike_extraction(neo.core.AnalogSignal(
signal[:, channel], sampling_rate=signal.sampling_rate),
threshold=threshold, sign=sign,
time_stamps=time_stamps,
interval=interval,
) for channel in range(signal.shape[1])]
else:
raise TypeError(
f"Signal must be AnalogSignal, provided: {type(signal)}")


def _threshold_detection(signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: str = 'above'
) -> neo.core.SpikeTrain:
if not isinstance(threshold, pq.Quantity):
raise ValueError('threshold must be a pq.Quantity')
raise TypeError('threshold must be a pq.Quantity')

if sign not in ('above', 'below'):
raise ValueError("sign should be 'above' or 'below'")
Expand Down Expand Up @@ -262,6 +262,57 @@ def threshold_detection(signal: neo.core.AnalogSignal,
return result_st


def threshold_detection(
signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: Literal['above', 'below'] = 'above',
always_as_list: bool = False,
) -> Union[neo.core.SpikeTrain, List[neo.core.SpikeTrain]]:
"""
Returns the times when the analog signal crosses a threshold.
Usually used for extracting spike times from a membrane potential.

Parameters
----------
signal : :class:`neo.core.AnalogSignal`
An analog input signal with one or multiple channels.
threshold : :class:`pq.Quantity`, optional
Contains a value that must be reached for an event to be detected.
Default: 0.0 * pq.mV
sign : {'above', 'below'}, optional
Determines whether to count threshold crossings that cross above or
below the threshold.
Default: 'above'
always_as_list: bool, optional
If True, a list of neo.SpikeTrain is returned.
Default: False

Returns
------- # noqa
result_st : :class:`neo.core.SpikeTrain`, List[:class:`neo.core.SpikeTrain`]
Contains the spike times of each of the events (spikes) extracted from
the signal. If `signal` is an AnalogSignal with multiple channels,
a list of AnalogSignals or `always_return_list=True` , a
list of :class:`neo.core.SpikeTrain` is returned.
"""
if isinstance(signal, neo.core.AnalogSignal):
if signal.shape[1] == 1:
if always_as_list:
return [_threshold_detection(signal, threshold=threshold,
sign=sign)]
else:
return _threshold_detection(signal, threshold=threshold,
sign=sign)
elif signal.shape[1] > 1:
return [_threshold_detection(neo.core.AnalogSignal(
signal[:, channel], sampling_rate=signal.sampling_rate),
threshold=threshold, sign=sign
) for channel in range(signal.shape[1])]
else:
raise TypeError(
f"Signal must be AnalogSignal, provided: {type(signal)}")


# legacy implementation of peak_detection
def _peak_detection(signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
Expand Down Expand Up @@ -324,7 +375,7 @@ def _peak_detection(signal: neo.core.AnalogSignal,

def peak_detection(signal: neo.core.AnalogSignal,
threshold: pq.Quantity = 0.0 * pq.mV,
sign: str = 'above',
sign: Literal['above', 'below'] = 'above',
as_array: bool = False,
always_as_list: bool = False
) -> Union[neo.core.SpikeTrain,
Expand All @@ -335,8 +386,8 @@ def peak_detection(signal: neo.core.AnalogSignal,
Similar to spike_train_generation.threshold_detection.

Parameters
---------- # noqa
signal : :class:`neo.core.AnalogSignal` or List[:class:`neo.core.AnalogSignal`]
----------
signal : :class:`neo.core.AnalogSignal`
An analog input signal or a list of analog input signals.
threshold : :class:`pq.Quantity`, optional
Contains a value that must be reached for an event to be detected.
Expand All @@ -354,12 +405,13 @@ def peak_detection(signal: neo.core.AnalogSignal,
Default: False

Returns
-------
result_st : :class:`neo.core.SpikeTrain` or List[:class:`neo.core.SpikeTrain`]
------- # noqa
result_st : :class:`neo.core.SpikeTrain`, List[:class:`neo.core.SpikeTrain`]
:class:`np.ndarrav`, List[:class:`np.ndarrav`]
Contains the spike times of each of the events (spikes) extracted from
the signal. If `signal` is an AnalogSignal with multiple channels,
a list of AnalogSignals or `always_return_list=True` , a
list of :class:`neo.core.SpikeTrain` is returned.
the signal.
If `signal` is an AnalogSignal with multiple channels or
`always_return_list=True` a list is returned.
"""
if isinstance(signal, neo.core.AnalogSignal):
if signal.shape[1] == 1:
Expand Down
Loading
Loading