Skip to content

Conversation

tianyu-l
Copy link
Contributor

issue pointed out in
#1534 (comment)
pytorch/pytorch#160285

solution given by @rakkit in #1534 (comment)

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 11, 2025
@tianyu-l tianyu-l merged commit 59e57a4 into main Aug 11, 2025
7 checks passed
@tianyu-l tianyu-l deleted the fix branch August 11, 2025 23:54
@tianyu-l tianyu-l linked an issue Aug 12, 2025 that may be closed by this pull request
ruisizhang123 added a commit that referenced this pull request Oct 9, 2025
this PR is a followup of SimpleFSDP+EP
[PR](#1529). Here, we add a
`gradient_divide_factor` following FSDP2 to ensure modules wrapped by
(FSDP+EP) has the correct gradient reduction value.

- The original FSDP2 implementation is in this
[PR](#1551).
- The `gradient_divide_factor` logic is
[here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688)

We have two ways of handling `gradient_divide_factor` in
`reduce_scatter`:

1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the
`gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only
accepts `reduce_op` as a str input
([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)).

To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM,
we need to update the DTensor `_reduce_shard_tensor` and
`torch.distributed._functional_collectives.reduce_scatter_tensor` so
that it can pass the factor associated with ReduceOp.PREMUL_SUM as an
input.



2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`.
The logic is in this [Diff](https://www.internalfb.com/diff/D76546536).
It does a `div_` over gradient before performing `ReduceOp.SUM`.

Currently I'm following 2 since it is requires less change to
`_functional_collectives`.


After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4)
have identical loss:

<img width="1194" height="780" alt="Screenshot 2025-10-08 at 5 27 24 PM"
src="https://github.com/user-attachments/assets/aaf83109-8db8-4051-973d-c7b6950513de"
/>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Wrong-size gradients in Expert Parallel MoE
2 participants