Skip to content

Commit

Permalink
Incremental commit on accessor
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Nov 22, 2024
1 parent 3db7212 commit 5c4c0f9
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 28 deletions.
11 changes: 7 additions & 4 deletions tests/test_dem/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,10 @@ class TestClassVsAccessorConsistency:
longyearbyen_path = examples.get_path("longyearbyen_ref_dem")

# Test common attributes
attributes = ["vcrs", "vcrs_grid", "vcrs_name"]
attributes_raster = ["crs", "transform", "nodata", "area_or_point", "res", "count", "height", "width", "footprint",
"shape", "bands", "indexes", "_is_xr", "is_loaded"]
attributes_dem = ["vcrs", "vcrs_grid", "vcrs_name"]
attributes = attributes_dem + attributes_raster

@pytest.mark.parametrize("path_dem", [longyearbyen_path]) # type: ignore
@pytest.mark.parametrize("attr", attributes) # type: ignore
Expand All @@ -103,10 +106,10 @@ def test_attributes(self, path_dem: str, attr: str) -> None:

# Get attribute for each object
output_dem = getattr(dem, attr)
output_ds = getattr(getattr(ds, "rst"), attr)
output_ds = getattr(getattr(ds, "dem"), attr)

# Assert equality
if attr != "is_xr": # Only attribute that is (purposely) not the same, but the opposite
if attr != "_is_xr": # Only attribute that is (purposely) not the same, but the opposite
assert output_equal(output_dem, output_ds)
else:
assert output_dem != output_ds
Expand Down Expand Up @@ -139,7 +142,7 @@ def test_methods(self, path_dem: str, method: str) -> None:

# Apply method for each class
output_dem = getattr(dem, method)(**args)
output_ds = getattr(getattr(ds, "rst"), method)(**args)
output_ds = getattr(getattr(ds, "dem"), method)(**args)

# Assert equality of output
assert output_equal(output_dem, output_ds)
51 changes: 34 additions & 17 deletions xdem/dem/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@
import geopandas as gpd
import numpy as np

import rasterio as rio
from affine import Affine
from geoutils import Raster
from geoutils.raster import Mask, RasterType
from pyproj import CRS
from pyproj.crs import CompoundCRS, VerticalCRS
Expand Down Expand Up @@ -62,7 +59,9 @@ class DEMBase(RasterBase):
"""

def __init__(self):
"""Initialize additional DEM metadata as None, for it to be overridden in sublasses."""
"""
Initialize additional DEM metadata as None, for it to be overridden in sublasses.
"""

super().__init__()
self._vcrs: VerticalCRS | Literal["Ellipsoid"] | None = None
Expand Down Expand Up @@ -248,7 +247,9 @@ def to_vcrs(

@copy_doc(terrain, remove_dem_res_params=True)
def slope(self, method: str = "Horn", degrees: bool = True) -> RasterType:
return terrain.slope(self, method=method, degrees=degrees)
slope = terrain.slope(dem=self.data, resolution=self.res, method=method, degrees=degrees)
return self.copy(new_array=slope)


@copy_doc(terrain, remove_dem_res_params=True)
def aspect(
Expand All @@ -257,63 +258,79 @@ def aspect(
degrees: bool = True,
) -> RasterType:

return terrain.aspect(self, method=method, degrees=degrees)
aspect = terrain.aspect(self.data, method=method, degrees=degrees)
return self.copy(new_array=aspect)

@copy_doc(terrain, remove_dem_res_params=True)
def hillshade(
self, method: str = "Horn", azimuth: float = 315.0, altitude: float = 45.0, z_factor: float = 1.0
) -> RasterType:

return terrain.hillshade(self, method=method, azimuth=azimuth, altitude=altitude, z_factor=z_factor)
hillshade = terrain.hillshade(self.data, resolution=self.res, method=method, azimuth=azimuth, altitude=altitude, z_factor=z_factor)
return self.copy(new_array=hillshade)

@copy_doc(terrain, remove_dem_res_params=True)
def curvature(self) -> RasterType:

return terrain.curvature(self)
curv = terrain.curvature(self.data, resolution=self.res)
return self.copy(new_array=curv)

@copy_doc(terrain, remove_dem_res_params=True)
def planform_curvature(self) -> RasterType:

return terrain.planform_curvature(self)
plan_curv = terrain.planform_curvature(self.data, resolution=self.res)
return self.copy(new_array=plan_curv)

@copy_doc(terrain, remove_dem_res_params=True)
def profile_curvature(self) -> RasterType:

return terrain.profile_curvature(self)
prof_curv = terrain.profile_curvature(self.data, resolution=self.res)
return self.copy(new_array=prof_curv)

@copy_doc(terrain, remove_dem_res_params=True)
def maximum_curvature(self) -> RasterType:

return terrain.maximum_curvature(self)
max_curv = terrain.maximum_curvature(self)
return self.copy(new_array=max_curv)

@copy_doc(terrain, remove_dem_res_params=True)
def topographic_position_index(self, window_size: int = 3) -> RasterType:

return terrain.topographic_position_index(self, window_size=window_size)
tpi = terrain.topographic_position_index(self, window_size=window_size)
return self.copy(new_array=tpi)

@copy_doc(terrain, remove_dem_res_params=True)
def terrain_ruggedness_index(self, method: str = "Riley", window_size: int = 3) -> RasterType:

return terrain.terrain_ruggedness_index(self, method=method, window_size=window_size)
tri = terrain.terrain_ruggedness_index(self, method=method, window_size=window_size)
return self.copy(new_array=tri)

@copy_doc(terrain, remove_dem_res_params=True)
def roughness(self, window_size: int = 3) -> RasterType:

return terrain.roughness(self, window_size=window_size)
roughness = terrain.roughness(self, window_size=window_size)
return self.copy(new_array=roughness)

@copy_doc(terrain, remove_dem_res_params=True)
def rugosity(self) -> RasterType:

return terrain.rugosity(self)
rugosity = terrain.rugosity(self)
return self.copy(new_array=rugosity)

@copy_doc(terrain, remove_dem_res_params=True)
def fractal_roughness(self, window_size: int = 13) -> RasterType:

return terrain.fractal_roughness(self, window_size=window_size)
frac_roughness = terrain.fractal_roughness(self, window_size=window_size)
return self.copy(new_array=frac_roughness)

@copy_doc(terrain, remove_dem_res_params=True)
def get_terrain_attribute(self, attribute: str | list[str], **kwargs: Any) -> RasterType | list[RasterType]:
return terrain.get_terrain_attribute(self, attribute=attribute, **kwargs)
attrs = terrain.get_terrain_attribute(self, attribute=attribute, **kwargs)

if isinstance(attrs, list):
return [self.copy(new_array=a) for a in attrs]
else:
return self.copy(new_array=attrs)

def coregister_3d(
self,
Expand Down
9 changes: 5 additions & 4 deletions xdem/dem/dem.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,12 @@
_vcrs_from_crs,
_vcrs_from_user_input,
)
from xdem.dem.base import DEMBase

dem_attrs = ["_vcrs", "_vcrs_name", "_vcrs_grid"]


class DEM(Raster): # type: ignore
class DEM(Raster, DEMBase): # type: ignore
"""
The digital elevation model.
Expand Down Expand Up @@ -76,7 +77,7 @@ def __init__(
parse_sensor_metadata: bool = False,
silent: bool = True,
downsample: int = 1,
nodata: int | float | None = None,
force_nodata: int | float | None = None,
) -> None:
"""
Instantiate a digital elevation model.
Expand All @@ -93,7 +94,7 @@ def __init__(
:param parse_sensor_metadata: Whether to parse sensor metadata from filename and similarly-named metadata files.
:param silent: Whether to display vertical reference parsing.
:param downsample: Downsample the array once loaded by a round factor. Default is no downsampling.
:param nodata: Nodata value to be used (overwrites the metadata). Default reads from metadata.
:param force_nodata: Force nodata value to be used (overwrites the metadata). Default reads from metadata.
"""

self.data: NDArrayf
Expand All @@ -116,7 +117,7 @@ def __init__(
parse_sensor_metadata=parse_sensor_metadata,
silent=silent,
downsample=downsample,
nodata=nodata,
force_nodata=force_nodata,
)

# Ensure DEM has only one band: self.bands can be None when data is not loaded through the Raster class
Expand Down
5 changes: 2 additions & 3 deletions xdem/dem/xr_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def open_dem(filename: str, **kwargs):
return ds

@xr.register_dataarray_accessor("dem")
class DEMAccessor(DEMBase, RasterAccessor):
class DEMAccessor(RasterAccessor, DEMBase):
"""
This class defines the Xarray accessor 'dem' for digital elevation models.
Expand All @@ -28,7 +28,6 @@ class DEMAccessor(DEMBase, RasterAccessor):
"""
def __init__(self, xarray_obj: xr.DataArray):

super().__init__()
super().__init__(xarray_obj=xarray_obj)

self._obj = xarray_obj
self._area_or_point = self._obj.attrs.get("AREA_OR_POINT", None)

0 comments on commit 5c4c0f9

Please sign in to comment.