| |
| |
|
|
| |
| |
|
|
|
|
| import os |
| import gzip |
| import torch |
| import numpy as np |
| import torch.utils.data as data |
| from collections import defaultdict |
| from dataclasses import dataclass |
| from typing import List, Optional, Any, Dict, Tuple |
|
|
| from cotracker.datasets.utils import CoTrackerData |
| from cotracker.datasets.dataclass_utils import load_dataclass |
|
|
|
|
| @dataclass |
| class ImageAnnotation: |
| |
| path: str |
| |
| size: Tuple[int, int] |
|
|
|
|
| @dataclass |
| class DynamicReplicaFrameAnnotation: |
| """A dataclass used to load annotations from json.""" |
|
|
| |
| sequence_name: str |
| |
| frame_number: int |
| |
| frame_timestamp: float |
|
|
| image: ImageAnnotation |
| meta: Optional[Dict[str, Any]] = None |
|
|
| camera_name: Optional[str] = None |
| trajectories: Optional[str] = None |
|
|
|
|
| class DynamicReplicaDataset(data.Dataset): |
| def __init__( |
| self, |
| root, |
| split="valid", |
| traj_per_sample=256, |
| crop_size=None, |
| sample_len=-1, |
| only_first_n_samples=-1, |
| rgbd_input=False, |
| ): |
| super(DynamicReplicaDataset, self).__init__() |
| self.root = root |
| self.sample_len = sample_len |
| self.split = split |
| self.traj_per_sample = traj_per_sample |
| self.rgbd_input = rgbd_input |
| self.crop_size = crop_size |
| frame_annotations_file = f"frame_annotations_{split}.jgz" |
| self.sample_list = [] |
| with gzip.open( |
| os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8" |
| ) as zipfile: |
| frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation]) |
| seq_annot = defaultdict(list) |
| for frame_annot in frame_annots_list: |
| if frame_annot.camera_name == "left": |
| seq_annot[frame_annot.sequence_name].append(frame_annot) |
|
|
| for seq_name in seq_annot.keys(): |
| seq_len = len(seq_annot[seq_name]) |
|
|
| step = self.sample_len if self.sample_len > 0 else seq_len |
| counter = 0 |
|
|
| for ref_idx in range(0, seq_len, step): |
| sample = seq_annot[seq_name][ref_idx : ref_idx + step] |
| self.sample_list.append(sample) |
| counter += 1 |
| if only_first_n_samples > 0 and counter >= only_first_n_samples: |
| break |
|
|
| def __len__(self): |
| return len(self.sample_list) |
|
|
| def crop(self, rgbs, trajs): |
| T, N, _ = trajs.shape |
|
|
| S = len(rgbs) |
| H, W = rgbs[0].shape[:2] |
| assert S == T |
|
|
| H_new = H |
| W_new = W |
|
|
| |
| y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2 |
| x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2 |
| rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] |
|
|
| trajs[:, :, 0] -= x0 |
| trajs[:, :, 1] -= y0 |
|
|
| return rgbs, trajs |
|
|
| def __getitem__(self, index): |
| sample = self.sample_list[index] |
| T = len(sample) |
| rgbs, visibilities, traj_2d = [], [], [] |
|
|
| H, W = sample[0].image.size |
| image_size = (H, W) |
|
|
| for i in range(T): |
| traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"]) |
| traj = torch.load(traj_path) |
|
|
| visibilities.append(traj["verts_inds_vis"].numpy()) |
|
|
| rgbs.append(traj["img"].numpy()) |
| traj_2d.append(traj["traj_2d"].numpy()[..., :2]) |
|
|
| traj_2d = np.stack(traj_2d) |
| visibility = np.stack(visibilities) |
| T, N, D = traj_2d.shape |
| |
| visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample] |
|
|
| traj_2d = traj_2d[:, visible_inds_sampled] |
| visibility = visibility[:, visible_inds_sampled] |
|
|
| if self.crop_size is not None: |
| rgbs, traj_2d = self.crop(rgbs, traj_2d) |
| H, W, _ = rgbs[0].shape |
| image_size = self.crop_size |
|
|
| visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False |
| visibility[traj_2d[:, :, 0] < 0] = False |
| visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False |
| visibility[traj_2d[:, :, 1] < 0] = False |
|
|
| |
| visible_inds_resampled = visibility.sum(0) > 10 |
| traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled]) |
| visibility = torch.from_numpy(visibility[:, visible_inds_resampled]) |
|
|
| rgbs = np.stack(rgbs, 0) |
| video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float() |
| return CoTrackerData( |
| video=video, |
| trajectory=traj_2d, |
| visibility=visibility, |
| valid=torch.ones(T, N), |
| seq_name=sample[0].sequence_name, |
| ) |
|
|