├── .gitattributes ├── .github └── workflows │ ├── deploy.yaml │ └── test.yaml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── diffpose ├── __init__.py ├── _modidx.py ├── calibration.py ├── deepfluoro.py ├── jacobians.py ├── ljubljana.py ├── metrics.py ├── registration.py └── visualization.py ├── environment.yml ├── experiments ├── deepfluoro │ ├── evaluate.py │ ├── register.py │ └── train.py ├── ljubljana │ ├── register.py │ └── train.py └── test_time_optimization.gif ├── notebooks ├── _quarto.yml ├── api │ ├── 00_deepfluoro.ipynb │ ├── 01_ljubljana.ipynb │ ├── 02_calibration.ipynb │ ├── 03_registration.ipynb │ ├── 04_metrics.ipynb │ ├── 05_visualization.ipynb │ └── 06_jacobians.ipynb ├── experiments │ ├── 00_3D_visualization.ipynb │ ├── 01_pose_recovery.ipynb │ ├── 02_loss_landscapes.ipynb │ ├── 03_sparse_rendering.ipynb │ ├── render.html │ └── test_time_optimization.gif ├── favicon.png ├── index.ipynb ├── nbdev.yml ├── sidebar.yml └── styles.css ├── settings.ini └── setup.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored=true 2 | *.ipynb merge=nbdev-merge 3 | *.html linguist-documentation 4 | -------------------------------------------------------------------------------- /.github/workflows/deploy.yaml: -------------------------------------------------------------------------------- 1 | name: Docs 2 | 3 | permissions: 4 | contents: write 5 | pages: write 6 | 7 | on: 8 | push: 9 | branches: [ "main", "master" ] 10 | workflow_dispatch: 11 | 12 | jobs: 13 | deploy: 14 | runs-on: ubuntu-latest 15 | steps: 16 | - name: Checkout repository 17 | uses: actions/checkout@v3 18 | 19 | - name: Activate conda env with environment.yml 20 | uses: mamba-org/setup-micromamba@v1 21 | with: 22 | environment-file: environment.yml 23 | cache-environment: true 24 | post-cleanup: 'all' 25 | 26 | - name: Install nbdev 27 | shell: bash -l {0} 28 | run: | 29 | pip install -U nbdev 30 | 31 | - name: Doing editable install 32 | shell: bash -l {0} 33 | run: | 34 | test -f setup.py && pip install -e ".[dev]" 35 | 36 | - name: Run nbdev_docs 37 | shell: bash -l {0} 38 | run: | 39 | nbdev_docs 40 | 41 | - name: Deploy to GitHub Pages 42 | uses: peaceiris/actions-gh-pages@v3 43 | with: 44 | github_token: ${{ github.token }} 45 | force_orphan: true 46 | publish_dir: ./_docs 47 | # The following lines assign commit authorship to the official GH-Actions bot for deploys to `gh-pages` branch. 48 | # You can swap them out with your own user credentials. 49 | user_name: github-actions[bot] 50 | user_email: 41898282+github-actions[bot]@users.noreply.github.com 51 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [workflow_dispatch, pull_request, push] 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | steps: 8 | - name: Checkout repository 9 | uses: actions/checkout@v3 10 | 11 | - name: Activate conda env with environment.yml 12 | uses: mamba-org/setup-micromamba@v1 13 | with: 14 | environment-file: environment.yml 15 | cache-environment: true 16 | post-cleanup: 'all' 17 | 18 | - name: Install nbdev 19 | shell: bash -l {0} 20 | run: | 21 | pip install -U nbdev 22 | 23 | - name: Doing editable install 24 | shell: bash -l {0} 25 | run: | 26 | test -f setup.py && pip install -e ".[dev]" 27 | 28 | - name: Check we are starting with clean git checkout 29 | shell: bash -l {0} 30 | run: | 31 | if [[ `git status --porcelain -uno` ]]; then 32 | git diff 33 | echo "git status is not clean" 34 | false 35 | fi 36 | 37 | - name: Trying to strip out notebooks 38 | shell: bash -l {0} 39 | run: | 40 | nbdev_clean 41 | git status -s # display the status to see which nbs need cleaning up 42 | if [[ `git status --porcelain -uno` ]]; then 43 | git status -uno 44 | echo -e "!!! Detected unstripped out notebooks\n!!!Remember to run nbdev_install_hooks" 45 | echo -e "This error can also happen if you are using an older version of nbdev relative to what is in CI. Please try to upgrade nbdev with the command `pip install -U nbdev`" 46 | false 47 | fi 48 | 49 | - name: Run nbdev_export 50 | shell: bash -l {0} 51 | run: | 52 | nbdev_export 53 | if [[ `git status --porcelain -uno` ]]; then 54 | echo "::error::Notebooks and library are not in sync. Please run nbdev_export." 55 | git status -uno 56 | git diff 57 | exit 1; 58 | fi 59 | 60 | - name: Run nbdev_test 61 | shell: bash -l {0} 62 | run: | 63 | nbdev_test 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/ 2 | evaluations/ 3 | data/ 4 | logs/ 5 | runs/ 6 | 7 | _docs/ 8 | _proc/ 9 | 10 | *.bak 11 | .gitattributes 12 | .last_checked 13 | .gitconfig 14 | *.bak 15 | *.log 16 | *~ 17 | ~* 18 | _tmp* 19 | tmp* 20 | tags 21 | *.pkg 22 | 23 | # Byte-compiled / optimized / DLL files 24 | __pycache__/ 25 | *.py[cod] 26 | *$py.class 27 | 28 | # C extensions 29 | *.so 30 | 31 | # Distribution / packaging 32 | .Python 33 | env/ 34 | build/ 35 | develop-eggs/ 36 | dist/ 37 | downloads/ 38 | eggs/ 39 | .eggs/ 40 | lib/ 41 | lib64/ 42 | parts/ 43 | sdist/ 44 | var/ 45 | wheels/ 46 | *.egg-info/ 47 | .installed.cfg 48 | *.egg 49 | 50 | # PyInstaller 51 | # Usually these files are written by a python script from a template 52 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 53 | *.manifest 54 | *.spec 55 | 56 | # Installer logs 57 | pip-log.txt 58 | pip-delete-this-directory.txt 59 | 60 | # Unit test / coverage reports 61 | htmlcov/ 62 | .tox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | .hypothesis/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | 79 | # Flask stuff: 80 | instance/ 81 | .webassets-cache 82 | 83 | # Scrapy stuff: 84 | .scrapy 85 | 86 | # Sphinx documentation 87 | docs/_build/ 88 | 89 | # PyBuilder 90 | target/ 91 | 92 | # Jupyter Notebook 93 | .ipynb_checkpoints 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # celery beat schedule file 99 | celerybeat-schedule 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # dotenv 105 | .env 106 | 107 | # virtualenv 108 | .venv 109 | venv/ 110 | ENV/ 111 | 112 | # Spyder project settings 113 | .spyderproject 114 | .spyproject 115 | 116 | # Rope project settings 117 | .ropeproject 118 | 119 | # mkdocs documentation 120 | /site 121 | 122 | # mypy 123 | .mypy_cache/ 124 | 125 | .vscode 126 | *.swp 127 | 128 | # osx generated files 129 | .DS_Store 130 | .DS_Store? 131 | .Trashes 132 | ehthumbs.db 133 | Thumbs.db 134 | .idea 135 | 136 | # pytest 137 | .pytest_cache 138 | 139 | # tools/trust-doc-nbs 140 | docs_src/.last_checked 141 | 142 | # symlinks to fastai 143 | docs_src/fastai 144 | tools/fastai 145 | 146 | # link checker 147 | checklink/cookies.txt 148 | 149 | # .gitconfig is now autogenerated 150 | .gitconfig 151 | 152 | # Quarto installer 153 | .deb 154 | .pkg 155 | 156 | # Quarto 157 | .quarto 158 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Vivek Gopalakrishnan 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffPose 2 | 3 | 4 | 5 | > Intraoperative 2D/3D registration via differentiable X-ray rendering 6 | 7 | [![CI](https://github.com/eigenvivek/DiffPose/actions/workflows/test.yaml/badge.svg)](https://github.com/eigenvivek/DiffPose/actions/workflows/test.yaml) 8 | [![Paper 9 | shield](https://img.shields.io/badge/arXiv-2312.06358-red.svg)](https://arxiv.org/abs/2312.06358) 10 | [![License: 11 | MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) 12 | [![Docs](https://github.com/eigenvivek/DiffPose/actions/workflows/deploy.yaml/badge.svg)](https://vivekg.dev/DiffPose) 13 | [![Code style: 14 | black](https://img.shields.io/badge/Code%20style-black-black.svg)](https://github.com/psf/black) 15 | 16 | ![](experiments/test_time_optimization.gif) 17 | 18 | > [!NOTE] 19 | > 20 | > If you're considering using `DiffPose` for your own data, instead use [`xvr`](https://github.com/eigenvivek/xvr), the 21 | > successor project from the same authors. 22 | > 23 | > `xvr` is actively maintained and offers many features not present in `DiffPose`, such as a Python API and CLI for 24 | > training your own patient-specific pose regression networks (in about 5 minutes!) and performing test-time optimization. 25 | 26 | ## Install 27 | 28 | To install `DiffPose` and the requirements in 29 | [`environment.yml`](https://github.com/eigenvivek/DiffPose/blob/main/environment.yml), 30 | run: 31 | 32 | ``` zsh 33 | pip install diffpose 34 | ``` 35 | 36 | The differentiable X-ray renderer that powers the backend of `DiffPose` 37 | is available at [`DiffDRR`](https://github.com/eigenvivek/DiffDRR). 38 | 39 | ## Datasets 40 | 41 | We evaluate `DiffPose` networks on the following open-source datasets: 42 | 43 | | **Dataset** | **Anatomy** | **\# of Subjects** | **\# of 2D Images** | **CTs** | **X-rays** | Fiducials | 44 | |----------------------------------------------------------------------------|--------------------|:------------------:|:-------------------:|:-------:|:----------:|:---------:| 45 | | [`DeepFluoro`](https://github.com/rg2/DeepFluoroLabeling-IPCAI2020) | Pelvis | 6 | 366 | ✅ | ✅ | ❌ | 46 | | [`Ljubljana`](https://lit.fe.uni-lj.si/en/research/resources/3D-2D-GS-CA/) | Cerebrovasculature | 10 | 20 | ✅ | ✅ | ✅ | 47 | 48 | 50 | 51 | - `DeepFluoro` ([**Grupp et al., 52 | 2020**](https://link.springer.com/article/10.1007/s11548-020-02162-7)) 53 | provides paired X-ray fluoroscopy images and CT volume of the pelvis. 54 | The data were collected from six cadaveric subjects at John Hopkins 55 | University. Ground truth camera poses were estimated with an offline 56 | registration process. A visualization of one X-ray / CT pair in the 57 | `DeepFluoro` dataset is available 58 | [here](https://vivekg.dev/DiffPose/experiments/render.html). 59 | 60 | ``` zsh 61 | mkdir -p data/ 62 | wget --no-check-certificate -O data/ipcai_2020_full_res_data.zip "http://archive.data.jhu.edu/api/access/datafile/:persistentId/?persistentId=doi:10.7281/T1/IFSXNV/EAN9GH" 63 | unzip -o data/ipcai_2020_full_res_data.zip -d data 64 | rm data/ipcai_2020_full_res_data.zip 65 | ``` 66 | 67 | - `Ljubljana` ([**Mitrovic et al., 68 | 2013**](https://ieeexplore.ieee.org/abstract/document/6507588)) 69 | provides paired 2D/3D digital subtraction angiography (DSA) images. 70 | The data were collected from 10 patients undergoing endovascular 71 | image-guided interventions at the University of Ljubljana. Ground 72 | truth camera poses were estimated by registering surface fiducial 73 | markers. 74 | 75 | ``` zsh 76 | mkdir -p data/ 77 | wget --no-check-certificate -O data/ljubljana.zip "https://drive.google.com/uc?export=download&confirm=yes&id=1x585pGLI8QGk21qZ2oGwwQ9LMJ09Tqrx" 78 | unzip -o data/ljubljana.zip -d data 79 | rm data/ljubljana.zip 80 | ``` 81 | 82 | 84 | 85 | ## Experiments 86 | 87 | To run the experiments in `DiffPose`, run the following scripts (ensure 88 | you’ve downloaded the data first): 89 | 90 | ``` zsh 91 | # DeepFluoro dataset 92 | cd experiments/deepfluoro 93 | srun python train.py # Pretrain pose regression CNN on synthetic X-rays 94 | srun python register.py # Run test-time optimization with the best network per subject 95 | ``` 96 | 97 | ``` zsh 98 | # Ljubljana dataset 99 | cd experiments/ljubljana 100 | srun python train.py 101 | srun python register.py 102 | ``` 103 | 104 | The training and test-time optimization scripts use SLURM to run on all 105 | subjects in parallel: 106 | 107 | - `experiments/deepfluoro/train.py` is configured to run across six 108 | A6000 GPUs 109 | - `experiments/deepfluoro/register.py` is configured to run across six 110 | 2080 Ti GPUs 111 | - `experiments/ljubljana/train.py` is configured to run across twenty 112 | 2080 Ti GPUs 113 | - `experiments/ljubljana/register.py` is configured to run on twenty 114 | 2080 Ti GPUs 115 | 116 | The GPU configurations can be changed at the end of each script using 117 | [`submitit`](https://github.com/facebookincubator/submitit). 118 | 119 | ## Development 120 | 121 | `DiffPose` package, docs, and CI are all built using 122 | [`nbdev`](https://nbdev.fast.ai/). To get set up with`nbdev`, install 123 | the following 124 | 125 | ``` zsh 126 | conda install jupyterlab nbdev -c fastai -c conda-forge 127 | nbdev_install_quarto # To build docs 128 | nbdev_install_hooks # Make notebooks git-friendly 129 | pip install -e ".[dev]" # Install the development verison of DiffPose 130 | ``` 131 | 132 | Running `nbdev_help` will give you the full list of options. The most 133 | important ones are 134 | 135 | ``` zsh 136 | nbdev_preview # Render docs locally and inspect in browser 137 | nbdev_clean # NECESSARY BEFORE PUSHING 138 | nbdev_test # tests notebooks 139 | nbdev_export # builds package and builds docs 140 | nbdev_readme # Render the readme 141 | ``` 142 | 143 | For more details, follow this [in-depth 144 | tutorial](https://nbdev.fast.ai/tutorials/tutorial.html). 145 | 146 | ## Citing `DiffPose` 147 | 148 | If you find `DiffPose` or 149 | [`DiffDRR`](https://github.com/eigenvivek/DiffDRR) useful in your work, 150 | please cite the appropriate papers: 151 | 152 | ``` 153 | @article{gopalakrishnan2023intraoperative, 154 | title={Intraoperative {2D/3D} Image Registration via Differentiable X-ray Rendering}, 155 | author={Gopalakrishnan, Vivek and Dey, Neel and Golland, Polina}, 156 | journal={arXiv preprint arXiv:2312.06358}, 157 | year={2023} 158 | } 159 | 160 | @inproceedings{gopalakrishnan2022fast, 161 | title={Fast Auto-Differentiable Digitally Reconstructed Radiographs for Solving Inverse Problems in Intraoperative Imaging}, 162 | author={Gopalakrishnan, Vivek and Golland, Polina}, 163 | booktitle={Workshop on Clinical Image-Based Procedures}, 164 | pages={1--11}, 165 | year={2022}, 166 | organization={Springer} 167 | } 168 | ``` 169 | -------------------------------------------------------------------------------- /diffpose/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | -------------------------------------------------------------------------------- /diffpose/_modidx.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by nbdev 2 | 3 | d = { 'settings': { 'branch': 'main', 4 | 'doc_baseurl': '/DiffPose', 5 | 'doc_host': 'https://vivekg.dev', 6 | 'git_url': 'https://github.com/eigenvivek/DiffPose', 7 | 'lib_path': 'diffpose'}, 8 | 'syms': { 'diffpose.calibration': { 'diffpose.calibration.RigidTransform': ( 'api/calibration.html#rigidtransform', 9 | 'diffpose/calibration.py'), 10 | 'diffpose.calibration.RigidTransform.__init__': ( 'api/calibration.html#rigidtransform.__init__', 11 | 'diffpose/calibration.py'), 12 | 'diffpose.calibration.RigidTransform.clone': ( 'api/calibration.html#rigidtransform.clone', 13 | 'diffpose/calibration.py'), 14 | 'diffpose.calibration.RigidTransform.compose': ( 'api/calibration.html#rigidtransform.compose', 15 | 'diffpose/calibration.py'), 16 | 'diffpose.calibration.RigidTransform.get_rotation': ( 'api/calibration.html#rigidtransform.get_rotation', 17 | 'diffpose/calibration.py'), 18 | 'diffpose.calibration.RigidTransform.get_se3_log': ( 'api/calibration.html#rigidtransform.get_se3_log', 19 | 'diffpose/calibration.py'), 20 | 'diffpose.calibration.RigidTransform.get_translation': ( 'api/calibration.html#rigidtransform.get_translation', 21 | 'diffpose/calibration.py'), 22 | 'diffpose.calibration.RigidTransform.inverse': ( 'api/calibration.html#rigidtransform.inverse', 23 | 'diffpose/calibration.py'), 24 | 'diffpose.calibration.convert': ('api/calibration.html#convert', 'diffpose/calibration.py'), 25 | 'diffpose.calibration.perspective_projection': ( 'api/calibration.html#perspective_projection', 26 | 'diffpose/calibration.py')}, 27 | 'diffpose.deepfluoro': { 'diffpose.deepfluoro.DeepFluoroDataset': ( 'api/deepfluoro.html#deepfluorodataset', 28 | 'diffpose/deepfluoro.py'), 29 | 'diffpose.deepfluoro.DeepFluoroDataset.__getitem__': ( 'api/deepfluoro.html#deepfluorodataset.__getitem__', 30 | 'diffpose/deepfluoro.py'), 31 | 'diffpose.deepfluoro.DeepFluoroDataset.__init__': ( 'api/deepfluoro.html#deepfluorodataset.__init__', 32 | 'diffpose/deepfluoro.py'), 33 | 'diffpose.deepfluoro.DeepFluoroDataset.__iter__': ( 'api/deepfluoro.html#deepfluorodataset.__iter__', 34 | 'diffpose/deepfluoro.py'), 35 | 'diffpose.deepfluoro.DeepFluoroDataset.__len__': ( 'api/deepfluoro.html#deepfluorodataset.__len__', 36 | 'diffpose/deepfluoro.py'), 37 | 'diffpose.deepfluoro.DeepFluoroDataset._rot_180_for_up': ( 'api/deepfluoro.html#deepfluorodataset._rot_180_for_up', 38 | 'diffpose/deepfluoro.py'), 39 | 'diffpose.deepfluoro.DeepFluoroDataset.get_2d_fiducials': ( 'api/deepfluoro.html#deepfluorodataset.get_2d_fiducials', 40 | 'diffpose/deepfluoro.py'), 41 | 'diffpose.deepfluoro.Evaluator': ('api/deepfluoro.html#evaluator', 'diffpose/deepfluoro.py'), 42 | 'diffpose.deepfluoro.Evaluator.__call__': ( 'api/deepfluoro.html#evaluator.__call__', 43 | 'diffpose/deepfluoro.py'), 44 | 'diffpose.deepfluoro.Evaluator.__init__': ( 'api/deepfluoro.html#evaluator.__init__', 45 | 'diffpose/deepfluoro.py'), 46 | 'diffpose.deepfluoro.Evaluator.project': ( 'api/deepfluoro.html#evaluator.project', 47 | 'diffpose/deepfluoro.py'), 48 | 'diffpose.deepfluoro.Transforms': ('api/deepfluoro.html#transforms', 'diffpose/deepfluoro.py'), 49 | 'diffpose.deepfluoro.Transforms.__call__': ( 'api/deepfluoro.html#transforms.__call__', 50 | 'diffpose/deepfluoro.py'), 51 | 'diffpose.deepfluoro.Transforms.__init__': ( 'api/deepfluoro.html#transforms.__init__', 52 | 'diffpose/deepfluoro.py'), 53 | 'diffpose.deepfluoro.convert_deepfluoro_to_diffdrr': ( 'api/deepfluoro.html#convert_deepfluoro_to_diffdrr', 54 | 'diffpose/deepfluoro.py'), 55 | 'diffpose.deepfluoro.convert_diffdrr_to_deepfluoro': ( 'api/deepfluoro.html#convert_diffdrr_to_deepfluoro', 56 | 'diffpose/deepfluoro.py'), 57 | 'diffpose.deepfluoro.get_3d_fiducials': ( 'api/deepfluoro.html#get_3d_fiducials', 58 | 'diffpose/deepfluoro.py'), 59 | 'diffpose.deepfluoro.get_random_offset': ( 'api/deepfluoro.html#get_random_offset', 60 | 'diffpose/deepfluoro.py'), 61 | 'diffpose.deepfluoro.load_deepfluoro_dataset': ( 'api/deepfluoro.html#load_deepfluoro_dataset', 62 | 'diffpose/deepfluoro.py'), 63 | 'diffpose.deepfluoro.parse_proj_params': ( 'api/deepfluoro.html#parse_proj_params', 64 | 'diffpose/deepfluoro.py'), 65 | 'diffpose.deepfluoro.parse_volume': ('api/deepfluoro.html#parse_volume', 'diffpose/deepfluoro.py'), 66 | 'diffpose.deepfluoro.preprocess': ('api/deepfluoro.html#preprocess', 'diffpose/deepfluoro.py')}, 67 | 'diffpose.jacobians': { 'diffpose.jacobians.JacobianDRR': ('api/jacobians.html#jacobiandrr', 'diffpose/jacobians.py'), 68 | 'diffpose.jacobians.JacobianDRR.__init__': ( 'api/jacobians.html#jacobiandrr.__init__', 69 | 'diffpose/jacobians.py'), 70 | 'diffpose.jacobians.JacobianDRR.cast': ('api/jacobians.html#jacobiandrr.cast', 'diffpose/jacobians.py'), 71 | 'diffpose.jacobians.JacobianDRR.forward': ( 'api/jacobians.html#jacobiandrr.forward', 72 | 'diffpose/jacobians.py'), 73 | 'diffpose.jacobians.JacobianDRR.permute': ( 'api/jacobians.html#jacobiandrr.permute', 74 | 'diffpose/jacobians.py'), 75 | 'diffpose.jacobians.gradient_matching': ( 'api/jacobians.html#gradient_matching', 76 | 'diffpose/jacobians.py'), 77 | 'diffpose.jacobians.plot_img_jacobian': ( 'api/jacobians.html#plot_img_jacobian', 78 | 'diffpose/jacobians.py')}, 79 | 'diffpose.ljubljana': { 'diffpose.ljubljana.Evaluator': ('api/ljubljana.html#evaluator', 'diffpose/ljubljana.py'), 80 | 'diffpose.ljubljana.Evaluator.__call__': ( 'api/ljubljana.html#evaluator.__call__', 81 | 'diffpose/ljubljana.py'), 82 | 'diffpose.ljubljana.Evaluator.__init__': ( 'api/ljubljana.html#evaluator.__init__', 83 | 'diffpose/ljubljana.py'), 84 | 'diffpose.ljubljana.Evaluator.project': ( 'api/ljubljana.html#evaluator.project', 85 | 'diffpose/ljubljana.py'), 86 | 'diffpose.ljubljana.LjubljanaDataset': ('api/ljubljana.html#ljubljanadataset', 'diffpose/ljubljana.py'), 87 | 'diffpose.ljubljana.LjubljanaDataset.__getitem__': ( 'api/ljubljana.html#ljubljanadataset.__getitem__', 88 | 'diffpose/ljubljana.py'), 89 | 'diffpose.ljubljana.LjubljanaDataset.__init__': ( 'api/ljubljana.html#ljubljanadataset.__init__', 90 | 'diffpose/ljubljana.py'), 91 | 'diffpose.ljubljana.LjubljanaDataset.__iter__': ( 'api/ljubljana.html#ljubljanadataset.__iter__', 92 | 'diffpose/ljubljana.py'), 93 | 'diffpose.ljubljana.LjubljanaDataset.__len__': ( 'api/ljubljana.html#ljubljanadataset.__len__', 94 | 'diffpose/ljubljana.py'), 95 | 'diffpose.ljubljana.Transforms': ('api/ljubljana.html#transforms', 'diffpose/ljubljana.py'), 96 | 'diffpose.ljubljana.Transforms.__call__': ( 'api/ljubljana.html#transforms.__call__', 97 | 'diffpose/ljubljana.py'), 98 | 'diffpose.ljubljana.Transforms.__init__': ( 'api/ljubljana.html#transforms.__init__', 99 | 'diffpose/ljubljana.py'), 100 | 'diffpose.ljubljana.get_random_offset': ( 'api/ljubljana.html#get_random_offset', 101 | 'diffpose/ljubljana.py')}, 102 | 'diffpose.metrics': { 'diffpose.metrics.CustomMetric': ('api/metrics.html#custommetric', 'diffpose/metrics.py'), 103 | 'diffpose.metrics.CustomMetric.__init__': ( 'api/metrics.html#custommetric.__init__', 104 | 'diffpose/metrics.py'), 105 | 'diffpose.metrics.CustomMetric.compute': ('api/metrics.html#custommetric.compute', 'diffpose/metrics.py'), 106 | 'diffpose.metrics.CustomMetric.update': ('api/metrics.html#custommetric.update', 'diffpose/metrics.py'), 107 | 'diffpose.metrics.DoubleGeodesic': ('api/metrics.html#doublegeodesic', 'diffpose/metrics.py'), 108 | 'diffpose.metrics.DoubleGeodesic.__init__': ( 'api/metrics.html#doublegeodesic.__init__', 109 | 'diffpose/metrics.py'), 110 | 'diffpose.metrics.DoubleGeodesic.forward': ( 'api/metrics.html#doublegeodesic.forward', 111 | 'diffpose/metrics.py'), 112 | 'diffpose.metrics.GeodesicSE3': ('api/metrics.html#geodesicse3', 'diffpose/metrics.py'), 113 | 'diffpose.metrics.GeodesicSE3.__init__': ('api/metrics.html#geodesicse3.__init__', 'diffpose/metrics.py'), 114 | 'diffpose.metrics.GeodesicSE3.forward': ('api/metrics.html#geodesicse3.forward', 'diffpose/metrics.py'), 115 | 'diffpose.metrics.GeodesicSO3': ('api/metrics.html#geodesicso3', 'diffpose/metrics.py'), 116 | 'diffpose.metrics.GeodesicSO3.__init__': ('api/metrics.html#geodesicso3.__init__', 'diffpose/metrics.py'), 117 | 'diffpose.metrics.GeodesicSO3.forward': ('api/metrics.html#geodesicso3.forward', 'diffpose/metrics.py'), 118 | 'diffpose.metrics.GeodesicTranslation': ('api/metrics.html#geodesictranslation', 'diffpose/metrics.py'), 119 | 'diffpose.metrics.GeodesicTranslation.__init__': ( 'api/metrics.html#geodesictranslation.__init__', 120 | 'diffpose/metrics.py'), 121 | 'diffpose.metrics.GeodesicTranslation.forward': ( 'api/metrics.html#geodesictranslation.forward', 122 | 'diffpose/metrics.py'), 123 | 'diffpose.metrics.GradientNormalizedCrossCorrelation': ( 'api/metrics.html#gradientnormalizedcrosscorrelation', 124 | 'diffpose/metrics.py'), 125 | 'diffpose.metrics.GradientNormalizedCrossCorrelation.__init__': ( 'api/metrics.html#gradientnormalizedcrosscorrelation.__init__', 126 | 'diffpose/metrics.py'), 127 | 'diffpose.metrics.MultiscaleNormalizedCrossCorrelation': ( 'api/metrics.html#multiscalenormalizedcrosscorrelation', 128 | 'diffpose/metrics.py'), 129 | 'diffpose.metrics.MultiscaleNormalizedCrossCorrelation.__init__': ( 'api/metrics.html#multiscalenormalizedcrosscorrelation.__init__', 130 | 'diffpose/metrics.py'), 131 | 'diffpose.metrics.NormalizedCrossCorrelation': ( 'api/metrics.html#normalizedcrosscorrelation', 132 | 'diffpose/metrics.py'), 133 | 'diffpose.metrics.NormalizedCrossCorrelation.__init__': ( 'api/metrics.html#normalizedcrosscorrelation.__init__', 134 | 'diffpose/metrics.py')}, 135 | 'diffpose.registration': { 'diffpose.registration.PoseRegressor': ( 'api/registration.html#poseregressor', 136 | 'diffpose/registration.py'), 137 | 'diffpose.registration.PoseRegressor.__init__': ( 'api/registration.html#poseregressor.__init__', 138 | 'diffpose/registration.py'), 139 | 'diffpose.registration.PoseRegressor.forward': ( 'api/registration.html#poseregressor.forward', 140 | 'diffpose/registration.py'), 141 | 'diffpose.registration.SparseRegistration': ( 'api/registration.html#sparseregistration', 142 | 'diffpose/registration.py'), 143 | 'diffpose.registration.SparseRegistration.__init__': ( 'api/registration.html#sparseregistration.__init__', 144 | 'diffpose/registration.py'), 145 | 'diffpose.registration.SparseRegistration.forward': ( 'api/registration.html#sparseregistration.forward', 146 | 'diffpose/registration.py'), 147 | 'diffpose.registration.SparseRegistration.get_current_pose': ( 'api/registration.html#sparseregistration.get_current_pose', 148 | 'diffpose/registration.py'), 149 | 'diffpose.registration.VectorizedNormalizedCrossCorrelation2d': ( 'api/registration.html#vectorizednormalizedcrosscorrelation2d', 150 | 'diffpose/registration.py'), 151 | 'diffpose.registration.VectorizedNormalizedCrossCorrelation2d.__init__': ( 'api/registration.html#vectorizednormalizedcrosscorrelation2d.__init__', 152 | 'diffpose/registration.py'), 153 | 'diffpose.registration.VectorizedNormalizedCrossCorrelation2d.forward': ( 'api/registration.html#vectorizednormalizedcrosscorrelation2d.forward', 154 | 'diffpose/registration.py'), 155 | 'diffpose.registration.VectorizedNormalizedCrossCorrelation2d.forward_compute': ( 'api/registration.html#vectorizednormalizedcrosscorrelation2d.forward_compute', 156 | 'diffpose/registration.py'), 157 | 'diffpose.registration.VectorizedNormalizedCrossCorrelation2d.norm': ( 'api/registration.html#vectorizednormalizedcrosscorrelation2d.norm', 158 | 'diffpose/registration.py'), 159 | 'diffpose.registration.img_to_patches': ( 'api/registration.html#img_to_patches', 160 | 'diffpose/registration.py'), 161 | 'diffpose.registration.mask_to_img': ( 'api/registration.html#mask_to_img', 162 | 'diffpose/registration.py'), 163 | 'diffpose.registration.pred_to_patches': ( 'api/registration.html#pred_to_patches', 164 | 'diffpose/registration.py'), 165 | 'diffpose.registration.preprocess': ('api/registration.html#preprocess', 'diffpose/registration.py'), 166 | 'diffpose.registration.vector_to_img': ( 'api/registration.html#vector_to_img', 167 | 'diffpose/registration.py')}, 168 | 'diffpose.visualization': { 'diffpose.visualization._overlay_edges': ( 'api/visualization.html#_overlay_edges', 169 | 'diffpose/visualization.py'), 170 | 'diffpose.visualization.fiducials_3d_to_projected_fiducials_3d': ( 'api/visualization.html#fiducials_3d_to_projected_fiducials_3d', 171 | 'diffpose/visualization.py'), 172 | 'diffpose.visualization.fiducials_to_mesh': ( 'api/visualization.html#fiducials_to_mesh', 173 | 'diffpose/visualization.py'), 174 | 'diffpose.visualization.lines_to_mesh': ( 'api/visualization.html#lines_to_mesh', 175 | 'diffpose/visualization.py'), 176 | 'diffpose.visualization.overlay_edges': ( 'api/visualization.html#overlay_edges', 177 | 'diffpose/visualization.py')}}} 178 | -------------------------------------------------------------------------------- /diffpose/calibration.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/02_calibration.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['RigidTransform', 'convert', 'perspective_projection'] 5 | 6 | # %% ../notebooks/api/02_calibration.ipynb 4 7 | import torch 8 | 9 | # %% ../notebooks/api/02_calibration.ipynb 6 10 | from typing import Optional 11 | 12 | from beartype import beartype 13 | from diffdrr.utils import Transform3d 14 | from diffdrr.utils import convert as convert_so3 15 | from diffdrr.utils import se3_exp_map, se3_log_map 16 | from jaxtyping import Float, jaxtyped 17 | 18 | # %% ../notebooks/api/02_calibration.ipynb 7 19 | @beartype 20 | class RigidTransform(Transform3d): 21 | """Wrapper of pytorch3d.transforms.Transform3d with extra functionalities.""" 22 | 23 | @jaxtyped(typechecker=beartype) 24 | def __init__( 25 | self, 26 | R: Float[torch.Tensor, "..."], 27 | t: Float[torch.Tensor, "... 3"], 28 | parameterization: str = "matrix", 29 | convention: Optional[str] = None, 30 | device=None, 31 | dtype=torch.float32, 32 | ): 33 | if device is None and (R.device == t.device): 34 | device = R.device 35 | 36 | R = convert_so3(R, parameterization, "matrix", convention) 37 | if R.dim() == 2 and t.dim() == 1: 38 | R = R.unsqueeze(0) 39 | t = t.unsqueeze(0) 40 | assert (batch_size := len(R)) == len(t), "R and t need same batch size" 41 | 42 | matrix = torch.zeros(batch_size, 4, 4, device=device, dtype=dtype) 43 | matrix[..., :3, :3] = R.transpose(-1, -2) 44 | matrix[..., 3, :3] = t 45 | matrix[..., 3, 3] = 1 46 | 47 | super().__init__(matrix=matrix, device=device, dtype=dtype) 48 | 49 | def get_rotation(self, parameterization=None, convention=None): 50 | R = self.get_matrix()[..., :3, :3].transpose(-1, -2) 51 | if parameterization is not None: 52 | R = convert_so3(R, "matrix", parameterization, None, convention) 53 | return R 54 | 55 | def get_translation(self): 56 | return self.get_matrix()[..., 3, :3] 57 | 58 | def inverse(self): 59 | """Closed-form inverse for rigid transforms.""" 60 | R = self.get_rotation().transpose(-1, -2) 61 | t = self.get_translation() 62 | t = -torch.einsum("bij,bj->bi", R, t) 63 | return RigidTransform(R, t, device=self.device, dtype=self.dtype) 64 | 65 | def compose(self, other): 66 | T = super().compose(other) 67 | R = T.get_matrix()[..., :3, :3].transpose(-1, -2) 68 | t = T.get_matrix()[..., 3, :3] 69 | return RigidTransform(R, t, device=self.device, dtype=self.dtype) 70 | 71 | def clone(self): 72 | R = self.get_matrix()[..., :3, :3].transpose(-1, -2).clone() 73 | t = self.get_matrix()[..., 3, :3].clone() 74 | return RigidTransform(R, t, device=self.device, dtype=self.dtype) 75 | 76 | def get_se3_log(self): 77 | return se3_log_map(self.get_matrix()) 78 | 79 | # %% ../notebooks/api/02_calibration.ipynb 8 80 | def convert( 81 | transform, 82 | input_parameterization, 83 | output_parameterization, 84 | input_convention=None, 85 | output_convention=None, 86 | ): 87 | """Convert between representations of SE(3).""" 88 | 89 | # Convert any input parameterization to a RigidTransform 90 | if input_parameterization == "se3_log_map": 91 | transform = torch.concat([transform[1], transform[0]], axis=-1) 92 | matrix = se3_exp_map(transform).transpose(-1, -2) 93 | transform = RigidTransform( 94 | R=matrix[..., :3, :3], 95 | t=matrix[..., :3, 3], 96 | device=matrix.device, 97 | dtype=matrix.dtype, 98 | ) 99 | elif input_parameterization == "se3_exp_map": 100 | pass 101 | else: 102 | transform = RigidTransform( 103 | R=transform[0], 104 | t=transform[1], 105 | parameterization=input_parameterization, 106 | convention=input_convention, 107 | ) 108 | 109 | # Convert the RigidTransform to any output 110 | if output_parameterization == "se3_exp_map": 111 | return transform 112 | elif output_parameterization == "se3_log_map": 113 | se3_log = transform.get_se3_log() 114 | log_t_vee = se3_log[..., :3] 115 | log_R_vee = se3_log[..., 3:] 116 | return log_R_vee, log_t_vee 117 | else: 118 | return ( 119 | transform.get_rotation(output_parameterization, output_convention), 120 | transform.get_translation(), 121 | ) 122 | 123 | # %% ../notebooks/api/02_calibration.ipynb 10 124 | @jaxtyped(typechecker=beartype) 125 | def perspective_projection( 126 | extrinsic: RigidTransform, # Extrinsic camera matrix (world to camera) 127 | intrinsic: Float[torch.Tensor, "3 3"], # Intrinsic camera matrix (camera to image) 128 | x: Float[torch.Tensor, "b n 3"], # World coordinates 129 | ) -> Float[torch.Tensor, "b n 2"]: 130 | x = extrinsic.transform_points(x) 131 | x = torch.einsum("ij, bnj -> bni", intrinsic, x) 132 | z = x[..., -1].unsqueeze(-1).clone() 133 | x = x / z 134 | return x[..., :2] 135 | -------------------------------------------------------------------------------- /diffpose/deepfluoro.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/00_deepfluoro.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['DeepFluoroDataset', 'convert_deepfluoro_to_diffdrr', 'convert_diffdrr_to_deepfluoro', 'Evaluator', 'preprocess', 5 | 'get_random_offset', 'Transforms'] 6 | 7 | # %% ../notebooks/api/00_deepfluoro.ipynb 3 8 | from pathlib import Path 9 | from typing import Optional, Union 10 | 11 | import h5py 12 | import numpy as np 13 | import torch 14 | from beartype import beartype 15 | 16 | from .calibration import RigidTransform, perspective_projection 17 | 18 | # %% ../notebooks/api/00_deepfluoro.ipynb 5 19 | @beartype 20 | class DeepFluoroDataset(torch.utils.data.Dataset): 21 | """ 22 | Get X-ray projections and poses from specimens in the `DeepFluoro` dataset. 23 | 24 | Given a specimen ID and projection index, returns the projection and the camera matrix for DiffDRR. 25 | """ 26 | 27 | def __init__( 28 | self, 29 | id_number: int, # Specimen number (1-6) 30 | filename: Optional[Union[str, Path]] = None, # Path to DeepFluoro h5 file 31 | preprocess: bool = True, # Preprocess X-rays 32 | ): 33 | # Load the volume 34 | ( 35 | self.specimen, 36 | self.projections, 37 | self.volume, 38 | self.spacing, 39 | self.lps2volume, 40 | self.intrinsic, 41 | self.extrinsic, 42 | self.focal_len, 43 | self.x0, 44 | self.y0, 45 | ) = load_deepfluoro_dataset(id_number, filename) 46 | self.preprocess = preprocess 47 | 48 | # Get the isocenter pose (AP viewing angle at volume isocenter) 49 | isocenter_rot = torch.tensor([[torch.pi / 2, 0.0, -torch.pi / 2]]) 50 | isocenter_xyz = torch.tensor(self.volume.shape) * self.spacing / 2 51 | isocenter_xyz = isocenter_xyz.unsqueeze(0) 52 | self.isocenter_pose = RigidTransform( 53 | isocenter_rot, isocenter_xyz, "euler_angles", "ZYX" 54 | ) 55 | 56 | # Camera matrices and fiducials for the specimen 57 | self.fiducials = get_3d_fiducials(self.specimen) 58 | 59 | # Miscellaneous transformation matrices for wrangling SE(3) poses 60 | self.flip_xz = RigidTransform( 61 | torch.tensor([[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]]), 62 | torch.zeros(3), 63 | ) 64 | self.translate = RigidTransform( 65 | torch.eye(3), 66 | torch.tensor([-self.focal_len / 2, 0.0, 0.0]), 67 | ) 68 | self.flip_180 = RigidTransform( 69 | torch.tensor([[1.0, 0.0, 0.0], [0.0, -1.0, 0.0], [0.0, 0.0, -1.0]]), 70 | torch.zeros(3), 71 | ) 72 | 73 | def __len__(self): 74 | return len(self.projections) 75 | 76 | def __iter__(self): 77 | return iter(self[idx] for idx in range(len(self))) 78 | 79 | def __getitem__(self, idx): 80 | """ 81 | (1) Swap the x- and z-axes 82 | (2) Reverse the x-axis to make the matrix E(3) -> SE(3) 83 | (3) Move the camera to the origin 84 | (4) Rotate the detector plane by 180, if offset 85 | (5) Form the full SE(3) transformation matrix 86 | """ 87 | projection = self.projections[f"{idx:03d}"] 88 | img = torch.from_numpy(projection["image/pixels"][:]) 89 | world2volume = torch.from_numpy(projection["gt-poses/cam-to-pelvis-vol"][:]) 90 | world2volume = RigidTransform(world2volume[:3, :3], world2volume[:3, 3]) 91 | pose = convert_deepfluoro_to_diffdrr(self, world2volume) 92 | 93 | # Handle rotations in the imaging dataset 94 | if self._rot_180_for_up(idx): 95 | img = torch.rot90(img, k=2) 96 | pose = self.flip_180.compose(pose) 97 | 98 | # Optionally, preprocess the images 99 | img = img.unsqueeze(0).unsqueeze(0) 100 | if self.preprocess: 101 | img = preprocess(img) 102 | 103 | return img, pose 104 | 105 | def get_2d_fiducials(self, idx, pose): 106 | # Get the fiducials from the true camera pose 107 | _, true_pose = self.__getitem__(idx) 108 | extrinsic = ( 109 | self.lps2volume.inverse() 110 | .compose(true_pose.inverse()) 111 | .compose(self.translate) 112 | .compose(self.flip_xz) 113 | ) 114 | true_fiducials = perspective_projection( 115 | extrinsic, self.intrinsic, self.fiducials 116 | ) 117 | 118 | # Get the fiducials from the predicted camera pose 119 | extrinsic = ( 120 | self.lps2volume.inverse() 121 | .compose(pose.cpu().inverse()) 122 | .compose(self.translate) 123 | .compose(self.flip_xz) 124 | ) 125 | pred_fiducials = perspective_projection( 126 | extrinsic, self.intrinsic, self.fiducials 127 | ) 128 | 129 | if self.preprocess: 130 | true_fiducials -= 50 131 | pred_fiducials -= 50 132 | 133 | return true_fiducials, pred_fiducials 134 | 135 | def _rot_180_for_up(self, idx): 136 | return self.projections[f"{idx:03d}"]["rot-180-for-up"][()] 137 | 138 | # %% ../notebooks/api/00_deepfluoro.ipynb 6 139 | def convert_deepfluoro_to_diffdrr(specimen, pose: RigidTransform): 140 | """Transform the camera coordinate system used in DeepFluoro to the convention used by DiffDRR.""" 141 | return ( 142 | specimen.translate.compose(specimen.flip_xz) 143 | .compose(specimen.extrinsic.inverse()) 144 | .compose(pose) 145 | .compose(specimen.lps2volume.inverse()) 146 | ) 147 | 148 | 149 | def convert_diffdrr_to_deepfluoro(specimen, pose: RigidTransform): 150 | """Transform the camera coordinate system used in DiffDRR to the convention used by DeepFluoro.""" 151 | return ( 152 | specimen.lps2volume.inverse() 153 | .compose(pose.inverse()) 154 | .compose(specimen.translate) 155 | .compose(specimen.flip_xz) 156 | ) 157 | 158 | # %% ../notebooks/api/00_deepfluoro.ipynb 7 159 | from torch.nn.functional import pad 160 | 161 | from .calibration import perspective_projection 162 | 163 | 164 | class Evaluator: 165 | def __init__(self, specimen, idx): 166 | # Save matrices to device 167 | self.translate = specimen.translate 168 | self.flip_xz = specimen.flip_xz 169 | self.intrinsic = specimen.intrinsic 170 | self.intrinsic_inv = specimen.intrinsic.inverse() 171 | 172 | # Get gt fiducial locations 173 | self.specimen = specimen 174 | self.fiducials = specimen.fiducials 175 | gt_pose = specimen[idx][1] 176 | self.true_projected_fiducials = self.project(gt_pose) 177 | 178 | def project(self, pose): 179 | extrinsic = convert_diffdrr_to_deepfluoro(self.specimen, pose) 180 | x = perspective_projection(extrinsic, self.intrinsic, self.fiducials) 181 | x = -self.specimen.focal_len * torch.einsum( 182 | "ij, bnj -> bni", 183 | self.intrinsic_inv, 184 | pad(x, (0, 1), value=1), # Convert to homogenous coordinates 185 | ) 186 | extrinsic = ( 187 | self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose) 188 | ) 189 | return extrinsic.transform_points(x) 190 | 191 | def __call__(self, pose): 192 | pred_projected_fiducials = self.project(pose) 193 | registration_error = ( 194 | (self.true_projected_fiducials - pred_projected_fiducials) 195 | .norm(dim=-1) 196 | .mean() 197 | ) 198 | registration_error *= 0.194 # Pixel spacing is 0.194 mm / pixel isotropic 199 | return registration_error 200 | 201 | # %% ../notebooks/api/00_deepfluoro.ipynb 8 202 | from diffdrr.utils import parse_intrinsic_matrix 203 | 204 | 205 | def load_deepfluoro_dataset(id_number, filename): 206 | # Open the H5 file for the dataset 207 | if filename is None: 208 | root = Path(__file__).parent.parent.absolute() 209 | filename = root / "data/ipcai_2020_full_res_data.h5" 210 | f = h5py.File(filename, "r") 211 | ( 212 | intrinsic, 213 | extrinsic, 214 | num_cols, 215 | num_rows, 216 | proj_col_spacing, 217 | proj_row_spacing, 218 | ) = parse_proj_params(f) 219 | focal_len, x0, y0 = parse_intrinsic_matrix( 220 | intrinsic, 221 | num_rows, 222 | num_cols, 223 | proj_row_spacing, 224 | proj_col_spacing, 225 | ) 226 | 227 | # Try to load the particular specimen 228 | assert id_number in {1, 2, 3, 4, 5, 6} 229 | specimen_id = [ 230 | "17-1882", 231 | "18-1109", 232 | "18-0725", 233 | "18-2799", 234 | "18-2800", 235 | "17-1905", 236 | ][id_number - 1] 237 | specimen = f[specimen_id] 238 | projections = specimen["projections"] 239 | 240 | # Parse the volume 241 | volume, spacing, lps2volume = parse_volume(specimen) 242 | return ( 243 | specimen, 244 | projections, 245 | volume, 246 | spacing, 247 | lps2volume, 248 | intrinsic, 249 | extrinsic, 250 | focal_len, 251 | x0, 252 | y0, 253 | ) 254 | 255 | 256 | def parse_volume(specimen): 257 | # Parse the volume 258 | spacing = specimen["vol/spacing"][:].flatten() 259 | volume = specimen["vol/pixels"][:].astype(np.float32) 260 | volume = np.swapaxes(volume, 0, 2)[::-1].copy() 261 | 262 | # Parse the translation matrix from LPS coordinates to volume coordinates 263 | origin = torch.from_numpy(specimen["vol/origin"][:].flatten()) 264 | lps2volume = RigidTransform(torch.eye(3), origin) 265 | return volume, spacing, lps2volume 266 | 267 | 268 | def parse_proj_params(f): 269 | proj_params = f["proj-params"] 270 | extrinsic = torch.from_numpy(proj_params["extrinsic"][:]) 271 | extrinsic = RigidTransform(extrinsic[..., :3, :3], extrinsic[:3, 3]) 272 | intrinsic = torch.from_numpy(proj_params["intrinsic"][:]) 273 | num_cols = float(proj_params["num-cols"][()]) 274 | num_rows = float(proj_params["num-rows"][()]) 275 | proj_col_spacing = float(proj_params["pixel-col-spacing"][()]) 276 | proj_row_spacing = float(proj_params["pixel-row-spacing"][()]) 277 | return intrinsic, extrinsic, num_cols, num_rows, proj_col_spacing, proj_row_spacing 278 | 279 | 280 | def get_3d_fiducials(specimen): 281 | fiducials = [] 282 | for landmark in specimen["vol-landmarks"]: 283 | pt_3d = specimen["vol-landmarks"][landmark][:] 284 | pt_3d = torch.from_numpy(pt_3d) 285 | fiducials.append(pt_3d) 286 | return torch.stack(fiducials, dim=0).permute(2, 0, 1) 287 | 288 | # %% ../notebooks/api/00_deepfluoro.ipynb 9 289 | from torchvision.transforms.functional import center_crop, gaussian_blur 290 | 291 | 292 | def preprocess(img, size=None, initial_energy=torch.tensor(65487.0)): 293 | """ 294 | Recover the line integral: $L[i,j] = \log I_0 - \log I_f[i,j]$ 295 | 296 | (1) Remove edge due to collimator 297 | (2) Smooth the image to make less noisy 298 | (3) Subtract the log initial energy for each ray 299 | (4) Recover the line integral image 300 | (5) Rescale image to [0, 1] 301 | """ 302 | img = center_crop(img, (1436, 1436)) 303 | img = gaussian_blur(img, (5, 5), sigma=1.0) 304 | img = initial_energy.log() - img.log() 305 | img = (img - img.min()) / (img.max() - img.min()) 306 | return img 307 | 308 | # %% ../notebooks/api/00_deepfluoro.ipynb 26 309 | from .calibration import RigidTransform, convert 310 | 311 | 312 | @beartype 313 | def get_random_offset(batch_size: int, device) -> RigidTransform: 314 | r1 = torch.distributions.Normal(0, 0.2).sample((batch_size,)) 315 | r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,)) 316 | r3 = torch.distributions.Normal(0, 0.25).sample((batch_size,)) 317 | t1 = torch.distributions.Normal(10, 70).sample((batch_size,)) 318 | t2 = torch.distributions.Normal(250, 90).sample((batch_size,)) 319 | t3 = torch.distributions.Normal(5, 50).sample((batch_size,)) 320 | log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device) 321 | log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device) 322 | return convert( 323 | [log_R_vee, log_t_vee], 324 | "se3_log_map", 325 | "se3_exp_map", 326 | ) 327 | 328 | # %% ../notebooks/api/00_deepfluoro.ipynb 32 329 | from torchvision.transforms import Compose, Lambda, Normalize, Resize 330 | 331 | 332 | class Transforms: 333 | def __init__( 334 | self, 335 | size: int, # Dimension to resize image 336 | eps: float = 1e-6, 337 | ): 338 | """Transform X-rays and DRRs before inputting to CNN.""" 339 | self.transforms = Compose( 340 | [ 341 | Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + eps)), 342 | Resize((size, size), antialias=True), 343 | Normalize(mean=0.3080, std=0.1494), 344 | ] 345 | ) 346 | 347 | def __call__(self, x): 348 | return self.transforms(x) 349 | -------------------------------------------------------------------------------- /diffpose/jacobians.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/06_jacobians.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['JacobianDRR', 'gradient_matching', 'plot_img_jacobian'] 5 | 6 | # %% ../notebooks/api/06_jacobians.ipynb 3 7 | import torch 8 | 9 | 10 | class JacobianDRR(torch.nn.Module): 11 | """Computes the Jacobian of a DRR wrt pose parameters.""" 12 | 13 | def __init__(self, drr, rotation, translation, parameterization, convention=None): 14 | super().__init__() 15 | self.drr = drr 16 | self.rotation = torch.nn.Parameter(rotation.clone()) 17 | self.translation = torch.nn.Parameter(translation.clone()) 18 | self.parameterization = parameterization 19 | self.convention = convention 20 | 21 | def forward(self): 22 | I = self.cast(self.rotation, self.translation) 23 | J = torch.autograd.functional.jacobian( 24 | self.cast, 25 | (self.rotation, self.translation), 26 | vectorize=True, 27 | strategy="forward-mode", 28 | ) 29 | J = torch.concat([self.permute(j) for j in J], dim=0) 30 | return I, J 31 | 32 | def cast(self, rotation, translation): 33 | return self.drr(rotation, translation, self.parameterization, self.convention) 34 | 35 | def permute(self, x): 36 | return x.permute(-1, 0, 2, 3, 1, 4)[..., 0, 0] 37 | 38 | # %% ../notebooks/api/06_jacobians.ipynb 4 39 | def gradient_matching(J0, J1): 40 | J0 /= J0.norm(dim=[-1, -2], keepdim=True) 41 | J1 /= J1.norm(dim=[-1, -2], keepdim=True) 42 | return (J0 - J1).norm() 43 | 44 | # %% ../notebooks/api/06_jacobians.ipynb 5 45 | import matplotlib.pyplot as plt 46 | from matplotlib.ticker import FuncFormatter 47 | 48 | 49 | def plot_img_jacobian(I, J, **kwargs): 50 | def fmt(x, pos): 51 | a, b = f"{x:.0e}".split("e") 52 | a = float(a) 53 | b = int(b) 54 | if a == 0: 55 | return "0" 56 | elif b == 0: 57 | if a < 0: 58 | return "-1" 59 | else: 60 | return "1" 61 | elif a < 0: 62 | return rf"$-10^{{{b}}}$" 63 | else: 64 | return rf"$10^{{{b}}}$" 65 | 66 | plt.figure(figsize=(10, 4), dpi=300, constrained_layout=True) 67 | plt.subplot(2, 4, 2) 68 | plt.title("J(yaw)") 69 | plt.imshow(J[0].squeeze().cpu().detach(), **kwargs) 70 | plt.colorbar(format=FuncFormatter(fmt)) 71 | plt.axis("off") 72 | plt.subplot(2, 4, 3) 73 | plt.title("J(pitch)") 74 | plt.imshow(J[1].squeeze().cpu().detach(), **kwargs) 75 | plt.colorbar(format=FuncFormatter(fmt)) 76 | plt.axis("off") 77 | plt.subplot(2, 4, 4) 78 | plt.title("J(roll)") 79 | plt.imshow(J[2].squeeze().cpu().detach(), **kwargs) 80 | plt.colorbar(format=FuncFormatter(fmt)) 81 | plt.axis("off") 82 | plt.subplot(2, 4, 6) 83 | plt.title("J(x)") 84 | plt.imshow(J[3].squeeze().cpu().detach(), **kwargs) 85 | plt.colorbar(format=FuncFormatter(fmt)) 86 | plt.axis("off") 87 | plt.subplot(2, 4, 7) 88 | plt.title("J(y)") 89 | plt.imshow(J[4].squeeze().cpu().detach(), **kwargs) 90 | plt.colorbar(format=FuncFormatter(fmt)) 91 | plt.axis("off") 92 | plt.subplot(2, 4, 8) 93 | plt.title("J(z)") 94 | plt.imshow(J[5].squeeze().cpu().detach(), **kwargs) 95 | plt.colorbar(format=FuncFormatter(fmt)) 96 | plt.axis("off") 97 | plt.subplot(2, 4, 1) 98 | plt.title("img") 99 | plt.imshow(I.cpu().detach().squeeze(), cmap="gray") 100 | plt.axis("off") 101 | plt.colorbar() 102 | plt.show() 103 | -------------------------------------------------------------------------------- /diffpose/ljubljana.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/01_ljubljana.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['LjubljanaDataset', 'get_random_offset', 'Evaluator', 'Transforms'] 5 | 6 | # %% ../notebooks/api/01_ljubljana.ipynb 3 7 | from pathlib import Path 8 | from typing import Optional, Union 9 | 10 | import h5py 11 | import torch 12 | from beartype import beartype 13 | 14 | from .calibration import RigidTransform 15 | 16 | # %% ../notebooks/api/01_ljubljana.ipynb 5 17 | from diffdrr.utils import parse_intrinsic_matrix 18 | 19 | 20 | @beartype 21 | class LjubljanaDataset(torch.utils.data.Dataset): 22 | """ 23 | Get X-ray projections and poses from specimens in the `Ljubljana` dataset. 24 | 25 | Given a specimen ID and projection index, returns the projection and the camera matrix for DiffDRR. 26 | """ 27 | 28 | def __init__( 29 | self, 30 | view: str, # "ap" or "lat" or "ap-max" or "lat-max" 31 | filename: Optional[Union[str, Path]] = None, # Path to DeepFluoro h5 file 32 | preprocess: bool = True, # Preprocess X-rays 33 | ): 34 | self.view = view 35 | self.preprocess = preprocess 36 | 37 | if filename is None: 38 | root = Path(__file__).parent.parent.absolute() 39 | filename = root / "data/ljubljana.h5" 40 | self.f = h5py.File(filename) 41 | 42 | self.flip_xz = RigidTransform( 43 | torch.tensor( 44 | [ 45 | [0.0, 0.0, -1.0], 46 | [0.0, 1.0, 0.0], 47 | [1.0, 0.0, 0.0], 48 | ] 49 | ), 50 | torch.zeros(3), 51 | ) 52 | 53 | def __len__(self): 54 | return 10 55 | 56 | def __iter__(self): 57 | return iter(self[idx] for idx in range(len(self))) 58 | 59 | def __getitem__(self, idx): 60 | idx += 1 61 | extrinsic = self.f[f"subject{idx:02d}/proj-{self.view}/extrinsic"][:] 62 | extrinsic = torch.from_numpy(extrinsic) 63 | extrinsic = RigidTransform(extrinsic[..., :3, :3], extrinsic[:3, 3]) 64 | 65 | intrinsic = self.f[f"subject{idx:02d}/proj-{self.view}/intrinsic"][:] 66 | intrinsic = torch.from_numpy(intrinsic) 67 | 68 | delx = self.f[f"subject{idx:02d}/proj-{self.view}/col-spacing"][()] 69 | dely = self.f[f"subject{idx:02d}/proj-{self.view}/row-spacing"][()] 70 | 71 | img = torch.from_numpy(self.f[f"subject{idx:02d}/proj-{self.view}/pixels"][:]) 72 | if self.preprocess: 73 | img += 1 74 | img = img.max().log() - img.log() 75 | height, width = img.shape 76 | img = img.unsqueeze(0).unsqueeze(0) 77 | 78 | focal_len, x0, y0 = parse_intrinsic_matrix( 79 | intrinsic, 80 | height, 81 | width, 82 | dely, 83 | delx, 84 | ) 85 | 86 | translate = RigidTransform( 87 | torch.eye(3), 88 | torch.tensor([-focal_len / 2, 0.0, 0.0]), 89 | ) 90 | pose = translate.compose(self.flip_xz).compose(extrinsic.inverse()) 91 | 92 | volume = self.f[f"subject{idx:02d}/volume/pixels"][:] 93 | spacing = self.f[f"subject{idx:02d}/volume/spacing"][:] 94 | 95 | isocenter_rot = torch.tensor([[torch.pi / 2, 0.0, -torch.pi / 2]]) 96 | isocenter_xyz = torch.tensor(volume.shape) * spacing / 2 97 | isocenter_xyz = isocenter_xyz.unsqueeze(0) 98 | isocenter_pose = RigidTransform( 99 | isocenter_rot, isocenter_xyz, "euler_angles", "ZYX" 100 | ) 101 | 102 | return ( 103 | volume, 104 | spacing, 105 | focal_len, 106 | height, 107 | width, 108 | delx, 109 | dely, 110 | x0, 111 | y0, 112 | img, 113 | pose, 114 | isocenter_pose, 115 | ) 116 | 117 | # %% ../notebooks/api/01_ljubljana.ipynb 7 118 | from .calibration import RigidTransform, convert 119 | 120 | 121 | @beartype 122 | def get_random_offset(view, batch_size: int, device) -> RigidTransform: 123 | if view == "ap": 124 | t1 = torch.distributions.Normal(-6, 30).sample((batch_size,)) 125 | t2 = torch.distributions.Normal(175, 30).sample((batch_size,)) 126 | t3 = torch.distributions.Normal(-5, 30).sample((batch_size,)) 127 | r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,)) 128 | r2 = torch.distributions.Normal(0, 0.1).sample((batch_size,)) 129 | r3 = torch.distributions.Normal(-0.15, 0.25).sample((batch_size,)) 130 | elif view == "lat": 131 | t1 = torch.distributions.Normal(75, 30).sample((batch_size,)) 132 | t2 = torch.distributions.Normal(-80, 30).sample((batch_size,)) 133 | t3 = torch.distributions.Normal(-5, 30).sample((batch_size,)) 134 | r1 = torch.distributions.Normal(0, 0.1).sample((batch_size,)) 135 | r2 = torch.distributions.Normal(0, 0.05).sample((batch_size,)) 136 | r3 = torch.distributions.Normal(1.55, 0.05).sample((batch_size,)) 137 | else: 138 | raise ValueError(f"view must be 'ap' or 'lat', not '{view}'") 139 | 140 | log_R_vee = torch.stack([r1, r2, r3], dim=1).to(device) 141 | log_t_vee = torch.stack([t1, t2, t3], dim=1).to(device) 142 | return convert( 143 | [log_R_vee, log_t_vee], 144 | "se3_log_map", 145 | "se3_exp_map", 146 | ) 147 | 148 | # %% ../notebooks/api/01_ljubljana.ipynb 9 149 | from torch.nn.functional import pad 150 | 151 | from .calibration import perspective_projection 152 | 153 | 154 | class Evaluator: 155 | def __init__(self, subject, idx): 156 | # Get transformation matrices of the camera 157 | (_, _, focal_len, _, _, _, _, _, _, _, gt_pose, _) = subject[idx] 158 | self.focal_len = focal_len 159 | intrinsic = subject.f[f"subject{idx + 1:02d}/proj-{subject.view}/intrinsic"][:] 160 | self.intrinsic = torch.from_numpy(intrinsic) 161 | self.translate = RigidTransform( 162 | torch.eye(3), 163 | torch.tensor([-self.focal_len / 2, 0.0, 0.0]), 164 | ) 165 | self.flip_xz = subject.flip_xz 166 | 167 | # Get the ground truth projections 168 | self.points = torch.from_numpy( 169 | subject.f[f"subject{idx + 1:02d}/points"][:] 170 | ).unsqueeze(0) 171 | self.true_projected_fiducials = self.project(gt_pose) 172 | 173 | def project(self, pose): 174 | extrinsic = pose.inverse().compose(self.translate).compose(self.flip_xz) 175 | x = perspective_projection(extrinsic, self.intrinsic, self.points) 176 | x = self.focal_len * torch.einsum( 177 | "ij, bnj -> bni", 178 | self.intrinsic.inverse(), 179 | pad(x, (0, 1), value=1), # Convert to homogenous coordinates 180 | ) 181 | extrinsic = ( 182 | self.flip_xz.inverse().compose(self.translate.inverse()).compose(pose) 183 | ) 184 | return extrinsic.transform_points(x) 185 | 186 | def __call__(self, pose): 187 | pred_projected_fiducials = self.project(pose) 188 | registration_error = ( 189 | (self.true_projected_fiducials - pred_projected_fiducials) 190 | .norm(dim=-1) 191 | .mean() 192 | ) 193 | registration_error *= 0.154 # Pixel spacing is 0.154 mm / pixel isotropic 194 | return registration_error 195 | 196 | # %% ../notebooks/api/01_ljubljana.ipynb 12 197 | from torchvision.transforms import Compose, Lambda, Normalize, Resize 198 | 199 | 200 | class Transforms: 201 | def __init__( 202 | self, 203 | height: int, 204 | width: int, 205 | eps: float = 1e-6, 206 | ): 207 | """Transform X-rays and DRRs before inputting to CNN.""" 208 | self.transforms = Compose( 209 | [ 210 | Lambda(lambda x: (x - x.min()) / (x.max() - x.min() + eps)), 211 | Resize((height, width), antialias=True), 212 | Normalize(mean=0.0774, std=0.0569), 213 | ] 214 | ) 215 | 216 | def __call__(self, x): 217 | return self.transforms(x) 218 | -------------------------------------------------------------------------------- /diffpose/metrics.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/04_metrics.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['NormalizedCrossCorrelation', 'MultiscaleNormalizedCrossCorrelation', 'GradientNormalizedCrossCorrelation', 5 | 'GeodesicSO3', 'GeodesicTranslation', 'GeodesicSE3', 'DoubleGeodesic'] 6 | 7 | # %% ../notebooks/api/04_metrics.ipynb 3 8 | from diffdrr.metrics import ( 9 | GradientNormalizedCrossCorrelation2d, 10 | MultiscaleNormalizedCrossCorrelation2d, 11 | NormalizedCrossCorrelation2d, 12 | ) 13 | from torchmetrics import Metric 14 | 15 | # %% ../notebooks/api/04_metrics.ipynb 5 16 | class CustomMetric(Metric): 17 | is_differentiable: True 18 | 19 | def __init__(self, LossClass, **kwargs): 20 | super().__init__() 21 | self.lossfn = LossClass(**kwargs) 22 | self.add_state("loss", default=torch.tensor(0.0), dist_reduce_fx="sum") 23 | self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") 24 | 25 | def update(self, preds, target): 26 | self.loss += self.lossfn(preds, target).sum() 27 | self.count += len(preds) 28 | 29 | def compute(self): 30 | return self.loss.float() / self.count 31 | 32 | # %% ../notebooks/api/04_metrics.ipynb 7 33 | class NormalizedCrossCorrelation(CustomMetric): 34 | """`torchmetric` wrapper for NCC.""" 35 | 36 | higher_is_better: True 37 | 38 | def __init__(self, patch_size=None): 39 | super().__init__(NormalizedCrossCorrelation2d, patch_size=patch_size) 40 | 41 | 42 | class MultiscaleNormalizedCrossCorrelation(CustomMetric): 43 | """`torchmetric` wrapper for Multiscale NCC.""" 44 | 45 | higher_is_better: True 46 | 47 | def __init__(self, patch_sizes, patch_weights): 48 | super().__init__( 49 | MultiscaleNormalizedCrossCorrelation2d, 50 | patch_sizes=patch_sizes, 51 | patch_weights=patch_weights, 52 | ) 53 | 54 | 55 | class GradientNormalizedCrossCorrelation(CustomMetric): 56 | """`torchmetric` wrapper for GradNCC.""" 57 | 58 | higher_is_better: True 59 | 60 | def __init__(self, patch_size=None): 61 | super().__init__(GradientNormalizedCrossCorrelation2d, patch_size=patch_size) 62 | 63 | # %% ../notebooks/api/04_metrics.ipynb 9 64 | import torch 65 | from beartype import beartype 66 | from diffdrr.utils import ( 67 | convert, 68 | so3_log_map, 69 | so3_relative_angle, 70 | so3_rotation_angle, 71 | standardize_quaternion, 72 | ) 73 | from jaxtyping import Float, jaxtyped 74 | 75 | from .calibration import RigidTransform 76 | 77 | # %% ../notebooks/api/04_metrics.ipynb 10 78 | class GeodesicSO3(torch.nn.Module): 79 | """Calculate the angular distance between two rotations in SO(3).""" 80 | 81 | def __init__(self): 82 | super().__init__() 83 | 84 | @jaxtyped(typechecker=beartype) 85 | def forward( 86 | self, 87 | pose_1: RigidTransform, 88 | pose_2: RigidTransform, 89 | ) -> Float[torch.Tensor, "b"]: 90 | r1 = pose_1.get_rotation() 91 | r2 = pose_2.get_rotation() 92 | rdiff = r1 @ r2.transpose(-1, -2) 93 | return so3_log_map(rdiff).norm(dim=-1) 94 | 95 | 96 | class GeodesicTranslation(torch.nn.Module): 97 | """Calculate the angular distance between two rotations in SO(3).""" 98 | 99 | def __init__(self): 100 | super().__init__() 101 | 102 | @jaxtyped(typechecker=beartype) 103 | def forward( 104 | self, 105 | pose_1: RigidTransform, 106 | pose_2: RigidTransform, 107 | ) -> Float[torch.Tensor, "b"]: 108 | t1 = pose_1.get_translation() 109 | t2 = pose_2.get_translation() 110 | return (t1 - t2).norm(dim=1) 111 | 112 | # %% ../notebooks/api/04_metrics.ipynb 11 113 | class GeodesicSE3(torch.nn.Module): 114 | """Calculate the distance between transforms in the log-space of SE(3).""" 115 | 116 | def __init__(self): 117 | super().__init__() 118 | 119 | @jaxtyped(typechecker=beartype) 120 | def forward( 121 | self, 122 | pose_1: RigidTransform, 123 | pose_2: RigidTransform, 124 | ) -> Float[torch.Tensor, "b"]: 125 | return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1) 126 | 127 | # %% ../notebooks/api/04_metrics.ipynb 12 128 | @beartype 129 | class DoubleGeodesic(torch.nn.Module): 130 | """Calculate the angular and translational geodesics between two SE(3) transformation matrices.""" 131 | 132 | def __init__( 133 | self, 134 | sdr: float, # Source-to-detector radius 135 | eps: float = 1e-4, # Avoid overflows in sqrt 136 | ): 137 | super().__init__() 138 | self.sdr = sdr 139 | self.eps = eps 140 | 141 | self.rotation = GeodesicSO3() 142 | self.translation = GeodesicTranslation() 143 | 144 | @jaxtyped(typechecker=beartype) 145 | def forward(self, pose_1: RigidTransform, pose_2: RigidTransform): 146 | angular_geodesic = self.sdr * self.rotation(pose_1, pose_2) 147 | translation_geodesic = self.translation(pose_1, pose_2) 148 | double_geodesic = ( 149 | (angular_geodesic).square() + translation_geodesic.square() + self.eps 150 | ).sqrt() 151 | return angular_geodesic, translation_geodesic, double_geodesic 152 | -------------------------------------------------------------------------------- /diffpose/registration.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/03_registration.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['PoseRegressor', 'SparseRegistration', 'VectorizedNormalizedCrossCorrelation2d'] 5 | 6 | # %% ../notebooks/api/03_registration.ipynb 3 7 | import timm 8 | import torch 9 | 10 | # %% ../notebooks/api/03_registration.ipynb 5 11 | from .calibration import RigidTransform, convert 12 | 13 | 14 | class PoseRegressor(torch.nn.Module): 15 | """ 16 | A PoseRegressor is comprised of a pretrained backbone model that extracts features 17 | from an input X-ray and two linear layers that decode these features into rotational 18 | and translational camera pose parameters, respectively. 19 | """ 20 | 21 | def __init__( 22 | self, 23 | model_name, 24 | parameterization, 25 | convention=None, 26 | pretrained=False, 27 | **kwargs, 28 | ): 29 | super().__init__() 30 | 31 | self.parameterization = parameterization 32 | self.convention = convention 33 | n_angular_components = N_ANGULAR_COMPONENTS[parameterization] 34 | 35 | # Get the size of the output from the backbone 36 | self.backbone = timm.create_model( 37 | model_name, 38 | pretrained, 39 | num_classes=0, 40 | in_chans=1, 41 | **kwargs, 42 | ) 43 | output = self.backbone(torch.randn(1, 1, 256, 256)).shape[-1] 44 | self.xyz_regression = torch.nn.Linear(output, 3) 45 | self.rot_regression = torch.nn.Linear(output, n_angular_components) 46 | 47 | def forward(self, x): 48 | x = self.backbone(x) 49 | rot = self.rot_regression(x) 50 | xyz = self.xyz_regression(x) 51 | return convert( 52 | [rot, xyz], 53 | input_parameterization=self.parameterization, 54 | output_parameterization="se3_exp_map", 55 | input_convention=self.convention, 56 | ) 57 | 58 | # %% ../notebooks/api/03_registration.ipynb 6 59 | N_ANGULAR_COMPONENTS = { 60 | "axis_angle": 3, 61 | "euler_angles": 3, 62 | "se3_log_map": 3, 63 | "quaternion": 4, 64 | "rotation_6d": 6, 65 | "rotation_10d": 10, 66 | "quaternion_adjugate": 10, 67 | } 68 | 69 | # %% ../notebooks/api/03_registration.ipynb 11 70 | from diffdrr.detector import make_xrays 71 | from diffdrr.drr import DRR 72 | from diffdrr.siddon import siddon_raycast 73 | 74 | from .calibration import RigidTransform 75 | 76 | 77 | class SparseRegistration(torch.nn.Module): 78 | def __init__( 79 | self, 80 | drr: DRR, 81 | pose: RigidTransform, 82 | parameterization: str, 83 | convention: str = None, 84 | features=None, # Used to compute biased estimate of mNCC 85 | n_patches: int = None, # If n_patches is None, render the whole image 86 | patch_size: int = 13, 87 | ): 88 | super().__init__() 89 | self.drr = drr 90 | 91 | # Parse the input pose 92 | rotation, translation = convert( 93 | pose, 94 | input_parameterization="se3_exp_map", 95 | output_parameterization=parameterization, 96 | output_convention=convention, 97 | ) 98 | self.parameterization = parameterization 99 | self.convention = convention 100 | self.rotation = torch.nn.Parameter(rotation) 101 | self.translation = torch.nn.Parameter(translation) 102 | 103 | # Crop pixels off the edge such that pixels don't fall outside the image 104 | self.n_patches = n_patches 105 | self.patch_size = patch_size 106 | self.patch_radius = self.patch_size // 2 + 1 107 | self.height = self.drr.detector.height 108 | self.width = self.drr.detector.width 109 | self.f_height = self.height - 2 * self.patch_radius 110 | self.f_width = self.width - 2 * self.patch_radius 111 | 112 | # Define the distribution over patch centers 113 | if features is None: 114 | features = torch.ones( 115 | self.height, self.width, device=self.rotation.device 116 | ) / (self.height * self.width) 117 | self.patch_centers = torch.distributions.categorical.Categorical( 118 | probs=features.squeeze()[ 119 | self.patch_radius : -self.patch_radius, 120 | self.patch_radius : -self.patch_radius, 121 | ].flatten() 122 | ) 123 | 124 | def forward(self, n_patches=None, patch_size=None): 125 | # Parse initial density 126 | if not hasattr(self.drr, "density"): 127 | self.drr.set_bone_attenuation_multiplier( 128 | self.drr.bone_attenuation_multiplier 129 | ) 130 | 131 | if n_patches is not None or patch_size is not None: 132 | self.n_patches = n_patches 133 | self.patch_size = patch_size 134 | 135 | # Make the mask for sparse rendering 136 | if self.n_patches is None: 137 | mask = torch.ones( 138 | 1, 139 | self.height, 140 | self.width, 141 | dtype=torch.bool, 142 | device=self.rotation.device, 143 | ) 144 | else: 145 | mask = torch.zeros( 146 | self.n_patches, 147 | self.height, 148 | self.width, 149 | dtype=torch.bool, 150 | device=self.rotation.device, 151 | ) 152 | radius = self.patch_size // 2 153 | idxs = self.patch_centers.sample(sample_shape=torch.Size([self.n_patches])) 154 | idxs, jdxs = ( 155 | idxs // self.f_height + self.patch_radius, 156 | idxs % self.f_width + self.patch_radius, 157 | ) 158 | 159 | idx = torch.arange(-radius, radius + 1, device=self.rotation.device) 160 | patches = torch.cartesian_prod(idx, idx).expand(self.n_patches, -1, -1) 161 | patches = patches + torch.stack([idxs, jdxs], dim=-1).unsqueeze(1) 162 | patches = torch.concat( 163 | [ 164 | torch.arange(self.n_patches, device=self.rotation.device) 165 | .unsqueeze(-1) 166 | .expand(-1, self.patch_size**2) 167 | .unsqueeze(-1), 168 | patches, 169 | ], 170 | dim=-1, 171 | ) 172 | mask[ 173 | patches[..., 0], 174 | patches[..., 1], 175 | patches[..., 2], 176 | ] = True 177 | 178 | # Get the source and target 179 | pose = convert( 180 | [self.rotation, self.translation], 181 | input_parameterization=self.parameterization, 182 | output_parameterization="se3_exp_map", 183 | input_convention=self.convention, 184 | ) 185 | source, target = make_xrays( 186 | pose, 187 | self.drr.detector.source, 188 | self.drr.detector.target, 189 | ) 190 | 191 | # Render the sparse image 192 | target = target[mask.any(dim=0).view(1, -1)] 193 | img = siddon_raycast(source, target, self.drr.density, self.drr.spacing) 194 | if self.n_patches is None: 195 | img = self.drr.reshape_transform(img, batch_size=len(self.rotation)) 196 | return img, mask 197 | 198 | def get_current_pose(self): 199 | return convert( 200 | [self.rotation, self.translation], 201 | input_parameterization=self.parameterization, 202 | output_parameterization="se3_exp_map", 203 | input_convention=self.convention, 204 | ) 205 | 206 | # %% ../notebooks/api/03_registration.ipynb 13 207 | def preprocess(x, eps=1e-4): 208 | x = (x - x.min()) / (x.max() - x.min() + eps) 209 | return (x - 0.3080) / 0.1494 210 | 211 | 212 | def pred_to_patches(pred_img, mask, n_patches, patch_size): 213 | return pred_img.expand(-1, n_patches, -1)[..., mask[..., mask.any(dim=0)]].reshape( 214 | 1, n_patches, -1 215 | ) 216 | 217 | 218 | def img_to_patches(img, mask, n_patches, patch_size): 219 | return img.expand(-1, n_patches, -1, -1)[..., mask].reshape(1, n_patches, -1) 220 | 221 | 222 | def mask_to_img(img, mask): 223 | return img[..., mask.any(dim=0)] 224 | 225 | 226 | def vector_to_img(pred_img, mask): 227 | patches = [pred_img] 228 | filled = torch.zeros(1, 1, *mask[0].shape, device=pred_img.device) 229 | filled[...] = torch.nan 230 | for idx in range(len(mask)): 231 | patch = pred_img[:, mask[idx][mask.any(dim=0)]] 232 | filled[..., mask[idx]] = patch 233 | patches.append(patch) 234 | return filled 235 | 236 | # %% ../notebooks/api/03_registration.ipynb 14 237 | class VectorizedNormalizedCrossCorrelation2d(torch.nn.Module): 238 | def __init__(self, eps=1e-4): 239 | super().__init__() 240 | self.eps = eps 241 | 242 | def forward(self, img, pred_img, mask, n_patches, patch_size): 243 | pred_img = preprocess(pred_img).unsqueeze(0) 244 | sub_img = mask_to_img(img, mask) 245 | pred_patches = pred_to_patches(pred_img, mask, n_patches, patch_size) 246 | img_patches = img_to_patches(img, mask, n_patches, patch_size) 247 | 248 | local_ncc = self.forward_compute(pred_patches, img_patches) 249 | global_ncc = self.forward_compute(pred_img, sub_img) 250 | return (local_ncc + global_ncc) / 2 251 | 252 | def forward_compute(self, x1, x2): 253 | assert x1.shape == x2.shape, "Input images must be the same size" 254 | x1, x2 = self.norm(x1), self.norm(x2) 255 | ncc = (x1 * x2).mean(dim=[-1, -2]) 256 | return ncc 257 | 258 | def norm(self, x): 259 | mu = x.mean(dim=-1, keepdim=True) 260 | var = x.var(dim=-1, keepdim=True, correction=0) + self.eps 261 | std = var.sqrt() 262 | return (x - mu) / std 263 | -------------------------------------------------------------------------------- /diffpose/visualization.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/api/05_visualization.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['overlay_edges', 'fiducials_to_mesh', 'lines_to_mesh'] 5 | 6 | # %% ../notebooks/api/05_visualization.ipynb 4 7 | from io import BytesIO 8 | 9 | import matplotlib.pyplot as plt 10 | import numpy as np 11 | import torch 12 | from PIL import Image 13 | from skimage.feature import canny 14 | from torchvision.utils import make_grid 15 | 16 | # %% ../notebooks/api/05_visualization.ipynb 5 17 | def _overlay_edges(target, pred, sigma, eps=1e-5): 18 | pred = (pred - pred.min()) / (pred.max() - pred.min() + eps) 19 | edges = canny(pred, sigma=sigma) 20 | edges = np.ma.masked_where(~edges, edges) 21 | 22 | buffer = BytesIO() 23 | plt.subplot() 24 | plt.imshow(target, cmap="gray") 25 | plt.imshow(edges, cmap="cool_r", interpolation="none", vmin=0.0, vmax=1.0) 26 | plt.axis("off") 27 | plt.savefig(buffer, format="png", bbox_inches="tight", pad_inches=0, dpi=300) 28 | arr = np.array(Image.open(buffer)) 29 | plt.close() 30 | return arr 31 | 32 | # %% ../notebooks/api/05_visualization.ipynb 6 33 | def overlay_edges(target, pred, sigma=1.5): 34 | """Generate edge overlays for a batch of targets and predictions.""" 35 | edges = [] 36 | for i, p in zip(target, pred): 37 | edge = _overlay_edges(i[0].cpu().numpy(), p[0].cpu().numpy(), sigma) 38 | edges.append(edge) 39 | edges = torch.from_numpy(np.stack(edges)).permute(0, -1, 1, 2) 40 | edges = make_grid(edges).permute(1, 2, 0) 41 | return edges 42 | 43 | # %% ../notebooks/api/05_visualization.ipynb 8 44 | import pyvista 45 | from torch.nn.functional import pad 46 | 47 | from .calibration import RigidTransform, perspective_projection 48 | 49 | # %% ../notebooks/api/05_visualization.ipynb 9 50 | def fiducials_3d_to_projected_fiducials_3d(specimen, pose): 51 | # Extrinsic camera matrix 52 | extrinsic = ( 53 | specimen.lps2volume.inverse() 54 | .compose(pose.inverse()) 55 | .compose(specimen.translate) 56 | .compose(specimen.flip_xz) 57 | ) 58 | 59 | # Intrinsic projection -> in 3D 60 | x = perspective_projection(extrinsic, specimen.intrinsic, specimen.fiducials) 61 | x = -specimen.focal_len * torch.einsum( 62 | "ij, bnj -> bni", 63 | specimen.intrinsic.inverse(), 64 | pad(x, (0, 1), value=1), # Convert to homogenous coordinates 65 | ) 66 | 67 | # Some command-z 68 | extrinsic = ( 69 | specimen.flip_xz.inverse().compose(specimen.translate.inverse()).compose(pose) 70 | ) 71 | return extrinsic.transform_points(x) 72 | 73 | # %% ../notebooks/api/05_visualization.ipynb 10 74 | def fiducials_to_mesh( 75 | specimen, 76 | rotation=None, 77 | translation=None, 78 | parameterization=None, 79 | convention=None, 80 | detector=None, 81 | ): 82 | """ 83 | Use camera matrices to get 2D projections of 3D fiducials for a given pose. 84 | If the detector is passed, 2D projections will be filtered for those that lie 85 | on the detector plane. 86 | """ 87 | # Location of fiducials in 3D 88 | fiducials_3d = specimen.lps2volume.inverse().transform_points(specimen.fiducials) 89 | fiducials_3d = pyvista.PolyData(fiducials_3d.squeeze().numpy()) 90 | if rotation is None and translation is None and parameterization is None: 91 | return fiducials_3d 92 | 93 | # Embedding of fiducials in 2D 94 | pose = RigidTransform( 95 | rotation, translation, parameterization, convention, device="cpu" 96 | ) 97 | fiducials_2d = fiducials_3d_to_projected_fiducials_3d(specimen, pose) 98 | fiducials_2d = fiducials_2d.squeeze().numpy() 99 | 100 | # Optionally, only render 2D fiducials that lie on the detector plane 101 | if detector is not None: 102 | corners = detector.points.reshape( 103 | detector["height"][0], detector["width"][0], 3 104 | )[ 105 | [0, 0, -1, -1], 106 | [0, -1, 0, -1], 107 | ] 108 | exclude = np.logical_or( 109 | fiducials_2d < corners.min(0), 110 | fiducials_2d > corners.max(0), 111 | ).any(1) 112 | fiducials_2d = fiducials_2d[~exclude] 113 | 114 | fiducials_2d = pyvista.PolyData(fiducials_2d) 115 | return fiducials_3d, fiducials_2d 116 | 117 | # %% ../notebooks/api/05_visualization.ipynb 11 118 | def lines_to_mesh(camera, fiducials_2d): 119 | """Draw lines from the camera to the 2D fiducials.""" 120 | lines = [] 121 | for pt in fiducials_2d.points: 122 | line = pyvista.Line(pt, camera.center) 123 | lines.append(line) 124 | return lines 125 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: diffpose 2 | channels: 3 | - conda-forge 4 | - pytorch 5 | - nvidia 6 | dependencies: 7 | - pip 8 | - pytorch 9 | - torchvision 10 | - pip: 11 | - diffdrr==0.3.9 12 | - h5py 13 | - scikit-image 14 | - seaborn 15 | - pytorch-transformers 16 | - timm 17 | - torchmetrics 18 | - tqdm 19 | - beartype 20 | - jaxtyping 21 | -------------------------------------------------------------------------------- /experiments/deepfluoro/evaluate.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pandas as pd 4 | import submitit 5 | import torch 6 | from tqdm import tqdm 7 | 8 | from diffpose.deepfluoro import DeepFluoroDataset, Evaluator, Transforms 9 | from diffpose.registration import PoseRegressor 10 | 11 | 12 | def load_specimen(id_number, device): 13 | specimen = DeepFluoroDataset(id_number) 14 | isocenter_pose = specimen.isocenter_pose.to(device) 15 | return specimen, isocenter_pose 16 | 17 | 18 | def load_model(model_name, device): 19 | ckpt = torch.load(model_name) 20 | model = PoseRegressor( 21 | ckpt["model_name"], 22 | ckpt["parameterization"], 23 | ckpt["convention"], 24 | norm_layer=ckpt["norm_layer"], 25 | ).to(device) 26 | model.load_state_dict(ckpt["model_state_dict"]) 27 | transforms = Transforms(ckpt["height"]) 28 | return model, transforms 29 | 30 | 31 | def evaluate(specimen, isocenter_pose, model, transforms, device): 32 | error = [] 33 | model.eval() 34 | for idx in tqdm(range(len(specimen)), ncols=100): 35 | target_registration_error = Evaluator(specimen, idx) 36 | img, _ = specimen[idx] 37 | img = img.to(device) 38 | img = transforms(img) 39 | with torch.no_grad(): 40 | offset = model(img) 41 | pred_pose = isocenter_pose.compose(offset) 42 | mtre = target_registration_error(pred_pose.cpu()).item() 43 | error.append(mtre) 44 | return error 45 | 46 | 47 | def main(id_number): 48 | device = torch.device("cuda") 49 | specimen, isocenter_pose = load_specimen(id_number, device) 50 | models = sorted(Path("checkpoints/").glob(f"specimen_{id_number:02d}_epoch*.ckpt")) 51 | 52 | errors = [] 53 | for model_name in models: 54 | model, transforms = load_model(model_name, device) 55 | error = evaluate(specimen, isocenter_pose, model, transforms, device) 56 | errors.append([model_name.stem] + error) 57 | 58 | df = pd.DataFrame(errors) 59 | df.to_csv(f"evaluations/subject{id_number}.csv", index=False) 60 | 61 | 62 | if __name__ == "__main__": 63 | seed = 123 64 | torch.manual_seed(seed) 65 | torch.cuda.manual_seed_all(seed) 66 | torch.backends.cudnn.benchmark = False 67 | torch.backends.cudnn.deterministic = True 68 | 69 | Path("evaluations").mkdir(exist_ok=True) 70 | id_numbers = [1, 2, 3, 4, 5, 6] 71 | 72 | executor = submitit.AutoExecutor(folder="logs") 73 | executor.update_parameters( 74 | name="eval", 75 | gpus_per_node=1, 76 | mem_gb=10.0, 77 | slurm_array_parallelism=len(id_numbers), 78 | slurm_exclude="curcum", 79 | slurm_partition="2080ti", 80 | timeout_min=10_000, 81 | ) 82 | jobs = executor.map_array(main, id_numbers) 83 | -------------------------------------------------------------------------------- /experiments/deepfluoro/register.py: -------------------------------------------------------------------------------- 1 | import time 2 | from itertools import product 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | import submitit 7 | import torch 8 | from diffdrr.drr import DRR 9 | from diffdrr.metrics import MultiscaleNormalizedCrossCorrelation2d 10 | from torchvision.transforms.functional import resize 11 | from tqdm import tqdm 12 | 13 | from diffpose.calibration import RigidTransform, convert 14 | from diffpose.deepfluoro import DeepFluoroDataset, Evaluator, Transforms 15 | from diffpose.metrics import DoubleGeodesic, GeodesicSE3 16 | from diffpose.registration import PoseRegressor, SparseRegistration 17 | 18 | 19 | class Registration: 20 | def __init__( 21 | self, 22 | drr, 23 | specimen, 24 | model, 25 | parameterization, 26 | convention=None, 27 | n_iters=500, 28 | verbose=False, 29 | device="cuda", 30 | ): 31 | self.device = torch.device(device) 32 | self.drr = drr.to(self.device) 33 | self.model = model.to(self.device) 34 | model.eval() 35 | 36 | self.specimen = specimen 37 | self.isocenter_pose = specimen.isocenter_pose.to(self.device) 38 | 39 | self.geodesics = GeodesicSE3() 40 | self.doublegeo = DoubleGeodesic(sdr=self.specimen.focal_len / 2) 41 | self.criterion = MultiscaleNormalizedCrossCorrelation2d([None, 13], [0.5, 0.5]) 42 | self.transforms = Transforms(self.drr.detector.height) 43 | self.parameterization = parameterization 44 | self.convention = convention 45 | 46 | self.n_iters = n_iters 47 | self.verbose = verbose 48 | 49 | def initialize_registration(self, img): 50 | with torch.no_grad(): 51 | offset = self.model(img) 52 | features = self.model.backbone.forward_features(img) 53 | features = resize( 54 | features, 55 | (self.drr.detector.height, self.drr.detector.width), 56 | interpolation=3, 57 | antialias=True, 58 | ) 59 | features = features.sum(dim=[0, 1], keepdim=True) 60 | features -= features.min() 61 | features /= features.max() - features.min() 62 | features /= features.sum() 63 | pred_pose = self.isocenter_pose.compose(offset) 64 | 65 | return SparseRegistration( 66 | self.drr, 67 | pose=pred_pose, 68 | parameterization=self.parameterization, 69 | convention=self.convention, 70 | features=features, 71 | ) 72 | 73 | def initialize_optimizer(self, registration): 74 | optimizer = torch.optim.Adam( 75 | [ 76 | {"params": [registration.rotation], "lr": 7.5e-3}, 77 | {"params": [registration.translation], "lr": 7.5e0}, 78 | ], 79 | maximize=True, 80 | ) 81 | scheduler = torch.optim.lr_scheduler.StepLR( 82 | optimizer, 83 | step_size=25, 84 | gamma=0.9, 85 | ) 86 | return optimizer, scheduler 87 | 88 | def evaluate(self, registration): 89 | est_pose = registration.get_current_pose() 90 | rot = est_pose.get_rotation("euler_angles", "ZYX") 91 | xyz = est_pose.get_translation() 92 | alpha, beta, gamma = rot.squeeze().tolist() 93 | bx, by, bz = xyz.squeeze().tolist() 94 | param = [alpha, beta, gamma, bx, by, bz] 95 | geo = ( 96 | torch.concat( 97 | [ 98 | *self.doublegeo(est_pose, self.pose), 99 | self.geodesics(est_pose, self.pose), 100 | ] 101 | ) 102 | .squeeze() 103 | .tolist() 104 | ) 105 | tre = self.target_registration_error(est_pose.cpu()).item() 106 | return param, geo, tre 107 | 108 | def run(self, idx): 109 | img, pose = self.specimen[idx] 110 | img = self.transforms(img).to(self.device) 111 | self.pose = pose.to(self.device) 112 | 113 | registration = self.initialize_registration(img) 114 | optimizer, scheduler = self.initialize_optimizer(registration) 115 | self.target_registration_error = Evaluator(self.specimen, idx) 116 | 117 | # Initial loss 118 | param, geo, tre = self.evaluate(registration) 119 | params = [param] 120 | losses = [] 121 | geodesic = [geo] 122 | fiducial = [tre] 123 | times = [] 124 | 125 | itr = ( 126 | tqdm(range(self.n_iters), ncols=75) if self.verbose else range(self.n_iters) 127 | ) 128 | for _ in itr: 129 | t0 = time.perf_counter() 130 | optimizer.zero_grad() 131 | pred_img, mask = registration() 132 | loss = self.criterion(pred_img, img) 133 | loss.backward() 134 | optimizer.step() 135 | scheduler.step() 136 | t1 = time.perf_counter() 137 | 138 | param, geo, tre = self.evaluate(registration) 139 | params.append(param) 140 | losses.append(loss.item()) 141 | geodesic.append(geo) 142 | fiducial.append(tre) 143 | times.append(t1 - t0) 144 | 145 | # Loss at final iteration 146 | pred_img, mask = registration() 147 | loss = self.criterion(pred_img, img) 148 | losses.append(loss.item()) 149 | times.append(0) 150 | 151 | # Write results to dataframe 152 | df = pd.DataFrame(params, columns=["alpha", "beta", "gamma", "bx", "by", "bz"]) 153 | df["ncc"] = losses 154 | df[["geo_r", "geo_t", "geo_d", "geo_se3"]] = geodesic 155 | df["fiducial"] = fiducial 156 | df["time"] = times 157 | df["idx"] = idx 158 | df["parameterization"] = self.parameterization 159 | return df 160 | 161 | 162 | def main(id_number, parameterization): 163 | ckpt = torch.load(f"checkpoints/specimen_{id_number:02d}_best.ckpt") 164 | model = PoseRegressor( 165 | ckpt["model_name"], 166 | ckpt["parameterization"], 167 | ckpt["convention"], 168 | norm_layer=ckpt["norm_layer"], 169 | ) 170 | model.load_state_dict(ckpt["model_state_dict"]) 171 | 172 | specimen = DeepFluoroDataset(id_number) 173 | height = ckpt["height"] 174 | subsample = (1536 - 100) / height 175 | delx = 0.194 * subsample 176 | 177 | drr = DRR( 178 | specimen.volume, 179 | specimen.spacing, 180 | sdr=specimen.focal_len / 2, 181 | height=height, 182 | delx=delx, 183 | x0=specimen.x0, 184 | y0=specimen.y0, 185 | reverse_x_axis=True, 186 | bone_attenuation_multiplier=2.5, 187 | ) 188 | 189 | registration = Registration( 190 | drr, 191 | specimen, 192 | model, 193 | parameterization, 194 | ) 195 | for idx in tqdm(range(len(specimen)), ncols=100): 196 | df = registration.run(idx) 197 | df.to_csv( 198 | f"runs/specimen{id_number:02d}_xray{idx:03d}_{parameterization}.csv", 199 | index=False, 200 | ) 201 | 202 | 203 | if __name__ == "__main__": 204 | seed = 123 205 | torch.manual_seed(seed) 206 | torch.cuda.manual_seed_all(seed) 207 | torch.backends.cudnn.benchmark = False 208 | torch.backends.cudnn.deterministic = True 209 | 210 | id_numbers = [1, 2, 3, 4, 5, 6] 211 | parameterizations = [ 212 | "se3_log_map", 213 | "so3_log_map", 214 | "axis_angle", 215 | "euler_angles", 216 | "quaternion", 217 | "rotation_6d", 218 | "rotation_10d", 219 | "quaternion_adjugate", 220 | ] 221 | id_numbers = [i for i, _ in product(id_numbers, parameterizations)] 222 | parameterizations = [p for _, p in product(id_numbers, parameterizations)] 223 | Path("runs").mkdir(exist_ok=True) 224 | 225 | executor = submitit.AutoExecutor(folder="logs") 226 | executor.update_parameters( 227 | name="registration", 228 | gpus_per_node=1, 229 | mem_gb=10.0, 230 | slurm_array_parallelism=12, 231 | slurm_partition="2080ti", 232 | timeout_min=10_000, 233 | ) 234 | jobs = executor.map_array(main, id_numbers, parameterizations) 235 | -------------------------------------------------------------------------------- /experiments/deepfluoro/train.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import submitit 4 | import torch 5 | from diffdrr.drr import DRR 6 | from diffdrr.metrics import MultiscaleNormalizedCrossCorrelation2d 7 | from pytorch_transformers.optimization import WarmupCosineSchedule 8 | from timm.utils.agc import adaptive_clip_grad as adaptive_clip_grad_ 9 | from tqdm import tqdm 10 | 11 | from diffpose.deepfluoro import DeepFluoroDataset, Transforms, get_random_offset 12 | from diffpose.metrics import DoubleGeodesic, GeodesicSE3 13 | from diffpose.registration import PoseRegressor 14 | 15 | 16 | def load(id_number, height, device): 17 | specimen = DeepFluoroDataset(id_number) 18 | isocenter_pose = specimen.isocenter_pose.to(device) 19 | 20 | subsample = (1536 - 100) / height 21 | delx = 0.194 * subsample 22 | drr = DRR( 23 | specimen.volume, 24 | specimen.spacing, 25 | specimen.focal_len / 2, 26 | height, 27 | delx, 28 | x0=specimen.x0, 29 | y0=specimen.y0, 30 | reverse_x_axis=True, 31 | ).to(device) 32 | transforms = Transforms(height) 33 | 34 | return specimen, isocenter_pose, transforms, drr 35 | 36 | 37 | def train( 38 | id_number, 39 | model, 40 | optimizer, 41 | scheduler, 42 | drr, 43 | transforms, 44 | specimen, 45 | isocenter_pose, 46 | device, 47 | batch_size, 48 | n_epochs, 49 | n_batches_per_epoch, 50 | model_params, 51 | ): 52 | metric = MultiscaleNormalizedCrossCorrelation2d(eps=1e-4) 53 | geodesic = GeodesicSE3() 54 | double = DoubleGeodesic(drr.detector.sdr) 55 | contrast_distribution = torch.distributions.Uniform(1.0, 10.0) 56 | 57 | best_loss = torch.inf 58 | 59 | model.train() 60 | for epoch in range(n_epochs + 1): 61 | losses = [] 62 | for _ in (itr := tqdm(range(n_batches_per_epoch), leave=False)): 63 | contrast = contrast_distribution.sample().item() 64 | offset = get_random_offset(batch_size, device) 65 | pose = isocenter_pose.compose(offset) 66 | img = drr(None, None, None, pose=pose, bone_attenuation_multiplier=contrast) 67 | img = transforms(img) 68 | 69 | pred_offset = model(img) 70 | pred_pose = isocenter_pose.compose(pred_offset) 71 | pred_img = drr(None, None, None, pose=pred_pose) 72 | pred_img = transforms(pred_img) 73 | 74 | ncc = metric(pred_img, img) 75 | log_geodesic = geodesic(pred_pose, pose) 76 | geodesic_rot, geodesic_xyz, double_geodesic = double(pred_pose, pose) 77 | loss = 1 - ncc + 1e-2 * (log_geodesic + double_geodesic) 78 | if loss.isnan().any(): 79 | print("Aaaaaaand we've crashed...") 80 | print(ncc) 81 | print(log_geodesic) 82 | print(geodesic_rot) 83 | print(geodesic_xyz) 84 | print(double_geodesic) 85 | print(pose.get_matrix()) 86 | print(pred_pose.get_matrix()) 87 | torch.save( 88 | { 89 | "model_state_dict": model.state_dict(), 90 | "optimizer_state_dict": optimizer.state_dict(), 91 | "height": drr.detector.height, 92 | "epoch": epoch, 93 | "batch_size": batch_size, 94 | "n_epochs": n_epochs, 95 | "n_batches_per_epoch": n_batches_per_epoch, 96 | "pose": pose.get_matrix().cpu(), 97 | "pred_pose": pred_pose.get_matrix().cpu(), 98 | "img": img.cpu(), 99 | "pred_img": pred_img.cpu() 100 | **model_params, 101 | }, 102 | f"checkpoints/specimen_{id_number:02d}_crashed.ckpt", 103 | ) 104 | raise RuntimeError("NaN loss") 105 | 106 | optimizer.zero_grad() 107 | loss.mean().backward() 108 | adaptive_clip_grad_(model.parameters()) 109 | optimizer.step() 110 | scheduler.step() 111 | 112 | losses.append(loss.mean().item()) 113 | 114 | # Update progress bar 115 | itr.set_description(f"Epoch [{epoch}/{n_epochs}]") 116 | itr.set_postfix( 117 | geodesic_rot=geodesic_rot.mean().item(), 118 | geodesic_xyz=geodesic_xyz.mean().item(), 119 | geodesic_dou=double_geodesic.mean().item(), 120 | geodesic_se3=log_geodesic.mean().item(), 121 | loss=loss.mean().item(), 122 | ncc=ncc.mean().item(), 123 | ) 124 | 125 | prev_pose = pose 126 | prev_pred_pose = pred_pose 127 | 128 | losses = torch.tensor(losses) 129 | tqdm.write(f"Epoch {epoch + 1:04d} | Loss {losses.mean().item():.4f}") 130 | if losses.mean() < best_loss and not losses.isnan().any(): 131 | best_loss = losses.mean().item() 132 | torch.save( 133 | { 134 | "model_state_dict": model.state_dict(), 135 | "optimizer_state_dict": optimizer.state_dict(), 136 | "height": drr.detector.height, 137 | "epoch": epoch, 138 | "loss": losses.mean().item(), 139 | "batch_size": batch_size, 140 | "n_epochs": n_epochs, 141 | "n_batches_per_epoch": n_batches_per_epoch, 142 | **model_params, 143 | }, 144 | f"checkpoints/specimen_{id_number:02d}_best.ckpt", 145 | ) 146 | 147 | if epoch % 50 == 0: 148 | torch.save( 149 | { 150 | "model_state_dict": model.state_dict(), 151 | "optimizer_state_dict": optimizer.state_dict(), 152 | "height": drr.detector.height, 153 | "epoch": epoch, 154 | "loss": losses.mean().item(), 155 | "batch_size": batch_size, 156 | "n_epochs": n_epochs, 157 | "n_batches_per_epoch": n_batches_per_epoch, 158 | **model_params, 159 | }, 160 | f"checkpoints/specimen_{id_number:02d}_epoch{epoch:03d}.ckpt", 161 | ) 162 | 163 | 164 | def main( 165 | id_number, 166 | height=256, 167 | restart=None, 168 | model_name="resnet18", 169 | parameterization="se3_log_map", 170 | convention=None, 171 | lr=1e-3, 172 | batch_size=8, 173 | n_epochs=1000, 174 | n_batches_per_epoch=100, 175 | ): 176 | id_number = int(id_number) 177 | 178 | device = torch.device("cuda") 179 | specimen, isocenter_pose, transforms, drr = load(id_number, height, device) 180 | 181 | model_params = { 182 | "model_name": model_name, 183 | "parameterization": parameterization, 184 | "convention": convention, 185 | "norm_layer": "groupnorm", 186 | } 187 | model = PoseRegressor(**model_params) 188 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 189 | if restart is not None: 190 | ckpt = torch.load(restart) 191 | model.load_state_dict(ckpt["model_state_dict"]) 192 | optimizer.load_state_dict(ckpt["optimizer_state_dict"]) 193 | model = model.to(device) 194 | 195 | scheduler = WarmupCosineSchedule( 196 | optimizer, 197 | 5 * n_batches_per_epoch, 198 | n_epochs * n_batches_per_epoch - 5 * n_batches_per_epoch, 199 | ) 200 | 201 | train( 202 | id_number, 203 | model, 204 | optimizer, 205 | scheduler, 206 | drr, 207 | transforms, 208 | specimen, 209 | isocenter_pose, 210 | device, 211 | batch_size, 212 | n_epochs, 213 | n_batches_per_epoch, 214 | model_params, 215 | ) 216 | 217 | 218 | if __name__ == "__main__": 219 | id_numbers = [1, 2, 3, 4, 5, 6] 220 | Path("checkpoints").mkdir(exist_ok=True) 221 | 222 | executor = submitit.AutoExecutor(folder="logs") 223 | executor.update_parameters( 224 | name="deepfluoro", 225 | gpus_per_node=1, 226 | mem_gb=43.5, 227 | slurm_array_parallelism=len(id_numbers), 228 | slurm_partition="A6000", 229 | slurm_exclude="sumac,fennel", 230 | timeout_min=10_000, 231 | ) 232 | jobs = executor.map_array(main, id_numbers) 233 | -------------------------------------------------------------------------------- /experiments/ljubljana/register.py: -------------------------------------------------------------------------------- 1 | import time 2 | from itertools import product 3 | from pathlib import Path 4 | 5 | import pandas as pd 6 | import submitit 7 | import torch 8 | from diffdrr.drr import DRR 9 | from diffdrr.metrics import MultiscaleNormalizedCrossCorrelation2d 10 | from torchvision.transforms.functional import resize 11 | from tqdm import tqdm 12 | 13 | from diffpose.calibration import RigidTransform, convert 14 | from diffpose.ljubljana import Evaluator, LjubljanaDataset, Transforms 15 | from diffpose.metrics import DoubleGeodesic, GeodesicSE3 16 | from diffpose.registration import PoseRegressor, SparseRegistration 17 | 18 | 19 | def initialize(id_number, view, subsample=8): 20 | # Load the model 21 | ckpt = torch.load(f"checkpoints/specimen_{id_number:02d}_{view}_best.ckpt") 22 | model = PoseRegressor( 23 | ckpt["model_name"], 24 | ckpt["parameterization"], 25 | ckpt["convention"], 26 | norm_layer=ckpt["norm_layer"], 27 | ) 28 | model.load_state_dict(ckpt["model_state_dict"]) 29 | model = model.cuda() 30 | model.eval() 31 | 32 | # Load the subject 33 | subject = LjubljanaDataset(view) 34 | ( 35 | volume, 36 | spacing, 37 | focal_len, 38 | height, 39 | width, 40 | delx, 41 | dely, 42 | x0, 43 | y0, 44 | img, 45 | pose, 46 | isocenter_pose, 47 | ) = subject[id_number] 48 | volume[volume < 1000] = 0.0 49 | isocenter_pose = isocenter_pose.cuda() 50 | evaluator = Evaluator(subject, id_number) 51 | 52 | # Initialize the DRR 53 | height //= subsample 54 | width //= subsample 55 | delx *= subsample 56 | dely *= subsample 57 | drr = DRR( 58 | volume, 59 | spacing, 60 | focal_len / 2, 61 | height, 62 | delx, 63 | width, 64 | dely, 65 | x0, 66 | y0, 67 | reverse_x_axis=True, 68 | ).to("cuda") 69 | transforms = Transforms(height, width) 70 | 71 | # Get the estimated pose and features 72 | pose = pose.to("cuda") 73 | img = transforms(img).to("cuda") 74 | with torch.no_grad(): 75 | offset = model(img) 76 | features = model.backbone.forward_features(img) 77 | features = resize( 78 | features, 79 | (height, width), 80 | interpolation=3, 81 | antialias=True, 82 | ) 83 | features = features.sum(dim=[0, 1], keepdim=True) 84 | features -= features.min() 85 | features /= features.max() - features.min() 86 | features /= features.sum() 87 | pred_pose = isocenter_pose.compose(offset) 88 | 89 | return drr, img, pose, pred_pose, features, evaluator 90 | 91 | 92 | class Registration: 93 | def __init__( 94 | self, 95 | drr, 96 | img, 97 | pose, 98 | pred_pose, 99 | features, 100 | evaluator, 101 | parameterization, 102 | convention="ZYX", 103 | lr_rot=1e-3, 104 | lr_xyz=1e0, 105 | n_iters=5000, 106 | verbose=True, 107 | ): 108 | self.parameterization = parameterization 109 | self.convention = convention 110 | self.registration, self.optimizer, self.scheduler = self.initialize( 111 | drr, pred_pose, features, lr_rot, lr_xyz 112 | ) 113 | 114 | self.img = img 115 | self.pose = pose 116 | 117 | self.geodesics = GeodesicSE3() 118 | self.doublegeo = DoubleGeodesic(drr.detector.sdr) 119 | self.criterion = MultiscaleNormalizedCrossCorrelation2d( 120 | [None, 13], # None corresponds to global 121 | [0.5, 0.5], 122 | ) 123 | self.target_registration_error = evaluator 124 | 125 | self.n_iters = n_iters 126 | self.verbose = verbose 127 | 128 | def initialize(self, drr, pred_pose, features, lr_rot, lr_xyz): 129 | registration = SparseRegistration( 130 | drr, 131 | pose=pred_pose, 132 | parameterization=self.parameterization, 133 | convention=self.convention, 134 | features=features, 135 | ) 136 | optimizer = torch.optim.Adam( 137 | [ 138 | {"params": [registration.rotation], "lr": lr_rot}, 139 | {"params": [registration.translation], "lr": lr_xyz}, 140 | ], 141 | maximize=True, 142 | ) 143 | scheduler = torch.optim.lr_scheduler.StepLR( 144 | optimizer, 145 | step_size=500, 146 | gamma=0.9, 147 | ) 148 | return registration, optimizer, scheduler 149 | 150 | def evaluate(self): 151 | est_pose = self.registration.get_current_pose() 152 | rot = est_pose.get_rotation("euler_angles", "ZYX") 153 | xyz = est_pose.get_translation() 154 | alpha, beta, gamma = rot.squeeze().tolist() 155 | bx, by, bz = xyz.squeeze().tolist() 156 | param = [alpha, beta, gamma, bx, by, bz] 157 | geo = ( 158 | torch.concat( 159 | [ 160 | *self.doublegeo(est_pose, self.pose), 161 | self.geodesics(est_pose, self.pose), 162 | ] 163 | ) 164 | .squeeze() 165 | .tolist() 166 | ) 167 | tre = self.target_registration_error(est_pose.cpu()).item() 168 | return param, geo, tre 169 | 170 | def run(self): 171 | # Initial loss 172 | param, geo, tre = self.evaluate() 173 | params = [param] 174 | losses = [] 175 | geodesic = [geo] 176 | fiducial = [tre] 177 | times = [] 178 | 179 | itr = ( 180 | tqdm(range(self.n_iters), ncols=75) if self.verbose else range(self.n_iters) 181 | ) 182 | for _ in itr: 183 | t0 = time.perf_counter() 184 | self.optimizer.zero_grad() 185 | pred_img, mask = self.registration(n_patches=None) 186 | loss = self.criterion(pred_img, self.img) 187 | loss.backward() 188 | self.optimizer.step() 189 | self.scheduler.step() 190 | t1 = time.perf_counter() 191 | 192 | param, geo, tre = self.evaluate() 193 | params.append(param) 194 | losses.append(loss.item()) 195 | geodesic.append(geo) 196 | fiducial.append(tre) 197 | times.append(t1 - t0) 198 | 199 | # Loss at final iteration 200 | pred_img, mask = self.registration() 201 | loss = self.criterion(pred_img, self.img) 202 | losses.append(loss.item()) 203 | times.append(0) 204 | 205 | # Write results to dataframe 206 | df = pd.DataFrame(params, columns=["alpha", "beta", "gamma", "bx", "by", "bz"]) 207 | df["ncc"] = losses 208 | df[["geo_r", "geo_t", "geo_d", "geo_se3"]] = geodesic 209 | df["fiducial"] = fiducial 210 | df["time"] = times 211 | df["parameterization"] = self.parameterization 212 | return df 213 | 214 | 215 | def main(id_number, view, parameterization="se3_log_map"): 216 | drr, img, pose, pred_pose, features, evaluator = initialize(id_number, view) 217 | registration = Registration(drr, img, pose, pred_pose, features, evaluator, parameterization) 218 | df = registration.run() 219 | df.to_csv( 220 | f"runs/specimen{id_number:02d}_{view}_{parameterization}.csv", 221 | index=False, 222 | ) 223 | 224 | 225 | if __name__ == "__main__": 226 | seed = 123 227 | torch.manual_seed(seed) 228 | torch.cuda.manual_seed_all(seed) 229 | torch.backends.cudnn.benchmark = False 230 | torch.backends.cudnn.deterministic = True 231 | 232 | id_numbers = list(range(10)) 233 | views = ["ap", "lat"] 234 | id_numbers = [i for i, _ in product(id_numbers, views)] 235 | views = [v for _, v in product(id_numbers, views)] 236 | 237 | Path("runs").mkdir(exist_ok=True) 238 | 239 | executor = submitit.AutoExecutor(folder="logs") 240 | executor.update_parameters( 241 | name="registration", 242 | gpus_per_node=1, 243 | mem_gb=10.0, 244 | slurm_array_parallelism=len(id_numbers), 245 | slurm_exclude="sassafras", 246 | slurm_partition="2080ti", 247 | timeout_min=10_000, 248 | ) 249 | jobs = executor.map_array(main, id_numbers, views) 250 | -------------------------------------------------------------------------------- /experiments/ljubljana/train.py: -------------------------------------------------------------------------------- 1 | from itertools import product 2 | from pathlib import Path 3 | 4 | import submitit 5 | import torch 6 | from diffdrr.drr import DRR 7 | from diffdrr.metrics import MultiscaleNormalizedCrossCorrelation2d 8 | from pytorch_transformers.optimization import WarmupCosineSchedule 9 | from timm.utils.agc import adaptive_clip_grad as adaptive_clip_grad_ 10 | from tqdm import tqdm 11 | 12 | from diffpose.ljubljana import LjubljanaDataset, Transforms, get_random_offset 13 | from diffpose.metrics import DoubleGeodesic, GeodesicSE3 14 | from diffpose.registration import PoseRegressor 15 | 16 | 17 | def load(id_number, view, subsample, device): 18 | # Load the subject 19 | subject = LjubljanaDataset(view) 20 | ( 21 | volume, 22 | spacing, 23 | focal_len, 24 | height, 25 | width, 26 | delx, 27 | dely, 28 | x0, 29 | y0, 30 | _, 31 | _, 32 | isocenter_pose, 33 | ) = subject[id_number] 34 | volume[volume < 500] = 0.0 35 | isocenter_pose = isocenter_pose.to(device) 36 | 37 | # Make the DRR 38 | height //= subsample 39 | width //= subsample 40 | delx *= subsample 41 | dely *= subsample 42 | 43 | drr = DRR( 44 | volume, 45 | spacing, 46 | focal_len / 2, 47 | height, 48 | delx, 49 | width, 50 | dely, 51 | x0, 52 | y0, 53 | reverse_x_axis=True, 54 | ).to(device) 55 | transforms = Transforms(height, width) 56 | 57 | return drr, isocenter_pose, transforms 58 | 59 | 60 | def train( 61 | model, 62 | optimizer, 63 | scheduler, 64 | drr, 65 | transforms, 66 | isocenter_pose, 67 | batch_size, 68 | n_epochs, 69 | n_batches_per_epoch, 70 | model_params, 71 | id_number, 72 | view, 73 | device, 74 | ): 75 | metric = MultiscaleNormalizedCrossCorrelation2d(eps=1e-4) 76 | geodesic = GeodesicSE3() 77 | double = DoubleGeodesic(drr.detector.sdr) 78 | 79 | best_loss = torch.inf 80 | 81 | model.train() 82 | for epoch in range(n_epochs): 83 | losses = [] 84 | for _ in (itr := tqdm(range(n_batches_per_epoch), leave=False)): 85 | try: 86 | offset = get_random_offset(view, batch_size, device) 87 | pose = isocenter_pose.compose(offset) 88 | img = drr(None, None, None, pose=pose) 89 | img = transforms(img) 90 | 91 | pred_offset = model(img) 92 | pred_pose = isocenter_pose.compose(pred_offset) 93 | pred_img = drr(None, None, None, pose=pred_pose) 94 | pred_img = transforms(pred_img) 95 | 96 | ncc = metric(pred_img, img) 97 | log_geodesic = geodesic(pred_pose, pose) 98 | geodesic_rot, geodesic_xyz, double_geodesic = double(pred_pose, pose) 99 | loss = 1 - ncc + 1e-2 * (log_geodesic + double_geodesic) 100 | 101 | optimizer.zero_grad() 102 | loss.mean().backward() 103 | adaptive_clip_grad_(model.parameters()) 104 | optimizer.step() 105 | scheduler.step() 106 | 107 | losses.append(loss.mean().item()) 108 | 109 | # Update progress bar 110 | itr.set_description(f"Epoch [{epoch}/{n_epochs}]") 111 | itr.set_postfix( 112 | geodesic_rot=geodesic_rot.mean().item(), 113 | geodesic_xyz=geodesic_xyz.mean().item(), 114 | geodesic_dou=double_geodesic.mean().item(), 115 | geodesic_se3=log_geodesic.mean().item(), 116 | loss=loss.mean().item(), 117 | ncc=ncc.mean().item(), 118 | ) 119 | 120 | prev_pose = pose 121 | prev_pred_pose = pred_pose 122 | except: 123 | print("Aaaaaaand we've crashed...") 124 | print(ncc) 125 | print(log_geodesic) 126 | print(geodesic_rot) 127 | print(geodesic_xyz) 128 | print(double_geodesic) 129 | print(pose.get_matrix()) 130 | print(pred_pose.get_matrix()) 131 | torch.save( 132 | { 133 | "model_state_dict": model.state_dict(), 134 | "optimizer_state_dict": optimizer.state_dict(), 135 | "height": drr.detector.height, 136 | "width": drr.detector.width, 137 | "epoch": epoch, 138 | "batch_size": batch_size, 139 | "n_epochs": n_epochs, 140 | "n_batches_per_epoch": n_batches_per_epoch, 141 | "pose": pose.get_matrix().cpu(), 142 | "pred_pose": pred_pose.get_matrix().cpu(), 143 | **model_params, 144 | }, 145 | f"checkpoints/specimen_{id_number:02d}_{view}_crashed.ckpt", 146 | ) 147 | raise RuntimeError("NaN loss") 148 | 149 | losses = torch.tensor(losses) 150 | tqdm.write(f"Epoch {epoch + 1:04d} | Loss {losses.mean().item():.4f}") 151 | if losses.mean() < best_loss and not losses.isnan().any(): 152 | best_loss = losses.mean().item() 153 | torch.save( 154 | { 155 | "model_state_dict": model.state_dict(), 156 | "optimizer_state_dict": optimizer.state_dict(), 157 | "height": drr.detector.height, 158 | "width": drr.detector.width, 159 | "epoch": epoch, 160 | "loss": losses.mean().item(), 161 | "batch_size": batch_size, 162 | "n_epochs": n_epochs, 163 | "n_batches_per_epoch": n_batches_per_epoch, 164 | **model_params, 165 | }, 166 | f"checkpoints/specimen_{id_number:02d}_{view}_best.ckpt", 167 | ) 168 | 169 | if epoch % 25 == 0 and epoch != 0: 170 | torch.save( 171 | { 172 | "model_state_dict": model.state_dict(), 173 | "optimizer_state_dict": optimizer.state_dict(), 174 | "height": drr.detector.height, 175 | "width": drr.detector.width, 176 | "epoch": epoch, 177 | "loss": losses.mean().item(), 178 | "batch_size": batch_size, 179 | "n_epochs": n_epochs, 180 | "n_batches_per_epoch": n_batches_per_epoch, 181 | **model_params, 182 | }, 183 | f"checkpoints/specimen_{id_number:02d}_{view}_epoch{epoch:03d}.ckpt", 184 | ) 185 | 186 | def main( 187 | id_number, 188 | view, 189 | subsample=8, 190 | restart=None, 191 | model_name="resnet18", 192 | parameterization="se3_log_map", 193 | convention=None, 194 | lr=1e-3, 195 | batch_size=1, 196 | n_epochs=1000, 197 | n_batches_per_epoch=100, 198 | ): 199 | device = torch.device("cuda") 200 | drr, isocenter_pose, transforms = load(id_number, view, subsample, device) 201 | 202 | model_params = { 203 | "model_name": model_name, 204 | "parameterization": parameterization, 205 | "convention": convention, 206 | "norm_layer": "groupnorm", 207 | } 208 | model = PoseRegressor(**model_params) 209 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 210 | if restart is not None: 211 | ckpt = torch.load(restart) 212 | model.load_state_dict(ckpt["model_state_dict"]) 213 | optimizer.load_state_dict(ckpt["optimizer_state_dict"]) 214 | model = model.to(device) 215 | 216 | scheduler = WarmupCosineSchedule( 217 | optimizer, 218 | 5 * n_batches_per_epoch, 219 | n_epochs * n_batches_per_epoch - 5 * n_batches_per_epoch, 220 | ) 221 | 222 | train( 223 | model, 224 | optimizer, 225 | scheduler, 226 | drr, 227 | transforms, 228 | isocenter_pose, 229 | batch_size, 230 | n_epochs, 231 | n_batches_per_epoch, 232 | model_params, 233 | id_number, 234 | view, 235 | device, 236 | ) 237 | 238 | if __name__ == "__main__": 239 | id_numbers = list(range(10)) 240 | views = ["ap", "lat"] 241 | id_numbers = [i for i, _ in product(id_numbers, views)] 242 | views = [v for _, v in product(id_numbers, views)] 243 | 244 | Path("checkpoints").mkdir(exist_ok=True) 245 | 246 | executor = submitit.AutoExecutor(folder="logs") 247 | executor.update_parameters( 248 | name="ljubljana", 249 | gpus_per_node=1, 250 | mem_gb=10.0, 251 | slurm_array_parallelism=len(id_numbers), 252 | slurm_partition="2080ti", 253 | timeout_min=10_000, 254 | ) 255 | jobs = executor.map_array(main, id_numbers, views) 256 | -------------------------------------------------------------------------------- /experiments/test_time_optimization.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eigenvivek/DiffPose/afc0159b3c1db4ea1e1bd55cd4fc85b0408c3974/experiments/test_time_optimization.gif -------------------------------------------------------------------------------- /notebooks/_quarto.yml: -------------------------------------------------------------------------------- 1 | project: 2 | type: website 3 | 4 | format: 5 | html: 6 | theme: cosmo 7 | css: styles.css 8 | toc: true 9 | 10 | website: 11 | twitter-card: true 12 | open-graph: true 13 | repo-actions: [issue] 14 | favicon: favicon.png 15 | navbar: 16 | background: primary 17 | search: true 18 | right: 19 | - icon: github 20 | href: "https://github.com/eigenvivek/DiffPose" 21 | sidebar: 22 | style: floating 23 | 24 | metadata-files: [nbdev.yml, sidebar.yml] -------------------------------------------------------------------------------- /notebooks/api/02_calibration.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "raw", 5 | "id": "a08b0eaa-f104-49a3-bcdf-a1b69ea31bc7", 6 | "metadata": {}, 7 | "source": [ 8 | "---\n", 9 | "title: calibration\n", 10 | "subtitle: Rigid transforms with camera calibration matrices\n", 11 | "---" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "id": "03a5387b-a349-4b8e-a144-a56d30972ffa", 17 | "metadata": {}, 18 | "source": [ 19 | "An X-ray C-arm can be modeled as a pinhole camera with its own extrinsic and intrinsic matrices. \n", 20 | "This module provides utilities for parsing these matrices and working with rigid transforms." 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "id": "8634b454-9289-4857-80f8-232d913de6a8", 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "#| default_exp calibration" 31 | ] 32 | }, 33 | { 34 | "cell_type": "code", 35 | "execution_count": null, 36 | "id": "359cb689-f4ab-4336-a5a9-b75be9492fb7", 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "#| hide\n", 41 | "from nbdev.showdoc import *" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "9f67e02b-580c-4f37-9a90-fcca5713de08", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "#| export\n", 52 | "import torch" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "id": "594553a4-4ec9-46e5-96b4-4430edc84135", 58 | "metadata": {}, 59 | "source": [ 60 | "## Rigid transformations\n", 61 | "\n", 62 | "We represent rigid transforms as $4 \\times 4$ matrices (following the right-handed convention of `PyTorch3D`),\n", 63 | "\n", 64 | "\\begin{equation}\n", 65 | "\\begin{bmatrix}\n", 66 | " \\mathbf R^T & \\mathbf 0 \\\\\n", 67 | " \\mathbf t^T & 1\n", 68 | "\\end{bmatrix}\n", 69 | "\\in \\mathbf{SE}(3) \\,,\n", 70 | "\\end{equation}\n", 71 | "\n", 72 | "where $\\mathbf R \\in \\mathbf{SO}(3)$ is a rotation matrix and $\\mathbf t\\in \\mathbb R^3$ represents a translation.\n", 73 | "\n", 74 | "Note that since rotation matrices are orthogonal, we have a simple closed-form equation for the inverse:\n", 75 | "\\begin{equation}\n", 76 | "\\begin{bmatrix}\n", 77 | " \\mathbf R^T & \\mathbf 0 \\\\\n", 78 | " \\mathbf t^T & 1\n", 79 | "\\end{bmatrix}^{-1} =\n", 80 | "\\begin{bmatrix}\n", 81 | " \\mathbf R & \\mathbf 0 \\\\\n", 82 | " -\\mathbf R \\mathbf t & 1\n", 83 | "\\end{bmatrix} \\,.\n", 84 | "\\end{equation}\n", 85 | "\n", 86 | "For convenience, we add a wrapper of `pytorch3d.transforms.Transform3d` that can be construced from a (batched) rotation matrix and translation vector. This module also includes the closed-form inverse specific to rigid transforms." 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": null, 92 | "id": "008781b6-de62-473f-b2d7-c2f7eb051b41", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "#| export\n", 97 | "from typing import Optional\n", 98 | "\n", 99 | "from beartype import beartype\n", 100 | "from diffdrr.utils import Transform3d\n", 101 | "from diffdrr.utils import convert as convert_so3\n", 102 | "from diffdrr.utils import se3_exp_map, se3_log_map\n", 103 | "from jaxtyping import Float, jaxtyped" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "86492beb-3f1f-40ac-aff6-018ba1d04067", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "#| export\n", 114 | "@beartype\n", 115 | "class RigidTransform(Transform3d):\n", 116 | " \"\"\"Wrapper of pytorch3d.transforms.Transform3d with extra functionalities.\"\"\"\n", 117 | "\n", 118 | " @jaxtyped(typechecker=beartype)\n", 119 | " def __init__(\n", 120 | " self,\n", 121 | " R: Float[torch.Tensor, \"...\"],\n", 122 | " t: Float[torch.Tensor, \"... 3\"],\n", 123 | " parameterization: str = \"matrix\",\n", 124 | " convention: Optional[str] = None,\n", 125 | " device=None,\n", 126 | " dtype=torch.float32,\n", 127 | " ):\n", 128 | " if device is None and (R.device == t.device):\n", 129 | " device = R.device\n", 130 | "\n", 131 | " R = convert_so3(R, parameterization, \"matrix\", convention)\n", 132 | " if R.dim() == 2 and t.dim() == 1:\n", 133 | " R = R.unsqueeze(0)\n", 134 | " t = t.unsqueeze(0)\n", 135 | " assert (batch_size := len(R)) == len(t), \"R and t need same batch size\"\n", 136 | "\n", 137 | " matrix = torch.zeros(batch_size, 4, 4, device=device, dtype=dtype)\n", 138 | " matrix[..., :3, :3] = R.transpose(-1, -2)\n", 139 | " matrix[..., 3, :3] = t\n", 140 | " matrix[..., 3, 3] = 1\n", 141 | "\n", 142 | " super().__init__(matrix=matrix, device=device, dtype=dtype)\n", 143 | "\n", 144 | " def get_rotation(self, parameterization=None, convention=None):\n", 145 | " R = self.get_matrix()[..., :3, :3].transpose(-1, -2)\n", 146 | " if parameterization is not None:\n", 147 | " R = convert_so3(R, \"matrix\", parameterization, None, convention)\n", 148 | " return R\n", 149 | "\n", 150 | " def get_translation(self):\n", 151 | " return self.get_matrix()[..., 3, :3]\n", 152 | "\n", 153 | " def inverse(self):\n", 154 | " \"\"\"Closed-form inverse for rigid transforms.\"\"\"\n", 155 | " R = self.get_rotation().transpose(-1, -2)\n", 156 | " t = self.get_translation()\n", 157 | " t = -torch.einsum(\"bij,bj->bi\", R, t)\n", 158 | " return RigidTransform(R, t, device=self.device, dtype=self.dtype)\n", 159 | "\n", 160 | " def compose(self, other):\n", 161 | " T = super().compose(other)\n", 162 | " R = T.get_matrix()[..., :3, :3].transpose(-1, -2)\n", 163 | " t = T.get_matrix()[..., 3, :3]\n", 164 | " return RigidTransform(R, t, device=self.device, dtype=self.dtype)\n", 165 | "\n", 166 | " def clone(self):\n", 167 | " R = self.get_matrix()[..., :3, :3].transpose(-1, -2).clone()\n", 168 | " t = self.get_matrix()[..., 3, :3].clone()\n", 169 | " return RigidTransform(R, t, device=self.device, dtype=self.dtype)\n", 170 | "\n", 171 | " def get_se3_log(self):\n", 172 | " return se3_log_map(self.get_matrix())" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "id": "ff7984c3-a8f5-435f-b504-dce55787f517", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "#| export\n", 183 | "def convert(\n", 184 | " transform,\n", 185 | " input_parameterization,\n", 186 | " output_parameterization,\n", 187 | " input_convention=None,\n", 188 | " output_convention=None,\n", 189 | "):\n", 190 | " \"\"\"Convert between representations of SE(3).\"\"\"\n", 191 | "\n", 192 | " # Convert any input parameterization to a RigidTransform\n", 193 | " if input_parameterization == \"se3_log_map\":\n", 194 | " transform = torch.concat([transform[1], transform[0]], axis=-1)\n", 195 | " matrix = se3_exp_map(transform).transpose(-1, -2)\n", 196 | " transform = RigidTransform(\n", 197 | " R=matrix[..., :3, :3],\n", 198 | " t=matrix[..., :3, 3],\n", 199 | " device=matrix.device,\n", 200 | " dtype=matrix.dtype,\n", 201 | " )\n", 202 | " elif input_parameterization == \"se3_exp_map\":\n", 203 | " pass\n", 204 | " else:\n", 205 | " transform = RigidTransform(\n", 206 | " R=transform[0],\n", 207 | " t=transform[1],\n", 208 | " parameterization=input_parameterization,\n", 209 | " convention=input_convention,\n", 210 | " )\n", 211 | "\n", 212 | " # Convert the RigidTransform to any output\n", 213 | " if output_parameterization == \"se3_exp_map\":\n", 214 | " return transform\n", 215 | " elif output_parameterization == \"se3_log_map\":\n", 216 | " se3_log = transform.get_se3_log()\n", 217 | " log_t_vee = se3_log[..., :3]\n", 218 | " log_R_vee = se3_log[..., 3:]\n", 219 | " return log_R_vee, log_t_vee\n", 220 | " else:\n", 221 | " return (\n", 222 | " transform.get_rotation(output_parameterization, output_convention),\n", 223 | " transform.get_translation(),\n", 224 | " )" 225 | ] 226 | }, 227 | { 228 | "cell_type": "markdown", 229 | "id": "7968b28c-ba34-49c7-8acc-bb080c0e4556", 230 | "metadata": {}, 231 | "source": [ 232 | "## Computing a perspective projection\n", 233 | "\n", 234 | "Given an `extrinsic` and `intrinsic` camera matrix, we can compute the perspective projection of a batch of points.\n", 235 | "This is used for computing where fiducials in world coordinates get mapped onto the image plane." 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "id": "8347585b-1a3b-4246-9fc4-d5f4dad5944b", 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [ 245 | "#| export\n", 246 | "@jaxtyped(typechecker=beartype)\n", 247 | "def perspective_projection(\n", 248 | " extrinsic: RigidTransform, # Extrinsic camera matrix (world to camera)\n", 249 | " intrinsic: Float[torch.Tensor, \"3 3\"], # Intrinsic camera matrix (camera to image)\n", 250 | " x: Float[torch.Tensor, \"b n 3\"], # World coordinates\n", 251 | ") -> Float[torch.Tensor, \"b n 2\"]:\n", 252 | " x = extrinsic.transform_points(x)\n", 253 | " x = torch.einsum(\"ij, bnj -> bni\", intrinsic, x)\n", 254 | " z = x[..., -1].unsqueeze(-1).clone()\n", 255 | " x = x / z\n", 256 | " return x[..., :2]" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "id": "b18dbb07-f6cd-4de2-ad25-a784ad621471", 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "#| hide\n", 267 | "import nbdev\n", 268 | "\n", 269 | "nbdev.nbdev_export()" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": null, 275 | "id": "a304b6f4-6558-4d63-99f8-dd9c82b9fddf", 276 | "metadata": {}, 277 | "outputs": [], 278 | "source": [] 279 | } 280 | ], 281 | "metadata": { 282 | "kernelspec": { 283 | "display_name": "python3", 284 | "language": "python", 285 | "name": "python3" 286 | } 287 | }, 288 | "nbformat": 4, 289 | "nbformat_minor": 5 290 | } 291 | -------------------------------------------------------------------------------- /notebooks/api/04_metrics.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "raw", 5 | "id": "c6e3d313-4b94-4b76-932c-fd67a999361a", 6 | "metadata": {}, 7 | "source": [ 8 | "---\n", 9 | "title: metrics\n", 10 | "subtitle: Image similarity metrics and geodesic distances for camera poses\n", 11 | "---" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "3a3d1f8e-5b89-41b3-889d-2b054f63c125", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "#| default_exp metrics" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "8a22f019-d10f-4c94-aa9e-fc6d32f74f4b", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "#| hide\n", 32 | "from nbdev.showdoc import *" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": null, 38 | "id": "9a452921-86fb-49c1-a3c6-749fad621a55", 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "#| export\n", 43 | "from diffdrr.metrics import (\n", 44 | " GradientNormalizedCrossCorrelation2d,\n", 45 | " MultiscaleNormalizedCrossCorrelation2d,\n", 46 | " NormalizedCrossCorrelation2d,\n", 47 | ")\n", 48 | "from torchmetrics import Metric" 49 | ] 50 | }, 51 | { 52 | "cell_type": "markdown", 53 | "id": "870e3dc8-fc8d-46ef-9204-66847c1cf4df", 54 | "metadata": {}, 55 | "source": [ 56 | "## Image similarity metrics\n", 57 | "\n", 58 | "Used to quantify the similarity between ground truth X-rays ($\\mathbf I$) and synthetic X-rays generated from estimated camera poses ($\\hat{\\mathbf I}$). If a metric is differentiable, it can be used to optimize camera poses with `DiffDRR`." 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "3d5b1780-fa0a-4d81-8103-40572c938431", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "#| exporti\n", 69 | "class CustomMetric(Metric):\n", 70 | " is_differentiable: True\n", 71 | "\n", 72 | " def __init__(self, LossClass, **kwargs):\n", 73 | " super().__init__()\n", 74 | " self.lossfn = LossClass(**kwargs)\n", 75 | " self.add_state(\"loss\", default=torch.tensor(0.0), dist_reduce_fx=\"sum\")\n", 76 | " self.add_state(\"count\", default=torch.tensor(0), dist_reduce_fx=\"sum\")\n", 77 | "\n", 78 | " def update(self, preds, target):\n", 79 | " self.loss += self.lossfn(preds, target).sum()\n", 80 | " self.count += len(preds)\n", 81 | "\n", 82 | " def compute(self):\n", 83 | " return self.loss.float() / self.count" 84 | ] 85 | }, 86 | { 87 | "cell_type": "markdown", 88 | "id": "cbfb768c-3745-47b6-97c8-760debc83fa6", 89 | "metadata": {}, 90 | "source": [ 91 | "`NCC` and `GradNCC` are originally implemented in [`diffdrr.metrics`](https://github.com/eigenvivek/DiffDRR/blob/main/notebooks/api/05_metrics.ipynb).\n", 92 | "`DiffPose` provides `torchmetrics` wrappers for these functions." 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "id": "dc49ea90-0841-4ac5-9b17-2faab1191d10", 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "#| export\n", 103 | "class NormalizedCrossCorrelation(CustomMetric):\n", 104 | " \"\"\"`torchmetric` wrapper for NCC.\"\"\"\n", 105 | "\n", 106 | " higher_is_better: True\n", 107 | "\n", 108 | " def __init__(self, patch_size=None):\n", 109 | " super().__init__(NormalizedCrossCorrelation2d, patch_size=patch_size)\n", 110 | "\n", 111 | "\n", 112 | "class MultiscaleNormalizedCrossCorrelation(CustomMetric):\n", 113 | " \"\"\"`torchmetric` wrapper for Multiscale NCC.\"\"\"\n", 114 | "\n", 115 | " higher_is_better: True\n", 116 | "\n", 117 | " def __init__(self, patch_sizes, patch_weights):\n", 118 | " super().__init__(\n", 119 | " MultiscaleNormalizedCrossCorrelation2d,\n", 120 | " patch_sizes=patch_sizes,\n", 121 | " patch_weights=patch_weights,\n", 122 | " )\n", 123 | "\n", 124 | "\n", 125 | "class GradientNormalizedCrossCorrelation(CustomMetric):\n", 126 | " \"\"\"`torchmetric` wrapper for GradNCC.\"\"\"\n", 127 | "\n", 128 | " higher_is_better: True\n", 129 | "\n", 130 | " def __init__(self, patch_size=None):\n", 131 | " super().__init__(GradientNormalizedCrossCorrelation2d, patch_size=patch_size)" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "id": "f394f290-bac4-4831-bec6-5ac88dcdd914", 137 | "metadata": {}, 138 | "source": [ 139 | "## Geodesic distances for SO(3) and SE(3)\n", 140 | "\n", 141 | "One can define geodesic pseudo-distances on SO(3) and SE(3).[^1] This let's us measure registration error (in radians and millimeters, respectively) on poses, rather than needed to compute the projection of fiducials.\n", 142 | "\n", 143 | "- **For SO(3)**, the geodesic distance between two rotation matrices $\\mathbf R_A$ and $\\mathbf R_B$ is\n", 144 | "\\begin{equation}\n", 145 | " d_\\theta(\\mathbf R_A, \\mathbf R_B; r) = r \\left| \\arccos \\left( \\frac{\\mathrm{trace}(\\mathbf R_A^* \\mathbf R_B) - 1}{2} \\right ) \\right| \\,,\n", 146 | "\\end{equation}\n", 147 | "where $r$, the source-to-detector radius, is used to convert radians to millimeters.\n", 148 | "\n", 149 | "- **For SE(3)**, we decompose the transformation matrix into a rotation and a translation, i.e., $\\mathbf T = (\\mathbf R, \\mathbf t)$.\n", 150 | "Then, we compute the geodesic on translations (just Euclidean distance),\n", 151 | "\\begin{equation}\n", 152 | " d_t(\\mathbf t_A, \\mathbf t_B) = \\| \\mathbf t_A - \\mathbf t_B \\|_2 \\,.\n", 153 | "\\end{equation}\n", 154 | "Finally, we compute the *double geodesic* on the rotations and translations:\n", 155 | "\\begin{equation}\n", 156 | " d(\\mathbf T_A, \\mathbf T_B) = \\sqrt{d_\\theta(\\mathbf R_A, \\mathbf R_B)^2 + d_t(\\mathbf t_A, \\mathbf t_B)^2} \\,.\n", 157 | "\\end{equation}\n", 158 | "\n", 159 | "[^1]: [https://vnav.mit.edu/material/04-05-LieGroups-notes.pdf](https://vnav.mit.edu/material/04-05-LieGroups-notes.pdf)" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "id": "6400a4b1-417e-48b5-aa7f-207c8cdf893d", 166 | "metadata": {}, 167 | "outputs": [], 168 | "source": [ 169 | "#| export\n", 170 | "import torch\n", 171 | "from beartype import beartype\n", 172 | "from diffdrr.utils import (\n", 173 | " convert,\n", 174 | " so3_log_map,\n", 175 | " so3_relative_angle,\n", 176 | " so3_rotation_angle,\n", 177 | " standardize_quaternion,\n", 178 | ")\n", 179 | "from jaxtyping import Float, jaxtyped\n", 180 | "\n", 181 | "from diffpose.calibration import RigidTransform" 182 | ] 183 | }, 184 | { 185 | "cell_type": "code", 186 | "execution_count": null, 187 | "id": "1ff308dc-4807-46dd-bd10-ef9dca35c4c9", 188 | "metadata": {}, 189 | "outputs": [], 190 | "source": [ 191 | "#| export\n", 192 | "class GeodesicSO3(torch.nn.Module):\n", 193 | " \"\"\"Calculate the angular distance between two rotations in SO(3).\"\"\"\n", 194 | "\n", 195 | " def __init__(self):\n", 196 | " super().__init__()\n", 197 | "\n", 198 | " @jaxtyped(typechecker=beartype)\n", 199 | " def forward(\n", 200 | " self,\n", 201 | " pose_1: RigidTransform,\n", 202 | " pose_2: RigidTransform,\n", 203 | " ) -> Float[torch.Tensor, \"b\"]:\n", 204 | " r1 = pose_1.get_rotation()\n", 205 | " r2 = pose_2.get_rotation()\n", 206 | " rdiff = r1 @ r2.transpose(-1, -2)\n", 207 | " return so3_log_map(rdiff).norm(dim=-1)\n", 208 | "\n", 209 | "\n", 210 | "class GeodesicTranslation(torch.nn.Module):\n", 211 | " \"\"\"Calculate the angular distance between two rotations in SO(3).\"\"\"\n", 212 | "\n", 213 | " def __init__(self):\n", 214 | " super().__init__()\n", 215 | "\n", 216 | " @jaxtyped(typechecker=beartype)\n", 217 | " def forward(\n", 218 | " self,\n", 219 | " pose_1: RigidTransform,\n", 220 | " pose_2: RigidTransform,\n", 221 | " ) -> Float[torch.Tensor, \"b\"]:\n", 222 | " t1 = pose_1.get_translation()\n", 223 | " t2 = pose_2.get_translation()\n", 224 | " return (t1 - t2).norm(dim=1)" 225 | ] 226 | }, 227 | { 228 | "cell_type": "code", 229 | "execution_count": null, 230 | "id": "1ad99b83-9759-4909-9930-598cf9c8433a", 231 | "metadata": {}, 232 | "outputs": [], 233 | "source": [ 234 | "#| export\n", 235 | "class GeodesicSE3(torch.nn.Module):\n", 236 | " \"\"\"Calculate the distance between transforms in the log-space of SE(3).\"\"\"\n", 237 | "\n", 238 | " def __init__(self):\n", 239 | " super().__init__()\n", 240 | "\n", 241 | " @jaxtyped(typechecker=beartype)\n", 242 | " def forward(\n", 243 | " self,\n", 244 | " pose_1: RigidTransform,\n", 245 | " pose_2: RigidTransform,\n", 246 | " ) -> Float[torch.Tensor, \"b\"]:\n", 247 | " return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": null, 253 | "id": "099c7c47-0d3b-4c4b-8ce9-c50051fa115d", 254 | "metadata": {}, 255 | "outputs": [], 256 | "source": [ 257 | "#| export\n", 258 | "@beartype\n", 259 | "class DoubleGeodesic(torch.nn.Module):\n", 260 | " \"\"\"Calculate the angular and translational geodesics between two SE(3) transformation matrices.\"\"\"\n", 261 | "\n", 262 | " def __init__(\n", 263 | " self,\n", 264 | " sdr: float, # Source-to-detector radius\n", 265 | " eps: float = 1e-4, # Avoid overflows in sqrt\n", 266 | " ):\n", 267 | " super().__init__()\n", 268 | " self.sdr = sdr\n", 269 | " self.eps = eps\n", 270 | "\n", 271 | " self.rotation = GeodesicSO3()\n", 272 | " self.translation = GeodesicTranslation()\n", 273 | "\n", 274 | " @jaxtyped(typechecker=beartype)\n", 275 | " def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):\n", 276 | " angular_geodesic = self.sdr * self.rotation(pose_1, pose_2)\n", 277 | " translation_geodesic = self.translation(pose_1, pose_2)\n", 278 | " double_geodesic = (\n", 279 | " (angular_geodesic).square() + translation_geodesic.square() + self.eps\n", 280 | " ).sqrt()\n", 281 | " return angular_geodesic, translation_geodesic, double_geodesic" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": null, 287 | "id": "708e9b0b-3fd1-4671-b9a8-a993ced0d187", 288 | "metadata": {}, 289 | "outputs": [ 290 | { 291 | "name": "stdout", 292 | "output_type": "stream", 293 | "text": [ 294 | "tensor([0.])\n", 295 | "tensor([0.1000])\n" 296 | ] 297 | } 298 | ], 299 | "source": [ 300 | "# SO(3) distance\n", 301 | "geodesic_so3 = GeodesicSO3()\n", 302 | "\n", 303 | "pose_1 = RigidTransform(\n", 304 | " torch.tensor([[0.1, 1.0, torch.pi]]),\n", 305 | " torch.ones(1, 3),\n", 306 | " parameterization=\"euler_angles\",\n", 307 | " convention=\"ZYX\",\n", 308 | ")\n", 309 | "pose_2 = RigidTransform(\n", 310 | " torch.tensor([[0.1, 1.0, torch.pi]]),\n", 311 | " torch.ones(1, 3),\n", 312 | " parameterization=\"euler_angles\",\n", 313 | " convention=\"ZYX\",\n", 314 | ")\n", 315 | "\n", 316 | "print(geodesic_so3(pose_1, pose_2)) # Angular distance in radians\n", 317 | "\n", 318 | "pose_1 = RigidTransform(\n", 319 | " torch.tensor([[0.1, 1.0, torch.pi]]),\n", 320 | " torch.ones(1, 3),\n", 321 | " parameterization=\"euler_angles\",\n", 322 | " convention=\"ZYX\",\n", 323 | ")\n", 324 | "pose_2 = RigidTransform(\n", 325 | " torch.tensor([[0.1, 1.1, torch.pi]]),\n", 326 | " torch.ones(1, 3),\n", 327 | " parameterization=\"euler_angles\",\n", 328 | " convention=\"ZYX\",\n", 329 | ")\n", 330 | "\n", 331 | "print(geodesic_so3(pose_1, pose_2)) # Angular distance in radians" 332 | ] 333 | }, 334 | { 335 | "cell_type": "code", 336 | "execution_count": null, 337 | "id": "c0ff1bce-f8a1-4c40-bb2f-4e9dfa87ac17", 338 | "metadata": {}, 339 | "outputs": [ 340 | { 341 | "data": { 342 | "text/plain": [ 343 | "tensor([1.7355])" 344 | ] 345 | }, 346 | "execution_count": null, 347 | "metadata": {}, 348 | "output_type": "execute_result" 349 | } 350 | ], 351 | "source": [ 352 | "# SE(3) distance\n", 353 | "geodesic_se3 = GeodesicSE3()\n", 354 | "\n", 355 | "pose_1 = RigidTransform(\n", 356 | " torch.tensor([[0.1, 1.0, torch.pi]]),\n", 357 | " torch.ones(1, 3),\n", 358 | " parameterization=\"euler_angles\",\n", 359 | " convention=\"ZYX\",\n", 360 | ")\n", 361 | "pose_2 = RigidTransform(\n", 362 | " torch.tensor([[0.1, 1.1, torch.pi]]),\n", 363 | " torch.zeros(1, 3),\n", 364 | " parameterization=\"euler_angles\",\n", 365 | " convention=\"ZYX\",\n", 366 | ")\n", 367 | "\n", 368 | "geodesic_se3(pose_1, pose_2)" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": null, 374 | "id": "8b69987b-98dc-431e-b885-b9e992e0e91a", 375 | "metadata": {}, 376 | "outputs": [ 377 | { 378 | "data": { 379 | "text/plain": [ 380 | "(tensor([51.0000]), tensor([1.7321]), tensor([51.0294]))" 381 | ] 382 | }, 383 | "execution_count": null, 384 | "metadata": {}, 385 | "output_type": "execute_result" 386 | } 387 | ], 388 | "source": [ 389 | "# Angular distance and translational distance both in mm\n", 390 | "double_geodesic = DoubleGeodesic(1020 / 2)\n", 391 | "\n", 392 | "pose_1 = RigidTransform(\n", 393 | " torch.tensor([[0.1, 1.0, torch.pi]]),\n", 394 | " torch.ones(1, 3),\n", 395 | " parameterization=\"euler_angles\",\n", 396 | " convention=\"ZYX\",\n", 397 | ")\n", 398 | "pose_2 = RigidTransform(\n", 399 | " torch.tensor([[0.1, 1.1, torch.pi]]),\n", 400 | " torch.zeros(1, 3),\n", 401 | " parameterization=\"euler_angles\",\n", 402 | " convention=\"ZYX\",\n", 403 | ")\n", 404 | "\n", 405 | "double_geodesic(pose_1, pose_2)" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": null, 411 | "id": "ba77dd5d-cd83-4a1e-8520-5d27d9b972fb", 412 | "metadata": {}, 413 | "outputs": [], 414 | "source": [ 415 | "#| hide\n", 416 | "import nbdev\n", 417 | "\n", 418 | "nbdev.nbdev_export()" 419 | ] 420 | }, 421 | { 422 | "cell_type": "code", 423 | "execution_count": null, 424 | "id": "fa4fbd2b-75ce-48d0-90ff-56ddb421a547", 425 | "metadata": {}, 426 | "outputs": [], 427 | "source": [] 428 | } 429 | ], 430 | "metadata": { 431 | "kernelspec": { 432 | "display_name": "python3", 433 | "language": "python", 434 | "name": "python3" 435 | } 436 | }, 437 | "nbformat": 4, 438 | "nbformat_minor": 5 439 | } 440 | -------------------------------------------------------------------------------- /notebooks/api/05_visualization.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "raw", 5 | "id": "b50fb661-62ae-4706-a86e-1b8f50f08f0d", 6 | "metadata": {}, 7 | "source": [ 8 | "---\n", 9 | "title: visualization\n", 10 | "subtitle: Plots for registration and 3D visualization\n", 11 | "---" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "3a3d1f8e-5b89-41b3-889d-2b054f63c125", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "#| default_exp visualization" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": null, 27 | "id": "8a22f019-d10f-4c94-aa9e-fc6d32f74f4b", 28 | "metadata": {}, 29 | "outputs": [], 30 | "source": [ 31 | "#| hide\n", 32 | "from nbdev.showdoc import *" 33 | ] 34 | }, 35 | { 36 | "cell_type": "markdown", 37 | "id": "aefed743-3d9c-4316-a376-04e24b04a636", 38 | "metadata": {}, 39 | "source": [ 40 | "## Overlay over predicted edges on target images" 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "9a452921-86fb-49c1-a3c6-749fad621a55", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "#| export\n", 51 | "from io import BytesIO\n", 52 | "\n", 53 | "import matplotlib.pyplot as plt\n", 54 | "import numpy as np\n", 55 | "import torch\n", 56 | "from PIL import Image\n", 57 | "from skimage.feature import canny\n", 58 | "from torchvision.utils import make_grid" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "3d5b1780-fa0a-4d81-8103-40572c938431", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "#| exporti\n", 69 | "def _overlay_edges(target, pred, sigma, eps=1e-5):\n", 70 | " pred = (pred - pred.min()) / (pred.max() - pred.min() + eps)\n", 71 | " edges = canny(pred, sigma=sigma)\n", 72 | " edges = np.ma.masked_where(~edges, edges)\n", 73 | "\n", 74 | " buffer = BytesIO()\n", 75 | " plt.subplot()\n", 76 | " plt.imshow(target, cmap=\"gray\")\n", 77 | " plt.imshow(edges, cmap=\"cool_r\", interpolation=\"none\", vmin=0.0, vmax=1.0)\n", 78 | " plt.axis(\"off\")\n", 79 | " plt.savefig(buffer, format=\"png\", bbox_inches=\"tight\", pad_inches=0, dpi=300)\n", 80 | " arr = np.array(Image.open(buffer))\n", 81 | " plt.close()\n", 82 | " return arr" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "id": "3dc560ca-c5f3-4569-990f-efa40d7c8481", 89 | "metadata": {}, 90 | "outputs": [], 91 | "source": [ 92 | "#| export\n", 93 | "def overlay_edges(target, pred, sigma=1.5):\n", 94 | " \"\"\"Generate edge overlays for a batch of targets and predictions.\"\"\"\n", 95 | " edges = []\n", 96 | " for i, p in zip(target, pred):\n", 97 | " edge = _overlay_edges(i[0].cpu().numpy(), p[0].cpu().numpy(), sigma)\n", 98 | " edges.append(edge)\n", 99 | " edges = torch.from_numpy(np.stack(edges)).permute(0, -1, 1, 2)\n", 100 | " edges = make_grid(edges).permute(1, 2, 0)\n", 101 | " return edges" 102 | ] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "id": "ab90aa75-5869-49d7-aaa2-5f9041295840", 107 | "metadata": {}, 108 | "source": [ 109 | "## Using PyVista to visualize 3D geometry" 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "65d197c9-dfbd-4029-bd23-05b4fda02a76", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "#| export\n", 120 | "import pyvista\n", 121 | "from torch.nn.functional import pad\n", 122 | "\n", 123 | "from diffpose.calibration import RigidTransform, perspective_projection" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "5b5897c4-605e-44eb-aa78-097a273211bc", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "#| exporti\n", 134 | "def fiducials_3d_to_projected_fiducials_3d(specimen, pose):\n", 135 | " # Extrinsic camera matrix\n", 136 | " extrinsic = (\n", 137 | " specimen.lps2volume.inverse()\n", 138 | " .compose(pose.inverse())\n", 139 | " .compose(specimen.translate)\n", 140 | " .compose(specimen.flip_xz)\n", 141 | " )\n", 142 | "\n", 143 | " # Intrinsic projection -> in 3D\n", 144 | " x = perspective_projection(extrinsic, specimen.intrinsic, specimen.fiducials)\n", 145 | " x = -specimen.focal_len * torch.einsum(\n", 146 | " \"ij, bnj -> bni\",\n", 147 | " specimen.intrinsic.inverse(),\n", 148 | " pad(x, (0, 1), value=1), # Convert to homogenous coordinates\n", 149 | " )\n", 150 | "\n", 151 | " # Some command-z\n", 152 | " extrinsic = (\n", 153 | " specimen.flip_xz.inverse().compose(specimen.translate.inverse()).compose(pose)\n", 154 | " )\n", 155 | " return extrinsic.transform_points(x)" 156 | ] 157 | }, 158 | { 159 | "cell_type": "code", 160 | "execution_count": null, 161 | "id": "902878a6-16b1-4ea0-b4b7-3693c5e7607d", 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "#| export\n", 166 | "def fiducials_to_mesh(\n", 167 | " specimen,\n", 168 | " rotation=None,\n", 169 | " translation=None,\n", 170 | " parameterization=None,\n", 171 | " convention=None,\n", 172 | " detector=None,\n", 173 | "):\n", 174 | " \"\"\"\n", 175 | " Use camera matrices to get 2D projections of 3D fiducials for a given pose.\n", 176 | " If the detector is passed, 2D projections will be filtered for those that lie\n", 177 | " on the detector plane.\n", 178 | " \"\"\"\n", 179 | " # Location of fiducials in 3D\n", 180 | " fiducials_3d = specimen.lps2volume.inverse().transform_points(specimen.fiducials)\n", 181 | " fiducials_3d = pyvista.PolyData(fiducials_3d.squeeze().numpy())\n", 182 | " if rotation is None and translation is None and parameterization is None:\n", 183 | " return fiducials_3d\n", 184 | "\n", 185 | " # Embedding of fiducials in 2D\n", 186 | " pose = RigidTransform(rotation, translation, parameterization, convention, device=\"cpu\")\n", 187 | " fiducials_2d = fiducials_3d_to_projected_fiducials_3d(specimen, pose)\n", 188 | " fiducials_2d = fiducials_2d.squeeze().numpy()\n", 189 | "\n", 190 | " # Optionally, only render 2D fiducials that lie on the detector plane\n", 191 | " if detector is not None:\n", 192 | " corners = detector.points.reshape(\n", 193 | " detector[\"height\"][0], detector[\"width\"][0], 3\n", 194 | " )[\n", 195 | " [0, 0, -1, -1],\n", 196 | " [0, -1, 0, -1],\n", 197 | " ]\n", 198 | " exclude = np.logical_or(\n", 199 | " fiducials_2d < corners.min(0),\n", 200 | " fiducials_2d > corners.max(0),\n", 201 | " ).any(1)\n", 202 | " fiducials_2d = fiducials_2d[~exclude]\n", 203 | "\n", 204 | " fiducials_2d = pyvista.PolyData(fiducials_2d)\n", 205 | " return fiducials_3d, fiducials_2d" 206 | ] 207 | }, 208 | { 209 | "cell_type": "code", 210 | "execution_count": null, 211 | "id": "87699278-de70-44c7-887e-7b63ba3f1981", 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "#| export\n", 216 | "def lines_to_mesh(camera, fiducials_2d):\n", 217 | " \"\"\"Draw lines from the camera to the 2D fiducials.\"\"\"\n", 218 | " lines = []\n", 219 | " for pt in fiducials_2d.points:\n", 220 | " line = pyvista.Line(pt, camera.center)\n", 221 | " lines.append(line)\n", 222 | " return lines" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": null, 228 | "id": "ba77dd5d-cd83-4a1e-8520-5d27d9b972fb", 229 | "metadata": {}, 230 | "outputs": [], 231 | "source": [ 232 | "#| hide\n", 233 | "import nbdev\n", 234 | "\n", 235 | "nbdev.nbdev_export()" 236 | ] 237 | }, 238 | { 239 | "cell_type": "code", 240 | "execution_count": null, 241 | "id": "fa4fbd2b-75ce-48d0-90ff-56ddb421a547", 242 | "metadata": {}, 243 | "outputs": [], 244 | "source": [] 245 | } 246 | ], 247 | "metadata": { 248 | "kernelspec": { 249 | "display_name": "python3", 250 | "language": "python", 251 | "name": "python3" 252 | } 253 | }, 254 | "nbformat": 4, 255 | "nbformat_minor": 5 256 | } 257 | -------------------------------------------------------------------------------- /notebooks/experiments/test_time_optimization.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eigenvivek/DiffPose/afc0159b3c1db4ea1e1bd55cd4fc85b0408c3974/notebooks/experiments/test_time_optimization.gif -------------------------------------------------------------------------------- /notebooks/favicon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eigenvivek/DiffPose/afc0159b3c1db4ea1e1bd55cd4fc85b0408c3974/notebooks/favicon.png -------------------------------------------------------------------------------- /notebooks/index.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "raw", 5 | "metadata": {}, 6 | "source": [ 7 | "---\n", 8 | "title: DiffPose\n", 9 | "---" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "> Intraoperative 2D/3D registration via differentiable X-ray rendering\n", 17 | "\n", 18 | "[![CI](https://github.com/eigenvivek/DiffPose/actions/workflows/test.yaml/badge.svg)](https://github.com/eigenvivek/DiffPose/actions/workflows/test.yaml)\n", 19 | "[![Paper shield](https://img.shields.io/badge/arXiv-2312.06358-red.svg)](https://arxiv.org/abs/2312.06358)\n", 20 | "[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE)\n", 21 | "[![Docs](https://github.com/eigenvivek/DiffPose/actions/workflows/deploy.yaml/badge.svg)](https://vivekg.dev/DiffPose)\n", 22 | "[![Code style: black](https://img.shields.io/badge/Code%20style-black-black.svg)](https://github.com/psf/black)\n", 23 | "\n", 24 | "![](experiments/test_time_optimization.gif)" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "## Install\n", 32 | "\n", 33 | "To install `DiffPose` and the requirements in [`environment.yml`](https://github.com/eigenvivek/DiffPose/blob/main/environment.yml), run:\n", 34 | "\n", 35 | "```zsh\n", 36 | "pip install diffpose\n", 37 | "```\n", 38 | "\n", 39 | "The differentiable X-ray renderer that powers the backend of `DiffPose` is available at [`DiffDRR`](https://github.com/eigenvivek/DiffDRR)." 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## Datasets\n", 47 | "\n", 48 | "We evaluate `DiffPose` networks on the following open-source datasets:\n", 49 | "\n", 50 | "| **Dataset** | **Anatomy** | **\\# of Subjects** | **\\# of 2D Images** | **CTs** | **X-rays** | Fiducials |\n", 51 | "|----------------------------------------------------------------------------|--------------------|:------------------:|:-------------------:|:-------:|:----------:|:---------:|\n", 52 | "| [`DeepFluoro`](https://github.com/rg2/DeepFluoroLabeling-IPCAI2020) | Pelvis | 6 | 366 | ✅ | ✅ | ❌ |\n", 53 | "| [`Ljubljana`](https://lit.fe.uni-lj.si/en/research/resources/3D-2D-GS-CA/) | Cerebrovasculature | 10 | 20 | ✅ | ✅ | ✅ |\n", 54 | "\n", 56 | "\n", 57 | "- `DeepFluoro` ([**Grupp et al., 2020**](https://link.springer.com/article/10.1007/s11548-020-02162-7)) provides paired X-ray fluoroscopy images and CT volume of the pelvis. The data were collected from six cadaveric subjects at John Hopkins University. Ground truth camera poses were estimated with an offline registration process. A visualization of one X-ray / CT pair in the `DeepFluoro` dataset is available [here](https://vivekg.dev/DiffPose/experiments/render.html).\n", 58 | "\n", 59 | "```zsh\n", 60 | "mkdir -p data/\n", 61 | "wget --no-check-certificate -O data/ipcai_2020_full_res_data.zip \"http://archive.data.jhu.edu/api/access/datafile/:persistentId/?persistentId=doi:10.7281/T1/IFSXNV/EAN9GH\"\n", 62 | "unzip -o data/ipcai_2020_full_res_data.zip -d data\n", 63 | "rm data/ipcai_2020_full_res_data.zip\n", 64 | "```\n", 65 | "\n", 66 | "- `Ljubljana` ([**Mitrovic et al., 2013**](https://ieeexplore.ieee.org/abstract/document/6507588)) provides paired 2D/3D digital subtraction angiography (DSA) images. The data were collected from 10 patients undergoing endovascular image-guided interventions at the University of Ljubljana. Ground truth camera poses were estimated by registering surface fiducial markers.\n", 67 | "\n", 68 | "```zsh\n", 69 | "mkdir -p data/\n", 70 | "wget --no-check-certificate -O data/ljubljana.zip \"https://drive.google.com/uc?export=download&confirm=yes&id=1x585pGLI8QGk21qZ2oGwwQ9LMJ09Tqrx\"\n", 71 | "unzip -o data/ljubljana.zip -d data\n", 72 | "rm data/ljubljana.zip\n", 73 | "```\n", 74 | "\n", 75 | "" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Experiments\n", 85 | "\n", 86 | "To run the experiments in `DiffPose`, run the following scripts (ensure you've downloaded the data first):\n", 87 | "\n", 88 | "```zsh\n", 89 | "# DeepFluoro dataset\n", 90 | "cd experiments/deepfluoro\n", 91 | "srun python train.py # Pretrain pose regression CNN on synthetic X-rays\n", 92 | "srun python register.py # Run test-time optimization with the best network per subject\n", 93 | "```\n", 94 | "\n", 95 | "```zsh\n", 96 | "# Ljubljana dataset\n", 97 | "cd experiments/ljubljana\n", 98 | "srun python train.py\n", 99 | "srun python register.py\n", 100 | "```\n", 101 | "\n", 102 | "The training and test-time optimization scripts use SLURM to run on all subjects in parallel:\n", 103 | "\n", 104 | "- `experiments/deepfluoro/train.py` is configured to run across six A6000 GPUs\n", 105 | "- `experiments/deepfluoro/register.py` is configured to run across six 2080 Ti GPUs\n", 106 | "- `experiments/ljubljana/train.py` is configured to run across twenty 2080 Ti GPUs\n", 107 | "- `experiments/ljubljana/register.py` is configured to run on twenty 2080 Ti GPUs\n", 108 | "\n", 109 | "The GPU configurations can be changed at the end of each script using [`submitit`](https://github.com/facebookincubator/submitit)." 110 | ] 111 | }, 112 | { 113 | "cell_type": "markdown", 114 | "metadata": {}, 115 | "source": [ 116 | "## Development\n", 117 | "\n", 118 | "`DiffPose` package, docs, and CI are all built using [`nbdev`](https://nbdev.fast.ai/).\n", 119 | "To get set up with`nbdev`, install the following\n", 120 | "\n", 121 | "```zsh\n", 122 | "conda install jupyterlab nbdev -c fastai -c conda-forge \n", 123 | "nbdev_install_quarto # To build docs\n", 124 | "nbdev_install_hooks # Make notebooks git-friendly\n", 125 | "pip install -e \".[dev]\" # Install the development verison of DiffPose\n", 126 | "```\n", 127 | "\n", 128 | "Running `nbdev_help` will give you the full list of options. The most important ones are\n", 129 | "\n", 130 | "```zsh\n", 131 | "nbdev_preview # Render docs locally and inspect in browser\n", 132 | "nbdev_clean # NECESSARY BEFORE PUSHING\n", 133 | "nbdev_test # tests notebooks\n", 134 | "nbdev_export # builds package and builds docs\n", 135 | "nbdev_readme # Render the readme\n", 136 | "```\n", 137 | "\n", 138 | "For more details, follow this [in-depth tutorial](https://nbdev.fast.ai/tutorials/tutorial.html)." 139 | ] 140 | }, 141 | { 142 | "cell_type": "markdown", 143 | "metadata": {}, 144 | "source": [ 145 | "## Citing `DiffPose`\n", 146 | "\n", 147 | "If you find `DiffPose` or [`DiffDRR`](https://github.com/eigenvivek/DiffDRR) useful in your work, please cite the appropriate papers:\n", 148 | "\n", 149 | "```\n", 150 | "@misc{gopalakrishnan2022diffpose,\n", 151 | " title={Intraoperative 2D/3D Image Registration via Differentiable X-ray Rendering}, \n", 152 | " author={Vivek Gopalakrishnan and Neel Dey and Polina Golland},\n", 153 | " year={2023},\n", 154 | " eprint={2312.06358},\n", 155 | " archivePrefix={arXiv},\n", 156 | " primaryClass={cs.CV}\n", 157 | "}\n", 158 | "\n", 159 | "@inproceedings{gopalakrishnan2022diffdrr,\n", 160 | " author={Gopalakrishnan, Vivek and Golland, Polina},\n", 161 | " title={Fast Auto-Differentiable Digitally Reconstructed Radiographs for Solving Inverse Problems in Intraoperative Imaging},\n", 162 | " year={2022},\n", 163 | " booktitle={Clinical Image-based Procedures: 11th International Workshop, CLIP 2022, Held in Conjunction with MICCAI 2022, Singapore, Proceedings},\n", 164 | " series={Lecture Notes in Computer Science},\n", 165 | " publisher={Springer},\n", 166 | " doi={https://doi.org/10.1007/978-3-031-23179-7_1},\n", 167 | "}\n", 168 | "```" 169 | ] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": null, 174 | "metadata": {}, 175 | "outputs": [], 176 | "source": [] 177 | } 178 | ], 179 | "metadata": { 180 | "kernelspec": { 181 | "display_name": "python3", 182 | "language": "python", 183 | "name": "python3" 184 | }, 185 | "widgets": { 186 | "application/vnd.jupyter.widget-state+json": { 187 | "state": {}, 188 | "version_major": 2, 189 | "version_minor": 0 190 | } 191 | } 192 | }, 193 | "nbformat": 4, 194 | "nbformat_minor": 4 195 | } 196 | -------------------------------------------------------------------------------- /notebooks/nbdev.yml: -------------------------------------------------------------------------------- 1 | project: 2 | output-dir: _docs 3 | 4 | website: 5 | title: "diffpose" 6 | site-url: "https://vivekg.dev/DiffPose" 7 | description: "Patient-specific intraoperative 2D/3D registration via differentiable rendering" 8 | repo-branch: main 9 | repo-url: "https://github.com/eigenvivek/DiffPose" 10 | -------------------------------------------------------------------------------- /notebooks/sidebar.yml: -------------------------------------------------------------------------------- 1 | website: 2 | sidebar: 3 | contents: 4 | - index.ipynb 5 | - section: api 6 | contents: 7 | - api/00_deepfluoro.ipynb 8 | - api/01_ljubljana.ipynb 9 | - api/02_calibration.ipynb 10 | - api/03_registration.ipynb 11 | - api/04_metrics.ipynb 12 | - api/05_visualization.ipynb 13 | - api/06_jacobians.ipynb 14 | - section: experiments 15 | contents: 16 | - experiments/00_3D_visualization.ipynb 17 | - experiments/01_pose_recovery.ipynb 18 | - experiments/02_loss_landscapes.ipynb 19 | - experiments/03_sparse_rendering.ipynb 20 | - experiments/render.html 21 | -------------------------------------------------------------------------------- /notebooks/styles.css: -------------------------------------------------------------------------------- 1 | .cell { 2 | margin-bottom: 1rem; 3 | } 4 | 5 | .cell > .sourceCode { 6 | margin-bottom: 0; 7 | } 8 | 9 | .cell-output > pre { 10 | margin-bottom: 0; 11 | } 12 | 13 | .cell-output > pre, .cell-output > .sourceCode > pre, .cell-output-stdout > pre { 14 | margin-left: 0.8rem; 15 | margin-top: 0; 16 | background: none; 17 | border-left: 2px solid lightsalmon; 18 | border-top-left-radius: 0; 19 | border-top-right-radius: 0; 20 | } 21 | 22 | .cell-output > .sourceCode { 23 | border: none; 24 | } 25 | 26 | .cell-output > .sourceCode { 27 | background: none; 28 | margin-top: 0; 29 | } 30 | 31 | div.description { 32 | padding-left: 2px; 33 | padding-top: 5px; 34 | font-style: italic; 35 | font-size: 135%; 36 | opacity: 70%; 37 | } 38 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | # All sections below are required unless otherwise specified. 3 | # See https://github.com/fastai/nbdev/blob/master/settings.ini for examples. 4 | 5 | ### Python library ### 6 | repo = DiffPose 7 | lib_name = diffpose 8 | version = 0.0.1 9 | min_python = 3.7 10 | license = mit 11 | black_formatting = True 12 | 13 | ### nbdev ### 14 | doc_path = _docs 15 | lib_path = diffpose 16 | nbs_path = notebooks 17 | recursive = True 18 | tst_flags = notest 19 | put_version_in_init = True 20 | 21 | ### Docs ### 22 | branch = main 23 | custom_sidebar = False 24 | doc_host = https://vivekg.dev 25 | doc_baseurl = /DiffPose 26 | git_url = https://github.com/eigenvivek/DiffPose 27 | title = diffpose 28 | readme_nb = index.ipynb 29 | 30 | ### PyPI ### 31 | audience = Developers 32 | author = Vivek Gopalakrihsnan 33 | author_email = vivekg@mit.edu 34 | copyright = 2023 onwards, Vivek Gopalakrishnan 35 | description = Patient-specific intraoperative 2D/3D registration via differentiable rendering 36 | keywords = nbdev jupyter notebook python 37 | language = English 38 | status = 3 39 | user = eigenvivek 40 | 41 | ### Optional ### 42 | requirements = diffdrr h5py scikit-image seaborn torch torchvision timm pytorch-transformers torchmetrics tqdm beartype jaxtyping 43 | dev_requirements = jupyterlab_code_formatter black flake8 isort nbdev ipykernel jupyter-server-proxy 44 | optional_requirements = submitit 45 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools, shlex 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini', encoding='utf-8') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | licenses = { 17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 18 | 'mit': ('MIT License', 'OSI Approved :: MIT License'), 19 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), 20 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), 21 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '3.6 3.7 3.8 3.9 3.10'.split() 26 | 27 | requirements = shlex.split(cfg.get('requirements', '')) 28 | if cfg.get('pip_requirements'): requirements += shlex.split(cfg.get('pip_requirements', '')) 29 | min_python = cfg['min_python'] 30 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) 31 | dev_requirements = (cfg.get('dev_requirements') or '').split() 32 | 33 | setuptools.setup( 34 | name = cfg['lib_name'], 35 | license = lic[0], 36 | classifiers = [ 37 | 'Development Status :: ' + statuses[int(cfg['status'])], 38 | 'Intended Audience :: ' + cfg['audience'].title(), 39 | 'Natural Language :: ' + cfg['language'].title(), 40 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), 41 | url = cfg['git_url'], 42 | packages = setuptools.find_packages(), 43 | include_package_data = True, 44 | install_requires = requirements, 45 | extras_require={ 'dev': dev_requirements }, 46 | dependency_links = cfg.get('dep_links','').split(), 47 | python_requires = '>=' + cfg['min_python'], 48 | long_description = open('README.md', encoding='utf-8').read(), 49 | long_description_content_type = 'text/markdown', 50 | zip_safe = False, 51 | entry_points = { 52 | 'console_scripts': cfg.get('console_scripts','').split(), 53 | 'nbdev': [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'] 54 | }, 55 | **setup_cfg) 56 | 57 | 58 | --------------------------------------------------------------------------------