From 269e179f27ce7cde47f321401fd8624b3f95150d Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Wed, 12 Feb 2025 13:48:34 +0100 Subject: [PATCH] Deprecate AugmentationSequential wrapper (#2396) * Deprecate AugmentationSequential wrapper * Test deprecation * Remove tests for deprecated wrapper --- pyproject.toml | 3 + requirements/min-reqs.old | 1 + requirements/required.txt | 1 + tests/transforms/test_transforms.py | 185 ---------------------------- torchgeo/transforms/transforms.py | 9 +- 5 files changed, 11 insertions(+), 188 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index df0c51cd2b7..2a4611fd450 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,6 +79,9 @@ dependencies = [ "torchmetrics>=0.10", # torchvision 0.14+ required for torchvision.models.swin_v2_b "torchvision>=0.14", + # typing-extensions 4.5+ required for typing_extensions.deprecated + # can be removed once Python 3.13 is minimum supported version + "typing-extensions>=4.5", ] dynamic = ["version"] diff --git a/requirements/min-reqs.old b/requirements/min-reqs.old index e1058142ac6..0f7748a718a 100644 --- a/requirements/min-reqs.old +++ b/requirements/min-reqs.old @@ -20,6 +20,7 @@ timm==0.4.12 torch==1.13.0 torchmetrics==0.10.0 torchvision==0.14.0 +typing-extensions==4.5.0 # datasets h5py==3.6.0 diff --git a/requirements/required.txt b/requirements/required.txt index d8d966e13d0..09de7cacf2d 100644 --- a/requirements/required.txt +++ b/requirements/required.txt @@ -20,3 +20,4 @@ timm==1.0.14 torch==2.6.0 torchmetrics==1.6.1 torchvision==0.21.0 +typing-extensions==4.12.2 diff --git a/tests/transforms/test_transforms.py b/tests/transforms/test_transforms.py index 8b47765ffa5..7352dd2aba7 100644 --- a/tests/transforms/test_transforms.py +++ b/tests/transforms/test_transforms.py @@ -2,195 +2,10 @@ # Licensed under the MIT License. import kornia.augmentation as K -import pytest import torch -from torch import Tensor -from torchgeo.transforms import indices from torchgeo.transforms.transforms import _ExtractPatches -# Kornia is very particular about its boxes: -# -# * Boxes must have shape B x 4 x 2 -# * Defined in clockwise order: top-left, top-right, bottom-right, bottom-left -# * Coordinates must be in (x, y) order -# -# This seems to change with every release... - - -@pytest.fixture -def batch_gray() -> dict[str, Tensor]: - return { - 'image': torch.tensor([[[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]], dtype=torch.float), - 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), - 'labels': torch.tensor([[0, 1]]), - } - - -@pytest.fixture -def batch_rgb() -> dict[str, Tensor]: - return { - 'image': torch.tensor( - [ - [ - [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - ] - ], - dtype=torch.float, - ), - 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float), - 'labels': torch.tensor([[0, 1]]), - } - - -@pytest.fixture -def batch_multispectral() -> dict[str, Tensor]: - return { - 'image': torch.tensor( - [ - [ - [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - [[1, 2, 3], [4, 5, 6], [7, 8, 9]], - ] - ], - dtype=torch.float, - ), - 'mask': torch.tensor([[[0, 0, 1], [0, 1, 1], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[0.0, 1.0, 1.0, 2.0]], dtype=torch.float), - 'labels': torch.tensor([[0, 1]]), - } - - -def assert_matching(output: dict[str, Tensor], expected: dict[str, Tensor]) -> None: - for key in expected: - err = f'output[{key}] != expected[{key}]' - equal = torch.allclose(output[key], expected[key]) - assert equal, err - - -def test_augmentation_sequential_gray(batch_gray: dict[str, Tensor]) -> None: - expected = { - 'image': torch.tensor([[[[3, 2, 1], [6, 5, 4], [9, 8, 7]]]], dtype=torch.float), - 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[0.0, 0.0, 2.0, 2.0]], dtype=torch.float), - 'labels': torch.tensor([[0, 1]]), - } - augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None) - output = augs(batch_gray) - assert_matching(output, expected) - - -def test_augmentation_sequential_rgb(batch_rgb: dict[str, Tensor]) -> None: - expected = { - 'image': torch.tensor( - [ - [ - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - ] - ], - dtype=torch.float, - ), - 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float), - 'labels': torch.tensor([[0, 1]]), - } - augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=None) - output = augs(batch_rgb) - assert_matching(output, expected) - - -def test_augmentation_sequential_multispectral( - batch_multispectral: dict[str, Tensor], -) -> None: - expected = { - 'image': torch.tensor( - [ - [ - [[7, 8, 9], [4, 5, 6], [1, 2, 3]], - [[7, 8, 9], [4, 5, 6], [1, 2, 3]], - [[7, 8, 9], [4, 5, 6], [1, 2, 3]], - [[7, 8, 9], [4, 5, 6], [1, 2, 3]], - [[7, 8, 9], [4, 5, 6], [1, 2, 3]], - ] - ], - dtype=torch.float, - ), - 'mask': torch.tensor([[[1, 1, 1], [0, 1, 1], [0, 0, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[0.0, 0.0, 1.0, 1.0]], dtype=torch.float), - 'labels': torch.tensor([[0, 1]]), - } - augs = K.AugmentationSequential(K.RandomVerticalFlip(p=1.0), data_keys=None) - output = augs(batch_multispectral) - assert_matching(output, expected) - - -def test_augmentation_sequential_image_only( - batch_multispectral: dict[str, Tensor], -) -> None: - expected_image = torch.tensor( - [ - [ - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - ] - ], - dtype=torch.float, - ) - - augs = K.AugmentationSequential(K.RandomHorizontalFlip(p=1.0), data_keys=['image']) - aug_image = augs(batch_multispectral['image']) - assert torch.allclose(aug_image, expected_image) - - -def test_sequential_transforms_augmentations( - batch_multispectral: dict[str, Tensor], -) -> None: - expected = { - 'image': torch.tensor( - [ - [ - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[3, 2, 1], [6, 5, 4], [9, 8, 7]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - [[0, 0, 0], [0, 0, 0], [0, 0, 0]], - ] - ], - dtype=torch.float, - ), - 'mask': torch.tensor([[[1, 0, 0], [1, 1, 0], [1, 1, 1]]], dtype=torch.long), - 'bbox_xyxy': torch.tensor([[1.0, 1.0, 2.0, 2.0]], dtype=torch.float), - 'labels': torch.tensor([[0, 1]]), - } - train_transforms = K.AugmentationSequential( - indices.AppendNBR(index_nir=0, index_swir=0), - indices.AppendNDBI(index_swir=0, index_nir=0), - indices.AppendNDSI(index_green=0, index_swir=0), - indices.AppendNDVI(index_red=0, index_nir=0), - indices.AppendNDWI(index_green=0, index_nir=0), - K.RandomHorizontalFlip(p=1.0), - data_keys=None, - ) - output = train_transforms(batch_multispectral) - assert_matching(output, expected) - def test_extract_patches() -> None: b, c, h, w = 2, 3, 64, 64 diff --git a/torchgeo/transforms/transforms.py b/torchgeo/transforms/transforms.py index 60e35aa4a88..0aedee5e330 100644 --- a/torchgeo/transforms/transforms.py +++ b/torchgeo/transforms/transforms.py @@ -8,15 +8,18 @@ import kornia.augmentation as K import torch from einops import rearrange -from kornia.augmentation import AugmentationSequential from kornia.contrib import extract_tensor_patches from kornia.geometry import crop_by_indices from torch import Tensor +from typing_extensions import deprecated -# Only include import redirects -__all__ = ('AugmentationSequential',) +@deprecated('Use kornia.augmentation.AugmentationSequential instead') +class AugmentationSequential(K.AugmentationSequential): + """Deprecated wrapper around kornia.augmentation.AugmentationSequential.""" + +# TODO: contribute these to Kornia and delete this file class _RandomNCrop(K.GeometricAugmentationBase2D): """Take N random crops of a tensor."""