diff --git a/demo1.py b/demo1.py index 503984e..5a0c7c5 100644 --- a/demo1.py +++ b/demo1.py @@ -3,14 +3,25 @@ import torch from base64 import b64encode from cotracker.utils.visualizer import Visualizer, read_video_from_path +import numpy as np +from PIL import Image +import time +device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu') +start_time = time.time() +print(f'Using device: {device}') +print(f'start loading video') video = read_video_from_path('./assets/F1_shorts.mp4') +print(f'video shape: {video.shape}') +# video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float().to(device) video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float() +end_time = time.time() +print(f'video shape after permute: {video.shape}') +print("Load video Time taken: {:.2f} seconds".format(end_time - start_time)) from cotracker.predictor import CoTrackerPredictor -device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu') model = CoTrackerPredictor( checkpoint=os.path.join( @@ -27,25 +38,45 @@ grid_query_frame=20 import torch.nn.functional as F # video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device) +interp_size = (720, 1280) +video_interp = F.interpolate(video[0], [interp_size[0], interp_size[1]], mode="bilinear")[None].to(device) +print(f'video_interp shape: {video_interp.shape}') -import time start_time = time.time() # pred_tracks, pred_visibility = model(video_interp, -pred_tracks, pred_visibility = model(video, - grid_query_frame=grid_query_frame, backward_tracking=True) +input_mask='./assets/F1_mask.png' +segm_mask = Image.open(input_mask) +interp_size = (interp_size[1], interp_size[0]) +segm_mask = segm_mask.resize(interp_size, Image.BILINEAR) +segm_mask = np.array(Image.open(input_mask)) +segm_mask = torch.tensor(segm_mask).to(device) +# pred_tracks, pred_visibility = model(video, +pred_tracks, pred_visibility = model(video_interp, + grid_query_frame=grid_query_frame, backward_tracking=True, + segm_mask=segm_mask ) end_time = time.time() print("Time taken: {:.2f} seconds".format(end_time - start_time)) +start_time = time.time() +print(f'start visualizing') vis = Visualizer( save_dir='./videos', pad_value=20, linewidth=1, mode='optical_flow' ) +print(f'vis initialized') +end_time = time.time() +print("Time taken: {:.2f} seconds".format(end_time - start_time)) +start_time = time.time() +print(f'start visualize') vis.visualize( - # video=video_interp, - video=video, + video=video_interp, + # video=video, tracks=pred_tracks, visibility=pred_visibility, - filename='dense'); \ No newline at end of file + filename='dense2'); +print(f'done') +end_time = time.time() +print("Time taken: {:.2f} seconds".format(end_time - start_time)) \ No newline at end of file