Skip to content

Commit

Permalink
Subset callable using multiple variables (#361)
Browse files Browse the repository at this point in the history
* sevaral variables can be passed to the callable returning the mask

* add some examples. update typings.

* add raise TypeError

* consistent typing. add unit tests

* minor changes

* docstring edits

* refactor and add tests. revert np.all. small updates to the docstring

* docstring edits

* Added one Exception to validate all variable share the same dimension

* bump version

* Update pyproject.toml

Add myself as an author for conda forge invite

* black v24

---------

Co-authored-by: Vadim Bertrand <vadim.bertrand@univ-grenoble-alpes.fr>
Co-authored-by: Philippe Miron <philippe.miron@dtn.com>
Co-authored-by: Shane Elipot <selipot@miami.edu>
Co-authored-by: Kevin Santana <kevinsantana11@gmail.com>
  • Loading branch information
5 people authored Jan 27, 2024
1 parent 2dd835c commit 33103dc
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 25 deletions.
95 changes: 72 additions & 23 deletions clouddrift/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,20 +545,21 @@ def subset(
obs_dim_name: str = "obs",
full_trajectories=False,
) -> xr.Dataset:
"""Subset a ragged array dataset as a function of one or more criteria.
"""Subset a ragged array xarray dataset as a function of one or more criteria.
The criteria are passed with a dictionary, where a dictionary key
is a variable to subset and the associated dictionary value is either a range
(valuemin, valuemax), a list [value1, value2, valueN], a single value, or a
masking function applied to every row of the ragged array using ``apply_ragged``.
masking function applied to any variable of the dataset.
This function needs to know the names of the dimensions of the ragged array dataset
(`traj_dim_name` and `obs_dim_name`), and the name of the rowsize variable (`rowsize_var_name`).
Default values are provided for these arguments (see below), but they can be changed if needed.
Default values corresponds to the clouddrift convention ("traj", "obs", and "rowsize") but should
be changed as needed.
Parameters
----------
ds : xr.Dataset
Dataset stored as ragged arrays.
Xarray dataset composed of ragged arrays.
criteria : dict
Dictionary containing the variables (as keys) and the ranges/values/functions (as values) to subset.
id_var_name : str, optional
Expand All @@ -570,19 +571,19 @@ def subset(
obs_dim_name : str, optional
Name of the observation dimension (default is "obs").
full_trajectories : bool, optional
If True, it returns the complete trajectories (rows) where at least one observation
matches the criteria, rather than just the segments where the criteria are satisfied.
Default is False.
If True, the function returns complete rows (trajectories) for which the criteria
are matched at least once. Default is False which means that only segments matching the criteria
are returned when filtering along the observation dimension.
Returns
-------
xr.Dataset
subset Dataset matching the criterion(a)
Subset xarray dataset matching the criterion(a).
Examples
--------
Criteria are combined on any data or metadata variables part of the Dataset.
The following examples are based on NOAA GDP datasets which can be accessed with the
Criteria are combined on any data (with dimension "obs") or metadata (with dimension "traj") variables
part of the Dataset. The following examples are based on NOAA GDP datasets which can be accessed with the
``clouddrift.datasets`` module.
Retrieve a region, like the Gulf of Mexico, using ranges of latitude and longitude:
Expand All @@ -607,7 +608,7 @@ def subset(
>>> subset(ds, {"rowsize": (0, 1000)})
Retrieve specific drifters from their IDs:
Retrieve specific drifters using their IDs:
>>> subset(ds, {"id": [2578, 2582, 2583]})
Expand Down Expand Up @@ -637,10 +638,32 @@ def subset(
>>> func = (lambda arr: ((arr - arr[0]) % 2) == 0)
>>> subset(ds, {"time": func})
The filtering function can accept several input variables passed as a tuple. For example, retrieve
drifters released in the Mediterranean Sea, but exclude those released in the Bay of Biscay and the Black Sea:
>>> def mediterranean_mask(lon: xr.DataArray, lat: xr.DataArray) -> xr.DataArray:
>>> # Mediterranean Sea bounding box
>>> in_med = np.logical_and(-6.0327 <= lon, np.logical_and(lon <= 36.2173,
>>> np.logical_and(30.2639 <= lat, lat <= 45.7833)))
>>> # Bay of Biscay
>>> in_biscay = np.logical_and(lon <= -0.1462, lat >= 43.2744)
>>> # Black Sea
>>> in_blacksea = np.logical_and(lon >= 27.4437, lat >= 40.9088)
>>> return np.logical_and(in_med, np.logical_not(np.logical_or(in_biscay, in_blacksea)))
>>> subset(ds, {("start_lon", "start_lat"): mediterranean_mask})
Raises
------
ValueError
If one of the variable in a criterion is not found in the Dataset
If one of the variable in a criterion is not found in the Dataset.
TypeError
If one of the `criteria` key is a tuple while its associated value is not a `Callable` criterion.
TypeError
If variables of a `criterion` key associated to a `Callable` do not share the same dimension.
See Also
--------
:func:`apply_ragged`
"""
mask_traj = xr.DataArray(
data=np.ones(ds.sizes[traj_dim_name], dtype="bool"), dims=[traj_dim_name]
Expand All @@ -650,19 +673,30 @@ def subset(
)

for key in criteria.keys():
if key in ds or key in ds.dims:
if ds[key].dims == (traj_dim_name,):
if np.all(np.isin(key, ds.variables) | np.isin(key, ds.dims)):
if isinstance(key, tuple):
criterion = [ds[k] for k in key]
if not all(c.dims == criterion[0].dims for c in criterion):
raise TypeError(
"Variables passed to the Callable must share the same dimension."
)
criterion_dims = criterion[0].dims
else:
criterion = ds[key]
criterion_dims = criterion.dims

if criterion_dims == (traj_dim_name,):
mask_traj = np.logical_and(
mask_traj,
_mask_var(
ds[key], criteria[key], ds[rowsize_var_name], traj_dim_name
criterion, criteria[key], ds[rowsize_var_name], traj_dim_name
),
)
elif ds[key].dims == (obs_dim_name,):
elif criterion_dims == (obs_dim_name,):
mask_obs = np.logical_and(
mask_obs,
_mask_var(
ds[key], criteria[key], ds[rowsize_var_name], obs_dim_name
criterion, criteria[key], ds[rowsize_var_name], obs_dim_name
),
)
else:
Expand Down Expand Up @@ -769,7 +803,7 @@ def unpack(


def _mask_var(
var: xr.DataArray,
var: Union[xr.DataArray, list[xr.DataArray]],
criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int, Callable],
rowsize: xr.DataArray = None,
dim_name: str = "dim_0",
Expand All @@ -778,8 +812,8 @@ def _mask_var(
Parameters
----------
var : xr.DataArray
DataArray to be subset by the criterion
var : xr.DataArray or list[xr.DataArray]
DataArray or list of DataArray (only applicable if the criterion is a Callable) to be used by the criterion
criterion : array-like or scalar or Callable
The criterion can take four forms:
- tuple: (min, max) defining a range
Expand Down Expand Up @@ -815,26 +849,41 @@ def _mask_var(
array([False, True, False, True, False])
Dimensions without coordinates: dim_0
>>> y = xr.DataArray(data=np.arange(0, 5)+2)
>>> rowsize = xr.DataArray(data=[2, 3])
>>> _mask_var([x, y], lambda var1, var2: ((var1 * var2) % 2) == 0, rowsize, "dim_0")
<xarray.DataArray (dim_0: 5)>
array([True, False, True, False, True])
Dimensions without coordinates: dim_0
Returns
-------
mask : xr.DataArray
The mask of the subset of the data matching the criteria
"""
if not callable(criterion) and isinstance(var, list):
raise TypeError(
"The `var` parameter can be a `list` only if the `criterion` is a `Callable`."
)

if isinstance(criterion, tuple): # min/max defining range
mask = np.logical_and(var >= criterion[0], var <= criterion[1])
elif isinstance(criterion, (list, np.ndarray, xr.DataArray)):
# select multiple values
mask = np.isin(var, criterion)
elif callable(criterion):
# mask directly created by applying `criterion` function
if len(var) == len(rowsize):
mask = criterion(var)
if not isinstance(var, list):
var = [var]

if len(var[0]) == len(rowsize):
mask = criterion(*var)
else:
mask = apply_ragged(criterion, var, rowsize)

mask = xr.DataArray(data=mask, dims=[dim_name]).astype(bool)

if not len(var) == len(mask):
if not len(var[0]) == len(mask):
raise ValueError(
"The `Callable` function must return a masked array that matches the length of the variable to filter."
)
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@ build-backend = "hatchling.build"

[project]
name = "clouddrift"
version = "0.30.0"
version = "0.31.0"
authors = [
{ name="Shane Elipot", email="selipot@miami.edu" },
{ name="Philippe Miron", email="philippemiron@gmail.com" },
{ name="Milan Curcic", email="mcurcic@miami.edu" }
{ name="Milan Curcic", email="mcurcic@miami.edu" },
{ name="Kevin Santana", email="kevinsantana11@gmail.com" }
]
description = "Accelerating the use of Lagrangian data for atmospheric, oceanic, and climate sciences"
readme = "README.md"
Expand Down
29 changes: 29 additions & 0 deletions tests/ragged_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,13 +714,42 @@ def test_subset_callable(self):
self.assertTrue(all(ds_sub["id"] == [1, 2]))
self.assertTrue(all(ds_sub["rowsize"] == [5, 4]))

def test_subset_callable_tuple(self):
func = lambda arr1, arr2: np.logical_and(
arr1 >= 0, arr2 >= 30
) # keep positive longitude and latitude larger or equal than 30
ds_sub = subset(self.ds, {("lon", "lat"): func})
self.assertTrue(all(ds_sub["id"] == [1, 2]))
self.assertTrue(all(ds_sub["rowsize"] == [2, 2]))
self.assertTrue(all(ds_sub["lon"] >= 0))
self.assertTrue(all(ds_sub["lat"] >= 30))

def test_subset_callable_wrong_dim(self):
func = lambda arr: [arr, arr] # returns 2 values per element
with self.assertRaises(ValueError):
subset(self.ds, {"time": func})
with self.assertRaises(ValueError):
subset(self.ds, {"id": func})

def test_subset_callable_wrong_type(self):
rows = [0, 2] # test extracting first and third rows
with self.assertRaises(TypeError): # passing a tuple when a string is expected
subset(self.ds, {("traj",): rows})

def test_subset_callable_tuple_unknown_var(self):
func = lambda arr1, arr2: np.logical_and(
arr1 >= 0, arr2 >= 30
) # keep positive longitude and latitude larger or equal than 30
with self.assertRaises(ValueError):
subset(self.ds, {("a", "lat"): func})

def test_subset_callable_tuple_not_same_dimension(self):
func = lambda arr1, arr2: np.logical_and(
arr1 >= 0, arr2 >= 30
) # keep positive longitude and latitude larger or equal than 30
with self.assertRaises(TypeError):
subset(self.ds, {("id", "lat"): func})


class unpack_tests(unittest.TestCase):
def test_unpack(self):
Expand Down

0 comments on commit 33103dc

Please sign in to comment.