├── .github └── workflows │ ├── docker-build-test.yml │ ├── pylint.yml │ └── pytest.yml ├── .gitignore ├── CONTRIBUTING.md ├── Dockerfile ├── README.md ├── docs ├── install_cli.md ├── install_macos.md ├── install_windows.md └── user_guide.md ├── gradio_app.py ├── img ├── browser_gradio_interface.PNG ├── contrail.png ├── debug.png ├── docker_desktop_images.PNG ├── docker_desktop_optional_settings.PNG ├── docker_desktop_search.PNG ├── dot.png ├── fully_connected.png ├── fully_connected_neon.png ├── highlight_line.png ├── lagging_dot.png ├── line.png ├── line_constant_acc.png ├── line_constant_vel.png └── neon_line.png ├── model_training ├── Dockerfile ├── train.py └── training_start.sh ├── project_goals.md ├── start ├── start_linux.sh └── start_linux_cpu.sh ├── test ├── __init__.py ├── test_detection.py ├── test_draw.py ├── test_effects.py ├── test_matching.py └── test_pipeline.py ├── test_assets ├── ball_on_tarmac.jpg ├── ball_on_tarmac.mp4 └── stormy.mp4 └── traccc ├── detect.py ├── detectors.py ├── draw.py ├── draw_detections.py ├── effects.py ├── filters.py ├── track.py └── trackers.py /.github/workflows/docker-build-test.yml: -------------------------------------------------------------------------------- 1 | name: Docker build 2 | 3 | on: [push] 4 | jobs: 5 | build: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v2 9 | - name: Build Docker image 10 | run: docker build -t balltracking . 11 | -------------------------------------------------------------------------------- /.github/workflows/pylint.yml: -------------------------------------------------------------------------------- 1 | name: Pylint 2 | 3 | on: [push] 4 | 5 | jobs: 6 | build: 7 | runs-on: ubuntu-latest 8 | steps: 9 | - uses: actions/checkout@v3 10 | - name: start Docker 11 | run: docker build -t balltracking . 12 | - name: Analysing the code with pylint inside docker 13 | run: docker run -v $(pwd):/balltracking --ipc host balltracking pylint --fail-under=5 $(git ls-files '*.py') 14 | -------------------------------------------------------------------------------- /.github/workflows/pytest.yml: -------------------------------------------------------------------------------- 1 | name: Docker pull and Pytest 2 | 3 | on: [push] 4 | jobs: 5 | build: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - uses: actions/checkout@v2 9 | - name: Pull docker image from the hub 10 | run: docker pull sinclairhudson/balltracking 11 | - name: Run tests inside the container 12 | run: docker run -v $(pwd):/balltracking --ipc host sinclairhudson/balltracking pytest 13 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/* 2 | test/__pycache__/* 3 | io/* 4 | internal/* 5 | *.mp4 6 | traccc/__pycache__/* 7 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## 📝 Contributing 📝 2 | 3 | This project uses docker, both for running the app itself and for development. 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-runtime 2 | ARG DEBIAN_FRONTEND=noninteractive 3 | 4 | RUN apt-get update 5 | RUN apt-get install ffmpeg libsm6 libxext6 -y 6 | 7 | RUN pip install filterpy transformers timm sk-video opencv-python pylint 8 | RUN pip install pytest 9 | RUN conda install -c conda-forge gradio # both are needed, for docker build to run 10 | RUN pip3 install gradio==3.40.0 11 | RUN echo 'alias py3="python3"' >> ~/.bashrc 12 | RUN echo 'alias python="python3"' >> ~/.bashrc 13 | 14 | RUN mkdir -p /balltracking/io 15 | RUN mkdir -p /balltracking/internal 16 | 17 | WORKDIR /balltracking 18 | ADD *.py /balltracking/ 19 | CMD ["python3", "gradio_app.py"] 20 | EXPOSE 7860 21 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Traccc 2 | 3 | A tool to track sports balls, and add interesting visual effects. 4 | 5 | ![build and test badge](https://github.com/SinclairHudson/traccc/actions/workflows/docker-build-test.yml/badge.svg) 6 | ![pytest tests](https://github.com/SinclairHudson/traccc/actions/workflows/pytest.yml/badge.svg) 7 | ![linting](https://github.com/SinclairHudson/traccc/actions/workflows/pylint.yml/badge.svg) 8 | 9 | ||| 10 | |---|---| 11 | |![debug output of the pipeline](img/debug.png)|![sample contrail output](img/contrail.png)| 12 | 13 | It's very fun to apply to juggling 🤹 and sports 🏐🏀 videos. 14 | 15 | 16 | ## 🔥 Application Quickstart 🔥 17 | 18 | 19 | For Windows users, see the [Windows Installation Guide](docs/install_windows.md). 20 | For MacOS users, see the [MacOS Installation Guide](docs/install_macos.md). 21 | 22 | After installation, you can refer to the [User Guide](docs/user_guide.md) for an 23 | overview of all the software has to offer. 24 | 25 | Running this project is fairly demanding. 26 | A GPU is very helpful for running the neural networks for ball detection quickly, but it isn't required. 27 | At minimum: **10GB** disk space. 28 | 29 | --- 30 | 31 | ## 📝 Contributing 📝 32 | 33 | Contributions are welcome! Open an issue or a pull request, and I'll get to it when I can. 34 | Please see [The Contributing Guidelines](CONTRIBUTING.md) before making a pull request; it will also help you get started with development. 35 | Adding new effects is an easy contribution to make, and a good place to start. 36 | -------------------------------------------------------------------------------- /docs/install_cli.md: -------------------------------------------------------------------------------- 1 | # Command Line Interface (CLI) installation guide 2 | 3 | This guide is for advanced users, who are comfortable in the command prompt 4 | of their respective operating system. If that's not you, check out the installation 5 | guides for [MacOS](install_macos.md) or [Windows](install_windows.md). 6 | 7 | This guide uses Docker in the command line only, skipping the GUI of Docker Desktop. 8 | 9 | 1. Install `docker` 10 | 11 | Here is the [official installation guide](https://docs.docker.com/engine/install/), 12 | and additionally there will be quite a few tutorials online. 13 | 14 | 15 | 2. Pull the docker image 16 | ``` 17 | docker pull sinclairhudson/balltracking:latest 18 | ``` 19 | 20 | 3. Run the docker image with special options 21 | 22 | ``` 23 | docker run -v VIDEO_IO:/balltracking/io -p 7860:7860 -it --ipc host balltracking 24 | ``` 25 | 26 | Replace `VIDEO_IO` with the folder you'd like to use for input and output. 27 | 28 | Optionally, if your docker can access your Nvidia GPU, then this software can make 29 | use of it for running neural networks. This will result in much faster detections, 30 | and a much better user experience. 31 | Instead of the above command, run this very similar one: 32 | 33 | ``` 34 | docker run -v VIDEO_IO:/balltracking/io -p 7860:7860 -it --ipc host --gpus all balltracking 35 | ``` 36 | 37 | 4. Go to `localhost:7860` in your web browser 38 | You should see a GUI, like in the image below 39 | ![GUI](../img/browser_gradio_interface.PNG) 40 | -------------------------------------------------------------------------------- /docs/install_macos.md: -------------------------------------------------------------------------------- 1 | # MacOS installation guide 2 | 3 | 1. Install [Docker Desktop](https://docs.docker.com/desktop/install/mac-install/) 4 | 5 | Docker is an application that allows virtualization, sort of like a virtual machine. 6 | This application only runs natively on the Linux operating system, but docker allows it 7 | to run on all operating systems (like Windows and MacOS). 8 | 9 | This is the hardest step; installing Docker can be a pain. 10 | 11 | --- 12 | 13 | 2. Pull the docker image 14 | Use the docker desktop app to download the docker image `sinclairhudson/balltracking:latest`. 15 | This is a version of the app posted to Docker Hub. 16 | Search for the image in the top search bar of Docker Desktop, and "pull" (download) it from Docker Hub. 17 | ![docker pull in docker desktop](../img/docker_desktop_search.PNG) 18 | 19 | You should see the image appear in the "Images" tab of Docker Desktop: 20 | 21 | ![docker image pulled](../img/docker_desktop_images.PNG) 22 | 23 | --- 24 | 25 | 3. Run the docker image with options 26 | There are two things that need to be done: port access and file system access for input and outputs. 27 | 28 | ![docker run with options](../img/docker_desktop_optional_settings.PNG) 29 | 30 | --- 31 | 32 | 4. Verify 33 | Wait a few seconds after clicking "run", and then open your browser and go to `localhost:7860`. 34 | If everything is working, then you should see the GUI in the browser. 35 | 36 | ![GUI](../img/browser_gradio_interface.PNG) 37 | 38 | --- 39 | 40 | Head over to the [User Guide](user_guide.md) for guidance on how to use the app! 41 | 42 | If you're not seeing the GUI, check that you ran the docker image with the correct options. 43 | When it's running, you can also check the "logs" tab in docker desktop to see if it's giving an 44 | error message. If you think there's something wrong, please open an issue on the GitHub page. 45 | -------------------------------------------------------------------------------- /docs/install_windows.md: -------------------------------------------------------------------------------- 1 | # Windows Installation Guide 2 | 3 | 1. Install [Docker Desktop](https://docs.docker.com/desktop/install/windows-install/) 4 | 5 | This is the hardest step; installing Docker can be a pain. 6 | 7 | Docker is an application that allows virtualization, sort of like a virtual machine. 8 | This tracking application only runs natively on the Linux operating system, but Docker allows it 9 | to run on all operating systems (like Windows and MacOS). It's possible that after 10 | installing the application, it will ask you to change some settings in your BIOS. 11 | This may seem daunting but it's not that hard. Unfortunately, the process is motherboard-specific, 12 | so it's hard to write an in-depth guide. 13 | 14 | --- 15 | 16 | 2. Pull the docker image 17 | Use the Docker Desktop app to download the docker image `sinclairhudson/balltracking:latest`. 18 | This is a version of the app posted to Docker Hub. 19 | Search for the image in the top search bar of Docker Desktop, and "pull" (download) it from Docker Hub. 20 | ![docker pull in docker desktop](../img/docker_desktop_search.PNG) 21 | 22 | You should see the image appear in the "Images" tab of Docker Desktop: 23 | 24 | ![docker image pulled](../img/docker_desktop_images.PNG) 25 | 26 | --- 27 | 28 | 3. Run the docker image with options 29 | There are two things that need to be done: port access and file system access for input and outputs. 30 | 31 | ![docker run with options](../img/docker_desktop_optional_settings.PNG) 32 | 33 | --- 34 | 35 | 4. Verify 36 | Wait a few seconds after clicking "run", and then open your browser and go to `localhost:7860`. 37 | If everything is working, then you should see the GUI in the browser. 38 | 39 | ![GUI](../img/browser_gradio_interface.PNG) 40 | 41 | --- 42 | 43 | Head over to the [User Guide](user_guide.md) for guidance on how to use the app! 44 | 45 | If you're not seeing the GUI, check that you ran the docker image with the correct options. 46 | When it's running, you can also check the "logs" tab in docker desktop to see if it's giving an 47 | error message. If you think there's something wrong, please open an issue on the GitHub page. 48 | -------------------------------------------------------------------------------- /docs/user_guide.md: -------------------------------------------------------------------------------- 1 | # Traccc User Guide 2 | 3 | This is the user guide for the ball tracking software, usually used via the 4 | GUI in the browser. The normal workflow is to go through the 3 stages (Detect, Track, Draw) sequentially 5 | with a single video. The __Project Name__ must remain the same for all three stages. 6 | The intermediate data is saved between each step. This means that if you don't like the effect you chose, 7 | you can go back and change the effect and re-draw the video. Similarly, if you feel the need to adjust 8 | the tracking parameters, you can go back to that step, re-track all the balls, and then re-draw your effect 9 | with the new output. 10 | 11 | ## Detection 12 | 13 | Currently, there are two supported detectors: 14 | |Name|Speed|Accuracy| 15 | |---|---|---| 16 | |DETR|1/5|4/5| 17 | |RN50|3/5|2/5, with quite a few false positives| 18 | 19 | Since detection only needs to be run once per video, I would recommend DETR for its 20 | higher accuracy, even though it'll take a few more minutes compared to RN50. 21 | 22 | ## Tracking 23 | 24 | There are two options for ball trackers, **Constant Velocity** and **Constant Acceleration**. 25 | The Constant Velocity tracker generally outputs more smooth trajectories, but it's overall less accurate. 26 | The Constant Acceleration tracker tracks objects closer, but at times extrapolates too far and makes sharp corrections. 27 | 28 | Here's an example that illustrates the difference between the two: 29 | 30 | |Constant Velocity|Constant Acceleration| 31 | |---|---| 32 | |![constant_vel](../img/line_constant_vel.png)|![constant_acc](../img/line_constant_acc.png)| 33 | 34 | Using constant velocity is recommended, though if the system is _losing track of balls_ 35 | then you might want to try constant acceleration. 36 | 37 | ## Effects 38 | 39 | Here are all the effects listed. Note that not all effects take all arguments. 40 | For example, The effect `fully_connected` doesn't change when the "length" attribute 41 | is changed, because it doesn't have a fixed length. 42 | 43 | |Effect Name|Example|Speed| 44 | |---|---|---| 45 | |dot|![dot](../img/dot.png)|5/5| 46 | |lagging_dot|![lagging_dot](../img/lagging_dot.png)|5/5| 47 | |line|![line](../img/line.png)|5/5| 48 | |highlight_line|![highlight_line](../img/highlight_line.png)|2/5| 49 | |neon_line|![neon_line](../img/neon_line.png)|2/5| 50 | |contrail|![contrail](../img/contrail.png)|1/5| 51 | |fully_connected|![fully_connected](../img/fully_connected.png)|4/5| 52 | |fully_connected_neon|![fully_connected_neon](../img/fully_connected_neon.png)|4/5| 53 | |debug|![debug](../img/debug.png)|2/5| 54 | 55 | ## Best Practices: 56 | 57 | 1. Use a **high framerate**. Most cameras, even phone cameras, can shoot 60 frames per second now, go as high as you can. 58 | Every frame is additional information for the tracker, so it's better able to track the balls. 59 | 2. Use a **fast shutter speed**. This is usually implied by high framerate, but technically different. 60 | A fast shutter speed will **reduce motion blur**, making the balls in every frame look more like balls 61 | and thus making them easier to detect. 62 | 3. Expose for the things you want to track. If the exposure of the camera is too high or too low, the objects will lose detail 63 | and they'll be almost impossible to detect, because they won't look like much of anything. Even lighting is best, 64 | so the camera doesn't have to expose for the highlights or the shadows. 65 | 4. Use a **small file to test ideas**. Some of the effects and detections are very slow to run, 66 | especially without a GPU. It's best to test and tweak an effect on a short clip, iterate fast, and then 67 | apply it to a longer clip once you're sure you have what you want. 68 | -------------------------------------------------------------------------------- /gradio_app.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | import os 3 | from traccc.draw import run_draw 4 | from traccc.track import run_track 5 | from traccc.detect import run_detect 6 | 7 | def sanitize_run_detect(project_name: str, model_select: str, input_file: str, 8 | progress=gr.Progress(track_tqdm=True)): 9 | if not os.path.exists("io/" + input_file): 10 | raise gr.Error(f"Input file '{input_file}' does not exist. Is the file" + \ 11 | " in the specificed io folder? Is the folder mounted correctly?") 12 | 13 | return run_detect(project_name, model_select, "io/" + input_file) 14 | 15 | def sanitize_run_track(name: str, track_type: str, death_time: int, iou_threshold: float, conf_threshold: float, max_cost: float): 16 | if not os.path.exists(f"internal/{name}.npz"): 17 | raise gr.Error(f"Couldn't find detections for this project. Is the project name" + \ 18 | " correct?") 19 | 20 | return run_track(name, track_type, death_time, iou_threshold, conf_threshold, max_cost) 21 | 22 | def sanitize_run_draw(name: str, input_video: str, output: str, effect_name: str, 23 | colour: str, size: float, length: int, min_age: int, progress=gr.Progress(track_tqdm=True)): 24 | 25 | if not os.path.exists(f"internal/{name}.yaml"): 26 | raise gr.Error(f"Couldn't find tracks for this project. Is the project name" + \ 27 | " correct?") 28 | 29 | if not os.path.exists("io/" + input_video): 30 | raise gr.Error(f"Couldn't find input video '{input_video}'. Is the input video path correct?") 31 | 32 | return run_draw(name, "io/" + input_video, "io/" + output, effect_name, colour, size, length, min_age) 33 | 34 | with gr.Blocks() as demo: 35 | gr.Markdown("Create cool ball tracking videos with this one simple trick!") 36 | with gr.Tab("Detect"): 37 | text_input = gr.Textbox(placeholder="fireball", label="Project Name", 38 | info="The name of the clip being processed. Remember \ 39 | this name and make it unique, because it's used in the \ 40 | next two steps as well. Using the same name will \ 41 | overwrite previous data!") 42 | # video_upload = gr.inputs.Video(label="Video File") 43 | input_file = gr.Textbox( 44 | placeholder="fireball.mp4", label="Input File") 45 | model_select = gr.components.Radio(["DETR", "RN50"], label="Model") 46 | detect_button = gr.Button("Detect", variant="primary") 47 | debug_textbox = gr.Textbox(label="Output") 48 | detect_button.click(sanitize_run_detect, inputs=[ 49 | text_input, model_select, input_file], outputs=[debug_textbox]) 50 | 51 | with gr.Tab("Track"): 52 | track_name_input = gr.Textbox( 53 | placeholder="fireball", label="Project Name") 54 | track_type_input = gr.components.Radio( 55 | ["Constant Acceleration", "Constant Velocity"], label="Track Type", value="Constant Acceleration") 56 | death_time = gr.Slider(label="Death Time", minimum=1, 57 | maximum=20, value=5, interactive=True, step=1) 58 | iou_threshold = gr.Slider( 59 | label="IoU Threshold", minimum=0.01, maximum=1, value=0.20, interactive=True) 60 | confidence_treshold = gr.Slider( 61 | label="Confidence Threshold", minimum=0, maximum=1, value=0.05, interactive=True) 62 | max_cost = gr.Slider(label="Maximum Matching Cost", 63 | minimum=0, maximum=1000, value=200, interactive=True) 64 | track_button = gr.Button("Track", variant="primary") 65 | track_debug_textbox = gr.Textbox(label="Output") 66 | track_button.click(sanitize_run_track, inputs=[track_name_input, track_type_input, death_time, 67 | iou_threshold, confidence_treshold, max_cost], outputs=[track_debug_textbox]) 68 | 69 | with gr.Tab("Draw"): 70 | draw_name = gr.Textbox(placeholder="fireball", label="Project Name") 71 | draw_input_file = gr.Textbox(placeholder="fireball.mp4", label="Input File", info="The input file \ 72 | should be the same as the one used in the Detect step.") 73 | output_file = gr.Textbox(placeholder="fireball_with_effect.mp4", label="Output File", 74 | info="The video file to be created") 75 | effect_name = gr.components.Radio(["dot", "lagging_dot", 76 | "line", "highlight_line", "neon_line", 77 | "contrail", "fully_connected", 78 | "fully_connected_neon", "debug"], label="Effect") 79 | colour = gr.ColorPicker(label="Colour", value="#ff0000") 80 | size = gr.Slider(label="size", info="size of the effect, proportional to \ 81 | the width of the object being tracked.", 82 | minimum=0, maximum=20, value=1, interactive=True) 83 | length = gr.Slider(label="length", info="length of the effect in frames", 84 | minimum=1, maximum=50, value=7, interactive=True, step=1) 85 | min_age = gr.Slider(label="minimum age", info="minimum age requirement (in frames) for \ 86 | a track to be visualized. Increasing this value will remove tracks that \ 87 | are short-lived, possibly false-positives.", 88 | minimum=1, maximum=50, value=7, interactive=True, step=1) 89 | 90 | draw_button = gr.Button("Draw Effect", variant="primary") 91 | 92 | draw_debug_textbox = gr.Textbox(label="Output") 93 | draw_button.click(sanitize_run_draw, inputs=[draw_name, draw_input_file, output_file, effect_name, colour, size, length, min_age], 94 | outputs=draw_debug_textbox) 95 | 96 | demo.queue().launch(server_name="0.0.0.0") 97 | -------------------------------------------------------------------------------- /img/browser_gradio_interface.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/browser_gradio_interface.PNG -------------------------------------------------------------------------------- /img/contrail.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/contrail.png -------------------------------------------------------------------------------- /img/debug.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/debug.png -------------------------------------------------------------------------------- /img/docker_desktop_images.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/docker_desktop_images.PNG -------------------------------------------------------------------------------- /img/docker_desktop_optional_settings.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/docker_desktop_optional_settings.PNG -------------------------------------------------------------------------------- /img/docker_desktop_search.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/docker_desktop_search.PNG -------------------------------------------------------------------------------- /img/dot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/dot.png -------------------------------------------------------------------------------- /img/fully_connected.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/fully_connected.png -------------------------------------------------------------------------------- /img/fully_connected_neon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/fully_connected_neon.png -------------------------------------------------------------------------------- /img/highlight_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/highlight_line.png -------------------------------------------------------------------------------- /img/lagging_dot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/lagging_dot.png -------------------------------------------------------------------------------- /img/line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/line.png -------------------------------------------------------------------------------- /img/line_constant_acc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/line_constant_acc.png -------------------------------------------------------------------------------- /img/line_constant_vel.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/line_constant_vel.png -------------------------------------------------------------------------------- /img/neon_line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/img/neon_line.png -------------------------------------------------------------------------------- /model_training/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:21.12-py3 2 | 3 | WORKDIR /workspace 4 | 5 | RUN mkdir /datasets 6 | 7 | RUN pip install filterpy 8 | 9 | -------------------------------------------------------------------------------- /model_training/train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | -------------------------------------------------------------------------------- /model_training/training_start.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | dir=$(dirname $(realpath -s $0)) 4 | 5 | docker run -v "$dir:/workspace" -v "$dataset_root:/datasets" -it --gpus all ball-tracking /bin/bash 6 | -------------------------------------------------------------------------------- /project_goals.md: -------------------------------------------------------------------------------- 1 | # Project Goals 2 | 3 | This project aims to create an application that tracks sports balls (tennis, basketball, volleyball, etc) through the air, 4 | using some object detector trained on the COCO dataset, and kalman filters to track. Additionally, this application will 5 | be able to add different effects to each track, to visualize them in a video format. 6 | 7 | ## Structure 8 | 9 | There are 4 parts to the project: 10 | * Model training on the COCO dataset 11 | * Model inference on individual frames of an input image 12 | * Tracking objects based on detections 13 | * adding effects based on tracks 14 | 15 | The 4 parts should be as de-coupled as possible. Specifically, we want to save 16 | computation and only do the object detections once for a single image. 17 | 18 | ## Technology and frameworks 19 | 20 | The whole project will be dockerized with cuda support (and cpu support), to make 21 | the application easy to run for users as well as developers. 22 | 23 | For anything machine learning, PyTorch 24 | For most arrays, numpy 25 | For kalman filters, pykalman 26 | For effects, image and video I/O, OpenCV 27 | For communication between different parts of the application, YAML files. 28 | 29 | 30 | ## Specifications 31 | 32 | ### Model training 33 | The model training component just needs to produce some sort of object detection 34 | model, trained exclusively to detect the "sports ball" class in the coco dataset. 35 | It won't be part of the app itself, but should operate in the same docker container. 36 | 37 | ### Model Inference 38 | Input: User video, in some standard IO directory 39 | Output: YAML file containing all the detections of the entire video. 40 | Probably some dictionary where every key is a frame and the entries are lists 41 | of coordinates in the image space. 42 | Example of running inference should be something like 43 | 44 | ``` 45 | python3 detect.py --model model.pth --video input.mp4 --name unique 46 | ``` 47 | 48 | ### Object tracking 49 | Input: yaml file generated from the model inference section (detections) 50 | Output: YAML of tracks associated with the detections 51 | Every track has: 52 | * Unique ID 53 | * starting frame (frame at which it appears first) 54 | * list of positions in the image space 55 | 56 | ``` 57 | python3 track.py --detections unique_detections.yaml 58 | ``` 59 | 60 | ### Effects 61 | Input: A video file and the tracks associated with them. The video file name should 62 | be in the yaml file already though. 63 | Also need some kind of user input to specify which effects to apply. 64 | An effect is a function that takes in a track and a reference to the video file and 65 | modifies the video file, applying an affect on the track. 66 | Output: video file with all the effects on each of the tracks, in IO directory 67 | 68 | ``` 69 | python3 add_effects.py --effect sparkles --tracks unique_tracks.yaml 70 | ``` 71 | 72 | 73 | -------------------------------------------------------------------------------- /start/start_linux.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | current=`pwd` 5 | # mount just the io when using the app, this is for development. 6 | #docker run -v $current:/balltracking -p 7860:7860 -it --ipc host --gpus all balltracking 7 | docker run -v $current:/balltracking -p 7860:7860 -it --ipc host --gpus all balltracking 8 | #docker run -v ~/balltracking_io:/balltracking/io -p 7860:7860 -it --ipc host --gpus all balltracking 9 | -------------------------------------------------------------------------------- /start/start_linux_cpu.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | current=`pwd` 5 | docker run -v $current:/balltracking -it -p 7860:7860 --ipc host balltracking 6 | -------------------------------------------------------------------------------- /test/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/test/__init__.py -------------------------------------------------------------------------------- /test/test_detection.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import skvideo.io 3 | 4 | from traccc.detectors import HuggingFaceDETR, PretrainedRN50Detector 5 | 6 | 7 | @pytest.mark.parametrize("DetectorClass", [HuggingFaceDETR, PretrainedRN50Detector]) 8 | def test_detector_single_ball_tarmac(DetectorClass): 9 | """ 10 | Test that the detector can detect a single ball in a test video, in every single frame. 11 | Duplicate detections are acceptable because NMS is not being run. 12 | """ 13 | model = DetectorClass() 14 | # this has 120 frames 15 | vid_generator = skvideo.io.vreader(f"test_assets/ball_on_tarmac.mp4") 16 | metadata = skvideo.io.ffprobe(f"test_assets/ball_on_tarmac.mp4") 17 | frame_count = int(metadata['video']['@nb_frames']) 18 | result = model.detect_video(vid_generator, frame_count=frame_count) 19 | assert len(result) == frame_count 20 | for frame_detections in result: 21 | # there must be a ball detected. there could be multiple boxes because no NMS 22 | assert len(frame_detections) >= 1 23 | -------------------------------------------------------------------------------- /test/test_draw.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import skvideo.io 3 | 4 | from traccc.detectors import HuggingFaceDETR, PretrainedRN50Detector 5 | 6 | # TODO integration tests, whole pipeline tests here. 7 | @pytest.mark.parametrize("DetectorClass", [HuggingFaceDETR, PretrainedRN50Detector]) 8 | def test_detector_single_ball_tarmac(DetectorClass): 9 | model = DetectorClass() 10 | # this has 120 frames 11 | vid_generator = skvideo.io.vreader(f"test_assets/ball_on_tarmac.mp4") 12 | metadata = skvideo.io.ffprobe(f"test_assets/ball_on_tarmac.mp4") 13 | frame_count = int(metadata['video']['@nb_frames']) 14 | result = model.detect_video(vid_generator, frame_count=frame_count) 15 | assert len(result) == frame_count 16 | for frame_detections in result: 17 | # there must be a ball detected. there could be multiple boxes because no NMS 18 | assert len(frame_detections) >= 1 19 | -------------------------------------------------------------------------------- /test/test_effects.py: -------------------------------------------------------------------------------- 1 | from traccc.effects import * 2 | import os 3 | import pytest 4 | 5 | def get_stormy(): 6 | if not os.path.exists("internal/stormy.npz"): 7 | os.system("python3 detect.py stormy --input test_assets/stormy.mp4") 8 | os.system("python3 track.py stormy") 9 | 10 | def get_stormy_vid_generator(): 11 | pass 12 | 13 | def test_fully_connected(): 14 | get_stormy() 15 | 16 | @pytest.mark.parametrize("effect_class", [FullyConnected, 17 | FullyConnectedNeon, 18 | Dot, 19 | LaggingDot, 20 | Line, 21 | Debug, 22 | HighlightLine, 23 | Contrail]) 24 | def test_effect(effect_class): 25 | pass 26 | 27 | -------------------------------------------------------------------------------- /test/test_matching.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | import pytest 3 | 4 | import numpy as np 5 | 6 | from traccc.track import hungarian_matching, track 7 | from traccc.trackers import Track, AccelTrack 8 | 9 | 10 | def test_matching(): 11 | """ 12 | tests that the hungarian matching matches close tracks to close detections 13 | """ 14 | tracks = [Track(0, np.array([0.92, 0, 0, 0, 0]), 0), 15 | Track(1, np.array([0.99, 20, 20, 0, 0]), 0), 16 | Track(2, np.array([0.39, 50, 40, 0, 0]), 0)] 17 | detections = np.array([[0.92, 1, 1, 2, 2], 18 | [0.99, 50, 50, 2, 4], 19 | [0.8, 20, 30, 2, 2]]) 20 | cost, row_ind, col_ind = hungarian_matching(tracks, detections) 21 | assert np.all(row_ind == np.array([0, 1, 2])) 22 | assert np.all(col_ind == np.array([0, 2, 1])) 23 | assert cost == 10 + 10 + sqrt(2) 24 | 25 | 26 | def test_reject_large_cost(): 27 | """ 28 | tests that the matching algorithm rejects a matching with a cost that is 29 | far too large 30 | """ 31 | tracks = [Track(0, np.array([0.92, 0, 0, 0, 0]), 0), 32 | Track(1, np.array([0.99, 20, 20, 0, 0]), 0), 33 | Track(2, np.array([0.39, 50000, 400000, 0, 0]), 0)] 34 | detections = np.array([[0.92, 1, 1, 2, 2], 35 | [0.99, 50, 50, 2, 4], 36 | [0.8, 20, 30, 2, 2]]) 37 | cost, row_ind, col_ind = hungarian_matching( 38 | tracks, detections, max_cost=100) 39 | 40 | # track 0 matches detection 0 (0, 0) 41 | # track 1 matchest detection 2 (1, 2) 42 | # track 2 matches nothing, as does detection 1 43 | assert np.all(row_ind == np.array([0, 1])) 44 | assert np.all(col_ind == np.array([0, 2])) 45 | assert cost == 10 + sqrt(2) 46 | 47 | 48 | def test_more_detections_than_tracks(): 49 | """ 50 | This function tests that the correct matching is made when there are more 51 | detections than tracks. 52 | """ 53 | tracks = [Track(0, np.array([0.92, 0, 0, 0, 0]), 0), 54 | Track(1, np.array([0.99, 20, 20, 0, 0]), 0)] 55 | detections = np.array([[0.92, 1, 1, 2, 2], 56 | [0.99, 50, 50, 2, 4], 57 | [0.91, 60, 50, 2, 4], 58 | [0.8, 20, 30, 2, 2]]) 59 | # track 0 matches detection 0 (0, 0) 60 | # track 1 matchest detection 3 (1, 3) 61 | # detection 1 and 2 are unmatched 62 | cost, row_ind, col_ind = hungarian_matching( 63 | tracks, detections, max_cost=100) 64 | assert np.all(row_ind == np.array([0, 1])) 65 | assert np.all(col_ind == np.array([0, 3])) 66 | assert cost == 10 + sqrt(2) 67 | 68 | 69 | def test_more_tracks_than_detections(): 70 | """ 71 | This function tests that when there are more tracks than detections, 72 | some of the tracks go unmatched and the cost is still correctly computed. 73 | """ 74 | tracks = [Track(0, np.array([0.92, 0, 0, 0, 0]), 0), 75 | Track(1, np.array([0.99, 20, 20, 0, 0]), 0), 76 | Track(2, np.array([0.99, 20, 10, 0, 0]), 0), 77 | Track(3, np.array([0.99, 50, 50, 0, 0]), 0), 78 | Track(4, np.array([0.39, 30, 30, 0, 0]), 0)] 79 | detections = np.array([[0.92, 21, 21, 2, 2], 80 | [0.8, 31, 31, 2, 2]]) 81 | # track 1 matches detection 0 (1, 0) with euclidean distance sqrt(2) 82 | # track 4 matchest detection 1 (4, 1) with euclidean distance sqrt(2) 83 | # tracks 0, 2, and 3 are unmatched 84 | cost, row_ind, col_ind = hungarian_matching( 85 | tracks, detections, max_cost=100) 86 | assert np.all(row_ind == np.array([1, 4])) 87 | assert np.all(col_ind == np.array([0, 1])) 88 | assert cost == sqrt(2) + sqrt(2) 89 | 90 | 91 | @pytest.mark.parametrize("tracker", [Track, AccelTrack]) 92 | def test_track(tracker): 93 | """ 94 | Simple test case to test that the tracker can take these detections and string 95 | them together. 96 | """ 97 | detections = [np.array([[0.92, 0, 0, 5, 5]]), 98 | np.array([[0.99, 20, 20, 5, 5]]), 99 | np.array([[0.90, 41, 38, 5, 5]]), 100 | np.array([[0.94, 58, 64, 5, 5]])] 101 | 102 | tracks = track(detections, tracker) 103 | assert len(tracks) == 1 104 | assert tracks[0].age == 3 105 | 106 | 107 | @pytest.mark.parametrize("tracker", [Track, AccelTrack]) 108 | def test_no_track_switch(tracker): 109 | """ 110 | Simple test case to ensure the tracker doesn't switch tracks when it shouldn't. 111 | In this test case, the trajectory of two objects cross, but the tracks shouldn't 112 | switch 113 | """ 114 | detections = [np.array([[0.99, 0, 0, 0, 0], [0.99, 100, 0, 0, 0]]), 115 | np.array([[0.99, 20, 20, 0, 0], [0.99, 80, 20, 0, 0]]), 116 | np.array([[0.99, 40, 40, 0, 0], [0.99, 60, 40, 0, 0]]), 117 | np.array([[0.99, 60, 60, 0, 0], [0.99, 40, 60, 0, 0]]), 118 | np.array([[0.99, 80, 80, 0, 0], [0.99, 20, 80, 0, 0]]), 119 | np.array([[0.99, 100, 100, 0, 0], [0.99, 0, 100, 0, 0]])] 120 | 121 | tracks = track(detections, tracker) 122 | assert len(tracks) == 2 # there should only be two tracks 123 | 124 | # and there should be a track that connects (0, 0) to (100, 100) 125 | # or connects (100, 0) to (0, 100) 126 | if np.all(tracks[0].prev_states[0] == np.array([0, 0, 0, 0])): 127 | assert np.allclose(tracks[0].prev_states[-1], 128 | np.array([100, 100, 20, 20]), atol=1e-2, rtol=0) 129 | elif np.all(tracks[0].prev_states[0] == np.array([100, 0, 20, 20])): 130 | assert np.allclose(tracks[0].prev_states[-1], 131 | np.array([0, 100, -20, 20]), atol=1e-2, rtol=0) 132 | -------------------------------------------------------------------------------- /test/test_pipeline.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | 4 | def test_pipeline(): 5 | """ 6 | Tests that the whole pipeline can be run end to end, which checks quite a few 7 | things like imports being good, etc. 8 | """ 9 | 10 | os.system("python3 detect.py stormy --input test_assets/stormy.mp4") 11 | os.system("python3 track.py stormy") 12 | os.system("python3 draw.py stormy --input test/assets/stormy.mp4 --effect line --length 8 --colour blue --output test_stormy.mp4") 13 | -------------------------------------------------------------------------------- /test_assets/ball_on_tarmac.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/test_assets/ball_on_tarmac.jpg -------------------------------------------------------------------------------- /test_assets/ball_on_tarmac.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/test_assets/ball_on_tarmac.mp4 -------------------------------------------------------------------------------- /test_assets/stormy.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SinclairHudson/traccc/b7cc6824525412da06286fc5874e2a52552294ae/test_assets/stormy.mp4 -------------------------------------------------------------------------------- /traccc/detect.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used for detecting objects in a video, and saving to an output. 3 | """ 4 | import argparse 5 | import os 6 | import gradio as gr 7 | 8 | import skvideo.io 9 | 10 | from traccc.detectors import HuggingFaceDETR, PretrainedRN50Detector 11 | 12 | model_selector = { 13 | "DETR": HuggingFaceDETR, 14 | "RN50": PretrainedRN50Detector 15 | } 16 | 17 | def run_detect(name: str, model: str, input_file: str, progress=gr.Progress(track_tqdm=True)): 18 | vid_generator = skvideo.io.vreader(input_file) 19 | metadata = skvideo.io.ffprobe(input_file) 20 | frame_count = int(metadata['video']['@nb_frames']) 21 | 22 | detector = model_selector[model]() 23 | 24 | if not os.path.exists(f"internal"): 25 | os.system("mkdir internal") # make internal if it doesn't exist 26 | detector.detect( 27 | vid_generator, filename=f"internal/{name}.npz", frame_count=frame_count) 28 | return f"Completed detection for project {name} using {model}." 29 | 30 | if __name__ == "__main__": 31 | parser = argparse.ArgumentParser( 32 | description="Tracks objects using detections as input.") 33 | parser.add_argument("name", help="name of the project to be tracked.") 34 | parser.add_argument("--model", help="choice of model", default="DETR") 35 | parser.add_argument("--input", help="video file to be used", default=None) 36 | 37 | 38 | # input sanitization 39 | args = parser.parse_args() 40 | name = args.name 41 | input_file = args.input 42 | if input_file is None: 43 | input_file = f"io/{name}.mp4" 44 | 45 | 46 | assert os.path.exists(input_file), f"Input file {input_file} does not exist." 47 | assert args.model in model_selector, f"Model {args.model} isn't supported" 48 | 49 | run_detect(name, args.model, input_file) 50 | -------------------------------------------------------------------------------- /traccc/detectors.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | import torch 6 | import torch.functional as F 7 | from torchvision.io import write_video 8 | from torchvision.models.detection import fasterrcnn_resnet50_fpn 9 | from torchvision.ops import box_convert 10 | from torchvision.transforms import Normalize 11 | from torchvision.utils import draw_bounding_boxes 12 | from tqdm import tqdm 13 | from transformers import DetrFeatureExtractor, DetrForObjectDetection 14 | 15 | SPORTS_BALL = 37 # from coco class mapping 16 | 17 | 18 | def show(imgs): 19 | """ 20 | Helper from pytorch tutorials 21 | """ 22 | if not isinstance(imgs, list): 23 | imgs = [imgs] 24 | _, axs = plt.subplots(ncols=len(imgs), squeeze=False) 25 | for i, img in enumerate(imgs): 26 | img = img.detach() 27 | img = F.to_pil_image(img) # pylint: disable=no-member 28 | axs[0, i].imshow(np.asarray(img)) 29 | axs[0, i].set(xticklabels=[], yticklabels=[], xticks=[], yticks=[]) 30 | 31 | 32 | class Detector(ABC): 33 | def __init__(self): 34 | raise NotImplementedError 35 | 36 | def detect_video(self, video, bbox_format="cxcywh", frame_count: int = None): 37 | raise NotImplementedError 38 | 39 | def detect(self, video, filename="internal/detections.npz", frame_count: int = None): 40 | detections = self.detect_video(video, frame_count=frame_count) 41 | np.savez(filename, *detections) 42 | return f"Successfully saved detections in {filename}" 43 | 44 | def display_detections_in_video(self, video: torch.Tensor, outfile: str) -> None: 45 | detections = self.detect_video(video, bbox_format="xyxy") 46 | print("drawing bounding boxes") 47 | for i, frame in tqdm(enumerate(video)): 48 | # draw boxes on single frame 49 | # write frame to mp4 50 | CHW = torch.permute(frame, (2, 0, 1)) # move channels to front 51 | video[i] = torch.permute(draw_bounding_boxes(CHW, torch.Tensor( 52 | detections[i]), colors="red", width=5), (1, 2, 0)) # move C back to end, save in tensor 53 | print("writing the video") 54 | write_video(outfile, video, video_codec="h264", fps=60.0) 55 | 56 | 57 | class PretrainedRN50Detector(Detector): 58 | def __init__(self): 59 | self.model = fasterrcnn_resnet50_fpn(pretrained=True, num_classes=91, 60 | pretrained_backbone=True) 61 | 62 | self.device = "cuda:0" if torch.cuda.is_available() else "cpu" 63 | self.model.eval().to(self.device) 64 | 65 | @torch.no_grad() 66 | def detect_video(self, video, bbox_format="cxcywh", frame_count: int = None): 67 | """ 68 | video is a generator 69 | output is a list of numpy arrays denoting bounding boxes for each frame 70 | """ 71 | # TODO batch these calls 72 | video_detections = [] # list of list of detections 73 | ZeroOne = Normalize((0, 0, 0), (255, 255, 255)) # divide to 0 to 1 74 | # num_batches = len(video) // batch_size # last frames may be cut 75 | 76 | print("detecting balls in the video") 77 | # for frame in tqdm(video, total=frame_count): 78 | for frame, _ in zip(video, tqdm(range(frame_count))): 79 | batch = torch.Tensor(frame).unsqueeze(0) 80 | batch = torch.moveaxis(batch, 3, 1) # move channels to position 1 81 | batch = ZeroOne(batch.float()) 82 | batched_result = self.model(batch.to(self.device)) 83 | for res in batched_result: 84 | xyxy = res["boxes"][res["labels"] == SPORTS_BALL] 85 | conf = res["scores"][res["labels"] == SPORTS_BALL] 86 | xywh = box_convert(xyxy, in_fmt="xyxy", out_fmt=bbox_format) 87 | cxywh = torch.cat((conf.unsqueeze(1), xywh), 88 | dim=1) # add confidences 89 | video_detections.append(cxywh.cpu().numpy()) 90 | return video_detections 91 | 92 | 93 | class HuggingFaceDETR(Detector): 94 | def __init__(self): 95 | self.feature_extractor = DetrFeatureExtractor.from_pretrained( 96 | 'facebook/detr-resnet-101-dc5') 97 | self.model = DetrForObjectDetection.from_pretrained( 98 | 'facebook/detr-resnet-101-dc5') 99 | 100 | self.device = "cuda:0" if torch.cuda.is_available() else "cpu" 101 | self.model.eval().to(self.device) 102 | 103 | @torch.no_grad() 104 | def detect_video(self, video, bbox_format="cxcywh", frame_count=None): 105 | """ 106 | video is a generator 107 | output is a list of numpy arrays denoting bounding boxes for each frame 108 | """ 109 | video_detections = [] # list of list of detections 110 | # num_batches = len(video) // batch_size # last frames may be cut 111 | 112 | print("detecting balls in the video") 113 | # for frame in tqdm(video, total=frame_count): 114 | for frame, _ in zip(video, tqdm(range(frame_count))): 115 | width, height, _ = frame.shape 116 | inputs = self.feature_extractor(images=frame, return_tensors="pt") 117 | inputs["pixel_values"] = inputs["pixel_values"].to(self.device) 118 | inputs["pixel_mask"] = inputs["pixel_mask"].to(self.device) 119 | 120 | outputs = self.model(**inputs) 121 | outputs["logits"] = outputs["logits"].squeeze( 122 | 0) # remove singleton batch dim 123 | outputs["pred_boxes"] = outputs["pred_boxes"].squeeze(0) 124 | confs = torch.nn.functional.softmax(outputs["logits"], dim=1) 125 | conf_scores, indices = torch.max(confs, dim=1) 126 | xywh = outputs["pred_boxes"][indices == SPORTS_BALL] 127 | conf_scores = conf_scores[indices == SPORTS_BALL] 128 | 129 | xywh[:, 0] *= height 130 | xywh[:, 2] *= height 131 | xywh[:, 1] *= width 132 | xywh[:, 3] *= width 133 | cxywh = torch.cat((conf_scores.unsqueeze(1), xywh), 134 | dim=1) # add confidences 135 | video_detections.append(cxywh.cpu().numpy()) 136 | return video_detections 137 | -------------------------------------------------------------------------------- /traccc/draw.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | import argparse 3 | import skvideo.io 4 | from traccc.effects import * 5 | from tqdm import tqdm 6 | from traccc.filters import * 7 | import cv2 8 | from typing import Tuple 9 | import gradio as gr 10 | 11 | def hex_to_bgr(rgb_hex: str) -> Tuple[int, int, int]: 12 | rgb_hex = rgb_hex.lstrip('#') 13 | rgb = [int(rgb_hex[i:i+2], 16) for i in (0, 2, 4)] 14 | return rgb 15 | 16 | def run_draw(name: str, input_video: str, output: str, effect_name: str, 17 | colour: str, size: float, length: int, min_age: int, progress=gr.Progress(track_tqdm=True)): 18 | """ 19 | Inputs are already expected to be sanitized 20 | """ 21 | vid_generator = skvideo.io.vreader(input_video) 22 | metadata = skvideo.io.ffprobe(input_video) 23 | frame_count = int(metadata['video']['@nb_frames']) 24 | fps = metadata['video']['@r_frame_rate'] 25 | numerator, denominator = map(int, fps.split('/')) 26 | fps = numerator / denominator 27 | width = int(metadata['video']['@width']) 28 | # we also probably need to rotate 29 | height = int(metadata['video']['@height']) 30 | fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v') 31 | opencv_out = cv2.VideoWriter( 32 | output, fourcc, fps, (width, height)) # TODO explore why sometimes this needs to be flipped. 33 | 34 | with open(f"internal/{name}.yaml", 'r') as f: 35 | track_dictionary = yaml.safe_load(f) 36 | 37 | rgb_color = hex_to_bgr(colour) 38 | effect = { 39 | "dot": Dot, 40 | "lagging_dot": LaggingDot, 41 | "line": Line, 42 | "highlight_line": HighlightLine, 43 | "neon_line": NeonLine, 44 | "contrail": Contrail, 45 | "fully_connected": FullyConnected, 46 | "fully_connected_neon": FullyConnectedNeon, 47 | "debug": Debug 48 | }[effect_name](rgb_color, length, size) 49 | 50 | tracks = track_dictionary["tracks"] 51 | 52 | # filter out all the tracks that we deem not good enough 53 | tracks = [track for track in tracks if standard_filter( 54 | track, min_age=min_age)] 55 | 56 | print("adding effect") 57 | # TODO workaround to issue https://github.com/gradio-app/gradio/issues/3841 58 | # revert to the below line when bug is fixed 59 | # for i, frame in tqdm(enumerate(vid_generator), total=frame_count): 60 | for (frame, i) in zip(vid_generator, tqdm(range(frame_count))): 61 | relevant_tracks = [ 62 | track for track in tracks if effect.relevant(track, i)] 63 | 64 | # loop through all tracks, draw each on the frame 65 | out_frame = effect.draw_tracks(frame, relevant_tracks, i) 66 | 67 | bgr_frame = cv2.cvtColor(out_frame, cv2.COLOR_RGB2BGR) 68 | opencv_out.write(bgr_frame) 69 | 70 | opencv_out.release() 71 | return f"successfully wrote video {output}" 72 | 73 | 74 | if __name__ == "__main__": 75 | parser = argparse.ArgumentParser( 76 | description="Draws effects on the video, based on the tracks") 77 | parser.add_argument("name", help="name of the project") 78 | parser.add_argument("--input", help="video file to be used", default=None) 79 | parser.add_argument( 80 | "--effect", help="name of effect you wish to use", default="line") 81 | parser.add_argument( 82 | "--min_age", help="tracks below this age don't get drawn", default=0) 83 | parser.add_argument("--colour", help="colour of the effect", default="#ff0000") 84 | parser.add_argument( 85 | "--length", help="length of the effect in frames", default=10) 86 | parser.add_argument( 87 | "--size", help="size or width of the effect", default=1.0) 88 | parser.add_argument( 89 | "--output", help="the output file", default=None) 90 | args = parser.parse_args() 91 | name = args.name 92 | effect = args.effect 93 | colour = args.colour 94 | length = int(args.length) 95 | size = float(args.size) 96 | if args.output == None: 97 | output = f"io/{name}_out.mp4" 98 | else: 99 | output = args.output 100 | 101 | if args.input == None: 102 | input_video = f"io/{name}.mp4" 103 | else: 104 | input_video = args.input 105 | 106 | run_draw(name, input_video, output, effect, args.colour, 107 | size,length, int(args.min_age)) 108 | 109 | -------------------------------------------------------------------------------- /traccc/draw_detections.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | from torchvision.utils import draw_bounding_boxes 4 | from torchvision.ops import box_convert 5 | import skvideo.io 6 | import torch 7 | from tqdm import tqdm 8 | 9 | if __name__ == "__main__": 10 | parser = argparse.ArgumentParser( 11 | description="Tracks objects using detections as input.") 12 | parser.add_argument("name", help="name of the project to be tracked.") 13 | parser.add_argument( 14 | "--conf_threshold", help="confidence threshold for removing uncertain predictions, must be in the range [0, 1].", default=0.0) 15 | args = parser.parse_args() 16 | name = args.name 17 | print("reading video") 18 | vid_generator = skvideo.io.vreader(f"io/{name}.mp4") 19 | vid_writer = skvideo.io.FFmpegWriter(f"io/{name}_detections.mp4") 20 | metadata = skvideo.io.ffprobe(f"io/{name}.mp4") 21 | frame_count = int(metadata['video']['@nb_frames']) 22 | 23 | detections = np.load(f"internal/{name}.npz") 24 | detections_list = [] 25 | for frame_number, frame_name in enumerate(detections): 26 | detections_list.append(detections[frame_name]) 27 | 28 | print("drawing detections") 29 | for i, frame in tqdm(enumerate(vid_generator), total=frame_count): 30 | if len(detections_list[i]) > 0: 31 | # move channels to front 32 | CHW = torch.permute(torch.tensor( 33 | frame, dtype=torch.uint8), (2, 0, 1)) 34 | confs = detections_list[i][:, 0] 35 | detections_list[i] = detections_list[i][confs > 36 | float(args.conf_threshold)] 37 | boxes_xyxy = box_convert(torch.Tensor( 38 | detections_list[i][:, 1:]), in_fmt="cxcywh", out_fmt="xyxy") 39 | drawn = CHW 40 | for i, box in enumerate(boxes_xyxy): 41 | conf = confs[i] 42 | drawn = draw_bounding_boxes(CHW, box.unsqueeze(0), colors=( 43 | int(conf * 255), int(conf * 255), 255), width=5) 44 | 45 | vid_writer.writeFrame(torch.permute(drawn, (1, 2, 0)).numpy()) 46 | else: 47 | vid_writer.writeFrame(frame) 48 | 49 | vid_writer.close() 50 | print(f"finished writing io/{name}_detections.mp4") 51 | -------------------------------------------------------------------------------- /traccc/effects.py: -------------------------------------------------------------------------------- 1 | from abc import ABC 2 | import cv2 3 | import numpy as np 4 | from typing import Tuple, List 5 | 6 | # pylint: disable=invalid-name, no-member 7 | 8 | class Effect(ABC): 9 | def __init__(self, colour: Tuple[int], length: int, size: float = 1.0): 10 | self.colour = colour 11 | self.size = size 12 | self.length = length 13 | 14 | def relevant(self, track: dict, frame_number) -> bool: 15 | """ 16 | returns True if the frame needs to be modified because of this effect 17 | """ 18 | if track["start_frame"] <= frame_number and frame_number < track["start_frame"] + track["age"]: 19 | return True 20 | else: 21 | return False 22 | 23 | def draw_tracks(self, frame, tracks: List[dict], frame_number: int): 24 | """ 25 | The default is to draw every track independently 26 | """ 27 | out_frame = frame 28 | for track in tracks: 29 | # print(f"drawing track {track['id']}") 30 | out_frame = self.draw(out_frame, track, frame_number) 31 | return out_frame 32 | 33 | def draw(self, frame, track, frame_number): 34 | raise NotImplementedError 35 | 36 | 37 | class FullyConnected(Effect): 38 | def draw_tracks(self, frame, tracks: List[dict], frame_number: int): 39 | out_frame = frame 40 | state_indexes = [frame_number - track["start_frame"] 41 | for track in tracks] 42 | states = [track["states"][max(0, i-1)] # TODO investigate why this offset looks better 43 | for track, i in zip(tracks, state_indexes)] 44 | # draw a line between each pair of points 45 | for i, state_1 in enumerate(states): 46 | for _, state_2 in enumerate(states[i+1:]): 47 | # white lines 48 | (x1, y1) = state_1[:2] 49 | (x2, y2) = state_2[:2] 50 | w2 = state_2[4] 51 | out_frame = cv2.line(out_frame, (int(x1), int(y1)), (int( 52 | x2), int(y2)), self.colour, thickness=int(w2*self.size)) 53 | return out_frame 54 | 55 | 56 | class FullyConnectedNeon(Effect): 57 | def draw_tracks(self, frame, tracks: List[dict], frame_number: int): 58 | blank = np.zeros_like(frame, dtype=np.uint8) 59 | state_indexes = [frame_number - track["start_frame"] 60 | for track in tracks] 61 | states = [track["states"][max(0, i-1)] # TODO investigate why this offset looks better 62 | for track, i in zip(tracks, state_indexes)] 63 | # draw a line between each pair of points 64 | if len(states) < 1: 65 | return frame 66 | 67 | width = states[0][4] 68 | for i, state_1 in enumerate(states): 69 | for _, state_2 in enumerate(states[i+1:]): 70 | # white lines 71 | (x1, y1) = state_1[:2] 72 | (x2, y2) = state_2[:2] 73 | w2 = state_2[4] 74 | blank = cv2.line(blank, (int(x1), int(y1)), (int( 75 | x2), int(y2)), self.colour, thickness=int(self.size*w2*2)) 76 | 77 | kernel_size = int(width * self.size * 3) 78 | if kernel_size % 2 == 0: 79 | kernel_size += 1 80 | 81 | blank = cv2.GaussianBlur( 82 | blank, (kernel_size, kernel_size), self.size//2) 83 | 84 | out_frame = cv2.addWeighted(frame, 1, blank, 1, 0) 85 | 86 | for i, state_1 in enumerate(states): 87 | for _, state_2 in enumerate(states[i+1:]): 88 | # white lines 89 | (x1, y1) = state_1[:2] 90 | (x2, y2) = state_2[:2] 91 | w2 = state_2[4] 92 | out_frame = cv2.line(out_frame, (int(x1), int(y1)), (int( 93 | x2), int(y2)), (255, 255, 255), thickness=int(w2*self.size/2)) 94 | 95 | return out_frame 96 | 97 | 98 | class Dot(Effect): 99 | def draw(self, frame, track, frame_number): 100 | i = frame_number - track["start_frame"] 101 | (x, y) = track["states"][i][:2] 102 | w = track["states"][i][4] 103 | return cv2.circle(frame, (int(x), int(y)), radius=int(w*self.size/2), 104 | color=self.colour, thickness=-1) 105 | 106 | 107 | class LaggingDot(Effect): 108 | def relevant(self, track: dict, frame_number: int) -> bool: 109 | """ 110 | returns True if the frame needs to be modified because of this effect 111 | """ 112 | if track["age"] >= self.time_lag and \ 113 | track["start_frame"] <= frame_number + self.time_lag and \ 114 | frame_number < track["start_frame"] + track["age"] + self.time_lag: 115 | return True 116 | else: 117 | return False 118 | 119 | def draw(self, frame: np.ndarray, track: dict, frame_number: int) -> np.ndarray: 120 | i = frame_number - track["start_frame"] - self.time_lag 121 | if i >= 0: 122 | (x, y) = track["states"][i][:2] 123 | w = track["states"][i][4] 124 | else: 125 | # for times where the track is too young 126 | (x, y) = track["states"][0][:2] 127 | w = track["states"][0][4] 128 | return cv2.circle(frame, (int(x), int(y)), radius=int(w*self.size/2), 129 | color=self.colour, thickness=-1) 130 | 131 | 132 | class Line(Effect): 133 | def draw(self, frame: np.ndarray, track: dict, frame_number: int) -> np.ndarray: 134 | start_line = max(1, frame_number - 135 | self.length - track["start_frame"]) 136 | end_line = frame_number - track["start_frame"] 137 | 138 | for i in range(start_line, end_line): 139 | (x, y) = track["states"][i][:2] 140 | w = track["states"][i][4] 141 | (x2, y2) = track["states"][i-1][:2] 142 | frame = cv2.line(frame, (int(x), int(y)), (int(x2), int(y2)), 143 | color=self.colour, thickness=int(self.size*w)) 144 | return frame 145 | 146 | 147 | def draw_x(frame, x, y, colour, size: float): 148 | """ 149 | draws a small x on the frame at the given coordinates 150 | """ 151 | frame = cv2.line(frame, (int(x-size), int(y-size)), (int(x+size), 152 | int(y+size)), color=colour, thickness=max(1, int(size/3))) 153 | return cv2.line(frame, (int(x-size), int(y+size)), (int(x+size), int(y-size)), color=colour, thickness=max(1, int(size/3))) 154 | 155 | 156 | class Debug(Effect): 157 | def __init__(self, color, length_in_frames: int = 15, size: float = 0.15): 158 | self.length_in_frames = length_in_frames 159 | self.size = size 160 | # color not used 161 | self.colours = [(255, 0, 0), (0, 255, 0), (0, 0, 255), 162 | (255, 0, 230), (252, 132, 0)] 163 | 164 | def draw(self, frame: np.ndarray, track: dict, frame_number: int) -> np.ndarray: 165 | start_line = max(1, frame_number - 166 | self.length_in_frames - track["start_frame"]) 167 | end_line = frame_number - track["start_frame"] 168 | 169 | colour = self.colours[track["id"] % len(self.colours)] 170 | for i in range(start_line, end_line): 171 | (x, y) = track["states"][i][:2] 172 | (x2, y2, _, _, w, _) = track["states"][i-1][:6] 173 | frame = cv2.line(frame, (int(x), int(y)), (int(x2), int(y2)), 174 | color=colour, thickness=int(w * self.size)) 175 | 176 | meas = track["measurements"][i] 177 | if meas is None: 178 | frame = draw_x(frame, x, y, colour, self.size * w) 179 | else: # a matched detection 180 | mx, my = meas[1:3] 181 | frame = cv2.circle(frame, (int(x), int(y)), color=colour, 182 | radius=int(self.size + 3), thickness=-1) 183 | frame = cv2.circle(frame, (int(mx), int(my)), 184 | color=(255, 255, 255), radius=int(self.size * w), thickness=-1) 185 | frame = cv2.line(frame, (int(mx), int(my)), 186 | (int(x), int(y)), color=(255, 255, 255), 187 | thickness=int(w * self.size)) 188 | 189 | return frame 190 | 191 | 192 | class HighlightLine(Effect): 193 | def draw(self, frame: np.ndarray, track: dict, frame_number: int) -> np.ndarray: 194 | start_line = max(1, frame_number - 195 | self.length - track["start_frame"]) 196 | end_line = frame_number - track["start_frame"] 197 | 198 | blank = np.zeros_like(frame, dtype=np.uint8) 199 | 200 | width = track["states"][0][4] 201 | for i in range(start_line, end_line): 202 | (x, y) = track["states"][i][:2] 203 | (x2, y2) = track["states"][i-1][:2] 204 | blank = cv2.line(blank, (int(x), int(y)), (int(x2), int(y2)), 205 | color=self.colour, thickness=int(self.size*width)) 206 | 207 | kernel_size = int(min(width, 1) * self.size * 3) 208 | if kernel_size % 2 == 0: 209 | kernel_size += 1 210 | blank = cv2.GaussianBlur(blank, (kernel_size, kernel_size), kernel_size//2) 211 | 212 | return cv2.addWeighted(frame, 1, blank, 1, 0) 213 | 214 | class NeonLine(Effect): 215 | def draw(self, frame: np.ndarray, track: dict, frame_number: int) -> np.ndarray: 216 | start_line = max(1, frame_number - 217 | self.length - track["start_frame"]) 218 | end_line = frame_number - track["start_frame"] 219 | 220 | blank = np.zeros_like(frame, dtype=np.uint8) 221 | 222 | width = track["states"][0][4] 223 | for i in range(start_line, end_line): 224 | (x, y) = track["states"][i][:2] 225 | (x2, y2) = track["states"][i-1][:2] 226 | blank = cv2.line(blank, (int(x), int(y)), (int(x2), int(y2)), 227 | color=self.colour, thickness=int(self.size*width*2)) 228 | 229 | kernel_size = int(min(width, 10) * self.size * 3) 230 | if kernel_size % 2 == 0: 231 | kernel_size += 1 232 | blank = cv2.GaussianBlur(blank, (kernel_size, kernel_size), self.size//2) 233 | 234 | intermediate = cv2.addWeighted(frame, 1, blank, 1, 0) 235 | 236 | for i in range(start_line, end_line): 237 | (x, y) = track["states"][i][:2] 238 | (x2, y2) = track["states"][i-1][:2] 239 | blank = cv2.line(intermediate, (int(x), int(y)), (int(x2), int(y2)), 240 | color=(255, 255, 255), thickness=int(self.size*width/2)) 241 | return intermediate 242 | 243 | 244 | class Contrail(Effect): 245 | def __init__(self, colour=(255, 0, 0), length_in_frames: int = 15, size: int = 10): 246 | self.length_in_frames = length_in_frames 247 | self.colour = colour 248 | self.size = size if size % 2 == 1 else size + 1 # must be odd 249 | 250 | def draw(self, frame: np.ndarray, track: dict, frame_number: int) -> np.ndarray: 251 | start_line = max(1, frame_number - 252 | self.length_in_frames - track["start_frame"]) 253 | end_line = frame_number - track["start_frame"] 254 | 255 | blank = np.zeros_like(frame, dtype=np.uint8) 256 | 257 | for i in range(start_line, end_line): 258 | (x, y) = track["states"][i][:2] 259 | (x2, y2) = track["states"][i-1][:2] 260 | w = track["states"][i][4] 261 | kernel_size = int(self.size*w) 262 | if kernel_size % 2 == 0: 263 | kernel_size += 1 # make odd for gaussian blur 264 | 265 | blank = cv2.line(blank, (int(x), int(y)), (int(x2), int(y2)), 266 | color=self.colour, thickness=kernel_size) 267 | blank = cv2.GaussianBlur( 268 | blank, (kernel_size, kernel_size), kernel_size//2) 269 | 270 | return cv2.addWeighted(frame, 1, blank, 1, 0) 271 | -------------------------------------------------------------------------------- /traccc/filters.py: -------------------------------------------------------------------------------- 1 | from traccc.trackers import Track 2 | 3 | def standard_filter(track: dict, min_age: int=30) -> bool: 4 | """ 5 | A filter for determining if the track produced by the tracker is good for 6 | visualization. 7 | """ 8 | if track["age"] < min_age: 9 | return False 10 | else: 11 | return True 12 | -------------------------------------------------------------------------------- /traccc/track.py: -------------------------------------------------------------------------------- 1 | """ 2 | track.py. The purpose of this file is to generate tracks from a list of detections 3 | """ 4 | import argparse 5 | from math import sqrt 6 | from typing import List 7 | import os 8 | 9 | import numpy as np 10 | import torch 11 | import yaml 12 | from scipy.optimize import linear_sum_assignment 13 | from torchvision.ops import box_convert, nms 14 | from tqdm import tqdm 15 | 16 | from traccc.trackers import Track, AccelTrack 17 | 18 | 19 | def euclidean_distance(track: Track, detection): 20 | """ 21 | calculates euclidean distance between a track and a detection in pixel space 22 | """ 23 | assert len(detection) == 5 # confidence, x, y, w, h 24 | return sqrt((track.kf.x[0] - detection[1]) ** 2 + (track.kf.x[1] - detection[2]) ** 2) 25 | 26 | 27 | def hungarian_matching(tracks, detections, cost_function=euclidean_distance, max_cost=np.infty): 28 | """ 29 | Finds the minimum cost matching between tracks and detections, based on 30 | some distance metric. Returns a permutation of detections that orders them 31 | to be in the correct order with tracks. 32 | If there are more detections that tracks, the detections permuted to the end 33 | returns the scalar of the cost of all the matches, as well as two arrays of equal 34 | length. row_ind[x] is the track that matches with col_ind[x] detection. 35 | """ 36 | cost_matrix = np.zeros((len(tracks), len(detections))) 37 | for i, track in enumerate(tracks): 38 | for j, detection in enumerate(detections): 39 | cost_matrix[i][j] = cost_function(track, detection) 40 | 41 | row_ind, col_ind = linear_sum_assignment(cost_matrix) 42 | 43 | # remove ridiculous matches; it's better to leave them unpaired 44 | for i, j in zip(row_ind, col_ind): 45 | if cost_matrix[i, j] > max_cost: 46 | row_ind = np.delete(row_ind, np.where(row_ind == i)) 47 | col_ind = np.delete(col_ind, np.where(col_ind == j)) 48 | cost = cost_matrix[row_ind, col_ind].sum() 49 | return cost, row_ind, col_ind 50 | 51 | 52 | def track(detections, track_class, death_time: int = 5, max_cost: float = np.infty): 53 | """ 54 | detections is a list of detections, every entry is a frame 55 | """ 56 | next_track_id = 0 # counter for track IDs 57 | inactive_tracks = [] 58 | tracks = [] 59 | for frame_number, frame_detections in zip(range(len(detections)), tqdm(detections)): 60 | 61 | # handle deaths; if a track hasn't been seen in a few frames, deactivate it 62 | dead_tracks = [track for track in tracks if not track.active] 63 | active_tracks = [track for track in tracks if track.active] 64 | inactive_tracks.extend(dead_tracks) 65 | 66 | # for dead in dead_tracks: 67 | # print(f"killed track {dead.id} tracks on frame {frame_number}") 68 | 69 | tracks = active_tracks 70 | # print(f"active tracks: {len(active_tracks)}") 71 | # print(f"detections: {len(frame_detections)}") 72 | for track in tracks: 73 | track.predict() # advance the Kalman Filter, to get the prior for this timestep 74 | 75 | if len(frame_detections) == 0: 76 | for track in tracks: 77 | track.update(None) # no data for any of the tracks 78 | 79 | else: # there are some detections 80 | if len(tracks) > 0: 81 | _, row_ind, col_ind = hungarian_matching( 82 | tracks, frame_detections, max_cost=max_cost) 83 | for i in range(len(row_ind)): 84 | # update the matched tracks 85 | tracks[row_ind[i]].update(frame_detections[col_ind[i]]) 86 | 87 | unmatched_track_indices = set( 88 | range(len(tracks))) - set(row_ind) 89 | for i in unmatched_track_indices: 90 | # print( 91 | # f"track {tracks[i].id} unmatched on frame {frame_number}") 92 | tracks[i].update(None) 93 | 94 | # births 95 | unmatched_detection_indices = set( 96 | range(len(frame_detections))) - set(col_ind) 97 | for i in unmatched_detection_indices: 98 | # print(f"birthed track {next_track_id} on frame {frame_number}") 99 | tracks.append( 100 | track_class(next_track_id, frame_detections[i], frame_number, 101 | death_time=death_time)) 102 | next_track_id += 1 103 | 104 | else: # no tracks, but detections 105 | # births 106 | for detection in frame_detections: 107 | # print( 108 | # f"birthed track {next_track_id} on frame {frame_number}") 109 | tracks.append(track_class(next_track_id, detection, 110 | frame_number, death_time=death_time)) 111 | next_track_id += 1 112 | 113 | inactive_tracks.extend(tracks) 114 | return inactive_tracks 115 | 116 | 117 | def filter_detections(detections, conf_threshold=0.0, iou_threshold=0.5) -> List[np.ndarray]: 118 | """ 119 | Applies confidence filtering and Non-Max Suppression 120 | """ 121 | filtered_detections = [] 122 | for frame_detections in detections: 123 | conf = frame_detections[:, 0] 124 | frame_detections = torch.Tensor( 125 | frame_detections[conf > conf_threshold]) 126 | 127 | xyxy = box_convert( 128 | frame_detections[:, 1:], in_fmt="cxcywh", out_fmt="xyxy") 129 | best_candidates = nms( 130 | xyxy, frame_detections[:, 0], iou_threshold=iou_threshold) 131 | filtered_detections.append(frame_detections[best_candidates].numpy()) 132 | return filtered_detections 133 | 134 | 135 | track_type_dict = { 136 | "Constant Velocity": Track, 137 | "Constant Acceleration": AccelTrack, 138 | } 139 | 140 | 141 | def run_track(name: str, track_type: str, death_time: int, iou_threshold: float, conf_threshold: float, max_cost: float): 142 | 143 | # load the raw detections 144 | detections = np.load(f"internal/{name}.npz") 145 | detections_list = [] 146 | for frame_name in detections: 147 | detections_list.append(detections[frame_name]) 148 | 149 | # filter out the detections that won't be used 150 | detections_list = filter_detections( 151 | detections_list, conf_threshold, iou_threshold) 152 | 153 | # run the tracking algorithm 154 | tracks = track( 155 | detections_list, track_type_dict[track_type], death_time=death_time, max_cost=max_cost) 156 | 157 | # dump to yaml 158 | track_lives = [track.encode_in_dictionary() for track in tracks] 159 | dictionary = { 160 | "detections_file": f"internal/{name}.npz", 161 | "tracks": track_lives 162 | } 163 | with open(f"internal/{name}.yaml", 'w') as f: 164 | yaml.safe_dump(dictionary, f) 165 | 166 | return f"Successfully saved {len(track_lives)} tracks to internal/{name}.yaml" 167 | 168 | 169 | if __name__ == "__main__": 170 | parser = argparse.ArgumentParser( 171 | description="Tracks objects using detections as input.") 172 | parser.add_argument("name", help="name of the project to be tracked.") 173 | parser.add_argument( 174 | "--track_type", help="track type", default="Track") 175 | parser.add_argument( 176 | "--death_time", help="number of frames without an observation before track deletion", default=5) 177 | parser.add_argument( 178 | "--iou_threshold", help="IoU threshold used in Non-Max Suppression filtering, must be in the range [0, 1].", default=0.2) 179 | parser.add_argument( 180 | "--conf_threshold", help="confidence threshold for removing uncertain predictions, must be in the range [0, 1].", default=0.05) 181 | parser.add_argument( 182 | "--max_cost", help="the maximum cost tolerated to match a track to a detection.", default=200) 183 | args = parser.parse_args() 184 | name = args.name 185 | 186 | assert os.path.exists( 187 | f"internal/{name}.npz"), f"Could not find internal/{name}.npz" 188 | message = run_track(args.name, args.track_type, int(args.death_time), float( 189 | args.iou_threshold), float(args.conf_threshold), float(args.max_cost)) 190 | print(message) 191 | -------------------------------------------------------------------------------- /traccc/trackers.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from filterpy.kalman import KalmanFilter 3 | 4 | class Track: 5 | """ 6 | This is a class representing a single track, ideally a single object and its movements 7 | """ 8 | 9 | def __init__(self, track_id: int, initial_pos: np.ndarray, start_frame: int, death_time: int = 5): 10 | self.id = track_id 11 | self.prev_states = [] # tracks all previous estimates of position, and velocity 12 | self.start_frame = start_frame 13 | self.death_time = death_time 14 | 15 | assert len(initial_pos) == 5 # cxywh 16 | self.kf = KalmanFilter(dim_x=6, dim_z=4) 17 | # initial pos is xywh 18 | # state (x vector) is [x, y, vx, vy] 19 | self.kf.x = np.array([initial_pos[1], initial_pos[2], 0, 0, initial_pos[3], initial_pos[4]]) 20 | self.kf.F = np.array([[1, 0, 1, 0, 0, 0], # x = x + vx 21 | [0, 1, 0, 1, 0, 0], # y = y + vy 22 | [0, 0, 1, 0, 0, 0], # vx = vx 23 | [0, 0, 0, 1, 0, 0], 24 | [0, 0, 0, 0, 1, 0], 25 | [0, 0, 0, 0, 0, 1] 26 | ]) # vy = vy + ay 27 | 28 | self.kf.H = np.array([[1, 0, 0, 0, 0, 0], # we only measure position 29 | [0, 1, 0, 0, 0, 0], 30 | [0, 0, 0, 0, 1, 0], 31 | [0, 0, 0, 0, 0, 1] 32 | ]) 33 | self.kf.P *= 1000 34 | 35 | self.age = 0 # in the first frame, age is 0 36 | self.time_missing = 0 37 | self.active = True 38 | self.prev_measurements = [] 39 | 40 | def predict(self) -> None: 41 | """ 42 | Advances the KalmanFilter, predicting the current state based on the prior 43 | """ 44 | self.kf.predict() 45 | 46 | def update(self, measurement: np.ndarray) -> None: 47 | """ 48 | Update our estimate of the state given the measurement. Calculate the posterior. 49 | """ 50 | self.prev_states.append(self.kf.x) 51 | self.prev_measurements.append(measurement) 52 | self.age += 1 53 | if measurement is None: # on this iteration, didn't see this object 54 | self.time_missing += 1 55 | if self.time_missing > self.death_time: 56 | self.active = False 57 | else: 58 | # measurement comes in as cxywh 59 | assert len(measurement) == 5 60 | measurement_xywh = measurement[1:5] 61 | self.kf.update(measurement_xywh) 62 | 63 | def encode_in_dictionary(self) -> dict: 64 | """ 65 | Encodes the track in a dictionary to be saved and used downstream. 66 | Uses vanilla python datatypes to allow for YAML serialization. 67 | """ 68 | life = { 69 | "id": self.id, 70 | "start_frame": self.start_frame, 71 | "states": [a.tolist() for a in self.prev_states], 72 | "measurements": [a.tolist() if a is not None else None 73 | for a in self.prev_measurements], 74 | "age": self.age 75 | } 76 | return life 77 | 78 | 79 | class AccelTrack(Track): 80 | def __init__(self, track_id: int, initial_pos: np.ndarray, start_frame: int, death_time: int = 5): 81 | self.id = track_id 82 | self.prev_states = [] # tracks all previous estimates of position, and velocity 83 | self.start_frame = start_frame 84 | self.death_time = death_time 85 | 86 | assert len(initial_pos) == 5 # cxywh 87 | self.kf = KalmanFilter(dim_x=8, dim_z=4) 88 | # initial pos is xywh 89 | # state (x vector) is [x, y, vx, vy] 90 | self.kf.x = np.array([initial_pos[1], initial_pos[2], 0, 0, initial_pos[3], initial_pos[4], 0, 0]) 91 | self.kf.F = np.array([[1, 0, 1, 0, 0, 0, 0, 0], # x = x + vx 92 | [0, 1, 0, 1, 0, 0, 0, 0], # y = y + vy 93 | [0, 0, 1, 0, 0, 0, 1, 0], # vx = vx 94 | [0, 0, 0, 1, 0, 0, 0, 1], 95 | [0, 0, 0, 0, 1, 0, 0, 0], 96 | [0, 0, 0, 0, 0, 1, 0, 0], 97 | [0, 0, 0, 0, 0, 0, 1, 0], 98 | [0, 0, 0, 0, 0, 0, 0, 1], 99 | ]) # vy = vy + ay 100 | 101 | self.kf.H = np.array([[1, 0, 0, 0, 0, 0, 0, 0], # we measure position and width and height 102 | [0, 1, 0, 0, 0, 0, 0, 0], 103 | [0, 0, 0, 0, 1, 0, 0, 0], 104 | [0, 0, 0, 0, 0, 1, 0, 0] 105 | ]) 106 | self.kf.P *= 1000 107 | 108 | self.age = 0 # in the first frame, age is 0 109 | self.time_missing = 0 110 | self.active = True 111 | self.prev_measurements = [] 112 | --------------------------------------------------------------------------------