Skip to content

Commit

Permalink
fix for traj variables
Browse files Browse the repository at this point in the history
  • Loading branch information
Philippe Miron committed Dec 20, 2023
1 parent ab00bbd commit 63bb625
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 6 deletions.
21 changes: 15 additions & 6 deletions clouddrift/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,12 +831,21 @@ def _mask_var(
elif isinstance(criterion, (list, np.ndarray, xr.DataArray)):
# select multiple values
mask = np.isin(var, criterion)
elif callable(
criterion
): # mask directly created by applying `criterion` to each trajectory
mask = xr.DataArray(
data=apply_ragged(criterion, var, rowsize), dims=[dim_name]
).astype(bool)
elif callable(criterion):
# mask directly created by applying `criterion` function
if len(var) == len(rowsize):
mask = criterion(var)
else:
mask = xr.DataArray(
data=apply_ragged(criterion, var, rowsize), dims=[dim_name]
).astype(bool)

if (len(var) == len(rowsize) and len(mask) != len(var)) or (
len(var) == np.sum(rowsize) and len(mask) != np.sum(rowsize)
):
raise ValueError(
"The `Callable` function needs to return a masked array that matches the length of the variable to filter."
)
else: # select one specific value
mask = var == criterion
return mask
12 changes: 12 additions & 0 deletions tests/ragged_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,18 @@ def test_subset_callable(self):
self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2]))
self.assertTrue(all(ds_sub["time"] == [1, 3, 5, 4, 2, 4]))

func = lambda arr: arr <= 2 # keep id larger or equal to 2
ds_sub = subset(self.ds, {"id": func})
self.assertTrue(all(ds_sub["id"] == [1, 2]))
self.assertTrue(all(ds_sub["rowsize"] == [5, 4]))

def test_subset_callable_wrong_dim(self):
func = lambda arr: [arr, arr] # returns 2 values per element
with self.assertRaises(ValueError):
subset(self.ds, {"time": func})
with self.assertRaises(ValueError):
subset(self.ds, {"id": func})


class unpack_tests(unittest.TestCase):
def test_unpack(self):
Expand Down

0 comments on commit 63bb625

Please sign in to comment.