| |
| |
| |
|
|
| |
| |
|
|
| import io |
| import os |
| import math |
| import requests |
|
|
| import torch |
| import torch.nn as nn |
| import torchaudio |
| import logging |
|
|
| |
| from .multimodal_preprocessors import SimpleTokenizer |
| from PIL import Image |
| |
| |
| |
|
|
| from torchvision import transforms |
| from torchvision.transforms._transforms_video import NormalizeVideo |
|
|
| DEFAULT_AUDIO_FRAME_SHIFT_MS = 10 |
|
|
| BPE_PATH = "bpe/bpe_simple_vocab_16e6.txt.gz" |
|
|
|
|
| def waveform2melspec(waveform, sample_rate, num_mel_bins, target_length): |
| |
| waveform -= waveform.mean() |
| fbank = torchaudio.compliance.kaldi.fbank( |
| waveform, |
| htk_compat=True, |
| sample_frequency=sample_rate, |
| use_energy=False, |
| window_type="hanning", |
| num_mel_bins=num_mel_bins, |
| dither=0.0, |
| frame_length=25, |
| frame_shift=DEFAULT_AUDIO_FRAME_SHIFT_MS, |
| ) |
| |
| fbank = fbank.transpose(0, 1) |
| |
| n_frames = fbank.size(1) |
| p = target_length - n_frames |
| |
| if abs(p) / n_frames > 0.2: |
| logging.warning( |
| "Large gap between audio n_frames(%d) and " |
| "target_length (%d). Is the audio_target_length " |
| "setting correct?", |
| n_frames, |
| target_length, |
| ) |
| |
| if p > 0: |
| fbank = torch.nn.functional.pad(fbank, (0, p), mode="constant", value=0) |
| elif p < 0: |
| fbank = fbank[:, 0:target_length] |
| |
| |
| fbank = fbank.unsqueeze(0) |
| return fbank |
|
|
|
|
| def get_clip_timepoints(clip_sampler, duration): |
| |
| all_clips_timepoints = [] |
| is_last_clip = False |
| end = 0.0 |
| while not is_last_clip: |
| start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) |
| all_clips_timepoints.append((start, end)) |
| return all_clips_timepoints |
|
|
|
|
| def load_and_transform_vision_data(image_paths, device, client=None): |
| if image_paths is None: |
| return None |
|
|
| image_ouputs = [] |
| for image_path in image_paths: |
| data_transform = transforms.Compose( |
| [ |
| transforms.Resize( |
| 224, interpolation=transforms.InterpolationMode.BICUBIC |
| ), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=(0.48145466, 0.4578275, 0.40821073), |
| std=(0.26862954, 0.26130258, 0.27577711), |
| ), |
| ] |
| ) |
| if os.path.exists(image_path): |
| with open(image_path, "rb") as fopen: |
| image = Image.open(fopen).convert("RGB") |
| elif image_path.startswith("s3://") and client is not None: |
| image = Image.open(io.BytesIO(client.get(image_path))).convert("RGB") |
| elif image_path.startswith("http"): |
| image = Image.open(requests.get(image_path, stream=True).raw).convert( |
| "RGB" |
| ) |
| else: |
| raise ValueError(f"Invalid image path: {image_path}") |
|
|
| image = data_transform(image).to(device) |
| image_ouputs.append(image) |
| return torch.stack(image_ouputs, dim=0) |
|
|
| def transform_vision_data(images, device): |
| image_ouputs = [] |
| for img in images: |
| data_transform = transforms.Compose( |
| [ |
| transforms.Resize( |
| 224, interpolation=transforms.InterpolationMode.BICUBIC |
| ), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=(0.48145466, 0.4578275, 0.40821073), |
| std=(0.26862954, 0.26130258, 0.27577711), |
| ), |
| ] |
| ) |
| image = data_transform(img).to(device) |
| image_ouputs.append(image) |
| return torch.stack(image_ouputs, dim=0) |
|
|
|
|
| def load_and_transform_thermal_data(thermal_paths, device): |
| if thermal_paths is None: |
| return None |
|
|
| thermal_ouputs = [] |
| for thermal_path in thermal_paths: |
| data_transform = transforms.Compose( |
| [ |
| transforms.Resize( |
| 224, interpolation=transforms.InterpolationMode.BICUBIC |
| ), |
| transforms.CenterCrop(224), |
| transforms.ToTensor(), |
| ] |
| ) |
| with open(thermal_path, "rb") as fopen: |
| thermal = Image.open(fopen).convert("L") |
| thermal = data_transform(thermal).to(device) |
| thermal_ouputs.append(thermal) |
| return torch.stack(thermal_ouputs, dim=0) |
|
|
|
|
| def load_and_transform_text(text, device): |
| if text is None: |
| return None |
| tokenizer = SimpleTokenizer(bpe_path=BPE_PATH) |
| tokens = [tokenizer(t).unsqueeze(0).to(device) for t in text] |
| tokens = torch.cat(tokens, dim=0) |
| return tokens |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
|
|
|
|
| def get_clip_timepoints(clip_sampler, duration): |
| |
| all_clips_timepoints = [] |
| is_last_clip = False |
| end = 0.0 |
| while not is_last_clip: |
| start, end, _, _, is_last_clip = clip_sampler(end, duration, annotation=None) |
| all_clips_timepoints.append((start, end)) |
| return all_clips_timepoints |
|
|
|
|
| def crop_boxes(boxes, x_offset, y_offset): |
| """ |
| Peform crop on the bounding boxes given the offsets. |
| Args: |
| boxes (ndarray or None): bounding boxes to peform crop. The dimension |
| is `num boxes` x 4. |
| x_offset (int): cropping offset in the x axis. |
| y_offset (int): cropping offset in the y axis. |
| Returns: |
| cropped_boxes (ndarray or None): the cropped boxes with dimension of |
| `num boxes` x 4. |
| """ |
| cropped_boxes = boxes.copy() |
| cropped_boxes[:, [0, 2]] = boxes[:, [0, 2]] - x_offset |
| cropped_boxes[:, [1, 3]] = boxes[:, [1, 3]] - y_offset |
|
|
| return cropped_boxes |
|
|
|
|
| def uniform_crop(images, size, spatial_idx, boxes=None, scale_size=None): |
| """ |
| Perform uniform spatial sampling on the images and corresponding boxes. |
| Args: |
| images (tensor): images to perform uniform crop. The dimension is |
| `num frames` x `channel` x `height` x `width`. |
| size (int): size of height and weight to crop the images. |
| spatial_idx (int): 0, 1, or 2 for left, center, and right crop if width |
| is larger than height. Or 0, 1, or 2 for top, center, and bottom |
| crop if height is larger than width. |
| boxes (ndarray or None): optional. Corresponding boxes to images. |
| Dimension is `num boxes` x 4. |
| scale_size (int): optinal. If not None, resize the images to scale_size before |
| performing any crop. |
| Returns: |
| cropped (tensor): images with dimension of |
| `num frames` x `channel` x `size` x `size`. |
| cropped_boxes (ndarray or None): the cropped boxes with dimension of |
| `num boxes` x 4. |
| """ |
| assert spatial_idx in [0, 1, 2] |
| ndim = len(images.shape) |
| if ndim == 3: |
| images = images.unsqueeze(0) |
| height = images.shape[2] |
| width = images.shape[3] |
|
|
| if scale_size is not None: |
| if width <= height: |
| width, height = scale_size, int(height / width * scale_size) |
| else: |
| width, height = int(width / height * scale_size), scale_size |
| images = torch.nn.functional.interpolate( |
| images, |
| size=(height, width), |
| mode="bilinear", |
| align_corners=False, |
| ) |
|
|
| y_offset = int(math.ceil((height - size) / 2)) |
| x_offset = int(math.ceil((width - size) / 2)) |
|
|
| if height > width: |
| if spatial_idx == 0: |
| y_offset = 0 |
| elif spatial_idx == 2: |
| y_offset = height - size |
| else: |
| if spatial_idx == 0: |
| x_offset = 0 |
| elif spatial_idx == 2: |
| x_offset = width - size |
| cropped = images[:, :, y_offset : y_offset + size, x_offset : x_offset + size] |
| cropped_boxes = crop_boxes(boxes, x_offset, y_offset) if boxes is not None else None |
| if ndim == 3: |
| cropped = cropped.squeeze(0) |
| return cropped, cropped_boxes |
|
|
|
|
| class SpatialCrop(nn.Module): |
| """ |
| Convert the video into 3 smaller clips spatially. Must be used after the |
| temporal crops to get spatial crops, and should be used with |
| -2 in the spatial crop at the slowfast augmentation stage (so full |
| frames are passed in here). Will return a larger list with the |
| 3x spatial crops as well. |
| """ |
|
|
| def __init__(self, crop_size: int = 224, num_crops: int = 3): |
| super().__init__() |
| self.crop_size = crop_size |
| if num_crops == 3: |
| self.crops_to_ext = [0, 1, 2] |
| self.flipped_crops_to_ext = [] |
| elif num_crops == 1: |
| self.crops_to_ext = [1] |
| self.flipped_crops_to_ext = [] |
| else: |
| raise NotImplementedError("Nothing else supported yet") |
|
|
| def forward(self, videos): |
| """ |
| Args: |
| videos: A list of C, T, H, W videos. |
| Returns: |
| videos: A list with 3x the number of elements. Each video converted |
| to C, T, H', W' by spatial cropping. |
| """ |
| assert isinstance(videos, list), "Must be a list of videos after temporal crops" |
| assert all([video.ndim == 4 for video in videos]), "Must be (C,T,H,W)" |
| res = [] |
| for video in videos: |
| for spatial_idx in self.crops_to_ext: |
| res.append(uniform_crop(video, self.crop_size, spatial_idx)[0]) |
| if not self.flipped_crops_to_ext: |
| continue |
| flipped_video = transforms.functional.hflip(video) |
| for spatial_idx in self.flipped_crops_to_ext: |
| res.append(uniform_crop(flipped_video, self.crop_size, spatial_idx)[0]) |
| return res |
|
|
| """ |
| def load_and_transform_video_data( |
| video_paths, |
| device, |
| clip_duration=2, |
| clips_per_video=5, |
| sample_rate=16000, |
| ): |
| if video_paths is None: |
| return None |
| |
| video_outputs = [] |
| video_transform = transforms.Compose( |
| [ |
| pv_transforms.ShortSideScale(224), |
| NormalizeVideo( |
| mean=(0.48145466, 0.4578275, 0.40821073), |
| std=(0.26862954, 0.26130258, 0.27577711), |
| ), |
| ] |
| ) |
| |
| clip_sampler = ConstantClipsPerVideoSampler( |
| clip_duration=clip_duration, clips_per_video=clips_per_video |
| ) |
| frame_sampler = pv_transforms.UniformTemporalSubsample(num_samples=clip_duration) |
| |
| for video_path in video_paths: |
| video = EncodedVideo.from_path( |
| video_path, |
| decoder="decord", |
| decode_audio=False, |
| **{"sample_rate": sample_rate}, |
| ) |
| |
| all_clips_timepoints = get_clip_timepoints(clip_sampler, video.duration) |
| |
| all_video = [] |
| for clip_timepoints in all_clips_timepoints: |
| # Read the clip, get frames |
| clip = video.get_clip(clip_timepoints[0], clip_timepoints[1]) |
| if clip is None: |
| raise ValueError("No clip found") |
| video_clip = frame_sampler(clip["video"]) |
| video_clip = video_clip / 255.0 # since this is float, need 0-1 |
| |
| all_video.append(video_clip) |
| |
| all_video = [video_transform(clip) for clip in all_video] |
| all_video = SpatialCrop(224, num_crops=3)(all_video) |
| |
| all_video = torch.stack(all_video, dim=0) |
| video_outputs.append(all_video) |
| |
| return torch.stack(video_outputs, dim=0).to(device) |
| """ |