| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from typing import Any, Dict, List, Optional, Tuple, Type, Union, Literal, Final, cast |
| from dataclasses import dataclass |
| from pathlib import Path |
| import re |
| import torchaudio |
|
|
| from transformers import processing_utils |
|
|
| processing_utils.MODALITY_TO_BASE_CLASS_MAPPING["audio_tokenizer"] = "PreTrainedModel" |
|
|
| import torch |
| from transformers import ( |
| PreTrainedTokenizerBase, |
| BatchFeature, |
| ProcessorMixin, |
| logging, |
| AutoConfig, |
| AutoModel, |
| AutoTokenizer, |
| ) |
|
|
| from .configuration_moss_tts import MossTTSDelayConfig |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| AUDIO_PLACEHOLDER = "<|audio|>" |
|
|
|
|
| @dataclass |
| class Message: |
| def to_dict(self) -> Dict[str, Any]: |
| raise NotImplementedError |
|
|
|
|
| @dataclass |
| class UserMessage(Message): |
| text: Optional[str] = None |
| reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None |
| instruction: Optional[str] = None |
| tokens: Optional[int] = None |
| quality: Optional[str] = None |
| sound_event: Optional[str] = None |
| ambient_sound: Optional[str] = None |
| language: Optional[str] = None |
|
|
| def __post_init__(self): |
| template = """<user_inst> |
| - Reference(s): |
| {reference} |
| - Instruction: |
| {instruction} |
| - Tokens: |
| {tokens} |
| - Quality: |
| {quality} |
| - Sound Event: |
| {sound_event} |
| - Ambient Sound: |
| {ambient_sound} |
| - Language: |
| {language} |
| - Text: |
| {text} |
| </user_inst>""" |
|
|
| audio_codes_list = [] |
| if self.reference is None: |
| reference = "None" |
| elif isinstance(self.reference, List): |
| reference = [] |
| for speaker_idx, speaker_reference in enumerate(self.reference): |
| if speaker_reference is not None: |
| reference.append(f"[S{speaker_idx+1}]:\n{AUDIO_PLACEHOLDER}") |
| reference = "\n".join(reference) |
| audio_codes_list = [ |
| speaker_reference |
| for speaker_reference in self.reference |
| if speaker_reference is not None |
| ] |
| else: |
| raise TypeError("`reference` should be exactly a list when it is not None.") |
|
|
| content = ( |
| template.replace("{reference}", str(reference)) |
| .replace("{instruction}", str(self.instruction)) |
| .replace("{tokens}", str(self.tokens)) |
| .replace("{quality}", str(self.quality)) |
| .replace("{sound_event}", str(self.sound_event)) |
| .replace("{ambient_sound}", str(self.ambient_sound)) |
| .replace("{language}", str(self.language)) |
| .replace("{text}", str(self.text)) |
| ) |
|
|
| self._content = content |
| self._audio_codes_list = audio_codes_list |
|
|
| def to_dict(self): |
| return { |
| "role": "user", |
| "content": self._content, |
| "audio_codes_list": self._audio_codes_list, |
| } |
|
|
|
|
| @dataclass |
| class AssistantMessage(Message): |
| audio_codes_list: List[Union[str, torch.Tensor]] |
| content: str = AUDIO_PLACEHOLDER |
|
|
| def to_dict(self): |
| return { |
| "role": "assistant", |
| "content": self.content, |
| "audio_codes_list": self.audio_codes_list, |
| } |
|
|
|
|
| USER_MESSAGE_FIELDS = ( |
| "text", |
| "reference", |
| "instruction", |
| "tokens", |
| "quality", |
| "sound_event", |
| "ambient_sound", |
| "language", |
| ) |
|
|
|
|
| class MossTTSDelayProcessor(ProcessorMixin): |
| tokenizer_class = "AutoTokenizer" |
| audio_tokenizer_class = "AutoModel" |
|
|
| tokenizer: PreTrainedTokenizerBase |
| audio_tokenizer: Any |
|
|
| def __init__( |
| self, |
| tokenizer: PreTrainedTokenizerBase, |
| audio_tokenizer: Any = None, |
| model_config: Optional[MossTTSDelayConfig] = None, |
| **kwargs, |
| ): |
| super().__init__(tokenizer=tokenizer, audio_tokenizer=audio_tokenizer, **kwargs) |
|
|
| |
| self.tokenizer = tokenizer |
| self.audio_tokenizer = audio_tokenizer |
| if model_config is None: |
| model_config = MossTTSDelayConfig() |
| self.model_config = model_config |
|
|
| self.imstart_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>") |
| self.imend_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| self.newline_token_id = 198 |
|
|
| def _id_to_token(token_id: int) -> str: |
| tok = tokenizer.convert_ids_to_tokens(int(token_id)) |
| if isinstance(tok, list): |
| return tok[0] if len(tok) > 0 else "" |
| return cast(str, tok) |
|
|
| self.audio_user_slot_token = _id_to_token( |
| self.model_config.audio_user_slot_token_id |
| ) |
| self.audio_assistant_gen_slot_token = _id_to_token( |
| self.model_config.audio_assistant_gen_slot_token_id |
| ) |
| self.audio_assistant_delay_slot_token = _id_to_token( |
| self.model_config.audio_assistant_delay_slot_token_id |
| ) |
| self.audio_start_token = _id_to_token(self.model_config.audio_start_token_id) |
| self.audio_end_token = _id_to_token(self.model_config.audio_end_token_id) |
|
|
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): |
| trust_remote_code = kwargs.pop("trust_remote_code", True) |
| kwargs.pop("_from_auto", None) |
|
|
| audio_tokenizer_name_or_path = kwargs.pop( |
| "codec_path", "OpenMOSS-Team/MOSS-Audio-Tokenizer" |
| ) |
|
|
| pretrained_model_name_or_path = Path(pretrained_model_name_or_path) |
| model_config = cast( |
| MossTTSDelayConfig, |
| AutoConfig.from_pretrained( |
| pretrained_model_name_or_path, |
| *args, |
| trust_remote_code=trust_remote_code, |
| **kwargs, |
| ), |
| ) |
| tokenizer = AutoTokenizer.from_pretrained( |
| pretrained_model_name_or_path, |
| *args, |
| trust_remote_code=trust_remote_code, |
| **kwargs, |
| ) |
| audio_tokenizer = AutoModel.from_pretrained( |
| audio_tokenizer_name_or_path, |
| trust_remote_code=trust_remote_code, |
| **kwargs, |
| ) |
|
|
| return cls( |
| tokenizer=tokenizer, |
| audio_tokenizer=audio_tokenizer, |
| model_config=model_config, |
| **kwargs, |
| ) |
|
|
| def __call__(self, *args, **kwargs) -> BatchFeature: |
| conversations = args[0] if len(args) > 0 else kwargs.pop("conversations") |
| mode: str = kwargs.pop("mode", "generation") |
| apply_chat_template: bool = kwargs.pop("apply_chat_template", True) |
| n_vq: Optional[int] = kwargs.pop("n_vq", None) |
|
|
| |
| kwargs.pop("return_tensors", None) |
| kwargs.pop("padding", None) |
| kwargs.pop("truncation", None) |
|
|
| """ |
| mode only works when a Message is converted to a dict. |
| """ |
|
|
| if mode not in {"generation", "continuation"}: |
| raise RuntimeError |
|
|
| if isinstance(conversations, (Message, Dict)): |
| conversations = [conversations] |
|
|
| truncation = False |
| if mode == "continuation": |
| truncation = True |
|
|
| input_ids_list = [] |
| for conversation in conversations: |
| if isinstance(conversation, (Message, Dict)): |
| conversation = [conversation] |
|
|
| |
| conversation = [self._normalize_message(m) for m in conversation] |
|
|
| if (mode == "generation") ^ (len(conversation) % 2 != 0): |
| raise ValueError |
|
|
| if (mode == "generation") ^ (conversation[-1]["role"] == "user"): |
| raise ValueError |
|
|
| unified_codes = [] |
| for message_idx, message in enumerate(conversation): |
| if apply_chat_template: |
| add_generation_prompt = ( |
| mode == "generation" and message_idx == len(conversation) - 1 |
| ) |
| try: |
| content = self.tokenizer.apply_chat_template( |
| [{"role": message["role"], "content": message["content"]}], |
| add_generation_prompt=add_generation_prompt, |
| tokenize=False, |
| ) |
| except TypeError: |
| try: |
| content = self.tokenizer.apply_chat_template( |
| [ |
| { |
| "role": message["role"], |
| "content": message["content"], |
| } |
| ], |
| add_generation_prompt=add_generation_prompt, |
| ) |
| except Exception: |
| logger.warning( |
| "apply_chat_template failed; fallback to raw content." |
| ) |
| content = message["content"] |
| else: |
| content = message["content"] |
|
|
| if not isinstance(content, str): |
| content = str(content) |
|
|
| |
| |
| |
| raw_audio_items = message.get("audio_codes_list", []) |
|
|
| audio_codes_list: List[torch.Tensor] = [] |
| if len(raw_audio_items) > 0: |
| encoded_items: List[Optional[torch.Tensor]] = [None] * len( |
| raw_audio_items |
| ) |
| paths: List[str] = [] |
| path_positions: List[int] = [] |
|
|
| for idx, item in enumerate(raw_audio_items): |
| if isinstance(item, torch.Tensor): |
| if n_vq is not None and item.shape[1] != n_vq: |
| raise RuntimeError( |
| "audio_codes's n_vq is not equal to the parameter `n_vq`. Your can set the parameter `n_vq` as None if you have already tokenzied the wavs." |
| ) |
| encoded_items[idx] = item |
| continue |
|
|
| if isinstance(item, (str, os.PathLike)): |
| paths.append(str(item)) |
| path_positions.append(idx) |
| continue |
|
|
| raise TypeError( |
| "Each audio item must be a torch.Tensor of codes or a path-like string." |
| ) |
|
|
| if len(paths) > 0: |
| encoded_from_paths = self.encode_audios_from_path(paths, n_vq) |
| if len(encoded_from_paths) != len(paths): |
| raise RuntimeError( |
| "encode_audios_from_path returned an unexpected number of items." |
| ) |
| for pos, codes in zip(path_positions, encoded_from_paths): |
| encoded_items[pos] = codes |
|
|
| audio_codes_list = [cast(torch.Tensor, t) for t in encoded_items] |
| unified_codes.append( |
| self._get_unified_codes( |
| message["role"], content, audio_codes_list, truncation |
| ) |
| ) |
|
|
| unified_codes = torch.cat(unified_codes) |
| input_ids_list.append(unified_codes) |
|
|
| return BatchFeature(data=self._pad(input_ids_list)) |
|
|
| @staticmethod |
| def build_user_message( |
| text: Optional[str] = None, |
| reference: Optional[List[Optional[Union[str, torch.Tensor]]]] = None, |
| instruction: Optional[str] = None, |
| tokens: Optional[int] = None, |
| quality: Optional[str] = None, |
| sound_event: Optional[str] = None, |
| ambient_sound: Optional[str] = None, |
| language: Optional[str] = None, |
| ) -> Dict: |
| if reference is not None and not isinstance(reference, list): |
| reference = [reference] |
| return UserMessage( |
| text=text, |
| reference=reference, |
| instruction=instruction, |
| tokens=tokens, |
| quality=quality, |
| sound_event=sound_event, |
| ambient_sound=ambient_sound, |
| language=language, |
| ).to_dict() |
|
|
| @staticmethod |
| def build_assistant_message( |
| audio_codes_list: List[Union[str, torch.Tensor]], |
| content: str = AUDIO_PLACEHOLDER, |
| ) -> Dict: |
| return AssistantMessage( |
| audio_codes_list=audio_codes_list, |
| content=content, |
| ).to_dict() |
|
|
| def _normalize_message(self, message: Union[Message, Dict]) -> Dict: |
| if isinstance(message, Message): |
| return message.to_dict() |
| if not isinstance(message, dict): |
| raise TypeError("Each message must be a Message or dict.") |
| if "role" not in message: |
| raise ValueError("Message dict must include a 'role' field.") |
| if "content" in message and "audio_codes_list" in message: |
| return message |
| role = message["role"] |
| if role == "user": |
| kwargs = {key: message.get(key) for key in USER_MESSAGE_FIELDS} |
| return self.build_user_message(**kwargs) |
| if role == "assistant": |
| return self.build_assistant_message( |
| audio_codes_list=message.get("audio_codes_list", []), |
| content=message.get("content", AUDIO_PLACEHOLDER), |
| ) |
| raise ValueError(f"Unsupported role: {role}") |
|
|
| def _pad(self, input_ids_list: List[torch.Tensor]): |
| device = input_ids_list[0].device |
| lengths = torch.tensor([w.shape[0] for w in input_ids_list], device=device) |
| pad_input_ids = torch.nn.utils.rnn.pad_sequence( |
| input_ids_list, |
| batch_first=True, |
| padding_value=self.model_config.audio_pad_code, |
| padding_side="left", |
| ) |
| other_channel_mask = (pad_input_ids.shape[1] - lengths).unsqueeze( |
| 1 |
| ) > torch.arange(pad_input_ids.shape[1], device=device).unsqueeze(0) |
| pad_input_ids[..., 0][other_channel_mask] = self.model_config.pad_token_id |
| attention_mask = torch.zeros( |
| pad_input_ids.shape[0], pad_input_ids.shape[1], device=device |
| ) |
| attention_mask[~other_channel_mask] = 1 |
| attention_mask = attention_mask.bool() |
| return { |
| "input_ids": pad_input_ids, |
| "attention_mask": attention_mask, |
| } |
|
|
| @staticmethod |
| def _replace_audio_placeholders( |
| content: str, |
| lengths: List[int], |
| n_vq: int, |
| gen_slot_token: str, |
| delay_slot_token: str, |
| audio_start_token: str, |
| audio_end_token: str, |
| ) -> str: |
| if n_vq < 1: |
| raise ValueError(f"n_vq must be >= 1, got {n_vq}") |
|
|
| num_placeholders = content.count(AUDIO_PLACEHOLDER) |
| if num_placeholders != len(lengths): |
| raise ValueError( |
| f"Number of {AUDIO_PLACEHOLDER} ({num_placeholders}) " |
| f"does not match lengths ({len(lengths)})" |
| ) |
|
|
| def build_audio_block(length: int) -> str: |
| if length < 0: |
| raise ValueError(f"length must be >= 0, got {length}") |
|
|
| if length == 0: |
| return f"{audio_start_token}{audio_end_token}" |
|
|
| step_tokens = gen_slot_token * length + (delay_slot_token * (n_vq - 1)) |
| return f"{audio_start_token}{step_tokens}{audio_end_token}" |
|
|
| lengths_iter = iter(lengths) |
|
|
| def replacer(match: re.Match) -> str: |
| length = next(lengths_iter) |
| return build_audio_block(length) |
|
|
| result = re.sub(re.escape(AUDIO_PLACEHOLDER), replacer, content) |
|
|
| return result |
|
|
| @staticmethod |
| def _merge_consecutive_audio_placeholders( |
| content: str, |
| audio_codes_list: List[torch.Tensor], |
| ) -> Tuple[str, List[torch.Tensor]]: |
| matches = list(re.finditer(re.escape(AUDIO_PLACEHOLDER), content)) |
| if len(matches) <= 1: |
| return content, audio_codes_list |
|
|
| if len(matches) != len(audio_codes_list): |
| raise ValueError( |
| "Audio placeholders do not match the provided audio codes list." |
| ) |
|
|
| new_audio_codes_list = [] |
| new_parts = [] |
| last_pos = 0 |
| i = 0 |
| while i < len(matches): |
| j = i |
| while ( |
| j + 1 < len(matches) |
| and content[matches[j].end() : matches[j + 1].start()].strip() == "" |
| ): |
| j += 1 |
|
|
| new_parts.append(content[last_pos : matches[i].start()]) |
| new_parts.append(AUDIO_PLACEHOLDER) |
| last_pos = matches[j].end() |
|
|
| if j == i: |
| new_audio_codes_list.append(audio_codes_list[i]) |
| else: |
| new_audio_codes_list.append( |
| torch.cat(audio_codes_list[i : j + 1], dim=0) |
| ) |
|
|
| i = j + 1 |
|
|
| new_parts.append(content[last_pos:]) |
| return "".join(new_parts), new_audio_codes_list |
|
|
| @staticmethod |
| def apply_delay_pattern(codes: torch.Tensor, pad_code: int) -> torch.Tensor: |
| delayed_tokens = torch.full( |
| (codes.shape[0] + codes.shape[1] - 1, codes.shape[1]), |
| pad_code, |
| device=codes.device, |
| dtype=codes.dtype, |
| ) |
| for i in range(codes.shape[1]): |
| delayed_tokens[i : i + codes.shape[0], i] = codes[:, i] |
| return delayed_tokens |
|
|
| @staticmethod |
| def apply_de_delay_pattern(delay_codes: torch.Tensor) -> torch.Tensor: |
| tokens = torch.full( |
| (delay_codes.shape[0] - delay_codes.shape[1] + 1, delay_codes.shape[1]), |
| 0, |
| device=delay_codes.device, |
| dtype=delay_codes.dtype, |
| ) |
| for i in range(delay_codes.shape[1]): |
| tokens[:, i] = delay_codes[i : i + tokens.shape[0], i] |
| return tokens |
|
|
| def _get_unified_codes( |
| self, |
| role: str, |
| content: str, |
| audio_codes_list: List[torch.Tensor], |
| truncation: bool, |
| ) -> torch.Tensor: |
| """ |
| 此时的 content 已经是带上了对话格式 |
| """ |
| if role == "user": |
| audio_gen_slot_token = audio_delay_slot_token = self.audio_user_slot_token |
| truncation = False |
| else: |
| audio_gen_slot_token = self.audio_assistant_gen_slot_token |
| audio_delay_slot_token = self.audio_assistant_delay_slot_token |
|
|
| if len(audio_codes_list): |
| n_vq = audio_codes_list[0].shape[1] |
| else: |
| n_vq = self.model_config.n_vq |
|
|
| if len(audio_codes_list) > 1 and AUDIO_PLACEHOLDER in content: |
| content, audio_codes_list = self._merge_consecutive_audio_placeholders( |
| content, audio_codes_list |
| ) |
| content = self._replace_audio_placeholders( |
| content=content, |
| lengths=[len(audio_codes) for audio_codes in audio_codes_list], |
| n_vq=n_vq, |
| gen_slot_token=audio_gen_slot_token, |
| delay_slot_token=audio_delay_slot_token, |
| audio_start_token=self.audio_start_token, |
| audio_end_token=self.audio_end_token, |
| ) |
| text_codes = torch.tensor( |
| self.tokenizer.encode(content), |
| device=audio_codes_list[0].device if audio_codes_list else None, |
| ) |
|
|
| audio_start_indices = torch.where( |
| text_codes == self.model_config.audio_start_token_id |
| )[0] |
| audio_end_indices = torch.where( |
| text_codes == self.model_config.audio_end_token_id |
| )[0] |
| if len(audio_start_indices) != len(audio_codes_list) or len( |
| audio_end_indices |
| ) != len(audio_codes_list): |
| raise ValueError( |
| "Audio placeholders do not match the provided audio codes list." |
| ) |
|
|
| delay_audio_codes_list = [] |
| if len(audio_codes_list) == 0: |
| delay_audio_codes_list = torch.full( |
| (len(text_codes), n_vq), |
| self.model_config.audio_pad_code, |
| device=text_codes.device, |
| dtype=text_codes.dtype, |
| ) |
| else: |
| prefix_idx = 0 |
| for audio_start_idx_t, audio_end_idx_t, audio_codes in zip( |
| audio_start_indices, audio_end_indices, audio_codes_list |
| ): |
| audio_start_idx = int(audio_start_idx_t.item()) |
| audio_end_idx = int(audio_end_idx_t.item()) |
| delay_audio_codes = self.apply_delay_pattern( |
| audio_codes, self.model_config.audio_pad_code |
| ) |
| pad_codes = torch.full( |
| (audio_start_idx - prefix_idx + 1, n_vq), |
| self.model_config.audio_pad_code, |
| device=audio_codes.device, |
| dtype=audio_codes.dtype, |
| ) |
| delay_audio_codes_list.extend([pad_codes, delay_audio_codes]) |
| prefix_idx = audio_end_idx |
|
|
| if truncation: |
| delay_audio_codes_list[-1] = delay_audio_codes_list[-1][ |
| : -(n_vq - 1), : |
| ] |
| else: |
| last_audio_end_idx = int(audio_end_indices[-1].item()) |
| pad_codes = torch.full( |
| (len(text_codes) - last_audio_end_idx, n_vq), |
| self.model_config.audio_pad_code, |
| device=audio_codes_list[0].device, |
| dtype=audio_codes_list[0].dtype, |
| ) |
| delay_audio_codes_list.append(pad_codes) |
|
|
| delay_audio_codes_list = torch.cat(delay_audio_codes_list) |
|
|
| if text_codes.shape[0] != delay_audio_codes_list.shape[0]: |
| text_codes = text_codes[: delay_audio_codes_list.shape[0]] |
|
|
| unified_codes = torch.cat( |
| [text_codes.unsqueeze(1), delay_audio_codes_list], dim=1 |
| ) |
| return unified_codes |
|
|
| def _parse_text_codes(self, start_length, text_codes): |
| text = cast(str, self.tokenizer.decode(text_codes)) |
| prefix = cast(str, self.tokenizer.decode(text_codes[:start_length])) |
| text = text[len(prefix) :] |
|
|
| AUDIO_PATTERN = re.compile( |
| rf"(?:{self.audio_start_token})?" |
| rf"(?:{self.audio_assistant_gen_slot_token})*" |
| rf"(?:{self.audio_assistant_delay_slot_token})*" |
| rf"{self.audio_end_token}" |
| ) |
|
|
| def normalize_audio_segments(text: str) -> str: |
| def repl(match: re.Match) -> str: |
| seg = match.group(0) |
| |
| if self.audio_assistant_gen_slot_token in seg: |
| return AUDIO_PLACEHOLDER |
| |
| return "" |
|
|
| return AUDIO_PATTERN.sub(repl, text) |
|
|
| return normalize_audio_segments(text) |
|
|
| def _parse_audio_codes(self, start_length, audio_codes): |
| |
| audio_codes = self.apply_de_delay_pattern(audio_codes) |
|
|
| |
| is_pad = (audio_codes == self.model_config.audio_pad_code).all(dim=1) |
| non_pad = ~is_pad |
| if not non_pad.any(): |
| return [] |
|
|
| idx = torch.nonzero(non_pad).squeeze(1) |
| breaks = torch.where(idx[1:] != idx[:-1] + 1)[0] + 1 |
| if breaks.numel() == 0: |
| segments_idx = [idx] |
| else: |
| segments_idx = torch.split(idx, breaks.tolist()) |
|
|
| audio_codes_list = [audio_codes[s] for s in segments_idx] |
|
|
| |
| decoded_audio_list = self.decode_audio_codes(audio_codes_list) |
|
|
| |
| |
| if ( |
| start_length > 0 |
| and len(audio_codes_list) > 0 |
| and len(decoded_audio_list) > 0 |
| ): |
| first_codes_length = audio_codes_list[0].shape[0] |
| if first_codes_length > 0: |
| trim_ratio = max( |
| 0.0, min(float(start_length) / float(first_codes_length), 1.0) |
| ) |
| first_audio = decoded_audio_list[0] |
| if trim_ratio >= 1.0: |
| decoded_audio_list = decoded_audio_list[1:] |
| elif trim_ratio > 0.0: |
| trim_samples = int(first_audio.shape[-1] * trim_ratio) |
| decoded_audio_list[0] = first_audio[..., trim_samples:] |
|
|
| return decoded_audio_list |
|
|
| def decode(self, output: List[Tuple[int, torch.Tensor]]): |
| """ |
| 1. 这里不管怎样,都需要一个完整的 assistant generation ids; |
| 2. 支持从任意位置进行截断; |
| """ |
|
|
| genearted_messages = [] |
| for start_length, generation_ids in output: |
| content = self._parse_text_codes(start_length, generation_ids[:, 0]) |
| audio_codes_list = self._parse_audio_codes( |
| start_length, generation_ids[:, 1:] |
| ) |
| if content == "": |
| message = None |
| else: |
| message = AssistantMessage( |
| content=content, |
| audio_codes_list=cast( |
| List[Union[str, torch.Tensor]], audio_codes_list |
| ), |
| ) |
| genearted_messages.append(message) |
| return genearted_messages |
|
|
| @staticmethod |
| def loudness_normalize( |
| wav: torch.Tensor, |
| target_dbfs: float = -20, |
| gain_range: tuple[float, float] = (-3.0, 3.0), |
| ) -> torch.Tensor: |
| wav = wav.to(torch.float32) |
| if wav.numel() == 0: |
| return wav |
| current_dbfs = 10.0 * torch.log10(torch.mean(wav**2) + 1e-9) |
| gain = float(target_dbfs - current_dbfs) |
| gain = max(gain_range[0], min(gain, gain_range[1])) |
| factor = 10.0 ** (gain / 20.0) |
| return wav * factor |
|
|
| def _get_audio_tokenizer_device(self) -> torch.device: |
| """Best-effort device inference for `self.audio_tokenizer`. |
| |
| Notes: |
| - Old TAC wrapper exposed `.device`, but standard `torch.nn.Module` does not. |
| - New MossAudioTokenizerModel is a `PreTrainedModel`; parameters define its device. |
| """ |
|
|
| audio_tokenizer = getattr(self, "audio_tokenizer", None) |
| if audio_tokenizer is None: |
| logger.warning( |
| "audio_tokenizer is not set on processor. Using CPU as default." |
| ) |
| return torch.device("cpu") |
|
|
| device_attr = getattr(audio_tokenizer, "device", None) |
| if isinstance(device_attr, torch.device): |
| return device_attr |
|
|
| try: |
| return next(audio_tokenizer.parameters()).device |
| except StopIteration: |
| |
| logger.warning( |
| "No parameters found on audio_tokenizer. Using CPU as default." |
| ) |
| return torch.device("cpu") |
|
|
| def encode_audios_from_wav( |
| self, |
| wav_list: List[torch.Tensor], |
| sampling_rate: int, |
| n_vq: Optional[int] = None, |
| ): |
| if self.audio_tokenizer is None: |
| raise RuntimeError("audio_tokenizer is not set on processor.") |
| audio_tokenizer = self.audio_tokenizer |
|
|
| if isinstance(wav_list, torch.Tensor): |
| wav_list = [wav_list] |
| wav_list_ = [] |
| resample = False |
| if sampling_rate != self.model_config.sampling_rate: |
| resample = True |
| device = self._get_audio_tokenizer_device() |
| for wav in wav_list: |
| if wav.shape[0] > 1: |
| wav = torch.mean(wav, dim=0, keepdim=True) |
| if resample: |
| wav = torchaudio.functional.resample( |
| waveform=wav, |
| orig_freq=sampling_rate, |
| new_freq=self.model_config.sampling_rate, |
| ) |
| wav = wav.to(device) |
| wav_list_.append(self.loudness_normalize(wav.squeeze(0))) |
|
|
| |
| if hasattr(audio_tokenizer, "batch_encode"): |
| enc = audio_tokenizer.batch_encode(wav_list_, num_quantizers=n_vq) |
| audio_codes = enc.audio_codes |
| audio_codes_lengths = enc.audio_codes_lengths |
| else: |
| |
| max_len = max(int(wav.shape[-1]) for wav in wav_list_) |
| input_values = torch.zeros( |
| len(wav_list_), 1, max_len, device=device, dtype=torch.float32 |
| ) |
| padding_mask = torch.zeros( |
| len(wav_list_), max_len, device=device, dtype=torch.bool |
| ) |
| for i, wav in enumerate(wav_list_): |
| this_len = int(wav.shape[-1]) |
| input_values[i, 0, :this_len] = wav |
| padding_mask[i, :this_len] = True |
| enc = audio_tokenizer.encode( |
| input_values, |
| padding_mask=padding_mask, |
| num_quantizers=n_vq, |
| return_dict=True, |
| ) |
| audio_codes = enc.audio_codes |
| audio_codes_lengths = enc.audio_codes_lengths |
|
|
| if audio_codes is None or audio_codes_lengths is None: |
| raise RuntimeError( |
| "audio_tokenizer.encode() returned empty outputs (audio_codes/audio_codes_lengths)." |
| ) |
|
|
| |
| |
| codes_list: List[torch.Tensor] = [] |
| for i in range(int(audio_codes.shape[1])): |
| length_i = int(audio_codes_lengths[i].item()) |
| codes_i = ( |
| audio_codes[:, i, :length_i] |
| .transpose(0, 1) |
| .contiguous() |
| .to(torch.long) |
| .cpu() |
| ) |
| codes_list.append(codes_i) |
| return codes_list |
|
|
| def encode_audios_from_path( |
| self, wav_path_list: Union[str, List[str]], n_vq: Optional[int] = None |
| ): |
| if isinstance(wav_path_list, str): |
| wav_path_list = [wav_path_list] |
|
|
| if len(wav_path_list) == 0: |
| raise ValueError("Empty wav_path_list") |
|
|
| |
| |
| |
| target_sr = int(self.model_config.sampling_rate) |
| wav_list: List[torch.Tensor] = [] |
| for wav_path in wav_path_list: |
| wav, sr = torchaudio.load(wav_path) |
| if int(sr) != target_sr: |
| wav = torchaudio.functional.resample( |
| waveform=wav, |
| orig_freq=int(sr), |
| new_freq=target_sr, |
| ) |
| wav_list.append(wav) |
|
|
| return self.encode_audios_from_wav(wav_list, target_sr, n_vq) |
|
|
| def decode_audio_codes( |
| self, audio_tokens_list: Union[torch.Tensor, List[torch.Tensor]] |
| ): |
| if self.audio_tokenizer is None: |
| raise RuntimeError("audio_tokenizer is not set on processor.") |
| audio_tokenizer = self.audio_tokenizer |
|
|
| if isinstance(audio_tokens_list, torch.Tensor): |
| audio_tokens_list = [audio_tokens_list] |
| if len(audio_tokens_list) == 0: |
| return [] |
|
|
| device = self._get_audio_tokenizer_device() |
|
|
| |
| codes_list = [ |
| codes.transpose(0, 1).contiguous().to(device=device, dtype=torch.long) |
| for codes in audio_tokens_list |
| ] |
| |
| |
| nq = int(codes_list[0].shape[0]) |
| max_t = max(int(c.shape[1]) for c in codes_list) |
| audio_codes = torch.zeros( |
| nq, len(codes_list), max_t, device=device, dtype=torch.long |
| ) |
| padding_mask = torch.zeros( |
| len(codes_list), max_t, device=device, dtype=torch.bool |
| ) |
| for i, c in enumerate(codes_list): |
| t = int(c.shape[1]) |
| audio_codes[:, i, :t] = c |
| padding_mask[i, :t] = True |
| dec = audio_tokenizer.decode( |
| audio_codes, padding_mask=padding_mask, return_dict=True, chunk_duration=8 |
| ) |
| audio = dec.audio |
| audio_lengths = dec.audio_lengths |
|
|
| if audio is None or audio_lengths is None: |
| raise RuntimeError( |
| "audio_tokenizer.decode() returned empty outputs (audio/audio_lengths)." |
| ) |
|
|
| |
| wav_list: List[torch.Tensor] = [] |
| for i in range(int(audio.shape[0])): |
| length_i = int(audio_lengths[i].item()) |
| wav = audio[i, 0, :length_i].contiguous().to(torch.float32).cpu() |
| wav_list.append(wav) |
| return wav_list |
|
|