18
18
from float8_experimental .float8_linear_utils import sync_float8_amax_and_scale_history
19
19
from tqdm import tqdm
20
20
21
- # Check if transformer_engine is installed
22
- transformer_engine_installed = False
23
- try :
24
- import transformer_engine .pytorch as te
25
- from transformer_engine .common import recipe
26
-
27
- transformer_engine_installed = True
28
- except ImportError :
29
- print ("transformer_engine not installed and we won't compare against this" )
30
-
31
21
# estimating TOPs for matmuls in fp32, fp16, fp8
32
22
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
33
23
@@ -66,7 +56,6 @@ class Experiment:
66
56
dtype : torch .dtype
67
57
compiled : bool = False
68
58
float_8_dtype : Optional [torch .dtype ] = torch .float8_e4m3fn
69
- te_time_sec : Optional [float ] = None
70
59
71
60
# 3 Times since we are calculating forward backward
72
61
@property
@@ -87,21 +76,6 @@ def float8_tops_sec(self):
87
76
def float8_pct_top_peak (self ):
88
77
return self .float8_tops_sec / dtype_to_peak_tops [self .float_8_dtype ]
89
78
90
- @property
91
- def te_tops_sec (self ):
92
- M , K , N = self .shape
93
- if self .te_time_sec is not None :
94
- return float (3 * (2 * M * K * N )) / self .te_time_sec
95
- else :
96
- return None
97
-
98
- @property
99
- def te_pct_top_peak (self ):
100
- if self .te_tops_sec is not None :
101
- return self .te_tops_sec / dtype_to_peak_tops [self .float_8_dtype ]
102
- else :
103
- return None
104
-
105
79
106
80
def main (
107
81
sweep_path : Path ,
@@ -113,7 +87,6 @@ def main(
113
87
114
88
# LLaMa 2 70B single-node weight shapes
115
89
# assumes fused attn.wqkv and ffn.w13
116
- # source: https://fburl.com/gsheet/g8onr7rh
117
90
name_to_shapes_70b = {
118
91
"attn.wqkv" : (8192 , 1280 ),
119
92
"attn.w0" : (1024 , 8192 ),
@@ -145,19 +118,6 @@ def float8_forw_backward():
145
118
sync_float8_amax_and_scale_history (linear_float8 )
146
119
linear_float8 (input_tensor ).sum ().backward ()
147
120
148
- if transformer_engine_installed :
149
- # Use the same recipe as float8_linear.DelayedScalingRecipe
150
- fp8_format = recipe .Format .HYBRID
151
- fp8_recipe = recipe .DelayedScaling (
152
- fp8_format = fp8_format , amax_history_len = 16 , amax_compute_algo = "max"
153
- )
154
- te_linear = te .Linear (K , N , bias = input_bias ).to (device = device , dtype = dtype )
155
-
156
- def te_forw_backward ():
157
- with te .fp8_autocast (enabled = True , fp8_recipe = fp8_recipe ):
158
- y = te_linear (input_tensor )
159
- y .sum ().backward ()
160
-
161
121
def n_times (n , fn , * args , ** kwargs ):
162
122
def wrapper (* args , ** kwargs ):
163
123
for _ in range (n ):
@@ -169,21 +129,14 @@ def wrapper(*args, **kwargs):
169
129
170
130
ref_forw_backward = n_times (REPEAT_N , ref_forw_backward )
171
131
float8_forw_backward = n_times (REPEAT_N , float8_forw_backward )
172
- if transformer_engine_installed :
173
- te_forw_backward = n_times (REPEAT_N , te_forw_backward )
174
132
175
133
if compile :
176
134
ref_forw_backward = torch .compile (ref_forw_backward )
177
135
float8_forw_backward = torch .compile (float8_forw_backward )
178
- # Compiling TE_linear fails but they are already compiling under the hood
179
- # if transformer_engine_installed:
180
- # te_forw_backward = torch.compile(te_forw_backward)
181
136
182
137
for _ in range (5 ):
183
138
ref_forw_backward ()
184
139
float8_forw_backward ()
185
- if transformer_engine_installed :
186
- te_forw_backward ()
187
140
188
141
ref_time = (
189
142
benchmark_torch_function_in_microseconds (ref_forw_backward )
@@ -195,27 +148,16 @@ def wrapper(*args, **kwargs):
195
148
* 1e-6
196
149
/ REPEAT_N
197
150
)
198
- if transformer_engine_installed :
199
- te_time_sec = (
200
- benchmark_torch_function_in_microseconds (te_forw_backward )
201
- * 1e-6
202
- / REPEAT_N
203
- )
204
- else :
205
- te_time_sec = None
206
151
experiment = Experiment (
207
152
name ,
208
153
(M , K , N ),
209
154
ref_time ,
210
155
float8_time ,
211
156
dtype ,
212
157
compile ,
213
- te_time_sec = te_time_sec ,
214
158
)
215
159
print (experiment )
216
160
print ("float8 speedup" , experiment .ref_time_sec / experiment .float8_time_sec )
217
- if transformer_engine_installed :
218
- print ("te speedup" , experiment .ref_time_sec / experiment .te_time_sec )
219
161
experiment_list .append (experiment )
220
162
torch ._dynamo .reset ()
221
163
@@ -229,13 +171,10 @@ def wrapper(*args, **kwargs):
229
171
"fp8_dtype" ,
230
172
"ref_time_sec" ,
231
173
"pt_fp8_time_sec" ,
232
- "te_fp8_time_sec" ,
233
174
"ref_tops_sec" ,
234
175
"ref_pct_top_peak" ,
235
176
"pt_fp8_tops_sec" ,
236
177
"pt_fp8_pct_top_peak" ,
237
- "te_fp8_tops_sec" ,
238
- "te_fp8_pct_top_peak" ,
239
178
]
240
179
data = []
241
180
for experiment in experiment_list :
@@ -250,22 +189,15 @@ def wrapper(*args, **kwargs):
250
189
experiment .float_8_dtype ,
251
190
experiment .ref_time_sec ,
252
191
experiment .float8_time_sec ,
253
- experiment .te_time_sec ,
254
192
experiment .ref_tops_sec ,
255
193
experiment .ref_pct_top_peak ,
256
194
experiment .float8_tops_sec ,
257
195
experiment .float8_pct_top_peak ,
258
- experiment .te_tops_sec ,
259
- experiment .te_pct_top_peak ,
260
196
]
261
197
)
262
198
263
199
data_pd = pd .DataFrame (data , columns = headers )
264
200
data_pd ["pt_fp8_speedup" ] = data_pd ["ref_time_sec" ] / data_pd ["pt_fp8_time_sec" ]
265
- if transformer_engine_installed :
266
- data_pd ["te_fp8_speedup" ] = data_pd ["ref_time_sec" ] / data_pd ["te_fp8_time_sec" ]
267
- else :
268
- data_pd ["te_fp8_speedup" ] = - 1.0
269
201
data_pd ["shape" ] = (
270
202
"("
271
203
+ data_pd ["M" ].astype (str )
@@ -284,9 +216,7 @@ def wrapper(*args, **kwargs):
284
216
"compiled" ,
285
217
"ref_time_sec" ,
286
218
"pt_fp8_time_sec" ,
287
- "te_fp8_time_sec" ,
288
219
"pt_fp8_speedup" ,
289
- "te_fp8_speedup" ,
290
220
]
291
221
]
292
222
print (data_pd_simple )
0 commit comments