From 3a96ee360af2d1a7925348465f5595572733afe3 Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Fri, 24 Jan 2025 13:40:01 +0100 Subject: [PATCH] rename utils function to check is_list_spiketrains, fix tests and replace raise error with return False --- elephant/statistics.py | 4 ++-- elephant/test/test_utils.py | 18 +++++++----------- elephant/utils.py | 18 ++++++++---------- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 291adefc0..191aad4b9 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -81,7 +81,7 @@ import elephant.trials from elephant.conversion import BinnedSpikeTrain from elephant.utils import deprecated_alias, check_neo_consistency, \ - is_time_quantity, round_binning_errors, is_list_neo_spiketrains + is_time_quantity, round_binning_errors, is_list_spiketrains # do not import unicode_literals # (quantities rescale does not work with unicodes) @@ -1031,7 +1031,7 @@ def optimal_kernel(st): sigma=str(kernel.sigma), invert=kernel.invert) - if is_list_neo_spiketrains(spiketrains) and (pool_spike_trains): + if is_list_spiketrains(spiketrains) and (pool_spike_trains): rate = np.mean(rate, axis=1) rate = neo.AnalogSignal(signal=rate, diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py index e66c07ef6..421ea9d00 100644 --- a/elephant/test/test_utils.py +++ b/elephant/test/test_utils.py @@ -144,33 +144,29 @@ def setUp(self): def test_valid_list_input(self): valid_list = [self.spiketrain1, self.spiketrain2] - self.assertTrue(utils.is_list_neo_spiketrains(valid_list)) + self.assertTrue(utils.is_list_spiketrains(valid_list)) def test_valid_tuple_input(self): valid_tuple = (self.spiketrain1, self.spiketrain2) - self.assertTrue(utils.is_list_neo_spiketrains(valid_tuple)) + self.assertTrue(utils.is_list_spiketrains(valid_tuple)) def test_valid_spiketrainlist_input(self): valid_spiketrainlist = neo.core.spiketrainlist.SpikeTrainList(items=(self.spiketrain1, self.spiketrain2)) - self.assertTrue(utils.is_list_neo_spiketrains(valid_spiketrainlist)) + self.assertTrue(utils.is_list_spiketrains(valid_spiketrainlist)) def test_non_iterable_input(self): - with self.assertRaises(TypeError): - utils.is_list_neo_spiketrains(42) + self.assertFalse(utils.is_list_spiketrains(42)) def test_non_spiketrain_objects(self): invalid_list = [self.spiketrain1, "not a spiketrain"] - with self.assertRaises(TypeError): - utils.is_list_neo_spiketrains(invalid_list) + self.assertFalse(utils.is_list_spiketrains(invalid_list)) def test_mixed_types_input(self): invalid_mixed = [self.spiketrain1, 42, self.spiketrain2] - with self.assertRaises(TypeError): - utils.is_list_neo_spiketrains(invalid_mixed) + self.assertFalse(utils.is_list_spiketrains(invalid_mixed)) def test_none_input(self): - with self.assertRaises(TypeError): - utils.is_list_neo_spiketrains(None) + self.assertFalse(utils.is_list_spiketrains(None)) if __name__ == '__main__': diff --git a/elephant/utils.py b/elephant/utils.py index aed0b15f6..906f513c4 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -7,7 +7,7 @@ check_neo_consistency check_same_units round_binning_errors - is_list_neo_spiketrains + is_list_spiketrains """ from __future__ import division, print_function, unicode_literals @@ -33,7 +33,7 @@ "check_neo_consistency", "check_same_units", "round_binning_errors", - "is_list_neo_spiketrains", + "is_list_spiketrains", ] @@ -451,30 +451,28 @@ def wrapper(*args, **kwargs): return wrapper -def is_list_neo_spiketrains(obj: object) -> bool: +def is_list_spiketrains(obj: object) -> bool: """ Check if input is an iterable containing only neo.SpikeTrain objects. Parameters ---------- obj : object - The object to check. Can be a neo.spiketrainlist, list, tuple or any other iterable. + The object to check. Returns ------- bool True if obj is an iterable containing only neo.SpikeTrain objects. - Raises - ------ - TypeError - If obj is not an iterable, or if any element is not a neo.SpikeTrain. """ if not isinstance(obj, collections.abc.Iterable): - raise TypeError("Input must be an iterable (list, tuple, etc.)") + # Input must be an iterable (list, tuple, etc.) + return False if not all(isinstance(st, neo.SpikeTrain) for st in obj): - raise TypeError("All elements must be neo.SpikeTrain objects") + # All elements must be neo.SpikeTrain objects + return False return True