Skip to content

Commit

Permalink
Trainers: add Instance Segmentation Task (#2513)
Browse files Browse the repository at this point in the history
* Add files via upload

* Add files via upload

* Update instancesegmentation.py

* Update and rename instancesegmentation.py to instance_segmentation.py

* Update test_instancesegmentation.py

* Update instance_segmentation.py

* Update __init__.py

* Update instance_segmentation.py

* Update instance_segmentation.py

* Add files via upload

* Update test_instancesegmentation.py

* Update and rename test_instancesegmentation.py to test_trainer_instancesegmentation.py

* Update instance_segmentation.py

* Add files via upload

* Creato con Colab

* Creato con Colab

* Creato con Colab

* Update instance_segmentation.py

* Delete test_trainer.ipynb

* Delete test_trainer_instancesegmentation.py

* Update and rename test_instancesegmentation.py to test_instance_segmentation.py

* Update instance_segmentation.py

* Update test_instance_segmentation.py

* Update instance_segmentation.py

* Update instance_segmentation.py

* Update instance_segmentation.py run ruff

* Ruff

* dos2unix

* Add support for MSI, weights

* Update tests

* timm and torchvision are not compatible

* Finalize trainer code, simpler

* Update VHR10 tests

* Uniformity

* Fix most tests

* 100% coverage

* Fix datasets tests

* Fix weight tests

* Fix MSI support

* Fix parameter replacement

* Fix minimum tests

* Fix minimum tests

* Add all unpacked data

* Fix tests

* Undo FTW changes

* Undo FTW changes

* Undo FTW changes

* Remove dead code

* Remove dead code, match detection style

* Try newer torchmetrics

* Try newer torchmetrics

* Try newer torchmetrics

* More metrics

* Fix mypy

* Fix and test weights=True, num_classes!=91

---------

Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
  • Loading branch information
ariannasole23 and adamjstewart authored Feb 25, 2025
1 parent 5eb2a5e commit 464e45d
Show file tree
Hide file tree
Showing 12 changed files with 422 additions and 35 deletions.
12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ dependencies = [
"rasterio>=1.3,!=1.4.0,!=1.4.1,!=1.4.2",
# rtree 1+ required for Python 3.10 wheels
"rtree>=1",
# segmentation-models-pytorch 0.2+ required for smp.losses module
"segmentation-models-pytorch>=0.2",
# segmentation-models-pytorch 0.3.3+ required for timm 0.8+ support
"segmentation-models-pytorch>=0.3.3",
# shapely 1.8+ required for Python 3.10 wheels
"shapely>=1.8",
# timm 0.4.12 required by segmentation-models-pytorch
"timm>=0.4.12",
# timm 0.8+ required for timm.models.adapt_input_conv, 0.9.2 required by SMP
"timm>=0.9.2",
# torch 1.13+ required by torchvision
"torch>=1.13",
# torchmetrics 0.10+ required for binary/multiclass/multilabel classification metrics
"torchmetrics>=0.10",
# torchmetrics 1.2+ required for average argument in mAP metric
"torchmetrics>=1.2",
# torchvision 0.14+ required for torchvision.models.swin_v2_b
"torchvision>=0.14",
# typing-extensions 4.5+ required for typing_extensions.deprecated
Expand Down
6 changes: 3 additions & 3 deletions requirements/min-reqs.old
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ pillow==8.4.0
pyproj==3.3.0
rasterio==1.3.0.post1
rtree==1.0.0
segmentation-models-pytorch==0.2.0
segmentation-models-pytorch==0.3.3
shapely==1.8.0
timm==0.4.12
timm==0.9.2
torch==1.13.0
torchmetrics==0.10.0
torchmetrics==1.2.0
torchvision==0.14.0
typing-extensions==4.5.0

Expand Down
14 changes: 14 additions & 0 deletions tests/conf/vhr10_ins_seg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
model:
class_path: InstanceSegmentationTask
init_args:
model: 'mask-rcnn'
backbone: 'resnet50'
num_classes: 11
data:
class_path: VHR10DataModule
init_args:
batch_size: 1
num_workers: 0
patch_size: 4
dict_kwargs:
root: 'tests/data/vhr10'
4 changes: 1 addition & 3 deletions tests/conf/vhr10.yaml → tests/conf/vhr10_obj_det.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@ model:
class_path: ObjectDetectionTask
init_args:
model: 'faster-rcnn'
backbone: 'resnet50'
backbone: 'resnet18'
num_classes: 11
lr: 2.5e-5
patience: 10
data:
class_path: VHR10DataModule
init_args:
Expand Down
6 changes: 3 additions & 3 deletions tests/datasets/test_vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,10 @@ def test_plot(self, dataset: VHR10) -> None:
scores = [0.7, 0.3, 0.7]
for i in range(3):
x = dataset[i]
x['prediction_labels'] = x['label']
x['prediction_label'] = x['label']
x['prediction_bbox_xyxy'] = x['bbox_xyxy']
x['prediction_scores'] = torch.Tensor([scores[i]])
x['prediction_score'] = torch.Tensor([scores[i]])
if 'mask' in x:
x['prediction_masks'] = x['mask']
x['prediction_mask'] = x['mask']
dataset.plot(x, show_feats='masks')
plt.close()
2 changes: 1 addition & 1 deletion tests/trainers/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def plot(*args: Any, **kwargs: Any) -> None:


class TestObjectDetectionTask:
@pytest.mark.parametrize('name', ['nasa_marine_debris', 'vhr10'])
@pytest.mark.parametrize('name', ['nasa_marine_debris', 'vhr10_obj_det'])
@pytest.mark.parametrize('model_name', ['faster-rcnn', 'fcos', 'retinanet'])
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, model_name: str, fast_dev_run: bool
Expand Down
125 changes: 125 additions & 0 deletions tests/trainers/test_instance_segmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import os
from typing import Any

import pytest
from lightning.pytorch import Trainer
from pytest import MonkeyPatch

from torchgeo.datamodules import MisconfigurationException, VHR10DataModule
from torchgeo.datasets import VHR10, RGBBandsMissingError
from torchgeo.main import main
from torchgeo.trainers import InstanceSegmentationTask

# mAP metric requires pycocotools to be installed
pytest.importorskip('pycocotools')


class PredictInstanceSegmentationDataModule(VHR10DataModule):
def setup(self, stage: str) -> None:
self.predict_dataset = VHR10(**self.kwargs)


def plot(*args: Any, **kwargs: Any) -> None:
return None


def plot_missing_bands(*args: Any, **kwargs: Any) -> None:
raise RGBBandsMissingError()


class TestInstanceSegmentationTask:
@pytest.mark.parametrize('name', ['vhr10_ins_seg'])
def test_trainer(
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool
) -> None:
config = os.path.join('tests', 'conf', name + '.yaml')

args = [
'--config',
config,
'--trainer.accelerator',
'cpu',
'--trainer.fast_dev_run',
str(fast_dev_run),
'--trainer.max_epochs',
'1',
'--trainer.log_every_n_steps',
'1',
]

main(['fit', *args])
try:
main(['test', *args])
except MisconfigurationException:
pass
try:
main(['predict', *args])
except MisconfigurationException:
pass

def test_invalid_model(self) -> None:
match = 'Invalid model type'
with pytest.raises(ValueError, match=match):
InstanceSegmentationTask(model='invalid_model')

def test_invalid_backbone(self) -> None:
match = 'Invalid backbone type'
with pytest.raises(ValueError, match=match):
InstanceSegmentationTask(backbone='invalid_backbone')

def test_weights(self) -> None:
InstanceSegmentationTask(weights=True, num_classes=3)
InstanceSegmentationTask(weights=True, num_classes=91)

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(VHR10DataModule, 'plot', plot)
datamodule = VHR10DataModule(
root='tests/data/vhr10', batch_size=1, num_workers=0
)
model = InstanceSegmentationTask(in_channels=3, num_classes=11)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(VHR10DataModule, 'plot', plot_missing_bands)
datamodule = VHR10DataModule(
root='tests/data/vhr10', batch_size=1, num_workers=0
)
model = InstanceSegmentationTask(in_channels=3, num_classes=11)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.validate(model=model, datamodule=datamodule)

def test_predict(self, fast_dev_run: bool) -> None:
datamodule = PredictInstanceSegmentationDataModule(
root='tests/data/vhr10', batch_size=1, num_workers=0
)
model = InstanceSegmentationTask(num_classes=11)
trainer = Trainer(
accelerator='cpu',
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
max_epochs=1,
)
trainer.predict(model=model, datamodule=datamodule)

def test_freeze_backbone(self) -> None:
task = InstanceSegmentationTask(freeze_backbone=True)
for param in task.model.backbone.parameters():
assert param.requires_grad is False

for head in ['rpn', 'roi_heads']:
for param in getattr(task.model, head).parameters():
assert param.requires_grad is True
20 changes: 10 additions & 10 deletions torchgeo/datasets/vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
sample = self.coco_convert(sample)
sample['class'] = sample['label']['labels']
sample['bbox_xyxy'] = sample['label']['boxes']
sample['mask'] = sample['label']['masks'].float()
sample['mask'] = sample['label']['masks']
sample['label'] = sample.pop('class')

if self.transforms is not None:
Expand Down Expand Up @@ -408,21 +408,21 @@ def plot(
n_gt = len(boxes)

ncols = 1
show_predictions = 'prediction_labels' in sample
show_predictions = 'prediction_label' in sample

if show_predictions:
show_pred_boxes = False
show_pred_masks = False
prediction_labels = sample['prediction_labels'].numpy()
prediction_scores = sample['prediction_scores'].numpy()
prediction_label = sample['prediction_label'].numpy()
prediction_score = sample['prediction_score'].numpy()
if 'prediction_bbox_xyxy' in sample:
prediction_bbox_xyxy = sample['prediction_bbox_xyxy'].numpy()
show_pred_boxes = True
if 'prediction_masks' in sample:
prediction_masks = sample['prediction_masks'].numpy()
if 'prediction_mask' in sample:
prediction_mask = sample['prediction_mask'].numpy()
show_pred_masks = True

n_pred = len(prediction_labels)
n_pred = len(prediction_label)
ncols += 1

# Display image
Expand Down Expand Up @@ -475,11 +475,11 @@ def plot(
axs[0, 1].imshow(image)
axs[0, 1].axis('off')
for i in range(n_pred):
score = prediction_scores[i]
score = prediction_score[i]
if score < 0.5:
continue

class_num = prediction_labels[i]
class_num = prediction_label[i]
color = cm(class_num / len(self.categories))

if show_pred_boxes:
Expand Down Expand Up @@ -511,7 +511,7 @@ def plot(

# Add masks
if show_pred_masks:
mask = prediction_masks[i]
mask = prediction_mask[i]
contours = skimage.measure.find_contours(mask, 0.5)
for verts in contours:
verts = np.fliplr(verts)
Expand Down
8 changes: 1 addition & 7 deletions torchgeo/models/fcsiam.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,7 @@ def __init__(
)
encoder_out_channels = [c * 2 for c in self.encoder.out_channels[1:]]
encoder_out_channels.insert(0, self.encoder.out_channels[0])
try:
# smp 0.3+
UnetDecoder = smp.decoders.unet.decoder.UnetDecoder
except AttributeError:
# smp 0.2
UnetDecoder = smp.unet.decoder.UnetDecoder
self.decoder = UnetDecoder(
self.decoder = smp.decoders.unet.decoder.UnetDecoder(
encoder_channels=encoder_out_channels,
decoder_channels=decoder_channels,
n_blocks=encoder_depth,
Expand Down
2 changes: 2 additions & 0 deletions torchgeo/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .byol import BYOLTask
from .classification import ClassificationTask, MultiLabelClassificationTask
from .detection import ObjectDetectionTask
from .instance_segmentation import InstanceSegmentationTask
from .iobench import IOBenchTask
from .moco import MoCoTask
from .regression import PixelwiseRegressionTask, RegressionTask
Expand All @@ -18,6 +19,7 @@
'BaseTask',
'ClassificationTask',
'IOBenchTask',
'InstanceSegmentationTask',
'MoCoTask',
'MultiLabelClassificationTask',
'ObjectDetectionTask',
Expand Down
4 changes: 2 additions & 2 deletions torchgeo/trainers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,8 @@ def validation_step(
):
datamodule = self.trainer.datamodule
batch['prediction_bbox_xyxy'] = [b['boxes'].cpu() for b in y_hat]
batch['prediction_labels'] = [b['labels'].cpu() for b in y_hat]
batch['prediction_scores'] = [b['scores'].cpu() for b in y_hat]
batch['prediction_label'] = [b['labels'].cpu() for b in y_hat]
batch['prediction_score'] = [b['scores'].cpu() for b in y_hat]
batch['image'] = batch['image'].cpu()
sample = unbind_samples(batch)[0]
# Convert image to uint8 for plotting
Expand Down
Loading

0 comments on commit 464e45d

Please sign in to comment.