-
Notifications
You must be signed in to change notification settings - Fork 337
safetensors support #2881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
safetensors support #2881
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2881
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 1 PendingAs of commit c4e9165 with merge base 2f78cfe ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
torchao/prototype/quantization/QuantizeTensorToFloat8KwargsJSON.py
Outdated
Show resolved
Hide resolved
torchao/prototype/quantization/QuantizeTensorToFloat8KwargsJSON.py
Outdated
Show resolved
Hide resolved
from torchao.quantization import Float8Tensor | ||
from torchao.quantization.quantize_.common import KernelPreference | ||
from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like some imports are not used?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are used to create the ALLOWED_CLASSES dict
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for restricting to ALLOWED_CLASSES
Context
Currently we serialize and distribute torchao quantized model with pytorch native APIs, specifically
torch.save(model.state_dict())
andmodel.load_state_dict(state_dict, assign=True)
Summary
This PR builds out the functionality
save_tensor_subclass_dict
andload_tensor_subclass_dict
usingsave_file
andload_file
from safetensors library for FP8 based off of the script here.Test Plan
Outputs match after saving/loading model.state_dict()