make raft can process single video
This commit is contained in:
parent
3fa0bb0a9c
commit
73892c4392
83
demo.py
83
demo.py
@ -8,6 +8,7 @@ import glob
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
import time
|
||||||
|
|
||||||
from raft import RAFT
|
from raft import RAFT
|
||||||
from utils import flow_viz
|
from utils import flow_viz
|
||||||
@ -39,34 +40,88 @@ def viz(img, flo):
|
|||||||
cv2.waitKey()
|
cv2.waitKey()
|
||||||
|
|
||||||
|
|
||||||
|
# def demo(args):
|
||||||
|
# model = torch.nn.DataParallel(RAFT(args))
|
||||||
|
# model.load_state_dict(torch.load(args.model))
|
||||||
|
|
||||||
|
# model = model.module
|
||||||
|
# model.to(DEVICE)
|
||||||
|
# model.eval()
|
||||||
|
|
||||||
|
# with torch.no_grad():
|
||||||
|
# images = glob.glob(os.path.join(args.path, '*.png')) + \
|
||||||
|
# glob.glob(os.path.join(args.path, '*.jpg'))
|
||||||
|
|
||||||
|
# images = sorted(images)
|
||||||
|
# for imfile1, imfile2 in zip(images[:-1], images[1:]):
|
||||||
|
# image1 = load_image(imfile1)
|
||||||
|
# image2 = load_image(imfile2)
|
||||||
|
|
||||||
|
# padder = InputPadder(image1.shape)
|
||||||
|
# image1, image2 = padder.pad(image1, image2)
|
||||||
|
|
||||||
|
# flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
||||||
|
# viz(image1, flow_up)
|
||||||
def demo(args):
|
def demo(args):
|
||||||
model = torch.nn.DataParallel(RAFT(args))
|
model = torch.nn.DataParallel(RAFT(args))
|
||||||
|
print(f'start loading model from {args.model}')
|
||||||
model.load_state_dict(torch.load(args.model))
|
model.load_state_dict(torch.load(args.model))
|
||||||
|
print('model loaded')
|
||||||
|
|
||||||
|
|
||||||
model = model.module
|
model = model.module
|
||||||
model.to(DEVICE)
|
model.to(DEVICE)
|
||||||
model.eval()
|
model.eval()
|
||||||
|
i=0
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
capture = cv2.VideoCapture(args.video_path)
|
||||||
glob.glob(os.path.join(args.path, '*.jpg'))
|
# fps = capture.get(cv2.CAP_PROP_FPS)
|
||||||
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||||
images = sorted(images)
|
# out = cv2.VideoWriter('./F1_1280.mp4',fourcc,fps,(1280,740))
|
||||||
for imfile1, imfile2 in zip(images[:-1], images[1:]):
|
ret,image1 = capture.read()
|
||||||
image1 = load_image(imfile1)
|
# image1 = cv2.resize(image1,(1280,720))
|
||||||
image2 = load_image(imfile2)
|
# out.write(image1)
|
||||||
|
print(image1.shape)
|
||||||
padder = InputPadder(image1.shape)
|
width = int(image1.shape[1])
|
||||||
image1, image2 = padder.pad(image1, image2)
|
height = int(image1.shape[0])
|
||||||
|
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
|
||||||
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
image1 = image1[None].to(DEVICE)
|
||||||
viz(image1, flow_up)
|
#width = int(img.shape[1])*2
|
||||||
|
out = cv2.VideoWriter(args.save_path,fourcc,30,(width,height*2))
|
||||||
|
if capture.isOpened():
|
||||||
|
start_time = time.time()
|
||||||
|
while True:
|
||||||
|
ret,image2 = capture.read()
|
||||||
|
if not ret:break
|
||||||
|
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
|
||||||
|
image2 = image2[None].to(DEVICE)
|
||||||
|
pre = image2
|
||||||
|
padder = InputPadder(image1.shape)
|
||||||
|
image1, image2 = padder.pad(image1, image2)
|
||||||
|
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
||||||
|
image1 = image1[0].permute(1,2,0).cpu().numpy()
|
||||||
|
flow_up = flow_up[0].permute(1,2,0).cpu().numpy()
|
||||||
|
# map flow to rgb image
|
||||||
|
flow_up = flow_viz.flow_to_image(flow_up)
|
||||||
|
img_flo = np.concatenate([image1, flow_up], axis=0)
|
||||||
|
img_flo = img_flo[:, :, [2,1,0]]
|
||||||
|
out.write(np.uint8(img_flo))
|
||||||
|
image1 = pre
|
||||||
|
end_time = time.time()
|
||||||
|
print("time using:",end_time-start_time)
|
||||||
|
|
||||||
|
else:
|
||||||
|
print("open video error!")
|
||||||
|
out.release()
|
||||||
|
capture.release()
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--model', help="restore checkpoint")
|
parser.add_argument('--model', help="restore checkpoint")
|
||||||
parser.add_argument('--path', help="dataset for evaluation")
|
parser.add_argument('--path', help="dataset for evaluation")
|
||||||
|
parser.add_argument('--video_path', default='1.mp4', help="path to video")
|
||||||
|
parser.add_argument('--save_path', default='res_1.mp4', help="path to save video")
|
||||||
parser.add_argument('--small', action='store_true', help='use small model')
|
parser.add_argument('--small', action='store_true', help='use small model')
|
||||||
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
|
||||||
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
|
||||||
|
Loading…
Reference in New Issue
Block a user