Skip to content

Commit 05350be

Browse files
committed
cleanup
1 parent a859c09 commit 05350be

File tree

4 files changed

+417
-72
lines changed

4 files changed

+417
-72
lines changed

datasets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ def detection_dataset_builder(*, api_version, rng, num_samples):
3333

3434
dataset = _coco_remove_images_without_annotations(dataset)
3535

36-
idcs = torch.randperm(len(dataset), generator=rng)[:num_samples]
37-
print(f"Caching {num_samples} COCO samples")
38-
return [dataset[idx] for idx in tqdm(idcs.tolist())]
36+
idcs = torch.randperm(len(dataset), generator=rng)[:num_samples].tolist()
37+
print(f"Caching {num_samples} ({idcs[:3]} ... {idcs[-3:]}) COCO samples")
38+
return [dataset[idx] for idx in tqdm(idcs)]
3939

4040

4141
# everything below is copy-pasted from

main.py

Lines changed: 65 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import contextlib
2+
import itertools
23
import pathlib
34
import string
45
import sys
@@ -33,87 +34,87 @@ def main(*, input_types, tasks, num_samples):
3334
# https://github.com/pytorch/pytorch/blob/19162083f8831be87be01bb84f186310cad1d348/torch/utils/data/_utils/worker.py#L222
3435
torch.set_num_threads(1)
3536

37+
dataset_rng = torch.Generator()
38+
dataset_rng.manual_seed(0)
39+
dataset_rng_state = dataset_rng.get_state()
40+
3641
for task_name in tasks:
3742
print("#" * 60)
3843
print(task_name)
3944
print("#" * 60)
4045

4146
medians = {input_type: {} for input_type in input_types}
42-
for input_type in input_types:
43-
dataset_rng = torch.Generator()
44-
dataset_rng.manual_seed(0)
45-
dataset_rng_state = dataset_rng.get_state()
46-
47-
for api_version in ["v1", "v2"]:
48-
dataset_rng.set_state(dataset_rng_state)
49-
task = make_task(
50-
task_name,
51-
input_type=input_type,
52-
api_version=api_version,
53-
dataset_rng=dataset_rng,
54-
num_samples=num_samples,
55-
)
56-
if task is None:
57-
continue
58-
59-
print(f"{input_type=}, {api_version=}")
60-
print()
61-
print(f"Results computed for {num_samples:_} samples")
62-
print()
63-
64-
pipeline, dataset = task
65-
66-
for sample in dataset:
67-
pipeline(sample)
68-
69-
results = pipeline.extract_times()
70-
field_len = max(len(name) for name in results)
71-
print(f"{' ' * field_len} {'median ':>9} {'std ':>9}")
72-
medians[input_type][api_version] = 0.0
73-
for transform_name, times in results.items():
74-
median = float(times.median())
75-
print(
76-
f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs"
77-
)
78-
medians[input_type][api_version] += median
47+
for input_type, api_version in itertools.product(input_types, ["v1", "v2"]):
48+
dataset_rng.set_state(dataset_rng_state)
49+
task = make_task(
50+
task_name,
51+
input_type=input_type,
52+
api_version=api_version,
53+
dataset_rng=dataset_rng,
54+
num_samples=num_samples,
55+
)
56+
if task is None:
57+
continue
7958

80-
print(
81-
f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs"
82-
)
83-
print("-" * 60)
59+
print(f"{input_type=}, {api_version=}")
60+
print()
61+
print(f"Results computed for {num_samples:_} samples")
62+
print()
8463

85-
print()
86-
print("Summaries")
87-
print()
64+
pipeline, dataset = task
8865

89-
field_len = max(len(input_type) for input_type in medians)
90-
print(f"{' ' * field_len} v2 / v1")
91-
for input_type, api_versions in medians.items():
92-
if len(api_versions) < 2:
93-
continue
66+
torch.manual_seed(0)
67+
for sample in dataset:
68+
pipeline(sample)
69+
70+
results = pipeline.extract_times()
71+
field_len = max(len(name) for name in results)
72+
print(f"{' ' * field_len} {'median ':>9} {'std ':>9}")
73+
medians[input_type][api_version] = 0.0
74+
for transform_name, times in results.items():
75+
median = float(times.median())
76+
print(
77+
f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs"
78+
)
79+
medians[input_type][api_version] += median
9480

9581
print(
96-
f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}"
82+
f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs"
9783
)
84+
print("-" * 60)
9885

99-
print()
86+
print()
87+
print("Summaries")
88+
print()
10089

101-
medians_flat = {
102-
f"{input_type}, {api_version}": median
103-
for input_type, api_versions in medians.items()
104-
for api_version, median in api_versions.items()
105-
}
106-
field_len = max(len(label) for label in medians_flat)
90+
field_len = max(len(input_type) for input_type in medians)
91+
print(f"{' ' * field_len} v2 / v1")
92+
for input_type, api_versions in medians.items():
93+
if len(api_versions) < 2:
94+
continue
10795

10896
print(
109-
f"{' ' * (field_len + 5)} {' '.join(f' [{id}]' for _, id in zip(range(len(medians_flat)), string.ascii_lowercase))}"
97+
f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}"
11098
)
111-
for (label, val), id in zip(medians_flat.items(), string.ascii_lowercase):
112-
print(
113-
f"{label:>{field_len}}, [{id}] {' '.join(f'{val / ref:4.2f}' for ref in medians_flat.values())}"
114-
)
115-
print()
116-
print("Slowdown as row / col")
99+
100+
print()
101+
102+
medians_flat = {
103+
f"{input_type}, {api_version}": median
104+
for input_type, api_versions in medians.items()
105+
for api_version, median in api_versions.items()
106+
}
107+
field_len = max(len(label) for label in medians_flat)
108+
109+
print(
110+
f"{' ' * (field_len + 5)} {' '.join(f' [{id}]' for _, id in zip(range(len(medians_flat)), string.ascii_lowercase))}"
111+
)
112+
for (label, val), id in zip(medians_flat.items(), string.ascii_lowercase):
113+
print(
114+
f"{label:>{field_len}}, [{id}] {' '.join(f'{val / ref:4.2f}' for ref in medians_flat.values())}"
115+
)
116+
print()
117+
print("Slowdown as row / col")
117118

118119

119120
if __name__ == "__main__":

0 commit comments

Comments
 (0)