Skip to content

Commit

Permalink
rename utils function to check is_list_spiketrains, fix tests and rep…
Browse files Browse the repository at this point in the history
…lace raise error with return False
  • Loading branch information
Moritz-Alexander-Kern committed Jan 24, 2025
1 parent 6b47e0e commit 3a96ee3
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 23 deletions.
4 changes: 2 additions & 2 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
import elephant.trials
from elephant.conversion import BinnedSpikeTrain
from elephant.utils import deprecated_alias, check_neo_consistency, \
is_time_quantity, round_binning_errors, is_list_neo_spiketrains
is_time_quantity, round_binning_errors, is_list_spiketrains

# do not import unicode_literals
# (quantities rescale does not work with unicodes)
Expand Down Expand Up @@ -1031,7 +1031,7 @@ def optimal_kernel(st):
sigma=str(kernel.sigma),
invert=kernel.invert)

if is_list_neo_spiketrains(spiketrains) and (pool_spike_trains):
if is_list_spiketrains(spiketrains) and (pool_spike_trains):
rate = np.mean(rate, axis=1)

rate = neo.AnalogSignal(signal=rate,
Expand Down
18 changes: 7 additions & 11 deletions elephant/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,33 +144,29 @@ def setUp(self):

def test_valid_list_input(self):
valid_list = [self.spiketrain1, self.spiketrain2]
self.assertTrue(utils.is_list_neo_spiketrains(valid_list))
self.assertTrue(utils.is_list_spiketrains(valid_list))

def test_valid_tuple_input(self):
valid_tuple = (self.spiketrain1, self.spiketrain2)
self.assertTrue(utils.is_list_neo_spiketrains(valid_tuple))
self.assertTrue(utils.is_list_spiketrains(valid_tuple))

def test_valid_spiketrainlist_input(self):
valid_spiketrainlist = neo.core.spiketrainlist.SpikeTrainList(items=(self.spiketrain1, self.spiketrain2))
self.assertTrue(utils.is_list_neo_spiketrains(valid_spiketrainlist))
self.assertTrue(utils.is_list_spiketrains(valid_spiketrainlist))

def test_non_iterable_input(self):
with self.assertRaises(TypeError):
utils.is_list_neo_spiketrains(42)
self.assertFalse(utils.is_list_spiketrains(42))

def test_non_spiketrain_objects(self):
invalid_list = [self.spiketrain1, "not a spiketrain"]
with self.assertRaises(TypeError):
utils.is_list_neo_spiketrains(invalid_list)
self.assertFalse(utils.is_list_spiketrains(invalid_list))

def test_mixed_types_input(self):
invalid_mixed = [self.spiketrain1, 42, self.spiketrain2]
with self.assertRaises(TypeError):
utils.is_list_neo_spiketrains(invalid_mixed)
self.assertFalse(utils.is_list_spiketrains(invalid_mixed))

def test_none_input(self):
with self.assertRaises(TypeError):
utils.is_list_neo_spiketrains(None)
self.assertFalse(utils.is_list_spiketrains(None))


if __name__ == '__main__':
Expand Down
18 changes: 8 additions & 10 deletions elephant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
check_neo_consistency
check_same_units
round_binning_errors
is_list_neo_spiketrains
is_list_spiketrains
"""

from __future__ import division, print_function, unicode_literals
Expand All @@ -33,7 +33,7 @@
"check_neo_consistency",
"check_same_units",
"round_binning_errors",
"is_list_neo_spiketrains",
"is_list_spiketrains",
]


Expand Down Expand Up @@ -451,30 +451,28 @@ def wrapper(*args, **kwargs):
return wrapper


def is_list_neo_spiketrains(obj: object) -> bool:
def is_list_spiketrains(obj: object) -> bool:
"""
Check if input is an iterable containing only neo.SpikeTrain objects.
Parameters
----------
obj : object
The object to check. Can be a neo.spiketrainlist, list, tuple or any other iterable.
The object to check.
Returns
-------
bool
True if obj is an iterable containing only neo.SpikeTrain objects.
Raises
------
TypeError
If obj is not an iterable, or if any element is not a neo.SpikeTrain.
"""

if not isinstance(obj, collections.abc.Iterable):
raise TypeError("Input must be an iterable (list, tuple, etc.)")
# Input must be an iterable (list, tuple, etc.)
return False

if not all(isinstance(st, neo.SpikeTrain) for st in obj):
raise TypeError("All elements must be neo.SpikeTrain objects")
# All elements must be neo.SpikeTrain objects
return False

return True

0 comments on commit 3a96ee3

Please sign in to comment.