-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Calculation of the spike-triggered phase and amplitude (#121)
* Added functions for calculating the spike-triggered phase
- Loading branch information
Showing
2 changed files
with
361 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Methods for performing phase analysis. | ||
:copyright: Copyright 2014-2018 by the Elephant team, see AUTHORS.txt. | ||
:license: Modified BSD, see LICENSE.txt for details. | ||
""" | ||
|
||
import numpy as np | ||
import quantities as pq | ||
|
||
|
||
def spike_triggered_phase(hilbert_transform, spiketrains, interpolate): | ||
""" | ||
Calculate the set of spike-triggered phases of an AnalogSignal. | ||
Parameters | ||
---------- | ||
hilbert_transform : AnalogSignal or list of AnalogSignal | ||
AnalogSignal of the complex analytic signal (e.g., returned by the | ||
elephant.signal_processing.hilbert()). All spike trains are compared to | ||
this signal, if only one signal is given. Otherwise, length of | ||
hilbert_transform must match the length of spiketrains. | ||
spiketrains : Spiketrain or list of Spiketrain | ||
Spiketrains on which to trigger hilbert_transform extraction | ||
interpolate : bool | ||
If True, the phases and amplitudes of hilbert_transform for spikes | ||
falling between two samples of signal is interpolated. Otherwise, the | ||
closest sample of hilbert_transform is used. | ||
Returns | ||
------- | ||
phases : list of arrays | ||
Spike-triggered phases. Entries in the list correspond to the | ||
SpikeTrains in spiketrains. Each entry contains an array with the | ||
spike-triggered angles (in rad) of the signal. | ||
amp : list of arrays | ||
Corresponding spike-triggered amplitudes. | ||
times : list of arrays | ||
A list of times corresponding to the signal | ||
Corresponding times (corresponds to the spike times). | ||
Example | ||
------- | ||
Create a 20 Hz oscillatory signal sampled at 1 kHz and a random Poisson | ||
spike train: | ||
>>> f_osc = 20. * pq.Hz | ||
>>> f_sampling = 1 * pq.ms | ||
>>> tlen = 100 * pq.s | ||
>>> time_axis = np.arange( | ||
0, tlen.magnitude, | ||
f_sampling.rescale(pq.s).magnitude) * pq.s | ||
>>> analogsignal = AnalogSignal( | ||
np.sin(2 * np.pi * (f_osc * time_axis).simplified.magnitude), | ||
units=pq.mV, t_start=0 * pq.ms, sampling_period=f_sampling) | ||
>>> spiketrain = elephant.spike_train_generation. | ||
homogeneous_poisson_process( | ||
50 * pq.Hz, t_start=0.0 * ms, t_stop=tlen.rescale(pq.ms)) | ||
Calculate spike-triggered phases and amplitudes of the oscillation: | ||
>>> phases, amps, times = elephant.phase_analysis.spike_triggered_phase( | ||
elephant.signal_processing.hilbert(analogsignal), | ||
spiketrain, | ||
interpolate=True) | ||
""" | ||
|
||
# Convert inputs to lists | ||
if not isinstance(spiketrains, list): | ||
spiketrains = [spiketrains] | ||
|
||
if not isinstance(hilbert_transform, list): | ||
hilbert_transform = [hilbert_transform] | ||
|
||
# Number of signals | ||
num_spiketrains = len(spiketrains) | ||
num_phase = len(hilbert_transform) | ||
|
||
if num_spiketrains != 1 and num_phase != 1 and \ | ||
num_spiketrains != num_phase: | ||
raise ValueError( | ||
"Number of spike trains and number of phase signals" | ||
"must match, or either of the two must be a single signal.") | ||
|
||
# For each trial, select the first input | ||
start = [elem.t_start for elem in hilbert_transform] | ||
stop = [elem.t_stop for elem in hilbert_transform] | ||
|
||
result_phases = [] | ||
result_amps = [] | ||
result_times = [] | ||
|
||
# Step through each signal | ||
for spiketrain_i, spiketrain in enumerate(spiketrains): | ||
# Check which hilbert_transform AnalogSignal to look at - if there is | ||
# only one then all spike trains relate to this one, otherwise the two | ||
# lists of spike trains and phases are matched up | ||
if num_phase > 1: | ||
phase_i = spiketrain_i | ||
else: | ||
phase_i = 0 | ||
|
||
# Take only spikes which lie directly within the signal segment - | ||
# ignore spikes sitting on the last sample | ||
sttimeind = np.where(np.logical_and( | ||
spiketrain >= start[phase_i], spiketrain < stop[phase_i]))[0] | ||
|
||
# Find index into signal for each spike | ||
ind_at_spike = np.round( | ||
(spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) / | ||
hilbert_transform[phase_i].sampling_period).magnitude.astype(int) | ||
|
||
# Extract times for speed reasons | ||
times = hilbert_transform[phase_i].times | ||
|
||
# Append new list to the results for this spiketrain | ||
result_phases.append([]) | ||
result_amps.append([]) | ||
result_times.append([]) | ||
|
||
# Step through all spikes | ||
for spike_i, ind_at_spike_j in enumerate(ind_at_spike): | ||
# Difference vector between actual spike time and sample point, | ||
# positive if spike time is later than sample point | ||
dv = spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j] | ||
|
||
# Make sure ind_at_spike is to the left of the spike time | ||
if dv < 0 and ind_at_spike_j > 0: | ||
ind_at_spike_j = ind_at_spike_j - 1 | ||
|
||
if interpolate: | ||
# Get relative spike occurrence between the two closest signal | ||
# sample points | ||
# if z->0 spike is more to the left sample | ||
# if z->1 more to the right sample | ||
z = (spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]) /\ | ||
hilbert_transform[phase_i].sampling_period | ||
|
||
# Save hilbert_transform (interpolate on circle) | ||
p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j]) | ||
p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1]) | ||
result_phases[spiketrain_i].append( | ||
np.angle( | ||
(1 - z) * np.exp(np.complex(0, p1)) + | ||
z * np.exp(np.complex(0, p2)))) | ||
|
||
# Save amplitude | ||
result_amps[spiketrain_i].append( | ||
(1 - z) * np.abs( | ||
hilbert_transform[phase_i][ind_at_spike_j]) + | ||
z * np.abs(hilbert_transform[phase_i][ind_at_spike_j + 1])) | ||
else: | ||
p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j]) | ||
result_phases[spiketrain_i].append(p1) | ||
|
||
# Save amplitude | ||
result_amps[spiketrain_i].append( | ||
np.abs(hilbert_transform[phase_i][ind_at_spike_j])) | ||
|
||
# Save time | ||
result_times[spiketrain_i].append(spiketrain[sttimeind[spike_i]]) | ||
|
||
# Convert outputs to arrays | ||
for i, entry in enumerate(result_phases): | ||
result_phases[i] = np.array(entry).flatten() | ||
for i, entry in enumerate(result_amps): | ||
result_amps[i] = pq.Quantity(entry, units=entry[0].units).flatten() | ||
for i, entry in enumerate(result_times): | ||
result_times[i] = pq.Quantity(entry, units=entry[0].units).flatten() | ||
|
||
return result_phases, result_amps, result_times |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Unit tests for the phase analysis module. | ||
:copyright: Copyright 2016 by the Elephant team, see AUTHORS.txt. | ||
:license: Modified BSD, see LICENSE.txt for details. | ||
""" | ||
from __future__ import division, print_function | ||
|
||
import unittest | ||
|
||
from neo import SpikeTrain, AnalogSignal | ||
import numpy as np | ||
import quantities as pq | ||
|
||
import elephant.phase_analysis | ||
|
||
from numpy.ma.testutils import assert_allclose | ||
|
||
|
||
class SpikeTriggeredPhaseTestCase(unittest.TestCase): | ||
|
||
def setUp(self): | ||
tlen0 = 100 * pq.s | ||
f0 = 20. * pq.Hz | ||
fs0 = 1 * pq.ms | ||
t0 = np.arange( | ||
0, tlen0.rescale(pq.s).magnitude, | ||
fs0.rescale(pq.s).magnitude) * pq.s | ||
self.anasig0 = AnalogSignal( | ||
np.sin(2 * np.pi * (f0 * t0).simplified.magnitude), | ||
units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0) | ||
self.st0 = SpikeTrain( | ||
np.arange(50, tlen0.rescale(pq.ms).magnitude - 50, 50) * pq.ms, | ||
t_start=0 * pq.ms, t_stop=tlen0) | ||
self.st1 = SpikeTrain( | ||
[100., 100.1, 100.2, 100.3, 100.9, 101.] * pq.ms, | ||
t_start=0 * pq.ms, t_stop=tlen0) | ||
|
||
def test_perfect_locking_one_spiketrain_one_signal(self): | ||
phases, amps, times = elephant.phase_analysis.spike_triggered_phase( | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
self.st0, | ||
interpolate=True) | ||
|
||
assert_allclose(phases[0], - np.pi / 2.) | ||
assert_allclose(amps[0], 1, atol=0.1) | ||
assert_allclose(times[0].magnitude, self.st0.magnitude) | ||
self.assertEqual(len(phases[0]), len(self.st0)) | ||
self.assertEqual(len(amps[0]), len(self.st0)) | ||
self.assertEqual(len(times[0]), len(self.st0)) | ||
|
||
def test_perfect_locking_many_spiketrains_many_signals(self): | ||
phases, amps, times = elephant.phase_analysis.spike_triggered_phase( | ||
[ | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
elephant.signal_processing.hilbert(self.anasig0)], | ||
[self.st0, self.st0], | ||
interpolate=True) | ||
|
||
assert_allclose(phases[0], -np.pi / 2.) | ||
assert_allclose(amps[0], 1, atol=0.1) | ||
assert_allclose(times[0].magnitude, self.st0.magnitude) | ||
self.assertEqual(len(phases[0]), len(self.st0)) | ||
self.assertEqual(len(amps[0]), len(self.st0)) | ||
self.assertEqual(len(times[0]), len(self.st0)) | ||
|
||
def test_perfect_locking_one_spiketrains_many_signals(self): | ||
phases, amps, times = elephant.phase_analysis.spike_triggered_phase( | ||
[ | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
elephant.signal_processing.hilbert(self.anasig0)], | ||
[self.st0], | ||
interpolate=True) | ||
|
||
assert_allclose(phases[0], -np.pi / 2.) | ||
assert_allclose(amps[0], 1, atol=0.1) | ||
assert_allclose(times[0].magnitude, self.st0.magnitude) | ||
self.assertEqual(len(phases[0]), len(self.st0)) | ||
self.assertEqual(len(amps[0]), len(self.st0)) | ||
self.assertEqual(len(times[0]), len(self.st0)) | ||
|
||
def test_perfect_locking_many_spiketrains_one_signal(self): | ||
phases, amps, times = elephant.phase_analysis.spike_triggered_phase( | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
[self.st0, self.st0], | ||
interpolate=True) | ||
|
||
assert_allclose(phases[0], -np.pi / 2.) | ||
assert_allclose(amps[0], 1, atol=0.1) | ||
assert_allclose(times[0].magnitude, self.st0.magnitude) | ||
self.assertEqual(len(phases[0]), len(self.st0)) | ||
self.assertEqual(len(amps[0]), len(self.st0)) | ||
self.assertEqual(len(times[0]), len(self.st0)) | ||
|
||
def test_interpolate(self): | ||
phases_int, _, _ = elephant.phase_analysis.spike_triggered_phase( | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
self.st1, | ||
interpolate=True) | ||
|
||
self.assertLess(phases_int[0][0], phases_int[0][1]) | ||
self.assertLess(phases_int[0][1], phases_int[0][2]) | ||
self.assertLess(phases_int[0][2], phases_int[0][3]) | ||
self.assertLess(phases_int[0][3], phases_int[0][4]) | ||
self.assertLess(phases_int[0][4], phases_int[0][5]) | ||
|
||
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
self.st1, | ||
interpolate=False) | ||
|
||
self.assertEqual(phases_noint[0][0], phases_noint[0][1]) | ||
self.assertEqual(phases_noint[0][1], phases_noint[0][2]) | ||
self.assertEqual(phases_noint[0][2], phases_noint[0][3]) | ||
self.assertEqual(phases_noint[0][3], phases_noint[0][4]) | ||
self.assertNotEqual(phases_noint[0][4], phases_noint[0][5]) | ||
|
||
# Verify that when using interpolation and the spike sits on the sample | ||
# of the Hilbert transform, this is the same result as when not using | ||
# interpolation with a spike slightly to the right | ||
self.assertEqual(phases_noint[0][2], phases_int[0][0]) | ||
self.assertEqual(phases_noint[0][4], phases_int[0][0]) | ||
|
||
def test_inconsistent_numbers_spiketrains_hilbert(self): | ||
self.assertRaises( | ||
ValueError, elephant.phase_analysis.spike_triggered_phase, | ||
[ | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
elephant.signal_processing.hilbert(self.anasig0)], | ||
[self.st0, self.st0, self.st0], False) | ||
|
||
self.assertRaises( | ||
ValueError, elephant.phase_analysis.spike_triggered_phase, | ||
[ | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
elephant.signal_processing.hilbert(self.anasig0)], | ||
[self.st0, self.st0, self.st0], False) | ||
|
||
def test_spike_earlier_than_hilbert(self): | ||
# This is a spike clearly outside the bounds | ||
st = SpikeTrain( | ||
[-50, 50], | ||
units='s', t_start=-100*pq.s, t_stop=100*pq.s) | ||
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
st, | ||
interpolate=False) | ||
self.assertEqual(len(phases_noint[0]), 1) | ||
|
||
# This is a spike right on the border (start of the signal is at 0s, | ||
# spike sits at t=0s). By definition of intervals in | ||
# Elephant (left borders inclusive, right borders exclusive), this | ||
# spike is to be considered. | ||
st = SpikeTrain( | ||
[0, 50], | ||
units='s', t_start=-100*pq.s, t_stop=100*pq.s) | ||
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
st, | ||
interpolate=False) | ||
self.assertEqual(len(phases_noint[0]), 2) | ||
|
||
def test_spike_later_than_hilbert(self): | ||
# This is a spike clearly outside the bounds | ||
st = SpikeTrain( | ||
[1, 250], | ||
units='s', t_start=-1*pq.s, t_stop=300*pq.s) | ||
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
st, | ||
interpolate=False) | ||
self.assertEqual(len(phases_noint[0]), 1) | ||
|
||
# This is a spike right on the border (length of the signal is 100s, | ||
# spike sits at t=100s). However, by definition of intervals in | ||
# Elephant (left borders inclusive, right borders exclusive), this | ||
# spike is not to be considered. | ||
st = SpikeTrain( | ||
[1, 100], | ||
units='s', t_start=-1*pq.s, t_stop=200*pq.s) | ||
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase( | ||
elephant.signal_processing.hilbert(self.anasig0), | ||
st, | ||
interpolate=False) | ||
self.assertEqual(len(phases_noint[0]), 1) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |