Skip to content

Commit 304a97e

Browse files
aakhundovjataylo
authored andcommitted
[user triton] Ignore backend-specific args in the TTIR analysis (pytorch#141062)
Fixes pytorch#140800. On AMD, backend-specific args like `matrix_instr_nonkdim`, `waves_per_eu` and `kpack` are passed either direclty to the kernel or via `triton.Config`, whereas they don't exist as kernel parameters. Native Triton code handles those excessive args [here](https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596). In this PR, we add similar handling to the TTIR analysis code to avoid bailing out. Pull Request resolved: pytorch#141062 Approved by: https://github.com/oulgen (cherry picked from commit b740a1b)
1 parent abbfe77 commit 304a97e

File tree

1 file changed

+16
-5
lines changed

1 file changed

+16
-5
lines changed

torch/_higher_order_ops/triton_kernel_wrap.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,23 @@ def generate_ttir(kernel, kwargs):
136136

137137
assert isinstance(kernel, JITFunction)
138138

139+
context = triton._C.libtriton.ir.context()
140+
target = triton.runtime.driver.active.get_current_target()
141+
backend = triton.compiler.compiler.make_backend(target)
142+
options = backend.parse_options({})
143+
144+
# ignore backend-specific kwargs same way as in the native Triton code
145+
# https://github.com/triton-lang/triton/blob/a6bb57d6285e723c58e87dd7cba263db6efff789/python/triton/runtime/jit.py#L594-L596
146+
# why this is important for user-defined Triton kernels on AMD: https://github.com/pytorch/pytorch/issues/140800
147+
for name in list(kwargs):
148+
if name not in kernel.arg_names and name in options.__dict__:
149+
kwargs.pop(name)
150+
139151
if len(kwargs) != len(kernel.arg_names):
140-
raise ValueError("Incorrect number of arguments passed to kernel")
152+
raise ValueError(
153+
"Incorrect number of arguments passed to kernel: "
154+
f"passed {list(kwargs.keys())}, expected {kernel.arg_names}."
155+
)
141156

142157
# Replace all SymExprs with a regular value for TTIR generation
143158
# Replace all FakeTensor/TensorBox with real tensors
@@ -168,10 +183,6 @@ def generate_ttir(kernel, kwargs):
168183
if i not in kernel.constexprs
169184
}
170185

171-
context = triton._C.libtriton.ir.context()
172-
target = triton.runtime.driver.active.get_current_target()
173-
backend = triton.compiler.compiler.make_backend(target)
174-
options = backend.parse_options({})
175186
triton._C.libtriton.ir.load_dialects(context)
176187
backend.load_dialects(context)
177188

0 commit comments

Comments
 (0)