RAFT/demo.py

131 lines
4.3 KiB
Python

import sys
sys.path.append('core')
import argparse
import os
import cv2
import glob
import numpy as np
import torch
from PIL import Image
import time
from raft import RAFT
from utils import flow_viz
from utils.utils import InputPadder
DEVICE = 'cuda'
def load_image(imfile):
img = np.array(Image.open(imfile)).astype(np.uint8)
img = torch.from_numpy(img).permute(2, 0, 1).float()
return img[None].to(DEVICE)
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
# import matplotlib.pyplot as plt
# plt.imshow(img_flo / 255.0)
# plt.show()
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
cv2.waitKey()
# def demo(args):
# model = torch.nn.DataParallel(RAFT(args))
# model.load_state_dict(torch.load(args.model))
# model = model.module
# model.to(DEVICE)
# model.eval()
# with torch.no_grad():
# images = glob.glob(os.path.join(args.path, '*.png')) + \
# glob.glob(os.path.join(args.path, '*.jpg'))
# images = sorted(images)
# for imfile1, imfile2 in zip(images[:-1], images[1:]):
# image1 = load_image(imfile1)
# image2 = load_image(imfile2)
# padder = InputPadder(image1.shape)
# image1, image2 = padder.pad(image1, image2)
# flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
# viz(image1, flow_up)
def demo(args):
model = torch.nn.DataParallel(RAFT(args))
print(f'start loading model from {args.model}')
model.load_state_dict(torch.load(args.model))
print('model loaded')
model = model.module
model.to(DEVICE)
model.eval()
i=0
with torch.no_grad():
capture = cv2.VideoCapture(args.video_path)
# fps = capture.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# out = cv2.VideoWriter('./F1_1280.mp4',fourcc,fps,(1280,740))
ret,image1 = capture.read()
# image1 = cv2.resize(image1,(1280,720))
# out.write(image1)
print(image1.shape)
width = int(image1.shape[1])
height = int(image1.shape[0])
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
image1 = image1[None].to(DEVICE)
#width = int(img.shape[1])*2
out = cv2.VideoWriter(args.save_path,fourcc,30,(width,height*2))
if capture.isOpened():
start_time = time.time()
while True:
ret,image2 = capture.read()
if not ret:break
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
image2 = image2[None].to(DEVICE)
pre = image2
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
image1 = image1[0].permute(1,2,0).cpu().numpy()
flow_up = flow_up[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flow_up = flow_viz.flow_to_image(flow_up)
img_flo = np.concatenate([image1, flow_up], axis=0)
img_flo = img_flo[:, :, [2,1,0]]
out.write(np.uint8(img_flo))
image1 = pre
end_time = time.time()
print("time using:",end_time-start_time)
else:
print("open video error!")
out.release()
capture.release()
cv2.destroyAllWindows()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
parser.add_argument('--path', help="dataset for evaluation")
parser.add_argument('--video_path', default='1.mp4', help="path to video")
parser.add_argument('--save_path', default='res_1.mp4', help="path to save video")
parser.add_argument('--small', action='store_true', help='use small model')
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
args = parser.parse_args()
demo(args)