Skip to content

Commit

Permalink
Fix most tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Feb 21, 2025
1 parent 4f201fd commit 006cfa9
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 82 deletions.
145 changes: 65 additions & 80 deletions tests/trainers/test_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,38 +2,61 @@
# Licensed under the MIT License.

import os
from pathlib import Path
from typing import Any, cast
from typing import Any

import pytest
import timm
import torch
import torch.nn as nn
from lightning.pytorch import Trainer
from pytest import MonkeyPatch
from torch.nn.modules import Module
from torchvision.models._api import WeightsEnum

from torchgeo.datamodules import MisconfigurationException, VHR10DataModule
from torchgeo.datasets import RGBBandsMissingError
from torchgeo.datasets import VHR10, RGBBandsMissingError
from torchgeo.main import main
from torchgeo.models import ResNet50_Weights
from torchgeo.trainers import InstanceSegmentationTask

# MAP metric requires pycocotools to be installed
pytest.importorskip('pycocotools')

class SegmentationTestModel(Module):
def __init__(self, in_channels: int = 3, classes: int = 3, **kwargs: Any) -> None:
super().__init__()
self.conv1 = nn.Conv2d(
in_channels=in_channels, out_channels=classes, kernel_size=1, padding=0
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return cast(torch.Tensor, self.conv1(x))
class PredictInstanceSegmentationDataModule(VHR10DataModule):
def setup(self, stage: str) -> None:
self.predict_dataset = VHR10(**self.kwargs)


def create_model(**kwargs: Any) -> Module:
return SegmentationTestModel(**kwargs)
# TODO: This is not even used yet
class InstanceSegmentationTestModel(Module):
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__()
self.fc = nn.Linear(1, 1)

def forward(self, images: Any, targets: Any = None) -> Any:
batch_size = len(images)
if self.training:
assert batch_size == len(targets)
# use the Linear layer to generate a tensor that has a gradient
return {
'loss_classifier': self.fc(torch.rand(1)),
'loss_box_reg': self.fc(torch.rand(1)),
'loss_objectness': self.fc(torch.rand(1)),
'loss_rpn_box_reg': self.fc(torch.rand(1)),
}
else: # eval mode
output = []
for i in range(batch_size):
boxes = torch.rand(10, 4)
# xmax, ymax must be larger than xmin, ymin
boxes[:, 2:] += 1
output.append(
{
'boxes': boxes,
'masks': torch.randint(2, images.shape[1:]),
'labels': torch.randint(2, (10,)),
'scores': torch.rand(10),
}
)
return output


def plot(*args: Any, **kwargs: Any) -> None:
Expand Down Expand Up @@ -74,74 +97,25 @@ def test_trainer(
except MisconfigurationException:
pass

@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet50_Weights.SENTINEL2_ALL_MOCO

@pytest.fixture
def mocked_weights(
self,
tmp_path: Path,
monkeypatch: MonkeyPatch,
weights: WeightsEnum,
load_state_dict_from_url: None,
) -> WeightsEnum:
path = tmp_path / f'{weights}.pth'
model = timm.create_model(
weights.meta['model'], in_chans=weights.meta['in_chans']
)
torch.save(model.state_dict(), path)
try:
monkeypatch.setattr(weights.value, 'url', str(path))
except AttributeError:
monkeypatch.setattr(weights, 'url', str(path))
return weights

def test_weight_file(self, checkpoint: str) -> None:
InstanceSegmentationTask(backbone='resnet50', weights=checkpoint, num_classes=6)

def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
InstanceSegmentationTask(
backbone=mocked_weights.meta['model'],
weights=mocked_weights,
in_channels=mocked_weights.meta['in_chans'],
)

def test_weight_str(self, mocked_weights: WeightsEnum) -> None:
InstanceSegmentationTask(
backbone=mocked_weights.meta['model'],
weights=str(mocked_weights),
in_channels=mocked_weights.meta['in_chans'],
)

@pytest.mark.slow
def test_weight_enum_download(self, weights: WeightsEnum) -> None:
InstanceSegmentationTask(
backbone=weights.meta['model'],
weights=weights,
in_channels=weights.meta['in_chans'],
)

@pytest.mark.slow
def test_weight_str_download(self, weights: WeightsEnum) -> None:
InstanceSegmentationTask(
backbone=weights.meta['model'],
weights=str(weights),
in_channels=weights.meta['in_chans'],
)

def test_invalid_model(self) -> None:
with pytest.raises(ValueError, match='Invalid model type'):
match = 'Invalid model type'
with pytest.raises(ValueError, match=match):
InstanceSegmentationTask(model='invalid_model')

def test_invalid_backbone(self) -> None:
match = 'Invalid backbone type'
with pytest.raises(ValueError, match=match):
InstanceSegmentationTask(backbone='invalid_backbone')

def test_pretrained_backbone(self) -> None:
InstanceSegmentationTask(backbone='resnet50', weights=True)

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(VHR10DataModule, 'plot', plot)
datamodule = VHR10DataModule(
root='tests/data/vhr10', batch_size=1, num_workers=0
)
model = InstanceSegmentationTask(
backbone='resnet50', in_channels=15, num_classes=6
)
model = InstanceSegmentationTask(in_channels=3, num_classes=11)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
Expand All @@ -155,9 +129,7 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
datamodule = VHR10DataModule(
root='tests/data/vhr10', batch_size=1, num_workers=0
)
model = InstanceSegmentationTask(
backbone='resnet50', in_channels=15, num_classes=6
)
model = InstanceSegmentationTask(in_channels=3, num_classes=11)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
Expand All @@ -166,8 +138,21 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictInstanceSegmentationDataModule(
root='tests/data/vhr10', batch_size=1, num_workers=0
)
model = InstanceSegmentationTask(num_classes=11)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)

def test_freeze_backbone(self) -> None:
task = InstanceSegmentationTask(backbone='resnet50', freeze_backbone=True)
task = InstanceSegmentationTask(freeze_backbone=True)
for param in task.model.backbone.parameters():
assert param.requires_grad is False

Expand Down
3 changes: 1 addition & 2 deletions torchgeo/trainers/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def __init__(
freeze_backbone: Freeze the backbone network to fine-tune the
decoder and segmentation head.
"""
self.weights = weights
super().__init__()

def configure_models(self) -> None:
Expand All @@ -72,7 +71,7 @@ def configure_models(self) -> None:

weights = None
weights_backbone = None
if self.weights:
if self.hparams['weights']:
weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1
weights_backbone = ResNet50_Weights.IMAGENET1K_V1

Expand Down

0 comments on commit 006cfa9

Please sign in to comment.