Skip to content

Commit

Permalink
Deprecate AugmentationSequential wrapper (#2396)
Browse files Browse the repository at this point in the history
* Deprecate AugmentationSequential wrapper

* Test deprecation

* Remove tests for deprecated wrapper
  • Loading branch information
adamjstewart authored Feb 12, 2025
1 parent e140738 commit 269e179
Show file tree
Hide file tree
Showing 5 changed files with 11 additions and 188 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ dependencies = [
"torchmetrics>=0.10",
# torchvision 0.14+ required for torchvision.models.swin_v2_b
"torchvision>=0.14",
# typing-extensions 4.5+ required for typing_extensions.deprecated
# can be removed once Python 3.13 is minimum supported version
"typing-extensions>=4.5",
]
dynamic = ["version"]

Expand Down
1 change: 1 addition & 0 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ timm==0.4.12
torch==1.13.0
torchmetrics==0.10.0
torchvision==0.14.0
typing-extensions==4.5.0

# datasets
h5py==3.6.0
Expand Down
1 change: 1 addition & 0 deletions requirements/required.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ timm==1.0.14
torch==2.6.0
torchmetrics==1.6.1
torchvision==0.21.0
typing-extensions==4.12.2
185 changes: 0 additions & 185 deletions tests/transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,195 +2,10 @@
# Licensed under the MIT License.

import kornia.augmentation as K
import pytest
import torch
from torch import Tensor

from torchgeo.transforms import indices
from torchgeo.transforms.transforms import _ExtractPatches

# Kornia is very particular about its boxes:
#
# * Boxes must have shape B x 4 x 2
# * Defined in clockwise order: top-left, top-right, bottom-right, bottom-left
# * Coordinates must be in (x, y) order
#
# This seems to change with every release...


@pytest.fixture
def batch_gray() -> dict[str, Tensor]:
return {
'image': torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float),
'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}


@pytest.fixture
def batch_rgb() -> dict[str, Tensor]:
return {
'image': torch.tensor(
[
[
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
]
],
dtype=torch.float,
),
'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}


@pytest.fixture
def batch_multispectral() -> dict[str, Tensor]:
return {
'image': torch.tensor(
[
[
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
]
],
dtype=torch.float,
),
'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}


def assert_matching(output: dict[str, Tensor], expected: dict[str, Tensor]) -> None:
for key in expected:
err = f'output[{key}] != expected[{key}]'
equal = torch.allclose(output[key], expected[key])
assert equal, err


def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None:
expected = {
'image': torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float),
'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}
augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None)
output = augs(batch_gray)
assert_matching(output, expected)


def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None:
expected = {
'image': torch.tensor(
[
[
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
]
],
dtype=torch.float,
),
'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}
augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None)
output = augs(batch_rgb)
assert_matching(output, expected)


def test_augmentation_sequential_multispectral(
batch_multispectral: dict[str, Tensor],
) -> None:
expected = {
'image': torch.tensor(
[
[
[[7, 8, 9], [4, 5, 6], [1, 2, 3]],
[[7, 8, 9], [4, 5, 6], [1, 2, 3]],
[[7, 8, 9], [4, 5, 6], [1, 2, 3]],
[[7, 8, 9], [4, 5, 6], [1, 2, 3]],
[[7, 8, 9], [4, 5, 6], [1, 2, 3]],
]
],
dtype=torch.float,
),
'mask': torch.tensor([[[1, 1, 1], [0, 1, 1], [0, 0, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}
augs = K.AugmentationSequential(K.RandomVerticalFlip(p=1.0), data_keys=None)
output = augs(batch_multispectral)
assert_matching(output, expected)


def test_augmentation_sequential_image_only(
batch_multispectral: dict[str, Tensor],
) -> None:
expected_image = torch.tensor(
[
[
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
]
],
dtype=torch.float,
)

augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=['image'])
aug_image = augs(batch_multispectral['image'])
assert torch.allclose(aug_image, expected_image)


def test_sequential_transforms_augmentations(
batch_multispectral: dict[str, Tensor],
) -> None:
expected = {
'image': torch.tensor(
[
[
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[3, 2, 1], [6, 5, 4], [9, 8, 7]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
]
],
dtype=torch.float,
),
'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long),
'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float),
'labels': torch.tensor([[0, 1]]),
}
train_transforms = K.AugmentationSequential(
indices.AppendNBR(index_nir=0, index_swir=0),
indices.AppendNDBI(index_swir=0, index_nir=0),
indices.AppendNDSI(index_green=0, index_swir=0),
indices.AppendNDVI(index_red=0, index_nir=0),
indices.AppendNDWI(index_green=0, index_nir=0),
K.RandomHorizontalFlip(p=1.0),
data_keys=None,
)
output = train_transforms(batch_multispectral)
assert_matching(output, expected)


def test_extract_patches() -> None:
b, c, h, w = 2, 3, 64, 64
Expand Down
9 changes: 6 additions & 3 deletions torchgeo/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,18 @@
import kornia.augmentation as K
import torch
from einops import rearrange
from kornia.augmentation import AugmentationSequential
from kornia.contrib import extract_tensor_patches
from kornia.geometry import crop_by_indices
from torch import Tensor
from typing_extensions import deprecated

# Only include import redirects
__all__ = ('AugmentationSequential',)

@deprecated('Use kornia.augmentation.AugmentationSequential instead')
class AugmentationSequential(K.AugmentationSequential):
"""Deprecated wrapper around kornia.augmentation.AugmentationSequential."""


# TODO: contribute these to Kornia and delete this file
class _RandomNCrop(K.GeometricAugmentationBase2D):
"""Take N random crops of a tensor."""

Expand Down

0 comments on commit 269e179

Please sign in to comment.