From ab00bbd5982fa394da885f96553ac24400d59ff0 Mon Sep 17 00:00:00 2001 From: Shane Elipot Date: Tue, 19 Dec 2023 06:07:03 -0500 Subject: [PATCH] lint format --- clouddrift/ragged.py | 24 ++++++++++++++++++------ tests/ragged_tests.py | 4 +++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index 7ded60e3..34a32189 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -658,11 +658,19 @@ def subset( for key in criteria.keys(): if key in ds or key in ds.dims: if ds[key].dims == (traj_dim_name,): - mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key], ds[rowsize_var_name], - traj_dim_name)) + mask_traj = np.logical_and( + mask_traj, + _mask_var( + ds[key], criteria[key], ds[rowsize_var_name], traj_dim_name + ), + ) elif ds[key].dims == (obs_dim_name,): - mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key], ds[rowsize_var_name], - obs_dim_name)) + mask_obs = np.logical_and( + mask_obs, + _mask_var( + ds[key], criteria[key], ds[rowsize_var_name], obs_dim_name + ), + ) else: raise ValueError(f"Unknown variable '{key}'.") @@ -823,8 +831,12 @@ def _mask_var( elif isinstance(criterion, (list, np.ndarray, xr.DataArray)): # select multiple values mask = np.isin(var, criterion) - elif callable(criterion): # mask directly created by applying `criterion` to each trajectory - mask = xr.DataArray(data=apply_ragged(criterion, var, rowsize), dims=[dim_name]).astype(bool) + elif callable( + criterion + ): # mask directly created by applying `criterion` to each trajectory + mask = xr.DataArray( + data=apply_ragged(criterion, var, rowsize), dims=[dim_name] + ).astype(bool) else: # select one specific value mask = var == criterion return mask diff --git a/tests/ragged_tests.py b/tests/ragged_tests.py index 8f3dc14a..26432c19 100644 --- a/tests/ragged_tests.py +++ b/tests/ragged_tests.py @@ -700,7 +700,9 @@ def test_subset_by_rows(self): self.assertTrue(all(ds_sub["rowsize"] == [5, 4])) def test_subset_callable(self): - func = lambda arr: ((arr - arr[0]) % 2) == 0 # test keeping obs every two time intervals + func = ( + lambda arr: ((arr - arr[0]) % 2) == 0 + ) # test keeping obs every two time intervals ds_sub = subset(self.ds, {"time": func}) self.assertTrue(all(ds_sub["id"] == [1, 3, 2])) self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2]))