Skip to content

Commit

Permalink
lint format
Browse files Browse the repository at this point in the history
  • Loading branch information
selipot committed Dec 19, 2023
1 parent d798c7e commit ab00bbd
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
24 changes: 18 additions & 6 deletions clouddrift/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,11 +658,19 @@ def subset(
for key in criteria.keys():
if key in ds or key in ds.dims:
if ds[key].dims == (traj_dim_name,):
mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key], ds[rowsize_var_name],
traj_dim_name))
mask_traj = np.logical_and(
mask_traj,
_mask_var(
ds[key], criteria[key], ds[rowsize_var_name], traj_dim_name
),
)
elif ds[key].dims == (obs_dim_name,):
mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key], ds[rowsize_var_name],
obs_dim_name))
mask_obs = np.logical_and(
mask_obs,
_mask_var(
ds[key], criteria[key], ds[rowsize_var_name], obs_dim_name
),
)
else:
raise ValueError(f"Unknown variable '{key}'.")

Expand Down Expand Up @@ -823,8 +831,12 @@ def _mask_var(
elif isinstance(criterion, (list, np.ndarray, xr.DataArray)):
# select multiple values
mask = np.isin(var, criterion)
elif callable(criterion): # mask directly created by applying `criterion` to each trajectory
mask = xr.DataArray(data=apply_ragged(criterion, var, rowsize), dims=[dim_name]).astype(bool)
elif callable(
criterion
): # mask directly created by applying `criterion` to each trajectory
mask = xr.DataArray(
data=apply_ragged(criterion, var, rowsize), dims=[dim_name]
).astype(bool)
else: # select one specific value
mask = var == criterion
return mask
4 changes: 3 additions & 1 deletion tests/ragged_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,7 +700,9 @@ def test_subset_by_rows(self):
self.assertTrue(all(ds_sub["rowsize"] == [5, 4]))

def test_subset_callable(self):
func = lambda arr: ((arr - arr[0]) % 2) == 0 # test keeping obs every two time intervals
func = (
lambda arr: ((arr - arr[0]) % 2) == 0
) # test keeping obs every two time intervals
ds_sub = subset(self.ds, {"time": func})
self.assertTrue(all(ds_sub["id"] == [1, 3, 2]))
self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2]))
Expand Down

0 comments on commit ab00bbd

Please sign in to comment.