162 lines
5.1 KiB
Python
162 lines
5.1 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
|
|
# This source code is licensed under the license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
|
|
import os
|
|
import gzip
|
|
import torch
|
|
import numpy as np
|
|
import torch.utils.data as data
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from typing import List, Optional, Any, Dict, Tuple
|
|
|
|
from cotracker.datasets.utils import CoTrackerData
|
|
from cotracker.datasets.dataclass_utils import load_dataclass
|
|
|
|
|
|
@dataclass
|
|
class ImageAnnotation:
|
|
# path to jpg file, relative w.r.t. dataset_root
|
|
path: str
|
|
# H x W
|
|
size: Tuple[int, int]
|
|
|
|
|
|
@dataclass
|
|
class DynamicReplicaFrameAnnotation:
|
|
"""A dataclass used to load annotations from json."""
|
|
|
|
# can be used to join with `SequenceAnnotation`
|
|
sequence_name: str
|
|
# 0-based, continuous frame number within sequence
|
|
frame_number: int
|
|
# timestamp in seconds from the video start
|
|
frame_timestamp: float
|
|
|
|
image: ImageAnnotation
|
|
meta: Optional[Dict[str, Any]] = None
|
|
|
|
camera_name: Optional[str] = None
|
|
trajectories: Optional[str] = None
|
|
|
|
|
|
class DynamicReplicaDataset(data.Dataset):
|
|
def __init__(
|
|
self,
|
|
root,
|
|
split="valid",
|
|
traj_per_sample=256,
|
|
crop_size=None,
|
|
sample_len=-1,
|
|
only_first_n_samples=-1,
|
|
rgbd_input=False,
|
|
):
|
|
super(DynamicReplicaDataset, self).__init__()
|
|
self.root = root
|
|
self.sample_len = sample_len
|
|
self.split = split
|
|
self.traj_per_sample = traj_per_sample
|
|
self.rgbd_input = rgbd_input
|
|
self.crop_size = crop_size
|
|
frame_annotations_file = f"frame_annotations_{split}.jgz"
|
|
self.sample_list = []
|
|
with gzip.open(
|
|
os.path.join(root, split, frame_annotations_file), "rt", encoding="utf8"
|
|
) as zipfile:
|
|
frame_annots_list = load_dataclass(zipfile, List[DynamicReplicaFrameAnnotation])
|
|
seq_annot = defaultdict(list)
|
|
for frame_annot in frame_annots_list:
|
|
if frame_annot.camera_name == "left":
|
|
seq_annot[frame_annot.sequence_name].append(frame_annot)
|
|
|
|
for seq_name in seq_annot.keys():
|
|
seq_len = len(seq_annot[seq_name])
|
|
|
|
step = self.sample_len if self.sample_len > 0 else seq_len
|
|
counter = 0
|
|
|
|
for ref_idx in range(0, seq_len, step):
|
|
sample = seq_annot[seq_name][ref_idx : ref_idx + step]
|
|
self.sample_list.append(sample)
|
|
counter += 1
|
|
if only_first_n_samples > 0 and counter >= only_first_n_samples:
|
|
break
|
|
|
|
def __len__(self):
|
|
return len(self.sample_list)
|
|
|
|
def crop(self, rgbs, trajs):
|
|
T, N, _ = trajs.shape
|
|
|
|
S = len(rgbs)
|
|
H, W = rgbs[0].shape[:2]
|
|
assert S == T
|
|
|
|
H_new = H
|
|
W_new = W
|
|
|
|
# simple random crop
|
|
y0 = 0 if self.crop_size[0] >= H_new else (H_new - self.crop_size[0]) // 2
|
|
x0 = 0 if self.crop_size[1] >= W_new else (W_new - self.crop_size[1]) // 2
|
|
rgbs = [rgb[y0 : y0 + self.crop_size[0], x0 : x0 + self.crop_size[1]] for rgb in rgbs]
|
|
|
|
trajs[:, :, 0] -= x0
|
|
trajs[:, :, 1] -= y0
|
|
|
|
return rgbs, trajs
|
|
|
|
def __getitem__(self, index):
|
|
sample = self.sample_list[index]
|
|
T = len(sample)
|
|
rgbs, visibilities, traj_2d = [], [], []
|
|
|
|
H, W = sample[0].image.size
|
|
image_size = (H, W)
|
|
|
|
for i in range(T):
|
|
traj_path = os.path.join(self.root, self.split, sample[i].trajectories["path"])
|
|
traj = torch.load(traj_path)
|
|
|
|
visibilities.append(traj["verts_inds_vis"].numpy())
|
|
|
|
rgbs.append(traj["img"].numpy())
|
|
traj_2d.append(traj["traj_2d"].numpy()[..., :2])
|
|
|
|
traj_2d = np.stack(traj_2d)
|
|
visibility = np.stack(visibilities)
|
|
T, N, D = traj_2d.shape
|
|
# subsample trajectories for augmentations
|
|
visible_inds_sampled = torch.randperm(N)[: self.traj_per_sample]
|
|
|
|
traj_2d = traj_2d[:, visible_inds_sampled]
|
|
visibility = visibility[:, visible_inds_sampled]
|
|
|
|
if self.crop_size is not None:
|
|
rgbs, traj_2d = self.crop(rgbs, traj_2d)
|
|
H, W, _ = rgbs[0].shape
|
|
image_size = self.crop_size
|
|
|
|
visibility[traj_2d[:, :, 0] > image_size[1] - 1] = False
|
|
visibility[traj_2d[:, :, 0] < 0] = False
|
|
visibility[traj_2d[:, :, 1] > image_size[0] - 1] = False
|
|
visibility[traj_2d[:, :, 1] < 0] = False
|
|
|
|
# filter out points that're visible for less than 10 frames
|
|
visible_inds_resampled = visibility.sum(0) > 10
|
|
traj_2d = torch.from_numpy(traj_2d[:, visible_inds_resampled])
|
|
visibility = torch.from_numpy(visibility[:, visible_inds_resampled])
|
|
|
|
rgbs = np.stack(rgbs, 0)
|
|
video = torch.from_numpy(rgbs).reshape(T, H, W, 3).permute(0, 3, 1, 2).float()
|
|
return CoTrackerData(
|
|
video=video,
|
|
trajectory=traj_2d,
|
|
visibility=visibility,
|
|
valid=torch.ones(T, N),
|
|
seq_name=sample[0].sequence_name,
|
|
)
|