Skip to content

Commit f99b667

Browse files
authored
[CPU offload optim] Fix when there are non-trainable params (#1210)
fix non-trainable params
1 parent 59dab15 commit f99b667

File tree

2 files changed

+4
-0
lines changed

2 files changed

+4
-0
lines changed

test/prototype/test_low_bit_optim.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,7 @@ def test_optim_4bit_correctness(self, optim_name):
211211
def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum):
212212
device = "cuda"
213213
model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device)
214+
model1[0].requires_grad_(False) # make sure it can work in the presence of non-trainable params
214215
model2 = copy.deepcopy(model1)
215216

216217
optim1 = torch.optim.AdamW(model1.parameters())

torchao/prototype/low_bit_optim/cpu_offload.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,9 @@ def backward_hook(p_cuda):
6666
params = param_group.pop("params")
6767

6868
for p_cuda in params:
69+
if not p_cuda.requires_grad:
70+
continue
71+
6972
# pre-allocate CPU params and grads
7073
p_cpu = torch.empty_like(p_cuda, device="cpu", pin_memory=True)
7174
p_cpu.grad = torch.empty_like(p_cpu, pin_memory=True)

0 commit comments

Comments
 (0)