Skip to content

Commit

Permalink
Add masking functionality; fix mask issue upon saving (#271)
Browse files Browse the repository at this point in the history
* Add a set_mask method to Raster

* dd associated tests

* Improve tests

* Update version of black and fix linting issues

* Ensure that masked rasters have a nodata value set upon saving

* Fix issue with change in test, 2 commits ago

* Only enforce nodata value upon saving if masked values exist

* Update and improve tests for Raster.save

* Linting

* Try fixing issue with Windows for NamedTemporaryFile

* Try fixing issue with Windows for NamedTemporaryFile

* Attempt 2 to fix Windows issue with temporary file

* Although previous commit worked, test solution similar to xdem

* Revert to previous option, since last test did not work

* Attempt 5 to fix Windows issue and delete temporary files

* Add a fix for Windows

* Update to previous commit
  • Loading branch information
adehecq authored May 18, 2022
1 parent 52f4cae commit c101c3e
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 9 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:

# Format the code aggressively using black
- repo: https://github.com/psf/black
rev: 21.6b0
rev: 22.3.0
hooks:
- id: black
args: [--line-length=120]
Expand Down
45 changes: 43 additions & 2 deletions geoutils/georaster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def __pow__(self: RasterType, power: int | float) -> RasterType:
raise ValueError("Power needs to be a number.")

# Calculate the product of arrays and save to new Raster
out_data = self.data ** power
out_data = self.data**power
ndv = self.nodata
if (np.sum(out_data.mask) > 0) & (ndv is None):
ndv = _default_ndv(out_data.dtype)
Expand Down Expand Up @@ -785,6 +785,34 @@ def _update(

self._is_modified = True

def set_mask(self, mask: np.ndarray) -> None:
"""
Mask all pixels of self.data where `mask` is set to True or > 0.
Masking is performed in place.
`mask` must have the same shape as loaded data, unless the first dimension is 1, then it is ignored.
:param mask: The data mask
"""
# Check that mask is a Numpy array
if not isinstance(mask, np.ndarray):
raise ValueError("mask must be a numpy array.")

# Check that new_data has correct shape
if self.is_loaded:
orig_shape = self._data.shape
else:
raise AttributeError("self.data must be loaded first, with e.g. self.load()")

if mask.shape != orig_shape:
# In case first dimension is empty and other dimensions match -> reshape mask
if (orig_shape[0] == 1) & (orig_shape[1:] == mask.shape):
mask = mask.reshape(orig_shape)
else:
raise ValueError(f"mask must be of the same shape as existing data: {orig_shape}.")

self.data[mask > 0] = np.ma.masked

def info(self, stats: bool = False) -> str:
"""
Returns string of information about the raster (filename, coordinate system, number of columns/rows, etc.).
Expand Down Expand Up @@ -1269,6 +1297,7 @@ def save(
filename: str | IO[bytes],
driver: str = "GTiff",
dtype: DTypeLike | None = None,
nodata: AnyNumber | None = None,
compress: str = "deflate",
tiled: bool = False,
blank_value: None | int | float = None,
Expand All @@ -1289,6 +1318,7 @@ def save(
:param filename: Filename to write the file to.
:param driver: the 'GDAL' driver to use to write the file as.
:param dtype: Data Type to write the image as (defaults to dtype of image data)
:param nodata: nodata value to be used.
:param compress: Compression type. Defaults to 'deflate' (equal to GDALs: COMPRESS=DEFLATE)
:param tiled: Whether to write blocks in tiles instead of strips. Improves read performance on large files,
but increases file size.
Expand All @@ -1312,6 +1342,9 @@ def save(
if gcps is None:
gcps = []

# Use nodata set by user, otherwise default to self's
nodata = nodata if nodata is not None else self.nodata

if (self.data is None) & (blank_value is None):
raise AttributeError("No data loaded, and alternative blank_value not set.")
elif blank_value is not None:
Expand All @@ -1323,6 +1356,14 @@ def save(
else:
save_data = self.data

# if masked array, save with masked values replaced by nodata
# In this case, nodata = None is not compatible, so revert to default values
if isinstance(save_data, np.ma.masked_array) & (np.count_nonzero(save_data.mask) > 0):
if nodata is None:
nodata = _default_ndv(save_data.dtype)
warnings.warn(f"No nodata set, will use default value of {nodata}")
save_data = save_data.filled(nodata)

with rio.open(
filename,
"w",
Expand All @@ -1333,7 +1374,7 @@ def save(
dtype=save_data.dtype,
crs=self.ds.crs,
transform=self.ds.transform,
nodata=self.ds.nodata,
nodata=nodata,
compress=compress,
tiled=tiled,
**co_opts,
Expand Down
2 changes: 1 addition & 1 deletion geoutils/spatial_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ def _get_closest_rectangle(size: int) -> tuple[int, int]:
close_cube = int(np.sqrt(size))

# If size has an integer root, return the respective cube.
if close_cube ** 2 == size:
if close_cube**2 == size:
return (close_cube, close_cube)

# One of these rectangles/cubes will cover all cells, so return the first that does.
Expand Down
87 changes: 82 additions & 5 deletions tests/test_georaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import tempfile
import warnings
from tempfile import TemporaryFile
from tempfile import NamedTemporaryFile, TemporaryFile

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -344,6 +344,48 @@ def test_is_modified(self) -> None:
r.data = r.data + 5
assert r.is_modified

@pytest.mark.parametrize("dataset", ["landsat_B4", "landsat_RGB"]) # type: ignore
def test_masking(self, dataset: str) -> None:
"""
Test self.set_mask
"""
# Test boolean mask
r = gr.Raster(datasets.get_path(dataset))
mask = r.data == np.min(r.data)
r.set_mask(mask)
assert (np.count_nonzero(mask) > 0) & np.array_equal(mask > 0, r.data.mask)

# Test non boolean mask with values > 0
r = gr.Raster(datasets.get_path(dataset))
mask = np.where(r.data == np.min(r.data), 32, 0)
r.set_mask(mask)
assert (np.count_nonzero(mask) > 0) & np.array_equal(mask > 0, r.data.mask)

# Test that previous mask is also preserved
mask2 = r.data == np.max(r.data)
assert np.count_nonzero(mask2) > 0
r.set_mask(mask2)
assert np.array_equal((mask > 0) | (mask2 > 0), r.data.mask)
assert np.count_nonzero(~r.data.mask[mask > 0]) == 0

# Test that shape of first dimension is ignored if equal to 1
r = gr.Raster(datasets.get_path(dataset))
if r.data.shape[0] == 1:
mask = (r.data == np.min(r.data)).squeeze()
r.set_mask(mask)
assert (np.count_nonzero(mask) > 0) & np.array_equal(mask > 0, r.data.mask.squeeze())

# Test that proper issue is raised if shape is incorrect
r = gr.Raster(datasets.get_path(dataset))
wrong_shape = np.array(r.data.shape) + 1
mask = np.zeros(wrong_shape)
with pytest.raises(ValueError, match="mask must be of the same shape as existing data*"):
r.set_mask(mask)

# Test that proper issue is raised if mask is not a numpy array
with pytest.raises(ValueError, match="mask must be a numpy array"):
r.set_mask(1)

def test_crop(self) -> None:

r = gr.Raster(datasets.get_path("landsat_B4"))
Expand Down Expand Up @@ -764,13 +806,48 @@ def test_saving(self) -> None:
# Read single band raster
img = gr.Raster(datasets.get_path("landsat_B4"))

# Temporary folder
temp_dir = tempfile.TemporaryDirectory()

# Save file to temporary file, with defaults opts
img.save(TemporaryFile())
temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name)
img.save(temp_file.name)
saved = gr.Raster(temp_file.name)
assert gu.misc.array_equal(img.data, saved.data)

# Test additional options
co_opts = {"TILED": "YES", "COMPRESS": "LZW"}
metadata = {"Type": "test"}
img.save(TemporaryFile(), co_opts=co_opts, metadata=metadata)
temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name)
img.save(temp_file.name, co_opts=co_opts, metadata=metadata)
saved = gr.Raster(temp_file.name)
assert gu.misc.array_equal(img.data, saved.data)
assert saved.ds.tags()["Type"] == "test"

# Test that nodata value is enforced when masking - since value 0 is not used, data should be unchanged
temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name)
img.save(temp_file.name, nodata=0)
saved = gr.Raster(temp_file.name)
assert gu.misc.array_equal(img.data, saved.data)
assert saved.nodata == 0

# Test that mask is preserved
mask = img.data == np.min(img.data)
img.set_mask(mask)
temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name)
img.save(temp_file.name, nodata=0)
saved = gr.Raster(temp_file.name)
assert gu.misc.array_equal(img.data, saved.data)

# Test that a warning is raised if nodata is not set
with pytest.warns(UserWarning):
img.save(TemporaryFile())

# Clean up teporary folder - fails on Windows
try:
temp_dir.cleanup()
except (NotADirectoryError, PermissionError):
pass

def test_coords(self) -> None:

Expand Down Expand Up @@ -1410,5 +1487,5 @@ def test_raise_errors(self, op: str) -> None:
def test_power(self, power: float | int) -> None:

if power > 0: # Integers to negative integer powers are not allowed.
assert self.r1 ** power == self.from_array(self.r1.data ** power, rst_ref=self.r1)
assert self.r1_f32 ** power == self.from_array(self.r1_f32.data ** power, rst_ref=self.r1_f32)
assert self.r1**power == self.from_array(self.r1.data**power, rst_ref=self.r1)
assert self.r1_f32**power == self.from_array(self.r1_f32.data**power, rst_ref=self.r1_f32)

0 comments on commit c101c3e

Please sign in to comment.