├── .gitignore ├── ACKNOWLEDGEMENTS.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE.md ├── LICENSE_MODEL.md ├── README.md ├── demo.py ├── get_pretrained_models.sh ├── posetrack21_eval ├── README.md ├── datasets │ ├── __init__.py │ ├── pt_sequence.py │ └── pt_warper.py ├── evaluate_mot.py └── motmetrics │ ├── __init__.py │ ├── distances.py │ ├── io.py │ ├── lap.py │ ├── math_util.py │ ├── metrics.py │ ├── mot.py │ ├── preprocess.py │ ├── tests │ ├── __init__.py │ ├── test_distances.py │ ├── test_io.py │ ├── test_issue19.py │ ├── test_lap.py │ ├── test_metrics.py │ ├── test_mot.py │ └── test_utils.py │ └── utils.py ├── pyproject.toml ├── samples ├── sample_info.txt ├── teaser_01.gif ├── teaser_02.gif ├── teaser_03.gif ├── teaser_04.gif ├── teaser_05.gif └── teaser_06.gif └── src └── comotion_demo ├── data └── smpl │ └── extra_smpl_reference.pt ├── models ├── __init__.py ├── backbones │ ├── __init__.py │ ├── _registry.py │ └── convnext.py ├── comotion.py ├── detect.py ├── layers.py └── refine.py └── utils ├── __init__.py ├── dataloading.py ├── helper.py ├── smpl_kinematics.py └── track.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.mp4 3 | 4 | .DS_Store 5 | 6 | src/comotion_demo/data 7 | src/comotion_demo.egg-info/ 8 | 9 | data/ 10 | tmp/ 11 | results/ 12 | -------------------------------------------------------------------------------- /ACKNOWLEDGEMENTS.md: -------------------------------------------------------------------------------- 1 | Acknowledgements 2 | Portions of this Software may utilize the following copyrighted 3 | material, the use of which is hereby acknowledged. 4 | 5 | ------------------------------------------------ 6 | PoseTrack21 evaluation code: 7 | 8 | MIT License 9 | 10 | Copyright (c) 2022 Andreas Doering 11 | 12 | Permission is hereby granted, free of charge, to any person obtaining a copy 13 | of this software and associated documentation files (the "Software"), to deal 14 | in the Software without restriction, including without limitation the rights 15 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 16 | copies of the Software, and to permit persons to whom the Software is 17 | furnished to do so, subject to the following conditions: 18 | 19 | The above copyright notice and this permission notice shall be included in all 20 | copies or substantial portions of the Software. 21 | 22 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 23 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 24 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 25 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 26 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 27 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 28 | SOFTWARE. 29 | 30 | ------------------------------------------------ 31 | MOTMetrics: 32 | 33 | MIT License 34 | 35 | Copyright (c) 2017-2020 Christoph Heindl 36 | Copyright (c) 2018 Toka 37 | Copyright (c) 2019-2020 Jack Valmadre 38 | 39 | Permission is hereby granted, free of charge, to any person obtaining a copy 40 | of this software and associated documentation files (the "Software"), to deal 41 | in the Software without restriction, including without limitation the rights 42 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 43 | copies of the Software, and to permit persons to whom the Software is 44 | furnished to do so, subject to the following conditions: 45 | 46 | The above copyright notice and this permission notice shall be included in all 47 | copies or substantial portions of the Software. 48 | 49 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 50 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 51 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 52 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 53 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 54 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 55 | SOFTWARE. 56 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /LICENSE_MODEL.md: -------------------------------------------------------------------------------- 1 | Disclaimer: IMPORTANT: This Apple Machine Learning Research Model is specifically developed and released by Apple Inc. ("Apple") for the sole purpose of scientific research of artificial intelligence and machine-learning technology. “Apple Machine 2 | Learning Research Model” means the model, including but not limited to algorithms, formulas, trained model weights, parameters, configurations, checkpoints, and any related materials (including documentation). 3 | 4 | This Apple Machine Learning Research Model is provided to You by Apple in consideration of your agreement to the following terms, and your use, modification, creation of Model Derivatives, and or redistribution of the Apple Machine Learning Research Model constitutes acceptance of this Agreement. If You do not agree with these terms, please do not use, modify, create Model Derivatives of, or distribute this Apple Machine Learning Research Model or Model Derivatives. 5 | 6 | 1. License Scope: In consideration of your agreement to abide by the following terms, and subject to these terms, Apple hereby grants you a personal, non-exclusive, worldwide, non-transferable, royalty-free, revocable, and limited license, to use, copy, modify, distribute, and create Model Derivatives (defined below) of the Apple Machine Learning Research Model exclusively for Research Purposes. You agree that any Model Derivatives You may create or that may be created for You will be limited to Research Purposes as well. “Research Purposes” means non-commercial scientific research and academic development activities, such as experimentation, analysis, testing conducted by You with the sole intent to advance scientific knowledge and research. “Research Purposes” does not include any commercial exploitation, product development or use in any commercial product or service. 7 | 8 | 2. Distribution of Apple Machine Learning Research Model and Model Derivatives: If you choose to redistribute Apple Machine Learning Research Model or its Model Derivatives, you must provide a copy of this Agreement to such third party, and ensure that the following attribution notice be provided: “Apple Machine Learning Research Model is licensed under the Apple Machine Learning Research Model License Agreement.” Additionally, all Model Derivatives must clearly be identified as such, including disclosure of modifications and changes made to the Apple Machine Learning Research Model. The name, trademarks, service marks or logos of Apple may not be used 9 | to endorse or promote Model Derivatives or the relationship between You and Apple. “Model Derivatives” means any models or any other artifacts created by modifications, improvements, adaptations, alterations to the architecture, 10 | algorithm or training processes of the Apple Machine Learning Research Model, or by any retraining, fine-tuning of the Apple Machine Learning Research Model. 11 | 12 | 3. No Other License: Except as expressly stated in this notice, no other rights or licenses, express or implied, are granted by Apple herein, including but not limited to any patent, trademark, and similar intellectual property rights worldwide that may be infringed by the Apple Machine Learning Research Model, the Model Derivatives or by other works in which the Apple Machine Learning Research Model may be incorporated. 13 | 14 | 4. Compliance with Laws: Your use of Apple Machine Learning Research Model must be in compliance with all applicable laws and regulations. 15 | 16 | 5. Term and Termination: The term of this Agreement will begin upon your acceptance of this Agreement or use of the Apple Machine Learning Research Model and will continue until terminated in accordance with the following terms. Apple may terminate this Agreement at any time if You are in breach of any term or condition of this Agreement. Upon termination of this Agreement, You must cease to use all Apple Machine Learning Research Models and Model Derivatives and permanently delete any copy thereof. Sections 3, 6 and 7 will survive termination. 17 | 18 | 6. Disclaimer and Limitation of Liability: This Apple Machine Learning Research Model and any outputs generated by the Apple Machine Learning Research Model are provided on an “AS IS” basis. APPLE MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE, REGARDING THE APPLE MACHINE LEARNING RESEARCH MODEL OR OUTPUTS GENERATED BY THE APPLE MACHINE LEARNING RESEARCH MODEL. You are solely responsible for determining the appropriateness of using or redistributing the Apple Machine Learning Research Model and any outputs of the Apple Machine Learning Research Model and assume any risks associated with Your use of the Apple Machine Learning Research Model and any output and results. IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, MODIFICATION AND/OR DISTRIBUTION OF THE APPLE MACHINE LEARNING RESEARCH MODEL AND ANY OUTPUTS OF THE APPLE MACHINE LEARNING RESEARCH MODEL, HOWEVER CAUSED AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 19 | 20 | 7. Governing Law: This Agreement will be governed by and construed under the laws of the State of California without regard to its choice of law principles. The Convention on Contracts for the International Sale of Goods shall not apply to the Agreement except that the arbitration clause and any arbitration hereunder shall be governed by the Federal Arbitration Act, Chapters 1 and 2. 21 | 22 | Copyright (c) 2025 Apple Inc. All Rights Reserved. 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoMotion: Concurrent Multi-person 3D Motion 2 | 3 | This software project accompanies the research paper: 4 | **[CoMotion: Concurrent Multi-person 3D Motion](https://openreview.net/forum?id=qKu6KWPgxt)**, 5 | _Alejandro Newell, Peiyun Hu, Lahav Lipson, Stephan R. Richter, and Vladlen Koltun_. 6 | 7 |
8 | 9 | 10 | 11 |
12 | 13 |
14 | 15 | 16 | 17 |
18 | 19 | We introduce CoMotion, an approach for detecting and tracking detailed 3D poses of multiple people from a single monocular camera stream. Our system maintains temporally coherent predictions in crowded scenes filled with difficult poses and occlusions. Our model performs both strong per-frame detection and a learned pose update to track people from frame to frame. Rather than match detections across time, poses are updated directly from a new input image, which enables online tracking through occlusion. 20 | 21 | The code in this directory provides helper functions and scripts for inference and visualization. 22 | 23 | ## Getting Started 24 | 25 | ### Installation 26 | 27 | ```bash 28 | conda create -n comotion -y python=3.10 29 | conda activate comotion 30 | pip install -e '.[all]' 31 | ``` 32 | 33 | ### Download models 34 | 35 | To download pretrained checkpoints, run: 36 | 37 | ```bash 38 | bash get_pretrained_models.sh 39 | ``` 40 | 41 | Checkpoint data will be downloaded to `src/comotion_demo/data`. You will find pretrained weights for the detection stage which includes the main vision backbone (`comotion_detection_checkpoint.pt`), as well as a separate checkpoint for the update stage (`comotion_refine_checkpoint.pt`). You can use the detection stage standalone for single-image multiperson pose estimation. 42 | 43 | For MacOS, we provide a pre-compiled `coreml` version of the detection stage of the model which offers significant speedups when running locally on a personal device. 44 | 45 | ### Download SMPL body model 46 | 47 | In order to run CoMotion and the corresponding visualization, the neutral SMPL body model is required. Please go to the [SMPL website](https://smpl.is.tue.mpg.de/) and follow the provided instructions to download the model (version 1.1.0). After downloading, copy `basicmodel_neutral_lbs_10_207_0_v1.1.0.pkl` to `src/comotion_demo/data/smpl/SMPL_NEUTRAL.pkl` (we rename the file to be compatible with the visualization library `aitviewer`). 48 | 49 | ### Run CoMotion 50 | 51 | We provide a demo script that takes either a video file or a directory of images as input. To run it, call: 52 | 53 | ```bash 54 | python demo.py -i path/to/video.mp4 -o results/ 55 | ``` 56 | 57 | Optional arguments include `--start-frame` and `--num-frames` to select subsets of the video to run on. The network will save a `.pt` file with all of the detected SMPL pose parameters as well as a rendered `.mp4` with the predictions overlaid on the input video. We also automatically produce a `.txt` file in the `MOT` format with bounding boxes compatible with most standard tracking evaluation code. If you wish to skip the visualization, add the command `--skip-visualization`. 58 | 59 | The demo code supports running on a single image as well, which the code will infer automatically if the input path provided has a `.png` or `.jpeg/.jpg` suffix: 60 | 61 | ```bash 62 | python demo.py -i path/to/image.jpg -o results/ 63 | ``` 64 | 65 | In this case, we save a `.pt` file with the detected SMPL poses as well as 2D and 3D coordinates and confidences associated with each detection. 66 | 67 | > [!TIP] 68 | > 69 | > - If you encounter an error that `libc++.1.dylib` is not found, resolve it with `conda install libcxx`. 70 | > - For headless rendering on a remote server, you may encounter an error like `XOpenDisplay: cannot open display`. In this case start a virtual display using `Xvfb :0 -screen 0 640x480x24 & export DISPLAY=:0.0`. You may need to install `xvfb` first (`apt install xvfb`). 71 | 72 | ## Citation 73 | 74 | If you find our work useful, please cite the following paper: 75 | 76 | ```bibtex 77 | @inproceedings{newell2025comotion, 78 | title = {CoMotion: Concurrent Multi-person 3D Motion}, 79 | author = {Alejandro Newell and Peiyun Hu and Lahav Lipson and Stephan R. Richter and Vladlen Koltun}, 80 | booktitle = {International Conference on Learning Representations}, 81 | year = {2025}, 82 | url = {https://openreview.net/forum?id=qKu6KWPgxt}, 83 | } 84 | ``` 85 | 86 | ## License 87 | 88 | This sample code is released under the [LICENSE](LICENSE.md) terms. 89 | 90 | The model weights are released under the [MODEL LICENSE](LICENSE_MODEL.md) terms. 91 | 92 | ## Acknowledgements 93 | 94 | Our codebase is built using multiple open source contributions, please see [Acknowledgements](ACKNOWLEDGEMENTS.md) for more details. 95 | 96 | Please check the paper for a complete list of references and datasets used in this work. 97 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | """Demo CoMotion with a video file or a directory of images.""" 3 | 4 | import logging 5 | import os 6 | import shutil 7 | import tempfile 8 | from pathlib import Path 9 | 10 | import click 11 | import numpy as np 12 | import torch 13 | from PIL import Image 14 | from tqdm import tqdm 15 | 16 | from comotion_demo.models import comotion 17 | from comotion_demo.utils import dataloading, helper 18 | from comotion_demo.utils import track as track_utils 19 | 20 | try: 21 | from aitviewer.configuration import CONFIG 22 | from aitviewer.headless import HeadlessRenderer 23 | from aitviewer.renderables.billboard import Billboard 24 | from aitviewer.renderables.smpl import SMPLLayer, SMPLSequence 25 | from aitviewer.scene.camera import OpenCVCamera 26 | 27 | comotion_model_dir = Path(comotion.__file__).parent 28 | CONFIG.smplx_models = os.path.join(comotion_model_dir, "../data") 29 | CONFIG.window_type = "pyqt6" 30 | aitviewer_available = True 31 | 32 | except ModuleNotFoundError: 33 | print( 34 | "WARNING: Skipped aitviewer import, ensure it is installed to run visualization." 35 | ) 36 | aitviewer_available = False 37 | 38 | 39 | logging.basicConfig( 40 | level=logging.INFO, 41 | format="%(asctime)s - %(levelname)s - %(funcName)s - %(message)s", 42 | datefmt="%Y-%m-%d %H:%M:%S", 43 | ) 44 | 45 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 46 | use_mps = torch.mps.is_available() 47 | 48 | 49 | def prepare_scene(viewer, width, height, K, image_paths, fps=30): 50 | """Prepare the scene for AITViewer rendering.""" 51 | viewer.reset() 52 | viewer.scene.floor.enabled = False 53 | viewer.scene.origin.enabled = False 54 | extrinsics = np.eye(4)[:3] 55 | 56 | # Initialize camera 57 | cam = OpenCVCamera(K, extrinsics, cols=width, rows=height, viewer=viewer) 58 | viewer.scene.add(cam) 59 | viewer.scene.camera.position = [0, 0, -5] 60 | viewer.scene.camera.target = [0, 0, 10] 61 | viewer.auto_set_camera_target = False 62 | viewer.set_temp_camera(cam) 63 | viewer.playback_fps = fps 64 | 65 | # "billboard" display for video frames 66 | billboard = Billboard.from_camera_and_distance( 67 | cam, 100.0, cols=width, rows=height, textures=image_paths 68 | ) 69 | viewer.scene.add(billboard) 70 | 71 | 72 | def add_pose_to_scene( 73 | viewer, 74 | smpl_layer, 75 | betas, 76 | pose, 77 | trans, 78 | color=(0.6, 0.6, 0.6), 79 | alpha=1, 80 | color_ref=None, 81 | ): 82 | """Add estimated poses to the rendered scene.""" 83 | if betas.ndim == 2: 84 | betas = betas[None] 85 | pose = pose[None] 86 | trans = trans[None] 87 | 88 | poses_root = pose[..., :3] 89 | poses_body = pose[..., 3:] 90 | max_people = pose.shape[1] 91 | 92 | if (betas != 0).any(): 93 | for person_idx in range(max_people): 94 | if color_ref is None: 95 | person_color = color 96 | else: 97 | person_color = color_ref[person_idx % len(color_ref)] * 0.4 + 0.3 98 | person_color = [c_ for c_ in person_color] + [alpha] 99 | 100 | valid_vals = (betas[:, person_idx] != 0).any(-1) 101 | idx_range = valid_vals.nonzero() 102 | if len(idx_range) > 0: 103 | trans[~valid_vals][..., 2] = -10000 104 | viewer.scene.add( 105 | SMPLSequence( 106 | smpl_layer=smpl_layer, 107 | betas=betas[:, person_idx], 108 | poses_root=poses_root[:, person_idx], 109 | poses_body=poses_body[:, person_idx], 110 | trans=trans[:, person_idx], 111 | color=person_color, 112 | ) 113 | ) 114 | 115 | 116 | def visualize_poses( 117 | input_path, 118 | cache_path, 119 | video_path, 120 | start_frame, 121 | num_frames, 122 | frameskip=1, 123 | color=(0.6, 0.6, 0.6), 124 | alpha=1, 125 | fps=30, 126 | ): 127 | """Visualize SMPL poses.""" 128 | logging.info(f"Rendering SMPL video: {input_path}") 129 | 130 | # Prepare temporary directory with saved images 131 | tmp_vis_dir = Path(tempfile.mkdtemp()) 132 | 133 | frame_idx = 0 134 | image_paths = [] 135 | for image, K in dataloading.yield_image_and_K( 136 | input_path, start_frame, num_frames, frameskip 137 | ): 138 | image_height, image_width = image.shape[-2:] 139 | image = dataloading.convert_tensor_to_image(image) 140 | image_paths.append(f"{tmp_vis_dir}/{frame_idx:06d}.jpg") 141 | Image.fromarray(image).save(image_paths[-1]) 142 | frame_idx += 1 143 | 144 | # Initialize viewer 145 | viewer = HeadlessRenderer(size=(image_width, image_height)) 146 | 147 | if dataloading.is_a_video(input_path): 148 | fps = int(dataloading.get_input_video_fps(input_path)) 149 | 150 | prepare_scene(viewer, image_width, image_height, K.cpu().numpy(), image_paths, fps) 151 | 152 | # Prepare SMPL poses 153 | smpl_layer = SMPLLayer(model_type="smpl", gender="neutral") 154 | if not cache_path.exists(): 155 | logging.warning("No detections found.") 156 | else: 157 | preds = torch.load(cache_path, weights_only=False, map_location="cpu") 158 | track_subset = track_utils.query_range(preds, 0, frame_idx - 1) 159 | id_lookup = track_subset["id"].max(0)[0] 160 | color_ref = helper.color_ref[id_lookup % len(helper.color_ref)] 161 | if len(id_lookup) == 1: 162 | color_ref = [color_ref] 163 | 164 | betas = track_subset["betas"] 165 | pose = track_subset["pose"] 166 | trans = track_subset["trans"] 167 | 168 | add_pose_to_scene( 169 | viewer, smpl_layer, betas, pose, trans, color, alpha, color_ref 170 | ) 171 | 172 | # Save rendered scene 173 | viewer.save_video( 174 | video_dir=str(video_path), 175 | output_fps=fps, 176 | ensure_no_overwrite=False, 177 | ) 178 | 179 | # Remove temporary directory 180 | shutil.rmtree(tmp_vis_dir) 181 | 182 | 183 | def run_detection(input_path, cache_path, skip_visualization=False, model=None): 184 | """Run model and visualize detections on single image.""" 185 | if model is None: 186 | model = comotion.CoMotion(use_coreml=use_mps) 187 | model.to(device).eval() 188 | 189 | # Load image 190 | image = np.array(Image.open(input_path)) 191 | image = dataloading.convert_image_to_tensor(image) 192 | K = dataloading.get_default_K(image) 193 | cropped_image, cropped_K = dataloading.prepare_network_inputs(image, K, device) 194 | 195 | # Get detections 196 | detections = model.detection_model(cropped_image, cropped_K) 197 | detections = comotion.detect.decode_network_outputs( 198 | K.to(device), 199 | model.smpl_decoder, 200 | detections, 201 | std=0.15, # Adjust NMS sensitivity 202 | conf_thr=0.25, # Adjust confidence cutoff 203 | ) 204 | 205 | detections = {k: v[0].cpu() for k, v in detections.items()} 206 | torch.save(detections, cache_path) 207 | 208 | if not skip_visualization: 209 | # Initialize viewer 210 | image_height, image_width = image.shape[-2:] 211 | viewer = HeadlessRenderer(size=(image_width, image_height)) 212 | prepare_scene( 213 | viewer, image_width, image_height, K.cpu().numpy(), [str(input_path)] 214 | ) 215 | 216 | # Prepare SMPL poses 217 | smpl_layer = SMPLLayer(model_type="smpl", gender="neutral") 218 | add_pose_to_scene( 219 | viewer, 220 | smpl_layer, 221 | detections["betas"], 222 | detections["pose"], 223 | detections["trans"], 224 | ) 225 | 226 | # Save rendered scene 227 | viewer.save_frame(str(cache_path).replace(".pt", ".png")) 228 | 229 | 230 | def track_poses( 231 | input_path, cache_path, start_frame, num_frames, frameskip=1, model=None 232 | ): 233 | """Track poses over a video or a directory of images.""" 234 | if model is None: 235 | model = comotion.CoMotion(use_coreml=use_mps) 236 | model.to(device).eval() 237 | 238 | detections = [] 239 | tracks = [] 240 | 241 | initialized = False 242 | for image, K in tqdm( 243 | dataloading.yield_image_and_K(input_path, start_frame, num_frames, frameskip), 244 | desc="Running CoMotion", 245 | ): 246 | if not initialized: 247 | image_res = image.shape[-2:] 248 | model.init_tracks(image_res) 249 | initialized = True 250 | 251 | detection, track = model(image, K, use_mps=use_mps) 252 | detection = {k: v.cpu() for k, v in detection.items()} 253 | track = track.cpu() 254 | detections.append(detection) 255 | tracks.append(track) 256 | 257 | detections = {k: [d[k] for d in detections] for k in detections[0].keys()} 258 | tracks = torch.stack(tracks, 1) 259 | tracks = {k: getattr(tracks, k) for k in ["id", "pose", "trans", "betas"]} 260 | 261 | track_ref = track_utils.cleanup_tracks( 262 | {"detections": detections, "tracks": tracks}, 263 | K, 264 | model.smpl_decoder.cpu(), 265 | min_matched_frames=1, 266 | ) 267 | if track_ref: 268 | frame_idxs, track_idxs = track_utils.convert_to_idxs( 269 | track_ref, tracks["id"][0].squeeze(-1).long() 270 | ) 271 | preds = {k: v[0, frame_idxs, track_idxs] for k, v in tracks.items()} 272 | preds["id"] = preds["id"].squeeze(-1).long() 273 | preds["frame_idx"] = frame_idxs 274 | torch.save(preds, cache_path) 275 | 276 | # Save bounding box tracks in MOT format 277 | bboxes = track_utils.bboxes_from_smpl( 278 | model.smpl_decoder, 279 | {k: preds[k] for k in ["betas", "pose", "trans"]}, 280 | image_res, 281 | K, 282 | ) 283 | with open(str(cache_path).replace(".pt", ".txt"), "w") as f: 284 | f.write(track_utils.convert_to_mot(preds["id"], preds["frame_idx"], bboxes)) 285 | 286 | 287 | @click.command() 288 | @click.option( 289 | "-i", 290 | "--input-path", 291 | required=True, 292 | type=click.Path(exists=True, path_type=Path), 293 | help="Path to the input video, a directory of images, or a single input image.", 294 | ) 295 | @click.option( 296 | "-o", 297 | "--output-dir", 298 | required=True, 299 | type=click.Path(exists=False, path_type=Path), 300 | help="Path to the output directory.", 301 | ) 302 | @click.option( 303 | "-s", 304 | "--start-frame", 305 | default=0, 306 | type=int, 307 | help="Frame to start with.", 308 | ) 309 | @click.option( 310 | "-n", 311 | "--num-frames", 312 | default=1_000_000_000, 313 | type=int, 314 | help="Number of frames to process.", 315 | ) 316 | @click.option( 317 | "--skip-visualization", 318 | is_flag=True, 319 | help="Whether to skip rendering the output SMPL meshes.", 320 | ) 321 | @click.option( 322 | "--frameskip", 323 | default=1, 324 | type=int, 325 | help="Subsample video frames (e.g. frameskip=2 processes every other frame).", 326 | ) 327 | def main( 328 | input_path, output_dir, start_frame, num_frames, skip_visualization, frameskip 329 | ): 330 | """Demo entry point.""" 331 | output_dir.mkdir(parents=True, exist_ok=True) 332 | input_name = input_path.stem 333 | skip_visualization = skip_visualization | (not aitviewer_available) 334 | 335 | cache_path = output_dir / f"{input_name}.pt" 336 | if input_path.suffix.lower() in dataloading.IMAGE_EXTENSIONS: 337 | # Run and visualize detections for a single image 338 | run_detection(input_path, cache_path, skip_visualization) 339 | else: 340 | # Run unrolled tracking on a full video 341 | track_poses(input_path, cache_path, start_frame, num_frames, frameskip) 342 | if not skip_visualization: 343 | video_path = output_dir / f"{input_name}.mp4" 344 | visualize_poses( 345 | input_path, cache_path, video_path, start_frame, num_frames, frameskip 346 | ) 347 | 348 | 349 | if __name__ == "__main__": 350 | main() 351 | -------------------------------------------------------------------------------- /get_pretrained_models.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # 3 | # For licensing see accompanying LICENSE file. 4 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 5 | 6 | cd src/comotion_demo/data 7 | 8 | # Download and extract model checkpoints 9 | wget https://ml-site.cdn-apple.com/models/comotion/demo_checkpoints.tar.gz 10 | tar zxf demo_checkpoints.tar.gz 11 | rm demo_checkpoints.tar.gz 12 | 13 | cd ../../.. 14 | -------------------------------------------------------------------------------- /posetrack21_eval/README.md: -------------------------------------------------------------------------------- 1 | # PoseTrack21 MOT Evaluation 2 | 3 | This is a slight adaptation of the [PoseTrack21 evaluation code](https://github.com/anDoer/PoseTrack21). We opt to use the bounding-box MOT evaluation here since our model does not output the appropriate format for the keypoint-based metrics. We make several adjustments to the MOT evaluation code. 4 | 5 | Major changes that affect tracking evaluation: 6 | - as discussed in the paper, we change the IOU calculation for ignore regions. This is a one line change: 7 | ``` 8 | # previous 9 | region_ious[j, i] = poly_intersection / poly_union 10 | # ours 11 | region_ious[j, i] = poly_intersection / det_boxes[j].area 12 | ``` 13 | - some of the PoseTrack sequences are not annotated at every frame, but the code by default would penalize false positives on un-annotated frames. We have updated the code to ignore detections on frames with no ground-truth annotations. 14 | 15 | Minor updates: 16 | - we no longer load images in the `PTSequence` class 17 | - we update the dtype behavior in `motmetrics` to support more recent versions of `numpy` 18 | - turn on `use_ignore_regions` by default (no flag to set at runtime) 19 | 20 | ## Setup 21 | 22 | A few extra packages are needed to run the evaluation code, these can be installed with: 23 | 24 | ``` 25 | conda install geos 26 | pip install lap pandas shapely==1.7.1 xmltodict 27 | ``` 28 | 29 | Follow the instructions provided by the authors in the [original repository](https://github.com/anDoer/PoseTrack21) to obtain a copy of the dataset and annotations. 30 | 31 | ## Running the evaluation 32 | 33 | To run: 34 | ``` 35 | python evaluate_mot --dataset_path $PATH_TO_DATASET_ROOT \ 36 | --mot_path $PATH_TO_RESPECTIVE_MOT_FOLDER \ 37 | --result_path $FOLDER_WITH_YOUR_RESULTS \ 38 | ``` 39 | 40 | We also support evaluation on individual sequences with: 41 | ``` 42 | python evaluate_mot ... --sequence_choice [SEQUENCE_ID_0] [SEQUENCE_ID_1] ... 43 | ``` 44 | where `SEQUENCE_ID` is the integer sequence for a given PoseTrack example (e.g. `000342`) - no need to include `_mpii_test`. 45 | -------------------------------------------------------------------------------- /posetrack21_eval/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/posetrack21_eval/datasets/__init__.py -------------------------------------------------------------------------------- /posetrack21_eval/datasets/pt_sequence.py: -------------------------------------------------------------------------------- 1 | import configparser 2 | import csv 3 | import os 4 | import os.path as osp 5 | 6 | import numpy as np 7 | from PIL import Image 8 | 9 | import json 10 | 11 | from shapely.geometry import box, Polygon, MultiPolygon 12 | 13 | def ignore_regions(): 14 | # remove boxes from in ignore regions 15 | ignore_box_candidates = set() 16 | ########################## 17 | ignore_iou_thres = 0.1 18 | ########################## 19 | 20 | if len(ignore_x) > 0: 21 | if not isinstance(ignore_x[0], list): 22 | ignore_x = [ignore_x] 23 | ignore_y = [ignore_y] 24 | 25 | # build ignore regions: 26 | ignore_regions = [] 27 | for r_idx in range(len(ignore_x)): 28 | region = [] 29 | 30 | for x, y in zip(ignore_x[r_idx], ignore_y[r_idx]): 31 | region.append([x, y]) 32 | 33 | ignore_region = Polygon(region) 34 | ignore_regions.append(ignore_region) 35 | 36 | region_ious = np.zeros((len(ignore_regions), num_det), dtype=np.float32) 37 | det_boxes = [] 38 | for j in rrange(num_det): 39 | x1 = det[j, 0] 40 | y1 = det[j, 1] 41 | x2 = det[j, 2] 42 | y2 = det[j, 3] 43 | 44 | box_poly = Polygon([ 45 | [x1, y1], 46 | [x2, y1], 47 | [x2, y2], 48 | [x1, y2], 49 | [x1, y1], 50 | ]) 51 | 52 | det_boxes.append(box_poly) 53 | 54 | for i in xrange(len(ignore_regions)): 55 | for j in xrange(num_det): 56 | if ignore_regions[i].is_valid: 57 | poly_intersection = ignore_regions[i].intersection(det_boxes[j]).area 58 | poly_union = ignore_regions[i].union(det_boxes[j]).area 59 | else: 60 | multi_poly = ignore_regions[i].buffer(0) 61 | poly_intersection = 0 62 | poly_union = 0 63 | 64 | if isinstance(multi_poly, Polygon): 65 | poly_intersection = multi_poly.intersection(det_boxes[j]).area 66 | poly_union = multi_poly.union(det_boxes[j]).area 67 | else: 68 | multi_poly = ignore_regions[i].buffer(0) 69 | poly_intersection = 0 70 | poly_union = 0 71 | 72 | if isinstance(multi_poly, Polygon): 73 | poly_intersection = multi_poly.intersection(det_boxes[j]).area 74 | poly_union = multi_poly.union(det_boxes[j]).area 75 | else: 76 | for poly in multi_poly: 77 | poly_intersection += poly.intersection(det_boxes[j]).area 78 | poly_union += poly.union(det_boxes[j]).area 79 | 80 | region_ious[i, j] = poly_intersection / poly_union 81 | 82 | candidates = np.argwhere(region_ious[i] > ignore_iou_thres) 83 | if len(candidates) > 0: 84 | candidates = candidates[:, 0].tolist() 85 | ignore_box_candidates.update(candidates) 86 | 87 | ious = np.zeros((num_gt, num_det), dtype=np.float32) 88 | for i in xrange(num_gt): 89 | for j in xrange(num_det): 90 | ious[i, j] = _compute_iou(gt_boxes[i], det[j, :4]) 91 | tfmat = (ious >= iou_thresh) 92 | # for each det, keep only the largest iou of all the gt 93 | for j in xrange(num_det): 94 | largest_ind = np.argmax(ious[:, j]) 95 | for i in xrange(num_gt): 96 | if i != largest_ind: 97 | tfmat[i, j] = False 98 | # for each gt, keep only the largest iou of all the det 99 | for i in xrange(num_gt): 100 | largest_ind = np.argmax(ious[i, :]) 101 | for j in xrange(num_det): 102 | if j != largest_ind: 103 | tfmat[i, j] = False 104 | for j in xrange(num_det): 105 | if j in ignore_box_candidates and not tfmat[:, j].any(): 106 | # we have a detection in ignore region 107 | detections_to_ignore[gallery_idx].append(j) 108 | continue 109 | 110 | y_score.append(det[j, -1]) 111 | if tfmat[:, j].any(): 112 | y_true.append(True) 113 | else: 114 | y_true.append(False) 115 | count_tp += tfmat.sum() 116 | count_gt += num_gt 117 | 118 | 119 | class PTSequence(): 120 | """Multiple Object Tracking Dataset. 121 | """ 122 | def __init__(self, seq_name, mot_dir, dataset_path, vis_threshold=0.0): 123 | self._seq_name = seq_name 124 | self._vis_threshold = vis_threshold 125 | 126 | self._mot_dir = mot_dir 127 | self.dataset_path = dataset_path 128 | self._folders = os.listdir(self._mot_dir) 129 | 130 | if seq_name is not None: 131 | assert seq_name in self._folders, \ 132 | 'Image set does not exist: {}'.format(seq_name) 133 | 134 | self.data, self.no_gt, self.ignore = self._sequence() 135 | else: 136 | self.data = [] 137 | self.no_gt = False 138 | 139 | def __len__(self): 140 | return len(self.data) 141 | 142 | def __getitem__(self, idx): 143 | """Return the ith image converted to blob""" 144 | data = self.data[idx] 145 | # img = Image.open(data['im_path']).convert("RGB") 146 | # img = np.asarray(img) 147 | 148 | sample = {} 149 | # sample['img'] = img 150 | sample['dets'] = np.array([det[:4] for det in data['dets']]) 151 | sample['img_path'] = data['im_path'] 152 | sample['gt'] = data['gt'] 153 | sample['vis'] = data['vis'] 154 | 155 | return sample 156 | 157 | def _sequence(self): 158 | seq_name = self._seq_name 159 | seq_path = osp.join(self._mot_dir, seq_name) 160 | 161 | config_file = os.path.join(seq_path, 'image_info.json') 162 | 163 | assert osp.exists(config_file), \ 164 | 'Config file does not exist: {}'.format(config_file) 165 | 166 | with open(config_file, 'r') as file: 167 | image_info = json.load(file) 168 | 169 | seqLength = len(image_info) 170 | gt_file = osp.join(seq_path, 'gt', 'gt.txt') 171 | 172 | total = [] 173 | 174 | visibility = {} 175 | boxes = {} 176 | dets = {} 177 | ignore_x = {} 178 | ignore_y = {} 179 | 180 | for i in range(seqLength): 181 | frame_index = image_info[i]['frame_index'] 182 | boxes[frame_index] = {} 183 | visibility[frame_index] = {} 184 | dets[frame_index] = [] 185 | ignore_x[frame_index] = image_info[i]['ignore_regions_x'] 186 | ignore_y[frame_index] = image_info[i]['ignore_regions_y'] 187 | 188 | no_gt = False 189 | if osp.exists(gt_file): 190 | with open(gt_file, "r") as inf: 191 | reader = csv.reader(inf, delimiter=',') 192 | for row in reader: 193 | # class person, certainity 1, visibility >= 0.25 194 | row[7] = '1' 195 | row[8] = '1' 196 | if int(row[6]) == 1 and int(row[7]) == 1 and float(row[8]) >= self._vis_threshold: 197 | # Make pixel indexes 0-based, should already be 0-based (or not) 198 | x1 = int(float(row[2])) # - 1 199 | y1 = int(float(row[3])) # - 1 200 | # This -1 accounts for the width (width of 1 x1=x2) 201 | x2 = x1 + int(float(row[4])) #- 1 202 | y2 = y1 + int(float(row[5])) # - 1 203 | bb = np.array([x1,y1,x2,y2], dtype=np.float32) 204 | 205 | boxes[int(row[0])][int(row[1])] = bb 206 | visibility[int(row[0])][int(row[1])] = float(row[8]) 207 | else: 208 | no_gt = True 209 | 210 | det_file = osp.join(seq_path, 'det', 'det.txt') 211 | 212 | if osp.exists(det_file): 213 | with open(det_file, "r") as inf: 214 | reader = csv.reader(inf, delimiter=',') 215 | for row in reader: 216 | x1 = float(row[2]) - 1 217 | y1 = float(row[3]) - 1 218 | # This -1 accounts for the width (width of 1 x1=x2) 219 | x2 = x1 + float(row[4]) - 1 220 | y2 = y1 + float(row[5]) - 1 221 | score = float(row[6]) 222 | bb = np.array([x1,y1,x2,y2, score], dtype=np.float32) 223 | dets[int(float(row[0]))].append(bb) 224 | 225 | dataset_path = self.dataset_path 226 | for i in range(seqLength): 227 | im_path = os.path.join(dataset_path, image_info[i]['file_name']) 228 | frame_index = image_info[i]['frame_index'] 229 | 230 | sample = {'gt':boxes[frame_index], 231 | 'im_path':im_path, 232 | 'vis':visibility[frame_index], 233 | 'dets':dets[frame_index]} 234 | 235 | total.append(sample) 236 | 237 | return total, no_gt, {'ignore_x': ignore_x, 'ignore_y': ignore_y} 238 | 239 | def __str__(self): 240 | return self._seq_name 241 | 242 | def write_results(self, all_tracks, output_dir): 243 | """Write the tracks in the format for MOT16/MOT17 sumbission 244 | 245 | all_tracks: dictionary with 1 dictionary for every track with {..., i:np.array([x1,y1,x2,y2]), ...} at key track_num 246 | 247 | Each file contains these lines: 248 | , , , , , , , , , 249 | """ 250 | 251 | #format_str = "{}, -1, {}, {}, {}, {}, {}, -1, -1, -1" 252 | 253 | if not os.path.exists(output_dir): 254 | os.makedirs(output_dir) 255 | 256 | with open(osp.join(output_dir, self._seq_name), "w") as of: 257 | writer = csv.writer(of, delimiter=',') 258 | for i, track in all_tracks.items(): 259 | for frame, bb in track.items(): 260 | x1 = bb[0] 261 | y1 = bb[1] 262 | x2 = bb[2] 263 | y2 = bb[3] 264 | writer.writerow( 265 | [frame + 1, 266 | i + 1, 267 | x1 + 1, 268 | y1 + 1, 269 | x2 - x1 + 1, 270 | y2 - y1 + 1, 271 | -1, -1, -1, -1]) 272 | 273 | def load_results(self, output_dir): 274 | file_path = osp.join(output_dir, self._seq_name) 275 | results = {} 276 | 277 | if not os.path.isfile(file_path): 278 | if os.path.isfile(f'{file_path}.txt'): 279 | file_path = f'{file_path}.txt' 280 | else: 281 | return results 282 | 283 | with open(file_path, "r") as of: 284 | csv_reader = csv.reader(of, delimiter=',') 285 | for row in csv_reader: 286 | frame_id, track_id = int(row[0]) - 1, int(row[1]) - 1 287 | 288 | if not track_id in results: 289 | results[track_id] = {} 290 | 291 | x1 = float(row[2]) - 1 292 | y1 = float(row[3]) - 1 293 | x2 = float(row[4]) - 1 + x1 294 | y2 = float(row[5]) - 1 + y1 295 | 296 | results[track_id][frame_id] = [x1, y1, x2, y2] 297 | 298 | return results 299 | -------------------------------------------------------------------------------- /posetrack21_eval/datasets/pt_warper.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import os 3 | #import torch 4 | #from torch.utils.data import Dataset 5 | 6 | from .pt_sequence import PTSequence 7 | 8 | 9 | class PTWrapper(): 10 | """A Wrapper for the MOT_Sequence class to return multiple sequences.""" 11 | 12 | def __init__(self, mot_dir, dataset_path, vis_threshold=0.0): 13 | 'seq_name, mot_dir, dataset_path, vis_threshold=0.0' 14 | 15 | sequences = os.listdir(mot_dir) 16 | 17 | self._data = [] 18 | for s in sequences: 19 | self._data.append(PTSequence(s, mot_dir, dataset_path, vis_threshold)) 20 | 21 | def __len__(self): 22 | return len(self._data) 23 | 24 | def __getitem__(self, idx): 25 | return self._data[idx] 26 | -------------------------------------------------------------------------------- /posetrack21_eval/evaluate_mot.py: -------------------------------------------------------------------------------- 1 | if __name__ == "__main__": 2 | from datasets.pt_warper import PTWrapper 3 | import motmetrics as mm 4 | else: 5 | from .datasets.pt_warper import PTWrapper 6 | from . import motmetrics as mm 7 | 8 | import numpy as np 9 | mm.lap.default_solver = 'lap' 10 | from shapely.geometry import box, Polygon, MultiPolygon 11 | import argparse 12 | import os 13 | from tqdm import tqdm 14 | 15 | def get_mot_accum(results, seq, ignore_iou_thres=0.1, use_ignore_regions=False): 16 | def get_frame_idx(impath): 17 | return int(impath.split("/")[-1].replace(".jpg", "")) 18 | 19 | valid_frames = [get_frame_idx(d["im_path"]) for d in seq.data if len(d["gt"]) > 0] 20 | 21 | mot_accum = mm.MOTAccumulator(auto_id=True) 22 | 23 | ignore_regions_x = seq.ignore['ignore_x'] 24 | ignore_regions_y = seq.ignore['ignore_y'] 25 | for i, data in enumerate(seq): 26 | if i in valid_frames: 27 | # i corresponds to image index 28 | ignore_x = ignore_regions_x[i + 1] 29 | ignore_y = ignore_regions_y[i + 1] 30 | 31 | if len(ignore_x) > 0: 32 | if not isinstance(ignore_x[0], list): 33 | ignore_x = [ignore_x] 34 | if len(ignore_y) > 0: 35 | if not isinstance(ignore_y[0], list): 36 | ignore_y = [ignore_y] 37 | 38 | ignore_regions = [] 39 | for r_idx in range(len(ignore_x)): 40 | region = [] 41 | if len(ignore_x[r_idx]) == 0 or len(ignore_y[r_idx]) == 0: 42 | continue 43 | 44 | for x, y in zip(ignore_x[r_idx], ignore_y[r_idx]): 45 | region.append([x, y]) 46 | 47 | try: 48 | ignore_region = Polygon(region) 49 | ignore_regions.append(ignore_region) 50 | except: 51 | assert False 52 | 53 | gt = data['gt'] 54 | gt_ids = [] 55 | gt_boxes = [] 56 | 57 | if gt: 58 | for gt_id, box in gt.items(): 59 | gt_ids.append(gt_id) 60 | gt_boxes.append(box) 61 | 62 | gt_boxes = np.stack(gt_boxes, axis=0) 63 | # x1, y1, x2, y2 --> x1, y1, width, height 64 | gt_boxes = np.stack((gt_boxes[:, 0], 65 | gt_boxes[:, 1], 66 | gt_boxes[:, 2] - gt_boxes[:, 0], 67 | gt_boxes[:, 3] - gt_boxes[:, 1]), 68 | axis=1) 69 | else: 70 | gt_boxes = np.array([]) 71 | 72 | track_ids = [] 73 | track_boxes = [] 74 | det_boxes = [] 75 | ignore_candidates = [] 76 | 77 | for track_id, frames in results.items(): 78 | if i in frames: 79 | track_ids.append(track_id) 80 | # frames = x1, y1, x2, y2, score 81 | track_boxes.append(frames[i][:4]) 82 | box = frames[i][:4] 83 | 84 | if use_ignore_regions: 85 | x1 = box[0] 86 | y1 = box[1] 87 | x2 = box[2] 88 | y2 = box[3] 89 | 90 | box_poly = Polygon([ 91 | [x1, y1], 92 | [x2, y1], 93 | [x2, y2], 94 | [x1, y2], 95 | [x1, y1], 96 | ]) 97 | 98 | det_boxes.append(box_poly) 99 | 100 | if use_ignore_regions: 101 | # obtain candidate detections for ignore regions 102 | region_ious = np.zeros((len(det_boxes), len(ignore_regions)), dtype=np.float32) 103 | for i in range(len(ignore_regions)): 104 | for j in range(len(det_boxes)): 105 | if ignore_regions[i].is_valid: 106 | poly_intersection = ignore_regions[i].intersection(det_boxes[j]).area 107 | poly_union = ignore_regions[i].union(det_boxes[j]).area 108 | else: 109 | multi_poly = ignore_regions[i].buffer(0) 110 | poly_intersection = 0 111 | poly_union = 0 112 | 113 | if isinstance(multi_poly, Polygon): 114 | poly_intersection = multi_poly.intersection(det_boxes[j]).area 115 | poly_union = multi_poly.union(det_boxes[j]).area 116 | elif isinstance(multi_poly, MultiPolygon): 117 | poly_intersection = multi_poly.intersection(det_boxes[j]).area 118 | poly_union = multi_poly.union(det_boxes[j]).area 119 | else: 120 | for poly in multi_poly: 121 | poly_intersection += poly.intersection(det_boxes[j]).area 122 | poly_union += poly.union(det_boxes[j]).area 123 | 124 | # For reference: original IOU calculation 125 | # - region_ious[j, i] = poly_intersection / poly_union 126 | 127 | # Updated IOU calculation 128 | region_ious[j, i] = poly_intersection / det_boxes[j].area 129 | 130 | if len(ignore_regions) > 0 and len(det_boxes) > 0: 131 | ignore_candidates = np.where((region_ious > ignore_iou_thres).max(axis=1))[0].tolist() 132 | 133 | if track_ids: 134 | track_boxes = np.stack(track_boxes, axis=0) 135 | # x1, y1, x2, y2 --> x1, y1, width, height 136 | track_boxes = np.stack((track_boxes[:, 0], 137 | track_boxes[:, 1], 138 | track_boxes[:, 2] - track_boxes[:, 0], 139 | track_boxes[:, 3] - track_boxes[:, 1]), 140 | axis=1) 141 | else: 142 | track_boxes = np.array([]) 143 | 144 | distance = mm.distances.iou_matrix(gt_boxes, track_boxes, max_iou=0.5) 145 | 146 | mot_accum.update( 147 | gt_ids, 148 | track_ids, 149 | distance, 150 | ignore_candidates=ignore_candidates) 151 | 152 | return mot_accum 153 | 154 | 155 | def evaluate_mot_accums(accums, names, generate_overall=False): 156 | mh = mm.metrics.create() 157 | summary = mh.compute_many( 158 | accums, 159 | metrics=mm.metrics.motchallenge_metrics + ['num_matches', 'num_objects'], 160 | names=names, 161 | generate_overall=generate_overall, ) 162 | 163 | str_summary = mm.io.render_summary( 164 | summary, 165 | formatters=mh.formatters, 166 | namemap=mm.io.motchallenge_metric_names, ) 167 | print(str_summary) 168 | 169 | return str_summary 170 | 171 | def main(): 172 | parser = argparse.ArgumentParser() 173 | parser.add_argument('--mot_path', required=True, help='Path to GT MOT files') 174 | parser.add_argument('--dataset_path', required=True) 175 | parser.add_argument('--result_path', required=True) 176 | parser.add_argument('--ignore_iou_thres', default=0.1) 177 | parser.add_argument('--sequence_choice', type=str, nargs="+", default=[]) 178 | args = parser.parse_args() 179 | 180 | mot_accums = [] 181 | 182 | mot_path = args.mot_path 183 | dataset_path = args.dataset_path 184 | result_path = args.result_path 185 | use_ignore_regions = True 186 | ignore_iou_thres = args.ignore_iou_thres 187 | 188 | dataset = PTWrapper(mot_path, dataset_path, vis_threshold=0.1) 189 | sequence_names = [str(s) for s in dataset if not s.no_gt] 190 | if len(args.sequence_choice) > 0: 191 | sequence_names = [f"{s}_mpii_test" for s in args.sequence_choice] 192 | 193 | for seq_idx, seq in enumerate(tqdm(dataset)): 194 | if seq._seq_name not in sequence_names: 195 | continue 196 | 197 | if not os.path.isdir(result_path): 198 | raise FileNotFoundError(f"result path {result_path} does not exist") 199 | results = seq.load_results(result_path) 200 | if len(results) == 0: 201 | print(f"Results not provided for gt sequence {seq._seq_name}") 202 | mot_accums.append(get_mot_accum(results, 203 | seq, 204 | ignore_iou_thres=ignore_iou_thres, 205 | use_ignore_regions=use_ignore_regions)) 206 | 207 | if mot_accums: 208 | evaluate_mot_accums(mot_accums, sequence_names, generate_overall=True) 209 | 210 | if __name__ == '__main__': 211 | main() 212 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/__init__.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 9 | 10 | Christoph Heindl, 2017 11 | https://github.com/cheind/py-motmetrics 12 | """ 13 | 14 | from __future__ import absolute_import 15 | from __future__ import division 16 | from __future__ import print_function 17 | 18 | __all__ = [ 19 | 'distances', 20 | 'io', 21 | 'lap', 22 | 'metrics', 23 | 'utils', 24 | 'MOTAccumulator', 25 | ] 26 | 27 | from . import distances 28 | from . import io 29 | from . import lap 30 | from . import metrics 31 | from . import utils 32 | from .mot import MOTAccumulator 33 | 34 | # Needs to be last line 35 | __version__ = '1.2.0' 36 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/distances.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Functions for comparing predictions and ground-truth.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import numpy as np 15 | 16 | from . import math_util 17 | 18 | 19 | def norm2squared_matrix(objs, hyps, max_d2=float('inf')): 20 | """Computes the squared Euclidean distance matrix between object and hypothesis points. 21 | 22 | Params 23 | ------ 24 | objs : NxM array 25 | Object points of dim M in rows 26 | hyps : KxM array 27 | Hypothesis points of dim M in rows 28 | 29 | Kwargs 30 | ------ 31 | max_d2 : float 32 | Maximum tolerable squared Euclidean distance. Object / hypothesis points 33 | with larger distance are set to np.nan signalling do-not-pair. Defaults 34 | to +inf 35 | 36 | Returns 37 | ------- 38 | C : NxK array 39 | Distance matrix containing pairwise distances or np.nan. 40 | """ 41 | 42 | objs = np.atleast_2d(objs).astype(float) 43 | hyps = np.atleast_2d(hyps).astype(float) 44 | 45 | if objs.size == 0 or hyps.size == 0: 46 | return np.empty((0, 0)) 47 | 48 | assert hyps.shape[1] == objs.shape[1], "Dimension mismatch" 49 | 50 | delta = objs[:, np.newaxis] - hyps[np.newaxis, :] 51 | C = np.sum(delta ** 2, axis=-1) 52 | 53 | C[C > max_d2] = np.nan 54 | return C 55 | 56 | 57 | def rect_min_max(r): 58 | min_pt = r[..., :2] 59 | size = r[..., 2:] 60 | max_pt = min_pt + size 61 | return min_pt, max_pt 62 | 63 | 64 | def boxiou(a, b): 65 | """Computes IOU of two rectangles.""" 66 | a_min, a_max = rect_min_max(a) 67 | b_min, b_max = rect_min_max(b) 68 | # Compute intersection. 69 | i_min = np.maximum(a_min, b_min) 70 | i_max = np.minimum(a_max, b_max) 71 | i_size = np.maximum(i_max - i_min, 0) 72 | i_vol = np.prod(i_size, axis=-1) 73 | # Get volume of union. 74 | a_size = np.maximum(a_max - a_min, 0) 75 | b_size = np.maximum(b_max - b_min, 0) 76 | a_vol = np.prod(a_size, axis=-1) 77 | b_vol = np.prod(b_size, axis=-1) 78 | u_vol = a_vol + b_vol - i_vol 79 | return np.where(i_vol == 0, np.zeros_like(i_vol, dtype=float), 80 | math_util.quiet_divide(i_vol, u_vol)) 81 | 82 | 83 | def iou_matrix(objs, hyps, max_iou=1.): 84 | """Computes 'intersection over union (IoU)' distance matrix between object and hypothesis rectangles. 85 | 86 | The IoU is computed as 87 | 88 | IoU(a,b) = 1. - isect(a, b) / union(a, b) 89 | 90 | where isect(a,b) is the area of intersection of two rectangles and union(a, b) the area of union. The 91 | IoU is bounded between zero and one. 0 when the rectangles overlap perfectly and 1 when the overlap is 92 | zero. 93 | 94 | Params 95 | ------ 96 | objs : Nx4 array 97 | Object rectangles (x,y,w,h) in rows 98 | hyps : Kx4 array 99 | Hypothesis rectangles (x,y,w,h) in rows 100 | 101 | Kwargs 102 | ------ 103 | max_iou : float 104 | Maximum tolerable overlap distance. Object / hypothesis points 105 | with larger distance are set to np.nan signalling do-not-pair. Defaults 106 | to 0.5 107 | 108 | Returns 109 | ------- 110 | C : NxK array 111 | Distance matrix containing pairwise distances or np.nan. 112 | """ 113 | 114 | if np.size(objs) == 0 or np.size(hyps) == 0: 115 | return np.empty((0, 0)) 116 | 117 | objs = np.asfarray(objs) 118 | hyps = np.asfarray(hyps) 119 | assert objs.shape[1] == 4 120 | assert hyps.shape[1] == 4 121 | iou = boxiou(objs[:, None], hyps[None, :]) 122 | dist = 1 - iou 123 | return np.where(dist > max_iou, np.nan, dist) 124 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/io.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Functions for loading data and writing summaries.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | from enum import Enum 15 | import io 16 | 17 | import numpy as np 18 | import pandas as pd 19 | import scipy.io 20 | import xmltodict 21 | 22 | 23 | class Format(Enum): 24 | """Enumerates supported file formats.""" 25 | 26 | MOT16 = 'mot16' 27 | """Milan, Anton, et al. "Mot16: A benchmark for multi-object tracking." arXiv preprint arXiv:1603.00831 (2016).""" 28 | 29 | MOT15_2D = 'mot15-2D' 30 | """Leal-Taixe, Laura, et al. "MOTChallenge 2015: Towards a benchmark for multi-target tracking." arXiv preprint arXiv:1504.01942 (2015).""" 31 | 32 | VATIC_TXT = 'vatic-txt' 33 | """Vondrick, Carl, Donald Patterson, and Deva Ramanan. "Efficiently scaling up crowdsourced video annotation." International Journal of Computer Vision 101.1 (2013): 184-204. 34 | https://github.com/cvondrick/vatic 35 | """ 36 | 37 | DETRAC_MAT = 'detrac-mat' 38 | """Wen, Longyin et al. "UA-DETRAC: A New Benchmark and Protocol for Multi-Object Detection and Tracking." arXiv preprint arXiv:arXiv:1511.04136 (2016). 39 | http://detrac-db.rit.albany.edu/download 40 | """ 41 | 42 | DETRAC_XML = 'detrac-xml' 43 | """Wen, Longyin et al. "UA-DETRAC: A New Benchmark and Protocol for Multi-Object Detection and Tracking." arXiv preprint arXiv:arXiv:1511.04136 (2016). 44 | http://detrac-db.rit.albany.edu/download 45 | """ 46 | 47 | 48 | def load_motchallenge(fname, **kwargs): 49 | r"""Load MOT challenge data. 50 | 51 | Params 52 | ------ 53 | fname : str 54 | Filename to load data from 55 | 56 | Kwargs 57 | ------ 58 | sep : str 59 | Allowed field separators, defaults to '\s+|\t+|,' 60 | min_confidence : float 61 | Rows with confidence less than this threshold are removed. 62 | Defaults to -1. You should set this to 1 when loading 63 | ground truth MOTChallenge data, so that invalid rectangles in 64 | the ground truth are not considered during matching. 65 | 66 | Returns 67 | ------ 68 | df : pandas.DataFrame 69 | The returned dataframe has the following columns 70 | 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility' 71 | The dataframe is indexed by ('FrameId', 'Id') 72 | """ 73 | 74 | sep = kwargs.pop('sep', r'\s+|\t+|,') 75 | min_confidence = kwargs.pop('min_confidence', -1) 76 | df = pd.read_csv( 77 | fname, 78 | sep=sep, 79 | index_col=[0, 1], 80 | skipinitialspace=True, 81 | header=None, 82 | names=['FrameId', 'Id', 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility', 'unused'], 83 | engine='python' 84 | ) 85 | 86 | # Account for matlab convention. 87 | df[['X', 'Y']] -= (1, 1) 88 | 89 | # Removed trailing column 90 | del df['unused'] 91 | 92 | # Remove all rows without sufficient confidence 93 | return df[df['Confidence'] >= min_confidence] 94 | 95 | 96 | def load_vatictxt(fname, **kwargs): 97 | """Load Vatic text format. 98 | 99 | Loads the vatic CSV text having the following columns per row 100 | 101 | 0 Track ID. All rows with the same ID belong to the same path. 102 | 1 xmin. The top left x-coordinate of the bounding box. 103 | 2 ymin. The top left y-coordinate of the bounding box. 104 | 3 xmax. The bottom right x-coordinate of the bounding box. 105 | 4 ymax. The bottom right y-coordinate of the bounding box. 106 | 5 frame. The frame that this annotation represents. 107 | 6 lost. If 1, the annotation is outside of the view screen. 108 | 7 occluded. If 1, the annotation is occluded. 109 | 8 generated. If 1, the annotation was automatically interpolated. 110 | 9 label. The label for this annotation, enclosed in quotation marks. 111 | 10+ attributes. Each column after this is an attribute set in the current frame 112 | 113 | Params 114 | ------ 115 | fname : str 116 | Filename to load data from 117 | 118 | Returns 119 | ------ 120 | df : pandas.DataFrame 121 | The returned dataframe has the following columns 122 | 'X', 'Y', 'Width', 'Height', 'Lost', 'Occluded', 'Generated', 'ClassId', '', '', ... 123 | where is placeholder for the actual attribute name capitalized (first letter). The order of attribute 124 | columns is sorted in attribute name. The dataframe is indexed by ('FrameId', 'Id') 125 | """ 126 | # pylint: disable=too-many-locals 127 | 128 | sep = kwargs.pop('sep', ' ') 129 | 130 | with io.open(fname) as f: 131 | # First time going over file, we collect the set of all variable activities 132 | activities = set() 133 | for line in f: 134 | for c in line.rstrip().split(sep)[10:]: 135 | activities.add(c) 136 | activitylist = sorted(list(activities)) 137 | 138 | # Second time we construct artificial binary columns for each activity 139 | data = [] 140 | f.seek(0) 141 | for line in f: 142 | fields = line.rstrip().split() 143 | attrs = ['0'] * len(activitylist) 144 | for a in fields[10:]: 145 | attrs[activitylist.index(a)] = '1' 146 | fields = fields[:10] 147 | fields.extend(attrs) 148 | data.append(' '.join(fields)) 149 | 150 | strdata = '\n'.join(data) 151 | 152 | dtype = { 153 | 'Id': np.int64, 154 | 'X': np.float32, 155 | 'Y': np.float32, 156 | 'Width': np.float32, 157 | 'Height': np.float32, 158 | 'FrameId': np.int64, 159 | 'Lost': bool, 160 | 'Occluded': bool, 161 | 'Generated': bool, 162 | 'ClassId': str, 163 | } 164 | 165 | # Remove quotes from activities 166 | activitylist = [a.replace('\"', '').capitalize() for a in activitylist] 167 | 168 | # Add dtypes for activities 169 | for a in activitylist: 170 | dtype[a] = bool 171 | 172 | # Read from CSV 173 | names = ['Id', 'X', 'Y', 'Width', 'Height', 'FrameId', 'Lost', 'Occluded', 'Generated', 'ClassId'] 174 | names.extend(activitylist) 175 | df = pd.read_csv(io.StringIO(strdata), names=names, index_col=['FrameId', 'Id'], header=None, sep=' ') 176 | 177 | # Correct Width and Height which are actually XMax, Ymax in files. 178 | w = df['Width'] - df['X'] 179 | h = df['Height'] - df['Y'] 180 | df['Width'] = w 181 | df['Height'] = h 182 | 183 | return df 184 | 185 | 186 | def load_detrac_mat(fname): 187 | """Loads UA-DETRAC annotations data from mat files 188 | 189 | Competition Site: http://detrac-db.rit.albany.edu/download 190 | 191 | File contains a nested structure of 2d arrays for indexed by frame id 192 | and Object ID. Separate arrays for top, left, width and height are given. 193 | 194 | Params 195 | ------ 196 | fname : str 197 | Filename to load data from 198 | 199 | Kwargs 200 | ------ 201 | Currently none of these arguments used. 202 | 203 | Returns 204 | ------ 205 | df : pandas.DataFrame 206 | The returned dataframe has the following columns 207 | 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility' 208 | The dataframe is indexed by ('FrameId', 'Id') 209 | """ 210 | 211 | matData = scipy.io.loadmat(fname) 212 | 213 | frameList = matData['gtInfo'][0][0][4][0] 214 | leftArray = matData['gtInfo'][0][0][0] 215 | topArray = matData['gtInfo'][0][0][1] 216 | widthArray = matData['gtInfo'][0][0][3] 217 | heightArray = matData['gtInfo'][0][0][2] 218 | 219 | parsedGT = [] 220 | for f in frameList: 221 | ids = [i + 1 for i, v in enumerate(leftArray[f - 1]) if v > 0] 222 | for i in ids: 223 | row = [] 224 | row.append(f) 225 | row.append(i) 226 | row.append(leftArray[f - 1, i - 1] - widthArray[f - 1, i - 1] / 2) 227 | row.append(topArray[f - 1, i - 1] - heightArray[f - 1, i - 1]) 228 | row.append(widthArray[f - 1, i - 1]) 229 | row.append(heightArray[f - 1, i - 1]) 230 | row.append(1) 231 | row.append(-1) 232 | row.append(-1) 233 | row.append(-1) 234 | parsedGT.append(row) 235 | 236 | df = pd.DataFrame(parsedGT, 237 | columns=['FrameId', 'Id', 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility', 'unused']) 238 | df.set_index(['FrameId', 'Id'], inplace=True) 239 | 240 | # Account for matlab convention. 241 | df[['X', 'Y']] -= (1, 1) 242 | 243 | # Removed trailing column 244 | del df['unused'] 245 | 246 | return df 247 | 248 | 249 | def load_detrac_xml(fname): 250 | """Loads UA-DETRAC annotations data from xml files 251 | 252 | Competition Site: http://detrac-db.rit.albany.edu/download 253 | 254 | Params 255 | ------ 256 | fname : str 257 | Filename to load data from 258 | 259 | Kwargs 260 | ------ 261 | Currently none of these arguments used. 262 | 263 | Returns 264 | ------ 265 | df : pandas.DataFrame 266 | The returned dataframe has the following columns 267 | 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility' 268 | The dataframe is indexed by ('FrameId', 'Id') 269 | """ 270 | 271 | with io.open(fname) as fd: 272 | doc = xmltodict.parse(fd.read()) 273 | frameList = doc['sequence']['frame'] 274 | 275 | parsedGT = [] 276 | for f in frameList: 277 | fid = int(f['@num']) 278 | targetList = f['target_list']['target'] 279 | if not isinstance(targetList, list): 280 | targetList = [targetList] 281 | 282 | for t in targetList: 283 | row = [] 284 | row.append(fid) 285 | row.append(int(t['@id'])) 286 | row.append(float(t['box']['@left'])) 287 | row.append(float(t['box']['@top'])) 288 | row.append(float(t['box']['@width'])) 289 | row.append(float(t['box']['@height'])) 290 | row.append(1) 291 | row.append(-1) 292 | row.append(-1) 293 | row.append(-1) 294 | parsedGT.append(row) 295 | 296 | df = pd.DataFrame(parsedGT, 297 | columns=['FrameId', 'Id', 'X', 'Y', 'Width', 'Height', 'Confidence', 'ClassId', 'Visibility', 'unused']) 298 | df.set_index(['FrameId', 'Id'], inplace=True) 299 | 300 | # Account for matlab convention. 301 | df[['X', 'Y']] -= (1, 1) 302 | 303 | # Removed trailing column 304 | del df['unused'] 305 | 306 | return df 307 | 308 | 309 | def loadtxt(fname, fmt=Format.MOT15_2D, **kwargs): 310 | """Load data from any known format.""" 311 | fmt = Format(fmt) 312 | 313 | switcher = { 314 | Format.MOT16: load_motchallenge, 315 | Format.MOT15_2D: load_motchallenge, 316 | Format.VATIC_TXT: load_vatictxt, 317 | Format.DETRAC_MAT: load_detrac_mat, 318 | Format.DETRAC_XML: load_detrac_xml 319 | } 320 | func = switcher.get(fmt) 321 | return func(fname, **kwargs) 322 | 323 | 324 | def render_summary(summary, formatters=None, namemap=None, buf=None): 325 | """Render metrics summary to console friendly tabular output. 326 | 327 | Params 328 | ------ 329 | summary : pd.DataFrame 330 | Dataframe containing summaries in rows. 331 | 332 | Kwargs 333 | ------ 334 | buf : StringIO-like, optional 335 | Buffer to write to 336 | formatters : dict, optional 337 | Dicionary defining custom formatters for individual metrics. 338 | I.e `{'mota': '{:.2%}'.format}`. You can get preset formatters 339 | from MetricsHost.formatters 340 | namemap : dict, optional 341 | Dictionary defining new metric names for display. I.e 342 | `{'num_false_positives': 'FP'}`. 343 | 344 | Returns 345 | ------- 346 | string 347 | Formatted string 348 | """ 349 | 350 | if namemap is not None: 351 | summary = summary.rename(columns=namemap) 352 | if formatters is not None: 353 | formatters = {namemap.get(c, c): f for c, f in formatters.items()} 354 | 355 | output = summary.to_string( 356 | buf=buf, 357 | formatters=formatters, 358 | ) 359 | 360 | return output 361 | 362 | 363 | motchallenge_metric_names = { 364 | 'idf1': 'IDF1', 365 | 'idp': 'IDP', 366 | 'idr': 'IDR', 367 | 'recall': 'Rcll', 368 | 'precision': 'Prcn', 369 | 'num_unique_objects': 'GT', 370 | 'mostly_tracked': 'MT', 371 | 'partially_tracked': 'PT', 372 | 'mostly_lost': 'ML', 373 | 'num_false_positives': 'FP', 374 | 'num_misses': 'FN', 375 | 'num_switches': 'IDs', 376 | 'num_fragmentations': 'FM', 377 | 'mota': 'MOTA', 378 | 'motp': 'MOTP', 379 | 'num_transfer': 'IDt', 380 | 'num_ascend': 'IDa', 381 | 'num_migrate': 'IDm', 382 | } 383 | """A list mappings for metric names to comply with MOTChallenge.""" 384 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/lap.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Tools for solving linear assignment problems.""" 9 | 10 | # pylint: disable=import-outside-toplevel 11 | 12 | from __future__ import absolute_import 13 | from __future__ import division 14 | from __future__ import print_function 15 | 16 | from contextlib import contextmanager 17 | import warnings 18 | 19 | import numpy as np 20 | 21 | 22 | def _module_is_available_py2(name): 23 | try: 24 | imp.find_module(name) 25 | return True 26 | except ImportError: 27 | return False 28 | 29 | 30 | def _module_is_available_py3(name): 31 | return importlib.util.find_spec(name) is not None 32 | 33 | 34 | try: 35 | import importlib.util 36 | except ImportError: 37 | import imp 38 | _module_is_available = _module_is_available_py2 39 | else: 40 | _module_is_available = _module_is_available_py3 41 | 42 | 43 | def linear_sum_assignment(costs, solver=None): 44 | """Solve a linear sum assignment problem (LSA). 45 | 46 | For large datasets solving the minimum cost assignment becomes the dominant runtime part. 47 | We therefore support various solvers out of the box (currently lapsolver, scipy, ortools, munkres) 48 | 49 | Params 50 | ------ 51 | costs : np.array 52 | numpy matrix containing costs. Use NaN/Inf values for unassignable 53 | row/column pairs. 54 | 55 | Kwargs 56 | ------ 57 | solver : callable or str, optional 58 | When str: name of solver to use. 59 | When callable: function to invoke 60 | When None: uses first available solver 61 | """ 62 | costs = np.asarray(costs) 63 | if not costs.size: 64 | return np.array([], dtype=int), np.array([], dtype=int) 65 | 66 | solver = solver or default_solver 67 | 68 | if isinstance(solver, str): 69 | # Try resolve from string 70 | solver = solver_map.get(solver, None) 71 | 72 | assert callable(solver), 'Invalid LAP solver.' 73 | rids, cids = solver(costs) 74 | rids = np.asarray(rids).astype(int) 75 | cids = np.asarray(cids).astype(int) 76 | return rids, cids 77 | 78 | 79 | def add_expensive_edges(costs): 80 | """Replaces non-edge costs (nan, inf) with large number. 81 | 82 | If the optimal solution includes one of these edges, 83 | then the original problem was infeasible. 84 | 85 | Parameters 86 | ---------- 87 | costs : np.ndarray 88 | """ 89 | # The graph is probably already dense if we are doing this. 90 | assert isinstance(costs, np.ndarray) 91 | # The linear_sum_assignment function in scipy does not support missing edges. 92 | # Replace nan with a large constant that ensures it is not chosen. 93 | # If it is chosen, that means the problem was infeasible. 94 | valid = np.isfinite(costs) 95 | if valid.all(): 96 | return costs.copy() 97 | if not valid.any(): 98 | return np.zeros_like(costs) 99 | r = min(costs.shape) 100 | # Assume all edges costs are within [-c, c], c >= 0. 101 | # The cost of an invalid edge must be such that... 102 | # choosing this edge once and the best-possible edge (r - 1) times 103 | # is worse than choosing the worst-possible edge r times. 104 | # l + (r - 1) (-c) > r c 105 | # l > r c + (r - 1) c 106 | # l > (2 r - 1) c 107 | # Choose l = 2 r c + 1 > (2 r - 1) c. 108 | c = np.abs(costs[valid]).max() + 1 # Doesn't hurt to add 1 here. 109 | large_constant = 2 * r * c + 1 110 | return np.where(valid, costs, large_constant) 111 | 112 | 113 | def _exclude_missing_edges(costs, rids, cids): 114 | subset = [ 115 | index for index, (i, j) in enumerate(zip(rids, cids)) 116 | if np.isfinite(costs[i, j]) 117 | ] 118 | return rids[subset], cids[subset] 119 | 120 | 121 | def lsa_solve_scipy(costs): 122 | """Solves the LSA problem using the scipy library.""" 123 | 124 | from scipy.optimize import linear_sum_assignment as scipy_solve 125 | 126 | # scipy (1.3.3) does not support nan or inf values 127 | finite_costs = add_expensive_edges(costs) 128 | rids, cids = scipy_solve(finite_costs) 129 | rids, cids = _exclude_missing_edges(costs, rids, cids) 130 | return rids, cids 131 | 132 | 133 | def lsa_solve_lapsolver(costs): 134 | """Solves the LSA problem using the lapsolver library.""" 135 | from lapsolver import solve_dense 136 | 137 | # Note that lapsolver will add expensive finite edges internally. 138 | # However, older versions did not add a large enough edge. 139 | finite_costs = add_expensive_edges(costs) 140 | rids, cids = solve_dense(finite_costs) 141 | rids, cids = _exclude_missing_edges(costs, rids, cids) 142 | return rids, cids 143 | 144 | 145 | def lsa_solve_munkres(costs): 146 | """Solves the LSA problem using the Munkres library.""" 147 | from munkres import Munkres 148 | 149 | m = Munkres() 150 | # The munkres package may hang if the problem is not feasible. 151 | # Therefore, add expensive edges instead of using munkres.DISALLOWED. 152 | finite_costs = add_expensive_edges(costs) 153 | # Ensure that matrix is square. 154 | finite_costs = _zero_pad_to_square(finite_costs) 155 | indices = np.array(m.compute(finite_costs), dtype=int) 156 | # Exclude extra matches from extension to square matrix. 157 | indices = indices[(indices[:, 0] < costs.shape[0]) 158 | & (indices[:, 1] < costs.shape[1])] 159 | rids, cids = indices[:, 0], indices[:, 1] 160 | rids, cids = _exclude_missing_edges(costs, rids, cids) 161 | return rids, cids 162 | 163 | 164 | def _zero_pad_to_square(costs): 165 | num_rows, num_cols = costs.shape 166 | if num_rows == num_cols: 167 | return costs 168 | n = max(num_rows, num_cols) 169 | padded = np.zeros((n, n), dtype=costs.dtype) 170 | padded[:num_rows, :num_cols] = costs 171 | return padded 172 | 173 | 174 | def lsa_solve_ortools(costs): 175 | """Solves the LSA problem using Google's optimization tools. """ 176 | from ortools.graph import pywrapgraph 177 | 178 | if costs.shape[0] != costs.shape[1]: 179 | # ortools assumes that the problem is square. 180 | # Non-square problem will be infeasible. 181 | # Default to scipy solver rather than add extra zeros. 182 | # (This maintains the same behaviour as previous versions.) 183 | return linear_sum_assignment(costs, solver='scipy') 184 | 185 | rs, cs = np.isfinite(costs).nonzero() # pylint: disable=unbalanced-tuple-unpacking 186 | finite_costs = costs[rs, cs] 187 | scale = find_scale_for_integer_approximation(finite_costs) 188 | if scale != 1: 189 | warnings.warn('costs are not integers; using approximation') 190 | int_costs = np.round(scale * finite_costs).astype(int) 191 | 192 | assignment = pywrapgraph.LinearSumAssignment() 193 | # OR-Tools does not like to receive indices of type np.int64. 194 | rs = rs.tolist() # pylint: disable=no-member 195 | cs = cs.tolist() 196 | int_costs = int_costs.tolist() 197 | for r, c, int_cost in zip(rs, cs, int_costs): 198 | assignment.AddArcWithCost(r, c, int_cost) 199 | 200 | status = assignment.Solve() 201 | try: 202 | _ortools_assert_is_optimal(pywrapgraph, status) 203 | except AssertionError: 204 | # Default to scipy solver rather than add finite edges. 205 | # (This maintains the same behaviour as previous versions.) 206 | return linear_sum_assignment(costs, solver='scipy') 207 | 208 | return _ortools_extract_solution(assignment) 209 | 210 | 211 | def find_scale_for_integer_approximation(costs, base=10, log_max_scale=8, log_safety=2): 212 | """Returns a multiplicative factor to use before rounding to integers. 213 | 214 | Tries to find scale = base ** j (for j integer) such that: 215 | abs(diff(unique(costs))) <= 1 / (scale * safety) 216 | where safety = base ** log_safety. 217 | 218 | Logs a warning if the desired resolution could not be achieved. 219 | """ 220 | costs = np.asarray(costs) 221 | costs = costs[np.isfinite(costs)] # Exclude non-edges (nan, inf) and -inf. 222 | if np.size(costs) == 0: 223 | # No edges with numeric value. Scale does not matter. 224 | return 1 225 | unique = np.unique(costs) 226 | if np.size(unique) == 1: 227 | # All costs have equal values. Scale does not matter. 228 | return 1 229 | try: 230 | _assert_integer(costs) 231 | except AssertionError: 232 | pass 233 | else: 234 | # The costs are already integers. 235 | return 1 236 | 237 | # Find scale = base ** e such that: 238 | # 1 / scale <= tol, or 239 | # e = log(scale) >= -log(tol) 240 | # where tol = min(diff(unique(costs))) 241 | min_diff = np.diff(unique).min() 242 | e = np.ceil(np.log(min_diff) / np.log(base)).astype(int).item() 243 | # Add optional non-negative safety factor to reduce quantization noise. 244 | e += max(log_safety, 0) 245 | # Ensure that we do not reduce the magnitude of the costs. 246 | e = max(e, 0) 247 | # Ensure that the scale is not too large. 248 | if e > log_max_scale: 249 | warnings.warn('could not achieve desired resolution for approximation: ' 250 | 'want exponent %d but max is %d', e, log_max_scale) 251 | e = log_max_scale 252 | scale = base ** e 253 | return scale 254 | 255 | 256 | def _assert_integer(costs): 257 | # Check that costs are not changed by rounding. 258 | # Note: Elements of cost matrix may be nan, inf, -inf. 259 | np.testing.assert_equal(np.round(costs), costs) 260 | 261 | 262 | def _ortools_assert_is_optimal(pywrapgraph, status): 263 | if status == pywrapgraph.LinearSumAssignment.OPTIMAL: 264 | pass 265 | elif status == pywrapgraph.LinearSumAssignment.INFEASIBLE: 266 | raise AssertionError('ortools: infeasible assignment problem') 267 | elif status == pywrapgraph.LinearSumAssignment.POSSIBLE_OVERFLOW: 268 | raise AssertionError('ortools: possible overflow in assignment problem') 269 | else: 270 | raise AssertionError('ortools: unknown status') 271 | 272 | 273 | def _ortools_extract_solution(assignment): 274 | if assignment.NumNodes() == 0: 275 | return np.array([], dtype=int), np.array([], dtype=int) 276 | 277 | pairings = [] 278 | for i in range(assignment.NumNodes()): 279 | pairings.append([i, assignment.RightMate(i)]) 280 | 281 | indices = np.array(pairings, dtype=int) 282 | return indices[:, 0], indices[:, 1] 283 | 284 | 285 | def lsa_solve_lapjv(costs): 286 | """Solves the LSA problem using lap.lapjv().""" 287 | 288 | from lap import lapjv 289 | 290 | # The lap.lapjv function supports +inf edges but there are some issues. 291 | # https://github.com/gatagat/lap/issues/20 292 | # Therefore, replace nans with large finite cost. 293 | finite_costs = add_expensive_edges(costs) 294 | row_to_col, _ = lapjv(finite_costs, return_cost=False, extend_cost=True) 295 | indices = np.array([np.arange(costs.shape[0]), row_to_col], dtype=int).T 296 | # Exclude unmatched rows (in case of unbalanced problem). 297 | indices = indices[indices[:, 1] != -1] # pylint: disable=unsubscriptable-object 298 | rids, cids = indices[:, 0], indices[:, 1] 299 | # Ensure that no missing edges were chosen. 300 | rids, cids = _exclude_missing_edges(costs, rids, cids) 301 | return rids, cids 302 | 303 | 304 | available_solvers = None 305 | default_solver = None 306 | solver_map = None 307 | 308 | 309 | def _init_standard_solvers(): 310 | global available_solvers, default_solver, solver_map # pylint: disable=global-statement 311 | 312 | solvers = [ 313 | ('lapsolver', lsa_solve_lapsolver), 314 | ('lap', lsa_solve_lapjv), 315 | ('scipy', lsa_solve_scipy), 316 | ('munkres', lsa_solve_munkres), 317 | ('ortools', lsa_solve_ortools), 318 | ] 319 | 320 | solver_map = dict(solvers) 321 | 322 | available_solvers = [s[0] for s in solvers if _module_is_available(s[0])] 323 | if len(available_solvers) == 0: 324 | default_solver = None 325 | warnings.warn('No standard LAP solvers found. Consider `pip install lapsolver` or `pip install scipy`', category=RuntimeWarning) 326 | else: 327 | default_solver = available_solvers[0] 328 | 329 | 330 | _init_standard_solvers() 331 | 332 | 333 | @contextmanager 334 | def set_default_solver(newsolver): 335 | """Change the default solver within context. 336 | 337 | Intended usage 338 | 339 | costs = ... 340 | mysolver = lambda x: ... # solver code that returns pairings 341 | 342 | with lap.set_default_solver(mysolver): 343 | rids, cids = lap.linear_sum_assignment(costs) 344 | 345 | Params 346 | ------ 347 | newsolver : callable or str 348 | new solver function 349 | """ 350 | 351 | global default_solver # pylint: disable=global-statement 352 | 353 | oldsolver = default_solver 354 | try: 355 | default_solver = newsolver 356 | yield 357 | finally: 358 | default_solver = oldsolver 359 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/math_util.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Math utility functions.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import warnings 15 | 16 | import numpy as np 17 | 18 | 19 | def quiet_divide(a, b): 20 | """Quiet divide function that does not warn about (0 / 0).""" 21 | with warnings.catch_warnings(): 22 | warnings.simplefilter('ignore', RuntimeWarning) 23 | return np.true_divide(a, b) 24 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/preprocess.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Preprocess data for CLEAR_MOT_M.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | from configparser import ConfigParser 15 | import logging 16 | import time 17 | 18 | import numpy as np 19 | 20 | from . import distances as mmd 21 | from .lap import linear_sum_assignment 22 | 23 | 24 | def preprocessResult(res, gt, inifile): 25 | """Preprocesses data for utils.CLEAR_MOT_M. 26 | 27 | Returns a subset of the predictions. 28 | """ 29 | # pylint: disable=too-many-locals 30 | st = time.time() 31 | labels = [ 32 | 'ped', # 1 33 | 'person_on_vhcl', # 2 34 | 'car', # 3 35 | 'bicycle', # 4 36 | 'mbike', # 5 37 | 'non_mot_vhcl', # 6 38 | 'static_person', # 7 39 | 'distractor', # 8 40 | 'occluder', # 9 41 | 'occluder_on_grnd', # 10 42 | 'occluder_full', # 11 43 | 'reflection', # 12 44 | 'crowd', # 13 45 | ] 46 | distractors = ['person_on_vhcl', 'static_person', 'distractor', 'reflection'] 47 | is_distractor = {i + 1: x in distractors for i, x in enumerate(labels)} 48 | for i in distractors: 49 | is_distractor[i] = 1 50 | seqIni = ConfigParser() 51 | seqIni.read(inifile, encoding='utf8') 52 | F = int(seqIni['Sequence']['seqLength']) 53 | todrop = [] 54 | for t in range(1, F + 1): 55 | if t not in res.index or t not in gt.index: 56 | continue 57 | resInFrame = res.loc[t] 58 | 59 | GTInFrame = gt.loc[t] 60 | A = GTInFrame[['X', 'Y', 'Width', 'Height']].values 61 | B = resInFrame[['X', 'Y', 'Width', 'Height']].values 62 | disM = mmd.iou_matrix(A, B, max_iou=0.5) 63 | le, ri = linear_sum_assignment(disM) 64 | flags = [ 65 | 1 if is_distractor[it['ClassId']] or it['Visibility'] < 0. else 0 66 | for i, (k, it) in enumerate(GTInFrame.iterrows()) 67 | ] 68 | hid = [k for k, it in resInFrame.iterrows()] 69 | for i, j in zip(le, ri): 70 | if not np.isfinite(disM[i, j]): 71 | continue 72 | if flags[i]: 73 | todrop.append((t, hid[j])) 74 | ret = res.drop(labels=todrop) 75 | logging.info('Preprocess take %.3f seconds and remove %d boxes.', 76 | time.time() - st, len(todrop)) 77 | return ret 78 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/posetrack21_eval/motmetrics/tests/__init__.py -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/tests/test_distances.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Tests distance computation.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import numpy as np 15 | 16 | import motmetrics as mm 17 | 18 | 19 | def test_norm2squared(): 20 | """Tests norm2squared_matrix.""" 21 | a = np.asfarray([ 22 | [1, 2], 23 | [2, 2], 24 | [3, 2], 25 | ]) 26 | 27 | b = np.asfarray([ 28 | [0, 0], 29 | [1, 1], 30 | ]) 31 | 32 | C = mm.distances.norm2squared_matrix(a, b) 33 | np.testing.assert_allclose( 34 | C, 35 | [ 36 | [5, 1], 37 | [8, 2], 38 | [13, 5] 39 | ] 40 | ) 41 | 42 | C = mm.distances.norm2squared_matrix(a, b, max_d2=5) 43 | np.testing.assert_allclose( 44 | C, 45 | [ 46 | [5, 1], 47 | [np.nan, 2], 48 | [np.nan, 5] 49 | ] 50 | ) 51 | 52 | 53 | def test_norm2squared_empty(): 54 | """Tests norm2squared_matrix with an empty input.""" 55 | a = [] 56 | b = np.asfarray([[0, 0], [1, 1]]) 57 | C = mm.distances.norm2squared_matrix(a, b) 58 | assert C.size == 0 59 | C = mm.distances.norm2squared_matrix(b, a) 60 | assert C.size == 0 61 | 62 | 63 | def test_iou_matrix(): 64 | """Tests iou_matrix.""" 65 | a = np.array([ 66 | [0, 0, 1, 2], 67 | ]) 68 | 69 | b = np.array([ 70 | [0, 0, 1, 2], 71 | [0, 0, 1, 1], 72 | [1, 1, 1, 1], 73 | [0.5, 0, 1, 1], 74 | [0, 1, 1, 1], 75 | ]) 76 | np.testing.assert_allclose( 77 | mm.distances.iou_matrix(a, b), 78 | [[0, 0.5, 1, 0.8, 0.5]], 79 | atol=1e-4 80 | ) 81 | 82 | np.testing.assert_allclose( 83 | mm.distances.iou_matrix(a, b, max_iou=0.5), 84 | [[0, 0.5, np.nan, np.nan, 0.5]], 85 | atol=1e-4 86 | ) 87 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/tests/test_io.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Tests IO functions.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os 15 | 16 | import pandas as pd 17 | 18 | import motmetrics as mm 19 | 20 | DATA_DIR = os.path.join(os.path.dirname(__file__), '../data') 21 | 22 | 23 | def test_load_vatic(): 24 | """Tests VATIC_TXT format.""" 25 | df = mm.io.loadtxt(os.path.join(DATA_DIR, 'iotest/vatic.txt'), fmt=mm.io.Format.VATIC_TXT) 26 | 27 | expected = pd.DataFrame([ 28 | # F,ID,Y,W,H,L,O,G,F,A1,A2,A3,A4 29 | (0, 0, 412, 0, 430, 124, 0, 0, 0, 'worker', 0, 0, 0, 0), 30 | (1, 0, 412, 10, 430, 114, 0, 0, 1, 'pc', 1, 0, 1, 0), 31 | (1, 1, 412, 0, 430, 124, 0, 0, 1, 'pc', 0, 1, 0, 0), 32 | (2, 2, 412, 0, 430, 124, 0, 0, 1, 'worker', 1, 1, 0, 1) 33 | ]) 34 | 35 | assert (df.reset_index().values == expected.values).all() 36 | 37 | 38 | def test_load_motchallenge(): 39 | """Tests MOT15_2D format.""" 40 | df = mm.io.loadtxt(os.path.join(DATA_DIR, 'iotest/motchallenge.txt'), fmt=mm.io.Format.MOT15_2D) 41 | 42 | expected = pd.DataFrame([ 43 | (1, 1, 398, 181, 121, 229, 1, -1, -1), # Note -1 on x and y for correcting matlab 44 | (1, 2, 281, 200, 92, 184, 1, -1, -1), 45 | (2, 2, 268, 201, 87, 182, 1, -1, -1), 46 | (2, 3, 70, 150, 100, 284, 1, -1, -1), 47 | (2, 4, 199, 205, 55, 137, 1, -1, -1), 48 | ]) 49 | 50 | assert (df.reset_index().values == expected.values).all() 51 | 52 | 53 | def test_load_detrac_mat(): 54 | """Tests DETRAC_MAT format.""" 55 | df = mm.io.loadtxt(os.path.join(DATA_DIR, 'iotest/detrac.mat'), fmt=mm.io.Format.DETRAC_MAT) 56 | 57 | expected = pd.DataFrame([ 58 | (1., 1., 745., 356., 148., 115., 1., -1., -1.), 59 | (2., 1., 738., 350., 145., 111., 1., -1., -1.), 60 | (3., 1., 732., 343., 142., 107., 1., -1., -1.), 61 | (4., 1., 725., 336., 139., 104., 1., -1., -1.) 62 | ]) 63 | 64 | assert (df.reset_index().values == expected.values).all() 65 | 66 | 67 | def test_load_detrac_xml(): 68 | """Tests DETRAC_XML format.""" 69 | df = mm.io.loadtxt(os.path.join(DATA_DIR, 'iotest/detrac.xml'), fmt=mm.io.Format.DETRAC_XML) 70 | 71 | expected = pd.DataFrame([ 72 | (1., 1., 744.6, 356.33, 148.2, 115.14, 1., -1., -1.), 73 | (2., 1., 738.2, 349.51, 145.21, 111.29, 1., -1., -1.), 74 | (3., 1., 731.8, 342.68, 142.23, 107.45, 1., -1., -1.), 75 | (4., 1., 725.4, 335.85, 139.24, 103.62, 1., -1., -1.) 76 | ]) 77 | 78 | assert (df.reset_index().values == expected.values).all() 79 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/tests/test_issue19.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Tests issue 19. 9 | 10 | https://github.com/cheind/py-motmetrics/issues/19 11 | """ 12 | 13 | from __future__ import absolute_import 14 | from __future__ import division 15 | from __future__ import print_function 16 | 17 | import numpy as np 18 | 19 | import motmetrics as mm 20 | 21 | 22 | def test_issue19(): 23 | """Tests issue 19.""" 24 | acc = mm.MOTAccumulator() 25 | 26 | g0 = [0, 1] 27 | p0 = [0, 1] 28 | d0 = [[0.2, np.nan], [np.nan, 0.2]] 29 | 30 | g1 = [2, 3] 31 | p1 = [2, 3, 4, 5] 32 | d1 = [[0.28571429, 0.5, 0.0, np.nan], [np.nan, 0.44444444, np.nan, 0.0]] 33 | 34 | acc.update(g0, p0, d0, 0) 35 | acc.update(g1, p1, d1, 1) 36 | 37 | mh = mm.metrics.create() 38 | mh.compute(acc) 39 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/tests/test_lap.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Tests linear assignment problem solvers.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import warnings 15 | 16 | import numpy as np 17 | import pytest 18 | 19 | from motmetrics import lap 20 | 21 | DESIRED_SOLVERS = ['lap', 'lapsolver', 'munkres', 'ortools', 'scipy'] 22 | SOLVERS = lap.available_solvers 23 | 24 | 25 | @pytest.mark.parametrize('solver', DESIRED_SOLVERS) 26 | def test_solver_is_available(solver): 27 | if solver not in lap.available_solvers: 28 | warnings.warn('solver not available: ' + solver) 29 | 30 | 31 | @pytest.mark.parametrize('solver', SOLVERS) 32 | def test_assign_easy(solver): 33 | """Problem that could be solved by a greedy algorithm.""" 34 | costs = np.asfarray([[6, 9, 1], [10, 3, 2], [8, 7, 4]]) 35 | costs_copy = costs.copy() 36 | result = lap.linear_sum_assignment(costs, solver=solver) 37 | 38 | expected = np.array([[0, 1, 2], [2, 1, 0]]) 39 | np.testing.assert_equal(result, expected) 40 | np.testing.assert_equal(costs, costs_copy) 41 | 42 | 43 | @pytest.mark.parametrize('solver', SOLVERS) 44 | def test_assign_full(solver): 45 | """Problem that would be incorrect using a greedy algorithm.""" 46 | costs = np.asfarray([[5, 5, 6], [1, 2, 5], [2, 4, 5]]) 47 | costs_copy = costs.copy() 48 | result = lap.linear_sum_assignment(costs, solver=solver) 49 | 50 | # Optimal matching is (0, 2), (1, 1), (2, 0) for 6 + 2 + 2. 51 | expected = np.asfarray([[0, 1, 2], [2, 1, 0]]) 52 | np.testing.assert_equal(result, expected) 53 | np.testing.assert_equal(costs, costs_copy) 54 | 55 | 56 | @pytest.mark.parametrize('solver', SOLVERS) 57 | def test_assign_full_negative(solver): 58 | costs = -7 + np.asfarray([[5, 5, 6], [1, 2, 5], [2, 4, 5]]) 59 | costs_copy = costs.copy() 60 | result = lap.linear_sum_assignment(costs, solver=solver) 61 | 62 | # Optimal matching is (0, 2), (1, 1), (2, 0) for 5 + 1 + 1. 63 | expected = np.array([[0, 1, 2], [2, 1, 0]]) 64 | np.testing.assert_equal(result, expected) 65 | np.testing.assert_equal(costs, costs_copy) 66 | 67 | 68 | @pytest.mark.parametrize('solver', SOLVERS) 69 | def test_assign_empty(solver): 70 | costs = np.asfarray([[]]) 71 | costs_copy = costs.copy() 72 | result = lap.linear_sum_assignment(costs, solver=solver) 73 | 74 | np.testing.assert_equal(np.size(result), 0) 75 | np.testing.assert_equal(costs, costs_copy) 76 | 77 | 78 | @pytest.mark.parametrize('solver', SOLVERS) 79 | def test_assign_infeasible(solver): 80 | """Tests that minimum-cost solution with most edges is found.""" 81 | costs = np.asfarray([[np.nan, np.nan, 2], 82 | [np.nan, np.nan, 1], 83 | [8, 7, 4]]) 84 | costs_copy = costs.copy() 85 | result = lap.linear_sum_assignment(costs, solver=solver) 86 | 87 | # Optimal matching is (1, 2), (2, 1). 88 | expected = np.array([[1, 2], [2, 1]]) 89 | np.testing.assert_equal(result, expected) 90 | np.testing.assert_equal(costs, costs_copy) 91 | 92 | 93 | @pytest.mark.parametrize('solver', SOLVERS) 94 | def test_assign_disallowed(solver): 95 | costs = np.asfarray([[5, 9, np.nan], [10, np.nan, 2], [8, 7, 4]]) 96 | costs_copy = costs.copy() 97 | result = lap.linear_sum_assignment(costs, solver=solver) 98 | 99 | expected = np.array([[0, 1, 2], [0, 2, 1]]) 100 | np.testing.assert_equal(result, expected) 101 | np.testing.assert_equal(costs, costs_copy) 102 | 103 | 104 | @pytest.mark.parametrize('solver', SOLVERS) 105 | def test_assign_non_integer(solver): 106 | costs = (1. / 9) * np.asfarray([[5, 9, np.nan], [10, np.nan, 2], [8, 7, 4]]) 107 | costs_copy = costs.copy() 108 | result = lap.linear_sum_assignment(costs, solver=solver) 109 | 110 | expected = np.array([[0, 1, 2], [0, 2, 1]]) 111 | np.testing.assert_equal(result, expected) 112 | np.testing.assert_equal(costs, costs_copy) 113 | 114 | 115 | @pytest.mark.parametrize('solver', SOLVERS) 116 | def test_assign_attractive_disallowed(solver): 117 | """Graph contains an attractive edge that cannot be used.""" 118 | costs = np.asfarray([[-10000, -1], [-1, np.nan]]) 119 | costs_copy = costs.copy() 120 | result = lap.linear_sum_assignment(costs, solver=solver) 121 | 122 | # The optimal solution is (0, 1), (1, 0) for a cost of -2. 123 | # Ensure that the algorithm does not choose the (0, 0) edge. 124 | # This would not be a perfect matching. 125 | expected = np.array([[0, 1], [1, 0]]) 126 | np.testing.assert_equal(result, expected) 127 | np.testing.assert_equal(costs, costs_copy) 128 | 129 | 130 | @pytest.mark.parametrize('solver', SOLVERS) 131 | def test_assign_attractive_broken_ring(solver): 132 | """Graph contains cheap broken ring and expensive unbroken ring.""" 133 | costs = np.asfarray([[np.nan, 1000, np.nan], [np.nan, 1, 1000], [1000, np.nan, 1]]) 134 | costs_copy = costs.copy() 135 | result = lap.linear_sum_assignment(costs, solver=solver) 136 | 137 | # Optimal solution is (0, 1), (1, 2), (2, 0) with cost 1000 + 1000 + 1000. 138 | # Solver might choose (0, 0), (1, 1), (2, 2) with cost inf + 1 + 1. 139 | expected = np.array([[0, 1, 2], [1, 2, 0]]) 140 | np.testing.assert_equal(result, expected) 141 | np.testing.assert_equal(costs, costs_copy) 142 | 143 | 144 | @pytest.mark.parametrize('solver', SOLVERS) 145 | def test_unbalanced_wide(solver): 146 | costs = np.asfarray([[6, 4, 1], [10, 8, 2]]) 147 | costs_copy = costs.copy() 148 | result = lap.linear_sum_assignment(costs, solver=solver) 149 | 150 | expected = np.array([[0, 1], [1, 2]]) 151 | np.testing.assert_equal(result, expected) 152 | np.testing.assert_equal(costs, costs_copy) 153 | 154 | 155 | @pytest.mark.parametrize('solver', SOLVERS) 156 | def test_unbalanced_tall(solver): 157 | costs = np.asfarray([[6, 10], [4, 8], [1, 2]]) 158 | costs_copy = costs.copy() 159 | result = lap.linear_sum_assignment(costs, solver=solver) 160 | 161 | expected = np.array([[1, 2], [0, 1]]) 162 | np.testing.assert_equal(result, expected) 163 | np.testing.assert_equal(costs, costs_copy) 164 | 165 | 166 | @pytest.mark.parametrize('solver', SOLVERS) 167 | def test_unbalanced_disallowed_wide(solver): 168 | costs = np.asfarray([[np.nan, 11, 8], [8, np.nan, 7]]) 169 | costs_copy = costs.copy() 170 | result = lap.linear_sum_assignment(costs, solver=solver) 171 | 172 | expected = np.array([[0, 1], [2, 0]]) 173 | np.testing.assert_equal(result, expected) 174 | np.testing.assert_equal(costs, costs_copy) 175 | 176 | 177 | @pytest.mark.parametrize('solver', SOLVERS) 178 | def test_unbalanced_disallowed_tall(solver): 179 | costs = np.asfarray([[np.nan, 9], [11, np.nan], [8, 7]]) 180 | costs_copy = costs.copy() 181 | result = lap.linear_sum_assignment(costs, solver=solver) 182 | 183 | expected = np.array([[0, 2], [1, 0]]) 184 | np.testing.assert_equal(result, expected) 185 | np.testing.assert_equal(costs, costs_copy) 186 | 187 | 188 | @pytest.mark.parametrize('solver', SOLVERS) 189 | def test_unbalanced_infeasible(solver): 190 | """Tests that minimum-cost solution with most edges is found.""" 191 | costs = np.asfarray([[np.nan, np.nan, 2], 192 | [np.nan, np.nan, 1], 193 | [np.nan, np.nan, 3], 194 | [8, 7, 4]]) 195 | costs_copy = costs.copy() 196 | result = lap.linear_sum_assignment(costs, solver=solver) 197 | 198 | # Optimal matching is (1, 2), (3, 1). 199 | expected = np.array([[1, 3], [2, 1]]) 200 | np.testing.assert_equal(result, expected) 201 | np.testing.assert_equal(costs, costs_copy) 202 | 203 | 204 | def test_change_solver(): 205 | """Tests effect of lap.set_default_solver.""" 206 | 207 | def mysolver(_): 208 | mysolver.called += 1 209 | return np.array([]), np.array([]) 210 | mysolver.called = 0 211 | 212 | costs = np.asfarray([[6, 9, 1], [10, 3, 2], [8, 7, 4]]) 213 | 214 | with lap.set_default_solver(mysolver): 215 | lap.linear_sum_assignment(costs) 216 | assert mysolver.called == 1 217 | lap.linear_sum_assignment(costs) 218 | assert mysolver.called == 1 219 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/tests/test_metrics.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Tests computation of metrics from accumulator.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import os 15 | 16 | import numpy as np 17 | import pandas as pd 18 | from pytest import approx 19 | 20 | import motmetrics as mm 21 | 22 | DATA_DIR = os.path.join(os.path.dirname(__file__), '../data') 23 | 24 | 25 | def test_metricscontainer_1(): 26 | """Tests registration of events with dependencies.""" 27 | m = mm.metrics.MetricsHost() 28 | m.register(lambda df: 1., name='a') 29 | m.register(lambda df: 2., name='b') 30 | m.register(lambda df, a, b: a + b, deps=['a', 'b'], name='add') 31 | m.register(lambda df, a, b: a - b, deps=['a', 'b'], name='sub') 32 | m.register(lambda df, a, b: a * b, deps=['add', 'sub'], name='mul') 33 | summary = m.compute(mm.MOTAccumulator.new_event_dataframe(), metrics=['mul', 'add'], name='x') 34 | assert summary.columns.values.tolist() == ['mul', 'add'] 35 | assert summary.iloc[0]['mul'] == -3. 36 | assert summary.iloc[0]['add'] == 3. 37 | 38 | 39 | def test_metricscontainer_autodep(): 40 | """Tests automatic dependencies from argument names.""" 41 | m = mm.metrics.MetricsHost() 42 | m.register(lambda df: 1., name='a') 43 | m.register(lambda df: 2., name='b') 44 | m.register(lambda df, a, b: a + b, name='add', deps='auto') 45 | m.register(lambda df, a, b: a - b, name='sub', deps='auto') 46 | m.register(lambda df, add, sub: add * sub, name='mul', deps='auto') 47 | summary = m.compute(mm.MOTAccumulator.new_event_dataframe(), metrics=['mul', 'add']) 48 | assert summary.columns.values.tolist() == ['mul', 'add'] 49 | assert summary.iloc[0]['mul'] == -3. 50 | assert summary.iloc[0]['add'] == 3. 51 | 52 | 53 | def test_metricscontainer_autoname(): 54 | """Tests automatic names (and dependencies) from inspection.""" 55 | 56 | def constant_a(_): 57 | """Constant a help.""" 58 | return 1. 59 | 60 | def constant_b(_): 61 | return 2. 62 | 63 | def add(_, constant_a, constant_b): 64 | return constant_a + constant_b 65 | 66 | def sub(_, constant_a, constant_b): 67 | return constant_a - constant_b 68 | 69 | def mul(_, add, sub): 70 | return add * sub 71 | 72 | m = mm.metrics.MetricsHost() 73 | m.register(constant_a, deps='auto') 74 | m.register(constant_b, deps='auto') 75 | m.register(add, deps='auto') 76 | m.register(sub, deps='auto') 77 | m.register(mul, deps='auto') 78 | 79 | assert m.metrics['constant_a']['help'] == 'Constant a help.' 80 | 81 | summary = m.compute(mm.MOTAccumulator.new_event_dataframe(), metrics=['mul', 'add']) 82 | assert summary.columns.values.tolist() == ['mul', 'add'] 83 | assert summary.iloc[0]['mul'] == -3. 84 | assert summary.iloc[0]['add'] == 3. 85 | 86 | 87 | def test_metrics_with_no_events(): 88 | """Tests metrics when accumulator is empty.""" 89 | acc = mm.MOTAccumulator() 90 | 91 | mh = mm.metrics.create() 92 | metr = mh.compute(acc, return_dataframe=False, return_cached=True, metrics=[ 93 | 'mota', 'motp', 'num_predictions', 'num_objects', 'num_detections', 'num_frames', 94 | ]) 95 | assert np.isnan(metr['mota']) 96 | assert np.isnan(metr['motp']) 97 | assert metr['num_predictions'] == 0 98 | assert metr['num_objects'] == 0 99 | assert metr['num_detections'] == 0 100 | assert metr['num_frames'] == 0 101 | 102 | 103 | def test_assignment_metrics_with_empty_groundtruth(): 104 | """Tests metrics when there are no ground-truth objects.""" 105 | acc = mm.MOTAccumulator(auto_id=True) 106 | # Empty groundtruth. 107 | acc.update([], [1, 2, 3, 4], []) 108 | acc.update([], [1, 2, 3, 4], []) 109 | acc.update([], [1, 2, 3, 4], []) 110 | acc.update([], [1, 2, 3, 4], []) 111 | 112 | mh = mm.metrics.create() 113 | metr = mh.compute(acc, return_dataframe=False, metrics=[ 114 | 'num_matches', 'num_false_positives', 'num_misses', 115 | 'idtp', 'idfp', 'idfn', 'num_frames', 116 | ]) 117 | assert metr['num_matches'] == 0 118 | assert metr['num_false_positives'] == 16 119 | assert metr['num_misses'] == 0 120 | assert metr['idtp'] == 0 121 | assert metr['idfp'] == 16 122 | assert metr['idfn'] == 0 123 | assert metr['num_frames'] == 4 124 | 125 | 126 | def test_assignment_metrics_with_empty_predictions(): 127 | """Tests metrics when there are no predictions.""" 128 | acc = mm.MOTAccumulator(auto_id=True) 129 | # Empty predictions. 130 | acc.update([1, 2, 3, 4], [], []) 131 | acc.update([1, 2, 3, 4], [], []) 132 | acc.update([1, 2, 3, 4], [], []) 133 | acc.update([1, 2, 3, 4], [], []) 134 | 135 | mh = mm.metrics.create() 136 | metr = mh.compute(acc, return_dataframe=False, metrics=[ 137 | 'num_matches', 'num_false_positives', 'num_misses', 138 | 'idtp', 'idfp', 'idfn', 'num_frames', 139 | ]) 140 | assert metr['num_matches'] == 0 141 | assert metr['num_false_positives'] == 0 142 | assert metr['num_misses'] == 16 143 | assert metr['idtp'] == 0 144 | assert metr['idfp'] == 0 145 | assert metr['idfn'] == 16 146 | assert metr['num_frames'] == 4 147 | 148 | 149 | def test_assignment_metrics_with_both_empty(): 150 | """Tests metrics when there are no ground-truth objects or predictions.""" 151 | acc = mm.MOTAccumulator(auto_id=True) 152 | # Empty groundtruth and empty predictions. 153 | acc.update([], [], []) 154 | acc.update([], [], []) 155 | acc.update([], [], []) 156 | acc.update([], [], []) 157 | 158 | mh = mm.metrics.create() 159 | metr = mh.compute(acc, return_dataframe=False, metrics=[ 160 | 'num_matches', 'num_false_positives', 'num_misses', 161 | 'idtp', 'idfp', 'idfn', 'num_frames', 162 | ]) 163 | assert metr['num_matches'] == 0 164 | assert metr['num_false_positives'] == 0 165 | assert metr['num_misses'] == 0 166 | assert metr['idtp'] == 0 167 | assert metr['idfp'] == 0 168 | assert metr['idfn'] == 0 169 | assert metr['num_frames'] == 4 170 | 171 | 172 | def _extract_counts(acc): 173 | df_map = mm.metrics.events_to_df_map(acc.events) 174 | return mm.metrics.extract_counts_from_df_map(df_map) 175 | 176 | 177 | def test_extract_counts(): 178 | """Tests events_to_df_map() and extract_counts_from_df_map().""" 179 | acc = mm.MOTAccumulator() 180 | # All FP 181 | acc.update([], [1, 2], [], frameid=0) 182 | # All miss 183 | acc.update([1, 2], [], [], frameid=1) 184 | # Match 185 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=2) 186 | # Switch 187 | acc.update([1, 2], [1, 2], [[0.2, np.nan], [np.nan, 0.1]], frameid=3) 188 | # Match. Better new match is available but should prefer history 189 | acc.update([1, 2], [1, 2], [[5, 1], [1, 5]], frameid=4) 190 | # No data 191 | acc.update([], [], [], frameid=5) 192 | 193 | ocs, hcs, tps = _extract_counts(acc) 194 | 195 | assert ocs == {1: 4, 2: 4} 196 | assert hcs == {1: 4, 2: 4} 197 | expected_tps = { 198 | (1, 1): 3, 199 | (1, 2): 2, 200 | (2, 1): 2, 201 | (2, 2): 3, 202 | } 203 | assert tps == expected_tps 204 | 205 | 206 | def test_extract_pandas_series_issue(): 207 | """Reproduce issue that arises with pd.Series but not pd.DataFrame. 208 | 209 | >>> data = [[0, 1, 0.1], [0, 1, 0.2], [0, 1, 0.3]] 210 | >>> df = pd.DataFrame(data, columns=['x', 'y', 'z']).set_index(['x', 'y']) 211 | >>> df['z'].groupby(['x', 'y']).count() 212 | {(0, 1): 3} 213 | 214 | >>> data = [[0, 1, 0.1], [0, 1, 0.2]] 215 | >>> df = pd.DataFrame(data, columns=['x', 'y', 'z']).set_index(['x', 'y']) 216 | >>> df['z'].groupby(['x', 'y']).count() 217 | {'x': 1, 'y': 1} 218 | 219 | >>> df[['z']].groupby(['x', 'y'])['z'].count().to_dict() 220 | {(0, 1): 2} 221 | """ 222 | acc = mm.MOTAccumulator(auto_id=True) 223 | acc.update([0], [1], [[0.1]]) 224 | acc.update([0], [1], [[0.1]]) 225 | ocs, hcs, tps = _extract_counts(acc) 226 | assert ocs == {0: 2} 227 | assert hcs == {1: 2} 228 | assert tps == {(0, 1): 2} 229 | 230 | 231 | def test_benchmark_extract_counts(benchmark): 232 | """Benchmarks events_to_df_map() and extract_counts_from_df_map().""" 233 | rand = np.random.RandomState(0) 234 | acc = _accum_random_uniform( 235 | rand, seq_len=100, num_objs=50, num_hyps=5000, 236 | objs_per_frame=20, hyps_per_frame=40) 237 | benchmark(_extract_counts, acc) 238 | 239 | 240 | def _accum_random_uniform(rand, seq_len, num_objs, num_hyps, objs_per_frame, hyps_per_frame): 241 | acc = mm.MOTAccumulator(auto_id=True) 242 | for _ in range(seq_len): 243 | # Choose subset of objects present in this frame. 244 | objs = rand.choice(num_objs, objs_per_frame, replace=False) 245 | # Choose subset of hypotheses present in this frame. 246 | hyps = rand.choice(num_hyps, hyps_per_frame, replace=False) 247 | dist = rand.uniform(size=(objs_per_frame, hyps_per_frame)) 248 | acc.update(objs, hyps, dist) 249 | return acc 250 | 251 | 252 | def test_mota_motp(): 253 | """Tests values of MOTA and MOTP.""" 254 | acc = mm.MOTAccumulator() 255 | 256 | # All FP 257 | acc.update([], [1, 2], [], frameid=0) 258 | # All miss 259 | acc.update([1, 2], [], [], frameid=1) 260 | # Match 261 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=2) 262 | # Switch 263 | acc.update([1, 2], [1, 2], [[0.2, np.nan], [np.nan, 0.1]], frameid=3) 264 | # Match. Better new match is available but should prefer history 265 | acc.update([1, 2], [1, 2], [[5, 1], [1, 5]], frameid=4) 266 | # No data 267 | acc.update([], [], [], frameid=5) 268 | 269 | mh = mm.metrics.create() 270 | metr = mh.compute(acc, return_dataframe=False, return_cached=True, metrics=[ 271 | 'num_matches', 'num_false_positives', 'num_misses', 'num_switches', 'num_detections', 272 | 'num_objects', 'num_predictions', 'mota', 'motp', 'num_frames' 273 | ]) 274 | 275 | assert metr['num_matches'] == 4 276 | assert metr['num_false_positives'] == 2 277 | assert metr['num_misses'] == 2 278 | assert metr['num_switches'] == 2 279 | assert metr['num_detections'] == 6 280 | assert metr['num_objects'] == 8 281 | assert metr['num_predictions'] == 8 282 | assert metr['mota'] == approx(1. - (2 + 2 + 2) / 8) 283 | assert metr['motp'] == approx(11.1 / 6) 284 | assert metr['num_frames'] == 6 285 | 286 | 287 | def test_ids(): 288 | """Test metrics with frame IDs specified manually.""" 289 | acc = mm.MOTAccumulator() 290 | 291 | # No data 292 | acc.update([], [], [], frameid=0) 293 | # Match 294 | acc.update([1, 2], [1, 2], [[1, 0], [0, 1]], frameid=1) 295 | # Switch also Transfer 296 | acc.update([1, 2], [1, 2], [[0.4, np.nan], [np.nan, 0.4]], frameid=2) 297 | # Match 298 | acc.update([1, 2], [1, 2], [[0, 1], [1, 0]], frameid=3) 299 | # Ascend (switch) 300 | acc.update([1, 2], [2, 3], [[1, 0], [0.4, 0.7]], frameid=4) 301 | # Migrate (transfer) 302 | acc.update([1, 3], [2, 3], [[1, 0], [0.4, 0.7]], frameid=5) 303 | # No data 304 | acc.update([], [], [], frameid=6) 305 | 306 | mh = mm.metrics.create() 307 | metr = mh.compute(acc, return_dataframe=False, return_cached=True, metrics=[ 308 | 'num_matches', 'num_false_positives', 'num_misses', 'num_switches', 309 | 'num_transfer', 'num_ascend', 'num_migrate', 310 | 'num_detections', 'num_objects', 'num_predictions', 311 | 'mota', 'motp', 'num_frames', 312 | ]) 313 | assert metr['num_matches'] == 7 314 | assert metr['num_false_positives'] == 0 315 | assert metr['num_misses'] == 0 316 | assert metr['num_switches'] == 3 317 | assert metr['num_transfer'] == 3 318 | assert metr['num_ascend'] == 1 319 | assert metr['num_migrate'] == 1 320 | assert metr['num_detections'] == 10 321 | assert metr['num_objects'] == 10 322 | assert metr['num_predictions'] == 10 323 | assert metr['mota'] == approx(1. - (0 + 0 + 3) / 10) 324 | assert metr['motp'] == approx(1.6 / 10) 325 | assert metr['num_frames'] == 7 326 | 327 | 328 | def test_correct_average(): 329 | """Tests what is depicted in figure 3 of 'Evaluating MOT Performance'.""" 330 | acc = mm.MOTAccumulator(auto_id=True) 331 | 332 | # No track 333 | acc.update([1, 2, 3, 4], [], []) 334 | acc.update([1, 2, 3, 4], [], []) 335 | acc.update([1, 2, 3, 4], [], []) 336 | acc.update([1, 2, 3, 4], [], []) 337 | 338 | # Track single 339 | acc.update([4], [4], [0]) 340 | acc.update([4], [4], [0]) 341 | acc.update([4], [4], [0]) 342 | acc.update([4], [4], [0]) 343 | 344 | mh = mm.metrics.create() 345 | metr = mh.compute(acc, metrics='mota', return_dataframe=False) 346 | assert metr['mota'] == approx(0.2) 347 | 348 | 349 | def test_motchallenge_files(): 350 | """Tests metrics for sequences TUD-Campus and TUD-Stadtmitte.""" 351 | dnames = [ 352 | 'TUD-Campus', 353 | 'TUD-Stadtmitte', 354 | ] 355 | 356 | def compute_motchallenge(dname): 357 | df_gt = mm.io.loadtxt(os.path.join(dname, 'gt.txt')) 358 | df_test = mm.io.loadtxt(os.path.join(dname, 'test.txt')) 359 | return mm.utils.compare_to_groundtruth(df_gt, df_test, 'iou', distth=0.5) 360 | 361 | accs = [compute_motchallenge(os.path.join(DATA_DIR, d)) for d in dnames] 362 | 363 | # For testing 364 | # [a.events.to_pickle(n) for (a,n) in zip(accs, dnames)] 365 | 366 | mh = mm.metrics.create() 367 | summary = mh.compute_many(accs, metrics=mm.metrics.motchallenge_metrics, names=dnames, generate_overall=True) 368 | 369 | print() 370 | print(mm.io.render_summary(summary, namemap=mm.io.motchallenge_metric_names, formatters=mh.formatters)) 371 | # assert ((summary['num_transfer'] - summary['num_migrate']) == (summary['num_switches'] - summary['num_ascend'])).all() # False assertion 372 | summary = summary[mm.metrics.motchallenge_metrics[:15]] 373 | expected = pd.DataFrame([ 374 | [0.557659, 0.729730, 0.451253, 0.582173, 0.941441, 8.0, 1, 6, 1, 13, 150, 7, 7, 0.526462, 0.277201], 375 | [0.644619, 0.819760, 0.531142, 0.608997, 0.939920, 10.0, 5, 4, 1, 45, 452, 7, 6, 0.564014, 0.345904], 376 | [0.624296, 0.799176, 0.512211, 0.602640, 0.940268, 18.0, 6, 10, 2, 58, 602, 14, 13, 0.555116, 0.330177], 377 | ]) 378 | np.testing.assert_allclose(summary, expected, atol=1e-3) 379 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/tests/test_mot.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Tests behavior of MOTAccumulator.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import numpy as np 15 | import pandas as pd 16 | import pytest 17 | 18 | import motmetrics as mm 19 | 20 | 21 | def test_events(): 22 | """Tests that expected events are created by MOTAccumulator.update().""" 23 | acc = mm.MOTAccumulator() 24 | 25 | # All FP 26 | acc.update([], [1, 2], [], frameid=0) 27 | # All miss 28 | acc.update([1, 2], [], [], frameid=1) 29 | # Match 30 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=2) 31 | # Switch 32 | acc.update([1, 2], [1, 2], [[0.2, np.nan], [np.nan, 0.1]], frameid=3) 33 | # Match. Better new match is available but should prefer history 34 | acc.update([1, 2], [1, 2], [[5, 1], [1, 5]], frameid=4) 35 | # No data 36 | acc.update([], [], [], frameid=5) 37 | 38 | expect = mm.MOTAccumulator.new_event_dataframe() 39 | expect.loc[(0, 0), :] = ['RAW', np.nan, np.nan, np.nan] 40 | expect.loc[(0, 1), :] = ['RAW', np.nan, 1, np.nan] 41 | expect.loc[(0, 2), :] = ['RAW', np.nan, 2, np.nan] 42 | expect.loc[(0, 3), :] = ['FP', np.nan, 1, np.nan] 43 | expect.loc[(0, 4), :] = ['FP', np.nan, 2, np.nan] 44 | 45 | expect.loc[(1, 0), :] = ['RAW', np.nan, np.nan, np.nan] 46 | expect.loc[(1, 1), :] = ['RAW', 1, np.nan, np.nan] 47 | expect.loc[(1, 2), :] = ['RAW', 2, np.nan, np.nan] 48 | expect.loc[(1, 3), :] = ['MISS', 1, np.nan, np.nan] 49 | expect.loc[(1, 4), :] = ['MISS', 2, np.nan, np.nan] 50 | 51 | expect.loc[(2, 0), :] = ['RAW', np.nan, np.nan, np.nan] 52 | expect.loc[(2, 1), :] = ['RAW', 1, 1, 1.0] 53 | expect.loc[(2, 2), :] = ['RAW', 1, 2, 0.5] 54 | expect.loc[(2, 3), :] = ['RAW', 2, 1, 0.3] 55 | expect.loc[(2, 4), :] = ['RAW', 2, 2, 1.0] 56 | expect.loc[(2, 5), :] = ['MATCH', 1, 2, 0.5] 57 | expect.loc[(2, 6), :] = ['MATCH', 2, 1, 0.3] 58 | 59 | expect.loc[(3, 0), :] = ['RAW', np.nan, np.nan, np.nan] 60 | expect.loc[(3, 1), :] = ['RAW', 1, 1, 0.2] 61 | expect.loc[(3, 2), :] = ['RAW', 2, 2, 0.1] 62 | expect.loc[(3, 3), :] = ['TRANSFER', 1, 1, 0.2] 63 | expect.loc[(3, 4), :] = ['SWITCH', 1, 1, 0.2] 64 | expect.loc[(3, 5), :] = ['TRANSFER', 2, 2, 0.1] 65 | expect.loc[(3, 6), :] = ['SWITCH', 2, 2, 0.1] 66 | 67 | expect.loc[(4, 0), :] = ['RAW', np.nan, np.nan, np.nan] 68 | expect.loc[(4, 1), :] = ['RAW', 1, 1, 5.] 69 | expect.loc[(4, 2), :] = ['RAW', 1, 2, 1.] 70 | expect.loc[(4, 3), :] = ['RAW', 2, 1, 1.] 71 | expect.loc[(4, 4), :] = ['RAW', 2, 2, 5.] 72 | expect.loc[(4, 5), :] = ['MATCH', 1, 1, 5.] 73 | expect.loc[(4, 6), :] = ['MATCH', 2, 2, 5.] 74 | 75 | expect.loc[(5, 0), :] = ['RAW', np.nan, np.nan, np.nan] 76 | 77 | pd.util.testing.assert_frame_equal(acc.events, expect) 78 | 79 | 80 | def test_max_switch_time(): 81 | """Tests max_switch_time option.""" 82 | acc = mm.MOTAccumulator(max_switch_time=1) 83 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=1) # 1->a, 2->b 84 | frameid = acc.update([1, 2], [1, 2], [[0.5, np.nan], [np.nan, 0.5]], frameid=2) # 1->b, 2->a 85 | 86 | df = acc.events.loc[frameid] 87 | assert ((df.Type == 'SWITCH') | (df.Type == 'RAW') | (df.Type == 'TRANSFER')).all() 88 | 89 | acc = mm.MOTAccumulator(max_switch_time=1) 90 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=1) # 1->a, 2->b 91 | frameid = acc.update([1, 2], [1, 2], [[0.5, np.nan], [np.nan, 0.5]], frameid=5) # Later frame 1->b, 2->a 92 | 93 | df = acc.events.loc[frameid] 94 | assert ((df.Type == 'MATCH') | (df.Type == 'RAW') | (df.Type == 'TRANSFER')).all() 95 | 96 | 97 | def test_auto_id(): 98 | """Tests auto_id option.""" 99 | acc = mm.MOTAccumulator(auto_id=True) 100 | acc.update([1, 2, 3, 4], [], []) 101 | acc.update([1, 2, 3, 4], [], []) 102 | assert acc.events.index.levels[0][-1] == 1 103 | acc.update([1, 2, 3, 4], [], []) 104 | assert acc.events.index.levels[0][-1] == 2 105 | 106 | with pytest.raises(AssertionError): 107 | acc.update([1, 2, 3, 4], [], [], frameid=5) 108 | 109 | acc = mm.MOTAccumulator(auto_id=False) 110 | with pytest.raises(AssertionError): 111 | acc.update([1, 2, 3, 4], [], []) 112 | 113 | 114 | def test_merge_dataframes(): 115 | """Tests merge_event_dataframes().""" 116 | # pylint: disable=too-many-statements 117 | acc = mm.MOTAccumulator() 118 | 119 | acc.update([], [1, 2], [], frameid=0) 120 | acc.update([1, 2], [], [], frameid=1) 121 | acc.update([1, 2], [1, 2], [[1, 0.5], [0.3, 1]], frameid=2) 122 | acc.update([1, 2], [1, 2], [[0.2, np.nan], [np.nan, 0.1]], frameid=3) 123 | 124 | r, mappings = mm.MOTAccumulator.merge_event_dataframes([acc.events, acc.events], return_mappings=True) 125 | 126 | expect = mm.MOTAccumulator.new_event_dataframe() 127 | 128 | expect.loc[(0, 0), :] = ['RAW', np.nan, np.nan, np.nan] 129 | expect.loc[(0, 1), :] = ['RAW', np.nan, mappings[0]['hid_map'][1], np.nan] 130 | expect.loc[(0, 2), :] = ['RAW', np.nan, mappings[0]['hid_map'][2], np.nan] 131 | expect.loc[(0, 3), :] = ['FP', np.nan, mappings[0]['hid_map'][1], np.nan] 132 | expect.loc[(0, 4), :] = ['FP', np.nan, mappings[0]['hid_map'][2], np.nan] 133 | 134 | expect.loc[(1, 0), :] = ['RAW', np.nan, np.nan, np.nan] 135 | expect.loc[(1, 1), :] = ['RAW', mappings[0]['oid_map'][1], np.nan, np.nan] 136 | expect.loc[(1, 2), :] = ['RAW', mappings[0]['oid_map'][2], np.nan, np.nan] 137 | expect.loc[(1, 3), :] = ['MISS', mappings[0]['oid_map'][1], np.nan, np.nan] 138 | expect.loc[(1, 4), :] = ['MISS', mappings[0]['oid_map'][2], np.nan, np.nan] 139 | 140 | expect.loc[(2, 0), :] = ['RAW', np.nan, np.nan, np.nan] 141 | expect.loc[(2, 1), :] = ['RAW', mappings[0]['oid_map'][1], mappings[0]['hid_map'][1], 1] 142 | expect.loc[(2, 2), :] = ['RAW', mappings[0]['oid_map'][1], mappings[0]['hid_map'][2], 0.5] 143 | expect.loc[(2, 3), :] = ['RAW', mappings[0]['oid_map'][2], mappings[0]['hid_map'][1], 0.3] 144 | expect.loc[(2, 4), :] = ['RAW', mappings[0]['oid_map'][2], mappings[0]['hid_map'][2], 1.0] 145 | expect.loc[(2, 5), :] = ['MATCH', mappings[0]['oid_map'][1], mappings[0]['hid_map'][2], 0.5] 146 | expect.loc[(2, 6), :] = ['MATCH', mappings[0]['oid_map'][2], mappings[0]['hid_map'][1], 0.3] 147 | 148 | expect.loc[(3, 0), :] = ['RAW', np.nan, np.nan, np.nan] 149 | expect.loc[(3, 1), :] = ['RAW', mappings[0]['oid_map'][1], mappings[0]['hid_map'][1], 0.2] 150 | expect.loc[(3, 2), :] = ['RAW', mappings[0]['oid_map'][2], mappings[0]['hid_map'][2], 0.1] 151 | expect.loc[(3, 3), :] = ['TRANSFER', mappings[0]['oid_map'][1], mappings[0]['hid_map'][1], 0.2] 152 | expect.loc[(3, 4), :] = ['SWITCH', mappings[0]['oid_map'][1], mappings[0]['hid_map'][1], 0.2] 153 | expect.loc[(3, 5), :] = ['TRANSFER', mappings[0]['oid_map'][2], mappings[0]['hid_map'][2], 0.1] 154 | expect.loc[(3, 6), :] = ['SWITCH', mappings[0]['oid_map'][2], mappings[0]['hid_map'][2], 0.1] 155 | 156 | # Merge duplication 157 | expect.loc[(4, 0), :] = ['RAW', np.nan, np.nan, np.nan] 158 | expect.loc[(4, 1), :] = ['RAW', np.nan, mappings[1]['hid_map'][1], np.nan] 159 | expect.loc[(4, 2), :] = ['RAW', np.nan, mappings[1]['hid_map'][2], np.nan] 160 | expect.loc[(4, 3), :] = ['FP', np.nan, mappings[1]['hid_map'][1], np.nan] 161 | expect.loc[(4, 4), :] = ['FP', np.nan, mappings[1]['hid_map'][2], np.nan] 162 | 163 | expect.loc[(5, 0), :] = ['RAW', np.nan, np.nan, np.nan] 164 | expect.loc[(5, 1), :] = ['RAW', mappings[1]['oid_map'][1], np.nan, np.nan] 165 | expect.loc[(5, 2), :] = ['RAW', mappings[1]['oid_map'][2], np.nan, np.nan] 166 | expect.loc[(5, 3), :] = ['MISS', mappings[1]['oid_map'][1], np.nan, np.nan] 167 | expect.loc[(5, 4), :] = ['MISS', mappings[1]['oid_map'][2], np.nan, np.nan] 168 | 169 | expect.loc[(6, 0), :] = ['RAW', np.nan, np.nan, np.nan] 170 | expect.loc[(6, 1), :] = ['RAW', mappings[1]['oid_map'][1], mappings[1]['hid_map'][1], 1] 171 | expect.loc[(6, 2), :] = ['RAW', mappings[1]['oid_map'][1], mappings[1]['hid_map'][2], 0.5] 172 | expect.loc[(6, 3), :] = ['RAW', mappings[1]['oid_map'][2], mappings[1]['hid_map'][1], 0.3] 173 | expect.loc[(6, 4), :] = ['RAW', mappings[1]['oid_map'][2], mappings[1]['hid_map'][2], 1.0] 174 | expect.loc[(6, 5), :] = ['MATCH', mappings[1]['oid_map'][1], mappings[1]['hid_map'][2], 0.5] 175 | expect.loc[(6, 6), :] = ['MATCH', mappings[1]['oid_map'][2], mappings[1]['hid_map'][1], 0.3] 176 | 177 | expect.loc[(7, 0), :] = ['RAW', np.nan, np.nan, np.nan] 178 | expect.loc[(7, 1), :] = ['RAW', mappings[1]['oid_map'][1], mappings[1]['hid_map'][1], 0.2] 179 | expect.loc[(7, 2), :] = ['RAW', mappings[1]['oid_map'][2], mappings[1]['hid_map'][2], 0.1] 180 | expect.loc[(7, 3), :] = ['TRANSFER', mappings[1]['oid_map'][1], mappings[1]['hid_map'][1], 0.2] 181 | expect.loc[(7, 4), :] = ['SWITCH', mappings[1]['oid_map'][1], mappings[1]['hid_map'][1], 0.2] 182 | expect.loc[(7, 5), :] = ['TRANSFER', mappings[1]['oid_map'][2], mappings[1]['hid_map'][2], 0.1] 183 | expect.loc[(7, 6), :] = ['SWITCH', mappings[1]['oid_map'][2], mappings[1]['hid_map'][2], 0.1] 184 | 185 | pd.util.testing.assert_frame_equal(r, expect) 186 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/tests/test_utils.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Tests accumulation of events using utility functions.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import itertools 15 | 16 | import numpy as np 17 | import pandas as pd 18 | 19 | import motmetrics as mm 20 | 21 | 22 | def test_annotations_xor_predictions_present(): 23 | """Tests frames that contain only annotations or predictions.""" 24 | _ = None 25 | anno_tracks = { 26 | 1: [0, 2, 4, 6, _, _, _], 27 | 2: [_, _, 0, 2, 4, _, _], 28 | } 29 | pred_tracks = { 30 | 1: [_, _, 3, 5, 7, 7, 7], 31 | } 32 | anno = _tracks_to_dataframe(anno_tracks) 33 | pred = _tracks_to_dataframe(pred_tracks) 34 | acc = mm.utils.compare_to_groundtruth(anno, pred, 'euc', distfields=['Position'], distth=2) 35 | mh = mm.metrics.create() 36 | metrics = mh.compute(acc, return_dataframe=False, metrics=[ 37 | 'num_objects', 'num_predictions', 'num_unique_objects', 38 | ]) 39 | np.testing.assert_equal(metrics['num_objects'], 7) 40 | np.testing.assert_equal(metrics['num_predictions'], 5) 41 | np.testing.assert_equal(metrics['num_unique_objects'], 2) 42 | 43 | 44 | def _tracks_to_dataframe(tracks): 45 | rows = [] 46 | for track_id, track in tracks.items(): 47 | for frame_id, position in zip(itertools.count(1), track): 48 | if position is None: 49 | continue 50 | rows.append({ 51 | 'FrameId': frame_id, 52 | 'Id': track_id, 53 | 'Position': position, 54 | }) 55 | return pd.DataFrame(rows).set_index(['FrameId', 'Id']) 56 | -------------------------------------------------------------------------------- /posetrack21_eval/motmetrics/utils.py: -------------------------------------------------------------------------------- 1 | # py-motmetrics - Metrics for multiple object tracker (MOT) benchmarking. 2 | # https://github.com/cheind/py-motmetrics/ 3 | # 4 | # MIT License 5 | # Copyright (c) 2017-2020 Christoph Heindl, Jack Valmadre and others. 6 | # See LICENSE file for terms. 7 | 8 | """Functions for populating event accumulators.""" 9 | 10 | from __future__ import absolute_import 11 | from __future__ import division 12 | from __future__ import print_function 13 | 14 | import numpy as np 15 | 16 | from .distances import iou_matrix, norm2squared_matrix 17 | from .mot import MOTAccumulator 18 | from .preprocess import preprocessResult 19 | 20 | 21 | def compare_to_groundtruth(gt, dt, dist='iou', distfields=None, distth=0.5): 22 | """Compare groundtruth and detector results. 23 | 24 | This method assumes both results are given in terms of DataFrames with at least the following fields 25 | - `FrameId` First level index used for matching ground-truth and test frames. 26 | - `Id` Secondary level index marking available object / hypothesis ids 27 | 28 | Depending on the distance to be used relevant distfields need to be specified. 29 | 30 | Params 31 | ------ 32 | gt : pd.DataFrame 33 | Dataframe for ground-truth 34 | test : pd.DataFrame 35 | Dataframe for detector results 36 | 37 | Kwargs 38 | ------ 39 | dist : str, optional 40 | String identifying distance to be used. Defaults to intersection over union. 41 | distfields: array, optional 42 | Fields relevant for extracting distance information. Defaults to ['X', 'Y', 'Width', 'Height'] 43 | distth: float, optional 44 | Maximum tolerable distance. Pairs exceeding this threshold are marked 'do-not-pair'. 45 | """ 46 | # pylint: disable=too-many-locals 47 | if distfields is None: 48 | distfields = ['X', 'Y', 'Width', 'Height'] 49 | 50 | def compute_iou(a, b): 51 | return iou_matrix(a, b, max_iou=distth) 52 | 53 | def compute_euc(a, b): 54 | return norm2squared_matrix(a, b, max_d2=distth) 55 | 56 | compute_dist = compute_iou if dist.upper() == 'IOU' else compute_euc 57 | 58 | acc = MOTAccumulator() 59 | 60 | # We need to account for all frames reported either by ground truth or 61 | # detector. In case a frame is missing in GT this will lead to FPs, in 62 | # case a frame is missing in detector results this will lead to FNs. 63 | allframeids = gt.index.union(dt.index).levels[0] 64 | 65 | gt = gt[distfields] 66 | dt = dt[distfields] 67 | fid_to_fgt = dict(iter(gt.groupby('FrameId'))) 68 | fid_to_fdt = dict(iter(dt.groupby('FrameId'))) 69 | 70 | for fid in allframeids: 71 | oids = np.empty(0) 72 | hids = np.empty(0) 73 | dists = np.empty((0, 0)) 74 | if fid in fid_to_fgt: 75 | fgt = fid_to_fgt[fid] 76 | oids = fgt.index.get_level_values('Id') 77 | if fid in fid_to_fdt: 78 | fdt = fid_to_fdt[fid] 79 | hids = fdt.index.get_level_values('Id') 80 | if len(oids) > 0 and len(hids) > 0: 81 | dists = compute_dist(fgt.values, fdt.values) 82 | acc.update(oids, hids, dists, frameid=fid) 83 | 84 | return acc 85 | 86 | 87 | def CLEAR_MOT_M(gt, dt, inifile, dist='iou', distfields=None, distth=0.5, include_all=False, vflag=''): 88 | """Compare groundtruth and detector results. 89 | 90 | This method assumes both results are given in terms of DataFrames with at least the following fields 91 | - `FrameId` First level index used for matching ground-truth and test frames. 92 | - `Id` Secondary level index marking available object / hypothesis ids 93 | 94 | Depending on the distance to be used relevant distfields need to be specified. 95 | 96 | Params 97 | ------ 98 | gt : pd.DataFrame 99 | Dataframe for ground-truth 100 | test : pd.DataFrame 101 | Dataframe for detector results 102 | 103 | Kwargs 104 | ------ 105 | dist : str, optional 106 | String identifying distance to be used. Defaults to intersection over union. 107 | distfields: array, optional 108 | Fields relevant for extracting distance information. Defaults to ['X', 'Y', 'Width', 'Height'] 109 | distth: float, optional 110 | Maximum tolerable distance. Pairs exceeding this threshold are marked 'do-not-pair'. 111 | """ 112 | # pylint: disable=too-many-locals 113 | if distfields is None: 114 | distfields = ['X', 'Y', 'Width', 'Height'] 115 | 116 | def compute_iou(a, b): 117 | return iou_matrix(a, b, max_iou=distth) 118 | 119 | def compute_euc(a, b): 120 | return norm2squared_matrix(a, b, max_d2=distth) 121 | 122 | compute_dist = compute_iou if dist.upper() == 'IOU' else compute_euc 123 | 124 | acc = MOTAccumulator() 125 | dt = preprocessResult(dt, gt, inifile) 126 | if include_all: 127 | gt = gt[gt['Confidence'] >= 0.99] 128 | else: 129 | gt = gt[(gt['Confidence'] >= 0.99) & (gt['ClassId'] == 1)] 130 | # We need to account for all frames reported either by ground truth or 131 | # detector. In case a frame is missing in GT this will lead to FPs, in 132 | # case a frame is missing in detector results this will lead to FNs. 133 | allframeids = gt.index.union(dt.index).levels[0] 134 | analysis = {'hyp': {}, 'obj': {}} 135 | for fid in allframeids: 136 | oids = np.empty(0) 137 | hids = np.empty(0) 138 | dists = np.empty((0, 0)) 139 | 140 | if fid in gt.index: 141 | fgt = gt.loc[fid] 142 | oids = fgt.index.values 143 | for oid in oids: 144 | oid = int(oid) 145 | if oid not in analysis['obj']: 146 | analysis['obj'][oid] = 0 147 | analysis['obj'][oid] += 1 148 | 149 | if fid in dt.index: 150 | fdt = dt.loc[fid] 151 | hids = fdt.index.values 152 | for hid in hids: 153 | hid = int(hid) 154 | if hid not in analysis['hyp']: 155 | analysis['hyp'][hid] = 0 156 | analysis['hyp'][hid] += 1 157 | 158 | if oids.shape[0] > 0 and hids.shape[0] > 0: 159 | dists = compute_dist(fgt[distfields].values, fdt[distfields].values) 160 | 161 | acc.update(oids, hids, dists, frameid=fid, vf=vflag) 162 | 163 | return acc, analysis 164 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "comotion_demo" 3 | description = "CoMotion: Concurrent Multi-person 3D Motion." 4 | version = "0.1" 5 | authors = [ 6 | {name = "Alejandro Newell"}, 7 | {name = "Peiyun Hu"}, 8 | {name = "Lahav Lipson"}, 9 | {name = "Stephan R. Richter"}, 10 | {name = "Vladlen Koltun"}, 11 | ] 12 | readme = "README.md" 13 | dependencies = [ 14 | "click", 15 | "coremltools", 16 | "einops", 17 | "ffmpeg-python", 18 | "opencv-python", 19 | "pypose", 20 | "PyQt6", 21 | "ruff", 22 | "scenedetect", 23 | "tensordict", 24 | "timm==1.0.13", 25 | "torch==2.5.1", 26 | "transformers==4.48.0", 27 | "tqdm", 28 | "chumpy @ git+https://github.com/mattloper/chumpy@9b045ff5d6588a24a0bab52c83f032e2ba433e17", 29 | ] 30 | requires-python = ">=3.10" 31 | 32 | [project.optional-dependencies] 33 | all = [ 34 | "aitviewer", 35 | ] 36 | colab = [ 37 | "pyrender", 38 | "smplx[all]", 39 | "more-itertools", 40 | "trimesh" 41 | ] 42 | 43 | [project.urls] 44 | Homepage = "https://github.com/apple/ml-comotion" 45 | Repository = "https://github.com/apple/ml-comotion" 46 | 47 | [build-system] 48 | requires = ["setuptools", "setuptools-scm"] 49 | build-backend = "setuptools.build_meta" 50 | 51 | [tool.setuptools.packages.find] 52 | where = ["src"] 53 | -------------------------------------------------------------------------------- /samples/sample_info.txt: -------------------------------------------------------------------------------- 1 | All sample visualizations provided under CC-by-NC-ND. 2 | 3 | Sources: 4 | https://www.pexels.com/video/a-man-and-a-woman-dancing-the-tango-on-the-sidewalk-8281172/ 5 | https://www.pexels.com/video/a-man-doing-breakdancing-8688465/ 6 | https://www.pexels.com/video/bearded-man-doing-breakdancing-9344627/ 7 | https://www.pexels.com/video/a-woman-showing-her-ballet-skill-in-turning-one-footed-5385885/ 8 | https://www.pexels.com/video/a-couple-practicing-acrobatics-6809489/ 9 | -------------------------------------------------------------------------------- /samples/teaser_01.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_01.gif -------------------------------------------------------------------------------- /samples/teaser_02.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_02.gif -------------------------------------------------------------------------------- /samples/teaser_03.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_03.gif -------------------------------------------------------------------------------- /samples/teaser_04.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_04.gif -------------------------------------------------------------------------------- /samples/teaser_05.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_05.gif -------------------------------------------------------------------------------- /samples/teaser_06.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/samples/teaser_06.gif -------------------------------------------------------------------------------- /src/comotion_demo/data/smpl/extra_smpl_reference.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-comotion/adb7e22f85f58c8f52279ba0e996af83608f904a/src/comotion_demo/data/smpl/extra_smpl_reference.pt -------------------------------------------------------------------------------- /src/comotion_demo/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | """CoMotion model.""" 3 | 4 | from . import backbones # noqa 5 | from .comotion import CoMotion # noqa 6 | from .detect import CoMotionDetect # noqa 7 | -------------------------------------------------------------------------------- /src/comotion_demo/models/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | """Backbone network for CoMotion.""" 3 | 4 | from ._registry import _lookup 5 | from .convnext import ConvNextV2 # noqa 6 | 7 | 8 | def initialize(backbone_choice): 9 | """Initialize a registered backbone.""" 10 | assert backbone_choice in _lookup, ( 11 | f"Backbone choice '{backbone_choice}' not found in registry." 12 | ) 13 | return _lookup[backbone_choice]() 14 | -------------------------------------------------------------------------------- /src/comotion_demo/models/backbones/_registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | _lookup = {} 3 | 4 | 5 | def register_model(fn, model_name=None): 6 | """Register a model under a name. When unspecified, use the class name.""" 7 | if model_name is None: 8 | model_name = fn.__name__ 9 | _lookup[model_name] = fn 10 | 11 | return fn 12 | -------------------------------------------------------------------------------- /src/comotion_demo/models/backbones/convnext.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | """ConvNextV2 vision backbone.""" 3 | 4 | import einops as eo 5 | from timm import create_model 6 | from timm.layers import LayerNorm2d 7 | from torch import nn 8 | 9 | from ._registry import register_model 10 | 11 | 12 | @register_model 13 | class ConvNextV2(nn.Module): 14 | """Extract a feature pyramid out of pre-trained ConvNextV2.""" 15 | 16 | def __init__(self, output_dim=256, size="l", dropout=0.2, pretrained=False): 17 | """Set up a variant of ConvNextV2 and the feature extractors.""" 18 | super().__init__() 19 | 20 | self.output_dim = output_dim 21 | self.size = size 22 | self.dropout = dropout 23 | self.pretrained = pretrained 24 | 25 | # Initialize network 26 | backbone_ref = { 27 | "t": ["convnextv2_tiny.fcmae", [96, 192, 384, 768]], 28 | "b": ["convnextv2_base.fcmae", [128, 256, 512, 1024]], 29 | "l": ["convnextv2_large.fcmae", [192, 384, 768, 1536]], 30 | "h": ["convnextv2_huge.fcmae", [352, 704, 1408, 2816]], 31 | } 32 | network_name, feat_dims = backbone_ref[self.size] 33 | network = create_model(network_name, pretrained=self.pretrained) 34 | self.stem = network.stem 35 | self.stages = network.stages 36 | 37 | # Layers to fuse features across scales 38 | self._init_feature_fusion(feat_dims, output_dim) 39 | self.norm_px_feats = LayerNorm2d(output_dim) 40 | if self.dropout > 0: 41 | self.px_dropout = nn.Dropout2d(self.dropout) 42 | 43 | def _init_feature_fusion(self, feat_dims, output_dim): 44 | """Initialize linear layers for feature extraction at different levels.""" 45 | self.project_0 = nn.Conv2d(feat_dims[0], output_dim, 2, 2) 46 | self.project_1 = nn.Conv2d(feat_dims[1], output_dim, 1, 1) 47 | self.project_2 = nn.Conv2d(feat_dims[2], output_dim * 4, 1, 1) 48 | self.project_3 = nn.Conv2d(feat_dims[3], output_dim * 16, 1, 1) 49 | 50 | def forward(self, x): 51 | """Run ConvNextV2 inference and return a list of feature maps.""" 52 | # Run ConvNext stages 53 | x = self.stem(x) 54 | feat_pyramid = [] 55 | for stage in self.stages: 56 | x = stage(x) 57 | feat_pyramid.append(x) 58 | 59 | # Fuse into single tensor at 1/8th input resolution 60 | f0 = self.project_0(feat_pyramid[0]) 61 | f1 = self.project_1(feat_pyramid[1]) 62 | f2 = self.project_2(feat_pyramid[2]) 63 | f3 = self.project_3(feat_pyramid[3]) 64 | f2 = eo.rearrange(f2, "... (c h2 w2) h w -> ... c (h h2) (w w2)", h2=2, w2=2) 65 | f3 = eo.rearrange(f3, "... (c h2 w2) h w -> ... c (h h2) (w w2)", h2=4, w2=4) 66 | px_feats = f0 + f1 + f2 + f3 67 | px_feats = self.norm_px_feats(px_feats) 68 | if self.dropout > 0: 69 | px_feats = self.px_dropout(px_feats) 70 | 71 | return px_feats 72 | -------------------------------------------------------------------------------- /src/comotion_demo/models/comotion.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | import torch 3 | from scenedetect.detectors import ContentDetector 4 | from torch import nn 5 | 6 | from ..utils import dataloading, helper, smpl_kinematics, track 7 | from . import detect, refine 8 | 9 | 10 | class CoMotion(nn.Module): 11 | """CoMotion network.""" 12 | 13 | def __init__(self, use_coreml=False, pretrained=True): 14 | """Initialize CoMotion. 15 | 16 | Args: 17 | ---- 18 | use_coreml: use CoreML version of the detection model (macOS only). 19 | pretrained: load pre-trained CoMotion modules. 20 | 21 | """ 22 | super().__init__() 23 | 24 | if use_coreml: 25 | self.detection_model = detect.CoMotionDetectCoreML() 26 | else: 27 | self.detection_model = detect.CoMotionDetect(pretrained=pretrained) 28 | 29 | self.update_step = refine.PoseRefinement( 30 | self.detection_model.feat_dim, 31 | self.detection_model.cfg.pose_embed_dim, 32 | pretrained=pretrained, 33 | ) 34 | 35 | self.smpl_decoder = smpl_kinematics.SMPLKinematics() 36 | self.shot_detector = ContentDetector(threshold=50.0, min_scene_len=3) 37 | self.frame_count = 0 38 | 39 | def init_tracks(self, image_res): 40 | """Initialize track handler.""" 41 | self.handler = track.TrackHandler(track.default_dims, image_res) 42 | 43 | @torch.inference_mode() 44 | def forward(self, image, K, detection_only=False, use_mps=False): 45 | """Perform detection and tracking given a new image. 46 | 47 | Input images are accepted at any resolution, resizing and cropping is 48 | handled automatically and output 2D keypoints will be provided at the 49 | original input resolution. 50 | 51 | Args: 52 | ---- 53 | image: Input image (C x H x W) float (0-1) or uint8 (0-255) tensor 54 | K: Intrinsics matrix (2 x 3) float tensor 55 | detection_only: Flag whether to only run initial detection stage 56 | use_mps: Flag whether to run update step on MPS on MacOS 57 | 58 | """ 59 | # Prepare inputs 60 | device = next(self.parameters()).device 61 | if use_mps: 62 | self.update_step.to("mps") 63 | device = "cpu" 64 | 65 | # Prepare inputs 66 | cropped_image, cropped_K = dataloading.prepare_network_inputs(image, K, device) 67 | K = K.to(device) 68 | 69 | # Get detections 70 | detect_out = self.detection_model(cropped_image, cropped_K) 71 | nms_out = detect.decode_network_outputs( 72 | K, 73 | self.smpl_decoder, 74 | detect_out, 75 | std=0.08, 76 | iou_thr=0.4, 77 | conf_thr=0.1, 78 | ) 79 | 80 | if detection_only: 81 | return nms_out 82 | 83 | def call_update(s: track.TrackTensorState): 84 | # Prepare inputs 85 | update_args = [ 86 | detect_out.image_features, 87 | cropped_K, 88 | s.betas, 89 | s.pose, 90 | s.trans, 91 | s.pred_3d, 92 | s.hidden, 93 | ] 94 | if use_mps: 95 | update_args = [arg.to("mps") for arg in update_args] 96 | 97 | # Run update step 98 | updated_params = self.update_step(*update_args) 99 | updated_params = refine.RefineOutput( 100 | **{k: v.to(device) for k, v in updated_params._asdict().items()} 101 | ) 102 | 103 | # Update track state 104 | s.pose = detect.get_smpl_pose( 105 | updated_params.delta_root_orient, 106 | updated_params.delta_body_pose, 107 | ) 108 | s.trans = updated_params.trans 109 | s.hidden = updated_params.hidden 110 | s.pred_3d = self.smpl_decoder( 111 | s.betas, s.pose, s.trans, output_format="joints_face" 112 | ) 113 | s.pred_2d = helper.project_to_2d(K, s.pred_3d) 114 | 115 | # Detect shot changes 116 | image_np = image.detach().to("cpu").permute(1, 2, 0).numpy() 117 | image_np = image_np[:, :, ::-1] # RGB2BGR 118 | shots = self.shot_detector.process_frame(self.frame_count, image_np) 119 | self.frame_count += 1 120 | is_new_shot = len(shots) > 1 if self.frame_count > 1 else False 121 | return nms_out, self.handler.update( 122 | nms_out, call_update, shot_reset=is_new_shot 123 | ) 124 | -------------------------------------------------------------------------------- /src/comotion_demo/models/detect.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | import os 3 | from collections import namedtuple 4 | from dataclasses import dataclass 5 | 6 | import einops as eo 7 | import torch 8 | from torch import nn 9 | 10 | from ..utils import helper, smpl_kinematics 11 | from . import backbones, layers 12 | 13 | curr_dir = os.path.abspath(os.path.dirname(__file__)) 14 | PYTORCH_CHECKPOINT_PATH = f"{curr_dir}/../data/comotion_detection_checkpoint.pt" 15 | COREML_CHECKPOINT_PATH = f"{curr_dir}/../data/comotion_detection.mlpackage" 16 | 17 | DetectionOutput = namedtuple( 18 | "DetectOutput", 19 | [ 20 | "image_features", 21 | "betas", 22 | "delta_root_orient", 23 | "delta_body_pose", 24 | "trans", 25 | "conf", 26 | ], 27 | ) 28 | 29 | 30 | @dataclass 31 | class CoMotionDetectConfig: 32 | backbone_choice: str = "ConvNextV2" 33 | pose_embed_dim: int = 256 34 | hidden_dim: int = 512 35 | rot_embed_dim: int = 8 36 | 37 | 38 | class DetectionHead(nn.Module): 39 | """CoMotion detection head. 40 | 41 | Accepts as input image features and an intrinsics matrix to produce a large 42 | pool of candidate SMPL poses and corresponding confidences. 43 | """ 44 | 45 | def __init__( 46 | self, 47 | input_dim, 48 | output_split, 49 | hidden_dim=512, 50 | num_slots=4, 51 | depth_adj=128, 52 | dropout=0.0, 53 | ): 54 | """Initialize CoMotion detection head.""" 55 | super().__init__() 56 | 57 | self.input_dim = input_dim 58 | self.hidden_dim = hidden_dim 59 | self.num_slots = num_slots 60 | self.output_split = output_split 61 | self.depth_adj = depth_adj 62 | self.dropout = dropout 63 | 64 | # Blocks for feature pyramid encoding. 65 | self.enc1 = nn.Sequential( 66 | layers.DownsampleConvNextBlock(input_dim, hidden_dim), 67 | layers.DownsampleConvNextBlock(hidden_dim, 2 * hidden_dim, dropout=dropout), 68 | ) 69 | self.enc2 = layers.DownsampleConvNextBlock(2 * hidden_dim, dropout=dropout) 70 | self.enc3 = layers.DownsampleConvNextBlock(2 * hidden_dim, dropout=dropout) 71 | 72 | out_dim = sum(self.output_split) * self.num_slots 73 | self.decoders = nn.ModuleList( 74 | [ 75 | nn.Sequential( 76 | nn.Conv2d(2 * hidden_dim, 2 * hidden_dim, 1), 77 | layers.LayerNorm2d(2 * hidden_dim), 78 | nn.ReLU(), 79 | nn.Conv2d(2 * hidden_dim, out_dim, 1), 80 | ) 81 | for _ in range(3) 82 | ] 83 | ) 84 | 85 | def forward(self, px_feats, K, pooling=8, return_feats=False): 86 | """Detect poses from image features and intrinsics. 87 | 88 | Args: 89 | ---- 90 | px_feats: image features from the backbone network. 91 | K: image intrinsics. 92 | pooling: downsampling factor from input image to features. 93 | return_feats: return image feature pyramid. 94 | 95 | Return: 96 | ------ 97 | detections: 1344 candidates = 1024 + 256 + 64 from 3 levels. 98 | feat_pyramid: image feature pyramid. 99 | 100 | """ 101 | # Rescale factor 102 | ht, wd = px_feats.shape[-2:] 103 | rescale = max(ht, wd) * pooling 104 | calib_adj = K[..., 0, 0][:, None, None] 105 | 106 | # Apply encoding and decoding layers 107 | x1 = self.enc1(px_feats) 108 | x2 = self.enc2(x1) 109 | x3 = self.enc3(x2) 110 | feat_pyramid = [x1, x2, x3] 111 | 112 | # Post-process to produce full set of detections 113 | pred_state = [] 114 | for pyr_idx, feats in enumerate(feat_pyramid): 115 | feats = feats.float() 116 | 117 | ht, wd = feats.shape[-2:] 118 | ref_xy = helper.get_grid(ht, wd, device=feats.device) + 0.5 / max(ht, wd) 119 | ref_xy *= rescale 120 | ref_xy_slots = eo.repeat(ref_xy, "h w c -> (h w n) c", n=self.num_slots) 121 | 122 | pred = self.decoders[pyr_idx](feats) 123 | pred = eo.rearrange(pred, "b (n c) h w -> b (h w n) c", n=self.num_slots) 124 | 125 | ( 126 | init_betas, 127 | init_pose, 128 | init_rot, 129 | init_xy, 130 | init_z, 131 | _, 132 | conf, 133 | ) = pred.split_with_sizes(dim=-1, split_sizes=self.output_split) 134 | 135 | # Adjust translation offset based on camera intrinsics 136 | default_depth = calib_adj / (self.depth_adj * 2**pyr_idx) 137 | init_z = default_depth / (torch.exp(init_z) + 0.05) 138 | init_xy = init_xy + ref_xy_slots 139 | init_xy = helper.px_to_world(K.unsqueeze(1), init_xy) * init_z 140 | init_trans = torch.cat([init_xy, init_z], -1) 141 | 142 | pred_state.append( 143 | { 144 | "betas": init_betas, 145 | "delta_root_orient": init_rot, 146 | "pose_embedding": init_pose, 147 | "trans": init_trans, 148 | "conf": conf, 149 | } 150 | ) 151 | 152 | detections = {} 153 | for k in pred_state[0].keys(): 154 | detections[k] = torch.cat([p[k] for p in pred_state], 1).float() 155 | 156 | if return_feats: 157 | return detections, feat_pyramid 158 | else: 159 | return detections 160 | 161 | 162 | class CoMotionDetect(nn.Module): 163 | """CoMotion detection module. 164 | 165 | Module responsible for initial feature extraction from a ConvNext backbone 166 | as well as producing candidate per-frame detections. 167 | """ 168 | 169 | def __init__( 170 | self, cfg: CoMotionDetectConfig | None = None, pretrained: bool = True 171 | ): 172 | """Initialize CoMotion detection module. 173 | 174 | Args: 175 | ---- 176 | cfg: Detection config defining various model hyperparameters. 177 | pretrained: Whether to load pretrained detection checkpoint. 178 | 179 | """ 180 | super().__init__() 181 | cfg = CoMotionDetectConfig() if cfg is None else cfg 182 | 183 | self.cfg = cfg 184 | self.pos_embedding = layers.PosEmbed() 185 | self.rot_embedding = layers.RotaryEmbed(cfg.rot_embed_dim) 186 | self.kn = smpl_kinematics.SMPLKinematics() 187 | self.body_keys = ["betas", "pose", "trans"] 188 | 189 | # Instantiate an image backbone network 190 | self.image_backbone = backbones.initialize(cfg.backbone_choice) 191 | self.feat_dim = self.image_backbone.output_dim 192 | 193 | # Detection head 194 | # Output split: betas, pose embedding, root_orient, xy, z, scale, confidence 195 | output_split = [smpl_kinematics.BETA_DOF, cfg.pose_embed_dim, 3, 2, 1, 1, 1] 196 | self.detection_head = DetectionHead( 197 | input_dim=self.feat_dim, 198 | hidden_dim=self.cfg.hidden_dim, 199 | output_split=output_split, 200 | ) 201 | 202 | self.pose_decoder = nn.Sequential( 203 | nn.LayerNorm(cfg.pose_embed_dim), 204 | nn.Linear(cfg.pose_embed_dim, cfg.hidden_dim), 205 | layers.ResidualMLP(cfg.hidden_dim), 206 | nn.LayerNorm(cfg.hidden_dim), 207 | nn.GELU(), 208 | nn.Linear(cfg.hidden_dim, smpl_kinematics.POSE_DOF - 3), 209 | ) 210 | 211 | self.fuse_features = layers.FusePyramid( 212 | self.feat_dim, 213 | 2 * cfg.hidden_dim, 214 | ) 215 | 216 | if pretrained: 217 | checkpoint = torch.load(PYTORCH_CHECKPOINT_PATH, weights_only=True) 218 | self.load_state_dict(checkpoint) 219 | 220 | def _intrinsics_conditioning(self, feats, K, pooling=8, normalize_factor=1024): 221 | """Condition image features on pixel and world coordinate mapping. 222 | 223 | From input intrinsics matrix we define a coordinate grid and use rotary 224 | embeddings to update the extracted image features. 225 | """ 226 | batch_size = feats.shape[0] 227 | feats = feats.clone() 228 | device = feats.device 229 | 230 | with torch.no_grad(): 231 | # Get reference pixel positions 232 | ht, wd = feats.shape[-2:] 233 | ref_xy = helper.get_grid(ht, wd, device=device) + 0.5 / max(ht, wd) 234 | ref_xy *= max(ht, wd) * pooling 235 | 236 | # Reduce scale of pixel values 237 | K = K / normalize_factor 238 | ref_xy = ref_xy / normalize_factor 239 | 240 | # Adjust into world coordinate frame 241 | ref_xy_world = helper.px_to_world(K[:, None, None], ref_xy[None]) 242 | 243 | # Get rotary embeddings 244 | xy_cs, xy_sn = self.rot_embedding(ref_xy) 245 | xy_world_cs, xy_world_sn = self.rot_embedding(ref_xy_world) 246 | 247 | # Rearrange from BHWC -> BCHW 248 | xy_cs = eo.repeat(xy_cs, f"h w d0 d1 -> {batch_size} (d0 d1) h w") 249 | xy_sn = eo.repeat(xy_sn, f"h w d0 d1 -> {batch_size} (d0 d1) h w") 250 | xy_world_cs = eo.rearrange(xy_world_cs, "b h w d0 d1 -> b (d0 d1) h w") 251 | xy_world_sn = eo.rearrange(xy_world_sn, "b h w d0 d1 -> b (d0 d1) h w") 252 | 253 | # Apply rotary embeddings 254 | cs = torch.cat([xy_cs, xy_world_cs, xy_cs, xy_world_cs], 1) 255 | sn = torch.cat([xy_sn, xy_world_sn, xy_sn, xy_world_sn], 1) 256 | embed_dim = sn.shape[1] 257 | f0 = feats[:, :embed_dim] 258 | f0 = layers.apply_rotary_pos_emb(f0, cs, sn) 259 | feats = torch.cat([f0, feats[:, embed_dim:]], 1) 260 | 261 | return feats 262 | 263 | @torch.inference_mode 264 | def forward(self, img, K) -> DetectionOutput: 265 | """Extract backbone features and detect poses. 266 | 267 | Args: 268 | ---- 269 | img: input image tensor of shape (B, 3, 512, 512) 270 | K: input intrinsic tensor of shape (B, 2, 3) 271 | 272 | Return: 273 | ------ 274 | DetectionOutput: NamedTuple that includes all output detection parameters. 275 | 276 | """ 277 | outputs = {} 278 | 279 | # Get backbone features 280 | feats = self.image_backbone(img) 281 | feats = self._intrinsics_conditioning(feats, K) 282 | 283 | # Get detections 284 | detections, feature_pyramid = self.detection_head(feats, K, return_feats=True) 285 | for k, v in detections.items(): 286 | if k == "pose_embedding": 287 | # Remap latent pose embedding to joint angles 288 | # Note: these are residual terms applied to a default pose 289 | outputs["delta_body_pose"] = self.pose_decoder(v) * 0.3 290 | else: 291 | outputs[k] = v 292 | 293 | # Fuse feature pyramid 294 | feature_pyramid = [feats] + feature_pyramid 295 | outputs["image_features"] = self.fuse_features(*feature_pyramid) 296 | 297 | return DetectionOutput(**outputs) 298 | 299 | 300 | class CoMotionDetectCoreML: 301 | """A CoreML wrapper for CoMotion detection module.""" 302 | 303 | def __init__(self): 304 | """Initialize the CoreML model.""" 305 | import coremltools as ct 306 | 307 | self.model = ct.models.MLModel(COREML_CHECKPOINT_PATH) 308 | self.cfg = CoMotionDetectConfig() 309 | self.feat_dim = 256 310 | 311 | def __call__(self, img, K) -> DetectionOutput: 312 | """Run inference for the CoreML model.""" 313 | outputs = self.model.predict({"image": img, "K": K}) 314 | outputs = {k: torch.tensor(v) for k, v in outputs.items()} 315 | return DetectionOutput(**outputs) 316 | 317 | 318 | def get_smpl_pose(delta_root_orient, delta_body_pose): 319 | """Apply predicted delta poses to default mean pose.""" 320 | device = delta_root_orient.device 321 | default_pose = smpl_kinematics.extra_ref["mean_pose"].clone().to(device) 322 | delta_pose = torch.cat([delta_root_orient, delta_body_pose], -1) 323 | return smpl_kinematics.update_pose(default_pose, delta_pose) 324 | 325 | 326 | def decode_network_outputs( 327 | K: torch.Tensor, 328 | smpl_decoder: smpl_kinematics.SMPLKinematics, 329 | detect_out: DetectionOutput, 330 | sample_idx: int = 0, 331 | **nms_kwargs, 332 | ): 333 | """Postprocessing to get detections from network output.""" 334 | # Decode output SMPL pose and joint coordinates 335 | pose = get_smpl_pose(detect_out.delta_root_orient, detect_out.delta_body_pose) 336 | 337 | pred_3d = smpl_decoder.forward( 338 | detect_out.betas, 339 | pose, 340 | detect_out.trans, 341 | output_format="joints_face", 342 | ) 343 | pred_2d = helper.project_to_2d(K, pred_3d) 344 | 345 | # Perform non-maximum suppression 346 | nms_idxs = helper.nms_detections( 347 | pred_2d[sample_idx] / 1024, detect_out.conf[sample_idx].flatten(), **nms_kwargs 348 | ) 349 | 350 | detections = { 351 | "betas": detect_out.betas, # (1, 1344, 10) 352 | "pose": pose, # (1, 1344, 72) 353 | "trans": detect_out.trans, # (1, 1344, 3) 354 | "pred_3d": pred_3d, # (1, 1344, 27, 3) 355 | "pred_2d": pred_2d, # (1, 1344, 27, 2) 356 | "conf": detect_out.conf, # (1, 1344, 1) 357 | } 358 | 359 | # Index into selected subset of detections, add singleton batch dimension 360 | for k, v in detections.items(): 361 | detections[k] = v[sample_idx, nms_idxs][None] 362 | 363 | return detections 364 | -------------------------------------------------------------------------------- /src/comotion_demo/models/layers.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | import einops as eo 3 | import torch 4 | import transformers 5 | from timm.layers import LayerNorm2d 6 | from timm.models.convnext import ConvNeXtBlock 7 | from torch import nn 8 | from torch.functional import F 9 | from transformers.models.llama import modeling_llama 10 | 11 | from ..utils import helper 12 | 13 | 14 | class PosEmbed(nn.Module): 15 | def __init__(self, embed_dim=16): 16 | super().__init__() 17 | self.register_buffer( 18 | "freq_bands", 2 ** torch.linspace(0, embed_dim - 1, embed_dim) 19 | ) 20 | self.out_dim = 2 * (1 + len(self.freq_bands)) 21 | 22 | def forward(self, x): 23 | # Input: ... x 2 24 | emb = torch.sin(x.unsqueeze(-1) * self.freq_bands) 25 | emb = eo.rearrange(emb, "... d0 d1 -> ... (d0 d1)") 26 | 27 | return torch.cat([x, emb], -1) 28 | 29 | 30 | class RotaryEmbed(nn.Module): 31 | def __init__(self, dim=32, rescale=1000, base=10000): 32 | super().__init__() 33 | self.dim = dim 34 | self.rescale = rescale 35 | self.base = base 36 | 37 | inv_freq = rescale / (base ** (torch.arange(0, dim).float() / dim)) 38 | self.register_buffer("inv_freq", inv_freq) 39 | self.out_dim = dim 40 | 41 | def forward(self, x): 42 | """Apply frequency band and return cos and sin terms.""" 43 | emb = x.unsqueeze(-1) * self.inv_freq 44 | return emb.cos(), emb.sin() 45 | 46 | 47 | def apply_rotary_pos_emb(x, cs, sn): 48 | half_channels = x.shape[1] // 2 49 | x0 = x[:, :half_channels] 50 | x1 = x[:, half_channels:] 51 | x_ = torch.cat((-x1, x0), dim=1) 52 | return cs * x + sn * x_ 53 | 54 | 55 | def _gru_update(h, z, q): 56 | return (1 - z) * h + z * q 57 | 58 | 59 | class GRU(nn.Module): 60 | def __init__(self, hdim=128, f_in=128): 61 | super().__init__() 62 | self.fz = nn.Linear(hdim + f_in, hdim) 63 | self.fr = nn.Linear(hdim + f_in, hdim) 64 | self.fq = nn.Linear(hdim + f_in, hdim) 65 | 66 | def forward(self, h, x): 67 | hx = torch.cat([h, x], dim=-1) 68 | 69 | z = torch.sigmoid(self.fz(hx)) 70 | r = torch.sigmoid(self.fr(hx)) 71 | q = torch.tanh(self.fq(torch.cat([r * h, x], dim=-1))) 72 | 73 | h = _gru_update(h, z, q) 74 | return h 75 | 76 | 77 | class ResidualMLP(nn.Module): 78 | """Residual Multilayer Perceptron.""" 79 | 80 | def __init__(self, dim, out_dim=None, hidden_dim=None, pre_ln=True): 81 | super().__init__() 82 | 83 | if out_dim is None: 84 | out_dim = dim 85 | if hidden_dim is None: 86 | hidden_dim = 2 * dim 87 | self.is_residual = dim == out_dim 88 | self.pre_ln = pre_ln 89 | 90 | if pre_ln: 91 | self.ln0 = nn.LayerNorm(dim) 92 | self.l1 = nn.Linear(dim, hidden_dim) 93 | self.ln1 = nn.LayerNorm(hidden_dim) 94 | self.act = nn.GELU() 95 | self.l2 = nn.Linear(hidden_dim, out_dim) 96 | 97 | def forward(self, x): 98 | inp = x 99 | if self.pre_ln: 100 | x = self.ln0(x) 101 | x = self.l1(x) 102 | x = self.act(self.ln1(x)) 103 | x = self.l2(x) 104 | 105 | if self.is_residual: 106 | return inp + x 107 | else: 108 | return x 109 | 110 | 111 | class DownsampleConvNextBlock(nn.Module): 112 | """Module to downsample 2x and apply ConvNext layers. 113 | 114 | Note we assume the input has already had normalization applied, 115 | and apply LayerNorm as the last operation. 116 | """ 117 | 118 | def __init__(self, input_dim, output_dim=None, dropout=0.0): 119 | super().__init__() 120 | 121 | if output_dim is None: 122 | output_dim = input_dim 123 | 124 | self.layers = nn.Sequential( 125 | nn.Conv2d(input_dim, output_dim, 2, 2, 0), # Downsample 126 | ConvNeXtBlock(output_dim), 127 | LayerNorm2d(output_dim), 128 | nn.Dropout(p=dropout), 129 | ) 130 | 131 | def forward(self, x): 132 | return self.layers(x) 133 | 134 | 135 | @helper.fixed_dim_op(d=3) 136 | def _call_llama(x, tf): 137 | kargs = { 138 | "position_embeddings": ( 139 | torch.ones_like(x[..., :1]), 140 | torch.zeros_like(x[..., :1]), 141 | ) 142 | } 143 | for tf_ in tf: 144 | x = tf_(x, **kargs)[0] 145 | return x 146 | 147 | 148 | class DecodeFromTokens(nn.Module): 149 | """Decode token into hidden update.""" 150 | 151 | def __init__(self, hidden_dim): 152 | super().__init__() 153 | 154 | self.ln = nn.LayerNorm(hidden_dim) 155 | self.token_weight = nn.Linear(hidden_dim, hidden_dim) 156 | self.token_to_value = nn.Linear(hidden_dim, hidden_dim) 157 | self.decoder = ResidualMLP(hidden_dim) 158 | 159 | def forward(self, x): 160 | x = self.ln(x) 161 | v = self.token_to_value(x) 162 | w = torch.sigmoid(self.token_weight(x)) 163 | v = v * w 164 | x = v.mean(-2) 165 | x = self.decoder(x) 166 | 167 | return x 168 | 169 | 170 | class CrossAttention(nn.Module): 171 | """Cross attention module.""" 172 | 173 | def __init__( 174 | self, 175 | num_tokens, 176 | num_heads, 177 | hidden_dim, 178 | token_dim, 179 | num_layers=2, 180 | drop_rate=0.1, 181 | use_global_attention=2, 182 | ): 183 | super().__init__() 184 | self.num_tokens = num_tokens 185 | self.num_heads = num_heads 186 | self.hidden_dim = hidden_dim 187 | self.token_dim = token_dim 188 | self.num_layers = num_layers 189 | self.drop_rate = drop_rate 190 | self.use_global_attention = use_global_attention 191 | 192 | self.token_to_query = nn.Sequential( 193 | ResidualMLP(self.hidden_dim), 194 | nn.LayerNorm(self.hidden_dim), 195 | nn.Linear(self.hidden_dim, self.token_dim), 196 | ) 197 | 198 | self.post_attention = nn.Sequential( 199 | ResidualMLP(self.token_dim), 200 | nn.LayerNorm(self.token_dim), 201 | nn.Linear(self.token_dim, self.hidden_dim), 202 | ) 203 | 204 | cfg = transformers.LlamaConfig() 205 | cfg.hidden_size = self.hidden_dim 206 | cfg.intermediate_size = self.hidden_dim * 2 207 | cfg.num_attention_heads = 16 208 | cfg.num_key_value_heads = 16 209 | cfg.attention_dropout = self.drop_rate 210 | 211 | if self.use_global_attention > 0: 212 | self.global_attention = nn.ModuleList( 213 | [ 214 | modeling_llama.LlamaDecoderLayer(cfg, i) 215 | for i in range(self.num_layers) 216 | ] 217 | ) 218 | 219 | self.indiv_attention = nn.ModuleList( 220 | [modeling_llama.LlamaDecoderLayer(cfg, i) for i in range(self.num_layers)] 221 | ) 222 | 223 | def forward(self, image_key, image_value, tokens): 224 | q = self.token_to_query(tokens) 225 | q = eo.rearrange( 226 | q, "b n d0 (h d1) -> b h (n d0) d1", h=self.num_heads 227 | ).contiguous() 228 | px_feedback = F.scaled_dot_product_attention(q, image_key, image_value) 229 | px_feedback = eo.rearrange( 230 | px_feedback, "b h (n d0) d1 -> b n d0 (h d1)", d0=self.num_tokens 231 | ) 232 | tokens = tokens + self.post_attention(px_feedback) 233 | 234 | # Attention across all people 235 | if self.use_global_attention == 1: 236 | # Concatenate all tokens together 237 | tokens = eo.rearrange(tokens, "b n d0 d1 -> b (n d0) d1") 238 | tokens = _call_llama(tokens, self.global_attention) 239 | tokens = eo.rearrange( 240 | tokens, "b (n d0) d1 -> b n d0 d1", d0=self.num_tokens 241 | ) 242 | elif self.use_global_attention == 2: 243 | # Do attention over people separately per-token 244 | tokens = eo.rearrange(tokens, "b n d0 d1 -> b d0 n d1") 245 | tokens = _call_llama(tokens, self.global_attention) 246 | tokens = eo.rearrange(tokens, "b d0 n d1 -> b n d0 d1") 247 | 248 | # Separate attention update per-person 249 | tokens = _call_llama(tokens, self.indiv_attention) 250 | 251 | return tokens 252 | 253 | 254 | class FusePyramid(nn.Module): 255 | """Module for fusing the feature pyramid.""" 256 | 257 | def __init__(self, in_dim=256, hidden_dim=512): 258 | """Initialize FusePyramid module.""" 259 | super().__init__() 260 | 261 | self.dc0 = nn.ConvTranspose2d(hidden_dim, hidden_dim, 2, 2) 262 | self.proj1 = nn.Conv2d(hidden_dim, hidden_dim, 1) 263 | self.dc1 = nn.ConvTranspose2d(hidden_dim, hidden_dim, 2, 2) 264 | self.ln1 = LayerNorm2d(hidden_dim) 265 | self.proj2 = nn.Conv2d(hidden_dim, hidden_dim, 1) 266 | self.dc2 = nn.ConvTranspose2d(hidden_dim, hidden_dim, 2, 2) 267 | self.ln2 = LayerNorm2d(hidden_dim) 268 | self.dc3 = nn.ConvTranspose2d(hidden_dim, in_dim, 4, 4) 269 | self.proj3 = nn.Conv2d(in_dim, in_dim, 1) 270 | self.ln3 = LayerNorm2d(in_dim) 271 | 272 | @helper.fixed_dim_op(nargs=4, is_class_fn=True) 273 | def forward(self, f64, f16, f8, f4): 274 | """Aggregate features across feature pyramid. 275 | 276 | Args: 277 | ---- 278 | f64: features of shape (B, 3, 64, 64), assuming input res is 512x512. 279 | f16: features of shape (B, 3, 16, 16) 280 | f8: features of shape (B, 3, 8, 8) 281 | f4: features of shape (B, 3, 4, 4) 282 | 283 | Return: 284 | ------ 285 | output features of shape (B, in_dim, 64, 64) with the same res as f64. 286 | 287 | """ 288 | x = self.dc0(f4) + self.proj1(f8) 289 | x = self.ln1(F.gelu(x)) 290 | x = self.dc1(x) + self.proj2(f16) 291 | x = self.ln2(F.gelu(x)) 292 | x = self.dc3(x) + self.proj3(f64) 293 | x = self.ln3(x) 294 | 295 | return x 296 | -------------------------------------------------------------------------------- /src/comotion_demo/models/refine.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | import os 3 | from collections import namedtuple 4 | from dataclasses import dataclass 5 | 6 | import einops as eo 7 | import torch 8 | from torch import nn 9 | 10 | from ..utils import helper, smpl_kinematics 11 | from . import layers 12 | 13 | curr_dir = os.path.abspath(os.path.dirname(__file__)) 14 | PYTORCH_CHECKPOINT_PATH = f"{curr_dir}/../data/comotion_refine_checkpoint.pt" 15 | 16 | RefineOutput = namedtuple( 17 | "RefineOutput", 18 | [ 19 | "delta_root_orient", 20 | "delta_body_pose", 21 | "trans", 22 | "hidden", 23 | ], 24 | ) 25 | 26 | 27 | @dataclass 28 | class PoseRefinementConfig: 29 | num_tokens: int = 24 30 | num_heads: int = 4 31 | token_dim: int = 256 32 | hidden_dim: int = 512 33 | pose_embed_dim: int = 256 34 | normalizing_factor: int = 1024 35 | 36 | 37 | class PoseRefinement(nn.Module): 38 | """CoMotion refinement module.""" 39 | 40 | def __init__( 41 | self, 42 | image_feat_dim: int, 43 | pose_embed_dim: int, 44 | cfg: PoseRefinementConfig | None = None, 45 | pretrained: bool = True, 46 | ): 47 | """Initialize refinement module. 48 | 49 | Args: 50 | ---- 51 | image_feat_dim: image backbone output feature dimension. 52 | pose_embed_dim: pose embedding dimension. 53 | cfg: pose refinement config. 54 | pretrained: whether to load pre-trained weights. 55 | 56 | """ 57 | super().__init__() 58 | cfg = PoseRefinementConfig() if cfg is None else cfg 59 | self.cfg = cfg 60 | self.image_feat_dim = image_feat_dim 61 | self.pose_embed_dim = pose_embed_dim 62 | self.pos_embedding = layers.PosEmbed() 63 | self.encode_grid = nn.Linear(self.pos_embedding.out_dim, cfg.token_dim) 64 | token_kv_in, token_kv_out = image_feat_dim + cfg.token_dim, cfg.token_dim 65 | 66 | self.get_px_key = nn.Linear(token_kv_in, token_kv_out) 67 | self.get_px_value = nn.Linear(token_kv_in, token_kv_out) 68 | 69 | # Output reference 70 | self.split_ref = [ 71 | smpl_kinematics.BETA_DOF, 72 | pose_embed_dim, 73 | 3, 74 | smpl_kinematics.TRANS_DOF, 75 | ] 76 | 77 | # Hidden update layers 78 | self.tokens_to_hidden = layers.DecodeFromTokens(cfg.hidden_dim) 79 | self.feedback_gru_update = layers.GRU(cfg.hidden_dim, cfg.hidden_dim) 80 | 81 | def init_token_encoder(in_dim): 82 | return layers.ResidualMLP( 83 | in_dim, 84 | out_dim=cfg.num_tokens * cfg.hidden_dim, 85 | hidden_dim=2 * cfg.hidden_dim, 86 | pre_ln=False, 87 | ) 88 | 89 | # Convert coordinates and hidden state to tokens 90 | cd_dim = smpl_kinematics.NUM_PARTS_AUX * (self.pos_embedding.out_dim + 1) 91 | dof = ( 92 | smpl_kinematics.BETA_DOF 93 | + smpl_kinematics.POSE_DOF 94 | + smpl_kinematics.TRANS_DOF 95 | ) 96 | 97 | self.encode_coords = init_token_encoder(cd_dim) 98 | self.encode_hidden = init_token_encoder(cfg.hidden_dim) 99 | self.encode_smpl = init_token_encoder(dof) 100 | 101 | cross_attention_kargs = { 102 | "num_tokens": cfg.num_tokens, 103 | "num_heads": cfg.num_heads, 104 | "hidden_dim": cfg.hidden_dim, 105 | "token_dim": cfg.token_dim, 106 | } 107 | 108 | self.cross_attention = nn.Sequential( 109 | layers.CrossAttention(**cross_attention_kargs), 110 | layers.CrossAttention(**cross_attention_kargs), 111 | ) 112 | 113 | decode = layers.DecodeFromTokens 114 | 115 | self.get_px_feedback = decode(cfg.hidden_dim) 116 | self.tokens_to_hidden = nn.Sequential( 117 | decode(cfg.hidden_dim), 118 | nn.LayerNorm(cfg.hidden_dim), 119 | ) 120 | self.project_hidden = nn.Linear(cfg.hidden_dim, cfg.hidden_dim) 121 | 122 | self.get_pose_update = nn.Sequential( 123 | layers.ResidualMLP(cfg.hidden_dim), 124 | nn.LayerNorm(cfg.hidden_dim), 125 | nn.GELU(), 126 | nn.Linear(cfg.hidden_dim, sum(self.split_ref)), 127 | ) 128 | 129 | # This is the same module used in the detection step (with identical weights) 130 | # We reinstantiate it to support using the separate CoreML detection stage 131 | self.pose_decoder = nn.Sequential( 132 | nn.LayerNorm(cfg.pose_embed_dim), 133 | nn.Linear(cfg.pose_embed_dim, cfg.hidden_dim), 134 | layers.ResidualMLP(cfg.hidden_dim), 135 | nn.LayerNorm(cfg.hidden_dim), 136 | nn.GELU(), 137 | nn.Linear(cfg.hidden_dim, smpl_kinematics.POSE_DOF - 3), 138 | ) 139 | 140 | if pretrained: 141 | checkpoint = torch.load(PYTORCH_CHECKPOINT_PATH, weights_only=True) 142 | self.load_state_dict(checkpoint) 143 | 144 | @helper.fixed_dim_op(d=4, is_class_fn=True) 145 | def compute_image_kv(self, image_feats, pooling=8): 146 | """From image features, get flattened set of key, value token pairs.""" 147 | batch_size = image_feats.shape[0] 148 | res = image_feats.shape[-2:] 149 | device = image_feats.device 150 | 151 | # Rearrange and flatten image features 152 | image_feats = eo.rearrange(image_feats, "b c h w -> b (h w) c") 153 | 154 | # Calculate reference grid of pixel positions 155 | with torch.no_grad(): 156 | px_scale_factor = max(res) * pooling 157 | grid = helper.get_grid(res[0], res[1], device) 158 | grid = grid * px_scale_factor / self.cfg.normalizing_factor 159 | 160 | grid_embed = self.encode_grid(self.pos_embedding(grid)) 161 | grid_embed = eo.repeat(grid_embed, "h w d -> b (h w) d", b=batch_size) 162 | image_feats = torch.cat([image_feats, grid_embed], -1) 163 | 164 | # Apply a linear layer to get keys and values 165 | pixel_k = eo.rearrange( 166 | self.get_px_key(image_feats), 167 | "... d0 (h d1) -> ... h d0 d1", 168 | h=self.cfg.num_heads, 169 | ).contiguous() 170 | pixel_v = eo.rearrange( 171 | self.get_px_value(image_feats), 172 | "... d0 (h d1) -> ... h d0 d1", 173 | h=self.cfg.num_heads, 174 | ).contiguous() 175 | 176 | return pixel_k, pixel_v 177 | 178 | def calib_adjusted_trans(self, K, trans, delta_trans, depth_adj=128, eps=0.05): 179 | """Update SMPL translation term based on provided intrinsics. 180 | 181 | This same operation is performed during detection, the output x, y 182 | terms are in pixel space and mapped to world coordinates using K. 183 | """ 184 | delta_xy, delta_z = delta_trans.split_with_sizes(dim=-1, split_sizes=[2, 1]) 185 | 186 | # Get new depth estimate 187 | default_depth = K[..., 0, 0][:, None, None] / depth_adj 188 | z = default_depth / (torch.exp(delta_z) + eps) 189 | 190 | # Apply delta to current x, y position in pixel space 191 | base_xy = helper.project_to_2d(K.unsqueeze(1), trans) 192 | xy = helper.px_to_world(K.unsqueeze(1), base_xy + delta_xy) * z 193 | return torch.cat([xy, z], -1) 194 | 195 | def encode_state( 196 | self, 197 | K, 198 | hidden, 199 | pred_3d, 200 | body_params, 201 | ): 202 | """Encode tracks into tokens. 203 | 204 | Hidden state, SMPL parameters, and 2D keypoint coordinates are all 205 | passed through an MLP to produce a set of tokens per-person. 206 | """ 207 | # Project to 2D 208 | K = K[:, None, None] / self.cfg.normalizing_factor 209 | xy = helper.project_to_2d(K, pred_3d).clamp(-5, 5) 210 | z = pred_3d[..., 2] 211 | cd_embed = eo.rearrange(self.pos_embedding(xy), "... d0 d1 -> ... (d0 d1)") 212 | cd_embed = torch.cat([cd_embed, z], -1) 213 | 214 | # Encode tokens 215 | tokens = self.encode_coords(cd_embed) 216 | tokens = tokens + self.encode_hidden(hidden) 217 | tokens = tokens + self.encode_smpl(body_params) 218 | tokens = eo.rearrange( 219 | tokens, "... (d0 d1) -> ... d0 d1", d0=self.cfg.num_tokens 220 | ) 221 | 222 | return tokens 223 | 224 | def perform_update(self, K, pixel_k, pixel_v, tokens, trans, hidden): 225 | """Attend to image features and calculate final outputs. 226 | 227 | Args: 228 | ---- 229 | K: Input intrinsics 230 | pixel_k: Flattened set of image token keys. 231 | pixel_v: Flattened set of image token values. 232 | tokens: Feature encoding of current tracks. 233 | trans: SMPL translation term for each track. 234 | hidden: Hidden state for each track. 235 | 236 | """ 237 | # Perform cross attention to get feedback from image features 238 | for ca in self.cross_attention: 239 | tokens = ca(pixel_k, pixel_v, tokens) 240 | 241 | # Update hidden state 242 | hidden_update = self.tokens_to_hidden(tokens) 243 | hidden = self.feedback_gru_update(hidden, hidden_update) 244 | 245 | # Update current state 246 | px_feedback = self.get_px_feedback(tokens) 247 | px_feedback = px_feedback + self.project_hidden(hidden) 248 | 249 | delta_smpl = self.get_pose_update(px_feedback) 250 | delta_smpl = delta_smpl.split_with_sizes(dim=-1, split_sizes=self.split_ref) 251 | _, pose_embedding, delta_root_orient, delta_trans = delta_smpl 252 | 253 | delta_body_pose = self.pose_decoder(pose_embedding) * 0.3 254 | trans = self.calib_adjusted_trans(K, trans, delta_trans) 255 | return RefineOutput(delta_root_orient, delta_body_pose, trans, hidden) 256 | 257 | def forward(self, image_feats, K, betas, pose, trans, pred_3d, hidden, pooling=8): 258 | """Predict new poses given image features and current tracks. 259 | 260 | Args: 261 | ---- 262 | image_feats: Per-pixel image features. 263 | K: Input intrinsics. 264 | betas: SMPL beta parameters for each track. 265 | pose: SMPL pose parameters for each track. 266 | trans: SMPL translation term for each track. 267 | pred_3d: 3D keypoints in camera coordinate frame for each track. 268 | hidden: Hidden state for each track. 269 | pooling: Indicator of how image features have been pooled from input image. 270 | 271 | """ 272 | body_params = torch.cat([betas, pose, trans], -1) 273 | pixel_k, pixel_v = self.compute_image_kv(image_feats, pooling=pooling) 274 | tokens = self.encode_state(K, hidden, pred_3d, body_params) 275 | return self.perform_update(K, pixel_k, pixel_v, tokens, trans, hidden) 276 | -------------------------------------------------------------------------------- /src/comotion_demo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | """CoMotion utility functions.""" 3 | -------------------------------------------------------------------------------- /src/comotion_demo/utils/dataloading.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | import logging 3 | from pathlib import Path 4 | from typing import Generator, Tuple 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | from numpy.typing import NDArray 10 | from PIL import Image 11 | from torchvision import transforms 12 | 13 | IMG_MEAN = torch.tensor([0.4850, 0.4560, 0.4060]).view(-1, 1, 1) 14 | IMG_STD = torch.tensor([0.2290, 0.2240, 0.2250]).view(-1, 1, 1) 15 | INTERNAL_RESOLUTION = (512, 512) 16 | VIDEO_EXTENSIONS = {".mp4"} 17 | IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png"} 18 | 19 | 20 | def normalize_image(image: torch.Tensor) -> torch.Tensor: 21 | """Apply ImageNet normalization to an image tensor.""" 22 | if not isinstance(image, torch.Tensor): 23 | raise ValueError("Expect input to be a torch.Tensor") 24 | return (image - IMG_MEAN) / IMG_STD 25 | 26 | 27 | def unnormalize_image(image: torch.Tensor) -> torch.Tensor: 28 | """Undo ImageNet normalization to an image tensor.""" 29 | if not isinstance(image, torch.Tensor): 30 | raise ValueError("Expect input to be a torch.Tensor") 31 | return image * IMG_STD + IMG_MEAN 32 | 33 | 34 | def convert_image_to_tensor(image: NDArray[np.uint8]) -> torch.Tensor: 35 | """Convert an uint8 numpy array of shape HWC to a float tensor of shape CHW.""" 36 | if not isinstance(image, np.ndarray): 37 | raise ValueError("Expect input to be a numpy array.") 38 | if image.dtype != np.uint8: 39 | raise ValueError("Expect input to be np.uint8 typed.") 40 | return torch.from_numpy(image).permute(2, 0, 1) 41 | 42 | 43 | def convert_tensor_to_image(tensor: torch.Tensor) -> NDArray[np.uint8]: 44 | """Convert a float tensor of shape CHW to an uint8 numpy array of shape HWC.""" 45 | if not isinstance(tensor, torch.Tensor): 46 | raise ValueError("Expect input to be a torch Tensor") 47 | return tensor.permute(1, 2, 0).cpu().numpy().astype(np.uint8) 48 | 49 | 50 | def crop_image_and_update_K( 51 | image: torch.Tensor, 52 | K: torch.Tensor, 53 | target_resolution: Tuple[int, int] = INTERNAL_RESOLUTION, 54 | ) -> Tuple[torch.Tensor, torch.Tensor]: 55 | """Pad and resize image to target resolution and update intrinsics.""" 56 | target_height, target_width = target_resolution 57 | target_hw_ratio = target_height / target_width 58 | source_height, source_width = torch.tensor(image.shape[-2:]) 59 | source_hw_ratio = source_height / source_width 60 | 61 | if source_hw_ratio >= target_hw_ratio: 62 | # pad needed along the width axis 63 | crop_height = source_height 64 | crop_width = int(source_height / target_hw_ratio) 65 | else: 66 | # pad needed along the height axis 67 | crop_height = int(source_width * target_hw_ratio) 68 | crop_width = source_width 69 | 70 | img_center_x = source_width / 2 71 | img_center_y = source_height / 2 72 | 73 | offset_x = int(img_center_x - crop_width / 2) 74 | offset_y = int(img_center_y - crop_height / 2) 75 | 76 | crop_args = [ 77 | offset_y, 78 | offset_x, 79 | crop_height, 80 | crop_width, 81 | target_resolution, 82 | transforms.InterpolationMode.BILINEAR, 83 | ] 84 | # Pad, crop, and resize image 85 | cropped_image = transforms.functional.resized_crop( 86 | image, *crop_args, antialias=True 87 | ) 88 | 89 | scale_y = target_height / crop_height 90 | scale_x = target_width / crop_width 91 | 92 | cropped_K = K.clone() 93 | cropped_K[..., 0, 0] *= scale_x 94 | cropped_K[..., 1, 1] *= scale_y 95 | cropped_K[..., 0, 2] = (cropped_K[..., 0, 2] - offset_x) * scale_x 96 | cropped_K[..., 1, 2] = (cropped_K[..., 1, 2] - offset_y) * scale_y 97 | 98 | return cropped_image, cropped_K 99 | 100 | 101 | def yield_image_from_directory( 102 | directory: Path, 103 | start_frame: int, 104 | num_frames: int, 105 | frameskip: int = 1, 106 | ) -> Generator[NDArray[np.uint8], None, None]: 107 | """Generate the next frame from a directory.""" 108 | if not directory.is_dir(): 109 | raise ValueError(f"Path is not a directory: {directory}") 110 | image_files = sorted( 111 | [ 112 | file 113 | for file in directory.glob("*") 114 | if file.is_file() and file.suffix.lower() in IMAGE_EXTENSIONS 115 | ] 116 | ) 117 | if not image_files: 118 | raise ValueError(f"No images found in directory: {directory}") 119 | 120 | image_files = image_files[start_frame::frameskip] 121 | 122 | if len(image_files) > num_frames: 123 | image_files = image_files[:num_frames] 124 | 125 | for image_file in image_files: 126 | image = np.array(Image.open(image_file).convert("RGB")) 127 | yield image 128 | 129 | 130 | def yield_image_from_video( 131 | filepath: Path, 132 | start_frame: int, 133 | num_frames: int, 134 | frameskip: int = 1, 135 | ) -> Generator[NDArray[np.uint8], None, None]: 136 | """Generate the next frame from a video.""" 137 | if filepath.suffix.lower() not in VIDEO_EXTENSIONS: 138 | raise ValueError(f"Input file is not a video: {filepath}") 139 | 140 | video = cv2.VideoCapture(filepath.as_posix()) 141 | 142 | if not video.isOpened(): 143 | raise ValueError(f"Could not open video file: {filepath}") 144 | 145 | max_frames = video.get(cv2.CAP_PROP_FRAME_COUNT) 146 | logging.info( 147 | f"Yielding {num_frames} frames from {filepath} ({max_frames} in total)." 148 | ) 149 | 150 | if max_frames <= start_frame: 151 | logging.warning(f"Cannot start on frame {start_frame}.") 152 | video.release() 153 | return 154 | 155 | success = video.set(cv2.CAP_PROP_POS_FRAMES, start_frame) 156 | if success: 157 | logging.info(f"Starting from frame {start_frame}") 158 | 159 | frame_limit = min(num_frames, max_frames - start_frame) 160 | frame_count = 0 161 | while frame_count < frame_limit: 162 | success, frame = video.read() 163 | if not success: 164 | break 165 | image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 166 | yield image 167 | frame_count += 1 168 | 169 | # Skip intermediate frames if frameskip > 1 170 | for _ in range(1, frameskip): 171 | success, frame = video.read() 172 | if not success: 173 | break 174 | 175 | video.release() 176 | 177 | 178 | def is_a_video(input_path: Path): 179 | return input_path.suffix in VIDEO_EXTENSIONS 180 | 181 | 182 | def get_input_video_fps(input_path: Path) -> float: 183 | cap = cv2.VideoCapture(input_path.as_posix()) 184 | if not cap.isOpened(): 185 | raise RuntimeError(f"Failed to load video at {input_path}.") 186 | 187 | fps = cap.get(cv2.CAP_PROP_FPS) 188 | if fps == 0: 189 | raise RuntimeError("Failed to retrieve FPS.") 190 | 191 | return fps 192 | 193 | 194 | def yield_image( 195 | input_path: Path, 196 | start_frame: int, 197 | num_frames: int, 198 | frameskip: int = 1, 199 | ) -> Generator[NDArray[np.uint8], None, None]: 200 | """Generate the next frame from either a video or a directory.""" 201 | if is_a_video(input_path): 202 | yield from yield_image_from_video( 203 | input_path, start_frame, num_frames, frameskip 204 | ) 205 | elif input_path.is_dir(): 206 | yield from yield_image_from_directory( 207 | input_path, start_frame, num_frames, frameskip 208 | ) 209 | else: 210 | raise ValueError("Input path must point to a video file or a directory") 211 | 212 | 213 | def get_default_K(image: torch.Tensor) -> torch.Tensor: 214 | """Get a default approximate intrinsic matrix.""" 215 | res = image.shape[-2:] 216 | max_res = max(res) 217 | K = torch.tensor([[2 * max_res, 0, 0.5 * res[1]], [0, 2 * max_res, 0.5 * res[0]]]) 218 | return K 219 | 220 | 221 | def yield_image_and_K( 222 | input_path: Path, 223 | start_frame: int, 224 | num_frames: int, 225 | frameskip: int = 1, 226 | ) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]: 227 | """Generate image and the intrinsic matrix.""" 228 | for image in yield_image(input_path, start_frame, num_frames, frameskip): 229 | image = convert_image_to_tensor(image) 230 | K = get_default_K(image) 231 | yield (image, K) 232 | 233 | 234 | def prepare_network_inputs( 235 | image: torch.Tensor, 236 | K: torch.Tensor, 237 | device: torch.device | str = "cpu", 238 | ) -> Tuple[torch.Tensor, torch.Tensor]: 239 | """Image and intrinsics prep before inference. 240 | 241 | We crop and pad the input to a 512x512 image and update the intrinsics 242 | accordingly. The input is also expected to be ImageNet normalized. 243 | 244 | This demo code only supports processing individual samples. Some operations 245 | assume the existence of a batch dimension, so we add a singleton batch 246 | dimension here. Other operations such as NMS and track management do not 247 | correctly handle batched inputs. 248 | 249 | Args: 250 | ---- 251 | image: Input image (C x H x W) float (0-1) or uint8 (0-255) tensor 252 | K: Intrinsics matrix (2 x 3) float tensor 253 | device: Target device for inference 254 | 255 | """ 256 | if image.dtype == torch.uint8: 257 | # Cast to float and normalize to 0-1 258 | image = image.float() / 255 259 | cropped_image, cropped_K = crop_image_and_update_K(image, K) 260 | cropped_image = normalize_image(cropped_image) 261 | 262 | # Add "batch" dimension and cast to target device 263 | cropped_image = cropped_image[None].to(device) 264 | cropped_K = cropped_K[None].to(device) 265 | 266 | return cropped_image, cropped_K 267 | -------------------------------------------------------------------------------- /src/comotion_demo/utils/helper.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | """Miscellaneous helper functions.""" 3 | 4 | import functools 5 | 6 | import einops as eo 7 | import numpy as np 8 | import torch 9 | 10 | 11 | def px_to_world(K, pts): 12 | """Convert values from pixel coordinates to world coordinates.""" 13 | x, y = pts.unbind(-1) 14 | return torch.stack( 15 | [(x - K[..., 0, 2]) / K[..., 0, 0], (y - K[..., 1, 2]) / K[..., 1, 1]], -1 16 | ) 17 | 18 | 19 | def project_to_2d(K, pts3d, min_depth=0.001): 20 | """Project 3D points to pixel coordinates. 21 | 22 | Args: 23 | ---- 24 | K: ... x 2 x 3 25 | pts3d: ... x 3 26 | min_depth: minimum depth for clamping (default: 0.001). 27 | 28 | """ 29 | # Normalize by depth 30 | z = pts3d[..., -1:].clamp_min(min_depth) 31 | pts3d_norml = pts3d / z 32 | pts3d_norml[..., -1].fill_(1.0) 33 | 34 | # Apply intrinsics 35 | pts_2d = eo.einsum(K, pts3d_norml, "... i j, ... j -> ... i") 36 | 37 | return pts_2d 38 | 39 | 40 | def _merge_aux(arr, dim=0, n_merged=2, unmerge=False, vals=None): 41 | n_extra = -(dim + 1) if dim < 0 else dim 42 | extra_d = " ".join(map(lambda x: f"a{x}", range(n_extra))) 43 | di = " ".join(map(lambda x: f"d{x}", range(n_merged))) 44 | df = f"({di})" 45 | kargs = {} 46 | 47 | if unmerge: 48 | di, df = df, di 49 | if isinstance(vals, int): 50 | vals = [vals] 51 | for v_idx, v in enumerate(vals): 52 | kargs[f"d{v_idx}"] = v 53 | 54 | if dim < 0: 55 | return eo.rearrange(arr, f"... {di} {extra_d} -> ... {df} {extra_d}", **kargs) 56 | else: 57 | return eo.rearrange(arr, f"{extra_d} {di} ... -> {extra_d} {df} ...", **kargs) 58 | 59 | 60 | def _merge_d(arr, dim=0, n_merged=2): 61 | return _merge_aux(arr, dim, n_merged) 62 | 63 | 64 | def _unmerge_d(arr, vals, dim=0, n_merged=2): 65 | return _merge_aux(arr, dim, n_merged, True, vals) 66 | 67 | 68 | def fixed_dim_op(fn_=None, d=4, nargs=1, is_class_fn=False): 69 | def _decorator(fn): 70 | @functools.wraps(fn) 71 | def run_fixed_dim(*args, **kargs): 72 | args = [a for a in args] 73 | if is_class_fn: 74 | self = args[0] 75 | args = args[1:] 76 | 77 | tmp_nargs = min(len(args), nargs) 78 | ndim = args[0].ndim 79 | if ndim > d: 80 | # If input is greater than required dims 81 | # flatten first dimensions together 82 | n = ndim - d + 1 83 | d_ref = args[0].shape[:n] 84 | for i in range(tmp_nargs): 85 | args[i] = _merge_d(args[i], 0, n) 86 | 87 | def adjust_fn(x): 88 | return _unmerge_d(x, d_ref, 0, n) if x is not None else x 89 | 90 | elif ndim < d: 91 | # If input is less than the required dims 92 | # prepend extra singleton dimensions 93 | n = d - ndim 94 | one_str = " ".join(["1"] * n) 95 | for i in range(tmp_nargs): 96 | args[i] = eo.rearrange(args[i], f"... -> {one_str} ...") 97 | 98 | def adjust_fn(x): 99 | return ( 100 | eo.rearrange(x, f"{one_str} ... -> ...") if x is not None else x 101 | ) 102 | 103 | else: 104 | # Otherwise, do nothing 105 | def adjust_fn(x): 106 | return x 107 | 108 | if is_class_fn: 109 | args = [self, *args] 110 | fn_out = fn(*args, **kargs) 111 | 112 | if isinstance(fn_out, tuple) or isinstance(fn_out, list): 113 | return [adjust_fn(v) for v in fn_out] 114 | elif isinstance(fn_out, dict): 115 | return {k: adjust_fn(v) for k, v in fn_out.items()} 116 | else: 117 | return adjust_fn(fn_out) 118 | 119 | return run_fixed_dim 120 | 121 | if fn_ is not None: 122 | return _decorator(fn_) 123 | else: 124 | return _decorator 125 | 126 | 127 | def get_grid(ht, wd, device="cpu"): 128 | """Get coordinate grid for a set of features. 129 | 130 | Assume pixels are normalized to 1 (aspect-ratio adjusted so 1 is max(ht,wd)) 131 | """ 132 | grid = torch.meshgrid( 133 | torch.arange(wd, device=device), torch.arange(ht, device=device), indexing="xy" 134 | ) 135 | grid = torch.stack(grid, -1) / float(max(ht, wd)) 136 | return grid 137 | 138 | 139 | _link_ref = torch.tensor( 140 | [ 141 | (23, 24), # eyes 142 | (22, 16), # nose -> shoulder 143 | (22, 17), # nose -> shoulder 144 | (16, 17), # shoulders 145 | (16, 1), # left shoulder -> hip 146 | (17, 2), # right shoulder -> hip 147 | (16, 18), # left shoulder -> elbow 148 | (17, 19), # right shoulder -> elbow 149 | (20, 18), # left wrist -> elbow 150 | (21, 19), # right wrist -> elbow 151 | (4, 1), # left knee -> hip 152 | (5, 2), # right knee -> hip 153 | (4, 7), # left knee -> ankle 154 | (5, 8), # right knee -> ankle 155 | ] 156 | ).T 157 | 158 | 159 | # Calculated based on default SMPL with above links 160 | _link_dists_ref = [ 161 | 0.07, 162 | 0.3, 163 | 0.3, 164 | 0.35, 165 | 0.55, 166 | 0.55, 167 | 0.26, 168 | 0.26, 169 | 0.25, 170 | 0.25, 171 | 0.38, 172 | 0.38, 173 | 0.4, 174 | 0.4, 175 | ] 176 | 177 | 178 | def _get_skeleton_scale(kps: torch.Tensor, valid=None): 179 | """Return approximate "size" of person in image. 180 | 181 | Instead of bounding box dimensions, we compare pairs of keypoints as defined 182 | in the links above. We ignore any distances that involve an "invalid" keypoint. 183 | We assume any values set to (0, 0) are invalid and should be ignored. 184 | """ 185 | invalid = (kps == 0).all(-1) 186 | if valid is not None: 187 | invalid |= ~(valid > 0) 188 | 189 | # Calculate link distances 190 | dists = (kps[..., _link_ref[0], :] - kps[..., _link_ref[1], :]).norm(dim=-1) 191 | 192 | # Compare to reference distances 193 | ratio = dists / torch.tensor(_link_dists_ref, device=dists.device) 194 | 195 | # Zero out any invalid links 196 | invalid = invalid[..., _link_ref[0]] | invalid[..., _link_ref[1]] 197 | ratio *= (~invalid).float() 198 | 199 | # Return max ratio which corresponds to limb with least foreshortening 200 | max_ratio = ratio.max(-1)[0] 201 | return max_ratio.clamp_min(0.001) 202 | 203 | 204 | def normalized_weighted_score( 205 | kp0, 206 | c0, 207 | kp1, 208 | c1, 209 | std=0.08, 210 | return_dists=False, 211 | ref_scale=None, 212 | min_scale=0.02, 213 | max_scale=1, 214 | fixed_scale=None, 215 | ): 216 | """Measure similarity of two sets of body pose keypoints. 217 | 218 | This is a modified version of the COCO object keypoint similarity (OKS) 219 | calculation using a Cauchy distribution instead of a Gaussian. We also 220 | compute a normalizing scale on the fly based on projected limb proportions. 221 | """ 222 | # Combine confidence terms 223 | # conf: ... num_people x num_people x num_points 224 | conf = (c0.unsqueeze(-2) * c1.unsqueeze(-3)) ** 0.5 225 | 226 | # Calculate scale adjustment 227 | scale0 = _get_skeleton_scale(kp0, c0) 228 | scale1 = _get_skeleton_scale(kp1, c1) 229 | if ref_scale is None: 230 | scale = scale0.unsqueeze(-1).maximum(scale1.unsqueeze(-2)) 231 | elif ref_scale == 0: 232 | scale = scale0.unsqueeze(-1) 233 | elif ref_scale == 1: 234 | scale = scale1.unsqueeze(-2) 235 | 236 | # Set scale bounds 237 | zero_mask = scale == 0 238 | scale.clamp_(min_scale, max_scale) 239 | scale[zero_mask] = 1e-6 240 | if fixed_scale is not None: 241 | scale[:] = fixed_scale 242 | 243 | # Scale-adjusted distance calculation 244 | # kp: ... num_people x num_pts x 2 245 | # dists: ... num_people x num_people x num_pts 246 | dists = (kp0.unsqueeze(-3) - kp1.unsqueeze(-4)).norm(dim=-1) 247 | dists = dists / scale.unsqueeze(-1) 248 | 249 | total_conf = conf.sum(-1) 250 | zero_filt = total_conf == 0 251 | scores = 1 / (1 + (dists / std) ** 2) 252 | scores = (scores * conf).sum(-1) 253 | scores = scores / total_conf 254 | scores[zero_filt] = 0 255 | 256 | if return_dists: 257 | return scores, dists 258 | else: 259 | return scores 260 | 261 | 262 | def nms_detections(pred_2d, conf, conf_thr=0.2, iou_thr=0.4, std=0.08): 263 | """Compare keypoint estimates and return indices of nonoverlapping estimates. 264 | 265 | Note: input predictions are expected to be normalized (e.g. ranging from 0 to 1), 266 | not in pixel coordinate space. 267 | """ 268 | conf_sigmoid = torch.sigmoid(conf) 269 | 270 | # Get indices of estimates above confidence threshold, sorted by confidence 271 | sorted_idxs = (-conf).argsort() 272 | sorted_idxs = sorted_idxs[conf_sigmoid[sorted_idxs] > conf_thr] 273 | 274 | # Compare all pairs of keypoint estimates 275 | p = pred_2d[sorted_idxs] 276 | c = torch.ones_like(p[..., 0]) 277 | ious = normalized_weighted_score(p, c, p, c, std=std).cpu() 278 | 279 | # Identify estimates that are lower confidence and too similar to another estimate 280 | triu = torch.triu_indices(len(p), len(p), offset=1) 281 | ious = ious[triu[0], triu[1]] 282 | to_remove = triu[1][ious > iou_thr] 283 | 284 | # Return the subset of estimates to keep 285 | to_keep = np.setdiff1d(np.arange(len(p)), to_remove.unique().numpy()) 286 | final_idxs = sorted_idxs.cpu()[to_keep] 287 | return final_idxs 288 | 289 | 290 | def check_inbounds(kps, res): 291 | """Return binary mask indicating which keypoints are inbounds.""" 292 | in_x = (kps[..., 0] > 0) & (kps[..., 0] < res[1]) 293 | in_y = (kps[..., 1] > 0) & (kps[..., 1] < res[0]) 294 | inbounds = in_x & in_y 295 | return inbounds 296 | 297 | 298 | def points_to_bbox2d(pts, pad_dims=None): 299 | """Get min and max range of set of keypoints to define 2d bounding-box. 300 | 301 | Output format is ((x0, y0), (x1, y1)). 302 | 303 | Args: 304 | ---- 305 | pts: ... x K x 2 306 | pad_dims: optional padding (as percentage of bounding-box size) 307 | 308 | Output: 309 | ---- 310 | bboxes: ... x 2 x 2 311 | 312 | """ 313 | p = torch.stack([pts.min(-2)[0], pts.max(-2)[0]], -2) 314 | 315 | if pad_dims is not None: 316 | dimensions = p[..., 1, :] - p[..., 0, :] 317 | scale = dimensions.max(dim=-1)[0] 318 | pad_x = scale * pad_dims[0] 319 | pad_y = scale * pad_dims[1] 320 | 321 | p[..., 0, 0] -= pad_x 322 | p[..., 0, 1] -= pad_y 323 | p[..., 1, 0] += pad_x 324 | p[..., 1, 1] += pad_y 325 | 326 | return p 327 | 328 | 329 | def hsv2rgb(hsv): 330 | """Convert a tuple from HSV to RGB.""" 331 | h, s, v = hsv 332 | vals = np.tile(v, [3] + [1] * v.ndim) 333 | vals[1:] *= 1 - s[None] 334 | 335 | h[h > (5 / 6)] -= 1 336 | diffs = np.tile(h, [3] + [1] * h.ndim) - (np.arange(3) / 3).reshape( 337 | 3, *[1] * h.ndim 338 | ) 339 | max_idx = np.abs(diffs).argmin(0) 340 | 341 | final_rgb = np.zeros_like(vals) 342 | 343 | for i in range(3): 344 | tmp_d = diffs[i] * (max_idx == i) 345 | dv = tmp_d * 6 * s * v 346 | vals[1] += np.maximum(0, dv) 347 | vals[2] += np.maximum(0, -dv) 348 | 349 | final_rgb += np.roll(vals, i, axis=0) * (max_idx == i) 350 | 351 | return final_rgb.transpose(*list(np.arange(h.ndim) + 1), 0) 352 | 353 | 354 | def init_color_ref(n, seed=12345): 355 | """Sample N colors in HSV and convert them to RGB.""" 356 | rand_state = np.random.RandomState(seed) 357 | rand_hsv = rand_state.rand(n, 3) 358 | rand_hsv[:, 1:] = 1 - rand_hsv[:, 1:] * 0.3 359 | color_ref = hsv2rgb(rand_hsv.T) 360 | 361 | return color_ref 362 | 363 | 364 | color_ref = init_color_ref(2000) 365 | -------------------------------------------------------------------------------- /src/comotion_demo/utils/smpl_kinematics.py: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2025 Apple Inc. All Rights Reserved. 2 | 3 | import os 4 | import pickle 5 | 6 | import chumpy 7 | import einops as eo 8 | import numpy as np 9 | import pypose 10 | import torch 11 | from torch import nn 12 | 13 | BETA_DOF = 10 14 | POSE_DOF = 72 15 | TRANS_DOF = 3 16 | FACE_VERTEX_IDXS = [332, 6260, 2800, 4071, 583] 17 | NUM_PARTS_AUX = 27 18 | 19 | smpl_dir = os.path.join(os.path.dirname(__file__), "../data/smpl") 20 | extra_ref = torch.load(f"{smpl_dir}/extra_smpl_reference.pt", weights_only=True) 21 | smpl_model_path = f"{smpl_dir}/SMPL_NEUTRAL.pkl" 22 | assert os.path.exists(smpl_model_path), ( 23 | "Please download the neutral SMPL body model from https://smpl.is.tue.mpg.de/ and" 24 | "rename it to SMPL_NEUTRAL.pkl, copying it into src/comotion_demo/data/smpl/" 25 | ) 26 | 27 | 28 | def to_rotmat(theta): 29 | """Convert axis-angle to rotation matrix.""" 30 | rotmat = pypose.so3(theta.view(-1, 3)).matrix() 31 | return rotmat.reshape(*theta.shape, 3) 32 | 33 | 34 | @torch.compiler.disable 35 | def update_pose(pose, delta_pose): 36 | """Apply residual to an existing set of joint angles. 37 | 38 | Instead of adding the update to the current pose directly, we apply the 39 | update as a rotation. 40 | """ 41 | # Convert flattened pose to K x 3 set of axis-angle terms 42 | pose = eo.rearrange(pose, "... (k c) -> ... k c", c=3).contiguous() 43 | delta_pose = eo.rearrange(delta_pose, "... (k c) -> ... k c", c=3).contiguous() 44 | 45 | # Change to a rotation matrix and multiply the matrices together 46 | pose: pypose.SO3_type = pypose.so3(delta_pose).Exp() * pypose.so3(pose).Exp() 47 | 48 | # Map back to axis-angle representation and flatten again 49 | pose = pose.Log().tensor() 50 | return eo.rearrange(pose, "... k c -> ... (k c)") 51 | 52 | 53 | class SMPLKinematics(nn.Module): 54 | """Parse SMPL parameters to get mesh and joint coordinates.""" 55 | 56 | def __init__(self): 57 | """Initialize SMPLKinematics.""" 58 | super().__init__() 59 | 60 | self.num_tf = 24 61 | 62 | # Load SMPL neutral model 63 | with open(smpl_model_path, "rb") as f: 64 | smpl_neutral = pickle.load(f, encoding="latin1") 65 | 66 | # Prepare model data and convert to torch tensors 67 | smpl_keys = [ 68 | "J_regressor", 69 | "v_template", 70 | "posedirs", 71 | "shapedirs", 72 | "weights", 73 | "kintree_table", 74 | ] 75 | for k in smpl_keys: 76 | v = smpl_neutral[k] 77 | if isinstance(v, chumpy.ch.Ch): 78 | v = np.array(v) 79 | elif not isinstance(v, np.ndarray): 80 | v = v.toarray() 81 | v = torch.tensor(v).float() 82 | 83 | if k == "shapedirs": 84 | v = v[..., :BETA_DOF].view(-1, BETA_DOF).contiguous() 85 | elif k == "posedirs": 86 | v = v.view(-1, v.shape[-1]).T 87 | elif k == "kintree_table": 88 | k = "parents" 89 | v = v[0].long() 90 | v[0] = -1 91 | 92 | self.register_buffer(k, v, persistent=False) 93 | 94 | # Precompute intermediate matrices 95 | self.register_buffer( 96 | "joint_shapedirs", 97 | (self.J_regressor @ self.shapedirs.view(-1, 3 * BETA_DOF)).view( 98 | self.num_tf * 3, BETA_DOF 99 | ), 100 | persistent=False, 101 | ) 102 | self.register_buffer( 103 | "joint_template", self.J_regressor @ self.v_template, persistent=False 104 | ) 105 | 106 | # Load additional reference tensors 107 | self.register_buffer("mean_pose", extra_ref["mean_pose"], persistent=False) 108 | self.register_buffer("coco_map", extra_ref["coco_map"], persistent=False) 109 | 110 | def forward_kinematics(self, rot_mat, unposed_tx): 111 | # Get relative joint positions 112 | relative_tx = unposed_tx.clone() 113 | relative_tx[1:] -= relative_tx[self.parents[1:]] 114 | 115 | # Define transformation matrices 116 | tf = torch.cat([rot_mat, relative_tx.unsqueeze(-1)], -1) 117 | tf = torch.functional.F.pad(tf, [0, 0, 0, 1]) 118 | tf[..., 3, 3] = 1 119 | 120 | # Follow kinematic chain 121 | tf_chain = [tf[0]] 122 | for child_idx in range(1, self.num_tf): 123 | parent_idx = self.parents[child_idx] 124 | tf_chain.append(tf_chain[parent_idx] @ tf[child_idx]) 125 | tf_chain = torch.stack(tf_chain) 126 | 127 | # Get output joint positions 128 | joints = eo.rearrange(tf_chain[..., :3, 3], "k ... c -> ... k c") 129 | return tf_chain, joints 130 | 131 | def get_smpl_template(self, subsample_rate=1, idx_subset=None): 132 | """Return reference tensors to derive SMPL mesh. 133 | 134 | Optional flags allow us to subsample a subset of vertices from the mesh to 135 | avoid computation over the complete mesh. 136 | """ 137 | posedir_dim = self.posedirs.shape[0] 138 | shapedirs = eo.rearrange(self.shapedirs, "(n c) m -> n c m", c=3) 139 | 140 | if idx_subset is None: 141 | v_template = self.v_template[..., ::subsample_rate, :] 142 | posedirs = self.posedirs.view(posedir_dim, -1, 3)[ 143 | :, ::subsample_rate 144 | ].reshape(posedir_dim, -1) 145 | shapedirs = eo.rearrange(shapedirs[::subsample_rate], "n c m -> (n c) m") 146 | weights = self.weights[::subsample_rate] 147 | else: 148 | v_template = self.v_template[..., idx_subset, :] 149 | posedirs = self.posedirs.view(posedir_dim, -1, 3)[:, idx_subset].reshape( 150 | posedir_dim, -1 151 | ) 152 | shapedirs = eo.rearrange(shapedirs[idx_subset], "n c m -> (n c) m") 153 | weights = self.weights[idx_subset] 154 | 155 | return v_template, posedirs, shapedirs, weights 156 | 157 | def forward( 158 | self, 159 | betas, 160 | pose, 161 | trans=None, 162 | output_format="joints", 163 | subsample_rate=1, 164 | idx_subset=None, 165 | ): 166 | """Get joints and mesh vertices. 167 | 168 | Valid output formats: 169 | - joints: 24x3 set of SMPL joints 170 | - joints_face: 27x3 joints where the first 22 values are original SMPL joints 171 | while the last 5 correspond to face keypoints 172 | - joints_coco: 17x3 joints in COCO order that match COCO annotation format 173 | e.g. hips are higher and wider than in SMPL 174 | - mesh: 6890x3 set of SMPL mesh vertices 175 | if idx_subset is not None, returns len(idx_subset) vertices 176 | if subsample rate > 1, returns (6078 // subsample_rate) vertices 177 | """ 178 | # Convert pose to rotation matrices 179 | pose = eo.rearrange(pose, "... (k c) -> k ... c", c=3).contiguous() 180 | rot_mat = to_rotmat(pose) 181 | 182 | # Adjust mesh based on betas 183 | blend_shape = eo.rearrange( 184 | self.joint_shapedirs @ betas.unsqueeze(-1), "... (k c) 1 -> ... k c", c=3 185 | ) 186 | unposed_tx = self.joint_template + blend_shape 187 | unposed_tx = eo.rearrange(unposed_tx, "... k c -> k ... c") 188 | 189 | # Run forward kinematics 190 | tf_chain, joints = self.forward_kinematics(rot_mat, unposed_tx) 191 | 192 | if output_format == "joints": 193 | smpl_output = joints 194 | 195 | else: 196 | if output_format == "joints_face": 197 | idx_subset = FACE_VERTEX_IDXS 198 | elif output_format == "joints_coco": 199 | subsample_rate = 1 200 | idx_subset = None 201 | 202 | v_template, posedirs, shapedirs, weights = self.get_smpl_template( 203 | subsample_rate, idx_subset 204 | ) 205 | 206 | # Adjust mesh based on betas 207 | blend_shape = eo.rearrange( 208 | shapedirs @ betas.unsqueeze(-1), "... (k c) 1 -> ... k c", c=3 209 | ) 210 | v_shaped = v_template + blend_shape 211 | 212 | # Get relative transforms (rotate unposed joints and subtract) 213 | tf_offset = (tf_chain[..., :3, :3] @ unposed_tx.unsqueeze(-1)).squeeze(-1) 214 | tf_relative = tf_chain.clone() 215 | tf_relative[..., :3, 3] -= tf_offset 216 | A = eo.rearrange(tf_relative, "n ... d0 d1 -> ... n (d0 d1)") 217 | 218 | # Flatten all rotation matrices (except root rotation), remove identity 219 | eye = torch.eye(3, device=rot_mat.device) 220 | pose_feature = eo.rearrange( 221 | rot_mat[1:] - eye, "k ... d0 d1 -> ... (k d0 d1)" 222 | ) 223 | 224 | # Adjust base vertex positions 225 | pose_offsets = eo.rearrange( 226 | pose_feature @ posedirs, "... (n c) -> ... n c", c=3 227 | ) 228 | v_posed = v_shaped + pose_offsets 229 | 230 | # Transform vertices based on pose 231 | T = weights @ A 232 | T = eo.rearrange(T, "... (d0 d1) -> ... d0 d1", d0=4) 233 | vertices = torch.functional.F.pad(v_posed, [0, 1], value=1) 234 | vertices = (T @ vertices.unsqueeze(-1)).squeeze(-1) 235 | vertices = vertices[..., :-1] / vertices[..., -1:] 236 | 237 | smpl_output = vertices 238 | if output_format == "joints_face": 239 | smpl_output = torch.cat([joints[..., :22, :], vertices], -2) 240 | elif output_format == "joints_coco": 241 | smpl_output = eo.einsum( 242 | vertices, self.coco_map, "... i j, i k -> ... k j" 243 | ) 244 | 245 | if trans is not None: 246 | # Apply translation offset 247 | smpl_output = smpl_output + trans.unsqueeze(-2) 248 | 249 | return smpl_output 250 | --------------------------------------------------------------------------------