| | from typing import Optional, Union |
| |
|
| | import torch |
| | import transformers |
| | from transformers import ProcessorMixin |
| |
|
| | try: |
| | from .asr_config import ASRConfig |
| | except ImportError: |
| | from asr_config import ASRConfig |
| |
|
| |
|
| | class ASRProcessor(ProcessorMixin): |
| | """Processor for Whisper-based ASR models.""" |
| |
|
| | attributes = ["feature_extractor", "tokenizer"] |
| | feature_extractor_class = "AutoFeatureExtractor" |
| | tokenizer_class = "AutoTokenizer" |
| | AUDIO_TOKEN = "<audio>" |
| | TRANSCRIBE_PROMPT = "Transcribe speech to text" |
| | |
| | DEFAULT_ENCODER_CONV_LAYERS = [(1, 3, 1), (1, 3, 2)] |
| |
|
| | def __init__( |
| | self, |
| | feature_extractor, |
| | tokenizer, |
| | projector=None, |
| | encoder_conv_layers: Optional[list] = None, |
| | ): |
| | """Initialize the ASR processor. |
| | |
| | Args: |
| | feature_extractor: Audio feature extractor (WhisperFeatureExtractor) |
| | tokenizer: Text tokenizer for the language model |
| | projector: Audio projector module (for computing output lengths) |
| | encoder_conv_layers: Conv layer specs [(pad, kernel, stride), ...] |
| | """ |
| | self.feature_extractor = feature_extractor |
| | self.tokenizer = tokenizer |
| | self.audio_token_id = tokenizer.convert_tokens_to_ids(self.AUDIO_TOKEN) |
| | self.projector = projector |
| | self.encoder_conv_layers = encoder_conv_layers or self.DEFAULT_ENCODER_CONV_LAYERS |
| |
|
| | def _compute_encoder_output_length(self, mel_length: int) -> int: |
| | """Compute encoder output length using conv layer formulas.""" |
| | length = mel_length |
| | for padding, kernel_size, stride in self.encoder_conv_layers: |
| | length = (length + 2 * padding - (kernel_size - 1) - 1) // stride + 1 |
| | return length |
| |
|
| | def __call__( |
| | self, |
| | audio: Optional[Union[list, "torch.Tensor"]] = None, |
| | text: Optional[str] = None, |
| | system_prompt: Optional[str] = None, |
| | return_tensors: str = "pt", |
| | **kwargs, |
| | ) -> dict: |
| | """Process audio and text inputs for inference. |
| | |
| | Args: |
| | audio: Raw audio waveform(s) |
| | text: Target transcription (optional, for training - but use DataCollator instead) |
| | system_prompt: Optional system prompt |
| | return_tensors: Return format ("pt" for PyTorch) |
| | |
| | Returns: |
| | Dict with input_features, input_ids, attention_mask |
| | """ |
| | result = {} |
| |
|
| | |
| | if audio is not None: |
| | audio_inputs = self.feature_extractor( |
| | audio, |
| | sampling_rate=getattr(self.feature_extractor, "sampling_rate", 16000), |
| | return_attention_mask=True, |
| | return_tensors=return_tensors, |
| | **kwargs, |
| | ) |
| | result["input_features"] = audio_inputs["input_features"] |
| | result["audio_attention_mask"] = audio_inputs["attention_mask"] |
| |
|
| | |
| | real_mel_len = int(audio_inputs["attention_mask"].sum(dim=-1).max().item()) |
| | encoder_output_len = self._compute_encoder_output_length(real_mel_len) |
| | num_audio_tokens = self.projector.get_output_length(encoder_output_len) |
| | else: |
| | num_audio_tokens = 0 |
| |
|
| | |
| | if num_audio_tokens > 0: |
| | user_content = self.AUDIO_TOKEN * num_audio_tokens + " " + self.TRANSCRIBE_PROMPT |
| | else: |
| | user_content = self.TRANSCRIBE_PROMPT |
| |
|
| | messages = [] |
| | if system_prompt: |
| | messages.append({"role": "system", "content": system_prompt}) |
| | messages.append({"role": "user", "content": user_content}) |
| | if text is not None: |
| | messages.append({"role": "assistant", "content": text}) |
| |
|
| | |
| | tokenized = self.tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=True, |
| | add_generation_prompt=(text is None), |
| | return_tensors=return_tensors, |
| | enable_thinking=False, |
| | ) |
| |
|
| | |
| | if isinstance(tokenized, torch.Tensor): |
| | input_ids = tokenized |
| | else: |
| | |
| | input_ids = tokenized.get("input_ids", tokenized.input_ids) |
| |
|
| | if input_ids.dim() == 1: |
| | input_ids = input_ids.unsqueeze(0) |
| |
|
| | result["input_ids"] = input_ids |
| | result["attention_mask"] = torch.ones_like(input_ids) |
| |
|
| | return result |
| |
|
| |
|
| | ASRProcessor.register_for_auto_class() |
| | transformers.AutoProcessor.register(ASRConfig, ASRProcessor) |
| |
|