| | import torch |
| | import librosa |
| | import os |
| | import numpy as np |
| | import matplotlib.pyplot as plt |
| | from transformers import AutoTokenizer, ClapTextModelWithProjection |
| | from src.models.transformer import Dasheng_Encoder |
| | from src.models.sed_decoder import Decoder, TSED_Wrapper |
| | from src.utils import load_yaml_with_includes |
| |
|
| |
|
| | class FlexSED: |
| | def __init__( |
| | self, |
| | config_path='src/configs/model.yml', |
| | ckpt_path='ckpts/flexsed_as.pt', |
| | ckpt_url='https://huggingface.co/Higobeatz/FlexSED/resolve/main/ckpts/flexsed_as.pt', |
| | device='cuda' |
| | ): |
| | """ |
| | Initialize FlexSED with model, CLAP, and tokenizer loaded once. |
| | If the checkpoint is not available locally, it will be downloaded automatically. |
| | """ |
| | self.device = device |
| | params = load_yaml_with_includes(config_path) |
| |
|
| | |
| | if not os.path.exists(ckpt_path): |
| | print(f"[FlexSED] Downloading checkpoint from {ckpt_url} ...") |
| | state_dict = torch.hub.load_state_dict_from_url(ckpt_url, map_location="cpu") |
| | else: |
| | state_dict = torch.load(ckpt_path, map_location="cpu") |
| |
|
| | |
| | encoder = Dasheng_Encoder(**params['encoder']).to(self.device) |
| | decoder = Decoder(**params['decoder']).to(self.device) |
| | self.model = TSED_Wrapper(encoder, decoder, params['ft_blocks'], params['frozen_encoder']) |
| | self.model.load_state_dict(state_dict['model']) |
| | self.model.eval() |
| |
|
| | |
| | self.clap = ClapTextModelWithProjection.from_pretrained("laion/clap-htsat-unfused") |
| | self.clap.eval() |
| | self.tokenizer = AutoTokenizer.from_pretrained("laion/clap-htsat-unfused") |
| |
|
| | def run_inference(self, audio_path, events, norm_audio=True): |
| | """ |
| | Run inference on audio for given events. |
| | """ |
| | audio, sr = librosa.load(audio_path, sr=16000) |
| | audio = torch.tensor([audio]).to(self.device) |
| |
|
| | if norm_audio: |
| | eps = 1e-9 |
| | max_val = torch.max(torch.abs(audio)) |
| | audio = audio / (max_val + eps) |
| |
|
| | clap_embeds = [] |
| | with torch.no_grad(): |
| | for event in events: |
| | text = f"The sound of {event.replace('_', ' ').capitalize()}" |
| | inputs = self.tokenizer([text], padding=True, return_tensors="pt") |
| | outputs = self.clap(**inputs) |
| | text_embeds = outputs.text_embeds.unsqueeze(1) |
| | clap_embeds.append(text_embeds) |
| |
|
| | query = torch.cat(clap_embeds, dim=1).to(self.device) |
| | mel = self.model.forward_to_spec(audio) |
| | preds = self.model(mel, query) |
| | preds = torch.sigmoid(preds).cpu() |
| |
|
| | return preds |
| |
|
| | |
| | @staticmethod |
| | def plot_and_save_multi(preds, events, sr=25, out_dir="./plots", fname="all_events"): |
| | os.makedirs(out_dir, exist_ok=True) |
| | preds_np = preds.squeeze(1).numpy() |
| | T = preds_np.shape[1] |
| |
|
| | plt.figure(figsize=(12, len(events) * 0.6 + 2)) |
| | plt.imshow( |
| | preds_np, |
| | aspect="auto", |
| | cmap="Blues", |
| | extent=[0, T/sr, 0, len(events)], |
| | vmin=0, vmax=1, origin="lower" |
| |
|
| | ) |
| | plt.colorbar(label="Probability") |
| | plt.yticks(np.arange(len(events)) + 0.5, events) |
| | plt.xlabel("Time (s)") |
| | plt.ylabel("Events") |
| | plt.title("Event Predictions") |
| |
|
| | save_path = os.path.join(out_dir, f"{fname}.png") |
| | plt.savefig(save_path, dpi=200, bbox_inches="tight") |
| | plt.close() |
| | return save_path |
| |
|
| | def to_multi_plot(self, preds, events, out_dir="./plots", fname="all_events"): |
| | return self.plot_and_save_multi(preds, events, out_dir=out_dir, fname=fname) |
| |
|
| | |
| | @staticmethod |
| | def make_multi_event_video(preds, events, sr=25, out_dir="./videos", |
| | audio_path=None, fps=25, highlight=True, fname="all_events"): |
| | from moviepy.editor import ImageSequenceClip, AudioFileClip |
| | from tqdm import tqdm |
| |
|
| | os.makedirs(out_dir, exist_ok=True) |
| | preds_np = preds.squeeze(1).numpy() |
| | T = preds_np.shape[1] |
| | duration = T / sr |
| |
|
| | frames = [] |
| | n_frames = int(duration * fps) |
| |
|
| | for i in tqdm(range(n_frames)): |
| | t = int(i * T / n_frames) |
| | plt.figure(figsize=(12, len(events) * 0.6 + 2)) |
| |
|
| | if highlight: |
| | mask = np.zeros_like(preds_np) |
| | mask[:, :t+1] = preds_np[:, :t+1] |
| | plt.imshow( |
| | mask, |
| | aspect="auto", |
| | cmap="Blues", |
| | extent=[0, T/sr, 0, len(events)], |
| | vmin=0, vmax=1, origin="lower" |
| | ) |
| | else: |
| | plt.imshow( |
| | preds_np[:, :t+1], |
| | aspect="auto", |
| | cmap="Blues", |
| | extent=[0, (t+1)/sr, 0, len(events)], |
| | vmin=0, vmax=1, origin="lower" |
| | ) |
| |
|
| | plt.colorbar(label="Probability") |
| | plt.yticks(np.arange(len(events)) + 0.5, events) |
| | plt.xlabel("Time (s)") |
| | plt.ylabel("Events") |
| | plt.title("Event Predictions") |
| |
|
| | frame_path = f"/tmp/frame_{i:04d}.png" |
| | plt.savefig(frame_path, dpi=150, bbox_inches="tight") |
| | plt.close() |
| | frames.append(frame_path) |
| |
|
| | clip = ImageSequenceClip(frames, fps=fps) |
| | if audio_path is not None: |
| | audio = AudioFileClip(audio_path).subclip(0, duration) |
| | clip = clip.set_audio(audio) |
| |
|
| | save_path = os.path.join(out_dir, f"{fname}.mp4") |
| | clip.write_videofile( |
| | save_path, |
| | fps=fps, |
| | codec="mpeg4", |
| | audio_codec="aac" |
| | ) |
| |
|
| | for f in frames: |
| | os.remove(f) |
| |
|
| | return save_path |
| |
|
| | def to_multi_video(self, preds, events, audio_path, out_dir="./videos", fname="all_events"): |
| | return self.make_multi_event_video( |
| | preds, events, audio_path=audio_path, out_dir=out_dir, fname=fname |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | flexsed = FlexSED(device='cuda') |
| |
|
| | events = ["Door", "Laughter", "Dog"] |
| | preds = flexsed.run_inference("example2.wav", events) |
| |
|
| | |
| | flexsed.to_multi_plot(preds, events, fname="example2") |
| | |
| |
|