Skip to content

Commit 4fe68da

Browse files
committed
Split the unit tests
1 parent 3491f4c commit 4fe68da

File tree

1 file changed

+25
-2
lines changed

1 file changed

+25
-2
lines changed

torch/testing/_internal/distributed/distributed_test.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4834,26 +4834,49 @@ def _test_post_localSGD_optimizer_parity(self, averager, grad_is_view):
48344834
BACKEND not in DistTestCases.backend_feature["ddp"],
48354835
f"The {BACKEND} backend does not support DistributedDataParallel"
48364836
)
4837-
def test_post_localSGD_optimizer_parity(self, grad_is_view):
4837+
def test_post_localSGD_optimizer_parity(self):
48384838
torch.cuda.set_device(self.rank)
48394839

48404840
averager = averagers.PeriodicModelAverager(period=4, warmup_steps=10)
48414841
self._test_post_localSGD_optimizer_parity(averager, grad_is_view=False)
4842+
4843+
@skip_if_lt_x_gpu(2)
4844+
@sandcastle_skip_if(
4845+
BACKEND not in DistTestCases.backend_feature["ddp"],
4846+
f"The {BACKEND} backend does not support DistributedDataParallel"
4847+
)
4848+
def test_post_localSGD_optimizer_parity_grad_is_view(self):
4849+
torch.cuda.set_device(self.rank)
4850+
4851+
averager = averagers.PeriodicModelAverager(period=4, warmup_steps=10)
48424852
self._test_post_localSGD_optimizer_parity(averager, grad_is_view=True)
48434853

48444854
@skip_if_lt_x_gpu(4)
48454855
@sandcastle_skip_if(
48464856
BACKEND not in DistTestCases.backend_feature["ddp"],
48474857
f"The {BACKEND} backend does not support DistributedDataParallel"
48484858
)
4849-
def test_post_localSGD_optimizer_parity_with_hierarchical_sgd(self, grad_is_view):
4859+
def test_post_localSGD_optimizer_parity_with_hierarchical_sgd(self):
48504860
torch.cuda.set_device(self.rank)
48514861

48524862
period_group_size_dict = OrderedDict([(2, 2), (4, dist.get_world_size())])
48534863
averager = hierarchicalSGD.HierarchicalModelAverager(
48544864
period_group_size_dict=period_group_size_dict, warmup_steps=4
48554865
)
48564866
self._test_post_localSGD_optimizer_parity(averager, grad_is_view=False)
4867+
4868+
@skip_if_lt_x_gpu(4)
4869+
@sandcastle_skip_if(
4870+
BACKEND not in DistTestCases.backend_feature["ddp"],
4871+
f"The {BACKEND} backend does not support DistributedDataParallel"
4872+
)
4873+
def test_post_localSGD_optimizer_parity_with_hierarchical_sgd_grad_is_view(self):
4874+
torch.cuda.set_device(self.rank)
4875+
4876+
period_group_size_dict = OrderedDict([(2, 2), (4, dist.get_world_size())])
4877+
averager = hierarchicalSGD.HierarchicalModelAverager(
4878+
period_group_size_dict=period_group_size_dict, warmup_steps=4
4879+
)
48574880
self._test_post_localSGD_optimizer_parity(averager, grad_is_view=True)
48584881

48594882
@sandcastle_skip_if(

0 commit comments

Comments
 (0)