-
Notifications
You must be signed in to change notification settings - Fork 262
[Draft] Tilelang JITv2: Simpler Kernel Declaration, Smart Code Generation, More Syntax Sugar and Extremely Low Overhead #916
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Important Review skippedDraft detected. Please check the settings in the CodeRabbit UI or the You can disable this status message by setting the ✨ Finishing touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
This is huge ! |
Tilelang JITv2
In this PR we introduce Tilelang JITv2, a new frontend for Tilelang with modern and attractive features.
Features
Kernel Declaration
Function declaration has been simplified:
Kernel Call
When calling functions, tensor shapes, strides, and dtypes are automatically inferred:
Auto Tuning
Auto tuning can be done via default arguments:
Or on-the-fly:
Smarter Static Evaluation
JITv2 preserves as much Python code as possible, allowing calls to custom Python functions or conditional kernel generation:
Smarter Type Hinting
JITv2 not only eliminates annoying type warnings, but also adds extensive type annotations. This helps you clearly see each Tensor’s dimensions, and marks whether a value is on the Python side or kernel side. Even generated functions and the JIT-compiled kernels have friendly type hints:
Extremely Low Overhead
JITv2's Python overhead has been optimized to the extreme. In the fast path, only dynamic parameters are checked, bringing overhead in line with calling a torch function (e.g.,
torch.add
):Architecture
The Tilelang JIT workflow:
Static & Dynamic Arguments
JITv2 inspects function signatures to determine which parameters are const and which are dyn:
int
,float
, andptr
; treated astir.Var
Tuple
overList
)Tensor
must be explicitly annotated because itsdata_ptr
is always dynamicint
but pass adict
) — note: validation is hard (like writing apydantic
)Argument Parser
JITv2 generates Python code for the fast path, which unpacks const and dyn arguments and then invokes the kernel:
torch.to_dlpack
K
. More complex asserts may be compiled to host code (not yet supported)Memory Allocation & Return Values
Inside functions, use
T.alloc_global
to create global buffers:T.alloc_global
is friendlier for type linting — it’s translated intotorch.empty
T.alloc_xxx
must be assigned to a variable (x = T.alloc_xxx()
), not passed directly as a function parameter (e.g.,foo(T.alloc_shared(...))
is not allowed)BLOCK_M + BLOCK_N
is not allowed):TODOs
tl.language