| """
|
| Custom Inference API handler for the Sinama audio classifier.
|
|
|
| Receives a raw audio file (WAV, MP3, etc.), extracts Mel Spectrogram
|
| features, runs inference through the CNN, and returns predicted class
|
| probabilities.
|
| """
|
|
|
| import json
|
| import os
|
| import tempfile
|
|
|
| import librosa
|
| import numpy as np
|
| import tensorflow as tf
|
|
|
|
|
| class EndpointHandler:
|
| """HF Inference Endpoints handler."""
|
|
|
| def __init__(self, path: str = ""):
|
|
|
| model_path = os.path.join(path, "best_model.keras")
|
| self.model = tf.keras.models.load_model(model_path)
|
|
|
| with open(os.path.join(path, "label_map.json"), "r") as f:
|
| raw = json.load(f)
|
| self.label_map = {int(k): v for k, v in raw.items()}
|
|
|
| with open(os.path.join(path, "config.json"), "r") as f:
|
| self.cfg = json.load(f)
|
|
|
| def preprocess(self, audio_bytes: bytes) -> np.ndarray:
|
| """Convert raw audio bytes into a Mel Spectrogram array."""
|
| sr = self.cfg["sample_rate"]
|
| duration = self.cfg["duration"]
|
| n_mels = self.cfg["n_mels"]
|
| n_fft = self.cfg["n_fft"]
|
| hop = self.cfg["hop_length"]
|
| target_len = int(sr * duration)
|
|
|
|
|
| with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
| tmp.write(audio_bytes)
|
| tmp_path = tmp.name
|
|
|
| try:
|
| waveform, _ = librosa.load(tmp_path, sr=sr, mono=True)
|
| finally:
|
| os.unlink(tmp_path)
|
|
|
|
|
| if len(waveform) < target_len:
|
| waveform = np.pad(waveform, (0, target_len - len(waveform)))
|
| else:
|
| waveform = waveform[:target_len]
|
|
|
|
|
| mel = librosa.feature.melspectrogram(
|
| y=waveform, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop
|
| )
|
| mel_db = librosa.power_to_db(mel, ref=np.max)
|
|
|
|
|
| mean, std = mel_db.mean(), mel_db.std()
|
| mel_db = (mel_db - mean) / (std + 1e-9)
|
|
|
|
|
| return mel_db[np.newaxis, ..., np.newaxis]
|
|
|
| def __call__(self, data):
|
| """
|
| Handle an inference request.
|
|
|
| Parameters
|
| ----------
|
| data : dict
|
| Either {"inputs": <base64 or bytes>} for audio data,
|
| or the raw request body bytes.
|
|
|
| Returns
|
| -------
|
| list[dict] – [{"label": "word", "score": 0.95}, ...]
|
| """
|
|
|
| if isinstance(data, dict):
|
| audio = data.get("inputs", data.get("body", b""))
|
| else:
|
| audio = data
|
|
|
| if isinstance(audio, str):
|
| import base64
|
| audio = base64.b64decode(audio)
|
|
|
| features = self.preprocess(audio)
|
| preds = self.model.predict(features, verbose=0)[0]
|
|
|
|
|
| top_indices = np.argsort(preds)[::-1][:5]
|
| results = [
|
| {"label": self.label_map[int(i)], "score": round(float(preds[i]), 4)}
|
| for i in top_indices
|
| ]
|
| return results
|
|
|