| """ |
| ACE-Step Inference API Module |
| |
| This module provides a standardized inference interface for music generation, |
| designed for third-party integration. It offers both a simplified API and |
| backward-compatible Gradio UI support. |
| """ |
|
|
| import math |
| import os |
| import tempfile |
| from typing import Optional, Union, List, Dict, Any, Tuple |
| from dataclasses import dataclass, field, asdict |
| from loguru import logger |
|
|
| from acestep.audio_utils import AudioSaver, generate_uuid_from_params |
|
|
|
|
| @dataclass |
| class GenerationParams: |
| """Configuration for music generation parameters. |
| |
| Attributes: |
| # Text Inputs |
| caption: A short text prompt describing the desired music (main prompt). < 512 characters |
| lyrics: Lyrics for the music. Use "[Instrumental]" for instrumental songs. < 4096 characters |
| instrumental: If True, generate instrumental music regardless of lyrics. |
| |
| # Music Metadata |
| bpm: BPM (beats per minute), e.g., 120. Set to None for automatic estimation. 30 ~ 300 |
| keyscale: Musical key (e.g., "C Major", "Am"). Leave empty for auto-detection. A-G, #/♭, major/minor |
| timesignature: Time signature (2 for '2/4', 3 for '3/4', 4 for '4/4', 6 for '6/8'). Leave empty for auto-detection. |
| vocal_language: Language code for vocals, e.g., "en", "zh", "ja", or "unknown". see acestep/constants.py:VALID_LANGUAGES |
| duration: Target audio length in seconds. If <0 or None, model chooses automatically. 10 ~ 600 |
| |
| # Generation Parameters |
| inference_steps: Number of diffusion steps (e.g., 8 for turbo, 32–100 for base model). |
| guidance_scale: CFG (classifier-free guidance) strength. Higher means following the prompt more strictly. Only support for non-turbo model. |
| seed: Integer seed for reproducibility. -1 means use random seed each time. |
| |
| # Advanced DiT Parameters |
| use_adg: Whether to use Adaptive Dual Guidance (only works for base model). |
| cfg_interval_start: Start ratio (0.0–1.0) to apply CFG. |
| cfg_interval_end: End ratio (0.0–1.0) to apply CFG. |
| shift: Timestep shift factor (default 1.0). When != 1.0, applies t = shift * t / (1 + (shift - 1) * t) to timesteps. |
| |
| # Task-Specific Parameters |
| task_type: Type of generation task. One of: "text2music", "cover", "repaint", "lego", "extract", "complete". |
| reference_audio: Path to a reference audio file for style transfer or cover tasks. |
| src_audio: Path to a source audio file for audio-to-audio tasks. |
| audio_codes: Audio semantic codes as a string (advanced use, for code-control generation). |
| repainting_start: For repaint/lego tasks: start time in seconds for region to repaint. |
| repainting_end: For repaint/lego tasks: end time in seconds for region to repaint (-1 for until end). |
| audio_cover_strength: Strength of reference audio/codes influence (range 0.0–1.0). set smaller (0.2) for style transfer tasks. |
| instruction: Optional task instruction prompt. If empty, auto-generated by system. |
| |
| # 5Hz Language Model Parameters for CoT reasoning |
| thinking: If True, enable 5Hz Language Model "Chain-of-Thought" reasoning for semantic/music metadata and codes. |
| lm_temperature: Sampling temperature for the LLM (0.0–2.0). Higher = more creative/varied results. |
| lm_cfg_scale: Classifier-free guidance scale for the LLM. |
| lm_top_k: LLM top-k sampling (0 = disabled). |
| lm_top_p: LLM top-p nucleus sampling (1.0 = disabled). |
| lm_negative_prompt: Negative prompt to use for LLM (for control). |
| use_cot_metas: Whether to let LLM generate music metadata via CoT reasoning. |
| use_cot_caption: Whether to let LLM rewrite or format the input caption via CoT reasoning. |
| use_cot_language: Whether to let LLM detect vocal language via CoT. |
| """ |
| |
| task_type: str = "text2music" |
| instruction: str = "Fill the audio semantic mask based on the given conditions:" |
|
|
| |
| reference_audio: Optional[str] = None |
| src_audio: Optional[str] = None |
|
|
| |
| audio_codes: str = "" |
|
|
| |
| caption: str = "" |
| lyrics: str = "" |
| instrumental: bool = False |
|
|
| |
| vocal_language: str = "unknown" |
| bpm: Optional[int] = None |
| keyscale: str = "" |
| timesignature: str = "" |
| duration: float = -1.0 |
|
|
| |
| inference_steps: int = 8 |
| seed: int = -1 |
| guidance_scale: float = 7.0 |
| use_adg: bool = False |
| cfg_interval_start: float = 0.0 |
| cfg_interval_end: float = 1.0 |
| shift: float = 1.0 |
| infer_method: str = "ode" |
| |
| |
| timesteps: Optional[List[float]] = None |
|
|
| repainting_start: float = 0.0 |
| repainting_end: float = -1 |
| audio_cover_strength: float = 1.0 |
|
|
| |
| thinking: bool = True |
| lm_temperature: float = 0.85 |
| lm_cfg_scale: float = 2.0 |
| lm_top_k: int = 0 |
| lm_top_p: float = 0.9 |
| lm_negative_prompt: str = "NO USER INPUT" |
| use_cot_metas: bool = True |
| use_cot_caption: bool = True |
| use_cot_lyrics: bool = False |
| use_cot_language: bool = True |
| use_constrained_decoding: bool = True |
|
|
| cot_bpm: Optional[int] = None |
| cot_keyscale: str = "" |
| cot_timesignature: str = "" |
| cot_duration: Optional[float] = None |
| cot_vocal_language: str = "unknown" |
| cot_caption: str = "" |
| cot_lyrics: str = "" |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert config to dictionary for JSON serialization.""" |
| return asdict(self) |
|
|
|
|
| @dataclass |
| class GenerationConfig: |
| """Configuration for music generation. |
| |
| Attributes: |
| batch_size: Number of audio samples to generate |
| allow_lm_batch: Whether to allow batch processing in LM |
| use_random_seed: Whether to use random seed |
| seeds: Seed(s) for batch generation. Can be: |
| - None: Use random seeds (when use_random_seed=True) or params.seed (when use_random_seed=False) |
| - List[int]: List of seeds, will be padded with random seeds if fewer than batch_size |
| - int: Single seed value (will be converted to list and padded) |
| lm_batch_chunk_size: Batch chunk size for LM processing |
| constrained_decoding_debug: Whether to enable constrained decoding debug |
| audio_format: Output audio format, one of "mp3", "wav", "flac". Default: "flac" |
| """ |
| batch_size: int = 2 |
| allow_lm_batch: bool = False |
| use_random_seed: bool = True |
| seeds: Optional[List[int]] = None |
| lm_batch_chunk_size: int = 8 |
| constrained_decoding_debug: bool = False |
| audio_format: str = "flac" |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert config to dictionary for JSON serialization.""" |
| return asdict(self) |
|
|
|
|
| @dataclass |
| class GenerationResult: |
| """Result of music generation. |
| |
| Attributes: |
| # Audio Outputs |
| audios: List of audio dictionaries with paths, keys, params |
| status_message: Status message from generation |
| extra_outputs: Extra outputs from generation |
| success: Whether generation completed successfully |
| error: Error message if generation failed |
| """ |
|
|
| |
| audios: List[Dict[str, Any]] = field(default_factory=list) |
| |
| status_message: str = "" |
| extra_outputs: Dict[str, Any] = field(default_factory=dict) |
| |
| success: bool = True |
| error: Optional[str] = None |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert result to dictionary for JSON serialization.""" |
| return asdict(self) |
|
|
|
|
| @dataclass |
| class UnderstandResult: |
| """Result of music understanding from audio codes. |
| |
| Attributes: |
| # Metadata Fields |
| caption: Generated caption describing the music |
| lyrics: Generated or extracted lyrics |
| bpm: Beats per minute (None if not detected) |
| duration: Duration in seconds (None if not detected) |
| keyscale: Musical key (e.g., "C Major") |
| language: Vocal language code (e.g., "en", "zh") |
| timesignature: Time signature (e.g., "4/4") |
| |
| # Status |
| status_message: Status message from understanding |
| success: Whether understanding completed successfully |
| error: Error message if understanding failed |
| """ |
| |
| caption: str = "" |
| lyrics: str = "" |
| bpm: Optional[int] = None |
| duration: Optional[float] = None |
| keyscale: str = "" |
| language: str = "" |
| timesignature: str = "" |
| |
| |
| status_message: str = "" |
| success: bool = True |
| error: Optional[str] = None |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert result to dictionary for JSON serialization.""" |
| return asdict(self) |
|
|
|
|
| def _update_metadata_from_lm( |
| metadata: Dict[str, Any], |
| bpm: Optional[int], |
| key_scale: str, |
| time_signature: str, |
| audio_duration: Optional[float], |
| vocal_language: str, |
| caption: str, |
| lyrics: str, |
| ) -> Tuple[Optional[int], str, str, Optional[float]]: |
| """Update metadata fields from LM output if not provided by user.""" |
|
|
| if bpm is None and metadata.get('bpm'): |
| bpm_value = metadata.get('bpm') |
| if bpm_value not in ["N/A", ""]: |
| try: |
| bpm = int(bpm_value) |
| except (ValueError, TypeError): |
| pass |
|
|
| if not key_scale and metadata.get('keyscale'): |
| key_scale_value = metadata.get('keyscale', metadata.get('key_scale', "")) |
| if key_scale_value != "N/A": |
| key_scale = key_scale_value |
|
|
| if not time_signature and metadata.get('timesignature'): |
| time_signature_value = metadata.get('timesignature', metadata.get('time_signature', "")) |
| if time_signature_value != "N/A": |
| time_signature = time_signature_value |
|
|
| if audio_duration is None: |
| audio_duration_value = metadata.get('duration', -1) |
| if audio_duration_value not in ["N/A", ""]: |
| try: |
| audio_duration = float(audio_duration_value) |
| except (ValueError, TypeError): |
| pass |
|
|
| if not vocal_language and metadata.get('vocal_language'): |
| vocal_language = metadata.get('vocal_language') |
| if not caption and metadata.get('caption'): |
| caption = metadata.get('caption') |
| if not lyrics and metadata.get('lyrics'): |
| lyrics = metadata.get('lyrics') |
| return bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics |
|
|
|
|
| def generate_music( |
| dit_handler, |
| llm_handler, |
| params: GenerationParams, |
| config: GenerationConfig, |
| save_dir: Optional[str] = None, |
| progress=None, |
| ) -> GenerationResult: |
| """Generate music using ACE-Step model with optional LM reasoning. |
| |
| Args: |
| dit_handler: Initialized DiT model handler (AceStepHandler instance) |
| llm_handler: Initialized LLM handler (LLMHandler instance) |
| params: Generation parameters (GenerationParams instance) |
| config: Generation configuration (GenerationConfig instance) |
| |
| Returns: |
| GenerationResult with generated audio files and metadata |
| """ |
| try: |
| |
| audio_code_string_to_use = params.audio_codes |
| lm_generated_metadata = None |
| lm_generated_audio_codes_list = [] |
| lm_total_time_costs = { |
| "phase1_time": 0.0, |
| "phase2_time": 0.0, |
| "total_time": 0.0, |
| } |
|
|
| |
| bpm = params.bpm |
| key_scale = params.keyscale |
| time_signature = params.timesignature |
| audio_duration = params.duration |
| dit_input_caption = params.caption |
| dit_input_vocal_language = params.vocal_language |
| dit_input_lyrics = params.lyrics |
| |
| |
| |
| user_provided_audio_codes = bool(params.audio_codes and str(params.audio_codes).strip()) |
|
|
| |
| |
| |
| |
| need_audio_codes = not user_provided_audio_codes |
|
|
| |
| |
| actual_batch_size = config.batch_size if config.batch_size is not None else 1 |
|
|
| |
| |
| |
| seed_for_generation = "" |
| if config.seeds is not None and len(config.seeds) > 0: |
| if isinstance(config.seeds, list): |
| |
| seed_for_generation = ",".join(str(s) for s in config.seeds) |
|
|
| |
| |
| actual_seed_list, _ = dit_handler.prepare_seeds(actual_batch_size, seed_for_generation, config.use_random_seed) |
|
|
| |
| |
| |
| skip_lm_tasks = {"cover", "repaint"} |
| |
| |
| |
| |
| |
| |
| |
| need_lm_for_cot = params.use_cot_caption or params.use_cot_language or params.use_cot_metas |
| use_lm = (params.thinking or need_lm_for_cot) and llm_handler.llm_initialized and params.task_type not in skip_lm_tasks |
| lm_status = [] |
| |
| if params.task_type in skip_lm_tasks: |
| logger.info(f"Skipping LM for task_type='{params.task_type}' - using DiT directly") |
| |
| logger.info(f"[generate_music] LLM usage decision: thinking={params.thinking}, " |
| f"use_cot_caption={params.use_cot_caption}, use_cot_language={params.use_cot_language}, " |
| f"use_cot_metas={params.use_cot_metas}, need_lm_for_cot={need_lm_for_cot}, " |
| f"llm_initialized={llm_handler.llm_initialized if llm_handler else False}, use_lm={use_lm}") |
| |
| if use_lm: |
| |
| top_k_value = None if not params.lm_top_k or params.lm_top_k == 0 else int(params.lm_top_k) |
| top_p_value = None if not params.lm_top_p or params.lm_top_p >= 1.0 else params.lm_top_p |
|
|
| |
| user_metadata = {} |
| if bpm is not None: |
| try: |
| bpm_value = float(bpm) |
| if bpm_value > 0: |
| user_metadata['bpm'] = int(bpm_value) |
| except (ValueError, TypeError): |
| pass |
|
|
| if key_scale and key_scale.strip(): |
| key_scale_clean = key_scale.strip() |
| if key_scale_clean.lower() not in ["n/a", ""]: |
| user_metadata['keyscale'] = key_scale_clean |
|
|
| if time_signature and time_signature.strip(): |
| time_sig_clean = time_signature.strip() |
| if time_sig_clean.lower() not in ["n/a", ""]: |
| user_metadata['timesignature'] = time_sig_clean |
|
|
| if audio_duration is not None: |
| try: |
| duration_value = float(audio_duration) |
| if duration_value > 0: |
| user_metadata['duration'] = int(duration_value) |
| except (ValueError, TypeError): |
| pass |
|
|
| user_metadata_to_pass = user_metadata if user_metadata else None |
|
|
| |
| |
| |
| infer_type = "llm_dit" if need_audio_codes and params.thinking else "dit" |
|
|
| |
| max_inference_batch_size = int(config.lm_batch_chunk_size) if config.lm_batch_chunk_size > 0 else actual_batch_size |
| num_chunks = math.ceil(actual_batch_size / max_inference_batch_size) |
|
|
| all_metadata_list = [] |
| all_audio_codes_list = [] |
|
|
| for chunk_idx in range(num_chunks): |
| chunk_start = chunk_idx * max_inference_batch_size |
| chunk_end = min(chunk_start + max_inference_batch_size, actual_batch_size) |
| chunk_size = chunk_end - chunk_start |
| chunk_seeds = actual_seed_list[chunk_start:chunk_end] if chunk_start < len(actual_seed_list) else None |
|
|
| logger.info(f"LM chunk {chunk_idx+1}/{num_chunks} (infer_type={infer_type}) " |
| f"(size: {chunk_size}, seeds: {chunk_seeds})") |
|
|
| |
| |
| |
| result = llm_handler.generate_with_stop_condition( |
| caption=params.caption or "", |
| lyrics=params.lyrics or "", |
| infer_type=infer_type, |
| temperature=params.lm_temperature, |
| cfg_scale=params.lm_cfg_scale, |
| negative_prompt=params.lm_negative_prompt, |
| top_k=top_k_value, |
| top_p=top_p_value, |
| user_metadata=user_metadata_to_pass, |
| use_cot_caption=params.use_cot_caption, |
| use_cot_language=params.use_cot_language, |
| use_cot_metas=params.use_cot_metas, |
| use_constrained_decoding=params.use_constrained_decoding, |
| constrained_decoding_debug=config.constrained_decoding_debug, |
| batch_size=chunk_size, |
| seeds=chunk_seeds, |
| progress=progress, |
| ) |
|
|
| |
| if not result.get("success", False): |
| error_msg = result.get("error", "Unknown LM error") |
| lm_status.append(f"❌ LM Error: {error_msg}") |
| |
| return GenerationResult( |
| audios=[], |
| status_message=f"❌ LM generation failed: {error_msg}", |
| extra_outputs={}, |
| success=False, |
| error=error_msg, |
| ) |
|
|
| |
| if chunk_size > 1: |
| metadata_list = result.get("metadata", []) |
| audio_codes_list = result.get("audio_codes", []) |
| all_metadata_list.extend(metadata_list) |
| all_audio_codes_list.extend(audio_codes_list) |
| else: |
| metadata = result.get("metadata", {}) |
| audio_codes = result.get("audio_codes", "") |
| all_metadata_list.append(metadata) |
| all_audio_codes_list.append(audio_codes) |
|
|
| |
| lm_extra = result.get("extra_outputs", {}) |
| lm_chunk_time_costs = lm_extra.get("time_costs", {}) |
| if lm_chunk_time_costs: |
| |
| for key in ["phase1_time", "phase2_time", "total_time"]: |
| if key in lm_chunk_time_costs: |
| lm_total_time_costs[key] += lm_chunk_time_costs[key] |
|
|
| time_str = ", ".join([f"{k}: {v:.2f}s" for k, v in lm_chunk_time_costs.items()]) |
| lm_status.append(f"✅ LM chunk {chunk_idx+1}: {time_str}") |
|
|
| lm_generated_metadata = all_metadata_list[0] if all_metadata_list else None |
| lm_generated_audio_codes_list = all_audio_codes_list |
|
|
| |
| if infer_type == "llm_dit": |
| |
| if actual_batch_size > 1: |
| audio_code_string_to_use = all_audio_codes_list |
| else: |
| audio_code_string_to_use = all_audio_codes_list[0] if all_audio_codes_list else "" |
| else: |
| |
| audio_code_string_to_use = params.audio_codes |
|
|
| |
| if lm_generated_metadata: |
| bpm, key_scale, time_signature, audio_duration, vocal_language, caption, lyrics = _update_metadata_from_lm( |
| metadata=lm_generated_metadata, |
| bpm=bpm, |
| key_scale=key_scale, |
| time_signature=time_signature, |
| audio_duration=audio_duration, |
| vocal_language=dit_input_vocal_language, |
| caption=dit_input_caption, |
| lyrics=dit_input_lyrics) |
| if not params.bpm: |
| params.cot_bpm = bpm |
| if not params.keyscale: |
| params.cot_keyscale = key_scale |
| if not params.timesignature: |
| params.cot_timesignature = time_signature |
| if not params.duration: |
| params.cot_duration = audio_duration |
| if not params.vocal_language: |
| params.cot_vocal_language = vocal_language |
| if not params.caption: |
| params.cot_caption = caption |
| if not params.lyrics: |
| params.cot_lyrics = lyrics |
|
|
| |
| if params.use_cot_caption: |
| dit_input_caption = lm_generated_metadata.get("caption", dit_input_caption) |
| if params.use_cot_language: |
| dit_input_vocal_language = lm_generated_metadata.get("vocal_language", dit_input_vocal_language) |
|
|
| |
| |
| result = dit_handler.generate_music( |
| captions=dit_input_caption, |
| lyrics=dit_input_lyrics, |
| bpm=bpm, |
| key_scale=key_scale, |
| time_signature=time_signature, |
| vocal_language=dit_input_vocal_language, |
| inference_steps=params.inference_steps, |
| guidance_scale=params.guidance_scale, |
| use_random_seed=config.use_random_seed, |
| seed=seed_for_generation, |
| reference_audio=params.reference_audio, |
| audio_duration=audio_duration, |
| batch_size=config.batch_size if config.batch_size is not None else 1, |
| src_audio=params.src_audio, |
| audio_code_string=audio_code_string_to_use, |
| repainting_start=params.repainting_start, |
| repainting_end=params.repainting_end, |
| instruction=params.instruction, |
| audio_cover_strength=params.audio_cover_strength, |
| task_type=params.task_type, |
| use_adg=params.use_adg, |
| cfg_interval_start=params.cfg_interval_start, |
| cfg_interval_end=params.cfg_interval_end, |
| shift=params.shift, |
| infer_method=params.infer_method, |
| timesteps=params.timesteps, |
| progress=progress, |
| ) |
|
|
| |
| if not result.get("success", False): |
| return GenerationResult( |
| audios=[], |
| status_message=result.get("status_message", ""), |
| extra_outputs={}, |
| success=False, |
| error=result.get("error"), |
| ) |
|
|
| |
| dit_audios = result.get("audios", []) |
| status_message = result.get("status_message", "") |
| dit_extra_outputs = result.get("extra_outputs", {}) |
|
|
| |
| |
| seed_list = actual_seed_list |
|
|
| |
| base_params_dict = params.to_dict() |
|
|
| |
| audio_format = config.audio_format if config.audio_format else "flac" |
| audio_saver = AudioSaver(default_format=audio_format) |
|
|
| |
| if save_dir is not None: |
| os.makedirs(save_dir, exist_ok=True) |
|
|
| |
| |
| audios = [] |
| for idx, dit_audio in enumerate(dit_audios): |
| |
| audio_params = base_params_dict.copy() |
|
|
| |
| audio_params["seed"] = seed_list[idx] if idx < len(seed_list) else None |
|
|
| |
| if lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list): |
| audio_params["audio_codes"] = lm_generated_audio_codes_list[idx] |
|
|
| |
| audio_tensor = dit_audio.get("tensor") |
| sample_rate = dit_audio.get("sample_rate", 48000) |
|
|
| |
| batch_seed = seed_list[idx] if idx < len(seed_list) else seed_list[0] if seed_list else -1 |
| audio_code_str = lm_generated_audio_codes_list[idx] if ( |
| lm_generated_audio_codes_list and idx < len(lm_generated_audio_codes_list)) else audio_code_string_to_use |
| if isinstance(audio_code_str, list): |
| audio_code_str = audio_code_str[idx] if idx < len(audio_code_str) else "" |
|
|
| audio_key = generate_uuid_from_params(audio_params) |
|
|
| |
| audio_path = None |
| if audio_tensor is not None and save_dir is not None: |
| try: |
| audio_file = os.path.join(save_dir, f"{audio_key}.{audio_format}") |
| audio_path = audio_saver.save_audio(audio_tensor, |
| audio_file, |
| sample_rate=sample_rate, |
| format=audio_format, |
| channels_first=True) |
| except Exception as e: |
| logger.error(f"[generate_music] Failed to save audio file: {e}") |
| audio_path = "" |
|
|
| audio_dict = { |
| "path": audio_path or "", |
| "tensor": audio_tensor, |
| "key": audio_key, |
| "sample_rate": sample_rate, |
| "params": audio_params, |
| } |
|
|
| audios.append(audio_dict) |
|
|
| |
| extra_outputs = dit_extra_outputs.copy() |
| extra_outputs["lm_metadata"] = lm_generated_metadata |
|
|
| |
| unified_time_costs = {} |
|
|
| |
| if use_lm and lm_total_time_costs: |
| for key, value in lm_total_time_costs.items(): |
| unified_time_costs[f"lm_{key}"] = value |
|
|
| |
| dit_time_costs = dit_extra_outputs.get("time_costs", {}) |
| if dit_time_costs: |
| for key, value in dit_time_costs.items(): |
| unified_time_costs[f"dit_{key}"] = value |
|
|
| |
| if unified_time_costs: |
| lm_total = unified_time_costs.get("lm_total_time", 0.0) |
| dit_total = unified_time_costs.get("dit_total_time_cost", 0.0) |
| unified_time_costs["pipeline_total_time"] = lm_total + dit_total |
|
|
| |
| extra_outputs["time_costs"] = unified_time_costs |
|
|
| if lm_status: |
| status_message = "\n".join(lm_status) + "\n" + status_message |
| else: |
| status_message = status_message |
| |
| return GenerationResult( |
| audios=audios, |
| status_message=status_message, |
| extra_outputs=extra_outputs, |
| success=True, |
| error=None, |
| ) |
|
|
| except Exception as e: |
| logger.exception("Music generation failed") |
| return GenerationResult( |
| audios=[], |
| status_message=f"Error: {str(e)}", |
| extra_outputs={}, |
| success=False, |
| error=str(e), |
| ) |
|
|
|
|
| def understand_music( |
| llm_handler, |
| audio_codes: str, |
| temperature: float = 0.85, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| repetition_penalty: float = 1.0, |
| use_constrained_decoding: bool = True, |
| constrained_decoding_debug: bool = False, |
| ) -> UnderstandResult: |
| """Understand music from audio codes using the 5Hz Language Model. |
| |
| This function analyzes audio semantic codes and generates metadata about the music, |
| including caption, lyrics, BPM, duration, key scale, language, and time signature. |
| |
| If audio_codes is empty or "NO USER INPUT", the LM will generate a sample example |
| instead of analyzing existing codes. |
| |
| Note: cfg_scale and negative_prompt are not supported in understand mode. |
| |
| Args: |
| llm_handler: Initialized LLM handler (LLMHandler instance) |
| audio_codes: String of audio code tokens (e.g., "<|audio_code_123|><|audio_code_456|>...") |
| Use empty string or "NO USER INPUT" to generate a sample example. |
| temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative. |
| top_k: Top-K sampling (None or 0 = disabled) |
| top_p: Top-P (nucleus) sampling (None or 1.0 = disabled) |
| repetition_penalty: Repetition penalty (1.0 = no penalty) |
| use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata |
| constrained_decoding_debug: Whether to enable debug logging for constrained decoding |
| |
| Returns: |
| UnderstandResult with parsed metadata fields and status |
| |
| Example: |
| >>> result = understand_music(llm_handler, audio_codes="<|audio_code_123|>...") |
| >>> if result.success: |
| ... print(f"Caption: {result.caption}") |
| ... print(f"BPM: {result.bpm}") |
| ... print(f"Lyrics: {result.lyrics}") |
| """ |
| |
| if not llm_handler.llm_initialized: |
| return UnderstandResult( |
| status_message="5Hz LM not initialized. Please initialize it first.", |
| success=False, |
| error="LLM not initialized", |
| ) |
| |
| |
| if not audio_codes or not audio_codes.strip(): |
| audio_codes = "NO USER INPUT" |
| |
| try: |
| |
| metadata, status = llm_handler.understand_audio_from_codes( |
| audio_codes=audio_codes, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| use_constrained_decoding=use_constrained_decoding, |
| constrained_decoding_debug=constrained_decoding_debug, |
| ) |
| |
| |
| if not metadata: |
| return UnderstandResult( |
| status_message=status or "Failed to understand audio codes", |
| success=False, |
| error=status or "Empty metadata returned", |
| ) |
| |
| |
| caption = metadata.get('caption', '') |
| lyrics = metadata.get('lyrics', '') |
| keyscale = metadata.get('keyscale', '') |
| language = metadata.get('language', metadata.get('vocal_language', '')) |
| timesignature = metadata.get('timesignature', '') |
| |
| |
| bpm = None |
| bpm_value = metadata.get('bpm') |
| if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '': |
| try: |
| bpm = int(bpm_value) |
| except (ValueError, TypeError): |
| pass |
| |
| |
| duration = None |
| duration_value = metadata.get('duration') |
| if duration_value is not None and duration_value != 'N/A' and duration_value != '': |
| try: |
| duration = float(duration_value) |
| except (ValueError, TypeError): |
| pass |
| |
| |
| if keyscale == 'N/A': |
| keyscale = '' |
| if language == 'N/A': |
| language = '' |
| if timesignature == 'N/A': |
| timesignature = '' |
| |
| return UnderstandResult( |
| caption=caption, |
| lyrics=lyrics, |
| bpm=bpm, |
| duration=duration, |
| keyscale=keyscale, |
| language=language, |
| timesignature=timesignature, |
| status_message=status, |
| success=True, |
| error=None, |
| ) |
| |
| except Exception as e: |
| logger.exception("Music understanding failed") |
| return UnderstandResult( |
| status_message=f"Error: {str(e)}", |
| success=False, |
| error=str(e), |
| ) |
|
|
|
|
| @dataclass |
| class CreateSampleResult: |
| """Result of creating a music sample from a natural language query. |
| |
| This is used by the "Simple Mode" / "Inspiration Mode" feature where users |
| provide a natural language description and the LLM generates a complete |
| sample with caption, lyrics, and metadata. |
| |
| Attributes: |
| # Metadata Fields |
| caption: Generated detailed music description/caption |
| lyrics: Generated lyrics (or "[Instrumental]" for instrumental music) |
| bpm: Beats per minute (None if not generated) |
| duration: Duration in seconds (None if not generated) |
| keyscale: Musical key (e.g., "C Major") |
| language: Vocal language code (e.g., "en", "zh") |
| timesignature: Time signature (e.g., "4") |
| instrumental: Whether this is an instrumental piece |
| |
| # Status |
| status_message: Status message from sample creation |
| success: Whether sample creation completed successfully |
| error: Error message if sample creation failed |
| """ |
| |
| caption: str = "" |
| lyrics: str = "" |
| bpm: Optional[int] = None |
| duration: Optional[float] = None |
| keyscale: str = "" |
| language: str = "" |
| timesignature: str = "" |
| instrumental: bool = False |
| |
| |
| status_message: str = "" |
| success: bool = True |
| error: Optional[str] = None |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert result to dictionary for JSON serialization.""" |
| return asdict(self) |
|
|
|
|
| def create_sample( |
| llm_handler, |
| query: str, |
| instrumental: bool = False, |
| vocal_language: Optional[str] = None, |
| temperature: float = 0.85, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| repetition_penalty: float = 1.0, |
| use_constrained_decoding: bool = True, |
| constrained_decoding_debug: bool = False, |
| ) -> CreateSampleResult: |
| """Create a music sample from a natural language query using the 5Hz Language Model. |
| |
| This is the "Simple Mode" / "Inspiration Mode" feature that takes a user's natural |
| language description of music and generates a complete sample including: |
| - Detailed caption/description |
| - Lyrics (unless instrumental) |
| - Metadata (BPM, duration, key, language, time signature) |
| |
| Note: cfg_scale and negative_prompt are not supported in create_sample mode. |
| |
| Args: |
| llm_handler: Initialized LLM handler (LLMHandler instance) |
| query: User's natural language music description (e.g., "a soft Bengali love song") |
| instrumental: Whether to generate instrumental music (no vocals) |
| vocal_language: Allowed vocal language for constrained decoding (e.g., "en", "zh"). |
| If provided, the model will be constrained to generate lyrics in this language. |
| If None or "unknown", no language constraint is applied. |
| temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative. |
| top_k: Top-K sampling (None or 0 = disabled) |
| top_p: Top-P (nucleus) sampling (None or 1.0 = disabled) |
| repetition_penalty: Repetition penalty (1.0 = no penalty) |
| use_constrained_decoding: Whether to use FSM-based constrained decoding |
| constrained_decoding_debug: Whether to enable debug logging |
| |
| Returns: |
| CreateSampleResult with generated sample fields and status |
| |
| Example: |
| >>> result = create_sample(llm_handler, "a soft Bengali love song for a quiet evening", vocal_language="bn") |
| >>> if result.success: |
| ... print(f"Caption: {result.caption}") |
| ... print(f"Lyrics: {result.lyrics}") |
| ... print(f"BPM: {result.bpm}") |
| """ |
| |
| if not llm_handler.llm_initialized: |
| return CreateSampleResult( |
| status_message="5Hz LM not initialized. Please initialize it first.", |
| success=False, |
| error="LLM not initialized", |
| ) |
| |
| try: |
| |
| metadata, status = llm_handler.create_sample_from_query( |
| query=query, |
| instrumental=instrumental, |
| vocal_language=vocal_language, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| use_constrained_decoding=use_constrained_decoding, |
| constrained_decoding_debug=constrained_decoding_debug, |
| ) |
| |
| |
| if not metadata: |
| return CreateSampleResult( |
| status_message=status or "Failed to create sample", |
| success=False, |
| error=status or "Empty metadata returned", |
| ) |
| |
| |
| caption = metadata.get('caption', '') |
| lyrics = metadata.get('lyrics', '') |
| keyscale = metadata.get('keyscale', '') |
| language = metadata.get('language', metadata.get('vocal_language', '')) |
| timesignature = metadata.get('timesignature', '') |
| is_instrumental = metadata.get('instrumental', instrumental) |
| |
| |
| bpm = None |
| bpm_value = metadata.get('bpm') |
| if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '': |
| try: |
| bpm = int(bpm_value) |
| except (ValueError, TypeError): |
| pass |
| |
| |
| duration = None |
| duration_value = metadata.get('duration') |
| if duration_value is not None and duration_value != 'N/A' and duration_value != '': |
| try: |
| duration = float(duration_value) |
| except (ValueError, TypeError): |
| pass |
| |
| |
| if keyscale == 'N/A': |
| keyscale = '' |
| if language == 'N/A': |
| language = '' |
| if timesignature == 'N/A': |
| timesignature = '' |
| |
| return CreateSampleResult( |
| caption=caption, |
| lyrics=lyrics, |
| bpm=bpm, |
| duration=duration, |
| keyscale=keyscale, |
| language=language, |
| timesignature=timesignature, |
| instrumental=is_instrumental, |
| status_message=status, |
| success=True, |
| error=None, |
| ) |
| |
| except Exception as e: |
| logger.exception("Sample creation failed") |
| return CreateSampleResult( |
| status_message=f"Error: {str(e)}", |
| success=False, |
| error=str(e), |
| ) |
|
|
|
|
| @dataclass |
| class FormatSampleResult: |
| """Result of formatting user-provided caption and lyrics. |
| |
| This is used by the "Format" feature where users provide caption and lyrics, |
| and the LLM formats them into structured music metadata and an enhanced description. |
| |
| Attributes: |
| # Metadata Fields |
| caption: Enhanced/formatted music description/caption |
| lyrics: Formatted lyrics (may be same as input or reformatted) |
| bpm: Beats per minute (None if not detected) |
| duration: Duration in seconds (None if not detected) |
| keyscale: Musical key (e.g., "C Major") |
| language: Vocal language code (e.g., "en", "zh") |
| timesignature: Time signature (e.g., "4") |
| |
| # Status |
| status_message: Status message from formatting |
| success: Whether formatting completed successfully |
| error: Error message if formatting failed |
| """ |
| |
| caption: str = "" |
| lyrics: str = "" |
| bpm: Optional[int] = None |
| duration: Optional[float] = None |
| keyscale: str = "" |
| language: str = "" |
| timesignature: str = "" |
| |
| |
| status_message: str = "" |
| success: bool = True |
| error: Optional[str] = None |
|
|
| def to_dict(self) -> Dict[str, Any]: |
| """Convert result to dictionary for JSON serialization.""" |
| return asdict(self) |
|
|
|
|
| def format_sample( |
| llm_handler, |
| caption: str, |
| lyrics: str, |
| user_metadata: Optional[Dict[str, Any]] = None, |
| temperature: float = 0.85, |
| top_k: Optional[int] = None, |
| top_p: Optional[float] = None, |
| repetition_penalty: float = 1.0, |
| use_constrained_decoding: bool = True, |
| constrained_decoding_debug: bool = False, |
| ) -> FormatSampleResult: |
| """Format user-provided caption and lyrics using the 5Hz Language Model. |
| |
| This function takes user input (caption and lyrics) and generates structured |
| music metadata including an enhanced caption, BPM, duration, key, language, |
| and time signature. |
| |
| If user_metadata is provided, those values will be used to constrain the |
| decoding, ensuring the output matches user-specified values. |
| |
| Note: cfg_scale and negative_prompt are not supported in format mode. |
| |
| Args: |
| llm_handler: Initialized LLM handler (LLMHandler instance) |
| caption: User's caption/description (e.g., "Latin pop, reggaeton") |
| lyrics: User's lyrics with structure tags |
| user_metadata: Optional dict with user-provided metadata to constrain decoding. |
| Supported keys: bpm, duration, keyscale, timesignature, language |
| temperature: Sampling temperature for generation (0.0-2.0). Higher = more creative. |
| top_k: Top-K sampling (None or 0 = disabled) |
| top_p: Top-P (nucleus) sampling (None or 1.0 = disabled) |
| repetition_penalty: Repetition penalty (1.0 = no penalty) |
| use_constrained_decoding: Whether to use FSM-based constrained decoding for metadata |
| constrained_decoding_debug: Whether to enable debug logging for constrained decoding |
| |
| Returns: |
| FormatSampleResult with formatted metadata fields and status |
| |
| Example: |
| >>> result = format_sample(llm_handler, "Latin pop, reggaeton", "[Verse 1]\\nHola mundo...") |
| >>> if result.success: |
| ... print(f"Caption: {result.caption}") |
| ... print(f"BPM: {result.bpm}") |
| ... print(f"Lyrics: {result.lyrics}") |
| """ |
| |
| if not llm_handler.llm_initialized: |
| return FormatSampleResult( |
| status_message="5Hz LM not initialized. Please initialize it first.", |
| success=False, |
| error="LLM not initialized", |
| ) |
| |
| try: |
| |
| metadata, status = llm_handler.format_sample_from_input( |
| caption=caption, |
| lyrics=lyrics, |
| user_metadata=user_metadata, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| use_constrained_decoding=use_constrained_decoding, |
| constrained_decoding_debug=constrained_decoding_debug, |
| ) |
| |
| |
| if not metadata: |
| return FormatSampleResult( |
| status_message=status or "Failed to format input", |
| success=False, |
| error=status or "Empty metadata returned", |
| ) |
| |
| |
| result_caption = metadata.get('caption', '') |
| result_lyrics = metadata.get('lyrics', lyrics) |
| keyscale = metadata.get('keyscale', '') |
| language = metadata.get('language', metadata.get('vocal_language', '')) |
| timesignature = metadata.get('timesignature', '') |
| |
| |
| bpm = None |
| bpm_value = metadata.get('bpm') |
| if bpm_value is not None and bpm_value != 'N/A' and bpm_value != '': |
| try: |
| bpm = int(bpm_value) |
| except (ValueError, TypeError): |
| pass |
| |
| |
| duration = None |
| duration_value = metadata.get('duration') |
| if duration_value is not None and duration_value != 'N/A' and duration_value != '': |
| try: |
| duration = float(duration_value) |
| except (ValueError, TypeError): |
| pass |
| |
| |
| if keyscale == 'N/A': |
| keyscale = '' |
| if language == 'N/A': |
| language = '' |
| if timesignature == 'N/A': |
| timesignature = '' |
| |
| return FormatSampleResult( |
| caption=result_caption, |
| lyrics=result_lyrics, |
| bpm=bpm, |
| duration=duration, |
| keyscale=keyscale, |
| language=language, |
| timesignature=timesignature, |
| status_message=status, |
| success=True, |
| error=None, |
| ) |
| |
| except Exception as e: |
| logger.exception("Format sample failed") |
| return FormatSampleResult( |
| status_message=f"Error: {str(e)}", |
| success=False, |
| error=str(e), |
| ) |
|
|