@@ -4834,26 +4834,49 @@ def _test_post_localSGD_optimizer_parity(self, averager, grad_is_view):
4834
4834
BACKEND not in DistTestCases .backend_feature ["ddp" ],
4835
4835
f"The { BACKEND } backend does not support DistributedDataParallel"
4836
4836
)
4837
- def test_post_localSGD_optimizer_parity (self , grad_is_view ):
4837
+ def test_post_localSGD_optimizer_parity (self ):
4838
4838
torch .cuda .set_device (self .rank )
4839
4839
4840
4840
averager = averagers .PeriodicModelAverager (period = 4 , warmup_steps = 10 )
4841
4841
self ._test_post_localSGD_optimizer_parity (averager , grad_is_view = False )
4842
+
4843
+ @skip_if_lt_x_gpu (2 )
4844
+ @sandcastle_skip_if (
4845
+ BACKEND not in DistTestCases .backend_feature ["ddp" ],
4846
+ f"The { BACKEND } backend does not support DistributedDataParallel"
4847
+ )
4848
+ def test_post_localSGD_optimizer_parity_grad_is_view (self ):
4849
+ torch .cuda .set_device (self .rank )
4850
+
4851
+ averager = averagers .PeriodicModelAverager (period = 4 , warmup_steps = 10 )
4842
4852
self ._test_post_localSGD_optimizer_parity (averager , grad_is_view = True )
4843
4853
4844
4854
@skip_if_lt_x_gpu (4 )
4845
4855
@sandcastle_skip_if (
4846
4856
BACKEND not in DistTestCases .backend_feature ["ddp" ],
4847
4857
f"The { BACKEND } backend does not support DistributedDataParallel"
4848
4858
)
4849
- def test_post_localSGD_optimizer_parity_with_hierarchical_sgd (self , grad_is_view ):
4859
+ def test_post_localSGD_optimizer_parity_with_hierarchical_sgd (self ):
4850
4860
torch .cuda .set_device (self .rank )
4851
4861
4852
4862
period_group_size_dict = OrderedDict ([(2 , 2 ), (4 , dist .get_world_size ())])
4853
4863
averager = hierarchicalSGD .HierarchicalModelAverager (
4854
4864
period_group_size_dict = period_group_size_dict , warmup_steps = 4
4855
4865
)
4856
4866
self ._test_post_localSGD_optimizer_parity (averager , grad_is_view = False )
4867
+
4868
+ @skip_if_lt_x_gpu (4 )
4869
+ @sandcastle_skip_if (
4870
+ BACKEND not in DistTestCases .backend_feature ["ddp" ],
4871
+ f"The { BACKEND } backend does not support DistributedDataParallel"
4872
+ )
4873
+ def test_post_localSGD_optimizer_parity_with_hierarchical_sgd_grad_is_view (self ):
4874
+ torch .cuda .set_device (self .rank )
4875
+
4876
+ period_group_size_dict = OrderedDict ([(2 , 2 ), (4 , dist .get_world_size ())])
4877
+ averager = hierarchicalSGD .HierarchicalModelAverager (
4878
+ period_group_size_dict = period_group_size_dict , warmup_steps = 4
4879
+ )
4857
4880
self ._test_post_localSGD_optimizer_parity (averager , grad_is_view = True )
4858
4881
4859
4882
@sandcastle_skip_if (
0 commit comments