-
Notifications
You must be signed in to change notification settings - Fork 30.3k
[torchao safetensors] integrate torchao safetensors support with transformers #40735
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
d60acfe
to
392a504
Compare
cc @MekkCyber for quantization |
7246421
to
6a26d01
Compare
6a26d01
to
40cb596
Compare
[For maintainers] Suggested jobs to run (before merge) run-slow: torchao_integration |
@@ -399,7 +401,7 @@ def test_autoquant(self): | |||
|
|||
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj) | |||
|
|||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)" | |||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be reverted?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i double checked that this fails on main as well, so i just added the correction
looks good, please update PR summary to align with recent code changes as well cc @SunMarc @MekkCyber please check if the API changes make sense |
Would be very nice to have this propagated to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this nice PR ! This is a very nice feature that will bring more adoption for torchao ! Excited to have this soon in diffusers also. Left a couple of comments
if hf_quantizer.quantization_config.quant_method is QuantizationMethod.TORCHAO: | ||
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization) | ||
else: | ||
state_dict = hf_quantizer.get_state_dict(self) | ||
metadata["format"] = "pt" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
<aybe we remove get_state_dict
completely and only keep get_state_dict_and_metadata
. I think this will be more clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
renaming in a separate PR
Context
Currently, we need to use
safe_serialization=False
while saving models as shown here. This PR enables safetensors support for torchao so that users can now save and load checkpoints using safetensors. Currently, only Float8Tensor is supported (Float8DynamicActivationFloat8WeightConfig
,Float8WeightOnlyConfig
) but allowing other subclasses should involve minimal code changes.Summary
Changes to transformers code includes:
TorchAoHfQuantizer
, we provideget_state_dict
andupdate_state_dict_with_metadata
that flattens/unflattens a model state dict with tensor subclasses by calling functionality built out in this PR.modeling_utils.py
, we make appropriate changes to support propagating the metadata from tensor subclasses. We also add logic similar tohqq
andbnb
to directly load ontocpu
rather thanmeta
.Test Plan
Modified unit test to allow safe serialization. Run using
python tests/quantization/torchao_integration/test_torchao.py
Reference https://huggingface.co/torchao-testing/opt-125m-Float8WeightOnlyConfig-v2-0.14.0.dev-safetensors for an example of a serialized model and test script