fixed a bug in compute_tapvid_metrics

This commit is contained in:
Nikita Karaev 2023-10-30 11:35:42 +00:00
parent 4f297a92fe
commit cd226f3e6f

View File

@ -55,32 +55,29 @@ def compute_tapvid_metrics(
"""
metrics = {}
# Fixed bug is described in:
# https://github.com/facebookresearch/co-tracker/issues/20
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
if query_mode == "first":
# evaluate frames after the query frame
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
elif query_mode == "strided":
# evaluate all frames except the query frame
query_frame_to_eval_frames = 1 - eye
else:
raise ValueError("Unknown query mode " + query_mode)
# Don't evaluate the query point. Numpy doesn't have one_hot, so we
# replicate it by indexing into an identity matrix.
one_hot_eye = np.eye(gt_tracks.shape[2])
query_frame = query_points[..., 0]
query_frame = np.round(query_frame).astype(np.int32)
evaluation_points = one_hot_eye[query_frame] == 0
# If we're using the first point on the track as a query, don't evaluate the
# other points.
if query_mode == "first":
for i in range(gt_occluded.shape[0]):
index = np.where(gt_occluded[i] == 0)[0][0]
evaluation_points[i, :index] = False
elif query_mode != "strided":
raise ValueError("Unknown query mode " + query_mode)
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
# Occlusion accuracy is simply how often the predicted occlusion equals the
# ground truth.
occ_acc = (
np.sum(
np.equal(pred_occluded, gt_occluded) & evaluation_points,
axis=(1, 2),
)
/ np.sum(evaluation_points)
)
occ_acc = np.sum(
np.equal(pred_occluded, gt_occluded) & evaluation_points,
axis=(1, 2),
) / np.sum(evaluation_points)
metrics["occlusion_accuracy"] = occ_acc
# Next, convert the predictions and ground truth positions into pixel
@ -92,13 +89,10 @@ def compute_tapvid_metrics(
for thresh in [1, 2, 4, 8, 16]:
# True positives are points that are within the threshold and where both
# the prediction and the ground truth are listed as visible.
within_dist = (
np.sum(
np.square(pred_tracks - gt_tracks),
axis=-1,
)
< np.square(thresh)
)
within_dist = np.sum(
np.square(pred_tracks - gt_tracks),
axis=-1,
) < np.square(thresh)
is_correct = np.logical_and(within_dist, visible)
# Compute the frac_within_threshold, which is the fraction of points