diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..33fb94c --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +__pycache__/ +.vscode/ +cotracker/__pycache__/ diff --git a/cotracker/predictor.py b/cotracker/predictor.py index 9778a7e..43b2120 100644 --- a/cotracker/predictor.py +++ b/cotracker/predictor.py @@ -21,6 +21,9 @@ class CoTrackerPredictor(torch.nn.Module): self.model = model self.model.eval() + self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu") + self.model.to(self.device) + @torch.no_grad() def forward( self, diff --git a/demo1.py b/demo1.py index b313217..503984e 100644 --- a/demo1.py +++ b/demo1.py @@ -10,6 +10,8 @@ video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float() from cotracker.predictor import CoTrackerPredictor +device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu') + model = CoTrackerPredictor( checkpoint=os.path.join( './checkpoints/cotracker2.pth' @@ -24,11 +26,12 @@ model = CoTrackerPredictor( grid_query_frame=20 import torch.nn.functional as F -video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None] +# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device) import time start_time = time.time() -pred_tracks, pred_visibility = model(video_interp, +# pred_tracks, pred_visibility = model(video_interp, +pred_tracks, pred_visibility = model(video, grid_query_frame=grid_query_frame, backward_tracking=True) end_time = time.time() @@ -41,6 +44,7 @@ vis = Visualizer( mode='optical_flow' ) vis.visualize( + # video=video_interp, video=video, tracks=pred_tracks, visibility=pred_visibility,