RAFT/demo.py

80 lines
2.0 KiB
Python
Raw Normal View History

2020-03-27 04:19:08 +01:00
import sys
sys.path.append('core')
import argparse
import os
import cv2
2020-07-26 01:36:17 +02:00
import glob
2020-03-27 04:19:08 +01:00
import numpy as np
import torch
from PIL import Image
from raft import RAFT
2020-07-26 01:36:17 +02:00
from utils import flow_viz
from utils.utils import InputPadder
2020-03-27 04:19:08 +01:00
2020-07-26 01:36:17 +02:00
DEVICE = 'cuda'
2020-03-27 04:19:08 +01:00
def load_image(imfile):
2020-07-26 01:36:17 +02:00
img = np.array(Image.open(imfile)).astype(np.uint8)
2020-03-27 04:19:08 +01:00
img = torch.from_numpy(img).permute(2, 0, 1).float()
2020-07-26 01:36:17 +02:00
return img
2020-03-27 04:19:08 +01:00
2020-07-26 01:36:17 +02:00
def load_image_list(image_files):
images = []
for imfile in sorted(image_files):
images.append(load_image(imfile))
images = torch.stack(images, dim=0)
images = images.to(DEVICE)
2020-03-27 04:19:08 +01:00
2020-07-26 01:36:17 +02:00
padder = InputPadder(images.shape)
return padder.pad(images)[0]
2020-03-27 04:19:08 +01:00
2020-07-26 01:36:17 +02:00
def viz(img, flo):
img = img[0].permute(1,2,0).cpu().numpy()
flo = flo[0].permute(1,2,0).cpu().numpy()
# map flow to rgb image
flo = flow_viz.flow_to_image(flo)
img_flo = np.concatenate([img, flo], axis=0)
2020-03-27 04:19:08 +01:00
2020-07-26 01:36:17 +02:00
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
2020-03-27 04:19:08 +01:00
cv2.waitKey()
def demo(args):
2020-07-26 01:36:17 +02:00
model = torch.nn.DataParallel(RAFT(args))
2020-03-27 04:19:08 +01:00
model.load_state_dict(torch.load(args.model))
2020-07-26 01:36:17 +02:00
model = model.module
2020-03-27 04:19:08 +01:00
model.to(DEVICE)
model.eval()
with torch.no_grad():
2020-07-26 01:36:17 +02:00
images = glob.glob(os.path.join(args.path, '*.png')) + \
glob.glob(os.path.join(args.path, '*.jpg'))
2020-03-27 04:19:08 +01:00
2020-07-26 01:36:17 +02:00
images = load_image_list(images)
for i in range(images.shape[0]-1):
image1 = images[i,None]
image2 = images[i+1,None]
2020-03-27 04:19:08 +01:00
2020-07-26 01:36:17 +02:00
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
viz(image1, flow_up)
2020-03-27 04:19:08 +01:00
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--model', help="restore checkpoint")
2020-07-26 01:36:17 +02:00
parser.add_argument('--path', help="dataset for evaluation")
2020-03-27 04:19:08 +01:00
parser.add_argument('--small', action='store_true', help='use small model')
2020-07-26 01:36:17 +02:00
parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
2020-03-27 04:19:08 +01:00
args = parser.parse_args()
2020-07-26 01:36:17 +02:00
demo(args)