demo fixes
This commit is contained in:
		
										
											Binary file not shown.
										
									
								
							| Before Width: | Height: | Size: 14 KiB After Width: | Height: | Size: 14 KiB | 
| @@ -55,7 +55,9 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|  | ||||
|         return tracks, visibilities | ||||
|  | ||||
|     def _compute_dense_tracks(self, video, grid_query_frame, grid_size=30, backward_tracking=False): | ||||
|     def _compute_dense_tracks( | ||||
|         self, video, grid_query_frame, grid_size=150, backward_tracking=False | ||||
|     ): | ||||
|         *_, H, W = video.shape | ||||
|         grid_step = W // grid_size | ||||
|         grid_width = W // grid_step | ||||
| @@ -172,8 +174,9 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|  | ||||
|         inv_tracks = inv_tracks.flip(1) | ||||
|         inv_visibilities = inv_visibilities.flip(1) | ||||
|         arange = torch.arange(video.shape[1], device=queries.device)[None, :, None] | ||||
|  | ||||
|         mask = tracks == 0 | ||||
|         mask = (arange < queries[None, :, :, 0]).unsqueeze(-1).repeat(1, 1, 1, 2) | ||||
|  | ||||
|         tracks[mask] = inv_tracks[mask] | ||||
|         visibilities[mask[:, :, :, 0]] = inv_visibilities[mask[:, :, :, 0]] | ||||
|   | ||||
| @@ -226,7 +226,7 @@ class Visualizer: | ||||
|  | ||||
|         #  draw tracks | ||||
|         if self.tracks_leave_trace != 0: | ||||
|             for t in range(1, T): | ||||
|             for t in range(query_frame + 1, T): | ||||
|                 first_ind = ( | ||||
|                     max(0, t - self.tracks_leave_trace) if self.tracks_leave_trace >= 0 else 0 | ||||
|                 ) | ||||
| @@ -251,7 +251,7 @@ class Visualizer: | ||||
|                     res_video[t] = self._draw_gt_tracks(res_video[t], gt_tracks[first_ind : t + 1]) | ||||
|  | ||||
|         #  draw points | ||||
|         for t in range(T): | ||||
|         for t in range(query_frame, T): | ||||
|             img = Image.fromarray(np.uint8(res_video[t])) | ||||
|             for i in range(N): | ||||
|                 coord = (tracks[t, i, 0], tracks[t, i, 1]) | ||||
|   | ||||
							
								
								
									
										9
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								demo.py
									
									
									
									
									
								
							| @@ -72,7 +72,7 @@ if __name__ == "__main__": | ||||
|         model = torch.hub.load("facebookresearch/co-tracker", "cotracker2") | ||||
|     model = model.to(DEFAULT_DEVICE) | ||||
|     video = video.to(DEFAULT_DEVICE) | ||||
|  | ||||
|     # video = video[:, :20] | ||||
|     pred_tracks, pred_visibility = model( | ||||
|         video, | ||||
|         grid_size=args.grid_size, | ||||
| @@ -85,4 +85,9 @@ if __name__ == "__main__": | ||||
|     # save a video with predicted tracks | ||||
|     seq_name = args.video_path.split("/")[-1] | ||||
|     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) | ||||
|     vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame) | ||||
|     vis.visualize( | ||||
|         video, | ||||
|         pred_tracks, | ||||
|         pred_visibility, | ||||
|         query_frame=0 if args.backward_tracking else args.grid_query_frame, | ||||
|     ) | ||||
|   | ||||
										
											
												File diff suppressed because one or more lines are too long
											
										
									
								
							
		Reference in New Issue
	
	Block a user