Skip to content

Commit

Permalink
Add dask chunk size checks to ABI l1b tests
Browse files Browse the repository at this point in the history
  • Loading branch information
djhoese committed Oct 31, 2023
1 parent 8b5c450 commit 14f59c4
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 33 deletions.
42 changes: 31 additions & 11 deletions satpy/readers/abi_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,22 @@
"""Advance Baseline Imager reader base class for the Level 1b and l2+ reader."""

import logging
import math
from contextlib import suppress
from datetime import datetime

import dask
import numpy as np
import xarray as xr
from pyresample import geometry

from satpy._compat import cached_property
from satpy.readers import open_file_or_filename
from satpy.readers.file_handlers import BaseFileHandler
from satpy.utils import get_legacy_chunk_size
from satpy.utils import get_dask_chunk_size_in_bytes

logger = logging.getLogger(__name__)

CHUNK_SIZE = get_legacy_chunk_size()
PLATFORM_NAMES = {
"g16": "GOES-16",
"g17": "GOES-17",
Expand Down Expand Up @@ -62,15 +63,8 @@ def __init__(self, filename, filename_info, filetype_info):
@cached_property
def nc(self):
"""Get the xarray dataset for this file."""
import math

from satpy.utils import get_dask_chunk_size_in_bytes
chunk_size_for_high_res = math.sqrt(get_dask_chunk_size_in_bytes() / 4) # 32-bit floats
chunk_size_for_high_res = np.round(max(chunk_size_for_high_res / (4 * 226), 1)) * (4 * 226)
low_res_factor = int(self.filetype_info.get("resolution", 2000) // 500)
res_chunk_bytes = int(chunk_size_for_high_res / low_res_factor) * 4
import dask
with dask.config.set({"array.chunk-size": res_chunk_bytes}):
chunk_bytes = self._chunk_bytes_for_resolution()
with dask.config.set({"array.chunk-size": chunk_bytes}):
f_obj = open_file_or_filename(self.filename)
nc = xr.open_dataset(f_obj,
decode_cf=True,
Expand All @@ -79,6 +73,32 @@ def nc(self):
nc = self._rename_dims(nc)
return nc

def _chunk_bytes_for_resolution(self) -> int:
"""Get a best-guess optimal chunk size for resolution-based chunking.
First a chunk size is chosen for the provided Dask setting `array.chunk-size`
and then aligned with a hardcoded on-disk chunk size of 226. This is then
adjusted to match the current resolution.
This should result in 500 meter data having 4 times as many pixels per
dask array chunk (2 in each dimension) as 1km data and 8 times as many
as 2km data. As data is combined or upsampled geographically the arrays
should not need to be rechunked. Care is taken to make sure that array
chunks are aligned with on-disk file chunks at all resolutions, but at
the cost of flexibility due to a hardcoded on-disk chunk size of 226
elements per dimension.
"""
num_high_res_elems_per_dim = math.sqrt(get_dask_chunk_size_in_bytes() / 4) # 32-bit floats
# assume on-disk chunk size of 226
# this is true for all CSPP Geo GRB output (226 for all sectors) and full disk from other sources
# 250 has been seen for AWS/CLASS CONUS, Mesoscale 1, and Mesoscale 2 files
# we align this with 4 on-disk chunks at 500m, so it will be 2 on-disk chunks for 1km, and 1 for 2km
high_res_elems_disk_aligned = np.round(max(num_high_res_elems_per_dim / (4 * 226), 1)) * (4 * 226)
low_res_factor = int(self.filetype_info.get("resolution", 2000) // 500)
res_elems_per_dim = int(high_res_elems_disk_aligned / low_res_factor)
return (res_elems_per_dim ** 2) * 4

@staticmethod
def _rename_dims(nc):
if "t" in nc.dims or "t" in nc.coords:
Expand Down
70 changes: 48 additions & 22 deletions satpy/tests/reader_tests/test_abi_l1b.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Any, Callable
from unittest import mock

import dask
import dask.array as da
import numpy as np
import numpy.typing as npt
Expand All @@ -36,9 +37,12 @@

RAD_SHAPE = {
500: (3000, 5000), # conus - 500m
1000: (1500, 2500), # conus - 1km
2000: (750, 1250), # conus - 2km
}
# RAD_SHAPE = {
# 500: (21696, 21696), # fldk - 500m
# }
RAD_SHAPE[1000] = (RAD_SHAPE[500][0] // 2, RAD_SHAPE[500][1] // 2)
RAD_SHAPE[2000] = (RAD_SHAPE[500][0] // 4, RAD_SHAPE[500][1] // 4)


def _create_fake_rad_dataarray(
Expand All @@ -54,7 +58,7 @@ def _create_fake_rad_dataarray(
rad_data = (rad_data + 1.0) / 0.5
rad_data = rad_data.astype(np.int16)
rad = xr.DataArray(
da.from_array(rad_data),
da.from_array(rad_data, chunks=226),
dims=("y", "x"),
attrs={
"scale_factor": 0.5,
Expand Down Expand Up @@ -134,15 +138,21 @@ def generate_l1b_filename(chan_name: str) -> str:

@pytest.fixture()
def c01_refl(tmp_path) -> xr.DataArray:
scn = _create_scene_for_data(tmp_path, "C01", None, 1000)
scn.load(["C01"])
# 4 bytes for 32-bit floats
# 4 on-disk chunks for 500 meter data
# 226 on-disk chunk size
# Square (**2) for 2D size
with dask.config.set({"array.chunk-size": ((226 * 4) ** 2) * 4}):
scn = _create_scene_for_data(tmp_path, "C01", None, 1000)
scn.load(["C01"])
return scn["C01"]


@pytest.fixture()
def c01_rad(tmp_path) -> xr.DataArray:
scn = _create_scene_for_data(tmp_path, "C01", None, 1000)
scn.load([DataQuery(name="C01", calibration="radiance")])
with dask.config.set({"array.chunk-size": ((226 * 4) ** 2) * 4}):
scn = _create_scene_for_data(tmp_path, "C01", None, 1000)
scn.load([DataQuery(name="C01", calibration="radiance")])
return scn["C01"]


Expand All @@ -153,7 +163,7 @@ def c01_rad_h5netcdf(tmp_path) -> xr.DataArray:
rad_data = (rad_data + 1.0) / 0.5
rad_data = rad_data.astype(np.int16)
rad = xr.DataArray(
da.from_array(rad_data),
da.from_array(rad_data, chunks=226),
dims=("y", "x"),
attrs={
"scale_factor": 0.5,
Expand All @@ -163,15 +173,17 @@ def c01_rad_h5netcdf(tmp_path) -> xr.DataArray:
"valid_range": (0, 4095),
},
)
scn = _create_scene_for_data(tmp_path, "C01", rad, 1000)
scn.load([DataQuery(name="C01", calibration="radiance")])
with dask.config.set({"array.chunk-size": ((226 * 4) ** 2) * 4}):
scn = _create_scene_for_data(tmp_path, "C01", rad, 1000)
scn.load([DataQuery(name="C01", calibration="radiance")])
return scn["C01"]


@pytest.fixture()
def c01_counts(tmp_path) -> xr.DataArray:
scn = _create_scene_for_data(tmp_path, "C01", None, 1000)
scn.load([DataQuery(name="C01", calibration="counts")])
with dask.config.set({"array.chunk-size": ((226 * 4) ** 2) * 4}):
scn = _create_scene_for_data(tmp_path, "C01", None, 1000)
scn.load([DataQuery(name="C01", calibration="counts")])
return scn["C01"]


Expand All @@ -181,14 +193,15 @@ def _load_data_array(
clip_negative_radiances: bool = False,
):
rad = _fake_c07_data()
scn = _create_scene_for_data(
tmp_path,
"C07",
rad,
2000,
{"clip_negative_radiances": clip_negative_radiances},
)
scn.load(["C07"])
with dask.config.set({"array.chunk-size": ((226 * 4) ** 2) * 4}):
scn = _create_scene_for_data(
tmp_path,
"C07",
rad,
2000,
{"clip_negative_radiances": clip_negative_radiances},
)
scn.load(["C07"])
return scn["C07"]

return _load_data_array
Expand All @@ -202,7 +215,7 @@ def _fake_c07_data() -> xr.DataArray:
rad_data = (rad_data + 1.3) / 0.5
data = rad_data.astype(np.int16)
rad = xr.DataArray(
da.from_array(data),
da.from_array(data, chunks=226),
dims=("y", "x"),
attrs={
"scale_factor": 0.5,
Expand All @@ -225,7 +238,12 @@ def _create_scene_for_data(
filename = generate_l1b_filename(channel_name)
data_path = tmp_path / filename
dataset = _create_fake_rad_dataset(rad=rad, resolution=resolution)
dataset.to_netcdf(data_path)
dataset.to_netcdf(
data_path,
encoding={
"Rad": {"chunksizes": [226, 226]},
},
)
scn = Scene(
reader="abi_l1b",
filenames=[str(data_path)],
Expand All @@ -236,10 +254,18 @@ def _create_scene_for_data(

def _get_and_check_array(data_arr: xr.DataArray, exp_dtype: npt.DTypeLike) -> npt.NDArray:
data_np = data_arr.data.compute()
assert isinstance(data_arr, xr.DataArray)
assert isinstance(data_arr.data, da.Array)
assert isinstance(data_np, np.ndarray)
res = 1000 if RAD_SHAPE[1000][0] == data_np.shape[0] else 2000
assert data_arr.chunks[0][0] == 226 * (4 / (res / 500))
assert data_arr.chunks[1][0] == 226 * (4 / (res / 500))

assert data_np.dtype == data_arr.dtype
assert data_np.dtype == exp_dtype
return data_np


def _check_area(data_arr: xr.DataArray) -> None:
from pyresample.geometry import AreaDefinition

Expand Down

0 comments on commit 14f59c4

Please sign in to comment.