SongFormer / handler.py
Ben Osheroff
seems unlikely to work but you know
4510d70
"""
HuggingFace Inference Endpoint Handler for SongFormer
Supports binary audio input (WAV, MP3, etc.) via base64 encoding or direct bytes
"""
import os
import sys
import io
import base64
import json
import tempfile
from typing import Dict, Any, Union
import librosa
import numpy as np
import torch
from transformers import AutoModel
class EndpointHandler:
"""
HuggingFace Inference Endpoint Handler for SongFormer model.
Accepts base64-encoded audio (WAV, MP3, FLAC, etc.)
"""
def __init__(self, path: str = ""):
"""
Initialize the handler and load the SongFormer model.
Args:
path: Path to the model directory (provided by HuggingFace)
"""
# Set up environment
self.model_path = path or os.getcwd()
os.environ["SONGFORMER_LOCAL_DIR"] = self.model_path
sys.path.insert(0, self.model_path)
# Import after setting up path
# Load the model
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading SongFormer model on {self.device}...")
# Load model without device_map to avoid meta device initialization
# The SongFormerModel.__init__ now handles meta device detection
self.model = AutoModel.from_pretrained(
self.model_path,
trust_remote_code=True,
device_map=None,
)
self.model.to(self.device)
self.model.eval()
# Expected sampling rate for the model
self.target_sr = 24000
print("SongFormer model loaded successfully!")
def _decode_base64_audio(self, audio_b64: str) -> np.ndarray:
"""
Decode base64-encoded audio to numpy array.
Args:
audio_b64: Base64-encoded audio string
Returns:
numpy array of audio samples at 24kHz
"""
# Decode base64 string to bytes
try:
audio_bytes = base64.b64decode(audio_b64)
except Exception as e:
raise ValueError(f"Failed to decode base64 audio data: {e}")
# Load audio from bytes using librosa
# Create a file-like object from bytes
audio_io = io.BytesIO(audio_bytes)
# Load with librosa (automatically handles WAV, MP3, etc.)
audio_array, _ = librosa.load(audio_io, sr=self.target_sr, mono=True)
return audio_array
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process inference request with base64-encoded audio.
Expected input:
{
"inputs": "<base64-encoded-audio-data>"
}
Returns:
{
"segments": [
{
"label": "intro",
"start": 0.0,
"end": 15.2
},
...
],
"duration": 180.5,
"num_segments": 8
}
"""
try:
# Extract base64-encoded audio
audio_b64 = data.get("inputs")
if not audio_b64:
raise ValueError("Missing 'inputs' key with base64-encoded audio")
if not isinstance(audio_b64, str):
raise ValueError("Input must be a base64-encoded string")
# Decode audio
audio_array = self._decode_base64_audio(audio_b64)
# Run inference
with torch.no_grad():
result = self.model(audio_array)
# Calculate duration
duration = len(audio_array) / self.target_sr
# Format output
output = {
"segments": result,
"duration": float(duration),
"num_segments": len(result)
}
return output
except Exception as e:
# Return error in a structured format
return {
"error": str(e),
"error_type": type(e).__name__,
"segments": [],
"duration": 0.0,
"num_segments": 0
}
# For local testing
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test SongFormer handler locally")
parser.add_argument("audio_file", help="Path to audio file to test")
parser.add_argument("--model-path", default=".", help="Path to model directory")
args = parser.parse_args()
# Initialize handler
handler = EndpointHandler(args.model_path)
# Read and encode audio file
with open(args.audio_file, "rb") as f:
audio_bytes = f.read()
audio_b64 = base64.b64encode(audio_bytes).decode('utf-8')
# Test with base64 input
print("\n=== Testing with base64-encoded audio ===")
result = handler({"inputs": audio_b64})
print(json.dumps(result, indent=2))
# Test with file path directly (for comparison)
print("\n=== Testing with direct file path (not typical for endpoint) ===")
result_direct = handler.model(args.audio_file)
print(json.dumps(result_direct, indent=2))