Merge pull request #14 from JunkyByte/main
minor fixes / mps default device when available / occlusion visualization
This commit is contained in:
		| @@ -133,7 +133,7 @@ class CoTrackerPredictor(torch.nn.Module): | ||||
|             ) | ||||
|  | ||||
|         if add_support_grid: | ||||
|             grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape) | ||||
|             grid_pts = get_points_on_a_grid(self.support_grid_size, self.interp_shape, device=video.device) | ||||
|             grid_pts = torch.cat( | ||||
|                 [torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2 | ||||
|             ) | ||||
|   | ||||
| @@ -62,6 +62,7 @@ class Visualizer: | ||||
|         self, | ||||
|         video: torch.Tensor,  # (B,T,C,H,W) | ||||
|         tracks: torch.Tensor,  # (B,T,N,2) | ||||
|         visibility: torch.Tensor = None,  # (B, T, N, 1) bool | ||||
|         gt_tracks: torch.Tensor = None,  # (B,T,N,2) | ||||
|         segm_mask: torch.Tensor = None,  # (B,1,H,W) | ||||
|         filename: str = "video", | ||||
| @@ -93,6 +94,7 @@ class Visualizer: | ||||
|         res_video = self.draw_tracks_on_video( | ||||
|             video=video, | ||||
|             tracks=tracks, | ||||
|             visibility=visibility, | ||||
|             segm_mask=segm_mask, | ||||
|             gt_tracks=gt_tracks, | ||||
|             query_frame=query_frame, | ||||
| @@ -126,6 +128,7 @@ class Visualizer: | ||||
|         self, | ||||
|         video: torch.Tensor, | ||||
|         tracks: torch.Tensor, | ||||
|         visibility: torch.Tensor = None, | ||||
|         segm_mask: torch.Tensor = None, | ||||
|         gt_tracks=None, | ||||
|         query_frame: int = 0, | ||||
| @@ -227,11 +230,13 @@ class Visualizer: | ||||
|                     if not compensate_for_camera_motion or ( | ||||
|                         compensate_for_camera_motion and segm_mask[i] > 0 | ||||
|                     ): | ||||
|  | ||||
|                         cv2.circle( | ||||
|                             res_video[t], | ||||
|                             coord, | ||||
|                             int(self.linewidth * 2), | ||||
|                             vector_colors[t, i].tolist(), | ||||
|                             thickness=-1 if visibility[0, t, i] else 2 | ||||
|                             -1, | ||||
|                         ) | ||||
|  | ||||
|   | ||||
							
								
								
									
										12
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								demo.py
									
									
									
									
									
								
							| @@ -14,6 +14,9 @@ from PIL import Image | ||||
| from cotracker.utils.visualizer import Visualizer, read_video_from_path | ||||
| from cotracker.predictor import CoTrackerPredictor | ||||
|  | ||||
| DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else | ||||
|                   'mps' if torch.backends.mps.is_available() else | ||||
|                   'cpu') | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
| @@ -55,11 +58,8 @@ if __name__ == "__main__": | ||||
|     segm_mask = torch.from_numpy(segm_mask)[None, None] | ||||
|  | ||||
|     model = CoTrackerPredictor(checkpoint=args.checkpoint) | ||||
|     if torch.cuda.is_available(): | ||||
|         model = model.cuda() | ||||
|         video = video.cuda() | ||||
|     else: | ||||
|         print("CUDA is not available!") | ||||
|     model = model.to(DEFAULT_DEVICE) | ||||
|     video = video.to(DEFAULT_DEVICE) | ||||
|  | ||||
|     pred_tracks, pred_visibility = model( | ||||
|         video, | ||||
| @@ -73,4 +73,4 @@ if __name__ == "__main__": | ||||
|     # save a video with predicted tracks | ||||
|     seq_name = args.video_path.split("/")[-1] | ||||
|     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) | ||||
|     vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame) | ||||
|     vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user