|
7 | 7 | from collections.abc import Sequence
|
8 | 8 | from contextlib import contextmanager
|
9 | 9 | from dataclasses import dataclass
|
10 |
| -from typing import List, Optional |
| 10 | +from typing import List, Optional, Union |
11 | 11 |
|
12 | 12 | import torch
|
13 | 13 | import torch.nn as nn
|
|
20 | 20 | Shard,
|
21 | 21 | )
|
22 | 22 | from torch.distributed.device_mesh import _mesh_resources, DeviceMesh
|
| 23 | +from torch.distributed.distributed_c10d import ReduceOp |
| 24 | +from torch.distributed.fsdp._fully_shard._fsdp_collectives import _div_if_needed |
23 | 25 | from torch.distributed.tensor._dtensor_spec import DTensorSpec
|
24 | 26 | from torch.distributed.tensor._redistribute import redistribute_local_tensor
|
25 | 27 | from torch.distributed.tensor.placement_types import _StridedShard, Placement
|
@@ -49,6 +51,101 @@ class MixedPrecisionPolicy:
|
49 | 51 | reduce_dtype: Optional[torch.dtype] = None
|
50 | 52 |
|
51 | 53 |
|
| 54 | +@dataclass(frozen=True) |
| 55 | +class SimpleFSDPPartial(Partial): |
| 56 | + gradient_divide_factor: Optional[float] = None |
| 57 | + reduce_dtype: Optional[torch.dtype] = None |
| 58 | + |
| 59 | + def _get_gradient_divide_factors( |
| 60 | + self, |
| 61 | + ) -> tuple[ |
| 62 | + Optional[float], |
| 63 | + Optional[float], |
| 64 | + Union[ReduceOp, ReduceOp.RedOpType], |
| 65 | + Union[ReduceOp, ReduceOp.RedOpType], |
| 66 | + ]: |
| 67 | + """ |
| 68 | + the logic follows |
| 69 | + https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688 |
| 70 | + """ |
| 71 | + if self.gradient_divide_factor is None: |
| 72 | + return None, None, None, None |
| 73 | + overflow_risk = self.reduce_dtype not in (torch.float32, torch.bfloat16) |
| 74 | + pre_factor: Optional[float] = None |
| 75 | + post_factor: Optional[float] = None |
| 76 | + reduce_scatter_op, all_reduce_op = ReduceOp.SUM, ReduceOp.SUM |
| 77 | + if overflow_risk: |
| 78 | + # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid |
| 79 | + # overflow/underflow. For N data parallel workers, each worker computes |
| 80 | + # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid |
| 81 | + # overflow/underflow, we divide by ~sqrt(N) before/after the reduction. |
| 82 | + pre_factor = 1 |
| 83 | + while ( |
| 84 | + self.gradient_divide_factor % pre_factor == 0 |
| 85 | + and self.gradient_divide_factor / pre_factor > pre_factor |
| 86 | + ): |
| 87 | + pre_factor *= 2 |
| 88 | + post_factor = self.gradient_divide_factor / pre_factor |
| 89 | + else: |
| 90 | + reduce_scatter_op = torch.distributed._make_nccl_premul_sum( |
| 91 | + 1 / self.gradient_divide_factor |
| 92 | + ) |
| 93 | + all_reduce_op = ReduceOp.SUM |
| 94 | + return pre_factor, post_factor, reduce_scatter_op, all_reduce_op |
| 95 | + |
| 96 | + def _reduce_value( |
| 97 | + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int |
| 98 | + ) -> torch.Tensor: |
| 99 | + ( |
| 100 | + pre_factor, |
| 101 | + post_factor, |
| 102 | + reduce_scatter_op, |
| 103 | + all_reduce_op, |
| 104 | + ) = self._get_gradient_divide_factors() |
| 105 | + if pre_factor is not None: |
| 106 | + _div_if_needed(tensor, pre_factor) |
| 107 | + reduced = super()._reduce_value(tensor, mesh, mesh_dim) |
| 108 | + if post_factor is not None: |
| 109 | + _div_if_needed(reduced, post_factor) |
| 110 | + return reduced |
| 111 | + |
| 112 | + def _partition_value( |
| 113 | + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int |
| 114 | + ) -> torch.Tensor: |
| 115 | + ( |
| 116 | + pre_factor, |
| 117 | + post_factor, |
| 118 | + reduce_scatter_op, |
| 119 | + all_reduce_op, |
| 120 | + ) = self._get_gradient_divide_factors() |
| 121 | + if pre_factor is not None: |
| 122 | + _div_if_needed(tensor, pre_factor) |
| 123 | + reduced = super()._reduce_value(tensor, mesh, mesh_dim) |
| 124 | + if post_factor is not None: |
| 125 | + _div_if_needed(reduced, post_factor) |
| 126 | + return reduced |
| 127 | + |
| 128 | + def _reduce_shard_value( |
| 129 | + self, |
| 130 | + tensor: torch.Tensor, |
| 131 | + mesh: DeviceMesh, |
| 132 | + mesh_dim: int, |
| 133 | + shard_spec: Placement, |
| 134 | + ) -> torch.Tensor: |
| 135 | + ( |
| 136 | + pre_factor, |
| 137 | + post_factor, |
| 138 | + reduce_scatter_op, |
| 139 | + all_reduce_op, |
| 140 | + ) = self._get_gradient_divide_factors() |
| 141 | + if pre_factor is not None: |
| 142 | + _div_if_needed(tensor, pre_factor) |
| 143 | + reduced = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec) |
| 144 | + if post_factor is not None: |
| 145 | + _div_if_needed(reduced, post_factor) |
| 146 | + return reduced |
| 147 | + |
| 148 | + |
52 | 149 | def _distribute_dtensor(
|
53 | 150 | tensor: DTensor,
|
54 | 151 | device_mesh: DeviceMesh,
|
@@ -192,18 +289,24 @@ def __init__(
|
192 | 289 | mode,
|
193 | 290 | regional_ac,
|
194 | 291 | mp_policy,
|
| 292 | + gradient_divide_factor, |
195 | 293 | ):
|
196 | 294 | super().__init__()
|
197 | 295 | self.device_mesh = device_mesh
|
198 | 296 | self.param_sharding = param_sharding
|
199 | 297 | self.mode = mode
|
200 | 298 | self.compute_placements = [Replicate()] * self.device_mesh.ndim
|
201 |
| - self.grad_placements = [Partial(reduce_op="avg")] * self.device_mesh.ndim |
| 299 | + self.grad_placements = [ |
| 300 | + SimpleFSDPPartial( |
| 301 | + reduce_op="avg", |
| 302 | + gradient_divide_factor=gradient_divide_factor, |
| 303 | + reduce_dtype=mp_policy.reduce_dtype, |
| 304 | + ) |
| 305 | + ] * self.device_mesh.ndim |
202 | 306 | self.regional_ac = regional_ac
|
203 | 307 | mp_policy = mp_policy or MixedPrecisionPolicy()
|
204 | 308 | self.param_dtype = mp_policy.param_dtype
|
205 | 309 | self.reduce_dtype = mp_policy.reduce_dtype
|
206 |
| - self.ep_mesh_name, self.tp_mesh_name = "ep", "tp" |
207 | 310 |
|
208 | 311 | def replicate_compute(self, x):
|
209 | 312 | # data parallel runtime replicate parameters and do local compute
|
@@ -286,6 +389,7 @@ def data_parallel(
|
286 | 389 | ac_mode: str = "none",
|
287 | 390 | mp_policy: Optional[MixedPrecisionPolicy] = None,
|
288 | 391 | shard_dim: int = 0,
|
| 392 | + gradient_divide_factor: Optional[float] = None, |
289 | 393 | ):
|
290 | 394 | if mode == "replicate":
|
291 | 395 | param_sharding = (Replicate(),)
|
@@ -348,6 +452,7 @@ def data_parallel(
|
348 | 452 | mode,
|
349 | 453 | regional_ac,
|
350 | 454 | mp_policy=mp_policy,
|
| 455 | + gradient_divide_factor=gradient_divide_factor, |
351 | 456 | ),
|
352 | 457 | )
|
353 | 458 | return model
|
0 commit comments