diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index e037ef8b11..4621f8ad76 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -1,5 +1,4 @@ import copy -from functools import partial import pytest import torch @@ -10,9 +9,16 @@ parametrize, run_tests, ) +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, + CheckpointWrapper, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest from torchao.prototype import low_bit_optim from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 try: import bitsandbytes as bnb @@ -24,6 +30,10 @@ except ImportError: lpmm = None +# for FSDP2 test +if TORCH_VERSION_AFTER_2_4: + from torch.distributed._composable.fsdp import CPUOffloadPolicy, OffloadPolicy, fully_shard + _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -156,6 +166,81 @@ def test_optim_fp8_smoke(self, optim_name, device): optim.zero_grad() +class TestFSDP2(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="torch >= 2.4 required") + @skip_if_lt_x_gpu(2) + def test_fsdp2(self): + self.run_subtests( + { + "enable_activation_checkpointing": [False, True], + "offload_policy": [ + OffloadPolicy(), + # CPUOffloadPolicy(pin_memory=True), + # CPUOffloadPolicy(pin_memory=False), + ], + }, + self._test_fsdp2, + ) + + def _test_fsdp2(self, enable_activation_checkpointing, offload_policy): + from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, + TransformerBlock, + ) + + batch_size = 3 + vocab_size = 1024 + seq_len = 64 + model_args = ModelArgs( + n_layers=3, + n_heads=4, + dim=1024, + vocab_size=vocab_size, + max_seq_len=seq_len, + dropout_p=0, + ) + torch.manual_seed(42) + with torch.device("cuda"): + base_model = Transformer(model_args) + if enable_activation_checkpointing: + apply_activation_checkpointing(base_model, auto_wrap_policy=ModuleWrapPolicy({TransformerBlock})) + base_optim = low_bit_optim.Adam8bit(base_model.parameters(), lr=1e-2) + + fsdp_kwargs = {"offload_policy": offload_policy} + fsdp_model = copy.deepcopy(base_model) + for m in fsdp_model.modules(): + if enable_activation_checkpointing: + if isinstance(m, CheckpointWrapper): + fully_shard(m, **fsdp_kwargs) + else: + if isinstance(m, TransformerBlock): + fully_shard(m, **fsdp_kwargs) + fully_shard(fsdp_model, **fsdp_kwargs) + fsdp_optim = low_bit_optim.Adam8bit(fsdp_model.parameters(), lr=1e-2) + + torch.manual_seed(42 + self.rank + 1) + for iter_idx in range(5): + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + fsdp_loss = fsdp_model(inp).sum() + fsdp_loss.backward() + fsdp_optim.step() + + base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + base_loss = base_model(inp).sum() + base_loss.backward() + for param in base_model.parameters(): + if param.grad is not None: + torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG) + base_optim.step() + self.assertEqual(fsdp_loss, base_loss) + + instantiate_parametrized_tests(TestQuantize) instantiate_parametrized_tests(TestOptim) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 6595711138..a37a4cab5a 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -3,6 +3,7 @@ import torch from torch import Tensor from torch.optim import Optimizer +from torch.distributed._tensor import DTensor from .subclass_8bit import maybe_new_8bit_zero_buffer from .subclass_4bit import maybe_new_4bit_zero_buffer @@ -48,16 +49,31 @@ def step(self, closure=None): if grad.is_sparse: raise RuntimeError("Sparse gradient is not supported") + # unwrap DTensor + # set requires_grad for unwrapped param to avoid torch.compile() recompilation + # if isinstance(p, DTensor): + # p = p._local_tensor.requires_grad_(True) + # if isinstance(grad, DTensor): + # grad = grad._local_tensor + state = self.state[p] + # uncomment this to flatten tensor to avoid recompiling + # but you will get the following error + # AssertionError: s8 (could be from ["L['grad']._base._local_tensor.size()[0]"]) not in {s3: ["L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]"], s4: ["L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]"], s12: ["L['grad']._local_tensor.size()[0]", "L['grad']._local_tensor.size()[0]"], s10: ["L['grad']._local_tensor.storage_offset()", "L['grad']._local_tensor.storage_offset()"], s16: ["L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]"], s23: ["L['p']._local_tensor.size()[0]", "L['p']._local_tensor.size()[0]"]}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665 + # p = p.view(-1) + # grad = grad.view(-1) + + # without it, you hit cache size limit + # State initialization - # state is flattened so that torch.compile won't recompile for tensors with different ndim + # flatten buffer to avoid torch.compile() recompilation if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) - state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["exp_avg"] = self._new_buffer(p, True, self.block_size) + state["exp_avg_sq"] = self._new_buffer(p, False, self.block_size) if group["amsgrad"]: - state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["max_exp_avg_sq"] = self._new_buffer(p, False, self.block_size) state["step"] += 1 @@ -67,10 +83,10 @@ def step(self, closure=None): if not isinstance(group["lr"], Tensor): group["lr"] = torch.tensor(group["lr"], device=p.device) - # flatten p and grad so that torch.compile won't recompile for tensors with different ndim + # flatten p and grad to avoid torch.compile() recompilation single_param_adam( - p.view(-1), - grad.view(-1), + p, + grad, state["step"], state["exp_avg"], state["exp_avg_sq"], diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index 9397f04c3c..dc279a3b5c 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -3,6 +3,7 @@ import torch from torch import Tensor from torch.optim import Optimizer +from torch.distributed._tensor import DTensor from .subclass_8bit import maybe_new_8bit_zero_buffer from .subclass_4bit import maybe_new_4bit_zero_buffer @@ -48,10 +49,17 @@ def step(self, closure=None): if grad.is_sparse: raise RuntimeError("Sparse gradient is not supported") + # unwrap DTensor + # set requires_grad for unwrapped param to avoid torch.compile() recompilation + if isinstance(p, DTensor): + p = p._local_tensor.requires_grad_(True) + if isinstance(grad, DTensor): + grad = grad._local_tensor + state = self.state[p] # State initialization - # state is flattened so that torch.compile won't recompile for tensors with different ndim + # flatten buffer to avoid torch.compile() recompilation if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) @@ -67,7 +75,7 @@ def step(self, closure=None): if not isinstance(group["lr"], Tensor): group["lr"] = torch.tensor(group["lr"], device=p.device) - # flatten p and grad so that torch.compile won't recompile for tensors with different ndim + # flatten p and grad to avoid torch.compile() recompilation single_param_adamw( p.view(-1), grad.view(-1), diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 44a3d593cf..52b166061a 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -178,8 +178,8 @@ def _(func, *args, **kwargs): elif isinstance(dst, OptimState8bit): codes, scale = quantize_8bit_with_qmap(src, dst.qmap, dst.block_size) - dst.codes.copy_(codes) - dst.scale.copy_(scale) + dst.codes.copy_(codes.view(dst.codes.shape)) + dst.scale.copy_(scale.view(dst.scale.shape)) else: dst.copy_(src.dequantize()) @@ -198,7 +198,13 @@ def _(func, *args, **kwargs): # TODO: also skip 1D tensor? e.g. biases and norm scales def maybe_new_8bit_zero_buffer(p: Tensor, signed: bool = True, block_size: int = 2048): if p.numel() >= 4096 and p.numel() % block_size == 0: - out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device) + from torch.distributed._tensor import DTensor + + if isinstance(p, DTensor): + out = torch.empty_like(p) + out._local_tensor = OptimState8bit.zeros(out._local_tensor.shape, signed, block_size, device=out._local_tensor.device) + else: + out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device) else: out = torch.zeros_like(p) return out