Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainers: add Instance Segmentation Task #2513

Merged
merged 58 commits into from
Feb 25, 2025
Merged

Conversation

ariannasole23
Copy link
Contributor

No description provided.

@github-actions github-actions bot added testing Continuous integration testing trainers PyTorch Lightning trainers labels Jan 13, 2025
@adamjstewart adamjstewart changed the title PR Trainers: add Instance Segmentation Task Jan 13, 2025
@adamjstewart adamjstewart added this to the 0.7.0 milestone Jan 13, 2025
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just need to make the code and tests match our existing trainers and run ruff: https://torchgeo.readthedocs.io/en/latest/user/contributing.html#linters

@adamjstewart
Copy link
Collaborator

To solve the import issue, you also need to add 2 lines to torchgeo/trainers/__init__.py, just copy the style used by other classes.

Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you run ruff on the code to make it more uniform? This will make it easier to review.

@github-actions github-actions bot added the dependencies Packaging and dependencies label Feb 21, 2025
InstanceSegmentationTask(backbone='invalid_backbone')

def test_weights(self) -> None:
InstanceSegmentationTask(weights=True, num_classes=91)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yet another test that downloads data on the fly (#1088), but let's fix that another day when we figure out how to support custom weights.

@adamjstewart
Copy link
Collaborator

Hmm, looks like FTW isn't supported because it only has masks, not bboxes.

@adamjstewart adamjstewart added the backwards-incompatible Changes that are not backwards compatible label Feb 21, 2025
@github-actions github-actions bot added the models Models and pretrained weights label Feb 21, 2025
adamjstewart
adamjstewart previously approved these changes Feb 22, 2025
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are finally passing! @ariannasole23 maybe you can check VHR-10 again to make sure I didn't break anything and it still works. Would like a couple other people to review this too, but otherwise this is ready to merge once you accept the CLA.

@ariannasole23
Copy link
Contributor Author

@microsoft-github-policy-service agree

adamjstewart
adamjstewart previously approved these changes Feb 23, 2025
Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will give others a few days to test before merging.

@ashnair1
Copy link
Collaborator

The current InstanceSegmentationTask implementation isn't really flexible. I tried training a model with it recently and found the following:

  • Training a model with weights=True i.e. with a pretrained backbone crashes as the FPN Weights trained on COCO require number of classes to be 91 whereas the VHR10 has 11 classes (incl background). Currently, the only way to train this model would be to have weights=None. See below for sample training config
  • Using the Resnet50_FPN_Weights means that users can only use the resnet50 backbone. The ObjectDetectionTask allows users to choose their preferred backbone i.e. resnet18, resnet34, resnet50 etc.

Train config used:

trainer:
  min_epochs: 15
  max_epochs: 100

model:
  class_path: InstanceSegmentationTask
  init_args:
    model: 'mask-rcnn'
    backbone: 'resnet50'  # Only resnet50 can be used
    weights: None  # Using pretrained weights (weights:True) will not work in this implementation
    num_classes: 11
    lr: 3.0e-5 

I'd suggest following the implementation of how Faster-RCNN is defined in ObjectDetectionTask while defining Mask-RCNN in InstanceSegmentationTask.

@adamjstewart
Copy link
Collaborator

  • weights=True, num_channels!=91 has been "fixed" and tested in the latest commit. I say "fixed" because I just skip the decoder weights when num_channels!=91. A smarter solution would be to drop the last layer of the weights, but I was lazy.
  • Yep, this trainer currently lacks a lot of options, and only supports a single model and backbone. I agree that the implementation in ObjectDetectionTask is more flexible, but @ariannasole23 is busy with final exams and I'm busy preparing the release. I would suggest merging this as is and you can open a follow-up PR to add more backbones if you want.

@ashnair1 ashnair1 merged commit 464e45d into microsoft:main Feb 25, 2025
22 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backwards-incompatible Changes that are not backwards compatible datasets Geospatial or benchmark datasets dependencies Packaging and dependencies models Models and pretrained weights testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants