| import io |
| import os |
| from huggingface_hub import hf_hub_download |
| from models.audio import AudioFormat, FORMAT_BACKENDS |
| import tempfile |
| import logging |
| import torchaudio |
| from fastapi import HTTPException |
| from fastapi.responses import JSONResponse |
| import torch |
| from typing import Tuple |
| from utils.custom_component import CustomRQBottleneckTransformer |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class AudioTokenizerService: |
| def __init__(self): |
| self.available_backends = torchaudio.list_audio_backends() |
| logger.info(f"Available backends: {self.available_backends}") |
| main_directory = os.path.dirname( |
| os.path.dirname(os.path.realpath(__file__))) |
|
|
| |
| self.has_ffmpeg = "ffmpeg" in self.available_backends |
| if not self.has_ffmpeg: |
| logger.warning( |
| "FFMPEG backend not available. Some formats may not be supported") |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| if not os.path.exists(main_directory+"/whisper-vq-stoks-v3-7lang-fixed.model"): |
| hf_hub_download( |
| repo_id="jan-hq/WhisperVQ", |
| filename="whisper-vq-stoks-v3-7lang-fixed.model", |
| local_dir=main_directory, |
| ) |
| self.vq_model = CustomRQBottleneckTransformer.load_vq_only( |
| main_directory + |
| "/whisper-vq-stoks-v3-7lang-fixed.model" |
| ).to(device) |
| self.vq_model.load_encoder(device) |
| self.vq_model.eval() |
| |
|
|
| def _get_best_backend(self, format: AudioFormat) -> str: |
| """Determine the best backend for the given format""" |
| supported_backends = FORMAT_BACKENDS[format] |
| for backend in supported_backends: |
| if backend in self.available_backends: |
| return backend |
| raise ValueError(f"No available backend supports format {format}") |
|
|
| def load_audio( |
| self, |
| file_obj: bytes, |
| format: AudioFormat, |
| target_sr: int = 16000 |
| ) -> Tuple[torch.Tensor, int]: |
| """ |
| Load audio from bytes object with format handling |
| |
| Args: |
| file_obj: Audio file bytes |
| format: Audio format enum |
| target_sr: Target sample rate (default: 16000) |
| |
| Returns: |
| Tuple[torch.Tensor, int]: Audio tensor and sample rate |
| """ |
| try: |
| |
| backend = self._get_best_backend(format) |
| torchaudio.set_audio_backend(backend) |
| logger.info(f"Using {backend} backend for {format} format") |
|
|
| if format == AudioFormat.PCM: |
| |
| wav = torch.frombuffer(file_obj, dtype=torch.int16) |
| wav = wav.float() / 32768.0 |
| wav = wav.unsqueeze(0) |
| sr = target_sr |
| else: |
| |
| if os.name == "nt": |
| wav, sr = torchaudio.load(io.BytesIO(file_obj)) |
| else: |
| with tempfile.NamedTemporaryFile(suffix=f".{format}") as temp_file: |
| |
| temp_file.write(file_obj) |
| temp_file.flush() |
|
|
| |
| wav, sr = torchaudio.load(temp_file.name) |
|
|
| |
| if wav.shape[0] > 1: |
| wav = torch.mean(wav, dim=0, keepdim=True) |
|
|
| |
| if sr != target_sr: |
| wav = torchaudio.functional.resample(wav, sr, target_sr) |
| sr = target_sr |
|
|
| return wav, sr |
|
|
| except Exception as e: |
| logger.error(f"Error loading audio: {e}") |
| raise HTTPException( |
| status_code=400, |
| detail=f"Error processing {format} audio: {str(e)}" |
| ) |
|
|
| def get_format_info(self) -> dict: |
| """Get information about supported formats""" |
| supported_formats = {} |
| for format in AudioFormat: |
| try: |
| backend = self._get_best_backend(format) |
| supported_formats[format] = { |
| "supported": True, |
| "backend": backend |
| } |
| except ValueError: |
| supported_formats[format] = { |
| "supported": False, |
| "backend": None |
| } |
| return supported_formats |
|
|
| def tokenize(self, audio_data: bytes, format: AudioFormat = "wav"): |
| try: |
| wav, sr = self.load_audio(audio_data, format) |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| wav = wav.to(device) |
|
|
| |
| with torch.no_grad(): |
| codes = self.vq_model.encode_audio(wav) |
| codes = codes[0].cpu().tolist() |
|
|
| |
| result = ''.join(f'<|sound_{num:04d}|>' for num in codes) |
|
|
| return JSONResponse(content={ |
| "model_name": "whisper-vq-stoks-v3-7lang-fixed.model", |
| "tokens": f'<|sound_start|>{result}<|sound_end|>', |
| "format": format, |
| "sample_rate": sr, |
| "backend_used": self._get_best_backend(format) |
| }) |
|
|
| except Exception as e: |
| logger.error(f"Error processing request: {e}") |
| raise HTTPException( |
| status_code=500, |
| detail=f"Error processing request: {str(e)}" |
| ) |
|
|
|
|
| _audio_tokenizer_service = None |
|
|
|
|
| def get_audio_tokenizer_service(): |
| global _audio_tokenizer_service |
| if _audio_tokenizer_service is None: |
| _audio_tokenizer_service = AudioTokenizerService() |
| return _audio_tokenizer_service |
|
|