import asyncio import json import time import logging from typing import Optional import torch import numpy as np import librosa from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException from fastapi.responses import JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.responses import HTMLResponse import uvicorn # Version tracking VERSION = "1.1.1" COMMIT_SHA = "TBD" # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global model variables model = None processor = None device = None async def load_model(): """Load STT model on startup""" global model, processor, device try: logger.info("Loading STT model...") device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Using device: {device}") # Try to load the actual model - fallback to mock if not available try: from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration model_id = "kyutai/stt-1b-en_fr" logger.info(f"Loading processor from {model_id}...") processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id) logger.info(f"Loading model from {model_id}...") model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id).to(device) logger.info(f"Model {model_id} loaded successfully on {device}") except Exception as model_error: logger.warning(f"Could not load actual model: {model_error}") logger.info("Using mock STT for development") model = "mock" processor = "mock" except Exception as e: logger.error(f"Error loading model: {e}") model = "mock" processor = "mock" def transcribe_audio(audio_data: np.ndarray, sample_rate: int = 24000) -> str: """Transcribe audio data - expects 24kHz audio for Kyutai STT""" try: if model == "mock": # Mock transcription for development duration = len(audio_data) / sample_rate return f"Mock transcription: {duration:.2f}s audio at {sample_rate}Hz ({len(audio_data)} samples)" # Real transcription - Kyutai STT expects 24kHz if sample_rate != 24000: logger.info(f"Resampling from {sample_rate}Hz to 24000Hz") audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=24000) inputs = processor(audio_data, sampling_rate=24000, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): generated_ids = model.generate(**inputs) transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] return transcription except Exception as e: logger.error(f"Transcription error: {e}") return f"Error: {str(e)}" # FastAPI app app = FastAPI( title="STT GPU Service Python v4", description="Real-time WebSocket STT streaming with kyutai/stt-1b-en_fr (24kHz)", version=VERSION ) @app.on_event("startup") async def startup_event(): """Load model on startup""" await load_model() @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "timestamp": time.time(), "version": VERSION, "commit_sha": COMMIT_SHA, "message": "STT WebSocket Service - Real-time streaming ready", "space_name": "stt-gpu-service-python-v4", "model_loaded": model is not None, "device": str(device) if device else "unknown", "expected_sample_rate": "24000Hz" } @app.get("/", response_class=HTMLResponse) async def get_index(): """Simple HTML interface for testing""" html_content = f""" STT GPU Service Python v4

🎙️ STT GPU Service Python v4

Real-time WebSocket speech transcription service (24kHz audio)

WebSocket Streaming Test

Status: Disconnected

Expected: 24kHz audio chunks (80ms = ~1920 samples)

Transcription output will appear here...

v{VERSION} (SHA: {COMMIT_SHA})
""" return HTMLResponse(content=html_content) @app.websocket("/ws/stream") async def websocket_endpoint(websocket: WebSocket): """WebSocket endpoint for real-time audio streaming""" await websocket.accept() logger.info("WebSocket connection established") try: # Send initial connection confirmation await websocket.send_json({ "type": "connection", "status": "connected", "message": "STT WebSocket ready for audio chunks", "chunk_size_ms": 80, "expected_sample_rate": 24000, "expected_chunk_samples": 1920 # 80ms at 24kHz = 1920 samples }) while True: # Receive audio data data = await websocket.receive_json() if data.get("type") == "audio_chunk": try: # Process 80ms audio chunk (1920 samples at 24kHz) # In real implementation, you would: # 1. Decode base64 audio data # 2. Convert to numpy array (24kHz) # 3. Process with STT model # 4. Return transcription # For now, mock processing transcription = f"Mock transcription for 24kHz chunk at {data.get('timestamp', 'unknown')}" # Send transcription result await websocket.send_json({ "type": "transcription", "text": transcription, "timestamp": time.time(), "chunk_id": data.get("timestamp"), "confidence": 0.95 }) except Exception as e: await websocket.send_json({ "type": "error", "message": f"Processing error: {str(e)}", "timestamp": time.time() }) elif data.get("type") == "ping": # Respond to ping await websocket.send_json({ "type": "pong", "timestamp": time.time() }) except WebSocketDisconnect: logger.info("WebSocket connection closed") except Exception as e: logger.error(f"WebSocket error: {e}") await websocket.close(code=1011, reason=f"Server error: {str(e)}") @app.post("/api/transcribe") async def api_transcribe(audio_file: Optional[str] = None): """REST API endpoint for testing""" if not audio_file: raise HTTPException(status_code=400, detail="No audio data provided") # Mock transcription result = { "transcription": f"REST API transcription result for: {audio_file[:50]}...", "timestamp": time.time(), "version": VERSION, "method": "REST", "expected_sample_rate": "24kHz" } return result if __name__ == "__main__": # Run the server uvicorn.run( "app:app", host="0.0.0.0", port=7860, log_level="info", access_log=True )