Skip to content

Commit

Permalink
Make Snap raster block more efficient.
Browse files Browse the repository at this point in the history
  • Loading branch information
arjanverkerk committed Oct 22, 2024
1 parent 0f3dec7 commit 4cbf064
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 57 deletions.
2 changes: 1 addition & 1 deletion CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Changelog of dask-geomodeling
2.5.2 (unreleased)
------------------

- Nothing changed yet.
- Make Snap raster block more efficient.


2.5.1 (2024-09-30)
Expand Down
105 changes: 49 additions & 56 deletions dask_geomodeling/raster/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
get_dtype_max,
parse_percentile_statistic,
dtype_for_statistic,
find_nearest,
)

from .base import RasterBlock, BaseSingle
Expand Down Expand Up @@ -77,7 +78,7 @@ def period(self):
@property
def timedelta(self):
return self.index.timedelta

@property
def temporal(self):
return self.index.temporal
Expand Down Expand Up @@ -105,74 +106,66 @@ def get_sources_and_requests(self, **request):

# if any store is empty, Snap will be empty
if store_period is None or index_period is None:
return [(dict(snap_mode="noop"), None)]
return [(None, None)]

# time requests are easy: just pass them to self.index
if request["mode"] == "time":
return [(dict(snap_mode="noop"), None), (self.index, request)]

start = request.get("start", index_period[1])
stop = request.get("stop", None)
return [(None, None), (self.index, request)]

# query the index time
start = request.get("start")
stop = request.get("stop")
index_result = self.index.get_data(mode="time", start=start, stop=stop)
if index_result is None:
return [(dict(snap_mode="noop"), None)]
return [(None, None)]
index_time = index_result["time"]

# special case: the store has only one frame. repeat it.
# for single frame results, query the store with the time from the index
if stop is None:
request["start"] = index_time[0]
return [(None, None), (self.store, request)]

# multiband request; knowledge of the time structure of the store near start,
# between start and stop, and near stop is required - result frames can be
# nearest to store frames outside the requested [start, stop] interval.
if store_period[0] == store_period[1]:
# new request only gets the last frame
request["start"] = store_period[0]
request["stop"] = None
return [
(dict(snap_mode="repeat", repeats=len(index_time)), None),
(self.store, request),
]

# Return a list of requests with snapped times. Times that occur more
# than once will not be evaluated multiple times due to caching.
requests = [(dict(snap_mode="concat"), None)]
request["stop"] = None
for time in index_time:
store_time = self.store.get_data(mode="time", start=time)["time"]
_request = request.copy()
_request["start"] = store_time[0]
requests.append((self.store, _request))
return requests
# there is only one frame in the store
store_time = [store_period[0]]
else:
# obtain time near start, between start and stop, and near stop
def get_store_time_set(start=None, stop=None):
result = self.store.get_data(mode="time", start=start, stop=stop)
if result is None:
return set()
return set(result["time"])
store_time = sorted(
get_store_time_set(start=start)
| get_store_time_set(start=start, stop=stop)
| get_store_time_set(start=stop)
)

# return a requst to query the store; the actual frames to pick
# for the result (`nearest`) are passed via the `process_kwargs`
request["start"] = store_time[0]
request["stop"] = store_time[-1]
nearest = find_nearest(store_time, index_time)
process_kwargs = {"nearest": nearest}
return [(process_kwargs, None), (self.store, request)]

@staticmethod
def process(process_kwargs, *args):
if len(args) == 0:
return None
def process(process_kwargs, data=None):
if process_kwargs is None:
return data

nearest = process_kwargs["nearest"]

if "values" in data:
data["values"] = data["values"][nearest]
return data

snap_mode = process_kwargs["snap_mode"]

if snap_mode == "noop":
return args[0]

if snap_mode == "repeat":
data = args[0]
repeats = process_kwargs["repeats"]
if "values" in data:
return {
"values": np.repeat(data["values"], repeats, axis=0),
"no_data_value": data["no_data_value"],
}
elif "meta" in data:
return {"meta": data["meta"] * repeats}

# we have a bunch of single frame results that need to be concatenated
if snap_mode == "concat":
if any((arg is None for arg in args)):
return None

# combine the args
if "values" in args[0]:
values = np.concatenate([x["values"] for x in args], 0)
return {"values": values, "no_data_value": args[0]["no_data_value"]}
elif "meta" in args[0]:
return {"meta": [x["meta"][0] for x in args]}
if "meta" in data:
data["meta"] = [data["meta"][i] for i in nearest]
return data


class Shift(BaseSingle):
Expand Down
18 changes: 18 additions & 0 deletions dask_geomodeling/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from datetime import datetime as Datetime
from unittest import mock
import unittest
import pytest
Expand Down Expand Up @@ -453,3 +454,20 @@ def test_rasterize_categorical_float(self):
self.geoseries, values=pd.Series([1.2, 2.4], dtype="category"), **self.box
)
self.assertEqual(np.float64, raster["values"].dtype)


class TestFindNearest(unittest.TestCase):
def test_find_nearest_one_element(self):
self.assertEqual(utils.find_nearest([42], [43, 44, 45]).tolist(), [0, 0, 0])

def test_find_nearest_number(self):
array = [2, 5]
value = [1, 2, 3, 4, 5, 6]
expected = [0, 0, 0, 1, 1, 1]
self.assertEqual(utils.find_nearest(array, value).tolist(), expected)

def test_find_nearest_datetime(self):
array = [Datetime(2001, 2, d) for d in (2, 5)]
value = [Datetime(2001, 2, d) for d in (1, 2, 3, 4, 5, 6)]
expected = [0, 0, 0, 1, 1, 1]
self.assertEqual(utils.find_nearest(array, value).tolist(), expected)
20 changes: 20 additions & 0 deletions dask_geomodeling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def _geopandas_set_srs(df_or_series, crs):
except AttributeError: # geopandas < 0.9
df_or_series.crs = crs


def geoseries_transform(geoseries, src_srs, dst_srs):
"""
Transform a GeoSeries to a different SRS. Returns a copy.
Expand Down Expand Up @@ -948,3 +949,22 @@ def dt_to_ms(dt):

def filter_none(lst):
return [x for x in lst if x is not None]


def find_nearest(array, value):
"""
Return indices to the nearest neighbours of elements.
Args:
array: 1-D array_like, must be sorted in ascending order.
values: array_like, values for which to find the nearest neighbour
"""
array = np.asarray(array)
value = np.asarray(value)

if array.size == 1:
return np.zeros(value.shape, dtype=int)

# determine midpoints a way that works for datetimes, too
midpoints = array[:-1] + (array[1:] - array[:-1]) / 2
return np.searchsorted(midpoints, value)

0 comments on commit 4cbf064

Please sign in to comment.