File size: 8,470 Bytes
4619f39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17d39ba
4619f39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17d39ba
4619f39
 
17d39ba
 
4619f39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""Fast audio captioning: CLAP tags + Silero VAD + faster-whisper lyrics.

Provides mood/genre/instrument tagging via CLAP zero-shot classification,
speech detection via Silero VAD, and lyrics extraction via faster-whisper.
All models run on CPU. Total: ~3-5 min per file.

Usage:
    from caption_fast import caption_audio
    result = caption_audio("song.mp3")
    # {"caption": "Pop, Energetic, Guitar, Melodic, Upbeat",
    #  "lyrics": "[Verse]\nSome lyrics here...",
    #  "bpm": 120, "key": "C major", "signature": "4/4",
    #  "tags": ["Pop", "Energetic", "Guitar", ...]}
"""

from __future__ import annotations

import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional

logger = logging.getLogger(__name__)

# Tag list for CLAP zero-shot classification (from clap-interrogator)
TAGS = [
    "Fast", "Slow", "Upbeat", "Downbeat", "Moderate",
    "Happy", "Sad", "Energetic", "Relaxed", "Melancholic", "Uplifting",
    "Aggressive", "Peaceful", "Romantic", "Dark", "Light", "Mysterious",
    "Dreamy", "Somber", "Hopeful", "Gloomy", "Cheerful", "Reflective",
    "Nostalgic", "Tense", "Calm",
    "Piano", "Guitar", "Violin", "Drums", "Bass", "Synthesizer",
    "Saxophone", "Trumpet", "Flute", "Cello", "Clarinet", "Harp",
    "Percussion", "Organ", "Accordion", "Electronic", "Acoustic",
    "Electric Guitar", "Acoustic Guitar", "Synth Pad", "Keyboards",
    "Rock", "Pop", "Jazz", "Classical", "Electronic", "Folk", "Hip-Hop",
    "Blues", "Ambient", "Country", "Reggae", "Funk", "Soul", "Metal",
    "Dance", "Disco", "House", "Techno", "Trance", "Soundtrack", "World",
    "Indie", "Alternative", "R&B", "EDM", "Chillwave", "Dubstep",
    "Lo-fi Hip-Hop", "Drum and Bass", "Jazz Fusion", "Neo-Soul", "Trap",
    "K-Pop", "J-Pop", "Reggaeton", "Punk", "Grunge",
    "Bright", "Warm", "Smooth", "Distorted", "Clean", "Lo-fi",
    "Layered", "Minimalist", "Cinematic", "Atmospheric", "Ethereal",
    "Groovy", "Rhythmic", "Melodic", "Harmonic",
    "Live", "Studio", "Instrumental",
]

_clap_model = None
_clap_processor = None
_whisper_model = None
_vad_model = None


def _load_clap():
    global _clap_model, _clap_processor
    if _clap_model is not None:
        return _clap_model, _clap_processor
    from transformers import ClapModel, ClapProcessor
    logger.info("[CLAP] Loading laion/larger_clap_music...")
    _clap_processor = ClapProcessor.from_pretrained("laion/larger_clap_music")
    _clap_model = ClapModel.from_pretrained("laion/larger_clap_music")
    _clap_model.eval()
    logger.info("[CLAP] Ready (~780MB)")
    return _clap_model, _clap_processor


def _load_whisper():
    global _whisper_model
    if _whisper_model is not None:
        return _whisper_model
    from faster_whisper import WhisperModel
    logger.info("[Whisper] Loading large-v3-turbo (int8, CPU)...")
    _whisper_model = WhisperModel(
        "large-v3-turbo",
        device="cpu",
        compute_type="int8",
    )
    logger.info("[Whisper] Ready (~1.5GB)")
    return _whisper_model


def _load_vad():
    global _vad_model
    if _vad_model is not None:
        return _vad_model
    import torch
    logger.info("[VAD] Loading Silero VAD...")
    _vad_model, _vad_utils = torch.hub.load(
        repo_or_dir='snakers4/silero-vad',
        model='silero_vad',
        onnx=True,
        trust_repo=True,
    )
    logger.info("[VAD] Ready (~2MB)")
    return _vad_model


def unload_caption_models():
    """Free all captioning models from memory."""
    global _clap_model, _clap_processor, _whisper_model, _vad_model
    import gc
    _clap_model = None
    _clap_processor = None
    _whisper_model = None
    _vad_model = None
    gc.collect()
    logger.info("[Caption] All models unloaded")


def tag_audio(audio_path: str, top_n: int = 10) -> List[str]:
    """Get top-N CLAP tags for an audio file."""
    import librosa
    import torch

    model, processor = _load_clap()
    audio, sr = librosa.load(audio_path, sr=48000, mono=True)

    inputs = processor(
        text=TAGS,
        audio=[audio],
        sampling_rate=48000,
        return_tensors="pt",
        padding=True,
    )

    with torch.no_grad():
        outputs = model(**inputs)

    probs = outputs.logits_per_audio.softmax(dim=-1)
    top_probs, top_indices = probs.topk(top_n, dim=1)
    return [TAGS[i] for i in top_indices[0].tolist()]


def detect_speech(audio_path: str, threshold: float = 5.0) -> bool:
    """Check if audio contains speech using Silero VAD.
    Returns True if speech detected for more than `threshold` seconds.
    """
    import torch
    import librosa

    vad = _load_vad()
    y, sr = librosa.load(audio_path, sr=16000, mono=True)
    wav = torch.from_numpy(y).unsqueeze(0)

    speech_timestamps = []
    window_size = 512
    for i in range(0, wav.shape[1], window_size):
        chunk = wav[0, i:i + window_size]
        if len(chunk) < window_size:
            break
        prob = vad(chunk, 16000).item()
        if prob > 0.5:
            speech_timestamps.append(i / 16000)

    speech_duration = len(speech_timestamps) * (window_size / 16000)
    logger.info("[VAD] Speech: %.1fs detected in %s", speech_duration, os.path.basename(audio_path))
    return speech_duration > threshold


def transcribe_lyrics(audio_path: str) -> str:
    """Extract lyrics from audio using faster-whisper."""
    model = _load_whisper()

    segments, info = model.transcribe(
        audio_path,
        language=None,
        beam_size=5,
        vad_filter=True,
    )

    lines = []
    for segment in segments:
        text = segment.text.strip()
        if text:
            lines.append(text)

    lyrics = "\n".join(lines)
    if not lyrics.strip():
        return "[Instrumental]"

    logger.info("[Whisper] Transcribed %d lines (lang=%s, prob=%.2f)",
                len(lines), info.language, info.language_probability)
    return lyrics


def get_bpm_key(audio_path: str) -> Dict[str, str]:
    """Get BPM and key via librosa."""
    import librosa
    import numpy as np

    y, sr = librosa.load(audio_path, sr=None, mono=True)

    tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
    bpm = int(round(float(tempo.item() if hasattr(tempo, 'item') else tempo)))

    chroma = librosa.feature.chroma_cens(y=y, sr=sr)
    chroma_avg = np.mean(chroma, axis=1)
    keys = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
    major_profile = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
    minor_profile = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])

    best_corr = -1
    best_key = "C major"
    for i in range(12):
        maj_corr = float(np.corrcoef(np.roll(major_profile, i), chroma_avg)[0, 1])
        min_corr = float(np.corrcoef(np.roll(minor_profile, i), chroma_avg)[0, 1])
        if maj_corr > best_corr:
            best_corr = maj_corr
            best_key = f"{keys[i]} major"
        if min_corr > best_corr:
            best_corr = min_corr
            best_key = f"{keys[i]} minor"

    return {"bpm": str(bpm), "key": best_key, "signature": "4/4"}


def caption_audio(
    audio_path: str,
    top_n: int = 10,
    extract_lyrics: bool = True,
    speech_threshold: float = 5.0,
) -> Dict[str, str]:
    """Full fast captioning pipeline for one audio file.

    Returns dict with: caption, lyrics, bpm, key, signature, tags
    """
    fname = os.path.basename(audio_path)
    logger.info("[Caption] Processing %s...", fname)

    # 1. CLAP tags (mood, genre, instruments)
    tags = tag_audio(audio_path, top_n=top_n)
    caption = ", ".join(tags)
    logger.info("[Caption] %s: tags=%s", fname, caption)

    # 2. BPM + key via librosa
    bpm_key = get_bpm_key(audio_path)
    logger.info("[Caption] %s: BPM=%s, key=%s", fname, bpm_key["bpm"], bpm_key["key"])

    # 3. Speech detection + lyrics
    lyrics = "[Instrumental]"
    if extract_lyrics:
        has_speech = detect_speech(audio_path, threshold=speech_threshold)
        if has_speech:
            logger.info("[Caption] %s: speech detected, transcribing lyrics...", fname)
            lyrics = transcribe_lyrics(audio_path)
        else:
            logger.info("[Caption] %s: no speech, marking instrumental", fname)

    return {
        "caption": caption,
        "lyrics": lyrics,
        "bpm": bpm_key["bpm"],
        "key": bpm_key["key"],
        "signature": bpm_key["signature"],
        "tags": tags,
    }