import sys sys.path.append('core') import argparse import os import cv2 import glob import numpy as np import torch from PIL import Image import time from raft import RAFT from utils import flow_viz from utils.utils import InputPadder DEVICE = 'cuda' def load_image(imfile): img = np.array(Image.open(imfile)).astype(np.uint8) img = torch.from_numpy(img).permute(2, 0, 1).float() return img[None].to(DEVICE) def viz(img, flo): img = img[0].permute(1,2,0).cpu().numpy() flo = flo[0].permute(1,2,0).cpu().numpy() # map flow to rgb image flo = flow_viz.flow_to_image(flo) img_flo = np.concatenate([img, flo], axis=0) # import matplotlib.pyplot as plt # plt.imshow(img_flo / 255.0) # plt.show() cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) cv2.waitKey() # def demo(args): # model = torch.nn.DataParallel(RAFT(args)) # model.load_state_dict(torch.load(args.model)) # model = model.module # model.to(DEVICE) # model.eval() # with torch.no_grad(): # images = glob.glob(os.path.join(args.path, '*.png')) + \ # glob.glob(os.path.join(args.path, '*.jpg')) # images = sorted(images) # for imfile1, imfile2 in zip(images[:-1], images[1:]): # image1 = load_image(imfile1) # image2 = load_image(imfile2) # padder = InputPadder(image1.shape) # image1, image2 = padder.pad(image1, image2) # flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) # viz(image1, flow_up) def demo(args): model = torch.nn.DataParallel(RAFT(args)) print(f'start loading model from {args.model}') model.load_state_dict(torch.load(args.model)) print('model loaded') model = model.module model.to(DEVICE) model.eval() i=0 with torch.no_grad(): capture = cv2.VideoCapture(args.video_path) # fps = capture.get(cv2.CAP_PROP_FPS) fourcc = cv2.VideoWriter_fourcc(*'mp4v') # out = cv2.VideoWriter('./F1_1280.mp4',fourcc,fps,(1280,740)) ret,image1 = capture.read() # image1 = cv2.resize(image1,(1280,720)) # out.write(image1) print(image1.shape) width = int(image1.shape[1]) height = int(image1.shape[0]) image1 = torch.from_numpy(image1).permute(2, 0, 1).float() image1 = image1[None].to(DEVICE) #width = int(img.shape[1])*2 out = cv2.VideoWriter(args.save_path,fourcc,30,(width,height*2)) if capture.isOpened(): start_time = time.time() while True: ret,image2 = capture.read() if not ret:break image2 = torch.from_numpy(image2).permute(2, 0, 1).float() image2 = image2[None].to(DEVICE) pre = image2 padder = InputPadder(image1.shape) image1, image2 = padder.pad(image1, image2) flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) image1 = image1[0].permute(1,2,0).cpu().numpy() flow_up = flow_up[0].permute(1,2,0).cpu().numpy() # map flow to rgb image flow_up = flow_viz.flow_to_image(flow_up) img_flo = np.concatenate([image1, flow_up], axis=0) img_flo = img_flo[:, :, [2,1,0]] out.write(np.uint8(img_flo)) image1 = pre end_time = time.time() print("time using:",end_time-start_time) else: print("open video error!") out.release() capture.release() cv2.destroyAllWindows() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', help="restore checkpoint") parser.add_argument('--path', help="dataset for evaluation") parser.add_argument('--video_path', default='1.mp4', help="path to video") parser.add_argument('--save_path', default='res_1.mp4', help="path to save video") parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') args = parser.parse_args() demo(args)