Merge branch 'main' of github.com:JunkyByte/co-tracker
This commit is contained in:
		
							
								
								
									
										41
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										41
									
								
								README.md
									
									
									
									
									
								
							@@ -1,6 +1,6 @@
 | 
			
		||||
# CoTracker: It is Better to Track Together
 | 
			
		||||
 | 
			
		||||
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
 | 
			
		||||
**[Meta AI Research, GenAI](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
 | 
			
		||||
 | 
			
		||||
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
 | 
			
		||||
 | 
			
		||||
@@ -15,7 +15,7 @@
 | 
			
		||||
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
 | 
			
		||||
 
 | 
			
		||||
CoTracker can track:
 | 
			
		||||
- **Every pixel** within a video
 | 
			
		||||
- **Every pixel** in a video
 | 
			
		||||
- Points sampled on a regular grid on any video frame 
 | 
			
		||||
- Manually selected points
 | 
			
		||||
 | 
			
		||||
@@ -26,16 +26,30 @@ Try these tracking modes for yourself with our [Colab demo](https://colab.resear
 | 
			
		||||
## Installation Instructions
 | 
			
		||||
Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support.
 | 
			
		||||
 | 
			
		||||
## Steps to Install CoTracker and its dependencies:
 | 
			
		||||
### Pretrained models via PyTorch Hub
 | 
			
		||||
The easiest way to use CoTracker is to load a pretrained model from torch.hub:
 | 
			
		||||
```
 | 
			
		||||
pip install einops timm tqdm
 | 
			
		||||
```
 | 
			
		||||
```
 | 
			
		||||
import torch
 | 
			
		||||
import timm
 | 
			
		||||
import einops
 | 
			
		||||
import tqdm
 | 
			
		||||
 | 
			
		||||
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8")
 | 
			
		||||
```
 | 
			
		||||
Another option is to install it from this gihub repo. That's the best way if you need to run our demo or evaluate / train CoTracker:
 | 
			
		||||
### Steps to Install CoTracker and its dependencies:
 | 
			
		||||
```
 | 
			
		||||
git clone https://github.com/facebookresearch/co-tracker
 | 
			
		||||
cd co-tracker
 | 
			
		||||
pip install -e .
 | 
			
		||||
pip install opencv-python einops timm matplotlib moviepy flow_vis
 | 
			
		||||
pip install opencv-python einops timm matplotlib moviepy flow_vis 
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
## Model Weights Download:
 | 
			
		||||
### Download Model Weights:
 | 
			
		||||
```
 | 
			
		||||
mkdir checkpoints
 | 
			
		||||
cd checkpoints
 | 
			
		||||
@@ -60,24 +74,26 @@ To reproduce the results presented in the paper, download the following datasets
 | 
			
		||||
 | 
			
		||||
And install the necessary dependencies:
 | 
			
		||||
```
 | 
			
		||||
pip install hydra-core==1.1.0 mediapy tensorboard 
 | 
			
		||||
pip install hydra-core==1.1.0 mediapy 
 | 
			
		||||
```
 | 
			
		||||
Then, execute the following command to evaluate on BADJA:
 | 
			
		||||
```
 | 
			
		||||
python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path
 | 
			
		||||
```
 | 
			
		||||
By default, evaluation will be slow since it is done for one target point at a time, which ensures robustness and fairness, as described in the paper.
 | 
			
		||||
 | 
			
		||||
## Training
 | 
			
		||||
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset.  Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
 | 
			
		||||
 | 
			
		||||
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
 | 
			
		||||
```
 | 
			
		||||
pip install pytorch_lightning==1.6.0
 | 
			
		||||
pip install pytorch_lightning==1.6.0 tensorboard
 | 
			
		||||
```
 | 
			
		||||
 launch training on Kubric. Our model was trained using 32 GPUs, and you can adjust the parameters to best suit your hardware setup.
 | 
			
		||||
Now you can launch training on Kubric. Our model was trained for 50000 iterations on 32 GPUs (4 nodes with 8 GPUs).
 | 
			
		||||
Modify *dataset_root* and *ckpt_path* accordingly before running this command:
 | 
			
		||||
```
 | 
			
		||||
python train.py --batch_size 1 --num_workers 28 \
 | 
			
		||||
--num_steps 50000 --ckpt_path ./ --model_name cotracker \
 | 
			
		||||
--num_steps 50000 --ckpt_path ./ --dataset_root ./datasets --model_name cotracker \
 | 
			
		||||
--save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \
 | 
			
		||||
--traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \
 | 
			
		||||
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4
 | 
			
		||||
@@ -86,13 +102,16 @@ python train.py --batch_size 1 --num_workers 28 \
 | 
			
		||||
## License
 | 
			
		||||
The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license.
 | 
			
		||||
 | 
			
		||||
## Acknowledgments
 | 
			
		||||
We would like to thank [PIPs](https://github.com/aharley/pips) and [TAP-Vid](https://github.com/deepmind/tapnet) for publicly releasing their code and data. We also want to thank [Luke Melas-Kyriazi](https://lukemelas.github.io/) for proofreading the paper, [Jianyuan Wang](https://jytime.github.io/), [Roman Shapovalov](https://shapovalov.ro/) and [Adam W. Harley](https://adamharley.com/) for the insightful discussions.
 | 
			
		||||
 | 
			
		||||
## Citing CoTracker
 | 
			
		||||
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
 | 
			
		||||
```
 | 
			
		||||
@article{karaev2023cotracker,
 | 
			
		||||
  title={CoTracker: It is Better to Track Together},
 | 
			
		||||
  author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
 | 
			
		||||
  journal={arxiv},
 | 
			
		||||
  journal={arXiv:2307.07635},
 | 
			
		||||
  year={2023}
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
```
 | 
			
		||||
 
 | 
			
		||||
@@ -185,7 +185,11 @@ class Evaluator:
 | 
			
		||||
                if not all(gotit):
 | 
			
		||||
                    print("batch is None")
 | 
			
		||||
                    continue
 | 
			
		||||
            dataclass_to_cuda_(sample)
 | 
			
		||||
            if torch.cuda.is_available():
 | 
			
		||||
                dataclass_to_cuda_(sample)
 | 
			
		||||
                device = torch.device("cuda")
 | 
			
		||||
            else:
 | 
			
		||||
                device = torch.device("cpu")
 | 
			
		||||
 | 
			
		||||
            if (
 | 
			
		||||
                not train_mode
 | 
			
		||||
@@ -205,7 +209,7 @@ class Evaluator:
 | 
			
		||||
                        queries[:, :, 1],
 | 
			
		||||
                    ],
 | 
			
		||||
                    dim=2,
 | 
			
		||||
                )
 | 
			
		||||
                ).to(device)
 | 
			
		||||
            else:
 | 
			
		||||
                queries = torch.cat(
 | 
			
		||||
                    [
 | 
			
		||||
@@ -213,7 +217,7 @@ class Evaluator:
 | 
			
		||||
                        sample.trajectory[:, 0],
 | 
			
		||||
                    ],
 | 
			
		||||
                    dim=2,
 | 
			
		||||
                )
 | 
			
		||||
                ).to(device)
 | 
			
		||||
 | 
			
		||||
            pred_tracks = model(sample.video, queries)
 | 
			
		||||
            if "strided" in dataset_name:
 | 
			
		||||
 
 | 
			
		||||
@@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig):
 | 
			
		||||
        single_point=cfg.single_point,
 | 
			
		||||
        n_iters=cfg.n_iters,
 | 
			
		||||
    )
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        predictor.model = predictor.model.cuda()
 | 
			
		||||
 | 
			
		||||
    # Setting the random seeds
 | 
			
		||||
    torch.manual_seed(cfg.seed)
 | 
			
		||||
 
 | 
			
		||||
@@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker
 | 
			
		||||
def build_cotracker(
 | 
			
		||||
    checkpoint: str,
 | 
			
		||||
):
 | 
			
		||||
    if checkpoint is None:
 | 
			
		||||
        return build_cotracker_stride_4_wind_8()
 | 
			
		||||
    model_name = checkpoint.split("/")[-1].split(".")[0]
 | 
			
		||||
    if model_name == "cotracker_stride_4_wind_8":
 | 
			
		||||
        return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
 | 
			
		||||
 
 | 
			
		||||
@@ -25,11 +25,11 @@ from cotracker.models.core.embeddings import (
 | 
			
		||||
torch.manual_seed(0)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'):
 | 
			
		||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"):
 | 
			
		||||
    if grid_size == 1:
 | 
			
		||||
        return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
 | 
			
		||||
        return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
 | 
			
		||||
            None, None
 | 
			
		||||
        ].to(device)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
    grid_y, grid_x = meshgrid2d(
 | 
			
		||||
        1, grid_size, grid_size, stack=False, norm=False, device=device
 | 
			
		||||
 
 | 
			
		||||
@@ -29,11 +29,10 @@ class EvaluationPredictor(torch.nn.Module):
 | 
			
		||||
        self.n_iters = n_iters
 | 
			
		||||
 | 
			
		||||
        self.model = cotracker_model
 | 
			
		||||
        self.model.to("cuda")
 | 
			
		||||
        self.model.eval()
 | 
			
		||||
 | 
			
		||||
    def forward(self, video, queries):
 | 
			
		||||
        queries = queries.clone().cuda()
 | 
			
		||||
        queries = queries.clone()
 | 
			
		||||
        B, T, C, H, W = video.shape
 | 
			
		||||
        B, N, D = queries.shape
 | 
			
		||||
 | 
			
		||||
@@ -42,14 +41,16 @@ class EvaluationPredictor(torch.nn.Module):
 | 
			
		||||
 | 
			
		||||
        rgbs = video.reshape(B * T, C, H, W)
 | 
			
		||||
        rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
 | 
			
		||||
        rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
 | 
			
		||||
        rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
 | 
			
		||||
 | 
			
		||||
        device = rgbs.device
 | 
			
		||||
 | 
			
		||||
        queries[:, :, 1] *= self.interp_shape[1] / W
 | 
			
		||||
        queries[:, :, 2] *= self.interp_shape[0] / H
 | 
			
		||||
 | 
			
		||||
        if self.single_point:
 | 
			
		||||
            traj_e = torch.zeros((B, T, N, 2)).cuda()
 | 
			
		||||
            vis_e = torch.zeros((B, T, N)).cuda()
 | 
			
		||||
            traj_e = torch.zeros((B, T, N, 2), device=device)
 | 
			
		||||
            vis_e = torch.zeros((B, T, N), device=device)
 | 
			
		||||
            for pind in range((N)):
 | 
			
		||||
                query = queries[:, pind : pind + 1]
 | 
			
		||||
 | 
			
		||||
@@ -60,8 +61,10 @@ class EvaluationPredictor(torch.nn.Module):
 | 
			
		||||
                vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
 | 
			
		||||
        else:
 | 
			
		||||
            if self.grid_size > 0:
 | 
			
		||||
                xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
 | 
			
		||||
                xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  #
 | 
			
		||||
                xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
 | 
			
		||||
                xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
 | 
			
		||||
                    device
 | 
			
		||||
                )  #
 | 
			
		||||
                queries = torch.cat([queries, xy], dim=1)  #
 | 
			
		||||
 | 
			
		||||
            traj_e, __, vis_e, __ = self.model(
 | 
			
		||||
@@ -91,8 +94,8 @@ class EvaluationPredictor(torch.nn.Module):
 | 
			
		||||
            query = torch.cat([query, xy_target], dim=1).to(device)  #
 | 
			
		||||
 | 
			
		||||
        if self.grid_size > 0:
 | 
			
		||||
            xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
 | 
			
		||||
            xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda()  #
 | 
			
		||||
            xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
 | 
			
		||||
            xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device)  #
 | 
			
		||||
            query = torch.cat([query, xy], dim=1).to(device)  #
 | 
			
		||||
        # crop the video to start from the queried frame
 | 
			
		||||
        query[0, 0, 0] = 0
 | 
			
		||||
 
 | 
			
		||||
@@ -25,8 +25,6 @@ class CoTrackerPredictor(torch.nn.Module):
 | 
			
		||||
        model = build_cotracker(checkpoint)
 | 
			
		||||
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.device = device or 'cuda'
 | 
			
		||||
        self.model.to(self.device)
 | 
			
		||||
        self.model.eval()
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
@@ -73,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module):
 | 
			
		||||
        grid_width = W // grid_step
 | 
			
		||||
        grid_height = H // grid_step
 | 
			
		||||
        tracks = visibilities = None
 | 
			
		||||
        grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
 | 
			
		||||
        grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
 | 
			
		||||
        grid_pts[0, :, 0] = grid_query_frame
 | 
			
		||||
        for offset in tqdm(range(grid_step * grid_step)):
 | 
			
		||||
            ox = offset % grid_step
 | 
			
		||||
@@ -108,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module):
 | 
			
		||||
        assert B == 1
 | 
			
		||||
 | 
			
		||||
        video = video.reshape(B * T, C, H, W)
 | 
			
		||||
        video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
 | 
			
		||||
        video = video.reshape(
 | 
			
		||||
            B, T, 3, self.interp_shape[0], self.interp_shape[1]
 | 
			
		||||
        ).to(self.device)
 | 
			
		||||
        video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
 | 
			
		||||
        video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
 | 
			
		||||
 | 
			
		||||
        if queries is not None:
 | 
			
		||||
            queries = queries.clone()
 | 
			
		||||
@@ -120,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module):
 | 
			
		||||
            queries[:, :, 1] *= self.interp_shape[1] / W
 | 
			
		||||
            queries[:, :, 2] *= self.interp_shape[0] / H
 | 
			
		||||
        elif grid_size > 0:
 | 
			
		||||
            grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
 | 
			
		||||
            grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
 | 
			
		||||
            if segm_mask is not None:
 | 
			
		||||
                segm_mask = F.interpolate(
 | 
			
		||||
                    segm_mask, tuple(self.interp_shape), mode="nearest"
 | 
			
		||||
 
 | 
			
		||||
@@ -14,7 +14,6 @@ from matplotlib import cm
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
import torchvision.transforms as transforms
 | 
			
		||||
from moviepy.editor import ImageSequenceClip
 | 
			
		||||
from torch.utils.tensorboard import SummaryWriter
 | 
			
		||||
import matplotlib.pyplot as plt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -67,7 +66,7 @@ class Visualizer:
 | 
			
		||||
        gt_tracks: torch.Tensor = None,  # (B,T,N,2)
 | 
			
		||||
        segm_mask: torch.Tensor = None,  # (B,1,H,W)
 | 
			
		||||
        filename: str = "video",
 | 
			
		||||
        writer: SummaryWriter = None,
 | 
			
		||||
        writer=None,  # tensorboard Summary Writer, used for visualization during training
 | 
			
		||||
        step: int = 0,
 | 
			
		||||
        query_frame: int = 0,
 | 
			
		||||
        save_video: bool = True,
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										12
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								demo.py
									
									
									
									
									
								
							@@ -32,11 +32,6 @@ if __name__ == "__main__":
 | 
			
		||||
        default="./checkpoints/cotracker_stride_4_wind_8.pth",
 | 
			
		||||
        help="cotracker model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--device",
 | 
			
		||||
        default="cuda",
 | 
			
		||||
        help="Device to use for inference",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--grid_query_frame",
 | 
			
		||||
@@ -59,7 +54,12 @@ if __name__ == "__main__":
 | 
			
		||||
    segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
 | 
			
		||||
    segm_mask = torch.from_numpy(segm_mask)[None, None]
 | 
			
		||||
 | 
			
		||||
    model = CoTrackerPredictor(checkpoint=args.checkpoint, device=args.device)
 | 
			
		||||
    model = CoTrackerPredictor(checkpoint=args.checkpoint)
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        model = model.cuda()
 | 
			
		||||
        video = video.cuda()
 | 
			
		||||
    else:
 | 
			
		||||
        print("CUDA is not available!")
 | 
			
		||||
 | 
			
		||||
    pred_tracks, pred_visibility = model(
 | 
			
		||||
        video,
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										32
									
								
								hubconf.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								hubconf.py
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,32 @@
 | 
			
		||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
 | 
			
		||||
# All rights reserved.
 | 
			
		||||
 | 
			
		||||
# This source code is licensed under the license found in the
 | 
			
		||||
# LICENSE file in the root directory of this source tree.
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
dependencies = ["torch", "einops", "timm", "tqdm"]
 | 
			
		||||
 | 
			
		||||
_COTRACKER_URL = (
 | 
			
		||||
    "https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs):
 | 
			
		||||
    from cotracker.predictor import CoTrackerPredictor
 | 
			
		||||
 | 
			
		||||
    predictor = CoTrackerPredictor(checkpoint=None)
 | 
			
		||||
    if pretrained:
 | 
			
		||||
        state_dict = torch.hub.load_state_dict_from_url(
 | 
			
		||||
            _COTRACKER_URL, map_location="cpu"
 | 
			
		||||
        )
 | 
			
		||||
        predictor.model.load_state_dict(state_dict)
 | 
			
		||||
    return predictor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cotracker_w8(*, pretrained: bool = True, **kwargs):
 | 
			
		||||
    """
 | 
			
		||||
    CoTracker model with stride 4 and window length 8. (The main model from the paper)
 | 
			
		||||
    """
 | 
			
		||||
    return _make_cotracker_predictor(pretrained=pretrained, **kwargs)
 | 
			
		||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
							
								
								
									
										111
									
								
								train.py
									
									
									
									
									
								
							
							
						
						
									
										111
									
								
								train.py
									
									
									
									
									
								
							@@ -36,21 +36,6 @@ from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_
 | 
			
		||||
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# define the handler function
 | 
			
		||||
# for training on a slurm cluster
 | 
			
		||||
def sig_handler(signum, frame):
 | 
			
		||||
    print("caught signal", signum)
 | 
			
		||||
    print(socket.gethostname(), "USR1 signal caught.")
 | 
			
		||||
    # do other stuff to cleanup here
 | 
			
		||||
    print("requeuing job " + os.environ["SLURM_JOB_ID"])
 | 
			
		||||
    os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
 | 
			
		||||
    sys.exit(-1)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def term_handler(signum, frame):
 | 
			
		||||
    print("bypassing sigterm", flush=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fetch_optimizer(args, model):
 | 
			
		||||
    """Create the optimizer and learning rate scheduler"""
 | 
			
		||||
    optimizer = optim.AdamW(
 | 
			
		||||
@@ -153,6 +138,8 @@ def run_test_eval(evaluator, model, dataloaders, writer, step):
 | 
			
		||||
            single_point=False,
 | 
			
		||||
            n_iters=6,
 | 
			
		||||
        )
 | 
			
		||||
        if torch.cuda.is_available():
 | 
			
		||||
            predictor.model = predictor.model.cuda()
 | 
			
		||||
 | 
			
		||||
        metrics = evaluator.evaluate_sequence(
 | 
			
		||||
            model=predictor,
 | 
			
		||||
@@ -302,9 +289,7 @@ class Lite(LightningLite):
 | 
			
		||||
            eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
 | 
			
		||||
 | 
			
		||||
        if "tapvid_davis_first" in args.eval_datasets:
 | 
			
		||||
            data_root = os.path.join(
 | 
			
		||||
                args.dataset_root, "/tapvid_davis/tapvid_davis.pkl"
 | 
			
		||||
            )
 | 
			
		||||
            data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl")
 | 
			
		||||
            eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
 | 
			
		||||
            eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
 | 
			
		||||
                eval_dataset,
 | 
			
		||||
@@ -551,17 +536,15 @@ class Lite(LightningLite):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    signal.signal(signal.SIGUSR1, sig_handler)
 | 
			
		||||
    signal.signal(signal.SIGTERM, term_handler)
 | 
			
		||||
    parser = argparse.ArgumentParser()
 | 
			
		||||
    parser.add_argument("--model_name", default="cotracker", help="model name")
 | 
			
		||||
    parser.add_argument("--restore_ckpt", help="restore checkpoint")
 | 
			
		||||
    parser.add_argument("--ckpt_path", help="restore checkpoint")
 | 
			
		||||
    parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
 | 
			
		||||
    parser.add_argument("--ckpt_path", help="path to save checkpoints")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--batch_size", type=int, default=4, help="batch size used during training."
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--num_workers", type=int, default=6, help="left right consistency loss"
 | 
			
		||||
        "--num_workers", type=int, default=6, help="number of dataloader workers"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
@@ -578,20 +561,34 @@ if __name__ == "__main__":
 | 
			
		||||
        "--evaluate_every_n_epoch",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=1,
 | 
			
		||||
        help="number of flow-field updates during validation forward pass",
 | 
			
		||||
        help="evaluate during training after every n epochs, after every epoch by default",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--save_every_n_epoch",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=1,
 | 
			
		||||
        help="number of flow-field updates during validation forward pass",
 | 
			
		||||
        help="save checkpoints during training after every n epochs, after every epoch by default",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--validate_at_start", action="store_true", help="use mixed precision"
 | 
			
		||||
        "--validate_at_start",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="whether to run evaluation before training starts",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--save_freq",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=100,
 | 
			
		||||
        help="frequency of trajectory visualization during training",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--traj_per_sample",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=768,
 | 
			
		||||
        help="the number of trajectories to sample for training",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--dataset_root", type=str, help="path lo all the datasets (train and eval)"
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--save_freq", type=int, default=100, help="save_freq")
 | 
			
		||||
    parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq")
 | 
			
		||||
    parser.add_argument("--dataset_root", type=str, help="path lo all the datasets")
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--train_iters",
 | 
			
		||||
@@ -605,49 +602,75 @@ if __name__ == "__main__":
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--eval_datasets",
 | 
			
		||||
        nargs="+",
 | 
			
		||||
        default=["things", "badja", "fastcapture"],
 | 
			
		||||
        help="eval datasets.",
 | 
			
		||||
        default=["things", "badja"],
 | 
			
		||||
        help="what datasets to use for evaluation",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--remove_space_attn", action="store_true", help="use mixed precision"
 | 
			
		||||
        "--remove_space_attn",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="remove space attention from CoTracker",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--dont_use_augs", action="store_true", help="use mixed precision"
 | 
			
		||||
        "--dont_use_augs",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="don't apply augmentations during training",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--sample_vis_1st_frame", action="store_true", help="use mixed precision"
 | 
			
		||||
        "--sample_vis_1st_frame",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="only sample trajectories with points visible on the first frame",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--sliding_window_len", type=int, default=8, help="use mixed precision"
 | 
			
		||||
        "--sliding_window_len",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=8,
 | 
			
		||||
        help="length of the CoTracker sliding window",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--updateformer_hidden_size", type=int, default=384, help="use mixed precision"
 | 
			
		||||
        "--updateformer_hidden_size",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=384,
 | 
			
		||||
        help="hidden dimension of the CoTracker transformer model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--updateformer_num_heads", type=int, default=8, help="use mixed precision"
 | 
			
		||||
        "--updateformer_num_heads",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=8,
 | 
			
		||||
        help="number of heads of the CoTracker transformer model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--updateformer_space_depth", type=int, default=12, help="use mixed precision"
 | 
			
		||||
        "--updateformer_space_depth",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=12,
 | 
			
		||||
        help="number of group attention layers in the CoTracker transformer model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--updateformer_time_depth", type=int, default=12, help="use mixed precision"
 | 
			
		||||
        "--updateformer_time_depth",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=12,
 | 
			
		||||
        help="number of time attention layers in the CoTracker transformer model",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model_stride", type=int, default=8, help="use mixed precision"
 | 
			
		||||
        "--model_stride",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=8,
 | 
			
		||||
        help="stride of the CoTracker feature network",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--crop_size",
 | 
			
		||||
        type=int,
 | 
			
		||||
        nargs="+",
 | 
			
		||||
        default=[384, 512],
 | 
			
		||||
        help="use mixed precision",
 | 
			
		||||
        help="crop videos to this resolution during training",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--eval_max_seq_len", type=int, default=1000, help="use mixed precision"
 | 
			
		||||
        "--eval_max_seq_len",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=1000,
 | 
			
		||||
        help="maximum length of evaluation videos",
 | 
			
		||||
    )
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    logging.basicConfig(
 | 
			
		||||
        level=logging.INFO,
 | 
			
		||||
        format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
 | 
			
		||||
@@ -661,5 +684,5 @@ if __name__ == "__main__":
 | 
			
		||||
        devices="auto",
 | 
			
		||||
        accelerator="gpu",
 | 
			
		||||
        precision=32,
 | 
			
		||||
        num_nodes=4,
 | 
			
		||||
        # num_nodes=4,
 | 
			
		||||
    ).run(args)
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user