Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework generated ragged array to utilize new conventions for coordinates #374

Merged
Merged
11 changes: 6 additions & 5 deletions clouddrift/adapters/gdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
import xarray as xr

from clouddrift.adapters.utils import download_with_progress
from clouddrift.raggedarray import DimNames

GDP_COORDS = [
"ids",
"time",
GDP_COORDS: list[tuple[str, DimNames]] = [
("id", "traj"),
("time", "obs"),
]

GDP_METADATA = [
Expand Down Expand Up @@ -179,7 +180,7 @@
idx : list
Unique set of drifter IDs sorted by their start date.
"""
return df.ID[np.where(np.in1d(df.ID, idx))[0]].values
return df.ID[np.where(np.in1d(df.ID, idx))[0]].values # type: ignore

Check warning on line 183 in clouddrift/adapters/gdp.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/gdp.py#L183

Added line #L183 was not covered by tests


def fetch_netcdf(url: str, file: str):
Expand Down Expand Up @@ -281,7 +282,7 @@
return charar


def drogue_presence(lost_time, time) -> bool:
def drogue_presence(lost_time, time) -> np.ndarray:
"""Create drogue status from the drogue lost time and the trajectory time.

Parameters
Expand Down
11 changes: 3 additions & 8 deletions clouddrift/adapters/gdp1h.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@
filelist: Sequence[str] = re.compile(pattern).findall(string) # noqa: F821
else:
filelist = [filename_pattern.format(id=did) for did in drifter_ids]
filelist = np.unique(filelist)
filelist = list(np.unique(filelist))

Check warning on line 97 in clouddrift/adapters/gdp1h.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/gdp1h.py#L97

Added line #L97 was not covered by tests

# retrieve only a subset of n_random_id trajectories
if n_random_id:
Expand Down Expand Up @@ -204,7 +204,6 @@
warnings.warn(f"Variable {var} not found in upstream data; skipping.")

# new variables
ds["ids"] = (["traj", "obs"], [np.repeat(ds.ID.values, ds.sizes["obs"])])
ds["drogue_status"] = (
["traj", "obs"],
[gdp.drogue_presence(ds.drogue_lost_date.data, ds.time.data[0])],
Expand Down Expand Up @@ -284,10 +283,6 @@
"longitude": {"long_name": "Longitude", "units": "degrees_east"},
"latitude": {"long_name": "Latitude", "units": "degrees_north"},
"time": {"long_name": "Time", "units": "seconds since 1970-01-01 00:00:00"},
"ids": {
"long_name": "Global Drifter Program Buoy ID repeated along observations",
"units": "-",
},
"rowsize": {
"long_name": "Number of observations per trajectory",
"sample_dimension": "obs",
Expand Down Expand Up @@ -501,7 +496,7 @@
ds.attrs = attrs

# rename variables
ds = ds.rename_vars({"longitude": "lon", "latitude": "lat"})
ds = ds.rename_vars({"longitude": "lon", "latitude": "lat", "ID": "id"})

Check warning on line 499 in clouddrift/adapters/gdp1h.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/gdp1h.py#L499

Added line #L499 was not covered by tests

# Cast float64 variables to float32 to reduce memory footprint.
ds = gdp.cast_float64_variables_to_float32(ds)
Expand Down Expand Up @@ -586,7 +581,7 @@
ra = RaggedArray.from_files(
indices=ids,
preprocess_func=preprocess,
name_coords=gdp.GDP_COORDS,
coord_dim_map=gdp.GDP_COORDS,
name_meta=gdp.GDP_METADATA,
name_data=GDP_DATA,
rowsize_func=gdp.rowsize,
Expand Down
18 changes: 9 additions & 9 deletions clouddrift/adapters/gdp6h.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import numpy as np
import xarray as xr
from numpy.typing import ArrayLike

import clouddrift.adapters.gdp as gdp
from clouddrift.adapters.utils import download_with_progress
Expand Down Expand Up @@ -80,12 +79,13 @@
if drifter_ids is None:
urlpath = urllib.request.urlopen(url)
string = urlpath.read().decode("utf-8")
drifter_urls: ArrayLike = []
drifter_urls: list[str] = []

Check warning on line 82 in clouddrift/adapters/gdp6h.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/gdp6h.py#L82

Added line #L82 was not covered by tests
for dir in directory_list:
urlpath = urllib.request.urlopen(os.path.join(url, dir))
string = urlpath.read().decode("utf-8")
filelist = list(set(re.compile(pattern).findall(string)))
drifter_urls += [os.path.join(url, dir, f) for f in filelist]
for f in filelist:
drifter_urls.append(os.path.join(url, dir, f))

Check warning on line 88 in clouddrift/adapters/gdp6h.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/gdp6h.py#L87-L88

Added lines #L87 - L88 were not covered by tests

# retrieve only a subset of n_random_id trajectories
if n_random_id:
Expand All @@ -95,7 +95,7 @@
)
else:
rng = np.random.RandomState(42)
drifter_urls = rng.choice(drifter_urls, n_random_id, replace=False)
drifter_urls = list(rng.choice(drifter_urls, n_random_id, replace=False))

Check warning on line 98 in clouddrift/adapters/gdp6h.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/gdp6h.py#L98

Added line #L98 was not covered by tests

download_with_progress(
[
Expand Down Expand Up @@ -204,7 +204,7 @@
ds["BuoyTypeSensorArray"] = (("traj"), gdp.cut_str(ds.BuoyTypeSensorArray, 20))
ds["CurrentProgram"] = (
("traj"),
np.int32([gdp.str_to_float(ds.CurrentProgram, -1)]),
np.int32(gdp.str_to_float(ds.CurrentProgram, -1)),
)
ds["PurchaserFunding"] = (("traj"), gdp.cut_str(ds.PurchaserFunding, 20))
ds["SensorUpgrade"] = (("traj"), gdp.cut_str(ds.SensorUpgrade, 20))
Expand All @@ -218,16 +218,16 @@
) # remove non ascii char
ds["ManufactureYear"] = (
("traj"),
np.int16([gdp.str_to_float(ds.ManufactureYear, -1)]),
np.int16(gdp.str_to_float(ds.ManufactureYear, -1)),
)
ds["ManufactureMonth"] = (
("traj"),
np.int16([gdp.str_to_float(ds.ManufactureMonth, -1)]),
np.int16(gdp.str_to_float(ds.ManufactureMonth, -1)),
)
ds["ManufactureSensorType"] = (("traj"), gdp.cut_str(ds.ManufactureSensorType, 20))
ds["ManufactureVoltage"] = (
("traj"),
np.int16([gdp.str_to_float(ds.ManufactureVoltage[:-6], -1)]),
np.int16(gdp.str_to_float(ds.ManufactureVoltage[:-6], -1)),
) # e.g. 56 V
ds["FloatDiameter"] = (
("traj"),
Expand Down Expand Up @@ -485,7 +485,7 @@
ra = RaggedArray.from_files(
indices=ids,
preprocess_func=preprocess,
name_coords=gdp.GDP_COORDS,
coord_dim_map=gdp.GDP_COORDS,
name_meta=gdp.GDP_METADATA,
name_data=GDP_DATA,
rowsize_func=gdp.rowsize,
Expand Down
4 changes: 2 additions & 2 deletions clouddrift/adapters/subsurface_floats.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import tempfile
import warnings
from datetime import datetime
from typing import Union
from typing import Hashable, List, Union

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -49,7 +49,7 @@ def to_xarray(
source_data = scipy.io.loadmat(local_file)

# metadata
meta_variables = [
meta_variables: List[Hashable] = [
"expList",
"expName",
"expOrg",
Expand Down
6 changes: 3 additions & 3 deletions clouddrift/adapters/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

_CHUNK_SIZE = 1024
_logger = logging.getLogger(__name__)
_standard_retry_protocol = retry(
_standard_retry_protocol: Callable[[WrappedFn], WrappedFn] = retry(
retry=retry_if_exception(
lambda ex: isinstance(ex, (requests.Timeout, requests.HTTPError))
),
wait=wait_exponential_jitter(initial=0.25),
stop=stop_after_attempt(10),
before=lambda rcs: _logger.debug(
f"calling {rcs.fn.__module__}.{rcs.fn.__name__}, attempt: {rcs.attempt_number}"
f"calling {str(rcs.fn)}, attempt: {rcs.attempt_number}"
),
)

Expand All @@ -42,7 +42,7 @@
if custom_retry_protocol is None:
retry_protocol = _standard_retry_protocol
else:
retry_protocol = custom_retry_protocol
retry_protocol = custom_retry_protocol # type: ignore

Check warning on line 45 in clouddrift/adapters/utils.py

View check run for this annotation

Codecov / codecov/patch

clouddrift/adapters/utils.py#L45

Added line #L45 was not covered by tests

executor = concurrent.futures.ThreadPoolExecutor()
futures: dict[
Expand Down
Loading
Loading