""" HuggingFace Inference Endpoint Handler for SongFormer Supports binary audio input (WAV, MP3, etc.) via base64 encoding or direct bytes """ import os import sys import io import base64 import json import tempfile from typing import Dict, Any, Union import librosa import numpy as np import torch from transformers import AutoModel class EndpointHandler: """ HuggingFace Inference Endpoint Handler for SongFormer model. Accepts base64-encoded audio (WAV, MP3, FLAC, etc.) """ def __init__(self, path: str = ""): """ Initialize the handler and load the SongFormer model. Args: path: Path to the model directory (provided by HuggingFace) """ # Set up environment self.model_path = path or os.getcwd() os.environ["SONGFORMER_LOCAL_DIR"] = self.model_path sys.path.insert(0, self.model_path) # Import after setting up path # Load the model self.device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading SongFormer model on {self.device}...") # Load model without device_map to avoid meta device initialization # The SongFormerModel.__init__ now handles meta device detection self.model = AutoModel.from_pretrained( self.model_path, trust_remote_code=True, device_map=None, ) self.model.to(self.device) self.model.eval() # Expected sampling rate for the model self.target_sr = 24000 print("SongFormer model loaded successfully!") def _decode_base64_audio(self, audio_b64: str) -> np.ndarray: """ Decode base64-encoded audio to numpy array. Args: audio_b64: Base64-encoded audio string Returns: numpy array of audio samples at 24kHz """ # Decode base64 string to bytes try: audio_bytes = base64.b64decode(audio_b64) except Exception as e: raise ValueError(f"Failed to decode base64 audio data: {e}") # Load audio from bytes using librosa # Create a file-like object from bytes audio_io = io.BytesIO(audio_bytes) # Load with librosa (automatically handles WAV, MP3, etc.) audio_array, _ = librosa.load(audio_io, sr=self.target_sr, mono=True) return audio_array def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process inference request with base64-encoded audio. Expected input: { "inputs": "" } Returns: { "segments": [ { "label": "intro", "start": 0.0, "end": 15.2 }, ... ], "duration": 180.5, "num_segments": 8 } """ try: # Extract base64-encoded audio audio_b64 = data.get("inputs") if not audio_b64: raise ValueError("Missing 'inputs' key with base64-encoded audio") if not isinstance(audio_b64, str): raise ValueError("Input must be a base64-encoded string") # Decode audio audio_array = self._decode_base64_audio(audio_b64) # Run inference with torch.no_grad(): result = self.model(audio_array) # Calculate duration duration = len(audio_array) / self.target_sr # Format output output = { "segments": result, "duration": float(duration), "num_segments": len(result) } return output except Exception as e: # Return error in a structured format return { "error": str(e), "error_type": type(e).__name__, "segments": [], "duration": 0.0, "num_segments": 0 } # For local testing if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Test SongFormer handler locally") parser.add_argument("audio_file", help="Path to audio file to test") parser.add_argument("--model-path", default=".", help="Path to model directory") args = parser.parse_args() # Initialize handler handler = EndpointHandler(args.model_path) # Read and encode audio file with open(args.audio_file, "rb") as f: audio_bytes = f.read() audio_b64 = base64.b64encode(audio_bytes).decode('utf-8') # Test with base64 input print("\n=== Testing with base64-encoded audio ===") result = handler({"inputs": audio_b64}) print(json.dumps(result, indent=2)) # Test with file path directly (for comparison) print("\n=== Testing with direct file path (not typical for endpoint) ===") result_direct = handler.model(args.audio_file) print(json.dumps(result_direct, indent=2))