| |
| |
|
|
| |
| |
|
|
|
|
| import torch |
| import dataclasses |
| import torch.nn.functional as F |
| from dataclasses import dataclass |
| from typing import Any, Optional |
|
|
|
|
| @dataclass(eq=False) |
| class CoTrackerData: |
| """ |
| Dataclass for storing video tracks data. |
| """ |
|
|
| video: torch.Tensor |
| trajectory: torch.Tensor |
| visibility: torch.Tensor |
| |
| valid: Optional[torch.Tensor] = None |
| segmentation: Optional[torch.Tensor] = None |
| seq_name: Optional[str] = None |
| query_points: Optional[torch.Tensor] = None |
|
|
|
|
| def collate_fn(batch): |
| """ |
| Collate function for video tracks data. |
| """ |
| video = torch.stack([b.video for b in batch], dim=0) |
| trajectory = torch.stack([b.trajectory for b in batch], dim=0) |
| visibility = torch.stack([b.visibility for b in batch], dim=0) |
| query_points = segmentation = None |
| if batch[0].query_points is not None: |
| query_points = torch.stack([b.query_points for b in batch], dim=0) |
| if batch[0].segmentation is not None: |
| segmentation = torch.stack([b.segmentation for b in batch], dim=0) |
| seq_name = [b.seq_name for b in batch] |
|
|
| return CoTrackerData( |
| video=video, |
| trajectory=trajectory, |
| visibility=visibility, |
| segmentation=segmentation, |
| seq_name=seq_name, |
| query_points=query_points, |
| ) |
|
|
|
|
| def collate_fn_train(batch): |
| """ |
| Collate function for video tracks data during training. |
| """ |
| gotit = [gotit for _, gotit in batch] |
| video = torch.stack([b.video for b, _ in batch], dim=0) |
| trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) |
| visibility = torch.stack([b.visibility for b, _ in batch], dim=0) |
| valid = torch.stack([b.valid for b, _ in batch], dim=0) |
| seq_name = [b.seq_name for b, _ in batch] |
| return ( |
| CoTrackerData( |
| video=video, |
| trajectory=trajectory, |
| visibility=visibility, |
| valid=valid, |
| seq_name=seq_name, |
| ), |
| gotit, |
| ) |
|
|
|
|
| def try_to_cuda(t: Any) -> Any: |
| """ |
| Try to move the input variable `t` to a cuda device. |
| |
| Args: |
| t: Input. |
| |
| Returns: |
| t_cuda: `t` moved to a cuda device, if supported. |
| """ |
| try: |
| t = t.float().cuda() |
| except AttributeError: |
| pass |
| return t |
|
|
|
|
| def dataclass_to_cuda_(obj): |
| """ |
| Move all contents of a dataclass to cuda inplace if supported. |
| |
| Args: |
| batch: Input dataclass. |
| |
| Returns: |
| batch_cuda: `batch` moved to a cuda device, if supported. |
| """ |
| for f in dataclasses.fields(obj): |
| setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) |
| return obj |
|
|