{ "cells": [ { "cell_type": "markdown", "id": "60a7e08e-93c6-4370-9778-3bb102dce78b", "metadata": {}, "source": [ "Copyright (c) Meta Platforms, Inc. and affiliates." ] }, { "cell_type": "markdown", "id": "3081cd8f-f6f9-4a1a-8c36-8a857b0c3b03", "metadata": {}, "source": [ "\n", " \"Open\n", "" ] }, { "cell_type": "markdown", "id": "f9f3240f-0354-4802-b8b5-9070930fc957", "metadata": {}, "source": [ "# CoTracker: It is Better to Track Together\n", "This is a demo for CoTracker, a model that can track any point in a video." ] }, { "cell_type": "markdown", "id": "36ff1fd0-572e-47fb-8221-1e73ac17cfd1", "metadata": {}, "source": [ "\"Logo\"" ] }, { "cell_type": "markdown", "id": "88c6db31", "metadata": {}, "source": [ "Don't forget to turn on GPU support if you're running this demo in Colab. \n", "\n", "**Runtime** -> **Change runtime type** -> **Hardware accelerator** -> **GPU**\n", "\n", "Let's install dependencies for Colab:" ] }, { "cell_type": "code", "execution_count": null, "id": "278876a7", "metadata": {}, "outputs": [], "source": [ "# !git clone https://github.com/facebookresearch/co-tracker\n", "# %cd co-tracker\n", "# !pip install -e .\n", "# !pip install opencv-python einops timm matplotlib moviepy flow_vis\n", "# !mkdir checkpoints\n", "# %cd checkpoints\n", "# !wget https://huggingface.co/facebook/cotracker/resolve/main/cotracker2.pth" ] }, { "cell_type": "code", "execution_count": 2, "id": "1745a859-71d4-4ec3-8ef3-027cabe786d4", "metadata": { "ExecuteTime": { "end_time": "2024-07-29T20:52:14.487553Z", "start_time": "2024-07-29T20:52:12.423999Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/mnt/d/cotracker\n" ] } ], "source": [ "%cd ..\n", "import os\n", "import torch\n", "\n", "from base64 import b64encode\n", "from cotracker.utils.visualizer import Visualizer, read_video_from_path\n", "from IPython.display import HTML" ] }, { "cell_type": "code", "execution_count": 3, "id": "44342f62abc0ec1e", "metadata": { "ExecuteTime": { "end_time": "2024-07-29T20:52:31.688043Z", "start_time": "2024-07-29T20:52:31.668043Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CUDA available\n" ] } ], "source": [ "if torch.cuda.is_available():\n", " print('CUDA available')" ] }, { "cell_type": "markdown", "id": "7894bd2d-2099-46fa-8286-f0c56298ecd1", "metadata": {}, "source": [ "Read a video from CO3D:" ] }, { "cell_type": "code", "execution_count": 4, "id": "f1f9ca4d-951e-49d2-8844-91f7bcadfecd", "metadata": {}, "outputs": [], "source": [ "video = read_video_from_path('./assets/F1_shorts.mp4')\n", "video = torch.from_numpy(video).permute(0, 3, 1, 2)[None].float()" ] }, { "cell_type": "code", "execution_count": 5, "id": "fb4c2e9d-0e85-4c10-81a2-827d0759bf87", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "def show_video(video_path):\n", " video_file = open(video_path, \"r+b\").read()\n", " video_url = f\"data:video/mp4;base64,{b64encode(video_file).decode()}\"\n", " return HTML(f\"\"\"\"\"\")\n", " \n", "show_video(\"./assets/F1_shorts.mp4\")" ] }, { "cell_type": "markdown", "id": "6f89ae18-54d0-4384-8a79-ca9247f5f31a", "metadata": {}, "source": [ "Import CoTrackerPredictor and create an instance of it. We'll use this object to estimate tracks:" ] }, { "cell_type": "code", "execution_count": 7, "id": "d59ac40b-bde8-46d4-bd57-4ead939f22ca", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/mnt/d/cotracker/cotracker/models/build_cotracker.py:29: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n", " state_dict = torch.load(f, map_location=\"cpu\")\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "(384, 512)\n" ] } ], "source": [ "from cotracker.predictor import CoTrackerPredictor\n", "\n", "model = CoTrackerPredictor(\n", " checkpoint=os.path.join(\n", " './checkpoints/cotracker2.pth'\n", " )\n", ")" ] }, { "cell_type": "markdown", "id": "e8398155-6dae-4ff0-95f3-dbb52ac70d20", "metadata": {}, "source": [ "Track points sampled on a regular grid of size 30\\*30 on the first frame:" ] }, { "cell_type": "code", "execution_count": 8, "id": "17fcaae9-7b3c-474c-977a-cce08a09d580", "metadata": {}, "outputs": [ { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m在当前单元格或上一个单元格中执行代码时 Kernel 崩溃。\n", "\u001b[1;31m请查看单元格中的代码,以确定故障的可能原因。\n", "\u001b[1;31m单击此处了解详细信息。\n", "\u001b[1;31m有关更多详细信息,请查看 Jupyter log。" ] } ], "source": [ "pred_tracks, pred_visibility = model(video, grid_size=30)" ] }, { "cell_type": "markdown", "id": "50a58521-a9ba-4f8b-be02-cfdaf79613a2", "metadata": {}, "source": [ "Visualize and save the result: " ] }, { "cell_type": "code", "execution_count": 7, "id": "7e793ce0-7b77-46ca-a629-155a6a146000", "metadata": {}, "outputs": [ { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31m在当前单元格或上一个单元格中执行代码时 Kernel 崩溃。\n", "\u001b[1;31m请查看单元格中的代码,以确定故障的可能原因。\n", "\u001b[1;31m单击此处了解详细信息。\n", "\u001b[1;31m有关更多详细信息,请查看 Jupyter log。" ] } ], "source": [ "vis = Visualizer(save_dir='./videos', pad_value=100)\n", "vis.visualize(video=video, tracks=pred_tracks, visibility=pred_visibility, filename='teaser');" ] }, { "cell_type": "code", "execution_count": 1, "id": "2d0733ba-8fe1-4cd4-b963-2085202fba13", "metadata": {}, "outputs": [ { "ename": "NameError", "evalue": "name 'show_video' is not defined", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mshow_video\u001b[49m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m./videos/teaser.mp4\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", "\u001b[0;31mNameError\u001b[0m: name 'show_video' is not defined" ] } ], "source": [ "show_video(\"./videos/teaser.mp4\")" ] }, { "cell_type": "markdown", "id": "73d88a5f-057c-4b9f-828d-ee0b97d1e72f", "metadata": {}, "source": [ "## Tracking manually selected points" ] }, { "cell_type": "markdown", "id": "a75bca85-b872-4f4e-be19-ff16f0984037", "metadata": { "jp-MarkdownHeadingCollapsed": true, "tags": [] }, "source": [ "We will start by tracking points queried manually.\n", "We define a queried point as: [time, x coord, y coord] \n", "\n", "So, the code below defines points with different x and y coordinates sampled on frames 0, 10, 20, and 30:" ] }, { "cell_type": "code", "execution_count": null, "id": "c6422e7c-8c6f-4269-92c3-245344afe35b", "metadata": {}, "outputs": [], "source": [ "queries = torch.tensor([\n", " [0., 400., 350.], # point tracked from the first frame\n", " [10., 600., 500.], # frame number 10\n", " [20., 750., 600.], # ...\n", " [30., 900., 200.]\n", "])\n", "if torch.cuda.is_available():\n", " queries = queries.cuda()" ] }, { "cell_type": "markdown", "id": "13697a2a-7304-4d18-93be-bfbebf3dc12a", "metadata": {}, "source": [ "That's what our queried points look like:" ] }, { "cell_type": "code", "execution_count": null, "id": "d7141079-d7e0-40b3-b031-a28879c4bd6d", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "# Create a list of frame numbers corresponding to each point\n", "frame_numbers = queries[:,0].int().tolist()\n", "\n", "fig, axs = plt.subplots(2, 2)\n", "axs = axs.flatten()\n", "\n", "for i, (query, frame_number) in enumerate(zip(queries, frame_numbers)):\n", " ax = axs[i]\n", " ax.plot(query[1].item(), query[2].item(), 'ro') \n", " \n", " ax.set_title(\"Frame {}\".format(frame_number))\n", " ax.set_xlim(0, video.shape[4])\n", " ax.set_ylim(0, video.shape[3])\n", " ax.invert_yaxis()\n", " \n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "aec7693b-9d74-48b3-b612-360290ff1e7a", "metadata": {}, "source": [ "We pass these points as input to the model and track them:" ] }, { "cell_type": "code", "execution_count": null, "id": "09008ca9-6a87-494f-8b05-6370cae6a600", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video, queries=queries[None])" ] }, { "cell_type": "markdown", "id": "b00d2a35-3daf-482d-b40b-b6d4f548ca40", "metadata": {}, "source": [ "Finally, we visualize the results with tracks leaving traces from the frame where the tracking starts.\n", "Color encodes time:" ] }, { "cell_type": "code", "execution_count": null, "id": "01467f8d-667c-4f41-b418-93132584c659", "metadata": {}, "outputs": [], "source": [ "vis = Visualizer(\n", " save_dir='./videos',\n", " linewidth=6,\n", " mode='cool',\n", " tracks_leave_trace=-1\n", ")\n", "vis.visualize(\n", " video=video,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " filename='queries');" ] }, { "cell_type": "code", "execution_count": null, "id": "fe23d210-ed90-49f1-8311-b7e354c7a9f6", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/queries.mp4\")" ] }, { "cell_type": "markdown", "id": "199ecb50-f707-406a-bc64-94bfded827a8", "metadata": {}, "source": [ "Notice that points queried at frames 10, 20, and 30 are tracked **incorrectly** before the query frame. This is because CoTracker is an online algorithm and only tracks points in one direction. However, we can also run it backward from the queried point to track in both directions. Let's correct this:" ] }, { "cell_type": "code", "execution_count": null, "id": "b40775f2-6ab0-4bc6-9099-f903935657c9", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video, queries=queries[None], backward_tracking=True)\n", "vis.visualize(\n", " video=video,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " filename='queries_backward');" ] }, { "cell_type": "code", "execution_count": null, "id": "d3120f31-9365-4867-8c85-5638b0708edc", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/queries_backward.mp4\")" ] }, { "cell_type": "markdown", "id": "87f2a3b4-a8b3-4aeb-87d2-28f056c624ba", "metadata": {}, "source": [ "## Points on a regular grid" ] }, { "cell_type": "markdown", "id": "a9aac679-19f8-4b78-9cc9-d934c6f83b01", "metadata": {}, "source": [ "### Tracking forward from the frame number x" ] }, { "cell_type": "markdown", "id": "0aeabca9-cc34-4d0f-8b2d-e6a6f797cb20", "metadata": {}, "source": [ "Let's now sample points on a regular grid and start tracking from the frame number 20 with a grid of 30\\*30. " ] }, { "cell_type": "code", "execution_count": null, "id": "c880f3ca-cf42-4f64-9df6-a0e8de6561dc", "metadata": {}, "outputs": [], "source": [ "grid_size = 30\n", "grid_query_frame = 20" ] }, { "cell_type": "code", "execution_count": null, "id": "3cd58820-7b23-469e-9b6d-5fa81257981f", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame)" ] }, { "cell_type": "code", "execution_count": null, "id": "25a85a1d-dce0-4e6b-9f7a-aaf31ade0600", "metadata": {}, "outputs": [], "source": [ "vis = Visualizer(save_dir='./videos', pad_value=100)\n", "vis.visualize(\n", " video=video,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " filename='grid_query_20',\n", " query_frame=grid_query_frame);" ] }, { "cell_type": "markdown", "id": "ce0fb5b8-d249-4f4e-b59a-51b4f03972c4", "metadata": {}, "source": [ "Note that tracking starts only from points sampled on a frame in the middle of the video. This is different from the grid in the first example:" ] }, { "cell_type": "code", "execution_count": null, "id": "f0b01d51-9222-472b-a714-188c38d83ad9", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/grid_query_20.mp4\")" ] }, { "cell_type": "markdown", "id": "10baad8f-0cb8-4118-9e69-3fb24575715c", "metadata": {}, "source": [ "### Tracking forward **and backward** from the frame number x" ] }, { "cell_type": "code", "execution_count": null, "id": "506233dc-1fb3-4a3c-b9eb-5cbd5df49128", "metadata": {}, "outputs": [], "source": [ "grid_size = 30\n", "grid_query_frame = 20" ] }, { "cell_type": "markdown", "id": "495b5fb4-9050-41fe-be98-d757916d0812", "metadata": {}, "source": [ "Let's activate backward tracking:" ] }, { "cell_type": "code", "execution_count": null, "id": "677cf34e-6c6a-49e3-a21b-f8a4f718f916", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video, grid_size=grid_size, grid_query_frame=grid_query_frame, backward_tracking=True)\n", "vis.visualize(\n", " video=video,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " filename='grid_query_20_backward');" ] }, { "cell_type": "markdown", "id": "585a0afa-2cfc-4a07-a6f0-f65924b9ebce", "metadata": {}, "source": [ "As you can see, we are now tracking points queried in the middle from the first frame:" ] }, { "cell_type": "code", "execution_count": null, "id": "c8d64ab0-7e92-4238-8e7d-178652fc409c", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/grid_query_20_backward.mp4\")" ] }, { "cell_type": "markdown", "id": "fb55fb01-6d8e-4e06-9346-8b2e9ef489c2", "metadata": {}, "source": [ "## Regular grid + Segmentation mask" ] }, { "cell_type": "markdown", "id": "e93a6b0a-b173-46fa-a6d2-1661ae6e6779", "metadata": {}, "source": [ "Let's now sample points on a grid and filter them with a segmentation mask.\n", "This allows us to track points sampled densely on an object because we consume less GPU memory." ] }, { "cell_type": "code", "execution_count": null, "id": "b759548d-1eda-473e-9c90-99e5d3197e20", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from PIL import Image\n", "grid_size = 100" ] }, { "cell_type": "code", "execution_count": null, "id": "14ae8a8b-fec7-40d1-b6f2-10e333b75db4", "metadata": {}, "outputs": [], "source": [ "input_mask = './assets/apple_mask.png'\n", "segm_mask = np.array(Image.open(input_mask))" ] }, { "cell_type": "markdown", "id": "4e3a1520-64bf-4a0d-b6e9-639430e31940", "metadata": {}, "source": [ "That's a segmentation mask for the first frame:" ] }, { "cell_type": "code", "execution_count": null, "id": "4d2efd4e-22df-4833-b9a0-a0763d59ee22", "metadata": {}, "outputs": [], "source": [ "plt.imshow((segm_mask[...,None]/255.*video[0,0].permute(1,2,0).cpu().numpy()/255.))" ] }, { "cell_type": "code", "execution_count": null, "id": "b42dce24-7952-4660-8298-4c362d6913cf", "metadata": {}, "outputs": [], "source": [ "pred_tracks, pred_visibility = model(video, grid_size=grid_size, segm_mask=torch.from_numpy(segm_mask)[None, None])\n", "vis = Visualizer(\n", " save_dir='./videos',\n", " pad_value=100,\n", " linewidth=2,\n", ")\n", "vis.visualize(\n", " video=video,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " filename='segm_grid');" ] }, { "cell_type": "markdown", "id": "5a386308-0d20-4ba3-bbb9-98ea79823a47", "metadata": {}, "source": [ "We are now only tracking points on the object (and around):" ] }, { "cell_type": "code", "execution_count": null, "id": "1810440f-00f4-488a-a174-36be05949e42", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/segm_grid.mp4\")" ] }, { "cell_type": "markdown", "id": "a63e89e4-8890-4e1b-91ec-d5dfa3f93309", "metadata": {}, "source": [ "## Dense Tracks" ] }, { "cell_type": "markdown", "id": "4ae764d8-db7c-41c2-a712-1876e7b4372d", "metadata": {}, "source": [ "### Tracking forward **and backward** from the frame number x" ] }, { "cell_type": "markdown", "id": "0dde3237-ecad-4c9b-b100-28b1f1b3cbe6", "metadata": {}, "source": [ "CoTracker also has a mode to track **every pixel** in a video in a **dense** manner but it is much slower than in previous examples. Let's downsample the video in order to make it faster: " ] }, { "cell_type": "code", "execution_count": null, "id": "379557d9-80ea-4316-91df-4da215193b41", "metadata": {}, "outputs": [], "source": [ "video.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "c6db5cc7-351d-4d9e-9b9d-3a40f05b077a", "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "video_interp = F.interpolate(video[0], [200, 360], mode=\"bilinear\")[None]" ] }, { "cell_type": "markdown", "id": "7ba32cb3-97dc-46f5-b2bd-b93a094dc819", "metadata": {}, "source": [ "The video now has a much lower resolution:" ] }, { "cell_type": "code", "execution_count": null, "id": "0918f246-5556-43b8-9f6d-88013d5a487e", "metadata": {}, "outputs": [], "source": [ "video_interp.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "e4451ae5-c132-4ef2-9329-133b31305db7", "metadata": {}, "outputs": [], "source": [ "grid_query_frame=20" ] }, { "cell_type": "markdown", "id": "bc7d3a2c-5e87-4c8d-ad10-1f9c6d2ffbed", "metadata": {}, "source": [ "Again, let's track points in both directions. This will only take a couple of minutes:" ] }, { "cell_type": "code", "execution_count": null, "id": "3b852606-5229-4abd-b166-496d35da1009", "metadata": {}, "outputs": [], "source": [ "import time\n", "start_time = time.time()\n", "pred_tracks, pred_visibility = model(video_interp, grid_query_frame=grid_query_frame, backward_tracking=True)\n", "end_time = time.time() \n", "print(\"Time taken: {:.2f} seconds\".format(end_time - start_time))" ] }, { "cell_type": "markdown", "id": "4143ab14-810e-4e65-93f1-5775957cf4da", "metadata": {}, "source": [ "Visualization with an optical flow color encoding:" ] }, { "cell_type": "code", "execution_count": null, "id": "5394b0ba-1fc7-4843-91d5-6113a6e86bdf", "metadata": {}, "outputs": [], "source": [ "vis = Visualizer(\n", " save_dir='./videos',\n", " pad_value=20,\n", " linewidth=1,\n", " mode='optical_flow'\n", ")\n", "vis.visualize(\n", " video=video_interp,\n", " tracks=pred_tracks,\n", " visibility=pred_visibility,\n", " filename='dense');" ] }, { "cell_type": "code", "execution_count": null, "id": "9113c2ac-4d25-4ef2-8951-71a1c1be74dd", "metadata": {}, "outputs": [], "source": [ "show_video(\"./videos/dense.mp4\")" ] }, { "cell_type": "markdown", "id": "95e9bce0-382b-4d18-9316-7f92093ada1d", "metadata": {}, "source": [ "That's all, now you can use CoTracker in your projects!" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" }, "vscode": { "interpreter": { "hash": "5c3af07c9b62e9e5a46520c50c603ac08cca123627f098bd5ae6d6af0e7d15f9" } } }, "nbformat": 4, "nbformat_minor": 5 }