├── CITATION.md ├── LICENSE ├── README.md ├── eval.py ├── figs ├── teaser.png └── teaser_data.png ├── main.py ├── requirements.txt ├── training.py └── utils.py /CITATION.md: -------------------------------------------------------------------------------- 1 | ```bib 2 | @inproceedings{konz2022intrinsic, 3 | title={Anatomically-Controllable Medical Image Generation with Segmentation-Guided Diffusion Models}, 4 | author={Nicholas Konz and Yuwen Chen and Haoyu Dong and Maciej A. Mazurowski}, 5 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 6 | year={2024} 7 | } 8 | ``` 9 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Attribution-NonCommercial 4.0 International 2 | 3 | Creative Commons Corporation ("Creative Commons") is not a law firm and 4 | does not provide legal services or legal advice. Distribution of 5 | Creative Commons public licenses does not create a lawyer-client or 6 | other relationship. Creative Commons makes its licenses and related 7 | information available on an "as-is" basis. Creative Commons gives no 8 | warranties regarding its licenses, any material licensed under their 9 | terms and conditions, or any related information. Creative Commons 10 | disclaims all liability for damages resulting from their use to the 11 | fullest extent possible. 12 | 13 | Using Creative Commons Public Licenses 14 | 15 | Creative Commons public licenses provide a standard set of terms and 16 | conditions that creators and other rights holders may use to share 17 | original works of authorship and other material subject to copyright and 18 | certain other rights specified in the public license below. The 19 | following considerations are for informational purposes only, are not 20 | exhaustive, and do not form part of our licenses. 21 | 22 | - Considerations for licensors: Our public licenses are intended for 23 | use by those authorized to give the public permission to use 24 | material in ways otherwise restricted by copyright and certain other 25 | rights. Our licenses are irrevocable. Licensors should read and 26 | understand the terms and conditions of the license they choose 27 | before applying it. Licensors should also secure all rights 28 | necessary before applying our licenses so that the public can reuse 29 | the material as expected. Licensors should clearly mark any material 30 | not subject to the license. This includes other CC-licensed 31 | material, or material used under an exception or limitation to 32 | copyright. More considerations for licensors : 33 | wiki.creativecommons.org/Considerations_for_licensors 34 | 35 | - Considerations for the public: By using one of our public licenses, 36 | a licensor grants the public permission to use the licensed material 37 | under specified terms and conditions. If the licensor's permission 38 | is not necessary for any reason–for example, because of any 39 | applicable exception or limitation to copyright–then that use is not 40 | regulated by the license. Our licenses grant only permissions under 41 | copyright and certain other rights that a licensor has authority to 42 | grant. Use of the licensed material may still be restricted for 43 | other reasons, including because others have copyright or other 44 | rights in the material. A licensor may make special requests, such 45 | as asking that all changes be marked or described. Although not 46 | required by our licenses, you are encouraged to respect those 47 | requests where reasonable. More considerations for the public : 48 | wiki.creativecommons.org/Considerations_for_licensees 49 | 50 | Creative Commons Attribution-NonCommercial 4.0 International Public 51 | License 52 | 53 | By exercising the Licensed Rights (defined below), You accept and agree 54 | to be bound by the terms and conditions of this Creative Commons 55 | Attribution-NonCommercial 4.0 International Public License ("Public 56 | License"). To the extent this Public License may be interpreted as a 57 | contract, You are granted the Licensed Rights in consideration of Your 58 | acceptance of these terms and conditions, and the Licensor grants You 59 | such rights in consideration of benefits the Licensor receives from 60 | making the Licensed Material available under these terms and conditions. 61 | 62 | - Section 1 – Definitions. 63 | 64 | - a. Adapted Material means material subject to Copyright and 65 | Similar Rights that is derived from or based upon the Licensed 66 | Material and in which the Licensed Material is translated, 67 | altered, arranged, transformed, or otherwise modified in a 68 | manner requiring permission under the Copyright and Similar 69 | Rights held by the Licensor. For purposes of this Public 70 | License, where the Licensed Material is a musical work, 71 | performance, or sound recording, Adapted Material is always 72 | produced where the Licensed Material is synched in timed 73 | relation with a moving image. 74 | - b. Adapter's License means the license You apply to Your 75 | Copyright and Similar Rights in Your contributions to Adapted 76 | Material in accordance with the terms and conditions of this 77 | Public License. 78 | - c. Copyright and Similar Rights means copyright and/or similar 79 | rights closely related to copyright including, without 80 | limitation, performance, broadcast, sound recording, and Sui 81 | Generis Database Rights, without regard to how the rights are 82 | labeled or categorized. For purposes of this Public License, the 83 | rights specified in Section 2(b)(1)-(2) are not Copyright and 84 | Similar Rights. 85 | - d. Effective Technological Measures means those measures that, 86 | in the absence of proper authority, may not be circumvented 87 | under laws fulfilling obligations under Article 11 of the WIPO 88 | Copyright Treaty adopted on December 20, 1996, and/or similar 89 | international agreements. 90 | - e. Exceptions and Limitations means fair use, fair dealing, 91 | and/or any other exception or limitation to Copyright and 92 | Similar Rights that applies to Your use of the Licensed 93 | Material. 94 | - f. Licensed Material means the artistic or literary work, 95 | database, or other material to which the Licensor applied this 96 | Public License. 97 | - g. Licensed Rights means the rights granted to You subject to 98 | the terms and conditions of this Public License, which are 99 | limited to all Copyright and Similar Rights that apply to Your 100 | use of the Licensed Material and that the Licensor has authority 101 | to license. 102 | - h. Licensor means the individual(s) or entity(ies) granting 103 | rights under this Public License. 104 | - i. NonCommercial means not primarily intended for or directed 105 | towards commercial advantage or monetary compensation. For 106 | purposes of this Public License, the exchange of the Licensed 107 | Material for other material subject to Copyright and Similar 108 | Rights by digital file-sharing or similar means is NonCommercial 109 | provided there is no payment of monetary compensation in 110 | connection with the exchange. 111 | - j. Share means to provide material to the public by any means or 112 | process that requires permission under the Licensed Rights, such 113 | as reproduction, public display, public performance, 114 | distribution, dissemination, communication, or importation, and 115 | to make material available to the public including in ways that 116 | members of the public may access the material from a place and 117 | at a time individually chosen by them. 118 | - k. Sui Generis Database Rights means rights other than copyright 119 | resulting from Directive 96/9/EC of the European Parliament and 120 | of the Council of 11 March 1996 on the legal protection of 121 | databases, as amended and/or succeeded, as well as other 122 | essentially equivalent rights anywhere in the world. 123 | - l. You means the individual or entity exercising the Licensed 124 | Rights under this Public License. Your has a corresponding 125 | meaning. 126 | 127 | - Section 2 – Scope. 128 | 129 | - a. License grant. 130 | - 1. Subject to the terms and conditions of this Public 131 | License, the Licensor hereby grants You a worldwide, 132 | royalty-free, non-sublicensable, non-exclusive, irrevocable 133 | license to exercise the Licensed Rights in the Licensed 134 | Material to: 135 | - A. reproduce and Share the Licensed Material, in whole 136 | or in part, for NonCommercial purposes only; and 137 | - B. produce, reproduce, and Share Adapted Material for 138 | NonCommercial purposes only. 139 | - 2. Exceptions and Limitations. For the avoidance of doubt, 140 | where Exceptions and Limitations apply to Your use, this 141 | Public License does not apply, and You do not need to comply 142 | with its terms and conditions. 143 | - 3. Term. The term of this Public License is specified in 144 | Section 6(a). 145 | - 4. Media and formats; technical modifications allowed. The 146 | Licensor authorizes You to exercise the Licensed Rights in 147 | all media and formats whether now known or hereafter 148 | created, and to make technical modifications necessary to do 149 | so. The Licensor waives and/or agrees not to assert any 150 | right or authority to forbid You from making technical 151 | modifications necessary to exercise the Licensed Rights, 152 | including technical modifications necessary to circumvent 153 | Effective Technological Measures. For purposes of this 154 | Public License, simply making modifications authorized by 155 | this Section 2(a)(4) never produces Adapted Material. 156 | - 5. Downstream recipients. 157 | - A. Offer from the Licensor – Licensed Material. Every 158 | recipient of the Licensed Material automatically 159 | receives an offer from the Licensor to exercise the 160 | Licensed Rights under the terms and conditions of this 161 | Public License. 162 | - B. No downstream restrictions. You may not offer or 163 | impose any additional or different terms or conditions 164 | on, or apply any Effective Technological Measures to, 165 | the Licensed Material if doing so restricts exercise of 166 | the Licensed Rights by any recipient of the Licensed 167 | Material. 168 | - 6. No endorsement. Nothing in this Public License 169 | constitutes or may be construed as permission to assert or 170 | imply that You are, or that Your use of the Licensed 171 | Material is, connected with, or sponsored, endorsed, or 172 | granted official status by, the Licensor or others 173 | designated to receive attribution as provided in Section 174 | 3(a)(1)(A)(i). 175 | - b. Other rights. 176 | - 1. Moral rights, such as the right of integrity, are not 177 | licensed under this Public License, nor are publicity, 178 | privacy, and/or other similar personality rights; however, 179 | to the extent possible, the Licensor waives and/or agrees 180 | not to assert any such rights held by the Licensor to the 181 | limited extent necessary to allow You to exercise the 182 | Licensed Rights, but not otherwise. 183 | - 2. Patent and trademark rights are not licensed under this 184 | Public License. 185 | - 3. To the extent possible, the Licensor waives any right to 186 | collect royalties from You for the exercise of the Licensed 187 | Rights, whether directly or through a collecting society 188 | under any voluntary or waivable statutory or compulsory 189 | licensing scheme. In all other cases the Licensor expressly 190 | reserves any right to collect such royalties, including when 191 | the Licensed Material is used other than for NonCommercial 192 | purposes. 193 | 194 | - Section 3 – License Conditions. 195 | 196 | Your exercise of the Licensed Rights is expressly made subject to 197 | the following conditions. 198 | 199 | - a. Attribution. 200 | - 1. If You Share the Licensed Material (including in modified 201 | form), You must: 202 | - A. retain the following if it is supplied by the 203 | Licensor with the Licensed Material: 204 | - i. identification of the creator(s) of the Licensed 205 | Material and any others designated to receive 206 | attribution, in any reasonable manner requested by 207 | the Licensor (including by pseudonym if designated); 208 | - ii. a copyright notice; 209 | - iii. a notice that refers to this Public License; 210 | - iv. a notice that refers to the disclaimer of 211 | warranties; 212 | - v. a URI or hyperlink to the Licensed Material to 213 | the extent reasonably practicable; 214 | - B. indicate if You modified the Licensed Material and 215 | retain an indication of any previous modifications; and 216 | - C. indicate the Licensed Material is licensed under this 217 | Public License, and include the text of, or the URI or 218 | hyperlink to, this Public License. 219 | - 2. You may satisfy the conditions in Section 3(a)(1) in any 220 | reasonable manner based on the medium, means, and context in 221 | which You Share the Licensed Material. For example, it may 222 | be reasonable to satisfy the conditions by providing a URI 223 | or hyperlink to a resource that includes the required 224 | information. 225 | - 3. If requested by the Licensor, You must remove any of the 226 | information required by Section 3(a)(1)(A) to the extent 227 | reasonably practicable. 228 | - 4. If You Share Adapted Material You produce, the Adapter's 229 | License You apply must not prevent recipients of the Adapted 230 | Material from complying with this Public License. 231 | 232 | - Section 4 – Sui Generis Database Rights. 233 | 234 | Where the Licensed Rights include Sui Generis Database Rights that 235 | apply to Your use of the Licensed Material: 236 | 237 | - a. for the avoidance of doubt, Section 2(a)(1) grants You the 238 | right to extract, reuse, reproduce, and Share all or a 239 | substantial portion of the contents of the database for 240 | NonCommercial purposes only; 241 | - b. if You include all or a substantial portion of the database 242 | contents in a database in which You have Sui Generis Database 243 | Rights, then the database in which You have Sui Generis Database 244 | Rights (but not its individual contents) is Adapted Material; 245 | and 246 | - c. You must comply with the conditions in Section 3(a) if You 247 | Share all or a substantial portion of the contents of the 248 | database. 249 | 250 | For the avoidance of doubt, this Section 4 supplements and does not 251 | replace Your obligations under this Public License where the 252 | Licensed Rights include other Copyright and Similar Rights. 253 | 254 | - Section 5 – Disclaimer of Warranties and Limitation of Liability. 255 | 256 | - a. Unless otherwise separately undertaken by the Licensor, to 257 | the extent possible, the Licensor offers the Licensed Material 258 | as-is and as-available, and makes no representations or 259 | warranties of any kind concerning the Licensed Material, whether 260 | express, implied, statutory, or other. This includes, without 261 | limitation, warranties of title, merchantability, fitness for a 262 | particular purpose, non-infringement, absence of latent or other 263 | defects, accuracy, or the presence or absence of errors, whether 264 | or not known or discoverable. Where disclaimers of warranties 265 | are not allowed in full or in part, this disclaimer may not 266 | apply to You. 267 | - b. To the extent possible, in no event will the Licensor be 268 | liable to You on any legal theory (including, without 269 | limitation, negligence) or otherwise for any direct, special, 270 | indirect, incidental, consequential, punitive, exemplary, or 271 | other losses, costs, expenses, or damages arising out of this 272 | Public License or use of the Licensed Material, even if the 273 | Licensor has been advised of the possibility of such losses, 274 | costs, expenses, or damages. Where a limitation of liability is 275 | not allowed in full or in part, this limitation may not apply to 276 | You. 277 | - c. The disclaimer of warranties and limitation of liability 278 | provided above shall be interpreted in a manner that, to the 279 | extent possible, most closely approximates an absolute 280 | disclaimer and waiver of all liability. 281 | 282 | - Section 6 – Term and Termination. 283 | 284 | - a. This Public License applies for the term of the Copyright and 285 | Similar Rights licensed here. However, if You fail to comply 286 | with this Public License, then Your rights under this Public 287 | License terminate automatically. 288 | - b. Where Your right to use the Licensed Material has terminated 289 | under Section 6(a), it reinstates: 290 | 291 | - 1. automatically as of the date the violation is cured, 292 | provided it is cured within 30 days of Your discovery of the 293 | violation; or 294 | - 2. upon express reinstatement by the Licensor. 295 | 296 | For the avoidance of doubt, this Section 6(b) does not affect 297 | any right the Licensor may have to seek remedies for Your 298 | violations of this Public License. 299 | 300 | - c. For the avoidance of doubt, the Licensor may also offer the 301 | Licensed Material under separate terms or conditions or stop 302 | distributing the Licensed Material at any time; however, doing 303 | so will not terminate this Public License. 304 | - d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 305 | License. 306 | 307 | - Section 7 – Other Terms and Conditions. 308 | 309 | - a. The Licensor shall not be bound by any additional or 310 | different terms or conditions communicated by You unless 311 | expressly agreed. 312 | - b. Any arrangements, understandings, or agreements regarding the 313 | Licensed Material not stated herein are separate from and 314 | independent of the terms and conditions of this Public License. 315 | 316 | - Section 8 – Interpretation. 317 | 318 | - a. For the avoidance of doubt, this Public License does not, and 319 | shall not be interpreted to, reduce, limit, restrict, or impose 320 | conditions on any use of the Licensed Material that could 321 | lawfully be made without permission under this Public License. 322 | - b. To the extent possible, if any provision of this Public 323 | License is deemed unenforceable, it shall be automatically 324 | reformed to the minimum extent necessary to make it enforceable. 325 | If the provision cannot be reformed, it shall be severed from 326 | this Public License without affecting the enforceability of the 327 | remaining terms and conditions. 328 | - c. No term or condition of this Public License will be waived 329 | and no failure to comply consented to unless expressly agreed to 330 | by the Licensor. 331 | - d. Nothing in this Public License constitutes or may be 332 | interpreted as a limitation upon, or waiver of, any privileges 333 | and immunities that apply to the Licensor or You, including from 334 | the legal processes of any jurisdiction or authority. 335 | 336 | Creative Commons is not a party to its public licenses. Notwithstanding, 337 | Creative Commons may elect to apply one of its public licenses to 338 | material it publishes and in those instances will be considered the 339 | "Licensor." The text of the Creative Commons public licenses is 340 | dedicated to the public domain under the CC0 Public Domain Dedication. 341 | Except for the limited purpose of indicating that material is shared 342 | under a Creative Commons public license or as otherwise permitted by the 343 | Creative Commons policies published at creativecommons.org/policies, 344 | Creative Commons does not authorize the use of the trademark "Creative 345 | Commons" or any other trademark or logo of Creative Commons without its 346 | prior written consent including, without limitation, in connection with 347 | any unauthorized modifications to any of its public licenses or any 348 | other arrangements, understandings, or agreements concerning use of 349 | licensed material. For the avoidance of doubt, this paragraph does not 350 | form part of the public licenses. 351 | 352 | Creative Commons may be contacted at creativecommons.org. 353 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Easy and Precise Segmentation-Guided Diffusion Models 2 | 3 | #### By [Nicholas Konz](https://nickk124.github.io/), [Yuwen Chen](https://scholar.google.com/citations?user=61s49p0AAAAJ&hl=en), [Haoyu Dong](https://scholar.google.com/citations?user=eZVEUCIAAAAJ&hl=en) and [Maciej Mazurowski](https://sites.duke.edu/mazurowski/). 4 | 5 | arXiv paper link: [![arXiv Paper](https://img.shields.io/badge/arXiv-2402.05210-orange.svg?style=flat)](https://arxiv.org/abs/2402.05210) 6 | 7 | ## NEWS: our paper was accepted to MICCAI 2024! 8 | 9 |

10 | 11 |

12 | 13 | This is the code for our paper [**Anatomically-Controllable Medical Image Generation with Segmentation-Guided Diffusion Models**](https://arxiv.org/abs/2402.05210), where we introduce a simple yet powerful training procedure for conditioning image-generating diffusion models on (possibly incomplete) multiclass segmentation masks. 14 | 15 | ### Why use our model? 16 | 17 | Our method outperforms existing segmentation-guided image generative models (like [SPADE](https://github.com/NVlabs/SPADE) and [ControlNet](https://github.com/lllyasviel/ControlNet)) in terms of the faithfulness of generated images to input masks, on multiple, multi-modality medical image datasets with a broad range of objects of interest, and is on par for anatomical realism. Our method is also simple to use and train, and its precise pixel-wise obedience to input segmentation masks is due to it always operating in the native image space (it's not a latent diffusion model), which is especially helpful when conditioning on complex and detailed anatomical structures. 18 | 19 | Additionally, our optional *mask-ablated training* algorithm allows our model to be conditioned on segmentation masks with missing classes, which is useful for medical images where segmentation masks may be incomplete or noisy. This allows not just for more flexible image generation, but as we show in our paper, adjustable anatomical similarity of images to some real image by taking advantage of the latent space structure of diffusion models. We also used this feature to generate a synthetic paired breast MRI dataset, [shown below](https://github.com/mazurowski-lab/segmentation-guided-diffusion?tab=readme-ov-file#synthetic-paired-breast-mri-dataset-release). 20 | 21 | **Using this code, you can:** 22 | 1. Train a segmentation-guided (or standard unconditional) diffusion model on your own dataset, with a wide range of options, including mask-ablated training. 23 | 2. Generate images from these models (or using our provided pre-trained models). 24 | 25 | Please follow the steps outlined below to do these. 26 | 27 | Also, check out our accompanying [**Synthetic Paired Breast MRI Dataset Release**](https://github.com/mazurowski-lab/segmentation-guided-diffusion?tab=readme-ov-file#synthetic-paired-breast-mri-dataset-release) below! 28 | 29 | Thank you to Hugging Face's awesome [Diffusers](https://github.com/huggingface/diffusers) library for providing a helpful backbone for our code! 30 | 31 | ## Citation 32 | 33 | Please cite our paper if you use our code or reference our work: 34 | ```bib 35 | @inproceedings{konz2024segguideddiffusion, 36 | title={Anatomically-Controllable Medical Image Generation with Segmentation-Guided Diffusion Models}, 37 | author={Nicholas Konz and Yuwen Chen and Haoyu Dong and Maciej A. Mazurowski}, 38 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention}, 39 | year={2024} 40 | } 41 | ``` 42 | 43 | ## 1) Package Installation 44 | This codebase was created with Python 3.11. First, install PyTorch for your computer's CUDA version (check it by running `nvidia-smi` if you're not sure) according to the provided command at https://pytorch.org/get-started/locally/; this codebase was made with `torch==2.1.2` and `torchvision==0.16.2` on CUDA 12.2. Next, run `pip3 install -r requirements.txt` to install the required packages. 45 | 46 | ## 2a) (optional) Use Pre-Trained Models 47 | 48 | We provide pre-trained model checkpoints (`.safetensor` files) and config (`.json`) files from our paper for the [Duke Breast MRI](https://www.cancerimagingarchive.net/collection/duke-breast-cancer-mri/) and [CT Organ](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=61080890) datasets, [here](https://drive.google.com/drive/folders/1OaOGBLfpUFe_tmpvZGEe2Mv2gow32Y8u). These include: 49 | 50 | 1. Segmentation-Conditional Models, trained without mask ablation. 51 | 2. Segmentation-Conditional Models, trained with mask ablation. 52 | 3. Unconditional (standard) Models. 53 | 54 | Once you've downloaded the checkpoint and config file for your model of choice, please: 55 | 1. Put both files in a directory called `{NAME}/unet`, where `NAME` is the model checkpoint's filename without the `.safetensors` ending, to use it with our evaluation code. 56 | 2. Rename the checkpoint file to `diffusion_pytorch_model.safetensors` and the config file to `config.json`. 57 | 58 | Next, you can proceed to the [**Evaluation/Sampling**](https://github.com/mazurowski-lab/segmentation-guided-diffusion#3-evaluationsampling) section below to generate images from these models. 59 | 60 | ## 2b) Train Your Own Models 61 | 62 | ### Data Preparation 63 | 64 | Please put your training images in some dataset directory `DATA_FOLDER` (with any filenames), organized into train, validation and test split subdirectories. The images should be in a format that PIL can read (e.g. `.png`, `.jpg`, etc.) if they are standard 1- or 3-channel images, and for images with other channel counts use `.npy` NumPy array files. For example: 65 | 66 | ``` 67 | DATA_FOLDER 68 | ├── train 69 | │ ├── tr_1.png 70 | │ ├── tr_2.png 71 | │ └── ... 72 | ├── val 73 | │ ├── val_1.png 74 | │ ├── val_2.png 75 | │ └── ... 76 | └── test 77 | ├── ts_1.png 78 | ├── ts_2.png 79 | └── ... 80 | ``` 81 | 82 | If you are using a segmentation-guided model, please put your segmentation masks within a similar directory structure in a separate folder `MASK_FOLDER`, with a subdirectory `all` that contains the split subfolders, as shown below. **Each segmentation mask should have the same filename as its corresponding image in `DATA_FOLDER`, and should be saved with integer values starting at zero for each object class, i.e., 0, 1, 2,...**. If you don't want to train a segmentation-guided model, you can skip this step. 83 | 84 | ``` 85 | MASK_FOLDER 86 | ├── all 87 | │ ├── train 88 | │ │ ├── tr_1.png 89 | │ │ ├── tr_2.png 90 | │ │ └── ... 91 | │ ├── val 92 | │ │ ├── val_1.png 93 | │ │ ├── val_2.png 94 | │ │ └── ... 95 | │ └── test 96 | │ ├── ts_1.png 97 | │ ├── ts_2.png 98 | │ └── ... 99 | ``` 100 | 101 | ### Training 102 | 103 | The basic command for training a standard unconditional diffusion model is 104 | ```bash 105 | CUDA_VISIBLE_DEVICES={DEVICES} python3 main.py \ 106 | --mode train \ 107 | --model_type DDIM \ 108 | --img_size {IMAGE_SIZE} \ 109 | --num_img_channels {NUM_IMAGE_CHANNELS} \ 110 | --dataset {DATASET_NAME} \ 111 | --img_dir {DATA_FOLDER} \ 112 | --train_batch_size 16 \ 113 | --eval_batch_size 8 \ 114 | --num_epochs 400 115 | ``` 116 | 117 | where: 118 | - `DEVICES` is a comma-separated list of GPU device indices to use (e.g. `0,1,2,3`). 119 | - `IMAGE_SIZE` and `NUM_IMAGE_CHANNELS` respectively specify the size of the images to train on (e.g. `256`) and the number of channels (`1` for greyscale, `3` for RGB). 120 | - `model_type` specifies the type of diffusion model sampling algorithm to evaluate the model with, and can be `DDIM` or `DDPM`. 121 | - `DATASET_NAME` is some name for your dataset (e.g. `breast_mri`). 122 | - `DATA_FOLDER` is the path to your dataset directory, as outlined in the previous section. 123 | - `--train_batch_size` and `--eval_batch_size` specify the batch sizes for training and evaluation, respectively. We use a train batch size of 16 for one 48 GB A6000 GPU for an image size of 256. 124 | - `--num_epochs` specifies the number of epochs to train for (our default is 400). 125 | 126 | #### Adding segmentation guidance, mask-ablated training, and other options 127 | 128 | To train your model with mask guidance, simply add the options: 129 | ```bash 130 | --seg_dir {MASK_FOLDER} \ 131 | --segmentation_guided \ 132 | --num_segmentation_classes {N_SEGMENTATION_CLASSES} \ 133 | ``` 134 | 135 | where: 136 | - `MASK_FOLDER` is the path to your segmentation mask directory, as outlined in the previous section. 137 | - `N_SEGMENTATION_CLASSES` is the number of classes in your segmentation masks, **including the background (0) class**. 138 | 139 | To also train your model with mask ablation (randomly removing classes from the masks to each the model to condition on masks with missing classes; see our paper for details), simply also add the option `--use_ablated_segmentations`. 140 | 141 | ## 3) Evaluation/Sampling 142 | 143 | Sampling images with a trained model is run similarly to training. For example, 100 samples from an unconditional model can be generated with the command: 144 | ```bash 145 | CUDA_VISIBLE_DEVICES={DEVICES} python3 main.py \ 146 | --mode eval_many \ 147 | --model_type DDIM \ 148 | --img_size 256 \ 149 | --num_img_channels {NUM_IMAGE_CHANNELS} \ 150 | --dataset {DATASET_NAME} \ 151 | --eval_batch_size 8 \ 152 | --eval_sample_size 100 153 | ``` 154 | 155 | Note that the code will automatically use the checkpoint from the training run, and will save the generated images to a directory called `samples` in the model's output directory. To sample from a model with segmentation guidance, simply add the options: 156 | ```bash 157 | --seg_dir {MASK_FOLDER} \ 158 | --segmentation_guided \ 159 | --num_segmentation_classes {N_SEGMENTATION_CLASSES} \ 160 | ``` 161 | This will generate images conditioned on the segmentation masks in `MASK_FOLDER/all/test`. Segmentation masks should be saved as image files (e.g., `.png`) with integer values starting at zero for each object class, i.e., 0, 1, 2. 162 | 163 | ## Additional Options/Config (learning rate, etc.) 164 | Our code has further options for training and evaluation; run `python3 main.py --help` for more information. Further settings still can be changed under `class TrainingConfig:` in `training.py` (some of which are exposed as command-line options for `main.py`, and some of which are not). 165 | 166 | ## Troubleshooting/Bugfixing 167 | - **Noisy generated images**: Sometimes your model may be generating images which have some noise; see https://github.com/mazurowski-lab/segmentation-guided-diffusion/issues/12 for example. Our suggested fix would be to either reduce the learning rate (e.g., to `2e-5`), at: https://github.com/mazurowski-lab/segmentation-guided-diffusion/blob/b1ef8b137eaaefab0210e52b4c49f34ff6067fa6/training.py#L29 or simply try training for more epochs. 168 | - Some users have reported a [bug](https://github.com/mazurowski-lab/segmentation-guided-diffusion/issues/11) when the model attempts to save during training and they receive an error of `module 'safetensors' has no attribute 'torch'`. This appears to be an issue with the `diffusers` library itself in some environments, and may be remedied by [this proposed solution](https://github.com/mazurowski-lab/segmentation-guided-diffusion/issues/11#issuecomment-2251890600). 169 | 170 | ## Synthetic Paired Breast MRI Dataset Release 171 | 172 | 173 | 174 | We also release synthetic 2D breast MRI slice images that are paired/"pre-registered" in terms of blood vessels and fibroglandular tissue to real image counterparts, generated by our segmentation-guided model. These were created by applying our mask-ablated-trained segmentation-guided model to the existing segmentation masks of the held-out training set and test sets of our paper (30 patients total or about ~5000 2D slice images; see the paper for more information), but with the breast mask removed. Because of this, each of these synthetic images have blood vessels and fibroglandular tissues that are registered to/spatially match those of a real image counterpart, but have different surrounding breast tissue and shape. This paired data enables potential applications such as training some breast MRI registration model, self-supervised learning, etc. 175 | 176 | The data can be downloaded [here](https://drive.google.com/file/d/1yaLLdzMhAjWUEzdkTa5FjbX3d7fKvQxM/view), and is organized as follows. 177 | 178 | ### Filename convention/dataset organization 179 | The generated images are stored in `synthetic_data` in the downloadable `.zip` file above, with the filename convention `condon_Breast_MRI_{PATIENTNUMBER}_slice_{SLICENUMBER}.png`, where `PATIENTNUMBER` is the original dataset's patient number that the image comes from, and `SLICE_NUMBER` is the $z$-axis index of the original 3D MRI image that the 2D slice image was taken from. The corresponding real slice image, found in `real_data`, is then named `Breast_MRI_{PATIENTNUMBER}_slice_{SLICENUMBER}.png`. Finally, the corresponding segmentation mask from which the blood vessel and fibrogladular tissue masks were used to generate the images are in `segmentations` (the breast masks are also present, but were not used to generate the images). For example: 180 | ``` 181 | synthetic_data 182 | │ ├── condon_Breast_MRI_002_slice_0.png 183 | │ ├── condon_Breast_MRI_002_slice_1.png 184 | │ └── ... 185 | real_data 186 | │ ├── Breast_MRI_002_slice_0.png 187 | │ ├── Breast_MRI_002_slice_1.png 188 | │ └── ... 189 | segmentations 190 | │ ├── Breast_MRI_002_slice_0.png 191 | │ ├── Breast_MRI_002_slice_1.png 192 | │ └── ... 193 | ``` 194 | 195 | ### Dataset Citation and License 196 | If you use this data, please cite both our paper (see **Citation** above) and the original breast MRI dataset (below), and follow the [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license. 197 | ```bib 198 | @misc{sahadata, 199 | author={Saha, Ashirbani and Harowicz, Michael R and Grimm, Lars J and Kim, Connie E and Ghate, Sujata V and Walsh, Ruth and Mazurowski, Maciej A}, 200 | title = {Dynamic contrast-enhanced magnetic resonance images of breast cancer patients with tumor locations [Data set]}, 201 | year = {2021}, 202 | publisher = {The Cancer Imaging Archive}, 203 | howpublished = {\url{https://doi.org/10.7937/TCIA.e3sv-re93}}, 204 | } 205 | ``` 206 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """ 2 | model evaluation/sampling 3 | """ 4 | import math 5 | import os 6 | import torch 7 | from typing import List, Optional, Tuple, Union 8 | from tqdm import tqdm 9 | from copy import deepcopy 10 | import numpy as np 11 | 12 | import diffusers 13 | from diffusers import DiffusionPipeline, ImagePipelineOutput, DDIMScheduler 14 | from diffusers.utils.torch_utils import randn_tensor 15 | 16 | from utils import make_grid 17 | from torchvision.utils import save_image 18 | import matplotlib.pyplot as plt 19 | from matplotlib.colors import ListedColormap 20 | 21 | #################### 22 | # segmentation-guided DDPM 23 | #################### 24 | 25 | def evaluate_sample_many( 26 | sample_size, 27 | config, 28 | model, 29 | noise_scheduler, 30 | eval_dataloader, 31 | device='cuda' 32 | ): 33 | 34 | # for loading segs to condition on: 35 | # setup for sampling 36 | if config.model_type == "DDPM": 37 | if config.segmentation_guided: 38 | pipeline = SegGuidedDDPMPipeline( 39 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config 40 | ) 41 | else: 42 | pipeline = diffusers.DDPMPipeline(unet=model.module, scheduler=noise_scheduler) 43 | elif config.model_type == "DDIM": 44 | if config.segmentation_guided: 45 | pipeline = SegGuidedDDIMPipeline( 46 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config 47 | ) 48 | else: 49 | pipeline = diffusers.DDIMPipeline(unet=model.module, scheduler=noise_scheduler) 50 | 51 | 52 | sample_dir = test_dir = os.path.join(config.output_dir, "samples_many_{}".format(sample_size)) 53 | if not os.path.exists(sample_dir): 54 | os.makedirs(sample_dir) 55 | 56 | num_sampled = 0 57 | # keep sampling images until we have enough 58 | for bidx, seg_batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)): 59 | if num_sampled < sample_size: 60 | if config.segmentation_guided: 61 | current_batch_size = [v for k, v in seg_batch.items() if k.startswith("seg_")][0].shape[0] 62 | else: 63 | current_batch_size = config.eval_batch_size 64 | 65 | if config.segmentation_guided: 66 | images = pipeline( 67 | batch_size = current_batch_size, 68 | seg_batch=seg_batch, 69 | ).images 70 | else: 71 | images = pipeline( 72 | batch_size = current_batch_size, 73 | ).images 74 | 75 | # save each image in the list separately 76 | for i, img in enumerate(images): 77 | if config.segmentation_guided: 78 | # name base on input mask fname 79 | img_fname = "{}/condon_{}".format(sample_dir, seg_batch["image_filenames"][i]) 80 | else: 81 | img_fname = f"{sample_dir}/{num_sampled + i:04d}.png" 82 | img.save(img_fname) 83 | 84 | num_sampled += len(images) 85 | print("sampled {}/{}.".format(num_sampled, sample_size)) 86 | 87 | 88 | 89 | def evaluate_generation( 90 | config, 91 | model, 92 | noise_scheduler, 93 | eval_dataloader, 94 | class_label_cfg=None, 95 | translate=False, 96 | eval_mask_removal=False, 97 | eval_blank_mask=False, 98 | device='cuda' 99 | ): 100 | """ 101 | general function to evaluate (possibly mask-guided) trained image generation model in useful ways. 102 | also has option to use CFG for class-conditioned sampling (otherwise, class-conditional models will be evaluated using naive class conditioning and sampling from both classes). 103 | 104 | can also evaluate for image translation. 105 | """ 106 | 107 | # for loading segs to condition on: 108 | eval_dataloader = iter(eval_dataloader) 109 | 110 | if config.segmentation_guided: 111 | seg_batch = next(eval_dataloader) 112 | if eval_blank_mask: 113 | # use blank masks 114 | for k, v in seg_batch.items(): 115 | if k.startswith("seg_"): 116 | seg_batch[k] = torch.zeros_like(v) 117 | 118 | # setup for sampling 119 | # After each epoch you optionally sample some demo images with evaluate() and save the model 120 | if config.model_type == "DDPM": 121 | if config.segmentation_guided: 122 | pipeline = SegGuidedDDPMPipeline( 123 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config 124 | ) 125 | else: 126 | pipeline = diffusers.DDPMPipeline(unet=model.module, scheduler=noise_scheduler) 127 | elif config.model_type == "DDIM": 128 | if config.segmentation_guided: 129 | pipeline = SegGuidedDDIMPipeline( 130 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config 131 | ) 132 | else: 133 | pipeline = diffusers.DDIMPipeline(unet=model.module, scheduler=noise_scheduler) 134 | 135 | # sample some images 136 | if config.segmentation_guided: 137 | evaluate(config, -1, pipeline, seg_batch, class_label_cfg, translate) 138 | else: 139 | if config.class_conditional: 140 | raise NotImplementedError("TODO: implement CFG and naive conditioning sampling for non-seg-guided pipelines, including for image translation") 141 | evaluate(config, -1, pipeline) 142 | 143 | # seg-guided specific visualizations 144 | if config.segmentation_guided and eval_mask_removal: 145 | plot_result_masks_multiclass = True 146 | if plot_result_masks_multiclass: 147 | pipeoutput_type = 'np' 148 | else: 149 | pipeoutput_type = 'pil' 150 | 151 | # visualize segmentation-guided sampling by seeing what happens 152 | # when segs removed 153 | num_viz = config.eval_batch_size 154 | 155 | # choose one seg to sample from; duplicate it 156 | eval_same_image = False 157 | if eval_same_image: 158 | seg_batch = {k: torch.cat(num_viz*[v[:1]]) for k, v in seg_batch.items()} 159 | 160 | result_masks = torch.Tensor() 161 | multiclass_masks = [] 162 | result_imgs = [] 163 | multiclass_masks_shape = (config.eval_batch_size, 1, config.image_size, config.image_size) 164 | 165 | # will plot segs + sampled images 166 | for seg_type in seg_batch.keys(): 167 | if seg_type.startswith("seg_"): 168 | #convert from tensor to PIL 169 | seg_batch_plt = seg_batch[seg_type].cpu() 170 | result_masks = torch.cat((result_masks, seg_batch_plt)) 171 | 172 | # sample given all segs 173 | multiclass_masks.append(convert_segbatch_to_multiclass(multiclass_masks_shape, seg_batch, config, device)) 174 | full_seg_imgs = pipeline( 175 | batch_size = num_viz, 176 | seg_batch=seg_batch, 177 | class_label_cfg=class_label_cfg, 178 | translate=translate, 179 | output_type=pipeoutput_type 180 | ).images 181 | if plot_result_masks_multiclass: 182 | result_imgs.append(full_seg_imgs) 183 | else: 184 | result_imgs += full_seg_imgs 185 | 186 | # only sample from masks with chosen classes removed 187 | chosen_class_combinations = None 188 | #chosen_class_combinations = [ #for example: 189 | # {"seg_all": [1, 2]} 190 | #] 191 | if chosen_class_combinations is not None: 192 | for allseg_classes in chosen_class_combinations: 193 | # remove all chosen classes 194 | seg_batch_removed = deepcopy(seg_batch) 195 | for seg_type in seg_batch_removed.keys(): 196 | # some datasets have multiple tissue segs stored in multiple masks 197 | if seg_type.startswith("seg_"): 198 | classes = allseg_classes[seg_type] 199 | for mask_val in classes: 200 | if mask_val != 0: 201 | remove_mask = (seg_batch_removed[seg_type]*255).int() == mask_val 202 | seg_batch_removed[seg_type][remove_mask] = 0 203 | 204 | seg_batch_removed_plt = torch.cat([seg_batch_removed[seg_type].cpu() for seg_type in seg_batch_removed.keys() if seg_type.startswith("seg_")]) 205 | result_masks = torch.cat((result_masks, seg_batch_removed_plt)) 206 | 207 | multiclass_masks.append(convert_segbatch_to_multiclass( 208 | multiclass_masks_shape, 209 | seg_batch_removed, config, device)) 210 | # add images conditioned on some segs but not all 211 | removed_seg_imgs = pipeline( 212 | batch_size = config.eval_batch_size, 213 | seg_batch=seg_batch_removed, 214 | class_label_cfg=class_label_cfg, 215 | translate=translate, 216 | output_type=pipeoutput_type 217 | ).images 218 | 219 | if plot_result_masks_multiclass: 220 | result_imgs.append(removed_seg_imgs) 221 | else: 222 | result_imgs += removed_seg_imgs 223 | 224 | 225 | else: 226 | for seg_type in seg_batch.keys(): 227 | # some datasets have multiple tissue segs stored in multiple masks 228 | if seg_type.startswith("seg_"): 229 | seg_batch_removed = seg_batch 230 | for mask_val in seg_batch[seg_type].unique(): 231 | if mask_val != 0: 232 | remove_mask = seg_batch[seg_type] == mask_val 233 | seg_batch_removed[seg_type][remove_mask] = 0 234 | 235 | seg_batch_removed_plt = torch.cat([seg_batch_removed[seg_type].cpu() for seg_type in seg_batch.keys() if seg_type.startswith("seg_")]) 236 | result_masks = torch.cat((result_masks, seg_batch_removed_plt)) 237 | 238 | multiclass_masks.append(convert_segbatch_to_multiclass( 239 | multiclass_masks_shape, 240 | seg_batch_removed, config, device)) 241 | # add images conditioned on some segs but not all 242 | removed_seg_imgs = pipeline( 243 | batch_size = config.eval_batch_size, 244 | seg_batch=seg_batch_removed, 245 | class_label_cfg=class_label_cfg, 246 | translate=translate, 247 | output_type=pipeoutput_type 248 | ).images 249 | 250 | if plot_result_masks_multiclass: 251 | result_imgs.append(removed_seg_imgs) 252 | else: 253 | result_imgs += removed_seg_imgs 254 | 255 | if plot_result_masks_multiclass: 256 | multiclass_masks = np.squeeze(torch.cat(multiclass_masks).cpu().numpy()) 257 | multiclass_masks = (multiclass_masks*255).astype(np.uint8) 258 | result_imgs = np.squeeze(np.concatenate(np.array(result_imgs), axis=0)) 259 | 260 | # reverse interleave 261 | plot_imgs = np.zeros_like(result_imgs) 262 | plot_imgs[0:len(plot_imgs)//2] = result_imgs[0::2] 263 | plot_imgs[len(plot_imgs)//2:] = result_imgs[1::2] 264 | 265 | plot_masks = np.zeros_like(multiclass_masks) 266 | plot_masks[0:len(plot_masks)//2] = multiclass_masks[0::2] 267 | plot_masks[len(plot_masks)//2:] = multiclass_masks[1::2] 268 | 269 | fig, axs = plt.subplots( 270 | 2, len(plot_masks), 271 | figsize=(len(plot_masks), 2), 272 | dpi=600 273 | ) 274 | 275 | for i, img in enumerate(plot_imgs): 276 | if config.dataset == 'breast_mri': 277 | colors = ['black', 'white', 'red', 'blue'] 278 | elif config.dataset == 'ct_organ_large': 279 | colors = ['black', 'blue', 'green', 'red', 'yellow', 'magenta'] 280 | else: 281 | raise ValueError('Unknown dataset') 282 | 283 | cmap = ListedColormap(colors) 284 | axs[0,i].imshow(plot_masks[i], cmap=cmap, vmin=0, vmax=len(colors)-1) 285 | axs[0,i].axis('off') 286 | axs[1,i].imshow(img, cmap='gray') 287 | axs[1,i].axis('off') 288 | 289 | plt.subplots_adjust(wspace=0, hspace=0) 290 | plt.savefig('ablated_samples_{}.pdf'.format(config.dataset), bbox_inches='tight') 291 | plt.show() 292 | 293 | 294 | 295 | else: 296 | # Make a grid out of the images 297 | cols = num_viz 298 | rows = math.ceil(len(result_imgs) / cols) 299 | image_grid = make_grid(result_imgs, rows=rows, cols=cols) 300 | 301 | # Save the images 302 | test_dir = os.path.join(config.output_dir, "samples") 303 | os.makedirs(test_dir, exist_ok=True) 304 | image_grid.save(f"{test_dir}/mask_removal_imgs.png") 305 | 306 | save_image(result_masks, f"{test_dir}/mask_removal_masks.png", normalize=True, 307 | nrow=cols*len(seg_batch.keys()) - 2) 308 | 309 | def convert_segbatch_to_multiclass(shape, segmentations_batch, config, device): 310 | # NOTE: this generic function assumes that segs don't overlap 311 | # put all segs on same channel 312 | segs = torch.zeros(shape).to(device) 313 | for k, seg in segmentations_batch.items(): 314 | if k.startswith("seg_"): 315 | seg = seg.to(device) 316 | segs[segs == 0] = seg[segs == 0] 317 | 318 | if config.use_ablated_segmentations: 319 | # randomly remove class labels from segs with some probability 320 | segs = ablate_masks(segs, config) 321 | 322 | return segs 323 | 324 | def ablate_masks(segs, config, method="equal_weighted"): 325 | # randomly remove class label(s) from segs with some probability 326 | if method == "equal_weighted": 327 | """ 328 | # give equal probability to each possible combination of removing non-background classes 329 | # NOTE: requires that each class has a value in ({0, 1, 2, ...} / 255) 330 | # which is by default if the mask file was saved as {0, 1, 2 ,...} and then normalized by default to [0, 1] by transforms.ToTensor() 331 | # num_segmentation_classes 332 | """ 333 | class_removals = (torch.rand(config.num_segmentation_classes - 1) < 0.5).int().bool().tolist() 334 | for class_idx, remove_class in enumerate(class_removals): 335 | if remove_class: 336 | segs[(255 * segs).int() == class_idx + 1] = 0 337 | 338 | elif method == "by_class": 339 | class_ablation_prob = 0.3 340 | for seg_value in segs.unique(): 341 | if seg_value != 0: 342 | # remove seg with some probability 343 | if torch.rand(1).item() < class_ablation_prob: 344 | segs[segs == seg_value] = 0 345 | 346 | else: 347 | raise NotImplementedError 348 | return segs 349 | 350 | def add_segmentations_to_noise(noisy_images, segmentations_batch, config, device): 351 | """ 352 | concat segmentations to noisy image 353 | """ 354 | 355 | if config.segmentation_channel_mode == "single": 356 | multiclass_masks_shape = (noisy_images.shape[0], 1, noisy_images.shape[2], noisy_images.shape[3]) 357 | segs = convert_segbatch_to_multiclass(multiclass_masks_shape, segmentations_batch, config, device) 358 | # concat segs to noise 359 | noisy_images = torch.cat((noisy_images, segs), dim=1) 360 | 361 | elif config.segmentation_channel_mode == "multi": 362 | raise NotImplementedError 363 | 364 | return noisy_images 365 | 366 | #################### 367 | # general DDPM 368 | #################### 369 | def evaluate(config, epoch, pipeline, seg_batch=None, class_label_cfg=None, translate=False): 370 | # Either generate or translate images, 371 | # possibly mask guided and/or class conditioned. 372 | # The default pipeline output type is `List[PIL.Image]` 373 | 374 | if config.segmentation_guided: 375 | images = pipeline( 376 | batch_size = config.eval_batch_size, 377 | seg_batch=seg_batch, 378 | class_label_cfg=class_label_cfg, 379 | translate=translate 380 | ).images 381 | else: 382 | images = pipeline( 383 | batch_size = config.eval_batch_size, 384 | # TODO: implement CFG and naive conditioning sampling for non-seg-guided pipelines (also needed for translation) 385 | ).images 386 | 387 | # Make a grid out of the images 388 | cols = 4 389 | rows = math.ceil(len(images) / cols) 390 | image_grid = make_grid(images, rows=rows, cols=cols) 391 | 392 | # Save the images 393 | test_dir = os.path.join(config.output_dir, "samples") 394 | os.makedirs(test_dir, exist_ok=True) 395 | image_grid.save(f"{test_dir}/{epoch:04d}.png") 396 | 397 | # save segmentations we conditioned the samples on 398 | if config.segmentation_guided: 399 | for seg_type in seg_batch.keys(): 400 | if seg_type.startswith("seg_"): 401 | save_image(seg_batch[seg_type], f"{test_dir}/{epoch:04d}_cond_{seg_type}.png", normalize=True, nrow=cols) 402 | 403 | # as well as original images that the segs belong to 404 | img_og = seg_batch['images'] 405 | save_image(img_og, f"{test_dir}/{epoch:04d}_orig.png", normalize=True, nrow=cols) 406 | 407 | 408 | # custom diffusers pipelines for sampling from segmentation-guided models 409 | class SegGuidedDDPMPipeline(DiffusionPipeline): 410 | r""" 411 | Pipeline for segmentation-guided image generation, modified from DDPMPipeline. 412 | generates both-class conditioned and unconditional images if using class-conditional model without CFG, or just generates 413 | conditional images guided by CFG. 414 | 415 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 416 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 417 | 418 | Parameters: 419 | unet ([`UNet2DModel`]): 420 | A `UNet2DModel` to denoise the encoded image latents. 421 | scheduler ([`SchedulerMixin`]): 422 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 423 | [`DDPMScheduler`], or [`DDIMScheduler`]. 424 | eval_dataloader ([`torch.utils.data.DataLoader`]): 425 | Dataloader to load the evaluation dataset of images and their segmentations. Here only uses the segmentations to generate images. 426 | """ 427 | model_cpu_offload_seq = "unet" 428 | 429 | def __init__(self, unet, scheduler, eval_dataloader, external_config): 430 | super().__init__() 431 | self.register_modules(unet=unet, scheduler=scheduler) 432 | self.eval_dataloader = eval_dataloader 433 | self.external_config = external_config # config is already a thing 434 | 435 | @torch.no_grad() 436 | def __call__( 437 | self, 438 | batch_size: int = 1, 439 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 440 | num_inference_steps: int = 1000, 441 | output_type: Optional[str] = "pil", 442 | return_dict: bool = True, 443 | seg_batch: Optional[torch.Tensor] = None, 444 | class_label_cfg: Optional[int] = None, 445 | translate = False, 446 | ) -> Union[ImagePipelineOutput, Tuple]: 447 | r""" 448 | The call function to the pipeline for generation. 449 | 450 | Args: 451 | batch_size (`int`, *optional*, defaults to 1): 452 | The number of images to generate. 453 | generator (`torch.Generator`, *optional*): 454 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 455 | generation deterministic. 456 | num_inference_steps (`int`, *optional*, defaults to 1000): 457 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 458 | expense of slower inference. 459 | output_type (`str`, *optional*, defaults to `"pil"`): 460 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 461 | return_dict (`bool`, *optional*, defaults to `True`): 462 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 463 | seg_batch (`torch.Tensor`, *optional*, defaults to None): 464 | batch of segmentations to condition generation on 465 | class_label_cfg (`int`, *optional*, defaults to `None`): 466 | class label to condition generation on using CFG, if using class-conditional model 467 | 468 | OPTIONS FOR IMAGE TRANSLATION: 469 | translate (`bool`, *optional*, defaults to False): 470 | whether to translate images from the source domain to the target domain 471 | 472 | Returns: 473 | [`~pipelines.ImagePipelineOutput`] or `tuple`: 474 | If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is 475 | returned where the first element is a list with the generated images 476 | """ 477 | # Sample gaussian noise to begin loop 478 | if self.external_config.segmentation_channel_mode == "single": 479 | img_channel_ct = self.unet.config.in_channels - 1 480 | elif self.external_config.segmentation_channel_mode == "multi": 481 | img_channel_ct = self.unet.config.in_channels - len([k for k in seg_batch.keys() if k.startswith("seg_")]) 482 | 483 | if isinstance(self.unet.config.sample_size, int): 484 | image_shape = ( 485 | batch_size, 486 | img_channel_ct, 487 | self.unet.config.sample_size, 488 | self.unet.config.sample_size, 489 | ) 490 | else: 491 | if self.external_config.segmentation_channel_mode == "single": 492 | image_shape = (batch_size, self.unet.config.in_channels - 1, *self.unet.config.sample_size) 493 | elif self.external_config.segmentation_channel_mode == "multi": 494 | image_shape = (batch_size, self.unet.config.in_channels - len([k for k in seg_batch.keys() if k.startswith("seg_")]), *self.unet.config.sample_size) 495 | 496 | 497 | # initiate latent variable to sample from 498 | if not translate: 499 | # normal sampling; start from noise 500 | if self.device.type == "mps": 501 | # randn does not work reproducibly on mps 502 | image = randn_tensor(image_shape, generator=generator) 503 | image = image.to(self.device) 504 | else: 505 | image = randn_tensor(image_shape, generator=generator, device=self.device) 506 | else: 507 | # image translation sampling; start from source domain images, add noise up to certain step, then being there for denoising 508 | trans_start_t = int(self.external_config.trans_noise_level * self.scheduler.config.num_train_timesteps) 509 | 510 | trans_start_images = seg_batch["images"] 511 | 512 | # Sample noise to add to the images 513 | noise = torch.randn(trans_start_images.shape).to(trans_start_images.device) 514 | timesteps = torch.full( 515 | (trans_start_images.size(0),), 516 | trans_start_t, 517 | device=trans_start_images.device 518 | ).long() 519 | image = self.scheduler.add_noise(trans_start_images, noise, timesteps) 520 | 521 | # set step values 522 | self.scheduler.set_timesteps(num_inference_steps) 523 | 524 | for t in self.progress_bar(self.scheduler.timesteps): 525 | if translate: 526 | # if doing translation, start at chosen time step given partially-noised image 527 | # skip all earlier time steps (with higher t) 528 | if t >= trans_start_t: 529 | continue 530 | 531 | # 1. predict noise model_output 532 | # first, concat segmentations to noise 533 | image = add_segmentations_to_noise(image, seg_batch, self.external_config, self.device) 534 | 535 | if self.external_config.class_conditional: 536 | if class_label_cfg is not None: 537 | class_labels = torch.full([image.size(0)], class_label_cfg).long().to(self.device) 538 | model_output_cond = self.unet(image, t, class_labels=class_labels).sample 539 | if self.external_config.use_cfg_for_eval_conditioning: 540 | # use classifier-free guidance for sampling from the given class 541 | 542 | if self.external_config.cfg_maskguidance_condmodel_only: 543 | image_emptymask = torch.cat((image[:, :img_channel_ct, :, :], torch.zeros_like(image[:, img_channel_ct:, :, :])), dim=1) 544 | model_output_uncond = self.unet(image_emptymask, t, 545 | class_labels=torch.zeros_like(class_labels).long()).sample 546 | else: 547 | model_output_uncond = self.unet(image, t, 548 | class_labels=torch.zeros_like(class_labels).long()).sample 549 | 550 | # use cfg equation 551 | model_output = (1. + self.external_config.cfg_weight) * model_output_cond - self.external_config.cfg_weight * model_output_uncond 552 | else: 553 | # just use normal conditioning 554 | model_output = model_output_cond 555 | 556 | else: 557 | # or, just use basic network conditioning to sample from both classes 558 | if self.external_config.class_conditional: 559 | # if training conditionally, evaluate source domain samples 560 | class_labels = torch.ones(image.size(0)).long().to(self.device) 561 | model_output = self.unet(image, t, class_labels=class_labels).sample 562 | else: 563 | model_output = self.unet(image, t).sample 564 | # output is slightly denoised image 565 | 566 | # 2. compute previous image: x_t -> x_t-1 567 | # but first, we're only adding denoising the image channel (not seg channel), 568 | # so remove segs 569 | image = image[:, :img_channel_ct, :, :] 570 | image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample 571 | 572 | # if training conditionally, also evaluate for target domain images 573 | # if not using chosen class for CFG 574 | if self.external_config.class_conditional and class_label_cfg is None: 575 | image_target_domain = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) 576 | 577 | # set step values 578 | self.scheduler.set_timesteps(num_inference_steps) 579 | 580 | for t in self.progress_bar(self.scheduler.timesteps): 581 | # 1. predict noise model_output 582 | # first, concat segmentations to noise 583 | # no masks in target domain so just use blank masks 584 | image_target_domain = torch.cat((image_target_domain, torch.zeros_like(image_target_domain)), dim=1) 585 | 586 | if self.external_config.class_conditional: 587 | # if training conditionally, also evaluate unconditional model and target domain (no masks) 588 | class_labels = torch.cat([torch.full([image_target_domain.size(0) // 2], 2), torch.zeros(image_target_domain.size(0)) // 2]).long().to(self.device) 589 | model_output = self.unet(image_target_domain, t, class_labels=class_labels).sample 590 | else: 591 | model_output = self.unet(image_target_domain, t).sample 592 | 593 | # 2. predict previous mean of image x_t-1 and add variance depending on eta 594 | # eta corresponds to η in paper and should be between [0, 1] 595 | # do x_t -> x_t-1 596 | # but first, we're only adding denoising the image channel (not seg channel), 597 | # so remove segs 598 | image_target_domain = image_target_domain[:, :img_channel_ct, :, :] 599 | image_target_domain = self.scheduler.step( 600 | model_output, t, image_target_domain, generator=generator 601 | ).prev_sample 602 | 603 | image = torch.cat((image, image_target_domain), dim=0) 604 | # will output source domain images first, then target domain images 605 | 606 | image = (image / 2 + 0.5).clamp(0, 1) 607 | image = image.cpu().permute(0, 2, 3, 1).numpy() 608 | if output_type == "pil": 609 | image = self.numpy_to_pil(image) 610 | 611 | if not return_dict: 612 | return (image,) 613 | 614 | return ImagePipelineOutput(images=image) 615 | 616 | class SegGuidedDDIMPipeline(DiffusionPipeline): 617 | r""" 618 | Pipeline for image generation, modified for seg-guided image gen. 619 | modified from diffusers.DDIMPipeline. 620 | generates both-class conditioned and unconditional images if using class-conditional model without CFG, or just generates 621 | conditional images guided by CFG. 622 | 623 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods 624 | implemented for all pipelines (downloading, saving, running on a particular device, etc.). 625 | 626 | Parameters: 627 | unet ([`UNet2DModel`]): 628 | A `UNet2DModel` to denoise the encoded image latents. 629 | scheduler ([`SchedulerMixin`]): 630 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of 631 | [`DDPMScheduler`], or [`DDIMScheduler`]. 632 | eval_dataloader ([`torch.utils.data.DataLoader`]): 633 | Dataloader to load the evaluation dataset of images and their segmentations. Here only uses the segmentations to generate images. 634 | 635 | """ 636 | model_cpu_offload_seq = "unet" 637 | 638 | def __init__(self, unet, scheduler, eval_dataloader, external_config): 639 | super().__init__() 640 | self.register_modules(unet=unet, scheduler=scheduler, eval_dataloader=eval_dataloader, external_config=external_config) 641 | # ^ some reason necessary for DDIM but not DDPM. 642 | 643 | self.eval_dataloader = eval_dataloader 644 | self.external_config = external_config # config is already a thing 645 | 646 | # make sure scheduler can always be converted to DDIM 647 | scheduler = DDIMScheduler.from_config(scheduler.config) 648 | 649 | 650 | @torch.no_grad() 651 | def __call__( 652 | self, 653 | batch_size: int = 1, 654 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 655 | eta: float = 0.0, 656 | num_inference_steps: int = 50, 657 | use_clipped_model_output: Optional[bool] = None, 658 | output_type: Optional[str] = "pil", 659 | return_dict: bool = True, 660 | seg_batch: Optional[torch.Tensor] = None, 661 | class_label_cfg: Optional[int] = None, 662 | translate = False, 663 | ) -> Union[ImagePipelineOutput, Tuple]: 664 | r""" 665 | The call function to the pipeline for generation. 666 | 667 | Args: 668 | batch_size (`int`, *optional*, defaults to 1): 669 | The number of images to generate. 670 | generator (`torch.Generator`, *optional*): 671 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make 672 | generation deterministic. 673 | eta (`float`, *optional*, defaults to 0.0): 674 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies 675 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to 676 | DDIM and `1` corresponds to DDPM. 677 | num_inference_steps (`int`, *optional*, defaults to 50): 678 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 679 | expense of slower inference. 680 | use_clipped_model_output (`bool`, *optional*, defaults to `None`): 681 | If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed 682 | downstream to the scheduler (use `None` for schedulers which don't support this argument). 683 | output_type (`str`, *optional*, defaults to `"pil"`): 684 | The output format of the generated image. Choose between `PIL.Image` or `np.array`. 685 | return_dict (`bool`, *optional*, defaults to `True`): 686 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple. 687 | seg_batch (`torch.Tensor`, *optional*): 688 | batch of segmentations to condition generation on 689 | class_label_cfg (`int`, *optional*, defaults to `None`): 690 | class label to condition generation on using CFG, if using class-conditional model 691 | 692 | OPTIONS FOR IMAGE TRANSLATION: 693 | translate (`bool`, *optional*, defaults to False): 694 | whether to translate images from the source domain to the target domain 695 | 696 | Example: 697 | 698 | ```py 699 | 700 | Returns: 701 | [`~pipelines.ImagePipelineOutput`] or `tuple`: 702 | If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is 703 | returned where the first element is a list with the generated images 704 | """ 705 | 706 | # Sample gaussian noise to begin loop 707 | if self.external_config.segmentation_channel_mode == "single": 708 | img_channel_ct = self.unet.config.in_channels - 1 709 | elif self.external_config.segmentation_channel_mode == "multi": 710 | img_channel_ct = self.unet.config.in_channels - len([k for k in seg_batch.keys() if k.startswith("seg_")]) 711 | 712 | if isinstance(self.unet.config.sample_size, int): 713 | if self.external_config.segmentation_channel_mode == "single": 714 | image_shape = ( 715 | batch_size, 716 | self.unet.config.in_channels - 1, 717 | self.unet.config.sample_size, 718 | self.unet.config.sample_size, 719 | ) 720 | elif self.external_config.segmentation_channel_mode == "multi": 721 | image_shape = ( 722 | batch_size, 723 | self.unet.config.in_channels - len([k for k in seg_batch.keys() if k.startswith("seg_")]), 724 | self.unet.config.sample_size, 725 | self.unet.config.sample_size, 726 | ) 727 | else: 728 | if self.external_config.segmentation_channel_mode == "single": 729 | image_shape = (batch_size, self.unet.config.in_channels - 1, *self.unet.config.sample_size) 730 | elif self.external_config.segmentation_channel_mode == "multi": 731 | image_shape = (batch_size, self.unet.config.in_channels - len([k for k in seg_batch.keys() if k.startswith("seg_")]), *self.unet.config.sample_size) 732 | 733 | if isinstance(generator, list) and len(generator) != batch_size: 734 | raise ValueError( 735 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 736 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 737 | ) 738 | 739 | # initiate latent variable to sample from 740 | if not translate: 741 | # normal sampling; start from noise 742 | image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) 743 | else: 744 | # image translation sampling; start from source domain images, add noise up to certain step, then being there for denoising 745 | trans_start_t = int(self.external_config.trans_noise_level * self.scheduler.config.num_train_timesteps) 746 | 747 | trans_start_images = seg_batch["images"].to(self._execution_device) 748 | 749 | # Sample noise to add to the images 750 | noise = torch.randn(trans_start_images.shape).to(trans_start_images.device) 751 | timesteps = torch.full( 752 | (trans_start_images.size(0),), 753 | trans_start_t, 754 | device=trans_start_images.device 755 | ).long() 756 | image = self.scheduler.add_noise(trans_start_images, noise, timesteps) 757 | 758 | # set step values 759 | self.scheduler.set_timesteps(num_inference_steps) 760 | 761 | for t in self.progress_bar(self.scheduler.timesteps): 762 | if translate: 763 | # if doing translation, start at chosen time step given partially-noised image 764 | # skip all earlier time steps (with higher t) 765 | if t >= trans_start_t: 766 | continue 767 | 768 | # 1. predict noise model_output 769 | # first, concat segmentations to noise 770 | image = add_segmentations_to_noise(image, seg_batch, self.external_config, self.device) 771 | 772 | if self.external_config.class_conditional: 773 | if class_label_cfg is not None: 774 | class_labels = torch.full([image.size(0)], class_label_cfg).long().to(self.device) 775 | model_output_cond = self.unet(image, t, class_labels=class_labels).sample 776 | if self.external_config.use_cfg_for_eval_conditioning: 777 | # use classifier-free guidance for sampling from the given class 778 | if self.external_config.cfg_maskguidance_condmodel_only: 779 | image_emptymask = torch.cat((image[:, :img_channel_ct, :, :], torch.zeros_like(image[:, img_channel_ct:, :, :])), dim=1) 780 | model_output_uncond = self.unet(image_emptymask, t, 781 | class_labels=torch.zeros_like(class_labels).long()).sample 782 | else: 783 | model_output_uncond = self.unet(image, t, 784 | class_labels=torch.zeros_like(class_labels).long()).sample 785 | 786 | # use cfg equation 787 | model_output = (1. + self.external_config.cfg_weight) * model_output_cond - self.external_config.cfg_weight * model_output_uncond 788 | else: 789 | model_output = model_output_cond 790 | 791 | else: 792 | # or, just use basic network conditioning to sample from both classes 793 | if self.external_config.class_conditional: 794 | # if training conditionally, evaluate source domain samples 795 | class_labels = torch.ones(image.size(0)).long().to(self.device) 796 | model_output = self.unet(image, t, class_labels=class_labels).sample 797 | else: 798 | model_output = self.unet(image, t).sample 799 | 800 | # 2. predict previous mean of image x_t-1 and add variance depending on eta 801 | # eta corresponds to η in paper and should be between [0, 1] 802 | # do x_t -> x_t-1 803 | # but first, we're only adding denoising the image channel (not seg channel), 804 | # so remove segs 805 | image = image[:, :img_channel_ct, :, :] 806 | image = self.scheduler.step( 807 | model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator 808 | ).prev_sample 809 | 810 | # if training conditionally, also evaluate for target domain images 811 | # if not using chosen class for CFG 812 | if self.external_config.class_conditional and class_label_cfg is None: 813 | image_target_domain = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype) 814 | 815 | # set step values 816 | self.scheduler.set_timesteps(num_inference_steps) 817 | 818 | for t in self.progress_bar(self.scheduler.timesteps): 819 | # 1. predict noise model_output 820 | # first, concat segmentations to noise 821 | # no masks in target domain so just use blank masks 822 | image_target_domain = torch.cat((image_target_domain, torch.zeros_like(image_target_domain)), dim=1) 823 | 824 | if self.external_config.class_conditional: 825 | # if training conditionally, also evaluate unconditional model and target domain (no masks) 826 | class_labels = torch.cat([torch.full([image_target_domain.size(0) // 2], 2), torch.zeros(image_target_domain.size(0) // 2)]).long().to(self.device) 827 | model_output = self.unet(image_target_domain, t, class_labels=class_labels).sample 828 | else: 829 | model_output = self.unet(image_target_domain, t).sample 830 | 831 | # 2. predict previous mean of image x_t-1 and add variance depending on eta 832 | # eta corresponds to η in paper and should be between [0, 1] 833 | # do x_t -> x_t-1 834 | # but first, we're only adding denoising the image channel (not seg channel), 835 | # so remove segs 836 | image_target_domain = image_target_domain[:, :img_channel_ct, :, :] 837 | image_target_domain = self.scheduler.step( 838 | model_output, t, image_target_domain, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator 839 | ).prev_sample 840 | 841 | image = torch.cat((image, image_target_domain), dim=0) 842 | # will output source domain images first, then target domain images 843 | 844 | image = (image / 2 + 0.5).clamp(0, 1) 845 | image = image.cpu().permute(0, 2, 3, 1).numpy() 846 | if output_type == "pil": 847 | image = self.numpy_to_pil(image) 848 | 849 | if not return_dict: 850 | return (image,) 851 | 852 | return ImagePipelineOutput(images=image) 853 | -------------------------------------------------------------------------------- /figs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazurowski-lab/segmentation-guided-diffusion/9a31f0f4eef0d7b835f01b2a6301df854b8040e9/figs/teaser.png -------------------------------------------------------------------------------- /figs/teaser_data.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/mazurowski-lab/segmentation-guided-diffusion/9a31f0f4eef0d7b835f01b2a6301df854b8040e9/figs/teaser_data.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | # torch imports 5 | import torch 6 | from torch import nn 7 | from torchvision import transforms 8 | import torch.nn.functional as F 9 | import numpy as np 10 | 11 | # HF imports 12 | import diffusers 13 | from diffusers.optimization import get_cosine_schedule_with_warmup 14 | import datasets 15 | 16 | # custom imports 17 | from training import TrainingConfig, train_loop 18 | from eval import evaluate_generation, evaluate_sample_many 19 | 20 | def main( 21 | mode, 22 | img_size, 23 | num_img_channels, 24 | dataset, 25 | img_dir, 26 | seg_dir, 27 | model_type, 28 | segmentation_guided, 29 | segmentation_channel_mode, 30 | num_segmentation_classes, 31 | train_batch_size, 32 | eval_batch_size, 33 | num_epochs, 34 | resume_epoch=None, 35 | use_ablated_segmentations=False, 36 | eval_shuffle_dataloader=True, 37 | 38 | # arguments only used in eval 39 | eval_mask_removal=False, 40 | eval_blank_mask=False, 41 | eval_sample_size=1000 42 | ): 43 | #GPUs 44 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 45 | print('running on {}'.format(device)) 46 | 47 | # load config 48 | output_dir = '{}-{}-{}'.format(model_type.lower(), dataset, img_size) # the model namy locally and on the HF Hub 49 | if segmentation_guided: 50 | output_dir += "-segguided" 51 | assert seg_dir is not None, "must provide segmentation directory for segmentation guided training/sampling" 52 | 53 | if use_ablated_segmentations or eval_mask_removal or eval_blank_mask: 54 | output_dir += "-ablated" 55 | 56 | print("output dir: {}".format(output_dir)) 57 | 58 | if mode == "train": 59 | evalset_name = "val" 60 | assert img_dir is not None, "must provide image directory for training" 61 | elif "eval" in mode: 62 | evalset_name = "test" 63 | 64 | print("using evaluation set: {}".format(evalset_name)) 65 | 66 | config = TrainingConfig( 67 | image_size = img_size, 68 | dataset = dataset, 69 | segmentation_guided = segmentation_guided, 70 | segmentation_channel_mode = segmentation_channel_mode, 71 | num_segmentation_classes = num_segmentation_classes, 72 | train_batch_size = train_batch_size, 73 | eval_batch_size = eval_batch_size, 74 | num_epochs = num_epochs, 75 | output_dir = output_dir, 76 | model_type=model_type, 77 | resume_epoch=resume_epoch, 78 | use_ablated_segmentations=use_ablated_segmentations 79 | ) 80 | 81 | load_images_as_np_arrays = False 82 | if num_img_channels not in [1, 3]: 83 | load_images_as_np_arrays = True 84 | print("image channels not 1 or 3, attempting to load images as np arrays...") 85 | 86 | if config.segmentation_guided: 87 | seg_types = os.listdir(seg_dir) 88 | seg_paths_train = {} 89 | seg_paths_eval = {} 90 | 91 | # train set 92 | if img_dir is not None: 93 | # make sure the images are matched to the segmentation masks 94 | img_dir_train = os.path.join(img_dir, "train") 95 | img_paths_train = [os.path.join(img_dir_train, f) for f in os.listdir(img_dir_train)] 96 | for seg_type in seg_types: 97 | seg_paths_train[seg_type] = [os.path.join(seg_dir, seg_type, "train", f) for f in os.listdir(img_dir_train)] 98 | else: 99 | for seg_type in seg_types: 100 | seg_paths_train[seg_type] = [os.path.join(seg_dir, seg_type, "train", f) for f in os.listdir(os.path.join(seg_dir, seg_type, "train"))] 101 | 102 | # eval set 103 | if img_dir is not None: 104 | img_dir_eval = os.path.join(img_dir, evalset_name) 105 | img_paths_eval = [os.path.join(img_dir_eval, f) for f in os.listdir(img_dir_eval)] 106 | for seg_type in seg_types: 107 | seg_paths_eval[seg_type] = [os.path.join(seg_dir, seg_type, evalset_name, f) for f in os.listdir(img_dir_eval)] 108 | else: 109 | for seg_type in seg_types: 110 | seg_paths_eval[seg_type] = [os.path.join(seg_dir, seg_type, evalset_name, f) for f in os.listdir(os.path.join(seg_dir, seg_type, evalset_name))] 111 | 112 | if img_dir is not None: 113 | dset_dict_train = { 114 | **{"image": img_paths_train}, 115 | **{"seg_{}".format(seg_type): seg_paths_train[seg_type] for seg_type in seg_types} 116 | } 117 | 118 | dset_dict_eval = { 119 | **{"image": img_paths_eval}, 120 | **{"seg_{}".format(seg_type): seg_paths_eval[seg_type] for seg_type in seg_types} 121 | } 122 | else: 123 | dset_dict_train = { 124 | **{"seg_{}".format(seg_type): seg_paths_train[seg_type] for seg_type in seg_types} 125 | } 126 | 127 | dset_dict_eval = { 128 | **{"seg_{}".format(seg_type): seg_paths_eval[seg_type] for seg_type in seg_types} 129 | } 130 | 131 | 132 | if img_dir is not None: 133 | # add image filenames to dataset 134 | dset_dict_train["image_filename"] = [os.path.basename(f) for f in dset_dict_train["image"]] 135 | dset_dict_eval["image_filename"] = [os.path.basename(f) for f in dset_dict_eval["image"]] 136 | else: 137 | # use segmentation filenames as image filenames 138 | dset_dict_train["image_filename"] = [os.path.basename(f) for f in dset_dict_train["seg_{}".format(seg_types[0])]] 139 | dset_dict_eval["image_filename"] = [os.path.basename(f) for f in dset_dict_eval["seg_{}".format(seg_types[0])]] 140 | 141 | dataset_train = datasets.Dataset.from_dict(dset_dict_train) 142 | dataset_eval = datasets.Dataset.from_dict(dset_dict_eval) 143 | 144 | # load the images 145 | if not load_images_as_np_arrays and img_dir is not None: 146 | dataset_train = dataset_train.cast_column("image", datasets.Image()) 147 | dataset_eval = dataset_eval.cast_column("image", datasets.Image()) 148 | 149 | for seg_type in seg_types: 150 | dataset_train = dataset_train.cast_column("seg_{}".format(seg_type), datasets.Image()) 151 | 152 | for seg_type in seg_types: 153 | dataset_eval = dataset_eval.cast_column("seg_{}".format(seg_type), datasets.Image()) 154 | 155 | else: 156 | if img_dir is not None: 157 | img_dir_train = os.path.join(img_dir, "train") 158 | img_paths_train = [os.path.join(img_dir_train, f) for f in os.listdir(img_dir_train)] 159 | 160 | img_dir_eval = os.path.join(img_dir, evalset_name) 161 | img_paths_eval = [os.path.join(img_dir_eval, f) for f in os.listdir(img_dir_eval)] 162 | 163 | dset_dict_train = { 164 | **{"image": img_paths_train} 165 | } 166 | 167 | dset_dict_eval = { 168 | **{"image": img_paths_eval} 169 | } 170 | 171 | # add image filenames to dataset 172 | dset_dict_train["image_filename"] = [os.path.basename(f) for f in dset_dict_train["image"]] 173 | dset_dict_eval["image_filename"] = [os.path.basename(f) for f in dset_dict_eval["image"]] 174 | 175 | dataset_train = datasets.Dataset.from_dict(dset_dict_train) 176 | dataset_eval = datasets.Dataset.from_dict(dset_dict_eval) 177 | 178 | # load the images 179 | if not load_images_as_np_arrays: 180 | dataset_train = dataset_train.cast_column("image", datasets.Image()) 181 | dataset_eval = dataset_eval.cast_column("image", datasets.Image()) 182 | 183 | # training set preprocessing 184 | if not load_images_as_np_arrays: 185 | preprocess = transforms.Compose( 186 | [ 187 | transforms.Resize((config.image_size, config.image_size)), 188 | # transforms.RandomHorizontalFlip(), # flipping wouldn't result in realistic images 189 | transforms.ToTensor(), 190 | transforms.Normalize( 191 | num_img_channels * [0.5], 192 | num_img_channels * [0.5]), 193 | ] 194 | ) 195 | else: 196 | # resizing will be done in the transform function 197 | preprocess = transforms.Compose( 198 | [ 199 | transforms.Normalize( 200 | num_img_channels * [0.5], 201 | num_img_channels * [0.5]), 202 | ] 203 | ) 204 | 205 | if num_img_channels == 1: 206 | PIL_image_type = "L" 207 | elif num_img_channels == 3: 208 | PIL_image_type = "RGB" 209 | else: 210 | PIL_image_type = None 211 | 212 | if config.segmentation_guided: 213 | preprocess_segmentation = transforms.Compose( 214 | [ 215 | transforms.Resize((config.image_size, config.image_size), interpolation=transforms.InterpolationMode.NEAREST), 216 | transforms.ToTensor(), 217 | ] 218 | ) 219 | 220 | def transform(examples): 221 | if img_dir is not None: 222 | if not load_images_as_np_arrays: 223 | images = [preprocess(image.convert(PIL_image_type)) for image in examples["image"]] 224 | else: 225 | # load np array as torch tensor, resize, then normalize 226 | images = [ 227 | preprocess(F.interpolate(torch.tensor(np.load(image)).unsqueeze(0).float(), size=(config.image_size, config.image_size)).squeeze()) for image in examples["image"] 228 | ] 229 | 230 | images_filenames = examples["image_filename"] 231 | 232 | segs = {} 233 | for seg_type in seg_types: 234 | segs["seg_{}".format(seg_type)] = [preprocess_segmentation(image.convert("L")) for image in examples["seg_{}".format(seg_type)]] 235 | # return {"images": images, "seg_breast": seg_breast, "seg_dv": seg_dv} 236 | if img_dir is not None: 237 | return {**{"images": images}, **segs, **{"image_filenames": images_filenames}} 238 | else: 239 | return {**segs, **{"image_filenames": images_filenames}} 240 | 241 | dataset_train.set_transform(transform) 242 | dataset_eval.set_transform(transform) 243 | 244 | else: 245 | if img_dir is not None: 246 | def transform(examples): 247 | if not load_images_as_np_arrays: 248 | images = [preprocess(image.convert(PIL_image_type)) for image in examples["image"]] 249 | else: 250 | images = [ 251 | preprocess(F.interpolate(torch.tensor(np.load(image)).unsqueeze(0).float(), size=(config.image_size, config.image_size)).squeeze()) for image in examples["image"] 252 | ] 253 | images_filenames = examples["image_filename"] 254 | #return {"images": images, "image_filenames": images_filenames} 255 | return {"images": images, **{"image_filenames": images_filenames}} 256 | 257 | dataset_train.set_transform(transform) 258 | dataset_eval.set_transform(transform) 259 | 260 | if ((img_dir is None) and (not segmentation_guided)): 261 | train_dataloader = None 262 | # just make placeholder dataloaders to iterate through when sampling from uncond model 263 | eval_dataloader = torch.utils.data.DataLoader( 264 | torch.utils.data.TensorDataset(torch.zeros(config.eval_batch_size, num_img_channels, config.image_size, config.image_size)), 265 | batch_size=config.eval_batch_size, 266 | shuffle=eval_shuffle_dataloader 267 | ) 268 | else: 269 | train_dataloader = torch.utils.data.DataLoader( 270 | dataset_train, 271 | batch_size=config.train_batch_size, 272 | shuffle=True 273 | ) 274 | 275 | eval_dataloader = torch.utils.data.DataLoader( 276 | dataset_eval, 277 | batch_size=config.eval_batch_size, 278 | shuffle=eval_shuffle_dataloader 279 | ) 280 | 281 | # define the model 282 | in_channels = num_img_channels 283 | if config.segmentation_guided: 284 | assert config.num_segmentation_classes is not None 285 | assert config.num_segmentation_classes > 1, "must have at least 2 segmentation classes (INCLUDING background)" 286 | if config.segmentation_channel_mode == "single": 287 | in_channels += 1 288 | elif config.segmentation_channel_mode == "multi": 289 | in_channels = len(seg_types) + in_channels 290 | 291 | model = diffusers.UNet2DModel( 292 | sample_size=config.image_size, # the target image resolution 293 | in_channels=in_channels, # the number of input channels, 3 for RGB images 294 | out_channels=num_img_channels, # the number of output channels 295 | layers_per_block=2, # how many ResNet layers to use per UNet block 296 | block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block 297 | down_block_types=( 298 | "DownBlock2D", # a regular ResNet downsampling block 299 | "DownBlock2D", 300 | "DownBlock2D", 301 | "DownBlock2D", 302 | "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention 303 | "DownBlock2D", 304 | ), 305 | up_block_types=( 306 | "UpBlock2D", # a regular ResNet upsampling block 307 | "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention 308 | "UpBlock2D", 309 | "UpBlock2D", 310 | "UpBlock2D", 311 | "UpBlock2D" 312 | ), 313 | ) 314 | 315 | if (mode == "train" and resume_epoch is not None) or "eval" in mode: 316 | if mode == "train": 317 | print("resuming from model at training epoch {}".format(resume_epoch)) 318 | elif "eval" in mode: 319 | print("loading saved model...") 320 | model = model.from_pretrained(os.path.join(config.output_dir, 'unet'), use_safetensors=True) 321 | 322 | model = nn.DataParallel(model) 323 | model.to(device) 324 | 325 | # define noise scheduler 326 | if model_type == "DDPM": 327 | noise_scheduler = diffusers.DDPMScheduler(num_train_timesteps=1000) 328 | elif model_type == "DDIM": 329 | noise_scheduler = diffusers.DDIMScheduler(num_train_timesteps=1000) 330 | 331 | if mode == "train": 332 | # training setup 333 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate) 334 | lr_scheduler = get_cosine_schedule_with_warmup( 335 | optimizer=optimizer, 336 | num_warmup_steps=config.lr_warmup_steps, 337 | num_training_steps=(len(train_dataloader) * config.num_epochs), 338 | ) 339 | 340 | # train 341 | train_loop( 342 | config, 343 | model, 344 | noise_scheduler, 345 | optimizer, 346 | train_dataloader, 347 | eval_dataloader, 348 | lr_scheduler, 349 | device=device 350 | ) 351 | elif mode == "eval": 352 | """ 353 | default eval behavior: 354 | evaluate image generation or translation (if for conditional model, either evaluate naive class conditioning but not CFG, 355 | or with CFG), 356 | possibly conditioned on masks. 357 | 358 | has various options. 359 | """ 360 | evaluate_generation( 361 | config, 362 | model, 363 | noise_scheduler, 364 | eval_dataloader, 365 | eval_mask_removal=eval_mask_removal, 366 | eval_blank_mask=eval_blank_mask, 367 | device=device 368 | ) 369 | 370 | elif mode == "eval_many": 371 | """ 372 | generate many images and save them to a directory, saved individually 373 | """ 374 | evaluate_sample_many( 375 | eval_sample_size, 376 | config, 377 | model, 378 | noise_scheduler, 379 | eval_dataloader, 380 | device=device 381 | ) 382 | 383 | else: 384 | raise ValueError("mode \"{}\" not supported.".format(mode)) 385 | 386 | 387 | if __name__ == "__main__": 388 | # parse args: 389 | parser = ArgumentParser() 390 | parser.add_argument('--mode', type=str, default='train') 391 | parser.add_argument('--img_size', type=int, default=256) 392 | parser.add_argument('--num_img_channels', type=int, default=1) 393 | parser.add_argument('--dataset', type=str, default="breast_mri") 394 | parser.add_argument('--img_dir', type=str, default=None) 395 | parser.add_argument('--seg_dir', type=str, default=None) 396 | parser.add_argument('--model_type', type=str, default="DDPM") 397 | parser.add_argument('--segmentation_guided', action='store_true', help='use segmentation guided training/sampling?') 398 | parser.add_argument('--segmentation_channel_mode', type=str, default="single", help='single == all segmentations in one channel, multi == each segmentation in its own channel') 399 | parser.add_argument('--num_segmentation_classes', type=int, default=None, help='number of segmentation classes, including background') 400 | parser.add_argument('--train_batch_size', type=int, default=32) 401 | parser.add_argument('--eval_batch_size', type=int, default=8) 402 | parser.add_argument('--num_epochs', type=int, default=200) 403 | parser.add_argument('--resume_epoch', type=int, default=None, help='resume training starting at this epoch') 404 | 405 | # novel options 406 | parser.add_argument('--use_ablated_segmentations', action='store_true', help='use mask ablated training and any evaluation? sometimes randomly remove class(es) from mask during training and sampling.') 407 | 408 | # other options 409 | parser.add_argument('--eval_noshuffle_dataloader', action='store_true', help='if true, don\'t shuffle the eval dataloader') 410 | 411 | # args only used in eval 412 | parser.add_argument('--eval_mask_removal', action='store_true', help='if true, evaluate gradually removing anatomies from mask and re-sampling') 413 | parser.add_argument('--eval_blank_mask', action='store_true', help='if true, evaluate sampling conditioned on blank (zeros) masks') 414 | parser.add_argument('--eval_sample_size', type=int, default=1000, help='number of images to sample when using eval_many mode') 415 | 416 | args = parser.parse_args() 417 | 418 | main( 419 | args.mode, 420 | args.img_size, 421 | args.num_img_channels, 422 | args.dataset, 423 | args.img_dir, 424 | args.seg_dir, 425 | args.model_type, 426 | args.segmentation_guided, 427 | args.segmentation_channel_mode, 428 | args.num_segmentation_classes, 429 | args.train_batch_size, 430 | args.eval_batch_size, 431 | args.num_epochs, 432 | args.resume_epoch, 433 | args.use_ablated_segmentations, 434 | not args.eval_noshuffle_dataloader, 435 | 436 | # args only used in eval 437 | args.eval_mask_removal, 438 | args.eval_blank_mask, 439 | args.eval_sample_size 440 | ) 441 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.26.1 2 | matplotlib==3.8.0 3 | Pillow==10.0.1 4 | diffusers==0.21.4 5 | datasets==2.14.5 6 | tqdm==4.66.1 7 | -------------------------------------------------------------------------------- /training.py: -------------------------------------------------------------------------------- 1 | """ 2 | training utils 3 | """ 4 | from dataclasses import dataclass 5 | import math 6 | import os 7 | from pathlib import Path 8 | from tqdm.auto import tqdm 9 | import numpy as np 10 | from datetime import timedelta 11 | 12 | import torch 13 | from torch import nn 14 | import torch.nn.functional as F 15 | from torch.utils.tensorboard import SummaryWriter 16 | 17 | import diffusers 18 | 19 | from eval import evaluate, add_segmentations_to_noise, SegGuidedDDPMPipeline, SegGuidedDDIMPipeline 20 | 21 | @dataclass 22 | class TrainingConfig: 23 | model_type: str = "DDPM" 24 | image_size: int = 256 # the generated image resolution 25 | train_batch_size: int = 32 26 | eval_batch_size: int = 8 # how many images to sample during evaluation 27 | num_epochs: int = 200 28 | gradient_accumulation_steps: int = 1 29 | learning_rate: float = 1e-4 30 | lr_warmup_steps: int = 500 31 | save_image_epochs: int = 20 32 | save_model_epochs: int = 30 33 | mixed_precision: str = 'fp16' # `no` for float32, `fp16` for automatic mixed precision 34 | output_dir: str = None 35 | 36 | push_to_hub: bool = False # whether to upload the saved model to the HF Hub 37 | hub_private_repo: bool = False 38 | overwrite_output_dir: bool = True # overwrite the old model when re-running the notebook 39 | seed: int = 0 40 | 41 | # custom options 42 | segmentation_guided: bool = False 43 | segmentation_channel_mode: str = "single" 44 | num_segmentation_classes: int = None # INCLUDING background 45 | use_ablated_segmentations: bool = False 46 | dataset: str = "breast_mri" 47 | resume_epoch: int = None 48 | 49 | # EXPERIMENTAL/UNTESTED: classifier-free class guidance and image translation 50 | class_conditional: bool = False 51 | cfg_p_uncond: float = 0.2 # p_uncond in classifier-free guidance paper 52 | cfg_weight: float = 0.3 # w in the paper 53 | trans_noise_level: float = 0.5 # ratio of time step t to noise trans_start_images to total T before denoising in translation. e.g. value of 0.5 means t = 500 for default T = 1000. 54 | use_cfg_for_eval_conditioning: bool = True # whether to use classifier-free guidance for or just naive class conditioning for main sampling loop 55 | cfg_maskguidance_condmodel_only: bool = True # if using mask guidance AND cfg, only give mask to conditional network 56 | # ^ this is because giving mask to both uncond and cond model make class guidance not work 57 | # (see "Classifier-free guidance resolution weighting." in ControlNet paper) 58 | 59 | 60 | def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, eval_dataloader, lr_scheduler, device='cuda'): 61 | # Prepare everything 62 | # There is no specific order to remember, you just need to unpack the 63 | # objects in the same order you gave them to the prepare method. 64 | 65 | global_step = 0 66 | 67 | # logging 68 | run_name = '{}-{}-{}'.format(config.model_type.lower(), config.dataset, config.image_size) 69 | if config.segmentation_guided: 70 | run_name += "-segguided" 71 | writer = SummaryWriter(comment=run_name) 72 | 73 | # for loading segs to condition on: 74 | eval_dataloader = iter(eval_dataloader) 75 | 76 | # Now you train the model 77 | start_epoch = 0 78 | if config.resume_epoch is not None: 79 | start_epoch = config.resume_epoch 80 | 81 | for epoch in range(start_epoch, config.num_epochs): 82 | progress_bar = tqdm(total=len(train_dataloader)) 83 | progress_bar.set_description(f"Epoch {epoch}") 84 | 85 | model.train() 86 | 87 | for step, batch in enumerate(train_dataloader): 88 | clean_images = batch['images'] 89 | clean_images = clean_images.to(device) 90 | 91 | # Sample noise to add to the images 92 | noise = torch.randn(clean_images.shape).to(clean_images.device) 93 | bs = clean_images.shape[0] 94 | 95 | # Sample a random timestep for each image 96 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device).long() 97 | 98 | # Add noise to the clean images according to the noise magnitude at each timestep 99 | # (this is the forward diffusion process) 100 | noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps) 101 | 102 | if config.segmentation_guided: 103 | noisy_images = add_segmentations_to_noise(noisy_images, batch, config, device) 104 | 105 | # Predict the noise residual 106 | if config.class_conditional: 107 | class_labels = torch.ones(noisy_images.size(0)).long().to(device) 108 | # classifier-free guidance 109 | a = np.random.uniform() 110 | if a <= config.cfg_p_uncond: 111 | class_labels = torch.zeros_like(class_labels).long() 112 | noise_pred = model(noisy_images, timesteps, class_labels=class_labels, return_dict=False)[0] 113 | else: 114 | noise_pred = model(noisy_images, timesteps, return_dict=False)[0] 115 | loss = F.mse_loss(noise_pred, noise) 116 | loss.backward() 117 | 118 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 119 | optimizer.step() 120 | lr_scheduler.step() 121 | optimizer.zero_grad() 122 | 123 | # also train on target domain images if conditional 124 | # (we don't have masks for this domain, so we can't do segmentation-guided; just use blank masks) 125 | if config.class_conditional: 126 | target_domain_images = batch['images_target'] 127 | target_domain_images = target_domain_images.to(device) 128 | 129 | # Sample noise to add to the images 130 | noise = torch.randn(target_domain_images.shape).to(target_domain_images.device) 131 | bs = target_domain_images.shape[0] 132 | 133 | # Sample a random timestep for each image 134 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=target_domain_images.device).long() 135 | 136 | # Add noise to the clean images according to the noise magnitude at each timestep 137 | # (this is the forward diffusion process) 138 | noisy_images = noise_scheduler.add_noise(target_domain_images, noise, timesteps) 139 | 140 | if config.segmentation_guided: 141 | # no masks in target domain so just use blank masks 142 | noisy_images = torch.cat((noisy_images, torch.zeros_like(noisy_images)), dim=1) 143 | 144 | # Predict the noise residual 145 | class_labels = torch.full([noisy_images.size(0)], 2).long().to(device) 146 | # classifier-free guidance 147 | a = np.random.uniform() 148 | if a <= config.cfg_p_uncond: 149 | class_labels = torch.zeros_like(class_labels).long() 150 | noise_pred = model(noisy_images, timesteps, class_labels=class_labels, return_dict=False)[0] 151 | loss_target_domain = F.mse_loss(noise_pred, noise) 152 | loss_target_domain.backward() 153 | 154 | nn.utils.clip_grad_norm_(model.parameters(), 1.0) 155 | optimizer.step() 156 | lr_scheduler.step() 157 | optimizer.zero_grad() 158 | 159 | progress_bar.update(1) 160 | if config.class_conditional: 161 | logs = {"loss": loss.detach().item(), "loss_target_domain": loss_target_domain.detach().item(), 162 | "lr": lr_scheduler.get_last_lr()[0], "step": global_step} 163 | writer.add_scalar("loss_target_domain", loss.detach().item(), global_step) 164 | else: 165 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step} 166 | writer.add_scalar("loss", loss.detach().item(), global_step) 167 | 168 | progress_bar.set_postfix(**logs) 169 | global_step += 1 170 | 171 | # After each epoch you optionally sample some demo images with evaluate() and save the model 172 | if config.model_type == "DDPM": 173 | if config.segmentation_guided: 174 | pipeline = SegGuidedDDPMPipeline( 175 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config 176 | ) 177 | else: 178 | if config.class_conditional: 179 | raise NotImplementedError("TODO: Conditional training not implemented for non-seg-guided DDPM") 180 | else: 181 | pipeline = diffusers.DDPMPipeline(unet=model.module, scheduler=noise_scheduler) 182 | elif config.model_type == "DDIM": 183 | if config.segmentation_guided: 184 | pipeline = SegGuidedDDIMPipeline( 185 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config 186 | ) 187 | else: 188 | if config.class_conditional: 189 | raise NotImplementedError("TODO: Conditional training not implemented for non-seg-guided DDIM") 190 | else: 191 | pipeline = diffusers.DDIMPipeline(unet=model.module, scheduler=noise_scheduler) 192 | 193 | model.eval() 194 | 195 | if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1: 196 | if config.segmentation_guided: 197 | seg_batch = next(eval_dataloader) 198 | evaluate(config, epoch, pipeline, seg_batch) 199 | else: 200 | evaluate(config, epoch, pipeline) 201 | 202 | if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1: 203 | pipeline.save_pretrained(config.output_dir) 204 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | def make_grid(images, rows, cols): 4 | w, h = images[0].size 5 | grid = Image.new('RGB', size=(cols*w, rows*h)) 6 | for i, image in enumerate(images): 7 | grid.paste(image, box=(i%cols*w, i//cols*h)) 8 | return grid 9 | --------------------------------------------------------------------------------