add time measure codes and update resolution

This commit is contained in:
mhz 2024-08-12 22:37:51 +02:00
parent 40e628ac73
commit be0891967b

View File

@ -3,14 +3,25 @@ import torch
from base64 import b64encode from base64 import b64encode
from cotracker.utils.visualizer import Visualizer, read_video_from_path from cotracker.utils.visualizer import Visualizer, read_video_from_path
import numpy as np
from PIL import Image
import time
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
start_time = time.time()
print(f'Using device: {device}')
print(f'start loading video')
video = read_video_from_path('./assets/F1_shorts.mp4') video = read_video_from_path('./assets/F1_shorts.mp4')
print(f'video shape: {video.shape}')
# video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float().to(device)
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float() video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
end_time = time.time()
print(f'video shape after permute: {video.shape}')
print("Load video Time taken: {:.2f} seconds".format(end_time - start_time))
from cotracker.predictor import CoTrackerPredictor from cotracker.predictor import CoTrackerPredictor
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
model = CoTrackerPredictor( model = CoTrackerPredictor(
checkpoint=os.path.join( checkpoint=os.path.join(
@ -27,25 +38,45 @@ grid_query_frame=20
import torch.nn.functional as F import torch.nn.functional as F
# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device) # video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
interp_size = (720, 1280)
video_interp = F.interpolate(video[0], [interp_size[0], interp_size[1]], mode="bilinear")[None].to(device)
print(f'video_interp shape: {video_interp.shape}')
import time
start_time = time.time() start_time = time.time()
# pred_tracks, pred_visibility = model(video_interp, # pred_tracks, pred_visibility = model(video_interp,
pred_tracks, pred_visibility = model(video, input_mask='./assets/F1_mask.png'
grid_query_frame=grid_query_frame, backward_tracking=True) segm_mask = Image.open(input_mask)
interp_size = (interp_size[1], interp_size[0])
segm_mask = segm_mask.resize(interp_size, Image.BILINEAR)
segm_mask = np.array(Image.open(input_mask))
segm_mask = torch.tensor(segm_mask).to(device)
# pred_tracks, pred_visibility = model(video,
pred_tracks, pred_visibility = model(video_interp,
grid_query_frame=grid_query_frame, backward_tracking=True,
segm_mask=segm_mask )
end_time = time.time() end_time = time.time()
print("Time taken: {:.2f} seconds".format(end_time - start_time)) print("Time taken: {:.2f} seconds".format(end_time - start_time))
start_time = time.time()
print(f'start visualizing')
vis = Visualizer( vis = Visualizer(
save_dir='./videos', save_dir='./videos',
pad_value=20, pad_value=20,
linewidth=1, linewidth=1,
mode='optical_flow' mode='optical_flow'
) )
print(f'vis initialized')
end_time = time.time()
print("Time taken: {:.2f} seconds".format(end_time - start_time))
start_time = time.time()
print(f'start visualize')
vis.visualize( vis.visualize(
# video=video_interp, video=video_interp,
video=video, # video=video,
tracks=pred_tracks, tracks=pred_tracks,
visibility=pred_visibility, visibility=pred_visibility,
filename='dense'); filename='dense2');
print(f'done')
end_time = time.time()
print("Time taken: {:.2f} seconds".format(end_time - start_time))