Skip to content

Conversation

Alex4210987
Copy link
Collaborator

@Alex4210987 Alex4210987 commented Oct 10, 2025

Summary by CodeRabbit

  • New Features

    • Expanded AMD FlashAttention with autotuning, deterministic reference outputs, and enhanced backward support.
    • Broadened FP8 support (additional formats and vectors) for HIP, improving compatibility on AMD GPUs.
  • Performance

    • Added multi-stage execution options and improved kernel wiring for AMD FlashAttention.
    • Adjusted attention scaling and exponent usage for potential speed and numerical stability gains.
    • Minor HIP codegen improvements and loop control support.
  • Tests

    • Removed AMD forward example invocations from the test script.
  • Refactor

    • Simplified layout hints in attention examples to streamline execution without changing results.

xinxyxiao and others added 30 commits July 29, 2025 03:26
…nd clarity (tile-ai#668)

- Enhanced buffer index handling to address precision issues by removing redundant operations.
- Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection.
- Updated related documentation to reflect changes in buffer management practices.
…ed flexibility

- Introduced a new input.txt file for configurable parameters.
- Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack.
- Streamlined the main function for better clarity and organization.
- Added a new test script to facilitate running the example with specified parameters.
… example with swizzle layout annotations

- Deleted input.txt and test.sh files as they are no longer needed.
- Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance.
- Reintroduced swizzle usage in the kernel for better performance.
- Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`.
- Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls.
- Removed outdated comments and improved code organization for better readability.
- Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function.
- Streamlined the `main` function parameter formatting for consistency.
- Removed unnecessary blank lines to enhance overall code organization.
- Improved the `example_amd_flash_attn_fwd.py` script for better clarity and organization.
- Added new CI workflows for AMD and documentation publishing.
- Updated various requirements files to include necessary dependencies.
- Introduced new test cases and examples for better coverage and functionality.
- Refactored existing code for improved readability and maintainability.
- Introduced `example_amd_flash_attn_bwd.py` for backward attention computation using TileLang.
- Added `test.sh` script to facilitate running the new example with specified parameters.
- Enhanced the overall structure and organization of the example for better clarity and usability.
- Reduced the number of threads and `num_split_q` options for improved performance.
- Adjusted `panel_size` options to streamline configuration settings.
- Introduced a new example script `example_amd_flash_attn_bwd.py` demonstrating the forward and backward operations of Flash Attention using TileLang.
- Implemented JIT-compiled functions for both forward and backward passes, including preprocessing and postprocessing steps.
- Added a main function to facilitate testing and benchmarking of the attention mechanism with configurable parameters.
- Included reference implementation for validation against PyTorch's attention mechanism.

This addition enhances the examples directory by providing a comprehensive guide for users to understand and utilize Flash Attention in their applications.
- Updated `example_amd_flash_attn_bwd.py` to include more comprehensive testing features for the Flash Attention implementation.
- Improved the main function to allow for better parameter configuration and benchmarking.
- Added validation checks against PyTorch's attention mechanism to ensure accuracy and reliability of the example.

This update aims to provide users with a more robust tool for understanding and utilizing Flash Attention in their applications.
- Updated file name from `intrin_rule_hip.cc` to `intrin_rule_cuda.cc` to reflect the change in focus from HIP to CUDA intrinsic rules.
- Adjusted include paths for better organization and clarity in the code structure.
…installation

- Removed the installation of `flash_attn==2.5.8` to streamline the CI process.
- Added a step to uninstall `torch`, `torchvision`, and `torchaudio` prior to installing pre-release versions, ensuring compatibility and reducing potential conflicts.
…rd example

- Eliminated the allocation of shared memory for `dv_shared` and `dk_shared` in `example_amd_flash_attn_bwd.py` to streamline memory usage and improve performance.
- This change focuses on optimizing the backward pass implementation by reducing unnecessary memory overhead.
- Eliminated the step to uninstall `torch`, `torchvision`, and `torchaudio` in the AMD CI workflow, as it is no longer required for the installation of pre-release versions.
- This change simplifies the CI process and reduces potential overhead during package management.
- Updated the return statement to use std::string for concatenation in the case of 16-bit types, improving code clarity.
- Added a null check for the CallNode pointer in DispatchHIPWarpActiveMask to enhance robustness and prevent potential dereferencing issues.
- Adjusted the formatting of TVM_REGISTER_OP calls for better readability by aligning method chaining.
- No functional changes were made; this update focuses on code style improvements to enhance maintainability.
- Renamed the file from `intrin_rule_cuda.cc` to `intrin_rule_hip.cc` to accurately reflect the focus on HIP intrinsic rules.
- Updated the file documentation to clarify its purpose as related to HIP rather than CUDA.
xinyxiao and others added 21 commits September 10, 2025 09:30
…ttn_bwd.py

- Updated scaling factor application for improved numerical stability in gradient calculations.
- Refined tensor handling to ensure consistency with forward pass operations.
- Optimized atomic operations for writing gradients to dK and dV using fp32 for better precision.
- Adjusted comments for clarity and alignment with standard implementation practices.
…update test.sh

- Increased the range of block sizes and stages for forward and backward configurations to enhance performance tuning.
- Adjusted the test script to include additional parameters for batch size and head dimensions, ensuring consistency with the forward example.
- Improved comments for clarity and alignment with the updated configurations.
…h_attn_bwd.py

- Updated FLOPs calculation to account for both forward and backward passes, clarifying the total computational cost.
- Modified benchmarking functions to evaluate the complete forward and backward performance of both reference and Tile-lang implementations.
- Improved comments for better understanding of the performance metrics and implementation details.
- Removed unnecessary parameter from test.sh to streamline execution.
…rd attention execution for streamlined testing.
…le_amd_flash_attn_bwd.py and example_amd_flash_attn_fwd.py

- Updated the forward function to return both output and log-sum-exp (LSE) values for improved gradient calculations.
- Enhanced autotuner configurations for forward pass, including new parameters for better performance tuning.
- Refined scaling factor calculations for numerical stability in both forward and backward passes.
- Improved comments and documentation for clarity and consistency across implementations.
- Adjusted main function to reflect changes in parameter handling and ensure compatibility with new output requirements.
- Removed outdated comments and improved clarity in the code.
- Enhanced the forward function to consistently return output and log-sum-exp (LSE) values.
- Updated autotuner configurations to include new parameters for better performance tuning.
- Refined tensor handling and scaling factor calculations for improved numerical stability.
- Adjusted the main function to ensure compatibility with updated output requirements and parameter handling.
…ttn_bwd.py

- Updated configuration parameters for backward calculations, including new options for block sizes, threads, and rasterization.
- Added new parameters (k_pack, qk_coalesced_width, v_coalesced_width) to improve performance tuning and memory access patterns.
- Modified tensor copy operations to utilize coalesced widths for optimized memory loads.
- Enhanced GEMM operations with k_pack for improved computational efficiency.
- Refined the configuration generation logic to accommodate the new parameters, ensuring comprehensive coverage for backward pass scenarios.
…n_bwd.py

- Updated backward configuration parameters to include larger block sizes and a wider range of threads for enhanced performance tuning.
- Removed unnecessary parameters (k_pack, qk_coalesced_width, v_coalesced_width) from function signatures and tensor operations to simplify the implementation.
- Optimized tensor copy operations by eliminating coalesced width specifications, streamlining memory access patterns.
- Adjusted GEMM operations to improve computational efficiency without the use of k_pack.
- Added support for additional FP8 types (e4m3, e4m3b11fnuz, e5m2fnuz, e8m0) in codegen_hip.cc to improve compatibility.
- Updated error logging to include unsupported FP8 type details for better debugging.
- Implemented handling for loop break and no-op register management in HIP within VisitExpr_ method.
- Introduced new FP8 vector types (e5 and e8) in hip_fp8.h for enhanced functionality.
- Added overloads for AtomicAdd in common.h to support both pointer and value arguments.
- Expanded FP8 type support in codegen_hip.cc to include additional float8 formats.
- Updated gemm.h to clarify the handling of the accumulator when clear_accum is true.
- Added comments in hip_fp8.h to indicate that E8M0 types are not supported in the current HIP version.
…ple_amd_flash_attn_bwd.py for cleaner output.
…xample_amd_flash_attn_fwd.py by adding spaces for improved readability in configuration parameters and print statements.
- Reorganized and cleaned up code in codegen_hip.cc for better readability and maintainability.
- Enhanced handling of FP8 types, including additional formats and improved error logging for unsupported types.
- Updated AtomicAdd function in common.h to streamline its implementation.
- Refined the PrintVecElemLoadExpr method to handle volatile loads more effectively.
- Added function to manage the addition of new functions in the code generation process.
- Adjusted the indentation of the MFMA call code block in codegen_hip.cc for improved readability and consistency.
- Reintroduced necessary includes and reorganized code in codegen_hip.cc for improved structure and readability.
- Enhanced the GetFP8Type function to support additional FP8 formats and improved error handling for unsupported types.
- Updated PrintType and PrintVecElemLoadExpr methods to better manage type conversions and vector element loading.
- Refined the AddFunction method to streamline function addition in the code generation process.
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

…bwd.py

- Updated the GEMM operation to use shared memory for improved performance.
- Adjusted parallelization parameters to enhance efficiency in the backward pass.
Copy link
Contributor

coderabbitai bot commented Oct 10, 2025

Walkthrough

The change set expands AMD FlashAttention examples with new forward/backward kernels, reference, autotune configs, diagnostics, and CLI updates; adjusts the forward example’s scaling/exponent; removes two test invocations; simplifies layout annotations in another example; and extends HIP codegen/templates with FP8 types, atomic helpers, loop/no-op handling, and GEMM constraints.

Changes

Cohort / File(s) Summary
AMD FlashAttention examples
examples/amd/example_amd_flash_attn_bwd.py, examples/amd/example_amd_flash_attn_fwd.py, examples/amd/test.sh
Adds ref implementation, forward/backward kernels, autotune configs, diagnostics, benchmarking, and updated CLI; modifies forward scaling (exp2→exp) and autotune stages; removes two forward test runs from script.
FlashAttention example (generic)
examples/flash_attention/example_mha_bwd.py
Removes explicit shared-memory layout annotations for select tensors; computation flow unchanged.
HIP codegen
src/target/codegen_hip.cc
Adds FP8 variants handling, emits loop_break and no-op no_set_max_nreg, adjusts MFMA emission formatting, extends FP8 constant printing.
HIP templates
src/tl_templates/hip/common.h, src/tl_templates/hip/gemm.h, src/tl_templates/hip/hip_fp8.h
Adds AtomicAdd overloads and AtomicAddRet; relaxes GemmTensorOp clear_accum assert (commented); introduces FP8 E5M2 aliases and vector structs (x4/x8/x16) with constructors/conversions.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant CLI as CLI/Caller
    participant AT as Autotuner
    participant JIT as JIT Compiler
    participant K as GPU Kernel (FWD)
    participant GPU as GPU

    CLI->>AT: fast_flashattn(args, configs)
    AT->>JIT: Select+compile kernel variant
    JIT-->>AT: Compiled kernel handle
    AT->>K: Launch with Q,K,V and params
    K->>GPU: Load tiles, compute QK, mask (causal), softmax, apply V
    GPU-->>K: Tile results
    K-->>CLI: Output + (optionally) lse/logits stats
Loading
sequenceDiagram
    autonumber
    participant CLI as CLI/Caller
    participant PRE as bwd_preprocess
    participant AT as Autotuner (BWD)
    participant JIT as JIT Compiler
    participant KB as GPU Kernel (BWD)
    participant POST as bwd_postprocess

    CLI->>PRE: flashattn_bwd_preprocess(Q,K,V,dO,...)
    PRE-->>CLI: Preprocessed aux (e.g., delta/LSE)
    CLI->>AT: flashattn_bwd(args, configs)
    AT->>JIT: Compile selected bwd kernel
    JIT-->>AT: Kernel handle
    AT->>KB: Launch to compute dQ, dK, dV
    KB-->>CLI: dQ, dK, dV
    CLI->>POST: flashattn_bwd_postprocess(dQ,...)
    POST-->>CLI: Final gradients
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • chengyupku

Poem

I twitch my ears at tiles that race,
New FP8 seeds in HIP’s embrace.
I hop through kernels, forward, back—
Softmax meadows, gradients track.
Autotune winds whisper “go!”
Logs and lattices softly glow.
Carrots cached; benchmarks show. 🥕🚀

Pre-merge checks and finishing touches

❌ Failed checks (1 warning, 1 inconclusive)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 7.41% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Title Check ❓ Inconclusive The title “fix bug&add amd examples” is too generic and does not clearly convey the main changes in the pull request, as it refers to an unspecified bug fix and vaguely mentions AMD examples without context or detail about the significant additions to FlashAttention kernels and HIP codegen updates. Consider renaming the title to something more descriptive of the primary changes, for example: “Add AMD FlashAttention forward/backward kernels and fix FP8 handling in HIP codegen,” so that it clearly communicates what was added and what bug was fixed.
✅ Passed checks (1 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

…lash_attn_bwd.py for improved code cleanliness.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (8)
examples/amd/example_amd_flash_attn_fwd.py (2)

37-37: Autotune space widened: expect longer tuning; consider adding num_stages=2 for AMD LDS pipelining

Good to include 0 and 1. On some workloads, stages=2 helps hide memory latency. Consider adding 2 (optional) or gate by dim/seq_len to avoid explosion.


175-186: Stable softmax path switch to exp is correct; minor micro-optimization optional

The move from exp2 to exp with scaled m is numerically sound. You can precompute inv_sqrt_d = scale once and reuse; also fold the two multiplications into a single fused multiply-sub for readability:

  • sf = T.exp((m_prev[i] - m_i[i]) * scale)
  • acc_s[i, j] = T.exp((acc_s[i, j] - m_i[i]) * scale)

Current code is correct as-is.

examples/amd/example_amd_flash_attn_bwd.py (5)

34-66: Config space parity with fwd example

Good coverage. To keep parity with the forward example file and avoid regressions, consider constraining threads/panel combos via simple heuristics (e.g., limit threads=1024 to block sizes ≥128) to reduce tuning time without losing best configs.


249-273: Preprocess kernel is correct; consider explicit threads kwarg

The Delta = sum(O*dO, axis=dim) implementation is fine. For consistency and predictable occupancy, pass threads=128 (or tuned) to T.Kernel here.


276-296: Backward kernel math path LGTM; minor clarity tweak

  • P_acc = exp(qkT*sm_scale - lse) and causal masking are standard.
  • dP, dv, dk, dq accumulation and atomic adds look consistent.

Optional: rename qkT -> S_scaled for readability.

Also applies to: 321-365


386-445: Debug/benchmark helpers are useful; minor CUDA guard

You already guard synchronize; optionally also guard torch.cuda.manual_seed(42) and device selection in main by checking CUDA/HIP availability to make CPU fallback easier when running on non-GPU CI.


446-486: End-to-end flow is solid; add LSE check and deterministic seeds

Great validation flow. Two small improvements:

  • Validate LSE from fwd_kernel against ref_program to catch rare stability issues.
  • Set torch.backends.cuda.matmul.allow_tf32 and cudnn flags as needed for determinism during checks (keep them default for perf bench).

I can add these guards and checks if you want.

Also applies to: 488-586

src/tl_templates/hip/hip_fp8.h (1)

54-64: Inconsistent API: fp8_e5_4_t is missing constructors and operators present in fp8_e4_4_t.

The fp8_e5_4_t struct only has:

  • Default constructor
  • Constructor from __hip_fp8x4_e5m2_fnuz
  • Conversion operator

However, fp8_e4_4_t (lines 24-40) additionally provides:

  • Constructor from float4
  • Assignment operator from native type

This API inconsistency could lead to confusion and limit usability of E5 types.

Consider adding the missing members to fp8_e5_4_t to match fp8_e4_4_t:

 struct fp8_e5_4_t {
   union {
     __hip_fp8x4_e5m2_fnuz data;
     struct {
       fp8_e5_t x, y, z, w;
     };
   };
   __device__ fp8_e5_4_t() = default;
   __device__ fp8_e5_4_t(const __hip_fp8x4_e5m2_fnuz &val) : data(val) {}
+  __device__ fp8_e5_4_t(const float4 &val) : data(val) {}
   __device__ operator __hip_fp8x4_e5m2_fnuz() const { return data; }
+  __device__ fp8_e5_4_t &operator=(const __hip_fp8x4_e5m2_fnuz &val) {
+    data = val;
+    return *this;
+  }
 };
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7cd0da9 and 8b345ae.

📒 Files selected for processing (8)
  • examples/amd/example_amd_flash_attn_bwd.py (2 hunks)
  • examples/amd/example_amd_flash_attn_fwd.py (3 hunks)
  • examples/amd/test.sh (0 hunks)
  • examples/flash_attention/example_mha_bwd.py (0 hunks)
  • src/target/codegen_hip.cc (4 hunks)
  • src/tl_templates/hip/common.h (1 hunks)
  • src/tl_templates/hip/gemm.h (1 hunks)
  • src/tl_templates/hip/hip_fp8.h (2 hunks)
💤 Files with no reviewable changes (2)
  • examples/amd/test.sh
  • examples/flash_attention/example_mha_bwd.py
🔇 Additional comments (11)
examples/amd/example_amd_flash_attn_fwd.py (1)

86-86: Scale change LGTM

Using natural exp with scale=1/sqrt(d) aligns with the reference softmax.

examples/amd/example_amd_flash_attn_bwd.py (1)

68-88: Forward kernel (in bwd file) numerics and LSE emission look correct

  • Scaling applied before max and exp; mask preserves -inf; LSE written as log(l_i)+m_i with bounds check. This matches the reference.

If desired, add an assert_allclose against ref_program for LSE too to tighten validation (currently only outputs checked via profiler in main).

Also applies to: 98-107, 176-201, 212-222

src/tl_templates/hip/common.h (1)

119-121: LGTM!

The AtomicAddRet implementation correctly performs an atomic add on a reference and returns the result. This is a useful addition for cases where the return value is needed.

src/target/codegen_hip.cc (4)

44-54: LGTM: Extended FP8 type support.

The additions correctly map new FP8 variants (kFloat8_e4m3, kFloat8_e4m3b11fnuz, kFloat8_e5m2fnuz, kFloat8_e8m0fnu) to their respective HIP type representations. The consistent pattern of preserving vector-width suffixes ensures proper type handling across all FP8 variants.


55-55: LGTM: Improved error diagnostics.

Including the actual type in the error message will help developers quickly identify unsupported FP8 types during debugging.


966-972: LGTM: HIP-specific control flow additions.

The loop_break emits a standard break statement, and no_set_max_nreg correctly becomes a no-op for HIP since register management is handled differently than in CUDA. Both implementations are appropriate for the HIP backend.


1178-1179: LGTM: Consistent FP8 constant handling.

The extension of FP8 constant detection to include is_float8_e4m3 and is_float8_e4m3fn aligns with the FP8 type handling additions earlier in the file, ensuring all FP8 variants are properly recognized and emitted as fp8_e4_t constants.

src/tl_templates/hip/hip_fp8.h (3)

8-13: LGTM: Added FP8 E5 type aliases.

The additions correctly provide type aliases for E5M2 FP8 types, maintaining consistency with the existing E4M3 aliases. The commented note about E8M0 not being supported in the current HIP version is helpful for future reference.


24-40: LGTM: Enhanced fp8_e4_4_t wrapper API.

The additions provide a complete set of constructors and operators:

  • Default constructor for uninitialized instances
  • Copy constructor from native type for interoperability
  • Constructor from float4 for conversions
  • Conversion operator for seamless usage with native types
  • Assignment operator for value updates

These enhancements improve usability and type safety for FP8 operations.


66-74: LGTM: E5 vector type scaffolding.

The fp8_e5_8_t and fp8_e5_16_t structs follow the same composition pattern as their E4 counterparts, providing consistent vector type support across FP8 variants.

src/tl_templates/hip/gemm.h (1)

73-75: Verify that accumulator clearing is implemented when clear_accum=true.

The comment states "we'll handle it by manually clearing the accumulator," but reviewing the body() and body_rs() methods (lines 158-291), I don't see explicit accumulator initialization code that checks the clear_accum template parameter.

If clear_accum=true is now allowed but not implemented, this could lead to undefined behavior with uninitialized accumulators.

Run the following script to verify if there's accumulator clearing logic elsewhere or if this needs to be implemented:

Comment on lines +328 to +329
T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix policy reference: use imported GemmWarpPolicy instead of T.GemmWarpPolicy

You import GemmWarpPolicy, but call T.GemmWarpPolicy.*. Unless tilelang.language exposes GemmWarpPolicy, this will raise at JIT time. Align with the forward file usage.

Apply this diff:

-                T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+                T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=GemmWarpPolicy.FullRow)
@@
-                T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+                T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=GemmWarpPolicy.FullRow)
@@
-                T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow)
+                T.gemm(p_cast, do_shared, dv, policy=GemmWarpPolicy.FullRow)
@@
-                T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow)
+                T.gemm(p_cast, q_shared, dk, policy=GemmWarpPolicy.FullRow)

Also applies to: 343-347, 353-354

🤖 Prompt for AI Agents
In examples/amd/example_amd_flash_attn_bwd.py around lines 328-329 (and also
343-347, 353-354), the code references T.GemmWarpPolicy.* while GemmWarpPolicy
is imported directly; replace all uses of T.GemmWarpPolicy with the imported
GemmWarpPolicy (e.g., GemmWarpPolicy.FullRow) to match the forward file and
avoid JIT-time attribute errors, updating each affected invocation to use
GemmWarpPolicy.* consistently.

Comment on lines +113 to +117
// Overload for when the first argument is a value instead of a pointer
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical bug: Atomic operation on local copy is meaningless.

This overload takes address by value (a copy), then takes the address of that local copy to perform the atomic operation. Since the atomic operation is performed on a temporary local variable, the result is discarded when the function returns, making this operation meaningless.

If the intent is to provide a convenience wrapper for value-based atomics, the signature should take a reference or pointer instead. As written, this code will compile but produce incorrect behavior.

Apply this diff to remove the broken overload:

-// Overload for when the first argument is a value instead of a pointer
-template <typename T1, typename T2>
-TL_DEVICE void AtomicAdd(T1 address, T2 val) {
-  atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
-}

If a value-based atomic is truly needed, clarify the intended semantics and implement accordingly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
// Overload for when the first argument is a value instead of a pointer
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
}
🤖 Prompt for AI Agents
In src/tl_templates/hip/common.h around lines 113-117, the AtomicAdd overload
takes the address parameter by value and performs atomicAdd on the address of
that local copy, so the atomic operation targets a temporary and has no effect;
remove this broken overload or change its signature to accept a pointer or
reference to the target storage (e.g., T1* or T1&) and use that directly in
atomicAdd, or if a value-based atomic is intended, document and implement the
correct semantics (e.g., return the new value and operate on a shared location)
instead of operating on a local copy.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
examples/amd/example_amd_flash_attn_fwd.py (1)

31-63: Autotune search space broadened; consider tuning bounds for ROCm

num_stages now includes [0, 1]. Good for overlap coverage. If you observe compile/runtime instability on some HIP stacks, you may want to guard stage=1 via device checks or leave a note in the example.

examples/amd/example_amd_flash_attn_bwd.py (3)

328-347: Unify policy enum usage to avoid ambiguity

You import GemmWarpPolicy but also use T.GemmWarpPolicy in a few places. Prefer one. Using the imported GemmWarpPolicy everywhere avoids reliance on module re-exports.

-                T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+                T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=GemmWarpPolicy.FullRow)

-                T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+                T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=GemmWarpPolicy.FullRow)

-                T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow)
+                T.gemm(p_cast, q_shared, dk, policy=GemmWarpPolicy.FullRow)

Also applies to: 343-347, 353-358


256-271: Rename ambiguous parameter ‘O’ to satisfy linters and improve clarity

Variable name O triggers E741 (ambiguous name) and reduces readability.

-    def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype),
+    def flash_bwd_prep(Out: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype),
                        Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):
@@
-                T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)
+                T.copy(Out[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)

Optionally mirror this rename at the call site for consistency in naming (argument name doesn’t need to match, so no functional impact).


518-534: Remove or underscore unused ‘*_mean_diff’ variables

Avoid RUF059 by ignoring the unused means.

-    dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(
+    dq_close, dq_max_diff, _dq_mean = debug_tensor_comparison(
         dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
@@
-    dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(
+    dk_close, dk_max_diff, _dk_mean = debug_tensor_comparison(
         dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
@@
-    dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(
+    dv_close, dv_max_diff, _dv_mean = debug_tensor_comparison(
         dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7cd0da9 and c34315c.

📒 Files selected for processing (8)
  • examples/amd/example_amd_flash_attn_bwd.py (2 hunks)
  • examples/amd/example_amd_flash_attn_fwd.py (3 hunks)
  • examples/amd/test.sh (0 hunks)
  • examples/flash_attention/example_mha_bwd.py (0 hunks)
  • src/target/codegen_hip.cc (4 hunks)
  • src/tl_templates/hip/common.h (1 hunks)
  • src/tl_templates/hip/gemm.h (1 hunks)
  • src/tl_templates/hip/hip_fp8.h (2 hunks)
💤 Files with no reviewable changes (2)
  • examples/flash_attention/example_mha_bwd.py
  • examples/amd/test.sh
🧰 Additional context used
🧬 Code graph analysis (3)
src/tl_templates/hip/hip_fp8.h (1)
src/tl_templates/cuda/cuda_fp8.h (1)
  • fp8_e5_2_t (44-47)
src/target/codegen_hip.cc (2)
tilelang/language/builtin.py (2)
  • loop_break (419-422)
  • no_set_max_nreg (160-163)
tilelang/language/customize.py (1)
  • loop_break (67-73)
examples/amd/example_amd_flash_attn_bwd.py (2)
examples/amd/example_amd_flash_attn_fwd.py (4)
  • ref_program (11-28)
  • fast_flashattn (68-205)
  • main (97-203)
  • main (208-238)
examples/attention_sink/example_gqa_sink_bwd_bhsd.py (7)
  • ref_program (403-447)
  • main (450-519)
  • flashattn_bwd_preprocess (148-173)
  • flash_bwd_prep (154-171)
  • flashattn_bwd (213-321)
  • flashattn_bwd_postprocess (188-205)
  • flash_bwd_post (194-203)
🪛 Ruff (0.13.3)
examples/amd/example_amd_flash_attn_bwd.py

257-257: Ambiguous variable name: O

(E741)


518-518: Unpacked variable dq_mean_diff is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


525-525: Unpacked variable dk_mean_diff is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


532-532: Unpacked variable dv_mean_diff is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-metal
  • GitHub Check: format-check
🔇 Additional comments (12)
src/tl_templates/hip/common.h (1)

119-121: LGTM!

The AtomicAddRet function correctly takes a reference and returns the old value from the atomic operation.

src/target/codegen_hip.cc (4)

44-55: LGTM!

The extended FP8 type handling correctly maps additional FP8 variants to their corresponding HIP types (fp8_e4_t, fp8_e5_t, fp8_e8_t). The error message now includes the type information for better diagnostics.


936-940: LGTM!

The reformatted mfma builtin call maintains the same behavior with improved readability.


966-972: LGTM!

The implementation correctly handles:

  • tl::loop_break() by emitting a break; statement
  • tl::no_set_max_nreg() as a no-op (HIP doesn't require explicit register management like CUDA)

Both behaviors are appropriate for the HIP backend.


1178-1179: LGTM!

The consolidated FP8 constant handling correctly combines multiple E4M3 variants (e4m3fnuz, e4m3, e4m3fn) that all map to the same fp8_e4_t type for HIP. This simplification reduces code duplication.

src/tl_templates/hip/hip_fp8.h (3)

8-13: LGTM!

The FP8 E5M2 typedefs are correctly defined using HIP's underlying types, and the comment appropriately documents that E8M0 types are not yet supported in the current HIP version.


53-74: LGTM!

The FP8 E5M2 vector types (fp8_e5_4_t, fp8_e5_8_t, fp8_e5_16_t) follow the same structure and pattern as the existing E4M3 types, providing consistent API and proper alignment attributes.


76-99: LGTM!

The commented-out E8M0 vector types provide a clear template for future implementation when HIP adds support for this FP8 variant.

examples/amd/example_amd_flash_attn_fwd.py (2)

86-86: Scale computation is correct

(1.0 / dim)**0.5 equals 1/sqrt(dim); matches the reference softmax scaling.


172-186: exp2 → exp change is consistent with the new scale

Using T.exp with the scale as written preserves correctness and removes the need for a log2(e) factor. Looks good.

If you previously cached autotune results, clear the cache to avoid mixing exp2/exp kernels during benchmarking.

examples/amd/example_amd_flash_attn_bwd.py (2)

88-105: Forward kernel: LSE production and scaling look correct

You now emit LSE and compute it as log(l_i) + m_i in the scaled domain. The earlier multiply-by-scale before reduce_max and exp(acc_s - m_i) is consistent and numerically stable.

Please re-run forward correctness with a larger seq_len (e.g., 8192) to sanity-check stability on HIP: rtol=5e-2, atol=5e-2.

Also applies to: 207-223


447-507: End-to-end wiring is clean and debuggable

Nice addition of reference checks, LSE plumbing, and separated pre/post backward stages; improves clarity and makes benchmarking straightforward.

Comment on lines +113 to +117
// Overload for when the first argument is a value instead of a pointer
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Critical logic error: taking address of pass-by-value parameter.

This overload takes T1 address by value and then uses &address to get a pointer. This retrieves the address of the local copy created for the function parameter, not the address of the caller's original variable. The atomic operation would be performed on the local copy, which is destroyed when the function returns, making the atomic operation completely ineffective.

Either remove this overload entirely if it's not needed, or if the intent was to support non-pointer types, the signature should be AtomicAdd(T1 &address, T2 val) (taking a reference).

Apply this diff to fix the signature to take a reference:

 // Overload for when the first argument is a value instead of a pointer
 template <typename T1, typename T2>
-TL_DEVICE void AtomicAdd(T1 address, T2 val) {
-  atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
+TL_DEVICE void AtomicAdd(T1 &address, T2 val) {
+  atomicAdd(&address, static_cast<T1>(val));
 }
🤖 Prompt for AI Agents
In src/tl_templates/hip/common.h around lines 113 to 117, the AtomicAdd overload
takes its first parameter by value and then uses &address, which takes the
address of the local copy instead of the caller's variable; change the function
signature to take the address parameter by reference (e.g., T1 &address) so
atomicAdd operates on the caller's memory (or remove the overload if not
needed), and update any callers if necessary to pass lvalues.

Comment on lines +73 to +75
// Note: clear_accum=true is not fully supported in HIP implementation
// but we'll handle it by manually clearing the accumulator
// static_assert(!clear_accum, "clear_accum=true is not supported yet");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify that clear_accum=true is properly handled.

The static_assert has been removed and replaced with a comment stating that clear_accum will be "manually cleared." However, examining the body() (lines 158-232) and body_rs() (lines 234-291) methods, there is no visible code that clears the accumulator when clear_accum=true.

If clear_accum=true is passed, the accumulator should be zeroed before accumulation begins, but this implementation appears to rely on the caller to pre-zero C_local.

Run the following script to check if clear_accum is handled elsewhere or if callers are responsible for zeroing:


🏁 Script executed:

#!/bin/bash
# Search for clear_accum usage patterns and accumulator initialization
rg -n -C5 'clear_accum.*true|C_local.*=.*0|memset.*C_local' --type=cpp --type=h

Length of output: 3003


Restore or implement clear_accum handling
Either re-enable static_assert(!clear_accum) on lines 73–75 or add code to zero C_local when clear_accum=true, as it’s currently unhandled.

🤖 Prompt for AI Agents
In src/tl_templates/hip/gemm.h around lines 73–75, restore handling for the
clear_accum template flag: either re-enable the static_assert(!clear_accum) to
fail compilation when clear_accum=true, or implement the runtime/compile-time
zeroing of C_local when clear_accum==true. If choosing static_assert, add the
assert back exactly for the HIP path; if choosing implementation, add a guarded
branch (compile-time if constexpr or preprocessor guarded by the template param)
that writes zeros into the accumulator buffer C_local before accumulation
begins, ensuring the zeroing matches the accumulator tile shape and memory
layout and does not break existing non-clear behavior.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
examples/amd/example_amd_flash_attn_bwd.py (1)

1-601: Format examples/amd/example_amd_flash_attn_bwd.py using the project’s formatter (e.g., ruff --fix or black) to resolve CI formatting errors.

src/tl_templates/hip/gemm.h (1)

54-65: Add MfmaTraits specialization for fp8_e5_t
HIP codegen emits fp8_e5_t but src/tl_templates/hip/gemm.h only handles fp8_e4_t. Add the following under #if defined(HIP_FP8_ENABLED):

template <> struct MfmaTraits<fp8_e5_t> {
  template <typename AccType>
  static TL_DEVICE void mfma_op(const fp8_e5_t* b, const fp8_e5_t* a, AccType* c) {
    int64_t a_val = *reinterpret_cast<const int64_t*>(a);
    int64_t b_val = *reinterpret_cast<const int64_t*>(b);
    *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
  }
};

Confirm the intrinsic is correct for E5 on your target compiler.

🧹 Nitpick comments (3)
examples/amd/example_amd_flash_attn_bwd.py (2)

257-258: Address ambiguous variable name.

Static analysis flags variable O as ambiguous (single uppercase letter). While this is idiomatic for mathematical notation in attention mechanisms, consider the project's style guidelines.

Based on learnings

If the codebase prefers descriptive names, apply this diff:

-    def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype),
+    def flash_bwd_prep(Output: T.Tensor(shape, dtype), dOutput: T.Tensor(shape, dtype),
                        Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):

And update references within the function accordingly.


517-517: Remove unused unpacked variables.

Static analysis flags dq_mean_diff, dk_mean_diff, and dv_mean_diff as unpacked but unused.

Apply this diff to suppress warnings:

-    dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison(
+    dq_close, dq_max_diff, _ = debug_tensor_comparison(
         dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05)
     if dq_close:
         print("dQ is correct.")
     else:
         print("dQ mismatch detected.")
 
-    dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison(
+    dk_close, dk_max_diff, _ = debug_tensor_comparison(
         dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05)
     if dk_close:
         print("dK is correct.")
     else:
         print("dK mismatch detected.")
 
-    dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison(
+    dv_close, dv_max_diff, _ = debug_tensor_comparison(
         dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)

Also applies to: 524-524, 531-531

src/tl_templates/hip/hip_fp8.h (1)

53-75: E5 FP8 vector wrappers: LGTM; consider parity with E4 helpers

Implementation mirrors E4 forms. For API parity, you may add an assignment operator from __hip_fp8x4_e5m2_fnuz like E4_4 has, but optional.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7cd0da9 and 8b345ae.

📒 Files selected for processing (8)
  • examples/amd/example_amd_flash_attn_bwd.py (2 hunks)
  • examples/amd/example_amd_flash_attn_fwd.py (3 hunks)
  • examples/amd/test.sh (0 hunks)
  • examples/flash_attention/example_mha_bwd.py (0 hunks)
  • src/target/codegen_hip.cc (4 hunks)
  • src/tl_templates/hip/common.h (1 hunks)
  • src/tl_templates/hip/gemm.h (1 hunks)
  • src/tl_templates/hip/hip_fp8.h (2 hunks)
💤 Files with no reviewable changes (2)
  • examples/amd/test.sh
  • examples/flash_attention/example_mha_bwd.py
🧰 Additional context used
🧬 Code graph analysis (3)
src/target/codegen_hip.cc (2)
tilelang/language/builtin.py (2)
  • loop_break (419-422)
  • no_set_max_nreg (160-163)
tilelang/language/customize.py (1)
  • loop_break (67-73)
src/tl_templates/hip/hip_fp8.h (1)
src/tl_templates/cuda/cuda_fp8.h (1)
  • fp8_e5_2_t (44-47)
examples/amd/example_amd_flash_attn_bwd.py (4)
examples/amd/example_amd_flash_attn_fwd.py (4)
  • ref_program (11-28)
  • fast_flashattn (68-205)
  • main (97-203)
  • main (208-238)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (3)
  • ref_program (191-236)
  • main (128-185)
  • main (253-312)
examples/attention_sink/example_mha_sink_bwd_bhsd.py (8)
  • ref_program (399-444)
  • main (447-514)
  • get_bwd_configs (11-19)
  • flashattn_bwd_preprocess (145-170)
  • flash_bwd_prep (151-168)
  • flashattn_bwd (210-322)
  • flashattn_bwd_postprocess (185-202)
  • backward (367-391)
tilelang/tileop/gemm/gemm_base.py (3)
  • K (42-43)
  • k_pack (111-112)
  • policy (119-120)
🪛 GitHub Actions: CI Test on Metal
examples/amd/example_amd_flash_attn_bwd.py

[error] 1-1: Reformatted files. Please review and stage the changes.

🪛 Ruff (0.13.3)
examples/amd/example_amd_flash_attn_bwd.py

257-257: Ambiguous variable name: O

(E741)


517-517: Unpacked variable dq_mean_diff is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


524-524: Unpacked variable dk_mean_diff is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


531-531: Unpacked variable dv_mean_diff is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

🔇 Additional comments (10)
examples/amd/example_amd_flash_attn_fwd.py (2)

37-37: LGTM! Autotune search space expanded.

Adding num_stages = [0, 1] increases the configuration search space, allowing the autotuner to explore pipelined and non-pipelined execution paths for better performance.


86-86: Consistent exp usage verified. All AMD attention examples use T.exp with scale = (1.0/dim)**0.5; no T.exp2 or hardcoded 1.44269504 found.

examples/amd/example_amd_flash_attn_bwd.py (6)

1-11: LGTM! Imports are appropriate.

The imports include necessary dependencies for the backward pass implementation, including numpy for tensor comparison utilities and time for benchmarking.


13-31: LGTM! Reference implementation returns LSE.

The reference implementation now returns both the output and log-sum-exp (LSE), which is required for the backward pass. The variable naming uses K_ref and V_ref for clarity after the repeat_interleave operation.


182-198: Guard against infinity before exponentiation.

The code correctly checks for -T.infinity values before calling T.exp, preventing numerical overflow and NaN propagation:

if m_prev[i] == -T.infinity(accum_dtype):
    scale_factor[i] = 0.0
else:
    scale_factor[i] = T.exp(m_prev[i] - m_i[i])

This defensive pattern is also applied at lines 194-197. Good practice for numerical stability.


276-365: LGTM! Backward kernel uses atomic operations correctly.

The backward pass implementation:

  1. Correctly computes dV and dK accumulations per block
  2. Uses T.atomic_add for dQ accumulation (line 359) to handle race conditions from multiple blocks writing to the same dQ elements
  3. Uses T.atomic_add for final dV and dK writes (lines 362-363) to handle multi-group scenarios

The use of accum_dtype (float32) for dQ accumulation buffer followed by postprocessing to float16 is a sound approach for numerical stability.


386-425: Comprehensive debug utility.

The debug_tensor_comparison function provides detailed diagnostics including max/mean differences, relative errors, NaN/Inf counts, and per-element analysis. This will be valuable for debugging numerical issues during development.


446-588: Comprehensive test harness with proper verification.

The main function provides:

  1. Forward pass autotuning and verification against PyTorch reference
  2. Backward pass autotuning with preprocessing and postprocessing
  3. Detailed gradient verification with configurable tolerances
  4. Performance benchmarking with proper synchronization
  5. Clear pass/fail reporting

The integration of all components (forward, backward, pre/post-processing) demonstrates a complete end-to-end workflow.

src/target/codegen_hip.cc (1)

966-973: loop_break and no_set_max_nreg lowering: LGTM

Emitting break; and treating no_set_max_nreg() as a no‑op on HIP is fine.

Ensure these intrinsics are only used in statement context (Evaluate) to avoid dangling expressions in value contexts.

src/tl_templates/hip/hip_fp8.h (1)

8-14: E5 FP8 aliases: LGTM

Adds E5M2 aliases consistent with HIP’s fnuz variants.

Confirm target HIP SDK provides __hip_fp8_e5m2_fnuz and __hip_fp8x2_e5m2_fnuz on all CI targets.

Comment on lines +34 to +65
def get_fwd_configs():
block_M = [32, 64, 128, 256]
block_N = [32, 64, 128, 256]
threads = [128, 256, 512]
num_split_q = [64, 128, 256]
num_stages = [0, 1]
enable_rasterization = [True]
k_pack = [2]
panel_size = [7, 8, 9, 10]
qk_coalesced_width = [8]
v_coalesced_width = [4]

valid_configs = []

for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q,
threads, num_stages,
enable_rasterization, k_pack,
panel_size, qk_coalesced_width,
v_coalesced_width):
valid_configs.append({
"block_M": m,
"block_N": n,
"num_split_q": s,
"threads": t,
"num_stages": stages,
"enable_rasterization": r,
"k_pack": k,
"panel_size": p,
"qk_coalesced_width": qkw,
"v_coalesced_width": vw,
})
return valid_configs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

🧩 Analysis chain

Verify configuration consistency with forward pass.

The get_fwd_configs() function generates configurations for forward pass autotuning. Note that panel_size has been expanded to [7, 8, 9, 10] compared to the forward-only file.

Ensure these configurations align with the standalone forward example to maintain consistency across the codebase:


🏁 Script executed:

#!/bin/bash
# Description: Compare forward configs between fwd-only and fwd-bwd examples

echo "=== Forward-only example config ==="
rg -A 10 'def get_configs' examples/amd/example_amd_flash_attn_fwd.py

echo ""
echo "=== Forward-backward example config ==="
rg -A 10 'def get_fwd_configs' examples/amd/example_amd_flash_attn_bwd.py

Length of output: 758


Align panel_size between examples. get_fwd_configs() uses [7, 8, 9, 10], but get_configs() uses [7, 8]; update one to match the other.

🤖 Prompt for AI Agents
In examples/amd/example_amd_flash_attn_bwd.py around lines 34 to 65, panel_size
values differ between get_fwd_configs() ([7, 8, 9, 10]) and get_configs() ([7,
8]); make them consistent by updating the get_configs() panel_size list to match
get_fwd_configs() (use [7, 8, 9, 10]) so both functions produce the same
panel_size options.

Comment on lines +44 to 56
} else if (type.code() == DataType::kFloat8_e4m3) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e4m3b11fnuz) {
stream << "fp8_e4" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e5m2) {
stream << "fp8_e5" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e5m2fnuz) {
stream << "fp8_e5" << vec << "_t";
} else if (type.code() == DataType::kFloat8_e8m0fnu) {
stream << "fp8_e8" << vec << "_t";
} else {
LOG(FATAL) << "Unsupported FP8 type in HIP codegen";
LOG(FATAL) << "Unsupported FP8 type in HIP codegen: " << type;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

FP8 E8 mapping unsupported by header; and FP8 E5 constants not handled

  • GetFP8Type maps kFloat8_e8m0fnu → fp8_e8*_t, but hip_fp8.h explicitly comments E8 types as unsupported. This will fail to compile when used.
  • E5 mapping is added here, but PrintConst lacks E5 handling, leading to LOG(FATAL) for FloatImm FP8 E5.

Apply:

-  } else if (type.code() == DataType::kFloat8_e8m0fnu) {
-    stream << "fp8_e8" << vec << "_t";
-  } else {
+  } else {
     LOG(FATAL) << "Unsupported FP8 type in HIP codegen: " << type;

And extend PrintConst (see below comment) to support E5 FP8 constants.

Check if any E8 paths are exercised; if so, either add E8 types to hip_fp8.h guarded by feature detection or drop E8 emission entirely.


🏁 Script executed:

#!/bin/bash
set -e

# Locate hip_fp8.h
HIP_FILE=$(fd hip_fp8.h | head -n1)
echo "=== hip_fp8.h ($HIP_FILE) context ==="
rg -n -C2 'E8' "$HIP_FILE"
rg -n -C2 'unsupported' "$HIP_FILE"

# Inspect PrintConst in codegen_hip.cc for FP8 handling
echo "=== PrintConst FP8 cases in codegen_hip.cc ==="
rg -n -C3 'PrintConst' -m1 src/target/codegen_hip.cc | sed -n '1,20p'
rg -n -C3 'kFloat8' src/target/codegen_hip.cc

Length of output: 408


🏁 Script executed:

#!/bin/bash
set -e

# Locate PrintConst in the repository
echo "=== PrintConst definitions/usages ==="
rg -n 'PrintConst' -C3 .

# Show all FP8 kFloat8 cases with context in codegen_hip.cc
echo "=== FP8 cases in codegen_hip.cc ==="
rg -n -C10 'kFloat8' src/target/codegen_hip.cc

Length of output: 5602


🏁 Script executed:

#!/bin/bash
set -e
sed -n '1171,1271p' src/target/codegen_hip.cc

Length of output: 2899


Remove unsupported FP8 E8 case and add FP8 E5 constant support

  • Drop the kFloat8_e8m0fnu branch since hip_fp8.h (lines 11–13) comments out E8 aliases, causing build failures.
  • Extend PrintConst to handle E5 constants:
    if (op->dtype.is_float8_e5m2() || op->dtype.is_float8_e5m2fnuz()) {
      os << "fp8_e5_t";
      os << '(' << std::scientific << op->value << 'f' << ')';
      return;
    }
-  } else if (type.code() == DataType::kFloat8_e8m0fnu) {
-    stream << "fp8_e8" << vec << "_t";
-  } else {
+  } else {
     LOG(FATAL) << "Unsupported FP8 type in HIP codegen: " << type;
🤖 Prompt for AI Agents
In src/target/codegen_hip.cc around lines 44 to 56, remove the branch that
handles DataType::kFloat8_e8m0fnu (delete the else-if that streams "fp8_e8...")
because HIP headers comment out E8 aliases and this causes build failures;
additionally, extend the PrintConst handling for FP8 constants to detect e5
types (is_float8_e5m2 or is_float8_e5m2fnuz) and emit the constant as "fp8_e5_t"
constructed from the literal value (i.e., print "fp8_e5_t(...f)"), ensuring the
same formatting style as other float constant code paths so E5 constants are
supported.

Comment on lines +1178 to 1181
} else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() ||
op->dtype.is_float8_e4m3fn()) {
os << "fp8_e4_t";
os << '(' << std::scientific << op->value << 'f' << ')';
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add FP8 E5 constant emission to avoid fatal path

Currently only E4 FP8 constants are emitted. Add E5 branch to PrintConst.

-  } else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() ||
-             op->dtype.is_float8_e4m3fn()) {
+  } else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() ||
+             op->dtype.is_float8_e4m3fn()) {
     os << "fp8_e4_t";
     os << '(' << std::scientific << op->value << 'f' << ')';
     return;
+  } else if (op->dtype.is_float8_e5m2fnuz()
+             /* add is_float8_e5m2() if available in your DataType */) {
+    os << "fp8_e5_t";
+    os << '(' << std::scientific << op->value << 'f' << ')';
+    return;
   }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
} else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() ||
op->dtype.is_float8_e4m3fn()) {
os << "fp8_e4_t";
os << '(' << std::scientific << op->value << 'f' << ')';
} else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() ||
op->dtype.is_float8_e4m3fn()) {
os << "fp8_e4_t";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
} else if (op->dtype.is_float8_e5m2fnuz()
/* add is_float8_e5m2() if available in your DataType */) {
os << "fp8_e5_t";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
}
🤖 Prompt for AI Agents
In src/target/codegen_hip.cc around lines 1178 to 1181, the PrintConst handling
only emits FP8 E4 constants; add a parallel branch to handle FP8 E5 types. Add
an else-if after the existing E4 check that tests the E5 dtype predicates (the
analogous is_float8_e5... methods), and emit "fp8_e5_t" with the same formatting
as E4 (e.g. os << "fp8_e5_t"; os << '(' << std::scientific << op->value << 'f'
<< ')';). Ensure the new branch mirrors the E4 branch placement and formatting
so FP8 E5 constants are emitted instead of falling into the fatal path.

Comment on lines +113 to +121
// Overload for when the first argument is a value instead of a pointer
template <typename T1, typename T2>
TL_DEVICE void AtomicAdd(T1 address, T2 val) {
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
}

template <typename T1, typename T2> TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val) {
return atomicAdd(&ref, static_cast<T1>(val));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Invalid AtomicAdd overload takes a value; atomics must target real memory

AtomicAdd(T1 address, T2 val) atomically updates a local copy via &address, not the intended memory; this is incorrect and can cause undefined behavior on HIP.

Remove this overload and, if needed, provide a reference overload (non‑returning) alongside AtomicAddRet:

-// Overload for when the first argument is a value instead of a pointer
-template <typename T1, typename T2>
-TL_DEVICE void AtomicAdd(T1 address, T2 val) {
-  atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
-}
+// Preferred non-returning reference overload (mirrors AtomicAddRet)
+template <typename T1, typename T2>
+TL_DEVICE void AtomicAdd(T1& ref, T2 val) {
+  atomicAdd(&ref, static_cast<T1>(val));
+}

Optionally constrain T1 to HIP‑supported atomic types via static_assert.

🤖 Prompt for AI Agents
In src/tl_templates/hip/common.h around lines 113 to 121, the overload
AtomicAdd(T1 address, T2 val) takes the address by value and does atomicAdd on
&address (a local copy) which is invalid; remove this value-taking overload,
replace with a non-returning reference overload AtomicAdd(T1 &address, T2 val)
that calls atomicAdd(&address, static_cast<T1>(val)), and (optionally) add a
static_assert or type trait to constrain T1 to HIP-supported atomic types to
catch misuse at compile time.

Comment on lines +73 to +75
// Note: clear_accum=true is not fully supported in HIP implementation
// but we'll handle it by manually clearing the accumulator
// static_assert(!clear_accum, "clear_accum=true is not supported yet");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Do not drop support guard without implementing clear_accum behavior

Removing the static_assert enables clear_accum=true but no code zeros accumulators, yielding nondeterministic results.

Add zeroing when clear_accum is true. Example insertion at the start of body() and body_rs() before the ki loop:

if constexpr (clear_accum) {
  for (int i = 0; i < warp_rows; ++i) {
    for (int j = 0; j < warp_cols; ++j) {
      ((float32x4*)C_local)[i * warp_cols + j] = 0;
    }
  }
}

Optionally re‑enable the static_assert if you cannot guarantee correct zeroing on all code paths.

🤖 Prompt for AI Agents
In src/tl_templates/hip/gemm.h around lines 73-75, the static_assert for
clear_accum was removed but no accumulator zeroing was added, so enabling
clear_accum=true yields nondeterministic results; add a constexpr guard at the
start of both body() and body_rs() (before the ki loop) that zeroes the per-warp
accumulator memory when clear_accum is true (iterate warp_rows and warp_cols and
set the corresponding C_local entries to zero, e.g., by casting C_local to the
appropriate vector type and writing zeros), and if you cannot guarantee zeroing
on all code paths re-enable the static_assert to prevent enabling clear_accum
without proper initialization.

@LeiWang1999
Copy link
Member

@Alex4210987 would be great if we can add some test case for fp8 dtype for hip backend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants