@@ -410,6 +410,7 @@ def get_block_absmax(inpt_tensor: torch.Tensor, block_size: int) -> torch.Tensor
410
410
class NF4Tensor (torch .Tensor ):
411
411
"""NF4Tensor class for converting a weight to the QLoRA NF4 format"""
412
412
413
+ @torch ._dynamo .disable
413
414
def __new__ (
414
415
cls ,
415
416
# Args related for base tensor construction
@@ -450,6 +451,7 @@ def __new__(
450
451
)
451
452
return nf4tensor
452
453
454
+ @torch ._dynamo .disable
453
455
def __init__ (
454
456
self ,
455
457
tensor_meta : SubclassTensorArgs ,
@@ -758,6 +760,7 @@ def __str__(self):
758
760
return self .to (torch .float32 ).__str__ ()
759
761
760
762
@classmethod
763
+ @torch ._dynamo .disable
761
764
def __torch_dispatch__ (cls , func , types , args , kwargs = None ):
762
765
"""TODO we are not supporting torch dispatch at the moment
763
766
instead we have created a Autograd.Function to handle the linear
@@ -849,7 +852,7 @@ def fsdp_post_all_gather(
849
852
), f"Expects out's data to be the all-gather output"
850
853
return
851
854
852
- return NF4Tensor (
855
+ return nf4_constructor (
853
856
tensor_meta ,
854
857
block_size ,
855
858
n_blocks ,
@@ -934,3 +937,27 @@ def function_cpu(*args, **kwargs):
934
937
updated_attrs = call_from_inner_tensors (nf4tensor , "cpu" , args [1 :], kwargs )
935
938
updated_attrs ["device" ] = "cpu"
936
939
return NF4Tensor (* construct_nf4_args (nf4tensor , updated_attrs ))
940
+
941
+ @torch ._dynamo .allow_in_graph
942
+ def nf4_constructor (
943
+ tensor_meta : SubclassTensorArgs ,
944
+ block_size : int ,
945
+ n_blocks : int ,
946
+ scaler_block_size : int ,
947
+ quantized_scalers : torch .Tensor ,
948
+ quantization_factor : torch .Tensor ,
949
+ scaler_mean : torch .Tensor ,
950
+ quantized_data : torch .Tensor ,
951
+ nf4 : torch .Tensor ,
952
+ ):
953
+ return NF4Tensor (
954
+ tensor_meta ,
955
+ block_size ,
956
+ n_blocks ,
957
+ scaler_block_size ,
958
+ quantized_scalers ,
959
+ quantization_factor ,
960
+ scaler_mean ,
961
+ quantized_data ,
962
+ nf4 ,
963
+ )
0 commit comments