| |
| """ |
| Enhanced profiling script for ACE-Step inference with deep LLM analysis |
| |
| This script helps diagnose why LLM generation is slow by tracking: |
| 1. Total tokens generated vs expected throughput (200 tokens/sec baseline) |
| 2. Per-iteration timing to detect compilation overhead or slow operations |
| 3. Constrained decoding overhead |
| 4. CFG overhead (2x forward passes) |
| 5. Model forward time vs sampling/processing time |
| |
| Usage: |
| python profile_inference.py # Standard profiling with warmup |
| python profile_inference.py --no-warmup # Profile first run (includes compilation) |
| python profile_inference.py --llm-debug # Deep LLM performance debugging |
| python profile_inference.py --detailed # Add cProfile function-level analysis |
| |
| Inference mode options: |
| python profile_inference.py --thinking # Enable CoT for code generation |
| python profile_inference.py --use-constrained-decoding # Use FSM constrained decoding |
| python profile_inference.py --use-cot-metas # Enable LM to generate metadata via CoT |
| """ |
|
|
| import time |
| import argparse |
| import sys |
| import os |
| from contextlib import contextmanager |
| from collections import defaultdict |
| import json |
| from typing import Tuple, Dict, Any, List |
| from functools import wraps |
|
|
| |
| project_root = os.path.abspath(os.path.dirname(__file__)) |
| if project_root not in sys.path: |
| sys.path.insert(0, project_root) |
|
|
|
|
| def load_env_config(): |
| """从 .env 文件加载配置""" |
| env_config = { |
| 'ACESTEP_CONFIG_PATH': 'acestep-v15-turbo', |
| 'ACESTEP_LM_MODEL_PATH': 'acestep-5Hz-lm-0.6B', |
| 'ACESTEP_DEVICE': 'auto', |
| 'ACESTEP_LM_BACKEND': 'vllm', |
| } |
| |
| env_file = os.path.join(project_root, '.env') |
| if os.path.exists(env_file): |
| with open(env_file, 'r', encoding='utf-8') as f: |
| for line in f: |
| line = line.strip() |
| |
| if not line or line.startswith('#'): |
| continue |
| |
| if '=' in line: |
| key, value = line.split('=', 1) |
| key = key.strip() |
| value = value.strip() |
| if key in env_config and value: |
| env_config[key] = value |
| |
| return env_config |
|
|
| import torch |
| from acestep.inference import generate_music, GenerationParams, GenerationConfig |
| from acestep.handler import AceStepHandler |
| from acestep.llm_inference import LLMHandler |
|
|
|
|
| class PreciseTimer: |
| """High-precision timer with CUDA synchronization for accurate GPU timing""" |
| |
| def __init__(self, device="cuda"): |
| self.device = device |
| self.timings = defaultdict(list) |
| self.enabled = True |
| |
| def sync(self): |
| """Synchronize CUDA operations for accurate timing""" |
| if self.enabled and self.device.startswith("cuda") and torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| |
| @contextmanager |
| def time(self, name: str): |
| """Time a code section with CUDA synchronization""" |
| if not self.enabled: |
| yield |
| return |
| |
| self.sync() |
| start = time.perf_counter() |
| try: |
| yield |
| finally: |
| self.sync() |
| elapsed = time.perf_counter() - start |
| self.timings[name].append(elapsed) |
| |
| def get_total(self, name: str) -> float: |
| """Get total accumulated time for a section""" |
| return sum(self.timings.get(name, [])) |
| |
| def get_mean(self, name: str) -> float: |
| """Get mean time per call for a section""" |
| times = self.timings.get(name, []) |
| return sum(times) / len(times) if times else 0.0 |
| |
| def get_count(self, name: str) -> int: |
| """Get number of calls for a section""" |
| return len(self.timings.get(name, [])) |
| |
| def get_all(self, name: str) -> List[float]: |
| """Get all timing samples for a section""" |
| return self.timings.get(name, []) |
|
|
|
|
| class LLMDebugger: |
| """Track detailed LLM performance metrics to diagnose slow generation""" |
| |
| def __init__(self): |
| self.reset() |
| |
| def reset(self): |
| """Reset all metrics""" |
| self.total_tokens = 0 |
| self.generation_start = None |
| self.generation_end = None |
| self.output_text = "" |
| self.prompt_length = 0 |
| |
| def start(self, prompt_length: int = 0): |
| """Mark generation start""" |
| self.generation_start = time.perf_counter() |
| self.prompt_length = prompt_length |
| |
| def end(self, output_text: str = ""): |
| """Mark generation end and store output""" |
| self.generation_end = time.perf_counter() |
| self.output_text = output_text |
| |
| def set_token_count(self, count: int): |
| """Set total token count""" |
| self.total_tokens = count |
| |
| def get_throughput(self) -> float: |
| """Calculate actual tokens per second""" |
| if self.generation_start and self.generation_end and self.total_tokens > 0: |
| total_time = self.generation_end - self.generation_start |
| if total_time > 0: |
| return self.total_tokens / total_time |
| return 0.0 |
| |
| def print_analysis(self): |
| """Print detailed LLM performance analysis""" |
| if not self.generation_start or not self.generation_end: |
| return |
| |
| print("\n" + "=" * 100) |
| print("🔍 LLM PERFORMANCE DEEP DIVE") |
| print("=" * 100) |
| |
| total_time = self.generation_end - self.generation_start |
| throughput = self.get_throughput() |
| |
| |
| print(f"\n{'Metric':<40} {'Value':<20} {'Notes'}") |
| print("-" * 100) |
| print(f"{'Total Tokens Generated:':<40} {self.total_tokens:<20} (new tokens only)") |
| print(f"{'Prompt Length (estimate):':<40} {self.prompt_length:<20} (input tokens)") |
| print(f"{'Total Generation Time:':<40} {total_time:<20.3f} seconds") |
| print(f"{'Measured Throughput:':<40} {throughput:<20.1f} tokens/sec") |
| print(f"{'Expected Throughput:':<40} {'200':<20} tokens/sec (baseline)") |
| |
| |
| if throughput > 0: |
| slowdown = 200.0 / throughput |
| efficiency = (throughput / 200.0) * 100 |
| print(f"{'Performance vs Baseline:':<40} {efficiency:<20.1f}% of expected") |
| print(f"{'Slowdown Factor:':<40} {slowdown:<20.2f}x slower") |
| |
| |
| if self.output_text: |
| print(f"\n{'Output Analysis:':<40}") |
| print(f"{' Output length:':<40} {len(self.output_text):<20} characters") |
| |
| |
| import re |
| code_pattern = r'<\|audio_code_\d+\|>' |
| codes = re.findall(code_pattern, self.output_text) |
| if codes: |
| print(f"{' Audio codes generated:':<40} {len(codes):<20} codes") |
| print(f"{' Expected audio duration:':<40} {f'~{len(codes)/5:.1f}s':<20} (5 codes per second)") |
| if total_time > 0: |
| print(f"{' Time per audio code:':<40} {f'{total_time/len(codes)*1000:.1f}ms':<20}") |
| |
| |
| if '<think>' in self.output_text and '</think>' in self.output_text: |
| cot_start = self.output_text.find('<think>') |
| cot_end = self.output_text.find('</think>') + 8 |
| cot_section = self.output_text[cot_start:cot_end] |
| cot_token_est = len(cot_section) // 4 |
| print(f"{' CoT section tokens (estimate):':<40} {f'~{cot_token_est}':<20}") |
| |
| |
| print("\n" + "=" * 100) |
| print("🔧 DIAGNOSTIC GUIDANCE") |
| print("=" * 100) |
| |
| if throughput < 50: |
| print("\n⚠️ CRITICAL: Throughput is extremely low (<50 tokens/sec)") |
| print("\nThis is ~4x slower than expected. Likely causes:") |
| print(" 1. ❗ Constrained decoding FSM overhead") |
| print(" → Each token triggers FSM state machine validation") |
| print(" → Try: set use_constrained_decoding=False in config") |
| print(" 2. ❗ CFG with double forward passes") |
| print(" → cfg_scale > 1.0 means running model twice per token") |
| print(" → Check: params.lm_cfg_scale value") |
| print(" 3. ❗ Running in eager mode without compilation") |
| print(" → PyTorch should compile kernels after warmup") |
| print(" → Check: torch._dynamo.config settings") |
| |
| elif throughput < 100: |
| print("\n⚠️ WARNING: Throughput is low (50-100 tokens/sec)") |
| print("\nLikely causes:") |
| print(" 1. Constrained decoding overhead (~30-50% slowdown expected)") |
| print(" 2. CFG enabled (2x compute per token if cfg_scale > 1.0)") |
| print(" 3. Small model or inefficient GPU utilization") |
| |
| elif throughput < 150: |
| print("\n⚠️ Throughput is below baseline but acceptable (100-150 tokens/sec)") |
| print("\nMinor overhead from:") |
| print(" - Constrained decoding: ~20-30% overhead") |
| print(" - Profiling instrumentation: ~5-10% overhead") |
| |
| else: |
| print(f"\n✓ Throughput is good ({throughput:.1f} tokens/sec)") |
| print(" Performance is within acceptable range") |
|
|
|
|
| |
| timer = None |
| llm_debugger = None |
|
|
|
|
| def wrap_method_with_timing(obj, method_name: str, timing_key: str): |
| """Wrap a method with timing instrumentation""" |
| original_method = getattr(obj, method_name) |
| |
| @wraps(original_method) |
| def timed_wrapper(*args, **kwargs): |
| with timer.time(timing_key): |
| return original_method(*args, **kwargs) |
| |
| setattr(obj, method_name, timed_wrapper) |
| return original_method |
|
|
|
|
| def wrap_llm_with_debug_tracking(llm_handler): |
| """Wrap LLM generation with detailed performance tracking""" |
| original_method = llm_handler.generate_with_stop_condition |
| |
| @wraps(original_method) |
| def debug_wrapper(*args, **kwargs): |
| |
| caption = kwargs.get('caption', args[0] if len(args) > 0 else "") |
| lyrics = kwargs.get('lyrics', args[1] if len(args) > 1 else "") |
| prompt_estimate = len(caption) + len(lyrics) |
| prompt_tokens_estimate = prompt_estimate // 4 |
| |
| |
| llm_debugger.reset() |
| llm_debugger.start(prompt_length=prompt_tokens_estimate) |
| |
| |
| with timer.time('llm_inference'): |
| result = original_method(*args, **kwargs) |
| |
| |
| output_text = "" |
| if isinstance(result, tuple) and len(result) >= 2: |
| if isinstance(result[1], list): |
| |
| output_text = "".join(result[1]) |
| else: |
| |
| cot_output = "" |
| if isinstance(result[0], dict): |
| for v in result[0].values(): |
| if isinstance(v, str): |
| cot_output += v |
| output_text = cot_output + str(result[1]) |
| |
| |
| import re |
| code_pattern = r'<\|audio_code_\d+\|>' |
| codes = re.findall(code_pattern, output_text) |
| remaining_text = re.sub(code_pattern, '', output_text) |
| cot_tokens_estimate = len(remaining_text) // 4 |
| total_tokens = len(codes) + cot_tokens_estimate |
| |
| llm_debugger.set_token_count(total_tokens) |
| llm_debugger.end(output_text) |
| |
| return result |
| |
| llm_handler.generate_with_stop_condition = debug_wrapper |
| return original_method |
|
|
|
|
| def instrument_handlers(dit_handler, llm_handler, enable_llm_debug=False): |
| """Add timing instrumentation to handler methods""" |
| originals = {} |
| |
| |
| if llm_handler and llm_handler.llm_initialized: |
| if enable_llm_debug: |
| originals['llm_generate'] = wrap_llm_with_debug_tracking(llm_handler) |
| else: |
| originals['llm_generate'] = wrap_method_with_timing( |
| llm_handler, 'generate_with_stop_condition', 'llm_inference' |
| ) |
| |
| |
| originals['dit_prepare'] = wrap_method_with_timing( |
| dit_handler, 'prepare_batch_data', 'prepare_batch_data' |
| ) |
| originals['dit_generate'] = wrap_method_with_timing( |
| dit_handler, 'service_generate', 'dit_inference' |
| ) |
| originals['dit_decode'] = wrap_method_with_timing( |
| dit_handler, 'tiled_decode', 'vae_decode' |
| ) |
| |
| return originals |
|
|
|
|
| def restore_handlers(dit_handler, llm_handler, originals): |
| """Restore original handler methods after profiling""" |
| if llm_handler and 'llm_generate' in originals: |
| llm_handler.generate_with_stop_condition = originals['llm_generate'] |
| |
| dit_handler.prepare_batch_data = originals['dit_prepare'] |
| dit_handler.service_generate = originals['dit_generate'] |
| dit_handler.tiled_decode = originals['dit_decode'] |
|
|
|
|
| def print_profiling_results(total_time: float, show_llm_debug: bool = False): |
| """Print comprehensive profiling results with performance insights""" |
| print("\n" + "=" * 100) |
| print("🎯 PROFILING RESULTS") |
| print("=" * 100) |
| |
| |
| model_sections = { |
| 'llm_inference': 'LLM Inference (5Hz Language Model)', |
| 'dit_inference': 'DiT Inference (Diffusion Transformer)', |
| 'vae_decode': 'VAE Decode (Audio Decoder)', |
| } |
| |
| non_model_sections = { |
| 'prepare_batch_data': 'Prepare Batch Data (embedding, formatting)', |
| } |
| |
| |
| model_time = sum(timer.get_total(k) for k in model_sections.keys()) |
| non_model_time = sum(timer.get_total(k) for k in non_model_sections.keys()) |
| other_time = total_time - model_time - non_model_time |
| |
| |
| print(f"\n{'CATEGORY':<50} {'TIME (s)':<12} {'%':<8} {'CALLS':<8}") |
| print("-" * 100) |
| |
| |
| print(f"\n{'🤖 MODEL TIME (Total)':<50} {model_time:<12.3f} {100*model_time/total_time:>6.1f}% {'':<8}") |
| for key, desc in model_sections.items(): |
| t = timer.get_total(key) |
| c = timer.get_count(key) |
| if c > 0: |
| mean = timer.get_mean(key) |
| pct = 100 * t / total_time |
| print(f" {'├─ ' + desc:<48} {t:<12.3f} {pct:>6.1f}% {c:<8} (avg: {mean:.3f}s)") |
| |
| |
| print(f"\n{'⚙️ NON-MODEL TIME (Total)':<50} {non_model_time:<12.3f} {100*non_model_time/total_time:>6.1f}% {'':<8}") |
| for key, desc in non_model_sections.items(): |
| t = timer.get_total(key) |
| c = timer.get_count(key) |
| if c > 0: |
| mean = timer.get_mean(key) |
| pct = 100 * t / total_time |
| print(f" {'├─ ' + desc:<48} {t:<12.3f} {pct:>6.1f}% {c:<8} (avg: {mean:.3f}s)") |
| |
| |
| if other_time > 0.01: |
| pct = 100 * other_time / total_time |
| print(f"\n{'📦 OTHER TIME (I/O, overhead, audio save)':<50} {other_time:<12.3f} {pct:>6.1f}% {'':<8}") |
| |
| print(f"\n{'📊 TOTAL TIME':<50} {total_time:<12.3f} {'100.0%':>6} {'':<8}") |
| |
| |
| if show_llm_debug: |
| llm_debugger.print_analysis() |
| |
| |
| print("\n" + "=" * 100) |
| print("💡 PERFORMANCE INSIGHTS") |
| print("=" * 100) |
| |
| llm_t = timer.get_total('llm_inference') |
| dit_t = timer.get_total('dit_inference') |
| vae_t = timer.get_total('vae_decode') |
| prep_t = timer.get_total('prepare_batch_data') |
| |
| |
| if model_time > 0: |
| print(f"\n✓ Model operations: {model_time:.3f}s ({100*model_time/total_time:.1f}% of total)") |
| |
| if llm_t > 0: |
| print(f" - LLM: {llm_t:.3f}s ({100*llm_t/model_time:.1f}% of model time)") |
| if dit_t > 0: |
| print(f" - DiT: {dit_t:.3f}s ({100*dit_t/model_time:.1f}% of model time)") |
| if vae_t > 0: |
| print(f" - VAE: {vae_t:.3f}s ({100*vae_t/model_time:.1f}% of model time)") |
| |
| |
| if llm_t > dit_t and llm_t > 5.0: |
| print(f"\n⚠️ LLM IS THE BOTTLENECK: {llm_t:.3f}s ({100*llm_t/total_time:.1f}% of total)") |
| print(f"\n Possible causes:") |
| print(f" 1. Generating too many tokens → use --llm-debug to verify") |
| print(f" 2. Constrained decoding overhead → FSM validation per token") |
| print(f" 3. CFG overhead → cfg_scale > 1.0 = 2x forward passes") |
| print(f" 4. First-token latency → warmup should help") |
| print(f" 5. KV cache inefficiency → should be ~5-10ms/token") |
| |
| |
| if non_model_time / total_time > 0.1: |
| print(f"\n⚠️ Non-model operations: {non_model_time:.3f}s ({100*non_model_time/total_time:.1f}%)") |
| if prep_t > 0.1: |
| print(f" - Batch preparation: {prep_t:.3f}s") |
| |
| |
| if other_time / total_time > 0.2: |
| print(f"\n⚠️ Overhead/I/O: {other_time:.3f}s ({100*other_time/total_time:.1f}%)") |
| |
| |
| print("\n" + "=" * 100) |
| print("🚀 OPTIMIZATION RECOMMENDATIONS") |
| print("=" * 100) |
| |
| if llm_t > dit_t * 2: |
| print("\n🎯 Priority: Optimize LLM") |
| print(" 1. Run: python profile_inference.py --llm-debug") |
| print(" → Shows exact token count and throughput") |
| print(" 2. Check constrained decoding overhead") |
| print(" 3. Check CFG scaling (lm_cfg_scale parameter)") |
| print(" 4. Profile nanovllm engine step() timing") |
| print(" 5. Compare vllm vs transformers backends") |
|
|
|
|
| def run_profiled_generation(dit_handler, llm_handler, params, config, |
| enable_cprofile=False, enable_llm_debug=False): |
| """Execute generation with full profiling instrumentation""" |
| |
| originals = instrument_handlers(dit_handler, llm_handler, enable_llm_debug) |
| |
| try: |
| print("\n[Profiling] Starting generation...") |
| timer.sync() |
| total_start = time.perf_counter() |
| |
| |
| prof = None |
| if enable_cprofile: |
| import cProfile |
| prof = cProfile.Profile() |
| prof.enable() |
| |
| |
| result = generate_music(dit_handler, llm_handler, params, config, save_dir="./") |
| |
| |
| timer.sync() |
| total_time = time.perf_counter() - total_start |
| |
| |
| if enable_cprofile and prof: |
| prof.disable() |
| |
| import pstats |
| import io |
| |
| output_file = "profile_cprofile_detailed.txt" |
| with open(output_file, 'w') as f: |
| ps = pstats.Stats(prof, stream=f) |
| ps.sort_stats('cumulative') |
| ps.print_stats(100) |
| |
| |
| print("\n" + "=" * 100) |
| print("📊 TOP 20 FUNCTIONS BY CUMULATIVE TIME (cProfile)") |
| print("=" * 100) |
| s = io.StringIO() |
| ps = pstats.Stats(prof, stream=s) |
| ps.sort_stats('cumulative') |
| ps.print_stats(20) |
| print(s.getvalue()) |
| |
| print(f"\nFull report: {output_file}") |
| |
| |
| print_profiling_results(total_time, show_llm_debug=enable_llm_debug) |
| |
| return result, total_time |
| |
| finally: |
| restore_handlers(dit_handler, llm_handler, originals) |
|
|
|
|
| def load_example_config(example_file: str) -> Tuple[GenerationParams, GenerationConfig]: |
| """Load configuration from example JSON file""" |
| try: |
| with open(example_file, 'r', encoding='utf-8') as f: |
| data = json.load(f) |
| |
| params = GenerationParams( |
| caption=data.get('caption', ''), |
| lyrics=data.get('lyrics', ''), |
| bpm=data.get('bpm'), |
| keyscale=data.get('keyscale', ''), |
| timesignature=data.get('timesignature', ''), |
| vocal_language=data.get('language', 'unknown'), |
| duration=data.get('duration'), |
| thinking=data.get('think', False), |
| inference_steps=data.get('inference_steps', 8), |
| seed=data.get('seed', 42), |
| ) |
| |
| config = GenerationConfig(batch_size=data.get('batch_size', 1), seeds=[42]) |
| |
| return params, config |
| |
| except Exception as e: |
| print(f" ❌ Failed to load: {e}") |
| return None, None |
|
|
|
|
| def main(): |
| global timer, llm_debugger |
| |
| |
| env_config = load_env_config() |
| |
| parser = argparse.ArgumentParser( |
| description="Profile ACE-Step inference with LLM debugging" |
| ) |
| parser.add_argument("--checkpoint-dir", type=str, default="./checkpoints") |
| parser.add_argument("--config-path", type=str, default=env_config['ACESTEP_CONFIG_PATH'], |
| help=f"模型配置路径 (默认从 .env: {env_config['ACESTEP_CONFIG_PATH']})") |
| parser.add_argument("--device", type=str, default=env_config['ACESTEP_DEVICE'], |
| help=f"设备 (默认从 .env: {env_config['ACESTEP_DEVICE']})") |
| parser.add_argument("--lm-model", type=str, default=env_config['ACESTEP_LM_MODEL_PATH'], |
| help=f"LLM 模型路径 (默认从 .env: {env_config['ACESTEP_LM_MODEL_PATH']})") |
| parser.add_argument("--lm-backend", type=str, default=env_config['ACESTEP_LM_BACKEND'], |
| help=f"LLM 后端 (默认从 .env: {env_config['ACESTEP_LM_BACKEND']})") |
| parser.add_argument("--no-warmup", action="store_true") |
| parser.add_argument("--detailed", action="store_true") |
| parser.add_argument("--llm-debug", action="store_true", |
| help="Enable deep LLM debugging (token count, throughput)") |
| parser.add_argument("--example", type=str, default="example_05.json") |
| |
| |
| parser.add_argument("--thinking", action="store_true", |
| help="Enable CoT reasoning for LM to generate audio codes") |
| parser.add_argument("--use-constrained-decoding", action="store_true", |
| help="Use FSM-based constrained decoding for meta generation") |
| parser.add_argument("--use-cot-metas", action="store_true", |
| help="Enable LLM to generate music metadata via CoT reasoning") |
| |
| args = parser.parse_args() |
| |
| |
| timer = PreciseTimer(device=args.device) |
| llm_debugger = LLMDebugger() |
| |
| print("=" * 100) |
| print("🎵 ACE-Step Inference Profiler (LLM Performance Analysis)") |
| print("=" * 100) |
| print(f"\n模型配置 (从 .env 加载):") |
| print(f" DiT 模型: {args.config_path}") |
| print(f" LLM 模型: {args.lm_model}") |
| print(f"\n运行配置:") |
| print(f" Device: {args.device}") |
| print(f" LLM Backend: {args.lm_backend}") |
| print(f" LLM Debug: {'Enabled' if args.llm_debug else 'Disabled'}") |
| print(f" Warmup: {'Disabled' if args.no_warmup else 'Enabled'}") |
| print(f"\nInference Mode:") |
| print(f" Thinking (CoT): {'Enabled' if args.thinking else 'Disabled'}") |
| print(f" Constrained Decoding: {'Enabled' if args.use_constrained_decoding else 'Disabled'}") |
| print(f" Use CoT for Metas: {'Enabled' if args.use_cot_metas else 'Disabled'}") |
| |
| |
| print(f"\nInitializing models...") |
| |
| dit_handler = AceStepHandler() |
| llm_handler = LLMHandler() |
| |
| print(" 🎹 Initializing DiT...") |
| status_dit, success_dit = dit_handler.initialize_service( |
| project_root=project_root, |
| config_path=args.config_path, |
| device=args.device, |
| use_flash_attention=True, |
| ) |
| if not success_dit: |
| print(f" ❌ Failed: {status_dit}") |
| sys.exit(1) |
| print(f" ✓ DiT ready") |
| |
| print(" 🧠 Initializing LLM...") |
| if args.thinking or args.use_cot_metas: |
| status_llm, success_llm = llm_handler.initialize( |
| checkpoint_dir=args.checkpoint_dir, |
| lm_model_path=args.lm_model, |
| backend=args.lm_backend, |
| device=args.device, |
| ) |
| if success_llm: |
| print(f" ✓ LLM ready ({args.lm_backend})") |
| else: |
| print(f" ⚠ Failed: {status_llm}") |
| else: |
| print(f" ✓ LLM not initialized (thinking or use_cot_metas is disabled)") |
| |
| |
| example_file = os.path.join(project_root, "examples", "text2music", args.example) |
| if not os.path.exists(example_file): |
| print(f"\n❌ Not found: {example_file}") |
| sys.exit(1) |
| |
| print(f"\n📄 Loading: {args.example}") |
| params, config = load_example_config(example_file) |
| |
| if not params or not config: |
| print("❌ Failed to load config") |
| sys.exit(1) |
| |
| print(f" Caption: {params.caption[:60]}...") |
| print(f" Batch: {config.batch_size}, Steps: {params.inference_steps}, LLM: {params.thinking}") |
| |
| |
| if not args.no_warmup: |
| print("\n" + "=" * 100) |
| print("🔥 WARMUP RUN") |
| print("=" * 100) |
| |
| warmup_params = GenerationParams( |
| caption=params.caption, |
| lyrics=params.lyrics, |
| bpm=params.bpm, |
| keyscale=params.keyscale, |
| timesignature=params.timesignature, |
| vocal_language=params.vocal_language, |
| duration=params.duration, |
| thinking=args.thinking, |
| use_cot_metas=args.use_cot_metas, |
| inference_steps=params.inference_steps, |
| seed=params.seed, |
| ) |
| warmup_config = GenerationConfig(batch_size=1, seeds=[42]) |
| warmup_config.use_constrained_decoding = args.use_constrained_decoding |
| |
| warmup_start = time.perf_counter() |
| warmup_result = generate_music(dit_handler, llm_handler, warmup_params, warmup_config, save_dir="./") |
| warmup_time = time.perf_counter() - warmup_start |
| |
| print(f"\n✓ Warmup: {warmup_time:.2f}s") |
| if not warmup_result.success: |
| print(f"⚠️ Warning: {warmup_result.error}") |
| |
| |
| timer = PreciseTimer(device=args.device) |
| llm_debugger = LLMDebugger() |
| |
| |
| print("\n" + "=" * 100) |
| print("⏱️ PROFILING RUN") |
| print("=" * 100) |
| |
| |
| config.use_constrained_decoding = args.use_constrained_decoding |
| |
| if args.thinking: |
| params.thinking = True |
| if args.use_cot_metas: |
| params.use_cot_metas = True |
| |
| result, total_time = run_profiled_generation( |
| dit_handler, llm_handler, params, config, |
| enable_cprofile=args.detailed, |
| enable_llm_debug=args.llm_debug |
| ) |
| |
| if not result.success: |
| print(f"\n❌ Failed: {result.error}") |
| sys.exit(1) |
| |
| print(f"\n✅ Success! Generated {len(result.audios)} audio file(s)") |
| |
| |
| if args.detailed: |
| print("\n💡 Check profile_cprofile_detailed.txt for function-level analysis") |
| elif not args.llm_debug: |
| print("\n💡 Run with --llm-debug to see LLM token count and throughput analysis") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|