16
16
from ..virtualized import V
17
17
from .aoti_hipify_utils import maybe_hipify_code_wrapper
18
18
from .codegen_device_driver import cuda_kernel_driver , cuda_kernel_header
19
- from .cpp_utils import DTYPE_TO_CPP
19
+ from .cpp_utils import cexpr , DTYPE_TO_CPP
20
20
from .cpp_wrapper_cpu import CppWrapperCpu
21
21
from .wrapper import SymbolicCallArg
22
22
@@ -61,6 +61,98 @@ def _new_line(self, line):
61
61
return DeferredCudaKernelLine (self .kernel_name , line , self .keys )
62
62
63
63
64
+ class DeferredCudaDefaultGrid :
65
+ """
66
+ A marker to
67
+ """
68
+
69
+ def __init__ (
70
+ self ,
71
+ kernel_name : str ,
72
+ grid ,
73
+ grid_callable : Optional [Callable [..., Any ]] = None ,
74
+ ** grid_extra_kwargs ,
75
+ ):
76
+ self .kernel_name = kernel_name
77
+ self .grid = grid
78
+ self .grid_callable = grid_callable
79
+ self .grid_extra_kwargs = grid_extra_kwargs
80
+
81
+ def __call__ (self ):
82
+ grid = self .grid
83
+ assert isinstance (grid , (list , tuple )), f"expected { grid = } to be a list"
84
+ grid = [e .inner_expr if isinstance (e , SymbolicCallArg ) else e for e in grid ]
85
+ grid_callable = self .grid_callable or default_grid
86
+ if not self .grid_extra_kwargs :
87
+ grid_fn = grid_callable (* grid )
88
+ else :
89
+ grid_fn = grid_callable (* grid , ** self .grid_extra_kwargs )
90
+
91
+ params = CudaKernelParamCache .get (self .kernel_name )
92
+ assert (
93
+ params is not None
94
+ ), f"{ self .kernel_name } not found in CudaKernelParamCache"
95
+ block_cfg = {
96
+ "XBLOCK" : params ["x_block" ],
97
+ "YBLOCK" : params ["y_block" ],
98
+ "ZBLOCK" : params ["z_block" ],
99
+ }
100
+ return grid_fn (block_cfg )
101
+
102
+
103
+ class DeferredCudaGridLine (DeferredLineBase ):
104
+ """
105
+ When using cpp wrapper, CUDA kernel load and launch needs to wait for Triton kernels
106
+ to be tuned and stored as cubin files, so use a deferred line to backfill those information
107
+ """
108
+
109
+ def __init__ (
110
+ self ,
111
+ kernel_name : str ,
112
+ grid_var : str ,
113
+ grid ,
114
+ autotune_configs ,
115
+ ):
116
+ super ().__init__ ("" )
117
+ self .kernel_name = kernel_name
118
+ self .grid_var = grid_var
119
+ self .grid = grid
120
+ self .autotune_configs = autotune_configs
121
+
122
+ def __call__ (self ):
123
+ params = CudaKernelParamCache .get (self .kernel_name )
124
+ assert (
125
+ params is not None
126
+ ), f"{ self .kernel_name } not found in CudaKernelParamCache"
127
+
128
+ if self .autotune_configs is not None :
129
+ # This indicates the Triton kernel is a user-defined one.
130
+ grid = None
131
+ if len (self .grid ) == 1 :
132
+ grid = self .grid [0 ]
133
+ else :
134
+ for i , c in enumerate (self .autotune_configs ):
135
+ if all (arg == params ["meta" ][key ] for key , arg in c .kwargs .items ()):
136
+ grid = self .grid [i ]
137
+ break
138
+ assert grid is not None
139
+ elif isinstance (self .grid , DeferredCudaDefaultGrid ):
140
+ grid = self .grid ()
141
+ else :
142
+ grid = self .grid
143
+
144
+ assert len (grid ) != 0 , "Grid can't be empty"
145
+ grid_args_str = ", " .join (
146
+ [cexpr (V .graph .sizevars .simplify (item )) for item in grid ]
147
+ )
148
+ return f"Grid { self .grid_var } = Grid({ grid_args_str } );"
149
+
150
+ def _new_line (self , line ):
151
+ return DeferredCudaGridLine (
152
+ self .kernel_name , self .grid_var , self .grid , self .autotune_configs
153
+ )
154
+
155
+
64
156
class CppWrapperCuda (CppWrapperCpu ):
65
157
"""
66
158
Generates cpp wrapper for running on GPU and calls CUDA kernels
@@ -116,28 +208,20 @@ def generate(self, is_inference):
116
208
return super ().generate (is_inference )
117
209
118
210
def generate_user_defined_triton_kernel (
119
- self , kernel_name , raw_args , grid , configs , triton_meta , constexprs
211
+ self ,
212
+ kernel_name : str ,
213
+ raw_args : List [Any ],
214
+ grid : List [Any ],
215
+ configs ,
216
+ triton_meta ,
217
+ constexprs ,
120
218
):
121
219
# in C++ wrapper, we don't pass constexpr args, as they don't
122
220
# get added as parameters to the PTX code compiled from the
123
221
# user-defined Triton kernel (only non-constexpr args do)
124
222
raw_args = [
125
223
raw_arg for i , raw_arg in enumerate (raw_args ) if i not in constexprs
126
224
]
127
-
128
- assert len (grid ) != 0
129
- if len (grid ) == 1 :
130
- grid_decision = grid [0 ]
131
- else :
132
- meta = CudaKernelParamCache .get (kernel_name )
133
- assert meta is not None
134
- grid_decision = None
135
- for i , c in enumerate (configs ):
136
- if all (arg == meta ["meta" ][key ] for key , arg in c .kwargs .items ()):
137
- grid_decision = grid [i ]
138
- break
139
- assert grid_decision is not None
140
-
141
225
args = [self .val_to_arg_str (v ) for v in raw_args ]
142
226
arg_types = [
143
227
arg .get_dtype () if hasattr (arg , "get_dtype" ) else type (arg )
@@ -147,10 +231,12 @@ def generate_user_defined_triton_kernel(
147
231
kernel_name ,
148
232
args ,
149
233
arg_types = arg_types ,
150
- grid = grid_decision ,
234
+ raw_args = raw_args ,
235
+ grid = grid ,
151
236
cuda = True ,
152
237
triton = True ,
153
238
triton_meta = triton_meta ,
239
+ autotune_configs = configs ,
154
240
)
155
241
156
242
@functools .lru_cache (None ) # noqa: B019
@@ -228,39 +314,27 @@ def generate_args_decl(self, call_args, arg_types):
228
314
229
315
def generate_default_grid (
230
316
self ,
231
- name : str ,
317
+ kernel_name : str ,
232
318
grid : List [Any ],
233
319
cuda : bool = True ,
234
320
grid_callable : Optional [Callable [..., Any ]] = None ,
235
321
** grid_extra_kwargs ,
236
322
):
237
323
"""
238
324
Generate grid configs for launching a CUDA kernel using the grid
239
- function from triton_heuristics.
325
+ function from triton_heuristics. Because its computation needs
326
+ to read kernel config after autotune, it is done in a deferred way
327
+ using DeferredCudaDefaultGrid.
240
328
"""
241
329
if not cuda :
242
330
return grid
243
- assert isinstance (grid , (list , tuple )), f"expected { grid = } to be a list"
244
- grid = [e .inner_expr if isinstance (e , SymbolicCallArg ) else e for e in grid ]
245
- grid_callable = grid_callable or default_grid
246
- if not grid_extra_kwargs :
247
- grid_fn = grid_callable (* grid )
248
- else :
249
- grid_fn = grid_callable (* grid , ** grid_extra_kwargs )
250
- params = CudaKernelParamCache .get (name )
251
- assert (
252
- params is not None
253
- ), f"cuda kernel parameters for { name } should already exist at this moment, only found { CudaKernelParamCache .get_keys ()} "
254
- block_cfg = {
255
- "XBLOCK" : params ["x_block" ],
256
- "YBLOCK" : params ["y_block" ],
257
- "ZBLOCK" : params ["z_block" ],
258
- }
259
- return grid_fn (block_cfg )
331
+ return DeferredCudaDefaultGrid (
332
+ kernel_name , grid , grid_callable , ** grid_extra_kwargs
333
+ )
260
334
261
335
def generate_kernel_call (
262
336
self ,
263
- kernel_name ,
337
+ kernel_name : str ,
264
338
call_args ,
265
339
grid = None ,
266
340
device_index = None ,
@@ -270,6 +344,7 @@ def generate_kernel_call(
270
344
raw_args = None ,
271
345
grid_fn : str = "grid" ,
272
346
triton_meta = None ,
347
+ autotune_configs = None ,
273
348
grid_extra_kwargs = "" ,
274
349
):
275
350
assert arg_types is not None and len (call_args ) == len (
@@ -279,7 +354,18 @@ def generate_kernel_call(
279
354
if not cuda :
280
355
# Even in CppWrapperCuda, we may see cpp kernels
281
356
return super ().generate_kernel_call (
282
- kernel_name , call_args , grid , device_index , cuda , triton , arg_types
357
+ kernel_name ,
358
+ call_args ,
359
+ grid ,
360
+ device_index ,
361
+ cuda ,
362
+ triton ,
363
+ arg_types ,
364
+ raw_args ,
365
+ grid_fn ,
366
+ triton_meta ,
367
+ autotune_configs ,
368
+ grid_extra_kwargs ,
283
369
)
284
370
285
371
device_index , call_args = self .prepare_triton_kernel_call (
@@ -307,33 +393,26 @@ def generate_kernel_call(
307
393
if V .graph .aot_mode
308
394
else self .write_get_raw_stream (device_index , V .graph )
309
395
)
310
- grid_name = f"{ kernel_name } _grid_{ next (self .grid_id )} "
311
- assert isinstance (
312
- grid , (list , tuple )
313
- ), f"expected grid to be a list or tuple but got: { grid = } "
314
-
315
- grid = [V .graph .sizevars .simplify (item ) for item in grid ]
316
- grid_uses_symbolic_shapes = any (item .free_symbols for item in grid )
317
- grid_args = [self .expr_printer (item ) for item in grid ]
318
- grid_args_str = ", " .join (grid_args )
319
- self .writeline (f"Grid { grid_name } = Grid({ grid_args_str } );" )
320
-
321
- if grid_uses_symbolic_shapes :
322
- self .writeline (f"if ({ grid_name } .is_non_zero()) {{" )
396
+
397
+ grid_var = f"{ kernel_name } _grid_{ next (self .grid_id )} "
398
+ self .writeline (
399
+ DeferredCudaGridLine (kernel_name , grid_var , grid , autotune_configs )
400
+ )
401
+
323
402
kernel_var_name = f"kernels.{ kernel_name } " if V .graph .aot_mode else kernel_name
403
+ self .writeline (f"if ({ grid_var } .is_non_zero()) {{" )
324
404
self .writeline (
325
405
DeferredCudaKernelLine (
326
406
kernel_name ,
327
407
r"launchKernel({}, {}, {}, {}, %s, %s, {}, {});" .format (
328
408
kernel_var_name ,
329
- f"{ grid_name } .grid_x" ,
330
- f"{ grid_name } .grid_y" ,
331
- f"{ grid_name } .grid_z" ,
409
+ f"{ grid_var } .grid_x" ,
410
+ f"{ grid_var } .grid_y" ,
411
+ f"{ grid_var } .grid_z" ,
332
412
kernel_args_var ,
333
413
stream ,
334
414
),
335
415
("num_warps" , "shared_mem" ),
336
416
),
337
417
)
338
- if grid_uses_symbolic_shapes :
339
- self .writeline ("}" )
418
+ self .writeline ("}" )
0 commit comments