updated demo for longer sequences
This commit is contained in:
parent
25eb2ac723
commit
d3f3840186
29
demo.py
29
demo.py
@ -20,21 +20,9 @@ DEVICE = 'cuda'
|
|||||||
def load_image(imfile):
|
def load_image(imfile):
|
||||||
img = np.array(Image.open(imfile)).astype(np.uint8)
|
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||||
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||||
return img
|
return img[None].to(DEVICE)
|
||||||
|
|
||||||
|
|
||||||
def load_image_list(image_files):
|
|
||||||
images = []
|
|
||||||
for imfile in sorted(image_files):
|
|
||||||
images.append(load_image(imfile))
|
|
||||||
|
|
||||||
images = torch.stack(images, dim=0)
|
|
||||||
images = images.to(DEVICE)
|
|
||||||
|
|
||||||
padder = InputPadder(images.shape)
|
|
||||||
return padder.pad(images)[0]
|
|
||||||
|
|
||||||
|
|
||||||
def viz(img, flo):
|
def viz(img, flo):
|
||||||
img = img[0].permute(1,2,0).cpu().numpy()
|
img = img[0].permute(1,2,0).cpu().numpy()
|
||||||
flo = flo[0].permute(1,2,0).cpu().numpy()
|
flo = flo[0].permute(1,2,0).cpu().numpy()
|
||||||
@ -43,6 +31,10 @@ def viz(img, flo):
|
|||||||
flo = flow_viz.flow_to_image(flo)
|
flo = flow_viz.flow_to_image(flo)
|
||||||
img_flo = np.concatenate([img, flo], axis=0)
|
img_flo = np.concatenate([img, flo], axis=0)
|
||||||
|
|
||||||
|
# import matplotlib.pyplot as plt
|
||||||
|
# plt.imshow(img_flo / 255.0)
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
||||||
cv2.waitKey()
|
cv2.waitKey()
|
||||||
|
|
||||||
@ -58,11 +50,14 @@ def demo(args):
|
|||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
||||||
glob.glob(os.path.join(args.path, '*.jpg'))
|
glob.glob(os.path.join(args.path, '*.jpg'))
|
||||||
|
|
||||||
|
images = sorted(images)
|
||||||
|
for imfile1, imfile2 in zip(images[:-1], images[1:]):
|
||||||
|
image1 = load_image(imfile1)
|
||||||
|
image2 = load_image(imfile2)
|
||||||
|
|
||||||
images = load_image_list(images)
|
padder = InputPadder(image1.shape)
|
||||||
for i in range(images.shape[0]-1):
|
image1, image2 = padder.pad(image1, image2)
|
||||||
image1 = images[i,None]
|
|
||||||
image2 = images[i+1,None]
|
|
||||||
|
|
||||||
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
||||||
viz(image1, flow_up)
|
viz(image1, flow_up)
|
||||||
|
Loading…
Reference in New Issue
Block a user