Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Apr 4, 2023
1 parent a859c09 commit 05350be
Show file tree
Hide file tree
Showing 4 changed files with 417 additions and 72 deletions.
6 changes: 3 additions & 3 deletions datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def detection_dataset_builder(*, api_version, rng, num_samples):

dataset = _coco_remove_images_without_annotations(dataset)

idcs = torch.randperm(len(dataset), generator=rng)[:num_samples]
print(f"Caching {num_samples} COCO samples")
return [dataset[idx] for idx in tqdm(idcs.tolist())]
idcs = torch.randperm(len(dataset), generator=rng)[:num_samples].tolist()
print(f"Caching {num_samples} ({idcs[:3]} ... {idcs[-3:]}) COCO samples")
return [dataset[idx] for idx in tqdm(idcs)]


# everything below is copy-pasted from
Expand Down
129 changes: 65 additions & 64 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import contextlib
import itertools
import pathlib
import string
import sys
Expand Down Expand Up @@ -33,87 +34,87 @@ def main(*, input_types, tasks, num_samples):
# https://github.com/pytorch/pytorch/blob/19162083f8831be87be01bb84f186310cad1d348/torch/utils/data/_utils/worker.py#L222
torch.set_num_threads(1)

dataset_rng = torch.Generator()
dataset_rng.manual_seed(0)
dataset_rng_state = dataset_rng.get_state()

for task_name in tasks:
print("#" * 60)
print(task_name)
print("#" * 60)

medians = {input_type: {} for input_type in input_types}
for input_type in input_types:
dataset_rng = torch.Generator()
dataset_rng.manual_seed(0)
dataset_rng_state = dataset_rng.get_state()

for api_version in ["v1", "v2"]:
dataset_rng.set_state(dataset_rng_state)
task = make_task(
task_name,
input_type=input_type,
api_version=api_version,
dataset_rng=dataset_rng,
num_samples=num_samples,
)
if task is None:
continue

print(f"{input_type=}, {api_version=}")
print()
print(f"Results computed for {num_samples:_} samples")
print()

pipeline, dataset = task

for sample in dataset:
pipeline(sample)

results = pipeline.extract_times()
field_len = max(len(name) for name in results)
print(f"{' ' * field_len} {'median ':>9} {'std ':>9}")
medians[input_type][api_version] = 0.0
for transform_name, times in results.items():
median = float(times.median())
print(
f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs"
)
medians[input_type][api_version] += median
for input_type, api_version in itertools.product(input_types, ["v1", "v2"]):
dataset_rng.set_state(dataset_rng_state)
task = make_task(
task_name,
input_type=input_type,
api_version=api_version,
dataset_rng=dataset_rng,
num_samples=num_samples,
)
if task is None:
continue

print(
f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs"
)
print("-" * 60)
print(f"{input_type=}, {api_version=}")
print()
print(f"Results computed for {num_samples:_} samples")
print()

print()
print("Summaries")
print()
pipeline, dataset = task

field_len = max(len(input_type) for input_type in medians)
print(f"{' ' * field_len} v2 / v1")
for input_type, api_versions in medians.items():
if len(api_versions) < 2:
continue
torch.manual_seed(0)
for sample in dataset:
pipeline(sample)

results = pipeline.extract_times()
field_len = max(len(name) for name in results)
print(f"{' ' * field_len} {'median ':>9} {'std ':>9}")
medians[input_type][api_version] = 0.0
for transform_name, times in results.items():
median = float(times.median())
print(
f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs"
)
medians[input_type][api_version] += median

print(
f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}"
f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs"
)
print("-" * 60)

print()
print()
print("Summaries")
print()

medians_flat = {
f"{input_type}, {api_version}": median
for input_type, api_versions in medians.items()
for api_version, median in api_versions.items()
}
field_len = max(len(label) for label in medians_flat)
field_len = max(len(input_type) for input_type in medians)
print(f"{' ' * field_len} v2 / v1")
for input_type, api_versions in medians.items():
if len(api_versions) < 2:
continue

print(
f"{' ' * (field_len + 5)} {' '.join(f' [{id}]' for _, id in zip(range(len(medians_flat)), string.ascii_lowercase))}"
f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}"
)
for (label, val), id in zip(medians_flat.items(), string.ascii_lowercase):
print(
f"{label:>{field_len}}, [{id}] {' '.join(f'{val / ref:4.2f}' for ref in medians_flat.values())}"
)
print()
print("Slowdown as row / col")

print()

medians_flat = {
f"{input_type}, {api_version}": median
for input_type, api_versions in medians.items()
for api_version, median in api_versions.items()
}
field_len = max(len(label) for label in medians_flat)

print(
f"{' ' * (field_len + 5)} {' '.join(f' [{id}]' for _, id in zip(range(len(medians_flat)), string.ascii_lowercase))}"
)
for (label, val), id in zip(medians_flat.items(), string.ascii_lowercase):
print(
f"{label:>{field_len}}, [{id}] {' '.join(f'{val / ref:4.2f}' for ref in medians_flat.values())}"
)
print()
print("Slowdown as row / col")


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 05350be

Please sign in to comment.