Skip to content

Commit b740a1b

Browse files
aakhundovpytorchmergebot
authored andcommitted
[user triton] Ignore backend-specific args in the TTIR analysis (pytorch#141062)
Fixes pytorch#140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. Pull Request resolved: pytorch#141062 Approved by: https://github.com/oulgen
1 parent 7c7c346 commit b740a1b

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,23 @@ def generate_ttir(
196196

197197
assert isinstance(kernel, JITFunction)
198198

199+
context = triton._C.libtriton.ir.context()
200+
target = triton.runtime.driver.active.get_current_target()
201+
backend = triton.compiler.compiler.make_backend(target)
202+
options = backend.parse_options({})
203+
204+
# ignore backend-specific kwargs same way as in the native Triton code
205+
# https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596
206+
# why this is important for user-defined Triton kernels on AMD: https://github.com/pytorch/pytorch/issues/140800
207+
for name in list(kwargs):
208+
if name not in kernel.arg_names and name in options.__dict__:
209+
kwargs.pop(name)
210+
199211
if len(kwargs) != len(kernel.arg_names):
200-
raise ValueError("Incorrect number of arguments passed to kernel")
212+
raise ValueError(
213+
"Incorrect number of arguments passed to kernel: "
214+
f"passed {list(kwargs.keys())}, expected {kernel.arg_names}."
215+
)
201216

202217
# Replace all SymExprs with a regular value for TTIR generation
203218
# Replace all FakeTensor/TensorBox with real tensors
@@ -239,10 +254,6 @@ def _get_specialization(args): # type: ignore[no-untyped-def]
239254
if i not in kernel.constexprs
240255
}
241256

242-
context = triton._C.libtriton.ir.context()
243-
target = triton.runtime.driver.active.get_current_target()
244-
backend = triton.compiler.compiler.make_backend(target)
245-
options = backend.parse_options({})
246257
triton._C.libtriton.ir.load_dialects(context)
247258
backend.load_dialects(context)
248259

0 commit comments

Comments
 (0)