| import os |
| from tracemalloc import start |
| import warnings |
| import glob |
| import random |
| import numpy as np |
| from PIL import Image |
|
|
| import torch |
| from torch.utils.data import Dataset |
| import torchvision |
| import torch.distributed as dist |
|
|
| from decord import VideoReader |
| from pcache_fileio import fileio |
| from pcache_fileio.oss_conf import OssConfigFactory |
|
|
|
|
| class SakugaRefDataset(Dataset): |
| def __init__( |
| self, |
| |
| video_frames=25, |
| ref_jump_frames=36, |
| base_folder='data/samples/', |
| file_list=None, |
| temporal_sample=None, |
| transform=None, |
| seed=42, |
| ): |
| """ |
| Args: |
| num_samples (int): Number of samples in the dataset. |
| channels (int): Number of channels, default is 3 for RGB. |
| """ |
| |
| |
| self.base_folder = base_folder |
|
|
| self.file_list = file_list |
| if file_list is None: |
| self.video_lists = glob.glob(os.path.join(self.base_folder, '*.mp4')) |
| else: |
| |
| self.video_lists = [] |
| with open(file_list, 'r') as f: |
| for line in f: |
| video_path = line.strip() |
| self.video_lists.append(os.path.join(self.base_folder, video_path)) |
|
|
| self.num_samples = len(self.video_lists) |
| self.channels = 3 |
| |
| |
| self.video_frames = video_frames |
| self.ref_jump_frames = ref_jump_frames |
| self.temporal_sample = temporal_sample |
| self.transform = transform |
|
|
| self.seed = seed |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def get_sample(self, idx): |
| """ |
| Args: |
| idx (int): Index of the sample to return. |
| |
| Returns: |
| dict: A dictionary containing the 'pixel_values' tensor of shape (16, channels, 320, 512). |
| """ |
|
|
| |
| path = self.video_lists[idx] |
|
|
| if self.file_list is not None: |
| with open(path, 'rb') as f: |
| vframes = VideoReader(f) |
| else: |
| vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW') |
| total_frames = len(vframes) |
|
|
| |
| ref_frame_ind, end_frame_ind = self.temporal_sample(total_frames) |
| if not end_frame_ind - ref_frame_ind >= self.video_frames+self.ref_jump_frames: |
| raise ValueError(f'video {path} does not have enough frames') |
| start_frame_ind = ref_frame_ind + self.ref_jump_frames |
| frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, self.video_frames, dtype=int) |
| frame_indice = np.insert(frame_indice, 0, ref_frame_ind) |
| if self.file_list is not None: |
| video = torch.from_numpy(vframes.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() |
| else: |
| video = vframes[frame_indice] |
|
|
| |
| pixel_values = self.transform(video) |
|
|
| return {'pixel_values': pixel_values} |
|
|
| def __getitem__(self, idx): |
| |
|
|
| while(True): |
| try: |
| |
| |
| item = self.get_sample(idx) |
| return item |
| except: |
| |
| idx = np.random.randint(0, len(self.video_lists) - 1) |
|
|
|
|
|
|
| |
| |