| | """ |
| | HuggingFace Inference Endpoint Handler for SongFormer |
| | Supports binary audio input (WAV, MP3, etc.) via base64 encoding or direct bytes |
| | """ |
| |
|
| | import os |
| | import sys |
| | import io |
| | import base64 |
| | import json |
| | import tempfile |
| | from typing import Dict, Any, Union |
| | import librosa |
| | import numpy as np |
| | import torch |
| | from transformers import AutoModel |
| |
|
| | class EndpointHandler: |
| | """ |
| | HuggingFace Inference Endpoint Handler for SongFormer model. |
| | |
| | Accepts base64-encoded audio (WAV, MP3, FLAC, etc.) |
| | """ |
| |
|
| | def __init__(self, path: str = ""): |
| | """ |
| | Initialize the handler and load the SongFormer model. |
| | |
| | Args: |
| | path: Path to the model directory (provided by HuggingFace) |
| | """ |
| | |
| | self.model_path = path or os.getcwd() |
| | os.environ["SONGFORMER_LOCAL_DIR"] = self.model_path |
| | sys.path.insert(0, self.model_path) |
| |
|
| | |
| |
|
| | |
| | self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| | print(f"Loading SongFormer model on {self.device}...") |
| |
|
| | |
| | |
| | self.model = AutoModel.from_pretrained( |
| | self.model_path, |
| | trust_remote_code=True, |
| | device_map=None, |
| | ) |
| | self.model.to(self.device) |
| | self.model.eval() |
| |
|
| | |
| | self.target_sr = 24000 |
| |
|
| | print("SongFormer model loaded successfully!") |
| |
|
| | def _decode_base64_audio(self, audio_b64: str) -> np.ndarray: |
| | """ |
| | Decode base64-encoded audio to numpy array. |
| | |
| | Args: |
| | audio_b64: Base64-encoded audio string |
| | |
| | Returns: |
| | numpy array of audio samples at 24kHz |
| | """ |
| | |
| | try: |
| | audio_bytes = base64.b64decode(audio_b64) |
| | except Exception as e: |
| | raise ValueError(f"Failed to decode base64 audio data: {e}") |
| |
|
| | |
| |
|
| | |
| | audio_io = io.BytesIO(audio_bytes) |
| |
|
| | |
| | audio_array, _ = librosa.load(audio_io, sr=self.target_sr, mono=True) |
| |
|
| | return audio_array |
| |
|
| | def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
| | """ |
| | Process inference request with base64-encoded audio. |
| | |
| | Expected input: |
| | { |
| | "inputs": "<base64-encoded-audio-data>" |
| | } |
| | |
| | Returns: |
| | { |
| | "segments": [ |
| | { |
| | "label": "intro", |
| | "start": 0.0, |
| | "end": 15.2 |
| | }, |
| | ... |
| | ], |
| | "duration": 180.5, |
| | "num_segments": 8 |
| | } |
| | """ |
| | try: |
| | |
| | audio_b64 = data.get("inputs") |
| | if not audio_b64: |
| | raise ValueError("Missing 'inputs' key with base64-encoded audio") |
| |
|
| | if not isinstance(audio_b64, str): |
| | raise ValueError("Input must be a base64-encoded string") |
| |
|
| | |
| | audio_array = self._decode_base64_audio(audio_b64) |
| |
|
| | |
| | with torch.no_grad(): |
| | result = self.model(audio_array) |
| |
|
| | |
| | duration = len(audio_array) / self.target_sr |
| |
|
| | |
| | output = { |
| | "segments": result, |
| | "duration": float(duration), |
| | "num_segments": len(result) |
| | } |
| |
|
| | return output |
| |
|
| | except Exception as e: |
| | |
| | return { |
| | "error": str(e), |
| | "error_type": type(e).__name__, |
| | "segments": [], |
| | "duration": 0.0, |
| | "num_segments": 0 |
| | } |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description="Test SongFormer handler locally") |
| | parser.add_argument("audio_file", help="Path to audio file to test") |
| | parser.add_argument("--model-path", default=".", help="Path to model directory") |
| | args = parser.parse_args() |
| |
|
| | |
| | handler = EndpointHandler(args.model_path) |
| |
|
| | |
| | with open(args.audio_file, "rb") as f: |
| | audio_bytes = f.read() |
| |
|
| | audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') |
| |
|
| | |
| | print("\n=== Testing with base64-encoded audio ===") |
| | result = handler({"inputs": audio_b64}) |
| | print(json.dumps(result, indent=2)) |
| |
|
| | |
| | print("\n=== Testing with direct file path (not typical for endpoint) ===") |
| | result_direct = handler.model(args.audio_file) |
| | print(json.dumps(result_direct, indent=2)) |
| |
|