Allows MPS inference. Fix visualization args
This commit is contained in:
		| @@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import ( | |||||||
|  |  | ||||||
| class CoTrackerPredictor(torch.nn.Module): | class CoTrackerPredictor(torch.nn.Module): | ||||||
|     def __init__( |     def __init__( | ||||||
|         self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None |         self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" | ||||||
|     ): |     ): | ||||||
|         super().__init__() |         super().__init__() | ||||||
|         self.interp_shape = (384, 512) |         self.interp_shape = (384, 512) | ||||||
|   | |||||||
							
								
								
									
										12
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								demo.py
									
									
									
									
									
								
							| @@ -14,6 +14,9 @@ from PIL import Image | |||||||
| from cotracker.utils.visualizer import Visualizer, read_video_from_path | from cotracker.utils.visualizer import Visualizer, read_video_from_path | ||||||
| from cotracker.predictor import CoTrackerPredictor | from cotracker.predictor import CoTrackerPredictor | ||||||
|  |  | ||||||
|  | DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else | ||||||
|  |                   'mps' if torch.backends.mps.is_available() else | ||||||
|  |                   'cpu') | ||||||
|  |  | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     parser = argparse.ArgumentParser() |     parser = argparse.ArgumentParser() | ||||||
| @@ -55,11 +58,8 @@ if __name__ == "__main__": | |||||||
|     segm_mask = torch.from_numpy(segm_mask)[None, None] |     segm_mask = torch.from_numpy(segm_mask)[None, None] | ||||||
|  |  | ||||||
|     model = CoTrackerPredictor(checkpoint=args.checkpoint) |     model = CoTrackerPredictor(checkpoint=args.checkpoint) | ||||||
|     if torch.cuda.is_available(): |     model = model.to(DEFAULT_DEVICE) | ||||||
|         model = model.cuda() |     video = video.to(DEFAULT_DEVICE) | ||||||
|         video = video.cuda() |  | ||||||
|     else: |  | ||||||
|         print("CUDA is not available!") |  | ||||||
|  |  | ||||||
|     pred_tracks, pred_visibility = model( |     pred_tracks, pred_visibility = model( | ||||||
|         video, |         video, | ||||||
| @@ -73,4 +73,4 @@ if __name__ == "__main__": | |||||||
|     # save a video with predicted tracks |     # save a video with predicted tracks | ||||||
|     seq_name = args.video_path.split("/")[-1] |     seq_name = args.video_path.split("/")[-1] | ||||||
|     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) |     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) | ||||||
|     vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame) |     vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame) | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user