├── .gitignore
├── ACKNOWLEDGEMENTS.md
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE.md
├── LICENSE_MODEL.md
├── README.md
├── demo.py
├── get_pretrained_models.sh
├── posetrack21_eval
├── README.md
├── datasets
│ ├── __init__.py
│ ├── pt_sequence.py
│ └── pt_warper.py
├── evaluate_mot.py
└── motmetrics
│ ├── __init__.py
│ ├── distances.py
│ ├── io.py
│ ├── lap.py
│ ├── math_util.py
│ ├── metrics.py
│ ├── mot.py
│ ├── preprocess.py
│ ├── tests
│ ├── __init__.py
│ ├── test_distances.py
│ ├── test_io.py
│ ├── test_issue19.py
│ ├── test_lap.py
│ ├── test_metrics.py
│ ├── test_mot.py
│ └── test_utils.py
│ └── utils.py
├── pyproject.toml
├── samples
├── sample_info.txt
├── teaser_01.gif
├── teaser_02.gif
├── teaser_03.gif
├── teaser_04.gif
├── teaser_05.gif
└── teaser_06.gif
└── src
└── comotion_demo
├── data
└── smpl
│ └── extra_smpl_reference.pt
├── models
├── __init__.py
├── backbones
│ ├── __init__.py
│ ├── _registry.py
│ └── convnext.py
├── comotion.py
├── detect.py
├── layers.py
└── refine.py
└── utils
├── __init__.py
├── dataloading.py
├── helper.py
├── smpl_kinematics.py
└── track.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.mp4
3 |
4 | .DS_Store
5 |
6 | src/comotion_demo/data
7 | src/comotion_demo.egg-info/
8 |
9 | data/
10 | tmp/
11 | results/
12 |
--------------------------------------------------------------------------------
/ACKNOWLEDGEMENTS.md:
--------------------------------------------------------------------------------
1 | Acknowledgements
2 | Portions of this Software may utilize the following copyrighted
3 | material, the use of which is hereby acknowledged.
4 |
5 | ------------------------------------------------
6 | PoseTrack21 evaluation code:
7 |
8 | MIT License
9 |
10 | Copyright (c) 2022 Andreas Doering
11 |
12 | Permission is hereby granted, free of charge, to any person obtaining a copy
13 | of this software and associated documentation files (the "Software"), to deal
14 | in the Software without restriction, including without limitation the rights
15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16 | copies of the Software, and to permit persons to whom the Software is
17 | furnished to do so, subject to the following conditions:
18 |
19 | The above copyright notice and this permission notice shall be included in all
20 | copies or substantial portions of the Software.
21 |
22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28 | SOFTWARE.
29 |
30 | ------------------------------------------------
31 | MOTMetrics:
32 |
33 | MIT License
34 |
35 | Copyright (c) 2017-2020 Christoph Heindl
36 | Copyright (c) 2018 Toka
37 | Copyright (c) 2019-2020 Jack Valmadre
38 |
39 | Permission is hereby granted, free of charge, to any person obtaining a copy
40 | of this software and associated documentation files (the "Software"), to deal
41 | in the Software without restriction, including without limitation the rights
42 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
43 | copies of the Software, and to permit persons to whom the Software is
44 | furnished to do so, subject to the following conditions:
45 |
46 | The above copyright notice and this permission notice shall be included in all
47 | copies or substantial portions of the Software.
48 |
49 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
50 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
51 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
52 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
53 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
54 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
55 | SOFTWARE.
56 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | ## Our Pledge
4 |
5 | In the interest of fostering an open and welcoming environment, we as
6 | contributors and maintainers pledge to making participation in our project and
7 | our community a harassment-free experience for everyone, regardless of age, body
8 | size, disability, ethnicity, sex characteristics, gender identity and expression,
9 | level of experience, education, socio-economic status, nationality, personal
10 | appearance, race, religion, or sexual identity and orientation.
11 |
12 | ## Our Standards
13 |
14 | Examples of behavior that contributes to creating a positive environment
15 | include:
16 |
17 | * Using welcoming and inclusive language
18 | * Being respectful of differing viewpoints and experiences
19 | * Gracefully accepting constructive criticism
20 | * Focusing on what is best for the community
21 | * Showing empathy towards other community members
22 |
23 | Examples of unacceptable behavior by participants include:
24 |
25 | * The use of sexualized language or imagery and unwelcome sexual attention or
26 | advances
27 | * Trolling, insulting/derogatory comments, and personal or political attacks
28 | * Public or private harassment
29 | * Publishing others' private information, such as a physical or electronic
30 | address, without explicit permission
31 | * Other conduct which could reasonably be considered inappropriate in a
32 | professional setting
33 |
34 | ## Our Responsibilities
35 |
36 | Project maintainers are responsible for clarifying the standards of acceptable
37 | behavior and are expected to take appropriate and fair corrective action in
38 | response to any instances of unacceptable behavior.
39 |
40 | Project maintainers have the right and responsibility to remove, edit, or
41 | reject comments, commits, code, wiki edits, issues, and other contributions
42 | that are not aligned to this Code of Conduct, or to ban temporarily or
43 | permanently any contributor for other behaviors that they deem inappropriate,
44 | threatening, offensive, or harmful.
45 |
46 | ## Scope
47 |
48 | This Code of Conduct applies within all project spaces, and it also applies when
49 | an individual is representing the project or its community in public spaces.
50 | Examples of representing a project or community include using an official
51 | project e-mail address, posting via an official social media account, or acting
52 | as an appointed representative at an online or offline event. Representation of
53 | a project may be further defined and clarified by project maintainers.
54 |
55 | ## Enforcement
56 |
57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be
58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All
59 | complaints will be reviewed and investigated and will result in a response that
60 | is deemed necessary and appropriate to the circumstances. The project team is
61 | obligated to maintain confidentiality with regard to the reporter of an incident.
62 | Further details of specific enforcement policies may be posted separately.
63 |
64 | Project maintainers who do not follow or enforce the Code of Conduct in good
65 | faith may face temporary or permanent repercussions as determined by other
66 | members of the project's leadership.
67 |
68 | ## Attribution
69 |
70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4,
71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html)
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contribution Guide
2 |
3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository.
4 |
5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged.
6 |
7 | ## Before you get started
8 |
9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE).
10 |
11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md).
12 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 |
3 | IMPORTANT: This Apple software is supplied to you by Apple
4 | Inc. ("Apple") in consideration of your agreement to the following
5 | terms, and your use, installation, modification or redistribution of
6 | this Apple software constitutes acceptance of these terms. If you do
7 | not agree with these terms, please do not use, install, modify or
8 | redistribute this Apple software.
9 |
10 | In consideration of your agreement to abide by the following terms, and
11 | subject to these terms, Apple grants you a personal, non-exclusive
12 | license, under Apple's copyrights in this original Apple software (the
13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple
14 | Software, with or without modifications, in source and/or binary forms;
15 | provided that if you redistribute the Apple Software in its entirety and
16 | without modifications, you must retain this notice and the following
17 | text and disclaimers in all such redistributions of the Apple Software.
18 | Neither the name, trademarks, service marks or logos of Apple Inc. may
19 | be used to endorse or promote products derived from the Apple Software
20 | without specific prior written permission from Apple. Except as
21 | expressly stated in this notice, no other rights or licenses, express or
22 | implied, are granted by Apple herein, including but not limited to any
23 | patent rights that may be infringed by your derivative works or by other
24 | works in which the Apple Software may be incorporated.
25 |
26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE
27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
31 |
32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
39 | POSSIBILITY OF SUCH DAMAGE.
40 |
--------------------------------------------------------------------------------
/LICENSE_MODEL.md:
--------------------------------------------------------------------------------
1 | Disclaimer: IMPORTANT: This Apple Machine Learning Research Model is specifically developed and released by Apple Inc. ("Apple") for the sole purpose of scientific research of artificial intelligence and machine-learning technology. “Apple Machine
2 | Learning Research Model” means the model, including but not limited to algorithms, formulas, trained model weights, parameters, configurations, checkpoints, and any related materials (including documentation).
3 |
4 | This Apple Machine Learning Research Model is provided to You by Apple in consideration of your agreement to the following terms, and your use, modification, creation of Model Derivatives, and or redistribution of the Apple Machine Learning Research Model constitutes acceptance of this Agreement. If You do not agree with these terms, please do not use, modify, create Model Derivatives of, or distribute this Apple Machine Learning Research Model or Model Derivatives.
5 |
6 | 1. License Scope: In consideration of your agreement to abide by the following terms, and subject to these terms, Apple hereby grants you a personal, non-exclusive, worldwide, non-transferable, royalty-free, revocable, and limited license, to use, copy, modify, distribute, and create Model Derivatives (defined below) of the Apple Machine Learning Research Model exclusively for Research Purposes. You agree that any Model Derivatives You may create or that may be created for You will be limited to Research Purposes as well. “Research Purposes” means non-commercial scientific research and academic development activities, such as experimentation, analysis, testing conducted by You with the sole intent to advance scientific knowledge and research. “Research Purposes” does not include any commercial exploitation, product development or use in any commercial product or service.
7 |
8 | 2. Distribution of Apple Machine Learning Research Model and Model Derivatives: If you choose to redistribute Apple Machine Learning Research Model or its Model Derivatives, you must provide a copy of this Agreement to such third party, and ensure that the following attribution notice be provided: “Apple Machine Learning Research Model is licensed under the Apple Machine Learning Research Model License Agreement.” Additionally, all Model Derivatives must clearly be identified as such, including disclosure of modifications and changes made to the Apple Machine Learning Research Model. The name, trademarks, service marks or logos of Apple may not be used
9 | to endorse or promote Model Derivatives or the relationship between You and Apple. “Model Derivatives” means any models or any other artifacts created by modifications, improvements, adaptations, alterations to the architecture,
10 | algorithm or training processes of the Apple Machine Learning Research Model, or by any retraining, fine-tuning of the Apple Machine Learning Research Model.
11 |
12 | 3. No Other License: Except as expressly stated in this notice, no other rights or licenses, express or implied, are granted by Apple herein, including but not limited to any patent, trademark, and similar intellectual property rights worldwide that may be infringed by the Apple Machine Learning Research Model, the Model Derivatives or by other works in which the Apple Machine Learning Research Model may be incorporated.
13 |
14 | 4. Compliance with Laws: Your use of Apple Machine Learning Research Model must be in compliance with all applicable laws and regulations.
15 |
16 | 5. Term and Termination: The term of this Agreement will begin upon your acceptance of this Agreement or use of the Apple Machine Learning Research Model and will continue until terminated in accordance with the following terms. Apple may terminate this Agreement at any time if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You must cease to use all Apple Machine Learning Research Models and Model Derivatives and permanently delete any copy thereof. Sections 3, 6 and 7 will survive termination.
17 |
18 | 6. Disclaimer and Limitation of Liability: This Apple Machine Learning Research Model and any outputs generated by the Apple Machine Learning Research Model are provided on an “AS IS” basis. APPLE MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, REGARDING THE APPLE MACHINE LEARNING RESEARCH MODEL OR OUTPUTS GENERATED BY THE APPLE MACHINE LEARNING RESEARCH MODEL. You are solely responsible for determining the appropriateness of using or redistributing the Apple Machine Learning Research Model and any outputs of the Apple Machine Learning Research Model and assume any risks associated with Your use of the Apple Machine Learning Research Model and any output and results. IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF THE APPLE MACHINE LEARNING RESEARCH MODEL AND ANY OUTPUTS OF THE APPLE MACHINE LEARNING RESEARCH MODEL, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
19 |
20 | 7. Governing Law: This Agreement will be governed by and construed under the laws of the State of California without regard to its choice of law principles. The Convention on Contracts for the International Sale of Goods shall not apply to the Agreement except that the arbitration clause and any arbitration hereunder shall be governed by the Federal Arbitration Act, Chapters 1 and 2.
21 |
22 | Copyright (c) 2025 Apple Inc. All Rights Reserved.
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CoMotion: Concurrent Multi-person 3D Motion
2 |
3 | This software project accompanies the research paper:
4 | **[CoMotion: Concurrent Multi-person 3D Motion](https://openreview.net/forum?id=qKu6KWPgxt)**,
5 | _Alejandro Newell, Peiyun Hu, Lahav Lipson, Stephan R. Richter, and Vladlen Koltun_.
6 |
7 |
12 |
13 |
18 |
19 | We introduce CoMotion, an approach for detecting and tracking detailed 3D poses of multiple people from a single monocular camera stream. Our system maintains temporally coherent predictions in crowded scenes filled with difficult poses and occlusions. Our model performs both strong per-frame detection and a learned pose update to track people from frame to frame. Rather than match detections across time, poses are updated directly from a new input image, which enables online tracking through occlusion.
20 |
21 | The code in this directory provides helper functions and scripts for inference and visualization.
22 |
23 | ## Getting Started
24 |
25 | ### Installation
26 |
27 | ```bash
28 | conda create -n comotion -y python=3.10
29 | conda activate comotion
30 | pip install -e '.[all]'
31 | ```
32 |
33 | ### Download models
34 |
35 | To download pretrained checkpoints, run:
36 |
37 | ```bash
38 | bash get_pretrained_models.sh
39 | ```
40 |
41 | Checkpoint data will be downloaded to `src/comotion_demo/data`. You will find pretrained weights for the detection stage which includes the main vision backbone (`comotion_detection_checkpoint.pt`), as well as a separate checkpoint for the update stage (`comotion_refine_checkpoint.pt`). You can use the detection stage standalone for single-image multiperson pose estimation.
42 |
43 | For MacOS, we provide a pre-compiled `coreml` version of the detection stage of the model which offers significant speedups when running locally on a personal device.
44 |
45 | ### Download SMPL body model
46 |
47 | In order to run CoMotion and the corresponding visualization, the neutral SMPL body model is required. Please go to the [SMPL website](https://smpl.is.tue.mpg.de/) and follow the provided instructions to download the model (version 1.1.0). After downloading, copy `basicmodel_neutral_lbs_10_207_0_v1.1.0.pkl` to `src/comotion_demo/data/smpl/SMPL_NEUTRAL.pkl` (we rename the file to be compatible with the visualization library `aitviewer`).
48 |
49 | ### Run CoMotion
50 |
51 | We provide a demo script that takes either a video file or a directory of images as input. To run it, call:
52 |
53 | ```bash
54 | python demo.py -i path/to/video.mp4 -o results/
55 | ```
56 |
57 | Optional arguments include `--start-frame` and `--num-frames` to select subsets of the video to run on. The network will save a `.pt` file with all of the detected SMPL pose parameters as well as a rendered `.mp4` with the predictions overlaid on the input video. We also automatically produce a `.txt` file in the `MOT` format with bounding boxes compatible with most standard tracking evaluation code. If you wish to skip the visualization, add the command `--skip-visualization`.
58 |
59 | The demo code supports running on a single image as well, which the code will infer automatically if the input path provided has a `.png` or `.jpeg/.jpg` suffix:
60 |
61 | ```bash
62 | python demo.py -i path/to/image.jpg -o results/
63 | ```
64 |
65 | In this case, we save a `.pt` file with the detected SMPL poses as well as 2D and 3D coordinates and confidences associated with each detection.
66 |
67 | > [!TIP]
68 | >
69 | > - If you encounter an error that `libc++.1.dylib` is not found, resolve it with `conda install libcxx`.
70 | > - For headless rendering on a remote server, you may encounter an error like `XOpenDisplay: cannot open display`. In this case start a virtual display using `Xvfb :0 -screen 0 640x480x24 & export DISPLAY=:0.0`. You may need to install `xvfb` first (`apt install xvfb`).
71 |
72 | ## Citation
73 |
74 | If you find our work useful, please cite the following paper:
75 |
76 | ```bibtex
77 | @inproceedings{newell2025comotion,
78 | title = {CoMotion: Concurrent Multi-person 3D Motion},
79 | author = {Alejandro Newell and Peiyun Hu and Lahav Lipson and Stephan R. Richter and Vladlen Koltun},
80 | booktitle = {International Conference on Learning Representations},
81 | year = {2025},
82 | url = {https://openreview.net/forum?id=qKu6KWPgxt},
83 | }
84 | ```
85 |
86 | ## License
87 |
88 | This sample code is released under the [LICENSE](LICENSE.md) terms.
89 |
90 | The model weights are released under the [MODEL LICENSE](LICENSE_MODEL.md) terms.
91 |
92 | ## Acknowledgements
93 |
94 | Our codebase is built using multiple open source contributions, please see [Acknowledgements](ACKNOWLEDGEMENTS.md) for more details.
95 |
96 | Please check the paper for a complete list of references and datasets used in this work.
97 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | """Demo CoMotion with a video file or a directory of images."""
3 |
4 | import logging
5 | import os
6 | import shutil
7 | import tempfile
8 | from pathlib import Path
9 |
10 | import click
11 | import numpy as np
12 | import torch
13 | from PIL import Image
14 | from tqdm import tqdm
15 |
16 | from comotion_demo.models import comotion
17 | from comotion_demo.utils import dataloading, helper
18 | from comotion_demo.utils import track as track_utils
19 |
20 | try:
21 | from aitviewer.configuration import CONFIG
22 | from aitviewer.headless import HeadlessRenderer
23 | from aitviewer.renderables.billboard import Billboard
24 | from aitviewer.renderables.smpl import SMPLLayer, SMPLSequence
25 | from aitviewer.scene.camera import OpenCVCamera
26 |
27 | comotion_model_dir = Path(comotion.__file__).parent
28 | CONFIG.smplx_models = os.path.join(comotion_model_dir, "../data")
29 | CONFIG.window_type = "pyqt6"
30 | aitviewer_available = True
31 |
32 | except ModuleNotFoundError:
33 | print(
34 | "WARNING: Skipped aitviewer import, ensure it is installed to run visualization."
35 | )
36 | aitviewer_available = False
37 |
38 |
39 | logging.basicConfig(
40 | level=logging.INFO,
41 | format="%(asctime)s - %(levelname)s - %(funcName)s - %(message)s",
42 | datefmt="%Y-%m-%d %H:%M:%S",
43 | )
44 |
45 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
46 | use_mps = torch.mps.is_available()
47 |
48 |
49 | def prepare_scene(viewer, width, height, K, image_paths, fps=30):
50 | """Prepare the scene for AITViewer rendering."""
51 | viewer.reset()
52 | viewer.scene.floor.enabled = False
53 | viewer.scene.origin.enabled = False
54 | extrinsics = np.eye(4)[:3]
55 |
56 | # Initialize camera
57 | cam = OpenCVCamera(K, extrinsics, cols=width, rows=height, viewer=viewer)
58 | viewer.scene.add(cam)
59 | viewer.scene.camera.position = [0, 0, -5]
60 | viewer.scene.camera.target = [0, 0, 10]
61 | viewer.auto_set_camera_target = False
62 | viewer.set_temp_camera(cam)
63 | viewer.playback_fps = fps
64 |
65 | # "billboard" display for video frames
66 | billboard = Billboard.from_camera_and_distance(
67 | cam, 100.0, cols=width, rows=height, textures=image_paths
68 | )
69 | viewer.scene.add(billboard)
70 |
71 |
72 | def add_pose_to_scene(
73 | viewer,
74 | smpl_layer,
75 | betas,
76 | pose,
77 | trans,
78 | color=(0.6, 0.6, 0.6),
79 | alpha=1,
80 | color_ref=None,
81 | ):
82 | """Add estimated poses to the rendered scene."""
83 | if betas.ndim == 2:
84 | betas = betas[None]
85 | pose = pose[None]
86 | trans = trans[None]
87 |
88 | poses_root = pose[..., :3]
89 | poses_body = pose[..., 3:]
90 | max_people = pose.shape[1]
91 |
92 | if (betas != 0).any():
93 | for person_idx in range(max_people):
94 | if color_ref is None:
95 | person_color = color
96 | else:
97 | person_color = color_ref[person_idx % len(color_ref)] * 0.4 + 0.3
98 | person_color = [c_ for c_ in person_color] + [alpha]
99 |
100 | valid_vals = (betas[:, person_idx] != 0).any(-1)
101 | idx_range = valid_vals.nonzero()
102 | if len(idx_range) > 0:
103 | trans[~valid_vals][..., 2] = -10000
104 | viewer.scene.add(
105 | SMPLSequence(
106 | smpl_layer=smpl_layer,
107 | betas=betas[:, person_idx],
108 | poses_root=poses_root[:, person_idx],
109 | poses_body=poses_body[:, person_idx],
110 | trans=trans[:, person_idx],
111 | color=person_color,
112 | )
113 | )
114 |
115 |
116 | def visualize_poses(
117 | input_path,
118 | cache_path,
119 | video_path,
120 | start_frame,
121 | num_frames,
122 | frameskip=1,
123 | color=(0.6, 0.6, 0.6),
124 | alpha=1,
125 | fps=30,
126 | ):
127 | """Visualize SMPL poses."""
128 | logging.info(f"Rendering SMPL video: {input_path}")
129 |
130 | # Prepare temporary directory with saved images
131 | tmp_vis_dir = Path(tempfile.mkdtemp())
132 |
133 | frame_idx = 0
134 | image_paths = []
135 | for image, K in dataloading.yield_image_and_K(
136 | input_path, start_frame, num_frames, frameskip
137 | ):
138 | image_height, image_width = image.shape[-2:]
139 | image = dataloading.convert_tensor_to_image(image)
140 | image_paths.append(f"{tmp_vis_dir}/{frame_idx:06d}.jpg")
141 | Image.fromarray(image).save(image_paths[-1])
142 | frame_idx += 1
143 |
144 | # Initialize viewer
145 | viewer = HeadlessRenderer(size=(image_width, image_height))
146 |
147 | if dataloading.is_a_video(input_path):
148 | fps = int(dataloading.get_input_video_fps(input_path))
149 |
150 | prepare_scene(viewer, image_width, image_height, K.cpu().numpy(), image_paths, fps)
151 |
152 | # Prepare SMPL poses
153 | smpl_layer = SMPLLayer(model_type="smpl", gender="neutral")
154 | if not cache_path.exists():
155 | logging.warning("No detections found.")
156 | else:
157 | preds = torch.load(cache_path, weights_only=False, map_location="cpu")
158 | track_subset = track_utils.query_range(preds, 0, frame_idx - 1)
159 | id_lookup = track_subset["id"].max(0)[0]
160 | color_ref = helper.color_ref[id_lookup % len(helper.color_ref)]
161 | if len(id_lookup) == 1:
162 | color_ref = [color_ref]
163 |
164 | betas = track_subset["betas"]
165 | pose = track_subset["pose"]
166 | trans = track_subset["trans"]
167 |
168 | add_pose_to_scene(
169 | viewer, smpl_layer, betas, pose, trans, color, alpha, color_ref
170 | )
171 |
172 | # Save rendered scene
173 | viewer.save_video(
174 | video_dir=str(video_path),
175 | output_fps=fps,
176 | ensure_no_overwrite=False,
177 | )
178 |
179 | # Remove temporary directory
180 | shutil.rmtree(tmp_vis_dir)
181 |
182 |
183 | def run_detection(input_path, cache_path, skip_visualization=False, model=None):
184 | """Run model and visualize detections on single image."""
185 | if model is None:
186 | model = comotion.CoMotion(use_coreml=use_mps)
187 | model.to(device).eval()
188 |
189 | # Load image
190 | image = np.array(Image.open(input_path))
191 | image = dataloading.convert_image_to_tensor(image)
192 | K = dataloading.get_default_K(image)
193 | cropped_image, cropped_K = dataloading.prepare_network_inputs(image, K, device)
194 |
195 | # Get detections
196 | detections = model.detection_model(cropped_image, cropped_K)
197 | detections = comotion.detect.decode_network_outputs(
198 | K.to(device),
199 | model.smpl_decoder,
200 | detections,
201 | std=0.15, # Adjust NMS sensitivity
202 | conf_thr=0.25, # Adjust confidence cutoff
203 | )
204 |
205 | detections = {k: v[0].cpu() for k, v in detections.items()}
206 | torch.save(detections, cache_path)
207 |
208 | if not skip_visualization:
209 | # Initialize viewer
210 | image_height, image_width = image.shape[-2:]
211 | viewer = HeadlessRenderer(size=(image_width, image_height))
212 | prepare_scene(
213 | viewer, image_width, image_height, K.cpu().numpy(), [str(input_path)]
214 | )
215 |
216 | # Prepare SMPL poses
217 | smpl_layer = SMPLLayer(model_type="smpl", gender="neutral")
218 | add_pose_to_scene(
219 | viewer,
220 | smpl_layer,
221 | detections["betas"],
222 | detections["pose"],
223 | detections["trans"],
224 | )
225 |
226 | # Save rendered scene
227 | viewer.save_frame(str(cache_path).replace(".pt", ".png"))
228 |
229 |
230 | def track_poses(
231 | input_path, cache_path, start_frame, num_frames, frameskip=1, model=None
232 | ):
233 | """Track poses over a video or a directory of images."""
234 | if model is None:
235 | model = comotion.CoMotion(use_coreml=use_mps)
236 | model.to(device).eval()
237 |
238 | detections = []
239 | tracks = []
240 |
241 | initialized = False
242 | for image, K in tqdm(
243 | dataloading.yield_image_and_K(input_path, start_frame, num_frames, frameskip),
244 | desc="Running CoMotion",
245 | ):
246 | if not initialized:
247 | image_res = image.shape[-2:]
248 | model.init_tracks(image_res)
249 | initialized = True
250 |
251 | detection, track = model(image, K, use_mps=use_mps)
252 | detection = {k: v.cpu() for k, v in detection.items()}
253 | track = track.cpu()
254 | detections.append(detection)
255 | tracks.append(track)
256 |
257 | detections = {k: [d[k] for d in detections] for k in detections[0].keys()}
258 | tracks = torch.stack(tracks, 1)
259 | tracks = {k: getattr(tracks, k) for k in ["id", "pose", "trans", "betas"]}
260 |
261 | track_ref = track_utils.cleanup_tracks(
262 | {"detections": detections, "tracks": tracks},
263 | K,
264 | model.smpl_decoder.cpu(),
265 | min_matched_frames=1,
266 | )
267 | if track_ref:
268 | frame_idxs, track_idxs = track_utils.convert_to_idxs(
269 | track_ref, tracks["id"][0].squeeze(-1).long()
270 | )
271 | preds = {k: v[0, frame_idxs, track_idxs] for k, v in tracks.items()}
272 | preds["id"] = preds["id"].squeeze(-1).long()
273 | preds["frame_idx"] = frame_idxs
274 | torch.save(preds, cache_path)
275 |
276 | # Save bounding box tracks in MOT format
277 | bboxes = track_utils.bboxes_from_smpl(
278 | model.smpl_decoder,
279 | {k: preds[k] for k in ["betas", "pose", "trans"]},
280 | image_res,
281 | K,
282 | )
283 | with open(str(cache_path).replace(".pt", ".txt"), "w") as f:
284 | f.write(track_utils.convert_to_mot(preds["id"], preds["frame_idx"], bboxes))
285 |
286 |
287 | @click.command()
288 | @click.option(
289 | "-i",
290 | "--input-path",
291 | required=True,
292 | type=click.Path(exists=True, path_type=Path),
293 | help="Path to the input video, a directory of images, or a single input image.",
294 | )
295 | @click.option(
296 | "-o",
297 | "--output-dir",
298 | required=True,
299 | type=click.Path(exists=False, path_type=Path),
300 | help="Path to the output directory.",
301 | )
302 | @click.option(
303 | "-s",
304 | "--start-frame",
305 | default=0,
306 | type=int,
307 | help="Frame to start with.",
308 | )
309 | @click.option(
310 | "-n",
311 | "--num-frames",
312 | default=1_000_000_000,
313 | type=int,
314 | help="Number of frames to process.",
315 | )
316 | @click.option(
317 | "--skip-visualization",
318 | is_flag=True,
319 | help="Whether to skip rendering the output SMPL meshes.",
320 | )
321 | @click.option(
322 | "--frameskip",
323 | default=1,
324 | type=int,
325 | help="Subsample video frames (e.g. frameskip=2 processes every other frame).",
326 | )
327 | def main(
328 | input_path, output_dir, start_frame, num_frames, skip_visualization, frameskip
329 | ):
330 | """Demo entry point."""
331 | output_dir.mkdir(parents=True, exist_ok=True)
332 | input_name = input_path.stem
333 | skip_visualization = skip_visualization | (not aitviewer_available)
334 |
335 | cache_path = output_dir / f"{input_name}.pt"
336 | if input_path.suffix.lower() in dataloading.IMAGE_EXTENSIONS:
337 | # Run and visualize detections for a single image
338 | run_detection(input_path, cache_path, skip_visualization)
339 | else:
340 | # Run unrolled tracking on a full video
341 | track_poses(input_path, cache_path, start_frame, num_frames, frameskip)
342 | if not skip_visualization:
343 | video_path = output_dir / f"{input_name}.mp4"
344 | visualize_poses(
345 | input_path, cache_path, video_path, start_frame, num_frames, frameskip
346 | )
347 |
348 |
349 | if __name__ == "__main__":
350 | main()
351 |
--------------------------------------------------------------------------------
/get_pretrained_models.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | #
3 | # For licensing see accompanying LICENSE file.
4 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
5 |
6 | cd src/comotion_demo/data
7 |
8 | # Download and extract model checkpoints
9 | wget https://ml-site.cdn-apple.com/models/comotion/demo_checkpoints.tar.gz
10 | tar zxf demo_checkpoints.tar.gz
11 | rm demo_checkpoints.tar.gz
12 |
13 | cd ../../..
14 |
--------------------------------------------------------------------------------
/posetrack21_eval/README.md:
--------------------------------------------------------------------------------
1 | # PoseTrack21 MOT Evaluation
2 |
3 | This is a slight adaptation of the [PoseTrack21 evaluation code](https://github.com/anDoer/PoseTrack21). We opt to use the bounding-box MOT evaluation here since our model does not output the appropriate format for the keypoint-based metrics. We make several adjustments to the MOT evaluation code.
4 |
5 | Major changes that affect tracking evaluation:
6 | - as discussed in the paper, we change the IOU calculation for ignore regions. This is a one line change:
7 | ```
8 | # previous
9 | region_ious[j, i] = poly_intersection / poly_union
10 | # ours
11 | region_ious[j, i] = poly_intersection / det_boxes[j].area
12 | ```
13 | - some of the PoseTrack sequences are not annotated at every frame, but the code by default would penalize false positives on un-annotated frames. We have updated the code to ignore detections on frames with no ground-truth annotations.
14 |
15 | Minor updates:
16 | - we no longer load images in the `PTSequence` class
17 | - we update the dtype behavior in `motmetrics` to support more recent versions of `numpy`
18 | - turn on `use_ignore_regions` by default (no flag to set at runtime)
19 |
20 | ## Setup
21 |
22 | A few extra packages are needed to run the evaluation code, these can be installed with:
23 |
24 | ```
25 | conda install geos
26 | pip install lap pandas shapely==1.7.1 xmltodict
27 | ```
28 |
29 | Follow the instructions provided by the authors in the [original repository](https://github.com/anDoer/PoseTrack21) to obtain a copy of the dataset and annotations.
30 |
31 | ## Running the evaluation
32 |
33 | To run:
34 | ```
35 | python evaluate_mot --dataset_path $PATH_TO_DATASET_ROOT \
36 | --mot_path $PATH_TO_RESPECTIVE_MOT_FOLDER \
37 | --result_path $FOLDER_WITH_YOUR_RESULTS \
38 | ```
39 |
40 | We also support evaluation on individual sequences with:
41 | ```
42 | python evaluate_mot ... --sequence_choice [SEQUENCE_ID_0] [SEQUENCE_ID_1] ...
43 | ```
44 | where `SEQUENCE_ID` is the integer sequence for a given PoseTrack example (e.g. `000342`) - no need to include `_mpii_test`.
45 |
--------------------------------------------------------------------------------
/posetrack21_eval/datasets/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/posetrack21_eval/datasets/__init__.py
--------------------------------------------------------------------------------
/posetrack21_eval/datasets/pt_sequence.py:
--------------------------------------------------------------------------------
1 | import configparser
2 | import csv
3 | import os
4 | import os.path as osp
5 |
6 | import numpy as np
7 | from PIL import Image
8 |
9 | import json
10 |
11 | from shapely.geometry import box, Polygon, MultiPolygon
12 |
13 | def ignore_regions():
14 | # remove boxes from in ignore regions
15 | ignore_box_candidates = set()
16 | ##########################
17 | ignore_iou_thres = 0.1
18 | ##########################
19 |
20 | if len(ignore_x) > 0:
21 | if not isinstance(ignore_x[0], list):
22 | ignore_x = [ignore_x]
23 | ignore_y = [ignore_y]
24 |
25 | # build ignore regions:
26 | ignore_regions = []
27 | for r_idx in range(len(ignore_x)):
28 | region = []
29 |
30 | for x, y in zip(ignore_x[r_idx], ignore_y[r_idx]):
31 | region.append([x, y])
32 |
33 | ignore_region = Polygon(region)
34 | ignore_regions.append(ignore_region)
35 |
36 | region_ious = np.zeros((len(ignore_regions), num_det), dtype=np.float32)
37 | det_boxes = []
38 | for j in rrange(num_det):
39 | x1 = det[j, 0]
40 | y1 = det[j, 1]
41 | x2 = det[j, 2]
42 | y2 = det[j, 3]
43 |
44 | box_poly = Polygon([
45 | [x1, y1],
46 | [x2, y1],
47 | [x2, y2],
48 | [x1, y2],
49 | [x1, y1],
50 | ])
51 |
52 | det_boxes.append(box_poly)
53 |
54 | for i in xrange(len(ignore_regions)):
55 | for j in xrange(num_det):
56 | if ignore_regions[i].is_valid:
57 | poly_intersection = ignore_regions[i].intersection(det_boxes[j]).area
58 | poly_union = ignore_regions[i].union(det_boxes[j]).area
59 | else:
60 | multi_poly = ignore_regions[i].buffer(0)
61 | poly_intersection = 0
62 | poly_union = 0
63 |
64 | if isinstance(multi_poly, Polygon):
65 | poly_intersection = multi_poly.intersection(det_boxes[j]).area
66 | poly_union = multi_poly.union(det_boxes[j]).area
67 | else:
68 | multi_poly = ignore_regions[i].buffer(0)
69 | poly_intersection = 0
70 | poly_union = 0
71 |
72 | if isinstance(multi_poly, Polygon):
73 | poly_intersection = multi_poly.intersection(det_boxes[j]).area
74 | poly_union = multi_poly.union(det_boxes[j]).area
75 | else:
76 | for poly in multi_poly:
77 | poly_intersection += poly.intersection(det_boxes[j]).area
78 | poly_union += poly.union(det_boxes[j]).area
79 |
80 | region_ious[i, j] = poly_intersection / poly_union
81 |
82 | candidates = np.argwhere(region_ious[i] > ignore_iou_thres)
83 | if len(candidates) > 0:
84 | candidates = candidates[:, 0].tolist()
85 | ignore_box_candidates.update(candidates)
86 |
87 | ious = np.zeros((num_gt, num_det), dtype=np.float32)
88 | for i in xrange(num_gt):
89 | for j in xrange(num_det):
90 | ious[i, j] = _compute_iou(gt_boxes[i], det[j, :4])
91 | tfmat = (ious >= iou_thresh)
92 | # for each det, keep only the largest iou of all the gt
93 | for j in xrange(num_det):
94 | largest_ind = np.argmax(ious[:, j])
95 | for i in xrange(num_gt):
96 | if i != largest_ind:
97 | tfmat[i, j] = False
98 | # for each gt, keep only the largest iou of all the det
99 | for i in xrange(num_gt):
100 | largest_ind = np.argmax(ious[i, :])
101 | for j in xrange(num_det):
102 | if j != largest_ind:
103 | tfmat[i, j] = False
104 | for j in xrange(num_det):
105 | if j in ignore_box_candidates and not tfmat[:, j].any():
106 | # we have a detection in ignore region
107 | detections_to_ignore[gallery_idx].append(j)
108 | continue
109 |
110 | y_score.append(det[j, -1])
111 | if tfmat[:, j].any():
112 | y_true.append(True)
113 | else:
114 | y_true.append(False)
115 | count_tp += tfmat.sum()
116 | count_gt += num_gt
117 |
118 |
119 | class PTSequence():
120 | """Multiple Object Tracking Dataset.
121 | """
122 | def __init__(self, seq_name, mot_dir, dataset_path, vis_threshold=0.0):
123 | self._seq_name = seq_name
124 | self._vis_threshold = vis_threshold
125 |
126 | self._mot_dir = mot_dir
127 | self.dataset_path = dataset_path
128 | self._folders = os.listdir(self._mot_dir)
129 |
130 | if seq_name is not None:
131 | assert seq_name in self._folders, \
132 | 'Image set does not exist: {}'.format(seq_name)
133 |
134 | self.data, self.no_gt, self.ignore = self._sequence()
135 | else:
136 | self.data = []
137 | self.no_gt = False
138 |
139 | def __len__(self):
140 | return len(self.data)
141 |
142 | def __getitem__(self, idx):
143 | """Return the ith image converted to blob"""
144 | data = self.data[idx]
145 | # img = Image.open(data['im_path']).convert("RGB")
146 | # img = np.asarray(img)
147 |
148 | sample = {}
149 | # sample['img'] = img
150 | sample['dets'] = np.array([det[:4] for det in data['dets']])
151 | sample['img_path'] = data['im_path']
152 | sample['gt'] = data['gt']
153 | sample['vis'] = data['vis']
154 |
155 | return sample
156 |
157 | def _sequence(self):
158 | seq_name = self._seq_name
159 | seq_path = osp.join(self._mot_dir, seq_name)
160 |
161 | config_file = os.path.join(seq_path, 'image_info.json')
162 |
163 | assert osp.exists(config_file), \
164 | 'Config file does not exist: {}'.format(config_file)
165 |
166 | with open(config_file, 'r') as file:
167 | image_info = json.load(file)
168 |
169 | seqLength = len(image_info)
170 | gt_file = osp.join(seq_path, 'gt', 'gt.txt')
171 |
172 | total = []
173 |
174 | visibility = {}
175 | boxes = {}
176 | dets = {}
177 | ignore_x = {}
178 | ignore_y = {}
179 |
180 | for i in range(seqLength):
181 | frame_index = image_info[i]['frame_index']
182 | boxes[frame_index] = {}
183 | visibility[frame_index] = {}
184 | dets[frame_index] = []
185 | ignore_x[frame_index] = image_info[i]['ignore_regions_x']
186 | ignore_y[frame_index] = image_info[i]['ignore_regions_y']
187 |
188 | no_gt = False
189 | if osp.exists(gt_file):
190 | with open(gt_file, "r") as inf:
191 | reader = csv.reader(inf, delimiter=',')
192 | for row in reader:
193 | # class person, certainity 1, visibility >= 0.25
194 | row[7] = '1'
195 | row[8] = '1'
196 | if int(row[6]) == 1 and int(row[7]) == 1 and float(row[8]) >= self._vis_threshold:
197 | # Make pixel indexes 0-based, should already be 0-based (or not)
198 | x1 = int(float(row[2])) # - 1
199 | y1 = int(float(row[3])) # - 1
200 | # This -1 accounts for the width (width of 1 x1=x2)
201 | x2 = x1 + int(float(row[4])) #- 1
202 | y2 = y1 + int(float(row[5])) # - 1
203 | bb = np.array([x1,y1,x2,y2], dtype=np.float32)
204 |
205 | boxes[int(row[0])][int(row[1])] = bb
206 | visibility[int(row[0])][int(row[1])] = float(row[8])
207 | else:
208 | no_gt = True
209 |
210 | det_file = osp.join(seq_path, 'det', 'det.txt')
211 |
212 | if osp.exists(det_file):
213 | with open(det_file, "r") as inf:
214 | reader = csv.reader(inf, delimiter=',')
215 | for row in reader:
216 | x1 = float(row[2]) - 1
217 | y1 = float(row[3]) - 1
218 | # This -1 accounts for the width (width of 1 x1=x2)
219 | x2 = x1 + float(row[4]) - 1
220 | y2 = y1 + float(row[5]) - 1
221 | score = float(row[6])
222 | bb = np.array([x1,y1,x2,y2, score], dtype=np.float32)
223 | dets[int(float(row[0]))].append(bb)
224 |
225 | dataset_path = self.dataset_path
226 | for i in range(seqLength):
227 | im_path = os.path.join(dataset_path, image_info[i]['file_name'])
228 | frame_index = image_info[i]['frame_index']
229 |
230 | sample = {'gt':boxes[frame_index],
231 | 'im_path':im_path,
232 | 'vis':visibility[frame_index],
233 | 'dets':dets[frame_index]}
234 |
235 | total.append(sample)
236 |
237 | return total, no_gt, {'ignore_x': ignore_x, 'ignore_y': ignore_y}
238 |
239 | def __str__(self):
240 | return self._seq_name
241 |
242 | def write_results(self, all_tracks, output_dir):
243 | """Write the tracks in the format for MOT16/MOT17 sumbission
244 |
245 | all_tracks: dictionary with 1 dictionary for every track with {..., i:np.array([x1,y1,x2,y2]), ...} at key track_num
246 |
247 | Each file contains these lines:
248 | , , , , , , , , ,
249 | """
250 |
251 | #format_str = "{}, -1, {}, {}, {}, {}, {}, -1, -1, -1"
252 |
253 | if not os.path.exists(output_dir):
254 | os.makedirs(output_dir)
255 |
256 | with open(osp.join(output_dir, self._seq_name), "w") as of:
257 | writer = csv.writer(of, delimiter=',')
258 | for i, track in all_tracks.items():
259 | for frame, bb in track.items():
260 | x1 = bb[0]
261 | y1 = bb[1]
262 | x2 = bb[2]
263 | y2 = bb[3]
264 | writer.writerow(
265 | [frame + 1,
266 | i + 1,
267 | x1 + 1,
268 | y1 + 1,
269 | x2 - x1 + 1,
270 | y2 - y1 + 1,
271 | -1, -1, -1, -1])
272 |
273 | def load_results(self, output_dir):
274 | file_path = osp.join(output_dir, self._seq_name)
275 | results = {}
276 |
277 | if not os.path.isfile(file_path):
278 | if os.path.isfile(f'{file_path}.txt'):
279 | file_path = f'{file_path}.txt'
280 | else:
281 | return results
282 |
283 | with open(file_path, "r") as of:
284 | csv_reader = csv.reader(of, delimiter=',')
285 | for row in csv_reader:
286 | frame_id, track_id = int(row[0]) - 1, int(row[1]) - 1
287 |
288 | if not track_id in results:
289 | results[track_id] = {}
290 |
291 | x1 = float(row[2]) - 1
292 | y1 = float(row[3]) - 1
293 | x2 = float(row[4]) - 1 + x1
294 | y2 = float(row[5]) - 1 + y1
295 |
296 | results[track_id][frame_id] = [x1, y1, x2, y2]
297 |
298 | return results
299 |
--------------------------------------------------------------------------------
/posetrack21_eval/datasets/pt_warper.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import os
3 | #import torch
4 | #from torch.utils.data import Dataset
5 |
6 | from .pt_sequence import PTSequence
7 |
8 |
9 | class PTWrapper():
10 | """A Wrapper for the MOT_Sequence class to return multiple sequences."""
11 |
12 | def __init__(self, mot_dir, dataset_path, vis_threshold=0.0):
13 | 'seq_name, mot_dir, dataset_path, vis_threshold=0.0'
14 |
15 | sequences = os.listdir(mot_dir)
16 |
17 | self._data = []
18 | for s in sequences:
19 | self._data.append(PTSequence(s, mot_dir, dataset_path, vis_threshold))
20 |
21 | def __len__(self):
22 | return len(self._data)
23 |
24 | def __getitem__(self, idx):
25 | return self._data[idx]
26 |
--------------------------------------------------------------------------------
/posetrack21_eval/evaluate_mot.py:
--------------------------------------------------------------------------------
1 | if __name__ == "__main__":
2 | from datasets.pt_warper import PTWrapper
3 | import motmetrics as mm
4 | else:
5 | from .datasets.pt_warper import PTWrapper
6 | from . import motmetrics as mm
7 |
8 | import numpy as np
9 | mm.lap.default_solver = 'lap'
10 | from shapely.geometry import box, Polygon, MultiPolygon
11 | import argparse
12 | import os
13 | from tqdm import tqdm
14 |
15 | def get_mot_accum(results, seq, ignore_iou_thres=0.1, use_ignore_regions=False):
16 | def get_frame_idx(impath):
17 | return int(impath.split("/")[-1].replace(".jpg", ""))
18 |
19 | valid_frames = [get_frame_idx(d["im_path"]) for d in seq.data if len(d["gt"]) > 0]
20 |
21 | mot_accum = mm.MOTAccumulator(auto_id=True)
22 |
23 | ignore_regions_x = seq.ignore['ignore_x']
24 | ignore_regions_y = seq.ignore['ignore_y']
25 | for i, data in enumerate(seq):
26 | if i in valid_frames:
27 | # i corresponds to image index
28 | ignore_x = ignore_regions_x[i + 1]
29 | ignore_y = ignore_regions_y[i + 1]
30 |
31 | if len(ignore_x) > 0:
32 | if not isinstance(ignore_x[0], list):
33 | ignore_x = [ignore_x]
34 | if len(ignore_y) > 0:
35 | if not isinstance(ignore_y[0], list):
36 | ignore_y = [ignore_y]
37 |
38 | ignore_regions = []
39 | for r_idx in range(len(ignore_x)):
40 | region = []
41 | if len(ignore_x[r_idx]) == 0 or len(ignore_y[r_idx]) == 0:
42 | continue
43 |
44 | for x, y in zip(ignore_x[r_idx], ignore_y[r_idx]):
45 | region.append([x, y])
46 |
47 | try:
48 | ignore_region = Polygon(region)
49 | ignore_regions.append(ignore_region)
50 | except:
51 | assert False
52 |
53 | gt = data['gt']
54 | gt_ids = []
55 | gt_boxes = []
56 |
57 | if gt:
58 | for gt_id, box in gt.items():
59 | gt_ids.append(gt_id)
60 | gt_boxes.append(box)
61 |
62 | gt_boxes = np.stack(gt_boxes, axis=0)
63 | # x1, y1, x2, y2 --> x1, y1, width, height
64 | gt_boxes = np.stack((gt_boxes[:, 0],
65 | gt_boxes[:, 1],
66 | gt_boxes[:, 2] - gt_boxes[:, 0],
67 | gt_boxes[:, 3] - gt_boxes[:, 1]),
68 | axis=1)
69 | else:
70 | gt_boxes = np.array([])
71 |
72 | track_ids = []
73 | track_boxes = []
74 | det_boxes = []
75 | ignore_candidates = []
76 |
77 | for track_id, frames in results.items():
78 | if i in frames:
79 | track_ids.append(track_id)
80 | # frames = x1, y1, x2, y2, score
81 | track_boxes.append(frames[i][:4])
82 | box = frames[i][:4]
83 |
84 | if use_ignore_regions:
85 | x1 = box[0]
86 | y1 = box[1]
87 | x2 = box[2]
88 | y2 = box[3]
89 |
90 | box_poly = Polygon([
91 | [x1, y1],
92 | [x2, y1],
93 | [x2, y2],
94 | [x1, y2],
95 | [x1, y1],
96 | ])
97 |
98 | det_boxes.append(box_poly)
99 |
100 | if use_ignore_regions:
101 | # obtain candidate detections for ignore regions
102 | region_ious = np.zeros((len(det_boxes), len(ignore_regions)), dtype=np.float32)
103 | for i in range(len(ignore_regions)):
104 | for j in range(len(det_boxes)):
105 | if ignore_regions[i].is_valid:
106 | poly_intersection = ignore_regions[i].intersection(det_boxes[j]).area
107 | poly_union = ignore_regions[i].union(det_boxes[j]).area
108 | else:
109 | multi_poly = ignore_regions[i].buffer(0)
110 | poly_intersection = 0
111 | poly_union = 0
112 |
113 | if isinstance(multi_poly, Polygon):
114 | poly_intersection = multi_poly.intersection(det_boxes[j]).area
115 | poly_union = multi_poly.union(det_boxes[j]).area
116 | elif isinstance(multi_poly, MultiPolygon):
117 | poly_intersection = multi_poly.intersection(det_boxes[j]).area
118 | poly_union = multi_poly.union(det_boxes[j]).area
119 | else:
120 | for poly in multi_poly:
121 | poly_intersection += poly.intersection(det_boxes[j]).area
122 | poly_union += poly.union(det_boxes[j]).area
123 |
124 | # For reference: original IOU calculation
125 | # - region_ious[j, i] = poly_intersection / poly_union
126 |
127 | # Updated IOU calculation
128 | region_ious[j, i] = poly_intersection / det_boxes[j].area
129 |
130 | if len(ignore_regions) > 0 and len(det_boxes) > 0:
131 | ignore_candidates = np.where((region_ious > ignore_iou_thres).max(axis=1))[0].tolist()
132 |
133 | if track_ids:
134 | track_boxes = np.stack(track_boxes, axis=0)
135 | # x1, y1, x2, y2 --> x1, y1, width, height
136 | track_boxes = np.stack((track_boxes[:, 0],
137 | track_boxes[:, 1],
138 | track_boxes[:, 2] - track_boxes[:, 0],
139 | track_boxes[:, 3] - track_boxes[:, 1]),
140 | axis=1)
141 | else:
142 | track_boxes = np.array([])
143 |
144 | distance = mm.distances.iou_matrix(gt_boxes, track_boxes, max_iou=0.5)
145 |
146 | mot_accum.update(
147 | gt_ids,
148 | track_ids,
149 | distance,
150 | ignore_candidates=ignore_candidates)
151 |
152 | return mot_accum
153 |
154 |
155 | def evaluate_mot_accums(accums, names, generate_overall=False):
156 | mh = mm.metrics.create()
157 | summary = mh.compute_many(
158 | accums,
159 | metrics=mm.metrics.motchallenge_metrics + ['num_matches', 'num_objects'],
160 | names=names,
161 | generate_overall=generate_overall, )
162 |
163 | str_summary = mm.io.render_summary(
164 | summary,
165 | formatters=mh.formatters,
166 | namemap=mm.io.motchallenge_metric_names, )
167 | print(str_summary)
168 |
169 | return str_summary
170 |
171 | def main():
172 | parser = argparse.ArgumentParser()
173 | parser.add_argument('--mot_path', required=True, help='Path to GT MOT files')
174 | parser.add_argument('--dataset_path', required=True)
175 | parser.add_argument('--result_path', required=True)
176 | parser.add_argument('--ignore_iou_thres', default=0.1)
177 | parser.add_argument('--sequence_choice', type=str, nargs="+", default=[])
178 | args = parser.parse_args()
179 |
180 | mot_accums = []
181 |
182 | mot_path = args.mot_path
183 | dataset_path = args.dataset_path
184 | result_path = args.result_path
185 | use_ignore_regions = True
186 | ignore_iou_thres = args.ignore_iou_thres
187 |
188 | dataset = PTWrapper(mot_path, dataset_path, vis_threshold=0.1)
189 | sequence_names = [str(s) for s in dataset if not s.no_gt]
190 | if len(args.sequence_choice) > 0:
191 | sequence_names = [f"{s}_mpii_test" for s in args.sequence_choice]
192 |
193 | for seq_idx, seq in enumerate(tqdm(dataset)):
194 | if seq._seq_name not in sequence_names:
195 | continue
196 |
197 | if not os.path.isdir(result_path):
198 | raise FileNotFoundError(f"result path {result_path} does not exist")
199 | results = seq.load_results(result_path)
200 | if len(results) == 0:
201 | print(f"Results not provided for gt sequence {seq._seq_name}")
202 | mot_accums.append(get_mot_accum(results,
203 | seq,
204 | ignore_iou_thres=ignore_iou_thres,
205 | use_ignore_regions=use_ignore_regions))
206 |
207 | if mot_accums:
208 | evaluate_mot_accums(mot_accums, sequence_names, generate_overall=True)
209 |
210 | if __name__ == '__main__':
211 | main()
212 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/__init__.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
9 |
10 | Christoph Heindl, 2017
11 | https://github.com/cheind/py-motmetrics
12 | """
13 |
14 | from __future__ import absolute_import
15 | from __future__ import division
16 | from __future__ import print_function
17 |
18 | __all__ = [
19 | 'distances',
20 | 'io',
21 | 'lap',
22 | 'metrics',
23 | 'utils',
24 | 'MOTAccumulator',
25 | ]
26 |
27 | from . import distances
28 | from . import io
29 | from . import lap
30 | from . import metrics
31 | from . import utils
32 | from .mot import MOTAccumulator
33 |
34 | # Needs to be last line
35 | __version__ = '1.2.0'
36 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/distances.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Functions for comparing predictions and ground-truth."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import numpy as np
15 |
16 | from . import math_util
17 |
18 |
19 | def norm2squared_matrix(objs, hyps, max_d2=float('inf')):
20 | """Computes the squared Euclidean distance matrix between object and hypothesis points.
21 |
22 | Params
23 | ------
24 | objs : NxM array
25 | Object points of dim M in rows
26 | hyps : KxM array
27 | Hypothesis points of dim M in rows
28 |
29 | Kwargs
30 | ------
31 | max_d2 : float
32 | Maximum tolerable squared Euclidean distance. Object / hypothesis points
33 | with larger distance are set to np.nan signalling do-not-pair. Defaults
34 | to +inf
35 |
36 | Returns
37 | -------
38 | C : NxK array
39 | Distance matrix containing pairwise distances or np.nan.
40 | """
41 |
42 | objs = np.atleast_2d(objs).astype(float)
43 | hyps = np.atleast_2d(hyps).astype(float)
44 |
45 | if objs.size == 0 or hyps.size == 0:
46 | return np.empty((0, 0))
47 |
48 | assert hyps.shape[1] == objs.shape[1], "Dimension mismatch"
49 |
50 | delta = objs[:, np.newaxis] - hyps[np.newaxis, :]
51 | C = np.sum(delta ** 2, axis=-1)
52 |
53 | C[C > max_d2] = np.nan
54 | return C
55 |
56 |
57 | def rect_min_max(r):
58 | min_pt = r[..., :2]
59 | size = r[..., 2:]
60 | max_pt = min_pt + size
61 | return min_pt, max_pt
62 |
63 |
64 | def boxiou(a, b):
65 | """Computes IOU of two rectangles."""
66 | a_min, a_max = rect_min_max(a)
67 | b_min, b_max = rect_min_max(b)
68 | # Compute intersection.
69 | i_min = np.maximum(a_min, b_min)
70 | i_max = np.minimum(a_max, b_max)
71 | i_size = np.maximum(i_max - i_min, 0)
72 | i_vol = np.prod(i_size, axis=-1)
73 | # Get volume of union.
74 | a_size = np.maximum(a_max - a_min, 0)
75 | b_size = np.maximum(b_max - b_min, 0)
76 | a_vol = np.prod(a_size, axis=-1)
77 | b_vol = np.prod(b_size, axis=-1)
78 | u_vol = a_vol + b_vol - i_vol
79 | return np.where(i_vol == 0, np.zeros_like(i_vol, dtype=float),
80 | math_util.quiet_divide(i_vol, u_vol))
81 |
82 |
83 | def iou_matrix(objs, hyps, max_iou=1.):
84 | """Computes 'intersection over union (IoU)' distance matrix between object and hypothesis rectangles.
85 |
86 | The IoU is computed as
87 |
88 | IoU(a,b) = 1. - isect(a, b) / union(a, b)
89 |
90 | where isect(a,b) is the area of intersection of two rectangles and union(a, b) the area of union. The
91 | IoU is bounded between zero and one. 0 when the rectangles overlap perfectly and 1 when the overlap is
92 | zero.
93 |
94 | Params
95 | ------
96 | objs : Nx4 array
97 | Object rectangles (x,y,w,h) in rows
98 | hyps : Kx4 array
99 | Hypothesis rectangles (x,y,w,h) in rows
100 |
101 | Kwargs
102 | ------
103 | max_iou : float
104 | Maximum tolerable overlap distance. Object / hypothesis points
105 | with larger distance are set to np.nan signalling do-not-pair. Defaults
106 | to 0.5
107 |
108 | Returns
109 | -------
110 | C : NxK array
111 | Distance matrix containing pairwise distances or np.nan.
112 | """
113 |
114 | if np.size(objs) == 0 or np.size(hyps) == 0:
115 | return np.empty((0, 0))
116 |
117 | objs = np.asfarray(objs)
118 | hyps = np.asfarray(hyps)
119 | assert objs.shape[1] == 4
120 | assert hyps.shape[1] == 4
121 | iou = boxiou(objs[:, None], hyps[None, :])
122 | dist = 1 - iou
123 | return np.where(dist > max_iou, np.nan, dist)
124 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/io.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Functions for loading data and writing summaries."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | from enum import Enum
15 | import io
16 |
17 | import numpy as np
18 | import pandas as pd
19 | import scipy.io
20 | import xmltodict
21 |
22 |
23 | class Format(Enum):
24 | """Enumerates supported file formats."""
25 |
26 | MOT16 = 'mot16'
27 | """Milan, Anton, et al. "Mot16: A benchmark for multi-object tracking." arXiv preprint arXiv:1603.00831 (2016)."""
28 |
29 | MOT15_2D = 'mot15-2D'
30 | """Leal-Taixe, Laura, et al. "MOTChallenge 2015: Towards a benchmark for multi-target tracking." arXiv preprint arXiv:1504.01942 (2015)."""
31 |
32 | VATIC_TXT = 'vatic-txt'
33 | """Vondrick, Carl, Donald Patterson, and Deva Ramanan. "Efficiently scaling up crowdsourced video annotation." International Journal of Computer Vision 101.1 (2013): 184-204.
34 | https://github.com/cvondrick/vatic
35 | """
36 |
37 | DETRAC_MAT = 'detrac-mat'
38 | """Wen, Longyin et al. "UA-DETRAC: A New Benchmark and Protocol for Multi-Object Detection and Tracking." arXiv preprint arXiv:arXiv:1511.04136 (2016).
39 | http://detrac-db.rit.albany.edu/download
40 | """
41 |
42 | DETRAC_XML = 'detrac-xml'
43 | """Wen, Longyin et al. "UA-DETRAC: A New Benchmark and Protocol for Multi-Object Detection and Tracking." arXiv preprint arXiv:arXiv:1511.04136 (2016).
44 | http://detrac-db.rit.albany.edu/download
45 | """
46 |
47 |
48 | def load_motchallenge(fname, **kwargs):
49 | r"""Load MOT challenge data.
50 |
51 | Params
52 | ------
53 | fname : str
54 | Filename to load data from
55 |
56 | Kwargs
57 | ------
58 | sep : str
59 | Allowed field separators, defaults to '\s+|\t+|,'
60 | min_confidence : float
61 | Rows with confidence less than this threshold are removed.
62 | Defaults to -1. You should set this to 1 when loading
63 | ground truth MOTChallenge data, so that invalid rectangles in
64 | the ground truth are not considered during matching.
65 |
66 | Returns
67 | ------
68 | df : pandas.DataFrame
69 | The returned dataframe has the following columns
70 | 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility'
71 | The dataframe is indexed by ('FrameId', 'Id')
72 | """
73 |
74 | sep = kwargs.pop('sep', r'\s+|\t+|,')
75 | min_confidence = kwargs.pop('min_confidence', -1)
76 | df = pd.read_csv(
77 | fname,
78 | sep=sep,
79 | index_col=[0, 1],
80 | skipinitialspace=True,
81 | header=None,
82 | names=['FrameId', 'Id', 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility', 'unused'],
83 | engine='python'
84 | )
85 |
86 | # Account for matlab convention.
87 | df[['X', 'Y']] -= (1, 1)
88 |
89 | # Removed trailing column
90 | del df['unused']
91 |
92 | # Remove all rows without sufficient confidence
93 | return df[df['Confidence'] >= min_confidence]
94 |
95 |
96 | def load_vatictxt(fname, **kwargs):
97 | """Load Vatic text format.
98 |
99 | Loads the vatic CSV text having the following columns per row
100 |
101 | 0 Track ID. All rows with the same ID belong to the same path.
102 | 1 xmin. The top left x-coordinate of the bounding box.
103 | 2 ymin. The top left y-coordinate of the bounding box.
104 | 3 xmax. The bottom right x-coordinate of the bounding box.
105 | 4 ymax. The bottom right y-coordinate of the bounding box.
106 | 5 frame. The frame that this annotation represents.
107 | 6 lost. If 1, the annotation is outside of the view screen.
108 | 7 occluded. If 1, the annotation is occluded.
109 | 8 generated. If 1, the annotation was automatically interpolated.
110 | 9 label. The label for this annotation, enclosed in quotation marks.
111 | 10+ attributes. Each column after this is an attribute set in the current frame
112 |
113 | Params
114 | ------
115 | fname : str
116 | Filename to load data from
117 |
118 | Returns
119 | ------
120 | df : pandas.DataFrame
121 | The returned dataframe has the following columns
122 | 'X', 'Y', 'Width', 'Height', 'Lost', 'Occluded', 'Generated', 'ClassId', '', '', ...
123 | where is placeholder for the actual attribute name capitalized (first letter). The order of attribute
124 | columns is sorted in attribute name. The dataframe is indexed by ('FrameId', 'Id')
125 | """
126 | # pylint: disable=too-many-locals
127 |
128 | sep = kwargs.pop('sep', ' ')
129 |
130 | with io.open(fname) as f:
131 | # First time going over file, we collect the set of all variable activities
132 | activities = set()
133 | for line in f:
134 | for c in line.rstrip().split(sep)[10:]:
135 | activities.add(c)
136 | activitylist = sorted(list(activities))
137 |
138 | # Second time we construct artificial binary columns for each activity
139 | data = []
140 | f.seek(0)
141 | for line in f:
142 | fields = line.rstrip().split()
143 | attrs = ['0'] * len(activitylist)
144 | for a in fields[10:]:
145 | attrs[activitylist.index(a)] = '1'
146 | fields = fields[:10]
147 | fields.extend(attrs)
148 | data.append(' '.join(fields))
149 |
150 | strdata = '\n'.join(data)
151 |
152 | dtype = {
153 | 'Id': np.int64,
154 | 'X': np.float32,
155 | 'Y': np.float32,
156 | 'Width': np.float32,
157 | 'Height': np.float32,
158 | 'FrameId': np.int64,
159 | 'Lost': bool,
160 | 'Occluded': bool,
161 | 'Generated': bool,
162 | 'ClassId': str,
163 | }
164 |
165 | # Remove quotes from activities
166 | activitylist = [a.replace('\"', '').capitalize() for a in activitylist]
167 |
168 | # Add dtypes for activities
169 | for a in activitylist:
170 | dtype[a] = bool
171 |
172 | # Read from CSV
173 | names = ['Id', 'X', 'Y', 'Width', 'Height', 'FrameId', 'Lost', 'Occluded', 'Generated', 'ClassId']
174 | names.extend(activitylist)
175 | df = pd.read_csv(io.StringIO(strdata), names=names, index_col=['FrameId', 'Id'], header=None, sep=' ')
176 |
177 | # Correct Width and Height which are actually XMax, Ymax in files.
178 | w = df['Width'] - df['X']
179 | h = df['Height'] - df['Y']
180 | df['Width'] = w
181 | df['Height'] = h
182 |
183 | return df
184 |
185 |
186 | def load_detrac_mat(fname):
187 | """Loads UA-DETRAC annotations data from mat files
188 |
189 | Competition Site: http://detrac-db.rit.albany.edu/download
190 |
191 | File contains a nested structure of 2d arrays for indexed by frame id
192 | and Object ID. Separate arrays for top, left, width and height are given.
193 |
194 | Params
195 | ------
196 | fname : str
197 | Filename to load data from
198 |
199 | Kwargs
200 | ------
201 | Currently none of these arguments used.
202 |
203 | Returns
204 | ------
205 | df : pandas.DataFrame
206 | The returned dataframe has the following columns
207 | 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility'
208 | The dataframe is indexed by ('FrameId', 'Id')
209 | """
210 |
211 | matData = scipy.io.loadmat(fname)
212 |
213 | frameList = matData['gtInfo'][0][0][4][0]
214 | leftArray = matData['gtInfo'][0][0][0]
215 | topArray = matData['gtInfo'][0][0][1]
216 | widthArray = matData['gtInfo'][0][0][3]
217 | heightArray = matData['gtInfo'][0][0][2]
218 |
219 | parsedGT = []
220 | for f in frameList:
221 | ids = [i + 1 for i, v in enumerate(leftArray[f - 1]) if v > 0]
222 | for i in ids:
223 | row = []
224 | row.append(f)
225 | row.append(i)
226 | row.append(leftArray[f - 1, i - 1] - widthArray[f - 1, i - 1] / 2)
227 | row.append(topArray[f - 1, i - 1] - heightArray[f - 1, i - 1])
228 | row.append(widthArray[f - 1, i - 1])
229 | row.append(heightArray[f - 1, i - 1])
230 | row.append(1)
231 | row.append(-1)
232 | row.append(-1)
233 | row.append(-1)
234 | parsedGT.append(row)
235 |
236 | df = pd.DataFrame(parsedGT,
237 | columns=['FrameId', 'Id', 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility', 'unused'])
238 | df.set_index(['FrameId', 'Id'], inplace=True)
239 |
240 | # Account for matlab convention.
241 | df[['X', 'Y']] -= (1, 1)
242 |
243 | # Removed trailing column
244 | del df['unused']
245 |
246 | return df
247 |
248 |
249 | def load_detrac_xml(fname):
250 | """Loads UA-DETRAC annotations data from xml files
251 |
252 | Competition Site: http://detrac-db.rit.albany.edu/download
253 |
254 | Params
255 | ------
256 | fname : str
257 | Filename to load data from
258 |
259 | Kwargs
260 | ------
261 | Currently none of these arguments used.
262 |
263 | Returns
264 | ------
265 | df : pandas.DataFrame
266 | The returned dataframe has the following columns
267 | 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility'
268 | The dataframe is indexed by ('FrameId', 'Id')
269 | """
270 |
271 | with io.open(fname) as fd:
272 | doc = xmltodict.parse(fd.read())
273 | frameList = doc['sequence']['frame']
274 |
275 | parsedGT = []
276 | for f in frameList:
277 | fid = int(f['@num'])
278 | targetList = f['target_list']['target']
279 | if not isinstance(targetList, list):
280 | targetList = [targetList]
281 |
282 | for t in targetList:
283 | row = []
284 | row.append(fid)
285 | row.append(int(t['@id']))
286 | row.append(float(t['box']['@left']))
287 | row.append(float(t['box']['@top']))
288 | row.append(float(t['box']['@width']))
289 | row.append(float(t['box']['@height']))
290 | row.append(1)
291 | row.append(-1)
292 | row.append(-1)
293 | row.append(-1)
294 | parsedGT.append(row)
295 |
296 | df = pd.DataFrame(parsedGT,
297 | columns=['FrameId', 'Id', 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility', 'unused'])
298 | df.set_index(['FrameId', 'Id'], inplace=True)
299 |
300 | # Account for matlab convention.
301 | df[['X', 'Y']] -= (1, 1)
302 |
303 | # Removed trailing column
304 | del df['unused']
305 |
306 | return df
307 |
308 |
309 | def loadtxt(fname, fmt=Format.MOT15_2D, **kwargs):
310 | """Load data from any known format."""
311 | fmt = Format(fmt)
312 |
313 | switcher = {
314 | Format.MOT16: load_motchallenge,
315 | Format.MOT15_2D: load_motchallenge,
316 | Format.VATIC_TXT: load_vatictxt,
317 | Format.DETRAC_MAT: load_detrac_mat,
318 | Format.DETRAC_XML: load_detrac_xml
319 | }
320 | func = switcher.get(fmt)
321 | return func(fname, **kwargs)
322 |
323 |
324 | def render_summary(summary, formatters=None, namemap=None, buf=None):
325 | """Render metrics summary to console friendly tabular output.
326 |
327 | Params
328 | ------
329 | summary : pd.DataFrame
330 | Dataframe containing summaries in rows.
331 |
332 | Kwargs
333 | ------
334 | buf : StringIO-like, optional
335 | Buffer to write to
336 | formatters : dict, optional
337 | Dicionary defining custom formatters for individual metrics.
338 | I.e `{'mota': '{:.2%}'.format}`. You can get preset formatters
339 | from MetricsHost.formatters
340 | namemap : dict, optional
341 | Dictionary defining new metric names for display. I.e
342 | `{'num_false_positives': 'FP'}`.
343 |
344 | Returns
345 | -------
346 | string
347 | Formatted string
348 | """
349 |
350 | if namemap is not None:
351 | summary = summary.rename(columns=namemap)
352 | if formatters is not None:
353 | formatters = {namemap.get(c, c): f for c, f in formatters.items()}
354 |
355 | output = summary.to_string(
356 | buf=buf,
357 | formatters=formatters,
358 | )
359 |
360 | return output
361 |
362 |
363 | motchallenge_metric_names = {
364 | 'idf1': 'IDF1',
365 | 'idp': 'IDP',
366 | 'idr': 'IDR',
367 | 'recall': 'Rcll',
368 | 'precision': 'Prcn',
369 | 'num_unique_objects': 'GT',
370 | 'mostly_tracked': 'MT',
371 | 'partially_tracked': 'PT',
372 | 'mostly_lost': 'ML',
373 | 'num_false_positives': 'FP',
374 | 'num_misses': 'FN',
375 | 'num_switches': 'IDs',
376 | 'num_fragmentations': 'FM',
377 | 'mota': 'MOTA',
378 | 'motp': 'MOTP',
379 | 'num_transfer': 'IDt',
380 | 'num_ascend': 'IDa',
381 | 'num_migrate': 'IDm',
382 | }
383 | """A list mappings for metric names to comply with MOTChallenge."""
384 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/lap.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Tools for solving linear assignment problems."""
9 |
10 | # pylint: disable=import-outside-toplevel
11 |
12 | from __future__ import absolute_import
13 | from __future__ import division
14 | from __future__ import print_function
15 |
16 | from contextlib import contextmanager
17 | import warnings
18 |
19 | import numpy as np
20 |
21 |
22 | def _module_is_available_py2(name):
23 | try:
24 | imp.find_module(name)
25 | return True
26 | except ImportError:
27 | return False
28 |
29 |
30 | def _module_is_available_py3(name):
31 | return importlib.util.find_spec(name) is not None
32 |
33 |
34 | try:
35 | import importlib.util
36 | except ImportError:
37 | import imp
38 | _module_is_available = _module_is_available_py2
39 | else:
40 | _module_is_available = _module_is_available_py3
41 |
42 |
43 | def linear_sum_assignment(costs, solver=None):
44 | """Solve a linear sum assignment problem (LSA).
45 |
46 | For large datasets solving the minimum cost assignment becomes the dominant runtime part.
47 | We therefore support various solvers out of the box (currently lapsolver, scipy, ortools, munkres)
48 |
49 | Params
50 | ------
51 | costs : np.array
52 | numpy matrix containing costs. Use NaN/Inf values for unassignable
53 | row/column pairs.
54 |
55 | Kwargs
56 | ------
57 | solver : callable or str, optional
58 | When str: name of solver to use.
59 | When callable: function to invoke
60 | When None: uses first available solver
61 | """
62 | costs = np.asarray(costs)
63 | if not costs.size:
64 | return np.array([], dtype=int), np.array([], dtype=int)
65 |
66 | solver = solver or default_solver
67 |
68 | if isinstance(solver, str):
69 | # Try resolve from string
70 | solver = solver_map.get(solver, None)
71 |
72 | assert callable(solver), 'Invalid LAP solver.'
73 | rids, cids = solver(costs)
74 | rids = np.asarray(rids).astype(int)
75 | cids = np.asarray(cids).astype(int)
76 | return rids, cids
77 |
78 |
79 | def add_expensive_edges(costs):
80 | """Replaces non-edge costs (nan, inf) with large number.
81 |
82 | If the optimal solution includes one of these edges,
83 | then the original problem was infeasible.
84 |
85 | Parameters
86 | ----------
87 | costs : np.ndarray
88 | """
89 | # The graph is probably already dense if we are doing this.
90 | assert isinstance(costs, np.ndarray)
91 | # The linear_sum_assignment function in scipy does not support missing edges.
92 | # Replace nan with a large constant that ensures it is not chosen.
93 | # If it is chosen, that means the problem was infeasible.
94 | valid = np.isfinite(costs)
95 | if valid.all():
96 | return costs.copy()
97 | if not valid.any():
98 | return np.zeros_like(costs)
99 | r = min(costs.shape)
100 | # Assume all edges costs are within [-c, c], c >= 0.
101 | # The cost of an invalid edge must be such that...
102 | # choosing this edge once and the best-possible edge (r - 1) times
103 | # is worse than choosing the worst-possible edge r times.
104 | # l + (r - 1) (-c) > r c
105 | # l > r c + (r - 1) c
106 | # l > (2 r - 1) c
107 | # Choose l = 2 r c + 1 > (2 r - 1) c.
108 | c = np.abs(costs[valid]).max() + 1 # Doesn't hurt to add 1 here.
109 | large_constant = 2 * r * c + 1
110 | return np.where(valid, costs, large_constant)
111 |
112 |
113 | def _exclude_missing_edges(costs, rids, cids):
114 | subset = [
115 | index for index, (i, j) in enumerate(zip(rids, cids))
116 | if np.isfinite(costs[i, j])
117 | ]
118 | return rids[subset], cids[subset]
119 |
120 |
121 | def lsa_solve_scipy(costs):
122 | """Solves the LSA problem using the scipy library."""
123 |
124 | from scipy.optimize import linear_sum_assignment as scipy_solve
125 |
126 | # scipy (1.3.3) does not support nan or inf values
127 | finite_costs = add_expensive_edges(costs)
128 | rids, cids = scipy_solve(finite_costs)
129 | rids, cids = _exclude_missing_edges(costs, rids, cids)
130 | return rids, cids
131 |
132 |
133 | def lsa_solve_lapsolver(costs):
134 | """Solves the LSA problem using the lapsolver library."""
135 | from lapsolver import solve_dense
136 |
137 | # Note that lapsolver will add expensive finite edges internally.
138 | # However, older versions did not add a large enough edge.
139 | finite_costs = add_expensive_edges(costs)
140 | rids, cids = solve_dense(finite_costs)
141 | rids, cids = _exclude_missing_edges(costs, rids, cids)
142 | return rids, cids
143 |
144 |
145 | def lsa_solve_munkres(costs):
146 | """Solves the LSA problem using the Munkres library."""
147 | from munkres import Munkres
148 |
149 | m = Munkres()
150 | # The munkres package may hang if the problem is not feasible.
151 | # Therefore, add expensive edges instead of using munkres.DISALLOWED.
152 | finite_costs = add_expensive_edges(costs)
153 | # Ensure that matrix is square.
154 | finite_costs = _zero_pad_to_square(finite_costs)
155 | indices = np.array(m.compute(finite_costs), dtype=int)
156 | # Exclude extra matches from extension to square matrix.
157 | indices = indices[(indices[:, 0] < costs.shape[0])
158 | & (indices[:, 1] < costs.shape[1])]
159 | rids, cids = indices[:, 0], indices[:, 1]
160 | rids, cids = _exclude_missing_edges(costs, rids, cids)
161 | return rids, cids
162 |
163 |
164 | def _zero_pad_to_square(costs):
165 | num_rows, num_cols = costs.shape
166 | if num_rows == num_cols:
167 | return costs
168 | n = max(num_rows, num_cols)
169 | padded = np.zeros((n, n), dtype=costs.dtype)
170 | padded[:num_rows, :num_cols] = costs
171 | return padded
172 |
173 |
174 | def lsa_solve_ortools(costs):
175 | """Solves the LSA problem using Google's optimization tools. """
176 | from ortools.graph import pywrapgraph
177 |
178 | if costs.shape[0] != costs.shape[1]:
179 | # ortools assumes that the problem is square.
180 | # Non-square problem will be infeasible.
181 | # Default to scipy solver rather than add extra zeros.
182 | # (This maintains the same behaviour as previous versions.)
183 | return linear_sum_assignment(costs, solver='scipy')
184 |
185 | rs, cs = np.isfinite(costs).nonzero() # pylint: disable=unbalanced-tuple-unpacking
186 | finite_costs = costs[rs, cs]
187 | scale = find_scale_for_integer_approximation(finite_costs)
188 | if scale != 1:
189 | warnings.warn('costs are not integers; using approximation')
190 | int_costs = np.round(scale * finite_costs).astype(int)
191 |
192 | assignment = pywrapgraph.LinearSumAssignment()
193 | # OR-Tools does not like to receive indices of type np.int64.
194 | rs = rs.tolist() # pylint: disable=no-member
195 | cs = cs.tolist()
196 | int_costs = int_costs.tolist()
197 | for r, c, int_cost in zip(rs, cs, int_costs):
198 | assignment.AddArcWithCost(r, c, int_cost)
199 |
200 | status = assignment.Solve()
201 | try:
202 | _ortools_assert_is_optimal(pywrapgraph, status)
203 | except AssertionError:
204 | # Default to scipy solver rather than add finite edges.
205 | # (This maintains the same behaviour as previous versions.)
206 | return linear_sum_assignment(costs, solver='scipy')
207 |
208 | return _ortools_extract_solution(assignment)
209 |
210 |
211 | def find_scale_for_integer_approximation(costs, base=10, log_max_scale=8, log_safety=2):
212 | """Returns a multiplicative factor to use before rounding to integers.
213 |
214 | Tries to find scale = base ** j (for j integer) such that:
215 | abs(diff(unique(costs))) <= 1 / (scale * safety)
216 | where safety = base ** log_safety.
217 |
218 | Logs a warning if the desired resolution could not be achieved.
219 | """
220 | costs = np.asarray(costs)
221 | costs = costs[np.isfinite(costs)] # Exclude non-edges (nan, inf) and -inf.
222 | if np.size(costs) == 0:
223 | # No edges with numeric value. Scale does not matter.
224 | return 1
225 | unique = np.unique(costs)
226 | if np.size(unique) == 1:
227 | # All costs have equal values. Scale does not matter.
228 | return 1
229 | try:
230 | _assert_integer(costs)
231 | except AssertionError:
232 | pass
233 | else:
234 | # The costs are already integers.
235 | return 1
236 |
237 | # Find scale = base ** e such that:
238 | # 1 / scale <= tol, or
239 | # e = log(scale) >= -log(tol)
240 | # where tol = min(diff(unique(costs)))
241 | min_diff = np.diff(unique).min()
242 | e = np.ceil(np.log(min_diff) / np.log(base)).astype(int).item()
243 | # Add optional non-negative safety factor to reduce quantization noise.
244 | e += max(log_safety, 0)
245 | # Ensure that we do not reduce the magnitude of the costs.
246 | e = max(e, 0)
247 | # Ensure that the scale is not too large.
248 | if e > log_max_scale:
249 | warnings.warn('could not achieve desired resolution for approximation: '
250 | 'want exponent %d but max is %d', e, log_max_scale)
251 | e = log_max_scale
252 | scale = base ** e
253 | return scale
254 |
255 |
256 | def _assert_integer(costs):
257 | # Check that costs are not changed by rounding.
258 | # Note: Elements of cost matrix may be nan, inf, -inf.
259 | np.testing.assert_equal(np.round(costs), costs)
260 |
261 |
262 | def _ortools_assert_is_optimal(pywrapgraph, status):
263 | if status == pywrapgraph.LinearSumAssignment.OPTIMAL:
264 | pass
265 | elif status == pywrapgraph.LinearSumAssignment.INFEASIBLE:
266 | raise AssertionError('ortools: infeasible assignment problem')
267 | elif status == pywrapgraph.LinearSumAssignment.POSSIBLE_OVERFLOW:
268 | raise AssertionError('ortools: possible overflow in assignment problem')
269 | else:
270 | raise AssertionError('ortools: unknown status')
271 |
272 |
273 | def _ortools_extract_solution(assignment):
274 | if assignment.NumNodes() == 0:
275 | return np.array([], dtype=int), np.array([], dtype=int)
276 |
277 | pairings = []
278 | for i in range(assignment.NumNodes()):
279 | pairings.append([i, assignment.RightMate(i)])
280 |
281 | indices = np.array(pairings, dtype=int)
282 | return indices[:, 0], indices[:, 1]
283 |
284 |
285 | def lsa_solve_lapjv(costs):
286 | """Solves the LSA problem using lap.lapjv()."""
287 |
288 | from lap import lapjv
289 |
290 | # The lap.lapjv function supports +inf edges but there are some issues.
291 | # https://github.com/gatagat/lap/issues/20
292 | # Therefore, replace nans with large finite cost.
293 | finite_costs = add_expensive_edges(costs)
294 | row_to_col, _ = lapjv(finite_costs, return_cost=False, extend_cost=True)
295 | indices = np.array([np.arange(costs.shape[0]), row_to_col], dtype=int).T
296 | # Exclude unmatched rows (in case of unbalanced problem).
297 | indices = indices[indices[:, 1] != -1] # pylint: disable=unsubscriptable-object
298 | rids, cids = indices[:, 0], indices[:, 1]
299 | # Ensure that no missing edges were chosen.
300 | rids, cids = _exclude_missing_edges(costs, rids, cids)
301 | return rids, cids
302 |
303 |
304 | available_solvers = None
305 | default_solver = None
306 | solver_map = None
307 |
308 |
309 | def _init_standard_solvers():
310 | global available_solvers, default_solver, solver_map # pylint: disable=global-statement
311 |
312 | solvers = [
313 | ('lapsolver', lsa_solve_lapsolver),
314 | ('lap', lsa_solve_lapjv),
315 | ('scipy', lsa_solve_scipy),
316 | ('munkres', lsa_solve_munkres),
317 | ('ortools', lsa_solve_ortools),
318 | ]
319 |
320 | solver_map = dict(solvers)
321 |
322 | available_solvers = [s[0] for s in solvers if _module_is_available(s[0])]
323 | if len(available_solvers) == 0:
324 | default_solver = None
325 | warnings.warn('No standard LAP solvers found. Consider `pip install lapsolver` or `pip install scipy`', category=RuntimeWarning)
326 | else:
327 | default_solver = available_solvers[0]
328 |
329 |
330 | _init_standard_solvers()
331 |
332 |
333 | @contextmanager
334 | def set_default_solver(newsolver):
335 | """Change the default solver within context.
336 |
337 | Intended usage
338 |
339 | costs = ...
340 | mysolver = lambda x: ... # solver code that returns pairings
341 |
342 | with lap.set_default_solver(mysolver):
343 | rids, cids = lap.linear_sum_assignment(costs)
344 |
345 | Params
346 | ------
347 | newsolver : callable or str
348 | new solver function
349 | """
350 |
351 | global default_solver # pylint: disable=global-statement
352 |
353 | oldsolver = default_solver
354 | try:
355 | default_solver = newsolver
356 | yield
357 | finally:
358 | default_solver = oldsolver
359 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/math_util.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Math utility functions."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import warnings
15 |
16 | import numpy as np
17 |
18 |
19 | def quiet_divide(a, b):
20 | """Quiet divide function that does not warn about (0 / 0)."""
21 | with warnings.catch_warnings():
22 | warnings.simplefilter('ignore', RuntimeWarning)
23 | return np.true_divide(a, b)
24 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/preprocess.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Preprocess data for CLEAR_MOT_M."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | from configparser import ConfigParser
15 | import logging
16 | import time
17 |
18 | import numpy as np
19 |
20 | from . import distances as mmd
21 | from .lap import linear_sum_assignment
22 |
23 |
24 | def preprocessResult(res, gt, inifile):
25 | """Preprocesses data for utils.CLEAR_MOT_M.
26 |
27 | Returns a subset of the predictions.
28 | """
29 | # pylint: disable=too-many-locals
30 | st = time.time()
31 | labels = [
32 | 'ped', # 1
33 | 'person_on_vhcl', # 2
34 | 'car', # 3
35 | 'bicycle', # 4
36 | 'mbike', # 5
37 | 'non_mot_vhcl', # 6
38 | 'static_person', # 7
39 | 'distractor', # 8
40 | 'occluder', # 9
41 | 'occluder_on_grnd', # 10
42 | 'occluder_full', # 11
43 | 'reflection', # 12
44 | 'crowd', # 13
45 | ]
46 | distractors = ['person_on_vhcl', 'static_person', 'distractor', 'reflection']
47 | is_distractor = {i + 1: x in distractors for i, x in enumerate(labels)}
48 | for i in distractors:
49 | is_distractor[i] = 1
50 | seqIni = ConfigParser()
51 | seqIni.read(inifile, encoding='utf8')
52 | F = int(seqIni['Sequence']['seqLength'])
53 | todrop = []
54 | for t in range(1, F + 1):
55 | if t not in res.index or t not in gt.index:
56 | continue
57 | resInFrame = res.loc[t]
58 |
59 | GTInFrame = gt.loc[t]
60 | A = GTInFrame[['X', 'Y', 'Width', 'Height']].values
61 | B = resInFrame[['X', 'Y', 'Width', 'Height']].values
62 | disM = mmd.iou_matrix(A, B, max_iou=0.5)
63 | le, ri = linear_sum_assignment(disM)
64 | flags = [
65 | 1 if is_distractor[it['ClassId']] or it['Visibility'] < 0. else 0
66 | for i, (k, it) in enumerate(GTInFrame.iterrows())
67 | ]
68 | hid = [k for k, it in resInFrame.iterrows()]
69 | for i, j in zip(le, ri):
70 | if not np.isfinite(disM[i, j]):
71 | continue
72 | if flags[i]:
73 | todrop.append((t, hid[j]))
74 | ret = res.drop(labels=todrop)
75 | logging.info('Preprocess take %.3f seconds and remove %d boxes.',
76 | time.time() - st, len(todrop))
77 | return ret
78 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/posetrack21_eval/motmetrics/tests/__init__.py
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/tests/test_distances.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Tests distance computation."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import numpy as np
15 |
16 | import motmetrics as mm
17 |
18 |
19 | def test_norm2squared():
20 | """Tests norm2squared_matrix."""
21 | a = np.asfarray([
22 | [1, 2],
23 | [2, 2],
24 | [3, 2],
25 | ])
26 |
27 | b = np.asfarray([
28 | [0, 0],
29 | [1, 1],
30 | ])
31 |
32 | C = mm.distances.norm2squared_matrix(a, b)
33 | np.testing.assert_allclose(
34 | C,
35 | [
36 | [5, 1],
37 | [8, 2],
38 | [13, 5]
39 | ]
40 | )
41 |
42 | C = mm.distances.norm2squared_matrix(a, b, max_d2=5)
43 | np.testing.assert_allclose(
44 | C,
45 | [
46 | [5, 1],
47 | [np.nan, 2],
48 | [np.nan, 5]
49 | ]
50 | )
51 |
52 |
53 | def test_norm2squared_empty():
54 | """Tests norm2squared_matrix with an empty input."""
55 | a = []
56 | b = np.asfarray([[0, 0], [1, 1]])
57 | C = mm.distances.norm2squared_matrix(a, b)
58 | assert C.size == 0
59 | C = mm.distances.norm2squared_matrix(b, a)
60 | assert C.size == 0
61 |
62 |
63 | def test_iou_matrix():
64 | """Tests iou_matrix."""
65 | a = np.array([
66 | [0, 0, 1, 2],
67 | ])
68 |
69 | b = np.array([
70 | [0, 0, 1, 2],
71 | [0, 0, 1, 1],
72 | [1, 1, 1, 1],
73 | [0.5, 0, 1, 1],
74 | [0, 1, 1, 1],
75 | ])
76 | np.testing.assert_allclose(
77 | mm.distances.iou_matrix(a, b),
78 | [[0, 0.5, 1, 0.8, 0.5]],
79 | atol=1e-4
80 | )
81 |
82 | np.testing.assert_allclose(
83 | mm.distances.iou_matrix(a, b, max_iou=0.5),
84 | [[0, 0.5, np.nan, np.nan, 0.5]],
85 | atol=1e-4
86 | )
87 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/tests/test_io.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Tests IO functions."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import os
15 |
16 | import pandas as pd
17 |
18 | import motmetrics as mm
19 |
20 | DATA_DIR = os.path.join(os.path.dirname(__file__), '../data')
21 |
22 |
23 | def test_load_vatic():
24 | """Tests VATIC_TXT format."""
25 | df = mm.io.loadtxt(os.path.join(DATA_DIR, 'iotest/vatic.txt'), fmt=mm.io.Format.VATIC_TXT)
26 |
27 | expected = pd.DataFrame([
28 | # F,ID,Y,W,H,L,O,G,F,A1,A2,A3,A4
29 | (0, 0, 412, 0, 430, 124, 0, 0, 0, 'worker', 0, 0, 0, 0),
30 | (1, 0, 412, 10, 430, 114, 0, 0, 1, 'pc', 1, 0, 1, 0),
31 | (1, 1, 412, 0, 430, 124, 0, 0, 1, 'pc', 0, 1, 0, 0),
32 | (2, 2, 412, 0, 430, 124, 0, 0, 1, 'worker', 1, 1, 0, 1)
33 | ])
34 |
35 | assert (df.reset_index().values == expected.values).all()
36 |
37 |
38 | def test_load_motchallenge():
39 | """Tests MOT15_2D format."""
40 | df = mm.io.loadtxt(os.path.join(DATA_DIR, 'iotest/motchallenge.txt'), fmt=mm.io.Format.MOT15_2D)
41 |
42 | expected = pd.DataFrame([
43 | (1, 1, 398, 181, 121, 229, 1, -1, -1), # Note -1 on x and y for correcting matlab
44 | (1, 2, 281, 200, 92, 184, 1, -1, -1),
45 | (2, 2, 268, 201, 87, 182, 1, -1, -1),
46 | (2, 3, 70, 150, 100, 284, 1, -1, -1),
47 | (2, 4, 199, 205, 55, 137, 1, -1, -1),
48 | ])
49 |
50 | assert (df.reset_index().values == expected.values).all()
51 |
52 |
53 | def test_load_detrac_mat():
54 | """Tests DETRAC_MAT format."""
55 | df = mm.io.loadtxt(os.path.join(DATA_DIR, 'iotest/detrac.mat'), fmt=mm.io.Format.DETRAC_MAT)
56 |
57 | expected = pd.DataFrame([
58 | (1., 1., 745., 356., 148., 115., 1., -1., -1.),
59 | (2., 1., 738., 350., 145., 111., 1., -1., -1.),
60 | (3., 1., 732., 343., 142., 107., 1., -1., -1.),
61 | (4., 1., 725., 336., 139., 104., 1., -1., -1.)
62 | ])
63 |
64 | assert (df.reset_index().values == expected.values).all()
65 |
66 |
67 | def test_load_detrac_xml():
68 | """Tests DETRAC_XML format."""
69 | df = mm.io.loadtxt(os.path.join(DATA_DIR, 'iotest/detrac.xml'), fmt=mm.io.Format.DETRAC_XML)
70 |
71 | expected = pd.DataFrame([
72 | (1., 1., 744.6, 356.33, 148.2, 115.14, 1., -1., -1.),
73 | (2., 1., 738.2, 349.51, 145.21, 111.29, 1., -1., -1.),
74 | (3., 1., 731.8, 342.68, 142.23, 107.45, 1., -1., -1.),
75 | (4., 1., 725.4, 335.85, 139.24, 103.62, 1., -1., -1.)
76 | ])
77 |
78 | assert (df.reset_index().values == expected.values).all()
79 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/tests/test_issue19.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Tests issue 19.
9 |
10 | https://github.com/cheind/py-motmetrics/issues/19
11 | """
12 |
13 | from __future__ import absolute_import
14 | from __future__ import division
15 | from __future__ import print_function
16 |
17 | import numpy as np
18 |
19 | import motmetrics as mm
20 |
21 |
22 | def test_issue19():
23 | """Tests issue 19."""
24 | acc = mm.MOTAccumulator()
25 |
26 | g0 = [0, 1]
27 | p0 = [0, 1]
28 | d0 = [[0.2, np.nan], [np.nan, 0.2]]
29 |
30 | g1 = [2, 3]
31 | p1 = [2, 3, 4, 5]
32 | d1 = [[0.28571429, 0.5, 0.0, np.nan], [np.nan, 0.44444444, np.nan, 0.0]]
33 |
34 | acc.update(g0, p0, d0, 0)
35 | acc.update(g1, p1, d1, 1)
36 |
37 | mh = mm.metrics.create()
38 | mh.compute(acc)
39 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/tests/test_lap.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Tests linear assignment problem solvers."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import warnings
15 |
16 | import numpy as np
17 | import pytest
18 |
19 | from motmetrics import lap
20 |
21 | DESIRED_SOLVERS = ['lap', 'lapsolver', 'munkres', 'ortools', 'scipy']
22 | SOLVERS = lap.available_solvers
23 |
24 |
25 | @pytest.mark.parametrize('solver', DESIRED_SOLVERS)
26 | def test_solver_is_available(solver):
27 | if solver not in lap.available_solvers:
28 | warnings.warn('solver not available: ' + solver)
29 |
30 |
31 | @pytest.mark.parametrize('solver', SOLVERS)
32 | def test_assign_easy(solver):
33 | """Problem that could be solved by a greedy algorithm."""
34 | costs = np.asfarray([[6, 9, 1], [10, 3, 2], [8, 7, 4]])
35 | costs_copy = costs.copy()
36 | result = lap.linear_sum_assignment(costs, solver=solver)
37 |
38 | expected = np.array([[0, 1, 2], [2, 1, 0]])
39 | np.testing.assert_equal(result, expected)
40 | np.testing.assert_equal(costs, costs_copy)
41 |
42 |
43 | @pytest.mark.parametrize('solver', SOLVERS)
44 | def test_assign_full(solver):
45 | """Problem that would be incorrect using a greedy algorithm."""
46 | costs = np.asfarray([[5, 5, 6], [1, 2, 5], [2, 4, 5]])
47 | costs_copy = costs.copy()
48 | result = lap.linear_sum_assignment(costs, solver=solver)
49 |
50 | # Optimal matching is (0, 2), (1, 1), (2, 0) for 6 + 2 + 2.
51 | expected = np.asfarray([[0, 1, 2], [2, 1, 0]])
52 | np.testing.assert_equal(result, expected)
53 | np.testing.assert_equal(costs, costs_copy)
54 |
55 |
56 | @pytest.mark.parametrize('solver', SOLVERS)
57 | def test_assign_full_negative(solver):
58 | costs = -7 + np.asfarray([[5, 5, 6], [1, 2, 5], [2, 4, 5]])
59 | costs_copy = costs.copy()
60 | result = lap.linear_sum_assignment(costs, solver=solver)
61 |
62 | # Optimal matching is (0, 2), (1, 1), (2, 0) for 5 + 1 + 1.
63 | expected = np.array([[0, 1, 2], [2, 1, 0]])
64 | np.testing.assert_equal(result, expected)
65 | np.testing.assert_equal(costs, costs_copy)
66 |
67 |
68 | @pytest.mark.parametrize('solver', SOLVERS)
69 | def test_assign_empty(solver):
70 | costs = np.asfarray([[]])
71 | costs_copy = costs.copy()
72 | result = lap.linear_sum_assignment(costs, solver=solver)
73 |
74 | np.testing.assert_equal(np.size(result), 0)
75 | np.testing.assert_equal(costs, costs_copy)
76 |
77 |
78 | @pytest.mark.parametrize('solver', SOLVERS)
79 | def test_assign_infeasible(solver):
80 | """Tests that minimum-cost solution with most edges is found."""
81 | costs = np.asfarray([[np.nan, np.nan, 2],
82 | [np.nan, np.nan, 1],
83 | [8, 7, 4]])
84 | costs_copy = costs.copy()
85 | result = lap.linear_sum_assignment(costs, solver=solver)
86 |
87 | # Optimal matching is (1, 2), (2, 1).
88 | expected = np.array([[1, 2], [2, 1]])
89 | np.testing.assert_equal(result, expected)
90 | np.testing.assert_equal(costs, costs_copy)
91 |
92 |
93 | @pytest.mark.parametrize('solver', SOLVERS)
94 | def test_assign_disallowed(solver):
95 | costs = np.asfarray([[5, 9, np.nan], [10, np.nan, 2], [8, 7, 4]])
96 | costs_copy = costs.copy()
97 | result = lap.linear_sum_assignment(costs, solver=solver)
98 |
99 | expected = np.array([[0, 1, 2], [0, 2, 1]])
100 | np.testing.assert_equal(result, expected)
101 | np.testing.assert_equal(costs, costs_copy)
102 |
103 |
104 | @pytest.mark.parametrize('solver', SOLVERS)
105 | def test_assign_non_integer(solver):
106 | costs = (1. / 9) * np.asfarray([[5, 9, np.nan], [10, np.nan, 2], [8, 7, 4]])
107 | costs_copy = costs.copy()
108 | result = lap.linear_sum_assignment(costs, solver=solver)
109 |
110 | expected = np.array([[0, 1, 2], [0, 2, 1]])
111 | np.testing.assert_equal(result, expected)
112 | np.testing.assert_equal(costs, costs_copy)
113 |
114 |
115 | @pytest.mark.parametrize('solver', SOLVERS)
116 | def test_assign_attractive_disallowed(solver):
117 | """Graph contains an attractive edge that cannot be used."""
118 | costs = np.asfarray([[-10000, -1], [-1, np.nan]])
119 | costs_copy = costs.copy()
120 | result = lap.linear_sum_assignment(costs, solver=solver)
121 |
122 | # The optimal solution is (0, 1), (1, 0) for a cost of -2.
123 | # Ensure that the algorithm does not choose the (0, 0) edge.
124 | # This would not be a perfect matching.
125 | expected = np.array([[0, 1], [1, 0]])
126 | np.testing.assert_equal(result, expected)
127 | np.testing.assert_equal(costs, costs_copy)
128 |
129 |
130 | @pytest.mark.parametrize('solver', SOLVERS)
131 | def test_assign_attractive_broken_ring(solver):
132 | """Graph contains cheap broken ring and expensive unbroken ring."""
133 | costs = np.asfarray([[np.nan, 1000, np.nan], [np.nan, 1, 1000], [1000, np.nan, 1]])
134 | costs_copy = costs.copy()
135 | result = lap.linear_sum_assignment(costs, solver=solver)
136 |
137 | # Optimal solution is (0, 1), (1, 2), (2, 0) with cost 1000 + 1000 + 1000.
138 | # Solver might choose (0, 0), (1, 1), (2, 2) with cost inf + 1 + 1.
139 | expected = np.array([[0, 1, 2], [1, 2, 0]])
140 | np.testing.assert_equal(result, expected)
141 | np.testing.assert_equal(costs, costs_copy)
142 |
143 |
144 | @pytest.mark.parametrize('solver', SOLVERS)
145 | def test_unbalanced_wide(solver):
146 | costs = np.asfarray([[6, 4, 1], [10, 8, 2]])
147 | costs_copy = costs.copy()
148 | result = lap.linear_sum_assignment(costs, solver=solver)
149 |
150 | expected = np.array([[0, 1], [1, 2]])
151 | np.testing.assert_equal(result, expected)
152 | np.testing.assert_equal(costs, costs_copy)
153 |
154 |
155 | @pytest.mark.parametrize('solver', SOLVERS)
156 | def test_unbalanced_tall(solver):
157 | costs = np.asfarray([[6, 10], [4, 8], [1, 2]])
158 | costs_copy = costs.copy()
159 | result = lap.linear_sum_assignment(costs, solver=solver)
160 |
161 | expected = np.array([[1, 2], [0, 1]])
162 | np.testing.assert_equal(result, expected)
163 | np.testing.assert_equal(costs, costs_copy)
164 |
165 |
166 | @pytest.mark.parametrize('solver', SOLVERS)
167 | def test_unbalanced_disallowed_wide(solver):
168 | costs = np.asfarray([[np.nan, 11, 8], [8, np.nan, 7]])
169 | costs_copy = costs.copy()
170 | result = lap.linear_sum_assignment(costs, solver=solver)
171 |
172 | expected = np.array([[0, 1], [2, 0]])
173 | np.testing.assert_equal(result, expected)
174 | np.testing.assert_equal(costs, costs_copy)
175 |
176 |
177 | @pytest.mark.parametrize('solver', SOLVERS)
178 | def test_unbalanced_disallowed_tall(solver):
179 | costs = np.asfarray([[np.nan, 9], [11, np.nan], [8, 7]])
180 | costs_copy = costs.copy()
181 | result = lap.linear_sum_assignment(costs, solver=solver)
182 |
183 | expected = np.array([[0, 2], [1, 0]])
184 | np.testing.assert_equal(result, expected)
185 | np.testing.assert_equal(costs, costs_copy)
186 |
187 |
188 | @pytest.mark.parametrize('solver', SOLVERS)
189 | def test_unbalanced_infeasible(solver):
190 | """Tests that minimum-cost solution with most edges is found."""
191 | costs = np.asfarray([[np.nan, np.nan, 2],
192 | [np.nan, np.nan, 1],
193 | [np.nan, np.nan, 3],
194 | [8, 7, 4]])
195 | costs_copy = costs.copy()
196 | result = lap.linear_sum_assignment(costs, solver=solver)
197 |
198 | # Optimal matching is (1, 2), (3, 1).
199 | expected = np.array([[1, 3], [2, 1]])
200 | np.testing.assert_equal(result, expected)
201 | np.testing.assert_equal(costs, costs_copy)
202 |
203 |
204 | def test_change_solver():
205 | """Tests effect of lap.set_default_solver."""
206 |
207 | def mysolver(_):
208 | mysolver.called += 1
209 | return np.array([]), np.array([])
210 | mysolver.called = 0
211 |
212 | costs = np.asfarray([[6, 9, 1], [10, 3, 2], [8, 7, 4]])
213 |
214 | with lap.set_default_solver(mysolver):
215 | lap.linear_sum_assignment(costs)
216 | assert mysolver.called == 1
217 | lap.linear_sum_assignment(costs)
218 | assert mysolver.called == 1
219 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/tests/test_metrics.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Tests computation of metrics from accumulator."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import os
15 |
16 | import numpy as np
17 | import pandas as pd
18 | from pytest import approx
19 |
20 | import motmetrics as mm
21 |
22 | DATA_DIR = os.path.join(os.path.dirname(__file__), '../data')
23 |
24 |
25 | def test_metricscontainer_1():
26 | """Tests registration of events with dependencies."""
27 | m = mm.metrics.MetricsHost()
28 | m.register(lambda df: 1., name='a')
29 | m.register(lambda df: 2., name='b')
30 | m.register(lambda df, a, b: a + b, deps=['a', 'b'], name='add')
31 | m.register(lambda df, a, b: a - b, deps=['a', 'b'], name='sub')
32 | m.register(lambda df, a, b: a * b, deps=['add', 'sub'], name='mul')
33 | summary = m.compute(mm.MOTAccumulator.new_event_dataframe(), metrics=['mul', 'add'], name='x')
34 | assert summary.columns.values.tolist() == ['mul', 'add']
35 | assert summary.iloc[0]['mul'] == -3.
36 | assert summary.iloc[0]['add'] == 3.
37 |
38 |
39 | def test_metricscontainer_autodep():
40 | """Tests automatic dependencies from argument names."""
41 | m = mm.metrics.MetricsHost()
42 | m.register(lambda df: 1., name='a')
43 | m.register(lambda df: 2., name='b')
44 | m.register(lambda df, a, b: a + b, name='add', deps='auto')
45 | m.register(lambda df, a, b: a - b, name='sub', deps='auto')
46 | m.register(lambda df, add, sub: add * sub, name='mul', deps='auto')
47 | summary = m.compute(mm.MOTAccumulator.new_event_dataframe(), metrics=['mul', 'add'])
48 | assert summary.columns.values.tolist() == ['mul', 'add']
49 | assert summary.iloc[0]['mul'] == -3.
50 | assert summary.iloc[0]['add'] == 3.
51 |
52 |
53 | def test_metricscontainer_autoname():
54 | """Tests automatic names (and dependencies) from inspection."""
55 |
56 | def constant_a(_):
57 | """Constant a help."""
58 | return 1.
59 |
60 | def constant_b(_):
61 | return 2.
62 |
63 | def add(_, constant_a, constant_b):
64 | return constant_a + constant_b
65 |
66 | def sub(_, constant_a, constant_b):
67 | return constant_a - constant_b
68 |
69 | def mul(_, add, sub):
70 | return add * sub
71 |
72 | m = mm.metrics.MetricsHost()
73 | m.register(constant_a, deps='auto')
74 | m.register(constant_b, deps='auto')
75 | m.register(add, deps='auto')
76 | m.register(sub, deps='auto')
77 | m.register(mul, deps='auto')
78 |
79 | assert m.metrics['constant_a']['help'] == 'Constant a help.'
80 |
81 | summary = m.compute(mm.MOTAccumulator.new_event_dataframe(), metrics=['mul', 'add'])
82 | assert summary.columns.values.tolist() == ['mul', 'add']
83 | assert summary.iloc[0]['mul'] == -3.
84 | assert summary.iloc[0]['add'] == 3.
85 |
86 |
87 | def test_metrics_with_no_events():
88 | """Tests metrics when accumulator is empty."""
89 | acc = mm.MOTAccumulator()
90 |
91 | mh = mm.metrics.create()
92 | metr = mh.compute(acc, return_dataframe=False, return_cached=True, metrics=[
93 | 'mota', 'motp', 'num_predictions', 'num_objects', 'num_detections', 'num_frames',
94 | ])
95 | assert np.isnan(metr['mota'])
96 | assert np.isnan(metr['motp'])
97 | assert metr['num_predictions'] == 0
98 | assert metr['num_objects'] == 0
99 | assert metr['num_detections'] == 0
100 | assert metr['num_frames'] == 0
101 |
102 |
103 | def test_assignment_metrics_with_empty_groundtruth():
104 | """Tests metrics when there are no ground-truth objects."""
105 | acc = mm.MOTAccumulator(auto_id=True)
106 | # Empty groundtruth.
107 | acc.update([], [1, 2, 3, 4], [])
108 | acc.update([], [1, 2, 3, 4], [])
109 | acc.update([], [1, 2, 3, 4], [])
110 | acc.update([], [1, 2, 3, 4], [])
111 |
112 | mh = mm.metrics.create()
113 | metr = mh.compute(acc, return_dataframe=False, metrics=[
114 | 'num_matches', 'num_false_positives', 'num_misses',
115 | 'idtp', 'idfp', 'idfn', 'num_frames',
116 | ])
117 | assert metr['num_matches'] == 0
118 | assert metr['num_false_positives'] == 16
119 | assert metr['num_misses'] == 0
120 | assert metr['idtp'] == 0
121 | assert metr['idfp'] == 16
122 | assert metr['idfn'] == 0
123 | assert metr['num_frames'] == 4
124 |
125 |
126 | def test_assignment_metrics_with_empty_predictions():
127 | """Tests metrics when there are no predictions."""
128 | acc = mm.MOTAccumulator(auto_id=True)
129 | # Empty predictions.
130 | acc.update([1, 2, 3, 4], [], [])
131 | acc.update([1, 2, 3, 4], [], [])
132 | acc.update([1, 2, 3, 4], [], [])
133 | acc.update([1, 2, 3, 4], [], [])
134 |
135 | mh = mm.metrics.create()
136 | metr = mh.compute(acc, return_dataframe=False, metrics=[
137 | 'num_matches', 'num_false_positives', 'num_misses',
138 | 'idtp', 'idfp', 'idfn', 'num_frames',
139 | ])
140 | assert metr['num_matches'] == 0
141 | assert metr['num_false_positives'] == 0
142 | assert metr['num_misses'] == 16
143 | assert metr['idtp'] == 0
144 | assert metr['idfp'] == 0
145 | assert metr['idfn'] == 16
146 | assert metr['num_frames'] == 4
147 |
148 |
149 | def test_assignment_metrics_with_both_empty():
150 | """Tests metrics when there are no ground-truth objects or predictions."""
151 | acc = mm.MOTAccumulator(auto_id=True)
152 | # Empty groundtruth and empty predictions.
153 | acc.update([], [], [])
154 | acc.update([], [], [])
155 | acc.update([], [], [])
156 | acc.update([], [], [])
157 |
158 | mh = mm.metrics.create()
159 | metr = mh.compute(acc, return_dataframe=False, metrics=[
160 | 'num_matches', 'num_false_positives', 'num_misses',
161 | 'idtp', 'idfp', 'idfn', 'num_frames',
162 | ])
163 | assert metr['num_matches'] == 0
164 | assert metr['num_false_positives'] == 0
165 | assert metr['num_misses'] == 0
166 | assert metr['idtp'] == 0
167 | assert metr['idfp'] == 0
168 | assert metr['idfn'] == 0
169 | assert metr['num_frames'] == 4
170 |
171 |
172 | def _extract_counts(acc):
173 | df_map = mm.metrics.events_to_df_map(acc.events)
174 | return mm.metrics.extract_counts_from_df_map(df_map)
175 |
176 |
177 | def test_extract_counts():
178 | """Tests events_to_df_map() and extract_counts_from_df_map()."""
179 | acc = mm.MOTAccumulator()
180 | # All FP
181 | acc.update([], [1, 2], [], frameid=0)
182 | # All miss
183 | acc.update([1, 2], [], [], frameid=1)
184 | # Match
185 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=2)
186 | # Switch
187 | acc.update([1, 2], [1, 2], [[0.2, np.nan], [np.nan, 0.1]], frameid=3)
188 | # Match. Better new match is available but should prefer history
189 | acc.update([1, 2], [1, 2], [[5, 1], [1, 5]], frameid=4)
190 | # No data
191 | acc.update([], [], [], frameid=5)
192 |
193 | ocs, hcs, tps = _extract_counts(acc)
194 |
195 | assert ocs == {1: 4, 2: 4}
196 | assert hcs == {1: 4, 2: 4}
197 | expected_tps = {
198 | (1, 1): 3,
199 | (1, 2): 2,
200 | (2, 1): 2,
201 | (2, 2): 3,
202 | }
203 | assert tps == expected_tps
204 |
205 |
206 | def test_extract_pandas_series_issue():
207 | """Reproduce issue that arises with pd.Series but not pd.DataFrame.
208 |
209 | >>> data = [[0, 1, 0.1], [0, 1, 0.2], [0, 1, 0.3]]
210 | >>> df = pd.DataFrame(data, columns=['x', 'y', 'z']).set_index(['x', 'y'])
211 | >>> df['z'].groupby(['x', 'y']).count()
212 | {(0, 1): 3}
213 |
214 | >>> data = [[0, 1, 0.1], [0, 1, 0.2]]
215 | >>> df = pd.DataFrame(data, columns=['x', 'y', 'z']).set_index(['x', 'y'])
216 | >>> df['z'].groupby(['x', 'y']).count()
217 | {'x': 1, 'y': 1}
218 |
219 | >>> df[['z']].groupby(['x', 'y'])['z'].count().to_dict()
220 | {(0, 1): 2}
221 | """
222 | acc = mm.MOTAccumulator(auto_id=True)
223 | acc.update([0], [1], [[0.1]])
224 | acc.update([0], [1], [[0.1]])
225 | ocs, hcs, tps = _extract_counts(acc)
226 | assert ocs == {0: 2}
227 | assert hcs == {1: 2}
228 | assert tps == {(0, 1): 2}
229 |
230 |
231 | def test_benchmark_extract_counts(benchmark):
232 | """Benchmarks events_to_df_map() and extract_counts_from_df_map()."""
233 | rand = np.random.RandomState(0)
234 | acc = _accum_random_uniform(
235 | rand, seq_len=100, num_objs=50, num_hyps=5000,
236 | objs_per_frame=20, hyps_per_frame=40)
237 | benchmark(_extract_counts, acc)
238 |
239 |
240 | def _accum_random_uniform(rand, seq_len, num_objs, num_hyps, objs_per_frame, hyps_per_frame):
241 | acc = mm.MOTAccumulator(auto_id=True)
242 | for _ in range(seq_len):
243 | # Choose subset of objects present in this frame.
244 | objs = rand.choice(num_objs, objs_per_frame, replace=False)
245 | # Choose subset of hypotheses present in this frame.
246 | hyps = rand.choice(num_hyps, hyps_per_frame, replace=False)
247 | dist = rand.uniform(size=(objs_per_frame, hyps_per_frame))
248 | acc.update(objs, hyps, dist)
249 | return acc
250 |
251 |
252 | def test_mota_motp():
253 | """Tests values of MOTA and MOTP."""
254 | acc = mm.MOTAccumulator()
255 |
256 | # All FP
257 | acc.update([], [1, 2], [], frameid=0)
258 | # All miss
259 | acc.update([1, 2], [], [], frameid=1)
260 | # Match
261 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=2)
262 | # Switch
263 | acc.update([1, 2], [1, 2], [[0.2, np.nan], [np.nan, 0.1]], frameid=3)
264 | # Match. Better new match is available but should prefer history
265 | acc.update([1, 2], [1, 2], [[5, 1], [1, 5]], frameid=4)
266 | # No data
267 | acc.update([], [], [], frameid=5)
268 |
269 | mh = mm.metrics.create()
270 | metr = mh.compute(acc, return_dataframe=False, return_cached=True, metrics=[
271 | 'num_matches', 'num_false_positives', 'num_misses', 'num_switches', 'num_detections',
272 | 'num_objects', 'num_predictions', 'mota', 'motp', 'num_frames'
273 | ])
274 |
275 | assert metr['num_matches'] == 4
276 | assert metr['num_false_positives'] == 2
277 | assert metr['num_misses'] == 2
278 | assert metr['num_switches'] == 2
279 | assert metr['num_detections'] == 6
280 | assert metr['num_objects'] == 8
281 | assert metr['num_predictions'] == 8
282 | assert metr['mota'] == approx(1. - (2 + 2 + 2) / 8)
283 | assert metr['motp'] == approx(11.1 / 6)
284 | assert metr['num_frames'] == 6
285 |
286 |
287 | def test_ids():
288 | """Test metrics with frame IDs specified manually."""
289 | acc = mm.MOTAccumulator()
290 |
291 | # No data
292 | acc.update([], [], [], frameid=0)
293 | # Match
294 | acc.update([1, 2], [1, 2], [[1, 0], [0, 1]], frameid=1)
295 | # Switch also Transfer
296 | acc.update([1, 2], [1, 2], [[0.4, np.nan], [np.nan, 0.4]], frameid=2)
297 | # Match
298 | acc.update([1, 2], [1, 2], [[0, 1], [1, 0]], frameid=3)
299 | # Ascend (switch)
300 | acc.update([1, 2], [2, 3], [[1, 0], [0.4, 0.7]], frameid=4)
301 | # Migrate (transfer)
302 | acc.update([1, 3], [2, 3], [[1, 0], [0.4, 0.7]], frameid=5)
303 | # No data
304 | acc.update([], [], [], frameid=6)
305 |
306 | mh = mm.metrics.create()
307 | metr = mh.compute(acc, return_dataframe=False, return_cached=True, metrics=[
308 | 'num_matches', 'num_false_positives', 'num_misses', 'num_switches',
309 | 'num_transfer', 'num_ascend', 'num_migrate',
310 | 'num_detections', 'num_objects', 'num_predictions',
311 | 'mota', 'motp', 'num_frames',
312 | ])
313 | assert metr['num_matches'] == 7
314 | assert metr['num_false_positives'] == 0
315 | assert metr['num_misses'] == 0
316 | assert metr['num_switches'] == 3
317 | assert metr['num_transfer'] == 3
318 | assert metr['num_ascend'] == 1
319 | assert metr['num_migrate'] == 1
320 | assert metr['num_detections'] == 10
321 | assert metr['num_objects'] == 10
322 | assert metr['num_predictions'] == 10
323 | assert metr['mota'] == approx(1. - (0 + 0 + 3) / 10)
324 | assert metr['motp'] == approx(1.6 / 10)
325 | assert metr['num_frames'] == 7
326 |
327 |
328 | def test_correct_average():
329 | """Tests what is depicted in figure 3 of 'Evaluating MOT Performance'."""
330 | acc = mm.MOTAccumulator(auto_id=True)
331 |
332 | # No track
333 | acc.update([1, 2, 3, 4], [], [])
334 | acc.update([1, 2, 3, 4], [], [])
335 | acc.update([1, 2, 3, 4], [], [])
336 | acc.update([1, 2, 3, 4], [], [])
337 |
338 | # Track single
339 | acc.update([4], [4], [0])
340 | acc.update([4], [4], [0])
341 | acc.update([4], [4], [0])
342 | acc.update([4], [4], [0])
343 |
344 | mh = mm.metrics.create()
345 | metr = mh.compute(acc, metrics='mota', return_dataframe=False)
346 | assert metr['mota'] == approx(0.2)
347 |
348 |
349 | def test_motchallenge_files():
350 | """Tests metrics for sequences TUD-Campus and TUD-Stadtmitte."""
351 | dnames = [
352 | 'TUD-Campus',
353 | 'TUD-Stadtmitte',
354 | ]
355 |
356 | def compute_motchallenge(dname):
357 | df_gt = mm.io.loadtxt(os.path.join(dname, 'gt.txt'))
358 | df_test = mm.io.loadtxt(os.path.join(dname, 'test.txt'))
359 | return mm.utils.compare_to_groundtruth(df_gt, df_test, 'iou', distth=0.5)
360 |
361 | accs = [compute_motchallenge(os.path.join(DATA_DIR, d)) for d in dnames]
362 |
363 | # For testing
364 | # [a.events.to_pickle(n) for (a,n) in zip(accs, dnames)]
365 |
366 | mh = mm.metrics.create()
367 | summary = mh.compute_many(accs, metrics=mm.metrics.motchallenge_metrics, names=dnames, generate_overall=True)
368 |
369 | print()
370 | print(mm.io.render_summary(summary, namemap=mm.io.motchallenge_metric_names, formatters=mh.formatters))
371 | # assert ((summary['num_transfer'] - summary['num_migrate']) == (summary['num_switches'] - summary['num_ascend'])).all() # False assertion
372 | summary = summary[mm.metrics.motchallenge_metrics[:15]]
373 | expected = pd.DataFrame([
374 | [0.557659, 0.729730, 0.451253, 0.582173, 0.941441, 8.0, 1, 6, 1, 13, 150, 7, 7, 0.526462, 0.277201],
375 | [0.644619, 0.819760, 0.531142, 0.608997, 0.939920, 10.0, 5, 4, 1, 45, 452, 7, 6, 0.564014, 0.345904],
376 | [0.624296, 0.799176, 0.512211, 0.602640, 0.940268, 18.0, 6, 10, 2, 58, 602, 14, 13, 0.555116, 0.330177],
377 | ])
378 | np.testing.assert_allclose(summary, expected, atol=1e-3)
379 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/tests/test_mot.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Tests behavior of MOTAccumulator."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import numpy as np
15 | import pandas as pd
16 | import pytest
17 |
18 | import motmetrics as mm
19 |
20 |
21 | def test_events():
22 | """Tests that expected events are created by MOTAccumulator.update()."""
23 | acc = mm.MOTAccumulator()
24 |
25 | # All FP
26 | acc.update([], [1, 2], [], frameid=0)
27 | # All miss
28 | acc.update([1, 2], [], [], frameid=1)
29 | # Match
30 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=2)
31 | # Switch
32 | acc.update([1, 2], [1, 2], [[0.2, np.nan], [np.nan, 0.1]], frameid=3)
33 | # Match. Better new match is available but should prefer history
34 | acc.update([1, 2], [1, 2], [[5, 1], [1, 5]], frameid=4)
35 | # No data
36 | acc.update([], [], [], frameid=5)
37 |
38 | expect = mm.MOTAccumulator.new_event_dataframe()
39 | expect.loc[(0, 0), :] = ['RAW', np.nan, np.nan, np.nan]
40 | expect.loc[(0, 1), :] = ['RAW', np.nan, 1, np.nan]
41 | expect.loc[(0, 2), :] = ['RAW', np.nan, 2, np.nan]
42 | expect.loc[(0, 3), :] = ['FP', np.nan, 1, np.nan]
43 | expect.loc[(0, 4), :] = ['FP', np.nan, 2, np.nan]
44 |
45 | expect.loc[(1, 0), :] = ['RAW', np.nan, np.nan, np.nan]
46 | expect.loc[(1, 1), :] = ['RAW', 1, np.nan, np.nan]
47 | expect.loc[(1, 2), :] = ['RAW', 2, np.nan, np.nan]
48 | expect.loc[(1, 3), :] = ['MISS', 1, np.nan, np.nan]
49 | expect.loc[(1, 4), :] = ['MISS', 2, np.nan, np.nan]
50 |
51 | expect.loc[(2, 0), :] = ['RAW', np.nan, np.nan, np.nan]
52 | expect.loc[(2, 1), :] = ['RAW', 1, 1, 1.0]
53 | expect.loc[(2, 2), :] = ['RAW', 1, 2, 0.5]
54 | expect.loc[(2, 3), :] = ['RAW', 2, 1, 0.3]
55 | expect.loc[(2, 4), :] = ['RAW', 2, 2, 1.0]
56 | expect.loc[(2, 5), :] = ['MATCH', 1, 2, 0.5]
57 | expect.loc[(2, 6), :] = ['MATCH', 2, 1, 0.3]
58 |
59 | expect.loc[(3, 0), :] = ['RAW', np.nan, np.nan, np.nan]
60 | expect.loc[(3, 1), :] = ['RAW', 1, 1, 0.2]
61 | expect.loc[(3, 2), :] = ['RAW', 2, 2, 0.1]
62 | expect.loc[(3, 3), :] = ['TRANSFER', 1, 1, 0.2]
63 | expect.loc[(3, 4), :] = ['SWITCH', 1, 1, 0.2]
64 | expect.loc[(3, 5), :] = ['TRANSFER', 2, 2, 0.1]
65 | expect.loc[(3, 6), :] = ['SWITCH', 2, 2, 0.1]
66 |
67 | expect.loc[(4, 0), :] = ['RAW', np.nan, np.nan, np.nan]
68 | expect.loc[(4, 1), :] = ['RAW', 1, 1, 5.]
69 | expect.loc[(4, 2), :] = ['RAW', 1, 2, 1.]
70 | expect.loc[(4, 3), :] = ['RAW', 2, 1, 1.]
71 | expect.loc[(4, 4), :] = ['RAW', 2, 2, 5.]
72 | expect.loc[(4, 5), :] = ['MATCH', 1, 1, 5.]
73 | expect.loc[(4, 6), :] = ['MATCH', 2, 2, 5.]
74 |
75 | expect.loc[(5, 0), :] = ['RAW', np.nan, np.nan, np.nan]
76 |
77 | pd.util.testing.assert_frame_equal(acc.events, expect)
78 |
79 |
80 | def test_max_switch_time():
81 | """Tests max_switch_time option."""
82 | acc = mm.MOTAccumulator(max_switch_time=1)
83 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=1) # 1->a, 2->b
84 | frameid = acc.update([1, 2], [1, 2], [[0.5, np.nan], [np.nan, 0.5]], frameid=2) # 1->b, 2->a
85 |
86 | df = acc.events.loc[frameid]
87 | assert ((df.Type == 'SWITCH') | (df.Type == 'RAW') | (df.Type == 'TRANSFER')).all()
88 |
89 | acc = mm.MOTAccumulator(max_switch_time=1)
90 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=1) # 1->a, 2->b
91 | frameid = acc.update([1, 2], [1, 2], [[0.5, np.nan], [np.nan, 0.5]], frameid=5) # Later frame 1->b, 2->a
92 |
93 | df = acc.events.loc[frameid]
94 | assert ((df.Type == 'MATCH') | (df.Type == 'RAW') | (df.Type == 'TRANSFER')).all()
95 |
96 |
97 | def test_auto_id():
98 | """Tests auto_id option."""
99 | acc = mm.MOTAccumulator(auto_id=True)
100 | acc.update([1, 2, 3, 4], [], [])
101 | acc.update([1, 2, 3, 4], [], [])
102 | assert acc.events.index.levels[0][-1] == 1
103 | acc.update([1, 2, 3, 4], [], [])
104 | assert acc.events.index.levels[0][-1] == 2
105 |
106 | with pytest.raises(AssertionError):
107 | acc.update([1, 2, 3, 4], [], [], frameid=5)
108 |
109 | acc = mm.MOTAccumulator(auto_id=False)
110 | with pytest.raises(AssertionError):
111 | acc.update([1, 2, 3, 4], [], [])
112 |
113 |
114 | def test_merge_dataframes():
115 | """Tests merge_event_dataframes()."""
116 | # pylint: disable=too-many-statements
117 | acc = mm.MOTAccumulator()
118 |
119 | acc.update([], [1, 2], [], frameid=0)
120 | acc.update([1, 2], [], [], frameid=1)
121 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=2)
122 | acc.update([1, 2], [1, 2], [[0.2, np.nan], [np.nan, 0.1]], frameid=3)
123 |
124 | r, mappings = mm.MOTAccumulator.merge_event_dataframes([acc.events, acc.events], return_mappings=True)
125 |
126 | expect = mm.MOTAccumulator.new_event_dataframe()
127 |
128 | expect.loc[(0, 0), :] = ['RAW', np.nan, np.nan, np.nan]
129 | expect.loc[(0, 1), :] = ['RAW', np.nan, mappings[0]['hid_map'][1], np.nan]
130 | expect.loc[(0, 2), :] = ['RAW', np.nan, mappings[0]['hid_map'][2], np.nan]
131 | expect.loc[(0, 3), :] = ['FP', np.nan, mappings[0]['hid_map'][1], np.nan]
132 | expect.loc[(0, 4), :] = ['FP', np.nan, mappings[0]['hid_map'][2], np.nan]
133 |
134 | expect.loc[(1, 0), :] = ['RAW', np.nan, np.nan, np.nan]
135 | expect.loc[(1, 1), :] = ['RAW', mappings[0]['oid_map'][1], np.nan, np.nan]
136 | expect.loc[(1, 2), :] = ['RAW', mappings[0]['oid_map'][2], np.nan, np.nan]
137 | expect.loc[(1, 3), :] = ['MISS', mappings[0]['oid_map'][1], np.nan, np.nan]
138 | expect.loc[(1, 4), :] = ['MISS', mappings[0]['oid_map'][2], np.nan, np.nan]
139 |
140 | expect.loc[(2, 0), :] = ['RAW', np.nan, np.nan, np.nan]
141 | expect.loc[(2, 1), :] = ['RAW', mappings[0]['oid_map'][1], mappings[0]['hid_map'][1], 1]
142 | expect.loc[(2, 2), :] = ['RAW', mappings[0]['oid_map'][1], mappings[0]['hid_map'][2], 0.5]
143 | expect.loc[(2, 3), :] = ['RAW', mappings[0]['oid_map'][2], mappings[0]['hid_map'][1], 0.3]
144 | expect.loc[(2, 4), :] = ['RAW', mappings[0]['oid_map'][2], mappings[0]['hid_map'][2], 1.0]
145 | expect.loc[(2, 5), :] = ['MATCH', mappings[0]['oid_map'][1], mappings[0]['hid_map'][2], 0.5]
146 | expect.loc[(2, 6), :] = ['MATCH', mappings[0]['oid_map'][2], mappings[0]['hid_map'][1], 0.3]
147 |
148 | expect.loc[(3, 0), :] = ['RAW', np.nan, np.nan, np.nan]
149 | expect.loc[(3, 1), :] = ['RAW', mappings[0]['oid_map'][1], mappings[0]['hid_map'][1], 0.2]
150 | expect.loc[(3, 2), :] = ['RAW', mappings[0]['oid_map'][2], mappings[0]['hid_map'][2], 0.1]
151 | expect.loc[(3, 3), :] = ['TRANSFER', mappings[0]['oid_map'][1], mappings[0]['hid_map'][1], 0.2]
152 | expect.loc[(3, 4), :] = ['SWITCH', mappings[0]['oid_map'][1], mappings[0]['hid_map'][1], 0.2]
153 | expect.loc[(3, 5), :] = ['TRANSFER', mappings[0]['oid_map'][2], mappings[0]['hid_map'][2], 0.1]
154 | expect.loc[(3, 6), :] = ['SWITCH', mappings[0]['oid_map'][2], mappings[0]['hid_map'][2], 0.1]
155 |
156 | # Merge duplication
157 | expect.loc[(4, 0), :] = ['RAW', np.nan, np.nan, np.nan]
158 | expect.loc[(4, 1), :] = ['RAW', np.nan, mappings[1]['hid_map'][1], np.nan]
159 | expect.loc[(4, 2), :] = ['RAW', np.nan, mappings[1]['hid_map'][2], np.nan]
160 | expect.loc[(4, 3), :] = ['FP', np.nan, mappings[1]['hid_map'][1], np.nan]
161 | expect.loc[(4, 4), :] = ['FP', np.nan, mappings[1]['hid_map'][2], np.nan]
162 |
163 | expect.loc[(5, 0), :] = ['RAW', np.nan, np.nan, np.nan]
164 | expect.loc[(5, 1), :] = ['RAW', mappings[1]['oid_map'][1], np.nan, np.nan]
165 | expect.loc[(5, 2), :] = ['RAW', mappings[1]['oid_map'][2], np.nan, np.nan]
166 | expect.loc[(5, 3), :] = ['MISS', mappings[1]['oid_map'][1], np.nan, np.nan]
167 | expect.loc[(5, 4), :] = ['MISS', mappings[1]['oid_map'][2], np.nan, np.nan]
168 |
169 | expect.loc[(6, 0), :] = ['RAW', np.nan, np.nan, np.nan]
170 | expect.loc[(6, 1), :] = ['RAW', mappings[1]['oid_map'][1], mappings[1]['hid_map'][1], 1]
171 | expect.loc[(6, 2), :] = ['RAW', mappings[1]['oid_map'][1], mappings[1]['hid_map'][2], 0.5]
172 | expect.loc[(6, 3), :] = ['RAW', mappings[1]['oid_map'][2], mappings[1]['hid_map'][1], 0.3]
173 | expect.loc[(6, 4), :] = ['RAW', mappings[1]['oid_map'][2], mappings[1]['hid_map'][2], 1.0]
174 | expect.loc[(6, 5), :] = ['MATCH', mappings[1]['oid_map'][1], mappings[1]['hid_map'][2], 0.5]
175 | expect.loc[(6, 6), :] = ['MATCH', mappings[1]['oid_map'][2], mappings[1]['hid_map'][1], 0.3]
176 |
177 | expect.loc[(7, 0), :] = ['RAW', np.nan, np.nan, np.nan]
178 | expect.loc[(7, 1), :] = ['RAW', mappings[1]['oid_map'][1], mappings[1]['hid_map'][1], 0.2]
179 | expect.loc[(7, 2), :] = ['RAW', mappings[1]['oid_map'][2], mappings[1]['hid_map'][2], 0.1]
180 | expect.loc[(7, 3), :] = ['TRANSFER', mappings[1]['oid_map'][1], mappings[1]['hid_map'][1], 0.2]
181 | expect.loc[(7, 4), :] = ['SWITCH', mappings[1]['oid_map'][1], mappings[1]['hid_map'][1], 0.2]
182 | expect.loc[(7, 5), :] = ['TRANSFER', mappings[1]['oid_map'][2], mappings[1]['hid_map'][2], 0.1]
183 | expect.loc[(7, 6), :] = ['SWITCH', mappings[1]['oid_map'][2], mappings[1]['hid_map'][2], 0.1]
184 |
185 | pd.util.testing.assert_frame_equal(r, expect)
186 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/tests/test_utils.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Tests accumulation of events using utility functions."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import itertools
15 |
16 | import numpy as np
17 | import pandas as pd
18 |
19 | import motmetrics as mm
20 |
21 |
22 | def test_annotations_xor_predictions_present():
23 | """Tests frames that contain only annotations or predictions."""
24 | _ = None
25 | anno_tracks = {
26 | 1: [0, 2, 4, 6, _, _, _],
27 | 2: [_, _, 0, 2, 4, _, _],
28 | }
29 | pred_tracks = {
30 | 1: [_, _, 3, 5, 7, 7, 7],
31 | }
32 | anno = _tracks_to_dataframe(anno_tracks)
33 | pred = _tracks_to_dataframe(pred_tracks)
34 | acc = mm.utils.compare_to_groundtruth(anno, pred, 'euc', distfields=['Position'], distth=2)
35 | mh = mm.metrics.create()
36 | metrics = mh.compute(acc, return_dataframe=False, metrics=[
37 | 'num_objects', 'num_predictions', 'num_unique_objects',
38 | ])
39 | np.testing.assert_equal(metrics['num_objects'], 7)
40 | np.testing.assert_equal(metrics['num_predictions'], 5)
41 | np.testing.assert_equal(metrics['num_unique_objects'], 2)
42 |
43 |
44 | def _tracks_to_dataframe(tracks):
45 | rows = []
46 | for track_id, track in tracks.items():
47 | for frame_id, position in zip(itertools.count(1), track):
48 | if position is None:
49 | continue
50 | rows.append({
51 | 'FrameId': frame_id,
52 | 'Id': track_id,
53 | 'Position': position,
54 | })
55 | return pd.DataFrame(rows).set_index(['FrameId', 'Id'])
56 |
--------------------------------------------------------------------------------
/posetrack21_eval/motmetrics/utils.py:
--------------------------------------------------------------------------------
1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking.
2 | # https://github.com/cheind/py-motmetrics/
3 | #
4 | # MIT License
5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others.
6 | # See LICENSE file for terms.
7 |
8 | """Functions for populating event accumulators."""
9 |
10 | from __future__ import absolute_import
11 | from __future__ import division
12 | from __future__ import print_function
13 |
14 | import numpy as np
15 |
16 | from .distances import iou_matrix, norm2squared_matrix
17 | from .mot import MOTAccumulator
18 | from .preprocess import preprocessResult
19 |
20 |
21 | def compare_to_groundtruth(gt, dt, dist='iou', distfields=None, distth=0.5):
22 | """Compare groundtruth and detector results.
23 |
24 | This method assumes both results are given in terms of DataFrames with at least the following fields
25 | - `FrameId` First level index used for matching ground-truth and test frames.
26 | - `Id` Secondary level index marking available object / hypothesis ids
27 |
28 | Depending on the distance to be used relevant distfields need to be specified.
29 |
30 | Params
31 | ------
32 | gt : pd.DataFrame
33 | Dataframe for ground-truth
34 | test : pd.DataFrame
35 | Dataframe for detector results
36 |
37 | Kwargs
38 | ------
39 | dist : str, optional
40 | String identifying distance to be used. Defaults to intersection over union.
41 | distfields: array, optional
42 | Fields relevant for extracting distance information. Defaults to ['X', 'Y', 'Width', 'Height']
43 | distth: float, optional
44 | Maximum tolerable distance. Pairs exceeding this threshold are marked 'do-not-pair'.
45 | """
46 | # pylint: disable=too-many-locals
47 | if distfields is None:
48 | distfields = ['X', 'Y', 'Width', 'Height']
49 |
50 | def compute_iou(a, b):
51 | return iou_matrix(a, b, max_iou=distth)
52 |
53 | def compute_euc(a, b):
54 | return norm2squared_matrix(a, b, max_d2=distth)
55 |
56 | compute_dist = compute_iou if dist.upper() == 'IOU' else compute_euc
57 |
58 | acc = MOTAccumulator()
59 |
60 | # We need to account for all frames reported either by ground truth or
61 | # detector. In case a frame is missing in GT this will lead to FPs, in
62 | # case a frame is missing in detector results this will lead to FNs.
63 | allframeids = gt.index.union(dt.index).levels[0]
64 |
65 | gt = gt[distfields]
66 | dt = dt[distfields]
67 | fid_to_fgt = dict(iter(gt.groupby('FrameId')))
68 | fid_to_fdt = dict(iter(dt.groupby('FrameId')))
69 |
70 | for fid in allframeids:
71 | oids = np.empty(0)
72 | hids = np.empty(0)
73 | dists = np.empty((0, 0))
74 | if fid in fid_to_fgt:
75 | fgt = fid_to_fgt[fid]
76 | oids = fgt.index.get_level_values('Id')
77 | if fid in fid_to_fdt:
78 | fdt = fid_to_fdt[fid]
79 | hids = fdt.index.get_level_values('Id')
80 | if len(oids) > 0 and len(hids) > 0:
81 | dists = compute_dist(fgt.values, fdt.values)
82 | acc.update(oids, hids, dists, frameid=fid)
83 |
84 | return acc
85 |
86 |
87 | def CLEAR_MOT_M(gt, dt, inifile, dist='iou', distfields=None, distth=0.5, include_all=False, vflag=''):
88 | """Compare groundtruth and detector results.
89 |
90 | This method assumes both results are given in terms of DataFrames with at least the following fields
91 | - `FrameId` First level index used for matching ground-truth and test frames.
92 | - `Id` Secondary level index marking available object / hypothesis ids
93 |
94 | Depending on the distance to be used relevant distfields need to be specified.
95 |
96 | Params
97 | ------
98 | gt : pd.DataFrame
99 | Dataframe for ground-truth
100 | test : pd.DataFrame
101 | Dataframe for detector results
102 |
103 | Kwargs
104 | ------
105 | dist : str, optional
106 | String identifying distance to be used. Defaults to intersection over union.
107 | distfields: array, optional
108 | Fields relevant for extracting distance information. Defaults to ['X', 'Y', 'Width', 'Height']
109 | distth: float, optional
110 | Maximum tolerable distance. Pairs exceeding this threshold are marked 'do-not-pair'.
111 | """
112 | # pylint: disable=too-many-locals
113 | if distfields is None:
114 | distfields = ['X', 'Y', 'Width', 'Height']
115 |
116 | def compute_iou(a, b):
117 | return iou_matrix(a, b, max_iou=distth)
118 |
119 | def compute_euc(a, b):
120 | return norm2squared_matrix(a, b, max_d2=distth)
121 |
122 | compute_dist = compute_iou if dist.upper() == 'IOU' else compute_euc
123 |
124 | acc = MOTAccumulator()
125 | dt = preprocessResult(dt, gt, inifile)
126 | if include_all:
127 | gt = gt[gt['Confidence'] >= 0.99]
128 | else:
129 | gt = gt[(gt['Confidence'] >= 0.99) & (gt['ClassId'] == 1)]
130 | # We need to account for all frames reported either by ground truth or
131 | # detector. In case a frame is missing in GT this will lead to FPs, in
132 | # case a frame is missing in detector results this will lead to FNs.
133 | allframeids = gt.index.union(dt.index).levels[0]
134 | analysis = {'hyp': {}, 'obj': {}}
135 | for fid in allframeids:
136 | oids = np.empty(0)
137 | hids = np.empty(0)
138 | dists = np.empty((0, 0))
139 |
140 | if fid in gt.index:
141 | fgt = gt.loc[fid]
142 | oids = fgt.index.values
143 | for oid in oids:
144 | oid = int(oid)
145 | if oid not in analysis['obj']:
146 | analysis['obj'][oid] = 0
147 | analysis['obj'][oid] += 1
148 |
149 | if fid in dt.index:
150 | fdt = dt.loc[fid]
151 | hids = fdt.index.values
152 | for hid in hids:
153 | hid = int(hid)
154 | if hid not in analysis['hyp']:
155 | analysis['hyp'][hid] = 0
156 | analysis['hyp'][hid] += 1
157 |
158 | if oids.shape[0] > 0 and hids.shape[0] > 0:
159 | dists = compute_dist(fgt[distfields].values, fdt[distfields].values)
160 |
161 | acc.update(oids, hids, dists, frameid=fid, vf=vflag)
162 |
163 | return acc, analysis
164 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "comotion_demo"
3 | description = "CoMotion: Concurrent Multi-person 3D Motion."
4 | version = "0.1"
5 | authors = [
6 | {name = "Alejandro Newell"},
7 | {name = "Peiyun Hu"},
8 | {name = "Lahav Lipson"},
9 | {name = "Stephan R. Richter"},
10 | {name = "Vladlen Koltun"},
11 | ]
12 | readme = "README.md"
13 | dependencies = [
14 | "click",
15 | "coremltools",
16 | "einops",
17 | "ffmpeg-python",
18 | "opencv-python",
19 | "pypose",
20 | "PyQt6",
21 | "ruff",
22 | "scenedetect",
23 | "tensordict",
24 | "timm==1.0.13",
25 | "torch==2.5.1",
26 | "transformers==4.48.0",
27 | "tqdm",
28 | "chumpy @ git+https://github.com/mattloper/chumpy@9b045ff5d6588a24a0bab52c83f032e2ba433e17",
29 | ]
30 | requires-python = ">=3.10"
31 |
32 | [project.optional-dependencies]
33 | all = [
34 | "aitviewer",
35 | ]
36 | colab = [
37 | "pyrender",
38 | "smplx[all]",
39 | "more-itertools",
40 | "trimesh"
41 | ]
42 |
43 | [project.urls]
44 | Homepage = "https://github.com/apple/ml-comotion"
45 | Repository = "https://github.com/apple/ml-comotion"
46 |
47 | [build-system]
48 | requires = ["setuptools", "setuptools-scm"]
49 | build-backend = "setuptools.build_meta"
50 |
51 | [tool.setuptools.packages.find]
52 | where = ["src"]
53 |
--------------------------------------------------------------------------------
/samples/sample_info.txt:
--------------------------------------------------------------------------------
1 | All sample visualizations provided under CC-by-NC-ND.
2 |
3 | Sources:
4 | https://www.pexels.com/video/a-man-and-a-woman-dancing-the-tango-on-the-sidewalk-8281172/
5 | https://www.pexels.com/video/a-man-doing-breakdancing-8688465/
6 | https://www.pexels.com/video/bearded-man-doing-breakdancing-9344627/
7 | https://www.pexels.com/video/a-woman-showing-her-ballet-skill-in-turning-one-footed-5385885/
8 | https://www.pexels.com/video/a-couple-practicing-acrobatics-6809489/
9 |
--------------------------------------------------------------------------------
/samples/teaser_01.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_01.gif
--------------------------------------------------------------------------------
/samples/teaser_02.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_02.gif
--------------------------------------------------------------------------------
/samples/teaser_03.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_03.gif
--------------------------------------------------------------------------------
/samples/teaser_04.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_04.gif
--------------------------------------------------------------------------------
/samples/teaser_05.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_05.gif
--------------------------------------------------------------------------------
/samples/teaser_06.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_06.gif
--------------------------------------------------------------------------------
/src/comotion_demo/data/smpl/extra_smpl_reference.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/src/comotion_demo/data/smpl/extra_smpl_reference.pt
--------------------------------------------------------------------------------
/src/comotion_demo/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | """CoMotion model."""
3 |
4 | from . import backbones # noqa
5 | from .comotion import CoMotion # noqa
6 | from .detect import CoMotionDetect # noqa
7 |
--------------------------------------------------------------------------------
/src/comotion_demo/models/backbones/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | """Backbone network for CoMotion."""
3 |
4 | from ._registry import _lookup
5 | from .convnext import ConvNextV2 # noqa
6 |
7 |
8 | def initialize(backbone_choice):
9 | """Initialize a registered backbone."""
10 | assert backbone_choice in _lookup, (
11 | f"Backbone choice '{backbone_choice}' not found in registry."
12 | )
13 | return _lookup[backbone_choice]()
14 |
--------------------------------------------------------------------------------
/src/comotion_demo/models/backbones/_registry.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | _lookup = {}
3 |
4 |
5 | def register_model(fn, model_name=None):
6 | """Register a model under a name. When unspecified, use the class name."""
7 | if model_name is None:
8 | model_name = fn.__name__
9 | _lookup[model_name] = fn
10 |
11 | return fn
12 |
--------------------------------------------------------------------------------
/src/comotion_demo/models/backbones/convnext.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | """ConvNextV2 vision backbone."""
3 |
4 | import einops as eo
5 | from timm import create_model
6 | from timm.layers import LayerNorm2d
7 | from torch import nn
8 |
9 | from ._registry import register_model
10 |
11 |
12 | @register_model
13 | class ConvNextV2(nn.Module):
14 | """Extract a feature pyramid out of pre-trained ConvNextV2."""
15 |
16 | def __init__(self, output_dim=256, size="l", dropout=0.2, pretrained=False):
17 | """Set up a variant of ConvNextV2 and the feature extractors."""
18 | super().__init__()
19 |
20 | self.output_dim = output_dim
21 | self.size = size
22 | self.dropout = dropout
23 | self.pretrained = pretrained
24 |
25 | # Initialize network
26 | backbone_ref = {
27 | "t": ["convnextv2_tiny.fcmae", [96, 192, 384, 768]],
28 | "b": ["convnextv2_base.fcmae", [128, 256, 512, 1024]],
29 | "l": ["convnextv2_large.fcmae", [192, 384, 768, 1536]],
30 | "h": ["convnextv2_huge.fcmae", [352, 704, 1408, 2816]],
31 | }
32 | network_name, feat_dims = backbone_ref[self.size]
33 | network = create_model(network_name, pretrained=self.pretrained)
34 | self.stem = network.stem
35 | self.stages = network.stages
36 |
37 | # Layers to fuse features across scales
38 | self._init_feature_fusion(feat_dims, output_dim)
39 | self.norm_px_feats = LayerNorm2d(output_dim)
40 | if self.dropout > 0:
41 | self.px_dropout = nn.Dropout2d(self.dropout)
42 |
43 | def _init_feature_fusion(self, feat_dims, output_dim):
44 | """Initialize linear layers for feature extraction at different levels."""
45 | self.project_0 = nn.Conv2d(feat_dims[0], output_dim, 2, 2)
46 | self.project_1 = nn.Conv2d(feat_dims[1], output_dim, 1, 1)
47 | self.project_2 = nn.Conv2d(feat_dims[2], output_dim * 4, 1, 1)
48 | self.project_3 = nn.Conv2d(feat_dims[3], output_dim * 16, 1, 1)
49 |
50 | def forward(self, x):
51 | """Run ConvNextV2 inference and return a list of feature maps."""
52 | # Run ConvNext stages
53 | x = self.stem(x)
54 | feat_pyramid = []
55 | for stage in self.stages:
56 | x = stage(x)
57 | feat_pyramid.append(x)
58 |
59 | # Fuse into single tensor at 1/8th input resolution
60 | f0 = self.project_0(feat_pyramid[0])
61 | f1 = self.project_1(feat_pyramid[1])
62 | f2 = self.project_2(feat_pyramid[2])
63 | f3 = self.project_3(feat_pyramid[3])
64 | f2 = eo.rearrange(f2, "... (c h2 w2) h w -> ... c (h h2) (w w2)", h2=2, w2=2)
65 | f3 = eo.rearrange(f3, "... (c h2 w2) h w -> ... c (h h2) (w w2)", h2=4, w2=4)
66 | px_feats = f0 + f1 + f2 + f3
67 | px_feats = self.norm_px_feats(px_feats)
68 | if self.dropout > 0:
69 | px_feats = self.px_dropout(px_feats)
70 |
71 | return px_feats
72 |
--------------------------------------------------------------------------------
/src/comotion_demo/models/comotion.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | import torch
3 | from scenedetect.detectors import ContentDetector
4 | from torch import nn
5 |
6 | from ..utils import dataloading, helper, smpl_kinematics, track
7 | from . import detect, refine
8 |
9 |
10 | class CoMotion(nn.Module):
11 | """CoMotion network."""
12 |
13 | def __init__(self, use_coreml=False, pretrained=True):
14 | """Initialize CoMotion.
15 |
16 | Args:
17 | ----
18 | use_coreml: use CoreML version of the detection model (macOS only).
19 | pretrained: load pre-trained CoMotion modules.
20 |
21 | """
22 | super().__init__()
23 |
24 | if use_coreml:
25 | self.detection_model = detect.CoMotionDetectCoreML()
26 | else:
27 | self.detection_model = detect.CoMotionDetect(pretrained=pretrained)
28 |
29 | self.update_step = refine.PoseRefinement(
30 | self.detection_model.feat_dim,
31 | self.detection_model.cfg.pose_embed_dim,
32 | pretrained=pretrained,
33 | )
34 |
35 | self.smpl_decoder = smpl_kinematics.SMPLKinematics()
36 | self.shot_detector = ContentDetector(threshold=50.0, min_scene_len=3)
37 | self.frame_count = 0
38 |
39 | def init_tracks(self, image_res):
40 | """Initialize track handler."""
41 | self.handler = track.TrackHandler(track.default_dims, image_res)
42 |
43 | @torch.inference_mode()
44 | def forward(self, image, K, detection_only=False, use_mps=False):
45 | """Perform detection and tracking given a new image.
46 |
47 | Input images are accepted at any resolution, resizing and cropping is
48 | handled automatically and output 2D keypoints will be provided at the
49 | original input resolution.
50 |
51 | Args:
52 | ----
53 | image: Input image (C x H x W) float (0-1) or uint8 (0-255) tensor
54 | K: Intrinsics matrix (2 x 3) float tensor
55 | detection_only: Flag whether to only run initial detection stage
56 | use_mps: Flag whether to run update step on MPS on MacOS
57 |
58 | """
59 | # Prepare inputs
60 | device = next(self.parameters()).device
61 | if use_mps:
62 | self.update_step.to("mps")
63 | device = "cpu"
64 |
65 | # Prepare inputs
66 | cropped_image, cropped_K = dataloading.prepare_network_inputs(image, K, device)
67 | K = K.to(device)
68 |
69 | # Get detections
70 | detect_out = self.detection_model(cropped_image, cropped_K)
71 | nms_out = detect.decode_network_outputs(
72 | K,
73 | self.smpl_decoder,
74 | detect_out,
75 | std=0.08,
76 | iou_thr=0.4,
77 | conf_thr=0.1,
78 | )
79 |
80 | if detection_only:
81 | return nms_out
82 |
83 | def call_update(s: track.TrackTensorState):
84 | # Prepare inputs
85 | update_args = [
86 | detect_out.image_features,
87 | cropped_K,
88 | s.betas,
89 | s.pose,
90 | s.trans,
91 | s.pred_3d,
92 | s.hidden,
93 | ]
94 | if use_mps:
95 | update_args = [arg.to("mps") for arg in update_args]
96 |
97 | # Run update step
98 | updated_params = self.update_step(*update_args)
99 | updated_params = refine.RefineOutput(
100 | **{k: v.to(device) for k, v in updated_params._asdict().items()}
101 | )
102 |
103 | # Update track state
104 | s.pose = detect.get_smpl_pose(
105 | updated_params.delta_root_orient,
106 | updated_params.delta_body_pose,
107 | )
108 | s.trans = updated_params.trans
109 | s.hidden = updated_params.hidden
110 | s.pred_3d = self.smpl_decoder(
111 | s.betas, s.pose, s.trans, output_format="joints_face"
112 | )
113 | s.pred_2d = helper.project_to_2d(K, s.pred_3d)
114 |
115 | # Detect shot changes
116 | image_np = image.detach().to("cpu").permute(1, 2, 0).numpy()
117 | image_np = image_np[:, :, ::-1] # RGB2BGR
118 | shots = self.shot_detector.process_frame(self.frame_count, image_np)
119 | self.frame_count += 1
120 | is_new_shot = len(shots) > 1 if self.frame_count > 1 else False
121 | return nms_out, self.handler.update(
122 | nms_out, call_update, shot_reset=is_new_shot
123 | )
124 |
--------------------------------------------------------------------------------
/src/comotion_demo/models/detect.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | import os
3 | from collections import namedtuple
4 | from dataclasses import dataclass
5 |
6 | import einops as eo
7 | import torch
8 | from torch import nn
9 |
10 | from ..utils import helper, smpl_kinematics
11 | from . import backbones, layers
12 |
13 | curr_dir = os.path.abspath(os.path.dirname(__file__))
14 | PYTORCH_CHECKPOINT_PATH = f"{curr_dir}/../data/comotion_detection_checkpoint.pt"
15 | COREML_CHECKPOINT_PATH = f"{curr_dir}/../data/comotion_detection.mlpackage"
16 |
17 | DetectionOutput = namedtuple(
18 | "DetectOutput",
19 | [
20 | "image_features",
21 | "betas",
22 | "delta_root_orient",
23 | "delta_body_pose",
24 | "trans",
25 | "conf",
26 | ],
27 | )
28 |
29 |
30 | @dataclass
31 | class CoMotionDetectConfig:
32 | backbone_choice: str = "ConvNextV2"
33 | pose_embed_dim: int = 256
34 | hidden_dim: int = 512
35 | rot_embed_dim: int = 8
36 |
37 |
38 | class DetectionHead(nn.Module):
39 | """CoMotion detection head.
40 |
41 | Accepts as input image features and an intrinsics matrix to produce a large
42 | pool of candidate SMPL poses and corresponding confidences.
43 | """
44 |
45 | def __init__(
46 | self,
47 | input_dim,
48 | output_split,
49 | hidden_dim=512,
50 | num_slots=4,
51 | depth_adj=128,
52 | dropout=0.0,
53 | ):
54 | """Initialize CoMotion detection head."""
55 | super().__init__()
56 |
57 | self.input_dim = input_dim
58 | self.hidden_dim = hidden_dim
59 | self.num_slots = num_slots
60 | self.output_split = output_split
61 | self.depth_adj = depth_adj
62 | self.dropout = dropout
63 |
64 | # Blocks for feature pyramid encoding.
65 | self.enc1 = nn.Sequential(
66 | layers.DownsampleConvNextBlock(input_dim, hidden_dim),
67 | layers.DownsampleConvNextBlock(hidden_dim, 2 * hidden_dim, dropout=dropout),
68 | )
69 | self.enc2 = layers.DownsampleConvNextBlock(2 * hidden_dim, dropout=dropout)
70 | self.enc3 = layers.DownsampleConvNextBlock(2 * hidden_dim, dropout=dropout)
71 |
72 | out_dim = sum(self.output_split) * self.num_slots
73 | self.decoders = nn.ModuleList(
74 | [
75 | nn.Sequential(
76 | nn.Conv2d(2 * hidden_dim, 2 * hidden_dim, 1),
77 | layers.LayerNorm2d(2 * hidden_dim),
78 | nn.ReLU(),
79 | nn.Conv2d(2 * hidden_dim, out_dim, 1),
80 | )
81 | for _ in range(3)
82 | ]
83 | )
84 |
85 | def forward(self, px_feats, K, pooling=8, return_feats=False):
86 | """Detect poses from image features and intrinsics.
87 |
88 | Args:
89 | ----
90 | px_feats: image features from the backbone network.
91 | K: image intrinsics.
92 | pooling: downsampling factor from input image to features.
93 | return_feats: return image feature pyramid.
94 |
95 | Return:
96 | ------
97 | detections: 1344 candidates = 1024 + 256 + 64 from 3 levels.
98 | feat_pyramid: image feature pyramid.
99 |
100 | """
101 | # Rescale factor
102 | ht, wd = px_feats.shape[-2:]
103 | rescale = max(ht, wd) * pooling
104 | calib_adj = K[..., 0, 0][:, None, None]
105 |
106 | # Apply encoding and decoding layers
107 | x1 = self.enc1(px_feats)
108 | x2 = self.enc2(x1)
109 | x3 = self.enc3(x2)
110 | feat_pyramid = [x1, x2, x3]
111 |
112 | # Post-process to produce full set of detections
113 | pred_state = []
114 | for pyr_idx, feats in enumerate(feat_pyramid):
115 | feats = feats.float()
116 |
117 | ht, wd = feats.shape[-2:]
118 | ref_xy = helper.get_grid(ht, wd, device=feats.device) + 0.5 / max(ht, wd)
119 | ref_xy *= rescale
120 | ref_xy_slots = eo.repeat(ref_xy, "h w c -> (h w n) c", n=self.num_slots)
121 |
122 | pred = self.decoders[pyr_idx](feats)
123 | pred = eo.rearrange(pred, "b (n c) h w -> b (h w n) c", n=self.num_slots)
124 |
125 | (
126 | init_betas,
127 | init_pose,
128 | init_rot,
129 | init_xy,
130 | init_z,
131 | _,
132 | conf,
133 | ) = pred.split_with_sizes(dim=-1, split_sizes=self.output_split)
134 |
135 | # Adjust translation offset based on camera intrinsics
136 | default_depth = calib_adj / (self.depth_adj * 2**pyr_idx)
137 | init_z = default_depth / (torch.exp(init_z) + 0.05)
138 | init_xy = init_xy + ref_xy_slots
139 | init_xy = helper.px_to_world(K.unsqueeze(1), init_xy) * init_z
140 | init_trans = torch.cat([init_xy, init_z], -1)
141 |
142 | pred_state.append(
143 | {
144 | "betas": init_betas,
145 | "delta_root_orient": init_rot,
146 | "pose_embedding": init_pose,
147 | "trans": init_trans,
148 | "conf": conf,
149 | }
150 | )
151 |
152 | detections = {}
153 | for k in pred_state[0].keys():
154 | detections[k] = torch.cat([p[k] for p in pred_state], 1).float()
155 |
156 | if return_feats:
157 | return detections, feat_pyramid
158 | else:
159 | return detections
160 |
161 |
162 | class CoMotionDetect(nn.Module):
163 | """CoMotion detection module.
164 |
165 | Module responsible for initial feature extraction from a ConvNext backbone
166 | as well as producing candidate per-frame detections.
167 | """
168 |
169 | def __init__(
170 | self, cfg: CoMotionDetectConfig | None = None, pretrained: bool = True
171 | ):
172 | """Initialize CoMotion detection module.
173 |
174 | Args:
175 | ----
176 | cfg: Detection config defining various model hyperparameters.
177 | pretrained: Whether to load pretrained detection checkpoint.
178 |
179 | """
180 | super().__init__()
181 | cfg = CoMotionDetectConfig() if cfg is None else cfg
182 |
183 | self.cfg = cfg
184 | self.pos_embedding = layers.PosEmbed()
185 | self.rot_embedding = layers.RotaryEmbed(cfg.rot_embed_dim)
186 | self.kn = smpl_kinematics.SMPLKinematics()
187 | self.body_keys = ["betas", "pose", "trans"]
188 |
189 | # Instantiate an image backbone network
190 | self.image_backbone = backbones.initialize(cfg.backbone_choice)
191 | self.feat_dim = self.image_backbone.output_dim
192 |
193 | # Detection head
194 | # Output split: betas, pose embedding, root_orient, xy, z, scale, confidence
195 | output_split = [smpl_kinematics.BETA_DOF, cfg.pose_embed_dim, 3, 2, 1, 1, 1]
196 | self.detection_head = DetectionHead(
197 | input_dim=self.feat_dim,
198 | hidden_dim=self.cfg.hidden_dim,
199 | output_split=output_split,
200 | )
201 |
202 | self.pose_decoder = nn.Sequential(
203 | nn.LayerNorm(cfg.pose_embed_dim),
204 | nn.Linear(cfg.pose_embed_dim, cfg.hidden_dim),
205 | layers.ResidualMLP(cfg.hidden_dim),
206 | nn.LayerNorm(cfg.hidden_dim),
207 | nn.GELU(),
208 | nn.Linear(cfg.hidden_dim, smpl_kinematics.POSE_DOF - 3),
209 | )
210 |
211 | self.fuse_features = layers.FusePyramid(
212 | self.feat_dim,
213 | 2 * cfg.hidden_dim,
214 | )
215 |
216 | if pretrained:
217 | checkpoint = torch.load(PYTORCH_CHECKPOINT_PATH, weights_only=True)
218 | self.load_state_dict(checkpoint)
219 |
220 | def _intrinsics_conditioning(self, feats, K, pooling=8, normalize_factor=1024):
221 | """Condition image features on pixel and world coordinate mapping.
222 |
223 | From input intrinsics matrix we define a coordinate grid and use rotary
224 | embeddings to update the extracted image features.
225 | """
226 | batch_size = feats.shape[0]
227 | feats = feats.clone()
228 | device = feats.device
229 |
230 | with torch.no_grad():
231 | # Get reference pixel positions
232 | ht, wd = feats.shape[-2:]
233 | ref_xy = helper.get_grid(ht, wd, device=device) + 0.5 / max(ht, wd)
234 | ref_xy *= max(ht, wd) * pooling
235 |
236 | # Reduce scale of pixel values
237 | K = K / normalize_factor
238 | ref_xy = ref_xy / normalize_factor
239 |
240 | # Adjust into world coordinate frame
241 | ref_xy_world = helper.px_to_world(K[:, None, None], ref_xy[None])
242 |
243 | # Get rotary embeddings
244 | xy_cs, xy_sn = self.rot_embedding(ref_xy)
245 | xy_world_cs, xy_world_sn = self.rot_embedding(ref_xy_world)
246 |
247 | # Rearrange from BHWC -> BCHW
248 | xy_cs = eo.repeat(xy_cs, f"h w d0 d1 -> {batch_size} (d0 d1) h w")
249 | xy_sn = eo.repeat(xy_sn, f"h w d0 d1 -> {batch_size} (d0 d1) h w")
250 | xy_world_cs = eo.rearrange(xy_world_cs, "b h w d0 d1 -> b (d0 d1) h w")
251 | xy_world_sn = eo.rearrange(xy_world_sn, "b h w d0 d1 -> b (d0 d1) h w")
252 |
253 | # Apply rotary embeddings
254 | cs = torch.cat([xy_cs, xy_world_cs, xy_cs, xy_world_cs], 1)
255 | sn = torch.cat([xy_sn, xy_world_sn, xy_sn, xy_world_sn], 1)
256 | embed_dim = sn.shape[1]
257 | f0 = feats[:, :embed_dim]
258 | f0 = layers.apply_rotary_pos_emb(f0, cs, sn)
259 | feats = torch.cat([f0, feats[:, embed_dim:]], 1)
260 |
261 | return feats
262 |
263 | @torch.inference_mode
264 | def forward(self, img, K) -> DetectionOutput:
265 | """Extract backbone features and detect poses.
266 |
267 | Args:
268 | ----
269 | img: input image tensor of shape (B, 3, 512, 512)
270 | K: input intrinsic tensor of shape (B, 2, 3)
271 |
272 | Return:
273 | ------
274 | DetectionOutput: NamedTuple that includes all output detection parameters.
275 |
276 | """
277 | outputs = {}
278 |
279 | # Get backbone features
280 | feats = self.image_backbone(img)
281 | feats = self._intrinsics_conditioning(feats, K)
282 |
283 | # Get detections
284 | detections, feature_pyramid = self.detection_head(feats, K, return_feats=True)
285 | for k, v in detections.items():
286 | if k == "pose_embedding":
287 | # Remap latent pose embedding to joint angles
288 | # Note: these are residual terms applied to a default pose
289 | outputs["delta_body_pose"] = self.pose_decoder(v) * 0.3
290 | else:
291 | outputs[k] = v
292 |
293 | # Fuse feature pyramid
294 | feature_pyramid = [feats] + feature_pyramid
295 | outputs["image_features"] = self.fuse_features(*feature_pyramid)
296 |
297 | return DetectionOutput(**outputs)
298 |
299 |
300 | class CoMotionDetectCoreML:
301 | """A CoreML wrapper for CoMotion detection module."""
302 |
303 | def __init__(self):
304 | """Initialize the CoreML model."""
305 | import coremltools as ct
306 |
307 | self.model = ct.models.MLModel(COREML_CHECKPOINT_PATH)
308 | self.cfg = CoMotionDetectConfig()
309 | self.feat_dim = 256
310 |
311 | def __call__(self, img, K) -> DetectionOutput:
312 | """Run inference for the CoreML model."""
313 | outputs = self.model.predict({"image": img, "K": K})
314 | outputs = {k: torch.tensor(v) for k, v in outputs.items()}
315 | return DetectionOutput(**outputs)
316 |
317 |
318 | def get_smpl_pose(delta_root_orient, delta_body_pose):
319 | """Apply predicted delta poses to default mean pose."""
320 | device = delta_root_orient.device
321 | default_pose = smpl_kinematics.extra_ref["mean_pose"].clone().to(device)
322 | delta_pose = torch.cat([delta_root_orient, delta_body_pose], -1)
323 | return smpl_kinematics.update_pose(default_pose, delta_pose)
324 |
325 |
326 | def decode_network_outputs(
327 | K: torch.Tensor,
328 | smpl_decoder: smpl_kinematics.SMPLKinematics,
329 | detect_out: DetectionOutput,
330 | sample_idx: int = 0,
331 | **nms_kwargs,
332 | ):
333 | """Postprocessing to get detections from network output."""
334 | # Decode output SMPL pose and joint coordinates
335 | pose = get_smpl_pose(detect_out.delta_root_orient, detect_out.delta_body_pose)
336 |
337 | pred_3d = smpl_decoder.forward(
338 | detect_out.betas,
339 | pose,
340 | detect_out.trans,
341 | output_format="joints_face",
342 | )
343 | pred_2d = helper.project_to_2d(K, pred_3d)
344 |
345 | # Perform non-maximum suppression
346 | nms_idxs = helper.nms_detections(
347 | pred_2d[sample_idx] / 1024, detect_out.conf[sample_idx].flatten(), **nms_kwargs
348 | )
349 |
350 | detections = {
351 | "betas": detect_out.betas, # (1, 1344, 10)
352 | "pose": pose, # (1, 1344, 72)
353 | "trans": detect_out.trans, # (1, 1344, 3)
354 | "pred_3d": pred_3d, # (1, 1344, 27, 3)
355 | "pred_2d": pred_2d, # (1, 1344, 27, 2)
356 | "conf": detect_out.conf, # (1, 1344, 1)
357 | }
358 |
359 | # Index into selected subset of detections, add singleton batch dimension
360 | for k, v in detections.items():
361 | detections[k] = v[sample_idx, nms_idxs][None]
362 |
363 | return detections
364 |
--------------------------------------------------------------------------------
/src/comotion_demo/models/layers.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | import einops as eo
3 | import torch
4 | import transformers
5 | from timm.layers import LayerNorm2d
6 | from timm.models.convnext import ConvNeXtBlock
7 | from torch import nn
8 | from torch.functional import F
9 | from transformers.models.llama import modeling_llama
10 |
11 | from ..utils import helper
12 |
13 |
14 | class PosEmbed(nn.Module):
15 | def __init__(self, embed_dim=16):
16 | super().__init__()
17 | self.register_buffer(
18 | "freq_bands", 2 ** torch.linspace(0, embed_dim - 1, embed_dim)
19 | )
20 | self.out_dim = 2 * (1 + len(self.freq_bands))
21 |
22 | def forward(self, x):
23 | # Input: ... x 2
24 | emb = torch.sin(x.unsqueeze(-1) * self.freq_bands)
25 | emb = eo.rearrange(emb, "... d0 d1 -> ... (d0 d1)")
26 |
27 | return torch.cat([x, emb], -1)
28 |
29 |
30 | class RotaryEmbed(nn.Module):
31 | def __init__(self, dim=32, rescale=1000, base=10000):
32 | super().__init__()
33 | self.dim = dim
34 | self.rescale = rescale
35 | self.base = base
36 |
37 | inv_freq = rescale / (base ** (torch.arange(0, dim).float() / dim))
38 | self.register_buffer("inv_freq", inv_freq)
39 | self.out_dim = dim
40 |
41 | def forward(self, x):
42 | """Apply frequency band and return cos and sin terms."""
43 | emb = x.unsqueeze(-1) * self.inv_freq
44 | return emb.cos(), emb.sin()
45 |
46 |
47 | def apply_rotary_pos_emb(x, cs, sn):
48 | half_channels = x.shape[1] // 2
49 | x0 = x[:, :half_channels]
50 | x1 = x[:, half_channels:]
51 | x_ = torch.cat((-x1, x0), dim=1)
52 | return cs * x + sn * x_
53 |
54 |
55 | def _gru_update(h, z, q):
56 | return (1 - z) * h + z * q
57 |
58 |
59 | class GRU(nn.Module):
60 | def __init__(self, hdim=128, f_in=128):
61 | super().__init__()
62 | self.fz = nn.Linear(hdim + f_in, hdim)
63 | self.fr = nn.Linear(hdim + f_in, hdim)
64 | self.fq = nn.Linear(hdim + f_in, hdim)
65 |
66 | def forward(self, h, x):
67 | hx = torch.cat([h, x], dim=-1)
68 |
69 | z = torch.sigmoid(self.fz(hx))
70 | r = torch.sigmoid(self.fr(hx))
71 | q = torch.tanh(self.fq(torch.cat([r * h, x], dim=-1)))
72 |
73 | h = _gru_update(h, z, q)
74 | return h
75 |
76 |
77 | class ResidualMLP(nn.Module):
78 | """Residual Multilayer Perceptron."""
79 |
80 | def __init__(self, dim, out_dim=None, hidden_dim=None, pre_ln=True):
81 | super().__init__()
82 |
83 | if out_dim is None:
84 | out_dim = dim
85 | if hidden_dim is None:
86 | hidden_dim = 2 * dim
87 | self.is_residual = dim == out_dim
88 | self.pre_ln = pre_ln
89 |
90 | if pre_ln:
91 | self.ln0 = nn.LayerNorm(dim)
92 | self.l1 = nn.Linear(dim, hidden_dim)
93 | self.ln1 = nn.LayerNorm(hidden_dim)
94 | self.act = nn.GELU()
95 | self.l2 = nn.Linear(hidden_dim, out_dim)
96 |
97 | def forward(self, x):
98 | inp = x
99 | if self.pre_ln:
100 | x = self.ln0(x)
101 | x = self.l1(x)
102 | x = self.act(self.ln1(x))
103 | x = self.l2(x)
104 |
105 | if self.is_residual:
106 | return inp + x
107 | else:
108 | return x
109 |
110 |
111 | class DownsampleConvNextBlock(nn.Module):
112 | """Module to downsample 2x and apply ConvNext layers.
113 |
114 | Note we assume the input has already had normalization applied,
115 | and apply LayerNorm as the last operation.
116 | """
117 |
118 | def __init__(self, input_dim, output_dim=None, dropout=0.0):
119 | super().__init__()
120 |
121 | if output_dim is None:
122 | output_dim = input_dim
123 |
124 | self.layers = nn.Sequential(
125 | nn.Conv2d(input_dim, output_dim, 2, 2, 0), # Downsample
126 | ConvNeXtBlock(output_dim),
127 | LayerNorm2d(output_dim),
128 | nn.Dropout(p=dropout),
129 | )
130 |
131 | def forward(self, x):
132 | return self.layers(x)
133 |
134 |
135 | @helper.fixed_dim_op(d=3)
136 | def _call_llama(x, tf):
137 | kargs = {
138 | "position_embeddings": (
139 | torch.ones_like(x[..., :1]),
140 | torch.zeros_like(x[..., :1]),
141 | )
142 | }
143 | for tf_ in tf:
144 | x = tf_(x, **kargs)[0]
145 | return x
146 |
147 |
148 | class DecodeFromTokens(nn.Module):
149 | """Decode token into hidden update."""
150 |
151 | def __init__(self, hidden_dim):
152 | super().__init__()
153 |
154 | self.ln = nn.LayerNorm(hidden_dim)
155 | self.token_weight = nn.Linear(hidden_dim, hidden_dim)
156 | self.token_to_value = nn.Linear(hidden_dim, hidden_dim)
157 | self.decoder = ResidualMLP(hidden_dim)
158 |
159 | def forward(self, x):
160 | x = self.ln(x)
161 | v = self.token_to_value(x)
162 | w = torch.sigmoid(self.token_weight(x))
163 | v = v * w
164 | x = v.mean(-2)
165 | x = self.decoder(x)
166 |
167 | return x
168 |
169 |
170 | class CrossAttention(nn.Module):
171 | """Cross attention module."""
172 |
173 | def __init__(
174 | self,
175 | num_tokens,
176 | num_heads,
177 | hidden_dim,
178 | token_dim,
179 | num_layers=2,
180 | drop_rate=0.1,
181 | use_global_attention=2,
182 | ):
183 | super().__init__()
184 | self.num_tokens = num_tokens
185 | self.num_heads = num_heads
186 | self.hidden_dim = hidden_dim
187 | self.token_dim = token_dim
188 | self.num_layers = num_layers
189 | self.drop_rate = drop_rate
190 | self.use_global_attention = use_global_attention
191 |
192 | self.token_to_query = nn.Sequential(
193 | ResidualMLP(self.hidden_dim),
194 | nn.LayerNorm(self.hidden_dim),
195 | nn.Linear(self.hidden_dim, self.token_dim),
196 | )
197 |
198 | self.post_attention = nn.Sequential(
199 | ResidualMLP(self.token_dim),
200 | nn.LayerNorm(self.token_dim),
201 | nn.Linear(self.token_dim, self.hidden_dim),
202 | )
203 |
204 | cfg = transformers.LlamaConfig()
205 | cfg.hidden_size = self.hidden_dim
206 | cfg.intermediate_size = self.hidden_dim * 2
207 | cfg.num_attention_heads = 16
208 | cfg.num_key_value_heads = 16
209 | cfg.attention_dropout = self.drop_rate
210 |
211 | if self.use_global_attention > 0:
212 | self.global_attention = nn.ModuleList(
213 | [
214 | modeling_llama.LlamaDecoderLayer(cfg, i)
215 | for i in range(self.num_layers)
216 | ]
217 | )
218 |
219 | self.indiv_attention = nn.ModuleList(
220 | [modeling_llama.LlamaDecoderLayer(cfg, i) for i in range(self.num_layers)]
221 | )
222 |
223 | def forward(self, image_key, image_value, tokens):
224 | q = self.token_to_query(tokens)
225 | q = eo.rearrange(
226 | q, "b n d0 (h d1) -> b h (n d0) d1", h=self.num_heads
227 | ).contiguous()
228 | px_feedback = F.scaled_dot_product_attention(q, image_key, image_value)
229 | px_feedback = eo.rearrange(
230 | px_feedback, "b h (n d0) d1 -> b n d0 (h d1)", d0=self.num_tokens
231 | )
232 | tokens = tokens + self.post_attention(px_feedback)
233 |
234 | # Attention across all people
235 | if self.use_global_attention == 1:
236 | # Concatenate all tokens together
237 | tokens = eo.rearrange(tokens, "b n d0 d1 -> b (n d0) d1")
238 | tokens = _call_llama(tokens, self.global_attention)
239 | tokens = eo.rearrange(
240 | tokens, "b (n d0) d1 -> b n d0 d1", d0=self.num_tokens
241 | )
242 | elif self.use_global_attention == 2:
243 | # Do attention over people separately per-token
244 | tokens = eo.rearrange(tokens, "b n d0 d1 -> b d0 n d1")
245 | tokens = _call_llama(tokens, self.global_attention)
246 | tokens = eo.rearrange(tokens, "b d0 n d1 -> b n d0 d1")
247 |
248 | # Separate attention update per-person
249 | tokens = _call_llama(tokens, self.indiv_attention)
250 |
251 | return tokens
252 |
253 |
254 | class FusePyramid(nn.Module):
255 | """Module for fusing the feature pyramid."""
256 |
257 | def __init__(self, in_dim=256, hidden_dim=512):
258 | """Initialize FusePyramid module."""
259 | super().__init__()
260 |
261 | self.dc0 = nn.ConvTranspose2d(hidden_dim, hidden_dim, 2, 2)
262 | self.proj1 = nn.Conv2d(hidden_dim, hidden_dim, 1)
263 | self.dc1 = nn.ConvTranspose2d(hidden_dim, hidden_dim, 2, 2)
264 | self.ln1 = LayerNorm2d(hidden_dim)
265 | self.proj2 = nn.Conv2d(hidden_dim, hidden_dim, 1)
266 | self.dc2 = nn.ConvTranspose2d(hidden_dim, hidden_dim, 2, 2)
267 | self.ln2 = LayerNorm2d(hidden_dim)
268 | self.dc3 = nn.ConvTranspose2d(hidden_dim, in_dim, 4, 4)
269 | self.proj3 = nn.Conv2d(in_dim, in_dim, 1)
270 | self.ln3 = LayerNorm2d(in_dim)
271 |
272 | @helper.fixed_dim_op(nargs=4, is_class_fn=True)
273 | def forward(self, f64, f16, f8, f4):
274 | """Aggregate features across feature pyramid.
275 |
276 | Args:
277 | ----
278 | f64: features of shape (B, 3, 64, 64), assuming input res is 512x512.
279 | f16: features of shape (B, 3, 16, 16)
280 | f8: features of shape (B, 3, 8, 8)
281 | f4: features of shape (B, 3, 4, 4)
282 |
283 | Return:
284 | ------
285 | output features of shape (B, in_dim, 64, 64) with the same res as f64.
286 |
287 | """
288 | x = self.dc0(f4) + self.proj1(f8)
289 | x = self.ln1(F.gelu(x))
290 | x = self.dc1(x) + self.proj2(f16)
291 | x = self.ln2(F.gelu(x))
292 | x = self.dc3(x) + self.proj3(f64)
293 | x = self.ln3(x)
294 |
295 | return x
296 |
--------------------------------------------------------------------------------
/src/comotion_demo/models/refine.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | import os
3 | from collections import namedtuple
4 | from dataclasses import dataclass
5 |
6 | import einops as eo
7 | import torch
8 | from torch import nn
9 |
10 | from ..utils import helper, smpl_kinematics
11 | from . import layers
12 |
13 | curr_dir = os.path.abspath(os.path.dirname(__file__))
14 | PYTORCH_CHECKPOINT_PATH = f"{curr_dir}/../data/comotion_refine_checkpoint.pt"
15 |
16 | RefineOutput = namedtuple(
17 | "RefineOutput",
18 | [
19 | "delta_root_orient",
20 | "delta_body_pose",
21 | "trans",
22 | "hidden",
23 | ],
24 | )
25 |
26 |
27 | @dataclass
28 | class PoseRefinementConfig:
29 | num_tokens: int = 24
30 | num_heads: int = 4
31 | token_dim: int = 256
32 | hidden_dim: int = 512
33 | pose_embed_dim: int = 256
34 | normalizing_factor: int = 1024
35 |
36 |
37 | class PoseRefinement(nn.Module):
38 | """CoMotion refinement module."""
39 |
40 | def __init__(
41 | self,
42 | image_feat_dim: int,
43 | pose_embed_dim: int,
44 | cfg: PoseRefinementConfig | None = None,
45 | pretrained: bool = True,
46 | ):
47 | """Initialize refinement module.
48 |
49 | Args:
50 | ----
51 | image_feat_dim: image backbone output feature dimension.
52 | pose_embed_dim: pose embedding dimension.
53 | cfg: pose refinement config.
54 | pretrained: whether to load pre-trained weights.
55 |
56 | """
57 | super().__init__()
58 | cfg = PoseRefinementConfig() if cfg is None else cfg
59 | self.cfg = cfg
60 | self.image_feat_dim = image_feat_dim
61 | self.pose_embed_dim = pose_embed_dim
62 | self.pos_embedding = layers.PosEmbed()
63 | self.encode_grid = nn.Linear(self.pos_embedding.out_dim, cfg.token_dim)
64 | token_kv_in, token_kv_out = image_feat_dim + cfg.token_dim, cfg.token_dim
65 |
66 | self.get_px_key = nn.Linear(token_kv_in, token_kv_out)
67 | self.get_px_value = nn.Linear(token_kv_in, token_kv_out)
68 |
69 | # Output reference
70 | self.split_ref = [
71 | smpl_kinematics.BETA_DOF,
72 | pose_embed_dim,
73 | 3,
74 | smpl_kinematics.TRANS_DOF,
75 | ]
76 |
77 | # Hidden update layers
78 | self.tokens_to_hidden = layers.DecodeFromTokens(cfg.hidden_dim)
79 | self.feedback_gru_update = layers.GRU(cfg.hidden_dim, cfg.hidden_dim)
80 |
81 | def init_token_encoder(in_dim):
82 | return layers.ResidualMLP(
83 | in_dim,
84 | out_dim=cfg.num_tokens * cfg.hidden_dim,
85 | hidden_dim=2 * cfg.hidden_dim,
86 | pre_ln=False,
87 | )
88 |
89 | # Convert coordinates and hidden state to tokens
90 | cd_dim = smpl_kinematics.NUM_PARTS_AUX * (self.pos_embedding.out_dim + 1)
91 | dof = (
92 | smpl_kinematics.BETA_DOF
93 | + smpl_kinematics.POSE_DOF
94 | + smpl_kinematics.TRANS_DOF
95 | )
96 |
97 | self.encode_coords = init_token_encoder(cd_dim)
98 | self.encode_hidden = init_token_encoder(cfg.hidden_dim)
99 | self.encode_smpl = init_token_encoder(dof)
100 |
101 | cross_attention_kargs = {
102 | "num_tokens": cfg.num_tokens,
103 | "num_heads": cfg.num_heads,
104 | "hidden_dim": cfg.hidden_dim,
105 | "token_dim": cfg.token_dim,
106 | }
107 |
108 | self.cross_attention = nn.Sequential(
109 | layers.CrossAttention(**cross_attention_kargs),
110 | layers.CrossAttention(**cross_attention_kargs),
111 | )
112 |
113 | decode = layers.DecodeFromTokens
114 |
115 | self.get_px_feedback = decode(cfg.hidden_dim)
116 | self.tokens_to_hidden = nn.Sequential(
117 | decode(cfg.hidden_dim),
118 | nn.LayerNorm(cfg.hidden_dim),
119 | )
120 | self.project_hidden = nn.Linear(cfg.hidden_dim, cfg.hidden_dim)
121 |
122 | self.get_pose_update = nn.Sequential(
123 | layers.ResidualMLP(cfg.hidden_dim),
124 | nn.LayerNorm(cfg.hidden_dim),
125 | nn.GELU(),
126 | nn.Linear(cfg.hidden_dim, sum(self.split_ref)),
127 | )
128 |
129 | # This is the same module used in the detection step (with identical weights)
130 | # We reinstantiate it to support using the separate CoreML detection stage
131 | self.pose_decoder = nn.Sequential(
132 | nn.LayerNorm(cfg.pose_embed_dim),
133 | nn.Linear(cfg.pose_embed_dim, cfg.hidden_dim),
134 | layers.ResidualMLP(cfg.hidden_dim),
135 | nn.LayerNorm(cfg.hidden_dim),
136 | nn.GELU(),
137 | nn.Linear(cfg.hidden_dim, smpl_kinematics.POSE_DOF - 3),
138 | )
139 |
140 | if pretrained:
141 | checkpoint = torch.load(PYTORCH_CHECKPOINT_PATH, weights_only=True)
142 | self.load_state_dict(checkpoint)
143 |
144 | @helper.fixed_dim_op(d=4, is_class_fn=True)
145 | def compute_image_kv(self, image_feats, pooling=8):
146 | """From image features, get flattened set of key, value token pairs."""
147 | batch_size = image_feats.shape[0]
148 | res = image_feats.shape[-2:]
149 | device = image_feats.device
150 |
151 | # Rearrange and flatten image features
152 | image_feats = eo.rearrange(image_feats, "b c h w -> b (h w) c")
153 |
154 | # Calculate reference grid of pixel positions
155 | with torch.no_grad():
156 | px_scale_factor = max(res) * pooling
157 | grid = helper.get_grid(res[0], res[1], device)
158 | grid = grid * px_scale_factor / self.cfg.normalizing_factor
159 |
160 | grid_embed = self.encode_grid(self.pos_embedding(grid))
161 | grid_embed = eo.repeat(grid_embed, "h w d -> b (h w) d", b=batch_size)
162 | image_feats = torch.cat([image_feats, grid_embed], -1)
163 |
164 | # Apply a linear layer to get keys and values
165 | pixel_k = eo.rearrange(
166 | self.get_px_key(image_feats),
167 | "... d0 (h d1) -> ... h d0 d1",
168 | h=self.cfg.num_heads,
169 | ).contiguous()
170 | pixel_v = eo.rearrange(
171 | self.get_px_value(image_feats),
172 | "... d0 (h d1) -> ... h d0 d1",
173 | h=self.cfg.num_heads,
174 | ).contiguous()
175 |
176 | return pixel_k, pixel_v
177 |
178 | def calib_adjusted_trans(self, K, trans, delta_trans, depth_adj=128, eps=0.05):
179 | """Update SMPL translation term based on provided intrinsics.
180 |
181 | This same operation is performed during detection, the output x, y
182 | terms are in pixel space and mapped to world coordinates using K.
183 | """
184 | delta_xy, delta_z = delta_trans.split_with_sizes(dim=-1, split_sizes=[2, 1])
185 |
186 | # Get new depth estimate
187 | default_depth = K[..., 0, 0][:, None, None] / depth_adj
188 | z = default_depth / (torch.exp(delta_z) + eps)
189 |
190 | # Apply delta to current x, y position in pixel space
191 | base_xy = helper.project_to_2d(K.unsqueeze(1), trans)
192 | xy = helper.px_to_world(K.unsqueeze(1), base_xy + delta_xy) * z
193 | return torch.cat([xy, z], -1)
194 |
195 | def encode_state(
196 | self,
197 | K,
198 | hidden,
199 | pred_3d,
200 | body_params,
201 | ):
202 | """Encode tracks into tokens.
203 |
204 | Hidden state, SMPL parameters, and 2D keypoint coordinates are all
205 | passed through an MLP to produce a set of tokens per-person.
206 | """
207 | # Project to 2D
208 | K = K[:, None, None] / self.cfg.normalizing_factor
209 | xy = helper.project_to_2d(K, pred_3d).clamp(-5, 5)
210 | z = pred_3d[..., 2]
211 | cd_embed = eo.rearrange(self.pos_embedding(xy), "... d0 d1 -> ... (d0 d1)")
212 | cd_embed = torch.cat([cd_embed, z], -1)
213 |
214 | # Encode tokens
215 | tokens = self.encode_coords(cd_embed)
216 | tokens = tokens + self.encode_hidden(hidden)
217 | tokens = tokens + self.encode_smpl(body_params)
218 | tokens = eo.rearrange(
219 | tokens, "... (d0 d1) -> ... d0 d1", d0=self.cfg.num_tokens
220 | )
221 |
222 | return tokens
223 |
224 | def perform_update(self, K, pixel_k, pixel_v, tokens, trans, hidden):
225 | """Attend to image features and calculate final outputs.
226 |
227 | Args:
228 | ----
229 | K: Input intrinsics
230 | pixel_k: Flattened set of image token keys.
231 | pixel_v: Flattened set of image token values.
232 | tokens: Feature encoding of current tracks.
233 | trans: SMPL translation term for each track.
234 | hidden: Hidden state for each track.
235 |
236 | """
237 | # Perform cross attention to get feedback from image features
238 | for ca in self.cross_attention:
239 | tokens = ca(pixel_k, pixel_v, tokens)
240 |
241 | # Update hidden state
242 | hidden_update = self.tokens_to_hidden(tokens)
243 | hidden = self.feedback_gru_update(hidden, hidden_update)
244 |
245 | # Update current state
246 | px_feedback = self.get_px_feedback(tokens)
247 | px_feedback = px_feedback + self.project_hidden(hidden)
248 |
249 | delta_smpl = self.get_pose_update(px_feedback)
250 | delta_smpl = delta_smpl.split_with_sizes(dim=-1, split_sizes=self.split_ref)
251 | _, pose_embedding, delta_root_orient, delta_trans = delta_smpl
252 |
253 | delta_body_pose = self.pose_decoder(pose_embedding) * 0.3
254 | trans = self.calib_adjusted_trans(K, trans, delta_trans)
255 | return RefineOutput(delta_root_orient, delta_body_pose, trans, hidden)
256 |
257 | def forward(self, image_feats, K, betas, pose, trans, pred_3d, hidden, pooling=8):
258 | """Predict new poses given image features and current tracks.
259 |
260 | Args:
261 | ----
262 | image_feats: Per-pixel image features.
263 | K: Input intrinsics.
264 | betas: SMPL beta parameters for each track.
265 | pose: SMPL pose parameters for each track.
266 | trans: SMPL translation term for each track.
267 | pred_3d: 3D keypoints in camera coordinate frame for each track.
268 | hidden: Hidden state for each track.
269 | pooling: Indicator of how image features have been pooled from input image.
270 |
271 | """
272 | body_params = torch.cat([betas, pose, trans], -1)
273 | pixel_k, pixel_v = self.compute_image_kv(image_feats, pooling=pooling)
274 | tokens = self.encode_state(K, hidden, pred_3d, body_params)
275 | return self.perform_update(K, pixel_k, pixel_v, tokens, trans, hidden)
276 |
--------------------------------------------------------------------------------
/src/comotion_demo/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | """CoMotion utility functions."""
3 |
--------------------------------------------------------------------------------
/src/comotion_demo/utils/dataloading.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | import logging
3 | from pathlib import Path
4 | from typing import Generator, Tuple
5 |
6 | import cv2
7 | import numpy as np
8 | import torch
9 | from numpy.typing import NDArray
10 | from PIL import Image
11 | from torchvision import transforms
12 |
13 | IMG_MEAN = torch.tensor([0.4850, 0.4560, 0.4060]).view(-1, 1, 1)
14 | IMG_STD = torch.tensor([0.2290, 0.2240, 0.2250]).view(-1, 1, 1)
15 | INTERNAL_RESOLUTION = (512, 512)
16 | VIDEO_EXTENSIONS = {".mp4"}
17 | IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"}
18 |
19 |
20 | def normalize_image(image: torch.Tensor) -> torch.Tensor:
21 | """Apply ImageNet normalization to an image tensor."""
22 | if not isinstance(image, torch.Tensor):
23 | raise ValueError("Expect input to be a torch.Tensor")
24 | return (image - IMG_MEAN) / IMG_STD
25 |
26 |
27 | def unnormalize_image(image: torch.Tensor) -> torch.Tensor:
28 | """Undo ImageNet normalization to an image tensor."""
29 | if not isinstance(image, torch.Tensor):
30 | raise ValueError("Expect input to be a torch.Tensor")
31 | return image * IMG_STD + IMG_MEAN
32 |
33 |
34 | def convert_image_to_tensor(image: NDArray[np.uint8]) -> torch.Tensor:
35 | """Convert an uint8 numpy array of shape HWC to a float tensor of shape CHW."""
36 | if not isinstance(image, np.ndarray):
37 | raise ValueError("Expect input to be a numpy array.")
38 | if image.dtype != np.uint8:
39 | raise ValueError("Expect input to be np.uint8 typed.")
40 | return torch.from_numpy(image).permute(2, 0, 1)
41 |
42 |
43 | def convert_tensor_to_image(tensor: torch.Tensor) -> NDArray[np.uint8]:
44 | """Convert a float tensor of shape CHW to an uint8 numpy array of shape HWC."""
45 | if not isinstance(tensor, torch.Tensor):
46 | raise ValueError("Expect input to be a torch Tensor")
47 | return tensor.permute(1, 2, 0).cpu().numpy().astype(np.uint8)
48 |
49 |
50 | def crop_image_and_update_K(
51 | image: torch.Tensor,
52 | K: torch.Tensor,
53 | target_resolution: Tuple[int, int] = INTERNAL_RESOLUTION,
54 | ) -> Tuple[torch.Tensor, torch.Tensor]:
55 | """Pad and resize image to target resolution and update intrinsics."""
56 | target_height, target_width = target_resolution
57 | target_hw_ratio = target_height / target_width
58 | source_height, source_width = torch.tensor(image.shape[-2:])
59 | source_hw_ratio = source_height / source_width
60 |
61 | if source_hw_ratio >= target_hw_ratio:
62 | # pad needed along the width axis
63 | crop_height = source_height
64 | crop_width = int(source_height / target_hw_ratio)
65 | else:
66 | # pad needed along the height axis
67 | crop_height = int(source_width * target_hw_ratio)
68 | crop_width = source_width
69 |
70 | img_center_x = source_width / 2
71 | img_center_y = source_height / 2
72 |
73 | offset_x = int(img_center_x - crop_width / 2)
74 | offset_y = int(img_center_y - crop_height / 2)
75 |
76 | crop_args = [
77 | offset_y,
78 | offset_x,
79 | crop_height,
80 | crop_width,
81 | target_resolution,
82 | transforms.InterpolationMode.BILINEAR,
83 | ]
84 | # Pad, crop, and resize image
85 | cropped_image = transforms.functional.resized_crop(
86 | image, *crop_args, antialias=True
87 | )
88 |
89 | scale_y = target_height / crop_height
90 | scale_x = target_width / crop_width
91 |
92 | cropped_K = K.clone()
93 | cropped_K[..., 0, 0] *= scale_x
94 | cropped_K[..., 1, 1] *= scale_y
95 | cropped_K[..., 0, 2] = (cropped_K[..., 0, 2] - offset_x) * scale_x
96 | cropped_K[..., 1, 2] = (cropped_K[..., 1, 2] - offset_y) * scale_y
97 |
98 | return cropped_image, cropped_K
99 |
100 |
101 | def yield_image_from_directory(
102 | directory: Path,
103 | start_frame: int,
104 | num_frames: int,
105 | frameskip: int = 1,
106 | ) -> Generator[NDArray[np.uint8], None, None]:
107 | """Generate the next frame from a directory."""
108 | if not directory.is_dir():
109 | raise ValueError(f"Path is not a directory: {directory}")
110 | image_files = sorted(
111 | [
112 | file
113 | for file in directory.glob("*")
114 | if file.is_file() and file.suffix.lower() in IMAGE_EXTENSIONS
115 | ]
116 | )
117 | if not image_files:
118 | raise ValueError(f"No images found in directory: {directory}")
119 |
120 | image_files = image_files[start_frame::frameskip]
121 |
122 | if len(image_files) > num_frames:
123 | image_files = image_files[:num_frames]
124 |
125 | for image_file in image_files:
126 | image = np.array(Image.open(image_file).convert("RGB"))
127 | yield image
128 |
129 |
130 | def yield_image_from_video(
131 | filepath: Path,
132 | start_frame: int,
133 | num_frames: int,
134 | frameskip: int = 1,
135 | ) -> Generator[NDArray[np.uint8], None, None]:
136 | """Generate the next frame from a video."""
137 | if filepath.suffix.lower() not in VIDEO_EXTENSIONS:
138 | raise ValueError(f"Input file is not a video: {filepath}")
139 |
140 | video = cv2.VideoCapture(filepath.as_posix())
141 |
142 | if not video.isOpened():
143 | raise ValueError(f"Could not open video file: {filepath}")
144 |
145 | max_frames = video.get(cv2.CAP_PROP_FRAME_COUNT)
146 | logging.info(
147 | f"Yielding {num_frames} frames from {filepath} ({max_frames} in total)."
148 | )
149 |
150 | if max_frames <= start_frame:
151 | logging.warning(f"Cannot start on frame {start_frame}.")
152 | video.release()
153 | return
154 |
155 | success = video.set(cv2.CAP_PROP_POS_FRAMES, start_frame)
156 | if success:
157 | logging.info(f"Starting from frame {start_frame}")
158 |
159 | frame_limit = min(num_frames, max_frames - start_frame)
160 | frame_count = 0
161 | while frame_count < frame_limit:
162 | success, frame = video.read()
163 | if not success:
164 | break
165 | image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
166 | yield image
167 | frame_count += 1
168 |
169 | # Skip intermediate frames if frameskip > 1
170 | for _ in range(1, frameskip):
171 | success, frame = video.read()
172 | if not success:
173 | break
174 |
175 | video.release()
176 |
177 |
178 | def is_a_video(input_path: Path):
179 | return input_path.suffix in VIDEO_EXTENSIONS
180 |
181 |
182 | def get_input_video_fps(input_path: Path) -> float:
183 | cap = cv2.VideoCapture(input_path.as_posix())
184 | if not cap.isOpened():
185 | raise RuntimeError(f"Failed to load video at {input_path}.")
186 |
187 | fps = cap.get(cv2.CAP_PROP_FPS)
188 | if fps == 0:
189 | raise RuntimeError("Failed to retrieve FPS.")
190 |
191 | return fps
192 |
193 |
194 | def yield_image(
195 | input_path: Path,
196 | start_frame: int,
197 | num_frames: int,
198 | frameskip: int = 1,
199 | ) -> Generator[NDArray[np.uint8], None, None]:
200 | """Generate the next frame from either a video or a directory."""
201 | if is_a_video(input_path):
202 | yield from yield_image_from_video(
203 | input_path, start_frame, num_frames, frameskip
204 | )
205 | elif input_path.is_dir():
206 | yield from yield_image_from_directory(
207 | input_path, start_frame, num_frames, frameskip
208 | )
209 | else:
210 | raise ValueError("Input path must point to a video file or a directory")
211 |
212 |
213 | def get_default_K(image: torch.Tensor) -> torch.Tensor:
214 | """Get a default approximate intrinsic matrix."""
215 | res = image.shape[-2:]
216 | max_res = max(res)
217 | K = torch.tensor([[2 * max_res, 0, 0.5 * res[1]], [0, 2 * max_res, 0.5 * res[0]]])
218 | return K
219 |
220 |
221 | def yield_image_and_K(
222 | input_path: Path,
223 | start_frame: int,
224 | num_frames: int,
225 | frameskip: int = 1,
226 | ) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:
227 | """Generate image and the intrinsic matrix."""
228 | for image in yield_image(input_path, start_frame, num_frames, frameskip):
229 | image = convert_image_to_tensor(image)
230 | K = get_default_K(image)
231 | yield (image, K)
232 |
233 |
234 | def prepare_network_inputs(
235 | image: torch.Tensor,
236 | K: torch.Tensor,
237 | device: torch.device | str = "cpu",
238 | ) -> Tuple[torch.Tensor, torch.Tensor]:
239 | """Image and intrinsics prep before inference.
240 |
241 | We crop and pad the input to a 512x512 image and update the intrinsics
242 | accordingly. The input is also expected to be ImageNet normalized.
243 |
244 | This demo code only supports processing individual samples. Some operations
245 | assume the existence of a batch dimension, so we add a singleton batch
246 | dimension here. Other operations such as NMS and track management do not
247 | correctly handle batched inputs.
248 |
249 | Args:
250 | ----
251 | image: Input image (C x H x W) float (0-1) or uint8 (0-255) tensor
252 | K: Intrinsics matrix (2 x 3) float tensor
253 | device: Target device for inference
254 |
255 | """
256 | if image.dtype == torch.uint8:
257 | # Cast to float and normalize to 0-1
258 | image = image.float() / 255
259 | cropped_image, cropped_K = crop_image_and_update_K(image, K)
260 | cropped_image = normalize_image(cropped_image)
261 |
262 | # Add "batch" dimension and cast to target device
263 | cropped_image = cropped_image[None].to(device)
264 | cropped_K = cropped_K[None].to(device)
265 |
266 | return cropped_image, cropped_K
267 |
--------------------------------------------------------------------------------
/src/comotion_demo/utils/helper.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 | """Miscellaneous helper functions."""
3 |
4 | import functools
5 |
6 | import einops as eo
7 | import numpy as np
8 | import torch
9 |
10 |
11 | def px_to_world(K, pts):
12 | """Convert values from pixel coordinates to world coordinates."""
13 | x, y = pts.unbind(-1)
14 | return torch.stack(
15 | [(x - K[..., 0, 2]) / K[..., 0, 0], (y - K[..., 1, 2]) / K[..., 1, 1]], -1
16 | )
17 |
18 |
19 | def project_to_2d(K, pts3d, min_depth=0.001):
20 | """Project 3D points to pixel coordinates.
21 |
22 | Args:
23 | ----
24 | K: ... x 2 x 3
25 | pts3d: ... x 3
26 | min_depth: minimum depth for clamping (default: 0.001).
27 |
28 | """
29 | # Normalize by depth
30 | z = pts3d[..., -1:].clamp_min(min_depth)
31 | pts3d_norml = pts3d / z
32 | pts3d_norml[..., -1].fill_(1.0)
33 |
34 | # Apply intrinsics
35 | pts_2d = eo.einsum(K, pts3d_norml, "... i j, ... j -> ... i")
36 |
37 | return pts_2d
38 |
39 |
40 | def _merge_aux(arr, dim=0, n_merged=2, unmerge=False, vals=None):
41 | n_extra = -(dim + 1) if dim < 0 else dim
42 | extra_d = " ".join(map(lambda x: f"a{x}", range(n_extra)))
43 | di = " ".join(map(lambda x: f"d{x}", range(n_merged)))
44 | df = f"({di})"
45 | kargs = {}
46 |
47 | if unmerge:
48 | di, df = df, di
49 | if isinstance(vals, int):
50 | vals = [vals]
51 | for v_idx, v in enumerate(vals):
52 | kargs[f"d{v_idx}"] = v
53 |
54 | if dim < 0:
55 | return eo.rearrange(arr, f"... {di} {extra_d} -> ... {df} {extra_d}", **kargs)
56 | else:
57 | return eo.rearrange(arr, f"{extra_d} {di} ... -> {extra_d} {df} ...", **kargs)
58 |
59 |
60 | def _merge_d(arr, dim=0, n_merged=2):
61 | return _merge_aux(arr, dim, n_merged)
62 |
63 |
64 | def _unmerge_d(arr, vals, dim=0, n_merged=2):
65 | return _merge_aux(arr, dim, n_merged, True, vals)
66 |
67 |
68 | def fixed_dim_op(fn_=None, d=4, nargs=1, is_class_fn=False):
69 | def _decorator(fn):
70 | @functools.wraps(fn)
71 | def run_fixed_dim(*args, **kargs):
72 | args = [a for a in args]
73 | if is_class_fn:
74 | self = args[0]
75 | args = args[1:]
76 |
77 | tmp_nargs = min(len(args), nargs)
78 | ndim = args[0].ndim
79 | if ndim > d:
80 | # If input is greater than required dims
81 | # flatten first dimensions together
82 | n = ndim - d + 1
83 | d_ref = args[0].shape[:n]
84 | for i in range(tmp_nargs):
85 | args[i] = _merge_d(args[i], 0, n)
86 |
87 | def adjust_fn(x):
88 | return _unmerge_d(x, d_ref, 0, n) if x is not None else x
89 |
90 | elif ndim < d:
91 | # If input is less than the required dims
92 | # prepend extra singleton dimensions
93 | n = d - ndim
94 | one_str = " ".join(["1"] * n)
95 | for i in range(tmp_nargs):
96 | args[i] = eo.rearrange(args[i], f"... -> {one_str} ...")
97 |
98 | def adjust_fn(x):
99 | return (
100 | eo.rearrange(x, f"{one_str} ... -> ...") if x is not None else x
101 | )
102 |
103 | else:
104 | # Otherwise, do nothing
105 | def adjust_fn(x):
106 | return x
107 |
108 | if is_class_fn:
109 | args = [self, *args]
110 | fn_out = fn(*args, **kargs)
111 |
112 | if isinstance(fn_out, tuple) or isinstance(fn_out, list):
113 | return [adjust_fn(v) for v in fn_out]
114 | elif isinstance(fn_out, dict):
115 | return {k: adjust_fn(v) for k, v in fn_out.items()}
116 | else:
117 | return adjust_fn(fn_out)
118 |
119 | return run_fixed_dim
120 |
121 | if fn_ is not None:
122 | return _decorator(fn_)
123 | else:
124 | return _decorator
125 |
126 |
127 | def get_grid(ht, wd, device="cpu"):
128 | """Get coordinate grid for a set of features.
129 |
130 | Assume pixels are normalized to 1 (aspect-ratio adjusted so 1 is max(ht,wd))
131 | """
132 | grid = torch.meshgrid(
133 | torch.arange(wd, device=device), torch.arange(ht, device=device), indexing="xy"
134 | )
135 | grid = torch.stack(grid, -1) / float(max(ht, wd))
136 | return grid
137 |
138 |
139 | _link_ref = torch.tensor(
140 | [
141 | (23, 24), # eyes
142 | (22, 16), # nose -> shoulder
143 | (22, 17), # nose -> shoulder
144 | (16, 17), # shoulders
145 | (16, 1), # left shoulder -> hip
146 | (17, 2), # right shoulder -> hip
147 | (16, 18), # left shoulder -> elbow
148 | (17, 19), # right shoulder -> elbow
149 | (20, 18), # left wrist -> elbow
150 | (21, 19), # right wrist -> elbow
151 | (4, 1), # left knee -> hip
152 | (5, 2), # right knee -> hip
153 | (4, 7), # left knee -> ankle
154 | (5, 8), # right knee -> ankle
155 | ]
156 | ).T
157 |
158 |
159 | # Calculated based on default SMPL with above links
160 | _link_dists_ref = [
161 | 0.07,
162 | 0.3,
163 | 0.3,
164 | 0.35,
165 | 0.55,
166 | 0.55,
167 | 0.26,
168 | 0.26,
169 | 0.25,
170 | 0.25,
171 | 0.38,
172 | 0.38,
173 | 0.4,
174 | 0.4,
175 | ]
176 |
177 |
178 | def _get_skeleton_scale(kps: torch.Tensor, valid=None):
179 | """Return approximate "size" of person in image.
180 |
181 | Instead of bounding box dimensions, we compare pairs of keypoints as defined
182 | in the links above. We ignore any distances that involve an "invalid" keypoint.
183 | We assume any values set to (0, 0) are invalid and should be ignored.
184 | """
185 | invalid = (kps == 0).all(-1)
186 | if valid is not None:
187 | invalid |= ~(valid > 0)
188 |
189 | # Calculate link distances
190 | dists = (kps[..., _link_ref[0], :] - kps[..., _link_ref[1], :]).norm(dim=-1)
191 |
192 | # Compare to reference distances
193 | ratio = dists / torch.tensor(_link_dists_ref, device=dists.device)
194 |
195 | # Zero out any invalid links
196 | invalid = invalid[..., _link_ref[0]] | invalid[..., _link_ref[1]]
197 | ratio *= (~invalid).float()
198 |
199 | # Return max ratio which corresponds to limb with least foreshortening
200 | max_ratio = ratio.max(-1)[0]
201 | return max_ratio.clamp_min(0.001)
202 |
203 |
204 | def normalized_weighted_score(
205 | kp0,
206 | c0,
207 | kp1,
208 | c1,
209 | std=0.08,
210 | return_dists=False,
211 | ref_scale=None,
212 | min_scale=0.02,
213 | max_scale=1,
214 | fixed_scale=None,
215 | ):
216 | """Measure similarity of two sets of body pose keypoints.
217 |
218 | This is a modified version of the COCO object keypoint similarity (OKS)
219 | calculation using a Cauchy distribution instead of a Gaussian. We also
220 | compute a normalizing scale on the fly based on projected limb proportions.
221 | """
222 | # Combine confidence terms
223 | # conf: ... num_people x num_people x num_points
224 | conf = (c0.unsqueeze(-2) * c1.unsqueeze(-3)) ** 0.5
225 |
226 | # Calculate scale adjustment
227 | scale0 = _get_skeleton_scale(kp0, c0)
228 | scale1 = _get_skeleton_scale(kp1, c1)
229 | if ref_scale is None:
230 | scale = scale0.unsqueeze(-1).maximum(scale1.unsqueeze(-2))
231 | elif ref_scale == 0:
232 | scale = scale0.unsqueeze(-1)
233 | elif ref_scale == 1:
234 | scale = scale1.unsqueeze(-2)
235 |
236 | # Set scale bounds
237 | zero_mask = scale == 0
238 | scale.clamp_(min_scale, max_scale)
239 | scale[zero_mask] = 1e-6
240 | if fixed_scale is not None:
241 | scale[:] = fixed_scale
242 |
243 | # Scale-adjusted distance calculation
244 | # kp: ... num_people x num_pts x 2
245 | # dists: ... num_people x num_people x num_pts
246 | dists = (kp0.unsqueeze(-3) - kp1.unsqueeze(-4)).norm(dim=-1)
247 | dists = dists / scale.unsqueeze(-1)
248 |
249 | total_conf = conf.sum(-1)
250 | zero_filt = total_conf == 0
251 | scores = 1 / (1 + (dists / std) ** 2)
252 | scores = (scores * conf).sum(-1)
253 | scores = scores / total_conf
254 | scores[zero_filt] = 0
255 |
256 | if return_dists:
257 | return scores, dists
258 | else:
259 | return scores
260 |
261 |
262 | def nms_detections(pred_2d, conf, conf_thr=0.2, iou_thr=0.4, std=0.08):
263 | """Compare keypoint estimates and return indices of nonoverlapping estimates.
264 |
265 | Note: input predictions are expected to be normalized (e.g. ranging from 0 to 1),
266 | not in pixel coordinate space.
267 | """
268 | conf_sigmoid = torch.sigmoid(conf)
269 |
270 | # Get indices of estimates above confidence threshold, sorted by confidence
271 | sorted_idxs = (-conf).argsort()
272 | sorted_idxs = sorted_idxs[conf_sigmoid[sorted_idxs] > conf_thr]
273 |
274 | # Compare all pairs of keypoint estimates
275 | p = pred_2d[sorted_idxs]
276 | c = torch.ones_like(p[..., 0])
277 | ious = normalized_weighted_score(p, c, p, c, std=std).cpu()
278 |
279 | # Identify estimates that are lower confidence and too similar to another estimate
280 | triu = torch.triu_indices(len(p), len(p), offset=1)
281 | ious = ious[triu[0], triu[1]]
282 | to_remove = triu[1][ious > iou_thr]
283 |
284 | # Return the subset of estimates to keep
285 | to_keep = np.setdiff1d(np.arange(len(p)), to_remove.unique().numpy())
286 | final_idxs = sorted_idxs.cpu()[to_keep]
287 | return final_idxs
288 |
289 |
290 | def check_inbounds(kps, res):
291 | """Return binary mask indicating which keypoints are inbounds."""
292 | in_x = (kps[..., 0] > 0) & (kps[..., 0] < res[1])
293 | in_y = (kps[..., 1] > 0) & (kps[..., 1] < res[0])
294 | inbounds = in_x & in_y
295 | return inbounds
296 |
297 |
298 | def points_to_bbox2d(pts, pad_dims=None):
299 | """Get min and max range of set of keypoints to define 2d bounding-box.
300 |
301 | Output format is ((x0, y0), (x1, y1)).
302 |
303 | Args:
304 | ----
305 | pts: ... x K x 2
306 | pad_dims: optional padding (as percentage of bounding-box size)
307 |
308 | Output:
309 | ----
310 | bboxes: ... x 2 x 2
311 |
312 | """
313 | p = torch.stack([pts.min(-2)[0], pts.max(-2)[0]], -2)
314 |
315 | if pad_dims is not None:
316 | dimensions = p[..., 1, :] - p[..., 0, :]
317 | scale = dimensions.max(dim=-1)[0]
318 | pad_x = scale * pad_dims[0]
319 | pad_y = scale * pad_dims[1]
320 |
321 | p[..., 0, 0] -= pad_x
322 | p[..., 0, 1] -= pad_y
323 | p[..., 1, 0] += pad_x
324 | p[..., 1, 1] += pad_y
325 |
326 | return p
327 |
328 |
329 | def hsv2rgb(hsv):
330 | """Convert a tuple from HSV to RGB."""
331 | h, s, v = hsv
332 | vals = np.tile(v, [3] + [1] * v.ndim)
333 | vals[1:] *= 1 - s[None]
334 |
335 | h[h > (5 / 6)] -= 1
336 | diffs = np.tile(h, [3] + [1] * h.ndim) - (np.arange(3) / 3).reshape(
337 | 3, *[1] * h.ndim
338 | )
339 | max_idx = np.abs(diffs).argmin(0)
340 |
341 | final_rgb = np.zeros_like(vals)
342 |
343 | for i in range(3):
344 | tmp_d = diffs[i] * (max_idx == i)
345 | dv = tmp_d * 6 * s * v
346 | vals[1] += np.maximum(0, dv)
347 | vals[2] += np.maximum(0, -dv)
348 |
349 | final_rgb += np.roll(vals, i, axis=0) * (max_idx == i)
350 |
351 | return final_rgb.transpose(*list(np.arange(h.ndim) + 1), 0)
352 |
353 |
354 | def init_color_ref(n, seed=12345):
355 | """Sample N colors in HSV and convert them to RGB."""
356 | rand_state = np.random.RandomState(seed)
357 | rand_hsv = rand_state.rand(n, 3)
358 | rand_hsv[:, 1:] = 1 - rand_hsv[:, 1:] * 0.3
359 | color_ref = hsv2rgb(rand_hsv.T)
360 |
361 | return color_ref
362 |
363 |
364 | color_ref = init_color_ref(2000)
365 |
--------------------------------------------------------------------------------
/src/comotion_demo/utils/smpl_kinematics.py:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved.
2 |
3 | import os
4 | import pickle
5 |
6 | import chumpy
7 | import einops as eo
8 | import numpy as np
9 | import pypose
10 | import torch
11 | from torch import nn
12 |
13 | BETA_DOF = 10
14 | POSE_DOF = 72
15 | TRANS_DOF = 3
16 | FACE_VERTEX_IDXS = [332, 6260, 2800, 4071, 583]
17 | NUM_PARTS_AUX = 27
18 |
19 | smpl_dir = os.path.join(os.path.dirname(__file__), "../data/smpl")
20 | extra_ref = torch.load(f"{smpl_dir}/extra_smpl_reference.pt", weights_only=True)
21 | smpl_model_path = f"{smpl_dir}/SMPL_NEUTRAL.pkl"
22 | assert os.path.exists(smpl_model_path), (
23 | "Please download the neutral SMPL body model from https://smpl.is.tue.mpg.de/ and"
24 | "rename it to SMPL_NEUTRAL.pkl, copying it into src/comotion_demo/data/smpl/"
25 | )
26 |
27 |
28 | def to_rotmat(theta):
29 | """Convert axis-angle to rotation matrix."""
30 | rotmat = pypose.so3(theta.view(-1, 3)).matrix()
31 | return rotmat.reshape(*theta.shape, 3)
32 |
33 |
34 | @torch.compiler.disable
35 | def update_pose(pose, delta_pose):
36 | """Apply residual to an existing set of joint angles.
37 |
38 | Instead of adding the update to the current pose directly, we apply the
39 | update as a rotation.
40 | """
41 | # Convert flattened pose to K x 3 set of axis-angle terms
42 | pose = eo.rearrange(pose, "... (k c) -> ... k c", c=3).contiguous()
43 | delta_pose = eo.rearrange(delta_pose, "... (k c) -> ... k c", c=3).contiguous()
44 |
45 | # Change to a rotation matrix and multiply the matrices together
46 | pose: pypose.SO3_type = pypose.so3(delta_pose).Exp() * pypose.so3(pose).Exp()
47 |
48 | # Map back to axis-angle representation and flatten again
49 | pose = pose.Log().tensor()
50 | return eo.rearrange(pose, "... k c -> ... (k c)")
51 |
52 |
53 | class SMPLKinematics(nn.Module):
54 | """Parse SMPL parameters to get mesh and joint coordinates."""
55 |
56 | def __init__(self):
57 | """Initialize SMPLKinematics."""
58 | super().__init__()
59 |
60 | self.num_tf = 24
61 |
62 | # Load SMPL neutral model
63 | with open(smpl_model_path, "rb") as f:
64 | smpl_neutral = pickle.load(f, encoding="latin1")
65 |
66 | # Prepare model data and convert to torch tensors
67 | smpl_keys = [
68 | "J_regressor",
69 | "v_template",
70 | "posedirs",
71 | "shapedirs",
72 | "weights",
73 | "kintree_table",
74 | ]
75 | for k in smpl_keys:
76 | v = smpl_neutral[k]
77 | if isinstance(v, chumpy.ch.Ch):
78 | v = np.array(v)
79 | elif not isinstance(v, np.ndarray):
80 | v = v.toarray()
81 | v = torch.tensor(v).float()
82 |
83 | if k == "shapedirs":
84 | v = v[..., :BETA_DOF].view(-1, BETA_DOF).contiguous()
85 | elif k == "posedirs":
86 | v = v.view(-1, v.shape[-1]).T
87 | elif k == "kintree_table":
88 | k = "parents"
89 | v = v[0].long()
90 | v[0] = -1
91 |
92 | self.register_buffer(k, v, persistent=False)
93 |
94 | # Precompute intermediate matrices
95 | self.register_buffer(
96 | "joint_shapedirs",
97 | (self.J_regressor @ self.shapedirs.view(-1, 3 * BETA_DOF)).view(
98 | self.num_tf * 3, BETA_DOF
99 | ),
100 | persistent=False,
101 | )
102 | self.register_buffer(
103 | "joint_template", self.J_regressor @ self.v_template, persistent=False
104 | )
105 |
106 | # Load additional reference tensors
107 | self.register_buffer("mean_pose", extra_ref["mean_pose"], persistent=False)
108 | self.register_buffer("coco_map", extra_ref["coco_map"], persistent=False)
109 |
110 | def forward_kinematics(self, rot_mat, unposed_tx):
111 | # Get relative joint positions
112 | relative_tx = unposed_tx.clone()
113 | relative_tx[1:] -= relative_tx[self.parents[1:]]
114 |
115 | # Define transformation matrices
116 | tf = torch.cat([rot_mat, relative_tx.unsqueeze(-1)], -1)
117 | tf = torch.functional.F.pad(tf, [0, 0, 0, 1])
118 | tf[..., 3, 3] = 1
119 |
120 | # Follow kinematic chain
121 | tf_chain = [tf[0]]
122 | for child_idx in range(1, self.num_tf):
123 | parent_idx = self.parents[child_idx]
124 | tf_chain.append(tf_chain[parent_idx] @ tf[child_idx])
125 | tf_chain = torch.stack(tf_chain)
126 |
127 | # Get output joint positions
128 | joints = eo.rearrange(tf_chain[..., :3, 3], "k ... c -> ... k c")
129 | return tf_chain, joints
130 |
131 | def get_smpl_template(self, subsample_rate=1, idx_subset=None):
132 | """Return reference tensors to derive SMPL mesh.
133 |
134 | Optional flags allow us to subsample a subset of vertices from the mesh to
135 | avoid computation over the complete mesh.
136 | """
137 | posedir_dim = self.posedirs.shape[0]
138 | shapedirs = eo.rearrange(self.shapedirs, "(n c) m -> n c m", c=3)
139 |
140 | if idx_subset is None:
141 | v_template = self.v_template[..., ::subsample_rate, :]
142 | posedirs = self.posedirs.view(posedir_dim, -1, 3)[
143 | :, ::subsample_rate
144 | ].reshape(posedir_dim, -1)
145 | shapedirs = eo.rearrange(shapedirs[::subsample_rate], "n c m -> (n c) m")
146 | weights = self.weights[::subsample_rate]
147 | else:
148 | v_template = self.v_template[..., idx_subset, :]
149 | posedirs = self.posedirs.view(posedir_dim, -1, 3)[:, idx_subset].reshape(
150 | posedir_dim, -1
151 | )
152 | shapedirs = eo.rearrange(shapedirs[idx_subset], "n c m -> (n c) m")
153 | weights = self.weights[idx_subset]
154 |
155 | return v_template, posedirs, shapedirs, weights
156 |
157 | def forward(
158 | self,
159 | betas,
160 | pose,
161 | trans=None,
162 | output_format="joints",
163 | subsample_rate=1,
164 | idx_subset=None,
165 | ):
166 | """Get joints and mesh vertices.
167 |
168 | Valid output formats:
169 | - joints: 24x3 set of SMPL joints
170 | - joints_face: 27x3 joints where the first 22 values are original SMPL joints
171 | while the last 5 correspond to face keypoints
172 | - joints_coco: 17x3 joints in COCO order that match COCO annotation format
173 | e.g. hips are higher and wider than in SMPL
174 | - mesh: 6890x3 set of SMPL mesh vertices
175 | if idx_subset is not None, returns len(idx_subset) vertices
176 | if subsample rate > 1, returns (6078 // subsample_rate) vertices
177 | """
178 | # Convert pose to rotation matrices
179 | pose = eo.rearrange(pose, "... (k c) -> k ... c", c=3).contiguous()
180 | rot_mat = to_rotmat(pose)
181 |
182 | # Adjust mesh based on betas
183 | blend_shape = eo.rearrange(
184 | self.joint_shapedirs @ betas.unsqueeze(-1), "... (k c) 1 -> ... k c", c=3
185 | )
186 | unposed_tx = self.joint_template + blend_shape
187 | unposed_tx = eo.rearrange(unposed_tx, "... k c -> k ... c")
188 |
189 | # Run forward kinematics
190 | tf_chain, joints = self.forward_kinematics(rot_mat, unposed_tx)
191 |
192 | if output_format == "joints":
193 | smpl_output = joints
194 |
195 | else:
196 | if output_format == "joints_face":
197 | idx_subset = FACE_VERTEX_IDXS
198 | elif output_format == "joints_coco":
199 | subsample_rate = 1
200 | idx_subset = None
201 |
202 | v_template, posedirs, shapedirs, weights = self.get_smpl_template(
203 | subsample_rate, idx_subset
204 | )
205 |
206 | # Adjust mesh based on betas
207 | blend_shape = eo.rearrange(
208 | shapedirs @ betas.unsqueeze(-1), "... (k c) 1 -> ... k c", c=3
209 | )
210 | v_shaped = v_template + blend_shape
211 |
212 | # Get relative transforms (rotate unposed joints and subtract)
213 | tf_offset = (tf_chain[..., :3, :3] @ unposed_tx.unsqueeze(-1)).squeeze(-1)
214 | tf_relative = tf_chain.clone()
215 | tf_relative[..., :3, 3] -= tf_offset
216 | A = eo.rearrange(tf_relative, "n ... d0 d1 -> ... n (d0 d1)")
217 |
218 | # Flatten all rotation matrices (except root rotation), remove identity
219 | eye = torch.eye(3, device=rot_mat.device)
220 | pose_feature = eo.rearrange(
221 | rot_mat[1:] - eye, "k ... d0 d1 -> ... (k d0 d1)"
222 | )
223 |
224 | # Adjust base vertex positions
225 | pose_offsets = eo.rearrange(
226 | pose_feature @ posedirs, "... (n c) -> ... n c", c=3
227 | )
228 | v_posed = v_shaped + pose_offsets
229 |
230 | # Transform vertices based on pose
231 | T = weights @ A
232 | T = eo.rearrange(T, "... (d0 d1) -> ... d0 d1", d0=4)
233 | vertices = torch.functional.F.pad(v_posed, [0, 1], value=1)
234 | vertices = (T @ vertices.unsqueeze(-1)).squeeze(-1)
235 | vertices = vertices[..., :-1] / vertices[..., -1:]
236 |
237 | smpl_output = vertices
238 | if output_format == "joints_face":
239 | smpl_output = torch.cat([joints[..., :22, :], vertices], -2)
240 | elif output_format == "joints_coco":
241 | smpl_output = eo.einsum(
242 | vertices, self.coco_map, "... i j, i k -> ... k j"
243 | )
244 |
245 | if trans is not None:
246 | # Apply translation offset
247 | smpl_output = smpl_output + trans.unsqueeze(-2)
248 |
249 | return smpl_output
250 |
--------------------------------------------------------------------------------