Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit d272138

Browse files
vkuzofacebook-github-bot
authored andcommitted
remove TE from the codebase (#184)
Summary: as titled Pull Request resolved: #184 Test Plan: ``` // scripts work python benchmarks/profile_linear_float8.py ../tmp/ False dynamic python benchmarks/bench_linear_float8.py -o ../tmp/test.txt -n 1 CUDA_VISIBLE_DEVICES=0,1 python benchmarks/bench_multi_gpu.py // no hits grep -r transformer_engine . ``` Reviewed By: drisspg Differential Revision: D52715981 Pulled By: vkuzo fbshipit-source-id: 30d8036e3454148d5611585984b5f21ecbe674d3
1 parent d0af81a commit d272138

File tree

3 files changed

+5
-176
lines changed

3 files changed

+5
-176
lines changed

benchmarks/bench_linear_float8.py

Lines changed: 0 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,6 @@
1818
from float8_experimental.float8_linear_utils import sync_float8_amax_and_scale_history
1919
from tqdm import tqdm
2020

21-
# Check if transformer_engine is installed
22-
transformer_engine_installed = False
23-
try:
24-
import transformer_engine.pytorch as te
25-
from transformer_engine.common import recipe
26-
27-
transformer_engine_installed = True
28-
except ImportError:
29-
print("transformer_engine not installed and we won't compare against this")
30-
3121
# estimating TOPs for matmuls in fp32, fp16, fp8
3222
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
3323

@@ -66,7 +56,6 @@ class Experiment:
6656
dtype: torch.dtype
6757
compiled: bool = False
6858
float_8_dtype: Optional[torch.dtype] = torch.float8_e4m3fn
69-
te_time_sec: Optional[float] = None
7059

7160
# 3 Times since we are calculating forward backward
7261
@property
@@ -87,21 +76,6 @@ def float8_tops_sec(self):
8776
def float8_pct_top_peak(self):
8877
return self.float8_tops_sec / dtype_to_peak_tops[self.float_8_dtype]
8978

90-
@property
91-
def te_tops_sec(self):
92-
M, K, N = self.shape
93-
if self.te_time_sec is not None:
94-
return float(3 * (2 * M * K * N)) / self.te_time_sec
95-
else:
96-
return None
97-
98-
@property
99-
def te_pct_top_peak(self):
100-
if self.te_tops_sec is not None:
101-
return self.te_tops_sec / dtype_to_peak_tops[self.float_8_dtype]
102-
else:
103-
return None
104-
10579

10680
def main(
10781
sweep_path: Path,
@@ -113,7 +87,6 @@ def main(
11387

11488
# LLaMa 2 70B single-node weight shapes
11589
# assumes fused attn.wqkv and ffn.w13
116-
# source: https://fburl.com/gsheet/g8onr7rh
11790
name_to_shapes_70b = {
11891
"attn.wqkv": (8192, 1280),
11992
"attn.w0": (1024, 8192),
@@ -145,19 +118,6 @@ def float8_forw_backward():
145118
sync_float8_amax_and_scale_history(linear_float8)
146119
linear_float8(input_tensor).sum().backward()
147120

148-
if transformer_engine_installed:
149-
# Use the same recipe as float8_linear.DelayedScalingRecipe
150-
fp8_format = recipe.Format.HYBRID
151-
fp8_recipe = recipe.DelayedScaling(
152-
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
153-
)
154-
te_linear = te.Linear(K, N, bias=input_bias).to(device=device, dtype=dtype)
155-
156-
def te_forw_backward():
157-
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
158-
y = te_linear(input_tensor)
159-
y.sum().backward()
160-
161121
def n_times(n, fn, *args, **kwargs):
162122
def wrapper(*args, **kwargs):
163123
for _ in range(n):
@@ -169,21 +129,14 @@ def wrapper(*args, **kwargs):
169129

170130
ref_forw_backward = n_times(REPEAT_N, ref_forw_backward)
171131
float8_forw_backward = n_times(REPEAT_N, float8_forw_backward)
172-
if transformer_engine_installed:
173-
te_forw_backward = n_times(REPEAT_N, te_forw_backward)
174132

175133
if compile:
176134
ref_forw_backward = torch.compile(ref_forw_backward)
177135
float8_forw_backward = torch.compile(float8_forw_backward)
178-
# Compiling TE_linear fails but they are already compiling under the hood
179-
# if transformer_engine_installed:
180-
# te_forw_backward = torch.compile(te_forw_backward)
181136

182137
for _ in range(5):
183138
ref_forw_backward()
184139
float8_forw_backward()
185-
if transformer_engine_installed:
186-
te_forw_backward()
187140

188141
ref_time = (
189142
benchmark_torch_function_in_microseconds(ref_forw_backward)
@@ -195,27 +148,16 @@ def wrapper(*args, **kwargs):
195148
* 1e-6
196149
/ REPEAT_N
197150
)
198-
if transformer_engine_installed:
199-
te_time_sec = (
200-
benchmark_torch_function_in_microseconds(te_forw_backward)
201-
* 1e-6
202-
/ REPEAT_N
203-
)
204-
else:
205-
te_time_sec = None
206151
experiment = Experiment(
207152
name,
208153
(M, K, N),
209154
ref_time,
210155
float8_time,
211156
dtype,
212157
compile,
213-
te_time_sec=te_time_sec,
214158
)
215159
print(experiment)
216160
print("float8 speedup", experiment.ref_time_sec / experiment.float8_time_sec)
217-
if transformer_engine_installed:
218-
print("te speedup", experiment.ref_time_sec / experiment.te_time_sec)
219161
experiment_list.append(experiment)
220162
torch._dynamo.reset()
221163

@@ -229,13 +171,10 @@ def wrapper(*args, **kwargs):
229171
"fp8_dtype",
230172
"ref_time_sec",
231173
"pt_fp8_time_sec",
232-
"te_fp8_time_sec",
233174
"ref_tops_sec",
234175
"ref_pct_top_peak",
235176
"pt_fp8_tops_sec",
236177
"pt_fp8_pct_top_peak",
237-
"te_fp8_tops_sec",
238-
"te_fp8_pct_top_peak",
239178
]
240179
data = []
241180
for experiment in experiment_list:
@@ -250,22 +189,15 @@ def wrapper(*args, **kwargs):
250189
experiment.float_8_dtype,
251190
experiment.ref_time_sec,
252191
experiment.float8_time_sec,
253-
experiment.te_time_sec,
254192
experiment.ref_tops_sec,
255193
experiment.ref_pct_top_peak,
256194
experiment.float8_tops_sec,
257195
experiment.float8_pct_top_peak,
258-
experiment.te_tops_sec,
259-
experiment.te_pct_top_peak,
260196
]
261197
)
262198

263199
data_pd = pd.DataFrame(data, columns=headers)
264200
data_pd["pt_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["pt_fp8_time_sec"]
265-
if transformer_engine_installed:
266-
data_pd["te_fp8_speedup"] = data_pd["ref_time_sec"] / data_pd["te_fp8_time_sec"]
267-
else:
268-
data_pd["te_fp8_speedup"] = -1.0
269201
data_pd["shape"] = (
270202
"("
271203
+ data_pd["M"].astype(str)
@@ -284,9 +216,7 @@ def wrapper(*args, **kwargs):
284216
"compiled",
285217
"ref_time_sec",
286218
"pt_fp8_time_sec",
287-
"te_fp8_time_sec",
288219
"pt_fp8_speedup",
289-
"te_fp8_speedup",
290220
]
291221
]
292222
print(data_pd_simple)

benchmarks/bench_multi_gpu.py

Lines changed: 5 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,6 @@
2626
StateDictType,
2727
)
2828

29-
# Check if transformer_engine is installed
30-
transformer_engine_installed = False
31-
try:
32-
import transformer_engine.pytorch as te
33-
from transformer_engine.common import recipe
34-
35-
transformer_engine_installed = True
36-
except ImportError:
37-
print("transformer_engine not installed and we won't compare against this")
38-
3929

4030
torch.manual_seed(0)
4131

@@ -68,26 +58,18 @@ def cleanup():
6858
dist.destroy_process_group()
6959

7060

71-
def get_model(K, N, is_fp8, is_te, base_dtype=torch.float32):
61+
def get_model(K, N, is_fp8, base_dtype=torch.float32):
7262
modules = [
73-
(
74-
nn.Linear(K, N, dtype=base_dtype)
75-
if not is_te
76-
else te.Linear(K, N, params_dtype=base_dtype)
77-
),
63+
nn.Linear(K, N, dtype=base_dtype),
7864
nn.ReLU(),
7965
]
8066
N_LAYERS = 20
8167
# N linear layers
8268
for _ in range(N_LAYERS - 1):
83-
if is_te:
84-
modules.append(te.Linear(N, N, params_dtype=base_dtype))
85-
else:
86-
modules.append(nn.Linear(N, N, dtype=base_dtype))
69+
modules.append(nn.Linear(N, N, dtype=base_dtype))
8770
modules.append(nn.ReLU())
8871
m = nn.Sequential(*modules)
8972
if is_fp8:
90-
assert not is_te, "`is_fp8` (using pytorch fp8) can't be used with `is_te`"
9173
swap_linear_with_float8_linear(m, Float8Linear, emulate=False)
9274
return m
9375

@@ -105,9 +87,7 @@ def fsdp_main(rank, world_size, args):
10587
bsz_local_end = int((rank + 1) / world_size * B)
10688
input_tensor = input_global[bsz_local_start:bsz_local_end].to(rank)
10789

108-
fp8_model = get_model(K, N, is_fp8=True, is_te=False, base_dtype=base_dtype).to(
109-
rank
110-
)
90+
fp8_model = get_model(K, N, is_fp8=True, base_dtype=base_dtype).to(rank)
11191
# Need use_orig_params=True to compile FSDP
11292
fp8_model = FSDP(fp8_model, use_orig_params=True)
11393
fp8_optimizer = torch.optim.SGD(fp8_model.parameters(), lr=lr * world_size)
@@ -132,9 +112,7 @@ def float8_forw_backward():
132112
fp8_optimizer.step()
133113
sync_float8_func(fp8_model)
134114

135-
ref_model = get_model(K, N, is_fp8=False, is_te=False, base_dtype=base_dtype).to(
136-
rank
137-
)
115+
ref_model = get_model(K, N, is_fp8=False, base_dtype=base_dtype).to(rank)
138116
ref_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size)
139117
if compile:
140118
ref_model = torch.compile(ref_model)
@@ -146,30 +124,6 @@ def ref_forw_backward():
146124
ref_model(input_tensor).sum().backward()
147125
ref_optimizer.step()
148126

149-
if transformer_engine_installed:
150-
te_model = FSDP(
151-
get_model(K, N, is_fp8=False, is_te=True, base_dtype=base_dtype).to(rank),
152-
use_orig_params=True,
153-
)
154-
fp8_format = recipe.Format.HYBRID
155-
fp8_recipe = recipe.DelayedScaling(
156-
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
157-
)
158-
# Compiling TE_linear fails but they are already compiling under the hood
159-
# if transformer_engine_installed:
160-
# te_forw_backward = torch.compile(te_forw_backward)
161-
if rank == 0:
162-
print(te_model)
163-
164-
te_optimizer = torch.optim.SGD(ref_model.parameters(), lr=lr * world_size)
165-
166-
def te_forw_backward():
167-
te_optimizer.zero_grad()
168-
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
169-
y = te_model(input_tensor)
170-
y.sum().backward()
171-
te_optimizer.step()
172-
173127
def run_n_iterations(n, fn):
174128
for _ in range(n):
175129
fn()
@@ -179,8 +133,6 @@ def run_n_iterations(n, fn):
179133
# warmup
180134
run_n_iterations(50, ref_forw_backward)
181135
run_n_iterations(50, float8_forw_backward)
182-
if transformer_engine_installed:
183-
run_n_iterations(50, te_forw_backward)
184136

185137
N_ITER = 50
186138
ref_time = (
@@ -197,24 +149,11 @@ def run_n_iterations(n, fn):
197149
* 1e-6
198150
/ N_ITER
199151
)
200-
if transformer_engine_installed:
201-
te_time_sec = (
202-
benchmark_torch_function_in_microseconds(
203-
run_n_iterations, N_ITER, te_forw_backward
204-
)
205-
* 1e-6
206-
/ N_ITER
207-
)
208-
else:
209-
te_time_sec = None
210152

211153
if rank == 0:
212154
print("ref_time", ref_time)
213155
print("float8_time", float8_time)
214-
print("te_time_sec", te_time_sec)
215156
print("float8 speedup", ref_time / float8_time)
216-
if transformer_engine_installed:
217-
print("te speedup", ref_time / te_time_sec)
218157

219158
cleanup()
220159

benchmarks/profile_linear_float8.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -76,17 +76,6 @@ def profile_function(
7676
return prof
7777

7878

79-
# Check if transformer_engine is installed
80-
transformer_engine_installed = False
81-
try:
82-
import transformer_engine.pytorch as te
83-
from transformer_engine.common import recipe
84-
85-
transformer_engine_installed = True
86-
except ImportError:
87-
print("transformer_engine not installed and we won't compare against this")
88-
89-
9079
@dataclass(frozen=True)
9180
class LinearParams:
9281
M: int
@@ -165,35 +154,13 @@ def float8_forw_backward_wrapper(x):
165154
with record_function("backward"):
166155
out.sum().backward()
167156

168-
if transformer_engine_installed:
169-
# Use the same recipe as float8_linear.DelayedScalingRecipe
170-
fp8_format = recipe.Format.HYBRID
171-
fp8_recipe = recipe.DelayedScaling(
172-
fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max"
173-
)
174-
te_linear = te.Linear(params.K, params.N, bias=params.input_bias).to(
175-
device="cuda", dtype=params.ref_dtype
176-
)
177-
178-
def te_forw_backward(x):
179-
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
180-
with record_function("forward"):
181-
out = te_linear(x)
182-
with record_function("backward"):
183-
out.sum().backward()
184-
185157
if params.torch_compile:
186158
ref_forw_backward = torch.compile(ref_forw_backward)
187159
float8_forw_backward = torch.compile(float8_forw_backward, fullgraph=True)
188-
# Compiling TE_linear fails but they are already compiling under the hood
189-
# if transformer_engine_installed:
190-
# te_forw_backward = torch.compile(te_forw_backward)
191160

192161
for _ in range(5):
193162
ref_forw_backward(input_tensor)
194163
float8_forw_backward_wrapper(input_tensor)
195-
if transformer_engine_installed:
196-
te_forw_backward(input_tensor)
197164

198165
# Profile Reference Linear
199166
ref_string = f"linear_ref_dtype_{params.ref_dtype}_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}_compile_{params.torch_compile}.json"
@@ -213,13 +180,6 @@ def te_forw_backward(x):
213180
)
214181
profile_function(profile_config, float8_forw_backward_wrapper, input_tensor)
215182

216-
te_string = f"linear_transformer_engine_M_{params.M}_K_{params.K}_N_{params.N}_input_bias_{params.input_bias}.json"
217-
if transformer_engine_installed:
218-
profile_config = ProfileConfig(
219-
str(profile_path / te_string), te_string, iters=5, warmup_iters=5, sync=True
220-
)
221-
profile_function(profile_config, te_forw_backward, input_tensor)
222-
223183

224184
def invoke_main() -> None:
225185
# Example usage: python benchmarks/profile_linear_float8.py benchmarks/data/profiles --compile=True --linear_type="dynamic"

0 commit comments

Comments
 (0)