| |
| import os |
| import tempfile |
| import torch |
| from pydub import AudioSegment |
| import soundfile as sf |
| from pyannote.audio import Pipeline |
| from transformers import pipeline as hf_pipeline |
|
|
| |
| DIAR_PYMODEL = "pyannote/speaker-diarization" |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) |
| DEVICE = 0 if torch.cuda.is_available() else -1 |
|
|
| |
| DIAR_PIPE = None |
| ASR_PIPE_CACHE = {} |
|
|
| def get_diar_pipeline(): |
| global DIAR_PIPE |
| if DIAR_PIPE is None: |
| |
| DIAR_PIPE = Pipeline.from_pretrained(DIAR_PYMODEL, use_auth_token=HF_TOKEN) |
| return DIAR_PIPE |
|
|
| def get_asr_pipeline(model_id): |
| if model_id in ASR_PIPE_CACHE: |
| return ASR_PIPE_CACHE[model_id] |
| p = hf_pipeline("automatic-speech-recognition", model=model_id, device=DEVICE) |
| ASR_PIPE_CACHE[model_id] = p |
| return p |
|
|
| def diarize_audio_to_segments(audio_path): |
| """ |
| Returns list of segments: [{'start': float, 'end': float, 'speaker': 'SPEAKER_00'}, ...] |
| """ |
| pipeline = get_diar_pipeline() |
| |
| diarization = pipeline(audio_path) |
| segments = [] |
| |
| for turn, _, label in diarization.itertracks(yield_label=True): |
| segments.append({"start": float(turn.start), "end": float(turn.end), "speaker": label}) |
| return segments |
|
|
| def extract_audio_segment(orig_path, start_s, end_s): |
| audio = AudioSegment.from_file(orig_path) |
| start_ms, end_ms = int(start_s * 1000), int(end_s * 1000) |
| chunk = audio[start_ms:end_ms] |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=".wav") |
| chunk.export(tmp.name, format="wav") |
| return tmp.name |
|
|
| def diarized_transcribe(audio_path, model_id): |
| """ |
| Runs diarization then ASR per speaker segment. Returns list of speaker-attributed segments. |
| """ |
| segments = diarize_audio_to_segments(audio_path) |
| asr = get_asr_pipeline(model_id) |
|
|
| speaker_results = [] |
| for seg in segments: |
| seg_path = extract_audio_segment(audio_path, seg["start"], seg["end"]) |
| try: |
| out = asr(seg_path) |
| text = out.get("text", str(out)) |
| except Exception as e: |
| text = f"[ASR error: {e}]" |
| speaker_results.append({ |
| "start": seg["start"], |
| "end": seg["end"], |
| "speaker": seg["speaker"], |
| "text": text |
| }) |
| try: os.unlink(seg_path) |
| except: pass |
|
|
| return speaker_results |
|
|