| import torch |
| import numpy as np |
| from transformers import AutoProcessor |
| from typing import Dict, List, Union |
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
| class DataCollatorSimpleSeamless: |
| def __init__( |
| self, |
| processor: str, |
| sample_rate: int = 16000, |
| max_audio_length_sec: float = 8.0, |
| max_text_length: int = 256, |
| normalization_type: str = "none" |
| ): |
| """Initialize the data collator. |
| |
| Args: |
| processor: The processor to use. |
| sample_rate: Audio sample rate. |
| max_audio_length_sec: Maximum audio length in seconds. |
| max_text_length: Maximum text length. |
| normalization_type: Type of normalization to apply to labels. Options: "log1p", "none" |
| """ |
| logger.info(f"Loading processor: {processor}") |
| self.processor = AutoProcessor.from_pretrained(processor) |
| |
| self.sample_rate = sample_rate |
| self.max_audio_sample_length = int(max_audio_length_sec * sample_rate) |
| self.max_text_length = max_text_length |
| self.normalization_type = normalization_type |
|
|
| def __call__(self, batch: List[Dict[str, Union[np.ndarray, str, float]]]) -> Dict[str, torch.Tensor]: |
| """Process a batch of raw features into model inputs.""" |
| |
| raw_audios = [item['raw_audio'] for item in batch] |
| raw_texts = [item['raw_text'] for item in batch] |
| |
| raw_audios = [torch.tensor(audio) for audio in raw_audios] |
| |
| audio_inputs = self.processor( |
| audios=raw_audios, |
| sampling_rate=self.sample_rate, |
| return_tensors="pt", |
| padding="longest", |
| truncation=True, |
| max_length=self.max_audio_sample_length, |
| ) |
| |
| text_inputs = self.processor( |
| text=raw_texts, |
| return_tensors="pt", |
| padding="longest", |
| truncation=True, |
| max_length=self.max_text_length, |
| ) |
|
|
| |
| is_translation = torch.tensor([item.get('is_translation', 0) for item in batch], dtype=torch.float32) |
| |
| |
| language_pair_id = torch.tensor([item.get('language_pair_id', 0) for item in batch], dtype=torch.long) |
|
|
| if 'labels' in batch[0]: |
| labels = [item['labels'] for item in batch] |
| labels = torch.tensor(labels, dtype=torch.float32) |
| |
| |
| if self.normalization_type == "log1p": |
| labels = torch.log1p(labels) |
| elif self.normalization_type == "none": |
| pass |
| else: |
| raise ValueError(f"Unknown normalization type: {self.normalization_type}") |
| else: |
| labels = None |
| |
| return { |
| 'input_features': audio_inputs['input_features'], |
| 'audio_attention_mask': audio_inputs.get('attention_mask', None) if audio_inputs.get('attention_mask') is not None else None, |
| 'input_ids': text_inputs['input_ids'], |
| 'text_attention_mask': text_inputs['attention_mask'], |
| 'is_translation': is_translation, |
| 'language_pair_id': language_pair_id, |
| **({'labels': labels} if labels is not None else {}) |
| } |