diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index f10a86e4..7ded60e3 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -3,7 +3,7 @@ """ import numpy as np -from typing import Tuple, Union, Iterable +from typing import Tuple, Union, Iterable, Callable import xarray as xr import pandas as pd from concurrent import futures @@ -547,7 +547,8 @@ def subset( ) -> xr.Dataset: """Subset the dataset as a function of one or many criteria. The criteria are passed as a dictionary, where a variable to subset is assigned to either a - range (valuemin, valuemax), a list [value1, value2, valueN], or a single value. + range (valuemin, valuemax), a list [value1, value2, valueN], a single value, + or a masking function applied to every trajectory using ``apply_ragged`` This function relies on specific names of the dataset dimensions and the rowsize variables. The default expected values are listed in the Parameters @@ -559,7 +560,7 @@ def subset( ds : xr.Dataset Lagrangian dataset stored in two-dimensional or ragged array format criteria : dict - dictionary containing the variables and the ranges/values to subset + dictionary containing the variables and the ranges/values/functions to subset id_var_name : str, optional Name of the variable containing the ID of the trajectories (default is "id") rowsize_var_name : str, optional @@ -629,6 +630,19 @@ def subset( >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78), "drogue_status": True, "sst": (303.15, np.inf), "time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))}) + Retrieve observations every 24 hours after the first one, trajectory-wise: + + >>> def daily_masking( + >>> traj_time: xr.DataArray + >>> ) -> np.ndarray: + >>> traj_time = traj_time.astype("float64") + >>> traj_time /= 1e9 # to seconds + >>> traj_time -= traj_time[0] # start from 0 + >>> mask = (traj_time % (24*60*60)) == 0 # get only obs every 24 hours after the first one + >>> rowsize = int(mask.sum()) # the number of obs per traj has to be updated + >>> return mask, rowsize + >>> subset(ds, {"time": daily_masking}) + Raises ------ ValueError @@ -644,9 +658,11 @@ def subset( for key in criteria.keys(): if key in ds or key in ds.dims: if ds[key].dims == (traj_dim_name,): - mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key])) + mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key], ds[rowsize_var_name], + traj_dim_name)) elif ds[key].dims == (obs_dim_name,): - mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key])) + mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key], ds[rowsize_var_name], + obs_dim_name)) else: raise ValueError(f"Unknown variable '{key}'.") @@ -752,7 +768,9 @@ def unpack( def _mask_var( var: xr.DataArray, - criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int], + criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int, Callable], + rowsize: xr.DataArray = None, + dim_name: str = "dim_0", ) -> xr.DataArray: """Return the mask of a subset of the data matching a test criterion. @@ -760,11 +778,16 @@ def _mask_var( ---------- var : xr.DataArray DataArray to be subset by the criterion - criterion : array-like - The criterion can take three forms: + criterion : array-like or scalar or Callable + The criterion can take four forms: - tuple: (min, max) defining a range - list, np.ndarray, or xr.DataArray: An array-like defining multiples values - scalar: value defining a single value + - function: a function applied against each trajectory using ``apply_ragged`` and returning a mask + rowsize : xr.DataArray, optional + List of integers specifying the number of data points in each row + dim_name : str, optional + Name of the masked dimension (default is "dim_0") Examples -------- @@ -784,6 +807,12 @@ def _mask_var( array([False, False, False, True, False]) Dimensions without coordinates: dim_0 + >>> rowsize = xr.DataArray(data=[2, 3]) + >>> _mask_var(x, lambda arr: arr==arr[0]+1, rowsize, "dim_0") + + array([False, True, False, True, False]) + Dimensions without coordinates: dim_0 + Returns ------- mask : xr.DataArray @@ -794,6 +823,8 @@ def _mask_var( elif isinstance(criterion, (list, np.ndarray, xr.DataArray)): # select multiple values mask = np.isin(var, criterion) + elif callable(criterion): # mask directly created by applying `criterion` to each trajectory + mask = xr.DataArray(data=apply_ragged(criterion, var, rowsize), dims=[dim_name]).astype(bool) else: # select one specific value mask = var == criterion return mask diff --git a/tests/ragged_tests.py b/tests/ragged_tests.py index 371d7340..8f3dc14a 100644 --- a/tests/ragged_tests.py +++ b/tests/ragged_tests.py @@ -699,6 +699,13 @@ def test_subset_by_rows(self): self.assertTrue(all(ds_sub["id"] == [1, 2])) self.assertTrue(all(ds_sub["rowsize"] == [5, 4])) + def test_subset_callable(self): + func = lambda arr: ((arr - arr[0]) % 2) == 0 # test keeping obs every two time intervals + ds_sub = subset(self.ds, {"time": func}) + self.assertTrue(all(ds_sub["id"] == [1, 3, 2])) + self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2])) + self.assertTrue(all(ds_sub["time"] == [1, 3, 5, 4, 2, 4])) + class unpack_tests(unittest.TestCase): def test_unpack(self):