Description
Note: I'll work on seeing if this reproduces with a non-torchchat example.
While working on migrating torchchat's WeightOnlyInt8Quantizer
to AO's quantize_(model, int8_weight_only())
API, I ran into issues where values would go to NaN after a few layers if the model's dtype was initially float16
. This seems to occur across multiple platforms (tested with MPS, Mac CPU, x86 CPU), so I'm not sure if it's a hardware-specific issue.
Interestingly, setting the model dtype to bfloat16
does not encounter this error.
To repro, you can check out this PR with the migration in torchchat
and run a model using:
python3 torchchat.py generate llama3.1 --quantize '{"linear:int8": {"groupsize": 256}, "executor":{"accelerator":"mps"}}' --prompt "King in the castle, king in the castle, i have a chair." --num-samples 3 --dtype float16
You'll notice the model just outputs "!" tokens - representing NaN. If you add a debug hook to the model, you can identify that some values in the intermediate tensors get very close to 0 just before NaN values are detected.
python3 torchchat.py generate llama3.1 --quantize '{"linear:int8": {"groupsize": 256}, "executor":{"accelerator":"mps"}}' --prompt "King in the castle, king in the castle, i have a chair." --num-samples 3 --dtype float16
Activity
psinger commentedon Feb 5, 2025
I can confirm this. I also noticed it the other day but did not dig deeper.
If the base weights are in
float16
,int8_weight_only
completely breaks the outputs. If the base weights arebfloat16
the output is as expected in inference only mode.[-][Needs more investigation] `int8_weight_only` via `quantize_()` API results in NaN values across multiple CPU architectures[/-][+][Needs more investigation] `int8_weight_only` via `quantize_()` API on `torch.float16` models results in NaN values across multiple CPU architectures[/+]leslie-fang-intel commentedon Feb 11, 2025
Thanks for the reporting this issue. I will take a look of this issue.
leslie-fang-intel commentedon Feb 11, 2025
It seems like a overflow issue. Hi @vmpuri @psinger did GPU meet same issue?
Draft a PR to fix it: #1698
After that the output with above cmd is:
[Inductor][CPP] Fix a CPP GEMM Template output data type issue (#146958)
[Inductor][CPP] Fix a CPP GEMM Template output data type issue (#146958)
[Inductor][CPP] Fix a CPP GEMM Template output data type issue (pytor…
[Inductor][CPP] Fix a CPP GEMM Template output data type issue (pytor…