Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] #613 spike_train_generation module to handle multichannel AnalogSignal inputs #614

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
move input checks from _spike_extraction to spike_extraction
  • Loading branch information
Moritz-Alexander-Kern committed Jan 31, 2024
commit ce383bad0b8659b9d506e0a896aa5782904ed39b
21 changes: 10 additions & 11 deletions elephant/spike_train_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,6 @@ def _spike_extraction(signal: neo.core.AnalogSignal,
time_stamps: neo.core.SpikeTrain = None,
interval: tuple = (-2 * pq.ms, 4 * pq.ms)
) -> neo.core.SpikeTrain:
# Get spike time_stamps
if time_stamps is None:
time_stamps = peak_detection(signal, threshold, sign=sign)
elif hasattr(time_stamps, 'times'):
time_stamps = time_stamps.times
else:
raise TypeError("time_stamps must be None, a `neo.core.SpikeTrain`"
" or expose the.times interface")

if len(time_stamps) == 0:
return neo.SpikeTrain(time_stamps, units=signal.times.units,
t_start=signal.t_start, t_stop=signal.t_stop,
Expand Down Expand Up @@ -201,6 +192,15 @@ def spike_extraction(
--------
:func:`elephant.spike_train_generation.peak_detection`
"""
# Get spike time_stamps
if time_stamps is None:
time_stamps = peak_detection(signal, threshold, sign=sign)
elif hasattr(time_stamps, 'times'):
time_stamps = time_stamps.times
else:
raise TypeError("time_stamps must be None, a `neo.core.SpikeTrain`"
" or expose the.times interface")

if isinstance(signal, neo.core.AnalogSignal):
if signal.shape[1] == 1:
if always_as_list:
Expand Down Expand Up @@ -392,8 +392,7 @@ def peak_detection(signal: neo.core.AnalogSignal,
sign: Literal['above', 'below'] = 'above',
as_array: bool = False,
always_as_list: bool = False
) -> Union[neo.core.SpikeTrain,
List[neo.core.SpikeTrain]]:
) -> Union[neo.core.SpikeTrain, SpikeTrainList]:
"""
Return the peak times for all events that cross threshold.
Usually used for extracting spike times from a membrane potential.
Expand Down