Broken MTP (speculative decoding)

#1
by Arien0 - opened

Hi! Thank you for your quants, really apreciated!

However, speculative decoding via native MTP is failing in the quants vs fullsize version.

Just a look at vllm (0.19.0) log, note 0.0% draft acceptance rate with quant:

(APIServer pid=138557) INFO 04-06 00:33:37 [loggers.py:259] Engine 000: Avg prompt throughput: 38.9 tokens/s, Avg generation throughput: 7.5 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 3.9%, Prefix cache hit rate: 0.0%
(APIServer pid=138557) INFO 04-06 00:33:37 [metrics.py:101] SpecDecoding metrics: Mean acceptance length: 1.00, Accepted throughput: 0.00 tokens/s, Drafted throughput: 11.75 tokens/s, Accepted: 0 tokens, Drafted: 296 tokens, Per-position acceptance rate: 0.000, 0.000, 0.000, 0.000, Avg Draft acceptance rate: 0.0%
(APIServer pid=138557) INFO 04-06 00:33:47 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 7.2 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
(APIServer pid=138557) INFO 04-06 00:33:47 [metrics.py:101] SpecDecoding metrics: Mean acceptance length: 1.00, Accepted throughput: 0.00 tokens/s, Drafted throughput: 28.80 tokens/s, Accepted: 0 tokens, Drafted: 288 tokens, Per-position acceptance rate: 0.000, 0.000, 0.000, 0.000, Avg Draft acceptance rate: 0.0%
(APIServer pid=138557) INFO 04-06 00:33:57 [loggers.py:259] Engine 000: Avg prompt throughput: 43.9 tokens/s, Avg generation throughput: 15.7 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 4.1%, Prefix cache hit rate: 0.0%
(APIServer pid=138557) INFO 04-06 00:33:57 [metrics.py:101] SpecDecoding metrics: Mean acceptance length: 1.00, Accepted throughput: 0.00 tokens/s, Drafted throughput: 62.40 tokens/s, Accepted: 0 tokens, Drafted: 624 tokens, Per-position acceptance rate: 0.000, 0.000, 0.000, 0.000, Avg Draft acceptance rate: 0.0%
(APIServer pid=138557) INFO 04-06 00:34:07 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 3.5 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%
(APIServer pid=138557) INFO 04-06 00:34:07 [metrics.py:101] SpecDecoding metrics: Mean acceptance length: 1.00, Accepted throughput: 0.00 tokens/s, Drafted throughput: 14.00 tokens/s, Accepted: 0 tokens, Drafted: 140 tokens, Per-position acceptance rate: 0.000, 0.000, 0.000, 0.000, Avg Draft acceptance rate: 0.0%
(APIServer pid=138557) INFO 04-06 00:34:17 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 0.0 tokens/s, Running: 0 reqs, Waiting: 0 reqs, GPU KV cache usage: 0.0%, Prefix cache hit rate: 0.0%

And just compare to BF16 non-quant:

(APIServer pid=947) INFO 04-04 11:25:21 [loggers.py:259] Engine 000: Avg prompt throughput: 555.5 tokens/s, Avg generation throughput: 37.8 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 14.9%, Prefix cache hit rate: 0.0%
(APIServer pid=947) INFO 04-04 11:25:21 [metrics.py:101] SpecDecoding metrics: Mean acceptance length: 3.02, Accepted throughput: 25.20 tokens/s, Drafted throughput: 50.00 tokens/s, Accepted: 252 tokens, Drafted: 500 tokens, Per-position acceptance rate: 0.800, 0.592, 0.376, 0.248, Avg Draft acceptance rate: 50.4%
(APIServer pid=947) INFO 04-04 11:25:31 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 58.5 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 15.4%, Prefix cache hit rate: 0.0%
(APIServer pid=947) INFO 04-04 11:25:31 [metrics.py:101] SpecDecoding metrics: Mean acceptance length: 2.97, Accepted throughput: 38.80 tokens/s, Drafted throughput: 78.80 tokens/s, Accepted: 388 tokens, Drafted: 788 tokens, Per-position acceptance rate: 0.731, 0.528, 0.396, 0.315, Avg Draft acceptance rate: 49.2%
(APIServer pid=947) INFO 04-04 11:25:41 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 64.1 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 16.0%, Prefix cache hit rate: 0.0%
(APIServer pid=947) INFO 04-04 11:25:41 [metrics.py:101] SpecDecoding metrics: Mean acceptance length: 3.25, Accepted throughput: 44.40 tokens/s, Drafted throughput: 78.80 tokens/s, Accepted: 444 tokens, Drafted: 788 tokens, Per-position acceptance rate: 0.822, 0.635, 0.452, 0.345, Avg Draft acceptance rate: 56.3%
(APIServer pid=947) INFO 04-04 11:25:51 [loggers.py:259] Engine 000: Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 78.2 tokens/s, Running: 1 reqs, Waiting: 0 reqs, GPU KV cache usage: 17.1%, Prefix cache hit rate: 0.0%
(APIServer pid=947) INFO 04-04 11:25:51 [metrics.py:101] SpecDecoding metrics: Mean acceptance length: 3.99, Accepted throughput: 58.60 tokens/s, Drafted throughput: 78.40 tokens/s, Accepted: 586 tokens, Drafted: 784 tokens, Per-position acceptance rate: 0.918, 0.786, 0.694, 0.592, Avg Draft acceptance rate: 74.7%

Maybe something's wrong with GPTQModel?

I’m not entirely convinced that this issue originates from GPTQModel during the quantization process. Based on my README evaluation using PPL, the quantized model performs almost identically to the full-precision model.

That said, I’m not fully certain how much degradation quantization can introduce in specific downstream tasks. I think many people, myself included, used to assume that weight-only W8A16 or even W8A8 quantization should be nearly lossless at inference. However, in more complex application scenarios explored over the past year, even INT8 quantization can sometimes lead to noticeable degradation.

For example, in this paper, Table 10 shows that on Mistral-7B-v0.3, several metrics suffer significant drops, or even collapse, under INT8 quantization.

Would you be glad to share more details about your use case and the exact commands you are running? While I may not be able to fully resolve the issue, having more context could help facilitate a more productive discussion within the community.

Sure! 1xRTX3090 with vLLM 0.19.0 transformers 4.57.6
Load sh (same script for INT8):

#!/bin/bash
source /opt/llm/vllm/venv/bin/activate
export CUDA_DEVICE_ORDER=PCI_BUS_ID
export CUDA_VISIBLE_DEVICES=0
export OMP_NUM_THREADS=12
export PYTORCH_ALLOC_CONF="expandable_segments:True,garbage_collection_threshold:0.5"
export VLLM_NO_USAGE_STATS=1
export VLLM_DO_NOT_TRACK=1
export VLLM_ENABLE_CUDAGRAPH_GC=1
export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=1
#export VLLM_LOGGING_LEVEL=DEBUG
export CUDA_HOME=/usr/local/cuda-12.8
export PATH=$CUDA_HOME/bin:$PATH
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH
export VLLM_USE_FLASHINFER_SAMPLER=1
MODEL="/opt/llm/modelos/Qwopus3.5-9B-v3"
PORT="8000"
HOST="0.0.0.0"
MAX_LEN="32768"
DTYPE="bfloat16"
KV_CACHE_DTYPE="auto"
GPU_MEMORY_UTIL="0.94"
NUM_SEQS="16"
BAT_TOK="8192"

taskset -c 0-11 vllm serve "$MODEL"
--language-model-only
--no-use-tqdm-on-load
--served-model-name Qwen3.5-Claude
--tokenizer-mode auto
--max-model-len $MAX_LEN
--enable-auto-tool-choice
--dtype $DTYPE
--block-size 32
--seed 3407
--optimization-level 3
--max-num-batched-tokens $BAT_TOK
--enable-chunked-prefill
--enable-prefix-caching
--async-scheduling
--performance-mode interactivity
--mamba-cache-mode align
--mamba-block-size 8
--kv-cache-dtype $KV_CACHE_DTYPE
--gpu-memory-utilization $GPU_MEMORY_UTIL
--reasoning-parser qwen3
--tool-call-parser qwen3_coder
--host $HOST
--port $PORT
--max-num-seqs $NUM_SEQS
--attention-backend FLASHINFER
--speculative-config '{"method":"mtp","num_speculative_tokens":5}'

This comment has been hidden (marked as Off-Topic)

Sign up or log in to comment