| import subprocess |
| import torch |
|
|
| |
| |
| |
| |
| |
|
|
| import whisperx |
| import os, gc |
|
|
| import time |
| import json |
| import base64 |
| import numpy as np |
|
|
| DEVNULL = open(os.devnull, "w") |
|
|
|
|
| |
| from typing import Dict, List, Any |
|
|
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| SAMPLE_RATE = 16000 |
|
|
|
|
| def whisper_config(): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| whisper_model = "large-v3" |
| batch_size = 48 if device == "cuda" else 1 |
| compute_type = "float16" if device == "cuda" else "int8" |
| return device, batch_size, compute_type, whisper_model |
|
|
|
|
| |
| |
| def ffmpeg_load_audio(filename, sr=44100, mono=False, normalize=True, in_type=np.int16, out_type=np.float32): |
| channels = 1 if mono else 2 |
| format_strings = { |
| np.float64: "f64le", |
| np.float32: "f32le", |
| np.int16: "s16le", |
| np.int32: "s32le", |
| np.uint32: "u32le", |
| } |
| format_string = format_strings[in_type] |
| command = [ |
| "ffmpeg", |
| "-i", |
| filename, |
| "-f", |
| format_string, |
| "-acodec", |
| "pcm_" + format_string, |
| "-ar", |
| str(sr), |
| "-ac", |
| str(channels), |
| "-", |
| ] |
| p = subprocess.Popen(command, stdout=subprocess.PIPE, stderr=DEVNULL, bufsize=4096) |
| bytes_per_sample = np.dtype(in_type).itemsize |
| frame_size = bytes_per_sample * channels |
| chunk_size = frame_size * sr |
| raw = b"" |
| with p.stdout as stdout: |
| while True: |
| data = stdout.read(chunk_size) |
| if data: |
| raw += data |
| else: |
| break |
| audio = np.fromstring(raw, dtype=in_type).astype(out_type) |
| if channels > 1: |
| audio = audio.reshape((-1, channels)).transpose() |
| if audio.size == 0: |
| return audio, sr |
| if issubclass(out_type, np.floating): |
| if normalize: |
| peak = np.abs(audio).max() |
| if peak > 0: |
| audio /= peak |
| elif issubclass(in_type, np.integer): |
| audio /= np.iinfo(in_type).max |
| return audio |
|
|
|
|
| |
| def ffmpeg_read(bpayload: bytes, sampling_rate: int) -> np.array: |
| """ |
| Helper function to read an audio file through ffmpeg. |
| """ |
| ar = f"{sampling_rate}" |
| ac = "1" |
| format_for_conversion = "f32le" |
| ffmpeg_command = [ |
| "ffmpeg", |
| "-i", |
| "pipe:0", |
| "-ac", |
| ac, |
| "-ar", |
| ar, |
| "-f", |
| format_for_conversion, |
| "-hide_banner", |
| "-loglevel", |
| "quiet", |
| "pipe:1", |
| ] |
|
|
| try: |
| with subprocess.Popen(ffmpeg_command, stdin=subprocess.PIPE, stdout=subprocess.PIPE) as ffmpeg_process: |
| output_stream = ffmpeg_process.communicate(bpayload) |
| except FileNotFoundError as error: |
| raise ValueError("ffmpeg was not found but is required to load audio files from filename") from error |
| out_bytes = output_stream[0] |
| audio = np.frombuffer(out_bytes, np.float32) |
| if audio.shape[0] == 0: |
| raise ValueError( |
| "Soundfile is either not in the correct format or is malformed. Ensure that the soundfile has " |
| "a valid audio file extension (e.g. wav, flac or mp3) and is not corrupted. If reading from a remote " |
| "URL, ensure that the URL is the full address to **download** the audio file." |
| ) |
| return audio |
|
|
|
|
| |
| def load_audio(file: str, sr: int = SAMPLE_RATE): |
| """ |
| Open an audio file and read as mono waveform, resampling as necessary |
| |
| Parameters |
| ---------- |
| file: str |
| The audio file to open |
| |
| sr: int |
| The sample rate to resample the audio if necessary |
| |
| Returns |
| ------- |
| A NumPy array containing the audio waveform, in float32 dtype. |
| """ |
| try: |
| |
| |
| cmd = [ |
| "ffmpeg", |
| "-nostdin", |
| "-threads", |
| "0", |
| "-i", |
| file, |
| "-f", |
| "s16le", |
| "-ac", |
| "1", |
| "-acodec", |
| "pcm_s16le", |
| "-ar", |
| str(sr), |
| "-", |
| ] |
| out = subprocess.run(cmd, capture_output=True, check=True).stdout |
| except subprocess.CalledProcessError as e: |
| raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e |
|
|
| return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0 |
|
|
|
|
| def display_gpu_infos(): |
| if not torch.cuda.is_available(): |
| return "NO CUDA" |
|
|
| infos = "torch.cuda.current_device(): " + str(torch.cuda.current_device()) + ", " |
| infos = infos + "torch.cuda.device(0): " + str(torch.cuda.device(0)) + ", " |
| infos = infos + "torch.cuda.device_count(): " + str(torch.cuda.device_count()) + ", " |
| infos = infos + "torch.cuda.get_device_name(0): " + str(torch.cuda.get_device_name(0)) |
| return infos |
|
|
|
|
| class EndpointHandler: |
| def __init__(self, path=""): |
| |
| device, batch_size, compute_type, whisper_model = whisper_config() |
| self.model = whisperx.load_model(whisper_model, device=device, compute_type=compute_type, language="fr") |
| |
| |
| |
|
|
| self.diarize_model = whisperx.DiarizationPipeline( |
| "pyannote/speaker-diarization-3.1", use_auth_token="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", device=device |
| ) |
|
|
| logger.info(f"Model for diarization initialized") |
|
|
| def __call__(self, data: Any) -> Dict[str, str]: |
| """ |
| Args: |
| data (:obj:): |
| includes the deserialized audio file as bytes |
| Return: |
| A :obj:`dict`:. base64 encoded image |
| """ |
| |
| st = time.time() |
|
|
| logger.info("--------------- CONFIGURATION ------------------------") |
| device, batch_size, compute_type, whisper_model = whisper_config() |
| logger.info(display_gpu_infos()) |
|
|
| |
| parameters = data.pop("parameters", None) |
| options = data.pop("options", None) |
|
|
| |
| info = options.get("info", False) |
| transcribe = options.get("transcription", False) |
| alignment = options.get("alignment", False) |
| diarization = options.get("diarization", False) |
| language = parameters.get("language", "fr") |
| min_speakers = parameters.get("min_speakers", 2) |
| max_speakers = parameters.get("max_speakers", 25) |
|
|
| |
| if transcribe: |
| inputs_encoded = data.pop("inputs", data) |
| elif diarization: |
| inputs_encoded, transcription = data.pop("inputs", data) |
|
|
| inputs = base64.b64decode(inputs_encoded) |
| logger.info(f"inputs decoded.") |
| |
| with open("/tmp/myfile.tmp", "wb") as w: |
| w.write(inputs) |
| logger.info(f"inputs saved.") |
|
|
| audio_nparray = load_audio("/tmp/myfile.tmp", sr=SAMPLE_RATE) |
| logger.info(f"inputs loaded as mono 16kHz.") |
| |
| os.remove("/tmp/myfile.tmp") |
| logger.info(f"temp file removed.") |
|
|
| et = time.time() |
| elapsed_time = et - st |
|
|
| logger.info(f"TIME for audio processing : {elapsed_time:.2f} seconds") |
| if info: |
| print(f"TIME for audio processing : {elapsed_time:.2f} seconds") |
|
|
| |
| if transcribe: |
| gc.collect() |
| torch.cuda.empty_cache() |
| logger.info("--------------- STARTING TRANSCRIPTION ------------------------") |
| transcription = self.model.transcribe(audio_nparray, batch_size=batch_size, language=language) |
| if info: |
| print(transcription["segments"][0:10_000]) |
| else: |
| logger.info(transcription["segments"][0:1_000]) |
|
|
| try: |
| first_text = transcription["segments"][0]["text"] |
| except: |
| logger.warning("No transcription") |
| return {"transcription": transcription["segments"]} |
|
|
| et = time.time() |
| elapsed_time = et - st |
| st = time.time() |
| logger.info(f"TIME for audio transcription : {elapsed_time:.2f} seconds") |
| if info: |
| print(f"TIME for audio transcription : {elapsed_time:.2f} seconds") |
|
|
| |
| if alignment: |
| gc.collect() |
| torch.cuda.empty_cache() |
| logger.info("--------------- STARTING ALIGNMENT ------------------------") |
| model_a, metadata = whisperx.load_align_model(language_code=transcription["language"], device=device) |
| transcription = whisperx.align( |
| transcription["segments"], model_a, metadata, audio_nparray, device, return_char_alignments=False |
| ) |
| del model_a |
| if info: |
| print(transcription["segments"][0:10000]) |
| else: |
| logger.info(transcription["segments"][0:1_000]) |
|
|
| et = time.time() |
| elapsed_time = et - st |
| st = time.time() |
| logger.info(f"TIME for alignment : {elapsed_time:.2f} seconds") |
| if info: |
| print(f"TIME for alignment : {elapsed_time:.2f} seconds") |
|
|
| |
| if diarization: |
| gc.collect() |
| torch.cuda.empty_cache() |
| logger.info("--------------- STARTING DIARIZATION ------------------------") |
| if not transcription: |
| logger.warning("No transcription to diarize") |
| |
| diarize_segments = self.diarize_model(audio_nparray, min_speakers=min_speakers, max_speakers=max_speakers) |
| if info: |
| print(diarize_segments) |
| else: |
| logger.info(diarize_segments) |
|
|
| transcription = whisperx.assign_word_speakers(diarize_segments, transcription) |
|
|
| et = time.time() |
| elapsed_time = et - st |
| st = time.time() |
| logger.info(f"TIME for audio diarization : {elapsed_time:.2f} seconds") |
| if info: |
| print(f"TIME for audio diarization : {elapsed_time:.2f} seconds") |
|
|
| |
| |
| |
| gc.collect() |
| torch.cuda.empty_cache() |
| return transcription |
|
|