RAFT/demo.py

131 lines
4.3 KiB
Python
Raw Permalink Normal View History

2020-03-27 04:19:08 +01:00
import sys
sys.path.append('core')
import argparse
import os
import cv2
2020-07-26 01:36:17 +02:00
import glob
2020-03-27 04:19:08 +01:00
import numpy as np
import torch
from PIL import Image
2024-09-02 23:45:54 +02:00
import time
2020-03-27 04:19:08 +01:00
from raft import RAFT
2020-07-26 01:36:17 +02:00
from utils import flow_viz
from utils.utils import InputPadder
2020-03-27 04:19:08 +01:00
2020-07-26 01:36:17 +02:00
DEVICE = 'cuda'
2020-03-27 04:19:08 +01:00
def load_image(imfile):
2020-07-26 01:36:17 +02:00
img = np.array(Image.open(imfile)).astype(np.uint8)
2020-03-27 04:19:08 +01:00
img = torch.from_numpy(img).permute(2, 0, 1).float()
2020-10-05 22:08:29 +02:00
return img[None].to(DEVICE)
2020-03-27 04:19:08 +01:00
2020-07-26 01:36:17 +02:00
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
2020-03-27 04:19:08 +01:00
2020-10-05 22:08:29 +02:00
# import matplotlib.pyplot as plt
# plt.imshow(img_flo / 255.0)
# plt.show()
2020-07-26 01:36:17 +02:00
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
2020-03-27 04:19:08 +01:00
cv2.waitKey()
2024-09-02 23:45:54 +02:00
# def demo(args):
# model = torch.nn.DataParallel(RAFT(args))
# model.load_state_dict(torch.load(args.model))
# model = model.module
# model.to(DEVICE)
# model.eval()
# with torch.no_grad():
# images = glob.glob(os.path.join(args.path, '*.png')) + \
# glob.glob(os.path.join(args.path, '*.jpg'))
# images = sorted(images)
# for imfile1, imfile2 in zip(images[:-1], images[1:]):
# image1 = load_image(imfile1)
# image2 = load_image(imfile2)
# padder = InputPadder(image1.shape)
# image1, image2 = padder.pad(image1, image2)
# flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
# viz(image1, flow_up)
2020-03-27 04:19:08 +01:00
def demo(args):
2020-07-26 01:36:17 +02:00
model = torch.nn.DataParallel(RAFT(args))
2024-09-02 23:45:54 +02:00
print(f'start loading model from {args.model}')
2020-03-27 04:19:08 +01:00
model.load_state_dict(torch.load(args.model))
2024-09-02 23:45:54 +02:00
print('model loaded')
2020-03-27 04:19:08 +01:00
2020-07-26 01:36:17 +02:00
model = model.module
2020-03-27 04:19:08 +01:00
model.to(DEVICE)
model.eval()
2024-09-02 23:45:54 +02:00
i=0
2020-03-27 04:19:08 +01:00
with torch.no_grad():
2024-09-02 23:45:54 +02:00
capture = cv2.VideoCapture(args.video_path)
# fps = capture.get(cv2.CAP_PROP_FPS)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
# out = cv2.VideoWriter('./F1_1280.mp4',fourcc,fps,(1280,740))
ret,image1 = capture.read()
# image1 = cv2.resize(image1,(1280,720))
# out.write(image1)
print(image1.shape)
width = int(image1.shape[1])
height = int(image1.shape[0])
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
image1 = image1[None].to(DEVICE)
#width = int(img.shape[1])*2
out = cv2.VideoWriter(args.save_path,fourcc,30,(width,height*2))
if capture.isOpened():
start_time = time.time()
while True:
ret,image2 = capture.read()
if not ret:break
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
image2 = image2[None].to(DEVICE)
pre = image2
padder = InputPadder(image1.shape)
image1, image2 = padder.pad(image1, image2)
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
image1 = image1[0].permute(1,2,0).cpu().numpy()
flow_up = flow_up[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flow_up = flow_viz.flow_to_image(flow_up)
img_flo = np.concatenate([image1, flow_up], axis=0)
img_flo = img_flo[:, :, [2,1,0]]
out.write(np.uint8(img_flo))
image1 = pre
end_time = time.time()
print("time using:",end_time-start_time)
else:
print("open video error!")
out.release()
capture.release()
cv2.destroyAllWindows()
2020-03-27 04:19:08 +01:00
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
2020-07-26 01:36:17 +02:00
parser.add_argument('--path', help="dataset for evaluation")
2024-09-02 23:45:54 +02:00
parser.add_argument('--video_path', default='1.mp4', help="path to video")
parser.add_argument('--save_path', default='res_1.mp4', help="path to save video")
2020-03-27 04:19:08 +01:00
parser.add_argument('--small', action='store_true', help='use small model')
2020-07-26 01:36:17 +02:00
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
2020-03-27 04:19:08 +01:00
args = parser.parse_args()
2020-07-26 01:36:17 +02:00
demo(args)