import asyncio import json import time import logging from typing import Optional import torch import numpy as np import librosa from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse import uvicorn # Version tracking VERSION = "1.1.0" COMMIT_SHA = "TBD" # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global model variables model = None processor = None device = None async def load_model(): """Load STT model on startup""" global model, processor, device try: logger.info("Loading STT model...") device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # Try to load the actual model - fallback to mock if not available try: from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration model_id = "kyutai/stt-1b-en_fr" processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id).to(device) logger.info(f"Model {model_id} loaded successfully") except Exception as model_error: logger.warning(f"Could not load actual model: {model_error}") logger.info("Using mock STT for development") model = "mock" processor = "mock" except Exception as e: logger.error(f"Error loading model: {e}") model = "mock" processor = "mock" def transcribe_audio(audio_data: np.ndarray, sample_rate: int = 16000) -> str: """Transcribe audio data""" try: if model == "mock": # Mock transcription for development return f"Mock transcription: {len(audio_data)} samples at {sample_rate}Hz" # Real transcription inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): generated_ids = model.generate(**inputs) transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return transcription except Exception as e: logger.error(f"Transcription error: {e}") return f"Error: {str(e)}" # FastAPI app app = FastAPI( title="STT GPU Service Python v4", description="Real-time WebSocket STT streaming with kyutai/stt-1b-en_fr", version=VERSION ) @app.on_event("startup") async def startup_event(): """Load model on startup""" await load_model() @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "timestamp": time.time(), "version": VERSION, "commit_sha": COMMIT_SHA, "message": "STT WebSocket Service - Real-time streaming ready", "space_name": "stt-gpu-service-python-v4", "model_loaded": model is not None, "device": str(device) if device else "unknown" } @app.get("/", response_class=HTMLResponse) async def get_index(): """Simple HTML interface for testing""" html_content = f"""
Real-time WebSocket speech transcription service
Status: Disconnected
Transcription output will appear here...