Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions examples/dynamic_shape/example_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
import tilelang.testing
from tilelang import tvm as tvm

tilelang.testing.set_random_seed(0)
tilelang.disable_cache()


@tilelang.jit(pass_configs={"tl.disable_dynamic_tail_split": True, "tl.dynamic_alignment": 8})
def matmul_dynamic_mnk(
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,6 @@ setuptools
einops
attrs
decorator
flash-attn
flash-attn<=2.8.0
scipy
tornado
119 changes: 95 additions & 24 deletions tilelang/jit/adapter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,37 +107,108 @@ def get_annotated_mod(


def pythonic_expr(expr: tvm.tir.PrimExpr) -> str:
"""
Converts a TVM PrimExpr into a Python-style string, correctly handling operator precedence.

Args:
expr: The TVM PrimExpr to convert.

Returns:
A string representation of the expression.
"""
if not isinstance(expr, tvm.tir.PrimExpr):
return str(expr)
python_str = ""
node_to_str_map = {} # Stores string representation for each node

def _pythonic_visitor(node):
# 1. Define operator precedence (higher value means higher precedence)
# Based on Python's operator precedence
PRECEDENCE = {
tvm.tir.Call: 20, # Includes min, max
tvm.tir.Cast: 20, # Treated like a function call
tvm.tir.Mul: 13,
tvm.tir.FloorDiv: 13,
tvm.tir.Div: 13, # For tvm.tir.Div if it appears
tvm.tir.FloorMod: 13,
tvm.tir.Add: 12,
tvm.tir.Sub: 12,
tvm.tir.LT: 10,
tvm.tir.LE: 10,
tvm.tir.GT: 10,
tvm.tir.GE: 10,
tvm.tir.EQ: 10,
tvm.tir.NE: 10,
tvm.tir.And: 5,
tvm.tir.Or: 4,
# Atoms (Var, IntImm) have the highest precedence implicitly
}
# By default, atomic expressions (variables, constants) have the highest precedence
ATOMIC_PRECEDENCE = 100

node_to_result_map = {} # Stores (string, precedence) for each node

def _visitor(node):
# 2. Visitor returns (str, precedence) tuple
if node in node_to_result_map:
return

if isinstance(node, tvm.tir.Var):
s = node.name
s, p = node.name, ATOMIC_PRECEDENCE
elif isinstance(node, (tvm.tir.IntImm, tvm.tir.FloatImm)):
# Integer constant: use value directly (ignore type)
s = str(node.value)
s, p = str(node.value), ATOMIC_PRECEDENCE
elif isinstance(node, tvm.tir.Cast):
# Type cast: represent as (type)value
dtype_map = {"int64": "int64_t", "int32": "int32_t", "int8": "int8_t"}
dtype = dtype_map.get(str(node.dtype), str(node.dtype))
value_str = node_to_str_map.get(node.value, str(node.value))
s = f"({dtype}){value_str}"
elif isinstance(node, tvm.tir.Mul):
# Multiplication: format as 'left * right'
a_str = node_to_str_map.get(node.a, str(node.a))
b_str = node_to_str_map.get(node.b, str(node.b))
s = f"{a_str} * {b_str}"
# C-style cast has high precedence
value_str, _ = node_to_result_map[node.value]
s = f"({node.dtype}){value_str}"
p = PRECEDENCE.get(type(node), ATOMIC_PRECEDENCE)
elif isinstance(
node,
(tvm.tir.Mul, tvm.tir.FloorDiv, tvm.tir.Add, tvm.tir.Sub, tvm.tir.FloorMod, tvm.tir.LT,
tvm.tir.LE, tvm.tir.GT, tvm.tir.GE, tvm.tir.EQ, tvm.tir.NE, tvm.tir.And, tvm.tir.Or)):
op_map = {
tvm.tir.Mul: "*",
tvm.tir.FloorDiv: "/",
tvm.tir.Add: "+",
tvm.tir.Sub: "-",
tvm.tir.FloorMod: "%",
tvm.tir.LT: "<",
tvm.tir.LE: "<=",
tvm.tir.GT: ">",
tvm.tir.GE: ">=",
tvm.tir.EQ: "==",
tvm.tir.NE: "!=",
tvm.tir.And: "and",
tvm.tir.Or: "or",
}
op_str = f" {op_map[type(node)]} "
my_precedence = PRECEDENCE[type(node)]

a_str, a_precedence = node_to_result_map[node.a]
b_str, b_precedence = node_to_result_map[node.b]

# 3. Add parentheses intelligently
# Add parentheses if the left operand's precedence is lower than the current operator
if a_precedence < my_precedence:
a_str = f"({a_str})"
# Add parentheses if the right operand's precedence is lower than or equal to the current operator
# 'Equal' is to handle non-associative operations, e.g., a - (b - c)
if b_precedence <= my_precedence:
b_str = f"({b_str})"

s = f"{a_str}{op_str}{b_str}"
p = my_precedence
elif isinstance(node, (tvm.tir.Min, tvm.tir.Max)):
op_name = "min" if isinstance(node, tvm.tir.Min) else "max"
a_str, _ = node_to_result_map[node.a]
b_str, _ = node_to_result_map[node.b]
s = f"{op_name}({a_str}, {b_str})"
# Function calls have high precedence
p = PRECEDENCE.get(tvm.tir.Call, ATOMIC_PRECEDENCE)
else:
# Other nodes: use default string representation
s = str(node)
# Fallback for unhandled expression types
s, p = str(node), 0

# Store current node's string representation
node_to_str_map[node] = s
nonlocal python_str
python_str = s # Update global string (retain root node in the end)
node_to_result_map[node] = (s, p)

# Perform post-order traversal
tvm.tir.stmt_functor.post_order_visit(expr, _pythonic_visitor)
return python_str
tvm.tir.stmt_functor.post_order_visit(expr, _visitor)

return next(iter(node_to_result_map[expr]), "")
27 changes: 5 additions & 22 deletions tilelang/jit/adapter/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,15 +281,6 @@ def maybe_desc(name: str, matches: List[str], i: int):
call_args.append(match)
return call_args

def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")

has_l2_persistent_map = False
for function_name, _ in function_informations.items():
if function_name in self.l2_persistent_map:
Expand All @@ -315,12 +306,13 @@ def legalize_c(p):
index = code.index("{", index)

block_str = "dim3({}, {}, {})".format(
legalize_c(block_info[0]),
legalize_c(block_info[1]),
legalize_c(block_info[2]),
pythonic_expr(block_info[0]),
pythonic_expr(block_info[1]),
pythonic_expr(block_info[2]),
)
grid_str = "dim3({}, {}, {})".format(
legalize_c(grid_info[0]), legalize_c(grid_info[1]), legalize_c(grid_info[2]))
pythonic_expr(grid_info[0]), pythonic_expr(grid_info[1]),
pythonic_expr(grid_info[2]))
smem_str = 0 if dynamic_smem_buf is None else dynamic_smem_buf
init_l2_persistent_map = self.generate_l2_persistent_map(function_name)
kernel_launch_code += init_l2_persistent_map
Expand Down Expand Up @@ -894,15 +886,6 @@ def func_call_args(s, function_args):
call_args.append(match)
return call_args

def legalize_c(p):
# Convert TIR expressions to legal C expressions
# Directly convert to string since the special case handling
# does not alter the string representation for `tvm.tir.Var` and `IntImm`.
# Replace Python's floor division operator with C's division operator
if isinstance(p, tvm.tir.IntImm):
p = int(p)
return str(p).replace("//", "/")

_call_str = """"""

for function_name, _ in function_informations.items():
Expand Down
Loading