|
5 | 5 | import os
|
6 | 6 | import tempfile
|
7 | 7 | import unittest
|
8 |
| - |
9 | 8 | import torch
|
10 | 9 | import torch.nn as nn
|
11 | 10 | import torch.nn.functional as F
|
|
15 | 14 | BackwardPrefetch,
|
16 | 15 | )
|
17 | 16 | from torch.distributed.fsdp.wrap import (
|
| 17 | + always_wrap_policy, |
18 | 18 | default_auto_wrap_policy,
|
19 | 19 | enable_wrap,
|
20 | 20 | wrap,
|
@@ -67,6 +67,15 @@ def get_model(cuda=True):
|
67 | 67 | sequential = sequential.cuda()
|
68 | 68 | return sequential
|
69 | 69 |
|
| 70 | + @staticmethod |
| 71 | + def verify_model_all_wrapped(cls, model): |
| 72 | + cls.assertTrue(isinstance(model, FSDP)) |
| 73 | + cls.assertTrue(isinstance(model.module[0], FSDP)) |
| 74 | + cls.assertTrue(isinstance(model.module[1], FSDP)) |
| 75 | + cls.assertTrue(isinstance(model.module[2], FSDP)) |
| 76 | + cls.assertTrue(isinstance(model.module[2].module[0], FSDP)) |
| 77 | + cls.assertTrue(isinstance(model.module[2].module[1], FSDP)) |
| 78 | + |
70 | 79 | @staticmethod
|
71 | 80 | def verify_model(cls, model):
|
72 | 81 | cls.assertTrue(isinstance(model, FSDP))
|
@@ -257,6 +266,16 @@ def test_wrap_override_defaults(self):
|
257 | 266 | self.assertEqual(layer.rank, 0)
|
258 | 267 | self.assertEqual(layer.world_size, 2)
|
259 | 268 |
|
| 269 | + @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA") |
| 270 | + def test_always_wrap(self): |
| 271 | + """ |
| 272 | + Test to ensure that if `always_wrap_policy` is |
| 273 | + passed into FSDP, all submodules are wrapped. |
| 274 | + """ |
| 275 | + seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True) |
| 276 | + model = FSDP(seq, process_group=self.process_group, fsdp_auto_wrap_policy=always_wrap_policy) |
| 277 | + TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped(self, model) |
| 278 | + |
260 | 279 | def test_auto_wrap_api(self):
|
261 | 280 | """
|
262 | 281 | Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
|
|
0 commit comments