add time measure codes and update resolution
This commit is contained in:
parent
40e628ac73
commit
be0891967b
45
demo1.py
45
demo1.py
@ -3,14 +3,25 @@ import torch
|
|||||||
|
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
from cotracker.utils.visualizer import Visualizer, read_video_from_path
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import time
|
||||||
|
|
||||||
|
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
print(f'Using device: {device}')
|
||||||
|
print(f'start loading video')
|
||||||
video = read_video_from_path('./assets/F1_shorts.mp4')
|
video = read_video_from_path('./assets/F1_shorts.mp4')
|
||||||
|
print(f'video shape: {video.shape}')
|
||||||
|
# video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float().to(device)
|
||||||
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
|
video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()
|
||||||
|
end_time = time.time()
|
||||||
|
print(f'video shape after permute: {video.shape}')
|
||||||
|
print("Load video Time taken: {:.2f} seconds".format(end_time - start_time))
|
||||||
|
|
||||||
from cotracker.predictor import CoTrackerPredictor
|
from cotracker.predictor import CoTrackerPredictor
|
||||||
|
|
||||||
device = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')
|
|
||||||
|
|
||||||
model = CoTrackerPredictor(
|
model = CoTrackerPredictor(
|
||||||
checkpoint=os.path.join(
|
checkpoint=os.path.join(
|
||||||
@ -27,25 +38,45 @@ grid_query_frame=20
|
|||||||
|
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
|
# video_interp = F.interpolate(video[0], [200, 360], mode="bilinear")[None].to(device)
|
||||||
|
interp_size = (720, 1280)
|
||||||
|
video_interp = F.interpolate(video[0], [interp_size[0], interp_size[1]], mode="bilinear")[None].to(device)
|
||||||
|
print(f'video_interp shape: {video_interp.shape}')
|
||||||
|
|
||||||
import time
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
# pred_tracks, pred_visibility = model(video_interp,
|
# pred_tracks, pred_visibility = model(video_interp,
|
||||||
pred_tracks, pred_visibility = model(video,
|
input_mask='./assets/F1_mask.png'
|
||||||
grid_query_frame=grid_query_frame, backward_tracking=True)
|
segm_mask = Image.open(input_mask)
|
||||||
|
interp_size = (interp_size[1], interp_size[0])
|
||||||
|
segm_mask = segm_mask.resize(interp_size, Image.BILINEAR)
|
||||||
|
segm_mask = np.array(Image.open(input_mask))
|
||||||
|
segm_mask = torch.tensor(segm_mask).to(device)
|
||||||
|
# pred_tracks, pred_visibility = model(video,
|
||||||
|
pred_tracks, pred_visibility = model(video_interp,
|
||||||
|
grid_query_frame=grid_query_frame, backward_tracking=True,
|
||||||
|
segm_mask=segm_mask )
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
print("Time taken: {:.2f} seconds".format(end_time - start_time))
|
print("Time taken: {:.2f} seconds".format(end_time - start_time))
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
print(f'start visualizing')
|
||||||
vis = Visualizer(
|
vis = Visualizer(
|
||||||
save_dir='./videos',
|
save_dir='./videos',
|
||||||
pad_value=20,
|
pad_value=20,
|
||||||
linewidth=1,
|
linewidth=1,
|
||||||
mode='optical_flow'
|
mode='optical_flow'
|
||||||
)
|
)
|
||||||
|
print(f'vis initialized')
|
||||||
|
end_time = time.time()
|
||||||
|
print("Time taken: {:.2f} seconds".format(end_time - start_time))
|
||||||
|
start_time = time.time()
|
||||||
|
print(f'start visualize')
|
||||||
vis.visualize(
|
vis.visualize(
|
||||||
# video=video_interp,
|
video=video_interp,
|
||||||
video=video,
|
# video=video,
|
||||||
tracks=pred_tracks,
|
tracks=pred_tracks,
|
||||||
visibility=pred_visibility,
|
visibility=pred_visibility,
|
||||||
filename='dense');
|
filename='dense2');
|
||||||
|
print(f'done')
|
||||||
|
end_time = time.time()
|
||||||
|
print("Time taken: {:.2f} seconds".format(end_time - start_time))
|
Loading…
Reference in New Issue
Block a user