From d798c7e90476ae4a695b4bc2b194fc1f66723589 Mon Sep 17 00:00:00 2001 From: vadimbertrand Date: Tue, 19 Dec 2023 11:14:17 +0100 Subject: [PATCH 01/10] add the option to pass a callable as masking criterion --- clouddrift/ragged.py | 47 +++++++++++++++++++++++++++++++++++-------- tests/ragged_tests.py | 7 +++++++ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index f10a86e4..7ded60e3 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -3,7 +3,7 @@ """ import numpy as np -from typing import Tuple, Union, Iterable +from typing import Tuple, Union, Iterable, Callable import xarray as xr import pandas as pd from concurrent import futures @@ -547,7 +547,8 @@ def subset( ) -> xr.Dataset: """Subset the dataset as a function of one or many criteria. The criteria are passed as a dictionary, where a variable to subset is assigned to either a - range (valuemin, valuemax), a list [value1, value2, valueN], or a single value. + range (valuemin, valuemax), a list [value1, value2, valueN], a single value, + or a masking function applied to every trajectory using ``apply_ragged`` This function relies on specific names of the dataset dimensions and the rowsize variables. The default expected values are listed in the Parameters @@ -559,7 +560,7 @@ def subset( ds : xr.Dataset Lagrangian dataset stored in two-dimensional or ragged array format criteria : dict - dictionary containing the variables and the ranges/values to subset + dictionary containing the variables and the ranges/values/functions to subset id_var_name : str, optional Name of the variable containing the ID of the trajectories (default is "id") rowsize_var_name : str, optional @@ -629,6 +630,19 @@ def subset( >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78), "drogue_status": True, "sst": (303.15, np.inf), "time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))}) + Retrieve observations every 24 hours after the first one, trajectory-wise: + + >>> def daily_masking( + >>> traj_time: xr.DataArray + >>> ) -> np.ndarray: + >>> traj_time = traj_time.astype("float64") + >>> traj_time /= 1e9 # to seconds + >>> traj_time -= traj_time[0] # start from 0 + >>> mask = (traj_time % (24*60*60)) == 0 # get only obs every 24 hours after the first one + >>> rowsize = int(mask.sum()) # the number of obs per traj has to be updated + >>> return mask, rowsize + >>> subset(ds, {"time": daily_masking}) + Raises ------ ValueError @@ -644,9 +658,11 @@ def subset( for key in criteria.keys(): if key in ds or key in ds.dims: if ds[key].dims == (traj_dim_name,): - mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key])) + mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key], ds[rowsize_var_name], + traj_dim_name)) elif ds[key].dims == (obs_dim_name,): - mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key])) + mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key], ds[rowsize_var_name], + obs_dim_name)) else: raise ValueError(f"Unknown variable '{key}'.") @@ -752,7 +768,9 @@ def unpack( def _mask_var( var: xr.DataArray, - criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int], + criterion: Union[tuple, list, np.ndarray, xr.DataArray, bool, float, int, Callable], + rowsize: xr.DataArray = None, + dim_name: str = "dim_0", ) -> xr.DataArray: """Return the mask of a subset of the data matching a test criterion. @@ -760,11 +778,16 @@ def _mask_var( ---------- var : xr.DataArray DataArray to be subset by the criterion - criterion : array-like - The criterion can take three forms: + criterion : array-like or scalar or Callable + The criterion can take four forms: - tuple: (min, max) defining a range - list, np.ndarray, or xr.DataArray: An array-like defining multiples values - scalar: value defining a single value + - function: a function applied against each trajectory using ``apply_ragged`` and returning a mask + rowsize : xr.DataArray, optional + List of integers specifying the number of data points in each row + dim_name : str, optional + Name of the masked dimension (default is "dim_0") Examples -------- @@ -784,6 +807,12 @@ def _mask_var( array([False, False, False, True, False]) Dimensions without coordinates: dim_0 + >>> rowsize = xr.DataArray(data=[2, 3]) + >>> _mask_var(x, lambda arr: arr==arr[0]+1, rowsize, "dim_0") + + array([False, True, False, True, False]) + Dimensions without coordinates: dim_0 + Returns ------- mask : xr.DataArray @@ -794,6 +823,8 @@ def _mask_var( elif isinstance(criterion, (list, np.ndarray, xr.DataArray)): # select multiple values mask = np.isin(var, criterion) + elif callable(criterion): # mask directly created by applying `criterion` to each trajectory + mask = xr.DataArray(data=apply_ragged(criterion, var, rowsize), dims=[dim_name]).astype(bool) else: # select one specific value mask = var == criterion return mask diff --git a/tests/ragged_tests.py b/tests/ragged_tests.py index 371d7340..8f3dc14a 100644 --- a/tests/ragged_tests.py +++ b/tests/ragged_tests.py @@ -699,6 +699,13 @@ def test_subset_by_rows(self): self.assertTrue(all(ds_sub["id"] == [1, 2])) self.assertTrue(all(ds_sub["rowsize"] == [5, 4])) + def test_subset_callable(self): + func = lambda arr: ((arr - arr[0]) % 2) == 0 # test keeping obs every two time intervals + ds_sub = subset(self.ds, {"time": func}) + self.assertTrue(all(ds_sub["id"] == [1, 3, 2])) + self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2])) + self.assertTrue(all(ds_sub["time"] == [1, 3, 5, 4, 2, 4])) + class unpack_tests(unittest.TestCase): def test_unpack(self): From ab00bbd5982fa394da885f96553ac24400d59ff0 Mon Sep 17 00:00:00 2001 From: Shane Elipot Date: Tue, 19 Dec 2023 06:07:03 -0500 Subject: [PATCH 02/10] lint format --- clouddrift/ragged.py | 24 ++++++++++++++++++------ tests/ragged_tests.py | 4 +++- 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index 7ded60e3..34a32189 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -658,11 +658,19 @@ def subset( for key in criteria.keys(): if key in ds or key in ds.dims: if ds[key].dims == (traj_dim_name,): - mask_traj = np.logical_and(mask_traj, _mask_var(ds[key], criteria[key], ds[rowsize_var_name], - traj_dim_name)) + mask_traj = np.logical_and( + mask_traj, + _mask_var( + ds[key], criteria[key], ds[rowsize_var_name], traj_dim_name + ), + ) elif ds[key].dims == (obs_dim_name,): - mask_obs = np.logical_and(mask_obs, _mask_var(ds[key], criteria[key], ds[rowsize_var_name], - obs_dim_name)) + mask_obs = np.logical_and( + mask_obs, + _mask_var( + ds[key], criteria[key], ds[rowsize_var_name], obs_dim_name + ), + ) else: raise ValueError(f"Unknown variable '{key}'.") @@ -823,8 +831,12 @@ def _mask_var( elif isinstance(criterion, (list, np.ndarray, xr.DataArray)): # select multiple values mask = np.isin(var, criterion) - elif callable(criterion): # mask directly created by applying `criterion` to each trajectory - mask = xr.DataArray(data=apply_ragged(criterion, var, rowsize), dims=[dim_name]).astype(bool) + elif callable( + criterion + ): # mask directly created by applying `criterion` to each trajectory + mask = xr.DataArray( + data=apply_ragged(criterion, var, rowsize), dims=[dim_name] + ).astype(bool) else: # select one specific value mask = var == criterion return mask diff --git a/tests/ragged_tests.py b/tests/ragged_tests.py index 8f3dc14a..26432c19 100644 --- a/tests/ragged_tests.py +++ b/tests/ragged_tests.py @@ -700,7 +700,9 @@ def test_subset_by_rows(self): self.assertTrue(all(ds_sub["rowsize"] == [5, 4])) def test_subset_callable(self): - func = lambda arr: ((arr - arr[0]) % 2) == 0 # test keeping obs every two time intervals + func = ( + lambda arr: ((arr - arr[0]) % 2) == 0 + ) # test keeping obs every two time intervals ds_sub = subset(self.ds, {"time": func}) self.assertTrue(all(ds_sub["id"] == [1, 3, 2])) self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2])) From 63bb625bad5b194c497c05a8fa37ec03969b988d Mon Sep 17 00:00:00 2001 From: Philippe Miron Date: Wed, 20 Dec 2023 13:41:35 -0300 Subject: [PATCH 03/10] fix for traj variables --- clouddrift/ragged.py | 21 +++++++++++++++------ tests/ragged_tests.py | 12 ++++++++++++ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index 34a32189..35e76266 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -831,12 +831,21 @@ def _mask_var( elif isinstance(criterion, (list, np.ndarray, xr.DataArray)): # select multiple values mask = np.isin(var, criterion) - elif callable( - criterion - ): # mask directly created by applying `criterion` to each trajectory - mask = xr.DataArray( - data=apply_ragged(criterion, var, rowsize), dims=[dim_name] - ).astype(bool) + elif callable(criterion): + # mask directly created by applying `criterion` function + if len(var) == len(rowsize): + mask = criterion(var) + else: + mask = xr.DataArray( + data=apply_ragged(criterion, var, rowsize), dims=[dim_name] + ).astype(bool) + + if (len(var) == len(rowsize) and len(mask) != len(var)) or ( + len(var) == np.sum(rowsize) and len(mask) != np.sum(rowsize) + ): + raise ValueError( + "The `Callable` function needs to return a masked array that matches the length of the variable to filter." + ) else: # select one specific value mask = var == criterion return mask diff --git a/tests/ragged_tests.py b/tests/ragged_tests.py index 26432c19..e2fd434b 100644 --- a/tests/ragged_tests.py +++ b/tests/ragged_tests.py @@ -708,6 +708,18 @@ def test_subset_callable(self): self.assertTrue(all(ds_sub["rowsize"] == [3, 1, 2])) self.assertTrue(all(ds_sub["time"] == [1, 3, 5, 4, 2, 4])) + func = lambda arr: arr <= 2 # keep id larger or equal to 2 + ds_sub = subset(self.ds, {"id": func}) + self.assertTrue(all(ds_sub["id"] == [1, 2])) + self.assertTrue(all(ds_sub["rowsize"] == [5, 4])) + + def test_subset_callable_wrong_dim(self): + func = lambda arr: [arr, arr] # returns 2 values per element + with self.assertRaises(ValueError): + subset(self.ds, {"time": func}) + with self.assertRaises(ValueError): + subset(self.ds, {"id": func}) + class unpack_tests(unittest.TestCase): def test_unpack(self): From d99547ad21d3b5ef08fc394f7c33cd73a2b469cc Mon Sep 17 00:00:00 2001 From: Philippe Miron Date: Wed, 20 Dec 2023 13:44:37 -0300 Subject: [PATCH 04/10] simplify --- clouddrift/ragged.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index 35e76266..2771732d 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -840,11 +840,9 @@ def _mask_var( data=apply_ragged(criterion, var, rowsize), dims=[dim_name] ).astype(bool) - if (len(var) == len(rowsize) and len(mask) != len(var)) or ( - len(var) == np.sum(rowsize) and len(mask) != np.sum(rowsize) - ): + if not len(var) == len(mask): raise ValueError( - "The `Callable` function needs to return a masked array that matches the length of the variable to filter." + "The `Callable` function must return a masked array that matches the length of the variable to filter." ) else: # select one specific value mask = var == criterion From 676d6f4d4fb475db5c3f306628c74143326c6169 Mon Sep 17 00:00:00 2001 From: Philippe Miron Date: Thu, 21 Dec 2023 10:35:18 -0300 Subject: [PATCH 05/10] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 73e9bda1..24b2d529 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "clouddrift" -version = "0.29.0" +version = "0.30.0" authors = [ { name="Shane Elipot", email="selipot@miami.edu" }, { name="Philippe Miron", email="philippemiron@gmail.com" }, From fa1421093c82ce679739d6cf47392ea16eac8ba4 Mon Sep 17 00:00:00 2001 From: Shane Elipot Date: Thu, 21 Dec 2023 12:49:55 -0500 Subject: [PATCH 06/10] docstring edits --- clouddrift/ragged.py | 42 ++++++++++++++++++------------------------ 1 file changed, 18 insertions(+), 24 deletions(-) diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index 2771732d..5753877a 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -545,22 +545,22 @@ def subset( obs_dim_name: str = "obs", full_trajectories=False, ) -> xr.Dataset: - """Subset the dataset as a function of one or many criteria. The criteria are - passed as a dictionary, where a variable to subset is assigned to either a - range (valuemin, valuemax), a list [value1, value2, valueN], a single value, - or a masking function applied to every trajectory using ``apply_ragged`` + """Subset a ragged array dataset as a function of one or more criteria. + The criteria are passed with a dictionary, where a dictionary key + is a variable to subset and the associated dictionary value is either a range + (valuemin, valuemax), a list [value1, value2, valueN], a single value, or a + masking function applied to every row of the ragged array using ``apply_ragged``. - This function relies on specific names of the dataset dimensions and the - rowsize variables. The default expected values are listed in the Parameters - section, however, if your dataset uses different names for these dimensions - and variables, you can specify them using the optional arguments. + This function needs to know the names of the dimensions of the ragged array dataset + (traj_dim_name and obs_dim_name), and the name of the rowsize variable (rowsize_var_name). + Default values are provided for these arguments (see below), but they can be changed if needed. Parameters ---------- ds : xr.Dataset - Lagrangian dataset stored in two-dimensional or ragged array format + Dataset stored as ragged arrays criteria : dict - dictionary containing the variables and the ranges/values/functions to subset + dictionary containing the variables (as keys) and the ranges/values/functions (as values) to subset id_var_name : str, optional Name of the variable containing the ID of the trajectories (default is "id") rowsize_var_name : str, optional @@ -570,7 +570,7 @@ def subset( obs_dim_name : str, optional Name of the observation dimension (default is "obs") full_trajectories : bool, optional - If True, it returns the complete trajectories where at least one observation + If True, it returns the complete trajectories (rows) where at least one observation matches the criteria, rather than just the segments where the criteria are satisfied. Default is False. @@ -582,7 +582,8 @@ def subset( Examples -------- Criteria are combined on any data or metadata variables part of the Dataset. - The following examples are based on the GDP dataset. + The following examples are based on NOAA GDP datasets which can be accessed with the + clouddrift.datasets module. Retrieve a region, like the Gulf of Mexico, using ranges of latitude and longitude: @@ -630,18 +631,11 @@ def subset( >>> subset(ds, {"lat": (21, 31), "lon": (-98, -78), "drogue_status": True, "sst": (303.15, np.inf), "time": (np.datetime64("2000-01-01"), np.datetime64("2020-01-31"))}) - Retrieve observations every 24 hours after the first one, trajectory-wise: - - >>> def daily_masking( - >>> traj_time: xr.DataArray - >>> ) -> np.ndarray: - >>> traj_time = traj_time.astype("float64") - >>> traj_time /= 1e9 # to seconds - >>> traj_time -= traj_time[0] # start from 0 - >>> mask = (traj_time % (24*60*60)) == 0 # get only obs every 24 hours after the first one - >>> rowsize = int(mask.sum()) # the number of obs per traj has to be updated - >>> return mask, rowsize - >>> subset(ds, {"time": daily_masking}) + You can also use a function to filter the data. For example, retrieve every other observation + of each trajectory (row): + + >>> func = (lambda arr: ((arr - arr[0]) % 2) == 0) + >>> subset(ds, {"time": func}) Raises ------ From b44e3a51285a840743132ee9763d2f9552baf998 Mon Sep 17 00:00:00 2001 From: Shane Elipot Date: Thu, 21 Dec 2023 12:51:21 -0500 Subject: [PATCH 07/10] lint --- clouddrift/ragged.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index 5753877a..e67bf110 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -545,13 +545,13 @@ def subset( obs_dim_name: str = "obs", full_trajectories=False, ) -> xr.Dataset: - """Subset a ragged array dataset as a function of one or more criteria. - The criteria are passed with a dictionary, where a dictionary key - is a variable to subset and the associated dictionary value is either a range - (valuemin, valuemax), a list [value1, value2, valueN], a single value, or a + """Subset a ragged array dataset as a function of one or more criteria. + The criteria are passed with a dictionary, where a dictionary key + is a variable to subset and the associated dictionary value is either a range + (valuemin, valuemax), a list [value1, value2, valueN], a single value, or a masking function applied to every row of the ragged array using ``apply_ragged``. - This function needs to know the names of the dimensions of the ragged array dataset + This function needs to know the names of the dimensions of the ragged array dataset (traj_dim_name and obs_dim_name), and the name of the rowsize variable (rowsize_var_name). Default values are provided for these arguments (see below), but they can be changed if needed. From df0ed7a9c6dec1ab78c369f8e19d56cff637ea89 Mon Sep 17 00:00:00 2001 From: Shane Elipot Date: Thu, 21 Dec 2023 13:08:43 -0500 Subject: [PATCH 08/10] backticks in docstring --- clouddrift/ragged.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index e67bf110..c2a6e33c 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -552,7 +552,7 @@ def subset( masking function applied to every row of the ragged array using ``apply_ragged``. This function needs to know the names of the dimensions of the ragged array dataset - (traj_dim_name and obs_dim_name), and the name of the rowsize variable (rowsize_var_name). + (`traj_dim_name` and `obs_dim_name`), and the name of the rowsize variable (`rowsize_var_name`). Default values are provided for these arguments (see below), but they can be changed if needed. Parameters @@ -583,7 +583,7 @@ def subset( -------- Criteria are combined on any data or metadata variables part of the Dataset. The following examples are based on NOAA GDP datasets which can be accessed with the - clouddrift.datasets module. + ``clouddrift.datasets`` module. Retrieve a region, like the Gulf of Mexico, using ranges of latitude and longitude: From e66db2dfbdc25d29218e7ec923caf57b02ba96dc Mon Sep 17 00:00:00 2001 From: Vadim Bertrand Date: Thu, 28 Dec 2023 10:56:42 +0100 Subject: [PATCH 09/10] minor change: for consistency the mask is cast to DataArray after the if/else statement --- clouddrift/ragged.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/clouddrift/ragged.py b/clouddrift/ragged.py index c2a6e33c..d5dc7be6 100644 --- a/clouddrift/ragged.py +++ b/clouddrift/ragged.py @@ -830,9 +830,9 @@ def _mask_var( if len(var) == len(rowsize): mask = criterion(var) else: - mask = xr.DataArray( - data=apply_ragged(criterion, var, rowsize), dims=[dim_name] - ).astype(bool) + mask = apply_ragged(criterion, var, rowsize) + + mask = xr.DataArray(data=mask, dims=[dim_name]).astype(bool) if not len(var) == len(mask): raise ValueError( From c438051a92172313e7a7de9075e5b239e034ba9d Mon Sep 17 00:00:00 2001 From: Kevin Santana Date: Fri, 19 Jan 2024 11:51:58 -0500 Subject: [PATCH 10/10] Update gdp datasets (fix) (#354) Update gdp datasets; centralize download logic and enhance with retry mechanism. closes #353 --- clouddrift/adapters/__init__.py | 27 ++++-- clouddrift/adapters/andro.py | 4 +- clouddrift/adapters/gdp.py | 8 +- clouddrift/adapters/gdp1h.py | 45 +++++---- clouddrift/adapters/gdp6h.py | 59 ++++++------ clouddrift/adapters/glad.py | 15 +-- clouddrift/adapters/mosaic.py | 15 ++- clouddrift/adapters/subsurface_floats.py | 11 +-- clouddrift/adapters/utils.py | 111 +++++++++++++++++++++++ clouddrift/adapters/yomaha.py | 54 ++--------- clouddrift/datasets.py | 81 ++++++++++------- environment.yml | 2 + pyproject.toml | 2 + tests/datasets_tests.py | 52 +++++------ tests/ragged_tests.py | 5 +- 15 files changed, 286 insertions(+), 205 deletions(-) create mode 100644 clouddrift/adapters/utils.py diff --git a/clouddrift/adapters/__init__.py b/clouddrift/adapters/__init__.py index 7ec7cfdf..9a83363e 100644 --- a/clouddrift/adapters/__init__.py +++ b/clouddrift/adapters/__init__.py @@ -8,10 +8,23 @@ in the future. """ -import clouddrift.adapters.andro -import clouddrift.adapters.gdp1h -import clouddrift.adapters.gdp6h -import clouddrift.adapters.glad -import clouddrift.adapters.mosaic -import clouddrift.adapters.subsurface_floats -import clouddrift.adapters.yomaha +import clouddrift.adapters.andro as andro +import clouddrift.adapters.gdp1h as gdp1h +import clouddrift.adapters.gdp6h as gdp6h +import clouddrift.adapters.glad as glad +import clouddrift.adapters.mosaic as mosaic +import clouddrift.adapters.subsurface_floats as subsurface_floats +import clouddrift.adapters.yomaha as yomaha +import clouddrift.adapters.utils as utils + + +__all__ = [ + "andro", + "gdp1h", + "gdp6h", + "glad", + "mosaic", + "subsurface_floats", + "yomaha", + "utils", +] diff --git a/clouddrift/adapters/andro.py b/clouddrift/adapters/andro.py index 84572c38..937654ec 100644 --- a/clouddrift/adapters/andro.py +++ b/clouddrift/adapters/andro.py @@ -17,7 +17,7 @@ SEANOE. https://doi.org/10.17882/47077 """ -from clouddrift.adapters.yomaha import download_with_progress +from clouddrift.adapters.utils import download_with_progress from datetime import datetime import numpy as np import os @@ -39,7 +39,7 @@ def to_xarray(tmp_path: str = None): # get or update dataset local_file = f"{tmp_path}/{ANDRO_URL.split('/')[-1]}" - download_with_progress(ANDRO_URL, local_file) + download_with_progress([(ANDRO_URL, local_file)]) # parse with panda col_names = [ diff --git a/clouddrift/adapters/gdp.py b/clouddrift/adapters/gdp.py index d275fcd1..47a24cca 100644 --- a/clouddrift/adapters/gdp.py +++ b/clouddrift/adapters/gdp.py @@ -5,12 +5,11 @@ and six-hourly (``clouddrift.adapters.gdp6h``) GDP modules. """ +from clouddrift.adapters.utils import download_with_progress import numpy as np import os import pandas as pd import xarray as xr -import urllib.request -import warnings GDP_COORDS = [ "ids", @@ -188,10 +187,7 @@ def fetch_netcdf(url: str, file: str): file : str Name of the file to save. """ - if not os.path.isfile(file): - urllib.request.urlretrieve(url, file) - else: - warnings.warn(f"{file} already exists; skip download.") + download_with_progress([(url, file)]) def decode_date(t): diff --git a/clouddrift/adapters/gdp1h.py b/clouddrift/adapters/gdp1h.py index 09459340..acd85e17 100644 --- a/clouddrift/adapters/gdp1h.py +++ b/clouddrift/adapters/gdp1h.py @@ -6,13 +6,12 @@ import clouddrift.adapters.gdp as gdp from clouddrift.raggedarray import RaggedArray -from datetime import datetime +from clouddrift.adapters.utils import download_with_progress +from datetime import datetime, timedelta import numpy as np import urllib.request -import concurrent.futures import re import tempfile -from tqdm import tqdm from typing import Optional import os import warnings @@ -20,7 +19,7 @@ GDP_VERSION = "2.01" -GDP_DATA_URL = "https://www.aoml.noaa.gov/ftp/pub/phod/lumpkin/hourly/v2.01/netcdf/" +GDP_DATA_URL = "https://www.aoml.noaa.gov/ftp/pub/phod/buoydata/hourly_product/v2.01/" GDP_DATA_URL_EXPERIMENTAL = ( "https://www.aoml.noaa.gov/ftp/pub/phod/lumpkin/hourly/experimental/" ) @@ -108,25 +107,11 @@ def download( rng = np.random.RandomState(42) drifter_ids = sorted(rng.choice(drifter_ids, n_random_id, replace=False)) - with concurrent.futures.ThreadPoolExecutor() as executor: - # create list of urls and paths - urls = [] - files = [] - for i in drifter_ids: - file = filename_pattern.format(id=i) - urls.append(os.path.join(url, file)) - files.append(os.path.join(tmp_path, file)) - - # parallel retrieving of individual netCDF files - list( - tqdm( - executor.map(gdp.fetch_netcdf, urls, files), - total=len(files), - desc="Downloading files", - ncols=80, - ) - ) - + download_requests = [ + (os.path.join(url, file_name), os.path.join(tmp_path, file_name)) + for file_name in map(lambda d_id: filename_pattern.format(id=d_id), drifter_ids) + ] + download_with_progress(download_requests) # Download the metadata so we can order the drifter IDs by end date. gdp_metadata = gdp.get_gdp_metadata() @@ -490,6 +475,8 @@ def preprocess(index: int, **kwargs) -> xr.Dataset: "title": "Global Drifter Program hourly drifting buoy collection", "history": f"version {GDP_VERSION}. Metadata from dirall.dat and deplog.dat", "Conventions": "CF-1.6", + "time_coverage_start": "", + "time_coverage_end": "", "date_created": datetime.now().isoformat(), "publisher_name": "GDP Drifter DAC", "publisher_email": "aoml.dftr@noaa.gov", @@ -602,7 +589,7 @@ def to_raggedarray( else: raise ValueError(f"url must be {GDP_DATA_URL} or {GDP_DATA_URL_EXPERIMENTAL}.") - return RaggedArray.from_files( + ra = RaggedArray.from_files( indices=ids, preprocess_func=preprocess, name_coords=gdp.GDP_COORDS, @@ -612,3 +599,13 @@ def to_raggedarray( filename_pattern=filename_pattern, tmp_path=tmp_path, ) + + # set dynamic global attributes + ra.attrs_global[ + "time_coverage_start" + ] = f"{datetime(1970,1,1) + timedelta(seconds=int(np.min(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" + ra.attrs_global[ + "time_coverage_end" + ] = f"{datetime(1970,1,1) + timedelta(seconds=int(np.max(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" + + return ra diff --git a/clouddrift/adapters/gdp6h.py b/clouddrift/adapters/gdp6h.py index 6e08de3b..941deae3 100644 --- a/clouddrift/adapters/gdp6h.py +++ b/clouddrift/adapters/gdp6h.py @@ -5,21 +5,21 @@ """ import clouddrift.adapters.gdp as gdp +from clouddrift.adapters.utils import download_with_progress from clouddrift.raggedarray import RaggedArray -from datetime import datetime +from datetime import datetime, timedelta import numpy as np import urllib.request -import concurrent.futures import re import tempfile -from tqdm import tqdm from typing import Optional import os import warnings import xarray as xr +GDP_VERSION = "September 2023" -GDP_DATA_URL = "https://www.aoml.noaa.gov/ftp/pub/phod/lumpkin/netcdf/" +GDP_DATA_URL = "https://www.aoml.noaa.gov/ftp/pub/phod/buoydata/6h/" GDP_TMP_PATH = os.path.join(tempfile.gettempdir(), "clouddrift", "gdp6h") GDP_DATA = [ "lon", @@ -57,7 +57,7 @@ def download( Returns ------- out : list - List of retrived drifters + List of retrieved drifters """ print(f"Downloading GDP 6-hourly data to {tmp_path}...") @@ -65,12 +65,12 @@ def download( # Create a temporary directory if doesn't already exists. os.makedirs(tmp_path, exist_ok=True) - pattern = "drifter_[0-9]*.nc" + pattern = "drifter_6h_[0-9]*.nc" directory_list = [ - "buoydata_1_5000", - "buoydata_5001_10000", - "buoydata_10001_15000", - "buoydata_15001_oct22", + "netcdf_1_5000", + "netcdf_5001_10000", + "netcdf_10001_15000", + "netcdf_15001_current", ] # retrieve all drifter ID numbers @@ -94,25 +94,14 @@ def download( rng = np.random.RandomState(42) drifter_urls = rng.choice(drifter_urls, n_random_id, replace=False) - with concurrent.futures.ThreadPoolExecutor() as executor: - # Asynchronously download individual netCDF files - list( - tqdm( - executor.map( - gdp.fetch_netcdf, - drifter_urls, - [os.path.join(tmp_path, os.path.basename(f)) for f in drifter_urls], - ), - total=len(drifter_urls), - desc="Downloading files", - ncols=80, - ) - ) + download_with_progress( + [(url, os.path.join(tmp_path, os.path.basename(url))) for url in drifter_urls] + ) # Download the metadata so we can order the drifter IDs by end date. gdp_metadata = gdp.get_gdp_metadata() drifter_ids = [ - int(os.path.basename(f).split("_")[1].split(".")[0]) for f in drifter_urls + int(os.path.basename(f).split("_")[2].split(".")[0]) for f in drifter_urls ] return gdp.order_by_date(gdp_metadata, drifter_ids) @@ -392,9 +381,11 @@ def preprocess(index: int, **kwargs) -> xr.Dataset: # global attributes attrs = { - "title": "Global Drifter Program hourly drifting buoy collection", - "history": f"version {gdp.GDP_VERSION}. Metadata from dirall.dat and deplog.dat", + "title": "Global Drifter Program drifting buoy collection", + "history": f"version {GDP_VERSION}. Metadata from dirall.dat and deplog.dat", "Conventions": "CF-1.6", + "time_coverage_start": "", + "time_coverage_end": "", "date_created": datetime.now().isoformat(), "publisher_name": "GDP Drifter DAC", "publisher_email": "aoml.dftr@noaa.gov", @@ -485,13 +476,23 @@ def to_raggedarray( """ ids = download(drifter_ids, n_random_id, GDP_DATA_URL, tmp_path) - return RaggedArray.from_files( + ra = RaggedArray.from_files( indices=ids, preprocess_func=preprocess, name_coords=gdp.GDP_COORDS, name_meta=gdp.GDP_METADATA, name_data=GDP_DATA, rowsize_func=gdp.rowsize, - filename_pattern="drifter_{id}.nc", + filename_pattern="drifter_6h_{id}.nc", tmp_path=tmp_path, ) + + # update dynamic global attributes + ra.attrs_global[ + "time_coverage_start" + ] = f"{datetime(1970,1,1) + timedelta(seconds=int(np.min(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" + ra.attrs_global[ + "time_coverage_end" + ] = f"{datetime(1970,1,1) + timedelta(seconds=int(np.max(ra.coords['time']))):%Y-%m-%d:%H:%M:%SZ}" + + return ra diff --git a/clouddrift/adapters/glad.py b/clouddrift/adapters/glad.py index 96959124..2fe5cc78 100644 --- a/clouddrift/adapters/glad.py +++ b/clouddrift/adapters/glad.py @@ -13,11 +13,10 @@ --------- Özgökmen, Tamay. 2013. GLAD experiment CODE-style drifter trajectories (low-pass filtered, 15 minute interval records), northern Gulf of Mexico near DeSoto Canyon, July-October 2012. Distributed by: Gulf of Mexico Research Initiative Information and Data Cooperative (GRIIDC), Harte Research Institute, Texas A&M University–Corpus Christi. doi:10.7266/N7VD6WC8 """ -from io import StringIO +from clouddrift.adapters.utils import download_with_progress +from io import BytesIO import numpy as np import pandas as pd -import requests -import tqdm import xarray as xr @@ -27,15 +26,9 @@ def get_dataframe() -> pd.DataFrame: # GRIIDC server doesn't provide Content-Length header, so we'll hardcode # the expected data length here. file_size = 155330876 - r = requests.get(url, stream=True) - progress_bar = tqdm.tqdm(total=file_size, unit="iB", unit_scale=True) - buf = StringIO() - for chunk in r.iter_content(chunk_size=1024): - if chunk: # filter out keep-alive new chunks - buf.write(chunk.decode("utf-8")) - progress_bar.update(len(chunk)) + buf = BytesIO(b"") + download_with_progress([(url, buf)]) buf.seek(0) - progress_bar.close() column_names = [ "id", "date", diff --git a/clouddrift/adapters/mosaic.py b/clouddrift/adapters/mosaic.py index eb4a78ff..0f28cfb2 100644 --- a/clouddrift/adapters/mosaic.py +++ b/clouddrift/adapters/mosaic.py @@ -18,8 +18,8 @@ >>> from clouddrift.adapters import mosaic >>> ds = mosaic.to_xarray() """ -from concurrent.futures import ThreadPoolExecutor from datetime import datetime +from io import BytesIO import numpy as np import pandas as pd import requests @@ -27,6 +27,8 @@ import xarray as xr import xml.etree.ElementTree as ET +from clouddrift.adapters.utils import download_with_progress + MOSAIC_VERSION = "2022" @@ -56,15 +58,10 @@ def get_dataframes() -> tuple[pd.DataFrame, pd.DataFrame]: range(len(sensor_ids)), key=lambda k: order_index[sensor_ids[k]] ) sorted_data_urls = [data_urls[i] for i in sorted_indices] + buffers = [BytesIO(b"") * len(sorted_data_urls)] - with ThreadPoolExecutor() as executor: - dfs = tqdm( - executor.map(pd.read_csv, sorted_data_urls), - total=len(sorted_data_urls), - desc="Downloading data", - ncols=80, - ) - + download_with_progress(zip(sorted_data_urls, buffers), desc="Downloading data") + dfs = [pd.read_csv(b) for b in buffers] obs_df = pd.concat(dfs) # Use the index of the concatenated DataFrame to determine the count/rowsize diff --git a/clouddrift/adapters/subsurface_floats.py b/clouddrift/adapters/subsurface_floats.py index b4723307..ed34f239 100644 --- a/clouddrift/adapters/subsurface_floats.py +++ b/clouddrift/adapters/subsurface_floats.py @@ -17,10 +17,11 @@ import pandas as pd import scipy.io import tempfile -import urllib.request import xarray as xr import warnings +from clouddrift.adapters.utils import download_with_progress + SUBSURFACE_FLOATS_DATA_URL = ( "https://www.aoml.noaa.gov/phod/float_traj/files/allFloats_12122017.mat" ) @@ -31,13 +32,7 @@ def download(file: str): - if not os.path.isfile(file): - print( - f"Downloading Subsurface float trajectories from {SUBSURFACE_FLOATS_DATA_URL} to {file}..." - ) - urllib.request.urlretrieve(SUBSURFACE_FLOATS_DATA_URL, file) - else: - warnings.warn(f"{file} already exists; skip download.") + download_with_progress([(SUBSURFACE_FLOATS_DATA_URL, file)]) def to_xarray( diff --git a/clouddrift/adapters/utils.py b/clouddrift/adapters/utils.py new file mode 100644 index 00000000..e63d80e0 --- /dev/null +++ b/clouddrift/adapters/utils.py @@ -0,0 +1,111 @@ +from io import BufferedIOBase +from typing import Callable, List, NamedTuple, Union +import os +import datetime +from tqdm import tqdm +import requests +import warnings +from tenacity import ( + retry, + wait_exponential_jitter, + stop_after_attempt, + retry_if_exception, +) +import concurrent.futures + + +_CHUNK_SIZE = 1024 +_ID_FUNC = lambda x: x + + +class _DownloadRequest(NamedTuple): + src: str + dst: Union[BufferedIOBase, str] + + +def download_with_progress( + download_map: List[_DownloadRequest], prewrite_func=_ID_FUNC +): + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = { + executor.submit(_download_with_progress, src, dst, prewrite_func): src + for (src, dst) in download_map + } + for fut in concurrent.futures.as_completed(futures): + url = futures[fut] + print(f"Finished downloading: {url}") + + +@retry( + retry=retry_if_exception( + lambda ex: isinstance(ex, [requests.Timeout, requests.HTTPError]) + ), + wait=wait_exponential_jitter(initial=0.25), + stop=stop_after_attempt(10), +) +def _download_with_progress( + url: str, + output: Union[BufferedIOBase, str], + prewrite_func: Callable[[bytes], Union[str, bytes]], +): + if isinstance(output, str) and os.path.exists(output): + print(f"File exists {output} checking for updates...") + local_last_modified = os.path.getmtime(output) + + # Get last modified time of the remote file + with requests.head(url, timeout=5) as response: + if "Last-Modified" in response.headers: + remote_last_modified = datetime.strptime( + response.headers.get("Last-Modified"), "%a, %d %b %Y %H:%M:%S %Z" + ) + + # compare with local modified time + if local_last_modified >= remote_last_modified.timestamp(): + warnings.warn( + f"{output} already exists and is up to date; skip download." + ) + return False + else: + warnings.warn( + "Cannot determine the file has been updated on the remote source. \ + 'Last-Modified' header not present." + ) + print(f"Downloading from {url} to {output}...") + + force_close = False + try: + response = requests.get(url, timeout=5, stream=True) + if isinstance(output, str): + buffer = open(output, "wb") + else: + buffer = output + bar = tqdm( + desc=url, + total=int(response.headers.get("Content-Length", 0)), + unit="B", + unit_scale=True, + unit_divisor=1024, + ) + + for chunk in response.iter_content(_CHUNK_SIZE): + if not chunk: + break + buffer.write(prewrite_func(chunk)) + bar.update(len(chunk)) + except Exception as e: + import traceback as tb + + force_close = True + error_msg = f"Error downloading data file: {url} to: {output}, error: {e}" + print(error_msg) + tb.print_exc() + raise RuntimeError(error_msg) + finally: + if response is not None: + response.close() + if buffer is not None and not isinstance(buffer, BufferedIOBase) or force_close: + print(f"closing buffer {buffer}") + buffer.close() + if bar is not None: + bar.close() + return True diff --git a/clouddrift/adapters/yomaha.py b/clouddrift/adapters/yomaha.py index f8c38e07..f844126a 100644 --- a/clouddrift/adapters/yomaha.py +++ b/clouddrift/adapters/yomaha.py @@ -22,11 +22,11 @@ import os import pandas as pd import tempfile -from tqdm import tqdm -import urllib.request import xarray as xr import warnings +from clouddrift.adapters.utils import download_with_progress + YOMAHA_URLS = [ # order of the URLs is important @@ -40,56 +40,16 @@ YOMAHA_TMP_PATH = os.path.join(tempfile.gettempdir(), "clouddrift", "yomaha") -def download_with_progress(url, output_file): - if os.path.isfile(output_file): - local_last_modified = os.path.getmtime(output_file) - - # Get last modified time of the remote file - with urllib.request.urlopen(url) as response: - remote_last_modified = datetime.strptime( - response.headers.get("Last-Modified"), "%a, %d %b %Y %H:%M:%S %Z" - ) - # compare with local modified time - if local_last_modified >= remote_last_modified.timestamp(): - warnings.warn( - f"{output_file} already exists and is up to date; skip download." - ) - return False - - print(f"Downloading from {url} to {output_file}...") - with urllib.request.urlopen(url) as response, open( - output_file, "wb" - ) as outfile, tqdm( - desc=url, - total=int(response.headers["Content-Length"] or 0), - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as bar: - chunk_size = 1024 - while True: - chunk = response.read(chunk_size) - if not chunk: - break - outfile.write(chunk) - bar.update(len(chunk)) - return True - - def download(tmp_path: str): - for i in range(0, len(YOMAHA_URLS) - 1): - print("Downloading: " + str(YOMAHA_URLS[i])) - outfile = f"{tmp_path}/{YOMAHA_URLS[i].split('/')[-1]}" - download_with_progress(YOMAHA_URLS[i], outfile) + download_requests = [ + (url, f"{tmp_path}/{url.split('/')[-1]}") for url in YOMAHA_URLS[:-1] + ] + download_with_progress(download_requests) filename_gz = f"{tmp_path}/{YOMAHA_URLS[-1].split('/')[-1]}" filename = filename_gz[:-3] - if download_with_progress(YOMAHA_URLS[-1], filename_gz) or not os.path.isfile( - filename - ): - with open(filename_gz, "rb") as f_gz, open(filename, "wb") as f: - f.write(gzip.decompress(f_gz.read())) + download_with_progress([(YOMAHA_URLS[-1], filename)], gzip.decompress) def to_xarray(tmp_path: str = None): diff --git a/clouddrift/datasets.py b/clouddrift/datasets.py index 8f41355c..c64e59fa 100644 --- a/clouddrift/datasets.py +++ b/clouddrift/datasets.py @@ -4,8 +4,10 @@ they will be downloaded from their upstream repositories and stored for later access (~/.clouddrift for UNIX-based systems). """ +from io import BufferedReader, BytesIO from clouddrift import adapters import os +import platform import xarray as xr @@ -85,7 +87,11 @@ def gdp6h(decode_times: bool = True) -> xr.Dataset: The data is accessed from a public HTTPS server at NOAA's Atlantic Oceanographic and Meteorological Laboratory (AOML) accessible at - https://www.aoml.noaa.gov/phod/gdp/index.php. + https://www.aoml.noaa.gov/phod/gdp/index.php. It should be noted that the data loading + method is platform dependent. Linux and Darwin (macOS) machines lazy load the datasets leveraging the + byte-range feature of the netCDF-c library (dataset loading engine used by xarray). + Windows machines download the entire dataset into a memory buffer which is then passed + to xarray. Parameters ---------- @@ -105,47 +111,54 @@ def gdp6h(decode_times: bool = True) -> xr.Dataset: >>> ds = gdp6h() >>> ds - Dimensions: (traj: 26843, obs: 44544647) + Dimensions: (traj: 27647, obs: 46535470) Coordinates: - id (traj) int64 ... - time (obs) datetime64[ns] ... - lon (obs) float32 ... - lat (obs) float32 ... + ids (obs) int64 7702204 7702204 ... 300234061198840 + time (obs) float64 2.879e+08 2.879e+08 ... 1.697e+09 Dimensions without coordinates: traj, obs - Data variables: (12/44) - rowsize (traj) int32 ... - WMO (traj) int32 ... - expno (traj) int32 ... - deploy_date (traj) datetime64[ns] ... - deploy_lat (traj) float32 ... + Data variables: (12/50) + ID (traj) int64 7702204 7702201 ... 300234061198840 + rowsize (traj) int32 92 1747 1943 1385 1819 ... 54 53 51 28 + WMO (traj) int32 0 0 0 0 ... 6203890 6203888 4101885 + expno (traj) int32 40 40 40 40 ... 31412 21421 21421 31412 + deploy_date (traj) float32 2.878e+08 2.878e+08 ... 1.696e+09 nan + deploy_lat (traj) float32 -7.798 -4.9 -3.18 ... 9.9 11.9 nan ... ... - vn (obs) float32 ... - temp (obs) float32 ... - err_lat (obs) float32 ... - err_lon (obs) float32 ... - err_temp (obs) float32 ... - drogue_status (obs) bool ... - Attributes: (12/16) - title: Global Drifter Program six-hourly drifting buoy collec... - history: Last update July 2022. Metadata from dirall.dat and d... - Conventions: CF-1.6 - date_created: 2022-12-08T18:44:27.784441 - publisher_name: GDP Drifter DAC - publisher_email: aoml.dftr@noaa.gov - ... ... - contributor_name: NOAA Global Drifter Program - contributor_role: Data Acquisition Center - institution: NOAA Atlantic Oceanographic and Meteorological Laboratory - acknowledgement: Lumpkin, Rick; Centurioni, Luca (2019). NOAA Global Dr... - summary: Global Drifter Program six-hourly data - doi: 10.25921/7ntx-z961 + vn (obs) float32 nan 0.1056 0.04974 ... 0.7384 nan + temp (obs) float32 28.35 28.3 nan ... 29.08 28.97 28.92 + err_lat (obs) float32 0.009737 0.007097 ... 0.001659 0.001687 + err_lon (obs) float32 0.00614 0.004583 ... 0.002471 0.002545 + err_temp (obs) float32 0.08666 0.08757 ... 0.03665 0.03665 + drogue_status (obs) bool False False False False ... True True True + Attributes: (12/18) + title: Global Drifter Program drifting buoy collection + history: version September 2023. Metadata from dirall.dat an... + Conventions: CF-1.6 + time_coverage_start: 1979-02-15:00:00:00Z + time_coverage_end: 2023-10-18:18:00:00Z + date_created: 2023-12-22T17:50:22.242943 + ... ... + contributor_name: NOAA Global Drifter Program + contributor_role: Data Acquisition Center + institution: NOAA Atlantic Oceanographic and Meteorological Labo... + acknowledgement: Lumpkin, Rick; Centurioni, Luca (2019). NOAA Global... + summary: Global Drifter Program six-hourly data + doi: 10.25921/7ntx-z961 See Also -------- :func:`gdp1h` """ - url = "https://www.aoml.noaa.gov/ftp/pub/phod/buoydata/gdp_jul22_ragged_6h.nc#mode=bytes" - ds = xr.open_dataset(url, decode_times=decode_times) + url = "https://www.aoml.noaa.gov/ftp/pub/phod/buoydata/gdp6h_ragged_may23.nc#mode=bytes" + + if platform.system() == "Windows": + buffer = BytesIO() + adapters.utils.download_with_progress([(f"{url}#mode=bytes", buffer)]) + reader = BufferedReader(buffer) + ds = xr.open_dataset(reader, decode_times=decode_times) + else: + ds = xr.open_dataset(f"{url}", decode_times=decode_times) + ds = ds.rename_vars({"ID": "id"}).assign_coords({"id": ds.ID}).drop_vars(["ids"]) return ds diff --git a/environment.yml b/environment.yml index 8dea8e0d..95f4d4cb 100644 --- a/environment.yml +++ b/environment.yml @@ -6,6 +6,7 @@ dependencies: - numpy>=1.21.6 - xarray>=2023.5.0 - pandas>=1.3.4 + - h5netcdf>=1.3.0 - netcdf4>=1.6.4 - pyarrow>=9.0.0 - tqdm>=4.64.1 @@ -17,5 +18,6 @@ dependencies: - requests>=2.31.0 - scipy>=1.11.2 - zarr>=2.14.2 + - tenacity>=8.2.3 - pip: - git+https://github.com/Cloud-Drift/clouddrift.git diff --git a/pyproject.toml b/pyproject.toml index 24b2d529..91d70a14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,7 @@ dependencies = [ "awkward>=2.0.0", "fsspec>=2022.3.0", "netcdf4>=1.6.4", + "h5netcdf>=1.3.0", "numpy>=1.22.4", "pandas>=1.3.4", "pyarrow>=8.0.0", @@ -32,6 +33,7 @@ dependencies = [ "scipy>=1.11.2", "xarray>=2023.5.0", "zarr>=2.14.2", + "tenacity>=8.2.3" ] [project.urls] diff --git a/tests/datasets_tests.py b/tests/datasets_tests.py index 68893a65..0c6935ec 100644 --- a/tests/datasets_tests.py +++ b/tests/datasets_tests.py @@ -10,45 +10,45 @@ class datasets_tests(unittest.TestCase): def test_gdp1h(self): - ds = datasets.gdp1h() - self.assertTrue(ds) + with datasets.gdp1h() as ds: + self.assertTrue(ds) def test_gdp6h(self): - ds = datasets.gdp6h() - self.assertTrue(ds) + with datasets.gdp6h() as ds: + self.assertTrue(ds) def test_glad(self): - ds = datasets.glad() - self.assertTrue(ds) + with datasets.glad() as ds: + self.assertTrue(ds) def test_glad_dims_coords(self): - ds = datasets.glad() - self.assertTrue(len(ds.sizes) == 2) - self.assertTrue("obs" in ds.dims) - self.assertTrue("traj" in ds.dims) - self.assertTrue(len(ds.coords) == 2) - self.assertTrue("time" in ds.coords) - self.assertTrue("id" in ds.coords) + with datasets.glad() as ds: + self.assertTrue(len(ds.sizes) == 2) + self.assertTrue("obs" in ds.dims) + self.assertTrue("traj" in ds.dims) + self.assertTrue(len(ds.coords) == 2) + self.assertTrue("time" in ds.coords) + self.assertTrue("id" in ds.coords) def test_glad_subset_and_apply_ragged_work(self): - ds = datasets.glad() - ds_sub = subset(ds, {"id": ["CARTHE_001", "CARTHE_002"]}, id_var_name="id") - self.assertTrue(ds_sub) - mean_lon = apply_ragged(np.mean, [ds_sub.longitude], ds_sub.rowsize) - self.assertTrue(mean_lon.size == 2) + with datasets.glad() as ds: + ds_sub = subset(ds, {"id": ["CARTHE_001", "CARTHE_002"]}, id_var_name="id") + self.assertTrue(ds_sub) + mean_lon = apply_ragged(np.mean, [ds_sub.longitude], ds_sub.rowsize) + self.assertTrue(mean_lon.size == 2) def test_spotters_opens(self): - ds = datasets.spotters() - self.assertTrue(ds) + with datasets.spotters() as ds: + self.assertTrue(ds) def test_subsurface_floats_opens(self): - ds = datasets.subsurface_floats() - self.assertTrue(ds) + with datasets.subsurface_floats() as ds: + self.assertTrue(ds) def test_andro_opens(self): - ds = datasets.andro() - self.assertTrue(ds) + with datasets.andro() as ds: + self.assertTrue(ds) def test_yomaha_opens(self): - ds = datasets.yomaha() - self.assertTrue(ds) + with datasets.yomaha() as ds: + self.assertTrue(ds) diff --git a/tests/ragged_tests.py b/tests/ragged_tests.py index e2fd434b..7157a362 100644 --- a/tests/ragged_tests.py +++ b/tests/ragged_tests.py @@ -628,8 +628,9 @@ def test_combine(self): self.assertTrue(all(ds_sub.lat == [-90, -45, 10])) def test_empty(self): - ds_sub = subset(self.ds, {"id": 3, "lon": (-180, 0)}) - self.assertTrue(ds_sub.dims == {}) + with self.assertWarns(UserWarning): + ds_sub = subset(self.ds, {"id": 3, "lon": (-180, 0)}) + self.assertTrue(len(ds_sub.sizes) == 0) def test_unknown_var(self): with self.assertRaises(ValueError):