make raft can process single video

This commit is contained in:
mhz 2024-09-02 23:45:54 +02:00
parent 3fa0bb0a9c
commit 73892c4392

83
demo.py
View File

@ -8,6 +8,7 @@ import glob
import numpy as np import numpy as np
import torch import torch
from PIL import Image from PIL import Image
import time
from raft import RAFT from raft import RAFT
from utils import flow_viz from utils import flow_viz
@ -39,34 +40,88 @@ def viz(img, flo):
cv2.waitKey() cv2.waitKey()
# def demo(args):
# model = torch.nn.DataParallel(RAFT(args))
# model.load_state_dict(torch.load(args.model))
# model = model.module
# model.to(DEVICE)
# model.eval()
# with torch.no_grad():
# images = glob.glob(os.path.join(args.path, '*.png')) + \
# glob.glob(os.path.join(args.path, '*.jpg'))
# images = sorted(images)
# for imfile1, imfile2 in zip(images[:-1], images[1:]):
# image1 = load_image(imfile1)
# image2 = load_image(imfile2)
# padder = InputPadder(image1.shape)
# image1, image2 = padder.pad(image1, image2)
# flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
# viz(image1, flow_up)
def demo(args): def demo(args):
model = torch.nn.DataParallel(RAFT(args)) model = torch.nn.DataParallel(RAFT(args))
print(f'start loading model from {args.model}')
model.load_state_dict(torch.load(args.model)) model.load_state_dict(torch.load(args.model))
print('model loaded')
model = model.module model = model.module
model.to(DEVICE) model.to(DEVICE)
model.eval() model.eval()
i=0
with torch.no_grad(): with torch.no_grad():
images = glob.glob(os.path.join(args.path, '*.png')) + \ capture = cv2.VideoCapture(args.video_path)
glob.glob(os.path.join(args.path, '*.jpg')) # fps = capture.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
images = sorted(images) # out = cv2.VideoWriter('./F1_1280.mp4',fourcc,fps,(1280,740))
for imfile1, imfile2 in zip(images[:-1], images[1:]): ret,image1 = capture.read()
image1 = load_image(imfile1) # image1 = cv2.resize(image1,(1280,720))
image2 = load_image(imfile2) # out.write(image1)
print(image1.shape)
padder = InputPadder(image1.shape) width = int(image1.shape[1])
image1, image2 = padder.pad(image1, image2) height = int(image1.shape[0])
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) image1 = image1[None].to(DEVICE)
viz(image1, flow_up) #width = int(img.shape[1])*2
out = cv2.VideoWriter(args.save_path,fourcc,30,(width,height*2))
if capture.isOpened():
start_time = time.time()
while True:
ret,image2 = capture.read()
if not ret:break
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
image2 = image2[None].to(DEVICE)
pre = image2
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
image1 = image1[0].permute(1,2,0).cpu().numpy()
flow_up = flow_up[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flow_up = flow_viz.flow_to_image(flow_up)
img_flo = np.concatenate([image1, flow_up], axis=0)
img_flo = img_flo[:, :, [2,1,0]]
out.write(np.uint8(img_flo))
image1 = pre
end_time = time.time()
print("time using:",end_time-start_time)
else:
print("open video error!")
out.release()
capture.release()
cv2.destroyAllWindows()
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint") parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--path', help="dataset for evaluation") parser.add_argument('--path', help="dataset for evaluation")
parser.add_argument('--video_path', default='1.mp4', help="path to video")
parser.add_argument('--save_path', default='res_1.mp4', help="path to save video")
parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')