Skip to content

Distributing ao tensor subclasses in .safetensors checkpoints #2338

@mikaylagawarecki

Description

@mikaylagawarecki

Context

The current status quo for distributing ao weights on huggingface is as checkpoints produced by torch.save (examples here). The reasoning is that by default, safetensors only supports saving state dictionaries of plain tensors, whereas ao weights tend to be tensor subclasses.

@drisspg and I had a conversation last week about what would be necessary for ao weights be distributed as .safetensors files rather than .pt files on huggingface. My understanding from him is this might be necessary for checkpoints to be marked as official on huggingface

Since ao tensor subclasses are wrapper tensor subclasses that have plain tensors + metadata, it should theoretically be possible to decompose subclasses into tensors and metadata --> save tensors + metadata to a safetensors file (with special handling for non-json serializable metadata) with safetensors.torch.save_file --> have a helper on top of safetensors.torch.load_file that reconstructs the subclass.

I wrote a simple prototype of what this would looks like and am looking

  1. To stimulate discussion on what saving ao subclasses in safetensors format might look like/whether this is needed
  2. For feedback on whether a solution like below covers the important cases/is a viable solution.

Simple example

We took the more straightforward case of a LinearActivationQuantizedTensor with dynamic quantization to fbgemmfp8tensor, which I understand from Driss might be the main case to target.

My understanding here is that a LinearActivationQuantizedTensor would take

  • original_weight_tensor: a plain tensor
  • input_quant_func: in this case to_fbgemm_fp8, which maps to FbgemmFp8Tensor.from_float
  • quant_kwargs: in this case, activation_scale_ub

Given this, I generated a rudimentary script that would

  1. Extract tensor and non-tensor attributes from LinearActivationQuantizedTensor
  2. Pass all information necessary to reconstruct the subclass to the metadata argument to safetensors.torch.save_file
    i. In particular the function torchao.dtypes.fbgemm_fp8_tensor.to_fbgemm_fp8 is serialized as a string "{__module__}.{__qualname__}" and the actual function object is accessed from a dict during loading.
  3. Load tensors via safetensors.torch.load_file and manually read metadata from the .safetensors file in order to reconstruct the subclasses

Click me for script


import torch
import json
import inspect
from typing import Dict, Any, Callable, Optional, Union, Tuple, List
from safetensors.torch import save_file, load_file
import torchao

ALLOWED_QUANT_FUNCTIONS = {
    "torchao.dtypes.fbgemm_fp8_tensor.to_fbgemm_fp8" : torchao.dtypes.fbgemm_fp8_tensor.to_fbgemm_fp8,
    "torchao.dtypes.fbgemm_fp8_tensor.FbgemmFp8Tensor.from_float" : torchao.dtypes.fbgemm_fp8_tensor.FbgemmFp8Tensor.from_float
    # add to me
}

def get_function_path(func: Callable) -> str:
    """Get the import path for a function."""
    fullpath = f"{func.__module__}.{func.__qualname__}"
    
    assert fullpath in ALLOWED_QUANT_FUNCTIONS
    return fullpath

def create_metadata_for_tensor_subclass(tensor: torch.Tensor) -> Tuple[Dict[str, str], Dict[str, torch.Tensor]]:
    """
    Create metadata for tensor subclasses from torchao.
    
    Args:
        tensor: A tensor subclass (e.g., LinearActivationQuantizedTensor)
        
    Returns:
        Tuple of (metadata, tensors_dict) where:
        - metadata: Dictionary with metadata needed to reconstruct the tensor
        - tensors_dict: Dictionary with tensors to save
    """
    metadata = {}
    tensors_dict = {}
    
    if tensor.__class__.__name__ == "LinearActivationQuantizedTensor":
        metadata["tensor_type"] = "LinearActivationQuantizedTensor"
        
        quant_func_path = get_function_path(tensor.input_quant_func)
        metadata["input_quant_func"] = quant_func_path
        
        if hasattr(tensor, "quant_kwargs"):
            metadata["quant_kwargs"] = json.dumps(tensor.quant_kwargs)
        
        tensors_dict["original_weight"] = tensor.original_weight_tensor
    else:
        raise ValueError(f"Unsupported tensor type: {tensor.__class__.__name__}")
    
    return metadata, tensors_dict

def save_tensor_subclass(tensor: torch.Tensor, file_path: str, additional_metadata: Optional[Dict[str, str]] = None):
    """
    Save a tensor subclass with appropriate metadata.
    
    Args:
        tensor: The tensor subclass to save
        file_path: Path where to save the tensor
        additional_metadata: Optional additional metadata to include
    """
    metadata, tensors_dict = create_metadata_for_tensor_subclass(tensor)
    
    if additional_metadata:
        metadata.update(additional_metadata)
    
    save_file(tensors_dict, file_path, metadata=metadata)
    print(f"Saved tensor subclass to {file_path} with metadata")

def save_tensor_subclass_dict(tensor_dict: Dict[str, torch.Tensor], file_path: str, 
                             additional_metadata: Optional[Dict[str, str]] = None):
    """
    Save a dictionary of tensor subclasses with appropriate metadata.
    
    Args:
        tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names
        file_path: Path where to save the tensors
        additional_metadata: Optional additional metadata to include
    """
    combined_metadata = {}
    combined_tensors_dict = {}
    
    for tensor_name, tensor in tensor_dict.items():
        # TODO: handle case where tensor is a plain tensor
        metadata, tensors_dict = create_metadata_for_tensor_subclass(tensor)
        
        prefixed_tensors_dict = {f"{tensor_name}:{key}": value for key, value in tensors_dict.items()}
        
        for key, value in metadata.items():
            combined_metadata[f"{tensor_name}:{key}"] = value
        
        combined_tensors_dict.update(prefixed_tensors_dict)
    
    combined_metadata["tensor_names"] = json.dumps(list(tensor_dict.keys()))
    
    if additional_metadata:
        combined_metadata.update(additional_metadata)
    
    save_file(combined_tensors_dict, file_path, metadata=combined_metadata)
    print(f"Saved {len(tensor_dict)} tensor subclasses to {file_path} with metadata")

def load_tensor_subclass(file_path: str) -> torch.Tensor:
    """
    Load a tensor subclass from a safetensors file.
    
    Args:
        file_path: Path to the safetensors file
        
    Returns:
        The reconstructed tensor subclass
    """
    loaded_tensors = load_file(file_path)
    
    with open(file_path, "rb") as f:
        import struct
        header_size = struct.unpack("<Q", f.read(8))[0]
        header_bytes = f.read(header_size)
        header = json.loads(header_bytes)
        metadata = header.get("__metadata__", {})

    assert "tensor_names" not in metadata
    
    tensor_type = metadata.get("tensor_type")
    
    if tensor_type == "LinearActivationQuantizedTensor":
        original_weight = loaded_tensors["original_weight"]
        
        quant_func_path = metadata.get("input_quant_func")
        if quant_func_path not in ALLOWED_QUANT_FUNCTIONS:
            raise ValueError(f"Security error: Quantization function '{quant_func_path}' is not in the allowed list")
        
        quant_func = ALLOWED_QUANT_FUNCTIONS.get(quant_func_path)
        
        quant_kwargs = json.loads(metadata.get("quant_kwargs", "{}"))
        
        from torchao.quantization.linear_activation_quantized_tensor import to_linear_activation_quantized
        
        return to_linear_activation_quantized(
            original_weight,
            input_quant_func=quant_func,
            quant_kwargs=quant_kwargs
        )   
    else:
        return loaded_tensors

def load_tensor_subclass_dict(file_path: str) -> Dict[str, torch.Tensor]:
    """
    Load a dictionary of tensor subclasses from a safetensors file.
    
    Args:
        file_path: Path to the safetensors file
        
    Returns:
        Dictionary of reconstructed tensor subclasses
    """
    loaded_tensors = load_file(file_path)
    
    with open(file_path, "rb") as f:
        import struct
        header_size = struct.unpack("<Q", f.read(8))[0]
        header_bytes = f.read(header_size)
        header = json.loads(header_bytes)
        metadata = header.get("__metadata__", {})
    
    if "tensor_names" not in metadata:
        tensor = load_tensor_subclass(file_path)
        return {"tensor": tensor}
    
    tensor_names = json.loads(metadata["tensor_names"])
    result = {}
    
    for tensor_name in tensor_names:
        tensor_metadata = {}
        for key, value in metadata.items():
            if key.startswith(f"{tensor_name}:"):
                # Remove the prefix
                tensor_metadata[key[len(tensor_name)+1:]] = value
        
        tensor_tensors = {}
        for key, value in loaded_tensors.items():
            if key.startswith(f"{tensor_name}:"):
                # Remove the prefix
                tensor_tensors[key[len(tensor_name)+1:]] = value
        
        tensor_type = tensor_metadata.get("tensor_type")
        
        if tensor_type == "LinearActivationQuantizedTensor":
            original_weight = tensor_tensors["original_weight"]
            
            quant_func_path = tensor_metadata.get("input_quant_func")
            if quant_func_path not in ALLOWED_QUANT_FUNCTIONS:
                raise ValueError(f"Security error: Quantization function '{quant_func_path}' is not in the allowed list")
            
            quant_func = ALLOWED_QUANT_FUNCTIONS.get(quant_func_path)
            
            quant_kwargs = json.loads(tensor_metadata.get("quant_kwargs", "{}"))
            
            from torchao.quantization.linear_activation_quantized_tensor import to_linear_activation_quantized
            
            result[tensor_name] = to_linear_activation_quantized(
                original_weight,
                input_quant_func=quant_func,
                quant_kwargs=quant_kwargs
            )
        
        else:
            result[tensor_name] = tensor_tensors
    
    return result

if __name__ == "__main__":
    from torchao.dtypes.fbgemm_fp8_tensor import to_fbgemm_fp8
    from torchao.quantization.linear_activation_quantized_tensor import to_linear_activation_quantized
    
    weight1 = torch.randn(32, 64, dtype=torch.float32)
    weight2 = torch.randn(64, 128, dtype=torch.float32)
    weight3 = torch.randn(128, 256, dtype=torch.float32)
    
    fp8_weight1 = to_linear_activation_quantized(
        weight1,
        input_quant_func=to_fbgemm_fp8,
        quant_kwargs={"activation_scale_ub": 0.3}
    )
    
    fp8_weight2 = to_linear_activation_quantized(
        weight2,
        input_quant_func=to_fbgemm_fp8,
        quant_kwargs={"activation_scale_ub": 0.5}
    )
    
    fp8_weight3 = to_linear_activation_quantized(
        weight3,
        input_quant_func=to_fbgemm_fp8,
        quant_kwargs={"activation_scale_ub": 0.7}
    )
    
    tensor_dict = {
        "layer1.weight": fp8_weight1,
        "layer2.weight": fp8_weight2,
        "layer3.weight": fp8_weight3
    }
    print("Saving tensor subclasses...")
    print(tensor_dict)

    save_tensor_subclass_dict(tensor_dict, "fp8_weights_multi.safetensors")
    
    reconstructed_dict = load_tensor_subclass_dict("fp8_weights_multi.safetensors")
    
    print(f"Loaded {len(reconstructed_dict)} tensors:")
    for name, tensor in reconstructed_dict.items():
        print(name, tensor)

Questions

  1. Does a solution like the above sufficiently handle the main cases ao cares about?
    a. Does this approach scale to subclasses other than LinearQuantizedActivationTensor in ao or are there more non-json serializable attributes that we might need to handle/might not be possible to handle?
  2. Do we have control over all points where these checkpoints would be loaded (as there would need to be additional helper code on top of safetensors.load_file to reconstruct the subclasses).

cc @jerryzh168 @drisspg

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions