Skip to content

Commit 54df1da

Browse files
committed
fix simplefsdp gradient_divide_factor
1 parent 5d8e2d5 commit 54df1da

File tree

2 files changed

+109
-7
lines changed

2 files changed

+109
-7
lines changed

torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,8 @@ def parallelize_deepseekv3(
132132
ac_mode=job_config.activation_checkpoint.mode,
133133
mp_policy=mp_policy,
134134
shard_dim=experts_shard_dim,
135+
gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
135136
)
136-
# TODO(ruisizhang123): support set_gradient_divide_factor in simplefsdp
137-
# transformer_block.moe.experts.set_gradient_divide_factor(
138-
# parallel_dims.fsdp_gradient_divide_factor,
139-
# )
140137

141138
model = data_parallel(
142139
model,

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 108 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Sequence
88
from contextlib import contextmanager
99
from dataclasses import dataclass
10-
from typing import List, Optional
10+
from typing import List, Optional, Union
1111

1212
import torch
1313
import torch.nn as nn
@@ -20,6 +20,8 @@
2020
Shard,
2121
)
2222
from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
23+
from torch.distributed.distributed_c10d import ReduceOp
24+
from torch.distributed.fsdp._fully_shard._fsdp_collectives import _div_if_needed
2325
from torch.distributed.tensor._dtensor_spec import DTensorSpec
2426
from torch.distributed.tensor._redistribute import redistribute_local_tensor
2527
from torch.distributed.tensor.placement_types import _StridedShard, Placement
@@ -49,6 +51,101 @@ class MixedPrecisionPolicy:
4951
reduce_dtype: Optional[torch.dtype] = None
5052

5153

54+
@dataclass(frozen=True)
55+
class SimpleFSDPPartial(Partial):
56+
gradient_divide_factor: Optional[float] = None
57+
reduce_dtype: Optional[torch.dtype] = None
58+
59+
def _get_gradient_divide_factors(
60+
self,
61+
) -> tuple[
62+
Optional[float],
63+
Optional[float],
64+
Union[ReduceOp, ReduceOp.RedOpType],
65+
Union[ReduceOp, ReduceOp.RedOpType],
66+
]:
67+
"""
68+
the logic follows
69+
https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688
70+
"""
71+
if self.gradient_divide_factor is None:
72+
return None, None, None, None
73+
overflow_risk = self.reduce_dtype not in (torch.float32, torch.bfloat16)
74+
pre_factor: Optional[float] = None
75+
post_factor: Optional[float] = None
76+
reduce_scatter_op, all_reduce_op = ReduceOp.SUM, ReduceOp.SUM
77+
if overflow_risk:
78+
# Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid
79+
# overflow/underflow. For N data parallel workers, each worker computes
80+
# g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid
81+
# overflow/underflow, we divide by ~sqrt(N) before/after the reduction.
82+
pre_factor = 1
83+
while (
84+
self.gradient_divide_factor % pre_factor == 0
85+
and self.gradient_divide_factor / pre_factor > pre_factor
86+
):
87+
pre_factor *= 2
88+
post_factor = self.gradient_divide_factor / pre_factor
89+
else:
90+
reduce_scatter_op = torch.distributed._make_nccl_premul_sum(
91+
1 / self.gradient_divide_factor
92+
)
93+
all_reduce_op = ReduceOp.SUM
94+
return pre_factor, post_factor, reduce_scatter_op, all_reduce_op
95+
96+
def _reduce_value(
97+
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
98+
) -> torch.Tensor:
99+
(
100+
pre_factor,
101+
post_factor,
102+
reduce_scatter_op,
103+
all_reduce_op,
104+
) = self._get_gradient_divide_factors()
105+
if pre_factor is not None:
106+
_div_if_needed(tensor, pre_factor)
107+
reduced = super()._reduce_value(tensor, mesh, mesh_dim)
108+
if post_factor is not None:
109+
_div_if_needed(reduced, post_factor)
110+
return reduced
111+
112+
def _partition_value(
113+
self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int
114+
) -> torch.Tensor:
115+
(
116+
pre_factor,
117+
post_factor,
118+
reduce_scatter_op,
119+
all_reduce_op,
120+
) = self._get_gradient_divide_factors()
121+
if pre_factor is not None:
122+
_div_if_needed(tensor, pre_factor)
123+
reduced = super()._reduce_value(tensor, mesh, mesh_dim)
124+
if post_factor is not None:
125+
_div_if_needed(reduced, post_factor)
126+
return reduced
127+
128+
def _reduce_shard_value(
129+
self,
130+
tensor: torch.Tensor,
131+
mesh: DeviceMesh,
132+
mesh_dim: int,
133+
shard_spec: Placement,
134+
) -> torch.Tensor:
135+
(
136+
pre_factor,
137+
post_factor,
138+
reduce_scatter_op,
139+
all_reduce_op,
140+
) = self._get_gradient_divide_factors()
141+
if pre_factor is not None:
142+
_div_if_needed(tensor, pre_factor)
143+
reduced = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec)
144+
if post_factor is not None:
145+
_div_if_needed(reduced, post_factor)
146+
return reduced
147+
148+
52149
def _distribute_dtensor(
53150
tensor: DTensor,
54151
device_mesh: DeviceMesh,
@@ -192,18 +289,24 @@ def __init__(
192289
mode,
193290
regional_ac,
194291
mp_policy,
292+
gradient_divide_factor,
195293
):
196294
super().__init__()
197295
self.device_mesh = device_mesh
198296
self.param_sharding = param_sharding
199297
self.mode = mode
200298
self.compute_placements = [Replicate()] * self.device_mesh.ndim
201-
self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim
299+
self.grad_placements = [
300+
SimpleFSDPPartial(
301+
reduce_op="avg",
302+
gradient_divide_factor=gradient_divide_factor,
303+
reduce_dtype=mp_policy.reduce_dtype,
304+
)
305+
] * self.device_mesh.ndim
202306
self.regional_ac = regional_ac
203307
mp_policy = mp_policy or MixedPrecisionPolicy()
204308
self.param_dtype = mp_policy.param_dtype
205309
self.reduce_dtype = mp_policy.reduce_dtype
206-
self.ep_mesh_name, self.tp_mesh_name = "ep", "tp"
207310

208311
def replicate_compute(self, x):
209312
# data parallel runtime replicate parameters and do local compute
@@ -286,6 +389,7 @@ def data_parallel(
286389
ac_mode: str = "none",
287390
mp_policy: Optional[MixedPrecisionPolicy] = None,
288391
shard_dim: int = 0,
392+
gradient_divide_factor: Optional[float] = None,
289393
):
290394
if mode == "replicate":
291395
param_sharding = (Replicate(),)
@@ -348,6 +452,7 @@ def data_parallel(
348452
mode,
349453
regional_ac,
350454
mp_policy=mp_policy,
455+
gradient_divide_factor=gradient_divide_factor,
351456
),
352457
)
353458
return model

0 commit comments

Comments
 (0)