Skip to content

Commit 7584fbd

Browse files
sarckkfacebook-github-bot
authored andcommitted
Fix KeyedOptimizer init state test (#1874)
Summary: Pull Request resolved: #1874 `KeyedOptimizer.test_init_state` test started failing since upstream pytorch changes: pytorch/pytorch#122349 (and later follow up: pytorch/pytorch#123757), which only initializes the state for param groups if momentum is enabled for SGD. Updating unit test to enable momentum fixes it. Also adding a new unit test to check state if momentum is disabled. Reviewed By: henrylhtsang Differential Revision: D56076424 fbshipit-source-id: e3d5ae063a5187d2d8702ad7f0bb4b2791b954fe
1 parent 568e116 commit 7584fbd

File tree

1 file changed

+20
-2
lines changed

1 file changed

+20
-2
lines changed

torchrec/optim/tests/test_keyed.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -191,12 +191,12 @@ def test_non_param_state_key(self) -> None:
191191
[{"params": [param_1], "param_group_val_0": 3.0}],
192192
)
193193

194-
def test_init_state(self) -> None:
194+
def test_init_state_with_momentum(self) -> None:
195195
dense = torch.nn.Parameter(torch.ones((2, 3), dtype=torch.float))
196196
sparse = torch.nn.Parameter(torch.ones((1, 4), dtype=torch.float))
197197
opt = KeyedOptimizerWrapper(
198198
{"dense": dense, "sparse": sparse},
199-
lambda params: torch.optim.SGD(params, lr=0.1),
199+
lambda params: torch.optim.SGD(params, lr=0.1, momentum=0.1),
200200
)
201201
opt.init_state({"sparse"})
202202

@@ -208,6 +208,24 @@ def test_init_state(self) -> None:
208208
self.assertTrue(sparse.grad.is_sparse)
209209
self.assertTrue("momentum_buffer" in opt.state_dict()["state"]["sparse"])
210210

211+
def test_init_state_no_momentum(self) -> None:
212+
dense = torch.nn.Parameter(torch.ones((2, 3), dtype=torch.float))
213+
sparse = torch.nn.Parameter(torch.ones((1, 4), dtype=torch.float))
214+
opt = KeyedOptimizerWrapper(
215+
{"dense": dense, "sparse": sparse},
216+
lambda params: torch.optim.SGD(params, lr=0.1),
217+
)
218+
opt.init_state({"sparse"})
219+
220+
self.assertTrue(dense.grad is not None)
221+
self.assertFalse(dense.grad.is_sparse)
222+
223+
self.assertTrue(sparse.grad is not None)
224+
self.assertTrue(sparse.grad.is_sparse)
225+
226+
self.assertTrue("state" in opt.state_dict())
227+
self.assertFalse(opt.state_dict()["state"])
228+
211229
def test_pickle(self) -> None:
212230
dense = torch.nn.Parameter(torch.ones((2, 3), dtype=torch.float))
213231
sparse = torch.nn.Parameter(torch.ones((1, 4), dtype=torch.float))

0 commit comments

Comments
 (0)