From 63bb625bad5b194c497c05a8fa37ec03969b988d Mon Sep 17 00:00:00 2001 From: Philippe Miron Date: Wed, 20 Dec 2023 13:41:35 -0300 Subject: [PATCH] fix for traj variables --- clouddrift/ragged.py | 21 +++++++++++++++------ tests/ragged_tests.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index 34a32189..35e76266 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -831,12 +831,21 @@ def _mask_var( elif isinstance(criterion, (list, np.ndarray, xr.DataArray)): # select multiple values mask = np.isin(var, criterion) - elif callable( - criterion - ): # mask directly created by applying `criterion` to each trajectory - mask = xr.DataArray( - data=apply_ragged(criterion, var, rowsize), dims=[dim_name] - ).astype(bool) + elif callable(criterion): + # mask directly created by applying `criterion` function + if len(var) == len(rowsize): + mask = criterion(var) + else: + mask = xr.DataArray( + data=apply_ragged(criterion, var, rowsize), dims=[dim_name] + ).astype(bool) + + if (len(var) == len(rowsize) and len(mask) != len(var)) or ( + len(var) == np.sum(rowsize) and len(mask) != np.sum(rowsize) + ): + raise ValueError( + "The `Callable` function needs to return a masked array that matches the length of the variable to filter." + ) else: # select one specific value mask = var == criterion return mask diff --git a/tests/ragged_tests.py b/tests/ragged_tests.py index 26432c19..e2fd434b 100644 --- a/tests/ragged_tests.py +++ b/tests/ragged_tests.py @@ -708,6 +708,18 @@ def test_subset_callable(self): self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2])) self.assertTrue(all(ds_sub["time"] == [1, 3, 5, 4, 2, 4])) + func = lambda arr: arr <= 2 # keep id larger or equal to 2 + ds_sub = subset(self.ds, {"id": func}) + self.assertTrue(all(ds_sub["id"] == [1, 2])) + self.assertTrue(all(ds_sub["rowsize"] == [5, 4])) + + def test_subset_callable_wrong_dim(self): + func = lambda arr: [arr, arr] # returns 2 values per element + with self.assertRaises(ValueError): + subset(self.ds, {"time": func}) + with self.assertRaises(ValueError): + subset(self.ds, {"id": func}) + class unpack_tests(unittest.TestCase): def test_unpack(self):