-
Notifications
You must be signed in to change notification settings - Fork 405
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Trainers: add Instance Segmentation Task (#2513)
* Add files via upload * Add files via upload * Update instancesegmentation.py * Update and rename instancesegmentation.py to instance_segmentation.py * Update test_instancesegmentation.py * Update instance_segmentation.py * Update __init__.py * Update instance_segmentation.py * Update instance_segmentation.py * Add files via upload * Update test_instancesegmentation.py * Update and rename test_instancesegmentation.py to test_trainer_instancesegmentation.py * Update instance_segmentation.py * Add files via upload * Creato con Colab * Creato con Colab * Creato con Colab * Update instance_segmentation.py * Delete test_trainer.ipynb * Delete test_trainer_instancesegmentation.py * Update and rename test_instancesegmentation.py to test_instance_segmentation.py * Update instance_segmentation.py * Update test_instance_segmentation.py * Update instance_segmentation.py * Update instance_segmentation.py * Update instance_segmentation.py run ruff * Ruff * dos2unix * Add support for MSI, weights * Update tests * timm and torchvision are not compatible * Finalize trainer code, simpler * Update VHR10 tests * Uniformity * Fix most tests * 100% coverage * Fix datasets tests * Fix weight tests * Fix MSI support * Fix parameter replacement * Fix minimum tests * Fix minimum tests * Add all unpacked data * Fix tests * Undo FTW changes * Undo FTW changes * Undo FTW changes * Remove dead code * Remove dead code, match detection style * Try newer torchmetrics * Try newer torchmetrics * Try newer torchmetrics * More metrics * Fix mypy * Fix and test weights=True, num_classes!=91 --------- Co-authored-by: Adam J. Stewart <ajstewart426@gmail.com>
- Loading branch information
1 parent
5eb2a5e
commit 464e45d
Showing
12 changed files
with
422 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
model: | ||
class_path: InstanceSegmentationTask | ||
init_args: | ||
model: 'mask-rcnn' | ||
backbone: 'resnet50' | ||
num_classes: 11 | ||
data: | ||
class_path: VHR10DataModule | ||
init_args: | ||
batch_size: 1 | ||
num_workers: 0 | ||
patch_size: 4 | ||
dict_kwargs: | ||
root: 'tests/data/vhr10' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. | ||
|
||
import os | ||
from typing import Any | ||
|
||
import pytest | ||
from lightning.pytorch import Trainer | ||
from pytest import MonkeyPatch | ||
|
||
from torchgeo.datamodules import MisconfigurationException, VHR10DataModule | ||
from torchgeo.datasets import VHR10, RGBBandsMissingError | ||
from torchgeo.main import main | ||
from torchgeo.trainers import InstanceSegmentationTask | ||
|
||
# mAP metric requires pycocotools to be installed | ||
pytest.importorskip('pycocotools') | ||
|
||
|
||
class PredictInstanceSegmentationDataModule(VHR10DataModule): | ||
def setup(self, stage: str) -> None: | ||
self.predict_dataset = VHR10(**self.kwargs) | ||
|
||
|
||
def plot(*args: Any, **kwargs: Any) -> None: | ||
return None | ||
|
||
|
||
def plot_missing_bands(*args: Any, **kwargs: Any) -> None: | ||
raise RGBBandsMissingError() | ||
|
||
|
||
class TestInstanceSegmentationTask: | ||
@pytest.mark.parametrize('name', ['vhr10_ins_seg']) | ||
def test_trainer( | ||
self, monkeypatch: MonkeyPatch, name: str, fast_dev_run: bool | ||
) -> None: | ||
config = os.path.join('tests', 'conf', name + '.yaml') | ||
|
||
args = [ | ||
'--config', | ||
config, | ||
'--trainer.accelerator', | ||
'cpu', | ||
'--trainer.fast_dev_run', | ||
str(fast_dev_run), | ||
'--trainer.max_epochs', | ||
'1', | ||
'--trainer.log_every_n_steps', | ||
'1', | ||
] | ||
|
||
main(['fit', *args]) | ||
try: | ||
main(['test', *args]) | ||
except MisconfigurationException: | ||
pass | ||
try: | ||
main(['predict', *args]) | ||
except MisconfigurationException: | ||
pass | ||
|
||
def test_invalid_model(self) -> None: | ||
match = 'Invalid model type' | ||
with pytest.raises(ValueError, match=match): | ||
InstanceSegmentationTask(model='invalid_model') | ||
|
||
def test_invalid_backbone(self) -> None: | ||
match = 'Invalid backbone type' | ||
with pytest.raises(ValueError, match=match): | ||
InstanceSegmentationTask(backbone='invalid_backbone') | ||
|
||
def test_weights(self) -> None: | ||
InstanceSegmentationTask(weights=True, num_classes=3) | ||
InstanceSegmentationTask(weights=True, num_classes=91) | ||
|
||
def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: | ||
monkeypatch.setattr(VHR10DataModule, 'plot', plot) | ||
datamodule = VHR10DataModule( | ||
root='tests/data/vhr10', batch_size=1, num_workers=0 | ||
) | ||
model = InstanceSegmentationTask(in_channels=3, num_classes=11) | ||
trainer = Trainer( | ||
accelerator='cpu', | ||
fast_dev_run=fast_dev_run, | ||
log_every_n_steps=1, | ||
max_epochs=1, | ||
) | ||
trainer.validate(model=model, datamodule=datamodule) | ||
|
||
def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None: | ||
monkeypatch.setattr(VHR10DataModule, 'plot', plot_missing_bands) | ||
datamodule = VHR10DataModule( | ||
root='tests/data/vhr10', batch_size=1, num_workers=0 | ||
) | ||
model = InstanceSegmentationTask(in_channels=3, num_classes=11) | ||
trainer = Trainer( | ||
accelerator='cpu', | ||
fast_dev_run=fast_dev_run, | ||
log_every_n_steps=1, | ||
max_epochs=1, | ||
) | ||
trainer.validate(model=model, datamodule=datamodule) | ||
|
||
def test_predict(self, fast_dev_run: bool) -> None: | ||
datamodule = PredictInstanceSegmentationDataModule( | ||
root='tests/data/vhr10', batch_size=1, num_workers=0 | ||
) | ||
model = InstanceSegmentationTask(num_classes=11) | ||
trainer = Trainer( | ||
accelerator='cpu', | ||
fast_dev_run=fast_dev_run, | ||
log_every_n_steps=1, | ||
max_epochs=1, | ||
) | ||
trainer.predict(model=model, datamodule=datamodule) | ||
|
||
def test_freeze_backbone(self) -> None: | ||
task = InstanceSegmentationTask(freeze_backbone=True) | ||
for param in task.model.backbone.parameters(): | ||
assert param.requires_grad is False | ||
|
||
for head in ['rpn', 'roi_heads']: | ||
for param in getattr(task.model, head).parameters(): | ||
assert param.requires_grad is True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.