| """Custom inference handler for HuggingFace Inference Endpoints.""" |
|
|
| from typing import Any, Dict, List, Union |
|
|
| import torch |
|
|
| try: |
| |
| from .asr_modeling import ASRModel |
| from .asr_pipeline import ASRPipeline |
| except ImportError: |
| |
| from asr_modeling import ASRModel |
| from asr_pipeline import ASRPipeline |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path: str = ""): |
| import os |
|
|
| import nltk |
|
|
| nltk.download("punkt_tab", quiet=True) |
|
|
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") |
|
|
| |
| |
| torch.backends.cuda.matmul.allow_tf32 = True |
| torch.backends.cudnn.allow_tf32 = True |
|
|
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| |
| self.dtype = torch.float16 if self.device == "cuda" else torch.float32 |
|
|
| |
| if torch.cuda.is_available(): |
| torch.backends.cudnn.benchmark = True |
|
|
| |
| model_kwargs = { |
| "dtype": self.dtype, |
| "low_cpu_mem_usage": True, |
| } |
| if torch.cuda.is_available(): |
| model_kwargs["attn_implementation"] = ( |
| "flash_attention_2" if self._is_flash_attn_available() else "sdpa" |
| ) |
|
|
| |
| self.model = ASRModel.from_pretrained(path, **model_kwargs) |
|
|
| |
| self.pipe = ASRPipeline( |
| model=self.model, |
| feature_extractor=self.model.feature_extractor, |
| tokenizer=self.model.tokenizer, |
| device=self.device, |
| ) |
|
|
| |
| |
| |
| if torch.cuda.is_available() and os.getenv("ENABLE_TORCH_COMPILE", "1") == "1": |
| compile_mode = os.getenv("TORCH_COMPILE_MODE", "default") |
| self.model = torch.compile(self.model, mode=compile_mode) |
| self.pipe.model = self.model |
|
|
| |
| if torch.cuda.is_available(): |
| self._warmup() |
|
|
| def _is_flash_attn_available(self): |
| """Check if flash attention is available.""" |
| import importlib.util |
|
|
| return importlib.util.find_spec("flash_attn") is not None |
|
|
| def _warmup(self): |
| """Warmup to trigger model compilation and allocate GPU memory.""" |
| try: |
| |
| sample_rate = self.pipe.model.config.audio_sample_rate |
| dummy_audio = torch.randn(sample_rate, dtype=torch.float32) |
|
|
| |
| with torch.inference_mode(): |
| warmup_tokens = self.pipe.model.config.inference_warmup_tokens |
| _ = self.pipe( |
| {"raw": dummy_audio, "sampling_rate": sample_rate}, |
| max_new_tokens=warmup_tokens, |
| ) |
|
|
| |
| if torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| |
| torch.cuda.empty_cache() |
|
|
| except Exception as e: |
| print(f"Warmup skipped due to: {e}") |
|
|
| def __call__(self, data: Dict[str, Any]) -> Union[Dict[str, Any], List[Dict[str, Any]]]: |
| inputs = data.get("inputs") |
| if inputs is None: |
| raise ValueError("Missing 'inputs' in request data") |
|
|
| |
| params = data.get("parameters", {}) |
|
|
| return self.pipe(inputs, **params) |
|
|