updated demo for longer sequences
This commit is contained in:
parent
25eb2ac723
commit
d3f3840186
29
demo.py
29
demo.py
@ -20,21 +20,9 @@ DEVICE = 'cuda'
|
||||
def load_image(imfile):
|
||||
img = np.array(Image.open(imfile)).astype(np.uint8)
|
||||
img = torch.from_numpy(img).permute(2, 0, 1).float()
|
||||
return img
|
||||
return img[None].to(DEVICE)
|
||||
|
||||
|
||||
def load_image_list(image_files):
|
||||
images = []
|
||||
for imfile in sorted(image_files):
|
||||
images.append(load_image(imfile))
|
||||
|
||||
images = torch.stack(images, dim=0)
|
||||
images = images.to(DEVICE)
|
||||
|
||||
padder = InputPadder(images.shape)
|
||||
return padder.pad(images)[0]
|
||||
|
||||
|
||||
def viz(img, flo):
|
||||
img = img[0].permute(1,2,0).cpu().numpy()
|
||||
flo = flo[0].permute(1,2,0).cpu().numpy()
|
||||
@ -43,6 +31,10 @@ def viz(img, flo):
|
||||
flo = flow_viz.flow_to_image(flo)
|
||||
img_flo = np.concatenate([img, flo], axis=0)
|
||||
|
||||
# import matplotlib.pyplot as plt
|
||||
# plt.imshow(img_flo / 255.0)
|
||||
# plt.show()
|
||||
|
||||
cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
|
||||
cv2.waitKey()
|
||||
|
||||
@ -58,11 +50,14 @@ def demo(args):
|
||||
with torch.no_grad():
|
||||
images = glob.glob(os.path.join(args.path, '*.png')) + \
|
||||
glob.glob(os.path.join(args.path, '*.jpg'))
|
||||
|
||||
images = sorted(images)
|
||||
for imfile1, imfile2 in zip(images[:-1], images[1:]):
|
||||
image1 = load_image(imfile1)
|
||||
image2 = load_image(imfile2)
|
||||
|
||||
images = load_image_list(images)
|
||||
for i in range(images.shape[0]-1):
|
||||
image1 = images[i,None]
|
||||
image2 = images[i+1,None]
|
||||
padder = InputPadder(image1.shape)
|
||||
image1, image2 = padder.pad(image1, image2)
|
||||
|
||||
flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
|
||||
viz(image1, flow_up)
|
||||
|
Loading…
Reference in New Issue
Block a user