| import cv2 |
| import numpy as np |
| import torch |
| |
| from torchvision import transforms |
| from transformers import ProcessorMixin, BatchEncoding |
| from transformers.image_processing_utils import BatchFeature |
| from torch.nn import functional as F |
|
|
|
|
| def make_list_of_images(x): |
| if not isinstance(x, list): |
| return [x] |
| return x |
|
|
|
|
| |
|
|
| def torchaudio_loader(path): |
| return torchaudio.load(path) |
|
|
| def int16_to_float32_torch(x): |
| return (x / 32767.0).type(torch.float32) |
|
|
| def float32_to_int16_torch(x): |
| x = torch.clamp(x, min=-1., max=1.) |
| return (x * 32767.).type(torch.int16) |
|
|
| DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 |
|
|
| class AudioTransform: |
| def __init__(self, config): |
| self.sample_rate = config.audio_sample_rate |
| self.num_mel_bins = config.num_mel_bins |
| self.target_length = config.target_length |
| self.audio_mean = config.audio_mean |
| self.audio_std = config.audio_std |
| |
| |
| self.norm = transforms.Normalize(mean=self.audio_mean, std=self.audio_std) |
|
|
| def __call__(self, audio_data_and_origin_sr): |
| audio_data, origin_sr = audio_data_and_origin_sr |
| if self.sample_rate != origin_sr: |
| |
| audio_data = torchaudio.functional.resample(audio_data, orig_freq=origin_sr, new_freq=self.sample_rate) |
| waveform_melspec = self.waveform2melspec(audio_data[0]) |
| return self.norm(waveform_melspec) |
|
|
| def waveform2melspec(self, audio_data): |
| max_len = self.target_length * self.sample_rate // 100 |
| if audio_data.shape[-1] > max_len: |
| mel = self.get_mel(audio_data) |
| |
| chunk_frames = self.target_length |
| total_frames = mel.shape[0] |
| ranges = np.array_split(list(range(0, total_frames - chunk_frames + 1)), 3) |
| |
| |
| |
| |
| if len(ranges[1]) == 0: |
| ranges[1] = [0] |
| if len(ranges[2]) == 0: |
| ranges[2] = [0] |
| |
| |
| |
| |
| idx_front = ranges[0][0] |
| idx_middle = ranges[1][0] |
| idx_back = ranges[2][0] |
| |
| mel_chunk_front = mel[idx_front:idx_front + chunk_frames, :] |
| mel_chunk_middle = mel[idx_middle:idx_middle + chunk_frames, :] |
| mel_chunk_back = mel[idx_back:idx_back + chunk_frames, :] |
| |
| mel_fusion = torch.stack([mel_chunk_front, mel_chunk_middle, mel_chunk_back], dim=0) |
| elif audio_data.shape[-1] < max_len: |
| n_repeat = int(max_len / len(audio_data)) |
| audio_data = audio_data.repeat(n_repeat) |
| audio_data = F.pad( |
| audio_data, |
| (0, max_len - len(audio_data)), |
| mode="constant", |
| value=0, |
| ) |
| mel = self.get_mel(audio_data) |
| mel_fusion = torch.stack([mel, mel, mel], dim=0) |
| else: |
| mel = self.get_mel(audio_data) |
| mel_fusion = torch.stack([mel, mel, mel], dim=0) |
|
|
| |
| p = self.target_length - mel_fusion.shape[1] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| if p > 0: |
| m = torch.nn.ZeroPad2d((0, 0, 0, p)) |
| mel_fusion = m(mel_fusion) |
| elif p < 0: |
| mel_fusion = mel_fusion[:, 0: self.target_length, :] |
|
|
| mel_fusion = mel_fusion.transpose(1, 2) |
| return mel_fusion |
|
|
| def get_mel(self, audio_data): |
| |
| audio_data -= audio_data.mean() |
| mel = torchaudio.compliance.kaldi.fbank( |
| audio_data.unsqueeze(0), |
| htk_compat=True, |
| sample_frequency=self.sample_rate, |
| use_energy=False, |
| window_type="hanning", |
| num_mel_bins=self.num_mel_bins, |
| dither=0.0, |
| frame_length=25, |
| frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, |
| ) |
| return mel |
|
|
| def get_audio_transform(config): |
| config = config.vision_config |
| return AudioTransform(config) |
|
|
|
|
| def load_and_transform_audio( |
| audio_path, |
| transform, |
| ): |
| waveform_and_sr = torchaudio_loader(audio_path) |
| audio_outputs = transform(waveform_and_sr) |
|
|
| return audio_outputs |
|
|
| class LanguageBindAudioProcessor(ProcessorMixin): |
| attributes = [] |
| tokenizer_class = ("LanguageBindAudioTokenizer") |
|
|
| def __init__(self, config, tokenizer=None, **kwargs): |
| super().__init__(**kwargs) |
| self.config = config |
| self.transform = get_audio_transform(config) |
| self.image_processor = load_and_transform_audio |
| self.tokenizer = tokenizer |
|
|
| def __call__(self, images=None, text=None, context_length=77, return_tensors=None, **kwargs): |
| if text is None and images is None: |
| raise ValueError("You have to specify either text or images. Both cannot be none.") |
|
|
| if text is not None: |
| encoding = self.tokenizer(text, max_length=context_length, padding='max_length', |
| truncation=True, return_tensors=return_tensors, **kwargs) |
|
|
| if images is not None: |
| images = make_list_of_images(images) |
| image_features = [self.image_processor(image, self.transform) for image in images] |
| image_features = torch.stack(image_features) |
|
|
| if text is not None and images is not None: |
| encoding["pixel_values"] = image_features |
| return encoding |
| elif text is not None: |
| return encoding |
| else: |
| return {"pixel_values": image_features} |
|
|
| def batch_decode(self, skip_special_tokens=True, *args, **kwargs): |
| """ |
| This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please |
| refer to the docstring of this method for more information. |
| """ |
| return self.tokenizer.batch_decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) |
|
|
| def decode(self, skip_special_tokens=True, *args, **kwargs): |
| """ |
| This method forwards all its arguments to CLIPTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to |
| the docstring of this method for more information. |
| """ |
| return self.tokenizer.decode(*args, skip_special_tokens=skip_special_tokens, **kwargs) |
|
|