Skip to content

Commit e5b705c

Browse files
authored
dont let dynamo inline inside of NF4 constructors or __torch_dispatch__ (#544)
* dont let dynamo inline inside of NF4 constructors or __torch_dispatch__ * allow_in_graph the ctr instead of disabling it
1 parent afde175 commit e5b705c

File tree

1 file changed

+28
-1
lines changed

1 file changed

+28
-1
lines changed

torchao/dtypes/nf4tensor.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor
410410
class NF4Tensor(torch.Tensor):
411411
"""NF4Tensor class for converting a weight to the QLoRA NF4 format"""
412412

413+
@torch._dynamo.disable
413414
def __new__(
414415
cls,
415416
# Args related for base tensor construction
@@ -450,6 +451,7 @@ def __new__(
450451
)
451452
return nf4tensor
452453

454+
@torch._dynamo.disable
453455
def __init__(
454456
self,
455457
tensor_meta: SubclassTensorArgs,
@@ -758,6 +760,7 @@ def __str__(self):
758760
return self.to(torch.float32).__str__()
759761

760762
@classmethod
763+
@torch._dynamo.disable
761764
def __torch_dispatch__(cls, func, types, args, kwargs=None):
762765
"""TODO we are not supporting torch dispatch at the moment
763766
instead we have created a Autograd.Function to handle the linear
@@ -849,7 +852,7 @@ def fsdp_post_all_gather(
849852
), f"Expects out's data to be the all-gather output"
850853
return
851854

852-
return NF4Tensor(
855+
return nf4_constructor(
853856
tensor_meta,
854857
block_size,
855858
n_blocks,
@@ -934,3 +937,27 @@ def function_cpu(*args, **kwargs):
934937
updated_attrs = call_from_inner_tensors(nf4tensor, "cpu", args[1:], kwargs)
935938
updated_attrs["device"] = "cpu"
936939
return NF4Tensor(*construct_nf4_args(nf4tensor, updated_attrs))
940+
941+
@torch._dynamo.allow_in_graph
942+
def nf4_constructor(
943+
tensor_meta: SubclassTensorArgs,
944+
block_size: int,
945+
n_blocks: int,
946+
scaler_block_size: int,
947+
quantized_scalers: torch.Tensor,
948+
quantization_factor: torch.Tensor,
949+
scaler_mean: torch.Tensor,
950+
quantized_data: torch.Tensor,
951+
nf4: torch.Tensor,
952+
):
953+
return NF4Tensor(
954+
tensor_meta,
955+
block_size,
956+
n_blocks,
957+
scaler_block_size,
958+
quantized_scalers,
959+
quantization_factor,
960+
scaler_mean,
961+
quantized_data,
962+
nf4,
963+
)

0 commit comments

Comments
 (0)