Allows MPS inference. Fix visualization args
This commit is contained in:
		| @@ -17,7 +17,7 @@ from cotracker.models.build_cotracker import ( | ||||
|  | ||||
| class CoTrackerPredictor(torch.nn.Module): | ||||
|     def __init__( | ||||
|         self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth", device=None | ||||
|         self, checkpoint="cotracker/checkpoints/cotracker_stride_4_wind_8.pth" | ||||
|     ): | ||||
|         super().__init__() | ||||
|         self.interp_shape = (384, 512) | ||||
|   | ||||
							
								
								
									
										12
									
								
								demo.py
									
									
									
									
									
								
							
							
						
						
									
										12
									
								
								demo.py
									
									
									
									
									
								
							| @@ -14,6 +14,9 @@ from PIL import Image | ||||
| from cotracker.utils.visualizer import Visualizer, read_video_from_path | ||||
| from cotracker.predictor import CoTrackerPredictor | ||||
|  | ||||
| DEFAULT_DEVICE = ('cuda' if torch.cuda.is_available() else | ||||
|                   'mps' if torch.backends.mps.is_available() else | ||||
|                   'cpu') | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser() | ||||
| @@ -55,11 +58,8 @@ if __name__ == "__main__": | ||||
|     segm_mask = torch.from_numpy(segm_mask)[None, None] | ||||
|  | ||||
|     model = CoTrackerPredictor(checkpoint=args.checkpoint) | ||||
|     if torch.cuda.is_available(): | ||||
|         model = model.cuda() | ||||
|         video = video.cuda() | ||||
|     else: | ||||
|         print("CUDA is not available!") | ||||
|     model = model.to(DEFAULT_DEVICE) | ||||
|     video = video.to(DEFAULT_DEVICE) | ||||
|  | ||||
|     pred_tracks, pred_visibility = model( | ||||
|         video, | ||||
| @@ -73,4 +73,4 @@ if __name__ == "__main__": | ||||
|     # save a video with predicted tracks | ||||
|     seq_name = args.video_path.split("/")[-1] | ||||
|     vis = Visualizer(save_dir="./saved_videos", pad_value=120, linewidth=3) | ||||
|     vis.visualize(video, pred_tracks, query_frame=args.grid_query_frame) | ||||
|     vis.visualize(video, pred_tracks, pred_visibility, query_frame=args.grid_query_frame) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user