From 99d166c00ba8057e6e467b9570fd57e325cf2e04 Mon Sep 17 00:00:00 2001 From: Kevin Santana Date: Sun, 29 Sep 2024 14:50:25 -0700 Subject: [PATCH] =?UTF-8?q?=E2=9A=A1=20use=20dask=20to=20improve=20memory?= =?UTF-8?q?=20and=20compute=20scalability=20*=20also=20remove?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- clouddrift/adapters/gdp/gdpsource.py | 256 ++++++----------------- clouddrift/datasets.py | 6 +- environment.yml | 3 +- pyproject.toml | 3 +- tests/adapters/gdp/source_integ_tests.py | 2 +- 5 files changed, 66 insertions(+), 204 deletions(-) diff --git a/clouddrift/adapters/gdp/gdpsource.py b/clouddrift/adapters/gdp/gdpsource.py index 595f533f..f6d7887c 100644 --- a/clouddrift/adapters/gdp/gdpsource.py +++ b/clouddrift/adapters/gdp/gdpsource.py @@ -1,19 +1,18 @@ from __future__ import annotations -import asyncio import datetime import logging import os import tempfile import warnings -from collections import defaultdict -from concurrent.futures import Future, ProcessPoolExecutor, as_completed from typing import Callable +import dask.dataframe as dd import numpy as np import pandas as pd import xarray as xr -from tqdm.asyncio import tqdm +from dask import delayed +from tqdm import tqdm from clouddrift.adapters.gdp import get_gdp_metadata from clouddrift.adapters.utils import download_with_progress @@ -25,12 +24,11 @@ _FILENAME_TEMPLATE = "buoydata_{start}_{end}_{suffix}.dat.gz" _SECONDS_IN_DAY = 86_400 -_COORDS = ["id", "obs_index"] +_COORDS = ["id", "position_datetime"] _DATA_VARS = [ "latitude", "longitude", - "position_datetime", "sensor_datetime", "drogue", "sst", @@ -108,15 +106,15 @@ _INPUT_COLS_DTYPES = { "id": np.int64, - "posObsMonth": np.int8, + "posObsMonth": np.float32, "posObsDay": np.float64, - "posObsYear": np.int16, + "posObsYear": np.float32, "latitude": np.float32, "longitude": np.float32, "qualityIndex": np.float32, - "senObsMonth": np.int8, + "senObsMonth": np.float32, "senObsDay": np.float64, - "senObsYear": np.int16, + "senObsYear": np.float32, "drogue": np.float32, "sst": np.float32, "voltage": np.float32, @@ -125,6 +123,14 @@ "sensor6": np.float32, } +_INPUT_COLS_PREFILTER_DTYPES: dict[str, type[object]] = { + "posObsMonth": np.str_, + "posObsYear": np.float64, + "senObsMonth": np.str_, + "senObsYear": np.float64, + "drogue": np.str_, +} + VARS_ATTRS: dict = { "id": {"long_name": "Global Drifter Program Buoy ID", "units": "-"}, @@ -323,9 +329,9 @@ def _preprocess(id_, **kwargs) -> xr.Dataset: coords = { "id": (["traj"], np.array([id_]).astype(np.int64)), - "obs_index": ( + "position_datetime": ( ["obs"], - traj_data_df[["obs_index"]].values.flatten().astype(np.int32), + traj_data_df[["position_datetime"]].values.flatten().astype(np.datetime64), ), } @@ -374,20 +380,18 @@ def _parse_datetime_with_day_ratio( return np.array(values).astype("datetime64[ns]") -def _process_chunk( - df_chunk: pd.DataFrame, - start_idx: int, - end_idx: int, +def _process( + df: dd.DataFrame, gdp_metadata_df: pd.DataFrame, use_fill_values: bool, -) -> dict[int, xr.Dataset]: +) -> xr.Dataset: """Process each dataframe chunk. Return a dictionary mapping each drifter to a unique xarray Dataset.""" # Transform the initial dataframe filtering out rows with really anomolous values # examples include: years in the future, years way in the past before GDP program, etc... - preremove_df_chunk = df_chunk.assign(obs_index=range(start_idx, end_idx)) + preremove_df = df.compute() df_chunk = _apply_remove( - preremove_df_chunk, + preremove_df, filters=[ # Filter out year values that are in the future or predating the GDP program lambda df: (df["posObsYear"] > datetime.datetime.now().year) @@ -405,7 +409,7 @@ def _process_chunk( drifter_ds_map = dict[int, xr.Dataset]() - preremove_len = len(preremove_df_chunk) + preremove_len = len(preremove_df) postremove_len = len(df_chunk) if preremove_len != postremove_len: @@ -455,153 +459,16 @@ def _process_chunk( md_df=gdp_metadata_df, data_df=df_chunk, use_fill_values=use_fill_values, - tqdm=dict(disable=True), - ) - ds = ra.to_xarray() - - for id_ in ids_with_md: - id_f_ds = subset(ds, dict(id=id_), row_dim_name="traj") - drifter_ds_map[id_] = id_f_ds - return drifter_ds_map - - -def _combine_chunked_drifter_datasets(datasets: list[xr.Dataset]) -> xr.Dataset: - """Combines several drifter observations found in separate chunks, ordering them - by the observations row index. - """ - traj_dataset = xr.concat( - datasets, dim="obs", coords="minimal", data_vars=_DATA_VARS, compat="override" + tqdm={"disable": True}, ) + return ra.to_xarray() - new_rowsize = sum([ds.rowsize.values[0] for ds in datasets]) - traj_dataset["rowsize"] = xr.DataArray( - np.array([new_rowsize], dtype=np.int64), coords=traj_dataset["rowsize"].coords - ) - - sort_coord = traj_dataset.coords["obs_index"] - vals: np.ndarray = sort_coord.data - sort_coord_dim = sort_coord.dims[-1] - sort_key = vals.argsort() - - for coord_name in _COORDS: - coord = traj_dataset.coords[coord_name] - dim = coord.dims[-1] - - if dim == sort_coord_dim: - sorted_coord = coord.isel({dim: sort_key}) - traj_dataset.coords[coord_name] = sorted_coord - - for varname in _DATA_VARS: - var = traj_dataset[varname] - dim = var.dims[-1] - sorted_var = var.isel({dim: sort_key}) - traj_dataset[varname] = sorted_var - - return traj_dataset - - -async def _parallel_get( - sources: list[str], - gdp_metadata_df: pd.DataFrame, - chunk_size: int, - tmp_path: str, - use_fill_values: bool, - max_chunks: int | None, -) -> list[xr.Dataset]: - """Parallel process dataset in chunks leveraging multiprocessing.""" - max_workers = (os.cpu_count() or 0) // 2 - with ProcessPoolExecutor(max_workers=max_workers) as ppe: - drifter_chunked_datasets: dict[int, list[xr.Dataset]] = defaultdict(list) - start_idx = 0 - for fp in tqdm( - sources, - desc="Loading files", - unit="file", - ncols=80, - total=len(sources), - position=0, - ): - file_chunks = pd.read_csv( - fp, - sep=r"\s+", - header=None, - names=_INPUT_COLS, - engine="c", - compression="gzip", - chunksize=chunk_size, - ) - - joblist = list[Future]() - jobmap = dict[Future, pd.DataFrame]() - for idx, chunk in enumerate(file_chunks): - if max_chunks is not None and idx >= max_chunks: - break - ajob = ppe.submit( - _process_chunk, - chunk, - start_idx, - start_idx + len(chunk), - gdp_metadata_df, - use_fill_values, - ) - start_idx += len(chunk) - jobmap[ajob] = chunk - joblist.append(ajob) - - bar = tqdm( - desc="Processing file chunks", - unit="chunk", - ncols=80, - total=len(joblist), - position=1, - ) - - for ajob in as_completed(jobmap.keys()): - if (exc := ajob.exception()) is not None: - chunk = jobmap[ajob] - _logger.warn(f"bad chunk detected, exception: {ajob.exception()}") - raise exc - - job_drifter_ds_map: dict[int, xr.Dataset] = ajob.result() - for id_ in job_drifter_ds_map.keys(): - drifter_ds = job_drifter_ds_map[id_] - drifter_chunked_datasets[id_].append(drifter_ds) - bar.update() - - combine_jobmap = dict[Future, int]() - for id_ in drifter_chunked_datasets.keys(): - datasets = drifter_chunked_datasets[id_] - - combine_job = ppe.submit(_combine_chunked_drifter_datasets, datasets) - combine_jobmap[combine_job] = id_ - - bar.close() - bar = tqdm( - desc="merging drifter chunks", - unit="drifter", - ncols=80, - total=len(drifter_chunked_datasets.keys()), - position=2, - ) - - os.makedirs(os.path.join(tmp_path, "drifters"), exist_ok=True) - - drifter_datasets = list[xr.Dataset]() - for combine_job in as_completed(combine_jobmap.keys()): - dataset: xr.Dataset = combine_job.result() - drifter_datasets.append(dataset) - bar.update() - bar.close() - return drifter_datasets def to_raggedarray( tmp_path: str = _TMP_PATH, - skip_download: bool = False, max: int | None = None, - chunk_size: int = 100_000, use_fill_values: bool = True, - max_chunks: int | None = None, ) -> xr.Dataset: """Get the GDP source dataset.""" @@ -611,52 +478,49 @@ def to_raggedarray( # Filter down for testing purposes. if max: - requests = [requests[max]] + requests = requests[:max] # Download necessary data and metadata files. - if not skip_download: - download_with_progress(requests) + download_with_progress(requests) gdp_metadata_df = get_gdp_metadata(tmp_path) - # Run async process to parallelize data processing. - drifter_datasets = asyncio.run( - _parallel_get( - [dst for (_, dst) in requests], - gdp_metadata_df, - chunk_size, - tmp_path, - use_fill_values, - max_chunks, - ) + import gzip + + data_files = list() + for compressed_data_file in tqdm( + [dst for (_, dst) in requests], desc="Decompressing files", unit="file" + ): + decompressed_fp = compressed_data_file[:-3] + data_files.append(decompressed_fp) + if not os.path.exists(decompressed_fp): + with ( + gzip.open(compressed_data_file, "rb") as compr, + open(decompressed_fp, "wb") as decompr, + ): + decompr.write(compr.read()) + + wanted_dtypes = dict() + wanted_dtypes.update(_INPUT_COLS_DTYPES) + wanted_dtypes.update(_INPUT_COLS_PREFILTER_DTYPES) + + df: dd.DataFrame = dd.read_csv( + data_files, + sep=r"\s+", + header=None, + names=_INPUT_COLS, + dtype=wanted_dtypes, + engine="c", + blocksize="1GB", + assume_missing=True, ) + ds = _process(df, gdp_metadata_df, use_fill_values) # Sort the drifters by their start date. - deploy_date_id_map = { - ds["id"].data[0]: ds["start_date"].data[0] for ds in drifter_datasets - } - deploy_date_sort_key = np.argsort(list(deploy_date_id_map.values())) - sorted_drifter_datasets = [drifter_datasets[idx] for idx in deploy_date_sort_key] - - # Concatenate drifter data and metadata variables separately. - obs_ds = xr.concat( - [ds.drop_dims("traj") for ds in sorted_drifter_datasets], - dim="obs", - data_vars=_DATA_VARS, - ) - traj_ds = xr.concat( - [ds.drop_dims("obs") for ds in sorted_drifter_datasets], - dim="traj", - data_vars=_METADATA_VARS, - ) - - # Merge the separate datasets. - agg_ds = xr.merge([obs_ds, traj_ds]) - # Add variable metadata. for var_name in _DATA_VARS + _METADATA_VARS: if var_name in VARS_ATTRS.keys(): - agg_ds[var_name].attrs = VARS_ATTRS[var_name] - agg_ds.attrs = ATTRS + ds[var_name].attrs = VARS_ATTRS[var_name] + ds.attrs = ATTRS - return agg_ds + return ds diff --git a/clouddrift/datasets.py b/clouddrift/datasets.py index 36864ef1..b431eb08 100644 --- a/clouddrift/datasets.py +++ b/clouddrift/datasets.py @@ -156,7 +156,6 @@ def gdp6h(decode_times: bool = True) -> xr.Dataset: def gdp_source( tmp_path: str = adapters.gdp_source._TMP_PATH, max: int | None = None, - skip_download: bool = False, use_fill_values: bool = True, decode_times: bool = True, ) -> xr.Dataset: @@ -178,9 +177,6 @@ def gdp_source( max: int, optional Maximum number of files to retrieve and parse to generate the aggregate file. Mainly used for testing purposes. - skip_download: bool, False (default) - If True, skips downloading the data files and the code assumes the files have already been downloaded. - This is mainly used to skip downloading files if the remote doesn't provide the HTTP Last-Modified header. use_fill_values: bool, True (default) When True, missing metadata fields are replaced with fill values. When False and no metadata is found for a given drifter its observations are ignored. @@ -225,7 +221,7 @@ def gdp_source( f"gdpsource_agg_{file_selection_label}.zarr", decode_times, lambda: adapters.gdp_source.to_raggedarray( - tmp_path, skip_download, max, use_fill_values=use_fill_values + tmp_path, max, use_fill_values=use_fill_values ), ) diff --git a/environment.yml b/environment.yml index a8a3dccd..6e209e5e 100644 --- a/environment.yml +++ b/environment.yml @@ -5,7 +5,7 @@ dependencies: - python>=3.10 - numpy>=1.21.6 - xarray>=2023.5.0 - - pandas>=1.3.4 + - pandas>=2.0.0 - h5netcdf>=1.3.0 - netcdf4>=1.6.4 - pyarrow>=9.0.0 @@ -19,3 +19,4 @@ dependencies: - scipy>=1.11.2 - zarr>=2.14.2 - tenacity>=8.2.3 + - dask>=2024.5.0 diff --git a/pyproject.toml b/pyproject.toml index eb966240..78c0c422 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,8 @@ dependencies = [ "scipy>=1.11.2", "xarray>=2023.5.0", "zarr>=2.14.2", - "tenacity>=8.2.3" + "tenacity>=8.2.3", + "dask>=2024.5.0" ] [project.optional-dependencies] diff --git a/tests/adapters/gdp/source_integ_tests.py b/tests/adapters/gdp/source_integ_tests.py index 7d91bb2d..cbcae203 100644 --- a/tests/adapters/gdp/source_integ_tests.py +++ b/tests/adapters/gdp/source_integ_tests.py @@ -19,7 +19,7 @@ def test_load_and_create_aggregate(self): may contain rows for one drifter (chunking can split a drifters trajectory) we join these partitioned segments in parallel per drifter. """ - ds = gdp_source.to_raggedarray(max=1, chunk_size=1_000, max_chunks=100) + ds = gdp_source.to_raggedarray(max=1) assert ds is not None all_drifter_obs_index = unpack(ds["obs_index"].data, ds["rowsize"])