|
|
| from enum import Enum, auto |
| from typing import Optional, Dict, Any, Tuple, List, Callable, Set |
| from loguru import logger |
| from transformers import AutoTokenizer |
| from transformers.generation.logits_process import LogitsProcessor |
| import os |
| import torch |
| from acestep.constants import ( |
| VALID_LANGUAGES, |
| KEYSCALE_NOTES, |
| KEYSCALE_ACCIDENTALS, |
| KEYSCALE_MODES, |
| VALID_KEYSCALES, |
| BPM_MIN, |
| BPM_MAX, |
| DURATION_MIN, |
| DURATION_MAX, |
| VALID_TIME_SIGNATURES, |
| ) |
|
|
|
|
| |
| |
| |
| class FSMState(Enum): |
| """Finite State Machine states for metadata generation""" |
| THINK_TAG = auto() |
| NEWLINE_AFTER_THINK = auto() |
| BPM_NAME = auto() |
| BPM_VALUE = auto() |
| NEWLINE_AFTER_BPM = auto() |
| CAPTION_NAME = auto() |
| CAPTION_VALUE = auto() |
| DURATION_NAME = auto() |
| DURATION_VALUE = auto() |
| NEWLINE_AFTER_DURATION = auto() |
| GENRES_NAME = auto() |
| GENRES_VALUE = auto() |
| NEWLINE_AFTER_GENRES = auto() |
| KEYSCALE_NAME = auto() |
| KEYSCALE_VALUE = auto() |
| NEWLINE_AFTER_KEYSCALE = auto() |
| LANGUAGE_NAME = auto() |
| LANGUAGE_VALUE = auto() |
| TIMESIG_NAME = auto() |
| TIMESIG_VALUE = auto() |
| NEWLINE_AFTER_TIMESIG = auto() |
| THINK_END_TAG = auto() |
| CODES_GENERATION = auto() |
| COMPLETED = auto() |
|
|
|
|
| class MetadataConstrainedLogitsProcessor(LogitsProcessor): |
| """ |
| FSM-driven LogitsProcessor that constrains generation to produce valid metadata. |
| |
| This processor enforces the following format: |
| <think> |
| bpm: [30-300] |
| caption: [text without code blocks, ends with period + newline] |
| duration: [10-600] |
| keyscale: [A-G][#/♭]? [major/minor] |
| language: [en/zh/ja/ko/es/fr/de/uk/ru/...] |
| timesignature: [2/3/4/6] |
| </think> |
| |
| It uses token masking (setting invalid token logits to -inf) to enforce constraints. |
| For numeric fields, it uses early-blocking to prevent out-of-range values. |
| For field transitions (e.g., end of numeric value), it compares P(newline) vs P(digit). |
| For caption field, it blocks code blocks and newlines, and only transitions when |
| the previous token was a period and newline has the highest probability. |
| """ |
| |
| def __init__( |
| self, |
| tokenizer: AutoTokenizer, |
| enabled: bool = True, |
| debug: bool = False, |
| genres_vocab_path: Optional[str] = None, |
| skip_genres: bool = True, |
| ): |
| """ |
| Initialize the constrained logits processor. |
| |
| This processor should be initialized once when loading the LLM and reused |
| for all generations. |
| Args: |
| tokenizer: The tokenizer to use for encoding/decoding |
| enabled: Whether to enable constrained decoding |
| debug: Whether to print debug information |
| """ |
| self.tokenizer = tokenizer |
| self.enabled = enabled |
| self.debug = debug |
| self.skip_genres = skip_genres |
| self.skip_caption = False |
| self.skip_language = False |
| self.caption: Optional[str] = None |
| |
| |
| |
| |
| self.user_provided_metadata: Dict[str, Optional[str]] = { |
| "bpm": None, |
| "caption": None, |
| "duration": None, |
| "keyscale": None, |
| "language": None, |
| "timesignature": None, |
| "genres": None, |
| } |
| |
| |
| |
| |
| self.metadata_temperature: Optional[float] = None |
| self.codes_temperature: Optional[float] = None |
| |
| |
| |
| self.target_duration: Optional[float] = None |
| self.target_codes: Optional[int] = None |
| self.codes_count: int = 0 |
| |
| |
| self.stop_at_reasoning: bool = False |
| |
| |
| |
| self.generation_phase: str = "cot" |
| |
| |
| self.state = FSMState.THINK_TAG |
| self.position_in_state = 0 |
| self.accumulated_value = "" |
| self.accumulated_token_ids: List[int] = [] |
| |
| |
| self.caption_after_newline = False |
| self.caption_token_count = 0 |
| self.caption_ending = False |
| self.pending_field_name = "" |
| |
| |
| self.user_field_token_queue: List[int] = [] |
| self.current_user_field: Optional[str] = None |
| |
| |
| self._precompute_tokens() |
|
|
| |
| self.genres_vocab_path = genres_vocab_path or os.path.join( |
| os.path.dirname(os.path.abspath(__file__)), "genres_vocab.txt" |
| ) |
| self.genres_vocab: List[str] = [] |
| self.genres_vocab_mtime: float = 0.0 |
| self.genres_trie: Dict = {} |
| self.caption_genres_trie: Dict = {} |
| self.caption_matched_genres: List[str] = [] |
| |
| self._char_to_tokens: Dict[str, set] = {} |
| |
| |
| self._precompute_char_token_mapping() |
| |
| |
| self.field_specs = { |
| "bpm": {"min": BPM_MIN, "max": BPM_MAX}, |
| "duration": {"min": DURATION_MIN, "max": DURATION_MAX}, |
| "timesignature": {"valid_values": VALID_TIME_SIGNATURES}, |
| } |
| |
| |
| |
| self.valid_bpm_values = [str(v) for v in range(self.field_specs["bpm"]["min"], self.field_specs["bpm"]["max"] + 1)] |
| self.valid_duration_values = [str(v) for v in range(self.field_specs["duration"]["min"], self.field_specs["duration"]["max"] + 1)] |
| self.valid_timesig_values = [str(v) for v in self.field_specs["timesignature"]["valid_values"]] |
| |
| |
| self.keyscale_prefix_tree = self._build_keyscale_prefix_tree() |
| |
| |
| |
| |
| self.bpm_prefix_tree = self._build_numeric_prefix_tree( |
| self.valid_bpm_values, |
| context_prefix_for_matching="bpm:", |
| context_prefix_for_tokenization="bpm: " |
| ) |
| self.duration_prefix_tree = self._build_numeric_prefix_tree( |
| self.valid_duration_values, |
| context_prefix_for_matching="duration:", |
| context_prefix_for_tokenization="duration: " |
| ) |
| self.timesig_prefix_tree = self._build_numeric_prefix_tree( |
| self.valid_timesig_values, |
| context_prefix_for_matching="timesignature:", |
| context_prefix_for_tokenization="timesignature: " |
| ) |
| |
| |
| self.language_prefix_tree = self._build_language_prefix_tree() |
|
|
| self._load_genres_vocab() |
| |
| |
| |
| |
| |
| self.fixed_strings = { |
| FSMState.THINK_TAG: "<think>", |
| FSMState.NEWLINE_AFTER_THINK: "\n", |
| FSMState.BPM_NAME: "bpm:", |
| FSMState.CAPTION_NAME: "caption:", |
| FSMState.DURATION_NAME: "duration:", |
| FSMState.GENRES_NAME: "genres:", |
| FSMState.KEYSCALE_NAME: "keyscale:", |
| FSMState.LANGUAGE_NAME: "language:", |
| FSMState.TIMESIG_NAME: "timesignature:", |
| FSMState.THINK_END_TAG: "</think>", |
| } |
| |
| |
| self._build_state_transitions() |
| |
| def _get_next_field_state(self, current_field: str) -> Optional[FSMState]: |
| """ |
| Get the next field state. Always returns the next field's NAME state, |
| even if the field is user-provided (we still need to generate the field name). |
| |
| Args: |
| current_field: Current field name ("bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature") |
| |
| Returns: |
| Next FSMState (NAME state of next field), or THINK_END_TAG if no more fields |
| """ |
| |
| |
| field_order = ["bpm", "caption", "duration", "genres", "keyscale", "language", "timesignature"] |
| field_to_state = { |
| "bpm": FSMState.BPM_NAME, |
| "caption": FSMState.CAPTION_NAME, |
| "duration": FSMState.DURATION_NAME, |
| "genres": FSMState.GENRES_NAME, |
| "keyscale": FSMState.KEYSCALE_NAME, |
| "language": FSMState.LANGUAGE_NAME, |
| "timesignature": FSMState.TIMESIG_NAME, |
| } |
| |
| try: |
| current_idx = field_order.index(current_field) |
| except ValueError: |
| return FSMState.THINK_END_TAG |
| |
| |
| for i in range(current_idx + 1, len(field_order)): |
| field = field_order[i] |
| |
| |
| if field == "genres" and self.skip_genres: |
| continue |
| if field == "caption" and self.skip_caption: |
| continue |
| if field == "language" and self.skip_language: |
| continue |
| |
| |
| return field_to_state[field] |
| |
| |
| return FSMState.THINK_END_TAG |
| |
| def _build_state_transitions(self): |
| """Build state transition map based on user-provided metadata.""" |
| self.next_state = { |
| FSMState.THINK_TAG: FSMState.NEWLINE_AFTER_THINK, |
| FSMState.NEWLINE_AFTER_THINK: FSMState.BPM_NAME, |
| FSMState.THINK_END_TAG: FSMState.CODES_GENERATION, |
| FSMState.CODES_GENERATION: FSMState.COMPLETED, |
| } |
| |
| |
| |
| |
| |
| self.next_state[FSMState.BPM_NAME] = FSMState.BPM_VALUE |
| self.next_state[FSMState.BPM_VALUE] = self._get_next_field_state("bpm") |
| |
| |
| if not self.skip_caption: |
| self.next_state[FSMState.CAPTION_NAME] = FSMState.CAPTION_VALUE |
| self.next_state[FSMState.CAPTION_VALUE] = self._get_next_field_state("caption") |
| |
| |
| self.next_state[FSMState.DURATION_NAME] = FSMState.DURATION_VALUE |
| self.next_state[FSMState.DURATION_VALUE] = self._get_next_field_state("duration") |
|
|
| |
| if not self.skip_genres: |
| self.next_state[FSMState.GENRES_NAME] = FSMState.GENRES_VALUE |
| self.next_state[FSMState.GENRES_VALUE] = self._get_next_field_state("genres") |
| |
| |
| self.next_state[FSMState.KEYSCALE_NAME] = FSMState.KEYSCALE_VALUE |
| self.next_state[FSMState.KEYSCALE_VALUE] = self._get_next_field_state("keyscale") |
| |
| |
| if not self.skip_language: |
| self.next_state[FSMState.LANGUAGE_NAME] = FSMState.LANGUAGE_VALUE |
| self.next_state[FSMState.LANGUAGE_VALUE] = self._get_next_field_state("language") |
| |
| |
| self.next_state[FSMState.TIMESIG_NAME] = FSMState.TIMESIG_VALUE |
| self.next_state[FSMState.TIMESIG_VALUE] = FSMState.THINK_END_TAG |
|
|
| def set_skip_genres(self, skip: bool): |
| """Set whether to skip genres generation and rebuild state transitions.""" |
| self.skip_genres = skip |
| self._build_state_transitions() |
| |
| def set_skip_caption(self, skip: bool): |
| """Set whether to skip caption generation and rebuild state transitions.""" |
| self.skip_caption = skip |
| self._build_state_transitions() |
| |
| def set_skip_language(self, skip: bool): |
| """Set whether to skip language generation and rebuild state transitions.""" |
| self.skip_language = skip |
| self._build_state_transitions() |
| |
| @staticmethod |
| def postprocess_caption(caption: str) -> str: |
| """ |
| Post-process caption to remove YAML multi-line formatting. |
| Converts YAML-style multi-line text (with newlines and leading spaces) |
| to a single-line string. |
| |
| Example: |
| Input: "An emotional ballad.\\n The track opens with piano.\\n More text." |
| Output: "An emotional ballad. The track opens with piano. More text." |
| |
| Args: |
| caption: Raw caption text with possible YAML formatting |
| |
| Returns: |
| Clean single-line caption |
| """ |
| if not caption: |
| return caption |
| |
| |
| lines = caption.split('\n') |
| |
| |
| cleaned_lines = [] |
| for line in lines: |
| stripped = line.strip() |
| if stripped: |
| cleaned_lines.append(stripped) |
| |
| |
| return ' '.join(cleaned_lines) |
| |
| def set_stop_at_reasoning(self, stop: bool): |
| """ |
| Set whether to stop generation after </think> tag. |
| |
| Args: |
| stop: If True, generation will stop immediately after </think> tag is generated. |
| If False, generation continues to codes generation phase. |
| """ |
| self.stop_at_reasoning = stop |
| |
| def set_generation_phase(self, phase: str): |
| """ |
| Set the generation phase. |
| |
| Args: |
| phase: "cot" for CoT metadata generation, "codes" for audio codes generation, |
| or "understand" for audio understanding (codes → metadata + lyrics). |
| When phase is "codes" and the input prompt already contains </think>, |
| the FSM will skip metadata generation and go directly to codes generation. |
| When phase is "understand", generate CoT metadata then free-form lyrics. |
| """ |
| if phase not in ("cot", "codes", "understand"): |
| raise ValueError(f"Invalid generation phase: {phase!r}. Must be 'cot', 'codes', or 'understand'") |
| self.generation_phase = phase |
| |
| def set_user_metadata(self, metadata: Optional[Dict[str, Optional[str]]] = None): |
| """ |
| Set user-provided metadata fields. Fields that are provided will be used directly |
| instead of generating. Fields that are None will be generated. |
| |
| Args: |
| metadata: Dictionary with optional fields: |
| - "bpm": Optional[str] - e.g., "120" |
| - "caption": Optional[str] - e.g., "A melodic piano piece..." |
| - "duration": Optional[str] - e.g., "234" |
| - "keyscale": Optional[str] - e.g., "G major" |
| - "language": Optional[str] - e.g., "en" |
| - "timesignature": Optional[str] - e.g., "4" |
| - "genres": Optional[str] - e.g., "Pop Rock" |
| If None, clears all user-provided metadata. |
| """ |
| if metadata is None: |
| metadata = {} |
| |
| |
| for field in ["bpm", "caption", "duration", "keyscale", "language", "timesignature", "genres"]: |
| if field in metadata: |
| self.user_provided_metadata[field] = metadata[field] |
| else: |
| self.user_provided_metadata[field] = None |
| |
| |
| self._build_state_transitions() |
| |
| if self.debug: |
| provided_fields = [k for k, v in self.user_provided_metadata.items() if v is not None] |
| if provided_fields: |
| logger.debug(f"User provided metadata fields: {provided_fields}") |
| else: |
| logger.debug("No user-provided metadata, all fields will be generated") |
| |
| def _precompute_tokens(self): |
| """Pre-compute commonly used token IDs for efficiency.""" |
| |
| self.digit_tokens = {} |
| for d in range(10): |
| tokens = self.tokenizer.encode(str(d), add_special_tokens=False) |
| if tokens: |
| self.digit_tokens[d] = tokens[-1] |
| |
| |
| newline_tokens = self.tokenizer.encode("\n", add_special_tokens=False) |
| self.newline_token = newline_tokens[-1] if newline_tokens else None |
| |
| |
| self.note_tokens = {} |
| for note in KEYSCALE_NOTES: |
| tokens = self.tokenizer.encode(note, add_special_tokens=False) |
| if tokens: |
| self.note_tokens[note] = tokens[-1] |
| |
| |
| self.sharp_tokens = [] |
| for s in ["#", "♯"]: |
| tokens = self.tokenizer.encode(s, add_special_tokens=False) |
| if tokens: |
| self.sharp_tokens.append(tokens[-1]) |
| |
| self.flat_tokens = [] |
| for f in ["b", "♭"]: |
| tokens = self.tokenizer.encode(f, add_special_tokens=False) |
| if tokens: |
| self.flat_tokens.append(tokens[-1]) |
| |
| |
| space_tokens = self.tokenizer.encode(" ", add_special_tokens=False) |
| self.space_token = space_tokens[-1] if space_tokens else None |
| |
| |
| self.major_start_tokens = [] |
| self.minor_start_tokens = [] |
| for prefix in ["m", "M"]: |
| tokens = self.tokenizer.encode(prefix, add_special_tokens=False) |
| if tokens: |
| if prefix.lower() == "m": |
| self.minor_start_tokens.append(tokens[-1]) |
| self.major_start_tokens.append(tokens[-1]) |
| |
| |
| self.vocab_size = len(self.tokenizer) |
|
|
| |
| comma_tokens = self.tokenizer.encode(",", add_special_tokens=False) |
| self.comma_token = comma_tokens[-1] if comma_tokens else None |
| |
| |
| self.eos_token_id = self.tokenizer.eos_token_id |
| |
| |
| period_tokens = self.tokenizer.encode(".", add_special_tokens=False) |
| self.period_token = period_tokens[-1] if period_tokens else None |
| |
| |
| backtick_tokens = self.tokenizer.encode("`", add_special_tokens=False) |
| self.backtick_token = backtick_tokens[-1] if backtick_tokens else None |
| |
| |
| self.valid_languages = VALID_LANGUAGES |
| |
| |
| |
| self.audio_code_token_ids: Set[int] = set() |
| self._precompute_audio_code_tokens() |
| |
| |
| |
| self.audio_code_mask: Optional[torch.Tensor] = None |
| |
| self.non_audio_code_mask: Optional[torch.Tensor] = None |
| self._build_audio_code_mask() |
| |
| |
| |
| self.valid_keyscales = VALID_KEYSCALES.copy() |
| |
| |
| |
| |
| def _precompute_audio_code_tokens(self): |
| """ |
| Precompute audio code token IDs (tokens matching <|audio_code_\\d+|>). |
| These tokens should be blocked during caption generation. |
| """ |
| import re |
| audio_code_pattern = re.compile(r'^<\|audio_code_\d+\|>$') |
| |
| |
| for token_id in range(self.vocab_size): |
| try: |
| token_text = self.tokenizer.decode([token_id]) |
| if audio_code_pattern.match(token_text): |
| self.audio_code_token_ids.add(token_id) |
| except Exception: |
| continue |
| |
| if self.debug: |
| logger.debug(f"Found {len(self.audio_code_token_ids)} audio code tokens") |
| |
| def _build_audio_code_mask(self): |
| """ |
| Build a precomputed mask tensor for blocking audio code tokens. |
| This mask can be added to scores in O(1) time instead of O(n) loop. |
| |
| The mask is [1, vocab_size] tensor with -inf at audio code token positions. |
| |
| Also builds the inverse mask (non_audio_code_mask) for CODES_GENERATION state, |
| which blocks all non-audio-code tokens. |
| """ |
| if not self.audio_code_token_ids: |
| self.audio_code_mask = None |
| self.non_audio_code_mask = None |
| return |
| |
| |
| |
| mask = torch.zeros(1, self.vocab_size, dtype=torch.float32) |
| |
| |
| audio_code_indices = list(self.audio_code_token_ids) |
| |
| |
| mask[0, audio_code_indices] = float('-inf') |
| |
| self.audio_code_mask = mask |
| |
| |
| |
| inverse_mask = torch.full((1, self.vocab_size), float('-inf'), dtype=torch.float32) |
| inverse_mask[0, audio_code_indices] = 0 |
| |
| |
| if self.eos_token_id is not None: |
| inverse_mask[0, self.eos_token_id] = 0 |
| |
| self.non_audio_code_mask = inverse_mask |
| |
| if self.debug: |
| logger.debug(f"Built audio code masks for {len(self.audio_code_token_ids)} tokens") |
|
|
| def _apply_whitelist_inplace(self, scores: torch.Tensor, allowed_tokens: List[int]) -> None: |
| """ |
| Apply whitelist constraint inplace: only allow specified tokens, block all others. |
| |
| This is more efficient than creating a mask tensor because: |
| 1. No memory allocation for mask |
| 2. No tensor addition operation |
| |
| Args: |
| scores: [1, vocab_size] scores tensor to modify inplace |
| allowed_tokens: List of token IDs to allow (all others will be set to -inf) |
| """ |
| if not allowed_tokens: |
| |
| scores.fill_(float('-inf')) |
| return |
| |
| |
| allowed_indices = torch.tensor(allowed_tokens, device=scores.device, dtype=torch.long) |
| saved_values = scores[0, allowed_indices].clone() |
| |
| |
| scores.fill_(float('-inf')) |
| |
| |
| scores[0, allowed_indices] = saved_values |
|
|
| def _build_keyscale_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]: |
| """ |
| Build keyscale prefix to allowed tokens mapping based on ACTUAL tokenization. |
| |
| IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches. |
| |
| CRITICAL FIX: The tokenizer may merge the context's trailing space into the next token. |
| For example: |
| - "keyscale: " tokenizes to [10563, 2246, 25, 220] -> ['keys', 'cale', ':', ' '] |
| - "keyscale: G major" tokenizes to [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major'] |
| The space ' ' (220) is merged into ' G' (479), so we can't use simple slicing. |
| |
| Strategy: |
| 1. For each keyscale (e.g., "G major"), encode the FULL string "keyscale: G major" |
| 2. Tokenize to get: [10563, 2246, 25, 479, 3598] -> ['keys', 'cale', ':', ' G', ' major'] |
| 3. Find where context prefix ends by matching token sequences (handling space merging) |
| 4. Extract keyscale value tokens: [479, 3598] (for "G major") |
| 5. Build prefix tree using token ID sequences as keys |
| |
| This ensures we get the exact tokenization that occurs during generation. |
| """ |
| prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {} |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| context_prefix_for_matching = "keyscale:" |
| context_prefix_for_tokenization = "keyscale: " |
| |
| |
| context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False) |
| |
| if self.debug: |
| context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids] |
| logger.debug(f"Context for matching 'keyscale:' tokenizes to {context_token_ids} -> {context_tokens_str}") |
| |
| |
| for keyscale in self.valid_keyscales: |
| |
| full_text = context_prefix_for_tokenization + keyscale |
| full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False) |
| |
| |
| |
| |
| context_end_idx = None |
| |
| |
| if len(full_token_ids) >= len(context_token_ids): |
| if full_token_ids[:len(context_token_ids)] == context_token_ids: |
| context_end_idx = len(context_token_ids) |
| |
| if context_end_idx is None: |
| if self.debug: |
| logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping") |
| continue |
| |
| |
| keyscale_token_ids = full_token_ids[context_end_idx:] |
| |
| |
| if not keyscale_token_ids: |
| if self.debug: |
| logger.warning(f"No tokens extracted for keyscale '{keyscale}', skipping") |
| continue |
| |
| |
| |
| first_token_id = keyscale_token_ids[0] |
| first_token_str = self.tokenizer.decode([first_token_id]) |
| |
| first_char = first_token_str.lstrip()[0].upper() if first_token_str.lstrip() else "" |
| if first_char not in "ABCDEFG": |
| |
| if self.debug: |
| logger.debug(f"Skipping keyscale '{keyscale}': first token is '{first_token_str}' (id={first_token_id}), not a note") |
| continue |
| |
| |
| |
| for i in range(len(keyscale_token_ids) + 1): |
| |
| token_prefix = tuple(keyscale_token_ids[:i]) |
| |
| if token_prefix not in prefix_to_tokens: |
| prefix_to_tokens[token_prefix] = set() |
| |
| if i < len(keyscale_token_ids): |
| |
| next_token_id = keyscale_token_ids[i] |
| prefix_to_tokens[token_prefix].add(next_token_id) |
| else: |
| |
| if self.newline_token: |
| prefix_to_tokens[token_prefix].add(self.newline_token) |
| |
| if self.debug: |
| logger.debug(f"Built keyscale prefix tree with {len(prefix_to_tokens)} token sequence prefixes") |
| |
| empty_prefix = tuple() |
| if empty_prefix in prefix_to_tokens: |
| first_tokens = prefix_to_tokens[empty_prefix] |
| decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)] |
| logger.debug(f"First tokens allowed (empty prefix): {decoded_first}") |
| |
| return prefix_to_tokens |
| |
| def _build_numeric_prefix_tree( |
| self, |
| valid_values: List[str], |
| context_prefix_for_matching: str = "", |
| context_prefix_for_tokenization: str = "" |
| ) -> Dict[Tuple[int, ...], Set[int]]: |
| """ |
| Build prefix tree for numeric field based on actual tokenization with context. |
| |
| IMPORTANT: Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches. |
| |
| Args: |
| valid_values: List of valid numeric strings (e.g., ["30", "31", ..., "300"]) |
| context_prefix_for_matching: Context string that state machine generates (e.g., "bpm:") - no space |
| context_prefix_for_tokenization: Context string for tokenization (e.g., "bpm: ") - with space |
| |
| Returns: |
| Dict mapping token ID sequence prefix -> set of allowed token IDs |
| """ |
| prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {} |
| |
| |
| context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False) if context_prefix_for_matching else [] |
| |
| |
| for value_str in valid_values: |
| |
| full_text = context_prefix_for_tokenization + value_str |
| token_ids = self.tokenizer.encode(full_text, add_special_tokens=False) |
| |
| |
| context_end_idx = None |
| if len(token_ids) >= len(context_token_ids): |
| if token_ids[:len(context_token_ids)] == context_token_ids: |
| context_end_idx = len(context_token_ids) |
| |
| if context_end_idx is None: |
| if self.debug: |
| logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping") |
| continue |
| |
| |
| value_token_ids = token_ids[context_end_idx:] |
| |
| |
| for i in range(len(value_token_ids) + 1): |
| |
| token_prefix = tuple(value_token_ids[:i]) |
| |
| if token_prefix not in prefix_to_tokens: |
| prefix_to_tokens[token_prefix] = set() |
| |
| if i < len(value_token_ids): |
| |
| next_token_id = value_token_ids[i] |
| prefix_to_tokens[token_prefix].add(next_token_id) |
| else: |
| |
| if self.newline_token: |
| prefix_to_tokens[token_prefix].add(self.newline_token) |
| |
| return prefix_to_tokens |
| |
| def _build_language_prefix_tree(self) -> Dict[Tuple[int, ...], Set[int]]: |
| """ |
| Build language prefix to allowed tokens mapping based on ACTUAL tokenization. |
| Similar to keyscale prefix tree but for language codes. |
| |
| Uses token ID sequences as keys, NOT strings, to avoid tokenization mismatches. |
| """ |
| prefix_to_tokens: Dict[Tuple[int, ...], Set[int]] = {} |
| |
| context_prefix_for_matching = "language:" |
| context_prefix_for_tokenization = "language: " |
| |
| context_token_ids = self.tokenizer.encode(context_prefix_for_matching, add_special_tokens=False) |
| |
| if self.debug: |
| context_tokens_str = [self.tokenizer.decode([t]) for t in context_token_ids] |
| logger.debug(f"Context for matching 'language:' tokenizes to {context_token_ids} -> {context_tokens_str}") |
| |
| for lang in self.valid_languages: |
| full_text = context_prefix_for_tokenization + lang |
| full_token_ids = self.tokenizer.encode(full_text, add_special_tokens=False) |
| |
| context_end_idx = None |
| if len(full_token_ids) >= len(context_token_ids): |
| if full_token_ids[:len(context_token_ids)] == context_token_ids: |
| context_end_idx = len(context_token_ids) |
| |
| if context_end_idx is None: |
| if self.debug: |
| logger.warning(f"Could not find context prefix in full tokenization of '{full_text}', skipping") |
| continue |
| |
| lang_token_ids = full_token_ids[context_end_idx:] |
| |
| if not lang_token_ids: |
| if self.debug: |
| logger.warning(f"No tokens extracted for language '{lang}', skipping") |
| continue |
| |
| for i in range(len(lang_token_ids) + 1): |
| token_prefix = tuple(lang_token_ids[:i]) |
| |
| if token_prefix not in prefix_to_tokens: |
| prefix_to_tokens[token_prefix] = set() |
| |
| if i < len(lang_token_ids): |
| next_token_id = lang_token_ids[i] |
| prefix_to_tokens[token_prefix].add(next_token_id) |
| else: |
| if self.newline_token: |
| prefix_to_tokens[token_prefix].add(self.newline_token) |
| |
| if self.debug: |
| logger.debug(f"Built language prefix tree with {len(prefix_to_tokens)} token sequence prefixes") |
| empty_prefix = tuple() |
| if empty_prefix in prefix_to_tokens: |
| first_tokens = prefix_to_tokens[empty_prefix] |
| decoded_first = [(t, repr(self.tokenizer.decode([t]))) for t in sorted(first_tokens)] |
| logger.debug(f"First tokens allowed for language (empty prefix): {decoded_first}") |
| |
| return prefix_to_tokens |
| |
| def diagnose_keyscale_prefix_tree(self): |
| """ |
| Diagnose the keyscale prefix tree to help debug generation bias. |
| Call this method to print detailed information about allowed tokens at each prefix. |
| """ |
| print("=" * 60) |
| print("KEYSCALE PREFIX TREE DIAGNOSIS") |
| print("=" * 60) |
| |
| |
| if "" in self.keyscale_prefix_tree: |
| first_tokens = self.keyscale_prefix_tree[""] |
| print(f"\n[Empty prefix] Allowed first tokens ({len(first_tokens)} total):") |
| for t in sorted(first_tokens): |
| decoded = self.tokenizer.decode([t]) |
| print(f" Token {t}: {repr(decoded)}") |
| else: |
| print("\nWARNING: Empty prefix not in tree!") |
| |
| |
| test_prefixes = ["A", "B", "C", "D", "E", "F", "G"] |
| for prefix in test_prefixes: |
| |
| for test_key in [prefix, prefix + " "]: |
| if test_key in self.keyscale_prefix_tree: |
| tokens = self.keyscale_prefix_tree[test_key] |
| print(f"\n[Prefix {repr(test_key)}] Allowed tokens ({len(tokens)}):") |
| for t in sorted(tokens): |
| decoded = self.tokenizer.decode([t]) |
| print(f" Token {t}: {repr(decoded)}") |
| |
| |
| print(f"\n[Valid keyscales] Total: {len(self.valid_keyscales)}") |
| sample = sorted(list(self.valid_keyscales))[:10] |
| for ks in sample: |
| print(f" {repr(ks)}") |
| |
| print("=" * 60) |
|
|
| |
| def _load_genres_vocab(self): |
| """ |
| Load genres vocabulary from file. Supports hot reload by checking file mtime. |
| File format: one genre per line, lines starting with # are comments. |
| """ |
| if not os.path.exists(self.genres_vocab_path): |
| if self.debug: |
| logger.debug(f"Genres vocab file not found: {self.genres_vocab_path}") |
| return |
| |
| try: |
| mtime = os.path.getmtime(self.genres_vocab_path) |
| if mtime <= self.genres_vocab_mtime: |
| return |
| |
| with open(self.genres_vocab_path, 'r', encoding='utf-8') as f: |
| genres = [] |
| for line in f: |
| line = line.strip() |
| if line and not line.startswith('#'): |
| genres.append(line.lower()) |
| |
| self.genres_vocab = genres |
| self.genres_vocab_mtime = mtime |
| self._build_genres_trie() |
| |
| if self.debug: |
| logger.debug(f"Loaded {len(self.genres_vocab)} genres from {self.genres_vocab_path}") |
| except Exception as e: |
| logger.warning(f"Failed to load genres vocab: {e}") |
| |
| def _build_genres_trie(self): |
| """ |
| Build a trie (prefix tree) from genres vocabulary for efficient prefix matching. |
| Each node is a dict with: |
| - '_end': True if this node represents a complete genre |
| - other keys: next characters in the trie |
| """ |
| self.genres_trie = {} |
| |
| for genre in self.genres_vocab: |
| node = self.genres_trie |
| for char in genre: |
| if char not in node: |
| node[char] = {} |
| node = node[char] |
| node['_end'] = True |
| |
| if self.debug: |
| logger.debug(f"Built genres trie with {len(self.genres_vocab)} entries") |
| |
| def _extract_caption_genres(self, caption: str): |
| """ |
| Extract genres from the user's caption that match entries in the vocabulary. |
| This creates a smaller trie for faster and more relevant genre generation. |
| |
| Strategy (optimized - O(words * max_genre_len) instead of O(vocab_size)): |
| 1. Extract words/phrases from caption |
| 2. For each word, use trie to find all vocab entries that START with this word |
| 3. Build a separate trie from matched genres |
| """ |
| if not caption or not self.genres_vocab: |
| return |
| |
| caption_lower = caption.lower() |
| matched_genres = set() |
| |
| |
| import re |
| words = re.split(r'[,\s\-_/\\|]+', caption_lower) |
| words = [w.strip() for w in words if w.strip() and len(w.strip()) >= 2] |
| |
| |
| for word in words: |
| |
| node = self._get_genres_trie_node(word) |
| if node is not None: |
| |
| self._collect_complete_genres(node, word, matched_genres) |
| |
| |
| |
| genres_set = set(self.genres_vocab) |
| for word in words: |
| if word in genres_set: |
| matched_genres.add(word) |
| |
| if not matched_genres: |
| if self.debug: |
| logger.debug(f"No genres matched in caption, using full vocab") |
| return |
| |
| |
| self.caption_matched_genres = list(matched_genres) |
| self.caption_genres_trie = {} |
| |
| for genre in matched_genres: |
| node = self.caption_genres_trie |
| for char in genre: |
| if char not in node: |
| node[char] = {} |
| node = node[char] |
| node['_end'] = True |
| |
| if self.debug: |
| logger.debug(f"Matched {len(matched_genres)} genres from caption: {list(matched_genres)[:5]}...") |
| |
| def _collect_complete_genres(self, node: Dict, prefix: str, result: set, max_depth: int = 50): |
| """ |
| Recursively collect all complete genres under a trie node. |
| Limited depth to avoid too many matches. |
| """ |
| if max_depth <= 0: |
| return |
| |
| if node.get('_end', False): |
| result.add(prefix) |
| |
| |
| if len(result) >= 100: |
| return |
| |
| for char, child_node in node.items(): |
| if char not in ('_end', '_tokens'): |
| self._collect_complete_genres(child_node, prefix + char, result, max_depth - 1) |
| |
| def _precompute_char_token_mapping(self): |
| """ |
| Precompute mapping from characters to token IDs and token decoded texts. |
| This allows O(1) lookup instead of calling tokenizer.encode()/decode() at runtime. |
| |
| Time complexity: O(vocab_size) - runs once during initialization |
| |
| Note: Many subword tokenizers (like Qwen) add space prefixes to tokens. |
| We need to handle both the raw first char and the first non-space char. |
| """ |
| self._char_to_tokens: Dict[str, set] = {} |
| self._token_to_text: Dict[int, str] = {} |
| |
| |
| for token_id in range(self.vocab_size): |
| try: |
| text = self.tokenizer.decode([token_id]) |
| |
| if not text: |
| continue |
| |
| |
| |
| |
| text_lower = text.lower() |
| if text_lower.strip(): |
| normalized_text = text_lower.rstrip() |
| else: |
| normalized_text = " " |
| self._token_to_text[token_id] = normalized_text |
| |
| |
| first_char = text[0].lower() |
| if first_char not in self._char_to_tokens: |
| self._char_to_tokens[first_char] = set() |
| self._char_to_tokens[first_char].add(token_id) |
| |
| |
| |
| stripped_text = text.lstrip() |
| if stripped_text and stripped_text != text: |
| first_nonspace_char = stripped_text[0].lower() |
| if first_nonspace_char not in self._char_to_tokens: |
| self._char_to_tokens[first_nonspace_char] = set() |
| self._char_to_tokens[first_nonspace_char].add(token_id) |
| |
| except Exception: |
| continue |
| |
| if self.debug: |
| logger.debug(f"Precomputed char->token mapping for {len(self._char_to_tokens)} unique characters") |
| |
| def _try_reload_genres_vocab(self): |
| """Check if genres vocab file has been updated and reload if necessary.""" |
| if not os.path.exists(self.genres_vocab_path): |
| return |
| |
| try: |
| mtime = os.path.getmtime(self.genres_vocab_path) |
| if mtime > self.genres_vocab_mtime: |
| self._load_genres_vocab() |
| except Exception: |
| pass |
| |
| def _get_genres_trie_node(self, prefix: str) -> Optional[Dict]: |
| """ |
| Get the trie node for a given prefix. |
| Returns None if the prefix is not valid (no genres start with this prefix). |
| """ |
| node = self.genres_trie |
| for char in prefix.lower(): |
| if char not in node: |
| return None |
| node = node[char] |
| return node |
| |
| def _is_complete_genre(self, text: str) -> bool: |
| """Check if the given text is a complete genre in the vocabulary.""" |
| node = self._get_genres_trie_node(text.strip()) |
| return node is not None and node.get('_end', False) |
| |
| def _get_trie_node_from_trie(self, trie: Dict, prefix: str) -> Optional[Dict]: |
| """Get a trie node from a specific trie (helper for caption vs full trie).""" |
| node = trie |
| for char in prefix.lower(): |
| if char not in node: |
| return None |
| node = node[char] |
| return node |
|
|
| def _get_allowed_genres_tokens(self) -> List[int]: |
| """ |
| Get allowed tokens for genres field based on trie matching. |
| |
| The entire genres string (including commas) must match a complete entry in the vocab. |
| For example, if vocab contains "pop, rock, jazz", the generated string must exactly |
| match that entry - we don't treat commas as separators for individual genres. |
| |
| Strategy: |
| 1. If caption-matched genres exist, use that smaller trie first (faster + more relevant) |
| 2. If no caption matches or prefix not in caption trie, fallback to full vocab trie |
| 3. Get valid next characters from current trie node |
| 4. For each candidate token, verify the full decoded text forms a valid trie prefix |
| """ |
| if not self.genres_vocab: |
| |
| return [] |
| |
| |
| accumulated = self.accumulated_value.lower() |
| current_genre_prefix = accumulated.strip() |
| |
| |
| use_caption_trie = False |
| current_node = None |
| |
| |
| if self.caption_genres_trie: |
| if current_genre_prefix == "": |
| current_node = self.caption_genres_trie |
| use_caption_trie = True |
| else: |
| current_node = self._get_trie_node_from_trie(self.caption_genres_trie, current_genre_prefix) |
| if current_node is not None: |
| use_caption_trie = True |
| |
| |
| if current_node is None: |
| if current_genre_prefix == "": |
| current_node = self.genres_trie |
| else: |
| current_node = self._get_genres_trie_node(current_genre_prefix) |
| |
| if current_node is None: |
| |
| if self.newline_token: |
| return [self.newline_token] |
| return [] |
| |
| |
| valid_next_chars = set(k for k in current_node.keys() if k not in ('_end', '_tokens')) |
| |
| |
| is_complete = current_node.get('_end', False) |
| |
| if not valid_next_chars: |
| |
| allowed = set() |
| if is_complete and self.newline_token: |
| allowed.add(self.newline_token) |
| return list(allowed) |
| |
| |
| candidate_tokens = set() |
| for char in valid_next_chars: |
| if char in self._char_to_tokens: |
| candidate_tokens.update(self._char_to_tokens[char]) |
| |
| |
| active_trie = self.caption_genres_trie if use_caption_trie else self.genres_trie |
| |
| |
| allowed = set() |
| for token_id in candidate_tokens: |
| |
| decoded_normalized = self._token_to_text.get(token_id, "") |
| |
| if not decoded_normalized or not decoded_normalized.strip(): |
| |
| if ' ' in valid_next_chars or ',' in valid_next_chars: |
| allowed.add(token_id) |
| continue |
| |
| |
| |
| if decoded_normalized.startswith(' ') or decoded_normalized.startswith(','): |
| |
| new_prefix = current_genre_prefix + decoded_normalized |
| else: |
| new_prefix = current_genre_prefix + decoded_normalized |
| |
| |
| new_node = self._get_trie_node_from_trie(active_trie, new_prefix) |
| if new_node is not None: |
| allowed.add(token_id) |
| |
| |
| if is_complete and self.newline_token: |
| allowed.add(self.newline_token) |
| |
| return list(allowed) |
| |
| def reset(self): |
| """Reset the processor state for a new generation.""" |
| self.state = FSMState.THINK_TAG |
| self.position_in_state = 0 |
| self.accumulated_value = "" |
| self.accumulated_token_ids = [] |
| self.codes_count = 0 |
| self.user_field_token_queue = [] |
| self.current_user_field = None |
| self.caption_after_newline = False |
| self.caption_token_count = 0 |
| self.caption_ending = False |
| self.pending_field_name = "" |
| |
| def set_target_duration(self, duration: Optional[float]): |
| """ |
| Set the target duration for codes generation. |
| |
| Args: |
| duration: Target duration in seconds. If None, no duration constraint is applied. |
| 5 codes = 1 second, so target_codes = duration * 5. |
| """ |
| self.target_duration = duration |
| if duration is not None and duration > 0: |
| self.target_codes = int(duration * 5) |
| if self.debug: |
| logger.debug(f"Set target duration: {duration}s -> {self.target_codes} codes") |
| else: |
| self.target_codes = None |
| if self.debug: |
| logger.debug("Target duration cleared, no duration constraint") |
| |
| def _get_allowed_tokens_for_fixed_string(self, fixed_str: str) -> List[int]: |
| """ |
| Get the token IDs that can continue the fixed string from current position. |
| Returns list of allowed token IDs. |
| |
| Strategy: Find the longest prefix that encodes to a single token, and return that token. |
| This ensures we generate by tokens, not character-by-character. |
| """ |
| remaining = fixed_str[self.position_in_state:] |
| if not remaining: |
| return [] |
| |
| if self.debug: |
| logger.debug(f"_get_allowed_tokens_for_fixed_string: fixed_str={repr(fixed_str)}, position_in_state={self.position_in_state}, remaining={repr(remaining)}") |
| |
| |
| |
| best_token = None |
| best_prefix_len = 0 |
| |
| |
| for end in range(len(remaining), 0, -1): |
| prefix = remaining[:end] |
| tokens = self.tokenizer.encode(prefix, add_special_tokens=False) |
| if tokens and len(tokens) == 1: |
| |
| |
| best_token = tokens[0] |
| best_prefix_len = end |
| if self.debug: |
| logger.debug(f"Found single-token match: prefix={repr(prefix)}, token_id={best_token}, token_text={repr(self.tokenizer.decode([best_token]))}") |
| break |
| |
| |
| if best_token is not None: |
| return [best_token] |
| |
| |
| |
| |
| |
| |
| allowed_tokens = {} |
| for end in range(1, min(len(remaining) + 1, 20)): |
| prefix = remaining[:end] |
| tokens = self.tokenizer.encode(prefix, add_special_tokens=False) |
| if tokens: |
| first_token = tokens[0] |
| |
| decoded_token = self.tokenizer.decode([first_token]) |
| |
| normalized_prefix = prefix.lstrip().lower() |
| normalized_decoded = decoded_token.lstrip().lower() |
| |
| |
| if normalized_decoded.startswith(normalized_prefix) or normalized_prefix.startswith(normalized_decoded): |
| |
| if first_token not in allowed_tokens or end > allowed_tokens[first_token]: |
| allowed_tokens[first_token] = end |
| |
| |
| |
| sorted_tokens = sorted(allowed_tokens.items(), key=lambda x: x[1], reverse=True) |
| result = [token for token, _ in sorted_tokens] if sorted_tokens else [] |
| |
| if self.debug: |
| logger.debug(f"Fallback: returning {len(result)} tokens: {[(t, repr(self.tokenizer.decode([t]))) for t in result[:5]]}") |
| if result: |
| logger.debug(f"Fixed string: {repr(fixed_str)}, position: {self.position_in_state}, remaining: {repr(remaining)}") |
| |
| return result |
| |
| def _get_allowed_digit_tokens(self, min_val: int, max_val: int) -> List[int]: |
| """ |
| Get allowed digit tokens based on accumulated value and range constraints. |
| Uses early-blocking to prevent out-of-range values. |
| """ |
| if not self.accumulated_value: |
| |
| allowed_digits = set() |
| for v in range(min_val, max_val + 1): |
| allowed_digits.add(int(str(v)[0])) |
| return [self.digit_tokens[d] for d in allowed_digits if d in self.digit_tokens] |
| |
| current = int(self.accumulated_value) |
| allowed = [] |
| |
| for d in range(10): |
| new_value = int(self.accumulated_value + str(d)) |
| |
| |
| |
| |
| |
| |
| if new_value > max_val: |
| continue |
| |
| |
| |
| |
| if new_value >= min_val: |
| allowed.append(d) |
| elif new_value * 10 <= max_val: |
| |
| allowed.append(d) |
| |
| return [self.digit_tokens[d] for d in allowed if d in self.digit_tokens] |
| |
| def _get_allowed_numeric_tokens(self, prefix_tree: Dict[Tuple[int, ...], Set[int]]) -> List[int]: |
| """ |
| Get allowed tokens for numeric field using the precomputed prefix tree. |
| |
| IMPORTANT: Uses token ID sequence as key (not string) to avoid tokenization mismatches. |
| |
| Args: |
| prefix_tree: Precomputed prefix tree mapping token ID sequence -> set of allowed token IDs |
| |
| Returns: |
| List of allowed token IDs for current accumulated_token_ids |
| """ |
| token_prefix = tuple(self.accumulated_token_ids) |
| |
| if token_prefix in prefix_tree: |
| return list(prefix_tree[token_prefix]) |
| |
| |
| |
| return [] |
| |
| def _should_end_numeric_field(self, logits: torch.Tensor, min_val: int, max_val: int) -> bool: |
| """ |
| Determine if we should end the current numeric field. |
| Returns True if P(newline) > P(any valid digit) AND current value is valid. |
| """ |
| if not self.accumulated_value: |
| return False |
| |
| current = int(self.accumulated_value) |
| if current < min_val or current > max_val: |
| return False |
| |
| |
| probs = torch.softmax(logits, dim=-1) |
| |
| newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0 |
| |
| |
| allowed_digits = self._get_allowed_digit_tokens(min_val, max_val) |
| if not allowed_digits: |
| return True |
| |
| max_digit_prob = max(probs[0, t].item() for t in allowed_digits) |
| |
| if self.debug: |
| logger.debug(f"Numeric field decision: newline_prob={newline_prob:.4f}, max_digit_prob={max_digit_prob:.4f}") |
| |
| return newline_prob > max_digit_prob |
|
|
| |
| def _should_end_text_field(self, logits: torch.Tensor) -> bool: |
| """ |
| Determine if we should end a text field (genres). |
| Returns True if P(newline) > P(any other token) AND we have some content. |
| """ |
| if not self.accumulated_value.strip(): |
| return False |
| |
| probs = torch.softmax(logits, dim=-1) |
| newline_prob = probs[0, self.newline_token].item() if self.newline_token else 0 |
| |
| |
| masked_probs = probs.clone() |
| if self.newline_token: |
| masked_probs[0, self.newline_token] = 0 |
| max_other_prob = masked_probs[0].max().item() |
| |
| return newline_prob > max_other_prob |
| |
| def _get_allowed_keyscale_tokens(self) -> List[int]: |
| """ |
| Get allowed tokens for keyscale field using the precomputed prefix tree. |
| Uses token ID sequence as key (not string) to avoid tokenization mismatches. |
| """ |
| |
| token_prefix = tuple(self.accumulated_token_ids) |
| |
| if token_prefix in self.keyscale_prefix_tree: |
| return list(self.keyscale_prefix_tree[token_prefix]) |
| |
| |
| |
| return [] |
| |
| def _is_keyscale_complete(self) -> bool: |
| """ |
| Check if keyscale value is complete and valid. |
| Uses token ID sequence to check if current prefix allows newline. |
| """ |
| token_prefix = tuple(self.accumulated_token_ids) |
| |
| if token_prefix in self.keyscale_prefix_tree: |
| return self.newline_token in self.keyscale_prefix_tree[token_prefix] |
| return False |
| |
| def _get_allowed_language_tokens(self) -> List[int]: |
| """ |
| Get allowed tokens for language field using the precomputed prefix tree. |
| Uses token ID sequence as key (not string) to avoid tokenization mismatches. |
| Similar to keyscale. |
| """ |
| token_prefix = tuple(self.accumulated_token_ids) |
| |
| if token_prefix in self.language_prefix_tree: |
| return list(self.language_prefix_tree[token_prefix]) |
| |
| |
| return [] |
| |
| def _get_allowed_timesig_tokens(self) -> List[int]: |
| """ |
| Get allowed tokens for timesignature field using the precomputed prefix tree. |
| Uses token ID sequence as key (not string) to avoid tokenization mismatches. |
| """ |
| token_prefix = tuple(self.accumulated_token_ids) |
| |
| if token_prefix in self.timesig_prefix_tree: |
| return list(self.timesig_prefix_tree[token_prefix]) |
| |
| |
| |
| return [] |
| |
| def __call__( |
| self, |
| input_ids: torch.LongTensor, |
| scores: torch.FloatTensor, |
| ) -> torch.FloatTensor: |
| """ |
| Apply constrained decoding by modifying logits. |
| |
| Args: |
| input_ids: [batch_size, seq_len] input token IDs |
| scores: [batch_size, vocab_size] logits for next token |
| |
| Returns: |
| Modified scores with invalid tokens masked to -inf and temperature scaling applied |
| """ |
| if not self.enabled: |
| return self._apply_temperature_scaling(scores) |
| |
| if self.state == FSMState.COMPLETED: |
| |
| if self.generation_phase == "understand" and self.audio_code_mask is not None: |
| |
| if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype: |
| self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype) |
| scores = scores + self.audio_code_mask |
| return self._apply_temperature_scaling(scores) |
| |
| |
| if self.generation_phase == "codes" and self.state == FSMState.THINK_TAG: |
| |
| if self._input_contains_think_end_tag(input_ids): |
| |
| self.state = FSMState.CODES_GENERATION |
| self.codes_count = 0 |
| if self.debug: |
| logger.debug("Codes phase: detected </think> in input, skipping to CODES_GENERATION") |
| |
| if self.state == FSMState.CODES_GENERATION: |
| |
| if self.non_audio_code_mask is not None: |
| |
| if self.non_audio_code_mask.device != scores.device or self.non_audio_code_mask.dtype != scores.dtype: |
| self.non_audio_code_mask = self.non_audio_code_mask.to(device=scores.device, dtype=scores.dtype) |
| scores = scores + self.non_audio_code_mask |
| |
| |
| if self.target_codes is not None and self.eos_token_id is not None: |
| if self.codes_count < self.target_codes: |
| |
| scores[:, self.eos_token_id] = float('-inf') |
| if self.debug: |
| logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, blocking EOS") |
| else: |
| |
| eos_scores = scores[:, self.eos_token_id].clone() |
| scores.fill_(float('-inf')) |
| scores[:, self.eos_token_id] = eos_scores |
| if self.debug: |
| logger.debug(f"Codes generation: {self.codes_count}/{self.target_codes}, forcing EOS") |
| return self._apply_temperature_scaling(scores) |
| |
| batch_size = scores.shape[0] |
| |
| |
| for b in range(batch_size): |
| result = self._process_single_sequence(input_ids[b], scores[b:b+1]) |
| scores[b] = result[0] |
| |
| |
| return self._apply_temperature_scaling(scores) |
| |
| def _input_contains_think_end_tag(self, input_ids: torch.LongTensor) -> bool: |
| """ |
| Check if input contains the </think> closing tag. |
| |
| Args: |
| input_ids: [batch_size, seq_len] input token IDs |
| |
| Returns: |
| True if </think> is found in the input (any sequence in batch) |
| """ |
| |
| think_end_tokens = self.tokenizer.encode("</think>", add_special_tokens=False) |
| if not think_end_tokens: |
| return False |
| |
| |
| for b in range(input_ids.shape[0]): |
| seq = input_ids[b].tolist() |
| |
| for i in range(len(seq) - len(think_end_tokens) + 1): |
| if seq[i:i+len(think_end_tokens)] == think_end_tokens: |
| return True |
| |
| return False |
| |
| def _apply_temperature_scaling(self, scores: torch.FloatTensor) -> torch.FloatTensor: |
| """ |
| Apply temperature scaling based on current generation phase. |
| |
| Temperature scaling: logits = logits / temperature |
| - Lower temperature (< 1.0) makes distribution sharper (more deterministic) |
| - Higher temperature (> 1.0) makes distribution flatter (more diverse) |
| |
| Args: |
| scores: [batch_size, vocab_size] logits |
| |
| Returns: |
| Temperature-scaled logits |
| """ |
| |
| if self.state == FSMState.CODES_GENERATION or self.state == FSMState.COMPLETED: |
| temperature = self.codes_temperature |
| else: |
| temperature = self.metadata_temperature |
| |
| |
| if temperature is None: |
| return scores |
| |
| |
| if temperature <= 0: |
| temperature = 1e-6 |
| |
| |
| return scores / temperature |
| |
| def _get_user_provided_field_tokens(self, field_name: str) -> Optional[List[int]]: |
| """ |
| Get token sequence for a user-provided field (field_name + value + newline). |
| Uses the same tokenization logic as prefix tree building. |
| |
| Args: |
| field_name: Field name ("bpm", "caption", "duration", "keyscale", "language", "timesignature") |
| |
| Returns: |
| List of token IDs for the complete field, or None if field is not provided |
| """ |
| value = self.user_provided_metadata.get(field_name) |
| if value is None: |
| return None |
| |
| |
| field_to_prefix = { |
| "bpm": "bpm: ", |
| "caption": "caption: ", |
| "duration": "duration: ", |
| "keyscale": "keyscale: ", |
| "language": "language: ", |
| "timesignature": "timesignature: ", |
| "genres": "genres: ", |
| } |
| prefix = field_to_prefix[field_name] |
| full_text = f"{prefix}{value}\n" |
| |
| |
| tokens = self.tokenizer.encode(full_text, add_special_tokens=False) |
| |
| |
| |
| prefix_for_matching = field_name + ":" |
| prefix_tokens = self.tokenizer.encode(prefix_for_matching, add_special_tokens=False) |
| |
| |
| if len(tokens) >= len(prefix_tokens) and tokens[:len(prefix_tokens)] == prefix_tokens: |
| |
| return tokens[len(prefix_tokens):] |
| else: |
| |
| if self.debug: |
| logger.warning(f"Could not match prefix tokens for field {field_name}, using all tokens") |
| return tokens |
| |
| def _process_single_sequence( |
| self, |
| input_ids: torch.LongTensor, |
| scores: torch.FloatTensor, |
| ) -> torch.FloatTensor: |
| """Process a single sequence and return modified scores (inplace when possible).""" |
| |
| |
| |
| if self.user_field_token_queue: |
| next_token = self.user_field_token_queue[0] |
| self._apply_whitelist_inplace(scores, [next_token]) |
| return scores |
| |
| if self.state in self.fixed_strings: |
| |
| fixed_str = self.fixed_strings[self.state] |
| allowed = self._get_allowed_tokens_for_fixed_string(fixed_str) |
| |
| if allowed: |
| |
| |
| if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning: |
| |
| remaining_chars = len(fixed_str) - self.position_in_state |
| |
| if remaining_chars <= 10: |
| |
| if self.eos_token_id is not None: |
| self._apply_whitelist_inplace(scores, [self.eos_token_id]) |
| if self.debug: |
| logger.debug(f"stop_at_reasoning=True: forcing EOS near end of </think> tag (remaining: {remaining_chars} chars)") |
| return scores |
| |
| |
| self._apply_whitelist_inplace(scores, allowed) |
| else: |
| |
| |
| |
| if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning: |
| |
| if self.eos_token_id is not None: |
| self._apply_whitelist_inplace(scores, [self.eos_token_id]) |
| if self.debug: |
| logger.debug(f"stop_at_reasoning=True: forcing EOS after completing </think> tag") |
| return scores |
| |
| old_state = self.state |
| self._transition_to_next_state() |
| |
| if self.state in self.fixed_strings: |
| |
| if self.debug: |
| logger.warning(f"State transition from {old_state.name} to {self.state.name} still in fixed_strings, avoiding recursion") |
| return scores |
| |
| scores.zero_() |
| return self._process_single_sequence(input_ids, scores) |
| |
| elif self.state == FSMState.BPM_VALUE: |
| |
| if self.user_provided_metadata["bpm"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids: |
| |
| value = self.user_provided_metadata["bpm"] |
| |
| value_text = f" {value}\n" |
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
| if value_tokens: |
| self.user_field_token_queue = value_tokens |
| self.current_user_field = "bpm" |
| |
| self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
| return scores |
| |
| |
| allowed = self._get_allowed_numeric_tokens(self.bpm_prefix_tree) |
| |
| |
| token_prefix = tuple(self.accumulated_token_ids) |
| if token_prefix in self.bpm_prefix_tree and self.newline_token in self.bpm_prefix_tree[token_prefix]: |
| allowed = allowed + [self.newline_token] |
| |
| self._apply_whitelist_inplace(scores, allowed) |
| |
| elif self.state == FSMState.CAPTION_VALUE: |
| |
| |
| |
| |
| |
| |
| |
| if self.user_provided_metadata["caption"] is not None and not self.user_field_token_queue and not self.accumulated_value: |
| |
| value = self.user_provided_metadata["caption"] |
| value_text = f" {value}\n" |
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
| if value_tokens: |
| self.user_field_token_queue = value_tokens |
| self.current_user_field = "caption" |
| |
| self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
| return scores |
| |
| |
| if self.caption_after_newline: |
| |
| top_token_id = torch.argmax(scores[0]).item() |
| top_token_text = self.tokenizer.decode([top_token_id]) |
| |
| |
| if len(top_token_text) > 0 and top_token_text[0] not in ' \t': |
| |
| |
| |
| self.caption_after_newline = False |
| self.caption_ending = True |
| self.pending_field_name = "" |
| |
| return scores |
| else: |
| |
| self.caption_after_newline = False |
| |
| |
|
|
| |
| |
| if self.caption_ending: |
| |
| |
| return scores |
| |
| |
| if self.backtick_token is not None: |
| scores[0, self.backtick_token] = float('-inf') |
| |
| |
| |
| if self.audio_code_mask is not None: |
| |
| if self.audio_code_mask.device != scores.device or self.audio_code_mask.dtype != scores.dtype: |
| self.audio_code_mask = self.audio_code_mask.to(device=scores.device, dtype=scores.dtype) |
| scores = scores + self.audio_code_mask |
| |
| |
| if self.caption_token_count >= 512: |
| |
| if self.newline_token is not None: |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| return scores |
| |
| |
| return scores |
| |
| elif self.state == FSMState.DURATION_VALUE: |
| |
| if self.user_provided_metadata["duration"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids: |
| |
| value = self.user_provided_metadata["duration"] |
| value_text = f" {value}\n" |
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
| if value_tokens: |
| self.user_field_token_queue = value_tokens |
| self.current_user_field = "duration" |
| |
| self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
| return scores |
| |
| |
| if self.target_duration is not None: |
| target_str = str(int(self.target_duration)) |
| current_pos = len(self.accumulated_value) |
| |
| if current_pos < len(target_str): |
| |
| next_digit = int(target_str[current_pos]) |
| if next_digit in self.digit_tokens: |
| self._apply_whitelist_inplace(scores, [self.digit_tokens[next_digit]]) |
| else: |
| |
| if self.newline_token: |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| else: |
| |
| |
| allowed = self._get_allowed_numeric_tokens(self.duration_prefix_tree) |
| |
| |
| token_prefix = tuple(self.accumulated_token_ids) |
| if token_prefix in self.duration_prefix_tree and self.newline_token in self.duration_prefix_tree[token_prefix]: |
| allowed = allowed + [self.newline_token] |
| |
| self._apply_whitelist_inplace(scores, allowed) |
| |
| elif self.state == FSMState.GENRES_VALUE: |
| |
| if self.user_provided_metadata["genres"] is not None and not self.user_field_token_queue and not self.accumulated_value: |
| |
| value = self.user_provided_metadata["genres"] |
| value_text = f" {value}\n" |
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
| if value_tokens: |
| self.user_field_token_queue = value_tokens |
| self.current_user_field = "genres" |
| |
| self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
| return scores |
| |
| |
| self._try_reload_genres_vocab() |
| |
| |
| allowed = self._get_allowed_genres_tokens() |
| |
| if allowed: |
| |
| self._apply_whitelist_inplace(scores, allowed) |
| elif self.genres_vocab: |
| |
| |
| if self.newline_token: |
| if self.debug: |
| logger.debug(f"No valid genre continuation for '{self.accumulated_value}', forcing newline") |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| else: |
| |
| if self._should_end_text_field(scores): |
| if self.newline_token: |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| self._transition_to_next_state() |
| else: |
| |
| if not self.accumulated_value.strip(): |
| if self.newline_token: |
| scores[0, self.newline_token] = float('-inf') |
| |
| |
| elif self.state == FSMState.KEYSCALE_VALUE: |
| |
| if self.user_provided_metadata["keyscale"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids: |
| |
| value = self.user_provided_metadata["keyscale"] |
| value_text = f" {value}\n" |
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
| if value_tokens: |
| self.user_field_token_queue = value_tokens |
| self.current_user_field = "keyscale" |
| |
| self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
| return scores |
| |
| |
| token_prefix = tuple(self.accumulated_token_ids) |
| if token_prefix in self.keyscale_prefix_tree and self.newline_token in self.keyscale_prefix_tree[token_prefix]: |
| |
| if self.newline_token: |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| else: |
| |
| allowed = self._get_allowed_keyscale_tokens() |
| if allowed: |
| self._apply_whitelist_inplace(scores, allowed) |
| else: |
| |
| |
| if self.newline_token: |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| |
| elif self.state == FSMState.LANGUAGE_VALUE: |
| |
| |
| |
| |
| |
| |
| if self.user_provided_metadata["language"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids: |
| |
| value = self.user_provided_metadata["language"] |
| value_text = f" {value}\n" |
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
| if value_tokens: |
| self.user_field_token_queue = value_tokens |
| self.current_user_field = "language" |
| |
| self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
| return scores |
|
|
| |
| |
| if not self.accumulated_token_ids: |
| |
| empty_prefix = tuple() |
| if empty_prefix in self.language_prefix_tree: |
| candidate_tokens = list(self.language_prefix_tree[empty_prefix]) |
| |
| if candidate_tokens: |
| |
| |
| candidate_indices = torch.tensor(candidate_tokens, device=scores.device, dtype=torch.long) |
| candidate_scores = scores[0, candidate_indices] |
| |
| |
| best_idx = torch.argmax(candidate_scores).item() |
| top_token_id = candidate_tokens[best_idx] |
| |
| |
| self._apply_whitelist_inplace(scores, [top_token_id]) |
| |
| if self.debug: |
| top_token_text = self.tokenizer.decode([top_token_id]) |
| logger.debug(f"Language field: selected top-1 token {top_token_id} ({repr(top_token_text)}) from {len(candidate_tokens)} candidates") |
| else: |
| |
| if self.newline_token: |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| else: |
| |
| if self.newline_token: |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| else: |
| |
| |
| token_prefix = tuple(self.accumulated_token_ids) |
| if token_prefix in self.language_prefix_tree and self.newline_token in self.language_prefix_tree[token_prefix]: |
| |
| if self.newline_token: |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| else: |
| |
| allowed = self._get_allowed_language_tokens() |
| if allowed: |
| self._apply_whitelist_inplace(scores, allowed) |
| else: |
| |
| if self.newline_token: |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| |
| elif self.state == FSMState.TIMESIG_VALUE: |
| |
| if self.user_provided_metadata["timesignature"] is not None and not self.user_field_token_queue and not self.accumulated_token_ids: |
| |
| value = self.user_provided_metadata["timesignature"] |
| value_text = f" {value}\n" |
| value_tokens = self.tokenizer.encode(value_text, add_special_tokens=False) |
| if value_tokens: |
| self.user_field_token_queue = value_tokens |
| self.current_user_field = "timesignature" |
| |
| self._apply_whitelist_inplace(scores, [value_tokens[0]]) |
| return scores |
| |
| |
| token_prefix = tuple(self.accumulated_token_ids) |
| if token_prefix in self.timesig_prefix_tree and self.newline_token in self.timesig_prefix_tree[token_prefix]: |
| |
| if self.newline_token: |
| self._apply_whitelist_inplace(scores, [self.newline_token]) |
| else: |
| |
| allowed = self._get_allowed_timesig_tokens() |
| self._apply_whitelist_inplace(scores, allowed) |
| |
| return scores |
| |
| def _transition_to_next_state(self): |
| """Transition to the next FSM state.""" |
| if self.state in self.next_state: |
| old_state = self.state |
| next_state = self.next_state[self.state] |
| |
| |
| |
| |
| |
| if old_state == FSMState.THINK_END_TAG: |
| if self.generation_phase == "understand": |
| |
| |
| next_state = FSMState.COMPLETED |
| if self.debug: |
| logger.debug(f"generation_phase='understand': allowing free-form lyrics after </think>") |
| |
| |
| self.state = next_state |
| self.position_in_state = 0 |
| self.accumulated_value = "" |
| self.accumulated_token_ids = [] |
| self.caption_after_newline = False |
| self.caption_token_count = 0 |
| self.caption_ending = False |
| self.pending_field_name = "" |
| if self.debug: |
| logger.debug(f"FSM transition: {old_state.name} -> {self.state.name}") |
| |
| def update_state(self, generated_token_id: int): |
| """ |
| Update internal state after a token has been generated. |
| This should be called after each token generation. |
| |
| Args: |
| generated_token_id: The token ID that was just generated |
| """ |
| if not self.enabled: |
| return |
| |
| if self.state == FSMState.COMPLETED: |
| return |
| |
| if self.state == FSMState.CODES_GENERATION: |
| |
| self.codes_count += 1 |
| if self.debug and self.target_codes is not None: |
| logger.debug(f"Codes count: {self.codes_count}/{self.target_codes}") |
| return |
| |
| |
| if self.user_field_token_queue: |
| |
| expected_token = self.user_field_token_queue[0] |
| if generated_token_id != expected_token: |
| if self.debug: |
| logger.warning(f"Expected token {expected_token} but got {generated_token_id} for user-provided field {self.current_user_field}") |
| |
| |
| self.user_field_token_queue.pop(0) |
| |
| |
| if not self.user_field_token_queue: |
| if self.debug: |
| logger.debug(f"Completed injection of user-provided field: {self.current_user_field}") |
| field_name = self.current_user_field |
| self.current_user_field = None |
| |
| |
| |
| next_state = self._get_next_field_state(field_name) |
| if next_state: |
| old_state = self.state |
| self.state = next_state |
| self.position_in_state = 0 |
| self.accumulated_value = "" |
| self.accumulated_token_ids = [] |
| if self.debug: |
| logger.debug(f"FSM transition (after user field injection): {old_state.name} -> {self.state.name}") |
| else: |
| |
| self._transition_to_next_state() |
| return |
| |
| token_str = self.tokenizer.decode([generated_token_id]) |
| |
| if self.debug: |
| logger.debug(f"Generated token: {repr(token_str)} (id={generated_token_id}), state={self.state.name}") |
| |
| if self.state in self.fixed_strings: |
| |
| fixed_str = self.fixed_strings[self.state] |
| self.position_in_state += len(token_str) |
| |
| |
| if self.position_in_state >= len(fixed_str): |
| |
| if self.state == FSMState.THINK_END_TAG and self.stop_at_reasoning: |
| |
| |
| old_state = self.state |
| self.state = FSMState.COMPLETED |
| self.position_in_state = 0 |
| self.accumulated_value = "" |
| self.accumulated_token_ids = [] |
| if self.debug: |
| logger.debug(f"FSM transition (stop_at_reasoning): {old_state.name} -> {self.state.name}") |
| else: |
| self._transition_to_next_state() |
| |
| elif self.state in [FSMState.BPM_VALUE, FSMState.DURATION_VALUE, FSMState.TIMESIG_VALUE]: |
| |
| if generated_token_id == self.newline_token: |
| old_state = self.state |
| self._transition_to_next_state() |
| |
| |
| |
| |
| if self.state in self.fixed_strings: |
| return |
| else: |
| |
| self.accumulated_token_ids.append(generated_token_id) |
| |
| if token_str.strip().isdigit(): |
| self.accumulated_value += token_str.strip() |
|
|
| elif self.state == FSMState.GENRES_VALUE: |
| if generated_token_id == self.newline_token: |
| |
| self._transition_to_next_state() |
| |
| |
| |
| |
| if self.state in self.fixed_strings: |
| return |
| else: |
| |
| self.accumulated_value += token_str |
| |
| elif self.state == FSMState.CAPTION_VALUE: |
| |
| self.caption_token_count += 1 |
| |
| |
| self.accumulated_value += token_str |
| |
| |
| |
| if '\n' in token_str: |
| |
| self.caption_after_newline = True |
| else: |
| |
| |
| self.caption_after_newline = False |
| |
| |
| if self.caption_ending: |
| self.pending_field_name += token_str |
| |
| |
| if ':' in token_str or token_str.strip() == ':': |
| |
| field_name_full = self.pending_field_name.strip() |
| |
| field_name = field_name_full.rstrip(':').strip().lower() |
| |
| if self.debug: |
| logger.debug(f"Detected field name after caption: {repr(field_name)}") |
| |
| |
| field_name_to_value_state = { |
| "duration": FSMState.DURATION_VALUE, |
| "genres": FSMState.GENRES_VALUE, |
| "keyscale": FSMState.KEYSCALE_VALUE, |
| "language": FSMState.LANGUAGE_VALUE, |
| "timesignature": FSMState.TIMESIG_VALUE, |
| } |
| |
| if field_name in field_name_to_value_state: |
| |
| old_state = self.state |
| self.state = field_name_to_value_state[field_name] |
| self.position_in_state = 0 |
| self.accumulated_value = "" |
| self.accumulated_token_ids = [] |
| self.caption_ending = False |
| self.pending_field_name = "" |
| |
| if self.debug: |
| logger.debug(f"FSM transition (caption ending): {old_state.name} -> {self.state.name}") |
| else: |
| |
| if self.debug: |
| logger.warning(f"Unknown field name after caption: {repr(field_name)}, forcing transition") |
| self.caption_ending = False |
| self.pending_field_name = "" |
| self._transition_to_next_state() |
| |
| elif self.state == FSMState.KEYSCALE_VALUE: |
| if generated_token_id == self.newline_token: |
| |
| self._transition_to_next_state() |
| |
| |
| |
| |
| if self.state in self.fixed_strings: |
| return |
| else: |
| |
| self.accumulated_token_ids.append(generated_token_id) |
| |
| self.accumulated_value += token_str |
| |
| elif self.state == FSMState.LANGUAGE_VALUE: |
| if generated_token_id == self.newline_token: |
| |
| self._transition_to_next_state() |
| if self.state in self.fixed_strings: |
| return |
| else: |
| |
| self.accumulated_token_ids.append(generated_token_id) |
| |
| self.accumulated_value += token_str |
|
|
|
|