correct query-point predictions (#32)
This commit is contained in:
parent
7d18c58cce
commit
4f297a92fe
@ -152,6 +152,21 @@ class CoTrackerPredictor(torch.nn.Module):
|
|||||||
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
|
visibilities = visibilities[:, :, : -self.support_grid_size ** 2]
|
||||||
thr = 0.9
|
thr = 0.9
|
||||||
visibilities = visibilities > thr
|
visibilities = visibilities > thr
|
||||||
|
|
||||||
|
# correct query-point predictions
|
||||||
|
# see https://github.com/facebookresearch/co-tracker/issues/28
|
||||||
|
|
||||||
|
# TODO: batchify
|
||||||
|
for i in range(len(queries)):
|
||||||
|
queries_t = queries[i, :tracks.size(2), 0].to(torch.int64)
|
||||||
|
arange = torch.arange(0, len(queries_t))
|
||||||
|
|
||||||
|
# overwrite the predictions with the query points
|
||||||
|
tracks[i, queries_t, arange] = queries[i, :tracks.size(2), 1:]
|
||||||
|
|
||||||
|
# correct visibilities, the query points should be visible
|
||||||
|
visibilities[i, queries_t, arange] = True
|
||||||
|
|
||||||
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
|
tracks[:, :, :, 0] *= W / float(self.interp_shape[1])
|
||||||
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
|
tracks[:, :, :, 1] *= H / float(self.interp_shape[0])
|
||||||
return tracks, visibilities
|
return tracks, visibilities
|
||||||
|
Loading…
Reference in New Issue
Block a user