Skip to content

Commit

Permalink
Calculation of the spike-triggered phase and amplitude (#121)
Browse files Browse the repository at this point in the history
* Added functions for calculating the spike-triggered phase
  • Loading branch information
mdenker authored and alperyeg committed Apr 4, 2018
1 parent b72b2e2 commit 1b45ecd
Show file tree
Hide file tree
Showing 2 changed files with 361 additions and 0 deletions.
171 changes: 171 additions & 0 deletions elephant/phase_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
"""
Methods for performing phase analysis.
:copyright: Copyright 2014-2018 by the Elephant team, see AUTHORS.txt.
:license: Modified BSD, see LICENSE.txt for details.
"""

import numpy as np
import quantities as pq


def spike_triggered_phase(hilbert_transform, spiketrains, interpolate):
"""
Calculate the set of spike-triggered phases of an AnalogSignal.
Parameters
----------
hilbert_transform : AnalogSignal or list of AnalogSignal
AnalogSignal of the complex analytic signal (e.g., returned by the
elephant.signal_processing.hilbert()). All spike trains are compared to
this signal, if only one signal is given. Otherwise, length of
hilbert_transform must match the length of spiketrains.
spiketrains : Spiketrain or list of Spiketrain
Spiketrains on which to trigger hilbert_transform extraction
interpolate : bool
If True, the phases and amplitudes of hilbert_transform for spikes
falling between two samples of signal is interpolated. Otherwise, the
closest sample of hilbert_transform is used.
Returns
-------
phases : list of arrays
Spike-triggered phases. Entries in the list correspond to the
SpikeTrains in spiketrains. Each entry contains an array with the
spike-triggered angles (in rad) of the signal.
amp : list of arrays
Corresponding spike-triggered amplitudes.
times : list of arrays
A list of times corresponding to the signal
Corresponding times (corresponds to the spike times).
Example
-------
Create a 20 Hz oscillatory signal sampled at 1 kHz and a random Poisson
spike train:
>>> f_osc = 20. * pq.Hz
>>> f_sampling = 1 * pq.ms
>>> tlen = 100 * pq.s
>>> time_axis = np.arange(
0, tlen.magnitude,
f_sampling.rescale(pq.s).magnitude) * pq.s
>>> analogsignal = AnalogSignal(
np.sin(2 * np.pi * (f_osc * time_axis).simplified.magnitude),
units=pq.mV, t_start=0 * pq.ms, sampling_period=f_sampling)
>>> spiketrain = elephant.spike_train_generation.
homogeneous_poisson_process(
50 * pq.Hz, t_start=0.0 * ms, t_stop=tlen.rescale(pq.ms))
Calculate spike-triggered phases and amplitudes of the oscillation:
>>> phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(analogsignal),
spiketrain,
interpolate=True)
"""

# Convert inputs to lists
if not isinstance(spiketrains, list):
spiketrains = [spiketrains]

if not isinstance(hilbert_transform, list):
hilbert_transform = [hilbert_transform]

# Number of signals
num_spiketrains = len(spiketrains)
num_phase = len(hilbert_transform)

if num_spiketrains != 1 and num_phase != 1 and \
num_spiketrains != num_phase:
raise ValueError(
"Number of spike trains and number of phase signals"
"must match, or either of the two must be a single signal.")

# For each trial, select the first input
start = [elem.t_start for elem in hilbert_transform]
stop = [elem.t_stop for elem in hilbert_transform]

result_phases = []
result_amps = []
result_times = []

# Step through each signal
for spiketrain_i, spiketrain in enumerate(spiketrains):
# Check which hilbert_transform AnalogSignal to look at - if there is
# only one then all spike trains relate to this one, otherwise the two
# lists of spike trains and phases are matched up
if num_phase > 1:
phase_i = spiketrain_i
else:
phase_i = 0

# Take only spikes which lie directly within the signal segment -
# ignore spikes sitting on the last sample
sttimeind = np.where(np.logical_and(
spiketrain >= start[phase_i], spiketrain < stop[phase_i]))[0]

# Find index into signal for each spike
ind_at_spike = np.round(
(spiketrain[sttimeind] - hilbert_transform[phase_i].t_start) /
hilbert_transform[phase_i].sampling_period).magnitude.astype(int)

# Extract times for speed reasons
times = hilbert_transform[phase_i].times

# Append new list to the results for this spiketrain
result_phases.append([])
result_amps.append([])
result_times.append([])

# Step through all spikes
for spike_i, ind_at_spike_j in enumerate(ind_at_spike):
# Difference vector between actual spike time and sample point,
# positive if spike time is later than sample point
dv = spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]

# Make sure ind_at_spike is to the left of the spike time
if dv < 0 and ind_at_spike_j > 0:
ind_at_spike_j = ind_at_spike_j - 1

if interpolate:
# Get relative spike occurrence between the two closest signal
# sample points
# if z->0 spike is more to the left sample
# if z->1 more to the right sample
z = (spiketrain[sttimeind[spike_i]] - times[ind_at_spike_j]) /\
hilbert_transform[phase_i].sampling_period

# Save hilbert_transform (interpolate on circle)
p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
p2 = np.angle(hilbert_transform[phase_i][ind_at_spike_j + 1])
result_phases[spiketrain_i].append(
np.angle(
(1 - z) * np.exp(np.complex(0, p1)) +
z * np.exp(np.complex(0, p2))))

# Save amplitude
result_amps[spiketrain_i].append(
(1 - z) * np.abs(
hilbert_transform[phase_i][ind_at_spike_j]) +
z * np.abs(hilbert_transform[phase_i][ind_at_spike_j + 1]))
else:
p1 = np.angle(hilbert_transform[phase_i][ind_at_spike_j])
result_phases[spiketrain_i].append(p1)

# Save amplitude
result_amps[spiketrain_i].append(
np.abs(hilbert_transform[phase_i][ind_at_spike_j]))

# Save time
result_times[spiketrain_i].append(spiketrain[sttimeind[spike_i]])

# Convert outputs to arrays
for i, entry in enumerate(result_phases):
result_phases[i] = np.array(entry).flatten()
for i, entry in enumerate(result_amps):
result_amps[i] = pq.Quantity(entry, units=entry[0].units).flatten()
for i, entry in enumerate(result_times):
result_times[i] = pq.Quantity(entry, units=entry[0].units).flatten()

return result_phases, result_amps, result_times
190 changes: 190 additions & 0 deletions elephant/test/test_phase_analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
# -*- coding: utf-8 -*-
"""
Unit tests for the phase analysis module.
:copyright: Copyright 2016 by the Elephant team, see AUTHORS.txt.
:license: Modified BSD, see LICENSE.txt for details.
"""
from __future__ import division, print_function

import unittest

from neo import SpikeTrain, AnalogSignal
import numpy as np
import quantities as pq

import elephant.phase_analysis

from numpy.ma.testutils import assert_allclose


class SpikeTriggeredPhaseTestCase(unittest.TestCase):

def setUp(self):
tlen0 = 100 * pq.s
f0 = 20. * pq.Hz
fs0 = 1 * pq.ms
t0 = np.arange(
0, tlen0.rescale(pq.s).magnitude,
fs0.rescale(pq.s).magnitude) * pq.s
self.anasig0 = AnalogSignal(
np.sin(2 * np.pi * (f0 * t0).simplified.magnitude),
units=pq.mV, t_start=0 * pq.ms, sampling_period=fs0)
self.st0 = SpikeTrain(
np.arange(50, tlen0.rescale(pq.ms).magnitude - 50, 50) * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)
self.st1 = SpikeTrain(
[100., 100.1, 100.2, 100.3, 100.9, 101.] * pq.ms,
t_start=0 * pq.ms, t_stop=tlen0)

def test_perfect_locking_one_spiketrain_one_signal(self):
phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st0,
interpolate=True)

assert_allclose(phases[0], - np.pi / 2.)
assert_allclose(amps[0], 1, atol=0.1)
assert_allclose(times[0].magnitude, self.st0.magnitude)
self.assertEqual(len(phases[0]), len(self.st0))
self.assertEqual(len(amps[0]), len(self.st0))
self.assertEqual(len(times[0]), len(self.st0))

def test_perfect_locking_many_spiketrains_many_signals(self):
phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
[
elephant.signal_processing.hilbert(self.anasig0),
elephant.signal_processing.hilbert(self.anasig0)],
[self.st0, self.st0],
interpolate=True)

assert_allclose(phases[0], -np.pi / 2.)
assert_allclose(amps[0], 1, atol=0.1)
assert_allclose(times[0].magnitude, self.st0.magnitude)
self.assertEqual(len(phases[0]), len(self.st0))
self.assertEqual(len(amps[0]), len(self.st0))
self.assertEqual(len(times[0]), len(self.st0))

def test_perfect_locking_one_spiketrains_many_signals(self):
phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
[
elephant.signal_processing.hilbert(self.anasig0),
elephant.signal_processing.hilbert(self.anasig0)],
[self.st0],
interpolate=True)

assert_allclose(phases[0], -np.pi / 2.)
assert_allclose(amps[0], 1, atol=0.1)
assert_allclose(times[0].magnitude, self.st0.magnitude)
self.assertEqual(len(phases[0]), len(self.st0))
self.assertEqual(len(amps[0]), len(self.st0))
self.assertEqual(len(times[0]), len(self.st0))

def test_perfect_locking_many_spiketrains_one_signal(self):
phases, amps, times = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
[self.st0, self.st0],
interpolate=True)

assert_allclose(phases[0], -np.pi / 2.)
assert_allclose(amps[0], 1, atol=0.1)
assert_allclose(times[0].magnitude, self.st0.magnitude)
self.assertEqual(len(phases[0]), len(self.st0))
self.assertEqual(len(amps[0]), len(self.st0))
self.assertEqual(len(times[0]), len(self.st0))

def test_interpolate(self):
phases_int, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st1,
interpolate=True)

self.assertLess(phases_int[0][0], phases_int[0][1])
self.assertLess(phases_int[0][1], phases_int[0][2])
self.assertLess(phases_int[0][2], phases_int[0][3])
self.assertLess(phases_int[0][3], phases_int[0][4])
self.assertLess(phases_int[0][4], phases_int[0][5])

phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
self.st1,
interpolate=False)

self.assertEqual(phases_noint[0][0], phases_noint[0][1])
self.assertEqual(phases_noint[0][1], phases_noint[0][2])
self.assertEqual(phases_noint[0][2], phases_noint[0][3])
self.assertEqual(phases_noint[0][3], phases_noint[0][4])
self.assertNotEqual(phases_noint[0][4], phases_noint[0][5])

# Verify that when using interpolation and the spike sits on the sample
# of the Hilbert transform, this is the same result as when not using
# interpolation with a spike slightly to the right
self.assertEqual(phases_noint[0][2], phases_int[0][0])
self.assertEqual(phases_noint[0][4], phases_int[0][0])

def test_inconsistent_numbers_spiketrains_hilbert(self):
self.assertRaises(
ValueError, elephant.phase_analysis.spike_triggered_phase,
[
elephant.signal_processing.hilbert(self.anasig0),
elephant.signal_processing.hilbert(self.anasig0)],
[self.st0, self.st0, self.st0], False)

self.assertRaises(
ValueError, elephant.phase_analysis.spike_triggered_phase,
[
elephant.signal_processing.hilbert(self.anasig0),
elephant.signal_processing.hilbert(self.anasig0)],
[self.st0, self.st0, self.st0], False)

def test_spike_earlier_than_hilbert(self):
# This is a spike clearly outside the bounds
st = SpikeTrain(
[-50, 50],
units='s', t_start=-100*pq.s, t_stop=100*pq.s)
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
st,
interpolate=False)
self.assertEqual(len(phases_noint[0]), 1)

# This is a spike right on the border (start of the signal is at 0s,
# spike sits at t=0s). By definition of intervals in
# Elephant (left borders inclusive, right borders exclusive), this
# spike is to be considered.
st = SpikeTrain(
[0, 50],
units='s', t_start=-100*pq.s, t_stop=100*pq.s)
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
st,
interpolate=False)
self.assertEqual(len(phases_noint[0]), 2)

def test_spike_later_than_hilbert(self):
# This is a spike clearly outside the bounds
st = SpikeTrain(
[1, 250],
units='s', t_start=-1*pq.s, t_stop=300*pq.s)
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
st,
interpolate=False)
self.assertEqual(len(phases_noint[0]), 1)

# This is a spike right on the border (length of the signal is 100s,
# spike sits at t=100s). However, by definition of intervals in
# Elephant (left borders inclusive, right borders exclusive), this
# spike is not to be considered.
st = SpikeTrain(
[1, 100],
units='s', t_start=-1*pq.s, t_stop=200*pq.s)
phases_noint, _, _ = elephant.phase_analysis.spike_triggered_phase(
elephant.signal_processing.hilbert(self.anasig0),
st,
interpolate=False)
self.assertEqual(len(phases_noint[0]), 1)


if __name__ == '__main__':
unittest.main()

0 comments on commit 1b45ecd

Please sign in to comment.