stt-gpu-service-python-v4 / app_docker_streaming.py
Peter Michael Gits
Fix Dockerfile directory permissions - create /app as root before switching users
26096f4
import asyncio
import json
import time
import logging
from typing import Optional
import torch
import numpy as np
import librosa
from fastapi import FastAPI, WebSocket, WebSocketDisconnect, HTTPException
from fastapi.responses import JSONResponse
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse
import uvicorn
# Version tracking
VERSION = "1.1.0"
COMMIT_SHA = "TBD"
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global model variables
model = None
processor = None
device = None
async def load_model():
"""Load STT model on startup"""
global model, processor, device
try:
logger.info("Loading STT model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")
# Try to load the actual model - fallback to mock if not available
try:
from transformers import KyutaiSpeechToTextProcessor, KyutaiSpeechToTextForConditionalGeneration
model_id = "kyutai/stt-1b-en_fr"
processor = KyutaiSpeechToTextProcessor.from_pretrained(model_id)
model = KyutaiSpeechToTextForConditionalGeneration.from_pretrained(model_id).to(device)
logger.info(f"Model {model_id} loaded successfully")
except Exception as model_error:
logger.warning(f"Could not load actual model: {model_error}")
logger.info("Using mock STT for development")
model = "mock"
processor = "mock"
except Exception as e:
logger.error(f"Error loading model: {e}")
model = "mock"
processor = "mock"
def transcribe_audio(audio_data: np.ndarray, sample_rate: int = 16000) -> str:
"""Transcribe audio data"""
try:
if model == "mock":
# Mock transcription for development
return f"Mock transcription: {len(audio_data)} samples at {sample_rate}Hz"
# Real transcription
inputs = processor(audio_data, sampling_rate=sample_rate, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
generated_ids = model.generate(**inputs)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
return transcription
except Exception as e:
logger.error(f"Transcription error: {e}")
return f"Error: {str(e)}"
# FastAPI app
app = FastAPI(
title="STT GPU Service Python v4",
description="Real-time WebSocket STT streaming with kyutai/stt-1b-en_fr",
version=VERSION
)
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
await load_model()
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"timestamp": time.time(),
"version": VERSION,
"commit_sha": COMMIT_SHA,
"message": "STT WebSocket Service - Real-time streaming ready",
"space_name": "stt-gpu-service-python-v4",
"model_loaded": model is not None,
"device": str(device) if device else "unknown"
}
@app.get("/", response_class=HTMLResponse)
async def get_index():
"""Simple HTML interface for testing"""
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>STT GPU Service Python v4</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
.container {{ max-width: 800px; margin: 0 auto; }}
.status {{ background: #f0f0f0; padding: 20px; border-radius: 8px; margin: 20px 0; }}
button {{ padding: 10px 20px; margin: 5px; background: #007bff; color: white; border: none; border-radius: 4px; cursor: pointer; }}
button:disabled {{ background: #ccc; }}
#output {{ background: #f8f9fa; padding: 15px; border-radius: 4px; margin-top: 20px; }}
.version {{ font-size: 0.8em; color: #666; margin-top: 20px; }}
</style>
</head>
<body>
<div class="container">
<h1>🎙️ STT GPU Service Python v4</h1>
<p>Real-time WebSocket speech transcription service</p>
<div class="status">
<h3>WebSocket Streaming Test</h3>
<button onclick="startWebSocket()">Connect WebSocket</button>
<button onclick="stopWebSocket()" disabled id="stopBtn">Disconnect</button>
<p>Status: <span id="wsStatus">Disconnected</span></p>
</div>
<div id="output">
<p>Transcription output will appear here...</p>
</div>
<div class="version">
v{VERSION} (SHA: {COMMIT_SHA})
</div>
</div>
<script>
let ws = null;
function startWebSocket() {{
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const wsUrl = `${{protocol}}//${{window.location.host}}/ws/stream`;
ws = new WebSocket(wsUrl);
ws.onopen = function(event) {{
document.getElementById('wsStatus').textContent = 'Connected';
document.querySelector('button').disabled = true;
document.getElementById('stopBtn').disabled = false;
// Send test message
ws.send(JSON.stringify({{
type: 'audio_chunk',
data: 'test_audio_data',
timestamp: Date.now()
}}));
}};
ws.onmessage = function(event) {{
const data = JSON.parse(event.data);
document.getElementById('output').innerHTML += `<p>${{JSON.stringify(data, null, 2)}}</p>`;
}};
ws.onclose = function(event) {{
document.getElementById('wsStatus').textContent = 'Disconnected';
document.querySelector('button').disabled = false;
document.getElementById('stopBtn').disabled = true;
}};
ws.onerror = function(error) {{
document.getElementById('output').innerHTML += `<p style="color: red;">WebSocket Error: ${{error}}</p>`;
}};
}}
function stopWebSocket() {{
if (ws) {{
ws.close();
}}
}}
</script>
</body>
</html>
"""
return HTMLResponse(content=html_content)
@app.websocket("/ws/stream")
async def websocket_endpoint(websocket: WebSocket):
"""WebSocket endpoint for real-time audio streaming"""
await websocket.accept()
logger.info("WebSocket connection established")
try:
# Send initial connection confirmation
await websocket.send_json({
"type": "connection",
"status": "connected",
"message": "STT WebSocket ready for audio chunks",
"chunk_size_ms": 80,
"expected_sample_rate": 16000
})
while True:
# Receive audio data
data = await websocket.receive_json()
if data.get("type") == "audio_chunk":
try:
# Process 80ms audio chunk
# In real implementation, you would:
# 1. Decode base64 audio data
# 2. Convert to numpy array
# 3. Process with STT model
# 4. Return transcription
# For now, mock processing
transcription = f"Mock transcription for chunk at {data.get('timestamp', 'unknown')}"
# Send transcription result
await websocket.send_json({
"type": "transcription",
"text": transcription,
"timestamp": time.time(),
"chunk_id": data.get("timestamp"),
"confidence": 0.95
})
except Exception as e:
await websocket.send_json({
"type": "error",
"message": f"Processing error: {str(e)}",
"timestamp": time.time()
})
elif data.get("type") == "ping":
# Respond to ping
await websocket.send_json({
"type": "pong",
"timestamp": time.time()
})
except WebSocketDisconnect:
logger.info("WebSocket connection closed")
except Exception as e:
logger.error(f"WebSocket error: {e}")
await websocket.close(code=1011, reason=f"Server error: {str(e)}")
@app.post("/api/transcribe")
async def api_transcribe(audio_file: Optional[str] = None):
"""REST API endpoint for testing"""
if not audio_file:
raise HTTPException(status_code=400, detail="No audio data provided")
# Mock transcription
result = {
"transcription": f"REST API transcription result for: {audio_file[:50]}...",
"timestamp": time.time(),
"version": VERSION,
"method": "REST"
}
return result
if __name__ == "__main__":
# Run the server
uvicorn.run(
"app:app",
host="0.0.0.0",
port=7860,
log_level="info",
access_log=True
)