""" Song Writing Assistant Module for TouchGrass. Assists with song composition across all elements. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, List, Dict, Tuple class SongwritingModule(nn.Module): """ Assists with song composition across all elements. Features: - Chord progression suggestions based on mood/genre - Lyric writing assistance with rhyme scheme awareness - Song structure templates (verse-chorus-bridge, AABA, etc) - Genre-appropriate production suggestions - Melody writing guidance - Hook development Understands song structure tokens: [VERSE], [CHORUS], [BRIDGE], [PRE-CHORUS], [OUTRO], [INTRO] """ # Song structures SONG_STRUCTURES = { "verse-chorus": ["INTRO", "VERSE", "CHORUS", "VERSE", "CHORUS", "BRIDGE", "CHORUS", "OUTRO"], "aaba": ["INTRO", "A", "A", "B", "A", "OUTRO"], "through-composed": ["INTRO", "VERSE", "VERSE", "VERSE", "VERSE", "OUTRO"], "pop": ["INTRO", "VERSE", "PRE-CHORUS", "CHORUS", "VERSE", "PRE-CHORUS", "CHORUS", "BRIDGE", "CHORUS", "OUTRO"], "blues": ["INTRO", "VERSE", "VERSE", "VERSE", "VERSE", "OUTRO"], # 12-bar each verse "sonata": ["EXPOSITION", "DEVELOPMENT", "RECAPITULATION"], } # Genres GENRES = [ "pop", "rock", "country", "folk", "blues", "jazz", "r&b", "soul", "hip-hop", "electronic", "classical", "metal", "punk", "indie", "folk-rock", "singer-songwriter", ] # Moods MOODS = [ "happy", "sad", "angry", "romantic", "melancholy", "uplifting", "dark", "energetic", "peaceful", "dramatic", "nostalgic", "hopeful", ] # Rhyme schemes RHYME_SCHEMES = { "AABB": "Couplet", "ABAB": "Alternating", "ABBA": "Enclosed", "ABCB": "Ballad", "free": "Free verse", } # Common rhyme families (simplified phonetics) RHYME_FAMILIES = { "ight": ["light", "night", "right", "fight", "bright", "sight"], "ine": ["shine", "mine", "fine", "line", "sign", "time"], "all": ["fall", "call", "wall", "tall", "ball", "small"], "ing": ["sing", "ring", "bring", "spring", "thing", "wing"], "ay": ["say", "day", "way", "stay", "play", "away"], "own": ["down", "crown", "frown", "town", "gown", "clown"], } # Hook types HOOK_TYPES = [ "melodic_hook", # catchy melody "lyrical_hook", # memorable phrase "rhythmic_hook", # distinctive rhythm "sonic_hook", # unique sound/texture ] # Production elements by genre GENRE_PRODUCTION = { "pop": ["reverb", "compression", "auto-tune", "synth pads", "four-on-the-floor"], "rock": ["distortion", "overdrive", "guitar amps", "live drums"], "country": ["acoustic guitar", "steel guitar", "reverb", "warm vocal"], "folk": ["acoustic", "minimal", "room mic", "organic"], "blues": ["tube amp", "overdrive", "blues harp", "shuffle rhythm"], "jazz": ["room recording", "minimal compression", "acoustic piano", "brass"], "hip-hop": ["808 bass", "hi-hats", "samples", "sidechain"], "electronic": ["synths", "drum machines", "reverb", "delay", "automation"], "metal": ["high gain", "double kick", "scream vocals", "fast tempo"], } def __init__(self, d_model: int, num_genres: int = 20): """ Initialize SongwritingModule. Args: d_model: Hidden dimension from base model num_genres: Number of genre categories """ super().__init__() self.d_model = d_model self.num_genres = num_genres # Embeddings self.genre_embed = nn.Embedding(num_genres, 128) self.structure_embed = nn.Embedding(10, 64) # song sections self.mood_embed = nn.Embedding(15, 64) # moods self.section_type_embed = nn.Embedding(8, 64) # verse/chorus/etc # Rhyme suggestion head self.rhyme_head = nn.Linear(d_model, d_model) # Chord progression type predictor self.progression_head = nn.Linear(d_model, 32) # Hook generator self.hook_generator = nn.GRU( input_size=d_model + 128, # hidden + genre hidden_size=d_model, num_layers=1, batch_first=True, ) # Lyric line generator self.lyric_generator = nn.GRU( input_size=d_model + 64, # hidden + section type hidden_size=d_model, num_layers=2, batch_first=True, dropout=0.1, ) # Genre classifier self.genre_classifier = nn.Linear(d_model, num_genres) # Mood classifier self.mood_classifier = nn.Linear(d_model, 15) # Section type classifier self.section_classifier = nn.Linear(d_model, 8) # Production suggestion head self.production_head = nn.Linear(d_model + num_genres, 64) def forward( self, hidden_states: torch.Tensor, genre: Optional[str] = None, mood: Optional[str] = None, structure: Optional[str] = None, ) -> Dict[str, torch.Tensor]: """ Forward pass through SongwritingModule. Args: hidden_states: Base model hidden states [batch, seq_len, d_model] genre: Optional genre string mood: Optional mood string structure: Optional song structure name Returns: Dictionary with songwriting predictions """ batch_size, seq_len, _ = hidden_states.shape # Pool hidden states pooled = hidden_states.mean(dim=1) # [batch, d_model] # Classify genre genre_logits = self.genre_classifier(pooled) # [batch, num_genres] # Classify mood mood_logits = self.mood_classifier(pooled) # [batch, 15] # Classify section type section_logits = self.section_classifier(pooled) # [batch, 8] # Predict chord progression type progression_logits = self.progression_head(pooled) # [batch, 32] # Generate hook (if genre provided) hook_output = None if genre: genre_idx = self._genre_to_idx(genre) genre_emb = self.genre_embed(torch.tensor([genre_idx], device=hidden_states.device)) genre_emb = genre_emb.expand(batch_size, -1) # Generate hook sequence hook_input = torch.cat([pooled.unsqueeze(1), genre_emb.unsqueeze(1)], dim=2) hook_output, _ = self.hook_generator(hook_input) # Generate lyrics (if section type provided) lyric_output = None if structure: section_idx = self._section_to_idx(structure) section_emb = self.section_type_embed(torch.tensor([section_idx], device=hidden_states.device)) section_emb = section_emb.expand(batch_size, -1) lyric_input = torch.cat([pooled.unsqueeze(1), section_emb.unsqueeze(1)], dim=2) lyric_output, _ = self.lyric_generator(lyric_input) outputs = { "genre_logits": genre_logits, "mood_logits": mood_logits, "section_logits": section_logits, "progression_logits": progression_logits, } if hook_output is not None: outputs["hook_output"] = hook_output if lyric_output is not None: outputs["lyric_output"] = lyric_output return outputs def get_song_structure(self, structure_name: str) -> List[str]: """ Get song structure template. Args: structure_name: Name of structure (verse-chorus, aaba, etc) Returns: List of section names in order """ return self.SONG_STRUCTURES.get(structure_name, self.SONG_STRUCTURES["verse-chorus"]) def suggest_progression( self, mood: str = "happy", genre: str = "pop", num_chords: int = 4, key: str = "C", ) -> List[Tuple[str, str]]: """ Suggest chord progression based on mood and genre. Args: mood: Emotional mood genre: Music genre num_chords: Number of chords key: Key signature Returns: List of (chord_degree, chord_name) tuples """ # Genre-specific progressions genre_progressions = { "pop": { "happy": ["I", "V", "vi", "IV"], "sad": ["vi", "IV", "I", "V"], "uplifting": ["I", "IV", "V", "I"], "romantic": ["ii", "V", "I", "vi"], }, "rock": { "energetic": ["I", "IV", "V", "IV"], "dark": ["i", "VI", "III", "VII"], "angry": ["i", "iv", "V", "i"], }, "blues": { "sad": ["I", "IV", "I", "I", "IV", "IV", "I", "I", "V", "IV", "I", "V"], "happy": ["I", "IV", "I", "I", "IV", "IV", "I", "I", "V", "IV", "I", "I"], }, "jazz": { "sophisticated": ["ii", "V", "I", "vi"], "jazzy": ["I", "vi", "ii", "V"], }, "folk": { "nostalgic": ["I", "V", "vi", "iii", "IV", "I", "IV", "V"], "peaceful": ["I", "IV", "I", "V", "I"], }, } # Get progression for genre/mood if genre in genre_progressions and mood in genre_progressions[genre]: progression = genre_progressions[genre][mood] else: # Default to pop happy progression = ["I", "V", "vi", "IV"] # Trim or extend to requested length if len(progression) > num_chords: progression = progression[:num_chords] elif len(progression) < num_chords: # Repeat or extend while len(progression) < num_chords: progression.append(progression[-1]) # Convert to chord names chords = self._degrees_to_chords(progression, key) return list(zip(progression, chords)) def _degrees_to_chords(self, degrees: List[str], key: str) -> List[str]: """Convert Roman numerals to chord names.""" # Major scale degrees major_scale = ["C", "D", "E", "F", "G", "A", "B"] minor_scale = ["C", "D", "Eb", "F", "G", "Ab", "Bb"] # Determine if key is major or minor is_minor = key.endswith("m") or "minor" in key root = key.rstrip("m").strip() scale = minor_scale if is_minor else major_scale # Map degree to chord degree_map = { "I": (0, "major"), "ii": (1, "minor"), "iii": (2, "minor"), "IV": (3, "major"), "V": (4, "major"), "vi": (5, "minor"), "vii°": (6, "diminished"), "i": (0, "minor"), "iv": (3, "minor"), "v": (4, "minor"), "VI": (5, "major"), "III": (2, "major"), "VII": (6, "major"), } chords = [] for degree in degrees: if degree in degree_map: idx, quality = degree_map[degree] root_note = scale[idx] if quality == "major": chord = f"{root_note} major" elif quality == "minor": chord = f"{root_note} minor" else: chord = f"{root_note} {quality}" chords.append(chord) else: chords.append(degree) # Keep as-is return chords def find_rhymes( self, word: str, rhyme_scheme: str = "AABB", num_rhymes: int = 4, ) -> List[str]: """ Find rhyming words. Args: word: Target word to rhyme rhyme_scheme: Rhyme scheme pattern num_rhymes: Number of rhymes to return Returns: List of rhyming words """ word = word.lower().strip() # Check rhyme families for ending, family in self.RHYME_FAMILIES.items(): if word.endswith(ending): rhymes = [w for w in family if w != word] return rhymes[:num_rhymes] # Fallback: simple suffix matching # (In production, use CMU pronunciation dictionary) common_endings = ["ing", "ed", "er", "ly", "tion", "sion", "ity", "ness"] for ending in common_endings: if word.endswith(ending) and len(word) > len(ending) + 2: # Generate placeholder rhymes base = word[:-len(ending)] rhymes = [base + ending] * num_rhymes # Placeholder return rhymes return [word] # No rhyme found def suggest_lyric_line( self, section_type: str, rhyme_with: Optional[str] = None, syllable_count: Optional[int] = None, mood: str = "happy", ) -> str: """ Suggest a lyric line. Args: section_type: Section (verse, chorus, bridge, etc) rhyme_with: Optional word to rhyme with syllable_count: Optional syllable count target mood: Emotional mood Returns: Suggested lyric line """ import random # Section-specific templates section_templates = { "VERSE": [ "Walking down this road again", "Memories of you remain", "Sunlight through the window pane", "Whispers in the pouring rain", ], "CHORUS": [ "This is our time, our moment now", "Forever you, forever me", "Hearts beating as one somehow", "Never gonna let you go", ], "BRIDGE": [ "But what if everything changes", "In the silence, I hear clearly", "Time reveals the truth within", "Sometimes the hardest thing to do is", ], "PRE-CHORUS": [ "Building up to something more", "Can you feel it coming now", "The tension rises, can't ignore", "Almost there, just take a bow", ], "OUTRO": [ "And so we fade into the night", "The story ends but love remains", "Goodbye for now, but not goodbye", "Echoes linger, fade away", ], } templates = section_templates.get(section_type, section_templates["VERSE"]) line = random.choice(templates) # Apply rhyme if specified if rhyme_with: rhymes = self.find_rhymes(rhyme_with) if rhymes: # Replace last word with rhyme words = line.split() if words: words[-1] = random.choice(rhymes) line = " ".join(words) return line def generate_hook( self, genre: str = "pop", mood: str = "happy", length: int = 4, ) -> Dict[str, str]: """ Generate a song hook (catchy phrase/melody). Args: genre: Music genre mood: Emotional mood length: Number of lines/phrases Returns: Dictionary with hook components """ import random # Hook templates by genre/mood hook_templates = { "pop": { "happy": [ "Feel the rhythm in your soul", "Dance like nobody's watching", "We are young, we are free", "This is our destiny", ], "sad": [ "But I still hear your voice", "Missing you, missing me", "Tears fall like rain tonight", "How could you say goodbye", ], }, "rock": { "energetic": [ "Break the chains, feel the fire", "We will never surrender", "Rising up from the ground", "Hear the sound all around", ], "angry": [ "I won't take it anymore", "Stand up and fight back", "This is my rebellion", "Breaking through the walls", ], }, "folk": { "nostalgic": [ "Remember those days gone by", "The old road leads us home", "Stories told by the fire", "Where the wild rivers flow", ], }, } # Get hooks for genre/mood hooks = [] if genre in hook_templates and mood in hook_templates[genre]: hooks = hook_templates[genre][mood] else: # Generic hooks hooks = [ "This is the hook that sticks", "Catchy melody, memorable line", "Sing along, feel the vibe", "The part you can't forget", ] # Select random hooks selected = random.sample(hooks, min(length, len(hooks))) return { "hook_lines": selected, "genre": genre, "mood": mood, "type": "lyrical_hook", } def suggest_production_elements( self, genre: str, mood: str, instruments: Optional[List[str]] = None, ) -> Dict[str, List[str]]: """ Suggest production elements for genre. Args: genre: Music genre mood: Emotional mood instruments: Optional instrument list Returns: Dictionary with production suggestions """ production = self.GENRE_PRODUCTION.get(genre, ["acoustic", "vocals", "drums"]) # Mood adjustments mood_effects = { "happy": ["bright reverb", "warm compression", "upbeat tempo"], "sad": ["hall reverb", "minimal", "slow tempo"], "dark": ["distortion", "low-pass filter", "dense reverb"], "energetic": ["compression", "sidechain", "fast tempo"], "peaceful": ["room tone", "natural reverb", "minimal processing"], } effects = mood_effects.get(mood, []) return { "genre_elements": production, "mood_effects": effects, "suggested_instruments": instruments or self._suggest_instruments(genre, mood), "mixing_tips": self._get_mixing_tips(genre), } def _suggest_instruments(self, genre: str, mood: str) -> List[str]: """Suggest instruments based on genre and mood.""" genre_instruments = { "pop": ["vocals", "synth", "drums", "bass", "guitar"], "rock": ["electric guitar", "drums", "bass", "vocals"], "country": ["acoustic guitar", "steel guitar", "fiddle", "vocals"], "folk": ["acoustic guitar", "harmonica", "vocals"], "blues": ["electric guitar", "harmonica", "drums", "bass"], "jazz": ["saxophone", "piano", "bass", "drums", "trumpet"], "hip-hop": ["drums", "bass", "synth", "samples"], "electronic": ["synth", "drum machine", "bass", "samples"], } instruments = genre_instruments.get(genre, ["guitar", "vocals", "drums"]) # Mood adjustments if mood == "sad" or mood == "peaceful": instruments = [inst for inst in instruments if "electric" not in inst] elif mood == "energetic" or mood == "angry": instruments = [inst for inst in instruments if "acoustic" not in inst] return instruments def _get_mixing_tips(self, genre: str) -> List[str]: """Get mixing tips for genre.""" tips = { "pop": [ "Vocal upfront in the mix", "Sidechain kick and bass", "Bright high-end on synths", ], "rock": [ "Guitars wide in stereo", "Drums punchy and present", "Bass tight and compressed", ], "folk": [ "Natural, room-filling sound", "Minimal processing", "Acoustic instruments front and center", ], "hip-hop": [ "808 bass sub-bass frequencies", "Hi-hats crisp and present", "Vocals front and center", ], } return tips.get(genre, ["Balance all elements", "Check on multiple speakers"]) def _genre_to_idx(self, genre: str) -> int: """Convert genre to index.""" try: return self.GENRES.index(genre) except ValueError: return 0 def _section_to_idx(self, section: str) -> int: """Convert section type to index.""" section_map = { "INTRO": 0, "VERSE": 1, "PRE-CHORUS": 2, "CHORUS": 3, "BRIDGE": 4, "OUTRO": 5, "A": 6, "B": 7, } return section_map.get(section.upper(), 1) def test_songwriting_module(): """Test the SongwritingModule.""" import torch # Create module module = SongwritingModule(d_model=4096, num_genres=20) # Test input batch_size = 2 seq_len = 10 d_model = 4096 hidden_states = torch.randn(batch_size, seq_len, d_model) # Forward pass outputs = module.forward( hidden_states, genre="pop", mood="happy", structure="CHORUS", ) print("Songwriting Module outputs:") for key, value in outputs.items(): if isinstance(value, torch.Tensor): print(f" {key}: {value.shape}") else: print(f" {key}: {value}") # Test song structure print("\nSong structure (verse-chorus):") structure = module.get_song_structure("verse-chorus") print(f" {' -> '.join(structure)}") # Test chord progression print("\nChord progression (pop, happy, 4 chords, key of C):") progression = module.suggest_progression(mood="happy", genre="pop", num_chords=4, key="C") for degree, chord in progression: print(f" {degree}: {chord}") # Test rhyme finder print("\nRhymes for 'light':") rhymes = module.find_rhymes("light", num_rhymes=5) print(f" {', '.join(rhymes)}") # Test lyric suggestion print("\nLyric suggestion (chorus, rhyme with 'now'):") lyric = module.suggest_lyric_line(section_type="CHORUS", rhyme_with="now") print(f" {lyric}") # Test hook generation print("\nHook generation (pop, happy, 2 lines):") hook = module.generate_hook(genre="pop", mood="happy", length=2) print(f" Hook: {hook['hook_lines']}") # Test production suggestions print("\nProduction suggestions (rock, energetic):") prod = module.suggest_production_elements(genre="rock", mood="energetic") print(f" Instruments: {', '.join(prod['suggested_instruments'])}") print(f" Effects: {', '.join(prod['mood_effects'])}") print(f" Mixing tips: {', '.join(prod['mixing_tips'])}") print("\nSongwriting Module test complete!") if __name__ == "__main__": test_songwriting_module()