Skip to content

Commit

Permalink
Finalize trainer code, simpler
Browse files Browse the repository at this point in the history
  • Loading branch information
adamjstewart committed Feb 21, 2025
1 parent 7c4e30c commit 7c34d4a
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 104 deletions.
12 changes: 6 additions & 6 deletions tests/trainers/test_instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from torchgeo.datamodules import MisconfigurationException, VHR10DataModule
from torchgeo.datasets import RGBBandsMissingError
from torchgeo.main import main
from torchgeo.models import ResNet18_Weights
from torchgeo.models import ResNet50_Weights
from torchgeo.trainers import InstanceSegmentationTask


Expand Down Expand Up @@ -76,7 +76,7 @@ def test_trainer(

@pytest.fixture
def weights(self) -> WeightsEnum:
return ResNet18_Weights.SENTINEL2_ALL_MOCO
return ResNet50_Weights.SENTINEL2_ALL_MOCO

@pytest.fixture
def mocked_weights(
Expand All @@ -98,7 +98,7 @@ def mocked_weights(
return weights

def test_weight_file(self, checkpoint: str) -> None:
InstanceSegmentationTask(backbone='resnet18', weights=checkpoint, num_classes=6)
InstanceSegmentationTask(backbone='resnet50', weights=checkpoint, num_classes=6)

def test_weight_enum(self, mocked_weights: WeightsEnum) -> None:
InstanceSegmentationTask(
Expand Down Expand Up @@ -140,7 +140,7 @@ def test_no_plot_method(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> N
root='tests/data/vhr10', batch_size=1, num_workers=0
)
model = InstanceSegmentationTask(
backbone='resnet18', in_channels=15, num_classes=6
backbone='resnet50', in_channels=15, num_classes=6
)
trainer = Trainer(
accelerator='cpu',
Expand All @@ -156,7 +156,7 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
root='tests/data/vhr10', batch_size=1, num_workers=0
)
model = InstanceSegmentationTask(
backbone='resnet18', in_channels=15, num_classes=6
backbone='resnet50', in_channels=15, num_classes=6
)
trainer = Trainer(
accelerator='cpu',
Expand All @@ -167,7 +167,7 @@ def test_no_rgb(self, monkeypatch: MonkeyPatch, fast_dev_run: bool) -> None:
trainer.validate(model=model, datamodule=datamodule)

def test_freeze_backbone(self) -> None:
task = InstanceSegmentationTask(backbone='resnet18', freeze_backbone=True)
task = InstanceSegmentationTask(backbone='resnet50', freeze_backbone=True)
for param in task.model.backbone.parameters():
assert param.requires_grad is False

Expand Down
20 changes: 10 additions & 10 deletions torchgeo/datasets/vhr10.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def __getitem__(self, index: int) -> dict[str, Any]:
sample = self.coco_convert(sample)
sample['class'] = sample['label']['labels']
sample['bbox_xyxy'] = sample['label']['boxes']
sample['mask'] = sample['label']['masks'].float()
sample['mask'] = sample['label']['masks']
sample['label'] = sample.pop('class')

if self.transforms is not None:
Expand Down Expand Up @@ -409,21 +409,21 @@ def plot(
n_gt = len(boxes)

ncols = 1
show_predictions = 'prediction_labels' in sample
show_predictions = 'prediction_label' in sample

if show_predictions:
show_pred_boxes = False
show_pred_masks = False
prediction_labels = sample['prediction_labels'].numpy()
prediction_scores = sample['prediction_scores'].numpy()
prediction_label = sample['prediction_label'].numpy()
prediction_score = sample['prediction_score'].numpy()
if 'prediction_bbox_xyxy' in sample:
prediction_bbox_xyxy = sample['prediction_bbox_xyxy'].numpy()
show_pred_boxes = True
if 'prediction_masks' in sample:
prediction_masks = sample['prediction_masks'].numpy()
if 'prediction_mask' in sample:
prediction_mask = sample['prediction_mask'].numpy()
show_pred_masks = True

n_pred = len(prediction_labels)
n_pred = len(prediction_label)
ncols += 1

# Display image
Expand Down Expand Up @@ -476,11 +476,11 @@ def plot(
axs[0, 1].imshow(image)
axs[0, 1].axis('off')
for i in range(n_pred):
score = prediction_scores[i]
score = prediction_score[i]
if score < 0.5:
continue

class_num = prediction_labels[i]
class_num = prediction_label[i]
color = cm(class_num / len(self.categories))

if show_pred_boxes:
Expand Down Expand Up @@ -512,7 +512,7 @@ def plot(

# Add masks
if show_pred_masks:
mask = prediction_masks[i]
mask = prediction_mask[i]
contours = skimage.measure.find_contours(mask, 0.5)
for verts in contours:
verts = np.fliplr(verts)
Expand Down
139 changes: 51 additions & 88 deletions torchgeo/trainers/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class InstanceSegmentationTask(BaseTask):
.. versionadded:: 0.7
"""

ignore = None
monitor = 'val_map'
mode = 'max'

def __init__(
self,
model: str = 'mask_rcnn',
Expand Down Expand Up @@ -108,14 +112,13 @@ def configure_metrics(self) -> None:
- Uses Mean Average Precision (mAP) for masks (IOU-based metric).
"""
metrics = MetricCollection([MeanAveragePrecision(iou_type='segm')])
self.train_metrics = metrics.clone(prefix='train_')
self.val_metrics = metrics.clone(prefix='val_')
self.test_metrics = metrics.clone(prefix='test_')

def training_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute the training loss and additional metrics.
"""Compute the training loss.
Args:
batch: The output of your DataLoader.
Expand All @@ -125,60 +128,40 @@ def training_step(
Returns:
The loss tensor.
"""
images = batch['image'].unbind()
images = batch['image']
targets = {
'boxes': batch['bbox_xyxy'],
'labels': batch['label'],
'masks': batch['mask'],
}
loss_dict = self(images, unbind_samples(targets))
loss = sum(loss for loss in loss_dict.values())
self.log('train_loss', loss, batch_size=len(images))
losses = self(images.unbind(), unbind_samples(targets))
self.log_dict(losses, batch_size=len(images))
loss: Tensor = sum(losses.values())
return loss

def validation_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> None:
"""Compute the validation loss and additional metrics.
"""Compute the validation metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
images, targets = batch['image'], batch['target']
batch_size = images.shape[0]

outputs = self.model(images)
loss_dict_list = self.model(images, targets) # list of dictionaries
total_loss = sum(
sum(loss_item for loss_item in loss_dict.values() if loss_item.ndim == 0)
for loss_dict in loss_dict_list
)

for target in targets:
target['masks'] = (target['masks'] > 0).to(torch.uint8)
target['boxes'] = target['boxes'].to(torch.float32)
target['labels'] = target['labels'].to(torch.int64)

for output in outputs:
if 'masks' in output:
output['masks'] = (output['masks'] > 0.5).squeeze(1).to(torch.uint8)

self.log('val_loss', total_loss, batch_size=batch_size)

metrics = self.val_metrics(outputs, targets)
# Log only scalar values from metrics
scalar_metrics = {}
for key, value in metrics.items():
if isinstance(value, torch.Tensor) and value.numel() > 1:
# Cast to float if integer and compute mean
value = value.to(torch.float32).mean()
scalar_metrics[key] = value

self.log_dict(scalar_metrics, batch_size=batch_size)

# check
images = batch['image']
targets = {'masks': batch['mask'], 'labels': batch['label']}
predictions = self(images.unbind())
for pred in predictions:
pred['masks'] = (pred['masks'] > 0.5).squeeze(1).to(torch.uint8)

metrics = self.val_metrics(predictions, unbind_samples(targets))

# https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
metrics.pop('val_classes', None)

self.log_dict(metrics, batch_size=len(images))

if (
batch_idx < 10
and hasattr(self.trainer, 'datamodule')
Expand All @@ -189,7 +172,10 @@ def validation_step(
):
datamodule = self.trainer.datamodule

batch['prediction_masks'] = [output['masks'].cpu() for output in outputs]
batch['prediction_bbox_xyxy'] = [
pred['boxes'].cpu() for pred in predictions
]
batch['prediction_mask'] = [pred['masks'].cpu() for pred in predictions]
batch['image'] = batch['image'].cpu()

sample = unbind_samples(batch)[0]
Expand All @@ -208,70 +194,47 @@ def validation_step(
plt.close()

def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
"""Compute the test loss and additional metrics.
"""Compute the test metrics.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
"""
images, targets = batch['image'], batch['target']
batch_size = images.shape[0]

outputs = self.model(images)
loss_dict_list = self.model(
images, targets
) # Compute all losses, list of dictonaries (one for every batch element)
total_loss = sum(
sum(loss_item for loss_item in loss_dict.values() if loss_item.ndim == 0)
for loss_dict in loss_dict_list
)

for target in targets:
target['masks'] = target['masks'].to(torch.uint8)
target['boxes'] = target['boxes'].to(torch.float32)
target['labels'] = target['labels'].to(torch.int64)

for output in outputs:
if 'masks' in output:
output['masks'] = (output['masks'] > 0.5).squeeze(1).to(torch.uint8)

self.log('test_loss', total_loss, batch_size=batch_size)

metrics = self.val_metrics(outputs, targets)
# Log only scalar values from metrics
scalar_metrics = {}
for key, value in metrics.items():
if isinstance(value, torch.Tensor) and value.numel() > 1:
# Cast to float if integer and compute mean
value = value.to(torch.float32).mean()
scalar_metrics[key] = value

self.log_dict(scalar_metrics, batch_size=batch_size)
images = batch['image']
targets = {'masks': batch['mask'], 'labels': batch['label']}
predictions = self(images.unbind())
for prediction in predictions:
prediction['masks'] = (prediction['masks'] > 0.5).squeeze(1).to(torch.uint8)

metrics = self.test_metrics(predictions, unbind_samples(targets))

# https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714
metrics.pop('test_classes', None)

self.log_dict(metrics, batch_size=len(images))

def predict_step(
self, batch: Any, batch_idx: int, dataloader_idx: int = 0
) -> Tensor:
"""Compute the predicted class probabilities.
) -> list[dict[str, Tensor]]:
"""Compute the predicted masks.
Args:
batch: The output of your DataLoader.
batch_idx: Integer displaying index of this batch.
dataloader_idx: Index of the current dataloader.
Returns:
Output predicted probabilities.
Output predicted masks.
"""
images = batch['image']
predictions: list[dict[str, Tensor]] = self(images.unbind())

with torch.no_grad():
outputs = self.model(images)

for output in outputs:
keep = output['scores'] > 0.05
output['boxes'] = output['boxes'][keep]
output['labels'] = output['labels'][keep]
output['scores'] = output['scores'][keep]
output['masks'] = (output['masks'] > 0.5).squeeze(1).to(torch.uint8)[keep]
for pred in predictions:
keep = pred['scores'] > 0.05
pred['boxes'] = pred['boxes'][keep]
pred['labels'] = pred['labels'][keep]
pred['scores'] = pred['scores'][keep]
pred['masks'] = (pred['masks'] > 0.5).squeeze(1).to(torch.uint8)[keep]

return outputs
return predictions

0 comments on commit 7c34d4a

Please sign in to comment.