cotracker/notebooks/demo.ipynb

925 lines
95 KiB
Plaintext
Raw Normal View History

2023-07-18 02:49:06 +02:00
{
"cells": [
{
"cell_type": "markdown",
"id": "60a7e08e-93c6-4370-9778-3bb102dce78b",
"metadata": {},
"source": [
"Copyright (c) Meta Platforms, Inc. and affiliates."
]
},
{
"cell_type": "markdown",
"id": "3081cd8f-f6f9-4a1a-8c36-8a857b0c3b03",
"metadata": {},
"source": [
"<a target=\"_blank\" href=\"https://colab.research.google.com/github/facebookresearch/co-tracker/blob/main/notebooks/demo.ipynb\">\n",
" <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
"</a>"
]
},
{
"cell_type": "markdown",
"id": "f9f3240f-0354-4802-b8b5-9070930fc957",
"metadata": {},
"source": [
"# CoTracker: It is Better to Track Together\n",
"This is a demo for <a href=\"https://co-tracker.github.io/\">CoTracker</a>, a model that can track any point in a video."
]
},
{
"cell_type": "markdown",
"id": "36ff1fd0-572e-47fb-8221-1e73ac17cfd1",
"metadata": {},
"source": [
"<img src=\"https://www.robots.ox.ac.uk/~nikita/storage/cotracker/bmx-bumps.gif\" alt=\"Logo\" width=\"50%\">"
]
},
{
"cell_type": "markdown",
"id": "6757bfa3-d663-4a54-9722-3e1a7da3307c",
"metadata": {},
"source": [
"Let's install dependencies:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1745a859-71d4-4ec3-8ef3-027cabe786d4",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import cv2\n",
"import torch\n",
"\n",
"from torchvision.io import read_video\n",
"from cotracker.utils.visualizer import Visualizer\n",
"from IPython.display import HTML"
]
},
{
"cell_type": "markdown",
"id": "7894bd2d-2099-46fa-8286-f0c56298ecd1",
"metadata": {},
"source": [
"Read a video from CO3D:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd",
"metadata": {},
"outputs": [],
"source": [
"video = read_video('../assets/apple.mp4')[0]\n",
"video = video.permute(0, 3, 1, 2)[None].float()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"../assets/apple.mp4\" type=\"video/mp4\">\n",
" </video>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HTML(\"\"\"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"../assets/apple.mp4\" type=\"video/mp4\">\n",
" </video>\"\"\")"
]
},
{
"cell_type": "markdown",
"id": "6f89ae18-54d0-4384-8a79-ca9247f5f31a",
"metadata": {},
"source": [
"Import CoTrackerPredictor and create an instance of it. We'll use this object to estimate tracks:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "d59ac40b-bde8-46d4-bd57-4ead939f22ca",
"metadata": {},
"outputs": [],
"source": [
"from cotracker.predictor import CoTrackerPredictor\n",
"\n",
"model = CoTrackerPredictor(\n",
" checkpoint=os.path.join(\n",
" '../checkpoints/cotracker_stride_4_wind_8.pth'\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "e8398155-6dae-4ff0-95f3-dbb52ac70d20",
"metadata": {},
"source": [
"Track points sampled on a regular grid of size 30\\*30 on the first frame:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "17fcaae9-7b3c-474c-977a-cce08a09d580",
"metadata": {},
"outputs": [],
"source": [
"pred_tracks, pred_visibility = model(video, grid_size=30)"
]
},
{
"cell_type": "markdown",
"id": "50a58521-a9ba-4f8b-be02-cfdaf79613a2",
"metadata": {},
"source": [
"Visualize and save the result: "
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "7e793ce0-7b77-46ca-a629-155a6a146000",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Video saved to ./videos/teaser_pred_track.mp4\n"
]
}
],
"source": [
"vis = Visualizer(save_dir='./videos', pad_value=100)\n",
"vis.visualize(video=video, tracks=pred_tracks, filename='teaser');"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "2d0733ba-8fe1-4cd4-b963-2085202fba13",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"./videos/teaser_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HTML(\"\"\"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"./videos/teaser_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>\"\"\")"
]
},
{
"cell_type": "markdown",
"id": "73d88a5f-057c-4b9f-828d-ee0b97d1e72f",
"metadata": {},
"source": [
"## Tracking manually selected points"
]
},
{
"cell_type": "markdown",
"id": "a75bca85-b872-4f4e-be19-ff16f0984037",
"metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [
"We will start by tracking points queried manually.\n",
"We define a queried point as: [time, x coord, y coord] \n",
"\n",
"So, the code below defines points with different x and y coordinates sampled on frames 0, 10, 20, and 30:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c6422e7c-8c6f-4269-92c3-245344afe35b",
"metadata": {},
"outputs": [],
"source": [
"queries = torch.tensor([\n",
" [0., 400., 350.], # point tracked from the first frame\n",
" [10., 600., 500.], # frame number 10\n",
" [20., 750., 600.], # ...\n",
" [30., 900., 200.]\n",
"]).cuda()"
]
},
{
"cell_type": "markdown",
"id": "13697a2a-7304-4d18-93be-bfbebf3dc12a",
"metadata": {},
"source": [
"That's what our queried points look like:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d7141079-d7e0-40b3-b031-a28879c4bd6d",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAoEAAAHVCAYAAACOpCHEAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAABHHElEQVR4nO3de3hU5b33/88kIUMwTEKCOUGCQTkUQYsgMSpYH/OAlHooiGxKlaIVtcFCUR6MW6HSrbC11lMV7b7cYDcqigWtbLRXBESt4SgoJyO4gUTMQcFMQMn5+/uDX9ZmTERCDpPJer+ua10k933PzH27kq+frFlrjcfMTAAAAHCVsGBPAAAAAG2PEAgAAOBChEAAAAAXIgQCAAC4ECEQAADAhQiBAAAALkQIBAAAcCFCIAAAgAsRAgEAAFyIEAgAAOBChEA0y+LFi+XxeBrd7r777mBPr8VUVlZq9uzZSklJUVRUlDIyMpSbmxvsaQFoIW6oZUePHtXcuXN15ZVXKi4uTh6PR4sXL/7e8bt379aVV16p6OhoxcXF6YYbbtCXX37ZdhNGq4sI9gTQMcybN0/p6ekBbQMHDgzSbFrer371K7366quaMWOG+vTpo8WLF+unP/2p1q5dq0svvTTY0wPQQjpyLfvqq680b948paWl6fzzz9c777zzvWM///xzjRgxQjExMXrwwQd19OhR/fGPf9T27du1ceNGRUZGtt3E0WoIgWgRo0eP1tChQ09pbEVFhSIjIxUWFhoHojdu3KilS5fq4Ycf1l133SVJuvHGGzVw4ED9v//3//TBBx8EeYYAWkpHrmXJyckqKipSUlKSNm/erAsvvPB7xz744IP65ptvtGXLFqWlpUmShg0bpv/7f/+vFi9erKlTp7bVtNGKQuMnFyHrnXfekcfj0dKlS3XvvfeqR48e6tKli8rLy3X48GHdddddGjRokKKjo+Xz+TR69Gh99NFHjT7HK6+8ovvvv189evRQ165ddd1118nv96uyslIzZsxQQkKCoqOjNWXKFFVWVjaYy5IlSzRkyBBFRUUpLi5O//Iv/6LCwsIfXMOrr76q8PDwgKLXuXNn3XzzzcrLyzul5wAQ2jpCLfN6vUpKSjql9f7tb3/Tz372MycASlJWVpb69u2rV1555ZSeA+0fRwLRIvx+v7766quAtu7duztf/+EPf1BkZKTuuusuVVZWKjIyUrt27dJrr72m8ePHKz09XSUlJXr22Wd12WWXadeuXUpJSQl4vvnz5ysqKkp333239u7dqyeffFKdOnVSWFiYvv76a/3+97/X+vXrtXjxYqWnp2vOnDnOYx944AHdd999uv766/XrX/9aX375pZ588kmNGDFCW7duVWxs7PeubevWrerbt698Pl9A+7BhwyRJ27ZtU2pq6un+pwPQjnTkWnaqDh48qNLS0kaPiA4bNkyrVq1q9mugnTCgGRYtWmSSGt3MzNauXWuSrHfv3vbtt98GPLaiosJqa2sD2vbt22der9fmzZvntNU/x8CBA62qqsppnzhxonk8Hhs9enTAc2RmZlqvXr2c7/fv32/h4eH2wAMPBIzbvn27RURENGj/rnPPPdf+z//5Pw3ad+7caZLsmWeeOenjAbR/bqhlJ9q0aZNJskWLFn1v31//+tcGfbNmzTJJVlFRccqvhfaLt4PRIp566inl5uYGbCeaPHmyoqKiAtq8Xq9zLk1tba0OHTqk6Oho9evXTx9++GGD17jxxhvVqVMn5/uMjAyZmW666aaAcRkZGSosLFRNTY0kafny5aqrq9P111+vr776ytmSkpLUp08frV279qRrO3bsmLxeb4P2zp07O/0AOoaOXMtOVX1No+51fLwdjBYxbNiwk55M/d2r7SSprq5Ojz/+uJ5++mnt27dPtbW1Tl98fHyD8SeemyJJMTExktTgrdiYmBjV1dXJ7/crPj5ee/bskZmpT58+jc7txGLcmKioqEbPy6moqHD6AXQMHbmWnar6mkbd6/gIgWgTjRWMBx98UPfdd59uuukm/eEPf1BcXJzCwsI0Y8YM1dXVNRgfHh7e6HN/X7uZSTpeoD0ej958881Gx0ZHR5907snJyTp48GCD9qKiIklqcL4PgI4rlGvZqUpOTpb0vzXuREVFRYqLi2v0KCFCDyEQQfPqq6/q8ssv13PPPRfQXlZWFnAidnOdffbZMjOlp6erb9++TX78j3/8Y61du1bl5eUBF4ds2LDB6QfgXqFSy05Vjx49dOaZZ2rz5s0N+jZu3EjN60A4JxBBEx4e7vyFW2/ZsmWNHnVrjrFjxyo8PFz3339/g9czMx06dOikj7/uuutUW1urv/zlL05bZWWlFi1apIyMDK4MBlwuVGpZU4wbN04rV64MuPXM6tWr9emnn2r8+PEt9joILo4EImh+9rOfad68eZoyZYouvvhibd++XS+88IJ69+7doq9z9tln69/+7d+Uk5Oj/fv369prr1XXrl21b98+rVixQlOnTnVuAt2YjIwMjR8/Xjk5OSotLdU555yj559/Xvv372/wlz8A9wmVWiZJf/7zn1VWVqYvvvhCkvTGG2/o888/lyTdcccdzvmJ99xzj5YtW6bLL79c06dP19GjR/Xwww9r0KBBmjJlSouuC8FDCETQ3HPPPfrmm2/04osv6uWXX9YFF1yg//7v/26Vz+m8++671bdvXz366KO6//77JR0/CXvkyJG6+uqrf/Dxf/3rX3Xffffpv/7rv/T111/rvPPO08qVKzVixIgWnyuA0BJKteyPf/yjDhw44Hy/fPlyLV++XJL0y1/+MuAilXXr1mnmzJm6++67FRkZqTFjxuiRRx7hfMAOxGPfPaYMAACADo9zAgEAAFyIEAgAAOBChEAAAAAXCloIfOqpp3TWWWepc+fOysjI0MaNG4M1FQBocdQ4AO1dUELgyy+/rJkzZ2ru3Ln68MMPdf7552vUqFEqLS0NxnQAoEVR4wCEgqBcHZyRkaELL7xQf/7znyUd/yic1NRU3XHHHY1eUl9ZWRnwGYZ1dXU6fPiw4uPj5fF42mzeANovM9ORI0eUkpKisLDgnulCjQPQklqtvlkbq6ystPDwcFuxYkVA+4033mhXX311o4+ZO3euSWJjY2P7wa2wsLANKtn3o8axsbG11tbS9a3Nbxb91Vdfqba2VomJiQHtiYmJ+uSTTxp9TE5OjmbOnOl87/f7lZaWpsLCwoDPcgXgXuXl5UpNTVXXrl2DOg9qHICW1lr1LSQ+McTr9TZ6h3Kfz0eBBBAgFN8+pcYBOBUtXd/a/MSZ7t27Kzw8XCUlJQHtJSUlSkpKauvpAECLosYBCBVtHgIjIyM1ZMgQrV692mmrq6vT6tWrlZmZ2dbTAYAWRY0DECqC8nbwzJkzNXnyZA0dOlTDhg3TY489pm+++UZTpkwJxnQAoEVR4wCEgqCEwAkTJujLL7/UnDlzVFxcrB//+Md66623GpxIDQChiBoHIBQE5T6BzVVeXq6YmBj5/X5OmgYgqWPVhY60FgDN11o1gc8OBgAAcCFCIAAAgAsRAgEAAFyIEAgAAOBChEAAAAAXIgQCAAC4ECEQAADAhQiBAAAALkQIBAAAcCFCIAAAgAsRAgEAAFyIEAgAAOBChEAAAAAXIgQCAAC4ECEQAADAhQiBAAAALkQIBAAAcCFCIAAAgAsRAgEAAFyIEAgAAOBChEAAAAAXIgQCAAC4ECEQAADAhQiBAAAALkQIBAAAcCFCIAAAgAsRAgEAAFyIEAgAAOBChEAAAAAXIgQCAAC4ECEQAADAhQiBAAAALkQIBAAAcCFCIAAAgAsRAgEAAFyIEAgAAOBChEAAAAAXalIInD9/vi688EJ17dpVCQkJuvbaa5Wfnx8wpqKiQtnZ2YqPj1d0dLTGjRunkpKSgDEFBQUaM2aMunTpooSEBM2aNUs1NTXNXw0ANAM1DoCbNCkErlu3TtnZ2Vq/fr1yc3NVXV2tkSNH6ptvvnHG/O53v9Mbb7yhZcuWad26dfriiy80duxYp7+2tlZjxoxRVVWVPvjgAz3//PNavHix5syZ03KrAoD
"text/plain": [
"<Figure size 640x480 with 4 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"# Create a list of frame numbers corresponding to each point\n",
"frame_numbers = queries[:,0].int().tolist()\n",
"\n",
"fig, axs = plt.subplots(2, 2)\n",
"axs = axs.flatten()\n",
"\n",
"for i, (query, frame_number) in enumerate(zip(queries, frame_numbers)):\n",
" ax = axs[i]\n",
" ax.plot(query[1].item(), query[2].item(), 'ro') \n",
" \n",
" ax.set_title(\"Frame {}\".format(frame_number))\n",
" ax.set_xlim(0, video.shape[4])\n",
" ax.set_ylim(0, video.shape[3])\n",
" ax.invert_yaxis()\n",
" \n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "aec7693b-9d74-48b3-b612-360290ff1e7a",
"metadata": {},
"source": [
"We pass these points as input to the model and track them:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "09008ca9-6a87-494f-8b05-6370cae6a600",
"metadata": {},
"outputs": [],
"source": [
"pred_tracks, __ = model(video, queries=queries[None])"
]
},
{
"cell_type": "markdown",
"id": "b00d2a35-3daf-482d-b40b-b6d4f548ca40",
"metadata": {},
"source": [
"Finally, we visualize the results with tracks leaving traces from the frame where the tracking starts.\n",
"Color encodes time:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "01467f8d-667c-4f41-b418-93132584c659",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Video saved to ./videos/queries_pred_track.mp4\n"
]
}
],
"source": [
"vis = Visualizer(\n",
" save_dir='./videos',\n",
" linewidth=6,\n",
" mode='cool',\n",
" tracks_leave_trace=-1\n",
")\n",
"vis.visualize(\n",
" video=video,\n",
" tracks=pred_tracks, \n",
" filename='queries');"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "fe23d210-ed90-49f1-8311-b7e354c7a9f6",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"./videos/queries_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HTML(\"\"\"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"./videos/queries_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>\"\"\")"
]
},
{
"cell_type": "markdown",
"id": "87f2a3b4-a8b3-4aeb-87d2-28f056c624ba",
"metadata": {},
"source": [
"## Points on a regular grid"
]
},
{
"cell_type": "markdown",
"id": "a9aac679-19f8-4b78-9cc9-d934c6f83b01",
"metadata": {},
"source": [
"### Tracking forward from the frame number x"
]
},
{
"cell_type": "markdown",
"id": "0aeabca9-cc34-4d0f-8b2d-e6a6f797cb20",
"metadata": {},
"source": [
"Let's now sample points on a regular grid and start tracking from the frame number 20 with a grid of 30\\*30. "
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "c880f3ca-cf42-4f64-9df6-a0e8de6561dc",
"metadata": {},
"outputs": [],
"source": [
"grid_size = 30\n",
"grid_query_frame = 20"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3cd58820-7b23-469e-9b6d-5fa81257981f",
"metadata": {},
"outputs": [],
"source": [
"pred_tracks, __ = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "25a85a1d-dce0-4e6b-9f7a-aaf31ade0600",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Video saved to ./videos/grid_query_20_pred_track.mp4\n"
]
}
],
"source": [
"vis = Visualizer(save_dir='./videos', pad_value=100)\n",
"vis.visualize(\n",
" video=video,\n",
" tracks=pred_tracks, \n",
" filename='grid_query_20',\n",
" query_frame=grid_query_frame);"
]
},
{
"cell_type": "markdown",
"id": "ce0fb5b8-d249-4f4e-b59a-51b4f03972c4",
"metadata": {},
"source": [
"Notice that tracking starts only from points sampled on a frame in the middle of the video. This is different from the grid in the first example:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "f0b01d51-9222-472b-a714-188c38d83ad9",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"./videos/grid_query_20_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HTML(\"\"\"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"./videos/grid_query_20_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>\"\"\")"
]
},
{
"cell_type": "markdown",
"id": "10baad8f-0cb8-4118-9e69-3fb24575715c",
"metadata": {},
"source": [
"### Tracking forward **and backward** from the frame number x"
]
},
{
"cell_type": "markdown",
"id": "8409e2f7-9e4e-4228-b198-56a64e2260a7",
"metadata": {},
"source": [
"CoTracker is an online algorithm and tracks points only in one direction. However, we can also run it backward from the queried point to track in both directions: "
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "506233dc-1fb3-4a3c-b9eb-5cbd5df49128",
"metadata": {},
"outputs": [],
"source": [
"grid_size = 30\n",
"grid_query_frame = 20"
]
},
{
"cell_type": "markdown",
"id": "495b5fb4-9050-41fe-be98-d757916d0812",
"metadata": {},
"source": [
"Let's activate backward tracking:"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "677cf34e-6c6a-49e3-a21b-f8a4f718f916",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Video saved to ./videos/grid_query_20_backward_pred_track.mp4\n"
]
}
],
"source": [
"pred_tracks, __ = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame, backward_tracking=True)\n",
"vis.visualize(\n",
" video=video,\n",
" tracks=pred_tracks, \n",
" filename='grid_query_20_backward',\n",
" query_frame=grid_query_frame);"
]
},
{
"cell_type": "markdown",
"id": "585a0afa-2cfc-4a07-a6f0-f65924b9ebce",
"metadata": {},
"source": [
"As you can see, we are now tracking points queried in the middle from the first frame:"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "c8d64ab0-7e92-4238-8e7d-178652fc409c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"./videos/grid_query_20_backward_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HTML(\"\"\"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"./videos/grid_query_20_backward_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>\"\"\")"
]
},
{
"cell_type": "markdown",
"id": "fb55fb01-6d8e-4e06-9346-8b2e9ef489c2",
"metadata": {},
"source": [
"## Regular grid + Segmentation mask"
]
},
{
"cell_type": "markdown",
"id": "e93a6b0a-b173-46fa-a6d2-1661ae6e6779",
"metadata": {},
"source": [
"Let's now sample points on a grid and filter them with a segmentation mask.\n",
"This allows us to track points sampled densely on an object because we consume less GPU memory."
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "b759548d-1eda-473e-9c90-99e5d3197e20",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from PIL import Image\n",
"grid_size = 120"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "14ae8a8b-fec7-40d1-b6f2-10e333b75db4",
"metadata": {},
"outputs": [],
"source": [
"input_mask = '../assets/apple_mask.png'\n",
"segm_mask = np.array(Image.open(input_mask))"
]
},
{
"cell_type": "markdown",
"id": "4e3a1520-64bf-4a0d-b6e9-639430e31940",
"metadata": {},
"source": [
"That's a segmentation mask for the first frame:"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "4d2efd4e-22df-4833-b9a0-a0763d59ee22",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.image.AxesImage at 0x7ff13dd1c8e0>"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAigAAAFHCAYAAACLR7eXAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAACRIUlEQVR4nOz9e6xs2Vnejf7GGPNSt1Vr7bWvfXf72m6w4YsN9j45JycCBweZKAgjkcgCB1mJhNpWwBIijoDIJMEIpJCgAI6iiERKHCL+IFGMHGI5YBTcMY6Rv+PYHw630G137/tadZ2XcXnPH2PMWmu3G0zbhl5tj59V3ntVzaqaVbu757Pf93mfV4mIkMlkMplMJnOG0M/3CWQymUwmk8k8kyxQMplMJpPJnDmyQMlkMplMJnPmyAIlk8lkMpnMmSMLlEwmk8lkMmeOLFAymUwmk8mcObJAyWQymUwmc+bIAiWTyWQymcyZIwuUTCaTyWQyZ44sUDKZTCaTyZw5nleB8rM/+7O86EUvYjQa8brXvY7f+q3fej5PJ5PJZDKZzBnheRMo/+E//Afe+c538g/+wT/gt3/7t/m6r/s63vjGN3Ljxo3n65QymUwmk8mcEdTztSzwda97Hd/wDd/AP//n/xyAEAIPPPAA73jHO/h7f+/vPR+nlMlkMplM5oxQPB9v2vc9H//4x3nXu961u09rzRve8AYef/zxzzu+6zq6rtv9HELgzp07nD9/HqXUn8s5ZzKZTCaT+dIQEVarFffeey9a/8lNnOdFoNy6dQvvPZcvX77r/suXL/M7v/M7n3f8e97zHt797nf/eZ1eJpPJZDKZP0OefPJJ7r///j/xmBfEFM+73vUuFovF7vbEE08836eUyWQymUzmi2Rvb+8LHvO8VFAuXLiAMYbr16/fdf/169e5cuXK5x1f1zV1Xf95nV4mk8lkMpk/Q/409oznpYJSVRWvec1r+NCHPrS7L4TAhz70Ia5evfp8nFImk8lkMpkzxPNSQQF45zvfyVvf+lZe+9rX8o3f+I3803/6T9lsNnzv937v83VKmUwmk8lkzgjPm0D5ru/6Lm7evMmP/uiPcu3aNb7+67+e//Jf/svnGWczmUwmk8l89fG85aB8KSyXS/b395/v08hkMplMJvNFsFgsmM/nf+IxL4gpnkwmk8lkMl9dZIGSyWQymUzmzJEFSiaTyWQymTNHFiiZTCaTyWTOHFmgZDKZTCaTOXNkgZLJZDKZTObMkQVKJpPJZDKZM0cWKJlMJpPJZM4cWaBkMplMJpM5c2SBkslkMplM5syRBUomk8lkMpkzRxYomUwmk8lkzhxZoGQymUwmkzlzZIGSyWQymUzmzJEFSiaTyWQymTNHFiiZTCaTyWTOHFmgZDKZTCaTOXNkgZLJZDKZTObMkQVKJpPJZDKZM0cWKJlMJpPJZM4cWaBkMplMJpM5c2SBkslkMplM5syRBUomk8lkMpkzRxYomUwmk8lkzhxZoGQymUwmkzlzZIGSyWQymUzmzJEFSiaTyWQymTNHFiiZTCaTyWTOHM9ZoPzGb/wGf+2v/TXuvfdelFL8x//4H+96XET40R/9Ue655x7G4zFveMMb+N3f/d27jrlz5w5vectbmM/nHBwc8La3vY31ev0lfZBMJpPJZDJfOTxngbLZbPi6r/s6fvZnf/ZZH//Jn/xJfuZnfob3vve9fPSjH2U6nfLGN76Rtm13x7zlLW/hU5/6FB/84Ad5//vfz2/8xm/wd/7O3/niP0Umk8lkMpmvLORLAJBf/uVf3v0cQpArV67IT/3UT+3uOz4+lrqu5d//+38vIiKf/vSnBZCPfexju2M+8IEPiFJKPve5z/2p3nexWAiQb/mWb/mWb/mWby/A22Kx+ILX+i+rB+UP//APuXbtGm94wxt29+3v7/O6172Oxx9/HIDHH3+cg4MDXvva1+6OecMb3oDWmo9+9KPP+rpd17FcLu+6ZTKZTCaT+crlyypQrl27BsDly5fvuv/y5cu7x65du8alS5fuerwoCg4PD3fHPJP3vOc97O/v724PPPDAl/O0M5lMJpPJnDFeEFM873rXu1gsFrvbk08++XyfUiaTyWQymT9DvqwC5cqVKwBcv379rvuvX7++e+zKlSvcuHHjrsedc9y5c2d3zDOp65r5fH7XLZPJZDKZzFcuX1aB8vDDD3PlyhU+9KEP7e5bLpd89KMf5erVqwBcvXqV4+NjPv7xj++O+W//7b8RQuB1r3vdl/N0MplMJpPJvEApnusT1us1v/d7v7f7+Q//8A/5xCc+weHhIQ8++CDf//3fzz/6R/+Il73sZTz88MP8yI/8CPfeey/f/u3fDsArX/lK/upf/av87b/9t3nve9+LtZa3v/3t/I2/8Te49957v2wfLJPJZDKZzAuYP+VE8Y5f+7Vfe9aRobe+9a0iEkeNf+RHfkQuX74sdV3LN3/zN8tnPvOZu17j9u3b8jf/5t+U2Wwm8/lcvvd7v1dWq9Wf+hzymHG+5Vu+5Vu+5dsL9/anGTNWIiK8wFgul+zv7z/fp5HJZDKZTOaLYLFYfEE/6QtiiieTyWQymcxXF1mgZDKZTCaTOXM8Z5NsJvPViFIACqXiTSuFNhqjDShABFMUlEVBVZbUdUVhChSglaYoDIUx6PT88Wi0+1lEEBEUihA81jmarqO3DusdPngkCJ21WOfw3uOcI6Tn+RAIPux+RoTYtxVeeA3cTCaTiWSBksn8CSiliNokihKlNcboJEQqqrrCGI3Rmsl4zGwyYT7bY282Y28yYTaaMB2NmIzHTEdj6iKKllFZMa6SiBEghCRSYNu23Do+5mi15GizpOlbmr5jud2w2m7ZNC3btqWzFuc9Xd9j+/h77wMhvZYkgXK3aMlkMpkXBlmgZDLPgoolk7srJlpjCkNVloxHI6aTSbyNR0zHYw7ncy7tn+PSuQucnx+wN5qwNxqzN5kyHY2ZjcYYUfRti7eWQmlKbdAC4jzBeZRShHlgu9+xtR1r27Potxw3C442S24vF9xcLri9XrFotmzalvW2oWlbuq7HOodzbidUgggSAgFQWaRkMpkXEFmgZDKnGATJTpjoKEwKYyiKglFdM5tOOHdwwKXzF7h87gKXDg64tH/AvQeHXNw7YK8cMzIlJRotQl1WTMZjlBPctmW7svjeopWiNAWVKcAGnPUoEYrCcFGPkXpKmCo6JaylYys9q3bLjcUR15dHXNsuudVsuLlYcLxcslyvWW8b2raj63u6vsc5h3MeHzwhxEpKrqZkMpkXAlmgZDKJoYWjtcJojTbRN1JWJeO6ZjqZcDDf456Ll7j/4hUevvdBXnLpfi5P5+yVNRMx1F5RBQVdIDgX/SbUGK/w1uGbQH+rQXeWqqqYTmuqomC53EBvCUFQk5rJfIoEwbWeutBMqgnBTHHlnBeNL7C92LOi44bdcm295NZiwfU7t7mxOGaxXbNarzlaLlltNmy3TRIrJ0JFJJy0fzKZTOYMkgVK5queoVpiUgvnmdWSvdmUg9keF8+d46FLV3jkgRfx4sv3c3HvAgfVlKIPaOsxnaP0oDpH6C3OOkBo7mwJ3iPOowO44zXbzYayLOnrNSLCcrnE+wAKdGmYHszRlWEyGXNwMKddNPjgqbRhVhQcVlP8+IAH5rA+51lebFk2W47aNUftipurOzxx4zqfu32TG0d3OF4uadqOtovVlcFsu/OrZLGSyWTOGFmgZL6qGdo4xkRvSVVVjOqayXjE+YMD7rl0ifsvXuKhK/fwwPlL3HNwngvTffZGU2bVlJGUeN8S+o5CF+i2Z3Nngd22uL7H9j2bdkPTtzjrmE5mKOBzTz2FTj4X7wNd34ECUQKFZnR7jClL7r//AZQyrBZHlCgm9Yi6rinrisoJRVVQ6YI9mXJlMqUfn8dqz8Y23Lq84NrmiD+6fYMnbl3j1nrBneWSO4sFi/WabdPS31VZORErmUwm83yTBUrmq5LTPpPCGOq6YjaZMJtO2ZvNuHjuHA8/+AAvuf9BXnL/g9x38TJTU7FXjBgVNVVRMipqigBiNGo+BYFw7Q7dZ7e
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.imshow((segm_mask[...,None]/255.*video[0,0].permute(1,2,0).cpu().numpy()/255.))"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "b42dce24-7952-4660-8298-4c362d6913cf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Video saved to ./videos/segm_grid_pred_track.mp4\n"
]
}
],
"source": [
"pred_tracks, __ = model(video, grid_size=grid_size, segm_mask=torch.from_numpy(segm_mask)[None, None])\n",
"vis = Visualizer(\n",
" save_dir='./videos',\n",
" pad_value=100,\n",
" linewidth=2,\n",
")\n",
"vis.visualize(\n",
" video=video,\n",
" tracks=pred_tracks, \n",
" filename='segm_grid');"
]
},
{
"cell_type": "markdown",
"id": "5a386308-0d20-4ba3-bbb9-98ea79823a47",
"metadata": {},
"source": [
"We are now only tracking points on the object (and around):"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "1810440f-00f4-488a-a174-36be05949e42",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"./videos/segm_grid_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HTML(\"\"\"<video width=\"640\" height=\"480\" autoplay loop controls>\n",
" <source src=\"./videos/segm_grid_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>\"\"\")"
]
},
{
"cell_type": "markdown",
"id": "a63e89e4-8890-4e1b-91ec-d5dfa3f93309",
"metadata": {},
"source": [
"## Dense Tracks"
]
},
{
"cell_type": "markdown",
"id": "4ae764d8-db7c-41c2-a712-1876e7b4372d",
"metadata": {},
"source": [
"### Tracking forward **and backward** from the frame number x"
]
},
{
"cell_type": "markdown",
"id": "0dde3237-ecad-4c9b-b100-28b1f1b3cbe6",
"metadata": {},
"source": [
"CoTracker also has a mode to track **every pixel** in a video in a **dense** manner but it is much slower than in previous examples. Let's downsample the video in order to make it faster: "
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "379557d9-80ea-4316-91df-4da215193b41",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 48, 3, 719, 1282])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"video.shape"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "c6db5cc7-351d-4d9e-9b9d-3a40f05b077a",
"metadata": {},
"outputs": [],
"source": [
"import torch.nn.functional as F\n",
"video_interp = F.interpolate(video[0], [100,180], mode=\"bilinear\")[None].cuda()"
]
},
{
"cell_type": "markdown",
"id": "7ba32cb3-97dc-46f5-b2bd-b93a094dc819",
"metadata": {},
"source": [
"The video now has a much lower resolution:"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "0918f246-5556-43b8-9f6d-88013d5a487e",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([1, 48, 3, 100, 180])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"video_interp.shape"
]
},
{
"cell_type": "markdown",
"id": "bc7d3a2c-5e87-4c8d-ad10-1f9c6d2ffbed",
"metadata": {},
"source": [
"Again, let's track points in both directions. This will only take a couple of minutes:"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "3b852606-5229-4abd-b166-496d35da1009",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 9/9 [02:07<00:00, 14.18s/it]\n"
]
}
],
"source": [
"pred_tracks, __ = model(video_interp, grid_query_frame=20, backward_tracking=True)\n"
]
},
{
"cell_type": "markdown",
"id": "4143ab14-810e-4e65-93f1-5775957cf4da",
"metadata": {},
"source": [
"Visualization with an optical flow color encoding:"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "5394b0ba-1fc7-4843-91d5-6113a6e86bdf",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Video saved to ./videos/dense_pred_track.mp4\n"
]
}
],
"source": [
"vis = Visualizer(\n",
" save_dir='./videos',\n",
" pad_value=20,\n",
" linewidth=1,\n",
" mode='optical_flow'\n",
")\n",
"vis.visualize(\n",
" video=video_interp,\n",
" tracks=pred_tracks, \n",
" query_frame=grid_query_frame,\n",
" filename='dense');"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "9113c2ac-4d25-4ef2-8951-71a1c1be74dd",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<video width=\"320\" height=\"240\" autoplay loop controls>\n",
" <source src=\"./videos/dense_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"execution_count": 30,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"HTML(\"\"\"<video width=\"320\" height=\"240\" autoplay loop controls>\n",
" <source src=\"./videos/dense_pred_track.mp4\" type=\"video/mp4\">\n",
" </video>\"\"\")"
]
},
{
"cell_type": "markdown",
"id": "95e9bce0-382b-4d18-9316-7f92093ada1d",
"metadata": {},
"source": [
"That's all, now you can use CoTracker in your projects!"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54e0ba0c-b532-46a9-af6f-9508de689dd2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "stereoformer",
"language": "python",
"name": "stereoformer"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}