diff --git a/cotracker/__init__.py b/cotracker/__init__.py
index 5277f46..4547e07 100644
--- a/cotracker/__init__.py
+++ b/cotracker/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/cotracker/__pycache__/__init__.cpython-38.pyc b/cotracker/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..359e6a7
Binary files /dev/null and b/cotracker/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cotracker/__pycache__/__init__.cpython-39.pyc b/cotracker/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..14e03a2
Binary files /dev/null and b/cotracker/__pycache__/__init__.cpython-39.pyc differ
diff --git a/cotracker/__pycache__/predictor.cpython-38.pyc b/cotracker/__pycache__/predictor.cpython-38.pyc
new file mode 100644
index 0000000..41ef9a6
Binary files /dev/null and b/cotracker/__pycache__/predictor.cpython-38.pyc differ
diff --git a/cotracker/__pycache__/predictor.cpython-39.pyc b/cotracker/__pycache__/predictor.cpython-39.pyc
new file mode 100644
index 0000000..a8a66d3
Binary files /dev/null and b/cotracker/__pycache__/predictor.cpython-39.pyc differ
diff --git a/cotracker/datasets/__init__.py b/cotracker/datasets/__init__.py
index 5277f46..4547e07 100644
--- a/cotracker/datasets/__init__.py
+++ b/cotracker/datasets/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/cotracker/datasets/dataclass_utils.py b/cotracker/datasets/dataclass_utils.py
index 11e103b..35be845 100644
--- a/cotracker/datasets/dataclass_utils.py
+++ b/cotracker/datasets/dataclass_utils.py
@@ -1,166 +1,166 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-import json
-import dataclasses
-import numpy as np
-from dataclasses import Field, MISSING
-from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
-
-_X = TypeVar("_X")
-
-
-def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
- """
- Loads to a @dataclass or collection hierarchy including dataclasses
- from a json recursively.
- Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
- raises KeyError if json has keys not mapping to the dataclass fields.
-
- Args:
- f: Either a path to a file, or a file opened for writing.
- cls: The class of the loaded dataclass.
- binary: Set to True if `f` is a file handle, else False.
- """
- if binary:
- asdict = json.loads(f.read().decode("utf8"))
- else:
- asdict = json.load(f)
-
- # in the list case, run a faster "vectorized" version
- cls = get_args(cls)[0]
- res = list(_dataclass_list_from_dict_list(asdict, cls))
-
- return res
-
-
-def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
- """Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
- if get_origin(type_) is Union:
- args = get_args(type_)
- if len(args) == 2 and args[1] == type(None): # noqa E721
- return True, args[0]
- if type_ is Any:
- return True, Any
-
- return False, type_
-
-
-def _unwrap_type(tp):
- # strips Optional wrapper, if any
- if get_origin(tp) is Union:
- args = get_args(tp)
- if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
- # this is typing.Optional
- return args[0] if args[1] is type(None) else args[1] # noqa: E721
- return tp
-
-
-def _get_dataclass_field_default(field: Field) -> Any:
- if field.default_factory is not MISSING:
- # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
- # dataclasses._DefaultFactory[typing.Any]]` is not a function.
- return field.default_factory()
- elif field.default is not MISSING:
- return field.default
- else:
- return None
-
-
-def _dataclass_list_from_dict_list(dlist, typeannot):
- """
- Vectorised version of `_dataclass_from_dict`.
- The output should be equivalent to
- `[_dataclass_from_dict(d, typeannot) for d in dlist]`.
-
- Args:
- dlist: list of objects to convert.
- typeannot: type of each of those objects.
- Returns:
- iterator or list over converted objects of the same length as `dlist`.
-
- Raises:
- ValueError: it assumes the objects have None's in consistent places across
- objects, otherwise it would ignore some values. This generally holds for
- auto-generated annotations, but otherwise use `_dataclass_from_dict`.
- """
-
- cls = get_origin(typeannot) or typeannot
-
- if typeannot is Any:
- return dlist
- if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
- return dlist
- if any(obj is None for obj in dlist):
- # filter out Nones and recurse on the resulting list
- idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
- idx, notnone = zip(*idx_notnone)
- converted = _dataclass_list_from_dict_list(notnone, typeannot)
- res = [None] * len(dlist)
- for i, obj in zip(idx, converted):
- res[i] = obj
- return res
-
- is_optional, contained_type = _resolve_optional(typeannot)
- if is_optional:
- return _dataclass_list_from_dict_list(dlist, contained_type)
-
- # otherwise, we dispatch by the type of the provided annotation to convert to
- if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
- # For namedtuple, call the function recursively on the lists of corresponding keys
- types = cls.__annotations__.values()
- dlist_T = zip(*dlist)
- res_T = [
- _dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
- ]
- return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
- elif issubclass(cls, (list, tuple)):
- # For list/tuple, call the function recursively on the lists of corresponding positions
- types = get_args(typeannot)
- if len(types) == 1: # probably List; replicate for all items
- types = types * len(dlist[0])
- dlist_T = zip(*dlist)
- res_T = (
- _dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
- )
- if issubclass(cls, tuple):
- return list(zip(*res_T))
- else:
- return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
- elif issubclass(cls, dict):
- # For the dictionary, call the function recursively on concatenated keys and vertices
- key_t, val_t = get_args(typeannot)
- all_keys_res = _dataclass_list_from_dict_list(
- [k for obj in dlist for k in obj.keys()], key_t
- )
- all_vals_res = _dataclass_list_from_dict_list(
- [k for obj in dlist for k in obj.values()], val_t
- )
- indices = np.cumsum([len(obj) for obj in dlist])
- assert indices[-1] == len(all_keys_res)
-
- keys = np.split(list(all_keys_res), indices[:-1])
- all_vals_res_iter = iter(all_vals_res)
- return [cls(zip(k, all_vals_res_iter)) for k in keys]
- elif not dataclasses.is_dataclass(typeannot):
- return dlist
-
- # dataclass node: 2nd recursion base; call the function recursively on the lists
- # of the corresponding fields
- assert dataclasses.is_dataclass(cls)
- fieldtypes = {
- f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
- for f in dataclasses.fields(typeannot)
- }
-
- # NOTE the default object is shared here
- key_lists = (
- _dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
- for k, (type_, default) in fieldtypes.items()
- )
- transposed = zip(*key_lists)
- return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import json
+import dataclasses
+import numpy as np
+from dataclasses import Field, MISSING
+from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
+
+_X = TypeVar("_X")
+
+
+def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
+ """
+ Loads to a @dataclass or collection hierarchy including dataclasses
+ from a json recursively.
+ Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
+ raises KeyError if json has keys not mapping to the dataclass fields.
+
+ Args:
+ f: Either a path to a file, or a file opened for writing.
+ cls: The class of the loaded dataclass.
+ binary: Set to True if `f` is a file handle, else False.
+ """
+ if binary:
+ asdict = json.loads(f.read().decode("utf8"))
+ else:
+ asdict = json.load(f)
+
+ # in the list case, run a faster "vectorized" version
+ cls = get_args(cls)[0]
+ res = list(_dataclass_list_from_dict_list(asdict, cls))
+
+ return res
+
+
+def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
+ """Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
+ if get_origin(type_) is Union:
+ args = get_args(type_)
+ if len(args) == 2 and args[1] == type(None): # noqa E721
+ return True, args[0]
+ if type_ is Any:
+ return True, Any
+
+ return False, type_
+
+
+def _unwrap_type(tp):
+ # strips Optional wrapper, if any
+ if get_origin(tp) is Union:
+ args = get_args(tp)
+ if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
+ # this is typing.Optional
+ return args[0] if args[1] is type(None) else args[1] # noqa: E721
+ return tp
+
+
+def _get_dataclass_field_default(field: Field) -> Any:
+ if field.default_factory is not MISSING:
+ # pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
+ # dataclasses._DefaultFactory[typing.Any]]` is not a function.
+ return field.default_factory()
+ elif field.default is not MISSING:
+ return field.default
+ else:
+ return None
+
+
+def _dataclass_list_from_dict_list(dlist, typeannot):
+ """
+ Vectorised version of `_dataclass_from_dict`.
+ The output should be equivalent to
+ `[_dataclass_from_dict(d, typeannot) for d in dlist]`.
+
+ Args:
+ dlist: list of objects to convert.
+ typeannot: type of each of those objects.
+ Returns:
+ iterator or list over converted objects of the same length as `dlist`.
+
+ Raises:
+ ValueError: it assumes the objects have None's in consistent places across
+ objects, otherwise it would ignore some values. This generally holds for
+ auto-generated annotations, but otherwise use `_dataclass_from_dict`.
+ """
+
+ cls = get_origin(typeannot) or typeannot
+
+ if typeannot is Any:
+ return dlist
+ if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
+ return dlist
+ if any(obj is None for obj in dlist):
+ # filter out Nones and recurse on the resulting list
+ idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
+ idx, notnone = zip(*idx_notnone)
+ converted = _dataclass_list_from_dict_list(notnone, typeannot)
+ res = [None] * len(dlist)
+ for i, obj in zip(idx, converted):
+ res[i] = obj
+ return res
+
+ is_optional, contained_type = _resolve_optional(typeannot)
+ if is_optional:
+ return _dataclass_list_from_dict_list(dlist, contained_type)
+
+ # otherwise, we dispatch by the type of the provided annotation to convert to
+ if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
+ # For namedtuple, call the function recursively on the lists of corresponding keys
+ types = cls.__annotations__.values()
+ dlist_T = zip(*dlist)
+ res_T = [
+ _dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
+ ]
+ return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
+ elif issubclass(cls, (list, tuple)):
+ # For list/tuple, call the function recursively on the lists of corresponding positions
+ types = get_args(typeannot)
+ if len(types) == 1: # probably List; replicate for all items
+ types = types * len(dlist[0])
+ dlist_T = zip(*dlist)
+ res_T = (
+ _dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
+ )
+ if issubclass(cls, tuple):
+ return list(zip(*res_T))
+ else:
+ return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
+ elif issubclass(cls, dict):
+ # For the dictionary, call the function recursively on concatenated keys and vertices
+ key_t, val_t = get_args(typeannot)
+ all_keys_res = _dataclass_list_from_dict_list(
+ [k for obj in dlist for k in obj.keys()], key_t
+ )
+ all_vals_res = _dataclass_list_from_dict_list(
+ [k for obj in dlist for k in obj.values()], val_t
+ )
+ indices = np.cumsum([len(obj) for obj in dlist])
+ assert indices[-1] == len(all_keys_res)
+
+ keys = np.split(list(all_keys_res), indices[:-1])
+ all_vals_res_iter = iter(all_vals_res)
+ return [cls(zip(k, all_vals_res_iter)) for k in keys]
+ elif not dataclasses.is_dataclass(typeannot):
+ return dlist
+
+ # dataclass node: 2nd recursion base; call the function recursively on the lists
+ # of the corresponding fields
+ assert dataclasses.is_dataclass(cls)
+ fieldtypes = {
+ f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
+ for f in dataclasses.fields(typeannot)
+ }
+
+ # NOTE the default object is shared here
+ key_lists = (
+ _dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
+ for k, (type_, default) in fieldtypes.items()
+ )
+ transposed = zip(*key_lists)
+ return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
diff --git a/cotracker/datasets/dr_dataset.py b/cotracker/datasets/dr_dataset.py
index 70af653..9a31884 100644
--- a/cotracker/datasets/dr_dataset.py
+++ b/cotracker/datasets/dr_dataset.py
@@ -1,161 +1,161 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-import os
-import gzip
-import torch
-import numpy as np
-import torch.utils.data as data
-from collections import defaultdict
-from dataclasses import dataclass
-from typing import List, Optional, Any, Dict, Tuple
-
-from cotracker.datasets.utils import CoTrackerData
-from cotracker.datasets.dataclass_utils import load_dataclass
-
-
-@dataclass
-class ImageAnnotation:
- # path to jpg file, relative w.r.t. dataset_root
- path: str
- # H x W
- size: Tuple[int, int]
-
-
-@dataclass
-class DynamicReplicaFrameAnnotation:
- """A dataclass used to load annotations from json."""
-
- # can be used to join with `SequenceAnnotation`
- sequence_name: str
- # 0-based, continuous frame number within sequence
- frame_number: int
- # timestamp in seconds from the video start
- frame_timestamp: float
-
- image: ImageAnnotation
- meta: Optional[Dict[str, Any]] = None
-
- camera_name: Optional[str] = None
- trajectories: Optional[str] = None
-
-
-class DynamicReplicaDataset(data.Dataset):
- def __init__(
- self,
- root,
- split="valid",
- traj_per_sample=256,
- crop_size=None,
- sample_len=-1,
- only_first_n_samples=-1,
- rgbd_input=False,
- ):
- super(DynamicReplicaDataset, self).__init__()
- self.root = root
- self.sample_len = sample_len
- self.split = split
- self.traj_per_sample = traj_per_sample
- self.rgbd_input = rgbd_input
- self.crop_size = crop_size
- frame_annotations_file = f"frame_annotations_{split}.jgz"
- self.sample_list = []
- with gzip.open(
- os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
- ) as zipfile:
- frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
- seq_annot = defaultdict(list)
- for frame_annot in frame_annots_list:
- if frame_annot.camera_name == "left":
- seq_annot[frame_annot.sequence_name].append(frame_annot)
-
- for seq_name in seq_annot.keys():
- seq_len = len(seq_annot[seq_name])
-
- step = self.sample_len if self.sample_len > 0 else seq_len
- counter = 0
-
- for ref_idx in range(0, seq_len, step):
- sample = seq_annot[seq_name][ref_idx : ref_idx + step]
- self.sample_list.append(sample)
- counter += 1
- if only_first_n_samples > 0 and counter >= only_first_n_samples:
- break
-
- def __len__(self):
- return len(self.sample_list)
-
- def crop(self, rgbs, trajs):
- T, N, _ = trajs.shape
-
- S = len(rgbs)
- H, W = rgbs[0].shape[:2]
- assert S == T
-
- H_new = H
- W_new = W
-
- # simple random crop
- y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
- x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
- rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
-
- trajs[:, :, 0] -= x0
- trajs[:, :, 1] -= y0
-
- return rgbs, trajs
-
- def __getitem__(self, index):
- sample = self.sample_list[index]
- T = len(sample)
- rgbs, visibilities, traj_2d = [], [], []
-
- H, W = sample[0].image.size
- image_size = (H, W)
-
- for i in range(T):
- traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
- traj = torch.load(traj_path)
-
- visibilities.append(traj["verts_inds_vis"].numpy())
-
- rgbs.append(traj["img"].numpy())
- traj_2d.append(traj["traj_2d"].numpy()[..., :2])
-
- traj_2d = np.stack(traj_2d)
- visibility = np.stack(visibilities)
- T, N, D = traj_2d.shape
- # subsample trajectories for augmentations
- visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
-
- traj_2d = traj_2d[:, visible_inds_sampled]
- visibility = visibility[:, visible_inds_sampled]
-
- if self.crop_size is not None:
- rgbs, traj_2d = self.crop(rgbs, traj_2d)
- H, W, _ = rgbs[0].shape
- image_size = self.crop_size
-
- visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
- visibility[traj_2d[:, :, 0] < 0] = False
- visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
- visibility[traj_2d[:, :, 1] < 0] = False
-
- # filter out points that're visible for less than 10 frames
- visible_inds_resampled = visibility.sum(0) > 10
- traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
- visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
-
- rgbs = np.stack(rgbs, 0)
- video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
- return CoTrackerData(
- video=video,
- trajectory=traj_2d,
- visibility=visibility,
- valid=torch.ones(T, N),
- seq_name=sample[0].sequence_name,
- )
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import os
+import gzip
+import torch
+import numpy as np
+import torch.utils.data as data
+from collections import defaultdict
+from dataclasses import dataclass
+from typing import List, Optional, Any, Dict, Tuple
+
+from cotracker.datasets.utils import CoTrackerData
+from cotracker.datasets.dataclass_utils import load_dataclass
+
+
+@dataclass
+class ImageAnnotation:
+ # path to jpg file, relative w.r.t. dataset_root
+ path: str
+ # H x W
+ size: Tuple[int, int]
+
+
+@dataclass
+class DynamicReplicaFrameAnnotation:
+ """A dataclass used to load annotations from json."""
+
+ # can be used to join with `SequenceAnnotation`
+ sequence_name: str
+ # 0-based, continuous frame number within sequence
+ frame_number: int
+ # timestamp in seconds from the video start
+ frame_timestamp: float
+
+ image: ImageAnnotation
+ meta: Optional[Dict[str, Any]] = None
+
+ camera_name: Optional[str] = None
+ trajectories: Optional[str] = None
+
+
+class DynamicReplicaDataset(data.Dataset):
+ def __init__(
+ self,
+ root,
+ split="valid",
+ traj_per_sample=256,
+ crop_size=None,
+ sample_len=-1,
+ only_first_n_samples=-1,
+ rgbd_input=False,
+ ):
+ super(DynamicReplicaDataset, self).__init__()
+ self.root = root
+ self.sample_len = sample_len
+ self.split = split
+ self.traj_per_sample = traj_per_sample
+ self.rgbd_input = rgbd_input
+ self.crop_size = crop_size
+ frame_annotations_file = f"frame_annotations_{split}.jgz"
+ self.sample_list = []
+ with gzip.open(
+ os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
+ ) as zipfile:
+ frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
+ seq_annot = defaultdict(list)
+ for frame_annot in frame_annots_list:
+ if frame_annot.camera_name == "left":
+ seq_annot[frame_annot.sequence_name].append(frame_annot)
+
+ for seq_name in seq_annot.keys():
+ seq_len = len(seq_annot[seq_name])
+
+ step = self.sample_len if self.sample_len > 0 else seq_len
+ counter = 0
+
+ for ref_idx in range(0, seq_len, step):
+ sample = seq_annot[seq_name][ref_idx : ref_idx + step]
+ self.sample_list.append(sample)
+ counter += 1
+ if only_first_n_samples > 0 and counter >= only_first_n_samples:
+ break
+
+ def __len__(self):
+ return len(self.sample_list)
+
+ def crop(self, rgbs, trajs):
+ T, N, _ = trajs.shape
+
+ S = len(rgbs)
+ H, W = rgbs[0].shape[:2]
+ assert S == T
+
+ H_new = H
+ W_new = W
+
+ # simple random crop
+ y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
+ x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
+ rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
+
+ trajs[:, :, 0] -= x0
+ trajs[:, :, 1] -= y0
+
+ return rgbs, trajs
+
+ def __getitem__(self, index):
+ sample = self.sample_list[index]
+ T = len(sample)
+ rgbs, visibilities, traj_2d = [], [], []
+
+ H, W = sample[0].image.size
+ image_size = (H, W)
+
+ for i in range(T):
+ traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
+ traj = torch.load(traj_path)
+
+ visibilities.append(traj["verts_inds_vis"].numpy())
+
+ rgbs.append(traj["img"].numpy())
+ traj_2d.append(traj["traj_2d"].numpy()[..., :2])
+
+ traj_2d = np.stack(traj_2d)
+ visibility = np.stack(visibilities)
+ T, N, D = traj_2d.shape
+ # subsample trajectories for augmentations
+ visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
+
+ traj_2d = traj_2d[:, visible_inds_sampled]
+ visibility = visibility[:, visible_inds_sampled]
+
+ if self.crop_size is not None:
+ rgbs, traj_2d = self.crop(rgbs, traj_2d)
+ H, W, _ = rgbs[0].shape
+ image_size = self.crop_size
+
+ visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
+ visibility[traj_2d[:, :, 0] < 0] = False
+ visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
+ visibility[traj_2d[:, :, 1] < 0] = False
+
+ # filter out points that're visible for less than 10 frames
+ visible_inds_resampled = visibility.sum(0) > 10
+ traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
+ visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
+
+ rgbs = np.stack(rgbs, 0)
+ video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
+ return CoTrackerData(
+ video=video,
+ trajectory=traj_2d,
+ visibility=visibility,
+ valid=torch.ones(T, N),
+ seq_name=sample[0].sequence_name,
+ )
diff --git a/cotracker/datasets/kubric_movif_dataset.py b/cotracker/datasets/kubric_movif_dataset.py
index 366d738..68ce73c 100644
--- a/cotracker/datasets/kubric_movif_dataset.py
+++ b/cotracker/datasets/kubric_movif_dataset.py
@@ -1,441 +1,441 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import os
-import torch
-import cv2
-
-import imageio
-import numpy as np
-
-from cotracker.datasets.utils import CoTrackerData
-from torchvision.transforms import ColorJitter, GaussianBlur
-from PIL import Image
-
-
-class CoTrackerDataset(torch.utils.data.Dataset):
- def __init__(
- self,
- data_root,
- crop_size=(384, 512),
- seq_len=24,
- traj_per_sample=768,
- sample_vis_1st_frame=False,
- use_augs=False,
- ):
- super(CoTrackerDataset, self).__init__()
- np.random.seed(0)
- torch.manual_seed(0)
- self.data_root = data_root
- self.seq_len = seq_len
- self.traj_per_sample = traj_per_sample
- self.sample_vis_1st_frame = sample_vis_1st_frame
- self.use_augs = use_augs
- self.crop_size = crop_size
-
- # photometric augmentation
- self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
- self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
-
- self.blur_aug_prob = 0.25
- self.color_aug_prob = 0.25
-
- # occlusion augmentation
- self.eraser_aug_prob = 0.5
- self.eraser_bounds = [2, 100]
- self.eraser_max = 10
-
- # occlusion augmentation
- self.replace_aug_prob = 0.5
- self.replace_bounds = [2, 100]
- self.replace_max = 10
-
- # spatial augmentations
- self.pad_bounds = [0, 100]
- self.crop_size = crop_size
- self.resize_lim = [0.25, 2.0] # sample resizes from here
- self.resize_delta = 0.2
- self.max_crop_offset = 50
-
- self.do_flip = True
- self.h_flip_prob = 0.5
- self.v_flip_prob = 0.5
-
- def getitem_helper(self, index):
- return NotImplementedError
-
- def __getitem__(self, index):
- gotit = False
-
- sample, gotit = self.getitem_helper(index)
- if not gotit:
- print("warning: sampling failed")
- # fake sample, so we can still collate
- sample = CoTrackerData(
- video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
- trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
- visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
- valid=torch.zeros((self.seq_len, self.traj_per_sample)),
- )
-
- return sample, gotit
-
- def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
- T, N, _ = trajs.shape
-
- S = len(rgbs)
- H, W = rgbs[0].shape[:2]
- assert S == T
-
- if eraser:
- ############ eraser transform (per image after the first) ############
- rgbs = [rgb.astype(np.float32) for rgb in rgbs]
- for i in range(1, S):
- if np.random.rand() < self.eraser_aug_prob:
- for _ in range(
- np.random.randint(1, self.eraser_max + 1)
- ): # number of times to occlude
- xc = np.random.randint(0, W)
- yc = np.random.randint(0, H)
- dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
- dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
- x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
- x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
- y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
- y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
-
- mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
- rgbs[i][y0:y1, x0:x1, :] = mean_color
-
- occ_inds = np.logical_and(
- np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
- np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
- )
- visibles[i, occ_inds] = 0
- rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
-
- if replace:
- rgbs_alt = [
- np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
- ]
- rgbs_alt = [
- np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
- ]
-
- ############ replace transform (per image after the first) ############
- rgbs = [rgb.astype(np.float32) for rgb in rgbs]
- rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
- for i in range(1, S):
- if np.random.rand() < self.replace_aug_prob:
- for _ in range(
- np.random.randint(1, self.replace_max + 1)
- ): # number of times to occlude
- xc = np.random.randint(0, W)
- yc = np.random.randint(0, H)
- dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
- dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
- x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
- x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
- y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
- y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
-
- wid = x1 - x0
- hei = y1 - y0
- y00 = np.random.randint(0, H - hei)
- x00 = np.random.randint(0, W - wid)
- fr = np.random.randint(0, S)
- rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
- rgbs[i][y0:y1, x0:x1, :] = rep
-
- occ_inds = np.logical_and(
- np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
- np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
- )
- visibles[i, occ_inds] = 0
- rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
-
- ############ photometric augmentation ############
- if np.random.rand() < self.color_aug_prob:
- # random per-frame amount of aug
- rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
-
- if np.random.rand() < self.blur_aug_prob:
- # random per-frame amount of blur
- rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
-
- return rgbs, trajs, visibles
-
- def add_spatial_augs(self, rgbs, trajs, visibles):
- T, N, __ = trajs.shape
-
- S = len(rgbs)
- H, W = rgbs[0].shape[:2]
- assert S == T
-
- rgbs = [rgb.astype(np.float32) for rgb in rgbs]
-
- ############ spatial transform ############
-
- # padding
- pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
- pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
- pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
- pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
-
- rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
- trajs[:, :, 0] += pad_x0
- trajs[:, :, 1] += pad_y0
- H, W = rgbs[0].shape[:2]
-
- # scaling + stretching
- scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
- scale_x = scale
- scale_y = scale
- H_new = H
- W_new = W
-
- scale_delta_x = 0.0
- scale_delta_y = 0.0
-
- rgbs_scaled = []
- for s in range(S):
- if s == 1:
- scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
- scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
- elif s > 1:
- scale_delta_x = (
- scale_delta_x * 0.8
- + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
- )
- scale_delta_y = (
- scale_delta_y * 0.8
- + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
- )
- scale_x = scale_x + scale_delta_x
- scale_y = scale_y + scale_delta_y
-
- # bring h/w closer
- scale_xy = (scale_x + scale_y) * 0.5
- scale_x = scale_x * 0.5 + scale_xy * 0.5
- scale_y = scale_y * 0.5 + scale_xy * 0.5
-
- # don't get too crazy
- scale_x = np.clip(scale_x, 0.2, 2.0)
- scale_y = np.clip(scale_y, 0.2, 2.0)
-
- H_new = int(H * scale_y)
- W_new = int(W * scale_x)
-
- # make it at least slightly bigger than the crop area,
- # so that the random cropping can add diversity
- H_new = np.clip(H_new, self.crop_size[0] + 10, None)
- W_new = np.clip(W_new, self.crop_size[1] + 10, None)
- # recompute scale in case we clipped
- scale_x = (W_new - 1) / float(W - 1)
- scale_y = (H_new - 1) / float(H - 1)
- rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
- trajs[s, :, 0] *= scale_x
- trajs[s, :, 1] *= scale_y
- rgbs = rgbs_scaled
-
- ok_inds = visibles[0, :] > 0
- vis_trajs = trajs[:, ok_inds] # S,?,2
-
- if vis_trajs.shape[1] > 0:
- mid_x = np.mean(vis_trajs[0, :, 0])
- mid_y = np.mean(vis_trajs[0, :, 1])
- else:
- mid_y = self.crop_size[0]
- mid_x = self.crop_size[1]
-
- x0 = int(mid_x - self.crop_size[1] // 2)
- y0 = int(mid_y - self.crop_size[0] // 2)
-
- offset_x = 0
- offset_y = 0
-
- for s in range(S):
- # on each frame, shift a bit more
- if s == 1:
- offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
- offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
- elif s > 1:
- offset_x = int(
- offset_x * 0.8
- + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
- )
- offset_y = int(
- offset_y * 0.8
- + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
- )
- x0 = x0 + offset_x
- y0 = y0 + offset_y
-
- H_new, W_new = rgbs[s].shape[:2]
- if H_new == self.crop_size[0]:
- y0 = 0
- else:
- y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
-
- if W_new == self.crop_size[1]:
- x0 = 0
- else:
- x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
-
- rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
- trajs[s, :, 0] -= x0
- trajs[s, :, 1] -= y0
-
- H_new = self.crop_size[0]
- W_new = self.crop_size[1]
-
- # flip
- h_flipped = False
- v_flipped = False
- if self.do_flip:
- # h flip
- if np.random.rand() < self.h_flip_prob:
- h_flipped = True
- rgbs = [rgb[:, ::-1] for rgb in rgbs]
- # v flip
- if np.random.rand() < self.v_flip_prob:
- v_flipped = True
- rgbs = [rgb[::-1] for rgb in rgbs]
- if h_flipped:
- trajs[:, :, 0] = W_new - trajs[:, :, 0]
- if v_flipped:
- trajs[:, :, 1] = H_new - trajs[:, :, 1]
-
- return rgbs, trajs
-
- def crop(self, rgbs, trajs):
- T, N, _ = trajs.shape
-
- S = len(rgbs)
- H, W = rgbs[0].shape[:2]
- assert S == T
-
- ############ spatial transform ############
-
- H_new = H
- W_new = W
-
- # simple random crop
- y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
- x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
- rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
-
- trajs[:, :, 0] -= x0
- trajs[:, :, 1] -= y0
-
- return rgbs, trajs
-
-
-class KubricMovifDataset(CoTrackerDataset):
- def __init__(
- self,
- data_root,
- crop_size=(384, 512),
- seq_len=24,
- traj_per_sample=768,
- sample_vis_1st_frame=False,
- use_augs=False,
- ):
- super(KubricMovifDataset, self).__init__(
- data_root=data_root,
- crop_size=crop_size,
- seq_len=seq_len,
- traj_per_sample=traj_per_sample,
- sample_vis_1st_frame=sample_vis_1st_frame,
- use_augs=use_augs,
- )
-
- self.pad_bounds = [0, 25]
- self.resize_lim = [0.75, 1.25] # sample resizes from here
- self.resize_delta = 0.05
- self.max_crop_offset = 15
- self.seq_names = [
- fname
- for fname in os.listdir(data_root)
- if os.path.isdir(os.path.join(data_root, fname))
- ]
- print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
-
- def getitem_helper(self, index):
- gotit = True
- seq_name = self.seq_names[index]
-
- npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
- rgb_path = os.path.join(self.data_root, seq_name, "frames")
-
- img_paths = sorted(os.listdir(rgb_path))
- rgbs = []
- for i, img_path in enumerate(img_paths):
- rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
-
- rgbs = np.stack(rgbs)
- annot_dict = np.load(npy_path, allow_pickle=True).item()
- traj_2d = annot_dict["coords"]
- visibility = annot_dict["visibility"]
-
- # random crop
- assert self.seq_len <= len(rgbs)
- if self.seq_len < len(rgbs):
- start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
-
- rgbs = rgbs[start_ind : start_ind + self.seq_len]
- traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
- visibility = visibility[:, start_ind : start_ind + self.seq_len]
-
- traj_2d = np.transpose(traj_2d, (1, 0, 2))
- visibility = np.transpose(np.logical_not(visibility), (1, 0))
- if self.use_augs:
- rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
- rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
- else:
- rgbs, traj_2d = self.crop(rgbs, traj_2d)
-
- visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
- visibility[traj_2d[:, :, 0] < 0] = False
- visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
- visibility[traj_2d[:, :, 1] < 0] = False
-
- visibility = torch.from_numpy(visibility)
- traj_2d = torch.from_numpy(traj_2d)
-
- visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
-
- if self.sample_vis_1st_frame:
- visibile_pts_inds = visibile_pts_first_frame_inds
- else:
- visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
- :, 0
- ]
- visibile_pts_inds = torch.cat(
- (visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
- )
- point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
- if len(point_inds) < self.traj_per_sample:
- gotit = False
-
- visible_inds_sampled = visibile_pts_inds[point_inds]
-
- trajs = traj_2d[:, visible_inds_sampled].float()
- visibles = visibility[:, visible_inds_sampled]
- valids = torch.ones((self.seq_len, self.traj_per_sample))
-
- rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
- sample = CoTrackerData(
- video=rgbs,
- trajectory=trajs,
- visibility=visibles,
- valid=valids,
- seq_name=seq_name,
- )
- return sample, gotit
-
- def __len__(self):
- return len(self.seq_names)
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import torch
+import cv2
+
+import imageio
+import numpy as np
+
+from cotracker.datasets.utils import CoTrackerData
+from torchvision.transforms import ColorJitter, GaussianBlur
+from PIL import Image
+
+
+class CoTrackerDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ data_root,
+ crop_size=(384, 512),
+ seq_len=24,
+ traj_per_sample=768,
+ sample_vis_1st_frame=False,
+ use_augs=False,
+ ):
+ super(CoTrackerDataset, self).__init__()
+ np.random.seed(0)
+ torch.manual_seed(0)
+ self.data_root = data_root
+ self.seq_len = seq_len
+ self.traj_per_sample = traj_per_sample
+ self.sample_vis_1st_frame = sample_vis_1st_frame
+ self.use_augs = use_augs
+ self.crop_size = crop_size
+
+ # photometric augmentation
+ self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
+ self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
+
+ self.blur_aug_prob = 0.25
+ self.color_aug_prob = 0.25
+
+ # occlusion augmentation
+ self.eraser_aug_prob = 0.5
+ self.eraser_bounds = [2, 100]
+ self.eraser_max = 10
+
+ # occlusion augmentation
+ self.replace_aug_prob = 0.5
+ self.replace_bounds = [2, 100]
+ self.replace_max = 10
+
+ # spatial augmentations
+ self.pad_bounds = [0, 100]
+ self.crop_size = crop_size
+ self.resize_lim = [0.25, 2.0] # sample resizes from here
+ self.resize_delta = 0.2
+ self.max_crop_offset = 50
+
+ self.do_flip = True
+ self.h_flip_prob = 0.5
+ self.v_flip_prob = 0.5
+
+ def getitem_helper(self, index):
+ return NotImplementedError
+
+ def __getitem__(self, index):
+ gotit = False
+
+ sample, gotit = self.getitem_helper(index)
+ if not gotit:
+ print("warning: sampling failed")
+ # fake sample, so we can still collate
+ sample = CoTrackerData(
+ video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
+ trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
+ visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
+ valid=torch.zeros((self.seq_len, self.traj_per_sample)),
+ )
+
+ return sample, gotit
+
+ def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
+ T, N, _ = trajs.shape
+
+ S = len(rgbs)
+ H, W = rgbs[0].shape[:2]
+ assert S == T
+
+ if eraser:
+ ############ eraser transform (per image after the first) ############
+ rgbs = [rgb.astype(np.float32) for rgb in rgbs]
+ for i in range(1, S):
+ if np.random.rand() < self.eraser_aug_prob:
+ for _ in range(
+ np.random.randint(1, self.eraser_max + 1)
+ ): # number of times to occlude
+ xc = np.random.randint(0, W)
+ yc = np.random.randint(0, H)
+ dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
+ dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
+ x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
+ x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
+ y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
+ y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
+
+ mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
+ rgbs[i][y0:y1, x0:x1, :] = mean_color
+
+ occ_inds = np.logical_and(
+ np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
+ np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
+ )
+ visibles[i, occ_inds] = 0
+ rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
+
+ if replace:
+ rgbs_alt = [
+ np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
+ ]
+ rgbs_alt = [
+ np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
+ ]
+
+ ############ replace transform (per image after the first) ############
+ rgbs = [rgb.astype(np.float32) for rgb in rgbs]
+ rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
+ for i in range(1, S):
+ if np.random.rand() < self.replace_aug_prob:
+ for _ in range(
+ np.random.randint(1, self.replace_max + 1)
+ ): # number of times to occlude
+ xc = np.random.randint(0, W)
+ yc = np.random.randint(0, H)
+ dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
+ dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
+ x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
+ x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
+ y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
+ y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
+
+ wid = x1 - x0
+ hei = y1 - y0
+ y00 = np.random.randint(0, H - hei)
+ x00 = np.random.randint(0, W - wid)
+ fr = np.random.randint(0, S)
+ rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
+ rgbs[i][y0:y1, x0:x1, :] = rep
+
+ occ_inds = np.logical_and(
+ np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
+ np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
+ )
+ visibles[i, occ_inds] = 0
+ rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
+
+ ############ photometric augmentation ############
+ if np.random.rand() < self.color_aug_prob:
+ # random per-frame amount of aug
+ rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
+
+ if np.random.rand() < self.blur_aug_prob:
+ # random per-frame amount of blur
+ rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
+
+ return rgbs, trajs, visibles
+
+ def add_spatial_augs(self, rgbs, trajs, visibles):
+ T, N, __ = trajs.shape
+
+ S = len(rgbs)
+ H, W = rgbs[0].shape[:2]
+ assert S == T
+
+ rgbs = [rgb.astype(np.float32) for rgb in rgbs]
+
+ ############ spatial transform ############
+
+ # padding
+ pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
+ pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
+ pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
+ pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
+
+ rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
+ trajs[:, :, 0] += pad_x0
+ trajs[:, :, 1] += pad_y0
+ H, W = rgbs[0].shape[:2]
+
+ # scaling + stretching
+ scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
+ scale_x = scale
+ scale_y = scale
+ H_new = H
+ W_new = W
+
+ scale_delta_x = 0.0
+ scale_delta_y = 0.0
+
+ rgbs_scaled = []
+ for s in range(S):
+ if s == 1:
+ scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
+ scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
+ elif s > 1:
+ scale_delta_x = (
+ scale_delta_x * 0.8
+ + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
+ )
+ scale_delta_y = (
+ scale_delta_y * 0.8
+ + np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
+ )
+ scale_x = scale_x + scale_delta_x
+ scale_y = scale_y + scale_delta_y
+
+ # bring h/w closer
+ scale_xy = (scale_x + scale_y) * 0.5
+ scale_x = scale_x * 0.5 + scale_xy * 0.5
+ scale_y = scale_y * 0.5 + scale_xy * 0.5
+
+ # don't get too crazy
+ scale_x = np.clip(scale_x, 0.2, 2.0)
+ scale_y = np.clip(scale_y, 0.2, 2.0)
+
+ H_new = int(H * scale_y)
+ W_new = int(W * scale_x)
+
+ # make it at least slightly bigger than the crop area,
+ # so that the random cropping can add diversity
+ H_new = np.clip(H_new, self.crop_size[0] + 10, None)
+ W_new = np.clip(W_new, self.crop_size[1] + 10, None)
+ # recompute scale in case we clipped
+ scale_x = (W_new - 1) / float(W - 1)
+ scale_y = (H_new - 1) / float(H - 1)
+ rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
+ trajs[s, :, 0] *= scale_x
+ trajs[s, :, 1] *= scale_y
+ rgbs = rgbs_scaled
+
+ ok_inds = visibles[0, :] > 0
+ vis_trajs = trajs[:, ok_inds] # S,?,2
+
+ if vis_trajs.shape[1] > 0:
+ mid_x = np.mean(vis_trajs[0, :, 0])
+ mid_y = np.mean(vis_trajs[0, :, 1])
+ else:
+ mid_y = self.crop_size[0]
+ mid_x = self.crop_size[1]
+
+ x0 = int(mid_x - self.crop_size[1] // 2)
+ y0 = int(mid_y - self.crop_size[0] // 2)
+
+ offset_x = 0
+ offset_y = 0
+
+ for s in range(S):
+ # on each frame, shift a bit more
+ if s == 1:
+ offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
+ offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
+ elif s > 1:
+ offset_x = int(
+ offset_x * 0.8
+ + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
+ )
+ offset_y = int(
+ offset_y * 0.8
+ + np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
+ )
+ x0 = x0 + offset_x
+ y0 = y0 + offset_y
+
+ H_new, W_new = rgbs[s].shape[:2]
+ if H_new == self.crop_size[0]:
+ y0 = 0
+ else:
+ y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
+
+ if W_new == self.crop_size[1]:
+ x0 = 0
+ else:
+ x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
+
+ rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
+ trajs[s, :, 0] -= x0
+ trajs[s, :, 1] -= y0
+
+ H_new = self.crop_size[0]
+ W_new = self.crop_size[1]
+
+ # flip
+ h_flipped = False
+ v_flipped = False
+ if self.do_flip:
+ # h flip
+ if np.random.rand() < self.h_flip_prob:
+ h_flipped = True
+ rgbs = [rgb[:, ::-1] for rgb in rgbs]
+ # v flip
+ if np.random.rand() < self.v_flip_prob:
+ v_flipped = True
+ rgbs = [rgb[::-1] for rgb in rgbs]
+ if h_flipped:
+ trajs[:, :, 0] = W_new - trajs[:, :, 0]
+ if v_flipped:
+ trajs[:, :, 1] = H_new - trajs[:, :, 1]
+
+ return rgbs, trajs
+
+ def crop(self, rgbs, trajs):
+ T, N, _ = trajs.shape
+
+ S = len(rgbs)
+ H, W = rgbs[0].shape[:2]
+ assert S == T
+
+ ############ spatial transform ############
+
+ H_new = H
+ W_new = W
+
+ # simple random crop
+ y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
+ x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
+ rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
+
+ trajs[:, :, 0] -= x0
+ trajs[:, :, 1] -= y0
+
+ return rgbs, trajs
+
+
+class KubricMovifDataset(CoTrackerDataset):
+ def __init__(
+ self,
+ data_root,
+ crop_size=(384, 512),
+ seq_len=24,
+ traj_per_sample=768,
+ sample_vis_1st_frame=False,
+ use_augs=False,
+ ):
+ super(KubricMovifDataset, self).__init__(
+ data_root=data_root,
+ crop_size=crop_size,
+ seq_len=seq_len,
+ traj_per_sample=traj_per_sample,
+ sample_vis_1st_frame=sample_vis_1st_frame,
+ use_augs=use_augs,
+ )
+
+ self.pad_bounds = [0, 25]
+ self.resize_lim = [0.75, 1.25] # sample resizes from here
+ self.resize_delta = 0.05
+ self.max_crop_offset = 15
+ self.seq_names = [
+ fname
+ for fname in os.listdir(data_root)
+ if os.path.isdir(os.path.join(data_root, fname))
+ ]
+ print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
+
+ def getitem_helper(self, index):
+ gotit = True
+ seq_name = self.seq_names[index]
+
+ npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
+ rgb_path = os.path.join(self.data_root, seq_name, "frames")
+
+ img_paths = sorted(os.listdir(rgb_path))
+ rgbs = []
+ for i, img_path in enumerate(img_paths):
+ rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
+
+ rgbs = np.stack(rgbs)
+ annot_dict = np.load(npy_path, allow_pickle=True).item()
+ traj_2d = annot_dict["coords"]
+ visibility = annot_dict["visibility"]
+
+ # random crop
+ assert self.seq_len <= len(rgbs)
+ if self.seq_len < len(rgbs):
+ start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
+
+ rgbs = rgbs[start_ind : start_ind + self.seq_len]
+ traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
+ visibility = visibility[:, start_ind : start_ind + self.seq_len]
+
+ traj_2d = np.transpose(traj_2d, (1, 0, 2))
+ visibility = np.transpose(np.logical_not(visibility), (1, 0))
+ if self.use_augs:
+ rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
+ rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
+ else:
+ rgbs, traj_2d = self.crop(rgbs, traj_2d)
+
+ visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
+ visibility[traj_2d[:, :, 0] < 0] = False
+ visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
+ visibility[traj_2d[:, :, 1] < 0] = False
+
+ visibility = torch.from_numpy(visibility)
+ traj_2d = torch.from_numpy(traj_2d)
+
+ visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
+
+ if self.sample_vis_1st_frame:
+ visibile_pts_inds = visibile_pts_first_frame_inds
+ else:
+ visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
+ :, 0
+ ]
+ visibile_pts_inds = torch.cat(
+ (visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
+ )
+ point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
+ if len(point_inds) < self.traj_per_sample:
+ gotit = False
+
+ visible_inds_sampled = visibile_pts_inds[point_inds]
+
+ trajs = traj_2d[:, visible_inds_sampled].float()
+ visibles = visibility[:, visible_inds_sampled]
+ valids = torch.ones((self.seq_len, self.traj_per_sample))
+
+ rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
+ sample = CoTrackerData(
+ video=rgbs,
+ trajectory=trajs,
+ visibility=visibles,
+ valid=valids,
+ seq_name=seq_name,
+ )
+ return sample, gotit
+
+ def __len__(self):
+ return len(self.seq_names)
diff --git a/cotracker/datasets/tap_vid_datasets.py b/cotracker/datasets/tap_vid_datasets.py
index 72e0001..5597b83 100644
--- a/cotracker/datasets/tap_vid_datasets.py
+++ b/cotracker/datasets/tap_vid_datasets.py
@@ -1,209 +1,209 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import os
-import io
-import glob
-import torch
-import pickle
-import numpy as np
-import mediapy as media
-
-from PIL import Image
-from typing import Mapping, Tuple, Union
-
-from cotracker.datasets.utils import CoTrackerData
-
-DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
-
-
-def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
- """Resize a video to output_size."""
- # If you have a GPU, consider replacing this with a GPU-enabled resize op,
- # such as a jitted jax.image.resize. It will make things faster.
- return media.resize_video(video, output_size)
-
-
-def sample_queries_first(
- target_occluded: np.ndarray,
- target_points: np.ndarray,
- frames: np.ndarray,
-) -> Mapping[str, np.ndarray]:
- """Package a set of frames and tracks for use in TAPNet evaluations.
- Given a set of frames and tracks with no query points, use the first
- visible point in each track as the query.
- Args:
- target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
- where True indicates occluded.
- target_points: Position, of shape [n_tracks, n_frames, 2], where each point
- is [x,y] scaled between 0 and 1.
- frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
- -1 and 1.
- Returns:
- A dict with the keys:
- video: Video tensor of shape [1, n_frames, height, width, 3]
- query_points: Query points of shape [1, n_queries, 3] where
- each point is [t, y, x] scaled to the range [-1, 1]
- target_points: Target points of shape [1, n_queries, n_frames, 2] where
- each point is [x, y] scaled to the range [-1, 1]
- """
- valid = np.sum(~target_occluded, axis=1) > 0
- target_points = target_points[valid, :]
- target_occluded = target_occluded[valid, :]
-
- query_points = []
- for i in range(target_points.shape[0]):
- index = np.where(target_occluded[i] == 0)[0][0]
- x, y = target_points[i, index, 0], target_points[i, index, 1]
- query_points.append(np.array([index, y, x])) # [t, y, x]
- query_points = np.stack(query_points, axis=0)
-
- return {
- "video": frames[np.newaxis, ...],
- "query_points": query_points[np.newaxis, ...],
- "target_points": target_points[np.newaxis, ...],
- "occluded": target_occluded[np.newaxis, ...],
- }
-
-
-def sample_queries_strided(
- target_occluded: np.ndarray,
- target_points: np.ndarray,
- frames: np.ndarray,
- query_stride: int = 5,
-) -> Mapping[str, np.ndarray]:
- """Package a set of frames and tracks for use in TAPNet evaluations.
-
- Given a set of frames and tracks with no query points, sample queries
- strided every query_stride frames, ignoring points that are not visible
- at the selected frames.
-
- Args:
- target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
- where True indicates occluded.
- target_points: Position, of shape [n_tracks, n_frames, 2], where each point
- is [x,y] scaled between 0 and 1.
- frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
- -1 and 1.
- query_stride: When sampling query points, search for un-occluded points
- every query_stride frames and convert each one into a query.
-
- Returns:
- A dict with the keys:
- video: Video tensor of shape [1, n_frames, height, width, 3]. The video
- has floats scaled to the range [-1, 1].
- query_points: Query points of shape [1, n_queries, 3] where
- each point is [t, y, x] scaled to the range [-1, 1].
- target_points: Target points of shape [1, n_queries, n_frames, 2] where
- each point is [x, y] scaled to the range [-1, 1].
- trackgroup: Index of the original track that each query point was
- sampled from. This is useful for visualization.
- """
- tracks = []
- occs = []
- queries = []
- trackgroups = []
- total = 0
- trackgroup = np.arange(target_occluded.shape[0])
- for i in range(0, target_occluded.shape[1], query_stride):
- mask = target_occluded[:, i] == 0
- query = np.stack(
- [
- i * np.ones(target_occluded.shape[0:1]),
- target_points[:, i, 1],
- target_points[:, i, 0],
- ],
- axis=-1,
- )
- queries.append(query[mask])
- tracks.append(target_points[mask])
- occs.append(target_occluded[mask])
- trackgroups.append(trackgroup[mask])
- total += np.array(np.sum(target_occluded[:, i] == 0))
-
- return {
- "video": frames[np.newaxis, ...],
- "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
- "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
- "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
- "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
- }
-
-
-class TapVidDataset(torch.utils.data.Dataset):
- def __init__(
- self,
- data_root,
- dataset_type="davis",
- resize_to_256=True,
- queried_first=True,
- ):
- self.dataset_type = dataset_type
- self.resize_to_256 = resize_to_256
- self.queried_first = queried_first
- if self.dataset_type == "kinetics":
- all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
- points_dataset = []
- for pickle_path in all_paths:
- with open(pickle_path, "rb") as f:
- data = pickle.load(f)
- points_dataset = points_dataset + data
- self.points_dataset = points_dataset
- else:
- with open(data_root, "rb") as f:
- self.points_dataset = pickle.load(f)
- if self.dataset_type == "davis":
- self.video_names = list(self.points_dataset.keys())
- print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
-
- def __getitem__(self, index):
- if self.dataset_type == "davis":
- video_name = self.video_names[index]
- else:
- video_name = index
- video = self.points_dataset[video_name]
- frames = video["video"]
-
- if isinstance(frames[0], bytes):
- # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
- def decode(frame):
- byteio = io.BytesIO(frame)
- img = Image.open(byteio)
- return np.array(img)
-
- frames = np.array([decode(frame) for frame in frames])
-
- target_points = self.points_dataset[video_name]["points"]
- if self.resize_to_256:
- frames = resize_video(frames, [256, 256])
- target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
- else:
- target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
-
- target_occ = self.points_dataset[video_name]["occluded"]
- if self.queried_first:
- converted = sample_queries_first(target_occ, target_points, frames)
- else:
- converted = sample_queries_strided(target_occ, target_points, frames)
- assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
-
- trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
-
- rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
- visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
- 1, 0
- ) # T, N
- query_points = torch.from_numpy(converted["query_points"])[0] # T, N
- return CoTrackerData(
- rgbs,
- trajs,
- visibles,
- seq_name=str(video_name),
- query_points=query_points,
- )
-
- def __len__(self):
- return len(self.points_dataset)
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import os
+import io
+import glob
+import torch
+import pickle
+import numpy as np
+import mediapy as media
+
+from PIL import Image
+from typing import Mapping, Tuple, Union
+
+from cotracker.datasets.utils import CoTrackerData
+
+DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
+
+
+def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
+ """Resize a video to output_size."""
+ # If you have a GPU, consider replacing this with a GPU-enabled resize op,
+ # such as a jitted jax.image.resize. It will make things faster.
+ return media.resize_video(video, output_size)
+
+
+def sample_queries_first(
+ target_occluded: np.ndarray,
+ target_points: np.ndarray,
+ frames: np.ndarray,
+) -> Mapping[str, np.ndarray]:
+ """Package a set of frames and tracks for use in TAPNet evaluations.
+ Given a set of frames and tracks with no query points, use the first
+ visible point in each track as the query.
+ Args:
+ target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
+ where True indicates occluded.
+ target_points: Position, of shape [n_tracks, n_frames, 2], where each point
+ is [x,y] scaled between 0 and 1.
+ frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
+ -1 and 1.
+ Returns:
+ A dict with the keys:
+ video: Video tensor of shape [1, n_frames, height, width, 3]
+ query_points: Query points of shape [1, n_queries, 3] where
+ each point is [t, y, x] scaled to the range [-1, 1]
+ target_points: Target points of shape [1, n_queries, n_frames, 2] where
+ each point is [x, y] scaled to the range [-1, 1]
+ """
+ valid = np.sum(~target_occluded, axis=1) > 0
+ target_points = target_points[valid, :]
+ target_occluded = target_occluded[valid, :]
+
+ query_points = []
+ for i in range(target_points.shape[0]):
+ index = np.where(target_occluded[i] == 0)[0][0]
+ x, y = target_points[i, index, 0], target_points[i, index, 1]
+ query_points.append(np.array([index, y, x])) # [t, y, x]
+ query_points = np.stack(query_points, axis=0)
+
+ return {
+ "video": frames[np.newaxis, ...],
+ "query_points": query_points[np.newaxis, ...],
+ "target_points": target_points[np.newaxis, ...],
+ "occluded": target_occluded[np.newaxis, ...],
+ }
+
+
+def sample_queries_strided(
+ target_occluded: np.ndarray,
+ target_points: np.ndarray,
+ frames: np.ndarray,
+ query_stride: int = 5,
+) -> Mapping[str, np.ndarray]:
+ """Package a set of frames and tracks for use in TAPNet evaluations.
+
+ Given a set of frames and tracks with no query points, sample queries
+ strided every query_stride frames, ignoring points that are not visible
+ at the selected frames.
+
+ Args:
+ target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
+ where True indicates occluded.
+ target_points: Position, of shape [n_tracks, n_frames, 2], where each point
+ is [x,y] scaled between 0 and 1.
+ frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
+ -1 and 1.
+ query_stride: When sampling query points, search for un-occluded points
+ every query_stride frames and convert each one into a query.
+
+ Returns:
+ A dict with the keys:
+ video: Video tensor of shape [1, n_frames, height, width, 3]. The video
+ has floats scaled to the range [-1, 1].
+ query_points: Query points of shape [1, n_queries, 3] where
+ each point is [t, y, x] scaled to the range [-1, 1].
+ target_points: Target points of shape [1, n_queries, n_frames, 2] where
+ each point is [x, y] scaled to the range [-1, 1].
+ trackgroup: Index of the original track that each query point was
+ sampled from. This is useful for visualization.
+ """
+ tracks = []
+ occs = []
+ queries = []
+ trackgroups = []
+ total = 0
+ trackgroup = np.arange(target_occluded.shape[0])
+ for i in range(0, target_occluded.shape[1], query_stride):
+ mask = target_occluded[:, i] == 0
+ query = np.stack(
+ [
+ i * np.ones(target_occluded.shape[0:1]),
+ target_points[:, i, 1],
+ target_points[:, i, 0],
+ ],
+ axis=-1,
+ )
+ queries.append(query[mask])
+ tracks.append(target_points[mask])
+ occs.append(target_occluded[mask])
+ trackgroups.append(trackgroup[mask])
+ total += np.array(np.sum(target_occluded[:, i] == 0))
+
+ return {
+ "video": frames[np.newaxis, ...],
+ "query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
+ "target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
+ "occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
+ "trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
+ }
+
+
+class TapVidDataset(torch.utils.data.Dataset):
+ def __init__(
+ self,
+ data_root,
+ dataset_type="davis",
+ resize_to_256=True,
+ queried_first=True,
+ ):
+ self.dataset_type = dataset_type
+ self.resize_to_256 = resize_to_256
+ self.queried_first = queried_first
+ if self.dataset_type == "kinetics":
+ all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
+ points_dataset = []
+ for pickle_path in all_paths:
+ with open(pickle_path, "rb") as f:
+ data = pickle.load(f)
+ points_dataset = points_dataset + data
+ self.points_dataset = points_dataset
+ else:
+ with open(data_root, "rb") as f:
+ self.points_dataset = pickle.load(f)
+ if self.dataset_type == "davis":
+ self.video_names = list(self.points_dataset.keys())
+ print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
+
+ def __getitem__(self, index):
+ if self.dataset_type == "davis":
+ video_name = self.video_names[index]
+ else:
+ video_name = index
+ video = self.points_dataset[video_name]
+ frames = video["video"]
+
+ if isinstance(frames[0], bytes):
+ # TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
+ def decode(frame):
+ byteio = io.BytesIO(frame)
+ img = Image.open(byteio)
+ return np.array(img)
+
+ frames = np.array([decode(frame) for frame in frames])
+
+ target_points = self.points_dataset[video_name]["points"]
+ if self.resize_to_256:
+ frames = resize_video(frames, [256, 256])
+ target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
+ else:
+ target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
+
+ target_occ = self.points_dataset[video_name]["occluded"]
+ if self.queried_first:
+ converted = sample_queries_first(target_occ, target_points, frames)
+ else:
+ converted = sample_queries_strided(target_occ, target_points, frames)
+ assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
+
+ trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
+
+ rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
+ visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
+ 1, 0
+ ) # T, N
+ query_points = torch.from_numpy(converted["query_points"])[0] # T, N
+ return CoTrackerData(
+ rgbs,
+ trajs,
+ visibles,
+ seq_name=str(video_name),
+ query_points=query_points,
+ )
+
+ def __len__(self):
+ return len(self.points_dataset)
diff --git a/cotracker/datasets/utils.py b/cotracker/datasets/utils.py
index 30149f1..09b5ede 100644
--- a/cotracker/datasets/utils.py
+++ b/cotracker/datasets/utils.py
@@ -1,106 +1,106 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-import torch
-import dataclasses
-import torch.nn.functional as F
-from dataclasses import dataclass
-from typing import Any, Optional
-
-
-@dataclass(eq=False)
-class CoTrackerData:
- """
- Dataclass for storing video tracks data.
- """
-
- video: torch.Tensor # B, S, C, H, W
- trajectory: torch.Tensor # B, S, N, 2
- visibility: torch.Tensor # B, S, N
- # optional data
- valid: Optional[torch.Tensor] = None # B, S, N
- segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
- seq_name: Optional[str] = None
- query_points: Optional[torch.Tensor] = None # TapVID evaluation format
-
-
-def collate_fn(batch):
- """
- Collate function for video tracks data.
- """
- video = torch.stack([b.video for b in batch], dim=0)
- trajectory = torch.stack([b.trajectory for b in batch], dim=0)
- visibility = torch.stack([b.visibility for b in batch], dim=0)
- query_points = segmentation = None
- if batch[0].query_points is not None:
- query_points = torch.stack([b.query_points for b in batch], dim=0)
- if batch[0].segmentation is not None:
- segmentation = torch.stack([b.segmentation for b in batch], dim=0)
- seq_name = [b.seq_name for b in batch]
-
- return CoTrackerData(
- video=video,
- trajectory=trajectory,
- visibility=visibility,
- segmentation=segmentation,
- seq_name=seq_name,
- query_points=query_points,
- )
-
-
-def collate_fn_train(batch):
- """
- Collate function for video tracks data during training.
- """
- gotit = [gotit for _, gotit in batch]
- video = torch.stack([b.video for b, _ in batch], dim=0)
- trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
- visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
- valid = torch.stack([b.valid for b, _ in batch], dim=0)
- seq_name = [b.seq_name for b, _ in batch]
- return (
- CoTrackerData(
- video=video,
- trajectory=trajectory,
- visibility=visibility,
- valid=valid,
- seq_name=seq_name,
- ),
- gotit,
- )
-
-
-def try_to_cuda(t: Any) -> Any:
- """
- Try to move the input variable `t` to a cuda device.
-
- Args:
- t: Input.
-
- Returns:
- t_cuda: `t` moved to a cuda device, if supported.
- """
- try:
- t = t.float().cuda()
- except AttributeError:
- pass
- return t
-
-
-def dataclass_to_cuda_(obj):
- """
- Move all contents of a dataclass to cuda inplace if supported.
-
- Args:
- batch: Input dataclass.
-
- Returns:
- batch_cuda: `batch` moved to a cuda device, if supported.
- """
- for f in dataclasses.fields(obj):
- setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
- return obj
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+import torch
+import dataclasses
+import torch.nn.functional as F
+from dataclasses import dataclass
+from typing import Any, Optional
+
+
+@dataclass(eq=False)
+class CoTrackerData:
+ """
+ Dataclass for storing video tracks data.
+ """
+
+ video: torch.Tensor # B, S, C, H, W
+ trajectory: torch.Tensor # B, S, N, 2
+ visibility: torch.Tensor # B, S, N
+ # optional data
+ valid: Optional[torch.Tensor] = None # B, S, N
+ segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
+ seq_name: Optional[str] = None
+ query_points: Optional[torch.Tensor] = None # TapVID evaluation format
+
+
+def collate_fn(batch):
+ """
+ Collate function for video tracks data.
+ """
+ video = torch.stack([b.video for b in batch], dim=0)
+ trajectory = torch.stack([b.trajectory for b in batch], dim=0)
+ visibility = torch.stack([b.visibility for b in batch], dim=0)
+ query_points = segmentation = None
+ if batch[0].query_points is not None:
+ query_points = torch.stack([b.query_points for b in batch], dim=0)
+ if batch[0].segmentation is not None:
+ segmentation = torch.stack([b.segmentation for b in batch], dim=0)
+ seq_name = [b.seq_name for b in batch]
+
+ return CoTrackerData(
+ video=video,
+ trajectory=trajectory,
+ visibility=visibility,
+ segmentation=segmentation,
+ seq_name=seq_name,
+ query_points=query_points,
+ )
+
+
+def collate_fn_train(batch):
+ """
+ Collate function for video tracks data during training.
+ """
+ gotit = [gotit for _, gotit in batch]
+ video = torch.stack([b.video for b, _ in batch], dim=0)
+ trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
+ visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
+ valid = torch.stack([b.valid for b, _ in batch], dim=0)
+ seq_name = [b.seq_name for b, _ in batch]
+ return (
+ CoTrackerData(
+ video=video,
+ trajectory=trajectory,
+ visibility=visibility,
+ valid=valid,
+ seq_name=seq_name,
+ ),
+ gotit,
+ )
+
+
+def try_to_cuda(t: Any) -> Any:
+ """
+ Try to move the input variable `t` to a cuda device.
+
+ Args:
+ t: Input.
+
+ Returns:
+ t_cuda: `t` moved to a cuda device, if supported.
+ """
+ try:
+ t = t.float().cuda()
+ except AttributeError:
+ pass
+ return t
+
+
+def dataclass_to_cuda_(obj):
+ """
+ Move all contents of a dataclass to cuda inplace if supported.
+
+ Args:
+ batch: Input dataclass.
+
+ Returns:
+ batch_cuda: `batch` moved to a cuda device, if supported.
+ """
+ for f in dataclasses.fields(obj):
+ setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
+ return obj
diff --git a/cotracker/evaluation/__init__.py b/cotracker/evaluation/__init__.py
index 5277f46..4547e07 100644
--- a/cotracker/evaluation/__init__.py
+++ b/cotracker/evaluation/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/cotracker/evaluation/configs/eval_dynamic_replica.yaml b/cotracker/evaluation/configs/eval_dynamic_replica.yaml
index 7d6fca9..2f9c325 100644
--- a/cotracker/evaluation/configs/eval_dynamic_replica.yaml
+++ b/cotracker/evaluation/configs/eval_dynamic_replica.yaml
@@ -1,6 +1,6 @@
-defaults:
- - default_config_eval
-exp_dir: ./outputs/cotracker
-dataset_name: dynamic_replica
-
+defaults:
+ - default_config_eval
+exp_dir: ./outputs/cotracker
+dataset_name: dynamic_replica
+
\ No newline at end of file
diff --git a/cotracker/evaluation/configs/eval_tapvid_davis_first.yaml b/cotracker/evaluation/configs/eval_tapvid_davis_first.yaml
index d37a6c9..0d72e37 100644
--- a/cotracker/evaluation/configs/eval_tapvid_davis_first.yaml
+++ b/cotracker/evaluation/configs/eval_tapvid_davis_first.yaml
@@ -1,6 +1,6 @@
-defaults:
- - default_config_eval
-exp_dir: ./outputs/cotracker
-dataset_name: tapvid_davis_first
-
+defaults:
+ - default_config_eval
+exp_dir: ./outputs/cotracker
+dataset_name: tapvid_davis_first
+
\ No newline at end of file
diff --git a/cotracker/evaluation/configs/eval_tapvid_davis_strided.yaml b/cotracker/evaluation/configs/eval_tapvid_davis_strided.yaml
index 6e3cf3c..5a687bc 100644
--- a/cotracker/evaluation/configs/eval_tapvid_davis_strided.yaml
+++ b/cotracker/evaluation/configs/eval_tapvid_davis_strided.yaml
@@ -1,6 +1,6 @@
-defaults:
- - default_config_eval
-exp_dir: ./outputs/cotracker
-dataset_name: tapvid_davis_strided
-
+defaults:
+ - default_config_eval
+exp_dir: ./outputs/cotracker
+dataset_name: tapvid_davis_strided
+
\ No newline at end of file
diff --git a/cotracker/evaluation/configs/eval_tapvid_kinetics_first.yaml b/cotracker/evaluation/configs/eval_tapvid_kinetics_first.yaml
index 3be8914..f8651f6 100644
--- a/cotracker/evaluation/configs/eval_tapvid_kinetics_first.yaml
+++ b/cotracker/evaluation/configs/eval_tapvid_kinetics_first.yaml
@@ -1,6 +1,6 @@
-defaults:
- - default_config_eval
-exp_dir: ./outputs/cotracker
-dataset_name: tapvid_kinetics_first
-
+defaults:
+ - default_config_eval
+exp_dir: ./outputs/cotracker
+dataset_name: tapvid_kinetics_first
+
\ No newline at end of file
diff --git a/cotracker/evaluation/core/__init__.py b/cotracker/evaluation/core/__init__.py
index 5277f46..4547e07 100644
--- a/cotracker/evaluation/core/__init__.py
+++ b/cotracker/evaluation/core/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/cotracker/evaluation/core/eval_utils.py b/cotracker/evaluation/core/eval_utils.py
index 7002fa5..dca0380 100644
--- a/cotracker/evaluation/core/eval_utils.py
+++ b/cotracker/evaluation/core/eval_utils.py
@@ -1,138 +1,138 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import numpy as np
-
-from typing import Iterable, Mapping, Tuple, Union
-
-
-def compute_tapvid_metrics(
- query_points: np.ndarray,
- gt_occluded: np.ndarray,
- gt_tracks: np.ndarray,
- pred_occluded: np.ndarray,
- pred_tracks: np.ndarray,
- query_mode: str,
-) -> Mapping[str, np.ndarray]:
- """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
- See the TAP-Vid paper for details on the metric computation. All inputs are
- given in raster coordinates. The first three arguments should be the direct
- outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
- The paper metrics assume these are scaled relative to 256x256 images.
- pred_occluded and pred_tracks are your algorithm's predictions.
- This function takes a batch of inputs, and computes metrics separately for
- each video. The metrics for the full benchmark are a simple mean of the
- metrics across the full set of videos. These numbers are between 0 and 1,
- but the paper multiplies them by 100 to ease reading.
- Args:
- query_points: The query points, an in the format [t, y, x]. Its size is
- [b, n, 3], where b is the batch size and n is the number of queries
- gt_occluded: A boolean array of shape [b, n, t], where t is the number
- of frames. True indicates that the point is occluded.
- gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
- in the format [x, y]
- pred_occluded: A boolean array of predicted occlusions, in the same
- format as gt_occluded.
- pred_tracks: An array of track predictions from your algorithm, in the
- same format as gt_tracks.
- query_mode: Either 'first' or 'strided', depending on how queries are
- sampled. If 'first', we assume the prior knowledge that all points
- before the query point are occluded, and these are removed from the
- evaluation.
- Returns:
- A dict with the following keys:
- occlusion_accuracy: Accuracy at predicting occlusion.
- pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
- predicted to be within the given pixel threshold, ignoring occlusion
- prediction.
- jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
- threshold
- average_pts_within_thresh: average across pts_within_{x}
- average_jaccard: average across jaccard_{x}
- """
-
- metrics = {}
- # Fixed bug is described in:
- # https://github.com/facebookresearch/co-tracker/issues/20
- eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
-
- if query_mode == "first":
- # evaluate frames after the query frame
- query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
- elif query_mode == "strided":
- # evaluate all frames except the query frame
- query_frame_to_eval_frames = 1 - eye
- else:
- raise ValueError("Unknown query mode " + query_mode)
-
- query_frame = query_points[..., 0]
- query_frame = np.round(query_frame).astype(np.int32)
- evaluation_points = query_frame_to_eval_frames[query_frame] > 0
-
- # Occlusion accuracy is simply how often the predicted occlusion equals the
- # ground truth.
- occ_acc = np.sum(
- np.equal(pred_occluded, gt_occluded) & evaluation_points,
- axis=(1, 2),
- ) / np.sum(evaluation_points)
- metrics["occlusion_accuracy"] = occ_acc
-
- # Next, convert the predictions and ground truth positions into pixel
- # coordinates.
- visible = np.logical_not(gt_occluded)
- pred_visible = np.logical_not(pred_occluded)
- all_frac_within = []
- all_jaccard = []
- for thresh in [1, 2, 4, 8, 16]:
- # True positives are points that are within the threshold and where both
- # the prediction and the ground truth are listed as visible.
- within_dist = np.sum(
- np.square(pred_tracks - gt_tracks),
- axis=-1,
- ) < np.square(thresh)
- is_correct = np.logical_and(within_dist, visible)
-
- # Compute the frac_within_threshold, which is the fraction of points
- # within the threshold among points that are visible in the ground truth,
- # ignoring whether they're predicted to be visible.
- count_correct = np.sum(
- is_correct & evaluation_points,
- axis=(1, 2),
- )
- count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
- frac_correct = count_correct / count_visible_points
- metrics["pts_within_" + str(thresh)] = frac_correct
- all_frac_within.append(frac_correct)
-
- true_positives = np.sum(
- is_correct & pred_visible & evaluation_points, axis=(1, 2)
- )
-
- # The denominator of the jaccard metric is the true positives plus
- # false positives plus false negatives. However, note that true positives
- # plus false negatives is simply the number of points in the ground truth
- # which is easier to compute than trying to compute all three quantities.
- # Thus we just add the number of points in the ground truth to the number
- # of false positives.
- #
- # False positives are simply points that are predicted to be visible,
- # but the ground truth is not visible or too far from the prediction.
- gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
- false_positives = (~visible) & pred_visible
- false_positives = false_positives | ((~within_dist) & pred_visible)
- false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
- jaccard = true_positives / (gt_positives + false_positives)
- metrics["jaccard_" + str(thresh)] = jaccard
- all_jaccard.append(jaccard)
- metrics["average_jaccard"] = np.mean(
- np.stack(all_jaccard, axis=1),
- axis=1,
- )
- metrics["average_pts_within_thresh"] = np.mean(
- np.stack(all_frac_within, axis=1),
- axis=1,
- )
- return metrics
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import numpy as np
+
+from typing import Iterable, Mapping, Tuple, Union
+
+
+def compute_tapvid_metrics(
+ query_points: np.ndarray,
+ gt_occluded: np.ndarray,
+ gt_tracks: np.ndarray,
+ pred_occluded: np.ndarray,
+ pred_tracks: np.ndarray,
+ query_mode: str,
+) -> Mapping[str, np.ndarray]:
+ """Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
+ See the TAP-Vid paper for details on the metric computation. All inputs are
+ given in raster coordinates. The first three arguments should be the direct
+ outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
+ The paper metrics assume these are scaled relative to 256x256 images.
+ pred_occluded and pred_tracks are your algorithm's predictions.
+ This function takes a batch of inputs, and computes metrics separately for
+ each video. The metrics for the full benchmark are a simple mean of the
+ metrics across the full set of videos. These numbers are between 0 and 1,
+ but the paper multiplies them by 100 to ease reading.
+ Args:
+ query_points: The query points, an in the format [t, y, x]. Its size is
+ [b, n, 3], where b is the batch size and n is the number of queries
+ gt_occluded: A boolean array of shape [b, n, t], where t is the number
+ of frames. True indicates that the point is occluded.
+ gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
+ in the format [x, y]
+ pred_occluded: A boolean array of predicted occlusions, in the same
+ format as gt_occluded.
+ pred_tracks: An array of track predictions from your algorithm, in the
+ same format as gt_tracks.
+ query_mode: Either 'first' or 'strided', depending on how queries are
+ sampled. If 'first', we assume the prior knowledge that all points
+ before the query point are occluded, and these are removed from the
+ evaluation.
+ Returns:
+ A dict with the following keys:
+ occlusion_accuracy: Accuracy at predicting occlusion.
+ pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
+ predicted to be within the given pixel threshold, ignoring occlusion
+ prediction.
+ jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
+ threshold
+ average_pts_within_thresh: average across pts_within_{x}
+ average_jaccard: average across jaccard_{x}
+ """
+
+ metrics = {}
+ # Fixed bug is described in:
+ # https://github.com/facebookresearch/co-tracker/issues/20
+ eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
+
+ if query_mode == "first":
+ # evaluate frames after the query frame
+ query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
+ elif query_mode == "strided":
+ # evaluate all frames except the query frame
+ query_frame_to_eval_frames = 1 - eye
+ else:
+ raise ValueError("Unknown query mode " + query_mode)
+
+ query_frame = query_points[..., 0]
+ query_frame = np.round(query_frame).astype(np.int32)
+ evaluation_points = query_frame_to_eval_frames[query_frame] > 0
+
+ # Occlusion accuracy is simply how often the predicted occlusion equals the
+ # ground truth.
+ occ_acc = np.sum(
+ np.equal(pred_occluded, gt_occluded) & evaluation_points,
+ axis=(1, 2),
+ ) / np.sum(evaluation_points)
+ metrics["occlusion_accuracy"] = occ_acc
+
+ # Next, convert the predictions and ground truth positions into pixel
+ # coordinates.
+ visible = np.logical_not(gt_occluded)
+ pred_visible = np.logical_not(pred_occluded)
+ all_frac_within = []
+ all_jaccard = []
+ for thresh in [1, 2, 4, 8, 16]:
+ # True positives are points that are within the threshold and where both
+ # the prediction and the ground truth are listed as visible.
+ within_dist = np.sum(
+ np.square(pred_tracks - gt_tracks),
+ axis=-1,
+ ) < np.square(thresh)
+ is_correct = np.logical_and(within_dist, visible)
+
+ # Compute the frac_within_threshold, which is the fraction of points
+ # within the threshold among points that are visible in the ground truth,
+ # ignoring whether they're predicted to be visible.
+ count_correct = np.sum(
+ is_correct & evaluation_points,
+ axis=(1, 2),
+ )
+ count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
+ frac_correct = count_correct / count_visible_points
+ metrics["pts_within_" + str(thresh)] = frac_correct
+ all_frac_within.append(frac_correct)
+
+ true_positives = np.sum(
+ is_correct & pred_visible & evaluation_points, axis=(1, 2)
+ )
+
+ # The denominator of the jaccard metric is the true positives plus
+ # false positives plus false negatives. However, note that true positives
+ # plus false negatives is simply the number of points in the ground truth
+ # which is easier to compute than trying to compute all three quantities.
+ # Thus we just add the number of points in the ground truth to the number
+ # of false positives.
+ #
+ # False positives are simply points that are predicted to be visible,
+ # but the ground truth is not visible or too far from the prediction.
+ gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
+ false_positives = (~visible) & pred_visible
+ false_positives = false_positives | ((~within_dist) & pred_visible)
+ false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
+ jaccard = true_positives / (gt_positives + false_positives)
+ metrics["jaccard_" + str(thresh)] = jaccard
+ all_jaccard.append(jaccard)
+ metrics["average_jaccard"] = np.mean(
+ np.stack(all_jaccard, axis=1),
+ axis=1,
+ )
+ metrics["average_pts_within_thresh"] = np.mean(
+ np.stack(all_frac_within, axis=1),
+ axis=1,
+ )
+ return metrics
diff --git a/cotracker/evaluation/core/evaluator.py b/cotracker/evaluation/core/evaluator.py
index ffc697e..d8487e6 100644
--- a/cotracker/evaluation/core/evaluator.py
+++ b/cotracker/evaluation/core/evaluator.py
@@ -1,253 +1,253 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from collections import defaultdict
-import os
-from typing import Optional
-import torch
-from tqdm import tqdm
-import numpy as np
-
-from torch.utils.tensorboard import SummaryWriter
-from cotracker.datasets.utils import dataclass_to_cuda_
-from cotracker.utils.visualizer import Visualizer
-from cotracker.models.core.model_utils import reduce_masked_mean
-from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
-
-import logging
-
-
-class Evaluator:
- """
- A class defining the CoTracker evaluator.
- """
-
- def __init__(self, exp_dir) -> None:
- # Visualization
- self.exp_dir = exp_dir
- os.makedirs(exp_dir, exist_ok=True)
- self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
- self.visualize_dir = os.path.join(exp_dir, "visualisations")
-
- def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
- if isinstance(pred_trajectory, tuple):
- pred_trajectory, pred_visibility = pred_trajectory
- else:
- pred_visibility = None
- if "tapvid" in dataset_name:
- B, T, N, D = sample.trajectory.shape
- traj = sample.trajectory.clone()
- thr = 0.9
-
- if pred_visibility is None:
- logging.warning("visibility is NONE")
- pred_visibility = torch.zeros_like(sample.visibility)
-
- if not pred_visibility.dtype == torch.bool:
- pred_visibility = pred_visibility > thr
-
- query_points = sample.query_points.clone().cpu().numpy()
-
- pred_visibility = pred_visibility[:, :, :N]
- pred_trajectory = pred_trajectory[:, :, :N]
-
- gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
- gt_occluded = (
- torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
- )
-
- pred_occluded = (
- torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
- )
- pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
-
- out_metrics = compute_tapvid_metrics(
- query_points,
- gt_occluded,
- gt_tracks,
- pred_occluded,
- pred_tracks,
- query_mode="strided" if "strided" in dataset_name else "first",
- )
-
- metrics[sample.seq_name[0]] = out_metrics
- for metric_name in out_metrics.keys():
- if "avg" not in metrics:
- metrics["avg"] = {}
- metrics["avg"][metric_name] = np.mean(
- [v[metric_name] for k, v in metrics.items() if k != "avg"]
- )
-
- logging.info(f"Metrics: {out_metrics}")
- logging.info(f"avg: {metrics['avg']}")
- print("metrics", out_metrics)
- print("avg", metrics["avg"])
- elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
- *_, N, _ = sample.trajectory.shape
- B, T, N = sample.visibility.shape
- H, W = sample.video.shape[-2:]
- device = sample.video.device
-
- out_metrics = {}
-
- d_vis_sum = d_occ_sum = d_sum_all = 0.0
- thrs = [1, 2, 4, 8, 16]
- sx_ = (W - 1) / 255.0
- sy_ = (H - 1) / 255.0
- sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
- sc_pt = torch.from_numpy(sc_py).float().to(device)
- __, first_visible_inds = torch.max(sample.visibility, dim=1)
-
- frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
- start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
-
- for thr in thrs:
- d_ = (
- torch.norm(
- pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
- dim=-1,
- )
- < thr
- ).float() # B,S-1,N
- d_occ = (
- reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
- * 100.0
- )
- d_occ_sum += d_occ
- out_metrics[f"accuracy_occ_{thr}"] = d_occ
-
- d_vis = (
- reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
- )
- d_vis_sum += d_vis
- out_metrics[f"accuracy_vis_{thr}"] = d_vis
-
- d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
- d_sum_all += d_all
- out_metrics[f"accuracy_{thr}"] = d_all
-
- d_occ_avg = d_occ_sum / len(thrs)
- d_vis_avg = d_vis_sum / len(thrs)
- d_all_avg = d_sum_all / len(thrs)
-
- sur_thr = 50
- dists = torch.norm(
- pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
- dim=-1,
- ) # B,S,N
- dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
- survival = torch.cumprod(dist_ok, dim=1) # B,S,N
- out_metrics["survival"] = torch.mean(survival).item() * 100.0
-
- out_metrics["accuracy_occ"] = d_occ_avg
- out_metrics["accuracy_vis"] = d_vis_avg
- out_metrics["accuracy"] = d_all_avg
-
- metrics[sample.seq_name[0]] = out_metrics
- for metric_name in out_metrics.keys():
- if "avg" not in metrics:
- metrics["avg"] = {}
- metrics["avg"][metric_name] = float(
- np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
- )
-
- logging.info(f"Metrics: {out_metrics}")
- logging.info(f"avg: {metrics['avg']}")
- print("metrics", out_metrics)
- print("avg", metrics["avg"])
-
- @torch.no_grad()
- def evaluate_sequence(
- self,
- model,
- test_dataloader: torch.utils.data.DataLoader,
- dataset_name: str,
- train_mode=False,
- visualize_every: int = 1,
- writer: Optional[SummaryWriter] = None,
- step: Optional[int] = 0,
- ):
- metrics = {}
-
- vis = Visualizer(
- save_dir=self.exp_dir,
- fps=7,
- )
-
- for ind, sample in enumerate(tqdm(test_dataloader)):
- if isinstance(sample, tuple):
- sample, gotit = sample
- if not all(gotit):
- print("batch is None")
- continue
- if torch.cuda.is_available():
- dataclass_to_cuda_(sample)
- device = torch.device("cuda")
- else:
- device = torch.device("cpu")
-
- if (
- not train_mode
- and hasattr(model, "sequence_len")
- and (sample.visibility[:, : model.sequence_len].sum() == 0)
- ):
- print(f"skipping batch {ind}")
- continue
-
- if "tapvid" in dataset_name:
- queries = sample.query_points.clone().float()
-
- queries = torch.stack(
- [
- queries[:, :, 0],
- queries[:, :, 2],
- queries[:, :, 1],
- ],
- dim=2,
- ).to(device)
- else:
- queries = torch.cat(
- [
- torch.zeros_like(sample.trajectory[:, 0, :, :1]),
- sample.trajectory[:, 0],
- ],
- dim=2,
- ).to(device)
-
- pred_tracks = model(sample.video, queries)
- if "strided" in dataset_name:
- inv_video = sample.video.flip(1).clone()
- inv_queries = queries.clone()
- inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
-
- pred_trj, pred_vsb = pred_tracks
- inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
-
- inv_pred_trj = inv_pred_trj.flip(1)
- inv_pred_vsb = inv_pred_vsb.flip(1)
-
- mask = pred_trj == 0
-
- pred_trj[mask] = inv_pred_trj[mask]
- pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
-
- pred_tracks = pred_trj, pred_vsb
-
- if dataset_name == "badja" or dataset_name == "fastcapture":
- seq_name = sample.seq_name[0]
- else:
- seq_name = str(ind)
- if ind % visualize_every == 0:
- vis.visualize(
- sample.video,
- pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
- filename=dataset_name + "_" + seq_name,
- writer=writer,
- step=step,
- )
-
- self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
- return metrics
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from collections import defaultdict
+import os
+from typing import Optional
+import torch
+from tqdm import tqdm
+import numpy as np
+
+from torch.utils.tensorboard import SummaryWriter
+from cotracker.datasets.utils import dataclass_to_cuda_
+from cotracker.utils.visualizer import Visualizer
+from cotracker.models.core.model_utils import reduce_masked_mean
+from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
+
+import logging
+
+
+class Evaluator:
+ """
+ A class defining the CoTracker evaluator.
+ """
+
+ def __init__(self, exp_dir) -> None:
+ # Visualization
+ self.exp_dir = exp_dir
+ os.makedirs(exp_dir, exist_ok=True)
+ self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
+ self.visualize_dir = os.path.join(exp_dir, "visualisations")
+
+ def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
+ if isinstance(pred_trajectory, tuple):
+ pred_trajectory, pred_visibility = pred_trajectory
+ else:
+ pred_visibility = None
+ if "tapvid" in dataset_name:
+ B, T, N, D = sample.trajectory.shape
+ traj = sample.trajectory.clone()
+ thr = 0.9
+
+ if pred_visibility is None:
+ logging.warning("visibility is NONE")
+ pred_visibility = torch.zeros_like(sample.visibility)
+
+ if not pred_visibility.dtype == torch.bool:
+ pred_visibility = pred_visibility > thr
+
+ query_points = sample.query_points.clone().cpu().numpy()
+
+ pred_visibility = pred_visibility[:, :, :N]
+ pred_trajectory = pred_trajectory[:, :, :N]
+
+ gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
+ gt_occluded = (
+ torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
+ )
+
+ pred_occluded = (
+ torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
+ )
+ pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
+
+ out_metrics = compute_tapvid_metrics(
+ query_points,
+ gt_occluded,
+ gt_tracks,
+ pred_occluded,
+ pred_tracks,
+ query_mode="strided" if "strided" in dataset_name else "first",
+ )
+
+ metrics[sample.seq_name[0]] = out_metrics
+ for metric_name in out_metrics.keys():
+ if "avg" not in metrics:
+ metrics["avg"] = {}
+ metrics["avg"][metric_name] = np.mean(
+ [v[metric_name] for k, v in metrics.items() if k != "avg"]
+ )
+
+ logging.info(f"Metrics: {out_metrics}")
+ logging.info(f"avg: {metrics['avg']}")
+ print("metrics", out_metrics)
+ print("avg", metrics["avg"])
+ elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
+ *_, N, _ = sample.trajectory.shape
+ B, T, N = sample.visibility.shape
+ H, W = sample.video.shape[-2:]
+ device = sample.video.device
+
+ out_metrics = {}
+
+ d_vis_sum = d_occ_sum = d_sum_all = 0.0
+ thrs = [1, 2, 4, 8, 16]
+ sx_ = (W - 1) / 255.0
+ sy_ = (H - 1) / 255.0
+ sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
+ sc_pt = torch.from_numpy(sc_py).float().to(device)
+ __, first_visible_inds = torch.max(sample.visibility, dim=1)
+
+ frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
+ start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
+
+ for thr in thrs:
+ d_ = (
+ torch.norm(
+ pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
+ dim=-1,
+ )
+ < thr
+ ).float() # B,S-1,N
+ d_occ = (
+ reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
+ * 100.0
+ )
+ d_occ_sum += d_occ
+ out_metrics[f"accuracy_occ_{thr}"] = d_occ
+
+ d_vis = (
+ reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
+ )
+ d_vis_sum += d_vis
+ out_metrics[f"accuracy_vis_{thr}"] = d_vis
+
+ d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
+ d_sum_all += d_all
+ out_metrics[f"accuracy_{thr}"] = d_all
+
+ d_occ_avg = d_occ_sum / len(thrs)
+ d_vis_avg = d_vis_sum / len(thrs)
+ d_all_avg = d_sum_all / len(thrs)
+
+ sur_thr = 50
+ dists = torch.norm(
+ pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
+ dim=-1,
+ ) # B,S,N
+ dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
+ survival = torch.cumprod(dist_ok, dim=1) # B,S,N
+ out_metrics["survival"] = torch.mean(survival).item() * 100.0
+
+ out_metrics["accuracy_occ"] = d_occ_avg
+ out_metrics["accuracy_vis"] = d_vis_avg
+ out_metrics["accuracy"] = d_all_avg
+
+ metrics[sample.seq_name[0]] = out_metrics
+ for metric_name in out_metrics.keys():
+ if "avg" not in metrics:
+ metrics["avg"] = {}
+ metrics["avg"][metric_name] = float(
+ np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
+ )
+
+ logging.info(f"Metrics: {out_metrics}")
+ logging.info(f"avg: {metrics['avg']}")
+ print("metrics", out_metrics)
+ print("avg", metrics["avg"])
+
+ @torch.no_grad()
+ def evaluate_sequence(
+ self,
+ model,
+ test_dataloader: torch.utils.data.DataLoader,
+ dataset_name: str,
+ train_mode=False,
+ visualize_every: int = 1,
+ writer: Optional[SummaryWriter] = None,
+ step: Optional[int] = 0,
+ ):
+ metrics = {}
+
+ vis = Visualizer(
+ save_dir=self.exp_dir,
+ fps=7,
+ )
+
+ for ind, sample in enumerate(tqdm(test_dataloader)):
+ if isinstance(sample, tuple):
+ sample, gotit = sample
+ if not all(gotit):
+ print("batch is None")
+ continue
+ if torch.cuda.is_available():
+ dataclass_to_cuda_(sample)
+ device = torch.device("cuda")
+ else:
+ device = torch.device("cpu")
+
+ if (
+ not train_mode
+ and hasattr(model, "sequence_len")
+ and (sample.visibility[:, : model.sequence_len].sum() == 0)
+ ):
+ print(f"skipping batch {ind}")
+ continue
+
+ if "tapvid" in dataset_name:
+ queries = sample.query_points.clone().float()
+
+ queries = torch.stack(
+ [
+ queries[:, :, 0],
+ queries[:, :, 2],
+ queries[:, :, 1],
+ ],
+ dim=2,
+ ).to(device)
+ else:
+ queries = torch.cat(
+ [
+ torch.zeros_like(sample.trajectory[:, 0, :, :1]),
+ sample.trajectory[:, 0],
+ ],
+ dim=2,
+ ).to(device)
+
+ pred_tracks = model(sample.video, queries)
+ if "strided" in dataset_name:
+ inv_video = sample.video.flip(1).clone()
+ inv_queries = queries.clone()
+ inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
+
+ pred_trj, pred_vsb = pred_tracks
+ inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
+
+ inv_pred_trj = inv_pred_trj.flip(1)
+ inv_pred_vsb = inv_pred_vsb.flip(1)
+
+ mask = pred_trj == 0
+
+ pred_trj[mask] = inv_pred_trj[mask]
+ pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
+
+ pred_tracks = pred_trj, pred_vsb
+
+ if dataset_name == "badja" or dataset_name == "fastcapture":
+ seq_name = sample.seq_name[0]
+ else:
+ seq_name = str(ind)
+ if ind % visualize_every == 0:
+ vis.visualize(
+ sample.video,
+ pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
+ filename=dataset_name + "_" + seq_name,
+ writer=writer,
+ step=step,
+ )
+
+ self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
+ return metrics
diff --git a/cotracker/evaluation/evaluate.py b/cotracker/evaluation/evaluate.py
index 5d679d2..f12248d 100644
--- a/cotracker/evaluation/evaluate.py
+++ b/cotracker/evaluation/evaluate.py
@@ -1,169 +1,169 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import json
-import os
-from dataclasses import dataclass, field
-
-import hydra
-import numpy as np
-
-import torch
-from omegaconf import OmegaConf
-
-from cotracker.datasets.tap_vid_datasets import TapVidDataset
-from cotracker.datasets.dr_dataset import DynamicReplicaDataset
-from cotracker.datasets.utils import collate_fn
-
-from cotracker.models.evaluation_predictor import EvaluationPredictor
-
-from cotracker.evaluation.core.evaluator import Evaluator
-from cotracker.models.build_cotracker import (
- build_cotracker,
-)
-
-
-@dataclass(eq=False)
-class DefaultConfig:
- # Directory where all outputs of the experiment will be saved.
- exp_dir: str = "./outputs"
-
- # Name of the dataset to be used for the evaluation.
- dataset_name: str = "tapvid_davis_first"
- # The root directory of the dataset.
- dataset_root: str = "./"
-
- # Path to the pre-trained model checkpoint to be used for the evaluation.
- # The default value is the path to a specific CoTracker model checkpoint.
- checkpoint: str = "./checkpoints/cotracker2.pth"
-
- # EvaluationPredictor parameters
- # The size (N) of the support grid used in the predictor.
- # The total number of points is (N*N).
- grid_size: int = 5
- # The size (N) of the local support grid.
- local_grid_size: int = 8
- # A flag indicating whether to evaluate one ground truth point at a time.
- single_point: bool = True
- # The number of iterative updates for each sliding window.
- n_iters: int = 6
-
- seed: int = 0
- gpu_idx: int = 0
-
- # Override hydra's working directory to current working dir,
- # also disable storing the .hydra logs:
- hydra: dict = field(
- default_factory=lambda: {
- "run": {"dir": "."},
- "output_subdir": None,
- }
- )
-
-
-def run_eval(cfg: DefaultConfig):
- """
- The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
-
- Args:
- cfg (DefaultConfig): An instance of DefaultConfig class which includes:
- - exp_dir (str): The directory path for the experiment.
- - dataset_name (str): The name of the dataset to be used.
- - dataset_root (str): The root directory of the dataset.
- - checkpoint (str): The path to the CoTracker model's checkpoint.
- - single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
- - n_iters (int): The number of iterative updates for each sliding window.
- - seed (int): The seed for setting the random state for reproducibility.
- - gpu_idx (int): The index of the GPU to be used.
- """
- # Creating the experiment directory if it doesn't exist
- os.makedirs(cfg.exp_dir, exist_ok=True)
-
- # Saving the experiment configuration to a .yaml file in the experiment directory
- cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
- with open(cfg_file, "w") as f:
- OmegaConf.save(config=cfg, f=f)
-
- evaluator = Evaluator(cfg.exp_dir)
- cotracker_model = build_cotracker(cfg.checkpoint)
-
- # Creating the EvaluationPredictor object
- predictor = EvaluationPredictor(
- cotracker_model,
- grid_size=cfg.grid_size,
- local_grid_size=cfg.local_grid_size,
- single_point=cfg.single_point,
- n_iters=cfg.n_iters,
- )
- if torch.cuda.is_available():
- predictor.model = predictor.model.cuda()
-
- # Setting the random seeds
- torch.manual_seed(cfg.seed)
- np.random.seed(cfg.seed)
-
- # Constructing the specified dataset
- curr_collate_fn = collate_fn
- if "tapvid" in cfg.dataset_name:
- dataset_type = cfg.dataset_name.split("_")[1]
- if dataset_type == "davis":
- data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
- elif dataset_type == "kinetics":
- data_root = os.path.join(
- cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
- )
- test_dataset = TapVidDataset(
- dataset_type=dataset_type,
- data_root=data_root,
- queried_first=not "strided" in cfg.dataset_name,
- )
- elif cfg.dataset_name == "dynamic_replica":
- test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
-
- # Creating the DataLoader object
- test_dataloader = torch.utils.data.DataLoader(
- test_dataset,
- batch_size=1,
- shuffle=False,
- num_workers=14,
- collate_fn=curr_collate_fn,
- )
-
- # Timing and conducting the evaluation
- import time
-
- start = time.time()
- evaluate_result = evaluator.evaluate_sequence(
- predictor,
- test_dataloader,
- dataset_name=cfg.dataset_name,
- )
- end = time.time()
- print(end - start)
-
- # Saving the evaluation results to a .json file
- evaluate_result = evaluate_result["avg"]
- print("evaluate_result", evaluate_result)
- result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
- evaluate_result["time"] = end - start
- print(f"Dumping eval results to {result_file}.")
- with open(result_file, "w") as f:
- json.dump(evaluate_result, f)
-
-
-cs = hydra.core.config_store.ConfigStore.instance()
-cs.store(name="default_config_eval", node=DefaultConfig)
-
-
-@hydra.main(config_path="./configs/", config_name="default_config_eval")
-def evaluate(cfg: DefaultConfig) -> None:
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
- run_eval(cfg)
-
-
-if __name__ == "__main__":
- evaluate()
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import json
+import os
+from dataclasses import dataclass, field
+
+import hydra
+import numpy as np
+
+import torch
+from omegaconf import OmegaConf
+
+from cotracker.datasets.tap_vid_datasets import TapVidDataset
+from cotracker.datasets.dr_dataset import DynamicReplicaDataset
+from cotracker.datasets.utils import collate_fn
+
+from cotracker.models.evaluation_predictor import EvaluationPredictor
+
+from cotracker.evaluation.core.evaluator import Evaluator
+from cotracker.models.build_cotracker import (
+ build_cotracker,
+)
+
+
+@dataclass(eq=False)
+class DefaultConfig:
+ # Directory where all outputs of the experiment will be saved.
+ exp_dir: str = "./outputs"
+
+ # Name of the dataset to be used for the evaluation.
+ dataset_name: str = "tapvid_davis_first"
+ # The root directory of the dataset.
+ dataset_root: str = "./"
+
+ # Path to the pre-trained model checkpoint to be used for the evaluation.
+ # The default value is the path to a specific CoTracker model checkpoint.
+ checkpoint: str = "./checkpoints/cotracker2.pth"
+
+ # EvaluationPredictor parameters
+ # The size (N) of the support grid used in the predictor.
+ # The total number of points is (N*N).
+ grid_size: int = 5
+ # The size (N) of the local support grid.
+ local_grid_size: int = 8
+ # A flag indicating whether to evaluate one ground truth point at a time.
+ single_point: bool = True
+ # The number of iterative updates for each sliding window.
+ n_iters: int = 6
+
+ seed: int = 0
+ gpu_idx: int = 0
+
+ # Override hydra's working directory to current working dir,
+ # also disable storing the .hydra logs:
+ hydra: dict = field(
+ default_factory=lambda: {
+ "run": {"dir": "."},
+ "output_subdir": None,
+ }
+ )
+
+
+def run_eval(cfg: DefaultConfig):
+ """
+ The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
+
+ Args:
+ cfg (DefaultConfig): An instance of DefaultConfig class which includes:
+ - exp_dir (str): The directory path for the experiment.
+ - dataset_name (str): The name of the dataset to be used.
+ - dataset_root (str): The root directory of the dataset.
+ - checkpoint (str): The path to the CoTracker model's checkpoint.
+ - single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
+ - n_iters (int): The number of iterative updates for each sliding window.
+ - seed (int): The seed for setting the random state for reproducibility.
+ - gpu_idx (int): The index of the GPU to be used.
+ """
+ # Creating the experiment directory if it doesn't exist
+ os.makedirs(cfg.exp_dir, exist_ok=True)
+
+ # Saving the experiment configuration to a .yaml file in the experiment directory
+ cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
+ with open(cfg_file, "w") as f:
+ OmegaConf.save(config=cfg, f=f)
+
+ evaluator = Evaluator(cfg.exp_dir)
+ cotracker_model = build_cotracker(cfg.checkpoint)
+
+ # Creating the EvaluationPredictor object
+ predictor = EvaluationPredictor(
+ cotracker_model,
+ grid_size=cfg.grid_size,
+ local_grid_size=cfg.local_grid_size,
+ single_point=cfg.single_point,
+ n_iters=cfg.n_iters,
+ )
+ if torch.cuda.is_available():
+ predictor.model = predictor.model.cuda()
+
+ # Setting the random seeds
+ torch.manual_seed(cfg.seed)
+ np.random.seed(cfg.seed)
+
+ # Constructing the specified dataset
+ curr_collate_fn = collate_fn
+ if "tapvid" in cfg.dataset_name:
+ dataset_type = cfg.dataset_name.split("_")[1]
+ if dataset_type == "davis":
+ data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
+ elif dataset_type == "kinetics":
+ data_root = os.path.join(
+ cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
+ )
+ test_dataset = TapVidDataset(
+ dataset_type=dataset_type,
+ data_root=data_root,
+ queried_first=not "strided" in cfg.dataset_name,
+ )
+ elif cfg.dataset_name == "dynamic_replica":
+ test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
+
+ # Creating the DataLoader object
+ test_dataloader = torch.utils.data.DataLoader(
+ test_dataset,
+ batch_size=1,
+ shuffle=False,
+ num_workers=14,
+ collate_fn=curr_collate_fn,
+ )
+
+ # Timing and conducting the evaluation
+ import time
+
+ start = time.time()
+ evaluate_result = evaluator.evaluate_sequence(
+ predictor,
+ test_dataloader,
+ dataset_name=cfg.dataset_name,
+ )
+ end = time.time()
+ print(end - start)
+
+ # Saving the evaluation results to a .json file
+ evaluate_result = evaluate_result["avg"]
+ print("evaluate_result", evaluate_result)
+ result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
+ evaluate_result["time"] = end - start
+ print(f"Dumping eval results to {result_file}.")
+ with open(result_file, "w") as f:
+ json.dump(evaluate_result, f)
+
+
+cs = hydra.core.config_store.ConfigStore.instance()
+cs.store(name="default_config_eval", node=DefaultConfig)
+
+
+@hydra.main(config_path="./configs/", config_name="default_config_eval")
+def evaluate(cfg: DefaultConfig) -> None:
+ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
+ run_eval(cfg)
+
+
+if __name__ == "__main__":
+ evaluate()
diff --git a/cotracker/models/__init__.py b/cotracker/models/__init__.py
index 5277f46..4547e07 100644
--- a/cotracker/models/__init__.py
+++ b/cotracker/models/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/cotracker/models/__pycache__/__init__.cpython-38.pyc b/cotracker/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..8c571aa
Binary files /dev/null and b/cotracker/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cotracker/models/__pycache__/__init__.cpython-39.pyc b/cotracker/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..5a4d55e
Binary files /dev/null and b/cotracker/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/cotracker/models/__pycache__/build_cotracker.cpython-38.pyc b/cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
new file mode 100644
index 0000000..670a0d2
Binary files /dev/null and b/cotracker/models/__pycache__/build_cotracker.cpython-38.pyc differ
diff --git a/cotracker/models/__pycache__/build_cotracker.cpython-39.pyc b/cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
new file mode 100644
index 0000000..d332f8c
Binary files /dev/null and b/cotracker/models/__pycache__/build_cotracker.cpython-39.pyc differ
diff --git a/cotracker/models/build_cotracker.py b/cotracker/models/build_cotracker.py
index 1ae5f90..1448670 100644
--- a/cotracker/models/build_cotracker.py
+++ b/cotracker/models/build_cotracker.py
@@ -1,33 +1,33 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-
-from cotracker.models.core.cotracker.cotracker import CoTracker2
-
-
-def build_cotracker(
- checkpoint: str,
-):
- if checkpoint is None:
- return build_cotracker()
- model_name = checkpoint.split("/")[-1].split(".")[0]
- if model_name == "cotracker":
- return build_cotracker(checkpoint=checkpoint)
- else:
- raise ValueError(f"Unknown model name {model_name}")
-
-
-def build_cotracker(checkpoint=None):
- cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
-
- if checkpoint is not None:
- with open(checkpoint, "rb") as f:
- state_dict = torch.load(f, map_location="cpu")
- if "model" in state_dict:
- state_dict = state_dict["model"]
- cotracker.load_state_dict(state_dict)
- return cotracker
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+
+from cotracker.models.core.cotracker.cotracker import CoTracker2
+
+
+def build_cotracker(
+ checkpoint: str,
+):
+ if checkpoint is None:
+ return build_cotracker()
+ model_name = checkpoint.split("/")[-1].split(".")[0]
+ if model_name == "cotracker":
+ return build_cotracker(checkpoint=checkpoint)
+ else:
+ raise ValueError(f"Unknown model name {model_name}")
+
+
+def build_cotracker(checkpoint=None):
+ cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
+
+ if checkpoint is not None:
+ with open(checkpoint, "rb") as f:
+ state_dict = torch.load(f, map_location="cpu")
+ if "model" in state_dict:
+ state_dict = state_dict["model"]
+ cotracker.load_state_dict(state_dict)
+ return cotracker
diff --git a/cotracker/models/core/__init__.py b/cotracker/models/core/__init__.py
index 5277f46..4547e07 100644
--- a/cotracker/models/core/__init__.py
+++ b/cotracker/models/core/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/cotracker/models/core/__pycache__/__init__.cpython-38.pyc b/cotracker/models/core/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..7247d61
Binary files /dev/null and b/cotracker/models/core/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cotracker/models/core/__pycache__/__init__.cpython-39.pyc b/cotracker/models/core/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..54e6ba5
Binary files /dev/null and b/cotracker/models/core/__pycache__/__init__.cpython-39.pyc differ
diff --git a/cotracker/models/core/__pycache__/embeddings.cpython-38.pyc b/cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
new file mode 100644
index 0000000..d149482
Binary files /dev/null and b/cotracker/models/core/__pycache__/embeddings.cpython-38.pyc differ
diff --git a/cotracker/models/core/__pycache__/embeddings.cpython-39.pyc b/cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
new file mode 100644
index 0000000..abc6341
Binary files /dev/null and b/cotracker/models/core/__pycache__/embeddings.cpython-39.pyc differ
diff --git a/cotracker/models/core/__pycache__/model_utils.cpython-38.pyc b/cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
new file mode 100644
index 0000000..1ef2dbd
Binary files /dev/null and b/cotracker/models/core/__pycache__/model_utils.cpython-38.pyc differ
diff --git a/cotracker/models/core/__pycache__/model_utils.cpython-39.pyc b/cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
new file mode 100644
index 0000000..2ee9624
Binary files /dev/null and b/cotracker/models/core/__pycache__/model_utils.cpython-39.pyc differ
diff --git a/cotracker/models/core/cotracker/__init__.py b/cotracker/models/core/cotracker/__init__.py
index 5277f46..4547e07 100644
--- a/cotracker/models/core/cotracker/__init__.py
+++ b/cotracker/models/core/cotracker/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/cotracker/models/core/cotracker/__pycache__/__init__.cpython-38.pyc b/cotracker/models/core/cotracker/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..e970a17
Binary files /dev/null and b/cotracker/models/core/cotracker/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cotracker/models/core/cotracker/__pycache__/__init__.cpython-39.pyc b/cotracker/models/core/cotracker/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..b02b700
Binary files /dev/null and b/cotracker/models/core/cotracker/__pycache__/__init__.cpython-39.pyc differ
diff --git a/cotracker/models/core/cotracker/__pycache__/blocks.cpython-38.pyc b/cotracker/models/core/cotracker/__pycache__/blocks.cpython-38.pyc
new file mode 100644
index 0000000..73f41bd
Binary files /dev/null and b/cotracker/models/core/cotracker/__pycache__/blocks.cpython-38.pyc differ
diff --git a/cotracker/models/core/cotracker/__pycache__/blocks.cpython-39.pyc b/cotracker/models/core/cotracker/__pycache__/blocks.cpython-39.pyc
new file mode 100644
index 0000000..c80e782
Binary files /dev/null and b/cotracker/models/core/cotracker/__pycache__/blocks.cpython-39.pyc differ
diff --git a/cotracker/models/core/cotracker/__pycache__/cotracker.cpython-38.pyc b/cotracker/models/core/cotracker/__pycache__/cotracker.cpython-38.pyc
new file mode 100644
index 0000000..be29aa6
Binary files /dev/null and b/cotracker/models/core/cotracker/__pycache__/cotracker.cpython-38.pyc differ
diff --git a/cotracker/models/core/cotracker/__pycache__/cotracker.cpython-39.pyc b/cotracker/models/core/cotracker/__pycache__/cotracker.cpython-39.pyc
new file mode 100644
index 0000000..1c77573
Binary files /dev/null and b/cotracker/models/core/cotracker/__pycache__/cotracker.cpython-39.pyc differ
diff --git a/cotracker/models/core/cotracker/blocks.py b/cotracker/models/core/cotracker/blocks.py
index 8d61b25..0d4234c 100644
--- a/cotracker/models/core/cotracker/blocks.py
+++ b/cotracker/models/core/cotracker/blocks.py
@@ -1,367 +1,368 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from functools import partial
-from typing import Callable
-import collections
-from torch import Tensor
-from itertools import repeat
-
-from cotracker.models.core.model_utils import bilinear_sampler
-
-
-# From PyTorch internals
-def _ntuple(n):
- def parse(x):
- if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
- return tuple(x)
- return tuple(repeat(x, n))
-
- return parse
-
-
-def exists(val):
- return val is not None
-
-
-def default(val, d):
- return val if exists(val) else d
-
-
-to_2tuple = _ntuple(2)
-
-
-class Mlp(nn.Module):
- """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
-
- def __init__(
- self,
- in_features,
- hidden_features=None,
- out_features=None,
- act_layer=nn.GELU,
- norm_layer=None,
- bias=True,
- drop=0.0,
- use_conv=False,
- ):
- super().__init__()
- out_features = out_features or in_features
- hidden_features = hidden_features or in_features
- bias = to_2tuple(bias)
- drop_probs = to_2tuple(drop)
- linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
-
- self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
- self.act = act_layer()
- self.drop1 = nn.Dropout(drop_probs[0])
- self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
- self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
- self.drop2 = nn.Dropout(drop_probs[1])
-
- def forward(self, x):
- x = self.fc1(x)
- x = self.act(x)
- x = self.drop1(x)
- x = self.fc2(x)
- x = self.drop2(x)
- return x
-
-
-class ResidualBlock(nn.Module):
- def __init__(self, in_planes, planes, norm_fn="group", stride=1):
- super(ResidualBlock, self).__init__()
-
- self.conv1 = nn.Conv2d(
- in_planes,
- planes,
- kernel_size=3,
- padding=1,
- stride=stride,
- padding_mode="zeros",
- )
- self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
- self.relu = nn.ReLU(inplace=True)
-
- num_groups = planes // 8
-
- if norm_fn == "group":
- self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
- self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
- if not stride == 1:
- self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
-
- elif norm_fn == "batch":
- self.norm1 = nn.BatchNorm2d(planes)
- self.norm2 = nn.BatchNorm2d(planes)
- if not stride == 1:
- self.norm3 = nn.BatchNorm2d(planes)
-
- elif norm_fn == "instance":
- self.norm1 = nn.InstanceNorm2d(planes)
- self.norm2 = nn.InstanceNorm2d(planes)
- if not stride == 1:
- self.norm3 = nn.InstanceNorm2d(planes)
-
- elif norm_fn == "none":
- self.norm1 = nn.Sequential()
- self.norm2 = nn.Sequential()
- if not stride == 1:
- self.norm3 = nn.Sequential()
-
- if stride == 1:
- self.downsample = None
-
- else:
- self.downsample = nn.Sequential(
- nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
- )
-
- def forward(self, x):
- y = x
- y = self.relu(self.norm1(self.conv1(y)))
- y = self.relu(self.norm2(self.conv2(y)))
-
- if self.downsample is not None:
- x = self.downsample(x)
-
- return self.relu(x + y)
-
-
-class BasicEncoder(nn.Module):
- def __init__(self, input_dim=3, output_dim=128, stride=4):
- super(BasicEncoder, self).__init__()
- self.stride = stride
- self.norm_fn = "instance"
- self.in_planes = output_dim // 2
-
- self.norm1 = nn.InstanceNorm2d(self.in_planes)
- self.norm2 = nn.InstanceNorm2d(output_dim * 2)
-
- self.conv1 = nn.Conv2d(
- input_dim,
- self.in_planes,
- kernel_size=7,
- stride=2,
- padding=3,
- padding_mode="zeros",
- )
- self.relu1 = nn.ReLU(inplace=True)
- self.layer1 = self._make_layer(output_dim // 2, stride=1)
- self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
- self.layer3 = self._make_layer(output_dim, stride=2)
- self.layer4 = self._make_layer(output_dim, stride=2)
-
- self.conv2 = nn.Conv2d(
- output_dim * 3 + output_dim // 4,
- output_dim * 2,
- kernel_size=3,
- padding=1,
- padding_mode="zeros",
- )
- self.relu2 = nn.ReLU(inplace=True)
- self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
- for m in self.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
- elif isinstance(m, (nn.InstanceNorm2d)):
- if m.weight is not None:
- nn.init.constant_(m.weight, 1)
- if m.bias is not None:
- nn.init.constant_(m.bias, 0)
-
- def _make_layer(self, dim, stride=1):
- layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
- layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
- layers = (layer1, layer2)
-
- self.in_planes = dim
- return nn.Sequential(*layers)
-
- def forward(self, x):
- _, _, H, W = x.shape
-
- x = self.conv1(x)
- x = self.norm1(x)
- x = self.relu1(x)
-
- a = self.layer1(x)
- b = self.layer2(a)
- c = self.layer3(b)
- d = self.layer4(c)
-
- def _bilinear_intepolate(x):
- return F.interpolate(
- x,
- (H // self.stride, W // self.stride),
- mode="bilinear",
- align_corners=True,
- )
-
- a = _bilinear_intepolate(a)
- b = _bilinear_intepolate(b)
- c = _bilinear_intepolate(c)
- d = _bilinear_intepolate(d)
-
- x = self.conv2(torch.cat([a, b, c, d], dim=1))
- x = self.norm2(x)
- x = self.relu2(x)
- x = self.conv3(x)
- return x
-
-
-class CorrBlock:
- def __init__(
- self,
- fmaps,
- num_levels=4,
- radius=4,
- multiple_track_feats=False,
- padding_mode="zeros",
- ):
- B, S, C, H, W = fmaps.shape
- self.S, self.C, self.H, self.W = S, C, H, W
- self.padding_mode = padding_mode
- self.num_levels = num_levels
- self.radius = radius
- self.fmaps_pyramid = []
- self.multiple_track_feats = multiple_track_feats
-
- self.fmaps_pyramid.append(fmaps)
- for i in range(self.num_levels - 1):
- fmaps_ = fmaps.reshape(B * S, C, H, W)
- fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
- _, _, H, W = fmaps_.shape
- fmaps = fmaps_.reshape(B, S, C, H, W)
- self.fmaps_pyramid.append(fmaps)
-
- def sample(self, coords):
- r = self.radius
- B, S, N, D = coords.shape
- assert D == 2
-
- H, W = self.H, self.W
- out_pyramid = []
- for i in range(self.num_levels):
- corrs = self.corrs_pyramid[i] # B, S, N, H, W
- *_, H, W = corrs.shape
-
- dx = torch.linspace(-r, r, 2 * r + 1)
- dy = torch.linspace(-r, r, 2 * r + 1)
- delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
-
- centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
- delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
- coords_lvl = centroid_lvl + delta_lvl
-
- corrs = bilinear_sampler(
- corrs.reshape(B * S * N, 1, H, W),
- coords_lvl,
- padding_mode=self.padding_mode,
- )
- corrs = corrs.view(B, S, N, -1)
- out_pyramid.append(corrs)
-
- out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
- out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
- return out
-
- def corr(self, targets):
- B, S, N, C = targets.shape
- if self.multiple_track_feats:
- targets_split = targets.split(C // self.num_levels, dim=-1)
- B, S, N, C = targets_split[0].shape
-
- assert C == self.C
- assert S == self.S
-
- fmap1 = targets
-
- self.corrs_pyramid = []
- for i, fmaps in enumerate(self.fmaps_pyramid):
- *_, H, W = fmaps.shape
- fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
- if self.multiple_track_feats:
- fmap1 = targets_split[i]
- corrs = torch.matmul(fmap1, fmap2s)
- corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
- corrs = corrs / torch.sqrt(torch.tensor(C).float())
- self.corrs_pyramid.append(corrs)
-
-
-class Attention(nn.Module):
- def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
- super().__init__()
- inner_dim = dim_head * num_heads
- context_dim = default(context_dim, query_dim)
- self.scale = dim_head**-0.5
- self.heads = num_heads
-
- self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
- self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
- self.to_out = nn.Linear(inner_dim, query_dim)
-
- def forward(self, x, context=None, attn_bias=None):
- B, N1, C = x.shape
- h = self.heads
-
- q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
- context = default(context, x)
- k, v = self.to_kv(context).chunk(2, dim=-1)
-
- N2 = context.shape[1]
- k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
- v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
-
- sim = (q @ k.transpose(-2, -1)) * self.scale
-
- if attn_bias is not None:
- sim = sim + attn_bias
- attn = sim.softmax(dim=-1)
-
- x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
- return self.to_out(x)
-
-
-class AttnBlock(nn.Module):
- def __init__(
- self,
- hidden_size,
- num_heads,
- attn_class: Callable[..., nn.Module] = Attention,
- mlp_ratio=4.0,
- **block_kwargs
- ):
- super().__init__()
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
-
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- approx_gelu = lambda: nn.GELU(approximate="tanh")
- self.mlp = Mlp(
- in_features=hidden_size,
- hidden_features=mlp_hidden_dim,
- act_layer=approx_gelu,
- drop=0,
- )
-
- def forward(self, x, mask=None):
- attn_bias = mask
- if mask is not None:
- mask = (
- (mask[:, None] * mask[:, :, None])
- .unsqueeze(1)
- .expand(-1, self.attn.num_heads, -1, -1)
- )
- max_neg_value = -torch.finfo(x.dtype).max
- attn_bias = (~mask) * max_neg_value
- x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
- x = x + self.mlp(self.norm2(x))
- return x
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import partial
+from typing import Callable
+import collections
+from torch import Tensor
+from itertools import repeat
+
+from cotracker.models.core.model_utils import bilinear_sampler
+
+
+# From PyTorch internals
+def _ntuple(n):
+ def parse(x):
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
+ return tuple(x)
+ return tuple(repeat(x, n))
+
+ return parse
+
+
+def exists(val):
+ return val is not None
+
+
+def default(val, d):
+ return val if exists(val) else d
+
+
+to_2tuple = _ntuple(2)
+
+
+class Mlp(nn.Module):
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
+
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ norm_layer=None,
+ bias=True,
+ drop=0.0,
+ use_conv=False,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ bias = to_2tuple(bias)
+ drop_probs = to_2tuple(drop)
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
+
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
+ self.act = act_layer()
+ self.drop1 = nn.Dropout(drop_probs[0])
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
+ self.drop2 = nn.Dropout(drop_probs[1])
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop1(x)
+ x = self.fc2(x)
+ x = self.drop2(x)
+ return x
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ planes,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ padding_mode="zeros",
+ )
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == "group":
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == "batch":
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == "instance":
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == "none":
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
+ )
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x + y)
+
+
+class BasicEncoder(nn.Module):
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
+ super(BasicEncoder, self).__init__()
+ self.stride = stride
+ self.norm_fn = "instance"
+ self.in_planes = output_dim // 2
+
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
+
+ self.conv1 = nn.Conv2d(
+ input_dim,
+ self.in_planes,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ padding_mode="zeros",
+ )
+ self.relu1 = nn.ReLU(inplace=True)
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
+ self.layer3 = self._make_layer(output_dim, stride=2)
+ self.layer4 = self._make_layer(output_dim, stride=2)
+
+ self.conv2 = nn.Conv2d(
+ output_dim * 3 + output_dim // 4,
+ output_dim * 2,
+ kernel_size=3,
+ padding=1,
+ padding_mode="zeros",
+ )
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+ elif isinstance(m, (nn.InstanceNorm2d)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ _, _, H, W = x.shape
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ # 四层残差块
+ a = self.layer1(x)
+ b = self.layer2(a)
+ c = self.layer3(b)
+ d = self.layer4(c)
+
+ def _bilinear_intepolate(x):
+ return F.interpolate(
+ x,
+ (H // self.stride, W // self.stride),
+ mode="bilinear",
+ align_corners=True,
+ )
+
+ a = _bilinear_intepolate(a)
+ b = _bilinear_intepolate(b)
+ c = _bilinear_intepolate(c)
+ d = _bilinear_intepolate(d)
+
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
+ x = self.norm2(x)
+ x = self.relu2(x)
+ x = self.conv3(x)
+ return x
+
+
+class CorrBlock:
+ def __init__(
+ self,
+ fmaps,
+ num_levels=4,
+ radius=4,
+ multiple_track_feats=False,
+ padding_mode="zeros",
+ ):
+ B, S, C, H, W = fmaps.shape
+ self.S, self.C, self.H, self.W = S, C, H, W
+ self.padding_mode = padding_mode
+ self.num_levels = num_levels
+ self.radius = radius
+ self.fmaps_pyramid = []
+ self.multiple_track_feats = multiple_track_feats
+
+ self.fmaps_pyramid.append(fmaps)
+ for i in range(self.num_levels - 1):
+ fmaps_ = fmaps.reshape(B * S, C, H, W)
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
+ _, _, H, W = fmaps_.shape
+ fmaps = fmaps_.reshape(B, S, C, H, W)
+ self.fmaps_pyramid.append(fmaps)
+
+ def sample(self, coords):
+ r = self.radius
+ B, S, N, D = coords.shape
+ assert D == 2
+
+ H, W = self.H, self.W
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
+ *_, H, W = corrs.shape
+
+ dx = torch.linspace(-r, r, 2 * r + 1)
+ dy = torch.linspace(-r, r, 2 * r + 1)
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
+
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+
+ corrs = bilinear_sampler(
+ corrs.reshape(B * S * N, 1, H, W),
+ coords_lvl,
+ padding_mode=self.padding_mode,
+ )
+ corrs = corrs.view(B, S, N, -1)
+ out_pyramid.append(corrs)
+
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
+ out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
+ return out
+
+ def corr(self, targets):
+ B, S, N, C = targets.shape
+ if self.multiple_track_feats:
+ targets_split = targets.split(C // self.num_levels, dim=-1)
+ B, S, N, C = targets_split[0].shape
+
+ assert C == self.C
+ assert S == self.S
+
+ fmap1 = targets
+
+ self.corrs_pyramid = []
+ for i, fmaps in enumerate(self.fmaps_pyramid):
+ *_, H, W = fmaps.shape
+ fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
+ if self.multiple_track_feats:
+ fmap1 = targets_split[i]
+ corrs = torch.matmul(fmap1, fmap2s)
+ corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
+ self.corrs_pyramid.append(corrs)
+
+
+class Attention(nn.Module):
+ def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
+ super().__init__()
+ inner_dim = dim_head * num_heads
+ context_dim = default(context_dim, query_dim)
+ self.scale = dim_head**-0.5
+ self.heads = num_heads
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
+ self.to_out = nn.Linear(inner_dim, query_dim)
+
+ def forward(self, x, context=None, attn_bias=None):
+ B, N1, C = x.shape
+ h = self.heads
+
+ q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
+ context = default(context, x)
+ k, v = self.to_kv(context).chunk(2, dim=-1)
+
+ N2 = context.shape[1]
+ k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
+ v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
+
+ sim = (q @ k.transpose(-2, -1)) * self.scale
+
+ if attn_bias is not None:
+ sim = sim + attn_bias
+ attn = sim.softmax(dim=-1)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
+ return self.to_out(x)
+
+
+class AttnBlock(nn.Module):
+ def __init__(
+ self,
+ hidden_size,
+ num_heads,
+ attn_class: Callable[..., nn.Module] = Attention,
+ mlp_ratio=4.0,
+ **block_kwargs
+ ):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+
+ def forward(self, x, mask=None):
+ attn_bias = mask
+ if mask is not None:
+ mask = (
+ (mask[:, None] * mask[:, :, None])
+ .unsqueeze(1)
+ .expand(-1, self.attn.num_heads, -1, -1)
+ )
+ max_neg_value = -torch.finfo(x.dtype).max
+ attn_bias = (~mask) * max_neg_value
+ x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/cotracker/models/core/cotracker/cotracker.py b/cotracker/models/core/cotracker/cotracker.py
index 53178fb..41422ca 100644
--- a/cotracker/models/core/cotracker/cotracker.py
+++ b/cotracker/models/core/cotracker/cotracker.py
@@ -1,503 +1,519 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from cotracker.models.core.model_utils import sample_features4d, sample_features5d
-from cotracker.models.core.embeddings import (
- get_2d_embedding,
- get_1d_sincos_pos_embed_from_grid,
- get_2d_sincos_pos_embed,
-)
-
-from cotracker.models.core.cotracker.blocks import (
- Mlp,
- BasicEncoder,
- AttnBlock,
- CorrBlock,
- Attention,
-)
-
-torch.manual_seed(0)
-
-
-class CoTracker2(nn.Module):
- def __init__(
- self,
- window_len=8,
- stride=4,
- add_space_attn=True,
- num_virtual_tracks=64,
- model_resolution=(384, 512),
- ):
- super(CoTracker2, self).__init__()
- self.window_len = window_len
- self.stride = stride
- self.hidden_dim = 256
- self.latent_dim = 128
- self.add_space_attn = add_space_attn
- self.fnet = BasicEncoder(output_dim=self.latent_dim)
- self.num_virtual_tracks = num_virtual_tracks
- self.model_resolution = model_resolution
- self.input_dim = 456
- self.updateformer = EfficientUpdateFormer(
- space_depth=6,
- time_depth=6,
- input_dim=self.input_dim,
- hidden_size=384,
- output_dim=self.latent_dim + 2,
- mlp_ratio=4.0,
- add_space_attn=add_space_attn,
- num_virtual_tracks=num_virtual_tracks,
- )
-
- time_grid = torch.linspace(0, window_len - 1, window_len).reshape(1, window_len, 1)
-
- self.register_buffer(
- "time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
- )
-
- self.register_buffer(
- "pos_emb",
- get_2d_sincos_pos_embed(
- embed_dim=self.input_dim,
- grid_size=(
- model_resolution[0] // stride,
- model_resolution[1] // stride,
- ),
- ),
- )
- self.norm = nn.GroupNorm(1, self.latent_dim)
- self.track_feat_updater = nn.Sequential(
- nn.Linear(self.latent_dim, self.latent_dim),
- nn.GELU(),
- )
- self.vis_predictor = nn.Sequential(
- nn.Linear(self.latent_dim, 1),
- )
-
- def forward_window(
- self,
- fmaps,
- coords,
- track_feat=None,
- vis=None,
- track_mask=None,
- attention_mask=None,
- iters=4,
- ):
- # B = batch size
- # S = number of frames in the window)
- # N = number of tracks
- # C = channels of a point feature vector
- # E = positional embedding size
- # LRR = local receptive field radius
- # D = dimension of the transformer input tokens
-
- # track_feat = B S N C
- # vis = B S N 1
- # track_mask = B S N 1
- # attention_mask = B S N
-
- B, S_init, N, __ = track_mask.shape
- B, S, *_ = fmaps.shape
-
- track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
- track_mask_vis = (
- torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2)
- )
-
- corr_block = CorrBlock(
- fmaps,
- num_levels=4,
- radius=3,
- padding_mode="border",
- )
-
- sampled_pos_emb = (
- sample_features4d(self.pos_emb.repeat(B, 1, 1, 1), coords[:, 0])
- .reshape(B * N, self.input_dim)
- .unsqueeze(1)
- ) # B E N -> (B N) 1 E
-
- coord_preds = []
- for __ in range(iters):
- coords = coords.detach() # B S N 2
- corr_block.corr(track_feat)
-
- # Sample correlation features around each point
- fcorrs = corr_block.sample(coords) # (B N) S LRR
-
- # Get the flow embeddings
- flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
- flow_emb = get_2d_embedding(flows, 64, cat_coords=True) # N S E
-
- track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
-
- transformer_input = torch.cat([flow_emb, fcorrs, track_feat_, track_mask_vis], dim=2)
- x = transformer_input + sampled_pos_emb + self.time_emb
- x = x.view(B, N, S, -1) # (B N) S D -> B N S D
-
- delta = self.updateformer(
- x,
- attention_mask.reshape(B * S, N), # B S N -> (B S) N
- )
-
- delta_coords = delta[..., :2].permute(0, 2, 1, 3)
- coords = coords + delta_coords
- coord_preds.append(coords * self.stride)
-
- delta_feats_ = delta[..., 2:].reshape(B * N * S, self.latent_dim)
- track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
- track_feat_ = self.track_feat_updater(self.norm(delta_feats_)) + track_feat_
- track_feat = track_feat_.reshape(B, N, S, self.latent_dim).permute(
- 0, 2, 1, 3
- ) # (B N S) C -> B S N C
-
- vis_pred = self.vis_predictor(track_feat).reshape(B, S, N)
- return coord_preds, vis_pred
-
- def get_track_feat(self, fmaps, queried_frames, queried_coords):
- sample_frames = queried_frames[:, None, :, None]
- sample_coords = torch.cat(
- [
- sample_frames,
- queried_coords[:, None],
- ],
- dim=-1,
- )
- sample_track_feats = sample_features5d(fmaps, sample_coords)
- return sample_track_feats
-
- def init_video_online_processing(self):
- self.online_ind = 0
- self.online_track_feat = None
- self.online_coords_predicted = None
- self.online_vis_predicted = None
-
- def forward(self, video, queries, iters=4, is_train=False, is_online=False):
- """Predict tracks
-
- Args:
- video (FloatTensor[B, T, 3]): input videos.
- queries (FloatTensor[B, N, 3]): point queries.
- iters (int, optional): number of updates. Defaults to 4.
- is_train (bool, optional): enables training mode. Defaults to False.
- is_online (bool, optional): enables online mode. Defaults to False. Before enabling, call model.init_video_online_processing().
-
- Returns:
- - coords_predicted (FloatTensor[B, T, N, 2]):
- - vis_predicted (FloatTensor[B, T, N]):
- - train_data: `None` if `is_train` is false, otherwise:
- - all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
- - all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
- - mask (BoolTensor[B, T, N]):
- """
- B, T, C, H, W = video.shape
- B, N, __ = queries.shape
- S = self.window_len
- device = queries.device
-
- # B = batch size
- # S = number of frames in the window of the padded video
- # S_trimmed = actual number of frames in the window
- # N = number of tracks
- # C = color channels (3 for RGB)
- # E = positional embedding size
- # LRR = local receptive field radius
- # D = dimension of the transformer input tokens
-
- # video = B T C H W
- # queries = B N 3
- # coords_init = B S N 2
- # vis_init = B S N 1
-
- assert S >= 2 # A tracker needs at least two frames to track something
- if is_online:
- assert T <= S, "Online mode: video chunk must be <= window size."
- assert self.online_ind is not None, "Call model.init_video_online_processing() first."
- assert not is_train, "Training not supported in online mode."
- step = S // 2 # How much the sliding window moves at every step
- video = 2 * (video / 255.0) - 1.0
-
- # The first channel is the frame number
- # The rest are the coordinates of points we want to track
- queried_frames = queries[:, :, 0].long()
-
- queried_coords = queries[..., 1:]
- queried_coords = queried_coords / self.stride
-
- # We store our predictions here
- coords_predicted = torch.zeros((B, T, N, 2), device=device)
- vis_predicted = torch.zeros((B, T, N), device=device)
- if is_online:
- if self.online_coords_predicted is None:
- # Init online predictions with zeros
- self.online_coords_predicted = coords_predicted
- self.online_vis_predicted = vis_predicted
- else:
- # Pad online predictions with zeros for the current window
- pad = min(step, T - step)
- coords_predicted = F.pad(
- self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
- )
- vis_predicted = F.pad(self.online_vis_predicted, (0, 0, 0, pad), "constant")
- all_coords_predictions, all_vis_predictions = [], []
-
- # Pad the video so that an integer number of sliding windows fit into it
- # TODO: we may drop this requirement because the transformer should not care
- # TODO: pad the features instead of the video
- pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0
- video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape(
- B, -1, C, H, W
- )
-
- # Compute convolutional features for the video or for the current chunk in case of online mode
- fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
- B, -1, self.latent_dim, H // self.stride, W // self.stride
- )
-
- # We compute track features
- track_feat = self.get_track_feat(
- fmaps,
- queried_frames - self.online_ind if is_online else queried_frames,
- queried_coords,
- ).repeat(1, S, 1, 1)
- if is_online:
- # We update track features for the current window
- sample_frames = queried_frames[:, None, :, None] # B 1 N 1
- left = 0 if self.online_ind == 0 else self.online_ind + step
- right = self.online_ind + S
- sample_mask = (sample_frames >= left) & (sample_frames < right)
- if self.online_track_feat is None:
- self.online_track_feat = torch.zeros_like(track_feat, device=device)
- self.online_track_feat += track_feat * sample_mask
- track_feat = self.online_track_feat.clone()
- # We process ((num_windows - 1) * step + S) frames in total, so there are
- # (ceil((T - S) / step) + 1) windows
- num_windows = (T - S + step - 1) // step + 1
- # We process only the current video chunk in the online mode
- indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
-
- coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
- vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
- for ind in indices:
- # We copy over coords and vis for tracks that are queried
- # by the end of the previous window, which is ind + overlap
- if ind > 0:
- overlap = S - step
- copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1
- coords_prev = torch.nn.functional.pad(
- coords_predicted[:, ind : ind + overlap] / self.stride,
- (0, 0, 0, 0, 0, step),
- "replicate",
- ) # B S N 2
- vis_prev = torch.nn.functional.pad(
- vis_predicted[:, ind : ind + overlap, :, None].clone(),
- (0, 0, 0, 0, 0, step),
- "replicate",
- ) # B S N 1
- coords_init = torch.where(
- copy_over.expand_as(coords_init), coords_prev, coords_init
- )
- vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init)
-
- # The attention mask is 1 for the spatio-temporal points within
- # a track which is updated in the current window
- attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N
-
- # The track mask is 1 for the spatio-temporal points that actually
- # need updating: only after begin queried, and not if contained
- # in a previous window
- track_mask = (
- queried_frames[:, None, :, None]
- <= torch.arange(ind, ind + S, device=device)[None, :, None, None]
- ).contiguous() # B S N 1
-
- if ind > 0:
- track_mask[:, :overlap, :, :] = False
-
- # Predict the coordinates and visibility for the current window
- coords, vis = self.forward_window(
- fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
- coords=coords_init,
- track_feat=attention_mask.unsqueeze(-1) * track_feat,
- vis=vis_init,
- track_mask=track_mask,
- attention_mask=attention_mask,
- iters=iters,
- )
-
- S_trimmed = T if is_online else min(T - ind, S) # accounts for last window duration
- coords_predicted[:, ind : ind + S] = coords[-1][:, :S_trimmed]
- vis_predicted[:, ind : ind + S] = vis[:, :S_trimmed]
- if is_train:
- all_coords_predictions.append([coord[:, :S_trimmed] for coord in coords])
- all_vis_predictions.append(torch.sigmoid(vis[:, :S_trimmed]))
-
- if is_online:
- self.online_ind += step
- self.online_coords_predicted = coords_predicted
- self.online_vis_predicted = vis_predicted
- vis_predicted = torch.sigmoid(vis_predicted)
-
- if is_train:
- mask = queried_frames[:, None] <= torch.arange(0, T, device=device)[None, :, None]
- train_data = (all_coords_predictions, all_vis_predictions, mask)
- else:
- train_data = None
-
- return coords_predicted, vis_predicted, train_data
-
-
-class EfficientUpdateFormer(nn.Module):
- """
- Transformer model that updates track estimates.
- """
-
- def __init__(
- self,
- space_depth=6,
- time_depth=6,
- input_dim=320,
- hidden_size=384,
- num_heads=8,
- output_dim=130,
- mlp_ratio=4.0,
- add_space_attn=True,
- num_virtual_tracks=64,
- ):
- super().__init__()
- self.out_channels = 2
- self.num_heads = num_heads
- self.hidden_size = hidden_size
- self.add_space_attn = add_space_attn
- self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
- self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
- self.num_virtual_tracks = num_virtual_tracks
- self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
- self.time_blocks = nn.ModuleList(
- [
- AttnBlock(
- hidden_size,
- num_heads,
- mlp_ratio=mlp_ratio,
- attn_class=Attention,
- )
- for _ in range(time_depth)
- ]
- )
-
- if add_space_attn:
- self.space_virtual_blocks = nn.ModuleList(
- [
- AttnBlock(
- hidden_size,
- num_heads,
- mlp_ratio=mlp_ratio,
- attn_class=Attention,
- )
- for _ in range(space_depth)
- ]
- )
- self.space_point2virtual_blocks = nn.ModuleList(
- [
- CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
- for _ in range(space_depth)
- ]
- )
- self.space_virtual2point_blocks = nn.ModuleList(
- [
- CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
- for _ in range(space_depth)
- ]
- )
- assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
- self.initialize_weights()
-
- def initialize_weights(self):
- def _basic_init(module):
- if isinstance(module, nn.Linear):
- torch.nn.init.xavier_uniform_(module.weight)
- if module.bias is not None:
- nn.init.constant_(module.bias, 0)
-
- self.apply(_basic_init)
-
- def forward(self, input_tensor, mask=None):
- tokens = self.input_transform(input_tensor)
- B, _, T, _ = tokens.shape
- virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
- tokens = torch.cat([tokens, virtual_tokens], dim=1)
- _, N, _, _ = tokens.shape
-
- j = 0
- for i in range(len(self.time_blocks)):
- time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
- time_tokens = self.time_blocks[i](time_tokens)
-
- tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
- if self.add_space_attn and (
- i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0
- ):
- space_tokens = (
- tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
- ) # B N T C -> (B T) N C
- point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
- virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
-
- virtual_tokens = self.space_virtual2point_blocks[j](
- virtual_tokens, point_tokens, mask=mask
- )
- virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
- point_tokens = self.space_point2virtual_blocks[j](
- point_tokens, virtual_tokens, mask=mask
- )
- space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
- tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
- j += 1
- tokens = tokens[:, : N - self.num_virtual_tracks]
- flow = self.flow_head(tokens)
- return flow
-
-
-class CrossAttnBlock(nn.Module):
- def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
- super().__init__()
- self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- self.norm_context = nn.LayerNorm(hidden_size)
- self.cross_attn = Attention(
- hidden_size, context_dim=context_dim, num_heads=num_heads, qkv_bias=True, **block_kwargs
- )
-
- self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
- approx_gelu = lambda: nn.GELU(approximate="tanh")
- self.mlp = Mlp(
- in_features=hidden_size,
- hidden_features=mlp_hidden_dim,
- act_layer=approx_gelu,
- drop=0,
- )
-
- def forward(self, x, context, mask=None):
- if mask is not None:
- if mask.shape[1] == x.shape[1]:
- mask = mask[:, None, :, None].expand(
- -1, self.cross_attn.heads, -1, context.shape[1]
- )
- else:
- mask = mask[:, None, None].expand(-1, self.cross_attn.heads, x.shape[1], -1)
-
- max_neg_value = -torch.finfo(x.dtype).max
- attn_bias = (~mask) * max_neg_value
- x = x + self.cross_attn(
- self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
- )
- x = x + self.mlp(self.norm2(x))
- return x
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from cotracker.models.core.model_utils import sample_features4d, sample_features5d
+from cotracker.models.core.embeddings import (
+ get_2d_embedding,
+ get_1d_sincos_pos_embed_from_grid,
+ get_2d_sincos_pos_embed,
+)
+
+from cotracker.models.core.cotracker.blocks import (
+ Mlp,
+ BasicEncoder,
+ AttnBlock,
+ CorrBlock,
+ Attention,
+)
+
+torch.manual_seed(0)
+
+
+class CoTracker2(nn.Module):
+ def __init__(
+ self,
+ window_len=8,
+ stride=4,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ model_resolution=(384, 512),
+ ):
+ super(CoTracker2, self).__init__()
+ self.window_len = window_len
+ self.stride = stride
+ self.hidden_dim = 256
+ self.latent_dim = 128
+ self.add_space_attn = add_space_attn
+
+ self.fnet = BasicEncoder(output_dim=self.latent_dim)
+ self.num_virtual_tracks = num_virtual_tracks
+ self.model_resolution = model_resolution
+ self.input_dim = 456
+ self.updateformer = EfficientUpdateFormer(
+ space_depth=6,
+ time_depth=6,
+ input_dim=self.input_dim,
+ hidden_size=384,
+ output_dim=self.latent_dim + 2,
+ mlp_ratio=4.0,
+ add_space_attn=add_space_attn,
+ num_virtual_tracks=num_virtual_tracks,
+ )
+
+ time_grid = torch.linspace(0, window_len - 1, window_len).reshape(1, window_len, 1)
+
+ self.register_buffer(
+ "time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
+ )
+
+ self.register_buffer(
+ "pos_emb",
+ get_2d_sincos_pos_embed(
+ embed_dim=self.input_dim,
+ grid_size=(
+ model_resolution[0] // stride,
+ model_resolution[1] // stride,
+ ),
+ ),
+ )
+ self.norm = nn.GroupNorm(1, self.latent_dim)
+ self.track_feat_updater = nn.Sequential(
+ nn.Linear(self.latent_dim, self.latent_dim),
+ nn.GELU(),
+ )
+ self.vis_predictor = nn.Sequential(
+ nn.Linear(self.latent_dim, 1),
+ )
+
+ def forward_window(
+ self,
+ fmaps,
+ coords,
+ track_feat=None,
+ vis=None,
+ track_mask=None,
+ attention_mask=None,
+ iters=4,
+ ):
+ # B = batch size
+ # S = number of frames in the window)
+ # N = number of tracks
+ # C = channels of a point feature vector
+ # E = positional embedding size
+ # LRR = local receptive field radius
+ # D = dimension of the transformer input tokens
+
+ # track_feat = B S N C
+ # vis = B S N 1
+ # track_mask = B S N 1
+ # attention_mask = B S N
+
+ B, S_init, N, __ = track_mask.shape
+ B, S, *_ = fmaps.shape
+
+ # 填充使得track_mask 的帧数与特征图的帧数一致。
+ track_mask = F.pad(track_mask, (0, 0, 0, 0, 0, S - S_init), "constant")
+ track_mask_vis = (
+ torch.cat([track_mask, vis], dim=-1).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+ )
+
+ corr_block = CorrBlock(
+ fmaps,
+ num_levels=4,
+ radius=3,
+ padding_mode="border",
+ )
+
+ sampled_pos_emb = (
+ sample_features4d(self.pos_emb.repeat(B, 1, 1, 1), coords[:, 0])
+ .reshape(B * N, self.input_dim)
+ .unsqueeze(1)
+ ) # B E N -> (B N) 1 E
+
+ coord_preds = []
+ for __ in range(iters):
+ coords = coords.detach() # B S N 2
+ corr_block.corr(track_feat)
+
+ # Sample correlation features around each point
+ fcorrs = corr_block.sample(coords) # (B N) S LRR
+
+ # Get the flow embeddings
+ flows = (coords - coords[:, 0:1]).permute(0, 2, 1, 3).reshape(B * N, S, 2)
+ flow_emb = get_2d_embedding(flows, 64, cat_coords=True) # N S E
+
+ track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N, S, self.latent_dim)
+
+ transformer_input = torch.cat([flow_emb, fcorrs, track_feat_, track_mask_vis], dim=2)
+ x = transformer_input + sampled_pos_emb + self.time_emb
+ x = x.view(B, N, S, -1) # (B N) S D -> B N S D
+
+ delta = self.updateformer(
+ x,
+ attention_mask.reshape(B * S, N), # B S N -> (B S) N
+ )
+
+ delta_coords = delta[..., :2].permute(0, 2, 1, 3)
+ coords = coords + delta_coords
+ coord_preds.append(coords * self.stride)
+
+ delta_feats_ = delta[..., 2:].reshape(B * N * S, self.latent_dim)
+ track_feat_ = track_feat.permute(0, 2, 1, 3).reshape(B * N * S, self.latent_dim)
+ track_feat_ = self.track_feat_updater(self.norm(delta_feats_)) + track_feat_
+ track_feat = track_feat_.reshape(B, N, S, self.latent_dim).permute(
+ 0, 2, 1, 3
+ ) # (B N S) C -> B S N C
+
+ vis_pred = self.vis_predictor(track_feat).reshape(B, S, N)
+ return coord_preds, vis_pred
+
+ def get_track_feat(self, fmaps, queried_frames, queried_coords):
+ sample_frames = queried_frames[:, None, :, None]
+ sample_coords = torch.cat(
+ [
+ sample_frames,
+ queried_coords[:, None],
+ ],
+ dim=-1,
+ )
+ # 双线性采样
+ sample_track_feats = sample_features5d(fmaps, sample_coords)
+ return sample_track_feats
+
+ def init_video_online_processing(self):
+ self.online_ind = 0
+ self.online_track_feat = None
+ self.online_coords_predicted = None
+ self.online_vis_predicted = None
+
+ def forward(self, video, queries, iters=4, is_train=False, is_online=False):
+ """Predict tracks
+
+ Args:
+ video (FloatTensor[B, T, 3]): input videos.
+ queries (FloatTensor[B, N, 3]): point queries.
+ iters (int, optional): number of updates. Defaults to 4.
+ is_train (bool, optional): enables training mode. Defaults to False.
+ is_online (bool, optional): enables online mode. Defaults to False. Before enabling, call model.init_video_online_processing().
+
+ Returns:
+ - coords_predicted (FloatTensor[B, T, N, 2]):
+ - vis_predicted (FloatTensor[B, T, N]):
+ - train_data: `None` if `is_train` is false, otherwise:
+ - all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
+ - all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
+ - mask (BoolTensor[B, T, N]):
+ """
+ B, T, C, H, W = video.shape
+ B, N, __ = queries.shape
+ S = self.window_len
+ device = queries.device
+
+ # B = batch size
+ # S = number of frames in the window of the padded video
+ # S_trimmed = actual number of frames in the window
+ # N = number of tracks
+ # C = color channels (3 for RGB)
+ # E = positional embedding size
+ # LRR = local receptive field radius
+ # D = dimension of the transformer input tokens
+
+ # video = B T C H W
+ # queries = B N 3
+ # coords_init = B S N 2
+ # vis_init = B S N 1
+
+ assert S >= 2 # A tracker needs at least two frames to track something
+ if is_online:
+ assert T <= S, "Online mode: video chunk must be <= window size."
+ assert self.online_ind is not None, "Call model.init_video_online_processing() first."
+ assert not is_train, "Training not supported in online mode."
+ step = S // 2 # How much the sliding window moves at every step
+ video = 2 * (video / 255.0) - 1.0
+
+ # The first channel is the frame number
+ # The rest are the coordinates of points we want to track
+ queried_frames = queries[:, :, 0].long() # 获取帧数字
+
+ queried_coords = queries[..., 1:]
+ queried_coords = queried_coords / self.stride # 缩放
+
+ # We store our predictions here
+ coords_predicted = torch.zeros((B, T, N, 2), device=device) # 等待处理的预测的点
+ vis_predicted = torch.zeros((B, T, N), device=device)
+ if is_online:
+ # 如果online的话,坐标都制成0, vis都是false
+ # 如果不是在线,就填充一圈0
+ if self.online_coords_predicted is None:
+ # Init online predictions with zeros
+ self.online_coords_predicted = coords_predicted
+ self.online_vis_predicted = vis_predicted
+ else:
+ # Pad online predictions with zeros for the current window
+ pad = min(step, T - step) # 确保填充量不会超过剩余的时间帧数
+ coords_predicted = F.pad(
+ self.online_coords_predicted, (0, 0, 0, 0, 0, pad), "constant"
+ )
+ vis_predicted = F.pad(self.online_vis_predicted, (0, 0, 0, pad), "constant")
+ all_coords_predictions, all_vis_predictions = [], []
+
+ # Pad the video so that an integer number of sliding windows fit into it
+ # 填充视频,使得一个整数的滑动窗口能够适应它
+ # TODO: we may drop this requirement because the transformer should not care
+ # TODO: pad the features instead of the video
+ # 下面这行计算需要填充的帧数
+ pad = S - T if is_online else (S - T % S) % S # We don't want to pad if T % S == 0
+ # 填充将最后一个帧复制pad遍
+ video = F.pad(video.reshape(B, 1, T, C * H * W), (0, 0, 0, pad), "replicate").reshape(
+ B, -1, C, H, W
+ )
+
+ # Compute convolutional features for the video or for the current chunk in case of online mode
+ # 计算视频的卷积特征或者是在线计算当前的块
+ fmaps = self.fnet(video.reshape(-1, C, H, W)).reshape(
+ B, -1, self.latent_dim, H // self.stride, W // self.stride
+ )
+
+ # We compute track features
+ # 内部是用双线性采样求feature maps feature
+ track_feat = self.get_track_feat(
+ fmaps,
+ queried_frames - self.online_ind if is_online else queried_frames,
+ queried_coords,
+ ).repeat(1, S, 1, 1)
+ if is_online:
+ # We update track features for the current window
+ sample_frames = queried_frames[:, None, :, None] # B 1 N 1
+ left = 0 if self.online_ind == 0 else self.online_ind + step
+ right = self.online_ind + S
+ sample_mask = (sample_frames >= left) & (sample_frames < right)
+ if self.online_track_feat is None:
+ self.online_track_feat = torch.zeros_like(track_feat, device=device)
+ self.online_track_feat += track_feat * sample_mask
+ track_feat = self.online_track_feat.clone()
+ # We process ((num_windows - 1) * step + S) frames in total, so there are
+ # (ceil((T - S) / step) + 1) windows
+ num_windows = (T - S + step - 1) // step + 1
+ # We process only the current video chunk in the online mode
+ indices = [self.online_ind] if is_online else range(0, step * num_windows, step)
+
+ # 查询的坐标调整形状
+ coords_init = queried_coords.reshape(B, 1, N, 2).expand(B, S, N, 2).float()
+ vis_init = torch.ones((B, S, N, 1), device=device).float() * 10
+ for ind in indices:
+ # We copy over coords and vis for tracks that are queried
+ # by the end of the previous window, which is ind + overlap
+ # 处理重叠部分
+ if ind > 0:
+ overlap = S - step
+ copy_over = (queried_frames < ind + overlap)[:, None, :, None] # B 1 N 1
+ # 复制前一个窗口的预测结果
+ coords_prev = torch.nn.functional.pad(
+ coords_predicted[:, ind : ind + overlap] / self.stride,
+ (0, 0, 0, 0, 0, step),
+ "replicate",
+ ) # B S N 2
+ vis_prev = torch.nn.functional.pad(
+ vis_predicted[:, ind : ind + overlap, :, None].clone(),
+ (0, 0, 0, 0, 0, step),
+ "replicate",
+ ) # B S N 1
+ coords_init = torch.where(
+ copy_over.expand_as(coords_init), coords_prev, coords_init
+ )# True就是coords_prev, False 就是coords_init
+ vis_init = torch.where(copy_over.expand_as(vis_init), vis_prev, vis_init)
+
+ # The attention mask is 1 for the spatio-temporal points within
+ # a track which is updated in the current window
+ # 用于表示在当前窗口内需要更新的时间-空间点
+ attention_mask = (queried_frames < ind + S).reshape(B, 1, N).repeat(1, S, 1) # B S N
+
+ # The track mask is 1 for the spatio-temporal points that actually
+ # need updating: only after begin queried, and not if contained
+ # in a previous window
+ # track_mask表示实际需要更新的
+ track_mask = (
+ queried_frames[:, None, :, None]
+ <= torch.arange(ind, ind + S, device=device)[None, :, None, None]
+ ).contiguous() # B S N 1
+
+ if ind > 0:
+ track_mask[:, :overlap, :, :] = False
+
+ # Predict the coordinates and visibility for the current window
+ # 用forward_window 来更新coords和vis
+ coords, vis = self.forward_window(
+ fmaps=fmaps if is_online else fmaps[:, ind : ind + S],
+ coords=coords_init,
+ track_feat=attention_mask.unsqueeze(-1) * track_feat,
+ vis=vis_init,
+ track_mask=track_mask,
+ attention_mask=attention_mask,
+ iters=iters,
+ )
+
+ S_trimmed = T if is_online else min(T - ind, S) # accounts for last window duration
+ coords_predicted[:, ind : ind + S] = coords[-1][:, :S_trimmed]
+ vis_predicted[:, ind : ind + S] = vis[:, :S_trimmed]
+ if is_train:
+ all_coords_predictions.append([coord[:, :S_trimmed] for coord in coords])
+ all_vis_predictions.append(torch.sigmoid(vis[:, :S_trimmed]))
+
+ if is_online:
+ self.online_ind += step
+ self.online_coords_predicted = coords_predicted
+ self.online_vis_predicted = vis_predicted
+ vis_predicted = torch.sigmoid(vis_predicted)
+
+ if is_train:
+ mask = queried_frames[:, None] <= torch.arange(0, T, device=device)[None, :, None]
+ train_data = (all_coords_predictions, all_vis_predictions, mask)
+ else:
+ train_data = None
+
+ return coords_predicted, vis_predicted, train_data
+
+
+class EfficientUpdateFormer(nn.Module):
+ """
+ Transformer model that updates track estimates.
+ """
+
+ def __init__(
+ self,
+ space_depth=6,
+ time_depth=6,
+ input_dim=320,
+ hidden_size=384,
+ num_heads=8,
+ output_dim=130,
+ mlp_ratio=4.0,
+ add_space_attn=True,
+ num_virtual_tracks=64,
+ ):
+ super().__init__()
+ self.out_channels = 2
+ self.num_heads = num_heads
+ self.hidden_size = hidden_size
+ self.add_space_attn = add_space_attn
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
+ self.num_virtual_tracks = num_virtual_tracks
+ self.virual_tracks = nn.Parameter(torch.randn(1, num_virtual_tracks, 1, hidden_size))
+ self.time_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=Attention,
+ )
+ for _ in range(time_depth)
+ ]
+ )
+
+ if add_space_attn:
+ self.space_virtual_blocks = nn.ModuleList(
+ [
+ AttnBlock(
+ hidden_size,
+ num_heads,
+ mlp_ratio=mlp_ratio,
+ attn_class=Attention,
+ )
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_point2virtual_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(space_depth)
+ ]
+ )
+ self.space_virtual2point_blocks = nn.ModuleList(
+ [
+ CrossAttnBlock(hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio)
+ for _ in range(space_depth)
+ ]
+ )
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ def _basic_init(module):
+ if isinstance(module, nn.Linear):
+ torch.nn.init.xavier_uniform_(module.weight)
+ if module.bias is not None:
+ nn.init.constant_(module.bias, 0)
+
+ self.apply(_basic_init)
+
+ def forward(self, input_tensor, mask=None):
+ tokens = self.input_transform(input_tensor)
+ B, _, T, _ = tokens.shape
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
+ _, N, _, _ = tokens.shape
+
+ j = 0
+ for i in range(len(self.time_blocks)):
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
+ time_tokens = self.time_blocks[i](time_tokens)
+
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
+ if self.add_space_attn and (
+ i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0
+ ):
+ space_tokens = (
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
+ ) # B N T C -> (B T) N C
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
+
+ virtual_tokens = self.space_virtual2point_blocks[j](
+ virtual_tokens, point_tokens, mask=mask
+ )
+ virtual_tokens = self.space_virtual_blocks[j](virtual_tokens)
+ point_tokens = self.space_point2virtual_blocks[j](
+ point_tokens, virtual_tokens, mask=mask
+ )
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
+ tokens = space_tokens.view(B, T, N, -1).permute(0, 2, 1, 3) # (B T) N C -> B N T C
+ j += 1
+ tokens = tokens[:, : N - self.num_virtual_tracks]
+ flow = self.flow_head(tokens)
+ return flow
+
+
+class CrossAttnBlock(nn.Module):
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
+ super().__init__()
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ self.norm_context = nn.LayerNorm(hidden_size)
+ self.cross_attn = Attention(
+ hidden_size, context_dim=context_dim, num_heads=num_heads, qkv_bias=True, **block_kwargs
+ )
+
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
+ self.mlp = Mlp(
+ in_features=hidden_size,
+ hidden_features=mlp_hidden_dim,
+ act_layer=approx_gelu,
+ drop=0,
+ )
+
+ def forward(self, x, context, mask=None):
+ if mask is not None:
+ if mask.shape[1] == x.shape[1]:
+ mask = mask[:, None, :, None].expand(
+ -1, self.cross_attn.heads, -1, context.shape[1]
+ )
+ else:
+ mask = mask[:, None, None].expand(-1, self.cross_attn.heads, x.shape[1], -1)
+
+ max_neg_value = -torch.finfo(x.dtype).max
+ attn_bias = (~mask) * max_neg_value
+ x = x + self.cross_attn(
+ self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
+ )
+ x = x + self.mlp(self.norm2(x))
+ return x
diff --git a/cotracker/models/core/cotracker/losses.py b/cotracker/models/core/cotracker/losses.py
index 2bdcc2e..0168d9d 100644
--- a/cotracker/models/core/cotracker/losses.py
+++ b/cotracker/models/core/cotracker/losses.py
@@ -1,61 +1,61 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torch.nn.functional as F
-from cotracker.models.core.model_utils import reduce_masked_mean
-
-EPS = 1e-6
-
-
-def balanced_ce_loss(pred, gt, valid=None):
- total_balanced_loss = 0.0
- for j in range(len(gt)):
- B, S, N = gt[j].shape
- # pred and gt are the same shape
- for (a, b) in zip(pred[j].size(), gt[j].size()):
- assert a == b # some shape mismatch!
- # if valid is not None:
- for (a, b) in zip(pred[j].size(), valid[j].size()):
- assert a == b # some shape mismatch!
-
- pos = (gt[j] > 0.95).float()
- neg = (gt[j] < 0.05).float()
-
- label = pos * 2.0 - 1.0
- a = -label * pred[j]
- b = F.relu(a)
- loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
-
- pos_loss = reduce_masked_mean(loss, pos * valid[j])
- neg_loss = reduce_masked_mean(loss, neg * valid[j])
-
- balanced_loss = pos_loss + neg_loss
- total_balanced_loss += balanced_loss / float(N)
- return total_balanced_loss
-
-
-def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
- """Loss function defined over sequence of flow predictions"""
- total_flow_loss = 0.0
- for j in range(len(flow_gt)):
- B, S, N, D = flow_gt[j].shape
- assert D == 2
- B, S1, N = vis[j].shape
- B, S2, N = valids[j].shape
- assert S == S1
- assert S == S2
- n_predictions = len(flow_preds[j])
- flow_loss = 0.0
- for i in range(n_predictions):
- i_weight = gamma ** (n_predictions - i - 1)
- flow_pred = flow_preds[j][i]
- i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
- i_loss = torch.mean(i_loss, dim=3) # B, S, N
- flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
- flow_loss = flow_loss / n_predictions
- total_flow_loss += flow_loss / float(N)
- return total_flow_loss
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from cotracker.models.core.model_utils import reduce_masked_mean
+
+EPS = 1e-6
+
+
+def balanced_ce_loss(pred, gt, valid=None):
+ total_balanced_loss = 0.0
+ for j in range(len(gt)):
+ B, S, N = gt[j].shape
+ # pred and gt are the same shape
+ for (a, b) in zip(pred[j].size(), gt[j].size()):
+ assert a == b # some shape mismatch!
+ # if valid is not None:
+ for (a, b) in zip(pred[j].size(), valid[j].size()):
+ assert a == b # some shape mismatch!
+
+ pos = (gt[j] > 0.95).float()
+ neg = (gt[j] < 0.05).float()
+
+ label = pos * 2.0 - 1.0
+ a = -label * pred[j]
+ b = F.relu(a)
+ loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
+
+ pos_loss = reduce_masked_mean(loss, pos * valid[j])
+ neg_loss = reduce_masked_mean(loss, neg * valid[j])
+
+ balanced_loss = pos_loss + neg_loss
+ total_balanced_loss += balanced_loss / float(N)
+ return total_balanced_loss
+
+
+def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
+ """Loss function defined over sequence of flow predictions"""
+ total_flow_loss = 0.0
+ for j in range(len(flow_gt)):
+ B, S, N, D = flow_gt[j].shape
+ assert D == 2
+ B, S1, N = vis[j].shape
+ B, S2, N = valids[j].shape
+ assert S == S1
+ assert S == S2
+ n_predictions = len(flow_preds[j])
+ flow_loss = 0.0
+ for i in range(n_predictions):
+ i_weight = gamma ** (n_predictions - i - 1)
+ flow_pred = flow_preds[j][i]
+ i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
+ flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
+ flow_loss = flow_loss / n_predictions
+ total_flow_loss += flow_loss / float(N)
+ return total_flow_loss
diff --git a/cotracker/models/core/embeddings.py b/cotracker/models/core/embeddings.py
index 897cd5d..2ee4aee 100644
--- a/cotracker/models/core/embeddings.py
+++ b/cotracker/models/core/embeddings.py
@@ -1,120 +1,120 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-from typing import Tuple, Union
-import torch
-
-
-def get_2d_sincos_pos_embed(
- embed_dim: int, grid_size: Union[int, Tuple[int, int]]
-) -> torch.Tensor:
- """
- This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
- It is a wrapper of get_2d_sincos_pos_embed_from_grid.
- Args:
- - embed_dim: The embedding dimension.
- - grid_size: The grid size.
- Returns:
- - pos_embed: The generated 2D positional embedding.
- """
- if isinstance(grid_size, tuple):
- grid_size_h, grid_size_w = grid_size
- else:
- grid_size_h = grid_size_w = grid_size
- grid_h = torch.arange(grid_size_h, dtype=torch.float)
- grid_w = torch.arange(grid_size_w, dtype=torch.float)
- grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
- grid = torch.stack(grid, dim=0)
- grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
- pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
- return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
-
-
-def get_2d_sincos_pos_embed_from_grid(
- embed_dim: int, grid: torch.Tensor
-) -> torch.Tensor:
- """
- This function generates a 2D positional embedding from a given grid using sine and cosine functions.
-
- Args:
- - embed_dim: The embedding dimension.
- - grid: The grid to generate the embedding from.
-
- Returns:
- - emb: The generated 2D positional embedding.
- """
- assert embed_dim % 2 == 0
-
- # use half of dimensions to encode grid_h
- emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
- emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
-
- emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
- return emb
-
-
-def get_1d_sincos_pos_embed_from_grid(
- embed_dim: int, pos: torch.Tensor
-) -> torch.Tensor:
- """
- This function generates a 1D positional embedding from a given grid using sine and cosine functions.
-
- Args:
- - embed_dim: The embedding dimension.
- - pos: The position to generate the embedding from.
-
- Returns:
- - emb: The generated 1D positional embedding.
- """
- assert embed_dim % 2 == 0
- omega = torch.arange(embed_dim // 2, dtype=torch.double)
- omega /= embed_dim / 2.0
- omega = 1.0 / 10000**omega # (D/2,)
-
- pos = pos.reshape(-1) # (M,)
- out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
-
- emb_sin = torch.sin(out) # (M, D/2)
- emb_cos = torch.cos(out) # (M, D/2)
-
- emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
- return emb[None].float()
-
-
-def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
- """
- This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
-
- Args:
- - xy: The coordinates to generate the embedding from.
- - C: The size of the embedding.
- - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
-
- Returns:
- - pe: The generated 2D positional embedding.
- """
- B, N, D = xy.shape
- assert D == 2
-
- x = xy[:, :, 0:1]
- y = xy[:, :, 1:2]
- div_term = (
- torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
- ).reshape(1, 1, int(C / 2))
-
- pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
- pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
-
- pe_x[:, :, 0::2] = torch.sin(x * div_term)
- pe_x[:, :, 1::2] = torch.cos(x * div_term)
-
- pe_y[:, :, 0::2] = torch.sin(y * div_term)
- pe_y[:, :, 1::2] = torch.cos(y * div_term)
-
- pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
- if cat_coords:
- pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
- return pe
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Tuple, Union
+import torch
+
+
+def get_2d_sincos_pos_embed(
+ embed_dim: int, grid_size: Union[int, Tuple[int, int]]
+) -> torch.Tensor:
+ """
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid_size: The grid size.
+ Returns:
+ - pos_embed: The generated 2D positional embedding.
+ """
+ if isinstance(grid_size, tuple):
+ grid_size_h, grid_size_w = grid_size
+ else:
+ grid_size_h = grid_size_w = grid_size
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
+ grid = torch.stack(grid, dim=0)
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
+
+
+def get_2d_sincos_pos_embed_from_grid(
+ embed_dim: int, grid: torch.Tensor
+) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - grid: The grid to generate the embedding from.
+
+ Returns:
+ - emb: The generated 2D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+
+ # use half of dimensions to encode grid_h
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
+
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
+ return emb
+
+
+def get_1d_sincos_pos_embed_from_grid(
+ embed_dim: int, pos: torch.Tensor
+) -> torch.Tensor:
+ """
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
+
+ Args:
+ - embed_dim: The embedding dimension.
+ - pos: The position to generate the embedding from.
+
+ Returns:
+ - emb: The generated 1D positional embedding.
+ """
+ assert embed_dim % 2 == 0
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
+ omega /= embed_dim / 2.0
+ omega = 1.0 / 10000**omega # (D/2,)
+
+ pos = pos.reshape(-1) # (M,)
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
+
+ emb_sin = torch.sin(out) # (M, D/2)
+ emb_cos = torch.cos(out) # (M, D/2)
+
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
+ return emb[None].float()
+
+
+def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
+ """
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
+
+ Args:
+ - xy: The coordinates to generate the embedding from.
+ - C: The size of the embedding.
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
+
+ Returns:
+ - pe: The generated 2D positional embedding.
+ """
+ B, N, D = xy.shape
+ assert D == 2
+
+ x = xy[:, :, 0:1]
+ y = xy[:, :, 1:2]
+ div_term = (
+ torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
+ ).reshape(1, 1, int(C / 2))
+
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
+
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
+
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
+
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
+ if cat_coords:
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
+ return pe
diff --git a/cotracker/models/core/model_utils.py b/cotracker/models/core/model_utils.py
index 321d1ee..12afd4e 100644
--- a/cotracker/models/core/model_utils.py
+++ b/cotracker/models/core/model_utils.py
@@ -1,256 +1,256 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torch.nn.functional as F
-from typing import Optional, Tuple
-
-EPS = 1e-6
-
-
-def smart_cat(tensor1, tensor2, dim):
- if tensor1 is None:
- return tensor2
- return torch.cat([tensor1, tensor2], dim=dim)
-
-
-def get_points_on_a_grid(
- size: int,
- extent: Tuple[float, ...],
- center: Optional[Tuple[float, ...]] = None,
- device: Optional[torch.device] = torch.device("cpu"),
-):
- r"""Get a grid of points covering a rectangular region
-
- `get_points_on_a_grid(size, extent)` generates a :attr:`size` by
- :attr:`size` grid fo points distributed to cover a rectangular area
- specified by `extent`.
-
- The `extent` is a pair of integer :math:`(H,W)` specifying the height
- and width of the rectangle.
-
- Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
- specifying the vertical and horizontal center coordinates. The center
- defaults to the middle of the extent.
-
- Points are distributed uniformly within the rectangle leaving a margin
- :math:`m=W/64` from the border.
-
- It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
- points :math:`P_{ij}=(x_i, y_i)` where
-
- .. math::
- P_{ij} = \left(
- c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
- c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
- \right)
-
- Points are returned in row-major order.
-
- Args:
- size (int): grid size.
- extent (tuple): height and with of the grid extent.
- center (tuple, optional): grid center.
- device (str, optional): Defaults to `"cpu"`.
-
- Returns:
- Tensor: grid.
- """
- if size == 1:
- return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
-
- if center is None:
- center = [extent[0] / 2, extent[1] / 2]
-
- margin = extent[1] / 64
- range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
- range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
- grid_y, grid_x = torch.meshgrid(
- torch.linspace(*range_y, size, device=device),
- torch.linspace(*range_x, size, device=device),
- indexing="ij",
- )
- return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
-
-
-def reduce_masked_mean(input, mask, dim=None, keepdim=False):
- r"""Masked mean
-
- `reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
- over a mask :attr:`mask`, returning
-
- .. math::
- \text{output} =
- \frac
- {\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
- {\epsilon + \sum_{i=1}^N \text{mask}_i}
-
- where :math:`N` is the number of elements in :attr:`input` and
- :attr:`mask`, and :math:`\epsilon` is a small constant to avoid
- division by zero.
-
- `reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
- :attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
- Optionally, the dimension can be kept in the output by setting
- :attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
- the same dimension as :attr:`input`.
-
- The interface is similar to `torch.mean()`.
-
- Args:
- inout (Tensor): input tensor.
- mask (Tensor): mask.
- dim (int, optional): Dimension to sum over. Defaults to None.
- keepdim (bool, optional): Keep the summed dimension. Defaults to False.
-
- Returns:
- Tensor: mean tensor.
- """
-
- mask = mask.expand_as(input)
-
- prod = input * mask
-
- if dim is None:
- numer = torch.sum(prod)
- denom = torch.sum(mask)
- else:
- numer = torch.sum(prod, dim=dim, keepdim=keepdim)
- denom = torch.sum(mask, dim=dim, keepdim=keepdim)
-
- mean = numer / (EPS + denom)
- return mean
-
-
-def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
- r"""Sample a tensor using bilinear interpolation
-
- `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
- coordinates :attr:`coords` using bilinear interpolation. It is the same
- as `torch.nn.functional.grid_sample()` but with a different coordinate
- convention.
-
- The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
- :math:`B` is the batch size, :math:`C` is the number of channels,
- :math:`H` is the height of the image, and :math:`W` is the width of the
- image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
- interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
-
- Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
- in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
- that in this case the order of the components is slightly different
- from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
-
- If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
- in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
- left-most image pixel :math:`W-1` to the center of the right-most
- pixel.
-
- If `align_corners` is `False`, the coordinate :math:`x` is assumed to
- be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
- the left-most pixel :math:`W` to the right edge of the right-most
- pixel.
-
- Similar conventions apply to the :math:`y` for the range
- :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
- :math:`[0,T-1]` and :math:`[0,T]`.
-
- Args:
- input (Tensor): batch of input images.
- coords (Tensor): batch of coordinates.
- align_corners (bool, optional): Coordinate convention. Defaults to `True`.
- padding_mode (str, optional): Padding mode. Defaults to `"border"`.
-
- Returns:
- Tensor: sampled points.
- """
-
- sizes = input.shape[2:]
-
- assert len(sizes) in [2, 3]
-
- if len(sizes) == 3:
- # t x y -> x y t to match dimensions T H W in grid_sample
- coords = coords[..., [1, 2, 0]]
-
- if align_corners:
- coords = coords * torch.tensor(
- [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
- )
- else:
- coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
-
- coords -= 1
-
- return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
-
-
-def sample_features4d(input, coords):
- r"""Sample spatial features
-
- `sample_features4d(input, coords)` samples the spatial features
- :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
-
- The field is sampled at coordinates :attr:`coords` using bilinear
- interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
- 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
- same convention as :func:`bilinear_sampler` with `align_corners=True`.
-
- The output tensor has one feature per point, and has shape :math:`(B,
- R, C)`.
-
- Args:
- input (Tensor): spatial features.
- coords (Tensor): points.
-
- Returns:
- Tensor: sampled features.
- """
-
- B, _, _, _ = input.shape
-
- # B R 2 -> B R 1 2
- coords = coords.unsqueeze(2)
-
- # B C R 1
- feats = bilinear_sampler(input, coords)
-
- return feats.permute(0, 2, 1, 3).view(
- B, -1, feats.shape[1] * feats.shape[3]
- ) # B C R 1 -> B R C
-
-
-def sample_features5d(input, coords):
- r"""Sample spatio-temporal features
-
- `sample_features5d(input, coords)` works in the same way as
- :func:`sample_features4d` but for spatio-temporal features and points:
- :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
- a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
- x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
-
- Args:
- input (Tensor): spatio-temporal features.
- coords (Tensor): spatio-temporal points.
-
- Returns:
- Tensor: sampled features.
- """
-
- B, T, _, _, _ = input.shape
-
- # B T C H W -> B C T H W
- input = input.permute(0, 2, 1, 3, 4)
-
- # B R1 R2 3 -> B R1 R2 1 3
- coords = coords.unsqueeze(3)
-
- # B C R1 R2 1
- feats = bilinear_sampler(input, coords)
-
- return feats.permute(0, 2, 3, 1, 4).view(
- B, feats.shape[2], feats.shape[3], feats.shape[1]
- ) # B C R1 R2 1 -> B R1 R2 C
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from typing import Optional, Tuple
+
+EPS = 1e-6
+
+
+def smart_cat(tensor1, tensor2, dim):
+ if tensor1 is None:
+ return tensor2
+ return torch.cat([tensor1, tensor2], dim=dim)
+
+
+def get_points_on_a_grid(
+ size: int,
+ extent: Tuple[float, ...],
+ center: Optional[Tuple[float, ...]] = None,
+ device: Optional[torch.device] = torch.device("cpu"),
+):
+ r"""Get a grid of points covering a rectangular region
+
+ `get_points_on_a_grid(size, extent)` generates a :attr:`size` by
+ :attr:`size` grid fo points distributed to cover a rectangular area
+ specified by `extent`.
+
+ The `extent` is a pair of integer :math:`(H,W)` specifying the height
+ and width of the rectangle.
+
+ Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
+ specifying the vertical and horizontal center coordinates. The center
+ defaults to the middle of the extent.
+
+ Points are distributed uniformly within the rectangle leaving a margin
+ :math:`m=W/64` from the border.
+
+ It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
+ points :math:`P_{ij}=(x_i, y_i)` where
+
+ .. math::
+ P_{ij} = \left(
+ c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
+ c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
+ \right)
+
+ Points are returned in row-major order.
+
+ Args:
+ size (int): grid size.
+ extent (tuple): height and with of the grid extent.
+ center (tuple, optional): grid center.
+ device (str, optional): Defaults to `"cpu"`.
+
+ Returns:
+ Tensor: grid.
+ """
+ if size == 1:
+ return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
+
+ if center is None:
+ center = [extent[0] / 2, extent[1] / 2]
+
+ margin = extent[1] / 64
+ range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
+ range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(*range_y, size, device=device),
+ torch.linspace(*range_x, size, device=device),
+ indexing="ij",
+ )
+ return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
+
+
+def reduce_masked_mean(input, mask, dim=None, keepdim=False):
+ r"""Masked mean
+
+ `reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
+ over a mask :attr:`mask`, returning
+
+ .. math::
+ \text{output} =
+ \frac
+ {\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
+ {\epsilon + \sum_{i=1}^N \text{mask}_i}
+
+ where :math:`N` is the number of elements in :attr:`input` and
+ :attr:`mask`, and :math:`\epsilon` is a small constant to avoid
+ division by zero.
+
+ `reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
+ :attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
+ Optionally, the dimension can be kept in the output by setting
+ :attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
+ the same dimension as :attr:`input`.
+
+ The interface is similar to `torch.mean()`.
+
+ Args:
+ inout (Tensor): input tensor.
+ mask (Tensor): mask.
+ dim (int, optional): Dimension to sum over. Defaults to None.
+ keepdim (bool, optional): Keep the summed dimension. Defaults to False.
+
+ Returns:
+ Tensor: mean tensor.
+ """
+
+ mask = mask.expand_as(input)
+
+ prod = input * mask
+
+ if dim is None:
+ numer = torch.sum(prod)
+ denom = torch.sum(mask)
+ else:
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
+ denom = torch.sum(mask, dim=dim, keepdim=keepdim)
+
+ mean = numer / (EPS + denom)
+ return mean
+
+
+def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
+ r"""Sample a tensor using bilinear interpolation
+
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
+ convention.
+
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
+ :math:`B` is the batch size, :math:`C` is the number of channels,
+ :math:`H` is the height of the image, and :math:`W` is the width of the
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
+
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
+ that in this case the order of the components is slightly different
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
+
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
+ left-most image pixel :math:`W-1` to the center of the right-most
+ pixel.
+
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
+ the left-most pixel :math:`W` to the right edge of the right-most
+ pixel.
+
+ Similar conventions apply to the :math:`y` for the range
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
+ :math:`[0,T-1]` and :math:`[0,T]`.
+
+ Args:
+ input (Tensor): batch of input images.
+ coords (Tensor): batch of coordinates.
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
+
+ Returns:
+ Tensor: sampled points.
+ """
+
+ sizes = input.shape[2:]
+
+ assert len(sizes) in [2, 3]
+
+ if len(sizes) == 3:
+ # t x y -> x y t to match dimensions T H W in grid_sample
+ coords = coords[..., [1, 2, 0]]
+
+ if align_corners:
+ coords = coords * torch.tensor(
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
+ )
+ else:
+ coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
+
+ coords -= 1
+
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
+
+
+def sample_features4d(input, coords):
+ r"""Sample spatial features
+
+ `sample_features4d(input, coords)` samples the spatial features
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
+
+ The field is sampled at coordinates :attr:`coords` using bilinear
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
+ 3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
+
+ The output tensor has one feature per point, and has shape :math:`(B,
+ R, C)`.
+
+ Args:
+ input (Tensor): spatial features.
+ coords (Tensor): points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, _, _, _ = input.shape
+
+ # B R 2 -> B R 1 2
+ coords = coords.unsqueeze(2)
+
+ # B C R 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 1, 3).view(
+ B, -1, feats.shape[1] * feats.shape[3]
+ ) # B C R 1 -> B R C
+
+
+def sample_features5d(input, coords):
+ r"""Sample spatio-temporal features
+
+ `sample_features5d(input, coords)` works in the same way as
+ :func:`sample_features4d` but for spatio-temporal features and points:
+ :attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
+ a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
+ x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
+
+ Args:
+ input (Tensor): spatio-temporal features.
+ coords (Tensor): spatio-temporal points.
+
+ Returns:
+ Tensor: sampled features.
+ """
+
+ B, T, _, _, _ = input.shape
+
+ # B T C H W -> B C T H W
+ input = input.permute(0, 2, 1, 3, 4)
+
+ # B R1 R2 3 -> B R1 R2 1 3
+ coords = coords.unsqueeze(3)
+
+ # B C R1 R2 1
+ feats = bilinear_sampler(input, coords)
+
+ return feats.permute(0, 2, 3, 1, 4).view(
+ B, feats.shape[2], feats.shape[3], feats.shape[1]
+ ) # B C R1 R2 1 -> B R1 R2 C
diff --git a/cotracker/models/evaluation_predictor.py b/cotracker/models/evaluation_predictor.py
index 87f8e18..223eb3c 100644
--- a/cotracker/models/evaluation_predictor.py
+++ b/cotracker/models/evaluation_predictor.py
@@ -1,104 +1,104 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torch.nn.functional as F
-from typing import Tuple
-
-from cotracker.models.core.cotracker.cotracker import CoTracker2
-from cotracker.models.core.model_utils import get_points_on_a_grid
-
-
-class EvaluationPredictor(torch.nn.Module):
- def __init__(
- self,
- cotracker_model: CoTracker2,
- interp_shape: Tuple[int, int] = (384, 512),
- grid_size: int = 5,
- local_grid_size: int = 8,
- single_point: bool = True,
- n_iters: int = 6,
- ) -> None:
- super(EvaluationPredictor, self).__init__()
- self.grid_size = grid_size
- self.local_grid_size = local_grid_size
- self.single_point = single_point
- self.interp_shape = interp_shape
- self.n_iters = n_iters
-
- self.model = cotracker_model
- self.model.eval()
-
- def forward(self, video, queries):
- queries = queries.clone()
- B, T, C, H, W = video.shape
- B, N, D = queries.shape
-
- assert D == 3
-
- video = video.reshape(B * T, C, H, W)
- video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
- video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
-
- device = video.device
-
- queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
- queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
-
- if self.single_point:
- traj_e = torch.zeros((B, T, N, 2), device=device)
- vis_e = torch.zeros((B, T, N), device=device)
- for pind in range((N)):
- query = queries[:, pind : pind + 1]
-
- t = query[0, 0, 0].long()
-
- traj_e_pind, vis_e_pind = self._process_one_point(video, query)
- traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
- vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
- else:
- if self.grid_size > 0:
- xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
- xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
- queries = torch.cat([queries, xy], dim=1) #
-
- traj_e, vis_e, __ = self.model(
- video=video,
- queries=queries,
- iters=self.n_iters,
- )
-
- traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
- traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
- return traj_e, vis_e
-
- def _process_one_point(self, video, query):
- t = query[0, 0, 0].long()
-
- device = query.device
- if self.local_grid_size > 0:
- xy_target = get_points_on_a_grid(
- self.local_grid_size,
- (50, 50),
- [query[0, 0, 2].item(), query[0, 0, 1].item()],
- )
-
- xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
- device
- ) #
- query = torch.cat([query, xy_target], dim=1) #
-
- if self.grid_size > 0:
- xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
- xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
- query = torch.cat([query, xy], dim=1) #
- # crop the video to start from the queried frame
- query[0, 0, 0] = 0
- traj_e_pind, vis_e_pind, __ = self.model(
- video=video[:, t:], queries=query, iters=self.n_iters
- )
-
- return traj_e_pind, vis_e_pind
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+from typing import Tuple
+
+from cotracker.models.core.cotracker.cotracker import CoTracker2
+from cotracker.models.core.model_utils import get_points_on_a_grid
+
+
+class EvaluationPredictor(torch.nn.Module):
+ def __init__(
+ self,
+ cotracker_model: CoTracker2,
+ interp_shape: Tuple[int, int] = (384, 512),
+ grid_size: int = 5,
+ local_grid_size: int = 8,
+ single_point: bool = True,
+ n_iters: int = 6,
+ ) -> None:
+ super(EvaluationPredictor, self).__init__()
+ self.grid_size = grid_size
+ self.local_grid_size = local_grid_size
+ self.single_point = single_point
+ self.interp_shape = interp_shape
+ self.n_iters = n_iters
+
+ self.model = cotracker_model
+ self.model.eval()
+
+ def forward(self, video, queries):
+ queries = queries.clone()
+ B, T, C, H, W = video.shape
+ B, N, D = queries.shape
+
+ assert D == 3
+
+ video = video.reshape(B * T, C, H, W)
+ video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
+ video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
+
+ device = video.device
+
+ queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
+ queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
+
+ if self.single_point:
+ traj_e = torch.zeros((B, T, N, 2), device=device)
+ vis_e = torch.zeros((B, T, N), device=device)
+ for pind in range((N)):
+ query = queries[:, pind : pind + 1]
+
+ t = query[0, 0, 0].long()
+
+ traj_e_pind, vis_e_pind = self._process_one_point(video, query)
+ traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
+ vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
+ else:
+ if self.grid_size > 0:
+ xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
+ xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
+ queries = torch.cat([queries, xy], dim=1) #
+
+ traj_e, vis_e, __ = self.model(
+ video=video,
+ queries=queries,
+ iters=self.n_iters,
+ )
+
+ traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
+ traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
+ return traj_e, vis_e
+
+ def _process_one_point(self, video, query):
+ t = query[0, 0, 0].long()
+
+ device = query.device
+ if self.local_grid_size > 0:
+ xy_target = get_points_on_a_grid(
+ self.local_grid_size,
+ (50, 50),
+ [query[0, 0, 2].item(), query[0, 0, 1].item()],
+ )
+
+ xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
+ device
+ ) #
+ query = torch.cat([query, xy_target], dim=1) #
+
+ if self.grid_size > 0:
+ xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
+ xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
+ query = torch.cat([query, xy], dim=1) #
+ # crop the video to start from the queried frame
+ query[0, 0, 0] = 0
+ traj_e_pind, vis_e_pind, __ = self.model(
+ video=video[:, t:], queries=query, iters=self.n_iters
+ )
+
+ return traj_e_pind, vis_e_pind
diff --git a/cotracker/predictor.py b/cotracker/predictor.py
index 067b50d..9778a7e 100644
--- a/cotracker/predictor.py
+++ b/cotracker/predictor.py
@@ -1,275 +1,279 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import torch
-import torch.nn.functional as F
-
-from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
-from cotracker.models.build_cotracker import build_cotracker
-
-
-class CoTrackerPredictor(torch.nn.Module):
- def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
- super().__init__()
- self.support_grid_size = 6
- model = build_cotracker(checkpoint)
- self.interp_shape = model.model_resolution
- print(self.interp_shape)
- self.model = model
- self.model.eval()
-
- @torch.no_grad()
- def forward(
- self,
- video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
- # input prompt types:
- # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
- # *backward_tracking=True* will compute tracks in both directions.
- # - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
- # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
- # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
- queries: torch.Tensor = None,
- segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
- grid_size: int = 0,
- grid_query_frame: int = 0, # only for dense and regular grid tracks
- backward_tracking: bool = False,
- ):
- if queries is None and grid_size == 0:
- tracks, visibilities = self._compute_dense_tracks(
- video,
- grid_query_frame=grid_query_frame,
- backward_tracking=backward_tracking,
- )
- else:
- tracks, visibilities = self._compute_sparse_tracks(
- video,
- queries,
- segm_mask,
- grid_size,
- add_support_grid=(grid_size == 0 or segm_mask is not None),
- grid_query_frame=grid_query_frame,
- backward_tracking=backward_tracking,
- )
-
- return tracks, visibilities
-
- def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
- *_, H, W = video.shape
- grid_step = W // grid_size
- grid_width = W // grid_step
- grid_height = H // grid_step # set the whole video to grid_size number of grids
- tracks = visibilities = None
- grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
- # (batch_size, grid_number, t,x,y)
- grid_pts[0, :, 0] = grid_query_frame
- # iterate every grid
- for offset in range(grid_step * grid_step):
- print(f"step {offset} / {grid_step * grid_step}")
- ox = offset % grid_step
- oy = offset // grid_step
- # initialize
- # for example
- # grid width = 4, grid height = 4, grid step = 10, ox = 1
- # torch.arange(grid_width) = [0,1,2,3]
- # torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
- # torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
- # get the location in the image
- grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
- grid_pts[0, :, 2] = (
- torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
- )
- tracks_step, visibilities_step = self._compute_sparse_tracks(
- video=video,
- queries=grid_pts,
- backward_tracking=backward_tracking,
- )
- tracks = smart_cat(tracks, tracks_step, dim=2)
- visibilities = smart_cat(visibilities, visibilities_step, dim=2)
-
- return tracks, visibilities
-
- def _compute_sparse_tracks(
- self,
- video,
- queries,
- segm_mask=None,
- grid_size=0,
- add_support_grid=False,
- grid_query_frame=0,
- backward_tracking=False,
- ):
- B, T, C, H, W = video.shape
-
- video = video.reshape(B * T, C, H, W)
- # ? what is interpolate?
- # 将video插值成interp_shape?
- video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
- video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
-
- if queries is not None:
- B, N, D = queries.shape # batch_size, number of points, (t,x,y)
- assert D == 3
- # query 缩放到( interp_shape - 1 ) / (W - 1)
- # 插完值之后缩放
- queries = queries.clone()
- queries[:, :, 1:] *= queries.new_tensor(
- [
- (self.interp_shape[1] - 1) / (W - 1),
- (self.interp_shape[0] - 1) / (H - 1),
- ]
- )
- # 生成grid
- elif grid_size > 0:
- grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
- if segm_mask is not None:
- segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
- point_mask = segm_mask[0, 0][
- (grid_pts[0, :, 1]).round().long().cpu(),
- (grid_pts[0, :, 0]).round().long().cpu(),
- ].bool()
- grid_pts = grid_pts[:, point_mask]
-
- queries = torch.cat(
- [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
- dim=2,
- ).repeat(B, 1, 1)
-
- # 添加支持点
-
- if add_support_grid:
- grid_pts = get_points_on_a_grid(
- self.support_grid_size, self.interp_shape, device=video.device
- )
- grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
- grid_pts = grid_pts.repeat(B, 1, 1)
- queries = torch.cat([queries, grid_pts], dim=1)
-
- tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
-
- if backward_tracking:
- tracks, visibilities = self._compute_backward_tracks(
- video, queries, tracks, visibilities
- )
- if add_support_grid:
- queries[:, -self.support_grid_size**2 :, 0] = T - 1
- if add_support_grid:
- tracks = tracks[:, :, : -self.support_grid_size**2]
- visibilities = visibilities[:, :, : -self.support_grid_size**2]
- thr = 0.9
- visibilities = visibilities > thr
-
- # correct query-point predictions
- # see https://github.com/facebookresearch/co-tracker/issues/28
-
- # TODO: batchify
- for i in range(len(queries)):
- queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
- arange = torch.arange(0, len(queries_t))
-
- # overwrite the predictions with the query points
- tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
-
- # correct visibilities, the query points should be visible
- visibilities[i, queries_t, arange] = True
-
- tracks *= tracks.new_tensor(
- [(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
- )
- return tracks, visibilities
-
- def _compute_backward_tracks(self, video, queries, tracks, visibilities):
- inv_video = video.flip(1).clone()
- inv_queries = queries.clone()
- inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
-
- inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
-
- inv_tracks = inv_tracks.flip(1)
- inv_visibilities = inv_visibilities.flip(1)
- arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
-
- mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
-
- tracks[mask] = inv_tracks[mask]
- visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
- return tracks, visibilities
-
-
-class CoTrackerOnlinePredictor(torch.nn.Module):
- def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
- super().__init__()
- self.support_grid_size = 6
- model = build_cotracker(checkpoint)
- self.interp_shape = model.model_resolution
- self.step = model.window_len // 2
- self.model = model
- self.model.eval()
-
- @torch.no_grad()
- def forward(
- self,
- video_chunk,
- is_first_step: bool = False,
- queries: torch.Tensor = None,
- grid_size: int = 10,
- grid_query_frame: int = 0,
- add_support_grid=False,
- ):
- B, T, C, H, W = video_chunk.shape
- # Initialize online video processing and save queried points
- # This needs to be done before processing *each new video*
- if is_first_step:
- self.model.init_video_online_processing()
- if queries is not None:
- B, N, D = queries.shape
- assert D == 3
- queries = queries.clone()
- queries[:, :, 1:] *= queries.new_tensor(
- [
- (self.interp_shape[1] - 1) / (W - 1),
- (self.interp_shape[0] - 1) / (H - 1),
- ]
- )
- elif grid_size > 0:
- grid_pts = get_points_on_a_grid(
- grid_size, self.interp_shape, device=video_chunk.device
- )
- queries = torch.cat(
- [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
- dim=2,
- )
- if add_support_grid:
- grid_pts = get_points_on_a_grid(
- self.support_grid_size, self.interp_shape, device=video_chunk.device
- )
- grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
- queries = torch.cat([queries, grid_pts], dim=1)
- self.queries = queries
- return (None, None)
-
- video_chunk = video_chunk.reshape(B * T, C, H, W)
- video_chunk = F.interpolate(
- video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
- )
- video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
-
- tracks, visibilities, __ = self.model(
- video=video_chunk,
- queries=self.queries,
- iters=6,
- is_online=True,
- )
- thr = 0.9
- return (
- tracks
- * tracks.new_tensor(
- [
- (W - 1) / (self.interp_shape[1] - 1),
- (H - 1) / (self.interp_shape[0] - 1),
- ]
- ),
- visibilities > thr,
- )
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn.functional as F
+
+from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
+from cotracker.models.build_cotracker import build_cotracker
+
+
+class CoTrackerPredictor(torch.nn.Module):
+ def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
+ super().__init__()
+ self.support_grid_size = 6
+ model = build_cotracker(checkpoint)
+ self.interp_shape = model.model_resolution
+ print(self.interp_shape)
+ self.model = model
+ self.model.eval()
+
+ @torch.no_grad()
+ def forward(
+ self,
+ video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
+ # input prompt types:
+ # - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
+ # *backward_tracking=True* will compute tracks in both directions.
+ # - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
+ # - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
+ # You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
+ queries: torch.Tensor = None,
+ segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
+ grid_size: int = 0,
+ grid_query_frame: int = 0, # only for dense and regular grid tracks
+ backward_tracking: bool = False,
+ ):
+ if queries is None and grid_size == 0:
+ tracks, visibilities = self._compute_dense_tracks(
+ video,
+ grid_query_frame=grid_query_frame,
+ backward_tracking=backward_tracking,
+ )
+ else:
+ tracks, visibilities = self._compute_sparse_tracks(
+ video,
+ queries,
+ segm_mask,
+ grid_size,
+ add_support_grid=(grid_size == 0 or segm_mask is not None),
+ grid_query_frame=grid_query_frame,
+ backward_tracking=backward_tracking,
+ )
+
+ return tracks, visibilities
+
+ # gpu dense inference time
+ # raft gpu comparison
+ # vision effects
+ # raft integrated
+ def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
+ *_, H, W = video.shape
+ grid_step = W // grid_size
+ grid_width = W // grid_step
+ grid_height = H // grid_step # set the whole video to grid_size number of grids
+ tracks = visibilities = None
+ grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
+ # (batch_size, grid_number, t,x,y)
+ grid_pts[0, :, 0] = grid_query_frame
+ # iterate every grid
+ for offset in range(grid_step * grid_step):
+ print(f"step {offset} / {grid_step * grid_step}")
+ ox = offset % grid_step
+ oy = offset // grid_step
+ # initialize
+ # for example
+ # grid width = 4, grid height = 4, grid step = 10, ox = 1
+ # torch.arange(grid_width) = [0,1,2,3]
+ # torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
+ # torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
+ # get the location in the image
+ grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
+ grid_pts[0, :, 2] = (
+ torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
+ )
+ tracks_step, visibilities_step = self._compute_sparse_tracks(
+ video=video,
+ queries=grid_pts,
+ backward_tracking=backward_tracking,
+ )
+ tracks = smart_cat(tracks, tracks_step, dim=2)
+ visibilities = smart_cat(visibilities, visibilities_step, dim=2)
+
+ return tracks, visibilities
+
+ def _compute_sparse_tracks(
+ self,
+ video,
+ queries,
+ segm_mask=None,
+ grid_size=0,
+ add_support_grid=False,
+ grid_query_frame=0,
+ backward_tracking=False,
+ ):
+ B, T, C, H, W = video.shape
+
+ video = video.reshape(B * T, C, H, W)
+ # ? what is interpolate?
+ # 将video插值成interp_shape?
+ video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
+ video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
+
+ if queries is not None:
+ B, N, D = queries.shape # batch_size, number of points, (t,x,y)
+ assert D == 3
+ # query 缩放到( interp_shape - 1 ) / (W - 1)
+ # 插完值之后缩放
+ queries = queries.clone()
+ queries[:, :, 1:] *= queries.new_tensor(
+ [
+ (self.interp_shape[1] - 1) / (W - 1),
+ (self.interp_shape[0] - 1) / (H - 1),
+ ]
+ )
+ # 生成grid
+ elif grid_size > 0:
+ grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
+ if segm_mask is not None:
+ segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
+ point_mask = segm_mask[0, 0][
+ (grid_pts[0, :, 1]).round().long().cpu(),
+ (grid_pts[0, :, 0]).round().long().cpu(),
+ ].bool()
+ grid_pts = grid_pts[:, point_mask]
+
+ queries = torch.cat(
+ [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
+ dim=2,
+ ).repeat(B, 1, 1)
+
+ # 添加支持点
+
+ if add_support_grid:
+ grid_pts = get_points_on_a_grid(
+ self.support_grid_size, self.interp_shape, device=video.device
+ )
+ grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
+ grid_pts = grid_pts.repeat(B, 1, 1)
+ queries = torch.cat([queries, grid_pts], dim=1)
+
+ tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
+
+ if backward_tracking:
+ tracks, visibilities = self._compute_backward_tracks(
+ video, queries, tracks, visibilities
+ )
+ if add_support_grid:
+ queries[:, -self.support_grid_size**2 :, 0] = T - 1
+ if add_support_grid:
+ tracks = tracks[:, :, : -self.support_grid_size**2]
+ visibilities = visibilities[:, :, : -self.support_grid_size**2]
+ thr = 0.9
+ visibilities = visibilities > thr
+
+ # correct query-point predictions
+ # see https://github.com/facebookresearch/co-tracker/issues/28
+
+ # TODO: batchify
+ for i in range(len(queries)):
+ queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
+ arange = torch.arange(0, len(queries_t))
+
+ # overwrite the predictions with the query points
+ tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
+
+ # correct visibilities, the query points should be visible
+ visibilities[i, queries_t, arange] = True
+
+ tracks *= tracks.new_tensor(
+ [(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
+ )
+ return tracks, visibilities
+
+ def _compute_backward_tracks(self, video, queries, tracks, visibilities):
+ inv_video = video.flip(1).clone()
+ inv_queries = queries.clone()
+ inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
+
+ inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
+
+ inv_tracks = inv_tracks.flip(1)
+ inv_visibilities = inv_visibilities.flip(1)
+ arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
+
+ mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
+
+ tracks[mask] = inv_tracks[mask]
+ visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
+ return tracks, visibilities
+
+
+class CoTrackerOnlinePredictor(torch.nn.Module):
+ def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
+ super().__init__()
+ self.support_grid_size = 6
+ model = build_cotracker(checkpoint)
+ self.interp_shape = model.model_resolution
+ self.step = model.window_len // 2
+ self.model = model
+ self.model.eval()
+
+ @torch.no_grad()
+ def forward(
+ self,
+ video_chunk,
+ is_first_step: bool = False,
+ queries: torch.Tensor = None,
+ grid_size: int = 10,
+ grid_query_frame: int = 0,
+ add_support_grid=False,
+ ):
+ B, T, C, H, W = video_chunk.shape
+ # Initialize online video processing and save queried points
+ # This needs to be done before processing *each new video*
+ if is_first_step:
+ self.model.init_video_online_processing()
+ if queries is not None:
+ B, N, D = queries.shape
+ assert D == 3
+ queries = queries.clone()
+ queries[:, :, 1:] *= queries.new_tensor(
+ [
+ (self.interp_shape[1] - 1) / (W - 1),
+ (self.interp_shape[0] - 1) / (H - 1),
+ ]
+ )
+ elif grid_size > 0:
+ grid_pts = get_points_on_a_grid(
+ grid_size, self.interp_shape, device=video_chunk.device
+ )
+ queries = torch.cat(
+ [torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
+ dim=2,
+ )
+ if add_support_grid:
+ grid_pts = get_points_on_a_grid(
+ self.support_grid_size, self.interp_shape, device=video_chunk.device
+ )
+ grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
+ queries = torch.cat([queries, grid_pts], dim=1)
+ self.queries = queries
+ return (None, None)
+
+ video_chunk = video_chunk.reshape(B * T, C, H, W)
+ video_chunk = F.interpolate(
+ video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
+ )
+ video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
+
+ tracks, visibilities, __ = self.model(
+ video=video_chunk,
+ queries=self.queries,
+ iters=6,
+ is_online=True,
+ )
+ thr = 0.9
+ return (
+ tracks
+ * tracks.new_tensor(
+ [
+ (W - 1) / (self.interp_shape[1] - 1),
+ (H - 1) / (self.interp_shape[0] - 1),
+ ]
+ ),
+ visibilities > thr,
+ )
diff --git a/cotracker/utils/__init__.py b/cotracker/utils/__init__.py
index 5277f46..4547e07 100644
--- a/cotracker/utils/__init__.py
+++ b/cotracker/utils/__init__.py
@@ -1,5 +1,5 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
diff --git a/cotracker/utils/__pycache__/__init__.cpython-38.pyc b/cotracker/utils/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000..2e14358
Binary files /dev/null and b/cotracker/utils/__pycache__/__init__.cpython-38.pyc differ
diff --git a/cotracker/utils/__pycache__/__init__.cpython-39.pyc b/cotracker/utils/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000..57b947a
Binary files /dev/null and b/cotracker/utils/__pycache__/__init__.cpython-39.pyc differ
diff --git a/cotracker/utils/__pycache__/visualizer.cpython-38.pyc b/cotracker/utils/__pycache__/visualizer.cpython-38.pyc
new file mode 100644
index 0000000..7f8f759
Binary files /dev/null and b/cotracker/utils/__pycache__/visualizer.cpython-38.pyc differ
diff --git a/cotracker/utils/__pycache__/visualizer.cpython-39.pyc b/cotracker/utils/__pycache__/visualizer.cpython-39.pyc
new file mode 100644
index 0000000..f029a3e
Binary files /dev/null and b/cotracker/utils/__pycache__/visualizer.cpython-39.pyc differ
diff --git a/cotracker/utils/visualizer.py b/cotracker/utils/visualizer.py
index 88287c3..22ba43a 100644
--- a/cotracker/utils/visualizer.py
+++ b/cotracker/utils/visualizer.py
@@ -1,343 +1,343 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-import os
-import numpy as np
-import imageio
-import torch
-
-from matplotlib import cm
-import torch.nn.functional as F
-import torchvision.transforms as transforms
-import matplotlib.pyplot as plt
-from PIL import Image, ImageDraw
-
-
-def read_video_from_path(path):
- try:
- reader = imageio.get_reader(path)
- except Exception as e:
- print("Error opening video file: ", e)
- return None
- frames = []
- for i, im in enumerate(reader):
- frames.append(np.array(im))
- return np.stack(frames)
-
-
-def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
- # Create a draw object
- draw = ImageDraw.Draw(rgb)
- # Calculate the bounding box of the circle
- left_up_point = (coord[0] - radius, coord[1] - radius)
- right_down_point = (coord[0] + radius, coord[1] + radius)
- # Draw the circle
- draw.ellipse(
- [left_up_point, right_down_point],
- fill=tuple(color) if visible else None,
- outline=tuple(color),
- )
- return rgb
-
-
-def draw_line(rgb, coord_y, coord_x, color, linewidth):
- draw = ImageDraw.Draw(rgb)
- draw.line(
- (coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
- fill=tuple(color),
- width=linewidth,
- )
- return rgb
-
-
-def add_weighted(rgb, alpha, original, beta, gamma):
- return (rgb * alpha + original * beta + gamma).astype("uint8")
-
-
-class Visualizer:
- def __init__(
- self,
- save_dir: str = "./results",
- grayscale: bool = False,
- pad_value: int = 0,
- fps: int = 10,
- mode: str = "rainbow", # 'cool', 'optical_flow'
- linewidth: int = 2,
- show_first_frame: int = 10,
- tracks_leave_trace: int = 0, # -1 for infinite
- ):
- self.mode = mode
- self.save_dir = save_dir
- if mode == "rainbow":
- self.color_map = cm.get_cmap("gist_rainbow")
- elif mode == "cool":
- self.color_map = cm.get_cmap(mode)
- self.show_first_frame = show_first_frame
- self.grayscale = grayscale
- self.tracks_leave_trace = tracks_leave_trace
- self.pad_value = pad_value
- self.linewidth = linewidth
- self.fps = fps
-
- def visualize(
- self,
- video: torch.Tensor, # (B,T,C,H,W)
- tracks: torch.Tensor, # (B,T,N,2)
- visibility: torch.Tensor = None, # (B, T, N, 1) bool
- gt_tracks: torch.Tensor = None, # (B,T,N,2)
- segm_mask: torch.Tensor = None, # (B,1,H,W)
- filename: str = "video",
- writer=None, # tensorboard Summary Writer, used for visualization during training
- step: int = 0,
- query_frame: int = 0,
- save_video: bool = True,
- compensate_for_camera_motion: bool = False,
- ):
- if compensate_for_camera_motion:
- assert segm_mask is not None
- if segm_mask is not None:
- coords = tracks[0, query_frame].round().long()
- segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
-
- video = F.pad(
- video,
- (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
- "constant",
- 255,
- )
- tracks = tracks + self.pad_value
-
- if self.grayscale:
- transform = transforms.Grayscale()
- video = transform(video)
- video = video.repeat(1, 1, 3, 1, 1)
-
- res_video = self.draw_tracks_on_video(
- video=video,
- tracks=tracks,
- visibility=visibility,
- segm_mask=segm_mask,
- gt_tracks=gt_tracks,
- query_frame=query_frame,
- compensate_for_camera_motion=compensate_for_camera_motion,
- )
- if save_video:
- self.save_video(res_video, filename=filename, writer=writer, step=step)
- return res_video
-
- def save_video(self, video, filename, writer=None, step=0):
- if writer is not None:
- writer.add_video(
- filename,
- video.to(torch.uint8),
- global_step=step,
- fps=self.fps,
- )
- else:
- os.makedirs(self.save_dir, exist_ok=True)
- wide_list = list(video.unbind(1))
- wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
-
- # Prepare the video file path
- save_path = os.path.join(self.save_dir, f"{filename}.mp4")
-
- # Create a writer object
- video_writer = imageio.get_writer(save_path, fps=self.fps)
-
- # Write frames to the video file
- for frame in wide_list[2:-1]:
- video_writer.append_data(frame)
-
- video_writer.close()
-
- print(f"Video saved to {save_path}")
-
- def draw_tracks_on_video(
- self,
- video: torch.Tensor,
- tracks: torch.Tensor,
- visibility: torch.Tensor = None,
- segm_mask: torch.Tensor = None,
- gt_tracks=None,
- query_frame: int = 0,
- compensate_for_camera_motion=False,
- ):
- B, T, C, H, W = video.shape
- _, _, N, D = tracks.shape
-
- assert D == 2
- assert C == 3
- video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
- tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
- if gt_tracks is not None:
- gt_tracks = gt_tracks[0].detach().cpu().numpy()
-
- res_video = []
-
- # process input video
- for rgb in video:
- res_video.append(rgb.copy())
- vector_colors = np.zeros((T, N, 3))
-
- if self.mode == "optical_flow":
- import flow_vis
-
- vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
- elif segm_mask is None:
- if self.mode == "rainbow":
- y_min, y_max = (
- tracks[query_frame, :, 1].min(),
- tracks[query_frame, :, 1].max(),
- )
- norm = plt.Normalize(y_min, y_max)
- for n in range(N):
- color = self.color_map(norm(tracks[query_frame, n, 1]))
- color = np.array(color[:3])[None] * 255
- vector_colors[:, n] = np.repeat(color, T, axis=0)
- else:
- # color changes with time
- for t in range(T):
- color = np.array(self.color_map(t / T)[:3])[None] * 255
- vector_colors[t] = np.repeat(color, N, axis=0)
- else:
- if self.mode == "rainbow":
- vector_colors[:, segm_mask <= 0, :] = 255
-
- y_min, y_max = (
- tracks[0, segm_mask > 0, 1].min(),
- tracks[0, segm_mask > 0, 1].max(),
- )
- norm = plt.Normalize(y_min, y_max)
- for n in range(N):
- if segm_mask[n] > 0:
- color = self.color_map(norm(tracks[0, n, 1]))
- color = np.array(color[:3])[None] * 255
- vector_colors[:, n] = np.repeat(color, T, axis=0)
-
- else:
- # color changes with segm class
- segm_mask = segm_mask.cpu()
- color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
- color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
- color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
- vector_colors = np.repeat(color[None], T, axis=0)
-
- # draw tracks
- if self.tracks_leave_trace != 0:
- for t in range(query_frame + 1, T):
- first_ind = (
- max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
- )
- curr_tracks = tracks[first_ind : t + 1]
- curr_colors = vector_colors[first_ind : t + 1]
- if compensate_for_camera_motion:
- diff = (
- tracks[first_ind : t + 1, segm_mask <= 0]
- - tracks[t : t + 1, segm_mask <= 0]
- ).mean(1)[:, None]
-
- curr_tracks = curr_tracks - diff
- curr_tracks = curr_tracks[:, segm_mask > 0]
- curr_colors = curr_colors[:, segm_mask > 0]
-
- res_video[t] = self._draw_pred_tracks(
- res_video[t],
- curr_tracks,
- curr_colors,
- )
- if gt_tracks is not None:
- res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
-
- # draw points
- for t in range(query_frame, T):
- img = Image.fromarray(np.uint8(res_video[t]))
- for i in range(N):
- coord = (tracks[t, i, 0], tracks[t, i, 1])
- visibile = True
- if visibility is not None:
- visibile = visibility[0, t, i]
- if coord[0] != 0 and coord[1] != 0:
- if not compensate_for_camera_motion or (
- compensate_for_camera_motion and segm_mask[i] > 0
- ):
- img = draw_circle(
- img,
- coord=coord,
- radius=int(self.linewidth * 2),
- color=vector_colors[t, i].astype(int),
- visible=visibile,
- )
- res_video[t] = np.array(img)
-
- # construct the final rgb sequence
- if self.show_first_frame > 0:
- res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
- return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
-
- def _draw_pred_tracks(
- self,
- rgb: np.ndarray, # H x W x 3
- tracks: np.ndarray, # T x 2
- vector_colors: np.ndarray,
- alpha: float = 0.5,
- ):
- T, N, _ = tracks.shape
- rgb = Image.fromarray(np.uint8(rgb))
- for s in range(T - 1):
- vector_color = vector_colors[s]
- original = rgb.copy()
- alpha = (s / T) ** 2
- for i in range(N):
- coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
- coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
- if coord_y[0] != 0 and coord_y[1] != 0:
- rgb = draw_line(
- rgb,
- coord_y,
- coord_x,
- vector_color[i].astype(int),
- self.linewidth,
- )
- if self.tracks_leave_trace > 0:
- rgb = Image.fromarray(
- np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
- )
- rgb = np.array(rgb)
- return rgb
-
- def _draw_gt_tracks(
- self,
- rgb: np.ndarray, # H x W x 3,
- gt_tracks: np.ndarray, # T x 2
- ):
- T, N, _ = gt_tracks.shape
- color = np.array((211, 0, 0))
- rgb = Image.fromarray(np.uint8(rgb))
- for t in range(T):
- for i in range(N):
- gt_tracks = gt_tracks[t][i]
- # draw a red cross
- if gt_tracks[0] > 0 and gt_tracks[1] > 0:
- length = self.linewidth * 3
- coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
- coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
- rgb = draw_line(
- rgb,
- coord_y,
- coord_x,
- color,
- self.linewidth,
- )
- coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
- coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
- rgb = draw_line(
- rgb,
- coord_y,
- coord_x,
- color,
- self.linewidth,
- )
- rgb = np.array(rgb)
- return rgb
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+import os
+import numpy as np
+import imageio
+import torch
+
+from matplotlib import cm
+import torch.nn.functional as F
+import torchvision.transforms as transforms
+import matplotlib.pyplot as plt
+from PIL import Image, ImageDraw
+
+
+def read_video_from_path(path):
+ try:
+ reader = imageio.get_reader(path)
+ except Exception as e:
+ print("Error opening video file: ", e)
+ return None
+ frames = []
+ for i, im in enumerate(reader):
+ frames.append(np.array(im))
+ return np.stack(frames)
+
+
+def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
+ # Create a draw object
+ draw = ImageDraw.Draw(rgb)
+ # Calculate the bounding box of the circle
+ left_up_point = (coord[0] - radius, coord[1] - radius)
+ right_down_point = (coord[0] + radius, coord[1] + radius)
+ # Draw the circle
+ draw.ellipse(
+ [left_up_point, right_down_point],
+ fill=tuple(color) if visible else None,
+ outline=tuple(color),
+ )
+ return rgb
+
+
+def draw_line(rgb, coord_y, coord_x, color, linewidth):
+ draw = ImageDraw.Draw(rgb)
+ draw.line(
+ (coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
+ fill=tuple(color),
+ width=linewidth,
+ )
+ return rgb
+
+
+def add_weighted(rgb, alpha, original, beta, gamma):
+ return (rgb * alpha + original * beta + gamma).astype("uint8")
+
+
+class Visualizer:
+ def __init__(
+ self,
+ save_dir: str = "./results",
+ grayscale: bool = False,
+ pad_value: int = 0,
+ fps: int = 10,
+ mode: str = "rainbow", # 'cool', 'optical_flow'
+ linewidth: int = 2,
+ show_first_frame: int = 10,
+ tracks_leave_trace: int = 0, # -1 for infinite
+ ):
+ self.mode = mode
+ self.save_dir = save_dir
+ if mode == "rainbow":
+ self.color_map = cm.get_cmap("gist_rainbow")
+ elif mode == "cool":
+ self.color_map = cm.get_cmap(mode)
+ self.show_first_frame = show_first_frame
+ self.grayscale = grayscale
+ self.tracks_leave_trace = tracks_leave_trace
+ self.pad_value = pad_value
+ self.linewidth = linewidth
+ self.fps = fps
+
+ def visualize(
+ self,
+ video: torch.Tensor, # (B,T,C,H,W)
+ tracks: torch.Tensor, # (B,T,N,2)
+ visibility: torch.Tensor = None, # (B, T, N, 1) bool
+ gt_tracks: torch.Tensor = None, # (B,T,N,2)
+ segm_mask: torch.Tensor = None, # (B,1,H,W)
+ filename: str = "video",
+ writer=None, # tensorboard Summary Writer, used for visualization during training
+ step: int = 0,
+ query_frame: int = 0,
+ save_video: bool = True,
+ compensate_for_camera_motion: bool = False,
+ ):
+ if compensate_for_camera_motion:
+ assert segm_mask is not None
+ if segm_mask is not None:
+ coords = tracks[0, query_frame].round().long()
+ segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
+
+ video = F.pad(
+ video,
+ (self.pad_value, self.pad_value, self.pad_value, self.pad_value),
+ "constant",
+ 255,
+ )
+ tracks = tracks + self.pad_value
+
+ if self.grayscale:
+ transform = transforms.Grayscale()
+ video = transform(video)
+ video = video.repeat(1, 1, 3, 1, 1)
+
+ res_video = self.draw_tracks_on_video(
+ video=video,
+ tracks=tracks,
+ visibility=visibility,
+ segm_mask=segm_mask,
+ gt_tracks=gt_tracks,
+ query_frame=query_frame,
+ compensate_for_camera_motion=compensate_for_camera_motion,
+ )
+ if save_video:
+ self.save_video(res_video, filename=filename, writer=writer, step=step)
+ return res_video
+
+ def save_video(self, video, filename, writer=None, step=0):
+ if writer is not None:
+ writer.add_video(
+ filename,
+ video.to(torch.uint8),
+ global_step=step,
+ fps=self.fps,
+ )
+ else:
+ os.makedirs(self.save_dir, exist_ok=True)
+ wide_list = list(video.unbind(1))
+ wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
+
+ # Prepare the video file path
+ save_path = os.path.join(self.save_dir, f"{filename}.mp4")
+
+ # Create a writer object
+ video_writer = imageio.get_writer(save_path, fps=self.fps)
+
+ # Write frames to the video file
+ for frame in wide_list[2:-1]:
+ video_writer.append_data(frame)
+
+ video_writer.close()
+
+ print(f"Video saved to {save_path}")
+
+ def draw_tracks_on_video(
+ self,
+ video: torch.Tensor,
+ tracks: torch.Tensor,
+ visibility: torch.Tensor = None,
+ segm_mask: torch.Tensor = None,
+ gt_tracks=None,
+ query_frame: int = 0,
+ compensate_for_camera_motion=False,
+ ):
+ B, T, C, H, W = video.shape
+ _, _, N, D = tracks.shape
+
+ assert D == 2
+ assert C == 3
+ video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
+ tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
+ if gt_tracks is not None:
+ gt_tracks = gt_tracks[0].detach().cpu().numpy()
+
+ res_video = []
+
+ # process input video
+ for rgb in video:
+ res_video.append(rgb.copy())
+ vector_colors = np.zeros((T, N, 3))
+
+ if self.mode == "optical_flow":
+ import flow_vis
+
+ vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
+ elif segm_mask is None:
+ if self.mode == "rainbow":
+ y_min, y_max = (
+ tracks[query_frame, :, 1].min(),
+ tracks[query_frame, :, 1].max(),
+ )
+ norm = plt.Normalize(y_min, y_max)
+ for n in range(N):
+ color = self.color_map(norm(tracks[query_frame, n, 1]))
+ color = np.array(color[:3])[None] * 255
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
+ else:
+ # color changes with time
+ for t in range(T):
+ color = np.array(self.color_map(t / T)[:3])[None] * 255
+ vector_colors[t] = np.repeat(color, N, axis=0)
+ else:
+ if self.mode == "rainbow":
+ vector_colors[:, segm_mask <= 0, :] = 255
+
+ y_min, y_max = (
+ tracks[0, segm_mask > 0, 1].min(),
+ tracks[0, segm_mask > 0, 1].max(),
+ )
+ norm = plt.Normalize(y_min, y_max)
+ for n in range(N):
+ if segm_mask[n] > 0:
+ color = self.color_map(norm(tracks[0, n, 1]))
+ color = np.array(color[:3])[None] * 255
+ vector_colors[:, n] = np.repeat(color, T, axis=0)
+
+ else:
+ # color changes with segm class
+ segm_mask = segm_mask.cpu()
+ color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
+ color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
+ color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
+ vector_colors = np.repeat(color[None], T, axis=0)
+
+ # draw tracks
+ if self.tracks_leave_trace != 0:
+ for t in range(query_frame + 1, T):
+ first_ind = (
+ max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
+ )
+ curr_tracks = tracks[first_ind : t + 1]
+ curr_colors = vector_colors[first_ind : t + 1]
+ if compensate_for_camera_motion:
+ diff = (
+ tracks[first_ind : t + 1, segm_mask <= 0]
+ - tracks[t : t + 1, segm_mask <= 0]
+ ).mean(1)[:, None]
+
+ curr_tracks = curr_tracks - diff
+ curr_tracks = curr_tracks[:, segm_mask > 0]
+ curr_colors = curr_colors[:, segm_mask > 0]
+
+ res_video[t] = self._draw_pred_tracks(
+ res_video[t],
+ curr_tracks,
+ curr_colors,
+ )
+ if gt_tracks is not None:
+ res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
+
+ # draw points
+ for t in range(query_frame, T):
+ img = Image.fromarray(np.uint8(res_video[t]))
+ for i in range(N):
+ coord = (tracks[t, i, 0], tracks[t, i, 1])
+ visibile = True
+ if visibility is not None:
+ visibile = visibility[0, t, i]
+ if coord[0] != 0 and coord[1] != 0:
+ if not compensate_for_camera_motion or (
+ compensate_for_camera_motion and segm_mask[i] > 0
+ ):
+ img = draw_circle(
+ img,
+ coord=coord,
+ radius=int(self.linewidth * 2),
+ color=vector_colors[t, i].astype(int),
+ visible=visibile,
+ )
+ res_video[t] = np.array(img)
+
+ # construct the final rgb sequence
+ if self.show_first_frame > 0:
+ res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
+ return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
+
+ def _draw_pred_tracks(
+ self,
+ rgb: np.ndarray, # H x W x 3
+ tracks: np.ndarray, # T x 2
+ vector_colors: np.ndarray,
+ alpha: float = 0.5,
+ ):
+ T, N, _ = tracks.shape
+ rgb = Image.fromarray(np.uint8(rgb))
+ for s in range(T - 1):
+ vector_color = vector_colors[s]
+ original = rgb.copy()
+ alpha = (s / T) ** 2
+ for i in range(N):
+ coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
+ coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
+ if coord_y[0] != 0 and coord_y[1] != 0:
+ rgb = draw_line(
+ rgb,
+ coord_y,
+ coord_x,
+ vector_color[i].astype(int),
+ self.linewidth,
+ )
+ if self.tracks_leave_trace > 0:
+ rgb = Image.fromarray(
+ np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
+ )
+ rgb = np.array(rgb)
+ return rgb
+
+ def _draw_gt_tracks(
+ self,
+ rgb: np.ndarray, # H x W x 3,
+ gt_tracks: np.ndarray, # T x 2
+ ):
+ T, N, _ = gt_tracks.shape
+ color = np.array((211, 0, 0))
+ rgb = Image.fromarray(np.uint8(rgb))
+ for t in range(T):
+ for i in range(N):
+ gt_tracks = gt_tracks[t][i]
+ # draw a red cross
+ if gt_tracks[0] > 0 and gt_tracks[1] > 0:
+ length = self.linewidth * 3
+ coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
+ coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
+ rgb = draw_line(
+ rgb,
+ coord_y,
+ coord_x,
+ color,
+ self.linewidth,
+ )
+ coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
+ coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
+ rgb = draw_line(
+ rgb,
+ coord_y,
+ coord_x,
+ color,
+ self.linewidth,
+ )
+ rgb = np.array(rgb)
+ return rgb
diff --git a/cotracker/version.py b/cotracker/version.py
index 4bdf9b4..d1cdb8f 100644
--- a/cotracker/version.py
+++ b/cotracker/version.py
@@ -1,8 +1,8 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-
-__version__ = "2.0.0"
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+
+__version__ = "2.0.0"
diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb
index 2a09fef..0a3dcfe 100644
--- a/notebooks/demo.ipynb
+++ b/notebooks/demo.ipynb
@@ -54,21 +54,34 @@
"metadata": {},
"outputs": [],
"source": [
- "!git clone https://github.com/facebookresearch/co-tracker\n",
- "%cd co-tracker\n",
- "!pip install -e .\n",
- "!pip install opencv-python einops timm matplotlib moviepy flow_vis\n",
- "!mkdir checkpoints\n",
- "%cd checkpoints\n",
- "!wget https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth"
+ "# !git clone https://github.com/facebookresearch/co-tracker\n",
+ "# %cd co-tracker\n",
+ "# !pip install -e .\n",
+ "# !pip install opencv-python einops timm matplotlib moviepy flow_vis\n",
+ "# !mkdir checkpoints\n",
+ "# %cd checkpoints\n",
+ "# !wget https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "1745a859-71d4-4ec3-8ef3-027cabe786d4",
- "metadata": {},
- "outputs": [],
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-29T20:52:14.487553Z",
+ "start_time": "2024-07-29T20:52:12.423999Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "/mnt/d/cotracker\n"
+ ]
+ }
+ ],
"source": [
"%cd ..\n",
"import os\n",
@@ -79,6 +92,30 @@
"from IPython.display import HTML"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "44342f62abc0ec1e",
+ "metadata": {
+ "ExecuteTime": {
+ "end_time": "2024-07-29T20:52:31.688043Z",
+ "start_time": "2024-07-29T20:52:31.668043Z"
+ }
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "CUDA available\n"
+ ]
+ }
+ ],
+ "source": [
+ "if torch.cuda.is_available():\n",
+ " print('CUDA available')"
+ ]
+ },
{
"cell_type": "markdown",
"id": "7894bd2d-2099-46fa-8286-f0c56298ecd1",
@@ -89,31 +126,31 @@
},
{
"cell_type": "code",
- "execution_count": 3,
+ "execution_count": 4,
"id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd",
"metadata": {},
"outputs": [],
"source": [
- "video = read_video_from_path('./assets/apple.mp4')\n",
+ "video = read_video_from_path('./assets/F1_shorts.mp4')\n",
"video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()"
]
},
{
"cell_type": "code",
- "execution_count": 4,
+ "execution_count": 5,
"id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
- ""
+ "