add some comments
This commit is contained in:
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
BIN
cotracker/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/predictor.cpython-38.pyc
Normal file
BIN
cotracker/__pycache__/predictor.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/__pycache__/predictor.cpython-39.pyc
Normal file
BIN
cotracker/__pycache__/predictor.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
@@ -1,166 +1,166 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from dataclasses import Field, MISSING
|
from dataclasses import Field, MISSING
|
||||||
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
|
from typing import IO, TypeVar, Type, get_args, get_origin, Union, Any, Tuple
|
||||||
|
|
||||||
_X = TypeVar("_X")
|
_X = TypeVar("_X")
|
||||||
|
|
||||||
|
|
||||||
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
|
def load_dataclass(f: IO, cls: Type[_X], binary: bool = False) -> _X:
|
||||||
"""
|
"""
|
||||||
Loads to a @dataclass or collection hierarchy including dataclasses
|
Loads to a @dataclass or collection hierarchy including dataclasses
|
||||||
from a json recursively.
|
from a json recursively.
|
||||||
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
|
Call it like load_dataclass(f, typing.List[FrameAnnotationAnnotation]).
|
||||||
raises KeyError if json has keys not mapping to the dataclass fields.
|
raises KeyError if json has keys not mapping to the dataclass fields.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
f: Either a path to a file, or a file opened for writing.
|
f: Either a path to a file, or a file opened for writing.
|
||||||
cls: The class of the loaded dataclass.
|
cls: The class of the loaded dataclass.
|
||||||
binary: Set to True if `f` is a file handle, else False.
|
binary: Set to True if `f` is a file handle, else False.
|
||||||
"""
|
"""
|
||||||
if binary:
|
if binary:
|
||||||
asdict = json.loads(f.read().decode("utf8"))
|
asdict = json.loads(f.read().decode("utf8"))
|
||||||
else:
|
else:
|
||||||
asdict = json.load(f)
|
asdict = json.load(f)
|
||||||
|
|
||||||
# in the list case, run a faster "vectorized" version
|
# in the list case, run a faster "vectorized" version
|
||||||
cls = get_args(cls)[0]
|
cls = get_args(cls)[0]
|
||||||
res = list(_dataclass_list_from_dict_list(asdict, cls))
|
res = list(_dataclass_list_from_dict_list(asdict, cls))
|
||||||
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
def _resolve_optional(type_: Any) -> Tuple[bool, Any]:
|
||||||
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
"""Check whether `type_` is equivalent to `typing.Optional[T]` for some T."""
|
||||||
if get_origin(type_) is Union:
|
if get_origin(type_) is Union:
|
||||||
args = get_args(type_)
|
args = get_args(type_)
|
||||||
if len(args) == 2 and args[1] == type(None): # noqa E721
|
if len(args) == 2 and args[1] == type(None): # noqa E721
|
||||||
return True, args[0]
|
return True, args[0]
|
||||||
if type_ is Any:
|
if type_ is Any:
|
||||||
return True, Any
|
return True, Any
|
||||||
|
|
||||||
return False, type_
|
return False, type_
|
||||||
|
|
||||||
|
|
||||||
def _unwrap_type(tp):
|
def _unwrap_type(tp):
|
||||||
# strips Optional wrapper, if any
|
# strips Optional wrapper, if any
|
||||||
if get_origin(tp) is Union:
|
if get_origin(tp) is Union:
|
||||||
args = get_args(tp)
|
args = get_args(tp)
|
||||||
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
|
if len(args) == 2 and any(a is type(None) for a in args): # noqa: E721
|
||||||
# this is typing.Optional
|
# this is typing.Optional
|
||||||
return args[0] if args[1] is type(None) else args[1] # noqa: E721
|
return args[0] if args[1] is type(None) else args[1] # noqa: E721
|
||||||
return tp
|
return tp
|
||||||
|
|
||||||
|
|
||||||
def _get_dataclass_field_default(field: Field) -> Any:
|
def _get_dataclass_field_default(field: Field) -> Any:
|
||||||
if field.default_factory is not MISSING:
|
if field.default_factory is not MISSING:
|
||||||
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
|
# pyre-fixme[29]: `Union[dataclasses._MISSING_TYPE,
|
||||||
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
|
# dataclasses._DefaultFactory[typing.Any]]` is not a function.
|
||||||
return field.default_factory()
|
return field.default_factory()
|
||||||
elif field.default is not MISSING:
|
elif field.default is not MISSING:
|
||||||
return field.default
|
return field.default
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _dataclass_list_from_dict_list(dlist, typeannot):
|
def _dataclass_list_from_dict_list(dlist, typeannot):
|
||||||
"""
|
"""
|
||||||
Vectorised version of `_dataclass_from_dict`.
|
Vectorised version of `_dataclass_from_dict`.
|
||||||
The output should be equivalent to
|
The output should be equivalent to
|
||||||
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
|
`[_dataclass_from_dict(d, typeannot) for d in dlist]`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dlist: list of objects to convert.
|
dlist: list of objects to convert.
|
||||||
typeannot: type of each of those objects.
|
typeannot: type of each of those objects.
|
||||||
Returns:
|
Returns:
|
||||||
iterator or list over converted objects of the same length as `dlist`.
|
iterator or list over converted objects of the same length as `dlist`.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: it assumes the objects have None's in consistent places across
|
ValueError: it assumes the objects have None's in consistent places across
|
||||||
objects, otherwise it would ignore some values. This generally holds for
|
objects, otherwise it would ignore some values. This generally holds for
|
||||||
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
|
auto-generated annotations, but otherwise use `_dataclass_from_dict`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cls = get_origin(typeannot) or typeannot
|
cls = get_origin(typeannot) or typeannot
|
||||||
|
|
||||||
if typeannot is Any:
|
if typeannot is Any:
|
||||||
return dlist
|
return dlist
|
||||||
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
if all(obj is None for obj in dlist): # 1st recursion base: all None nodes
|
||||||
return dlist
|
return dlist
|
||||||
if any(obj is None for obj in dlist):
|
if any(obj is None for obj in dlist):
|
||||||
# filter out Nones and recurse on the resulting list
|
# filter out Nones and recurse on the resulting list
|
||||||
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
idx_notnone = [(i, obj) for i, obj in enumerate(dlist) if obj is not None]
|
||||||
idx, notnone = zip(*idx_notnone)
|
idx, notnone = zip(*idx_notnone)
|
||||||
converted = _dataclass_list_from_dict_list(notnone, typeannot)
|
converted = _dataclass_list_from_dict_list(notnone, typeannot)
|
||||||
res = [None] * len(dlist)
|
res = [None] * len(dlist)
|
||||||
for i, obj in zip(idx, converted):
|
for i, obj in zip(idx, converted):
|
||||||
res[i] = obj
|
res[i] = obj
|
||||||
return res
|
return res
|
||||||
|
|
||||||
is_optional, contained_type = _resolve_optional(typeannot)
|
is_optional, contained_type = _resolve_optional(typeannot)
|
||||||
if is_optional:
|
if is_optional:
|
||||||
return _dataclass_list_from_dict_list(dlist, contained_type)
|
return _dataclass_list_from_dict_list(dlist, contained_type)
|
||||||
|
|
||||||
# otherwise, we dispatch by the type of the provided annotation to convert to
|
# otherwise, we dispatch by the type of the provided annotation to convert to
|
||||||
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
if issubclass(cls, tuple) and hasattr(cls, "_fields"): # namedtuple
|
||||||
# For namedtuple, call the function recursively on the lists of corresponding keys
|
# For namedtuple, call the function recursively on the lists of corresponding keys
|
||||||
types = cls.__annotations__.values()
|
types = cls.__annotations__.values()
|
||||||
dlist_T = zip(*dlist)
|
dlist_T = zip(*dlist)
|
||||||
res_T = [
|
res_T = [
|
||||||
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
|
_dataclass_list_from_dict_list(key_list, tp) for key_list, tp in zip(dlist_T, types)
|
||||||
]
|
]
|
||||||
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
return [cls(*converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||||
elif issubclass(cls, (list, tuple)):
|
elif issubclass(cls, (list, tuple)):
|
||||||
# For list/tuple, call the function recursively on the lists of corresponding positions
|
# For list/tuple, call the function recursively on the lists of corresponding positions
|
||||||
types = get_args(typeannot)
|
types = get_args(typeannot)
|
||||||
if len(types) == 1: # probably List; replicate for all items
|
if len(types) == 1: # probably List; replicate for all items
|
||||||
types = types * len(dlist[0])
|
types = types * len(dlist[0])
|
||||||
dlist_T = zip(*dlist)
|
dlist_T = zip(*dlist)
|
||||||
res_T = (
|
res_T = (
|
||||||
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
|
_dataclass_list_from_dict_list(pos_list, tp) for pos_list, tp in zip(dlist_T, types)
|
||||||
)
|
)
|
||||||
if issubclass(cls, tuple):
|
if issubclass(cls, tuple):
|
||||||
return list(zip(*res_T))
|
return list(zip(*res_T))
|
||||||
else:
|
else:
|
||||||
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
return [cls(converted_as_tuple) for converted_as_tuple in zip(*res_T)]
|
||||||
elif issubclass(cls, dict):
|
elif issubclass(cls, dict):
|
||||||
# For the dictionary, call the function recursively on concatenated keys and vertices
|
# For the dictionary, call the function recursively on concatenated keys and vertices
|
||||||
key_t, val_t = get_args(typeannot)
|
key_t, val_t = get_args(typeannot)
|
||||||
all_keys_res = _dataclass_list_from_dict_list(
|
all_keys_res = _dataclass_list_from_dict_list(
|
||||||
[k for obj in dlist for k in obj.keys()], key_t
|
[k for obj in dlist for k in obj.keys()], key_t
|
||||||
)
|
)
|
||||||
all_vals_res = _dataclass_list_from_dict_list(
|
all_vals_res = _dataclass_list_from_dict_list(
|
||||||
[k for obj in dlist for k in obj.values()], val_t
|
[k for obj in dlist for k in obj.values()], val_t
|
||||||
)
|
)
|
||||||
indices = np.cumsum([len(obj) for obj in dlist])
|
indices = np.cumsum([len(obj) for obj in dlist])
|
||||||
assert indices[-1] == len(all_keys_res)
|
assert indices[-1] == len(all_keys_res)
|
||||||
|
|
||||||
keys = np.split(list(all_keys_res), indices[:-1])
|
keys = np.split(list(all_keys_res), indices[:-1])
|
||||||
all_vals_res_iter = iter(all_vals_res)
|
all_vals_res_iter = iter(all_vals_res)
|
||||||
return [cls(zip(k, all_vals_res_iter)) for k in keys]
|
return [cls(zip(k, all_vals_res_iter)) for k in keys]
|
||||||
elif not dataclasses.is_dataclass(typeannot):
|
elif not dataclasses.is_dataclass(typeannot):
|
||||||
return dlist
|
return dlist
|
||||||
|
|
||||||
# dataclass node: 2nd recursion base; call the function recursively on the lists
|
# dataclass node: 2nd recursion base; call the function recursively on the lists
|
||||||
# of the corresponding fields
|
# of the corresponding fields
|
||||||
assert dataclasses.is_dataclass(cls)
|
assert dataclasses.is_dataclass(cls)
|
||||||
fieldtypes = {
|
fieldtypes = {
|
||||||
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
|
f.name: (_unwrap_type(f.type), _get_dataclass_field_default(f))
|
||||||
for f in dataclasses.fields(typeannot)
|
for f in dataclasses.fields(typeannot)
|
||||||
}
|
}
|
||||||
|
|
||||||
# NOTE the default object is shared here
|
# NOTE the default object is shared here
|
||||||
key_lists = (
|
key_lists = (
|
||||||
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
|
_dataclass_list_from_dict_list([obj.get(k, default) for obj in dlist], type_)
|
||||||
for k, (type_, default) in fieldtypes.items()
|
for k, (type_, default) in fieldtypes.items()
|
||||||
)
|
)
|
||||||
transposed = zip(*key_lists)
|
transposed = zip(*key_lists)
|
||||||
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
|
return [cls(*vals_as_tuple) for vals_as_tuple in transposed]
|
||||||
|
@@ -1,161 +1,161 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import gzip
|
import gzip
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch.utils.data as data
|
import torch.utils.data as data
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Any, Dict, Tuple
|
from typing import List, Optional, Any, Dict, Tuple
|
||||||
|
|
||||||
from cotracker.datasets.utils import CoTrackerData
|
from cotracker.datasets.utils import CoTrackerData
|
||||||
from cotracker.datasets.dataclass_utils import load_dataclass
|
from cotracker.datasets.dataclass_utils import load_dataclass
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ImageAnnotation:
|
class ImageAnnotation:
|
||||||
# path to jpg file, relative w.r.t. dataset_root
|
# path to jpg file, relative w.r.t. dataset_root
|
||||||
path: str
|
path: str
|
||||||
# H x W
|
# H x W
|
||||||
size: Tuple[int, int]
|
size: Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DynamicReplicaFrameAnnotation:
|
class DynamicReplicaFrameAnnotation:
|
||||||
"""A dataclass used to load annotations from json."""
|
"""A dataclass used to load annotations from json."""
|
||||||
|
|
||||||
# can be used to join with `SequenceAnnotation`
|
# can be used to join with `SequenceAnnotation`
|
||||||
sequence_name: str
|
sequence_name: str
|
||||||
# 0-based, continuous frame number within sequence
|
# 0-based, continuous frame number within sequence
|
||||||
frame_number: int
|
frame_number: int
|
||||||
# timestamp in seconds from the video start
|
# timestamp in seconds from the video start
|
||||||
frame_timestamp: float
|
frame_timestamp: float
|
||||||
|
|
||||||
image: ImageAnnotation
|
image: ImageAnnotation
|
||||||
meta: Optional[Dict[str, Any]] = None
|
meta: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
camera_name: Optional[str] = None
|
camera_name: Optional[str] = None
|
||||||
trajectories: Optional[str] = None
|
trajectories: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class DynamicReplicaDataset(data.Dataset):
|
class DynamicReplicaDataset(data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
root,
|
root,
|
||||||
split="valid",
|
split="valid",
|
||||||
traj_per_sample=256,
|
traj_per_sample=256,
|
||||||
crop_size=None,
|
crop_size=None,
|
||||||
sample_len=-1,
|
sample_len=-1,
|
||||||
only_first_n_samples=-1,
|
only_first_n_samples=-1,
|
||||||
rgbd_input=False,
|
rgbd_input=False,
|
||||||
):
|
):
|
||||||
super(DynamicReplicaDataset, self).__init__()
|
super(DynamicReplicaDataset, self).__init__()
|
||||||
self.root = root
|
self.root = root
|
||||||
self.sample_len = sample_len
|
self.sample_len = sample_len
|
||||||
self.split = split
|
self.split = split
|
||||||
self.traj_per_sample = traj_per_sample
|
self.traj_per_sample = traj_per_sample
|
||||||
self.rgbd_input = rgbd_input
|
self.rgbd_input = rgbd_input
|
||||||
self.crop_size = crop_size
|
self.crop_size = crop_size
|
||||||
frame_annotations_file = f"frame_annotations_{split}.jgz"
|
frame_annotations_file = f"frame_annotations_{split}.jgz"
|
||||||
self.sample_list = []
|
self.sample_list = []
|
||||||
with gzip.open(
|
with gzip.open(
|
||||||
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
|
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
|
||||||
) as zipfile:
|
) as zipfile:
|
||||||
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
|
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
|
||||||
seq_annot = defaultdict(list)
|
seq_annot = defaultdict(list)
|
||||||
for frame_annot in frame_annots_list:
|
for frame_annot in frame_annots_list:
|
||||||
if frame_annot.camera_name == "left":
|
if frame_annot.camera_name == "left":
|
||||||
seq_annot[frame_annot.sequence_name].append(frame_annot)
|
seq_annot[frame_annot.sequence_name].append(frame_annot)
|
||||||
|
|
||||||
for seq_name in seq_annot.keys():
|
for seq_name in seq_annot.keys():
|
||||||
seq_len = len(seq_annot[seq_name])
|
seq_len = len(seq_annot[seq_name])
|
||||||
|
|
||||||
step = self.sample_len if self.sample_len > 0 else seq_len
|
step = self.sample_len if self.sample_len > 0 else seq_len
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
for ref_idx in range(0, seq_len, step):
|
for ref_idx in range(0, seq_len, step):
|
||||||
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
|
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
|
||||||
self.sample_list.append(sample)
|
self.sample_list.append(sample)
|
||||||
counter += 1
|
counter += 1
|
||||||
if only_first_n_samples > 0 and counter >= only_first_n_samples:
|
if only_first_n_samples > 0 and counter >= only_first_n_samples:
|
||||||
break
|
break
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.sample_list)
|
return len(self.sample_list)
|
||||||
|
|
||||||
def crop(self, rgbs, trajs):
|
def crop(self, rgbs, trajs):
|
||||||
T, N, _ = trajs.shape
|
T, N, _ = trajs.shape
|
||||||
|
|
||||||
S = len(rgbs)
|
S = len(rgbs)
|
||||||
H, W = rgbs[0].shape[:2]
|
H, W = rgbs[0].shape[:2]
|
||||||
assert S == T
|
assert S == T
|
||||||
|
|
||||||
H_new = H
|
H_new = H
|
||||||
W_new = W
|
W_new = W
|
||||||
|
|
||||||
# simple random crop
|
# simple random crop
|
||||||
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
|
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
|
||||||
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
|
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
|
||||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||||
|
|
||||||
trajs[:, :, 0] -= x0
|
trajs[:, :, 0] -= x0
|
||||||
trajs[:, :, 1] -= y0
|
trajs[:, :, 1] -= y0
|
||||||
|
|
||||||
return rgbs, trajs
|
return rgbs, trajs
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
sample = self.sample_list[index]
|
sample = self.sample_list[index]
|
||||||
T = len(sample)
|
T = len(sample)
|
||||||
rgbs, visibilities, traj_2d = [], [], []
|
rgbs, visibilities, traj_2d = [], [], []
|
||||||
|
|
||||||
H, W = sample[0].image.size
|
H, W = sample[0].image.size
|
||||||
image_size = (H, W)
|
image_size = (H, W)
|
||||||
|
|
||||||
for i in range(T):
|
for i in range(T):
|
||||||
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
|
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
|
||||||
traj = torch.load(traj_path)
|
traj = torch.load(traj_path)
|
||||||
|
|
||||||
visibilities.append(traj["verts_inds_vis"].numpy())
|
visibilities.append(traj["verts_inds_vis"].numpy())
|
||||||
|
|
||||||
rgbs.append(traj["img"].numpy())
|
rgbs.append(traj["img"].numpy())
|
||||||
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
|
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
|
||||||
|
|
||||||
traj_2d = np.stack(traj_2d)
|
traj_2d = np.stack(traj_2d)
|
||||||
visibility = np.stack(visibilities)
|
visibility = np.stack(visibilities)
|
||||||
T, N, D = traj_2d.shape
|
T, N, D = traj_2d.shape
|
||||||
# subsample trajectories for augmentations
|
# subsample trajectories for augmentations
|
||||||
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
|
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
|
||||||
|
|
||||||
traj_2d = traj_2d[:, visible_inds_sampled]
|
traj_2d = traj_2d[:, visible_inds_sampled]
|
||||||
visibility = visibility[:, visible_inds_sampled]
|
visibility = visibility[:, visible_inds_sampled]
|
||||||
|
|
||||||
if self.crop_size is not None:
|
if self.crop_size is not None:
|
||||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||||
H, W, _ = rgbs[0].shape
|
H, W, _ = rgbs[0].shape
|
||||||
image_size = self.crop_size
|
image_size = self.crop_size
|
||||||
|
|
||||||
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
|
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
|
||||||
visibility[traj_2d[:, :, 0] < 0] = False
|
visibility[traj_2d[:, :, 0] < 0] = False
|
||||||
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
|
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
|
||||||
visibility[traj_2d[:, :, 1] < 0] = False
|
visibility[traj_2d[:, :, 1] < 0] = False
|
||||||
|
|
||||||
# filter out points that're visible for less than 10 frames
|
# filter out points that're visible for less than 10 frames
|
||||||
visible_inds_resampled = visibility.sum(0) > 10
|
visible_inds_resampled = visibility.sum(0) > 10
|
||||||
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
|
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
|
||||||
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
|
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
|
||||||
|
|
||||||
rgbs = np.stack(rgbs, 0)
|
rgbs = np.stack(rgbs, 0)
|
||||||
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
|
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
|
||||||
return CoTrackerData(
|
return CoTrackerData(
|
||||||
video=video,
|
video=video,
|
||||||
trajectory=traj_2d,
|
trajectory=traj_2d,
|
||||||
visibility=visibility,
|
visibility=visibility,
|
||||||
valid=torch.ones(T, N),
|
valid=torch.ones(T, N),
|
||||||
seq_name=sample[0].sequence_name,
|
seq_name=sample[0].sequence_name,
|
||||||
)
|
)
|
||||||
|
@@ -1,441 +1,441 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import cv2
|
import cv2
|
||||||
|
|
||||||
import imageio
|
import imageio
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from cotracker.datasets.utils import CoTrackerData
|
from cotracker.datasets.utils import CoTrackerData
|
||||||
from torchvision.transforms import ColorJitter, GaussianBlur
|
from torchvision.transforms import ColorJitter, GaussianBlur
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
class CoTrackerDataset(torch.utils.data.Dataset):
|
class CoTrackerDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_root,
|
data_root,
|
||||||
crop_size=(384, 512),
|
crop_size=(384, 512),
|
||||||
seq_len=24,
|
seq_len=24,
|
||||||
traj_per_sample=768,
|
traj_per_sample=768,
|
||||||
sample_vis_1st_frame=False,
|
sample_vis_1st_frame=False,
|
||||||
use_augs=False,
|
use_augs=False,
|
||||||
):
|
):
|
||||||
super(CoTrackerDataset, self).__init__()
|
super(CoTrackerDataset, self).__init__()
|
||||||
np.random.seed(0)
|
np.random.seed(0)
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
self.data_root = data_root
|
self.data_root = data_root
|
||||||
self.seq_len = seq_len
|
self.seq_len = seq_len
|
||||||
self.traj_per_sample = traj_per_sample
|
self.traj_per_sample = traj_per_sample
|
||||||
self.sample_vis_1st_frame = sample_vis_1st_frame
|
self.sample_vis_1st_frame = sample_vis_1st_frame
|
||||||
self.use_augs = use_augs
|
self.use_augs = use_augs
|
||||||
self.crop_size = crop_size
|
self.crop_size = crop_size
|
||||||
|
|
||||||
# photometric augmentation
|
# photometric augmentation
|
||||||
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
|
self.photo_aug = ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.25 / 3.14)
|
||||||
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
self.blur_aug = GaussianBlur(11, sigma=(0.1, 2.0))
|
||||||
|
|
||||||
self.blur_aug_prob = 0.25
|
self.blur_aug_prob = 0.25
|
||||||
self.color_aug_prob = 0.25
|
self.color_aug_prob = 0.25
|
||||||
|
|
||||||
# occlusion augmentation
|
# occlusion augmentation
|
||||||
self.eraser_aug_prob = 0.5
|
self.eraser_aug_prob = 0.5
|
||||||
self.eraser_bounds = [2, 100]
|
self.eraser_bounds = [2, 100]
|
||||||
self.eraser_max = 10
|
self.eraser_max = 10
|
||||||
|
|
||||||
# occlusion augmentation
|
# occlusion augmentation
|
||||||
self.replace_aug_prob = 0.5
|
self.replace_aug_prob = 0.5
|
||||||
self.replace_bounds = [2, 100]
|
self.replace_bounds = [2, 100]
|
||||||
self.replace_max = 10
|
self.replace_max = 10
|
||||||
|
|
||||||
# spatial augmentations
|
# spatial augmentations
|
||||||
self.pad_bounds = [0, 100]
|
self.pad_bounds = [0, 100]
|
||||||
self.crop_size = crop_size
|
self.crop_size = crop_size
|
||||||
self.resize_lim = [0.25, 2.0] # sample resizes from here
|
self.resize_lim = [0.25, 2.0] # sample resizes from here
|
||||||
self.resize_delta = 0.2
|
self.resize_delta = 0.2
|
||||||
self.max_crop_offset = 50
|
self.max_crop_offset = 50
|
||||||
|
|
||||||
self.do_flip = True
|
self.do_flip = True
|
||||||
self.h_flip_prob = 0.5
|
self.h_flip_prob = 0.5
|
||||||
self.v_flip_prob = 0.5
|
self.v_flip_prob = 0.5
|
||||||
|
|
||||||
def getitem_helper(self, index):
|
def getitem_helper(self, index):
|
||||||
return NotImplementedError
|
return NotImplementedError
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
gotit = False
|
gotit = False
|
||||||
|
|
||||||
sample, gotit = self.getitem_helper(index)
|
sample, gotit = self.getitem_helper(index)
|
||||||
if not gotit:
|
if not gotit:
|
||||||
print("warning: sampling failed")
|
print("warning: sampling failed")
|
||||||
# fake sample, so we can still collate
|
# fake sample, so we can still collate
|
||||||
sample = CoTrackerData(
|
sample = CoTrackerData(
|
||||||
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
|
video=torch.zeros((self.seq_len, 3, self.crop_size[0], self.crop_size[1])),
|
||||||
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
trajectory=torch.zeros((self.seq_len, self.traj_per_sample, 2)),
|
||||||
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
visibility=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||||
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
valid=torch.zeros((self.seq_len, self.traj_per_sample)),
|
||||||
)
|
)
|
||||||
|
|
||||||
return sample, gotit
|
return sample, gotit
|
||||||
|
|
||||||
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
|
def add_photometric_augs(self, rgbs, trajs, visibles, eraser=True, replace=True):
|
||||||
T, N, _ = trajs.shape
|
T, N, _ = trajs.shape
|
||||||
|
|
||||||
S = len(rgbs)
|
S = len(rgbs)
|
||||||
H, W = rgbs[0].shape[:2]
|
H, W = rgbs[0].shape[:2]
|
||||||
assert S == T
|
assert S == T
|
||||||
|
|
||||||
if eraser:
|
if eraser:
|
||||||
############ eraser transform (per image after the first) ############
|
############ eraser transform (per image after the first) ############
|
||||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||||
for i in range(1, S):
|
for i in range(1, S):
|
||||||
if np.random.rand() < self.eraser_aug_prob:
|
if np.random.rand() < self.eraser_aug_prob:
|
||||||
for _ in range(
|
for _ in range(
|
||||||
np.random.randint(1, self.eraser_max + 1)
|
np.random.randint(1, self.eraser_max + 1)
|
||||||
): # number of times to occlude
|
): # number of times to occlude
|
||||||
xc = np.random.randint(0, W)
|
xc = np.random.randint(0, W)
|
||||||
yc = np.random.randint(0, H)
|
yc = np.random.randint(0, H)
|
||||||
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
dx = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||||
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
dy = np.random.randint(self.eraser_bounds[0], self.eraser_bounds[1])
|
||||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||||
|
|
||||||
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
|
mean_color = np.mean(rgbs[i][y0:y1, x0:x1, :].reshape(-1, 3), axis=0)
|
||||||
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
rgbs[i][y0:y1, x0:x1, :] = mean_color
|
||||||
|
|
||||||
occ_inds = np.logical_and(
|
occ_inds = np.logical_and(
|
||||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||||
)
|
)
|
||||||
visibles[i, occ_inds] = 0
|
visibles[i, occ_inds] = 0
|
||||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||||
|
|
||||||
if replace:
|
if replace:
|
||||||
rgbs_alt = [
|
rgbs_alt = [
|
||||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
|
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs
|
||||||
]
|
]
|
||||||
rgbs_alt = [
|
rgbs_alt = [
|
||||||
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
|
np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs_alt
|
||||||
]
|
]
|
||||||
|
|
||||||
############ replace transform (per image after the first) ############
|
############ replace transform (per image after the first) ############
|
||||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||||
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
|
rgbs_alt = [rgb.astype(np.float32) for rgb in rgbs_alt]
|
||||||
for i in range(1, S):
|
for i in range(1, S):
|
||||||
if np.random.rand() < self.replace_aug_prob:
|
if np.random.rand() < self.replace_aug_prob:
|
||||||
for _ in range(
|
for _ in range(
|
||||||
np.random.randint(1, self.replace_max + 1)
|
np.random.randint(1, self.replace_max + 1)
|
||||||
): # number of times to occlude
|
): # number of times to occlude
|
||||||
xc = np.random.randint(0, W)
|
xc = np.random.randint(0, W)
|
||||||
yc = np.random.randint(0, H)
|
yc = np.random.randint(0, H)
|
||||||
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
dx = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||||
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
dy = np.random.randint(self.replace_bounds[0], self.replace_bounds[1])
|
||||||
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
x0 = np.clip(xc - dx / 2, 0, W - 1).round().astype(np.int32)
|
||||||
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
x1 = np.clip(xc + dx / 2, 0, W - 1).round().astype(np.int32)
|
||||||
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
y0 = np.clip(yc - dy / 2, 0, H - 1).round().astype(np.int32)
|
||||||
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
y1 = np.clip(yc + dy / 2, 0, H - 1).round().astype(np.int32)
|
||||||
|
|
||||||
wid = x1 - x0
|
wid = x1 - x0
|
||||||
hei = y1 - y0
|
hei = y1 - y0
|
||||||
y00 = np.random.randint(0, H - hei)
|
y00 = np.random.randint(0, H - hei)
|
||||||
x00 = np.random.randint(0, W - wid)
|
x00 = np.random.randint(0, W - wid)
|
||||||
fr = np.random.randint(0, S)
|
fr = np.random.randint(0, S)
|
||||||
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
|
rep = rgbs_alt[fr][y00 : y00 + hei, x00 : x00 + wid, :]
|
||||||
rgbs[i][y0:y1, x0:x1, :] = rep
|
rgbs[i][y0:y1, x0:x1, :] = rep
|
||||||
|
|
||||||
occ_inds = np.logical_and(
|
occ_inds = np.logical_and(
|
||||||
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
np.logical_and(trajs[i, :, 0] >= x0, trajs[i, :, 0] < x1),
|
||||||
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
np.logical_and(trajs[i, :, 1] >= y0, trajs[i, :, 1] < y1),
|
||||||
)
|
)
|
||||||
visibles[i, occ_inds] = 0
|
visibles[i, occ_inds] = 0
|
||||||
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
rgbs = [rgb.astype(np.uint8) for rgb in rgbs]
|
||||||
|
|
||||||
############ photometric augmentation ############
|
############ photometric augmentation ############
|
||||||
if np.random.rand() < self.color_aug_prob:
|
if np.random.rand() < self.color_aug_prob:
|
||||||
# random per-frame amount of aug
|
# random per-frame amount of aug
|
||||||
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
rgbs = [np.array(self.photo_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||||
|
|
||||||
if np.random.rand() < self.blur_aug_prob:
|
if np.random.rand() < self.blur_aug_prob:
|
||||||
# random per-frame amount of blur
|
# random per-frame amount of blur
|
||||||
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
rgbs = [np.array(self.blur_aug(Image.fromarray(rgb)), dtype=np.uint8) for rgb in rgbs]
|
||||||
|
|
||||||
return rgbs, trajs, visibles
|
return rgbs, trajs, visibles
|
||||||
|
|
||||||
def add_spatial_augs(self, rgbs, trajs, visibles):
|
def add_spatial_augs(self, rgbs, trajs, visibles):
|
||||||
T, N, __ = trajs.shape
|
T, N, __ = trajs.shape
|
||||||
|
|
||||||
S = len(rgbs)
|
S = len(rgbs)
|
||||||
H, W = rgbs[0].shape[:2]
|
H, W = rgbs[0].shape[:2]
|
||||||
assert S == T
|
assert S == T
|
||||||
|
|
||||||
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
rgbs = [rgb.astype(np.float32) for rgb in rgbs]
|
||||||
|
|
||||||
############ spatial transform ############
|
############ spatial transform ############
|
||||||
|
|
||||||
# padding
|
# padding
|
||||||
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
pad_x0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||||
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
pad_x1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||||
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
pad_y0 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||||
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
pad_y1 = np.random.randint(self.pad_bounds[0], self.pad_bounds[1])
|
||||||
|
|
||||||
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
|
rgbs = [np.pad(rgb, ((pad_y0, pad_y1), (pad_x0, pad_x1), (0, 0))) for rgb in rgbs]
|
||||||
trajs[:, :, 0] += pad_x0
|
trajs[:, :, 0] += pad_x0
|
||||||
trajs[:, :, 1] += pad_y0
|
trajs[:, :, 1] += pad_y0
|
||||||
H, W = rgbs[0].shape[:2]
|
H, W = rgbs[0].shape[:2]
|
||||||
|
|
||||||
# scaling + stretching
|
# scaling + stretching
|
||||||
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
|
scale = np.random.uniform(self.resize_lim[0], self.resize_lim[1])
|
||||||
scale_x = scale
|
scale_x = scale
|
||||||
scale_y = scale
|
scale_y = scale
|
||||||
H_new = H
|
H_new = H
|
||||||
W_new = W
|
W_new = W
|
||||||
|
|
||||||
scale_delta_x = 0.0
|
scale_delta_x = 0.0
|
||||||
scale_delta_y = 0.0
|
scale_delta_y = 0.0
|
||||||
|
|
||||||
rgbs_scaled = []
|
rgbs_scaled = []
|
||||||
for s in range(S):
|
for s in range(S):
|
||||||
if s == 1:
|
if s == 1:
|
||||||
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
|
scale_delta_x = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||||
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
|
scale_delta_y = np.random.uniform(-self.resize_delta, self.resize_delta)
|
||||||
elif s > 1:
|
elif s > 1:
|
||||||
scale_delta_x = (
|
scale_delta_x = (
|
||||||
scale_delta_x * 0.8
|
scale_delta_x * 0.8
|
||||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||||
)
|
)
|
||||||
scale_delta_y = (
|
scale_delta_y = (
|
||||||
scale_delta_y * 0.8
|
scale_delta_y * 0.8
|
||||||
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
+ np.random.uniform(-self.resize_delta, self.resize_delta) * 0.2
|
||||||
)
|
)
|
||||||
scale_x = scale_x + scale_delta_x
|
scale_x = scale_x + scale_delta_x
|
||||||
scale_y = scale_y + scale_delta_y
|
scale_y = scale_y + scale_delta_y
|
||||||
|
|
||||||
# bring h/w closer
|
# bring h/w closer
|
||||||
scale_xy = (scale_x + scale_y) * 0.5
|
scale_xy = (scale_x + scale_y) * 0.5
|
||||||
scale_x = scale_x * 0.5 + scale_xy * 0.5
|
scale_x = scale_x * 0.5 + scale_xy * 0.5
|
||||||
scale_y = scale_y * 0.5 + scale_xy * 0.5
|
scale_y = scale_y * 0.5 + scale_xy * 0.5
|
||||||
|
|
||||||
# don't get too crazy
|
# don't get too crazy
|
||||||
scale_x = np.clip(scale_x, 0.2, 2.0)
|
scale_x = np.clip(scale_x, 0.2, 2.0)
|
||||||
scale_y = np.clip(scale_y, 0.2, 2.0)
|
scale_y = np.clip(scale_y, 0.2, 2.0)
|
||||||
|
|
||||||
H_new = int(H * scale_y)
|
H_new = int(H * scale_y)
|
||||||
W_new = int(W * scale_x)
|
W_new = int(W * scale_x)
|
||||||
|
|
||||||
# make it at least slightly bigger than the crop area,
|
# make it at least slightly bigger than the crop area,
|
||||||
# so that the random cropping can add diversity
|
# so that the random cropping can add diversity
|
||||||
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
H_new = np.clip(H_new, self.crop_size[0] + 10, None)
|
||||||
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
W_new = np.clip(W_new, self.crop_size[1] + 10, None)
|
||||||
# recompute scale in case we clipped
|
# recompute scale in case we clipped
|
||||||
scale_x = (W_new - 1) / float(W - 1)
|
scale_x = (W_new - 1) / float(W - 1)
|
||||||
scale_y = (H_new - 1) / float(H - 1)
|
scale_y = (H_new - 1) / float(H - 1)
|
||||||
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
|
rgbs_scaled.append(cv2.resize(rgbs[s], (W_new, H_new), interpolation=cv2.INTER_LINEAR))
|
||||||
trajs[s, :, 0] *= scale_x
|
trajs[s, :, 0] *= scale_x
|
||||||
trajs[s, :, 1] *= scale_y
|
trajs[s, :, 1] *= scale_y
|
||||||
rgbs = rgbs_scaled
|
rgbs = rgbs_scaled
|
||||||
|
|
||||||
ok_inds = visibles[0, :] > 0
|
ok_inds = visibles[0, :] > 0
|
||||||
vis_trajs = trajs[:, ok_inds] # S,?,2
|
vis_trajs = trajs[:, ok_inds] # S,?,2
|
||||||
|
|
||||||
if vis_trajs.shape[1] > 0:
|
if vis_trajs.shape[1] > 0:
|
||||||
mid_x = np.mean(vis_trajs[0, :, 0])
|
mid_x = np.mean(vis_trajs[0, :, 0])
|
||||||
mid_y = np.mean(vis_trajs[0, :, 1])
|
mid_y = np.mean(vis_trajs[0, :, 1])
|
||||||
else:
|
else:
|
||||||
mid_y = self.crop_size[0]
|
mid_y = self.crop_size[0]
|
||||||
mid_x = self.crop_size[1]
|
mid_x = self.crop_size[1]
|
||||||
|
|
||||||
x0 = int(mid_x - self.crop_size[1] // 2)
|
x0 = int(mid_x - self.crop_size[1] // 2)
|
||||||
y0 = int(mid_y - self.crop_size[0] // 2)
|
y0 = int(mid_y - self.crop_size[0] // 2)
|
||||||
|
|
||||||
offset_x = 0
|
offset_x = 0
|
||||||
offset_y = 0
|
offset_y = 0
|
||||||
|
|
||||||
for s in range(S):
|
for s in range(S):
|
||||||
# on each frame, shift a bit more
|
# on each frame, shift a bit more
|
||||||
if s == 1:
|
if s == 1:
|
||||||
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
offset_x = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||||
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
offset_y = np.random.randint(-self.max_crop_offset, self.max_crop_offset)
|
||||||
elif s > 1:
|
elif s > 1:
|
||||||
offset_x = int(
|
offset_x = int(
|
||||||
offset_x * 0.8
|
offset_x * 0.8
|
||||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||||
)
|
)
|
||||||
offset_y = int(
|
offset_y = int(
|
||||||
offset_y * 0.8
|
offset_y * 0.8
|
||||||
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
+ np.random.randint(-self.max_crop_offset, self.max_crop_offset + 1) * 0.2
|
||||||
)
|
)
|
||||||
x0 = x0 + offset_x
|
x0 = x0 + offset_x
|
||||||
y0 = y0 + offset_y
|
y0 = y0 + offset_y
|
||||||
|
|
||||||
H_new, W_new = rgbs[s].shape[:2]
|
H_new, W_new = rgbs[s].shape[:2]
|
||||||
if H_new == self.crop_size[0]:
|
if H_new == self.crop_size[0]:
|
||||||
y0 = 0
|
y0 = 0
|
||||||
else:
|
else:
|
||||||
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
|
y0 = min(max(0, y0), H_new - self.crop_size[0] - 1)
|
||||||
|
|
||||||
if W_new == self.crop_size[1]:
|
if W_new == self.crop_size[1]:
|
||||||
x0 = 0
|
x0 = 0
|
||||||
else:
|
else:
|
||||||
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
|
x0 = min(max(0, x0), W_new - self.crop_size[1] - 1)
|
||||||
|
|
||||||
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
rgbs[s] = rgbs[s][y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]]
|
||||||
trajs[s, :, 0] -= x0
|
trajs[s, :, 0] -= x0
|
||||||
trajs[s, :, 1] -= y0
|
trajs[s, :, 1] -= y0
|
||||||
|
|
||||||
H_new = self.crop_size[0]
|
H_new = self.crop_size[0]
|
||||||
W_new = self.crop_size[1]
|
W_new = self.crop_size[1]
|
||||||
|
|
||||||
# flip
|
# flip
|
||||||
h_flipped = False
|
h_flipped = False
|
||||||
v_flipped = False
|
v_flipped = False
|
||||||
if self.do_flip:
|
if self.do_flip:
|
||||||
# h flip
|
# h flip
|
||||||
if np.random.rand() < self.h_flip_prob:
|
if np.random.rand() < self.h_flip_prob:
|
||||||
h_flipped = True
|
h_flipped = True
|
||||||
rgbs = [rgb[:, ::-1] for rgb in rgbs]
|
rgbs = [rgb[:, ::-1] for rgb in rgbs]
|
||||||
# v flip
|
# v flip
|
||||||
if np.random.rand() < self.v_flip_prob:
|
if np.random.rand() < self.v_flip_prob:
|
||||||
v_flipped = True
|
v_flipped = True
|
||||||
rgbs = [rgb[::-1] for rgb in rgbs]
|
rgbs = [rgb[::-1] for rgb in rgbs]
|
||||||
if h_flipped:
|
if h_flipped:
|
||||||
trajs[:, :, 0] = W_new - trajs[:, :, 0]
|
trajs[:, :, 0] = W_new - trajs[:, :, 0]
|
||||||
if v_flipped:
|
if v_flipped:
|
||||||
trajs[:, :, 1] = H_new - trajs[:, :, 1]
|
trajs[:, :, 1] = H_new - trajs[:, :, 1]
|
||||||
|
|
||||||
return rgbs, trajs
|
return rgbs, trajs
|
||||||
|
|
||||||
def crop(self, rgbs, trajs):
|
def crop(self, rgbs, trajs):
|
||||||
T, N, _ = trajs.shape
|
T, N, _ = trajs.shape
|
||||||
|
|
||||||
S = len(rgbs)
|
S = len(rgbs)
|
||||||
H, W = rgbs[0].shape[:2]
|
H, W = rgbs[0].shape[:2]
|
||||||
assert S == T
|
assert S == T
|
||||||
|
|
||||||
############ spatial transform ############
|
############ spatial transform ############
|
||||||
|
|
||||||
H_new = H
|
H_new = H
|
||||||
W_new = W
|
W_new = W
|
||||||
|
|
||||||
# simple random crop
|
# simple random crop
|
||||||
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
|
y0 = 0 if self.crop_size[0] >= H_new else np.random.randint(0, H_new - self.crop_size[0])
|
||||||
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
|
x0 = 0 if self.crop_size[1] >= W_new else np.random.randint(0, W_new - self.crop_size[1])
|
||||||
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
||||||
|
|
||||||
trajs[:, :, 0] -= x0
|
trajs[:, :, 0] -= x0
|
||||||
trajs[:, :, 1] -= y0
|
trajs[:, :, 1] -= y0
|
||||||
|
|
||||||
return rgbs, trajs
|
return rgbs, trajs
|
||||||
|
|
||||||
|
|
||||||
class KubricMovifDataset(CoTrackerDataset):
|
class KubricMovifDataset(CoTrackerDataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_root,
|
data_root,
|
||||||
crop_size=(384, 512),
|
crop_size=(384, 512),
|
||||||
seq_len=24,
|
seq_len=24,
|
||||||
traj_per_sample=768,
|
traj_per_sample=768,
|
||||||
sample_vis_1st_frame=False,
|
sample_vis_1st_frame=False,
|
||||||
use_augs=False,
|
use_augs=False,
|
||||||
):
|
):
|
||||||
super(KubricMovifDataset, self).__init__(
|
super(KubricMovifDataset, self).__init__(
|
||||||
data_root=data_root,
|
data_root=data_root,
|
||||||
crop_size=crop_size,
|
crop_size=crop_size,
|
||||||
seq_len=seq_len,
|
seq_len=seq_len,
|
||||||
traj_per_sample=traj_per_sample,
|
traj_per_sample=traj_per_sample,
|
||||||
sample_vis_1st_frame=sample_vis_1st_frame,
|
sample_vis_1st_frame=sample_vis_1st_frame,
|
||||||
use_augs=use_augs,
|
use_augs=use_augs,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.pad_bounds = [0, 25]
|
self.pad_bounds = [0, 25]
|
||||||
self.resize_lim = [0.75, 1.25] # sample resizes from here
|
self.resize_lim = [0.75, 1.25] # sample resizes from here
|
||||||
self.resize_delta = 0.05
|
self.resize_delta = 0.05
|
||||||
self.max_crop_offset = 15
|
self.max_crop_offset = 15
|
||||||
self.seq_names = [
|
self.seq_names = [
|
||||||
fname
|
fname
|
||||||
for fname in os.listdir(data_root)
|
for fname in os.listdir(data_root)
|
||||||
if os.path.isdir(os.path.join(data_root, fname))
|
if os.path.isdir(os.path.join(data_root, fname))
|
||||||
]
|
]
|
||||||
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
print("found %d unique videos in %s" % (len(self.seq_names), self.data_root))
|
||||||
|
|
||||||
def getitem_helper(self, index):
|
def getitem_helper(self, index):
|
||||||
gotit = True
|
gotit = True
|
||||||
seq_name = self.seq_names[index]
|
seq_name = self.seq_names[index]
|
||||||
|
|
||||||
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
|
npy_path = os.path.join(self.data_root, seq_name, seq_name + ".npy")
|
||||||
rgb_path = os.path.join(self.data_root, seq_name, "frames")
|
rgb_path = os.path.join(self.data_root, seq_name, "frames")
|
||||||
|
|
||||||
img_paths = sorted(os.listdir(rgb_path))
|
img_paths = sorted(os.listdir(rgb_path))
|
||||||
rgbs = []
|
rgbs = []
|
||||||
for i, img_path in enumerate(img_paths):
|
for i, img_path in enumerate(img_paths):
|
||||||
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
|
rgbs.append(imageio.v2.imread(os.path.join(rgb_path, img_path)))
|
||||||
|
|
||||||
rgbs = np.stack(rgbs)
|
rgbs = np.stack(rgbs)
|
||||||
annot_dict = np.load(npy_path, allow_pickle=True).item()
|
annot_dict = np.load(npy_path, allow_pickle=True).item()
|
||||||
traj_2d = annot_dict["coords"]
|
traj_2d = annot_dict["coords"]
|
||||||
visibility = annot_dict["visibility"]
|
visibility = annot_dict["visibility"]
|
||||||
|
|
||||||
# random crop
|
# random crop
|
||||||
assert self.seq_len <= len(rgbs)
|
assert self.seq_len <= len(rgbs)
|
||||||
if self.seq_len < len(rgbs):
|
if self.seq_len < len(rgbs):
|
||||||
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
|
start_ind = np.random.choice(len(rgbs) - self.seq_len, 1)[0]
|
||||||
|
|
||||||
rgbs = rgbs[start_ind : start_ind + self.seq_len]
|
rgbs = rgbs[start_ind : start_ind + self.seq_len]
|
||||||
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
|
traj_2d = traj_2d[:, start_ind : start_ind + self.seq_len]
|
||||||
visibility = visibility[:, start_ind : start_ind + self.seq_len]
|
visibility = visibility[:, start_ind : start_ind + self.seq_len]
|
||||||
|
|
||||||
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
traj_2d = np.transpose(traj_2d, (1, 0, 2))
|
||||||
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
visibility = np.transpose(np.logical_not(visibility), (1, 0))
|
||||||
if self.use_augs:
|
if self.use_augs:
|
||||||
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
|
rgbs, traj_2d, visibility = self.add_photometric_augs(rgbs, traj_2d, visibility)
|
||||||
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
rgbs, traj_2d = self.add_spatial_augs(rgbs, traj_2d, visibility)
|
||||||
else:
|
else:
|
||||||
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
||||||
|
|
||||||
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
|
visibility[traj_2d[:, :, 0] > self.crop_size[1] - 1] = False
|
||||||
visibility[traj_2d[:, :, 0] < 0] = False
|
visibility[traj_2d[:, :, 0] < 0] = False
|
||||||
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
|
visibility[traj_2d[:, :, 1] > self.crop_size[0] - 1] = False
|
||||||
visibility[traj_2d[:, :, 1] < 0] = False
|
visibility[traj_2d[:, :, 1] < 0] = False
|
||||||
|
|
||||||
visibility = torch.from_numpy(visibility)
|
visibility = torch.from_numpy(visibility)
|
||||||
traj_2d = torch.from_numpy(traj_2d)
|
traj_2d = torch.from_numpy(traj_2d)
|
||||||
|
|
||||||
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
|
visibile_pts_first_frame_inds = (visibility[0]).nonzero(as_tuple=False)[:, 0]
|
||||||
|
|
||||||
if self.sample_vis_1st_frame:
|
if self.sample_vis_1st_frame:
|
||||||
visibile_pts_inds = visibile_pts_first_frame_inds
|
visibile_pts_inds = visibile_pts_first_frame_inds
|
||||||
else:
|
else:
|
||||||
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
|
visibile_pts_mid_frame_inds = (visibility[self.seq_len // 2]).nonzero(as_tuple=False)[
|
||||||
:, 0
|
:, 0
|
||||||
]
|
]
|
||||||
visibile_pts_inds = torch.cat(
|
visibile_pts_inds = torch.cat(
|
||||||
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
(visibile_pts_first_frame_inds, visibile_pts_mid_frame_inds), dim=0
|
||||||
)
|
)
|
||||||
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
|
point_inds = torch.randperm(len(visibile_pts_inds))[: self.traj_per_sample]
|
||||||
if len(point_inds) < self.traj_per_sample:
|
if len(point_inds) < self.traj_per_sample:
|
||||||
gotit = False
|
gotit = False
|
||||||
|
|
||||||
visible_inds_sampled = visibile_pts_inds[point_inds]
|
visible_inds_sampled = visibile_pts_inds[point_inds]
|
||||||
|
|
||||||
trajs = traj_2d[:, visible_inds_sampled].float()
|
trajs = traj_2d[:, visible_inds_sampled].float()
|
||||||
visibles = visibility[:, visible_inds_sampled]
|
visibles = visibility[:, visible_inds_sampled]
|
||||||
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
valids = torch.ones((self.seq_len, self.traj_per_sample))
|
||||||
|
|
||||||
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
rgbs = torch.from_numpy(np.stack(rgbs)).permute(0, 3, 1, 2).float()
|
||||||
sample = CoTrackerData(
|
sample = CoTrackerData(
|
||||||
video=rgbs,
|
video=rgbs,
|
||||||
trajectory=trajs,
|
trajectory=trajs,
|
||||||
visibility=visibles,
|
visibility=visibles,
|
||||||
valid=valids,
|
valid=valids,
|
||||||
seq_name=seq_name,
|
seq_name=seq_name,
|
||||||
)
|
)
|
||||||
return sample, gotit
|
return sample, gotit
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.seq_names)
|
return len(self.seq_names)
|
||||||
|
@@ -1,209 +1,209 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import io
|
import io
|
||||||
import glob
|
import glob
|
||||||
import torch
|
import torch
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import mediapy as media
|
import mediapy as media
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from typing import Mapping, Tuple, Union
|
from typing import Mapping, Tuple, Union
|
||||||
|
|
||||||
from cotracker.datasets.utils import CoTrackerData
|
from cotracker.datasets.utils import CoTrackerData
|
||||||
|
|
||||||
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
DatasetElement = Mapping[str, Mapping[str, Union[np.ndarray, str]]]
|
||||||
|
|
||||||
|
|
||||||
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
def resize_video(video: np.ndarray, output_size: Tuple[int, int]) -> np.ndarray:
|
||||||
"""Resize a video to output_size."""
|
"""Resize a video to output_size."""
|
||||||
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
# If you have a GPU, consider replacing this with a GPU-enabled resize op,
|
||||||
# such as a jitted jax.image.resize. It will make things faster.
|
# such as a jitted jax.image.resize. It will make things faster.
|
||||||
return media.resize_video(video, output_size)
|
return media.resize_video(video, output_size)
|
||||||
|
|
||||||
|
|
||||||
def sample_queries_first(
|
def sample_queries_first(
|
||||||
target_occluded: np.ndarray,
|
target_occluded: np.ndarray,
|
||||||
target_points: np.ndarray,
|
target_points: np.ndarray,
|
||||||
frames: np.ndarray,
|
frames: np.ndarray,
|
||||||
) -> Mapping[str, np.ndarray]:
|
) -> Mapping[str, np.ndarray]:
|
||||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||||
Given a set of frames and tracks with no query points, use the first
|
Given a set of frames and tracks with no query points, use the first
|
||||||
visible point in each track as the query.
|
visible point in each track as the query.
|
||||||
Args:
|
Args:
|
||||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||||
where True indicates occluded.
|
where True indicates occluded.
|
||||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||||
is [x,y] scaled between 0 and 1.
|
is [x,y] scaled between 0 and 1.
|
||||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||||
-1 and 1.
|
-1 and 1.
|
||||||
Returns:
|
Returns:
|
||||||
A dict with the keys:
|
A dict with the keys:
|
||||||
video: Video tensor of shape [1, n_frames, height, width, 3]
|
video: Video tensor of shape [1, n_frames, height, width, 3]
|
||||||
query_points: Query points of shape [1, n_queries, 3] where
|
query_points: Query points of shape [1, n_queries, 3] where
|
||||||
each point is [t, y, x] scaled to the range [-1, 1]
|
each point is [t, y, x] scaled to the range [-1, 1]
|
||||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||||
each point is [x, y] scaled to the range [-1, 1]
|
each point is [x, y] scaled to the range [-1, 1]
|
||||||
"""
|
"""
|
||||||
valid = np.sum(~target_occluded, axis=1) > 0
|
valid = np.sum(~target_occluded, axis=1) > 0
|
||||||
target_points = target_points[valid, :]
|
target_points = target_points[valid, :]
|
||||||
target_occluded = target_occluded[valid, :]
|
target_occluded = target_occluded[valid, :]
|
||||||
|
|
||||||
query_points = []
|
query_points = []
|
||||||
for i in range(target_points.shape[0]):
|
for i in range(target_points.shape[0]):
|
||||||
index = np.where(target_occluded[i] == 0)[0][0]
|
index = np.where(target_occluded[i] == 0)[0][0]
|
||||||
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
x, y = target_points[i, index, 0], target_points[i, index, 1]
|
||||||
query_points.append(np.array([index, y, x])) # [t, y, x]
|
query_points.append(np.array([index, y, x])) # [t, y, x]
|
||||||
query_points = np.stack(query_points, axis=0)
|
query_points = np.stack(query_points, axis=0)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"video": frames[np.newaxis, ...],
|
"video": frames[np.newaxis, ...],
|
||||||
"query_points": query_points[np.newaxis, ...],
|
"query_points": query_points[np.newaxis, ...],
|
||||||
"target_points": target_points[np.newaxis, ...],
|
"target_points": target_points[np.newaxis, ...],
|
||||||
"occluded": target_occluded[np.newaxis, ...],
|
"occluded": target_occluded[np.newaxis, ...],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def sample_queries_strided(
|
def sample_queries_strided(
|
||||||
target_occluded: np.ndarray,
|
target_occluded: np.ndarray,
|
||||||
target_points: np.ndarray,
|
target_points: np.ndarray,
|
||||||
frames: np.ndarray,
|
frames: np.ndarray,
|
||||||
query_stride: int = 5,
|
query_stride: int = 5,
|
||||||
) -> Mapping[str, np.ndarray]:
|
) -> Mapping[str, np.ndarray]:
|
||||||
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
"""Package a set of frames and tracks for use in TAPNet evaluations.
|
||||||
|
|
||||||
Given a set of frames and tracks with no query points, sample queries
|
Given a set of frames and tracks with no query points, sample queries
|
||||||
strided every query_stride frames, ignoring points that are not visible
|
strided every query_stride frames, ignoring points that are not visible
|
||||||
at the selected frames.
|
at the selected frames.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
target_occluded: Boolean occlusion flag, of shape [n_tracks, n_frames],
|
||||||
where True indicates occluded.
|
where True indicates occluded.
|
||||||
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
target_points: Position, of shape [n_tracks, n_frames, 2], where each point
|
||||||
is [x,y] scaled between 0 and 1.
|
is [x,y] scaled between 0 and 1.
|
||||||
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
frames: Video tensor, of shape [n_frames, height, width, 3]. Scaled between
|
||||||
-1 and 1.
|
-1 and 1.
|
||||||
query_stride: When sampling query points, search for un-occluded points
|
query_stride: When sampling query points, search for un-occluded points
|
||||||
every query_stride frames and convert each one into a query.
|
every query_stride frames and convert each one into a query.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dict with the keys:
|
A dict with the keys:
|
||||||
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
video: Video tensor of shape [1, n_frames, height, width, 3]. The video
|
||||||
has floats scaled to the range [-1, 1].
|
has floats scaled to the range [-1, 1].
|
||||||
query_points: Query points of shape [1, n_queries, 3] where
|
query_points: Query points of shape [1, n_queries, 3] where
|
||||||
each point is [t, y, x] scaled to the range [-1, 1].
|
each point is [t, y, x] scaled to the range [-1, 1].
|
||||||
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
target_points: Target points of shape [1, n_queries, n_frames, 2] where
|
||||||
each point is [x, y] scaled to the range [-1, 1].
|
each point is [x, y] scaled to the range [-1, 1].
|
||||||
trackgroup: Index of the original track that each query point was
|
trackgroup: Index of the original track that each query point was
|
||||||
sampled from. This is useful for visualization.
|
sampled from. This is useful for visualization.
|
||||||
"""
|
"""
|
||||||
tracks = []
|
tracks = []
|
||||||
occs = []
|
occs = []
|
||||||
queries = []
|
queries = []
|
||||||
trackgroups = []
|
trackgroups = []
|
||||||
total = 0
|
total = 0
|
||||||
trackgroup = np.arange(target_occluded.shape[0])
|
trackgroup = np.arange(target_occluded.shape[0])
|
||||||
for i in range(0, target_occluded.shape[1], query_stride):
|
for i in range(0, target_occluded.shape[1], query_stride):
|
||||||
mask = target_occluded[:, i] == 0
|
mask = target_occluded[:, i] == 0
|
||||||
query = np.stack(
|
query = np.stack(
|
||||||
[
|
[
|
||||||
i * np.ones(target_occluded.shape[0:1]),
|
i * np.ones(target_occluded.shape[0:1]),
|
||||||
target_points[:, i, 1],
|
target_points[:, i, 1],
|
||||||
target_points[:, i, 0],
|
target_points[:, i, 0],
|
||||||
],
|
],
|
||||||
axis=-1,
|
axis=-1,
|
||||||
)
|
)
|
||||||
queries.append(query[mask])
|
queries.append(query[mask])
|
||||||
tracks.append(target_points[mask])
|
tracks.append(target_points[mask])
|
||||||
occs.append(target_occluded[mask])
|
occs.append(target_occluded[mask])
|
||||||
trackgroups.append(trackgroup[mask])
|
trackgroups.append(trackgroup[mask])
|
||||||
total += np.array(np.sum(target_occluded[:, i] == 0))
|
total += np.array(np.sum(target_occluded[:, i] == 0))
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"video": frames[np.newaxis, ...],
|
"video": frames[np.newaxis, ...],
|
||||||
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
|
"query_points": np.concatenate(queries, axis=0)[np.newaxis, ...],
|
||||||
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
"target_points": np.concatenate(tracks, axis=0)[np.newaxis, ...],
|
||||||
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
|
"occluded": np.concatenate(occs, axis=0)[np.newaxis, ...],
|
||||||
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
"trackgroup": np.concatenate(trackgroups, axis=0)[np.newaxis, ...],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TapVidDataset(torch.utils.data.Dataset):
|
class TapVidDataset(torch.utils.data.Dataset):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
data_root,
|
data_root,
|
||||||
dataset_type="davis",
|
dataset_type="davis",
|
||||||
resize_to_256=True,
|
resize_to_256=True,
|
||||||
queried_first=True,
|
queried_first=True,
|
||||||
):
|
):
|
||||||
self.dataset_type = dataset_type
|
self.dataset_type = dataset_type
|
||||||
self.resize_to_256 = resize_to_256
|
self.resize_to_256 = resize_to_256
|
||||||
self.queried_first = queried_first
|
self.queried_first = queried_first
|
||||||
if self.dataset_type == "kinetics":
|
if self.dataset_type == "kinetics":
|
||||||
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
|
all_paths = glob.glob(os.path.join(data_root, "*_of_0010.pkl"))
|
||||||
points_dataset = []
|
points_dataset = []
|
||||||
for pickle_path in all_paths:
|
for pickle_path in all_paths:
|
||||||
with open(pickle_path, "rb") as f:
|
with open(pickle_path, "rb") as f:
|
||||||
data = pickle.load(f)
|
data = pickle.load(f)
|
||||||
points_dataset = points_dataset + data
|
points_dataset = points_dataset + data
|
||||||
self.points_dataset = points_dataset
|
self.points_dataset = points_dataset
|
||||||
else:
|
else:
|
||||||
with open(data_root, "rb") as f:
|
with open(data_root, "rb") as f:
|
||||||
self.points_dataset = pickle.load(f)
|
self.points_dataset = pickle.load(f)
|
||||||
if self.dataset_type == "davis":
|
if self.dataset_type == "davis":
|
||||||
self.video_names = list(self.points_dataset.keys())
|
self.video_names = list(self.points_dataset.keys())
|
||||||
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
|
print("found %d unique videos in %s" % (len(self.points_dataset), data_root))
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
if self.dataset_type == "davis":
|
if self.dataset_type == "davis":
|
||||||
video_name = self.video_names[index]
|
video_name = self.video_names[index]
|
||||||
else:
|
else:
|
||||||
video_name = index
|
video_name = index
|
||||||
video = self.points_dataset[video_name]
|
video = self.points_dataset[video_name]
|
||||||
frames = video["video"]
|
frames = video["video"]
|
||||||
|
|
||||||
if isinstance(frames[0], bytes):
|
if isinstance(frames[0], bytes):
|
||||||
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
# TAP-Vid is stored and JPEG bytes rather than `np.ndarray`s.
|
||||||
def decode(frame):
|
def decode(frame):
|
||||||
byteio = io.BytesIO(frame)
|
byteio = io.BytesIO(frame)
|
||||||
img = Image.open(byteio)
|
img = Image.open(byteio)
|
||||||
return np.array(img)
|
return np.array(img)
|
||||||
|
|
||||||
frames = np.array([decode(frame) for frame in frames])
|
frames = np.array([decode(frame) for frame in frames])
|
||||||
|
|
||||||
target_points = self.points_dataset[video_name]["points"]
|
target_points = self.points_dataset[video_name]["points"]
|
||||||
if self.resize_to_256:
|
if self.resize_to_256:
|
||||||
frames = resize_video(frames, [256, 256])
|
frames = resize_video(frames, [256, 256])
|
||||||
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
|
target_points *= np.array([255, 255]) # 1 should be mapped to 256-1
|
||||||
else:
|
else:
|
||||||
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
|
target_points *= np.array([frames.shape[2] - 1, frames.shape[1] - 1])
|
||||||
|
|
||||||
target_occ = self.points_dataset[video_name]["occluded"]
|
target_occ = self.points_dataset[video_name]["occluded"]
|
||||||
if self.queried_first:
|
if self.queried_first:
|
||||||
converted = sample_queries_first(target_occ, target_points, frames)
|
converted = sample_queries_first(target_occ, target_points, frames)
|
||||||
else:
|
else:
|
||||||
converted = sample_queries_strided(target_occ, target_points, frames)
|
converted = sample_queries_strided(target_occ, target_points, frames)
|
||||||
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
assert converted["target_points"].shape[1] == converted["query_points"].shape[1]
|
||||||
|
|
||||||
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
|
trajs = torch.from_numpy(converted["target_points"])[0].permute(1, 0, 2).float() # T, N, D
|
||||||
|
|
||||||
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
rgbs = torch.from_numpy(frames).permute(0, 3, 1, 2).float()
|
||||||
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
|
visibles = torch.logical_not(torch.from_numpy(converted["occluded"]))[0].permute(
|
||||||
1, 0
|
1, 0
|
||||||
) # T, N
|
) # T, N
|
||||||
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
query_points = torch.from_numpy(converted["query_points"])[0] # T, N
|
||||||
return CoTrackerData(
|
return CoTrackerData(
|
||||||
rgbs,
|
rgbs,
|
||||||
trajs,
|
trajs,
|
||||||
visibles,
|
visibles,
|
||||||
seq_name=str(video_name),
|
seq_name=str(video_name),
|
||||||
query_points=query_points,
|
query_points=query_points,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.points_dataset)
|
return len(self.points_dataset)
|
||||||
|
@@ -1,106 +1,106 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
class CoTrackerData:
|
class CoTrackerData:
|
||||||
"""
|
"""
|
||||||
Dataclass for storing video tracks data.
|
Dataclass for storing video tracks data.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
video: torch.Tensor # B, S, C, H, W
|
video: torch.Tensor # B, S, C, H, W
|
||||||
trajectory: torch.Tensor # B, S, N, 2
|
trajectory: torch.Tensor # B, S, N, 2
|
||||||
visibility: torch.Tensor # B, S, N
|
visibility: torch.Tensor # B, S, N
|
||||||
# optional data
|
# optional data
|
||||||
valid: Optional[torch.Tensor] = None # B, S, N
|
valid: Optional[torch.Tensor] = None # B, S, N
|
||||||
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
|
segmentation: Optional[torch.Tensor] = None # B, S, 1, H, W
|
||||||
seq_name: Optional[str] = None
|
seq_name: Optional[str] = None
|
||||||
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
query_points: Optional[torch.Tensor] = None # TapVID evaluation format
|
||||||
|
|
||||||
|
|
||||||
def collate_fn(batch):
|
def collate_fn(batch):
|
||||||
"""
|
"""
|
||||||
Collate function for video tracks data.
|
Collate function for video tracks data.
|
||||||
"""
|
"""
|
||||||
video = torch.stack([b.video for b in batch], dim=0)
|
video = torch.stack([b.video for b in batch], dim=0)
|
||||||
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
trajectory = torch.stack([b.trajectory for b in batch], dim=0)
|
||||||
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
visibility = torch.stack([b.visibility for b in batch], dim=0)
|
||||||
query_points = segmentation = None
|
query_points = segmentation = None
|
||||||
if batch[0].query_points is not None:
|
if batch[0].query_points is not None:
|
||||||
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
query_points = torch.stack([b.query_points for b in batch], dim=0)
|
||||||
if batch[0].segmentation is not None:
|
if batch[0].segmentation is not None:
|
||||||
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
segmentation = torch.stack([b.segmentation for b in batch], dim=0)
|
||||||
seq_name = [b.seq_name for b in batch]
|
seq_name = [b.seq_name for b in batch]
|
||||||
|
|
||||||
return CoTrackerData(
|
return CoTrackerData(
|
||||||
video=video,
|
video=video,
|
||||||
trajectory=trajectory,
|
trajectory=trajectory,
|
||||||
visibility=visibility,
|
visibility=visibility,
|
||||||
segmentation=segmentation,
|
segmentation=segmentation,
|
||||||
seq_name=seq_name,
|
seq_name=seq_name,
|
||||||
query_points=query_points,
|
query_points=query_points,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def collate_fn_train(batch):
|
def collate_fn_train(batch):
|
||||||
"""
|
"""
|
||||||
Collate function for video tracks data during training.
|
Collate function for video tracks data during training.
|
||||||
"""
|
"""
|
||||||
gotit = [gotit for _, gotit in batch]
|
gotit = [gotit for _, gotit in batch]
|
||||||
video = torch.stack([b.video for b, _ in batch], dim=0)
|
video = torch.stack([b.video for b, _ in batch], dim=0)
|
||||||
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
trajectory = torch.stack([b.trajectory for b, _ in batch], dim=0)
|
||||||
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
visibility = torch.stack([b.visibility for b, _ in batch], dim=0)
|
||||||
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
valid = torch.stack([b.valid for b, _ in batch], dim=0)
|
||||||
seq_name = [b.seq_name for b, _ in batch]
|
seq_name = [b.seq_name for b, _ in batch]
|
||||||
return (
|
return (
|
||||||
CoTrackerData(
|
CoTrackerData(
|
||||||
video=video,
|
video=video,
|
||||||
trajectory=trajectory,
|
trajectory=trajectory,
|
||||||
visibility=visibility,
|
visibility=visibility,
|
||||||
valid=valid,
|
valid=valid,
|
||||||
seq_name=seq_name,
|
seq_name=seq_name,
|
||||||
),
|
),
|
||||||
gotit,
|
gotit,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def try_to_cuda(t: Any) -> Any:
|
def try_to_cuda(t: Any) -> Any:
|
||||||
"""
|
"""
|
||||||
Try to move the input variable `t` to a cuda device.
|
Try to move the input variable `t` to a cuda device.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
t: Input.
|
t: Input.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
t_cuda: `t` moved to a cuda device, if supported.
|
t_cuda: `t` moved to a cuda device, if supported.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
t = t.float().cuda()
|
t = t.float().cuda()
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
pass
|
pass
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
|
||||||
def dataclass_to_cuda_(obj):
|
def dataclass_to_cuda_(obj):
|
||||||
"""
|
"""
|
||||||
Move all contents of a dataclass to cuda inplace if supported.
|
Move all contents of a dataclass to cuda inplace if supported.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
batch: Input dataclass.
|
batch: Input dataclass.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
batch_cuda: `batch` moved to a cuda device, if supported.
|
batch_cuda: `batch` moved to a cuda device, if supported.
|
||||||
"""
|
"""
|
||||||
for f in dataclasses.fields(obj):
|
for f in dataclasses.fields(obj):
|
||||||
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
setattr(obj, f.name, try_to_cuda(getattr(obj, f.name)))
|
||||||
return obj
|
return obj
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
@@ -1,6 +1,6 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- default_config_eval
|
- default_config_eval
|
||||||
exp_dir: ./outputs/cotracker
|
exp_dir: ./outputs/cotracker
|
||||||
dataset_name: dynamic_replica
|
dataset_name: dynamic_replica
|
||||||
|
|
||||||
|
|
@@ -1,6 +1,6 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- default_config_eval
|
- default_config_eval
|
||||||
exp_dir: ./outputs/cotracker
|
exp_dir: ./outputs/cotracker
|
||||||
dataset_name: tapvid_davis_first
|
dataset_name: tapvid_davis_first
|
||||||
|
|
||||||
|
|
@@ -1,6 +1,6 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- default_config_eval
|
- default_config_eval
|
||||||
exp_dir: ./outputs/cotracker
|
exp_dir: ./outputs/cotracker
|
||||||
dataset_name: tapvid_davis_strided
|
dataset_name: tapvid_davis_strided
|
||||||
|
|
||||||
|
|
@@ -1,6 +1,6 @@
|
|||||||
defaults:
|
defaults:
|
||||||
- default_config_eval
|
- default_config_eval
|
||||||
exp_dir: ./outputs/cotracker
|
exp_dir: ./outputs/cotracker
|
||||||
dataset_name: tapvid_kinetics_first
|
dataset_name: tapvid_kinetics_first
|
||||||
|
|
||||||
|
|
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
@@ -1,138 +1,138 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from typing import Iterable, Mapping, Tuple, Union
|
from typing import Iterable, Mapping, Tuple, Union
|
||||||
|
|
||||||
|
|
||||||
def compute_tapvid_metrics(
|
def compute_tapvid_metrics(
|
||||||
query_points: np.ndarray,
|
query_points: np.ndarray,
|
||||||
gt_occluded: np.ndarray,
|
gt_occluded: np.ndarray,
|
||||||
gt_tracks: np.ndarray,
|
gt_tracks: np.ndarray,
|
||||||
pred_occluded: np.ndarray,
|
pred_occluded: np.ndarray,
|
||||||
pred_tracks: np.ndarray,
|
pred_tracks: np.ndarray,
|
||||||
query_mode: str,
|
query_mode: str,
|
||||||
) -> Mapping[str, np.ndarray]:
|
) -> Mapping[str, np.ndarray]:
|
||||||
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
|
"""Computes TAP-Vid metrics (Jaccard, Pts. Within Thresh, Occ. Acc.)
|
||||||
See the TAP-Vid paper for details on the metric computation. All inputs are
|
See the TAP-Vid paper for details on the metric computation. All inputs are
|
||||||
given in raster coordinates. The first three arguments should be the direct
|
given in raster coordinates. The first three arguments should be the direct
|
||||||
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
outputs of the reader: the 'query_points', 'occluded', and 'target_points'.
|
||||||
The paper metrics assume these are scaled relative to 256x256 images.
|
The paper metrics assume these are scaled relative to 256x256 images.
|
||||||
pred_occluded and pred_tracks are your algorithm's predictions.
|
pred_occluded and pred_tracks are your algorithm's predictions.
|
||||||
This function takes a batch of inputs, and computes metrics separately for
|
This function takes a batch of inputs, and computes metrics separately for
|
||||||
each video. The metrics for the full benchmark are a simple mean of the
|
each video. The metrics for the full benchmark are a simple mean of the
|
||||||
metrics across the full set of videos. These numbers are between 0 and 1,
|
metrics across the full set of videos. These numbers are between 0 and 1,
|
||||||
but the paper multiplies them by 100 to ease reading.
|
but the paper multiplies them by 100 to ease reading.
|
||||||
Args:
|
Args:
|
||||||
query_points: The query points, an in the format [t, y, x]. Its size is
|
query_points: The query points, an in the format [t, y, x]. Its size is
|
||||||
[b, n, 3], where b is the batch size and n is the number of queries
|
[b, n, 3], where b is the batch size and n is the number of queries
|
||||||
gt_occluded: A boolean array of shape [b, n, t], where t is the number
|
gt_occluded: A boolean array of shape [b, n, t], where t is the number
|
||||||
of frames. True indicates that the point is occluded.
|
of frames. True indicates that the point is occluded.
|
||||||
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
|
gt_tracks: The target points, of shape [b, n, t, 2]. Each point is
|
||||||
in the format [x, y]
|
in the format [x, y]
|
||||||
pred_occluded: A boolean array of predicted occlusions, in the same
|
pred_occluded: A boolean array of predicted occlusions, in the same
|
||||||
format as gt_occluded.
|
format as gt_occluded.
|
||||||
pred_tracks: An array of track predictions from your algorithm, in the
|
pred_tracks: An array of track predictions from your algorithm, in the
|
||||||
same format as gt_tracks.
|
same format as gt_tracks.
|
||||||
query_mode: Either 'first' or 'strided', depending on how queries are
|
query_mode: Either 'first' or 'strided', depending on how queries are
|
||||||
sampled. If 'first', we assume the prior knowledge that all points
|
sampled. If 'first', we assume the prior knowledge that all points
|
||||||
before the query point are occluded, and these are removed from the
|
before the query point are occluded, and these are removed from the
|
||||||
evaluation.
|
evaluation.
|
||||||
Returns:
|
Returns:
|
||||||
A dict with the following keys:
|
A dict with the following keys:
|
||||||
occlusion_accuracy: Accuracy at predicting occlusion.
|
occlusion_accuracy: Accuracy at predicting occlusion.
|
||||||
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
pts_within_{x} for x in [1, 2, 4, 8, 16]: Fraction of points
|
||||||
predicted to be within the given pixel threshold, ignoring occlusion
|
predicted to be within the given pixel threshold, ignoring occlusion
|
||||||
prediction.
|
prediction.
|
||||||
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
jaccard_{x} for x in [1, 2, 4, 8, 16]: Jaccard metric for the given
|
||||||
threshold
|
threshold
|
||||||
average_pts_within_thresh: average across pts_within_{x}
|
average_pts_within_thresh: average across pts_within_{x}
|
||||||
average_jaccard: average across jaccard_{x}
|
average_jaccard: average across jaccard_{x}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
metrics = {}
|
metrics = {}
|
||||||
# Fixed bug is described in:
|
# Fixed bug is described in:
|
||||||
# https://github.com/facebookresearch/co-tracker/issues/20
|
# https://github.com/facebookresearch/co-tracker/issues/20
|
||||||
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
|
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
|
||||||
|
|
||||||
if query_mode == "first":
|
if query_mode == "first":
|
||||||
# evaluate frames after the query frame
|
# evaluate frames after the query frame
|
||||||
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
|
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
|
||||||
elif query_mode == "strided":
|
elif query_mode == "strided":
|
||||||
# evaluate all frames except the query frame
|
# evaluate all frames except the query frame
|
||||||
query_frame_to_eval_frames = 1 - eye
|
query_frame_to_eval_frames = 1 - eye
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown query mode " + query_mode)
|
raise ValueError("Unknown query mode " + query_mode)
|
||||||
|
|
||||||
query_frame = query_points[..., 0]
|
query_frame = query_points[..., 0]
|
||||||
query_frame = np.round(query_frame).astype(np.int32)
|
query_frame = np.round(query_frame).astype(np.int32)
|
||||||
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
|
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
|
||||||
|
|
||||||
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
||||||
# ground truth.
|
# ground truth.
|
||||||
occ_acc = np.sum(
|
occ_acc = np.sum(
|
||||||
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
||||||
axis=(1, 2),
|
axis=(1, 2),
|
||||||
) / np.sum(evaluation_points)
|
) / np.sum(evaluation_points)
|
||||||
metrics["occlusion_accuracy"] = occ_acc
|
metrics["occlusion_accuracy"] = occ_acc
|
||||||
|
|
||||||
# Next, convert the predictions and ground truth positions into pixel
|
# Next, convert the predictions and ground truth positions into pixel
|
||||||
# coordinates.
|
# coordinates.
|
||||||
visible = np.logical_not(gt_occluded)
|
visible = np.logical_not(gt_occluded)
|
||||||
pred_visible = np.logical_not(pred_occluded)
|
pred_visible = np.logical_not(pred_occluded)
|
||||||
all_frac_within = []
|
all_frac_within = []
|
||||||
all_jaccard = []
|
all_jaccard = []
|
||||||
for thresh in [1, 2, 4, 8, 16]:
|
for thresh in [1, 2, 4, 8, 16]:
|
||||||
# True positives are points that are within the threshold and where both
|
# True positives are points that are within the threshold and where both
|
||||||
# the prediction and the ground truth are listed as visible.
|
# the prediction and the ground truth are listed as visible.
|
||||||
within_dist = np.sum(
|
within_dist = np.sum(
|
||||||
np.square(pred_tracks - gt_tracks),
|
np.square(pred_tracks - gt_tracks),
|
||||||
axis=-1,
|
axis=-1,
|
||||||
) < np.square(thresh)
|
) < np.square(thresh)
|
||||||
is_correct = np.logical_and(within_dist, visible)
|
is_correct = np.logical_and(within_dist, visible)
|
||||||
|
|
||||||
# Compute the frac_within_threshold, which is the fraction of points
|
# Compute the frac_within_threshold, which is the fraction of points
|
||||||
# within the threshold among points that are visible in the ground truth,
|
# within the threshold among points that are visible in the ground truth,
|
||||||
# ignoring whether they're predicted to be visible.
|
# ignoring whether they're predicted to be visible.
|
||||||
count_correct = np.sum(
|
count_correct = np.sum(
|
||||||
is_correct & evaluation_points,
|
is_correct & evaluation_points,
|
||||||
axis=(1, 2),
|
axis=(1, 2),
|
||||||
)
|
)
|
||||||
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
|
count_visible_points = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||||
frac_correct = count_correct / count_visible_points
|
frac_correct = count_correct / count_visible_points
|
||||||
metrics["pts_within_" + str(thresh)] = frac_correct
|
metrics["pts_within_" + str(thresh)] = frac_correct
|
||||||
all_frac_within.append(frac_correct)
|
all_frac_within.append(frac_correct)
|
||||||
|
|
||||||
true_positives = np.sum(
|
true_positives = np.sum(
|
||||||
is_correct & pred_visible & evaluation_points, axis=(1, 2)
|
is_correct & pred_visible & evaluation_points, axis=(1, 2)
|
||||||
)
|
)
|
||||||
|
|
||||||
# The denominator of the jaccard metric is the true positives plus
|
# The denominator of the jaccard metric is the true positives plus
|
||||||
# false positives plus false negatives. However, note that true positives
|
# false positives plus false negatives. However, note that true positives
|
||||||
# plus false negatives is simply the number of points in the ground truth
|
# plus false negatives is simply the number of points in the ground truth
|
||||||
# which is easier to compute than trying to compute all three quantities.
|
# which is easier to compute than trying to compute all three quantities.
|
||||||
# Thus we just add the number of points in the ground truth to the number
|
# Thus we just add the number of points in the ground truth to the number
|
||||||
# of false positives.
|
# of false positives.
|
||||||
#
|
#
|
||||||
# False positives are simply points that are predicted to be visible,
|
# False positives are simply points that are predicted to be visible,
|
||||||
# but the ground truth is not visible or too far from the prediction.
|
# but the ground truth is not visible or too far from the prediction.
|
||||||
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
|
gt_positives = np.sum(visible & evaluation_points, axis=(1, 2))
|
||||||
false_positives = (~visible) & pred_visible
|
false_positives = (~visible) & pred_visible
|
||||||
false_positives = false_positives | ((~within_dist) & pred_visible)
|
false_positives = false_positives | ((~within_dist) & pred_visible)
|
||||||
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
|
false_positives = np.sum(false_positives & evaluation_points, axis=(1, 2))
|
||||||
jaccard = true_positives / (gt_positives + false_positives)
|
jaccard = true_positives / (gt_positives + false_positives)
|
||||||
metrics["jaccard_" + str(thresh)] = jaccard
|
metrics["jaccard_" + str(thresh)] = jaccard
|
||||||
all_jaccard.append(jaccard)
|
all_jaccard.append(jaccard)
|
||||||
metrics["average_jaccard"] = np.mean(
|
metrics["average_jaccard"] = np.mean(
|
||||||
np.stack(all_jaccard, axis=1),
|
np.stack(all_jaccard, axis=1),
|
||||||
axis=1,
|
axis=1,
|
||||||
)
|
)
|
||||||
metrics["average_pts_within_thresh"] = np.mean(
|
metrics["average_pts_within_thresh"] = np.mean(
|
||||||
np.stack(all_frac_within, axis=1),
|
np.stack(all_frac_within, axis=1),
|
||||||
axis=1,
|
axis=1,
|
||||||
)
|
)
|
||||||
return metrics
|
return metrics
|
||||||
|
@@ -1,253 +1,253 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
from cotracker.datasets.utils import dataclass_to_cuda_
|
from cotracker.datasets.utils import dataclass_to_cuda_
|
||||||
from cotracker.utils.visualizer import Visualizer
|
from cotracker.utils.visualizer import Visualizer
|
||||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||||
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
|
from cotracker.evaluation.core.eval_utils import compute_tapvid_metrics
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
"""
|
"""
|
||||||
A class defining the CoTracker evaluator.
|
A class defining the CoTracker evaluator.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, exp_dir) -> None:
|
def __init__(self, exp_dir) -> None:
|
||||||
# Visualization
|
# Visualization
|
||||||
self.exp_dir = exp_dir
|
self.exp_dir = exp_dir
|
||||||
os.makedirs(exp_dir, exist_ok=True)
|
os.makedirs(exp_dir, exist_ok=True)
|
||||||
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
|
self.visualization_filepaths = defaultdict(lambda: defaultdict(list))
|
||||||
self.visualize_dir = os.path.join(exp_dir, "visualisations")
|
self.visualize_dir = os.path.join(exp_dir, "visualisations")
|
||||||
|
|
||||||
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
|
def compute_metrics(self, metrics, sample, pred_trajectory, dataset_name):
|
||||||
if isinstance(pred_trajectory, tuple):
|
if isinstance(pred_trajectory, tuple):
|
||||||
pred_trajectory, pred_visibility = pred_trajectory
|
pred_trajectory, pred_visibility = pred_trajectory
|
||||||
else:
|
else:
|
||||||
pred_visibility = None
|
pred_visibility = None
|
||||||
if "tapvid" in dataset_name:
|
if "tapvid" in dataset_name:
|
||||||
B, T, N, D = sample.trajectory.shape
|
B, T, N, D = sample.trajectory.shape
|
||||||
traj = sample.trajectory.clone()
|
traj = sample.trajectory.clone()
|
||||||
thr = 0.9
|
thr = 0.9
|
||||||
|
|
||||||
if pred_visibility is None:
|
if pred_visibility is None:
|
||||||
logging.warning("visibility is NONE")
|
logging.warning("visibility is NONE")
|
||||||
pred_visibility = torch.zeros_like(sample.visibility)
|
pred_visibility = torch.zeros_like(sample.visibility)
|
||||||
|
|
||||||
if not pred_visibility.dtype == torch.bool:
|
if not pred_visibility.dtype == torch.bool:
|
||||||
pred_visibility = pred_visibility > thr
|
pred_visibility = pred_visibility > thr
|
||||||
|
|
||||||
query_points = sample.query_points.clone().cpu().numpy()
|
query_points = sample.query_points.clone().cpu().numpy()
|
||||||
|
|
||||||
pred_visibility = pred_visibility[:, :, :N]
|
pred_visibility = pred_visibility[:, :, :N]
|
||||||
pred_trajectory = pred_trajectory[:, :, :N]
|
pred_trajectory = pred_trajectory[:, :, :N]
|
||||||
|
|
||||||
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
gt_tracks = traj.permute(0, 2, 1, 3).cpu().numpy()
|
||||||
gt_occluded = (
|
gt_occluded = (
|
||||||
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
torch.logical_not(sample.visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||||
)
|
)
|
||||||
|
|
||||||
pred_occluded = (
|
pred_occluded = (
|
||||||
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
torch.logical_not(pred_visibility.clone().permute(0, 2, 1)).cpu().numpy()
|
||||||
)
|
)
|
||||||
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
pred_tracks = pred_trajectory.permute(0, 2, 1, 3).cpu().numpy()
|
||||||
|
|
||||||
out_metrics = compute_tapvid_metrics(
|
out_metrics = compute_tapvid_metrics(
|
||||||
query_points,
|
query_points,
|
||||||
gt_occluded,
|
gt_occluded,
|
||||||
gt_tracks,
|
gt_tracks,
|
||||||
pred_occluded,
|
pred_occluded,
|
||||||
pred_tracks,
|
pred_tracks,
|
||||||
query_mode="strided" if "strided" in dataset_name else "first",
|
query_mode="strided" if "strided" in dataset_name else "first",
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics[sample.seq_name[0]] = out_metrics
|
metrics[sample.seq_name[0]] = out_metrics
|
||||||
for metric_name in out_metrics.keys():
|
for metric_name in out_metrics.keys():
|
||||||
if "avg" not in metrics:
|
if "avg" not in metrics:
|
||||||
metrics["avg"] = {}
|
metrics["avg"] = {}
|
||||||
metrics["avg"][metric_name] = np.mean(
|
metrics["avg"][metric_name] = np.mean(
|
||||||
[v[metric_name] for k, v in metrics.items() if k != "avg"]
|
[v[metric_name] for k, v in metrics.items() if k != "avg"]
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(f"Metrics: {out_metrics}")
|
logging.info(f"Metrics: {out_metrics}")
|
||||||
logging.info(f"avg: {metrics['avg']}")
|
logging.info(f"avg: {metrics['avg']}")
|
||||||
print("metrics", out_metrics)
|
print("metrics", out_metrics)
|
||||||
print("avg", metrics["avg"])
|
print("avg", metrics["avg"])
|
||||||
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
|
elif dataset_name == "dynamic_replica" or dataset_name == "pointodyssey":
|
||||||
*_, N, _ = sample.trajectory.shape
|
*_, N, _ = sample.trajectory.shape
|
||||||
B, T, N = sample.visibility.shape
|
B, T, N = sample.visibility.shape
|
||||||
H, W = sample.video.shape[-2:]
|
H, W = sample.video.shape[-2:]
|
||||||
device = sample.video.device
|
device = sample.video.device
|
||||||
|
|
||||||
out_metrics = {}
|
out_metrics = {}
|
||||||
|
|
||||||
d_vis_sum = d_occ_sum = d_sum_all = 0.0
|
d_vis_sum = d_occ_sum = d_sum_all = 0.0
|
||||||
thrs = [1, 2, 4, 8, 16]
|
thrs = [1, 2, 4, 8, 16]
|
||||||
sx_ = (W - 1) / 255.0
|
sx_ = (W - 1) / 255.0
|
||||||
sy_ = (H - 1) / 255.0
|
sy_ = (H - 1) / 255.0
|
||||||
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
|
sc_py = np.array([sx_, sy_]).reshape([1, 1, 2])
|
||||||
sc_pt = torch.from_numpy(sc_py).float().to(device)
|
sc_pt = torch.from_numpy(sc_py).float().to(device)
|
||||||
__, first_visible_inds = torch.max(sample.visibility, dim=1)
|
__, first_visible_inds = torch.max(sample.visibility, dim=1)
|
||||||
|
|
||||||
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
|
frame_ids_tensor = torch.arange(T, device=device)[None, :, None].repeat(B, 1, N)
|
||||||
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
|
start_tracking_mask = frame_ids_tensor > (first_visible_inds.unsqueeze(1))
|
||||||
|
|
||||||
for thr in thrs:
|
for thr in thrs:
|
||||||
d_ = (
|
d_ = (
|
||||||
torch.norm(
|
torch.norm(
|
||||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
< thr
|
< thr
|
||||||
).float() # B,S-1,N
|
).float() # B,S-1,N
|
||||||
d_occ = (
|
d_occ = (
|
||||||
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
|
reduce_masked_mean(d_, (1 - sample.visibility) * start_tracking_mask).item()
|
||||||
* 100.0
|
* 100.0
|
||||||
)
|
)
|
||||||
d_occ_sum += d_occ
|
d_occ_sum += d_occ
|
||||||
out_metrics[f"accuracy_occ_{thr}"] = d_occ
|
out_metrics[f"accuracy_occ_{thr}"] = d_occ
|
||||||
|
|
||||||
d_vis = (
|
d_vis = (
|
||||||
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
|
reduce_masked_mean(d_, sample.visibility * start_tracking_mask).item() * 100.0
|
||||||
)
|
)
|
||||||
d_vis_sum += d_vis
|
d_vis_sum += d_vis
|
||||||
out_metrics[f"accuracy_vis_{thr}"] = d_vis
|
out_metrics[f"accuracy_vis_{thr}"] = d_vis
|
||||||
|
|
||||||
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
|
d_all = reduce_masked_mean(d_, start_tracking_mask).item() * 100.0
|
||||||
d_sum_all += d_all
|
d_sum_all += d_all
|
||||||
out_metrics[f"accuracy_{thr}"] = d_all
|
out_metrics[f"accuracy_{thr}"] = d_all
|
||||||
|
|
||||||
d_occ_avg = d_occ_sum / len(thrs)
|
d_occ_avg = d_occ_sum / len(thrs)
|
||||||
d_vis_avg = d_vis_sum / len(thrs)
|
d_vis_avg = d_vis_sum / len(thrs)
|
||||||
d_all_avg = d_sum_all / len(thrs)
|
d_all_avg = d_sum_all / len(thrs)
|
||||||
|
|
||||||
sur_thr = 50
|
sur_thr = 50
|
||||||
dists = torch.norm(
|
dists = torch.norm(
|
||||||
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
pred_trajectory[..., :2] / sc_pt - sample.trajectory[..., :2] / sc_pt,
|
||||||
dim=-1,
|
dim=-1,
|
||||||
) # B,S,N
|
) # B,S,N
|
||||||
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
|
dist_ok = 1 - (dists > sur_thr).float() * sample.visibility # B,S,N
|
||||||
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
|
survival = torch.cumprod(dist_ok, dim=1) # B,S,N
|
||||||
out_metrics["survival"] = torch.mean(survival).item() * 100.0
|
out_metrics["survival"] = torch.mean(survival).item() * 100.0
|
||||||
|
|
||||||
out_metrics["accuracy_occ"] = d_occ_avg
|
out_metrics["accuracy_occ"] = d_occ_avg
|
||||||
out_metrics["accuracy_vis"] = d_vis_avg
|
out_metrics["accuracy_vis"] = d_vis_avg
|
||||||
out_metrics["accuracy"] = d_all_avg
|
out_metrics["accuracy"] = d_all_avg
|
||||||
|
|
||||||
metrics[sample.seq_name[0]] = out_metrics
|
metrics[sample.seq_name[0]] = out_metrics
|
||||||
for metric_name in out_metrics.keys():
|
for metric_name in out_metrics.keys():
|
||||||
if "avg" not in metrics:
|
if "avg" not in metrics:
|
||||||
metrics["avg"] = {}
|
metrics["avg"] = {}
|
||||||
metrics["avg"][metric_name] = float(
|
metrics["avg"][metric_name] = float(
|
||||||
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
|
np.mean([v[metric_name] for k, v in metrics.items() if k != "avg"])
|
||||||
)
|
)
|
||||||
|
|
||||||
logging.info(f"Metrics: {out_metrics}")
|
logging.info(f"Metrics: {out_metrics}")
|
||||||
logging.info(f"avg: {metrics['avg']}")
|
logging.info(f"avg: {metrics['avg']}")
|
||||||
print("metrics", out_metrics)
|
print("metrics", out_metrics)
|
||||||
print("avg", metrics["avg"])
|
print("avg", metrics["avg"])
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def evaluate_sequence(
|
def evaluate_sequence(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
test_dataloader: torch.utils.data.DataLoader,
|
test_dataloader: torch.utils.data.DataLoader,
|
||||||
dataset_name: str,
|
dataset_name: str,
|
||||||
train_mode=False,
|
train_mode=False,
|
||||||
visualize_every: int = 1,
|
visualize_every: int = 1,
|
||||||
writer: Optional[SummaryWriter] = None,
|
writer: Optional[SummaryWriter] = None,
|
||||||
step: Optional[int] = 0,
|
step: Optional[int] = 0,
|
||||||
):
|
):
|
||||||
metrics = {}
|
metrics = {}
|
||||||
|
|
||||||
vis = Visualizer(
|
vis = Visualizer(
|
||||||
save_dir=self.exp_dir,
|
save_dir=self.exp_dir,
|
||||||
fps=7,
|
fps=7,
|
||||||
)
|
)
|
||||||
|
|
||||||
for ind, sample in enumerate(tqdm(test_dataloader)):
|
for ind, sample in enumerate(tqdm(test_dataloader)):
|
||||||
if isinstance(sample, tuple):
|
if isinstance(sample, tuple):
|
||||||
sample, gotit = sample
|
sample, gotit = sample
|
||||||
if not all(gotit):
|
if not all(gotit):
|
||||||
print("batch is None")
|
print("batch is None")
|
||||||
continue
|
continue
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
dataclass_to_cuda_(sample)
|
dataclass_to_cuda_(sample)
|
||||||
device = torch.device("cuda")
|
device = torch.device("cuda")
|
||||||
else:
|
else:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not train_mode
|
not train_mode
|
||||||
and hasattr(model, "sequence_len")
|
and hasattr(model, "sequence_len")
|
||||||
and (sample.visibility[:, : model.sequence_len].sum() == 0)
|
and (sample.visibility[:, : model.sequence_len].sum() == 0)
|
||||||
):
|
):
|
||||||
print(f"skipping batch {ind}")
|
print(f"skipping batch {ind}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if "tapvid" in dataset_name:
|
if "tapvid" in dataset_name:
|
||||||
queries = sample.query_points.clone().float()
|
queries = sample.query_points.clone().float()
|
||||||
|
|
||||||
queries = torch.stack(
|
queries = torch.stack(
|
||||||
[
|
[
|
||||||
queries[:, :, 0],
|
queries[:, :, 0],
|
||||||
queries[:, :, 2],
|
queries[:, :, 2],
|
||||||
queries[:, :, 1],
|
queries[:, :, 1],
|
||||||
],
|
],
|
||||||
dim=2,
|
dim=2,
|
||||||
).to(device)
|
).to(device)
|
||||||
else:
|
else:
|
||||||
queries = torch.cat(
|
queries = torch.cat(
|
||||||
[
|
[
|
||||||
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
|
torch.zeros_like(sample.trajectory[:, 0, :, :1]),
|
||||||
sample.trajectory[:, 0],
|
sample.trajectory[:, 0],
|
||||||
],
|
],
|
||||||
dim=2,
|
dim=2,
|
||||||
).to(device)
|
).to(device)
|
||||||
|
|
||||||
pred_tracks = model(sample.video, queries)
|
pred_tracks = model(sample.video, queries)
|
||||||
if "strided" in dataset_name:
|
if "strided" in dataset_name:
|
||||||
inv_video = sample.video.flip(1).clone()
|
inv_video = sample.video.flip(1).clone()
|
||||||
inv_queries = queries.clone()
|
inv_queries = queries.clone()
|
||||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||||
|
|
||||||
pred_trj, pred_vsb = pred_tracks
|
pred_trj, pred_vsb = pred_tracks
|
||||||
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
|
inv_pred_trj, inv_pred_vsb = model(inv_video, inv_queries)
|
||||||
|
|
||||||
inv_pred_trj = inv_pred_trj.flip(1)
|
inv_pred_trj = inv_pred_trj.flip(1)
|
||||||
inv_pred_vsb = inv_pred_vsb.flip(1)
|
inv_pred_vsb = inv_pred_vsb.flip(1)
|
||||||
|
|
||||||
mask = pred_trj == 0
|
mask = pred_trj == 0
|
||||||
|
|
||||||
pred_trj[mask] = inv_pred_trj[mask]
|
pred_trj[mask] = inv_pred_trj[mask]
|
||||||
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
|
pred_vsb[mask[:, :, :, 0]] = inv_pred_vsb[mask[:, :, :, 0]]
|
||||||
|
|
||||||
pred_tracks = pred_trj, pred_vsb
|
pred_tracks = pred_trj, pred_vsb
|
||||||
|
|
||||||
if dataset_name == "badja" or dataset_name == "fastcapture":
|
if dataset_name == "badja" or dataset_name == "fastcapture":
|
||||||
seq_name = sample.seq_name[0]
|
seq_name = sample.seq_name[0]
|
||||||
else:
|
else:
|
||||||
seq_name = str(ind)
|
seq_name = str(ind)
|
||||||
if ind % visualize_every == 0:
|
if ind % visualize_every == 0:
|
||||||
vis.visualize(
|
vis.visualize(
|
||||||
sample.video,
|
sample.video,
|
||||||
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
pred_tracks[0] if isinstance(pred_tracks, tuple) else pred_tracks,
|
||||||
filename=dataset_name + "_" + seq_name,
|
filename=dataset_name + "_" + seq_name,
|
||||||
writer=writer,
|
writer=writer,
|
||||||
step=step,
|
step=step,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
self.compute_metrics(metrics, sample, pred_tracks, dataset_name)
|
||||||
return metrics
|
return metrics
|
||||||
|
@@ -1,169 +1,169 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
import hydra
|
import hydra
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
|
|
||||||
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
from cotracker.datasets.tap_vid_datasets import TapVidDataset
|
||||||
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
|
from cotracker.datasets.dr_dataset import DynamicReplicaDataset
|
||||||
from cotracker.datasets.utils import collate_fn
|
from cotracker.datasets.utils import collate_fn
|
||||||
|
|
||||||
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
from cotracker.models.evaluation_predictor import EvaluationPredictor
|
||||||
|
|
||||||
from cotracker.evaluation.core.evaluator import Evaluator
|
from cotracker.evaluation.core.evaluator import Evaluator
|
||||||
from cotracker.models.build_cotracker import (
|
from cotracker.models.build_cotracker import (
|
||||||
build_cotracker,
|
build_cotracker,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=False)
|
@dataclass(eq=False)
|
||||||
class DefaultConfig:
|
class DefaultConfig:
|
||||||
# Directory where all outputs of the experiment will be saved.
|
# Directory where all outputs of the experiment will be saved.
|
||||||
exp_dir: str = "./outputs"
|
exp_dir: str = "./outputs"
|
||||||
|
|
||||||
# Name of the dataset to be used for the evaluation.
|
# Name of the dataset to be used for the evaluation.
|
||||||
dataset_name: str = "tapvid_davis_first"
|
dataset_name: str = "tapvid_davis_first"
|
||||||
# The root directory of the dataset.
|
# The root directory of the dataset.
|
||||||
dataset_root: str = "./"
|
dataset_root: str = "./"
|
||||||
|
|
||||||
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
# Path to the pre-trained model checkpoint to be used for the evaluation.
|
||||||
# The default value is the path to a specific CoTracker model checkpoint.
|
# The default value is the path to a specific CoTracker model checkpoint.
|
||||||
checkpoint: str = "./checkpoints/cotracker2.pth"
|
checkpoint: str = "./checkpoints/cotracker2.pth"
|
||||||
|
|
||||||
# EvaluationPredictor parameters
|
# EvaluationPredictor parameters
|
||||||
# The size (N) of the support grid used in the predictor.
|
# The size (N) of the support grid used in the predictor.
|
||||||
# The total number of points is (N*N).
|
# The total number of points is (N*N).
|
||||||
grid_size: int = 5
|
grid_size: int = 5
|
||||||
# The size (N) of the local support grid.
|
# The size (N) of the local support grid.
|
||||||
local_grid_size: int = 8
|
local_grid_size: int = 8
|
||||||
# A flag indicating whether to evaluate one ground truth point at a time.
|
# A flag indicating whether to evaluate one ground truth point at a time.
|
||||||
single_point: bool = True
|
single_point: bool = True
|
||||||
# The number of iterative updates for each sliding window.
|
# The number of iterative updates for each sliding window.
|
||||||
n_iters: int = 6
|
n_iters: int = 6
|
||||||
|
|
||||||
seed: int = 0
|
seed: int = 0
|
||||||
gpu_idx: int = 0
|
gpu_idx: int = 0
|
||||||
|
|
||||||
# Override hydra's working directory to current working dir,
|
# Override hydra's working directory to current working dir,
|
||||||
# also disable storing the .hydra logs:
|
# also disable storing the .hydra logs:
|
||||||
hydra: dict = field(
|
hydra: dict = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"run": {"dir": "."},
|
"run": {"dir": "."},
|
||||||
"output_subdir": None,
|
"output_subdir": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_eval(cfg: DefaultConfig):
|
def run_eval(cfg: DefaultConfig):
|
||||||
"""
|
"""
|
||||||
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
|
The function evaluates CoTracker on a specified benchmark dataset based on a provided configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
|
cfg (DefaultConfig): An instance of DefaultConfig class which includes:
|
||||||
- exp_dir (str): The directory path for the experiment.
|
- exp_dir (str): The directory path for the experiment.
|
||||||
- dataset_name (str): The name of the dataset to be used.
|
- dataset_name (str): The name of the dataset to be used.
|
||||||
- dataset_root (str): The root directory of the dataset.
|
- dataset_root (str): The root directory of the dataset.
|
||||||
- checkpoint (str): The path to the CoTracker model's checkpoint.
|
- checkpoint (str): The path to the CoTracker model's checkpoint.
|
||||||
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
|
- single_point (bool): A flag indicating whether to evaluate one ground truth point at a time.
|
||||||
- n_iters (int): The number of iterative updates for each sliding window.
|
- n_iters (int): The number of iterative updates for each sliding window.
|
||||||
- seed (int): The seed for setting the random state for reproducibility.
|
- seed (int): The seed for setting the random state for reproducibility.
|
||||||
- gpu_idx (int): The index of the GPU to be used.
|
- gpu_idx (int): The index of the GPU to be used.
|
||||||
"""
|
"""
|
||||||
# Creating the experiment directory if it doesn't exist
|
# Creating the experiment directory if it doesn't exist
|
||||||
os.makedirs(cfg.exp_dir, exist_ok=True)
|
os.makedirs(cfg.exp_dir, exist_ok=True)
|
||||||
|
|
||||||
# Saving the experiment configuration to a .yaml file in the experiment directory
|
# Saving the experiment configuration to a .yaml file in the experiment directory
|
||||||
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
cfg_file = os.path.join(cfg.exp_dir, "expconfig.yaml")
|
||||||
with open(cfg_file, "w") as f:
|
with open(cfg_file, "w") as f:
|
||||||
OmegaConf.save(config=cfg, f=f)
|
OmegaConf.save(config=cfg, f=f)
|
||||||
|
|
||||||
evaluator = Evaluator(cfg.exp_dir)
|
evaluator = Evaluator(cfg.exp_dir)
|
||||||
cotracker_model = build_cotracker(cfg.checkpoint)
|
cotracker_model = build_cotracker(cfg.checkpoint)
|
||||||
|
|
||||||
# Creating the EvaluationPredictor object
|
# Creating the EvaluationPredictor object
|
||||||
predictor = EvaluationPredictor(
|
predictor = EvaluationPredictor(
|
||||||
cotracker_model,
|
cotracker_model,
|
||||||
grid_size=cfg.grid_size,
|
grid_size=cfg.grid_size,
|
||||||
local_grid_size=cfg.local_grid_size,
|
local_grid_size=cfg.local_grid_size,
|
||||||
single_point=cfg.single_point,
|
single_point=cfg.single_point,
|
||||||
n_iters=cfg.n_iters,
|
n_iters=cfg.n_iters,
|
||||||
)
|
)
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
predictor.model = predictor.model.cuda()
|
predictor.model = predictor.model.cuda()
|
||||||
|
|
||||||
# Setting the random seeds
|
# Setting the random seeds
|
||||||
torch.manual_seed(cfg.seed)
|
torch.manual_seed(cfg.seed)
|
||||||
np.random.seed(cfg.seed)
|
np.random.seed(cfg.seed)
|
||||||
|
|
||||||
# Constructing the specified dataset
|
# Constructing the specified dataset
|
||||||
curr_collate_fn = collate_fn
|
curr_collate_fn = collate_fn
|
||||||
if "tapvid" in cfg.dataset_name:
|
if "tapvid" in cfg.dataset_name:
|
||||||
dataset_type = cfg.dataset_name.split("_")[1]
|
dataset_type = cfg.dataset_name.split("_")[1]
|
||||||
if dataset_type == "davis":
|
if dataset_type == "davis":
|
||||||
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
|
data_root = os.path.join(cfg.dataset_root, "tapvid_davis", "tapvid_davis.pkl")
|
||||||
elif dataset_type == "kinetics":
|
elif dataset_type == "kinetics":
|
||||||
data_root = os.path.join(
|
data_root = os.path.join(
|
||||||
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
cfg.dataset_root, "/kinetics/kinetics-dataset/k700-2020/tapvid_kinetics"
|
||||||
)
|
)
|
||||||
test_dataset = TapVidDataset(
|
test_dataset = TapVidDataset(
|
||||||
dataset_type=dataset_type,
|
dataset_type=dataset_type,
|
||||||
data_root=data_root,
|
data_root=data_root,
|
||||||
queried_first=not "strided" in cfg.dataset_name,
|
queried_first=not "strided" in cfg.dataset_name,
|
||||||
)
|
)
|
||||||
elif cfg.dataset_name == "dynamic_replica":
|
elif cfg.dataset_name == "dynamic_replica":
|
||||||
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
|
test_dataset = DynamicReplicaDataset(sample_len=300, only_first_n_samples=1)
|
||||||
|
|
||||||
# Creating the DataLoader object
|
# Creating the DataLoader object
|
||||||
test_dataloader = torch.utils.data.DataLoader(
|
test_dataloader = torch.utils.data.DataLoader(
|
||||||
test_dataset,
|
test_dataset,
|
||||||
batch_size=1,
|
batch_size=1,
|
||||||
shuffle=False,
|
shuffle=False,
|
||||||
num_workers=14,
|
num_workers=14,
|
||||||
collate_fn=curr_collate_fn,
|
collate_fn=curr_collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Timing and conducting the evaluation
|
# Timing and conducting the evaluation
|
||||||
import time
|
import time
|
||||||
|
|
||||||
start = time.time()
|
start = time.time()
|
||||||
evaluate_result = evaluator.evaluate_sequence(
|
evaluate_result = evaluator.evaluate_sequence(
|
||||||
predictor,
|
predictor,
|
||||||
test_dataloader,
|
test_dataloader,
|
||||||
dataset_name=cfg.dataset_name,
|
dataset_name=cfg.dataset_name,
|
||||||
)
|
)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
print(end - start)
|
print(end - start)
|
||||||
|
|
||||||
# Saving the evaluation results to a .json file
|
# Saving the evaluation results to a .json file
|
||||||
evaluate_result = evaluate_result["avg"]
|
evaluate_result = evaluate_result["avg"]
|
||||||
print("evaluate_result", evaluate_result)
|
print("evaluate_result", evaluate_result)
|
||||||
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
result_file = os.path.join(cfg.exp_dir, f"result_eval_.json")
|
||||||
evaluate_result["time"] = end - start
|
evaluate_result["time"] = end - start
|
||||||
print(f"Dumping eval results to {result_file}.")
|
print(f"Dumping eval results to {result_file}.")
|
||||||
with open(result_file, "w") as f:
|
with open(result_file, "w") as f:
|
||||||
json.dump(evaluate_result, f)
|
json.dump(evaluate_result, f)
|
||||||
|
|
||||||
|
|
||||||
cs = hydra.core.config_store.ConfigStore.instance()
|
cs = hydra.core.config_store.ConfigStore.instance()
|
||||||
cs.store(name="default_config_eval", node=DefaultConfig)
|
cs.store(name="default_config_eval", node=DefaultConfig)
|
||||||
|
|
||||||
|
|
||||||
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
@hydra.main(config_path="./configs/", config_name="default_config_eval")
|
||||||
def evaluate(cfg: DefaultConfig) -> None:
|
def evaluate(cfg: DefaultConfig) -> None:
|
||||||
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
|
||||||
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(cfg.gpu_idx)
|
||||||
run_eval(cfg)
|
run_eval(cfg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
evaluate()
|
evaluate()
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
BIN
cotracker/models/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/models/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/models/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
Normal file
BIN
cotracker/models/__pycache__/build_cotracker.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
Normal file
BIN
cotracker/models/__pycache__/build_cotracker.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,33 +1,33 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||||
|
|
||||||
|
|
||||||
def build_cotracker(
|
def build_cotracker(
|
||||||
checkpoint: str,
|
checkpoint: str,
|
||||||
):
|
):
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
return build_cotracker()
|
return build_cotracker()
|
||||||
model_name = checkpoint.split("/")[-1].split(".")[0]
|
model_name = checkpoint.split("/")[-1].split(".")[0]
|
||||||
if model_name == "cotracker":
|
if model_name == "cotracker":
|
||||||
return build_cotracker(checkpoint=checkpoint)
|
return build_cotracker(checkpoint=checkpoint)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown model name {model_name}")
|
raise ValueError(f"Unknown model name {model_name}")
|
||||||
|
|
||||||
|
|
||||||
def build_cotracker(checkpoint=None):
|
def build_cotracker(checkpoint=None):
|
||||||
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
|
cotracker = CoTracker2(stride=4, window_len=8, add_space_attn=True)
|
||||||
|
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
with open(checkpoint, "rb") as f:
|
with open(checkpoint, "rb") as f:
|
||||||
state_dict = torch.load(f, map_location="cpu")
|
state_dict = torch.load(f, map_location="cpu")
|
||||||
if "model" in state_dict:
|
if "model" in state_dict:
|
||||||
state_dict = state_dict["model"]
|
state_dict = state_dict["model"]
|
||||||
cotracker.load_state_dict(state_dict)
|
cotracker.load_state_dict(state_dict)
|
||||||
return cotracker
|
return cotracker
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
BIN
cotracker/models/core/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/embeddings.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/embeddings.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
Normal file
BIN
cotracker/models/core/__pycache__/model_utils.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
Normal file
BIN
cotracker/models/core/__pycache__/model_utils.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,367 +1,368 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
import collections
|
import collections
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from itertools import repeat
|
from itertools import repeat
|
||||||
|
|
||||||
from cotracker.models.core.model_utils import bilinear_sampler
|
from cotracker.models.core.model_utils import bilinear_sampler
|
||||||
|
|
||||||
|
|
||||||
# From PyTorch internals
|
# From PyTorch internals
|
||||||
def _ntuple(n):
|
def _ntuple(n):
|
||||||
def parse(x):
|
def parse(x):
|
||||||
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
||||||
return tuple(x)
|
return tuple(x)
|
||||||
return tuple(repeat(x, n))
|
return tuple(repeat(x, n))
|
||||||
|
|
||||||
return parse
|
return parse
|
||||||
|
|
||||||
|
|
||||||
def exists(val):
|
def exists(val):
|
||||||
return val is not None
|
return val is not None
|
||||||
|
|
||||||
|
|
||||||
def default(val, d):
|
def default(val, d):
|
||||||
return val if exists(val) else d
|
return val if exists(val) else d
|
||||||
|
|
||||||
|
|
||||||
to_2tuple = _ntuple(2)
|
to_2tuple = _ntuple(2)
|
||||||
|
|
||||||
|
|
||||||
class Mlp(nn.Module):
|
class Mlp(nn.Module):
|
||||||
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_features,
|
in_features,
|
||||||
hidden_features=None,
|
hidden_features=None,
|
||||||
out_features=None,
|
out_features=None,
|
||||||
act_layer=nn.GELU,
|
act_layer=nn.GELU,
|
||||||
norm_layer=None,
|
norm_layer=None,
|
||||||
bias=True,
|
bias=True,
|
||||||
drop=0.0,
|
drop=0.0,
|
||||||
use_conv=False,
|
use_conv=False,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
out_features = out_features or in_features
|
out_features = out_features or in_features
|
||||||
hidden_features = hidden_features or in_features
|
hidden_features = hidden_features or in_features
|
||||||
bias = to_2tuple(bias)
|
bias = to_2tuple(bias)
|
||||||
drop_probs = to_2tuple(drop)
|
drop_probs = to_2tuple(drop)
|
||||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||||
|
|
||||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||||
self.act = act_layer()
|
self.act = act_layer()
|
||||||
self.drop1 = nn.Dropout(drop_probs[0])
|
self.drop1 = nn.Dropout(drop_probs[0])
|
||||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||||
self.drop2 = nn.Dropout(drop_probs[1])
|
self.drop2 = nn.Dropout(drop_probs[1])
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.fc1(x)
|
x = self.fc1(x)
|
||||||
x = self.act(x)
|
x = self.act(x)
|
||||||
x = self.drop1(x)
|
x = self.drop1(x)
|
||||||
x = self.fc2(x)
|
x = self.fc2(x)
|
||||||
x = self.drop2(x)
|
x = self.drop2(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class ResidualBlock(nn.Module):
|
class ResidualBlock(nn.Module):
|
||||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||||
super(ResidualBlock, self).__init__()
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(
|
self.conv1 = nn.Conv2d(
|
||||||
in_planes,
|
in_planes,
|
||||||
planes,
|
planes,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
padding_mode="zeros",
|
padding_mode="zeros",
|
||||||
)
|
)
|
||||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
|
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, padding_mode="zeros")
|
||||||
self.relu = nn.ReLU(inplace=True)
|
self.relu = nn.ReLU(inplace=True)
|
||||||
|
|
||||||
num_groups = planes // 8
|
num_groups = planes // 8
|
||||||
|
|
||||||
if norm_fn == "group":
|
if norm_fn == "group":
|
||||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||||
|
|
||||||
elif norm_fn == "batch":
|
elif norm_fn == "batch":
|
||||||
self.norm1 = nn.BatchNorm2d(planes)
|
self.norm1 = nn.BatchNorm2d(planes)
|
||||||
self.norm2 = nn.BatchNorm2d(planes)
|
self.norm2 = nn.BatchNorm2d(planes)
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
self.norm3 = nn.BatchNorm2d(planes)
|
self.norm3 = nn.BatchNorm2d(planes)
|
||||||
|
|
||||||
elif norm_fn == "instance":
|
elif norm_fn == "instance":
|
||||||
self.norm1 = nn.InstanceNorm2d(planes)
|
self.norm1 = nn.InstanceNorm2d(planes)
|
||||||
self.norm2 = nn.InstanceNorm2d(planes)
|
self.norm2 = nn.InstanceNorm2d(planes)
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
self.norm3 = nn.InstanceNorm2d(planes)
|
self.norm3 = nn.InstanceNorm2d(planes)
|
||||||
|
|
||||||
elif norm_fn == "none":
|
elif norm_fn == "none":
|
||||||
self.norm1 = nn.Sequential()
|
self.norm1 = nn.Sequential()
|
||||||
self.norm2 = nn.Sequential()
|
self.norm2 = nn.Sequential()
|
||||||
if not stride == 1:
|
if not stride == 1:
|
||||||
self.norm3 = nn.Sequential()
|
self.norm3 = nn.Sequential()
|
||||||
|
|
||||||
if stride == 1:
|
if stride == 1:
|
||||||
self.downsample = None
|
self.downsample = None
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.downsample = nn.Sequential(
|
self.downsample = nn.Sequential(
|
||||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y = x
|
y = x
|
||||||
y = self.relu(self.norm1(self.conv1(y)))
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
y = self.relu(self.norm2(self.conv2(y)))
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
|
||||||
if self.downsample is not None:
|
if self.downsample is not None:
|
||||||
x = self.downsample(x)
|
x = self.downsample(x)
|
||||||
|
|
||||||
return self.relu(x + y)
|
return self.relu(x + y)
|
||||||
|
|
||||||
|
|
||||||
class BasicEncoder(nn.Module):
|
class BasicEncoder(nn.Module):
|
||||||
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
||||||
super(BasicEncoder, self).__init__()
|
super(BasicEncoder, self).__init__()
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.norm_fn = "instance"
|
self.norm_fn = "instance"
|
||||||
self.in_planes = output_dim // 2
|
self.in_planes = output_dim // 2
|
||||||
|
|
||||||
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
||||||
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
||||||
|
|
||||||
self.conv1 = nn.Conv2d(
|
self.conv1 = nn.Conv2d(
|
||||||
input_dim,
|
input_dim,
|
||||||
self.in_planes,
|
self.in_planes,
|
||||||
kernel_size=7,
|
kernel_size=7,
|
||||||
stride=2,
|
stride=2,
|
||||||
padding=3,
|
padding=3,
|
||||||
padding_mode="zeros",
|
padding_mode="zeros",
|
||||||
)
|
)
|
||||||
self.relu1 = nn.ReLU(inplace=True)
|
self.relu1 = nn.ReLU(inplace=True)
|
||||||
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
||||||
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
||||||
self.layer3 = self._make_layer(output_dim, stride=2)
|
self.layer3 = self._make_layer(output_dim, stride=2)
|
||||||
self.layer4 = self._make_layer(output_dim, stride=2)
|
self.layer4 = self._make_layer(output_dim, stride=2)
|
||||||
|
|
||||||
self.conv2 = nn.Conv2d(
|
self.conv2 = nn.Conv2d(
|
||||||
output_dim * 3 + output_dim // 4,
|
output_dim * 3 + output_dim // 4,
|
||||||
output_dim * 2,
|
output_dim * 2,
|
||||||
kernel_size=3,
|
kernel_size=3,
|
||||||
padding=1,
|
padding=1,
|
||||||
padding_mode="zeros",
|
padding_mode="zeros",
|
||||||
)
|
)
|
||||||
self.relu2 = nn.ReLU(inplace=True)
|
self.relu2 = nn.ReLU(inplace=True)
|
||||||
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv2d):
|
if isinstance(m, nn.Conv2d):
|
||||||
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
||||||
elif isinstance(m, (nn.InstanceNorm2d)):
|
elif isinstance(m, (nn.InstanceNorm2d)):
|
||||||
if m.weight is not None:
|
if m.weight is not None:
|
||||||
nn.init.constant_(m.weight, 1)
|
nn.init.constant_(m.weight, 1)
|
||||||
if m.bias is not None:
|
if m.bias is not None:
|
||||||
nn.init.constant_(m.bias, 0)
|
nn.init.constant_(m.bias, 0)
|
||||||
|
|
||||||
def _make_layer(self, dim, stride=1):
|
def _make_layer(self, dim, stride=1):
|
||||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
layers = (layer1, layer2)
|
layers = (layer1, layer2)
|
||||||
|
|
||||||
self.in_planes = dim
|
self.in_planes = dim
|
||||||
return nn.Sequential(*layers)
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
_, _, H, W = x.shape
|
_, _, H, W = x.shape
|
||||||
|
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
x = self.norm1(x)
|
x = self.norm1(x)
|
||||||
x = self.relu1(x)
|
x = self.relu1(x)
|
||||||
|
|
||||||
a = self.layer1(x)
|
# 四层残差块
|
||||||
b = self.layer2(a)
|
a = self.layer1(x)
|
||||||
c = self.layer3(b)
|
b = self.layer2(a)
|
||||||
d = self.layer4(c)
|
c = self.layer3(b)
|
||||||
|
d = self.layer4(c)
|
||||||
def _bilinear_intepolate(x):
|
|
||||||
return F.interpolate(
|
def _bilinear_intepolate(x):
|
||||||
x,
|
return F.interpolate(
|
||||||
(H // self.stride, W // self.stride),
|
x,
|
||||||
mode="bilinear",
|
(H // self.stride, W // self.stride),
|
||||||
align_corners=True,
|
mode="bilinear",
|
||||||
)
|
align_corners=True,
|
||||||
|
)
|
||||||
a = _bilinear_intepolate(a)
|
|
||||||
b = _bilinear_intepolate(b)
|
a = _bilinear_intepolate(a)
|
||||||
c = _bilinear_intepolate(c)
|
b = _bilinear_intepolate(b)
|
||||||
d = _bilinear_intepolate(d)
|
c = _bilinear_intepolate(c)
|
||||||
|
d = _bilinear_intepolate(d)
|
||||||
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
|
||||||
x = self.norm2(x)
|
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
||||||
x = self.relu2(x)
|
x = self.norm2(x)
|
||||||
x = self.conv3(x)
|
x = self.relu2(x)
|
||||||
return x
|
x = self.conv3(x)
|
||||||
|
return x
|
||||||
|
|
||||||
class CorrBlock:
|
|
||||||
def __init__(
|
class CorrBlock:
|
||||||
self,
|
def __init__(
|
||||||
fmaps,
|
self,
|
||||||
num_levels=4,
|
fmaps,
|
||||||
radius=4,
|
num_levels=4,
|
||||||
multiple_track_feats=False,
|
radius=4,
|
||||||
padding_mode="zeros",
|
multiple_track_feats=False,
|
||||||
):
|
padding_mode="zeros",
|
||||||
B, S, C, H, W = fmaps.shape
|
):
|
||||||
self.S, self.C, self.H, self.W = S, C, H, W
|
B, S, C, H, W = fmaps.shape
|
||||||
self.padding_mode = padding_mode
|
self.S, self.C, self.H, self.W = S, C, H, W
|
||||||
self.num_levels = num_levels
|
self.padding_mode = padding_mode
|
||||||
self.radius = radius
|
self.num_levels = num_levels
|
||||||
self.fmaps_pyramid = []
|
self.radius = radius
|
||||||
self.multiple_track_feats = multiple_track_feats
|
self.fmaps_pyramid = []
|
||||||
|
self.multiple_track_feats = multiple_track_feats
|
||||||
self.fmaps_pyramid.append(fmaps)
|
|
||||||
for i in range(self.num_levels - 1):
|
self.fmaps_pyramid.append(fmaps)
|
||||||
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
for i in range(self.num_levels - 1):
|
||||||
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
fmaps_ = fmaps.reshape(B * S, C, H, W)
|
||||||
_, _, H, W = fmaps_.shape
|
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
||||||
fmaps = fmaps_.reshape(B, S, C, H, W)
|
_, _, H, W = fmaps_.shape
|
||||||
self.fmaps_pyramid.append(fmaps)
|
fmaps = fmaps_.reshape(B, S, C, H, W)
|
||||||
|
self.fmaps_pyramid.append(fmaps)
|
||||||
def sample(self, coords):
|
|
||||||
r = self.radius
|
def sample(self, coords):
|
||||||
B, S, N, D = coords.shape
|
r = self.radius
|
||||||
assert D == 2
|
B, S, N, D = coords.shape
|
||||||
|
assert D == 2
|
||||||
H, W = self.H, self.W
|
|
||||||
out_pyramid = []
|
H, W = self.H, self.W
|
||||||
for i in range(self.num_levels):
|
out_pyramid = []
|
||||||
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
for i in range(self.num_levels):
|
||||||
*_, H, W = corrs.shape
|
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
||||||
|
*_, H, W = corrs.shape
|
||||||
dx = torch.linspace(-r, r, 2 * r + 1)
|
|
||||||
dy = torch.linspace(-r, r, 2 * r + 1)
|
dx = torch.linspace(-r, r, 2 * r + 1)
|
||||||
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
dy = torch.linspace(-r, r, 2 * r + 1)
|
||||||
|
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(coords.device)
|
||||||
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
|
||||||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2**i
|
||||||
coords_lvl = centroid_lvl + delta_lvl
|
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||||
|
coords_lvl = centroid_lvl + delta_lvl
|
||||||
corrs = bilinear_sampler(
|
|
||||||
corrs.reshape(B * S * N, 1, H, W),
|
corrs = bilinear_sampler(
|
||||||
coords_lvl,
|
corrs.reshape(B * S * N, 1, H, W),
|
||||||
padding_mode=self.padding_mode,
|
coords_lvl,
|
||||||
)
|
padding_mode=self.padding_mode,
|
||||||
corrs = corrs.view(B, S, N, -1)
|
)
|
||||||
out_pyramid.append(corrs)
|
corrs = corrs.view(B, S, N, -1)
|
||||||
|
out_pyramid.append(corrs)
|
||||||
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
|
||||||
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
|
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
||||||
return out
|
out = out.permute(0, 2, 1, 3).contiguous().view(B * N, S, -1).float()
|
||||||
|
return out
|
||||||
def corr(self, targets):
|
|
||||||
B, S, N, C = targets.shape
|
def corr(self, targets):
|
||||||
if self.multiple_track_feats:
|
B, S, N, C = targets.shape
|
||||||
targets_split = targets.split(C // self.num_levels, dim=-1)
|
if self.multiple_track_feats:
|
||||||
B, S, N, C = targets_split[0].shape
|
targets_split = targets.split(C // self.num_levels, dim=-1)
|
||||||
|
B, S, N, C = targets_split[0].shape
|
||||||
assert C == self.C
|
|
||||||
assert S == self.S
|
assert C == self.C
|
||||||
|
assert S == self.S
|
||||||
fmap1 = targets
|
|
||||||
|
fmap1 = targets
|
||||||
self.corrs_pyramid = []
|
|
||||||
for i, fmaps in enumerate(self.fmaps_pyramid):
|
self.corrs_pyramid = []
|
||||||
*_, H, W = fmaps.shape
|
for i, fmaps in enumerate(self.fmaps_pyramid):
|
||||||
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
*_, H, W = fmaps.shape
|
||||||
if self.multiple_track_feats:
|
fmap2s = fmaps.view(B, S, C, H * W) # B S C H W -> B S C (H W)
|
||||||
fmap1 = targets_split[i]
|
if self.multiple_track_feats:
|
||||||
corrs = torch.matmul(fmap1, fmap2s)
|
fmap1 = targets_split[i]
|
||||||
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
corrs = torch.matmul(fmap1, fmap2s)
|
||||||
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
corrs = corrs.view(B, S, N, H, W) # B S N (H W) -> B S N H W
|
||||||
self.corrs_pyramid.append(corrs)
|
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
||||||
|
self.corrs_pyramid.append(corrs)
|
||||||
|
|
||||||
class Attention(nn.Module):
|
|
||||||
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
|
class Attention(nn.Module):
|
||||||
super().__init__()
|
def __init__(self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False):
|
||||||
inner_dim = dim_head * num_heads
|
super().__init__()
|
||||||
context_dim = default(context_dim, query_dim)
|
inner_dim = dim_head * num_heads
|
||||||
self.scale = dim_head**-0.5
|
context_dim = default(context_dim, query_dim)
|
||||||
self.heads = num_heads
|
self.scale = dim_head**-0.5
|
||||||
|
self.heads = num_heads
|
||||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
|
||||||
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
||||||
self.to_out = nn.Linear(inner_dim, query_dim)
|
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
||||||
|
self.to_out = nn.Linear(inner_dim, query_dim)
|
||||||
def forward(self, x, context=None, attn_bias=None):
|
|
||||||
B, N1, C = x.shape
|
def forward(self, x, context=None, attn_bias=None):
|
||||||
h = self.heads
|
B, N1, C = x.shape
|
||||||
|
h = self.heads
|
||||||
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
|
||||||
context = default(context, x)
|
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
||||||
k, v = self.to_kv(context).chunk(2, dim=-1)
|
context = default(context, x)
|
||||||
|
k, v = self.to_kv(context).chunk(2, dim=-1)
|
||||||
N2 = context.shape[1]
|
|
||||||
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
N2 = context.shape[1]
|
||||||
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||||
|
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
||||||
sim = (q @ k.transpose(-2, -1)) * self.scale
|
|
||||||
|
sim = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
if attn_bias is not None:
|
|
||||||
sim = sim + attn_bias
|
if attn_bias is not None:
|
||||||
attn = sim.softmax(dim=-1)
|
sim = sim + attn_bias
|
||||||
|
attn = sim.softmax(dim=-1)
|
||||||
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
|
||||||
return self.to_out(x)
|
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
||||||
|
return self.to_out(x)
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
|
||||||
def __init__(
|
class AttnBlock(nn.Module):
|
||||||
self,
|
def __init__(
|
||||||
hidden_size,
|
self,
|
||||||
num_heads,
|
hidden_size,
|
||||||
attn_class: Callable[..., nn.Module] = Attention,
|
num_heads,
|
||||||
mlp_ratio=4.0,
|
attn_class: Callable[..., nn.Module] = Attention,
|
||||||
**block_kwargs
|
mlp_ratio=4.0,
|
||||||
):
|
**block_kwargs
|
||||||
super().__init__()
|
):
|
||||||
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
super().__init__()
|
||||||
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
|
self.attn = attn_class(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
|
||||||
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
|
||||||
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||||
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||||
self.mlp = Mlp(
|
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
||||||
in_features=hidden_size,
|
self.mlp = Mlp(
|
||||||
hidden_features=mlp_hidden_dim,
|
in_features=hidden_size,
|
||||||
act_layer=approx_gelu,
|
hidden_features=mlp_hidden_dim,
|
||||||
drop=0,
|
act_layer=approx_gelu,
|
||||||
)
|
drop=0,
|
||||||
|
)
|
||||||
def forward(self, x, mask=None):
|
|
||||||
attn_bias = mask
|
def forward(self, x, mask=None):
|
||||||
if mask is not None:
|
attn_bias = mask
|
||||||
mask = (
|
if mask is not None:
|
||||||
(mask[:, None] * mask[:, :, None])
|
mask = (
|
||||||
.unsqueeze(1)
|
(mask[:, None] * mask[:, :, None])
|
||||||
.expand(-1, self.attn.num_heads, -1, -1)
|
.unsqueeze(1)
|
||||||
)
|
.expand(-1, self.attn.num_heads, -1, -1)
|
||||||
max_neg_value = -torch.finfo(x.dtype).max
|
)
|
||||||
attn_bias = (~mask) * max_neg_value
|
max_neg_value = -torch.finfo(x.dtype).max
|
||||||
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
attn_bias = (~mask) * max_neg_value
|
||||||
x = x + self.mlp(self.norm2(x))
|
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
||||||
return x
|
x = x + self.mlp(self.norm2(x))
|
||||||
|
return x
|
||||||
|
File diff suppressed because it is too large
Load Diff
@@ -1,61 +1,61 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from cotracker.models.core.model_utils import reduce_masked_mean
|
from cotracker.models.core.model_utils import reduce_masked_mean
|
||||||
|
|
||||||
EPS = 1e-6
|
EPS = 1e-6
|
||||||
|
|
||||||
|
|
||||||
def balanced_ce_loss(pred, gt, valid=None):
|
def balanced_ce_loss(pred, gt, valid=None):
|
||||||
total_balanced_loss = 0.0
|
total_balanced_loss = 0.0
|
||||||
for j in range(len(gt)):
|
for j in range(len(gt)):
|
||||||
B, S, N = gt[j].shape
|
B, S, N = gt[j].shape
|
||||||
# pred and gt are the same shape
|
# pred and gt are the same shape
|
||||||
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
for (a, b) in zip(pred[j].size(), gt[j].size()):
|
||||||
assert a == b # some shape mismatch!
|
assert a == b # some shape mismatch!
|
||||||
# if valid is not None:
|
# if valid is not None:
|
||||||
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
for (a, b) in zip(pred[j].size(), valid[j].size()):
|
||||||
assert a == b # some shape mismatch!
|
assert a == b # some shape mismatch!
|
||||||
|
|
||||||
pos = (gt[j] > 0.95).float()
|
pos = (gt[j] > 0.95).float()
|
||||||
neg = (gt[j] < 0.05).float()
|
neg = (gt[j] < 0.05).float()
|
||||||
|
|
||||||
label = pos * 2.0 - 1.0
|
label = pos * 2.0 - 1.0
|
||||||
a = -label * pred[j]
|
a = -label * pred[j]
|
||||||
b = F.relu(a)
|
b = F.relu(a)
|
||||||
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
loss = b + torch.log(torch.exp(-b) + torch.exp(a - b))
|
||||||
|
|
||||||
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
pos_loss = reduce_masked_mean(loss, pos * valid[j])
|
||||||
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
neg_loss = reduce_masked_mean(loss, neg * valid[j])
|
||||||
|
|
||||||
balanced_loss = pos_loss + neg_loss
|
balanced_loss = pos_loss + neg_loss
|
||||||
total_balanced_loss += balanced_loss / float(N)
|
total_balanced_loss += balanced_loss / float(N)
|
||||||
return total_balanced_loss
|
return total_balanced_loss
|
||||||
|
|
||||||
|
|
||||||
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
|
def sequence_loss(flow_preds, flow_gt, vis, valids, gamma=0.8):
|
||||||
"""Loss function defined over sequence of flow predictions"""
|
"""Loss function defined over sequence of flow predictions"""
|
||||||
total_flow_loss = 0.0
|
total_flow_loss = 0.0
|
||||||
for j in range(len(flow_gt)):
|
for j in range(len(flow_gt)):
|
||||||
B, S, N, D = flow_gt[j].shape
|
B, S, N, D = flow_gt[j].shape
|
||||||
assert D == 2
|
assert D == 2
|
||||||
B, S1, N = vis[j].shape
|
B, S1, N = vis[j].shape
|
||||||
B, S2, N = valids[j].shape
|
B, S2, N = valids[j].shape
|
||||||
assert S == S1
|
assert S == S1
|
||||||
assert S == S2
|
assert S == S2
|
||||||
n_predictions = len(flow_preds[j])
|
n_predictions = len(flow_preds[j])
|
||||||
flow_loss = 0.0
|
flow_loss = 0.0
|
||||||
for i in range(n_predictions):
|
for i in range(n_predictions):
|
||||||
i_weight = gamma ** (n_predictions - i - 1)
|
i_weight = gamma ** (n_predictions - i - 1)
|
||||||
flow_pred = flow_preds[j][i]
|
flow_pred = flow_preds[j][i]
|
||||||
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
i_loss = (flow_pred - flow_gt[j]).abs() # B, S, N, 2
|
||||||
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
||||||
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
|
flow_loss += i_weight * reduce_masked_mean(i_loss, valids[j])
|
||||||
flow_loss = flow_loss / n_predictions
|
flow_loss = flow_loss / n_predictions
|
||||||
total_flow_loss += flow_loss / float(N)
|
total_flow_loss += flow_loss / float(N)
|
||||||
return total_flow_loss
|
return total_flow_loss
|
||||||
|
@@ -1,120 +1,120 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed(
|
def get_2d_sincos_pos_embed(
|
||||||
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
|
embed_dim: int, grid_size: Union[int, Tuple[int, int]]
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
|
||||||
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
It is a wrapper of get_2d_sincos_pos_embed_from_grid.
|
||||||
Args:
|
Args:
|
||||||
- embed_dim: The embedding dimension.
|
- embed_dim: The embedding dimension.
|
||||||
- grid_size: The grid size.
|
- grid_size: The grid size.
|
||||||
Returns:
|
Returns:
|
||||||
- pos_embed: The generated 2D positional embedding.
|
- pos_embed: The generated 2D positional embedding.
|
||||||
"""
|
"""
|
||||||
if isinstance(grid_size, tuple):
|
if isinstance(grid_size, tuple):
|
||||||
grid_size_h, grid_size_w = grid_size
|
grid_size_h, grid_size_w = grid_size
|
||||||
else:
|
else:
|
||||||
grid_size_h = grid_size_w = grid_size
|
grid_size_h = grid_size_w = grid_size
|
||||||
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
grid_h = torch.arange(grid_size_h, dtype=torch.float)
|
||||||
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
grid_w = torch.arange(grid_size_w, dtype=torch.float)
|
||||||
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
|
||||||
grid = torch.stack(grid, dim=0)
|
grid = torch.stack(grid, dim=0)
|
||||||
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
|
||||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||||
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
|
||||||
|
|
||||||
|
|
||||||
def get_2d_sincos_pos_embed_from_grid(
|
def get_2d_sincos_pos_embed_from_grid(
|
||||||
embed_dim: int, grid: torch.Tensor
|
embed_dim: int, grid: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
This function generates a 2D positional embedding from a given grid using sine and cosine functions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- embed_dim: The embedding dimension.
|
- embed_dim: The embedding dimension.
|
||||||
- grid: The grid to generate the embedding from.
|
- grid: The grid to generate the embedding from.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- emb: The generated 2D positional embedding.
|
- emb: The generated 2D positional embedding.
|
||||||
"""
|
"""
|
||||||
assert embed_dim % 2 == 0
|
assert embed_dim % 2 == 0
|
||||||
|
|
||||||
# use half of dimensions to encode grid_h
|
# use half of dimensions to encode grid_h
|
||||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||||
|
|
||||||
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
|
||||||
return emb
|
return emb
|
||||||
|
|
||||||
|
|
||||||
def get_1d_sincos_pos_embed_from_grid(
|
def get_1d_sincos_pos_embed_from_grid(
|
||||||
embed_dim: int, pos: torch.Tensor
|
embed_dim: int, pos: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- embed_dim: The embedding dimension.
|
- embed_dim: The embedding dimension.
|
||||||
- pos: The position to generate the embedding from.
|
- pos: The position to generate the embedding from.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- emb: The generated 1D positional embedding.
|
- emb: The generated 1D positional embedding.
|
||||||
"""
|
"""
|
||||||
assert embed_dim % 2 == 0
|
assert embed_dim % 2 == 0
|
||||||
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
||||||
omega /= embed_dim / 2.0
|
omega /= embed_dim / 2.0
|
||||||
omega = 1.0 / 10000**omega # (D/2,)
|
omega = 1.0 / 10000**omega # (D/2,)
|
||||||
|
|
||||||
pos = pos.reshape(-1) # (M,)
|
pos = pos.reshape(-1) # (M,)
|
||||||
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||||
|
|
||||||
emb_sin = torch.sin(out) # (M, D/2)
|
emb_sin = torch.sin(out) # (M, D/2)
|
||||||
emb_cos = torch.cos(out) # (M, D/2)
|
emb_cos = torch.cos(out) # (M, D/2)
|
||||||
|
|
||||||
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
||||||
return emb[None].float()
|
return emb[None].float()
|
||||||
|
|
||||||
|
|
||||||
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
- xy: The coordinates to generate the embedding from.
|
- xy: The coordinates to generate the embedding from.
|
||||||
- C: The size of the embedding.
|
- C: The size of the embedding.
|
||||||
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
- cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- pe: The generated 2D positional embedding.
|
- pe: The generated 2D positional embedding.
|
||||||
"""
|
"""
|
||||||
B, N, D = xy.shape
|
B, N, D = xy.shape
|
||||||
assert D == 2
|
assert D == 2
|
||||||
|
|
||||||
x = xy[:, :, 0:1]
|
x = xy[:, :, 0:1]
|
||||||
y = xy[:, :, 1:2]
|
y = xy[:, :, 1:2]
|
||||||
div_term = (
|
div_term = (
|
||||||
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)
|
||||||
).reshape(1, 1, int(C / 2))
|
).reshape(1, 1, int(C / 2))
|
||||||
|
|
||||||
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||||
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
|
||||||
|
|
||||||
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
pe_x[:, :, 0::2] = torch.sin(x * div_term)
|
||||||
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
pe_x[:, :, 1::2] = torch.cos(x * div_term)
|
||||||
|
|
||||||
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
pe_y[:, :, 0::2] = torch.sin(y * div_term)
|
||||||
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
pe_y[:, :, 1::2] = torch.cos(y * div_term)
|
||||||
|
|
||||||
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
|
||||||
if cat_coords:
|
if cat_coords:
|
||||||
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
|
||||||
return pe
|
return pe
|
||||||
|
@@ -1,256 +1,256 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
EPS = 1e-6
|
EPS = 1e-6
|
||||||
|
|
||||||
|
|
||||||
def smart_cat(tensor1, tensor2, dim):
|
def smart_cat(tensor1, tensor2, dim):
|
||||||
if tensor1 is None:
|
if tensor1 is None:
|
||||||
return tensor2
|
return tensor2
|
||||||
return torch.cat([tensor1, tensor2], dim=dim)
|
return torch.cat([tensor1, tensor2], dim=dim)
|
||||||
|
|
||||||
|
|
||||||
def get_points_on_a_grid(
|
def get_points_on_a_grid(
|
||||||
size: int,
|
size: int,
|
||||||
extent: Tuple[float, ...],
|
extent: Tuple[float, ...],
|
||||||
center: Optional[Tuple[float, ...]] = None,
|
center: Optional[Tuple[float, ...]] = None,
|
||||||
device: Optional[torch.device] = torch.device("cpu"),
|
device: Optional[torch.device] = torch.device("cpu"),
|
||||||
):
|
):
|
||||||
r"""Get a grid of points covering a rectangular region
|
r"""Get a grid of points covering a rectangular region
|
||||||
|
|
||||||
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
|
`get_points_on_a_grid(size, extent)` generates a :attr:`size` by
|
||||||
:attr:`size` grid fo points distributed to cover a rectangular area
|
:attr:`size` grid fo points distributed to cover a rectangular area
|
||||||
specified by `extent`.
|
specified by `extent`.
|
||||||
|
|
||||||
The `extent` is a pair of integer :math:`(H,W)` specifying the height
|
The `extent` is a pair of integer :math:`(H,W)` specifying the height
|
||||||
and width of the rectangle.
|
and width of the rectangle.
|
||||||
|
|
||||||
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
|
Optionally, the :attr:`center` can be specified as a pair :math:`(c_y,c_x)`
|
||||||
specifying the vertical and horizontal center coordinates. The center
|
specifying the vertical and horizontal center coordinates. The center
|
||||||
defaults to the middle of the extent.
|
defaults to the middle of the extent.
|
||||||
|
|
||||||
Points are distributed uniformly within the rectangle leaving a margin
|
Points are distributed uniformly within the rectangle leaving a margin
|
||||||
:math:`m=W/64` from the border.
|
:math:`m=W/64` from the border.
|
||||||
|
|
||||||
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
|
It returns a :math:`(1, \text{size} \times \text{size}, 2)` tensor of
|
||||||
points :math:`P_{ij}=(x_i, y_i)` where
|
points :math:`P_{ij}=(x_i, y_i)` where
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
P_{ij} = \left(
|
P_{ij} = \left(
|
||||||
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
|
c_x + m -\frac{W}{2} + \frac{W - 2m}{\text{size} - 1}\, j,~
|
||||||
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
|
c_y + m -\frac{H}{2} + \frac{H - 2m}{\text{size} - 1}\, i
|
||||||
\right)
|
\right)
|
||||||
|
|
||||||
Points are returned in row-major order.
|
Points are returned in row-major order.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
size (int): grid size.
|
size (int): grid size.
|
||||||
extent (tuple): height and with of the grid extent.
|
extent (tuple): height and with of the grid extent.
|
||||||
center (tuple, optional): grid center.
|
center (tuple, optional): grid center.
|
||||||
device (str, optional): Defaults to `"cpu"`.
|
device (str, optional): Defaults to `"cpu"`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: grid.
|
Tensor: grid.
|
||||||
"""
|
"""
|
||||||
if size == 1:
|
if size == 1:
|
||||||
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
|
return torch.tensor([extent[1] / 2, extent[0] / 2], device=device)[None, None]
|
||||||
|
|
||||||
if center is None:
|
if center is None:
|
||||||
center = [extent[0] / 2, extent[1] / 2]
|
center = [extent[0] / 2, extent[1] / 2]
|
||||||
|
|
||||||
margin = extent[1] / 64
|
margin = extent[1] / 64
|
||||||
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
|
range_y = (margin - extent[0] / 2 + center[0], extent[0] / 2 + center[0] - margin)
|
||||||
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
|
range_x = (margin - extent[1] / 2 + center[1], extent[1] / 2 + center[1] - margin)
|
||||||
grid_y, grid_x = torch.meshgrid(
|
grid_y, grid_x = torch.meshgrid(
|
||||||
torch.linspace(*range_y, size, device=device),
|
torch.linspace(*range_y, size, device=device),
|
||||||
torch.linspace(*range_x, size, device=device),
|
torch.linspace(*range_x, size, device=device),
|
||||||
indexing="ij",
|
indexing="ij",
|
||||||
)
|
)
|
||||||
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
return torch.stack([grid_x, grid_y], dim=-1).reshape(1, -1, 2)
|
||||||
|
|
||||||
|
|
||||||
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
|
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
|
||||||
r"""Masked mean
|
r"""Masked mean
|
||||||
|
|
||||||
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
|
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
|
||||||
over a mask :attr:`mask`, returning
|
over a mask :attr:`mask`, returning
|
||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
\text{output} =
|
\text{output} =
|
||||||
\frac
|
\frac
|
||||||
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
|
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
|
||||||
{\epsilon + \sum_{i=1}^N \text{mask}_i}
|
{\epsilon + \sum_{i=1}^N \text{mask}_i}
|
||||||
|
|
||||||
where :math:`N` is the number of elements in :attr:`input` and
|
where :math:`N` is the number of elements in :attr:`input` and
|
||||||
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
|
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
|
||||||
division by zero.
|
division by zero.
|
||||||
|
|
||||||
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
|
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
|
||||||
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
|
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
|
||||||
Optionally, the dimension can be kept in the output by setting
|
Optionally, the dimension can be kept in the output by setting
|
||||||
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
|
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
|
||||||
the same dimension as :attr:`input`.
|
the same dimension as :attr:`input`.
|
||||||
|
|
||||||
The interface is similar to `torch.mean()`.
|
The interface is similar to `torch.mean()`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
inout (Tensor): input tensor.
|
inout (Tensor): input tensor.
|
||||||
mask (Tensor): mask.
|
mask (Tensor): mask.
|
||||||
dim (int, optional): Dimension to sum over. Defaults to None.
|
dim (int, optional): Dimension to sum over. Defaults to None.
|
||||||
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
|
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: mean tensor.
|
Tensor: mean tensor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mask = mask.expand_as(input)
|
mask = mask.expand_as(input)
|
||||||
|
|
||||||
prod = input * mask
|
prod = input * mask
|
||||||
|
|
||||||
if dim is None:
|
if dim is None:
|
||||||
numer = torch.sum(prod)
|
numer = torch.sum(prod)
|
||||||
denom = torch.sum(mask)
|
denom = torch.sum(mask)
|
||||||
else:
|
else:
|
||||||
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
||||||
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
|
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
|
||||||
|
|
||||||
mean = numer / (EPS + denom)
|
mean = numer / (EPS + denom)
|
||||||
return mean
|
return mean
|
||||||
|
|
||||||
|
|
||||||
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
|
||||||
r"""Sample a tensor using bilinear interpolation
|
r"""Sample a tensor using bilinear interpolation
|
||||||
|
|
||||||
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
`bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
|
||||||
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
coordinates :attr:`coords` using bilinear interpolation. It is the same
|
||||||
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
as `torch.nn.functional.grid_sample()` but with a different coordinate
|
||||||
convention.
|
convention.
|
||||||
|
|
||||||
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
|
||||||
:math:`B` is the batch size, :math:`C` is the number of channels,
|
:math:`B` is the batch size, :math:`C` is the number of channels,
|
||||||
:math:`H` is the height of the image, and :math:`W` is the width of the
|
:math:`H` is the height of the image, and :math:`W` is the width of the
|
||||||
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
|
||||||
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
|
||||||
|
|
||||||
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
|
||||||
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
|
||||||
that in this case the order of the components is slightly different
|
that in this case the order of the components is slightly different
|
||||||
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
|
||||||
|
|
||||||
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
|
||||||
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
|
||||||
left-most image pixel :math:`W-1` to the center of the right-most
|
left-most image pixel :math:`W-1` to the center of the right-most
|
||||||
pixel.
|
pixel.
|
||||||
|
|
||||||
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
If `align_corners` is `False`, the coordinate :math:`x` is assumed to
|
||||||
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
|
||||||
the left-most pixel :math:`W` to the right edge of the right-most
|
the left-most pixel :math:`W` to the right edge of the right-most
|
||||||
pixel.
|
pixel.
|
||||||
|
|
||||||
Similar conventions apply to the :math:`y` for the range
|
Similar conventions apply to the :math:`y` for the range
|
||||||
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
:math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
|
||||||
:math:`[0,T-1]` and :math:`[0,T]`.
|
:math:`[0,T-1]` and :math:`[0,T]`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): batch of input images.
|
input (Tensor): batch of input images.
|
||||||
coords (Tensor): batch of coordinates.
|
coords (Tensor): batch of coordinates.
|
||||||
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
align_corners (bool, optional): Coordinate convention. Defaults to `True`.
|
||||||
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
padding_mode (str, optional): Padding mode. Defaults to `"border"`.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: sampled points.
|
Tensor: sampled points.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
sizes = input.shape[2:]
|
sizes = input.shape[2:]
|
||||||
|
|
||||||
assert len(sizes) in [2, 3]
|
assert len(sizes) in [2, 3]
|
||||||
|
|
||||||
if len(sizes) == 3:
|
if len(sizes) == 3:
|
||||||
# t x y -> x y t to match dimensions T H W in grid_sample
|
# t x y -> x y t to match dimensions T H W in grid_sample
|
||||||
coords = coords[..., [1, 2, 0]]
|
coords = coords[..., [1, 2, 0]]
|
||||||
|
|
||||||
if align_corners:
|
if align_corners:
|
||||||
coords = coords * torch.tensor(
|
coords = coords * torch.tensor(
|
||||||
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
[2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
coords = coords * torch.tensor([2 / size for size in reversed(sizes)], device=coords.device)
|
||||||
|
|
||||||
coords -= 1
|
coords -= 1
|
||||||
|
|
||||||
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
|
||||||
|
|
||||||
|
|
||||||
def sample_features4d(input, coords):
|
def sample_features4d(input, coords):
|
||||||
r"""Sample spatial features
|
r"""Sample spatial features
|
||||||
|
|
||||||
`sample_features4d(input, coords)` samples the spatial features
|
`sample_features4d(input, coords)` samples the spatial features
|
||||||
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
:attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
|
||||||
|
|
||||||
The field is sampled at coordinates :attr:`coords` using bilinear
|
The field is sampled at coordinates :attr:`coords` using bilinear
|
||||||
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
|
||||||
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
3)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
|
||||||
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
same convention as :func:`bilinear_sampler` with `align_corners=True`.
|
||||||
|
|
||||||
The output tensor has one feature per point, and has shape :math:`(B,
|
The output tensor has one feature per point, and has shape :math:`(B,
|
||||||
R, C)`.
|
R, C)`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): spatial features.
|
input (Tensor): spatial features.
|
||||||
coords (Tensor): points.
|
coords (Tensor): points.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: sampled features.
|
Tensor: sampled features.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
B, _, _, _ = input.shape
|
B, _, _, _ = input.shape
|
||||||
|
|
||||||
# B R 2 -> B R 1 2
|
# B R 2 -> B R 1 2
|
||||||
coords = coords.unsqueeze(2)
|
coords = coords.unsqueeze(2)
|
||||||
|
|
||||||
# B C R 1
|
# B C R 1
|
||||||
feats = bilinear_sampler(input, coords)
|
feats = bilinear_sampler(input, coords)
|
||||||
|
|
||||||
return feats.permute(0, 2, 1, 3).view(
|
return feats.permute(0, 2, 1, 3).view(
|
||||||
B, -1, feats.shape[1] * feats.shape[3]
|
B, -1, feats.shape[1] * feats.shape[3]
|
||||||
) # B C R 1 -> B R C
|
) # B C R 1 -> B R C
|
||||||
|
|
||||||
|
|
||||||
def sample_features5d(input, coords):
|
def sample_features5d(input, coords):
|
||||||
r"""Sample spatio-temporal features
|
r"""Sample spatio-temporal features
|
||||||
|
|
||||||
`sample_features5d(input, coords)` works in the same way as
|
`sample_features5d(input, coords)` works in the same way as
|
||||||
:func:`sample_features4d` but for spatio-temporal features and points:
|
:func:`sample_features4d` but for spatio-temporal features and points:
|
||||||
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
:attr:`input` is a 5D tensor :math:`(B, T, C, H, W)`, :attr:`coords` is
|
||||||
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
a :math:`(B, R1, R2, 3)` tensor of spatio-temporal point :math:`(t_i,
|
||||||
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
x_i, y_i)`. The output tensor has shape :math:`(B, R1, R2, C)`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input (Tensor): spatio-temporal features.
|
input (Tensor): spatio-temporal features.
|
||||||
coords (Tensor): spatio-temporal points.
|
coords (Tensor): spatio-temporal points.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tensor: sampled features.
|
Tensor: sampled features.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
B, T, _, _, _ = input.shape
|
B, T, _, _, _ = input.shape
|
||||||
|
|
||||||
# B T C H W -> B C T H W
|
# B T C H W -> B C T H W
|
||||||
input = input.permute(0, 2, 1, 3, 4)
|
input = input.permute(0, 2, 1, 3, 4)
|
||||||
|
|
||||||
# B R1 R2 3 -> B R1 R2 1 3
|
# B R1 R2 3 -> B R1 R2 1 3
|
||||||
coords = coords.unsqueeze(3)
|
coords = coords.unsqueeze(3)
|
||||||
|
|
||||||
# B C R1 R2 1
|
# B C R1 R2 1
|
||||||
feats = bilinear_sampler(input, coords)
|
feats = bilinear_sampler(input, coords)
|
||||||
|
|
||||||
return feats.permute(0, 2, 3, 1, 4).view(
|
return feats.permute(0, 2, 3, 1, 4).view(
|
||||||
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
B, feats.shape[2], feats.shape[3], feats.shape[1]
|
||||||
) # B C R1 R2 1 -> B R1 R2 C
|
) # B C R1 R2 1 -> B R1 R2 C
|
||||||
|
@@ -1,104 +1,104 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
from cotracker.models.core.cotracker.cotracker import CoTracker2
|
||||||
from cotracker.models.core.model_utils import get_points_on_a_grid
|
from cotracker.models.core.model_utils import get_points_on_a_grid
|
||||||
|
|
||||||
|
|
||||||
class EvaluationPredictor(torch.nn.Module):
|
class EvaluationPredictor(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
cotracker_model: CoTracker2,
|
cotracker_model: CoTracker2,
|
||||||
interp_shape: Tuple[int, int] = (384, 512),
|
interp_shape: Tuple[int, int] = (384, 512),
|
||||||
grid_size: int = 5,
|
grid_size: int = 5,
|
||||||
local_grid_size: int = 8,
|
local_grid_size: int = 8,
|
||||||
single_point: bool = True,
|
single_point: bool = True,
|
||||||
n_iters: int = 6,
|
n_iters: int = 6,
|
||||||
) -> None:
|
) -> None:
|
||||||
super(EvaluationPredictor, self).__init__()
|
super(EvaluationPredictor, self).__init__()
|
||||||
self.grid_size = grid_size
|
self.grid_size = grid_size
|
||||||
self.local_grid_size = local_grid_size
|
self.local_grid_size = local_grid_size
|
||||||
self.single_point = single_point
|
self.single_point = single_point
|
||||||
self.interp_shape = interp_shape
|
self.interp_shape = interp_shape
|
||||||
self.n_iters = n_iters
|
self.n_iters = n_iters
|
||||||
|
|
||||||
self.model = cotracker_model
|
self.model = cotracker_model
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
def forward(self, video, queries):
|
def forward(self, video, queries):
|
||||||
queries = queries.clone()
|
queries = queries.clone()
|
||||||
B, T, C, H, W = video.shape
|
B, T, C, H, W = video.shape
|
||||||
B, N, D = queries.shape
|
B, N, D = queries.shape
|
||||||
|
|
||||||
assert D == 3
|
assert D == 3
|
||||||
|
|
||||||
video = video.reshape(B * T, C, H, W)
|
video = video.reshape(B * T, C, H, W)
|
||||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||||
|
|
||||||
device = video.device
|
device = video.device
|
||||||
|
|
||||||
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
|
queries[:, :, 1] *= (self.interp_shape[1] - 1) / (W - 1)
|
||||||
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
|
queries[:, :, 2] *= (self.interp_shape[0] - 1) / (H - 1)
|
||||||
|
|
||||||
if self.single_point:
|
if self.single_point:
|
||||||
traj_e = torch.zeros((B, T, N, 2), device=device)
|
traj_e = torch.zeros((B, T, N, 2), device=device)
|
||||||
vis_e = torch.zeros((B, T, N), device=device)
|
vis_e = torch.zeros((B, T, N), device=device)
|
||||||
for pind in range((N)):
|
for pind in range((N)):
|
||||||
query = queries[:, pind : pind + 1]
|
query = queries[:, pind : pind + 1]
|
||||||
|
|
||||||
t = query[0, 0, 0].long()
|
t = query[0, 0, 0].long()
|
||||||
|
|
||||||
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
|
traj_e_pind, vis_e_pind = self._process_one_point(video, query)
|
||||||
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
traj_e[:, t:, pind : pind + 1] = traj_e_pind[:, :, :1]
|
||||||
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
vis_e[:, t:, pind : pind + 1] = vis_e_pind[:, :, :1]
|
||||||
else:
|
else:
|
||||||
if self.grid_size > 0:
|
if self.grid_size > 0:
|
||||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||||
queries = torch.cat([queries, xy], dim=1) #
|
queries = torch.cat([queries, xy], dim=1) #
|
||||||
|
|
||||||
traj_e, vis_e, __ = self.model(
|
traj_e, vis_e, __ = self.model(
|
||||||
video=video,
|
video=video,
|
||||||
queries=queries,
|
queries=queries,
|
||||||
iters=self.n_iters,
|
iters=self.n_iters,
|
||||||
)
|
)
|
||||||
|
|
||||||
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
|
traj_e[:, :, :, 0] *= (W - 1) / float(self.interp_shape[1] - 1)
|
||||||
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
|
traj_e[:, :, :, 1] *= (H - 1) / float(self.interp_shape[0] - 1)
|
||||||
return traj_e, vis_e
|
return traj_e, vis_e
|
||||||
|
|
||||||
def _process_one_point(self, video, query):
|
def _process_one_point(self, video, query):
|
||||||
t = query[0, 0, 0].long()
|
t = query[0, 0, 0].long()
|
||||||
|
|
||||||
device = query.device
|
device = query.device
|
||||||
if self.local_grid_size > 0:
|
if self.local_grid_size > 0:
|
||||||
xy_target = get_points_on_a_grid(
|
xy_target = get_points_on_a_grid(
|
||||||
self.local_grid_size,
|
self.local_grid_size,
|
||||||
(50, 50),
|
(50, 50),
|
||||||
[query[0, 0, 2].item(), query[0, 0, 1].item()],
|
[query[0, 0, 2].item(), query[0, 0, 1].item()],
|
||||||
)
|
)
|
||||||
|
|
||||||
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
|
xy_target = torch.cat([torch.zeros_like(xy_target[:, :, :1]), xy_target], dim=2).to(
|
||||||
device
|
device
|
||||||
) #
|
) #
|
||||||
query = torch.cat([query, xy_target], dim=1) #
|
query = torch.cat([query, xy_target], dim=1) #
|
||||||
|
|
||||||
if self.grid_size > 0:
|
if self.grid_size > 0:
|
||||||
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
xy = get_points_on_a_grid(self.grid_size, video.shape[3:])
|
||||||
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
xy = torch.cat([torch.zeros_like(xy[:, :, :1]), xy], dim=2).to(device) #
|
||||||
query = torch.cat([query, xy], dim=1) #
|
query = torch.cat([query, xy], dim=1) #
|
||||||
# crop the video to start from the queried frame
|
# crop the video to start from the queried frame
|
||||||
query[0, 0, 0] = 0
|
query[0, 0, 0] = 0
|
||||||
traj_e_pind, vis_e_pind, __ = self.model(
|
traj_e_pind, vis_e_pind, __ = self.model(
|
||||||
video=video[:, t:], queries=query, iters=self.n_iters
|
video=video[:, t:], queries=query, iters=self.n_iters
|
||||||
)
|
)
|
||||||
|
|
||||||
return traj_e_pind, vis_e_pind
|
return traj_e_pind, vis_e_pind
|
||||||
|
@@ -1,275 +1,279 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
|
from cotracker.models.core.model_utils import smart_cat, get_points_on_a_grid
|
||||||
from cotracker.models.build_cotracker import build_cotracker
|
from cotracker.models.build_cotracker import build_cotracker
|
||||||
|
|
||||||
|
|
||||||
class CoTrackerPredictor(torch.nn.Module):
|
class CoTrackerPredictor(torch.nn.Module):
|
||||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.support_grid_size = 6
|
self.support_grid_size = 6
|
||||||
model = build_cotracker(checkpoint)
|
model = build_cotracker(checkpoint)
|
||||||
self.interp_shape = model.model_resolution
|
self.interp_shape = model.model_resolution
|
||||||
print(self.interp_shape)
|
print(self.interp_shape)
|
||||||
self.model = model
|
self.model = model
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
|
video, # (B, T, 3, H, W) Batch_size, time, rgb, height, width
|
||||||
# input prompt types:
|
# input prompt types:
|
||||||
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
# - None. Dense tracks are computed in this case. You can adjust *query_frame* to compute tracks starting from a specific frame.
|
||||||
# *backward_tracking=True* will compute tracks in both directions.
|
# *backward_tracking=True* will compute tracks in both directions.
|
||||||
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
# - queries. Queried points of shape (B, N, 3) in format (t, x, y) for frame index and pixel coordinates.
|
||||||
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
# - grid_size. Grid of N*N points from the first frame. if segm_mask is provided, then computed only for the mask.
|
||||||
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
# You can adjust *query_frame* and *backward_tracking* for the regular grid in the same way as for dense tracks.
|
||||||
queries: torch.Tensor = None,
|
queries: torch.Tensor = None,
|
||||||
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
|
segm_mask: torch.Tensor = None, # Segmentation mask of shape (B, 1, H, W)
|
||||||
grid_size: int = 0,
|
grid_size: int = 0,
|
||||||
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
grid_query_frame: int = 0, # only for dense and regular grid tracks
|
||||||
backward_tracking: bool = False,
|
backward_tracking: bool = False,
|
||||||
):
|
):
|
||||||
if queries is None and grid_size == 0:
|
if queries is None and grid_size == 0:
|
||||||
tracks, visibilities = self._compute_dense_tracks(
|
tracks, visibilities = self._compute_dense_tracks(
|
||||||
video,
|
video,
|
||||||
grid_query_frame=grid_query_frame,
|
grid_query_frame=grid_query_frame,
|
||||||
backward_tracking=backward_tracking,
|
backward_tracking=backward_tracking,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tracks, visibilities = self._compute_sparse_tracks(
|
tracks, visibilities = self._compute_sparse_tracks(
|
||||||
video,
|
video,
|
||||||
queries,
|
queries,
|
||||||
segm_mask,
|
segm_mask,
|
||||||
grid_size,
|
grid_size,
|
||||||
add_support_grid=(grid_size == 0 or segm_mask is not None),
|
add_support_grid=(grid_size == 0 or segm_mask is not None),
|
||||||
grid_query_frame=grid_query_frame,
|
grid_query_frame=grid_query_frame,
|
||||||
backward_tracking=backward_tracking,
|
backward_tracking=backward_tracking,
|
||||||
)
|
)
|
||||||
|
|
||||||
return tracks, visibilities
|
return tracks, visibilities
|
||||||
|
|
||||||
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
|
# gpu dense inference time
|
||||||
*_, H, W = video.shape
|
# raft gpu comparison
|
||||||
grid_step = W // grid_size
|
# vision effects
|
||||||
grid_width = W // grid_step
|
# raft integrated
|
||||||
grid_height = H // grid_step # set the whole video to grid_size number of grids
|
def _compute_dense_tracks(self, video, grid_query_frame, grid_size=80, backward_tracking=False):
|
||||||
tracks = visibilities = None
|
*_, H, W = video.shape
|
||||||
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
grid_step = W // grid_size
|
||||||
# (batch_size, grid_number, t,x,y)
|
grid_width = W // grid_step
|
||||||
grid_pts[0, :, 0] = grid_query_frame
|
grid_height = H // grid_step # set the whole video to grid_size number of grids
|
||||||
# iterate every grid
|
tracks = visibilities = None
|
||||||
for offset in range(grid_step * grid_step):
|
grid_pts = torch.zeros((1, grid_width * grid_height, 3)).to(video.device)
|
||||||
print(f"step {offset} / {grid_step * grid_step}")
|
# (batch_size, grid_number, t,x,y)
|
||||||
ox = offset % grid_step
|
grid_pts[0, :, 0] = grid_query_frame
|
||||||
oy = offset // grid_step
|
# iterate every grid
|
||||||
# initialize
|
for offset in range(grid_step * grid_step):
|
||||||
# for example
|
print(f"step {offset} / {grid_step * grid_step}")
|
||||||
# grid width = 4, grid height = 4, grid step = 10, ox = 1
|
ox = offset % grid_step
|
||||||
# torch.arange(grid_width) = [0,1,2,3]
|
oy = offset // grid_step
|
||||||
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
|
# initialize
|
||||||
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
|
# for example
|
||||||
# get the location in the image
|
# grid width = 4, grid height = 4, grid step = 10, ox = 1
|
||||||
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
# torch.arange(grid_width) = [0,1,2,3]
|
||||||
grid_pts[0, :, 2] = (
|
# torch.arange(grid_width).repeat(grid_height) = [0,1,2,3,0,1,2,3,0,1,2,3]
|
||||||
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
# torch.arange(grid_width).repeat(grid_height) * grid_step = [0,10,20,30,0,10,20,30,0,10,20,30]
|
||||||
)
|
# get the location in the image
|
||||||
tracks_step, visibilities_step = self._compute_sparse_tracks(
|
grid_pts[0, :, 1] = torch.arange(grid_width).repeat(grid_height) * grid_step + ox
|
||||||
video=video,
|
grid_pts[0, :, 2] = (
|
||||||
queries=grid_pts,
|
torch.arange(grid_height).repeat_interleave(grid_width) * grid_step + oy
|
||||||
backward_tracking=backward_tracking,
|
)
|
||||||
)
|
tracks_step, visibilities_step = self._compute_sparse_tracks(
|
||||||
tracks = smart_cat(tracks, tracks_step, dim=2)
|
video=video,
|
||||||
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
|
queries=grid_pts,
|
||||||
|
backward_tracking=backward_tracking,
|
||||||
return tracks, visibilities
|
)
|
||||||
|
tracks = smart_cat(tracks, tracks_step, dim=2)
|
||||||
def _compute_sparse_tracks(
|
visibilities = smart_cat(visibilities, visibilities_step, dim=2)
|
||||||
self,
|
|
||||||
video,
|
return tracks, visibilities
|
||||||
queries,
|
|
||||||
segm_mask=None,
|
def _compute_sparse_tracks(
|
||||||
grid_size=0,
|
self,
|
||||||
add_support_grid=False,
|
video,
|
||||||
grid_query_frame=0,
|
queries,
|
||||||
backward_tracking=False,
|
segm_mask=None,
|
||||||
):
|
grid_size=0,
|
||||||
B, T, C, H, W = video.shape
|
add_support_grid=False,
|
||||||
|
grid_query_frame=0,
|
||||||
video = video.reshape(B * T, C, H, W)
|
backward_tracking=False,
|
||||||
# ? what is interpolate?
|
):
|
||||||
# 将video插值成interp_shape?
|
B, T, C, H, W = video.shape
|
||||||
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
|
||||||
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
video = video.reshape(B * T, C, H, W)
|
||||||
|
# ? what is interpolate?
|
||||||
if queries is not None:
|
# 将video插值成interp_shape?
|
||||||
B, N, D = queries.shape # batch_size, number of points, (t,x,y)
|
video = F.interpolate(video, tuple(self.interp_shape), mode="bilinear", align_corners=True)
|
||||||
assert D == 3
|
video = video.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||||
# query 缩放到( interp_shape - 1 ) / (W - 1)
|
|
||||||
# 插完值之后缩放
|
if queries is not None:
|
||||||
queries = queries.clone()
|
B, N, D = queries.shape # batch_size, number of points, (t,x,y)
|
||||||
queries[:, :, 1:] *= queries.new_tensor(
|
assert D == 3
|
||||||
[
|
# query 缩放到( interp_shape - 1 ) / (W - 1)
|
||||||
(self.interp_shape[1] - 1) / (W - 1),
|
# 插完值之后缩放
|
||||||
(self.interp_shape[0] - 1) / (H - 1),
|
queries = queries.clone()
|
||||||
]
|
queries[:, :, 1:] *= queries.new_tensor(
|
||||||
)
|
[
|
||||||
# 生成grid
|
(self.interp_shape[1] - 1) / (W - 1),
|
||||||
elif grid_size > 0:
|
(self.interp_shape[0] - 1) / (H - 1),
|
||||||
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
]
|
||||||
if segm_mask is not None:
|
)
|
||||||
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
|
# 生成grid
|
||||||
point_mask = segm_mask[0, 0][
|
elif grid_size > 0:
|
||||||
(grid_pts[0, :, 1]).round().long().cpu(),
|
grid_pts = get_points_on_a_grid(grid_size, self.interp_shape, device=video.device)
|
||||||
(grid_pts[0, :, 0]).round().long().cpu(),
|
if segm_mask is not None:
|
||||||
].bool()
|
segm_mask = F.interpolate(segm_mask, tuple(self.interp_shape), mode="nearest")
|
||||||
grid_pts = grid_pts[:, point_mask]
|
point_mask = segm_mask[0, 0][
|
||||||
|
(grid_pts[0, :, 1]).round().long().cpu(),
|
||||||
queries = torch.cat(
|
(grid_pts[0, :, 0]).round().long().cpu(),
|
||||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
].bool()
|
||||||
dim=2,
|
grid_pts = grid_pts[:, point_mask]
|
||||||
).repeat(B, 1, 1)
|
|
||||||
|
queries = torch.cat(
|
||||||
# 添加支持点
|
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||||
|
dim=2,
|
||||||
if add_support_grid:
|
).repeat(B, 1, 1)
|
||||||
grid_pts = get_points_on_a_grid(
|
|
||||||
self.support_grid_size, self.interp_shape, device=video.device
|
# 添加支持点
|
||||||
)
|
|
||||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
if add_support_grid:
|
||||||
grid_pts = grid_pts.repeat(B, 1, 1)
|
grid_pts = get_points_on_a_grid(
|
||||||
queries = torch.cat([queries, grid_pts], dim=1)
|
self.support_grid_size, self.interp_shape, device=video.device
|
||||||
|
)
|
||||||
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
|
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||||
|
grid_pts = grid_pts.repeat(B, 1, 1)
|
||||||
if backward_tracking:
|
queries = torch.cat([queries, grid_pts], dim=1)
|
||||||
tracks, visibilities = self._compute_backward_tracks(
|
|
||||||
video, queries, tracks, visibilities
|
tracks, visibilities, __ = self.model.forward(video=video, queries=queries, iters=6)
|
||||||
)
|
|
||||||
if add_support_grid:
|
if backward_tracking:
|
||||||
queries[:, -self.support_grid_size**2 :, 0] = T - 1
|
tracks, visibilities = self._compute_backward_tracks(
|
||||||
if add_support_grid:
|
video, queries, tracks, visibilities
|
||||||
tracks = tracks[:, :, : -self.support_grid_size**2]
|
)
|
||||||
visibilities = visibilities[:, :, : -self.support_grid_size**2]
|
if add_support_grid:
|
||||||
thr = 0.9
|
queries[:, -self.support_grid_size**2 :, 0] = T - 1
|
||||||
visibilities = visibilities > thr
|
if add_support_grid:
|
||||||
|
tracks = tracks[:, :, : -self.support_grid_size**2]
|
||||||
# correct query-point predictions
|
visibilities = visibilities[:, :, : -self.support_grid_size**2]
|
||||||
# see https://github.com/facebookresearch/co-tracker/issues/28
|
thr = 0.9
|
||||||
|
visibilities = visibilities > thr
|
||||||
# TODO: batchify
|
|
||||||
for i in range(len(queries)):
|
# correct query-point predictions
|
||||||
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
|
# see https://github.com/facebookresearch/co-tracker/issues/28
|
||||||
arange = torch.arange(0, len(queries_t))
|
|
||||||
|
# TODO: batchify
|
||||||
# overwrite the predictions with the query points
|
for i in range(len(queries)):
|
||||||
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
|
queries_t = queries[i, : tracks.size(2), 0].to(torch.int64)
|
||||||
|
arange = torch.arange(0, len(queries_t))
|
||||||
# correct visibilities, the query points should be visible
|
|
||||||
visibilities[i, queries_t, arange] = True
|
# overwrite the predictions with the query points
|
||||||
|
tracks[i, queries_t, arange] = queries[i, : tracks.size(2), 1:]
|
||||||
tracks *= tracks.new_tensor(
|
|
||||||
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
|
# correct visibilities, the query points should be visible
|
||||||
)
|
visibilities[i, queries_t, arange] = True
|
||||||
return tracks, visibilities
|
|
||||||
|
tracks *= tracks.new_tensor(
|
||||||
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
[(W - 1) / (self.interp_shape[1] - 1), (H - 1) / (self.interp_shape[0] - 1)]
|
||||||
inv_video = video.flip(1).clone()
|
)
|
||||||
inv_queries = queries.clone()
|
return tracks, visibilities
|
||||||
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
|
||||||
|
def _compute_backward_tracks(self, video, queries, tracks, visibilities):
|
||||||
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
|
inv_video = video.flip(1).clone()
|
||||||
|
inv_queries = queries.clone()
|
||||||
inv_tracks = inv_tracks.flip(1)
|
inv_queries[:, :, 0] = inv_video.shape[1] - inv_queries[:, :, 0] - 1
|
||||||
inv_visibilities = inv_visibilities.flip(1)
|
|
||||||
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
|
inv_tracks, inv_visibilities, __ = self.model(video=inv_video, queries=inv_queries, iters=6)
|
||||||
|
|
||||||
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
|
inv_tracks = inv_tracks.flip(1)
|
||||||
|
inv_visibilities = inv_visibilities.flip(1)
|
||||||
tracks[mask] = inv_tracks[mask]
|
arange = torch.arange(video.shape[1], device=queries.device)[None, :, None]
|
||||||
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
|
||||||
return tracks, visibilities
|
mask = (arange < queries[:, None, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2)
|
||||||
|
|
||||||
|
tracks[mask] = inv_tracks[mask]
|
||||||
class CoTrackerOnlinePredictor(torch.nn.Module):
|
visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]]
|
||||||
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
return tracks, visibilities
|
||||||
super().__init__()
|
|
||||||
self.support_grid_size = 6
|
|
||||||
model = build_cotracker(checkpoint)
|
class CoTrackerOnlinePredictor(torch.nn.Module):
|
||||||
self.interp_shape = model.model_resolution
|
def __init__(self, checkpoint="./checkpoints/cotracker2.pth"):
|
||||||
self.step = model.window_len // 2
|
super().__init__()
|
||||||
self.model = model
|
self.support_grid_size = 6
|
||||||
self.model.eval()
|
model = build_cotracker(checkpoint)
|
||||||
|
self.interp_shape = model.model_resolution
|
||||||
@torch.no_grad()
|
self.step = model.window_len // 2
|
||||||
def forward(
|
self.model = model
|
||||||
self,
|
self.model.eval()
|
||||||
video_chunk,
|
|
||||||
is_first_step: bool = False,
|
@torch.no_grad()
|
||||||
queries: torch.Tensor = None,
|
def forward(
|
||||||
grid_size: int = 10,
|
self,
|
||||||
grid_query_frame: int = 0,
|
video_chunk,
|
||||||
add_support_grid=False,
|
is_first_step: bool = False,
|
||||||
):
|
queries: torch.Tensor = None,
|
||||||
B, T, C, H, W = video_chunk.shape
|
grid_size: int = 10,
|
||||||
# Initialize online video processing and save queried points
|
grid_query_frame: int = 0,
|
||||||
# This needs to be done before processing *each new video*
|
add_support_grid=False,
|
||||||
if is_first_step:
|
):
|
||||||
self.model.init_video_online_processing()
|
B, T, C, H, W = video_chunk.shape
|
||||||
if queries is not None:
|
# Initialize online video processing and save queried points
|
||||||
B, N, D = queries.shape
|
# This needs to be done before processing *each new video*
|
||||||
assert D == 3
|
if is_first_step:
|
||||||
queries = queries.clone()
|
self.model.init_video_online_processing()
|
||||||
queries[:, :, 1:] *= queries.new_tensor(
|
if queries is not None:
|
||||||
[
|
B, N, D = queries.shape
|
||||||
(self.interp_shape[1] - 1) / (W - 1),
|
assert D == 3
|
||||||
(self.interp_shape[0] - 1) / (H - 1),
|
queries = queries.clone()
|
||||||
]
|
queries[:, :, 1:] *= queries.new_tensor(
|
||||||
)
|
[
|
||||||
elif grid_size > 0:
|
(self.interp_shape[1] - 1) / (W - 1),
|
||||||
grid_pts = get_points_on_a_grid(
|
(self.interp_shape[0] - 1) / (H - 1),
|
||||||
grid_size, self.interp_shape, device=video_chunk.device
|
]
|
||||||
)
|
)
|
||||||
queries = torch.cat(
|
elif grid_size > 0:
|
||||||
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
grid_pts = get_points_on_a_grid(
|
||||||
dim=2,
|
grid_size, self.interp_shape, device=video_chunk.device
|
||||||
)
|
)
|
||||||
if add_support_grid:
|
queries = torch.cat(
|
||||||
grid_pts = get_points_on_a_grid(
|
[torch.ones_like(grid_pts[:, :, :1]) * grid_query_frame, grid_pts],
|
||||||
self.support_grid_size, self.interp_shape, device=video_chunk.device
|
dim=2,
|
||||||
)
|
)
|
||||||
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
if add_support_grid:
|
||||||
queries = torch.cat([queries, grid_pts], dim=1)
|
grid_pts = get_points_on_a_grid(
|
||||||
self.queries = queries
|
self.support_grid_size, self.interp_shape, device=video_chunk.device
|
||||||
return (None, None)
|
)
|
||||||
|
grid_pts = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)
|
||||||
video_chunk = video_chunk.reshape(B * T, C, H, W)
|
queries = torch.cat([queries, grid_pts], dim=1)
|
||||||
video_chunk = F.interpolate(
|
self.queries = queries
|
||||||
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
|
return (None, None)
|
||||||
)
|
|
||||||
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
video_chunk = video_chunk.reshape(B * T, C, H, W)
|
||||||
|
video_chunk = F.interpolate(
|
||||||
tracks, visibilities, __ = self.model(
|
video_chunk, tuple(self.interp_shape), mode="bilinear", align_corners=True
|
||||||
video=video_chunk,
|
)
|
||||||
queries=self.queries,
|
video_chunk = video_chunk.reshape(B, T, 3, self.interp_shape[0], self.interp_shape[1])
|
||||||
iters=6,
|
|
||||||
is_online=True,
|
tracks, visibilities, __ = self.model(
|
||||||
)
|
video=video_chunk,
|
||||||
thr = 0.9
|
queries=self.queries,
|
||||||
return (
|
iters=6,
|
||||||
tracks
|
is_online=True,
|
||||||
* tracks.new_tensor(
|
)
|
||||||
[
|
thr = 0.9
|
||||||
(W - 1) / (self.interp_shape[1] - 1),
|
return (
|
||||||
(H - 1) / (self.interp_shape[0] - 1),
|
tracks
|
||||||
]
|
* tracks.new_tensor(
|
||||||
),
|
[
|
||||||
visibilities > thr,
|
(W - 1) / (self.interp_shape[1] - 1),
|
||||||
)
|
(H - 1) / (self.interp_shape[0] - 1),
|
||||||
|
]
|
||||||
|
),
|
||||||
|
visibilities > thr,
|
||||||
|
)
|
||||||
|
@@ -1,5 +1,5 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
BIN
cotracker/utils/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
cotracker/utils/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/__init__.cpython-39.pyc
Normal file
BIN
cotracker/utils/__pycache__/__init__.cpython-39.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/visualizer.cpython-38.pyc
Normal file
BIN
cotracker/utils/__pycache__/visualizer.cpython-38.pyc
Normal file
Binary file not shown.
BIN
cotracker/utils/__pycache__/visualizer.cpython-39.pyc
Normal file
BIN
cotracker/utils/__pycache__/visualizer.cpython-39.pyc
Normal file
Binary file not shown.
@@ -1,343 +1,343 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
import os
|
import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import imageio
|
import imageio
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from matplotlib import cm
|
from matplotlib import cm
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torchvision.transforms as transforms
|
import torchvision.transforms as transforms
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
from PIL import Image, ImageDraw
|
from PIL import Image, ImageDraw
|
||||||
|
|
||||||
|
|
||||||
def read_video_from_path(path):
|
def read_video_from_path(path):
|
||||||
try:
|
try:
|
||||||
reader = imageio.get_reader(path)
|
reader = imageio.get_reader(path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error opening video file: ", e)
|
print("Error opening video file: ", e)
|
||||||
return None
|
return None
|
||||||
frames = []
|
frames = []
|
||||||
for i, im in enumerate(reader):
|
for i, im in enumerate(reader):
|
||||||
frames.append(np.array(im))
|
frames.append(np.array(im))
|
||||||
return np.stack(frames)
|
return np.stack(frames)
|
||||||
|
|
||||||
|
|
||||||
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
|
def draw_circle(rgb, coord, radius, color=(255, 0, 0), visible=True):
|
||||||
# Create a draw object
|
# Create a draw object
|
||||||
draw = ImageDraw.Draw(rgb)
|
draw = ImageDraw.Draw(rgb)
|
||||||
# Calculate the bounding box of the circle
|
# Calculate the bounding box of the circle
|
||||||
left_up_point = (coord[0] - radius, coord[1] - radius)
|
left_up_point = (coord[0] - radius, coord[1] - radius)
|
||||||
right_down_point = (coord[0] + radius, coord[1] + radius)
|
right_down_point = (coord[0] + radius, coord[1] + radius)
|
||||||
# Draw the circle
|
# Draw the circle
|
||||||
draw.ellipse(
|
draw.ellipse(
|
||||||
[left_up_point, right_down_point],
|
[left_up_point, right_down_point],
|
||||||
fill=tuple(color) if visible else None,
|
fill=tuple(color) if visible else None,
|
||||||
outline=tuple(color),
|
outline=tuple(color),
|
||||||
)
|
)
|
||||||
return rgb
|
return rgb
|
||||||
|
|
||||||
|
|
||||||
def draw_line(rgb, coord_y, coord_x, color, linewidth):
|
def draw_line(rgb, coord_y, coord_x, color, linewidth):
|
||||||
draw = ImageDraw.Draw(rgb)
|
draw = ImageDraw.Draw(rgb)
|
||||||
draw.line(
|
draw.line(
|
||||||
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
|
(coord_y[0], coord_y[1], coord_x[0], coord_x[1]),
|
||||||
fill=tuple(color),
|
fill=tuple(color),
|
||||||
width=linewidth,
|
width=linewidth,
|
||||||
)
|
)
|
||||||
return rgb
|
return rgb
|
||||||
|
|
||||||
|
|
||||||
def add_weighted(rgb, alpha, original, beta, gamma):
|
def add_weighted(rgb, alpha, original, beta, gamma):
|
||||||
return (rgb * alpha + original * beta + gamma).astype("uint8")
|
return (rgb * alpha + original * beta + gamma).astype("uint8")
|
||||||
|
|
||||||
|
|
||||||
class Visualizer:
|
class Visualizer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
save_dir: str = "./results",
|
save_dir: str = "./results",
|
||||||
grayscale: bool = False,
|
grayscale: bool = False,
|
||||||
pad_value: int = 0,
|
pad_value: int = 0,
|
||||||
fps: int = 10,
|
fps: int = 10,
|
||||||
mode: str = "rainbow", # 'cool', 'optical_flow'
|
mode: str = "rainbow", # 'cool', 'optical_flow'
|
||||||
linewidth: int = 2,
|
linewidth: int = 2,
|
||||||
show_first_frame: int = 10,
|
show_first_frame: int = 10,
|
||||||
tracks_leave_trace: int = 0, # -1 for infinite
|
tracks_leave_trace: int = 0, # -1 for infinite
|
||||||
):
|
):
|
||||||
self.mode = mode
|
self.mode = mode
|
||||||
self.save_dir = save_dir
|
self.save_dir = save_dir
|
||||||
if mode == "rainbow":
|
if mode == "rainbow":
|
||||||
self.color_map = cm.get_cmap("gist_rainbow")
|
self.color_map = cm.get_cmap("gist_rainbow")
|
||||||
elif mode == "cool":
|
elif mode == "cool":
|
||||||
self.color_map = cm.get_cmap(mode)
|
self.color_map = cm.get_cmap(mode)
|
||||||
self.show_first_frame = show_first_frame
|
self.show_first_frame = show_first_frame
|
||||||
self.grayscale = grayscale
|
self.grayscale = grayscale
|
||||||
self.tracks_leave_trace = tracks_leave_trace
|
self.tracks_leave_trace = tracks_leave_trace
|
||||||
self.pad_value = pad_value
|
self.pad_value = pad_value
|
||||||
self.linewidth = linewidth
|
self.linewidth = linewidth
|
||||||
self.fps = fps
|
self.fps = fps
|
||||||
|
|
||||||
def visualize(
|
def visualize(
|
||||||
self,
|
self,
|
||||||
video: torch.Tensor, # (B,T,C,H,W)
|
video: torch.Tensor, # (B,T,C,H,W)
|
||||||
tracks: torch.Tensor, # (B,T,N,2)
|
tracks: torch.Tensor, # (B,T,N,2)
|
||||||
visibility: torch.Tensor = None, # (B, T, N, 1) bool
|
visibility: torch.Tensor = None, # (B, T, N, 1) bool
|
||||||
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
gt_tracks: torch.Tensor = None, # (B,T,N,2)
|
||||||
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
segm_mask: torch.Tensor = None, # (B,1,H,W)
|
||||||
filename: str = "video",
|
filename: str = "video",
|
||||||
writer=None, # tensorboard Summary Writer, used for visualization during training
|
writer=None, # tensorboard Summary Writer, used for visualization during training
|
||||||
step: int = 0,
|
step: int = 0,
|
||||||
query_frame: int = 0,
|
query_frame: int = 0,
|
||||||
save_video: bool = True,
|
save_video: bool = True,
|
||||||
compensate_for_camera_motion: bool = False,
|
compensate_for_camera_motion: bool = False,
|
||||||
):
|
):
|
||||||
if compensate_for_camera_motion:
|
if compensate_for_camera_motion:
|
||||||
assert segm_mask is not None
|
assert segm_mask is not None
|
||||||
if segm_mask is not None:
|
if segm_mask is not None:
|
||||||
coords = tracks[0, query_frame].round().long()
|
coords = tracks[0, query_frame].round().long()
|
||||||
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
|
||||||
|
|
||||||
video = F.pad(
|
video = F.pad(
|
||||||
video,
|
video,
|
||||||
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
|
||||||
"constant",
|
"constant",
|
||||||
255,
|
255,
|
||||||
)
|
)
|
||||||
tracks = tracks + self.pad_value
|
tracks = tracks + self.pad_value
|
||||||
|
|
||||||
if self.grayscale:
|
if self.grayscale:
|
||||||
transform = transforms.Grayscale()
|
transform = transforms.Grayscale()
|
||||||
video = transform(video)
|
video = transform(video)
|
||||||
video = video.repeat(1, 1, 3, 1, 1)
|
video = video.repeat(1, 1, 3, 1, 1)
|
||||||
|
|
||||||
res_video = self.draw_tracks_on_video(
|
res_video = self.draw_tracks_on_video(
|
||||||
video=video,
|
video=video,
|
||||||
tracks=tracks,
|
tracks=tracks,
|
||||||
visibility=visibility,
|
visibility=visibility,
|
||||||
segm_mask=segm_mask,
|
segm_mask=segm_mask,
|
||||||
gt_tracks=gt_tracks,
|
gt_tracks=gt_tracks,
|
||||||
query_frame=query_frame,
|
query_frame=query_frame,
|
||||||
compensate_for_camera_motion=compensate_for_camera_motion,
|
compensate_for_camera_motion=compensate_for_camera_motion,
|
||||||
)
|
)
|
||||||
if save_video:
|
if save_video:
|
||||||
self.save_video(res_video, filename=filename, writer=writer, step=step)
|
self.save_video(res_video, filename=filename, writer=writer, step=step)
|
||||||
return res_video
|
return res_video
|
||||||
|
|
||||||
def save_video(self, video, filename, writer=None, step=0):
|
def save_video(self, video, filename, writer=None, step=0):
|
||||||
if writer is not None:
|
if writer is not None:
|
||||||
writer.add_video(
|
writer.add_video(
|
||||||
filename,
|
filename,
|
||||||
video.to(torch.uint8),
|
video.to(torch.uint8),
|
||||||
global_step=step,
|
global_step=step,
|
||||||
fps=self.fps,
|
fps=self.fps,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
os.makedirs(self.save_dir, exist_ok=True)
|
os.makedirs(self.save_dir, exist_ok=True)
|
||||||
wide_list = list(video.unbind(1))
|
wide_list = list(video.unbind(1))
|
||||||
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
|
||||||
|
|
||||||
# Prepare the video file path
|
# Prepare the video file path
|
||||||
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
|
save_path = os.path.join(self.save_dir, f"{filename}.mp4")
|
||||||
|
|
||||||
# Create a writer object
|
# Create a writer object
|
||||||
video_writer = imageio.get_writer(save_path, fps=self.fps)
|
video_writer = imageio.get_writer(save_path, fps=self.fps)
|
||||||
|
|
||||||
# Write frames to the video file
|
# Write frames to the video file
|
||||||
for frame in wide_list[2:-1]:
|
for frame in wide_list[2:-1]:
|
||||||
video_writer.append_data(frame)
|
video_writer.append_data(frame)
|
||||||
|
|
||||||
video_writer.close()
|
video_writer.close()
|
||||||
|
|
||||||
print(f"Video saved to {save_path}")
|
print(f"Video saved to {save_path}")
|
||||||
|
|
||||||
def draw_tracks_on_video(
|
def draw_tracks_on_video(
|
||||||
self,
|
self,
|
||||||
video: torch.Tensor,
|
video: torch.Tensor,
|
||||||
tracks: torch.Tensor,
|
tracks: torch.Tensor,
|
||||||
visibility: torch.Tensor = None,
|
visibility: torch.Tensor = None,
|
||||||
segm_mask: torch.Tensor = None,
|
segm_mask: torch.Tensor = None,
|
||||||
gt_tracks=None,
|
gt_tracks=None,
|
||||||
query_frame: int = 0,
|
query_frame: int = 0,
|
||||||
compensate_for_camera_motion=False,
|
compensate_for_camera_motion=False,
|
||||||
):
|
):
|
||||||
B, T, C, H, W = video.shape
|
B, T, C, H, W = video.shape
|
||||||
_, _, N, D = tracks.shape
|
_, _, N, D = tracks.shape
|
||||||
|
|
||||||
assert D == 2
|
assert D == 2
|
||||||
assert C == 3
|
assert C == 3
|
||||||
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
|
||||||
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
|
||||||
if gt_tracks is not None:
|
if gt_tracks is not None:
|
||||||
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
gt_tracks = gt_tracks[0].detach().cpu().numpy()
|
||||||
|
|
||||||
res_video = []
|
res_video = []
|
||||||
|
|
||||||
# process input video
|
# process input video
|
||||||
for rgb in video:
|
for rgb in video:
|
||||||
res_video.append(rgb.copy())
|
res_video.append(rgb.copy())
|
||||||
vector_colors = np.zeros((T, N, 3))
|
vector_colors = np.zeros((T, N, 3))
|
||||||
|
|
||||||
if self.mode == "optical_flow":
|
if self.mode == "optical_flow":
|
||||||
import flow_vis
|
import flow_vis
|
||||||
|
|
||||||
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
|
||||||
elif segm_mask is None:
|
elif segm_mask is None:
|
||||||
if self.mode == "rainbow":
|
if self.mode == "rainbow":
|
||||||
y_min, y_max = (
|
y_min, y_max = (
|
||||||
tracks[query_frame, :, 1].min(),
|
tracks[query_frame, :, 1].min(),
|
||||||
tracks[query_frame, :, 1].max(),
|
tracks[query_frame, :, 1].max(),
|
||||||
)
|
)
|
||||||
norm = plt.Normalize(y_min, y_max)
|
norm = plt.Normalize(y_min, y_max)
|
||||||
for n in range(N):
|
for n in range(N):
|
||||||
color = self.color_map(norm(tracks[query_frame, n, 1]))
|
color = self.color_map(norm(tracks[query_frame, n, 1]))
|
||||||
color = np.array(color[:3])[None] * 255
|
color = np.array(color[:3])[None] * 255
|
||||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||||
else:
|
else:
|
||||||
# color changes with time
|
# color changes with time
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
color = np.array(self.color_map(t / T)[:3])[None] * 255
|
||||||
vector_colors[t] = np.repeat(color, N, axis=0)
|
vector_colors[t] = np.repeat(color, N, axis=0)
|
||||||
else:
|
else:
|
||||||
if self.mode == "rainbow":
|
if self.mode == "rainbow":
|
||||||
vector_colors[:, segm_mask <= 0, :] = 255
|
vector_colors[:, segm_mask <= 0, :] = 255
|
||||||
|
|
||||||
y_min, y_max = (
|
y_min, y_max = (
|
||||||
tracks[0, segm_mask > 0, 1].min(),
|
tracks[0, segm_mask > 0, 1].min(),
|
||||||
tracks[0, segm_mask > 0, 1].max(),
|
tracks[0, segm_mask > 0, 1].max(),
|
||||||
)
|
)
|
||||||
norm = plt.Normalize(y_min, y_max)
|
norm = plt.Normalize(y_min, y_max)
|
||||||
for n in range(N):
|
for n in range(N):
|
||||||
if segm_mask[n] > 0:
|
if segm_mask[n] > 0:
|
||||||
color = self.color_map(norm(tracks[0, n, 1]))
|
color = self.color_map(norm(tracks[0, n, 1]))
|
||||||
color = np.array(color[:3])[None] * 255
|
color = np.array(color[:3])[None] * 255
|
||||||
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
vector_colors[:, n] = np.repeat(color, T, axis=0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# color changes with segm class
|
# color changes with segm class
|
||||||
segm_mask = segm_mask.cpu()
|
segm_mask = segm_mask.cpu()
|
||||||
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
|
||||||
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
|
||||||
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
|
||||||
vector_colors = np.repeat(color[None], T, axis=0)
|
vector_colors = np.repeat(color[None], T, axis=0)
|
||||||
|
|
||||||
# draw tracks
|
# draw tracks
|
||||||
if self.tracks_leave_trace != 0:
|
if self.tracks_leave_trace != 0:
|
||||||
for t in range(query_frame + 1, T):
|
for t in range(query_frame + 1, T):
|
||||||
first_ind = (
|
first_ind = (
|
||||||
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
|
max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0
|
||||||
)
|
)
|
||||||
curr_tracks = tracks[first_ind : t + 1]
|
curr_tracks = tracks[first_ind : t + 1]
|
||||||
curr_colors = vector_colors[first_ind : t + 1]
|
curr_colors = vector_colors[first_ind : t + 1]
|
||||||
if compensate_for_camera_motion:
|
if compensate_for_camera_motion:
|
||||||
diff = (
|
diff = (
|
||||||
tracks[first_ind : t + 1, segm_mask <= 0]
|
tracks[first_ind : t + 1, segm_mask <= 0]
|
||||||
- tracks[t : t + 1, segm_mask <= 0]
|
- tracks[t : t + 1, segm_mask <= 0]
|
||||||
).mean(1)[:, None]
|
).mean(1)[:, None]
|
||||||
|
|
||||||
curr_tracks = curr_tracks - diff
|
curr_tracks = curr_tracks - diff
|
||||||
curr_tracks = curr_tracks[:, segm_mask > 0]
|
curr_tracks = curr_tracks[:, segm_mask > 0]
|
||||||
curr_colors = curr_colors[:, segm_mask > 0]
|
curr_colors = curr_colors[:, segm_mask > 0]
|
||||||
|
|
||||||
res_video[t] = self._draw_pred_tracks(
|
res_video[t] = self._draw_pred_tracks(
|
||||||
res_video[t],
|
res_video[t],
|
||||||
curr_tracks,
|
curr_tracks,
|
||||||
curr_colors,
|
curr_colors,
|
||||||
)
|
)
|
||||||
if gt_tracks is not None:
|
if gt_tracks is not None:
|
||||||
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
|
res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1])
|
||||||
|
|
||||||
# draw points
|
# draw points
|
||||||
for t in range(query_frame, T):
|
for t in range(query_frame, T):
|
||||||
img = Image.fromarray(np.uint8(res_video[t]))
|
img = Image.fromarray(np.uint8(res_video[t]))
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
coord = (tracks[t, i, 0], tracks[t, i, 1])
|
||||||
visibile = True
|
visibile = True
|
||||||
if visibility is not None:
|
if visibility is not None:
|
||||||
visibile = visibility[0, t, i]
|
visibile = visibility[0, t, i]
|
||||||
if coord[0] != 0 and coord[1] != 0:
|
if coord[0] != 0 and coord[1] != 0:
|
||||||
if not compensate_for_camera_motion or (
|
if not compensate_for_camera_motion or (
|
||||||
compensate_for_camera_motion and segm_mask[i] > 0
|
compensate_for_camera_motion and segm_mask[i] > 0
|
||||||
):
|
):
|
||||||
img = draw_circle(
|
img = draw_circle(
|
||||||
img,
|
img,
|
||||||
coord=coord,
|
coord=coord,
|
||||||
radius=int(self.linewidth * 2),
|
radius=int(self.linewidth * 2),
|
||||||
color=vector_colors[t, i].astype(int),
|
color=vector_colors[t, i].astype(int),
|
||||||
visible=visibile,
|
visible=visibile,
|
||||||
)
|
)
|
||||||
res_video[t] = np.array(img)
|
res_video[t] = np.array(img)
|
||||||
|
|
||||||
# construct the final rgb sequence
|
# construct the final rgb sequence
|
||||||
if self.show_first_frame > 0:
|
if self.show_first_frame > 0:
|
||||||
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
|
||||||
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
|
||||||
|
|
||||||
def _draw_pred_tracks(
|
def _draw_pred_tracks(
|
||||||
self,
|
self,
|
||||||
rgb: np.ndarray, # H x W x 3
|
rgb: np.ndarray, # H x W x 3
|
||||||
tracks: np.ndarray, # T x 2
|
tracks: np.ndarray, # T x 2
|
||||||
vector_colors: np.ndarray,
|
vector_colors: np.ndarray,
|
||||||
alpha: float = 0.5,
|
alpha: float = 0.5,
|
||||||
):
|
):
|
||||||
T, N, _ = tracks.shape
|
T, N, _ = tracks.shape
|
||||||
rgb = Image.fromarray(np.uint8(rgb))
|
rgb = Image.fromarray(np.uint8(rgb))
|
||||||
for s in range(T - 1):
|
for s in range(T - 1):
|
||||||
vector_color = vector_colors[s]
|
vector_color = vector_colors[s]
|
||||||
original = rgb.copy()
|
original = rgb.copy()
|
||||||
alpha = (s / T) ** 2
|
alpha = (s / T) ** 2
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
|
||||||
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
|
||||||
if coord_y[0] != 0 and coord_y[1] != 0:
|
if coord_y[0] != 0 and coord_y[1] != 0:
|
||||||
rgb = draw_line(
|
rgb = draw_line(
|
||||||
rgb,
|
rgb,
|
||||||
coord_y,
|
coord_y,
|
||||||
coord_x,
|
coord_x,
|
||||||
vector_color[i].astype(int),
|
vector_color[i].astype(int),
|
||||||
self.linewidth,
|
self.linewidth,
|
||||||
)
|
)
|
||||||
if self.tracks_leave_trace > 0:
|
if self.tracks_leave_trace > 0:
|
||||||
rgb = Image.fromarray(
|
rgb = Image.fromarray(
|
||||||
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
|
np.uint8(add_weighted(np.array(rgb), alpha, np.array(original), 1 - alpha, 0))
|
||||||
)
|
)
|
||||||
rgb = np.array(rgb)
|
rgb = np.array(rgb)
|
||||||
return rgb
|
return rgb
|
||||||
|
|
||||||
def _draw_gt_tracks(
|
def _draw_gt_tracks(
|
||||||
self,
|
self,
|
||||||
rgb: np.ndarray, # H x W x 3,
|
rgb: np.ndarray, # H x W x 3,
|
||||||
gt_tracks: np.ndarray, # T x 2
|
gt_tracks: np.ndarray, # T x 2
|
||||||
):
|
):
|
||||||
T, N, _ = gt_tracks.shape
|
T, N, _ = gt_tracks.shape
|
||||||
color = np.array((211, 0, 0))
|
color = np.array((211, 0, 0))
|
||||||
rgb = Image.fromarray(np.uint8(rgb))
|
rgb = Image.fromarray(np.uint8(rgb))
|
||||||
for t in range(T):
|
for t in range(T):
|
||||||
for i in range(N):
|
for i in range(N):
|
||||||
gt_tracks = gt_tracks[t][i]
|
gt_tracks = gt_tracks[t][i]
|
||||||
# draw a red cross
|
# draw a red cross
|
||||||
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
if gt_tracks[0] > 0 and gt_tracks[1] > 0:
|
||||||
length = self.linewidth * 3
|
length = self.linewidth * 3
|
||||||
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
coord_y = (int(gt_tracks[0]) + length, int(gt_tracks[1]) + length)
|
||||||
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
coord_x = (int(gt_tracks[0]) - length, int(gt_tracks[1]) - length)
|
||||||
rgb = draw_line(
|
rgb = draw_line(
|
||||||
rgb,
|
rgb,
|
||||||
coord_y,
|
coord_y,
|
||||||
coord_x,
|
coord_x,
|
||||||
color,
|
color,
|
||||||
self.linewidth,
|
self.linewidth,
|
||||||
)
|
)
|
||||||
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
coord_y = (int(gt_tracks[0]) - length, int(gt_tracks[1]) + length)
|
||||||
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
coord_x = (int(gt_tracks[0]) + length, int(gt_tracks[1]) - length)
|
||||||
rgb = draw_line(
|
rgb = draw_line(
|
||||||
rgb,
|
rgb,
|
||||||
coord_y,
|
coord_y,
|
||||||
coord_x,
|
coord_x,
|
||||||
color,
|
color,
|
||||||
self.linewidth,
|
self.linewidth,
|
||||||
)
|
)
|
||||||
rgb = np.array(rgb)
|
rgb = np.array(rgb)
|
||||||
return rgb
|
return rgb
|
||||||
|
@@ -1,8 +1,8 @@
|
|||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
# All rights reserved.
|
# All rights reserved.
|
||||||
|
|
||||||
# This source code is licensed under the license found in the
|
# This source code is licensed under the license found in the
|
||||||
# LICENSE file in the root directory of this source tree.
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
__version__ = "2.0.0"
|
__version__ = "2.0.0"
|
||||||
|
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user