Skip to content

Conversation

Hzfengsy
Copy link
Collaborator

No description provided.

Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@Hzfengsy Hzfengsy changed the base branch from refactor to main June 24, 2025 09:23
@LeiWang1999 LeiWang1999 marked this pull request as ready for review June 24, 2025 09:26
Copy link
Contributor

Warning

Gemini encountered an error creating the summary. You can try again by commenting /gemini summary.

@LeiWang1999
Copy link
Member

/gemini summary

@LeiWang1999 LeiWang1999 self-requested a review June 24, 2025 09:45
@LeiWang1999 LeiWang1999 self-assigned this Jun 24, 2025
Copy link
Contributor

Summary of Changes

This pull request undertakes a significant refactoring of the TileLang codebase to align it with a newer upstream version of TVM. The changes involve adapting to updated TVM APIs, introducing new Tensor IR (TIR) optimization passes for improved performance and memory management, and enhancing development tooling. This modernization effort aims to ensure continued compatibility, improve code quality, and enable better utilization of hardware capabilities.

Highlights

  • TVM API Modernization: Updated core C++ and Python FFI interactions to align with recent TVM API changes, ensuring compatibility and leveraging modern TVM features. This includes widespread changes from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL, TVMArgs to PackedArgs, and NullOpt to std::nullopt.
  • New TIR Optimization Passes: Introduced three new passes: LowerStmtGroupBlock for block node removal, LowerThreadAllreduce for device-specific thread-level reductions, and StorageRewrite for advanced memory access pattern analysis and optimization.
  • FP8 Data Type Alignment: Standardized FP8 data type naming conventions across examples and internal operations (e.g., e4m3_float8 to float8_e4m3fn) for consistency with upstream TVM.
  • Improved Development Tooling: Integrated clang-tidy into the project's formatting and linting pipeline, enhancing C++ code quality and maintainability, and updated CMake configurations for better TVM FFI and Cython integration.
  • Codebase Refinements: Performed extensive code cleanup, including removal of unused includes, namespace adjustments, and minor formatting fixes across various source and test files.
Changelog
  • .gitignore
    • Added config.cmake to ignore user-specific CMake configurations.
  • 3rdparty/tvm
    • Updated TVM submodule to commit f77817bf7d03e836bd16380c456ed86ac2770502.
  • CMakeLists.txt
    • Changed default build type from RelWithDebInfo to Release (L41-L42).
    • Added ffi/include to TILE_LANG_INCLUDES (L148).
    • Linked tilelang shared library with tvm_ffi_header (L203).
    • Added tvm_cython as a dependency for tilelang when not prebuilt (L217-L220).
  • examples/cast/example_group_per_split_token_cast_to_fp8.py
    • Renamed e4m3_float8 to float8_e4m3fn for FP8 tensor types (L22).
  • examples/cast/example_per_token_cast_to_fp8.py
    • Renamed e4m3_float8 to float8_e4m3fn for FP8 tensor types (L20).
  • examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py
    • Renamed e4m3_float8 to float8_e4m3fn and e5m2_float8 to float8_e5m2 for FP8 tensor types (L81-L82).
  • format.sh
    • Integrated clang-tidy for C/C++ static analysis, with options for all files or changed files (L253-L307).
  • requirements-lint.txt
    • Added clang-tidy==15.0.7 as a linting requirement (L8).
  • src/ir.cc
    • Replaced NullOpt with std::nullopt for optional arguments (L71, L105, L163).
    • Updated TVM_REGISTER_GLOBAL macros to TVM_FFI_REGISTER_GLOBAL for FFI functions (L285-L288, L365).
  • src/layout/layout.cc
    • Replaced NullOpt with std::nullopt in Fragment constructors (L161, L169, L182, L204).
    • Updated FFI registration macros and argument types (TVMArgs/TVMRetValue to PackedArgs/Any) (L441-L498).
  • src/layout/utils.cc
    • Adjusted Var comparison logic for IterMark source (L136).
  • src/op/bulk_copy.cc
    • Unified FP8 data type checks to dtype.is_float8() (L46).
  • src/op/elem.cc
    • Replaced NullOpt with std::nullopt for optional arguments (L152).
    • Updated Fragment object handling in layout inference (L247, L249, L362-L363, L365, L367-L368).
  • src/op/logical.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L10).
  • src/op/math.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L10).
  • src/op/op.h
    • Changed TypedPackedFunc to ffi::TypedFunction (L28).
  • src/op/parallel.cc
    • Updated IntImmNode access (L200).
    • Replaced NullOpt with std::nullopt for optional arguments (L248, L249, L302, L329).
  • src/op/reduce.cc
    • Replaced NullOpt with std::nullopt for optional arguments (L190, L291).
    • Updated Var optional handling (L202-L203).
  • src/runtime/runtime.cc
    • Updated FFI header includes and removed using namespace runtime; (L13-L14, L18).
    • Adapted FFI argument extraction and registration to PackedArgs/Any and TVM_FFI_REGISTER_GLOBAL (L44-L75, L96-L100, L125-L161, L186-L190).
  • src/target/codegen_cpp.cc
    • Updated FFI context variable name from __tvm_module_ctx to __tvm_ffi_library_ctx (L40, L57).
    • Removed unused header includes (L25-L26, L30, L34, L38).
    • Replaced TVMValue with TVMFFIAny (L382, L388, L397).
  • src/target/codegen_cuda.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L10).
    • Updated global function retrieval from runtime::Registry::Get to ffi::Function::GetGlobal (L1100-L1101).
  • src/target/codegen_hip.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L10).
  • src/target/codegen_webgpu.cc
    • Updated FFI function types from PackedFunc to ffi::Function (L707, L711).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L776).
  • src/target/rt_mod_cpp.cc
    • Removed tvm::runtime::Registry usage (L10).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L72).
  • src/target/rt_mod_cuda.cc
    • Commented out grid_constant argument handling (L21-L27).
    • Removed tvm::runtime::Registry usage (L42).
    • Updated FFI function retrieval from Registry::Get to ffi::Function::GetGlobal (L57, L62).
    • Updated FFI registration macros from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L94, L96).
  • src/target/rt_mod_hip.cc
    • Updated FFI registration macros from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L101, L103).
  • src/transform/align_dynamic_shared_memory_allocations.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L153).
  • src/transform/annotate_device_regions.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L25).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L90).
  • src/transform/cluster_planning.cc
    • Updated copyright header (L1-L2).
    • Added FFI header include tvm/ffi/function.h (L9).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L120).
  • src/transform/common/loop_vectorization_utils.h
    • Replaced NullOpt with std::nullopt (L604).
    • Removed ProducerStoreNode visit method (L686-L689).
  • src/transform/config_index_bitwidth.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L5).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L91).
  • src/transform/eliminate_storage_sync_for_mbarrier.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L10).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L120).
  • src/transform/flatten_buffer.cc
    • Removed IsFromLegacyTESchedule check (L355-L360).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L366).
  • src/transform/frontend_legalize.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L91).
  • src/transform/if_stmt_binding.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L85).
  • src/transform/inject_fence_proxy.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L196).
  • src/transform/inject_pipeline.cc
    • Updated Map value types from ObjectRef to Any (L740).
    • Replaced NullOpt with std::nullopt (L751).
    • Adjusted Array<Integer> downcasting from annot to annot.value() (L958).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L1041).
  • src/transform/inject_ptx_async_copy.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L236).
  • src/transform/inject_tma_barrier.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L307).
  • src/transform/layout_inference.cc
    • Updated Fragment node handling (L294, L296).
    • Replaced NullOpt with std::nullopt (L457).
    • Adjusted Map annotation retrieval (L498).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L659).
  • src/transform/legalize_safe_memory_access.cc
    • Adjusted Map annotation retrieval (L319).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L359).
  • src/transform/legalize_vectorized_loop.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L91).
  • src/transform/loop_vectorize_dynamic.cc
    • Unified FP8 data type checks (L151).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L538).
  • src/transform/lower_device_storage_access_info.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L25).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L144).
  • src/transform/lower_hopper_intrin.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L153).
  • src/transform/lower_l2_persistent_annotation.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L103).
  • src/transform/lower_stmt_group_block.cc
    • New file: Added a pass to remove Block nodes with stmt_group annotation (L1-L73).
  • src/transform/lower_thread_allreduce.cc
    • New file: Added a pass to lower tvm_thread_allreduce to device-specific IR, including warp-level reductions (L1-L948).
  • src/transform/lower_tile_op.cc
    • Adjusted Map annotation retrieval (L114-L121).
    • Replaced NullOpt with std::nullopt (L234).
    • Changed Array<RelayExpr> to Array<RelaxExpr> (L312).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L489).
  • src/transform/make_packed_api.cc
    • Updated FFI header includes (L23-L24).
    • Removed unordered_set include (L33).
    • Changed FFI type codes and TVMValue to TVMFFIAny (L78-L111, L348-L373, L382, L388, L397).
    • Removed func_ptr->checked_type_ = func_ptr->func_type_annotation(); (L471).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L515).
  • src/transform/merge_if_stmt.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L96).
  • src/transform/merge_shared_memory_allocations.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L27).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L1026).
  • src/transform/multi_version_buffer_rewriter.cc
    • Updated num_stages_anno handling (L223, L230).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L343).
  • src/transform/persist_threadblock.cc
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L64).
  • src/transform/pipeline_planning.cc
    • Updated Map value types from ObjectRef to Any (L251, L445).
    • Updated annotation handling (L226, L230, L264, L273).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L495).
  • src/transform/simplify.cc
    • Replaced NullOpt with std::nullopt (L213, L448, L457).
    • Added used_in_buffer_def_ tracking for unused variable removal (L245-L248).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L473).
  • src/transform/storage_rewrite.cc
    • New file: Added a pass for memory access pattern analysis and optimization, including data type legalization and buffer merging (L1-L1859).
  • src/transform/thread_partial_sync.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L7).
    • Replaced NullOpt with std::nullopt (L275).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L377).
  • src/transform/thread_storage_sync.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L23).
    • Replaced NullOpt with std::nullopt (L284).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L554).
  • src/transform/vectorize_loop.cc
    • Changed FFI header include from tvm/runtime/registry.h to tvm/ffi/function.h (L25).
    • Replaced NullOpt with std::nullopt (L634).
    • Removed ProducerStoreNode visit method (L691-L694).
    • Updated TargetHasSVE usage (L799).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L840).
  • src/transform/warp_specialized_rewriter.cc
    • Updated Map value types from ObjectRef to Any (L445).
    • Updated annotation handling (L634, L646-L655, L684, L705).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L1216).
  • src/transform/wgmma_sync_rewriter.cc
    • Updated annotation handling (L134).
    • Updated FFI registration macro from TVM_REGISTER_GLOBAL to TVM_FFI_REGISTER_GLOBAL (L284).
  • testing/python/autotune/test_tilelang_autotune.py
    • Minor formatting adjustment (L37).
  • testing/python/cpu/test_tilelang_cpu_gemm.py
    • Modified if __name__ == "__main__": block for specific test execution (L120-L121).
  • testing/python/kernel/test_tilelang_kernel_bf16_gemm_mma.py
    • Modified if __name__ == "__main__": block for specific test execution (L226-L227).
  • testing/python/kernel/test_tilelang_kernel_dequantize_gemm.py
    • Added print(program.script()) for debugging (L102).
    • Modified if __name__ == "__main__": block for specific test execution (L643-L644).
  • testing/python/kernel/test_tilelang_kernel_int4_gemm_mma.py
    • Added tilelang.disable_cache() (L20).
    • Commented out @simplify_prim_func decorator (L22).
    • Added pass_configs to tilelang.compile call and print(kernel.get_kernel_source()) (L170-L173).
    • Modified if __name__ == "__main__": block for specific test execution (L409-L410).
  • testing/python/language/test_tilelang_language_annotate_pad.py
    • Removed print(kernel.get_kernel_source()) (L45).
  • testing/python/language/test_tilelang_language_copy.py
    • Added import tilelang.testing (L7).
  • testing/python/transform/test_tilelang_transform_cluster_planning.py
    • Cast clusterIdx.y annotation value to T.int32 (L48).
  • testing/python/transform/test_tilelang_transform_make_packed_api.py
    • Rewrote test cases to align with new TVM API for PrimFunc and IRModule (L28-L225).
    • Updated argument count assertions for MakePackedAPI (L65).
    • Removed requires_llvm decorators (L51, L163, L181, L194, L211, L236, L260, L290, L334, L353, L393, L414, L438, L461).
  • testing/python/transform/test_tilelang_transform_multi_version_buffer.py
    • Cast num_stages annotation value to T.int32 (L51, L84).
  • testing/python/transform/test_tilelang_transform_pipeline_planning.py
    • Cast software_pipeline_async_stages, software_pipeline_order, and software_pipeline_stage annotation values to T.int32 (L56-L58).
  • testing/python/transform/test_tilelang_transform_thread_sync.py
    • Removed tvm.tir.transform.StorageFlatten call (L11).
    • Added new test case test_sync_shared_dyn (L104-L142).
    • Removed requires_cuda decorator from test_sync_let_stmt (L145).
  • testing/python/transform/test_tilelang_transform_vectorize_loop.py
    • Removed tilelang import (L4).
    • Removed several test cases related to scalable vector errors and dtype mismatch (L58-L128, L329-L341).
    • Updated test_vectorize_with_if to use T.match_buffer (L128-L129, L141).
    • Added new test cases for buffer predication (L491-L583).
    • Removed requires_llvm decorators (L16, L58, L73, L86, L103, L120, L155, L177, L196, L215, L236, L260, L283, L334, L353, L393, L414, L438, L461).
  • testing/python/transform/test_tilelang_transform_warp_specialized.py
    • Cast num_stages annotation value to T.int32 (L49).
    • Modified if __name__ == "__main__": block for specific test execution (L123).
  • tilelang/init.py
    • Updated TVM FFI base import from tvm._ffi.base to tvm.base (L63, L75).
  • tilelang/_ffi_api.py
    • Updated TVM FFI import from tvm._ffi to tvm.ffi (L5).
  • tilelang/carver/analysis.py
    • Updated TVM FFI import (L8).
  • tilelang/carver/matmul_analysis.py
    • Removed skip_simplify=True from sch.reindex calls (L513, L515, L517).
  • tilelang/contrib/cc.py
    • Updated TVM FFI base import (L27).
  • tilelang/contrib/hipcc.py
    • Updated TVM FFI imports (L14, L17).
  • tilelang/contrib/nvcc.py
    • Updated TVM FFI imports (L13, L16, L186, L193, L258, L396, L409, L426).
  • tilelang/contrib/rocm.py
    • Updated TVM FFI imports (L24, L25, L103, L127, L229).
  • tilelang/engine/lower.py
    • Changed default target_host from stackvm to c (L135).
    • Updated TVM FFI global function retrieval (L150, L152, L164, L166, L178, L181, L184, L186, L188).
  • tilelang/engine/phase.py
    • Corrected module imports for StorageRewrite and LowerThreadAllreduce (L123, L144).
    • Adjusted allow_warp_specialized condition (L19).
  • tilelang/jit/adapter/ctypes/adapter.py
    • Updated TensorType import from tvm.relay to tvm.relax (L11).
  • tilelang/jit/adapter/cython/adapter.py
    • Updated TensorType import from tvm.relay to tvm.relax (L12).
  • tilelang/jit/adapter/wrapper.py
    • Added new FP8 data types (float8_e4m3fn, float8_e5m2) to CUDA type mapping (L186-L187).
  • tilelang/language/ast/_ffi_api.py
    • Updated TVM FFI import (L20).
  • tilelang/language/frame.py
    • Updated TVM FFI import (L5).
  • tilelang/language/kernel.py
    • Updated TVM FFI import (L10).
  • tilelang/language/logical.py
    • Updated BufferRegion import to include BufferLoad (L6).
    • Adjusted any_of and all_of functions to use BufferLoad and r.min for region access (L32-L41, L63-L74).
  • tilelang/language/memscope.py
    • Updated TVM FFI registry import (L4).
  • tilelang/language/parser/operation.py
    • Removed tvm._ffi.runtime_ctypes import (L24).
    • Minor formatting in _auto_broadcast function (L89, L92).
  • tilelang/language/tir/entry.py
    • Changed default check_well_formed parameter to False in prim_func decorator (L14).
  • tilelang/language/tir/op.py
    • Minor formatting in pow_of_int docstring (L2606).
  • tilelang/language/warpgroup.py
    • Updated TVM FFI import (L6).
  • tilelang/layout/fragment.py
    • Updated TVM FFI import (L14).
    • Added initialization for forward_index in Fragment constructor (L99-L100).
  • tilelang/layout/layout.py
    • Updated TVM FFI import (L14).
  • tilelang/transform/init.py
    • Added new transform passes (LowerStmtGroupBlock, LowerThreadAllreduce, StorageRewrite) (L62-L70, L386-L394, L397-L405).
Activity
  • The pull request was opened by @Hzfengsy.
  • A bot (github-actions[bot]) commented to remind the author to run bash format.sh (2025-06-24 09:23:36 UTC).
  • A bot (gemini-code-assist[bot]) attempted to create a summary but encountered an error (2025-06-24 09:28:55 UTC).
  • @LeiWang1999 requested a summary again from the bot (2025-06-24 09:45:14 UTC).

@Hzfengsy Hzfengsy force-pushed the refactor branch 4 times, most recently from 764d8be to e3f31b6 Compare July 6, 2025 05:37
@LeiWang1999 LeiWang1999 force-pushed the main branch 2 times, most recently from 7aea41e to b060c9f Compare July 20, 2025 05:30
LeiWang1999 and others added 16 commits July 24, 2025 15:29
…to False for improved flexibility in TIR function parsing.
Introduced the StorageRewrite function in the tilelang.transform module, which returns a TVM transform pass. This addition enhances the functionality of the module by providing a new transformation option for users.
- Updated instances of `NullOpt` to `std::nullopt` in `ir.cc` and `parallel.cc` for consistency with modern C++ practices.
- Enhanced layout inference logic in `layout_inference.cc` to improve type safety by replacing `as<Fragment>().get()` with `as<FragmentNode>()`.
- Adjusted error handling in `multi_version_buffer_rewriter.cc` and `persist_threadblock.cc` to use more concise null checks.
- Cleaned up test files by commenting out `tilelang.testing.main()` and replacing it with specific test function calls for better clarity.
- Removed unused test file `test_tilelang_kernel_deepseek_nsa.py` to streamline the testing suite.
…n handling

- Updated the TVM subproject to a dirty commit state.
- Refactored copyright headers in `cluster_planning.cc` to reflect the new licensing.
- Enhanced error handling in `lower_tile_op.cc` to check for missing padding map annotations.
- Modified test files to improve clarity and functionality, including adjustments to kernel compilation and test assertions.
- Updated various test cases to ensure proper handling of annotations and configurations in the TileLang testing framework.
- Changed the annotation type in the `test_warp_specialized` function from a literal integer to `T.int32(3)` for improved type safety and consistency with the TileLang framework.
- Replaced the direct call to `test_warp_specialized()` with `tilelang.testing.main()` in the test file to standardize test execution and improve integration with the TileLang testing framework.
…nce (tile-ai#594)

- Introduced a `strict_layout_map` to enhance layout inference by ensuring that buffers with strict layout requirements are properly accounted for during the inference process.
- Updated the inference logic to check for the presence of buffers in the `strict_layout_map` before applying layout changes, improving the accuracy of layout assignments.
- Refactored the layout inference steps to include the copying of layouts into the new strict map, ensuring a clear separation of layout handling based on inference levels.
* [Example] Update kernel compilation in examples to use @tilelang.jit

- Refactored multiple examples to eliminate the use of `tilelang.compile` for kernel creation, directly invoking the functions instead.
- Added `@tilelang.jit` decorators with appropriate output indices to enhance performance and maintainability.
- Improved code clarity by simplifying the kernel invocation process across various examples, ensuring consistency in how kernels are defined and executed.

* format

* Update example_tilelang_sparse_gqa_decode_varlen_indice.py

* Update example_dequant_gemm_fine_grained.py

* Update example_gemm_autotune.py

---------

Co-authored-by: Lei Wang <[email protected]>
…shared range checks (tile-ai#599)

* [Enhancement] Improve error messaging for global and shared range legality checks in LowerBulkCopy

- Updated error messages in the LowerBulkCopy function to provide clearer context when global and shared ranges are illegal.
- Enhanced the readability of the error output by including tensor names, improving debugging and validation processes during bulk copy operations.

* [Enhancement] Refine error messaging in LowerBulkCopy for global and shared range checks

- Improved the clarity of error messages in the LowerBulkCopy function by enhancing the output format.
- Included additional context in error messages to aid debugging when global and shared ranges are found to be illegal, ensuring better traceability during bulk copy operations.
…Y_MERGE` to enable aggressive shared memory reuse (tile-ai#602)

* [Enhancement] Add aggressive shared memory merge option in memory allocation

- Introduced a new configuration option `tl.enable_aggressive_shared_memory_merge` to enable aggressive merging of shared memory allocations.
- Updated the `SharedMemLinearAccessPatternFinder` class to support an aggressive merge strategy, allowing for improved memory reuse.
- Modified the `MergeSharedMemoryAllocations` function to incorporate the new merging strategy based on the configuration.
- Enhanced the `PassConfigKey` enumeration to include the new aggressive merge option, ensuring it can be configured appropriately.

* lint fix

* [Enhancement] Add aggressive shared memory merge configuration option

- Introduced a new configuration option `kEnableAggressiveSharedMemoryMerge` to enable aggressive merging of shared memory allocations, enhancing memory management capabilities.

* [Enhancement] Update MergeSharedMemoryAllocations to support aggressive merge option

- Modified the `MergeSharedMemoryAllocations` function to accept an `enable_aggressive_merge` parameter, allowing for more flexible memory management.
- Introduced a new helper function `should_enable_aggressive_merge` to determine the aggressive merge configuration based on the pass context and target.
- Updated the relevant calls in the `phase.py` and `__init__.py` files to utilize the new aggressive merge functionality, enhancing the overall memory allocation strategy.
- Replaced the use of `tiled_mma.accumulate_ = GMMA::ScaleOut::Zero` with a call to `clear(acc)` for better clarity and maintainability in the accumulation logic.
- This change enhances the readability of the code by standardizing the approach to clearing accumulation values across multiple sections of the file.
@LeiWang1999
Copy link
Member

LeiWang1999 commented Jul 28, 2025

Likely the CI passed, but we still need to finish these items:

  • Check for performance regression
  • Clean up the codebase

@LeiWang1999
Copy link
Member

image GEMM verified successfully, and no change in the generated CUDA

@LeiWang1999
Copy link
Member

Summarize part of the rebase pr:

  1. Support T.thread_return() → CUDA return syntax
    Added support for translating T.thread_return() to CUDA's native return statement.

  2. Dynamic type support for function inputs
    Functions now accept dynamically typed parameters using typing:

    dyn_type = T.int32 or T.float
    @T.prim_func
    def main(
        a: dyn_type,
    )
  3. Device Function Codegen
    Added support for generating __device__ functions in CUDA:

    @I.ir_module
    class Module:
        @T.prim_func(private=True)
        def add(a: T.int32, b: T.int32) -> T.int32:
            return a + b
    
        @T.prim_func
        def main(
            A: T.Buffer((128, 128), "int32"),
            B: T.Buffer((128, 128), "int32"),
            C: T.Buffer((128, 128), "int32"),
        ):
            T.func_attr({"global_symbol": "main"})
            length: T.int32 = Module.add(64, 64)  # Host call
            for bx in T.thread_binding(length, "blockIdx.x"):
                for tx in T.thread_binding(length, "threadIdx.x"):
                    C[bx, tx] = Module.add(A[bx, tx], B[bx, tx])  # Device call

    After compilation, add becomes a CUDA __device__ function.

  4. Cython-based Python/C++ interop
    Replaced ctypes with Cython for all Python/C++ interactions:

    • Python → C++ calls
    • C++ → Cython calls
      This improves performance by around 100x and reduces CPU overhead during compile/runtime.
  5. FP8 data type standardization
    Migrated e5m2_float8 and similar types to Torch-standardized variantsfloat8_e5m2 and etc.

@LeiWang1999 LeiWang1999 merged commit a7c9a8b into tile-ai:main Jul 30, 2025
2 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.