""" Music Theory Engine for TouchGrass. Understands music theory relationships, scales, chords, progressions. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, List, Dict, Tuple class MusicTheoryModule(nn.Module): """ Understands music theory relationships. Knows: - Circle of fifths and key relationships - Scale degrees and chord functions (I, ii, iii, IV, V, vi, vii°) - All modes: Ionian, Dorian, Phrygian, Lydian, Mixolydian, Aeolian, Locrian - Interval relationships (major/minor/perfect/augmented/diminished) - Chord tensions and extensions (7ths, 9ths, 11ths, 13ths) - Common progressions (I-IV-V, ii-V-I, I-V-vi-IV, 12-bar blues, etc) - Voice leading principles - Modulation techniques """ # Chromatic notes (C-based) CHROMATIC_NOTES = ["C", "C#", "D", "Db", "E", "Eb", "F", "F#", "G", "Gb", "A", "Ab", "B", "Bb"] # Actually 12 notes, but listing enharmonics for flexibility # Scale degrees in major (Ionian) SCALE_DEGREES = ["I", "ii", "iii", "IV", "V", "vi", "vii°"] # Common chord types CHORD_TYPES = [ "major", "minor", "diminished", "augmented", "major7", "minor7", "dominant7", "half-dim7", "dim7", "major9", "minor9", "dominant9", "sus2", "sus4", "add9", "6", "maj6", ] # Modes MODES = [ "ionian", "dorian", "phrygian", "lydian", "mixolydian", "aeolian", "locrian" ] # Common progressions (by scale degrees) COMMON_PROGRESSIONS = { "I-IV-V-I": "Classical cadential", "ii-V-I": "Jazz turnaround", "I-V-vi-IV": "Pop progression (4-chord)", "vi-IV-I-V": "Pop variant", "I-vi-ii-V": "Circle progression", "I-vi-IV-V": "50s progression", "IV-V-I": "Plagal cadence", "V-I": "Authentic cadence", "12-bar blues": "Blues", "i-iv-v": "Minor blues", } def __init__(self, d_model: int): """ Initialize MusicTheoryModule. Args: d_model: Hidden dimension from base model """ super().__init__() self.d_model = d_model # Embeddings # 12 chromatic notes × 4 octave context = 48 total pitch classes self.note_embed = nn.Embedding(48, 128) # 12 notes × 4 octaves self.chord_type_embed = nn.Embedding(15, 128) self.mode_embed = nn.Embedding(7, 128) self.key_embed = nn.Embedding(24, 128) # 12 major + 12 minor keys # Theory relationship head self.relationship_proj = nn.Linear(d_model, d_model) # Chord function classifier (tonic, subdominant, dominant) self.chord_function_head = nn.Linear(d_model, 3) # Scale degree predictor self.scale_degree_head = nn.Linear(d_model, 7) # Interval classifier (unison through 13th) self.interval_head = nn.Linear(d_model, 14) # Progression predictor (next chord in progression) self.progression_head = nn.Linear(d_model, 7) # Key detection head self.key_detection_head = nn.Linear(d_model, 24) # Mode classifier self.mode_classifier = nn.Linear(d_model, 7) def forward( self, hidden_states: torch.Tensor, query: Optional[str] = None, ) -> Dict[str, torch.Tensor]: """ Forward pass through MusicTheoryModule. Args: hidden_states: Base model hidden states [batch, seq_len, d_model] query: Optional text query about music theory Returns: Dictionary with theory-related predictions """ batch_size, seq_len, _ = hidden_states.shape # Pool hidden states pooled = hidden_states.mean(dim=1) # [batch, d_model] # Predict chord function chord_function_logits = self.chord_function_head(pooled) # [batch, 3] # Predict scale degree scale_degree_logits = self.scale_degree_head(pooled) # [batch, 7] # Predict interval interval_logits = self.interval_head(pooled) # [batch, 14] # Predict next chord in progression progression_logits = self.progression_head(pooled) # [batch, 7] # Detect key key_logits = self.key_detection_head(pooled) # [batch, 24] # Classify mode mode_logits = self.mode_classifier(pooled) # [batch, 7] outputs = { "chord_function_logits": chord_function_logits, "scale_degree_logits": scale_degree_logits, "interval_logits": interval_logits, "progression_logits": progression_logits, "key_logits": key_logits, "mode_logits": mode_logits, } return outputs def get_chord_function(self, scale_degree: str) -> str: """ Get chord function (tonic, subdominant, dominant). Args: scale_degree: Roman numeral (I, ii, V, etc) Returns: Chord function string """ tonic = ["I", "vi"] subdominant = ["ii", "IV", "vi"] dominant = ["V", "vii°", "iii"] if scale_degree in tonic: return "tonic" elif scale_degree in subdominant: return "subdominant" elif scale_degree in dominant: return "dominant" else: return "unknown" def get_scale_from_key(self, key: str, mode: str = "ionian") -> List[str]: """ Generate scale notes from key and mode. Args: key: Root note (C, D, E, etc) mode: Mode name (ionian, dorian, etc) Returns: List of notes in the scale """ # Define intervals for each mode (semitones from root) mode_intervals = { "ionian": [0, 2, 4, 5, 7, 9, 11], "dorian": [0, 2, 3, 5, 7, 9, 10], "phrygian": [0, 1, 3, 5, 7, 8, 10], "lydian": [0, 2, 4, 6, 7, 9, 11], "mixolydian": [0, 2, 4, 5, 7, 9, 10], "aeolian": [0, 2, 3, 5, 7, 8, 10], "locrian": [0, 1, 3, 5, 6, 8, 10], } # Note to semitone mapping (C=0) note_to_semitone = { "C": 0, "C#": 1, "Db": 1, "D": 2, "D#": 3, "Eb": 3, "E": 4, "F": 5, "F#": 6, "Gb": 6, "G": 7, "G#": 8, "Ab": 8, "A": 9, "A#": 10, "Bb": 10, "B": 11, } if mode not in mode_intervals: raise ValueError(f"Unknown mode: {mode}") root_semitone = note_to_semitone.get(key) if root_semitone is None: raise ValueError(f"Unknown key: {key}") # Build scale intervals = mode_intervals[mode] scale = [] for interval in intervals: semitone = (root_semitone + interval) % 12 # Find note name note_name = self._semitone_to_note(semitone) scale.append(note_name) return scale def _semitone_to_note(self, semitone: int) -> str: """Convert semitone number to note name.""" semitone_to_note = { 0: "C", 1: "C#", 2: "D", 3: "Eb", 4: "E", 5: "F", 6: "F#", 7: "G", 8: "Ab", 9: "A", 10: "Bb", 11: "B", } return semitone_to_note[semitone] def get_progression_chords( self, progression_name: str, key: str = "C", ) -> List[Tuple[str, str]]: """ Get chord progression as list of (degree, chord). Args: progression_name: Name of progression (e.g., "I-IV-V-I") key: Root key Returns: List of (scale_degree, chord) tuples """ if progression_name not in self.COMMON_PROGRESSIONS: raise ValueError(f"Unknown progression: {progression_name}") # Parse progression degrees degrees = progression_name.split("-") # Get scale for key scale = self.get_scale_from_key(key, mode="ionian") chords = [] for degree in degrees: # Convert Roman numeral to scale index roman_map = {"I": 0, "ii": 1, "iii": 2, "IV": 3, "V": 4, "vi": 5, "vii°": 6} idx = roman_map.get(degree) if idx is None: continue root_note = scale[idx] # Determine chord quality based on degree if degree in ["ii", "iii", "vi"]: quality = "minor" elif degree == "vii°": quality = "diminished" else: quality = "major" chord = f"{root_note} {quality}" chords.append((degree, chord)) return chords def suggest_progression( self, mood: str = "happy", genre: str = "pop", num_chords: int = 4, ) -> List[str]: """ Suggest chord progression based on mood and genre. Args: mood: Emotional mood (happy, sad, tense, etc) genre: Music genre num_chords: Number of chords in progression Returns: List of chord names """ # Simple rule-based suggestions if mood == "happy" and genre == "pop": if num_chords == 4: return ["I", "V", "vi", "IV"] elif num_chords == 3: return ["I", "IV", "V"] elif mood == "sad" or mood == "melancholy": return ["vi", "IV", "I", "V"] elif mood == "tense" or mood == "dramatic": return ["i", "iv", "V", "i"] # Minor with dominant elif mood == "jazzy": return ["ii", "V", "I", "vi"] else: return ["I", "IV", "V", "I"] # Default return ["I", "IV", "V", "I"] def validate_progression( self, progression: List[str], key: str = "C", ) -> Tuple[bool, List[str]]: """ Validate chord progression for theoretical correctness. Args: progression: List of Roman numerals or chord names key: Key center Returns: (is_valid, issues) """ issues = [] # Check if all chords belong to the key scale = self.get_scale_from_key(key, mode="ionian") scale_notes = [note.rstrip("b#") for note in scale] # Simplified for chord in progression: # Extract root note from chord name if " " in chord: root = chord.split(" ")[0] if root.rstrip("b#") not in scale_notes: issues.append(f"Chord {chord} has root {root} not in key {key}") return len(issues) == 0, issues def test_music_theory_module(): """Test the MusicTheoryModule.""" import torch # Create module module = MusicTheoryModule(d_model=4096) # Test input batch_size = 2 seq_len = 10 d_model = 4096 hidden_states = torch.randn(batch_size, seq_len, d_model) # Forward pass outputs = module.forward(hidden_states) print("Music Theory Module outputs:") for key, value in outputs.items(): print(f" {key}: {value.shape}") # Test scale generation print("\nScale from C ionian:") scale = module.get_scale_from_key("C", "ionian") print(f" {scale}") print("\nScale from A dorian:") scale = module.get_scale_from_key("A", "dorian") print(f" {scale}") # Test progression print("\nProgression I-V-vi-IV in C:") chords = module.get_progression_chords("I-V-vi-IV", "C") for degree, chord in chords: print(f" {degree}: {chord}") # Test suggestion print("\nSuggested progression (happy, pop, 4 chords):") prog = module.suggest_progression(mood="happy", genre="pop", num_chords=4) print(f" {prog}") # Test validation print("\nValidate progression [I, IV, V, I] in C:") valid, issues = module.validate_progression(["I", "IV", "V", "I"], "C") print(f" Valid: {valid}") if issues: print(f" Issues: {issues}") print("\nMusic Theory Module test complete!") if __name__ == "__main__": test_music_theory_module()