Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a masking functionality #271

Merged
merged 18 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ repos:

# Format the code aggressively using black
- repo: https://github.com/psf/black
rev: 21.6b0
rev: 22.3.0
hooks:
- id: black
args: [--line-length=120]
Expand Down
45 changes: 43 additions & 2 deletions geoutils/georaster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def __pow__(self: RasterType, power: int | float) -> RasterType:
raise ValueError("Power needs to be a number.")

# Calculate the product of arrays and save to new Raster
out_data = self.data ** power
out_data = self.data**power
ndv = self.nodata
if (np.sum(out_data.mask) > 0) & (ndv is None):
ndv = _default_ndv(out_data.dtype)
Expand Down Expand Up @@ -785,6 +785,34 @@ def _update(

self._is_modified = True

def set_mask(self, mask: np.ndarray) -> None:
"""
Mask all pixels of self.data where `mask` is set to True or > 0.

Masking is performed in place.
`mask` must have the same shape as loaded data, unless the first dimension is 1, then it is ignored.

:param mask: The data mask
"""
# Check that mask is a Numpy array
if not isinstance(mask, np.ndarray):
raise ValueError("mask must be a numpy array.")

# Check that new_data has correct shape
if self.is_loaded:
orig_shape = self._data.shape
else:
raise AttributeError("self.data must be loaded first, with e.g. self.load()")

if mask.shape != orig_shape:
# In case first dimension is empty and other dimensions match -> reshape mask
if (orig_shape[0] == 1) & (orig_shape[1:] == mask.shape):
mask = mask.reshape(orig_shape)
else:
raise ValueError(f"mask must be of the same shape as existing data: {orig_shape}.")

self.data[mask > 0] = np.ma.masked

def info(self, stats: bool = False) -> str:
"""
Returns string of information about the raster (filename, coordinate system, number of columns/rows, etc.).
Expand Down Expand Up @@ -1269,6 +1297,7 @@ def save(
filename: str | IO[bytes],
driver: str = "GTiff",
dtype: DTypeLike | None = None,
nodata: AnyNumber | None = None,
compress: str = "deflate",
tiled: bool = False,
blank_value: None | int | float = None,
Expand All @@ -1289,6 +1318,7 @@ def save(
:param filename: Filename to write the file to.
:param driver: the 'GDAL' driver to use to write the file as.
:param dtype: Data Type to write the image as (defaults to dtype of image data)
:param nodata: nodata value to be used.
:param compress: Compression type. Defaults to 'deflate' (equal to GDALs: COMPRESS=DEFLATE)
:param tiled: Whether to write blocks in tiles instead of strips. Improves read performance on large files,
but increases file size.
Expand All @@ -1312,6 +1342,9 @@ def save(
if gcps is None:
gcps = []

# Use nodata set by user, otherwise default to self's
nodata = nodata if nodata is not None else self.nodata

if (self.data is None) & (blank_value is None):
raise AttributeError("No data loaded, and alternative blank_value not set.")
elif blank_value is not None:
Expand All @@ -1323,6 +1356,14 @@ def save(
else:
save_data = self.data

# if masked array, save with masked values replaced by nodata
# In this case, nodata = None is not compatible, so revert to default values
if isinstance(save_data, np.ma.masked_array) & (np.count_nonzero(save_data.mask) > 0):
if nodata is None:
nodata = _default_ndv(save_data.dtype)
warnings.warn(f"No nodata set, will use default value of {nodata}")
save_data = save_data.filled(nodata)

with rio.open(
filename,
"w",
Expand All @@ -1333,7 +1374,7 @@ def save(
dtype=save_data.dtype,
crs=self.ds.crs,
transform=self.ds.transform,
nodata=self.ds.nodata,
nodata=nodata,
compress=compress,
tiled=tiled,
**co_opts,
Expand Down
2 changes: 1 addition & 1 deletion geoutils/spatial_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def _get_closest_rectangle(size: int) -> tuple[int, int]:
close_cube = int(np.sqrt(size))

# If size has an integer root, return the respective cube.
if close_cube ** 2 == size:
if close_cube**2 == size:
return (close_cube, close_cube)

# One of these rectangles/cubes will cover all cells, so return the first that does.
Expand Down
87 changes: 82 additions & 5 deletions tests/test_georaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import os
import tempfile
import warnings
from tempfile import TemporaryFile
from tempfile import NamedTemporaryFile, TemporaryFile

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -344,6 +344,48 @@ def test_is_modified(self) -> None:
r.data = r.data + 5
assert r.is_modified

@pytest.mark.parametrize("dataset", ["landsat_B4", "landsat_RGB"]) # type: ignore
def test_masking(self, dataset: str) -> None:
"""
Test self.set_mask
"""
# Test boolean mask
r = gr.Raster(datasets.get_path(dataset))
mask = r.data == np.min(r.data)
r.set_mask(mask)
assert (np.count_nonzero(mask) > 0) & np.array_equal(mask > 0, r.data.mask)

# Test non boolean mask with values > 0
r = gr.Raster(datasets.get_path(dataset))
mask = np.where(r.data == np.min(r.data), 32, 0)
r.set_mask(mask)
assert (np.count_nonzero(mask) > 0) & np.array_equal(mask > 0, r.data.mask)

# Test that previous mask is also preserved
mask2 = r.data == np.max(r.data)
assert np.count_nonzero(mask2) > 0
r.set_mask(mask2)
assert np.array_equal((mask > 0) | (mask2 > 0), r.data.mask)
assert np.count_nonzero(~r.data.mask[mask > 0]) == 0

# Test that shape of first dimension is ignored if equal to 1
r = gr.Raster(datasets.get_path(dataset))
if r.data.shape[0] == 1:
mask = (r.data == np.min(r.data)).squeeze()
r.set_mask(mask)
assert (np.count_nonzero(mask) > 0) & np.array_equal(mask > 0, r.data.mask.squeeze())

# Test that proper issue is raised if shape is incorrect
r = gr.Raster(datasets.get_path(dataset))
wrong_shape = np.array(r.data.shape) + 1
mask = np.zeros(wrong_shape)
with pytest.raises(ValueError, match="mask must be of the same shape as existing data*"):
r.set_mask(mask)

# Test that proper issue is raised if mask is not a numpy array
with pytest.raises(ValueError, match="mask must be a numpy array"):
r.set_mask(1)

def test_crop(self) -> None:

r = gr.Raster(datasets.get_path("landsat_B4"))
Expand Down Expand Up @@ -764,13 +806,48 @@ def test_saving(self) -> None:
# Read single band raster
img = gr.Raster(datasets.get_path("landsat_B4"))

# Temporary folder
temp_dir = tempfile.TemporaryDirectory()

# Save file to temporary file, with defaults opts
img.save(TemporaryFile())
temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name)
img.save(temp_file.name)
saved = gr.Raster(temp_file.name)
assert gu.misc.array_equal(img.data, saved.data)

# Test additional options
co_opts = {"TILED": "YES", "COMPRESS": "LZW"}
metadata = {"Type": "test"}
img.save(TemporaryFile(), co_opts=co_opts, metadata=metadata)
temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name)
img.save(temp_file.name, co_opts=co_opts, metadata=metadata)
saved = gr.Raster(temp_file.name)
assert gu.misc.array_equal(img.data, saved.data)
assert saved.ds.tags()["Type"] == "test"

# Test that nodata value is enforced when masking - since value 0 is not used, data should be unchanged
temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name)
img.save(temp_file.name, nodata=0)
saved = gr.Raster(temp_file.name)
assert gu.misc.array_equal(img.data, saved.data)
assert saved.nodata == 0

# Test that mask is preserved
mask = img.data == np.min(img.data)
img.set_mask(mask)
temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name)
img.save(temp_file.name, nodata=0)
saved = gr.Raster(temp_file.name)
assert gu.misc.array_equal(img.data, saved.data)

# Test that a warning is raised if nodata is not set
with pytest.warns(UserWarning):
img.save(TemporaryFile())

# Clean up teporary folder - fails on Windows
try:
temp_dir.cleanup()
except (NotADirectoryError, PermissionError):
pass

def test_coords(self) -> None:

Expand Down Expand Up @@ -1410,5 +1487,5 @@ def test_raise_errors(self, op: str) -> None:
def test_power(self, power: float | int) -> None:

if power > 0: # Integers to negative integer powers are not allowed.
assert self.r1 ** power == self.from_array(self.r1.data ** power, rst_ref=self.r1)
assert self.r1_f32 ** power == self.from_array(self.r1_f32.data ** power, rst_ref=self.r1_f32)
assert self.r1**power == self.from_array(self.r1.data**power, rst_ref=self.r1)
assert self.r1_f32**power == self.from_array(self.r1_f32.data**power, rst_ref=self.r1_f32)