Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 1, 2025

Summary

Rather than store experts weights pre-transposed (E, in_dim, out_dim), we should store the expert weights non-transposed (E, out_dim, in_dim) then transpose before grouped gemm for (1) compatible dims for gemm, and (2) column-major memory layout required for right operand in grouped gemm. Doing this simple transpose (metadata change only) is must more efficient than doing this inefficient memory layout transformation before every GEMM in fp8.

Eager Performance

Llama4 debug model with FSDP=8, using config:

    "debugmodel": TransformerModelArgs(
        dim=5120,
        n_layers=4,
        n_heads=40,
        n_kv_heads=8,
        ffn_dim_multiplier=1.2,
        multiple_of=2048,
        rope_theta=500000,
        max_seq_len=10485760,
        num_experts=16,
        interleave_moe_layer_step=1,
    ),

bfloat16

With change:

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2147.0
Max Memory Usage: 92.67 GiB

Without change:

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 1711.0
Max Memory Usage: 92.67 GiB

fp8 rowwise

With change:

(torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2675.0
Max Memory Usage: 90.35 GiB

Without change:

(torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2360.0
Max Memory Usage: 90.35 GiB

@danielvegamyhre danielvegamyhre marked this pull request as draft August 1, 2025 22:53
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 1, 2025
@danielvegamyhre danielvegamyhre marked this pull request as ready for review August 1, 2025 23:16
mod.register_parameter(
"w1",
nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(2)])),
) # Column-wise sharding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not remove these comments on colwise vs. rowwise

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please fix files in torchtitan/models/deepseek_v3 as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will do, follow up PR ok?

self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
self.use_grouped_mm = use_grouped_mm

def forward(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should fix for-loop as well right? Can we do transpose in this function, not in the subfunctions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, this wasn't ready for review yet, still was working on that and running manual tests. It's updated now and I confirmed via setting use_grouped_mm=False the loop implementation runs without error.

@danielvegamyhre
Copy link
Contributor Author

cc @tianyu-l for review. In a follow up PR I will update the dsv3 implementation and manually test the changes

"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)]))
) # Column-wise sharding
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(1)]))
) # Rowwise sharding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reversed

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Aug 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For 3d weights, sharding on dim 1 is applying rowwise sharding for each expert, in my mind. So the new comment is deliberately the reverse of the prior comment. Same applies to the other comments. Can you clarify?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

colwise vs. rowwise is referring to logical placement, not depending on the storage & transpose combination of actual tensors. In TP, you'd first do intermediate = input x colwise then output = intermediate x rowwise.

As reference, original TP API were designed this way too
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L125

Also you are changing some but not others, e.g. w3 is still colwise. We should revert all of them. Feel free to add extra comments to explain motivation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

colwise vs. rowwise is referring to logical placement, not depending on the storage & transpose combination of actual tensors.

I see, that makes sense then. Updated the comments accordingly.

nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])),
) # Row-wise sharding
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(2)])),
) # Columnwise sharding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reversed

nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(2)])),
) # Column-wise sharding
nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(1)])),
) # Rowwise sharding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reversed

nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(1)])),
) # Row-wise sharding
nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(2)])),
) # Columnwise sharding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reversed

Comment on lines +72 to +74
h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1)))
h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1))
h = torch.matmul(h, w2[expert_idx].transpose(-2, -1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we do this one time in GroupedExperts.forward and reuse them for both grouped mm and for-loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could but IMO that would make the code less clear/readable if the transpose happens far away from where it is actually used

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good point

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert the colwise / rowwise comments before landing. See inline comments for details.

"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)]))
) # Column-wise sharding
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(1)]))
) # Rowwise sharding
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

colwise vs. rowwise is referring to logical placement, not depending on the storage & transpose combination of actual tensors. In TP, you'd first do intermediate = input x colwise then output = intermediate x rowwise.

As reference, original TP API were designed this way too
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L125

Also you are changing some but not others, e.g. w3 is still colwise. We should revert all of them. Feel free to add extra comments to explain motivation.

Comment on lines +72 to +74
h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1)))
h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1))
h = torch.matmul(h, w2[expert_idx].transpose(-2, -1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah good point

@danielvegamyhre danielvegamyhre merged commit ed288bc into pytorch:main Aug 3, 2025
5 checks passed
bentherien pushed a commit to bentherien/torchtitan_ that referenced this pull request Aug 5, 2025
…ed mm to have col-major memory layout (pytorch#1517)

# Summary
Rather than store experts weights pre-transposed (E, in_dim, out_dim),
we should store the expert weights non-transposed (E, out_dim, in_dim)
then transpose before grouped gemm for (1) compatible dims for gemm, and
(2) column-major memory layout required for right operand in grouped
gemm. Doing this simple transpose (metadata change only) is must more
efficient than doing this [inefficient memory layout transformation
before every GEMM in
fp8](https://github.com/pytorch/ao/blob/6e941c87c4d9fb9a74e6f979dd522605c696ca42/torchao/prototype/moe_training/scaled_grouped_mm.py#L96).


# Eager Performance

Llama4 debug model with FSDP=8, using config:
```python
    "debugmodel": TransformerModelArgs(
        dim=5120,
        n_layers=4,
        n_heads=40,
        n_kv_heads=8,
        ffn_dim_multiplier=1.2,
        multiple_of=2048,
        rope_theta=500000,
        max_seq_len=10485760,
        num_experts=16,
        interleave_moe_layer_step=1,
    ),
```

### bfloat16

With change:
```
=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2147.0
Max Memory Usage: 92.67 GiB
```
Without change:
```
=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 1711.0
Max Memory Usage: 92.67 GiB
```

### fp8 rowwise
With change:
```
(torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2675.0
Max Memory Usage: 90.35 GiB
```

Without change:
```
(torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2360.0
Max Memory Usage: 90.35 GiB
```
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
…ed mm to have col-major memory layout (pytorch#1517)

# Summary
Rather than store experts weights pre-transposed (E, in_dim, out_dim),
we should store the expert weights non-transposed (E, out_dim, in_dim)
then transpose before grouped gemm for (1) compatible dims for gemm, and
(2) column-major memory layout required for right operand in grouped
gemm. Doing this simple transpose (metadata change only) is must more
efficient than doing this [inefficient memory layout transformation
before every GEMM in
fp8](https://github.com/pytorch/ao/blob/6e941c87c4d9fb9a74e6f979dd522605c696ca42/torchao/prototype/moe_training/scaled_grouped_mm.py#L96).


# Eager Performance

Llama4 debug model with FSDP=8, using config:
```python
    "debugmodel": TransformerModelArgs(
        dim=5120,
        n_layers=4,
        n_heads=40,
        n_kv_heads=8,
        ffn_dim_multiplier=1.2,
        multiple_of=2048,
        rope_theta=500000,
        max_seq_len=10485760,
        num_experts=16,
        interleave_moe_layer_step=1,
    ),
```

### bfloat16

With change:
```
=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2147.0
Max Memory Usage: 92.67 GiB
```
Without change:
```
=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 1711.0
Max Memory Usage: 92.67 GiB
```

### fp8 rowwise
With change:
```
(torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2675.0
Max Memory Usage: 90.35 GiB
```

Without change:
```
(torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2360.0
Max Memory Usage: 90.35 GiB
```
joellidin pushed a commit to one-covenant/torchtitan that referenced this pull request Aug 8, 2025
…ed mm to have col-major memory layout (pytorch#1517)

# Summary
Rather than store experts weights pre-transposed (E, in_dim, out_dim),
we should store the expert weights non-transposed (E, out_dim, in_dim)
then transpose before grouped gemm for (1) compatible dims for gemm, and
(2) column-major memory layout required for right operand in grouped
gemm. Doing this simple transpose (metadata change only) is must more
efficient than doing this [inefficient memory layout transformation
before every GEMM in
fp8](https://github.com/pytorch/ao/blob/6e941c87c4d9fb9a74e6f979dd522605c696ca42/torchao/prototype/moe_training/scaled_grouped_mm.py#L96).


# Eager Performance

Llama4 debug model with FSDP=8, using config:
```python
    "debugmodel": TransformerModelArgs(
        dim=5120,
        n_layers=4,
        n_heads=40,
        n_kv_heads=8,
        ffn_dim_multiplier=1.2,
        multiple_of=2048,
        rope_theta=500000,
        max_seq_len=10485760,
        num_experts=16,
        interleave_moe_layer_step=1,
    ),
```

### bfloat16

With change:
```
=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2147.0
Max Memory Usage: 92.67 GiB
```
Without change:
```
=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 1711.0
Max Memory Usage: 92.67 GiB
```

### fp8 rowwise
With change:
```
(torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2675.0
Max Memory Usage: 90.35 GiB
```

Without change:
```
(torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh

=====================================================
 Calculating training performance metrics
=====================================================
Median Tokens/Second (excluding step 1): 2360.0
Max Memory Usage: 90.35 GiB
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants