correct query-point predictions (#32)

This commit is contained in:
Ernie Chu 2023-09-14 18:20:02 +08:00 committed by GitHub
parent 7d18c58cce
commit 4f297a92fe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -152,6 +152,21 @@ class CoTrackerPredictor(torch.nn.Module):
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
thr = 0.9
visibilities = visibilities > thr
# correct query-point predictions
# see https://github.com/facebookresearch/co-tracker/issues/28
# TODO: batchify
for i in range(len(queries)):
queries_t = queries[i, :tracks.size(2), 0].to(torch.int64)
arange = torch.arange(0, len(queries_t))
# overwrite the predictions with the query points
tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:]
# correct visibilities, the query points should be visible
visibilities[i, queries_t, arange] = True
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
return tracks, visibilities