Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Integrate trials object with Fano factor #645

Open
wants to merge 19 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
4011266
add handling for trial object
Moritz-Alexander-Kern Oct 22, 2024
218a653
add tests
Moritz-Alexander-Kern Oct 22, 2024
ed5ca52
Merge branch 'master' into enh/trials_fano_factor
Moritz-Alexander-Kern Nov 14, 2024
29694fa
add tests for trial object pooling trials or spiketrains
Moritz-Alexander-Kern Nov 14, 2024
1711f43
add parameters pool_trials, pool_spiketrains
Moritz-Alexander-Kern Nov 14, 2024
7b0ccf9
add type check for pool parameters
Moritz-Alexander-Kern Nov 14, 2024
1925bf3
refactor type annotations
Moritz-Alexander-Kern Nov 14, 2024
3c1eb5d
add to docstring
Moritz-Alexander-Kern Nov 14, 2024
6bb2fc3
add user warning and did refactoring of function
Moritz-Alexander-Kern Nov 14, 2024
69a4d24
remove pool trials parameter
Moritz-Alexander-Kern Dec 10, 2024
5b1cf75
add paramter ignored for spiketrainslist to docstring
Moritz-Alexander-Kern Dec 10, 2024
303f363
remove user warning to manually check duration for numpy arrays
Moritz-Alexander-Kern Dec 10, 2024
e9e6778
remove pool_trials arg
Moritz-Alexander-Kern Dec 10, 2024
663b1c8
remove pool_trials from fano-factor
Moritz-Alexander-Kern Jan 13, 2025
381ae35
remove test for type error
Moritz-Alexander-Kern Jan 13, 2025
58e202b
Merge branch 'master' into enh/trials_fano_factor
Moritz-Alexander-Kern Jan 13, 2025
b75355a
fix docstring for return value
Moritz-Alexander-Kern Jan 13, 2025
fce4427
add test for results to fano factor
Moritz-Alexander-Kern Jan 14, 2025
f776d04
Merge branch 'master' into enh/trials_fano_factor
Moritz-Alexander-Kern Jan 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 45 additions & 31 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
import scipy.signal
from numpy import ndarray
from scipy.special import erf
from typing import Union
from typing import List, Union

import elephant.conversion as conv
import elephant.kernels as kernels
Expand Down Expand Up @@ -270,10 +270,11 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None):
return rates


def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):
def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray], elephant.trials.Trials],
warn_tolerance: pq.Quantity = 0.1 * pq.ms) -> Union[float, List[float], List[List[float]]]:
r"""
Evaluates the empirical Fano factor F of the spike counts of
a list of `neo.SpikeTrain` objects.
a list of `neo.SpikeTrain` objects or `elephant.trials.Trial` object.

Given the vector v containing the observed spike counts (one per
spike train) in the time window [t0, t1], F is defined as:
Expand All @@ -288,18 +289,20 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):

Parameters
----------
spiketrains : list
spiketrains : list or elephant.trials.Trial
List of `neo.SpikeTrain` or `pq.Quantity` or `np.ndarray` or list of
spike times for which to compute the Fano factor of spike counts.
spike times for which to compute the Fano factor of spike counts, or
an `elephant.trials.Trial` object. If a Trial object is used, spike trains are
pooled across trials before computing the Fano factor.
warn_tolerance : pq.Quantity
In case of a list of input neo.SpikeTrains, if their durations vary by
more than `warn_tolerence` in their absolute values, throw a warning
more than `warn_tolerance` in their absolute values, throw a warning
(see Notes).
Default: 0.1 ms

Returns
-------
fano : float
fano : float, list of floats or list of list of floats
Moritz-Alexander-Kern marked this conversation as resolved.
Show resolved Hide resolved
The Fano factor of the spike counts of the input spike trains.
Returns np.NaN if an empty list is specified, or if all spike trains
are empty.
Expand All @@ -313,7 +316,7 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):
Notes
-----
The check for the equal duration of the input spike trains is performed
only if the input is of type`neo.SpikeTrain`: if you pass a numpy array,
only if the input is of type`neo.SpikeTrain`: if you pass e.g. a numpy array,
please make sure that they all have the same duration manually.
Moritz-Alexander-Kern marked this conversation as resolved.
Show resolved Hide resolved

Examples
Expand All @@ -328,29 +331,40 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):
0.07142857142857142

"""
# Build array of spike counts (one per spike train)
spike_counts = np.array([len(st) for st in spiketrains])

# Compute FF
if all(count == 0 for count in spike_counts):
# empty list of spiketrains reaches this branch, and NaN is returned
return np.nan

if all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
if not is_time_quantity(warn_tolerance):
raise TypeError("'warn_tolerance' must be a time quantity.")
durations = [(st.t_stop - st.t_start).simplified.item()
for st in spiketrains]
durations_min = min(durations)
durations_max = max(durations)
if durations_max - durations_min > warn_tolerance.simplified.item():
warnings.warn("Fano factor calculated for spike trains of "
"different duration (minimum: {_min}s, maximum "
"{_max}s).".format(_min=durations_min,
_max=durations_max))

fano = spike_counts.var() / spike_counts.mean()
return fano
# Check if parameters are of the correct type
if not is_time_quantity(warn_tolerance):
raise TypeError(f"'warn_tolerance' must be a time quantity, but got {type(warn_tolerance)}")

def _check_input_spiketrains_durations(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity],
List[np.ndarray]]) -> None:
if spiketrains and all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
durations = np.array(tuple(st.duration for st in spiketrains))
if np.max(durations) - np.min(durations) > warn_tolerance:
warnings.warn(f"Fano factor calculated for spike trains of "
f"different duration (minimum: {np.min(durations)}s, maximum "
f"{np.max(durations)}s).")

def _compute_fano(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray]]) -> float:
# Check spike train durations
_check_input_spiketrains_durations(spiketrains)
# Build array of spike counts (one per spike train)
spike_counts = np.array(tuple(len(st) for st in spiketrains))
# Compute FF
if np.all(np.array(spike_counts) == 0):
# empty list of spiketrains reaches this branch, and NaN is returned
return np.nan
else:
return spike_counts.var()/spike_counts.mean()

if isinstance(spiketrains, elephant.trials.Trials):
list_of_lists_of_spiketrains = [
spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)
for trial_no in range(spiketrains.n_trials)]
return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no]
for trial_no in range(len(list_of_lists_of_spiketrains))])
for st_no in range(len(list_of_lists_of_spiketrains[0]))]
else: # Legacy behavior
return _compute_fano(spiketrains)


def __variation_check(v, with_nan):
Expand Down
32 changes: 19 additions & 13 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from elephant import statistics
from elephant.spike_train_generation import StationaryPoissonProcess
from elephant.test.test_trials import _create_trials_block
from elephant.trials import TrialsFromBlock
from elephant.trials import TrialsFromBlock, TrialsFromLists


class IsiTestCase(unittest.TestCase):
Expand Down Expand Up @@ -269,32 +269,34 @@ def test_mean_firing_rate_with_plain_array_and_units_start_stop_typeerror(


class FanoFactorTestCase(unittest.TestCase):
def setUp(self):
@classmethod
def setUpClass(cls):
np.random.seed(100)
num_st = 300
self.test_spiketrains = []
self.test_array = []
self.test_quantity = []
self.test_list = []
self.sp_counts = np.zeros(num_st)
cls.test_spiketrains = []
cls.test_array = []
cls.test_quantity = []
cls.test_list = []
cls.sp_counts = np.zeros(num_st)
for i in range(num_st):
r = np.random.rand(np.random.randint(20) + 1)
st = neo.core.SpikeTrain(r * pq.ms,
t_start=0.0 * pq.ms,
t_stop=20.0 * pq.ms)
self.test_spiketrains.append(st)
self.test_array.append(r)
self.test_quantity.append(r * pq.ms)
self.test_list.append(list(r))
cls.test_spiketrains.append(st)
cls.test_array.append(r)
cls.test_quantity.append(r * pq.ms)
cls.test_list.append(list(r))
# for cross-validation
self.sp_counts[i] = len(st)
cls.sp_counts[i] = len(st)

cls.test_trials = TrialsFromLists([cls.test_spiketrains, cls.test_spiketrains])

def test_fanofactor_spiketrains(self):
# Test with list of spiketrains
self.assertEqual(
np.var(self.sp_counts) / np.mean(self.sp_counts),
statistics.fanofactor(self.test_spiketrains))

# One spiketrain in list
st = self.test_spiketrains[0]
self.assertEqual(statistics.fanofactor([st]), 0.0)
Expand Down Expand Up @@ -352,6 +354,10 @@ def test_fanofactor_wrong_type(self):
self.assertRaises(TypeError, statistics.fanofactor, [st1],
warn_tolerance=1e-4)

def test_fanofactor_trials_pool_trials(self):
results = statistics.fanofactor(self.test_trials)
self.assertEqual(len(results), self.test_trials.n_spiketrains_trial_by_trial[0])


class LVTestCase(unittest.TestCase):
def setUp(self):
Expand Down
Loading