-
Notifications
You must be signed in to change notification settings - Fork 405
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom semantic segmentation tutorial #2588
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A much needed tutorial, thanks!
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"flake8: noqa: E501\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove flake8
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's make this a code cell, otherwise tools like ruff (CPY rule) can't check this file
"\n", | ||
"_Written by: Caleb Robinson_\n", | ||
"\n", | ||
"In this tutorial, we demonstrate how to extend a TorchGeo [\"trainer class\"](https://torchgeo.readthedocs.io/en/latest/api/trainers.html). In TorchGeo there exist several trainer classes that are pre-made PyTorch Lightning Modules designed to allow for the easy training of models on semantic segmentation, classification, change detection, etc. tasks using TorchGeo's [prebuilt DataModules](https://torchgeo.readthedocs.io/en/latest/api/datamodules.html). While the trainers aim to provide sensible defaults and customization options for common tasks, they will not be able to cover all situations (e.g. researchers will likely want to implement and use their own architectures, loss functions, optimizers, etc. in the training routine). If you run into such a situation, then you can simply extend the trainer class you are interested in, and write custom logic to override the default functionality.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"In this tutorial, we demonstrate how to extend a TorchGeo [\"trainer class\"](https://torchgeo.readthedocs.io/en/latest/api/trainers.html). In TorchGeo there exist several trainer classes that are pre-made PyTorch Lightning Modules designed to allow for the easy training of models on semantic segmentation, classification, change detection, etc. tasks using TorchGeo's [prebuilt DataModules](https://torchgeo.readthedocs.io/en/latest/api/datamodules.html). While the trainers aim to provide sensible defaults and customization options for common tasks, they will not be able to cover all situations (e.g. researchers will likely want to implement and use their own architectures, loss functions, optimizers, etc. in the training routine). If you run into such a situation, then you can simply extend the trainer class you are interested in, and write custom logic to override the default functionality.\n", | |
"In this tutorial, we demonstrate how to extend a TorchGeo [\"trainer class\"](https://torchgeo.readthedocs.io/en/latest/api/trainers.html). In TorchGeo there exist several trainer classes that are pre-made PyTorch Lightning Modules designed to allow for the easy training of models on semantic segmentation, classification, change detection, etc. tasks using TorchGeo's [prebuilt DataModules](https://torchgeo.readthedocs.io/en/latest/api/datamodules.html). While the trainers aim to provide sensible defaults and customization options for common tasks, they will not be able to cover all situations (e.g., researchers will likely want to implement and use their own architectures, loss functions, optimizers, etc. in the training routine). If you run into such a situation, then you can simply extend the trainer class you are interested in, and write custom logic to override the default functionality.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's add custom metrics to this list, that's also a common use case.
"- `configure_callbacks`: We demonstrate how to stack `ModelCheckpoint` callbacks to save the best checkpoint as well as periodic checkpoints\n", | ||
"- `on_train_epoch_start`: We log the learning rate at the start of each epoch so we can easily see how it decays over a training run\n", | ||
"\n", | ||
"Overall these demonstrate how to customize the training routine to investigate specific research questions (e.g. of the scheduler on test performance)." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Overall these demonstrate how to customize the training routine to investigate specific research questions (e.g. of the scheduler on test performance)." | |
"Overall these demonstrate how to customize the training routine to investigate specific research questions (e.g., of the scheduler on test performance)." |
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Get rid of the pesky warnings raised by kornia\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to not suppress warnings. Can we instead fix them?
"\n", | ||
"# You can use the following for actual training runs\n", | ||
"# from torchgeo.datamodules import LandCoverAIDataModule\n", | ||
"# dm = LandCoverAIDataModule(root='data', batch_size=64, num_workers=8, download=True)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure if this comment is necessary, but I'll let you decide.
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Downloading: \"https://download.pytorch.org/models/resnet50-19c8e357.pth\" to /home/davrob/.cache/torch/hub/checkpoints/resnet50-19c8e357.pth\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I usually strip the outputs to keep the diffs small. But I know you disagree with this approach.
" limit_val_batches=1,\n", | ||
" num_sanity_val_steps=0,\n", | ||
" max_epochs=1,\n", | ||
" accelerator='gpu' if torch.cuda.is_available() else 'cpu',\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can remove this line, it doesn't cover the other dozen accelerators, and the auto
default does
"source": [ | ||
"# The following Trainer config is useful just for testing the code in this notebook.\n", | ||
"trainer = pl.Trainer(\n", | ||
" limit_train_batches=1,\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of this can be replaced by fast_dev_run=True
. See the existing Trainers tutorial for how to mock this so that you can do 10+ epochs in the tutorial but 1 step in the tests.
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is one of the reasons I like to strip the output
Re-upped version of #1897
For review convenience, here's a link to the notebook.