diff --git a/demo.py b/demo.py index 88e5975..5abc1da 100644 --- a/demo.py +++ b/demo.py @@ -20,21 +20,9 @@ DEVICE = 'cuda' def load_image(imfile): img = np.array(Image.open(imfile)).astype(np.uint8) img = torch.from_numpy(img).permute(2, 0, 1).float() - return img + return img[None].to(DEVICE) -def load_image_list(image_files): - images = [] - for imfile in sorted(image_files): - images.append(load_image(imfile)) - - images = torch.stack(images, dim=0) - images = images.to(DEVICE) - - padder = InputPadder(images.shape) - return padder.pad(images)[0] - - def viz(img, flo): img = img[0].permute(1,2,0).cpu().numpy() flo = flo[0].permute(1,2,0).cpu().numpy() @@ -43,6 +31,10 @@ def viz(img, flo): flo = flow_viz.flow_to_image(flo) img_flo = np.concatenate([img, flo], axis=0) + # import matplotlib.pyplot as plt + # plt.imshow(img_flo / 255.0) + # plt.show() + cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) cv2.waitKey() @@ -58,11 +50,14 @@ def demo(args): with torch.no_grad(): images = glob.glob(os.path.join(args.path, '*.png')) + \ glob.glob(os.path.join(args.path, '*.jpg')) + + images = sorted(images) + for imfile1, imfile2 in zip(images[:-1], images[1:]): + image1 = load_image(imfile1) + image2 = load_image(imfile2) - images = load_image_list(images) - for i in range(images.shape[0]-1): - image1 = images[i,None] - image2 = images[i+1,None] + padder = InputPadder(image1.shape) + image1, image2 = padder.pad(image1, image2) flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) viz(image1, flow_up)