\
+ title="🎨 CoTracker: It is Better to Track Together",
+ description="
\
Welcome to CoTracker! This space demonstrates point (pixel) tracking in videos. \
Points are sampled on a regular grid and are tracked jointly.
\
To get started, simply upload your .mp4 video in landscape orientation or click on one of the example videos to load them. The shorter the video, the faster the processing. We recommend submitting short videos of length 2-7 seconds.
\
\
- The total number of grid points is the square of Grid Size.
\
- To specify the starting frame for tracking, adjust Grid Query Frame. Tracks will be visualized only after the selected frame.
\
- - Use Backward Tracking to track points from the selected frame in both directions.
\
- Check Visualize Track Traces to visualize traces of all the tracked points.
\
\
For more details, check out our GitHub Repo ⭐
\
",
-
fn=cotracker_demo,
inputs=[
gr.Video(label="Input video", interactive=True),
gr.Slider(minimum=1, maximum=30, step=1, value=10, label="Grid Size"),
gr.Slider(minimum=0, maximum=30, step=1, value=0, label="Grid Query Frame"),
- gr.Checkbox(label="Backward Tracking"),
gr.Checkbox(label="Visualize Track Traces"),
],
outputs=gr.Video(label="Video with predicted tracks"),
examples=[
- [ "./assets/apple.mp4", 20, 0, False, False ],
- [ "./assets/apple.mp4", 10, 30, True, False ],
+ ["./assets/apple.mp4", 20, 0, False, False],
+ ["./assets/apple.mp4", 10, 30, True, False],
],
- cache_examples=False
+ cache_examples=False,
)
app.launch(share=False)
diff --git a/gradio_demo/requirements.txt b/gradio_demo/requirements.txt
index 73d9745..67afbdf 100644
--- a/gradio_demo/requirements.txt
+++ b/gradio_demo/requirements.txt
@@ -1,7 +1,3 @@
-einops
-timm
-tqdm
-opencv-python
matplotlib
moviepy
flow_vis
diff --git a/hubconf.py b/hubconf.py
index c9ceac4..da13030 100644
--- a/hubconf.py
+++ b/hubconf.py
@@ -6,27 +6,33 @@
import torch
-dependencies = ["torch", "einops", "timm", "tqdm"]
-
-_COTRACKER_URL = (
- "https://dl.fbaipublicfiles.com/cotracker/cotracker_stride_4_wind_8.pth"
-)
+_COTRACKER_URL = "https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth"
-def _make_cotracker_predictor(*, pretrained: bool = True, **kwargs):
- from cotracker.predictor import CoTrackerPredictor
+def _make_cotracker_predictor(*, pretrained: bool = True, online=False, **kwargs):
+ if online:
+ from cotracker.predictor import CoTrackerOnlinePredictor
- predictor = CoTrackerPredictor(checkpoint=None)
+ predictor = CoTrackerOnlinePredictor(checkpoint=None)
+ else:
+ from cotracker.predictor import CoTrackerPredictor
+
+ predictor = CoTrackerPredictor(checkpoint=None)
if pretrained:
- state_dict = torch.hub.load_state_dict_from_url(
- _COTRACKER_URL, map_location="cpu"
- )
+ state_dict = torch.hub.load_state_dict_from_url(_COTRACKER_URL, map_location="cpu")
predictor.model.load_state_dict(state_dict)
return predictor
-def cotracker_w8(*, pretrained: bool = True, **kwargs):
+def cotracker2(*, pretrained: bool = True, **kwargs):
"""
- CoTracker model with stride 4 and window length 8. (The main model from the paper)
+ CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly.
"""
- return _make_cotracker_predictor(pretrained=pretrained, **kwargs)
+ return _make_cotracker_predictor(pretrained=pretrained, online=False, **kwargs)
+
+
+def cotracker2_online(*, pretrained: bool = True, **kwargs):
+ """
+ Online CoTracker2 with stride 4 and window length 8. Can track up to 265*265 points jointly.
+ """
+ return _make_cotracker_predictor(pretrained=pretrained, online=True, **kwargs)
diff --git a/launch_training.sh b/launch_training.sh
new file mode 100644
index 0000000..555cfe3
--- /dev/null
+++ b/launch_training.sh
@@ -0,0 +1,24 @@
+#!/bin/bash
+
+EXP_DIR=$1
+EXP_NAME=$2
+DATE=$3
+DATASET_ROOT=$4
+NUM_STEPS=$5
+
+
+echo `which python`
+
+mkdir -p ${EXP_DIR}/${DATE}_${EXP_NAME}/logs/;
+
+export PYTHONPATH=`(cd ../ && pwd)`:`pwd`:$PYTHONPATH
+sbatch --comment=${EXP_NAME} --partition=learn --time=39:00:00 --gpus-per-node=8 --nodes=4 --ntasks-per-node=8 \
+--job-name=${EXP_NAME} --cpus-per-task=10 --signal=USR1@60 --open-mode=append \
+--output=${EXP_DIR}/${DATE}_${EXP_NAME}/logs/%j_%x_%A_%a_%N.out \
+--error=${EXP_DIR}/${DATE}_${EXP_NAME}/logs/%j_%x_%A_%a_%N.err \
+--wrap="srun --label python ./train.py --batch_size 1 \
+--num_steps ${NUM_STEPS} --ckpt_path ${EXP_DIR}/${DATE}_${EXP_NAME} --model_name cotracker \
+--save_freq 200 --sequence_len 24 --eval_datasets dynamic_replica tapvid_davis_first \
+--traj_per_sample 768 --sliding_window_len 8 \
+--save_every_n_epoch 10 --evaluate_every_n_epoch 10 --model_stride 4 --dataset_root ${DATASET_ROOT} --num_nodes 4 \
+--num_virtual_tracks 64"
diff --git a/online_demo.py b/online_demo.py
new file mode 100644
index 0000000..e05ed41
--- /dev/null
+++ b/online_demo.py
@@ -0,0 +1,90 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import argparse
+import imageio.v3 as iio
+import numpy as np
+
+from cotracker.utils.visualizer import Visualizer
+from cotracker.predictor import CoTrackerOnlinePredictor
+
+# Unfortunately MPS acceleration does not support all the features we require,
+# but we may be able to enable it in the future
+
+DEFAULT_DEVICE = (
+ # "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
+ "cuda"
+ if torch.cuda.is_available()
+ else "cpu"
+)
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--video_path",
+ default="./assets/apple.mp4",
+ help="path to a video",
+ )
+ parser.add_argument(
+ "--checkpoint",
+ default=None,
+ help="CoTracker model parameters",
+ )
+ parser.add_argument("--grid_size", type=int, default=10, help="Regular grid size")
+ parser.add_argument(
+ "--grid_query_frame",
+ type=int,
+ default=0,
+ help="Compute dense and grid tracks starting from this frame",
+ )
+
+ args = parser.parse_args()
+
+ if args.checkpoint is not None:
+ model = CoTrackerOnlinePredictor(checkpoint=args.checkpoint)
+ else:
+ model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online")
+ model = model.to(DEFAULT_DEVICE)
+
+ window_frames = []
+
+ def _process_step(window_frames, is_first_step, grid_size):
+ video_chunk = (
+ torch.tensor(np.stack(window_frames[-model.step * 2 :]), device=DEFAULT_DEVICE)
+ .float()
+ .permute(0, 3, 1, 2)[None]
+ ) # (1, T, 3, H, W)
+ return model(video_chunk, is_first_step=is_first_step, grid_size=grid_size)
+
+ # Iterating over video frames, processing one window at a time:
+ is_first_step = True
+ for i, frame in enumerate(
+ iio.imiter(
+ "https://github.com/facebookresearch/co-tracker/blob/main/assets/apple.mp4",
+ plugin="FFMPEG",
+ )
+ ):
+ if i % model.step == 0 and i != 0:
+ pred_tracks, pred_visibility = _process_step(
+ window_frames, is_first_step, grid_size=args.grid_size
+ )
+ is_first_step = False
+ window_frames.append(frame)
+ # Processing the final video frames in case video length is not a multiple of model.step
+ pred_tracks, pred_visibility = _process_step(
+ window_frames[-(i % model.step) - model.step - 1 :],
+ is_first_step,
+ grid_size=args.grid_size,
+ )
+
+ print("Tracks are computed")
+
+ # save a video with predicted tracks
+ seq_name = args.video_path.split("/")[-1]
+ video = torch.tensor(np.stack(window_frames), device=DEFAULT_DEVICE).permute(0, 3, 1, 2)[None]
+ vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3)
+ vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame)
diff --git a/setup.py b/setup.py
index ad6dc97..c67b1e1 100644
--- a/setup.py
+++ b/setup.py
@@ -8,11 +8,11 @@ from setuptools import find_packages, setup
setup(
name="cotracker",
- version="1.0",
+ version="2.0",
install_requires=[],
packages=find_packages(exclude="notebooks"),
extras_require={
- "all": ["matplotlib", "opencv-python"],
+ "all": ["matplotlib"],
"dev": ["flake8", "black"],
},
)
diff --git a/tests/test_bilinear_sample.py b/tests/test_bilinear_sample.py
new file mode 100644
index 0000000..29e5322
--- /dev/null
+++ b/tests/test_bilinear_sample.py
@@ -0,0 +1,51 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import unittest
+
+from cotracker.models.core.model_utils import bilinear_sampler
+
+
+class TestBilinearSampler(unittest.TestCase):
+ # Sample from an image (4d)
+ def _test4d(self, align_corners):
+ H, W = 4, 5
+ # Construct a grid to obtain indentity sampling
+ input = torch.randn(H * W).view(1, 1, H, W).float()
+ coords = torch.meshgrid(torch.arange(H), torch.arange(W))
+ coords = torch.stack(coords[::-1], dim=-1).float()[None]
+ if not align_corners:
+ coords = coords + 0.5
+ sampled_input = bilinear_sampler(input, coords, align_corners=align_corners)
+ torch.testing.assert_close(input, sampled_input)
+
+ # Sample from a video (5d)
+ def _test5d(self, align_corners):
+ T, H, W = 3, 4, 5
+ # Construct a grid to obtain indentity sampling
+ input = torch.randn(H * W).view(1, 1, H, W).float()
+ input = torch.stack([input, input + 1, input + 2], dim=2)
+ coords = torch.meshgrid(torch.arange(T), torch.arange(W), torch.arange(H))
+ coords = torch.stack(coords, dim=-1).float().permute(0, 2, 1, 3)[None]
+
+ if not align_corners:
+ coords = coords + 0.5
+ sampled_input = bilinear_sampler(input, coords, align_corners=align_corners)
+ torch.testing.assert_close(input, sampled_input)
+
+ def test4d(self):
+ self._test4d(align_corners=True)
+ self._test4d(align_corners=False)
+
+ def test5d(self):
+ self._test5d(align_corners=True)
+ self._test5d(align_corners=False)
+
+
+# run the test
+unittest.main()
diff --git a/train.py b/train.py
index 13ad8ef..c2b354f 100644
--- a/train.py
+++ b/train.py
@@ -25,22 +25,35 @@ from torch.utils.tensorboard import SummaryWriter
from pytorch_lightning.lite import LightningLite
from cotracker.models.evaluation_predictor import EvaluationPredictor
-from cotracker.models.core.cotracker.cotracker import CoTracker
+from cotracker.models.core.cotracker.cotracker import CoTracker2
from cotracker.utils.visualizer import Visualizer
from cotracker.datasets.tap_vid_datasets import TapVidDataset
-from cotracker.datasets.badja_dataset import BadjaDataset
-from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
+
+from cotracker.datasets.dr_dataset import DynamicReplicaDataset
from cotracker.evaluation.core.evaluator import Evaluator
from cotracker.datasets import kubric_movif_dataset
from cotracker.datasets.utils import collate_fn, collate_fn_train, dataclass_to_cuda_
from cotracker.models.core.cotracker.losses import sequence_loss, balanced_ce_loss
+# define the handler function
+# for training on a slurm cluster
+def sig_handler(signum, frame):
+ print("caught signal", signum)
+ print(socket.gethostname(), "USR1 signal caught.")
+ # do other stuff to cleanup here
+ print("requeuing job " + os.environ["SLURM_JOB_ID"])
+ os.system("scontrol requeue " + os.environ["SLURM_JOB_ID"])
+ sys.exit(-1)
+
+
+def term_handler(signum, frame):
+ print("bypassing sigterm", flush=True)
+
+
def fetch_optimizer(args, model):
"""Create the optimizer and learning rate scheduler"""
- optimizer = optim.AdamW(
- model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8
- )
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wdecay, eps=1e-8)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
args.lr,
@@ -53,69 +66,61 @@ def fetch_optimizer(args, model):
return optimizer, scheduler
-def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
- rgbs = batch.video
+def forward_batch(batch, model, args):
+ video = batch.video
trajs_g = batch.trajectory
vis_g = batch.visibility
valids = batch.valid
- B, T, C, H, W = rgbs.shape
+ B, T, C, H, W = video.shape
assert C == 3
B, T, N, D = trajs_g.shape
- device = rgbs.device
+ device = video.device
__, first_positive_inds = torch.max(vis_g, dim=1)
# We want to make sure that during training the model sees visible points
# that it does not need to track just yet: they are visible but queried from a later frame
N_rand = N // 4
# inds of visible points in the 1st frame
- nonzero_inds = [torch.nonzero(vis_g[0, :, i]) for i in range(N)]
- rand_vis_inds = torch.cat(
- [
- nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
- for nonzero_row in nonzero_inds
- ],
- dim=1,
- )
- first_positive_inds = torch.cat(
- [rand_vis_inds[:, :N_rand], first_positive_inds[:, N_rand:]], dim=1
- )
+ nonzero_inds = [[torch.nonzero(vis_g[b, :, i]) for i in range(N)] for b in range(B)]
+
+ for b in range(B):
+ rand_vis_inds = torch.cat(
+ [
+ nonzero_row[torch.randint(len(nonzero_row), size=(1,))]
+ for nonzero_row in nonzero_inds[b]
+ ],
+ dim=1,
+ )
+ first_positive_inds[b] = torch.cat(
+ [rand_vis_inds[:, :N_rand], first_positive_inds[b : b + 1, N_rand:]], dim=1
+ )
+
ind_array_ = torch.arange(T, device=device)
ind_array_ = ind_array_[None, :, None].repeat(B, 1, N)
assert torch.allclose(
vis_g[ind_array_ == first_positive_inds[:, None, :]],
- torch.ones_like(vis_g),
- )
- assert torch.allclose(
- vis_g[ind_array_ == rand_vis_inds[:, None, :]], torch.ones_like(vis_g)
- )
-
- gather = torch.gather(
- trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
+ torch.ones(1, device=device),
)
+ gather = torch.gather(trajs_g, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, D))
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
- queries = torch.cat([first_positive_inds[:, :, None], xys], dim=2)
+ queries = torch.cat([first_positive_inds[:, :, None], xys[:, :, :2]], dim=2)
- predictions, __, visibility, train_data = model(
- rgbs=rgbs, queries=queries, iters=args.train_iters, is_train=True
+ predictions, visibility, train_data = model(
+ video=video, queries=queries, iters=args.train_iters, is_train=True
)
- vis_predictions, coord_predictions, wind_inds, sort_inds = train_data
-
- trajs_g = trajs_g[:, :, sort_inds]
- vis_g = vis_g[:, :, sort_inds]
- valids = valids[:, :, sort_inds]
+ coord_predictions, vis_predictions, valid_mask = train_data
vis_gts = []
traj_gts = []
valids_gts = []
- for i, wind_idx in enumerate(wind_inds):
- ind = i * (args.sliding_window_len // 2)
-
- vis_gts.append(vis_g[:, ind : ind + args.sliding_window_len, :wind_idx])
- traj_gts.append(trajs_g[:, ind : ind + args.sliding_window_len, :wind_idx])
- valids_gts.append(valids[:, ind : ind + args.sliding_window_len, :wind_idx])
-
+ S = args.sliding_window_len
+ for ind in range(0, args.sequence_len - S // 2, S // 2):
+ vis_gts.append(vis_g[:, ind : ind + S])
+ traj_gts.append(trajs_g[:, ind : ind + S])
+ valids_gts.append(valids[:, ind : ind + S] * valid_mask[:, ind : ind + S])
+
seq_loss = sequence_loss(coord_predictions, traj_gts, vis_gts, valids_gts, 0.8)
vis_loss = balanced_ce_loss(vis_predictions, vis_gts, valids_gts)
@@ -131,9 +136,17 @@ def forward_batch(batch, model, args, loss_fn=None, writer=None, step=0):
def run_test_eval(evaluator, model, dataloaders, writer, step):
model.eval()
for ds_name, dataloader in dataloaders:
+ visualize_every = 1
+ grid_size = 5
+ if ds_name == "dynamic_replica":
+ visualize_every = 8
+ grid_size = 0
+ elif "tapvid" in ds_name:
+ visualize_every = 5
+
predictor = EvaluationPredictor(
model.module.module,
- grid_size=6,
+ grid_size=grid_size,
local_grid_size=0,
single_point=False,
n_iters=6,
@@ -148,37 +161,23 @@ def run_test_eval(evaluator, model, dataloaders, writer, step):
train_mode=True,
writer=writer,
step=step,
+ visualize_every=visualize_every,
)
- if ds_name == "badja" or ds_name == "fastcapture" or ("kubric" in ds_name):
-
- metrics = {
- **{
- f"{ds_name}_avg": np.mean(
- [v for k, v in metrics.items() if "accuracy" not in k]
- )
- },
- **{
- f"{ds_name}_avg_accuracy": np.mean(
- [v for k, v in metrics.items() if "accuracy" in k]
- )
- },
- }
- print("avg", np.mean([v for v in metrics.values()]))
+ if ds_name == "dynamic_replica" or ds_name == "kubric":
+ metrics = {f"{ds_name}_avg_{k}": v for k, v in metrics["avg"].items()}
if "tapvid" in ds_name:
metrics = {
- f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"] * 100,
- f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"]
- * 100,
- f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"] * 100,
+ f"{ds_name}_avg_OA": metrics["avg"]["occlusion_accuracy"],
+ f"{ds_name}_avg_delta": metrics["avg"]["average_pts_within_thresh"],
+ f"{ds_name}_avg_Jaccard": metrics["avg"]["average_jaccard"],
}
- writer.add_scalars(f"Eval", metrics, step)
+ writer.add_scalars(f"Eval_{ds_name}", metrics, step)
class Logger:
-
SUM_FREQ = 100
def __init__(self, model, scheduler):
@@ -190,24 +189,19 @@ class Logger:
def _print_training_status(self):
metrics_data = [
- self.running_loss[k] / Logger.SUM_FREQ
- for k in sorted(self.running_loss.keys())
+ self.running_loss[k] / Logger.SUM_FREQ for k in sorted(self.running_loss.keys())
]
training_str = "[{:6d}] ".format(self.total_steps + 1)
metrics_str = ("{:10.4f}, " * len(metrics_data)).format(*metrics_data)
# print the training status
- logging.info(
- f"Training Metrics ({self.total_steps}): {training_str + metrics_str}"
- )
+ logging.info(f"Training Metrics ({self.total_steps}): {training_str + metrics_str}")
if self.writer is None:
self.writer = SummaryWriter(log_dir=os.path.join(args.ckpt_path, "runs"))
for k in self.running_loss:
- self.writer.add_scalar(
- k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps
- )
+ self.writer.add_scalar(k, self.running_loss[k] / Logger.SUM_FREQ, self.total_steps)
self.running_loss[k] = 0.0
def push(self, metrics, task):
@@ -249,79 +243,56 @@ class Lite(LightningLite):
seed_everything(0)
def seed_worker(worker_id):
- worker_seed = torch.initial_seed() % 2 ** 32
+ worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
g = torch.Generator()
g.manual_seed(0)
+ if self.global_rank == 0:
+ eval_dataloaders = []
+ if "dynamic_replica" in args.eval_datasets:
+ eval_dataset = DynamicReplicaDataset(
+ sample_len=60, only_first_n_samples=1, rgbd_input=False
+ )
+ eval_dataloader_dr = torch.utils.data.DataLoader(
+ eval_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=1,
+ collate_fn=collate_fn,
+ )
+ eval_dataloaders.append(("dynamic_replica", eval_dataloader_dr))
- eval_dataloaders = []
- if "badja" in args.eval_datasets:
- eval_dataset = BadjaDataset(
- data_root=os.path.join(args.dataset_root, "BADJA"),
- max_seq_len=args.eval_max_seq_len,
- dataset_resolution=args.crop_size,
+ if "tapvid_davis_first" in args.eval_datasets:
+ data_root = os.path.join(args.dataset_root, "tapvid/tapvid_davis/tapvid_davis.pkl")
+ eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
+ eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
+ eval_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=1,
+ collate_fn=collate_fn,
+ )
+ eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
+
+ evaluator = Evaluator(args.ckpt_path)
+
+ visualizer = Visualizer(
+ save_dir=args.ckpt_path,
+ pad_value=80,
+ fps=1,
+ show_first_frame=0,
+ tracks_leave_trace=0,
)
- eval_dataloader_badja = torch.utils.data.DataLoader(
- eval_dataset,
- batch_size=1,
- shuffle=False,
- num_workers=8,
- collate_fn=collate_fn,
- )
- eval_dataloaders.append(("badja", eval_dataloader_badja))
-
- if "fastcapture" in args.eval_datasets:
- eval_dataset = FastCaptureDataset(
- data_root=os.path.join(args.dataset_root, "fastcapture"),
- max_seq_len=min(100, args.eval_max_seq_len),
- max_num_points=40,
- dataset_resolution=args.crop_size,
- )
- eval_dataloader_fastcapture = torch.utils.data.DataLoader(
- eval_dataset,
- batch_size=1,
- shuffle=False,
- num_workers=1,
- collate_fn=collate_fn,
- )
- eval_dataloaders.append(("fastcapture", eval_dataloader_fastcapture))
-
- if "tapvid_davis_first" in args.eval_datasets:
- data_root = os.path.join(args.dataset_root, "tapvid_davis/tapvid_davis.pkl")
- eval_dataset = TapVidDataset(dataset_type="davis", data_root=data_root)
- eval_dataloader_tapvid_davis = torch.utils.data.DataLoader(
- eval_dataset,
- batch_size=1,
- shuffle=False,
- num_workers=1,
- collate_fn=collate_fn,
- )
- eval_dataloaders.append(("tapvid_davis", eval_dataloader_tapvid_davis))
-
- evaluator = Evaluator(args.ckpt_path)
-
- visualizer = Visualizer(
- save_dir=args.ckpt_path,
- pad_value=80,
- fps=1,
- show_first_frame=0,
- tracks_leave_trace=0,
- )
-
- loss_fn = None
if args.model_name == "cotracker":
-
- model = CoTracker(
+ model = CoTracker2(
stride=args.model_stride,
- S=args.sliding_window_len,
+ window_len=args.sliding_window_len,
add_space_attn=not args.remove_space_attn,
- num_heads=args.updateformer_num_heads,
- hidden_size=args.updateformer_hidden_size,
- space_depth=args.updateformer_space_depth,
- time_depth=args.updateformer_time_depth,
+ num_virtual_tracks=args.num_virtual_tracks,
+ model_resolution=args.crop_size,
)
else:
raise ValueError(f"Model {args.model_name} doesn't exist")
@@ -332,7 +303,7 @@ class Lite(LightningLite):
model.cuda()
train_dataset = kubric_movif_dataset.KubricMovifDataset(
- data_root=os.path.join(args.dataset_root, "kubric_movi_f"),
+ data_root=os.path.join(args.dataset_root, "kubric", "kubric_movi_f_tracks"),
crop_size=args.crop_size,
seq_len=args.sequence_len,
traj_per_sample=args.traj_per_sample,
@@ -357,7 +328,8 @@ class Lite(LightningLite):
optimizer, scheduler = fetch_optimizer(args, model)
total_steps = 0
- logger = Logger(model, scheduler)
+ if self.global_rank == 0:
+ logger = Logger(model, scheduler)
folder_ckpts = [
f
@@ -383,9 +355,7 @@ class Lite(LightningLite):
logging.info(f"Load total_steps {total_steps}")
elif args.restore_ckpt is not None:
- assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(
- ".pt"
- )
+ assert args.restore_ckpt.endswith(".pth") or args.restore_ckpt.endswith(".pt")
logging.info("Loading checkpoint...")
strict = True
@@ -394,9 +364,7 @@ class Lite(LightningLite):
state_dict = state_dict["model"]
if list(state_dict.keys())[0].startswith("module."):
- state_dict = {
- k.replace("module.", ""): v for k, v in state_dict.items()
- }
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict, strict=strict)
logging.info(f"Done loading checkpoint")
@@ -424,33 +392,22 @@ class Lite(LightningLite):
assert model.training
- output = forward_batch(
- batch,
- model,
- args,
- loss_fn=loss_fn,
- writer=logger.writer,
- step=total_steps,
- )
+ output = forward_batch(batch, model, args)
loss = 0
for k, v in output.items():
if "loss" in v:
loss += v["loss"]
- logger.writer.add_scalar(
- f"live_{k}_loss", v["loss"].item(), total_steps
- )
- if "metrics" in v:
- logger.push(v["metrics"], k)
if self.global_rank == 0:
- if total_steps % save_freq == save_freq - 1:
- if args.model_name == "motion_diffuser":
- pred_coords = model.module.module.forward_batch_test(
- batch, interp_shape=args.crop_size
+ for k, v in output.items():
+ if "loss" in v:
+ logger.writer.add_scalar(
+ f"live_{k}_loss", v["loss"].item(), total_steps
)
-
- output["flow"] = {"predictions": pred_coords[0].detach()}
+ if "metrics" in v:
+ logger.push(v["metrics"], k)
+ if total_steps % save_freq == save_freq - 1:
visualizer.visualize(
video=batch.video.clone(),
tracks=batch.trajectory.clone(),
@@ -468,9 +425,7 @@ class Lite(LightningLite):
)
if len(output) > 1:
- logger.writer.add_scalar(
- f"live_total_loss", loss.item(), total_steps
- )
+ logger.writer.add_scalar(f"live_total_loss", loss.item(), total_steps)
logger.writer.add_scalar(
f"learning_rate", optimizer.param_groups[0]["lr"], total_steps
)
@@ -492,9 +447,7 @@ class Lite(LightningLite):
total_steps == 1 and args.validate_at_start
):
if (epoch + 1) % args.save_every_n_epoch == 0:
- ckpt_iter = "0" * (6 - len(str(total_steps))) + str(
- total_steps
- )
+ ckpt_iter = "0" * (6 - len(str(total_steps))) + str(total_steps)
save_path = Path(
f"{args.ckpt_path}/model_{args.model_name}_{ckpt_iter}.pth"
)
@@ -526,16 +479,18 @@ class Lite(LightningLite):
if total_steps > args.num_steps:
should_keep_training = False
break
+ if self.global_rank == 0:
+ print("FINISHED TRAINING")
- print("FINISHED TRAINING")
-
- PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
- torch.save(model.module.module.state_dict(), PATH)
- run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
- logger.close()
+ PATH = f"{args.ckpt_path}/{args.model_name}_final.pth"
+ torch.save(model.module.module.state_dict(), PATH)
+ run_test_eval(evaluator, model, eval_dataloaders, logger.writer, total_steps)
+ logger.close()
if __name__ == "__main__":
+ signal.signal(signal.SIGUSR1, sig_handler)
+ signal.signal(signal.SIGTERM, term_handler)
parser = argparse.ArgumentParser()
parser.add_argument("--model_name", default="cotracker", help="model name")
parser.add_argument("--restore_ckpt", help="path to restore a checkpoint")
@@ -543,17 +498,12 @@ if __name__ == "__main__":
parser.add_argument(
"--batch_size", type=int, default=4, help="batch size used during training."
)
- parser.add_argument(
- "--num_workers", type=int, default=6, help="number of dataloader workers"
- )
+ parser.add_argument("--num_nodes", type=int, default=1)
+ parser.add_argument("--num_workers", type=int, default=10, help="number of dataloader workers")
- parser.add_argument(
- "--mixed_precision", action="store_true", help="use mixed precision"
- )
+ parser.add_argument("--mixed_precision", action="store_true", help="use mixed precision")
parser.add_argument("--lr", type=float, default=0.0005, help="max learning rate.")
- parser.add_argument(
- "--wdecay", type=float, default=0.00001, help="Weight decay in optimizer."
- )
+ parser.add_argument("--wdecay", type=float, default=0.00001, help="Weight decay in optimizer.")
parser.add_argument(
"--num_steps", type=int, default=200000, help="length of training schedule."
)
@@ -596,13 +546,11 @@ if __name__ == "__main__":
default=4,
help="number of updates to the disparity field in each forward pass.",
)
- parser.add_argument(
- "--sequence_len", type=int, default=8, help="train sequence length"
- )
+ parser.add_argument("--sequence_len", type=int, default=8, help="train sequence length")
parser.add_argument(
"--eval_datasets",
nargs="+",
- default=["things", "badja"],
+ default=["tapvid_davis_first"],
help="what datasets to use for evaluation",
)
@@ -611,6 +559,12 @@ if __name__ == "__main__":
action="store_true",
help="remove space attention from CoTracker",
)
+ parser.add_argument(
+ "--num_virtual_tracks",
+ type=int,
+ default=None,
+ help="stride of the CoTracker feature network",
+ )
parser.add_argument(
"--dont_use_augs",
action="store_true",
@@ -627,30 +581,6 @@ if __name__ == "__main__":
default=8,
help="length of the CoTracker sliding window",
)
- parser.add_argument(
- "--updateformer_hidden_size",
- type=int,
- default=384,
- help="hidden dimension of the CoTracker transformer model",
- )
- parser.add_argument(
- "--updateformer_num_heads",
- type=int,
- default=8,
- help="number of heads of the CoTracker transformer model",
- )
- parser.add_argument(
- "--updateformer_space_depth",
- type=int,
- default=12,
- help="number of group attention layers in the CoTracker transformer model",
- )
- parser.add_argument(
- "--updateformer_time_depth",
- type=int,
- default=12,
- help="number of time attention layers in the CoTracker transformer model",
- )
parser.add_argument(
"--model_stride",
type=int,
@@ -680,9 +610,9 @@ if __name__ == "__main__":
from pytorch_lightning.strategies import DDPStrategy
Lite(
- strategy=DDPStrategy(find_unused_parameters=True),
+ strategy=DDPStrategy(find_unused_parameters=False),
devices="auto",
accelerator="gpu",
precision=32,
- # num_nodes=4,
+ num_nodes=args.num_nodes,
).run(args)