fixed a bug in compute_tapvid_metrics
This commit is contained in:
parent
4f297a92fe
commit
cd226f3e6f
@ -55,32 +55,29 @@ def compute_tapvid_metrics(
|
||||
"""
|
||||
|
||||
metrics = {}
|
||||
# Fixed bug is described in:
|
||||
# https://github.com/facebookresearch/co-tracker/issues/20
|
||||
eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
|
||||
|
||||
if query_mode == "first":
|
||||
# evaluate frames after the query frame
|
||||
query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
|
||||
elif query_mode == "strided":
|
||||
# evaluate all frames except the query frame
|
||||
query_frame_to_eval_frames = 1 - eye
|
||||
else:
|
||||
raise ValueError("Unknown query mode " + query_mode)
|
||||
|
||||
# Don't evaluate the query point. Numpy doesn't have one_hot, so we
|
||||
# replicate it by indexing into an identity matrix.
|
||||
one_hot_eye = np.eye(gt_tracks.shape[2])
|
||||
query_frame = query_points[..., 0]
|
||||
query_frame = np.round(query_frame).astype(np.int32)
|
||||
evaluation_points = one_hot_eye[query_frame] == 0
|
||||
|
||||
# If we're using the first point on the track as a query, don't evaluate the
|
||||
# other points.
|
||||
if query_mode == "first":
|
||||
for i in range(gt_occluded.shape[0]):
|
||||
index = np.where(gt_occluded[i] == 0)[0][0]
|
||||
evaluation_points[i, :index] = False
|
||||
elif query_mode != "strided":
|
||||
raise ValueError("Unknown query mode " + query_mode)
|
||||
evaluation_points = query_frame_to_eval_frames[query_frame] > 0
|
||||
|
||||
# Occlusion accuracy is simply how often the predicted occlusion equals the
|
||||
# ground truth.
|
||||
occ_acc = (
|
||||
np.sum(
|
||||
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
||||
axis=(1, 2),
|
||||
)
|
||||
/ np.sum(evaluation_points)
|
||||
)
|
||||
occ_acc = np.sum(
|
||||
np.equal(pred_occluded, gt_occluded) & evaluation_points,
|
||||
axis=(1, 2),
|
||||
) / np.sum(evaluation_points)
|
||||
metrics["occlusion_accuracy"] = occ_acc
|
||||
|
||||
# Next, convert the predictions and ground truth positions into pixel
|
||||
@ -92,13 +89,10 @@ def compute_tapvid_metrics(
|
||||
for thresh in [1, 2, 4, 8, 16]:
|
||||
# True positives are points that are within the threshold and where both
|
||||
# the prediction and the ground truth are listed as visible.
|
||||
within_dist = (
|
||||
np.sum(
|
||||
np.square(pred_tracks - gt_tracks),
|
||||
axis=-1,
|
||||
)
|
||||
< np.square(thresh)
|
||||
)
|
||||
within_dist = np.sum(
|
||||
np.square(pred_tracks - gt_tracks),
|
||||
axis=-1,
|
||||
) < np.square(thresh)
|
||||
is_correct = np.logical_and(within_dist, visible)
|
||||
|
||||
# Compute the frac_within_threshold, which is the fraction of points
|
||||
|
Loading…
Reference in New Issue
Block a user