Skip to content

Commit aba1a10

Browse files
committed
add fp8 support
1 parent e5548b7 commit aba1a10

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

scripts/hf_eval.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
5656
change_linear_weights_to_int4_woqtensors(model.to(device=device))
5757
elif quantization == "autoquant":
5858
model = autoquant(model.to(device=device))
59+
elif quantization == "fp8":
60+
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
61+
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
62+
model.to(device)
63+
swap_linear_with_float8_linear(
64+
model,
65+
Float8DynamicLinear,
66+
from_float_kwargs={
67+
"pre_quantize_weight": True,
68+
},
69+
)
70+
pass # no quantization applied, model is already on device and precision dtype.
71+
5972
with torch.no_grad():
6073
result = evaluate(
6174
HFLM(
@@ -78,7 +91,7 @@ def run_evaluation(repo_id, tasks, limit, device, precision, quantization, compi
7891
parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate')
7992
parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use')
8093
parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation')
81-
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "None"], help='Which quantization technique to apply')
94+
parser.add_argument('-q', '--quantization', default = "None", choices=["int8dq", "int8wo", "int4wo","autoquant", "fp8", "None"], help='Which quantization technique to apply')
8295
parser.add_argument('--compile', action='store_true', help='Whether to compile the model.')
8396
parser.add_argument('--batch_size', type=int, default=1, help='Batch size to use for evaluation, note int8wo and int4wo work best with small batchsizes, int8dq works better with large batchsizes')
8497
parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time')

0 commit comments

Comments
 (0)