Initial commit

This commit is contained in:
nikitakaraevv
2023-07-17 17:49:06 -07:00
commit 6d62d873fa
41 changed files with 5989 additions and 0 deletions

5
cotracker/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,390 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
import os
import json
import imageio
import cv2
from enum import Enum
from cotracker.datasets.utils import CoTrackerData, resize_sample
IGNORE_ANIMALS = [
# "bear.json",
# "camel.json",
"cat_jump.json"
# "cows.json",
# "dog.json",
# "dog-agility.json",
# "horsejump-high.json",
# "horsejump-low.json",
# "impala0.json",
# "rs_dog.json"
"tiger.json"
]
class SMALJointCatalog(Enum):
# body_0 = 0
# body_1 = 1
# body_2 = 2
# body_3 = 3
# body_4 = 4
# body_5 = 5
# body_6 = 6
# upper_right_0 = 7
upper_right_1 = 8
upper_right_2 = 9
upper_right_3 = 10
# upper_left_0 = 11
upper_left_1 = 12
upper_left_2 = 13
upper_left_3 = 14
neck_lower = 15
# neck_upper = 16
# lower_right_0 = 17
lower_right_1 = 18
lower_right_2 = 19
lower_right_3 = 20
# lower_left_0 = 21
lower_left_1 = 22
lower_left_2 = 23
lower_left_3 = 24
tail_0 = 25
# tail_1 = 26
# tail_2 = 27
tail_3 = 28
# tail_4 = 29
# tail_5 = 30
tail_6 = 31
jaw = 32
nose = 33 # ADDED JOINT FOR VERTEX 1863
# chin = 34 # ADDED JOINT FOR VERTEX 26
right_ear = 35 # ADDED JOINT FOR VERTEX 149
left_ear = 36 # ADDED JOINT FOR VERTEX 2124
class SMALJointInfo:
def __init__(self):
# These are the
self.annotated_classes = np.array(
[
8,
9,
10, # upper_right
12,
13,
14, # upper_left
15, # neck
18,
19,
20, # lower_right
22,
23,
24, # lower_left
25,
28,
31, # tail
32,
33, # head
35, # right_ear
36,
]
) # left_ear
self.annotated_markers = np.array(
[
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_TRIANGLE_DOWN,
cv2.MARKER_CROSS,
cv2.MARKER_STAR,
cv2.MARKER_CROSS,
cv2.MARKER_CROSS,
]
)
self.joint_regions = np.array(
[
0,
0,
0,
0,
0,
0,
0,
1,
1,
1,
1,
2,
2,
2,
2,
3,
3,
4,
4,
4,
4,
5,
5,
5,
5,
6,
6,
6,
6,
6,
6,
6,
7,
7,
7,
8,
9,
]
)
self.annotated_joint_region = self.joint_regions[self.annotated_classes]
self.region_colors = np.array(
[
[250, 190, 190], # body, light pink
[60, 180, 75], # upper_right, green
[230, 25, 75], # upper_left, red
[128, 0, 0], # neck, maroon
[0, 130, 200], # lower_right, blue
[255, 255, 25], # lower_left, yellow
[240, 50, 230], # tail, majenta
[245, 130, 48], # jaw / nose / chin, orange
[29, 98, 115], # right_ear, turquoise
[255, 153, 204],
]
) # left_ear, pink
self.joint_colors = np.array(self.region_colors)[self.annotated_joint_region]
class BADJAData:
def __init__(self, data_root, complete=False):
annotations_path = os.path.join(data_root, "joint_annotations")
self.animal_dict = {}
self.animal_count = 0
self.smal_joint_info = SMALJointInfo()
for __, animal_json in enumerate(sorted(os.listdir(annotations_path))):
if animal_json not in IGNORE_ANIMALS:
json_path = os.path.join(annotations_path, animal_json)
with open(json_path) as json_data:
animal_joint_data = json.load(json_data)
filenames = []
segnames = []
joints = []
visible = []
first_path = animal_joint_data[0]["segmentation_path"]
last_path = animal_joint_data[-1]["segmentation_path"]
first_frame = first_path.split("/")[-1]
last_frame = last_path.split("/")[-1]
if not "extra_videos" in first_path:
animal = first_path.split("/")[-2]
first_frame_int = int(first_frame.split(".")[0])
last_frame_int = int(last_frame.split(".")[0])
for fr in range(first_frame_int, last_frame_int + 1):
ref_file_name = os.path.join(
data_root,
"DAVIS/JPEGImages/Full-Resolution/%s/%05d.jpg"
% (animal, fr),
)
ref_seg_name = os.path.join(
data_root,
"DAVIS/Annotations/Full-Resolution/%s/%05d.png"
% (animal, fr),
)
foundit = False
for ind, image_annotation in enumerate(animal_joint_data):
file_name = os.path.join(
data_root, image_annotation["image_path"]
)
seg_name = os.path.join(
data_root, image_annotation["segmentation_path"]
)
if file_name == ref_file_name:
foundit = True
label_ind = ind
if foundit:
image_annotation = animal_joint_data[label_ind]
file_name = os.path.join(
data_root, image_annotation["image_path"]
)
seg_name = os.path.join(
data_root, image_annotation["segmentation_path"]
)
joint = np.array(image_annotation["joints"])
vis = np.array(image_annotation["visibility"])
else:
file_name = ref_file_name
seg_name = ref_seg_name
joint = None
vis = None
filenames.append(file_name)
segnames.append(seg_name)
joints.append(joint)
visible.append(vis)
if len(filenames):
self.animal_dict[self.animal_count] = (
filenames,
segnames,
joints,
visible,
)
self.animal_count += 1
print("Loaded BADJA dataset")
def get_loader(self):
for __ in range(int(1e6)):
animal_id = np.random.choice(len(self.animal_dict.keys()))
filenames, segnames, joints, visible = self.animal_dict[animal_id]
image_id = np.random.randint(0, len(filenames))
seg_file = segnames[image_id]
image_file = filenames[image_id]
joints = joints[image_id].copy()
joints = joints[self.smal_joint_info.annotated_classes]
visible = visible[image_id][self.smal_joint_info.annotated_classes]
rgb_img = imageio.imread(image_file) # , mode='RGB')
sil_img = imageio.imread(seg_file) # , mode='RGB')
rgb_h, rgb_w, _ = rgb_img.shape
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
yield rgb_img, sil_img, joints, visible, image_file
def get_video(self, animal_id):
filenames, segnames, joint, visible = self.animal_dict[animal_id]
rgbs = []
segs = []
joints = []
visibles = []
for s in range(len(filenames)):
image_file = filenames[s]
rgb_img = imageio.imread(image_file) # , mode='RGB')
rgb_h, rgb_w, _ = rgb_img.shape
seg_file = segnames[s]
sil_img = imageio.imread(seg_file) # , mode='RGB')
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
jo = joint[s]
if jo is not None:
joi = joint[s].copy()
joi = joi[self.smal_joint_info.annotated_classes]
vis = visible[s][self.smal_joint_info.annotated_classes]
else:
joi = None
vis = None
rgbs.append(rgb_img)
segs.append(sil_img)
joints.append(joi)
visibles.append(vis)
return rgbs, segs, joints, visibles, filenames[0]
class BadjaDataset(torch.utils.data.Dataset):
def __init__(
self, data_root, max_seq_len=1000, dataset_resolution=(384, 512)
):
self.data_root = data_root
self.badja_data = BADJAData(data_root)
self.max_seq_len = max_seq_len
self.dataset_resolution = dataset_resolution
print(
"found %d unique videos in %s"
% (self.badja_data.animal_count, self.data_root)
)
def __getitem__(self, index):
rgbs, segs, joints, visibles, filename = self.badja_data.get_video(index)
S = len(rgbs)
H, W, __ = rgbs[0].shape
H, W, __ = segs[0].shape
N, __ = joints[0].shape
# let's eliminate the Nones
# note the first one is guaranteed present
for s in range(1, S):
if joints[s] is None:
joints[s] = np.zeros_like(joints[0])
visibles[s] = np.zeros_like(visibles[0])
# eliminate the mystery dim
segs = [seg[:, :, 0] for seg in segs]
rgbs = np.stack(rgbs, 0)
segs = np.stack(segs, 0)
trajs = np.stack(joints, 0)
visibles = np.stack(visibles, 0)
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
segs = torch.from_numpy(segs).reshape(S, 1, H, W).float()
trajs = torch.from_numpy(trajs).reshape(S, N, 2).float()
visibles = torch.from_numpy(visibles).reshape(S, N)
rgbs = rgbs[: self.max_seq_len]
segs = segs[: self.max_seq_len]
trajs = trajs[: self.max_seq_len]
visibles = visibles[: self.max_seq_len]
# apparently the coords are in yx order
trajs = torch.flip(trajs, [2])
if "extra_videos" in filename:
seq_name = filename.split("/")[-3]
else:
seq_name = filename.split("/")[-2]
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
def __len__(self):
return self.badja_data.animal_count

View File

@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
# from PIL import Image
import imageio
import numpy as np
from cotracker.datasets.utils import CoTrackerData, resize_sample
class FastCaptureDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
max_seq_len=50,
max_num_points=20,
dataset_resolution=(384, 512),
):
self.data_root = data_root
self.seq_names = os.listdir(os.path.join(data_root, "renders_local_rm"))
self.pth_dir = os.path.join(data_root, "zju_tracking")
self.max_seq_len = max_seq_len
self.max_num_points = max_num_points
self.dataset_resolution = dataset_resolution
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
def __getitem__(self, index):
seq_name = self.seq_names[index]
spath = os.path.join(self.data_root, "renders_local_rm", seq_name)
pthpath = os.path.join(self.pth_dir, seq_name + ".pth")
rgbs = []
img_paths = sorted(os.listdir(spath))
for i, img_path in enumerate(img_paths):
if i < self.max_seq_len:
rgbs.append(imageio.imread(os.path.join(spath, img_path)))
annot_dict = torch.load(pthpath)
traj_2d = annot_dict["traj_2d"][:, :, : self.max_seq_len]
visibility = annot_dict["visibility"][:, : self.max_seq_len]
S = len(rgbs)
H, W, __ = rgbs[0].shape
*_, S = traj_2d.shape
visibile_pts_first_frame_inds = (visibility[:, 0] > 0).nonzero(as_tuple=False)[
:, 0
]
torch.manual_seed(0)
point_inds = torch.randperm(len(visibile_pts_first_frame_inds))[
: self.max_num_points
]
visible_inds_sampled = visibile_pts_first_frame_inds[point_inds]
rgbs = np.stack(rgbs, 0)
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
segs = torch.ones(S, 1, H, W).float()
trajs = traj_2d[visible_inds_sampled].permute(2, 0, 1).float()
visibles = visibility[visible_inds_sampled].permute(1, 0)
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
def __len__(self):
return len(self.seq_names)

View File

@@ -0,0 +1,494 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import torch
import imageio
import numpy as np
from cotracker.datasets.utils import CoTrackerData
from torchvision.transforms import ColorJitter, GaussianBlur
from PIL import Image
import cv2
class CoTrackerDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
sample_vis_1st_frame=False,
use_augs=False,
):
super(CoTrackerDataset, self).__init__()
np.random.seed(0)
torch.manual_seed(0)
self.data_root = data_root
self.seq_len = seq_len
self.traj_per_sample = traj_per_sample
self.sample_vis_1st_frame = sample_vis_1st_frame
self.use_augs = use_augs
self.crop_size = crop_size
# photometric augmentation
self.photo_aug = ColorJitter(
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14
)
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
self.blur_aug_prob = 0.25
self.color_aug_prob = 0.25
# occlusion augmentation
self.eraser_aug_prob = 0.5
self.eraser_bounds = [2, 100]
self.eraser_max = 10
# occlusion augmentation
self.replace_aug_prob = 0.5
self.replace_bounds = [2, 100]
self.replace_max = 10
# spatial augmentations
self.pad_bounds = [0, 100]
self.crop_size = crop_size
self.resize_lim = [0.25, 2.0] # sample resizes from here
self.resize_delta = 0.2
self.max_crop_offset = 50
self.do_flip = True
self.h_flip_prob = 0.5
self.v_flip_prob = 0.5
def getitem_helper(self, index):
return NotImplementedError
def __getitem__(self, index):
gotit = False
sample, gotit = self.getitem_helper(index)
if not gotit:
print("warning: sampling failed")
# fake sample, so we can still collate
sample = CoTrackerData(
video=torch.zeros(
(self.seq_len, 3, self.crop_size[0], self.crop_size[1])
),
segmentation=torch.zeros(
(self.seq_len, 1, self.crop_size[0], self.crop_size[1])
),
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
)
return sample, gotit
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
if eraser:
############ eraser transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
for i in range(1, S):
if np.random.rand() < self.eraser_aug_prob:
for _ in range(
np.random.randint(1, self.eraser_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(
self.eraser_bounds[0], self.eraser_bounds[1]
)
dy = np.random.randint(
self.eraser_bounds[0], self.eraser_bounds[1]
)
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
mean_color = np.mean(
rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0
)
rgbs[i][y0:y1, x0:x1, :] = mean_color
occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
)
visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
if replace:
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs_alt
]
############ replace transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
for i in range(1, S):
if np.random.rand() < self.replace_aug_prob:
for _ in range(
np.random.randint(1, self.replace_max + 1)
): # number of times to occlude
xc = np.random.randint(0, W)
yc = np.random.randint(0, H)
dx = np.random.randint(
self.replace_bounds[0], self.replace_bounds[1]
)
dy = np.random.randint(
self.replace_bounds[0], self.replace_bounds[1]
)
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
wid = x1 - x0
hei = y1 - y0
y00 = np.random.randint(0, H - hei)
x00 = np.random.randint(0, W - wid)
fr = np.random.randint(0, S)
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
rgbs[i][y0:y1, x0:x1, :] = rep
occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
)
visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
############ photometric augmentation ############
if np.random.rand() < self.color_aug_prob:
# random per-frame amount of aug
rgbs = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
if np.random.rand() < self.blur_aug_prob:
# random per-frame amount of blur
rgbs = [
np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8)
for rgb in rgbs
]
return rgbs, trajs, visibles
def add_spatial_augs(self, rgbs, trajs, visibles):
T, N, __ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
############ spatial transform ############
# padding
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
rgbs = [
np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs
]
trajs[:, :, 0] += pad_x0
trajs[:, :, 1] += pad_y0
H, W = rgbs[0].shape[:2]
# scaling + stretching
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
scale_x = scale
scale_y = scale
H_new = H
W_new = W
scale_delta_x = 0.0
scale_delta_y = 0.0
rgbs_scaled = []
for s in range(S):
if s == 1:
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
elif s > 1:
scale_delta_x = (
scale_delta_x * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
)
scale_delta_y = (
scale_delta_y * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
)
scale_x = scale_x + scale_delta_x
scale_y = scale_y + scale_delta_y
# bring h/w closer
scale_xy = (scale_x + scale_y) * 0.5
scale_x = scale_x * 0.5 + scale_xy * 0.5
scale_y = scale_y * 0.5 + scale_xy * 0.5
# don't get too crazy
scale_x = np.clip(scale_x, 0.2, 2.0)
scale_y = np.clip(scale_y, 0.2, 2.0)
H_new = int(H * scale_y)
W_new = int(W * scale_x)
# make it at least slightly bigger than the crop area,
# so that the random cropping can add diversity
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
# recompute scale in case we clipped
scale_x = W_new / float(W)
scale_y = H_new / float(H)
rgbs_scaled.append(
cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)
)
trajs[s, :, 0] *= scale_x
trajs[s, :, 1] *= scale_y
rgbs = rgbs_scaled
ok_inds = visibles[0, :] > 0
vis_trajs = trajs[:, ok_inds] # S,?,2
if vis_trajs.shape[1] > 0:
mid_x = np.mean(vis_trajs[0, :, 0])
mid_y = np.mean(vis_trajs[0, :, 1])
else:
mid_y = self.crop_size[0]
mid_x = self.crop_size[1]
x0 = int(mid_x - self.crop_size[1] // 2)
y0 = int(mid_y - self.crop_size[0] // 2)
offset_x = 0
offset_y = 0
for s in range(S):
# on each frame, shift a bit more
if s == 1:
offset_x = np.random.randint(
-self.max_crop_offset, self.max_crop_offset
)
offset_y = np.random.randint(
-self.max_crop_offset, self.max_crop_offset
)
elif s > 1:
offset_x = int(
offset_x * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
* 0.2
)
offset_y = int(
offset_y * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
* 0.2
)
x0 = x0 + offset_x
y0 = y0 + offset_y
H_new, W_new = rgbs[s].shape[:2]
if H_new == self.crop_size[0]:
y0 = 0
else:
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
if W_new == self.crop_size[1]:
x0 = 0
else:
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
trajs[s, :, 0] -= x0
trajs[s, :, 1] -= y0
H_new = self.crop_size[0]
W_new = self.crop_size[1]
# flip
h_flipped = False
v_flipped = False
if self.do_flip:
# h flip
if np.random.rand() < self.h_flip_prob:
h_flipped = True
rgbs = [rgb[:, ::-1] for rgb in rgbs]
# v flip
if np.random.rand() < self.v_flip_prob:
v_flipped = True
rgbs = [rgb[::-1] for rgb in rgbs]
if h_flipped:
trajs[:, :, 0] = W_new - trajs[:, :, 0]
if v_flipped:
trajs[:, :, 1] = H_new - trajs[:, :, 1]
return rgbs, trajs
def crop(self, rgbs, trajs):
T, N, _ = trajs.shape
S = len(rgbs)
H, W = rgbs[0].shape[:2]
assert S == T
############ spatial transform ############
H_new = H
W_new = W
# simple random crop
y0 = (
0
if self.crop_size[0] >= H_new
else np.random.randint(0, H_new - self.crop_size[0])
)
x0 = (
0
if self.crop_size[1] >= W_new
else np.random.randint(0, W_new - self.crop_size[1])
)
rgbs = [
rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
for rgb in rgbs
]
trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0
return rgbs, trajs
class KubricMovifDataset(CoTrackerDataset):
def __init__(
self,
data_root,
crop_size=(384, 512),
seq_len=24,
traj_per_sample=768,
sample_vis_1st_frame=False,
use_augs=False,
):
super(KubricMovifDataset, self).__init__(
data_root=data_root,
crop_size=crop_size,
seq_len=seq_len,
traj_per_sample=traj_per_sample,
sample_vis_1st_frame=sample_vis_1st_frame,
use_augs=use_augs,
)
self.pad_bounds = [0, 25]
self.resize_lim = [0.75, 1.25] # sample resizes from here
self.resize_delta = 0.05
self.max_crop_offset = 15
self.seq_names = [
fname
for fname in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, fname))
]
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
def getitem_helper(self, index):
gotit = True
seq_name = self.seq_names[index]
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
rgb_path = os.path.join(self.data_root, seq_name, "frames")
img_paths = sorted(os.listdir(rgb_path))
rgbs = []
for i, img_path in enumerate(img_paths):
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
rgbs = np.stack(rgbs)
annot_dict = np.load(npy_path, allow_pickle=True).item()
traj_2d = annot_dict["coords"]
visibility = annot_dict["visibility"]
# random crop
assert self.seq_len <= len(rgbs)
if self.seq_len < len(rgbs):
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
rgbs = rgbs[start_ind : start_ind + self.seq_len]
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
visibility = visibility[:, start_ind : start_ind + self.seq_len]
traj_2d = np.transpose(traj_2d, (1, 0, 2))
visibility = np.transpose(np.logical_not(visibility), (1, 0))
if self.use_augs:
rgbs, traj_2d, visibility = self.add_photometric_augs(
rgbs, traj_2d, visibility
)
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
else:
rgbs, traj_2d = self.crop(rgbs, traj_2d)
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False
visibility = torch.from_numpy(visibility)
traj_2d = torch.from_numpy(traj_2d)
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
if self.sample_vis_1st_frame:
visibile_pts_inds = visibile_pts_first_frame_inds
else:
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(
as_tuple=False
)[:, 0]
visibile_pts_inds = torch.cat(
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
)
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
if len(point_inds) < self.traj_per_sample:
gotit = False
visible_inds_sampled = visibile_pts_inds[point_inds]
trajs = traj_2d[:, visible_inds_sampled].float()
visibles = visibility[:, visible_inds_sampled]
valids = torch.ones((self.seq_len, self.traj_per_sample))
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
segs = torch.ones((self.seq_len, 1, self.crop_size[0], self.crop_size[1]))
sample = CoTrackerData(
video=rgbs,
segmentation=segs,
trajectory=trajs,
visibility=visibles,
valid=valids,
seq_name=seq_name,
)
return sample, gotit
def __len__(self):
return len(self.seq_names)

View File

@@ -0,0 +1,218 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import io
import glob
import torch
import pickle
import numpy as np
import mediapy as media
from PIL import Image
from typing import Mapping, Tuple, Union
from cotracker.datasets.utils import CoTrackerData
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""Resize a video to output_size."""
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
# such as a jitted jax.image.resize. It will make things faster.
return media.resize_video(video, output_size)
def sample_queries_first(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, use the first
visible point in each track as the query.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1]
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1]
"""
valid = np.sum(~target_occluded, axis=1) > 0
target_points = target_points[valid, :]
target_occluded = target_occluded[valid, :]
query_points = []
for i in range(target_points.shape[0]):
index = np.where(target_occluded[i] == 0)[0][0]
x, y = target_points[i, index, 0], target_points[i, index, 1]
query_points.append(np.array([index, y, x])) # [t, y, x]
query_points = np.stack(query_points, axis=0)
return {
"video": frames[np.newaxis, ...],
"query_points": query_points[np.newaxis, ...],
"target_points": target_points[np.newaxis, ...],
"occluded": target_occluded[np.newaxis, ...],
}
def sample_queries_strided(
target_occluded: np.ndarray,
target_points: np.ndarray,
frames: np.ndarray,
query_stride: int = 5,
) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, sample queries
strided every query_stride frames, ignoring points that are not visible
at the selected frames.
Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1.
query_stride: When sampling query points, search for un-occluded points
every query_stride frames and convert each one into a query.
Returns:
A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
has floats scaled to the range [-1, 1].
query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1].
target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1].
trackgroup: Index of the original track that each query point was
sampled from. This is useful for visualization.
"""
tracks = []
occs = []
queries = []
trackgroups = []
total = 0
trackgroup = np.arange(target_occluded.shape[0])
for i in range(0, target_occluded.shape[1], query_stride):
mask = target_occluded[:, i] == 0
query = np.stack(
[
i * np.ones(target_occluded.shape[0:1]),
target_points[:, i, 1],
target_points[:, i, 0],
],
axis=-1,
)
queries.append(query[mask])
tracks.append(target_points[mask])
occs.append(target_occluded[mask])
trackgroups.append(trackgroup[mask])
total += np.array(np.sum(target_occluded[:, i] == 0))
return {
"video": frames[np.newaxis, ...],
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
}
class TapVidDataset(torch.utils.data.Dataset):
def __init__(
self,
data_root,
dataset_type="davis",
resize_to_256=True,
queried_first=True,
):
self.dataset_type = dataset_type
self.resize_to_256 = resize_to_256
self.queried_first = queried_first
if self.dataset_type == "kinetics":
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
points_dataset = []
for pickle_path in all_paths:
with open(pickle_path, "rb") as f:
data = pickle.load(f)
points_dataset = points_dataset + data
self.points_dataset = points_dataset
else:
with open(data_root, "rb") as f:
self.points_dataset = pickle.load(f)
if self.dataset_type == "davis":
self.video_names = list(self.points_dataset.keys())
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
def __getitem__(self, index):
if self.dataset_type == "davis":
video_name = self.video_names[index]
else:
video_name = index
video = self.points_dataset[video_name]
frames = video["video"]
if isinstance(frames[0], bytes):
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
def decode(frame):
byteio = io.BytesIO(frame)
img = Image.open(byteio)
return np.array(img)
frames = np.array([decode(frame) for frame in frames])
target_points = self.points_dataset[video_name]["points"]
if self.resize_to_256:
frames = resize_video(frames, [256, 256])
target_points *= np.array([256, 256])
else:
target_points *= np.array([frames.shape[2], frames.shape[1]])
T, H, W, C = frames.shape
N, T, D = target_points.shape
target_occ = self.points_dataset[video_name]["occluded"]
if self.queried_first:
converted = sample_queries_first(target_occ, target_points, frames)
else:
converted = sample_queries_strided(target_occ, target_points, frames)
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
trajs = (
torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float()
) # T, N, D
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
segs = torch.ones(T, 1, H, W).float()
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[
0
].permute(
1, 0
) # T, N
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
return CoTrackerData(
rgbs,
segs,
trajs,
visibles,
seq_name=str(video_name),
query_points=query_points,
)
def __len__(self):
return len(self.points_dataset)

114
cotracker/datasets/utils.py Normal file
View File

@@ -0,0 +1,114 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import dataclasses
import torch.nn.functional as F
from dataclasses import dataclass
from typing import Any, Optional
@dataclass(eq=False)
class CoTrackerData:
"""
Dataclass for storing video tracks data.
"""
video: torch.Tensor # B, S, C, H, W
segmentation: torch.Tensor # B, S, 1, H, W
trajectory: torch.Tensor # B, S, N, 2
visibility: torch.Tensor # B, S, N
# optional data
valid: Optional[torch.Tensor] = None # B, S, N
seq_name: Optional[str] = None
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
def collate_fn(batch):
"""
Collate function for video tracks data.
"""
video = torch.stack([b.video for b in batch], dim=0)
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
visibility = torch.stack([b.visibility for b in batch], dim=0)
query_points = None
if batch[0].query_points is not None:
query_points = torch.stack([b.query_points for b in batch], dim=0)
seq_name = [b.seq_name for b in batch]
return CoTrackerData(
video,
segmentation,
trajectory,
visibility,
seq_name=seq_name,
query_points=query_points,
)
def collate_fn_train(batch):
"""
Collate function for video tracks data during training.
"""
gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0)
segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0)
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
valid = torch.stack([b.valid for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch]
return (
CoTrackerData(video, segmentation, trajectory, visibility, valid, seq_name),
gotit,
)
def try_to_cuda(t: Any) -> Any:
"""
Try to move the input variable `t` to a cuda device.
Args:
t: Input.
Returns:
t_cuda: `t` moved to a cuda device, if supported.
"""
try:
t = t.float().cuda()
except AttributeError:
pass
return t
def dataclass_to_cuda_(obj):
"""
Move all contents of a dataclass to cuda inplace if supported.
Args:
batch: Input dataclass.
Returns:
batch_cuda: `batch` moved to a cuda device, if supported.
"""
for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj
def resize_sample(rgbs, trajs_g, segs, interp_shape):
S, C, H, W = rgbs.shape
S, N, D = trajs_g.shape
assert D == 2
rgbs = F.interpolate(rgbs, interp_shape, mode="bilinear")
segs = F.interpolate(segs, interp_shape, mode="nearest")
trajs_g[:, :, 0] *= interp_shape[1] / W
trajs_g[:, :, 1] *= interp_shape[0] / H
return rgbs, trajs_g, segs

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: badja

View File

@@ -0,0 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: fastcapture

View File

@@ -0,0 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_first

View File

@@ -0,0 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_strided

View File

@@ -0,0 +1,6 @@
defaults:
- default_config_eval
exp_dir: ./outputs/cotracker
dataset_name: tapvid_kinetics_first

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import numpy as np
from typing import Iterable, Mapping, Tuple, Union
def compute_tapvid_metrics(
query_points: np.ndarray,
gt_occluded: np.ndarray,
gt_tracks: np.ndarray,
pred_occluded: np.ndarray,
pred_tracks: np.ndarray,
query_mode: str,
) -> Mapping[str, np.ndarray]:
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
See the TAP-Vid paper for details on the metric computation. All inputs are
given in raster coordinates. The first three arguments should be the direct
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
The paper metrics assume these are scaled relative to 256x256 images.
pred_occluded and pred_tracks are your algorithm's predictions.
This function takes a batch of inputs, and computes metrics separately for
each video. The metrics for the full benchmark are a simple mean of the
metrics across the full set of videos. These numbers are between 0 and 1,
but the paper multiplies them by 100 to ease reading.
Args:
query_points: The query points, an in the format [t, y, x]. Its size is
[b, n, 3], where b is the batch size and n is the number of queries
gt_occluded: A boolean array of shape [b, n, t], where t is the number
of frames. True indicates that the point is occluded.
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
in the format [x, y]
pred_occluded: A boolean array of predicted occlusions, in the same
format as gt_occluded.
pred_tracks: An array of track predictions from your algorithm, in the
same format as gt_tracks.
query_mode: Either 'first' or 'strided', depending on how queries are
sampled. If 'first', we assume the prior knowledge that all points
before the query point are occluded, and these are removed from the
evaluation.
Returns:
A dict with the following keys:
occlusion_accuracy: Accuracy at predicting occlusion.
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
predicted to be within the given pixel threshold, ignoring occlusion
prediction.
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
threshold
average_pts_within_thresh: average across pts_within_{x}
average_jaccard: average across jaccard_{x}
"""
metrics = {}
# Don't evaluate the query point. Numpy doesn't have one_hot, so we
# replicate it by indexing into an identity matrix.
one_hot_eye = np.eye(gt_tracks.shape[2])
query_frame = query_points[..., 0]
query_frame = np.round(query_frame).astype(np.int32)
evaluation_points = one_hot_eye[query_frame] == 0
# If we're using the first point on the track as a query, don't evaluate the
# other points.
if query_mode == "first":
for i in range(gt_occluded.shape[0]):
index = np.where(gt_occluded[i] == 0)[0][0]
evaluation_points[i, :index] = False
elif query_mode != "strided":
raise ValueError("Unknown query mode " + query_mode)
# Occlusion accuracy is simply how often the predicted occlusion equals the
# ground truth.
occ_acc = (
np.sum(
np.equal(pred_occluded, gt_occluded) & evaluation_points,
axis=(1, 2),
)
/ np.sum(evaluation_points)
)
metrics["occlusion_accuracy"] = occ_acc
# Next, convert the predictions and ground truth positions into pixel
# coordinates.
visible = np.logical_not(gt_occluded)
pred_visible = np.logical_not(pred_occluded)
all_frac_within = []
all_jaccard = []
for thresh in [1, 2, 4, 8, 16]:
# True positives are points that are within the threshold and where both
# the prediction and the ground truth are listed as visible.
within_dist = (
np.sum(
np.square(pred_tracks - gt_tracks),
axis=-1,
)
< np.square(thresh)
)
is_correct = np.logical_and(within_dist, visible)
# Compute the frac_within_threshold, which is the fraction of points
# within the threshold among points that are visible in the ground truth,
# ignoring whether they're predicted to be visible.
count_correct = np.sum(
is_correct & evaluation_points,
axis=(1, 2),
)
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
frac_correct = count_correct / count_visible_points
metrics["pts_within_" + str(thresh)] = frac_correct
all_frac_within.append(frac_correct)
true_positives = np.sum(
is_correct & pred_visible & evaluation_points, axis=(1, 2)
)
# The denominator of the jaccard metric is the true positives plus
# false positives plus false negatives. However, note that true positives
# plus false negatives is simply the number of points in the ground truth
# which is easier to compute than trying to compute all three quantities.
# Thus we just add the number of points in the ground truth to the number
# of false positives.
#
# False positives are simply points that are predicted to be visible,
# but the ground truth is not visible or too far from the prediction.
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
false_positives = (~visible) & pred_visible
false_positives = false_positives | ((~within_dist) & pred_visible)
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
jaccard = true_positives / (gt_positives + false_positives)
metrics["jaccard_" + str(thresh)] = jaccard
all_jaccard.append(jaccard)
metrics["average_jaccard"] = np.mean(
np.stack(all_jaccard, axis=1),
axis=1,
)
metrics["average_pts_within_thresh"] = np.mean(
np.stack(all_frac_within, axis=1),
axis=1,
)
return metrics

View File

@@ -0,0 +1,252 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from collections import defaultdict
import os
from typing import Optional
import torch
from tqdm import tqdm
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from cotracker.datasets.utils import dataclass_to_cuda_
from cotracker.utils.visualizer import Visualizer
from cotracker.models.core.model_utils import reduce_masked_mean
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
import logging
class Evaluator:
"""
A class defining the CoTracker evaluator.
"""
def __init__(self, exp_dir) -> None:
# Visualization
self.exp_dir = exp_dir
os.makedirs(exp_dir, exist_ok=True)
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
self.visualize_dir = os.path.join(exp_dir, "visualisations")
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
if isinstance(pred_trajectory, tuple):
pred_trajectory, pred_visibility = pred_trajectory
else:
pred_visibility = None
if dataset_name == "badja":
sample.segmentation = (sample.segmentation > 0).float()
*_, N, _ = sample.trajectory.shape
accs = []
accs_3px = []
for s1 in range(1, sample.video.shape[1]): # target frame
for n in range(N):
vis = sample.visibility[0, s1, n]
if vis > 0:
coord_e = pred_trajectory[0, s1, n] # 2
coord_g = sample.trajectory[0, s1, n] # 2
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
area = torch.sum(sample.segmentation[0, s1])
# print_('0.2*sqrt(area)', 0.2*torch.sqrt(area))
thr = 0.2 * torch.sqrt(area)
# correct =
accs.append((dist < thr).float())
# print('thr',thr)
accs_3px.append((dist < 3.0).float())
res = torch.mean(torch.stack(accs)) * 100.0
res_3px = torch.mean(torch.stack(accs_3px)) * 100.0
metrics[sample.seq_name[0]] = res.item()
metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item()
print(metrics)
print(
"avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k])
)
print(
"avg acc 3px",
np.mean([v for k, v in metrics.items() if "accuracy" in k]),
)
elif dataset_name == "fastcapture" or ("kubric" in dataset_name):
*_, N, _ = sample.trajectory.shape
accs = []
for s1 in range(1, sample.video.shape[1]): # target frame
for n in range(N):
vis = sample.visibility[0, s1, n]
if vis > 0:
coord_e = pred_trajectory[0, s1, n] # 2
coord_g = sample.trajectory[0, s1, n] # 2
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
thr = 3
correct = (dist < thr).float()
accs.append(correct)
res = torch.mean(torch.stack(accs)) * 100.0
metrics[sample.seq_name[0] + "_accuracy"] = res.item()
print(metrics)
print("avg", np.mean([v for v in metrics.values()]))
elif "tapvid" in dataset_name:
B, T, N, D = sample.trajectory.shape
traj = sample.trajectory.clone()
thr = 0.9
if pred_visibility is None:
logging.warning("visibility is NONE")
pred_visibility = torch.zeros_like(sample.visibility)
if not pred_visibility.dtype == torch.bool:
pred_visibility = pred_visibility > thr
# pred_trajectory
query_points = sample.query_points.clone().cpu().numpy()
pred_visibility = pred_visibility[:, :, :N]
pred_trajectory = pred_trajectory[:, :, :N]
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
gt_occluded = (
torch.logical_not(sample.visibility.clone().permute(0, 2, 1))
.cpu()
.numpy()
)
pred_occluded = (
torch.logical_not(pred_visibility.clone().permute(0, 2, 1))
.cpu()
.numpy()
)
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
out_metrics = compute_tapvid_metrics(
query_points,
gt_occluded,
gt_tracks,
pred_occluded,
pred_tracks,
query_mode="strided" if "strided" in dataset_name else "first",
)
metrics[sample.seq_name[0]] = out_metrics
for metric_name in out_metrics.keys():
if "avg" not in metrics:
metrics["avg"] = {}
metrics["avg"][metric_name] = np.mean(
[v[metric_name] for k, v in metrics.items() if k != "avg"]
)
logging.info(f"Metrics: {out_metrics}")
logging.info(f"avg: {metrics['avg']}")
print("metrics", out_metrics)
print("avg", metrics["avg"])
else:
rgbs = sample.video
trajs_g = sample.trajectory
valids = sample.valid
vis_g = sample.visibility
B, S, C, H, W = rgbs.shape
assert C == 3
B, S, N, D = trajs_g.shape
assert torch.sum(valids) == B * S * N
vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1)
ate = torch.norm(pred_trajectory - trajs_g, dim=-1) # B, S, N
metrics["things_all"] = reduce_masked_mean(ate, valids).item()
metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item()
metrics["things_occ"] = reduce_masked_mean(
ate, valids * (1.0 - vis_g)
).item()
@torch.no_grad()
def evaluate_sequence(
self,
model,
test_dataloader: torch.utils.data.DataLoader,
dataset_name: str,
train_mode=False,
writer: Optional[SummaryWriter] = None,
step: Optional[int] = 0,
):
metrics = {}
vis = Visualizer(
save_dir=self.exp_dir,
fps=7,
)
for ind, sample in enumerate(tqdm(test_dataloader)):
if isinstance(sample, tuple):
sample, gotit = sample
if not all(gotit):
print("batch is None")
continue
dataclass_to_cuda_(sample)
if (
not train_mode
and hasattr(model, "sequence_len")
and (sample.visibility[:, : model.sequence_len].sum() == 0)
):
print(f"skipping batch {ind}")
continue
if "tapvid" in dataset_name:
queries = sample.query_points.clone().float()
queries = torch.stack(
[
queries[:, :, 0],
queries[:, :, 2],
queries[:, :, 1],
],
dim=2,
)
else:
queries = torch.cat(
[
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
sample.trajectory[:, 0],
],
dim=2,
)
pred_tracks = model(sample.video, queries)
if "strided" in dataset_name:
inv_video = sample.video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
pred_trj, pred_vsb = pred_tracks
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
inv_pred_trj = inv_pred_trj.flip(1)
inv_pred_vsb = inv_pred_vsb.flip(1)
mask = pred_trj == 0
pred_trj[mask] = inv_pred_trj[mask]
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
pred_tracks = pred_trj, pred_vsb
if dataset_name == "badja" or dataset_name == "fastcapture":
seq_name = sample.seq_name[0]
else:
seq_name = str(ind)
vis.visualize(
sample.video,
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
filename=dataset_name + "_" + seq_name,
writer=writer,
step=step,
)
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
return metrics

View File

@@ -0,0 +1,179 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import json
import os
from dataclasses import dataclass, field
import hydra
import numpy as np
import torch
from omegaconf import OmegaConf
from cotracker.datasets.badja_dataset import BadjaDataset
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
from cotracker.datasets.tap_vid_datasets import TapVidDataset
from cotracker.datasets.utils import collate_fn
from cotracker.models.evaluation_predictor import EvaluationPredictor
from cotracker.evaluation.core.evaluator import Evaluator
from cotracker.models.build_cotracker import (
build_cotracker,
)
@dataclass(eq=False)
class DefaultConfig:
# Directory where all outputs of the experiment will be saved.
exp_dir: str = "./outputs"
# Name of the dataset to be used for the evaluation.
dataset_name: str = "badja"
# The root directory of the dataset.
dataset_root: str = "./"
# Path to the pre-trained model checkpoint to be used for the evaluation.
# The default value is the path to a specific CoTracker model checkpoint.
# Other available options are commented.
checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth"
# cotracker_stride_4_wind_12
# cotracker_stride_8_wind_16
# EvaluationPredictor parameters
# The size (N) of the support grid used in the predictor.
# The total number of points is (N*N).
grid_size: int = 6
# The size (N) of the local support grid.
local_grid_size: int = 6
# A flag indicating whether to evaluate one ground truth point at a time.
single_point: bool = True
# The number of iterative updates for each sliding window.
n_iters: int = 6
seed: int = 0
gpu_idx: int = 0
# Override hydra's working directory to current working dir,
# also disable storing the .hydra logs:
hydra: dict = field(
default_factory=lambda: {
"run": {"dir": "."},
"output_subdir": None,
}
)
def run_eval(cfg: DefaultConfig):
"""
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
Args:
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
- exp_dir (str): The directory path for the experiment.
- dataset_name (str): The name of the dataset to be used.
- dataset_root (str): The root directory of the dataset.
- checkpoint (str): The path to the CoTracker model's checkpoint.
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
- n_iters (int): The number of iterative updates for each sliding window.
- seed (int): The seed for setting the random state for reproducibility.
- gpu_idx (int): The index of the GPU to be used.
"""
# Creating the experiment directory if it doesn't exist
os.makedirs(cfg.exp_dir, exist_ok=True)
# Saving the experiment configuration to a .yaml file in the experiment directory
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
with open(cfg_file, "w") as f:
OmegaConf.save(config=cfg, f=f)
evaluator = Evaluator(cfg.exp_dir)
cotracker_model = build_cotracker(cfg.checkpoint)
# Creating the EvaluationPredictor object
predictor = EvaluationPredictor(
cotracker_model,
grid_size=cfg.grid_size,
local_grid_size=cfg.local_grid_size,
single_point=cfg.single_point,
n_iters=cfg.n_iters,
)
# Setting the random seeds
torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)
# Constructing the specified dataset
curr_collate_fn = collate_fn
if cfg.dataset_name == "badja":
test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA"))
elif cfg.dataset_name == "fastcapture":
test_dataset = FastCaptureDataset(
data_root=os.path.join(cfg.dataset_root, "fastcapture"),
max_seq_len=100,
max_num_points=20,
)
elif "tapvid" in cfg.dataset_name:
dataset_type = cfg.dataset_name.split("_")[1]
if dataset_type == "davis":
data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl")
elif dataset_type == "kinetics":
data_root = os.path.join(
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
)
test_dataset = TapVidDataset(
dataset_type=dataset_type,
data_root=data_root,
queried_first=not "strided" in cfg.dataset_name,
)
# Creating the DataLoader object
test_dataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=14,
collate_fn=curr_collate_fn,
)
# Timing and conducting the evaluation
import time
start = time.time()
evaluate_result = evaluator.evaluate_sequence(
predictor,
test_dataloader,
dataset_name=cfg.dataset_name,
)
end = time.time()
print(end - start)
# Saving the evaluation results to a .json file
if not "tapvid" in cfg.dataset_name:
print("evaluate_result", evaluate_result)
else:
evaluate_result = evaluate_result["avg"]
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
evaluate_result["time"] = end - start
print(f"Dumping eval results to {result_file}.")
with open(result_file, "w") as f:
json.dump(evaluate_result, f)
cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config_eval", node=DefaultConfig)
@hydra.main(config_path="./configs/", config_name="default_config_eval")
def evaluate(cfg: DefaultConfig) -> None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
run_eval(cfg)
if __name__ == "__main__":
evaluate()

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
from cotracker.models.core.cotracker.cotracker import CoTracker
def build_cotracker(
checkpoint: str,
):
model_name = checkpoint.split("/")[-1].split(".")[0]
if model_name == "cotracker_stride_4_wind_8":
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
elif model_name == "cotracker_stride_4_wind_12":
return build_cotracker_stride_4_wind_12(checkpoint=checkpoint)
elif model_name == "cotracker_stride_8_wind_16":
return build_cotracker_stride_8_wind_16(checkpoint=checkpoint)
else:
raise ValueError(f"Unknown model name {model_name}")
# model used to produce the results in the paper
def build_cotracker_stride_4_wind_8(checkpoint=None):
return _build_cotracker(
stride=4,
sequence_len=8,
checkpoint=checkpoint,
)
def build_cotracker_stride_4_wind_12(checkpoint=None):
return _build_cotracker(
stride=4,
sequence_len=12,
checkpoint=checkpoint,
)
# the fastest model
def build_cotracker_stride_8_wind_16(checkpoint=None):
return _build_cotracker(
stride=8,
sequence_len=16,
checkpoint=checkpoint,
)
def _build_cotracker(
stride,
sequence_len,
checkpoint=None,
):
cotracker = CoTracker(
stride=stride,
S=sequence_len,
add_space_attn=True,
space_depth=6,
time_depth=6,
)
if checkpoint is not None:
with open(checkpoint, "rb") as f:
state_dict = torch.load(f, map_location="cpu")
if "model" in state_dict:
state_dict = state_dict["model"]
cotracker.load_state_dict(state_dict)
return cotracker

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,400 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from timm.models.vision_transformer import Attention, Mlp
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
padding=1,
stride=stride,
padding_mode="zeros",
)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(
self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0
):
super(BasicEncoder, self).__init__()
self.stride = stride
self.norm_fn = norm_fn
self.in_planes = 64
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(self.in_planes)
self.norm2 = nn.BatchNorm2d(output_dim * 2)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(
input_dim,
self.in_planes,
kernel_size=7,
stride=2,
padding=3,
padding_mode="zeros",
)
self.relu1 = nn.ReLU(inplace=True)
self.shallow = False
if self.shallow:
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
else:
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
self.layer4 = self._make_layer(128, stride=2)
self.conv2 = nn.Conv2d(
128 + 128 + 96 + 64,
output_dim * 2,
kernel_size=3,
padding=1,
padding_mode="zeros",
)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
_, _, H, W = x.shape
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
if self.shallow:
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
a = F.interpolate(
a,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
b = F.interpolate(
b,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
c = F.interpolate(
c,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
x = self.conv2(torch.cat([a, b, c], dim=1))
else:
a = self.layer1(x)
b = self.layer2(a)
c = self.layer3(b)
d = self.layer4(c)
a = F.interpolate(
a,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
b = F.interpolate(
b,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
c = F.interpolate(
c,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
d = F.interpolate(
d,
(H // self.stride, W // self.stride),
mode="bilinear",
align_corners=True,
)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x)
x = self.relu2(x)
x = self.conv3(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
return x
class AttnBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(
in_features=hidden_size,
hidden_features=mlp_hidden_dim,
act_layer=approx_gelu,
drop=0,
)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
"""Wrapper for grid_sample, uses pixel coordinates"""
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
# go to 0,1 then 0,2 then -1,1
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
class CorrBlock:
def __init__(self, fmaps, num_levels=4, radius=4):
B, S, C, H, W = fmaps.shape
self.S, self.C, self.H, self.W = S, C, H, W
self.num_levels = num_levels
self.radius = radius
self.fmaps_pyramid = []
self.fmaps_pyramid.append(fmaps)
for i in range(self.num_levels - 1):
fmaps_ = fmaps.reshape(B * S, C, H, W)
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
_, _, H, W = fmaps_.shape
fmaps = fmaps_.reshape(B, S, C, H, W)
self.fmaps_pyramid.append(fmaps)
def sample(self, coords):
r = self.radius
B, S, N, D = coords.shape
assert D == 2
H, W = self.H, self.W
out_pyramid = []
for i in range(self.num_levels):
corrs = self.corrs_pyramid[i] # B, S, N, H, W
_, _, _, H, W = corrs.shape
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
coords.device
)
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
corrs = corrs.view(B, S, N, -1)
out_pyramid.append(corrs)
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
return out.contiguous().float()
def corr(self, targets):
B, S, N, C = targets.shape
assert C == self.C
assert S == self.S
fmap1 = targets
self.corrs_pyramid = []
for fmaps in self.fmaps_pyramid:
_, _, _, H, W = fmaps.shape
fmap2s = fmaps.view(B, S, C, H * W)
corrs = torch.matmul(fmap1, fmap2s)
corrs = corrs.view(B, S, N, H, W)
corrs = corrs / torch.sqrt(torch.tensor(C).float())
self.corrs_pyramid.append(corrs)
class UpdateFormer(nn.Module):
"""
Transformer model that updates track estimates.
"""
def __init__(
self,
space_depth=12,
time_depth=12,
input_dim=320,
hidden_size=384,
num_heads=8,
output_dim=130,
mlp_ratio=4.0,
add_space_attn=True,
):
super().__init__()
self.out_channels = 2
self.num_heads = num_heads
self.hidden_size = hidden_size
self.add_space_attn = add_space_attn
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
self.time_blocks = nn.ModuleList(
[
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(time_depth)
]
)
if add_space_attn:
self.space_blocks = nn.ModuleList(
[
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(space_depth)
]
)
assert len(self.time_blocks) >= len(self.space_blocks)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def forward(self, input_tensor):
x = self.input_transform(input_tensor)
j = 0
for i in range(len(self.time_blocks)):
B, N, T, _ = x.shape
x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N)
x_time = self.time_blocks[i](x_time)
x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N)
if self.add_space_attn and (
i % (len(self.time_blocks) // len(self.space_blocks)) == 0
):
x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N)
x_space = self.space_blocks[j](x_space)
x = rearrange(x_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
j += 1
flow = self.flow_head(x)
return flow

View File

@@ -0,0 +1,351 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from einops import rearrange
from cotracker.models.core.cotracker.blocks import (
BasicEncoder,
CorrBlock,
UpdateFormer,
)
from cotracker.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat
from cotracker.models.core.embeddings import (
get_2d_embedding,
get_1d_sincos_pos_embed_from_grid,
get_2d_sincos_pos_embed,
)
torch.manual_seed(0)
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
if grid_size == 1:
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
None, None
].cuda()
grid_y, grid_x = meshgrid2d(
1, grid_size, grid_size, stack=False, norm=False, device="cuda"
)
step = interp_shape[1] // 64
if grid_center[0] != 0 or grid_center[1] != 0:
grid_y = grid_y - grid_size / 2.0
grid_x = grid_x - grid_size / 2.0
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[0] - step * 2
)
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
interp_shape[1] - step * 2
)
grid_y = grid_y + grid_center[0]
grid_x = grid_x + grid_center[1]
xy = torch.stack([grid_x, grid_y], dim=-1).cuda()
return xy
def sample_pos_embed(grid_size, embed_dim, coords):
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size)
pos_embed = (
torch.from_numpy(pos_embed)
.reshape(grid_size[0], grid_size[1], embed_dim)
.float()
.unsqueeze(0)
.to(coords.device)
)
sampled_pos_embed = bilinear_sample2d(
pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1]
)
return sampled_pos_embed
class CoTracker(nn.Module):
def __init__(
self,
S=8,
stride=8,
add_space_attn=True,
num_heads=8,
hidden_size=384,
space_depth=12,
time_depth=12,
):
super(CoTracker, self).__init__()
self.S = S
self.stride = stride
self.hidden_dim = 256
self.latent_dim = latent_dim = 128
self.corr_levels = 4
self.corr_radius = 3
self.add_space_attn = add_space_attn
self.fnet = BasicEncoder(
output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride
)
self.updateformer = UpdateFormer(
space_depth=space_depth,
time_depth=time_depth,
input_dim=456,
hidden_size=hidden_size,
num_heads=num_heads,
output_dim=latent_dim + 2,
mlp_ratio=4.0,
add_space_attn=add_space_attn,
)
self.norm = nn.GroupNorm(1, self.latent_dim)
self.ffeat_updater = nn.Sequential(
nn.Linear(self.latent_dim, self.latent_dim),
nn.GELU(),
)
self.vis_predictor = nn.Sequential(
nn.Linear(self.latent_dim, 1),
)
def forward_iteration(
self,
fmaps,
coords_init,
feat_init=None,
vis_init=None,
track_mask=None,
iters=4,
):
B, S_init, N, D = coords_init.shape
assert D == 2
assert B == 1
B, S, __, H8, W8 = fmaps.shape
device = fmaps.device
if S_init < S:
coords = torch.cat(
[coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
)
vis_init = torch.cat(
[vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
)
else:
coords = coords_init.clone()
fcorr_fn = CorrBlock(
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
)
ffeats = feat_init.clone()
times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
pos_embed = sample_pos_embed(
grid_size=(H8, W8),
embed_dim=456,
coords=coords,
)
pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
times_embed = (
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
.repeat(B, 1, 1)
.float()
.to(device)
)
coord_predictions = []
for __ in range(iters):
coords = coords.detach()
fcorr_fn.corr(ffeats)
fcorrs = fcorr_fn.sample(coords) # B, S, N, LRR
LRR = fcorrs.shape[3]
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
flows_cat = get_2d_embedding(flows_, 64, cat_coords=True)
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
if track_mask.shape[1] < vis_init.shape[1]:
track_mask = torch.cat(
[
track_mask,
torch.zeros_like(track_mask[:, 0]).repeat(
1, vis_init.shape[1] - track_mask.shape[1], 1, 1
),
],
dim=1,
)
concat = (
torch.cat([track_mask, vis_init], dim=2)
.permute(0, 2, 1, 3)
.reshape(B * N, S, 2)
)
transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
x = transformer_input + pos_embed + times_embed
x = rearrange(x, "(b n) t d -> b n t d", b=B)
delta = self.updateformer(x)
delta = rearrange(delta, " b n t d -> (b n) t d")
delta_coords_ = delta[:, :, :2]
delta_feats_ = delta[:, :, 2:]
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_
ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute(
0, 2, 1, 3
) # B,S,N,C
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
coord_predictions.append(coords * self.stride)
vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape(
B, S, N
)
return coord_predictions, vis_e, feat_init
def forward(self, rgbs, queries, iters=4, feat_init=None, is_train=False):
B, T, C, H, W = rgbs.shape
B, N, __ = queries.shape
device = rgbs.device
assert B == 1
# INIT for the first sequence
# We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively
first_positive_inds = queries[:, :, 0].long()
__, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False)
inv_sort_inds = torch.argsort(sort_inds, dim=0)
first_positive_sorted_inds = first_positive_inds[0][sort_inds]
assert torch.allclose(
first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds]
)
coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat(
1, self.S, 1, 1
) / float(self.stride)
rgbs = 2 * (rgbs / 255.0) - 1.0
traj_e = torch.zeros((B, T, N, 2), device=device)
vis_e = torch.zeros((B, T, N), device=device)
ind_array = torch.arange(T, device=device)
ind_array = ind_array[None, :, None].repeat(B, 1, N)
track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1)
# these are logits, so we initialize visibility with something that would give a value close to 1 after softmax
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
ind = 0
track_mask_ = track_mask[:, :, sort_inds].clone()
coords_init_ = coords_init[:, :, sort_inds].clone()
vis_init_ = vis_init[:, :, sort_inds].clone()
prev_wind_idx = 0
fmaps_ = None
vis_predictions = []
coord_predictions = []
wind_inds = []
while ind < T - self.S // 2:
rgbs_seq = rgbs[:, ind : ind + self.S]
S = S_local = rgbs_seq.shape[1]
if S < self.S:
rgbs_seq = torch.cat(
[rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
dim=1,
)
S = rgbs_seq.shape[1]
rgbs_ = rgbs_seq.reshape(B * S, C, H, W)
if fmaps_ is None:
fmaps_ = self.fnet(rgbs_)
else:
fmaps_ = torch.cat(
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
)
fmaps = fmaps_.reshape(
B, S, self.latent_dim, H // self.stride, W // self.stride
)
curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S)
if curr_wind_points.shape[0] == 0:
ind = ind + self.S // 2
continue
wind_idx = curr_wind_points[-1] + 1
if wind_idx - prev_wind_idx > 0:
fmaps_sample = fmaps[
:, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind
]
feat_init_ = bilinear_sample2d(
fmaps_sample,
coords_init_[:, 0, prev_wind_idx:wind_idx, 0],
coords_init_[:, 0, prev_wind_idx:wind_idx, 1],
).permute(0, 2, 1)
feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1)
feat_init = smart_cat(feat_init, feat_init_, dim=2)
if prev_wind_idx > 0:
new_coords = coords[-1][:, self.S // 2 :] / float(self.stride)
coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords
coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[
:, -1
].repeat(1, self.S // 2, 1, 1)
new_vis = vis[:, self.S // 2 :].unsqueeze(-1)
vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis
vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat(
1, self.S // 2, 1, 1
)
coords, vis, __ = self.forward_iteration(
fmaps=fmaps,
coords_init=coords_init_[:, :, :wind_idx],
feat_init=feat_init[:, :, :wind_idx],
vis_init=vis_init_[:, :, :wind_idx],
track_mask=track_mask_[:, ind : ind + self.S, :wind_idx],
iters=iters,
)
if is_train:
vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
coord_predictions.append([coord[:, :S_local] for coord in coords])
wind_inds.append(wind_idx)
traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local]
vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local]
track_mask_[:, : ind + self.S, :wind_idx] = 0.0
ind = ind + self.S // 2
prev_wind_idx = wind_idx
traj_e = traj_e[:, :, inv_sort_inds]
vis_e = vis_e[:, :, inv_sort_inds]
vis_e = torch.sigmoid(vis_e)
train_data = (
(vis_predictions, coord_predictions, wind_inds, sort_inds)
if is_train
else None
)
return traj_e, feat_init, vis_e, train_data

View File

@@ -0,0 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from cotracker.models.core.model_utils import reduce_masked_mean
EPS = 1e-6
def balanced_ce_loss(pred, gt, valid=None):
total_balanced_loss = 0.0
for j in range(len(gt)):
B, S, N = gt[j].shape
# pred and gt are the same shape
for (a, b) in zip(pred[j].size(), gt[j].size()):
assert a == b # some shape mismatch!
# if valid is not None:
for (a, b) in zip(pred[j].size(), valid[j].size()):
assert a == b # some shape mismatch!
pos = (gt[j] > 0.95).float()
neg = (gt[j] < 0.05).float()
label = pos * 2.0 - 1.0
a = -label * pred[j]
b = F.relu(a)
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
pos_loss = reduce_masked_mean(loss, pos * valid[j])
neg_loss = reduce_masked_mean(loss, neg * valid[j])
balanced_loss = pos_loss + neg_loss
total_balanced_loss += balanced_loss / float(N)
return total_balanced_loss
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
"""Loss function defined over sequence of flow predictions"""
total_flow_loss = 0.0
for j in range(len(flow_gt)):
B, S, N, D = flow_gt[j].shape
assert D == 2
B, S1, N = vis[j].shape
B, S2, N = valids[j].shape
assert S == S1
assert S == S2
n_predictions = len(flow_preds[j])
flow_loss = 0.0
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[j][i]
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
i_loss = torch.mean(i_loss, dim=3) # B, S, N
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
flow_loss = flow_loss / n_predictions
total_flow_loss += flow_loss / float(N)
return total_flow_loss

View File

@@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import numpy as np
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
if isinstance(grid_size, tuple):
grid_size_h, grid_size_w = grid_size
else:
grid_size_h = grid_size_w = grid_size
grid_h = np.arange(grid_size_h, dtype=np.float32)
grid_w = np.arange(grid_size_w, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate(
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert embed_dim % 2 == 0
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000 ** omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
def get_2d_embedding(xy, C, cat_coords=True):
B, N, D = xy.shape
assert D == 2
x = xy[:, :, 0:1]
y = xy[:, :, 1:2]
div_term = (
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
return pe
def get_3d_embedding(xyz, C, cat_coords=True):
B, N, D = xyz.shape
assert D == 3
x = xyz[:, :, 0:1]
y = xyz[:, :, 1:2]
z = xyz[:, :, 2:3]
div_term = (
torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe_z[:, :, 0::2] = torch.sin(z * div_term)
pe_z[:, :, 1::2] = torch.cos(z * div_term)
pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
return pe
def get_4d_embedding(xyzw, C, cat_coords=True):
B, N, D = xyzw.shape
assert D == 4
x = xyzw[:, :, 0:1]
y = xyzw[:, :, 1:2]
z = xyzw[:, :, 2:3]
w = xyzw[:, :, 3:4]
div_term = (
torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe_z[:, :, 0::2] = torch.sin(z * div_term)
pe_z[:, :, 1::2] = torch.cos(z * div_term)
pe_w[:, :, 0::2] = torch.sin(w * div_term)
pe_w[:, :, 1::2] = torch.cos(w * div_term)
pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
if cat_coords:
pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
return pe

View File

@@ -0,0 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
EPS = 1e-6
def smart_cat(tensor1, tensor2, dim):
if tensor1 is None:
return tensor2
return torch.cat([tensor1, tensor2], dim=dim)
def normalize_single(d):
# d is a whatever shape torch tensor
dmin = torch.min(d)
dmax = torch.max(d)
d = (d - dmin) / (EPS + (dmax - dmin))
return d
def normalize(d):
# d is B x whatever. normalize within each element of the batch
out = torch.zeros(d.size())
if d.is_cuda:
out = out.cuda()
B = list(d.size())[0]
for b in list(range(B)):
out[b] = normalize_single(d[b])
return out
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
# returns a meshgrid sized B x Y x X
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
grid_y = torch.reshape(grid_y, [1, Y, 1])
grid_y = grid_y.repeat(B, 1, X)
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
grid_x = torch.reshape(grid_x, [1, 1, X])
grid_x = grid_x.repeat(B, Y, 1)
if stack:
# note we stack in xy order
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
grid = torch.stack([grid_x, grid_y], dim=-1)
return grid
else:
return grid_y, grid_x
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
# returns shape-1
# axis can be a list of axes
for (a, b) in zip(x.size(), mask.size()):
assert a == b # some shape mismatch!
prod = x * mask
if dim is None:
numer = torch.sum(prod)
denom = EPS + torch.sum(mask)
else:
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
mean = numer / denom
return mean
def bilinear_sample2d(im, x, y, return_inbounds=False):
# x and y are each B, N
# output is B, C, N
if len(im.shape) == 5:
B, N, C, H, W = list(im.shape)
else:
B, C, H, W = list(im.shape)
N = list(x.shape)[1]
x = x.float()
y = y.float()
H_f = torch.tensor(H, dtype=torch.float32)
W_f = torch.tensor(W, dtype=torch.float32)
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
max_y = (H_f - 1).int()
max_x = (W_f - 1).int()
x0 = torch.floor(x).int()
x1 = x0 + 1
y0 = torch.floor(y).int()
y1 = y0 + 1
x0_clip = torch.clamp(x0, 0, max_x)
x1_clip = torch.clamp(x1, 0, max_x)
y0_clip = torch.clamp(y0, 0, max_y)
y1_clip = torch.clamp(y1, 0, max_y)
dim2 = W
dim1 = W * H
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
base = torch.reshape(base, [B, 1]).repeat([1, N])
base_y0 = base + y0_clip * dim2
base_y1 = base + y1_clip * dim2
idx_y0_x0 = base_y0 + x0_clip
idx_y0_x1 = base_y0 + x1_clip
idx_y1_x0 = base_y1 + x0_clip
idx_y1_x1 = base_y1 + x1_clip
# use the indices to lookup pixels in the flat image
# im is B x C x H x W
# move C out to last dim
if len(im.shape) == 5:
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
0, 2, 1
)
else:
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
i_y0_x0 = im_flat[idx_y0_x0.long()]
i_y0_x1 = im_flat[idx_y0_x1.long()]
i_y1_x0 = im_flat[idx_y1_x0.long()]
i_y1_x1 = im_flat[idx_y1_x1.long()]
# Finally calculate interpolated values.
x0_f = x0.float()
x1_f = x1.float()
y0_f = y0.float()
y1_f = y1.float()
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
output = (
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
)
# output is B*N x C
output = output.view(B, -1, C)
output = output.permute(0, 2, 1)
# output is B x C x N
if return_inbounds:
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
inbounds = (x_valid & y_valid).float()
inbounds = inbounds.reshape(
B, N
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
return output, inbounds
return output # B, C, N

View File

@@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from typing import Tuple
from cotracker.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid
class EvaluationPredictor(torch.nn.Module):
def __init__(
self,
cotracker_model: CoTracker,
interp_shape: Tuple[int, int] = (384, 512),
grid_size: int = 6,
local_grid_size: int = 6,
single_point: bool = True,
n_iters: int = 6,
) -> None:
super(EvaluationPredictor, self).__init__()
self.grid_size = grid_size
self.local_grid_size = local_grid_size
self.single_point = single_point
self.interp_shape = interp_shape
self.n_iters = n_iters
self.model = cotracker_model
self.model.to("cuda")
self.model.eval()
def forward(self, video, queries):
queries = queries.clone().cuda()
B, T, C, H, W = video.shape
B, N, D = queries.shape
assert D == 3
assert B == 1
rgbs = video.reshape(B * T, C, H, W)
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
if self.single_point:
traj_e = torch.zeros((B, T, N, 2)).cuda()
vis_e = torch.zeros((B, T, N)).cuda()
for pind in range((N)):
query = queries[:, pind : pind + 1]
t = query[0, 0, 0].long()
traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query)
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
else:
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
queries = torch.cat([queries, xy], dim=1) #
traj_e, __, vis_e, __ = self.model(
rgbs=rgbs,
queries=queries,
iters=self.n_iters,
)
traj_e[:, :, :, 0] *= W / float(self.interp_shape[1])
traj_e[:, :, :, 1] *= H / float(self.interp_shape[0])
return traj_e, vis_e
def _process_one_point(self, rgbs, query):
t = query[0, 0, 0].long()
device = rgbs.device
if self.local_grid_size > 0:
xy_target = get_points_on_a_grid(
self.local_grid_size,
(50, 50),
[query[0, 0, 2], query[0, 0, 1]],
)
xy_target = torch.cat(
[torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2
) #
query = torch.cat([query, xy_target], dim=1).to(device) #
if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
query = torch.cat([query, xy], dim=1).to(device) #
# crop the video to start from the queried frame
query[0, 0, 0] = 0
traj_e_pind, __, vis_e_pind, __ = self.model(
rgbs=rgbs[:, t:], queries=query, iters=self.n_iters
)
return traj_e_pind, vis_e_pind

178
cotracker/predictor.py Normal file
View File

@@ -0,0 +1,178 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn.functional as F
from tqdm import tqdm
from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid
from cotracker.models.core.model_utils import smart_cat
from cotracker.models.build_cotracker import (
build_cotracker,
)
class CoTrackerPredictor(torch.nn.Module):
def __init__(
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
):
super().__init__()
self.interp_shape = (384, 512)
self.support_grid_size = 6
model = build_cotracker(checkpoint)
self.model = model
self.model.to("cuda")
self.model.eval()
@torch.no_grad()
def forward(
self,
video, # (1, T, 3, H, W)
# input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries: torch.Tensor = None,
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
grid_size: int = 0,
grid_query_frame: int = 0, # only for dense and regular grid tracks
backward_tracking: bool = False,
):
if queries is None and grid_size == 0:
tracks, visibilities = self._compute_dense_tracks(
video,
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
)
else:
tracks, visibilities = self._compute_sparse_tracks(
video,
queries,
segm_mask,
grid_size,
add_support_grid=(grid_size == 0 or segm_mask is not None),
grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking,
)
return tracks, visibilities
def _compute_dense_tracks(
self, video, grid_query_frame, grid_size=50, backward_tracking=False
):
*_, H, W = video.shape
grid_step = W // grid_size
grid_width = W // grid_step
grid_height = H // grid_step
tracks = visibilities = None
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda")
grid_pts[0, :, 0] = grid_query_frame
for offset in tqdm(range(grid_step * grid_step)):
ox = offset % grid_step
oy = offset // grid_step
grid_pts[0, :, 1] = (
torch.arange(grid_width).repeat(grid_height) * grid_step + ox
)
grid_pts[0, :, 2] = (
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
)
tracks_step, visibilities_step = self._compute_sparse_tracks(
video=video,
queries=grid_pts,
backward_tracking=backward_tracking,
)
tracks = smart_cat(tracks, tracks_step, dim=2)
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
return tracks, visibilities
def _compute_sparse_tracks(
self,
video,
queries,
segm_mask=None,
grid_size=0,
add_support_grid=False,
grid_query_frame=0,
backward_tracking=False,
):
B, T, C, H, W = video.shape
assert B == 1
video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda()
video = video.reshape(
B, T, 3, self.interp_shape[0], self.interp_shape[1]
).cuda()
if queries is not None:
queries = queries.clone()
B, N, D = queries.shape
assert D == 3
queries[:, :, 1] *= self.interp_shape[1] / W
queries[:, :, 2] *= self.interp_shape[0] / H
elif grid_size > 0:
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape)
if segm_mask is not None:
segm_mask = F.interpolate(
segm_mask, tuple(self.interp_shape), mode="nearest"
)
point_mask = segm_mask[0, 0][
(grid_pts[0, :, 1]).round().long().cpu(),
(grid_pts[0, :, 0]).round().long().cpu(),
].bool()
grid_pts = grid_pts[:, point_mask]
queries = torch.cat(
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
)
if add_support_grid:
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape)
grid_pts = torch.cat(
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
)
queries = torch.cat([queries, grid_pts], dim=1)
tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6)
if backward_tracking:
tracks, visibilities = self._compute_backward_tracks(
video, queries, tracks, visibilities
)
if add_support_grid:
queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
if add_support_grid:
tracks = tracks[:, :, : -self.support_grid_size ** 2]
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
thr = 0.9
visibilities = visibilities > thr
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
return tracks, visibilities
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
inv_video = video.flip(1).clone()
inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
inv_tracks, __, inv_visibilities, __ = self.model(
rgbs=inv_video, queries=inv_queries, iters=6
)
inv_tracks = inv_tracks.flip(1)
inv_visibilities = inv_visibilities.flip(1)
mask = tracks == 0
tracks[mask] = inv_tracks[mask]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
return tracks, visibilities

View File

@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

View File

@@ -0,0 +1,291 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import cv2
import torch
import flow_vis
from matplotlib import cm
import torch.nn.functional as F
import torchvision.transforms as transforms
from moviepy.editor import ImageSequenceClip
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
class Visualizer:
def __init__(
self,
save_dir: str = "./results",
grayscale: bool = False,
pad_value: int = 0,
fps: int = 10,
mode: str = "rainbow", # 'cool', 'optical_flow'
linewidth: int = 2,
show_first_frame: int = 10,
tracks_leave_trace: int = 0, # -1 for infinite
):
self.mode = mode
self.save_dir = save_dir
if mode == "rainbow":
self.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool":
self.color_map = cm.get_cmap(mode)
self.show_first_frame = show_first_frame
self.grayscale = grayscale
self.tracks_leave_trace = tracks_leave_trace
self.pad_value = pad_value
self.linewidth = linewidth
self.fps = fps
def visualize(
self,
video: torch.Tensor, # (B,T,C,H,W)
tracks: torch.Tensor, # (B,T,N,2)
gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video",
writer: SummaryWriter = None,
step: int = 0,
query_frame: int = 0,
save_video: bool = True,
compensate_for_camera_motion: bool = False,
):
if compensate_for_camera_motion:
assert segm_mask is not None
if segm_mask is not None:
coords = tracks[0, query_frame].round().long()
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
video = F.pad(
video,
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
"constant",
255,
)
tracks = tracks + self.pad_value
if self.grayscale:
transform = transforms.Grayscale()
video = transform(video)
video = video.repeat(1, 1, 3, 1, 1)
res_video = self.draw_tracks_on_video(
video=video,
tracks=tracks,
segm_mask=segm_mask,
gt_tracks=gt_tracks,
query_frame=query_frame,
compensate_for_camera_motion=compensate_for_camera_motion,
)
if save_video:
self.save_video(res_video, filename=filename, writer=writer, step=step)
return res_video
def save_video(self, video, filename, writer=None, step=0):
if writer is not None:
writer.add_video(
f"{filename}_pred_track",
video.to(torch.uint8),
global_step=step,
fps=self.fps,
)
else:
os.makedirs(self.save_dir, exist_ok=True)
wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
# Write the video file
save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
print(f"Video saved to {save_path}")
def draw_tracks_on_video(
self,
video: torch.Tensor,
tracks: torch.Tensor,
segm_mask: torch.Tensor = None,
gt_tracks=None,
query_frame: int = 0,
compensate_for_camera_motion=False,
):
B, T, C, H, W = video.shape
_, _, N, D = tracks.shape
assert D == 2
assert C == 3
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
if gt_tracks is not None:
gt_tracks = gt_tracks[0].detach().cpu().numpy()
res_video = []
# process input video
for rgb in video:
res_video.append(rgb.copy())
vector_colors = np.zeros((T, N, 3))
if self.mode == "optical_flow":
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
elif segm_mask is None:
if self.mode == "rainbow":
y_min, y_max = (
tracks[query_frame, :, 1].min(),
tracks[query_frame, :, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
color = self.color_map(norm(tracks[query_frame, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with time
for t in range(T):
color = np.array(self.color_map(t / T)[:3])[None] * 255
vector_colors[t] = np.repeat(color, N, axis=0)
else:
if self.mode == "rainbow":
vector_colors[:, segm_mask <= 0, :] = 255
y_min, y_max = (
tracks[0, segm_mask > 0, 1].min(),
tracks[0, segm_mask > 0, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
if segm_mask[n] > 0:
color = self.color_map(norm(tracks[0, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with segm class
segm_mask = segm_mask.cpu()
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
vector_colors = np.repeat(color[None], T, axis=0)
# draw tracks
if self.tracks_leave_trace != 0:
for t in range(1, T):
first_ind = (
max(0, t - self.tracks_leave_trace)
if self.tracks_leave_trace >= 0
else 0
)
curr_tracks = tracks[first_ind : t + 1]
curr_colors = vector_colors[first_ind : t + 1]
if compensate_for_camera_motion:
diff = (
tracks[first_ind : t + 1, segm_mask <= 0]
- tracks[t : t + 1, segm_mask <= 0]
).mean(1)[:, None]
curr_tracks = curr_tracks - diff
curr_tracks = curr_tracks[:, segm_mask > 0]
curr_colors = curr_colors[:, segm_mask > 0]
res_video[t] = self._draw_pred_tracks(
res_video[t],
curr_tracks,
curr_colors,
)
if gt_tracks is not None:
res_video[t] = self._draw_gt_tracks(
res_video[t], gt_tracks[first_ind : t + 1]
)
# draw points
for t in range(T):
for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1])
if coord[0] != 0 and coord[1] != 0:
if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0
):
cv2.circle(
res_video[t],
coord,
int(self.linewidth * 2),
vector_colors[t, i].tolist(),
-1,
)
# construct the final rgb sequence
if self.show_first_frame > 0:
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
def _draw_pred_tracks(
self,
rgb: np.ndarray, # H x W x 3
tracks: np.ndarray, # T x 2
vector_colors: np.ndarray,
alpha: float = 0.5,
):
T, N, _ = tracks.shape
for s in range(T - 1):
vector_color = vector_colors[s]
original = rgb.copy()
alpha = (s / T) ** 2
for i in range(N):
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
if coord_y[0] != 0 and coord_y[1] != 0:
cv2.line(
rgb,
coord_y,
coord_x,
vector_color[i].tolist(),
self.linewidth,
cv2.LINE_AA,
)
if self.tracks_leave_trace > 0:
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
return rgb
def _draw_gt_tracks(
self,
rgb: np.ndarray, # H x W x 3,
gt_tracks: np.ndarray, # T x 2
):
T, N, _ = gt_tracks.shape
color = np.array((211.0, 0.0, 0.0))
for t in range(T):
for i in range(N):
gt_tracks = gt_tracks[t][i]
# draw a red cross
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
length = self.linewidth * 3
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
cv2.line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
cv2.line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
return rgb