Skip to content

Commit 780af58

Browse files
rohan-varmacyyever
authored andcommitted
[FSDP] Add always_wrap policy (#73687)
Summary: Add a smaller helper policy that always returns True to automatically always wrap all FSDP submodules. This is the first and simplest step of providing a set of policies that allow users to seamlessly experiment with different FSDP config. More Context: pytorch/pytorch#68789 Pull Request resolved: pytorch/pytorch#73687 Reviewed By: jbschlosser, zhaojuanmao Differential Revision: D34625801 Pulled By: rohan-varma fbshipit-source-id: f20c951f8d62ea29b504543c93acd546247d8206 (cherry picked from commit 3b0bf02bc8bb236ee09e2fa986d52bbf5231efc3)
1 parent 8612b1e commit 780af58

File tree

2 files changed

+29
-1
lines changed

2 files changed

+29
-1
lines changed

test/distributed/fsdp/test_wrap.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import os
66
import tempfile
77
import unittest
8-
98
import torch
109
import torch.nn as nn
1110
import torch.nn.functional as F
@@ -15,6 +14,7 @@
1514
BackwardPrefetch,
1615
)
1716
from torch.distributed.fsdp.wrap import (
17+
always_wrap_policy,
1818
default_auto_wrap_policy,
1919
enable_wrap,
2020
wrap,
@@ -67,6 +67,15 @@ def get_model(cuda=True):
6767
sequential = sequential.cuda()
6868
return sequential
6969

70+
@staticmethod
71+
def verify_model_all_wrapped(cls, model):
72+
cls.assertTrue(isinstance(model, FSDP))
73+
cls.assertTrue(isinstance(model.module[0], FSDP))
74+
cls.assertTrue(isinstance(model.module[1], FSDP))
75+
cls.assertTrue(isinstance(model.module[2], FSDP))
76+
cls.assertTrue(isinstance(model.module[2].module[0], FSDP))
77+
cls.assertTrue(isinstance(model.module[2].module[1], FSDP))
78+
7079
@staticmethod
7180
def verify_model(cls, model):
7281
cls.assertTrue(isinstance(model, FSDP))
@@ -257,6 +266,16 @@ def test_wrap_override_defaults(self):
257266
self.assertEqual(layer.rank, 0)
258267
self.assertEqual(layer.world_size, 2)
259268

269+
@unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
270+
def test_always_wrap(self):
271+
"""
272+
Test to ensure that if `always_wrap_policy` is
273+
passed into FSDP, all submodules are wrapped.
274+
"""
275+
seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True)
276+
model = FSDP(seq, process_group=self.process_group, fsdp_auto_wrap_policy=always_wrap_policy)
277+
TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped(self, model)
278+
260279
def test_auto_wrap_api(self):
261280
"""
262281
Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.

torch/distributed/fsdp/wrap.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,15 @@
99
import torch.nn as nn
1010

1111

12+
def always_wrap_policy(*args, **kwargs) -> bool:
13+
"""
14+
A simple wrapper policy that always returns ``True``,
15+
i.e. when passed as the `auto_wrap_policy` into FSDP,
16+
this will result in all submodules being wrapped as
17+
distinct FSDP instances.
18+
"""
19+
return True
20+
1221
def default_auto_wrap_policy(
1322
module: nn.Module,
1423
recurse: bool,

0 commit comments

Comments
 (0)