Skip to content

Conversation

kurisu6912
Copy link
Collaborator

Tilelang JITv2

In this PR we introduce Tilelang JITv2, a new frontend for Tilelang with modern and attractive features.

Features

Kernel Declaration

Function declaration has been simplified:

image

Kernel Call

When calling functions, tensor shapes, strides, and dtypes are automatically inferred:

# before
ker_1 = matmul(1024, 1024, 1024, 'float32')
c1 = ker_1(a1, b1)
ker_2 = matmul(1024, 1024, 512, 'float32')
c2 = ker_2(a2, b2)

# after
gemm(a1, b1)
gemm(a2, b2)

Auto Tuning

Auto tuning can be done via default arguments:

@tl.jit
def add(
    A: tl.Tensor[int],
    B: tl.Tensor[int],
    block: int = tune([128, 256, 512])
):
    ...

Or on-the-fly:

add(A, B, tune([64, 128]))

Smarter Static Evaluation

JITv2 preserves as much Python code as possible, allowing calls to custom Python functions or conditional kernel generation:

@tl.jit
def gemm(
    ...
    split_k: bool = False
):
    block_size = my_super_block_size_huristic(M, N, K)
    if split_k: # split_k is a constant value
        with tl.Kernel(...) as ...:
            ...
    else:
        with tl.Kernel(...) as ...:
            ...

    return C

Smarter Type Hinting

JITv2 not only eliminates annoying type warnings, but also adds extensive type annotations. This helps you clearly see each Tensor’s dimensions, and marks whether a value is on the Python side or kernel side. Even generated functions and the JIT-compiled kernels have friendly type hints:

image

Extremely Low Overhead

JITv2's Python overhead has been optimized to the extreme. In the fast path, only dynamic parameters are checked, bringing overhead in line with calling a torch function (e.g., torch.add):

A = torch.randn(128, dtype=torch.float16, device="cuda")
B = torch.randn(128, dtype=torch.float16, device="cuda")

# torch.add:  ~ 6.5us
C_1 = A + B
# jit kernel: ~ 7.5us (cached)
C_2 = add(A, B)

Architecture

The Tilelang JIT workflow:

  1. Py-to-Py generates two pieces of code: argument parser and JIT function generator
  2. Fast path (~1.5 μs): Calls the kernel, argument parser separates static and dynamic parameters; static cache hit → directly calls C++ library functions
  3. Slow path: Static cache miss → kernel needs to be recompiled
whiteboard_exported_image (1)

Static & Dynamic Arguments

JITv2 inspects function signatures to determine which parameters are const and which are dyn:

  • dyn supports only int, float, and ptr; treated as tir.Var
  • const can be any type (simple types preferred; prefer Tuple over List)
  • dyn types must be explicitly annotated; Tensor must be explicitly annotated because its data_ptr is always dynamic
  • const arguments can differ from annotation (e.g., annotate int but pass a dict) — note: validation is hard (like writing a pydantic)
whiteboard_exported_image (2)

Argument Parser

JITv2 generates Python code for the fast path, which unpacks const and dyn arguments and then invokes the kernel:

  • Optimized Python statements: Each statement in the fast path is carefully designed, using bytecode fast to execute — overhead is minimal, even slightly smaller than torch.to_dlpack
  • Static check cache: The fast path does not perform type checks for const variables; instead, these are checked at compile time (e.g., wrong tensor shape → cache miss → kernel compiled → value range check)
  • Dynamic type checks: Fast path performs simple dynamic checks, e.g., asserting equal values for K. More complex asserts may be compiled to host code (not yet supported)
_K = dyn[int, '_K']
def foo(
    a: Tensor[int, _K],
    b: Tensor[int, _K],
    c: int,
):
    pass
# generated code
def foo_fastpath(a, b, c):
    # 1. Unpack type info
    # 1.1 Unpacking a tensor ~600 ns; each of the following lines takes ~200 ns, heavily optimized
    assert a.device != __device_cpu__, "Expected a non CPU tensor"
    a__shape_0, a__shape_1 = a.shape
    a__stride_0, a__stride_1 = a.stride()
    assert b.device != __device_cpu__, "Expected a non CPU tensor"
    #                  ^- note: torch.device('cpu') costs 200+ ns; using closure trick, __device_cpu__ costs 5 ns
    b__shape_0, b__shape_1 = b.shape
    b__stride_0, b__stride_1 = b.stride()
    # 2. Construct argument lists ~20–50 ns
    __const_args__ = (
        a.dtype, a__shape_0, a__shape_1, a__stride_0, a__stride_1,
        b.dtype, b__shape_0, b__shape_1, b__stride_0, b__stride_1,
        c)
    __dyn_args__ = (a.data_ptr(), b.data_ptr())
    return __const_args, __dyn_args__

Memory Allocation & Return Values

Inside functions, use T.alloc_global to create global buffers:

  • T.alloc_global is friendlier for type linting — it’s translated into torch.empty
  • T.alloc_xxx must be assigned to a variable (x = T.alloc_xxx()), not passed directly as a function parameter (e.g., foo(T.alloc_shared(...)) is not allowed)
  • Return objects must be global buffers; returning Python objects is not supported (e.g., returning BLOCK_M + BLOCK_N is not allowed):
@T.prim_func
def gemm(
    A: T.Tensor[int, int],
    B: T.Tensor[int, int],
    out_ty  = torch.half,
    BLOCK_M = T.tune([64, 128, 256]),
    BLOCK_N = T.tune([64, 128, 256]),
):
    # Quickly get dimensions
    (N, K), (M, K2) = A.shape, B.shape
    assert K == K2, "Expect 2 matrices with identical K dimension"
    # Allocate memory for output
    out = T.alloc_global((N, M), dtype=out_ty)
    with T.Kernel((T.ceildiv(M, BLOCK_M), T.ceildiv(N, BLOCK_N)), threads=128) as (bx, by):
        pass
    return out

TODOs

  • Integrate with tl.language
  • Add auto tuner

Copy link
Contributor

coderabbitai bot commented Sep 30, 2025

Important

Review skipped

Draft detected.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

✨ Finishing touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@LeiWang1999
Copy link
Member

This is huge !

@kurisu6912 kurisu6912 closed this Oct 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants