|
16 | 16 | import platform
|
17 | 17 | import re
|
18 | 18 | import shutil
|
| 19 | +import statistics |
19 | 20 | import sys
|
20 | 21 | import tempfile
|
21 | 22 | import textwrap
|
@@ -169,6 +170,77 @@ class GraphPartitionMap:
|
169 | 170 | constant_names: list[str]
|
170 | 171 |
|
171 | 172 |
|
| 173 | +def fp8_bench(fn: Callable[[], Any], warmup: int = 25, rep: int = 100) -> float: |
| 174 | + """ |
| 175 | + Returns benchmark results by examining torch profiler events. |
| 176 | + This could be more accurate as it doesn't count CPU side overhead. |
| 177 | + However, this also requires manually excluding irrelevant event, e.g. |
| 178 | + vectorized_elementwise_kernel which is used to fill L2 cache, |
| 179 | + various CUDA events, etc, so could also be fragile. |
| 180 | + """ |
| 181 | + |
| 182 | + fn() |
| 183 | + torch.cuda.synchronize() |
| 184 | + cache = torch.empty(int(256e6 // 4), dtype=torch.float16, device="cuda") |
| 185 | + |
| 186 | + # Estimate the runtime of the function |
| 187 | + start_event = torch.cuda.Event(enable_timing=True) |
| 188 | + end_event = torch.cuda.Event(enable_timing=True) |
| 189 | + start_event.record() |
| 190 | + for _ in range(5): |
| 191 | + cache.zero_() |
| 192 | + fn() |
| 193 | + end_event.record() |
| 194 | + torch.cuda.synchronize() |
| 195 | + estimate_ms = start_event.elapsed_time(end_event) / 5 |
| 196 | + |
| 197 | + # compute number of warmup and repeat |
| 198 | + n_warmup = max(1, int(warmup / estimate_ms)) |
| 199 | + n_repeat = max(1, int(rep / estimate_ms)) |
| 200 | + |
| 201 | + # Warm-up |
| 202 | + for _ in range(n_warmup): |
| 203 | + fn() |
| 204 | + |
| 205 | + start_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] |
| 206 | + end_event = [torch.cuda.Event(enable_timing=True) for _ in range(n_repeat)] |
| 207 | + with torch.profiler.profile( |
| 208 | + activities=[ |
| 209 | + torch.profiler.ProfilerActivity.CUDA, |
| 210 | + ] |
| 211 | + ) as p: |
| 212 | + torch.cuda.synchronize() |
| 213 | + for i in range(n_repeat): |
| 214 | + cache.zero_() |
| 215 | + start_event[i].record() |
| 216 | + with torch.cuda.nvtx.range("RunCudaModule"): |
| 217 | + fn() |
| 218 | + end_event[i].record() |
| 219 | + torch.cuda.synchronize() |
| 220 | + times = torch.tensor( |
| 221 | + [s.elapsed_time(e) for s, e in zip(start_event, end_event)] |
| 222 | + ) |
| 223 | + |
| 224 | + res = torch.mean(times).item() |
| 225 | + log.debug("raw events") |
| 226 | + log.debug(p.key_averages().table(sort_by="self_device_time_total", row_limit=-1)) |
| 227 | + filtered_events = EventList( |
| 228 | + [ |
| 229 | + event |
| 230 | + for event in p.events() |
| 231 | + if event.device_type == DeviceType.CUDA and "fused_abs_max_0" in event.name |
| 232 | + ] |
| 233 | + ) |
| 234 | + if filtered_events: |
| 235 | + res -= ( |
| 236 | + statistics.mean(event.device_time_total for event in filtered_events) |
| 237 | + / 1000.0 |
| 238 | + ) |
| 239 | + |
| 240 | + log.debug("profiling results: %s ms", res) |
| 241 | + return res |
| 242 | + |
| 243 | + |
172 | 244 | def do_bench_using_profiling(
|
173 | 245 | fn: Callable[[], Any], warmup: int = 25, rep: int = 100
|
174 | 246 | ) -> float:
|
|
0 commit comments