Skip to content

Commit

Permalink
Fix weight tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Feb 21, 2025
1 parent 1d80adc commit d8e8fe6
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tests/trainers/test_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def test_invalid_backbone(self) -> None:
with pytest.raises(ValueError, match=match):
InstanceSegmentationTask(backbone='invalid_backbone')

def test_pretrained_backbone(self) -> None:
InstanceSegmentationTask(backbone='resnet50', weights=True)
def test_weights(self) -> None:
InstanceSegmentationTask(weights=True, num_classes=91)

def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
monkeypatch.setattr(VHR10DataModule, 'plot', plot)
Expand Down
4 changes: 3 additions & 1 deletion torchgeo/trainers/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,9 @@ def configure_models(self) -> None:
weights_backbone = None
if self.hparams['weights']:
weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1
weights_backbone = ResNet50_Weights.IMAGENET1K_V1
# TODO: drop last layer of weights
if num_classes == 91:
weights_backbone = ResNet50_Weights.IMAGENET1K_V1

# Create model
if model == 'mask_rcnn':
Expand Down

0 comments on commit d8e8fe6

Please sign in to comment.