Skip to content

Conversation

sanandaraj5597
Copy link
Contributor

This PR adds support gradient fusion for MCore FSDP.

Selvaraj Anandaraj and others added 2 commits September 20, 2025 22:31
Copy link
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this makes sense. If you configure a TE module with fuse_wgrad_accumulation=True (e.g. here), the correct behavior is to fuse wgrad accumulation. If Mcore FSDP doesn't support it, then it should be Mcore's responsibility to not set that arg.

@timmoon10
Copy link
Collaborator

The root problem is that Mcore DDP and FSDP have different behaviors and require different contracts with TE:

  • DDP uses persistent main_grad buffers and it expects TE to accumulate into it. To adhere to this contract, Mcore zeros out the main_grad before the first microbatch step.
  • FSDP uses temporary main_grad buffers and it expects TE to overwrite it.

I don't like this PR's approach of switching between these two cases based on whether Mcore is using DDP or FSDP, since that's not actually the important thing. It also needlessly blocks some possible optimizations (DDP might want to overwrite main_grads in the first microbatch, FSDP might want to accumulate into main_grads if a weight is shared).

There are a few possible redesigns:

  1. Deprecate the fuse_wgrad_accumulation kwarg in favor of something like output_wgrad_to_main_grad. Then check param flags to decide whether to overwrite or accumulate into the main_grad:
    grad_weight: torch.Tensor
    accumulate: bool = False
    if output_wgrad_to_main_grad:
        if getattr(weight, "get_main_grad", None) is not None:
            grad_weight = weight.get_main_grad()
        else:
            grad_weight = weight.main_grad
        accumulate = getattr(weight, "_overwrite_main_grad", True)
    else:
        grad_weight = torch.empty(...)

    gemm(..., out=grad_weight, accumulate=accumulate)

Ensuring backward compatibility will be tricky.

  1. Have separate kwargs for fuse_wgrad_accumulation and overwrite_wgrad_main_grad. This means that the two cases are separate code paths and backward compatibility is easier to maintain. However, it also means we can't change behavior between steps.
  2. Keep the fuse_wgrad_accumulation kwarg and purely control behavior with param flags. This is basically the approach used in this PR, although it could be improved by using better names rather than just checking weight.__fsdp_param__. One problem is that fuse_wgrad_accumulation will no longer be an accurate name.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants