| | |
| | |
| |
|
| | |
| | |
| |
|
| | import os |
| | import random |
| | import torch |
| | import signal |
| | import socket |
| | import sys |
| | import json |
| |
|
| | import numpy as np |
| | import argparse |
| | import logging |
| | from pathlib import Path |
| | from tqdm import tqdm |
| | import torch.optim as optim |
| | from torch.utils.data import DataLoader |
| | from torch.cuda.amp import GradScaler |
| |
|
| | from torch.utils.tensorboard import SummaryWriter |
| | from pytorch_lightning.lite import LightningLite |
| |
|
| | from cotracker.models.evaluation_predictor import EvaluationPredictor |
| | from cotracker.models.core.cotracker.cotracker import CoTracker2 |
| | from cotracker.utils.visualizer import Visualizer |
| | from cotracker.datasets.tap_vid_datasets import TapVidDataset |
| |
|
| | from cotracker.datasets.dr_dataset import DynamicReplicaDataset |
| | from cotracker.evaluation.core.evaluator import Evaluator |
| | from cotracker.datasets import kubric_movif_dataset |
| | from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_ |
| | from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss |
| |
|
| |
|
| | |
| | |
| | def sig_handler(signum, frame): |
| | print("caught signal", signum) |
| | print(socket.gethostname(), "USR1 signal caught.") |
| | |
| | print("requeuing job " + os.environ["SLURM_JOB_ID"]) |
| | os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"]) |
| | sys.exit(-1) |
| |
|
| |
|
| | def term_handler(signum, frame): |
| | print("bypassing sigterm", flush=True) |
| |
|
| |
|
| | def fetch_optimizer(args, model): |
| | """Create the optimizer and learning rate scheduler""" |
| | optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8) |
| | scheduler = optim.lr_scheduler.OneCycleLR( |
| | optimizer, |
| | args.lr, |
| | args.num_steps + 100, |
| | pct_start=0.05, |
| | cycle_momentum=False, |
| | anneal_strategy="linear", |
| | ) |
| |
|
| | return optimizer, scheduler |
| |
|
| |
|
| | def forward_batch(batch, model, args): |
| | video = batch.video |
| | trajs_g = batch.trajectory |
| | vis_g = batch.visibility |
| | valids = batch.valid |
| | B, T, C, H, W = video.shape |
| | assert C == 3 |
| | B, T, N, D = trajs_g.shape |
| | device = video.device |
| |
|
| | __, first_positive_inds = torch.max(vis_g, dim=1) |
| | |
| | |
| | N_rand = N // 4 |
| | |
| | nonzero_inds = [[torch.nonzero(vis_g[b, :, i]) for i in range(N)] for b in range(B)] |
| |
|
| | for b in range(B): |
| | rand_vis_inds = torch.cat( |
| | [ |
| | nonzero_row[torch.randint(len(nonzero_row), size=(1,))] |
| | for nonzero_row in nonzero_inds[b] |
| | ], |
| | dim=1, |
| | ) |
| | first_positive_inds[b] = torch.cat( |
| | [rand_vis_inds[:, :N_rand], first_positive_inds[b : b + 1, N_rand:]], dim=1 |
| | ) |
| |
|
| | ind_array_ = torch.arange(T, device=device) |
| | ind_array_ = ind_array_[None, :, None].repeat(B, 1, N) |
| | assert torch.allclose( |
| | vis_g[ind_array_ == first_positive_inds[:, None, :]], |
| | torch.ones(1, device=device), |
| | ) |
| | gather = torch.gather(trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, D)) |
| | xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1) |
| |
|
| | queries = torch.cat([first_positive_inds[:, :, None], xys[:, :, :2]], dim=2) |
| |
|
| | predictions, visibility, train_data = model( |
| | video=video, queries=queries, iters=args.train_iters, is_train=True |
| | ) |
| | coord_predictions, vis_predictions, valid_mask = train_data |
| |
|
| | vis_gts = [] |
| | traj_gts = [] |
| | valids_gts = [] |
| |
|
| | S = args.sliding_window_len |
| | for ind in range(0, args.sequence_len - S // 2, S // 2): |
| | vis_gts.append(vis_g[:, ind : ind + S]) |
| | traj_gts.append(trajs_g[:, ind : ind + S]) |
| | valids_gts.append(valids[:, ind : ind + S] * valid_mask[:, ind : ind + S]) |
| | |
| | seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8) |
| | vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts) |
| |
|
| | output = {"flow": {"predictions": predictions[0].detach()}} |
| | output["flow"]["loss"] = seq_loss.mean() |
| | output["visibility"] = { |
| | "loss": vis_loss.mean() * 10.0, |
| | "predictions": visibility[0].detach(), |
| | } |
| | return output |
| |
|
| |
|
| | def run_test_eval(evaluator, model, dataloaders, writer, step): |
| | model.eval() |
| | for ds_name, dataloader in dataloaders: |
| | visualize_every = 1 |
| | grid_size = 5 |
| | if ds_name == "dynamic_replica": |
| | visualize_every = 8 |
| | grid_size = 0 |
| | elif "tapvid" in ds_name: |
| | visualize_every = 5 |
| |
|
| | predictor = EvaluationPredictor( |
| | model.module.module, |
| | grid_size=grid_size, |
| | local_grid_size=0, |
| | single_point=False, |
| | n_iters=6, |
| | ) |
| | if torch.cuda.is_available(): |
| | predictor.model = predictor.model.cuda() |
| |
|
| | metrics = evaluator.evaluate_sequence( |
| | model=predictor, |
| | test_dataloader=dataloader, |
| | dataset_name=ds_name, |
| | train_mode=True, |
| | writer=writer, |
| | step=step, |
| | visualize_every=visualize_every, |
| | ) |
| |
|
| | if ds_name == "dynamic_replica" or ds_name == "kubric": |
| | metrics = {f"{ds_name}_avg_{k}": v for k, v in metrics["avg"].items()} |
| |
|
| | if "tapvid" in ds_name: |
| | metrics = { |
| | f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"], |
| | f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"], |
| | f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"], |
| | } |
| |
|
| | writer.add_scalars(f"Eval_{ds_name}", metrics, step) |
| |
|
| |
|
| | class Logger: |
| | SUM_FREQ = 100 |
| |
|
| | def __init__(self, model, scheduler): |
| | self.model = model |
| | self.scheduler = scheduler |
| | self.total_steps = 0 |
| | self.running_loss = {} |
| | self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) |
| |
|
| | def _print_training_status(self): |
| | metrics_data = [ |
| | self.running_loss[k] / Logger.SUM_FREQ for k in sorted(self.running_loss.keys()) |
| | ] |
| | training_str = "[{:6d}] ".format(self.total_steps + 1) |
| | metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data) |
| |
|
| | |
| | logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}") |
| |
|
| | if self.writer is None: |
| | self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) |
| |
|
| | for k in self.running_loss: |
| | self.writer.add_scalar(k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps) |
| | self.running_loss[k] = 0.0 |
| |
|
| | def push(self, metrics, task): |
| | self.total_steps += 1 |
| |
|
| | for key in metrics: |
| | task_key = str(key) + "_" + task |
| | if task_key not in self.running_loss: |
| | self.running_loss[task_key] = 0.0 |
| |
|
| | self.running_loss[task_key] += metrics[key] |
| |
|
| | if self.total_steps % Logger.SUM_FREQ == Logger.SUM_FREQ - 1: |
| | self._print_training_status() |
| | self.running_loss = {} |
| |
|
| | def write_dict(self, results): |
| | if self.writer is None: |
| | self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs")) |
| |
|
| | for key in results: |
| | self.writer.add_scalar(key, results[key], self.total_steps) |
| |
|
| | def close(self): |
| | self.writer.close() |
| |
|
| |
|
| | class Lite(LightningLite): |
| | def run(self, args): |
| | def seed_everything(seed: int): |
| | random.seed(seed) |
| | os.environ["PYTHONHASHSEED"] = str(seed) |
| | np.random.seed(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed(seed) |
| | torch.backends.cudnn.deterministic = True |
| | torch.backends.cudnn.benchmark = False |
| |
|
| | seed_everything(0) |
| |
|
| | def seed_worker(worker_id): |
| | worker_seed = torch.initial_seed() % 2**32 |
| | np.random.seed(worker_seed) |
| | random.seed(worker_seed) |
| |
|
| | g = torch.Generator() |
| | g.manual_seed(0) |
| | if self.global_rank == 0: |
| | eval_dataloaders = [] |
| | if "dynamic_replica" in args.eval_datasets: |
| | eval_dataset = DynamicReplicaDataset( |
| | sample_len=60, only_first_n_samples=1, rgbd_input=False |
| | ) |
| | eval_dataloader_dr = torch.utils.data.DataLoader( |
| | eval_dataset, |
| | batch_size=1, |
| | shuffle=False, |
| | num_workers=1, |
| | collate_fn=collate_fn, |
| | ) |
| | eval_dataloaders.append(("dynamic_replica", eval_dataloader_dr)) |
| |
|
| | if "tapvid_davis_first" in args.eval_datasets: |
| | data_root = os.path.join(args.dataset_root, "tapvid/tapvid_davis/tapvid_davis.pkl") |
| | eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root) |
| | eval_dataloader_tapvid_davis = torch.utils.data.DataLoader( |
| | eval_dataset, |
| | batch_size=1, |
| | shuffle=False, |
| | num_workers=1, |
| | collate_fn=collate_fn, |
| | ) |
| | eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis)) |
| |
|
| | evaluator = Evaluator(args.ckpt_path) |
| |
|
| | visualizer = Visualizer( |
| | save_dir=args.ckpt_path, |
| | pad_value=80, |
| | fps=1, |
| | show_first_frame=0, |
| | tracks_leave_trace=0, |
| | ) |
| |
|
| | if args.model_name == "cotracker": |
| | model = CoTracker2( |
| | stride=args.model_stride, |
| | window_len=args.sliding_window_len, |
| | add_space_attn=not args.remove_space_attn, |
| | num_virtual_tracks=args.num_virtual_tracks, |
| | model_resolution=args.crop_size, |
| | ) |
| | else: |
| | raise ValueError(f"Model {args.model_name} doesn't exist") |
| |
|
| | with open(args.ckpt_path + "/meta.json", "w") as file: |
| | json.dump(vars(args), file, sort_keys=True, indent=4) |
| |
|
| | model.cuda() |
| |
|
| | train_dataset = kubric_movif_dataset.KubricMovifDataset( |
| | data_root=os.path.join(args.dataset_root, "kubric", "kubric_movi_f_tracks"), |
| | crop_size=args.crop_size, |
| | seq_len=args.sequence_len, |
| | traj_per_sample=args.traj_per_sample, |
| | sample_vis_1st_frame=args.sample_vis_1st_frame, |
| | use_augs=not args.dont_use_augs, |
| | ) |
| |
|
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=args.batch_size, |
| | shuffle=True, |
| | num_workers=args.num_workers, |
| | worker_init_fn=seed_worker, |
| | generator=g, |
| | pin_memory=True, |
| | collate_fn=collate_fn_train, |
| | drop_last=True, |
| | ) |
| |
|
| | train_loader = self.setup_dataloaders(train_loader, move_to_device=False) |
| | print("LEN TRAIN LOADER", len(train_loader)) |
| | optimizer, scheduler = fetch_optimizer(args, model) |
| |
|
| | total_steps = 0 |
| | if self.global_rank == 0: |
| | logger = Logger(model, scheduler) |
| |
|
| | folder_ckpts = [ |
| | f |
| | for f in os.listdir(args.ckpt_path) |
| | if not os.path.isdir(f) and f.endswith(".pth") and not "final" in f |
| | ] |
| | if len(folder_ckpts) > 0: |
| | ckpt_path = sorted(folder_ckpts)[-1] |
| | ckpt = self.load(os.path.join(args.ckpt_path, ckpt_path)) |
| | logging.info(f"Loading checkpoint {ckpt_path}") |
| | if "model" in ckpt: |
| | model.load_state_dict(ckpt["model"]) |
| | else: |
| | model.load_state_dict(ckpt) |
| | if "optimizer" in ckpt: |
| | logging.info("Load optimizer") |
| | optimizer.load_state_dict(ckpt["optimizer"]) |
| | if "scheduler" in ckpt: |
| | logging.info("Load scheduler") |
| | scheduler.load_state_dict(ckpt["scheduler"]) |
| | if "total_steps" in ckpt: |
| | total_steps = ckpt["total_steps"] |
| | logging.info(f"Load total_steps {total_steps}") |
| |
|
| | elif args.restore_ckpt is not None: |
| | assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(".pt") |
| | logging.info("Loading checkpoint...") |
| |
|
| | strict = True |
| | state_dict = self.load(args.restore_ckpt) |
| | if "model" in state_dict: |
| | state_dict = state_dict["model"] |
| |
|
| | if list(state_dict.keys())[0].startswith("module."): |
| | state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} |
| | model.load_state_dict(state_dict, strict=strict) |
| |
|
| | logging.info(f"Done loading checkpoint") |
| | model, optimizer = self.setup(model, optimizer, move_to_device=False) |
| | |
| | model.train() |
| |
|
| | save_freq = args.save_freq |
| | scaler = GradScaler(enabled=args.mixed_precision) |
| |
|
| | should_keep_training = True |
| | global_batch_num = 0 |
| | epoch = -1 |
| |
|
| | while should_keep_training: |
| | epoch += 1 |
| | for i_batch, batch in enumerate(tqdm(train_loader)): |
| | batch, gotit = batch |
| | if not all(gotit): |
| | print("batch is None") |
| | continue |
| | dataclass_to_cuda_(batch) |
| |
|
| | optimizer.zero_grad() |
| |
|
| | assert model.training |
| |
|
| | output = forward_batch(batch, model, args) |
| |
|
| | loss = 0 |
| | for k, v in output.items(): |
| | if "loss" in v: |
| | loss += v["loss"] |
| |
|
| | if self.global_rank == 0: |
| | for k, v in output.items(): |
| | if "loss" in v: |
| | logger.writer.add_scalar( |
| | f"live_{k}_loss", v["loss"].item(), total_steps |
| | ) |
| | if "metrics" in v: |
| | logger.push(v["metrics"], k) |
| | if total_steps % save_freq == save_freq - 1: |
| | visualizer.visualize( |
| | video=batch.video.clone(), |
| | tracks=batch.trajectory.clone(), |
| | filename="train_gt_traj", |
| | writer=logger.writer, |
| | step=total_steps, |
| | ) |
| |
|
| | visualizer.visualize( |
| | video=batch.video.clone(), |
| | tracks=output["flow"]["predictions"][None], |
| | filename="train_pred_traj", |
| | writer=logger.writer, |
| | step=total_steps, |
| | ) |
| |
|
| | if len(output) > 1: |
| | logger.writer.add_scalar(f"live_total_loss", loss.item(), total_steps) |
| | logger.writer.add_scalar( |
| | f"learning_rate", optimizer.param_groups[0]["lr"], total_steps |
| | ) |
| | global_batch_num += 1 |
| |
|
| | self.barrier() |
| |
|
| | self.backward(scaler.scale(loss)) |
| |
|
| | scaler.unscale_(optimizer) |
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 10.0) |
| |
|
| | scaler.step(optimizer) |
| | scheduler.step() |
| | scaler.update() |
| | total_steps += 1 |
| | if self.global_rank == 0: |
| | if (i_batch >= len(train_loader) - 1) or ( |
| | total_steps == 1 and args.validate_at_start |
| | ): |
| | if (epoch + 1) % args.save_every_n_epoch == 0: |
| | ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps) |
| | save_path = Path( |
| | f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth" |
| | ) |
| |
|
| | save_dict = { |
| | "model": model.module.module.state_dict(), |
| | "optimizer": optimizer.state_dict(), |
| | "scheduler": scheduler.state_dict(), |
| | "total_steps": total_steps, |
| | } |
| |
|
| | logging.info(f"Saving file {save_path}") |
| | self.save(save_dict, save_path) |
| |
|
| | if (epoch + 1) % args.evaluate_every_n_epoch == 0 or ( |
| | args.validate_at_start and epoch == 0 |
| | ): |
| | run_test_eval( |
| | evaluator, |
| | model, |
| | eval_dataloaders, |
| | logger.writer, |
| | total_steps, |
| | ) |
| | model.train() |
| | torch.cuda.empty_cache() |
| |
|
| | self.barrier() |
| | if total_steps > args.num_steps: |
| | should_keep_training = False |
| | break |
| | if self.global_rank == 0: |
| | print("FINISHED TRAINING") |
| |
|
| | PATH = f"{args.ckpt_path}/{args.model_name}_final.pth" |
| | torch.save(model.module.module.state_dict(), PATH) |
| | run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps) |
| | logger.close() |
| |
|
| |
|
| | if __name__ == "__main__": |
| | signal.signal(signal.SIGUSR1, sig_handler) |
| | signal.signal(signal.SIGTERM, term_handler) |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--model_name", default="cotracker", help="model name") |
| | parser.add_argument("--restore_ckpt", help="path to restore a checkpoint") |
| | parser.add_argument("--ckpt_path", help="path to save checkpoints") |
| | parser.add_argument( |
| | "--batch_size", type=int, default=4, help="batch size used during training." |
| | ) |
| | parser.add_argument("--num_nodes", type=int, default=1) |
| | parser.add_argument("--num_workers", type=int, default=10, help="number of dataloader workers") |
| |
|
| | parser.add_argument("--mixed_precision", action="store_true", help="use mixed precision") |
| | parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.") |
| | parser.add_argument("--wdecay", type=float, default=0.00001, help="Weight decay in optimizer.") |
| | parser.add_argument( |
| | "--num_steps", type=int, default=200000, help="length of training schedule." |
| | ) |
| | parser.add_argument( |
| | "--evaluate_every_n_epoch", |
| | type=int, |
| | default=1, |
| | help="evaluate during training after every n epochs, after every epoch by default", |
| | ) |
| | parser.add_argument( |
| | "--save_every_n_epoch", |
| | type=int, |
| | default=1, |
| | help="save checkpoints during training after every n epochs, after every epoch by default", |
| | ) |
| | parser.add_argument( |
| | "--validate_at_start", |
| | action="store_true", |
| | help="whether to run evaluation before training starts", |
| | ) |
| | parser.add_argument( |
| | "--save_freq", |
| | type=int, |
| | default=100, |
| | help="frequency of trajectory visualization during training", |
| | ) |
| | parser.add_argument( |
| | "--traj_per_sample", |
| | type=int, |
| | default=768, |
| | help="the number of trajectories to sample for training", |
| | ) |
| | parser.add_argument( |
| | "--dataset_root", type=str, help="path lo all the datasets (train and eval)" |
| | ) |
| |
|
| | parser.add_argument( |
| | "--train_iters", |
| | type=int, |
| | default=4, |
| | help="number of updates to the disparity field in each forward pass.", |
| | ) |
| | parser.add_argument("--sequence_len", type=int, default=8, help="train sequence length") |
| | parser.add_argument( |
| | "--eval_datasets", |
| | nargs="+", |
| | default=["tapvid_davis_first"], |
| | help="what datasets to use for evaluation", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--remove_space_attn", |
| | action="store_true", |
| | help="remove space attention from CoTracker", |
| | ) |
| | parser.add_argument( |
| | "--num_virtual_tracks", |
| | type=int, |
| | default=None, |
| | help="stride of the CoTracker feature network", |
| | ) |
| | parser.add_argument( |
| | "--dont_use_augs", |
| | action="store_true", |
| | help="don't apply augmentations during training", |
| | ) |
| | parser.add_argument( |
| | "--sample_vis_1st_frame", |
| | action="store_true", |
| | help="only sample trajectories with points visible on the first frame", |
| | ) |
| | parser.add_argument( |
| | "--sliding_window_len", |
| | type=int, |
| | default=8, |
| | help="length of the CoTracker sliding window", |
| | ) |
| | parser.add_argument( |
| | "--model_stride", |
| | type=int, |
| | default=8, |
| | help="stride of the CoTracker feature network", |
| | ) |
| | parser.add_argument( |
| | "--crop_size", |
| | type=int, |
| | nargs="+", |
| | default=[384, 512], |
| | help="crop videos to this resolution during training", |
| | ) |
| | parser.add_argument( |
| | "--eval_max_seq_len", |
| | type=int, |
| | default=1000, |
| | help="maximum length of evaluation videos", |
| | ) |
| | args = parser.parse_args() |
| | logging.basicConfig( |
| | level=logging.INFO, |
| | format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s", |
| | ) |
| |
|
| | Path(args.ckpt_path).mkdir(exist_ok=True, parents=True) |
| | from pytorch_lightning.strategies import DDPStrategy |
| |
|
| | Lite( |
| | strategy=DDPStrategy(find_unused_parameters=False), |
| | devices="auto", |
| | accelerator="gpu", |
| | precision=32, |
| | num_nodes=args.num_nodes, |
| | ).run(args) |
| |
|