-
Notifications
You must be signed in to change notification settings - Fork 254
fix bug&add amd examples #966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…nd clarity (tile-ai#668) - Enhanced buffer index handling to address precision issues by removing redundant operations. - Streamlined the logic for determining buffer overlaps, ensuring more accurate conflict detection. - Updated related documentation to reflect changes in buffer management practices.
…ed flexibility - Introduced a new input.txt file for configurable parameters. - Modified the example_amd_flash_attn_fwd.py script to allow for a wider range of configurations, including additional options for num_stages, enable_rasterization, and k_pack. - Streamlined the main function for better clarity and organization. - Added a new test script to facilitate running the example with specified parameters.
… example with swizzle layout annotations - Deleted input.txt and test.sh files as they are no longer needed. - Updated example_amd_flash_attn_fwd.py to include swizzle layout annotations for shared memory, improving bank conflict avoidance. - Reintroduced swizzle usage in the kernel for better performance.
- Updated function names for clarity, changing `get_v2_configs` to `get_configs` and `fast_flashattn_v2` to `fast_flashattn`. - Streamlined the main function by renaming `main_v2` to `main` and adjusting the corresponding calls. - Removed outdated comments and improved code organization for better readability.
- Improved code readability by adjusting line breaks and indentation in the `fast_flashattn` function. - Streamlined the `main` function parameter formatting for consistency. - Removed unnecessary blank lines to enhance overall code organization.
- Improved the `example_amd_flash_attn_fwd.py` script for better clarity and organization. - Added new CI workflows for AMD and documentation publishing. - Updated various requirements files to include necessary dependencies. - Introduced new test cases and examples for better coverage and functionality. - Refactored existing code for improved readability and maintainability.
- Introduced `example_amd_flash_attn_bwd.py` for backward attention computation using TileLang. - Added `test.sh` script to facilitate running the new example with specified parameters. - Enhanced the overall structure and organization of the example for better clarity and usability.
- Reduced the number of threads and `num_split_q` options for improved performance. - Adjusted `panel_size` options to streamline configuration settings.
- Introduced a new example script `example_amd_flash_attn_bwd.py` demonstrating the forward and backward operations of Flash Attention using TileLang. - Implemented JIT-compiled functions for both forward and backward passes, including preprocessing and postprocessing steps. - Added a main function to facilitate testing and benchmarking of the attention mechanism with configurable parameters. - Included reference implementation for validation against PyTorch's attention mechanism. This addition enhances the examples directory by providing a comprehensive guide for users to understand and utilize Flash Attention in their applications.
- Updated `example_amd_flash_attn_bwd.py` to include more comprehensive testing features for the Flash Attention implementation. - Improved the main function to allow for better parameter configuration and benchmarking. - Added validation checks against PyTorch's attention mechanism to ensure accuracy and reliability of the example. This update aims to provide users with a more robust tool for understanding and utilizing Flash Attention in their applications.
- Updated file name from `intrin_rule_hip.cc` to `intrin_rule_cuda.cc` to reflect the change in focus from HIP to CUDA intrinsic rules. - Adjusted include paths for better organization and clarity in the code structure.
…installation - Removed the installation of `flash_attn==2.5.8` to streamline the CI process. - Added a step to uninstall `torch`, `torchvision`, and `torchaudio` prior to installing pre-release versions, ensuring compatibility and reducing potential conflicts.
…rd example - Eliminated the allocation of shared memory for `dv_shared` and `dk_shared` in `example_amd_flash_attn_bwd.py` to streamline memory usage and improve performance. - This change focuses on optimizing the backward pass implementation by reducing unnecessary memory overhead.
- Eliminated the step to uninstall `torch`, `torchvision`, and `torchaudio` in the AMD CI workflow, as it is no longer required for the installation of pre-release versions. - This change simplifies the CI process and reduces potential overhead during package management.
- Updated the return statement to use std::string for concatenation in the case of 16-bit types, improving code clarity. - Added a null check for the CallNode pointer in DispatchHIPWarpActiveMask to enhance robustness and prevent potential dereferencing issues.
- Adjusted the formatting of TVM_REGISTER_OP calls for better readability by aligning method chaining. - No functional changes were made; this update focuses on code style improvements to enhance maintainability.
- Renamed the file from `intrin_rule_cuda.cc` to `intrin_rule_hip.cc` to accurately reflect the focus on HIP intrinsic rules. - Updated the file documentation to clarify its purpose as related to HIP rather than CUDA.
…ttn_bwd.py - Updated scaling factor application for improved numerical stability in gradient calculations. - Refined tensor handling to ensure consistency with forward pass operations. - Optimized atomic operations for writing gradients to dK and dV using fp32 for better precision. - Adjusted comments for clarity and alignment with standard implementation practices.
…update test.sh - Increased the range of block sizes and stages for forward and backward configurations to enhance performance tuning. - Adjusted the test script to include additional parameters for batch size and head dimensions, ensuring consistency with the forward example. - Improved comments for clarity and alignment with the updated configurations.
…h_attn_bwd.py - Updated FLOPs calculation to account for both forward and backward passes, clarifying the total computational cost. - Modified benchmarking functions to evaluate the complete forward and backward performance of both reference and Tile-lang implementations. - Improved comments for better understanding of the performance metrics and implementation details. - Removed unnecessary parameter from test.sh to streamline execution.
…rd attention execution for streamlined testing.
…le_amd_flash_attn_bwd.py and example_amd_flash_attn_fwd.py - Updated the forward function to return both output and log-sum-exp (LSE) values for improved gradient calculations. - Enhanced autotuner configurations for forward pass, including new parameters for better performance tuning. - Refined scaling factor calculations for numerical stability in both forward and backward passes. - Improved comments and documentation for clarity and consistency across implementations. - Adjusted main function to reflect changes in parameter handling and ensure compatibility with new output requirements.
- Removed outdated comments and improved clarity in the code. - Enhanced the forward function to consistently return output and log-sum-exp (LSE) values. - Updated autotuner configurations to include new parameters for better performance tuning. - Refined tensor handling and scaling factor calculations for improved numerical stability. - Adjusted the main function to ensure compatibility with updated output requirements and parameter handling.
…ttn_bwd.py - Updated configuration parameters for backward calculations, including new options for block sizes, threads, and rasterization. - Added new parameters (k_pack, qk_coalesced_width, v_coalesced_width) to improve performance tuning and memory access patterns. - Modified tensor copy operations to utilize coalesced widths for optimized memory loads. - Enhanced GEMM operations with k_pack for improved computational efficiency. - Refined the configuration generation logic to accommodate the new parameters, ensuring comprehensive coverage for backward pass scenarios.
…n_bwd.py - Updated backward configuration parameters to include larger block sizes and a wider range of threads for enhanced performance tuning. - Removed unnecessary parameters (k_pack, qk_coalesced_width, v_coalesced_width) from function signatures and tensor operations to simplify the implementation. - Optimized tensor copy operations by eliminating coalesced width specifications, streamlining memory access patterns. - Adjusted GEMM operations to improve computational efficiency without the use of k_pack.
- Added support for additional FP8 types (e4m3, e4m3b11fnuz, e5m2fnuz, e8m0) in codegen_hip.cc to improve compatibility. - Updated error logging to include unsupported FP8 type details for better debugging. - Implemented handling for loop break and no-op register management in HIP within VisitExpr_ method. - Introduced new FP8 vector types (e5 and e8) in hip_fp8.h for enhanced functionality. - Added overloads for AtomicAdd in common.h to support both pointer and value arguments.
- Expanded FP8 type support in codegen_hip.cc to include additional float8 formats. - Updated gemm.h to clarify the handling of the accumulator when clear_accum is true. - Added comments in hip_fp8.h to indicate that E8M0 types are not supported in the current HIP version.
…ample_amd_flash_attn_bwd.py
…ple_amd_flash_attn_bwd.py for cleaner output.
…xample_amd_flash_attn_fwd.py by adding spaces for improved readability in configuration parameters and print statements.
- Reorganized and cleaned up code in codegen_hip.cc for better readability and maintainability. - Enhanced handling of FP8 types, including additional formats and improved error logging for unsupported types. - Updated AtomicAdd function in common.h to streamline its implementation. - Refined the PrintVecElemLoadExpr method to handle volatile loads more effectively. - Added function to manage the addition of new functions in the code generation process.
- Adjusted the indentation of the MFMA call code block in codegen_hip.cc for improved readability and consistency.
- Reintroduced necessary includes and reorganized code in codegen_hip.cc for improved structure and readability. - Enhanced the GetFP8Type function to support additional FP8 formats and improved error handling for unsupported types. - Updated PrintType and PrintVecElemLoadExpr methods to better manage type conversions and vector element loading. - Refined the AddFunction method to streamline function addition in the code generation process.
…proved code cleanliness.
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
…bwd.py - Updated the GEMM operation to use shared memory for improved performance. - Adjusted parallelization parameters to enhance efficiency in the backward pass.
WalkthroughThe change set expands AMD FlashAttention examples with new forward/backward kernels, reference, autotune configs, diagnostics, and CLI updates; adjusts the forward example’s scaling/exponent; removes two test invocations; simplifies layout annotations in another example; and extends HIP codegen/templates with FP8 types, atomic helpers, loop/no-op handling, and GEMM constraints. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant CLI as CLI/Caller
participant AT as Autotuner
participant JIT as JIT Compiler
participant K as GPU Kernel (FWD)
participant GPU as GPU
CLI->>AT: fast_flashattn(args, configs)
AT->>JIT: Select+compile kernel variant
JIT-->>AT: Compiled kernel handle
AT->>K: Launch with Q,K,V and params
K->>GPU: Load tiles, compute QK, mask (causal), softmax, apply V
GPU-->>K: Tile results
K-->>CLI: Output + (optionally) lse/logits stats
sequenceDiagram
autonumber
participant CLI as CLI/Caller
participant PRE as bwd_preprocess
participant AT as Autotuner (BWD)
participant JIT as JIT Compiler
participant KB as GPU Kernel (BWD)
participant POST as bwd_postprocess
CLI->>PRE: flashattn_bwd_preprocess(Q,K,V,dO,...)
PRE-->>CLI: Preprocessed aux (e.g., delta/LSE)
CLI->>AT: flashattn_bwd(args, configs)
AT->>JIT: Compile selected bwd kernel
JIT-->>AT: Kernel handle
AT->>KB: Launch to compute dQ, dK, dV
KB-->>CLI: dQ, dK, dV
CLI->>POST: flashattn_bwd_postprocess(dQ,...)
POST-->>CLI: Final gradients
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
…lash_attn_bwd.py for improved code cleanliness.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (8)
examples/amd/example_amd_flash_attn_fwd.py (2)
37-37
: Autotune space widened: expect longer tuning; consider adding num_stages=2 for AMD LDS pipeliningGood to include 0 and 1. On some workloads, stages=2 helps hide memory latency. Consider adding 2 (optional) or gate by dim/seq_len to avoid explosion.
175-186
: Stable softmax path switch to exp is correct; minor micro-optimization optionalThe move from exp2 to exp with scaled m is numerically sound. You can precompute inv_sqrt_d = scale once and reuse; also fold the two multiplications into a single fused multiply-sub for readability:
- sf = T.exp((m_prev[i] - m_i[i]) * scale)
- acc_s[i, j] = T.exp((acc_s[i, j] - m_i[i]) * scale)
Current code is correct as-is.
examples/amd/example_amd_flash_attn_bwd.py (5)
34-66
: Config space parity with fwd exampleGood coverage. To keep parity with the forward example file and avoid regressions, consider constraining threads/panel combos via simple heuristics (e.g., limit threads=1024 to block sizes ≥128) to reduce tuning time without losing best configs.
249-273
: Preprocess kernel is correct; consider explicit threads kwargThe Delta = sum(O*dO, axis=dim) implementation is fine. For consistency and predictable occupancy, pass threads=128 (or tuned) to T.Kernel here.
276-296
: Backward kernel math path LGTM; minor clarity tweak
- P_acc = exp(qkT*sm_scale - lse) and causal masking are standard.
- dP, dv, dk, dq accumulation and atomic adds look consistent.
Optional: rename qkT -> S_scaled for readability.
Also applies to: 321-365
386-445
: Debug/benchmark helpers are useful; minor CUDA guardYou already guard synchronize; optionally also guard torch.cuda.manual_seed(42) and device selection in main by checking CUDA/HIP availability to make CPU fallback easier when running on non-GPU CI.
446-486
: End-to-end flow is solid; add LSE check and deterministic seedsGreat validation flow. Two small improvements:
- Validate LSE from fwd_kernel against ref_program to catch rare stability issues.
- Set torch.backends.cuda.matmul.allow_tf32 and cudnn flags as needed for determinism during checks (keep them default for perf bench).
I can add these guards and checks if you want.
Also applies to: 488-586
src/tl_templates/hip/hip_fp8.h (1)
54-64
: Inconsistent API: fp8_e5_4_t is missing constructors and operators present in fp8_e4_4_t.The
fp8_e5_4_t
struct only has:
- Default constructor
- Constructor from
__hip_fp8x4_e5m2_fnuz
- Conversion operator
However,
fp8_e4_4_t
(lines 24-40) additionally provides:
- Constructor from
float4
- Assignment operator from native type
This API inconsistency could lead to confusion and limit usability of E5 types.
Consider adding the missing members to
fp8_e5_4_t
to matchfp8_e4_4_t
:struct fp8_e5_4_t { union { __hip_fp8x4_e5m2_fnuz data; struct { fp8_e5_t x, y, z, w; }; }; __device__ fp8_e5_4_t() = default; __device__ fp8_e5_4_t(const __hip_fp8x4_e5m2_fnuz &val) : data(val) {} + __device__ fp8_e5_4_t(const float4 &val) : data(val) {} __device__ operator __hip_fp8x4_e5m2_fnuz() const { return data; } + __device__ fp8_e5_4_t &operator=(const __hip_fp8x4_e5m2_fnuz &val) { + data = val; + return *this; + } };
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/amd/example_amd_flash_attn_bwd.py
(2 hunks)examples/amd/example_amd_flash_attn_fwd.py
(3 hunks)examples/amd/test.sh
(0 hunks)examples/flash_attention/example_mha_bwd.py
(0 hunks)src/target/codegen_hip.cc
(4 hunks)src/tl_templates/hip/common.h
(1 hunks)src/tl_templates/hip/gemm.h
(1 hunks)src/tl_templates/hip/hip_fp8.h
(2 hunks)
💤 Files with no reviewable changes (2)
- examples/amd/test.sh
- examples/flash_attention/example_mha_bwd.py
🔇 Additional comments (11)
examples/amd/example_amd_flash_attn_fwd.py (1)
86-86
: Scale change LGTMUsing natural exp with scale=1/sqrt(d) aligns with the reference softmax.
examples/amd/example_amd_flash_attn_bwd.py (1)
68-88
: Forward kernel (in bwd file) numerics and LSE emission look correct
- Scaling applied before max and exp; mask preserves -inf; LSE written as log(l_i)+m_i with bounds check. This matches the reference.
If desired, add an assert_allclose against ref_program for LSE too to tighten validation (currently only outputs checked via profiler in main).
Also applies to: 98-107, 176-201, 212-222
src/tl_templates/hip/common.h (1)
119-121
: LGTM!The AtomicAddRet implementation correctly performs an atomic add on a reference and returns the result. This is a useful addition for cases where the return value is needed.
src/target/codegen_hip.cc (4)
44-54
: LGTM: Extended FP8 type support.The additions correctly map new FP8 variants (kFloat8_e4m3, kFloat8_e4m3b11fnuz, kFloat8_e5m2fnuz, kFloat8_e8m0fnu) to their respective HIP type representations. The consistent pattern of preserving vector-width suffixes ensures proper type handling across all FP8 variants.
55-55
: LGTM: Improved error diagnostics.Including the actual type in the error message will help developers quickly identify unsupported FP8 types during debugging.
966-972
: LGTM: HIP-specific control flow additions.The
loop_break
emits a standard break statement, andno_set_max_nreg
correctly becomes a no-op for HIP since register management is handled differently than in CUDA. Both implementations are appropriate for the HIP backend.
1178-1179
: LGTM: Consistent FP8 constant handling.The extension of FP8 constant detection to include
is_float8_e4m3
andis_float8_e4m3fn
aligns with the FP8 type handling additions earlier in the file, ensuring all FP8 variants are properly recognized and emitted asfp8_e4_t
constants.src/tl_templates/hip/hip_fp8.h (3)
8-13
: LGTM: Added FP8 E5 type aliases.The additions correctly provide type aliases for E5M2 FP8 types, maintaining consistency with the existing E4M3 aliases. The commented note about E8M0 not being supported in the current HIP version is helpful for future reference.
24-40
: LGTM: Enhanced fp8_e4_4_t wrapper API.The additions provide a complete set of constructors and operators:
- Default constructor for uninitialized instances
- Copy constructor from native type for interoperability
- Constructor from float4 for conversions
- Conversion operator for seamless usage with native types
- Assignment operator for value updates
These enhancements improve usability and type safety for FP8 operations.
66-74
: LGTM: E5 vector type scaffolding.The
fp8_e5_8_t
andfp8_e5_16_t
structs follow the same composition pattern as their E4 counterparts, providing consistent vector type support across FP8 variants.src/tl_templates/hip/gemm.h (1)
73-75
: Verify that accumulator clearing is implemented when clear_accum=true.The comment states "we'll handle it by manually clearing the accumulator," but reviewing the
body()
andbody_rs()
methods (lines 158-291), I don't see explicit accumulator initialization code that checks theclear_accum
template parameter.If
clear_accum=true
is now allowed but not implemented, this could lead to undefined behavior with uninitialized accumulators.Run the following script to verify if there's accumulator clearing logic elsewhere or if this needs to be implemented:
T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix policy reference: use imported GemmWarpPolicy instead of T.GemmWarpPolicy
You import GemmWarpPolicy, but call T.GemmWarpPolicy.*. Unless tilelang.language exposes GemmWarpPolicy, this will raise at JIT time. Align with the forward file usage.
Apply this diff:
- T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+ T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=GemmWarpPolicy.FullRow)
@@
- T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow)
+ T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=GemmWarpPolicy.FullRow)
@@
- T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow)
+ T.gemm(p_cast, do_shared, dv, policy=GemmWarpPolicy.FullRow)
@@
- T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow)
+ T.gemm(p_cast, q_shared, dk, policy=GemmWarpPolicy.FullRow)
Also applies to: 343-347, 353-354
🤖 Prompt for AI Agents
In examples/amd/example_amd_flash_attn_bwd.py around lines 328-329 (and also
343-347, 353-354), the code references T.GemmWarpPolicy.* while GemmWarpPolicy
is imported directly; replace all uses of T.GemmWarpPolicy with the imported
GemmWarpPolicy (e.g., GemmWarpPolicy.FullRow) to match the forward file and
avoid JIT-time attribute errors, updating each affected invocation to use
GemmWarpPolicy.* consistently.
// Overload for when the first argument is a value instead of a pointer | ||
template <typename T1, typename T2> | ||
TL_DEVICE void AtomicAdd(T1 address, T2 val) { | ||
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: Atomic operation on local copy is meaningless.
This overload takes address
by value (a copy), then takes the address of that local copy to perform the atomic operation. Since the atomic operation is performed on a temporary local variable, the result is discarded when the function returns, making this operation meaningless.
If the intent is to provide a convenience wrapper for value-based atomics, the signature should take a reference or pointer instead. As written, this code will compile but produce incorrect behavior.
Apply this diff to remove the broken overload:
-// Overload for when the first argument is a value instead of a pointer
-template <typename T1, typename T2>
-TL_DEVICE void AtomicAdd(T1 address, T2 val) {
- atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
-}
If a value-based atomic is truly needed, clarify the intended semantics and implement accordingly.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
// Overload for when the first argument is a value instead of a pointer | |
template <typename T1, typename T2> | |
TL_DEVICE void AtomicAdd(T1 address, T2 val) { | |
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val)); | |
} |
🤖 Prompt for AI Agents
In src/tl_templates/hip/common.h around lines 113-117, the AtomicAdd overload
takes the address parameter by value and performs atomicAdd on the address of
that local copy, so the atomic operation targets a temporary and has no effect;
remove this broken overload or change its signature to accept a pointer or
reference to the target storage (e.g., T1* or T1&) and use that directly in
atomicAdd, or if a value-based atomic is intended, document and implement the
correct semantics (e.g., return the new value and operate on a shared location)
instead of operating on a local copy.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (4)
examples/amd/example_amd_flash_attn_fwd.py (1)
31-63
: Autotune search space broadened; consider tuning bounds for ROCmnum_stages now includes [0, 1]. Good for overlap coverage. If you observe compile/runtime instability on some HIP stacks, you may want to guard stage=1 via device checks or leave a note in the example.
examples/amd/example_amd_flash_attn_bwd.py (3)
328-347
: Unify policy enum usage to avoid ambiguityYou import GemmWarpPolicy but also use T.GemmWarpPolicy in a few places. Prefer one. Using the imported GemmWarpPolicy everywhere avoids reliance on module re-exports.
- T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm(K_shared, q_shared, qkT, transpose_B=True, policy=GemmWarpPolicy.FullRow) - T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=GemmWarpPolicy.FullRow) - T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow) + T.gemm(p_cast, q_shared, dk, policy=GemmWarpPolicy.FullRow)Also applies to: 343-347, 353-358
256-271
: Rename ambiguous parameter ‘O’ to satisfy linters and improve clarityVariable name O triggers E741 (ambiguous name) and reduces readability.
- def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), + def flash_bwd_prep(Out: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)): @@ - T.copy(O[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o) + T.copy(Out[bz, by * blk:(by + 1) * blk, bx, k * blk:(k + 1) * blk], o)Optionally mirror this rename at the call site for consistency in naming (argument name doesn’t need to match, so no functional impact).
518-534
: Remove or underscore unused ‘*_mean_diff’ variablesAvoid RUF059 by ignoring the unused means.
- dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison( + dq_close, dq_max_diff, _dq_mean = debug_tensor_comparison( dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) @@ - dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison( + dk_close, dk_max_diff, _dk_mean = debug_tensor_comparison( dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) @@ - dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison( + dv_close, dv_max_diff, _dv_mean = debug_tensor_comparison( dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/amd/example_amd_flash_attn_bwd.py
(2 hunks)examples/amd/example_amd_flash_attn_fwd.py
(3 hunks)examples/amd/test.sh
(0 hunks)examples/flash_attention/example_mha_bwd.py
(0 hunks)src/target/codegen_hip.cc
(4 hunks)src/tl_templates/hip/common.h
(1 hunks)src/tl_templates/hip/gemm.h
(1 hunks)src/tl_templates/hip/hip_fp8.h
(2 hunks)
💤 Files with no reviewable changes (2)
- examples/flash_attention/example_mha_bwd.py
- examples/amd/test.sh
🧰 Additional context used
🧬 Code graph analysis (3)
src/tl_templates/hip/hip_fp8.h (1)
src/tl_templates/cuda/cuda_fp8.h (1)
fp8_e5_2_t
(44-47)
src/target/codegen_hip.cc (2)
tilelang/language/builtin.py (2)
loop_break
(419-422)no_set_max_nreg
(160-163)tilelang/language/customize.py (1)
loop_break
(67-73)
examples/amd/example_amd_flash_attn_bwd.py (2)
examples/amd/example_amd_flash_attn_fwd.py (4)
ref_program
(11-28)fast_flashattn
(68-205)main
(97-203)main
(208-238)examples/attention_sink/example_gqa_sink_bwd_bhsd.py (7)
ref_program
(403-447)main
(450-519)flashattn_bwd_preprocess
(148-173)flash_bwd_prep
(154-171)flashattn_bwd
(213-321)flashattn_bwd_postprocess
(188-205)flash_bwd_post
(194-203)
🪛 Ruff (0.13.3)
examples/amd/example_amd_flash_attn_bwd.py
257-257: Ambiguous variable name: O
(E741)
518-518: Unpacked variable dq_mean_diff
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
525-525: Unpacked variable dk_mean_diff
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
532-532: Unpacked variable dv_mean_diff
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-metal
- GitHub Check: format-check
🔇 Additional comments (12)
src/tl_templates/hip/common.h (1)
119-121
: LGTM!The
AtomicAddRet
function correctly takes a reference and returns the old value from the atomic operation.src/target/codegen_hip.cc (4)
44-55
: LGTM!The extended FP8 type handling correctly maps additional FP8 variants to their corresponding HIP types (fp8_e4_t, fp8_e5_t, fp8_e8_t). The error message now includes the type information for better diagnostics.
936-940
: LGTM!The reformatted mfma builtin call maintains the same behavior with improved readability.
966-972
: LGTM!The implementation correctly handles:
tl::loop_break()
by emitting abreak;
statementtl::no_set_max_nreg()
as a no-op (HIP doesn't require explicit register management like CUDA)Both behaviors are appropriate for the HIP backend.
1178-1179
: LGTM!The consolidated FP8 constant handling correctly combines multiple E4M3 variants (e4m3fnuz, e4m3, e4m3fn) that all map to the same
fp8_e4_t
type for HIP. This simplification reduces code duplication.src/tl_templates/hip/hip_fp8.h (3)
8-13
: LGTM!The FP8 E5M2 typedefs are correctly defined using HIP's underlying types, and the comment appropriately documents that E8M0 types are not yet supported in the current HIP version.
53-74
: LGTM!The FP8 E5M2 vector types (fp8_e5_4_t, fp8_e5_8_t, fp8_e5_16_t) follow the same structure and pattern as the existing E4M3 types, providing consistent API and proper alignment attributes.
76-99
: LGTM!The commented-out E8M0 vector types provide a clear template for future implementation when HIP adds support for this FP8 variant.
examples/amd/example_amd_flash_attn_fwd.py (2)
86-86
: Scale computation is correct(1.0 / dim)**0.5 equals 1/sqrt(dim); matches the reference softmax scaling.
172-186
: exp2 → exp change is consistent with the new scaleUsing T.exp with the scale as written preserves correctness and removes the need for a log2(e) factor. Looks good.
If you previously cached autotune results, clear the cache to avoid mixing exp2/exp kernels during benchmarking.
examples/amd/example_amd_flash_attn_bwd.py (2)
88-105
: Forward kernel: LSE production and scaling look correctYou now emit LSE and compute it as log(l_i) + m_i in the scaled domain. The earlier multiply-by-scale before reduce_max and exp(acc_s - m_i) is consistent and numerically stable.
Please re-run forward correctness with a larger seq_len (e.g., 8192) to sanity-check stability on HIP: rtol=5e-2, atol=5e-2.
Also applies to: 207-223
447-507
: End-to-end wiring is clean and debuggableNice addition of reference checks, LSE plumbing, and separated pre/post backward stages; improves clarity and makes benchmarking straightforward.
// Overload for when the first argument is a value instead of a pointer | ||
template <typename T1, typename T2> | ||
TL_DEVICE void AtomicAdd(T1 address, T2 val) { | ||
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical logic error: taking address of pass-by-value parameter.
This overload takes T1 address
by value and then uses &address
to get a pointer. This retrieves the address of the local copy created for the function parameter, not the address of the caller's original variable. The atomic operation would be performed on the local copy, which is destroyed when the function returns, making the atomic operation completely ineffective.
Either remove this overload entirely if it's not needed, or if the intent was to support non-pointer types, the signature should be AtomicAdd(T1 &address, T2 val)
(taking a reference).
Apply this diff to fix the signature to take a reference:
// Overload for when the first argument is a value instead of a pointer
template <typename T1, typename T2>
-TL_DEVICE void AtomicAdd(T1 address, T2 val) {
- atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
+TL_DEVICE void AtomicAdd(T1 &address, T2 val) {
+ atomicAdd(&address, static_cast<T1>(val));
}
🤖 Prompt for AI Agents
In src/tl_templates/hip/common.h around lines 113 to 117, the AtomicAdd overload
takes its first parameter by value and then uses &address, which takes the
address of the local copy instead of the caller's variable; change the function
signature to take the address parameter by reference (e.g., T1 &address) so
atomicAdd operates on the caller's memory (or remove the overload if not
needed), and update any callers if necessary to pass lvalues.
// Note: clear_accum=true is not fully supported in HIP implementation | ||
// but we'll handle it by manually clearing the accumulator | ||
// static_assert(!clear_accum, "clear_accum=true is not supported yet"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
Verify that clear_accum=true is properly handled.
The static_assert has been removed and replaced with a comment stating that clear_accum will be "manually cleared." However, examining the body()
(lines 158-232) and body_rs()
(lines 234-291) methods, there is no visible code that clears the accumulator when clear_accum=true
.
If clear_accum=true
is passed, the accumulator should be zeroed before accumulation begins, but this implementation appears to rely on the caller to pre-zero C_local.
Run the following script to check if clear_accum is handled elsewhere or if callers are responsible for zeroing:
🏁 Script executed:
#!/bin/bash
# Search for clear_accum usage patterns and accumulator initialization
rg -n -C5 'clear_accum.*true|C_local.*=.*0|memset.*C_local' --type=cpp --type=h
Length of output: 3003
Restore or implement clear_accum handling
Either re-enable static_assert(!clear_accum)
on lines 73–75 or add code to zero C_local
when clear_accum=true
, as it’s currently unhandled.
🤖 Prompt for AI Agents
In src/tl_templates/hip/gemm.h around lines 73–75, restore handling for the
clear_accum template flag: either re-enable the static_assert(!clear_accum) to
fail compilation when clear_accum=true, or implement the runtime/compile-time
zeroing of C_local when clear_accum==true. If choosing static_assert, add the
assert back exactly for the HIP path; if choosing implementation, add a guarded
branch (compile-time if constexpr or preprocessor guarded by the template param)
that writes zeros into the accumulator buffer C_local before accumulation
begins, ensuring the zeroing matches the accumulator tile shape and memory
layout and does not break existing non-clear behavior.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/amd/example_amd_flash_attn_bwd.py (1)
1-601
: Format examples/amd/example_amd_flash_attn_bwd.py using the project’s formatter (e.g.,ruff --fix
orblack
) to resolve CI formatting errors.src/tl_templates/hip/gemm.h (1)
54-65
: Add MfmaTraits specialization for fp8_e5_t
HIP codegen emits fp8_e5_t but src/tl_templates/hip/gemm.h only handles fp8_e4_t. Add the following under#if defined(HIP_FP8_ENABLED)
:template <> struct MfmaTraits<fp8_e5_t> { template <typename AccType> static TL_DEVICE void mfma_op(const fp8_e5_t* b, const fp8_e5_t* a, AccType* c) { int64_t a_val = *reinterpret_cast<const int64_t*>(a); int64_t b_val = *reinterpret_cast<const int64_t*>(b); *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0); } };Confirm the intrinsic is correct for E5 on your target compiler.
🧹 Nitpick comments (3)
examples/amd/example_amd_flash_attn_bwd.py (2)
257-258
: Address ambiguous variable name.Static analysis flags variable
O
as ambiguous (single uppercase letter). While this is idiomatic for mathematical notation in attention mechanisms, consider the project's style guidelines.Based on learnings
If the codebase prefers descriptive names, apply this diff:
- def flash_bwd_prep(O: T.Tensor(shape, dtype), dO: T.Tensor(shape, dtype), + def flash_bwd_prep(Output: T.Tensor(shape, dtype), dOutput: T.Tensor(shape, dtype), Delta: T.Tensor([batch, heads, seq_len], accum_dtype)):And update references within the function accordingly.
517-517
: Remove unused unpacked variables.Static analysis flags
dq_mean_diff
,dk_mean_diff
, anddv_mean_diff
as unpacked but unused.Apply this diff to suppress warnings:
- dq_close, dq_max_diff, dq_mean_diff = debug_tensor_comparison( + dq_close, dq_max_diff, _ = debug_tensor_comparison( dQ_tl, q_ref.grad, "dQ", rtol=0.05, atol=0.05) if dq_close: print("dQ is correct.") else: print("dQ mismatch detected.") - dk_close, dk_max_diff, dk_mean_diff = debug_tensor_comparison( + dk_close, dk_max_diff, _ = debug_tensor_comparison( dK_tl.to(torch.float16), k_ref.grad, "dK", rtol=0.05, atol=0.05) if dk_close: print("dK is correct.") else: print("dK mismatch detected.") - dv_close, dv_max_diff, dv_mean_diff = debug_tensor_comparison( + dv_close, dv_max_diff, _ = debug_tensor_comparison( dV_tl.to(torch.float16), v_ref.grad, "dV", rtol=0.05, atol=0.05)Also applies to: 524-524, 531-531
src/tl_templates/hip/hip_fp8.h (1)
53-75
: E5 FP8 vector wrappers: LGTM; consider parity with E4 helpersImplementation mirrors E4 forms. For API parity, you may add an assignment operator from __hip_fp8x4_e5m2_fnuz like E4_4 has, but optional.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (8)
examples/amd/example_amd_flash_attn_bwd.py
(2 hunks)examples/amd/example_amd_flash_attn_fwd.py
(3 hunks)examples/amd/test.sh
(0 hunks)examples/flash_attention/example_mha_bwd.py
(0 hunks)src/target/codegen_hip.cc
(4 hunks)src/tl_templates/hip/common.h
(1 hunks)src/tl_templates/hip/gemm.h
(1 hunks)src/tl_templates/hip/hip_fp8.h
(2 hunks)
💤 Files with no reviewable changes (2)
- examples/amd/test.sh
- examples/flash_attention/example_mha_bwd.py
🧰 Additional context used
🧬 Code graph analysis (3)
src/target/codegen_hip.cc (2)
tilelang/language/builtin.py (2)
loop_break
(419-422)no_set_max_nreg
(160-163)tilelang/language/customize.py (1)
loop_break
(67-73)
src/tl_templates/hip/hip_fp8.h (1)
src/tl_templates/cuda/cuda_fp8.h (1)
fp8_e5_2_t
(44-47)
examples/amd/example_amd_flash_attn_bwd.py (4)
examples/amd/example_amd_flash_attn_fwd.py (4)
ref_program
(11-28)fast_flashattn
(68-205)main
(97-203)main
(208-238)examples/attention_sink/example_mha_sink_fwd_bhsd.py (3)
ref_program
(191-236)main
(128-185)main
(253-312)examples/attention_sink/example_mha_sink_bwd_bhsd.py (8)
ref_program
(399-444)main
(447-514)get_bwd_configs
(11-19)flashattn_bwd_preprocess
(145-170)flash_bwd_prep
(151-168)flashattn_bwd
(210-322)flashattn_bwd_postprocess
(185-202)backward
(367-391)tilelang/tileop/gemm/gemm_base.py (3)
K
(42-43)k_pack
(111-112)policy
(119-120)
🪛 GitHub Actions: CI Test on Metal
examples/amd/example_amd_flash_attn_bwd.py
[error] 1-1: Reformatted files. Please review and stage the changes.
🪛 Ruff (0.13.3)
examples/amd/example_amd_flash_attn_bwd.py
257-257: Ambiguous variable name: O
(E741)
517-517: Unpacked variable dq_mean_diff
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
524-524: Unpacked variable dk_mean_diff
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
531-531: Unpacked variable dv_mean_diff
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
🔇 Additional comments (10)
examples/amd/example_amd_flash_attn_fwd.py (2)
37-37
: LGTM! Autotune search space expanded.Adding
num_stages = [0, 1]
increases the configuration search space, allowing the autotuner to explore pipelined and non-pipelined execution paths for better performance.
86-86
: Consistentexp
usage verified. All AMD attention examples useT.exp
withscale = (1.0/dim)**0.5
; noT.exp2
or hardcoded1.44269504
found.examples/amd/example_amd_flash_attn_bwd.py (6)
1-11
: LGTM! Imports are appropriate.The imports include necessary dependencies for the backward pass implementation, including
numpy
for tensor comparison utilities andtime
for benchmarking.
13-31
: LGTM! Reference implementation returns LSE.The reference implementation now returns both the output and log-sum-exp (LSE), which is required for the backward pass. The variable naming uses
K_ref
andV_ref
for clarity after therepeat_interleave
operation.
182-198
: Guard against infinity before exponentiation.The code correctly checks for
-T.infinity
values before callingT.exp
, preventing numerical overflow and NaN propagation:if m_prev[i] == -T.infinity(accum_dtype): scale_factor[i] = 0.0 else: scale_factor[i] = T.exp(m_prev[i] - m_i[i])This defensive pattern is also applied at lines 194-197. Good practice for numerical stability.
276-365
: LGTM! Backward kernel uses atomic operations correctly.The backward pass implementation:
- Correctly computes dV and dK accumulations per block
- Uses
T.atomic_add
for dQ accumulation (line 359) to handle race conditions from multiple blocks writing to the same dQ elements- Uses
T.atomic_add
for final dV and dK writes (lines 362-363) to handle multi-group scenariosThe use of
accum_dtype
(float32) for dQ accumulation buffer followed by postprocessing to float16 is a sound approach for numerical stability.
386-425
: Comprehensive debug utility.The
debug_tensor_comparison
function provides detailed diagnostics including max/mean differences, relative errors, NaN/Inf counts, and per-element analysis. This will be valuable for debugging numerical issues during development.
446-588
: Comprehensive test harness with proper verification.The main function provides:
- Forward pass autotuning and verification against PyTorch reference
- Backward pass autotuning with preprocessing and postprocessing
- Detailed gradient verification with configurable tolerances
- Performance benchmarking with proper synchronization
- Clear pass/fail reporting
The integration of all components (forward, backward, pre/post-processing) demonstrates a complete end-to-end workflow.
src/target/codegen_hip.cc (1)
966-973
: loop_break and no_set_max_nreg lowering: LGTMEmitting break; and treating no_set_max_nreg() as a no‑op on HIP is fine.
Ensure these intrinsics are only used in statement context (Evaluate) to avoid dangling expressions in value contexts.
src/tl_templates/hip/hip_fp8.h (1)
8-14
: E5 FP8 aliases: LGTMAdds E5M2 aliases consistent with HIP’s fnuz variants.
Confirm target HIP SDK provides __hip_fp8_e5m2_fnuz and __hip_fp8x2_e5m2_fnuz on all CI targets.
def get_fwd_configs(): | ||
block_M = [32, 64, 128, 256] | ||
block_N = [32, 64, 128, 256] | ||
threads = [128, 256, 512] | ||
num_split_q = [64, 128, 256] | ||
num_stages = [0, 1] | ||
enable_rasterization = [True] | ||
k_pack = [2] | ||
panel_size = [7, 8, 9, 10] | ||
qk_coalesced_width = [8] | ||
v_coalesced_width = [4] | ||
|
||
valid_configs = [] | ||
|
||
for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product(block_M, block_N, num_split_q, | ||
threads, num_stages, | ||
enable_rasterization, k_pack, | ||
panel_size, qk_coalesced_width, | ||
v_coalesced_width): | ||
valid_configs.append({ | ||
"block_M": m, | ||
"block_N": n, | ||
"num_split_q": s, | ||
"threads": t, | ||
"num_stages": stages, | ||
"enable_rasterization": r, | ||
"k_pack": k, | ||
"panel_size": p, | ||
"qk_coalesced_width": qkw, | ||
"v_coalesced_width": vw, | ||
}) | ||
return valid_configs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion | 🟠 Major
🧩 Analysis chain
Verify configuration consistency with forward pass.
The get_fwd_configs()
function generates configurations for forward pass autotuning. Note that panel_size
has been expanded to [7, 8, 9, 10]
compared to the forward-only file.
Ensure these configurations align with the standalone forward example to maintain consistency across the codebase:
🏁 Script executed:
#!/bin/bash
# Description: Compare forward configs between fwd-only and fwd-bwd examples
echo "=== Forward-only example config ==="
rg -A 10 'def get_configs' examples/amd/example_amd_flash_attn_fwd.py
echo ""
echo "=== Forward-backward example config ==="
rg -A 10 'def get_fwd_configs' examples/amd/example_amd_flash_attn_bwd.py
Length of output: 758
Align panel_size between examples. get_fwd_configs()
uses [7, 8, 9, 10], but get_configs()
uses [7, 8]; update one to match the other.
🤖 Prompt for AI Agents
In examples/amd/example_amd_flash_attn_bwd.py around lines 34 to 65, panel_size
values differ between get_fwd_configs() ([7, 8, 9, 10]) and get_configs() ([7,
8]); make them consistent by updating the get_configs() panel_size list to match
get_fwd_configs() (use [7, 8, 9, 10]) so both functions produce the same
panel_size options.
} else if (type.code() == DataType::kFloat8_e4m3) { | ||
stream << "fp8_e4" << vec << "_t"; | ||
} else if (type.code() == DataType::kFloat8_e4m3b11fnuz) { | ||
stream << "fp8_e4" << vec << "_t"; | ||
} else if (type.code() == DataType::kFloat8_e5m2) { | ||
stream << "fp8_e5" << vec << "_t"; | ||
} else if (type.code() == DataType::kFloat8_e5m2fnuz) { | ||
stream << "fp8_e5" << vec << "_t"; | ||
} else if (type.code() == DataType::kFloat8_e8m0fnu) { | ||
stream << "fp8_e8" << vec << "_t"; | ||
} else { | ||
LOG(FATAL) << "Unsupported FP8 type in HIP codegen"; | ||
LOG(FATAL) << "Unsupported FP8 type in HIP codegen: " << type; | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
FP8 E8 mapping unsupported by header; and FP8 E5 constants not handled
- GetFP8Type maps kFloat8_e8m0fnu → fp8_e8*_t, but hip_fp8.h explicitly comments E8 types as unsupported. This will fail to compile when used.
- E5 mapping is added here, but PrintConst lacks E5 handling, leading to LOG(FATAL) for FloatImm FP8 E5.
Apply:
- } else if (type.code() == DataType::kFloat8_e8m0fnu) {
- stream << "fp8_e8" << vec << "_t";
- } else {
+ } else {
LOG(FATAL) << "Unsupported FP8 type in HIP codegen: " << type;
And extend PrintConst (see below comment) to support E5 FP8 constants.
Check if any E8 paths are exercised; if so, either add E8 types to hip_fp8.h guarded by feature detection or drop E8 emission entirely.
🏁 Script executed:
#!/bin/bash
set -e
# Locate hip_fp8.h
HIP_FILE=$(fd hip_fp8.h | head -n1)
echo "=== hip_fp8.h ($HIP_FILE) context ==="
rg -n -C2 'E8' "$HIP_FILE"
rg -n -C2 'unsupported' "$HIP_FILE"
# Inspect PrintConst in codegen_hip.cc for FP8 handling
echo "=== PrintConst FP8 cases in codegen_hip.cc ==="
rg -n -C3 'PrintConst' -m1 src/target/codegen_hip.cc | sed -n '1,20p'
rg -n -C3 'kFloat8' src/target/codegen_hip.cc
Length of output: 408
🏁 Script executed:
#!/bin/bash
set -e
# Locate PrintConst in the repository
echo "=== PrintConst definitions/usages ==="
rg -n 'PrintConst' -C3 .
# Show all FP8 kFloat8 cases with context in codegen_hip.cc
echo "=== FP8 cases in codegen_hip.cc ==="
rg -n -C10 'kFloat8' src/target/codegen_hip.cc
Length of output: 5602
🏁 Script executed:
#!/bin/bash
set -e
sed -n '1171,1271p' src/target/codegen_hip.cc
Length of output: 2899
Remove unsupported FP8 E8 case and add FP8 E5 constant support
- Drop the
kFloat8_e8m0fnu
branch sincehip_fp8.h
(lines 11–13) comments out E8 aliases, causing build failures. - Extend
PrintConst
to handle E5 constants:if (op->dtype.is_float8_e5m2() || op->dtype.is_float8_e5m2fnuz()) { os << "fp8_e5_t"; os << '(' << std::scientific << op->value << 'f' << ')'; return; }
- } else if (type.code() == DataType::kFloat8_e8m0fnu) {
- stream << "fp8_e8" << vec << "_t";
- } else {
+ } else {
LOG(FATAL) << "Unsupported FP8 type in HIP codegen: " << type;
🤖 Prompt for AI Agents
In src/target/codegen_hip.cc around lines 44 to 56, remove the branch that
handles DataType::kFloat8_e8m0fnu (delete the else-if that streams "fp8_e8...")
because HIP headers comment out E8 aliases and this causes build failures;
additionally, extend the PrintConst handling for FP8 constants to detect e5
types (is_float8_e5m2 or is_float8_e5m2fnuz) and emit the constant as "fp8_e5_t"
constructed from the literal value (i.e., print "fp8_e5_t(...f)"), ensuring the
same formatting style as other float constant code paths so E5 constants are
supported.
} else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() || | ||
op->dtype.is_float8_e4m3fn()) { | ||
os << "fp8_e4_t"; | ||
os << '(' << std::scientific << op->value << 'f' << ')'; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add FP8 E5 constant emission to avoid fatal path
Currently only E4 FP8 constants are emitted. Add E5 branch to PrintConst.
- } else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() ||
- op->dtype.is_float8_e4m3fn()) {
+ } else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() ||
+ op->dtype.is_float8_e4m3fn()) {
os << "fp8_e4_t";
os << '(' << std::scientific << op->value << 'f' << ')';
return;
+ } else if (op->dtype.is_float8_e5m2fnuz()
+ /* add is_float8_e5m2() if available in your DataType */) {
+ os << "fp8_e5_t";
+ os << '(' << std::scientific << op->value << 'f' << ')';
+ return;
}
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
} else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() || | |
op->dtype.is_float8_e4m3fn()) { | |
os << "fp8_e4_t"; | |
os << '(' << std::scientific << op->value << 'f' << ')'; | |
} else if (op->dtype.is_float8_e4m3fnuz() || op->dtype.is_float8_e4m3() || | |
op->dtype.is_float8_e4m3fn()) { | |
os << "fp8_e4_t"; | |
os << '(' << std::scientific << op->value << 'f' << ')'; | |
return; | |
} else if (op->dtype.is_float8_e5m2fnuz() | |
/* add is_float8_e5m2() if available in your DataType */) { | |
os << "fp8_e5_t"; | |
os << '(' << std::scientific << op->value << 'f' << ')'; | |
return; | |
} |
🤖 Prompt for AI Agents
In src/target/codegen_hip.cc around lines 1178 to 1181, the PrintConst handling
only emits FP8 E4 constants; add a parallel branch to handle FP8 E5 types. Add
an else-if after the existing E4 check that tests the E5 dtype predicates (the
analogous is_float8_e5... methods), and emit "fp8_e5_t" with the same formatting
as E4 (e.g. os << "fp8_e5_t"; os << '(' << std::scientific << op->value << 'f'
<< ')';). Ensure the new branch mirrors the E4 branch placement and formatting
so FP8 E5 constants are emitted instead of falling into the fatal path.
// Overload for when the first argument is a value instead of a pointer | ||
template <typename T1, typename T2> | ||
TL_DEVICE void AtomicAdd(T1 address, T2 val) { | ||
atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val)); | ||
} | ||
|
||
template <typename T1, typename T2> TL_DEVICE T1 AtomicAddRet(T1 &ref, T2 val) { | ||
return atomicAdd(&ref, static_cast<T1>(val)); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Invalid AtomicAdd overload takes a value; atomics must target real memory
AtomicAdd(T1 address, T2 val)
atomically updates a local copy via &address
, not the intended memory; this is incorrect and can cause undefined behavior on HIP.
Remove this overload and, if needed, provide a reference overload (non‑returning) alongside AtomicAddRet:
-// Overload for when the first argument is a value instead of a pointer
-template <typename T1, typename T2>
-TL_DEVICE void AtomicAdd(T1 address, T2 val) {
- atomicAdd(reinterpret_cast<T1 *>(&address), static_cast<T1>(val));
-}
+// Preferred non-returning reference overload (mirrors AtomicAddRet)
+template <typename T1, typename T2>
+TL_DEVICE void AtomicAdd(T1& ref, T2 val) {
+ atomicAdd(&ref, static_cast<T1>(val));
+}
Optionally constrain T1 to HIP‑supported atomic types via static_assert.
🤖 Prompt for AI Agents
In src/tl_templates/hip/common.h around lines 113 to 121, the overload
AtomicAdd(T1 address, T2 val) takes the address by value and does atomicAdd on
&address (a local copy) which is invalid; remove this value-taking overload,
replace with a non-returning reference overload AtomicAdd(T1 &address, T2 val)
that calls atomicAdd(&address, static_cast<T1>(val)), and (optionally) add a
static_assert or type trait to constrain T1 to HIP-supported atomic types to
catch misuse at compile time.
// Note: clear_accum=true is not fully supported in HIP implementation | ||
// but we'll handle it by manually clearing the accumulator | ||
// static_assert(!clear_accum, "clear_accum=true is not supported yet"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do not drop support guard without implementing clear_accum behavior
Removing the static_assert enables clear_accum=true but no code zeros accumulators, yielding nondeterministic results.
Add zeroing when clear_accum is true. Example insertion at the start of body() and body_rs() before the ki loop:
if constexpr (clear_accum) {
for (int i = 0; i < warp_rows; ++i) {
for (int j = 0; j < warp_cols; ++j) {
((float32x4*)C_local)[i * warp_cols + j] = 0;
}
}
}
Optionally re‑enable the static_assert if you cannot guarantee correct zeroing on all code paths.
🤖 Prompt for AI Agents
In src/tl_templates/hip/gemm.h around lines 73-75, the static_assert for
clear_accum was removed but no accumulator zeroing was added, so enabling
clear_accum=true yields nondeterministic results; add a constexpr guard at the
start of both body() and body_rs() (before the ki loop) that zeroes the per-warp
accumulator memory when clear_accum is true (iterate warp_rows and warp_cols and
set the corresponding C_local entries to zero, e.g., by casting C_local to the
appropriate vector type and writing zeros), and if you cannot guarantee zeroing
on all code paths re-enable the static_assert to prevent enabling clear_accum
without proper initialization.
@Alex4210987 would be great if we can add some test case for fp8 dtype for hip backend. |
Summary by CodeRabbit
New Features
Performance
Tests
Refactor