flash_attention_2 does not work in Qwen3-TTS

#5
by Matilda83 - opened

Hello!

Why doesn't flash_attention_2 work on Qwen3_TTS?
The model architecture is incompatible with FA2. But... the documentation says to enable it and use it to speed up generation.

Generation is very slow on the RTX 5070 ti! A 30-second sample takes 35 (~350 tokens) seconds to generate. The 0.6B-Base model loads like this:

tts = Qwen3TTSModel.from_pretrained(
    MODEL_PATH,
    device_map=device,
    dtype=torch.bfloat16,
    attn_implementation="flash_attention_2" <<<--- no effect
)

....
common_gen_kwargs = dict(
do_sample=True,
top_k=30,
top_p=1.0,
temperature=0.7,
repetition_penalty=1.15,
subtalker_dosample=False
)
....
wavs, sr = tts.generate_voice_clone(
text=syn_text,
max_new_tokens=512,
flash_attn=True,
non_streaming_mode=True,
language="Auto",
voice_clone_prompt=prompt_items,
use_cache=True,
**common_gen_kwargs,
)

Python3.10

$ pip show torch
Name: torch
Version: 2.10.0+cu130
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org
Author:
Author-email: PyTorch Team packages@pytorch.org
License: BSD-3-Clause
Location: /AI/.venv/lib/python3.10/site-packages
Requires: cuda-bindings, filelock, fsspec, jinja2, networkx, nvidia-cublas, nvidia-cuda-cupti, nvidia-cuda-nvrtc, nvidia-cuda-runtime, nvidia-cudnn-cu13, nvidia-cufft, nvidia-cufile, nvidia-curand, nvidia-cusolver, nvidia-cusparse, nvidia-cusparselt-cu13, nvidia-nccl-cu13, nvidia-nvjitlink, nvidia-nvshmem-cu13, nvidia-nvtx, sympy, triton, typing-extensions
Required-by: accelerate, bitsandbytes, flash-attn, peft, torchaudio, torchvision

$ pip show flash_attn
Name: flash-attn
Version: 2.8.3
Summary: Flash Attention: Fast and Memory-Efficient Exact Attention
Home-page: https://github.com/Dao-AILab/flash-attention
Author: Tri Dao
Author-email: tri@tridao.me
License:
Location: /AI/.venv/lib/python3.10/site-packages
Requires: einops, torch
Required-by:

$ pip show qwen-tts
Name: qwen-tts
Version: 0.0.5
Summary: Qwen-TTS python package
Home-page: https://github.com/Qwen/Qwen3-TTS
Author: Alibaba Qwen Team
Author-email:
License: Apache-2.0
Location: /AI/.venv/lib/python3.10/site-packages
Requires: accelerate, einops, gradio, librosa, onnxruntime, soundfile, sox, torchaudio, transformers
Required-by:

$ pip show transformers
Name: transformers
Version: 4.57.3
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /AI/.venv/lib/python3.10/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: peft, qwen-tts, trl

I'm certainly not an expert, and I have the exact same problem, but could this be the result of cu130? It seems that official support for flash_attn ends with cu128. Also, the output of your pip show flash_attn does not indicate that you are using the cu130 build. I had problems with the build, so I used a ready-made whl, but the problem remains the same.

[lada]>[/env/tts]>[~]>pip show flash_attn
Name: flash_attn
Version: 2.8.3+cu130torch2.10
Summary: Flash Attention: Fast and Memory-Efficient Exact Attention
Home-page: https://github.com/Dao-AILab/flash-attention
Author: Tri Dao
Author-email: tri@tridao.me
License: 
Location: /home/lada/miniconda3/envs/tts/lib/python3.12/site-packages
Requires: einops, torch
Required-by: 

Sign up or log in to comment