@@ -196,8 +196,23 @@ def generate_ttir(
196
196
197
197
assert isinstance (kernel , JITFunction )
198
198
199
+ context = triton ._C .libtriton .ir .context ()
200
+ target = triton .runtime .driver .active .get_current_target ()
201
+ backend = triton .compiler .compiler .make_backend (target )
202
+ options = backend .parse_options ({})
203
+
204
+ # ignore backend-specific kwargs same way as in the native Triton code
205
+ # https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596
206
+ # why this is important for user-defined Triton kernels on AMD: https://github.com/pytorch/pytorch/issues/140800
207
+ for name in list (kwargs ):
208
+ if name not in kernel .arg_names and name in options .__dict__ :
209
+ kwargs .pop (name )
210
+
199
211
if len (kwargs ) != len (kernel .arg_names ):
200
- raise ValueError ("Incorrect number of arguments passed to kernel" )
212
+ raise ValueError (
213
+ "Incorrect number of arguments passed to kernel: "
214
+ f"passed { list (kwargs .keys ())} , expected { kernel .arg_names } ."
215
+ )
201
216
202
217
# Replace all SymExprs with a regular value for TTIR generation
203
218
# Replace all FakeTensor/TensorBox with real tensors
@@ -239,10 +254,6 @@ def _get_specialization(args): # type: ignore[no-untyped-def]
239
254
if i not in kernel .constexprs
240
255
}
241
256
242
- context = triton ._C .libtriton .ir .context ()
243
- target = triton .runtime .driver .active .get_current_target ()
244
- backend = triton .compiler .compiler .make_backend (target )
245
- options = backend .parse_options ({})
246
257
triton ._C .libtriton .ir .load_dialects (context )
247
258
backend .load_dialects (context )
248
259
0 commit comments