Skip to content

Commit

Permalink
Automatically standardize dtypes (#117)
Browse files Browse the repository at this point in the history
  • Loading branch information
maxrjones authored Mar 20, 2024
1 parent 9711b6e commit d1ece2d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 22 deletions.
55 changes: 44 additions & 11 deletions ndpyramid/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ def set_zarr_encoding(
codec_config: dict | None = None,
float_dtype: npt.DTypeLike | None = None,
int_dtype: npt.DTypeLike | None = None,
datetime_dtype: npt.DTypeLike | None = None,
object_dtype: npt.DTypeLike | None = None,
) -> xr.Dataset:
"""Set zarr encoding for each variable in the dataset
Expand All @@ -77,11 +79,26 @@ def set_zarr_encoding(
The default is {'id': 'zlib', 'level': 1}
float_dtype : str or dtype, optional
Dtype to cast floating point variables to
int_dtype : str or dtype, optional
Dtype to cast integer variables to
object_dtype : str or dtype, optional
Dtype to cast object variables to.
datetime_dtype : str or dtype, optional
Dtype to encode numpy.datetime64 variables as.
Time coordinates are encoded as 'int32' if cf_xarray
is able to identify the coordinates representing time,
even if `datetime_dtype` is None.
Returns
-------
ds : xr.Dataset
Output dataset with updated variable encodings
Notes
-----
The *_dtype parameters can be used to coerce variables into data types
readable by Zarr implementations in other languages.
"""
import numcodecs

Expand All @@ -93,26 +110,30 @@ def set_zarr_encoding(

time_vars = ds.cf.axes.get('T', []) + ds.cf.bounds.get('T', [])
for varname, da in ds.variables.items():
# maybe cast float type
# remove old encoding
da.encoding.clear()

# maybe cast data type
if np.issubdtype(da.dtype, np.floating) and float_dtype is not None:
da = da.astype(float_dtype)

if np.issubdtype(da.dtype, np.integer) and int_dtype is not None:
da.encoding['dtype'] = str(float_dtype)
elif np.issubdtype(da.dtype, np.integer) and int_dtype is not None:
da = da.astype(int_dtype)

# remove old encoding
da.encoding.clear()
da.encoding['dtype'] = str(int_dtype)
elif da.dtype == 'O' and object_dtype is not None:
da = da.astype(object_dtype)
da.encoding['dtype'] = str(object_dtype)
elif np.issubdtype(da.dtype, np.datetime64) and datetime_dtype is not None:
da.encoding['dtype'] = str(datetime_dtype)
elif varname in time_vars:
da.encoding['dtype'] = 'int32'

# update with new encoding
da.encoding['compressor'] = compressor
with contextlib.suppress(KeyError):
del da.attrs['_FillValue']
da.encoding['_FillValue'] = default_fillvals.get(da.dtype.str[-2:], None)

# TODO: handle date/time types
# set encoding for time and time_bnds
if varname in time_vars:
da.encoding['dtype'] = 'int32'
ds[varname] = da

return ds
Expand Down Expand Up @@ -145,6 +166,13 @@ def add_metadata_and_zarr_encoding(
-------
dt.DataTree
Updated data pyramid with metadata / encoding set
Notes
-----
The variables within the pyramid are coerced into data types readable by
`@carbonplan/maps`. See https://ndpyramid.readthedocs.io/en/latest/schema.html
for more information. Raise an issue in https://github.com/carbonplan/ndpyramid
if more flexibility is needed.
'''
chunks = {'x': pixels_per_tile, 'y': pixels_per_tile}
if other_chunks is not None:
Expand All @@ -160,7 +188,12 @@ def add_metadata_and_zarr_encoding(

# set dataset encoding
pyramid[slevel].ds = set_zarr_encoding(
pyramid[slevel].ds, codec_config={'id': 'zlib', 'level': 1}, float_dtype='float32'
pyramid[slevel].ds,
codec_config={'id': 'zlib', 'level': 1},
float_dtype='float32',
int_dtype='int32',
datetime_dtype='int32',
object_dtype='str',
)

# set global metadata
Expand Down
21 changes: 10 additions & 11 deletions notebooks/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -141,16 +141,13 @@
" .squeeze()\n",
" .reset_coords([\"band\"], drop=True)\n",
")\n",
"ds2[\"climate\"] = ds2[\"climate\"].astype(\"float32\")\n",
"ds2[\"climate\"].values[ds2[\"climate\"].values == ds2[\"climate\"].values[0, 0]] = ds1[\"climate\"].values[\n",
" 0, 0\n",
"]\n",
"ds = xr.concat([ds1, ds2], pd.Index([\"tavg\", \"prec\"], name=\"band\"))\n",
"ds[\"band\"] = ds[\"band\"].astype(\"str\")\n",
"\n",
"# create the pyramid\n",
"dt = pyramid_reproject(ds, levels=LEVELS, other_chunks={'band': 2}, clear_attrs=True)\n",
"dt.ds.attrs\n",
"\n",
"# write the pyramid to zarr\n",
"dt.to_zarr(store_3d, consolidated=True)"
Expand Down Expand Up @@ -189,11 +186,9 @@
" )\n",
" ds_all.append(ds)\n",
"ds = xr.concat(ds_all, pd.Index(months, name=\"month\"))\n",
"ds[\"month\"] = ds[\"month\"].astype(\"int32\")\n",
"\n",
"# create the pyramid\n",
"dt = pyramid_reproject(ds, levels=LEVELS, other_chunks={'month': 12}, clear_attrs=True)\n",
"dt.ds.attrs\n",
"\n",
"# write the pyramid to zarr\n",
"dt.to_zarr(store_3d_1var, consolidated=True)"
Expand Down Expand Up @@ -243,14 +238,10 @@
" )\n",
" ds2_all.append(ds)\n",
"ds2 = xr.concat(ds2_all, pd.Index(months, name=\"month\"))\n",
"ds1[\"month\"] = ds1[\"month\"].astype(\"int32\")\n",
"ds2[\"month\"] = ds2[\"month\"].astype(\"int32\")\n",
"ds2[\"climate\"] = ds2[\"climate\"].astype(\"float32\")\n",
"ds2[\"climate\"].values[ds2[\"climate\"].values == ds2[\"climate\"].values[0, 0, 0]] = ds1[\n",
" \"climate\"\n",
"].values[0, 0, 0]\n",
"ds = xr.concat([ds1, ds2], pd.Index([\"tavg\", \"prec\"], name=\"band\"))\n",
"ds[\"band\"] = ds[\"band\"].astype(\"str\")\n",
"\n",
"# create the pyramid\n",
"dt = pyramid_reproject(\n",
Expand All @@ -259,8 +250,16 @@
"dt.ds.attrs\n",
"\n",
"# write the pyramid to zarr\n",
"dt.to_zarr(store_4d, consolidated=True)"
"dt.to_zarr(store_4d, consolidated=True, mode=\"w\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "11",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -274,7 +273,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.7"
"version": "3.12.2"
}
},
"nbformat": 4,
Expand Down

0 comments on commit d1ece2d

Please sign in to comment.