Skip to content

Commit

Permalink
Merge pull request #350 from vadmbertr/subset-callable
Browse files Browse the repository at this point in the history
Allow `Callable` criterion in `ragged.subset`
  • Loading branch information
kevinsantana11 authored Jan 20, 2024
2 parents 22ea170 + c438051 commit f702928
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 18 deletions.
78 changes: 61 additions & 17 deletions clouddrift/ragged.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

import numpy as np
from typing import Tuple, Union, Iterable
from typing import Tuple, Union, Iterable, Callable
import xarray as xr
import pandas as pd
from concurrent import futures
Expand Down Expand Up @@ -545,21 +545,22 @@ def subset(
obs_dim_name: str = "obs",
full_trajectories=False,
) -> xr.Dataset:
"""Subset the dataset as a function of one or many criteria. The criteria are
passed as a dictionary, where a variable to subset is assigned to either a
range (valuemin, valuemax), a list [value1, value2, valueN], or a single value.
"""Subset a ragged array dataset as a function of one or more criteria.
The criteria are passed with a dictionary, where a dictionary key
is a variable to subset and the associated dictionary value is either a range
(valuemin, valuemax), a list [value1, value2, valueN], a single value, or a
masking function applied to every row of the ragged array using ``apply_ragged``.
This function relies on specific names of the dataset dimensions and the
rowsize variables. The default expected values are listed in the Parameters
section, however, if your dataset uses different names for these dimensions
and variables, you can specify them using the optional arguments.
This function needs to know the names of the dimensions of the ragged array dataset
(`traj_dim_name` and `obs_dim_name`), and the name of the rowsize variable (`rowsize_var_name`).
Default values are provided for these arguments (see below), but they can be changed if needed.
Parameters
----------
ds : xr.Dataset
Lagrangian dataset stored in two-dimensional or ragged array format
Dataset stored as ragged arrays
criteria : dict
dictionary containing the variables and the ranges/values to subset
dictionary containing the variables (as keys) and the ranges/values/functions (as values) to subset
id_var_name : str, optional
Name of the variable containing the ID of the trajectories (default is "id")
rowsize_var_name : str, optional
Expand All @@ -569,7 +570,7 @@ def subset(
obs_dim_name : str, optional
Name of the observation dimension (default is "obs")
full_trajectories : bool, optional
If True, it returns the complete trajectories where at least one observation
If True, it returns the complete trajectories (rows) where at least one observation
matches the criteria, rather than just the segments where the criteria are satisfied.
Default is False.
Expand All @@ -581,7 +582,8 @@ def subset(
Examples
--------
Criteria are combined on any data or metadata variables part of the Dataset.
The following examples are based on the GDP dataset.
The following examples are based on NOAA GDP datasets which can be accessed with the
``clouddrift.datasets`` module.
Retrieve a region, like the Gulf of Mexico, using ranges of latitude and longitude:
Expand Down Expand Up @@ -629,6 +631,12 @@ def subset(
>>> subset(ds, {"lat": (21, 31), "lon": (-98, -78), "drogue_status": True, "sst": (303.15, np.inf), "time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))})
You can also use a function to filter the data. For example, retrieve every other observation
of each trajectory (row):
>>> func = (lambda arr: ((arr - arr[0]) % 2) == 0)
>>> subset(ds, {"time": func})
Raises
------
ValueError
Expand All @@ -644,9 +652,19 @@ def subset(
for key in criteria.keys():
if key in ds or key in ds.dims:
if ds[key].dims == (traj_dim_name,):
mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key]))
mask_traj = np.logical_and(
mask_traj,
_mask_var(
ds[key], criteria[key], ds[rowsize_var_name], traj_dim_name
),
)
elif ds[key].dims == (obs_dim_name,):
mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key]))
mask_obs = np.logical_and(
mask_obs,
_mask_var(
ds[key], criteria[key], ds[rowsize_var_name], obs_dim_name
),
)
else:
raise ValueError(f"Unknown variable '{key}'.")

Expand Down Expand Up @@ -752,19 +770,26 @@ def unpack(

def _mask_var(
var: xr.DataArray,
criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int],
criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int, Callable],
rowsize: xr.DataArray = None,
dim_name: str = "dim_0",
) -> xr.DataArray:
"""Return the mask of a subset of the data matching a test criterion.
Parameters
----------
var : xr.DataArray
DataArray to be subset by the criterion
criterion : array-like
The criterion can take three forms:
criterion : array-like or scalar or Callable
The criterion can take four forms:
- tuple: (min, max) defining a range
- list, np.ndarray, or xr.DataArray: An array-like defining multiples values
- scalar: value defining a single value
- function: a function applied against each trajectory using ``apply_ragged`` and returning a mask
rowsize : xr.DataArray, optional
List of integers specifying the number of data points in each row
dim_name : str, optional
Name of the masked dimension (default is "dim_0")
Examples
--------
Expand All @@ -784,6 +809,12 @@ def _mask_var(
array([False, False, False, True, False])
Dimensions without coordinates: dim_0
>>> rowsize = xr.DataArray(data=[2, 3])
>>> _mask_var(x, lambda arr: arr==arr[0]+1, rowsize, "dim_0")
<xarray.DataArray (dim_0: 5)>
array([False, True, False, True, False])
Dimensions without coordinates: dim_0
Returns
-------
mask : xr.DataArray
Expand All @@ -794,6 +825,19 @@ def _mask_var(
elif isinstance(criterion, (list, np.ndarray, xr.DataArray)):
# select multiple values
mask = np.isin(var, criterion)
elif callable(criterion):
# mask directly created by applying `criterion` function
if len(var) == len(rowsize):
mask = criterion(var)
else:
mask = apply_ragged(criterion, var, rowsize)

mask = xr.DataArray(data=mask, dims=[dim_name]).astype(bool)

if not len(var) == len(mask):
raise ValueError(
"The `Callable` function must return a masked array that matches the length of the variable to filter."
)
else: # select one specific value
mask = var == criterion
return mask
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "hatchling.build"

[project]
name = "clouddrift"
version = "0.29.0"
version = "0.30.0"
authors = [
{ name="Shane Elipot", email="selipot@miami.edu" },
{ name="Philippe Miron", email="philippemiron@gmail.com" },
Expand Down
21 changes: 21 additions & 0 deletions tests/ragged_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -700,6 +700,27 @@ def test_subset_by_rows(self):
self.assertTrue(all(ds_sub["id"] == [1, 2]))
self.assertTrue(all(ds_sub["rowsize"] == [5, 4]))

def test_subset_callable(self):
func = (
lambda arr: ((arr - arr[0]) % 2) == 0
) # test keeping obs every two time intervals
ds_sub = subset(self.ds, {"time": func})
self.assertTrue(all(ds_sub["id"] == [1, 3, 2]))
self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2]))
self.assertTrue(all(ds_sub["time"] == [1, 3, 5, 4, 2, 4]))

func = lambda arr: arr <= 2 # keep id larger or equal to 2
ds_sub = subset(self.ds, {"id": func})
self.assertTrue(all(ds_sub["id"] == [1, 2]))
self.assertTrue(all(ds_sub["rowsize"] == [5, 4]))

def test_subset_callable_wrong_dim(self):
func = lambda arr: [arr, arr] # returns 2 values per element
with self.assertRaises(ValueError):
subset(self.ds, {"time": func})
with self.assertRaises(ValueError):
subset(self.ds, {"id": func})


class unpack_tests(unittest.TestCase):
def test_unpack(self):
Expand Down

0 comments on commit f702928

Please sign in to comment.