diff --git a/demo.py b/demo.py index 5abc1da..7bde016 100644 --- a/demo.py +++ b/demo.py @@ -8,6 +8,7 @@ import glob import numpy as np import torch from PIL import Image +import time from raft import RAFT from utils import flow_viz @@ -39,34 +40,88 @@ def viz(img, flo): cv2.waitKey() +# def demo(args): +# model = torch.nn.DataParallel(RAFT(args)) +# model.load_state_dict(torch.load(args.model)) + +# model = model.module +# model.to(DEVICE) +# model.eval() + +# with torch.no_grad(): +# images = glob.glob(os.path.join(args.path, '*.png')) + \ +# glob.glob(os.path.join(args.path, '*.jpg')) + +# images = sorted(images) +# for imfile1, imfile2 in zip(images[:-1], images[1:]): +# image1 = load_image(imfile1) +# image2 = load_image(imfile2) + +# padder = InputPadder(image1.shape) +# image1, image2 = padder.pad(image1, image2) + +# flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) +# viz(image1, flow_up) def demo(args): model = torch.nn.DataParallel(RAFT(args)) + print(f'start loading model from {args.model}') model.load_state_dict(torch.load(args.model)) + print('model loaded') + model = model.module model.to(DEVICE) model.eval() - + i=0 with torch.no_grad(): - images = glob.glob(os.path.join(args.path, '*.png')) + \ - glob.glob(os.path.join(args.path, '*.jpg')) - - images = sorted(images) - for imfile1, imfile2 in zip(images[:-1], images[1:]): - image1 = load_image(imfile1) - image2 = load_image(imfile2) - - padder = InputPadder(image1.shape) - image1, image2 = padder.pad(image1, image2) - - flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) - viz(image1, flow_up) + capture = cv2.VideoCapture(args.video_path) + # fps = capture.get(cv2.CAP_PROP_FPS) + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # out = cv2.VideoWriter('./F1_1280.mp4',fourcc,fps,(1280,740)) + ret,image1 = capture.read() + # image1 = cv2.resize(image1,(1280,720)) + # out.write(image1) + print(image1.shape) + width = int(image1.shape[1]) + height = int(image1.shape[0]) + image1 = torch.from_numpy(image1).permute(2, 0, 1).float() + image1 = image1[None].to(DEVICE) + #width = int(img.shape[1])*2 + out = cv2.VideoWriter(args.save_path,fourcc,30,(width,height*2)) + if capture.isOpened(): + start_time = time.time() + while True: + ret,image2 = capture.read() + if not ret:break + image2 = torch.from_numpy(image2).permute(2, 0, 1).float() + image2 = image2[None].to(DEVICE) + pre = image2 + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) + flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) + image1 = image1[0].permute(1,2,0).cpu().numpy() + flow_up = flow_up[0].permute(1,2,0).cpu().numpy() + # map flow to rgb image + flow_up = flow_viz.flow_to_image(flow_up) + img_flo = np.concatenate([image1, flow_up], axis=0) + img_flo = img_flo[:, :, [2,1,0]] + out.write(np.uint8(img_flo)) + image1 = pre + end_time = time.time() + print("time using:",end_time-start_time) + else: + print("open video error!") + out.release() + capture.release() + cv2.destroyAllWindows() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--model', help="restore checkpoint") parser.add_argument('--path', help="dataset for evaluation") + parser.add_argument('--video_path', default='1.mp4', help="path to video") + parser.add_argument('--save_path', default='res_1.mp4', help="path to save video") parser.add_argument('--small', action='store_true', help='use small model') parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')