Merge branch 'main' of github.com:JunkyByte/co-tracker
This commit is contained in:
commit
03f3c41e07
41
README.md
41
README.md
@ -1,6 +1,6 @@
|
||||
# CoTracker: It is Better to Track Together
|
||||
|
||||
**[Meta AI Research, FAIR](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
|
||||
**[Meta AI Research, GenAI](https://ai.facebook.com/research/)**; **[University of Oxford, VGG](https://www.robots.ox.ac.uk/~vgg/)**
|
||||
|
||||
[Nikita Karaev](https://nikitakaraevv.github.io/), [Ignacio Rocco](https://www.irocco.info/), [Benjamin Graham](https://ai.facebook.com/people/benjamin-graham/), [Natalia Neverova](https://nneverova.github.io/), [Andrea Vedaldi](https://www.robots.ox.ac.uk/~vedaldi/), [Christian Rupprecht](https://chrirupp.github.io/)
|
||||
|
||||
@ -15,7 +15,7 @@
|
||||
**CoTracker** is a fast transformer-based model that can track any point in a video. It brings to tracking some of the benefits of Optical Flow.
|
||||
|
||||
CoTracker can track:
|
||||
- **Every pixel** within a video
|
||||
- **Every pixel** in a video
|
||||
- Points sampled on a regular grid on any video frame
|
||||
- Manually selected points
|
||||
|
||||
@ -26,16 +26,30 @@ Try these tracking modes for yourself with our [Colab demo](https://colab.resear
|
||||
## Installation Instructions
|
||||
Ensure you have both PyTorch and TorchVision installed on your system. Follow the instructions [here](https://pytorch.org/get-started/locally/) for the installation. We strongly recommend installing both PyTorch and TorchVision with CUDA support.
|
||||
|
||||
## Steps to Install CoTracker and its dependencies:
|
||||
### Pretrained models via PyTorch Hub
|
||||
The easiest way to use CoTracker is to load a pretrained model from torch.hub:
|
||||
```
|
||||
pip install einops timm tqdm
|
||||
```
|
||||
```
|
||||
import torch
|
||||
import timm
|
||||
import einops
|
||||
import tqdm
|
||||
|
||||
cotracker = torch.hub.load("facebookresearch/co-tracker", "cotracker_w8")
|
||||
```
|
||||
Another option is to install it from this gihub repo. That's the best way if you need to run our demo or evaluate / train CoTracker:
|
||||
### Steps to Install CoTracker and its dependencies:
|
||||
```
|
||||
git clone https://github.com/facebookresearch/co-tracker
|
||||
cd co-tracker
|
||||
pip install -e .
|
||||
pip install opencv-python einops timm matplotlib moviepy flow_vis
|
||||
pip install opencv-python einops timm matplotlib moviepy flow_vis
|
||||
```
|
||||
|
||||
|
||||
## Model Weights Download:
|
||||
### Download Model Weights:
|
||||
```
|
||||
mkdir checkpoints
|
||||
cd checkpoints
|
||||
@ -60,24 +74,26 @@ To reproduce the results presented in the paper, download the following datasets
|
||||
|
||||
And install the necessary dependencies:
|
||||
```
|
||||
pip install hydra-core==1.1.0 mediapy tensorboard
|
||||
pip install hydra-core==1.1.0 mediapy
|
||||
```
|
||||
Then, execute the following command to evaluate on BADJA:
|
||||
```
|
||||
python ./cotracker/evaluation/evaluate.py --config-name eval_badja exp_dir=./eval_outputs dataset_root=your/badja/path
|
||||
```
|
||||
By default, evaluation will be slow since it is done for one target point at a time, which ensures robustness and fairness, as described in the paper.
|
||||
|
||||
## Training
|
||||
To train the CoTracker as described in our paper, you first need to generate annotations for [Google Kubric](https://github.com/google-research/kubric) MOVI-f dataset. Instructions for annotation generation can be found [here](https://github.com/deepmind/tapnet).
|
||||
|
||||
Once you have the annotated dataset, you need to make sure you followed the steps for evaluation setup and install the training dependencies:
|
||||
```
|
||||
pip install pytorch_lightning==1.6.0
|
||||
pip install pytorch_lightning==1.6.0 tensorboard
|
||||
```
|
||||
launch training on Kubric. Our model was trained using 32 GPUs, and you can adjust the parameters to best suit your hardware setup.
|
||||
Now you can launch training on Kubric. Our model was trained for 50000 iterations on 32 GPUs (4 nodes with 8 GPUs).
|
||||
Modify *dataset_root* and *ckpt_path* accordingly before running this command:
|
||||
```
|
||||
python train.py --batch_size 1 --num_workers 28 \
|
||||
--num_steps 50000 --ckpt_path ./ --model_name cotracker \
|
||||
--num_steps 50000 --ckpt_path ./ --dataset_root ./datasets --model_name cotracker \
|
||||
--save_freq 200 --sequence_len 24 --eval_datasets tapvid_davis_first badja \
|
||||
--traj_per_sample 256 --sliding_window_len 8 --updateformer_space_depth 6 --updateformer_time_depth 6 \
|
||||
--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4
|
||||
@ -86,13 +102,16 @@ python train.py --batch_size 1 --num_workers 28 \
|
||||
## License
|
||||
The majority of CoTracker is licensed under CC-BY-NC, however portions of the project are available under separate license terms: Particle Video Revisited is licensed under the MIT license, TAP-Vid is licensed under the Apache 2.0 license.
|
||||
|
||||
## Acknowledgments
|
||||
We would like to thank [PIPs](https://github.com/aharley/pips) and [TAP-Vid](https://github.com/deepmind/tapnet) for publicly releasing their code and data. We also want to thank [Luke Melas-Kyriazi](https://lukemelas.github.io/) for proofreading the paper, [Jianyuan Wang](https://jytime.github.io/), [Roman Shapovalov](https://shapovalov.ro/) and [Adam W. Harley](https://adamharley.com/) for the insightful discussions.
|
||||
|
||||
## Citing CoTracker
|
||||
If you find our repository useful, please consider giving it a star ⭐ and citing our paper in your work:
|
||||
```
|
||||
@article{karaev2023cotracker,
|
||||
title={CoTracker: It is Better to Track Together},
|
||||
author={Nikita Karaev and Ignacio Rocco and Benjamin Graham and Natalia Neverova and Andrea Vedaldi and Christian Rupprecht},
|
||||
journal={arxiv},
|
||||
journal={arXiv:2307.07635},
|
||||
year={2023}
|
||||
}
|
||||
```
|
||||
```
|
||||
|
@ -185,7 +185,11 @@ class Evaluator:
|
||||
if not all(gotit):
|
||||
print("batch is None")
|
||||
continue
|
||||
dataclass_to_cuda_(sample)
|
||||
if torch.cuda.is_available():
|
||||
dataclass_to_cuda_(sample)
|
||||
device = torch.device("cuda")
|
||||
else:
|
||||
device = torch.device("cpu")
|
||||
|
||||
if (
|
||||
not train_mode
|
||||
@ -205,7 +209,7 @@ class Evaluator:
|
||||
queries[:, :, 1],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
).to(device)
|
||||
else:
|
||||
queries = torch.cat(
|
||||
[
|
||||
@ -213,7 +217,7 @@ class Evaluator:
|
||||
sample.trajectory[:, 0],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
).to(device)
|
||||
|
||||
pred_tracks = model(sample.video, queries)
|
||||
if "strided" in dataset_name:
|
||||
|
@ -102,6 +102,8 @@ def run_eval(cfg: DefaultConfig):
|
||||
single_point=cfg.single_point,
|
||||
n_iters=cfg.n_iters,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
predictor.model = predictor.model.cuda()
|
||||
|
||||
# Setting the random seeds
|
||||
torch.manual_seed(cfg.seed)
|
||||
|
@ -12,6 +12,8 @@ from cotracker.models.core.cotracker.cotracker import CoTracker
|
||||
def build_cotracker(
|
||||
checkpoint: str,
|
||||
):
|
||||
if checkpoint is None:
|
||||
return build_cotracker_stride_4_wind_8()
|
||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||
if model_name == "cotracker_stride_4_wind_8":
|
||||
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
|
||||
|
@ -25,11 +25,11 @@ from cotracker.models.core.embeddings import (
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device='cuda'):
|
||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0), device="cuda"):
|
||||
if grid_size == 1:
|
||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
|
||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2], device=device)[
|
||||
None, None
|
||||
].to(device)
|
||||
]
|
||||
|
||||
grid_y, grid_x = meshgrid2d(
|
||||
1, grid_size, grid_size, stack=False, norm=False, device=device
|
||||
|
@ -29,11 +29,10 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
self.n_iters = n_iters
|
||||
|
||||
self.model = cotracker_model
|
||||
self.model.to("cuda")
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, video, queries):
|
||||
queries = queries.clone().cuda()
|
||||
queries = queries.clone()
|
||||
B, T, C, H, W = video.shape
|
||||
B, N, D = queries.shape
|
||||
|
||||
@ -42,14 +41,16 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
|
||||
rgbs = video.reshape(B * T, C, H, W)
|
||||
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
|
||||
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
|
||||
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
device = rgbs.device
|
||||
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
|
||||
if self.single_point:
|
||||
traj_e = torch.zeros((B, T, N, 2)).cuda()
|
||||
vis_e = torch.zeros((B, T, N)).cuda()
|
||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||
vis_e = torch.zeros((B, T, N), device=device)
|
||||
for pind in range((N)):
|
||||
query = queries[:, pind : pind + 1]
|
||||
|
||||
@ -60,8 +61,10 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||
else:
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(
|
||||
device
|
||||
) #
|
||||
queries = torch.cat([queries, xy], dim=1) #
|
||||
|
||||
traj_e, __, vis_e, __ = self.model(
|
||||
@ -91,8 +94,8 @@ class EvaluationPredictor(torch.nn.Module):
|
||||
query = torch.cat([query, xy_target], dim=1).to(device) #
|
||||
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:], device=device)
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||
query = torch.cat([query, xy], dim=1).to(device) #
|
||||
# crop the video to start from the queried frame
|
||||
query[0, 0, 0] = 0
|
||||
|
@ -25,8 +25,6 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
model = build_cotracker(checkpoint)
|
||||
|
||||
self.model = model
|
||||
self.device = device or 'cuda'
|
||||
self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
@ -73,7 +71,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(self.device)
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
for offset in tqdm(range(grid_step * grid_step)):
|
||||
ox = offset % grid_step
|
||||
@ -108,10 +106,8 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").to(self.device)
|
||||
video = video.reshape(
|
||||
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
||||
).to(self.device)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear")
|
||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||
|
||||
if queries is not None:
|
||||
queries = queries.clone()
|
||||
@ -120,7 +116,7 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=self.device)
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(
|
||||
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||
|
@ -14,7 +14,6 @@ from matplotlib import cm
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
@ -67,7 +66,7 @@ class Visualizer:
|
||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||
filename: str = "video",
|
||||
writer: SummaryWriter = None,
|
||||
writer=None, # tensorboard Summary Writer, used for visualization during training
|
||||
step: int = 0,
|
||||
query_frame: int = 0,
|
||||
save_video: bool = True,
|
||||
|
12
demo.py
12
demo.py
@ -32,11 +32,6 @@ if __name__ == "__main__":
|
||||
default="./checkpoints/cotracker_stride_4_wind_8.pth",
|
||||
help="cotracker model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="cuda",
|
||||
help="Device to use for inference",
|
||||
)
|
||||
parser.add_argument("--grid_size", type=int, default=0, help="Regular grid size")
|
||||
parser.add_argument(
|
||||
"--grid_query_frame",
|
||||
@ -59,7 +54,12 @@ if __name__ == "__main__":
|
||||
segm_mask = np.array(Image.open(os.path.join(args.mask_path)))
|
||||
segm_mask = torch.from_numpy(segm_mask)[None, None]
|
||||
|
||||
model = CoTrackerPredictor(checkpoint=args.checkpoint, device=args.device)
|
||||
model = CoTrackerPredictor(checkpoint=args.checkpoint)
|
||||
if torch.cuda.is_available():
|
||||
model = model.cuda()
|
||||
video = video.cuda()
|
||||
else:
|
||||
print("CUDA is not available!")
|
||||
|
||||
pred_tracks, pred_visibility = model(
|
||||
video,
|
||||
|
32
hubconf.py
Normal file
32
hubconf.py
Normal file
@ -0,0 +1,32 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
dependencies = ["torch", "einops", "timm", "tqdm"]
|
||||
|
||||
_COTRACKER_URL = (
|
||||
"https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
|
||||
)
|
||||
|
||||
|
||||
def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs):
|
||||
from cotracker.predictor import CoTrackerPredictor
|
||||
|
||||
predictor = CoTrackerPredictor(checkpoint=None)
|
||||
if pretrained:
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
_COTRACKER_URL, map_location="cpu"
|
||||
)
|
||||
predictor.model.load_state_dict(state_dict)
|
||||
return predictor
|
||||
|
||||
|
||||
def cotracker_w8(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
CoTracker model with stride 4 and window length 8. (The main model from the paper)
|
||||
"""
|
||||
return _make_cotracker_predictor(pretrained=pretrained, **kwargs)
|
File diff suppressed because one or more lines are too long
111
train.py
111
train.py
@ -36,21 +36,6 @@ from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_
|
||||
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
|
||||
|
||||
|
||||
# define the handler function
|
||||
# for training on a slurm cluster
|
||||
def sig_handler(signum, frame):
|
||||
print("caught signal", signum)
|
||||
print(socket.gethostname(), "USR1 signal caught.")
|
||||
# do other stuff to cleanup here
|
||||
print("requeuing job " + os.environ["SLURM_JOB_ID"])
|
||||
os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
|
||||
sys.exit(-1)
|
||||
|
||||
|
||||
def term_handler(signum, frame):
|
||||
print("bypassing sigterm", flush=True)
|
||||
|
||||
|
||||
def fetch_optimizer(args, model):
|
||||
"""Create the optimizer and learning rate scheduler"""
|
||||
optimizer = optim.AdamW(
|
||||
@ -153,6 +138,8 @@ def run_test_eval(evaluator, model, dataloaders, writer, step):
|
||||
single_point=False,
|
||||
n_iters=6,
|
||||
)
|
||||
if torch.cuda.is_available():
|
||||
predictor.model = predictor.model.cuda()
|
||||
|
||||
metrics = evaluator.evaluate_sequence(
|
||||
model=predictor,
|
||||
@ -302,9 +289,7 @@ class Lite(LightningLite):
|
||||
eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
|
||||
|
||||
if "tapvid_davis_first" in args.eval_datasets:
|
||||
data_root = os.path.join(
|
||||
args.dataset_root, "/tapvid_davis/tapvid_davis.pkl"
|
||||
)
|
||||
data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl")
|
||||
eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
|
||||
eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
|
||||
eval_dataset,
|
||||
@ -551,17 +536,15 @@ class Lite(LightningLite):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
signal.signal(signal.SIGUSR1, sig_handler)
|
||||
signal.signal(signal.SIGTERM, term_handler)
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_name", default="cotracker", help="model name")
|
||||
parser.add_argument("--restore_ckpt", help="restore checkpoint")
|
||||
parser.add_argument("--ckpt_path", help="restore checkpoint")
|
||||
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
|
||||
parser.add_argument("--ckpt_path", help="path to save checkpoints")
|
||||
parser.add_argument(
|
||||
"--batch_size", type=int, default=4, help="batch size used during training."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_workers", type=int, default=6, help="left right consistency loss"
|
||||
"--num_workers", type=int, default=6, help="number of dataloader workers"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@ -578,20 +561,34 @@ if __name__ == "__main__":
|
||||
"--evaluate_every_n_epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of flow-field updates during validation forward pass",
|
||||
help="evaluate during training after every n epochs, after every epoch by default",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_every_n_epoch",
|
||||
type=int,
|
||||
default=1,
|
||||
help="number of flow-field updates during validation forward pass",
|
||||
help="save checkpoints during training after every n epochs, after every epoch by default",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validate_at_start", action="store_true", help="use mixed precision"
|
||||
"--validate_at_start",
|
||||
action="store_true",
|
||||
help="whether to run evaluation before training starts",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save_freq",
|
||||
type=int,
|
||||
default=100,
|
||||
help="frequency of trajectory visualization during training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--traj_per_sample",
|
||||
type=int,
|
||||
default=768,
|
||||
help="the number of trajectories to sample for training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_root", type=str, help="path lo all the datasets (train and eval)"
|
||||
)
|
||||
parser.add_argument("--save_freq", type=int, default=100, help="save_freq")
|
||||
parser.add_argument("--traj_per_sample", type=int, default=768, help="save_freq")
|
||||
parser.add_argument("--dataset_root", type=str, help="path lo all the datasets")
|
||||
|
||||
parser.add_argument(
|
||||
"--train_iters",
|
||||
@ -605,49 +602,75 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--eval_datasets",
|
||||
nargs="+",
|
||||
default=["things", "badja", "fastcapture"],
|
||||
help="eval datasets.",
|
||||
default=["things", "badja"],
|
||||
help="what datasets to use for evaluation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--remove_space_attn", action="store_true", help="use mixed precision"
|
||||
"--remove_space_attn",
|
||||
action="store_true",
|
||||
help="remove space attention from CoTracker",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dont_use_augs", action="store_true", help="use mixed precision"
|
||||
"--dont_use_augs",
|
||||
action="store_true",
|
||||
help="don't apply augmentations during training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_vis_1st_frame", action="store_true", help="use mixed precision"
|
||||
"--sample_vis_1st_frame",
|
||||
action="store_true",
|
||||
help="only sample trajectories with points visible on the first frame",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sliding_window_len", type=int, default=8, help="use mixed precision"
|
||||
"--sliding_window_len",
|
||||
type=int,
|
||||
default=8,
|
||||
help="length of the CoTracker sliding window",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_hidden_size", type=int, default=384, help="use mixed precision"
|
||||
"--updateformer_hidden_size",
|
||||
type=int,
|
||||
default=384,
|
||||
help="hidden dimension of the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_num_heads", type=int, default=8, help="use mixed precision"
|
||||
"--updateformer_num_heads",
|
||||
type=int,
|
||||
default=8,
|
||||
help="number of heads of the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_space_depth", type=int, default=12, help="use mixed precision"
|
||||
"--updateformer_space_depth",
|
||||
type=int,
|
||||
default=12,
|
||||
help="number of group attention layers in the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--updateformer_time_depth", type=int, default=12, help="use mixed precision"
|
||||
"--updateformer_time_depth",
|
||||
type=int,
|
||||
default=12,
|
||||
help="number of time attention layers in the CoTracker transformer model",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model_stride", type=int, default=8, help="use mixed precision"
|
||||
"--model_stride",
|
||||
type=int,
|
||||
default=8,
|
||||
help="stride of the CoTracker feature network",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop_size",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[384, 512],
|
||||
help="use mixed precision",
|
||||
help="crop videos to this resolution during training",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_max_seq_len", type=int, default=1000, help="use mixed precision"
|
||||
"--eval_max_seq_len",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="maximum length of evaluation videos",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s",
|
||||
@ -661,5 +684,5 @@ if __name__ == "__main__":
|
||||
devices="auto",
|
||||
accelerator="gpu",
|
||||
precision=32,
|
||||
num_nodes=4,
|
||||
# num_nodes=4,
|
||||
).run(args)
|
||||
|
Loading…
Reference in New Issue
Block a user