diff --git a/docs/api/datamodules.rst b/docs/api/datamodules.rst index cb6dea18f0d..87fd8b48bf1 100644 --- a/docs/api/datamodules.rst +++ b/docs/api/datamodules.rst @@ -36,6 +36,7 @@ Sentinel .. autoclass:: Sentinel2CDLDataModule .. autoclass:: Sentinel2NCCMDataModule +.. autoclass:: Sentinel2SouthAmericaSoybeanDataModule Non-geospatial DataModules -------------------------- diff --git a/tests/conf/sentinel2_south_america_soybean.yaml b/tests/conf/sentinel2_south_america_soybean.yaml new file mode 100644 index 00000000000..1d1b91ba067 --- /dev/null +++ b/tests/conf/sentinel2_south_america_soybean.yaml @@ -0,0 +1,17 @@ +model: + class_path: SemanticSegmentationTask + init_args: + loss: "ce" + model: "deeplabv3+" + backbone: "resnet18" + in_channels: 13 + num_classes: 2 + num_filters: 1 +data: + class_path: Sentinel2SouthAmericaSoybeanDataModule + init_args: + batch_size: 2 + patch_size: 16 + dict_kwargs: + south_america_soybean_paths: "tests/data/south_america_soybean" + sentinel2_paths: "tests/data/sentinel2" diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean.zip b/tests/data/south_america_soybean/SouthAmericaSoybean.zip index 5453b89fc25..630cc6c94c9 100644 Binary files a/tests/data/south_america_soybean/SouthAmericaSoybean.zip and b/tests/data/south_america_soybean/SouthAmericaSoybean.zip differ diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2002.tif b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2002.tif index 95667ce067c..23156a6a1ef 100644 Binary files a/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2002.tif and b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2002.tif differ diff --git a/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif index a220b500677..7aab8f3db65 100644 Binary files a/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif and b/tests/data/south_america_soybean/SouthAmericaSoybean/South_America_Soybean_2021.tif differ diff --git a/tests/data/south_america_soybean/data.py b/tests/data/south_america_soybean/data.py index fbe7d7b23d1..63e93ebdc1d 100644 --- a/tests/data/south_america_soybean/data.py +++ b/tests/data/south_america_soybean/data.py @@ -11,7 +11,7 @@ from rasterio.crs import CRS from rasterio.transform import Affine -SIZE = 32 +SIZE = 128 np.random.seed(0) @@ -24,15 +24,8 @@ def create_file(path: str, dtype: str): "driver": "GTiff", "dtype": dtype, "count": 1, - "crs": CRS.from_epsg(4326), - "transform": Affine( - 0.0002499999999999943131, - 0.0, - -82.0005000000000024, - 0.0, - -0.0002499999999999943131, - 0.0005000000000000, - ), + "crs": CRS.from_epsg(32616), + "transform": Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0), "height": SIZE, "width": SIZE, "compress": "lzw", diff --git a/tests/trainers/test_segmentation.py b/tests/trainers/test_segmentation.py index 59c7c16cff5..40da4efead6 100644 --- a/tests/trainers/test_segmentation.py +++ b/tests/trainers/test_segmentation.py @@ -75,6 +75,7 @@ class TestSemanticSegmentationTask: "sen12ms_s2_reduced", "sentinel2_cdl", "sentinel2_nccm", + "sentinel2_south_america_soybean", "spacenet1", "ssl4eo_l_benchmark_cdl", "ssl4eo_l_benchmark_nlcd", diff --git a/torchgeo/datamodules/__init__.py b/torchgeo/datamodules/__init__.py index 1c268bb807c..a2d3eb1a666 100644 --- a/torchgeo/datamodules/__init__.py +++ b/torchgeo/datamodules/__init__.py @@ -31,6 +31,7 @@ from .sen12ms import SEN12MSDataModule from .sentinel2_cdl import Sentinel2CDLDataModule from .sentinel2_nccm import Sentinel2NCCMDataModule +from .sentinel2_south_america_soybean import Sentinel2SouthAmericaSoybeanDataModule from .skippd import SKIPPDDataModule from .so2sat import So2SatDataModule from .spacenet import SpaceNet1DataModule @@ -53,6 +54,7 @@ "NAIPChesapeakeDataModule", "Sentinel2CDLDataModule", "Sentinel2NCCMDataModule", + "Sentinel2SouthAmericaSoybeanDataModule", # NonGeoDataset "BigEarthNetDataModule", "ChaBuDDataModule", diff --git a/torchgeo/datamodules/sentinel2_south_america_soybean.py b/torchgeo/datamodules/sentinel2_south_america_soybean.py new file mode 100644 index 00000000000..6ce54000cbc --- /dev/null +++ b/torchgeo/datamodules/sentinel2_south_america_soybean.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + + +"""South America Soybean datamodule.""" + + +from typing import Any, Optional, Union + +import kornia.augmentation as K +import torch +from kornia.constants import DataKey, Resample +from matplotlib.figure import Figure + +from ..datasets import Sentinel2, SouthAmericaSoybean, random_grid_cell_assignment +from ..samplers import GridGeoSampler, RandomBatchGeoSampler +from ..samplers.utils import _to_tuple +from ..transforms import AugmentationSequential +from .geo import GeoDataModule + + +class Sentinel2SouthAmericaSoybeanDataModule(GeoDataModule): + """LightningDataModule for SouthAmericaSoybean and Sentinel2 datasets. + + .. versionadded:: 0.6 + """ + + def __init__( + self, + batch_size: int = 64, + patch_size: Union[int, tuple[int, int]] = 64, + length: Optional[int] = None, + num_workers: int = 0, + **kwargs: Any, + ) -> None: + """Initialize a new Sentinel2SouthAmericaSoybeanDataModule instance. + + Args: + batch_size: Size of each mini-batch. + patch_size: Size of each patch, either ``size`` or ``(height, width)``. + length: Length of each training epoch. + num_workers: Number of workers for parallel data loading. + **kwargs: Additional keyword arguments passed to + :class:`~torchgeo.datasets.SouthAmericaSoybean` + (prefix keys with ``south_america_soybean_``) and + :class:`~torchgeo.datasets.Sentinel2` + (prefix keys with ``sentinel2_``). + """ + self.south_america_soybean_kwargs = {} + self.sentinel2_kwargs = {} + for key, val in kwargs.items(): + if key.startswith("south_america_soybean_"): + self.south_america_soybean_kwargs[key[22:]] = val + elif key.startswith("sentinel2_"): + self.sentinel2_kwargs[key[10:]] = val + + super().__init__( + SouthAmericaSoybean, + batch_size=batch_size, + patch_size=patch_size, + length=length, + num_workers=num_workers, + **kwargs, + ) + + self.train_aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), + K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)), + K.RandomVerticalFlip(p=0.5), + K.RandomHorizontalFlip(p=0.5), + data_keys=["image", "mask"], + extra_args={ + DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None} + }, + ) + + self.aug = AugmentationSequential( + K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"] + ) + + def setup(self, stage: str) -> None: + """Set up datasets and samplers. + + Args: + stage: Either 'fit', 'validate', 'test', or 'predict'. + """ + self.sentinel2 = Sentinel2(**self.sentinel2_kwargs) + self.south_america_soybean = SouthAmericaSoybean( + **self.south_america_soybean_kwargs + ) + self.dataset = self.sentinel2 & self.south_america_soybean + + generator = torch.Generator().manual_seed(1) + (self.train_dataset, self.val_dataset, self.test_dataset) = ( + random_grid_cell_assignment( + self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator + ) + ) + + if stage in ["fit"]: + self.train_batch_sampler = RandomBatchGeoSampler( + self.train_dataset, self.patch_size, self.batch_size, self.length + ) + if stage in ["fit", "validate"]: + self.val_sampler = GridGeoSampler( + self.val_dataset, self.patch_size, self.patch_size + ) + if stage in ["test"]: + self.test_sampler = GridGeoSampler( + self.test_dataset, self.patch_size, self.patch_size + ) + + def plot(self, *args: Any, **kwargs: Any) -> Figure: + """Run SouthAmericaSoybean plot method. + + Args: + *args: Arguments passed to plot method. + **kwargs: Keyword arguments passed to plot method. + + Returns: + A matplotlib Figure with the image, ground truth, and predictions. + """ + return self.south_america_soybean.plot(*args, **kwargs)