Skip to content

Commit 0ae9027

Browse files
authored
Merge pull request #3 from pmeier/detection
benchmark ssdlite detection pipeline
2 parents ef9b660 + 05350be commit 0ae9027

File tree

6 files changed

+819
-329
lines changed

6 files changed

+819
-329
lines changed

datasets.py

Lines changed: 92 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,102 @@
1-
import torch
1+
import pathlib
2+
3+
from torch.hub import tqdm
24

5+
from torchvision import datasets
36
from torchvision.transforms import functional as F_v1
47

8+
COCO_ROOT = "~/datasets/coco"
9+
10+
__all__ = ["classification_dataset_builder", "detection_dataset_builder"]
511

6-
def classification_dataset_builder(*, input_type, api_version, rng, num_samples):
12+
13+
def classification_dataset_builder(*, api_version, rng, num_samples):
714
return [
815
F_v1.to_pil_image(
916
# average size of images in ImageNet
10-
torch.randint(0, 256, (3, 469, 387), dtype=torch.uint8, generator=rng)
17+
torch.randint(0, 256, (3, 469, 387), dtype=torch.uint8, generator=rng),
1118
)
1219
for _ in range(num_samples)
1320
]
21+
22+
23+
def detection_dataset_builder(*, api_version, rng, num_samples):
24+
root = pathlib.Path(COCO_ROOT).expanduser().resolve()
25+
image_folder = str(root / "train2017")
26+
annotation_file = str(root / "annotations" / "instances_train2017.json")
27+
if api_version == "v1":
28+
dataset = CocoDetectionV1(image_folder, annotation_file, transforms=None)
29+
elif api_version == "v2":
30+
dataset = datasets.CocoDetection(image_folder, annotation_file)
31+
else:
32+
raise ValueError(f"Got {api_version=}")
33+
34+
dataset = _coco_remove_images_without_annotations(dataset)
35+
36+
idcs = torch.randperm(len(dataset), generator=rng)[:num_samples].tolist()
37+
print(f"Caching {num_samples} ({idcs[:3]} ... {idcs[-3:]}) COCO samples")
38+
return [dataset[idx] for idx in tqdm(idcs)]
39+
40+
41+
# everything below is copy-pasted from
42+
# https://github.com/pytorch/vision/blob/main/references/detection/coco_utils.py
43+
44+
import torch
45+
import torchvision
46+
47+
48+
class CocoDetectionV1(torchvision.datasets.CocoDetection):
49+
def __init__(self, img_folder, ann_file, transforms):
50+
super().__init__(img_folder, ann_file)
51+
self._transforms = transforms
52+
53+
def __getitem__(self, idx):
54+
img, target = super().__getitem__(idx)
55+
image_id = self.ids[idx]
56+
target = dict(image_id=image_id, annotations=target)
57+
if self._transforms is not None:
58+
img, target = self._transforms(img, target)
59+
return img, target
60+
61+
62+
def _coco_remove_images_without_annotations(dataset, cat_list=None):
63+
def _has_only_empty_bbox(anno):
64+
return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
65+
66+
def _count_visible_keypoints(anno):
67+
return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
68+
69+
min_keypoints_per_image = 10
70+
71+
def _has_valid_annotation(anno):
72+
# if it's empty, there is no annotation
73+
if len(anno) == 0:
74+
return False
75+
# if all boxes have close to zero area, there is no annotation
76+
if _has_only_empty_bbox(anno):
77+
return False
78+
# keypoints task have a slight different criteria for considering
79+
# if an annotation is valid
80+
if "keypoints" not in anno[0]:
81+
return True
82+
# for keypoint detection tasks, only consider valid images those
83+
# containing at least min_keypoints_per_image
84+
if _count_visible_keypoints(anno) >= min_keypoints_per_image:
85+
return True
86+
return False
87+
88+
if not isinstance(dataset, torchvision.datasets.CocoDetection):
89+
raise TypeError(
90+
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
91+
)
92+
ids = []
93+
for ds_idx, img_id in enumerate(dataset.ids):
94+
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
95+
anno = dataset.coco.loadAnns(ann_ids)
96+
if cat_list:
97+
anno = [obj for obj in anno if obj["category_id"] in cat_list]
98+
if _has_valid_annotation(anno):
99+
ids.append(ds_idx)
100+
101+
dataset = torch.utils.data.Subset(dataset, ids)
102+
return dataset

main.py

Lines changed: 78 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import contextlib
22
import itertools
33
import pathlib
4+
import string
45
import sys
56
from datetime import datetime
67

@@ -23,97 +24,111 @@ def write(self, message):
2324
self.stdout.write(message)
2425
self.file.write(message)
2526

27+
def flush(self):
28+
self.stdout.flush()
29+
self.file.flush()
30+
2631

2732
def main(*, input_types, tasks, num_samples):
2833
# This is hardcoded when using a DataLoader with multiple workers:
2934
# https://github.com/pytorch/pytorch/blob/19162083f8831be87be01bb84f186310cad1d348/torch/utils/data/_utils/worker.py#L222
3035
torch.set_num_threads(1)
3136

37+
dataset_rng = torch.Generator()
38+
dataset_rng.manual_seed(0)
39+
dataset_rng_state = dataset_rng.get_state()
40+
3241
for task_name in tasks:
3342
print("#" * 60)
3443
print(task_name)
3544
print("#" * 60)
3645

3746
medians = {input_type: {} for input_type in input_types}
38-
for input_type in input_types:
39-
dataset_rng = torch.Generator()
40-
dataset_rng.manual_seed(0)
41-
dataset_rng_state = dataset_rng.get_state()
42-
43-
for api_version in ["v1", "v2"]:
44-
dataset_rng.set_state(dataset_rng_state)
45-
task = make_task(
46-
task_name,
47-
input_type=input_type,
48-
api_version=api_version,
49-
dataset_rng=dataset_rng,
50-
num_samples=num_samples,
51-
)
52-
if task is None:
53-
continue
54-
55-
print(f"{input_type=}, {api_version=}")
56-
print()
57-
print(f"Results computed for {num_samples:_} samples")
58-
print()
59-
60-
pipeline, dataset = task
61-
62-
for sample in dataset:
63-
pipeline(sample)
64-
65-
results = pipeline.extract_times()
66-
field_len = max(len(name) for name in results)
67-
print(f"{' ' * field_len} {'median ':>9} {'std ':>9}")
68-
medians[input_type][api_version] = 0.0
69-
for transform_name, times in results.items():
70-
median = float(times.median())
71-
print(
72-
f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs"
73-
)
74-
medians[input_type][api_version] += median
47+
for input_type, api_version in itertools.product(input_types, ["v1", "v2"]):
48+
dataset_rng.set_state(dataset_rng_state)
49+
task = make_task(
50+
task_name,
51+
input_type=input_type,
52+
api_version=api_version,
53+
dataset_rng=dataset_rng,
54+
num_samples=num_samples,
55+
)
56+
if task is None:
57+
continue
7558

76-
print(
77-
f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs"
78-
)
79-
print("-" * 60)
59+
print(f"{input_type=}, {api_version=}")
60+
print()
61+
print(f"Results computed for {num_samples:_} samples")
62+
print()
8063

81-
print()
82-
print("Summaries")
83-
print()
64+
pipeline, dataset = task
8465

85-
field_len = max(len(input_type) for input_type in medians)
86-
print(f"{' ' * field_len} v2 / v1")
87-
for input_type, api_versions in medians.items():
88-
if len(api_versions) < 2:
89-
continue
66+
torch.manual_seed(0)
67+
for sample in dataset:
68+
pipeline(sample)
69+
70+
results = pipeline.extract_times()
71+
field_len = max(len(name) for name in results)
72+
print(f"{' ' * field_len} {'median ':>9} {'std ':>9}")
73+
medians[input_type][api_version] = 0.0
74+
for transform_name, times in results.items():
75+
median = float(times.median())
76+
print(
77+
f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs"
78+
)
79+
medians[input_type][api_version] += median
9080

9181
print(
92-
f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}"
82+
f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs"
9383
)
84+
print("-" * 60)
85+
86+
print()
87+
print("Summaries")
88+
print()
9489

95-
print()
90+
field_len = max(len(input_type) for input_type in medians)
91+
print(f"{' ' * field_len} v2 / v1")
92+
for input_type, api_versions in medians.items():
93+
if len(api_versions) < 2:
94+
continue
9695

97-
median_ref = medians["PIL"]["v1"]
98-
medians_flat = {
99-
f"{input_type}, {api_version}": median
100-
for input_type, api_versions in medians.items()
101-
for api_version, median in api_versions.items()
102-
}
103-
field_len = max(len(label) for label in medians_flat)
104-
print(f"{' ' * field_len} x / PIL, v1")
105-
for label, median in medians_flat.items():
106-
print(f"{label:{field_len}} {median / median_ref:>11.2f}")
96+
print(
97+
f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}"
98+
)
99+
100+
print()
101+
102+
medians_flat = {
103+
f"{input_type}, {api_version}": median
104+
for input_type, api_versions in medians.items()
105+
for api_version, median in api_versions.items()
106+
}
107+
field_len = max(len(label) for label in medians_flat)
108+
109+
print(
110+
f"{' ' * (field_len + 5)} {' '.join(f' [{id}]' for _, id in zip(range(len(medians_flat)), string.ascii_lowercase))}"
111+
)
112+
for (label, val), id in zip(medians_flat.items(), string.ascii_lowercase):
113+
print(
114+
f"{label:>{field_len}}, [{id}] {' '.join(f'{val / ref:4.2f}' for ref in medians_flat.values())}"
115+
)
116+
print()
117+
print("Slowdown as row / col")
107118

108119

109120
if __name__ == "__main__":
110121
tee = Tee(stdout=sys.stdout)
111122

112123
with contextlib.redirect_stdout(tee):
113124
main(
114-
tasks=["classification-simple", "classification-complex"],
125+
tasks=[
126+
"classification-simple",
127+
"classification-complex",
128+
"detection-ssdlite",
129+
],
115130
input_types=["Tensor", "PIL", "Datapoint"],
116-
num_samples=10_000,
131+
num_samples=1_000,
117132
)
118133

119134
print("#" * 60)

0 commit comments

Comments
 (0)