diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 51376ec7..982d3194 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -24,7 +24,7 @@ repos: # Format the code aggressively using black - repo: https://github.com/psf/black - rev: 21.6b0 + rev: 22.3.0 hooks: - id: black args: [--line-length=120] diff --git a/geoutils/georaster/raster.py b/geoutils/georaster/raster.py index e5534290..b4c08aed 100644 --- a/geoutils/georaster/raster.py +++ b/geoutils/georaster/raster.py @@ -620,7 +620,7 @@ def __pow__(self: RasterType, power: int | float) -> RasterType: raise ValueError("Power needs to be a number.") # Calculate the product of arrays and save to new Raster - out_data = self.data ** power + out_data = self.data**power ndv = self.nodata if (np.sum(out_data.mask) > 0) & (ndv is None): ndv = _default_ndv(out_data.dtype) @@ -785,6 +785,34 @@ def _update( self._is_modified = True + def set_mask(self, mask: np.ndarray) -> None: + """ + Mask all pixels of self.data where `mask` is set to True or > 0. + + Masking is performed in place. + `mask` must have the same shape as loaded data, unless the first dimension is 1, then it is ignored. + + :param mask: The data mask + """ + # Check that mask is a Numpy array + if not isinstance(mask, np.ndarray): + raise ValueError("mask must be a numpy array.") + + # Check that new_data has correct shape + if self.is_loaded: + orig_shape = self._data.shape + else: + raise AttributeError("self.data must be loaded first, with e.g. self.load()") + + if mask.shape != orig_shape: + # In case first dimension is empty and other dimensions match -> reshape mask + if (orig_shape[0] == 1) & (orig_shape[1:] == mask.shape): + mask = mask.reshape(orig_shape) + else: + raise ValueError(f"mask must be of the same shape as existing data: {orig_shape}.") + + self.data[mask > 0] = np.ma.masked + def info(self, stats: bool = False) -> str: """ Returns string of information about the raster (filename, coordinate system, number of columns/rows, etc.). @@ -1269,6 +1297,7 @@ def save( filename: str | IO[bytes], driver: str = "GTiff", dtype: DTypeLike | None = None, + nodata: AnyNumber | None = None, compress: str = "deflate", tiled: bool = False, blank_value: None | int | float = None, @@ -1289,6 +1318,7 @@ def save( :param filename: Filename to write the file to. :param driver: the 'GDAL' driver to use to write the file as. :param dtype: Data Type to write the image as (defaults to dtype of image data) + :param nodata: nodata value to be used. :param compress: Compression type. Defaults to 'deflate' (equal to GDALs: COMPRESS=DEFLATE) :param tiled: Whether to write blocks in tiles instead of strips. Improves read performance on large files, but increases file size. @@ -1312,6 +1342,9 @@ def save( if gcps is None: gcps = [] + # Use nodata set by user, otherwise default to self's + nodata = nodata if nodata is not None else self.nodata + if (self.data is None) & (blank_value is None): raise AttributeError("No data loaded, and alternative blank_value not set.") elif blank_value is not None: @@ -1323,6 +1356,14 @@ def save( else: save_data = self.data + # if masked array, save with masked values replaced by nodata + # In this case, nodata = None is not compatible, so revert to default values + if isinstance(save_data, np.ma.masked_array) & (np.count_nonzero(save_data.mask) > 0): + if nodata is None: + nodata = _default_ndv(save_data.dtype) + warnings.warn(f"No nodata set, will use default value of {nodata}") + save_data = save_data.filled(nodata) + with rio.open( filename, "w", @@ -1333,7 +1374,7 @@ def save( dtype=save_data.dtype, crs=self.ds.crs, transform=self.ds.transform, - nodata=self.ds.nodata, + nodata=nodata, compress=compress, tiled=tiled, **co_opts, diff --git a/geoutils/spatial_tools.py b/geoutils/spatial_tools.py index 0ffc99af..d602df74 100644 --- a/geoutils/spatial_tools.py +++ b/geoutils/spatial_tools.py @@ -320,7 +320,7 @@ def _get_closest_rectangle(size: int) -> tuple[int, int]: close_cube = int(np.sqrt(size)) # If size has an integer root, return the respective cube. - if close_cube ** 2 == size: + if close_cube**2 == size: return (close_cube, close_cube) # One of these rectangles/cubes will cover all cells, so return the first that does. diff --git a/tests/test_georaster.py b/tests/test_georaster.py index 5875d40d..390d33eb 100644 --- a/tests/test_georaster.py +++ b/tests/test_georaster.py @@ -6,7 +6,7 @@ import os import tempfile import warnings -from tempfile import TemporaryFile +from tempfile import NamedTemporaryFile, TemporaryFile import matplotlib.pyplot as plt import numpy as np @@ -344,6 +344,48 @@ def test_is_modified(self) -> None: r.data = r.data + 5 assert r.is_modified + @pytest.mark.parametrize("dataset", ["landsat_B4", "landsat_RGB"]) # type: ignore + def test_masking(self, dataset: str) -> None: + """ + Test self.set_mask + """ + # Test boolean mask + r = gr.Raster(datasets.get_path(dataset)) + mask = r.data == np.min(r.data) + r.set_mask(mask) + assert (np.count_nonzero(mask) > 0) & np.array_equal(mask > 0, r.data.mask) + + # Test non boolean mask with values > 0 + r = gr.Raster(datasets.get_path(dataset)) + mask = np.where(r.data == np.min(r.data), 32, 0) + r.set_mask(mask) + assert (np.count_nonzero(mask) > 0) & np.array_equal(mask > 0, r.data.mask) + + # Test that previous mask is also preserved + mask2 = r.data == np.max(r.data) + assert np.count_nonzero(mask2) > 0 + r.set_mask(mask2) + assert np.array_equal((mask > 0) | (mask2 > 0), r.data.mask) + assert np.count_nonzero(~r.data.mask[mask > 0]) == 0 + + # Test that shape of first dimension is ignored if equal to 1 + r = gr.Raster(datasets.get_path(dataset)) + if r.data.shape[0] == 1: + mask = (r.data == np.min(r.data)).squeeze() + r.set_mask(mask) + assert (np.count_nonzero(mask) > 0) & np.array_equal(mask > 0, r.data.mask.squeeze()) + + # Test that proper issue is raised if shape is incorrect + r = gr.Raster(datasets.get_path(dataset)) + wrong_shape = np.array(r.data.shape) + 1 + mask = np.zeros(wrong_shape) + with pytest.raises(ValueError, match="mask must be of the same shape as existing data*"): + r.set_mask(mask) + + # Test that proper issue is raised if mask is not a numpy array + with pytest.raises(ValueError, match="mask must be a numpy array"): + r.set_mask(1) + def test_crop(self) -> None: r = gr.Raster(datasets.get_path("landsat_B4")) @@ -764,13 +806,48 @@ def test_saving(self) -> None: # Read single band raster img = gr.Raster(datasets.get_path("landsat_B4")) + # Temporary folder + temp_dir = tempfile.TemporaryDirectory() + # Save file to temporary file, with defaults opts - img.save(TemporaryFile()) + temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name) + img.save(temp_file.name) + saved = gr.Raster(temp_file.name) + assert gu.misc.array_equal(img.data, saved.data) # Test additional options co_opts = {"TILED": "YES", "COMPRESS": "LZW"} metadata = {"Type": "test"} - img.save(TemporaryFile(), co_opts=co_opts, metadata=metadata) + temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name) + img.save(temp_file.name, co_opts=co_opts, metadata=metadata) + saved = gr.Raster(temp_file.name) + assert gu.misc.array_equal(img.data, saved.data) + assert saved.ds.tags()["Type"] == "test" + + # Test that nodata value is enforced when masking - since value 0 is not used, data should be unchanged + temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name) + img.save(temp_file.name, nodata=0) + saved = gr.Raster(temp_file.name) + assert gu.misc.array_equal(img.data, saved.data) + assert saved.nodata == 0 + + # Test that mask is preserved + mask = img.data == np.min(img.data) + img.set_mask(mask) + temp_file = NamedTemporaryFile(mode="w", delete=False, dir=temp_dir.name) + img.save(temp_file.name, nodata=0) + saved = gr.Raster(temp_file.name) + assert gu.misc.array_equal(img.data, saved.data) + + # Test that a warning is raised if nodata is not set + with pytest.warns(UserWarning): + img.save(TemporaryFile()) + + # Clean up teporary folder - fails on Windows + try: + temp_dir.cleanup() + except (NotADirectoryError, PermissionError): + pass def test_coords(self) -> None: @@ -1410,5 +1487,5 @@ def test_raise_errors(self, op: str) -> None: def test_power(self, power: float | int) -> None: if power > 0: # Integers to negative integer powers are not allowed. - assert self.r1 ** power == self.from_array(self.r1.data ** power, rst_ref=self.r1) - assert self.r1_f32 ** power == self.from_array(self.r1_f32.data ** power, rst_ref=self.r1_f32) + assert self.r1**power == self.from_array(self.r1.data**power, rst_ref=self.r1) + assert self.r1_f32**power == self.from_array(self.r1_f32.data**power, rst_ref=self.r1_f32)