Skip to content

Commit

Permalink
ENH: Allow the user to change the default resampling per product and …
Browse files Browse the repository at this point in the history
…free the `resolution` keyword in `load` and `stack` #103
  • Loading branch information
remi-braun committed Feb 14, 2025
1 parent b98ea6d commit ce6ad4b
Show file tree
Hide file tree
Showing 18 changed files with 132 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## 0.21.10 (2024-mm-dd)

- ENH: Allow the user to change the default resampling per product and free the `resolution` keyword in `load` and `stack` ([#103](https://github.com/sertit/eoreader/discussions/103))
- FIX: Fix Sentinel-2 Theia footprints when the nodata area is wider than the data area ([#201](https://github.com/sertit/eoreader/issues/201))
- FIX: Pop `driver` keyword in stack function to only use it for writing, allowing people to drive stack as COGs ([#181](https://github.com/sertit/eoreader/issues/181), [#202](https://github.com/sertit/eoreader/discussions/202))
- FIX: Simplify L7 footprint ([#198](https://github.com/sertit/eoreader/issues/198))
Expand Down
2 changes: 2 additions & 0 deletions ci/on_push/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import pytest
import xarray as xr
from rasterio.enums import Resampling
from sertit import AnyPath, path

from ci.scripts_utils import (
Expand Down Expand Up @@ -205,6 +206,7 @@ def _test_core(
pixel_size=pixel_size,
stack_path=curr_path,
clean_optical="clean",
resamplig=Resampling.bilinear,
**kwargs,
)

Expand Down
37 changes: 36 additions & 1 deletion ci/on_push/test_others.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import rasterio
import tempenv
import xarray as xr
from rasterio.enums import Resampling
from rasterio.windows import Window
from sertit import AnyPath, path, unistra

Expand Down Expand Up @@ -399,8 +400,9 @@ def test_reader_methods():

@s3_env
def test_windowed_reading():
"""Tets windowed reading"""
# Get paths
prod_path = opt_path().joinpath("LC08_L1TP_200030_20201220_20210310_02_T1.tar")
prod_path = opt_path().joinpath("LT05_L1TP_200030_20111110_20200820_02_T1")
window_path = others_path().joinpath(
"20201220T104856_L8_200030_OLI_TIRS_window.geojson"
)
Expand All @@ -421,6 +423,39 @@ def test_windowed_reading():
np.testing.assert_array_equal(red_raw.data, red_clean.data)


@s3_env
def test_custom_resamplings():
"""Test custom resamplings"""
# Get paths
prod_path = opt_path().joinpath("LT05_L1TP_200030_20111110_20200820_02_T1")
window_path = others_path().joinpath(
"20201220T104856_L8_200030_OLI_TIRS_window.geojson"
)

os.environ["EOREADER_BAND_RESAMPLING"] = str(Resampling.nearest)
prod = READER.open(prod_path, remove_tmp=True)
red_default = prod.load(RED, window=window_path, pixel_size=600)[RED]

prod.clean_tmp()
red_nearest = prod.load(
RED, window=window_path, pixel_size=600, resampling=Resampling.nearest
)[RED]

prod.clean_tmp()
red_bilinear = prod.load(
RED, window=window_path, pixel_size=600, resampling=Resampling.bilinear
)[RED]

assert red_default.shape == red_bilinear.shape == red_nearest.shape

# The arrays should be equal
np.testing.assert_array_equal(red_default.data, red_nearest.data)

# The arrays shouldn't be equal (resampling has changed)
with pytest.raises(AssertionError):
np.testing.assert_array_equal(red_default.data, red_bilinear.data)


def test_deprecation():
"""Test deprecation warning"""

Expand Down
20 changes: 20 additions & 0 deletions docs/main_features.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,26 @@ Some additional arguments can be passed to this function, please see {meth}`~eor
[`EOREADER_TILE_SIZE` environment variable](https://eoreader.readthedocs.io/latest/api/eoreader.env_vars.TILE_SIZE.html#eoreader.env_vars.TILE_SIZE).
The `TILE_SIZE` default value is 2048.

💡 By default the band will be resampled following a bilinear resampling.
To override this behaviour, modify the `EOREADER_BAND_RESAMPLING` environment variable.
Note that for discrete files such as masks, the nearest resampling is set in stone.

Available values (use the number and see rasterio's Resampling for more details and limitations):
- `nearest` = `0`
- `bilinear` = `1`
- `cubic` = `2`
- `cubic_spline` = `3`
- `lanczos` = `4`
- `average` = `5`
- `mode` = `6`
- `gauss` = `7`
- `max` = `8`
- `min` = `9`
- `med` = `10`
- `q1` = `11`
- `q3` = `12`
- `sum` = `13`
- `rms` = `14`

## Stack

Expand Down
34 changes: 34 additions & 0 deletions eoreader/env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,37 @@
Only used if :code:`EOREADER_USE_DASK` is set to 1.
Not used in case of :code:`EOREADER_USE_DASK` set as 'auto'.
"""

BAND_RESAMPLING = "EOREADER_BAND_RESAMPLING"
"""
Overrides the default resampling (bilinear) used when loading bands.
Note that for discrete files such as masks, the nearest resampling is set in stone.
Available values (use the number and see rasterio's Resampling for more details and limitations):
- nearest = 0
- bilinear = 1
- cubic = 2
- cubic_spline = 3
- lanczos = 4
- average = 5
- mode = 6
- gauss = 7
- max = 8
- min = 9
- med = 10
- q1 = 11
- q3 = 12
- sum = 13
- rms = 14
Examples:
>>> import os
>>>
>>> # Nearest
>>> os.environ["EOREADER_BAND_RESAMPLING"] = "0"
>>>
>>> # Cubic
>>> from rasterio.enums import Resampling
>>> os.environ["EOREADER_BAND_RESAMPLING"] = str(Resampling.cubic)
"""
2 changes: 1 addition & 1 deletion eoreader/products/custom_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def _read_band(
band_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=kwargs.pop("resampling", self.band_resampling),
indexes=[self.bands[band].id],
as_type=np.float32,
**kwargs,
Expand Down
2 changes: 1 addition & 1 deletion eoreader/products/optical/hls_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,7 @@ def _read_band(
band_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=kwargs.pop("resampling", self.band_resampling),
**kwargs,
)

Expand Down
6 changes: 4 additions & 2 deletions eoreader/products/optical/landsat_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,6 +1115,8 @@ def _read_band(
Returns:
xr.DataArray: Band xarray
"""
resampling = kwargs.pop("resampling", self.band_resampling)

if self.is_archived:
filename = path.get_filename(str(band_path).split("!")[-1])
else:
Expand All @@ -1131,13 +1133,13 @@ def _read_band(
**kwargs,
)
else:
# Read band (call superclass generic method)
# Read band
band_arr = utils.read(
band_path,
pixel_size=pixel_size,
size=size,
as_type=np.float32,
resampling=Resampling.bilinear,
resampling=resampling,
**kwargs,
)

Expand Down
5 changes: 3 additions & 2 deletions eoreader/products/optical/planet_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,7 @@ def _read_band(
xr.DataArray: Band xarray
"""
resampling = kwargs.pop("resampling", self.band_resampling)
with rasterio.open(str(band_path)) as dst:
# Manage the case if we open a simple band (EOReader processed bands)
if dst.count == 1:
Expand All @@ -468,7 +469,7 @@ def _read_band(
band_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=resampling,
**kwargs,
)

Expand All @@ -478,7 +479,7 @@ def _read_band(
band_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=resampling,
indexes=[self.bands[band].id],
**kwargs,
)
Expand Down
2 changes: 1 addition & 1 deletion eoreader/products/optical/s2_e84_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def _read_band(
band_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=kwargs.pop("resampling", self.band_resampling),
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion eoreader/products/optical/s2_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ def _read_band(
geocoded_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=kwargs.pop("resampling", self.band_resampling),
**kwargs,
)

Expand Down
2 changes: 1 addition & 1 deletion eoreader/products/optical/s2_theia_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def _read_band(
band_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=kwargs.pop("resampling", self.band_resampling),
**kwargs,
)

Expand Down
3 changes: 1 addition & 2 deletions eoreader/products/optical/s3_olci_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,8 +575,7 @@ def _rad_2_refl(
utils.write(sza_nc, sza_path)

with rasterio.open(sza_path) as ds_sza:
# Values can be easily interpolated at pixels from Tie Points by linear interpolation using the
# image column coordinate.
# Values can be easily interpolated at pixels from Tie Points by linear interpolation using the image column coordinate.
sza, _ = rasters_rio.read(
ds_sza,
size=(band_arr.rio.width, band_arr.rio.height),
Expand Down
2 changes: 1 addition & 1 deletion eoreader/products/optical/s3_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ def _read_band(
band_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=kwargs.pop("resampling", self.band_resampling),
**kwargs,
)

Expand Down
16 changes: 10 additions & 6 deletions eoreader/products/optical/vhr_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,13 @@ def _reproject(
if src_xda.rio.crs is None:
src_xda.rio.write_crs(self._get_raw_crs(), inplace=True)

resampling = kwargs.pop("resampling", self.band_resampling)

try:
out_xda = src_xda.rio.reproject(
dst_crs=self.crs(),
resolution=self.pixel_size,
resampling=Resampling.bilinear,
resampling=resampling,
nodata=self._raw_nodata,
num_threads=utils.get_max_cores(),
rpcs=rpcs,
Expand Down Expand Up @@ -354,7 +356,7 @@ def _reproject(
# out_xda = src_xda.odc.reproject(
# how=self.crs(),
# resolution=self.pixel_size,
# resampling=Resampling.bilinear,
# resampling=kwargs.pop("resampling", self.band_resampling),
# dst_nodata=self._raw_nodata,
# num_threads=utils.get_max_cores(),
# rpcs=rpcs,
Expand Down Expand Up @@ -382,7 +384,7 @@ def _reproject(
dst_resolution=self.pixel_size,
dst_nodata=self._raw_nodata, # input data should be in integer
num_threads=utils.get_max_cores(),
resampling=Resampling.bilinear,
resampling=resampling,
**kwargs,
)
# Get dims
Expand Down Expand Up @@ -434,6 +436,8 @@ def _read_band(
Returns:
xr.DataArray: Band xarray
"""
resampling = kwargs.pop("resampling", self.band_resampling)

with rasterio.open(str(band_path)) as dst:
dst_crs = dst.crs

Expand Down Expand Up @@ -463,7 +467,7 @@ def _read_band(
reproj_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=resampling,
indexes=[self.bands[band].id],
**kwargs,
)
Expand All @@ -475,7 +479,7 @@ def _read_band(
band_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=resampling,
**kwargs,
)

Expand All @@ -486,7 +490,7 @@ def _read_band(
band_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=resampling,
indexes=[self.bands[band].id],
**kwargs,
)
Expand Down
3 changes: 3 additions & 0 deletions eoreader/products/product.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ def __init__(
self.is_stacked = False
"""True if the bands are stacked (like for VHR data)."""

self.band_resampling = utils.get_band_resampling()
"""Band resampling (default: bilinear). Overriden by the env variable "EOREADER_BAND_RESAMPLING", if existing and valid."""

self._stac = None

# Manage output
Expand Down
2 changes: 1 addition & 1 deletion eoreader/products/sar/sar_product.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,7 @@ def _read_band(
band_path,
pixel_size=pixel_size,
size=size,
resampling=Resampling.bilinear,
resampling=kwargs.pop("resampling", self.band_resampling),
as_type=np.float32,
**kwargs,
)
Expand Down
12 changes: 11 additions & 1 deletion eoreader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@

from eoreader import EOREADER_NAME, cache
from eoreader.bands import is_index, is_sat_band
from eoreader.env_vars import NOF_BANDS_IN_CHUNKS, TILE_SIZE, USE_DASK
from eoreader.env_vars import NOF_BANDS_IN_CHUNKS, TILE_SIZE, USE_DASK, BAND_RESAMPLING
from eoreader.exceptions import InvalidProductError
from eoreader.keywords import _prune_keywords

Expand Down Expand Up @@ -616,3 +616,13 @@ def get_archived_rio_path(
def is_uint16(band_arr: xr.DataArray):
"""Is this array saved as uint16 on disk?"""
return band_arr.encoding.get("dtype") in ["uint16", np.uint16]


def get_band_resampling():
"""Overrides the default band resampling (bilinear) with the env variable "EOREADER_BAND_RESAMPLING", if existing and valid"""
resampling = Resampling.bilinear
with contextlib.suppress(ValueError, TypeError):
resampling = Resampling(int(os.getenv(BAND_RESAMPLING)))
LOGGER.debug(f"Band resampling overriden to '{resampling.name}'.")

return resampling

0 comments on commit ce6ad4b

Please sign in to comment.