| import torch |
| from datasets import load_dataset |
| from torch.utils.data import DataLoader |
| from transformers import BertTokenizer |
| import decord |
| import numpy as np |
| from tqdm import tqdm |
|
|
|
|
| FRAMES = 50 |
| H, W = 128, 128 |
| BATCH_SIZE = 8 |
| TEXT_MAX_LEN = 3000 |
|
|
|
|
| dataset = load_dataset("gaussalgo/webvid-10m", split="train") |
|
|
|
|
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") |
|
|
|
|
| class VideoDataset(torch.utils.data.Dataset): |
| def __init__(self, dataset): |
| self.dataset = dataset |
| self.decord_ctx = decord.cpu(0) |
|
|
| def __len__(self): |
| return len(self.dataset) |
|
|
| def __getitem__(self, idx): |
| item = self.dataset[idx] |
| |
|
|
| vr = decord.VideoReader(item["video_path"], ctx=self.decord_ctx) |
| frame_indices = np.linspace(0, len(vr)-1, FRAMES, dtype=int) |
| video = vr.get_batch(frame_indices).numpy() |
| video = torch.from_numpy(video).permute(3, 0, 1, 2).float() |
| |
| |
| video = F.interpolate(video, size=(H, W), mode="bilinear") |
| video = (video / 255.0) * 2 - 1 |
| |
| |
| text = tokenizer( |
| item["caption"], |
| padding="max_length", |
| truncation=True, |
| max_length=TEXT_MAX_LEN, |
| return_tensors="pt" |
| ).input_ids.squeeze(0) |
| |
| return {"video": video, "text": text} |
|
|
| |
| dataset = VideoDataset(dataset) |
| dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4) |