-
Notifications
You must be signed in to change notification settings - Fork 52
Support loading for static quant weight fp8 act fp8 #730
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for loading static quantized models with FP8 weights and FP8 activations by implementing a new quantized linear layer class and updating the model conversion infrastructure.
Key changes:
- Implemented
WeightFP8ActFP8StaticQuantLinear
class for handling FP8 weight and activation quantization - Updated model conversion logic to detect and handle FP8 static quantization configurations
- Enhanced test coverage to verify both export and loading functionality for static FP8 quantization
Reviewed Changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
test/test_cpu/test_export.py | Extended test to verify loading of static FP8 quantized models and renamed test method |
auto_round/inference/convert_model.py | Added support for act_dynamic parameter and FP8 static quantization detection in model conversion |
auto_round/inference/backend.py | Added FP8 static quantization detection function and updated dynamic import logic |
auto_round/export/export_to_autoround/export_to_fp8_woq.py | Implemented new WeightFP8ActFP8StaticQuantLinear class with quantization/dequantization methods |
This PR is unnecessary for now, you need to work with Heng to fix the FP8 |
@wenhuach21 The purpose of this PR is to support loading an existing qmodel from disk and then evaluating its accuracy. cc @n1ck-guo |
Yes, but the primary purpose is for evaluation, which the fake model should cover well #731. This is not a product feature, and it involves changes to critical product code. As discussed earlier, please hold this PR for now, or move the code elsewhere without modifying the important HF model inference code. |
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
auto_round/inference/backend.py
Outdated
bits, group_size, sym = config["bits"], config["group_size"], config["sym"] | ||
|
||
if is_weight_fp8_activation_static_fp8(config): | ||
from auto_round.export.export_to_autoround.export_to_fp8_woq import WeightFP8ActFP8StaticQuantLinear |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please refactor the code to comply with our backend coding style: backend.py#L191
. Use the Torch backend to support w8fp8. You may add extra keys if the existing ones do not meet your requirements, but ensure all calls go through the Torch backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you could also create a new backend and add an alias "torch"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, add a new backend auto_round:torch_fp8_static
.
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
@@ -61,7 +62,7 @@ def skip_not_convert_modules(model, quantization_config, layer_names, layer_conf | |||
try: # transformers new api | |||
modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert, add_default_skips=True) | |||
except: | |||
modules_to_not_convert = get_modules_to_not_convert(model, modules_to_not_convert) | |||
modules_to_not_convert = _get_modules_to_not_convert(model, modules_to_not_convert) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why change this
- "clip" (bool): Whether weight clipping is enabled. | ||
""" | ||
bits = quantization_config.bits | ||
group_size = quantization_config.group_size | ||
data_type = getattr(quantization_config, "data_type", "int") # Default to "int" if not specified | ||
sym = quantization_config.sym | ||
|
||
act_dynamic = getattr(quantization_config, "act_dynamic", False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you add this one, I think you need to introduce act_bits, act_group_size, xxx too
priority=0, | ||
feature_checks=[], | ||
alias=["auto_round", "torch"], | ||
requirements=["auto-round>=0.6.1.dev0"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"> 0.6.0"
sym=[True], | ||
dtype=["float32", "float16", "bfloat16"], | ||
bits=[8], | ||
priority=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add checkers to differentiate between w8a8-FP8 dynamic, w8a8-int, w8a4, etc.
|
||
|
||
def _is_weight_fp8_activation_static_fp8(bit, group_size, sym, data_type, act_dynamic): | ||
return bit == 8 and group_size == -1 and sym and data_type == "fp8" and not act_dynamic |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add type and return dytpe
dtype=["float32", "float16", "bfloat16"], | ||
bits=[8], | ||
priority=0, | ||
feature_checks=[], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
group_size checker is also needed, as mentioned in the comment, only per-channel is supported for now
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
Signed-off-by: yiliu30 <[email protected]>
auto_round:torch_fp8_static
for loading and inference w8afp8