fixed a bug in compute_tapvid_metrics
This commit is contained in:
		@@ -55,32 +55,29 @@ def compute_tapvid_metrics(
 | 
				
			|||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    metrics = {}
 | 
					    metrics = {}
 | 
				
			||||||
 | 
					    # Fixed bug is described in:
 | 
				
			||||||
 | 
					    # https://github.com/facebookresearch/co-tracker/issues/20
 | 
				
			||||||
 | 
					    eye = np.eye(gt_tracks.shape[2], dtype=np.int32)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if query_mode == "first":
 | 
				
			||||||
 | 
					        # evaluate frames after the query frame
 | 
				
			||||||
 | 
					        query_frame_to_eval_frames = np.cumsum(eye, axis=1) - eye
 | 
				
			||||||
 | 
					    elif query_mode == "strided":
 | 
				
			||||||
 | 
					        # evaluate all frames except the query frame
 | 
				
			||||||
 | 
					        query_frame_to_eval_frames = 1 - eye
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        raise ValueError("Unknown query mode " + query_mode)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Don't evaluate the query point.  Numpy doesn't have one_hot, so we
 | 
					 | 
				
			||||||
    # replicate it by indexing into an identity matrix.
 | 
					 | 
				
			||||||
    one_hot_eye = np.eye(gt_tracks.shape[2])
 | 
					 | 
				
			||||||
    query_frame = query_points[..., 0]
 | 
					    query_frame = query_points[..., 0]
 | 
				
			||||||
    query_frame = np.round(query_frame).astype(np.int32)
 | 
					    query_frame = np.round(query_frame).astype(np.int32)
 | 
				
			||||||
    evaluation_points = one_hot_eye[query_frame] == 0
 | 
					    evaluation_points = query_frame_to_eval_frames[query_frame] > 0
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # If we're using the first point on the track as a query, don't evaluate the
 | 
					 | 
				
			||||||
    # other points.
 | 
					 | 
				
			||||||
    if query_mode == "first":
 | 
					 | 
				
			||||||
        for i in range(gt_occluded.shape[0]):
 | 
					 | 
				
			||||||
            index = np.where(gt_occluded[i] == 0)[0][0]
 | 
					 | 
				
			||||||
            evaluation_points[i, :index] = False
 | 
					 | 
				
			||||||
    elif query_mode != "strided":
 | 
					 | 
				
			||||||
        raise ValueError("Unknown query mode " + query_mode)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Occlusion accuracy is simply how often the predicted occlusion equals the
 | 
					    # Occlusion accuracy is simply how often the predicted occlusion equals the
 | 
				
			||||||
    # ground truth.
 | 
					    # ground truth.
 | 
				
			||||||
    occ_acc = (
 | 
					    occ_acc = np.sum(
 | 
				
			||||||
        np.sum(
 | 
					        np.equal(pred_occluded, gt_occluded) & evaluation_points,
 | 
				
			||||||
            np.equal(pred_occluded, gt_occluded) & evaluation_points,
 | 
					        axis=(1, 2),
 | 
				
			||||||
            axis=(1, 2),
 | 
					    ) / np.sum(evaluation_points)
 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        / np.sum(evaluation_points)
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
    metrics["occlusion_accuracy"] = occ_acc
 | 
					    metrics["occlusion_accuracy"] = occ_acc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # Next, convert the predictions and ground truth positions into pixel
 | 
					    # Next, convert the predictions and ground truth positions into pixel
 | 
				
			||||||
@@ -92,13 +89,10 @@ def compute_tapvid_metrics(
 | 
				
			|||||||
    for thresh in [1, 2, 4, 8, 16]:
 | 
					    for thresh in [1, 2, 4, 8, 16]:
 | 
				
			||||||
        # True positives are points that are within the threshold and where both
 | 
					        # True positives are points that are within the threshold and where both
 | 
				
			||||||
        # the prediction and the ground truth are listed as visible.
 | 
					        # the prediction and the ground truth are listed as visible.
 | 
				
			||||||
        within_dist = (
 | 
					        within_dist = np.sum(
 | 
				
			||||||
            np.sum(
 | 
					            np.square(pred_tracks - gt_tracks),
 | 
				
			||||||
                np.square(pred_tracks - gt_tracks),
 | 
					            axis=-1,
 | 
				
			||||||
                axis=-1,
 | 
					        ) < np.square(thresh)
 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            < np.square(thresh)
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        is_correct = np.logical_and(within_dist, visible)
 | 
					        is_correct = np.logical_and(within_dist, visible)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Compute the frac_within_threshold, which is the fraction of points
 | 
					        # Compute the frac_within_threshold, which is the fraction of points
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user