From 006cfa97ff0c9b72ae75c0bf9faa8bf7ed3404de Mon Sep 17 00:00:00 2001 From: "Adam J. Stewart" Date: Fri, 21 Feb 2025 15:29:22 +0100 Subject: [PATCH] Fix most tests --- tests/trainers/test_instance_segmentation.py | 145 +++++++++---------- torchgeo/trainers/instance_segmentation.py | 3 +- 2 files changed, 66 insertions(+), 82 deletions(-) diff --git a/tests/trainers/test_instance_segmentation.py b/tests/trainers/test_instance_segmentation.py index dfb68907c0f..eefb0f0d935 100644 --- a/tests/trainers/test_instance_segmentation.py +++ b/tests/trainers/test_instance_segmentation.py @@ -2,38 +2,61 @@ # Licensed under the MIT License. import os -from pathlib import Path -from typing import Any, cast +from typing import Any import pytest -import timm import torch import torch.nn as nn from lightning.pytorch import Trainer from pytest import MonkeyPatch from torch.nn.modules import Module -from torchvision.models._api import WeightsEnum from torchgeo.datamodules import MisconfigurationException, VHR10DataModule -from torchgeo.datasets import RGBBandsMissingError +from torchgeo.datasets import VHR10, RGBBandsMissingError from torchgeo.main import main -from torchgeo.models import ResNet50_Weights from torchgeo.trainers import InstanceSegmentationTask +# MAP metric requires pycocotools to be installed +pytest.importorskip('pycocotools') -class SegmentationTestModel(Module): - def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None: - super().__init__() - self.conv1 = nn.Conv2d( - in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0 - ) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return cast(torch.Tensor, self.conv1(x)) +class PredictInstanceSegmentationDataModule(VHR10DataModule): + def setup(self, stage: str) -> None: + self.predict_dataset = VHR10(**self.kwargs) -def create_model(**kwargs: Any) -> Module: - return SegmentationTestModel(**kwargs) +# TODO: This is not even used yet +class InstanceSegmentationTestModel(Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + self.fc = nn.Linear(1, 1) + + def forward(self, images: Any, targets: Any = None) -> Any: + batch_size = len(images) + if self.training: + assert batch_size == len(targets) + # use the Linear layer to generate a tensor that has a gradient + return { + 'loss_classifier': self.fc(torch.rand(1)), + 'loss_box_reg': self.fc(torch.rand(1)), + 'loss_objectness': self.fc(torch.rand(1)), + 'loss_rpn_box_reg': self.fc(torch.rand(1)), + } + else: # eval mode + output = [] + for i in range(batch_size): + boxes = torch.rand(10, 4) + # xmax, ymax must be larger than xmin, ymin + boxes[:, 2:] += 1 + output.append( + { + 'boxes': boxes, + 'masks': torch.randint(2, images.shape[1:]), + 'labels': torch.randint(2, (10,)), + 'scores': torch.rand(10), + } + ) + return output def plot(*args: Any, **kwargs: Any) -> None: @@ -74,74 +97,25 @@ def test_trainer( except MisconfigurationException: pass - @pytest.fixture - def weights(self) -> WeightsEnum: - return ResNet50_Weights.SENTINEL2_ALL_MOCO - - @pytest.fixture - def mocked_weights( - self, - tmp_path: Path, - monkeypatch: MonkeyPatch, - weights: WeightsEnum, - load_state_dict_from_url: None, - ) -> WeightsEnum: - path = tmp_path / f'{weights}.pth' - model = timm.create_model( - weights.meta['model'], in_chans=weights.meta['in_chans'] - ) - torch.save(model.state_dict(), path) - try: - monkeypatch.setattr(weights.value, 'url', str(path)) - except AttributeError: - monkeypatch.setattr(weights, 'url', str(path)) - return weights - - def test_weight_file(self, checkpoint: str) -> None: - InstanceSegmentationTask(backbone='resnet50', weights=checkpoint, num_classes=6) - - def test_weight_enum(self, mocked_weights: WeightsEnum) -> None: - InstanceSegmentationTask( - backbone=mocked_weights.meta['model'], - weights=mocked_weights, - in_channels=mocked_weights.meta['in_chans'], - ) - - def test_weight_str(self, mocked_weights: WeightsEnum) -> None: - InstanceSegmentationTask( - backbone=mocked_weights.meta['model'], - weights=str(mocked_weights), - in_channels=mocked_weights.meta['in_chans'], - ) - - @pytest.mark.slow - def test_weight_enum_download(self, weights: WeightsEnum) -> None: - InstanceSegmentationTask( - backbone=weights.meta['model'], - weights=weights, - in_channels=weights.meta['in_chans'], - ) - - @pytest.mark.slow - def test_weight_str_download(self, weights: WeightsEnum) -> None: - InstanceSegmentationTask( - backbone=weights.meta['model'], - weights=str(weights), - in_channels=weights.meta['in_chans'], - ) - def test_invalid_model(self) -> None: - with pytest.raises(ValueError, match='Invalid model type'): + match = 'Invalid model type' + with pytest.raises(ValueError, match=match): InstanceSegmentationTask(model='invalid_model') + def test_invalid_backbone(self) -> None: + match = 'Invalid backbone type' + with pytest.raises(ValueError, match=match): + InstanceSegmentationTask(backbone='invalid_backbone') + + def test_pretrained_backbone(self) -> None: + InstanceSegmentationTask(backbone='resnet50', weights=True) + def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: monkeypatch.setattr(VHR10DataModule, 'plot', plot) datamodule = VHR10DataModule( root='tests/data/vhr10', batch_size=1, num_workers=0 ) - model = InstanceSegmentationTask( - backbone='resnet50', in_channels=15, num_classes=6 - ) + model = InstanceSegmentationTask(in_channels=3, num_classes=11) trainer = Trainer( accelerator='cpu', fast_dev_run=fast_dev_run, @@ -155,9 +129,7 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: datamodule = VHR10DataModule( root='tests/data/vhr10', batch_size=1, num_workers=0 ) - model = InstanceSegmentationTask( - backbone='resnet50', in_channels=15, num_classes=6 - ) + model = InstanceSegmentationTask(in_channels=3, num_classes=11) trainer = Trainer( accelerator='cpu', fast_dev_run=fast_dev_run, @@ -166,8 +138,21 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: ) trainer.validate(model=model, datamodule=datamodule) + def test_predict(self, fast_dev_run: bool) -> None: + datamodule = PredictInstanceSegmentationDataModule( + root='tests/data/vhr10', batch_size=1, num_workers=0 + ) + model = InstanceSegmentationTask(num_classes=11) + trainer = Trainer( + accelerator='cpu', + fast_dev_run=fast_dev_run, + log_every_n_steps=1, + max_epochs=1, + ) + trainer.predict(model=model, datamodule=datamodule) + def test_freeze_backbone(self) -> None: - task = InstanceSegmentationTask(backbone='resnet50', freeze_backbone=True) + task = InstanceSegmentationTask(freeze_backbone=True) for param in task.model.backbone.parameters(): assert param.requires_grad is False diff --git a/torchgeo/trainers/instance_segmentation.py b/torchgeo/trainers/instance_segmentation.py index 08e1c8e3764..86033dc4b58 100644 --- a/torchgeo/trainers/instance_segmentation.py +++ b/torchgeo/trainers/instance_segmentation.py @@ -56,7 +56,6 @@ def __init__( freeze_backbone: Freeze the backbone network to fine-tune the decoder and segmentation head. """ - self.weights = weights super().__init__() def configure_models(self) -> None: @@ -72,7 +71,7 @@ def configure_models(self) -> None: weights = None weights_backbone = None - if self.weights: + if self.hparams['weights']: weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1 weights_backbone = ResNet50_Weights.IMAGENET1K_V1