Skip to content

Commit

Permalink
Merge pull request #3 from pmeier/detection
Browse files Browse the repository at this point in the history
benchmark ssdlite detection pipeline
  • Loading branch information
pmeier authored Apr 4, 2023
2 parents ef9b660 + 05350be commit 0ae9027
Show file tree
Hide file tree
Showing 6 changed files with 819 additions and 329 deletions.
95 changes: 92 additions & 3 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,102 @@
import torch
import pathlib

from torch.hub import tqdm

from torchvision import datasets
from torchvision.transforms import functional as F_v1

COCO_ROOT = "~/datasets/coco"

__all__ = ["classification_dataset_builder", "detection_dataset_builder"]

def classification_dataset_builder(*, input_type, api_version, rng, num_samples):

def classification_dataset_builder(*, api_version, rng, num_samples):
return [
F_v1.to_pil_image(
# average size of images in ImageNet
torch.randint(0, 256, (3, 469, 387), dtype=torch.uint8, generator=rng)
torch.randint(0, 256, (3, 469, 387), dtype=torch.uint8, generator=rng),
)
for _ in range(num_samples)
]


def detection_dataset_builder(*, api_version, rng, num_samples):
root = pathlib.Path(COCO_ROOT).expanduser().resolve()
image_folder = str(root / "train2017")
annotation_file = str(root / "annotations" / "instances_train2017.json")
if api_version == "v1":
dataset = CocoDetectionV1(image_folder, annotation_file, transforms=None)
elif api_version == "v2":
dataset = datasets.CocoDetection(image_folder, annotation_file)
else:
raise ValueError(f"Got {api_version=}")

dataset = _coco_remove_images_without_annotations(dataset)

idcs = torch.randperm(len(dataset), generator=rng)[:num_samples].tolist()
print(f"Caching {num_samples} ({idcs[:3]} ... {idcs[-3:]}) COCO samples")
return [dataset[idx] for idx in tqdm(idcs)]


# everything below is copy-pasted from
# https://github.com/pytorch/vision/blob/main/references/detection/coco_utils.py

import torch
import torchvision


class CocoDetectionV1(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms):
super().__init__(img_folder, ann_file)
self._transforms = transforms

def __getitem__(self, idx):
img, target = super().__getitem__(idx)
image_id = self.ids[idx]
target = dict(image_id=image_id, annotations=target)
if self._transforms is not None:
img, target = self._transforms(img, target)
return img, target


def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_only_empty_bbox(anno):
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)

def _count_visible_keypoints(anno):
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)

min_keypoints_per_image = 10

def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if all boxes have close to zero area, there is no annotation
if _has_only_empty_bbox(anno):
return False
# keypoints task have a slight different criteria for considering
# if an annotation is valid
if "keypoints" not in anno[0]:
return True
# for keypoint detection tasks, only consider valid images those
# containing at least min_keypoints_per_image
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
return True
return False

if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)

dataset = torch.utils.data.Subset(dataset, ids)
return dataset
141 changes: 78 additions & 63 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import contextlib
import itertools
import pathlib
import string
import sys
from datetime import datetime

Expand All @@ -23,97 +24,111 @@ def write(self, message):
self.stdout.write(message)
self.file.write(message)

def flush(self):
self.stdout.flush()
self.file.flush()


def main(*, input_types, tasks, num_samples):
# This is hardcoded when using a DataLoader with multiple workers:
# https://github.com/pytorch/pytorch/blob/19162083f8831be87be01bb84f186310cad1d348/torch/utils/data/_utils/worker.py#L222
torch.set_num_threads(1)

dataset_rng = torch.Generator()
dataset_rng.manual_seed(0)
dataset_rng_state = dataset_rng.get_state()

for task_name in tasks:
print("#" * 60)
print(task_name)
print("#" * 60)

medians = {input_type: {} for input_type in input_types}
for input_type in input_types:
dataset_rng = torch.Generator()
dataset_rng.manual_seed(0)
dataset_rng_state = dataset_rng.get_state()

for api_version in ["v1", "v2"]:
dataset_rng.set_state(dataset_rng_state)
task = make_task(
task_name,
input_type=input_type,
api_version=api_version,
dataset_rng=dataset_rng,
num_samples=num_samples,
)
if task is None:
continue

print(f"{input_type=}, {api_version=}")
print()
print(f"Results computed for {num_samples:_} samples")
print()

pipeline, dataset = task

for sample in dataset:
pipeline(sample)

results = pipeline.extract_times()
field_len = max(len(name) for name in results)
print(f"{' ' * field_len} {'median ':>9} {'std ':>9}")
medians[input_type][api_version] = 0.0
for transform_name, times in results.items():
median = float(times.median())
print(
f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs"
)
medians[input_type][api_version] += median
for input_type, api_version in itertools.product(input_types, ["v1", "v2"]):
dataset_rng.set_state(dataset_rng_state)
task = make_task(
task_name,
input_type=input_type,
api_version=api_version,
dataset_rng=dataset_rng,
num_samples=num_samples,
)
if task is None:
continue

print(
f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs"
)
print("-" * 60)
print(f"{input_type=}, {api_version=}")
print()
print(f"Results computed for {num_samples:_} samples")
print()

print()
print("Summaries")
print()
pipeline, dataset = task

field_len = max(len(input_type) for input_type in medians)
print(f"{' ' * field_len} v2 / v1")
for input_type, api_versions in medians.items():
if len(api_versions) < 2:
continue
torch.manual_seed(0)
for sample in dataset:
pipeline(sample)

results = pipeline.extract_times()
field_len = max(len(name) for name in results)
print(f"{' ' * field_len} {'median ':>9} {'std ':>9}")
medians[input_type][api_version] = 0.0
for transform_name, times in results.items():
median = float(times.median())
print(
f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs"
)
medians[input_type][api_version] += median

print(
f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}"
f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs"
)
print("-" * 60)

print()
print("Summaries")
print()

print()
field_len = max(len(input_type) for input_type in medians)
print(f"{' ' * field_len} v2 / v1")
for input_type, api_versions in medians.items():
if len(api_versions) < 2:
continue

median_ref = medians["PIL"]["v1"]
medians_flat = {
f"{input_type}, {api_version}": median
for input_type, api_versions in medians.items()
for api_version, median in api_versions.items()
}
field_len = max(len(label) for label in medians_flat)
print(f"{' ' * field_len} x / PIL, v1")
for label, median in medians_flat.items():
print(f"{label:{field_len}} {median / median_ref:>11.2f}")
print(
f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}"
)

print()

medians_flat = {
f"{input_type}, {api_version}": median
for input_type, api_versions in medians.items()
for api_version, median in api_versions.items()
}
field_len = max(len(label) for label in medians_flat)

print(
f"{' ' * (field_len + 5)} {' '.join(f' [{id}]' for _, id in zip(range(len(medians_flat)), string.ascii_lowercase))}"
)
for (label, val), id in zip(medians_flat.items(), string.ascii_lowercase):
print(
f"{label:>{field_len}}, [{id}] {' '.join(f'{val / ref:4.2f}' for ref in medians_flat.values())}"
)
print()
print("Slowdown as row / col")


if __name__ == "__main__":
tee = Tee(stdout=sys.stdout)

with contextlib.redirect_stdout(tee):
main(
tasks=["classification-simple", "classification-complex"],
tasks=[
"classification-simple",
"classification-complex",
"detection-ssdlite",
],
input_types=["Tensor", "PIL", "Datapoint"],
num_samples=10_000,
num_samples=1_000,
)

print("#" * 60)
Expand Down
Loading

0 comments on commit 0ae9027

Please sign in to comment.