add some comments

This commit is contained in:
2024-08-05 23:36:58 +02:00
parent 36d1566750
commit 6e7bcd2d26
55 changed files with 3946 additions and 4095 deletions

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.

View File

@@ -1,166 +1,166 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import json import json
import dataclasses import dataclasses
import numpy as np import numpy as np
from dataclasses import Field, MISSING from dataclasses import Field, MISSING
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
_X = TypeVar("_X") _X = TypeVar("_X")
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X: def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
""" """
Loads to a @dataclass or collection hierarchy including dataclasses Loads to a @dataclass or collection hierarchy including dataclasses
from a json recursively. from a json recursively.
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]). Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
raises KeyError if json has keys not mapping to the dataclass fields. raises KeyError if json has keys not mapping to the dataclass fields.
Args: Args:
f: Either a path to a file, or a file opened for writing. f: Either a path to a file, or a file opened for writing.
cls: The class of the loaded dataclass. cls: The class of the loaded dataclass.
binary: Set to True if `f` is a file handle, else False. binary: Set to True if `f` is a file handle, else False.
""" """
if binary: if binary:
asdict = json.loads(f.read().decode("utf8")) asdict = json.loads(f.read().decode("utf8"))
else: else:
asdict = json.load(f) asdict = json.load(f)
# in the list case, run a faster "vectorized" version # in the list case, run a faster "vectorized" version
cls = get_args(cls)[0] cls = get_args(cls)[0]
res = list(_dataclass_list_from_dict_list(asdict, cls)) res = list(_dataclass_list_from_dict_list(asdict, cls))
return res return res
def _resolve_optional(type_: Any) -> Tuple[bool, Any]: def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T.""" """Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
if get_origin(type_) is Union: if get_origin(type_) is Union:
args = get_args(type_) args = get_args(type_)
if len(args) == 2 and args[1] == type(None): # noqa E721 if len(args) == 2 and args[1] == type(None): # noqa E721
return True, args[0] return True, args[0]
if type_ is Any: if type_ is Any:
return True, Any return True, Any
return False, type_ return False, type_
def _unwrap_type(tp): def _unwrap_type(tp):
# strips Optional wrapper, if any # strips Optional wrapper, if any
if get_origin(tp) is Union: if get_origin(tp) is Union:
args = get_args(tp) args = get_args(tp)
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721 if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
# this is typing.Optional # this is typing.Optional
return args[0] if args[1] is type(None) else args[1] # noqa: E721 return args[0] if args[1] is type(None) else args[1] # noqa: E721
return tp return tp
def _get_dataclass_field_default(field: Field) -> Any: def _get_dataclass_field_default(field: Field) -> Any:
if field.default_factory is not MISSING: if field.default_factory is not MISSING:
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE, # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
# dataclasses._DefaultFactory[typing.Any]]` is not a function. # dataclasses._DefaultFactory[typing.Any]]` is not a function.
return field.default_factory() return field.default_factory()
elif field.default is not MISSING: elif field.default is not MISSING:
return field.default return field.default
else: else:
return None return None
def _dataclass_list_from_dict_list(dlist, typeannot): def _dataclass_list_from_dict_list(dlist, typeannot):
""" """
Vectorised version of `_dataclass_from_dict`. Vectorised version of `_dataclass_from_dict`.
The output should be equivalent to The output should be equivalent to
`[_dataclass_from_dict(d, typeannot) for d in dlist]`. `[_dataclass_from_dict(d, typeannot) for d in dlist]`.
Args: Args:
dlist: list of objects to convert. dlist: list of objects to convert.
typeannot: type of each of those objects. typeannot: type of each of those objects.
Returns: Returns:
iterator or list over converted objects of the same length as `dlist`. iterator or list over converted objects of the same length as `dlist`.
Raises: Raises:
ValueError: it assumes the objects have None's in consistent places across ValueError: it assumes the objects have None's in consistent places across
objects, otherwise it would ignore some values. This generally holds for objects, otherwise it would ignore some values. This generally holds for
auto-generated annotations, but otherwise use `_dataclass_from_dict`. auto-generated annotations, but otherwise use `_dataclass_from_dict`.
""" """
cls = get_origin(typeannot) or typeannot cls = get_origin(typeannot) or typeannot
if typeannot is Any: if typeannot is Any:
return dlist return dlist
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
return dlist return dlist
if any(obj is None for obj in dlist): if any(obj is None for obj in dlist):
# filter out Nones and recurse on the resulting list # filter out Nones and recurse on the resulting list
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None] idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
idx, notnone = zip(*idx_notnone) idx, notnone = zip(*idx_notnone)
converted = _dataclass_list_from_dict_list(notnone, typeannot) converted = _dataclass_list_from_dict_list(notnone, typeannot)
res = [None] * len(dlist) res = [None] * len(dlist)
for i, obj in zip(idx, converted): for i, obj in zip(idx, converted):
res[i] = obj res[i] = obj
return res return res
is_optional, contained_type = _resolve_optional(typeannot) is_optional, contained_type = _resolve_optional(typeannot)
if is_optional: if is_optional:
return _dataclass_list_from_dict_list(dlist, contained_type) return _dataclass_list_from_dict_list(dlist, contained_type)
# otherwise, we dispatch by the type of the provided annotation to convert to # otherwise, we dispatch by the type of the provided annotation to convert to
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
# For namedtuple, call the function recursively on the lists of corresponding keys # For namedtuple, call the function recursively on the lists of corresponding keys
types = cls.__annotations__.values() types = cls.__annotations__.values()
dlist_T = zip(*dlist) dlist_T = zip(*dlist)
res_T = [ res_T = [
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types) _dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
] ]
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)] return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, (list, tuple)): elif issubclass(cls, (list, tuple)):
# For list/tuple, call the function recursively on the lists of corresponding positions # For list/tuple, call the function recursively on the lists of corresponding positions
types = get_args(typeannot) types = get_args(typeannot)
if len(types) == 1: # probably List; replicate for all items if len(types) == 1: # probably List; replicate for all items
types = types * len(dlist[0]) types = types * len(dlist[0])
dlist_T = zip(*dlist) dlist_T = zip(*dlist)
res_T = ( res_T = (
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types) _dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
) )
if issubclass(cls, tuple): if issubclass(cls, tuple):
return list(zip(*res_T)) return list(zip(*res_T))
else: else:
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)] return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
elif issubclass(cls, dict): elif issubclass(cls, dict):
# For the dictionary, call the function recursively on concatenated keys and vertices # For the dictionary, call the function recursively on concatenated keys and vertices
key_t, val_t = get_args(typeannot) key_t, val_t = get_args(typeannot)
all_keys_res = _dataclass_list_from_dict_list( all_keys_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.keys()], key_t [k for obj in dlist for k in obj.keys()], key_t
) )
all_vals_res = _dataclass_list_from_dict_list( all_vals_res = _dataclass_list_from_dict_list(
[k for obj in dlist for k in obj.values()], val_t [k for obj in dlist for k in obj.values()], val_t
) )
indices = np.cumsum([len(obj) for obj in dlist]) indices = np.cumsum([len(obj) for obj in dlist])
assert indices[-1] == len(all_keys_res) assert indices[-1] == len(all_keys_res)
keys = np.split(list(all_keys_res), indices[:-1]) keys = np.split(list(all_keys_res), indices[:-1])
all_vals_res_iter = iter(all_vals_res) all_vals_res_iter = iter(all_vals_res)
return [cls(zip(k, all_vals_res_iter)) for k in keys] return [cls(zip(k, all_vals_res_iter)) for k in keys]
elif not dataclasses.is_dataclass(typeannot): elif not dataclasses.is_dataclass(typeannot):
return dlist return dlist
# dataclass node: 2nd recursion base; call the function recursively on the lists # dataclass node: 2nd recursion base; call the function recursively on the lists
# of the corresponding fields # of the corresponding fields
assert dataclasses.is_dataclass(cls) assert dataclasses.is_dataclass(cls)
fieldtypes = { fieldtypes = {
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f)) f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
for f in dataclasses.fields(typeannot) for f in dataclasses.fields(typeannot)
} }
# NOTE the default object is shared here # NOTE the default object is shared here
key_lists = ( key_lists = (
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_) _dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
for k, (type_, default) in fieldtypes.items() for k, (type_, default) in fieldtypes.items()
) )
transposed = zip(*key_lists) transposed = zip(*key_lists)
return [cls(*vals_as_tuple) for vals_as_tuple in transposed] return [cls(*vals_as_tuple) for vals_as_tuple in transposed]

View File

@@ -1,161 +1,161 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os import os
import gzip import gzip
import torch import torch
import numpy as np import numpy as np
import torch.utils.data as data import torch.utils.data as data
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Any, Dict, Tuple from typing import List, Optional, Any, Dict, Tuple
from cotracker.datasets.utils import CoTrackerData from cotracker.datasets.utils import CoTrackerData
from cotracker.datasets.dataclass_utils import load_dataclass from cotracker.datasets.dataclass_utils import load_dataclass
@dataclass @dataclass
class ImageAnnotation: class ImageAnnotation:
# path to jpg file, relative w.r.t. dataset_root # path to jpg file, relative w.r.t. dataset_root
path: str path: str
# H x W # H x W
size: Tuple[int, int] size: Tuple[int, int]
@dataclass @dataclass
class DynamicReplicaFrameAnnotation: class DynamicReplicaFrameAnnotation:
"""A dataclass used to load annotations from json.""" """A dataclass used to load annotations from json."""
# can be used to join with `SequenceAnnotation` # can be used to join with `SequenceAnnotation`
sequence_name: str sequence_name: str
# 0-based, continuous frame number within sequence # 0-based, continuous frame number within sequence
frame_number: int frame_number: int
# timestamp in seconds from the video start # timestamp in seconds from the video start
frame_timestamp: float frame_timestamp: float
image: ImageAnnotation image: ImageAnnotation
meta: Optional[Dict[str, Any]] = None meta: Optional[Dict[str, Any]] = None
camera_name: Optional[str] = None camera_name: Optional[str] = None
trajectories: Optional[str] = None trajectories: Optional[str] = None
class DynamicReplicaDataset(data.Dataset): class DynamicReplicaDataset(data.Dataset):
def __init__( def __init__(
self, self,
root, root,
split="valid", split="valid",
traj_per_sample=256, traj_per_sample=256,
crop_size=None, crop_size=None,
sample_len=-1, sample_len=-1,
only_first_n_samples=-1, only_first_n_samples=-1,
rgbd_input=False, rgbd_input=False,
): ):
super(DynamicReplicaDataset, self).__init__() super(DynamicReplicaDataset, self).__init__()
self.root = root self.root = root
self.sample_len = sample_len self.sample_len = sample_len
self.split = split self.split = split
self.traj_per_sample = traj_per_sample self.traj_per_sample = traj_per_sample
self.rgbd_input = rgbd_input self.rgbd_input = rgbd_input
self.crop_size = crop_size self.crop_size = crop_size
frame_annotations_file = f"frame_annotations_{split}.jgz" frame_annotations_file = f"frame_annotations_{split}.jgz"
self.sample_list = [] self.sample_list = []
with gzip.open( with gzip.open(
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8" os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
) as zipfile: ) as zipfile:
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation]) frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
seq_annot = defaultdict(list) seq_annot = defaultdict(list)
for frame_annot in frame_annots_list: for frame_annot in frame_annots_list:
if frame_annot.camera_name == "left": if frame_annot.camera_name == "left":
seq_annot[frame_annot.sequence_name].append(frame_annot) seq_annot[frame_annot.sequence_name].append(frame_annot)
for seq_name in seq_annot.keys(): for seq_name in seq_annot.keys():
seq_len = len(seq_annot[seq_name]) seq_len = len(seq_annot[seq_name])
step = self.sample_len if self.sample_len > 0 else seq_len step = self.sample_len if self.sample_len > 0 else seq_len
counter = 0 counter = 0
for ref_idx in range(0, seq_len, step): for ref_idx in range(0, seq_len, step):
sample = seq_annot[seq_name][ref_idx : ref_idx + step] sample = seq_annot[seq_name][ref_idx : ref_idx + step]
self.sample_list.append(sample) self.sample_list.append(sample)
counter += 1 counter += 1
if only_first_n_samples > 0 and counter >= only_first_n_samples: if only_first_n_samples > 0 and counter >= only_first_n_samples:
break break
def __len__(self): def __len__(self):
return len(self.sample_list) return len(self.sample_list)
def crop(self, rgbs, trajs): def crop(self, rgbs, trajs):
T, N, _ = trajs.shape T, N, _ = trajs.shape
S = len(rgbs) S = len(rgbs)
H, W = rgbs[0].shape[:2] H, W = rgbs[0].shape[:2]
assert S == T assert S == T
H_new = H H_new = H
W_new = W W_new = W
# simple random crop # simple random crop
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2 y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2 x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0 trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0 trajs[:, :, 1] -= y0
return rgbs, trajs return rgbs, trajs
def __getitem__(self, index): def __getitem__(self, index):
sample = self.sample_list[index] sample = self.sample_list[index]
T = len(sample) T = len(sample)
rgbs, visibilities, traj_2d = [], [], [] rgbs, visibilities, traj_2d = [], [], []
H, W = sample[0].image.size H, W = sample[0].image.size
image_size = (H, W) image_size = (H, W)
for i in range(T): for i in range(T):
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"]) traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
traj = torch.load(traj_path) traj = torch.load(traj_path)
visibilities.append(traj["verts_inds_vis"].numpy()) visibilities.append(traj["verts_inds_vis"].numpy())
rgbs.append(traj["img"].numpy()) rgbs.append(traj["img"].numpy())
traj_2d.append(traj["traj_2d"].numpy()[..., :2]) traj_2d.append(traj["traj_2d"].numpy()[..., :2])
traj_2d = np.stack(traj_2d) traj_2d = np.stack(traj_2d)
visibility = np.stack(visibilities) visibility = np.stack(visibilities)
T, N, D = traj_2d.shape T, N, D = traj_2d.shape
# subsample trajectories for augmentations # subsample trajectories for augmentations
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample] visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
traj_2d = traj_2d[:, visible_inds_sampled] traj_2d = traj_2d[:, visible_inds_sampled]
visibility = visibility[:, visible_inds_sampled] visibility = visibility[:, visible_inds_sampled]
if self.crop_size is not None: if self.crop_size is not None:
rgbs, traj_2d = self.crop(rgbs, traj_2d) rgbs, traj_2d = self.crop(rgbs, traj_2d)
H, W, _ = rgbs[0].shape H, W, _ = rgbs[0].shape
image_size = self.crop_size image_size = self.crop_size
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False visibility[traj_2d[:, :, 1] < 0] = False
# filter out points that're visible for less than 10 frames # filter out points that're visible for less than 10 frames
visible_inds_resampled = visibility.sum(0) > 10 visible_inds_resampled = visibility.sum(0) > 10
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled]) traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
visibility = torch.from_numpy(visibility[:, visible_inds_resampled]) visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
rgbs = np.stack(rgbs, 0) rgbs = np.stack(rgbs, 0)
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float() video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
return CoTrackerData( return CoTrackerData(
video=video, video=video,
trajectory=traj_2d, trajectory=traj_2d,
visibility=visibility, visibility=visibility,
valid=torch.ones(T, N), valid=torch.ones(T, N),
seq_name=sample[0].sequence_name, seq_name=sample[0].sequence_name,
) )

View File

@@ -1,441 +1,441 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os import os
import torch import torch
import cv2 import cv2
import imageio import imageio
import numpy as np import numpy as np
from cotracker.datasets.utils import CoTrackerData from cotracker.datasets.utils import CoTrackerData
from torchvision.transforms import ColorJitter, GaussianBlur from torchvision.transforms import ColorJitter, GaussianBlur
from PIL import Image from PIL import Image
class CoTrackerDataset(torch.utils.data.Dataset): class CoTrackerDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
data_root, data_root,
crop_size=(384, 512), crop_size=(384, 512),
seq_len=24, seq_len=24,
traj_per_sample=768, traj_per_sample=768,
sample_vis_1st_frame=False, sample_vis_1st_frame=False,
use_augs=False, use_augs=False,
): ):
super(CoTrackerDataset, self).__init__() super(CoTrackerDataset, self).__init__()
np.random.seed(0) np.random.seed(0)
torch.manual_seed(0) torch.manual_seed(0)
self.data_root = data_root self.data_root = data_root
self.seq_len = seq_len self.seq_len = seq_len
self.traj_per_sample = traj_per_sample self.traj_per_sample = traj_per_sample
self.sample_vis_1st_frame = sample_vis_1st_frame self.sample_vis_1st_frame = sample_vis_1st_frame
self.use_augs = use_augs self.use_augs = use_augs
self.crop_size = crop_size self.crop_size = crop_size
# photometric augmentation # photometric augmentation
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14) self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0)) self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
self.blur_aug_prob = 0.25 self.blur_aug_prob = 0.25
self.color_aug_prob = 0.25 self.color_aug_prob = 0.25
# occlusion augmentation # occlusion augmentation
self.eraser_aug_prob = 0.5 self.eraser_aug_prob = 0.5
self.eraser_bounds = [2, 100] self.eraser_bounds = [2, 100]
self.eraser_max = 10 self.eraser_max = 10
# occlusion augmentation # occlusion augmentation
self.replace_aug_prob = 0.5 self.replace_aug_prob = 0.5
self.replace_bounds = [2, 100] self.replace_bounds = [2, 100]
self.replace_max = 10 self.replace_max = 10
# spatial augmentations # spatial augmentations
self.pad_bounds = [0, 100] self.pad_bounds = [0, 100]
self.crop_size = crop_size self.crop_size = crop_size
self.resize_lim = [0.25, 2.0] # sample resizes from here self.resize_lim = [0.25, 2.0] # sample resizes from here
self.resize_delta = 0.2 self.resize_delta = 0.2
self.max_crop_offset = 50 self.max_crop_offset = 50
self.do_flip = True self.do_flip = True
self.h_flip_prob = 0.5 self.h_flip_prob = 0.5
self.v_flip_prob = 0.5 self.v_flip_prob = 0.5
def getitem_helper(self, index): def getitem_helper(self, index):
return NotImplementedError return NotImplementedError
def __getitem__(self, index): def __getitem__(self, index):
gotit = False gotit = False
sample, gotit = self.getitem_helper(index) sample, gotit = self.getitem_helper(index)
if not gotit: if not gotit:
print("warning: sampling failed") print("warning: sampling failed")
# fake sample, so we can still collate # fake sample, so we can still collate
sample = CoTrackerData( sample = CoTrackerData(
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])), video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)), trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
visibility=torch.zeros((self.seq_len, self.traj_per_sample)), visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
valid=torch.zeros((self.seq_len, self.traj_per_sample)), valid=torch.zeros((self.seq_len, self.traj_per_sample)),
) )
return sample, gotit return sample, gotit
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True): def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
T, N, _ = trajs.shape T, N, _ = trajs.shape
S = len(rgbs) S = len(rgbs)
H, W = rgbs[0].shape[:2] H, W = rgbs[0].shape[:2]
assert S == T assert S == T
if eraser: if eraser:
############ eraser transform (per image after the first) ############ ############ eraser transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs] rgbs = [rgb.astype(np.float32) for rgb in rgbs]
for i in range(1, S): for i in range(1, S):
if np.random.rand() < self.eraser_aug_prob: if np.random.rand() < self.eraser_aug_prob:
for _ in range( for _ in range(
np.random.randint(1, self.eraser_max + 1) np.random.randint(1, self.eraser_max + 1)
): # number of times to occlude ): # number of times to occlude
xc = np.random.randint(0, W) xc = np.random.randint(0, W)
yc = np.random.randint(0, H) yc = np.random.randint(0, H)
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1]) dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1]) dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0) mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
rgbs[i][y0:y1, x0:x1, :] = mean_color rgbs[i][y0:y1, x0:x1, :] = mean_color
occ_inds = np.logical_and( occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
) )
visibles[i, occ_inds] = 0 visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs] rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
if replace: if replace:
rgbs_alt = [ rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
] ]
rgbs_alt = [ rgbs_alt = [
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
] ]
############ replace transform (per image after the first) ############ ############ replace transform (per image after the first) ############
rgbs = [rgb.astype(np.float32) for rgb in rgbs] rgbs = [rgb.astype(np.float32) for rgb in rgbs]
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt] rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
for i in range(1, S): for i in range(1, S):
if np.random.rand() < self.replace_aug_prob: if np.random.rand() < self.replace_aug_prob:
for _ in range( for _ in range(
np.random.randint(1, self.replace_max + 1) np.random.randint(1, self.replace_max + 1)
): # number of times to occlude ): # number of times to occlude
xc = np.random.randint(0, W) xc = np.random.randint(0, W)
yc = np.random.randint(0, H) yc = np.random.randint(0, H)
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1]) dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1]) dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32) x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32) x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32) y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32) y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
wid = x1 - x0 wid = x1 - x0
hei = y1 - y0 hei = y1 - y0
y00 = np.random.randint(0, H - hei) y00 = np.random.randint(0, H - hei)
x00 = np.random.randint(0, W - wid) x00 = np.random.randint(0, W - wid)
fr = np.random.randint(0, S) fr = np.random.randint(0, S)
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :] rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
rgbs[i][y0:y1, x0:x1, :] = rep rgbs[i][y0:y1, x0:x1, :] = rep
occ_inds = np.logical_and( occ_inds = np.logical_and(
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1), np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1), np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
) )
visibles[i, occ_inds] = 0 visibles[i, occ_inds] = 0
rgbs = [rgb.astype(np.uint8) for rgb in rgbs] rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
############ photometric augmentation ############ ############ photometric augmentation ############
if np.random.rand() < self.color_aug_prob: if np.random.rand() < self.color_aug_prob:
# random per-frame amount of aug # random per-frame amount of aug
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs] rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
if np.random.rand() < self.blur_aug_prob: if np.random.rand() < self.blur_aug_prob:
# random per-frame amount of blur # random per-frame amount of blur
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs] rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
return rgbs, trajs, visibles return rgbs, trajs, visibles
def add_spatial_augs(self, rgbs, trajs, visibles): def add_spatial_augs(self, rgbs, trajs, visibles):
T, N, __ = trajs.shape T, N, __ = trajs.shape
S = len(rgbs) S = len(rgbs)
H, W = rgbs[0].shape[:2] H, W = rgbs[0].shape[:2]
assert S == T assert S == T
rgbs = [rgb.astype(np.float32) for rgb in rgbs] rgbs = [rgb.astype(np.float32) for rgb in rgbs]
############ spatial transform ############ ############ spatial transform ############
# padding # padding
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1]) pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs] rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
trajs[:, :, 0] += pad_x0 trajs[:, :, 0] += pad_x0
trajs[:, :, 1] += pad_y0 trajs[:, :, 1] += pad_y0
H, W = rgbs[0].shape[:2] H, W = rgbs[0].shape[:2]
# scaling + stretching # scaling + stretching
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1]) scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
scale_x = scale scale_x = scale
scale_y = scale scale_y = scale
H_new = H H_new = H
W_new = W W_new = W
scale_delta_x = 0.0 scale_delta_x = 0.0
scale_delta_y = 0.0 scale_delta_y = 0.0
rgbs_scaled = [] rgbs_scaled = []
for s in range(S): for s in range(S):
if s == 1: if s == 1:
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta) scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta) scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
elif s > 1: elif s > 1:
scale_delta_x = ( scale_delta_x = (
scale_delta_x * 0.8 scale_delta_x * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
) )
scale_delta_y = ( scale_delta_y = (
scale_delta_y * 0.8 scale_delta_y * 0.8
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2 + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
) )
scale_x = scale_x + scale_delta_x scale_x = scale_x + scale_delta_x
scale_y = scale_y + scale_delta_y scale_y = scale_y + scale_delta_y
# bring h/w closer # bring h/w closer
scale_xy = (scale_x + scale_y) * 0.5 scale_xy = (scale_x + scale_y) * 0.5
scale_x = scale_x * 0.5 + scale_xy * 0.5 scale_x = scale_x * 0.5 + scale_xy * 0.5
scale_y = scale_y * 0.5 + scale_xy * 0.5 scale_y = scale_y * 0.5 + scale_xy * 0.5
# don't get too crazy # don't get too crazy
scale_x = np.clip(scale_x, 0.2, 2.0) scale_x = np.clip(scale_x, 0.2, 2.0)
scale_y = np.clip(scale_y, 0.2, 2.0) scale_y = np.clip(scale_y, 0.2, 2.0)
H_new = int(H * scale_y) H_new = int(H * scale_y)
W_new = int(W * scale_x) W_new = int(W * scale_x)
# make it at least slightly bigger than the crop area, # make it at least slightly bigger than the crop area,
# so that the random cropping can add diversity # so that the random cropping can add diversity
H_new = np.clip(H_new, self.crop_size[0] + 10, None) H_new = np.clip(H_new, self.crop_size[0] + 10, None)
W_new = np.clip(W_new, self.crop_size[1] + 10, None) W_new = np.clip(W_new, self.crop_size[1] + 10, None)
# recompute scale in case we clipped # recompute scale in case we clipped
scale_x = (W_new - 1) / float(W - 1) scale_x = (W_new - 1) / float(W - 1)
scale_y = (H_new - 1) / float(H - 1) scale_y = (H_new - 1) / float(H - 1)
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR)) rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
trajs[s, :, 0] *= scale_x trajs[s, :, 0] *= scale_x
trajs[s, :, 1] *= scale_y trajs[s, :, 1] *= scale_y
rgbs = rgbs_scaled rgbs = rgbs_scaled
ok_inds = visibles[0, :] > 0 ok_inds = visibles[0, :] > 0
vis_trajs = trajs[:, ok_inds] # S,?,2 vis_trajs = trajs[:, ok_inds] # S,?,2
if vis_trajs.shape[1] > 0: if vis_trajs.shape[1] > 0:
mid_x = np.mean(vis_trajs[0, :, 0]) mid_x = np.mean(vis_trajs[0, :, 0])
mid_y = np.mean(vis_trajs[0, :, 1]) mid_y = np.mean(vis_trajs[0, :, 1])
else: else:
mid_y = self.crop_size[0] mid_y = self.crop_size[0]
mid_x = self.crop_size[1] mid_x = self.crop_size[1]
x0 = int(mid_x - self.crop_size[1] // 2) x0 = int(mid_x - self.crop_size[1] // 2)
y0 = int(mid_y - self.crop_size[0] // 2) y0 = int(mid_y - self.crop_size[0] // 2)
offset_x = 0 offset_x = 0
offset_y = 0 offset_y = 0
for s in range(S): for s in range(S):
# on each frame, shift a bit more # on each frame, shift a bit more
if s == 1: if s == 1:
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset) offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset) offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
elif s > 1: elif s > 1:
offset_x = int( offset_x = int(
offset_x * 0.8 offset_x * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2 + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
) )
offset_y = int( offset_y = int(
offset_y * 0.8 offset_y * 0.8
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2 + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
) )
x0 = x0 + offset_x x0 = x0 + offset_x
y0 = y0 + offset_y y0 = y0 + offset_y
H_new, W_new = rgbs[s].shape[:2] H_new, W_new = rgbs[s].shape[:2]
if H_new == self.crop_size[0]: if H_new == self.crop_size[0]:
y0 = 0 y0 = 0
else: else:
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1) y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
if W_new == self.crop_size[1]: if W_new == self.crop_size[1]:
x0 = 0 x0 = 0
else: else:
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1) x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
trajs[s, :, 0] -= x0 trajs[s, :, 0] -= x0
trajs[s, :, 1] -= y0 trajs[s, :, 1] -= y0
H_new = self.crop_size[0] H_new = self.crop_size[0]
W_new = self.crop_size[1] W_new = self.crop_size[1]
# flip # flip
h_flipped = False h_flipped = False
v_flipped = False v_flipped = False
if self.do_flip: if self.do_flip:
# h flip # h flip
if np.random.rand() < self.h_flip_prob: if np.random.rand() < self.h_flip_prob:
h_flipped = True h_flipped = True
rgbs = [rgb[:, ::-1] for rgb in rgbs] rgbs = [rgb[:, ::-1] for rgb in rgbs]
# v flip # v flip
if np.random.rand() < self.v_flip_prob: if np.random.rand() < self.v_flip_prob:
v_flipped = True v_flipped = True
rgbs = [rgb[::-1] for rgb in rgbs] rgbs = [rgb[::-1] for rgb in rgbs]
if h_flipped: if h_flipped:
trajs[:, :, 0] = W_new - trajs[:, :, 0] trajs[:, :, 0] = W_new - trajs[:, :, 0]
if v_flipped: if v_flipped:
trajs[:, :, 1] = H_new - trajs[:, :, 1] trajs[:, :, 1] = H_new - trajs[:, :, 1]
return rgbs, trajs return rgbs, trajs
def crop(self, rgbs, trajs): def crop(self, rgbs, trajs):
T, N, _ = trajs.shape T, N, _ = trajs.shape
S = len(rgbs) S = len(rgbs)
H, W = rgbs[0].shape[:2] H, W = rgbs[0].shape[:2]
assert S == T assert S == T
############ spatial transform ############ ############ spatial transform ############
H_new = H H_new = H
W_new = W W_new = W
# simple random crop # simple random crop
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0]) y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1]) x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs] rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
trajs[:, :, 0] -= x0 trajs[:, :, 0] -= x0
trajs[:, :, 1] -= y0 trajs[:, :, 1] -= y0
return rgbs, trajs return rgbs, trajs
class KubricMovifDataset(CoTrackerDataset): class KubricMovifDataset(CoTrackerDataset):
def __init__( def __init__(
self, self,
data_root, data_root,
crop_size=(384, 512), crop_size=(384, 512),
seq_len=24, seq_len=24,
traj_per_sample=768, traj_per_sample=768,
sample_vis_1st_frame=False, sample_vis_1st_frame=False,
use_augs=False, use_augs=False,
): ):
super(KubricMovifDataset, self).__init__( super(KubricMovifDataset, self).__init__(
data_root=data_root, data_root=data_root,
crop_size=crop_size, crop_size=crop_size,
seq_len=seq_len, seq_len=seq_len,
traj_per_sample=traj_per_sample, traj_per_sample=traj_per_sample,
sample_vis_1st_frame=sample_vis_1st_frame, sample_vis_1st_frame=sample_vis_1st_frame,
use_augs=use_augs, use_augs=use_augs,
) )
self.pad_bounds = [0, 25] self.pad_bounds = [0, 25]
self.resize_lim = [0.75, 1.25] # sample resizes from here self.resize_lim = [0.75, 1.25] # sample resizes from here
self.resize_delta = 0.05 self.resize_delta = 0.05
self.max_crop_offset = 15 self.max_crop_offset = 15
self.seq_names = [ self.seq_names = [
fname fname
for fname in os.listdir(data_root) for fname in os.listdir(data_root)
if os.path.isdir(os.path.join(data_root, fname)) if os.path.isdir(os.path.join(data_root, fname))
] ]
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root)) print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
def getitem_helper(self, index): def getitem_helper(self, index):
gotit = True gotit = True
seq_name = self.seq_names[index] seq_name = self.seq_names[index]
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy") npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
rgb_path = os.path.join(self.data_root, seq_name, "frames") rgb_path = os.path.join(self.data_root, seq_name, "frames")
img_paths = sorted(os.listdir(rgb_path)) img_paths = sorted(os.listdir(rgb_path))
rgbs = [] rgbs = []
for i, img_path in enumerate(img_paths): for i, img_path in enumerate(img_paths):
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path))) rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
rgbs = np.stack(rgbs) rgbs = np.stack(rgbs)
annot_dict = np.load(npy_path, allow_pickle=True).item() annot_dict = np.load(npy_path, allow_pickle=True).item()
traj_2d = annot_dict["coords"] traj_2d = annot_dict["coords"]
visibility = annot_dict["visibility"] visibility = annot_dict["visibility"]
# random crop # random crop
assert self.seq_len <= len(rgbs) assert self.seq_len <= len(rgbs)
if self.seq_len < len(rgbs): if self.seq_len < len(rgbs):
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0] start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
rgbs = rgbs[start_ind : start_ind + self.seq_len] rgbs = rgbs[start_ind : start_ind + self.seq_len]
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len] traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
visibility = visibility[:, start_ind : start_ind + self.seq_len] visibility = visibility[:, start_ind : start_ind + self.seq_len]
traj_2d = np.transpose(traj_2d, (1, 0, 2)) traj_2d = np.transpose(traj_2d, (1, 0, 2))
visibility = np.transpose(np.logical_not(visibility), (1, 0)) visibility = np.transpose(np.logical_not(visibility), (1, 0))
if self.use_augs: if self.use_augs:
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility) rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility) rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
else: else:
rgbs, traj_2d = self.crop(rgbs, traj_2d) rgbs, traj_2d = self.crop(rgbs, traj_2d)
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
visibility[traj_2d[:, :, 0] < 0] = False visibility[traj_2d[:, :, 0] < 0] = False
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
visibility[traj_2d[:, :, 1] < 0] = False visibility[traj_2d[:, :, 1] < 0] = False
visibility = torch.from_numpy(visibility) visibility = torch.from_numpy(visibility)
traj_2d = torch.from_numpy(traj_2d) traj_2d = torch.from_numpy(traj_2d)
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0] visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
if self.sample_vis_1st_frame: if self.sample_vis_1st_frame:
visibile_pts_inds = visibile_pts_first_frame_inds visibile_pts_inds = visibile_pts_first_frame_inds
else: else:
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[ visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
:, 0 :, 0
] ]
visibile_pts_inds = torch.cat( visibile_pts_inds = torch.cat(
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0 (visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
) )
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample] point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
if len(point_inds) < self.traj_per_sample: if len(point_inds) < self.traj_per_sample:
gotit = False gotit = False
visible_inds_sampled = visibile_pts_inds[point_inds] visible_inds_sampled = visibile_pts_inds[point_inds]
trajs = traj_2d[:, visible_inds_sampled].float() trajs = traj_2d[:, visible_inds_sampled].float()
visibles = visibility[:, visible_inds_sampled] visibles = visibility[:, visible_inds_sampled]
valids = torch.ones((self.seq_len, self.traj_per_sample)) valids = torch.ones((self.seq_len, self.traj_per_sample))
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float() rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
sample = CoTrackerData( sample = CoTrackerData(
video=rgbs, video=rgbs,
trajectory=trajs, trajectory=trajs,
visibility=visibles, visibility=visibles,
valid=valids, valid=valids,
seq_name=seq_name, seq_name=seq_name,
) )
return sample, gotit return sample, gotit
def __len__(self): def __len__(self):
return len(self.seq_names) return len(self.seq_names)

View File

@@ -1,209 +1,209 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os import os
import io import io
import glob import glob
import torch import torch
import pickle import pickle
import numpy as np import numpy as np
import mediapy as media import mediapy as media
from PIL import Image from PIL import Image
from typing import Mapping, Tuple, Union from typing import Mapping, Tuple, Union
from cotracker.datasets.utils import CoTrackerData from cotracker.datasets.utils import CoTrackerData
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]] DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray: def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
"""Resize a video to output_size.""" """Resize a video to output_size."""
# If you have a GPU, consider replacing this with a GPU-enabled resize op, # If you have a GPU, consider replacing this with a GPU-enabled resize op,
# such as a jitted jax.image.resize. It will make things faster. # such as a jitted jax.image.resize. It will make things faster.
return media.resize_video(video, output_size) return media.resize_video(video, output_size)
def sample_queries_first( def sample_queries_first(
target_occluded: np.ndarray, target_occluded: np.ndarray,
target_points: np.ndarray, target_points: np.ndarray,
frames: np.ndarray, frames: np.ndarray,
) -> Mapping[str, np.ndarray]: ) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations. """Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, use the first Given a set of frames and tracks with no query points, use the first
visible point in each track as the query. visible point in each track as the query.
Args: Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded. where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1. is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1. -1 and 1.
Returns: Returns:
A dict with the keys: A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3] video: Video tensor of shape [1, n_frames, height, width, 3]
query_points: Query points of shape [1, n_queries, 3] where query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1] each point is [t, y, x] scaled to the range [-1, 1]
target_points: Target points of shape [1, n_queries, n_frames, 2] where target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1] each point is [x, y] scaled to the range [-1, 1]
""" """
valid = np.sum(~target_occluded, axis=1) > 0 valid = np.sum(~target_occluded, axis=1) > 0
target_points = target_points[valid, :] target_points = target_points[valid, :]
target_occluded = target_occluded[valid, :] target_occluded = target_occluded[valid, :]
query_points = [] query_points = []
for i in range(target_points.shape[0]): for i in range(target_points.shape[0]):
index = np.where(target_occluded[i] == 0)[0][0] index = np.where(target_occluded[i] == 0)[0][0]
x, y = target_points[i, index, 0], target_points[i, index, 1] x, y = target_points[i, index, 0], target_points[i, index, 1]
query_points.append(np.array([index, y, x])) # [t, y, x] query_points.append(np.array([index, y, x])) # [t, y, x]
query_points = np.stack(query_points, axis=0) query_points = np.stack(query_points, axis=0)
return { return {
"video": frames[np.newaxis, ...], "video": frames[np.newaxis, ...],
"query_points": query_points[np.newaxis, ...], "query_points": query_points[np.newaxis, ...],
"target_points": target_points[np.newaxis, ...], "target_points": target_points[np.newaxis, ...],
"occluded": target_occluded[np.newaxis, ...], "occluded": target_occluded[np.newaxis, ...],
} }
def sample_queries_strided( def sample_queries_strided(
target_occluded: np.ndarray, target_occluded: np.ndarray,
target_points: np.ndarray, target_points: np.ndarray,
frames: np.ndarray, frames: np.ndarray,
query_stride: int = 5, query_stride: int = 5,
) -> Mapping[str, np.ndarray]: ) -> Mapping[str, np.ndarray]:
"""Package a set of frames and tracks for use in TAPNet evaluations. """Package a set of frames and tracks for use in TAPNet evaluations.
Given a set of frames and tracks with no query points, sample queries Given a set of frames and tracks with no query points, sample queries
strided every query_stride frames, ignoring points that are not visible strided every query_stride frames, ignoring points that are not visible
at the selected frames. at the selected frames.
Args: Args:
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames], target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
where True indicates occluded. where True indicates occluded.
target_points: Position, of shape [n_tracks, n_frames, 2], where each point target_points: Position, of shape [n_tracks, n_frames, 2], where each point
is [x,y] scaled between 0 and 1. is [x,y] scaled between 0 and 1.
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
-1 and 1. -1 and 1.
query_stride: When sampling query points, search for un-occluded points query_stride: When sampling query points, search for un-occluded points
every query_stride frames and convert each one into a query. every query_stride frames and convert each one into a query.
Returns: Returns:
A dict with the keys: A dict with the keys:
video: Video tensor of shape [1, n_frames, height, width, 3]. The video video: Video tensor of shape [1, n_frames, height, width, 3]. The video
has floats scaled to the range [-1, 1]. has floats scaled to the range [-1, 1].
query_points: Query points of shape [1, n_queries, 3] where query_points: Query points of shape [1, n_queries, 3] where
each point is [t, y, x] scaled to the range [-1, 1]. each point is [t, y, x] scaled to the range [-1, 1].
target_points: Target points of shape [1, n_queries, n_frames, 2] where target_points: Target points of shape [1, n_queries, n_frames, 2] where
each point is [x, y] scaled to the range [-1, 1]. each point is [x, y] scaled to the range [-1, 1].
trackgroup: Index of the original track that each query point was trackgroup: Index of the original track that each query point was
sampled from. This is useful for visualization. sampled from. This is useful for visualization.
""" """
tracks = [] tracks = []
occs = [] occs = []
queries = [] queries = []
trackgroups = [] trackgroups = []
total = 0 total = 0
trackgroup = np.arange(target_occluded.shape[0]) trackgroup = np.arange(target_occluded.shape[0])
for i in range(0, target_occluded.shape[1], query_stride): for i in range(0, target_occluded.shape[1], query_stride):
mask = target_occluded[:, i] == 0 mask = target_occluded[:, i] == 0
query = np.stack( query = np.stack(
[ [
i * np.ones(target_occluded.shape[0:1]), i * np.ones(target_occluded.shape[0:1]),
target_points[:, i, 1], target_points[:, i, 1],
target_points[:, i, 0], target_points[:, i, 0],
], ],
axis=-1, axis=-1,
) )
queries.append(query[mask]) queries.append(query[mask])
tracks.append(target_points[mask]) tracks.append(target_points[mask])
occs.append(target_occluded[mask]) occs.append(target_occluded[mask])
trackgroups.append(trackgroup[mask]) trackgroups.append(trackgroup[mask])
total += np.array(np.sum(target_occluded[:, i] == 0)) total += np.array(np.sum(target_occluded[:, i] == 0))
return { return {
"video": frames[np.newaxis, ...], "video": frames[np.newaxis, ...],
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...], "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...], "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...], "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...], "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
} }
class TapVidDataset(torch.utils.data.Dataset): class TapVidDataset(torch.utils.data.Dataset):
def __init__( def __init__(
self, self,
data_root, data_root,
dataset_type="davis", dataset_type="davis",
resize_to_256=True, resize_to_256=True,
queried_first=True, queried_first=True,
): ):
self.dataset_type = dataset_type self.dataset_type = dataset_type
self.resize_to_256 = resize_to_256 self.resize_to_256 = resize_to_256
self.queried_first = queried_first self.queried_first = queried_first
if self.dataset_type == "kinetics": if self.dataset_type == "kinetics":
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl")) all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
points_dataset = [] points_dataset = []
for pickle_path in all_paths: for pickle_path in all_paths:
with open(pickle_path, "rb") as f: with open(pickle_path, "rb") as f:
data = pickle.load(f) data = pickle.load(f)
points_dataset = points_dataset + data points_dataset = points_dataset + data
self.points_dataset = points_dataset self.points_dataset = points_dataset
else: else:
with open(data_root, "rb") as f: with open(data_root, "rb") as f:
self.points_dataset = pickle.load(f) self.points_dataset = pickle.load(f)
if self.dataset_type == "davis": if self.dataset_type == "davis":
self.video_names = list(self.points_dataset.keys()) self.video_names = list(self.points_dataset.keys())
print("found %d unique videos in %s" % (len(self.points_dataset), data_root)) print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
def __getitem__(self, index): def __getitem__(self, index):
if self.dataset_type == "davis": if self.dataset_type == "davis":
video_name = self.video_names[index] video_name = self.video_names[index]
else: else:
video_name = index video_name = index
video = self.points_dataset[video_name] video = self.points_dataset[video_name]
frames = video["video"] frames = video["video"]
if isinstance(frames[0], bytes): if isinstance(frames[0], bytes):
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s. # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
def decode(frame): def decode(frame):
byteio = io.BytesIO(frame) byteio = io.BytesIO(frame)
img = Image.open(byteio) img = Image.open(byteio)
return np.array(img) return np.array(img)
frames = np.array([decode(frame) for frame in frames]) frames = np.array([decode(frame) for frame in frames])
target_points = self.points_dataset[video_name]["points"] target_points = self.points_dataset[video_name]["points"]
if self.resize_to_256: if self.resize_to_256:
frames = resize_video(frames, [256, 256]) frames = resize_video(frames, [256, 256])
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1 target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
else: else:
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1]) target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
target_occ = self.points_dataset[video_name]["occluded"] target_occ = self.points_dataset[video_name]["occluded"]
if self.queried_first: if self.queried_first:
converted = sample_queries_first(target_occ, target_points, frames) converted = sample_queries_first(target_occ, target_points, frames)
else: else:
converted = sample_queries_strided(target_occ, target_points, frames) converted = sample_queries_strided(target_occ, target_points, frames)
assert converted["target_points"].shape[1] == converted["query_points"].shape[1] assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float() rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute( visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
1, 0 1, 0
) # T, N ) # T, N
query_points = torch.from_numpy(converted["query_points"])[0] # T, N query_points = torch.from_numpy(converted["query_points"])[0] # T, N
return CoTrackerData( return CoTrackerData(
rgbs, rgbs,
trajs, trajs,
visibles, visibles,
seq_name=str(video_name), seq_name=str(video_name),
query_points=query_points, query_points=query_points,
) )
def __len__(self): def __len__(self):
return len(self.points_dataset) return len(self.points_dataset)

View File

@@ -1,106 +1,106 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import dataclasses import dataclasses
import torch.nn.functional as F import torch.nn.functional as F
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Optional from typing import Any, Optional
@dataclass(eq=False) @dataclass(eq=False)
class CoTrackerData: class CoTrackerData:
""" """
Dataclass for storing video tracks data. Dataclass for storing video tracks data.
""" """
video: torch.Tensor # B, S, C, H, W video: torch.Tensor # B, S, C, H, W
trajectory: torch.Tensor # B, S, N, 2 trajectory: torch.Tensor # B, S, N, 2
visibility: torch.Tensor # B, S, N visibility: torch.Tensor # B, S, N
# optional data # optional data
valid: Optional[torch.Tensor] = None # B, S, N valid: Optional[torch.Tensor] = None # B, S, N
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
seq_name: Optional[str] = None seq_name: Optional[str] = None
query_points: Optional[torch.Tensor] = None # TapVID evaluation format query_points: Optional[torch.Tensor] = None # TapVID evaluation format
def collate_fn(batch): def collate_fn(batch):
""" """
Collate function for video tracks data. Collate function for video tracks data.
""" """
video = torch.stack([b.video for b in batch], dim=0) video = torch.stack([b.video for b in batch], dim=0)
trajectory = torch.stack([b.trajectory for b in batch], dim=0) trajectory = torch.stack([b.trajectory for b in batch], dim=0)
visibility = torch.stack([b.visibility for b in batch], dim=0) visibility = torch.stack([b.visibility for b in batch], dim=0)
query_points = segmentation = None query_points = segmentation = None
if batch[0].query_points is not None: if batch[0].query_points is not None:
query_points = torch.stack([b.query_points for b in batch], dim=0) query_points = torch.stack([b.query_points for b in batch], dim=0)
if batch[0].segmentation is not None: if batch[0].segmentation is not None:
segmentation = torch.stack([b.segmentation for b in batch], dim=0) segmentation = torch.stack([b.segmentation for b in batch], dim=0)
seq_name = [b.seq_name for b in batch] seq_name = [b.seq_name for b in batch]
return CoTrackerData( return CoTrackerData(
video=video, video=video,
trajectory=trajectory, trajectory=trajectory,
visibility=visibility, visibility=visibility,
segmentation=segmentation, segmentation=segmentation,
seq_name=seq_name, seq_name=seq_name,
query_points=query_points, query_points=query_points,
) )
def collate_fn_train(batch): def collate_fn_train(batch):
""" """
Collate function for video tracks data during training. Collate function for video tracks data during training.
""" """
gotit = [gotit for _, gotit in batch] gotit = [gotit for _, gotit in batch]
video = torch.stack([b.video for b, _ in batch], dim=0) video = torch.stack([b.video for b, _ in batch], dim=0)
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0) trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
visibility = torch.stack([b.visibility for b, _ in batch], dim=0) visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
valid = torch.stack([b.valid for b, _ in batch], dim=0) valid = torch.stack([b.valid for b, _ in batch], dim=0)
seq_name = [b.seq_name for b, _ in batch] seq_name = [b.seq_name for b, _ in batch]
return ( return (
CoTrackerData( CoTrackerData(
video=video, video=video,
trajectory=trajectory, trajectory=trajectory,
visibility=visibility, visibility=visibility,
valid=valid, valid=valid,
seq_name=seq_name, seq_name=seq_name,
), ),
gotit, gotit,
) )
def try_to_cuda(t: Any) -> Any: def try_to_cuda(t: Any) -> Any:
""" """
Try to move the input variable `t` to a cuda device. Try to move the input variable `t` to a cuda device.
Args: Args:
t: Input. t: Input.
Returns: Returns:
t_cuda: `t` moved to a cuda device, if supported. t_cuda: `t` moved to a cuda device, if supported.
""" """
try: try:
t = t.float().cuda() t = t.float().cuda()
except AttributeError: except AttributeError:
pass pass
return t return t
def dataclass_to_cuda_(obj): def dataclass_to_cuda_(obj):
""" """
Move all contents of a dataclass to cuda inplace if supported. Move all contents of a dataclass to cuda inplace if supported.
Args: Args:
batch: Input dataclass. batch: Input dataclass.
Returns: Returns:
batch_cuda: `batch` moved to a cuda device, if supported. batch_cuda: `batch` moved to a cuda device, if supported.
""" """
for f in dataclasses.fields(obj): for f in dataclasses.fields(obj):
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name))) setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
return obj return obj

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.

View File

@@ -1,6 +1,6 @@
defaults: defaults:
- default_config_eval - default_config_eval
exp_dir: ./outputs/cotracker exp_dir: ./outputs/cotracker
dataset_name: dynamic_replica dataset_name: dynamic_replica

View File

@@ -1,6 +1,6 @@
defaults: defaults:
- default_config_eval - default_config_eval
exp_dir: ./outputs/cotracker exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_first dataset_name: tapvid_davis_first

View File

@@ -1,6 +1,6 @@
defaults: defaults:
- default_config_eval - default_config_eval
exp_dir: ./outputs/cotracker exp_dir: ./outputs/cotracker
dataset_name: tapvid_davis_strided dataset_name: tapvid_davis_strided

View File

@@ -1,6 +1,6 @@
defaults: defaults:
- default_config_eval - default_config_eval
exp_dir: ./outputs/cotracker exp_dir: ./outputs/cotracker
dataset_name: tapvid_kinetics_first dataset_name: tapvid_kinetics_first

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.

View File

@@ -1,138 +1,138 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import numpy as np import numpy as np
from typing import Iterable, Mapping, Tuple, Union from typing import Iterable, Mapping, Tuple, Union
def compute_tapvid_metrics( def compute_tapvid_metrics(
query_points: np.ndarray, query_points: np.ndarray,
gt_occluded: np.ndarray, gt_occluded: np.ndarray,
gt_tracks: np.ndarray, gt_tracks: np.ndarray,
pred_occluded: np.ndarray, pred_occluded: np.ndarray,
pred_tracks: np.ndarray, pred_tracks: np.ndarray,
query_mode: str, query_mode: str,
) -> Mapping[str, np.ndarray]: ) -> Mapping[str, np.ndarray]:
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.) """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
See the TAP-Vid paper for details on the metric computation. All inputs are See the TAP-Vid paper for details on the metric computation. All inputs are
given in raster coordinates. The first three arguments should be the direct given in raster coordinates. The first three arguments should be the direct
outputs of the reader: the 'query_points', 'occluded', and 'target_points'. outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
The paper metrics assume these are scaled relative to 256x256 images. The paper metrics assume these are scaled relative to 256x256 images.
pred_occluded and pred_tracks are your algorithm's predictions. pred_occluded and pred_tracks are your algorithm's predictions.
This function takes a batch of inputs, and computes metrics separately for This function takes a batch of inputs, and computes metrics separately for
each video. The metrics for the full benchmark are a simple mean of the each video. The metrics for the full benchmark are a simple mean of the
metrics across the full set of videos. These numbers are between 0 and 1, metrics across the full set of videos. These numbers are between 0 and 1,
but the paper multiplies them by 100 to ease reading. but the paper multiplies them by 100 to ease reading.
Args: Args:
query_points: The query points, an in the format [t, y, x]. Its size is query_points: The query points, an in the format [t, y, x]. Its size is
[b, n, 3], where b is the batch size and n is the number of queries [b, n, 3], where b is the batch size and n is the number of queries
gt_occluded: A boolean array of shape [b, n, t], where t is the number gt_occluded: A boolean array of shape [b, n, t], where t is the number
of frames. True indicates that the point is occluded. of frames. True indicates that the point is occluded.
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
in the format [x, y] in the format [x, y]
pred_occluded: A boolean array of predicted occlusions, in the same pred_occluded: A boolean array of predicted occlusions, in the same
format as gt_occluded. format as gt_occluded.
pred_tracks: An array of track predictions from your algorithm, in the pred_tracks: An array of track predictions from your algorithm, in the
same format as gt_tracks. same format as gt_tracks.
query_mode: Either 'first' or 'strided', depending on how queries are query_mode: Either 'first' or 'strided', depending on how queries are
sampled. If 'first', we assume the prior knowledge that all points sampled. If 'first', we assume the prior knowledge that all points
before the query point are occluded, and these are removed from the before the query point are occluded, and these are removed from the
evaluation. evaluation.
Returns: Returns:
A dict with the following keys: A dict with the following keys:
occlusion_accuracy: Accuracy at predicting occlusion. occlusion_accuracy: Accuracy at predicting occlusion.
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
predicted to be within the given pixel threshold, ignoring occlusion predicted to be within the given pixel threshold, ignoring occlusion
prediction. prediction.
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
threshold threshold
average_pts_within_thresh: average across pts_within_{x} average_pts_within_thresh: average across pts_within_{x}
average_jaccard: average across jaccard_{x} average_jaccard: average across jaccard_{x}
""" """
metrics = {} metrics = {}
# Fixed bug is described in: # Fixed bug is described in:
# https://github.com/facebookresearch/co-tracker/issues/20 # https://github.com/facebookresearch/co-tracker/issues/20
eye = np.eye(gt_tracks.shape[2], dtype=np.int32) eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
if query_mode == "first": if query_mode == "first":
# evaluate frames after the query frame # evaluate frames after the query frame
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
elif query_mode == "strided": elif query_mode == "strided":
# evaluate all frames except the query frame # evaluate all frames except the query frame
query_frame_to_eval_frames = 1 - eye query_frame_to_eval_frames = 1 - eye
else: else:
raise ValueError("Unknown query mode " + query_mode) raise ValueError("Unknown query mode " + query_mode)
query_frame = query_points[..., 0] query_frame = query_points[..., 0]
query_frame = np.round(query_frame).astype(np.int32) query_frame = np.round(query_frame).astype(np.int32)
evaluation_points = query_frame_to_eval_frames[query_frame] > 0 evaluation_points = query_frame_to_eval_frames[query_frame] > 0
# Occlusion accuracy is simply how often the predicted occlusion equals the # Occlusion accuracy is simply how often the predicted occlusion equals the
# ground truth. # ground truth.
occ_acc = np.sum( occ_acc = np.sum(
np.equal(pred_occluded, gt_occluded) & evaluation_points, np.equal(pred_occluded, gt_occluded) & evaluation_points,
axis=(1, 2), axis=(1, 2),
) / np.sum(evaluation_points) ) / np.sum(evaluation_points)
metrics["occlusion_accuracy"] = occ_acc metrics["occlusion_accuracy"] = occ_acc
# Next, convert the predictions and ground truth positions into pixel # Next, convert the predictions and ground truth positions into pixel
# coordinates. # coordinates.
visible = np.logical_not(gt_occluded) visible = np.logical_not(gt_occluded)
pred_visible = np.logical_not(pred_occluded) pred_visible = np.logical_not(pred_occluded)
all_frac_within = [] all_frac_within = []
all_jaccard = [] all_jaccard = []
for thresh in [1, 2, 4, 8, 16]: for thresh in [1, 2, 4, 8, 16]:
# True positives are points that are within the threshold and where both # True positives are points that are within the threshold and where both
# the prediction and the ground truth are listed as visible. # the prediction and the ground truth are listed as visible.
within_dist = np.sum( within_dist = np.sum(
np.square(pred_tracks - gt_tracks), np.square(pred_tracks - gt_tracks),
axis=-1, axis=-1,
) < np.square(thresh) ) < np.square(thresh)
is_correct = np.logical_and(within_dist, visible) is_correct = np.logical_and(within_dist, visible)
# Compute the frac_within_threshold, which is the fraction of points # Compute the frac_within_threshold, which is the fraction of points
# within the threshold among points that are visible in the ground truth, # within the threshold among points that are visible in the ground truth,
# ignoring whether they're predicted to be visible. # ignoring whether they're predicted to be visible.
count_correct = np.sum( count_correct = np.sum(
is_correct & evaluation_points, is_correct & evaluation_points,
axis=(1, 2), axis=(1, 2),
) )
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2)) count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
frac_correct = count_correct / count_visible_points frac_correct = count_correct / count_visible_points
metrics["pts_within_" + str(thresh)] = frac_correct metrics["pts_within_" + str(thresh)] = frac_correct
all_frac_within.append(frac_correct) all_frac_within.append(frac_correct)
true_positives = np.sum( true_positives = np.sum(
is_correct & pred_visible & evaluation_points, axis=(1, 2) is_correct & pred_visible & evaluation_points, axis=(1, 2)
) )
# The denominator of the jaccard metric is the true positives plus # The denominator of the jaccard metric is the true positives plus
# false positives plus false negatives. However, note that true positives # false positives plus false negatives. However, note that true positives
# plus false negatives is simply the number of points in the ground truth # plus false negatives is simply the number of points in the ground truth
# which is easier to compute than trying to compute all three quantities. # which is easier to compute than trying to compute all three quantities.
# Thus we just add the number of points in the ground truth to the number # Thus we just add the number of points in the ground truth to the number
# of false positives. # of false positives.
# #
# False positives are simply points that are predicted to be visible, # False positives are simply points that are predicted to be visible,
# but the ground truth is not visible or too far from the prediction. # but the ground truth is not visible or too far from the prediction.
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2)) gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
false_positives = (~visible) & pred_visible false_positives = (~visible) & pred_visible
false_positives = false_positives | ((~within_dist) & pred_visible) false_positives = false_positives | ((~within_dist) & pred_visible)
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2)) false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
jaccard = true_positives / (gt_positives + false_positives) jaccard = true_positives / (gt_positives + false_positives)
metrics["jaccard_" + str(thresh)] = jaccard metrics["jaccard_" + str(thresh)] = jaccard
all_jaccard.append(jaccard) all_jaccard.append(jaccard)
metrics["average_jaccard"] = np.mean( metrics["average_jaccard"] = np.mean(
np.stack(all_jaccard, axis=1), np.stack(all_jaccard, axis=1),
axis=1, axis=1,
) )
metrics["average_pts_within_thresh"] = np.mean( metrics["average_pts_within_thresh"] = np.mean(
np.stack(all_frac_within, axis=1), np.stack(all_frac_within, axis=1),
axis=1, axis=1,
) )
return metrics return metrics

View File

@@ -1,253 +1,253 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from collections import defaultdict from collections import defaultdict
import os import os
from typing import Optional from typing import Optional
import torch import torch
from tqdm import tqdm from tqdm import tqdm
import numpy as np import numpy as np
from torch.utils.tensorboard import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from cotracker.datasets.utils import dataclass_to_cuda_ from cotracker.datasets.utils import dataclass_to_cuda_
from cotracker.utils.visualizer import Visualizer from cotracker.utils.visualizer import Visualizer
from cotracker.models.core.model_utils import reduce_masked_mean from cotracker.models.core.model_utils import reduce_masked_mean
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
import logging import logging
class Evaluator: class Evaluator:
""" """
A class defining the CoTracker evaluator. A class defining the CoTracker evaluator.
""" """
def __init__(self, exp_dir) -> None: def __init__(self, exp_dir) -> None:
# Visualization # Visualization
self.exp_dir = exp_dir self.exp_dir = exp_dir
os.makedirs(exp_dir, exist_ok=True) os.makedirs(exp_dir, exist_ok=True)
self.visualization_filepaths = defaultdict(lambda: defaultdict(list)) self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
self.visualize_dir = os.path.join(exp_dir, "visualisations") self.visualize_dir = os.path.join(exp_dir, "visualisations")
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name): def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
if isinstance(pred_trajectory, tuple): if isinstance(pred_trajectory, tuple):
pred_trajectory, pred_visibility = pred_trajectory pred_trajectory, pred_visibility = pred_trajectory
else: else:
pred_visibility = None pred_visibility = None
if "tapvid" in dataset_name: if "tapvid" in dataset_name:
B, T, N, D = sample.trajectory.shape B, T, N, D = sample.trajectory.shape
traj = sample.trajectory.clone() traj = sample.trajectory.clone()
thr = 0.9 thr = 0.9
if pred_visibility is None: if pred_visibility is None:
logging.warning("visibility is NONE") logging.warning("visibility is NONE")
pred_visibility = torch.zeros_like(sample.visibility) pred_visibility = torch.zeros_like(sample.visibility)
if not pred_visibility.dtype == torch.bool: if not pred_visibility.dtype == torch.bool:
pred_visibility = pred_visibility > thr pred_visibility = pred_visibility > thr
query_points = sample.query_points.clone().cpu().numpy() query_points = sample.query_points.clone().cpu().numpy()
pred_visibility = pred_visibility[:, :, :N] pred_visibility = pred_visibility[:, :, :N]
pred_trajectory = pred_trajectory[:, :, :N] pred_trajectory = pred_trajectory[:, :, :N]
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy() gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
gt_occluded = ( gt_occluded = (
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy() torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
) )
pred_occluded = ( pred_occluded = (
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy() torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
) )
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy() pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
out_metrics = compute_tapvid_metrics( out_metrics = compute_tapvid_metrics(
query_points, query_points,
gt_occluded, gt_occluded,
gt_tracks, gt_tracks,
pred_occluded, pred_occluded,
pred_tracks, pred_tracks,
query_mode="strided" if "strided" in dataset_name else "first", query_mode="strided" if "strided" in dataset_name else "first",
) )
metrics[sample.seq_name[0]] = out_metrics metrics[sample.seq_name[0]] = out_metrics
for metric_name in out_metrics.keys(): for metric_name in out_metrics.keys():
if "avg" not in metrics: if "avg" not in metrics:
metrics["avg"] = {} metrics["avg"] = {}
metrics["avg"][metric_name] = np.mean( metrics["avg"][metric_name] = np.mean(
[v[metric_name] for k, v in metrics.items() if k != "avg"] [v[metric_name] for k, v in metrics.items() if k != "avg"]
) )
logging.info(f"Metrics: {out_metrics}") logging.info(f"Metrics: {out_metrics}")
logging.info(f"avg: {metrics['avg']}") logging.info(f"avg: {metrics['avg']}")
print("metrics", out_metrics) print("metrics", out_metrics)
print("avg", metrics["avg"]) print("avg", metrics["avg"])
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey": elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
*_, N, _ = sample.trajectory.shape *_, N, _ = sample.trajectory.shape
B, T, N = sample.visibility.shape B, T, N = sample.visibility.shape
H, W = sample.video.shape[-2:] H, W = sample.video.shape[-2:]
device = sample.video.device device = sample.video.device
out_metrics = {} out_metrics = {}
d_vis_sum = d_occ_sum = d_sum_all = 0.0 d_vis_sum = d_occ_sum = d_sum_all = 0.0
thrs = [1, 2, 4, 8, 16] thrs = [1, 2, 4, 8, 16]
sx_ = (W - 1) / 255.0 sx_ = (W - 1) / 255.0
sy_ = (H - 1) / 255.0 sy_ = (H - 1) / 255.0
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2]) sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
sc_pt = torch.from_numpy(sc_py).float().to(device) sc_pt = torch.from_numpy(sc_py).float().to(device)
__, first_visible_inds = torch.max(sample.visibility, dim=1) __, first_visible_inds = torch.max(sample.visibility, dim=1)
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N) frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1)) start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
for thr in thrs: for thr in thrs:
d_ = ( d_ = (
torch.norm( torch.norm(
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt, pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
dim=-1, dim=-1,
) )
< thr < thr
).float() # B,S-1,N ).float() # B,S-1,N
d_occ = ( d_occ = (
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item() reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
* 100.0 * 100.0
) )
d_occ_sum += d_occ d_occ_sum += d_occ
out_metrics[f"accuracy_occ_{thr}"] = d_occ out_metrics[f"accuracy_occ_{thr}"] = d_occ
d_vis = ( d_vis = (
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0 reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
) )
d_vis_sum += d_vis d_vis_sum += d_vis
out_metrics[f"accuracy_vis_{thr}"] = d_vis out_metrics[f"accuracy_vis_{thr}"] = d_vis
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0 d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
d_sum_all += d_all d_sum_all += d_all
out_metrics[f"accuracy_{thr}"] = d_all out_metrics[f"accuracy_{thr}"] = d_all
d_occ_avg = d_occ_sum / len(thrs) d_occ_avg = d_occ_sum / len(thrs)
d_vis_avg = d_vis_sum / len(thrs) d_vis_avg = d_vis_sum / len(thrs)
d_all_avg = d_sum_all / len(thrs) d_all_avg = d_sum_all / len(thrs)
sur_thr = 50 sur_thr = 50
dists = torch.norm( dists = torch.norm(
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt, pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
dim=-1, dim=-1,
) # B,S,N ) # B,S,N
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
survival = torch.cumprod(dist_ok, dim=1) # B,S,N survival = torch.cumprod(dist_ok, dim=1) # B,S,N
out_metrics["survival"] = torch.mean(survival).item() * 100.0 out_metrics["survival"] = torch.mean(survival).item() * 100.0
out_metrics["accuracy_occ"] = d_occ_avg out_metrics["accuracy_occ"] = d_occ_avg
out_metrics["accuracy_vis"] = d_vis_avg out_metrics["accuracy_vis"] = d_vis_avg
out_metrics["accuracy"] = d_all_avg out_metrics["accuracy"] = d_all_avg
metrics[sample.seq_name[0]] = out_metrics metrics[sample.seq_name[0]] = out_metrics
for metric_name in out_metrics.keys(): for metric_name in out_metrics.keys():
if "avg" not in metrics: if "avg" not in metrics:
metrics["avg"] = {} metrics["avg"] = {}
metrics["avg"][metric_name] = float( metrics["avg"][metric_name] = float(
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"]) np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
) )
logging.info(f"Metrics: {out_metrics}") logging.info(f"Metrics: {out_metrics}")
logging.info(f"avg: {metrics['avg']}") logging.info(f"avg: {metrics['avg']}")
print("metrics", out_metrics) print("metrics", out_metrics)
print("avg", metrics["avg"]) print("avg", metrics["avg"])
@torch.no_grad() @torch.no_grad()
def evaluate_sequence( def evaluate_sequence(
self, self,
model, model,
test_dataloader: torch.utils.data.DataLoader, test_dataloader: torch.utils.data.DataLoader,
dataset_name: str, dataset_name: str,
train_mode=False, train_mode=False,
visualize_every: int = 1, visualize_every: int = 1,
writer: Optional[SummaryWriter] = None, writer: Optional[SummaryWriter] = None,
step: Optional[int] = 0, step: Optional[int] = 0,
): ):
metrics = {} metrics = {}
vis = Visualizer( vis = Visualizer(
save_dir=self.exp_dir, save_dir=self.exp_dir,
fps=7, fps=7,
) )
for ind, sample in enumerate(tqdm(test_dataloader)): for ind, sample in enumerate(tqdm(test_dataloader)):
if isinstance(sample, tuple): if isinstance(sample, tuple):
sample, gotit = sample sample, gotit = sample
if not all(gotit): if not all(gotit):
print("batch is None") print("batch is None")
continue continue
if torch.cuda.is_available(): if torch.cuda.is_available():
dataclass_to_cuda_(sample) dataclass_to_cuda_(sample)
device = torch.device("cuda") device = torch.device("cuda")
else: else:
device = torch.device("cpu") device = torch.device("cpu")
if ( if (
not train_mode not train_mode
and hasattr(model, "sequence_len") and hasattr(model, "sequence_len")
and (sample.visibility[:, : model.sequence_len].sum() == 0) and (sample.visibility[:, : model.sequence_len].sum() == 0)
): ):
print(f"skipping batch {ind}") print(f"skipping batch {ind}")
continue continue
if "tapvid" in dataset_name: if "tapvid" in dataset_name:
queries = sample.query_points.clone().float() queries = sample.query_points.clone().float()
queries = torch.stack( queries = torch.stack(
[ [
queries[:, :, 0], queries[:, :, 0],
queries[:, :, 2], queries[:, :, 2],
queries[:, :, 1], queries[:, :, 1],
], ],
dim=2, dim=2,
).to(device) ).to(device)
else: else:
queries = torch.cat( queries = torch.cat(
[ [
torch.zeros_like(sample.trajectory[:, 0, :, :1]), torch.zeros_like(sample.trajectory[:, 0, :, :1]),
sample.trajectory[:, 0], sample.trajectory[:, 0],
], ],
dim=2, dim=2,
).to(device) ).to(device)
pred_tracks = model(sample.video, queries) pred_tracks = model(sample.video, queries)
if "strided" in dataset_name: if "strided" in dataset_name:
inv_video = sample.video.flip(1).clone() inv_video = sample.video.flip(1).clone()
inv_queries = queries.clone() inv_queries = queries.clone()
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1 inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
pred_trj, pred_vsb = pred_tracks pred_trj, pred_vsb = pred_tracks
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries) inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
inv_pred_trj = inv_pred_trj.flip(1) inv_pred_trj = inv_pred_trj.flip(1)
inv_pred_vsb = inv_pred_vsb.flip(1) inv_pred_vsb = inv_pred_vsb.flip(1)
mask = pred_trj == 0 mask = pred_trj == 0
pred_trj[mask] = inv_pred_trj[mask] pred_trj[mask] = inv_pred_trj[mask]
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]] pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
pred_tracks = pred_trj, pred_vsb pred_tracks = pred_trj, pred_vsb
if dataset_name == "badja" or dataset_name == "fastcapture": if dataset_name == "badja" or dataset_name == "fastcapture":
seq_name = sample.seq_name[0] seq_name = sample.seq_name[0]
else: else:
seq_name = str(ind) seq_name = str(ind)
if ind % visualize_every == 0: if ind % visualize_every == 0:
vis.visualize( vis.visualize(
sample.video, sample.video,
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks, pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
filename=dataset_name + "_" + seq_name, filename=dataset_name + "_" + seq_name,
writer=writer, writer=writer,
step=step, step=step,
) )
self.compute_metrics(metrics, sample, pred_tracks, dataset_name) self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
return metrics return metrics

View File

@@ -1,169 +1,169 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import json import json
import os import os
from dataclasses import dataclass, field from dataclasses import dataclass, field
import hydra import hydra
import numpy as np import numpy as np
import torch import torch
from omegaconf import OmegaConf from omegaconf import OmegaConf
from cotracker.datasets.tap_vid_datasets import TapVidDataset from cotracker.datasets.tap_vid_datasets import TapVidDataset
from cotracker.datasets.dr_dataset import DynamicReplicaDataset from cotracker.datasets.dr_dataset import DynamicReplicaDataset
from cotracker.datasets.utils import collate_fn from cotracker.datasets.utils import collate_fn
from cotracker.models.evaluation_predictor import EvaluationPredictor from cotracker.models.evaluation_predictor import EvaluationPredictor
from cotracker.evaluation.core.evaluator import Evaluator from cotracker.evaluation.core.evaluator import Evaluator
from cotracker.models.build_cotracker import ( from cotracker.models.build_cotracker import (
build_cotracker, build_cotracker,
) )
@dataclass(eq=False) @dataclass(eq=False)
class DefaultConfig: class DefaultConfig:
# Directory where all outputs of the experiment will be saved. # Directory where all outputs of the experiment will be saved.
exp_dir: str = "./outputs" exp_dir: str = "./outputs"
# Name of the dataset to be used for the evaluation. # Name of the dataset to be used for the evaluation.
dataset_name: str = "tapvid_davis_first" dataset_name: str = "tapvid_davis_first"
# The root directory of the dataset. # The root directory of the dataset.
dataset_root: str = "./" dataset_root: str = "./"
# Path to the pre-trained model checkpoint to be used for the evaluation. # Path to the pre-trained model checkpoint to be used for the evaluation.
# The default value is the path to a specific CoTracker model checkpoint. # The default value is the path to a specific CoTracker model checkpoint.
checkpoint: str = "./checkpoints/cotracker2.pth" checkpoint: str = "./checkpoints/cotracker2.pth"
# EvaluationPredictor parameters # EvaluationPredictor parameters
# The size (N) of the support grid used in the predictor. # The size (N) of the support grid used in the predictor.
# The total number of points is (N*N). # The total number of points is (N*N).
grid_size: int = 5 grid_size: int = 5
# The size (N) of the local support grid. # The size (N) of the local support grid.
local_grid_size: int = 8 local_grid_size: int = 8
# A flag indicating whether to evaluate one ground truth point at a time. # A flag indicating whether to evaluate one ground truth point at a time.
single_point: bool = True single_point: bool = True
# The number of iterative updates for each sliding window. # The number of iterative updates for each sliding window.
n_iters: int = 6 n_iters: int = 6
seed: int = 0 seed: int = 0
gpu_idx: int = 0 gpu_idx: int = 0
# Override hydra's working directory to current working dir, # Override hydra's working directory to current working dir,
# also disable storing the .hydra logs: # also disable storing the .hydra logs:
hydra: dict = field( hydra: dict = field(
default_factory=lambda: { default_factory=lambda: {
"run": {"dir": "."}, "run": {"dir": "."},
"output_subdir": None, "output_subdir": None,
} }
) )
def run_eval(cfg: DefaultConfig): def run_eval(cfg: DefaultConfig):
""" """
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration. The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
Args: Args:
cfg (DefaultConfig): An instance of DefaultConfig class which includes: cfg (DefaultConfig): An instance of DefaultConfig class which includes:
- exp_dir (str): The directory path for the experiment. - exp_dir (str): The directory path for the experiment.
- dataset_name (str): The name of the dataset to be used. - dataset_name (str): The name of the dataset to be used.
- dataset_root (str): The root directory of the dataset. - dataset_root (str): The root directory of the dataset.
- checkpoint (str): The path to the CoTracker model's checkpoint. - checkpoint (str): The path to the CoTracker model's checkpoint.
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time. - single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
- n_iters (int): The number of iterative updates for each sliding window. - n_iters (int): The number of iterative updates for each sliding window.
- seed (int): The seed for setting the random state for reproducibility. - seed (int): The seed for setting the random state for reproducibility.
- gpu_idx (int): The index of the GPU to be used. - gpu_idx (int): The index of the GPU to be used.
""" """
# Creating the experiment directory if it doesn't exist # Creating the experiment directory if it doesn't exist
os.makedirs(cfg.exp_dir, exist_ok=True) os.makedirs(cfg.exp_dir, exist_ok=True)
# Saving the experiment configuration to a .yaml file in the experiment directory # Saving the experiment configuration to a .yaml file in the experiment directory
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml") cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
with open(cfg_file, "w") as f: with open(cfg_file, "w") as f:
OmegaConf.save(config=cfg, f=f) OmegaConf.save(config=cfg, f=f)
evaluator = Evaluator(cfg.exp_dir) evaluator = Evaluator(cfg.exp_dir)
cotracker_model = build_cotracker(cfg.checkpoint) cotracker_model = build_cotracker(cfg.checkpoint)
# Creating the EvaluationPredictor object # Creating the EvaluationPredictor object
predictor = EvaluationPredictor( predictor = EvaluationPredictor(
cotracker_model, cotracker_model,
grid_size=cfg.grid_size, grid_size=cfg.grid_size,
local_grid_size=cfg.local_grid_size, local_grid_size=cfg.local_grid_size,
single_point=cfg.single_point, single_point=cfg.single_point,
n_iters=cfg.n_iters, n_iters=cfg.n_iters,
) )
if torch.cuda.is_available(): if torch.cuda.is_available():
predictor.model = predictor.model.cuda() predictor.model = predictor.model.cuda()
# Setting the random seeds # Setting the random seeds
torch.manual_seed(cfg.seed) torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed) np.random.seed(cfg.seed)
# Constructing the specified dataset # Constructing the specified dataset
curr_collate_fn = collate_fn curr_collate_fn = collate_fn
if "tapvid" in cfg.dataset_name: if "tapvid" in cfg.dataset_name:
dataset_type = cfg.dataset_name.split("_")[1] dataset_type = cfg.dataset_name.split("_")[1]
if dataset_type == "davis": if dataset_type == "davis":
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl") data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
elif dataset_type == "kinetics": elif dataset_type == "kinetics":
data_root = os.path.join( data_root = os.path.join(
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics" cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
) )
test_dataset = TapVidDataset( test_dataset = TapVidDataset(
dataset_type=dataset_type, dataset_type=dataset_type,
data_root=data_root, data_root=data_root,
queried_first=not "strided" in cfg.dataset_name, queried_first=not "strided" in cfg.dataset_name,
) )
elif cfg.dataset_name == "dynamic_replica": elif cfg.dataset_name == "dynamic_replica":
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1) test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
# Creating the DataLoader object # Creating the DataLoader object
test_dataloader = torch.utils.data.DataLoader( test_dataloader = torch.utils.data.DataLoader(
test_dataset, test_dataset,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
num_workers=14, num_workers=14,
collate_fn=curr_collate_fn, collate_fn=curr_collate_fn,
) )
# Timing and conducting the evaluation # Timing and conducting the evaluation
import time import time
start = time.time() start = time.time()
evaluate_result = evaluator.evaluate_sequence( evaluate_result = evaluator.evaluate_sequence(
predictor, predictor,
test_dataloader, test_dataloader,
dataset_name=cfg.dataset_name, dataset_name=cfg.dataset_name,
) )
end = time.time() end = time.time()
print(end - start) print(end - start)
# Saving the evaluation results to a .json file # Saving the evaluation results to a .json file
evaluate_result = evaluate_result["avg"] evaluate_result = evaluate_result["avg"]
print("evaluate_result", evaluate_result) print("evaluate_result", evaluate_result)
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json") result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
evaluate_result["time"] = end - start evaluate_result["time"] = end - start
print(f"Dumping eval results to {result_file}.") print(f"Dumping eval results to {result_file}.")
with open(result_file, "w") as f: with open(result_file, "w") as f:
json.dump(evaluate_result, f) json.dump(evaluate_result, f)
cs = hydra.core.config_store.ConfigStore.instance() cs = hydra.core.config_store.ConfigStore.instance()
cs.store(name="default_config_eval", node=DefaultConfig) cs.store(name="default_config_eval", node=DefaultConfig)
@hydra.main(config_path="./configs/", config_name="default_config_eval") @hydra.main(config_path="./configs/", config_name="default_config_eval")
def evaluate(cfg: DefaultConfig) -> None: def evaluate(cfg: DefaultConfig) -> None:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx) os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
run_eval(cfg) run_eval(cfg)
if __name__ == "__main__": if __name__ == "__main__":
evaluate() evaluate()

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.

Binary file not shown.

Binary file not shown.

View File

@@ -1,33 +1,33 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
from cotracker.models.core.cotracker.cotracker import CoTracker2 from cotracker.models.core.cotracker.cotracker import CoTracker2
def build_cotracker( def build_cotracker(
checkpoint: str, checkpoint: str,
): ):
if checkpoint is None: if checkpoint is None:
return build_cotracker() return build_cotracker()
model_name = checkpoint.split("/")[-1].split(".")[0] model_name = checkpoint.split("/")[-1].split(".")[0]
if model_name == "cotracker": if model_name == "cotracker":
return build_cotracker(checkpoint=checkpoint) return build_cotracker(checkpoint=checkpoint)
else: else:
raise ValueError(f"Unknown model name {model_name}") raise ValueError(f"Unknown model name {model_name}")
def build_cotracker(checkpoint=None): def build_cotracker(checkpoint=None):
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True) cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
if checkpoint is not None: if checkpoint is not None:
with open(checkpoint, "rb") as f: with open(checkpoint, "rb") as f:
state_dict = torch.load(f, map_location="cpu") state_dict = torch.load(f, map_location="cpu")
if "model" in state_dict: if "model" in state_dict:
state_dict = state_dict["model"] state_dict = state_dict["model"]
cotracker.load_state_dict(state_dict) cotracker.load_state_dict(state_dict)
return cotracker return cotracker

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.

View File

@@ -1,367 +1,368 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from functools import partial from functools import partial
from typing import Callable from typing import Callable
import collections import collections
from torch import Tensor from torch import Tensor
from itertools import repeat from itertools import repeat
from cotracker.models.core.model_utils import bilinear_sampler from cotracker.models.core.model_utils import bilinear_sampler
# From PyTorch internals # From PyTorch internals
def _ntuple(n): def _ntuple(n):
def parse(x): def parse(x):
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
return tuple(x) return tuple(x)
return tuple(repeat(x, n)) return tuple(repeat(x, n))
return parse return parse
def exists(val): def exists(val):
return val is not None return val is not None
def default(val, d): def default(val, d):
return val if exists(val) else d return val if exists(val) else d
to_2tuple = _ntuple(2) to_2tuple = _ntuple(2)
class Mlp(nn.Module): class Mlp(nn.Module):
"""MLP as used in Vision Transformer, MLP-Mixer and related networks""" """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
def __init__( def __init__(
self, self,
in_features, in_features,
hidden_features=None, hidden_features=None,
out_features=None, out_features=None,
act_layer=nn.GELU, act_layer=nn.GELU,
norm_layer=None, norm_layer=None,
bias=True, bias=True,
drop=0.0, drop=0.0,
use_conv=False, use_conv=False,
): ):
super().__init__() super().__init__()
out_features = out_features or in_features out_features = out_features or in_features
hidden_features = hidden_features or in_features hidden_features = hidden_features or in_features
bias = to_2tuple(bias) bias = to_2tuple(bias)
drop_probs = to_2tuple(drop) drop_probs = to_2tuple(drop)
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
self.act = act_layer() self.act = act_layer()
self.drop1 = nn.Dropout(drop_probs[0]) self.drop1 = nn.Dropout(drop_probs[0])
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
self.drop2 = nn.Dropout(drop_probs[1]) self.drop2 = nn.Dropout(drop_probs[1])
def forward(self, x): def forward(self, x):
x = self.fc1(x) x = self.fc1(x)
x = self.act(x) x = self.act(x)
x = self.drop1(x) x = self.drop1(x)
x = self.fc2(x) x = self.fc2(x)
x = self.drop2(x) x = self.drop2(x)
return x return x
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn="group", stride=1): def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__() super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
in_planes, in_planes,
planes, planes,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
stride=stride, stride=stride,
padding_mode="zeros", padding_mode="zeros",
) )
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros") self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8 num_groups = planes // 8
if norm_fn == "group": if norm_fn == "group":
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
if not stride == 1: if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == "batch": elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2d(planes) self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes) self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1: if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes) self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == "instance": elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2d(planes) self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes) self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1: if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes) self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == "none": elif norm_fn == "none":
self.norm1 = nn.Sequential() self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential() self.norm2 = nn.Sequential()
if not stride == 1: if not stride == 1:
self.norm3 = nn.Sequential() self.norm3 = nn.Sequential()
if stride == 1: if stride == 1:
self.downsample = None self.downsample = None
else: else:
self.downsample = nn.Sequential( self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3 nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
) )
def forward(self, x): def forward(self, x):
y = x y = x
y = self.relu(self.norm1(self.conv1(y))) y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y))) y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None: if self.downsample is not None:
x = self.downsample(x) x = self.downsample(x)
return self.relu(x + y) return self.relu(x + y)
class BasicEncoder(nn.Module): class BasicEncoder(nn.Module):
def __init__(self, input_dim=3, output_dim=128, stride=4): def __init__(self, input_dim=3, output_dim=128, stride=4):
super(BasicEncoder, self).__init__() super(BasicEncoder, self).__init__()
self.stride = stride self.stride = stride
self.norm_fn = "instance" self.norm_fn = "instance"
self.in_planes = output_dim // 2 self.in_planes = output_dim // 2
self.norm1 = nn.InstanceNorm2d(self.in_planes) self.norm1 = nn.InstanceNorm2d(self.in_planes)
self.norm2 = nn.InstanceNorm2d(output_dim * 2) self.norm2 = nn.InstanceNorm2d(output_dim * 2)
self.conv1 = nn.Conv2d( self.conv1 = nn.Conv2d(
input_dim, input_dim,
self.in_planes, self.in_planes,
kernel_size=7, kernel_size=7,
stride=2, stride=2,
padding=3, padding=3,
padding_mode="zeros", padding_mode="zeros",
) )
self.relu1 = nn.ReLU(inplace=True) self.relu1 = nn.ReLU(inplace=True)
self.layer1 = self._make_layer(output_dim // 2, stride=1) self.layer1 = self._make_layer(output_dim // 2, stride=1)
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2) self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
self.layer3 = self._make_layer(output_dim, stride=2) self.layer3 = self._make_layer(output_dim, stride=2)
self.layer4 = self._make_layer(output_dim, stride=2) self.layer4 = self._make_layer(output_dim, stride=2)
self.conv2 = nn.Conv2d( self.conv2 = nn.Conv2d(
output_dim * 3 + output_dim // 4, output_dim * 3 + output_dim // 4,
output_dim * 2, output_dim * 2,
kernel_size=3, kernel_size=3,
padding=1, padding=1,
padding_mode="zeros", padding_mode="zeros",
) )
self.relu2 = nn.ReLU(inplace=True) self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1) self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv2d): if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, (nn.InstanceNorm2d)): elif isinstance(m, (nn.InstanceNorm2d)):
if m.weight is not None: if m.weight is not None:
nn.init.constant_(m.weight, 1) nn.init.constant_(m.weight, 1)
if m.bias is not None: if m.bias is not None:
nn.init.constant_(m.bias, 0) nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1): def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2) layers = (layer1, layer2)
self.in_planes = dim self.in_planes = dim
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward(self, x): def forward(self, x):
_, _, H, W = x.shape _, _, H, W = x.shape
x = self.conv1(x) x = self.conv1(x)
x = self.norm1(x) x = self.norm1(x)
x = self.relu1(x) x = self.relu1(x)
a = self.layer1(x) # 四层残差块
b = self.layer2(a) a = self.layer1(x)
c = self.layer3(b) b = self.layer2(a)
d = self.layer4(c) c = self.layer3(b)
d = self.layer4(c)
def _bilinear_intepolate(x):
return F.interpolate( def _bilinear_intepolate(x):
x, return F.interpolate(
(H // self.stride, W // self.stride), x,
mode="bilinear", (H // self.stride, W // self.stride),
align_corners=True, mode="bilinear",
) align_corners=True,
)
a = _bilinear_intepolate(a)
b = _bilinear_intepolate(b) a = _bilinear_intepolate(a)
c = _bilinear_intepolate(c) b = _bilinear_intepolate(b)
d = _bilinear_intepolate(d) c = _bilinear_intepolate(c)
d = _bilinear_intepolate(d)
x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.norm2(x) x = self.conv2(torch.cat([a, b, c, d], dim=1))
x = self.relu2(x) x = self.norm2(x)
x = self.conv3(x) x = self.relu2(x)
return x x = self.conv3(x)
return x
class CorrBlock:
def __init__( class CorrBlock:
self, def __init__(
fmaps, self,
num_levels=4, fmaps,
radius=4, num_levels=4,
multiple_track_feats=False, radius=4,
padding_mode="zeros", multiple_track_feats=False,
): padding_mode="zeros",
B, S, C, H, W = fmaps.shape ):
self.S, self.C, self.H, self.W = S, C, H, W B, S, C, H, W = fmaps.shape
self.padding_mode = padding_mode self.S, self.C, self.H, self.W = S, C, H, W
self.num_levels = num_levels self.padding_mode = padding_mode
self.radius = radius self.num_levels = num_levels
self.fmaps_pyramid = [] self.radius = radius
self.multiple_track_feats = multiple_track_feats self.fmaps_pyramid = []
self.multiple_track_feats = multiple_track_feats
self.fmaps_pyramid.append(fmaps)
for i in range(self.num_levels - 1): self.fmaps_pyramid.append(fmaps)
fmaps_ = fmaps.reshape(B * S, C, H, W) for i in range(self.num_levels - 1):
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2) fmaps_ = fmaps.reshape(B * S, C, H, W)
_, _, H, W = fmaps_.shape fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
fmaps = fmaps_.reshape(B, S, C, H, W) _, _, H, W = fmaps_.shape
self.fmaps_pyramid.append(fmaps) fmaps = fmaps_.reshape(B, S, C, H, W)
self.fmaps_pyramid.append(fmaps)
def sample(self, coords):
r = self.radius def sample(self, coords):
B, S, N, D = coords.shape r = self.radius
assert D == 2 B, S, N, D = coords.shape
assert D == 2
H, W = self.H, self.W
out_pyramid = [] H, W = self.H, self.W
for i in range(self.num_levels): out_pyramid = []
corrs = self.corrs_pyramid[i] # B, S, N, H, W for i in range(self.num_levels):
*_, H, W = corrs.shape corrs = self.corrs_pyramid[i] # B, S, N, H, W
*_, H, W = corrs.shape
dx = torch.linspace(-r, r, 2 * r + 1)
dy = torch.linspace(-r, r, 2 * r + 1) dx = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device) dy = torch.linspace(-r, r, 2 * r + 1)
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
coords_lvl = centroid_lvl + delta_lvl delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corrs = bilinear_sampler(
corrs.reshape(B * S * N, 1, H, W), corrs = bilinear_sampler(
coords_lvl, corrs.reshape(B * S * N, 1, H, W),
padding_mode=self.padding_mode, coords_lvl,
) padding_mode=self.padding_mode,
corrs = corrs.view(B, S, N, -1) )
out_pyramid.append(corrs) corrs = corrs.view(B, S, N, -1)
out_pyramid.append(corrs)
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float() out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
return out out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
return out
def corr(self, targets):
B, S, N, C = targets.shape def corr(self, targets):
if self.multiple_track_feats: B, S, N, C = targets.shape
targets_split = targets.split(C // self.num_levels, dim=-1) if self.multiple_track_feats:
B, S, N, C = targets_split[0].shape targets_split = targets.split(C // self.num_levels, dim=-1)
B, S, N, C = targets_split[0].shape
assert C == self.C
assert S == self.S assert C == self.C
assert S == self.S
fmap1 = targets
fmap1 = targets
self.corrs_pyramid = []
for i, fmaps in enumerate(self.fmaps_pyramid): self.corrs_pyramid = []
*_, H, W = fmaps.shape for i, fmaps in enumerate(self.fmaps_pyramid):
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W) *_, H, W = fmaps.shape
if self.multiple_track_feats: fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
fmap1 = targets_split[i] if self.multiple_track_feats:
corrs = torch.matmul(fmap1, fmap2s) fmap1 = targets_split[i]
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W corrs = torch.matmul(fmap1, fmap2s)
corrs = corrs / torch.sqrt(torch.tensor(C).float()) corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
self.corrs_pyramid.append(corrs) corrs = corrs / torch.sqrt(torch.tensor(C).float())
self.corrs_pyramid.append(corrs)
class Attention(nn.Module):
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False): class Attention(nn.Module):
super().__init__() def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
inner_dim = dim_head * num_heads super().__init__()
context_dim = default(context_dim, query_dim) inner_dim = dim_head * num_heads
self.scale = dim_head**-0.5 context_dim = default(context_dim, query_dim)
self.heads = num_heads self.scale = dim_head**-0.5
self.heads = num_heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias) self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
self.to_out = nn.Linear(inner_dim, query_dim) self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
self.to_out = nn.Linear(inner_dim, query_dim)
def forward(self, x, context=None, attn_bias=None):
B, N1, C = x.shape def forward(self, x, context=None, attn_bias=None):
h = self.heads B, N1, C = x.shape
h = self.heads
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
context = default(context, x) q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
k, v = self.to_kv(context).chunk(2, dim=-1) context = default(context, x)
k, v = self.to_kv(context).chunk(2, dim=-1)
N2 = context.shape[1]
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) N2 = context.shape[1]
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3) k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
sim = (q @ k.transpose(-2, -1)) * self.scale
sim = (q @ k.transpose(-2, -1)) * self.scale
if attn_bias is not None:
sim = sim + attn_bias if attn_bias is not None:
attn = sim.softmax(dim=-1) sim = sim + attn_bias
attn = sim.softmax(dim=-1)
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
return self.to_out(x) x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
return self.to_out(x)
class AttnBlock(nn.Module):
def __init__( class AttnBlock(nn.Module):
self, def __init__(
hidden_size, self,
num_heads, hidden_size,
attn_class: Callable[..., nn.Module] = Attention, num_heads,
mlp_ratio=4.0, attn_class: Callable[..., nn.Module] = Attention,
**block_kwargs mlp_ratio=4.0,
): **block_kwargs
super().__init__() ):
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) super().__init__()
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio) self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
approx_gelu = lambda: nn.GELU(approximate="tanh") mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.mlp = Mlp( approx_gelu = lambda: nn.GELU(approximate="tanh")
in_features=hidden_size, self.mlp = Mlp(
hidden_features=mlp_hidden_dim, in_features=hidden_size,
act_layer=approx_gelu, hidden_features=mlp_hidden_dim,
drop=0, act_layer=approx_gelu,
) drop=0,
)
def forward(self, x, mask=None):
attn_bias = mask def forward(self, x, mask=None):
if mask is not None: attn_bias = mask
mask = ( if mask is not None:
(mask[:, None] * mask[:, :, None]) mask = (
.unsqueeze(1) (mask[:, None] * mask[:, :, None])
.expand(-1, self.attn.num_heads, -1, -1) .unsqueeze(1)
) .expand(-1, self.attn.num_heads, -1, -1)
max_neg_value = -torch.finfo(x.dtype).max )
attn_bias = (~mask) * max_neg_value max_neg_value = -torch.finfo(x.dtype).max
x = x + self.attn(self.norm1(x), attn_bias=attn_bias) attn_bias = (~mask) * max_neg_value
x = x + self.mlp(self.norm2(x)) x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
return x x = x + self.mlp(self.norm2(x))
return x

File diff suppressed because it is too large Load Diff

View File

@@ -1,61 +1,61 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from cotracker.models.core.model_utils import reduce_masked_mean from cotracker.models.core.model_utils import reduce_masked_mean
EPS = 1e-6 EPS = 1e-6
def balanced_ce_loss(pred, gt, valid=None): def balanced_ce_loss(pred, gt, valid=None):
total_balanced_loss = 0.0 total_balanced_loss = 0.0
for j in range(len(gt)): for j in range(len(gt)):
B, S, N = gt[j].shape B, S, N = gt[j].shape
# pred and gt are the same shape # pred and gt are the same shape
for (a, b) in zip(pred[j].size(), gt[j].size()): for (a, b) in zip(pred[j].size(), gt[j].size()):
assert a == b # some shape mismatch! assert a == b # some shape mismatch!
# if valid is not None: # if valid is not None:
for (a, b) in zip(pred[j].size(), valid[j].size()): for (a, b) in zip(pred[j].size(), valid[j].size()):
assert a == b # some shape mismatch! assert a == b # some shape mismatch!
pos = (gt[j] > 0.95).float() pos = (gt[j] > 0.95).float()
neg = (gt[j] < 0.05).float() neg = (gt[j] < 0.05).float()
label = pos * 2.0 - 1.0 label = pos * 2.0 - 1.0
a = -label * pred[j] a = -label * pred[j]
b = F.relu(a) b = F.relu(a)
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b)) loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
pos_loss = reduce_masked_mean(loss, pos * valid[j]) pos_loss = reduce_masked_mean(loss, pos * valid[j])
neg_loss = reduce_masked_mean(loss, neg * valid[j]) neg_loss = reduce_masked_mean(loss, neg * valid[j])
balanced_loss = pos_loss + neg_loss balanced_loss = pos_loss + neg_loss
total_balanced_loss += balanced_loss / float(N) total_balanced_loss += balanced_loss / float(N)
return total_balanced_loss return total_balanced_loss
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8): def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
"""Loss function defined over sequence of flow predictions""" """Loss function defined over sequence of flow predictions"""
total_flow_loss = 0.0 total_flow_loss = 0.0
for j in range(len(flow_gt)): for j in range(len(flow_gt)):
B, S, N, D = flow_gt[j].shape B, S, N, D = flow_gt[j].shape
assert D == 2 assert D == 2
B, S1, N = vis[j].shape B, S1, N = vis[j].shape
B, S2, N = valids[j].shape B, S2, N = valids[j].shape
assert S == S1 assert S == S1
assert S == S2 assert S == S2
n_predictions = len(flow_preds[j]) n_predictions = len(flow_preds[j])
flow_loss = 0.0 flow_loss = 0.0
for i in range(n_predictions): for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1) i_weight = gamma ** (n_predictions - i - 1)
flow_pred = flow_preds[j][i] flow_pred = flow_preds[j][i]
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2 i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
i_loss = torch.mean(i_loss, dim=3) # B, S, N i_loss = torch.mean(i_loss, dim=3) # B, S, N
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j]) flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
flow_loss = flow_loss / n_predictions flow_loss = flow_loss / n_predictions
total_flow_loss += flow_loss / float(N) total_flow_loss += flow_loss / float(N)
return total_flow_loss return total_flow_loss

View File

@@ -1,120 +1,120 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
from typing import Tuple, Union from typing import Tuple, Union
import torch import torch
def get_2d_sincos_pos_embed( def get_2d_sincos_pos_embed(
embed_dim: int, grid_size: Union[int, Tuple[int, int]] embed_dim: int, grid_size: Union[int, Tuple[int, int]]
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions. This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
It is a wrapper of get_2d_sincos_pos_embed_from_grid. It is a wrapper of get_2d_sincos_pos_embed_from_grid.
Args: Args:
- embed_dim: The embedding dimension. - embed_dim: The embedding dimension.
- grid_size: The grid size. - grid_size: The grid size.
Returns: Returns:
- pos_embed: The generated 2D positional embedding. - pos_embed: The generated 2D positional embedding.
""" """
if isinstance(grid_size, tuple): if isinstance(grid_size, tuple):
grid_size_h, grid_size_w = grid_size grid_size_h, grid_size_w = grid_size
else: else:
grid_size_h = grid_size_w = grid_size grid_size_h = grid_size_w = grid_size
grid_h = torch.arange(grid_size_h, dtype=torch.float) grid_h = torch.arange(grid_size_h, dtype=torch.float)
grid_w = torch.arange(grid_size_w, dtype=torch.float) grid_w = torch.arange(grid_size_w, dtype=torch.float)
grid = torch.meshgrid(grid_w, grid_h, indexing="xy") grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
grid = torch.stack(grid, dim=0) grid = torch.stack(grid, dim=0)
grid = grid.reshape([2, 1, grid_size_h, grid_size_w]) grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2) return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
def get_2d_sincos_pos_embed_from_grid( def get_2d_sincos_pos_embed_from_grid(
embed_dim: int, grid: torch.Tensor embed_dim: int, grid: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function generates a 2D positional embedding from a given grid using sine and cosine functions. This function generates a 2D positional embedding from a given grid using sine and cosine functions.
Args: Args:
- embed_dim: The embedding dimension. - embed_dim: The embedding dimension.
- grid: The grid to generate the embedding from. - grid: The grid to generate the embedding from.
Returns: Returns:
- emb: The generated 2D positional embedding. - emb: The generated 2D positional embedding.
""" """
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
# use half of dimensions to encode grid_h # use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D) emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
return emb return emb
def get_1d_sincos_pos_embed_from_grid( def get_1d_sincos_pos_embed_from_grid(
embed_dim: int, pos: torch.Tensor embed_dim: int, pos: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function generates a 1D positional embedding from a given grid using sine and cosine functions. This function generates a 1D positional embedding from a given grid using sine and cosine functions.
Args: Args:
- embed_dim: The embedding dimension. - embed_dim: The embedding dimension.
- pos: The position to generate the embedding from. - pos: The position to generate the embedding from.
Returns: Returns:
- emb: The generated 1D positional embedding. - emb: The generated 1D positional embedding.
""" """
assert embed_dim % 2 == 0 assert embed_dim % 2 == 0
omega = torch.arange(embed_dim // 2, dtype=torch.double) omega = torch.arange(embed_dim // 2, dtype=torch.double)
omega /= embed_dim / 2.0 omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,) omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,) pos = pos.reshape(-1) # (M,)
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = torch.sin(out) # (M, D/2) emb_sin = torch.sin(out) # (M, D/2)
emb_cos = torch.cos(out) # (M, D/2) emb_cos = torch.cos(out) # (M, D/2)
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D) emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
return emb[None].float() return emb[None].float()
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor: def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
""" """
This function generates a 2D positional embedding from given coordinates using sine and cosine functions. This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
Args: Args:
- xy: The coordinates to generate the embedding from. - xy: The coordinates to generate the embedding from.
- C: The size of the embedding. - C: The size of the embedding.
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding. - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
Returns: Returns:
- pe: The generated 2D positional embedding. - pe: The generated 2D positional embedding.
""" """
B, N, D = xy.shape B, N, D = xy.shape
assert D == 2 assert D == 2
x = xy[:, :, 0:1] x = xy[:, :, 0:1]
y = xy[:, :, 1:2] y = xy[:, :, 1:2]
div_term = ( div_term = (
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C) torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
).reshape(1, 1, int(C / 2)) ).reshape(1, 1, int(C / 2))
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32) pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
pe_x[:, :, 0::2] = torch.sin(x * div_term) pe_x[:, :, 0::2] = torch.sin(x * div_term)
pe_x[:, :, 1::2] = torch.cos(x * div_term) pe_x[:, :, 1::2] = torch.cos(x * div_term)
pe_y[:, :, 0::2] = torch.sin(y * div_term) pe_y[:, :, 0::2] = torch.sin(y * div_term)
pe_y[:, :, 1::2] = torch.cos(y * div_term) pe_y[:, :, 1::2] = torch.cos(y * div_term)
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3) pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
if cat_coords: if cat_coords:
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3) pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
return pe return pe

View File

@@ -1,256 +1,256 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, Tuple from typing import Optional, Tuple
EPS = 1e-6 EPS = 1e-6
def smart_cat(tensor1, tensor2, dim): def smart_cat(tensor1, tensor2, dim):
if tensor1 is None: if tensor1 is None:
return tensor2 return tensor2
return torch.cat([tensor1, tensor2], dim=dim) return torch.cat([tensor1, tensor2], dim=dim)
def get_points_on_a_grid( def get_points_on_a_grid(
size: int, size: int,
extent: Tuple[float, ...], extent: Tuple[float, ...],
center: Optional[Tuple[float, ...]] = None, center: Optional[Tuple[float, ...]] = None,
device: Optional[torch.device] = torch.device("cpu"), device: Optional[torch.device] = torch.device("cpu"),
): ):
r"""Get a grid of points covering a rectangular region r"""Get a grid of points covering a rectangular region
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by `get_points_on_a_grid(size, extent)` generates a :attr:`size` by
:attr:`size` grid fo points distributed to cover a rectangular area :attr:`size` grid fo points distributed to cover a rectangular area
specified by `extent`. specified by `extent`.
The `extent` is a pair of integer :math:`(H,W)` specifying the height The `extent` is a pair of integer :math:`(H,W)` specifying the height
and width of the rectangle. and width of the rectangle.
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)` Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
specifying the vertical and horizontal center coordinates. The center specifying the vertical and horizontal center coordinates. The center
defaults to the middle of the extent. defaults to the middle of the extent.
Points are distributed uniformly within the rectangle leaving a margin Points are distributed uniformly within the rectangle leaving a margin
:math:`m=W/64` from the border. :math:`m=W/64` from the border.
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
points :math:`P_{ij}=(x_i, y_i)` where points :math:`P_{ij}=(x_i, y_i)` where
.. math:: .. math::
P_{ij} = \left( P_{ij} = \left(
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~ c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
\right) \right)
Points are returned in row-major order. Points are returned in row-major order.
Args: Args:
size (int): grid size. size (int): grid size.
extent (tuple): height and with of the grid extent. extent (tuple): height and with of the grid extent.
center (tuple, optional): grid center. center (tuple, optional): grid center.
device (str, optional): Defaults to `"cpu"`. device (str, optional): Defaults to `"cpu"`.
Returns: Returns:
Tensor: grid. Tensor: grid.
""" """
if size == 1: if size == 1:
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None] return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
if center is None: if center is None:
center = [extent[0] / 2, extent[1] / 2] center = [extent[0] / 2, extent[1] / 2]
margin = extent[1] / 64 margin = extent[1] / 64
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin) range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin) range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
grid_y, grid_x = torch.meshgrid( grid_y, grid_x = torch.meshgrid(
torch.linspace(*range_y, size, device=device), torch.linspace(*range_y, size, device=device),
torch.linspace(*range_x, size, device=device), torch.linspace(*range_x, size, device=device),
indexing="ij", indexing="ij",
) )
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2) return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
def reduce_masked_mean(input, mask, dim=None, keepdim=False): def reduce_masked_mean(input, mask, dim=None, keepdim=False):
r"""Masked mean r"""Masked mean
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input` `reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
over a mask :attr:`mask`, returning over a mask :attr:`mask`, returning
.. math:: .. math::
\text{output} = \text{output} =
\frac \frac
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i} {\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
{\epsilon + \sum_{i=1}^N \text{mask}_i} {\epsilon + \sum_{i=1}^N \text{mask}_i}
where :math:`N` is the number of elements in :attr:`input` and where :math:`N` is the number of elements in :attr:`input` and
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid :attr:`mask`, and :math:`\epsilon` is a small constant to avoid
division by zero. division by zero.
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor `reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`. :attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
Optionally, the dimension can be kept in the output by setting Optionally, the dimension can be kept in the output by setting
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to :attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
the same dimension as :attr:`input`. the same dimension as :attr:`input`.
The interface is similar to `torch.mean()`. The interface is similar to `torch.mean()`.
Args: Args:
inout (Tensor): input tensor. inout (Tensor): input tensor.
mask (Tensor): mask. mask (Tensor): mask.
dim (int, optional): Dimension to sum over. Defaults to None. dim (int, optional): Dimension to sum over. Defaults to None.
keepdim (bool, optional): Keep the summed dimension. Defaults to False. keepdim (bool, optional): Keep the summed dimension. Defaults to False.
Returns: Returns:
Tensor: mean tensor. Tensor: mean tensor.
""" """
mask = mask.expand_as(input) mask = mask.expand_as(input)
prod = input * mask prod = input * mask
if dim is None: if dim is None:
numer = torch.sum(prod) numer = torch.sum(prod)
denom = torch.sum(mask) denom = torch.sum(mask)
else: else:
numer = torch.sum(prod, dim=dim, keepdim=keepdim) numer = torch.sum(prod, dim=dim, keepdim=keepdim)
denom = torch.sum(mask, dim=dim, keepdim=keepdim) denom = torch.sum(mask, dim=dim, keepdim=keepdim)
mean = numer / (EPS + denom) mean = numer / (EPS + denom)
return mean return mean
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"): def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
r"""Sample a tensor using bilinear interpolation r"""Sample a tensor using bilinear interpolation
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
coordinates :attr:`coords` using bilinear interpolation. It is the same coordinates :attr:`coords` using bilinear interpolation. It is the same
as `torch.nn.functional.grid_sample()` but with a different coordinate as `torch.nn.functional.grid_sample()` but with a different coordinate
convention. convention.
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
:math:`B` is the batch size, :math:`C` is the number of channels, :math:`B` is the batch size, :math:`C` is the number of channels,
:math:`H` is the height of the image, and :math:`W` is the width of the :math:`H` is the height of the image, and :math:`W` is the width of the
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`. interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`, Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
that in this case the order of the components is slightly different that in this case the order of the components is slightly different
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`. from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
left-most image pixel :math:`W-1` to the center of the right-most left-most image pixel :math:`W-1` to the center of the right-most
pixel. pixel.
If `align_corners` is `False`, the coordinate :math:`x` is assumed to If `align_corners` is `False`, the coordinate :math:`x` is assumed to
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
the left-most pixel :math:`W` to the right edge of the right-most the left-most pixel :math:`W` to the right edge of the right-most
pixel. pixel.
Similar conventions apply to the :math:`y` for the range Similar conventions apply to the :math:`y` for the range
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
:math:`[0,T-1]` and :math:`[0,T]`. :math:`[0,T-1]` and :math:`[0,T]`.
Args: Args:
input (Tensor): batch of input images. input (Tensor): batch of input images.
coords (Tensor): batch of coordinates. coords (Tensor): batch of coordinates.
align_corners (bool, optional): Coordinate convention. Defaults to `True`. align_corners (bool, optional): Coordinate convention. Defaults to `True`.
padding_mode (str, optional): Padding mode. Defaults to `"border"`. padding_mode (str, optional): Padding mode. Defaults to `"border"`.
Returns: Returns:
Tensor: sampled points. Tensor: sampled points.
""" """
sizes = input.shape[2:] sizes = input.shape[2:]
assert len(sizes) in [2, 3] assert len(sizes) in [2, 3]
if len(sizes) == 3: if len(sizes) == 3:
# t x y -> x y t to match dimensions T H W in grid_sample # t x y -> x y t to match dimensions T H W in grid_sample
coords = coords[..., [1, 2, 0]] coords = coords[..., [1, 2, 0]]
if align_corners: if align_corners:
coords = coords * torch.tensor( coords = coords * torch.tensor(
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
) )
else: else:
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device) coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
coords -= 1 coords -= 1
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode) return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
def sample_features4d(input, coords): def sample_features4d(input, coords):
r"""Sample spatial features r"""Sample spatial features
`sample_features4d(input, coords)` samples the spatial features `sample_features4d(input, coords)` samples the spatial features
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`. :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
The field is sampled at coordinates :attr:`coords` using bilinear The field is sampled at coordinates :attr:`coords` using bilinear
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R, interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
same convention as :func:`bilinear_sampler` with `align_corners=True`. same convention as :func:`bilinear_sampler` with `align_corners=True`.
The output tensor has one feature per point, and has shape :math:`(B, The output tensor has one feature per point, and has shape :math:`(B,
R, C)`. R, C)`.
Args: Args:
input (Tensor): spatial features. input (Tensor): spatial features.
coords (Tensor): points. coords (Tensor): points.
Returns: Returns:
Tensor: sampled features. Tensor: sampled features.
""" """
B, _, _, _ = input.shape B, _, _, _ = input.shape
# B R 2 -> B R 1 2 # B R 2 -> B R 1 2
coords = coords.unsqueeze(2) coords = coords.unsqueeze(2)
# B C R 1 # B C R 1
feats = bilinear_sampler(input, coords) feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 1, 3).view( return feats.permute(0, 2, 1, 3).view(
B, -1, feats.shape[1] * feats.shape[3] B, -1, feats.shape[1] * feats.shape[3]
) # B C R 1 -> B R C ) # B C R 1 -> B R C
def sample_features5d(input, coords): def sample_features5d(input, coords):
r"""Sample spatio-temporal features r"""Sample spatio-temporal features
`sample_features5d(input, coords)` works in the same way as `sample_features5d(input, coords)` works in the same way as
:func:`sample_features4d` but for spatio-temporal features and points: :func:`sample_features4d` but for spatio-temporal features and points:
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i, a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`. x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
Args: Args:
input (Tensor): spatio-temporal features. input (Tensor): spatio-temporal features.
coords (Tensor): spatio-temporal points. coords (Tensor): spatio-temporal points.
Returns: Returns:
Tensor: sampled features. Tensor: sampled features.
""" """
B, T, _, _, _ = input.shape B, T, _, _, _ = input.shape
# B T C H W -> B C T H W # B T C H W -> B C T H W
input = input.permute(0, 2, 1, 3, 4) input = input.permute(0, 2, 1, 3, 4)
# B R1 R2 3 -> B R1 R2 1 3 # B R1 R2 3 -> B R1 R2 1 3
coords = coords.unsqueeze(3) coords = coords.unsqueeze(3)
# B C R1 R2 1 # B C R1 R2 1
feats = bilinear_sampler(input, coords) feats = bilinear_sampler(input, coords)
return feats.permute(0, 2, 3, 1, 4).view( return feats.permute(0, 2, 3, 1, 4).view(
B, feats.shape[2], feats.shape[3], feats.shape[1] B, feats.shape[2], feats.shape[3], feats.shape[1]
) # B C R1 R2 1 -> B R1 R2 C ) # B C R1 R2 1 -> B R1 R2 C

View File

@@ -1,104 +1,104 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Tuple from typing import Tuple
from cotracker.models.core.cotracker.cotracker import CoTracker2 from cotracker.models.core.cotracker.cotracker import CoTracker2
from cotracker.models.core.model_utils import get_points_on_a_grid from cotracker.models.core.model_utils import get_points_on_a_grid
class EvaluationPredictor(torch.nn.Module): class EvaluationPredictor(torch.nn.Module):
def __init__( def __init__(
self, self,
cotracker_model: CoTracker2, cotracker_model: CoTracker2,
interp_shape: Tuple[int, int] = (384, 512), interp_shape: Tuple[int, int] = (384, 512),
grid_size: int = 5, grid_size: int = 5,
local_grid_size: int = 8, local_grid_size: int = 8,
single_point: bool = True, single_point: bool = True,
n_iters: int = 6, n_iters: int = 6,
) -> None: ) -> None:
super(EvaluationPredictor, self).__init__() super(EvaluationPredictor, self).__init__()
self.grid_size = grid_size self.grid_size = grid_size
self.local_grid_size = local_grid_size self.local_grid_size = local_grid_size
self.single_point = single_point self.single_point = single_point
self.interp_shape = interp_shape self.interp_shape = interp_shape
self.n_iters = n_iters self.n_iters = n_iters
self.model = cotracker_model self.model = cotracker_model
self.model.eval() self.model.eval()
def forward(self, video, queries): def forward(self, video, queries):
queries = queries.clone() queries = queries.clone()
B, T, C, H, W = video.shape B, T, C, H, W = video.shape
B, N, D = queries.shape B, N, D = queries.shape
assert D == 3 assert D == 3
video = video.reshape(B * T, C, H, W) video = video.reshape(B * T, C, H, W)
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True) video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
device = video.device device = video.device
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1) queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1) queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
if self.single_point: if self.single_point:
traj_e = torch.zeros((B, T, N, 2), device=device) traj_e = torch.zeros((B, T, N, 2), device=device)
vis_e = torch.zeros((B, T, N), device=device) vis_e = torch.zeros((B, T, N), device=device)
for pind in range((N)): for pind in range((N)):
query = queries[:, pind : pind + 1] query = queries[:, pind : pind + 1]
t = query[0, 0, 0].long() t = query[0, 0, 0].long()
traj_e_pind, vis_e_pind = self._process_one_point(video, query) traj_e_pind, vis_e_pind = self._process_one_point(video, query)
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1] traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1] vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
else: else:
if self.grid_size > 0: if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
queries = torch.cat([queries, xy], dim=1) # queries = torch.cat([queries, xy], dim=1) #
traj_e, vis_e, __ = self.model( traj_e, vis_e, __ = self.model(
video=video, video=video,
queries=queries, queries=queries,
iters=self.n_iters, iters=self.n_iters,
) )
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1) traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1) traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
return traj_e, vis_e return traj_e, vis_e
def _process_one_point(self, video, query): def _process_one_point(self, video, query):
t = query[0, 0, 0].long() t = query[0, 0, 0].long()
device = query.device device = query.device
if self.local_grid_size > 0: if self.local_grid_size > 0:
xy_target = get_points_on_a_grid( xy_target = get_points_on_a_grid(
self.local_grid_size, self.local_grid_size,
(50, 50), (50, 50),
[query[0, 0, 2].item(), query[0, 0, 1].item()], [query[0, 0, 2].item(), query[0, 0, 1].item()],
) )
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to( xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
device device
) # ) #
query = torch.cat([query, xy_target], dim=1) # query = torch.cat([query, xy_target], dim=1) #
if self.grid_size > 0: if self.grid_size > 0:
xy = get_points_on_a_grid(self.grid_size, video.shape[3:]) xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) # xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
query = torch.cat([query, xy], dim=1) # query = torch.cat([query, xy], dim=1) #
# crop the video to start from the queried frame # crop the video to start from the queried frame
query[0, 0, 0] = 0 query[0, 0, 0] = 0
traj_e_pind, vis_e_pind, __ = self.model( traj_e_pind, vis_e_pind, __ = self.model(
video=video[:, t:], queries=query, iters=self.n_iters video=video[:, t:], queries=query, iters=self.n_iters
) )
return traj_e_pind, vis_e_pind return traj_e_pind, vis_e_pind

View File

@@ -1,275 +1,279 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
from cotracker.models.build_cotracker import build_cotracker from cotracker.models.build_cotracker import build_cotracker
class CoTrackerPredictor(torch.nn.Module): class CoTrackerPredictor(torch.nn.Module):
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"): def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
super().__init__() super().__init__()
self.support_grid_size = 6 self.support_grid_size = 6
model = build_cotracker(checkpoint) model = build_cotracker(checkpoint)
self.interp_shape = model.model_resolution self.interp_shape = model.model_resolution
print(self.interp_shape) print(self.interp_shape)
self.model = model self.model = model
self.model.eval() self.model.eval()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
self, self,
video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
# input prompt types: # input prompt types:
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame. # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
# *backward_tracking=True* will compute tracks in both directions. # *backward_tracking=True* will compute tracks in both directions.
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates. # - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask. # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks. # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
queries: torch.Tensor = None, queries: torch.Tensor = None,
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W) segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
grid_size: int = 0, grid_size: int = 0,
grid_query_frame: int = 0, # only for dense and regular grid tracks grid_query_frame: int = 0, # only for dense and regular grid tracks
backward_tracking: bool = False, backward_tracking: bool = False,
): ):
if queries is None and grid_size == 0: if queries is None and grid_size == 0:
tracks, visibilities = self._compute_dense_tracks( tracks, visibilities = self._compute_dense_tracks(
video, video,
grid_query_frame=grid_query_frame, grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking, backward_tracking=backward_tracking,
) )
else: else:
tracks, visibilities = self._compute_sparse_tracks( tracks, visibilities = self._compute_sparse_tracks(
video, video,
queries, queries,
segm_mask, segm_mask,
grid_size, grid_size,
add_support_grid=(grid_size == 0 or segm_mask is not None), add_support_grid=(grid_size == 0 or segm_mask is not None),
grid_query_frame=grid_query_frame, grid_query_frame=grid_query_frame,
backward_tracking=backward_tracking, backward_tracking=backward_tracking,
) )
return tracks, visibilities return tracks, visibilities
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False): # gpu dense inference time
*_, H, W = video.shape # raft gpu comparison
grid_step = W // grid_size # vision effects
grid_width = W // grid_step # raft integrated
grid_height = H // grid_step # set the whole video to grid_size number of grids def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
tracks = visibilities = None *_, H, W = video.shape
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device) grid_step = W // grid_size
# (batch_size, grid_number, t,x,y) grid_width = W // grid_step
grid_pts[0, :, 0] = grid_query_frame grid_height = H // grid_step # set the whole video to grid_size number of grids
# iterate every grid tracks = visibilities = None
for offset in range(grid_step * grid_step): grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
print(f"step {offset} / {grid_step * grid_step}") # (batch_size, grid_number, t,x,y)
ox = offset % grid_step grid_pts[0, :, 0] = grid_query_frame
oy = offset // grid_step # iterate every grid
# initialize for offset in range(grid_step * grid_step):
# for example print(f"step {offset} / {grid_step * grid_step}")
# grid width = 4, grid height = 4, grid step = 10, ox = 1 ox = offset % grid_step
# torch.arange(grid_width) = [0,1,2,3] oy = offset // grid_step
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3] # initialize
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30] # for example
# get the location in the image # grid width = 4, grid height = 4, grid step = 10, ox = 1
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox # torch.arange(grid_width) = [0,1,2,3]
grid_pts[0, :, 2] = ( # torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy # torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
) # get the location in the image
tracks_step, visibilities_step = self._compute_sparse_tracks( grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
video=video, grid_pts[0, :, 2] = (
queries=grid_pts, torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
backward_tracking=backward_tracking, )
) tracks_step, visibilities_step = self._compute_sparse_tracks(
tracks = smart_cat(tracks, tracks_step, dim=2) video=video,
visibilities = smart_cat(visibilities, visibilities_step, dim=2) queries=grid_pts,
backward_tracking=backward_tracking,
return tracks, visibilities )
tracks = smart_cat(tracks, tracks_step, dim=2)
def _compute_sparse_tracks( visibilities = smart_cat(visibilities, visibilities_step, dim=2)
self,
video, return tracks, visibilities
queries,
segm_mask=None, def _compute_sparse_tracks(
grid_size=0, self,
add_support_grid=False, video,
grid_query_frame=0, queries,
backward_tracking=False, segm_mask=None,
): grid_size=0,
B, T, C, H, W = video.shape add_support_grid=False,
grid_query_frame=0,
video = video.reshape(B * T, C, H, W) backward_tracking=False,
# ? what is interpolate? ):
# 将video插值成interp_shape? B, T, C, H, W = video.shape
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) video = video.reshape(B * T, C, H, W)
# ? what is interpolate?
if queries is not None: # 将video插值成interp_shape?
B, N, D = queries.shape # batch_size, number of points, (t,x,y) video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
assert D == 3 video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
# query 缩放到( interp_shape - 1 ) / (W - 1)
# 插完值之后缩放 if queries is not None:
queries = queries.clone() B, N, D = queries.shape # batch_size, number of points, (t,x,y)
queries[:, :, 1:] *= queries.new_tensor( assert D == 3
[ # query 缩放到( interp_shape - 1 ) / (W - 1)
(self.interp_shape[1] - 1) / (W - 1), # 插完值之后缩放
(self.interp_shape[0] - 1) / (H - 1), queries = queries.clone()
] queries[:, :, 1:] *= queries.new_tensor(
) [
# 生成grid (self.interp_shape[1] - 1) / (W - 1),
elif grid_size > 0: (self.interp_shape[0] - 1) / (H - 1),
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device) ]
if segm_mask is not None: )
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest") # 生成grid
point_mask = segm_mask[0, 0][ elif grid_size > 0:
(grid_pts[0, :, 1]).round().long().cpu(), grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
(grid_pts[0, :, 0]).round().long().cpu(), if segm_mask is not None:
].bool() segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
grid_pts = grid_pts[:, point_mask] point_mask = segm_mask[0, 0][
(grid_pts[0, :, 1]).round().long().cpu(),
queries = torch.cat( (grid_pts[0, :, 0]).round().long().cpu(),
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], ].bool()
dim=2, grid_pts = grid_pts[:, point_mask]
).repeat(B, 1, 1)
queries = torch.cat(
# 添加支持点 [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
dim=2,
if add_support_grid: ).repeat(B, 1, 1)
grid_pts = get_points_on_a_grid(
self.support_grid_size, self.interp_shape, device=video.device # 添加支持点
)
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2) if add_support_grid:
grid_pts = grid_pts.repeat(B, 1, 1) grid_pts = get_points_on_a_grid(
queries = torch.cat([queries, grid_pts], dim=1) self.support_grid_size, self.interp_shape, device=video.device
)
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6) grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
grid_pts = grid_pts.repeat(B, 1, 1)
if backward_tracking: queries = torch.cat([queries, grid_pts], dim=1)
tracks, visibilities = self._compute_backward_tracks(
video, queries, tracks, visibilities tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
)
if add_support_grid: if backward_tracking:
queries[:, -self.support_grid_size**2 :, 0] = T - 1 tracks, visibilities = self._compute_backward_tracks(
if add_support_grid: video, queries, tracks, visibilities
tracks = tracks[:, :, : -self.support_grid_size**2] )
visibilities = visibilities[:, :, : -self.support_grid_size**2] if add_support_grid:
thr = 0.9 queries[:, -self.support_grid_size**2 :, 0] = T - 1
visibilities = visibilities > thr if add_support_grid:
tracks = tracks[:, :, : -self.support_grid_size**2]
# correct query-point predictions visibilities = visibilities[:, :, : -self.support_grid_size**2]
# see https://github.com/facebookresearch/co-tracker/issues/28 thr = 0.9
visibilities = visibilities > thr
# TODO: batchify
for i in range(len(queries)): # correct query-point predictions
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64) # see https://github.com/facebookresearch/co-tracker/issues/28
arange = torch.arange(0, len(queries_t))
# TODO: batchify
# overwrite the predictions with the query points for i in range(len(queries)):
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:] queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
arange = torch.arange(0, len(queries_t))
# correct visibilities, the query points should be visible
visibilities[i, queries_t, arange] = True # overwrite the predictions with the query points
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
tracks *= tracks.new_tensor(
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)] # correct visibilities, the query points should be visible
) visibilities[i, queries_t, arange] = True
return tracks, visibilities
tracks *= tracks.new_tensor(
def _compute_backward_tracks(self, video, queries, tracks, visibilities): [(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
inv_video = video.flip(1).clone() )
inv_queries = queries.clone() return tracks, visibilities
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6) inv_video = video.flip(1).clone()
inv_queries = queries.clone()
inv_tracks = inv_tracks.flip(1) inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
inv_visibilities = inv_visibilities.flip(1)
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None] inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2) inv_tracks = inv_tracks.flip(1)
inv_visibilities = inv_visibilities.flip(1)
tracks[mask] = inv_tracks[mask] arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
return tracks, visibilities mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
tracks[mask] = inv_tracks[mask]
class CoTrackerOnlinePredictor(torch.nn.Module): visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"): return tracks, visibilities
super().__init__()
self.support_grid_size = 6
model = build_cotracker(checkpoint) class CoTrackerOnlinePredictor(torch.nn.Module):
self.interp_shape = model.model_resolution def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
self.step = model.window_len // 2 super().__init__()
self.model = model self.support_grid_size = 6
self.model.eval() model = build_cotracker(checkpoint)
self.interp_shape = model.model_resolution
@torch.no_grad() self.step = model.window_len // 2
def forward( self.model = model
self, self.model.eval()
video_chunk,
is_first_step: bool = False, @torch.no_grad()
queries: torch.Tensor = None, def forward(
grid_size: int = 10, self,
grid_query_frame: int = 0, video_chunk,
add_support_grid=False, is_first_step: bool = False,
): queries: torch.Tensor = None,
B, T, C, H, W = video_chunk.shape grid_size: int = 10,
# Initialize online video processing and save queried points grid_query_frame: int = 0,
# This needs to be done before processing *each new video* add_support_grid=False,
if is_first_step: ):
self.model.init_video_online_processing() B, T, C, H, W = video_chunk.shape
if queries is not None: # Initialize online video processing and save queried points
B, N, D = queries.shape # This needs to be done before processing *each new video*
assert D == 3 if is_first_step:
queries = queries.clone() self.model.init_video_online_processing()
queries[:, :, 1:] *= queries.new_tensor( if queries is not None:
[ B, N, D = queries.shape
(self.interp_shape[1] - 1) / (W - 1), assert D == 3
(self.interp_shape[0] - 1) / (H - 1), queries = queries.clone()
] queries[:, :, 1:] *= queries.new_tensor(
) [
elif grid_size > 0: (self.interp_shape[1] - 1) / (W - 1),
grid_pts = get_points_on_a_grid( (self.interp_shape[0] - 1) / (H - 1),
grid_size, self.interp_shape, device=video_chunk.device ]
) )
queries = torch.cat( elif grid_size > 0:
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts], grid_pts = get_points_on_a_grid(
dim=2, grid_size, self.interp_shape, device=video_chunk.device
) )
if add_support_grid: queries = torch.cat(
grid_pts = get_points_on_a_grid( [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
self.support_grid_size, self.interp_shape, device=video_chunk.device dim=2,
) )
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2) if add_support_grid:
queries = torch.cat([queries, grid_pts], dim=1) grid_pts = get_points_on_a_grid(
self.queries = queries self.support_grid_size, self.interp_shape, device=video_chunk.device
return (None, None) )
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
video_chunk = video_chunk.reshape(B * T, C, H, W) queries = torch.cat([queries, grid_pts], dim=1)
video_chunk = F.interpolate( self.queries = queries
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True return (None, None)
)
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1]) video_chunk = video_chunk.reshape(B * T, C, H, W)
video_chunk = F.interpolate(
tracks, visibilities, __ = self.model( video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
video=video_chunk, )
queries=self.queries, video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
iters=6,
is_online=True, tracks, visibilities, __ = self.model(
) video=video_chunk,
thr = 0.9 queries=self.queries,
return ( iters=6,
tracks is_online=True,
* tracks.new_tensor( )
[ thr = 0.9
(W - 1) / (self.interp_shape[1] - 1), return (
(H - 1) / (self.interp_shape[0] - 1), tracks
] * tracks.new_tensor(
), [
visibilities > thr, (W - 1) / (self.interp_shape[1] - 1),
) (H - 1) / (self.interp_shape[0] - 1),
]
),
visibilities > thr,
)

View File

@@ -1,5 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,343 +1,343 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import os import os
import numpy as np import numpy as np
import imageio import imageio
import torch import torch
from matplotlib import cm from matplotlib import cm
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.transforms as transforms import torchvision.transforms as transforms
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
from PIL import Image, ImageDraw from PIL import Image, ImageDraw
def read_video_from_path(path): def read_video_from_path(path):
try: try:
reader = imageio.get_reader(path) reader = imageio.get_reader(path)
except Exception as e: except Exception as e:
print("Error opening video file: ", e) print("Error opening video file: ", e)
return None return None
frames = [] frames = []
for i, im in enumerate(reader): for i, im in enumerate(reader):
frames.append(np.array(im)) frames.append(np.array(im))
return np.stack(frames) return np.stack(frames)
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True): def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
# Create a draw object # Create a draw object
draw = ImageDraw.Draw(rgb) draw = ImageDraw.Draw(rgb)
# Calculate the bounding box of the circle # Calculate the bounding box of the circle
left_up_point = (coord[0] - radius, coord[1] - radius) left_up_point = (coord[0] - radius, coord[1] - radius)
right_down_point = (coord[0] + radius, coord[1] + radius) right_down_point = (coord[0] + radius, coord[1] + radius)
# Draw the circle # Draw the circle
draw.ellipse( draw.ellipse(
[left_up_point, right_down_point], [left_up_point, right_down_point],
fill=tuple(color) if visible else None, fill=tuple(color) if visible else None,
outline=tuple(color), outline=tuple(color),
) )
return rgb return rgb
def draw_line(rgb, coord_y, coord_x, color, linewidth): def draw_line(rgb, coord_y, coord_x, color, linewidth):
draw = ImageDraw.Draw(rgb) draw = ImageDraw.Draw(rgb)
draw.line( draw.line(
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]), (coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
fill=tuple(color), fill=tuple(color),
width=linewidth, width=linewidth,
) )
return rgb return rgb
def add_weighted(rgb, alpha, original, beta, gamma): def add_weighted(rgb, alpha, original, beta, gamma):
return (rgb * alpha + original * beta + gamma).astype("uint8") return (rgb * alpha + original * beta + gamma).astype("uint8")
class Visualizer: class Visualizer:
def __init__( def __init__(
self, self,
save_dir: str = "./results", save_dir: str = "./results",
grayscale: bool = False, grayscale: bool = False,
pad_value: int = 0, pad_value: int = 0,
fps: int = 10, fps: int = 10,
mode: str = "rainbow", # 'cool', 'optical_flow' mode: str = "rainbow", # 'cool', 'optical_flow'
linewidth: int = 2, linewidth: int = 2,
show_first_frame: int = 10, show_first_frame: int = 10,
tracks_leave_trace: int = 0, # -1 for infinite tracks_leave_trace: int = 0, # -1 for infinite
): ):
self.mode = mode self.mode = mode
self.save_dir = save_dir self.save_dir = save_dir
if mode == "rainbow": if mode == "rainbow":
self.color_map = cm.get_cmap("gist_rainbow") self.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool": elif mode == "cool":
self.color_map = cm.get_cmap(mode) self.color_map = cm.get_cmap(mode)
self.show_first_frame = show_first_frame self.show_first_frame = show_first_frame
self.grayscale = grayscale self.grayscale = grayscale
self.tracks_leave_trace = tracks_leave_trace self.tracks_leave_trace = tracks_leave_trace
self.pad_value = pad_value self.pad_value = pad_value
self.linewidth = linewidth self.linewidth = linewidth
self.fps = fps self.fps = fps
def visualize( def visualize(
self, self,
video: torch.Tensor, # (B,T,C,H,W) video: torch.Tensor, # (B,T,C,H,W)
tracks: torch.Tensor, # (B,T,N,2) tracks: torch.Tensor, # (B,T,N,2)
visibility: torch.Tensor = None, # (B, T, N, 1) bool visibility: torch.Tensor = None, # (B, T, N, 1) bool
gt_tracks: torch.Tensor = None, # (B,T,N,2) gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W) segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video", filename: str = "video",
writer=None, # tensorboard Summary Writer, used for visualization during training writer=None, # tensorboard Summary Writer, used for visualization during training
step: int = 0, step: int = 0,
query_frame: int = 0, query_frame: int = 0,
save_video: bool = True, save_video: bool = True,
compensate_for_camera_motion: bool = False, compensate_for_camera_motion: bool = False,
): ):
if compensate_for_camera_motion: if compensate_for_camera_motion:
assert segm_mask is not None assert segm_mask is not None
if segm_mask is not None: if segm_mask is not None:
coords = tracks[0, query_frame].round().long() coords = tracks[0, query_frame].round().long()
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long() segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
video = F.pad( video = F.pad(
video, video,
(self.pad_value, self.pad_value, self.pad_value, self.pad_value), (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
"constant", "constant",
255, 255,
) )
tracks = tracks + self.pad_value tracks = tracks + self.pad_value
if self.grayscale: if self.grayscale:
transform = transforms.Grayscale() transform = transforms.Grayscale()
video = transform(video) video = transform(video)
video = video.repeat(1, 1, 3, 1, 1) video = video.repeat(1, 1, 3, 1, 1)
res_video = self.draw_tracks_on_video( res_video = self.draw_tracks_on_video(
video=video, video=video,
tracks=tracks, tracks=tracks,
visibility=visibility, visibility=visibility,
segm_mask=segm_mask, segm_mask=segm_mask,
gt_tracks=gt_tracks, gt_tracks=gt_tracks,
query_frame=query_frame, query_frame=query_frame,
compensate_for_camera_motion=compensate_for_camera_motion, compensate_for_camera_motion=compensate_for_camera_motion,
) )
if save_video: if save_video:
self.save_video(res_video, filename=filename, writer=writer, step=step) self.save_video(res_video, filename=filename, writer=writer, step=step)
return res_video return res_video
def save_video(self, video, filename, writer=None, step=0): def save_video(self, video, filename, writer=None, step=0):
if writer is not None: if writer is not None:
writer.add_video( writer.add_video(
filename, filename,
video.to(torch.uint8), video.to(torch.uint8),
global_step=step, global_step=step,
fps=self.fps, fps=self.fps,
) )
else: else:
os.makedirs(self.save_dir, exist_ok=True) os.makedirs(self.save_dir, exist_ok=True)
wide_list = list(video.unbind(1)) wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list] wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
# Prepare the video file path # Prepare the video file path
save_path = os.path.join(self.save_dir, f"{filename}.mp4") save_path = os.path.join(self.save_dir, f"{filename}.mp4")
# Create a writer object # Create a writer object
video_writer = imageio.get_writer(save_path, fps=self.fps) video_writer = imageio.get_writer(save_path, fps=self.fps)
# Write frames to the video file # Write frames to the video file
for frame in wide_list[2:-1]: for frame in wide_list[2:-1]:
video_writer.append_data(frame) video_writer.append_data(frame)
video_writer.close() video_writer.close()
print(f"Video saved to {save_path}") print(f"Video saved to {save_path}")
def draw_tracks_on_video( def draw_tracks_on_video(
self, self,
video: torch.Tensor, video: torch.Tensor,
tracks: torch.Tensor, tracks: torch.Tensor,
visibility: torch.Tensor = None, visibility: torch.Tensor = None,
segm_mask: torch.Tensor = None, segm_mask: torch.Tensor = None,
gt_tracks=None, gt_tracks=None,
query_frame: int = 0, query_frame: int = 0,
compensate_for_camera_motion=False, compensate_for_camera_motion=False,
): ):
B, T, C, H, W = video.shape B, T, C, H, W = video.shape
_, _, N, D = tracks.shape _, _, N, D = tracks.shape
assert D == 2 assert D == 2
assert C == 3 assert C == 3
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2 tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
if gt_tracks is not None: if gt_tracks is not None:
gt_tracks = gt_tracks[0].detach().cpu().numpy() gt_tracks = gt_tracks[0].detach().cpu().numpy()
res_video = [] res_video = []
# process input video # process input video
for rgb in video: for rgb in video:
res_video.append(rgb.copy()) res_video.append(rgb.copy())
vector_colors = np.zeros((T, N, 3)) vector_colors = np.zeros((T, N, 3))
if self.mode == "optical_flow": if self.mode == "optical_flow":
import flow_vis import flow_vis
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None]) vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
elif segm_mask is None: elif segm_mask is None:
if self.mode == "rainbow": if self.mode == "rainbow":
y_min, y_max = ( y_min, y_max = (
tracks[query_frame, :, 1].min(), tracks[query_frame, :, 1].min(),
tracks[query_frame, :, 1].max(), tracks[query_frame, :, 1].max(),
) )
norm = plt.Normalize(y_min, y_max) norm = plt.Normalize(y_min, y_max)
for n in range(N): for n in range(N):
color = self.color_map(norm(tracks[query_frame, n, 1])) color = self.color_map(norm(tracks[query_frame, n, 1]))
color = np.array(color[:3])[None] * 255 color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0) vector_colors[:, n] = np.repeat(color, T, axis=0)
else: else:
# color changes with time # color changes with time
for t in range(T): for t in range(T):
color = np.array(self.color_map(t / T)[:3])[None] * 255 color = np.array(self.color_map(t / T)[:3])[None] * 255
vector_colors[t] = np.repeat(color, N, axis=0) vector_colors[t] = np.repeat(color, N, axis=0)
else: else:
if self.mode == "rainbow": if self.mode == "rainbow":
vector_colors[:, segm_mask <= 0, :] = 255 vector_colors[:, segm_mask <= 0, :] = 255
y_min, y_max = ( y_min, y_max = (
tracks[0, segm_mask > 0, 1].min(), tracks[0, segm_mask > 0, 1].min(),
tracks[0, segm_mask > 0, 1].max(), tracks[0, segm_mask > 0, 1].max(),
) )
norm = plt.Normalize(y_min, y_max) norm = plt.Normalize(y_min, y_max)
for n in range(N): for n in range(N):
if segm_mask[n] > 0: if segm_mask[n] > 0:
color = self.color_map(norm(tracks[0, n, 1])) color = self.color_map(norm(tracks[0, n, 1]))
color = np.array(color[:3])[None] * 255 color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0) vector_colors[:, n] = np.repeat(color, T, axis=0)
else: else:
# color changes with segm class # color changes with segm class
segm_mask = segm_mask.cpu() segm_mask = segm_mask.cpu()
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32) color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0 color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0 color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
vector_colors = np.repeat(color[None], T, axis=0) vector_colors = np.repeat(color[None], T, axis=0)
# draw tracks # draw tracks
if self.tracks_leave_trace != 0: if self.tracks_leave_trace != 0:
for t in range(query_frame + 1, T): for t in range(query_frame + 1, T):
first_ind = ( first_ind = (
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0 max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
) )
curr_tracks = tracks[first_ind : t + 1] curr_tracks = tracks[first_ind : t + 1]
curr_colors = vector_colors[first_ind : t + 1] curr_colors = vector_colors[first_ind : t + 1]
if compensate_for_camera_motion: if compensate_for_camera_motion:
diff = ( diff = (
tracks[first_ind : t + 1, segm_mask <= 0] tracks[first_ind : t + 1, segm_mask <= 0]
- tracks[t : t + 1, segm_mask <= 0] - tracks[t : t + 1, segm_mask <= 0]
).mean(1)[:, None] ).mean(1)[:, None]
curr_tracks = curr_tracks - diff curr_tracks = curr_tracks - diff
curr_tracks = curr_tracks[:, segm_mask > 0] curr_tracks = curr_tracks[:, segm_mask > 0]
curr_colors = curr_colors[:, segm_mask > 0] curr_colors = curr_colors[:, segm_mask > 0]
res_video[t] = self._draw_pred_tracks( res_video[t] = self._draw_pred_tracks(
res_video[t], res_video[t],
curr_tracks, curr_tracks,
curr_colors, curr_colors,
) )
if gt_tracks is not None: if gt_tracks is not None:
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1]) res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
# draw points # draw points
for t in range(query_frame, T): for t in range(query_frame, T):
img = Image.fromarray(np.uint8(res_video[t])) img = Image.fromarray(np.uint8(res_video[t]))
for i in range(N): for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1]) coord = (tracks[t, i, 0], tracks[t, i, 1])
visibile = True visibile = True
if visibility is not None: if visibility is not None:
visibile = visibility[0, t, i] visibile = visibility[0, t, i]
if coord[0] != 0 and coord[1] != 0: if coord[0] != 0 and coord[1] != 0:
if not compensate_for_camera_motion or ( if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0 compensate_for_camera_motion and segm_mask[i] > 0
): ):
img = draw_circle( img = draw_circle(
img, img,
coord=coord, coord=coord,
radius=int(self.linewidth * 2), radius=int(self.linewidth * 2),
color=vector_colors[t, i].astype(int), color=vector_colors[t, i].astype(int),
visible=visibile, visible=visibile,
) )
res_video[t] = np.array(img) res_video[t] = np.array(img)
# construct the final rgb sequence # construct the final rgb sequence
if self.show_first_frame > 0: if self.show_first_frame > 0:
res_video = [res_video[0]] * self.show_first_frame + res_video[1:] res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte() return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
def _draw_pred_tracks( def _draw_pred_tracks(
self, self,
rgb: np.ndarray, # H x W x 3 rgb: np.ndarray, # H x W x 3
tracks: np.ndarray, # T x 2 tracks: np.ndarray, # T x 2
vector_colors: np.ndarray, vector_colors: np.ndarray,
alpha: float = 0.5, alpha: float = 0.5,
): ):
T, N, _ = tracks.shape T, N, _ = tracks.shape
rgb = Image.fromarray(np.uint8(rgb)) rgb = Image.fromarray(np.uint8(rgb))
for s in range(T - 1): for s in range(T - 1):
vector_color = vector_colors[s] vector_color = vector_colors[s]
original = rgb.copy() original = rgb.copy()
alpha = (s / T) ** 2 alpha = (s / T) ** 2
for i in range(N): for i in range(N):
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1])) coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1])) coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
if coord_y[0] != 0 and coord_y[1] != 0: if coord_y[0] != 0 and coord_y[1] != 0:
rgb = draw_line( rgb = draw_line(
rgb, rgb,
coord_y, coord_y,
coord_x, coord_x,
vector_color[i].astype(int), vector_color[i].astype(int),
self.linewidth, self.linewidth,
) )
if self.tracks_leave_trace > 0: if self.tracks_leave_trace > 0:
rgb = Image.fromarray( rgb = Image.fromarray(
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0)) np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
) )
rgb = np.array(rgb) rgb = np.array(rgb)
return rgb return rgb
def _draw_gt_tracks( def _draw_gt_tracks(
self, self,
rgb: np.ndarray, # H x W x 3, rgb: np.ndarray, # H x W x 3,
gt_tracks: np.ndarray, # T x 2 gt_tracks: np.ndarray, # T x 2
): ):
T, N, _ = gt_tracks.shape T, N, _ = gt_tracks.shape
color = np.array((211, 0, 0)) color = np.array((211, 0, 0))
rgb = Image.fromarray(np.uint8(rgb)) rgb = Image.fromarray(np.uint8(rgb))
for t in range(T): for t in range(T):
for i in range(N): for i in range(N):
gt_tracks = gt_tracks[t][i] gt_tracks = gt_tracks[t][i]
# draw a red cross # draw a red cross
if gt_tracks[0] > 0 and gt_tracks[1] > 0: if gt_tracks[0] > 0 and gt_tracks[1] > 0:
length = self.linewidth * 3 length = self.linewidth * 3
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length) coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length) coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
rgb = draw_line( rgb = draw_line(
rgb, rgb,
coord_y, coord_y,
coord_x, coord_x,
color, color,
self.linewidth, self.linewidth,
) )
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length) coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length) coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
rgb = draw_line( rgb = draw_line(
rgb, rgb,
coord_y, coord_y,
coord_x, coord_x,
color, color,
self.linewidth, self.linewidth,
) )
rgb = np.array(rgb) rgb = np.array(rgb)
return rgb return rgb

View File

@@ -1,8 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
# This source code is licensed under the license found in the # This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
__version__ = "2.0.0" __version__ = "2.0.0"

File diff suppressed because one or more lines are too long