Skip to content

Commit

Permalink
add the option to pass a callable as masking criterion
Browse files Browse the repository at this point in the history
  • Loading branch information
vadmbertr committed Dec 19, 2023
1 parent 02afcb4 commit d798c7e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 8 deletions.
47 changes: 39 additions & 8 deletions clouddrift/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import numpy as np
from typing import Tuple, Union, Iterable
from typing import Tuple, Union, Iterable, Callable
import xarray as xr
import pandas as pd
from concurrent import futures
Expand Down Expand Up @@ -547,7 +547,8 @@ def subset(
) -> xr.Dataset:
"""Subset the dataset as a function of one or many criteria. The criteria are
passed as a dictionary, where a variable to subset is assigned to either a
range (valuemin, valuemax), a list [value1, value2, valueN], or a single value.
range (valuemin, valuemax), a list [value1, value2, valueN], a single value,
or a masking function applied to every trajectory using ``apply_ragged``
This function relies on specific names of the dataset dimensions and the
rowsize variables. The default expected values are listed in the Parameters
Expand All @@ -559,7 +560,7 @@ def subset(
ds : xr.Dataset
Lagrangian dataset stored in two-dimensional or ragged array format
criteria : dict
dictionary containing the variables and the ranges/values to subset
dictionary containing the variables and the ranges/values/functions to subset
id_var_name : str, optional
Name of the variable containing the ID of the trajectories (default is "id")
rowsize_var_name : str, optional
Expand Down Expand Up @@ -629,6 +630,19 @@ def subset(
>>> subset(ds, {"lat": (21, 31), "lon": (-98, -78), "drogue_status": True, "sst": (303.15, np.inf), "time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))})
Retrieve observations every 24 hours after the first one, trajectory-wise:
>>> def daily_masking(
>>> traj_time: xr.DataArray
>>> ) -> np.ndarray:
>>> traj_time = traj_time.astype("float64")
>>> traj_time /= 1e9 # to seconds
>>> traj_time -= traj_time[0] # start from 0
>>> mask = (traj_time % (24*60*60)) == 0 # get only obs every 24 hours after the first one
>>> rowsize = int(mask.sum()) # the number of obs per traj has to be updated
>>> return mask, rowsize
>>> subset(ds, {"time": daily_masking})
Raises
------
ValueError
Expand All @@ -644,9 +658,11 @@ def subset(
for key in criteria.keys():
if key in ds or key in ds.dims:
if ds[key].dims == (traj_dim_name,):
mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key]))
mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key], ds[rowsize_var_name],
traj_dim_name))
elif ds[key].dims == (obs_dim_name,):
mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key]))
mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key], ds[rowsize_var_name],
obs_dim_name))
else:
raise ValueError(f"Unknown variable '{key}'.")

Expand Down Expand Up @@ -752,19 +768,26 @@ def unpack(

def _mask_var(
var: xr.DataArray,
criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int],
criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int, Callable],
rowsize: xr.DataArray = None,
dim_name: str = "dim_0",
) -> xr.DataArray:
"""Return the mask of a subset of the data matching a test criterion.
Parameters
----------
var : xr.DataArray
DataArray to be subset by the criterion
criterion : array-like
The criterion can take three forms:
criterion : array-like or scalar or Callable
The criterion can take four forms:
- tuple: (min, max) defining a range
- list, np.ndarray, or xr.DataArray: An array-like defining multiples values
- scalar: value defining a single value
- function: a function applied against each trajectory using ``apply_ragged`` and returning a mask
rowsize : xr.DataArray, optional
List of integers specifying the number of data points in each row
dim_name : str, optional
Name of the masked dimension (default is "dim_0")
Examples
--------
Expand All @@ -784,6 +807,12 @@ def _mask_var(
array([False, False, False, True, False])
Dimensions without coordinates: dim_0
>>> rowsize = xr.DataArray(data=[2, 3])
>>> _mask_var(x, lambda arr: arr==arr[0]+1, rowsize, "dim_0")
<xarray.DataArray (dim_0: 5)>
array([False, True, False, True, False])
Dimensions without coordinates: dim_0
Returns
-------
mask : xr.DataArray
Expand All @@ -794,6 +823,8 @@ def _mask_var(
elif isinstance(criterion, (list, np.ndarray, xr.DataArray)):
# select multiple values
mask = np.isin(var, criterion)
elif callable(criterion): # mask directly created by applying `criterion` to each trajectory
mask = xr.DataArray(data=apply_ragged(criterion, var, rowsize), dims=[dim_name]).astype(bool)
else: # select one specific value
mask = var == criterion
return mask
7 changes: 7 additions & 0 deletions tests/ragged_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,13 @@ def test_subset_by_rows(self):
self.assertTrue(all(ds_sub["id"] == [1, 2]))
self.assertTrue(all(ds_sub["rowsize"] == [5, 4]))

def test_subset_callable(self):
func = lambda arr: ((arr - arr[0]) % 2) == 0 # test keeping obs every two time intervals
ds_sub = subset(self.ds, {"time": func})
self.assertTrue(all(ds_sub["id"] == [1, 3, 2]))
self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2]))
self.assertTrue(all(ds_sub["time"] == [1, 3, 5, 4, 2, 4]))


class unpack_tests(unittest.TestCase):
def test_unpack(self):
Expand Down

0 comments on commit d798c7e

Please sign in to comment.