Skip to content

Commit a4b44dd

Browse files
desertfirepytorchmergebot
authored andcommitted
[AOTI] Introduce DeferredCudaGridLine for cuda cpp wrapper (pytorch#129268)
Summary: Similar to pytorch#129135, use DeferredCudaGridLine to create a deferred grid computation line when generating cpp wrapper. Differential Revision: [D61800622](https://our.internmc.facebook.com/intern/diff/D61800622) Pull Request resolved: pytorch#129268 Approved by: https://github.com/angelayi
1 parent 5fd670e commit a4b44dd

File tree

3 files changed

+153
-62
lines changed

3 files changed

+153
-62
lines changed

torch/_inductor/codegen/cpp_wrapper_cpu.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def __init__(self):
7171

7272
def generate_kernel_call(
7373
self,
74-
name,
74+
kernel_name: str,
7575
call_args,
7676
grid=None,
7777
device_index=None,
@@ -81,6 +81,7 @@ def generate_kernel_call(
8181
raw_args=None,
8282
grid_fn: str = "grid",
8383
triton_meta=None,
84+
autotune_configs=None,
8485
grid_extra_kwargs="",
8586
):
8687
"""
@@ -94,14 +95,18 @@ def generate_kernel_call(
9495
"""
9596
if cuda:
9697
return super().generate_kernel_call(
97-
name,
98+
kernel_name,
9899
call_args,
99100
grid,
100101
device_index,
101102
cuda,
102103
triton,
103104
arg_types,
105+
raw_args,
104106
grid_fn,
107+
triton_meta,
108+
autotune_configs,
109+
grid_extra_kwargs,
105110
)
106111
else:
107112
if config.abi_compatible:
@@ -119,9 +124,9 @@ def generate_kernel_call(
119124
else:
120125
# arg is a scalar
121126
new_args.append(arg)
122-
self.writeline(self.wrap_kernel_call(name, new_args))
127+
self.writeline(self.wrap_kernel_call(kernel_name, new_args))
123128
else:
124-
self.writeline(self.wrap_kernel_call(name, call_args))
129+
self.writeline(self.wrap_kernel_call(kernel_name, call_args))
125130

126131
def write_constant(self, name, hashed):
127132
# include a hash so our code cache gives different constants different files

torch/_inductor/codegen/cpp_wrapper_cuda.py

Lines changed: 135 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from ..virtualized import V
1717
from .aoti_hipify_utils import maybe_hipify_code_wrapper
1818
from .codegen_device_driver import cuda_kernel_driver, cuda_kernel_header
19-
from .cpp_utils import DTYPE_TO_CPP
19+
from .cpp_utils import cexpr, DTYPE_TO_CPP
2020
from .cpp_wrapper_cpu import CppWrapperCpu
2121
from .wrapper import SymbolicCallArg
2222

@@ -61,6 +61,98 @@ def _new_line(self, line):
6161
return DeferredCudaKernelLine(self.kernel_name, line, self.keys)
6262

6363

64+
class DeferredCudaDefaultGrid:
65+
"""
66+
A marker to
67+
"""
68+
69+
def __init__(
70+
self,
71+
kernel_name: str,
72+
grid,
73+
grid_callable: Optional[Callable[..., Any]] = None,
74+
**grid_extra_kwargs,
75+
):
76+
self.kernel_name = kernel_name
77+
self.grid = grid
78+
self.grid_callable = grid_callable
79+
self.grid_extra_kwargs = grid_extra_kwargs
80+
81+
def __call__(self):
82+
grid = self.grid
83+
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
84+
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
85+
grid_callable = self.grid_callable or default_grid
86+
if not self.grid_extra_kwargs:
87+
grid_fn = grid_callable(*grid)
88+
else:
89+
grid_fn = grid_callable(*grid, **self.grid_extra_kwargs)
90+
91+
params = CudaKernelParamCache.get(self.kernel_name)
92+
assert (
93+
params is not None
94+
), f"{self.kernel_name} not found in CudaKernelParamCache"
95+
block_cfg = {
96+
"XBLOCK": params["x_block"],
97+
"YBLOCK": params["y_block"],
98+
"ZBLOCK": params["z_block"],
99+
}
100+
return grid_fn(block_cfg)
101+
102+
103+
class DeferredCudaGridLine(DeferredLineBase):
104+
"""
105+
When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels
106+
to be tuned and stored as cubin files, so use a deferred line to backfill those information
107+
"""
108+
109+
def __init__(
110+
self,
111+
kernel_name: str,
112+
grid_var: str,
113+
grid,
114+
autotune_configs,
115+
):
116+
super().__init__("")
117+
self.kernel_name = kernel_name
118+
self.grid_var = grid_var
119+
self.grid = grid
120+
self.autotune_configs = autotune_configs
121+
122+
def __call__(self):
123+
params = CudaKernelParamCache.get(self.kernel_name)
124+
assert (
125+
params is not None
126+
), f"{self.kernel_name} not found in CudaKernelParamCache"
127+
128+
if self.autotune_configs is not None:
129+
# This indicates the Triton kernel is a user-defined one.
130+
grid = None
131+
if len(self.grid) == 1:
132+
grid = self.grid[0]
133+
else:
134+
for i, c in enumerate(self.autotune_configs):
135+
if all(arg == params["meta"][key] for key, arg in c.kwargs.items()):
136+
grid = self.grid[i]
137+
break
138+
assert grid is not None
139+
elif isinstance(self.grid, DeferredCudaDefaultGrid):
140+
grid = self.grid()
141+
else:
142+
grid = self.grid
143+
144+
assert len(grid) != 0, "Grid can't be empty"
145+
grid_args_str = ", ".join(
146+
[cexpr(V.graph.sizevars.simplify(item)) for item in grid]
147+
)
148+
return f"Grid {self.grid_var} = Grid({grid_args_str});"
149+
150+
def _new_line(self, line):
151+
return DeferredCudaGridLine(
152+
self.kernel_name, self.grid_var, self.grid, self.autotune_configs
153+
)
154+
155+
64156
class CppWrapperCuda(CppWrapperCpu):
65157
"""
66158
Generates cpp wrapper for running on GPU and calls CUDA kernels
@@ -116,28 +208,20 @@ def generate(self, is_inference):
116208
return super().generate(is_inference)
117209

118210
def generate_user_defined_triton_kernel(
119-
self, kernel_name, raw_args, grid, configs, triton_meta, constexprs
211+
self,
212+
kernel_name: str,
213+
raw_args: List[Any],
214+
grid: List[Any],
215+
configs,
216+
triton_meta,
217+
constexprs,
120218
):
121219
# in C++ wrapper, we don't pass constexpr args, as they don't
122220
# get added as parameters to the PTX code compiled from the
123221
# user-defined Triton kernel (only non-constexpr args do)
124222
raw_args = [
125223
raw_arg for i, raw_arg in enumerate(raw_args) if i not in constexprs
126224
]
127-
128-
assert len(grid) != 0
129-
if len(grid) == 1:
130-
grid_decision = grid[0]
131-
else:
132-
meta = CudaKernelParamCache.get(kernel_name)
133-
assert meta is not None
134-
grid_decision = None
135-
for i, c in enumerate(configs):
136-
if all(arg == meta["meta"][key] for key, arg in c.kwargs.items()):
137-
grid_decision = grid[i]
138-
break
139-
assert grid_decision is not None
140-
141225
args = [self.val_to_arg_str(v) for v in raw_args]
142226
arg_types = [
143227
arg.get_dtype() if hasattr(arg, "get_dtype") else type(arg)
@@ -147,10 +231,12 @@ def generate_user_defined_triton_kernel(
147231
kernel_name,
148232
args,
149233
arg_types=arg_types,
150-
grid=grid_decision,
234+
raw_args=raw_args,
235+
grid=grid,
151236
cuda=True,
152237
triton=True,
153238
triton_meta=triton_meta,
239+
autotune_configs=configs,
154240
)
155241

156242
@functools.lru_cache(None) # noqa: B019
@@ -228,39 +314,27 @@ def generate_args_decl(self, call_args, arg_types):
228314

229315
def generate_default_grid(
230316
self,
231-
name: str,
317+
kernel_name: str,
232318
grid: List[Any],
233319
cuda: bool = True,
234320
grid_callable: Optional[Callable[..., Any]] = None,
235321
**grid_extra_kwargs,
236322
):
237323
"""
238324
Generate grid configs for launching a CUDA kernel using the grid
239-
function from triton_heuristics.
325+
function from triton_heuristics. Because its computation needs
326+
to read kernel config after autotune, it is done in a deferred way
327+
using DeferredCudaDefaultGrid.
240328
"""
241329
if not cuda:
242330
return grid
243-
assert isinstance(grid, (list, tuple)), f"expected {grid=} to be a list"
244-
grid = [e.inner_expr if isinstance(e, SymbolicCallArg) else e for e in grid]
245-
grid_callable = grid_callable or default_grid
246-
if not grid_extra_kwargs:
247-
grid_fn = grid_callable(*grid)
248-
else:
249-
grid_fn = grid_callable(*grid, **grid_extra_kwargs)
250-
params = CudaKernelParamCache.get(name)
251-
assert (
252-
params is not None
253-
), f"cuda kernel parameters for {name} should already exist at this moment, only found {CudaKernelParamCache.get_keys()}"
254-
block_cfg = {
255-
"XBLOCK": params["x_block"],
256-
"YBLOCK": params["y_block"],
257-
"ZBLOCK": params["z_block"],
258-
}
259-
return grid_fn(block_cfg)
331+
return DeferredCudaDefaultGrid(
332+
kernel_name, grid, grid_callable, **grid_extra_kwargs
333+
)
260334

261335
def generate_kernel_call(
262336
self,
263-
kernel_name,
337+
kernel_name: str,
264338
call_args,
265339
grid=None,
266340
device_index=None,
@@ -270,6 +344,7 @@ def generate_kernel_call(
270344
raw_args=None,
271345
grid_fn: str = "grid",
272346
triton_meta=None,
347+
autotune_configs=None,
273348
grid_extra_kwargs="",
274349
):
275350
assert arg_types is not None and len(call_args) == len(
@@ -279,7 +354,18 @@ def generate_kernel_call(
279354
if not cuda:
280355
# Even in CppWrapperCuda, we may see cpp kernels
281356
return super().generate_kernel_call(
282-
kernel_name, call_args, grid, device_index, cuda, triton, arg_types
357+
kernel_name,
358+
call_args,
359+
grid,
360+
device_index,
361+
cuda,
362+
triton,
363+
arg_types,
364+
raw_args,
365+
grid_fn,
366+
triton_meta,
367+
autotune_configs,
368+
grid_extra_kwargs,
283369
)
284370

285371
device_index, call_args = self.prepare_triton_kernel_call(
@@ -307,33 +393,26 @@ def generate_kernel_call(
307393
if V.graph.aot_mode
308394
else self.write_get_raw_stream(device_index, V.graph)
309395
)
310-
grid_name = f"{kernel_name}_grid_{next(self.grid_id)}"
311-
assert isinstance(
312-
grid, (list, tuple)
313-
), f"expected grid to be a list or tuple but got: {grid=}"
314-
315-
grid = [V.graph.sizevars.simplify(item) for item in grid]
316-
grid_uses_symbolic_shapes = any(item.free_symbols for item in grid)
317-
grid_args = [self.expr_printer(item) for item in grid]
318-
grid_args_str = ", ".join(grid_args)
319-
self.writeline(f"Grid {grid_name} = Grid({grid_args_str});")
320-
321-
if grid_uses_symbolic_shapes:
322-
self.writeline(f"if ({grid_name}.is_non_zero()) {{")
396+
397+
grid_var = f"{kernel_name}_grid_{next(self.grid_id)}"
398+
self.writeline(
399+
DeferredCudaGridLine(kernel_name, grid_var, grid, autotune_configs)
400+
)
401+
323402
kernel_var_name = f"kernels.{kernel_name}" if V.graph.aot_mode else kernel_name
403+
self.writeline(f"if ({grid_var}.is_non_zero()) {{")
324404
self.writeline(
325405
DeferredCudaKernelLine(
326406
kernel_name,
327407
r"launchKernel({}, {}, {}, {}, %s, %s, {}, {});".format(
328408
kernel_var_name,
329-
f"{grid_name}.grid_x",
330-
f"{grid_name}.grid_y",
331-
f"{grid_name}.grid_z",
409+
f"{grid_var}.grid_x",
410+
f"{grid_var}.grid_y",
411+
f"{grid_var}.grid_z",
332412
kernel_args_var,
333413
stream,
334414
),
335415
("num_warps", "shared_mem"),
336416
),
337417
)
338-
if grid_uses_symbolic_shapes:
339-
self.writeline("}")
418+
self.writeline("}")

torch/_inductor/codegen/wrapper.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,13 @@ def generate_extern_kernel_out(
788788
self.writeline(f"{kernel}({', '.join(args)})")
789789

790790
def generate_user_defined_triton_kernel(
791-
self, kernel_name, raw_args, grid, configs, triton_meta, constexprs
791+
self,
792+
kernel_name: str,
793+
raw_args: List[Any],
794+
grid: List[Any],
795+
configs,
796+
triton_meta,
797+
constexprs,
792798
):
793799
grid_fn, code = user_defined_kernel_grid_fn_code(
794800
kernel_name, configs, grid, wrapper=self
@@ -1541,7 +1547,7 @@ def generate_save_uncompiled_kernels(self):
15411547

15421548
def generate_default_grid(
15431549
self,
1544-
name: str,
1550+
kernel_name: str,
15451551
grid: List[Any],
15461552
cuda: bool = True,
15471553
grid_callable: Optional[Callable[..., Any]] = None,
@@ -1632,6 +1638,7 @@ def generate_kernel_call(
16321638
raw_args=None,
16331639
grid_fn: str = "grid",
16341640
triton_meta=None,
1641+
autotune_configs=None,
16351642
grid_extra_kwargs="",
16361643
):
16371644
"""

0 commit comments

Comments
 (0)