├── .gitignore ├── LICENSE ├── README.md ├── images ├── l2r_m.png ├── method_overview.png ├── rankings.png ├── sc_graphic2-2.png ├── tasks.png └── test.txt ├── l2r_2020_convexAdam_CuRIOUS.py ├── l2r_2020_curious_landmarks.zip ├── l2r_2021_convexAdam_task1_docker.py ├── l2r_2021_convexAdam_task2_docker.py ├── l2r_2021_convexAdam_task3_docker.py ├── pyproject.toml ├── requirements.txt ├── requirements_dev.txt ├── self_configuring ├── adam_run_paired_mind_shiftSpline.py ├── adam_run_withconfig_shiftSpline.py ├── convexAdam_hyper_util.py ├── convex_adam_MIND.py ├── convex_adam_MIND_testset.py ├── convex_adam_nnUNet.py ├── convex_adam_nnUNet_testset.py ├── convex_run_paired_mind.py ├── convex_run_withconfig.py ├── infer_convexadam.py ├── l2r3.py ├── main_for_l2r3_MIND.py ├── main_for_l2r3_MIND_testset.py ├── main_for_l2r3_nnUNet.py └── main_for_l2r3_nnUNet_testset.py ├── setup.cfg ├── setup.py ├── src └── convexAdam │ ├── __init__.py │ ├── apply_convex.py │ ├── convex_adam_MIND.py │ ├── convex_adam_nnUNet.py │ ├── convex_adam_translation.py │ └── convex_adam_utils.py └── tests ├── Development-README.md ├── helper_functions.py ├── input └── 10000 │ ├── 10000_1000000_adc.mha │ ├── 10000_1000000_hbv.mha │ ├── 10000_1000000_prostate_seg.nii.gz │ └── 10000_1000000_t2w.mha ├── output-expected └── 10000 │ └── 10000_1000000_adc_warped.mha ├── output └── .gitkeep ├── test_convex_adam_mind.py ├── test_convex_adam_mind_aniso.py └── test_convex_adam_mind_translation.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # VSCode stuff 118 | .vscode/ 119 | 120 | # Rope project settings 121 | .ropeproject 122 | 123 | # mkdocs documentation 124 | /site 125 | 126 | # mypy 127 | .mypy_cache/ 128 | .dmypy.json 129 | dmypy.json 130 | 131 | # Pyre type checker 132 | .pyre/ 133 | 134 | # macOS 135 | .DS_Store 136 | 137 | # custom 138 | tests/output/* 139 | !tests/output/.gitkeep 140 | tests/input/P-65591573/ 141 | tests/output-expected/P-65591573/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # convexAdam 2 | 3 | News: 4 | * :zap: Our [Journal paper extension](https://ieeexplore.ieee.org/abstract/document/10681158) got accepted to IEEE TMI! :tada: 5 | * :zap: Easy installation with pip available! 6 | 7 | ## :zap: Fast and accurate optimisation for registration with little learning 8 | ![MethodOverview](images/method_overview.png?raw=true "Selfconfiguring") 9 | 10 | 11 | ## :star: ConvexAdam ranks first for the [Learn2Reg Challenge](https://learn2reg.grand-challenge.org/) Datasets! :star: 12 | 13 | ![MethodOverview](images/l2r_m.png?raw=true "Learn2RegResults") 14 | 15 | ## :floppy_disk: Installation 16 | 17 | You can run ConvexAdam out of the box with 18 | ``` 19 | pip install convexAdam 20 | ```` 21 | 22 | ## :bar_chart: Self-configuring hyperparameter optimisation 23 | 24 | ![ConceptOverview](images/sc_graphic2-2.png?raw=true "Selfconfiguring") 25 | 26 | To obtain an automatic estimate of the best choice of all various hyperparameter configurations, we propose a rank-based multi-metric two-stage search mechanism that leverages the fast dual optimisation employed in ConvexAdam to rapidly evaluate hundreds of settings. 27 | 28 | We consider two scenarios: with and without available automatic semantic segmentation features using a pre-trained nnUNet. In the latter case we employ the handcraft MIND-SSC feature descriptor. For the former all infered train/test segmentations for the Learn2Reg tasks can be obtained at https://cloud.imi.uni-luebeck.de/s/cgXJfjDZNNgKRZe 29 | 30 | Next we create a small config file for a new task that is similar to the Learn2Reg dataset.json and contains information on which training/validation pairs to use and how many (if any) labels are available for test/evaluation. 31 | 32 | The entire self-configuring hyperparameter optimisation can usually be run in 1 hour or less and comprises two scripts that are executed after another. 33 | 34 | ``convex_run_withconfig.py`` and ``adam_run_with_config.py`` 35 | 36 | Each will test various settings, run online validation on the training/validation data and create a small log of all obtained scores that are ranked across those individual settings using a simplified version of Learn2Reg's evaluation (normalised per metric ranking w/o statistical significance and a geometric mean across metrics). 37 | 38 | Finally you can use infer_convexadam.py to apply the best parameter setting to the test data and refer to https://github.com/MDL-UzL/L2R/tree/main/evaluation for the official evaluation. 39 | 40 | ## :books: Citations 41 | 42 | If you find our work helpful, please cite: 43 | 44 | ``` 45 | %convexAdam + Hyperparameter Optimisation TMI 46 | @article{siebert2024convexadam, 47 | title={ConvexAdam: Self-Configuring Dual-Optimisation-Based 3D Multitask Medical Image Registration}, 48 | author={Siebert, Hanna and Gro{\ss}br{\"o}hmer, Christoph and Hansen, Lasse and Heinrich, Mattias P}, 49 | journal={IEEE Transactions on Medical Imaging}, 50 | year={2024}, 51 | publisher={IEEE} 52 | } 53 | % Original Learn2Reg2021 Submission 54 | @inproceedings{siebert2021fast, 55 | title={Fast 3D registration with accurate optimisation and little learning for Learn2Reg 2021}, 56 | author={Siebert, Hanna and Hansen, Lasse and Heinrich, Mattias P}, 57 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 58 | pages={174--179}, 59 | year={2021}, 60 | organization={Springer} 61 | } 62 | % Registration with Convex Optimisation 63 | @inproceedings{heinrich2014non, 64 | title={Non-parametric discrete registration with convex optimisation}, 65 | author={Heinrich, Mattias P and Papie{\.z}, Bartlomiej W and Schnabel, Julia A and Handels, Heinz}, 66 | booktitle={Biomedical Image Registration: 6th International Workshop, WBIR 2014, London, UK, July 7-8, 2014. Proceedings 6}, 67 | pages={51--61}, 68 | year={2014}, 69 | organization={Springer} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /images/l2r_m.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/images/l2r_m.png -------------------------------------------------------------------------------- /images/method_overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/images/method_overview.png -------------------------------------------------------------------------------- /images/rankings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/images/rankings.png -------------------------------------------------------------------------------- /images/sc_graphic2-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/images/sc_graphic2-2.png -------------------------------------------------------------------------------- /images/tasks.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/images/tasks.png -------------------------------------------------------------------------------- /images/test.txt: -------------------------------------------------------------------------------- 1 | test 2 | -------------------------------------------------------------------------------- /l2r_2020_curious_landmarks.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/l2r_2020_curious_landmarks.zip -------------------------------------------------------------------------------- /l2r_2021_convexAdam_task2_docker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | import struct 4 | import scipy.ndimage 5 | from scipy.ndimage import zoom, map_coordinates 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | print(torch.__version__) 12 | import sys 13 | import time 14 | from scipy.ndimage import distance_transform_edt as edt 15 | 16 | 17 | 18 | sys.path.append('voxelmorph/pytorch/') 19 | import losses 20 | print(losses.mind_loss) 21 | 22 | def gpu_usage(): 23 | print('gpu usage (current/max): {:.2f} / {:.2f} GB'.format(torch.cuda.memory_allocated()*1e-9, torch.cuda.max_memory_allocated()*1e-9)) 24 | 25 | 26 | 27 | H = 192 28 | W = 192 29 | D = 208 30 | 31 | A = torch.ones(32,32).cuda() 32 | A.requires_grad = True 33 | A.sum().backward() 34 | 35 | def load_case(nu): 36 | fixed = torch.from_numpy(nib.load('/data_supergrover2/heinrich/L2R_Lung/scans/case_0'+str(nu).zfill(2)+'_exp.nii.gz').get_fdata()).float() 37 | moving = torch.from_numpy(nib.load('/data_supergrover2/heinrich/L2R_Lung/scans/case_0'+str(nu).zfill(2)+'_insp.nii.gz').get_fdata()).float() 38 | 39 | 40 | fixed_mask = torch.from_numpy(nib.load('/data_supergrover2/heinrich/L2R_Lung/lungMasks/case_0'+str(nu).zfill(2)+'_exp.nii.gz').get_fdata()).float() 41 | moving_mask = torch.from_numpy(nib.load('/data_supergrover2/heinrich/L2R_Lung/lungMasks/case_0'+str(nu).zfill(2)+'_insp.nii.gz').get_fdata()).float() 42 | return fixed,moving,fixed_mask,moving_mask 43 | 44 | #correlation layer: dense discretised displacements to compute SSD cost volume with box-filter 45 | def correlate(mind_fix,mind_mov,disp_hw,grid_sp): 46 | torch.cuda.synchronize() 47 | t0 = time.time() 48 | with torch.no_grad(): 49 | mind_unfold = F.unfold(F.pad(mind_mov,(disp_hw,disp_hw,disp_hw,disp_hw,disp_hw,disp_hw)).squeeze(0),disp_hw*2+1) 50 | mind_unfold = mind_unfold.view(12,-1,(disp_hw*2+1)**2,W//grid_sp,D//grid_sp) 51 | 52 | 53 | ssd = torch.zeros((disp_hw*2+1)**3,H//grid_sp,W//grid_sp,D//grid_sp,dtype=mind_fix.dtype, device=mind_fix.device)#.cuda().half() 54 | ssd_argmin = torch.zeros(H//grid_sp,W//grid_sp,D//grid_sp).long() 55 | with torch.no_grad(): 56 | for i in range(disp_hw*2+1): 57 | mind_sum = (mind_fix.permute(1,2,0,3,4)-mind_unfold[:,i:i+H//grid_sp]).pow(2).sum(0,keepdim=True) 58 | #5,stride=1,padding=2 59 | #3,stride=1,padding=1 60 | ssd[i::(disp_hw*2+1)] = F.avg_pool3d(mind_sum.transpose(2,1),3,stride=1,padding=1).squeeze(1) 61 | ssd = ssd.view(disp_hw*2+1,disp_hw*2+1,disp_hw*2+1,H//grid_sp,W//grid_sp,D//grid_sp).transpose(1,0).reshape((disp_hw*2+1)**3,H//grid_sp,W//grid_sp,D//grid_sp) 62 | ssd_argmin = torch.argmin(ssd,0)# 63 | #ssd = F.softmax(-ssd*1000,0) 64 | torch.cuda.synchronize() 65 | 66 | t1 = time.time() 67 | print(t1-t0,'sec (ssd)') 68 | gpu_usage() 69 | return ssd,ssd_argmin 70 | 71 | #solve two coupled convex optimisation problems for efficient global regularisation 72 | def coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp): 73 | disp_soft = F.avg_pool3d(disp_mesh_t.view(3,-1)[:,ssd_argmin.view(-1)].reshape(1,3,H//grid_sp,W//grid_sp,D//grid_sp),3,padding=1,stride=1) 74 | 75 | 76 | coeffs = torch.tensor([0.003,0.01,0.03,0.1,0.3,1]) 77 | for j in range(6): 78 | ssd_coupled_argmin = torch.zeros_like(ssd_argmin) 79 | with torch.no_grad(): 80 | for i in range(H//grid_sp): 81 | 82 | coupled = ssd[:,i,:,:]+coeffs[j]*(disp_mesh_t-disp_soft[:,:,i].view(3,1,-1)).pow(2).sum(0).view(-1,W//grid_sp,D//grid_sp) 83 | ssd_coupled_argmin[i] = torch.argmin(coupled,0) 84 | #print(coupled.shape) 85 | 86 | disp_soft = F.avg_pool3d(disp_mesh_t.view(3,-1)[:,ssd_coupled_argmin.view(-1)].reshape(1,3,H//grid_sp,W//grid_sp,D//grid_sp),3,padding=1,stride=1) 87 | 88 | return disp_soft 89 | 90 | #enforce inverse consistency of forward and backward transform 91 | def inverse_consistency(disp_field1s,disp_field2s,iter=20): 92 | #factor = 1 93 | B,C,H,W,D = disp_field1s.size() 94 | #make inverse consistent 95 | with torch.no_grad(): 96 | disp_field1i = disp_field1s.clone() 97 | disp_field2i = disp_field2s.clone() 98 | 99 | identity = F.affine_grid(torch.eye(3,4).unsqueeze(0),(1,1,H,W,D)).permute(0,4,1,2,3).to(disp_field1s.device).to(disp_field1s.dtype) 100 | for i in range(iter): 101 | disp_field1s = disp_field1i.clone() 102 | disp_field2s = disp_field2i.clone() 103 | 104 | disp_field1i = 0.5*(disp_field1s-F.grid_sample(disp_field2s,(identity+disp_field1s).permute(0,2,3,4,1))) 105 | disp_field2i = 0.5*(disp_field2s-F.grid_sample(disp_field1s,(identity+disp_field2s).permute(0,2,3,4,1))) 106 | 107 | return disp_field1i,disp_field2i 108 | 109 | def combineDeformation3d(disp_1st,disp_2nd,identity): 110 | disp_composition = disp_2nd + F.grid_sample(disp_1st,disp_2nd.permute(0,2,3,4,1)+identity) 111 | return disp_composition 112 | 113 | def kpts_pt(kpts_world, shape): 114 | device = kpts_world.device 115 | H, W, D = shape 116 | return (kpts_world.flip(-1) / (torch.tensor([D, W, H]).to(device) - 1)) * 2 - 1 117 | 118 | def kpts_world(kpts_pt, shape): 119 | device = kpts_pt.device 120 | H, W, D = shape 121 | return ((kpts_pt.flip(-1) + 1) / 2) * (torch.tensor([H, W, D]).to(device) - 1) 122 | 123 | import math 124 | import torch 125 | import torch.nn.functional as F 126 | 127 | class TPS: 128 | @staticmethod 129 | def fit(c, f, lambd=0.): 130 | device = c.device 131 | 132 | n = c.shape[0] 133 | f_dim = f.shape[1] 134 | 135 | U = TPS.u(TPS.d(c, c)) 136 | K = U + torch.eye(n, device=device) * lambd 137 | 138 | P = torch.ones((n, 4), device=device) 139 | P[:, 1:] = c 140 | 141 | v = torch.zeros((n+4, f_dim), device=device) 142 | v[:n, :] = f 143 | 144 | A = torch.zeros((n+4, n+4), device=device) 145 | A[:n, :n] = K 146 | A[:n, -4:] = P 147 | A[-4:, :n] = P.t() 148 | 149 | theta = torch.solve(v, A)[0] 150 | return theta 151 | 152 | @staticmethod 153 | def d(a, b): 154 | ra = (a**2).sum(dim=1).view(-1, 1) 155 | rb = (b**2).sum(dim=1).view(1, -1) 156 | dist = ra + rb - 2.0 * torch.mm(a, b.permute(1, 0)) 157 | dist.clamp_(0.0, float('inf')) 158 | return torch.sqrt(dist) 159 | 160 | @staticmethod 161 | def u(r): 162 | return (r**2) * torch.log(r + 1e-6) 163 | 164 | @staticmethod 165 | def z(x, c, theta): 166 | U = TPS.u(TPS.d(x, c)) 167 | w, a = theta[:-4], theta[-4:].unsqueeze(2) 168 | b = torch.matmul(U, w) 169 | return (a[0] + a[1] * x[:, 0] + a[2] * x[:, 1] + a[3] * x[:, 2] + b.t()).t() 170 | 171 | def thin_plate_dense(x1, y1, shape, step, lambd=.0, unroll_step_size=2**12): 172 | device = x1.device 173 | D, H, W = shape 174 | D1, H1, W1 = D//step, H//step, W//step 175 | 176 | x2 = F.affine_grid(torch.eye(3, 4, device=device).unsqueeze(0), (1, 1, D1, H1, W1), align_corners=True).view(-1, 3) 177 | tps = TPS() 178 | theta = tps.fit(x1[0], y1[0], lambd) 179 | 180 | y2 = torch.zeros((1, D1 * H1 * W1, 3), device=device) 181 | N = D1*H1*W1 182 | n = math.ceil(N/unroll_step_size) 183 | for j in range(n): 184 | j1 = j * unroll_step_size 185 | j2 = min((j + 1) * unroll_step_size, N) 186 | y2[0, j1:j2, :] = tps.z(x2[j1:j2], x1[0], theta) 187 | 188 | y2 = y2.view(1, D1, H1, W1, 3).permute(0, 4, 1, 2, 3) 189 | y2 = F.interpolate(y2, (D, H, W), mode='trilinear', align_corners=True).permute(0, 2, 3, 4, 1) 190 | 191 | return y2 192 | 193 | 194 | time_all = torch.zeros(10) 195 | 196 | 197 | A = torch.ones(128,128).cuda() 198 | A.requires_grad = True 199 | A.sum().backward() 200 | torch.cuda.synchronize() 201 | 202 | 203 | for nu in range(21,31): 204 | torch.cuda.synchronize() 205 | t0 = time.time() 206 | fixed,moving,fixed_mask,moving_mask = load_case(nu) 207 | 208 | torch.cuda.synchronize() 209 | t0a = time.time() 210 | grid_sp = 4#4 211 | disp_hw = 6#6 212 | 213 | 214 | #replicate masking!!! 215 | avg3 = nn.Sequential(nn.ReplicationPad3d(1),nn.AvgPool3d(3,stride=1)) 216 | avg3.cuda() 217 | mask = (avg3(fixed_mask.view(1,1,H,W,D).cuda())>0.9).float() 218 | dist,idx = edt((mask[0,0,::2,::2,::2]==0).squeeze().cpu().numpy(),return_indices=True) 219 | fixed_r = F.interpolate((fixed[::2,::2,::2].cuda().reshape(-1)[idx[0]*104*96+idx[1]*104+idx[2]]).unsqueeze(0).unsqueeze(0),scale_factor=2,mode='trilinear') 220 | fixed_r.view(-1)[mask.view(-1)!=0] = fixed.cuda().reshape(-1)[mask.view(-1)!=0] 221 | #fixed_r = fixed.cuda().view(1,1,H,W,D)*mask 222 | mask = (avg3(moving_mask.view(1,1,H,W,D).cuda())>0.9).float() 223 | dist,idx = edt((mask[0,0,::2,::2,::2]==0).squeeze().cpu().numpy(),return_indices=True) 224 | moving_r = F.interpolate((moving[::2,::2,::2].cuda().reshape(-1)[idx[0]*104*96+idx[1]*104+idx[2]]).unsqueeze(0).unsqueeze(0),scale_factor=2,mode='trilinear') 225 | moving_r.view(-1)[mask.view(-1)!=0] = moving.cuda().reshape(-1)[mask.view(-1)!=0] 226 | #moving_r = moving.cuda().view(1,1,H,W,D)*mask 227 | 228 | 229 | #compute MIND descriptors and downsample (using average pooling) 230 | with torch.no_grad(): 231 | mindssc_fix = losses.MINDSSC(fixed_r,1,2).half()#*fixed_mask.cuda().half()#.cpu() 232 | mindssc_mov = losses.MINDSSC(moving_r,1,2).half()#*moving_mask.cuda().half()#.cpu() 233 | 234 | mind_fix = F.avg_pool3d(mindssc_fix,grid_sp,stride=grid_sp) 235 | mind_mov = F.avg_pool3d(mindssc_mov,grid_sp,stride=grid_sp) 236 | 237 | 238 | ssd,ssd_argmin = correlate(mind_fix,mind_mov,disp_hw,grid_sp) 239 | disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).cuda().half().unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1) 240 | 241 | disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp) 242 | 243 | #disp_soft = F.avg_pool3d(disp_mesh_t.view(3,-1)[:,ssd_argmin.view(-1)].reshape(1,3,H//grid_sp,W//grid_sp,D//grid_sp),3,padding=1,stride=1) 244 | 245 | #ssd_,ssd_argmin_ = correlate(mind_mov,mind_fix,disp_hw,grid_sp) 246 | #disp_soft_ = coupled_convex(ssd_,ssd_argmin_,disp_mesh_t,grid_sp) 247 | scale = torch.tensor([H//grid_sp-1,W//grid_sp-1,D//grid_sp-1]).view(1,3,1,1,1).cuda().half()/2 248 | #disp_ice,_ = inverse_consistency((disp_soft/scale).flip(1),(disp_soft_/scale).flip(1),iter=10) 249 | 250 | 251 | disp_hr = F.interpolate(disp_soft*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False) 252 | 253 | 254 | 255 | grid_sp = 2 256 | 257 | with torch.no_grad(): 258 | mind_fix = F.avg_pool3d(mindssc_fix,grid_sp,stride=grid_sp) 259 | mind_mov = F.avg_pool3d(mindssc_mov,grid_sp,stride=grid_sp) 260 | 261 | 262 | #patch_mind_fix = nn.Flatten(5,)(F.pad(mind_fix,(1,1,1,1,1,1)).unfold(2,3,1).unfold(3,3,1).unfold(4,3,1)).permute(0,1,5,2,3,4).reshape(1,-1,H//grid_sp,W//grid_sp,D//grid_sp) 263 | #patch_mind_mov = nn.Flatten(5,)(F.pad(mind_mov,(1,1,1,1,1,1)).unfold(2,3,1).unfold(3,3,1).unfold(4,3,1)).permute(0,1,5,2,3,4).reshape(1,-1,H//grid_sp,W//grid_sp,D//grid_sp) 264 | 265 | 266 | #create optimisable displacement grid 267 | grid_sp = 2 268 | disp_lr = F.interpolate(disp_hr,size=(H//grid_sp,W//grid_sp,D//grid_sp),mode='trilinear',align_corners=False) 269 | 270 | 271 | net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp,W//grid_sp,D//grid_sp),bias=False)) 272 | net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp 273 | net.cuda() 274 | optimizer = torch.optim.Adam(net.parameters(), lr=1) 275 | #torch.cuda.synchronize() 276 | #t0 = time.time() 277 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp,W//grid_sp,D//grid_sp),align_corners=False) 278 | 279 | #run Adam optimisation with diffusion regularisation and B-spline smoothing 280 | lambda_weight = .65# with tps: .5, without:0.7 281 | for iter in range(50):#80 282 | optimizer.zero_grad() 283 | 284 | disp_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1) 285 | reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ 286 | lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ 287 | lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() 288 | 289 | #grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/torch.tensor([63/2,63/2,68/2]).unsqueeze(0).cuda()).flip(1) 290 | 291 | scale = torch.tensor([(H//grid_sp-1)/2,(W//grid_sp-1)/2,(D//grid_sp-1)/2]).cuda().unsqueeze(0) 292 | grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() 293 | 294 | patch_mov_sampled = F.grid_sample(mind_mov.float(),grid_disp.view(1,H//grid_sp,W//grid_sp,D//grid_sp,3).cuda(),align_corners=False,mode='bilinear')#,padding_mode='border') 295 | #patch_mov_sampled_sq = F.grid_sample(mind_mov.pow(2).float(),grid_disp.view(1,H//grid_sp,W//grid_sp,D//grid_sp,3).cuda(),align_corners=True,mode='bilinear') 296 | #sampled_cost = (patch_mov_sampled_sq-2*patch_mov_sampled*mind_fix+mind_fix.pow(2)).mean(1)*12 297 | 298 | sampled_cost = (patch_mov_sampled-mind_fix).pow(2).mean(1)*12 299 | #sampled_cost = F.grid_sample(ssd2.view(-1,1,17,17,17).float(),disp_sample.view(-1,1,1,1,3)/disp_hw,align_corners=True,padding_mode='border') 300 | loss = sampled_cost.mean() 301 | (loss+reg_loss).backward() 302 | optimizer.step() 303 | 304 | fitted_grid = disp_sample.permute(0,4,1,2,3).detach() 305 | disp_hr = F.interpolate(fitted_grid*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False) 306 | disp_smooth = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,3,padding=1,stride=1),3,padding=1,stride=1),3,padding=1,stride=1) 307 | 308 | disp_field = F.interpolate(disp_smooth,scale_factor = 0.5,mode='trilinear',align_corners=False) 309 | 310 | 311 | torch.cuda.synchronize() 312 | t0b = time.time() 313 | x1 = disp_field[0,0,:,:,:].cpu().float().data.numpy() 314 | y1 = disp_field[0,1,:,:,:].cpu().float().data.numpy() 315 | z1 = disp_field[0,2,:,:,:].cpu().float().data.numpy() 316 | 317 | #x1 = zoom(x,1/2,order=2).astype('float16') 318 | #y1 = zoom(y,1/2,order=2).astype('float16') 319 | #z1 = zoom(z,1/2,order=2).astype('float16') 320 | 321 | 322 | np.savez_compressed('/data_supergrover2/heinrich/L2R2021/convexAdam/submission/task_02/disp_'+str(nu).zfill(4)+'_'+str(nu).zfill(4)+'.npz',np.stack((x1,y1,z1),0)) 323 | 324 | torch.cuda.synchronize() 325 | #t1 = time.time() 326 | #print(t1-t0,'sec (optim)') 327 | t1 = time.time() 328 | print(t1-t0,'time all sec',t0a-t0+t1-t0b,'read/write') 329 | time_all[nu-21] = t1-t0 330 | 331 | print('time all',time_all.mean()) 332 | torch.save(time_all,'/data_supergrover2/heinrich/L2R2021/convexAdam/task2_times.pth') 333 | -------------------------------------------------------------------------------- /l2r_2021_convexAdam_task3_docker.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import nibabel as nib 3 | import struct 4 | import scipy.ndimage 5 | from scipy.ndimage import zoom, map_coordinates 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | 11 | print(torch.__version__) 12 | import sys 13 | import time 14 | 15 | 16 | sys.path.append('voxelmorph/pytorch/') 17 | import losses 18 | print(losses.mind_loss) 19 | 20 | def gpu_usage(): 21 | print('gpu usage (current/max): {:.2f} / {:.2f} GB'.format(torch.cuda.memory_allocated()*1e-9, torch.cuda.max_memory_allocated()*1e-9)) 22 | 23 | 24 | H = 160 25 | W = 192 26 | D = 224 27 | 28 | def dice_coeff(outputs, labels, max_label): 29 | dice = torch.FloatTensor(max_label-1).fill_(0) 30 | for label_num in range(1, max_label): 31 | iflat = (outputs==label_num).view(-1).float() 32 | tflat = (labels==label_num).view(-1).float() 33 | intersection = torch.mean(iflat * tflat) 34 | dice[label_num-1] = (2. * intersection) / (1e-8 + torch.mean(iflat) + torch.mean(tflat)) 35 | return dice 36 | 37 | 38 | identity = np.stack(np.meshgrid(np.arange(H), np.arange(W), np.arange(D), indexing='ij')) 39 | #print(fixed.shape) 40 | #correlation layer: dense discretised displacements to compute SSD cost volume with box-filter 41 | def correlate(mind_fix,mind_mov,disp_hw,grid_sp): 42 | torch.cuda.synchronize() 43 | C_mind = mind_fix.shape[1] 44 | t0 = time.time() 45 | with torch.no_grad(): 46 | mind_unfold = F.unfold(F.pad(mind_mov,(disp_hw,disp_hw,disp_hw,disp_hw,disp_hw,disp_hw)).squeeze(0),disp_hw*2+1) 47 | mind_unfold = mind_unfold.view(C_mind,-1,(disp_hw*2+1)**2,W//grid_sp,D//grid_sp) 48 | 49 | 50 | ssd = torch.zeros((disp_hw*2+1)**3,H//grid_sp,W//grid_sp,D//grid_sp,dtype=mind_fix.dtype, device=mind_fix.device)#.cuda().half() 51 | ssd_argmin = torch.zeros(H//grid_sp,W//grid_sp,D//grid_sp).long() 52 | with torch.no_grad(): 53 | for i in range(disp_hw*2+1): 54 | mind_sum = (mind_fix.permute(1,2,0,3,4)-mind_unfold[:,i:i+H//grid_sp]).abs().sum(0,keepdim=True) 55 | 56 | ssd[i::(disp_hw*2+1)] = F.avg_pool3d(mind_sum.transpose(2,1),3,stride=1,padding=1).squeeze(1) 57 | ssd = ssd.view(disp_hw*2+1,disp_hw*2+1,disp_hw*2+1,H//grid_sp,W//grid_sp,D//grid_sp).transpose(1,0).reshape((disp_hw*2+1)**3,H//grid_sp,W//grid_sp,D//grid_sp) 58 | ssd_argmin = torch.argmin(ssd,0)# 59 | #ssd = F.softmax(-ssd*1000,0) 60 | torch.cuda.synchronize() 61 | 62 | t1 = time.time() 63 | print(t1-t0,'sec (ssd)') 64 | gpu_usage() 65 | return ssd,ssd_argmin 66 | 67 | #solve two coupled convex optimisation problems for efficient global regularisation 68 | def coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp): 69 | disp_soft = F.avg_pool3d(disp_mesh_t.view(3,-1)[:,ssd_argmin.view(-1)].reshape(1,3,H//grid_sp,W//grid_sp,D//grid_sp),3,padding=1,stride=1) 70 | 71 | 72 | coeffs = torch.tensor([0.003,0.01,0.03,0.1,0.3,1]) 73 | for j in range(6): 74 | ssd_coupled_argmin = torch.zeros_like(ssd_argmin) 75 | with torch.no_grad(): 76 | for i in range(H//grid_sp): 77 | 78 | coupled = ssd[:,i,:,:]+coeffs[j]*(disp_mesh_t-disp_soft[:,:,i].view(3,1,-1)).pow(2).sum(0).view(-1,W//grid_sp,D//grid_sp) 79 | ssd_coupled_argmin[i] = torch.argmin(coupled,0) 80 | #print(coupled.shape) 81 | 82 | disp_soft = F.avg_pool3d(disp_mesh_t.view(3,-1)[:,ssd_coupled_argmin.view(-1)].reshape(1,3,H//grid_sp,W//grid_sp,D//grid_sp),3,padding=1,stride=1) 83 | 84 | return disp_soft 85 | 86 | #enforce inverse consistency of forward and backward transform 87 | def inverse_consistency(disp_field1s,disp_field2s,iter=20): 88 | #factor = 1 89 | B,C,H,W,D = disp_field1s.size() 90 | #make inverse consistent 91 | with torch.no_grad(): 92 | disp_field1i = disp_field1s.clone() 93 | disp_field2i = disp_field2s.clone() 94 | 95 | identity = F.affine_grid(torch.eye(3,4).unsqueeze(0),(1,1,H,W,D)).permute(0,4,1,2,3).to(disp_field1s.device).to(disp_field1s.dtype) 96 | for i in range(iter): 97 | disp_field1s = disp_field1i.clone() 98 | disp_field2s = disp_field2i.clone() 99 | 100 | disp_field1i = 0.5*(disp_field1s-F.grid_sample(disp_field2s,(identity+disp_field1s).permute(0,2,3,4,1))) 101 | disp_field2i = 0.5*(disp_field2s-F.grid_sample(disp_field1s,(identity+disp_field2s).permute(0,2,3,4,1))) 102 | 103 | return disp_field1i,disp_field2i 104 | 105 | def combineDeformation3d(disp_1st,disp_2nd,identity): 106 | disp_composition = disp_2nd + F.grid_sample(disp_1st,disp_2nd.permute(0,2,3,4,1)+identity) 107 | return disp_composition 108 | 109 | grid_sp = 2 110 | disp_hw = 3 111 | 112 | nu = 438 113 | 114 | fixed = torch.from_numpy(nib.load('L2R2021/Task3/skull_stripped/nnunet/img0'+str(nu)+'.nii.gz').get_fdata()).float() 115 | moving = torch.from_numpy(nib.load('L2R2021/Task3/skull_stripped/nnunet/img0'+str(nu+1)+'.nii.gz').get_fdata()).float() 116 | 117 | 118 | weight = 1/(torch.bincount(fixed.long().reshape(-1))+torch.bincount(moving.long().reshape(-1))).float().pow(.3) 119 | weight /= weight.mean() 120 | print(weight) 121 | case_time = torch.zeros(38) 122 | torch.cuda.synchronize() 123 | t0_ = time.time() 124 | for nu in range(1,39): 125 | torch.cuda.synchronize() 126 | t0 = time.time() 127 | fixed = torch.from_numpy(nib.load('/data_supergrover2/heinrich/nnUNet_predict/L2R_2021_Task3_test/img'+str(nu).zfill(4)+'_0000.nii.gz').get_fdata()).float() 128 | # mindssc_fix_ = 10*(F.one_hot(fixed.cuda().view(1,H,W,D).long()).float().permute(0,4,1,2,3).contiguous()*weight.view(1,-1,1,1,1).cuda()).half() 129 | #RuntimeError: The size of tensor a (2) must match the size of tensor b (36) at non-singleton dimension 1 130 | 131 | moving = torch.from_numpy(nib.load('/data_supergrover2/heinrich/nnUNet_predict/L2R_2021_Task3_test/img'+str(nu+1).zfill(4)+'_0000.nii.gz').get_fdata()).float() 132 | 133 | #/share/data_supergrover1/hansen/temp/nnUNet/nnUNet_results/nnUNet/3d_fullres/Task509_OASIS 134 | #OASIS001_img.nii.gz 135 | #/data_supergrover2/heinrich/nnUNet_predict/L2R_Task3_Test/nnunet/img 136 | fixed_seg = torch.from_numpy(nib.load('L2R2021/Task3/Lasse_nnUNet/img'+str(nu).zfill(4)+'.nii.gz').get_fdata()).float() 137 | moving_seg = torch.from_numpy(nib.load('L2R2021/Task3/Lasse_nnUNet/img'+str(nu+1).zfill(4)+'.nii.gz').get_fdata()).float() 138 | #moving_seg = torch.from_numpy(nib.load('/data_supergrover2/heinrich/nnUNet_predict/L2R_Task3_Test/nnunet/img'+str(nu+1).zfill(4)+'.nii.gz').get_fdata()).float() 139 | with torch.no_grad(): 140 | mindssc_fix_ = 10*(F.one_hot(fixed_seg.cuda().view(1,H,W,D).long()).float().permute(0,4,1,2,3).contiguous()*weight.view(1,-1,1,1,1).cuda()).half() 141 | mindssc_mov_ = 10*(F.one_hot(moving_seg.cuda().view(1,H,W,D).long()).float().permute(0,4,1,2,3).contiguous()*weight.view(1,-1,1,1,1).cuda()).half() 142 | mind_fix_ = F.avg_pool3d(mindssc_fix_,grid_sp,stride=grid_sp) 143 | mind_mov_ = F.avg_pool3d(mindssc_mov_,grid_sp,stride=grid_sp) 144 | ssd,ssd_argmin = correlate(mind_fix_,mind_mov_,disp_hw,grid_sp) 145 | disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).cuda().half().unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1) 146 | disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp) 147 | 148 | del ssd,mind_fix_,mind_mov_ 149 | 150 | 151 | print(disp_soft.shape) 152 | #del ssd 153 | #del disp_mesh_t 154 | torch.cuda.empty_cache() 155 | gpu_usage() 156 | 157 | 158 | 159 | disp_lr = F.interpolate(disp_soft*grid_sp,size=(H//2,W//2,D//2),mode='trilinear',align_corners=False) 160 | #disp_soft*grid_sp/2# 161 | 162 | 163 | 164 | grid_sp = 2 165 | 166 | 167 | #extract one-hot patches 168 | torch.cuda.synchronize() 169 | t0 = time.time() 170 | 171 | with torch.no_grad(): 172 | mind_fix_ = F.avg_pool3d(mindssc_fix_,grid_sp,stride=grid_sp) 173 | mind_mov_ = F.avg_pool3d(mindssc_mov_,grid_sp,stride=grid_sp) 174 | del mindssc_fix_,mindssc_mov_ 175 | 176 | 177 | #extract one-hot patches 178 | 179 | #create optimisable displacement grid 180 | net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp,W//grid_sp,D//grid_sp),bias=False)) 181 | net[0].weight.data[:] = disp_lr/grid_sp 182 | net.cuda() 183 | optimizer = torch.optim.Adam(net.parameters(), lr=1) 184 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp,W//grid_sp,D//grid_sp),align_corners=False) 185 | 186 | #run Adam optimisation with diffusion regularisation and B-spline smoothing 187 | lambda_weight = 1.25# sad: 10, ssd:0.75 188 | for iter in range(100): 189 | optimizer.zero_grad() 190 | 191 | disp_sample = F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1) 192 | reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ 193 | lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ 194 | lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() 195 | 196 | scale = torch.tensor([(H//grid_sp-1)/2,(W//grid_sp-1)/2,(D//grid_sp-1)/2]).cuda().unsqueeze(0) 197 | grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() 198 | 199 | patch_mov_sampled = F.grid_sample(mind_mov_.float(),grid_disp.view(1,H//grid_sp,W//grid_sp,D//grid_sp,3).cuda(),align_corners=False,mode='bilinear')#,padding_mode='border') 200 | #patch_mov_sampled_sq = F.grid_sample(mind_mov_.pow(2).float(),grid_disp.view(1,H//grid_sp,W//grid_sp,D//grid_sp,3).cuda(),align_corners=True,mode='bilinear') 201 | #sampled_cost = (patch_mov_sampled_sq-2*patch_mov_sampled*mind_fix_+mind_fix_.pow(2)).mean(1)*12 202 | sampled_cost = (patch_mov_sampled-mind_fix_).pow(2).mean(1)*12 203 | 204 | 205 | loss = sampled_cost.mean() 206 | (loss+reg_loss).backward() 207 | optimizer.step() 208 | torch.cuda.synchronize() 209 | t1 = time.time() 210 | print(t1-t0,'sec (optim)') 211 | 212 | fitted_grid = disp_sample.permute(0,4,1,2,3).detach() 213 | disp_hr = F.interpolate(fitted_grid*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False) 214 | disp_field = F.interpolate(disp_hr,scale_factor = 0.5,mode='trilinear',align_corners=False) 215 | x1 = disp_field[0,0,:,:,:].cpu().float().data.numpy() 216 | y1 = disp_field[0,1,:,:,:].cpu().float().data.numpy() 217 | z1 = disp_field[0,2,:,:,:].cpu().float().data.numpy() 218 | 219 | #x1 = zoom(x,1/2,order=2).astype('float16') 220 | #y1 = zoom(y,1/2,order=2).astype('float16') 221 | #z1 = zoom(z,1/2,order=2).astype('float16') 222 | 223 | 224 | np.savez_compressed('/data_supergrover2/heinrich/L2R2021/convexAdam/submission/task_03/disp_'+str(nu).zfill(4)+'_'+str(nu+1).zfill(4)+'.npz',np.stack((x1,y1,z1),0)) 225 | 226 | torch.cuda.synchronize() 227 | t1 = time.time() 228 | case_time[nu-1] = t1-t0 229 | 230 | torch.cuda.synchronize() 231 | t1_ = time.time() 232 | print('total time',t1_-t0_) 233 | torch.save(case_time,'/data_supergrover2/heinrich/L2R2021/convexAdam/task3_times.pth') 234 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=42.0", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.pytest.ini_options] 6 | addopts = "--cov=convexAdam" 7 | testpaths = [ 8 | "tests", 9 | ] 10 | 11 | [tool.mypy] 12 | mypy_path = "src" 13 | check_untyped_defs = true 14 | disallow_any_generics = true 15 | ignore_missing_imports = true 16 | no_implicit_optional = true 17 | show_error_codes = true 18 | strict_equality = true 19 | warn_redundant_casts = true 20 | warn_return_any = true 21 | warn_unreachable = true 22 | warn_unused_configs = true 23 | no_implicit_reexport = true 24 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | nibabel 2 | numpy 3 | scikit-learn 4 | SimpleITK 5 | torch 6 | -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | flake8==3.9.2 2 | tox==3.24.3 3 | pytest==6.2.5 4 | pytest-cov==2.12.1 5 | mypy===0.910 6 | build 7 | twine 8 | -------------------------------------------------------------------------------- /self_configuring/adam_run_paired_mind_shiftSpline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import time 4 | import warnings 5 | 6 | import nibabel as nib 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | warnings.filterwarnings("ignore") 13 | import os 14 | 15 | from convexAdam_hyper_util import (GaussianSmoothing, correlate, 16 | coupled_convex, extract_features, gpu_usage, 17 | inverse_consistency, 18 | jacobian_determinant_3d, kovesi_spline, 19 | sort_rank) 20 | from tqdm.auto import tqdm, trange 21 | 22 | 23 | def get_data_train(topk,HWD,f_img,f_key,f_mask): 24 | l2r_base_folder = './' 25 | H,W,D = HWD[0],HWD[1],HWD[2] 26 | 27 | imgs_fixed = [] 28 | keypts_fixed = [] 29 | masks_fixed = [] 30 | imgs_moving = [] 31 | keypts_moving = [] 32 | masks_moving = [] 33 | 34 | for i in tqdm(topk): 35 | file_img = f_img.replace('xxxx',str(i).zfill(4)) 36 | file_key = f_key.replace('xxxx',str(i).zfill(4)) 37 | file_mask = f_mask.replace('xxxx',str(i).zfill(4)) 38 | 39 | img_fixed = torch.from_numpy(nib.load(file_img).get_fdata()).float().cuda().contiguous() 40 | key_fixed = torch.from_numpy(np.loadtxt(file_key,delimiter=',')).float() 41 | mask_fixed = torch.from_numpy(nib.load(file_mask).get_fdata()).float().cuda().contiguous() 42 | imgs_fixed.append(img_fixed) 43 | keypts_fixed.append(key_fixed) 44 | masks_fixed.append(mask_fixed) 45 | 46 | file_img = file_img.replace('0000',str(1).zfill(4)) 47 | file_key = file_key.replace('0000',str(1).zfill(4)) 48 | file_mask = file_mask.replace('0000',str(1).zfill(4)) 49 | 50 | img_moving = torch.from_numpy(nib.load(file_img).get_fdata()).float().cuda().contiguous() 51 | key_moving = torch.from_numpy(np.loadtxt(file_key,delimiter=',')).float() 52 | mask_moving = torch.from_numpy(nib.load(file_mask).get_fdata()).float().cuda().contiguous() 53 | imgs_moving.append(img_moving) 54 | keypts_moving.append(key_moving) 55 | masks_moving.append(mask_moving) 56 | 57 | return imgs_fixed,keypts_fixed,masks_fixed,imgs_moving,keypts_moving,masks_moving 58 | 59 | def main(gpunum,configfile,convex_s): 60 | 61 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 62 | os.environ['CUDA_VISIBLE_DEVICES'] = str((gpunum)) 63 | print(torch.cuda.get_device_name()) 64 | 65 | with open(configfile, 'r') as f: 66 | config = json.load(f) 67 | topk = config['topk'] 68 | 69 | print('using 15 registration pairs') 70 | imgs_fixed,keypts_fixed,masks_fixed,imgs_moving,keypts_moving,masks_moving = get_data_train(topk,config['HWD'],config['f_img'],config['f_key'],config['f_mask']) 71 | robust30 = [] 72 | for i in range(len(topk)): 73 | tre0 = (keypts_fixed[i]-keypts_moving[i]).square().sum(-1).sqrt() 74 | robust30.append(tre0.topk(int(len(tre0)*.3),largest=True).indices) 75 | 76 | torch.manual_seed(1004) 77 | settings = (torch.rand(100,4)*torch.tensor([3,3,4,6])+torch.tensor([0.5,0.5,1.5,1.5])).round() 78 | #print(settings[1]) 79 | settings[settings[:,2]==2,3] = torch.minimum(settings[settings[:,2]==2,3],torch.tensor([5])) 80 | 81 | print(settings.min(0).values,settings.max(0).values,) 82 | 83 | mind_r = int(settings[convex_s,0])#1 84 | mind_d = int(settings[convex_s,1])#1 85 | grid_sp = int(settings[convex_s,2])#6 86 | disp_hw = int(settings[convex_s,3])#4 87 | 88 | 89 | print('using predetermined setting s=',convex_s) 90 | 91 | print('setting mind_r',mind_r,'mind_d',mind_d,'grid_sp',grid_sp,'disp_hw',disp_hw) 92 | tre_convex = torch.zeros(len(topk)) 93 | ##APPLY BEST CONVEX TO TRAIN 94 | disps_lr = [] 95 | for i in trange(len(topk)): 96 | 97 | t0 = time.time() 98 | 99 | img_fixed = imgs_fixed[i].cuda() 100 | key_fixed = keypts_fixed[i].cuda() 101 | mask_fixed = masks_fixed[i].cuda() 102 | 103 | img_moving = imgs_moving[i].cuda() 104 | key_moving = keypts_moving[i].cuda() 105 | mask_moving = masks_moving[i].cuda() 106 | 107 | H, W, D = img_fixed.shape[-3:] 108 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H,W,D),align_corners=False) 109 | torch.cuda.synchronize() 110 | t0 = time.time() 111 | 112 | # compute features and downsample (using average pooling) 113 | with torch.no_grad(): 114 | 115 | features_fix,features_mov = extract_features(img_fixed,img_moving,mind_r,mind_d,True,mask_fixed,mask_moving) 116 | 117 | features_fix_smooth = F.avg_pool3d(features_fix,grid_sp,stride=grid_sp) 118 | features_mov_smooth = F.avg_pool3d(features_mov,grid_sp,stride=grid_sp) 119 | 120 | n_ch = features_fix_smooth.shape[1] 121 | t1 = time.time() 122 | #with torch.cuda.amp.autocast(dtype=torch.bfloat16): 123 | # compute correlation volume with SSD 124 | ssd,ssd_argmin = correlate(features_fix_smooth,features_mov_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 125 | 126 | # provide auxiliary mesh grid 127 | disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).cuda().half().unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1) 128 | 129 | # perform coupled convex optimisation 130 | disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp,(H,W,D)) 131 | 132 | # if "ic" flag is set: make inverse consistent 133 | scale = torch.tensor([H//grid_sp-1,W//grid_sp-1,D//grid_sp-1]).view(1,3,1,1,1).cuda().half()/2 134 | 135 | ssd_,ssd_argmin_ = correlate(features_mov_smooth,features_fix_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 136 | 137 | disp_soft_ = coupled_convex(ssd_,ssd_argmin_,disp_mesh_t,grid_sp,(H,W,D)) 138 | disp_ice,_ = inverse_consistency((disp_soft/scale).flip(1),(disp_soft_/scale).flip(1),iter=15) 139 | disp_lr = disp_ice.flip(1)*scale*grid_sp 140 | 141 | disp_hr = F.interpolate(disp_lr,size=(H,W,D),mode='trilinear',align_corners=False) 142 | t2 = time.time() 143 | disps_lr.append(disp_lr.cpu()) 144 | scale1 = torch.tensor([D-1,W-1,H-1]).cuda()/2 145 | 146 | lms_fix1 = (key_fixed.flip(1)/scale1-1).cuda().view(1,-1,1,1,3) 147 | disp_sampled = F.grid_sample(disp_hr.float().cuda(),lms_fix1).squeeze().t().cpu().data 148 | #TRE0 = (key_fixed.cpu()-key_moving.cpu()).square().sum(-1).sqrt() 149 | TRE1 = (key_fixed.cpu()-key_moving.cpu()+disp_sampled).square().sum(-1).sqrt() 150 | tre_convex[i] = TRE1.mean() 151 | #print(TRE0.mean(),'>',TRE1.mean()) 152 | print('TRE convex',tre_convex.mean()) 153 | del disp_soft; del disp_soft_; del ssd_; del ssd; del disp_hr; del features_fix; del features_mov; del features_fix_smooth; del features_mov_smooth; 154 | 155 | ##FIND OPTIMAL ADAM SETTING 156 | 157 | avgs = [GaussianSmoothing(.7).cuda(),\ 158 | GaussianSmoothing(1).cuda(),kovesi_spline(1.3,4).cuda(),kovesi_spline(1.6,4).cuda(),kovesi_spline(1.9,4).cuda(),kovesi_spline(2.2,4).cuda(),kovesi_spline(2.5,4).cuda(),kovesi_spline(2.8,4).cuda()] 159 | 160 | 161 | torch.manual_seed(2004) 162 | #settings_adam = (torch.rand(50,5)*torch.tensor([2,2,3,5,7])+torch.tensor([0.5,0.5,0.5,.5,1.5])).round() #new 163 | 164 | # settings_adam = (torch.rand(50,5)*torch.tensor([2,2,4,5,7])+torch.tensor([0.5,0.5,0.5,-.49,1.5])).round() 165 | settings_adam = (torch.rand(75,5)*torch.tensor([2,2,4,5,7])+torch.tensor([0.5,0.5,0.5,.5,1.5])).round() 166 | settings_adam[:,4] *= .2 167 | #settings_adam[0] = torch.tensor([1,2,2,3,1.5]) 168 | #settings_adam[1] = torch.tensor([1,2,1,4,1.5]) 169 | #print('s0',settings_adam[0]) 170 | #print('s1',settings_adam[1]) 171 | #settings[settings[:,2]==2,3] = torch.minimum(settings[settings[:,2]==2,3],torch.tensor([5])) 172 | #print(settings[1]) 173 | torch.cuda.empty_cache() 174 | print(settings_adam.min(0).values,settings_adam.max(0).values,gpu_usage()) 175 | 176 | jstd2 = torch.zeros(75,4,4,2) 177 | tre2 = torch.zeros(75,4,4,2) 178 | tre_min = 100 179 | for s in trange(75): 180 | for i in trange(len(topk)): 181 | 182 | t0 = time.time() 183 | 184 | img_fixed = imgs_fixed[i].cuda() 185 | key_fixed = keypts_fixed[i].cuda() 186 | mask_fixed = masks_fixed[i].cuda() 187 | 188 | img_moving = imgs_moving[i].cuda() 189 | key_moving = keypts_moving[i].cuda() 190 | mask_moving = masks_moving[i].cuda() 191 | 192 | mind_r = int(settings_adam[s,0])#1 193 | mind_d = int(settings_adam[s,1])#2 194 | grid_sp_adam = int(settings_adam[s,2])#6 195 | avg_n = int(settings_adam[s,3])#6 196 | if(grid_sp_adam==1): 197 | avg_n += 2 198 | if(grid_sp_adam==2): 199 | avg_n += 1 200 | lambda_weight = float(settings_adam[s,4])#4 201 | 202 | t0 = time.time() 203 | 204 | 205 | H, W, D = img_fixed.shape[-3:] 206 | grid0_hr = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H,W,D),align_corners=False) 207 | torch.cuda.synchronize() 208 | t0 = time.time() 209 | 210 | # compute features and downsample (using average pooling) 211 | with torch.no_grad(): 212 | features_fix,features_mov = extract_features(img_fixed,img_moving,mind_r,mind_d,True,mask_fixed,mask_moving) 213 | n_ch = features_mov.shape[1] 214 | # run Adam instance optimisation 215 | with torch.no_grad(): 216 | patch_features_fix = F.avg_pool3d(features_fix,grid_sp_adam,stride=grid_sp_adam) 217 | patch_features_mov = F.avg_pool3d(features_mov,grid_sp_adam,stride=grid_sp_adam) 218 | 219 | disp_hr = F.interpolate(disps_lr[i].cuda().float(),size=(H,W,D),mode='trilinear',align_corners=False) 220 | #create optimisable displacement grid 221 | disp_lr = F.interpolate(disp_hr,size=(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),mode='trilinear',align_corners=False) 222 | 223 | net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),bias=False)) 224 | net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp_adam 225 | net.cuda() 226 | optimizer = torch.optim.Adam(net.parameters(), lr=1) 227 | 228 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),align_corners=False) 229 | #run Adam optimisation with diffusion regularisation and B-spline smoothing 230 | for iter in range(120): 231 | optimizer.zero_grad() 232 | 233 | disp_sample = (avgs[avg_n](net[0].weight)).permute(0,2,3,4,1)#,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1) 234 | reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ 235 | lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ 236 | lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() 237 | 238 | scale = torch.tensor([(H//grid_sp_adam-1)/2,(W//grid_sp_adam-1)/2,(D//grid_sp_adam-1)/2]).cuda().unsqueeze(0) 239 | grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() 240 | 241 | patch_mov_sampled = F.grid_sample(patch_features_mov.float(),grid_disp.view(1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam,3).cuda(),align_corners=False,mode='bilinear') 242 | 243 | sampled_cost = (patch_mov_sampled-patch_features_fix).pow(2).mean(1)*n_ch 244 | loss = sampled_cost.mean() 245 | (loss+reg_loss).backward() 246 | optimizer.step() 247 | scale1 = torch.tensor([D-1,W-1,H-1]).cuda()/2 248 | #lms_fix1 = (lms_fixed.flip(1)/scale1-1).cuda().view(1,-1,1,1,3) 249 | 250 | if(iter>=59): 251 | with torch.no_grad(): 252 | if((iter-59)%20==0): 253 | 254 | fitted_grid = disp_sample.detach().permute(0,4,1,2,3) 255 | disp_hr = F.interpolate(fitted_grid*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 256 | 257 | kernel_smooth = 3; padding_smooth = kernel_smooth//2 258 | 259 | ii = int((iter-59)//20) 260 | for kk in range(4): 261 | if(kk>0): 262 | disp_hr = F.avg_pool3d(disp_hr,kernel_smooth,padding=padding_smooth,stride=1) 263 | #disp_sampled = F.grid_sample(disp_hr.float().cuda(),lms_fix1).squeeze().t().cpu().data 264 | lms_fix1 = (key_fixed.flip(1)/scale1-1).cuda().view(1,-1,1,1,3) 265 | disp_sampled = F.grid_sample(disp_hr.float().cuda(),lms_fix1).squeeze().t().cpu().data 266 | jac_det = jacobian_determinant_3d(disp_hr.float(),False) 267 | 268 | TRE1 = (key_fixed.cpu()-key_moving.cpu()+disp_sampled).square().sum(-1).sqrt() 269 | 270 | #t_mind[s] += t1-t0 271 | #t_convex[s] += t2-t1 272 | tre2[s,ii,kk,0] += 1/len(topk)*TRE1.mean() 273 | tre2[s,ii,kk,1] += 1/len(topk)*TRE1[robust30[i]].mean() 274 | jac_det_log = jac_det.add(3).clamp_(0.000000001, 1000000000).log()#.std() 275 | jstd2[s,ii,kk,0] += 1/len(topk)*(jac_det_log).std().cpu() 276 | jstd2[s,ii,kk,1] += 1/len(topk)*((jac_det<0).float().mean()).cpu() 277 | 278 | 279 | torch.save([tre2,jstd2],config['output_adam']) 280 | loss.cpu(); del loss 281 | reg_loss.cpu(); del reg_loss; 282 | net.cpu(); del net; 283 | if(tre2[s,:,:,0].min() 0: 78 | niter_adam = 80 79 | with torch.no_grad(): 80 | 81 | patch_features_fix = F.avg_pool3d(features_fix,grid_sp_adam,stride=grid_sp_adam) 82 | patch_features_mov = F.avg_pool3d(features_mov,grid_sp_adam,stride=grid_sp_adam) 83 | 84 | 85 | #create optimisable displacement grid 86 | disp_lr = F.interpolate(disp_hr,size=(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),mode='trilinear',align_corners=False) 87 | 88 | 89 | net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),bias=False)) 90 | net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp_adam 91 | net.cuda() 92 | optimizer = torch.optim.Adam(net.parameters(), lr=1) 93 | 94 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),align_corners=False) 95 | 96 | #run Adam optimisation with diffusion regularisation and B-spline smoothing 97 | for iter in range(niter_adam): 98 | optimizer.zero_grad() 99 | 100 | disp_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1) 101 | reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ 102 | lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ 103 | lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() 104 | 105 | scale = torch.tensor([(H//grid_sp_adam-1)/2,(W//grid_sp_adam-1)/2,(D//grid_sp_adam-1)/2]).cuda().unsqueeze(0) 106 | grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() 107 | 108 | patch_mov_sampled = F.grid_sample(patch_features_mov.float(),grid_disp.view(1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam,3).cuda(),align_corners=False,mode='bilinear') 109 | 110 | sampled_cost = (patch_mov_sampled-patch_features_fix).pow(2).mean(1)*12 111 | loss = sampled_cost.mean() 112 | (loss+reg_loss).backward() 113 | optimizer.step() 114 | 115 | if iter == 39: 116 | fitted_grid_40 = disp_sample.detach().permute(0,4,1,2,3) 117 | t40 = time.time() 118 | if iter == 59: 119 | fitted_grid_60 = disp_sample.detach().permute(0,4,1,2,3) 120 | t60 = time.time() 121 | if iter == 79: 122 | fitted_grid_80 = disp_sample.detach().permute(0,4,1,2,3) 123 | t80 = time.time() 124 | 125 | disp_hr_40 = F.interpolate(fitted_grid_40*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 126 | disp_hr_60 = F.interpolate(fitted_grid_60*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 127 | disp_hr_80 = F.interpolate(fitted_grid_80*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 128 | 129 | kernel_smooth = 5 130 | padding_smooth = kernel_smooth//2 131 | disp_hr_40_smooth5 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_40,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 132 | disp_hr_60_smooth5 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_60,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 133 | disp_hr_80_smooth5 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_80,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 134 | 135 | kernel_smooth = 3 136 | padding_smooth = kernel_smooth//2 137 | disp_hr_40_smooth3 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_40,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 138 | disp_hr_60_smooth3 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_60,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 139 | disp_hr_80_smooth3 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_80,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 140 | 141 | 142 | torch.cuda.synchronize() 143 | t1 = time.time() 144 | case_time = t1-t0 145 | print('case time: ', case_time) 146 | 147 | # no smoothing 148 | x = disp_hr_40[0,0,:,:,:].cpu().half().data.numpy() 149 | y = disp_hr_40[0,1,:,:,:].cpu().half().data.numpy() 150 | z = disp_hr_40[0,2,:,:,:].cpu().half().data.numpy() 151 | displacements_40 = np.stack((x,y,z),3).astype(float) 152 | 153 | x = disp_hr_60[0,0,:,:,:].cpu().half().data.numpy() 154 | y = disp_hr_60[0,1,:,:,:].cpu().half().data.numpy() 155 | z = disp_hr_60[0,2,:,:,:].cpu().half().data.numpy() 156 | displacements_60 = np.stack((x,y,z),3).astype(float) 157 | 158 | x = disp_hr_80[0,0,:,:,:].cpu().half().data.numpy() 159 | y = disp_hr_80[0,1,:,:,:].cpu().half().data.numpy() 160 | z = disp_hr_80[0,2,:,:,:].cpu().half().data.numpy() 161 | displacements_80 = np.stack((x,y,z),3).astype(float) 162 | 163 | # smoothing kernel=3 164 | x = disp_hr_40_smooth3[0,0,:,:,:].cpu().half().data.numpy() 165 | y = disp_hr_40_smooth3[0,1,:,:,:].cpu().half().data.numpy() 166 | z = disp_hr_40_smooth3[0,2,:,:,:].cpu().half().data.numpy() 167 | displacements_40_smooth3 = np.stack((x,y,z),3).astype(float) 168 | 169 | x = disp_hr_60_smooth3[0,0,:,:,:].cpu().half().data.numpy() 170 | y = disp_hr_60_smooth3[0,1,:,:,:].cpu().half().data.numpy() 171 | z = disp_hr_60_smooth3[0,2,:,:,:].cpu().half().data.numpy() 172 | displacements_60_smooth3 = np.stack((x,y,z),3).astype(float) 173 | 174 | x = disp_hr_80_smooth3[0,0,:,:,:].cpu().half().data.numpy() 175 | y = disp_hr_80_smooth3[0,1,:,:,:].cpu().half().data.numpy() 176 | z = disp_hr_80_smooth3[0,2,:,:,:].cpu().half().data.numpy() 177 | displacements_80_smooth3 = np.stack((x,y,z),3).astype(float) 178 | 179 | # smoothing kernel=5 180 | x = disp_hr_40_smooth5[0,0,:,:,:].cpu().half().data.numpy() 181 | y = disp_hr_40_smooth5[0,1,:,:,:].cpu().half().data.numpy() 182 | z = disp_hr_40_smooth5[0,2,:,:,:].cpu().half().data.numpy() 183 | displacements_40_smooth5 = np.stack((x,y,z),3).astype(float) 184 | 185 | x = disp_hr_60_smooth5[0,0,:,:,:].cpu().half().data.numpy() 186 | y = disp_hr_60_smooth5[0,1,:,:,:].cpu().half().data.numpy() 187 | z = disp_hr_60_smooth5[0,2,:,:,:].cpu().half().data.numpy() 188 | displacements_60_smooth5 = np.stack((x,y,z),3).astype(float) 189 | 190 | x = disp_hr_80_smooth5[0,0,:,:,:].cpu().half().data.numpy() 191 | y = disp_hr_80_smooth5[0,1,:,:,:].cpu().half().data.numpy() 192 | z = disp_hr_80_smooth5[0,2,:,:,:].cpu().half().data.numpy() 193 | displacements_80_smooth5 = np.stack((x,y,z),3).astype(float) 194 | 195 | 196 | return displacements_40, displacements_60, displacements_80, displacements_40_smooth3, displacements_60_smooth3, displacements_80_smooth3, displacements_40_smooth5, displacements_60_smooth5, displacements_80_smooth5, case_time 197 | -------------------------------------------------------------------------------- /self_configuring/convex_adam_MIND_testset.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from scipy.ndimage import zoom as zoom 9 | 10 | from convexAdam.convex_adam_MIND import extract_features 11 | from convexAdam.convex_adam_utils import (correlate, coupled_convex, 12 | inverse_consistency) 13 | 14 | warnings.filterwarnings("ignore") 15 | 16 | 17 | # coupled convex optimisation with adam instance optimisation 18 | def convex_adam(img_fixed, 19 | img_moving, 20 | mind_r, 21 | mind_d, 22 | use_mask, 23 | mask_fixed, 24 | mask_moving, 25 | lambda_weight, 26 | grid_sp, 27 | disp_hw, 28 | selected_niter, 29 | selected_smooth, 30 | grid_sp_adam=2, 31 | ic=True): 32 | 33 | H,W,D = img_fixed.shape 34 | 35 | torch.cuda.synchronize() 36 | t0 = time.time() 37 | 38 | #compute features and downsample (using average pooling) 39 | with torch.no_grad(): 40 | 41 | features_fix, features_mov = extract_features(img_fixed=img_fixed, 42 | img_moving=img_moving, 43 | mind_r=mind_r, 44 | mind_d=mind_d, 45 | use_mask=use_mask, 46 | mask_fixed=mask_fixed, 47 | mask_moving=mask_moving) 48 | 49 | features_fix_smooth = F.avg_pool3d(features_fix,grid_sp,stride=grid_sp) 50 | features_mov_smooth = F.avg_pool3d(features_mov,grid_sp,stride=grid_sp) 51 | 52 | n_ch = features_fix_smooth.shape[1] 53 | 54 | # compute correlation volume with SSD 55 | ssd,ssd_argmin = correlate(features_fix_smooth,features_mov_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 56 | 57 | # provide auxiliary mesh grid 58 | disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).cuda().half().unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1) 59 | 60 | # perform coupled convex optimisation 61 | disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp,(H,W,D)) 62 | 63 | # if "ic" flag is set: make inverse consistent 64 | if ic: 65 | scale = torch.tensor([H//grid_sp-1,W//grid_sp-1,D//grid_sp-1]).view(1,3,1,1,1).cuda().half()/2 66 | 67 | ssd_,ssd_argmin_ = correlate(features_mov_smooth,features_fix_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 68 | 69 | disp_soft_ = coupled_convex(ssd_,ssd_argmin_,disp_mesh_t,grid_sp,(H,W,D)) 70 | disp_ice,_ = inverse_consistency((disp_soft/scale).flip(1),(disp_soft_/scale).flip(1),iter=15) 71 | 72 | disp_hr = F.interpolate(disp_ice.flip(1)*scale*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False) 73 | 74 | else: 75 | disp_hr=disp_soft 76 | 77 | # run Adam instance optimisation 78 | if lambda_weight > 0: 79 | with torch.no_grad(): 80 | 81 | patch_features_fix = F.avg_pool3d(features_fix,grid_sp_adam,stride=grid_sp_adam) 82 | patch_features_mov = F.avg_pool3d(features_mov,grid_sp_adam,stride=grid_sp_adam) 83 | 84 | 85 | #create optimisable displacement grid 86 | disp_lr = F.interpolate(disp_hr,size=(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),mode='trilinear',align_corners=False) 87 | 88 | 89 | net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),bias=False)) 90 | net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp_adam 91 | net.cuda() 92 | optimizer = torch.optim.Adam(net.parameters(), lr=1) 93 | 94 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),align_corners=False) 95 | 96 | #run Adam optimisation with diffusion regularisation and B-spline smoothing 97 | for iter in range(selected_niter): 98 | optimizer.zero_grad() 99 | 100 | disp_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1) 101 | reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ 102 | lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ 103 | lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() 104 | 105 | scale = torch.tensor([(H//grid_sp_adam-1)/2,(W//grid_sp_adam-1)/2,(D//grid_sp_adam-1)/2]).cuda().unsqueeze(0) 106 | grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() 107 | 108 | patch_mov_sampled = F.grid_sample(patch_features_mov.float(),grid_disp.view(1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam,3).cuda(),align_corners=False,mode='bilinear') 109 | 110 | sampled_cost = (patch_mov_sampled-patch_features_fix).pow(2).mean(1)*12 111 | loss = sampled_cost.mean() 112 | (loss+reg_loss).backward() 113 | optimizer.step() 114 | 115 | 116 | fitted_grid = disp_sample.detach().permute(0,4,1,2,3) 117 | disp_hr = F.interpolate(fitted_grid*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 118 | 119 | if selected_smooth == 5: 120 | kernel_smooth = 5 121 | padding_smooth = kernel_smooth//2 122 | disp_hr = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 123 | 124 | 125 | if selected_smooth == 3: 126 | kernel_smooth = 3 127 | padding_smooth = kernel_smooth//2 128 | disp_hr = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 129 | 130 | 131 | torch.cuda.synchronize() 132 | t1 = time.time() 133 | case_time = t1-t0 134 | print('case time: ', case_time) 135 | 136 | x = disp_hr[0,0,:,:,:].cpu().half().data.numpy() 137 | y = disp_hr[0,1,:,:,:].cpu().half().data.numpy() 138 | z = disp_hr[0,2,:,:,:].cpu().half().data.numpy() 139 | displacements = np.stack((x,y,z),3).astype(float) 140 | 141 | return displacements, case_time 142 | -------------------------------------------------------------------------------- /self_configuring/convex_adam_nnUNet.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from convexAdam.convex_adam_utils import (correlate, coupled_convex, 10 | inverse_consistency) 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | 15 | # extract MIND and/or semantic nnUNet features 16 | def extract_features(pred_fixed, 17 | pred_moving): 18 | 19 | eps=1e-32 20 | H,W,D = pred_fixed.shape[-3:] 21 | 22 | #weight = 1/((torch.bincount(pred_fixed.long().reshape(-1))+torch.bincount(pred_moving.long().reshape(-1)))+eps).float().pow(.3) 23 | #weight /= weight.mean() 24 | 25 | combined_bins = torch.bincount(pred_fixed.long().reshape(-1))+torch.bincount(pred_moving.long().reshape(-1)) 26 | 27 | pos = torch.nonzero(combined_bins).reshape(-1) 28 | 29 | #features_fix = 10*(pred_fixed[:,1:,:,:,:].data.float().permute(0,1,4,3,2).contiguous()*weight[1:].view(1,-1,1,1,1).cuda()).half() 30 | #features_mov = 10*(pred_moving[:,1:,:,:,:].data.float().permute(0,1,4,3,2).contiguous()*weight[1:].view(1,-1,1,1,1).cuda()).half() 31 | 32 | pred_fixed = F.one_hot(pred_fixed.cuda().view(1,H,W,D).long())[:,:,:,:,pos] 33 | pred_moving = F.one_hot(pred_moving.cuda().view(1,H,W,D).long())[:,:,:,:,pos] 34 | 35 | weight = 1/((torch.bincount(pred_fixed.permute(0,4,1,2,3).argmax(1).long().reshape(-1))+torch.bincount(pred_moving.permute(0,4,1,2,3).argmax(1).long().reshape(-1)))+eps).float().pow(.3) 36 | weight /= weight.mean() 37 | 38 | features_fix = 10*(pred_fixed.data.float().permute(0,4,1,2,3).contiguous()*weight.view(1,-1,1,1,1).cuda()).half() 39 | features_mov = 10*(pred_moving.data.float().permute(0,4,1,2,3).contiguous()*weight.view(1,-1,1,1,1).cuda()).half() 40 | 41 | return features_fix, features_mov 42 | 43 | # coupled convex optimisation with adam instance optimisation 44 | def convex_adam(img_fixed, 45 | img_moving, 46 | pred_fixed, 47 | pred_moving, 48 | use_mask, 49 | mask_fixed, 50 | mask_moving, 51 | lambda_weight, 52 | grid_sp, 53 | disp_hw, 54 | grid_sp_adam=2, 55 | ic=True): 56 | 57 | H,W,D = img_fixed.shape 58 | 59 | torch.cuda.synchronize() 60 | t0 = time.time() 61 | 62 | #compute features and downsample (using average pooling) 63 | with torch.no_grad(): 64 | 65 | features_fix, features_mov = extract_features(pred_fixed=pred_fixed, 66 | pred_moving=pred_moving) 67 | 68 | features_fix_smooth = F.avg_pool3d(features_fix,grid_sp,stride=grid_sp) 69 | features_mov_smooth = F.avg_pool3d(features_mov,grid_sp,stride=grid_sp) 70 | 71 | n_ch = features_fix_smooth.shape[1] 72 | 73 | # compute correlation volume with SSD 74 | ssd,ssd_argmin = correlate(features_fix_smooth,features_mov_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 75 | 76 | # provide auxiliary mesh grid 77 | disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).cuda().half().unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1) 78 | 79 | # perform coupled convex optimisation 80 | disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp,(H,W,D)) 81 | 82 | # if "ic" flag is set: make inverse consistent 83 | if ic: 84 | scale = torch.tensor([H//grid_sp-1,W//grid_sp-1,D//grid_sp-1]).view(1,3,1,1,1).cuda().half()/2 85 | 86 | ssd_,ssd_argmin_ = correlate(features_mov_smooth,features_fix_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 87 | 88 | disp_soft_ = coupled_convex(ssd_,ssd_argmin_,disp_mesh_t,grid_sp,(H,W,D)) 89 | disp_ice,_ = inverse_consistency((disp_soft/scale).flip(1),(disp_soft_/scale).flip(1),iter=15) 90 | 91 | disp_hr = F.interpolate(disp_ice.flip(1)*scale*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False) 92 | 93 | else: 94 | disp_hr=disp_soft 95 | 96 | # run Adam instance optimisation 97 | if lambda_weight > 0: 98 | niter_adam = 80 99 | with torch.no_grad(): 100 | 101 | patch_features_fix = F.avg_pool3d(features_fix,grid_sp_adam,stride=grid_sp_adam) 102 | patch_features_mov = F.avg_pool3d(features_mov,grid_sp_adam,stride=grid_sp_adam) 103 | 104 | 105 | #create optimisable displacement grid 106 | disp_lr = F.interpolate(disp_hr,size=(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),mode='trilinear',align_corners=False) 107 | 108 | 109 | net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),bias=False)) 110 | net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp_adam 111 | net.cuda() 112 | optimizer = torch.optim.Adam(net.parameters(), lr=1) 113 | 114 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),align_corners=False) 115 | 116 | #run Adam optimisation with diffusion regularisation and B-spline smoothing 117 | for iter in range(niter_adam): 118 | optimizer.zero_grad() 119 | 120 | disp_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1) 121 | reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ 122 | lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ 123 | lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() 124 | 125 | scale = torch.tensor([(H//grid_sp_adam-1)/2,(W//grid_sp_adam-1)/2,(D//grid_sp_adam-1)/2]).cuda().unsqueeze(0) 126 | grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() 127 | 128 | patch_mov_sampled = F.grid_sample(patch_features_mov.float(),grid_disp.view(1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam,3).cuda(),align_corners=False,mode='bilinear') 129 | 130 | sampled_cost = (patch_mov_sampled-patch_features_fix).pow(2).mean(1)*12 131 | loss = sampled_cost.mean() 132 | (loss+reg_loss).backward() 133 | optimizer.step() 134 | 135 | if iter == 39: 136 | fitted_grid_40 = disp_sample.detach().permute(0,4,1,2,3) 137 | t40 = time.time() 138 | if iter == 59: 139 | fitted_grid_60 = disp_sample.detach().permute(0,4,1,2,3) 140 | t60 = time.time() 141 | if iter == 79: 142 | fitted_grid_80 = disp_sample.detach().permute(0,4,1,2,3) 143 | t80 = time.time() 144 | 145 | disp_hr_40 = F.interpolate(fitted_grid_40*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 146 | disp_hr_60 = F.interpolate(fitted_grid_60*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 147 | disp_hr_80 = F.interpolate(fitted_grid_80*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 148 | 149 | kernel_smooth = 5 150 | padding_smooth = kernel_smooth//2 151 | disp_hr_40_smooth5 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_40,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 152 | disp_hr_60_smooth5 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_60,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 153 | disp_hr_80_smooth5 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_80,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 154 | 155 | kernel_smooth = 3 156 | padding_smooth = kernel_smooth//2 157 | disp_hr_40_smooth3 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_40,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 158 | disp_hr_60_smooth3 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_60,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 159 | disp_hr_80_smooth3 = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr_80,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 160 | 161 | 162 | torch.cuda.synchronize() 163 | t1 = time.time() 164 | case_time = t1-t0 165 | print('case time: ', case_time) 166 | 167 | # no smoothing 168 | x = disp_hr_40[0,0,:,:,:].cpu().half().data.numpy() 169 | y = disp_hr_40[0,1,:,:,:].cpu().half().data.numpy() 170 | z = disp_hr_40[0,2,:,:,:].cpu().half().data.numpy() 171 | displacements_40 = np.stack((x,y,z),3).astype(float) 172 | 173 | x = disp_hr_60[0,0,:,:,:].cpu().half().data.numpy() 174 | y = disp_hr_60[0,1,:,:,:].cpu().half().data.numpy() 175 | z = disp_hr_60[0,2,:,:,:].cpu().half().data.numpy() 176 | displacements_60 = np.stack((x,y,z),3).astype(float) 177 | 178 | x = disp_hr_80[0,0,:,:,:].cpu().half().data.numpy() 179 | y = disp_hr_80[0,1,:,:,:].cpu().half().data.numpy() 180 | z = disp_hr_80[0,2,:,:,:].cpu().half().data.numpy() 181 | displacements_80 = np.stack((x,y,z),3).astype(float) 182 | 183 | # smoothing kernel=3 184 | x = disp_hr_40_smooth3[0,0,:,:,:].cpu().half().data.numpy() 185 | y = disp_hr_40_smooth3[0,1,:,:,:].cpu().half().data.numpy() 186 | z = disp_hr_40_smooth3[0,2,:,:,:].cpu().half().data.numpy() 187 | displacements_40_smooth3 = np.stack((x,y,z),3).astype(float) 188 | 189 | x = disp_hr_60_smooth3[0,0,:,:,:].cpu().half().data.numpy() 190 | y = disp_hr_60_smooth3[0,1,:,:,:].cpu().half().data.numpy() 191 | z = disp_hr_60_smooth3[0,2,:,:,:].cpu().half().data.numpy() 192 | displacements_60_smooth3 = np.stack((x,y,z),3).astype(float) 193 | 194 | x = disp_hr_80_smooth3[0,0,:,:,:].cpu().half().data.numpy() 195 | y = disp_hr_80_smooth3[0,1,:,:,:].cpu().half().data.numpy() 196 | z = disp_hr_80_smooth3[0,2,:,:,:].cpu().half().data.numpy() 197 | displacements_80_smooth3 = np.stack((x,y,z),3).astype(float) 198 | 199 | # smoothing kernel=5 200 | x = disp_hr_40_smooth5[0,0,:,:,:].cpu().half().data.numpy() 201 | y = disp_hr_40_smooth5[0,1,:,:,:].cpu().half().data.numpy() 202 | z = disp_hr_40_smooth5[0,2,:,:,:].cpu().half().data.numpy() 203 | displacements_40_smooth5 = np.stack((x,y,z),3).astype(float) 204 | 205 | x = disp_hr_60_smooth5[0,0,:,:,:].cpu().half().data.numpy() 206 | y = disp_hr_60_smooth5[0,1,:,:,:].cpu().half().data.numpy() 207 | z = disp_hr_60_smooth5[0,2,:,:,:].cpu().half().data.numpy() 208 | displacements_60_smooth5 = np.stack((x,y,z),3).astype(float) 209 | 210 | x = disp_hr_80_smooth5[0,0,:,:,:].cpu().half().data.numpy() 211 | y = disp_hr_80_smooth5[0,1,:,:,:].cpu().half().data.numpy() 212 | z = disp_hr_80_smooth5[0,2,:,:,:].cpu().half().data.numpy() 213 | displacements_80_smooth5 = np.stack((x,y,z),3).astype(float) 214 | 215 | 216 | return displacements_40, displacements_60, displacements_80, displacements_40_smooth3, displacements_60_smooth3, displacements_80_smooth3, displacements_40_smooth5, displacements_60_smooth5, displacements_80_smooth5, case_time 217 | -------------------------------------------------------------------------------- /self_configuring/convex_adam_nnUNet_testset.py: -------------------------------------------------------------------------------- 1 | import time 2 | import warnings 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | from convexAdam.convex_adam_utils import (correlate, coupled_convex, 10 | inverse_consistency) 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | 15 | # extract MIND and/or semantic nnUNet features 16 | # extract MIND and/or semantic nnUNet features 17 | def extract_features(pred_fixed, 18 | pred_moving): 19 | 20 | eps=1e-32 21 | H,W,D = pred_fixed.shape[-3:] 22 | 23 | #weight = 1/((torch.bincount(pred_fixed.long().reshape(-1))+torch.bincount(pred_moving.long().reshape(-1)))+eps).float().pow(.3) 24 | #weight /= weight.mean() 25 | 26 | combined_bins = torch.bincount(pred_fixed.long().reshape(-1))+torch.bincount(pred_moving.long().reshape(-1)) 27 | 28 | pos = torch.nonzero(combined_bins).reshape(-1) 29 | 30 | #features_fix = 10*(pred_fixed[:,1:,:,:,:].data.float().permute(0,1,4,3,2).contiguous()*weight[1:].view(1,-1,1,1,1).cuda()).half() 31 | #features_mov = 10*(pred_moving[:,1:,:,:,:].data.float().permute(0,1,4,3,2).contiguous()*weight[1:].view(1,-1,1,1,1).cuda()).half() 32 | 33 | pred_fixed = F.one_hot(pred_fixed.cuda().view(1,H,W,D).long())[:,:,:,:,pos] 34 | pred_moving = F.one_hot(pred_moving.cuda().view(1,H,W,D).long())[:,:,:,:,pos] 35 | 36 | weight = 1/((torch.bincount(pred_fixed.permute(0,4,1,2,3).argmax(1).long().reshape(-1))+torch.bincount(pred_moving.permute(0,4,1,2,3).argmax(1).long().reshape(-1)))+eps).float().pow(.3) 37 | weight /= weight.mean() 38 | 39 | features_fix = 10*(pred_fixed.data.float().permute(0,4,1,2,3).contiguous()*weight.view(1,-1,1,1,1).cuda()).half() 40 | features_mov = 10*(pred_moving.data.float().permute(0,4,1,2,3).contiguous()*weight.view(1,-1,1,1,1).cuda()).half() 41 | return features_fix, features_mov 42 | 43 | # coupled convex optimisation with adam instance optimisation 44 | def convex_adam(img_fixed, 45 | img_moving, 46 | pred_fixed, 47 | pred_moving, 48 | use_mask, 49 | mask_fixed, 50 | mask_moving, 51 | lambda_weight, 52 | grid_sp, 53 | disp_hw, 54 | selected_niter, 55 | selected_smooth, 56 | grid_sp_adam=2, 57 | ic=True): 58 | 59 | H,W,D = img_fixed.shape 60 | 61 | torch.cuda.synchronize() 62 | t0 = time.time() 63 | 64 | #compute features and downsample (using average pooling) 65 | with torch.no_grad(): 66 | 67 | features_fix, features_mov = extract_features(pred_fixed=pred_fixed, 68 | pred_moving=pred_moving) 69 | 70 | features_fix_smooth = F.avg_pool3d(features_fix,grid_sp,stride=grid_sp) 71 | features_mov_smooth = F.avg_pool3d(features_mov,grid_sp,stride=grid_sp) 72 | 73 | n_ch = features_fix_smooth.shape[1] 74 | 75 | # compute correlation volume with SSD 76 | ssd,ssd_argmin = correlate(features_fix_smooth,features_mov_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 77 | 78 | # provide auxiliary mesh grid 79 | disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).cuda().half().unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1) 80 | 81 | # perform coupled convex optimisation 82 | disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp,(H,W,D)) 83 | 84 | # if "ic" flag is set: make inverse consistent 85 | if ic: 86 | scale = torch.tensor([H//grid_sp-1,W//grid_sp-1,D//grid_sp-1]).view(1,3,1,1,1).cuda().half()/2 87 | 88 | ssd_,ssd_argmin_ = correlate(features_mov_smooth,features_fix_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 89 | 90 | disp_soft_ = coupled_convex(ssd_,ssd_argmin_,disp_mesh_t,grid_sp,(H,W,D)) 91 | disp_ice,_ = inverse_consistency((disp_soft/scale).flip(1),(disp_soft_/scale).flip(1),iter=15) 92 | 93 | disp_hr = F.interpolate(disp_ice.flip(1)*scale*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False) 94 | 95 | else: 96 | disp_hr=disp_soft 97 | 98 | # run Adam instance optimisation 99 | if lambda_weight > 0: 100 | with torch.no_grad(): 101 | 102 | patch_features_fix = F.avg_pool3d(features_fix,grid_sp_adam,stride=grid_sp_adam) 103 | patch_features_mov = F.avg_pool3d(features_mov,grid_sp_adam,stride=grid_sp_adam) 104 | 105 | 106 | #create optimisable displacement grid 107 | disp_lr = F.interpolate(disp_hr,size=(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),mode='trilinear',align_corners=False) 108 | 109 | 110 | net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),bias=False)) 111 | net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp_adam 112 | net.cuda() 113 | optimizer = torch.optim.Adam(net.parameters(), lr=1) 114 | 115 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),align_corners=False) 116 | 117 | #run Adam optimisation with diffusion regularisation and B-spline smoothing 118 | for iter in range(selected_niter): 119 | optimizer.zero_grad() 120 | 121 | disp_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1) 122 | reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ 123 | lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ 124 | lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() 125 | 126 | scale = torch.tensor([(H//grid_sp_adam-1)/2,(W//grid_sp_adam-1)/2,(D//grid_sp_adam-1)/2]).cuda().unsqueeze(0) 127 | grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() 128 | 129 | patch_mov_sampled = F.grid_sample(patch_features_mov.float(),grid_disp.view(1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam,3).cuda(),align_corners=False,mode='bilinear') 130 | 131 | sampled_cost = (patch_mov_sampled-patch_features_fix).pow(2).mean(1)*12 132 | loss = sampled_cost.mean() 133 | (loss+reg_loss).backward() 134 | optimizer.step() 135 | 136 | 137 | fitted_grid = disp_sample.detach().permute(0,4,1,2,3) 138 | disp_hr = F.interpolate(fitted_grid*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 139 | 140 | if selected_smooth == 5: 141 | kernel_smooth = 5 142 | padding_smooth = kernel_smooth//2 143 | disp_hr = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 144 | 145 | 146 | if selected_smooth == 3: 147 | kernel_smooth = 3 148 | padding_smooth = kernel_smooth//2 149 | disp_hr = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 150 | 151 | 152 | torch.cuda.synchronize() 153 | t1 = time.time() 154 | case_time = t1-t0 155 | print('case time: ', case_time) 156 | 157 | x = disp_hr[0,0,:,:,:].cpu().half().data.numpy() 158 | y = disp_hr[0,1,:,:,:].cpu().half().data.numpy() 159 | z = disp_hr[0,2,:,:,:].cpu().half().data.numpy() 160 | displacements = np.stack((x,y,z),3).astype(float) 161 | 162 | return displacements, case_time 163 | -------------------------------------------------------------------------------- /self_configuring/convex_run_paired_mind.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | import warnings 4 | 5 | import nibabel as nib 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | warnings.filterwarnings("ignore") 11 | import json 12 | import os 13 | 14 | from convexAdam_hyper_util import (correlate, coupled_convex, extract_features, 15 | inverse_consistency, 16 | jacobian_determinant_3d, sort_rank) 17 | from tqdm.auto import tqdm, trange 18 | 19 | 20 | def get_data_train(topk,HWD,f_img,f_key,f_mask): 21 | l2r_base_folder = './' 22 | #~/storage/staff/christophgrossbroeh/data/Learn2Reg/Learn2Reg_Dataset_release_v1.1 23 | #topk = (1,2,3,4,5,16,17,18,19,20) 24 | 25 | 26 | # #### topk = (3,5,6,10,11,12,13,16) 27 | H,W,D = HWD[0],HWD[1],HWD[2] 28 | 29 | imgs_fixed = [] 30 | keypts_fixed = [] 31 | masks_fixed = [] 32 | imgs_moving = [] 33 | keypts_moving = [] 34 | masks_moving = [] 35 | 36 | for i in tqdm(topk): 37 | file_img = f_img.replace('xxxx',str(i).zfill(4)) 38 | file_key = f_key.replace('xxxx',str(i).zfill(4)) 39 | file_mask = f_mask.replace('xxxx',str(i).zfill(4)) 40 | 41 | img_fixed = torch.from_numpy(nib.load(file_img).get_fdata()).float().cuda().contiguous() 42 | key_fixed = torch.from_numpy(np.loadtxt(file_key,delimiter=',')).float() 43 | mask_fixed = torch.from_numpy(nib.load(file_mask).get_fdata()).float().cuda().contiguous() 44 | imgs_fixed.append(img_fixed) 45 | keypts_fixed.append(key_fixed) 46 | masks_fixed.append(mask_fixed) 47 | 48 | file_img = file_img.replace('0000',str(1).zfill(4)) 49 | file_key = file_key.replace('0000',str(1).zfill(4)) 50 | file_mask = file_mask.replace('0000',str(1).zfill(4)) 51 | 52 | img_moving = torch.from_numpy(nib.load(file_img).get_fdata()).float().cuda().contiguous() 53 | key_moving = torch.from_numpy(np.loadtxt(file_key,delimiter=',')).float() 54 | mask_moving = torch.from_numpy(nib.load(file_mask).get_fdata()).float().cuda().contiguous() 55 | imgs_moving.append(img_moving) 56 | keypts_moving.append(key_moving) 57 | masks_moving.append(mask_moving) 58 | 59 | return imgs_fixed,keypts_fixed,masks_fixed,imgs_moving,keypts_moving,masks_moving 60 | 61 | def main(gpunum,configfile): 62 | 63 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 64 | os.environ['CUDA_VISIBLE_DEVICES'] = str((gpunum)) 65 | print(torch.cuda.get_device_name()) 66 | 67 | with open(configfile, 'r') as f: 68 | config = json.load(f) 69 | topk = config['topk'] 70 | 71 | 72 | #starting full run 0 out of 100 73 | # setting mind_r 3 mind_d 2 grid_sp 3 disp_hw 7 74 | #100%|███████████████████████████████████████████████████████████████████████████████████| 15/15 [00:38<00:00, 2.59s/it] 75 | #s 0 1.623 1.865 jstd tensor(0.0968) 76 | #tensor(2) 77 | #tensor([1.9585, 2.1685]) tensor([0.0729, 0.0000]) tensor(3.5369) 78 | #tensor([2., 2., 4., 4.]) 79 | 80 | 81 | #topk_pair = config['topk_pair']# ((2,4),(4,9),(3,4),(0,4),(1,4),(4,7),(4,5),(2,8)) 82 | #topk_pair = [] 83 | #for i in range(0,10): 84 | # for j in range(0,10): 85 | # if(i',TRE1.mean()) 146 | t_mind[s] += t1-t0 147 | t_convex[s] += t2-t1 148 | dice[s,0] += 1/len(topk_pair)*DICE1.mean() 149 | dice[s,1] += 1/len(topk_pair)*DICE1[robust30[i]].mean() 150 | jac_det_log = jac_det.add(3).clamp_(0.000000001, 1000000000).log()#.std() 151 | jstd[s,0] += 1/len(topk_pair)*(jac_det_log).std().cpu() 152 | jstd[s,1] += 1/len(topk_pair)*((jac_det<0).float().mean()).cpu() 153 | hd95[s] += 1/len(topk_pair)*(HD95).mean().cpu() 154 | 155 | 156 | torch.save([dice,jstd,hd95,t_convex],config['output']) 157 | 158 | if(dice[s,0]>dice_min): 159 | print('s',s,'%0.3f'%dice[s,0].item(),'%0.3f'%dice[s,1].item(),'jstd',jstd[s,0]) 160 | dice_min = dice[s,0] 161 | 162 | rank1 = sort_rank(-dice[:,0]) 163 | rank1 *= sort_rank(-dice[:,1]) 164 | 165 | rank1 *= sort_rank(hd95[:]) 166 | rank1 *= sort_rank(jstd[:,0])#.sqrt() 167 | 168 | rank1 = rank1.pow(1/4) 169 | print(rank1.argmax()) 170 | print(dice[rank1.argmax()],jstd[rank1.argmax()],t_convex[rank1.argmax()]) 171 | print(settings[rank1.argmax()]) 172 | torch.save([rank1,dice,jstd,hd95,t_convex],config['output']) 173 | 174 | 175 | 176 | use_mask = True 177 | if __name__ == "__main__": 178 | gpu_id = int(sys.argv[1]) 179 | configfile = str(sys.argv[2]) 180 | main(gpu_id,configfile) 181 | -------------------------------------------------------------------------------- /self_configuring/infer_convexadam.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | import time 4 | import warnings 5 | 6 | import nibabel as nib 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | warnings.filterwarnings("ignore") 13 | import os 14 | 15 | from convexAdam_hyper_util import (GaussianSmoothing, correlate, 16 | coupled_convex, extract_features_nnunet, 17 | inverse_consistency, kovesi_spline) 18 | from tqdm.auto import trange 19 | 20 | 21 | def get_data_train(topk,HWD,f_predict,f_gt): 22 | l2r_base_folder = './' 23 | print('reading test data') 24 | f_predict = f_predict.replace('sTr/','sTs/') 25 | f_gt = f_gt.replace('sTr/','sTs/') 26 | 27 | H,W,D = HWD[0],HWD[1],HWD[2] 28 | #robustify 29 | preds_fixed = [] 30 | segs_fixed = [] 31 | preds_moving = [] 32 | segs_moving = [] 33 | 34 | for i in topk: 35 | file_pred = f_predict.replace('xxxx',str(i).zfill(4)) 36 | #file_seg = f_gt.replace('xxxx',str(i).zfill(4)) #not available for test-scans 37 | 38 | pred_fixed = torch.from_numpy(nib.load(file_pred).get_fdata()).float().cuda().contiguous() 39 | #seg_fixed = torch.from_numpy(nib.load(file_seg).get_fdata()).float().cuda().contiguous() 40 | segs_fixed.append(None)#seg_fixed) 41 | preds_fixed.append(pred_fixed) 42 | 43 | 44 | 45 | 46 | return preds_fixed,segs_fixed 47 | def main(gpunum,configfile,convex_s,adam_s1,adam_s2): 48 | 49 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 50 | os.environ['CUDA_VISIBLE_DEVICES'] = str((gpunum)) 51 | print(torch.cuda.get_device_name()) 52 | 53 | with open(configfile, 'r') as f: 54 | config = json.load(f) 55 | topk = config['test'] 56 | 57 | num_labels = config['num_labels']-1 58 | eval_labels = num_labels 59 | topk_pair = config['test_pair'] 60 | 61 | print('using all test',len(topk_pair),' registration pairs') 62 | preds_fixed,segs_fixed = get_data_train(topk,config['HWD'],config['f_predict'],config['f_gt']) 63 | 64 | 65 | torch.manual_seed(1004) 66 | settings = (torch.rand(100,3)*torch.tensor([6,4,6])+torch.tensor([.5,1.5,1.5])).round() 67 | #print(settings[1]) 68 | settings[:,0] *= 2.5 69 | settings[settings[:,1]==2,2] = torch.minimum(settings[settings[:,1]==2,2],torch.tensor([5])) 70 | print(settings.min(0).values,settings.max(0).values,) 71 | 72 | print('using predetermined setting s=',convex_s) 73 | 74 | nn_mult = int(settings[convex_s,0])#1 75 | grid_sp = int(settings[convex_s,1])#6 76 | disp_hw = int(settings[convex_s,2])#4 77 | print('setting nn_mult',nn_mult,'grid_sp',grid_sp,'disp_hw',disp_hw) 78 | 79 | ##APPLY BEST CONVEX TO TRAIN 80 | disps_lr = [] 81 | for i in trange(len(topk_pair)): 82 | 83 | t0 = time.time() 84 | 85 | pred_fixed = preds_fixed[int(topk_pair[i][0])].float() 86 | pred_moving = preds_fixed[int(topk_pair[i][1])].float() 87 | 88 | H, W, D = pred_fixed.shape[-3:] 89 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H,W,D),align_corners=False) 90 | torch.cuda.synchronize() 91 | t0 = time.time() 92 | 93 | # compute features and downsample (using average pooling) 94 | with torch.no_grad(): 95 | 96 | features_fix, features_mov = extract_features_nnunet(pred_fixed=pred_fixed, 97 | pred_moving=pred_moving) 98 | 99 | features_fix_smooth = F.avg_pool3d(features_fix,grid_sp,stride=grid_sp) 100 | features_mov_smooth = F.avg_pool3d(features_mov,grid_sp,stride=grid_sp) 101 | 102 | n_ch = features_fix_smooth.shape[1] 103 | t1 = time.time() 104 | #with torch.cuda.amp.autocast(dtype=torch.bfloat16): 105 | # compute correlation volume with SSD 106 | ssd,ssd_argmin = correlate(features_fix_smooth,features_mov_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 107 | 108 | # provide auxiliary mesh grid 109 | disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).cuda().half().unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1) 110 | 111 | # perform coupled convex optimisation 112 | disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp,(H,W,D)) 113 | 114 | # if "ic" flag is set: make inverse consistent 115 | scale = torch.tensor([H//grid_sp-1,W//grid_sp-1,D//grid_sp-1]).view(1,3,1,1,1).cuda().half()/2 116 | 117 | ssd_,ssd_argmin_ = correlate(features_mov_smooth,features_fix_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 118 | 119 | disp_soft_ = coupled_convex(ssd_,ssd_argmin_,disp_mesh_t,grid_sp,(H,W,D)) 120 | disp_ice,_ = inverse_consistency((disp_soft/scale).flip(1),(disp_soft_/scale).flip(1),iter=15) 121 | disp_lr = disp_ice.flip(1)*scale*grid_sp 122 | disps_lr.append(disp_lr.data.cpu()) 123 | disp_hr = F.interpolate(disp_lr,size=(H,W,D),mode='trilinear',align_corners=False) 124 | t2 = time.time() 125 | scale1 = torch.tensor([D-1,W-1,H-1]).cuda()/2 126 | 127 | 128 | del disp_soft; del disp_soft_; del ssd_; del ssd; del disp_hr; del features_fix; del features_mov; del features_fix_smooth; del features_mov_smooth; 129 | 130 | 131 | ##FIND OPTIMAL ADAM SETTING 132 | avgs = [GaussianSmoothing(.7).cuda(),GaussianSmoothing(1).cuda(),kovesi_spline(1.3,4).cuda(),kovesi_spline(1.6,4).cuda(),kovesi_spline(1.9,4).cuda(),kovesi_spline(2.2,4).cuda()] 133 | 134 | torch.manual_seed(2004) 135 | settings_adam = (torch.rand(75,3)*torch.tensor([4,5,7])+torch.tensor([0.5,.5,1.5])).round() 136 | settings_adam[:,2] *= .2 137 | 138 | 139 | torch.cuda.empty_cache() 140 | 141 | 142 | grid_sp_adam = int(settings_adam[adam_s1,0])#6 143 | avg_n = int(settings_adam[adam_s1,1])#6 144 | ##SHIFT-SPLINE 145 | if(grid_sp_adam==1): 146 | avg_n += 2 147 | if(grid_sp_adam==2): 148 | avg_n += 1 149 | 150 | lambda_weight = float(settings_adam[adam_s1,2])#4 151 | 152 | iters = (adam_s2//4)*20+60 153 | kks = (adam_s2%4) 154 | print('setting grid_sp_adam',grid_sp_adam,'avg_n',avg_n,'lambda_weight',lambda_weight,'iters',iters,'kks',kks) 155 | 156 | jstd2 = torch.zeros(len(topk_pair),2) 157 | dice2 = torch.zeros(len(topk_pair),eval_labels) 158 | hd95_2 = torch.zeros(len(topk_pair),eval_labels) 159 | 160 | dice_ident = torch.zeros(len(topk_pair),eval_labels) 161 | 162 | for i in trange(len(topk_pair)): 163 | file = config['f_gt'].split('/')[-1] 164 | stem = file.split('_')[0] 165 | file_field = stem+'/fieldsTs/'+file.replace(stem,'disp').replace('0000',str(int(topk[topk_pair[i][1]])).zfill(4)).replace('xxxx',str(int(topk[topk_pair[i][0]])).zfill(4)) 166 | print('writing output-nii to ',file_field) 167 | t0 = time.time() 168 | 169 | 170 | pred_fixed = preds_fixed[int(topk_pair[i][0])].float() 171 | pred_moving = preds_fixed[int(topk_pair[i][1])].float() 172 | #seg_fixed = segs_fixed[int(topk_pair[i][0])].float() 173 | #seg_moving = segs_fixed[int(topk_pair[i][1])].float() 174 | #mind_r = int(settings_adam[s,0])#1 175 | #mind_d = int(settings_adam[s,1])#2 176 | 177 | t0 = time.time() 178 | 179 | 180 | H, W, D = pred_fixed.shape[-3:] 181 | grid0_hr = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H,W,D),align_corners=False) 182 | torch.cuda.synchronize() 183 | t0 = time.time() 184 | 185 | # compute features and downsample (using average pooling) 186 | with torch.no_grad(): 187 | 188 | features_fix, features_mov = extract_features_nnunet(pred_fixed=pred_fixed, 189 | pred_moving=pred_moving) 190 | 191 | 192 | n_ch = features_fix.shape[1] 193 | # run Adam instance optimisation 194 | with torch.no_grad(): 195 | patch_features_fix = F.avg_pool3d(features_fix,grid_sp_adam,stride=grid_sp_adam) 196 | patch_features_mov = F.avg_pool3d(features_mov,grid_sp_adam,stride=grid_sp_adam) 197 | 198 | disp_hr = F.interpolate(disps_lr[i].float().cuda(),size=(H,W,D),mode='trilinear',align_corners=False) 199 | 200 | #create optimisable displacement grid 201 | disp_lr = F.interpolate(disp_hr,size=(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),mode='trilinear',align_corners=False) 202 | 203 | net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),bias=False)) 204 | net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp_adam 205 | net.cuda() 206 | optimizer = torch.optim.Adam(net.parameters(), lr=1) 207 | 208 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),align_corners=False) 209 | #run Adam optimisation with diffusion regularisation and B-spline smoothing 210 | for iter in range(iters): 211 | optimizer.zero_grad() 212 | 213 | disp_sample = (avgs[avg_n](net[0].weight)).permute(0,2,3,4,1)#,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1) 214 | reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ 215 | lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ 216 | lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() 217 | 218 | scale = torch.tensor([(H//grid_sp_adam-1)/2,(W//grid_sp_adam-1)/2,(D//grid_sp_adam-1)/2]).cuda().unsqueeze(0) 219 | grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() 220 | 221 | patch_mov_sampled = F.grid_sample(patch_features_mov.float(),grid_disp.view(1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam,3).cuda(),align_corners=False,mode='bilinear') 222 | 223 | sampled_cost = (patch_mov_sampled-patch_features_fix).pow(2).mean(1)*n_ch 224 | loss = sampled_cost.mean() 225 | (loss+reg_loss).backward() 226 | optimizer.step() 227 | scale1 = torch.tensor([D-1,W-1,H-1]).cuda()/2 228 | 229 | with torch.no_grad(): 230 | 231 | fitted_grid = disp_sample.detach().permute(0,4,1,2,3) 232 | disp_hr = F.interpolate(fitted_grid*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 233 | 234 | kernel_smooth = 3; padding_smooth = kernel_smooth//2 235 | 236 | for kk in range(kks): 237 | if(kk>0): 238 | disp_hr = F.avg_pool3d(disp_hr,kernel_smooth,padding=padding_smooth,stride=1) 239 | #save displacement field 240 | nib.save(nib.Nifti1Image(disp_hr.permute(0,2,3,4,1).squeeze().data.cpu().numpy(),np.eye(4)),file_field) 241 | 242 | 243 | 244 | use_mask = True 245 | if __name__ == "__main__": 246 | gpu_id = int(sys.argv[1]) 247 | configfile = (sys.argv[2]) 248 | convex_s = int(sys.argv[3]) 249 | adam_s1 = int(sys.argv[4]) 250 | adam_s2 = int(sys.argv[5]) 251 | main(gpu_id,configfile,convex_s,adam_s1,adam_s2) 252 | -------------------------------------------------------------------------------- /self_configuring/main_for_l2r3_MIND.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import warnings 5 | from pathlib import Path 6 | 7 | import nibabel as nib 8 | import torch 9 | from convex_adam_MIND import * 10 | from L2R_main.evaluation import evaluation 11 | 12 | warnings.filterwarnings("ignore") 13 | 14 | 15 | def main(task_name, 16 | mind_r, 17 | mind_d, 18 | use_mask, 19 | lambda_weight, 20 | grid_sp, 21 | disp_hw, 22 | evaluate, 23 | data_dir, 24 | result_path, 25 | config_path): 26 | 27 | task_dir = os.path.join(data_dir,task_name) 28 | dataset_json = os.path.join(task_dir,task_name+'_dataset.json') 29 | 30 | with open(dataset_json, 'r') as f: 31 | data = json.load(f) 32 | val_pairs = data['registration_val'] 33 | 34 | if len(data['modality'].keys()) == 1: 35 | modality_fixed = data['modality']['0'] 36 | modality_moving = data['modality']['0'] 37 | 38 | if len(data['modality'].keys()) == 2: 39 | modality_fixed = data['modality']['0'] 40 | modality_moving = data['modality']['1'] 41 | 42 | if len(data['modality'].keys()) == 3: 43 | modality_fixed = data['modality']['0'] 44 | modality_moving = data['modality']['2'] 45 | 46 | outstr = '_'+'MIND'+str(int(mind_r))+str(int(mind_d))+'_'+str(int(lambda_weight*100))+'lambda_'+str(grid_sp)+'gs1_'+str(disp_hw)+'disp_'+str(use_mask)+'Masks' 47 | print(outstr) 48 | 49 | print('>>> Modality fixed: ', modality_fixed) 50 | print('>>> Modality moving: ', modality_moving) 51 | print('>>> Settings: lambda_weight: {}; grid_sp: {}; disp_hw: {}'.format(lambda_weight, grid_sp, disp_hw)) 52 | 53 | # create save directory 54 | save_paths = ['40_smoothing0', '60_smoothing0', '80_smoothing0', '40_smoothing3', '60_smoothing3', '80_smoothing3', '40_smoothing5', '60_smoothing5', '80_smoothing5'] 55 | for save_path in save_paths: 56 | new_path = os.path.join(result_path, task_name, save_path) 57 | isExist = os.path.exists(new_path) 58 | if not isExist: 59 | os.makedirs(new_path) 60 | files = os.listdir(new_path) 61 | for item in files: 62 | if item.endswith(".nii.gz"): 63 | os.remove(os.path.join(new_path, item)) 64 | if item.endswith(".nii"): 65 | os.remove(os.path.join(new_path, item)) 66 | 67 | 68 | case_times = torch.zeros(len(val_pairs)) 69 | ii=0 70 | for _, pair in enumerate(val_pairs): 71 | path_fixed = os.path.join(task_dir, pair['fixed']) 72 | path_moving = os.path.join(task_dir, pair['moving']) 73 | img_fixed = torch.from_numpy(nib.load(path_fixed).get_fdata()).float() 74 | img_moving = torch.from_numpy(nib.load(path_moving).get_fdata()).float() 75 | if use_mask: 76 | path_fixed_mask = os.path.join(task_dir, pair['fixed'].replace('images','masks')) 77 | path_moving_mask = os.path.join(task_dir, pair['moving'].replace('images','masks')) 78 | mask_fixed = torch.from_numpy(nib.load(path_fixed_mask).get_fdata()).float() 79 | mask_moving = torch.from_numpy(nib.load(path_moving_mask).get_fdata()).float() 80 | else: 81 | mask_fixed = None 82 | mask_moving = None 83 | 84 | 85 | displacements_40, displacements_60, displacements_80, displacements_40_smooth3, displacements_60_smooth3, displacements_80_smooth3, displacements_40_smooth5, displacements_60_smooth5, displacements_80_smooth5, case_time = convex_adam(img_fixed=img_fixed, 86 | img_moving=img_moving, 87 | mind_r=mind_r, 88 | mind_d=mind_d, 89 | use_mask=use_mask, 90 | mask_fixed=mask_fixed, 91 | mask_moving=mask_moving, 92 | lambda_weight=lambda_weight, 93 | grid_sp=grid_sp, 94 | disp_hw=disp_hw) 95 | 96 | case_times[ii] = case_time 97 | ii+=1 98 | 99 | 100 | affine = nib.load(path_fixed).affine 101 | 102 | # save smoothing 0 103 | #disp_path = os.path.join(result_path, task_name, '40_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 104 | disp_path = os.path.join(result_path, task_name, '40_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 105 | disp_nii = nib.Nifti1Image(displacements_40, affine) 106 | nib.save(disp_nii, disp_path) 107 | 108 | #disp_path = os.path.join(result_path, task_name, '60_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 109 | disp_path = os.path.join(result_path, task_name, '60_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 110 | disp_nii = nib.Nifti1Image(displacements_60, affine) 111 | nib.save(disp_nii, disp_path) 112 | 113 | #disp_path = os.path.join(result_path, task_name, '80_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 114 | disp_path = os.path.join(result_path, task_name, '80_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 115 | disp_nii = nib.Nifti1Image(displacements_80, affine) 116 | nib.save(disp_nii, disp_path) 117 | 118 | # save smoothing 3 119 | #disp_path = os.path.join(result_path, task_name, '40_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 120 | disp_path = os.path.join(result_path, task_name, '40_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 121 | disp_nii = nib.Nifti1Image(displacements_40_smooth3, affine) 122 | nib.save(disp_nii, disp_path) 123 | 124 | #disp_path = os.path.join(result_path, task_name, '60_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 125 | disp_path = os.path.join(result_path, task_name, '60_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 126 | disp_nii = nib.Nifti1Image(displacements_60_smooth3, affine) 127 | nib.save(disp_nii, disp_path) 128 | 129 | #disp_path = os.path.join(result_path, task_name, '80_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 130 | disp_path = os.path.join(result_path, task_name, '80_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 131 | disp_nii = nib.Nifti1Image(displacements_80_smooth3, affine) 132 | nib.save(disp_nii, disp_path) 133 | 134 | # save smoothing 5 135 | #disp_path = os.path.join(result_path, task_name, '40_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 136 | disp_path = os.path.join(result_path, task_name, '40_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 137 | disp_nii = nib.Nifti1Image(displacements_40_smooth5, affine) 138 | nib.save(disp_nii, disp_path) 139 | 140 | #disp_path = os.path.join(result_path, task_name, '60_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 141 | disp_path = os.path.join(result_path, task_name, '60_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 142 | disp_nii = nib.Nifti1Image(displacements_60_smooth5, affine) 143 | nib.save(disp_nii, disp_path) 144 | 145 | #disp_path = os.path.join(result_path, task_name, '80_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 146 | disp_path = os.path.join(result_path, task_name, '80_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 147 | disp_nii = nib.Nifti1Image(displacements_80_smooth5, affine) 148 | nib.save(disp_nii, disp_path) 149 | 150 | 151 | median_case_time = case_times.median().item() 152 | print('median case time: ', median_case_time) 153 | 154 | if evaluate: 155 | print('>>>EVALUATION') 156 | DEFAULT_GROUND_TRUTH_PATH = Path(task_dir) 157 | eval_config_path = config_path+task_name+'_VAL_evaluation_config.json' 158 | 159 | for save_path in save_paths: 160 | DEFAULT_INPUT_PATH = Path(os.path.join(result_path, task_name,save_path)) 161 | DEFAULT_EVALUATION_OUTPUT_FILE_PATH = Path(os.path.join(result_path, task_name,save_path)+'/metrics'+outstr+'.json') 162 | evaluation.evaluate_L2R(DEFAULT_INPUT_PATH, DEFAULT_GROUND_TRUTH_PATH, DEFAULT_EVALUATION_OUTPUT_FILE_PATH, eval_config_path, verbose=False) 163 | 164 | with open(DEFAULT_EVALUATION_OUTPUT_FILE_PATH, "r") as jsonFile: 165 | data = json.load(jsonFile) 166 | 167 | data[task_name]["aggregates"]["median_case_time"] = median_case_time 168 | 169 | with open(DEFAULT_EVALUATION_OUTPUT_FILE_PATH, "w") as jsonFile: 170 | json.dump(data, jsonFile) 171 | 172 | print('Path to evaluation JSON file: ', DEFAULT_EVALUATION_OUTPUT_FILE_PATH) 173 | 174 | else: 175 | print('NO EVALUATION') 176 | 177 | if __name__=="__main__": 178 | parser = argparse.ArgumentParser() 179 | parser.add_argument('--task_name', type=str, required=True) 180 | parser.add_argument('--lambda_weight', type=float, required=True) 181 | parser.add_argument('--grid_sp', type=int, required=True) 182 | parser.add_argument('--disp_hw', type=int, required=True) 183 | parser.add_argument('--mind_r', type=int, default=1) 184 | parser.add_argument('--mind_d', type=int, default=2) 185 | parser.add_argument('--use_mask', choices=('True','False'), default= 'False') 186 | parser.add_argument('--evaluate', choices=('True','False'), default= 'True') 187 | parser.add_argument('--data_dir', type=str, default='/share/data_zoe3/grossbroehmer/Learn2Reg2022/Learn2Reg_Dataset_v11/') 188 | parser.add_argument('--result_path', type=str, default='/share/data_abby2/hsiebert/code/adam_optimisation/JournalExperiments/l2r2022/results/') 189 | parser.add_argument('--config_path', type=str, default='/share/data_abby2/hsiebert/code/adam_optimisation/JournalExperiments/l2r2022/L2R_main/evaluation/evaluation_configs/') 190 | args = parser.parse_args() 191 | 192 | 193 | if args.evaluate == 'True': 194 | evaluate=True 195 | else: 196 | evaluate=False 197 | 198 | if args.use_mask == 'True': 199 | use_mask=True 200 | else: 201 | use_mask=False 202 | 203 | 204 | 205 | task_name = args.task_name 206 | data_dir = args.data_dir 207 | mind_r = args.mind_r 208 | mind_d = args.mind_d 209 | lambda_weight = args.lambda_weight 210 | grid_sp = args.grid_sp 211 | disp_hw = args.disp_hw 212 | result_path = args.result_path 213 | config_path = args.config_path 214 | 215 | main(task_name, 216 | mind_r, 217 | mind_d, 218 | use_mask, 219 | lambda_weight, 220 | grid_sp, 221 | disp_hw, 222 | evaluate, 223 | data_dir, 224 | result_path, 225 | config_path) -------------------------------------------------------------------------------- /self_configuring/main_for_l2r3_MIND_testset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import warnings 5 | 6 | import nibabel as nib 7 | import torch 8 | from convex_adam_MIND_testset import * 9 | 10 | warnings.filterwarnings("ignore") 11 | 12 | 13 | def main(task_name, 14 | mind_r, 15 | mind_d, 16 | use_mask, 17 | lambda_weight, 18 | grid_sp, 19 | disp_hw, 20 | selected_niter, 21 | selected_smooth, 22 | data_dir, 23 | result_path): 24 | 25 | task_dir = os.path.join(data_dir,task_name) 26 | dataset_json = os.path.join(task_dir,task_name+'_dataset.json') 27 | 28 | with open(dataset_json, 'r') as f: 29 | data = json.load(f) 30 | val_pairs = data['registration_test'] 31 | 32 | # create save directory 33 | save_paths = ['results_testset'] 34 | for save_path in save_paths: 35 | new_path = os.path.join(result_path, task_name, save_path) 36 | isExist = os.path.exists(new_path) 37 | if not isExist: 38 | os.makedirs(new_path) 39 | files = os.listdir(new_path) 40 | for item in files: 41 | if item.endswith(".nii.gz"): 42 | os.remove(os.path.join(new_path, item)) 43 | if item.endswith(".nii"): 44 | os.remove(os.path.join(new_path, item)) 45 | 46 | 47 | case_times = torch.zeros(len(val_pairs)) 48 | ii=0 49 | for _, pair in enumerate(val_pairs): 50 | path_fixed = os.path.join(task_dir, pair['fixed']) 51 | path_moving = os.path.join(task_dir, pair['moving']) 52 | img_fixed = torch.from_numpy(nib.load(path_fixed).get_fdata()).float() 53 | img_moving = torch.from_numpy(nib.load(path_moving).get_fdata()).float() 54 | if use_mask: 55 | path_fixed_mask = os.path.join(task_dir, pair['fixed'].replace('images','masks')) 56 | path_moving_mask = os.path.join(task_dir, pair['moving'].replace('images','masks')) 57 | mask_fixed = torch.from_numpy(nib.load(path_fixed_mask).get_fdata()).float() 58 | mask_moving = torch.from_numpy(nib.load(path_moving_mask).get_fdata()).float() 59 | else: 60 | mask_fixed = None 61 | mask_moving = None 62 | 63 | 64 | displacements, case_time = convex_adam(img_fixed=img_fixed, 65 | img_moving=img_moving, 66 | mind_r=mind_r, 67 | mind_d=mind_d, 68 | use_mask=use_mask, 69 | mask_fixed=mask_fixed, 70 | mask_moving=mask_moving, 71 | lambda_weight=lambda_weight, 72 | grid_sp=grid_sp, 73 | disp_hw=disp_hw, 74 | selected_niter=selected_niter, 75 | selected_smooth=selected_smooth) 76 | 77 | case_times[ii] = case_time 78 | ii+=1 79 | 80 | affine = nib.load(path_fixed).affine 81 | 82 | disp_path = os.path.join(result_path, task_name, 'results_testset', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 83 | disp_nii = nib.Nifti1Image(displacements, affine) 84 | nib.save(disp_nii, disp_path) 85 | 86 | median_case_time = case_times.median().item() 87 | print('median case time: ', median_case_time) 88 | 89 | if __name__=="__main__": 90 | parser = argparse.ArgumentParser() 91 | parser.add_argument('--task_name', type=str, required=True) 92 | parser.add_argument('--lambda_weight', type=float, required=True) 93 | parser.add_argument('--grid_sp', type=int, required=True) 94 | parser.add_argument('--disp_hw', type=int, required=True) 95 | parser.add_argument('--mind_r', type=int, default=1) 96 | parser.add_argument('--mind_d', type=int, default=2) 97 | parser.add_argument('--selected_niter', type=int, required=True) 98 | parser.add_argument('--selected_smooth', type=int, required=True) 99 | parser.add_argument('--use_mask', choices=('True','False'), default= 'False') 100 | parser.add_argument('--data_dir', type=str, default='/share/data_zoe3/grossbroehmer/Learn2Reg2022/Learn2Reg_Dataset_v11/') 101 | parser.add_argument('--result_path', type=str, default='/share/data_abby2/hsiebert/code/adam_optimisation/JournalExperiments/l2r2022/results/') 102 | args = parser.parse_args() 103 | 104 | if args.use_mask == 'True': 105 | use_mask=True 106 | else: 107 | use_mask=False 108 | 109 | task_name = args.task_name 110 | data_dir = args.data_dir 111 | mind_r = args.mind_r 112 | mind_d = args.mind_d 113 | lambda_weight = args.lambda_weight 114 | grid_sp = args.grid_sp 115 | disp_hw = args.disp_hw 116 | result_path = args.result_path 117 | selected_niter = args.selected_niter 118 | selected_smooth = args.selected_smooth 119 | 120 | main(task_name, 121 | mind_r, 122 | mind_d, 123 | use_mask, 124 | lambda_weight, 125 | grid_sp, 126 | disp_hw, 127 | selected_niter, 128 | selected_smooth, 129 | data_dir, 130 | result_path) -------------------------------------------------------------------------------- /self_configuring/main_for_l2r3_nnUNet.py: -------------------------------------------------------------------------------- 1 | from convex_adam_nnUNet import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import nibabel as nib 6 | import argparse 7 | from pathlib import Path 8 | import json 9 | import os 10 | from L2R_main.evaluation import evaluation 11 | import warnings 12 | warnings.filterwarnings("ignore") 13 | 14 | 15 | def main(task_name, 16 | mind_r, 17 | mind_d, 18 | use_mask, 19 | lambda_weight, 20 | grid_sp, 21 | disp_hw, 22 | evaluate, 23 | data_dir, 24 | result_path, 25 | config_path): 26 | 27 | task_dir = os.path.join(data_dir,task_name) 28 | dataset_json = os.path.join(task_dir,task_name+'_dataset.json') 29 | 30 | with open(dataset_json, 'r') as f: 31 | data = json.load(f) 32 | val_pairs = data['registration_val'] 33 | 34 | if len(data['modality'].keys()) == 1: 35 | modality_fixed = data['modality']['0'] 36 | modality_moving = data['modality']['0'] 37 | 38 | if len(data['modality'].keys()) == 2: 39 | modality_fixed = data['modality']['0'] 40 | modality_moving = data['modality']['1'] 41 | 42 | if len(data['modality'].keys()) == 3: 43 | modality_fixed = data['modality']['0'] 44 | modality_moving = data['modality']['2'] 45 | 46 | outstr = '_'+'nnUNet'+'_'+str(int(lambda_weight*100))+'lambda_'+str(grid_sp)+'gs1_'+str(disp_hw)+'disp_'+str(use_mask)+'Masks' 47 | print(outstr) 48 | 49 | print('>>> Modality fixed: ', modality_fixed) 50 | print('>>> Modality moving: ', modality_moving) 51 | print('>>> Settings: lambda_weight: {}; grid_sp: {}; disp_hw: {}'.format(lambda_weight, grid_sp, disp_hw)) 52 | 53 | # create save directory 54 | save_paths = ['40_smoothing0', '60_smoothing0', '80_smoothing0', '40_smoothing3', '60_smoothing3', '80_smoothing3', '40_smoothing5', '60_smoothing5', '80_smoothing5'] 55 | for save_path in save_paths: 56 | new_path = os.path.join(result_path, task_name, save_path) 57 | isExist = os.path.exists(new_path) 58 | if not isExist: 59 | os.makedirs(new_path) 60 | files = os.listdir(new_path) 61 | for item in files: 62 | if item.endswith(".nii.gz"): 63 | os.remove(os.path.join(new_path, item)) 64 | if item.endswith(".nii"): 65 | os.remove(os.path.join(new_path, item)) 66 | 67 | 68 | case_times = torch.zeros(len(val_pairs)) 69 | ii=0 70 | for _, pair in enumerate(val_pairs): 71 | path_fixed = os.path.join(task_dir, pair['fixed']) 72 | path_moving = os.path.join(task_dir, pair['moving']) 73 | img_fixed = torch.from_numpy(nib.load(path_fixed).get_fdata()).float() 74 | img_moving = torch.from_numpy(nib.load(path_moving).get_fdata()).float() 75 | 76 | path_fixed_pred = os.path.join(task_dir, pair['fixed'].replace('images','predictedlabels')) 77 | path_moving_pred = os.path.join(task_dir, pair['moving'].replace('images','predictedlabels')) 78 | pred_fixed = torch.from_numpy(nib.load(path_fixed_pred).get_fdata()).float() 79 | pred_moving = torch.from_numpy(nib.load(path_moving_pred).get_fdata()).float() 80 | 81 | if use_mask: 82 | path_fixed_mask = os.path.join(task_dir, pair['fixed'].replace('images','masks')) 83 | path_moving_mask = os.path.join(task_dir, pair['moving'].replace('images','masks')) 84 | mask_fixed = torch.from_numpy(nib.load(path_fixed_mask).get_fdata()).float() 85 | mask_moving = torch.from_numpy(nib.load(path_moving_mask).get_fdata()).float() 86 | else: 87 | mask_fixed = None 88 | mask_moving = None 89 | 90 | 91 | displacements_40, displacements_60, displacements_80, displacements_40_smooth3, displacements_60_smooth3, displacements_80_smooth3, displacements_40_smooth5, displacements_60_smooth5, displacements_80_smooth5, case_time = convex_adam(img_fixed=img_fixed, 92 | img_moving=img_moving, 93 | pred_fixed=pred_fixed, 94 | pred_moving=pred_moving, 95 | use_mask=use_mask, 96 | mask_fixed=mask_fixed, 97 | mask_moving=mask_moving, 98 | lambda_weight=lambda_weight, 99 | grid_sp=grid_sp, 100 | disp_hw=disp_hw) 101 | 102 | case_times[ii] = case_time 103 | ii+=1 104 | 105 | 106 | affine = nib.load(path_fixed).affine 107 | 108 | # save smoothing 0 109 | #disp_path = os.path.join(result_path, task_name, '40_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 110 | disp_path = os.path.join(result_path, task_name, '40_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 111 | disp_nii = nib.Nifti1Image(displacements_40, affine) 112 | nib.save(disp_nii, disp_path) 113 | 114 | #disp_path = os.path.join(result_path, task_name, '60_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 115 | disp_path = os.path.join(result_path, task_name, '60_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 116 | disp_nii = nib.Nifti1Image(displacements_60, affine) 117 | nib.save(disp_nii, disp_path) 118 | 119 | #disp_path = os.path.join(result_path, task_name, '80_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 120 | disp_path = os.path.join(result_path, task_name, '80_smoothing0', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 121 | disp_nii = nib.Nifti1Image(displacements_80, affine) 122 | nib.save(disp_nii, disp_path) 123 | 124 | # save smoothing 3 125 | #disp_path = os.path.join(result_path, task_name, '40_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 126 | disp_path = os.path.join(result_path, task_name, '40_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 127 | disp_nii = nib.Nifti1Image(displacements_40_smooth3, affine) 128 | nib.save(disp_nii, disp_path) 129 | 130 | #disp_path = os.path.join(result_path, task_name, '60_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 131 | disp_path = os.path.join(result_path, task_name, '60_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 132 | disp_nii = nib.Nifti1Image(displacements_60_smooth3, affine) 133 | nib.save(disp_nii, disp_path) 134 | 135 | #disp_path = os.path.join(result_path, task_name, '80_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 136 | disp_path = os.path.join(result_path, task_name, '80_smoothing3', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 137 | disp_nii = nib.Nifti1Image(displacements_80_smooth3, affine) 138 | nib.save(disp_nii, disp_path) 139 | 140 | # save smoothing 5 141 | #disp_path = os.path.join(result_path, task_name, '40_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 142 | disp_path = os.path.join(result_path, task_name, '40_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 143 | disp_nii = nib.Nifti1Image(displacements_40_smooth5, affine) 144 | nib.save(disp_nii, disp_path) 145 | 146 | #disp_path = os.path.join(result_path, task_name, '60_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 147 | disp_path = os.path.join(result_path, task_name, '60_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 148 | disp_nii = nib.Nifti1Image(displacements_60_smooth5, affine) 149 | nib.save(disp_nii, disp_path) 150 | 151 | #disp_path = os.path.join(result_path, task_name, '80_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 152 | disp_path = os.path.join(result_path, task_name, '80_smoothing5', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii')) 153 | disp_nii = nib.Nifti1Image(displacements_80_smooth5, affine) 154 | nib.save(disp_nii, disp_path) 155 | 156 | 157 | median_case_time = case_times.median().item() 158 | print('median case time: ', median_case_time) 159 | 160 | if evaluate: 161 | print('>>>EVALUATION') 162 | DEFAULT_GROUND_TRUTH_PATH = Path(task_dir) 163 | eval_config_path = config_path+task_name+'_VAL_evaluation_config.json' 164 | 165 | for save_path in save_paths: 166 | DEFAULT_INPUT_PATH = Path(os.path.join(result_path, task_name,save_path)) 167 | DEFAULT_EVALUATION_OUTPUT_FILE_PATH = Path(os.path.join(result_path, task_name,save_path)+'/metrics'+outstr+'.json') 168 | evaluation.evaluate_L2R(DEFAULT_INPUT_PATH, DEFAULT_GROUND_TRUTH_PATH, DEFAULT_EVALUATION_OUTPUT_FILE_PATH, eval_config_path, verbose=False) 169 | 170 | with open(DEFAULT_EVALUATION_OUTPUT_FILE_PATH, "r") as jsonFile: 171 | data = json.load(jsonFile) 172 | 173 | data[task_name]["aggregates"]["median_case_time"] = median_case_time 174 | 175 | with open(DEFAULT_EVALUATION_OUTPUT_FILE_PATH, "w") as jsonFile: 176 | json.dump(data, jsonFile) 177 | 178 | print('Path to evaluation JSON file: ', DEFAULT_EVALUATION_OUTPUT_FILE_PATH) 179 | 180 | else: 181 | print('NO EVALUATION') 182 | 183 | if __name__=="__main__": 184 | parser = argparse.ArgumentParser() 185 | parser.add_argument('--task_name', type=str, required=True) 186 | parser.add_argument('--lambda_weight', type=float, required=True) 187 | parser.add_argument('--grid_sp', type=int, required=True) 188 | parser.add_argument('--disp_hw', type=int, required=True) 189 | parser.add_argument('--mind_r', type=int, default=1) 190 | parser.add_argument('--mind_d', type=int, default=2) 191 | parser.add_argument('--use_mask', choices=('True','False'), default= 'False') 192 | parser.add_argument('--evaluate', choices=('True','False'), default= 'True') 193 | parser.add_argument('--data_dir', type=str, default='/share/data_zoe3/grossbroehmer/Learn2Reg2022/Learn2Reg_Dataset_v11/') 194 | parser.add_argument('--result_path', type=str, default='/share/data_abby2/hsiebert/code/adam_optimisation/JournalExperiments/l2r2022/results/') 195 | parser.add_argument('--config_path', type=str, default='/share/data_abby2/hsiebert/code/adam_optimisation/JournalExperiments/l2r2022/L2R_main/evaluation/evaluation_configs/') 196 | args = parser.parse_args() 197 | 198 | 199 | if args.evaluate == 'True': 200 | evaluate=True 201 | else: 202 | evaluate=False 203 | 204 | if args.use_mask == 'True': 205 | use_mask=True 206 | else: 207 | use_mask=False 208 | 209 | 210 | 211 | task_name = args.task_name 212 | data_dir = args.data_dir 213 | mind_r = args.mind_r 214 | mind_d = args.mind_d 215 | lambda_weight = args.lambda_weight 216 | grid_sp = args.grid_sp 217 | disp_hw = args.disp_hw 218 | result_path = args.result_path 219 | config_path = args.config_path 220 | 221 | main(task_name, 222 | mind_r, 223 | mind_d, 224 | use_mask, 225 | lambda_weight, 226 | grid_sp, 227 | disp_hw, 228 | evaluate, 229 | data_dir, 230 | result_path, 231 | config_path) -------------------------------------------------------------------------------- /self_configuring/main_for_l2r3_nnUNet_testset.py: -------------------------------------------------------------------------------- 1 | from convex_adam_nnUNet_testset import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import nibabel as nib 6 | import argparse 7 | import json 8 | import os 9 | import warnings 10 | warnings.filterwarnings("ignore") 11 | 12 | 13 | def main(task_name, 14 | mind_r, 15 | mind_d, 16 | use_mask, 17 | lambda_weight, 18 | grid_sp, 19 | disp_hw, 20 | selected_niter, 21 | selected_smooth, 22 | data_dir, 23 | result_path): 24 | 25 | task_dir = os.path.join(data_dir,task_name) 26 | dataset_json = os.path.join(task_dir,task_name+'_dataset.json') 27 | 28 | with open(dataset_json, 'r') as f: 29 | data = json.load(f) 30 | val_pairs = data['registration_test'] 31 | 32 | # create save directory 33 | save_paths = ['results_testset'] 34 | for save_path in save_paths: 35 | new_path = os.path.join(result_path, task_name, save_path) 36 | isExist = os.path.exists(new_path) 37 | if not isExist: 38 | os.makedirs(new_path) 39 | files = os.listdir(new_path) 40 | for item in files: 41 | if item.endswith(".nii.gz"): 42 | os.remove(os.path.join(new_path, item)) 43 | if item.endswith(".nii"): 44 | os.remove(os.path.join(new_path, item)) 45 | 46 | 47 | case_times = torch.zeros(len(val_pairs)) 48 | ii=0 49 | for _, pair in enumerate(val_pairs): 50 | path_fixed = os.path.join(task_dir, pair['fixed']) 51 | path_moving = os.path.join(task_dir, pair['moving']) 52 | img_fixed = torch.from_numpy(nib.load(path_fixed).get_fdata()).float() 53 | img_moving = torch.from_numpy(nib.load(path_moving).get_fdata()).float() 54 | 55 | path_fixed_pred = os.path.join(task_dir, pair['fixed'].replace('images','predictedlabels')) 56 | path_moving_pred = os.path.join(task_dir, pair['moving'].replace('images','predictedlabels')) 57 | pred_fixed = torch.from_numpy(nib.load(path_fixed_pred).get_fdata()).float() 58 | pred_moving = torch.from_numpy(nib.load(path_moving_pred).get_fdata()).float() 59 | 60 | if use_mask: 61 | path_fixed_mask = os.path.join(task_dir, pair['fixed'].replace('images','masks')) 62 | path_moving_mask = os.path.join(task_dir, pair['moving'].replace('images','masks')) 63 | mask_fixed = torch.from_numpy(nib.load(path_fixed_mask).get_fdata()).float() 64 | mask_moving = torch.from_numpy(nib.load(path_moving_mask).get_fdata()).float() 65 | else: 66 | mask_fixed = None 67 | mask_moving = None 68 | 69 | 70 | displacements, case_time = convex_adam(img_fixed=img_fixed, 71 | img_moving=img_moving, 72 | pred_fixed=pred_fixed, 73 | pred_moving=pred_moving, 74 | use_mask=use_mask, 75 | mask_fixed=mask_fixed, 76 | mask_moving=mask_moving, 77 | lambda_weight=lambda_weight, 78 | grid_sp=grid_sp, 79 | disp_hw=disp_hw, 80 | selected_niter=selected_niter, 81 | selected_smooth=selected_smooth) 82 | 83 | case_times[ii] = case_time 84 | ii+=1 85 | 86 | affine = nib.load(path_fixed).affine 87 | 88 | disp_path = os.path.join(result_path, task_name, 'results_testset', 'disp_{}_{}'.format(pair['fixed'][-16:-12], pair['moving'][-16:-12]+'.nii.gz')) 89 | disp_nii = nib.Nifti1Image(displacements, affine) 90 | nib.save(disp_nii, disp_path) 91 | 92 | median_case_time = case_times.median().item() 93 | print('median case time: ', median_case_time) 94 | print('displacements saved here: ', os.path.join(result_path, task_name, 'results_testset')) 95 | 96 | if __name__=="__main__": 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument('--task_name', type=str, required=True) 99 | parser.add_argument('--lambda_weight', type=float, required=True) 100 | parser.add_argument('--grid_sp', type=int, required=True) 101 | parser.add_argument('--disp_hw', type=int, required=True) 102 | parser.add_argument('--mind_r', type=int, default=1) 103 | parser.add_argument('--mind_d', type=int, default=2) 104 | parser.add_argument('--selected_niter', type=int, required=True) 105 | parser.add_argument('--selected_smooth', type=int, required=True) 106 | parser.add_argument('--use_mask', choices=('True','False'), default= 'False') 107 | parser.add_argument('--data_dir', type=str, default='/share/data_zoe3/grossbroehmer/Learn2Reg2022/Learn2Reg_Dataset_v11/') 108 | parser.add_argument('--result_path', type=str, default='/share/data_abby2/hsiebert/code/adam_optimisation/JournalExperiments/l2r2022/results/') 109 | args = parser.parse_args() 110 | 111 | if args.use_mask == 'True': 112 | use_mask=True 113 | else: 114 | use_mask=False 115 | 116 | task_name = args.task_name 117 | data_dir = args.data_dir 118 | mind_r = args.mind_r 119 | mind_d = args.mind_d 120 | lambda_weight = args.lambda_weight 121 | grid_sp = args.grid_sp 122 | disp_hw = args.disp_hw 123 | result_path = args.result_path 124 | selected_niter = args.selected_niter 125 | selected_smooth = args.selected_smooth 126 | 127 | main(task_name, 128 | mind_r, 129 | mind_d, 130 | use_mask, 131 | lambda_weight, 132 | grid_sp, 133 | disp_hw, 134 | selected_niter, 135 | selected_smooth, 136 | data_dir, 137 | result_path) -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = convexAdam 3 | description = Convex Adam 4 | author = Mattias Paul Heinrich 5 | license = Apache 2.0 6 | license_files = LICENSE 7 | platforms = unix, linux, osx, cygwin, win32 8 | classifiers = 9 | Programming Language :: Python :: 3 10 | Programming Language :: Python :: 3 :: Only 11 | Programming Language :: Python :: 3.6 12 | Programming Language :: Python :: 3.7 13 | Programming Language :: Python :: 3.8 14 | Programming Language :: Python :: 3.9 15 | Programming Language :: Python :: 3.10 16 | Programming Language :: Python :: 3.11 17 | 18 | [options] 19 | packages = 20 | convexAdam 21 | install_requires = 22 | nibabel 23 | numpy 24 | scikit-learn 25 | SimpleITK 26 | torch 27 | python_requires = >=3.6 28 | package_dir = 29 | =src 30 | zip_safe = no 31 | 32 | [options.extras_require] 33 | testing = 34 | pytest>=6.0 35 | pytest-cov>=2.0 36 | mypy>=0.910 37 | flake8>=3.9 38 | tox>=3.24 39 | 40 | [options.package_data] 41 | convexAdam = py.typed 42 | 43 | [flake8] 44 | max-line-length = 160 45 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | if __name__ == '__main__': 4 | with open("README.md", "r", encoding="utf-8") as fh: 5 | long_description = fh.read() 6 | 7 | setuptools.setup( 8 | version='0.2.0', 9 | author_email='heinrich@imi.uni-luebeck.de', 10 | long_description=long_description, 11 | long_description_content_type="text/markdown", 12 | url='https://github.com/multimodallearning/convexAdam', 13 | project_urls={ 14 | "Bug Tracker": "https://github.com/multimodallearning/convexAdam/issues" 15 | }, 16 | license='Apache 2.0', 17 | packages=['convexAdam'], 18 | ) 19 | -------------------------------------------------------------------------------- /src/convexAdam/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/src/convexAdam/__init__.py -------------------------------------------------------------------------------- /src/convexAdam/apply_convex.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Union 3 | 4 | import nibabel as nib 5 | import numpy as np 6 | import SimpleITK as sitk 7 | import torch 8 | from scipy.ndimage import map_coordinates 9 | 10 | from convexAdam.convex_adam_utils import validate_image 11 | 12 | 13 | def apply_convex( 14 | disp: Union[torch.Tensor, np.ndarray, sitk.Image], 15 | moving: Union[torch.Tensor, np.ndarray, sitk.Image], 16 | ) -> np.ndarray: 17 | # convert to numpy, if not already 18 | moving = validate_image(moving).numpy() 19 | disp = validate_image(disp).numpy() 20 | 21 | d1, d2, d3, _ = disp.shape 22 | identity = np.meshgrid(np.arange(d1), np.arange(d2), np.arange(d3), indexing='ij') 23 | warped_image = map_coordinates(moving, disp.transpose(3, 0, 1, 2) + identity, order=1) 24 | return warped_image 25 | 26 | 27 | def apply_convex_original_moving( 28 | disp: Union[torch.Tensor, np.ndarray, sitk.Image], 29 | moving_image_original: sitk.Image, 30 | fixed_image_original: sitk.Image, 31 | fixed_image_resampled: sitk.Image, 32 | ): 33 | """Apply displacement field to the moving image without resampling the moving image""" 34 | # convert to numpy, if not already 35 | disp = validate_image(disp).numpy() 36 | 37 | # resample the displacement field to the physical space of the original moving image 38 | channels_resampled = [] 39 | for i in range(3): 40 | displacement_field_channel = sitk.GetImageFromArray(disp[:, :, :, i]) 41 | displacement_field_channel.CopyInformation(fixed_image_resampled) 42 | 43 | # set up the resampling filter 44 | resampler = sitk.ResampleImageFilter() 45 | resampler.SetReferenceImage(moving_image_original) 46 | resampler.SetInterpolator(sitk.sitkLinear) 47 | 48 | # apply resampling 49 | displacement_field_resampled = resampler.Execute(displacement_field_channel) 50 | 51 | # append to list of channels 52 | channels_resampled.append(displacement_field_resampled) 53 | 54 | # combine channels 55 | displacement_field_resampled = sitk.JoinSeries(channels_resampled) 56 | displacement_field_resampled = np.moveaxis(sitk.GetArrayFromImage(displacement_field_resampled), 0, -1) 57 | 58 | # find the rotation between the direction of the moving image and the direction of the fixed image 59 | fixed_direction = np.array(fixed_image_original.GetDirection()).reshape(3, 3) 60 | moving_direction = np.array(moving_image_original.GetDirection()).reshape(3, 3) 61 | rotation = np.dot(np.linalg.inv(fixed_direction), moving_direction) 62 | 63 | # rotate the vectors in the displacement field (the z, y, x components are in the last dimension) 64 | displacement_field_resampled = displacement_field_resampled[..., ::-1] # make the order x, y, z 65 | displacement_field_rotated = np.dot(displacement_field_resampled, rotation) 66 | displacement_field_rotated = displacement_field_rotated[..., ::-1] # make the order z, y, x 67 | 68 | # adapt the displacement field to the original moving image, which has a different spacing 69 | scaling_factor = np.array(fixed_image_resampled.GetSpacing()) / np.array(moving_image_original.GetSpacing()) 70 | displacement_field_rescaled = displacement_field_rotated * list(scaling_factor)[::-1] 71 | 72 | moving_image_warped = apply_convex( 73 | disp=displacement_field_rescaled, 74 | moving=moving_image_original, 75 | ) 76 | moving_image_warped = sitk.GetImageFromArray(moving_image_warped.astype(np.float32)) 77 | moving_image_warped.CopyInformation(moving_image_original) 78 | return moving_image_warped 79 | 80 | 81 | if __name__ == '__main__': 82 | parser = argparse.ArgumentParser() 83 | parser.add_argument("--input_field", dest="input_field", help="input convex displacement field (.nii.gz) full resolution", default=None, required=True) 84 | parser.add_argument("--input_moving", dest="input_moving", help="input moving scan (.nii.gz)", default=None, required=True) 85 | parser.add_argument("--output_warped", dest="output_warped", help="output warped scan (.nii.gz)", default=None, required=True) 86 | args = parser.parse_args() 87 | 88 | moving = nib.load(args.input_moving) 89 | disp = nib.load(args.input_field) 90 | 91 | warped_image = apply_convex( 92 | disp=disp.get_fdata().astype('float32'), 93 | moving=moving.get_fdata().astype('float32'), 94 | ) 95 | 96 | warped_image = nib.Nifti1Image(warped_image, affine=None, header=moving.header) 97 | nib.save(warped_image, args.output_warped) 98 | -------------------------------------------------------------------------------- /src/convexAdam/convex_adam_MIND.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import warnings 5 | from pathlib import Path 6 | from typing import Optional, Union 7 | 8 | import nibabel as nib 9 | import numpy as np 10 | import SimpleITK as sitk 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | from scipy.ndimage import distance_transform_edt as edt 15 | 16 | from convexAdam.convex_adam_utils import (MINDSSC, correlate, coupled_convex, 17 | inverse_consistency, validate_image) 18 | 19 | warnings.filterwarnings("ignore") 20 | 21 | 22 | def extract_features( 23 | img_fixed: torch.Tensor, 24 | img_moving: torch.Tensor, 25 | mind_r: int, 26 | mind_d: int, 27 | use_mask: bool, 28 | mask_fixed: torch.Tensor, 29 | mask_moving: torch.Tensor, 30 | device: torch.device = torch.device("cuda"), 31 | dtype: torch.dtype = torch.float16, 32 | ) -> tuple[torch.Tensor, torch.Tensor]: 33 | """Extract MIND and/or semantic nnUNet features""" 34 | 35 | # MIND features 36 | if use_mask: 37 | H,W,D = img_fixed.shape[-3:] 38 | 39 | #replicate masking 40 | avg3 = nn.Sequential(nn.ReplicationPad3d(1),nn.AvgPool3d(3,stride=1)) 41 | avg3.to(device) 42 | 43 | mask = (avg3(mask_fixed.view(1,1,H,W,D).to(device))>0.9).float() 44 | _,idx = edt((mask[0,0,::2,::2,::2]==0).squeeze().cpu().numpy(),return_indices=True) 45 | fixed_r = F.interpolate((img_fixed[::2,::2,::2].to(device).reshape(-1)[idx[0]*D//2*W//2+idx[1]*D//2+idx[2]]).unsqueeze(0).unsqueeze(0),scale_factor=2,mode='trilinear') 46 | fixed_r.view(-1)[mask.view(-1)!=0] = img_fixed.to(device).reshape(-1)[mask.view(-1)!=0] 47 | 48 | mask = (avg3(mask_moving.view(1,1,H,W,D).to(device))>0.9).float() 49 | _,idx = edt((mask[0,0,::2,::2,::2]==0).squeeze().cpu().numpy(),return_indices=True) 50 | moving_r = F.interpolate((img_moving[::2,::2,::2].to(device).reshape(-1)[idx[0]*D//2*W//2+idx[1]*D//2+idx[2]]).unsqueeze(0).unsqueeze(0),scale_factor=2,mode='trilinear') 51 | moving_r.view(-1)[mask.view(-1)!=0] = img_moving.to(device).reshape(-1)[mask.view(-1)!=0] 52 | 53 | features_fix = MINDSSC(fixed_r.to(device),mind_r,mind_d,device=device).to(dtype) 54 | features_mov = MINDSSC(moving_r.to(device),mind_r,mind_d,device=device).to(dtype) 55 | else: 56 | img_fixed = img_fixed.unsqueeze(0).unsqueeze(0) 57 | img_moving = img_moving.unsqueeze(0).unsqueeze(0) 58 | features_fix = MINDSSC(img_fixed.to(device),mind_r,mind_d,device=device).to(dtype) 59 | features_mov = MINDSSC(img_moving.to(device),mind_r,mind_d,device=device).to(dtype) 60 | 61 | return features_fix, features_mov 62 | 63 | 64 | def convex_adam_pt( 65 | img_fixed: Union[torch.Tensor, np.ndarray, sitk.Image, nib.Nifti1Image], 66 | img_moving: Union[torch.Tensor, np.ndarray, sitk.Image, nib.Nifti1Image], 67 | mind_r: int = 1, 68 | mind_d: int = 2, 69 | lambda_weight: float = 1.25, 70 | grid_sp: int = 6, 71 | disp_hw: int = 4, 72 | selected_niter: int = 80, 73 | selected_smooth: int = 0, 74 | grid_sp_adam: int = 2, 75 | ic: bool = True, 76 | use_mask: bool = False, 77 | path_fixed_mask: Optional[Union[Path, str]] = None, 78 | path_moving_mask: Optional[Union[Path, str]] = None, 79 | dtype: torch.dtype = torch.float16, 80 | verbose: bool = False, 81 | device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"), 82 | ) -> np.ndarray: 83 | """Coupled convex optimisation with adam instance optimisation""" 84 | img_fixed = validate_image(img_fixed) 85 | img_moving = validate_image(img_moving) 86 | img_fixed = img_fixed.float() 87 | img_moving = img_moving.float() 88 | 89 | if dtype == torch.float16 and device == torch.device("cpu"): 90 | print("Warning: float16 is not supported on CPU, using float32 instead") 91 | dtype = torch.float32 92 | 93 | if use_mask: 94 | mask_fixed = torch.from_numpy(nib.load(path_fixed_mask).get_fdata()).float() 95 | mask_moving = torch.from_numpy(nib.load(path_moving_mask).get_fdata()).float() 96 | else: 97 | mask_fixed = None 98 | mask_moving = None 99 | 100 | H, W, D = img_fixed.shape 101 | 102 | t0 = time.time() 103 | 104 | # compute features and downsample (using average pooling) 105 | with torch.no_grad(): 106 | features_fix, features_mov = extract_features( 107 | img_fixed=img_fixed, 108 | img_moving=img_moving, 109 | mind_r=mind_r, 110 | mind_d=mind_d, 111 | use_mask=use_mask, 112 | mask_fixed=mask_fixed, 113 | mask_moving=mask_moving, 114 | device=device, 115 | dtype=dtype, 116 | ) 117 | 118 | features_fix_smooth = F.avg_pool3d(features_fix,grid_sp,stride=grid_sp) 119 | features_mov_smooth = F.avg_pool3d(features_mov,grid_sp,stride=grid_sp) 120 | 121 | n_ch = features_fix_smooth.shape[1] 122 | 123 | # compute correlation volume with SSD 124 | ssd,ssd_argmin = correlate(features_fix_smooth,features_mov_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 125 | 126 | # provide auxiliary mesh grid 127 | disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).to(device).to(dtype).unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1) 128 | 129 | # perform coupled convex optimisation 130 | disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp,(H,W,D)) 131 | 132 | # if "ic" flag is set: make inverse consistent 133 | if ic: 134 | scale = torch.tensor([H//grid_sp-1,W//grid_sp-1,D//grid_sp-1]).view(1,3,1,1,1).to(device).to(dtype)/2 135 | 136 | ssd_,ssd_argmin_ = correlate(features_mov_smooth,features_fix_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 137 | 138 | disp_soft_ = coupled_convex(ssd_,ssd_argmin_,disp_mesh_t,grid_sp,(H,W,D)) 139 | disp_ice,_ = inverse_consistency((disp_soft/scale).flip(1),(disp_soft_/scale).flip(1),iter=15) 140 | 141 | disp_hr = F.interpolate(disp_ice.flip(1)*scale*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False) 142 | 143 | else: 144 | disp_hr=disp_soft 145 | 146 | # run Adam instance optimisation 147 | if lambda_weight > 0: 148 | with torch.no_grad(): 149 | patch_features_fix = F.avg_pool3d(features_fix,grid_sp_adam,stride=grid_sp_adam) 150 | patch_features_mov = F.avg_pool3d(features_mov,grid_sp_adam,stride=grid_sp_adam) 151 | 152 | #create optimisable displacement grid 153 | disp_lr = F.interpolate(disp_hr,size=(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),mode='trilinear',align_corners=False) 154 | 155 | net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),bias=False)) 156 | net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp_adam 157 | net.to(device) 158 | optimizer = torch.optim.Adam(net.parameters(), lr=1) 159 | 160 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).to(device),(1,1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),align_corners=False) 161 | 162 | #run Adam optimisation with diffusion regularisation and B-spline smoothing 163 | for iter in range(selected_niter): 164 | optimizer.zero_grad() 165 | 166 | disp_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1) 167 | reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ 168 | lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ 169 | lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() 170 | 171 | scale = torch.tensor([(H//grid_sp_adam-1)/2,(W//grid_sp_adam-1)/2,(D//grid_sp_adam-1)/2]).to(device).unsqueeze(0) 172 | grid_disp = grid0.view(-1,3).to(device).float()+((disp_sample.view(-1,3))/scale).flip(1).float() 173 | 174 | patch_mov_sampled = F.grid_sample(patch_features_mov.float(),grid_disp.view(1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam,3).to(device),align_corners=False,mode='bilinear') 175 | 176 | sampled_cost = (patch_mov_sampled-patch_features_fix).pow(2).mean(1)*12 177 | loss = sampled_cost.mean() 178 | (loss+reg_loss).backward() 179 | optimizer.step() 180 | 181 | fitted_grid = disp_sample.detach().permute(0,4,1,2,3) 182 | disp_hr = F.interpolate(fitted_grid*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 183 | 184 | if selected_smooth > 0: 185 | if selected_smooth % 2 == 0: 186 | kernel_smooth = selected_smooth+1 187 | print('selected_smooth should be an odd number, adding 1') 188 | 189 | kernel_smooth = selected_smooth 190 | padding_smooth = kernel_smooth//2 191 | disp_hr = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 192 | 193 | t1 = time.time() 194 | case_time = t1-t0 195 | if verbose: 196 | print(f'case time: {case_time}') 197 | 198 | x = disp_hr[0,0,:,:,:].cpu().to(dtype).data.numpy() 199 | y = disp_hr[0,1,:,:,:].cpu().to(dtype).data.numpy() 200 | z = disp_hr[0,2,:,:,:].cpu().to(dtype).data.numpy() 201 | displacements = np.stack((x,y,z),3).astype(float) 202 | return displacements 203 | 204 | 205 | def convex_adam( 206 | path_img_fixed: Union[Path, str], 207 | path_img_moving: Union[Path, str], 208 | mind_r: int = 1, 209 | mind_d: int = 2, 210 | lambda_weight: float = 1.25, 211 | grid_sp: int = 6, 212 | disp_hw: int = 4, 213 | selected_niter: int = 80, 214 | selected_smooth: int = 0, 215 | grid_sp_adam: int = 2, 216 | ic: bool = True, 217 | use_mask: bool = False, 218 | path_fixed_mask: Optional[Union[Path, str]] = None, 219 | path_moving_mask: Optional[Union[Path, str]] = None, 220 | result_path: Union[Path, str] = './', 221 | verbose: bool = False, 222 | ) -> None: 223 | """Coupled convex optimisation with adam instance optimisation""" 224 | 225 | img_fixed = torch.from_numpy(nib.load(path_img_fixed).get_fdata()).float() 226 | img_moving = torch.from_numpy(nib.load(path_img_moving).get_fdata()).float() 227 | 228 | displacements = convex_adam_pt( 229 | img_fixed=img_fixed, 230 | img_moving=img_moving, 231 | mind_r=mind_r, 232 | mind_d=mind_d, 233 | lambda_weight=lambda_weight, 234 | grid_sp=grid_sp, 235 | disp_hw=disp_hw, 236 | selected_niter=selected_niter, 237 | selected_smooth=selected_smooth, 238 | grid_sp_adam=grid_sp_adam, 239 | ic=ic, 240 | use_mask=use_mask, 241 | path_fixed_mask=path_fixed_mask, 242 | path_moving_mask=path_moving_mask, 243 | verbose=verbose, 244 | ) 245 | 246 | affine = nib.load(path_img_fixed).affine 247 | disp_nii = nib.Nifti1Image(displacements, affine) 248 | nib.save(disp_nii, os.path.join(result_path,'disp.nii.gz')) 249 | 250 | 251 | if __name__=="__main__": 252 | parser = argparse.ArgumentParser() 253 | parser.add_argument("-f","--path_img_fixed", type=str, required=True) 254 | parser.add_argument("-m",'--path_img_moving', type=str, required=True) 255 | parser.add_argument('--mind_r', type=int, default=1) 256 | parser.add_argument('--mind_d', type=int, default=2) 257 | parser.add_argument('--lambda_weight', type=float, default=1.25) 258 | parser.add_argument('--grid_sp', type=int, default=6) 259 | parser.add_argument('--disp_hw', type=int, default=4) 260 | parser.add_argument('--selected_niter', type=int, default=80) 261 | parser.add_argument('--selected_smooth', type=int, default=0) 262 | parser.add_argument('--grid_sp_adam', type=int, default=2) 263 | parser.add_argument('--ic', choices=('True','False'), default='True') 264 | parser.add_argument('--use_mask', choices=('True','False'), default='False') 265 | parser.add_argument('--path_mask_fixed', type=str, default=None) 266 | parser.add_argument('--path_mask_moving', type=str, default=None) 267 | parser.add_argument('--result_path', type=str, default='./') 268 | 269 | args = parser.parse_args() 270 | 271 | convex_adam( 272 | path_img_fixed=args.path_img_fixed, 273 | path_img_moving=args.path_img_moving, 274 | mind_r=args.mind_r, 275 | mind_d=args.mind_d, 276 | lambda_weight=args.lambda_weight, 277 | grid_sp=args.grid_sp, 278 | disp_hw=args.disp_hw, 279 | selected_niter=args.selected_niter, 280 | selected_smooth=args.selected_smooth, 281 | grid_sp_adam=args.grid_sp_adam, 282 | ic=(args.ic == 'True'), 283 | use_mask=(args.use_mask == 'True'), 284 | path_fixed_mask=args.path_mask_fixed, 285 | path_moving_mask=args.path_mask_moving, 286 | result_path=args.result_path 287 | ) 288 | -------------------------------------------------------------------------------- /src/convexAdam/convex_adam_nnUNet.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import time 4 | import warnings 5 | 6 | import nibabel as nib 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | 12 | from convexAdam.convex_adam_utils import (correlate, coupled_convex, 13 | inverse_consistency) 14 | 15 | warnings.filterwarnings("ignore") 16 | 17 | 18 | # process nnUNet features 19 | def extract_features(pred_fixed, 20 | pred_moving): 21 | 22 | eps=1e-32 23 | H,W,D = pred_fixed.shape[-3:] 24 | 25 | combined_bins = torch.bincount(pred_fixed.long().reshape(-1))+torch.bincount(pred_moving.long().reshape(-1)) 26 | 27 | pos = torch.nonzero(combined_bins).reshape(-1) 28 | 29 | pred_fixed = F.one_hot(pred_fixed.cuda().view(1,H,W,D).long())[:,:,:,:,pos] 30 | pred_moving = F.one_hot(pred_moving.cuda().view(1,H,W,D).long())[:,:,:,:,pos] 31 | 32 | weight = 1/((torch.bincount(pred_fixed.permute(0,4,1,2,3).argmax(1).long().reshape(-1))+torch.bincount(pred_moving.permute(0,4,1,2,3).argmax(1).long().reshape(-1)))+eps).float().pow(.3) 33 | weight /= weight.mean() 34 | 35 | features_fix = 10*(pred_fixed.data.float().permute(0,4,1,2,3).contiguous()*weight.view(1,-1,1,1,1).cuda()).half() 36 | features_mov = 10*(pred_moving.data.float().permute(0,4,1,2,3).contiguous()*weight.view(1,-1,1,1,1).cuda()).half() 37 | 38 | return features_fix, features_mov 39 | 40 | # coupled convex optimisation with adam instance optimisation 41 | def convex_adam(path_pred_fixed, 42 | path_pred_moving, 43 | lambda_weight, 44 | grid_sp, 45 | disp_hw, 46 | selected_niter, 47 | selected_smooth, 48 | grid_sp_adam=2, 49 | ic=True, 50 | result_path='./'): 51 | 52 | pred_fixed = torch.from_numpy(nib.load(path_pred_fixed).get_fdata()).float() 53 | pred_moving = torch.from_numpy(nib.load(path_pred_moving).get_fdata()).float() 54 | 55 | H,W,D = pred_fixed.shape[-3:] 56 | 57 | torch.cuda.synchronize() 58 | t0 = time.time() 59 | 60 | #compute features and downsample (using average pooling) 61 | with torch.no_grad(): 62 | 63 | features_fix, features_mov = extract_features(pred_fixed=pred_fixed, 64 | pred_moving=pred_moving) 65 | 66 | features_fix_smooth = F.avg_pool3d(features_fix,grid_sp,stride=grid_sp) 67 | features_mov_smooth = F.avg_pool3d(features_mov,grid_sp,stride=grid_sp) 68 | 69 | n_ch = features_fix_smooth.shape[1] 70 | 71 | # compute correlation volume with SSD 72 | ssd,ssd_argmin = correlate(features_fix_smooth,features_mov_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 73 | 74 | # provide auxiliary mesh grid 75 | disp_mesh_t = F.affine_grid(disp_hw*torch.eye(3,4).cuda().half().unsqueeze(0),(1,1,disp_hw*2+1,disp_hw*2+1,disp_hw*2+1),align_corners=True).permute(0,4,1,2,3).reshape(3,-1,1) 76 | 77 | # perform coupled convex optimisation 78 | disp_soft = coupled_convex(ssd,ssd_argmin,disp_mesh_t,grid_sp,(H,W,D)) 79 | 80 | # if "ic" flag is set: make inverse consistent 81 | if ic: 82 | scale = torch.tensor([H//grid_sp-1,W//grid_sp-1,D//grid_sp-1]).view(1,3,1,1,1).cuda().half()/2 83 | 84 | ssd_,ssd_argmin_ = correlate(features_mov_smooth,features_fix_smooth,disp_hw,grid_sp,(H,W,D), n_ch) 85 | 86 | disp_soft_ = coupled_convex(ssd_,ssd_argmin_,disp_mesh_t,grid_sp,(H,W,D)) 87 | disp_ice,_ = inverse_consistency((disp_soft/scale).flip(1),(disp_soft_/scale).flip(1),iter=15) 88 | 89 | disp_hr = F.interpolate(disp_ice.flip(1)*scale*grid_sp,size=(H,W,D),mode='trilinear',align_corners=False) 90 | 91 | else: 92 | disp_hr=disp_soft 93 | 94 | # run Adam instance optimisation 95 | if lambda_weight > 0: 96 | with torch.no_grad(): 97 | 98 | patch_features_fix = F.avg_pool3d(features_fix,grid_sp_adam,stride=grid_sp_adam) 99 | patch_features_mov = F.avg_pool3d(features_mov,grid_sp_adam,stride=grid_sp_adam) 100 | 101 | 102 | #create optimisable displacement grid 103 | disp_lr = F.interpolate(disp_hr,size=(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),mode='trilinear',align_corners=False) 104 | 105 | 106 | net = nn.Sequential(nn.Conv3d(3,1,(H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),bias=False)) 107 | net[0].weight.data[:] = disp_lr.float().cpu().data/grid_sp_adam 108 | net.cuda() 109 | optimizer = torch.optim.Adam(net.parameters(), lr=1) 110 | 111 | grid0 = F.affine_grid(torch.eye(3,4).unsqueeze(0).cuda(),(1,1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam),align_corners=False) 112 | 113 | #run Adam optimisation with diffusion regularisation and B-spline smoothing 114 | for iter in range(selected_niter): 115 | optimizer.zero_grad() 116 | 117 | disp_sample = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(net[0].weight,3,stride=1,padding=1),3,stride=1,padding=1),3,stride=1,padding=1).permute(0,2,3,4,1) 118 | reg_loss = lambda_weight*((disp_sample[0,:,1:,:]-disp_sample[0,:,:-1,:])**2).mean()+\ 119 | lambda_weight*((disp_sample[0,1:,:,:]-disp_sample[0,:-1,:,:])**2).mean()+\ 120 | lambda_weight*((disp_sample[0,:,:,1:]-disp_sample[0,:,:,:-1])**2).mean() 121 | 122 | scale = torch.tensor([(H//grid_sp_adam-1)/2,(W//grid_sp_adam-1)/2,(D//grid_sp_adam-1)/2]).cuda().unsqueeze(0) 123 | grid_disp = grid0.view(-1,3).cuda().float()+((disp_sample.view(-1,3))/scale).flip(1).float() 124 | 125 | patch_mov_sampled = F.grid_sample(patch_features_mov.float(),grid_disp.view(1,H//grid_sp_adam,W//grid_sp_adam,D//grid_sp_adam,3).cuda(),align_corners=False,mode='bilinear') 126 | 127 | sampled_cost = (patch_mov_sampled-patch_features_fix).pow(2).mean(1)*12 128 | loss = sampled_cost.mean() 129 | (loss+reg_loss).backward() 130 | optimizer.step() 131 | 132 | 133 | fitted_grid = disp_sample.detach().permute(0,4,1,2,3) 134 | disp_hr = F.interpolate(fitted_grid*grid_sp_adam,size=(H,W,D),mode='trilinear',align_corners=False) 135 | 136 | if selected_smooth == 5: 137 | kernel_smooth = 5 138 | padding_smooth = kernel_smooth//2 139 | disp_hr = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 140 | 141 | if selected_smooth == 3: 142 | kernel_smooth = 3 143 | padding_smooth = kernel_smooth//2 144 | disp_hr = F.avg_pool3d(F.avg_pool3d(F.avg_pool3d(disp_hr,kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1),kernel_smooth,padding=padding_smooth,stride=1) 145 | 146 | torch.cuda.synchronize() 147 | t1 = time.time() 148 | case_time = t1-t0 149 | print('case time: ', case_time) 150 | 151 | x = disp_hr[0,0,:,:,:].cpu().half().data.numpy() 152 | y = disp_hr[0,1,:,:,:].cpu().half().data.numpy() 153 | z = disp_hr[0,2,:,:,:].cpu().half().data.numpy() 154 | displacements = np.stack((x,y,z),3).astype(float) 155 | 156 | affine = nib.load(path_pred_fixed).affine 157 | disp_nii = nib.Nifti1Image(displacements, affine) 158 | nib.save(disp_nii, os.path.join(result_path,'disp.nii.gz')) 159 | return 160 | 161 | 162 | if __name__=="__main__": 163 | parser=argparse.ArgumentParser() 164 | parser.add_argument("-f","--path_pred_fixed", type=str, required=True) 165 | parser.add_argument("-m",'--path_pred_moving', type=str, required=True) 166 | parser.add_argument('--lambda_weight', type=float, default=1.25) 167 | parser.add_argument('--grid_sp', type=int, default=6) 168 | parser.add_argument('--disp_hw', type=int, default=4) 169 | parser.add_argument('--selected_niter', type=int, default=80) 170 | parser.add_argument('--selected_smooth', type=int, default=0) 171 | parser.add_argument('--grid_sp_adam', type=int, default=2) 172 | parser.add_argument('--ic', choices=('True','False'), default='True') 173 | parser.add_argument('--result_path', type=str, default='./') 174 | 175 | args= parser.parse_args() 176 | 177 | if args.ic == 'True': 178 | ic=True 179 | else: 180 | ic=False 181 | 182 | convex_adam(args.path_pred_fixed, 183 | args.path_pred_moving, 184 | args.lambda_weight, 185 | args.grid_sp, 186 | args.disp_hw, 187 | args.selected_niter, 188 | args.selected_smooth, 189 | args.grid_sp_adam, 190 | ic, 191 | args.result_path) 192 | -------------------------------------------------------------------------------- /src/convexAdam/convex_adam_translation.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | from typing import Iterable, Optional 4 | 5 | import numpy as np 6 | import SimpleITK as sitk 7 | 8 | from convexAdam.convex_adam_MIND import convex_adam_pt 9 | from convexAdam.convex_adam_utils import resample_img, resample_moving_to_fixed 10 | 11 | 12 | def index_translation_to_world_translation( 13 | index_translation: Iterable[float], 14 | direction: Iterable[float] 15 | ) -> np.ndarray: 16 | """ 17 | Convert a translation along the image grid to a translation in world coordinates. 18 | 19 | Args: 20 | index_translation (i, j, k): The translation in index coordinates (mm). 21 | direction: The direction of the image. 22 | Returns: 23 | world_translation (x, y, z): The translation in world coordinates (mm). 24 | """ 25 | dimension = int(np.sqrt(len(direction))) 26 | direction_matrix = np.array(direction).reshape((dimension, dimension)) 27 | index_translation = direction_matrix @ np.array(index_translation) 28 | return index_translation 29 | 30 | 31 | def apply_translation( 32 | moving_image: sitk.Image, 33 | translation_ijk: Iterable[float] = (0, 0, 0), 34 | ) -> sitk.Image: 35 | """ 36 | Apply a translation to an image, with the translation in mm along the image grid. 37 | 38 | Args: 39 | moving_image: The image to translate. 40 | translation_ijk: The translation in mm along the image grid. 41 | 42 | Returns: 43 | The translated image. 44 | """ 45 | # copy image 46 | moving_image = sitk.Image(moving_image) 47 | 48 | # apply translation to moving image 49 | translation_xyz = index_translation_to_world_translation(translation_ijk, moving_image.GetDirection()[0:9]) 50 | origin = list(moving_image.GetOrigin()) 51 | origin[0:3] -= translation_xyz 52 | moving_image.SetOrigin(tuple(origin)) 53 | 54 | return moving_image 55 | 56 | 57 | def convex_adam_translation( 58 | fixed_image: sitk.Image, 59 | moving_image: sitk.Image, 60 | segmentation: Optional[sitk.Image] = None, 61 | co_moving_images: Optional[Iterable[sitk.Image]] = None, 62 | ) -> tuple[tuple[float], sitk.Image, Optional[Iterable[sitk.Image]]]: 63 | """ 64 | Apply convex Adam translation to an image. 65 | 66 | Args: 67 | fixed_image: The fixed image. 68 | moving_image: The moving image. 69 | segmentation: The segmentation. 70 | co_moving_images: The co-moving images. 71 | 72 | Returns: 73 | translation_xyz: The translation in mm. 74 | moving_image: The moved image. 75 | co_moving_images: The moved co-moving images. 76 | """ 77 | 78 | # resample images to specified spacing and the field of view of the fixed image 79 | fixed_image_resampled = resample_img(fixed_image, spacing=(1.0, 1.0, 1.0)) 80 | moving_image_resampled = resample_moving_to_fixed(fixed_image_resampled, moving_image) 81 | 82 | # run convex adam 83 | displacementfield = convex_adam_pt( 84 | img_fixed=fixed_image_resampled, 85 | img_moving=moving_image_resampled, 86 | ) 87 | 88 | # convert displacement field to translation only 89 | if segmentation is not None: 90 | # resample segmentation to the same spacing as the displacement field 91 | segmentation = resample_moving_to_fixed(moving=segmentation, fixed=fixed_image_resampled) 92 | seg_arr = sitk.GetArrayFromImage(segmentation) 93 | seg_arr = (seg_arr > 0) # above resampling is with linear interpolation, so we need to threshold 94 | translation_zyx = np.mean(displacementfield[seg_arr], axis=0) 95 | else: 96 | translation_zyx = np.mean(displacementfield, axis=(0, 1, 2)) 97 | 98 | # transform translation into the number of pixels to move in image space 99 | spacing_zyx = np.array(list(moving_image.GetSpacing())[::-1]) 100 | translation_ijk = translation_zyx / spacing_zyx 101 | translation_ijk_voxels = np.round(translation_ijk, decimals=0) 102 | translation_ijk_mm = translation_ijk_voxels * spacing_zyx 103 | translation_xyz = tuple(list(translation_ijk_mm[::-1])) 104 | 105 | # apply translation to moving image 106 | moving_image = apply_translation(moving_image=moving_image, translation_ijk=translation_xyz) 107 | 108 | # apply translation to co-moving images 109 | if co_moving_images is not None: 110 | for i, co_moving_image in enumerate(co_moving_images): 111 | co_moving_image = apply_translation(moving_image=co_moving_image, translation_ijk=translation_xyz) 112 | co_moving_images[i] = co_moving_image 113 | 114 | return translation_xyz, moving_image, co_moving_images 115 | 116 | 117 | def convex_adam_translation_from_file( 118 | fixed_path: Path = Path("/input/fixed.mha"), 119 | moving_path: Path = Path("/input/moving.mha"), 120 | segmentation_path: Optional[Path] = Path("/input/segmentation.nii.gz"), 121 | moving_output_path: Path = Path("/output/moving_warped.mha"), 122 | co_moving_paths: Optional[Iterable[Path]] = None, 123 | co_moving_output_paths: Optional[Iterable[Path]] = None, 124 | ): 125 | # paths 126 | fixed_image = sitk.ReadImage(str(fixed_path)) 127 | moving_image = sitk.ReadImage(str(moving_path)) 128 | segmentation = sitk.ReadImage(str(segmentation_path)) if segmentation_path is not None else None 129 | 130 | translation_xyz, moving_image, co_moving_images = convex_adam_translation( 131 | fixed_image=fixed_image, 132 | moving_image=moving_image, 133 | segmentation=segmentation, 134 | co_moving_images=[sitk.ReadImage(str(path)) for path in co_moving_paths] if co_moving_paths is not None else None, 135 | ) 136 | 137 | # save moved image 138 | sitk.WriteImage(moving_image, str(moving_output_path)) 139 | 140 | # save co-moving images 141 | if co_moving_images is not None: 142 | for co_moving_image, co_moving_output_path in zip(co_moving_images, co_moving_output_paths): 143 | sitk.WriteImage(co_moving_image, str(co_moving_output_path)) 144 | 145 | return translation_xyz 146 | 147 | 148 | if __name__ == "__main__": 149 | # command line interface 150 | parser = argparse.ArgumentParser(description="Apply convex Adam translation to an image.") 151 | parser.add_argument("--fixed_path", type=Path, help="Path to the fixed image.") 152 | parser.add_argument("--moving_path", type=Path, help="Path to the moving image.") 153 | parser.add_argument("--segmentation_path", type=Path, help="Path to the segmentation.") 154 | parser.add_argument("--moving_output_path", type=Path, help="Path to the output moving image.") 155 | parser.add_argument("--co_moving_paths", type=Path, nargs="+", help="Paths to the co-moving images.") 156 | parser.add_argument("--co_moving_output_paths", type=Path, nargs="+", help="Paths to the output co-moving images.") 157 | args = parser.parse_args() 158 | 159 | convex_adam_translation_from_file( 160 | fixed_path=args.fixed_path, 161 | moving_path=args.moving_path, 162 | segmentation_path=args.segmentation_path, 163 | moving_output_path=args.moving_output_path, 164 | co_moving_paths=args.co_moving_paths, 165 | co_moving_output_paths=args.co_moving_output_paths, 166 | ) 167 | -------------------------------------------------------------------------------- /tests/Development-README.md: -------------------------------------------------------------------------------- 1 | # Steps to set up testing environment 2 | 3 | Set up conda environment: 4 | ``` 5 | conda create --name=convex_adam python=3.10 6 | ``` 7 | 8 | Activate environment: 9 | ``` 10 | conda activate convex_adam 11 | ``` 12 | 13 | Install module and dependencies: 14 | ``` 15 | pip install -e . 16 | pip install -r requirements_dev.txt 17 | ``` 18 | 19 | Perform tests: 20 | ``` 21 | pytest 22 | mypy src 23 | flake8 src 24 | ``` 25 | 26 | # Push release to PyPI 27 | 1. Increase version in setup.py, and set below 28 | 2. Build: `python -m build` 29 | 3. Test package distribution: `python -m twine upload --repository testpypi dist/*0.2.0*` 30 | 4. Distribute package to PyPI: `python -m twine upload dist/*0.2.0*` 31 | -------------------------------------------------------------------------------- /tests/helper_functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import SimpleITK as sitk 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | 7 | def rotate_image_around_center_affine(image: sitk.Image, angle: float) -> None: 8 | """ 9 | Rotate the given image around its center by the specified angle. 10 | 11 | Parameters: 12 | moving_image (sitk.image): The image to be rotated. 13 | angle (float): The angle of rotation in radians. 14 | """ 15 | original_origin = np.array(image.GetOrigin()) 16 | image.SetOrigin([0, 0, 0]) 17 | 18 | # Calculate the physical center of the image 19 | physical_center = image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize())/2.0) 20 | 21 | # For a 3D image rotation around the z-axis, the rotation matrix is: 22 | direction = image.GetDirection() 23 | axis_angle = (direction[2], direction[5], direction[8], angle) 24 | rotation_matrix = matrix_from_axis_angle(axis_angle) 25 | 26 | # Compute the new origin after rotation 27 | new_origin = np.dot(rotation_matrix, -np.array(physical_center)) + np.array(physical_center) 28 | 29 | # Get the current direction of the image 30 | direction = np.array(image.GetDirection()).reshape((3, 3)) 31 | 32 | # Compute the new direction cosines by multiplying the current direction by the rotation matrix 33 | new_direction_cosines = np.dot(rotation_matrix, direction) 34 | 35 | # Update the image with the new direction and origin 36 | image.SetDirection(new_direction_cosines.flatten()) 37 | image.SetOrigin(new_origin + original_origin) 38 | 39 | 40 | def rotate_image_around_center_resample(image: sitk.Image, angle: float) -> sitk.Image: 41 | """ 42 | Rotate the given image around its center by the specified angle. The rotation is around the z-axis. 43 | 44 | Parameters: 45 | moving_image (sitk.image): The image to be rotated. 46 | angle (float): The angle of rotation in radians. 47 | """ 48 | scale_factor = 1.0 49 | translation = (0, 0, 0) 50 | rotation_center = image.TransformContinuousIndexToPhysicalPoint(np.array(image.GetSize())/2.0) 51 | direction = image.GetDirection() 52 | axis = (direction[2], direction[5], direction[8]) 53 | 54 | # rotate moving image 55 | similarity_transform = sitk.Similarity3DTransform( 56 | scale_factor, axis, angle, translation, rotation_center 57 | ) 58 | 59 | image = sitk.Resample(image, similarity_transform) 60 | 61 | return image 62 | 63 | 64 | # This function is from https://github.com/rock-learning/pytransform3d/blob/7589e083a50597a75b12d745ebacaa7cc056cfbd/pytransform3d/rotations.py#L302 65 | def matrix_from_axis_angle(a): 66 | """ Compute rotation matrix from axis-angle. 67 | This is called exponential map or Rodrigues' formula. 68 | Parameters 69 | ---------- 70 | a : array-like, shape (4,) 71 | Axis of rotation and rotation angle: (x, y, z, angle) 72 | Returns 73 | ------- 74 | R : array-like, shape (3, 3) 75 | Rotation matrix 76 | """ 77 | ux, uy, uz, theta = a 78 | c = np.cos(theta) 79 | s = np.sin(theta) 80 | ci = 1.0 - c 81 | R = np.array([[ci * ux * ux + c, 82 | ci * ux * uy - uz * s, 83 | ci * ux * uz + uy * s], 84 | [ci * uy * ux + uz * s, 85 | ci * uy * uy + c, 86 | ci * uy * uz - ux * s], 87 | [ci * uz * ux - uy * s, 88 | ci * uz * uy + ux * s, 89 | ci * uz * uz + c], 90 | ]) 91 | 92 | # This is equivalent to 93 | # R = (np.eye(3) * np.cos(theta) + 94 | # (1.0 - np.cos(theta)) * a[:3, np.newaxis].dot(a[np.newaxis, :3]) + 95 | # cross_product_matrix(a[:3]) * np.sin(theta)) 96 | 97 | return R 98 | 99 | 100 | # The following code has been copied/adapted from https://github.com/jinh0park/pytorch-ssim-3D 101 | # Thanks to the author for providing this resource 102 | def gaussian(window_size, sigma): 103 | x = torch.arange(window_size, dtype=torch.float32) - window_size // 2 104 | gauss = torch.exp(-x**2 / (2 * sigma**2)) 105 | return gauss / gauss.sum() 106 | 107 | def create_window_3D(window_size, channel): 108 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 109 | _2D_window = _1D_window.mm(_1D_window.t()) 110 | _3D_window = _1D_window.mm(_2D_window.reshape(1, -1)).reshape(window_size, window_size, window_size).float().unsqueeze(0).unsqueeze(0) 111 | window = Variable(_3D_window.expand(channel, 1, window_size, window_size, window_size).contiguous()) 112 | return window 113 | 114 | def _ssim_3D(img1, img2, window, window_size, channel, size_average = True): 115 | mu1 = F.conv3d(img1, window, padding = window_size//2, groups = channel) 116 | mu2 = F.conv3d(img2, window, padding = window_size//2, groups = channel) 117 | 118 | mu1_sq = mu1.pow(2) 119 | mu2_sq = mu2.pow(2) 120 | 121 | mu1_mu2 = mu1*mu2 122 | 123 | sigma1_sq = F.conv3d(img1*img1, window, padding = window_size//2, groups = channel) - mu1_sq 124 | sigma2_sq = F.conv3d(img2*img2, window, padding = window_size//2, groups = channel) - mu2_sq 125 | sigma12 = F.conv3d(img1*img2, window, padding = window_size//2, groups = channel) - mu1_mu2 126 | 127 | C1 = 0.01**2 128 | C2 = 0.03**2 129 | 130 | ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) 131 | 132 | if size_average: 133 | return ssim_map.mean() 134 | else: 135 | return ssim_map.mean(1).mean(1).mean(1) 136 | 137 | def ssim3D(img1, img2, window_size = 11, size_average = True): 138 | (_, channel, _, _, _) = img1.size() 139 | window = create_window_3D(window_size, channel) 140 | 141 | if img1.is_cuda: 142 | window = window.cuda(img1.get_device()) 143 | window = window.type_as(img1) 144 | 145 | return _ssim_3D(img1, img2, window, window_size, channel, size_average) 146 | 147 | 148 | -------------------------------------------------------------------------------- /tests/input/10000/10000_1000000_adc.mha: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/tests/input/10000/10000_1000000_adc.mha -------------------------------------------------------------------------------- /tests/input/10000/10000_1000000_hbv.mha: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/tests/input/10000/10000_1000000_hbv.mha -------------------------------------------------------------------------------- /tests/input/10000/10000_1000000_prostate_seg.nii.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/tests/input/10000/10000_1000000_prostate_seg.nii.gz -------------------------------------------------------------------------------- /tests/input/10000/10000_1000000_t2w.mha: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/tests/input/10000/10000_1000000_t2w.mha -------------------------------------------------------------------------------- /tests/output-expected/10000/10000_1000000_adc_warped.mha: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/tests/output-expected/10000/10000_1000000_adc_warped.mha -------------------------------------------------------------------------------- /tests/output/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/multimodallearning/convexAdam/ed65c35d2ee489501d23ce0a6119d2b613d6f470/tests/output/.gitkeep -------------------------------------------------------------------------------- /tests/test_convex_adam_mind.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import torch 3 | 4 | import numpy as np 5 | import SimpleITK as sitk 6 | from helper_functions import (rotate_image_around_center_affine, 7 | rotate_image_around_center_resample, 8 | ssim3D) 9 | 10 | from convexAdam.apply_convex import apply_convex 11 | from convexAdam.convex_adam_MIND import convex_adam_pt 12 | from convexAdam.convex_adam_utils import (resample_img, 13 | resample_moving_to_fixed, 14 | rescale_displacement_field) 15 | 16 | 17 | ##For testing 18 | torch.backends.cuda.matmul.allow_tf32 = False 19 | torch.backends.cudnn.deterministic = True 20 | torch.use_deterministic_algorithms(True, warn_only=True) 21 | 22 | def test_convex_adam_identity( 23 | input_dir = Path("tests/input"), 24 | subject_id = "10000_1000000", 25 | ): 26 | # paths 27 | patient_id = subject_id.split("_")[0] 28 | fixed_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 29 | moving_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 30 | 31 | # resample images to specified spacing and the field of view of the fixed image 32 | fixed_image_resampled = resample_img(fixed_image, spacing=(1.0, 1.0, 1.0)) 33 | moving_image_resampled = resample_moving_to_fixed(fixed_image_resampled, moving_image) 34 | 35 | # run convex adam 36 | displacementfield = convex_adam_pt( 37 | img_fixed=fixed_image_resampled, 38 | img_moving=moving_image_resampled, 39 | ) 40 | 41 | # test that the displacement field is performing identity transformation 42 | assert np.allclose(displacementfield, np.zeros_like(displacementfield), atol=0.1) 43 | 44 | 45 | def test_convex_adam( 46 | input_dir = Path("tests/input"), 47 | output_dir = Path("tests/output"), 48 | output_expected_dir = Path("tests/output-expected"), 49 | subject_id = "10000_1000000", 50 | ): 51 | # paths 52 | patient_id = subject_id.split("_")[0] 53 | fixed_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 54 | moving_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_adc.mha")) 55 | moving_image_reference = sitk.ReadImage(str(output_expected_dir / patient_id / f"{subject_id}_adc_warped.mha")) 56 | (output_dir / patient_id).mkdir(exist_ok=True, parents=True) 57 | 58 | # resample images to specified spacing and the field of view of the fixed image 59 | fixed_image_resampled = resample_img(fixed_image, spacing=(1.0, 1.0, 1.0)) 60 | moving_image_resampled = resample_moving_to_fixed(fixed_image_resampled, moving_image) 61 | 62 | # run convex adam 63 | displacementfield = convex_adam_pt( 64 | img_fixed=fixed_image_resampled, 65 | img_moving=moving_image_resampled, 66 | ) 67 | 68 | # apply displacement field 69 | moving_image_resampled_warped = apply_convex( 70 | disp=displacementfield, 71 | moving=moving_image_resampled, 72 | ) 73 | 74 | # convert to SimpleITK image 75 | moving_image_resampled_warped = sitk.GetImageFromArray(moving_image_resampled_warped.astype(np.float32)) 76 | moving_image_resampled_warped.CopyInformation(moving_image_resampled) 77 | 78 | # save warped image 79 | output_dir.mkdir(exist_ok=True, parents=True) 80 | sitk.WriteImage(moving_image_resampled_warped, str(output_dir / patient_id / f"{subject_id}_adc_warped.mha")) 81 | 82 | # compare results with SSIM metric 83 | arr1 = torch.from_numpy(sitk.GetArrayFromImage(moving_image_resampled_warped)[np.newaxis, np.newaxis, ...]) 84 | arr2 = torch.from_numpy(sitk.GetArrayFromImage(moving_image_reference)[np.newaxis, np.newaxis, ...]) 85 | assert ssim3D(arr1, arr2) > 0.95 86 | 87 | def test_convex_adam_translation( 88 | input_dir = Path("tests/input"), 89 | output_dir = Path("tests/output"), 90 | subject_id = "10000_1000000", 91 | ): 92 | # paths 93 | patient_id = subject_id.split("_")[0] 94 | fixed_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 95 | moving_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 96 | (output_dir / patient_id).mkdir(exist_ok=True, parents=True) 97 | 98 | # set direction to unity (this is important for the test) 99 | # doing this aligns the image axes with the world axes 100 | fixed_image.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) 101 | moving_image.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) 102 | 103 | # resample images to specified spacing and the field of view of the fixed image 104 | fixed_image_resampled = resample_img(fixed_image, spacing=(1.0, 1.0, 1.0)) 105 | moving_image_resampled = resample_moving_to_fixed(fixed_image_resampled, moving_image) 106 | 107 | # move moving image 108 | affine = sitk.AffineTransform(3) 109 | affine.SetTranslation([10, 10, 10]) 110 | moving_image_resampled = sitk.Resample(moving_image_resampled, affine) 111 | 112 | # run convex adam 113 | displacementfield = convex_adam_pt( 114 | img_fixed=fixed_image_resampled, 115 | img_moving=moving_image_resampled, 116 | ) 117 | 118 | # apply displacement field 119 | moving_image_resampled_warped = apply_convex( 120 | disp=displacementfield, 121 | moving=moving_image_resampled, 122 | ) 123 | 124 | # convert to SimpleITK image 125 | moving_image_resampled_warped = sitk.GetImageFromArray(moving_image_resampled_warped.astype(np.float32)) 126 | moving_image_resampled_warped.CopyInformation(moving_image_resampled) 127 | 128 | # save warped image 129 | output_dir.mkdir(exist_ok=True, parents=True) 130 | sitk.WriteImage(moving_image_resampled_warped, str(output_dir / patient_id / f"{subject_id}_t2w_translation_warped.mha")) 131 | 132 | # compare with reference (displacement field should be within 1 mm of the translation for at least 90% of the voxels in the center) 133 | s = displacementfield.shape[0] // 10 134 | displacementfield_center = displacementfield[s:-s, s:-s, s:-s] 135 | assert (np.abs(displacementfield_center + 10) < 1).mean() > 0.90 136 | 137 | 138 | def test_convex_adam_identity_rotated_direction( 139 | input_dir = Path("tests/input"), 140 | output_dir = Path("tests/output"), 141 | subject_id = "10000_1000000", 142 | ): 143 | # paths 144 | patient_id = subject_id.split("_")[0] 145 | fixed_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 146 | moving_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 147 | (output_dir / patient_id).mkdir(exist_ok=True, parents=True) 148 | 149 | # set center and direction to unity 150 | moving_image.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) 151 | moving_image.SetOrigin([0, 0, 0]) 152 | fixed_image.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) 153 | fixed_image.SetOrigin([0, 0, 0]) 154 | sitk.WriteImage(fixed_image, str(output_dir / patient_id / f"{subject_id}_fixed_unity.mha")) 155 | 156 | # rotate the moving image twice: once by updating the direction cosines and once by resampling the image 157 | angle = np.pi / 4.0 158 | moving_image = rotate_image_around_center_resample(moving_image, angle) 159 | rotate_image_around_center_affine(moving_image, angle) 160 | 161 | # resample images to specified spacing and the field of view of the fixed image 162 | fixed_image_resampled = resample_img(fixed_image, spacing=(1.0, 1.0, 1.0)) 163 | moving_image_resampled = resample_moving_to_fixed(fixed_image_resampled, moving_image) 164 | sitk.WriteImage(fixed_image_resampled, str(output_dir / patient_id / f"{subject_id}_fixed_resampled.mha")) 165 | 166 | # run convex adam 167 | displacementfield = convex_adam_pt( 168 | img_fixed=fixed_image_resampled, 169 | img_moving=moving_image_resampled, 170 | ) 171 | 172 | # apply displacement field 173 | moving_image_resampled_warped = apply_convex( 174 | disp=displacementfield, 175 | moving=moving_image_resampled, 176 | ) 177 | 178 | # convert to SimpleITK image 179 | moving_image_resampled_warped = sitk.GetImageFromArray(moving_image_resampled_warped.astype(np.float32)) 180 | moving_image_resampled_warped.CopyInformation(moving_image_resampled) 181 | 182 | # save warped image 183 | output_dir.mkdir(exist_ok=True, parents=True) 184 | sitk.WriteImage(moving_image_resampled_warped, str(output_dir / patient_id / f"{subject_id}_moving_rotation_warped.mha")) 185 | 186 | # test that the displacement field is performing identity transformation 187 | d1, d2, d3 = np.array(displacementfield.shape[0:3]) // 3 188 | disp_center = displacementfield[d1:-d1, d2:-d2, d3:-d3] 189 | assert np.allclose(disp_center, np.zeros_like(disp_center), atol=0.3) 190 | 191 | 192 | def test_convex_adam_identity_rotated_and_shifted( 193 | input_dir = Path("tests/input"), 194 | output_dir = Path("tests/output"), 195 | subject_id = "10000_1000000", 196 | ): 197 | # paths 198 | patient_id = subject_id.split("_")[0] 199 | fixed_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 200 | moving_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 201 | (output_dir / patient_id).mkdir(exist_ok=True, parents=True) 202 | 203 | # set center and direction to unity 204 | moving_image.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) 205 | moving_image.SetOrigin([0, 0, 0]) 206 | fixed_image.SetDirection([1, 0, 0, 0, 1, 0, 0, 0, 1]) 207 | fixed_image.SetOrigin([0, 0, 0]) 208 | sitk.WriteImage(fixed_image, str(output_dir / patient_id / f"{subject_id}_fixed_unity.mha")) 209 | 210 | # rotate the moving image twice: once by updating the direction cosines and once by resampling the image 211 | angle = np.pi / 4.0 212 | moving_image = rotate_image_around_center_resample(moving_image, angle) 213 | rotate_image_around_center_affine(moving_image, angle) 214 | 215 | # translate the moving image 216 | affine = sitk.AffineTransform(3) 217 | affine.SetTranslation([20, 0, 0]) 218 | moving_image = sitk.Resample(moving_image, affine) 219 | sitk.WriteImage(moving_image, str(output_dir / patient_id / f"{subject_id}_moving_rotated_and_shifted.mha")) 220 | # note: the moving image, when viewed in ITK-SNAP, is now moved 20 mm to the left (patient's right) 221 | 222 | # resample images to specified spacing and the field of view of the fixed image 223 | fixed_image_resampled = resample_img(fixed_image, spacing=(1.0, 1.0, 1.0)) 224 | moving_image_resampled = resample_moving_to_fixed(fixed_image_resampled, moving_image) 225 | sitk.WriteImage(fixed_image_resampled, str(output_dir / patient_id / f"{subject_id}_fixed_resampled.mha")) 226 | 227 | # run convex adam 228 | displacementfield = convex_adam_pt( 229 | img_fixed=fixed_image_resampled, 230 | img_moving=moving_image_resampled, 231 | ) 232 | 233 | disp = sitk.GetImageFromArray(displacementfield.astype(np.float32)) 234 | disp.CopyInformation(fixed_image_resampled) 235 | sitk.WriteImage(disp, str(output_dir / patient_id / f"{subject_id}_displacementfield.mha")) 236 | 237 | # apply displacement field 238 | moving_image_resampled_warped = apply_convex( 239 | disp=displacementfield, 240 | moving=moving_image_resampled, 241 | ) 242 | 243 | # convert to SimpleITK image 244 | moving_image_resampled_warped = sitk.GetImageFromArray(moving_image_resampled_warped.astype(np.float32)) 245 | moving_image_resampled_warped.CopyInformation(moving_image_resampled) 246 | 247 | # save warped image 248 | output_dir.mkdir(exist_ok=True, parents=True) 249 | sitk.WriteImage(moving_image_resampled_warped, str(output_dir / patient_id / f"{subject_id}_moving_rotated_and_shifted_warped.mha")) 250 | 251 | # apply displacement field to the moving image without resampling the moving image 252 | displacement_field_rescaled = rescale_displacement_field( 253 | displacement_field=displacementfield, 254 | moving_image=moving_image, 255 | fixed_image=fixed_image, 256 | fixed_image_resampled=fixed_image_resampled, 257 | ) 258 | 259 | moving_image_warped = apply_convex( 260 | disp=displacement_field_rescaled, 261 | moving=moving_image, 262 | ) 263 | moving_image_warped = sitk.GetImageFromArray(moving_image_warped.astype(np.float32)) 264 | moving_image_warped.CopyInformation(moving_image) 265 | sitk.WriteImage(moving_image_warped, str(output_dir / patient_id / f"{subject_id}_original_moving_rotated_and_shifted_warped.mha")) 266 | 267 | 268 | if __name__ == "__main__": 269 | test_convex_adam_identity() 270 | test_convex_adam() 271 | test_convex_adam_translation() 272 | test_convex_adam_identity_rotated_direction() 273 | test_convex_adam_identity_rotated_and_shifted() 274 | print("All tests passed") 275 | -------------------------------------------------------------------------------- /tests/test_convex_adam_mind_aniso.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import SimpleITK as sitk 5 | from helper_functions import (rotate_image_around_center_affine, 6 | rotate_image_around_center_resample) 7 | 8 | from convexAdam.apply_convex import apply_convex, apply_convex_original_moving 9 | from convexAdam.convex_adam_MIND import convex_adam_pt 10 | from convexAdam.convex_adam_utils import (resample_img, 11 | resample_moving_to_fixed, 12 | rescale_displacement_field) 13 | 14 | 15 | def test_convex_adam_rotated_and_shifted_anisotropic( 16 | input_dir = Path("tests/input"), 17 | output_dir = Path("tests/output"), 18 | subject_id = "10000_1000000", 19 | ): 20 | # paths 21 | patient_id = subject_id.split("_")[0] 22 | fixed_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 23 | moving_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 24 | (output_dir / patient_id).mkdir(exist_ok=True, parents=True) 25 | 26 | # translate the moving image 27 | translation = 20 28 | affine = sitk.AffineTransform(3) 29 | affine.SetTranslation([translation, 0, 0]) 30 | moving_image = sitk.Resample(moving_image, affine) 31 | 32 | # rotate the moving image twice: once by updating the direction cosines and once by resampling the image 33 | angle = np.pi / 4.0 34 | moving_image = rotate_image_around_center_resample(moving_image, angle) 35 | rotate_image_around_center_affine(moving_image, angle) 36 | sitk.WriteImage(moving_image, str(output_dir / patient_id / f"{subject_id}_moving_rotated_and_shifted.mha")) 37 | # note: the moving image, when viewed in ITK-SNAP, is now moved 20 mm to the left (patient's right) 38 | 39 | # resample images to specified spacing and the field of view of the fixed image 40 | fixed_image_resampled = resample_img(fixed_image, spacing=(1.0, 1.0, 1.0)) 41 | moving_image_resampled = resample_moving_to_fixed(fixed_image_resampled, moving_image) 42 | sitk.WriteImage(fixed_image_resampled, str(output_dir / patient_id / f"{subject_id}_fixed_resampled.mha")) 43 | sitk.WriteImage(moving_image_resampled, str(output_dir / patient_id / f"{subject_id}_moving_rotated_and_shifted_resampled.mha")) 44 | 45 | # run convex adam 46 | displacementfield = convex_adam_pt( 47 | img_fixed=fixed_image_resampled, 48 | img_moving=moving_image_resampled, 49 | ) 50 | 51 | disp = sitk.GetImageFromArray(displacementfield.astype(np.float32)) 52 | disp.CopyInformation(fixed_image_resampled) 53 | sitk.WriteImage(disp, str(output_dir / patient_id / f"{subject_id}_displacementfield.mha")) 54 | 55 | # apply displacement field 56 | moving_image_resampled_warped = apply_convex( 57 | disp=displacementfield, 58 | moving=moving_image_resampled, 59 | ) 60 | 61 | # convert to SimpleITK image 62 | moving_image_resampled_warped = sitk.GetImageFromArray(moving_image_resampled_warped.astype(np.float32)) 63 | moving_image_resampled_warped.CopyInformation(moving_image_resampled) 64 | 65 | # save warped image 66 | output_dir.mkdir(exist_ok=True, parents=True) 67 | sitk.WriteImage(moving_image_resampled_warped, str(output_dir / patient_id / f"{subject_id}_moving_rotated_and_shifted_resampled_warped.mha")) 68 | 69 | # apply displacement field to the moving image without resampling the moving image 70 | displacement_field_rescaled = rescale_displacement_field( 71 | displacement_field=displacementfield, 72 | moving_image=moving_image, 73 | fixed_image=fixed_image, 74 | fixed_image_resampled=fixed_image_resampled, 75 | ) 76 | 77 | moving_image_warped = apply_convex( 78 | disp=displacement_field_rescaled, 79 | moving=moving_image, 80 | ) 81 | moving_image_warped = sitk.GetImageFromArray(moving_image_warped.astype(np.float32)) 82 | moving_image_warped.CopyInformation(moving_image) 83 | sitk.WriteImage(moving_image_warped, str(output_dir / patient_id / f"{subject_id}_moving_rotated_and_shifted_warped.mha")) 84 | 85 | 86 | def test_convex_adam_anisotropic( 87 | input_dir = Path("tests/input"), 88 | output_dir = Path("tests/output"), 89 | subject_id = "10000_1000000", 90 | ): 91 | # paths 92 | patient_id = subject_id.split("_")[0] 93 | fixed_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 94 | moving_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_adc.mha")) 95 | (output_dir / patient_id).mkdir(exist_ok=True, parents=True) 96 | 97 | # resample images to specified spacing and the field of view of the fixed image 98 | fixed_image_resampled = resample_img(fixed_image, spacing=(1.0, 1.0, 1.0)) 99 | moving_image_resampled = resample_moving_to_fixed(fixed_image_resampled, moving_image) 100 | 101 | # run convex adam 102 | displacementfield = convex_adam_pt( 103 | img_fixed=fixed_image_resampled, 104 | img_moving=moving_image_resampled, 105 | ) 106 | 107 | # apply displacement field to the moving image without resampling the moving image 108 | moving_image_warped = apply_convex_original_moving( 109 | disp=displacementfield, 110 | moving_image_original=moving_image, 111 | fixed_image_original=fixed_image, 112 | fixed_image_resampled=fixed_image_resampled, 113 | ) 114 | sitk.WriteImage(moving_image_warped, str(output_dir / patient_id / f"{subject_id}_moving_warped.mha")) 115 | 116 | 117 | if __name__ == "__main__": 118 | test_convex_adam_rotated_and_shifted_anisotropic() 119 | test_convex_adam_anisotropic() 120 | print("All tests passed") 121 | -------------------------------------------------------------------------------- /tests/test_convex_adam_mind_translation.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Iterable 3 | 4 | import numpy as np 5 | import SimpleITK as sitk 6 | 7 | from convexAdam.convex_adam_translation import ( 8 | apply_translation, convex_adam_translation, 9 | index_translation_to_world_translation) 10 | from convexAdam.convex_adam_utils import resample_moving_to_fixed 11 | 12 | 13 | def translate_along_image_directions(image: sitk.Image, translation: Iterable[float]): 14 | """ 15 | Translate an image along its image directions (not physical directions). 16 | 17 | Args: 18 | image: The image to translate. 19 | translation (x, y, z): The translation in the image directions (mm). 20 | """ 21 | # Convert physical translation to index space (voxel units) 22 | world_translation = index_translation_to_world_translation(translation, direction=image.GetDirection()) 23 | 24 | # Create the transformation 25 | dimension = image.GetDimension() 26 | transform = sitk.TranslationTransform(dimension, world_translation) 27 | 28 | # Apply the transformation 29 | resampled_image = sitk.Resample(image, transform, sitk.sitkLinear, 0.0, image.GetPixelID()) 30 | 31 | return resampled_image 32 | 33 | 34 | def test_translation_precision( 35 | input_dir = Path("tests/input"), 36 | output_dir = Path("tests/output"), 37 | subject_id = "10000_1000000", 38 | ): 39 | # paths 40 | patient_id = subject_id.split("_")[0] 41 | fixed_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 42 | moving_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 43 | (output_dir / patient_id).mkdir(exist_ok=True, parents=True) 44 | 45 | # move moving image a multiple of the voxel size 46 | spacing = np.array(moving_image.GetSpacing()) 47 | nvoxels = 5 48 | translation = spacing * nvoxels 49 | moving_image = translate_along_image_directions(image=moving_image, translation=translation) 50 | sitk.WriteImage(moving_image, str(output_dir / patient_id / f"{subject_id}_t2w_translation.mha")) 51 | 52 | # move moving image back 53 | moving_image = apply_translation(moving_image=moving_image, translation_ijk=-translation) 54 | sitk.WriteImage(moving_image, str(output_dir / patient_id / f"{subject_id}_t2w_translation_back.mha")) 55 | 56 | # compare images 57 | moving_image = resample_moving_to_fixed(moving=moving_image, fixed=fixed_image) 58 | arr_fixed = sitk.GetArrayFromImage(fixed_image) 59 | arr_moving = sitk.GetArrayFromImage(moving_image) 60 | 61 | # crop (to avoid edge effects from translation) 62 | arr_fixed = arr_fixed[nvoxels:-nvoxels, nvoxels:-nvoxels, nvoxels:-nvoxels] 63 | arr_moving = arr_moving[nvoxels:-nvoxels, nvoxels:-nvoxels, nvoxels:-nvoxels] 64 | 65 | np.testing.assert_allclose( 66 | arr_fixed, 67 | arr_moving, 68 | atol=2.0 69 | ) 70 | 71 | 72 | def test_convex_adam_translation( 73 | input_dir = Path("tests/input"), 74 | subject_id = "10000_1000000", 75 | use_mask: bool = True, 76 | ): 77 | # paths 78 | patient_id = subject_id.split("_")[0] 79 | fixed_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 80 | moving_image = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_t2w.mha")) 81 | if use_mask: 82 | segmentation = sitk.ReadImage(str(input_dir / patient_id / f"{subject_id}_prostate_seg.nii.gz")) 83 | 84 | # move moving image 85 | translation = [10, 10, 0] 86 | moving_image = translate_along_image_directions(moving_image, translation) 87 | 88 | # apply convex adam translation 89 | translation_xyz, moving_image, _ = convex_adam_translation( 90 | fixed_image=fixed_image, 91 | moving_image=moving_image, 92 | segmentation=segmentation, 93 | ) 94 | 95 | # check translation 96 | np.testing.assert_allclose( 97 | -np.array(translation), 98 | translation_xyz, 99 | atol=1.0 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | test_translation_precision() 105 | test_convex_adam_translation() 106 | print("All tests passed") 107 | --------------------------------------------------------------------------------