|
1 | 1 | import contextlib
|
| 2 | +import itertools |
2 | 3 | import pathlib
|
3 | 4 | import string
|
4 | 5 | import sys
|
@@ -33,87 +34,87 @@ def main(*, input_types, tasks, num_samples):
|
33 | 34 | # https://github.com/pytorch/pytorch/blob/19162083f8831be87be01bb84f186310cad1d348/torch/utils/data/_utils/worker.py#L222
|
34 | 35 | torch.set_num_threads(1)
|
35 | 36 |
|
| 37 | + dataset_rng = torch.Generator() |
| 38 | + dataset_rng.manual_seed(0) |
| 39 | + dataset_rng_state = dataset_rng.get_state() |
| 40 | + |
36 | 41 | for task_name in tasks:
|
37 | 42 | print("#" * 60)
|
38 | 43 | print(task_name)
|
39 | 44 | print("#" * 60)
|
40 | 45 |
|
41 | 46 | medians = {input_type: {} for input_type in input_types}
|
42 |
| - for input_type in input_types: |
43 |
| - dataset_rng = torch.Generator() |
44 |
| - dataset_rng.manual_seed(0) |
45 |
| - dataset_rng_state = dataset_rng.get_state() |
46 |
| - |
47 |
| - for api_version in ["v1", "v2"]: |
48 |
| - dataset_rng.set_state(dataset_rng_state) |
49 |
| - task = make_task( |
50 |
| - task_name, |
51 |
| - input_type=input_type, |
52 |
| - api_version=api_version, |
53 |
| - dataset_rng=dataset_rng, |
54 |
| - num_samples=num_samples, |
55 |
| - ) |
56 |
| - if task is None: |
57 |
| - continue |
58 |
| - |
59 |
| - print(f"{input_type=}, {api_version=}") |
60 |
| - print() |
61 |
| - print(f"Results computed for {num_samples:_} samples") |
62 |
| - print() |
63 |
| - |
64 |
| - pipeline, dataset = task |
65 |
| - |
66 |
| - for sample in dataset: |
67 |
| - pipeline(sample) |
68 |
| - |
69 |
| - results = pipeline.extract_times() |
70 |
| - field_len = max(len(name) for name in results) |
71 |
| - print(f"{' ' * field_len} {'median ':>9} {'std ':>9}") |
72 |
| - medians[input_type][api_version] = 0.0 |
73 |
| - for transform_name, times in results.items(): |
74 |
| - median = float(times.median()) |
75 |
| - print( |
76 |
| - f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs" |
77 |
| - ) |
78 |
| - medians[input_type][api_version] += median |
| 47 | + for input_type, api_version in itertools.product(input_types, ["v1", "v2"]): |
| 48 | + dataset_rng.set_state(dataset_rng_state) |
| 49 | + task = make_task( |
| 50 | + task_name, |
| 51 | + input_type=input_type, |
| 52 | + api_version=api_version, |
| 53 | + dataset_rng=dataset_rng, |
| 54 | + num_samples=num_samples, |
| 55 | + ) |
| 56 | + if task is None: |
| 57 | + continue |
79 | 58 |
|
80 |
| - print( |
81 |
| - f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs" |
82 |
| - ) |
83 |
| - print("-" * 60) |
| 59 | + print(f"{input_type=}, {api_version=}") |
| 60 | + print() |
| 61 | + print(f"Results computed for {num_samples:_} samples") |
| 62 | + print() |
84 | 63 |
|
85 |
| - print() |
86 |
| - print("Summaries") |
87 |
| - print() |
| 64 | + pipeline, dataset = task |
88 | 65 |
|
89 |
| - field_len = max(len(input_type) for input_type in medians) |
90 |
| - print(f"{' ' * field_len} v2 / v1") |
91 |
| - for input_type, api_versions in medians.items(): |
92 |
| - if len(api_versions) < 2: |
93 |
| - continue |
| 66 | + torch.manual_seed(0) |
| 67 | + for sample in dataset: |
| 68 | + pipeline(sample) |
| 69 | + |
| 70 | + results = pipeline.extract_times() |
| 71 | + field_len = max(len(name) for name in results) |
| 72 | + print(f"{' ' * field_len} {'median ':>9} {'std ':>9}") |
| 73 | + medians[input_type][api_version] = 0.0 |
| 74 | + for transform_name, times in results.items(): |
| 75 | + median = float(times.median()) |
| 76 | + print( |
| 77 | + f"{transform_name:{field_len}} {median * 1e6:6.0f} µs +- {float(times.std()) * 1e6:6.0f} µs" |
| 78 | + ) |
| 79 | + medians[input_type][api_version] += median |
94 | 80 |
|
95 | 81 | print(
|
96 |
| - f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}" |
| 82 | + f"\n{'total':{field_len}} {medians[input_type][api_version] * 1e6:6.0f} µs" |
97 | 83 | )
|
| 84 | + print("-" * 60) |
98 | 85 |
|
99 |
| - print() |
| 86 | + print() |
| 87 | + print("Summaries") |
| 88 | + print() |
100 | 89 |
|
101 |
| - medians_flat = { |
102 |
| - f"{input_type}, {api_version}": median |
103 |
| - for input_type, api_versions in medians.items() |
104 |
| - for api_version, median in api_versions.items() |
105 |
| - } |
106 |
| - field_len = max(len(label) for label in medians_flat) |
| 90 | + field_len = max(len(input_type) for input_type in medians) |
| 91 | + print(f"{' ' * field_len} v2 / v1") |
| 92 | + for input_type, api_versions in medians.items(): |
| 93 | + if len(api_versions) < 2: |
| 94 | + continue |
107 | 95 |
|
108 | 96 | print(
|
109 |
| - f"{' ' * (field_len + 5)} {' '.join(f' [{id}]' for _, id in zip(range(len(medians_flat)), string.ascii_lowercase))}" |
| 97 | + f"{input_type:{field_len}} {api_versions['v2'] / api_versions['v1']:>7.2f}" |
110 | 98 | )
|
111 |
| - for (label, val), id in zip(medians_flat.items(), string.ascii_lowercase): |
112 |
| - print( |
113 |
| - f"{label:>{field_len}}, [{id}] {' '.join(f'{val / ref:4.2f}' for ref in medians_flat.values())}" |
114 |
| - ) |
115 |
| - print() |
116 |
| - print("Slowdown as row / col") |
| 99 | + |
| 100 | + print() |
| 101 | + |
| 102 | + medians_flat = { |
| 103 | + f"{input_type}, {api_version}": median |
| 104 | + for input_type, api_versions in medians.items() |
| 105 | + for api_version, median in api_versions.items() |
| 106 | + } |
| 107 | + field_len = max(len(label) for label in medians_flat) |
| 108 | + |
| 109 | + print( |
| 110 | + f"{' ' * (field_len + 5)} {' '.join(f' [{id}]' for _, id in zip(range(len(medians_flat)), string.ascii_lowercase))}" |
| 111 | + ) |
| 112 | + for (label, val), id in zip(medians_flat.items(), string.ascii_lowercase): |
| 113 | + print( |
| 114 | + f"{label:>{field_len}}, [{id}] {' '.join(f'{val / ref:4.2f}' for ref in medians_flat.values())}" |
| 115 | + ) |
| 116 | + print() |
| 117 | + print("Slowdown as row / col") |
117 | 118 |
|
118 | 119 |
|
119 | 120 | if __name__ == "__main__":
|
|
0 commit comments