Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster RCNN Model + Pascal VOC DataModule #157

Merged
merged 10 commits into from
Aug 22, 2020
20 changes: 16 additions & 4 deletions pl_bolts/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,24 @@
from pl_bolts.datamodules.async_dataloader import AsynchronousLoader
from pl_bolts.datamodules.cifar10_datamodule import CIFAR10DataModule, TinyCIFAR10DataModule
from pl_bolts.datamodules.cifar10_datamodule import (
CIFAR10DataModule,
TinyCIFAR10DataModule,
)
from pl_bolts.datamodules.dummy_dataset import DummyDataset
from pl_bolts.datamodules.experience_source import (ExperienceSourceDataset, ExperienceSource,
DiscountedExperienceSource)
from pl_bolts.datamodules.experience_source import (
ExperienceSourceDataset,
ExperienceSource,
DiscountedExperienceSource,
)
from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule
from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule
from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule
from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule
from pl_bolts.datamodules.sklearn_datamodule import SklearnDataset, SklearnDataModule, TensorDataset, TensorDataModule
from pl_bolts.datamodules.sklearn_datamodule import (
SklearnDataset,
SklearnDataModule,
TensorDataset,
TensorDataModule,
)
from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule
from pl_bolts.datamodules.stl10_datamodule import STL10DataModule
from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule
199 changes: 199 additions & 0 deletions pl_bolts/datamodules/vocdetection_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import torch
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from torchvision.datasets import VOCDetection
import torchvision.transforms as T


class Compose(object):
"""
Like `torchvision.transforms.compose` but works for (image, target)
"""

def __init__(self, transforms):
self.transforms = transforms

def __call__(self, image, target):
for t in self.transforms:
image, target = t(image, target)
return image, target


def _collate_fn(batch):
return tuple(zip(*batch))


CLASSES = (
"__background__ ",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
)


def _prepare_voc_instance(image, target):
"""
Prepares VOC dataset into appropriate target for fasterrcnn

https://github.com/pytorch/vision/issues/1097#issuecomment-508917489
"""
anno = target["annotation"]
h, w = anno["size"]["height"], anno["size"]["width"]
boxes = []
classes = []
area = []
iscrowd = []
objects = anno["object"]
if not isinstance(objects, list):
objects = [objects]
for obj in objects:
bbox = obj["bndbox"]
bbox = [int(bbox[n]) - 1 for n in ["xmin", "ymin", "xmax", "ymax"]]
boxes.append(bbox)
classes.append(CLASSES.index(obj["name"]))
iscrowd.append(int(obj["difficult"]))
area.append((bbox[2] - bbox[0]) * (bbox[3] - bbox[1]))

boxes = torch.as_tensor(boxes, dtype=torch.float32)
classes = torch.as_tensor(classes)
area = torch.as_tensor(area)
iscrowd = torch.as_tensor(iscrowd)

image_id = anno["filename"][5:-4]
image_id = torch.as_tensor([int(image_id)])

target = {}
target["boxes"] = boxes
target["labels"] = classes
target["image_id"] = image_id

# for conversion to coco api
target["area"] = area
target["iscrowd"] = iscrowd

return image, target


class VOCDetectionDataModule(LightningDataModule):
name = "vocdetection"

def __init__(
self,
data_dir: str,
year: str = "2012",
num_workers: int = 16,
normalize: bool = False,
*args,
**kwargs,
):
"""
TODO(teddykoker) docstring
"""

super().__init__(*args, **kwargs)
self.year = year
self.data_dir = data_dir
self.num_workers = num_workers
self.normalize = normalize

@property
def num_classes(self):
"""
Return:
21
"""
return 21

def prepare_data(self):
"""
Saves VOCDetection files to data_dir
"""
VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)

def train_dataloader(self, batch_size=1, transforms=None):
"""
VOCDetection train set uses the `train` subset

Args:
batch_size: size of batch
transforms: custom transforms
"""
t = [_prepare_voc_instance]
transforms = transforms or self.train_transforms or self._default_transforms()
if transforms is not None:
t.append(transforms)
transforms = Compose(t)

dataset = VOCDetection(
self.data_dir, year=self.year, image_set="train", transforms=transforms
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=_collate_fn,
)
return loader

def val_dataloader(self, batch_size=1, transforms=None):
"""
VOCDetection val set uses the `val` subset

Args:
batch_size: size of batch
transforms: custom transforms
"""
t = [_prepare_voc_instance]
transforms = transforms or self.val_transforms or self._default_transforms()
if transforms is not None:
t.append(transforms)
transforms = Compose(t)
dataset = VOCDetection(
self.data_dir, year=self.year, image_set="val", transforms=transforms
)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=_collate_fn,
)
return loader

def _default_transforms(self):
if self.normalize:
return (
lambda image, target: (
T.Compose(
[
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)(image),
target,
),
)
return lambda image, target: (T.ToTensor()(image), target)
1 change: 1 addition & 0 deletions pl_bolts/models/detection/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from pl_bolts.models.detection.faster_rcnn import FasterRCNN
146 changes: 146 additions & 0 deletions pl_bolts/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import torch
from torch import nn
from torchvision.models.detection import faster_rcnn, fasterrcnn_resnet50_fpn
from torchvision.ops import box_iou

import pytorch_lightning as pl

from pytorch_lightning.metrics import IoU
from argparse import ArgumentParser

from pl_bolts.datamodules import VOCDetectionDataModule


def _evaluate_iou(target, pred):
"""
Evaluate intersection over union (IOU) for target from dataset and output prediction
from model
"""
if pred["boxes"].shape[0] == 0:
# no box detected, 0 IOU
return torch.tensor(0.0, device=pred["boxes"].device)
return box_iou(target["boxes"], pred["boxes"]).diag().mean()


class FasterRCNN(pl.LightningModule):
def __init__(
self,
learning_rate: float = 0.0001,
num_classes: int = 91,
pretrained: bool = False,
pretrained_backbone: bool = True,
trainable_backbone_layers: int = 3,
replace_head: bool = True,
**kwargs,
):
"""
PyTorch Lightning implementation of `Faster R-CNN: Towards Real-Time Object Detection with
Region Proposal Networks <https://arxiv.org/abs/1506.01497>`_.

Paper authors: Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun

Model implemented by:
- `Teddy Koker <https://github.com/teddykoker>`

During training, the model expects both the input tensors, as well as targets (list of dictionary), containing:
- boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format.
- labels (`Int64Tensor[N]`): the class label for each ground truh box

CLI command::

# PascalVOC
python faster_rcnn.py --gpus 1 --pretrained True

Args:
learning_rate: the learning rate
num_classes: number of detection classes (including background)
pretrained: if true, returns a model pre-trained on COCO train2017
pretrained_backbone: if true, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers: number of trainable resnet layers starting from final block
"""
super().__init__()

model = fasterrcnn_resnet50_fpn(
# num_classes=num_classes,
pretrained=pretrained,
pretrained_backbone=pretrained_backbone,
trainable_backbone_layers=trainable_backbone_layers,
)

if replace_head:
in_features = model.roi_heads.box_predictor.cls_score.in_features
head = faster_rcnn.FastRCNNPredictor(in_features, num_classes)
model.roi_heads.box_predictor = head
else:
assert num_classes == 91, "replace_head must be true to change num_classes"

self.model = model
self.learning_rate = learning_rate

def forward(self, x):
self.model.eval()
return self.model(x)

def training_step(self, batch, batch_idx):

images, targets = batch
targets = [{k: v for k, v in t.items()} for t in targets]

# fasterrcnn takes both images and targets for training, returns
loss_dict = self.model(images, targets)
loss = sum(loss for loss in loss_dict.values())
return {"loss": loss, "log": loss_dict}

def validation_step(self, batch, batch_idx):
images, targets = batch
# fasterrcnn takes only images for eval() mode
outs = self.model(images)
iou = torch.stack([_evaluate_iou(t, o) for t, o in zip(targets, outs)]).mean()
return {"val_iou": iou}

def validation_epoch_end(self, outs):
avg_iou = torch.stack([o["val_iou"] for o in outs]).mean()
logs = {"val_iou": avg_iou}
return {"avg_val_iou": avg_iou, "log": logs}

def configure_optimizers(self):
return torch.optim.SGD(
self.model.parameters(),
lr=self.learning_rate,
momentum=0.9,
weight_decay=0.005,
)

@staticmethod
def add_model_specific_args(parent_parser):
parser = ArgumentParser(parents=[parent_parser], add_help=False)
parser.add_argument("--learning_rate", type=float, default=0.0001)
parser.add_argument("--num_classes", type=int, default=91)
parser.add_argument("--pretrained", type=bool, default=False)
parser.add_argument("--pretrained_backbone", type=bool, default=True)
parser.add_argument("--trainable_backbone_layers", type=int, default=3)
parser.add_argument("--replace_head", type=bool, default=True)

parser.add_argument("--data_dir", type=str, default=".")
parser.add_argument("--batch_size", type=int, default=1)
return parser


def cli_main():
pl.seed_everything(42)
parser = ArgumentParser()
parser = pl.Trainer.add_argparse_args(parser)
parser = FasterRCNN.add_model_specific_args(parser)

args = parser.parse_args()

datamodule = VOCDetectionDataModule.from_argparse_args(args)
args.num_classes = datamodule.num_classes

model = FasterRCNN(**vars(args))
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, datamodule)


if __name__ == "__main__":
cli_main()
12 changes: 12 additions & 0 deletions tests/models/test_detection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch
from pl_bolts.models.detection import FasterRCNN


def test_fasterrcnn(tmpdir):
# NOTE: we probably want to test training, but the detection datasets are quite large
# so it could be time consuming on the test server

model = FasterRCNN()

image = torch.rand(1, 3, 400, 400)
model(image)