import os import tempfile import logging import subprocess from fastapi import FastAPI, UploadFile, File, HTTPException from fastapi.middleware.cors import CORSMiddleware import uvicorn logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI(title="Speaker Diarization Service") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) pipeline = None # ───────────────────────────────────────────────────────────── # Convert webm → wav (REQUIRED for pyannote) # ───────────────────────────────────────────────────────────── def convert_to_wav(input_path): output_path = input_path.replace(".webm", ".wav") try: subprocess.run([ "ffmpeg", "-y", "-i", input_path, "-ac", "1", # mono "-ar", "16000", # 16kHz (required) output_path ], check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) return output_path except subprocess.CalledProcessError as e: logger.error(f"FFmpeg conversion failed: {e}") raise Exception("Audio conversion failed (ffmpeg error)") # ───────────────────────────────────────────────────────────── # Load diarization pipeline # ───────────────────────────────────────────────────────────── @app.on_event("startup") async def load_pipeline(): global pipeline hf_token = os.environ.get("HF_TOKEN") logger.info(f"HF_TOKEN exists: {bool(hf_token)}") if not hf_token: logger.error("HF_TOKEN not set — diarization will not work") return try: from pyannote.audio import Pipeline import torch logger.info("Loading pyannote speaker diarization pipeline...") pipeline = Pipeline.from_pretrained( "pyannote/speaker-diarization-3.1", use_auth_token=hf_token ) pipeline = pipeline.to(torch.device("cpu")) logger.info("Pipeline loaded successfully on cpu") except Exception as e: logger.error(f"Failed to load pipeline: {e}") import traceback logger.error(traceback.format_exc()) pipeline = None # ───────────────────────────────────────────────────────────── # Health check # ───────────────────────────────────────────────────────────── @app.get("/health") def health(): return { "status": "ok", "pipeline_loaded": pipeline is not None } # ───────────────────────────────────────────────────────────── # Diarization endpoint # ───────────────────────────────────────────────────────────── @app.post("/diarize") async def diarize( file: UploadFile = File(...), num_speakers: int = None ): if pipeline is None: raise HTTPException( status_code=503, detail="Diarization pipeline not loaded. Check HF_TOKEN and logs." ) suffix = os.path.splitext(file.filename or "audio.webm")[1] or ".webm" tmp_path = None wav_path = None try: # Save uploaded file with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as tmp: content = await file.read() tmp.write(content) tmp_path = tmp.name logger.info( f"Diarizing {file.filename} ({len(content)/1024:.1f}KB), " f"num_speakers={num_speakers}" ) # ── Convert to WAV (CRITICAL FIX) ─────────────────────── wav_path = convert_to_wav(tmp_path) diarize_kwargs = {} if num_speakers and num_speakers > 1: diarize_kwargs["num_speakers"] = num_speakers diarization = pipeline(wav_path, **diarize_kwargs) segments = [] speakers_seen = set() for turn, _, speaker in diarization.itertracks(yield_label=True): segments.append({ "start": round(turn.start, 3), "end": round(turn.end, 3), "speaker": speaker }) speakers_seen.add(speaker) logger.info(f"Done: {len(segments)} segments, {len(speakers_seen)} speakers") return { "segments": segments, "num_speakers_detected": len(speakers_seen) } except Exception as e: logger.error(f"Diarization error: {e}") raise HTTPException(status_code=500, detail=str(e)) finally: # Cleanup temp files if tmp_path and os.path.exists(tmp_path): os.unlink(tmp_path) if wav_path and os.path.exists(wav_path): os.unlink(wav_path) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=7860)