Skip to content

Commit

Permalink
test(GeoTIFFDataset): test wrong transform on multiband rasters
Browse files Browse the repository at this point in the history
  • Loading branch information
gionnid committed Jan 25, 2025
1 parent 1a5e0fa commit a141321
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 0 deletions.
44 changes: 44 additions & 0 deletions kedro-datasets/kedro_datasets_experimental/rioxarray/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
import rasterio.io
import numpy as np
import xarray as xr


@pytest.fixture(scope="session")
def data_array_of_ones_from_shape():
def data_array_of_ones_from_shape(shape: tuple[int, int, int]) -> xr.DataArray:
data = xr.DataArray(
data=np.ones(shape, dtype="float32"),
dims=("band", "y", "x"),
coords={
"band": np.arange(1, shape[0] + 1),
"y": np.arange(shape[1]),
"x": np.arange(shape[2]),
},
)

data.rio.write_crs("epsg:4326", inplace=True)
data.rio.set_spatial_dims(x_dim="x", y_dim="y", inplace=True)

data = data.assign_attrs(
{f"band_{i}_description": f"band_{i}" for i in range(1, shape[0] + 1)}
)
return data

return data_array_of_ones_from_shape


@pytest.fixture(scope="session")
def one_band_raster_file_path(data_array_of_ones_from_shape):
memfile = rasterio.io.MemoryFile()
data_array_of_ones_from_shape((1, 10, 10)).rio.to_raster(memfile.name)
yield memfile.name
memfile.close()


@pytest.fixture(scope="session")
def three_band_raster_file_path(data_array_of_ones_from_shape):
memfile = rasterio.io.MemoryFile()
data_array_of_ones_from_shape((3, 10, 10)).rio.to_raster(memfile.name)
yield memfile.name
memfile.close()
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from .geotiff_dataset import GeoTIFFDataset
import numpy as np
import rasterio.io
import rioxarray as rxr
import xarray
import pytest


@pytest.mark.parametrize(
"raster_file_path",
[
"one_band_raster_file_path",
"three_band_raster_file_path",
],
)
class TestUnitGeoTIFFDataset:
@staticmethod
def test_raster_file_correctly_mocked(request, raster_file_path):
raster_file_path = request.getfixturevalue(raster_file_path)

with rasterio.open(raster_file_path) as dataset:
profile = dataset.profile.data
array = dataset.read()

assert array.shape == (profile["count"], profile["height"], profile["width"])
assert np.all(array == 1)

@staticmethod
def test_geotiffdataset_load(request, raster_file_path):
raster_file_path = request.getfixturevalue(raster_file_path)

with rasterio.open(raster_file_path) as dataset:
reference_array = dataset.read()
reference_profile = dataset.profile.data

dataset = GeoTIFFDataset(filepath=raster_file_path)
assert isinstance(dataset.load(), xarray.DataArray)

raster = dataset.load()
assert raster.shape == reference_array.shape
assert raster.dtype == reference_array.dtype

assert raster.rio.transform() == reference_profile["transform"]
assert raster.rio.crs == reference_profile["crs"]

@staticmethod
def test_geotiffdataset_save(
request, raster_file_path, data_array_of_ones_from_shape
):
raster_file_path = request.getfixturevalue(raster_file_path)
with rasterio.open(raster_file_path) as dataset:
reference_profile = dataset.profile.data

data_array = data_array_of_ones_from_shape(
(
reference_profile["count"],
reference_profile["height"],
reference_profile["width"],
)
)

with rasterio.io.MemoryFile() as memfile:
GeoTIFFDataset(filepath=memfile.name).save(data_array)

data = GeoTIFFDataset(filepath=memfile.name).load()
assert data.shape == data_array.shape
assert data.dtype == data_array.dtype
assert data.rio.transform() == data_array.rio.transform()
assert data.rio.crs == data_array.rio.crs

0 comments on commit a141321

Please sign in to comment.