Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Initial shift Capability to DEM Coregistration #650

Merged
merged 5 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions doc/source/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ To build and pass your coregistration pipeline to {func}`~xdem.DEM.coregister_3d
coreg.Coreg.meta
```

#### Quick coregistration
```{eval-rst}
.. autosummary::
:toctree: gen_modules/

coreg.workflows.dem_coregistration
```

### Affine coregistration

#### Parent object (to define custom methods)
Expand Down
9 changes: 8 additions & 1 deletion doc/source/coregistration.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ my_coreg_pipeline = xdem.coreg.ICP() + xdem.coreg.NuthKaab()
my_coreg_pipeline = xdem.coreg.NuthKaab()
```

Then, coregistering a pair of elevation data can be done by calling {func}`xdem.DEM.coregister_3d` from the DEM that should be aligned.
Then, coregistering a pair of elevation data can be done by calling {func}`xdem.coreg.workflows.dem_coregistration`, or
{func}`xdem.DEM.coregister_3d` from the DEM that should be aligned.

```{code-cell} ipython3
:tags: [hide-cell]
Expand All @@ -66,12 +67,18 @@ Then, coregistering a pair of elevation data can be done by calling {func}`xdem.
import geoutils as gu
import numpy as np
import matplotlib.pyplot as plt
from xdem.coreg.workflows import dem_coregistration

# Open a reference and to-be-aligned DEM
ref_dem = xdem.DEM(xdem.examples.get_path("longyearbyen_ref_dem"))
tba_dem = xdem.DEM(xdem.examples.get_path("longyearbyen_tba_dem"))
```

```{code-cell} ipython3
# Coregister by calling the dem_coregistration function
aligned_dem = dem_coregistration(tba_dem, ref_dem, coreg_method=my_coreg_pipeline)
```

```{code-cell} ipython3
# Coregister by calling the DEM method
aligned_dem = tba_dem.coregister_3d(ref_dem, my_coreg_pipeline)
Expand Down
82 changes: 78 additions & 4 deletions tests/test_coreg/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,83 @@ def test_dem_coregistration(self) -> None:
out_fig.close()

# Testing different coreg method
dem_coreg, coreg_method, coreg_stats, inlier_mask = dem_coregistration(
dem_coreg2, coreg_method2, coreg_stats2, inlier_mask2 = dem_coregistration(
tba_dem, ref_dem, coreg_method=xdem.coreg.Deramp()
)
assert isinstance(coreg_method, xdem.coreg.Deramp)
assert abs(coreg_stats["med_orig"].values) > abs(coreg_stats["med_coreg"].values)
assert coreg_stats["nmad_orig"].values > coreg_stats["nmad_coreg"].values
assert isinstance(coreg_method2, xdem.coreg.Deramp)
assert abs(coreg_stats2["med_orig"].values) > abs(coreg_stats2["med_coreg"].values)
assert coreg_stats2["nmad_orig"].values > coreg_stats2["nmad_coreg"].values

# Testing with initial shift
test_shift_list = [10, 5]
tba_dem_origin = tba_dem.copy()
coreg_pipeline = xdem.coreg.affine.NuthKaab() + xdem.coreg.affine.VerticalShift()

dem_coreg2, coreg_method2, coreg_stats2, inlier_mask2 = dem_coregistration(
tba_dem, ref_dem, coreg_method=coreg_pipeline, estimated_initial_shift=test_shift_list, random_state=42
)
dem_coreg3, coreg_method3, coreg_stats3, inlier_mask3 = dem_coregistration(
tba_dem, ref_dem, coreg_method=coreg_pipeline, random_state=42
)
assert tba_dem.raster_equal(tba_dem_origin)
assert isinstance(coreg_method2, xdem.coreg.CoregPipeline)
assert isinstance(coreg_method3, xdem.coreg.CoregPipeline)
assert isinstance(coreg_method2.pipeline[0], xdem.coreg.AffineCoreg)
assert isinstance(coreg_method3.pipeline[0], xdem.coreg.AffineCoreg)
assert (
coreg_method2.pipeline[0].meta["outputs"]["affine"]["shift_x"]
== coreg_method3.pipeline[0].meta["outputs"]["affine"]["shift_x"]
)
assert (
coreg_method2.pipeline[0].meta["outputs"]["affine"]["shift_y"]
== coreg_method3.pipeline[0].meta["outputs"]["affine"]["shift_y"]
)

# Testing without coreg pipeline
test_shift_tuple = (-5, 2) # tuple
coreg_simple = xdem.coreg.affine.DhMinimize()

dem_coreg2, coreg_method2, coreg_stats2, inlier_mask2 = dem_coregistration(
tba_dem, ref_dem, coreg_method=coreg_simple, estimated_initial_shift=test_shift_tuple, random_state=42
)
dem_coreg3, coreg_method3, coreg_stats3, inlier_mask3 = dem_coregistration(
tba_dem, ref_dem, coreg_method=coreg_simple, random_state=42
)
assert isinstance(coreg_method2, xdem.coreg.AffineCoreg)
assert isinstance(coreg_method3, xdem.coreg.AffineCoreg)
assert coreg_method2.meta["outputs"]["affine"]["shift_x"] == coreg_method3.meta["outputs"]["affine"]["shift_x"]
assert coreg_method2.meta["outputs"]["affine"]["shift_y"] == coreg_method3.meta["outputs"]["affine"]["shift_y"]

# Check if the appropriate exception is raised with an initial shift and without affine coreg
with pytest.raises(TypeError, match=r".*affine.*"):
dem_coregistration(
tba_dem,
ref_dem,
coreg_method=xdem.coreg.Deramp(),
estimated_initial_shift=test_shift_tuple,
random_state=42,
)
with pytest.raises(TypeError, match=r".*affine.*"):
dem_coregistration(
tba_dem,
ref_dem,
coreg_method=xdem.coreg.Deramp() + xdem.coreg.TerrainBias(),
estimated_initial_shift=test_shift_tuple,
random_state=42,
)

# Check if the appropriate exception is raised with a wrong type initial shift
with pytest.raises(ValueError, match=r".*two numerical values.*"):
dem_coregistration(
tba_dem,
ref_dem,
estimated_initial_shift=["2", 2],
random_state=42,
)
with pytest.raises(ValueError, match=r".*two numerical values.*"):
dem_coregistration(
tba_dem,
ref_dem,
estimated_initial_shift=[2, 3, 5],
random_state=42,
)
82 changes: 69 additions & 13 deletions xdem/coreg/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from geoutils.raster import RasterType

from xdem._typing import NDArrayf
from xdem.coreg import AffineCoreg, CoregPipeline
from xdem.coreg.affine import NuthKaab, VerticalShift
from xdem.coreg.base import Coreg
from xdem.dem import DEM
Expand Down Expand Up @@ -148,7 +149,7 @@ def dem_coregistration(
src_dem_path: str | RasterType,
ref_dem_path: str | RasterType,
out_dem_path: str | None = None,
coreg_method: Coreg | None = None,
coreg_method: Coreg | CoregPipeline | None = None,
grid: str = "ref",
resample: bool = False,
resampling: rio.warp.Resampling | None = rio.warp.Resampling.bilinear,
Expand All @@ -161,7 +162,8 @@ def dem_coregistration(
random_state: int | np.random.Generator | None = None,
plot: bool = False,
out_fig: str = None,
) -> tuple[DEM, Coreg, pd.DataFrame, NDArrayf]:
estimated_initial_shift: list[Number] | tuple[Number, Number] | None = None,
) -> tuple[DEM, Coreg | CoregPipeline, pd.DataFrame, NDArrayf]:
"""
A one-line function to coregister a selected DEM to a reference DEM.

Expand All @@ -173,7 +175,7 @@ def dem_coregistration(
:param src_dem_path: Path to the input DEM to be coregistered
:param ref_dem_path: Path to the reference DEM
:param out_dem_path: Path where to save the coregistered DEM. If set to None (default), will not save to file.
:param coreg_method: Coregistration method or pipeline. Defaults to NuthKaab + VerticalShift.
:param coreg_method: Coregistration method, or pipeline.
:param grid: The grid to be used during coregistration, set either to "ref" or "src".
:param resample: If set to True, will reproject output Raster on the same grid as input. Otherwise, only \
the array/transform will be updated (if possible) and no resampling is done. Useful to avoid spreading data gaps.
Expand All @@ -189,6 +191,8 @@ def dem_coregistration(
:param random_state: Random state or seed number to use for subsampling and optimizer.
:param plot: Set to True to plot a figure of elevation diff before/after coregistration.
:param out_fig: Path to the output figure. If None will display to screen.
:param estimated_initial_shift: List containing x and y shifts (in pixels). These shifts are applied before \
the coregistration process begins.

:returns: A tuple containing 1) coregistered DEM as an xdem.DEM instance 2) the coregistration method \
3) DataFrame of coregistration statistics (count of obs, median and NMAD over stable terrain) before and after \
Expand Down Expand Up @@ -221,21 +225,52 @@ def dem_coregistration(
if grid not in ["ref", "src"]:
raise ValueError(f"Argument `grid` must be either 'ref' or 'src' - currently set to {grid}.")

# Ensure that if an initial shift is provided, at least one coregistration method is affine.
if estimated_initial_shift:
if not (
isinstance(estimated_initial_shift, (list, tuple))
and len(estimated_initial_shift) == 2
and all(isinstance(val, (float, int)) for val in estimated_initial_shift)
):
raise ValueError(
"Argument `estimated_initial_shift` must be a list or tuple of exactly two numerical values."
)
if isinstance(coreg_method, CoregPipeline):
if not any(isinstance(step, AffineCoreg) for step in coreg_method.pipeline):
raise TypeError(
"An initial shift has been provided, but none of the coregistration methods in the pipeline "
"are affine. At least one affine coregistration method (e.g., AffineCoreg) is required."
)
elif not isinstance(coreg_method, AffineCoreg):
raise TypeError(
"An initial shift has been provided, but the coregistration method is not affine. "
"An affine coregistration method (e.g., AffineCoreg) is required."
)

# Load both DEMs
logging.info("Loading and reprojecting input data")

if isinstance(ref_dem_path, str):
if grid == "ref":
ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=0)
elif grid == "src":
ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path], ref_grid=1)
else:
ref_dem, src_dem = gu.raster.load_multiple_rasters([ref_dem_path, src_dem_path])

elif isinstance(src_dem_path, gu.Raster):
ref_dem = ref_dem_path
src_dem = src_dem_path
if grid == "ref":
src_dem = src_dem.reproject(ref_dem, silent=True)
elif grid == "src":
ref_dem = ref_dem.reproject(src_dem, silent=True)
src_dem = src_dem_path.copy()

# If an initial shift is provided, apply it before coregistration
if estimated_initial_shift:

# convert shift
shift_x = estimated_initial_shift[0] * src_dem.res[0]
shift_y = estimated_initial_shift[1] * src_dem.res[1]

# Apply the shift to the source dem
src_dem.translate(shift_x, shift_y, inplace=True)

if grid == "ref":
src_dem = src_dem.reproject(ref_dem, silent=True)
elif grid == "src":
ref_dem = ref_dem.reproject(src_dem, silent=True)

# Convert to DEM instance with Float32 dtype
# TODO: Could only convert types int into float, but any other float dtype should yield very similar results
Expand Down Expand Up @@ -268,6 +303,27 @@ def dem_coregistration(
coreg_method.fit(ref_dem, src_dem, inlier_mask, random_state=random_state)
dem_coreg = coreg_method.apply(src_dem, resample=resample, resampling=resampling)

# Add the initial shift to the calculated shift
if estimated_initial_shift:

def update_shift(
coreg_method: Coreg | CoregPipeline, shift_x: float = shift_x, shift_y: float = shift_y
) -> None:
if isinstance(coreg_method, CoregPipeline):
for step in coreg_method.pipeline:
update_shift(step)
else:
# check if the keys exists
if "outputs" in coreg_method.meta and "affine" in coreg_method.meta["outputs"]:
if "shift_x" in coreg_method.meta["outputs"]["affine"]:
coreg_method.meta["outputs"]["affine"]["shift_x"] += shift_x
logging.debug(f"Updated shift_x by {shift_x} in {coreg_method}")
if "shift_y" in coreg_method.meta["outputs"]["affine"]:
coreg_method.meta["outputs"]["affine"]["shift_y"] += shift_y
logging.debug(f"Updated shift_y by {shift_y} in {coreg_method}")

update_shift(coreg_method)

# Calculate coregistered ddem (might need resampling if resample set to False), needed for stats and plot only
ddem_coreg = dem_coreg.reproject(ref_dem, silent=True) - ref_dem

Expand Down
Loading