Skip to content

Commit e24a354

Browse files
authored
Merge branch 'site' into fix-previous-versions
2 parents 8f638ba + 75215d4 commit e24a354

30 files changed

+691
-128
lines changed

_community_blog/datathon-2025.md

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
title: "Solve Real-Word AI Challenges with PyTorch at Datathon 2025: DataOrbit"
3+
author: "Aakash Senthilnathan"
4+
ext_url: /blog/datathon-2025/
5+
date: Feb 12, 2025
6+
---
7+
8+
**We’re excited to have PyTorch sponsor [Datathon 2025: DataOrbit](https://dataorbit-2025.devpost.com/)**, a place where students can collaborate with a team to solve problems using real-world datasets! This event, hosted by Data Science UCSB in collaboration with Gaucho Sports Analytics and ACM@UCSB, will take place on **February 22–23rd, 2025 at UC Santa Barbara**, with the incredible opportunity to present your project to a panel of corporate and faculty judges – **including the executive director of Pytorch!** – for a chance to win prizes up to $3000.

_events/ai-programming.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ In this talk, Anton will share how he built an AI agent that ranked #1 in the fi
1616

1717
Anton Pidkuiko is a Software Engineer at Meta, Reality Labs in London. He is currently working on applying the power of Large Language Models to Metaverse Avatar product experiences.
1818

19-
[Register now to join the event](/ai-powered-competitive-programming)
19+
[More info on this event.](/ai-powered-competitive-programming)

_events/pt-26-live-q-a.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ Nikita is a Software Engineer at Meta where he is, among other things, responsib
1717

1818
Bring your PyTorch 2.6 questions for Nikita during this live Q&A session.
1919

20-
[Register now to join the event](/pt-26-live-q-a)
20+
[More info on this event.](/pt-26-live-q-a)

_get_started/previous-versions.md

+12-12
Original file line numberDiff line numberDiff line change
@@ -25,45 +25,45 @@ your convenience.
2525

2626
```
2727
# conda
28-
conda install pytorch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 -c pytorch
28+
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 -c pytorch
2929
```
3030

3131
##### Linux and Windows
3232

3333
```
3434
# CUDA 11.8
35-
conda install pytorch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 pytorch-cuda=11.8 -c pytorch -c nvidia
35+
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=11.8 -c pytorch -c nvidia
3636
# CUDA 12.1
37-
conda install pytorch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 pytorch-cuda=12.1 -c pytorch -c nvidia
37+
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.1 -c pytorch -c nvidia
3838
# CUDA 12.4
39-
conda install pytorch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 pytorch-cuda=12.4 -c pytorch -c nvidia
39+
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 pytorch-cuda=12.4 -c pytorch -c nvidia
4040
# CPU Only
41-
conda install pytorch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 cpuonly -c pytorch
41+
conda install pytorch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 cpuonly -c pytorch
4242
```
4343

4444
#### Wheel
4545

4646
##### OSX
4747

4848
```
49-
pip install torch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0
49+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1
5050
```
5151

5252
##### Linux and Windows
5353

5454
```
5555
# ROCM 6.1 (Linux only)
56-
pip install torch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/rocm6.1
56+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/rocm6.1
5757
# ROCM 6.2 (Linux only)
58-
pip install torch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/rocm6.2
58+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/rocm6.2
5959
# CUDA 11.8
60-
pip install torch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu118
60+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu118
6161
# CUDA 12.1
62-
pip install torch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu121
62+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu121
6363
# CUDA 12.4
64-
pip install torch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cu124
64+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
6565
# CPU only
66-
pip install torch==2.5.1 torchvision==0.20.0 torchaudio==2.5.0 --index-url https://download.pytorch.org/whl/cpu
66+
pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cpu
6767
```
6868

6969
### v2.5.0

_posts/2024-08-07-flexattention.md

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
layout: blog_detail
33
title: "FlexAttention: The Flexibility of PyTorch with the Performance of FlashAttention"
4-
author: "Team PyTorch: Horace He, Driss Guessous, Yanbo Liang, Joy Dong"
4+
author: "Team PyTorch: Driss Guessous, Yanbo Liang, Joy Dong, Horace He"
55
---
66

77
![a cartoon chart flexing his muscles](/assets/images/flexattention/fg1.jpg){:style="width:100%"}
@@ -131,7 +131,7 @@ Alibi is similar to relative positional encodings with one exception \- it has a
131131
alibi_bias = generate_alibi_bias() # [num_heads]
132132

133133
def alibi(score, b, h, q_idx, kv_idx):
134-
bias = alibi_bias[h] * (q_idx - kv_idx)
134+
bias = alibi_bias[h] * (kv_idx - q_idx)
135135
return score + bias
136136
```
137137

@@ -218,12 +218,12 @@ def sliding_window_causal(b, h, q_idx, kv_idx):
218218
return causal_mask & window_mask
219219

220220
# If you want to be cute...
221-
from torch.nn.attention import or_masks
221+
from torch.nn.attention import and_masks
222222

223223
def sliding_window(b, h, q_idx, kv_idx)
224224
return q_idx - kv_idx <= SLIDING_WINDOW
225225

226-
sliding_window_causal = or_masks(causal_mask, sliding_window)
226+
sliding_window_causal = and_masks(causal_mask, sliding_window)
227227
```
228228

229229
We benchmark it against `F.scaled_dot_product_attention` with a sliding window mask as well as FA2 with a causal mask (as a reference point for performance). Not only are we significantly faster than `F.scaled_dot_product_attention`, we’re *also* significantly faster than FA2 with a causal mask as this mask has significantly more sparsity.
@@ -479,4 +479,4 @@ We want to highlight some prior work (and people) that have inspired FlexAttenti
479479
- The Jax team's work on SplashAttention
480480
- Philippe Tillet and Keren Zhou for helping us with Triton
481481
- Ali Hassani for discussions on neighborhood attention
482-
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)
482+
- Everybody who's complained about attention kernels not supporting their favorite attention variant :)
+112
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
---
2+
layout: blog_detail
3+
title: "Enabling advanced GPU features in PyTorch - Warp Specialization"
4+
author: "Meta and NVIDIA"
5+
---
6+
7+
**Meta**: Hongtao Yu, Manman Ren, Bert Maher, Shane Nay
8+
**NVIDIA**: Gustav Zhu, Shuhao Jiang
9+
10+
Over the past few months, we have been working on enabling advanced GPU features for PyTorch and Triton users through the Triton compiler. One of our key goals has been to introduce warp specialization support on NVIDIA Hopper GPUs. Today, we are thrilled to announce that our efforts have resulted in the rollout of fully automated Triton warp specialization, now available to users in the upcoming release of Triton [3.2](https://github.com/triton-lang/triton/tree/release/3.2.x), which will ship with PyTorch 2.6. PyTorch users can leverage this feature by [implementing user-defined Triton kernels](https://pytorch.org/tutorials/recipes/torch_compile_user_defined_triton_kernel_tutorial.html). This work leveraged an initial implementation of warp specialization in Triton by NVIDIA and we look forward to further development with the community in the future.
11+
12+
Warp specialization (WS) is a GPU programming technique where warps (a group of 32 threads on NVIDIA GPUs) within a threadblock are assigned distinct roles or tasks. This approach optimizes performance by enabling efficient execution of workloads that require task differentiation or cooperative processing. It enhances kernel performance by leveraging an asynchronous execution model, where different parts of the kernel are managed by separate hardware units. Data communication between these units, facilitated via shared memory on the NVIDIA H100, is highly efficient. Compared to a uniform warp approach, warp specialization allows the hardware multitasking warp scheduler to operate more effectively, maximizing resource utilization and overall performance.
13+
14+
Using GEMM as an example, a typical uniform warp approach on the H100 GPU involves 8 warps per thread block collectively computing a tile of the output tensor. These 8 warps are divided into two warp groups (WG), with each group cooperatively computing half of the tile using efficient warp-group-level MMA (WGMMA) instructions, as illustrated in Figure 1.
15+
16+
17+
![Figure 1. GEMM K-loop Body with Uniform Warps](/assets/images/warp-specialization/fg1.jpg){:style="width:100%"}
18+
19+
Figure 1. GEMM K-loop Body with Uniform Warps
20+
21+
The implementation is clean, easy to understand, and generally performs well, thanks to an elegant software pipeliner. The pipeliner's purpose is to enhance instruction-level parallelism by executing non-dependent operations on different hardware units. For instance, load operations from the next loop iteration can be executed simultaneously with WGMMA operations in the current iteration. However, this approach relies heavily on the compiler to craft an instruction sequence that ensures load and WGMMA instructions are issued at precisely the right time. While this is relatively straightforward for GEMM, which involves a limited number of operations, it becomes significantly more challenging for more complex kernels, such as flash attention.
22+
23+
On the other hand, warp specialization addresses programming challenges by separating operations intended to run simultaneously on different hardware units into distinct warps, synchronizing them efficiently using low-cost barriers in shared memory. This allows each warp to have its own instruction sequence, enabling instructions to be issued and executed continuously without being interrupted by other operations, thanks to the multi-way warp scheduler. An illustration of a warp-specialized GEMM can be seen in Figure 2.
24+
25+
26+
![Figure 2. GEMM K-loop Body with Specialized Warps](/assets/images/warp-specialization/fg2.jpg){:style="width:100%"}
27+
28+
Figure 2. GEMM K-loop Body with Specialized Warps
29+
30+
31+
## How to enable WS
32+
33+
To enable warp specialization, users simply need to specify two autotune flags: num_consumer_groups and num_buffers_warp_spec. For example, a warp-specialized GEMM implementation might look as shown below. Users can enable warp specialization by setting a non-zero value for num_consumer_groups, which defines the number of consumer warp groups. There is no corresponding flag to set the number of producer warp groups, as currently only one producer is supported. The num_buffers_warp_spec flag specifies the number of buffers the producer warp group will use to communicate with the consumer warp groups. A working example of a warp-specialized kernel is provided in the persistent GEMM [tutorial](https://github.com/triton-lang/triton/blob/6771065cb3137f7e64454cc047b9b74d577cbf7f/python/tutorials/09-persistent-matmul.py#L620).
34+
35+
```
36+
@triton.autotune(
37+
configs=[
38+
triton.Config(
39+
{
40+
"BLOCK_SIZE_M": 128,
41+
"BLOCK_SIZE_N": 256,
42+
"BLOCK_SIZE_K": 64,
43+
"GROUP_SIZE_M": 8,
44+
},
45+
num_stages=2,
46+
num_warps=4,
47+
num_consumer_groups=2,
48+
num_buffers_warp_spec=3,
49+
),
50+
],
51+
key=["M", "N", "K"],
52+
)
53+
@triton.jit
54+
def matmul_persistent_ws_kernel(
55+
a_ptr, b_ptr, c_ptr, M, N, K,
56+
stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,
57+
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
58+
):
59+
pid = tl.program_id(axis=0)
60+
num_pid_m = tl.cdiv(M, BLOCK_M)
61+
num_pid_n = tl.cdiv(N, BLOCK_N)
62+
pid_m = pid // num_pid_m
63+
pid_n = pid % num_pid_n
64+
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
65+
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
66+
offs_k = tl.arange(0, BLOCK_K)
67+
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
68+
b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
69+
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
70+
for k in range(0, tl.cdiv(K, BLOCK_K)):
71+
a = tl.load(a_ptrs)
72+
b = tl.load(b_ptrs)
73+
acc += tl.dot(a, b)
74+
a_ptrs += BLOCK_K * stride_ak
75+
b_ptrs += BLOCK_K * stride_bk
76+
c = acc.to(tl.float16)
77+
c_ptrs = c_ptr + stride_cm * offs_m[:, None] + stride_cn * offs_n[None, :]
78+
tl.store(c_ptrs, c)
79+
```
80+
81+
82+
## Under the Hood
83+
84+
Warp specialization uses a set of hierarchical compiler transformations and IR changes to translate a user's non-warp-specialized kernel into warp-specialized machine code. These include:
85+
86+
87+
88+
* **Task Partitioning**: The entire kernel is automatically divided into asynchronous tasks based on predefined heuristics. The compiler determines how to utilize one producer warp group and a user-specified number of consumer warp groups to execute the kernel. It assigns task IDs to specific anchor operations, which then influence the task assignments for remaining operations through asynchronous task ID propagation and dependency analysis. Since shared memory is the most efficient method for data transfer between warp groups across all supported platforms, the compiler optimizes task partitions to minimize register spills to shared memory, ensuring efficient execution.
89+
* **Data Partitioning for Multiple Consumer Groups**: Efficiently partitioning data among multiple consumer groups is key to optimizing workload distribution. On the H100 GPU, the compiler, by default, attempts to partition the input tensor `A` along the `M` dimension, allowing each consumer group to compute half of the output tensor independently. This strategy, known as [cooperative partitioning](https://github.com/NVIDIA/cutlass/blob/main/media/docs/efficient_gemm.md#warp-specialization), maximizes efficiency under most conditions. However, if this split leads to inefficiencies—such as producing a workload smaller than the native WGMMA instruction size—the compiler dynamically adjusts and partitions along the `N` dimension instead.
90+
* **Dataflow Pipelining**: The compiler creates cyclic shared memory buffers to pipeline dataflows across multiple-dimensional loops. Warp-specialized pipelining supports complex control flow. For example, our warp-specialized persistent GEMM kernel uses a doubly-nested loop, allowing the producer to begin fetching data for the next output tile while the consumer is finishing the compute for the prior tile.
91+
* **Communication Operations**`: `We introduced four high-level Triton GPU IR (TTGIR) communication operations`—ProducerAcquireOp, ProducerCommitOp, ConsumerWaitOp, `and` ConsumerReleaseOp—`to manage pipelined dataflows. These support both TMA and non-TMA memory operations.
92+
* **Code Partitioning**: Each async task is outlined into its own standalone code region, guarded by warp group ID checks. Control dependencies are duplicated as needed.
93+
* **TTGIR to LLVM/PTX Materialization**: TTGIR communication operations are materialized into corresponding LLVM/PTX barrier operations.
94+
95+
96+
## Performance
97+
98+
The [warp specialization release](https://github.com/triton-lang/triton/pull/5622) introduces a range of Triton compiler transformations that collectively convert user code into warp-specialized kernels. This feature has been applied to several key kernels, including Flash Attention and FP8 row-wise GEMM, resulting in significant performance gains of 10% to 15%. Below, we highlight the latest performance metrics for these high-impact kernels.
99+
100+
101+
![bar chart](/assets/images/warp-specialization/fg3.png){:style="width:100%"}
102+
103+
104+
105+
106+
![bar chart](/assets/images/warp-specialization/fg4.png){:style="width:100%"}
107+
108+
109+
110+
## Future Work
111+
112+
Looking ahead, we plan to further enhance Triton's warp specialization support by introducing new features such as Ping-Pong scheduling, expanded buffer sharing support, improved transparent handling for TMA, refined partitioning heuristics for upcoming NVIDIA hardware.

0 commit comments

Comments
 (0)