Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add South America Soybean DataModule #1959

Merged
merged 3 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/datamodules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ Sentinel

.. autoclass:: Sentinel2CDLDataModule
.. autoclass:: Sentinel2NCCMDataModule
.. autoclass:: Sentinel2SouthAmericaSoybeanDataModule

Non-geospatial DataModules
--------------------------
Expand Down
17 changes: 17 additions & 0 deletions tests/conf/sentinel2_south_america_soybean.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
model:
class_path: SemanticSegmentationTask
init_args:
loss: "ce"
model: "deeplabv3+"
backbone: "resnet18"
in_channels: 13
num_classes: 2
num_filters: 1
data:
class_path: Sentinel2SouthAmericaSoybeanDataModule
init_args:
batch_size: 2
patch_size: 16
dict_kwargs:
south_america_soybean_paths: "tests/data/south_america_soybean"
sentinel2_paths: "tests/data/sentinel2"
Binary file modified tests/data/south_america_soybean/SouthAmericaSoybean.zip
Binary file not shown.
Binary file not shown.
Binary file not shown.
13 changes: 3 additions & 10 deletions tests/data/south_america_soybean/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from rasterio.crs import CRS
from rasterio.transform import Affine

SIZE = 32
SIZE = 128


np.random.seed(0)
Expand All @@ -24,15 +24,8 @@ def create_file(path: str, dtype: str):
"driver": "GTiff",
"dtype": dtype,
"count": 1,
"crs": CRS.from_epsg(4326),
"transform": Affine(
0.0002499999999999943131,
0.0,
-82.0005000000000024,
0.0,
-0.0002499999999999943131,
0.0005000000000000,
),
"crs": CRS.from_epsg(32616),
"transform": Affine(10, 0.0, 399960.0, 0.0, -10, 4500000.0),
"height": SIZE,
"width": SIZE,
"compress": "lzw",
Expand Down
1 change: 1 addition & 0 deletions tests/trainers/test_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class TestSemanticSegmentationTask:
"sen12ms_s2_reduced",
"sentinel2_cdl",
"sentinel2_nccm",
"sentinel2_south_america_soybean",
"spacenet1",
"ssl4eo_l_benchmark_cdl",
"ssl4eo_l_benchmark_nlcd",
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/datamodules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .sen12ms import SEN12MSDataModule
from .sentinel2_cdl import Sentinel2CDLDataModule
from .sentinel2_nccm import Sentinel2NCCMDataModule
from .sentinel2_south_america_soybean import Sentinel2SouthAmericaSoybeanDataModule
from .skippd import SKIPPDDataModule
from .so2sat import So2SatDataModule
from .spacenet import SpaceNet1DataModule
Expand All @@ -53,6 +54,7 @@
"NAIPChesapeakeDataModule",
"Sentinel2CDLDataModule",
"Sentinel2NCCMDataModule",
"Sentinel2SouthAmericaSoybeanDataModule",
# NonGeoDataset
"BigEarthNetDataModule",
"ChaBuDDataModule",
Expand Down
123 changes: 123 additions & 0 deletions torchgeo/datamodules/sentinel2_south_america_soybean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.


"""South America Soybean datamodule."""


from typing import Any, Optional, Union

import kornia.augmentation as K
import torch
from kornia.constants import DataKey, Resample
from matplotlib.figure import Figure

from ..datasets import Sentinel2, SouthAmericaSoybean, random_grid_cell_assignment
from ..samplers import GridGeoSampler, RandomBatchGeoSampler
from ..samplers.utils import _to_tuple
from ..transforms import AugmentationSequential
from .geo import GeoDataModule


class Sentinel2SouthAmericaSoybeanDataModule(GeoDataModule):
"""LightningDataModule for SouthAmericaSoybean and Sentinel2 datasets.

.. versionadded:: 0.6
"""

def __init__(
self,
batch_size: int = 64,
patch_size: Union[int, tuple[int, int]] = 64,
length: Optional[int] = None,
num_workers: int = 0,
**kwargs: Any,
) -> None:
"""Initialize a new Sentinel2SouthAmericaSoybeanDataModule instance.

Args:
batch_size: Size of each mini-batch.
patch_size: Size of each patch, either ``size`` or ``(height, width)``.
length: Length of each training epoch.
num_workers: Number of workers for parallel data loading.
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.SouthAmericaSoybean`
(prefix keys with ``south_america_soybean_``) and
:class:`~torchgeo.datasets.Sentinel2`
(prefix keys with ``sentinel2_``).
"""
self.south_america_soybean_kwargs = {}
self.sentinel2_kwargs = {}
for key, val in kwargs.items():
if key.startswith("south_america_soybean_"):
self.south_america_soybean_kwargs[key[22:]] = val
elif key.startswith("sentinel2_"):
self.sentinel2_kwargs[key[10:]] = val

super().__init__(
SouthAmericaSoybean,
batch_size=batch_size,
patch_size=patch_size,
length=length,
num_workers=num_workers,
**kwargs,
)

self.train_aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std),
K.RandomResizedCrop(_to_tuple(self.patch_size), scale=(0.6, 1.0)),
K.RandomVerticalFlip(p=0.5),
K.RandomHorizontalFlip(p=0.5),
data_keys=["image", "mask"],
extra_args={
DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": None}
},
)

self.aug = AugmentationSequential(
K.Normalize(mean=self.mean, std=self.std), data_keys=["image", "mask"]
)

def setup(self, stage: str) -> None:
"""Set up datasets and samplers.

Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
self.sentinel2 = Sentinel2(**self.sentinel2_kwargs)
self.south_america_soybean = SouthAmericaSoybean(
**self.south_america_soybean_kwargs
)
self.dataset = self.sentinel2 & self.south_america_soybean

generator = torch.Generator().manual_seed(1)
(self.train_dataset, self.val_dataset, self.test_dataset) = (
random_grid_cell_assignment(
self.dataset, [0.8, 0.1, 0.1], grid_size=8, generator=generator
)
)

if stage in ["fit"]:
self.train_batch_sampler = RandomBatchGeoSampler(
self.train_dataset, self.patch_size, self.batch_size, self.length
)
if stage in ["fit", "validate"]:
self.val_sampler = GridGeoSampler(
self.val_dataset, self.patch_size, self.patch_size
)
if stage in ["test"]:
self.test_sampler = GridGeoSampler(
self.test_dataset, self.patch_size, self.patch_size
)

def plot(self, *args: Any, **kwargs: Any) -> Figure:
"""Run SouthAmericaSoybean plot method.

Args:
*args: Arguments passed to plot method.
**kwargs: Keyword arguments passed to plot method.

Returns:
A matplotlib Figure with the image, ground truth, and predictions.
"""
return self.south_america_soybean.plot(*args, **kwargs)
Loading