| """ |
| Business Logic Handler |
| Encapsulates all data processing and business logic as a bridge between model and UI |
| """ |
| import os |
|
|
| |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| import math |
| from copy import deepcopy |
| import tempfile |
| import traceback |
| import re |
| import random |
| import uuid |
| import hashlib |
| import json |
| from contextlib import contextmanager |
| from typing import Optional, Dict, Any, Tuple, List, Union |
|
|
| import torch |
| import torchaudio |
| import soundfile as sf |
| import time |
| from tqdm import tqdm |
| from loguru import logger |
| import warnings |
|
|
| from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM |
| from transformers.generation.streamers import BaseStreamer |
| from diffusers.models import AutoencoderOobleck |
| from acestep.constants import ( |
| TASK_INSTRUCTIONS, |
| SFT_GEN_PROMPT, |
| DEFAULT_DIT_INSTRUCTION, |
| ) |
| from acestep.dit_alignment_score import MusicStampsAligner, MusicLyricScorer |
|
|
|
|
| warnings.filterwarnings("ignore") |
|
|
|
|
| class AceStepHandler: |
| """ACE-Step Business Logic Handler""" |
| |
| def __init__(self): |
| self.model = None |
| self.config = None |
| self.device = "cpu" |
| self.dtype = torch.float32 |
|
|
| |
| self.vae = None |
| |
| |
| self.text_encoder = None |
| self.text_tokenizer = None |
| |
| |
| self.silence_latent = None |
| |
| |
| self.sample_rate = 48000 |
| |
| |
| self.reward_model = None |
| |
| |
| self.batch_size = 2 |
| |
| |
| self.custom_layers_config = {2: [6], 3: [10, 11], 4: [3], 5: [8, 9], 6: [8]} |
| self.offload_to_cpu = False |
| self.offload_dit_to_cpu = False |
| self.current_offload_cost = 0.0 |
| |
| |
| self.lora_loaded = False |
| self.use_lora = False |
| self._base_decoder = None |
| |
| def get_available_checkpoints(self) -> str: |
| """Return project root directory path""" |
| |
| project_root = self._get_project_root() |
| |
| checkpoint_dir = os.path.join(project_root, "checkpoints") |
| if os.path.exists(checkpoint_dir): |
| return [checkpoint_dir] |
| else: |
| return [] |
| |
| def get_available_acestep_v15_models(self) -> List[str]: |
| """Scan and return all model directory names starting with 'acestep-v15-'""" |
| |
| project_root = self._get_project_root() |
| checkpoint_dir = os.path.join(project_root, "checkpoints") |
| |
| models = [] |
| if os.path.exists(checkpoint_dir): |
| |
| for item in os.listdir(checkpoint_dir): |
| item_path = os.path.join(checkpoint_dir, item) |
| if os.path.isdir(item_path) and item.startswith("acestep-v15-"): |
| models.append(item) |
| |
| |
| models.sort() |
| return models |
| |
| def is_flash_attention_available(self) -> bool: |
| """Check if flash attention is available on the system""" |
| try: |
| import flash_attn |
| return True |
| except ImportError: |
| return False |
| |
| def is_turbo_model(self) -> bool: |
| """Check if the currently loaded model is a turbo model""" |
| if self.config is None: |
| return False |
| return getattr(self.config, 'is_turbo', False) |
| |
| def load_lora(self, lora_path: str) -> str: |
| """Load LoRA adapter into the decoder. |
| |
| Args: |
| lora_path: Path to the LoRA adapter directory (containing adapter_config.json) |
| |
| Returns: |
| Status message |
| """ |
| if self.model is None: |
| return "❌ Model not initialized. Please initialize service first." |
| |
| if not lora_path or not lora_path.strip(): |
| return "❌ Please provide a LoRA path." |
| |
| lora_path = lora_path.strip() |
| |
| |
| if not os.path.exists(lora_path): |
| return f"❌ LoRA path not found: {lora_path}" |
| |
| |
| config_file = os.path.join(lora_path, "adapter_config.json") |
| if not os.path.exists(config_file): |
| return f"❌ Invalid LoRA adapter: adapter_config.json not found in {lora_path}" |
| |
| try: |
| from peft import PeftModel, PeftConfig |
| except ImportError: |
| return "❌ PEFT library not installed. Please install with: pip install peft" |
| |
| try: |
| |
| if self._base_decoder is None: |
| import copy |
| self._base_decoder = copy.deepcopy(self.model.decoder) |
| logger.info("Base decoder backed up") |
| else: |
| |
| self.model.decoder = copy.deepcopy(self._base_decoder) |
| logger.info("Restored base decoder before loading new LoRA") |
| |
| |
| logger.info(f"Loading LoRA adapter from {lora_path}") |
| self.model.decoder = PeftModel.from_pretrained( |
| self.model.decoder, |
| lora_path, |
| is_trainable=False, |
| ) |
| self.model.decoder = self.model.decoder.to(self.device).to(self.dtype) |
| self.model.decoder.eval() |
| |
| self.lora_loaded = True |
| self.use_lora = True |
| |
| logger.info(f"LoRA adapter loaded successfully from {lora_path}") |
| return f"✅ LoRA loaded from {lora_path}" |
| |
| except Exception as e: |
| logger.exception("Failed to load LoRA adapter") |
| return f"❌ Failed to load LoRA: {str(e)}" |
| |
| def unload_lora(self) -> str: |
| """Unload LoRA adapter and restore base decoder. |
| |
| Returns: |
| Status message |
| """ |
| if not self.lora_loaded: |
| return "⚠️ No LoRA adapter loaded." |
| |
| if self._base_decoder is None: |
| return "❌ Base decoder backup not found. Cannot restore." |
| |
| try: |
| import copy |
| |
| self.model.decoder = copy.deepcopy(self._base_decoder) |
| self.model.decoder = self.model.decoder.to(self.device).to(self.dtype) |
| self.model.decoder.eval() |
| |
| self.lora_loaded = False |
| self.use_lora = False |
| |
| logger.info("LoRA unloaded, base decoder restored") |
| return "✅ LoRA unloaded, using base model" |
| |
| except Exception as e: |
| logger.exception("Failed to unload LoRA") |
| return f"❌ Failed to unload LoRA: {str(e)}" |
| |
| def set_use_lora(self, use_lora: bool) -> str: |
| """Toggle LoRA usage for inference. |
| |
| Args: |
| use_lora: Whether to use LoRA adapter |
| |
| Returns: |
| Status message |
| """ |
| if use_lora and not self.lora_loaded: |
| return "❌ No LoRA adapter loaded. Please load a LoRA first." |
| |
| self.use_lora = use_lora |
| |
| |
| if self.lora_loaded and hasattr(self.model.decoder, 'disable_adapter_layers'): |
| try: |
| if use_lora: |
| self.model.decoder.enable_adapter_layers() |
| logger.info("LoRA adapter enabled") |
| else: |
| self.model.decoder.disable_adapter_layers() |
| logger.info("LoRA adapter disabled") |
| except Exception as e: |
| logger.warning(f"Could not toggle adapter layers: {e}") |
| |
| status = "enabled" if use_lora else "disabled" |
| return f"✅ LoRA {status}" |
| |
| def get_lora_status(self) -> Dict[str, Any]: |
| """Get current LoRA status. |
| |
| Returns: |
| Dictionary with LoRA status info |
| """ |
| return { |
| "loaded": self.lora_loaded, |
| "active": self.use_lora, |
| } |
| |
| def initialize_service( |
| self, |
| project_root: str, |
| config_path: str, |
| device: str = "auto", |
| use_flash_attention: bool = False, |
| compile_model: bool = False, |
| offload_to_cpu: bool = False, |
| offload_dit_to_cpu: bool = False, |
| quantization: Optional[str] = None, |
| ) -> Tuple[str, bool]: |
| """ |
| Initialize DiT model service |
| |
| Args: |
| project_root: Project root path (may be checkpoints directory, will be handled automatically) |
| config_path: Model config directory name (e.g., "acestep-v15-turbo") |
| device: Device type |
| use_flash_attention: Whether to use flash attention (requires flash_attn package) |
| compile_model: Whether to use torch.compile to optimize the model |
| offload_to_cpu: Whether to offload models to CPU when not in use |
| offload_dit_to_cpu: Whether to offload DiT model to CPU when not in use (only effective if offload_to_cpu is True) |
| |
| Returns: |
| (status_message, enable_generate_button) |
| """ |
| try: |
| if device == "auto": |
| if hasattr(torch, 'xpu') and torch.xpu.is_available(): |
| device = "xpu" |
| elif torch.cuda.is_available(): |
| device = "cuda" |
| elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| device = "mps" |
| else: |
| device = "cpu" |
|
|
| status_msg = "" |
| |
| self.device = device |
| self.offload_to_cpu = offload_to_cpu |
| self.offload_dit_to_cpu = offload_dit_to_cpu |
| |
| self.dtype = torch.bfloat16 if device in ["cuda","xpu"] else torch.float32 |
| self.quantization = quantization |
| if self.quantization is not None: |
| assert compile_model, "Quantization requires compile_model to be True" |
| try: |
| import torchao |
| except ImportError: |
| raise ImportError("torchao is required for quantization but is not installed. Please install torchao to use quantization features.") |
| |
|
|
| |
| actual_project_root = self._get_project_root() |
| checkpoint_dir = os.path.join(actual_project_root, "checkpoints") |
|
|
| |
| |
| acestep_v15_checkpoint_path = os.path.join(checkpoint_dir, config_path) |
| if os.path.exists(acestep_v15_checkpoint_path): |
| |
| if use_flash_attention and self.is_flash_attention_available(): |
| attn_implementation = "flash_attention_2" |
| self.dtype = torch.bfloat16 |
| else: |
| attn_implementation = "sdpa" |
|
|
| try: |
| logger.info(f"[initialize_service] Attempting to load model with attention implementation: {attn_implementation}") |
| self.model = AutoModel.from_pretrained( |
| acestep_v15_checkpoint_path, |
| trust_remote_code=True, |
| attn_implementation=attn_implementation, |
| dtype="bfloat16" |
| ) |
| except Exception as e: |
| logger.warning(f"[initialize_service] Failed to load model with {attn_implementation}: {e}") |
| if attn_implementation == "sdpa": |
| logger.info("[initialize_service] Falling back to eager attention") |
| attn_implementation = "eager" |
| self.model = AutoModel.from_pretrained( |
| acestep_v15_checkpoint_path, |
| trust_remote_code=True, |
| attn_implementation=attn_implementation |
| ) |
| else: |
| raise e |
|
|
| self.model.config._attn_implementation = attn_implementation |
| self.config = self.model.config |
| |
| if not self.offload_to_cpu: |
| self.model = self.model.to(device).to(self.dtype) |
| else: |
| |
| if not self.offload_dit_to_cpu: |
| logger.info(f"[initialize_service] Keeping main model on {device} (persistent)") |
| self.model = self.model.to(device).to(self.dtype) |
| else: |
| self.model = self.model.to("cpu").to(self.dtype) |
| self.model.eval() |
| |
| if compile_model: |
| self.model = torch.compile(self.model) |
| |
| if self.quantization is not None: |
| from torchao.quantization import quantize_ |
| if self.quantization == "int8_weight_only": |
| from torchao.quantization import Int8WeightOnlyConfig |
| quant_config = Int8WeightOnlyConfig() |
| elif self.quantization == "fp8_weight_only": |
| from torchao.quantization import Float8WeightOnlyConfig |
| quant_config = Float8WeightOnlyConfig() |
| elif self.quantization == "w8a8_dynamic": |
| from torchao.quantization import Int8DynamicActivationInt8WeightConfig, MappingType |
| quant_config = Int8DynamicActivationInt8WeightConfig(act_mapping_type=MappingType.ASYMMETRIC) |
| else: |
| raise ValueError(f"Unsupported quantization type: {self.quantization}") |
| |
| quantize_(self.model, quant_config) |
| logger.info(f"[initialize_service] DiT quantized with: {self.quantization}") |
| |
| |
| silence_latent_path = os.path.join(acestep_v15_checkpoint_path, "silence_latent.pt") |
| if os.path.exists(silence_latent_path): |
| self.silence_latent = torch.load(silence_latent_path).transpose(1, 2) |
| |
| |
| self.silence_latent = self.silence_latent.to(device).to(self.dtype) |
| else: |
| raise FileNotFoundError(f"Silence latent not found at {silence_latent_path}") |
| else: |
| raise FileNotFoundError(f"ACE-Step V1.5 checkpoint not found at {acestep_v15_checkpoint_path}") |
| |
| |
| vae_checkpoint_path = os.path.join(checkpoint_dir, "vae") |
| if os.path.exists(vae_checkpoint_path): |
| self.vae = AutoencoderOobleck.from_pretrained(vae_checkpoint_path) |
| |
| vae_dtype = self._get_vae_dtype(device) |
| if not self.offload_to_cpu: |
| self.vae = self.vae.to(device).to(vae_dtype) |
| else: |
| self.vae = self.vae.to("cpu").to(vae_dtype) |
| self.vae.eval() |
| else: |
| raise FileNotFoundError(f"VAE checkpoint not found at {vae_checkpoint_path}") |
|
|
| if compile_model: |
| self.vae = torch.compile(self.vae) |
| |
| |
| text_encoder_path = os.path.join(checkpoint_dir, "Qwen3-Embedding-0.6B") |
| if os.path.exists(text_encoder_path): |
| self.text_tokenizer = AutoTokenizer.from_pretrained(text_encoder_path) |
| self.text_encoder = AutoModel.from_pretrained(text_encoder_path) |
| if not self.offload_to_cpu: |
| self.text_encoder = self.text_encoder.to(device).to(self.dtype) |
| else: |
| self.text_encoder = self.text_encoder.to("cpu").to(self.dtype) |
| self.text_encoder.eval() |
| else: |
| raise FileNotFoundError(f"Text encoder not found at {text_encoder_path}") |
|
|
| |
| actual_attn = getattr(self.config, "_attn_implementation", "eager") |
| |
| status_msg = f"✅ Model initialized successfully on {device}\n" |
| status_msg += f"Main model: {acestep_v15_checkpoint_path}\n" |
| status_msg += f"VAE: {vae_checkpoint_path}\n" |
| status_msg += f"Text encoder: {text_encoder_path}\n" |
| status_msg += f"Dtype: {self.dtype}\n" |
| status_msg += f"Attention: {actual_attn}\n" |
| status_msg += f"Compiled: {compile_model}\n" |
| status_msg += f"Offload to CPU: {self.offload_to_cpu}\n" |
| status_msg += f"Offload DiT to CPU: {self.offload_dit_to_cpu}" |
| |
| return status_msg, True |
| |
| except Exception as e: |
| error_msg = f"❌ Error initializing model: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
| logger.exception("[initialize_service] Error initializing model") |
| return error_msg, False |
| |
| def _is_on_target_device(self, tensor, target_device): |
| """Check if tensor is on the target device (handles cuda vs cuda:0 comparison).""" |
| if tensor is None: |
| return True |
| target_type = "cpu" if target_device == "cpu" else "cuda" |
| return tensor.device.type == target_type |
| |
| def _ensure_silence_latent_on_device(self): |
| """Ensure silence_latent is on the correct device (self.device).""" |
| if hasattr(self, "silence_latent") and self.silence_latent is not None: |
| if not self._is_on_target_device(self.silence_latent, self.device): |
| self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) |
| |
| def _move_module_recursive(self, module, target_device, dtype=None, visited=None): |
| """ |
| Recursively move a module and all its submodules to the target device. |
| This handles modules that may not be properly registered. |
| """ |
| if visited is None: |
| visited = set() |
| |
| module_id = id(module) |
| if module_id in visited: |
| return |
| visited.add(module_id) |
| |
| |
| module.to(target_device) |
| if dtype is not None: |
| module.to(dtype) |
| |
| |
| for param_name, param in module._parameters.items(): |
| if param is not None and not self._is_on_target_device(param, target_device): |
| module._parameters[param_name] = param.to(target_device) |
| if dtype is not None: |
| module._parameters[param_name] = module._parameters[param_name].to(dtype) |
| |
| |
| for buf_name, buf in module._buffers.items(): |
| if buf is not None and not self._is_on_target_device(buf, target_device): |
| module._buffers[buf_name] = buf.to(target_device) |
| |
| |
| for name, child in module._modules.items(): |
| if child is not None: |
| self._move_module_recursive(child, target_device, dtype, visited) |
| |
| |
| for attr_name in dir(module): |
| if attr_name.startswith('_'): |
| continue |
| try: |
| attr = getattr(module, attr_name, None) |
| if isinstance(attr, torch.nn.Module) and id(attr) not in visited: |
| self._move_module_recursive(attr, target_device, dtype, visited) |
| except Exception: |
| pass |
| |
| def _recursive_to_device(self, model, device, dtype=None): |
| """ |
| Recursively move all parameters and buffers of a model to the specified device. |
| This is more thorough than model.to() for some custom HuggingFace models. |
| """ |
| target_device = torch.device(device) if isinstance(device, str) else device |
| |
| |
| model.to(target_device) |
| if dtype is not None: |
| model.to(dtype) |
| |
| |
| self._move_module_recursive(model, target_device, dtype) |
| |
| |
| wrong_device_params = [] |
| for name, param in model.named_parameters(): |
| if not self._is_on_target_device(param, device): |
| wrong_device_params.append(name) |
| |
| if wrong_device_params and device != "cpu": |
| logger.warning(f"[_recursive_to_device] {len(wrong_device_params)} parameters on wrong device, using state_dict method") |
| |
| state_dict = model.state_dict() |
| moved_state_dict = {} |
| for key, value in state_dict.items(): |
| if isinstance(value, torch.Tensor): |
| moved_state_dict[key] = value.to(target_device) |
| if dtype is not None and moved_state_dict[key].is_floating_point(): |
| moved_state_dict[key] = moved_state_dict[key].to(dtype) |
| else: |
| moved_state_dict[key] = value |
| model.load_state_dict(moved_state_dict) |
| |
| |
| if device != "cpu" and torch.cuda.is_available(): |
| torch.cuda.synchronize() |
| |
| |
| if device != "cpu": |
| still_wrong = [] |
| for name, param in model.named_parameters(): |
| if not self._is_on_target_device(param, device): |
| still_wrong.append(f"{name} on {param.device}") |
| if still_wrong: |
| logger.error(f"[_recursive_to_device] CRITICAL: {len(still_wrong)} parameters still on wrong device: {still_wrong[:10]}") |
| |
| @contextmanager |
| def _load_model_context(self, model_name: str): |
| """ |
| Context manager to load a model to GPU and offload it back to CPU after use. |
| |
| Args: |
| model_name: Name of the model to load ("text_encoder", "vae", "model") |
| """ |
| if not self.offload_to_cpu: |
| yield |
| return |
|
|
| |
| if model_name == "model" and not self.offload_dit_to_cpu: |
| |
| model = getattr(self, model_name, None) |
| if model is not None: |
| |
| |
| try: |
| param = next(model.parameters()) |
| if param.device.type == "cpu": |
| logger.info(f"[_load_model_context] Moving {model_name} to {self.device} (persistent)") |
| self._recursive_to_device(model, self.device, self.dtype) |
| if hasattr(self, "silence_latent"): |
| self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) |
| except StopIteration: |
| pass |
| yield |
| return |
|
|
| model = getattr(self, model_name, None) |
| if model is None: |
| yield |
| return |
|
|
| |
| logger.info(f"[_load_model_context] Loading {model_name} to {self.device}") |
| start_time = time.time() |
| if model_name == "vae": |
| vae_dtype = self._get_vae_dtype() |
| self._recursive_to_device(model, self.device, vae_dtype) |
| else: |
| self._recursive_to_device(model, self.device, self.dtype) |
| |
| if model_name == "model" and hasattr(self, "silence_latent"): |
| self.silence_latent = self.silence_latent.to(self.device).to(self.dtype) |
| |
| load_time = time.time() - start_time |
| self.current_offload_cost += load_time |
| logger.info(f"[_load_model_context] Loaded {model_name} to {self.device} in {load_time:.4f}s") |
|
|
| try: |
| yield |
| finally: |
| |
| logger.info(f"[_load_model_context] Offloading {model_name} to CPU") |
| start_time = time.time() |
| self._recursive_to_device(model, "cpu") |
| |
| |
| |
| |
| |
| torch.cuda.empty_cache() |
| offload_time = time.time() - start_time |
| self.current_offload_cost += offload_time |
| logger.info(f"[_load_model_context] Offloaded {model_name} to CPU in {offload_time:.4f}s") |
|
|
| def process_target_audio(self, audio_file) -> Optional[torch.Tensor]: |
| """Process target audio""" |
| if audio_file is None: |
| return None |
| |
| try: |
| |
| audio_np, sr = sf.read(audio_file, dtype='float32') |
| |
| if audio_np.ndim == 1: |
| audio = torch.from_numpy(audio_np).unsqueeze(0) |
| else: |
| audio = torch.from_numpy(audio_np.T) |
| |
| |
| audio = self._normalize_audio_to_stereo_48k(audio, sr) |
| |
| return audio |
| except Exception as e: |
| logger.exception("[process_target_audio] Error processing target audio") |
| return None |
| |
| def _parse_audio_code_string(self, code_str: str) -> List[int]: |
| """Extract integer audio codes from prompt tokens like <|audio_code_123|>.""" |
| if not code_str: |
| return [] |
| try: |
| return [int(x) for x in re.findall(r"<\|audio_code_(\d+)\|>", code_str)] |
| except Exception as e: |
| logger.debug(f"[_parse_audio_code_string] Failed to parse audio code string: {e}") |
| return [] |
| |
| def _decode_audio_codes_to_latents(self, code_str: str) -> Optional[torch.Tensor]: |
| """ |
| Convert serialized audio code string into 25Hz latents using model quantizer/detokenizer. |
| """ |
| if not self.model or not hasattr(self.model, 'tokenizer') or not hasattr(self.model, 'detokenizer'): |
| return None |
| |
| code_ids = self._parse_audio_code_string(code_str) |
| if len(code_ids) == 0: |
| return None |
| |
| with self._load_model_context("model"): |
| quantizer = self.model.tokenizer.quantizer |
| detokenizer = self.model.detokenizer |
| |
| num_quantizers = getattr(quantizer, "num_quantizers", 1) |
| |
| indices = torch.tensor(code_ids, device=self.device, dtype=torch.long) |
| |
| indices = indices.unsqueeze(0).unsqueeze(-1) |
| |
| |
| |
| quantized = quantizer.get_output_from_indices(indices) |
| if quantized.dtype != self.dtype: |
| quantized = quantized.to(self.dtype) |
| |
| |
| lm_hints_25hz = detokenizer(quantized) |
| return lm_hints_25hz |
| |
| def _create_default_meta(self) -> str: |
| """Create default metadata string.""" |
| return ( |
| "- bpm: N/A\n" |
| "- timesignature: N/A\n" |
| "- keyscale: N/A\n" |
| "- duration: 30 seconds\n" |
| ) |
| |
| def _dict_to_meta_string(self, meta_dict: Dict[str, Any]) -> str: |
| """Convert metadata dict to formatted string.""" |
| bpm = meta_dict.get('bpm', meta_dict.get('tempo', 'N/A')) |
| timesignature = meta_dict.get('timesignature', meta_dict.get('time_signature', 'N/A')) |
| keyscale = meta_dict.get('keyscale', meta_dict.get('key', meta_dict.get('scale', 'N/A'))) |
| duration = meta_dict.get('duration', meta_dict.get('length', 30)) |
|
|
| |
| if isinstance(duration, (int, float)): |
| duration = f"{int(duration)} seconds" |
| elif not isinstance(duration, str): |
| duration = "30 seconds" |
| |
| return ( |
| f"- bpm: {bpm}\n" |
| f"- timesignature: {timesignature}\n" |
| f"- keyscale: {keyscale}\n" |
| f"- duration: {duration}\n" |
| ) |
| |
| def _parse_metas(self, metas: List[Union[str, Dict[str, Any]]]) -> List[str]: |
| """ |
| Parse and normalize metadata with fallbacks. |
| |
| Args: |
| metas: List of metadata (can be strings, dicts, or None) |
| |
| Returns: |
| List of formatted metadata strings |
| """ |
| parsed_metas = [] |
| for meta in metas: |
| if meta is None: |
| |
| parsed_meta = self._create_default_meta() |
| elif isinstance(meta, str): |
| |
| parsed_meta = meta |
| elif isinstance(meta, dict): |
| |
| parsed_meta = self._dict_to_meta_string(meta) |
| else: |
| |
| parsed_meta = self._create_default_meta() |
| |
| parsed_metas.append(parsed_meta) |
| |
| return parsed_metas |
| |
| def build_dit_inputs( |
| self, |
| task: str, |
| instruction: Optional[str], |
| caption: str, |
| lyrics: str, |
| metas: Optional[Union[str, Dict[str, Any]]] = None, |
| vocal_language: str = "en", |
| ) -> Tuple[str, str]: |
| """ |
| Build text inputs for the caption and lyric branches used by DiT. |
| |
| Args: |
| task: Task name (e.g., text2music, cover, repaint); kept for logging/future branching. |
| instruction: Instruction text; default fallback matches service_generate behavior. |
| caption: Caption string (fallback if not in metas). |
| lyrics: Lyrics string. |
| metas: Metadata (str or dict); follows _parse_metas formatting. |
| May contain 'caption' and 'language' fields from LM CoT output. |
| vocal_language: Language code for lyrics section (fallback if not in metas). |
| |
| Returns: |
| (caption_input_text, lyrics_input_text) |
| |
| Example: |
| caption_input, lyrics_input = handler.build_dit_inputs( |
| task="text2music", |
| instruction=None, |
| caption="A calm piano melody", |
| lyrics="la la la", |
| metas={"bpm": 90, "duration": 45, "caption": "LM generated caption", "language": "en"}, |
| vocal_language="en", |
| ) |
| """ |
| |
| final_instruction = self._format_instruction(instruction or DEFAULT_DIT_INSTRUCTION) |
|
|
| |
| |
| actual_caption = caption |
| actual_language = vocal_language |
| |
| if metas is not None: |
| |
| if isinstance(metas, str): |
| |
| parsed_metas = self._parse_metas([metas]) |
| if parsed_metas and isinstance(parsed_metas[0], dict): |
| meta_dict = parsed_metas[0] |
| else: |
| meta_dict = {} |
| elif isinstance(metas, dict): |
| meta_dict = metas |
| else: |
| meta_dict = {} |
| |
| |
| if 'caption' in meta_dict and meta_dict['caption']: |
| actual_caption = str(meta_dict['caption']) |
| |
| |
| if 'language' in meta_dict and meta_dict['language']: |
| actual_language = str(meta_dict['language']) |
|
|
| parsed_meta = self._parse_metas([metas])[0] |
| caption_input = SFT_GEN_PROMPT.format(final_instruction, actual_caption, parsed_meta) |
| lyrics_input = self._format_lyrics(lyrics, actual_language) |
| return caption_input, lyrics_input |
| |
| def _get_text_hidden_states(self, text_prompt: str) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Get text hidden states from text encoder.""" |
| if self.text_tokenizer is None or self.text_encoder is None: |
| raise ValueError("Text encoder not initialized") |
| |
| with self._load_model_context("text_encoder"): |
| |
| text_inputs = self.text_tokenizer( |
| text_prompt, |
| padding="longest", |
| truncation=True, |
| max_length=256, |
| return_tensors="pt", |
| ) |
| text_input_ids = text_inputs.input_ids.to(self.device) |
| text_attention_mask = text_inputs.attention_mask.to(self.device).bool() |
| |
| |
| with torch.no_grad(): |
| text_outputs = self.text_encoder(text_input_ids) |
| if hasattr(text_outputs, 'last_hidden_state'): |
| text_hidden_states = text_outputs.last_hidden_state |
| elif isinstance(text_outputs, tuple): |
| text_hidden_states = text_outputs[0] |
| else: |
| text_hidden_states = text_outputs |
| |
| text_hidden_states = text_hidden_states.to(self.dtype) |
| |
| return text_hidden_states, text_attention_mask |
| |
| def extract_caption_from_sft_format(self, caption: str) -> str: |
| try: |
| if "# Instruction" in caption and "# Caption" in caption: |
| pattern = r'#\s*Caption\s*\n(.*?)(?:\n\s*#\s*Metas|$)' |
| match = re.search(pattern, caption, re.DOTALL) |
| if match: |
| return match.group(1).strip() |
| return caption |
| except Exception as e: |
| logger.exception("[extract_caption_from_sft_format] Error extracting caption") |
| return caption |
|
|
| def prepare_seeds(self, actual_batch_size, seed, use_random_seed): |
| actual_seed_list: List[int] = [] |
| seed_value_for_ui = "" |
|
|
| if use_random_seed: |
| |
| actual_seed_list = [random.randint(0, 2 ** 32 - 1) for _ in range(actual_batch_size)] |
| seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list) |
| else: |
| |
| |
| seed_list = [] |
| if isinstance(seed, str): |
| |
| seed_str_list = [s.strip() for s in seed.split(",")] |
| for s in seed_str_list: |
| if s == "-1" or s == "": |
| seed_list.append(-1) |
| else: |
| try: |
| seed_list.append(int(float(s))) |
| except (ValueError, TypeError) as e: |
| logger.debug(f"[prepare_seeds] Failed to parse seed value '{s}': {e}") |
| seed_list.append(-1) |
| elif seed is None or (isinstance(seed, (int, float)) and seed < 0): |
| |
| seed_list = [-1] * actual_batch_size |
| elif isinstance(seed, (int, float)): |
| |
| seed_list = [int(seed)] |
| else: |
| |
| seed_list = [-1] * actual_batch_size |
|
|
| |
| |
| |
| |
| |
| has_single_non_negative_seed = (len(seed_list) == 1 and seed_list[0] != -1) |
|
|
| for i in range(actual_batch_size): |
| if i < len(seed_list): |
| seed_val = seed_list[i] |
| else: |
| |
| seed_val = -1 |
|
|
| |
| |
| if has_single_non_negative_seed and actual_batch_size > 1 and i > 0: |
| |
| actual_seed_list.append(random.randint(0, 2 ** 32 - 1)) |
| elif seed_val == -1: |
| |
| actual_seed_list.append(random.randint(0, 2 ** 32 - 1)) |
| else: |
| actual_seed_list.append(int(seed_val)) |
|
|
| seed_value_for_ui = ", ".join(str(s) for s in actual_seed_list) |
| return actual_seed_list, seed_value_for_ui |
| |
| def prepare_metadata(self, bpm, key_scale, time_signature): |
| """Build metadata dict - use "N/A" as default for empty fields.""" |
| return self._build_metadata_dict(bpm, key_scale, time_signature) |
| |
| def is_silence(self, audio): |
| return torch.all(audio.abs() < 1e-6) |
| |
| def _get_project_root(self) -> str: |
| """Get project root directory path.""" |
| current_file = os.path.abspath(__file__) |
| return os.path.dirname(os.path.dirname(current_file)) |
| |
| def _get_vae_dtype(self, device: Optional[str] = None) -> torch.dtype: |
| """Get VAE dtype based on device.""" |
| device = device or self.device |
| return torch.bfloat16 if device in ["cuda", "xpu"] else self.dtype |
| |
| def _format_instruction(self, instruction: str) -> str: |
| """Format instruction to ensure it ends with colon.""" |
| if not instruction.endswith(":"): |
| instruction = instruction + ":" |
| return instruction |
| |
| def _normalize_audio_to_stereo_48k(self, audio: torch.Tensor, sr: int) -> torch.Tensor: |
| """ |
| Normalize audio to stereo 48kHz format. |
| |
| Args: |
| audio: Audio tensor [channels, samples] or [samples] |
| sr: Sample rate |
| |
| Returns: |
| Normalized audio tensor [2, samples] at 48kHz |
| """ |
| |
| if audio.shape[0] == 1: |
| audio = torch.cat([audio, audio], dim=0) |
| |
| |
| audio = audio[:2] |
| |
| |
| if sr != 48000: |
| audio = torchaudio.transforms.Resample(sr, 48000)(audio) |
| |
| |
| audio = torch.clamp(audio, -1.0, 1.0) |
| |
| return audio |
| |
| def _normalize_audio_code_hints(self, audio_code_hints: Optional[Union[str, List[str]]], batch_size: int) -> List[Optional[str]]: |
| """Normalize audio_code_hints to list of correct length.""" |
| if audio_code_hints is None: |
| normalized = [None] * batch_size |
| elif isinstance(audio_code_hints, str): |
| normalized = [audio_code_hints] * batch_size |
| elif len(audio_code_hints) == 1 and batch_size > 1: |
| normalized = audio_code_hints * batch_size |
| elif len(audio_code_hints) != batch_size: |
| |
| normalized = list(audio_code_hints[:batch_size]) |
| while len(normalized) < batch_size: |
| normalized.append(None) |
| else: |
| normalized = list(audio_code_hints) |
| |
| |
| normalized = [hint if isinstance(hint, str) and hint.strip() else None for hint in normalized] |
| return normalized |
| |
| def _normalize_instructions(self, instructions: Optional[Union[str, List[str]]], batch_size: int, default: Optional[str] = None) -> List[str]: |
| """Normalize instructions to list of correct length.""" |
| if instructions is None: |
| default_instruction = default or DEFAULT_DIT_INSTRUCTION |
| return [default_instruction] * batch_size |
| elif isinstance(instructions, str): |
| return [instructions] * batch_size |
| elif len(instructions) == 1: |
| return instructions * batch_size |
| elif len(instructions) != batch_size: |
| |
| normalized = list(instructions[:batch_size]) |
| default_instruction = default or DEFAULT_DIT_INSTRUCTION |
| while len(normalized) < batch_size: |
| normalized.append(default_instruction) |
| return normalized |
| else: |
| return list(instructions) |
| |
| def _format_lyrics(self, lyrics: str, language: str) -> str: |
| """Format lyrics text with language header.""" |
| return f"# Languages\n{language}\n\n# Lyric\n{lyrics}<|endoftext|>" |
| |
| def _pad_sequences(self, sequences: List[torch.Tensor], max_length: int, pad_value: int = 0) -> torch.Tensor: |
| """Pad sequences to same length.""" |
| return torch.stack([ |
| torch.nn.functional.pad(seq, (0, max_length - len(seq)), 'constant', pad_value) |
| for seq in sequences |
| ]) |
| |
| def _extract_caption_and_language(self, metas: List[Union[str, Dict[str, Any]]], captions: List[str], vocal_languages: List[str]) -> Tuple[List[str], List[str]]: |
| """Extract caption and language from metas with fallback to provided values.""" |
| actual_captions = list(captions) |
| actual_languages = list(vocal_languages) |
| |
| for i, meta in enumerate(metas): |
| if i >= len(actual_captions): |
| break |
| |
| meta_dict = None |
| if isinstance(meta, str): |
| parsed = self._parse_metas([meta]) |
| if parsed and isinstance(parsed[0], dict): |
| meta_dict = parsed[0] |
| elif isinstance(meta, dict): |
| meta_dict = meta |
| |
| if meta_dict: |
| if 'caption' in meta_dict and meta_dict['caption']: |
| actual_captions[i] = str(meta_dict['caption']) |
| if 'language' in meta_dict and meta_dict['language']: |
| actual_languages[i] = str(meta_dict['language']) |
| |
| return actual_captions, actual_languages |
| |
| def _encode_audio_to_latents(self, audio: torch.Tensor) -> torch.Tensor: |
| """ |
| Encode audio to latents using VAE. |
| |
| Args: |
| audio: Audio tensor [channels, samples] or [batch, channels, samples] |
| |
| Returns: |
| Latents tensor [T, D] or [batch, T, D] |
| """ |
| |
| input_was_2d = (audio.dim() == 2) |
| |
| |
| if input_was_2d: |
| audio = audio.unsqueeze(0) |
| |
| |
| vae_input = audio.to(self.device).to(self.vae.dtype) |
| |
| |
| with torch.no_grad(): |
| latents = self.vae.encode(vae_input).latent_dist.sample() |
| |
| |
| latents = latents.to(self.dtype) |
| |
| |
| latents = latents.transpose(1, 2) |
| |
| |
| if input_was_2d: |
| latents = latents.squeeze(0) |
| |
| return latents |
| |
| def _build_metadata_dict(self, bpm: Optional[Union[int, str]], key_scale: str, time_signature: str, duration: Optional[float] = None) -> Dict[str, Any]: |
| """ |
| Build metadata dictionary with default values. |
| |
| Args: |
| bpm: BPM value (optional) |
| key_scale: Key/scale string |
| time_signature: Time signature string |
| duration: Duration in seconds (optional) |
| |
| Returns: |
| Metadata dictionary |
| """ |
| metadata_dict = {} |
| if bpm: |
| metadata_dict["bpm"] = bpm |
| else: |
| metadata_dict["bpm"] = "N/A" |
|
|
| if key_scale.strip(): |
| metadata_dict["keyscale"] = key_scale |
| else: |
| metadata_dict["keyscale"] = "N/A" |
|
|
| if time_signature.strip() and time_signature != "N/A" and time_signature: |
| metadata_dict["timesignature"] = time_signature |
| else: |
| metadata_dict["timesignature"] = "N/A" |
| |
| |
| if duration is not None: |
| metadata_dict["duration"] = f"{int(duration)} seconds" |
| |
| return metadata_dict |
|
|
| def generate_instruction( |
| self, |
| task_type: str, |
| track_name: Optional[str] = None, |
| complete_track_classes: Optional[List[str]] = None |
| ) -> str: |
| if task_type == "text2music": |
| return TASK_INSTRUCTIONS["text2music"] |
| elif task_type == "repaint": |
| return TASK_INSTRUCTIONS["repaint"] |
| elif task_type == "cover": |
| return TASK_INSTRUCTIONS["cover"] |
| elif task_type == "extract": |
| if track_name: |
| |
| track_name_upper = track_name.upper() |
| return TASK_INSTRUCTIONS["extract"].format(TRACK_NAME=track_name_upper) |
| else: |
| return TASK_INSTRUCTIONS["extract_default"] |
| elif task_type == "lego": |
| if track_name: |
| |
| track_name_upper = track_name.upper() |
| return TASK_INSTRUCTIONS["lego"].format(TRACK_NAME=track_name_upper) |
| else: |
| return TASK_INSTRUCTIONS["lego_default"] |
| elif task_type == "complete": |
| if complete_track_classes and len(complete_track_classes) > 0: |
| |
| track_classes_upper = [t.upper() for t in complete_track_classes] |
| complete_track_classes_str = " | ".join(track_classes_upper) |
| return TASK_INSTRUCTIONS["complete"].format(TRACK_CLASSES=complete_track_classes_str) |
| else: |
| return TASK_INSTRUCTIONS["complete_default"] |
| else: |
| return TASK_INSTRUCTIONS["text2music"] |
| |
| def process_reference_audio(self, audio_file) -> Optional[torch.Tensor]: |
| if audio_file is None: |
| return None |
| |
| try: |
| |
| audio, sr = torchaudio.load(audio_file) |
| |
| logger.debug(f"[process_reference_audio] Reference audio shape: {audio.shape}") |
| logger.debug(f"[process_reference_audio] Reference audio sample rate: {sr}") |
| logger.debug(f"[process_reference_audio] Reference audio duration: {audio.shape[-1] / 48000.0} seconds") |
| |
| |
| audio = self._normalize_audio_to_stereo_48k(audio, sr) |
| |
| is_silence = self.is_silence(audio) |
| if is_silence: |
| return None |
| |
| |
| target_frames = 30 * 48000 |
| segment_frames = 10 * 48000 |
| |
| |
| if audio.shape[-1] < target_frames: |
| repeat_times = math.ceil(target_frames / audio.shape[-1]) |
| audio = audio.repeat(1, repeat_times) |
| |
| |
| |
| |
| total_frames = audio.shape[-1] |
| segment_size = total_frames // 3 |
| |
| |
| front_start = random.randint(0, max(0, segment_size - segment_frames)) |
| front_audio = audio[:, front_start:front_start + segment_frames] |
| |
| |
| middle_start = segment_size + random.randint(0, max(0, segment_size - segment_frames)) |
| middle_audio = audio[:, middle_start:middle_start + segment_frames] |
| |
| |
| back_start = 2 * segment_size + random.randint(0, max(0, (total_frames - 2 * segment_size) - segment_frames)) |
| back_audio = audio[:, back_start:back_start + segment_frames] |
| |
| |
| audio = torch.cat([front_audio, middle_audio, back_audio], dim=-1) |
| |
| return audio |
| |
| except Exception as e: |
| logger.exception("[process_reference_audio] Error processing reference audio") |
| return None |
|
|
| def process_src_audio(self, audio_file) -> Optional[torch.Tensor]: |
| if audio_file is None: |
| return None |
| |
| try: |
| |
| audio, sr = torchaudio.load(audio_file) |
| |
| |
| audio = self._normalize_audio_to_stereo_48k(audio, sr) |
| |
| return audio |
| |
| except Exception as e: |
| logger.exception("[process_src_audio] Error processing source audio") |
| return None |
| |
| def convert_src_audio_to_codes(self, audio_file) -> str: |
| """ |
| Convert uploaded source audio to audio codes string. |
| |
| Args: |
| audio_file: Path to audio file or None |
| |
| Returns: |
| Formatted codes string like '<|audio_code_123|><|audio_code_456|>...' or error message |
| """ |
| if audio_file is None: |
| return "❌ Please upload source audio first" |
| |
| if self.model is None or self.vae is None: |
| return "❌ Model not initialized. Please initialize the service first." |
| |
| try: |
| |
| processed_audio = self.process_src_audio(audio_file) |
| if processed_audio is None: |
| return "❌ Failed to process audio file" |
| |
| |
| with torch.no_grad(): |
| with self._load_model_context("vae"): |
| |
| if self.is_silence(processed_audio.unsqueeze(0)): |
| return "❌ Audio file appears to be silent" |
| |
| |
| latents = self._encode_audio_to_latents(processed_audio) |
| |
| |
| attention_mask = torch.ones(latents.shape[0], dtype=torch.bool, device=self.device) |
| |
| |
| with self._load_model_context("model"): |
| |
| hidden_states = latents.unsqueeze(0) |
| |
| |
| |
| _, indices, _ = self.model.tokenize(hidden_states, self.silence_latent, attention_mask.unsqueeze(0)) |
| |
| |
| |
| |
| indices_flat = indices.flatten().cpu().tolist() |
| codes_string = "".join([f"<|audio_code_{idx}|>" for idx in indices_flat]) |
| |
| logger.info(f"[convert_src_audio_to_codes] Generated {len(indices_flat)} audio codes") |
| return codes_string |
| |
| except Exception as e: |
| error_msg = f"❌ Error converting audio to codes: {str(e)}\n{traceback.format_exc()}" |
| logger.exception("[convert_src_audio_to_codes] Error converting audio to codes") |
| return error_msg |
| |
| def prepare_batch_data( |
| self, |
| actual_batch_size, |
| processed_src_audio, |
| audio_duration, |
| captions, |
| lyrics, |
| vocal_language, |
| instruction, |
| bpm, |
| key_scale, |
| time_signature |
| ): |
| pure_caption = self.extract_caption_from_sft_format(captions) |
| captions_batch = [pure_caption] * actual_batch_size |
| instructions_batch = [instruction] * actual_batch_size |
| lyrics_batch = [lyrics] * actual_batch_size |
| vocal_languages_batch = [vocal_language] * actual_batch_size |
| |
| calculated_duration = None |
| if processed_src_audio is not None: |
| calculated_duration = processed_src_audio.shape[-1] / 48000.0 |
| elif audio_duration is not None and float(audio_duration) > 0: |
| calculated_duration = float(audio_duration) |
|
|
| |
| metadata_dict = self._build_metadata_dict(bpm, key_scale, time_signature, calculated_duration) |
|
|
| |
| |
| metas_batch = [metadata_dict.copy() for _ in range(actual_batch_size)] |
| return captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch |
| |
| def determine_task_type(self, task_type, audio_code_string): |
| |
| |
| is_repaint_task = (task_type == "repaint") |
| is_lego_task = (task_type == "lego") |
| is_cover_task = (task_type == "cover") |
|
|
| has_codes = False |
| if isinstance(audio_code_string, list): |
| has_codes = any((c or "").strip() for c in audio_code_string) |
| else: |
| has_codes = bool(audio_code_string and str(audio_code_string).strip()) |
|
|
| if has_codes: |
| is_cover_task = True |
| |
| can_use_repainting = is_repaint_task or is_lego_task |
| return is_repaint_task, is_lego_task, is_cover_task, can_use_repainting |
|
|
| def create_target_wavs(self, duration_seconds: float) -> torch.Tensor: |
| try: |
| |
| duration_seconds = max(0.1, round(duration_seconds, 1)) |
| |
| frames = int(duration_seconds * 48000) |
| |
| target_wavs = torch.zeros(2, frames) |
| return target_wavs |
| except Exception as e: |
| logger.exception("[create_target_wavs] Error creating target audio") |
| |
| return torch.zeros(2, 30 * 48000) |
| |
| def prepare_padding_info( |
| self, |
| actual_batch_size, |
| processed_src_audio, |
| audio_duration, |
| repainting_start, |
| repainting_end, |
| is_repaint_task, |
| is_lego_task, |
| is_cover_task, |
| can_use_repainting, |
| ): |
| target_wavs_batch = [] |
| |
| padding_info_batch = [] |
| for i in range(actual_batch_size): |
| if processed_src_audio is not None: |
| if is_cover_task: |
| |
| batch_target_wavs = processed_src_audio |
| padding_info_batch.append({ |
| 'left_padding_duration': 0.0, |
| 'right_padding_duration': 0.0 |
| }) |
| elif is_repaint_task or is_lego_task: |
| |
| src_audio_duration = processed_src_audio.shape[-1] / 48000.0 |
|
|
| |
| if repainting_end is None or repainting_end < 0: |
| actual_end = src_audio_duration |
| else: |
| actual_end = repainting_end |
|
|
| left_padding_duration = max(0, -repainting_start) if repainting_start is not None else 0 |
| right_padding_duration = max(0, actual_end - src_audio_duration) |
|
|
| |
| left_padding_frames = int(left_padding_duration * 48000) |
| right_padding_frames = int(right_padding_duration * 48000) |
|
|
| if left_padding_frames > 0 or right_padding_frames > 0: |
| |
| batch_target_wavs = torch.nn.functional.pad( |
| processed_src_audio, |
| (left_padding_frames, right_padding_frames), |
| 'constant', 0 |
| ) |
| else: |
| batch_target_wavs = processed_src_audio |
|
|
| |
| padding_info_batch.append({ |
| 'left_padding_duration': left_padding_duration, |
| 'right_padding_duration': right_padding_duration |
| }) |
| else: |
| |
| batch_target_wavs = processed_src_audio |
| padding_info_batch.append({ |
| 'left_padding_duration': 0.0, |
| 'right_padding_duration': 0.0 |
| }) |
| else: |
| padding_info_batch.append({ |
| 'left_padding_duration': 0.0, |
| 'right_padding_duration': 0.0 |
| }) |
| if audio_duration is not None and float(audio_duration) > 0: |
| batch_target_wavs = self.create_target_wavs(float(audio_duration)) |
| else: |
| import random |
| random_duration = random.uniform(10.0, 120.0) |
| batch_target_wavs = self.create_target_wavs(random_duration) |
| target_wavs_batch.append(batch_target_wavs) |
|
|
| |
| |
| max_frames = max(wav.shape[-1] for wav in target_wavs_batch) |
| padded_target_wavs = [] |
| for wav in target_wavs_batch: |
| if wav.shape[-1] < max_frames: |
| pad_frames = max_frames - wav.shape[-1] |
| padded_wav = torch.nn.functional.pad(wav, (0, pad_frames), 'constant', 0) |
| padded_target_wavs.append(padded_wav) |
| else: |
| padded_target_wavs.append(wav) |
|
|
| target_wavs_tensor = torch.stack(padded_target_wavs, dim=0) |
|
|
| if can_use_repainting: |
| |
| if repainting_start is None: |
| repainting_start_batch = None |
| elif isinstance(repainting_start, (int, float)): |
| if processed_src_audio is not None: |
| adjusted_start = repainting_start + padding_info_batch[0]['left_padding_duration'] |
| repainting_start_batch = [adjusted_start] * actual_batch_size |
| else: |
| repainting_start_batch = [repainting_start] * actual_batch_size |
| else: |
| |
| repainting_start_batch = [] |
| for i in range(actual_batch_size): |
| if processed_src_audio is not None: |
| adjusted_start = repainting_start[i] + padding_info_batch[i]['left_padding_duration'] |
| repainting_start_batch.append(adjusted_start) |
| else: |
| repainting_start_batch.append(repainting_start[i]) |
|
|
| |
| if processed_src_audio is not None: |
| |
| src_audio_duration = processed_src_audio.shape[-1] / 48000.0 |
| if repainting_end is None or repainting_end < 0: |
| |
| adjusted_end = src_audio_duration + padding_info_batch[0]['left_padding_duration'] |
| repainting_end_batch = [adjusted_end] * actual_batch_size |
| else: |
| |
| adjusted_end = repainting_end + padding_info_batch[0]['left_padding_duration'] |
| repainting_end_batch = [adjusted_end] * actual_batch_size |
| else: |
| |
| if repainting_end is None or repainting_end < 0: |
| repainting_end_batch = None |
| elif isinstance(repainting_end, (int, float)): |
| repainting_end_batch = [repainting_end] * actual_batch_size |
| else: |
| |
| repainting_end_batch = [] |
| for i in range(actual_batch_size): |
| if processed_src_audio is not None: |
| adjusted_end = repainting_end[i] + padding_info_batch[i]['left_padding_duration'] |
| repainting_end_batch.append(adjusted_end) |
| else: |
| repainting_end_batch.append(repainting_end[i]) |
| else: |
| |
| |
| repainting_start_batch = None |
| repainting_end_batch = None |
| |
| return repainting_start_batch, repainting_end_batch, target_wavs_tensor |
|
|
| def _prepare_batch( |
| self, |
| captions: List[str], |
| lyrics: List[str], |
| keys: Optional[List[str]] = None, |
| target_wavs: Optional[torch.Tensor] = None, |
| refer_audios: Optional[List[List[torch.Tensor]]] = None, |
| metas: Optional[List[Union[str, Dict[str, Any]]]] = None, |
| vocal_languages: Optional[List[str]] = None, |
| repainting_start: Optional[List[float]] = None, |
| repainting_end: Optional[List[float]] = None, |
| instructions: Optional[List[str]] = None, |
| audio_code_hints: Optional[List[Optional[str]]] = None, |
| audio_cover_strength: float = 1.0, |
| ) -> Dict[str, Any]: |
| """ |
| Prepare batch data with fallbacks for missing inputs. |
| |
| Args: |
| captions: List of text captions (optional, can be empty strings) |
| lyrics: List of lyrics (optional, can be empty strings) |
| keys: List of unique identifiers (optional) |
| target_wavs: Target audio tensors (optional, will use silence if not provided) |
| refer_audios: Reference audio tensors (optional, will use silence if not provided) |
| metas: Metadata (optional, will use defaults if not provided) |
| vocal_languages: Vocal languages (optional, will default to 'en') |
| |
| Returns: |
| Batch dictionary ready for model input |
| """ |
| batch_size = len(captions) |
| |
| |
| self._ensure_silence_latent_on_device() |
|
|
| |
| audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) |
| |
| for ii, refer_audio_list in enumerate(refer_audios): |
| if isinstance(refer_audio_list, list): |
| for idx, refer_audio in enumerate(refer_audio_list): |
| refer_audio_list[idx] = refer_audio_list[idx].to(self.device).to(torch.bfloat16) |
| elif isinstance(refer_audio_list, torch.Tensor): |
| refer_audios[ii] = refer_audios[ii].to(self.device) |
| |
| if vocal_languages is None: |
| vocal_languages = self._create_fallback_vocal_languages(batch_size) |
| |
| |
| parsed_metas = self._parse_metas(metas) |
| |
| |
| with torch.no_grad(): |
| target_latents_list = [] |
| latent_lengths = [] |
| |
| target_wavs_list = [target_wavs[i].clone() for i in range(batch_size)] |
| if target_wavs.device != self.device: |
| target_wavs = target_wavs.to(self.device) |
| |
| with self._load_model_context("vae"): |
| for i in range(batch_size): |
| code_hint = audio_code_hints[i] |
| |
| if code_hint: |
| logger.info(f"[generate_music] Decoding audio codes for item {i}...") |
| decoded_latents = self._decode_audio_codes_to_latents(code_hint) |
| if decoded_latents is not None: |
| decoded_latents = decoded_latents.squeeze(0) |
| target_latents_list.append(decoded_latents) |
| latent_lengths.append(decoded_latents.shape[0]) |
| |
| frames_from_codes = max(1, int(decoded_latents.shape[0] * 1920)) |
| target_wavs_list[i] = torch.zeros(2, frames_from_codes) |
| continue |
| |
| current_wav = target_wavs_list[i].to(self.device).unsqueeze(0) |
| if self.is_silence(current_wav): |
| expected_latent_length = current_wav.shape[-1] // 1920 |
| target_latent = self.silence_latent[0, :expected_latent_length, :] |
| else: |
| |
| logger.info(f"[generate_music] Encoding target audio to latents for item {i}...") |
| target_latent = self._encode_audio_to_latents(current_wav.squeeze(0)) |
| target_latents_list.append(target_latent) |
| latent_lengths.append(target_latent.shape[0]) |
| |
| |
| max_target_frames = max(wav.shape[-1] for wav in target_wavs_list) |
| padded_target_wavs = [] |
| for wav in target_wavs_list: |
| if wav.shape[-1] < max_target_frames: |
| pad_frames = max_target_frames - wav.shape[-1] |
| wav = torch.nn.functional.pad(wav, (0, pad_frames), "constant", 0) |
| padded_target_wavs.append(wav) |
| target_wavs = torch.stack(padded_target_wavs) |
| wav_lengths = torch.tensor([target_wavs.shape[-1]] * batch_size, dtype=torch.long) |
| |
| |
| max_latent_length = max(latent.shape[0] for latent in target_latents_list) |
| max_latent_length = max(128, max_latent_length) |
| |
| padded_latents = [] |
| for latent in target_latents_list: |
| latent_length = latent.shape[0] |
| |
| if latent.shape[0] < max_latent_length: |
| pad_length = max_latent_length - latent.shape[0] |
| latent = torch.cat([latent, self.silence_latent[0, :pad_length, :]], dim=0) |
| padded_latents.append(latent) |
| |
| target_latents = torch.stack(padded_latents) |
| latent_masks = torch.stack([ |
| torch.cat([ |
| torch.ones(l, dtype=torch.long, device=self.device), |
| torch.zeros(max_latent_length - l, dtype=torch.long, device=self.device) |
| ]) |
| for l in latent_lengths |
| ]) |
| |
| |
| |
| instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) |
| |
| |
| |
| chunk_masks = [] |
| spans = [] |
| is_covers = [] |
| |
| repainting_ranges = {} |
| |
| for i in range(batch_size): |
| has_code_hint = audio_code_hints[i] is not None |
| |
| has_repainting = False |
| if repainting_start is not None and repainting_end is not None: |
| start_sec = repainting_start[i] if repainting_start[i] is not None else 0.0 |
| end_sec = repainting_end[i] |
| |
| if end_sec is not None and end_sec > start_sec: |
| |
| |
| |
| |
| |
| left_padding_sec = max(0, -start_sec) |
| |
| |
| |
| adjusted_start_sec = start_sec + left_padding_sec |
| adjusted_end_sec = end_sec + left_padding_sec |
| |
| |
| start_latent = int(adjusted_start_sec * self.sample_rate // 1920) |
| end_latent = int(adjusted_end_sec * self.sample_rate // 1920) |
|
|
| |
| start_latent = max(0, min(start_latent, max_latent_length - 1)) |
| end_latent = max(start_latent + 1, min(end_latent, max_latent_length)) |
| |
| mask = torch.zeros(max_latent_length, dtype=torch.bool, device=self.device) |
| mask[start_latent:end_latent] = True |
| chunk_masks.append(mask) |
| spans.append(("repainting", start_latent, end_latent)) |
| |
| repainting_ranges[i] = (start_latent, end_latent) |
| has_repainting = True |
| is_covers.append(False) |
| else: |
| |
| chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device)) |
| spans.append(("full", 0, max_latent_length)) |
| |
| |
| instruction_i = instructions[i] if instructions and i < len(instructions) else "" |
| instruction_lower = instruction_i.lower() |
| |
| is_cover = ("generate audio semantic tokens" in instruction_lower and |
| "based on the given conditions" in instruction_lower) or has_code_hint |
| is_covers.append(is_cover) |
| else: |
| |
| chunk_masks.append(torch.ones(max_latent_length, dtype=torch.bool, device=self.device)) |
| spans.append(("full", 0, max_latent_length)) |
| |
| |
| instruction_i = instructions[i] if instructions and i < len(instructions) else "" |
| instruction_lower = instruction_i.lower() |
| |
| is_cover = ("generate audio semantic tokens" in instruction_lower and |
| "based on the given conditions" in instruction_lower) or has_code_hint |
| is_covers.append(is_cover) |
| |
| chunk_masks = torch.stack(chunk_masks) |
| is_covers = torch.BoolTensor(is_covers).to(self.device) |
| |
| |
| |
| |
| |
| src_latents_list = [] |
| silence_latent_tiled = self.silence_latent[0, :max_latent_length, :] |
| for i in range(batch_size): |
| |
| has_code_hint = audio_code_hints[i] is not None |
| has_target_audio = has_code_hint or (target_wavs is not None and target_wavs[i].abs().sum() > 1e-6) |
| |
| if has_target_audio: |
| |
| |
| item_has_repainting = (i in repainting_ranges) |
| |
| if item_has_repainting: |
| |
| |
| src_latent = target_latents[i].clone() |
| |
| start_latent, end_latent = repainting_ranges[i] |
| src_latent[start_latent:end_latent] = silence_latent_tiled[start_latent:end_latent] |
| src_latents_list.append(src_latent) |
| else: |
| |
| |
| src_latents_list.append(target_latents[i].clone()) |
| else: |
| |
| |
| src_latents_list.append(silence_latent_tiled.clone()) |
| |
| src_latents = torch.stack(src_latents_list) |
| |
| |
| precomputed_lm_hints_25Hz_list = [] |
| for i in range(batch_size): |
| if audio_code_hints[i] is not None: |
| |
| logger.info(f"[generate_music] Decoding audio codes for LM hints for item {i}...") |
| hints = self._decode_audio_codes_to_latents(audio_code_hints[i]) |
| if hints is not None: |
| |
| if hints.shape[1] < max_latent_length: |
| pad_length = max_latent_length - hints.shape[1] |
| pad = self.silence_latent |
| |
| if pad.dim() == 2: |
| pad = pad.unsqueeze(0) |
| if hints.dim() == 2: |
| hints = hints.unsqueeze(0) |
| pad_chunk = pad[:, :pad_length, :] |
| if pad_chunk.device != hints.device or pad_chunk.dtype != hints.dtype: |
| pad_chunk = pad_chunk.to(device=hints.device, dtype=hints.dtype) |
| hints = torch.cat([hints, pad_chunk], dim=1) |
| elif hints.shape[1] > max_latent_length: |
| hints = hints[:, :max_latent_length, :] |
| precomputed_lm_hints_25Hz_list.append(hints[0]) |
| else: |
| precomputed_lm_hints_25Hz_list.append(None) |
| else: |
| precomputed_lm_hints_25Hz_list.append(None) |
| |
| |
| if any(h is not None for h in precomputed_lm_hints_25Hz_list): |
| |
| precomputed_lm_hints_25Hz = torch.stack([ |
| h if h is not None else silence_latent_tiled |
| for h in precomputed_lm_hints_25Hz_list |
| ]) |
| else: |
| precomputed_lm_hints_25Hz = None |
| |
| |
| |
| actual_captions, actual_languages = self._extract_caption_and_language(parsed_metas, captions, vocal_languages) |
| |
| |
| text_inputs = [] |
| text_token_idss = [] |
| text_attention_masks = [] |
| lyric_token_idss = [] |
| lyric_attention_masks = [] |
| |
| for i in range(batch_size): |
| |
| instruction = self._format_instruction(instructions[i] if i < len(instructions) else DEFAULT_DIT_INSTRUCTION) |
| |
| actual_caption = actual_captions[i] |
| actual_language = actual_languages[i] |
| |
| |
| text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i]) |
| |
| |
| text_inputs_dict = self.text_tokenizer( |
| text_prompt, |
| padding="longest", |
| truncation=True, |
| max_length=256, |
| return_tensors="pt", |
| ) |
| text_token_ids = text_inputs_dict.input_ids[0] |
| text_attention_mask = text_inputs_dict.attention_mask[0].bool() |
| |
| |
| lyrics_text = self._format_lyrics(lyrics[i], actual_language) |
| lyrics_inputs_dict = self.text_tokenizer( |
| lyrics_text, |
| padding="longest", |
| truncation=True, |
| max_length=2048, |
| return_tensors="pt", |
| ) |
| lyric_token_ids = lyrics_inputs_dict.input_ids[0] |
| lyric_attention_mask = lyrics_inputs_dict.attention_mask[0].bool() |
| |
| |
| text_input = text_prompt + "\n\n" + lyrics_text |
| |
| text_inputs.append(text_input) |
| text_token_idss.append(text_token_ids) |
| text_attention_masks.append(text_attention_mask) |
| lyric_token_idss.append(lyric_token_ids) |
| lyric_attention_masks.append(lyric_attention_mask) |
| |
| |
| max_text_length = max(len(seq) for seq in text_token_idss) |
| padded_text_token_idss = self._pad_sequences(text_token_idss, max_text_length, self.text_tokenizer.pad_token_id) |
| padded_text_attention_masks = self._pad_sequences(text_attention_masks, max_text_length, 0) |
| |
| max_lyric_length = max(len(seq) for seq in lyric_token_idss) |
| padded_lyric_token_idss = self._pad_sequences(lyric_token_idss, max_lyric_length, self.text_tokenizer.pad_token_id) |
| padded_lyric_attention_masks = self._pad_sequences(lyric_attention_masks, max_lyric_length, 0) |
|
|
| padded_non_cover_text_input_ids = None |
| padded_non_cover_text_attention_masks = None |
| if audio_cover_strength < 1.0: |
| non_cover_text_input_ids = [] |
| non_cover_text_attention_masks = [] |
| for i in range(batch_size): |
| |
| instruction = self._format_instruction(DEFAULT_DIT_INSTRUCTION) |
| |
| |
| actual_caption = actual_captions[i] |
| |
| |
| text_prompt = SFT_GEN_PROMPT.format(instruction, actual_caption, parsed_metas[i]) |
| |
| |
| text_inputs_dict = self.text_tokenizer( |
| text_prompt, |
| padding="longest", |
| truncation=True, |
| max_length=256, |
| return_tensors="pt", |
| ) |
| text_token_ids = text_inputs_dict.input_ids[0] |
| non_cover_text_attention_mask = text_inputs_dict.attention_mask[0].bool() |
| non_cover_text_input_ids.append(text_token_ids) |
| non_cover_text_attention_masks.append(non_cover_text_attention_mask) |
| |
| padded_non_cover_text_input_ids = self._pad_sequences(non_cover_text_input_ids, max_text_length, self.text_tokenizer.pad_token_id) |
| padded_non_cover_text_attention_masks = self._pad_sequences(non_cover_text_attention_masks, max_text_length, 0) |
| |
| if audio_cover_strength < 1.0: |
| assert padded_non_cover_text_input_ids is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_input_ids must not be None" |
| assert padded_non_cover_text_attention_masks is not None, "When audio_cover_strength < 1.0, padded_non_cover_text_attention_masks must not be None" |
| |
| batch = { |
| "keys": keys, |
| "target_wavs": target_wavs.to(self.device), |
| "refer_audioss": refer_audios, |
| "wav_lengths": wav_lengths.to(self.device), |
| "captions": captions, |
| "lyrics": lyrics, |
| "metas": parsed_metas, |
| "vocal_languages": vocal_languages, |
| "target_latents": target_latents, |
| "src_latents": src_latents, |
| "latent_masks": latent_masks, |
| "chunk_masks": chunk_masks, |
| "spans": spans, |
| "text_inputs": text_inputs, |
| "text_token_idss": padded_text_token_idss, |
| "text_attention_masks": padded_text_attention_masks, |
| "lyric_token_idss": padded_lyric_token_idss, |
| "lyric_attention_masks": padded_lyric_attention_masks, |
| "is_covers": is_covers, |
| "precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz, |
| "non_cover_text_input_ids": padded_non_cover_text_input_ids, |
| "non_cover_text_attention_masks": padded_non_cover_text_attention_masks, |
| } |
| |
| for k, v in batch.items(): |
| if isinstance(v, torch.Tensor): |
| batch[k] = v.to(self.device) |
| if torch.is_floating_point(v): |
| batch[k] = v.to(self.dtype) |
| return batch |
| |
| def infer_refer_latent(self, refer_audioss): |
| refer_audio_order_mask = [] |
| refer_audio_latents = [] |
| |
| |
| self._ensure_silence_latent_on_device() |
|
|
| def _normalize_audio_2d(a: torch.Tensor) -> torch.Tensor: |
| """Normalize audio tensor to [2, T] on current device.""" |
| if not isinstance(a, torch.Tensor): |
| raise TypeError(f"refer_audio must be a torch.Tensor, got {type(a)!r}") |
| |
| if a.dim() == 3 and a.shape[0] == 1: |
| a = a.squeeze(0) |
| if a.dim() == 1: |
| a = a.unsqueeze(0) |
| if a.dim() != 2: |
| raise ValueError(f"refer_audio must be 1D/2D/3D(1,2,T); got shape={tuple(a.shape)}") |
| if a.shape[0] == 1: |
| a = torch.cat([a, a], dim=0) |
| a = a[:2] |
| return a |
|
|
| def _ensure_latent_3d(z: torch.Tensor) -> torch.Tensor: |
| """Ensure latent is [N, T, D] (3D) for packing.""" |
| if z.dim() == 4 and z.shape[0] == 1: |
| z = z.squeeze(0) |
| if z.dim() == 2: |
| z = z.unsqueeze(0) |
| return z |
|
|
| for batch_idx, refer_audios in enumerate(refer_audioss): |
| if len(refer_audios) == 1 and torch.all(refer_audios[0] == 0.0): |
| refer_audio_latent = _ensure_latent_3d(self.silence_latent[:, :750, :]) |
| refer_audio_latents.append(refer_audio_latent) |
| refer_audio_order_mask.append(batch_idx) |
| else: |
| for refer_audio in refer_audios: |
| refer_audio = _normalize_audio_2d(refer_audio) |
| |
| vae_input = refer_audio.unsqueeze(0).to(self.vae.dtype) |
| refer_audio_latent = self.vae.encode(vae_input).latent_dist.sample() |
| |
| refer_audio_latent = refer_audio_latent.to(self.dtype) |
| refer_audio_latents.append(_ensure_latent_3d(refer_audio_latent.transpose(1, 2))) |
| refer_audio_order_mask.append(batch_idx) |
|
|
| refer_audio_latents = torch.cat(refer_audio_latents, dim=0) |
| refer_audio_order_mask = torch.tensor(refer_audio_order_mask, device=self.device, dtype=torch.long) |
| return refer_audio_latents, refer_audio_order_mask |
|
|
| def infer_text_embeddings(self, text_token_idss): |
| with torch.no_grad(): |
| text_embeddings = self.text_encoder(input_ids=text_token_idss, lyric_attention_mask=None).last_hidden_state |
| return text_embeddings |
|
|
| def infer_lyric_embeddings(self, lyric_token_ids): |
| with torch.no_grad(): |
| lyric_embeddings = self.text_encoder.embed_tokens(lyric_token_ids) |
| return lyric_embeddings |
|
|
| def preprocess_batch(self, batch): |
|
|
| |
| |
| target_latents = batch["target_latents"] |
| src_latents = batch["src_latents"] |
| attention_mask = batch["latent_masks"] |
| audio_codes = batch.get("audio_codes", None) |
| audio_attention_mask = attention_mask |
|
|
| dtype = target_latents.dtype |
| bs = target_latents.shape[0] |
| device = target_latents.device |
|
|
| |
| keys = batch["keys"] |
| with self._load_model_context("vae"): |
| refer_audio_acoustic_hidden_states_packed, refer_audio_order_mask = self.infer_refer_latent(batch["refer_audioss"]) |
| if refer_audio_acoustic_hidden_states_packed.dtype != dtype: |
| refer_audio_acoustic_hidden_states_packed = refer_audio_acoustic_hidden_states_packed.to(dtype) |
|
|
| |
| chunk_mask = batch["chunk_masks"] |
| chunk_mask = chunk_mask.to(device).unsqueeze(-1).repeat(1, 1, target_latents.shape[2]) |
|
|
| spans = batch["spans"] |
| |
| text_token_idss = batch["text_token_idss"] |
| text_attention_mask = batch["text_attention_masks"] |
| lyric_token_idss = batch["lyric_token_idss"] |
| lyric_attention_mask = batch["lyric_attention_masks"] |
| text_inputs = batch["text_inputs"] |
|
|
| logger.info("[preprocess_batch] Inferring prompt embeddings...") |
| with self._load_model_context("text_encoder"): |
| text_hidden_states = self.infer_text_embeddings(text_token_idss) |
| logger.info("[preprocess_batch] Inferring lyric embeddings...") |
| lyric_hidden_states = self.infer_lyric_embeddings(lyric_token_idss) |
|
|
| is_covers = batch["is_covers"] |
| |
| |
| precomputed_lm_hints_25Hz = batch.get("precomputed_lm_hints_25Hz", None) |
| |
| |
| non_cover_text_input_ids = batch.get("non_cover_text_input_ids", None) |
| non_cover_text_attention_masks = batch.get("non_cover_text_attention_masks", None) |
| non_cover_text_hidden_states = None |
| if non_cover_text_input_ids is not None: |
| logger.info("[preprocess_batch] Inferring non-cover text embeddings...") |
| non_cover_text_hidden_states = self.infer_text_embeddings(non_cover_text_input_ids) |
|
|
| return ( |
| keys, |
| text_inputs, |
| src_latents, |
| target_latents, |
| |
| text_hidden_states, |
| text_attention_mask, |
| lyric_hidden_states, |
| lyric_attention_mask, |
| audio_attention_mask, |
| refer_audio_acoustic_hidden_states_packed, |
| refer_audio_order_mask, |
| chunk_mask, |
| spans, |
| is_covers, |
| audio_codes, |
| lyric_token_idss, |
| precomputed_lm_hints_25Hz, |
| non_cover_text_hidden_states, |
| non_cover_text_attention_masks, |
| ) |
| |
| @torch.no_grad() |
| def service_generate( |
| self, |
| captions: Union[str, List[str]], |
| lyrics: Union[str, List[str]], |
| keys: Optional[Union[str, List[str]]] = None, |
| target_wavs: Optional[torch.Tensor] = None, |
| refer_audios: Optional[List[List[torch.Tensor]]] = None, |
| metas: Optional[Union[str, Dict[str, Any], List[Union[str, Dict[str, Any]]]]] = None, |
| vocal_languages: Optional[Union[str, List[str]]] = None, |
| infer_steps: int = 60, |
| guidance_scale: float = 7.0, |
| seed: Optional[Union[int, List[int]]] = None, |
| return_intermediate: bool = False, |
| repainting_start: Optional[Union[float, List[float]]] = None, |
| repainting_end: Optional[Union[float, List[float]]] = None, |
| instructions: Optional[Union[str, List[str]]] = None, |
| audio_cover_strength: float = 1.0, |
| use_adg: bool = False, |
| cfg_interval_start: float = 0.0, |
| cfg_interval_end: float = 1.0, |
| shift: float = 1.0, |
| audio_code_hints: Optional[Union[str, List[str]]] = None, |
| infer_method: str = "ode", |
| timesteps: Optional[List[float]] = None, |
| ) -> Dict[str, Any]: |
|
|
| """ |
| Generate music from text inputs. |
| |
| Args: |
| captions: Text caption(s) describing the music (optional, can be empty strings) |
| lyrics: Lyric text(s) (optional, can be empty strings) |
| keys: Unique identifier(s) (optional) |
| target_wavs: Target audio tensor(s) for conditioning (optional) |
| refer_audios: Reference audio tensor(s) for style transfer (optional) |
| metas: Metadata dict(s) or string(s) (optional) |
| vocal_languages: Language code(s) for lyrics (optional, defaults to 'en') |
| infer_steps: Number of inference steps (default: 60) |
| guidance_scale: Guidance scale for generation (default: 7.0) |
| seed: Random seed (optional) |
| return_intermediate: Whether to return intermediate results (default: False) |
| repainting_start: Start time(s) for repainting region in seconds (optional) |
| repainting_end: End time(s) for repainting region in seconds (optional) |
| instructions: Instruction text(s) for generation (optional) |
| audio_cover_strength: Strength of audio cover mode (default: 1.0) |
| use_adg: Whether to use ADG (Adaptive Diffusion Guidance) (default: False) |
| cfg_interval_start: Start of CFG interval (0.0-1.0, default: 0.0) |
| cfg_interval_end: End of CFG interval (0.0-1.0, default: 1.0) |
| |
| Returns: |
| Dictionary containing: |
| - pred_wavs: Generated audio tensors |
| - target_wavs: Input target audio (if provided) |
| - vqvae_recon_wavs: VAE reconstruction of target |
| - keys: Identifiers used |
| - text_inputs: Formatted text inputs |
| - sr: Sample rate |
| - spans: Generation spans |
| - time_costs: Timing information |
| - seed_num: Seed used |
| """ |
| if self.config.is_turbo: |
| |
| if infer_steps > 8: |
| logger.warning(f"[service_generate] dmd_gan version: infer_steps {infer_steps} exceeds maximum 8, clamping to 8") |
| infer_steps = 8 |
| |
| |
| |
| |
| if isinstance(captions, str): |
| captions = [captions] |
| if isinstance(lyrics, str): |
| lyrics = [lyrics] |
| if isinstance(keys, str): |
| keys = [keys] |
| if isinstance(vocal_languages, str): |
| vocal_languages = [vocal_languages] |
| if isinstance(metas, (str, dict)): |
| metas = [metas] |
| |
| |
| if isinstance(repainting_start, (int, float)): |
| repainting_start = [repainting_start] |
| if isinstance(repainting_end, (int, float)): |
| repainting_end = [repainting_end] |
| |
| |
| batch_size = len(captions) |
|
|
| |
| instructions = self._normalize_instructions(instructions, batch_size, DEFAULT_DIT_INSTRUCTION) if instructions is not None else None |
| audio_code_hints = self._normalize_audio_code_hints(audio_code_hints, batch_size) if audio_code_hints is not None else None |
| |
| |
| if seed is None: |
| seed_list = None |
| elif isinstance(seed, list): |
| seed_list = seed |
| |
| if len(seed_list) < batch_size: |
| |
| import random |
| while len(seed_list) < batch_size: |
| seed_list.append(random.randint(0, 2**32 - 1)) |
| elif len(seed_list) > batch_size: |
| |
| seed_list = seed_list[:batch_size] |
| else: |
| |
| seed_list = [int(seed)] * batch_size |
|
|
| |
| |
| |
| batch = self._prepare_batch( |
| captions=captions, |
| lyrics=lyrics, |
| keys=keys, |
| target_wavs=target_wavs, |
| refer_audios=refer_audios, |
| metas=metas, |
| vocal_languages=vocal_languages, |
| repainting_start=repainting_start, |
| repainting_end=repainting_end, |
| instructions=instructions, |
| audio_code_hints=audio_code_hints, |
| audio_cover_strength=audio_cover_strength, |
| ) |
| |
| processed_data = self.preprocess_batch(batch) |
| |
| ( |
| keys, |
| text_inputs, |
| src_latents, |
| target_latents, |
| |
| text_hidden_states, |
| text_attention_mask, |
| lyric_hidden_states, |
| lyric_attention_mask, |
| audio_attention_mask, |
| refer_audio_acoustic_hidden_states_packed, |
| refer_audio_order_mask, |
| chunk_mask, |
| spans, |
| is_covers, |
| audio_codes, |
| lyric_token_idss, |
| precomputed_lm_hints_25Hz, |
| non_cover_text_hidden_states, |
| non_cover_text_attention_masks, |
| ) = processed_data |
|
|
| |
| |
| if seed_list is not None: |
| |
| seed_param = seed_list |
| else: |
| seed_param = random.randint(0, 2**32 - 1) |
| |
| |
| self._ensure_silence_latent_on_device() |
| |
| generate_kwargs = { |
| "text_hidden_states": text_hidden_states, |
| "text_attention_mask": text_attention_mask, |
| "lyric_hidden_states": lyric_hidden_states, |
| "lyric_attention_mask": lyric_attention_mask, |
| "refer_audio_acoustic_hidden_states_packed": refer_audio_acoustic_hidden_states_packed, |
| "refer_audio_order_mask": refer_audio_order_mask, |
| "src_latents": src_latents, |
| "chunk_masks": chunk_mask, |
| "is_covers": is_covers, |
| "silence_latent": self.silence_latent, |
| "seed": seed_param, |
| "non_cover_text_hidden_states": non_cover_text_hidden_states, |
| "non_cover_text_attention_mask": non_cover_text_attention_masks, |
| "precomputed_lm_hints_25Hz": precomputed_lm_hints_25Hz, |
| "audio_cover_strength": audio_cover_strength, |
| "infer_method": infer_method, |
| "infer_steps": infer_steps, |
| "diffusion_guidance_sale": guidance_scale, |
| "use_adg": use_adg, |
| "cfg_interval_start": cfg_interval_start, |
| "cfg_interval_end": cfg_interval_end, |
| "shift": shift, |
| } |
| |
| if timesteps is not None: |
| generate_kwargs["timesteps"] = torch.tensor(timesteps, dtype=torch.float32) |
| logger.info("[service_generate] Generating audio...") |
| with self._load_model_context("model"): |
| |
| encoder_hidden_states, encoder_attention_mask, context_latents = self.model.prepare_condition( |
| text_hidden_states=text_hidden_states, |
| text_attention_mask=text_attention_mask, |
| lyric_hidden_states=lyric_hidden_states, |
| lyric_attention_mask=lyric_attention_mask, |
| refer_audio_acoustic_hidden_states_packed=refer_audio_acoustic_hidden_states_packed, |
| refer_audio_order_mask=refer_audio_order_mask, |
| hidden_states=src_latents, |
| attention_mask=torch.ones(src_latents.shape[0], src_latents.shape[1], device=src_latents.device, dtype=src_latents.dtype), |
| silence_latent=self.silence_latent, |
| src_latents=src_latents, |
| chunk_masks=chunk_mask, |
| is_covers=is_covers, |
| precomputed_lm_hints_25Hz=precomputed_lm_hints_25Hz, |
| ) |
| |
| outputs = self.model.generate_audio(**generate_kwargs) |
| |
| |
| outputs["src_latents"] = src_latents |
| outputs["target_latents_input"] = target_latents |
| outputs["chunk_masks"] = chunk_mask |
| outputs["spans"] = spans |
| outputs["latent_masks"] = batch.get("latent_masks") |
| |
| |
| outputs["encoder_hidden_states"] = encoder_hidden_states |
| outputs["encoder_attention_mask"] = encoder_attention_mask |
| outputs["context_latents"] = context_latents |
| outputs["lyric_token_idss"] = lyric_token_idss |
| |
| return outputs |
|
|
| def tiled_decode(self, latents, chunk_size=512, overlap=64, offload_wav_to_cpu=False): |
| """ |
| Decode latents using tiling to reduce VRAM usage. |
| Uses overlap-discard strategy to avoid boundary artifacts. |
| |
| Args: |
| latents: [Batch, Channels, Length] |
| chunk_size: Size of latent chunk to process at once |
| overlap: Overlap size in latent frames |
| offload_wav_to_cpu: If True, offload decoded wav audio to CPU immediately to save VRAM |
| """ |
| B, C, T = latents.shape |
| |
| |
| if T <= chunk_size: |
| return self.vae.decode(latents).sample |
|
|
| |
| stride = chunk_size - 2 * overlap |
| if stride <= 0: |
| raise ValueError(f"chunk_size {chunk_size} must be > 2 * overlap {overlap}") |
| |
| num_steps = math.ceil(T / stride) |
| |
| if offload_wav_to_cpu: |
| |
| return self._tiled_decode_offload_cpu(latents, B, T, stride, overlap, num_steps) |
| else: |
| |
| return self._tiled_decode_gpu(latents, B, T, stride, overlap, num_steps) |
| |
| def _tiled_decode_gpu(self, latents, B, T, stride, overlap, num_steps): |
| """Standard tiled decode keeping all data on GPU.""" |
| decoded_audio_list = [] |
| upsample_factor = None |
| |
| for i in tqdm(range(num_steps), desc="Decoding audio chunks"): |
| |
| core_start = i * stride |
| core_end = min(core_start + stride, T) |
| |
| |
| win_start = max(0, core_start - overlap) |
| win_end = min(T, core_end + overlap) |
| |
| |
| latent_chunk = latents[:, :, win_start:win_end] |
| |
| |
| |
| audio_chunk = self.vae.decode(latent_chunk).sample |
| |
| |
| if upsample_factor is None: |
| upsample_factor = audio_chunk.shape[-1] / latent_chunk.shape[-1] |
| |
| |
| |
| added_start = core_start - win_start |
| trim_start = int(round(added_start * upsample_factor)) |
| |
| |
| added_end = win_end - core_end |
| trim_end = int(round(added_end * upsample_factor)) |
| |
| |
| audio_len = audio_chunk.shape[-1] |
| end_idx = audio_len - trim_end if trim_end > 0 else audio_len |
| |
| audio_core = audio_chunk[:, :, trim_start:end_idx] |
| decoded_audio_list.append(audio_core) |
| |
| |
| final_audio = torch.cat(decoded_audio_list, dim=-1) |
| return final_audio |
| |
| def _tiled_decode_offload_cpu(self, latents, B, T, stride, overlap, num_steps): |
| """Optimized tiled decode that offloads to CPU immediately to save VRAM.""" |
| |
| first_core_start = 0 |
| first_core_end = min(stride, T) |
| first_win_start = 0 |
| first_win_end = min(T, first_core_end + overlap) |
| |
| first_latent_chunk = latents[:, :, first_win_start:first_win_end] |
| first_audio_chunk = self.vae.decode(first_latent_chunk).sample |
| |
| upsample_factor = first_audio_chunk.shape[-1] / first_latent_chunk.shape[-1] |
| audio_channels = first_audio_chunk.shape[1] |
| |
| |
| total_audio_length = int(round(T * upsample_factor)) |
| final_audio = torch.zeros(B, audio_channels, total_audio_length, |
| dtype=first_audio_chunk.dtype, device='cpu') |
| |
| |
| first_added_end = first_win_end - first_core_end |
| first_trim_end = int(round(first_added_end * upsample_factor)) |
| first_audio_len = first_audio_chunk.shape[-1] |
| first_end_idx = first_audio_len - first_trim_end if first_trim_end > 0 else first_audio_len |
| |
| first_audio_core = first_audio_chunk[:, :, :first_end_idx] |
| audio_write_pos = first_audio_core.shape[-1] |
| final_audio[:, :, :audio_write_pos] = first_audio_core.cpu() |
| |
| |
| del first_audio_chunk, first_audio_core, first_latent_chunk |
| |
| |
| for i in tqdm(range(1, num_steps), desc="Decoding audio chunks"): |
| |
| core_start = i * stride |
| core_end = min(core_start + stride, T) |
| |
| |
| win_start = max(0, core_start - overlap) |
| win_end = min(T, core_end + overlap) |
| |
| |
| latent_chunk = latents[:, :, win_start:win_end] |
| |
| |
| |
| audio_chunk = self.vae.decode(latent_chunk).sample |
| |
| |
| added_start = core_start - win_start |
| trim_start = int(round(added_start * upsample_factor)) |
| |
| added_end = win_end - core_end |
| trim_end = int(round(added_end * upsample_factor)) |
| |
| |
| audio_len = audio_chunk.shape[-1] |
| end_idx = audio_len - trim_end if trim_end > 0 else audio_len |
| |
| audio_core = audio_chunk[:, :, trim_start:end_idx] |
| |
| |
| core_len = audio_core.shape[-1] |
| final_audio[:, :, audio_write_pos:audio_write_pos + core_len] = audio_core.cpu() |
| audio_write_pos += core_len |
| |
| |
| del audio_chunk, audio_core, latent_chunk |
| |
| |
| final_audio = final_audio[:, :, :audio_write_pos] |
| |
| return final_audio |
|
|
| def generate_music( |
| self, |
| captions: str, |
| lyrics: str, |
| bpm: Optional[int] = None, |
| key_scale: str = "", |
| time_signature: str = "", |
| vocal_language: str = "en", |
| inference_steps: int = 8, |
| guidance_scale: float = 7.0, |
| use_random_seed: bool = True, |
| seed: Optional[Union[str, float, int]] = -1, |
| reference_audio=None, |
| audio_duration: Optional[float] = None, |
| batch_size: Optional[int] = None, |
| src_audio=None, |
| audio_code_string: Union[str, List[str]] = "", |
| repainting_start: float = 0.0, |
| repainting_end: Optional[float] = None, |
| instruction: str = DEFAULT_DIT_INSTRUCTION, |
| audio_cover_strength: float = 1.0, |
| task_type: str = "text2music", |
| use_adg: bool = False, |
| cfg_interval_start: float = 0.0, |
| cfg_interval_end: float = 1.0, |
| shift: float = 1.0, |
| infer_method: str = "ode", |
| use_tiled_decode: bool = True, |
| timesteps: Optional[List[float]] = None, |
| progress=None |
| ) -> Dict[str, Any]: |
| """ |
| Main interface for music generation |
| |
| Returns: |
| Dictionary containing: |
| - audios: List of audio dictionaries with path, key, params |
| - generation_info: Markdown-formatted generation information |
| - status_message: Status message |
| - extra_outputs: Dictionary with latents, masks, time_costs, etc. |
| - success: Whether generation completed successfully |
| - error: Error message if generation failed |
| """ |
| if progress is None: |
| def progress(*args, **kwargs): |
| pass |
|
|
| if self.model is None or self.vae is None or self.text_tokenizer is None or self.text_encoder is None: |
| return { |
| "audios": [], |
| "status_message": "❌ Model not fully initialized. Please initialize all components first.", |
| "extra_outputs": {}, |
| "success": False, |
| "error": "Model not fully initialized", |
| } |
|
|
| def _has_audio_codes(v: Union[str, List[str]]) -> bool: |
| if isinstance(v, list): |
| return any((x or "").strip() for x in v) |
| return bool(v and str(v).strip()) |
|
|
| |
| |
| |
| if task_type == "text2music": |
| if _has_audio_codes(audio_code_string): |
| |
| task_type = "cover" |
| |
| instruction = TASK_INSTRUCTIONS["cover"] |
|
|
| logger.info("[generate_music] Starting generation...") |
| if progress: |
| progress(0.51, desc="Preparing inputs...") |
| logger.info("[generate_music] Preparing inputs...") |
| |
| |
| self.current_offload_cost = 0.0 |
|
|
| |
| |
| actual_batch_size = batch_size if batch_size is not None else self.batch_size |
| actual_batch_size = max(1, actual_batch_size) |
|
|
| actual_seed_list, seed_value_for_ui = self.prepare_seeds(actual_batch_size, seed, use_random_seed) |
| |
| |
| if audio_duration is not None and float(audio_duration) <= 0: |
| audio_duration = None |
| |
| |
| if repainting_end is not None and float(repainting_end) < 0: |
| repainting_end = None |
| |
| try: |
| |
| refer_audios = None |
| if reference_audio is not None: |
| logger.info("[generate_music] Processing reference audio...") |
| processed_ref_audio = self.process_reference_audio(reference_audio) |
| if processed_ref_audio is not None: |
| |
| |
| refer_audios = [[processed_ref_audio] for _ in range(actual_batch_size)] |
| else: |
| refer_audios = [[torch.zeros(2, 30*self.sample_rate)] for _ in range(actual_batch_size)] |
| |
| |
| |
| processed_src_audio = None |
| if src_audio is not None: |
| |
| if _has_audio_codes(audio_code_string): |
| logger.info("[generate_music] Audio codes provided, ignoring src_audio and using codes instead") |
| else: |
| logger.info("[generate_music] Processing source audio...") |
| processed_src_audio = self.process_src_audio(src_audio) |
| |
| |
| captions_batch, instructions_batch, lyrics_batch, vocal_languages_batch, metas_batch = self.prepare_batch_data( |
| actual_batch_size, |
| processed_src_audio, |
| audio_duration, |
| captions, |
| lyrics, |
| vocal_language, |
| instruction, |
| bpm, |
| key_scale, |
| time_signature |
| ) |
| |
| is_repaint_task, is_lego_task, is_cover_task, can_use_repainting = self.determine_task_type(task_type, audio_code_string) |
| |
| repainting_start_batch, repainting_end_batch, target_wavs_tensor = self.prepare_padding_info( |
| actual_batch_size, |
| processed_src_audio, |
| audio_duration, |
| repainting_start, |
| repainting_end, |
| is_repaint_task, |
| is_lego_task, |
| is_cover_task, |
| can_use_repainting |
| ) |
| |
| progress(0.52, desc=f"Generating music (batch size: {actual_batch_size})...") |
| |
| |
| |
| audio_code_hints_batch = None |
| if _has_audio_codes(audio_code_string): |
| if isinstance(audio_code_string, list): |
| audio_code_hints_batch = audio_code_string |
| else: |
| audio_code_hints_batch = [audio_code_string] * actual_batch_size |
|
|
| should_return_intermediate = (task_type == "text2music") |
| outputs = self.service_generate( |
| captions=captions_batch, |
| lyrics=lyrics_batch, |
| metas=metas_batch, |
| vocal_languages=vocal_languages_batch, |
| refer_audios=refer_audios, |
| target_wavs=target_wavs_tensor, |
| infer_steps=inference_steps, |
| guidance_scale=guidance_scale, |
| seed=actual_seed_list, |
| repainting_start=repainting_start_batch, |
| repainting_end=repainting_end_batch, |
| instructions=instructions_batch, |
| audio_cover_strength=audio_cover_strength, |
| use_adg=use_adg, |
| cfg_interval_start=cfg_interval_start, |
| cfg_interval_end=cfg_interval_end, |
| shift=shift, |
| infer_method=infer_method, |
| audio_code_hints=audio_code_hints_batch, |
| return_intermediate=should_return_intermediate, |
| timesteps=timesteps, |
| ) |
| |
| logger.info("[generate_music] Model generation completed. Decoding latents...") |
| pred_latents = outputs["target_latents"] |
| time_costs = outputs["time_costs"] |
| time_costs["offload_time_cost"] = self.current_offload_cost |
| logger.debug(f"[generate_music] pred_latents: {pred_latents.shape}, dtype={pred_latents.dtype} {pred_latents.min()=}, {pred_latents.max()=}, {pred_latents.mean()=} {pred_latents.std()=}") |
| logger.debug(f"[generate_music] time_costs: {time_costs}") |
| if progress: |
| progress(0.8, desc="Decoding audio...") |
| logger.info("[generate_music] Decoding latents with VAE...") |
| |
| |
| start_time = time.time() |
| with torch.no_grad(): |
| with self._load_model_context("vae"): |
| |
| pred_latents_for_decode = pred_latents.transpose(1, 2) |
| |
| pred_latents_for_decode = pred_latents_for_decode.to(self.vae.dtype) |
| |
| if use_tiled_decode: |
| logger.info("[generate_music] Using tiled VAE decode to reduce VRAM usage...") |
| pred_wavs = self.tiled_decode(pred_latents_for_decode) |
| else: |
| pred_wavs = self.vae.decode(pred_latents_for_decode).sample |
| |
| |
| pred_wavs = pred_wavs.to(torch.float32) |
| end_time = time.time() |
| time_costs["vae_decode_time_cost"] = end_time - start_time |
| time_costs["total_time_cost"] = time_costs["total_time_cost"] + time_costs["vae_decode_time_cost"] |
| |
| |
| time_costs["offload_time_cost"] = self.current_offload_cost |
| |
| logger.info("[generate_music] VAE decode completed. Preparing audio tensors...") |
| if progress: |
| progress(0.99, desc="Preparing audio data...") |
| |
| |
| |
| |
| audio_tensors = [] |
| |
| for i in range(actual_batch_size): |
| |
| audio_tensor = pred_wavs[i].cpu().float() |
| audio_tensors.append(audio_tensor) |
| |
| status_message = f"✅ Generation completed successfully!" |
| logger.info(f"[generate_music] Done! Generated {len(audio_tensors)} audio tensors.") |
| |
| |
| src_latents = outputs.get("src_latents") |
| target_latents_input = outputs.get("target_latents_input") |
| chunk_masks = outputs.get("chunk_masks") |
| spans = outputs.get("spans", []) |
| latent_masks = outputs.get("latent_masks") |
| |
| |
| encoder_hidden_states = outputs.get("encoder_hidden_states") |
| encoder_attention_mask = outputs.get("encoder_attention_mask") |
| context_latents = outputs.get("context_latents") |
| lyric_token_idss = outputs.get("lyric_token_idss") |
| |
| |
| extra_outputs = { |
| "pred_latents": pred_latents.detach().cpu() if pred_latents is not None else None, |
| "target_latents": target_latents_input.detach().cpu() if target_latents_input is not None else None, |
| "src_latents": src_latents.detach().cpu() if src_latents is not None else None, |
| "chunk_masks": chunk_masks.detach().cpu() if chunk_masks is not None else None, |
| "latent_masks": latent_masks.detach().cpu() if latent_masks is not None else None, |
| "spans": spans, |
| "time_costs": time_costs, |
| "seed_value": seed_value_for_ui, |
| |
| "encoder_hidden_states": encoder_hidden_states.detach().cpu() if encoder_hidden_states is not None else None, |
| "encoder_attention_mask": encoder_attention_mask.detach().cpu() if encoder_attention_mask is not None else None, |
| "context_latents": context_latents.detach().cpu() if context_latents is not None else None, |
| "lyric_token_idss": lyric_token_idss.detach().cpu() if lyric_token_idss is not None else None, |
| } |
| |
| |
| audios = [] |
| for idx, audio_tensor in enumerate(audio_tensors): |
| audio_dict = { |
| "tensor": audio_tensor, |
| "sample_rate": self.sample_rate, |
| } |
| audios.append(audio_dict) |
| |
| return { |
| "audios": audios, |
| "status_message": status_message, |
| "extra_outputs": extra_outputs, |
| "success": True, |
| "error": None, |
| } |
|
|
| except Exception as e: |
| error_msg = f"❌ Error: {str(e)}\n{traceback.format_exc()}" |
| logger.exception("[generate_music] Generation failed") |
| return { |
| "audios": [], |
| "status_message": error_msg, |
| "extra_outputs": {}, |
| "success": False, |
| "error": str(e), |
| } |
|
|
| @torch.no_grad() |
| def get_lyric_timestamp( |
| self, |
| pred_latent: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| encoder_attention_mask: torch.Tensor, |
| context_latents: torch.Tensor, |
| lyric_token_ids: torch.Tensor, |
| total_duration_seconds: float, |
| vocal_language: str = "en", |
| inference_steps: int = 8, |
| seed: int = 42, |
| custom_layers_config: Optional[Dict] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Generate lyrics timestamps from generated audio latents using cross-attention alignment. |
| |
| This method adds noise to the final pred_latent and re-infers one step to get |
| cross-attention matrices, then uses DTW to align lyrics tokens with audio frames. |
| |
| Args: |
| pred_latent: Generated latent tensor [batch, T, D] |
| encoder_hidden_states: Cached encoder hidden states |
| encoder_attention_mask: Cached encoder attention mask |
| context_latents: Cached context latents |
| lyric_token_ids: Tokenized lyrics tensor [batch, seq_len] |
| total_duration_seconds: Total audio duration in seconds |
| vocal_language: Language code for lyrics header parsing |
| inference_steps: Number of inference steps (for noise level calculation) |
| seed: Random seed for noise generation |
| custom_layers_config: Dict mapping layer indices to head indices |
| |
| Returns: |
| Dict containing: |
| - lrc_text: LRC formatted lyrics with timestamps |
| - sentence_timestamps: List of SentenceTimestamp objects |
| - token_timestamps: List of TokenTimestamp objects |
| - success: Whether generation succeeded |
| - error: Error message if failed |
| """ |
| from transformers.cache_utils import EncoderDecoderCache, DynamicCache |
| |
| if self.model is None: |
| return { |
| "lrc_text": "", |
| "sentence_timestamps": [], |
| "token_timestamps": [], |
| "success": False, |
| "error": "Model not initialized" |
| } |
| |
| if custom_layers_config is None: |
| custom_layers_config = self.custom_layers_config |
| |
| try: |
| |
| device = self.device |
| dtype = self.dtype |
| |
| pred_latent = pred_latent.to(device=device, dtype=dtype) |
| encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype) |
| encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype) |
| context_latents = context_latents.to(device=device, dtype=dtype) |
| |
| bsz = pred_latent.shape[0] |
| |
| |
| t_last_val = 1.0 / inference_steps |
| t_curr_tensor = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype) |
| |
| x1 = pred_latent |
| |
| |
| if seed is None: |
| x0 = torch.randn_like(x1) |
| else: |
| generator = torch.Generator(device=device).manual_seed(int(seed)) |
| x0 = torch.randn(x1.shape, generator=generator, device=device, dtype=dtype) |
| |
| |
| xt = t_last_val * x0 + (1.0 - t_last_val) * x1 |
|
|
| xt_in = xt |
| t_in = t_curr_tensor |
| |
| |
| encoder_hidden_states_in = encoder_hidden_states |
| encoder_attention_mask_in = encoder_attention_mask |
| context_latents_in = context_latents |
| latent_length = x1.shape[1] |
| attention_mask = torch.ones(bsz, latent_length, device=device, dtype=dtype) |
| attention_mask_in = attention_mask |
| past_key_values = None |
| |
| |
| with self._load_model_context("model"): |
| decoder = self.model.decoder |
| decoder_outputs = decoder( |
| hidden_states=xt_in, |
| timestep=t_in, |
| timestep_r=t_in, |
| attention_mask=attention_mask_in, |
| encoder_hidden_states=encoder_hidden_states_in, |
| use_cache=False, |
| past_key_values=past_key_values, |
| encoder_attention_mask=encoder_attention_mask_in, |
| context_latents=context_latents_in, |
| output_attentions=True, |
| custom_layers_config=custom_layers_config, |
| enable_early_exit=True |
| ) |
| |
| |
| if decoder_outputs[2] is None: |
| return { |
| "lrc_text": "", |
| "sentence_timestamps": [], |
| "token_timestamps": [], |
| "success": False, |
| "error": "Model did not return attentions" |
| } |
| |
| cross_attns = decoder_outputs[2] |
| |
| captured_layers_list = [] |
| for layer_attn in cross_attns: |
| |
| if layer_attn is None: |
| continue |
| |
| cond_attn = layer_attn[:bsz] |
| layer_matrix = cond_attn.transpose(-1, -2) |
| captured_layers_list.append(layer_matrix) |
| |
| if not captured_layers_list: |
| return { |
| "lrc_text": "", |
| "sentence_timestamps": [], |
| "token_timestamps": [], |
| "success": False, |
| "error": "No valid attention layers returned" |
| } |
| |
| stacked = torch.stack(captured_layers_list) |
| if bsz == 1: |
| all_layers_matrix = stacked.squeeze(1) |
| else: |
| all_layers_matrix = stacked |
| |
| |
| if isinstance(lyric_token_ids, torch.Tensor): |
| raw_lyric_ids = lyric_token_ids[0].tolist() |
| else: |
| raw_lyric_ids = lyric_token_ids |
| |
| |
| header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n" |
| header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False) |
| start_idx = len(header_ids) |
| |
| |
| try: |
| end_idx = raw_lyric_ids.index(151643) |
| except ValueError: |
| end_idx = len(raw_lyric_ids) |
| |
| pure_lyric_ids = raw_lyric_ids[start_idx:end_idx] |
| pure_lyric_matrix = all_layers_matrix[:, :, start_idx:end_idx, :] |
| |
| |
| aligner = MusicStampsAligner(self.text_tokenizer) |
| |
| align_info = aligner.stamps_align_info( |
| attention_matrix=pure_lyric_matrix, |
| lyrics_tokens=pure_lyric_ids, |
| total_duration_seconds=total_duration_seconds, |
| custom_config=custom_layers_config, |
| return_matrices=False, |
| violence_level=2.0, |
| medfilt_width=1, |
| ) |
| |
| if align_info.get("calc_matrix") is None: |
| return { |
| "lrc_text": "", |
| "sentence_timestamps": [], |
| "token_timestamps": [], |
| "success": False, |
| "error": align_info.get("error", "Failed to process attention matrix") |
| } |
| |
| |
| result = aligner.get_timestamps_and_lrc( |
| calc_matrix=align_info["calc_matrix"], |
| lyrics_tokens=pure_lyric_ids, |
| total_duration_seconds=total_duration_seconds |
| ) |
| |
| return { |
| "lrc_text": result["lrc_text"], |
| "sentence_timestamps": result["sentence_timestamps"], |
| "token_timestamps": result["token_timestamps"], |
| "success": True, |
| "error": None |
| } |
| |
| except Exception as e: |
| error_msg = f"Error generating timestamps: {str(e)}" |
| logger.exception("[get_lyric_timestamp] Failed") |
| return { |
| "lrc_text": "", |
| "sentence_timestamps": [], |
| "token_timestamps": [], |
| "success": False, |
| "error": error_msg |
| } |
|
|
| @torch.no_grad() |
| def get_lyric_score( |
| self, |
| pred_latent: torch.Tensor, |
| encoder_hidden_states: torch.Tensor, |
| encoder_attention_mask: torch.Tensor, |
| context_latents: torch.Tensor, |
| lyric_token_ids: torch.Tensor, |
| vocal_language: str = "en", |
| inference_steps: int = 8, |
| seed: int = 42, |
| custom_layers_config: Optional[Dict] = None, |
| ) -> Dict[str, Any]: |
| """ |
| Calculate both LM and DiT alignment scores in one pass. |
| |
| - lm_score: Checks structural alignment using pure noise at t=1.0. |
| - dit_score: Checks denoising alignment using regressed latents at t=1/steps. |
| |
| Args: |
| pred_latent: Generated latent tensor [batch, T, D] |
| encoder_hidden_states: Cached encoder hidden states |
| encoder_attention_mask: Cached encoder attention mask |
| context_latents: Cached context latents |
| lyric_token_ids: Tokenized lyrics tensor [batch, seq_len] |
| vocal_language: Language code for lyrics header parsing |
| inference_steps: Number of inference steps (for noise level calculation) |
| seed: Random seed for noise generation |
| custom_layers_config: Dict mapping layer indices to head indices |
| |
| Returns: |
| Dict containing: |
| - lm_score: float |
| - dit_score: float |
| - success: Whether generation succeeded |
| - error: Error message if failed |
| """ |
| from transformers.cache_utils import EncoderDecoderCache, DynamicCache |
|
|
| if self.model is None: |
| return { |
| "lm_score": 0.0, |
| "dit_score": 0.0, |
| "success": False, |
| "error": "Model not initialized" |
| } |
|
|
| if custom_layers_config is None: |
| custom_layers_config = self.custom_layers_config |
|
|
| try: |
| |
| device = self.device |
| dtype = self.dtype |
|
|
| pred_latent = pred_latent.to(device=device, dtype=dtype) |
| encoder_hidden_states = encoder_hidden_states.to(device=device, dtype=dtype) |
| encoder_attention_mask = encoder_attention_mask.to(device=device, dtype=dtype) |
| context_latents = context_latents.to(device=device, dtype=dtype) |
|
|
| bsz = pred_latent.shape[0] |
|
|
| if seed is None: |
| x0 = torch.randn_like(pred_latent) |
| else: |
| generator = torch.Generator(device=device).manual_seed(int(seed)) |
| x0 = torch.randn(pred_latent.shape, generator=generator, device=device, dtype=dtype) |
|
|
| |
| |
| t_lm = torch.tensor([1.0] * bsz, device=device, dtype=dtype) |
| xt_lm = x0 |
|
|
| |
| |
| t_last_val = 1.0 / inference_steps |
| t_dit = torch.tensor([t_last_val] * bsz, device=device, dtype=dtype) |
| |
| xt_dit = t_last_val * x0 + (1.0 - t_last_val) * pred_latent |
|
|
| |
| xt_in = torch.cat([xt_lm, xt_dit], dim=0) |
| t_in = torch.cat([t_lm, t_dit], dim=0) |
|
|
| |
| encoder_hidden_states_in = torch.cat([encoder_hidden_states, encoder_hidden_states], dim=0) |
| encoder_attention_mask_in = torch.cat([encoder_attention_mask, encoder_attention_mask], dim=0) |
| context_latents_in = torch.cat([context_latents, context_latents], dim=0) |
|
|
| |
| latent_length = xt_in.shape[1] |
| attention_mask_in = torch.ones(2 * bsz, latent_length, device=device, dtype=dtype) |
| past_key_values = None |
|
|
| |
| with self._load_model_context("model"): |
| decoder = self.model.decoder |
| if hasattr(decoder, 'eval'): |
| decoder.eval() |
|
|
| decoder_outputs = decoder( |
| hidden_states=xt_in, |
| timestep=t_in, |
| timestep_r=t_in, |
| attention_mask=attention_mask_in, |
| encoder_hidden_states=encoder_hidden_states_in, |
| use_cache=False, |
| past_key_values=past_key_values, |
| encoder_attention_mask=encoder_attention_mask_in, |
| context_latents=context_latents_in, |
| output_attentions=True, |
| custom_layers_config=custom_layers_config, |
| enable_early_exit=True |
| ) |
|
|
| |
| if decoder_outputs[2] is None: |
| return { |
| "lm_score": 0.0, |
| "dit_score": 0.0, |
| "success": False, |
| "error": "Model did not return attentions" |
| } |
|
|
| cross_attns = decoder_outputs[2] |
|
|
| captured_layers_list = [] |
| for layer_attn in cross_attns: |
| if layer_attn is None: |
| continue |
|
|
| |
| layer_matrix = layer_attn.transpose(-1, -2) |
| captured_layers_list.append(layer_matrix) |
|
|
| if not captured_layers_list: |
| return { |
| "lm_score": 0.0, |
| "dit_score": 0.0, |
| "success": False, |
| "error": "No valid attention layers returned" |
| } |
|
|
| stacked = torch.stack(captured_layers_list) |
|
|
| all_layers_matrix_lm = stacked[:, :bsz, ...] |
| all_layers_matrix_dit = stacked[:, bsz:, ...] |
|
|
| if bsz == 1: |
| all_layers_matrix_lm = all_layers_matrix_lm.squeeze(1) |
| all_layers_matrix_dit = all_layers_matrix_dit.squeeze(1) |
| else: |
| pass |
|
|
| |
| if isinstance(lyric_token_ids, torch.Tensor): |
| raw_lyric_ids = lyric_token_ids[0].tolist() |
| else: |
| raw_lyric_ids = lyric_token_ids |
|
|
| |
| header_str = f"# Languages\n{vocal_language}\n\n# Lyric\n" |
| header_ids = self.text_tokenizer.encode(header_str, add_special_tokens=False) |
| start_idx = len(header_ids) |
|
|
| |
| try: |
| end_idx = raw_lyric_ids.index(151643) |
| except ValueError: |
| end_idx = len(raw_lyric_ids) |
|
|
| pure_lyric_ids = raw_lyric_ids[start_idx:end_idx] |
| if start_idx >= all_layers_matrix_lm.shape[-2]: |
| return { |
| "lm_score": 0.0, |
| "dit_score": 0.0, |
| "success": False, |
| "error": "Lyrics indices out of bounds" |
| } |
|
|
| pure_matrix_lm = all_layers_matrix_lm[..., start_idx:end_idx, :] |
| pure_matrix_dit = all_layers_matrix_dit[..., start_idx:end_idx, :] |
|
|
| |
| aligner = MusicLyricScorer(self.text_tokenizer) |
|
|
| def calculate_single_score(matrix): |
| """Helper to run aligner on a matrix""" |
| info = aligner.lyrics_alignment_info( |
| attention_matrix=matrix, |
| token_ids=pure_lyric_ids, |
| custom_config=custom_layers_config, |
| return_matrices=False, |
| medfilt_width=1, |
| ) |
| if info.get("energy_matrix") is None: |
| return 0.0 |
|
|
| res = aligner.calculate_score( |
| energy_matrix=info["energy_matrix"], |
| type_mask=info["type_mask"], |
| path_coords=info["path_coords"], |
| ) |
| |
| return res.get("lyrics_score", res.get("final_score", 0.0)) |
|
|
| lm_score = calculate_single_score(pure_matrix_lm) |
| dit_score = calculate_single_score(pure_matrix_dit) |
|
|
| return { |
| "lm_score": lm_score, |
| "dit_score": dit_score, |
| "success": True, |
| "error": None |
| } |
|
|
| except Exception as e: |
| error_msg = f"Error generating score: {str(e)}" |
| logger.exception("[get_lyric_score] Failed") |
| return { |
| "lm_score": 0.0, |
| "dit_score": 0.0, |
| "success": False, |
| "error": error_msg |
| } |