@@ -136,8 +136,23 @@ def generate_ttir(kernel, kwargs):
136
136
137
137
assert isinstance (kernel , JITFunction )
138
138
139
+ context = triton ._C .libtriton .ir .context ()
140
+ target = triton .runtime .driver .active .get_current_target ()
141
+ backend = triton .compiler .compiler .make_backend (target )
142
+ options = backend .parse_options ({})
143
+
144
+ # ignore backend-specific kwargs same way as in the native Triton code
145
+ # https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596
146
+ # why this is important for user-defined Triton kernels on AMD: https://github.com/pytorch/pytorch/issues/140800
147
+ for name in list (kwargs ):
148
+ if name not in kernel .arg_names and name in options .__dict__ :
149
+ kwargs .pop (name )
150
+
139
151
if len (kwargs ) != len (kernel .arg_names ):
140
- raise ValueError ("Incorrect number of arguments passed to kernel" )
152
+ raise ValueError (
153
+ "Incorrect number of arguments passed to kernel: "
154
+ f"passed { list (kwargs .keys ())} , expected { kernel .arg_names } ."
155
+ )
141
156
142
157
# Replace all SymExprs with a regular value for TTIR generation
143
158
# Replace all FakeTensor/TensorBox with real tensors
@@ -168,10 +183,6 @@ def generate_ttir(kernel, kwargs):
168
183
if i not in kernel .constexprs
169
184
}
170
185
171
- context = triton ._C .libtriton .ir .context ()
172
- target = triton .runtime .driver .active .get_current_target ()
173
- backend = triton .compiler .compiler .make_backend (target )
174
- options = backend .parse_options ({})
175
186
triton ._C .libtriton .ir .load_dialects (context )
176
187
backend .load_dialects (context )
177
188
0 commit comments