try to generate the tracking f1 video but too many points
This commit is contained in:
		
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
										Normal file
									
								
							@@ -0,0 +1,3 @@
 | 
			
		||||
__pycache__/
 | 
			
		||||
.vscode/
 | 
			
		||||
cotracker/__pycache__/
 | 
			
		||||
@@ -21,6 +21,9 @@ class CoTrackerPredictor(torch.nn.Module):
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.model.eval()
 | 
			
		||||
 | 
			
		||||
        self.device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
 | 
			
		||||
        self.model.to(self.device)
 | 
			
		||||
 | 
			
		||||
    @torch.no_grad()
 | 
			
		||||
    def forward(
 | 
			
		||||
        self,
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										8
									
								
								demo1.py
									
									
									
									
									
								
							
							
						
						
									
										8
									
								
								demo1.py
									
									
									
									
									
								
							@@ -10,6 +10,8 @@ video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
 | 
			
		||||
 | 
			
		||||
from cotracker.predictor import CoTrackerPredictor
 | 
			
		||||
 | 
			
		||||
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
 | 
			
		||||
 | 
			
		||||
model = CoTrackerPredictor(
 | 
			
		||||
    checkpoint=os.path.join(
 | 
			
		||||
        './checkpoints/cotracker2.pth'
 | 
			
		||||
@@ -24,11 +26,12 @@ model = CoTrackerPredictor(
 | 
			
		||||
grid_query_frame=20
 | 
			
		||||
 | 
			
		||||
import torch.nn.functional as F
 | 
			
		||||
video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None]
 | 
			
		||||
# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
 | 
			
		||||
 | 
			
		||||
import time
 | 
			
		||||
start_time = time.time()
 | 
			
		||||
pred_tracks, pred_visibility = model(video_interp, 
 | 
			
		||||
# pred_tracks, pred_visibility = model(video_interp, 
 | 
			
		||||
pred_tracks, pred_visibility = model(video, 
 | 
			
		||||
                                     grid_query_frame=grid_query_frame, backward_tracking=True)
 | 
			
		||||
end_time = time.time() 
 | 
			
		||||
 | 
			
		||||
@@ -41,6 +44,7 @@ vis = Visualizer(
 | 
			
		||||
    mode='optical_flow'
 | 
			
		||||
)
 | 
			
		||||
vis.visualize(
 | 
			
		||||
    # video=video_interp,
 | 
			
		||||
    video=video,
 | 
			
		||||
    tracks=pred_tracks,
 | 
			
		||||
    visibility=pred_visibility,
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user