Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions test/prototype/test_low_bit_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,15 @@ def _test_fsdp2(self, optim_cls):
base_optim.step()
self.assertEqual(fsdp_loss, base_loss)

base_param = base_optim.param_groups[0]["params"][0]
base_exp_avg = base_optim.state[base_param]["exp_avg"]

fsdp_param = fsdp_optim.param_groups[0]["params"][0]
fsdp_exp_avg = fsdp_optim.state[fsdp_param]["exp_avg"]
full_fsdp_exp_avg = fsdp_exp_avg.full_tensor()

self.assertEqual(base_exp_avg.dequantize(), full_fsdp_exp_avg.dequantize())


instantiate_parametrized_tests(TestQuantize)
instantiate_parametrized_tests(TestOptim)
Expand Down
10 changes: 5 additions & 5 deletions torchao/prototype/low_bit_optim/adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = torch.empty_like(p)
out._local_tensor = self._subclass_zeros(
out._local_tensor,
signed,
self.block_size,
out = DTensor.from_local(
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size)
Expand Down
10 changes: 5 additions & 5 deletions torchao/prototype/low_bit_optim/adamw.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def _subclass_zeros(p: Tensor, signed: bool, block_size: int):
def _new_buffer(self, p: Tensor, signed: bool):
if p.numel() >= 4096 and p.numel() % self.block_size == 0:
if isinstance(p, DTensor):
out = torch.empty_like(p)
out._local_tensor = self._subclass_zeros(
out._local_tensor,
signed,
self.block_size,
out = DTensor.from_local(
local_tensor=self._subclass_zeros(p.to_local(), signed, self.block_size),
device_mesh=p.device_mesh,
placements=p.placements,
run_check=False,
)
else:
out = self._subclass_zeros(p, signed, self.block_size)
Expand Down
42 changes: 34 additions & 8 deletions torchao/prototype/low_bit_optim/subclass_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@


aten = torch.ops.aten

c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional

# https://github.com/thu-ml/low-bit-optimizers/blob/e3e2854728e498c2a606e3fdb88daa27ae94f9a6/lpmm/configs/2nd_moment_group_128.yml
# NOTE: power-1 is linear
Expand Down Expand Up @@ -38,10 +39,7 @@ def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool, sha
self.qmap = qmap
self.signed = signed
self._shape = shape

@property
def block_size(self):
return self.codes.numel() * 2 // self.scale.numel()
self.block_size = codes.numel() * 2 // scale.numel()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious q: Is there some description of the codes/ scales tensor and their relation to each other?

I can see the pattern that codes has .5x (4bit) and 1x (8bit) the bsize * scale numels
But does this assert square blocks?
I think some description here would be helpful

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add some description. Basically for 8-bit and FP8, codes has the same shape as the "outer shape", while for 4-bit, since there is bit-packing, I find that it's easier to let codes be a flattened 1D buffer and keep track of the shape manually.
To get the scale, the float tensor is actually flattened first and reshape to (-1, block_size). This is done to relax the requirement that the last dimension must be divisible by block_size -> now we only need numel (total size) to be divisible by block_size. This is especially needed when block size is large (8-bit optim uses block_size=2048 as done in bnb). Since optim update is element-wise, we don't really need to care if the original tensor is 1D, 2D, or n-D (well, maybe there is some structure in n-D tensor that flattening it might not be so wise). I believe the original implementation in bnb does this as well.
-> scale is always a 1D tensor, with size=original_tensor.numel() // block_size

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@drisspg Added some docs. Lmk if it is still unclear.


def __tensor_flatten__(self):
return self.tensor_attrs, [self.signed, self._shape]
Expand Down Expand Up @@ -113,9 +111,37 @@ def _(func, *args, **kwargs):
return func(*args, **kwargs)


# this is needed for DTensor.from_local() and for flattening tensor
@OptimState4bit.implements(aten.view.default)
def _(func, *args, **kwargs):
x, shape = args
if len(shape) > 1 or shape[0] != -1:
raise ValueError(f"{x.__class__.__name__} only supports .view() with shape=[-1]")
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),))

if tuple(x.shape) == tuple(shape):
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, x._shape)

if len(shape) == 1 and shape[0] == -1:
return OptimState4bit(x.codes, x.scale, x.qmap, x.signed, (x.numel(),))

raise ValueError(f"{x.__class__.__name__} only supports .view() with same shape or shape=[-1]")


# this is needed for DTensor.full_tensor()
@OptimState4bit.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimState4bit):
raise ValueError(f"expecting a OptimState4bit but found {type(x)}")

codes = func(x.codes, *args[1:], **kwargs)
scale = func(x.scale, *args[1:], **kwargs)

# adjust the first dim
shape = (x._shape[0] * codes.numel() // x.codes.numel(),) + x._shape[1:]

# assume tensors from all ranks have the same signedness
return OptimState4bit(codes, scale, x.qmap.clone(), x.signed, shape)
35 changes: 31 additions & 4 deletions torchao/prototype/low_bit_optim/subclass_8bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@


aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional

QMAP_SIGNED = create_dynamic_map(signed=True)
QMAP_UNSIGNED = create_dynamic_map(signed=False)
Expand Down Expand Up @@ -33,10 +35,7 @@ def __init__(self, codes: Tensor, scale: Tensor, qmap: Tensor, signed: bool):
self.scale = scale
self.qmap = qmap
self.signed = signed

@property
def block_size(self):
return self.codes.numel() // self.scale.numel()
self.block_size = codes.numel() // scale.numel()

def __tensor_flatten__(self):
return self.tensor_attrs, [self.signed]
Expand Down Expand Up @@ -97,3 +96,31 @@ def _(func, *args, **kwargs):
def _(func, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimState8bit) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimState8bit.implements(aten.view.default)
def _(func, *args, **kwargs):
x, shape = args
return OptimState8bit(x.codes.view(shape), x.scale, x.qmap, x.signed)


# this is needed for DTensor.full_tensor()
@OptimState8bit.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimState8bit):
raise ValueError(f"expecting a OptimState8bit but found {type(x)}")

# assume tensors from all ranks have the same signedness
return OptimState8bit(
func(x.codes, *args[1:], **kwargs),
func(x.scale, *args[1:], **kwargs),
x.qmap.clone(),
x.signed,
)
34 changes: 30 additions & 4 deletions torchao/prototype/low_bit_optim/subclass_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@


aten = torch.ops.aten
c10d_functional = torch.ops.c10d_functional
_c10d_functional = torch.ops._c10d_functional

DTYPE = torch.float8_e4m3fn


Expand Down Expand Up @@ -35,10 +38,7 @@ def __init__(self, codes: Tensor, scale: Tensor):
assert codes.dtype is DTYPE
self.codes = codes
self.scale = scale

@property
def block_size(self):
return self.codes.numel() // self.scale.numel()
self.block_size = codes.numel() // scale.numel()

def __tensor_flatten__(self):
return self.tensor_attrs, []
Expand Down Expand Up @@ -99,3 +99,29 @@ def _(func, *args, **kwargs):
def _(func, *args, **kwargs):
args = [x.dequantize() if isinstance(x, OptimStateFp8) else x for x in args]
return func(*args, **kwargs)


# this is needed for DTensor.from_local()
@OptimStateFp8.implements(aten.view.default)
def _(func, *args, **kwargs):
x, shape = args
return OptimStateFp8(x.codes.view(shape), x.scale)


# this is needed for DTensor.full_tensor()
@OptimStateFp8.implements([
c10d_functional.all_gather_into_tensor.default,
_c10d_functional.all_gather_into_tensor.default,
c10d_functional.wait_tensor.default,
_c10d_functional.wait_tensor.default,
])
def _(func, *args, **kwargs):
x = args[0]
if not isinstance(x, OptimStateFp8):
raise ValueError(f"expecting a OptimStateFp8 but found {type(x)}")

# assume tensors from all ranks have the same signedness
return OptimStateFp8(
func(x.codes, *args[1:], **kwargs),
func(x.scale, *args[1:], **kwargs),
)