Initial commit
This commit is contained in:
5
cotracker/__init__.py
Normal file
5
cotracker/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
5
cotracker/datasets/__init__.py
Normal file
5
cotracker/datasets/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
390
cotracker/datasets/badja_dataset.py
Normal file
390
cotracker/datasets/badja_dataset.py
Normal file
@@ -0,0 +1,390 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
import json
|
||||
import imageio
|
||||
import cv2
|
||||
|
||||
from enum import Enum
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData, resize_sample
|
||||
|
||||
IGNORE_ANIMALS = [
|
||||
# "bear.json",
|
||||
# "camel.json",
|
||||
"cat_jump.json"
|
||||
# "cows.json",
|
||||
# "dog.json",
|
||||
# "dog-agility.json",
|
||||
# "horsejump-high.json",
|
||||
# "horsejump-low.json",
|
||||
# "impala0.json",
|
||||
# "rs_dog.json"
|
||||
"tiger.json"
|
||||
]
|
||||
|
||||
|
||||
class SMALJointCatalog(Enum):
|
||||
# body_0 = 0
|
||||
# body_1 = 1
|
||||
# body_2 = 2
|
||||
# body_3 = 3
|
||||
# body_4 = 4
|
||||
# body_5 = 5
|
||||
# body_6 = 6
|
||||
# upper_right_0 = 7
|
||||
upper_right_1 = 8
|
||||
upper_right_2 = 9
|
||||
upper_right_3 = 10
|
||||
# upper_left_0 = 11
|
||||
upper_left_1 = 12
|
||||
upper_left_2 = 13
|
||||
upper_left_3 = 14
|
||||
neck_lower = 15
|
||||
# neck_upper = 16
|
||||
# lower_right_0 = 17
|
||||
lower_right_1 = 18
|
||||
lower_right_2 = 19
|
||||
lower_right_3 = 20
|
||||
# lower_left_0 = 21
|
||||
lower_left_1 = 22
|
||||
lower_left_2 = 23
|
||||
lower_left_3 = 24
|
||||
tail_0 = 25
|
||||
# tail_1 = 26
|
||||
# tail_2 = 27
|
||||
tail_3 = 28
|
||||
# tail_4 = 29
|
||||
# tail_5 = 30
|
||||
tail_6 = 31
|
||||
jaw = 32
|
||||
nose = 33 # ADDED JOINT FOR VERTEX 1863
|
||||
# chin = 34 # ADDED JOINT FOR VERTEX 26
|
||||
right_ear = 35 # ADDED JOINT FOR VERTEX 149
|
||||
left_ear = 36 # ADDED JOINT FOR VERTEX 2124
|
||||
|
||||
|
||||
class SMALJointInfo:
|
||||
def __init__(self):
|
||||
# These are the
|
||||
self.annotated_classes = np.array(
|
||||
[
|
||||
8,
|
||||
9,
|
||||
10, # upper_right
|
||||
12,
|
||||
13,
|
||||
14, # upper_left
|
||||
15, # neck
|
||||
18,
|
||||
19,
|
||||
20, # lower_right
|
||||
22,
|
||||
23,
|
||||
24, # lower_left
|
||||
25,
|
||||
28,
|
||||
31, # tail
|
||||
32,
|
||||
33, # head
|
||||
35, # right_ear
|
||||
36,
|
||||
]
|
||||
) # left_ear
|
||||
|
||||
self.annotated_markers = np.array(
|
||||
[
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_TRIANGLE_DOWN,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_STAR,
|
||||
cv2.MARKER_CROSS,
|
||||
cv2.MARKER_CROSS,
|
||||
]
|
||||
)
|
||||
|
||||
self.joint_regions = np.array(
|
||||
[
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
1,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
2,
|
||||
3,
|
||||
3,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
4,
|
||||
5,
|
||||
5,
|
||||
5,
|
||||
5,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
6,
|
||||
7,
|
||||
7,
|
||||
7,
|
||||
8,
|
||||
9,
|
||||
]
|
||||
)
|
||||
|
||||
self.annotated_joint_region = self.joint_regions[self.annotated_classes]
|
||||
self.region_colors = np.array(
|
||||
[
|
||||
[250, 190, 190], # body, light pink
|
||||
[60, 180, 75], # upper_right, green
|
||||
[230, 25, 75], # upper_left, red
|
||||
[128, 0, 0], # neck, maroon
|
||||
[0, 130, 200], # lower_right, blue
|
||||
[255, 255, 25], # lower_left, yellow
|
||||
[240, 50, 230], # tail, majenta
|
||||
[245, 130, 48], # jaw / nose / chin, orange
|
||||
[29, 98, 115], # right_ear, turquoise
|
||||
[255, 153, 204],
|
||||
]
|
||||
) # left_ear, pink
|
||||
|
||||
self.joint_colors = np.array(self.region_colors)[self.annotated_joint_region]
|
||||
|
||||
|
||||
class BADJAData:
|
||||
def __init__(self, data_root, complete=False):
|
||||
annotations_path = os.path.join(data_root, "joint_annotations")
|
||||
|
||||
self.animal_dict = {}
|
||||
self.animal_count = 0
|
||||
self.smal_joint_info = SMALJointInfo()
|
||||
for __, animal_json in enumerate(sorted(os.listdir(annotations_path))):
|
||||
if animal_json not in IGNORE_ANIMALS:
|
||||
json_path = os.path.join(annotations_path, animal_json)
|
||||
with open(json_path) as json_data:
|
||||
animal_joint_data = json.load(json_data)
|
||||
|
||||
filenames = []
|
||||
segnames = []
|
||||
joints = []
|
||||
visible = []
|
||||
|
||||
first_path = animal_joint_data[0]["segmentation_path"]
|
||||
last_path = animal_joint_data[-1]["segmentation_path"]
|
||||
first_frame = first_path.split("/")[-1]
|
||||
last_frame = last_path.split("/")[-1]
|
||||
|
||||
if not "extra_videos" in first_path:
|
||||
animal = first_path.split("/")[-2]
|
||||
|
||||
first_frame_int = int(first_frame.split(".")[0])
|
||||
last_frame_int = int(last_frame.split(".")[0])
|
||||
|
||||
for fr in range(first_frame_int, last_frame_int + 1):
|
||||
ref_file_name = os.path.join(
|
||||
data_root,
|
||||
"DAVIS/JPEGImages/Full-Resolution/%s/%05d.jpg"
|
||||
% (animal, fr),
|
||||
)
|
||||
ref_seg_name = os.path.join(
|
||||
data_root,
|
||||
"DAVIS/Annotations/Full-Resolution/%s/%05d.png"
|
||||
% (animal, fr),
|
||||
)
|
||||
|
||||
foundit = False
|
||||
for ind, image_annotation in enumerate(animal_joint_data):
|
||||
file_name = os.path.join(
|
||||
data_root, image_annotation["image_path"]
|
||||
)
|
||||
seg_name = os.path.join(
|
||||
data_root, image_annotation["segmentation_path"]
|
||||
)
|
||||
|
||||
if file_name == ref_file_name:
|
||||
foundit = True
|
||||
label_ind = ind
|
||||
|
||||
if foundit:
|
||||
image_annotation = animal_joint_data[label_ind]
|
||||
file_name = os.path.join(
|
||||
data_root, image_annotation["image_path"]
|
||||
)
|
||||
seg_name = os.path.join(
|
||||
data_root, image_annotation["segmentation_path"]
|
||||
)
|
||||
joint = np.array(image_annotation["joints"])
|
||||
vis = np.array(image_annotation["visibility"])
|
||||
else:
|
||||
file_name = ref_file_name
|
||||
seg_name = ref_seg_name
|
||||
joint = None
|
||||
vis = None
|
||||
|
||||
filenames.append(file_name)
|
||||
segnames.append(seg_name)
|
||||
joints.append(joint)
|
||||
visible.append(vis)
|
||||
|
||||
if len(filenames):
|
||||
self.animal_dict[self.animal_count] = (
|
||||
filenames,
|
||||
segnames,
|
||||
joints,
|
||||
visible,
|
||||
)
|
||||
self.animal_count += 1
|
||||
print("Loaded BADJA dataset")
|
||||
|
||||
def get_loader(self):
|
||||
for __ in range(int(1e6)):
|
||||
animal_id = np.random.choice(len(self.animal_dict.keys()))
|
||||
filenames, segnames, joints, visible = self.animal_dict[animal_id]
|
||||
|
||||
image_id = np.random.randint(0, len(filenames))
|
||||
|
||||
seg_file = segnames[image_id]
|
||||
image_file = filenames[image_id]
|
||||
|
||||
joints = joints[image_id].copy()
|
||||
joints = joints[self.smal_joint_info.annotated_classes]
|
||||
visible = visible[image_id][self.smal_joint_info.annotated_classes]
|
||||
|
||||
rgb_img = imageio.imread(image_file) # , mode='RGB')
|
||||
sil_img = imageio.imread(seg_file) # , mode='RGB')
|
||||
|
||||
rgb_h, rgb_w, _ = rgb_img.shape
|
||||
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
|
||||
|
||||
yield rgb_img, sil_img, joints, visible, image_file
|
||||
|
||||
def get_video(self, animal_id):
|
||||
filenames, segnames, joint, visible = self.animal_dict[animal_id]
|
||||
|
||||
rgbs = []
|
||||
segs = []
|
||||
joints = []
|
||||
visibles = []
|
||||
|
||||
for s in range(len(filenames)):
|
||||
image_file = filenames[s]
|
||||
rgb_img = imageio.imread(image_file) # , mode='RGB')
|
||||
rgb_h, rgb_w, _ = rgb_img.shape
|
||||
|
||||
seg_file = segnames[s]
|
||||
sil_img = imageio.imread(seg_file) # , mode='RGB')
|
||||
sil_img = cv2.resize(sil_img, (rgb_w, rgb_h), cv2.INTER_NEAREST)
|
||||
|
||||
jo = joint[s]
|
||||
|
||||
if jo is not None:
|
||||
joi = joint[s].copy()
|
||||
joi = joi[self.smal_joint_info.annotated_classes]
|
||||
vis = visible[s][self.smal_joint_info.annotated_classes]
|
||||
else:
|
||||
joi = None
|
||||
vis = None
|
||||
|
||||
rgbs.append(rgb_img)
|
||||
segs.append(sil_img)
|
||||
joints.append(joi)
|
||||
visibles.append(vis)
|
||||
|
||||
return rgbs, segs, joints, visibles, filenames[0]
|
||||
|
||||
|
||||
class BadjaDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self, data_root, max_seq_len=1000, dataset_resolution=(384, 512)
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
self.badja_data = BADJAData(data_root)
|
||||
self.max_seq_len = max_seq_len
|
||||
self.dataset_resolution = dataset_resolution
|
||||
print(
|
||||
"found %d unique videos in %s"
|
||||
% (self.badja_data.animal_count, self.data_root)
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
rgbs, segs, joints, visibles, filename = self.badja_data.get_video(index)
|
||||
S = len(rgbs)
|
||||
H, W, __ = rgbs[0].shape
|
||||
H, W, __ = segs[0].shape
|
||||
|
||||
N, __ = joints[0].shape
|
||||
|
||||
# let's eliminate the Nones
|
||||
# note the first one is guaranteed present
|
||||
for s in range(1, S):
|
||||
if joints[s] is None:
|
||||
joints[s] = np.zeros_like(joints[0])
|
||||
visibles[s] = np.zeros_like(visibles[0])
|
||||
|
||||
# eliminate the mystery dim
|
||||
segs = [seg[:, :, 0] for seg in segs]
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
segs = np.stack(segs, 0)
|
||||
trajs = np.stack(joints, 0)
|
||||
visibles = np.stack(visibles, 0)
|
||||
|
||||
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
segs = torch.from_numpy(segs).reshape(S, 1, H, W).float()
|
||||
trajs = torch.from_numpy(trajs).reshape(S, N, 2).float()
|
||||
visibles = torch.from_numpy(visibles).reshape(S, N)
|
||||
|
||||
rgbs = rgbs[: self.max_seq_len]
|
||||
segs = segs[: self.max_seq_len]
|
||||
trajs = trajs[: self.max_seq_len]
|
||||
visibles = visibles[: self.max_seq_len]
|
||||
# apparently the coords are in yx order
|
||||
trajs = torch.flip(trajs, [2])
|
||||
|
||||
if "extra_videos" in filename:
|
||||
seq_name = filename.split("/")[-3]
|
||||
else:
|
||||
seq_name = filename.split("/")[-2]
|
||||
|
||||
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
|
||||
|
||||
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
|
||||
|
||||
def __len__(self):
|
||||
return self.badja_data.animal_count
|
72
cotracker/datasets/fast_capture_dataset.py
Normal file
72
cotracker/datasets/fast_capture_dataset.py
Normal file
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
# from PIL import Image
|
||||
import imageio
|
||||
import numpy as np
|
||||
from cotracker.datasets.utils import CoTrackerData, resize_sample
|
||||
|
||||
|
||||
class FastCaptureDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
max_seq_len=50,
|
||||
max_num_points=20,
|
||||
dataset_resolution=(384, 512),
|
||||
):
|
||||
|
||||
self.data_root = data_root
|
||||
self.seq_names = os.listdir(os.path.join(data_root, "renders_local_rm"))
|
||||
self.pth_dir = os.path.join(data_root, "zju_tracking")
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_num_points = max_num_points
|
||||
self.dataset_resolution = dataset_resolution
|
||||
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||
|
||||
def __getitem__(self, index):
|
||||
seq_name = self.seq_names[index]
|
||||
spath = os.path.join(self.data_root, "renders_local_rm", seq_name)
|
||||
pthpath = os.path.join(self.pth_dir, seq_name + ".pth")
|
||||
|
||||
rgbs = []
|
||||
img_paths = sorted(os.listdir(spath))
|
||||
for i, img_path in enumerate(img_paths):
|
||||
if i < self.max_seq_len:
|
||||
rgbs.append(imageio.imread(os.path.join(spath, img_path)))
|
||||
|
||||
annot_dict = torch.load(pthpath)
|
||||
traj_2d = annot_dict["traj_2d"][:, :, : self.max_seq_len]
|
||||
visibility = annot_dict["visibility"][:, : self.max_seq_len]
|
||||
|
||||
S = len(rgbs)
|
||||
H, W, __ = rgbs[0].shape
|
||||
*_, S = traj_2d.shape
|
||||
visibile_pts_first_frame_inds = (visibility[:, 0] > 0).nonzero(as_tuple=False)[
|
||||
:, 0
|
||||
]
|
||||
torch.manual_seed(0)
|
||||
point_inds = torch.randperm(len(visibile_pts_first_frame_inds))[
|
||||
: self.max_num_points
|
||||
]
|
||||
visible_inds_sampled = visibile_pts_first_frame_inds[point_inds]
|
||||
|
||||
rgbs = np.stack(rgbs, 0)
|
||||
rgbs = torch.from_numpy(rgbs).reshape(S, H, W, 3).permute(0, 3, 1, 2).float()
|
||||
|
||||
segs = torch.ones(S, 1, H, W).float()
|
||||
trajs = traj_2d[visible_inds_sampled].permute(2, 0, 1).float()
|
||||
visibles = visibility[visible_inds_sampled].permute(1, 0)
|
||||
|
||||
rgbs, trajs, segs = resize_sample(rgbs, trajs, segs, self.dataset_resolution)
|
||||
|
||||
return CoTrackerData(rgbs, segs, trajs, visibles, seq_name=seq_name)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
494
cotracker/datasets/kubric_movif_dataset.py
Normal file
494
cotracker/datasets/kubric_movif_dataset.py
Normal file
@@ -0,0 +1,494 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import torch
|
||||
|
||||
import imageio
|
||||
import numpy as np
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
from torchvision.transforms import ColorJitter, GaussianBlur
|
||||
from PIL import Image
|
||||
import cv2
|
||||
|
||||
|
||||
class CoTrackerDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(CoTrackerDataset, self).__init__()
|
||||
np.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
self.data_root = data_root
|
||||
self.seq_len = seq_len
|
||||
self.traj_per_sample = traj_per_sample
|
||||
self.sample_vis_1st_frame = sample_vis_1st_frame
|
||||
self.use_augs = use_augs
|
||||
self.crop_size = crop_size
|
||||
|
||||
# photometric augmentation
|
||||
self.photo_aug = ColorJitter(
|
||||
brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14
|
||||
)
|
||||
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
||||
|
||||
self.blur_aug_prob = 0.25
|
||||
self.color_aug_prob = 0.25
|
||||
|
||||
# occlusion augmentation
|
||||
self.eraser_aug_prob = 0.5
|
||||
self.eraser_bounds = [2, 100]
|
||||
self.eraser_max = 10
|
||||
|
||||
# occlusion augmentation
|
||||
self.replace_aug_prob = 0.5
|
||||
self.replace_bounds = [2, 100]
|
||||
self.replace_max = 10
|
||||
|
||||
# spatial augmentations
|
||||
self.pad_bounds = [0, 100]
|
||||
self.crop_size = crop_size
|
||||
self.resize_lim = [0.25, 2.0] # sample resizes from here
|
||||
self.resize_delta = 0.2
|
||||
self.max_crop_offset = 50
|
||||
|
||||
self.do_flip = True
|
||||
self.h_flip_prob = 0.5
|
||||
self.v_flip_prob = 0.5
|
||||
|
||||
def getitem_helper(self, index):
|
||||
return NotImplementedError
|
||||
|
||||
def __getitem__(self, index):
|
||||
gotit = False
|
||||
|
||||
sample, gotit = self.getitem_helper(index)
|
||||
if not gotit:
|
||||
print("warning: sampling failed")
|
||||
# fake sample, so we can still collate
|
||||
sample = CoTrackerData(
|
||||
video=torch.zeros(
|
||||
(self.seq_len, 3, self.crop_size[0], self.crop_size[1])
|
||||
),
|
||||
segmentation=torch.zeros(
|
||||
(self.seq_len, 1, self.crop_size[0], self.crop_size[1])
|
||||
),
|
||||
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
||||
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||
)
|
||||
|
||||
return sample, gotit
|
||||
|
||||
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
if eraser:
|
||||
############ eraser transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.eraser_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.eraser_max + 1)
|
||||
): # number of times to occlude
|
||||
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(
|
||||
self.eraser_bounds[0], self.eraser_bounds[1]
|
||||
)
|
||||
dy = np.random.randint(
|
||||
self.eraser_bounds[0], self.eraser_bounds[1]
|
||||
)
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
mean_color = np.mean(
|
||||
rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0
|
||||
)
|
||||
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
if replace:
|
||||
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs
|
||||
]
|
||||
rgbs_alt = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs_alt
|
||||
]
|
||||
|
||||
############ replace transform (per image after the first) ############
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
|
||||
for i in range(1, S):
|
||||
if np.random.rand() < self.replace_aug_prob:
|
||||
for _ in range(
|
||||
np.random.randint(1, self.replace_max + 1)
|
||||
): # number of times to occlude
|
||||
xc = np.random.randint(0, W)
|
||||
yc = np.random.randint(0, H)
|
||||
dx = np.random.randint(
|
||||
self.replace_bounds[0], self.replace_bounds[1]
|
||||
)
|
||||
dy = np.random.randint(
|
||||
self.replace_bounds[0], self.replace_bounds[1]
|
||||
)
|
||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||
|
||||
wid = x1 - x0
|
||||
hei = y1 - y0
|
||||
y00 = np.random.randint(0, H - hei)
|
||||
x00 = np.random.randint(0, W - wid)
|
||||
fr = np.random.randint(0, S)
|
||||
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
|
||||
rgbs[i][y0:y1, x0:x1, :] = rep
|
||||
|
||||
occ_inds = np.logical_and(
|
||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||
)
|
||||
visibles[i, occ_inds] = 0
|
||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||
|
||||
############ photometric augmentation ############
|
||||
if np.random.rand() < self.color_aug_prob:
|
||||
# random per-frame amount of aug
|
||||
rgbs = [
|
||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs
|
||||
]
|
||||
|
||||
if np.random.rand() < self.blur_aug_prob:
|
||||
# random per-frame amount of blur
|
||||
rgbs = [
|
||||
np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8)
|
||||
for rgb in rgbs
|
||||
]
|
||||
|
||||
return rgbs, trajs, visibles
|
||||
|
||||
def add_spatial_augs(self, rgbs, trajs, visibles):
|
||||
T, N, __ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
# padding
|
||||
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||
|
||||
rgbs = [
|
||||
np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs
|
||||
]
|
||||
trajs[:, :, 0] += pad_x0
|
||||
trajs[:, :, 1] += pad_y0
|
||||
H, W = rgbs[0].shape[:2]
|
||||
|
||||
# scaling + stretching
|
||||
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
|
||||
scale_x = scale
|
||||
scale_y = scale
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
scale_delta_x = 0.0
|
||||
scale_delta_y = 0.0
|
||||
|
||||
rgbs_scaled = []
|
||||
for s in range(S):
|
||||
if s == 1:
|
||||
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||
elif s > 1:
|
||||
scale_delta_x = (
|
||||
scale_delta_x * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_delta_y = (
|
||||
scale_delta_y * 0.8
|
||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||
)
|
||||
scale_x = scale_x + scale_delta_x
|
||||
scale_y = scale_y + scale_delta_y
|
||||
|
||||
# bring h/w closer
|
||||
scale_xy = (scale_x + scale_y) * 0.5
|
||||
scale_x = scale_x * 0.5 + scale_xy * 0.5
|
||||
scale_y = scale_y * 0.5 + scale_xy * 0.5
|
||||
|
||||
# don't get too crazy
|
||||
scale_x = np.clip(scale_x, 0.2, 2.0)
|
||||
scale_y = np.clip(scale_y, 0.2, 2.0)
|
||||
|
||||
H_new = int(H * scale_y)
|
||||
W_new = int(W * scale_x)
|
||||
|
||||
# make it at least slightly bigger than the crop area,
|
||||
# so that the random cropping can add diversity
|
||||
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
||||
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
||||
# recompute scale in case we clipped
|
||||
scale_x = W_new / float(W)
|
||||
scale_y = H_new / float(H)
|
||||
|
||||
rgbs_scaled.append(
|
||||
cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)
|
||||
)
|
||||
trajs[s, :, 0] *= scale_x
|
||||
trajs[s, :, 1] *= scale_y
|
||||
rgbs = rgbs_scaled
|
||||
|
||||
ok_inds = visibles[0, :] > 0
|
||||
vis_trajs = trajs[:, ok_inds] # S,?,2
|
||||
|
||||
if vis_trajs.shape[1] > 0:
|
||||
mid_x = np.mean(vis_trajs[0, :, 0])
|
||||
mid_y = np.mean(vis_trajs[0, :, 1])
|
||||
else:
|
||||
mid_y = self.crop_size[0]
|
||||
mid_x = self.crop_size[1]
|
||||
|
||||
x0 = int(mid_x - self.crop_size[1] // 2)
|
||||
y0 = int(mid_y - self.crop_size[0] // 2)
|
||||
|
||||
offset_x = 0
|
||||
offset_y = 0
|
||||
|
||||
for s in range(S):
|
||||
# on each frame, shift a bit more
|
||||
if s == 1:
|
||||
offset_x = np.random.randint(
|
||||
-self.max_crop_offset, self.max_crop_offset
|
||||
)
|
||||
offset_y = np.random.randint(
|
||||
-self.max_crop_offset, self.max_crop_offset
|
||||
)
|
||||
elif s > 1:
|
||||
offset_x = int(
|
||||
offset_x * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
|
||||
* 0.2
|
||||
)
|
||||
offset_y = int(
|
||||
offset_y * 0.8
|
||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1)
|
||||
* 0.2
|
||||
)
|
||||
x0 = x0 + offset_x
|
||||
y0 = y0 + offset_y
|
||||
|
||||
H_new, W_new = rgbs[s].shape[:2]
|
||||
if H_new == self.crop_size[0]:
|
||||
y0 = 0
|
||||
else:
|
||||
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
|
||||
|
||||
if W_new == self.crop_size[1]:
|
||||
x0 = 0
|
||||
else:
|
||||
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
|
||||
|
||||
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||
trajs[s, :, 0] -= x0
|
||||
trajs[s, :, 1] -= y0
|
||||
|
||||
H_new = self.crop_size[0]
|
||||
W_new = self.crop_size[1]
|
||||
|
||||
# flip
|
||||
h_flipped = False
|
||||
v_flipped = False
|
||||
if self.do_flip:
|
||||
# h flip
|
||||
if np.random.rand() < self.h_flip_prob:
|
||||
h_flipped = True
|
||||
rgbs = [rgb[:, ::-1] for rgb in rgbs]
|
||||
# v flip
|
||||
if np.random.rand() < self.v_flip_prob:
|
||||
v_flipped = True
|
||||
rgbs = [rgb[::-1] for rgb in rgbs]
|
||||
if h_flipped:
|
||||
trajs[:, :, 0] = W_new - trajs[:, :, 0]
|
||||
if v_flipped:
|
||||
trajs[:, :, 1] = H_new - trajs[:, :, 1]
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
def crop(self, rgbs, trajs):
|
||||
T, N, _ = trajs.shape
|
||||
|
||||
S = len(rgbs)
|
||||
H, W = rgbs[0].shape[:2]
|
||||
assert S == T
|
||||
|
||||
############ spatial transform ############
|
||||
|
||||
H_new = H
|
||||
W_new = W
|
||||
|
||||
# simple random crop
|
||||
y0 = (
|
||||
0
|
||||
if self.crop_size[0] >= H_new
|
||||
else np.random.randint(0, H_new - self.crop_size[0])
|
||||
)
|
||||
x0 = (
|
||||
0
|
||||
if self.crop_size[1] >= W_new
|
||||
else np.random.randint(0, W_new - self.crop_size[1])
|
||||
)
|
||||
rgbs = [
|
||||
rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||
for rgb in rgbs
|
||||
]
|
||||
|
||||
trajs[:, :, 0] -= x0
|
||||
trajs[:, :, 1] -= y0
|
||||
|
||||
return rgbs, trajs
|
||||
|
||||
|
||||
class KubricMovifDataset(CoTrackerDataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
crop_size=(384, 512),
|
||||
seq_len=24,
|
||||
traj_per_sample=768,
|
||||
sample_vis_1st_frame=False,
|
||||
use_augs=False,
|
||||
):
|
||||
super(KubricMovifDataset, self).__init__(
|
||||
data_root=data_root,
|
||||
crop_size=crop_size,
|
||||
seq_len=seq_len,
|
||||
traj_per_sample=traj_per_sample,
|
||||
sample_vis_1st_frame=sample_vis_1st_frame,
|
||||
use_augs=use_augs,
|
||||
)
|
||||
|
||||
self.pad_bounds = [0, 25]
|
||||
self.resize_lim = [0.75, 1.25] # sample resizes from here
|
||||
self.resize_delta = 0.05
|
||||
self.max_crop_offset = 15
|
||||
self.seq_names = [
|
||||
fname
|
||||
for fname in os.listdir(data_root)
|
||||
if os.path.isdir(os.path.join(data_root, fname))
|
||||
]
|
||||
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||
|
||||
def getitem_helper(self, index):
|
||||
gotit = True
|
||||
seq_name = self.seq_names[index]
|
||||
|
||||
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
|
||||
rgb_path = os.path.join(self.data_root, seq_name, "frames")
|
||||
|
||||
img_paths = sorted(os.listdir(rgb_path))
|
||||
rgbs = []
|
||||
for i, img_path in enumerate(img_paths):
|
||||
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
|
||||
|
||||
rgbs = np.stack(rgbs)
|
||||
annot_dict = np.load(npy_path, allow_pickle=True).item()
|
||||
traj_2d = annot_dict["coords"]
|
||||
visibility = annot_dict["visibility"]
|
||||
|
||||
# random crop
|
||||
assert self.seq_len <= len(rgbs)
|
||||
if self.seq_len < len(rgbs):
|
||||
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
|
||||
|
||||
rgbs = rgbs[start_ind : start_ind + self.seq_len]
|
||||
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
|
||||
visibility = visibility[:, start_ind : start_ind + self.seq_len]
|
||||
|
||||
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
||||
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
||||
if self.use_augs:
|
||||
rgbs, traj_2d, visibility = self.add_photometric_augs(
|
||||
rgbs, traj_2d, visibility
|
||||
)
|
||||
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
||||
else:
|
||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||
|
||||
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
|
||||
visibility[traj_2d[:, :, 0] < 0] = False
|
||||
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
|
||||
visibility[traj_2d[:, :, 1] < 0] = False
|
||||
|
||||
visibility = torch.from_numpy(visibility)
|
||||
traj_2d = torch.from_numpy(traj_2d)
|
||||
|
||||
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
|
||||
|
||||
if self.sample_vis_1st_frame:
|
||||
visibile_pts_inds = visibile_pts_first_frame_inds
|
||||
else:
|
||||
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(
|
||||
as_tuple=False
|
||||
)[:, 0]
|
||||
visibile_pts_inds = torch.cat(
|
||||
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
||||
)
|
||||
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
|
||||
if len(point_inds) < self.traj_per_sample:
|
||||
gotit = False
|
||||
|
||||
visible_inds_sampled = visibile_pts_inds[point_inds]
|
||||
|
||||
trajs = traj_2d[:, visible_inds_sampled].float()
|
||||
visibles = visibility[:, visible_inds_sampled]
|
||||
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
||||
|
||||
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
||||
segs = torch.ones((self.seq_len, 1, self.crop_size[0], self.crop_size[1]))
|
||||
sample = CoTrackerData(
|
||||
video=rgbs,
|
||||
segmentation=segs,
|
||||
trajectory=trajs,
|
||||
visibility=visibles,
|
||||
valid=valids,
|
||||
seq_name=seq_name,
|
||||
)
|
||||
return sample, gotit
|
||||
|
||||
def __len__(self):
|
||||
return len(self.seq_names)
|
218
cotracker/datasets/tap_vid_datasets.py
Normal file
218
cotracker/datasets/tap_vid_datasets.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import io
|
||||
import glob
|
||||
import torch
|
||||
import pickle
|
||||
import numpy as np
|
||||
import mediapy as media
|
||||
|
||||
from PIL import Image
|
||||
from typing import Mapping, Tuple, Union
|
||||
|
||||
from cotracker.datasets.utils import CoTrackerData
|
||||
|
||||
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
||||
|
||||
|
||||
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||
"""Resize a video to output_size."""
|
||||
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
||||
# such as a jitted jax.image.resize. It will make things faster.
|
||||
return media.resize_video(video, output_size)
|
||||
|
||||
|
||||
def sample_queries_first(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
Given a set of frames and tracks with no query points, use the first
|
||||
visible point in each track as the query.
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1]
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1]
|
||||
"""
|
||||
valid = np.sum(~target_occluded, axis=1) > 0
|
||||
target_points = target_points[valid, :]
|
||||
target_occluded = target_occluded[valid, :]
|
||||
|
||||
query_points = []
|
||||
for i in range(target_points.shape[0]):
|
||||
index = np.where(target_occluded[i] == 0)[0][0]
|
||||
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
||||
query_points.append(np.array([index, y, x])) # [t, y, x]
|
||||
query_points = np.stack(query_points, axis=0)
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": query_points[np.newaxis, ...],
|
||||
"target_points": target_points[np.newaxis, ...],
|
||||
"occluded": target_occluded[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
def sample_queries_strided(
|
||||
target_occluded: np.ndarray,
|
||||
target_points: np.ndarray,
|
||||
frames: np.ndarray,
|
||||
query_stride: int = 5,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||
|
||||
Given a set of frames and tracks with no query points, sample queries
|
||||
strided every query_stride frames, ignoring points that are not visible
|
||||
at the selected frames.
|
||||
|
||||
Args:
|
||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||
where True indicates occluded.
|
||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||
is [x,y] scaled between 0 and 1.
|
||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||
-1 and 1.
|
||||
query_stride: When sampling query points, search for un-occluded points
|
||||
every query_stride frames and convert each one into a query.
|
||||
|
||||
Returns:
|
||||
A dict with the keys:
|
||||
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
||||
has floats scaled to the range [-1, 1].
|
||||
query_points: Query points of shape [1, n_queries, 3] where
|
||||
each point is [t, y, x] scaled to the range [-1, 1].
|
||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||
each point is [x, y] scaled to the range [-1, 1].
|
||||
trackgroup: Index of the original track that each query point was
|
||||
sampled from. This is useful for visualization.
|
||||
"""
|
||||
tracks = []
|
||||
occs = []
|
||||
queries = []
|
||||
trackgroups = []
|
||||
total = 0
|
||||
trackgroup = np.arange(target_occluded.shape[0])
|
||||
for i in range(0, target_occluded.shape[1], query_stride):
|
||||
mask = target_occluded[:, i] == 0
|
||||
query = np.stack(
|
||||
[
|
||||
i * np.ones(target_occluded.shape[0:1]),
|
||||
target_points[:, i, 1],
|
||||
target_points[:, i, 0],
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
queries.append(query[mask])
|
||||
tracks.append(target_points[mask])
|
||||
occs.append(target_occluded[mask])
|
||||
trackgroups.append(trackgroup[mask])
|
||||
total += np.array(np.sum(target_occluded[:, i] == 0))
|
||||
|
||||
return {
|
||||
"video": frames[np.newaxis, ...],
|
||||
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
|
||||
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
||||
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
|
||||
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
||||
}
|
||||
|
||||
|
||||
class TapVidDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
data_root,
|
||||
dataset_type="davis",
|
||||
resize_to_256=True,
|
||||
queried_first=True,
|
||||
):
|
||||
self.dataset_type = dataset_type
|
||||
self.resize_to_256 = resize_to_256
|
||||
self.queried_first = queried_first
|
||||
if self.dataset_type == "kinetics":
|
||||
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
|
||||
points_dataset = []
|
||||
for pickle_path in all_paths:
|
||||
with open(pickle_path, "rb") as f:
|
||||
data = pickle.load(f)
|
||||
points_dataset = points_dataset + data
|
||||
self.points_dataset = points_dataset
|
||||
else:
|
||||
with open(data_root, "rb") as f:
|
||||
self.points_dataset = pickle.load(f)
|
||||
if self.dataset_type == "davis":
|
||||
self.video_names = list(self.points_dataset.keys())
|
||||
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.dataset_type == "davis":
|
||||
video_name = self.video_names[index]
|
||||
else:
|
||||
video_name = index
|
||||
video = self.points_dataset[video_name]
|
||||
frames = video["video"]
|
||||
|
||||
if isinstance(frames[0], bytes):
|
||||
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
||||
def decode(frame):
|
||||
byteio = io.BytesIO(frame)
|
||||
img = Image.open(byteio)
|
||||
return np.array(img)
|
||||
|
||||
frames = np.array([decode(frame) for frame in frames])
|
||||
|
||||
target_points = self.points_dataset[video_name]["points"]
|
||||
if self.resize_to_256:
|
||||
frames = resize_video(frames, [256, 256])
|
||||
target_points *= np.array([256, 256])
|
||||
else:
|
||||
target_points *= np.array([frames.shape[2], frames.shape[1]])
|
||||
|
||||
T, H, W, C = frames.shape
|
||||
N, T, D = target_points.shape
|
||||
|
||||
target_occ = self.points_dataset[video_name]["occluded"]
|
||||
if self.queried_first:
|
||||
converted = sample_queries_first(target_occ, target_points, frames)
|
||||
else:
|
||||
converted = sample_queries_strided(target_occ, target_points, frames)
|
||||
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
||||
|
||||
trajs = (
|
||||
torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float()
|
||||
) # T, N, D
|
||||
|
||||
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
||||
segs = torch.ones(T, 1, H, W).float()
|
||||
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[
|
||||
0
|
||||
].permute(
|
||||
1, 0
|
||||
) # T, N
|
||||
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
||||
return CoTrackerData(
|
||||
rgbs,
|
||||
segs,
|
||||
trajs,
|
||||
visibles,
|
||||
seq_name=str(video_name),
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.points_dataset)
|
114
cotracker/datasets/utils.py
Normal file
114
cotracker/datasets/utils.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
|
||||
import torch
|
||||
import dataclasses
|
||||
import torch.nn.functional as F
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class CoTrackerData:
|
||||
"""
|
||||
Dataclass for storing video tracks data.
|
||||
"""
|
||||
|
||||
video: torch.Tensor # B, S, C, H, W
|
||||
segmentation: torch.Tensor # B, S, 1, H, W
|
||||
trajectory: torch.Tensor # B, S, N, 2
|
||||
visibility: torch.Tensor # B, S, N
|
||||
# optional data
|
||||
valid: Optional[torch.Tensor] = None # B, S, N
|
||||
seq_name: Optional[str] = None
|
||||
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
"""
|
||||
Collate function for video tracks data.
|
||||
"""
|
||||
video = torch.stack([b.video for b in batch], dim=0)
|
||||
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
||||
query_points = None
|
||||
if batch[0].query_points is not None:
|
||||
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
||||
seq_name = [b.seq_name for b in batch]
|
||||
|
||||
return CoTrackerData(
|
||||
video,
|
||||
segmentation,
|
||||
trajectory,
|
||||
visibility,
|
||||
seq_name=seq_name,
|
||||
query_points=query_points,
|
||||
)
|
||||
|
||||
|
||||
def collate_fn_train(batch):
|
||||
"""
|
||||
Collate function for video tracks data during training.
|
||||
"""
|
||||
gotit = [gotit for _, gotit in batch]
|
||||
video = torch.stack([b.video for b, _ in batch], dim=0)
|
||||
segmentation = torch.stack([b.segmentation for b, _ in batch], dim=0)
|
||||
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
||||
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
||||
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
||||
seq_name = [b.seq_name for b, _ in batch]
|
||||
return (
|
||||
CoTrackerData(video, segmentation, trajectory, visibility, valid, seq_name),
|
||||
gotit,
|
||||
)
|
||||
|
||||
|
||||
def try_to_cuda(t: Any) -> Any:
|
||||
"""
|
||||
Try to move the input variable `t` to a cuda device.
|
||||
|
||||
Args:
|
||||
t: Input.
|
||||
|
||||
Returns:
|
||||
t_cuda: `t` moved to a cuda device, if supported.
|
||||
"""
|
||||
try:
|
||||
t = t.float().cuda()
|
||||
except AttributeError:
|
||||
pass
|
||||
return t
|
||||
|
||||
|
||||
def dataclass_to_cuda_(obj):
|
||||
"""
|
||||
Move all contents of a dataclass to cuda inplace if supported.
|
||||
|
||||
Args:
|
||||
batch: Input dataclass.
|
||||
|
||||
Returns:
|
||||
batch_cuda: `batch` moved to a cuda device, if supported.
|
||||
"""
|
||||
for f in dataclasses.fields(obj):
|
||||
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
||||
return obj
|
||||
|
||||
|
||||
def resize_sample(rgbs, trajs_g, segs, interp_shape):
|
||||
S, C, H, W = rgbs.shape
|
||||
S, N, D = trajs_g.shape
|
||||
|
||||
assert D == 2
|
||||
|
||||
rgbs = F.interpolate(rgbs, interp_shape, mode="bilinear")
|
||||
segs = F.interpolate(segs, interp_shape, mode="nearest")
|
||||
|
||||
trajs_g[:, :, 0] *= interp_shape[1] / W
|
||||
trajs_g[:, :, 1] *= interp_shape[0] / H
|
||||
return rgbs, trajs_g, segs
|
5
cotracker/evaluation/__init__.py
Normal file
5
cotracker/evaluation/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
6
cotracker/evaluation/configs/eval_badja.yaml
Normal file
6
cotracker/evaluation/configs/eval_badja.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: badja
|
||||
|
||||
|
6
cotracker/evaluation/configs/eval_fastcapture.yaml
Normal file
6
cotracker/evaluation/configs/eval_fastcapture.yaml
Normal file
@@ -0,0 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: fastcapture
|
||||
|
||||
|
@@ -0,0 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_first
|
||||
|
||||
|
@@ -0,0 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_davis_strided
|
||||
|
||||
|
@@ -0,0 +1,6 @@
|
||||
defaults:
|
||||
- default_config_eval
|
||||
exp_dir: ./outputs/cotracker
|
||||
dataset_name: tapvid_kinetics_first
|
||||
|
||||
|
5
cotracker/evaluation/core/__init__.py
Normal file
5
cotracker/evaluation/core/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
144
cotracker/evaluation/core/eval_utils.py
Normal file
144
cotracker/evaluation/core/eval_utils.py
Normal file
@@ -0,0 +1,144 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from typing import Iterable, Mapping, Tuple, Union
|
||||
|
||||
|
||||
def compute_tapvid_metrics(
|
||||
query_points: np.ndarray,
|
||||
gt_occluded: np.ndarray,
|
||||
gt_tracks: np.ndarray,
|
||||
pred_occluded: np.ndarray,
|
||||
pred_tracks: np.ndarray,
|
||||
query_mode: str,
|
||||
) -> Mapping[str, np.ndarray]:
|
||||
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
|
||||
See the TAP-Vid paper for details on the metric computation. All inputs are
|
||||
given in raster coordinates. The first three arguments should be the direct
|
||||
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
||||
The paper metrics assume these are scaled relative to 256x256 images.
|
||||
pred_occluded and pred_tracks are your algorithm's predictions.
|
||||
This function takes a batch of inputs, and computes metrics separately for
|
||||
each video. The metrics for the full benchmark are a simple mean of the
|
||||
metrics across the full set of videos. These numbers are between 0 and 1,
|
||||
but the paper multiplies them by 100 to ease reading.
|
||||
Args:
|
||||
query_points: The query points, an in the format [t, y, x]. Its size is
|
||||
[b, n, 3], where b is the batch size and n is the number of queries
|
||||
gt_occluded: A boolean array of shape [b, n, t], where t is the number
|
||||
of frames. True indicates that the point is occluded.
|
||||
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
|
||||
in the format [x, y]
|
||||
pred_occluded: A boolean array of predicted occlusions, in the same
|
||||
format as gt_occluded.
|
||||
pred_tracks: An array of track predictions from your algorithm, in the
|
||||
same format as gt_tracks.
|
||||
query_mode: Either 'first' or 'strided', depending on how queries are
|
||||
sampled. If 'first', we assume the prior knowledge that all points
|
||||
before the query point are occluded, and these are removed from the
|
||||
evaluation.
|
||||
Returns:
|
||||
A dict with the following keys:
|
||||
occlusion_accuracy: Accuracy at predicting occlusion.
|
||||
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
||||
predicted to be within the given pixel threshold, ignoring occlusion
|
||||
prediction.
|
||||
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
||||
threshold
|
||||
average_pts_within_thresh: average across pts_within_{x}
|
||||
average_jaccard: average across jaccard_{x}
|
||||
"""
|
||||
|
||||
metrics = {}
|
||||
|
||||
# Don't evaluate the query point. Numpy doesn't have one_hot, so we
|
||||
# replicate it by indexing into an identity matrix.
|
||||
one_hot_eye = np.eye(gt_tracks.shape[2])
|
||||
query_frame = query_points[..., 0]
|
||||
query_frame = np.round(query_frame).astype(np.int32)
|
||||
evaluation_points = one_hot_eye[query_frame] == 0
|
||||
|
||||
# If we're using the first point on the track as a query, don't evaluate the
|
||||
# other points.
|
||||
if query_mode == "first":
|
||||
for i in range(gt_occluded.shape[0]):
|
||||
index = np.where(gt_occluded[i] == 0)[0][0]
|
||||
evaluation_points[i, :index] = False
|
||||
elif query_mode != "strided":
|
||||
raise ValueError("Unknown query mode " + query_mode)
|
||||
|
||||
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
||||
# ground truth.
|
||||
occ_acc = (
|
||||
np.sum(
|
||||
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
||||
axis=(1, 2),
|
||||
)
|
||||
/ np.sum(evaluation_points)
|
||||
)
|
||||
metrics["occlusion_accuracy"] = occ_acc
|
||||
|
||||
# Next, convert the predictions and ground truth positions into pixel
|
||||
# coordinates.
|
||||
visible = np.logical_not(gt_occluded)
|
||||
pred_visible = np.logical_not(pred_occluded)
|
||||
all_frac_within = []
|
||||
all_jaccard = []
|
||||
for thresh in [1, 2, 4, 8, 16]:
|
||||
# True positives are points that are within the threshold and where both
|
||||
# the prediction and the ground truth are listed as visible.
|
||||
within_dist = (
|
||||
np.sum(
|
||||
np.square(pred_tracks - gt_tracks),
|
||||
axis=-1,
|
||||
)
|
||||
< np.square(thresh)
|
||||
)
|
||||
is_correct = np.logical_and(within_dist, visible)
|
||||
|
||||
# Compute the frac_within_threshold, which is the fraction of points
|
||||
# within the threshold among points that are visible in the ground truth,
|
||||
# ignoring whether they're predicted to be visible.
|
||||
count_correct = np.sum(
|
||||
is_correct & evaluation_points,
|
||||
axis=(1, 2),
|
||||
)
|
||||
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
frac_correct = count_correct / count_visible_points
|
||||
metrics["pts_within_" + str(thresh)] = frac_correct
|
||||
all_frac_within.append(frac_correct)
|
||||
|
||||
true_positives = np.sum(
|
||||
is_correct & pred_visible & evaluation_points, axis=(1, 2)
|
||||
)
|
||||
|
||||
# The denominator of the jaccard metric is the true positives plus
|
||||
# false positives plus false negatives. However, note that true positives
|
||||
# plus false negatives is simply the number of points in the ground truth
|
||||
# which is easier to compute than trying to compute all three quantities.
|
||||
# Thus we just add the number of points in the ground truth to the number
|
||||
# of false positives.
|
||||
#
|
||||
# False positives are simply points that are predicted to be visible,
|
||||
# but the ground truth is not visible or too far from the prediction.
|
||||
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||
false_positives = (~visible) & pred_visible
|
||||
false_positives = false_positives | ((~within_dist) & pred_visible)
|
||||
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
|
||||
jaccard = true_positives / (gt_positives + false_positives)
|
||||
metrics["jaccard_" + str(thresh)] = jaccard
|
||||
all_jaccard.append(jaccard)
|
||||
metrics["average_jaccard"] = np.mean(
|
||||
np.stack(all_jaccard, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
metrics["average_pts_within_thresh"] = np.mean(
|
||||
np.stack(all_frac_within, axis=1),
|
||||
axis=1,
|
||||
)
|
||||
return metrics
|
252
cotracker/evaluation/core/evaluator.py
Normal file
252
cotracker/evaluation/core/evaluator.py
Normal file
@@ -0,0 +1,252 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from collections import defaultdict
|
||||
import os
|
||||
from typing import Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from cotracker.datasets.utils import dataclass_to_cuda_
|
||||
from cotracker.utils.visualizer import Visualizer
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""
|
||||
A class defining the CoTracker evaluator.
|
||||
"""
|
||||
|
||||
def __init__(self, exp_dir) -> None:
|
||||
# Visualization
|
||||
self.exp_dir = exp_dir
|
||||
os.makedirs(exp_dir, exist_ok=True)
|
||||
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
|
||||
self.visualize_dir = os.path.join(exp_dir, "visualisations")
|
||||
|
||||
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
|
||||
if isinstance(pred_trajectory, tuple):
|
||||
pred_trajectory, pred_visibility = pred_trajectory
|
||||
else:
|
||||
pred_visibility = None
|
||||
if dataset_name == "badja":
|
||||
sample.segmentation = (sample.segmentation > 0).float()
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
accs = []
|
||||
accs_3px = []
|
||||
for s1 in range(1, sample.video.shape[1]): # target frame
|
||||
for n in range(N):
|
||||
vis = sample.visibility[0, s1, n]
|
||||
if vis > 0:
|
||||
coord_e = pred_trajectory[0, s1, n] # 2
|
||||
coord_g = sample.trajectory[0, s1, n] # 2
|
||||
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
|
||||
area = torch.sum(sample.segmentation[0, s1])
|
||||
# print_('0.2*sqrt(area)', 0.2*torch.sqrt(area))
|
||||
thr = 0.2 * torch.sqrt(area)
|
||||
# correct =
|
||||
accs.append((dist < thr).float())
|
||||
# print('thr',thr)
|
||||
accs_3px.append((dist < 3.0).float())
|
||||
|
||||
res = torch.mean(torch.stack(accs)) * 100.0
|
||||
res_3px = torch.mean(torch.stack(accs_3px)) * 100.0
|
||||
metrics[sample.seq_name[0]] = res.item()
|
||||
metrics[sample.seq_name[0] + "_accuracy"] = res_3px.item()
|
||||
print(metrics)
|
||||
print(
|
||||
"avg", np.mean([v for k, v in metrics.items() if "accuracy" not in k])
|
||||
)
|
||||
print(
|
||||
"avg acc 3px",
|
||||
np.mean([v for k, v in metrics.items() if "accuracy" in k]),
|
||||
)
|
||||
elif dataset_name == "fastcapture" or ("kubric" in dataset_name):
|
||||
*_, N, _ = sample.trajectory.shape
|
||||
accs = []
|
||||
for s1 in range(1, sample.video.shape[1]): # target frame
|
||||
for n in range(N):
|
||||
vis = sample.visibility[0, s1, n]
|
||||
if vis > 0:
|
||||
coord_e = pred_trajectory[0, s1, n] # 2
|
||||
coord_g = sample.trajectory[0, s1, n] # 2
|
||||
dist = torch.sqrt(torch.sum((coord_e - coord_g) ** 2, dim=0))
|
||||
thr = 3
|
||||
correct = (dist < thr).float()
|
||||
accs.append(correct)
|
||||
|
||||
res = torch.mean(torch.stack(accs)) * 100.0
|
||||
metrics[sample.seq_name[0] + "_accuracy"] = res.item()
|
||||
print(metrics)
|
||||
print("avg", np.mean([v for v in metrics.values()]))
|
||||
elif "tapvid" in dataset_name:
|
||||
B, T, N, D = sample.trajectory.shape
|
||||
traj = sample.trajectory.clone()
|
||||
thr = 0.9
|
||||
|
||||
if pred_visibility is None:
|
||||
logging.warning("visibility is NONE")
|
||||
pred_visibility = torch.zeros_like(sample.visibility)
|
||||
|
||||
if not pred_visibility.dtype == torch.bool:
|
||||
pred_visibility = pred_visibility > thr
|
||||
|
||||
# pred_trajectory
|
||||
query_points = sample.query_points.clone().cpu().numpy()
|
||||
|
||||
pred_visibility = pred_visibility[:, :, :N]
|
||||
pred_trajectory = pred_trajectory[:, :, :N]
|
||||
|
||||
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
||||
gt_occluded = (
|
||||
torch.logical_not(sample.visibility.clone().permute(0, 2, 1))
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
|
||||
pred_occluded = (
|
||||
torch.logical_not(pred_visibility.clone().permute(0, 2, 1))
|
||||
.cpu()
|
||||
.numpy()
|
||||
)
|
||||
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
||||
|
||||
out_metrics = compute_tapvid_metrics(
|
||||
query_points,
|
||||
gt_occluded,
|
||||
gt_tracks,
|
||||
pred_occluded,
|
||||
pred_tracks,
|
||||
query_mode="strided" if "strided" in dataset_name else "first",
|
||||
)
|
||||
|
||||
metrics[sample.seq_name[0]] = out_metrics
|
||||
for metric_name in out_metrics.keys():
|
||||
if "avg" not in metrics:
|
||||
metrics["avg"] = {}
|
||||
metrics["avg"][metric_name] = np.mean(
|
||||
[v[metric_name] for k, v in metrics.items() if k != "avg"]
|
||||
)
|
||||
|
||||
logging.info(f"Metrics: {out_metrics}")
|
||||
logging.info(f"avg: {metrics['avg']}")
|
||||
print("metrics", out_metrics)
|
||||
print("avg", metrics["avg"])
|
||||
else:
|
||||
rgbs = sample.video
|
||||
trajs_g = sample.trajectory
|
||||
valids = sample.valid
|
||||
vis_g = sample.visibility
|
||||
|
||||
B, S, C, H, W = rgbs.shape
|
||||
assert C == 3
|
||||
B, S, N, D = trajs_g.shape
|
||||
|
||||
assert torch.sum(valids) == B * S * N
|
||||
|
||||
vis_g = (torch.sum(vis_g, dim=1, keepdim=True) >= 4).float().repeat(1, S, 1)
|
||||
|
||||
ate = torch.norm(pred_trajectory - trajs_g, dim=-1) # B, S, N
|
||||
|
||||
metrics["things_all"] = reduce_masked_mean(ate, valids).item()
|
||||
metrics["things_vis"] = reduce_masked_mean(ate, valids * vis_g).item()
|
||||
metrics["things_occ"] = reduce_masked_mean(
|
||||
ate, valids * (1.0 - vis_g)
|
||||
).item()
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate_sequence(
|
||||
self,
|
||||
model,
|
||||
test_dataloader: torch.utils.data.DataLoader,
|
||||
dataset_name: str,
|
||||
train_mode=False,
|
||||
writer: Optional[SummaryWriter] = None,
|
||||
step: Optional[int] = 0,
|
||||
):
|
||||
metrics = {}
|
||||
|
||||
vis = Visualizer(
|
||||
save_dir=self.exp_dir,
|
||||
fps=7,
|
||||
)
|
||||
|
||||
for ind, sample in enumerate(tqdm(test_dataloader)):
|
||||
if isinstance(sample, tuple):
|
||||
sample, gotit = sample
|
||||
if not all(gotit):
|
||||
print("batch is None")
|
||||
continue
|
||||
dataclass_to_cuda_(sample)
|
||||
|
||||
if (
|
||||
not train_mode
|
||||
and hasattr(model, "sequence_len")
|
||||
and (sample.visibility[:, : model.sequence_len].sum() == 0)
|
||||
):
|
||||
print(f"skipping batch {ind}")
|
||||
continue
|
||||
|
||||
if "tapvid" in dataset_name:
|
||||
queries = sample.query_points.clone().float()
|
||||
|
||||
queries = torch.stack(
|
||||
[
|
||||
queries[:, :, 0],
|
||||
queries[:, :, 2],
|
||||
queries[:, :, 1],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
else:
|
||||
queries = torch.cat(
|
||||
[
|
||||
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
|
||||
sample.trajectory[:, 0],
|
||||
],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
pred_tracks = model(sample.video, queries)
|
||||
if "strided" in dataset_name:
|
||||
|
||||
inv_video = sample.video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
pred_trj, pred_vsb = pred_tracks
|
||||
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
|
||||
|
||||
inv_pred_trj = inv_pred_trj.flip(1)
|
||||
inv_pred_vsb = inv_pred_vsb.flip(1)
|
||||
|
||||
mask = pred_trj == 0
|
||||
|
||||
pred_trj[mask] = inv_pred_trj[mask]
|
||||
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
|
||||
|
||||
pred_tracks = pred_trj, pred_vsb
|
||||
|
||||
if dataset_name == "badja" or dataset_name == "fastcapture":
|
||||
seq_name = sample.seq_name[0]
|
||||
else:
|
||||
seq_name = str(ind)
|
||||
|
||||
vis.visualize(
|
||||
sample.video,
|
||||
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
||||
filename=dataset_name + "_" + seq_name,
|
||||
writer=writer,
|
||||
step=step,
|
||||
)
|
||||
|
||||
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
||||
return metrics
|
179
cotracker/evaluation/evaluate.py
Normal file
179
cotracker/evaluation/evaluate.py
Normal file
@@ -0,0 +1,179 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from cotracker.datasets.badja_dataset import BadjaDataset
|
||||
from cotracker.datasets.fast_capture_dataset import FastCaptureDataset
|
||||
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||
from cotracker.datasets.utils import collate_fn
|
||||
|
||||
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||
|
||||
from cotracker.evaluation.core.evaluator import Evaluator
|
||||
from cotracker.models.build_cotracker import (
|
||||
build_cotracker,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(eq=False)
|
||||
class DefaultConfig:
|
||||
# Directory where all outputs of the experiment will be saved.
|
||||
exp_dir: str = "./outputs"
|
||||
|
||||
# Name of the dataset to be used for the evaluation.
|
||||
dataset_name: str = "badja"
|
||||
# The root directory of the dataset.
|
||||
dataset_root: str = "./"
|
||||
|
||||
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
||||
# The default value is the path to a specific CoTracker model checkpoint.
|
||||
# Other available options are commented.
|
||||
checkpoint: str = "./checkpoints/cotracker_stride_4_wind_8.pth"
|
||||
# cotracker_stride_4_wind_12
|
||||
# cotracker_stride_8_wind_16
|
||||
|
||||
# EvaluationPredictor parameters
|
||||
# The size (N) of the support grid used in the predictor.
|
||||
# The total number of points is (N*N).
|
||||
grid_size: int = 6
|
||||
# The size (N) of the local support grid.
|
||||
local_grid_size: int = 6
|
||||
# A flag indicating whether to evaluate one ground truth point at a time.
|
||||
single_point: bool = True
|
||||
# The number of iterative updates for each sliding window.
|
||||
n_iters: int = 6
|
||||
|
||||
seed: int = 0
|
||||
gpu_idx: int = 0
|
||||
|
||||
# Override hydra's working directory to current working dir,
|
||||
# also disable storing the .hydra logs:
|
||||
hydra: dict = field(
|
||||
default_factory=lambda: {
|
||||
"run": {"dir": "."},
|
||||
"output_subdir": None,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def run_eval(cfg: DefaultConfig):
|
||||
"""
|
||||
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
|
||||
|
||||
Args:
|
||||
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
|
||||
- exp_dir (str): The directory path for the experiment.
|
||||
- dataset_name (str): The name of the dataset to be used.
|
||||
- dataset_root (str): The root directory of the dataset.
|
||||
- checkpoint (str): The path to the CoTracker model's checkpoint.
|
||||
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
|
||||
- n_iters (int): The number of iterative updates for each sliding window.
|
||||
- seed (int): The seed for setting the random state for reproducibility.
|
||||
- gpu_idx (int): The index of the GPU to be used.
|
||||
"""
|
||||
# Creating the experiment directory if it doesn't exist
|
||||
os.makedirs(cfg.exp_dir, exist_ok=True)
|
||||
|
||||
# Saving the experiment configuration to a .yaml file in the experiment directory
|
||||
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
||||
with open(cfg_file, "w") as f:
|
||||
OmegaConf.save(config=cfg, f=f)
|
||||
|
||||
evaluator = Evaluator(cfg.exp_dir)
|
||||
cotracker_model = build_cotracker(cfg.checkpoint)
|
||||
|
||||
# Creating the EvaluationPredictor object
|
||||
predictor = EvaluationPredictor(
|
||||
cotracker_model,
|
||||
grid_size=cfg.grid_size,
|
||||
local_grid_size=cfg.local_grid_size,
|
||||
single_point=cfg.single_point,
|
||||
n_iters=cfg.n_iters,
|
||||
)
|
||||
|
||||
# Setting the random seeds
|
||||
torch.manual_seed(cfg.seed)
|
||||
np.random.seed(cfg.seed)
|
||||
|
||||
# Constructing the specified dataset
|
||||
curr_collate_fn = collate_fn
|
||||
if cfg.dataset_name == "badja":
|
||||
test_dataset = BadjaDataset(data_root=os.path.join(cfg.dataset_root, "BADJA"))
|
||||
elif cfg.dataset_name == "fastcapture":
|
||||
test_dataset = FastCaptureDataset(
|
||||
data_root=os.path.join(cfg.dataset_root, "fastcapture"),
|
||||
max_seq_len=100,
|
||||
max_num_points=20,
|
||||
)
|
||||
elif "tapvid" in cfg.dataset_name:
|
||||
dataset_type = cfg.dataset_name.split("_")[1]
|
||||
if dataset_type == "davis":
|
||||
data_root = os.path.join(cfg.dataset_root, "/tapvid_davis/tapvid_davis.pkl")
|
||||
elif dataset_type == "kinetics":
|
||||
data_root = os.path.join(
|
||||
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
||||
)
|
||||
test_dataset = TapVidDataset(
|
||||
dataset_type=dataset_type,
|
||||
data_root=data_root,
|
||||
queried_first=not "strided" in cfg.dataset_name,
|
||||
)
|
||||
|
||||
# Creating the DataLoader object
|
||||
test_dataloader = torch.utils.data.DataLoader(
|
||||
test_dataset,
|
||||
batch_size=1,
|
||||
shuffle=False,
|
||||
num_workers=14,
|
||||
collate_fn=curr_collate_fn,
|
||||
)
|
||||
|
||||
# Timing and conducting the evaluation
|
||||
import time
|
||||
|
||||
start = time.time()
|
||||
evaluate_result = evaluator.evaluate_sequence(
|
||||
predictor,
|
||||
test_dataloader,
|
||||
dataset_name=cfg.dataset_name,
|
||||
)
|
||||
end = time.time()
|
||||
print(end - start)
|
||||
|
||||
# Saving the evaluation results to a .json file
|
||||
if not "tapvid" in cfg.dataset_name:
|
||||
print("evaluate_result", evaluate_result)
|
||||
else:
|
||||
evaluate_result = evaluate_result["avg"]
|
||||
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
||||
evaluate_result["time"] = end - start
|
||||
print(f"Dumping eval results to {result_file}.")
|
||||
with open(result_file, "w") as f:
|
||||
json.dump(evaluate_result, f)
|
||||
|
||||
|
||||
cs = hydra.core.config_store.ConfigStore.instance()
|
||||
cs.store(name="default_config_eval", node=DefaultConfig)
|
||||
|
||||
|
||||
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
||||
def evaluate(cfg: DefaultConfig) -> None:
|
||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
||||
run_eval(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
evaluate()
|
5
cotracker/models/__init__.py
Normal file
5
cotracker/models/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
70
cotracker/models/build_cotracker.py
Normal file
70
cotracker/models/build_cotracker.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker
|
||||
|
||||
|
||||
def build_cotracker(
|
||||
checkpoint: str,
|
||||
):
|
||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||
if model_name == "cotracker_stride_4_wind_8":
|
||||
return build_cotracker_stride_4_wind_8(checkpoint=checkpoint)
|
||||
elif model_name == "cotracker_stride_4_wind_12":
|
||||
return build_cotracker_stride_4_wind_12(checkpoint=checkpoint)
|
||||
elif model_name == "cotracker_stride_8_wind_16":
|
||||
return build_cotracker_stride_8_wind_16(checkpoint=checkpoint)
|
||||
else:
|
||||
raise ValueError(f"Unknown model name {model_name}")
|
||||
|
||||
|
||||
# model used to produce the results in the paper
|
||||
def build_cotracker_stride_4_wind_8(checkpoint=None):
|
||||
return _build_cotracker(
|
||||
stride=4,
|
||||
sequence_len=8,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def build_cotracker_stride_4_wind_12(checkpoint=None):
|
||||
return _build_cotracker(
|
||||
stride=4,
|
||||
sequence_len=12,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
# the fastest model
|
||||
def build_cotracker_stride_8_wind_16(checkpoint=None):
|
||||
return _build_cotracker(
|
||||
stride=8,
|
||||
sequence_len=16,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def _build_cotracker(
|
||||
stride,
|
||||
sequence_len,
|
||||
checkpoint=None,
|
||||
):
|
||||
cotracker = CoTracker(
|
||||
stride=stride,
|
||||
S=sequence_len,
|
||||
add_space_attn=True,
|
||||
space_depth=6,
|
||||
time_depth=6,
|
||||
)
|
||||
if checkpoint is not None:
|
||||
with open(checkpoint, "rb") as f:
|
||||
state_dict = torch.load(f, map_location="cpu")
|
||||
if "model" in state_dict:
|
||||
state_dict = state_dict["model"]
|
||||
cotracker.load_state_dict(state_dict)
|
||||
return cotracker
|
5
cotracker/models/core/__init__.py
Normal file
5
cotracker/models/core/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
5
cotracker/models/core/cotracker/__init__.py
Normal file
5
cotracker/models/core/cotracker/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
400
cotracker/models/core/cotracker/blocks.py
Normal file
400
cotracker/models/core/cotracker/blocks.py
Normal file
@@ -0,0 +1,400 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from einops import rearrange
|
||||
from timm.models.vision_transformer import Attention, Mlp
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=stride,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
||||
)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(
|
||||
self, input_dim=3, output_dim=128, stride=8, norm_fn="batch", dropout=0.0
|
||||
):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.stride = stride
|
||||
self.norm_fn = norm_fn
|
||||
self.in_planes = 64
|
||||
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=self.in_planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=8, num_channels=output_dim * 2)
|
||||
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2d(self.in_planes)
|
||||
self.norm2 = nn.BatchNorm2d(output_dim * 2)
|
||||
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
||||
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
||||
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
input_dim,
|
||||
self.in_planes,
|
||||
kernel_size=7,
|
||||
stride=2,
|
||||
padding=3,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.shallow = False
|
||||
if self.shallow:
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
self.conv2 = nn.Conv2d(128 + 96 + 64, output_dim, kernel_size=1)
|
||||
else:
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
self.layer4 = self._make_layer(128, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
128 + 128 + 96 + 64,
|
||||
output_dim * 2,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
padding_mode="zeros",
|
||||
)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
_, _, H, W = x.shape
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
if self.shallow:
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
a = F.interpolate(
|
||||
a,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
b = F.interpolate(
|
||||
b,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
c = F.interpolate(
|
||||
c,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
x = self.conv2(torch.cat([a, b, c], dim=1))
|
||||
else:
|
||||
a = self.layer1(x)
|
||||
b = self.layer2(a)
|
||||
c = self.layer3(b)
|
||||
d = self.layer4(c)
|
||||
a = F.interpolate(
|
||||
a,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
b = F.interpolate(
|
||||
b,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
c = F.interpolate(
|
||||
c,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
d = F.interpolate(
|
||||
d,
|
||||
(H // self.stride, W // self.stride),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
||||
x = self.norm2(x)
|
||||
x = self.relu2(x)
|
||||
x = self.conv3(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
"""
|
||||
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.attn = Attention(
|
||||
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
|
||||
)
|
||||
|
||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=approx_gelu,
|
||||
drop=0,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
||||
"""Wrapper for grid_sample, uses pixel coordinates"""
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||
# go to 0,1 then 0,2 then -1,1
|
||||
xgrid = 2 * xgrid / (W - 1) - 1
|
||||
ygrid = 2 * ygrid / (H - 1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
def __init__(self, fmaps, num_levels=4, radius=4):
|
||||
B, S, C, H, W = fmaps.shape
|
||||
self.S, self.C, self.H, self.W = S, C, H, W
|
||||
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.fmaps_pyramid = []
|
||||
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
for i in range(self.num_levels - 1):
|
||||
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
||||
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
||||
_, _, H, W = fmaps_.shape
|
||||
fmaps = fmaps_.reshape(B, S, C, H, W)
|
||||
self.fmaps_pyramid.append(fmaps)
|
||||
|
||||
def sample(self, coords):
|
||||
r = self.radius
|
||||
B, S, N, D = coords.shape
|
||||
assert D == 2
|
||||
|
||||
H, W = self.H, self.W
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
||||
_, _, _, H, W = corrs.shape
|
||||
|
||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
||||
coords.device
|
||||
)
|
||||
|
||||
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
||||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
|
||||
corrs = corrs.view(B, S, N, -1)
|
||||
out_pyramid.append(corrs)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
||||
return out.contiguous().float()
|
||||
|
||||
def corr(self, targets):
|
||||
B, S, N, C = targets.shape
|
||||
assert C == self.C
|
||||
assert S == self.S
|
||||
|
||||
fmap1 = targets
|
||||
|
||||
self.corrs_pyramid = []
|
||||
for fmaps in self.fmaps_pyramid:
|
||||
_, _, _, H, W = fmaps.shape
|
||||
fmap2s = fmaps.view(B, S, C, H * W)
|
||||
corrs = torch.matmul(fmap1, fmap2s)
|
||||
corrs = corrs.view(B, S, N, H, W)
|
||||
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
||||
self.corrs_pyramid.append(corrs)
|
||||
|
||||
|
||||
class UpdateFormer(nn.Module):
|
||||
"""
|
||||
Transformer model that updates track estimates.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
space_depth=12,
|
||||
time_depth=12,
|
||||
input_dim=320,
|
||||
hidden_size=384,
|
||||
num_heads=8,
|
||||
output_dim=130,
|
||||
mlp_ratio=4.0,
|
||||
add_space_attn=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.out_channels = 2
|
||||
self.num_heads = num_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.add_space_attn = add_space_attn
|
||||
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
||||
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
||||
|
||||
self.time_blocks = nn.ModuleList(
|
||||
[
|
||||
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(time_depth)
|
||||
]
|
||||
)
|
||||
|
||||
if add_space_attn:
|
||||
self.space_blocks = nn.ModuleList(
|
||||
[
|
||||
AttnBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio)
|
||||
for _ in range(space_depth)
|
||||
]
|
||||
)
|
||||
assert len(self.time_blocks) >= len(self.space_blocks)
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
def _basic_init(module):
|
||||
if isinstance(module, nn.Linear):
|
||||
torch.nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0)
|
||||
|
||||
self.apply(_basic_init)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
x = self.input_transform(input_tensor)
|
||||
|
||||
j = 0
|
||||
for i in range(len(self.time_blocks)):
|
||||
B, N, T, _ = x.shape
|
||||
x_time = rearrange(x, "b n t c -> (b n) t c", b=B, t=T, n=N)
|
||||
x_time = self.time_blocks[i](x_time)
|
||||
|
||||
x = rearrange(x_time, "(b n) t c -> b n t c ", b=B, t=T, n=N)
|
||||
if self.add_space_attn and (
|
||||
i % (len(self.time_blocks) // len(self.space_blocks)) == 0
|
||||
):
|
||||
x_space = rearrange(x, "b n t c -> (b t) n c ", b=B, t=T, n=N)
|
||||
x_space = self.space_blocks[j](x_space)
|
||||
x = rearrange(x_space, "(b t) n c -> b n t c ", b=B, t=T, n=N)
|
||||
j += 1
|
||||
|
||||
flow = self.flow_head(x)
|
||||
return flow
|
351
cotracker/models/core/cotracker/cotracker.py
Normal file
351
cotracker/models/core/cotracker/cotracker.py
Normal file
@@ -0,0 +1,351 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from cotracker.models.core.cotracker.blocks import (
|
||||
BasicEncoder,
|
||||
CorrBlock,
|
||||
UpdateFormer,
|
||||
)
|
||||
|
||||
from cotracker.models.core.model_utils import meshgrid2d, bilinear_sample2d, smart_cat
|
||||
from cotracker.models.core.embeddings import (
|
||||
get_2d_embedding,
|
||||
get_1d_sincos_pos_embed_from_grid,
|
||||
get_2d_sincos_pos_embed,
|
||||
)
|
||||
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
|
||||
def get_points_on_a_grid(grid_size, interp_shape, grid_center=(0, 0)):
|
||||
if grid_size == 1:
|
||||
return torch.tensor([interp_shape[1] / 2, interp_shape[0] / 2])[
|
||||
None, None
|
||||
].cuda()
|
||||
|
||||
grid_y, grid_x = meshgrid2d(
|
||||
1, grid_size, grid_size, stack=False, norm=False, device="cuda"
|
||||
)
|
||||
step = interp_shape[1] // 64
|
||||
if grid_center[0] != 0 or grid_center[1] != 0:
|
||||
grid_y = grid_y - grid_size / 2.0
|
||||
grid_x = grid_x - grid_size / 2.0
|
||||
grid_y = step + grid_y.reshape(1, -1) / float(grid_size - 1) * (
|
||||
interp_shape[0] - step * 2
|
||||
)
|
||||
grid_x = step + grid_x.reshape(1, -1) / float(grid_size - 1) * (
|
||||
interp_shape[1] - step * 2
|
||||
)
|
||||
|
||||
grid_y = grid_y + grid_center[0]
|
||||
grid_x = grid_x + grid_center[1]
|
||||
xy = torch.stack([grid_x, grid_y], dim=-1).cuda()
|
||||
return xy
|
||||
|
||||
|
||||
def sample_pos_embed(grid_size, embed_dim, coords):
|
||||
pos_embed = get_2d_sincos_pos_embed(embed_dim=embed_dim, grid_size=grid_size)
|
||||
pos_embed = (
|
||||
torch.from_numpy(pos_embed)
|
||||
.reshape(grid_size[0], grid_size[1], embed_dim)
|
||||
.float()
|
||||
.unsqueeze(0)
|
||||
.to(coords.device)
|
||||
)
|
||||
sampled_pos_embed = bilinear_sample2d(
|
||||
pos_embed.permute(0, 3, 1, 2), coords[:, 0, :, 0], coords[:, 0, :, 1]
|
||||
)
|
||||
return sampled_pos_embed
|
||||
|
||||
|
||||
class CoTracker(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
S=8,
|
||||
stride=8,
|
||||
add_space_attn=True,
|
||||
num_heads=8,
|
||||
hidden_size=384,
|
||||
space_depth=12,
|
||||
time_depth=12,
|
||||
):
|
||||
super(CoTracker, self).__init__()
|
||||
self.S = S
|
||||
self.stride = stride
|
||||
self.hidden_dim = 256
|
||||
self.latent_dim = latent_dim = 128
|
||||
self.corr_levels = 4
|
||||
self.corr_radius = 3
|
||||
self.add_space_attn = add_space_attn
|
||||
self.fnet = BasicEncoder(
|
||||
output_dim=self.latent_dim, norm_fn="instance", dropout=0, stride=stride
|
||||
)
|
||||
|
||||
self.updateformer = UpdateFormer(
|
||||
space_depth=space_depth,
|
||||
time_depth=time_depth,
|
||||
input_dim=456,
|
||||
hidden_size=hidden_size,
|
||||
num_heads=num_heads,
|
||||
output_dim=latent_dim + 2,
|
||||
mlp_ratio=4.0,
|
||||
add_space_attn=add_space_attn,
|
||||
)
|
||||
|
||||
self.norm = nn.GroupNorm(1, self.latent_dim)
|
||||
self.ffeat_updater = nn.Sequential(
|
||||
nn.Linear(self.latent_dim, self.latent_dim),
|
||||
nn.GELU(),
|
||||
)
|
||||
self.vis_predictor = nn.Sequential(
|
||||
nn.Linear(self.latent_dim, 1),
|
||||
)
|
||||
|
||||
def forward_iteration(
|
||||
self,
|
||||
fmaps,
|
||||
coords_init,
|
||||
feat_init=None,
|
||||
vis_init=None,
|
||||
track_mask=None,
|
||||
iters=4,
|
||||
):
|
||||
B, S_init, N, D = coords_init.shape
|
||||
assert D == 2
|
||||
assert B == 1
|
||||
|
||||
B, S, __, H8, W8 = fmaps.shape
|
||||
|
||||
device = fmaps.device
|
||||
|
||||
if S_init < S:
|
||||
coords = torch.cat(
|
||||
[coords_init, coords_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
|
||||
)
|
||||
vis_init = torch.cat(
|
||||
[vis_init, vis_init[:, -1].repeat(1, S - S_init, 1, 1)], dim=1
|
||||
)
|
||||
else:
|
||||
coords = coords_init.clone()
|
||||
|
||||
fcorr_fn = CorrBlock(
|
||||
fmaps, num_levels=self.corr_levels, radius=self.corr_radius
|
||||
)
|
||||
|
||||
ffeats = feat_init.clone()
|
||||
|
||||
times_ = torch.linspace(0, S - 1, S).reshape(1, S, 1)
|
||||
|
||||
pos_embed = sample_pos_embed(
|
||||
grid_size=(H8, W8),
|
||||
embed_dim=456,
|
||||
coords=coords,
|
||||
)
|
||||
pos_embed = rearrange(pos_embed, "b e n -> (b n) e").unsqueeze(1)
|
||||
times_embed = (
|
||||
torch.from_numpy(get_1d_sincos_pos_embed_from_grid(456, times_[0]))[None]
|
||||
.repeat(B, 1, 1)
|
||||
.float()
|
||||
.to(device)
|
||||
)
|
||||
coord_predictions = []
|
||||
|
||||
for __ in range(iters):
|
||||
coords = coords.detach()
|
||||
fcorr_fn.corr(ffeats)
|
||||
|
||||
fcorrs = fcorr_fn.sample(coords) # B, S, N, LRR
|
||||
LRR = fcorrs.shape[3]
|
||||
|
||||
fcorrs_ = fcorrs.permute(0, 2, 1, 3).reshape(B * N, S, LRR)
|
||||
flows_ = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
|
||||
|
||||
flows_cat = get_2d_embedding(flows_, 64, cat_coords=True)
|
||||
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
|
||||
|
||||
if track_mask.shape[1] < vis_init.shape[1]:
|
||||
track_mask = torch.cat(
|
||||
[
|
||||
track_mask,
|
||||
torch.zeros_like(track_mask[:, 0]).repeat(
|
||||
1, vis_init.shape[1] - track_mask.shape[1], 1, 1
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
concat = (
|
||||
torch.cat([track_mask, vis_init], dim=2)
|
||||
.permute(0, 2, 1, 3)
|
||||
.reshape(B * N, S, 2)
|
||||
)
|
||||
|
||||
transformer_input = torch.cat([flows_cat, fcorrs_, ffeats_, concat], dim=2)
|
||||
x = transformer_input + pos_embed + times_embed
|
||||
|
||||
x = rearrange(x, "(b n) t d -> b n t d", b=B)
|
||||
|
||||
delta = self.updateformer(x)
|
||||
|
||||
delta = rearrange(delta, " b n t d -> (b n) t d")
|
||||
|
||||
delta_coords_ = delta[:, :, :2]
|
||||
delta_feats_ = delta[:, :, 2:]
|
||||
|
||||
delta_feats_ = delta_feats_.reshape(B * N * S, self.latent_dim)
|
||||
ffeats_ = ffeats.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
|
||||
|
||||
ffeats_ = self.ffeat_updater(self.norm(delta_feats_)) + ffeats_
|
||||
|
||||
ffeats = ffeats_.reshape(B, N, S, self.latent_dim).permute(
|
||||
0, 2, 1, 3
|
||||
) # B,S,N,C
|
||||
|
||||
coords = coords + delta_coords_.reshape(B, N, S, 2).permute(0, 2, 1, 3)
|
||||
coord_predictions.append(coords * self.stride)
|
||||
|
||||
vis_e = self.vis_predictor(ffeats.reshape(B * S * N, self.latent_dim)).reshape(
|
||||
B, S, N
|
||||
)
|
||||
return coord_predictions, vis_e, feat_init
|
||||
|
||||
def forward(self, rgbs, queries, iters=4, feat_init=None, is_train=False):
|
||||
B, T, C, H, W = rgbs.shape
|
||||
B, N, __ = queries.shape
|
||||
|
||||
device = rgbs.device
|
||||
assert B == 1
|
||||
# INIT for the first sequence
|
||||
# We want to sort points by the first frame they are visible to add them to the tensor of tracked points consequtively
|
||||
first_positive_inds = queries[:, :, 0].long()
|
||||
|
||||
__, sort_inds = torch.sort(first_positive_inds[0], dim=0, descending=False)
|
||||
inv_sort_inds = torch.argsort(sort_inds, dim=0)
|
||||
first_positive_sorted_inds = first_positive_inds[0][sort_inds]
|
||||
|
||||
assert torch.allclose(
|
||||
first_positive_inds[0], first_positive_inds[0][sort_inds][inv_sort_inds]
|
||||
)
|
||||
|
||||
coords_init = queries[:, :, 1:].reshape(B, 1, N, 2).repeat(
|
||||
1, self.S, 1, 1
|
||||
) / float(self.stride)
|
||||
|
||||
rgbs = 2 * (rgbs / 255.0) - 1.0
|
||||
|
||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||
vis_e = torch.zeros((B, T, N), device=device)
|
||||
|
||||
ind_array = torch.arange(T, device=device)
|
||||
ind_array = ind_array[None, :, None].repeat(B, 1, N)
|
||||
|
||||
track_mask = (ind_array >= first_positive_inds[:, None, :]).unsqueeze(-1)
|
||||
# these are logits, so we initialize visibility with something that would give a value close to 1 after softmax
|
||||
vis_init = torch.ones((B, self.S, N, 1), device=device).float() * 10
|
||||
|
||||
ind = 0
|
||||
|
||||
track_mask_ = track_mask[:, :, sort_inds].clone()
|
||||
coords_init_ = coords_init[:, :, sort_inds].clone()
|
||||
vis_init_ = vis_init[:, :, sort_inds].clone()
|
||||
|
||||
prev_wind_idx = 0
|
||||
fmaps_ = None
|
||||
vis_predictions = []
|
||||
coord_predictions = []
|
||||
wind_inds = []
|
||||
while ind < T - self.S // 2:
|
||||
rgbs_seq = rgbs[:, ind : ind + self.S]
|
||||
|
||||
S = S_local = rgbs_seq.shape[1]
|
||||
if S < self.S:
|
||||
rgbs_seq = torch.cat(
|
||||
[rgbs_seq, rgbs_seq[:, -1, None].repeat(1, self.S - S, 1, 1, 1)],
|
||||
dim=1,
|
||||
)
|
||||
S = rgbs_seq.shape[1]
|
||||
rgbs_ = rgbs_seq.reshape(B * S, C, H, W)
|
||||
|
||||
if fmaps_ is None:
|
||||
fmaps_ = self.fnet(rgbs_)
|
||||
else:
|
||||
fmaps_ = torch.cat(
|
||||
[fmaps_[self.S // 2 :], self.fnet(rgbs_[self.S // 2 :])], dim=0
|
||||
)
|
||||
fmaps = fmaps_.reshape(
|
||||
B, S, self.latent_dim, H // self.stride, W // self.stride
|
||||
)
|
||||
|
||||
curr_wind_points = torch.nonzero(first_positive_sorted_inds < ind + self.S)
|
||||
if curr_wind_points.shape[0] == 0:
|
||||
ind = ind + self.S // 2
|
||||
continue
|
||||
wind_idx = curr_wind_points[-1] + 1
|
||||
|
||||
if wind_idx - prev_wind_idx > 0:
|
||||
fmaps_sample = fmaps[
|
||||
:, first_positive_sorted_inds[prev_wind_idx:wind_idx] - ind
|
||||
]
|
||||
|
||||
feat_init_ = bilinear_sample2d(
|
||||
fmaps_sample,
|
||||
coords_init_[:, 0, prev_wind_idx:wind_idx, 0],
|
||||
coords_init_[:, 0, prev_wind_idx:wind_idx, 1],
|
||||
).permute(0, 2, 1)
|
||||
|
||||
feat_init_ = feat_init_.unsqueeze(1).repeat(1, self.S, 1, 1)
|
||||
feat_init = smart_cat(feat_init, feat_init_, dim=2)
|
||||
|
||||
if prev_wind_idx > 0:
|
||||
new_coords = coords[-1][:, self.S // 2 :] / float(self.stride)
|
||||
|
||||
coords_init_[:, : self.S // 2, :prev_wind_idx] = new_coords
|
||||
coords_init_[:, self.S // 2 :, :prev_wind_idx] = new_coords[
|
||||
:, -1
|
||||
].repeat(1, self.S // 2, 1, 1)
|
||||
|
||||
new_vis = vis[:, self.S // 2 :].unsqueeze(-1)
|
||||
vis_init_[:, : self.S // 2, :prev_wind_idx] = new_vis
|
||||
vis_init_[:, self.S // 2 :, :prev_wind_idx] = new_vis[:, -1].repeat(
|
||||
1, self.S // 2, 1, 1
|
||||
)
|
||||
|
||||
coords, vis, __ = self.forward_iteration(
|
||||
fmaps=fmaps,
|
||||
coords_init=coords_init_[:, :, :wind_idx],
|
||||
feat_init=feat_init[:, :, :wind_idx],
|
||||
vis_init=vis_init_[:, :, :wind_idx],
|
||||
track_mask=track_mask_[:, ind : ind + self.S, :wind_idx],
|
||||
iters=iters,
|
||||
)
|
||||
if is_train:
|
||||
vis_predictions.append(torch.sigmoid(vis[:, :S_local]))
|
||||
coord_predictions.append([coord[:, :S_local] for coord in coords])
|
||||
wind_inds.append(wind_idx)
|
||||
|
||||
traj_e[:, ind : ind + self.S, :wind_idx] = coords[-1][:, :S_local]
|
||||
vis_e[:, ind : ind + self.S, :wind_idx] = vis[:, :S_local]
|
||||
|
||||
track_mask_[:, : ind + self.S, :wind_idx] = 0.0
|
||||
ind = ind + self.S // 2
|
||||
|
||||
prev_wind_idx = wind_idx
|
||||
|
||||
traj_e = traj_e[:, :, inv_sort_inds]
|
||||
vis_e = vis_e[:, :, inv_sort_inds]
|
||||
|
||||
vis_e = torch.sigmoid(vis_e)
|
||||
|
||||
train_data = (
|
||||
(vis_predictions, coord_predictions, wind_inds, sort_inds)
|
||||
if is_train
|
||||
else None
|
||||
)
|
||||
return traj_e, feat_init, vis_e, train_data
|
61
cotracker/models/core/cotracker/losses.py
Normal file
61
cotracker/models/core/cotracker/losses.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def balanced_ce_loss(pred, gt, valid=None):
|
||||
total_balanced_loss = 0.0
|
||||
for j in range(len(gt)):
|
||||
B, S, N = gt[j].shape
|
||||
# pred and gt are the same shape
|
||||
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
# if valid is not None:
|
||||
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
||||
assert a == b # some shape mismatch!
|
||||
|
||||
pos = (gt[j] > 0.95).float()
|
||||
neg = (gt[j] < 0.05).float()
|
||||
|
||||
label = pos * 2.0 - 1.0
|
||||
a = -label * pred[j]
|
||||
b = F.relu(a)
|
||||
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
||||
|
||||
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
||||
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
||||
|
||||
balanced_loss = pos_loss + neg_loss
|
||||
total_balanced_loss += balanced_loss / float(N)
|
||||
return total_balanced_loss
|
||||
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
|
||||
"""Loss function defined over sequence of flow predictions"""
|
||||
total_flow_loss = 0.0
|
||||
for j in range(len(flow_gt)):
|
||||
B, S, N, D = flow_gt[j].shape
|
||||
assert D == 2
|
||||
B, S1, N = vis[j].shape
|
||||
B, S2, N = valids[j].shape
|
||||
assert S == S1
|
||||
assert S == S2
|
||||
n_predictions = len(flow_preds[j])
|
||||
flow_loss = 0.0
|
||||
for i in range(n_predictions):
|
||||
i_weight = gamma ** (n_predictions - i - 1)
|
||||
flow_pred = flow_preds[j][i]
|
||||
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
||||
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
||||
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
|
||||
flow_loss = flow_loss / n_predictions
|
||||
total_flow_loss += flow_loss / float(N)
|
||||
return total_flow_loss
|
154
cotracker/models/core/embeddings.py
Normal file
154
cotracker/models/core/embeddings.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width
|
||||
return:
|
||||
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
if isinstance(grid_size, tuple):
|
||||
grid_size_h, grid_size_w = grid_size
|
||||
else:
|
||||
grid_size_h = grid_size_w = grid_size
|
||||
grid_h = np.arange(grid_size_h, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size_w, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate(
|
||||
[np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
assert embed_dim % 2 == 0
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position
|
||||
pos: a list of positions to be encoded: size (M,)
|
||||
out: (M, D)
|
||||
"""
|
||||
assert embed_dim % 2 == 0
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000 ** omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_embedding(xy, C, cat_coords=True):
|
||||
B, N, D = xy.shape
|
||||
assert D == 2
|
||||
|
||||
x = xy[:, :, 0:1]
|
||||
y = xy[:, :, 1:2]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y], dim=2) # B, N, C*3
|
||||
if cat_coords:
|
||||
pe = torch.cat([xy, pe], dim=2) # B, N, C*3+3
|
||||
return pe
|
||||
|
||||
|
||||
def get_3d_embedding(xyz, C, cat_coords=True):
|
||||
B, N, D = xyz.shape
|
||||
assert D == 3
|
||||
|
||||
x = xyz[:, :, 0:1]
|
||||
y = xyz[:, :, 1:2]
|
||||
z = xyz[:, :, 2:3]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xyz.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||
pe_z = torch.zeros(B, N, C, device=xyz.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
||||
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y, pe_z], dim=2) # B, N, C*3
|
||||
if cat_coords:
|
||||
pe = torch.cat([pe, xyz], dim=2) # B, N, C*3+3
|
||||
return pe
|
||||
|
||||
|
||||
def get_4d_embedding(xyzw, C, cat_coords=True):
|
||||
B, N, D = xyzw.shape
|
||||
assert D == 4
|
||||
|
||||
x = xyzw[:, :, 0:1]
|
||||
y = xyzw[:, :, 1:2]
|
||||
z = xyzw[:, :, 2:3]
|
||||
w = xyzw[:, :, 3:4]
|
||||
div_term = (
|
||||
torch.arange(0, C, 2, device=xyzw.device, dtype=torch.float32) * (1000.0 / C)
|
||||
).reshape(1, 1, int(C / 2))
|
||||
|
||||
pe_x = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
pe_y = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
pe_z = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
pe_w = torch.zeros(B, N, C, device=xyzw.device, dtype=torch.float32)
|
||||
|
||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||
|
||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||
|
||||
pe_z[:, :, 0::2] = torch.sin(z * div_term)
|
||||
pe_z[:, :, 1::2] = torch.cos(z * div_term)
|
||||
|
||||
pe_w[:, :, 0::2] = torch.sin(w * div_term)
|
||||
pe_w[:, :, 1::2] = torch.cos(w * div_term)
|
||||
|
||||
pe = torch.cat([pe_x, pe_y, pe_z, pe_w], dim=2) # B, N, C*3
|
||||
if cat_coords:
|
||||
pe = torch.cat([pe, xyzw], dim=2) # B, N, C*3+3
|
||||
return pe
|
169
cotracker/models/core/model_utils.py
Normal file
169
cotracker/models/core/model_utils.py
Normal file
@@ -0,0 +1,169 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
|
||||
EPS = 1e-6
|
||||
|
||||
|
||||
def smart_cat(tensor1, tensor2, dim):
|
||||
if tensor1 is None:
|
||||
return tensor2
|
||||
return torch.cat([tensor1, tensor2], dim=dim)
|
||||
|
||||
|
||||
def normalize_single(d):
|
||||
# d is a whatever shape torch tensor
|
||||
dmin = torch.min(d)
|
||||
dmax = torch.max(d)
|
||||
d = (d - dmin) / (EPS + (dmax - dmin))
|
||||
return d
|
||||
|
||||
|
||||
def normalize(d):
|
||||
# d is B x whatever. normalize within each element of the batch
|
||||
out = torch.zeros(d.size())
|
||||
if d.is_cuda:
|
||||
out = out.cuda()
|
||||
B = list(d.size())[0]
|
||||
for b in list(range(B)):
|
||||
out[b] = normalize_single(d[b])
|
||||
return out
|
||||
|
||||
|
||||
def meshgrid2d(B, Y, X, stack=False, norm=False, device="cuda"):
|
||||
# returns a meshgrid sized B x Y x X
|
||||
|
||||
grid_y = torch.linspace(0.0, Y - 1, Y, device=torch.device(device))
|
||||
grid_y = torch.reshape(grid_y, [1, Y, 1])
|
||||
grid_y = grid_y.repeat(B, 1, X)
|
||||
|
||||
grid_x = torch.linspace(0.0, X - 1, X, device=torch.device(device))
|
||||
grid_x = torch.reshape(grid_x, [1, 1, X])
|
||||
grid_x = grid_x.repeat(B, Y, 1)
|
||||
|
||||
if stack:
|
||||
# note we stack in xy order
|
||||
# (see https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample)
|
||||
grid = torch.stack([grid_x, grid_y], dim=-1)
|
||||
return grid
|
||||
else:
|
||||
return grid_y, grid_x
|
||||
|
||||
|
||||
def reduce_masked_mean(x, mask, dim=None, keepdim=False):
|
||||
# x and mask are the same shape, or at least broadcastably so < actually it's safer if you disallow broadcasting
|
||||
# returns shape-1
|
||||
# axis can be a list of axes
|
||||
for (a, b) in zip(x.size(), mask.size()):
|
||||
assert a == b # some shape mismatch!
|
||||
prod = x * mask
|
||||
if dim is None:
|
||||
numer = torch.sum(prod)
|
||||
denom = EPS + torch.sum(mask)
|
||||
else:
|
||||
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
||||
denom = EPS + torch.sum(mask, dim=dim, keepdim=keepdim)
|
||||
|
||||
mean = numer / denom
|
||||
return mean
|
||||
|
||||
|
||||
def bilinear_sample2d(im, x, y, return_inbounds=False):
|
||||
# x and y are each B, N
|
||||
# output is B, C, N
|
||||
if len(im.shape) == 5:
|
||||
B, N, C, H, W = list(im.shape)
|
||||
else:
|
||||
B, C, H, W = list(im.shape)
|
||||
N = list(x.shape)[1]
|
||||
|
||||
x = x.float()
|
||||
y = y.float()
|
||||
H_f = torch.tensor(H, dtype=torch.float32)
|
||||
W_f = torch.tensor(W, dtype=torch.float32)
|
||||
|
||||
# inbound_mask = (x>-0.5).float()*(y>-0.5).float()*(x<W_f+0.5).float()*(y<H_f+0.5).float()
|
||||
|
||||
max_y = (H_f - 1).int()
|
||||
max_x = (W_f - 1).int()
|
||||
|
||||
x0 = torch.floor(x).int()
|
||||
x1 = x0 + 1
|
||||
y0 = torch.floor(y).int()
|
||||
y1 = y0 + 1
|
||||
|
||||
x0_clip = torch.clamp(x0, 0, max_x)
|
||||
x1_clip = torch.clamp(x1, 0, max_x)
|
||||
y0_clip = torch.clamp(y0, 0, max_y)
|
||||
y1_clip = torch.clamp(y1, 0, max_y)
|
||||
dim2 = W
|
||||
dim1 = W * H
|
||||
|
||||
base = torch.arange(0, B, dtype=torch.int64, device=x.device) * dim1
|
||||
base = torch.reshape(base, [B, 1]).repeat([1, N])
|
||||
|
||||
base_y0 = base + y0_clip * dim2
|
||||
base_y1 = base + y1_clip * dim2
|
||||
|
||||
idx_y0_x0 = base_y0 + x0_clip
|
||||
idx_y0_x1 = base_y0 + x1_clip
|
||||
idx_y1_x0 = base_y1 + x0_clip
|
||||
idx_y1_x1 = base_y1 + x1_clip
|
||||
|
||||
# use the indices to lookup pixels in the flat image
|
||||
# im is B x C x H x W
|
||||
# move C out to last dim
|
||||
if len(im.shape) == 5:
|
||||
im_flat = (im.permute(0, 3, 4, 1, 2)).reshape(B * H * W, N, C)
|
||||
i_y0_x0 = torch.diagonal(im_flat[idx_y0_x0.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
i_y0_x1 = torch.diagonal(im_flat[idx_y0_x1.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
i_y1_x0 = torch.diagonal(im_flat[idx_y1_x0.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
i_y1_x1 = torch.diagonal(im_flat[idx_y1_x1.long()], dim1=1, dim2=2).permute(
|
||||
0, 2, 1
|
||||
)
|
||||
else:
|
||||
im_flat = (im.permute(0, 2, 3, 1)).reshape(B * H * W, C)
|
||||
i_y0_x0 = im_flat[idx_y0_x0.long()]
|
||||
i_y0_x1 = im_flat[idx_y0_x1.long()]
|
||||
i_y1_x0 = im_flat[idx_y1_x0.long()]
|
||||
i_y1_x1 = im_flat[idx_y1_x1.long()]
|
||||
|
||||
# Finally calculate interpolated values.
|
||||
x0_f = x0.float()
|
||||
x1_f = x1.float()
|
||||
y0_f = y0.float()
|
||||
y1_f = y1.float()
|
||||
|
||||
w_y0_x0 = ((x1_f - x) * (y1_f - y)).unsqueeze(2)
|
||||
w_y0_x1 = ((x - x0_f) * (y1_f - y)).unsqueeze(2)
|
||||
w_y1_x0 = ((x1_f - x) * (y - y0_f)).unsqueeze(2)
|
||||
w_y1_x1 = ((x - x0_f) * (y - y0_f)).unsqueeze(2)
|
||||
|
||||
output = (
|
||||
w_y0_x0 * i_y0_x0 + w_y0_x1 * i_y0_x1 + w_y1_x0 * i_y1_x0 + w_y1_x1 * i_y1_x1
|
||||
)
|
||||
# output is B*N x C
|
||||
output = output.view(B, -1, C)
|
||||
output = output.permute(0, 2, 1)
|
||||
# output is B x C x N
|
||||
|
||||
if return_inbounds:
|
||||
x_valid = (x > -0.5).byte() & (x < float(W_f - 0.5)).byte()
|
||||
y_valid = (y > -0.5).byte() & (y < float(H_f - 0.5)).byte()
|
||||
inbounds = (x_valid & y_valid).float()
|
||||
inbounds = inbounds.reshape(
|
||||
B, N
|
||||
) # something seems wrong here for B>1; i'm getting an error here (or downstream if i put -1)
|
||||
return output, inbounds
|
||||
|
||||
return output # B, C, N
|
103
cotracker/models/evaluation_predictor.py
Normal file
103
cotracker/models/evaluation_predictor.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
from cotracker.models.core.cotracker.cotracker import CoTracker, get_points_on_a_grid
|
||||
|
||||
|
||||
class EvaluationPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
cotracker_model: CoTracker,
|
||||
interp_shape: Tuple[int, int] = (384, 512),
|
||||
grid_size: int = 6,
|
||||
local_grid_size: int = 6,
|
||||
single_point: bool = True,
|
||||
n_iters: int = 6,
|
||||
) -> None:
|
||||
super(EvaluationPredictor, self).__init__()
|
||||
self.grid_size = grid_size
|
||||
self.local_grid_size = local_grid_size
|
||||
self.single_point = single_point
|
||||
self.interp_shape = interp_shape
|
||||
self.n_iters = n_iters
|
||||
|
||||
self.model = cotracker_model
|
||||
self.model.to("cuda")
|
||||
self.model.eval()
|
||||
|
||||
def forward(self, video, queries):
|
||||
queries = queries.clone().cuda()
|
||||
B, T, C, H, W = video.shape
|
||||
B, N, D = queries.shape
|
||||
|
||||
assert D == 3
|
||||
assert B == 1
|
||||
|
||||
rgbs = video.reshape(B * T, C, H, W)
|
||||
rgbs = F.interpolate(rgbs, tuple(self.interp_shape), mode="bilinear")
|
||||
rgbs = rgbs.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]).cuda()
|
||||
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
|
||||
if self.single_point:
|
||||
traj_e = torch.zeros((B, T, N, 2)).cuda()
|
||||
vis_e = torch.zeros((B, T, N)).cuda()
|
||||
for pind in range((N)):
|
||||
query = queries[:, pind : pind + 1]
|
||||
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
traj_e_pind, vis_e_pind = self._process_one_point(rgbs, query)
|
||||
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||
else:
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
||||
queries = torch.cat([queries, xy], dim=1) #
|
||||
|
||||
traj_e, __, vis_e, __ = self.model(
|
||||
rgbs=rgbs,
|
||||
queries=queries,
|
||||
iters=self.n_iters,
|
||||
)
|
||||
|
||||
traj_e[:, :, :, 0] *= W / float(self.interp_shape[1])
|
||||
traj_e[:, :, :, 1] *= H / float(self.interp_shape[0])
|
||||
return traj_e, vis_e
|
||||
|
||||
def _process_one_point(self, rgbs, query):
|
||||
t = query[0, 0, 0].long()
|
||||
|
||||
device = rgbs.device
|
||||
if self.local_grid_size > 0:
|
||||
xy_target = get_points_on_a_grid(
|
||||
self.local_grid_size,
|
||||
(50, 50),
|
||||
[query[0, 0, 2], query[0, 0, 1]],
|
||||
)
|
||||
|
||||
xy_target = torch.cat(
|
||||
[torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2
|
||||
) #
|
||||
query = torch.cat([query, xy_target], dim=1).to(device) #
|
||||
|
||||
if self.grid_size > 0:
|
||||
xy = get_points_on_a_grid(self.grid_size, rgbs.shape[3:])
|
||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).cuda() #
|
||||
query = torch.cat([query, xy], dim=1).to(device) #
|
||||
# crop the video to start from the queried frame
|
||||
query[0, 0, 0] = 0
|
||||
traj_e_pind, __, vis_e_pind, __ = self.model(
|
||||
rgbs=rgbs[:, t:], queries=query, iters=self.n_iters
|
||||
)
|
||||
|
||||
return traj_e_pind, vis_e_pind
|
178
cotracker/predictor.py
Normal file
178
cotracker/predictor.py
Normal file
@@ -0,0 +1,178 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from tqdm import tqdm
|
||||
from cotracker.models.core.cotracker.cotracker import get_points_on_a_grid
|
||||
from cotracker.models.core.model_utils import smart_cat
|
||||
from cotracker.models.build_cotracker import (
|
||||
build_cotracker,
|
||||
)
|
||||
|
||||
|
||||
class CoTrackerPredictor(torch.nn.Module):
|
||||
def __init__(
|
||||
self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth"
|
||||
):
|
||||
super().__init__()
|
||||
self.interp_shape = (384, 512)
|
||||
self.support_grid_size = 6
|
||||
model = build_cotracker(checkpoint)
|
||||
|
||||
self.model = model
|
||||
self.model.to("cuda")
|
||||
self.model.eval()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
video, # (1, T, 3, H, W)
|
||||
# input prompt types:
|
||||
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
||||
# *backward_tracking=True* will compute tracks in both directions.
|
||||
# - queries. Queried points of shape (1, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
||||
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
||||
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
||||
queries: torch.Tensor = None,
|
||||
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
|
||||
grid_size: int = 0,
|
||||
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
||||
backward_tracking: bool = False,
|
||||
):
|
||||
|
||||
if queries is None and grid_size == 0:
|
||||
tracks, visibilities = self._compute_dense_tracks(
|
||||
video,
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
else:
|
||||
tracks, visibilities = self._compute_sparse_tracks(
|
||||
video,
|
||||
queries,
|
||||
segm_mask,
|
||||
grid_size,
|
||||
add_support_grid=(grid_size == 0 or segm_mask is not None),
|
||||
grid_query_frame=grid_query_frame,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_dense_tracks(
|
||||
self, video, grid_query_frame, grid_size=50, backward_tracking=False
|
||||
):
|
||||
*_, H, W = video.shape
|
||||
grid_step = W // grid_size
|
||||
grid_width = W // grid_step
|
||||
grid_height = H // grid_step
|
||||
tracks = visibilities = None
|
||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to("cuda")
|
||||
grid_pts[0, :, 0] = grid_query_frame
|
||||
for offset in tqdm(range(grid_step * grid_step)):
|
||||
ox = offset % grid_step
|
||||
oy = offset // grid_step
|
||||
grid_pts[0, :, 1] = (
|
||||
torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||
)
|
||||
grid_pts[0, :, 2] = (
|
||||
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||
)
|
||||
tracks_step, visibilities_step = self._compute_sparse_tracks(
|
||||
video=video,
|
||||
queries=grid_pts,
|
||||
backward_tracking=backward_tracking,
|
||||
)
|
||||
tracks = smart_cat(tracks, tracks_step, dim=2)
|
||||
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
|
||||
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_sparse_tracks(
|
||||
self,
|
||||
video,
|
||||
queries,
|
||||
segm_mask=None,
|
||||
grid_size=0,
|
||||
add_support_grid=False,
|
||||
grid_query_frame=0,
|
||||
backward_tracking=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
assert B == 1
|
||||
|
||||
video = video.reshape(B * T, C, H, W)
|
||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear").cuda()
|
||||
video = video.reshape(
|
||||
B, T, 3, self.interp_shape[0], self.interp_shape[1]
|
||||
).cuda()
|
||||
|
||||
if queries is not None:
|
||||
queries = queries.clone()
|
||||
B, N, D = queries.shape
|
||||
assert D == 3
|
||||
queries[:, :, 1] *= self.interp_shape[1] / W
|
||||
queries[:, :, 2] *= self.interp_shape[0] / H
|
||||
elif grid_size > 0:
|
||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape)
|
||||
if segm_mask is not None:
|
||||
segm_mask = F.interpolate(
|
||||
segm_mask, tuple(self.interp_shape), mode="nearest"
|
||||
)
|
||||
point_mask = segm_mask[0, 0][
|
||||
(grid_pts[0, :, 1]).round().long().cpu(),
|
||||
(grid_pts[0, :, 0]).round().long().cpu(),
|
||||
].bool()
|
||||
grid_pts = grid_pts[:, point_mask]
|
||||
|
||||
queries = torch.cat(
|
||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||
dim=2,
|
||||
)
|
||||
|
||||
if add_support_grid:
|
||||
grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape)
|
||||
grid_pts = torch.cat(
|
||||
[torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2
|
||||
)
|
||||
queries = torch.cat([queries, grid_pts], dim=1)
|
||||
|
||||
tracks, __, visibilities, __ = self.model(rgbs=video, queries=queries, iters=6)
|
||||
|
||||
if backward_tracking:
|
||||
tracks, visibilities = self._compute_backward_tracks(
|
||||
video, queries, tracks, visibilities
|
||||
)
|
||||
if add_support_grid:
|
||||
queries[:, -self.support_grid_size ** 2 :, 0] = T - 1
|
||||
if add_support_grid:
|
||||
tracks = tracks[:, :, : -self.support_grid_size ** 2]
|
||||
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
|
||||
thr = 0.9
|
||||
visibilities = visibilities > thr
|
||||
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
|
||||
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
|
||||
return tracks, visibilities
|
||||
|
||||
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
||||
inv_video = video.flip(1).clone()
|
||||
inv_queries = queries.clone()
|
||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||
|
||||
inv_tracks, __, inv_visibilities, __ = self.model(
|
||||
rgbs=inv_video, queries=inv_queries, iters=6
|
||||
)
|
||||
|
||||
inv_tracks = inv_tracks.flip(1)
|
||||
inv_visibilities = inv_visibilities.flip(1)
|
||||
|
||||
mask = tracks == 0
|
||||
|
||||
tracks[mask] = inv_tracks[mask]
|
||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||
return tracks, visibilities
|
5
cotracker/utils/__init__.py
Normal file
5
cotracker/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
291
cotracker/utils/visualizer.py
Normal file
291
cotracker/utils/visualizer.py
Normal file
@@ -0,0 +1,291 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import flow_vis
|
||||
|
||||
from matplotlib import cm
|
||||
import torch.nn.functional as F
|
||||
import torchvision.transforms as transforms
|
||||
from moviepy.editor import ImageSequenceClip
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
|
||||
class Visualizer:
|
||||
def __init__(
|
||||
self,
|
||||
save_dir: str = "./results",
|
||||
grayscale: bool = False,
|
||||
pad_value: int = 0,
|
||||
fps: int = 10,
|
||||
mode: str = "rainbow", # 'cool', 'optical_flow'
|
||||
linewidth: int = 2,
|
||||
show_first_frame: int = 10,
|
||||
tracks_leave_trace: int = 0, # -1 for infinite
|
||||
):
|
||||
self.mode = mode
|
||||
self.save_dir = save_dir
|
||||
if mode == "rainbow":
|
||||
self.color_map = cm.get_cmap("gist_rainbow")
|
||||
elif mode == "cool":
|
||||
self.color_map = cm.get_cmap(mode)
|
||||
self.show_first_frame = show_first_frame
|
||||
self.grayscale = grayscale
|
||||
self.tracks_leave_trace = tracks_leave_trace
|
||||
self.pad_value = pad_value
|
||||
self.linewidth = linewidth
|
||||
self.fps = fps
|
||||
|
||||
def visualize(
|
||||
self,
|
||||
video: torch.Tensor, # (B,T,C,H,W)
|
||||
tracks: torch.Tensor, # (B,T,N,2)
|
||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||
filename: str = "video",
|
||||
writer: SummaryWriter = None,
|
||||
step: int = 0,
|
||||
query_frame: int = 0,
|
||||
save_video: bool = True,
|
||||
compensate_for_camera_motion: bool = False,
|
||||
):
|
||||
if compensate_for_camera_motion:
|
||||
assert segm_mask is not None
|
||||
if segm_mask is not None:
|
||||
coords = tracks[0, query_frame].round().long()
|
||||
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
||||
|
||||
video = F.pad(
|
||||
video,
|
||||
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
||||
"constant",
|
||||
255,
|
||||
)
|
||||
tracks = tracks + self.pad_value
|
||||
|
||||
if self.grayscale:
|
||||
transform = transforms.Grayscale()
|
||||
video = transform(video)
|
||||
video = video.repeat(1, 1, 3, 1, 1)
|
||||
|
||||
res_video = self.draw_tracks_on_video(
|
||||
video=video,
|
||||
tracks=tracks,
|
||||
segm_mask=segm_mask,
|
||||
gt_tracks=gt_tracks,
|
||||
query_frame=query_frame,
|
||||
compensate_for_camera_motion=compensate_for_camera_motion,
|
||||
)
|
||||
if save_video:
|
||||
self.save_video(res_video, filename=filename, writer=writer, step=step)
|
||||
return res_video
|
||||
|
||||
def save_video(self, video, filename, writer=None, step=0):
|
||||
if writer is not None:
|
||||
writer.add_video(
|
||||
f"{filename}_pred_track",
|
||||
video.to(torch.uint8),
|
||||
global_step=step,
|
||||
fps=self.fps,
|
||||
)
|
||||
else:
|
||||
os.makedirs(self.save_dir, exist_ok=True)
|
||||
wide_list = list(video.unbind(1))
|
||||
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
||||
clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
|
||||
|
||||
# Write the video file
|
||||
save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
|
||||
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
|
||||
|
||||
print(f"Video saved to {save_path}")
|
||||
|
||||
def draw_tracks_on_video(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
tracks: torch.Tensor,
|
||||
segm_mask: torch.Tensor = None,
|
||||
gt_tracks=None,
|
||||
query_frame: int = 0,
|
||||
compensate_for_camera_motion=False,
|
||||
):
|
||||
B, T, C, H, W = video.shape
|
||||
_, _, N, D = tracks.shape
|
||||
|
||||
assert D == 2
|
||||
assert C == 3
|
||||
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
||||
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
||||
if gt_tracks is not None:
|
||||
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
||||
|
||||
res_video = []
|
||||
|
||||
# process input video
|
||||
for rgb in video:
|
||||
res_video.append(rgb.copy())
|
||||
|
||||
vector_colors = np.zeros((T, N, 3))
|
||||
if self.mode == "optical_flow":
|
||||
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
||||
elif segm_mask is None:
|
||||
if self.mode == "rainbow":
|
||||
y_min, y_max = (
|
||||
tracks[query_frame, :, 1].min(),
|
||||
tracks[query_frame, :, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
color = self.color_map(norm(tracks[query_frame, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
else:
|
||||
# color changes with time
|
||||
for t in range(T):
|
||||
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
||||
vector_colors[t] = np.repeat(color, N, axis=0)
|
||||
else:
|
||||
if self.mode == "rainbow":
|
||||
vector_colors[:, segm_mask <= 0, :] = 255
|
||||
|
||||
y_min, y_max = (
|
||||
tracks[0, segm_mask > 0, 1].min(),
|
||||
tracks[0, segm_mask > 0, 1].max(),
|
||||
)
|
||||
norm = plt.Normalize(y_min, y_max)
|
||||
for n in range(N):
|
||||
if segm_mask[n] > 0:
|
||||
color = self.color_map(norm(tracks[0, n, 1]))
|
||||
color = np.array(color[:3])[None] * 255
|
||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||
|
||||
else:
|
||||
# color changes with segm class
|
||||
segm_mask = segm_mask.cpu()
|
||||
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
||||
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
||||
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
||||
vector_colors = np.repeat(color[None], T, axis=0)
|
||||
|
||||
# draw tracks
|
||||
if self.tracks_leave_trace != 0:
|
||||
for t in range(1, T):
|
||||
first_ind = (
|
||||
max(0, t - self.tracks_leave_trace)
|
||||
if self.tracks_leave_trace >= 0
|
||||
else 0
|
||||
)
|
||||
curr_tracks = tracks[first_ind : t + 1]
|
||||
curr_colors = vector_colors[first_ind : t + 1]
|
||||
if compensate_for_camera_motion:
|
||||
diff = (
|
||||
tracks[first_ind : t + 1, segm_mask <= 0]
|
||||
- tracks[t : t + 1, segm_mask <= 0]
|
||||
).mean(1)[:, None]
|
||||
|
||||
curr_tracks = curr_tracks - diff
|
||||
curr_tracks = curr_tracks[:, segm_mask > 0]
|
||||
curr_colors = curr_colors[:, segm_mask > 0]
|
||||
|
||||
res_video[t] = self._draw_pred_tracks(
|
||||
res_video[t],
|
||||
curr_tracks,
|
||||
curr_colors,
|
||||
)
|
||||
if gt_tracks is not None:
|
||||
res_video[t] = self._draw_gt_tracks(
|
||||
res_video[t], gt_tracks[first_ind : t + 1]
|
||||
)
|
||||
|
||||
# draw points
|
||||
for t in range(T):
|
||||
for i in range(N):
|
||||
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||
if coord[0] != 0 and coord[1] != 0:
|
||||
if not compensate_for_camera_motion or (
|
||||
compensate_for_camera_motion and segm_mask[i] > 0
|
||||
):
|
||||
cv2.circle(
|
||||
res_video[t],
|
||||
coord,
|
||||
int(self.linewidth * 2),
|
||||
vector_colors[t, i].tolist(),
|
||||
-1,
|
||||
)
|
||||
|
||||
# construct the final rgb sequence
|
||||
if self.show_first_frame > 0:
|
||||
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
||||
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
||||
|
||||
def _draw_pred_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3
|
||||
tracks: np.ndarray, # T x 2
|
||||
vector_colors: np.ndarray,
|
||||
alpha: float = 0.5,
|
||||
):
|
||||
T, N, _ = tracks.shape
|
||||
|
||||
for s in range(T - 1):
|
||||
vector_color = vector_colors[s]
|
||||
original = rgb.copy()
|
||||
alpha = (s / T) ** 2
|
||||
for i in range(N):
|
||||
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
||||
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
||||
if coord_y[0] != 0 and coord_y[1] != 0:
|
||||
cv2.line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
vector_color[i].tolist(),
|
||||
self.linewidth,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
if self.tracks_leave_trace > 0:
|
||||
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
|
||||
return rgb
|
||||
|
||||
def _draw_gt_tracks(
|
||||
self,
|
||||
rgb: np.ndarray, # H x W x 3,
|
||||
gt_tracks: np.ndarray, # T x 2
|
||||
):
|
||||
T, N, _ = gt_tracks.shape
|
||||
color = np.array((211.0, 0.0, 0.0))
|
||||
|
||||
for t in range(T):
|
||||
for i in range(N):
|
||||
gt_tracks = gt_tracks[t][i]
|
||||
# draw a red cross
|
||||
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
||||
length = self.linewidth * 3
|
||||
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
||||
cv2.line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
||||
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
||||
cv2.line(
|
||||
rgb,
|
||||
coord_y,
|
||||
coord_x,
|
||||
color,
|
||||
self.linewidth,
|
||||
cv2.LINE_AA,
|
||||
)
|
||||
return rgb
|
Reference in New Issue
Block a user