| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Processor class for MossSpeech.""" |
|
|
| from __future__ import annotations |
|
|
| import os |
| import re |
| from dataclasses import asdict, dataclass |
| from typing import Any, Mapping, Optional, Sequence, Union |
|
|
| from transformers import AutoTokenizer, AutoModel |
| from transformers.processing_utils import ProcessingKwargs, ProcessorMixin |
| from transformers.tokenization_utils_base import BatchEncoding |
| from transformers.utils import is_torch_available, logging |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
| if is_torch_available(): |
| import torch |
| else: |
| torch = None |
|
|
| _TEXT_PLACEHOLDER_TOKEN_ID = 151667 |
| _AUDIO_PAD_TOKEN_ID = 512 |
| _SOSP_TOKEN_ID = 151646 |
| _EOSP_TOKEN_ID = 16384 |
|
|
|
|
| class MossSpeechProcessorKwargs(ProcessingKwargs, total=False): |
| """Default keyword argument groups supported by :class:`MossSpeechProcessor`.""" |
|
|
| _defaults = { |
| "common_kwargs": { |
| "return_tensors": "pt", |
| "padding": True, |
| } |
| } |
|
|
|
|
| @dataclass |
| class MossSpeechChatSample: |
| """Per-sample MossSpeech input with two-channel token grid.""" |
|
|
| input_ids_2d: "torch.LongTensor" |
| label_ids_2d: Optional["torch.LongTensor"] = None |
|
|
|
|
| @dataclass |
| class MossSpeechBatchInput: |
| """Batched MossSpeech tensors returned by the processor.""" |
|
|
| input_ids: "torch.LongTensor" |
| attention_mask: "torch.LongTensor" |
| labels: Optional["torch.LongTensor"] = None |
|
|
|
|
| @dataclass |
| class MossSpeechResponse: |
| """Decoded MossSpeech output item containing text and optional audio.""" |
|
|
| audio: Optional["torch.Tensor"] = None |
| generated_text: str = "" |
| sampling_rate: Optional[int] = None |
|
|
|
|
| @dataclass |
| class _MossSpeechInputSegment: |
| """Internal helper representing either text or audio tokens.""" |
|
|
| text: Optional[str] = None |
| audio_tokens: Optional["torch.Tensor"] = None |
| tokenized_text: Optional["torch.Tensor"] = None |
|
|
| def __post_init__(self) -> None: |
| if self.text is None and self.tokenized_text is None: |
| self.text = "" |
| if self.audio_tokens is not None and self.audio_tokens.dim() != 1: |
| raise ValueError("`audio_tokens` must be a 1D tensor of codec token ids.") |
|
|
| def to_tensor(self, tokenizer: AutoTokenizer) -> "torch.Tensor": |
| if self.tokenized_text is None: |
| tokenized = tokenizer( |
| self.text, |
| return_tensors="pt", |
| truncation=True, |
| max_length=999999, |
| padding=False, |
| add_special_tokens=False, |
| )["input_ids"].to(dtype=torch.long) |
| else: |
| tokenized = self.tokenized_text.unsqueeze(0).to(dtype=torch.long) |
|
|
| if self.audio_tokens is None: |
| audio_channel = torch.full_like(tokenized, _AUDIO_PAD_TOKEN_ID) |
| return torch.cat([tokenized, audio_channel], dim=0) |
|
|
| audio_tokens = self.audio_tokens.reshape(1, -1).to(dtype=torch.long) |
| text_channel = torch.full((1, audio_tokens.shape[1]), _TEXT_PLACEHOLDER_TOKEN_ID, dtype=torch.long) |
|
|
| sosp = torch.tensor([[_SOSP_TOKEN_ID]], dtype=torch.long) |
| text_pad = torch.tensor([[_TEXT_PLACEHOLDER_TOKEN_ID]], dtype=torch.long) |
| eosp = torch.tensor([[_EOSP_TOKEN_ID]], dtype=torch.long) |
| audio_pad = torch.tensor([[_AUDIO_PAD_TOKEN_ID]], dtype=torch.long) |
|
|
| text_channel = torch.cat([sosp, text_channel, text_pad], dim=1) |
| audio_channel = torch.cat([audio_pad, audio_tokens, eosp], dim=1) |
| return torch.cat([text_channel, audio_channel], dim=0) |
|
|
|
|
| class MossSpeechSampleProcessor: |
| """Formats a structured conversation into MossSpeech grid tokens.""" |
|
|
| def __init__( |
| self, |
| tokenizer: AutoTokenizer, |
| audio_codec: AutoModel, |
| default_system_prompts: Mapping[str, str], |
| ) -> None: |
| self.tokenizer = tokenizer |
| self.audio_codec = audio_codec |
| self.default_system_prompts = default_system_prompts |
|
|
| def prepare_sample( |
| self, |
| conversation: Sequence[Mapping[str, Any]], |
| output_modality: str, |
| ) -> MossSpeechChatSample: |
| if len(conversation) == 0: |
| raise ValueError("`conversation` must contain at least one turn.") |
|
|
| segments: list[_MossSpeechInputSegment] = [] |
| for turn in conversation: |
| role = turn.get("role") |
| if role not in {"user", "assistant", "system"}: |
| raise ValueError(f"Unsupported role `{role}` detected.") |
|
|
| segments.append(_MossSpeechInputSegment(text=f"<|im_start|>{role}\n")) |
| content = turn.get("content") |
| if isinstance(content, Mapping): |
| audio_path = content.get("path") |
| if audio_path is None: |
| raise ValueError("Audio turn content must include a `path` entry.") |
| encoded = self.audio_codec.encode([audio_path])[0] |
| segments.append(_MossSpeechInputSegment(audio_tokens=torch.tensor(encoded, dtype=torch.long))) |
| else: |
| segments.append(_MossSpeechInputSegment(text=str(content))) |
| segments.append(_MossSpeechInputSegment(text="<|im_end|>\n")) |
|
|
| if conversation[0].get("role") != "system": |
| system_prompt = self.default_system_prompts.get(output_modality) |
| if system_prompt is None: |
| raise KeyError(f"Missing default system prompt for modality `{output_modality}`.") |
| segments.extend( |
| [ |
| _MossSpeechInputSegment(text="<|im_start|>system\n"), |
| _MossSpeechInputSegment(text=system_prompt), |
| _MossSpeechInputSegment(text="<|im_end|>\n"), |
| ] |
| ) |
|
|
| if output_modality == "text": |
| segments.append(_MossSpeechInputSegment(text="<|im_start|>assistant\n")) |
| elif output_modality == "audio": |
| segments.append(_MossSpeechInputSegment(text="<|im_start|>assistant\n<|object_ref_start|>")) |
| else: |
| raise NotImplementedError("Supported modalities are `text` and `audio`.") |
|
|
| input_tensors = [segment.to_tensor(self.tokenizer) for segment in segments] |
| input_ids = torch.cat(input_tensors, dim=1) |
| return MossSpeechChatSample(input_ids_2d=input_ids) |
|
|
| def collate( |
| self, |
| samples: Sequence[MossSpeechChatSample], |
| *, |
| pad_token_id: int, |
| audio_pad_token_id: int, |
| ) -> MossSpeechBatchInput: |
| if len(samples) == 0: |
| raise ValueError("`samples` must not be empty.") |
|
|
| channel_count = samples[0].input_ids_2d.shape[0] |
| max_length = max(sample.input_ids_2d.shape[1] for sample in samples) |
|
|
| padded_inputs: list["torch.Tensor"] = [] |
| attention_masks: list["torch.Tensor"] = [] |
|
|
| for sample in samples: |
| seq_len = sample.input_ids_2d.shape[1] |
| pad_len = max_length - seq_len |
|
|
| pad_grid = torch.full((channel_count, pad_len), audio_pad_token_id, dtype=torch.long) |
| pad_grid[0] = pad_grid[0].fill_(pad_token_id) |
| padded_inputs.append(torch.cat([pad_grid, sample.input_ids_2d], dim=1)) |
|
|
| attention_prefix = torch.zeros(pad_len, dtype=torch.long) |
| attention_body = torch.ones(seq_len, dtype=torch.long) |
| attention_masks.append(torch.cat([attention_prefix, attention_body], dim=0)) |
|
|
| input_ids = torch.stack(padded_inputs).permute(0, 2, 1) |
| attention_mask = torch.stack(attention_masks) |
|
|
| return MossSpeechBatchInput( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=None, |
| ) |
|
|
|
|
| class MossSpeechProcessor(ProcessorMixin): |
| r"""Combines MossSpeech tokenizer and codec for unified text/audio processing.""" |
|
|
| tokenizer_class = "AutoTokenizer" |
| audio_codec_class = "PreTrainedModel" |
| attributes = ["tokenizer", "audio_codec"] |
|
|
| def __init__( |
| self, |
| tokenizer, |
| audio_codec, |
| **kwargs, |
| ) -> None: |
| super().__init__(tokenizer=tokenizer, audio_codec=audio_codec, **kwargs) |
| self.default_system_prompts = { |
| "text": "You are a helpful assistant. Respond with text outputs.", |
| "audio": "You are a helpful assistant. Respond with spoken outputs.", |
| } |
| self.sample_processor = MossSpeechSampleProcessor( |
| tokenizer=self.tokenizer, |
| audio_codec=self.audio_codec, |
| default_system_prompts=self.default_system_prompts, |
| ) |
| self.sosp_token_id = _SOSP_TOKEN_ID |
| self.eosp_token_id = _EOSP_TOKEN_ID |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| pretrained_model_name_or_path: Union[str, os.PathLike[str]], |
| trust_remote_code: bool = True, |
| **kwargs: Any, |
| ) -> "MossSpeechProcessor": |
| kwargs.pop("_from_auto", None) |
| codec_path = kwargs.pop("codec_path", None) |
| if codec_path is None: |
| raise ValueError("`codec_path` must be supplied to load the MossSpeech codec.") |
|
|
| device = kwargs.pop("device", None) or "cpu" |
| tokenizer = AutoTokenizer.from_pretrained( |
| pretrained_model_name_or_path, |
| trust_remote_code=trust_remote_code, |
| **kwargs, |
| ) |
| audio_codec = AutoModel.from_pretrained(codec_path, trust_remote_code=True).to(device) |
| return cls(tokenizer=tokenizer, audio_codec=audio_codec) |
|
|
| def __call__( |
| self, |
| data: Union[Mapping[str, Any], Sequence[Sequence[Mapping[str, Any]]]], |
| output_modalities: Union[str, Sequence[str]], |
| **kwargs: Any, |
| ) -> BatchEncoding: |
| if isinstance(data, Mapping): |
| samples: list[Sequence[Mapping[str, Any]]] = [data] |
| elif isinstance(data, Sequence): |
| if len(data) == 0: |
| raise ValueError("`data` must contain at least one sample.") |
| if all(isinstance(turn, Mapping) for turn in data): |
| samples = [data] |
| else: |
| samples = list(data) |
| else: |
| raise TypeError("`data` must be a conversation dictionary or a sequence of conversations.") |
|
|
| if isinstance(output_modalities, str): |
| output_modalities = [output_modalities] * len(samples) |
| elif len(output_modalities) != len(samples): |
| raise ValueError("`output_modalities` length must match number of samples.") |
|
|
| merged_kwargs = self._merge_kwargs(MossSpeechProcessorKwargs, **kwargs) |
| common_kwargs = merged_kwargs["common_kwargs"] |
| padding = common_kwargs.get("padding", True) |
| if not padding: |
| raise NotImplementedError("Only padded batches are currently supported.") |
| return_tensors = common_kwargs.get("return_tensors", "pt") |
|
|
| chat_samples = [ |
| self.sample_processor.prepare_sample(conversation, modality) |
| for conversation, modality in zip(samples, output_modalities) |
| ] |
|
|
| pad_token_id = self.tokenizer.pad_token_id |
| if pad_token_id is None: |
| raise ValueError("Tokenizer must define `pad_token_id` for MossSpeech processing.") |
|
|
| batch_inputs = self.sample_processor.collate( |
| chat_samples, |
| pad_token_id=pad_token_id, |
| audio_pad_token_id=_AUDIO_PAD_TOKEN_ID, |
| ) |
| payload = {key: value for key, value in asdict(batch_inputs).items() if value is not None} |
| return BatchEncoding(payload, tensor_type=return_tensors) |
|
|
| def decode( |
| self, |
| token_ids: "torch.Tensor", |
| output_modalities: Union[str, Sequence[str]], |
| *args: Any, |
| **kwargs: Any, |
| ) -> list[MossSpeechResponse]: |
| if token_ids.dim() != 3: |
| raise ValueError("`token_ids` must be shaped as (batch, sequence_length, channels).") |
|
|
| if isinstance(output_modalities, str): |
| output_modalities = [output_modalities] * token_ids.shape[0] |
| elif len(output_modalities) != token_ids.shape[0]: |
| raise ValueError("`output_modalities` length must equal the batch size.") |
|
|
| if token_ids.shape[0] != 1: |
| raise NotImplementedError("Batch decoding is not yet implemented for MossSpeech.") |
|
|
| responses: list[MossSpeechResponse] = [] |
| for batch_index, modality in enumerate(output_modalities): |
| tokens = token_ids[batch_index].int().cpu() |
| if tokens.shape[0] == 2: |
| pass |
| elif tokens.shape[-1] == 2: |
| tokens = tokens.transpose(0, 1) |
| else: |
| raise ValueError( |
| "Decoded tensor must contain exactly two channels (text and audio)." |
| ) |
|
|
| if modality == "audio": |
| prefix = torch.tensor([[_SOSP_TOKEN_ID], [_AUDIO_PAD_TOKEN_ID]], dtype=torch.long) |
| tokens = torch.cat([prefix, tokens], dim=1) |
|
|
| text_channel = tokens[0, :-1] |
| audio_channel = tokens[1, :-1] |
| decoded_text = ( |
| self.tokenizer.decode(text_channel, skip_special_tokens=True) |
| .replace("<|empty|>", ".") |
| .replace("<|end_empty|>", ":") |
| ) |
|
|
| sosp_indices = (text_channel == self.sosp_token_id).nonzero(as_tuple=True)[0] |
| eosp_indices = (audio_channel == self.eosp_token_id).nonzero(as_tuple=True)[0] |
|
|
| waveform: Optional["torch.Tensor"] = None |
| if len(sosp_indices) > 0: |
| start_idx = sosp_indices[0].item() + 1 |
| stop_idx = eosp_indices[0].item() if len(eosp_indices) > 0 else text_channel.shape[0] |
| audio_tokens = tokens[:, start_idx:stop_idx] |
| flattened_audio_tokens = audio_tokens[1].reshape(-1).tolist() |
|
|
| continuation = "".join(f"<{token}>" for token in flattened_audio_tokens) |
| codec_tokens = [int(match) for match in re.findall(r"(\d+)>", continuation)] |
| codec_tensor = torch.tensor(codec_tokens, dtype=torch.long).reshape(1, 1, -1) |
|
|
| prompt_path = kwargs.get("decoder_audio_prompt_path") |
| if prompt_path is None: |
| raise ValueError("`decoder_audio_prompt_path` must be provided to decode audio outputs.") |
| codec_output = self.audio_codec.decode(codec_tensor, prompt_speech=prompt_path) |
| waveform = codec_output["syn_wav_list"][0].reshape(1, -1).detach().cpu() |
|
|
| responses.append( |
| MossSpeechResponse( |
| audio=waveform, |
| generated_text=decoded_text, |
| sampling_rate=24000 if waveform is not None else None, |
| ) |
| ) |
|
|
| return responses |
|
|
|
|
| __all__ = [ |
| "MossSpeechProcessor", |
| "MossSpeechProcessorKwargs", |
| "MossSpeechResponse", |
| ] |
|
|