try to generate the tracking f1 video but too many points
This commit is contained in:
		
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
				
			|||||||
 | 
					__pycache__/
 | 
				
			||||||
 | 
					.vscode/
 | 
				
			||||||
 | 
					cotracker/__pycache__/
 | 
				
			||||||
@@ -21,6 +21,9 @@ class CoTrackerPredictor(torch.nn.Module):
 | 
				
			|||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
        self.model.eval()
 | 
					        self.model.eval()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
 | 
				
			||||||
 | 
					        self.model.to(self.device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @torch.no_grad()
 | 
					    @torch.no_grad()
 | 
				
			||||||
    def forward(
 | 
					    def forward(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										8
									
								
								demo1.py
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								demo1.py
									
									
									
									
									
								
							@@ -10,6 +10,8 @@ video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from cotracker.predictor import CoTrackerPredictor
 | 
					from cotracker.predictor import CoTrackerPredictor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
model = CoTrackerPredictor(
 | 
					model = CoTrackerPredictor(
 | 
				
			||||||
    checkpoint=os.path.join(
 | 
					    checkpoint=os.path.join(
 | 
				
			||||||
        './checkpoints/cotracker2.pth'
 | 
					        './checkpoints/cotracker2.pth'
 | 
				
			||||||
@@ -24,11 +26,12 @@ model = CoTrackerPredictor(
 | 
				
			|||||||
grid_query_frame=20
 | 
					grid_query_frame=20
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import torch.nn.functional as F
 | 
					import torch.nn.functional as F
 | 
				
			||||||
video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None]
 | 
					# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import time
 | 
					import time
 | 
				
			||||||
start_time = time.time()
 | 
					start_time = time.time()
 | 
				
			||||||
pred_tracks, pred_visibility = model(video_interp, 
 | 
					# pred_tracks, pred_visibility = model(video_interp, 
 | 
				
			||||||
 | 
					pred_tracks, pred_visibility = model(video, 
 | 
				
			||||||
                                     grid_query_frame=grid_query_frame, backward_tracking=True)
 | 
					                                     grid_query_frame=grid_query_frame, backward_tracking=True)
 | 
				
			||||||
end_time = time.time() 
 | 
					end_time = time.time() 
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -41,6 +44,7 @@ vis = Visualizer(
 | 
				
			|||||||
    mode='optical_flow'
 | 
					    mode='optical_flow'
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
vis.visualize(
 | 
					vis.visualize(
 | 
				
			||||||
 | 
					    # video=video_interp,
 | 
				
			||||||
    video=video,
 | 
					    video=video,
 | 
				
			||||||
    tracks=pred_tracks,
 | 
					    tracks=pred_tracks,
 | 
				
			||||||
    visibility=pred_visibility,
 | 
					    visibility=pred_visibility,
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user