Skip to content

Commit dccb065

Browse files
vkuzojainapurva
authored andcommitted
update float8 benchmarks to be more useful for smaller shapes (#615)
Summary: adds ability to select various shape gen strategies to benchmarks/float/bench_matmul.py adds options to profile smaller building blocks in benchmarks/float8/profile_linear_float8.py, targeting ~50% GPU time between gemms and float8 overhead This will make it easier to hunt for performance improvements relevant to small shapes. Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 2330a0b commit dccb065

File tree

4 files changed

+98
-76
lines changed

4 files changed

+98
-76
lines changed

benchmarks/float8/bench_linear_float8.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
sync_float8_amax_and_scale_history,
2222
)
2323
from torchao.float8.float8_tensor import ScaledMMConfig
24+
from utils import get_name_to_shapes_iter
2425
from tqdm import tqdm
2526

2627
# estimating TOPs for matmuls in fp32, fp16, fp8
@@ -96,6 +97,11 @@ def main(
9697
n_limit: Optional[int] = None,
9798
fast_accum_filter: Optional[bool] = None,
9899
shape_name_filter: Optional[str] = None,
100+
*,
101+
shape_gen_name: str = 'llama',
102+
M: Optional[int] = None,
103+
K: Optional[int] = None,
104+
N: Optional[int] = None,
99105
scaling_type_input: str = "dynamic",
100106
scaling_type_weight: str = "dynamic",
101107
scaling_type_grad_output: str = "dynamic",
@@ -112,26 +118,19 @@ def main(
112118
cast_config_grad_output=CastConfig(scaling_type=scaling_type_grad_output),
113119
)
114120

115-
# LLaMa 2 70B single-node weight shapes
116-
# assumes fused attn.wqkv and ffn.w13
117-
name_to_shapes_70b = {
118-
"attn.wqkv": (8192, 1280),
119-
"attn.w0": (1024, 8192),
120-
"ffn.w13": (8192, 7168),
121-
"ffn.w2": (3584, 8192),
122-
}
121+
name_to_shapes = get_name_to_shapes_iter(shape_gen_name, M, K, N)
123122
input_bias = False
124123
if fast_accum_filter is not None:
125124
use_fast_accum = [fast_accum_filter]
126125
else:
127126
use_fast_accum = [True, False]
128127
if shape_name_filter is not None:
129128
k = shape_name_filter
130-
name_to_shapes_70b = {k: name_to_shapes_70b[k]}
129+
name_to_shapes = ((k, v) for (k, v) in name_to_shapes if k == shape_name_filter)
131130
experiment_list: List[Experiment] = []
132131
dtype = torch.bfloat16
133-
for idx, (fast_accum, (name, (K, N))) in enumerate(
134-
tqdm(list(product(use_fast_accum, name_to_shapes_70b.items())))
132+
for idx, (fast_accum, (name, (M, K, N))) in enumerate(
133+
tqdm(list(product(use_fast_accum, name_to_shapes)))
135134
):
136135
if n_limit is not None and idx >= n_limit:
137136
break
@@ -150,8 +149,6 @@ def main(
150149
else:
151150
linear_float8.forward_config = ScaledMMConfig(False, False, False)
152151

153-
bsz, seq_len = 4, 4096
154-
M = bsz * seq_len
155152
input_tensor = torch.randn(M, K, device=device, dtype=dtype, requires_grad=True)
156153
ref_forw_backward = lambda: linear_ref(input_tensor).sum().backward()
157154

@@ -279,6 +276,10 @@ def invoke_main() -> None:
279276
parser.add_argument("-o", "--output_path", type=str, required=False)
280277
parser.add_argument("--disable_compile", action="store_true")
281278
parser.add_argument("-n", "--n_limit", type=int, required=False)
279+
parser.add_argument("--shape_gen_name", type=str, required=False)
280+
parser.add_argument("--M", type=int, required=False)
281+
parser.add_argument("--K", type=int, required=False)
282+
parser.add_argument("--N", type=int, required=False)
282283
parser.add_argument("--fast_accum_filter", type=bool, required=False)
283284
parser.add_argument("--shape_name_filter", type=str, required=False)
284285
parser.add_argument("--scaling_type_input", type=str, required=False)
@@ -287,6 +288,14 @@ def invoke_main() -> None:
287288
args = parser.parse_args()
288289
output_path = Path(args.output_path) if args.output_path is not None else None
289290
kwargs = {}
291+
if args.shape_gen_name is not None:
292+
kwargs["shape_gen_name"] = args.shape_gen_name
293+
if args.M is not None:
294+
kwargs["M"] = args.M,
295+
if args.K is not None:
296+
kwargs["K"] = args.K,
297+
if args.N is not None:
298+
kwargs["N"] = args.N,
290299
if args.scaling_type_input is not None:
291300
kwargs["scaling_type_input"] = args.scaling_type_input
292301
if args.scaling_type_weight is not None:

benchmarks/float8/bench_matmul.py

Lines changed: 2 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import torch.nn as nn
1414
import torch.utils.benchmark as benchmark
1515

16+
from utils import get_name_to_shapes_iter
17+
1618
# estimating TOPs for matmuls in fp32, fp16, fp8
1719
# assuming A * B = C, with A being M * K, B being K * N, C being M * N
1820

@@ -48,67 +50,6 @@ def do_benchmarks(tops, peak_tops, f, *args, **kwargs):
4850
return time_sec, tops_sec, pct_top_peak
4951

5052

51-
def get_name_to_shapes_iter(
52-
shape_gen_name: str,
53-
M: Optional[int],
54-
K: Optional[int],
55-
N: Optional[int],
56-
):
57-
if shape_gen_name == 'llama':
58-
assert M == K == N == None, \
59-
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
60-
bsz, seq_len = 4, 4096
61-
M = bsz * seq_len
62-
# LLaMa 2 70B single-node weight shapes
63-
# assumes fused attn.wqkv and ffn.w13
64-
# source: https://fburl.com/gsheet/g8onr7rh
65-
name_to_shapes_70b = {
66-
"attn.wqkv": (M, 8192, 1280),
67-
"attn.w0": (M, 1024, 8192),
68-
"ffn.w13": (M, 8192, 7168),
69-
"ffn.w2": (M, 3584, 8192),
70-
}
71-
return name_to_shapes_70b.items()
72-
73-
elif shape_gen_name == 'square':
74-
assert M == K == N == None, \
75-
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
76-
name_to_shapes = {}
77-
min_power_of_2 = 5 # 32
78-
max_power_of_2 = 16 # 65,536
79-
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
80-
val = 2 ** power_of_2
81-
name_to_shapes[idx] = val, val, val
82-
return name_to_shapes.items()
83-
84-
elif shape_gen_name == 'sweep':
85-
assert M == K == N == None, \
86-
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
87-
name_to_shapes = {}
88-
min_p2 = 5 # 32
89-
max_p2 = 16 # 65,536
90-
counter = 0
91-
for M_p2 in range(min_p2, max_p2 + 1):
92-
M = 2 ** M_p2
93-
for K_p2 in range(min_p2, max_p2 + 1):
94-
K = 2 ** K_p2
95-
for N_p2 in range(min_p2, max_p2 + 1):
96-
N = 2 ** N_p2
97-
name_to_shapes[counter] = M, K, N
98-
counter += 1
99-
return name_to_shapes.items()
100-
101-
elif shape_gen_name == 'custom':
102-
assert M is not None and K is not None and N is not None, \
103-
'M, K, N must be specified for custom shape_gen'
104-
name_to_shapes = {
105-
1: (M, K, N),
106-
}
107-
return name_to_shapes.items()
108-
109-
raise AssertionError(f'unknown shape_gen_name {shape_gen_name}')
110-
111-
11253
@torch.inference_mode()
11354
def run(
11455
n_limit: Optional[int] = None,

benchmarks/float8/profile_linear_float8.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def main(
210210
model_type: str = "linear",
211211
dtype_filter: str = "both",
212212
):
213-
assert model_type in ("linear", "ln_linear", "norm_ffn_norm"), "unsupported"
213+
assert model_type in ("linear", "ln_linear", "norm_ffn_norm", "norm_ffn_norm_small"), "unsupported"
214214
assert dtype_filter in ("both", "float8", "bfloat16")
215215

216216
scaling_type_input = ScalingType(scaling_type_input)
@@ -250,8 +250,18 @@ def main(
250250
input_tensor = torch.randn(
251251
1, 8192, 4096, device=device, dtype=ref_dtype
252252
).requires_grad_()
253+
elif model_type == "norm_ffn_norm_small":
254+
m_ref = NormFFNResidualNorm(
255+
dim=4096,
256+
hidden_dim=4096,
257+
multiple_of=1024,
258+
ffn_dim_multiplier=1.0,
259+
)
260+
input_tensor = torch.randn(
261+
1, 2048, 4096, device=device, dtype=ref_dtype
262+
).requires_grad_()
253263
else:
254-
M, K, N = 4 * 4096, 8192, 7168
264+
M, K, N = 4096, 4096, 4096
255265
m_ref = torch.nn.Sequential(
256266
torch.nn.Linear(K, N, bias=False),
257267
)

benchmarks/float8/utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import collections
88
import re
9+
from typing import Optional
910

1011

1112
def profiler_output_to_time_by_kernel_name(prof):
@@ -81,3 +82,64 @@ def parse_bw_and_kernel_name(line):
8182
return result.group(1), result.group(2)
8283
else:
8384
return None, None
85+
86+
87+
def get_name_to_shapes_iter(
88+
shape_gen_name: str,
89+
M: Optional[int],
90+
K: Optional[int],
91+
N: Optional[int],
92+
):
93+
if shape_gen_name == 'llama':
94+
assert M == K == N == None, \
95+
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
96+
bsz, seq_len = 4, 4096
97+
M = bsz * seq_len
98+
# LLaMa 2 70B single-node weight shapes
99+
# assumes fused attn.wqkv and ffn.w13
100+
# source: https://fburl.com/gsheet/g8onr7rh
101+
name_to_shapes_70b = {
102+
"attn.wqkv": (M, 8192, 1280),
103+
"attn.w0": (M, 1024, 8192),
104+
"ffn.w13": (M, 8192, 7168),
105+
"ffn.w2": (M, 3584, 8192),
106+
}
107+
return name_to_shapes_70b.items()
108+
109+
elif shape_gen_name == 'square':
110+
assert M == K == N == None, \
111+
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
112+
name_to_shapes = {}
113+
min_power_of_2 = 5 # 32
114+
max_power_of_2 = 16 # 65,536
115+
for idx, power_of_2 in enumerate(range(min_power_of_2, max_power_of_2 + 1)):
116+
val = 2 ** power_of_2
117+
name_to_shapes[idx] = val, val, val
118+
return name_to_shapes.items()
119+
120+
elif shape_gen_name == 'sweep':
121+
assert M == K == N == None, \
122+
f'M, K, N arguments not supported for shape_gen_name {shape_gen_name}'
123+
name_to_shapes = {}
124+
min_p2 = 5 # 32
125+
max_p2 = 16 # 65,536
126+
counter = 0
127+
for M_p2 in range(min_p2, max_p2 + 1):
128+
M = 2 ** M_p2
129+
for K_p2 in range(min_p2, max_p2 + 1):
130+
K = 2 ** K_p2
131+
for N_p2 in range(min_p2, max_p2 + 1):
132+
N = 2 ** N_p2
133+
name_to_shapes[counter] = M, K, N
134+
counter += 1
135+
return name_to_shapes.items()
136+
137+
elif shape_gen_name == 'custom':
138+
assert M is not None and K is not None and N is not None, \
139+
'M, K, N must be specified for custom shape_gen'
140+
name_to_shapes = {
141+
1: (M, K, N),
142+
}
143+
return name_to_shapes.items()
144+
145+
raise AssertionError(f'unknown shape_gen_name {shape_gen_name}')

0 commit comments

Comments
 (0)