Skip to content

Low-bit optim support for DTensor [to be closed] #490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from functools import partial

import pytest
import torch
Expand All @@ -10,9 +9,16 @@
parametrize,
run_tests,
)
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
apply_activation_checkpointing,
CheckpointWrapper,
)
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
from torchao.prototype import low_bit_optim
from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit
from torchao.utils import TORCH_VERSION_AFTER_2_3
from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4

try:
import bitsandbytes as bnb
Expand All @@ -24,6 +30,10 @@
except ImportError:
lpmm = None

# for FSDP2 test
if TORCH_VERSION_AFTER_2_4:
from torch.distributed._composable.fsdp import CPUOffloadPolicy, OffloadPolicy, fully_shard


_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])

Expand Down Expand Up @@ -156,6 +166,81 @@ def test_optim_fp8_smoke(self, optim_name, device):
optim.zero_grad()


class TestFSDP2(FSDPTest):
@property
def world_size(self) -> int:
return 2

@pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="torch >= 2.4 required")
@skip_if_lt_x_gpu(2)
def test_fsdp2(self):
self.run_subtests(
{
"enable_activation_checkpointing": [False, True],
"offload_policy": [
OffloadPolicy(),
# CPUOffloadPolicy(pin_memory=True),
# CPUOffloadPolicy(pin_memory=False),
],
},
self._test_fsdp2,
)

def _test_fsdp2(self, enable_activation_checkpointing, offload_policy):
from torch.testing._internal.distributed._tensor.common_dtensor import (
ModelArgs,
Transformer,
TransformerBlock,
)

batch_size = 3
vocab_size = 1024
seq_len = 64
model_args = ModelArgs(
n_layers=3,
n_heads=4,
dim=1024,
vocab_size=vocab_size,
max_seq_len=seq_len,
dropout_p=0,
)
torch.manual_seed(42)
with torch.device("cuda"):
base_model = Transformer(model_args)
if enable_activation_checkpointing:
apply_activation_checkpointing(base_model, auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}))
base_optim = low_bit_optim.Adam8bit(base_model.parameters(), lr=1e-2)

fsdp_kwargs = {"offload_policy": offload_policy}
fsdp_model = copy.deepcopy(base_model)
for m in fsdp_model.modules():
if enable_activation_checkpointing:
if isinstance(m, CheckpointWrapper):
fully_shard(m, **fsdp_kwargs)
else:
if isinstance(m, TransformerBlock):
fully_shard(m, **fsdp_kwargs)
fully_shard(fsdp_model, **fsdp_kwargs)
fsdp_optim = low_bit_optim.Adam8bit(fsdp_model.parameters(), lr=1e-2)

torch.manual_seed(42 + self.rank + 1)
for iter_idx in range(5):
inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda")
fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
fsdp_loss = fsdp_model(inp).sum()
fsdp_loss.backward()
fsdp_optim.step()

base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
base_loss = base_model(inp).sum()
base_loss.backward()
for param in base_model.parameters():
if param.grad is not None:
torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG)
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)

Expand Down
30 changes: 23 additions & 7 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import Tensor
from torch.optim import Optimizer
from torch.distributed._tensor import DTensor

from .subclass_8bit import maybe_new_8bit_zero_buffer
from .subclass_4bit import maybe_new_4bit_zero_buffer
Expand Down Expand Up @@ -48,16 +49,31 @@ def step(self, closure=None):
if grad.is_sparse:
raise RuntimeError("Sparse gradient is not supported")

# unwrap DTensor
# set requires_grad for unwrapped param to avoid torch.compile() recompilation
# if isinstance(p, DTensor):
# p = p._local_tensor.requires_grad_(True)
# if isinstance(grad, DTensor):
# grad = grad._local_tensor

state = self.state[p]

# uncomment this to flatten tensor to avoid recompiling
# but you will get the following error
# AssertionError: s8 (could be from ["L['grad']._base._local_tensor.size()[0]"]) not in {s3: ["L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]"], s4: ["L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]"], s12: ["L['grad']._local_tensor.size()[0]", "L['grad']._local_tensor.size()[0]"], s10: ["L['grad']._local_tensor.storage_offset()", "L['grad']._local_tensor.storage_offset()"], s16: ["L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]"], s23: ["L['p']._local_tensor.size()[0]", "L['p']._local_tensor.size()[0]"]}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
# p = p.view(-1)
# grad = grad.view(-1)
Copy link
Collaborator Author

@gau-nernst gau-nernst Jul 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bdhirsh The .view(-1) is here. The snippet in the PR description is just to show that torch.compile can't generate dynamic kernel for DTensor (from what I understand).

To give more context. The low-bit optim here works fine with normal tensor. To avoid re-compilation when p has different ndim, I call .view(-1) to flatten it (I observe torch.compile(dynamic=True) still re-compiles when p has different ndim, thus the .view(-1) trick).

So if I use the same trick for DTensor, the above error will show up.

The recompilation issue is worse for DTensor, because it seems like torch.compile can't generate dynamic kernel for it (as reported in the PR description as a standalone example).


# without it, you hit cache size limit

# State initialization
# state is flattened so that torch.compile won't recompile for tensors with different ndim
# flatten buffer to avoid torch.compile() recompilation
if len(state) == 0:
state["step"] = torch.tensor(0.0, device=p.device)
state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size)
state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size)
state["exp_avg"] = self._new_buffer(p, True, self.block_size)
state["exp_avg_sq"] = self._new_buffer(p, False, self.block_size)
if group["amsgrad"]:
state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size)
state["max_exp_avg_sq"] = self._new_buffer(p, False, self.block_size)

state["step"] += 1

Expand All @@ -67,10 +83,10 @@ def step(self, closure=None):
if not isinstance(group["lr"], Tensor):
group["lr"] = torch.tensor(group["lr"], device=p.device)

# flatten p and grad so that torch.compile won't recompile for tensors with different ndim
# flatten p and grad to avoid torch.compile() recompilation
single_param_adam(
p.view(-1),
grad.view(-1),
p,
grad,
state["step"],
state["exp_avg"],
state["exp_avg_sq"],
Expand Down
12 changes: 10 additions & 2 deletions torchao/prototype/low_bit_optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch import Tensor
from torch.optim import Optimizer
from torch.distributed._tensor import DTensor

from .subclass_8bit import maybe_new_8bit_zero_buffer
from .subclass_4bit import maybe_new_4bit_zero_buffer
Expand Down Expand Up @@ -48,10 +49,17 @@ def step(self, closure=None):
if grad.is_sparse:
raise RuntimeError("Sparse gradient is not supported")

# unwrap DTensor
# set requires_grad for unwrapped param to avoid torch.compile() recompilation
if isinstance(p, DTensor):
p = p._local_tensor.requires_grad_(True)
if isinstance(grad, DTensor):
grad = grad._local_tensor

state = self.state[p]

# State initialization
# state is flattened so that torch.compile won't recompile for tensors with different ndim
# flatten buffer to avoid torch.compile() recompilation
if len(state) == 0:
state["step"] = torch.tensor(0.0, device=p.device)
state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size)
Expand All @@ -67,7 +75,7 @@ def step(self, closure=None):
if not isinstance(group["lr"], Tensor):
group["lr"] = torch.tensor(group["lr"], device=p.device)

# flatten p and grad so that torch.compile won't recompile for tensors with different ndim
# flatten p and grad to avoid torch.compile() recompilation
single_param_adamw(
p.view(-1),
grad.view(-1),
Expand Down
12 changes: 9 additions & 3 deletions torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,8 @@ def _(func, *args, **kwargs):

elif isinstance(dst, OptimState8bit):
codes, scale = quantize_8bit_with_qmap(src, dst.qmap, dst.block_size)
dst.codes.copy_(codes)
dst.scale.copy_(scale)
dst.codes.copy_(codes.view(dst.codes.shape))
dst.scale.copy_(scale.view(dst.scale.shape))

else:
dst.copy_(src.dequantize())
Expand All @@ -198,7 +198,13 @@ def _(func, *args, **kwargs):
# TODO: also skip 1D tensor? e.g. biases and norm scales
def maybe_new_8bit_zero_buffer(p: Tensor, signed: bool = True, block_size: int = 2048):
if p.numel() >= 4096 and p.numel() % block_size == 0:
out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device)
from torch.distributed._tensor import DTensor

if isinstance(p, DTensor):
out = torch.empty_like(p)
out._local_tensor = OptimState8bit.zeros(out._local_tensor.shape, signed, block_size, device=out._local_tensor.device)
else:
out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device)
else:
out = torch.zeros_like(p)
return out
Loading