├── .gitignore ├── LICENSE ├── README.md ├── data └── relpose_co3d_subset │ ├── backpack │ └── 346_36107_66565 │ │ ├── images │ │ ├── frame000001.jpg │ │ ├── frame000011.jpg │ │ ├── frame000021.jpg │ │ ├── frame000031.jpg │ │ ├── frame000041.jpg │ │ ├── frame000051.jpg │ │ ├── frame000061.jpg │ │ ├── frame000071.jpg │ │ ├── frame000081.jpg │ │ └── frame000091.jpg │ │ └── masks │ │ ├── frame000001.png │ │ ├── frame000011.png │ │ ├── frame000021.png │ │ ├── frame000031.png │ │ ├── frame000041.png │ │ ├── frame000051.png │ │ ├── frame000061.png │ │ ├── frame000071.png │ │ ├── frame000081.png │ │ └── frame000091.png │ ├── chair │ └── 187_20196_37207 │ │ ├── images │ │ ├── frame000001.jpg │ │ ├── frame000011.jpg │ │ ├── frame000021.jpg │ │ ├── frame000031.jpg │ │ ├── frame000041.jpg │ │ ├── frame000051.jpg │ │ ├── frame000061.jpg │ │ ├── frame000071.jpg │ │ ├── frame000081.jpg │ │ └── frame000091.jpg │ │ └── masks │ │ ├── frame000001.png │ │ ├── frame000011.png │ │ ├── frame000021.png │ │ ├── frame000031.png │ │ ├── frame000041.png │ │ ├── frame000051.png │ │ ├── frame000061.png │ │ ├── frame000071.png │ │ ├── frame000081.png │ │ └── frame000091.png │ ├── hydrant │ └── 194_20925_42241 │ │ ├── images │ │ ├── frame000001.jpg │ │ ├── frame000011.jpg │ │ ├── frame000021.jpg │ │ ├── frame000031.jpg │ │ ├── frame000041.jpg │ │ ├── frame000051.jpg │ │ ├── frame000061.jpg │ │ ├── frame000071.jpg │ │ ├── frame000081.jpg │ │ └── frame000091.jpg │ │ └── masks │ │ ├── frame000001.png │ │ ├── frame000011.png │ │ ├── frame000021.png │ │ ├── frame000031.png │ │ ├── frame000041.png │ │ ├── frame000051.png │ │ ├── frame000061.png │ │ ├── frame000071.png │ │ ├── frame000081.png │ │ └── frame000091.png │ └── plant │ └── 374_41993_84073 │ ├── images │ ├── frame000001.jpg │ ├── frame000011.jpg │ ├── frame000021.jpg │ ├── frame000031.jpg │ ├── frame000041.jpg │ ├── frame000051.jpg │ ├── frame000061.jpg │ ├── frame000071.jpg │ ├── frame000081.jpg │ └── frame000091.jpg │ └── masks │ ├── frame000001.png │ ├── frame000011.png │ ├── frame000021.png │ ├── frame000031.png │ ├── frame000041.png │ ├── frame000051.png │ ├── frame000061.png │ ├── frame000071.png │ ├── frame000081.png │ └── frame000091.png ├── dev └── linter.sh ├── docs ├── dataset.md └── eval.md ├── notebooks └── demo.ipynb ├── preprocess ├── __init__.py ├── copy_co3d.py ├── preprocess_co3d.py └── preprocess_co3dv1.py ├── relpose ├── __init__.py ├── dataset │ ├── __init__.py │ ├── co3d.py │ ├── co3dv1.py │ ├── custom.py │ └── dataloader.py ├── eval │ ├── __init__.py │ ├── eval_driver.py │ ├── eval_joint.py │ ├── eval_pairwise.py │ ├── load_model.py │ └── pairwise.py ├── inference │ ├── __init__.py │ └── joint_inference.py ├── models.py ├── trainer.py └── utils │ ├── __init__.py │ ├── bbox.py │ ├── geometry.py │ ├── permutations.py │ └── visualize.py ├── requirements.txt └── setup.cfg /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | data/ 3 | output 4 | eval_jobs.sh 5 | 6 | # Byte-compiled / optimized / DLL files 7 | __pycache__/ 8 | *.py[cod] 9 | *$py.class 10 | 11 | # C extensions 12 | *.so 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | share/python-wheels/ 29 | *.egg-info/ 30 | .installed.cfg 31 | *.egg 32 | MANIFEST 33 | 34 | # PyInstaller 35 | # Usually these files are written by a python script from a template 36 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 37 | *.manifest 38 | *.spec 39 | 40 | # Installer logs 41 | pip-log.txt 42 | pip-delete-this-directory.txt 43 | 44 | # Unit test / coverage reports 45 | htmlcov/ 46 | .tox/ 47 | .nox/ 48 | .coverage 49 | .coverage.* 50 | .cache 51 | nosetests.xml 52 | coverage.xml 53 | *.cover 54 | *.py,cover 55 | .hypothesis/ 56 | .pytest_cache/ 57 | cover/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | .pybuilder/ 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | # For a library or package, you might want to ignore these files since the code is 92 | # intended to run in multiple environments; otherwise, check them in: 93 | # .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # poetry 103 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 104 | # This is especially recommended for binary packages to ensure reproducibility, and is more 105 | # commonly ignored for libraries. 106 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 107 | #poetry.lock 108 | 109 | # pdm 110 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 111 | #pdm.lock 112 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 113 | # in version control. 114 | # https://pdm.fming.dev/#use-with-ide 115 | .pdm.toml 116 | 117 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 118 | __pypackages__/ 119 | 120 | # Celery stuff 121 | celerybeat-schedule 122 | celerybeat.pid 123 | 124 | # SageMath parsed files 125 | *.sage.py 126 | 127 | # Environments 128 | .env 129 | .venv 130 | env/ 131 | venv/ 132 | ENV/ 133 | env.bak/ 134 | venv.bak/ 135 | 136 | # Spyder project settings 137 | .spyderproject 138 | .spyproject 139 | 140 | # Rope project settings 141 | .ropeproject 142 | 143 | # mkdocs documentation 144 | /site 145 | 146 | # mypy 147 | .mypy_cache/ 148 | .dmypy.json 149 | dmypy.json 150 | 151 | # Pyre type checker 152 | .pyre/ 153 | 154 | # pytype static type analyzer 155 | .pytype/ 156 | 157 | # Cython debug symbols 158 | cython_debug/ 159 | 160 | # PyCharm 161 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 162 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 163 | # and can be added to the global gitignore or merged into this file. For a more nuclear 164 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 165 | #.idea/ 166 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2022, Jason Y. Zhang 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | 1. Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | 2. Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | 3. Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RelPose: Predicting Probabilistic Relative Rotation for Single Objects in the Wild 2 | 3 | [[`arXiv`](https://arxiv.org/abs/2208.05963)] 4 | [[`Project Page`](https://jasonyzhang.com/relpose/)] 5 | [[`Bibtex`](#citing-relpose)] 6 | 7 | ## Installation 8 | 9 | Follow directions for setting up CO3D (v1 or v2) from [here](dataset.md) 10 | 11 | ### Setup 12 | We recommend using conda to manage dependencies. Make sure to install a cudatoolkit 13 | compatible with your GPU. 14 | ``` 15 | git clone --depth 1 https://github.com/jasonyzhang/relpose.git 16 | conda create -n relpose python=3.8 17 | conda activate relpose 18 | conda install pytorch==1.12.0 torchvision==0.13.0 cudatoolkit=11.3 -c pytorch 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ### Model Weights 23 | 24 | You can download the pre-trained model weights on both CO3Dv1 and CO3Dv2 from 25 | [Google Drive](https://drive.google.com/file/d/1XwRjxOzqj6DXGg_bzYFy83iDlZx8mkQ-/view?usp=share_link). 26 | Alternatively, you can use gdown: 27 | ``` 28 | gdown --output data/pretrained_relpose.zip https://drive.google.com/uc?id=1XwRjxOzqj6DXGg_bzYFy83iDlZx8mkQ- 29 | unzip data/pretrained_relpose.zip -d data 30 | ``` 31 | 32 | ### Installing Pytorch3d 33 | 34 | Here, we list the recommended steps for installing Pytorch3d. Refer to the 35 | [official installation directions](https://github.com/facebookresearch/pytorch3d/blob/master/INSTALL.md) 36 | for troubleshooting and additional details. 37 | 38 | ``` 39 | mkdir -p external 40 | git clone --depth 1 --branch v0.7.0 https://github.com/facebookresearch/pytorch3d.git external/pytorch3d 41 | cd external/pytorch3d 42 | conda activate relpose 43 | conda install -c conda-forge -c fvcore -c iopath -c bottler fvcore iopath nvidiacub 44 | python setup.py install 45 | ``` 46 | 47 | If you need to compile for multiple architectures (e.g. Turing for 2080TI and Maxwell 48 | for 1080TI), you can pass the architectures as an environment variable, i.e. 49 | `TORCH_CUDA_ARCH_LIST="Maxwell;Pascal;Turing;Volta" python setup.py install`. 50 | 51 | If you get a warning about the default C/C++ compiler on your machine, you should 52 | compile Pytorch3D using the same compiler that your pytorch installation uses, likely 53 | gcc/g++. Try: `CC=gcc CXX=g++ python setup.py install`. 54 | 55 | 56 | ### Dataset Preparation 57 | 58 | Please see [docs/dataset.md](docs/dataset.md) for instructions on preparing the CO3Dv1 dataset or your own dataset. 59 | 60 | ## Training 61 | 62 | Once the datasets are setup, run the following command to train on 4 GPUs on CO3Dv2: 63 | ``` 64 | python -m relpose.trainer --batch_size 64 --num_gpus 4 --output_dir output --dataset co3d 65 | ``` 66 | 67 | With 4 2080TI GPUs, we expect training to take a little less than 2 days. 68 | 69 | ## Inference 70 | 71 | Please see [notebooks/demo.ipynb](notebooks/demo.ipynb) for a demo of visualizing 72 | pairwise relative pose distributions given 2 images as well as recovering camera 73 | rotations using the pairwise predictor. Currently, the demo supports using a Maximum 74 | Spanning Tree and Coordinate Ascent for joint camera pose inference. 75 | 76 | ## Evaluation 77 | 78 | Please see [docs/eval.md](docs/eval.md) for instructions on evaluating on sequential, 79 | MST, and coordinate ascent inference. 80 | 81 | ## Citing RelPose 82 | 83 | If you use find this code helpful, please cite: 84 | 85 | ```BibTeX 86 | @InProceedings{zhang2022relpose, 87 | title = {{RelPose}: Predicting Probabilistic Relative Rotation for Single Objects in the Wild}, 88 | author = {Zhang, Jason Y. and Ramanan, Deva and Tulsiani, Shubham}, 89 | booktitle = {European Conference on Computer Vision}, 90 | year = {2022}, 91 | } 92 | ``` -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000001.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000011.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000021.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000031.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000041.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000051.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000051.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000061.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000071.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000071.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000081.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000081.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/images/frame000091.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000001.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000011.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000021.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000031.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000031.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000041.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000041.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000051.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000051.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000061.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000061.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000071.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000071.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000081.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000081.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000091.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/backpack/346_36107_66565/masks/frame000091.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/images/frame000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/images/frame000001.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/images/frame000011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/images/frame000011.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/images/frame000021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/images/frame000021.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/images/frame000031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/images/frame000031.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/images/frame000041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/images/frame000041.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/images/frame000051.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/images/frame000051.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/images/frame000061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/images/frame000061.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/images/frame000071.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/images/frame000071.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/images/frame000081.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/images/frame000081.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/images/frame000091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/images/frame000091.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000001.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000011.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000021.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000031.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000031.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000041.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000041.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000051.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000051.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000061.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000061.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000071.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000071.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000081.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000081.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000091.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/chair/187_20196_37207/masks/frame000091.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000001.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000011.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000021.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000031.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000041.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000051.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000051.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000061.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000071.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000071.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000081.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000081.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/images/frame000091.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000001.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000011.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000021.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000031.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000031.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000041.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000041.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000051.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000051.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000061.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000061.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000071.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000071.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000081.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000081.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000091.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/hydrant/194_20925_42241/masks/frame000091.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/images/frame000001.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/images/frame000001.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/images/frame000011.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/images/frame000011.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/images/frame000021.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/images/frame000021.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/images/frame000031.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/images/frame000031.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/images/frame000041.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/images/frame000041.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/images/frame000051.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/images/frame000051.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/images/frame000061.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/images/frame000061.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/images/frame000071.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/images/frame000071.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/images/frame000081.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/images/frame000081.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/images/frame000091.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/images/frame000091.jpg -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000001.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000001.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000011.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000011.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000021.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000021.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000031.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000031.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000041.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000041.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000051.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000051.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000061.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000061.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000071.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000071.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000081.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000081.png -------------------------------------------------------------------------------- /data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000091.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/data/relpose_co3d_subset/plant/374_41993_84073/masks/frame000091.png -------------------------------------------------------------------------------- /dev/linter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash -e 2 | # Borrowed from FAIR 3 | 4 | # Run this script at project root by "./dev/linter.sh" before you commit 5 | 6 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" 7 | DIR=$(dirname "${DIR}") 8 | 9 | echo "Running isort..." 10 | # isort -y -sp "${DIR}" 11 | isort "${DIR}" 12 | 13 | echo "Running black..." 14 | black "${DIR}" 15 | 16 | echo "Running flake..." 17 | 18 | -------------------------------------------------------------------------------- /docs/dataset.md: -------------------------------------------------------------------------------- 1 | # Dataset Preparation 2 | 3 | ## Preparing CO3Dv1 Dataset 4 | 5 | Follow the directions to download and extract CO3Dv1 from 6 | [here](https://github.com/facebookresearch/co3d/tree/v1). 7 | 8 | You will need to pre-process that dataset to extract the bounding boxes and camera poses. 9 | The bounding boxes are processed separately because it takes the most amount of time 10 | and should only be run once. 11 | ``` 12 | python -m preprocess.preprocess_co3dv1 --category all --precompute_bbox \ 13 | --co3d_v1_dir /path/to/co3d_v1 14 | python -m preprocess.preprocess_co3dv1 --category all \ 15 | --co3d_v1_dir /path/to/co3d_v1 16 | ``` 17 | 18 | ## Preparing Your Own Dataset 19 | 20 | For inference on your own video, you can use the `CustomDataset` class in 21 | `relpose/dataset/custom.py`. You will need to provide a directory of images and a 22 | directory of masks or bounding boxes. The masks are simply used to compute bounding 23 | boxes. 24 | -------------------------------------------------------------------------------- /docs/eval.md: -------------------------------------------------------------------------------- 1 | # Evaluation 2 | 3 | For a trained model, use `relpose/eval/eval_joint.py` for evaluation. For example, 4 | to evaluate using a maximum spanning tree over 5 frames from seen object categories, 5 | use: 6 | ``` 7 | python -m relpose.eval.eval_joint \ 8 | --checkpoint data/pretrained_co3dv1/checkpoints/ckpt_000400000.pth \ 9 | --num_frames 5 \ 10 | --use_pbar \ 11 | --dataset co3dv1 \ 12 | --categories_type seen \ 13 | --mode mst 14 | ``` 15 | 16 | To simplify evaluation, we provide a script to generate a shell file with the commands 17 | for different numbers of frames, evaluation modes, etc. Please see 18 | `relpose/eval/eval_driver.py` which generate a script `eval_jobs.sh`. You can then run 19 | this script with `sh eval_jobs.sh`. 20 | 21 | Note: coordinate ascent must be run after the MST evaluation because it is initialized 22 | from the MST solution. 23 | 24 | ## Expected evaluation results 25 | 26 | These models were retrained and may not match the numbers in the paper. There may also 27 | be some stochasticiy in the runs. 28 | 29 | ### CO3Dv2 30 | 31 | Expected evaluation results (Uniform, seen categories): 32 | ``` 33 | Sequential N=3 N=5 N=10 N=20 34 | Acc <15° 0.38 0.36 0.33 0.29 35 | Acc <30° 0.61 0.59 0.57 0.54 36 | 37 | MST N=3 N=5 N=10 N=20 38 | Acc <15° 0.38 0.44 0.46 0.43 39 | Acc <30° 0.61 0.63 0.64 0.61 40 | 41 | Coord Asc N=3 N=5 N=10 N=20 42 | Acc <15° 0.44 0.51 0.54 0.56 43 | Acc <30° 0.63 0.69 0.71 0.72 44 | ``` 45 | 46 | 47 | Expected evaluation results (Uniform, unseen categories): 48 | ``` 49 | Sequential N=3 N=5 N=10 N=20 50 | Acc <15° 0.28 0.27 0.27 0.24 51 | Acc <30° 0.48 0.46 0.47 0.47 52 | 53 | MST N=3 N=5 N=10 N=20 54 | Acc <15° 0.29 0.32 0.37 0.37 55 | Acc <30° 0.48 0.50 0.52 0.53 56 | 57 | Coord Asc N=3 N=5 N=10 N=20 58 | Acc <15° 0.33 0.37 0.43 0.46 59 | Acc <30° 0.51 0.55 0.58 0.61 60 | 61 | ``` 62 | 63 | 64 | ### CO3Dv1 65 | 66 | Expected evaluation results (Uniform, seen categories): 67 | ``` 68 | Sequential N=3 N=5 N=10 N=20 69 | Acc <15° 0.31 0.30 0.30 0.28 70 | Acc <30° 0.54 0.51 0.51 0.51 71 | 72 | MST N=3 N=5 N=10 N=20 73 | Acc <15° 0.30 0.33 0.35 0.34 74 | Acc <30° 0.53 0.54 0.55 0.53 75 | 76 | Coord Asc N=3 N=5 N=10 N=20 77 | Acc <15° 0.35 0.38 0.43 0.45 78 | Acc <30° 0.56 0.58 0.62 0.64 79 | ``` 80 | 81 | 82 | Expected evaluation results (Uniform, unseen categories): 83 | ``` 84 | Sequential N=3 N=5 N=10 N=20 85 | Acc <15° 0.18 0.21 0.23 0.25 86 | Acc <30° 0.39 0.38 0.43 0.46 87 | 88 | MST N=3 N=5 N=10 N=20 89 | Acc <15° 0.19 0.22 0.25 0.27 90 | Acc <30° 0.41 0.42 0.42 0.43 91 | 92 | Coord Asc N=3 N=5 N=10 N=20 93 | Acc <15° 0.20 0.25 0.31 0.34 94 | Acc <30° 0.42 0.45 0.51 0.52 95 | 96 | ``` 97 | -------------------------------------------------------------------------------- /preprocess/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jasonyzhang/relpose/e96c643341d97dfda4921ac19fb272a1bb77c65a/preprocess/__init__.py -------------------------------------------------------------------------------- /preprocess/copy_co3d.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from glob import glob 5 | 6 | from tqdm.auto import tqdm 7 | 8 | 9 | def get_parser(): 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("--source_dir", type=str, default="/data/drive3/jason/co3d") 12 | parser.add_argument("--target_dir", type=str, default="/data/drive2/jason/co3d") 13 | parser.add_argument("--category", default="*") 14 | return parser 15 | 16 | 17 | def main(args): 18 | source_dir = args.source_dir 19 | target_dir = args.target_dir 20 | all_image_paths = sorted( 21 | glob(os.path.join(source_dir, args.category, "*", "images", "*.jpg")) 22 | ) 23 | for image_path in tqdm(all_image_paths): 24 | target_image_path = os.path.join( 25 | target_dir, image_path.replace(source_dir, "")[1:] # drop the first '/' 26 | ) 27 | target_image_dir = os.path.dirname(target_image_path) 28 | if not os.path.exists(target_image_dir): 29 | os.makedirs(target_image_dir, exist_ok=True) 30 | shutil.copy(image_path, target_image_path) 31 | 32 | 33 | if __name__ == "__main__": 34 | main(get_parser().parse_args()) -------------------------------------------------------------------------------- /preprocess/preprocess_co3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to pre-process camera poses and bounding boxes for CO3Dv2 dataset. This is 3 | important because computing the bounding boxes from the masks is a significant 4 | bottleneck. 5 | 6 | First, you should pre-compute the bounding boxes since this takes a long time. 7 | 8 | Usage: 9 | python -m preprocess.preprocess_co3d --category all --precompute_bbox \ 10 | --co3d_dir /path/to/co3d 11 | python -m preprocess.preprocess_co3d --category all \ 12 | --co3d_dir /path/to/co3d 13 | """ 14 | import argparse 15 | import gzip 16 | import json 17 | import os 18 | import os.path as osp 19 | from glob import glob 20 | 21 | import ipdb 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | from tqdm.auto import tqdm 25 | 26 | # fmt: off 27 | CATEGORIES = [ 28 | "apple", "backpack", "ball", "banana", "baseballbat", "baseballglove", 29 | "bench", "bicycle", "book", "bottle", "bowl", "broccoli", "cake", "car", "carrot", 30 | "cellphone", "chair", "couch", "cup", "donut", "frisbee", "hairdryer", "handbag", 31 | "hotdog", "hydrant", "keyboard", "kite", "laptop", "microwave", "motorcycle", 32 | "mouse", "orange", "parkingmeter", "pizza", "plant", "remote", "sandwich", 33 | "skateboard", "stopsign", "suitcase", "teddybear", "toaster", "toilet", "toybus", 34 | "toyplane", "toytrain", "toytruck", "tv", "umbrella", "vase", "wineglass", 35 | ] 36 | # fmt: on 37 | 38 | 39 | def get_parser(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--category", type=str, default="apple") 42 | parser.add_argument("--output_dir", type=str, default="data/co3d_annotations") 43 | parser.add_argument("--co3d_dir", type=str, default="data/co3d") 44 | parser.add_argument( 45 | "--min_quality", 46 | type=float, 47 | default=0.5, 48 | help="Minimum viewpoint quality score.", 49 | ) 50 | parser.add_argument("--precompute_bbox", action="store_true") 51 | return parser 52 | 53 | 54 | def mask_to_bbox(mask): 55 | """ 56 | xyxy format 57 | """ 58 | mask = mask > 0.4 59 | if not np.any(mask): 60 | return [] 61 | rows = np.any(mask, axis=1) 62 | cols = np.any(mask, axis=0) 63 | rmin, rmax = np.where(rows)[0][[0, -1]] 64 | cmin, cmax = np.where(cols)[0][[0, -1]] 65 | return [int(cmin), int(rmin), int(cmax) + 1, int(rmax) + 1] 66 | 67 | 68 | def precompute_bbox(co3d_dir, category, output_dir): 69 | """ 70 | Precomputes bounding boxes for all frames using the masks. This can be an expensive 71 | operation because it needs to load every mask in the dataset. Thus, we only want to 72 | run this once, whereas processing the rest of the dataset is fast. 73 | """ 74 | category_dir = osp.join(co3d_dir, category) 75 | print("Precomputing bbox for:", category) 76 | all_masks = sorted(glob(osp.join(category_dir, "*", "masks", "*.png"))) 77 | bboxes = {} 78 | for mask_filename in tqdm(all_masks): 79 | mask = plt.imread(mask_filename) 80 | # /Dataset/category/sequence/masks/mask.png -> category/sequence/mask/mask.png 81 | mask_filename = mask_filename.replace(osp.dirname(category_dir), "")[1:] 82 | try: 83 | bboxes[mask_filename] = mask_to_bbox(mask) 84 | except IndexError: 85 | ipdb.set_trace() 86 | output_file = osp.join(output_dir, f"{category}_bbox.jgz") 87 | with gzip.open(output_file, "w") as f: 88 | f.write(json.dumps(bboxes).encode("utf-8")) 89 | 90 | 91 | def process_poses(co3d_dir, category, output_dir, min_quality): 92 | category_dir = osp.join(co3d_dir, args.category) 93 | print("Processing category:", category) 94 | frame_file = osp.join(category_dir, "frame_annotations.jgz") 95 | sequence_file = osp.join(category_dir, "sequence_annotations.jgz") 96 | subset_lists_file = osp.join(category_dir, "set_lists/set_lists_fewview_dev.json") 97 | 98 | bbox_file = osp.join(output_dir, f"{category}_bbox.jgz") 99 | 100 | with open(subset_lists_file) as f: 101 | subset_lists_data = json.load(f) 102 | 103 | with gzip.open(sequence_file, "r") as fin: 104 | sequence_data = json.loads(fin.read()) 105 | 106 | with gzip.open(frame_file, "r") as fin: 107 | frame_data = json.loads(fin.read()) 108 | 109 | with gzip.open(bbox_file, "r") as fin: 110 | bbox_data = json.loads(fin.read()) 111 | 112 | frame_data_processed = {} 113 | for f_data in frame_data: 114 | sequence_name = f_data["sequence_name"] 115 | if sequence_name not in frame_data_processed: 116 | frame_data_processed[sequence_name] = {} 117 | frame_data_processed[sequence_name][f_data["frame_number"]] = f_data 118 | 119 | good_quality_sequences = set() 120 | for seq_data in sequence_data: 121 | if seq_data["viewpoint_quality_score"] > min_quality: 122 | good_quality_sequences.add(seq_data["sequence_name"]) 123 | 124 | cat_dir = f"/home/amyxlase/process_co3d_v2/images/{category}" 125 | if not os.path.exists(cat_dir): 126 | os.mkdir(cat_dir) 127 | 128 | for subset in ["train", "test"]: 129 | category_data = {} # {sequence_name: [{filepath, R, T}]} 130 | for seq_name, frame_number, filepath in subset_lists_data[subset]: 131 | if seq_name not in good_quality_sequences: 132 | continue 133 | 134 | if seq_name not in category_data: 135 | category_data[seq_name] = [] 136 | 137 | mask_path = filepath.replace("images", "masks").replace(".jpg", ".png") 138 | bbox = bbox_data[mask_path] 139 | if bbox == []: 140 | # Mask did not include any object. 141 | continue 142 | 143 | frame_data = frame_data_processed[seq_name][frame_number] 144 | category_data[seq_name].append( 145 | { 146 | "filepath": filepath, 147 | "R": frame_data["viewpoint"]["R"], 148 | "T": frame_data["viewpoint"]["T"], 149 | "focal_length": frame_data["viewpoint"]["focal_length"], 150 | "principal_point": frame_data["viewpoint"]["principal_point"], 151 | "bbox": bbox, 152 | } 153 | ) 154 | 155 | output_file = osp.join(args.output_dir, f"{args.category}_{subset}.jgz") 156 | with gzip.open(output_file, "w") as f: 157 | f.write(json.dumps(category_data).encode("utf-8")) 158 | 159 | 160 | if __name__ == "__main__": 161 | parser = get_parser() 162 | args = parser.parse_args() 163 | if args.category == "all": 164 | categories = CATEGORIES 165 | else: 166 | categories = [args.category] 167 | if args.precompute_bbox: 168 | for category in categories: 169 | precompute_bbox( 170 | co3d_dir=args.co3d_dir, 171 | category=category, 172 | output_dir=args.output_dir, 173 | ) 174 | else: 175 | for category in categories: 176 | process_poses( 177 | co3d_dir=args.co3d_dir, 178 | category=category, 179 | output_dir=args.output_dir, 180 | min_quality=args.min_quality, 181 | ) 182 | -------------------------------------------------------------------------------- /preprocess/preprocess_co3dv1.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to pre-process camera poses and bounding boxes for CO3Dv1 dataset. This is 3 | important because computing the bounding boxes from the masks is a significant 4 | bottleneck. 5 | 6 | First, you should pre-compute the bounding boxes since this takes a long time. 7 | 8 | Usage: 9 | python -m preprocess.preprocess_co3dv1 --category all --precompute_bbox \ 10 | --co3d_v1_dir /path/to/co3d_v1 11 | python -m preprocess.preprocess_co3dv1 --category all \ 12 | --co3d_v1_dir /path/to/co3d_v1 13 | """ 14 | import argparse 15 | import gzip 16 | import json 17 | import os 18 | import os.path as osp 19 | from glob import glob 20 | 21 | import ipdb 22 | import matplotlib.pyplot as plt 23 | import numpy as np 24 | from tqdm.auto import tqdm 25 | 26 | # fmt: off 27 | CATEGORIES = [ 28 | "apple", "backpack", "ball", "banana", "baseballbat", "baseballglove", 29 | "bench", "bicycle", "book", "bottle", "bowl", "broccoli", "cake", "car", "carrot", 30 | "cellphone", "chair", "couch", "cup", "donut", "frisbee", "hairdryer", "handbag", 31 | "hotdog", "hydrant", "keyboard", "kite", "laptop", "microwave", "motorcycle", 32 | "mouse", "orange", "parkingmeter", "pizza", "plant", "remote", "sandwich", 33 | "skateboard", "stopsign", "suitcase", "teddybear", "toaster", "toilet", "toybus", 34 | "toyplane", "toytrain", "toytruck", "tv", "umbrella", "vase", "wineglass", 35 | ] 36 | # fmt: on 37 | 38 | 39 | def get_parser(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--category", type=str, default="tv") 42 | parser.add_argument("--output_dir", type=str, default="data/co3dv1_annotations") 43 | parser.add_argument("--co3d_v1_dir", type=str, default="data/co3dv1") 44 | parser.add_argument( 45 | "--min_quality", 46 | type=float, 47 | default=0.5, 48 | help="Minimum viewpoint quality score.", 49 | ) 50 | parser.add_argument("--precompute_bbox", action="store_true") 51 | return parser 52 | 53 | 54 | def mask_to_bbox(mask): 55 | """ 56 | xyxy format 57 | """ 58 | mask = mask > 0.4 59 | if not np.any(mask): 60 | return [] 61 | rows = np.any(mask, axis=1) 62 | cols = np.any(mask, axis=0) 63 | rmin, rmax = np.where(rows)[0][[0, -1]] 64 | cmin, cmax = np.where(cols)[0][[0, -1]] 65 | return [int(cmin), int(rmin), int(cmax) + 1, int(rmax) + 1] 66 | 67 | 68 | def precompute_bbox(co3d_v1_dir, category, output_dir): 69 | """ 70 | Precomputes bounding boxes for all frames using the masks. This can be an expensive 71 | operation because it needs to load every mask in the dataset. Thus, we only want to 72 | run this once, whereas processing the rest of the dataset is fast. 73 | """ 74 | category_dir = osp.join(co3d_v1_dir, category) 75 | print("Precomputing bbox for:", category) 76 | all_masks = sorted(glob(osp.join(category_dir, "*", "masks", "*.png"))) 77 | bboxes = {} 78 | for mask_filename in tqdm(all_masks): 79 | mask = plt.imread(mask_filename) 80 | # /Dataset/category/sequence/masks/mask.png -> category/sequence/mask/mask.png 81 | mask_filename = mask_filename.replace(osp.dirname(category_dir), "")[1:] 82 | try: 83 | bboxes[mask_filename] = mask_to_bbox(mask) 84 | except IndexError: 85 | ipdb.set_trace() 86 | output_file = osp.join(output_dir, f"{category}_bbox.jgz") 87 | os.makedirs(output_dir, exist_ok=True) 88 | with gzip.open(output_file, "w") as f: 89 | f.write(json.dumps(bboxes).encode("utf-8")) 90 | 91 | 92 | def process_poses(co3d_v1_dir, category, output_dir, min_quality): 93 | category_dir = osp.join(co3d_v1_dir, category) 94 | print("Processing category:", category) 95 | frame_file = osp.join(category_dir, "frame_annotations.jgz") 96 | sequence_file = osp.join(category_dir, "sequence_annotations.jgz") 97 | subset_lists_file = osp.join(category_dir, "set_lists.json") 98 | 99 | bbox_file = osp.join(args.output_dir, f"{category}_bbox.jgz") 100 | 101 | with open(subset_lists_file) as f: 102 | # Splits are: train_known, train_unseen, test_known, test_unseen 103 | subset_lists_data = json.load(f) 104 | 105 | with gzip.open(sequence_file, "r") as fin: 106 | sequence_data = json.loads(fin.read()) 107 | 108 | with gzip.open(frame_file, "r") as fin: 109 | frame_data = json.loads(fin.read()) 110 | 111 | with gzip.open(bbox_file, "r") as fin: 112 | bbox_data = json.loads(fin.read()) 113 | 114 | frame_data_processed = {} 115 | for f_data in frame_data: 116 | sequence_name = f_data["sequence_name"] 117 | if sequence_name not in frame_data_processed: 118 | frame_data_processed[sequence_name] = {} 119 | frame_data_processed[sequence_name][f_data["frame_number"]] = f_data 120 | 121 | good_quality_sequences = set() 122 | for seq_data in sequence_data: 123 | if seq_data["viewpoint_quality_score"] > min_quality: 124 | good_quality_sequences.add(seq_data["sequence_name"]) 125 | 126 | for subset in subset_lists_data.keys(): 127 | category_data = {} # {sequence_name: [{filepath, R, T}]} 128 | for seq_name, frame_number, filepath in subset_lists_data[subset]: 129 | if seq_name not in good_quality_sequences: 130 | continue 131 | if seq_name not in category_data: 132 | category_data[seq_name] = [] 133 | 134 | mask_path = filepath.replace("images", "masks").replace(".jpg", ".png") 135 | bbox = bbox_data[mask_path] 136 | if bbox == []: 137 | # Mask did not include any object. 138 | continue 139 | 140 | frame_data = frame_data_processed[seq_name][frame_number] 141 | category_data[seq_name].append( 142 | { 143 | "filepath": filepath, 144 | "R": frame_data["viewpoint"]["R"], 145 | "T": frame_data["viewpoint"]["T"], 146 | "focal_length": frame_data["viewpoint"]["focal_length"], 147 | "principal_point": frame_data["viewpoint"]["principal_point"], 148 | "bbox": bbox, 149 | } 150 | ) 151 | output_file = osp.join(output_dir, f"{category}_{subset}.jgz") 152 | with gzip.open(output_file, "w") as f: 153 | f.write(json.dumps(category_data).encode("utf-8")) 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = get_parser() 158 | args = parser.parse_args() 159 | if args.category == "all": 160 | categories = CATEGORIES 161 | else: 162 | categories = [args.category] 163 | if args.precompute_bbox: 164 | for category in categories: 165 | precompute_bbox( 166 | co3d_v1_dir=args.co3d_v1_dir, 167 | category=category, 168 | output_dir=args.output_dir, 169 | ) 170 | else: 171 | for category in categories: 172 | process_poses( 173 | co3d_v1_dir=args.co3d_v1_dir, 174 | category=category, 175 | output_dir=args.output_dir, 176 | min_quality=args.min_quality, 177 | ) 178 | -------------------------------------------------------------------------------- /relpose/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval.load_model import get_model 2 | from .models import RelPose 3 | -------------------------------------------------------------------------------- /relpose/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .co3dv1 import Co3dv1Dataset 2 | from .co3d import Co3dDataset 3 | from .custom import CustomDataset 4 | from .dataloader import get_dataloader 5 | -------------------------------------------------------------------------------- /relpose/dataset/co3d.py: -------------------------------------------------------------------------------- 1 | """ 2 | CO3D (v2) dataset. 3 | """ 4 | 5 | import gzip 6 | import json 7 | import os.path as osp 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image, ImageFile 12 | from torch.utils.data import Dataset 13 | from torchvision import transforms 14 | from tqdm.auto import tqdm 15 | 16 | from relpose.utils.bbox import square_bbox 17 | 18 | CO3D_DIR = "data/co3d" 19 | CO3D_ANNOTATION_DIR = "data/co3d_annotations" 20 | 21 | TRAINING_CATEGORIES = [ 22 | "apple", 23 | "backpack", 24 | "banana", 25 | "baseballbat", 26 | "baseballglove", 27 | "bench", 28 | "bicycle", 29 | "bottle", 30 | "bowl", 31 | "broccoli", 32 | "cake", 33 | "car", 34 | "carrot", 35 | "cellphone", 36 | "chair", 37 | "cup", 38 | "donut", 39 | "hairdryer", 40 | "handbag", 41 | "hydrant", 42 | "keyboard", 43 | "laptop", 44 | "microwave", 45 | "motorcycle", 46 | "mouse", 47 | "orange", 48 | "parkingmeter", 49 | "pizza", 50 | "plant", 51 | "stopsign", 52 | "teddybear", 53 | "toaster", 54 | "toilet", 55 | "toybus", 56 | "toyplane", 57 | "toytrain", 58 | "toytruck", 59 | "tv", 60 | "umbrella", 61 | "vase", 62 | "wineglass", 63 | ] 64 | 65 | TEST_CATEGORIES = [ 66 | "ball", 67 | "book", 68 | "couch", 69 | "frisbee", 70 | "hotdog", 71 | "kite", 72 | "remote", 73 | "sandwich", 74 | "skateboard", 75 | "suitcase", 76 | ] 77 | 78 | Image.MAX_IMAGE_PIXELS = None 79 | ImageFile.LOAD_TRUNCATED_IMAGES = True 80 | 81 | 82 | class Co3dDataset(Dataset): 83 | def __init__( 84 | self, 85 | category=("all",), 86 | split="train", 87 | transform=None, 88 | debug=False, 89 | random_aug=True, 90 | jitter_scale=(1.1, 1.2), 91 | jitter_trans=(-0.07, 0.07), 92 | num_images=2, 93 | ): 94 | """ 95 | Args: 96 | category (list): List of categories to use. 97 | split (str): "train" or "test". 98 | transform (callable): Transformation to apply to the image. 99 | random_aug (bool): Whether to apply random augmentation. 100 | jitter_scale (tuple): Scale jitter range. 101 | jitter_trans (tuple): Translation jitter range. 102 | num_images: Number of images in each batch. 103 | """ 104 | if "all" in category: 105 | category = TRAINING_CATEGORIES 106 | category = sorted(category) 107 | 108 | split_name = split 109 | 110 | self.rotations = {} 111 | self.category_map = {} 112 | iterable = tqdm(category) if len(category) > 1 else category 113 | for c in category: 114 | annotation_file = osp.join(CO3D_ANNOTATION_DIR, f"{c}_{split_name}.jgz") 115 | with gzip.open(annotation_file, "r") as fin: 116 | annotation = json.loads(fin.read()) 117 | for seq_name, seq_data in annotation.items(): 118 | if len(seq_data) < 2: 119 | continue 120 | filtered_data = [] 121 | self.category_map[seq_name] = c 122 | for data in seq_data: 123 | # Ignore all unnecessary information. 124 | filtered_data.append( 125 | { 126 | "filepath": data["filepath"], 127 | "bbox": data["bbox"], 128 | "R": data["R"], 129 | "focal_length": data["focal_length"], 130 | }, 131 | ) 132 | self.rotations[seq_name] = filtered_data 133 | 134 | self.sequence_list = list(self.rotations.keys()) 135 | self.split = split 136 | self.debug = debug 137 | if transform is None: 138 | self.transform = transforms.Compose( 139 | [ 140 | transforms.ToTensor(), 141 | transforms.Resize(224), 142 | transforms.Normalize( 143 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 144 | ), 145 | ] 146 | ) 147 | else: 148 | self.transform = transform 149 | if random_aug: 150 | self.jitter_scale = jitter_scale 151 | self.jitter_trans = jitter_trans 152 | else: 153 | self.jitter_scale = [1.15, 1.15] 154 | self.jitter_trans = [0, 0] 155 | self.num_images = num_images 156 | 157 | def __len__(self): 158 | return len(self.sequence_list) 159 | 160 | def _jitter_bbox(self, bbox): 161 | bbox = square_bbox(bbox.astype(np.float32)) 162 | s = np.random.uniform(self.jitter_scale[0], self.jitter_scale[1]) 163 | tx, ty = np.random.uniform(self.jitter_trans[0], self.jitter_trans[1], size=2) 164 | 165 | side_length = bbox[2] - bbox[0] 166 | center = (bbox[:2] + bbox[2:]) / 2 + np.array([tx, ty]) * side_length 167 | extent = side_length / 2 * s 168 | 169 | # Final coordinates need to be integer for cropping. 170 | ul = (center - extent).round().astype(int) 171 | lr = ul + np.round(2 * extent).astype(int) 172 | return np.concatenate((ul, lr)) 173 | 174 | def _crop_image(self, image, bbox): 175 | image_crop = transforms.functional.crop( 176 | image, 177 | top=bbox[1], 178 | left=bbox[0], 179 | height=bbox[3] - bbox[1], 180 | width=bbox[2] - bbox[0], 181 | ) 182 | return image_crop 183 | 184 | def __getitem__(self, index): 185 | sequence_name = self.sequence_list[index] 186 | metadata = self.rotations[sequence_name] 187 | ids = np.random.choice(len(metadata), self.num_images) 188 | if self.debug: 189 | # id1, id2 = np.random.choice(5, 2, replace=False) 190 | pass 191 | return self.get_data(index=index, ids=ids) 192 | 193 | def get_data(self, index=None, sequence_name=None, ids=(0, 1)): 194 | if sequence_name is None: 195 | sequence_name = self.sequence_list[index] 196 | metadata = self.rotations[sequence_name] 197 | 198 | annos = [metadata[i] for i in ids] 199 | images = [Image.open(osp.join(CO3D_DIR, anno["filepath"])) for anno in annos] 200 | rotations = [torch.tensor(anno["R"]) for anno in annos] 201 | 202 | additional_data = {} 203 | 204 | images_transformed = [] 205 | for anno, image in zip(annos, images): 206 | if self.transform is None: 207 | images_transformed.append(image) 208 | else: 209 | bbox = np.array(anno["bbox"]) 210 | bbox_jitter = self._jitter_bbox(bbox) 211 | image = self._crop_image(image, bbox_jitter) 212 | images_transformed.append(self.transform(image)) 213 | images = images_transformed 214 | 215 | relative_rotation = rotations[0].T @ rotations[1] 216 | category = self.category_map[sequence_name] 217 | batch = { 218 | "relative_rotation": relative_rotation, 219 | "model_id": sequence_name, 220 | "category": category, 221 | "n": len(metadata), 222 | } 223 | if self.transform is None: 224 | batch["image"] = images 225 | else: 226 | batch["image"] = torch.stack(images) 227 | batch["ind"] = torch.tensor(ids) 228 | batch["R"] = torch.stack(rotations) 229 | batch.update(additional_data) 230 | return batch 231 | -------------------------------------------------------------------------------- /relpose/dataset/co3dv1.py: -------------------------------------------------------------------------------- 1 | """ 2 | CO3Dv1 dataset. 3 | """ 4 | 5 | import gzip 6 | import json 7 | import os.path as osp 8 | 9 | import numpy as np 10 | import torch 11 | from PIL import Image, ImageFile 12 | from torch.utils.data import Dataset 13 | from torchvision import transforms 14 | 15 | from relpose.utils.bbox import square_bbox 16 | 17 | CO3D_DIR = "data/co3d_v1" 18 | CO3D_ANNOTATION_DIR = "data/co3dv1_annotations" 19 | 20 | TRAINING_CATEGORIES = [ 21 | "apple", 22 | "backpack", 23 | "banana", 24 | "baseballbat", 25 | "baseballglove", 26 | "bench", 27 | "bicycle", 28 | "bottle", 29 | "bowl", 30 | "broccoli", 31 | "cake", 32 | "car", 33 | "carrot", 34 | "cellphone", 35 | "chair", 36 | "cup", 37 | "donut", 38 | "hairdryer", 39 | "handbag", 40 | "hydrant", 41 | "keyboard", 42 | "laptop", 43 | "microwave", 44 | "motorcycle", 45 | "mouse", 46 | "orange", 47 | "parkingmeter", 48 | "pizza", 49 | "plant", 50 | "stopsign", 51 | "teddybear", 52 | "toaster", 53 | "toilet", 54 | "toybus", 55 | "toyplane", 56 | "toytrain", 57 | "toytruck", 58 | "tv", 59 | "umbrella", 60 | "vase", 61 | "wineglass", 62 | ] 63 | 64 | TEST_CATEGORIES = [ 65 | "ball", 66 | "book", 67 | "couch", 68 | "frisbee", 69 | "hotdog", 70 | "kite", 71 | "remote", 72 | "sandwich", 73 | "skateboard", 74 | "suitcase", 75 | ] 76 | 77 | Image.MAX_IMAGE_PIXELS = None 78 | ImageFile.LOAD_TRUNCATED_IMAGES = True 79 | 80 | 81 | class Co3dv1Dataset(Dataset): 82 | def __init__( 83 | self, 84 | category=("all",), 85 | split="train", 86 | transform=None, 87 | debug=False, 88 | random_aug=True, 89 | jitter_scale=(1.1, 1.2), 90 | jitter_trans=(-0.07, 0.07), 91 | num_images=2, 92 | ): 93 | if "all" in category: 94 | category = TRAINING_CATEGORIES 95 | category = sorted(category) 96 | 97 | if split == "train": 98 | split_name = "train_known" 99 | elif split == "test": 100 | split_name = "test_known" 101 | 102 | self.rotations = {} 103 | self.category_map = {} 104 | for c in category: 105 | annotation_file = osp.join(CO3D_ANNOTATION_DIR, f"{c}_{split_name}.jgz") 106 | with gzip.open(annotation_file, "r") as fin: 107 | annotation = json.loads(fin.read()) 108 | for seq_name, seq_data in annotation.items(): 109 | if len(seq_data) < 2: 110 | continue 111 | filtered_data = [] 112 | self.category_map[seq_name] = c 113 | for data in seq_data: 114 | # Ignore all unnecessary information. 115 | filtered_data.append( 116 | { 117 | "filepath": data["filepath"], 118 | "bbox": data["bbox"], 119 | "R": data["R"], 120 | "focal_length": data["focal_length"], 121 | }, 122 | ) 123 | self.rotations[seq_name] = filtered_data 124 | 125 | self.sequence_list = list(self.rotations.keys()) 126 | self.split = split 127 | self.debug = debug 128 | if transform is None: 129 | self.transform = transforms.Compose( 130 | [ 131 | transforms.ToTensor(), 132 | transforms.Resize(224), 133 | transforms.Normalize( 134 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 135 | ), 136 | ] 137 | ) 138 | else: 139 | self.transform = transform 140 | if random_aug: 141 | self.jitter_scale = jitter_scale 142 | self.jitter_trans = jitter_trans 143 | else: 144 | self.jitter_scale = [1.15, 1.15] 145 | self.jitter_trans = [0, 0] 146 | self.num_images = num_images 147 | 148 | def __len__(self): 149 | return len(self.sequence_list) 150 | 151 | def _jitter_bbox(self, bbox): 152 | bbox = square_bbox(bbox.astype(np.float32)) 153 | s = np.random.uniform(self.jitter_scale[0], self.jitter_scale[1]) 154 | tx, ty = np.random.uniform(self.jitter_trans[0], self.jitter_trans[1], size=2) 155 | 156 | side_length = bbox[2] - bbox[0] 157 | center = (bbox[:2] + bbox[2:]) / 2 + np.array([tx, ty]) * side_length 158 | extent = side_length / 2 * s 159 | 160 | # Final coordinates need to be integer for cropping. 161 | ul = (center - extent).round().astype(int) 162 | lr = ul + np.round(2 * extent).astype(int) 163 | return np.concatenate((ul, lr)) 164 | 165 | def _crop_image(self, image, bbox): 166 | image_crop = transforms.functional.crop( 167 | image, 168 | top=bbox[1], 169 | left=bbox[0], 170 | height=bbox[3] - bbox[1], 171 | width=bbox[2] - bbox[0], 172 | ) 173 | return image_crop 174 | 175 | def __getitem__(self, index): 176 | sequence_name = self.sequence_list[index] 177 | metadata = self.rotations[sequence_name] 178 | ids = np.random.choice(len(metadata), self.num_images) 179 | if self.debug: 180 | # id1, id2 = np.random.choice(5, 2, replace=False) 181 | pass 182 | return self.get_data(index=index, ids=ids) 183 | 184 | def get_data(self, index=None, sequence_name=None, ids=(0, 1)): 185 | if sequence_name is None: 186 | sequence_name = self.sequence_list[index] 187 | metadata = self.rotations[sequence_name] 188 | 189 | annos = [metadata[i] for i in ids] 190 | images = [Image.open(osp.join(CO3D_DIR, anno["filepath"])) for anno in annos] 191 | rotations = [torch.tensor(anno["R"]) for anno in annos] 192 | 193 | additional_data = {} 194 | 195 | images_transformed = [] 196 | for anno, image in zip(annos, images): 197 | if self.transform is None: 198 | images_transformed.append(image) 199 | else: 200 | bbox = np.array(anno["bbox"]) 201 | bbox_jitter = self._jitter_bbox(bbox) 202 | image = self._crop_image(image, bbox_jitter) 203 | images_transformed.append(self.transform(image)) 204 | images = images_transformed 205 | 206 | relative_rotation = rotations[0].T @ rotations[1] 207 | category = self.category_map[sequence_name] 208 | batch = { 209 | "relative_rotation": relative_rotation, 210 | "model_id": sequence_name, 211 | "category": category, 212 | "n": len(metadata), 213 | } 214 | if self.transform is None: 215 | batch["image"] = images 216 | else: 217 | batch["image"] = torch.stack(images) 218 | batch["ind"] = torch.tensor(ids) 219 | batch["R"] = torch.stack(rotations) 220 | batch.update(additional_data) 221 | return batch 222 | -------------------------------------------------------------------------------- /relpose/dataset/custom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataset class for custom datasets. Should provide a directory with images and 3 | optionally a directory of masks. The masks are used to extracting bounding boxes for 4 | each image. If masks are not provided, bounding boxes must be provided directly instead. 5 | 6 | Directory format: 7 | 8 | image_dir 9 | |_ image0001.jpg 10 | mask_dir 11 | |_ mask0001.png 12 | """ 13 | import os 14 | import os.path as osp 15 | 16 | import matplotlib.pyplot as plt 17 | import numpy as np 18 | import torch 19 | from PIL import Image 20 | from torch.utils.data import Dataset 21 | from torchvision import transforms 22 | 23 | from relpose.utils.bbox import mask_to_bbox, square_bbox 24 | 25 | 26 | class CustomDataset(Dataset): 27 | def __init__(self, image_dir, mask_dir=None, bboxes=None): 28 | self.image_dir = image_dir 29 | self.mask_dir = mask_dir 30 | self.bboxes = [] 31 | self.images = [] 32 | 33 | for image_path in sorted(os.listdir(image_dir)): 34 | self.images.append(Image.open(osp.join(image_dir, image_path))) 35 | self.n = len(self.images) 36 | if bboxes is None: 37 | for mask_path in sorted(os.listdir(mask_dir))[: self.n]: 38 | mask = plt.imread(osp.join(mask_dir, mask_path)) 39 | if len(mask.shape) == 3: 40 | mask = mask[:, :, :3] 41 | else: 42 | mask = np.dstack([mask, mask, mask]) 43 | self.bboxes.append(mask_to_bbox(mask)) 44 | else: 45 | self.bboxes = bboxes 46 | self.jitter_scale = [1.15, 1.15] 47 | self.jitter_trans = [0, 0] 48 | self.transform = transforms.Compose( 49 | [ 50 | transforms.ToTensor(), 51 | transforms.Resize(224), 52 | transforms.Normalize( 53 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 54 | ), 55 | ] 56 | ) 57 | 58 | def __len__(self): 59 | return 1 60 | 61 | def _jitter_bbox(self, bbox): 62 | bbox = square_bbox(bbox.astype(np.float32)) 63 | s = np.random.uniform(self.jitter_scale[0], self.jitter_scale[1]) 64 | tx, ty = np.random.uniform(self.jitter_trans[0], self.jitter_trans[1], size=2) 65 | 66 | side_length = bbox[2] - bbox[0] 67 | center = (bbox[:2] + bbox[2:]) / 2 + np.array([tx, ty]) * side_length 68 | extent = side_length / 2 * s 69 | 70 | # Final coordinates need to be integer for cropping. 71 | ul = (center - extent).round().astype(int) 72 | lr = ul + np.round(2 * extent).astype(int) 73 | return np.concatenate((ul, lr)) 74 | 75 | def _crop_image(self, image, bbox): 76 | image_crop = transforms.functional.crop( 77 | image, 78 | top=bbox[1], 79 | left=bbox[0], 80 | height=bbox[3] - bbox[1], 81 | width=bbox[2] - bbox[0], 82 | ) 83 | return image_crop 84 | 85 | def __getitem__(self, index): 86 | # Should use get_data instead. 87 | ids = np.random.choice(self.n, 2) 88 | return self.get_data(ids=ids) 89 | 90 | def get_data(self, ids=(0, 1)): 91 | images = [self.images[i] for i in ids] 92 | bboxes = [self.bboxes[i] for i in ids] 93 | images_transformed = [] 94 | for _, (bbox, image) in enumerate(zip(bboxes, images)): 95 | bbox = np.array(bbox) 96 | bbox_jitter = self._jitter_bbox(bbox) 97 | image = self._crop_image(image, bbox_jitter) 98 | images_transformed.append(self.transform(image)) 99 | images = images_transformed 100 | return torch.stack(images) 101 | -------------------------------------------------------------------------------- /relpose/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from relpose.dataset import Co3dDataset, Co3dv1Dataset 4 | 5 | 6 | def get_dataloader( 7 | batch_size=64, 8 | dataset="co3dv1", 9 | category=("apple",), 10 | split="train", 11 | shuffle=True, 12 | num_workers=8, 13 | debug=False, 14 | num_images=2, 15 | ): 16 | if debug: 17 | num_workers = 0 18 | if dataset == "co3dv1": 19 | dataset = Co3dv1Dataset( 20 | category=category, 21 | split=split, 22 | num_images=num_images, 23 | debug=debug, 24 | ) 25 | elif dataset in ["co3d", "co3dv2"]: 26 | dataset = Co3dDataset( 27 | category=category, 28 | split=split, 29 | num_images=num_images, 30 | debug=debug, 31 | ) 32 | else: 33 | raise Exception(f"Unknown dataset: {dataset}") 34 | 35 | return torch.utils.data.DataLoader( 36 | dataset, 37 | batch_size=batch_size, 38 | shuffle=shuffle, 39 | num_workers=num_workers, 40 | pin_memory=True, 41 | drop_last=True, 42 | ) 43 | -------------------------------------------------------------------------------- /relpose/eval/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval_pairwise import compute_angular_error, compute_angular_error_batch 2 | from .load_model import get_model, get_eval_dataset 3 | -------------------------------------------------------------------------------- /relpose/eval/eval_driver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Outputs all evaluation jobs that need to be run to a shell script. 3 | 4 | Update ARGUMENTS with the sweep of parameters you want to run. 5 | 6 | Usage: 7 | python -m relpose.eval.eval_driver eval_jobs.sh 8 | """ 9 | import argparse 10 | import itertools 11 | 12 | BASE_CMD = "python -m relpose.eval.eval_joint " 13 | ARGUMENTS = { 14 | "checkpoint": ["data/pretrained_co3dv1"], 15 | "mode": ["coord_asc"], 16 | "num_frames": [20, 10, 5, 3], 17 | "dataset": ["co3dv1"], 18 | "categories_type": ["unseen", "seen"], 19 | "index": [0, 1, 2, 3], 20 | "skip": [4], 21 | } 22 | 23 | 24 | def dict_product(dicts): 25 | """ 26 | https://stackoverflow.com/a/40623158 27 | 28 | >>> list(dict_product(dict(number=[1,2], character='ab'))) 29 | [{'character': 'a', 'number': 1}, 30 | {'character': 'a', 'number': 2}, 31 | {'character': 'b', 'number': 1}, 32 | {'character': 'b', 'number': 2}] 33 | """ 34 | return (dict(zip(dicts, x)) for x in itertools.product(*dicts.values())) 35 | 36 | 37 | def main(output_path): 38 | with open(output_path, "w") as f: 39 | for args in dict_product(ARGUMENTS): 40 | f.write(BASE_CMD + " ".join([f"--{k} {v}" for k, v in args.items()]) + "\n") 41 | 42 | 43 | if __name__ == "__main__": 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument("output_path", type=str, default="eval_jobs.sh") 46 | args = parser.parse_args() 47 | main(args.output_path) 48 | -------------------------------------------------------------------------------- /relpose/eval/eval_joint.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation script for relpose. 3 | 4 | Mode can be sequential, mst (maximum spanning tree), or coordinate ascent. 5 | Default is uniform spacing. Use --random_order to use random protocol. 6 | 7 | Example: 8 | python -m relpose.eval.eval_joint \ 9 | --checkpoint /path/to/checkpoint \ 10 | --num_frames 5 \ 11 | --use_pbar \ 12 | --dataset co3dv1 \ 13 | --categories_type seen \ 14 | --mode mst 15 | 16 | """ 17 | 18 | import argparse 19 | import json 20 | import os 21 | import os.path as osp 22 | 23 | import ipdb 24 | import numpy as np 25 | import torch 26 | from tqdm.auto import tqdm 27 | 28 | from relpose.dataset.co3dv1 import TEST_CATEGORIES, TRAINING_CATEGORIES 29 | from relpose.eval import compute_angular_error_batch, get_eval_dataset, get_model 30 | from relpose.inference.joint_inference import ( 31 | compute_mst, 32 | run_coordinate_ascent, 33 | score_hypothesis, 34 | ) 35 | from relpose.utils import get_permutations 36 | 37 | 38 | def get_parser(): 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--checkpoint", type=str, required=True) 41 | parser.add_argument("--num_frames", type=int, default=5) 42 | parser.add_argument("--use_pbar", action="store_true") 43 | parser.add_argument("--dataset", type=str, default="co3dv1") 44 | parser.add_argument( 45 | "--categories_type", type=str, default="seen", choices=["seen", "unseen", "all"] 46 | ) 47 | parser.add_argument("--split", type=str, default="test", choices=["train", "test"]) 48 | parser.add_argument( 49 | "--mode", 50 | type=str, 51 | default="sequential", 52 | choices=["sequential", "mst", "coord_asc"], 53 | ) 54 | parser.add_argument( 55 | "--random_order", 56 | action="store_true", 57 | help="If True, uses random order. Else uses uniform spacing.", 58 | ) 59 | parser.add_argument( 60 | "--num_queries", 61 | default=250_000, 62 | type=int, 63 | help="Number of queries to use for coordinate ascent.", 64 | ) 65 | parser.add_argument( 66 | "--num_iterations", 67 | default=200, 68 | type=int, 69 | help="Number of iterations to use for coordinate ascent.", 70 | ) 71 | parser.add_argument( 72 | "--force", action="store_true", help="If True, replaces existing results." 73 | ) 74 | parser.add_argument("--reverse", action="store_true") 75 | parser.add_argument("--index", type=int, default=0) 76 | parser.add_argument("--skip", type=int, default=1) 77 | return parser 78 | 79 | 80 | def evaluate_category_sequential( 81 | model, 82 | category="banana", 83 | split="train", 84 | num_frames=5, 85 | use_pbar=False, 86 | save_dir=None, 87 | force=False, 88 | random_order=False, 89 | dataset="co3dv1", 90 | **kwargs, 91 | ): 92 | if save_dir is not None: 93 | r = "-random" if random_order else "-uniform" 94 | path = osp.join( 95 | save_dir, f"{category}-{split}-sequential-{num_frames:03d}{r}.json" 96 | ) 97 | if osp.exists(path) and not force: 98 | print(f"{path} already exists, skipping") 99 | with open(path, "r") as f: 100 | data = json.load(f) 101 | angular_errors = [] 102 | for d in data.values(): 103 | angular_errors.extend(d["angular_errors"]) 104 | return np.array(angular_errors) 105 | dataset = get_eval_dataset(category=category, split=split, dataset=dataset) 106 | device = next(model.parameters()).device 107 | permutations = get_permutations(num_frames) 108 | iterable = tqdm(dataset) if use_pbar else dataset 109 | all_errors = {} 110 | angular_errors = [] 111 | if random_order: 112 | # order = np.load(osp.join("data", "sequence_order", f"{category}-known.npz")) 113 | with open(osp.join("data", "co3d_v2_random_order_0", f"{category}.json")) as f: 114 | order = json.load(f) 115 | for metadata in iterable: 116 | n = metadata["n"] 117 | sequence_name = metadata["model_id"] 118 | 119 | if random_order: 120 | key_frames = sorted(order[sequence_name][:num_frames]) 121 | else: 122 | key_frames = np.linspace( 123 | 0, n - 1, num=num_frames, endpoint=False, dtype=int 124 | ) 125 | batch = dataset.get_data(sequence_name=sequence_name, ids=key_frames) 126 | images = batch["image"].to(device) 127 | rotations = batch["R"] 128 | rotations_pred = [np.eye(3)] 129 | for i in range(num_frames - 1): 130 | image1 = images[i] 131 | image2 = images[i + 1] 132 | with torch.no_grad(): 133 | queries, logits = model( 134 | images1=image1.unsqueeze(0), 135 | images2=image2.unsqueeze(0), 136 | ) 137 | probabilities = torch.softmax(logits, -1) 138 | probabilities = probabilities[0].detach().cpu().numpy() 139 | best_prob = probabilities.argmax() 140 | best_rotation = queries[0].detach().cpu().numpy()[best_prob] 141 | rotations_pred.append(rotations_pred[-1] @ best_rotation) 142 | rotations_pred = np.stack(rotations_pred) 143 | rotations_gt = rotations.numpy() 144 | permutations = get_permutations(num_frames) 145 | R_pred_batched = rotations_pred[permutations] 146 | R_pred_rel = np.einsum( 147 | "Bij,Bjk ->Bik", 148 | R_pred_batched[:, 0].transpose(0, 2, 1), 149 | R_pred_batched[:, 1], 150 | ) 151 | R_gt_batched = rotations_gt[permutations] 152 | R_gt_rel = np.einsum( 153 | "Bij,Bjk ->Bik", 154 | R_gt_batched[:, 0].transpose(0, 2, 1), 155 | R_gt_batched[:, 1], 156 | ) 157 | errors = compute_angular_error_batch(R_pred_rel, R_gt_rel) 158 | angular_errors.extend(errors) 159 | all_errors[sequence_name] = { 160 | "R_pred": rotations_pred.tolist(), 161 | "R_gt": rotations_gt.tolist(), 162 | "angular_errors": errors.tolist(), 163 | } 164 | if save_dir is not None: 165 | with open(path, "w") as f: 166 | json.dump(all_errors, f) 167 | return np.array(angular_errors) 168 | 169 | 170 | def evaluate_category_mst( 171 | model, 172 | category="banana", 173 | split="train", 174 | num_frames=5, 175 | use_pbar=False, 176 | save_dir=None, 177 | force=False, 178 | random_order=False, 179 | dataset="co3dv1", 180 | **kwargs, 181 | ): 182 | if save_dir is not None: 183 | r = "-random" if random_order else "-uniform" 184 | path = osp.join(save_dir, f"{category}-{split}-mst-{num_frames:03d}{r}.json") 185 | if osp.exists(path) and not force: 186 | print(f"{path} already exists, skipping") 187 | with open(path, "r") as f: 188 | data = json.load(f) 189 | angular_errors = [] 190 | for d in data.values(): 191 | angular_errors.extend(d["angular_errors"]) 192 | return np.array(angular_errors) 193 | dataset = get_eval_dataset(category=category, split=split, dataset=dataset) 194 | device = next(model.parameters()).device 195 | permutations = get_permutations(num_frames) 196 | 197 | if random_order: 198 | with open(osp.join("data", "co3d_v2_random_order_0", f"{category}.json")) as f: 199 | order = json.load(f) 200 | 201 | iterable = tqdm(dataset) if use_pbar else dataset 202 | all_errors = {} 203 | angular_errors = [] 204 | for metadata in iterable: 205 | n = metadata["n"] 206 | if num_frames > n: 207 | continue 208 | sequence_name = metadata["model_id"] 209 | 210 | best_rotations = np.zeros((num_frames, num_frames, 3, 3)) 211 | best_probs = np.zeros((num_frames, num_frames)) 212 | if random_order: 213 | key_frames = sorted(order[sequence_name][:num_frames]) 214 | else: 215 | key_frames = np.linspace( 216 | 0, n - 1, num=num_frames, dtype=int, endpoint=False 217 | ) 218 | batch = dataset.get_data(sequence_name=sequence_name, ids=key_frames) 219 | images = batch["image"].to(device) 220 | rotations = batch["R"] 221 | rotations_gt = rotations.numpy() 222 | 223 | for i, j in permutations: 224 | image1 = images[i].unsqueeze(0).to(device) 225 | image2 = images[j].unsqueeze(0).to(device) 226 | with torch.no_grad(): 227 | queries, logits = model( 228 | images1=image1, 229 | images2=image2, 230 | ) 231 | probabilities = torch.softmax(logits, -1) 232 | probabilities = probabilities[0].detach().cpu().numpy() 233 | best_prob = probabilities.max() 234 | best_rotation = queries[0].detach().cpu().numpy()[probabilities.argmax()] 235 | 236 | best_rotations[i, j] = best_rotation 237 | best_probs[i, j] = best_prob 238 | 239 | rotations_pred, edges = compute_mst( 240 | num_frames=num_frames, 241 | best_probs=best_probs, 242 | best_rotations=best_rotations, 243 | ) 244 | 245 | R_pred_batched = rotations_pred[permutations] 246 | R_pred_rel = np.einsum( 247 | "Bij,Bjk ->Bik", 248 | R_pred_batched[:, 0].transpose(0, 2, 1), 249 | R_pred_batched[:, 1], 250 | ) 251 | R_gt_batched = rotations_gt[permutations] 252 | R_gt_rel = np.einsum( 253 | "Bij,Bjk ->Bik", 254 | R_gt_batched[:, 0].transpose(0, 2, 1), 255 | R_gt_batched[:, 1], 256 | ) 257 | errors = compute_angular_error_batch(R_pred_rel, R_gt_rel) 258 | angular_errors.extend(errors) 259 | all_errors[sequence_name] = { 260 | "R_pred": rotations_pred.tolist(), 261 | "R_gt": rotations_gt.tolist(), 262 | "angular_errors": errors.tolist(), 263 | "edges": edges, 264 | } 265 | if save_dir is not None: 266 | with open(path, "w") as f: 267 | json.dump(all_errors, f) 268 | return np.array(angular_errors) 269 | 270 | 271 | def evaluate_category_coord_asc( 272 | model, 273 | category, 274 | split="train", 275 | num_iterations=200, 276 | num_frames=5, 277 | use_pbar=False, 278 | save_dir=None, 279 | force=False, 280 | dataset="co3dv1", 281 | num_queries=250_000, 282 | skip=1, 283 | index=0, 284 | random_order=False, 285 | ): 286 | dataset = get_eval_dataset(category=category, split=split, dataset=dataset) 287 | device = next(model.parameters()).device 288 | permutations = get_permutations(num_frames) 289 | 290 | if random_order: 291 | with open(osp.join("data", "co3d_v2_random_order_0", f"{category}.json")) as f: 292 | order = json.load(f) 293 | 294 | angular_errors = [] 295 | iterator = np.arange(len(dataset))[index::skip] 296 | for i in tqdm(iterator): 297 | metadata = dataset[i] 298 | n = metadata["n"] 299 | if num_frames > n: 300 | continue 301 | sequence_name = metadata["model_id"] 302 | 303 | r = "-random" if random_order else "-uniform" 304 | output_file = osp.join( 305 | save_dir, f"{category}-{sequence_name}-{split}-{num_frames:03d}{r}.json" 306 | ) 307 | if osp.exists(output_file) and not force: 308 | with open(output_file) as f: 309 | data = json.load(f) 310 | angular_errors.extend(data["errors"]) 311 | continue 312 | 313 | if random_order: 314 | key_frames = sorted(order[sequence_name][:num_frames]) 315 | else: 316 | key_frames = np.linspace( 317 | 0, n - 1, num=num_frames, dtype=int, endpoint=False, 318 | ) 319 | batch = dataset.get_data(sequence_name=sequence_name, ids=key_frames) 320 | images = batch["image"].to(device) 321 | features = model.feature_extractor(images) 322 | rotations = batch["R"] 323 | rotations_gt = rotations.numpy() 324 | 325 | mst_path = osp.join( 326 | save_dir, "../mst", f"{category}-{split}-mst-{num_frames:03d}{r}.json" 327 | ) 328 | with open(mst_path) as f: 329 | mst_data = json.load(f) 330 | 331 | initial_hypothesis = np.array(mst_data[sequence_name]["R_pred"]) 332 | rotations_pred = run_coordinate_ascent( 333 | model=model, 334 | images=images, 335 | num_frames=num_frames, 336 | initial_hypothesis=initial_hypothesis, 337 | num_iterations=num_iterations, 338 | num_queries=num_queries, 339 | use_pbar=use_pbar, 340 | ) 341 | score = score_hypothesis( 342 | model=model, 343 | hypothesis=rotations_pred, 344 | permutations=torch.from_numpy(permutations), 345 | features=features, 346 | ) 347 | rotations_pred = rotations_pred.cpu().numpy() 348 | R_pred_batched = rotations_pred[permutations] 349 | R_pred_rel = np.einsum( 350 | "Bij,Bjk ->Bik", 351 | R_pred_batched[:, 0].transpose(0, 2, 1), 352 | R_pred_batched[:, 1], 353 | ) 354 | R_gt_batched = rotations_gt[permutations] 355 | R_gt_rel = np.einsum( 356 | "Bij,Bjk ->Bik", 357 | R_gt_batched[:, 0].transpose(0, 2, 1), 358 | R_gt_batched[:, 1], 359 | ) 360 | errors = compute_angular_error_batch(R_pred_rel, R_gt_rel) 361 | 362 | output_data = { 363 | "joint_score": score.item(), 364 | "errors": errors.tolist(), 365 | "R_pred": rotations_pred.tolist(), 366 | } 367 | with open(output_file, "w") as f: 368 | json.dump(output_data, f) 369 | angular_errors.extend(errors) 370 | return np.array(angular_errors) 371 | 372 | 373 | def evaluate_joint( 374 | model=None, 375 | checkpoint_path=None, 376 | dataset="co3dv1", 377 | categories_type="seen", 378 | split="test", 379 | num_frames=5, 380 | print_results=True, 381 | use_pbar=False, 382 | mode="sequential", 383 | save_output=True, 384 | force=False, 385 | num_queries=250_000, 386 | num_iterations=200, 387 | random_order=False, 388 | reverse=False, 389 | index=0, 390 | skip=1, 391 | ): 392 | if model is None or params is None: 393 | print(checkpoint_path) 394 | model, params = get_model(checkpoint_path) 395 | 396 | if save_output: 397 | if ".pth" in checkpoint_path: 398 | model_dir = osp.dirname(osp.dirname(checkpoint_path)) 399 | else: 400 | model_dir = checkpoint_path 401 | save_dir = osp.join(model_dir, "eval", mode) 402 | os.makedirs(save_dir, exist_ok=True) 403 | else: 404 | save_dir = None 405 | 406 | eval_map = { 407 | "sequential": evaluate_category_sequential, 408 | "mst": evaluate_category_mst, 409 | "coord_asc": evaluate_category_coord_asc, 410 | } 411 | eval_function = eval_map[mode] 412 | 413 | errors_15 = {} 414 | errors_30 = {} 415 | 416 | if categories_type == "seen": 417 | categories = TRAINING_CATEGORIES 418 | elif categories_type == "unseen": 419 | categories = TEST_CATEGORIES 420 | elif categories_type == "all": 421 | categories = TRAINING_CATEGORIES + TEST_CATEGORIES 422 | else: 423 | raise Exception(f"Unknown categories type: {categories_type}") 424 | categories = categories[index::skip] 425 | if reverse: 426 | categories = categories[::-1] 427 | for category in categories: 428 | angular_errors = eval_function( 429 | model=model, 430 | dataset=dataset, 431 | category=category, 432 | split=split, 433 | num_frames=num_frames, 434 | num_iterations=num_iterations, 435 | use_pbar=use_pbar, 436 | save_dir=save_dir, 437 | force=force, 438 | num_queries=num_queries, 439 | random_order=random_order, 440 | ) 441 | errors_15[category] = np.mean(angular_errors < 15) 442 | errors_30[category] = np.mean(angular_errors < 30) 443 | 444 | errors_15["mean"] = np.mean(list(errors_15.values())) 445 | errors_30["mean"] = np.mean(list(errors_30.values())) 446 | if print_results: 447 | print(f"{'Category':>10s}{'<15':6s}{'<30':6s}") 448 | for category in errors_15.keys(): 449 | print( 450 | f"{category:>10s}{errors_15[category]:6.02f}{errors_30[category]:6.02f}" 451 | ) 452 | if index == 0 and skip == 1: 453 | r = "random" if random_order else "uniform" 454 | output_path = osp.join( 455 | model_dir, 456 | "eval", 457 | f"{categories_type}-{mode}-{r}-N{num_frames:02d}.json", 458 | ) 459 | with open(output_path, "w") as f: 460 | json.dump({"errors_15": errors_15, "errors_30": errors_30}, f) 461 | return errors_15, errors_30 462 | 463 | 464 | if __name__ == "__main__": 465 | args = get_parser().parse_args() 466 | evaluate_joint( 467 | checkpoint_path=args.checkpoint, 468 | num_frames=args.num_frames, 469 | mode=args.mode, 470 | print_results=True, 471 | use_pbar=args.use_pbar, 472 | force=args.force, 473 | split=args.split, 474 | dataset=args.dataset, 475 | num_queries=args.num_queries, 476 | num_iterations=args.num_iterations, 477 | random_order=args.random_order, 478 | reverse=args.reverse, 479 | index=args.index, 480 | skip=args.skip, 481 | categories_type=args.categories_type, 482 | ) 483 | -------------------------------------------------------------------------------- /relpose/eval/eval_pairwise.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for pairwise evaluation of predictor (ie, given 2 images, compute accuracy of 3 | highest scoring mode). 4 | 5 | Note that here, num_frames refers to the number of images sampled from the sequence. 6 | The input frames will be all NP2 permutations of using those image frames for pairwise 7 | evaluation. 8 | """ 9 | 10 | import argparse 11 | 12 | import numpy as np 13 | import torch 14 | from tqdm.auto import tqdm 15 | 16 | from relpose.dataset.co3d import Co3dDataset 17 | from relpose.dataset.co3dv1 import TEST_CATEGORIES, Co3dv1Dataset 18 | from relpose.eval.load_model import get_model 19 | 20 | 21 | def get_parser(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument("--checkpoint", type=str, required=True) 24 | parser.add_argument("--num_frames", type=int, default=5) 25 | parser.add_argument("--use_pbar", action="store_true") 26 | return parser 27 | 28 | 29 | def compute_angular_error(rotation1, rotation2): 30 | R_rel = rotation1.T @ rotation2 31 | tr = (np.trace(R_rel) - 1) / 2 32 | theta = np.arccos(tr.clip(-1, 1)) 33 | return theta * 180 / np.pi 34 | 35 | 36 | def compute_angular_error_batch(rotation1, rotation2): 37 | R_rel = np.einsum("Bij,Bjk ->Bik", rotation1.transpose(0, 2, 1), rotation2) 38 | t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2 39 | theta = np.arccos(np.clip(t, -1, 1)) 40 | return theta * 180 / np.pi 41 | 42 | 43 | def get_permutations(num_frames): 44 | permutations = [] 45 | for i in range(num_frames): 46 | for j in range(num_frames): 47 | if i != j: 48 | permutations.append((i, j)) 49 | return torch.tensor(permutations) 50 | 51 | 52 | def get_dataset(category="banana", split="train", params={}, dataset="co3dv1"): 53 | if dataset == "co3dv1": 54 | return Co3dv1Dataset( 55 | split=split, 56 | category=[category], 57 | random_aug=False, 58 | ) 59 | elif dataset in ["co3d", "co3dv2"]: 60 | return Co3dDataset( 61 | split=split, 62 | category=[category], 63 | random_aug=False, 64 | ) 65 | else: 66 | raise Exception(f"Unknown dataset {dataset}") 67 | 68 | 69 | def evaluate_category( 70 | model, 71 | params, 72 | category="banana", 73 | split="train", 74 | num_frames=5, 75 | use_pbar=False, 76 | dataset="co3dv1", 77 | ): 78 | dataset = get_dataset( 79 | category=category, split=split, params=params, dataset=dataset 80 | ) 81 | device = next(model.parameters()).device 82 | 83 | permutations = get_permutations(num_frames) 84 | angular_errors = [] 85 | iterable = tqdm(dataset) if use_pbar else dataset 86 | for metadata in iterable: 87 | n = metadata["n"] 88 | sequence_name = metadata["model_id"] 89 | key_frames = np.linspace(0, n - 1, num=num_frames, dtype=int) 90 | batch = dataset.get_data(sequence_name=sequence_name, ids=key_frames) 91 | images = batch["image"] 92 | rotations = batch["R"] 93 | images_permuted = images[permutations] 94 | rotations_permuted = rotations[permutations] 95 | rotations_gt = torch.bmm( 96 | rotations_permuted[:, 0].transpose(1, 2), 97 | rotations_permuted[:, 1], 98 | ) 99 | images1 = images_permuted[:, 0].to(device) 100 | images2 = images_permuted[:, 1].to(device) 101 | 102 | for i in range(len(permutations)): 103 | image1 = images1[i] 104 | image2 = images2[i] 105 | rotation_gt = rotations_gt[i] 106 | 107 | with torch.no_grad(): 108 | queries, logits = model( 109 | images1=image1.unsqueeze(0), 110 | images2=image2.unsqueeze(0), 111 | gt_rotation=rotation_gt.to(device).unsqueeze(0), 112 | ) 113 | 114 | probabilities = torch.softmax(logits, -1) 115 | probabilities = probabilities[0].detach().cpu().numpy() 116 | best_prob = probabilities.argmax() 117 | best_rotation = queries[0].detach().cpu().numpy()[best_prob] 118 | angular_errors.append( 119 | compute_angular_error(rotation_gt.numpy(), best_rotation) 120 | ) 121 | return np.array(angular_errors) 122 | 123 | 124 | def evaluate_pairwise( 125 | model=None, 126 | params=None, 127 | checkpoint_path=None, 128 | split="train", 129 | num_frames=5, 130 | print_results=True, 131 | use_pbar=False, 132 | categories=TEST_CATEGORIES, 133 | dataset="co3dv1", 134 | ): 135 | if model is None or params is None: 136 | print(checkpoint_path) 137 | model, params = get_model(checkpoint_path) 138 | 139 | errors_15 = {} 140 | errors_30 = {} 141 | for category in categories: 142 | angular_errors = evaluate_category( 143 | model=model, 144 | params=params, 145 | category=category, 146 | split=split, 147 | num_frames=num_frames, 148 | use_pbar=use_pbar, 149 | dataset=dataset, 150 | ) 151 | errors_15[category] = np.mean(angular_errors < 15) 152 | errors_30[category] = np.mean(angular_errors < 30) 153 | 154 | errors_15["mean"] = np.mean(list(errors_15.values())) 155 | errors_30["mean"] = np.mean(list(errors_30.values())) 156 | if print_results: 157 | print(f"{'Category':>10s}{'<15':6s}{'<30':6s}") 158 | for category in errors_15.keys(): 159 | print( 160 | f"{category:>10s}{errors_15[category]:6.02f}{errors_30[category]:6.02f}" 161 | ) 162 | return errors_15, errors_30 163 | 164 | 165 | if __name__ == "__main__": 166 | args = get_parser().parse_args() 167 | evaluate_pairwise( 168 | checkpoint_path=args.checkpoint, 169 | num_frames=args.num_frames, 170 | print_results=True, 171 | use_pbar=args.use_pbar, 172 | split="test", 173 | ) 174 | -------------------------------------------------------------------------------- /relpose/eval/load_model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import os.path as osp 4 | 5 | import torch 6 | 7 | from relpose.dataset import Co3dDataset, Co3dv1Dataset 8 | from relpose.models import RelPose 9 | 10 | 11 | def get_model(checkpoint, device="cuda:0"): 12 | """ 13 | Loads a model from a checkpoint and any associated metadata. 14 | """ 15 | if ".pth" not in checkpoint: 16 | checkpoint_dir = osp.join(checkpoint, "checkpoints") 17 | last_checkpoint = sorted(os.listdir(checkpoint_dir))[-1] 18 | print(f"Loading checkpoint {last_checkpoint}") 19 | checkpoint = osp.join(checkpoint_dir, last_checkpoint) 20 | pretrained_weights = torch.load(checkpoint, map_location=device)["state_dict"] 21 | pretrained_weights = { 22 | k.replace("module.", ""): v for k, v in pretrained_weights.items() 23 | } 24 | args_path = osp.join(osp.dirname(osp.dirname(checkpoint)), "args.json") 25 | if osp.exists(args_path): 26 | with open(args_path) as f: 27 | args = json.load(f) 28 | args["output_dir"] = osp.dirname(osp.dirname(checkpoint)) 29 | else: 30 | args = {} 31 | relpose = RelPose(sample_mode="equivolumetric") 32 | relpose.to(device) 33 | relpose.load_state_dict(pretrained_weights) 34 | relpose.eval() 35 | return relpose, args 36 | 37 | 38 | def get_eval_dataset(category, split, dataset="co3dv1"): 39 | if isinstance(category, str): 40 | category = [category] 41 | if dataset == "co3dv1": 42 | dataset = Co3dv1Dataset(category, split, random_aug=False) 43 | elif dataset in ["co3d", "co3dv2"]: 44 | dataset = Co3dDataset(category, split, random_aug=False) 45 | else: 46 | raise ValueError(f"Unknown dataset {dataset}") 47 | return dataset 48 | -------------------------------------------------------------------------------- /relpose/eval/pairwise.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import os.path as osp 5 | 6 | import numpy as np 7 | import torch 8 | import torchvision 9 | from tqdm.auto import tqdm 10 | 11 | from relpose.dataset.co3dv1 import TEST_CATEGORIES 12 | from relpose.eval.load_model import get_eval_dataset, get_model 13 | from relpose.utils import get_permutations 14 | 15 | 16 | def compute_angular_error(rotation1, rotation2): 17 | R_rel = rotation1.T @ rotation2 18 | tr = (np.trace(R_rel) - 1) / 2 19 | theta = np.arccos(tr.clip(-1, 1)) 20 | return theta * 180 / np.pi 21 | 22 | 23 | def compute_angular_error_batch(rotation1, rotation2): 24 | R_rel = np.einsum("Bij,Bjk ->Bik", rotation1.transpose(0, 2, 1), rotation2) 25 | t = (np.trace(R_rel, axis1=1, axis2=2) - 1) / 2 26 | theta = np.arccos(np.clip(t, -1, 1)) 27 | return theta * 180 / np.pi 28 | 29 | 30 | def evaluate_category( 31 | model, 32 | params, 33 | category="banana", 34 | split="train", 35 | num_frames=5, 36 | use_pbar=False, 37 | ): 38 | dataset = get_eval_dataset(category=category, split=split, params=params) 39 | device = next(model.parameters()).device 40 | 41 | permutations = get_permutations(num_frames) 42 | angular_errors = [] 43 | iterable = tqdm(dataset) if use_pbar else dataset 44 | for metadata in iterable: 45 | n = metadata["n"] 46 | sequence_name = metadata["model_id"] 47 | key_frames = np.linspace(0, n - 1, num=num_frames, dtype=int) 48 | batch = dataset.get_data(sequence_name=sequence_name, ids=key_frames) 49 | images = batch["image"] 50 | rotations = batch["R"] 51 | images_permuted = images[permutations] 52 | rotations_permuted = rotations[permutations] 53 | rotations_gt = torch.bmm( 54 | rotations_permuted[:, 1], 55 | rotations_permuted[:, 0].transpose(1, 2), 56 | ) 57 | images1 = images_permuted[:, 0].to(device) 58 | images2 = images_permuted[:, 1].to(device) 59 | 60 | for i in range(len(permutations)): 61 | image1 = images1[i] 62 | image2 = images2[i] 63 | rotation_gt = rotations_gt[i] 64 | 65 | with torch.no_grad(): 66 | queries, logits = model( 67 | images1=image1.unsqueeze(0), 68 | images2=image2.unsqueeze(0), 69 | recursion_level=4, 70 | gt_rotation=rotation_gt.to(device).unsqueeze(0), 71 | ) 72 | 73 | probabilities = torch.softmax(logits, -1) 74 | probabilities = probabilities[0].detach().cpu().numpy() 75 | best_prob = probabilities.argmax() 76 | best_rotation = queries[0].detach().cpu().numpy()[best_prob] 77 | angular_errors.append( 78 | compute_angular_error(rotation_gt.numpy(), best_rotation) 79 | ) 80 | return np.array(angular_errors) 81 | 82 | 83 | def evaluate_pairwise( 84 | model=None, 85 | params=None, 86 | checkpoint_path=None, 87 | split="train", 88 | num_frames=5, 89 | print_results=True, 90 | use_pbar=False, 91 | categories=TEST_CATEGORIES, 92 | ): 93 | if model is None or params is None: 94 | print(checkpoint_path) 95 | model, params = get_model(checkpoint_path) 96 | 97 | errors_15 = {} 98 | errors_30 = {} 99 | for category in categories: 100 | angular_errors = evaluate_category( 101 | model=model, 102 | params=params, 103 | category=category, 104 | split=split, 105 | num_frames=num_frames, 106 | use_pbar=use_pbar, 107 | ) 108 | errors_15[category] = np.mean(angular_errors < 15) 109 | errors_30[category] = np.mean(angular_errors < 30) 110 | 111 | errors_15["mean"] = np.mean(list(errors_15.values())) 112 | errors_30["mean"] = np.mean(list(errors_30.values())) 113 | if print_results: 114 | print(f"{'Category':>10s}{'<15':6s}{'<30':6s}") 115 | for category in errors_15.keys(): 116 | print( 117 | f"{category:>10s}{errors_15[category]:6.02f}{errors_30[category]:6.02f}" 118 | ) 119 | return errors_15, errors_30 120 | -------------------------------------------------------------------------------- /relpose/inference/__init__.py: -------------------------------------------------------------------------------- 1 | from .joint_inference import ( 2 | get_permutations, 3 | run_coordinate_ascent, 4 | run_maximum_spanning_tree, 5 | ) 6 | -------------------------------------------------------------------------------- /relpose/inference/joint_inference.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from tqdm.auto import tqdm 4 | 5 | from relpose.utils.geometry import generate_random_rotations 6 | 7 | 8 | def get_permutations(num_frames): 9 | permutations = [] 10 | for i in range(num_frames): 11 | for j in range(num_frames): 12 | if i != j: 13 | permutations.append((i, j)) 14 | return torch.tensor(permutations) 15 | 16 | 17 | def score_hypothesis(hypothesis, model, permutations, features): 18 | R_pred_batched = hypothesis[permutations] 19 | R_pred_rel = torch.einsum( 20 | "bij,bjk ->bik", 21 | R_pred_batched[:, 0].permute(0, 2, 1), 22 | R_pred_batched[:, 1], 23 | ) 24 | features_batched = features[permutations] 25 | _, logits = model( 26 | features1=features_batched[:, 0], 27 | features2=features_batched[:, 1], 28 | queries=R_pred_rel.unsqueeze(0), 29 | ) 30 | score = torch.trace(logits) 31 | return score 32 | 33 | 34 | def compute_mst(num_frames, best_probs, best_rotations): 35 | """ 36 | Computes the maximum spanning tree of the graph defined by the best_probs. Uses 37 | Prim's algorithm (modified for a directed graph). 38 | Currently a naive O(N^3) implementation :P 39 | """ 40 | current_assigned = {0} 41 | assigned_rotations = np.tile(np.eye(3), [num_frames, 1, 1]) 42 | 43 | edges = [] 44 | 45 | while len(current_assigned) < num_frames: 46 | # Find the highest probability edge that connects an unassigned node to the MST 47 | best_i = -1 48 | best_j = -1 49 | best_p = -1 50 | not_assigned = set(range(num_frames)) - current_assigned 51 | for i in current_assigned: 52 | for j in not_assigned: 53 | if best_probs[i, j] > best_p: 54 | best_p = best_probs[i, j] 55 | best_i = i 56 | best_j = j 57 | if best_probs[j, i] > best_p: 58 | best_p = best_probs[j, i] 59 | best_i = j 60 | best_j = i 61 | 62 | rot = best_rotations[best_i, best_j] 63 | if best_i in current_assigned: 64 | current_assigned.add(best_j) 65 | assigned_rotations[best_j] = assigned_rotations[best_i] @ rot 66 | else: 67 | current_assigned.add(best_i) 68 | assigned_rotations[best_i] = assigned_rotations[best_j] @ rot.T 69 | edges.append((best_i, best_j)) 70 | 71 | return assigned_rotations, edges 72 | 73 | 74 | def run_maximum_spanning_tree(model, images, num_frames): 75 | device = images.device 76 | permutations = get_permutations(num_frames) 77 | best_rotations = np.zeros((num_frames, num_frames, 3, 3)) 78 | best_probs = np.zeros((num_frames, num_frames)) 79 | for i, j in permutations: 80 | image1 = images[i].unsqueeze(0).to(device) 81 | image2 = images[j].unsqueeze(0).to(device) 82 | with torch.no_grad(): 83 | queries, logits = model( 84 | images1=image1, 85 | images2=image2, 86 | ) 87 | probabilities = torch.softmax(logits, -1) 88 | probabilities = probabilities[0].detach().cpu().numpy() 89 | best_prob = probabilities.max() 90 | best_rotation = queries[0].detach().cpu().numpy()[probabilities.argmax()] 91 | 92 | best_rotations[i, j] = best_rotation 93 | best_probs[i, j] = best_prob 94 | 95 | rotations_pred, edges = compute_mst( 96 | num_frames=num_frames, 97 | best_probs=best_probs, 98 | best_rotations=best_rotations, 99 | ) 100 | return rotations_pred 101 | 102 | 103 | def run_coordinate_ascent( 104 | model, 105 | images, 106 | num_frames, 107 | initial_hypothesis, 108 | num_iterations=200, 109 | num_queries=250_000, 110 | use_pbar=True, 111 | ): 112 | """ 113 | Args: 114 | model (nn.Module): RelPose model. 115 | images (torch.Tensor): Tensor of shape (N, 3, H, W) containing the images. 116 | num_frames (int): Number of frames in the sequence. 117 | initial_hypothesis (np.ndarray): Initial hypothesis of shape (N, 3, 3). 118 | num_iterations (int): Number of iterations to run coordinate ascent. Defaults 119 | to 200. 120 | num_queries (int): Number of queries to use for each coordinate ascent. Defaults 121 | to 250,000. 122 | use_pbar (bool): Whether to use a progress bar. Defaults to True. 123 | 124 | Returns: 125 | torch.tensor: Final hypothesis of shape (N, 3, 3). 126 | """ 127 | device = images.device 128 | hypothesis = torch.from_numpy(initial_hypothesis).to(device).float() 129 | features = model.feature_extractor(images.to(device)) 130 | it = tqdm(range(num_iterations)) if use_pbar else range(num_iterations) 131 | for j in it: 132 | # Randomly sample an index to update 133 | k = np.random.choice(num_frames) 134 | proposals = generate_random_rotations(num_queries, device) 135 | proposals[0] = hypothesis[k] 136 | scores = torch.zeros(1, num_queries, device=device) 137 | for i in range(num_frames): 138 | if i == k: 139 | continue 140 | feature1 = features[i, None] 141 | feature2 = features[k, None] 142 | R_rel = hypothesis[i].T @ proposals 143 | with torch.no_grad(): 144 | _, logits = model( 145 | features1=feature1, 146 | features2=feature2, 147 | queries=R_rel.unsqueeze(0), 148 | ) 149 | scores += logits 150 | _, logits = model( 151 | features1=feature2, 152 | features2=feature1, 153 | queries=R_rel.transpose(1, 2).unsqueeze(0), 154 | ) 155 | scores += logits 156 | best_ind = scores.argmax() 157 | hypothesis[k] = proposals[best_ind] 158 | return hypothesis 159 | -------------------------------------------------------------------------------- /relpose/models.py: -------------------------------------------------------------------------------- 1 | import antialiased_cnns 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from relpose.utils.geometry import generate_random_rotations, generate_superfibonacci 7 | 8 | 9 | def generate_hypotheses(rotations_gt, num_queries=50000): 10 | """ 11 | Args: 12 | rotations_gt (tensor): Batched rotations (B, N_I, 3, 3). 13 | 14 | Returns: 15 | hypotheses (tensor): Hypotheses (B, N_I, N_Q, 3, 3). 16 | """ 17 | batch_size, num_images, _, _ = rotations_gt.shape 18 | hypotheses = generate_random_rotations( 19 | (num_queries - 1) * batch_size * num_images, device=rotations_gt.device 20 | ) 21 | # (B, N_i, N_q - 1, 3, 3) 22 | hypotheses = hypotheses.reshape(batch_size, num_images, (num_queries - 1), 3, 3) 23 | # (B, N_i, N_q, 3, 3) 24 | hypotheses = torch.cat((rotations_gt.unsqueeze(2), hypotheses), dim=2) 25 | return hypotheses 26 | 27 | 28 | def get_feature_extractor(): 29 | """ 30 | Returns a network that takes in images (B, 3, 224, 224) and outputs a feature 31 | vector (B, 2048, 1, 1). 32 | """ 33 | model = antialiased_cnns.resnet50(pretrained=True) 34 | feature_extractor = torch.nn.Sequential(*(list(model.children())[:-1])) 35 | return feature_extractor 36 | 37 | 38 | class RelPose(nn.Module): 39 | def __init__( 40 | self, 41 | feature_extractor=None, 42 | num_pe_bases=8, 43 | num_layers=4, 44 | hidden_size=256, 45 | num_queries=50000, 46 | sample_mode="random", 47 | freeze_encoder=False, 48 | ): 49 | """ 50 | Args: 51 | feature_extractor (nn.Module): Feature extractor. 52 | num_pe_bases (int): Number of positional encoding bases. 53 | num_layers (int): Number of layers in the network. 54 | hidden_size (int): Size of the hidden layer. 55 | num_queries (int): Number of rotations to sample if using random sampling. 56 | sample_mode (str): Sampling mode. Can be equivolumetric or random. 57 | """ 58 | super().__init__() 59 | if feature_extractor is None: 60 | feature_extractor = get_feature_extractor() 61 | self.num_queries = num_queries 62 | self.sample_mode = sample_mode 63 | 64 | self.feature_extractor = feature_extractor 65 | if freeze_encoder: 66 | self.freeze_encoder() 67 | 68 | self.use_positional_encoding = num_pe_bases > 0 69 | if self.use_positional_encoding: 70 | query_size = num_pe_bases * 2 * 9 71 | self.register_buffer( 72 | "embedding", (2 ** torch.arange(num_pe_bases)).reshape(1, 1, -1) 73 | ) 74 | else: 75 | query_size = 9 76 | 77 | self.embed_feature = nn.Linear(2048 * 2, hidden_size) 78 | self.embed_query = nn.Linear(query_size, hidden_size) 79 | layers = [] 80 | for _ in range(num_layers - 2): 81 | layers.append(nn.LeakyReLU()) 82 | layers.append(nn.Linear(hidden_size, hidden_size)) 83 | layers.append(nn.LeakyReLU()) 84 | layers.append(nn.Linear(hidden_size, 1)) 85 | self.layers = nn.Sequential(*layers) 86 | self.equi_grid = {} 87 | 88 | def freeze_encoder(self): 89 | for param in self.feature_extractor.parameters(): 90 | param.requires_grad = False 91 | 92 | def positional_encoding(self, x): 93 | """ 94 | Args: 95 | x (tensor): Input (B, D). 96 | 97 | Returns: 98 | y (tensor): Positional encoding (B, 2 * D * L). 99 | """ 100 | if not self.use_positional_encoding: 101 | return x 102 | embed = (x[..., None] * self.embedding).view(*x.shape[:-1], -1) 103 | return torch.cat((embed.sin(), embed.cos()), dim=-1) 104 | 105 | def forward( 106 | self, 107 | images1=None, 108 | images2=None, 109 | features1=None, 110 | features2=None, 111 | gt_rotation=None, 112 | num_queries=None, 113 | queries=None, 114 | ): 115 | """ 116 | Must provide either images1 and images2 or features1 and features2. If 117 | gt_rotation is provided, the first query will be the ground truth rotation. 118 | 119 | Args: 120 | images1 (tensor): First set of images (B, 3, 224, 224). 121 | images2 (tensor): Corresponding set of images (B, 3, 224, 224). 122 | gt_rotation (tensor): Ground truth rotation (B, 3, 3). 123 | num_queries (int): Number of rotations to sample if using random sampling. 124 | 125 | Returns: 126 | rotations (tensor): Rotation matrices (B, num_queries, 3, 3). First query 127 | is the ground truth rotation. 128 | logits (tensor): logits (B, num_queries). 129 | """ 130 | 131 | if features1 is None: 132 | features1 = self.feature_extractor(images1) 133 | if features2 is None: 134 | features2 = self.feature_extractor(images2) 135 | features = torch.cat([features1, features2], dim=1) 136 | 137 | batch_size = features1.size(0) 138 | assert batch_size == features2.size(0) 139 | features = features.reshape(batch_size, -1) # (B, 4096) 140 | if queries is None: 141 | if num_queries is None: 142 | num_queries = self.num_queries 143 | if self.sample_mode == "equivolumetric": 144 | if num_queries not in self.equi_grid: 145 | self.equi_grid[num_queries] = generate_superfibonacci( 146 | num_queries, device="cpu" 147 | ) 148 | queries = self.equi_grid[num_queries].to(images1.device) 149 | elif self.sample_mode == "random": 150 | queries = generate_random_rotations(num_queries, device=images1.device) 151 | else: 152 | raise Exception(f"Unknown sampling mode {self.sample_mode}.") 153 | 154 | if gt_rotation is not None: 155 | delta_rot = queries[0].T @ gt_rotation 156 | # First entry will always be the gt rotation 157 | queries = torch.einsum("aij,bjk->baik", queries, delta_rot) 158 | else: 159 | if len(queries.shape) == 3: 160 | queries = queries.unsqueeze(0) 161 | num_queries = queries.shape[1] 162 | else: 163 | num_queries = queries.shape[1] 164 | 165 | queries_pe = self.positional_encoding(queries.reshape(-1, num_queries, 9)) 166 | 167 | e_f = self.embed_feature(features).unsqueeze(1) # (B, 1, H) 168 | e_q = self.embed_query(queries_pe) # (B, n_q, H) 169 | out = self.layers(e_f + e_q) # (B, n_q, 1) 170 | logits = out.reshape(batch_size, num_queries) 171 | return queries, logits 172 | 173 | def predict_probability(self, images1, images2, query_rotation, num_queries=None): 174 | """ 175 | Args: 176 | images1 (tensor): First set of images (B, 3, 224, 224). 177 | images2 (tensor): Corresponding set of images (B, 3, 224, 224). 178 | gt_rotation (tensor): Ground truth rotation (B, 3, 3). 179 | num_queries (int): Number of rotations to sample. If gt_rotation is given 180 | will sample num_queries - batch size. 181 | 182 | Returns: 183 | probabilities 184 | """ 185 | logits = self.forward( 186 | images1, 187 | images2, 188 | gt_rotation=query_rotation, 189 | num_queries=num_queries, 190 | ) 191 | probabilities = torch.softmax(logits, dim=-1) 192 | probabilities = probabilities * num_queries / np.pi**2 193 | return probabilities[:, 0] 194 | -------------------------------------------------------------------------------- /relpose/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trainer for relpose. Default configuration is run with 4 GPUs. 3 | 4 | Usage: 5 | python -m relpose.trainer --batch_size 64 --num_gpus 4 --output_dir output 6 | """ 7 | import argparse 8 | import datetime 9 | import json 10 | import os 11 | import os.path as osp 12 | import shutil 13 | import time 14 | from glob import glob 15 | 16 | import cv2 17 | import matplotlib 18 | import numpy as np 19 | import torch 20 | from torch.nn import DataParallel 21 | from torch.utils.tensorboard import SummaryWriter 22 | 23 | from relpose.dataset import get_dataloader 24 | from relpose.dataset.co3d import TEST_CATEGORIES, TRAINING_CATEGORIES 25 | from relpose.eval.eval_pairwise import evaluate_pairwise 26 | from relpose.models import RelPose 27 | from relpose.utils.visualize import unnormalize_image, visualize_so3_probabilities 28 | 29 | matplotlib.use("Agg") 30 | 31 | 32 | def get_parser(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--category", nargs="+", type=str, default=["all"]) 35 | parser.add_argument("--batch_size", type=int, default=64) 36 | parser.add_argument("--num_iterations", type=int, default=400_000) 37 | parser.add_argument("--interval_checkpoint", type=int, default=1000) 38 | parser.add_argument( 39 | "--interval_delete_checkpoint", 40 | type=int, 41 | default=10000, 42 | help="Interval to delete old checkpoints.", 43 | ) 44 | parser.add_argument("--interval_visualize", type=int, default=1000) 45 | parser.add_argument("--interval_evaluate", type=int, default=25000) 46 | parser.add_argument("--lr", type=float, default=0.001) 47 | parser.add_argument("--output_dir", type=str, default="output") 48 | parser.add_argument( 49 | "--dataset", 50 | type=str, 51 | default="co3d", 52 | help="co3d or co3dv1. co3d refers to co3dv2", 53 | ) 54 | parser.add_argument("--name", type=str, default="") 55 | parser.add_argument( 56 | "--sampling_mode", 57 | type=str, 58 | default="equivolumetric", 59 | help="Sampling mode can be equivolumetric or random.", 60 | ) 61 | parser.add_argument("--resume", default="", type=str, help="Path to directory.") 62 | parser.add_argument("--num_gpus", type=int, default=4) 63 | parser.add_argument( 64 | "--num_workers", type=int, default=None, help="Default: 4 * num_gpus" 65 | ) 66 | parser.add_argument("--debug", action="store_true") 67 | parser.add_argument( 68 | "--pretrained", 69 | default="", 70 | help="Path to pretrained model (to load weights from)", 71 | ) 72 | parser.add_argument( 73 | "--freeze_encoder", 74 | action="store_true", 75 | help="If True, freezes the image encoder.", 76 | ) 77 | return parser 78 | 79 | 80 | def get_permutations(num_images): 81 | for i in range(num_images): 82 | for j in range(num_images): 83 | if i != j: 84 | yield (i, j) 85 | 86 | 87 | class Trainer(object): 88 | def __init__(self, args) -> None: 89 | self.args = args 90 | self.batch_size = args.batch_size 91 | self.num_iterations = int(args.num_iterations) 92 | self.lr = args.lr 93 | self.dataset = args.dataset 94 | self.interval_visualize = args.interval_visualize 95 | self.interval_checkpoint = args.interval_checkpoint 96 | self.interval_delete_checkpoint = args.interval_delete_checkpoint 97 | self.interval_evaluate = args.interval_evaluate 98 | assert self.interval_delete_checkpoint % self.interval_checkpoint == 0 99 | self.debug = args.debug 100 | 101 | # Experiment settings: 102 | self.category = args.category 103 | self.freeze_encoder = args.freeze_encoder 104 | self.sampling_mode = args.sampling_mode 105 | 106 | self.iteration = 0 107 | self.epoch = 0 108 | 109 | num_workers = ( 110 | args.num_gpus * 4 if args.num_workers is None else args.num_workers 111 | ) 112 | if self.category[0] == "all": 113 | self.category = TRAINING_CATEGORIES 114 | print("preparing dataloader") 115 | self.dataloader = get_dataloader( 116 | category=self.category, 117 | dataset=self.dataset, 118 | split="train", 119 | batch_size=self.batch_size, 120 | num_workers=num_workers, 121 | debug=self.debug, 122 | ) 123 | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 124 | self.net = RelPose( 125 | num_layers=4, 126 | num_pe_bases=8, 127 | hidden_size=256, 128 | sample_mode=args.sampling_mode, 129 | num_queries=36864, # To match Healpy recursion level = 3 130 | freeze_encoder=self.freeze_encoder, 131 | ) 132 | self.net = DataParallel(self.net, device_ids=list(range(args.num_gpus))) 133 | self.net.to(self.device) 134 | self.start_time = None 135 | 136 | # Setup output directory. 137 | name = datetime.datetime.now().strftime("%m%d_%H%M") 138 | if self.debug: 139 | name += "_debug" 140 | name += args.name 141 | name += f"_{args.dataset}" 142 | if "co3d" in args.dataset: 143 | if len(self.category) != len(TRAINING_CATEGORIES): 144 | name += f"{'-'.join(sorted(args.category))}" 145 | if args.sampling_mode != "equivolumetric": 146 | name += f"_{args.sampling_mode}" 147 | if self.batch_size != 64: 148 | name += f"_b{args.batch_size}" 149 | if args.lr != 0.001: 150 | name += f"_lr{args.lr}" 151 | 152 | if args.pretrained != "": 153 | name += "_pre" + osp.basename(args.pretrained)[:9] 154 | if args.freeze_encoder: 155 | name += "_freeze" 156 | 157 | # Resume checkpoint. 158 | if args.resume: 159 | self.output_dir = args.resume 160 | self.checkpoint_dir = osp.join(self.output_dir, "checkpoints") 161 | last_checkpoint = sorted(os.listdir(self.checkpoint_dir))[-1] 162 | self.load_model(osp.join(self.checkpoint_dir, last_checkpoint)) 163 | else: 164 | self.output_dir = osp.join(args.output_dir, name) 165 | self.checkpoint_dir = osp.join(self.output_dir, "checkpoints") 166 | os.makedirs(self.checkpoint_dir, exist_ok=True) 167 | 168 | with open(osp.join(self.output_dir, "args.json"), "w") as f: 169 | json.dump(vars(args), f) 170 | # Make a copy of the code. 171 | shutil.copytree("relpose", osp.join(self.output_dir, "relpose")) 172 | print("Output Directory:", self.output_dir) 173 | 174 | if args.pretrained != "": 175 | checkpoint_dir = osp.join(args.pretrained, "checkpoints") 176 | last_checkpoint = sorted(os.listdir(checkpoint_dir))[-1] 177 | self.load_model( 178 | osp.join(checkpoint_dir, last_checkpoint), load_metadata=False 179 | ) 180 | 181 | # Setup tensorboard. 182 | self.writer = SummaryWriter(log_dir=self.output_dir, flush_secs=30) 183 | 184 | def train(self): 185 | optimizer = torch.optim.AdamW(self.net.parameters(), lr=self.lr) 186 | while self.iteration < self.num_iterations: 187 | for batch in self.dataloader: 188 | images = batch["image"].to(self.device, non_blocking=True) 189 | optimizer.zero_grad() 190 | image1 = images[:, 0] 191 | image2 = images[:, 1] 192 | 193 | relative_rotation = batch["relative_rotation"].to( 194 | self.device, non_blocking=True 195 | ) 196 | queries, logits = self.net( 197 | images1=image1, 198 | images2=image2, 199 | gt_rotation=relative_rotation, 200 | ) 201 | log_prob = torch.log_softmax(logits, dim=-1) 202 | loss = -torch.mean(log_prob[:, 0]) 203 | loss.backward() 204 | optimizer.step() 205 | 206 | if self.iteration % self.interval_checkpoint == 0: 207 | checkpoint_path = osp.join( 208 | self.checkpoint_dir, f"ckpt_{self.iteration:09d}.pth" 209 | ) 210 | self.save_model(checkpoint_path) 211 | 212 | if self.iteration % self.interval_visualize == 0: 213 | visuals = self.make_visualization( 214 | images1=image1, 215 | images2=image2, 216 | rotations=queries, 217 | probabilities=logits.softmax(dim=-1), 218 | model_id=batch["model_id"], 219 | category=batch["category"], 220 | ind1=batch["ind"][:, 0], 221 | ind2=batch["ind"][:, 1], 222 | ) 223 | for v, image in enumerate(visuals): 224 | self.writer.add_image( 225 | f"Visualization/{v}", 226 | image, 227 | self.iteration, 228 | dataformats="HWC", 229 | ) 230 | 231 | if self.iteration % 20 == 0: 232 | if self.start_time is None: 233 | self.start_time = time.time() 234 | time_elapsed = np.round(time.time() - self.start_time) 235 | time_remaining = np.round( 236 | (time.time() - self.start_time) 237 | / (self.iteration + 1) 238 | * (self.num_iterations - self.iteration) 239 | ) 240 | disp = [ 241 | f"Iter: {self.iteration:d}/{self.num_iterations:d}", 242 | f"Epoch: {self.epoch:d}", 243 | f"Loss: {loss.item():.3f}", 244 | f"Elap: {str(datetime.timedelta(seconds=time_elapsed))}", 245 | f"Rem: {str(datetime.timedelta(seconds=time_remaining))}", 246 | ] 247 | print(", ".join(disp)) 248 | self.writer.add_scalar("Loss/train", loss.item(), self.iteration) 249 | 250 | self.iteration += 1 251 | 252 | if self.iteration % self.interval_evaluate == 0: 253 | del images, image1, image2, queries, logits 254 | errors_15, errors_30 = evaluate_pairwise( 255 | self.net, 256 | params=vars(self.args), 257 | split="test", 258 | print_results=True, 259 | use_pbar=True, 260 | categories=TEST_CATEGORIES, 261 | dataset=self.dataset, 262 | ) 263 | for k, v in errors_15.items(): 264 | self.writer.add_scalar(f"Val/{k}@15", v, self.iteration) 265 | for k, v in errors_30.items(): 266 | self.writer.add_scalar(f"Val/{k}@30", v, self.iteration) 267 | 268 | if self.iteration % self.interval_delete_checkpoint == 0: 269 | self.clear_old_checkpoints(self.checkpoint_dir) 270 | 271 | if self.iteration >= self.num_iterations + 1: 272 | break 273 | self.epoch += 1 274 | 275 | def save_model(self, path): 276 | elapsed = time.time() - self.start_time if self.start_time is not None else 0 277 | save_dict = { 278 | "state_dict": self.net.state_dict(), 279 | "iteration": self.iteration, 280 | "epoch": self.epoch, 281 | "elapsed": elapsed, 282 | } 283 | torch.save(save_dict, path) 284 | 285 | def load_model(self, path, load_metadata=True): 286 | save_dict = torch.load(path) 287 | if "state_dict" in save_dict: 288 | self.net.load_state_dict(save_dict["state_dict"]) 289 | if load_metadata: 290 | self.iteration = save_dict["iteration"] 291 | self.epoch = save_dict["epoch"] 292 | if "elapsed" in save_dict: 293 | time_elapsed = save_dict["elapsed"] 294 | self.start_time = time.time() - time_elapsed 295 | else: 296 | self.net.load_state_dict(save_dict) 297 | 298 | def clear_old_checkpoints(self, checkpoint_dir): 299 | print("Clearing old checkpoints") 300 | checkpoint_files = glob(osp.join(checkpoint_dir, "ckpt_*.pth")) 301 | for checkpoint_file in checkpoint_files: 302 | checkpoint = osp.basename(checkpoint_file) 303 | checkpoint_iteration = int("".join(filter(str.isdigit, checkpoint))) 304 | if checkpoint_iteration % self.interval_delete_checkpoint != 0: 305 | os.remove(checkpoint_file) 306 | 307 | def make_visualization( 308 | self, 309 | images1, 310 | images2, 311 | rotations, 312 | probabilities, 313 | num_vis=5, 314 | model_id=None, 315 | category=None, 316 | ind1=None, 317 | ind2=None, 318 | ): 319 | images1 = images1[:num_vis].detach().cpu().numpy().transpose(0, 2, 3, 1) 320 | images2 = images2[:num_vis].detach().cpu().numpy().transpose(0, 2, 3, 1) 321 | rotations = rotations[:num_vis].detach().cpu().numpy() 322 | probabilities = probabilities[:num_vis].detach().cpu().numpy() 323 | 324 | visuals = [] 325 | for i in range(len(images1)): 326 | # image1 = unnormalize_image(cv2.resize(images1[i], (448, 448))) 327 | # image2 = unnormalize_image(cv2.resize(images2[i], (448, 448))) 328 | image1 = unnormalize_image(images1[i]) 329 | image2 = unnormalize_image(images2[i]) 330 | so3_vis = visualize_so3_probabilities( 331 | rotations=rotations[i], 332 | probabilities=probabilities[i], 333 | rotations_gt=rotations[i, 0], 334 | to_image=True, 335 | display_threshold_probability=1 / len(probabilities[i]), 336 | dpi=112, 337 | ) 338 | full_image = np.vstack((np.hstack((image1, image2)), so3_vis)) 339 | if model_id is not None: 340 | cv2.putText(full_image, model_id[i], (5, 40), 4, 1, (0, 0, 255)) 341 | cv2.putText(full_image, category[i], (5, 80), 4, 1, (0, 0, 255)) 342 | cv2.putText(full_image, str(int(ind1[i])), (5, 120), 4, 1, (0, 0, 255)) 343 | cv2.putText( 344 | full_image, str(int(ind2[i])), (453, 120), 4, 1, (0, 0, 255) 345 | ) 346 | visuals.append(full_image) 347 | return visuals 348 | 349 | 350 | if __name__ == "__main__": 351 | args = get_parser().parse_args() 352 | trainer = Trainer(args) 353 | trainer.train() 354 | -------------------------------------------------------------------------------- /relpose/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .permutations import get_permutations 2 | from .visualize import visualize_so3_probabilities 3 | -------------------------------------------------------------------------------- /relpose/utils/bbox.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def mask_to_bbox(mask, thresh=0.4): 5 | """ 6 | xyxy format 7 | """ 8 | mask = mask > thresh 9 | if not np.any(mask): 10 | return [] 11 | rows = np.any(mask, axis=1) 12 | cols = np.any(mask, axis=0) 13 | rmin, rmax = np.where(rows)[0][[0, -1]] 14 | cmin, cmax = np.where(cols)[0][[0, -1]] 15 | return [int(cmin), int(rmin), int(cmax) + 1, int(rmax) + 1] 16 | 17 | 18 | def square_bbox(bbox, padding=0.0, astype=None): 19 | """ 20 | Computes a square bounding box, with optional padding parameters. 21 | 22 | Args: 23 | bbox: Bounding box in xyxy format (4,). 24 | 25 | Returns: 26 | square_bbox in xyxy format (4,). 27 | """ 28 | if astype is None: 29 | astype = type(bbox[0]) 30 | bbox = np.array(bbox) 31 | center = (bbox[:2] + bbox[2:]) / 2 32 | extents = (bbox[2:] - bbox[:2]) / 2 33 | s = max(extents) * (1 + padding) 34 | square_bbox = np.array( 35 | [center[0] - s, center[1] - s, center[0] + s, center[1] + s], 36 | dtype=astype, 37 | ) 38 | return square_bbox 39 | -------------------------------------------------------------------------------- /relpose/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import numpy as np 4 | import torch 5 | from pytorch3d.transforms import euler_angles_to_matrix, quaternion_to_matrix 6 | 7 | 8 | def generate_random_rotations(n=1, device="cpu"): 9 | quats = torch.randn(n, 4, device=device) 10 | quats = quats / quats.norm(dim=1, keepdim=True) 11 | return quaternion_to_matrix(quats) 12 | 13 | 14 | def generate_superfibonacci(n=1, device="cpu"): 15 | """ 16 | Samples n rotations equivolumetrically using a Super-Fibonacci Spiral. 17 | 18 | Reference: Marc Alexa, Super-Fibonacci Spirals. CVPR 22. 19 | 20 | Args: 21 | n (int): Number of rotations to sample. 22 | device (str): CUDA Device. Defaults to CPU. 23 | 24 | Returns: 25 | (tensor): Rotations (n, 3, 3). 26 | """ 27 | phi = np.sqrt(2.0) 28 | psi = 1.533751168755204288118041 29 | ind = torch.arange(n, device=device) 30 | s = ind + 0.5 31 | r = torch.sqrt(s / n) 32 | R = torch.sqrt(1.0 - s / n) 33 | alpha = 2 * np.pi * s / phi 34 | beta = 2.0 * np.pi * s / psi 35 | Q = torch.stack( 36 | [ 37 | r * torch.sin(alpha), 38 | r * torch.cos(alpha), 39 | R * torch.sin(beta), 40 | R * torch.cos(beta), 41 | ], 42 | 1, 43 | ) 44 | return quaternion_to_matrix(Q).float() 45 | 46 | 47 | def generate_equivolumetric_grid(recursion_level=3): 48 | """ 49 | Generates an equivolumetric grid on SO(3). Deprecated in favor of super-fibonacci 50 | which is more efficient and does not require additional dependencies. 51 | 52 | Uses a Healpix grid on S2 and then tiles 6 * 2 ** recursion level over 2pi. 53 | 54 | Code adapted from https://github.com/google-research/google-research/blob/master/ 55 | implicit_pdf/models.py 56 | 57 | Grid sizes: 58 | 1: 576 59 | 2: 4608 60 | 3: 36864 61 | 4: 294912 62 | 5: 2359296 63 | n: 72 * 8 ** n 64 | 65 | Args: 66 | recursion_level: The recursion level of the Healpix grid. 67 | 68 | Returns: 69 | tensor: rotation matrices (N, 3, 3). 70 | """ 71 | import healpy 72 | 73 | log = logging.getLogger("healpy") 74 | log.setLevel(logging.ERROR) # Supress healpy linking warnings. 75 | 76 | number_per_side = 2**recursion_level 77 | number_pix = healpy.nside2npix(number_per_side) 78 | s2_points = healpy.pix2vec(number_per_side, np.arange(number_pix)) 79 | s2_points = torch.tensor(np.stack([*s2_points], 1)) 80 | 81 | azimuths = torch.atan2(s2_points[:, 1], s2_points[:, 0]) 82 | # torch doesn't have endpoint=False for linspace yet. 83 | tilts = torch.tensor( 84 | np.linspace(0, 2 * np.pi, 6 * 2**recursion_level, endpoint=False) 85 | ) 86 | polars = torch.arccos(s2_points[:, 2]) 87 | grid_rots_mats = [] 88 | for tilt in tilts: 89 | rot_mats = euler_angles_to_matrix( 90 | torch.stack( 91 | [azimuths, torch.zeros(number_pix), torch.zeros(number_pix)], 1 92 | ), 93 | "XYZ", 94 | ) 95 | rot_mats = rot_mats @ euler_angles_to_matrix( 96 | torch.stack([torch.zeros(number_pix), torch.zeros(number_pix), polars], 1), 97 | "XYZ", 98 | ) 99 | rot_mats = rot_mats @ euler_angles_to_matrix( 100 | torch.tensor([[tilt, 0.0, 0.0]]), "XYZ" 101 | ) 102 | grid_rots_mats.append(rot_mats) 103 | 104 | grid_rots_mats = torch.cat(grid_rots_mats, 0) 105 | return grid_rots_mats.float() 106 | -------------------------------------------------------------------------------- /relpose/utils/permutations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def get_permutations(num_images): 5 | permutations = [] 6 | for i in range(num_images): 7 | for j in range(num_images): 8 | if i != j: 9 | permutations.append((i, j)) 10 | return np.array(permutations) 11 | -------------------------------------------------------------------------------- /relpose/utils/visualize.py: -------------------------------------------------------------------------------- 1 | """ 2 | The visualization code is adapted from Implicit-PDF (Murphy et. al.) 3 | github.com/google-research/google-research/blob/master/implicit_pdf/evaluation.py 4 | 5 | Modified so that the rotations are interpretable as yaw (x-axis), pitch (y-axis), and 6 | roll (color). 7 | """ 8 | import matplotlib.pyplot as plt 9 | import numpy as np 10 | import torch 11 | from matplotlib import rc 12 | from PIL import Image 13 | from pytorch3d import transforms 14 | 15 | rc("font", **{"family": "serif", "serif": ["Times New Roman"]}) 16 | 17 | EYE = np.eye(3) 18 | 19 | 20 | def visualize_so3_probabilities( 21 | rotations, 22 | probabilities, 23 | rotations_gt=None, 24 | ax=None, 25 | fig=None, 26 | display_threshold_probability=0, 27 | to_image=True, 28 | show_color_wheel=True, 29 | canonical_rotation=EYE, 30 | gt_size=2500, 31 | y_offset=-30, 32 | dpi=400, 33 | ): 34 | """ 35 | Plot a single distribution on SO(3) using the tilt-colored method. 36 | 37 | Args: 38 | rotations: [N, 3, 3] tensor of rotation matrices 39 | probabilities: [N] tensor of probabilities 40 | rotations_gt: [N_gt, 3, 3] or [3, 3] ground truth rotation matrices 41 | ax: The matplotlib.pyplot.axis object to paint 42 | fig: The matplotlib.pyplot.figure object to paint 43 | display_threshold_probability: The probability threshold below which to omit 44 | the marker 45 | to_image: If True, return a tensor containing the pixels of the finished 46 | figure; if False return the figure itself 47 | show_color_wheel: If True, display the explanatory color wheel which matches 48 | color on the plot with tilt angle 49 | canonical_rotation: A [3, 3] rotation matrix representing the 'display 50 | rotation', to change the view of the distribution. It rotates the 51 | canonical axes so that the view of SO(3) on the plot is different, which 52 | can help obtain a more informative view. 53 | 54 | Returns: 55 | A matplotlib.pyplot.figure object, or a tensor of pixels if to_image=True. 56 | """ 57 | 58 | def _show_single_marker(ax, rotation, marker, edgecolors=True, facecolors=False): 59 | eulers = transforms.matrix_to_euler_angles(torch.tensor(rotation), "ZXY") 60 | eulers = eulers.numpy() 61 | 62 | tilt_angle = eulers[0] 63 | latitude = eulers[1] 64 | longitude = eulers[2] 65 | 66 | color = cmap(0.5 + tilt_angle / 2 / np.pi) 67 | ax.scatter( 68 | longitude, 69 | latitude, 70 | s=gt_size, 71 | edgecolors=color if edgecolors else "none", 72 | facecolors=facecolors if facecolors else "none", 73 | marker=marker, 74 | linewidth=4, 75 | ) 76 | 77 | if ax is None: 78 | fig = plt.figure(figsize=(4, 2), dpi=dpi) 79 | ax = fig.add_subplot(111, projection="mollweide") 80 | if rotations_gt is not None and len(rotations_gt.shape) == 2: 81 | rotations_gt = rotations_gt[None] 82 | 83 | display_rotations = rotations @ canonical_rotation 84 | cmap = plt.cm.hsv 85 | scatterpoint_scaling = 4e3 86 | eulers_queries = transforms.matrix_to_euler_angles( 87 | torch.tensor(display_rotations), "ZXY" 88 | ) 89 | eulers_queries = eulers_queries.numpy() 90 | 91 | tilt_angles = eulers_queries[:, 0] 92 | longitudes = eulers_queries[:, 2] 93 | latitudes = eulers_queries[:, 1] 94 | 95 | which_to_display = probabilities > display_threshold_probability 96 | 97 | if rotations_gt is not None: 98 | display_rotations_gt = rotations_gt @ canonical_rotation 99 | 100 | for rotation in display_rotations_gt: 101 | _show_single_marker(ax, rotation, "o") 102 | # Cover up the centers with white markers 103 | for rotation in display_rotations_gt: 104 | _show_single_marker( 105 | ax, rotation, "o", edgecolors=False, facecolors="#ffffff" 106 | ) 107 | 108 | # Display the distribution 109 | ax.scatter( 110 | longitudes[which_to_display], 111 | latitudes[which_to_display], 112 | s=scatterpoint_scaling * probabilities[which_to_display], 113 | c=cmap(0.5 + tilt_angles[which_to_display] / 2.0 / np.pi), 114 | ) 115 | 116 | yticks = np.array([-60, -30, 0, 30, 60]) 117 | yticks_minor = np.arange(-75, 90, 15) 118 | ax.set_yticks(yticks_minor * np.pi / 180, minor=True) 119 | ax.set_yticks(yticks * np.pi / 180, [f"{y}°" for y in yticks], fontsize=14) 120 | xticks = np.array([-90, 0, 90]) 121 | xticks_minor = np.arange(-150, 180, 30) 122 | ax.set_xticks(xticks * np.pi / 180, []) 123 | ax.set_xticks(xticks_minor * np.pi / 180, minor=True) 124 | 125 | for xtick in xticks: 126 | # Manually set xticks 127 | x = xtick * np.pi / 180 128 | y = y_offset * np.pi / 180 129 | ax.text(x, y, f"{xtick}°", ha="center", va="center", fontsize=14) 130 | 131 | ax.grid(which="minor") 132 | ax.grid(which="major") 133 | 134 | if show_color_wheel: 135 | # Add a color wheel showing the tilt angle to color conversion. 136 | ax = fig.add_axes([0.86, 0.17, 0.12, 0.12], projection="polar") 137 | theta = np.linspace(-3 * np.pi / 2, np.pi / 2, 200) 138 | radii = np.linspace(0.4, 0.5, 2) 139 | _, theta_grid = np.meshgrid(radii, theta) 140 | colormap_val = 0.5 + theta_grid / np.pi / 2.0 141 | ax.pcolormesh(theta, radii, colormap_val.T, cmap=cmap, shading="auto") 142 | ax.set_yticklabels([]) 143 | ax.set_xticks(np.arange(0, 2 * np.pi, np.pi / 2)) 144 | ax.set_xticklabels( 145 | [ 146 | r"90$\degree$", 147 | r"180$\degree$", 148 | r"270$\degree$", 149 | r"0$\degree$", 150 | ], 151 | fontsize=12, 152 | ) 153 | ax.spines["polar"].set_visible(False) 154 | plt.text( 155 | 0.5, 156 | 0.5, 157 | "Roll", 158 | fontsize=10, 159 | horizontalalignment="center", 160 | verticalalignment="center", 161 | transform=ax.transAxes, 162 | ) 163 | 164 | if to_image: 165 | return plot_to_image(fig) 166 | else: 167 | return fig 168 | 169 | 170 | def plot_to_image(fig): 171 | fig.canvas.draw() 172 | image_from_plot = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 173 | image_from_plot = image_from_plot.reshape( 174 | fig.canvas.get_width_height()[::-1] + (3,) 175 | ) 176 | plt.close(fig) 177 | return image_from_plot 178 | 179 | 180 | def antialias(image, level=1): 181 | is_numpy = isinstance(image, np.ndarray) 182 | if is_numpy: 183 | image = Image.fromarray(image) 184 | for _ in range(level): 185 | size = np.array(image.size) // 2 186 | image = image.resize(size, Image.LANCZOS) 187 | if is_numpy: 188 | image = np.array(image) 189 | return image 190 | 191 | 192 | def unnormalize_image(image): 193 | if isinstance(image, torch.Tensor): 194 | image = image.cpu().numpy() 195 | if image.shape[0] == 3: 196 | image = image.transpose(1, 2, 0) 197 | mean = np.array([0.485, 0.456, 0.406]) 198 | std = np.array([0.229, 0.224, 0.225]) 199 | image = image * std + mean 200 | return (image * 255.0).astype(np.uint8) 201 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | antialiased-cnns 2 | black 3 | cython 4 | flake8 5 | gdown 6 | ipdb 7 | isort 8 | jupyter 9 | matplotlib 10 | numpy 11 | opencv-python 12 | plotly 13 | tensorboard 14 | tqdm -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [isort] 2 | profile=black 3 | line_length=88 4 | multi_line_output=3 5 | force_grid_wrap=0 6 | use_parentheses=True 7 | ensure_newline_before_comments=True 8 | include_trailing_comma=True 9 | skip=data,dev,external,output*,co3d,tools 10 | skip_glob=*/__init__.py 11 | known_myself=relpose 12 | sections=FUTURE,STDLIB,THIRDPARTY,FIRSTPARTY,MYSELF,LOCALFOLDER 13 | default_section=THIRDPARTY 14 | 15 | [flake8] 16 | max-line-length = 88 17 | ignore = E203,E501,W503,W605 18 | per-file-ignores = 19 | __init__.py: F401 20 | exclude = data,dev,external,output* 21 | --------------------------------------------------------------------------------