1
1
import contextlib
2
2
import itertools
3
3
import pathlib
4
+ import string
4
5
import sys
5
6
from datetime import datetime
6
7
@@ -23,97 +24,111 @@ def write(self, message):
23
24
self .stdout .write (message )
24
25
self .file .write (message )
25
26
27
+ def flush (self ):
28
+ self .stdout .flush ()
29
+ self .file .flush ()
30
+
26
31
27
32
def main (* , input_types , tasks , num_samples ):
28
33
# This is hardcoded when using a DataLoader with multiple workers:
29
34
# https://github.com/pytorch/pytorch/blob/19162083f8831be87be01bb84f186310cad1d348/torch/utils/data/_utils/worker.py#L222
30
35
torch .set_num_threads (1 )
31
36
37
+ dataset_rng = torch .Generator ()
38
+ dataset_rng .manual_seed (0 )
39
+ dataset_rng_state = dataset_rng .get_state ()
40
+
32
41
for task_name in tasks :
33
42
print ("#" * 60 )
34
43
print (task_name )
35
44
print ("#" * 60 )
36
45
37
46
medians = {input_type : {} for input_type in input_types }
38
- for input_type in input_types :
39
- dataset_rng = torch .Generator ()
40
- dataset_rng .manual_seed (0 )
41
- dataset_rng_state = dataset_rng .get_state ()
42
-
43
- for api_version in ["v1" , "v2" ]:
44
- dataset_rng .set_state (dataset_rng_state )
45
- task = make_task (
46
- task_name ,
47
- input_type = input_type ,
48
- api_version = api_version ,
49
- dataset_rng = dataset_rng ,
50
- num_samples = num_samples ,
51
- )
52
- if task is None :
53
- continue
54
-
55
- print (f"{ input_type = } , { api_version = } " )
56
- print ()
57
- print (f"Results computed for { num_samples :_} samples" )
58
- print ()
59
-
60
- pipeline , dataset = task
61
-
62
- for sample in dataset :
63
- pipeline (sample )
64
-
65
- results = pipeline .extract_times ()
66
- field_len = max (len (name ) for name in results )
67
- print (f"{ ' ' * field_len } { 'median ' :>9} { 'std ' :>9} " )
68
- medians [input_type ][api_version ] = 0.0
69
- for transform_name , times in results .items ():
70
- median = float (times .median ())
71
- print (
72
- f"{ transform_name :{field_len }} { median * 1e6 :6.0f} µs +- { float (times .std ()) * 1e6 :6.0f} µs"
73
- )
74
- medians [input_type ][api_version ] += median
47
+ for input_type , api_version in itertools .product (input_types , ["v1" , "v2" ]):
48
+ dataset_rng .set_state (dataset_rng_state )
49
+ task = make_task (
50
+ task_name ,
51
+ input_type = input_type ,
52
+ api_version = api_version ,
53
+ dataset_rng = dataset_rng ,
54
+ num_samples = num_samples ,
55
+ )
56
+ if task is None :
57
+ continue
75
58
76
- print (
77
- f" \n { 'total' :{ field_len } } { medians [ input_type ][ api_version ] * 1e6 :6.0f } µs"
78
- )
79
- print ("-" * 60 )
59
+ print (f" { input_type = } , { api_version = } " )
60
+ print ()
61
+ print ( f"Results computed for { num_samples :_ } samples" )
62
+ print ()
80
63
81
- print ()
82
- print ("Summaries" )
83
- print ()
64
+ pipeline , dataset = task
84
65
85
- field_len = max (len (input_type ) for input_type in medians )
86
- print (f"{ ' ' * field_len } v2 / v1" )
87
- for input_type , api_versions in medians .items ():
88
- if len (api_versions ) < 2 :
89
- continue
66
+ torch .manual_seed (0 )
67
+ for sample in dataset :
68
+ pipeline (sample )
69
+
70
+ results = pipeline .extract_times ()
71
+ field_len = max (len (name ) for name in results )
72
+ print (f"{ ' ' * field_len } { 'median ' :>9} { 'std ' :>9} " )
73
+ medians [input_type ][api_version ] = 0.0
74
+ for transform_name , times in results .items ():
75
+ median = float (times .median ())
76
+ print (
77
+ f"{ transform_name :{field_len }} { median * 1e6 :6.0f} µs +- { float (times .std ()) * 1e6 :6.0f} µs"
78
+ )
79
+ medians [input_type ][api_version ] += median
90
80
91
81
print (
92
- f"{ input_type :{field_len }} { api_versions [ 'v2' ] / api_versions [ 'v1' ]:>7.2f } "
82
+ f"\n { 'total' :{field_len }} { medians [ input_type ][ api_version ] * 1e6 :6.0f } µs "
93
83
)
84
+ print ("-" * 60 )
85
+
86
+ print ()
87
+ print ("Summaries" )
88
+ print ()
94
89
95
- print ()
90
+ field_len = max (len (input_type ) for input_type in medians )
91
+ print (f"{ ' ' * field_len } v2 / v1" )
92
+ for input_type , api_versions in medians .items ():
93
+ if len (api_versions ) < 2 :
94
+ continue
96
95
97
- median_ref = medians ["PIL" ]["v1" ]
98
- medians_flat = {
99
- f"{ input_type } , { api_version } " : median
100
- for input_type , api_versions in medians .items ()
101
- for api_version , median in api_versions .items ()
102
- }
103
- field_len = max (len (label ) for label in medians_flat )
104
- print (f"{ ' ' * field_len } x / PIL, v1" )
105
- for label , median in medians_flat .items ():
106
- print (f"{ label :{field_len }} { median / median_ref :>11.2f} " )
96
+ print (
97
+ f"{ input_type :{field_len }} { api_versions ['v2' ] / api_versions ['v1' ]:>7.2f} "
98
+ )
99
+
100
+ print ()
101
+
102
+ medians_flat = {
103
+ f"{ input_type } , { api_version } " : median
104
+ for input_type , api_versions in medians .items ()
105
+ for api_version , median in api_versions .items ()
106
+ }
107
+ field_len = max (len (label ) for label in medians_flat )
108
+
109
+ print (
110
+ f"{ ' ' * (field_len + 5 )} { ' ' .join (f' [{ id } ]' for _ , id in zip (range (len (medians_flat )), string .ascii_lowercase ))} "
111
+ )
112
+ for (label , val ), id in zip (medians_flat .items (), string .ascii_lowercase ):
113
+ print (
114
+ f"{ label :>{field_len }} , [{ id } ] { ' ' .join (f'{ val / ref :4.2f} ' for ref in medians_flat .values ())} "
115
+ )
116
+ print ()
117
+ print ("Slowdown as row / col" )
107
118
108
119
109
120
if __name__ == "__main__" :
110
121
tee = Tee (stdout = sys .stdout )
111
122
112
123
with contextlib .redirect_stdout (tee ):
113
124
main (
114
- tasks = ["classification-simple" , "classification-complex" ],
125
+ tasks = [
126
+ "classification-simple" ,
127
+ "classification-complex" ,
128
+ "detection-ssdlite" ,
129
+ ],
115
130
input_types = ["Tensor" , "PIL" , "Datapoint" ],
116
- num_samples = 10_000 ,
131
+ num_samples = 1_000 ,
117
132
)
118
133
119
134
print ("#" * 60 )
0 commit comments