├── .dockerignore ├── .gitignore ├── .pre-commit-config.yaml ├── Dockerfile ├── LICENSE ├── README.md ├── assets └── figure_samples.png ├── cluster └── runai │ ├── preprocessing │ ├── create_datalist.sh │ └── create_png_dataset.sh │ ├── testing │ ├── controlnet_performance.sh │ ├── convert_model.sh │ ├── fid.sh │ ├── msssim.sh │ ├── msssim_reconstruction.sh │ ├── sample_flair_to_t1w.sh │ └── sample_t1w.sh │ └── training │ ├── controlnet.sh │ ├── ldm.sh │ └── stage1.sh ├── configs ├── controlnet │ └── controlnet_v0.yaml ├── ldm │ └── ldm_v0.yaml └── stage1 │ └── aekl_v0.yaml ├── create_image.sh ├── pyproject.toml ├── requirements.txt └── src ├── bash └── start_script.sh └── python ├── preprocessing ├── create_datalist.py └── create_png_dataset.py ├── testing ├── compute_controlnet_performance.py ├── compute_fid.py ├── compute_msssim_reconstruction.py ├── compute_msssim_sample.py ├── convert_mlflow_to_pytorch.py ├── sample_flair_to_t1w.py ├── sample_t1w.py └── util.py └── training ├── custom_transforms.py ├── train_aekl.py ├── train_controlnet.py ├── train_ldm.py ├── training_functions.py └── util.py /.dockerignore: -------------------------------------------------------------------------------- 1 | ### Python template 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 99 | __pypackages__/ 100 | 101 | # Celery stuff 102 | celerybeat-schedule 103 | celerybeat.pid 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | .dmypy.json 130 | dmypy.json 131 | 132 | # Pyre type checker 133 | .pyre/ 134 | 135 | # pytype static type analyzer 136 | .pytype/ 137 | 138 | # Cython debug symbols 139 | cython_debug/ 140 | 141 | outputs/ 142 | .idea/ 143 | mlruns/ 144 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.4.0 4 | hooks: 5 | - id: end-of-file-fixer 6 | - id: trailing-whitespace 7 | - id: check-yaml 8 | - id: check-docstring-first 9 | - id: check-executables-have-shebangs 10 | - id: check-toml 11 | - id: check-case-conflict 12 | - id: check-added-large-files 13 | args: ['--maxkb=1024'] 14 | - id: detect-private-key 15 | - id: forbid-new-submodules 16 | - id: pretty-format-json 17 | args: ['--autofix', '--no-sort-keys', '--indent=4'] 18 | - id: end-of-file-fixer 19 | - id: mixed-line-ending 20 | - repo: https://github.com/psf/black 21 | rev: 23.3.0 22 | hooks: 23 | - id: black 24 | - repo: https://github.com/PyCQA/isort 25 | rev: 5.12.0 26 | hooks: 27 | - id: isort 28 | - repo: https://github.com/PyCQA/flake8 29 | rev: 3.9.2 30 | hooks: 31 | - id: flake8 32 | args: # arguments to configure flake8 33 | # these are errors that will be ignored by flake8 34 | # check out their meaning here 35 | # https://flake8.pycqa.org/en/latest/user/error-codes.html 36 | - "--ignore=E203,E266,E501,W503,E731,F541,F841" 37 | # Adding args to work with black format 38 | - "--max-line-length=120" 39 | - "--max-complexity=18" 40 | - "--per-file-ignores=__init__.py:F401" 41 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:23.04-py3 2 | 3 | ARG USER_ID 4 | ARG GROUP_ID 5 | ARG USER 6 | RUN addgroup --gid $GROUP_ID $USER 7 | RUN adduser --disabled-password --gecos '' --uid $USER_ID --gid $GROUP_ID $USER 8 | 9 | COPY requirements.txt . 10 | RUN pip3 install -r requirements.txt 11 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ControlNet for Brain T1w Images Generation from FLAIR images using MONAI Generative Models 2 | 3 | Script to train a ControlNet (from [Adding Conditional Control to Text-to-Image Diffusion Models](https://arxiv.org/abs/2302.05543)) 4 | on UK BIOBANK dataset to transform FLAIRs to T1w 2D images using [MONAI Generative Models](https://github.com/Project-MONAI/GenerativeModels) 5 | package. 6 | 7 | ![ControlNet Samples](https://github.com/Warvito/generative_brain_controlnet/blob/main/assets/figure_samples.png?raw=true) 8 | 9 | This repository is part of this Medium post. 10 | 11 | ## Instructions 12 | ### Preprocessing 13 | After Downloading UK Biobank dataset and preprocessing it, we obtain the list of the image paths for the T1w and FLAIR 14 | images. For that, we use the following script: 15 | 16 | 1) `src/python/preprocessing/create_png_dataset.py` - Create png images from the nifti files 17 | 2) `src/python/preprocessing/create_ids.py` - Create files with datalist for training, validation and test 18 | 19 | ### Training 20 | After we obtain the paths, we can train the models using similar commands as in the following files (note: This project was 21 | executed on a cluster with RunAI platform): 22 | 23 | 1) `cluster/runai/training/stage1.sh` - Command to start to execute in the server the training the first stage of the model. 24 | The main python script in for this is the `src/python/training/train_aekl.py` script. The `--volume` flags indicate how the dataset 25 | is mounted in the Docker container. 26 | 3) `cluster/runai/training/ldm.sh` - Command to start to execute in the server the training the diffusion model on the latent representation. 27 | The main python script in for this is the `src/python/training/train_ldm.py` script. The `--volume` flags indicate how the dataset 28 | is mounted in the Docker container. 29 | 4) `cluster/runai/training/controlnet.sh` - Command to start to execute in the server the training the ControlNet model using the pretrained LDM. 30 | The main python script in for this is the `src/python/training/train_controlnet.py` script. The `--volume` flags indicate how the dataset 31 | is mounted in the Docker container. 32 | 33 | These `.sh` files indicates which parameters and configuration file was used for training, as well how the host directories 34 | were mounted in the used Docker container. 35 | 36 | 37 | ### Inference and evaluation 38 | Finally, we converted the mlflow model to .pth files (for easily loading with MONAI), sampled images from the diffusion 39 | model and controlnet, and evaluated the models. The following is the list of execution for inference and evaluation: 40 | 41 | 1) `src/python/testing/convert_mlflow_to_pytorch.py` - Convert mlflow model to .pth files 42 | 2) `src/python/testing/sample_t1w.py` - Sample T1w images from the diffusion model without using contditioning. 43 | 3) `cluster/runai/testing/sample_flair_to_t1w.py` - Sample T1w images from the controlnet using the test set's FLAIR 44 | images as conditionings. 45 | 4) `src/python/testing/compute_msssim_reconstruction.py` - Measure the mean structural similarity index between images and 46 | reconstruction to measure the preformance of the autoencoder. 47 | 5) `src/python/testing/compute_msssim_sample.py` - Measure the mean structural similarity index between samples in order 48 | to measure the diversity of the synthetic data. 49 | 6) `src/python/testing/compute_fid.py` - Compute FID score between generated images and real images. 50 | 7) `src/python/testing/compute_controlnet_performance.py` - Compute the performance of the controlnet using MAE, PSNR and 51 | MS-SSIM metrics. 52 | -------------------------------------------------------------------------------- /assets/figure_samples.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Warvito/generative_brain_controlnet/e6a397645c74bca3f58eb6654a3c82f7c0962692/assets/figure_samples.png -------------------------------------------------------------------------------- /cluster/runai/preprocessing/create_datalist.sh: -------------------------------------------------------------------------------- 1 | runai submit \ 2 | --name create-datalist \ 3 | --image aicregistry:5000/wds20:control_brain \ 4 | --backoff-limit 0 \ 5 | --gpu 0 \ 6 | --cpu 2 \ 7 | --large-shm \ 8 | --run-as-user \ 9 | --host-ipc \ 10 | --project wds20 \ 11 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 12 | --volume /nfs/home/wds20/datasets/Biobank/derivatives/2d_controlnet/:/data/ \ 13 | --command -- bash /project/src/bash/start_script.sh \ 14 | python3 /project/src/python/preprocessing/create_datalist.py 15 | -------------------------------------------------------------------------------- /cluster/runai/preprocessing/create_png_dataset.sh: -------------------------------------------------------------------------------- 1 | for i in {0..39};do 2 | start=$((i*1000)) 3 | stop=$(((i+1)*1000)) 4 | 5 | runai submit \ 6 | --name create-dataset-${start}-${stop} \ 7 | --image aicregistry:5000/wds20:control_brain \ 8 | --backoff-limit 0 \ 9 | --gpu 0 \ 10 | --cpu 2 \ 11 | --large-shm \ 12 | --run-as-user \ 13 | --host-ipc \ 14 | --project wds20 \ 15 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 16 | --volume /nfs/project/AMIGO/Biobank/derivatives/super-res/:/source_data/ \ 17 | --volume /nfs/home/wds20/datasets/Biobank/derivatives/2d_controlnet/:/data/ \ 18 | --command -- bash /project/src/bash/start_script.sh \ 19 | python3 /project/src/python/preprocessing/create_png_dataset.py \ 20 | start=${start} \ 21 | stop=${stop} 22 | sleep 2 23 | done 24 | -------------------------------------------------------------------------------- /cluster/runai/testing/controlnet_performance.sh: -------------------------------------------------------------------------------- 1 | seed=42 2 | samples_dir="/project/outputs/samples/flair_to_t1w/" 3 | test_ids="/project/outputs/ids/test.tsv" 4 | num_workers=8 5 | 6 | runai submit \ 7 | --name controlnet-perf-metrics \ 8 | --image aicregistry:5000/wds20:control_brain \ 9 | --backoff-limit 0 \ 10 | --gpu 1 \ 11 | --cpu 4 \ 12 | --large-shm \ 13 | --run-as-user \ 14 | --host-ipc \ 15 | --project wds20 \ 16 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 17 | --volume /nfs/home/wds20/datasets/Biobank/derivatives/2d_controlnet/:/data/ \ 18 | --command -- bash /project/src/bash/start_script.sh \ 19 | python3 /project/src/python/testing/compute_controlnet_performance.py \ 20 | seed=${seed} \ 21 | samples_dir=${samples_dir} \ 22 | test_ids=${test_ids} \ 23 | num_workers=${num_workers} 24 | -------------------------------------------------------------------------------- /cluster/runai/testing/convert_model.sh: -------------------------------------------------------------------------------- 1 | stage1_mlflow_path="/project/mlruns/837816334068618022/39336906f86c4cdc96fb6464b88c8c06/artifacts/final_model" 2 | diffusion_mlflow_path="/project/mlruns/102676348294480761/a53f700f40184ff49f5f7e27fafece97/artifacts/final_model" 3 | controlnet_mlflow_path="/project/mlruns/672765428205510835/92b0ea370a234caca38810246d4c60b7/artifacts/final_model" 4 | output_dir="/project/outputs/trained_models/" 5 | 6 | runai submit \ 7 | --name controlnet-convert-model \ 8 | --image aicregistry:5000/wds20:control_brain \ 9 | --backoff-limit 0 \ 10 | --gpu 1 \ 11 | --cpu 4 \ 12 | --large-shm \ 13 | --run-as-user \ 14 | --node-type "A100" \ 15 | --host-ipc \ 16 | --project wds20 \ 17 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 18 | --command -- bash /project/src/bash/start_script.sh \ 19 | python3 /project/src/python/testing/convert_mlflow_to_pytorch.py \ 20 | stage1_mlflow_path=${stage1_mlflow_path} \ 21 | diffusion_mlflow_path=${diffusion_mlflow_path} \ 22 | controlnet_mlflow_path=${controlnet_mlflow_path} \ 23 | output_dir=${output_dir} 24 | -------------------------------------------------------------------------------- /cluster/runai/testing/fid.sh: -------------------------------------------------------------------------------- 1 | seed=42 2 | sample_dir="/project/outputs/samples/samples_unconditioned/" 3 | test_ids="/project/outputs/ids/test.tsv" 4 | num_workers=8 5 | batch_size=16 6 | 7 | runai submit \ 8 | --name controlnet-fid \ 9 | --image aicregistry:5000/wds20:control_brain \ 10 | --backoff-limit 0 \ 11 | --gpu 1 \ 12 | --cpu 4 \ 13 | --large-shm \ 14 | --run-as-user \ 15 | --host-ipc \ 16 | --project wds20 \ 17 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 18 | --volume /nfs/home/wds20/datasets/Biobank/derivatives/2d_controlnet/:/data/ \ 19 | --command -- bash /project/src/bash/start_script.sh \ 20 | python3 /project/src/python/testing/compute_fid.py \ 21 | seed=${seed} \ 22 | sample_dir=${sample_dir} \ 23 | test_ids=${test_ids} \ 24 | batch_size=${batch_size} \ 25 | num_workers=${num_workers} 26 | -------------------------------------------------------------------------------- /cluster/runai/testing/msssim.sh: -------------------------------------------------------------------------------- 1 | seed=42 2 | sample_dir="/project/outputs/samples/samples_unconditioned/" 3 | num_workers=8 4 | 5 | runai submit \ 6 | --name controlnet-ssim-sample \ 7 | --image aicregistry:5000/wds20:control_brain \ 8 | --backoff-limit 0 \ 9 | --gpu 1 \ 10 | --cpu 4 \ 11 | --large-shm \ 12 | --run-as-user \ 13 | --host-ipc \ 14 | --project wds20 \ 15 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 16 | --command -- bash /project/src/bash/start_script.sh \ 17 | python3 /project/src/python/testing/compute_msssim_sample.py \ 18 | seed=${seed} \ 19 | sample_dir=${sample_dir} \ 20 | num_workers=${num_workers} 21 | -------------------------------------------------------------------------------- /cluster/runai/testing/msssim_reconstruction.sh: -------------------------------------------------------------------------------- 1 | output_dir="/project/outputs/metrics/" 2 | test_ids="/project/outputs/ids/test.tsv" 3 | config_file="/project/configs/stage1/aekl_v0.yaml" 4 | stage1_path="/project/outputs/trained_models/autoencoder.pth" 5 | seed=42 6 | batch_size=16 7 | num_workers=8 8 | 9 | runai submit \ 10 | --name controlnet-ssim \ 11 | --image aicregistry:5000/wds20:control_brain \ 12 | --backoff-limit 0 \ 13 | --gpu 1 \ 14 | --cpu 4 \ 15 | --large-shm \ 16 | --run-as-user \ 17 | --host-ipc \ 18 | --project wds20 \ 19 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 20 | --volume /nfs/home/wds20/datasets/Biobank/derivatives/2d_controlnet/:/data/ \ 21 | --command -- bash /project/src/bash/start_script.sh \ 22 | python3 /project/src/python/testing/compute_msssim_reconstruction.py \ 23 | seed=${seed} \ 24 | output_dir=${output_dir} \ 25 | test_ids=${test_ids} \ 26 | batch_size=${batch_size} \ 27 | config_file=${config_file} \ 28 | stage1_path=${stage1_path} \ 29 | num_workers=${num_workers} 30 | -------------------------------------------------------------------------------- /cluster/runai/testing/sample_flair_to_t1w.sh: -------------------------------------------------------------------------------- 1 | output_dir="/project/outputs/samples/flair_to_t1w/" 2 | stage1_path="/project/outputs/trained_models/autoencoder.pth" 3 | diffusion_path="/project/outputs/trained_models/diffusion_model.pth" 4 | controlnet_path="/project/outputs/trained_models/controlnet_model.pth" 5 | stage1_config_file_path="/project/configs/stage1/aekl_v0.yaml" 6 | diffusion_config_file_path="/project/configs/ldm/ldm_v0.yaml" 7 | controlnet_config_file_path="/project/configs/controlnet/controlnet_v0.yaml" 8 | test_ids="/project/outputs/ids/test.tsv" 9 | controlnet_scale=1.0 10 | guidance_scale=7.0 11 | x_size=20 12 | y_size=28 13 | scale_factor=0.3 14 | num_workers=8 15 | num_inference_steps=200 16 | 17 | runai submit \ 18 | --name controlnet-sampling \ 19 | --image aicregistry:5000/wds20:control_brain \ 20 | --backoff-limit 0 \ 21 | --gpu 1 \ 22 | --cpu 4 \ 23 | --large-shm \ 24 | --run-as-user \ 25 | --host-ipc \ 26 | --project wds20 \ 27 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 28 | --volume /nfs/home/wds20/datasets/Biobank/derivatives/2d_controlnet/:/data/ \ 29 | --command -- bash /project/src/bash/start_script.sh \ 30 | python3 /project/src/python/testing/sample_flair_to_t1w.py \ 31 | output_dir=${output_dir} \ 32 | stage1_path=${stage1_path} \ 33 | diffusion_path=${diffusion_path} \ 34 | controlnet_path=${controlnet_path} \ 35 | stage1_config_file_path=${stage1_config_file_path} \ 36 | diffusion_config_file_path=${diffusion_config_file_path} \ 37 | controlnet_config_file_path=${controlnet_config_file_path} \ 38 | test_ids=${test_ids} \ 39 | controlnet_scale=${controlnet_scale} \ 40 | guidance_scale=${guidance_scale} \ 41 | x_size=${x_size} \ 42 | y_size=${y_size} \ 43 | scale_factor=${scale_factor} \ 44 | num_workers=${num_workers} \ 45 | num_inference_steps=${num_inference_steps} 46 | -------------------------------------------------------------------------------- /cluster/runai/testing/sample_t1w.sh: -------------------------------------------------------------------------------- 1 | output_dir="/project/outputs/samples/samples_unconditioned/" 2 | stage1_path="/project/outputs/trained_models/autoencoder.pth" 3 | diffusion_path="/project/outputs/trained_models/diffusion_model.pth" 4 | stage1_config_file_path="/project/configs/stage1/aekl_v0.yaml" 5 | diffusion_config_file_path="/project/configs/ldm/ldm_v0.yaml" 6 | start_seed=0 7 | stop_seed=1000 8 | guidance_scale=0.0 9 | x_size=20 10 | y_size=28 11 | scale_factor=0.3 12 | num_inference_steps=200 13 | 14 | runai submit \ 15 | --name controlnet-sampling-unsupervised \ 16 | --image aicregistry:5000/wds20:control_brain \ 17 | --backoff-limit 0 \ 18 | --gpu 1 \ 19 | --cpu 4 \ 20 | --large-shm \ 21 | --run-as-user \ 22 | --host-ipc \ 23 | --project wds20 \ 24 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 25 | --command -- bash /project/src/bash/start_script.sh \ 26 | python3 /project/src/python/testing/sample_t1w.py \ 27 | output_dir=${output_dir} \ 28 | stage1_path=${stage1_path} \ 29 | diffusion_path=${diffusion_path} \ 30 | stage1_config_file_path=${stage1_config_file_path} \ 31 | diffusion_config_file_path=${diffusion_config_file_path} \ 32 | start_seed=${start_seed} \ 33 | stop_seed=${stop_seed} \ 34 | guidance_scale=${guidance_scale} \ 35 | x_size=${x_size} \ 36 | y_size=${y_size} \ 37 | scale_factor=${scale_factor} \ 38 | num_inference_steps=${num_inference_steps} 39 | -------------------------------------------------------------------------------- /cluster/runai/training/controlnet.sh: -------------------------------------------------------------------------------- 1 | seed=42 2 | run_dir="aekl_v0_ldm_v0_controlnet_v0" 3 | training_ids="/project/outputs/ids/train.tsv" 4 | validation_ids="/project/outputs/ids/validation.tsv" 5 | stage1_uri="/project/mlruns/837816334068618022/39336906f86c4cdc96fb6464b88c8c06/artifacts/final_model" 6 | ddpm_uri="/project/mlruns/102676348294480761/a53f700f40184ff49f5f7e27fafece97/artifacts/final_model" 7 | config_file="/project/configs/controlnet/controlnet_v0.yaml" 8 | scale_factor=0.3 9 | batch_size=384 10 | n_epochs=150 11 | eval_freq=10 12 | num_workers=64 13 | experiment="CONTROLNET" 14 | 15 | runai submit \ 16 | --name controlnet-controlnet-v0 \ 17 | --image aicregistry:5000/wds20:control_brain \ 18 | --backoff-limit 0 \ 19 | --gpu 4 \ 20 | --cpu 32 \ 21 | --memory-limit 256G \ 22 | --large-shm \ 23 | --run-as-user \ 24 | --host-ipc \ 25 | --project wds20 \ 26 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 27 | --volume /nfs/home/wds20/datasets/Biobank/derivatives/2d_controlnet/:/data/ \ 28 | --command -- bash /project/src/bash/start_script.sh \ 29 | python3 /project/src/python/training/train_controlnet.py \ 30 | seed=${seed} \ 31 | run_dir=${run_dir} \ 32 | training_ids=${training_ids} \ 33 | validation_ids=${validation_ids} \ 34 | stage1_uri=${stage1_uri} \ 35 | ddpm_uri=${ddpm_uri} \ 36 | config_file=${config_file} \ 37 | scale_factor=${scale_factor} \ 38 | batch_size=${batch_size} \ 39 | n_epochs=${n_epochs} \ 40 | eval_freq=${eval_freq} \ 41 | num_workers=${num_workers} \ 42 | experiment=${experiment} 43 | -------------------------------------------------------------------------------- /cluster/runai/training/ldm.sh: -------------------------------------------------------------------------------- 1 | seed=42 2 | run_dir="aekl_v0_ldm_v0" 3 | training_ids="/project/outputs/ids/train.tsv" 4 | validation_ids="/project/outputs/ids/validation.tsv" 5 | stage1_uri="/project/mlruns/837816334068618022/39336906f86c4cdc96fb6464b88c8c06/artifacts/final_model" 6 | config_file="/project/configs/ldm/ldm_v0.yaml" 7 | scale_factor=0.3 8 | batch_size=512 9 | n_epochs=150 10 | eval_freq=10 11 | num_workers=128 12 | experiment="LDM" 13 | 14 | runai submit \ 15 | --name controlnet-ldm-v0 \ 16 | --image aicregistry:5000/wds20:control_brain \ 17 | --backoff-limit 0 \ 18 | --gpu 4 \ 19 | --cpu 64 \ 20 | --memory-limit 256G \ 21 | --large-shm \ 22 | --run-as-user \ 23 | --host-ipc \ 24 | --project wds20 \ 25 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 26 | --volume /nfs/home/wds20/datasets/Biobank/derivatives/2d_controlnet/:/data/ \ 27 | --command -- bash /project/src/bash/start_script.sh \ 28 | python3 /project/src/python/training/train_ldm.py \ 29 | seed=${seed} \ 30 | run_dir=${run_dir} \ 31 | training_ids=${training_ids} \ 32 | validation_ids=${validation_ids} \ 33 | stage1_uri=${stage1_uri} \ 34 | config_file=${config_file} \ 35 | scale_factor=${scale_factor} \ 36 | batch_size=${batch_size} \ 37 | n_epochs=${n_epochs} \ 38 | eval_freq=${eval_freq} \ 39 | num_workers=${num_workers} \ 40 | experiment=${experiment} 41 | -------------------------------------------------------------------------------- /cluster/runai/training/stage1.sh: -------------------------------------------------------------------------------- 1 | seed=42 2 | run_dir="aekl_v0" 3 | training_ids="/project/outputs/ids/train.tsv" 4 | validation_ids="/project/outputs/ids/validation.tsv" 5 | config_file="/project/configs/stage1/aekl_v0.yaml" 6 | batch_size=256 7 | n_epochs=100 8 | adv_start=10 9 | eval_freq=5 10 | num_workers=64 11 | experiment="AEKL" 12 | 13 | runai submit \ 14 | --name controlnet-aekl-v0 \ 15 | --image aicregistry:5000/wds20:control_brain \ 16 | --backoff-limit 0 \ 17 | --gpu 8 \ 18 | --cpu 32 \ 19 | --memory-limit 512G \ 20 | --node-type "A100" \ 21 | --large-shm \ 22 | --run-as-user \ 23 | --host-ipc \ 24 | --project wds20 \ 25 | --volume /nfs/home/wds20/projects/generative_brain_controlnet/:/project/ \ 26 | --volume /nfs/home/wds20/datasets/Biobank/derivatives/2d_controlnet/:/data/ \ 27 | --command -- bash /project/src/bash/start_script.sh \ 28 | python3 /project/src/python/training/train_aekl.py \ 29 | seed=${seed} \ 30 | run_dir=${run_dir} \ 31 | training_ids=${training_ids} \ 32 | validation_ids=${validation_ids} \ 33 | config_file=${config_file} \ 34 | batch_size=${batch_size} \ 35 | n_epochs=${n_epochs} \ 36 | adv_start=${adv_start} \ 37 | eval_freq=${eval_freq} \ 38 | num_workers=${num_workers} \ 39 | experiment=${experiment} 40 | -------------------------------------------------------------------------------- /configs/controlnet/controlnet_v0.yaml: -------------------------------------------------------------------------------- 1 | controlnet: 2 | base_lr: 0.000025 3 | params: 4 | spatial_dims: 2 5 | in_channels: 3 6 | num_res_blocks: 2 7 | num_channels: [256, 512, 768] 8 | attention_levels: [False, True, True] 9 | with_conditioning: True 10 | cross_attention_dim: 1024 11 | num_head_channels: [0, 512, 768] 12 | conditioning_embedding_in_channels: 1 13 | conditioning_embedding_num_channels: [64, 128, 128, 256] 14 | 15 | 16 | ldm: 17 | scheduler: 18 | schedule: "scaled_linear_beta" 19 | num_train_timesteps: 1000 20 | beta_start: 0.0015 21 | beta_end: 0.0205 22 | prediction_type: "v_prediction" 23 | -------------------------------------------------------------------------------- /configs/ldm/ldm_v0.yaml: -------------------------------------------------------------------------------- 1 | ldm: 2 | base_lr: 0.000025 3 | params: 4 | spatial_dims: 2 5 | in_channels: 3 6 | out_channels: 3 7 | num_res_blocks: 2 8 | num_channels: [256, 512, 768] 9 | attention_levels: [False, True, True] 10 | with_conditioning: True 11 | cross_attention_dim: 1024 12 | num_head_channels: [0, 512, 768] 13 | scheduler: 14 | schedule: "scaled_linear_beta" 15 | num_train_timesteps: 1000 16 | beta_start: 0.0015 17 | beta_end: 0.0205 18 | prediction_type: "v_prediction" 19 | -------------------------------------------------------------------------------- /configs/stage1/aekl_v0.yaml: -------------------------------------------------------------------------------- 1 | stage1: 2 | base_lr: 0.00001 3 | disc_lr: 0.00005 4 | perceptual_weight: 0.002 5 | adv_weight: 0.005 6 | kl_weight: 0.00000001 7 | params: 8 | spatial_dims: 2 9 | in_channels: 1 10 | out_channels: 1 11 | num_channels: [128, 256, 256, 512] 12 | latent_channels: 3 13 | num_res_blocks: 2 14 | attention_levels: [False, False, False, False] 15 | with_encoder_nonlocal_attn: False 16 | with_decoder_nonlocal_attn: False 17 | 18 | discriminator: 19 | params: 20 | spatial_dims: 2 21 | num_channels: 128 22 | num_layers_d: 3 23 | in_channels: 1 24 | out_channels: 1 25 | 26 | perceptual_network: 27 | params: 28 | spatial_dims: 2 29 | network_type: "squeeze" 30 | -------------------------------------------------------------------------------- /create_image.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # 3 | # A simple script to build the distributed Docker image. 4 | # 5 | # $ create_docker_image.sh 6 | set -ex 7 | TAG=control_brain 8 | 9 | docker build --network=host --tag "aicregistry:5000/${USER}:${TAG}" -f Dockerfile . \ 10 | --build-arg USER_ID=$(id -u) \ 11 | --build-arg GROUP_ID=$(id -g) \ 12 | --build-arg USER=${USER} 13 | 14 | docker push "aicregistry:5000/${USER}:${TAG}" 15 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.isort] 2 | line_length = 120 3 | profile = "black" 4 | 5 | [tool.black] 6 | line-length = 120 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | black 2 | gdown 3 | isort 4 | lpips==0.1.4 5 | monai-generative 6 | monai[nibabel, pillow, matplotlib, pandas, mlflow, tqdm]>=1.2.0 7 | numpy==1.22.2 8 | omegaconf==2.3.0 9 | pre-commit 10 | protobuf==3.20.3 11 | pynvml 12 | spacy 13 | tensorboardX 14 | transformers[torch]==4.30.2 15 | -------------------------------------------------------------------------------- /src/bash/start_script.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/bash 2 | #print user info 3 | echo "$(id)" 4 | 5 | # Define mlflow 6 | export MLFLOW_TRACKING_URI=file:/project/mlruns 7 | echo ${MLFLOW_TRACKING_URI} 8 | 9 | # Define place to save lpips pretrained models 10 | export TORCH_HOME=/project/outputs/torch_home 11 | export HF_HOME=/project/outputs/hf_home 12 | 13 | # parse arguments 14 | CMD="" 15 | for i in $@; do 16 | if [[ $i == *"="* ]]; then 17 | ARG=${i//=/ } 18 | CMD=$CMD"--$ARG " 19 | else 20 | CMD=$CMD"$i " 21 | fi 22 | done 23 | 24 | # execute comand 25 | echo $CMD 26 | $CMD 27 | -------------------------------------------------------------------------------- /src/python/preprocessing/create_datalist.py: -------------------------------------------------------------------------------- 1 | """ Script to create train, validation and test data lists with paths to t1w and flair images. """ 2 | from pathlib import Path 3 | 4 | import pandas as pd 5 | 6 | 7 | def create_datalist(sub_dirs): 8 | data_list = [] 9 | for sub_dir in sub_dirs: 10 | images_paths = sorted(list(sub_dir.glob("**/*T1w.png"))) 11 | for image_path in images_paths: 12 | flair_path = image_path.parent / (image_path.name.replace("T1w", "FLAIR")) 13 | data_list.append({"t1w": str(image_path), "flair": flair_path}) 14 | 15 | return pd.DataFrame(data_list) 16 | 17 | 18 | def main(): 19 | output_dir = Path("/project/outputs/ids/") 20 | output_dir.mkdir(parents=True, exist_ok=True) 21 | 22 | data_dir = Path("/data/") 23 | sub_dirs = sorted([x for x in data_dir.iterdir() if x.is_dir()]) 24 | 25 | train_sub_dirs = sub_dirs[:35000] 26 | val_sub_dirs = sub_dirs[35000:35500] 27 | test_sub_dirs = sub_dirs[35500:] 28 | 29 | data_df = create_datalist(train_sub_dirs) 30 | data_df.to_csv(output_dir / "train.tsv", index=False, sep="\t") 31 | 32 | data_df = create_datalist(val_sub_dirs) 33 | data_df.to_csv(output_dir / "validation.tsv", index=False, sep="\t") 34 | 35 | data_df = create_datalist(test_sub_dirs) 36 | data_df.to_csv(output_dir / "test.tsv", index=False, sep="\t") 37 | 38 | 39 | if __name__ == "__main__": 40 | main() 41 | -------------------------------------------------------------------------------- /src/python/preprocessing/create_png_dataset.py: -------------------------------------------------------------------------------- 1 | """ Script to create png dataset. """ 2 | import argparse 3 | from pathlib import Path 4 | 5 | import nibabel as nib 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | 10 | def parse_args(): 11 | parser = argparse.ArgumentParser() 12 | 13 | parser.add_argument("--start", type=int, help="Starting index") 14 | parser.add_argument("--stop", type=int, help="Stoping index") 15 | 16 | args = parser.parse_args() 17 | return args 18 | 19 | 20 | def main(args): 21 | new_data_dir = Path("/data/") 22 | data_dir = Path("/source_data/") 23 | images_paths = sorted(list(data_dir.glob("**/*T1w.nii.gz"))) 24 | images_paths = images_paths[args.start : args.stop] 25 | 26 | for image_path in tqdm(images_paths): 27 | if "unusable" not in str(image_path): 28 | flair_path = Path(str(image_path).replace("T1w", "FLAIR")) 29 | if flair_path.exists(): 30 | new_sub_dir = new_data_dir / str(image_path.parents[1])[13:] 31 | new_sub_dir.mkdir(parents=True, exist_ok=True) 32 | 33 | t1w = nib.load(image_path) 34 | flair = nib.load(flair_path) 35 | 36 | t1w = t1w.get_fdata() 37 | flair = flair.get_fdata() 38 | 39 | t1w = t1w[16:176, 16:240, 166:171] 40 | flair = flair[16:176, 16:240, 166:171] 41 | 42 | t1w = t1w.astype("float32") 43 | flair = flair.astype("float32") 44 | 45 | t1w = (t1w - t1w.min()) / (t1w.max() - t1w.min()) 46 | flair = (flair - flair.min()) / (flair.max() - flair.min()) 47 | 48 | t1w = (t1w * 255).astype("uint8") 49 | flair = (flair * 255).astype("uint8") 50 | 51 | for i in range(t1w.shape[2]): 52 | t1w_slice = Image.fromarray(t1w[:, :, i]) 53 | flair_slice = Image.fromarray(flair[:, :, i]) 54 | 55 | new_image_path = new_sub_dir / f"{image_path.stem.replace('_T1w.nii', f'_slice-{i}_T1w.png')}" 56 | new_flair_path = new_sub_dir / f"{flair_path.stem.replace('_FLAIR.nii', f'_slice-{i}_FLAIR.png')}" 57 | 58 | t1w_slice.save(new_image_path) 59 | flair_slice.save(new_flair_path) 60 | 61 | 62 | if __name__ == "__main__": 63 | args = parse_args() 64 | main(args) 65 | -------------------------------------------------------------------------------- /src/python/testing/compute_controlnet_performance.py: -------------------------------------------------------------------------------- 1 | """ Script to compute the performance of the ControlNet to convert FLAIR to T1w.""" 2 | import argparse 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | from generative.metrics import MultiScaleSSIMMetric 7 | from monai import transforms 8 | from monai.config import print_config 9 | from monai.metrics import MAEMetric, PSNRMetric 10 | from monai.utils import set_determinism 11 | from tqdm import tqdm 12 | from util import get_test_dataloader 13 | 14 | 15 | def parse_args(): 16 | parser = argparse.ArgumentParser() 17 | 18 | parser.add_argument("--seed", type=int, default=2, help="Random seed to use.") 19 | parser.add_argument("--samples_dir", help="Location of the samples to evaluate.") 20 | parser.add_argument("--test_ids", help="Location of file with test ids.") 21 | parser.add_argument("--num_workers", type=int, default=8, help="Number of loader workers") 22 | 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def main(args): 28 | set_determinism(seed=args.seed) 29 | print_config() 30 | 31 | samples_dir = Path(args.samples_dir) 32 | 33 | sample_transforms = transforms.Compose( 34 | [ 35 | transforms.LoadImaged(keys=["t1w"]), 36 | transforms.EnsureChannelFirstd(keys=["t1w"]), 37 | transforms.Rotate90d(keys=["t1w"], k=-1, spatial_axes=(0, 1)), # Fix flipped image read 38 | transforms.Flipd(keys=["t1w"], spatial_axis=1), # Fix flipped image read 39 | transforms.ScaleIntensityRanged(keys=["t1w"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), 40 | transforms.ToTensord(keys=["t1w"]), 41 | ] 42 | ) 43 | 44 | # Test set 45 | test_loader = get_test_dataloader( 46 | batch_size=1, 47 | test_ids=args.test_ids, 48 | num_workers=args.num_workers, 49 | upper_limit=1000, 50 | ) 51 | 52 | psnr_metric = PSNRMetric(max_val=1.0) 53 | mae_metric = MAEMetric() 54 | mssim_metric = MultiScaleSSIMMetric(spatial_dims=2, kernel_size=7) 55 | 56 | psnr_list = [] 57 | mae_list = [] 58 | mssim_list = [] 59 | for batch in tqdm(test_loader): 60 | img = batch["t1w"] 61 | img_synthetic = sample_transforms( 62 | {"t1w": samples_dir / Path(batch["t1w_meta_dict"]["filename_or_obj"][0]).name} 63 | )["t1w"].unsqueeze(1) 64 | 65 | psnr_value = psnr_metric(img, img_synthetic) 66 | mae_value = mae_metric(img, img_synthetic) 67 | mssim_value = mssim_metric(img, img_synthetic) 68 | 69 | psnr_list.append(psnr_value.item()) 70 | mae_list.append(mae_value.item()) 71 | mssim_list.append(mssim_value.item()) 72 | 73 | print(f"PSNR: {np.mean(psnr_list)}+-{np.std(psnr_list)}") 74 | print(f"MAE: {np.mean(mae_list)}+-{np.std(mae_list)}") 75 | print(f"MSSIM: {np.mean(mssim_list)}+-{np.std(mssim_list)}") 76 | 77 | 78 | if __name__ == "__main__": 79 | args = parse_args() 80 | main(args) 81 | -------------------------------------------------------------------------------- /src/python/testing/compute_fid.py: -------------------------------------------------------------------------------- 1 | """ Script to compute the Frechet Inception Distance (FID) of the samples of the LDM. 2 | 3 | In order to measure the quality of the samples, we use the Frechet Inception Distance (FID) metric between 1200 images 4 | from the MIMIC-CXR dataset and 1000 images from the LDM. 5 | """ 6 | import argparse 7 | from pathlib import Path 8 | 9 | import torch 10 | from generative.metrics import FIDMetric 11 | from monai import transforms 12 | from monai.config import print_config 13 | from monai.data import Dataset 14 | from monai.utils import set_determinism 15 | from torch.utils.data import DataLoader 16 | from tqdm import tqdm 17 | from util import get_test_dataloader 18 | 19 | 20 | def subtract_mean(x: torch.Tensor) -> torch.Tensor: 21 | mean = [0.406, 0.456, 0.485] 22 | x[:, 0, :, :] -= mean[0] 23 | x[:, 1, :, :] -= mean[1] 24 | x[:, 2, :, :] -= mean[2] 25 | return x 26 | 27 | 28 | def spatial_average(x: torch.Tensor, keepdim: bool = True) -> torch.Tensor: 29 | return x.mean([2, 3], keepdim=keepdim) 30 | 31 | 32 | def get_features(image, radnet): 33 | # If input has just 1 channel, repeat channel to have 3 channels 34 | if image.shape[1]: 35 | image = image.repeat(1, 3, 1, 1) 36 | 37 | # Change order from 'RGB' to 'BGR' 38 | image = image[:, [2, 1, 0], ...] 39 | 40 | # Subtract mean used during training 41 | image = subtract_mean(image) 42 | 43 | # Get model outputs 44 | with torch.no_grad(): 45 | feature_image = radnet.forward(image) 46 | # flattens the image spatially 47 | feature_image = spatial_average(feature_image, keepdim=False) 48 | 49 | return feature_image 50 | 51 | 52 | def parse_args(): 53 | parser = argparse.ArgumentParser() 54 | 55 | parser.add_argument("--seed", type=int, default=2, help="Random seed to use.") 56 | parser.add_argument("--sample_dir", help="Location of the samples to evaluate.") 57 | parser.add_argument("--test_ids", help="Location of file with test ids.") 58 | parser.add_argument("--batch_size", type=int, default=256, help="Batch size.") 59 | parser.add_argument("--num_workers", type=int, default=8, help="Number of loader workers") 60 | 61 | args = parser.parse_args() 62 | return args 63 | 64 | 65 | def main(args): 66 | set_determinism(seed=args.seed) 67 | print_config() 68 | 69 | samples_dir = Path(args.sample_dir) 70 | 71 | # Load pretrained model 72 | device = torch.device("cuda") 73 | model = torch.hub.load("Warvito/radimagenet-models", model="radimagenet_resnet50", verbose=True) 74 | model = model.to(device) 75 | model.eval() 76 | 77 | # Samples 78 | samples_datalist = [] 79 | for sample_path in sorted(list(samples_dir.glob("*.png"))): 80 | samples_datalist.append( 81 | { 82 | "t1w": str(sample_path), 83 | } 84 | ) 85 | print(f"{len(samples_datalist)} images found in {str(samples_dir)}") 86 | 87 | sample_transforms = transforms.Compose( 88 | [ 89 | transforms.LoadImaged(keys=["t1w"]), 90 | transforms.EnsureChannelFirstd(keys=["t1w"]), 91 | transforms.Rotate90d(keys=["t1w"], k=-1, spatial_axes=(0, 1)), # Fix flipped image read 92 | transforms.Flipd(keys=["t1w"], spatial_axis=1), # Fix flipped image read 93 | transforms.ScaleIntensityRanged(keys=["t1w"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), 94 | transforms.ToTensord(keys=["t1w"]), 95 | ] 96 | ) 97 | 98 | samples_ds = Dataset( 99 | data=samples_datalist, 100 | transform=sample_transforms, 101 | ) 102 | samples_loader = DataLoader( 103 | samples_ds, 104 | batch_size=args.batch_size, 105 | shuffle=False, 106 | num_workers=8, 107 | ) 108 | 109 | samples_features = [] 110 | for batch in tqdm(samples_loader): 111 | img = batch["t1w"] 112 | with torch.no_grad(): 113 | outputs = get_features(img.to(device), radnet=model) 114 | 115 | samples_features.append(outputs.cpu()) 116 | samples_features = torch.cat(samples_features, dim=0) 117 | 118 | # Test set 119 | test_loader = get_test_dataloader( 120 | batch_size=args.batch_size, 121 | test_ids=args.test_ids, 122 | num_workers=args.num_workers, 123 | upper_limit=1000, 124 | ) 125 | 126 | test_features = [] 127 | for batch in tqdm(test_loader): 128 | img = batch["t1w"] 129 | with torch.no_grad(): 130 | outputs = get_features(img.to(device), radnet=model) 131 | 132 | test_features.append(outputs.cpu()) 133 | test_features = torch.cat(test_features, dim=0) 134 | 135 | # Compute FID 136 | metric = FIDMetric() 137 | fid = metric(samples_features, test_features) 138 | 139 | print(f"FID: {fid:.6f}") 140 | 141 | 142 | if __name__ == "__main__": 143 | args = parse_args() 144 | main(args) 145 | -------------------------------------------------------------------------------- /src/python/testing/compute_msssim_reconstruction.py: -------------------------------------------------------------------------------- 1 | """ Script to compute the MS-SSIM score of the reconstructions of the Autoencoder. 2 | 3 | Here we compute the MS-SSIM score between the images of the test set of the MIMIC-CXR dataset and the reconstructions 4 | created byt the AutoencoderKL. 5 | """ 6 | import argparse 7 | from pathlib import Path 8 | 9 | import pandas as pd 10 | import torch 11 | from generative.metrics import MultiScaleSSIMMetric 12 | from generative.networks.nets import AutoencoderKL 13 | from monai.config import print_config 14 | from monai.utils import set_determinism 15 | from omegaconf import OmegaConf 16 | from tqdm import tqdm 17 | from util import get_test_dataloader 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument("--seed", type=int, default=2, help="Random seed to use.") 24 | parser.add_argument("--output_dir", help="Location to save the output.") 25 | parser.add_argument("--test_ids", help="Location of file with test ids.") 26 | parser.add_argument("--batch_size", type=int, default=256, help="Testing batch size.") 27 | parser.add_argument("--config_file", help="Location of config file.") 28 | parser.add_argument("--stage1_path", help="Location of stage1 model.") 29 | parser.add_argument("--num_workers", type=int, default=8, help="Number of loader workers") 30 | 31 | args = parser.parse_args() 32 | return args 33 | 34 | 35 | def main(args): 36 | set_determinism(seed=args.seed) 37 | print_config() 38 | 39 | output_dir = Path(args.output_dir) 40 | output_dir.mkdir(exist_ok=True, parents=True) 41 | 42 | print("Getting data...") 43 | test_loader = get_test_dataloader( 44 | batch_size=args.batch_size, 45 | test_ids=args.test_ids, 46 | num_workers=args.num_workers, 47 | ) 48 | 49 | print("Creating model...") 50 | device = torch.device("cuda") 51 | config = OmegaConf.load(args.config_file) 52 | stage1 = AutoencoderKL(**config["stage1"]["params"]) 53 | stage1.load_state_dict(torch.load(args.stage1_path)) 54 | stage1 = stage1.to(device) 55 | stage1.eval() 56 | 57 | ms_ssim = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_size=7) 58 | 59 | print("Computing MS-SSIM...") 60 | ms_ssim_list = [] 61 | filenames = [] 62 | for batch in tqdm(test_loader): 63 | x = batch["t1w"].to(device) 64 | 65 | with torch.no_grad(): 66 | x_recon = stage1.reconstruct(x) 67 | 68 | ms_ssim_list.append(ms_ssim(x, x_recon)) 69 | filenames.extend(batch["t1w_meta_dict"]["filename_or_obj"]) 70 | 71 | ms_ssim_list = torch.cat(ms_ssim_list, dim=0) 72 | 73 | prediction_df = pd.DataFrame({"filename": filenames, "ms_ssim": ms_ssim_list.cpu()[:, 0]}) 74 | prediction_df.to_csv(output_dir / "ms_ssim_reconstruction.tsv", index=False, sep="\t") 75 | 76 | print(f"Mean MS-SSIM: {ms_ssim_list.mean():.6f}") 77 | 78 | 79 | if __name__ == "__main__": 80 | args = parse_args() 81 | main(args) 82 | -------------------------------------------------------------------------------- /src/python/testing/compute_msssim_sample.py: -------------------------------------------------------------------------------- 1 | """ Script to compute the MS-SSIM score of the samples of the LDM. 2 | 3 | In order to measure the diversity of the samples generated by the LDM, we use the Multi-Scale Structural Similarity 4 | (MS-SSIM) metric between 1000 samples. 5 | """ 6 | import argparse 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import torch 11 | from generative.metrics import MultiScaleSSIMMetric 12 | from monai import transforms 13 | from monai.config import print_config 14 | from monai.data import CacheDataset 15 | from monai.utils import set_determinism 16 | from torch.utils.data import DataLoader 17 | from tqdm import tqdm 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | 23 | parser.add_argument("--seed", type=int, default=2, help="Random seed to use.") 24 | parser.add_argument("--sample_dir", help="Location of the samples to evaluate.") 25 | parser.add_argument("--num_workers", type=int, default=8, help="Number of loader workers") 26 | 27 | args = parser.parse_args() 28 | return args 29 | 30 | 31 | def main(args): 32 | set_determinism(seed=args.seed) 33 | print_config() 34 | 35 | sample_dir = Path(args.sample_dir) 36 | sample_list = sorted(list(sample_dir.glob("*.png"))) 37 | 38 | datalist = [] 39 | for sample_path in sample_list: 40 | datalist.append( 41 | { 42 | "t1w": str(sample_path), 43 | } 44 | ) 45 | 46 | eval_transforms = transforms.Compose( 47 | [ 48 | transforms.LoadImaged(keys=["t1w"]), 49 | transforms.EnsureChannelFirstd(keys=["t1w"]), 50 | transforms.Rotate90d(keys=["t1w"], k=-1, spatial_axes=(0, 1)), # Fix flipped image read 51 | transforms.Flipd(keys=["t1w"], spatial_axis=1), # Fix flipped image read 52 | transforms.ScaleIntensityRanged(keys=["t1w"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), 53 | transforms.ToTensord(keys=["t1w"]), 54 | ] 55 | ) 56 | 57 | eval_ds = CacheDataset( 58 | data=datalist, 59 | transform=eval_transforms, 60 | ) 61 | eval_loader = DataLoader( 62 | eval_ds, 63 | batch_size=1, 64 | shuffle=False, 65 | num_workers=args.num_workers, 66 | ) 67 | 68 | eval_ds_2 = CacheDataset( 69 | data=datalist, 70 | transform=eval_transforms, 71 | ) 72 | eval_loader_2 = DataLoader( 73 | eval_ds_2, 74 | batch_size=1, 75 | shuffle=False, 76 | num_workers=args.num_workers, 77 | ) 78 | 79 | device = torch.device("cuda") 80 | ms_ssim = MultiScaleSSIMMetric(spatial_dims=2, data_range=1.0, kernel_size=7) 81 | 82 | print("Computing MS-SSIM...") 83 | ms_ssim_list = [] 84 | pbar = tqdm(enumerate(eval_loader), total=len(eval_loader)) 85 | for step, batch in pbar: 86 | img = batch["t1w"] 87 | for batch2 in eval_loader_2: 88 | img2 = batch2["t1w"] 89 | if batch["t1w_meta_dict"]["filename_or_obj"][0] == batch2["t1w_meta_dict"]["filename_or_obj"][0]: 90 | continue 91 | ms_ssim_list.append(ms_ssim(img.to(device), img2.to(device)).item()) 92 | pbar.update() 93 | 94 | ms_ssim_list = np.array(ms_ssim_list) 95 | print(f"Mean MS-SSIM: {ms_ssim_list.mean():.6f}") 96 | 97 | 98 | if __name__ == "__main__": 99 | args = parse_args() 100 | main(args) 101 | -------------------------------------------------------------------------------- /src/python/testing/convert_mlflow_to_pytorch.py: -------------------------------------------------------------------------------- 1 | """ Script to convert the model from mlflow format to a format suitable for release (.pth). 2 | 3 | All the following scripts will use the .pth format (easly shared). 4 | """ 5 | import argparse 6 | from pathlib import Path 7 | 8 | import mlflow.pytorch 9 | import torch 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | 15 | parser.add_argument("--stage1_mlflow_path", help="Path to the MLFlow artifact of the stage1.") 16 | parser.add_argument("--diffusion_mlflow_path", help="Path to the MLFlow artifact of the diffusion model.") 17 | parser.add_argument("--controlnet_mlflow_path", help="Path to the MLFlow artifact of the diffusion model.") 18 | parser.add_argument("--output_dir", help="Path to save the .pth file of the diffusion model.") 19 | 20 | args = parser.parse_args() 21 | return args 22 | 23 | 24 | def main(args): 25 | output_dir = Path(args.output_dir) 26 | output_dir.mkdir(exist_ok=True) 27 | 28 | stage1_model = mlflow.pytorch.load_model(args.stage1_mlflow_path) 29 | torch.save(stage1_model.state_dict(), output_dir / "autoencoder.pth") 30 | 31 | diffusion_model = mlflow.pytorch.load_model(args.diffusion_mlflow_path) 32 | torch.save(diffusion_model.state_dict(), output_dir / "diffusion_model.pth") 33 | 34 | controlnet_model = mlflow.pytorch.load_model(args.controlnet_mlflow_path) 35 | torch.save(controlnet_model.state_dict(), output_dir / "controlnet_model.pth") 36 | 37 | 38 | if __name__ == "__main__": 39 | args = parse_args() 40 | main(args) 41 | -------------------------------------------------------------------------------- /src/python/testing/sample_flair_to_t1w.py: -------------------------------------------------------------------------------- 1 | """ Script to generate sample images from the diffusion model. 2 | 3 | In the generation of the images, the script is using a DDIM scheduler. 4 | """ 5 | from __future__ import annotations 6 | 7 | import argparse 8 | from pathlib import Path 9 | 10 | import numpy as np 11 | import torch 12 | from generative.networks.nets import AutoencoderKL, ControlNet, DiffusionModelUNet 13 | from generative.networks.schedulers import DDIMScheduler 14 | from monai.config import print_config 15 | from omegaconf import OmegaConf 16 | from PIL import Image 17 | from tqdm import tqdm 18 | from transformers import CLIPTextModel, CLIPTokenizer 19 | from util import get_test_dataloader 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser() 24 | 25 | parser.add_argument("--output_dir", help="Path to save the .pth file of the diffusion model.") 26 | parser.add_argument("--stage1_path", help="Path to the .pth model from the stage1.") 27 | parser.add_argument("--diffusion_path", help="Path to the .pth model from the diffusion model.") 28 | parser.add_argument("--controlnet_path", help="Path to the .pth model from the diffusion model.") 29 | parser.add_argument("--stage1_config_file_path", help="Path to the .pth model from the stage1.") 30 | parser.add_argument("--diffusion_config_file_path", help="Path to the .pth model from the diffusion model.") 31 | parser.add_argument("--controlnet_config_file_path", help="Path to the .pth model from the diffusion model.") 32 | parser.add_argument("--test_ids", help="Location of file with test ids.") 33 | parser.add_argument("--controlnet_scale", type=float, default=1.0, help="") 34 | parser.add_argument("--guidance_scale", type=float, default=7.0, help="") 35 | parser.add_argument("--x_size", type=int, default=64, help="Latent space x size.") 36 | parser.add_argument("--y_size", type=int, default=64, help="Latent space y size.") 37 | parser.add_argument("--num_workers", type=int, help="") 38 | parser.add_argument("--scale_factor", type=float, help="Latent space y size.") 39 | parser.add_argument("--num_inference_steps", type=int, help="") 40 | 41 | args = parser.parse_args() 42 | return args 43 | 44 | 45 | def main(args): 46 | print_config() 47 | 48 | output_dir = Path(args.output_dir) 49 | output_dir.mkdir(exist_ok=True, parents=True) 50 | 51 | device = torch.device("cuda") 52 | 53 | config = OmegaConf.load(args.stage1_config_file_path) 54 | stage1 = AutoencoderKL(**config["stage1"]["params"]) 55 | stage1.load_state_dict(torch.load(args.stage1_path)) 56 | stage1.to(device) 57 | stage1.eval() 58 | 59 | config = OmegaConf.load(args.diffusion_config_file_path) 60 | diffusion = DiffusionModelUNet(**config["ldm"].get("params", dict())) 61 | diffusion.load_state_dict(torch.load(args.diffusion_path)) 62 | diffusion.to(device) 63 | diffusion.eval() 64 | 65 | config = OmegaConf.load(args.controlnet_config_file_path) 66 | controlnet = ControlNet(**config["controlnet"].get("params", dict())) 67 | controlnet.load_state_dict(torch.load(args.controlnet_path)) 68 | controlnet.to(device) 69 | controlnet.eval() 70 | 71 | scheduler = DDIMScheduler( 72 | num_train_timesteps=config["ldm"]["scheduler"]["num_train_timesteps"], 73 | beta_start=config["ldm"]["scheduler"]["beta_start"], 74 | beta_end=config["ldm"]["scheduler"]["beta_end"], 75 | schedule=config["ldm"]["scheduler"]["schedule"], 76 | prediction_type=config["ldm"]["scheduler"]["prediction_type"], 77 | clip_sample=False, 78 | ) 79 | scheduler.set_timesteps(args.num_inference_steps) 80 | 81 | tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer") 82 | text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder") 83 | 84 | prompt = ["", "T1-weighted image of a brain."] 85 | text_inputs = tokenizer( 86 | prompt, 87 | padding="max_length", 88 | max_length=tokenizer.model_max_length, 89 | truncation=True, 90 | return_tensors="pt", 91 | ) 92 | text_input_ids = text_inputs.input_ids 93 | 94 | prompt_embeds = text_encoder(text_input_ids.squeeze(1)) 95 | prompt_embeds = prompt_embeds[0].to(device) 96 | 97 | test_loader = get_test_dataloader( 98 | batch_size=1, 99 | test_ids=args.test_ids, 100 | num_workers=args.num_workers, 101 | upper_limit=1000, 102 | ) 103 | 104 | pbar = tqdm(enumerate(test_loader), total=len(test_loader)) 105 | for i, x in pbar: 106 | images = x["t1w"].to(device) 107 | cond = x["flair"].to(device) 108 | 109 | noise = torch.randn((1, config["controlnet"]["params"]["in_channels"], args.x_size, args.y_size)).to(device) 110 | 111 | with torch.no_grad(): 112 | progress_bar = tqdm(scheduler.timesteps) 113 | for t in progress_bar: 114 | noise_input = torch.cat([noise] * 2) 115 | cond_input = torch.cat([cond] * 2) 116 | down_block_res_samples, mid_block_res_sample = controlnet( 117 | x=noise_input, 118 | timesteps=torch.Tensor((t,)).to(noise.device).long(), 119 | context=prompt_embeds, 120 | controlnet_cond=cond_input, 121 | conditioning_scale=args.controlnet_scale, 122 | ) 123 | 124 | model_output = diffusion( 125 | x=noise_input, 126 | timesteps=torch.Tensor((t,)).to(noise.device).long(), 127 | context=prompt_embeds, 128 | down_block_additional_residuals=down_block_res_samples, 129 | mid_block_additional_residual=mid_block_res_sample, 130 | ) 131 | noise_pred_uncond, noise_pred_text = model_output.chunk(2) 132 | noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_text - noise_pred_uncond) 133 | 134 | noise, _ = scheduler.step(noise_pred, t, noise) 135 | 136 | with torch.no_grad(): 137 | sample = stage1.decode_stage_2_outputs(noise / args.scale_factor) 138 | 139 | sample = np.clip(sample.cpu().numpy(), 0, 1) 140 | sample = (sample * 255).astype(np.uint8) 141 | im = Image.fromarray(sample[0, 0]) 142 | im.save(output_dir / Path(x["t1w_meta_dict"]["filename_or_obj"][0]).name) 143 | 144 | cond = (cond.cpu().numpy() * 255).astype(np.uint8) 145 | images = (images.cpu().numpy() * 255).astype(np.uint8) 146 | sample = np.concatenate([sample, cond, images], axis=3) 147 | im = Image.fromarray(sample[0, 0]) 148 | im.save(output_dir / f"{Path(x['t1w_meta_dict']['filename_or_obj'][0]).stem}_comparison.png") 149 | 150 | 151 | if __name__ == "__main__": 152 | args = parse_args() 153 | main(args) 154 | -------------------------------------------------------------------------------- /src/python/testing/sample_t1w.py: -------------------------------------------------------------------------------- 1 | """ Script to generate sample images from the diffusion model. 2 | 3 | In the generation of the images, the script is using a DDIM scheduler. 4 | """ 5 | 6 | import argparse 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import torch 11 | from generative.networks.nets import AutoencoderKL, DiffusionModelUNet 12 | from generative.networks.schedulers import DDIMScheduler 13 | from monai.config import print_config 14 | from monai.utils import set_determinism 15 | from omegaconf import OmegaConf 16 | from PIL import Image 17 | from tqdm import tqdm 18 | from transformers import CLIPTextModel, CLIPTokenizer 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument("--output_dir", help="Path to save the .pth file of the diffusion model.") 25 | parser.add_argument("--stage1_path", help="Path to the .pth model from the stage1.") 26 | parser.add_argument("--diffusion_path", help="Path to the .pth model from the diffusion model.") 27 | parser.add_argument("--stage1_config_file_path", help="Path to the .pth model from the stage1.") 28 | parser.add_argument("--diffusion_config_file_path", help="Path to the .pth model from the diffusion model.") 29 | parser.add_argument("--start_seed", type=int, help="Path to the MLFlow artifact of the stage1.") 30 | parser.add_argument("--stop_seed", type=int, help="Path to the MLFlow artifact of the stage1.") 31 | parser.add_argument("--guidance_scale", type=float, default=7.0, help="") 32 | parser.add_argument("--x_size", type=int, default=64, help="Latent space x size.") 33 | parser.add_argument("--y_size", type=int, default=64, help="Latent space y size.") 34 | parser.add_argument("--scale_factor", type=float, help="Latent space y size.") 35 | parser.add_argument("--num_inference_steps", type=int, help="") 36 | 37 | args = parser.parse_args() 38 | return args 39 | 40 | 41 | def main(args): 42 | print_config() 43 | 44 | output_dir = Path(args.output_dir) 45 | output_dir.mkdir(exist_ok=True, parents=True) 46 | 47 | device = torch.device("cuda") 48 | 49 | config = OmegaConf.load(args.stage1_config_file_path) 50 | stage1 = AutoencoderKL(**config["stage1"]["params"]) 51 | stage1.load_state_dict(torch.load(args.stage1_path)) 52 | stage1.to(device) 53 | stage1.eval() 54 | 55 | config = OmegaConf.load(args.diffusion_config_file_path) 56 | diffusion = DiffusionModelUNet(**config["ldm"].get("params", dict())) 57 | diffusion.load_state_dict(torch.load(args.diffusion_path)) 58 | diffusion.to(device) 59 | diffusion.eval() 60 | 61 | scheduler = DDIMScheduler( 62 | num_train_timesteps=config["ldm"]["scheduler"]["num_train_timesteps"], 63 | beta_start=config["ldm"]["scheduler"]["beta_start"], 64 | beta_end=config["ldm"]["scheduler"]["beta_end"], 65 | schedule=config["ldm"]["scheduler"]["schedule"], 66 | prediction_type=config["ldm"]["scheduler"]["prediction_type"], 67 | clip_sample=False, 68 | ) 69 | scheduler.set_timesteps(args.num_inference_steps) 70 | 71 | tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer") 72 | text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder") 73 | 74 | prompt = ["", "T1-weighted image of a brain."] 75 | text_inputs = tokenizer( 76 | prompt, 77 | padding="max_length", 78 | max_length=tokenizer.model_max_length, 79 | truncation=True, 80 | return_tensors="pt", 81 | ) 82 | text_input_ids = text_inputs.input_ids 83 | 84 | prompt_embeds = text_encoder(text_input_ids.squeeze(1)) 85 | prompt_embeds = prompt_embeds[0].to(device) 86 | 87 | for i in range(args.start_seed, args.stop_seed): 88 | set_determinism(seed=i) 89 | noise = torch.randn((1, config["ldm"]["params"]["in_channels"], args.x_size, args.y_size)).to(device) 90 | 91 | with torch.no_grad(): 92 | progress_bar = tqdm(scheduler.timesteps) 93 | for t in progress_bar: 94 | noise_input = torch.cat([noise] * 2) 95 | model_output = diffusion( 96 | noise_input, timesteps=torch.Tensor((t,)).to(noise.device).long(), context=prompt_embeds 97 | ) 98 | noise_pred_uncond, noise_pred_text = model_output.chunk(2) 99 | noise_pred = noise_pred_uncond + args.guidance_scale * (noise_pred_text - noise_pred_uncond) 100 | 101 | noise, _ = scheduler.step(noise_pred, t, noise) 102 | 103 | with torch.no_grad(): 104 | sample = stage1.decode_stage_2_outputs(noise / args.scale_factor) 105 | 106 | sample = np.clip(sample.cpu().numpy(), 0, 1) 107 | sample = (sample * 255).astype(np.uint8) 108 | im = Image.fromarray(sample[0, 0]) 109 | im.save(output_dir / f"sample_{i}.png") 110 | 111 | 112 | if __name__ == "__main__": 113 | args = parse_args() 114 | main(args) 115 | -------------------------------------------------------------------------------- /src/python/testing/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for testing.""" 2 | from __future__ import annotations 3 | 4 | import pandas as pd 5 | from monai import transforms 6 | from monai.data import CacheDataset 7 | from torch.utils.data import DataLoader 8 | 9 | 10 | def get_test_dataloader( 11 | batch_size: int, 12 | test_ids: str, 13 | num_workers: int = 8, 14 | upper_limit: int | None = None, 15 | ): 16 | test_transforms = transforms.Compose( 17 | [ 18 | transforms.LoadImaged(keys=["t1w", "flair"]), 19 | transforms.EnsureChannelFirstd(keys=["t1w", "flair"]), 20 | transforms.Rotate90d(keys=["t1w", "flair"], k=-1, spatial_axes=(0, 1)), # Fix flipped image read 21 | transforms.Flipd(keys=["t1w", "flair"], spatial_axis=1), # Fix flipped image read 22 | transforms.ScaleIntensityRanged( 23 | keys=["t1w", "flair"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True 24 | ), 25 | transforms.ToTensord(keys=["t1w", "flair"]), 26 | ] 27 | ) 28 | 29 | test_dicts = get_datalist(ids_path=test_ids, upper_limit=upper_limit) 30 | test_ds = CacheDataset(data=test_dicts, transform=test_transforms) 31 | test_loader = DataLoader( 32 | test_ds, 33 | batch_size=batch_size, 34 | shuffle=False, 35 | num_workers=num_workers, 36 | drop_last=False, 37 | pin_memory=False, 38 | persistent_workers=True, 39 | ) 40 | 41 | return test_loader 42 | 43 | 44 | def get_datalist( 45 | ids_path: str, 46 | upper_limit: int | None = None, 47 | ): 48 | """Get data dicts for data loaders.""" 49 | df = pd.read_csv(ids_path, sep="\t") 50 | 51 | if upper_limit is not None: 52 | df = df[:upper_limit] 53 | 54 | data_dicts = [] 55 | for index, row in df.iterrows(): 56 | data_dicts.append( 57 | { 58 | "t1w": f"{row['t1w']}", 59 | "flair": f"{row['flair']}", 60 | "report": "T1-weighted image of a brain.", 61 | } 62 | ) 63 | 64 | print(f"Found {len(data_dicts)} subjects.") 65 | return data_dicts 66 | -------------------------------------------------------------------------------- /src/python/training/custom_transforms.py: -------------------------------------------------------------------------------- 1 | """Custom transforms to load non-imaging data.""" 2 | from typing import Optional 3 | 4 | from monai.config import KeysCollection 5 | from monai.data.image_reader import ImageReader 6 | from monai.transforms.transform import MapTransform, Transform 7 | from transformers import CLIPTokenizer 8 | 9 | 10 | class ApplyTokenizer(Transform): 11 | """Transformation to apply the CLIP tokenizer.""" 12 | 13 | def __init__(self) -> None: 14 | self.tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="tokenizer") 15 | 16 | def __call__(self, text_input: str): 17 | tokenized_sentence = self.tokenizer( 18 | text_input, 19 | padding="max_length", 20 | max_length=self.tokenizer.model_max_length, 21 | truncation=True, 22 | return_tensors="pt", 23 | ) 24 | return tokenized_sentence.input_ids 25 | 26 | 27 | class ApplyTokenizerd(MapTransform): 28 | def __init__( 29 | self, 30 | keys: KeysCollection, 31 | allow_missing_keys: bool = False, 32 | *args, 33 | **kwargs, 34 | ) -> None: 35 | super().__init__(keys, allow_missing_keys) 36 | self._padding = ApplyTokenizer(*args, **kwargs) 37 | 38 | def __call__(self, data, reader: Optional[ImageReader] = None): 39 | d = dict(data) 40 | for key in self.key_iterator(d): 41 | data = self._padding(d[key]) 42 | d[key] = data 43 | 44 | return d 45 | -------------------------------------------------------------------------------- /src/python/training/train_aekl.py: -------------------------------------------------------------------------------- 1 | """ Training script for the autoencoder with KL regulization. """ 2 | import argparse 3 | import warnings 4 | from pathlib import Path 5 | 6 | import torch 7 | import torch.optim as optim 8 | from generative.losses.perceptual import PerceptualLoss 9 | from generative.networks.nets import AutoencoderKL 10 | from generative.networks.nets.patchgan_discriminator import PatchDiscriminator 11 | from monai.config import print_config 12 | from monai.utils import set_determinism 13 | from omegaconf import OmegaConf 14 | from tensorboardX import SummaryWriter 15 | from training_functions import train_aekl 16 | from util import get_dataloader, log_mlflow 17 | 18 | warnings.filterwarnings("ignore") 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument("--seed", type=int, default=2, help="Random seed to use.") 25 | parser.add_argument("--run_dir", help="Location of model to resume.") 26 | parser.add_argument("--training_ids", help="Location of file with training ids.") 27 | parser.add_argument("--validation_ids", help="Location of file with validation ids.") 28 | parser.add_argument("--config_file", help="Location of file with validation ids.") 29 | parser.add_argument("--batch_size", type=int, default=256, help="Training batch size.") 30 | parser.add_argument("--n_epochs", type=int, default=25, help="Number of epochs to train.") 31 | parser.add_argument("--adv_start", type=int, default=25, help="Epoch when the adversarial training starts.") 32 | parser.add_argument("--eval_freq", type=int, default=10, help="Number of epochs to between evaluations.") 33 | parser.add_argument("--num_workers", type=int, default=8, help="Number of loader workers") 34 | parser.add_argument("--experiment", help="Mlflow experiment name.") 35 | 36 | args = parser.parse_args() 37 | return args 38 | 39 | 40 | def main(args): 41 | set_determinism(seed=args.seed) 42 | print_config() 43 | 44 | output_dir = Path("/project/outputs/runs/") 45 | output_dir.mkdir(exist_ok=True, parents=True) 46 | 47 | run_dir = output_dir / args.run_dir 48 | if run_dir.exists() and (run_dir / "checkpoint.pth").exists(): 49 | resume = True 50 | else: 51 | resume = False 52 | run_dir.mkdir(exist_ok=True) 53 | 54 | print(f"Run directory: {str(run_dir)}") 55 | print(f"Arguments: {str(args)}") 56 | for k, v in vars(args).items(): 57 | print(f" {k}: {v}") 58 | 59 | writer_train = SummaryWriter(log_dir=str(run_dir / "train")) 60 | writer_val = SummaryWriter(log_dir=str(run_dir / "val")) 61 | 62 | print("Getting data...") 63 | cache_dir = output_dir / "cached_data_aekl" 64 | cache_dir.mkdir(exist_ok=True) 65 | 66 | train_loader, val_loader = get_dataloader( 67 | cache_dir=cache_dir, 68 | batch_size=args.batch_size, 69 | training_ids=args.training_ids, 70 | validation_ids=args.validation_ids, 71 | num_workers=args.num_workers, 72 | model_type="autoencoder", 73 | ) 74 | 75 | print("Creating model...") 76 | config = OmegaConf.load(args.config_file) 77 | model = AutoencoderKL(**config["stage1"]["params"]) 78 | discriminator = PatchDiscriminator(**config["discriminator"]["params"]) 79 | perceptual_loss = PerceptualLoss(**config["perceptual_network"]["params"]) 80 | 81 | print(f"Let's use {torch.cuda.device_count()} GPUs!") 82 | device = torch.device("cuda") 83 | if torch.cuda.device_count() > 1: 84 | model = torch.nn.DataParallel(model) 85 | discriminator = torch.nn.DataParallel(discriminator) 86 | perceptual_loss = torch.nn.DataParallel(perceptual_loss) 87 | 88 | model = model.to(device) 89 | perceptual_loss = perceptual_loss.to(device) 90 | discriminator = discriminator.to(device) 91 | 92 | # Optimizers 93 | optimizer_g = optim.Adam(model.parameters(), lr=config["stage1"]["base_lr"]) 94 | optimizer_d = optim.Adam(discriminator.parameters(), lr=config["stage1"]["disc_lr"]) 95 | 96 | # Get Checkpoint 97 | best_loss = float("inf") 98 | start_epoch = 0 99 | if resume: 100 | print(f"Using checkpoint!") 101 | checkpoint = torch.load(str(run_dir / "checkpoint.pth")) 102 | model.load_state_dict(checkpoint["state_dict"]) 103 | discriminator.load_state_dict(checkpoint["discriminator"]) 104 | optimizer_g.load_state_dict(checkpoint["optimizer_g"]) 105 | optimizer_d.load_state_dict(checkpoint["optimizer_d"]) 106 | start_epoch = checkpoint["epoch"] 107 | best_loss = checkpoint["best_loss"] 108 | else: 109 | print(f"No checkpoint found.") 110 | 111 | # Train model 112 | print(f"Starting Training") 113 | val_loss = train_aekl( 114 | model=model, 115 | discriminator=discriminator, 116 | perceptual_loss=perceptual_loss, 117 | start_epoch=start_epoch, 118 | best_loss=best_loss, 119 | train_loader=train_loader, 120 | val_loader=val_loader, 121 | optimizer_g=optimizer_g, 122 | optimizer_d=optimizer_d, 123 | n_epochs=args.n_epochs, 124 | eval_freq=args.eval_freq, 125 | writer_train=writer_train, 126 | writer_val=writer_val, 127 | device=device, 128 | run_dir=run_dir, 129 | kl_weight=config["stage1"]["kl_weight"], 130 | adv_weight=config["stage1"]["adv_weight"], 131 | perceptual_weight=config["stage1"]["perceptual_weight"], 132 | adv_start=args.adv_start, 133 | ) 134 | 135 | log_mlflow( 136 | model=model, 137 | config=config, 138 | args=args, 139 | experiment=args.experiment, 140 | run_dir=run_dir, 141 | val_loss=val_loss, 142 | ) 143 | 144 | 145 | if __name__ == "__main__": 146 | args = parse_args() 147 | main(args) 148 | -------------------------------------------------------------------------------- /src/python/training/train_controlnet.py: -------------------------------------------------------------------------------- 1 | """ Training script for the controlnet model in the latent space of the pretraine AEKL model. """ 2 | import argparse 3 | import warnings 4 | from pathlib import Path 5 | 6 | import mlflow.pytorch 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from generative.networks.nets import ControlNet 11 | from generative.networks.schedulers import DDPMScheduler 12 | from monai.config import print_config 13 | from monai.utils import set_determinism 14 | from omegaconf import OmegaConf 15 | from tensorboardX import SummaryWriter 16 | from training_functions import train_controlnet 17 | from transformers import CLIPTextModel 18 | from util import get_dataloader, log_mlflow 19 | 20 | warnings.filterwarnings("ignore") 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser() 25 | 26 | parser.add_argument("--seed", type=int, default=2, help="Random seed to use.") 27 | parser.add_argument("--run_dir", help="Location of model to resume.") 28 | parser.add_argument("--training_ids", help="Location of file with training ids.") 29 | parser.add_argument("--validation_ids", help="Location of file with validation ids.") 30 | parser.add_argument("--config_file", help="Location of file with validation ids.") 31 | parser.add_argument("--stage1_uri", help="Path readable by load_model.") 32 | parser.add_argument("--ddpm_uri", help="Path readable by load_model.") 33 | parser.add_argument("--scale_factor", type=float, help="Path readable by load_model.") 34 | parser.add_argument("--batch_size", type=int, default=256, help="Training batch size.") 35 | parser.add_argument("--n_epochs", type=int, default=25, help="Number of epochs to train.") 36 | parser.add_argument("--eval_freq", type=int, default=10, help="Number of epochs to between evaluations.") 37 | parser.add_argument("--num_workers", type=int, default=8, help="Number of loader workers") 38 | parser.add_argument("--experiment", help="Mlflow experiment name.") 39 | 40 | args = parser.parse_args() 41 | return args 42 | 43 | 44 | class Stage1Wrapper(nn.Module): 45 | """Wrapper for stage 1 model as a workaround for the DataParallel usage in the training loop.""" 46 | 47 | def __init__(self, model: nn.Module) -> None: 48 | super().__init__() 49 | self.model = model 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | z_mu, z_sigma = self.model.encode(x) 53 | z = self.model.sampling(z_mu, z_sigma) 54 | return z 55 | 56 | 57 | def main(args): 58 | set_determinism(seed=args.seed) 59 | print_config() 60 | 61 | output_dir = Path("/project/outputs/runs/") 62 | output_dir.mkdir(exist_ok=True, parents=True) 63 | 64 | run_dir = output_dir / args.run_dir 65 | if run_dir.exists() and (run_dir / "checkpoint.pth").exists(): 66 | resume = True 67 | else: 68 | resume = False 69 | run_dir.mkdir(exist_ok=True) 70 | 71 | print(f"Run directory: {str(run_dir)}") 72 | print(f"Arguments: {str(args)}") 73 | for k, v in vars(args).items(): 74 | print(f" {k}: {v}") 75 | 76 | writer_train = SummaryWriter(log_dir=str(run_dir / "train")) 77 | writer_val = SummaryWriter(log_dir=str(run_dir / "val")) 78 | 79 | print("Getting data...") 80 | cache_dir = output_dir / "cached_data_controlnet" 81 | cache_dir.mkdir(exist_ok=True) 82 | 83 | train_loader, val_loader = get_dataloader( 84 | cache_dir=cache_dir, 85 | batch_size=args.batch_size, 86 | training_ids=args.training_ids, 87 | validation_ids=args.validation_ids, 88 | num_workers=args.num_workers, 89 | model_type="controlnet", 90 | ) 91 | 92 | print(f"Loading Stage 1 from {args.stage1_uri}") 93 | stage1 = mlflow.pytorch.load_model(args.stage1_uri) 94 | stage1 = Stage1Wrapper(model=stage1) 95 | stage1.eval() 96 | 97 | print(f"Loading Diffusion rom {args.ddpm_uri}") 98 | diffusion = mlflow.pytorch.load_model(args.ddpm_uri) 99 | diffusion.eval() 100 | 101 | print("Creating model...") 102 | config = OmegaConf.load(args.config_file) 103 | controlnet = ControlNet(**config["controlnet"].get("params", dict())) 104 | 105 | # Copy weights from the DM to the controlnet 106 | controlnet.load_state_dict(diffusion.state_dict(), strict=False) 107 | 108 | # Freeze the weights of the diffusion model 109 | for p in diffusion.parameters(): 110 | p.requires_grad = False 111 | 112 | scheduler = DDPMScheduler(**config["ldm"].get("scheduler", dict())) 113 | 114 | text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder") 115 | 116 | print(f"Let's use {torch.cuda.device_count()} GPUs!") 117 | device = torch.device("cuda") 118 | if torch.cuda.device_count() > 1: 119 | stage1 = torch.nn.DataParallel(stage1) 120 | diffusion = torch.nn.DataParallel(diffusion) 121 | controlnet = torch.nn.DataParallel(controlnet) 122 | text_encoder = torch.nn.DataParallel(text_encoder) 123 | 124 | stage1 = stage1.to(device) 125 | diffusion = diffusion.to(device) 126 | controlnet = controlnet.to(device) 127 | text_encoder = text_encoder.to(device) 128 | 129 | optimizer = optim.AdamW(controlnet.parameters(), lr=config["controlnet"]["base_lr"]) 130 | 131 | # Get Checkpoint 132 | best_loss = float("inf") 133 | start_epoch = 0 134 | if resume: 135 | print(f"Using checkpoint!") 136 | checkpoint = torch.load(str(run_dir / "checkpoint.pth")) 137 | controlnet.load_state_dict(checkpoint["controlnet"]) 138 | # Issue loading optimizer https://github.com/pytorch/pytorch/issues/2830 139 | optimizer.load_state_dict(checkpoint["optimizer"]) 140 | start_epoch = checkpoint["epoch"] 141 | best_loss = checkpoint["best_loss"] 142 | else: 143 | print(f"No checkpoint found.") 144 | 145 | # Train model 146 | print(f"Starting Training") 147 | val_loss = train_controlnet( 148 | controlnet=controlnet, 149 | diffusion=diffusion, 150 | stage1=stage1, 151 | scheduler=scheduler, 152 | text_encoder=text_encoder, 153 | start_epoch=start_epoch, 154 | best_loss=best_loss, 155 | train_loader=train_loader, 156 | val_loader=val_loader, 157 | optimizer=optimizer, 158 | n_epochs=args.n_epochs, 159 | eval_freq=args.eval_freq, 160 | writer_train=writer_train, 161 | writer_val=writer_val, 162 | device=device, 163 | run_dir=run_dir, 164 | scale_factor=args.scale_factor, 165 | ) 166 | 167 | log_mlflow( 168 | model=controlnet, 169 | config=config, 170 | args=args, 171 | experiment=args.experiment, 172 | run_dir=run_dir, 173 | val_loss=val_loss, 174 | ) 175 | 176 | 177 | if __name__ == "__main__": 178 | args = parse_args() 179 | main(args) 180 | -------------------------------------------------------------------------------- /src/python/training/train_ldm.py: -------------------------------------------------------------------------------- 1 | """ Training script for the diffusion model in the latent space of the pretraine AEKL model. """ 2 | import argparse 3 | import warnings 4 | from pathlib import Path 5 | 6 | import mlflow.pytorch 7 | import torch 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from generative.networks.nets import DiffusionModelUNet 11 | from generative.networks.schedulers import DDPMScheduler 12 | from monai.config import print_config 13 | from monai.utils import set_determinism 14 | from omegaconf import OmegaConf 15 | from tensorboardX import SummaryWriter 16 | from training_functions import train_ldm 17 | from transformers import CLIPTextModel 18 | from util import get_dataloader, log_mlflow 19 | 20 | warnings.filterwarnings("ignore") 21 | 22 | 23 | def parse_args(): 24 | parser = argparse.ArgumentParser() 25 | 26 | parser.add_argument("--seed", type=int, default=2, help="Random seed to use.") 27 | parser.add_argument("--run_dir", help="Location of model to resume.") 28 | parser.add_argument("--training_ids", help="Location of file with training ids.") 29 | parser.add_argument("--validation_ids", help="Location of file with validation ids.") 30 | parser.add_argument("--config_file", help="Location of file with validation ids.") 31 | parser.add_argument("--stage1_uri", help="Path readable by load_model.") 32 | parser.add_argument("--scale_factor", type=float, help="Path readable by load_model.") 33 | parser.add_argument("--batch_size", type=int, default=256, help="Training batch size.") 34 | parser.add_argument("--n_epochs", type=int, default=25, help="Number of epochs to train.") 35 | parser.add_argument("--eval_freq", type=int, default=10, help="Number of epochs to between evaluations.") 36 | parser.add_argument("--num_workers", type=int, default=8, help="Number of loader workers") 37 | parser.add_argument("--experiment", help="Mlflow experiment name.") 38 | 39 | args = parser.parse_args() 40 | return args 41 | 42 | 43 | class Stage1Wrapper(nn.Module): 44 | """Wrapper for stage 1 model as a workaround for the DataParallel usage in the training loop.""" 45 | 46 | def __init__(self, model: nn.Module) -> None: 47 | super().__init__() 48 | self.model = model 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | z_mu, z_sigma = self.model.encode(x) 52 | z = self.model.sampling(z_mu, z_sigma) 53 | return z 54 | 55 | 56 | def main(args): 57 | set_determinism(seed=args.seed) 58 | print_config() 59 | 60 | output_dir = Path("/project/outputs/runs/") 61 | output_dir.mkdir(exist_ok=True, parents=True) 62 | 63 | run_dir = output_dir / args.run_dir 64 | if run_dir.exists() and (run_dir / "checkpoint.pth").exists(): 65 | resume = True 66 | else: 67 | resume = False 68 | run_dir.mkdir(exist_ok=True) 69 | 70 | print(f"Run directory: {str(run_dir)}") 71 | print(f"Arguments: {str(args)}") 72 | for k, v in vars(args).items(): 73 | print(f" {k}: {v}") 74 | 75 | writer_train = SummaryWriter(log_dir=str(run_dir / "train")) 76 | writer_val = SummaryWriter(log_dir=str(run_dir / "val")) 77 | 78 | print("Getting data...") 79 | cache_dir = output_dir / "cached_data_diffusion" 80 | cache_dir.mkdir(exist_ok=True) 81 | 82 | train_loader, val_loader = get_dataloader( 83 | cache_dir=cache_dir, 84 | batch_size=args.batch_size, 85 | training_ids=args.training_ids, 86 | validation_ids=args.validation_ids, 87 | num_workers=args.num_workers, 88 | model_type="diffusion", 89 | ) 90 | 91 | # Load Autoencoder to produce the latent representations 92 | print(f"Loading Stage 1 from {args.stage1_uri}") 93 | stage1 = mlflow.pytorch.load_model(args.stage1_uri) 94 | stage1 = Stage1Wrapper(model=stage1) 95 | stage1.eval() 96 | 97 | # Create the diffusion model 98 | print("Creating model...") 99 | config = OmegaConf.load(args.config_file) 100 | diffusion = DiffusionModelUNet(**config["ldm"].get("params", dict())) 101 | scheduler = DDPMScheduler(**config["ldm"].get("scheduler", dict())) 102 | 103 | text_encoder = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1-base", subfolder="text_encoder") 104 | 105 | print(f"Let's use {torch.cuda.device_count()} GPUs!") 106 | device = torch.device("cuda") 107 | if torch.cuda.device_count() > 1: 108 | stage1 = torch.nn.DataParallel(stage1) 109 | diffusion = torch.nn.DataParallel(diffusion) 110 | text_encoder = torch.nn.DataParallel(text_encoder) 111 | 112 | stage1 = stage1.to(device) 113 | diffusion = diffusion.to(device) 114 | text_encoder = text_encoder.to(device) 115 | 116 | optimizer = optim.AdamW(diffusion.parameters(), lr=config["ldm"]["base_lr"]) 117 | 118 | # Get Checkpoint 119 | best_loss = float("inf") 120 | start_epoch = 0 121 | if resume: 122 | print(f"Using checkpoint!") 123 | checkpoint = torch.load(str(run_dir / "checkpoint.pth")) 124 | diffusion.load_state_dict(checkpoint["diffusion"]) 125 | # Issue loading optimizer https://github.com/pytorch/pytorch/issues/2830 126 | optimizer.load_state_dict(checkpoint["optimizer"]) 127 | start_epoch = checkpoint["epoch"] 128 | best_loss = checkpoint["best_loss"] 129 | else: 130 | print(f"No checkpoint found.") 131 | 132 | # Train model 133 | print(f"Starting Training") 134 | val_loss = train_ldm( 135 | model=diffusion, 136 | stage1=stage1, 137 | scheduler=scheduler, 138 | text_encoder=text_encoder, 139 | start_epoch=start_epoch, 140 | best_loss=best_loss, 141 | train_loader=train_loader, 142 | val_loader=val_loader, 143 | optimizer=optimizer, 144 | n_epochs=args.n_epochs, 145 | eval_freq=args.eval_freq, 146 | writer_train=writer_train, 147 | writer_val=writer_val, 148 | device=device, 149 | run_dir=run_dir, 150 | scale_factor=args.scale_factor, 151 | ) 152 | 153 | log_mlflow( 154 | model=diffusion, 155 | config=config, 156 | args=args, 157 | experiment=args.experiment, 158 | run_dir=run_dir, 159 | val_loss=val_loss, 160 | ) 161 | 162 | 163 | if __name__ == "__main__": 164 | args = parse_args() 165 | main(args) 166 | -------------------------------------------------------------------------------- /src/python/training/training_functions.py: -------------------------------------------------------------------------------- 1 | """ Training functions for the different models. """ 2 | from collections import OrderedDict 3 | from pathlib import Path 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from generative.losses.adversarial_loss import PatchAdversarialLoss 9 | from pynvml.smi import nvidia_smi 10 | from tensorboardX import SummaryWriter 11 | from torch.cuda.amp import GradScaler, autocast 12 | from tqdm import tqdm 13 | from util import log_ldm_sample_unconditioned, log_reconstructions 14 | 15 | 16 | def get_lr(optimizer): 17 | for param_group in optimizer.param_groups: 18 | return param_group["lr"] 19 | 20 | 21 | def print_gpu_memory_report(): 22 | if torch.cuda.is_available(): 23 | nvsmi = nvidia_smi.getInstance() 24 | data = nvsmi.DeviceQuery("memory.used, memory.total, utilization.gpu")["gpu"] 25 | print("Memory report") 26 | for i, data_by_rank in enumerate(data): 27 | mem_report = data_by_rank["fb_memory_usage"] 28 | print(f"gpu:{i} mem(%) {int(mem_report['used'] * 100.0 / mem_report['total'])}") 29 | 30 | 31 | # ---------------------------------------------------------------------------------------------------------------------- 32 | # AUTOENCODER KL 33 | # ---------------------------------------------------------------------------------------------------------------------- 34 | def train_aekl( 35 | model: nn.Module, 36 | discriminator: nn.Module, 37 | perceptual_loss: nn.Module, 38 | start_epoch: int, 39 | best_loss: float, 40 | train_loader: torch.utils.data.DataLoader, 41 | val_loader: torch.utils.data.DataLoader, 42 | optimizer_g: torch.optim.Optimizer, 43 | optimizer_d: torch.optim.Optimizer, 44 | n_epochs: int, 45 | eval_freq: int, 46 | writer_train: SummaryWriter, 47 | writer_val: SummaryWriter, 48 | device: torch.device, 49 | run_dir: Path, 50 | adv_weight: float, 51 | perceptual_weight: float, 52 | kl_weight: float, 53 | adv_start: int, 54 | ) -> float: 55 | scaler_g = GradScaler() 56 | scaler_d = GradScaler() 57 | 58 | raw_model = model.module if hasattr(model, "module") else model 59 | 60 | val_loss = eval_aekl( 61 | model=model, 62 | discriminator=discriminator, 63 | perceptual_loss=perceptual_loss, 64 | loader=val_loader, 65 | device=device, 66 | step=len(train_loader) * start_epoch, 67 | writer=writer_val, 68 | kl_weight=kl_weight, 69 | adv_weight=adv_weight if start_epoch >= adv_start else 0.0, 70 | perceptual_weight=perceptual_weight, 71 | ) 72 | print(f"epoch {start_epoch} val loss: {val_loss:.4f}") 73 | for epoch in range(start_epoch, n_epochs): 74 | train_epoch_aekl( 75 | model=model, 76 | discriminator=discriminator, 77 | perceptual_loss=perceptual_loss, 78 | loader=train_loader, 79 | optimizer_g=optimizer_g, 80 | optimizer_d=optimizer_d, 81 | device=device, 82 | epoch=epoch, 83 | writer=writer_train, 84 | kl_weight=kl_weight, 85 | adv_weight=adv_weight if epoch >= adv_start else 0.0, 86 | perceptual_weight=perceptual_weight, 87 | scaler_g=scaler_g, 88 | scaler_d=scaler_d, 89 | ) 90 | 91 | if (epoch + 1) % eval_freq == 0: 92 | val_loss = eval_aekl( 93 | model=model, 94 | discriminator=discriminator, 95 | perceptual_loss=perceptual_loss, 96 | loader=val_loader, 97 | device=device, 98 | step=len(train_loader) * epoch, 99 | writer=writer_val, 100 | kl_weight=kl_weight, 101 | adv_weight=adv_weight if epoch >= adv_start else 0.0, 102 | perceptual_weight=perceptual_weight, 103 | ) 104 | print(f"epoch {epoch + 1} val loss: {val_loss:.4f}") 105 | print_gpu_memory_report() 106 | 107 | # Save checkpoint 108 | checkpoint = { 109 | "epoch": epoch + 1, 110 | "state_dict": model.state_dict(), 111 | "discriminator": discriminator.state_dict(), 112 | "optimizer_g": optimizer_g.state_dict(), 113 | "optimizer_d": optimizer_d.state_dict(), 114 | "best_loss": best_loss, 115 | } 116 | torch.save(checkpoint, str(run_dir / "checkpoint.pth")) 117 | 118 | if val_loss <= best_loss: 119 | print(f"New best val loss {val_loss}") 120 | best_loss = val_loss 121 | 122 | print(f"Training finished!") 123 | print(f"Saving final model...") 124 | torch.save(raw_model.state_dict(), str(run_dir / "final_model.pth")) 125 | 126 | return val_loss 127 | 128 | 129 | def train_epoch_aekl( 130 | model: nn.Module, 131 | discriminator: nn.Module, 132 | perceptual_loss: nn.Module, 133 | loader: torch.utils.data.DataLoader, 134 | optimizer_g: torch.optim.Optimizer, 135 | optimizer_d: torch.optim.Optimizer, 136 | device: torch.device, 137 | epoch: int, 138 | writer: SummaryWriter, 139 | kl_weight: float, 140 | adv_weight: float, 141 | perceptual_weight: float, 142 | scaler_g: GradScaler, 143 | scaler_d: GradScaler, 144 | ) -> None: 145 | model.train() 146 | discriminator.train() 147 | 148 | adv_loss = PatchAdversarialLoss(criterion="least_squares", no_activation_leastsq=True) 149 | 150 | pbar = tqdm(enumerate(loader), total=len(loader)) 151 | for step, x in pbar: 152 | images = x["t1w"].to(device) 153 | 154 | # GENERATOR 155 | optimizer_g.zero_grad(set_to_none=True) 156 | with autocast(enabled=True): 157 | reconstruction, z_mu, z_sigma = model(x=images) 158 | l1_loss = F.l1_loss(reconstruction.float(), images.float()) 159 | p_loss = perceptual_loss(reconstruction.float(), images.float()) 160 | 161 | kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3]) 162 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 163 | 164 | if adv_weight > 0: 165 | logits_fake = discriminator(reconstruction.contiguous().float())[-1] 166 | generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False) 167 | else: 168 | generator_loss = torch.tensor([0.0]).to(device) 169 | 170 | loss = l1_loss + kl_weight * kl_loss + perceptual_weight * p_loss + adv_weight * generator_loss 171 | 172 | loss = loss.mean() 173 | l1_loss = l1_loss.mean() 174 | p_loss = p_loss.mean() 175 | kl_loss = kl_loss.mean() 176 | g_loss = generator_loss.mean() 177 | 178 | losses = OrderedDict( 179 | loss=loss, 180 | l1_loss=l1_loss, 181 | p_loss=p_loss, 182 | kl_loss=kl_loss, 183 | g_loss=g_loss, 184 | ) 185 | 186 | scaler_g.scale(losses["loss"]).backward() 187 | scaler_g.unscale_(optimizer_g) 188 | torch.nn.utils.clip_grad_norm_(model.parameters(), 1) 189 | scaler_g.step(optimizer_g) 190 | scaler_g.update() 191 | 192 | # DISCRIMINATOR 193 | if adv_weight > 0: 194 | optimizer_d.zero_grad(set_to_none=True) 195 | 196 | with autocast(enabled=True): 197 | logits_fake = discriminator(reconstruction.contiguous().detach())[-1] 198 | loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True) 199 | logits_real = discriminator(images.contiguous().detach())[-1] 200 | loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True) 201 | discriminator_loss = (loss_d_fake + loss_d_real) * 0.5 202 | 203 | d_loss = adv_weight * discriminator_loss 204 | d_loss = d_loss.mean() 205 | 206 | scaler_d.scale(d_loss).backward() 207 | scaler_d.unscale_(optimizer_d) 208 | torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1) 209 | scaler_d.step(optimizer_d) 210 | scaler_d.update() 211 | else: 212 | discriminator_loss = torch.tensor([0.0]).to(device) 213 | 214 | losses["d_loss"] = discriminator_loss 215 | 216 | writer.add_scalar("lr_g", get_lr(optimizer_g), epoch * len(loader) + step) 217 | writer.add_scalar("lr_d", get_lr(optimizer_d), epoch * len(loader) + step) 218 | for k, v in losses.items(): 219 | writer.add_scalar(f"{k}", v.item(), epoch * len(loader) + step) 220 | 221 | pbar.set_postfix( 222 | { 223 | "epoch": epoch, 224 | "loss": f"{losses['loss'].item():.6f}", 225 | "l1_loss": f"{losses['l1_loss'].item():.6f}", 226 | "p_loss": f"{losses['p_loss'].item():.6f}", 227 | "g_loss": f"{losses['g_loss'].item():.6f}", 228 | "d_loss": f"{losses['d_loss'].item():.6f}", 229 | "lr_g": f"{get_lr(optimizer_g):.6f}", 230 | "lr_d": f"{get_lr(optimizer_d):.6f}", 231 | }, 232 | ) 233 | 234 | 235 | @torch.no_grad() 236 | def eval_aekl( 237 | model: nn.Module, 238 | discriminator: nn.Module, 239 | perceptual_loss: nn.Module, 240 | loader: torch.utils.data.DataLoader, 241 | device: torch.device, 242 | step: int, 243 | writer: SummaryWriter, 244 | kl_weight: float, 245 | adv_weight: float, 246 | perceptual_weight: float, 247 | ) -> float: 248 | model.eval() 249 | discriminator.eval() 250 | 251 | adv_loss = PatchAdversarialLoss(criterion="least_squares", no_activation_leastsq=True) 252 | total_losses = OrderedDict() 253 | for x in loader: 254 | images = x["t1w"].to(device) 255 | 256 | with autocast(enabled=True): 257 | # GENERATOR 258 | reconstruction, z_mu, z_sigma = model(x=images) 259 | l1_loss = F.l1_loss(reconstruction.float(), images.float()) 260 | p_loss = perceptual_loss(reconstruction.float(), images.float()) 261 | kl_loss = 0.5 * torch.sum(z_mu.pow(2) + z_sigma.pow(2) - torch.log(z_sigma.pow(2)) - 1, dim=[1, 2, 3]) 262 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 263 | 264 | if adv_weight > 0: 265 | logits_fake = discriminator(reconstruction.contiguous().float())[-1] 266 | generator_loss = adv_loss(logits_fake, target_is_real=True, for_discriminator=False) 267 | else: 268 | generator_loss = torch.tensor([0.0]).to(device) 269 | 270 | # DISCRIMINATOR 271 | if adv_weight > 0: 272 | logits_fake = discriminator(reconstruction.contiguous().detach())[-1] 273 | loss_d_fake = adv_loss(logits_fake, target_is_real=False, for_discriminator=True) 274 | logits_real = discriminator(images.contiguous().detach())[-1] 275 | loss_d_real = adv_loss(logits_real, target_is_real=True, for_discriminator=True) 276 | discriminator_loss = (loss_d_fake + loss_d_real) * 0.5 277 | else: 278 | discriminator_loss = torch.tensor([0.0]).to(device) 279 | 280 | loss = l1_loss + kl_weight * kl_loss + perceptual_weight * p_loss + adv_weight * generator_loss 281 | 282 | loss = loss.mean() 283 | l1_loss = l1_loss.mean() 284 | p_loss = p_loss.mean() 285 | kl_loss = kl_loss.mean() 286 | g_loss = generator_loss.mean() 287 | d_loss = discriminator_loss.mean() 288 | 289 | losses = OrderedDict( 290 | loss=loss, 291 | l1_loss=l1_loss, 292 | p_loss=p_loss, 293 | kl_loss=kl_loss, 294 | g_loss=g_loss, 295 | d_loss=d_loss, 296 | ) 297 | 298 | for k, v in losses.items(): 299 | total_losses[k] = total_losses.get(k, 0) + v.item() * images.shape[0] 300 | 301 | for k in total_losses.keys(): 302 | total_losses[k] /= len(loader.dataset) 303 | 304 | for k, v in total_losses.items(): 305 | writer.add_scalar(f"{k}", v, step) 306 | 307 | log_reconstructions( 308 | image=images, 309 | reconstruction=reconstruction, 310 | writer=writer, 311 | step=step, 312 | ) 313 | 314 | return total_losses["l1_loss"] 315 | 316 | 317 | # ---------------------------------------------------------------------------------------------------------------------- 318 | # Latent Diffusion Model 319 | # ---------------------------------------------------------------------------------------------------------------------- 320 | def train_ldm( 321 | model: nn.Module, 322 | stage1: nn.Module, 323 | scheduler: nn.Module, 324 | text_encoder, 325 | start_epoch: int, 326 | best_loss: float, 327 | train_loader: torch.utils.data.DataLoader, 328 | val_loader: torch.utils.data.DataLoader, 329 | optimizer: torch.optim.Optimizer, 330 | n_epochs: int, 331 | eval_freq: int, 332 | writer_train: SummaryWriter, 333 | writer_val: SummaryWriter, 334 | device: torch.device, 335 | run_dir: Path, 336 | scale_factor: float = 1.0, 337 | ) -> float: 338 | scaler = GradScaler() 339 | raw_model = model.module if hasattr(model, "module") else model 340 | 341 | val_loss = eval_ldm( 342 | model=model, 343 | stage1=stage1, 344 | scheduler=scheduler, 345 | text_encoder=text_encoder, 346 | loader=val_loader, 347 | device=device, 348 | step=len(train_loader) * start_epoch, 349 | writer=writer_val, 350 | sample=False, 351 | scale_factor=scale_factor, 352 | ) 353 | print(f"epoch {start_epoch} val loss: {val_loss:.4f}") 354 | 355 | for epoch in range(start_epoch, n_epochs): 356 | train_epoch_ldm( 357 | model=model, 358 | stage1=stage1, 359 | scheduler=scheduler, 360 | text_encoder=text_encoder, 361 | loader=train_loader, 362 | optimizer=optimizer, 363 | device=device, 364 | epoch=epoch, 365 | writer=writer_train, 366 | scaler=scaler, 367 | scale_factor=scale_factor, 368 | ) 369 | 370 | if (epoch + 1) % eval_freq == 0: 371 | val_loss = eval_ldm( 372 | model=model, 373 | stage1=stage1, 374 | scheduler=scheduler, 375 | text_encoder=text_encoder, 376 | loader=val_loader, 377 | device=device, 378 | step=len(train_loader) * epoch, 379 | writer=writer_val, 380 | sample=True if (epoch + 1) % (eval_freq * 2) == 0 else False, 381 | scale_factor=scale_factor, 382 | ) 383 | 384 | print(f"epoch {epoch + 1} val loss: {val_loss:.4f}") 385 | print_gpu_memory_report() 386 | 387 | # Save checkpoint 388 | checkpoint = { 389 | "epoch": epoch + 1, 390 | "diffusion": model.state_dict(), 391 | "optimizer": optimizer.state_dict(), 392 | "best_loss": best_loss, 393 | } 394 | torch.save(checkpoint, str(run_dir / "checkpoint.pth")) 395 | 396 | if val_loss <= best_loss: 397 | print(f"New best val loss {val_loss}") 398 | best_loss = val_loss 399 | torch.save(raw_model.state_dict(), str(run_dir / "best_model.pth")) 400 | 401 | print(f"Training finished!") 402 | print(f"Saving final model...") 403 | torch.save(raw_model.state_dict(), str(run_dir / "final_model.pth")) 404 | 405 | return val_loss 406 | 407 | 408 | def train_epoch_ldm( 409 | model: nn.Module, 410 | stage1: nn.Module, 411 | scheduler: nn.Module, 412 | text_encoder, 413 | loader: torch.utils.data.DataLoader, 414 | optimizer: torch.optim.Optimizer, 415 | device: torch.device, 416 | epoch: int, 417 | writer: SummaryWriter, 418 | scaler: GradScaler, 419 | scale_factor: float = 1.0, 420 | ) -> None: 421 | model.train() 422 | 423 | pbar = tqdm(enumerate(loader), total=len(loader)) 424 | for step, x in pbar: 425 | images = x["t1w"].to(device) 426 | reports = x["report"].to(device) 427 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (images.shape[0],), device=device).long() 428 | 429 | optimizer.zero_grad(set_to_none=True) 430 | with autocast(enabled=True): 431 | with torch.no_grad(): 432 | e = stage1(images) * scale_factor 433 | 434 | prompt_embeds = text_encoder(reports.squeeze(1)) 435 | prompt_embeds = prompt_embeds[0] 436 | 437 | noise = torch.randn_like(e).to(device) 438 | noisy_e = scheduler.add_noise(original_samples=e, noise=noise, timesteps=timesteps) 439 | noise_pred = model(x=noisy_e, timesteps=timesteps, context=prompt_embeds) 440 | 441 | if scheduler.prediction_type == "v_prediction": 442 | # Use v-prediction parameterization 443 | target = scheduler.get_velocity(e, noise, timesteps) 444 | elif scheduler.prediction_type == "epsilon": 445 | target = noise 446 | loss = F.mse_loss(noise_pred.float(), target.float()) 447 | 448 | losses = OrderedDict(loss=loss) 449 | 450 | scaler.scale(losses["loss"]).backward() 451 | scaler.step(optimizer) 452 | scaler.update() 453 | 454 | writer.add_scalar("lr", get_lr(optimizer), epoch * len(loader) + step) 455 | 456 | for k, v in losses.items(): 457 | writer.add_scalar(f"{k}", v.item(), epoch * len(loader) + step) 458 | 459 | pbar.set_postfix({"epoch": epoch, "loss": f"{losses['loss'].item():.5f}", "lr": f"{get_lr(optimizer):.6f}"}) 460 | 461 | 462 | @torch.no_grad() 463 | def eval_ldm( 464 | model: nn.Module, 465 | stage1: nn.Module, 466 | scheduler: nn.Module, 467 | text_encoder, 468 | loader: torch.utils.data.DataLoader, 469 | device: torch.device, 470 | step: int, 471 | writer: SummaryWriter, 472 | sample: bool = False, 473 | scale_factor: float = 1.0, 474 | ) -> float: 475 | model.eval() 476 | raw_stage1 = stage1.module if hasattr(stage1, "module") else stage1 477 | raw_model = model.module if hasattr(model, "module") else model 478 | total_losses = OrderedDict() 479 | 480 | for x in loader: 481 | images = x["t1w"].to(device) 482 | reports = x["report"].to(device) 483 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (images.shape[0],), device=device).long() 484 | 485 | with autocast(enabled=True): 486 | e = stage1(images) * scale_factor 487 | 488 | prompt_embeds = text_encoder(reports.squeeze(1)) 489 | prompt_embeds = prompt_embeds[0] 490 | 491 | noise = torch.randn_like(e).to(device) 492 | noisy_e = scheduler.add_noise(original_samples=e, noise=noise, timesteps=timesteps) 493 | noise_pred = model(x=noisy_e, timesteps=timesteps, context=prompt_embeds) 494 | 495 | if scheduler.prediction_type == "v_prediction": 496 | # Use v-prediction parameterization 497 | target = scheduler.get_velocity(e, noise, timesteps) 498 | elif scheduler.prediction_type == "epsilon": 499 | target = noise 500 | loss = F.mse_loss(noise_pred.float(), target.float()) 501 | 502 | loss = loss.mean() 503 | losses = OrderedDict(loss=loss) 504 | 505 | for k, v in losses.items(): 506 | total_losses[k] = total_losses.get(k, 0) + v.item() * images.shape[0] 507 | 508 | for k in total_losses.keys(): 509 | total_losses[k] /= len(loader.dataset) 510 | 511 | for k, v in total_losses.items(): 512 | writer.add_scalar(f"{k}", v, step) 513 | 514 | if sample: 515 | log_ldm_sample_unconditioned( 516 | model=raw_model, 517 | stage1=raw_stage1, 518 | scheduler=scheduler, 519 | text_encoder=text_encoder, 520 | spatial_shape=tuple(e.shape[1:]), 521 | writer=writer, 522 | step=step, 523 | device=device, 524 | scale_factor=scale_factor, 525 | ) 526 | 527 | return total_losses["loss"] 528 | 529 | 530 | # ---------------------------------------------------------------------------------------------------------------------- 531 | # Controlnet 532 | # ---------------------------------------------------------------------------------------------------------------------- 533 | def train_controlnet( 534 | controlnet: nn.Module, 535 | diffusion: nn.Module, 536 | stage1: nn.Module, 537 | scheduler: nn.Module, 538 | text_encoder, 539 | start_epoch: int, 540 | best_loss: float, 541 | train_loader: torch.utils.data.DataLoader, 542 | val_loader: torch.utils.data.DataLoader, 543 | optimizer: torch.optim.Optimizer, 544 | n_epochs: int, 545 | eval_freq: int, 546 | writer_train: SummaryWriter, 547 | writer_val: SummaryWriter, 548 | device: torch.device, 549 | run_dir: Path, 550 | scale_factor: float = 1.0, 551 | ) -> float: 552 | scaler = GradScaler() 553 | raw_controlnet = controlnet.module if hasattr(controlnet, "module") else controlnet 554 | 555 | val_loss = eval_controlnet( 556 | controlnet=controlnet, 557 | diffusion=diffusion, 558 | stage1=stage1, 559 | scheduler=scheduler, 560 | text_encoder=text_encoder, 561 | loader=val_loader, 562 | device=device, 563 | step=len(train_loader) * start_epoch, 564 | writer=writer_val, 565 | scale_factor=scale_factor, 566 | ) 567 | print(f"epoch {start_epoch} val loss: {val_loss:.4f}") 568 | 569 | for epoch in range(start_epoch, n_epochs): 570 | train_epoch_controlnet( 571 | controlnet=controlnet, 572 | diffusion=diffusion, 573 | stage1=stage1, 574 | scheduler=scheduler, 575 | text_encoder=text_encoder, 576 | loader=train_loader, 577 | optimizer=optimizer, 578 | device=device, 579 | epoch=epoch, 580 | writer=writer_train, 581 | scaler=scaler, 582 | scale_factor=scale_factor, 583 | ) 584 | 585 | if (epoch + 1) % eval_freq == 0: 586 | val_loss = eval_controlnet( 587 | controlnet=controlnet, 588 | diffusion=diffusion, 589 | stage1=stage1, 590 | scheduler=scheduler, 591 | text_encoder=text_encoder, 592 | loader=val_loader, 593 | device=device, 594 | step=len(train_loader) * epoch, 595 | writer=writer_val, 596 | scale_factor=scale_factor, 597 | ) 598 | 599 | print(f"epoch {epoch + 1} val loss: {val_loss:.4f}") 600 | print_gpu_memory_report() 601 | 602 | # Save checkpoint 603 | checkpoint = { 604 | "epoch": epoch + 1, 605 | "controlnet": controlnet.state_dict(), 606 | "optimizer": optimizer.state_dict(), 607 | "best_loss": best_loss, 608 | } 609 | torch.save(checkpoint, str(run_dir / "checkpoint.pth")) 610 | 611 | if val_loss <= best_loss: 612 | print(f"New best val loss {val_loss}") 613 | best_loss = val_loss 614 | torch.save(raw_controlnet.state_dict(), str(run_dir / "best_model.pth")) 615 | 616 | print(f"Training finished!") 617 | print(f"Saving final model...") 618 | torch.save(raw_controlnet.state_dict(), str(run_dir / "final_model.pth")) 619 | 620 | return val_loss 621 | 622 | 623 | def train_epoch_controlnet( 624 | controlnet: nn.Module, 625 | diffusion: nn.Module, 626 | stage1: nn.Module, 627 | scheduler: nn.Module, 628 | text_encoder, 629 | loader: torch.utils.data.DataLoader, 630 | optimizer: torch.optim.Optimizer, 631 | device: torch.device, 632 | epoch: int, 633 | writer: SummaryWriter, 634 | scaler: GradScaler, 635 | scale_factor: float = 1.0, 636 | ) -> None: 637 | controlnet.train() 638 | 639 | pbar = tqdm(enumerate(loader), total=len(loader)) 640 | for step, x in pbar: 641 | images = x["t1w"].to(device) 642 | reports = x["report"].to(device) 643 | cond = x["flair"].to(device) 644 | 645 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (images.shape[0],), device=device).long() 646 | 647 | optimizer.zero_grad(set_to_none=True) 648 | with autocast(enabled=True): 649 | with torch.no_grad(): 650 | e = stage1(images) * scale_factor 651 | 652 | prompt_embeds = text_encoder(reports.squeeze(1)) 653 | prompt_embeds = prompt_embeds[0] 654 | 655 | noise = torch.randn_like(e).to(device) 656 | noisy_e = scheduler.add_noise(original_samples=e, noise=noise, timesteps=timesteps) 657 | 658 | down_block_res_samples, mid_block_res_sample = controlnet( 659 | x=noisy_e, timesteps=timesteps, context=prompt_embeds, controlnet_cond=cond 660 | ) 661 | 662 | noise_pred = diffusion( 663 | x=noisy_e, 664 | timesteps=timesteps, 665 | context=prompt_embeds, 666 | down_block_additional_residuals=down_block_res_samples, 667 | mid_block_additional_residual=mid_block_res_sample, 668 | ) 669 | 670 | if scheduler.prediction_type == "v_prediction": 671 | # Use v-prediction parameterization 672 | target = scheduler.get_velocity(e, noise, timesteps) 673 | elif scheduler.prediction_type == "epsilon": 674 | target = noise 675 | loss = F.mse_loss(noise_pred.float(), target.float()) 676 | 677 | losses = OrderedDict(loss=loss) 678 | 679 | scaler.scale(losses["loss"]).backward() 680 | scaler.step(optimizer) 681 | scaler.update() 682 | 683 | writer.add_scalar("lr", get_lr(optimizer), epoch * len(loader) + step) 684 | 685 | for k, v in losses.items(): 686 | writer.add_scalar(f"{k}", v.item(), epoch * len(loader) + step) 687 | 688 | pbar.set_postfix({"epoch": epoch, "loss": f"{losses['loss'].item():.5f}", "lr": f"{get_lr(optimizer):.6f}"}) 689 | 690 | 691 | @torch.no_grad() 692 | def eval_controlnet( 693 | controlnet: nn.Module, 694 | diffusion: nn.Module, 695 | stage1: nn.Module, 696 | scheduler: nn.Module, 697 | text_encoder, 698 | loader: torch.utils.data.DataLoader, 699 | device: torch.device, 700 | step: int, 701 | writer: SummaryWriter, 702 | scale_factor: float = 1.0, 703 | ) -> float: 704 | controlnet.eval() 705 | total_losses = OrderedDict() 706 | 707 | for x in loader: 708 | images = x["t1w"].to(device) 709 | reports = x["report"].to(device) 710 | cond = x["flair"].to(device) 711 | 712 | timesteps = torch.randint(0, scheduler.num_train_timesteps, (images.shape[0],), device=device).long() 713 | 714 | with autocast(enabled=True): 715 | e = stage1(images) * scale_factor 716 | 717 | prompt_embeds = text_encoder(reports.squeeze(1)) 718 | prompt_embeds = prompt_embeds[0] 719 | 720 | noise = torch.randn_like(e).to(device) 721 | noisy_e = scheduler.add_noise(original_samples=e, noise=noise, timesteps=timesteps) 722 | 723 | down_block_res_samples, mid_block_res_sample = controlnet( 724 | x=noisy_e, timesteps=timesteps, context=prompt_embeds, controlnet_cond=cond 725 | ) 726 | 727 | noise_pred = diffusion( 728 | x=noisy_e, 729 | timesteps=timesteps, 730 | context=prompt_embeds, 731 | down_block_additional_residuals=down_block_res_samples, 732 | mid_block_additional_residual=mid_block_res_sample, 733 | ) 734 | 735 | if scheduler.prediction_type == "v_prediction": 736 | # Use v-prediction parameterization 737 | target = scheduler.get_velocity(e, noise, timesteps) 738 | elif scheduler.prediction_type == "epsilon": 739 | target = noise 740 | loss = F.mse_loss(noise_pred.float(), target.float()) 741 | 742 | loss = loss.mean() 743 | losses = OrderedDict(loss=loss) 744 | 745 | for k, v in losses.items(): 746 | total_losses[k] = total_losses.get(k, 0) + v.item() * images.shape[0] 747 | 748 | for k in total_losses.keys(): 749 | total_losses[k] /= len(loader.dataset) 750 | 751 | for k, v in total_losses.items(): 752 | writer.add_scalar(f"{k}", v, step) 753 | 754 | return total_losses["loss"] 755 | -------------------------------------------------------------------------------- /src/python/training/util.py: -------------------------------------------------------------------------------- 1 | """Utility functions for training.""" 2 | from pathlib import Path 3 | from typing import Tuple, Union 4 | 5 | import matplotlib.pyplot as plt 6 | import mlflow.pytorch 7 | import numpy as np 8 | import pandas as pd 9 | import torch 10 | import torch.nn as nn 11 | from custom_transforms import ApplyTokenizerd 12 | from mlflow import start_run 13 | from monai import transforms 14 | from monai.data import PersistentDataset 15 | from omegaconf import OmegaConf 16 | from omegaconf.dictconfig import DictConfig 17 | from tensorboardX import SummaryWriter 18 | from torch.utils.data import DataLoader 19 | from tqdm import tqdm 20 | 21 | 22 | # ---------------------------------------------------------------------------------------------------------------------- 23 | # DATA LOADING 24 | # ---------------------------------------------------------------------------------------------------------------------- 25 | def get_datalist( 26 | ids_path: str, 27 | extended_report: bool = False, 28 | ): 29 | """Get data dicts for data loaders.""" 30 | df = pd.read_csv(ids_path, sep="\t") 31 | 32 | data_dicts = [] 33 | for index, row in df.iterrows(): 34 | data_dicts.append( 35 | { 36 | "t1w": f"{row['t1w']}", 37 | "flair": f"{row['flair']}", 38 | "report": "T1-weighted image of a brain.", 39 | } 40 | ) 41 | 42 | print(f"Found {len(data_dicts)} subjects.") 43 | return data_dicts 44 | 45 | 46 | def get_dataloader( 47 | cache_dir: Union[str, Path], 48 | batch_size: int, 49 | training_ids: str, 50 | validation_ids: str, 51 | num_workers: int = 8, 52 | model_type: str = "autoencoder", 53 | ): 54 | # Define transformations 55 | val_transforms = transforms.Compose( 56 | [ 57 | transforms.LoadImaged(keys=["t1w"]), 58 | transforms.EnsureChannelFirstd(keys=["t1w"]), 59 | transforms.Rotate90d(keys=["t1w"], k=-1, spatial_axes=(0, 1)), # Fix flipped image read 60 | transforms.Flipd(keys=["t1w"], spatial_axis=1), # Fix flipped image read 61 | transforms.ScaleIntensityRanged(keys=["t1w"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), 62 | ApplyTokenizerd(keys=["report"]), 63 | transforms.ToTensord(keys=["t1w", "report"]), 64 | ] 65 | ) 66 | if model_type == "autoencoder": 67 | train_transforms = transforms.Compose( 68 | [ 69 | transforms.LoadImaged(keys=["t1w"]), 70 | transforms.EnsureChannelFirstd(keys=["t1w"]), 71 | transforms.Rotate90d(keys=["t1w"], k=-1, spatial_axes=(0, 1)), # Fix flipped image read 72 | transforms.Flipd(keys=["t1w"], spatial_axis=1), # Fix flipped image read 73 | transforms.ScaleIntensityRanged(keys=["t1w"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), 74 | transforms.RandFlipd(keys=["t1w"], prob=0.5, spatial_axis=0), 75 | transforms.RandAffined( 76 | keys=["t1w"], 77 | translate_range=(-2, 2), 78 | scale_range=(-0.05, 0.05), 79 | spatial_size=[160, 224], 80 | prob=0.5, 81 | ), 82 | transforms.RandShiftIntensityd(keys=["t1w"], offsets=0.05, prob=0.1), 83 | transforms.RandAdjustContrastd(keys=["t1w"], gamma=(0.97, 1.03), prob=0.1), 84 | transforms.ThresholdIntensityd(keys=["t1w"], threshold=1, above=False, cval=1.0), 85 | transforms.ThresholdIntensityd(keys=["t1w"], threshold=0, above=True, cval=0), 86 | ] 87 | ) 88 | if model_type == "diffusion": 89 | train_transforms = transforms.Compose( 90 | [ 91 | transforms.LoadImaged(keys=["t1w"]), 92 | transforms.EnsureChannelFirstd(keys=["t1w"]), 93 | transforms.Rotate90d(keys=["t1w"], k=-1, spatial_axes=(0, 1)), # Fix flipped image read 94 | transforms.Flipd(keys=["t1w"], spatial_axis=1), # Fix flipped image read 95 | transforms.ScaleIntensityRanged(keys=["t1w"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True), 96 | transforms.RandFlipd(keys=["t1w"], prob=0.5, spatial_axis=0), 97 | transforms.RandAffined( 98 | keys=["t1w"], 99 | translate_range=(-2, 2), 100 | scale_range=(-0.01, 0.01), 101 | spatial_size=[160, 224], 102 | prob=0.25, 103 | ), 104 | transforms.RandShiftIntensityd(keys=["t1w"], offsets=0.05, prob=0.1), 105 | transforms.RandAdjustContrastd(keys=["t1w"], gamma=(0.97, 1.03), prob=0.1), 106 | transforms.ThresholdIntensityd(keys=["t1w"], threshold=1, above=False, cval=1.0), 107 | transforms.ThresholdIntensityd(keys=["t1w"], threshold=0, above=True, cval=0), 108 | ApplyTokenizerd(keys=["report"]), 109 | transforms.RandLambdad( 110 | keys=["report"], 111 | prob=0.10, 112 | func=lambda x: torch.cat( 113 | (49406 * torch.ones(1, 1), 49407 * torch.ones(1, x.shape[1] - 1)), 1 114 | ).long(), 115 | ), # 49406: BOS token 49407: PAD token 116 | ] 117 | ) 118 | if model_type == "controlnet": 119 | val_transforms = transforms.Compose( 120 | [ 121 | transforms.LoadImaged(keys=["t1w", "flair"]), 122 | transforms.EnsureChannelFirstd(keys=["t1w", "flair"]), 123 | transforms.Rotate90d(keys=["t1w", "flair"], k=-1, spatial_axes=(0, 1)), # Fix flipped image read 124 | transforms.Flipd(keys=["t1w", "flair"], spatial_axis=1), # Fix flipped image read 125 | transforms.ScaleIntensityRanged( 126 | keys=["t1w", "flair"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True 127 | ), 128 | ApplyTokenizerd(keys=["report"]), 129 | transforms.ToTensord(keys=["t1w", "flair", "report"]), 130 | ] 131 | ) 132 | train_transforms = transforms.Compose( 133 | [ 134 | transforms.LoadImaged(keys=["t1w", "flair"]), 135 | transforms.EnsureChannelFirstd(keys=["t1w", "flair"]), 136 | transforms.Rotate90d(keys=["t1w", "flair"], k=-1, spatial_axes=(0, 1)), # Fix flipped image read 137 | transforms.Flipd(keys=["t1w", "flair"], spatial_axis=1), # Fix flipped image read 138 | transforms.ScaleIntensityRanged( 139 | keys=["t1w", "flair"], a_min=0.0, a_max=255.0, b_min=0.0, b_max=1.0, clip=True 140 | ), 141 | transforms.RandFlipd(keys=["t1w", "flair"], prob=0.5, spatial_axis=0), 142 | transforms.RandAffined( 143 | keys=["t1w", "flair"], 144 | translate_range=(-2, 2), 145 | scale_range=(-0.01, 0.01), 146 | spatial_size=[160, 224], 147 | prob=0.25, 148 | ), 149 | transforms.RandShiftIntensityd(keys=["t1w", "flair"], offsets=0.05, prob=0.1), 150 | transforms.RandAdjustContrastd(keys=["t1w", "flair"], gamma=(0.97, 1.03), prob=0.1), 151 | transforms.ThresholdIntensityd(keys=["t1w", "flair"], threshold=1, above=False, cval=1.0), 152 | transforms.ThresholdIntensityd(keys=["t1w", "flair"], threshold=0, above=True, cval=0), 153 | ApplyTokenizerd(keys=["report"]), 154 | transforms.RandLambdad( 155 | keys=["report"], 156 | prob=0.10, 157 | func=lambda x: torch.cat( 158 | (49406 * torch.ones(1, 1), 49407 * torch.ones(1, x.shape[1] - 1)), 1 159 | ).long(), 160 | ), # 49406: BOS token 49407: PAD token 161 | ] 162 | ) 163 | 164 | train_dicts = get_datalist(ids_path=training_ids) 165 | train_ds = PersistentDataset(data=train_dicts, transform=train_transforms, cache_dir=str(cache_dir)) 166 | train_loader = DataLoader( 167 | train_ds, 168 | batch_size=batch_size, 169 | shuffle=True, 170 | num_workers=num_workers, 171 | drop_last=False, 172 | pin_memory=False, 173 | persistent_workers=True, 174 | ) 175 | 176 | val_dicts = get_datalist(ids_path=validation_ids) 177 | val_ds = PersistentDataset(data=val_dicts, transform=val_transforms, cache_dir=str(cache_dir)) 178 | val_loader = DataLoader( 179 | val_ds, 180 | batch_size=batch_size, 181 | num_workers=num_workers, 182 | drop_last=False, 183 | pin_memory=False, 184 | persistent_workers=True, 185 | ) 186 | 187 | return train_loader, val_loader 188 | 189 | 190 | # ---------------------------------------------------------------------------------------------------------------------- 191 | # LOGS 192 | # ---------------------------------------------------------------------------------------------------------------------- 193 | def recursive_items(dictionary, prefix=""): 194 | for key, value in dictionary.items(): 195 | if type(value) in [dict, DictConfig]: 196 | yield from recursive_items(value, prefix=str(key) if prefix == "" else f"{prefix}.{str(key)}") 197 | else: 198 | yield (str(key) if prefix == "" else f"{prefix}.{str(key)}", value) 199 | 200 | 201 | def log_mlflow( 202 | model, 203 | config, 204 | args, 205 | experiment: str, 206 | run_dir: Path, 207 | val_loss: float, 208 | ): 209 | """Log model and performance on Mlflow system""" 210 | config = {**OmegaConf.to_container(config), **vars(args)} 211 | print(f"Setting mlflow experiment: {experiment}") 212 | mlflow.set_experiment(experiment) 213 | 214 | with start_run(): 215 | print(f"MLFLOW URI: {mlflow.tracking.get_tracking_uri()}") 216 | print(f"MLFLOW ARTIFACT URI: {mlflow.get_artifact_uri()}") 217 | 218 | for key, value in recursive_items(config): 219 | mlflow.log_param(key, str(value)) 220 | 221 | mlflow.log_artifacts(str(run_dir / "train"), artifact_path="events_train") 222 | mlflow.log_artifacts(str(run_dir / "val"), artifact_path="events_val") 223 | mlflow.log_metric(f"loss", val_loss, 0) 224 | 225 | raw_model = model.module if hasattr(model, "module") else model 226 | mlflow.pytorch.log_model(raw_model, "final_model") 227 | 228 | 229 | def get_figure( 230 | img: torch.Tensor, 231 | recons: torch.Tensor, 232 | ): 233 | img_npy_0 = np.clip(a=img[0, 0, :, :].cpu().numpy(), a_min=0, a_max=1) 234 | recons_npy_0 = np.clip(a=recons[0, 0, :, :].cpu().numpy(), a_min=0, a_max=1) 235 | img_npy_1 = np.clip(a=img[1, 0, :, :].cpu().numpy(), a_min=0, a_max=1) 236 | recons_npy_1 = np.clip(a=recons[1, 0, :, :].cpu().numpy(), a_min=0, a_max=1) 237 | 238 | img_row_0 = np.concatenate( 239 | ( 240 | img_npy_0, 241 | recons_npy_0, 242 | img_npy_1, 243 | recons_npy_1, 244 | ), 245 | axis=1, 246 | ) 247 | 248 | fig = plt.figure(dpi=300) 249 | plt.imshow(img_row_0, cmap="gray") 250 | plt.axis("off") 251 | return fig 252 | 253 | 254 | def log_reconstructions( 255 | image: torch.Tensor, 256 | reconstruction: torch.Tensor, 257 | writer: SummaryWriter, 258 | step: int, 259 | title: str = "RECONSTRUCTION", 260 | ) -> None: 261 | fig = get_figure( 262 | image, 263 | reconstruction, 264 | ) 265 | writer.add_figure(title, fig, step) 266 | 267 | 268 | @torch.no_grad() 269 | def log_ldm_sample_unconditioned( 270 | model: nn.Module, 271 | stage1: nn.Module, 272 | text_encoder, 273 | scheduler: nn.Module, 274 | spatial_shape: Tuple, 275 | writer: SummaryWriter, 276 | step: int, 277 | device: torch.device, 278 | scale_factor: float = 1.0, 279 | ) -> None: 280 | latent = torch.randn((1,) + spatial_shape) 281 | latent = latent.to(device) 282 | 283 | prompt_embeds = torch.cat((49406 * torch.ones(1, 1), 49407 * torch.ones(1, 76)), 1).long() 284 | prompt_embeds = text_encoder(prompt_embeds.squeeze(1)) 285 | prompt_embeds = prompt_embeds[0] 286 | 287 | for t in tqdm(scheduler.timesteps, ncols=70): 288 | noise_pred = model(x=latent, timesteps=torch.asarray((t,)).to(device), context=prompt_embeds) 289 | latent, _ = scheduler.step(noise_pred, t, latent) 290 | 291 | x_hat = stage1.model.decode(latent / scale_factor) 292 | img_0 = np.clip(a=x_hat[0, 0, :, :].cpu().numpy(), a_min=0, a_max=1) 293 | fig = plt.figure(dpi=300) 294 | plt.imshow(img_0, cmap="gray") 295 | plt.axis("off") 296 | writer.add_figure("SAMPLE", fig, step) 297 | --------------------------------------------------------------------------------