sinama-translator / handler.py
der02's picture
Upload handler.py with huggingface_hub
965496a verified
"""
Custom Inference API handler for the Sinama audio classifier.
Receives a raw audio file (WAV, MP3, etc.), extracts Mel Spectrogram
features, runs inference through the CNN, and returns predicted class
probabilities.
"""
import json
import os
import tempfile
import librosa
import numpy as np
import tensorflow as tf
class EndpointHandler:
"""HF Inference Endpoints handler."""
def __init__(self, path: str = ""):
# path is the model directory on the endpoint
model_path = os.path.join(path, "best_model.keras")
self.model = tf.keras.models.load_model(model_path)
with open(os.path.join(path, "label_map.json"), "r") as f:
raw = json.load(f)
self.label_map = {int(k): v for k, v in raw.items()}
with open(os.path.join(path, "config.json"), "r") as f:
self.cfg = json.load(f)
def preprocess(self, audio_bytes: bytes) -> np.ndarray:
"""Convert raw audio bytes into a Mel Spectrogram array."""
sr = self.cfg["sample_rate"]
duration = self.cfg["duration"]
n_mels = self.cfg["n_mels"]
n_fft = self.cfg["n_fft"]
hop = self.cfg["hop_length"]
target_len = int(sr * duration)
# Write bytes to a temp file so librosa can read it
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
waveform, _ = librosa.load(tmp_path, sr=sr, mono=True)
finally:
os.unlink(tmp_path)
# Pad / trim
if len(waveform) < target_len:
waveform = np.pad(waveform, (0, target_len - len(waveform)))
else:
waveform = waveform[:target_len]
# Mel spectrogram
mel = librosa.feature.melspectrogram(
y=waveform, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop
)
mel_db = librosa.power_to_db(mel, ref=np.max)
# Normalise
mean, std = mel_db.mean(), mel_db.std()
mel_db = (mel_db - mean) / (std + 1e-9)
# Add batch + channel dims → (1, freq, time, 1)
return mel_db[np.newaxis, ..., np.newaxis]
def __call__(self, data):
"""
Handle an inference request.
Parameters
----------
data : dict
Either {"inputs": <base64 or bytes>} for audio data,
or the raw request body bytes.
Returns
-------
list[dict] – [{"label": "word", "score": 0.95}, ...]
"""
# Extract audio bytes from the request
if isinstance(data, dict):
audio = data.get("inputs", data.get("body", b""))
else:
audio = data
if isinstance(audio, str):
import base64
audio = base64.b64decode(audio)
features = self.preprocess(audio)
preds = self.model.predict(features, verbose=0)[0]
# Return top-5 predictions sorted by confidence
top_indices = np.argsort(preds)[::-1][:5]
results = [
{"label": self.label_map[int(i)], "score": round(float(preds[i]), 4)}
for i in top_indices
]
return results