From 81e516bbadddd1a55e73a7379b89a48786badf0e Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Fri, 1 Nov 2024 13:37:38 +0800 Subject: [PATCH] fix non-trainable params --- test/prototype/test_low_bit_optim.py | 1 + torchao/prototype/low_bit_optim/cpu_offload.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 8db22ad86e..39f97896bf 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -211,6 +211,7 @@ def test_optim_4bit_correctness(self, optim_name): def test_optim_cpu_offload_correctness(self, offload_grad, grad_accum): device = "cuda" model1 = nn.Sequential(nn.Linear(32, 1024), nn.ReLU(), nn.Linear(1024, 128)).to(device) + model1[0].requires_grad_(False) # make sure it can work in the presence of non-trainable params model2 = copy.deepcopy(model1) optim1 = torch.optim.AdamW(model1.parameters()) diff --git a/torchao/prototype/low_bit_optim/cpu_offload.py b/torchao/prototype/low_bit_optim/cpu_offload.py index 1a4ed58168..6a8671082c 100644 --- a/torchao/prototype/low_bit_optim/cpu_offload.py +++ b/torchao/prototype/low_bit_optim/cpu_offload.py @@ -66,6 +66,9 @@ def backward_hook(p_cuda): params = param_group.pop("params") for p_cuda in params: + if not p_cuda.requires_grad: + continue + # pre-allocate CPU params and grads p_cpu = torch.empty_like(p_cuda, device="cpu", pin_memory=True) p_cpu.grad = torch.empty_like(p_cpu, pin_memory=True)