| |
| |
|
|
| |
| |
|
|
| import os |
| import torch |
| import argparse |
| import imageio.v3 as iio |
| import numpy as np |
|
|
| from cotracker.utils.visualizer import Visualizer |
| from cotracker.predictor import CoTrackerOnlinePredictor |
|
|
| |
| |
|
|
| DEFAULT_DEVICE = ( |
| |
| "cuda" |
| if torch.cuda.is_available() |
| else "cpu" |
| ) |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--video_path", |
| default="./assets/apple.mp4", |
| help="path to a video", |
| ) |
| parser.add_argument( |
| "--checkpoint", |
| default=None, |
| help="CoTracker model parameters", |
| ) |
| parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size") |
| parser.add_argument( |
| "--grid_query_frame", |
| type=int, |
| default=0, |
| help="Compute dense and grid tracks starting from this frame", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| if not os.path.isfile(args.video_path): |
| raise ValueError("Video file does not exist") |
|
|
| if args.checkpoint is not None: |
| model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint) |
| else: |
| model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online") |
| model = model.to(DEFAULT_DEVICE) |
|
|
| window_frames = [] |
|
|
| def _process_step(window_frames, is_first_step, grid_size, grid_query_frame): |
| video_chunk = ( |
| torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE) |
| .float() |
| .permute(0, 3, 1, 2)[None] |
| ) |
| return model( |
| video_chunk, |
| is_first_step=is_first_step, |
| grid_size=grid_size, |
| grid_query_frame=grid_query_frame, |
| ) |
|
|
| |
| is_first_step = True |
| for i, frame in enumerate( |
| iio.imiter( |
| args.video_path, |
| plugin="FFMPEG", |
| ) |
| ): |
| if i % model.step == 0 and i != 0: |
| pred_tracks, pred_visibility = _process_step( |
| window_frames, |
| is_first_step, |
| grid_size=args.grid_size, |
| grid_query_frame=args.grid_query_frame, |
| ) |
| is_first_step = False |
| window_frames.append(frame) |
| |
| pred_tracks, pred_visibility = _process_step( |
| window_frames[-(i % model.step) - model.step - 1 :], |
| is_first_step, |
| grid_size=args.grid_size, |
| grid_query_frame=args.grid_query_frame, |
| ) |
|
|
| print("Tracks are computed") |
|
|
| |
| seq_name = args.video_path.split("/")[-1] |
| video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None] |
| vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) |
| vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame) |
|
|