Skip to content

Gradient clipping doesn't work with FSDP CPU offloading #1977

@acisseJZhong

Description

@acisseJZhong

I am running the full finetune distributed recipe, when setting clip_grad_norm: 1.0 and fsdp_cpu_offload: True, it raises error
RuntimeError: No backend type associated with device type cpu

Full error stack trace:

[rank2]: Traceback (most recent call last):
[rank2]:   File "/torchtune/recipes/full_finetune_distributed.py", line 847, in <module>
[rank2]:     sys.exit(recipe_main())
[rank2]:              ^^^^^^^^^^^^^
[rank2]:   File "/torchtune/torchtune/config/_parse.py", line 99, in wrapper
[rank2]:     sys.exit(recipe_main(conf))
[rank2]:              ^^^^^^^^^^^^^^^^^
[rank2]:   File "/torchtune/recipes/full_finetune_distributed.py", line 842, in recipe_main
[rank2]:     recipe.train()
[rank2]:   File "/torchtune/recipes/full_finetune_distributed.py", line 740, in train
[rank2]:     grad_norm = torch.nn.utils.clip_grad_norm_(
[rank2]:                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py", line 30, in _no_grad_wrapper
[rank2]:     return func(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/nn/utils/clip_grad.py", line 105, in clip_grad_norm_
[rank2]:     clip_coef = max_norm / (total_norm + 1e-6)
[rank2]:                 ~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_tensor.py", line 39, in wrapped
[rank2]:     return f(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_tensor.py", line 1075, in __rdiv__
[rank2]:     return self.reciprocal() * other
[rank2]:            ^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_compile.py", line 32, in inner
[rank2]:     return disable_fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 721, in _fn
[rank2]:     return fn(*args, **kwargs)
[rank2]:            ^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
[rank2]:     return DTensor._op_dispatcher.dispatch(
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 181, in dispatch
[rank2]:     self.redistribute_local_args(
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 317, in redistribute_local_args
[rank2]:     resharded_local_tensor = redistribute_local_tensor(
[rank2]:                              ^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_redistribute.py", line 208, in redistribute_local_tensor
[rank2]:     new_local_tensor = partial_spec._reduce_value(
[rank2]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/_ops/_math_ops.py", line 126, in _reduce_value
[rank2]:     reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim)
[rank2]:                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/tensor/placement_types.py", line 599, in _reduce_value
[rank2]:     return funcol.all_reduce(
[rank2]:            ^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/distributed/_functional_collectives.py", line 176, in all_reduce
[rank2]:     tensor = torch.ops._c10d_functional.all_reduce(self, reduceOp.lower(), group_name)
[rank2]:              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]:   File "/home/jessicazhong/.conda/envs/torchtune/lib/python3.11/site-packages/torch/_ops.py", line 1123, in __call__
[rank2]:     return self._op(*args, **(kwargs or {}))
[rank2]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: RuntimeError: No backend type associated with device type cpu

Wondering how should we fix this error?

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions