correct query-point predictions (#32)
This commit is contained in:
parent
7d18c58cce
commit
4f297a92fe
@ -152,6 +152,21 @@ class CoTrackerPredictor(torch.nn.Module):
|
||||
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
|
||||
thr = 0.9
|
||||
visibilities = visibilities > thr
|
||||
|
||||
# correct query-point predictions
|
||||
# see https://github.com/facebookresearch/co-tracker/issues/28
|
||||
|
||||
# TODO: batchify
|
||||
for i in range(len(queries)):
|
||||
queries_t = queries[i, :tracks.size(2), 0].to(torch.int64)
|
||||
arange = torch.arange(0, len(queries_t))
|
||||
|
||||
# overwrite the predictions with the query points
|
||||
tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:]
|
||||
|
||||
# correct visibilities, the query points should be visible
|
||||
visibilities[i, queries_t, arange] = True
|
||||
|
||||
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
|
||||
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
|
||||
return tracks, visibilities
|
||||
|
Loading…
Reference in New Issue
Block a user