-
Notifications
You must be signed in to change notification settings - Fork 557
[llama4] store expert weights such that we can transpose before grouped mm to have col-major memory layout #1517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
… col major mem layout
mod.register_parameter( | ||
"w1", | ||
nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(2)])), | ||
) # Column-wise sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not remove these comments on colwise vs. rowwise
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please fix files in torchtitan/models/deepseek_v3
as well
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will do, follow up PR ok?
self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) | ||
self.use_grouped_mm = use_grouped_mm | ||
|
||
def forward( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should fix for-loop as well right? Can we do transpose
in this function, not in the subfunctions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, this wasn't ready for review yet, still was working on that and running manual tests. It's updated now and I confirmed via setting use_grouped_mm=False
the loop implementation runs without error.
cc @tianyu-l for review. In a follow up PR I will update the dsv3 implementation and manually test the changes |
350b0ec
to
5e6863d
Compare
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)])) | ||
) # Column-wise sharding | ||
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(1)])) | ||
) # Rowwise sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reversed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For 3d weights, sharding on dim 1 is applying rowwise sharding for each expert, in my mind. So the new comment is deliberately the reverse of the prior comment. Same applies to the other comments. Can you clarify?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
colwise vs. rowwise is referring to logical placement, not depending on the storage & transpose combination of actual tensors. In TP, you'd first do intermediate = input x colwise
then output = intermediate x rowwise
.
As reference, original TP API were designed this way too
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L125
Also you are changing some but not others, e.g. w3
is still colwise. We should revert all of them. Feel free to add extra comments to explain motivation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
colwise vs. rowwise is referring to logical placement, not depending on the storage & transpose combination of actual tensors.
I see, that makes sense then. Updated the comments accordingly.
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(1)])), | ||
) # Row-wise sharding | ||
nn.Parameter(distribute_tensor(module.w2, device_mesh, [Shard(2)])), | ||
) # Columnwise sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reversed
nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(2)])), | ||
) # Column-wise sharding | ||
nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(1)])), | ||
) # Rowwise sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reversed
nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(1)])), | ||
) # Row-wise sharding | ||
nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(2)])), | ||
) # Columnwise sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reversed
h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) | ||
h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) | ||
h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we do this one time in GroupedExperts.forward
and reuse them for both grouped mm and for-loop?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could but IMO that would make the code less clear/readable if the transpose happens far away from where it is actually used
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah good point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert the colwise / rowwise comments before landing. See inline comments for details.
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(2)])) | ||
) # Column-wise sharding | ||
"w1", nn.Parameter(distribute_tensor(module.w1, device_mesh, [Shard(1)])) | ||
) # Rowwise sharding |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
colwise vs. rowwise is referring to logical placement, not depending on the storage & transpose combination of actual tensors. In TP, you'd first do intermediate = input x colwise
then output = intermediate x rowwise
.
As reference, original TP API were designed this way too
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L125
Also you are changing some but not others, e.g. w3
is still colwise. We should revert all of them. Feel free to add extra comments to explain motivation.
h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) | ||
h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) | ||
h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah good point
…ed mm to have col-major memory layout (pytorch#1517) # Summary Rather than store experts weights pre-transposed (E, in_dim, out_dim), we should store the expert weights non-transposed (E, out_dim, in_dim) then transpose before grouped gemm for (1) compatible dims for gemm, and (2) column-major memory layout required for right operand in grouped gemm. Doing this simple transpose (metadata change only) is must more efficient than doing this [inefficient memory layout transformation before every GEMM in fp8](https://github.com/pytorch/ao/blob/6e941c87c4d9fb9a74e6f979dd522605c696ca42/torchao/prototype/moe_training/scaled_grouped_mm.py#L96). # Eager Performance Llama4 debug model with FSDP=8, using config: ```python "debugmodel": TransformerModelArgs( dim=5120, n_layers=4, n_heads=40, n_kv_heads=8, ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, max_seq_len=10485760, num_experts=16, interleave_moe_layer_step=1, ), ``` ### bfloat16 With change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2147.0 Max Memory Usage: 92.67 GiB ``` Without change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 1711.0 Max Memory Usage: 92.67 GiB ``` ### fp8 rowwise With change: ``` (torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2675.0 Max Memory Usage: 90.35 GiB ``` Without change: ``` (torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2360.0 Max Memory Usage: 90.35 GiB ```
…ed mm to have col-major memory layout (pytorch#1517) # Summary Rather than store experts weights pre-transposed (E, in_dim, out_dim), we should store the expert weights non-transposed (E, out_dim, in_dim) then transpose before grouped gemm for (1) compatible dims for gemm, and (2) column-major memory layout required for right operand in grouped gemm. Doing this simple transpose (metadata change only) is must more efficient than doing this [inefficient memory layout transformation before every GEMM in fp8](https://github.com/pytorch/ao/blob/6e941c87c4d9fb9a74e6f979dd522605c696ca42/torchao/prototype/moe_training/scaled_grouped_mm.py#L96). # Eager Performance Llama4 debug model with FSDP=8, using config: ```python "debugmodel": TransformerModelArgs( dim=5120, n_layers=4, n_heads=40, n_kv_heads=8, ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, max_seq_len=10485760, num_experts=16, interleave_moe_layer_step=1, ), ``` ### bfloat16 With change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2147.0 Max Memory Usage: 92.67 GiB ``` Without change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 1711.0 Max Memory Usage: 92.67 GiB ``` ### fp8 rowwise With change: ``` (torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2675.0 Max Memory Usage: 90.35 GiB ``` Without change: ``` (torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2360.0 Max Memory Usage: 90.35 GiB ```
…ed mm to have col-major memory layout (pytorch#1517) # Summary Rather than store experts weights pre-transposed (E, in_dim, out_dim), we should store the expert weights non-transposed (E, out_dim, in_dim) then transpose before grouped gemm for (1) compatible dims for gemm, and (2) column-major memory layout required for right operand in grouped gemm. Doing this simple transpose (metadata change only) is must more efficient than doing this [inefficient memory layout transformation before every GEMM in fp8](https://github.com/pytorch/ao/blob/6e941c87c4d9fb9a74e6f979dd522605c696ca42/torchao/prototype/moe_training/scaled_grouped_mm.py#L96). # Eager Performance Llama4 debug model with FSDP=8, using config: ```python "debugmodel": TransformerModelArgs( dim=5120, n_layers=4, n_heads=40, n_kv_heads=8, ffn_dim_multiplier=1.2, multiple_of=2048, rope_theta=500000, max_seq_len=10485760, num_experts=16, interleave_moe_layer_step=1, ), ``` ### bfloat16 With change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2147.0 Max Memory Usage: 92.67 GiB ``` Without change: ``` ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 1711.0 Max Memory Usage: 92.67 GiB ``` ### fp8 rowwise With change: ``` (torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2675.0 Max Memory Usage: 90.35 GiB ``` Without change: ``` (torchtitan) [[email protected] ~/ao/benchmarks/float8/training (metdata)]$ TORCHTITAN_ROOT=/home/danvm/torchtitan NGPU=8 EXTRA_ARGS="--model.converters="float8" --float8.recipe_name="rowwise" --float8.filter_fqns="output,auto_filter_small_kn" --float8.moe_fqns_prototype="experts"" ./llama4.sh ===================================================== Calculating training performance metrics ===================================================== Median Tokens/Second (excluding step 1): 2360.0 Max Memory Usage: 90.35 GiB ```
Summary
Rather than store experts weights pre-transposed (E, in_dim, out_dim), we should store the expert weights non-transposed (E, out_dim, in_dim) then transpose before grouped gemm for (1) compatible dims for gemm, and (2) column-major memory layout required for right operand in grouped gemm. Doing this simple transpose (metadata change only) is must more efficient than doing this inefficient memory layout transformation before every GEMM in fp8.
Eager Performance
Llama4 debug model with FSDP=8, using config:
bfloat16
With change:
Without change:
fp8 rowwise
With change:
Without change: