Open
Description
hi all, i was giving the CPUOffloadOptimizer a try and found two issues when using with QLoRA single device in torchtune:
- When using a LR scheduler i got. Maybe there is a way to inherit the optimizer class?
File "/data/users/felipemello/torchtune/torchtune/training/lr_schedulers.py", line 58, in get_cosine_schedule_with_warmup
return LambdaLR(optimizer, lr_lambda, last_epoch)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 336, in __init__
super().__init__(optimizer, last_epoch, verbose)
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/optim/lr_scheduler.py", line 99, in __init__
raise TypeError(f"{type(optimizer).__name__} is not an Optimizer")
TypeError: CPUOffloadOptimizer is not an Optimizer
- When passing model.params() i got the error below. I imagine that a simple fix is to keep only params that require grad, like adamw implementation oes
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torchao/prototype/low_bit_optim/cpu_offload.py", line 76, in __init__
p_cuda.register_post_accumulate_grad_hook(backward_hook)
File "/home/felipemello/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_tensor.py", line 678, in register_post_accumulate_grad_hook
raise RuntimeError(
RuntimeError: cannot register a hook on a tensor that doesn't require gradient
cc: @gau-nernst
Metadata
Metadata
Assignees
Labels
Type
Projects
Milestone
Relationships
Development
No branches or pull requests
Activity
gau-nernst commentedon Nov 1, 2024
1 is a known issue. You can see my view here #959 (comment). I will look into
torch.optim.Optimizer
base class to see what could go wrong if I makeCPUOffloadOptimizer
inherit it. For example, on the top of my head,CPUOffloadOptimizer
will not haveself.state
.In the meantime,
CPUOffloadOptimizer
requires setting LR manually #584 (comment)For 2, it's an oversight from my part. We can simply add a requires grad check here. Will push a fix
ao/torchao/prototype/low_bit_optim/cpu_offload.py
Lines 68 to 77 in 2761917
fzyzcjy commentedon Nov 18, 2024
Hi, is there any updates? Thanks! It would be great if it can be directly plugged into huggingface transformers, but now it has errors caused by scheduler issue above:
gau-nernst commentedon Nov 19, 2024
@fzyzcjy To unblock your case, you can try making
CPUOffloadOptimizer
a subclass oftorch.optim.Optimizer
i.e. change the following lineao/torchao/prototype/low_bit_optim/cpu_offload.py
Line 9 in aeff75b
to
class CPUOffloadOptimizer(Optimizer):
. Make sure to not callsuper().__init__()
, as this is just a workaround to pass the class check by PyTorch LR scheduler. I will investigate if this will cause other issues before merging the fix.IMO, since Python is duck-typing, PyTorch LR scheduler should not explicitly check for the optimizer class.
fzyzcjy commentedon Nov 19, 2024
Thank you!