-
Notifications
You must be signed in to change notification settings - Fork 11.6k
Misc. bug: The model's reasoning performance has significantly decreased despite using different versions of the same model architecture, identical parameters, and the same set of questions. #12816
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
I am not 100% certain whether the flash attention update for b4759 has any bugs. After multiple tests, I’ve become confused myself. I hope someone could conduct some comparative testing on this FA update as well. I look forward to seeing the results of your tests. |
Is this with CUDA backend? What hardware? |
Yes, the hardware used is an RTX 4090 GPU with CUDA 12.8 on Debian 12. After disabling the -fa option, all versions now exhibit consistent performance. I tested several inference scenarios and noticed some "abnormal" behavior in versions post b4759. Upon reviewing the code, the only difference between versions b4759 and b4756 is the CUDA implementation of the "fa" feature. The qwq inference model clearly demonstrates distinguishable reasoning capabilities between these versions. |
I can confirm results similar to your findings. In my case I ran QWQ IQ4_XS quant on 2 4070s with 1 RPC and 1 master and I use greedy sampling with no speculation, 1 server slot and effectively no prompt cache for the tests so results are fully deterministic. If you don't setup a deterministic run you can get different results every time making it very hard to track down problems. Flash attention on: inconsistent across versions Flash attention off: consistent across versions Failure to solve problem with FA off might not be a problem as the model could just be unstable with greedy sampling at the quant I am running. Inconsistent results across versions with FA on suggest a possible issue with the FA related changes however. Note in other tests unique to my downstream server version I was able to get the model to solve cipher very efficiently in about 1500 tokens by speculating it with a R1 distill so token processing was being done in batches > 1 by the target QwQ model. It might be coincidence that it worked or it might point to some issue with batch size 1 in the new FA code. |
pinging @JohannesGaessler - This sounds like a possible precision issue introduced in #12014? |
Btw, I recently fixed the Metal FA kernels to always use |
Generally speaking there is no guarantee that results will be bit-for-bit identical across versions. I changed the order in which floating point operations are done and as such the results will inevitably change due to floating point rounding error. Especially at e.g. the beginnings of sentences the token distribution is very flat and small changes can result in a different token being sampled. A single prompt is unfortunately simply not enough data to tell whether or not there is a statistically significant difference to the average reasoning ability of the model. I'm currently working on code for evaluating language model benchmarks using the llama.cpp server. So far I have support for MMLU (14k questions) and GSM8K (1300 problems). I'll prioritize adding support for models that need more than 24 GB of memory and investigate whether the FlashAttention changes have made a statistically significant difference. |
I tested QwQ 32b q8_0 on 100 questions of MMLU and GSM8K. The two setups I currently have are "instant" where the model is forced to provide an answer immediately and "normal" where the model is allowed to reason. I'm using greedy sampling with no modifiers to the token probabilities except for a grammar that forces the model to choose between the four answers of MMLU at the end when asked for a final answer.
It seems to be that there are indeed changes but they are small and not consistently better or worse. |
Some perplexity results to track changes vs version. My downstream server can compute perplexity in batch size 1 or batch size 128 and it is also compared against llama-perplexity : B4742 B4759 B5121 These results show :
I think QwQ 32b is just an unstable checkpoint, at least at 4b quants. QwQ preview and Deepseek R1 distill based on Qwen 32b did not show hyper sensitive behavior like this at 4b quants. QwQ 32b often gets stuck in think mode, and often gets into short term repeats while thinking. I believe small changes in backend math are exposing model instability. |
I can confirm that using I haven't time currently to find the exact details though as away. |
Sorry if this was a confusing post: It was linked from #12801 (comment) and thought it was specifically about draft models. |
I saw that version b5028 was released (llama: add option to override model tensor buffers #11397), and I was excited to compile and use this official version. However, I was surprised to find that the models I normally use now behave like a completely different person. They’ve lost their previous rationality and conciseness, becoming more emotional, verbose, and even unstable in mood. After multiple comparative tests, I discovered that this change was caused by an update to the attention calculation in this version. Currently, I’m forced to stay on the older version at https://github.com/ggml-org/llama.cpp/tree/sl/custom-tensor-offload (which has since been deleted). This version combines the old FA implementation with the capability to offload expert tensors to system memory. I’m hoping that someone with expertise can help restore the main branch to its former rational and concise behavior. |
I saved a copy of this branch here if it's any use: https://github.com/jukofyork/llama.cpp/tree/custom-tensor-offload |
I spent half an hour trying to fork it but couldn't get it to work. I gave up. I still haven't mastered using GitHub properly. However, I had saved the ZIP source code package earlier, so I can still use that. |
Name and Version
built with cc (Debian 12.2.0-14) 12.2.0 for x86_64-linux-gnu
llama.cpp-b4702
llama.cpp-b4751
llama.cpp-b4756 **************
llama.cpp-b4759 **************
llama.cpp-b4761
llama.cpp-b4762
llama.cpp-b4769
llama.cpp-b4775
llama.cpp-b4800
llama.cpp-b4900
llama.cpp-b4940
llama.cpp-b4990
llama.cpp-b5026
llama.cpp-b5030
Operating systems
Linux
Which llama.cpp modules do you know to be affected?
llama-server, llama-cli
Command line
Problem description & steps to reproduce
Testing with Different Versions of the llama.cpp Server for the Same Inference Task
Using two versions of the llama.cpp server to address the same problem:
llama.cpp-b4756
llama.cpp-b4759
Both versions employ identical parameters and models, yet exhibit significant performance differences.
Key observations:
Performance degradation:
b4759 is noticeably less capable than b4756 (performing worse than twice as poorly in some cases).
Token consumption for the same task:
b4756: ~3,000 tokens
b4759: ~6,000 tokens
Version comparison:
b4702 (an older version) shows superior performance compared to b4756.
The test problem used:
Can you help me decrypt this cipher I received?
"K nkmg rncakpi hqqvdcnn."
This behavior is reproducible through multiple tests. After extensive testing, version b4759 was identified as the one with drastically degraded performance.
If you can reproduce similar findings, please share your test cases!
First Bad Commit
No response
Relevant log output
The text was updated successfully, but these errors were encountered: