From 9331495fc8f81d94bb2a10dd32325e1da9b0fd06 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 08:46:40 +0800 Subject: [PATCH 01/11] add test for FSDP2 --- test/prototype/test_low_bit_optim.py | 96 +++++++++++++++++++++++++++- 1 file changed, 94 insertions(+), 2 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index e037ef8b11..3232c87dc6 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -1,5 +1,4 @@ import copy -from functools import partial import pytest import torch @@ -10,9 +9,16 @@ parametrize, run_tests, ) +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + apply_activation_checkpointing, + CheckpointWrapper, +) +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.testing._internal.common_distributed import skip_if_lt_x_gpu +from torch.testing._internal.common_fsdp import FSDPTest from torchao.prototype import low_bit_optim from torchao.prototype.low_bit_optim import subclass_8bit, subclass_4bit -from torchao.utils import TORCH_VERSION_AFTER_2_3 +from torchao.utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 try: import bitsandbytes as bnb @@ -156,6 +162,92 @@ def test_optim_fp8_smoke(self, optim_name, device): optim.zero_grad() +class TestFSDP2(FSDPTest): + @property + def world_size(self) -> int: + return 2 + + @pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="torch >= 2.4 required") + @skip_if_lt_x_gpu(2) + def test_qlora_fsdp2(self): + from torch.distributed._composable.fsdp import CPUOffloadPolicy, OffloadPolicy + + self.run_subtests( + { + "enable_activation_checkpointing": [False, True], + "offload_policy": [ + OffloadPolicy(), + CPUOffloadPolicy(pin_memory=True), + CPUOffloadPolicy(pin_memory=False), + ], + }, + self._test_fsdp2, + ) + + def _test_fsdp2( + self, + enable_activation_checkpointing: bool, + offload_policy: "OffloadPolicy", + ): + from torch.distributed._composable.fsdp import fully_shard + from torch.testing._internal.distributed._tensor.common_dtensor import ( + ModelArgs, + Transformer, + TransformerBlock, + ) + + batch_size = 3 + vocab_size = 1024 + seq_len = 64 + model_args = ModelArgs( + n_layers=3, + n_heads=4, + dim=1024, + vocab_size=vocab_size, + max_seq_len=seq_len, + dropout_p=0, + ) + torch.manual_seed(42) + with torch.device("cuda"): + base_model = Transformer(model_args) + if enable_activation_checkpointing: + apply_activation_checkpointing( + base_model, auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}) + ) + base_optim = low_bit_optim.Adam8bit(base_model.parameters(), lr=1e-2) + + fsdp_kwargs = {"offload_policy": offload_policy} + fsdp_model = copy.deepcopy(base_model) + for m in fsdp_model.modules(): + if enable_activation_checkpointing: + if isinstance(m, CheckpointWrapper): + fully_shard(m, **fsdp_kwargs) + else: + if isinstance(m, TransformerBlock): + fully_shard(m, **fsdp_kwargs) + fully_shard(fsdp_model, **fsdp_kwargs) + fsdp_optim = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2) + + torch.manual_seed(42 + self.rank + 1) + for iter_idx in range(5): + inp = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") + fsdp_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + fsdp_loss = fsdp_model(inp).sum() + fsdp_loss.backward() + fsdp_optim.step() + + base_optim.zero_grad(set_to_none=(iter_idx % 2 == 0)) + base_loss = base_model(inp).sum() + base_loss.backward() + for param in base_model.parameters(): + if param.grad is not None: + torch.distributed.all_reduce( + param.grad, op=torch.distributed.ReduceOp.AVG + ) + base_optim.step() + self.assertEqual(fsdp_loss, base_loss) + + instantiate_parametrized_tests(TestQuantize) instantiate_parametrized_tests(TestOptim) From e159e9328efe95f4ef0b7b9a2e08c1ee60b87ef6 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 09:29:13 +0800 Subject: [PATCH 02/11] fix optim --- test/prototype/test_low_bit_optim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 3232c87dc6..a9bb848f50 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -226,7 +226,7 @@ def _test_fsdp2( if isinstance(m, TransformerBlock): fully_shard(m, **fsdp_kwargs) fully_shard(fsdp_model, **fsdp_kwargs) - fsdp_optim = torch.optim.AdamW(fsdp_model.parameters(), lr=1e-2) + fsdp_optim = low_bit_optim.Adam8bit(fsdp_model.parameters(), lr=1e-2) torch.manual_seed(42 + self.rank + 1) for iter_idx in range(5): From 2fc88300595cd9060c51f7c982d6781067b63d90 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 05:53:57 +0000 Subject: [PATCH 03/11] add some fsdp2 ops --- test/prototype/test_low_bit_optim.py | 2 +- .../prototype/low_bit_optim/subclass_8bit.py | 54 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index a9bb848f50..cd860085ec 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -169,7 +169,7 @@ def world_size(self) -> int: @pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="torch >= 2.4 required") @skip_if_lt_x_gpu(2) - def test_qlora_fsdp2(self): + def test_fsdp2(self): from torch.distributed._composable.fsdp import CPUOffloadPolicy, OffloadPolicy self.run_subtests( diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 44a3d593cf..35b299bb5c 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -1,5 +1,7 @@ import torch from torch import Tensor +from torch.utils._python_dispatch import return_and_correct_aliasing + from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE @@ -157,6 +159,14 @@ def __repr__(self): f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})" ) + def _apply_fn_to_data(self, fn, *args, **kwargs): + return self.__class__( + fn(self.codes, *args, **kwargs), + fn(self.scale, *args, **kwargs), + fn(self.qmap, *args, **kwargs), + self.signed, + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: @@ -193,6 +203,50 @@ def _(func, *args, **kwargs): return func(*args, **kwargs) +# the following ops are required for FSDP +@OptimState8bit.implements(aten._to_copy.default) +def _(func, *args, **kwargs): + # ignore dtype and layout + kwargs.pop("dtype") + kwargs.pop("layout") + return args[0]._apply_fn_to_data(func, **kwargs) + + +@OptimState8bit.implements(aten.detach.default) +def _(func, *args, **kwargs): + return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) + + +@OptimState8bit.implements(aten.split.Tensor) +def _(func, *args, **kwargs): + tensor: OptimState8bit = args[0] + split_size = args[1] + dim = args[2] if len(args) >= 3 else 0 + + assert dim == 0 # only support splitting dim=0 + assert isinstance(split_size, int) # don't support list + assert tensor.numel() % split_size == 0 + + codes_list = torch.split(tensor.codes, split_size, dim) + scale_list = torch.split(tensor.scale, split_size, dim) + + outputs = [] + for codes, scale in zip(codes_list, scale_list): + outputs.append(OptimState8bit(codes, scale, tensor.qmap.clone(), tensor.signed)) + return tuple(outputs) + + +@OptimState8bit.implements(aten.empty_like.default) +def _(func, *args, **kwargs): + tensor: OptimState8bit = args[0] + return OptimState8bit( + torch.empty_like(tensor.codes), + torch.empty_like(tensor.scale), + tensor.qmap.clone(), + tensor.signed, + ) + + # follow bitsandbytes # only apply quantization for tensor with more than 4096 values # TODO: also skip 1D tensor? e.g. biases and norm scales From e30f142adf68f94e077110cde096efc45721ce8a Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 09:36:20 +0000 Subject: [PATCH 04/11] add DTensor --- .../prototype/low_bit_optim/subclass_8bit.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 35b299bb5c..ff3d66a626 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -252,7 +252,25 @@ def _(func, *args, **kwargs): # TODO: also skip 1D tensor? e.g. biases and norm scales def maybe_new_8bit_zero_buffer(p: Tensor, signed: bool = True, block_size: int = 2048): if p.numel() >= 4096 and p.numel() % block_size == 0: - out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device) + from torch.distributed._tensor import DTensor + + if isinstance(p, DTensor): + p_local = p._local_tensor + out_local = OptimState8bit.zeros(p_local.shape, signed, block_size, device=p_local.device) + out = DTensor( + local_tensor=out_local, + device_mesh=p.device_mesh, + placements=p.placements, + shape=p.size(), + dtype=p_local.dtype, + stride=p.stride(), + requires_grad=p.requires_grad, + ) + + else: + out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device) + else: out = torch.zeros_like(p) + return out From a88750cefa77f65834b6b0a68bb651a9fef5b59d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 21:12:49 +0800 Subject: [PATCH 05/11] try DTensor --- test/prototype/test_low_bit_optim.py | 10 +++------- torchao/prototype/low_bit_optim/adam.py | 22 +++++++++++++++------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index cd860085ec..13439151a5 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -177,18 +177,14 @@ def test_fsdp2(self): "enable_activation_checkpointing": [False, True], "offload_policy": [ OffloadPolicy(), - CPUOffloadPolicy(pin_memory=True), - CPUOffloadPolicy(pin_memory=False), + # CPUOffloadPolicy(pin_memory=True), + # CPUOffloadPolicy(pin_memory=False), ], }, self._test_fsdp2, ) - def _test_fsdp2( - self, - enable_activation_checkpointing: bool, - offload_policy: "OffloadPolicy", - ): + def _test_fsdp2(self, enable_activation_checkpointing, offload_policy): from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 6595711138..aeaaaa597e 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -3,6 +3,7 @@ import torch from torch import Tensor from torch.optim import Optimizer +from torch.distributed._tensor import DTensor from .subclass_8bit import maybe_new_8bit_zero_buffer from .subclass_4bit import maybe_new_4bit_zero_buffer @@ -48,16 +49,24 @@ def step(self, closure=None): if grad.is_sparse: raise RuntimeError("Sparse gradient is not supported") + # unwrap DTensor + if isinstance(p, DTensor): + p = p._local_tensor.requires_grad_(True) + if isinstance(grad, DTensor): + grad = grad._local_tensor + + # flatten p and grad so that torch.compile won't recompile for tensors with different ndim + p = p.view(-1) + grad = grad.view(-1) state = self.state[p] # State initialization - # state is flattened so that torch.compile won't recompile for tensors with different ndim if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) - state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["exp_avg"] = self._new_buffer(p, True, self.block_size) + state["exp_avg_sq"] = self._new_buffer(p, False, self.block_size) if group["amsgrad"]: - state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["max_exp_avg_sq"] = self._new_buffer(p, False, self.block_size) state["step"] += 1 @@ -67,10 +76,9 @@ def step(self, closure=None): if not isinstance(group["lr"], Tensor): group["lr"] = torch.tensor(group["lr"], device=p.device) - # flatten p and grad so that torch.compile won't recompile for tensors with different ndim single_param_adam( - p.view(-1), - grad.view(-1), + p, + grad, state["step"], state["exp_avg"], state["exp_avg_sq"], From 9593c0734041f71e8186c20dd84ff4b9095760fd Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 21:46:15 +0800 Subject: [PATCH 06/11] undo changes --- torchao/prototype/low_bit_optim/adam.py | 3 +- .../prototype/low_bit_optim/subclass_8bit.py | 73 +------------------ 2 files changed, 3 insertions(+), 73 deletions(-) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index aeaaaa597e..7158a4818d 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -50,12 +50,13 @@ def step(self, closure=None): raise RuntimeError("Sparse gradient is not supported") # unwrap DTensor + # set requires_grad for unwrapped param to avoid torch.compile() recompilation if isinstance(p, DTensor): p = p._local_tensor.requires_grad_(True) if isinstance(grad, DTensor): grad = grad._local_tensor - # flatten p and grad so that torch.compile won't recompile for tensors with different ndim + # flatten p and grad to avoid torch.compile() recompilation p = p.view(-1) grad = grad.view(-1) state = self.state[p] diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index ff3d66a626..908ca8c540 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -1,6 +1,5 @@ import torch from torch import Tensor -from torch.utils._python_dispatch import return_and_correct_aliasing from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE @@ -159,14 +158,6 @@ def __repr__(self): f"shape={tuple(self.shape)}, device={self.device}, requires_grad={self.requires_grad})" ) - def _apply_fn_to_data(self, fn, *args, **kwargs): - return self.__class__( - fn(self.codes, *args, **kwargs), - fn(self.scale, *args, **kwargs), - fn(self.qmap, *args, **kwargs), - self.signed, - ) - @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): if func in _ATEN_OP_OR_TORCH_FN_TABLE[cls]: @@ -203,74 +194,12 @@ def _(func, *args, **kwargs): return func(*args, **kwargs) -# the following ops are required for FSDP -@OptimState8bit.implements(aten._to_copy.default) -def _(func, *args, **kwargs): - # ignore dtype and layout - kwargs.pop("dtype") - kwargs.pop("layout") - return args[0]._apply_fn_to_data(func, **kwargs) - - -@OptimState8bit.implements(aten.detach.default) -def _(func, *args, **kwargs): - return return_and_correct_aliasing(func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)) - - -@OptimState8bit.implements(aten.split.Tensor) -def _(func, *args, **kwargs): - tensor: OptimState8bit = args[0] - split_size = args[1] - dim = args[2] if len(args) >= 3 else 0 - - assert dim == 0 # only support splitting dim=0 - assert isinstance(split_size, int) # don't support list - assert tensor.numel() % split_size == 0 - - codes_list = torch.split(tensor.codes, split_size, dim) - scale_list = torch.split(tensor.scale, split_size, dim) - - outputs = [] - for codes, scale in zip(codes_list, scale_list): - outputs.append(OptimState8bit(codes, scale, tensor.qmap.clone(), tensor.signed)) - return tuple(outputs) - - -@OptimState8bit.implements(aten.empty_like.default) -def _(func, *args, **kwargs): - tensor: OptimState8bit = args[0] - return OptimState8bit( - torch.empty_like(tensor.codes), - torch.empty_like(tensor.scale), - tensor.qmap.clone(), - tensor.signed, - ) - - # follow bitsandbytes # only apply quantization for tensor with more than 4096 values # TODO: also skip 1D tensor? e.g. biases and norm scales def maybe_new_8bit_zero_buffer(p: Tensor, signed: bool = True, block_size: int = 2048): if p.numel() >= 4096 and p.numel() % block_size == 0: - from torch.distributed._tensor import DTensor - - if isinstance(p, DTensor): - p_local = p._local_tensor - out_local = OptimState8bit.zeros(p_local.shape, signed, block_size, device=p_local.device) - out = DTensor( - local_tensor=out_local, - device_mesh=p.device_mesh, - placements=p.placements, - shape=p.size(), - dtype=p_local.dtype, - stride=p.stride(), - requires_grad=p.requires_grad, - ) - - else: - out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device) - + out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device) else: out = torch.zeros_like(p) - return out From e3408d878917efb8a5cec8c0d953ebf37eeca0e0 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 21:47:21 +0800 Subject: [PATCH 07/11] add DTensor support for adamw --- torchao/prototype/low_bit_optim/adamw.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index 9397f04c3c..ae257230da 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -3,6 +3,7 @@ import torch from torch import Tensor from torch.optim import Optimizer +from torch.distributed._tensor import DTensor from .subclass_8bit import maybe_new_8bit_zero_buffer from .subclass_4bit import maybe_new_4bit_zero_buffer @@ -48,16 +49,25 @@ def step(self, closure=None): if grad.is_sparse: raise RuntimeError("Sparse gradient is not supported") + # unwrap DTensor + # set requires_grad for unwrapped param to avoid torch.compile() recompilation + if isinstance(p, DTensor): + p = p._local_tensor.requires_grad_(True) + if isinstance(grad, DTensor): + grad = grad._local_tensor + + # flatten p and grad to avoid torch.compile() recompilation + p = p.view(-1) + grad = grad.view(-1) state = self.state[p] # State initialization - # state is flattened so that torch.compile won't recompile for tensors with different ndim if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) - state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["exp_avg"] = self._new_buffer(p, True, self.block_size) + state["exp_avg_sq"] = self._new_buffer(p, False, self.block_size) if group["amsgrad"]: - state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["max_exp_avg_sq"] = self._new_buffer(p, False, self.block_size) state["step"] += 1 @@ -67,7 +77,6 @@ def step(self, closure=None): if not isinstance(group["lr"], Tensor): group["lr"] = torch.tensor(group["lr"], device=p.device) - # flatten p and grad so that torch.compile won't recompile for tensors with different ndim single_param_adamw( p.view(-1), grad.view(-1), From a71c9bc5a0430ab43a58af6280751e5f27d91c31 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 21:56:54 +0800 Subject: [PATCH 08/11] update imports --- test/prototype/test_low_bit_optim.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/test/prototype/test_low_bit_optim.py b/test/prototype/test_low_bit_optim.py index 13439151a5..4621f8ad76 100644 --- a/test/prototype/test_low_bit_optim.py +++ b/test/prototype/test_low_bit_optim.py @@ -30,6 +30,10 @@ except ImportError: lpmm = None +# for FSDP2 test +if TORCH_VERSION_AFTER_2_4: + from torch.distributed._composable.fsdp import CPUOffloadPolicy, OffloadPolicy, fully_shard + _DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) @@ -170,8 +174,6 @@ def world_size(self) -> int: @pytest.mark.skipif(not TORCH_VERSION_AFTER_2_4, reason="torch >= 2.4 required") @skip_if_lt_x_gpu(2) def test_fsdp2(self): - from torch.distributed._composable.fsdp import CPUOffloadPolicy, OffloadPolicy - self.run_subtests( { "enable_activation_checkpointing": [False, True], @@ -185,7 +187,6 @@ def test_fsdp2(self): ) def _test_fsdp2(self, enable_activation_checkpointing, offload_policy): - from torch.distributed._composable.fsdp import fully_shard from torch.testing._internal.distributed._tensor.common_dtensor import ( ModelArgs, Transformer, @@ -207,9 +208,7 @@ def _test_fsdp2(self, enable_activation_checkpointing, offload_policy): with torch.device("cuda"): base_model = Transformer(model_args) if enable_activation_checkpointing: - apply_activation_checkpointing( - base_model, auto_wrap_policy=ModuleWrapPolicy({TransformerBlock}) - ) + apply_activation_checkpointing(base_model, auto_wrap_policy=ModuleWrapPolicy({TransformerBlock})) base_optim = low_bit_optim.Adam8bit(base_model.parameters(), lr=1e-2) fsdp_kwargs = {"offload_policy": offload_policy} @@ -237,9 +236,7 @@ def _test_fsdp2(self, enable_activation_checkpointing, offload_policy): base_loss.backward() for param in base_model.parameters(): if param.grad is not None: - torch.distributed.all_reduce( - param.grad, op=torch.distributed.ReduceOp.AVG - ) + torch.distributed.all_reduce(param.grad, op=torch.distributed.ReduceOp.AVG) base_optim.step() self.assertEqual(fsdp_loss, base_loss) From 1aa460093439e718df3acf93ac428957101962be Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 21:57:31 +0800 Subject: [PATCH 09/11] remove whitespace --- torchao/prototype/low_bit_optim/subclass_8bit.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 908ca8c540..44a3d593cf 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -1,6 +1,5 @@ import torch from torch import Tensor - from torchao.dtypes.utils import _implements, _ATEN_OP_OR_TORCH_FN_TABLE From 17e4cb8462d25407f88b7e3348cc2abd79907976 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Mon, 8 Jul 2024 23:01:46 +0800 Subject: [PATCH 10/11] fix view issue with compiler --- torchao/prototype/low_bit_optim/adam.py | 15 +++++++-------- torchao/prototype/low_bit_optim/adamw.py | 11 +++++------ 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index 7158a4818d..b406ef6403 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -56,18 +56,16 @@ def step(self, closure=None): if isinstance(grad, DTensor): grad = grad._local_tensor - # flatten p and grad to avoid torch.compile() recompilation - p = p.view(-1) - grad = grad.view(-1) state = self.state[p] # State initialization + # flatten buffer to avoid torch.compile() recompilation if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = self._new_buffer(p, True, self.block_size) - state["exp_avg_sq"] = self._new_buffer(p, False, self.block_size) + state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) + state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) if group["amsgrad"]: - state["max_exp_avg_sq"] = self._new_buffer(p, False, self.block_size) + state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) state["step"] += 1 @@ -77,9 +75,10 @@ def step(self, closure=None): if not isinstance(group["lr"], Tensor): group["lr"] = torch.tensor(group["lr"], device=p.device) + # flatten p and grad to avoid torch.compile() recompilation single_param_adam( - p, - grad, + p.view(-1), + grad.view(-1), state["step"], state["exp_avg"], state["exp_avg_sq"], diff --git a/torchao/prototype/low_bit_optim/adamw.py b/torchao/prototype/low_bit_optim/adamw.py index ae257230da..dc279a3b5c 100644 --- a/torchao/prototype/low_bit_optim/adamw.py +++ b/torchao/prototype/low_bit_optim/adamw.py @@ -56,18 +56,16 @@ def step(self, closure=None): if isinstance(grad, DTensor): grad = grad._local_tensor - # flatten p and grad to avoid torch.compile() recompilation - p = p.view(-1) - grad = grad.view(-1) state = self.state[p] # State initialization + # flatten buffer to avoid torch.compile() recompilation if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = self._new_buffer(p, True, self.block_size) - state["exp_avg_sq"] = self._new_buffer(p, False, self.block_size) + state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) + state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) if group["amsgrad"]: - state["max_exp_avg_sq"] = self._new_buffer(p, False, self.block_size) + state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) state["step"] += 1 @@ -77,6 +75,7 @@ def step(self, closure=None): if not isinstance(group["lr"], Tensor): group["lr"] = torch.tensor(group["lr"], device=p.device) + # flatten p and grad to avoid torch.compile() recompilation single_param_adamw( p.view(-1), grad.view(-1), From 598569a3f81af26138c1673d90c8bd000f2cb439 Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Tue, 9 Jul 2024 01:15:45 +0000 Subject: [PATCH 11/11] add DTensor variant --- torchao/prototype/low_bit_optim/adam.py | 26 ++++++++++++------- .../prototype/low_bit_optim/subclass_8bit.py | 12 ++++++--- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/torchao/prototype/low_bit_optim/adam.py b/torchao/prototype/low_bit_optim/adam.py index b406ef6403..a37a4cab5a 100644 --- a/torchao/prototype/low_bit_optim/adam.py +++ b/torchao/prototype/low_bit_optim/adam.py @@ -51,21 +51,29 @@ def step(self, closure=None): # unwrap DTensor # set requires_grad for unwrapped param to avoid torch.compile() recompilation - if isinstance(p, DTensor): - p = p._local_tensor.requires_grad_(True) - if isinstance(grad, DTensor): - grad = grad._local_tensor + # if isinstance(p, DTensor): + # p = p._local_tensor.requires_grad_(True) + # if isinstance(grad, DTensor): + # grad = grad._local_tensor state = self.state[p] + # uncomment this to flatten tensor to avoid recompiling + # but you will get the following error + # AssertionError: s8 (could be from ["L['grad']._base._local_tensor.size()[0]"]) not in {s3: ["L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]", "L['exp_avg']._local_tensor.scale.size()[0]"], s4: ["L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]", "L['exp_avg_sq']._local_tensor.qmap.size()[0]"], s12: ["L['grad']._local_tensor.size()[0]", "L['grad']._local_tensor.size()[0]"], s10: ["L['grad']._local_tensor.storage_offset()", "L['grad']._local_tensor.storage_offset()"], s16: ["L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]", "L['exp_avg_sq']._local_tensor.scale.size()[0]"], s23: ["L['p']._local_tensor.size()[0]", "L['p']._local_tensor.size()[0]"]}. If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665 + # p = p.view(-1) + # grad = grad.view(-1) + + # without it, you hit cache size limit + # State initialization # flatten buffer to avoid torch.compile() recompilation if len(state) == 0: state["step"] = torch.tensor(0.0, device=p.device) - state["exp_avg"] = self._new_buffer(p.view(-1), True, self.block_size) - state["exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["exp_avg"] = self._new_buffer(p, True, self.block_size) + state["exp_avg_sq"] = self._new_buffer(p, False, self.block_size) if group["amsgrad"]: - state["max_exp_avg_sq"] = self._new_buffer(p.view(-1), False, self.block_size) + state["max_exp_avg_sq"] = self._new_buffer(p, False, self.block_size) state["step"] += 1 @@ -77,8 +85,8 @@ def step(self, closure=None): # flatten p and grad to avoid torch.compile() recompilation single_param_adam( - p.view(-1), - grad.view(-1), + p, + grad, state["step"], state["exp_avg"], state["exp_avg_sq"], diff --git a/torchao/prototype/low_bit_optim/subclass_8bit.py b/torchao/prototype/low_bit_optim/subclass_8bit.py index 44a3d593cf..52b166061a 100644 --- a/torchao/prototype/low_bit_optim/subclass_8bit.py +++ b/torchao/prototype/low_bit_optim/subclass_8bit.py @@ -178,8 +178,8 @@ def _(func, *args, **kwargs): elif isinstance(dst, OptimState8bit): codes, scale = quantize_8bit_with_qmap(src, dst.qmap, dst.block_size) - dst.codes.copy_(codes) - dst.scale.copy_(scale) + dst.codes.copy_(codes.view(dst.codes.shape)) + dst.scale.copy_(scale.view(dst.scale.shape)) else: dst.copy_(src.dequantize()) @@ -198,7 +198,13 @@ def _(func, *args, **kwargs): # TODO: also skip 1D tensor? e.g. biases and norm scales def maybe_new_8bit_zero_buffer(p: Tensor, signed: bool = True, block_size: int = 2048): if p.numel() >= 4096 and p.numel() % block_size == 0: - out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device) + from torch.distributed._tensor import DTensor + + if isinstance(p, DTensor): + out = torch.empty_like(p) + out._local_tensor = OptimState8bit.zeros(out._local_tensor.shape, signed, block_size, device=out._local_tensor.device) + else: + out = OptimState8bit.zeros(p.shape, signed, block_size, device=p.device) else: out = torch.zeros_like(p) return out