Skip to content

request for faster inductor kernels for blockwise reduction across dim1 -> write #149982

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
vkuzo opened this issue Mar 25, 2025 · 0 comments
Open
Assignees
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@vkuzo
Copy link
Contributor

vkuzo commented Mar 25, 2025

🐛 Describe the bug

We should make the following kernel be fast in compile + inductor. This is important to be able to generate the dim1 cast to MX formats.

def scale_dim1_reference(x_hp: torch.Tensor, block_size) -> Tuple[torch.Tensor, torch.Tensor]:
    # normalize across dim1
    x_hp_d1 = x_hp.t().contiguous()
    x_hp_d1_block = x_hp_d1.reshape(-1, block_size)
    x_hp_d1_block_abs = x_hp_d1_block.abs()
    amax_dim1 = torch.amax(x_hp_d1_block_abs, dim=1).unsqueeze(1)
    x_hp_d1_block_normalized = x_hp_d1_block / amax_dim1
    x_hp_d1_normalized = x_hp_d1_block_normalized.reshape(x_hp_d1.shape)
    return x_hp_d1_normalized.t(), amax_dim1

Currently, I am only hitting 0.6 to 0.7 TB/s on NVIDIA H100. If the reduction and write is across dim0 instead of dim1, I see 2.0-2.2 TB/s. From discussions with @eellison , this is due to uncoalesced reads and we can fix this.

Repro script: https://gist.github.com/vkuzo/9eff0d27691be483e45bb10edf66d82c
Repro results on NVIDIA H100:

(pytorch) [[email protected] ~/local/pytorch_scripts/mx_cast_poc (20250325_dim1_cast)]$ python 20250325_dim1_cast.py --M 4096 --K 4096
M 4096 K 4096 BLOCK_SIZE 32
GPU: NVIDIA H100
torch version: 2.8.0a0+gitdd94e94
triton version: 3.2.0
time_reference_compile_us 107.69072608695663
mem_bw_gbps 632.8998092645895
(pytorch) [[email protected] ~/local/pytorch_scripts/mx_cast_poc (20250325_dim1_cast)]$ python 20250325_dim1_cast.py --M 16384 --K 16384
M 16384 K 16384 BLOCK_SIZE 32
GPU: NVIDIA H100
torch version: 2.8.0a0+gitdd94e94
triton version: 3.2.0
time_reference_compile_us 1612.7510689655173
mem_bw_gbps 676.1855942836252

TORCH_LOGS=output_code results: https://gist.github.com/vkuzo/4420c5b508ddd560e5d4620758b5936a

Versions

main branch

cc @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @aakhundov

@jansel jansel added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module module: inductor labels Mar 29, 2025
eellison added a commit that referenced this issue Apr 28, 2025
Fix for #149982. 

Summary:

This PR does two main things:
1. Rewrites the tiling heuristics. The previous tiling heuristic would have each dependency generate a tiling. Then, we sum up the score for each generated tiling, preferring any 2d tiling over the default. The new tiling heuristics scores each tiling by its global coalesced memory. This gives both a potentially better tiling (especially for more complicated, 3d patterns) as well as information we can use in generating block sizes.  

2. Analyses memory dependencies for accesses that would be coalesced with additional tiling. The motivating kernel is in #149982 which is a 32 element reduction. A smaller version of it is [here](https://gist.github.com/eellison/0fa9396f5479eb4dba09756e3bf6ff2a). We need to run this kernel once in the forward per linear layer on a contiguous tensor, and once in the backward on a transposed tensor. 

While the contiguous kernel has coalesced accesses, and is performant on master, the transposed version accesses uncoalesced memory on main and is ~2.8x slower. See, this [full log](https://gist.github.com/eellison/fa644bfd9d0ae11dadb62e17a5d48a83) from the above repro. Now, with this PR, it is only ~1.15x slower. See the [updated log](https://gist.github.com/eellison/0b2b653309494d28cf7b48929a022075). 

We analyse memory addresses that are not coalesced by any iteration variable. For this following dependency:

`(((32*n0 + n1)//2048)) + 4096*(ModularIndexing(32*n0 + n1, 1, 2048))` we infer that tiling `n0` by 64 makes the first term coalesced. 

I'm sure there are still some CI failures to debug..

cc vkuzo 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov 

[ghstack-poisoned]
eellison added a commit that referenced this issue Apr 28, 2025
Fix for #149982. 

Summary:

This PR does two main things:
1. Rewrites the tiling heuristics. The previous tiling heuristic would have each dependency generate a tiling. Then, we sum up the score for each generated tiling, preferring any 2d tiling over the default. The new tiling heuristics scores each tiling by its global coalesced memory. This gives both a potentially better tiling (especially for more complicated, 3d patterns) as well as information we can use in generating block sizes.  

2. Analyses memory dependencies for accesses that would be coalesced with additional tiling. The motivating kernel is in #149982 which is a 32 element reduction. A smaller version of it is [here](https://gist.github.com/eellison/0fa9396f5479eb4dba09756e3bf6ff2a). We need to run this kernel once in the forward per linear layer on a contiguous tensor, and once in the backward on a transposed tensor. 

While the contiguous kernel has coalesced accesses, and is performant on master, the transposed version accesses uncoalesced memory on main and is ~2.8x slower. See, this [full log](https://gist.github.com/eellison/fa644bfd9d0ae11dadb62e17a5d48a83) from the above repro. Now, with this PR, it is only ~1.15x slower. See the [updated log](https://gist.github.com/eellison/0b2b653309494d28cf7b48929a022075). 

We analyse memory addresses that are not coalesced by any iteration variable. For this following dependency:

`(((32*n0 + n1)//2048)) + 4096*(ModularIndexing(32*n0 + n1, 1, 2048))` we infer that tiling `n0` by 64 makes the first term coalesced. 

I'm sure there are still some CI failures to debug..

cc vkuzo 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov 

[ghstack-poisoned]
eellison added a commit that referenced this issue Apr 28, 2025
Fix for #149982. 

Summary:

This PR does two main things:
1. Rewrites the tiling heuristics. The previous tiling heuristic would have each dependency generate a tiling. Then, we sum up the score for each generated tiling, preferring any 2d tiling over the default. The new tiling heuristics scores each tiling by its global coalesced memory. This gives both a potentially better tiling (especially for more complicated, 3d patterns) as well as information we can use in generating block sizes.  

2. Analyses memory dependencies for accesses that would be coalesced with additional tiling. The motivating kernel is in #149982 which is a 32 element reduction. A smaller version of it is [here](https://gist.github.com/eellison/0fa9396f5479eb4dba09756e3bf6ff2a). We need to run this kernel once in the forward per linear layer on a contiguous tensor, and once in the backward on a transposed tensor. 

While the contiguous kernel has coalesced accesses, and is performant on master, the transposed version accesses uncoalesced memory on main and is ~2.8x slower. See, this [full log](https://gist.github.com/eellison/fa644bfd9d0ae11dadb62e17a5d48a83) from the above repro. Now, with this PR, it is only ~1.15x slower. See the [updated log](https://gist.github.com/eellison/0b2b653309494d28cf7b48929a022075). 

We analyse memory addresses that are not coalesced by any iteration variable. For this following dependency:

`(((32*n0 + n1)//2048)) + 4096*(ModularIndexing(32*n0 + n1, 1, 2048))` we infer that tiling `n0` by 64 makes the first term coalesced. 

I'm sure there are still some CI failures to debug..

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov vkuzo 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov 

[ghstack-poisoned]
eellison added a commit that referenced this issue Apr 28, 2025
Fix for #149982. 

Summary:

This PR does two main things:
1. Rewrites the tiling heuristics. The previous tiling heuristic would have each dependency generate a tiling. Then, we sum up the score for each generated tiling, preferring any 2d tiling over the default. The new tiling heuristics scores each tiling by its global coalesced memory. This gives both a potentially better tiling (especially for more complicated, 3d patterns) as well as information we can use in generating block sizes.  

2. Analyses memory dependencies for accesses that would be coalesced with additional tiling. The motivating kernel is in #149982 which is a 32 element reduction. A smaller version of it is [here](https://gist.github.com/eellison/0fa9396f5479eb4dba09756e3bf6ff2a). We need to run this kernel once in the forward per linear layer on a contiguous tensor, and once in the backward on a transposed tensor. 

While the contiguous kernel has coalesced accesses, and is performant on master, the transposed version accesses uncoalesced memory on main and is ~2.8x slower. See, this [full log](https://gist.github.com/eellison/fa644bfd9d0ae11dadb62e17a5d48a83) from the above repro. Now, with this PR, it is only ~1.15x slower. See the [updated log](https://gist.github.com/eellison/0b2b653309494d28cf7b48929a022075). 

We analyse memory addresses that are not coalesced by any iteration variable. For this following dependency:

`(((32*n0 + n1)//2048)) + 4096*(ModularIndexing(32*n0 + n1, 1, 2048))` we infer that tiling `n0` by 64 makes the first term coalesced. 

I'm sure there are still some CI failures to debug..

cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov vkuzo 


cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov 

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: inductor oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

4 participants