File size: 3,329 Bytes
965496a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
"""

Custom Inference API handler for the Sinama audio classifier.



Receives a raw audio file (WAV, MP3, etc.), extracts Mel Spectrogram

features, runs inference through the CNN, and returns predicted class

probabilities.

"""

import json
import os
import tempfile

import librosa
import numpy as np
import tensorflow as tf


class EndpointHandler:
    """HF Inference Endpoints handler."""

    def __init__(self, path: str = ""):
        # path is the model directory on the endpoint
        model_path = os.path.join(path, "best_model.keras")
        self.model = tf.keras.models.load_model(model_path)

        with open(os.path.join(path, "label_map.json"), "r") as f:
            raw = json.load(f)
        self.label_map = {int(k): v for k, v in raw.items()}

        with open(os.path.join(path, "config.json"), "r") as f:
            self.cfg = json.load(f)

    def preprocess(self, audio_bytes: bytes) -> np.ndarray:
        """Convert raw audio bytes into a Mel Spectrogram array."""
        sr       = self.cfg["sample_rate"]
        duration = self.cfg["duration"]
        n_mels   = self.cfg["n_mels"]
        n_fft    = self.cfg["n_fft"]
        hop      = self.cfg["hop_length"]
        target_len = int(sr * duration)

        # Write bytes to a temp file so librosa can read it
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            tmp.write(audio_bytes)
            tmp_path = tmp.name

        try:
            waveform, _ = librosa.load(tmp_path, sr=sr, mono=True)
        finally:
            os.unlink(tmp_path)

        # Pad / trim
        if len(waveform) < target_len:
            waveform = np.pad(waveform, (0, target_len - len(waveform)))
        else:
            waveform = waveform[:target_len]

        # Mel spectrogram
        mel = librosa.feature.melspectrogram(
            y=waveform, sr=sr, n_mels=n_mels, n_fft=n_fft, hop_length=hop
        )
        mel_db = librosa.power_to_db(mel, ref=np.max)

        # Normalise
        mean, std = mel_db.mean(), mel_db.std()
        mel_db = (mel_db - mean) / (std + 1e-9)

        # Add batch + channel dims  → (1, freq, time, 1)
        return mel_db[np.newaxis, ..., np.newaxis]

    def __call__(self, data):
        """

        Handle an inference request.



        Parameters

        ----------

        data : dict

            Either {"inputs": <base64 or bytes>} for audio data,

            or the raw request body bytes.



        Returns

        -------

        list[dict]  – [{"label": "word", "score": 0.95}, ...]

        """
        # Extract audio bytes from the request
        if isinstance(data, dict):
            audio = data.get("inputs", data.get("body", b""))
        else:
            audio = data

        if isinstance(audio, str):
            import base64
            audio = base64.b64decode(audio)

        features = self.preprocess(audio)
        preds = self.model.predict(features, verbose=0)[0]

        # Return top-5 predictions sorted by confidence
        top_indices = np.argsort(preds)[::-1][:5]
        results = [
            {"label": self.label_map[int(i)], "score": round(float(preds[i]), 4)}
            for i in top_indices
        ]
        return results