Open
Description
Summary
To run: TRITON_ALWAYS_COMPILE=1 TRITON_DUMP_DIR=my_directory_2 TRITON_KERNEL_DUMP=1 pytest -s -v test/prototype/mx_formats/test_custom_cast.py -k "test_fp4_triton_unscaled_cast"
Bad ttir on left good on right. No real differences
https://www.diffchecker.com/ueX5YZw4
TTGIR:
https://www.diffchecker.com/M5PS6QJg/
Differences in PTX
https://www.diffchecker.com/8mseNnKA/
Activity
[-]New Pytorch Triton breaks custom cast kernel[/-][+]New Pytorch Triton breaks custom cast kernel MX[/+]CliveUnger commentedon Mar 20, 2025
I started to take a look at this and found that it fails on both Hopper and Blackwell machines. I narrowed it down to a single culprit commit on Triton where bf16 op lowering is offloaded to LLVM instead of custom code conversion.
Here is the commit and related PR:
Looking closer at the kernel it seems that the non-denormalized expontents are not being biased correctly. Currently, trying to understand what is going wrong in the vectorized LLVM code.
davidberard98 commentedon Mar 24, 2025
@drisspg @CliveUnger @danielvegamyhre any ideas why this failure didn't show up in CI?
jerryzh168 commentedon May 1, 2025
@drisspg is this fixed now?