From 60ad02378f54ef91ce72fe21f0756242a94a369d Mon Sep 17 00:00:00 2001 From: Caleb Robinson Date: Sun, 26 Feb 2023 13:36:46 -0800 Subject: [PATCH] Speed up detection tests (#1148) --- tests/trainers/test_detection.py | 43 ++++++++++++++++++++++++++++++++ torchgeo/trainers/detection.py | 11 ++++---- 2 files changed, 49 insertions(+), 5 deletions(-) diff --git a/tests/trainers/test_detection.py b/tests/trainers/test_detection.py index 26e66ac247f..5bbf20548af 100644 --- a/tests/trainers/test_detection.py +++ b/tests/trainers/test_detection.py @@ -5,9 +5,13 @@ from typing import Any, Dict, Type, cast import pytest +import torch +import torch.nn as nn +import torchvision.models.detection from _pytest.monkeypatch import MonkeyPatch from omegaconf import OmegaConf from pytorch_lightning import LightningDataModule, Trainer +from torch.nn.modules import Module from torchgeo.datamodules import MisconfigurationException, NASAMarineDebrisDataModule from torchgeo.datasets import NASAMarineDebris @@ -19,6 +23,35 @@ def setup(self, stage: str) -> None: self.predict_dataset = NASAMarineDebris(**self.kwargs) +class ObjectDetectionTestModel(Module): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + self.fc = nn.Linear(1, 1) + + def forward(self, images: Any, targets: Any = None) -> Any: + batch_size = len(images) + if self.training: + assert batch_size == len(targets) + # use the Linear layer to generate a tensor that has a gradient + return { + "loss_classifier": self.fc(torch.rand(1)), + "loss_box_reg": self.fc(torch.rand(1)), + "loss_objectness": self.fc(torch.rand(1)), + "loss_rpn_box_reg": self.fc(torch.rand(1)), + } + else: # eval mode + output = [] + for i in range(batch_size): + output.append( + { + "boxes": torch.rand(10, 4), + "labels": torch.randint(0, 2, (10,)), + "scores": torch.rand(10), + } + ) + return output + + def plot(*args: Any, **kwargs: Any) -> None: raise ValueError @@ -30,6 +63,7 @@ class TestObjectDetectionTask: @pytest.mark.parametrize("model_name", ["faster-rcnn", "fcos", "retinanet"]) def test_trainer( self, + monkeypatch: MonkeyPatch, name: str, classname: Type[LightningDataModule], model_name: str, @@ -44,6 +78,15 @@ def test_trainer( datamodule = classname(**datamodule_kwargs) # Instantiate model + monkeypatch.setattr( + torchvision.models.detection, "FasterRCNN", ObjectDetectionTestModel + ) + monkeypatch.setattr( + torchvision.models.detection, "FCOS", ObjectDetectionTestModel + ) + monkeypatch.setattr( + torchvision.models.detection, "RetinaNet", ObjectDetectionTestModel + ) model_kwargs = conf_dict["module"] model_kwargs["model"] = model_name model = ObjectDetectionTask(**model_kwargs) diff --git a/torchgeo/trainers/detection.py b/torchgeo/trainers/detection.py index 1eafdf1b47b..4e1ece02c83 100644 --- a/torchgeo/trainers/detection.py +++ b/torchgeo/trainers/detection.py @@ -9,11 +9,11 @@ import matplotlib.pyplot as plt import pytorch_lightning as pl import torch +import torchvision.models.detection from torch import Tensor from torch.optim.lr_scheduler import ReduceLROnPlateau from torchmetrics.detection.mean_ap import MeanAveragePrecision from torchvision.models import resnet as R -from torchvision.models.detection import FCOS, FasterRCNN, RetinaNet from torchvision.models.detection.backbone_utils import resnet_fpn_backbone from torchvision.models.detection.retinanet import RetinaNetHead from torchvision.models.detection.rpn import AnchorGenerator @@ -95,7 +95,7 @@ def config_task(self) -> None: roi_pooler = MultiScaleRoIAlign( featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2 ) - self.model = FasterRCNN( + self.model = torchvision.models.detection.FasterRCNN( backbone, num_classes, rpn_anchor_generator=anchor_generator, @@ -113,8 +113,9 @@ def config_task(self) -> None: aspect_ratios=((1.0,), (1.0,), (1.0,), (1.0,), (1.0,), (1.0,)), ) - self.model = FCOS(backbone, num_classes, anchor_generator=anchor_generator) - + self.model = torchvision.models.detection.FCOS( + backbone, num_classes, anchor_generator=anchor_generator + ) elif self.hyperparams["model"] == "retinanet": kwargs["extra_blocks"] = feature_pyramid_network.LastLevelP6P7( latent_dim, 256 @@ -139,7 +140,7 @@ def config_task(self) -> None: norm_layer=partial(torch.nn.GroupNorm, 32), ) - self.model = RetinaNet( + self.model = torchvision.models.detection.RetinaNet( backbone, num_classes, anchor_generator=anchor_generator, head=head ) else: