├── CITATION.md
├── LICENSE
├── README.md
├── eval.py
├── figs
├── teaser.png
└── teaser_data.png
├── main.py
├── requirements.txt
├── training.py
└── utils.py
/CITATION.md:
--------------------------------------------------------------------------------
1 | ```bib
2 | @inproceedings{konz2022intrinsic,
3 | title={Anatomically-Controllable Medical Image Generation with Segmentation-Guided Diffusion Models},
4 | author={Nicholas Konz and Yuwen Chen and Haoyu Dong and Maciej A. Mazurowski},
5 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
6 | year={2024}
7 | }
8 | ```
9 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Creative Commons Attribution-NonCommercial 4.0 International
2 |
3 | Creative Commons Corporation ("Creative Commons") is not a law firm and
4 | does not provide legal services or legal advice. Distribution of
5 | Creative Commons public licenses does not create a lawyer-client or
6 | other relationship. Creative Commons makes its licenses and related
7 | information available on an "as-is" basis. Creative Commons gives no
8 | warranties regarding its licenses, any material licensed under their
9 | terms and conditions, or any related information. Creative Commons
10 | disclaims all liability for damages resulting from their use to the
11 | fullest extent possible.
12 |
13 | Using Creative Commons Public Licenses
14 |
15 | Creative Commons public licenses provide a standard set of terms and
16 | conditions that creators and other rights holders may use to share
17 | original works of authorship and other material subject to copyright and
18 | certain other rights specified in the public license below. The
19 | following considerations are for informational purposes only, are not
20 | exhaustive, and do not form part of our licenses.
21 |
22 | - Considerations for licensors: Our public licenses are intended for
23 | use by those authorized to give the public permission to use
24 | material in ways otherwise restricted by copyright and certain other
25 | rights. Our licenses are irrevocable. Licensors should read and
26 | understand the terms and conditions of the license they choose
27 | before applying it. Licensors should also secure all rights
28 | necessary before applying our licenses so that the public can reuse
29 | the material as expected. Licensors should clearly mark any material
30 | not subject to the license. This includes other CC-licensed
31 | material, or material used under an exception or limitation to
32 | copyright. More considerations for licensors :
33 | wiki.creativecommons.org/Considerations_for_licensors
34 |
35 | - Considerations for the public: By using one of our public licenses,
36 | a licensor grants the public permission to use the licensed material
37 | under specified terms and conditions. If the licensor's permission
38 | is not necessary for any reason–for example, because of any
39 | applicable exception or limitation to copyright–then that use is not
40 | regulated by the license. Our licenses grant only permissions under
41 | copyright and certain other rights that a licensor has authority to
42 | grant. Use of the licensed material may still be restricted for
43 | other reasons, including because others have copyright or other
44 | rights in the material. A licensor may make special requests, such
45 | as asking that all changes be marked or described. Although not
46 | required by our licenses, you are encouraged to respect those
47 | requests where reasonable. More considerations for the public :
48 | wiki.creativecommons.org/Considerations_for_licensees
49 |
50 | Creative Commons Attribution-NonCommercial 4.0 International Public
51 | License
52 |
53 | By exercising the Licensed Rights (defined below), You accept and agree
54 | to be bound by the terms and conditions of this Creative Commons
55 | Attribution-NonCommercial 4.0 International Public License ("Public
56 | License"). To the extent this Public License may be interpreted as a
57 | contract, You are granted the Licensed Rights in consideration of Your
58 | acceptance of these terms and conditions, and the Licensor grants You
59 | such rights in consideration of benefits the Licensor receives from
60 | making the Licensed Material available under these terms and conditions.
61 |
62 | - Section 1 – Definitions.
63 |
64 | - a. Adapted Material means material subject to Copyright and
65 | Similar Rights that is derived from or based upon the Licensed
66 | Material and in which the Licensed Material is translated,
67 | altered, arranged, transformed, or otherwise modified in a
68 | manner requiring permission under the Copyright and Similar
69 | Rights held by the Licensor. For purposes of this Public
70 | License, where the Licensed Material is a musical work,
71 | performance, or sound recording, Adapted Material is always
72 | produced where the Licensed Material is synched in timed
73 | relation with a moving image.
74 | - b. Adapter's License means the license You apply to Your
75 | Copyright and Similar Rights in Your contributions to Adapted
76 | Material in accordance with the terms and conditions of this
77 | Public License.
78 | - c. Copyright and Similar Rights means copyright and/or similar
79 | rights closely related to copyright including, without
80 | limitation, performance, broadcast, sound recording, and Sui
81 | Generis Database Rights, without regard to how the rights are
82 | labeled or categorized. For purposes of this Public License, the
83 | rights specified in Section 2(b)(1)-(2) are not Copyright and
84 | Similar Rights.
85 | - d. Effective Technological Measures means those measures that,
86 | in the absence of proper authority, may not be circumvented
87 | under laws fulfilling obligations under Article 11 of the WIPO
88 | Copyright Treaty adopted on December 20, 1996, and/or similar
89 | international agreements.
90 | - e. Exceptions and Limitations means fair use, fair dealing,
91 | and/or any other exception or limitation to Copyright and
92 | Similar Rights that applies to Your use of the Licensed
93 | Material.
94 | - f. Licensed Material means the artistic or literary work,
95 | database, or other material to which the Licensor applied this
96 | Public License.
97 | - g. Licensed Rights means the rights granted to You subject to
98 | the terms and conditions of this Public License, which are
99 | limited to all Copyright and Similar Rights that apply to Your
100 | use of the Licensed Material and that the Licensor has authority
101 | to license.
102 | - h. Licensor means the individual(s) or entity(ies) granting
103 | rights under this Public License.
104 | - i. NonCommercial means not primarily intended for or directed
105 | towards commercial advantage or monetary compensation. For
106 | purposes of this Public License, the exchange of the Licensed
107 | Material for other material subject to Copyright and Similar
108 | Rights by digital file-sharing or similar means is NonCommercial
109 | provided there is no payment of monetary compensation in
110 | connection with the exchange.
111 | - j. Share means to provide material to the public by any means or
112 | process that requires permission under the Licensed Rights, such
113 | as reproduction, public display, public performance,
114 | distribution, dissemination, communication, or importation, and
115 | to make material available to the public including in ways that
116 | members of the public may access the material from a place and
117 | at a time individually chosen by them.
118 | - k. Sui Generis Database Rights means rights other than copyright
119 | resulting from Directive 96/9/EC of the European Parliament and
120 | of the Council of 11 March 1996 on the legal protection of
121 | databases, as amended and/or succeeded, as well as other
122 | essentially equivalent rights anywhere in the world.
123 | - l. You means the individual or entity exercising the Licensed
124 | Rights under this Public License. Your has a corresponding
125 | meaning.
126 |
127 | - Section 2 – Scope.
128 |
129 | - a. License grant.
130 | - 1. Subject to the terms and conditions of this Public
131 | License, the Licensor hereby grants You a worldwide,
132 | royalty-free, non-sublicensable, non-exclusive, irrevocable
133 | license to exercise the Licensed Rights in the Licensed
134 | Material to:
135 | - A. reproduce and Share the Licensed Material, in whole
136 | or in part, for NonCommercial purposes only; and
137 | - B. produce, reproduce, and Share Adapted Material for
138 | NonCommercial purposes only.
139 | - 2. Exceptions and Limitations. For the avoidance of doubt,
140 | where Exceptions and Limitations apply to Your use, this
141 | Public License does not apply, and You do not need to comply
142 | with its terms and conditions.
143 | - 3. Term. The term of this Public License is specified in
144 | Section 6(a).
145 | - 4. Media and formats; technical modifications allowed. The
146 | Licensor authorizes You to exercise the Licensed Rights in
147 | all media and formats whether now known or hereafter
148 | created, and to make technical modifications necessary to do
149 | so. The Licensor waives and/or agrees not to assert any
150 | right or authority to forbid You from making technical
151 | modifications necessary to exercise the Licensed Rights,
152 | including technical modifications necessary to circumvent
153 | Effective Technological Measures. For purposes of this
154 | Public License, simply making modifications authorized by
155 | this Section 2(a)(4) never produces Adapted Material.
156 | - 5. Downstream recipients.
157 | - A. Offer from the Licensor – Licensed Material. Every
158 | recipient of the Licensed Material automatically
159 | receives an offer from the Licensor to exercise the
160 | Licensed Rights under the terms and conditions of this
161 | Public License.
162 | - B. No downstream restrictions. You may not offer or
163 | impose any additional or different terms or conditions
164 | on, or apply any Effective Technological Measures to,
165 | the Licensed Material if doing so restricts exercise of
166 | the Licensed Rights by any recipient of the Licensed
167 | Material.
168 | - 6. No endorsement. Nothing in this Public License
169 | constitutes or may be construed as permission to assert or
170 | imply that You are, or that Your use of the Licensed
171 | Material is, connected with, or sponsored, endorsed, or
172 | granted official status by, the Licensor or others
173 | designated to receive attribution as provided in Section
174 | 3(a)(1)(A)(i).
175 | - b. Other rights.
176 | - 1. Moral rights, such as the right of integrity, are not
177 | licensed under this Public License, nor are publicity,
178 | privacy, and/or other similar personality rights; however,
179 | to the extent possible, the Licensor waives and/or agrees
180 | not to assert any such rights held by the Licensor to the
181 | limited extent necessary to allow You to exercise the
182 | Licensed Rights, but not otherwise.
183 | - 2. Patent and trademark rights are not licensed under this
184 | Public License.
185 | - 3. To the extent possible, the Licensor waives any right to
186 | collect royalties from You for the exercise of the Licensed
187 | Rights, whether directly or through a collecting society
188 | under any voluntary or waivable statutory or compulsory
189 | licensing scheme. In all other cases the Licensor expressly
190 | reserves any right to collect such royalties, including when
191 | the Licensed Material is used other than for NonCommercial
192 | purposes.
193 |
194 | - Section 3 – License Conditions.
195 |
196 | Your exercise of the Licensed Rights is expressly made subject to
197 | the following conditions.
198 |
199 | - a. Attribution.
200 | - 1. If You Share the Licensed Material (including in modified
201 | form), You must:
202 | - A. retain the following if it is supplied by the
203 | Licensor with the Licensed Material:
204 | - i. identification of the creator(s) of the Licensed
205 | Material and any others designated to receive
206 | attribution, in any reasonable manner requested by
207 | the Licensor (including by pseudonym if designated);
208 | - ii. a copyright notice;
209 | - iii. a notice that refers to this Public License;
210 | - iv. a notice that refers to the disclaimer of
211 | warranties;
212 | - v. a URI or hyperlink to the Licensed Material to
213 | the extent reasonably practicable;
214 | - B. indicate if You modified the Licensed Material and
215 | retain an indication of any previous modifications; and
216 | - C. indicate the Licensed Material is licensed under this
217 | Public License, and include the text of, or the URI or
218 | hyperlink to, this Public License.
219 | - 2. You may satisfy the conditions in Section 3(a)(1) in any
220 | reasonable manner based on the medium, means, and context in
221 | which You Share the Licensed Material. For example, it may
222 | be reasonable to satisfy the conditions by providing a URI
223 | or hyperlink to a resource that includes the required
224 | information.
225 | - 3. If requested by the Licensor, You must remove any of the
226 | information required by Section 3(a)(1)(A) to the extent
227 | reasonably practicable.
228 | - 4. If You Share Adapted Material You produce, the Adapter's
229 | License You apply must not prevent recipients of the Adapted
230 | Material from complying with this Public License.
231 |
232 | - Section 4 – Sui Generis Database Rights.
233 |
234 | Where the Licensed Rights include Sui Generis Database Rights that
235 | apply to Your use of the Licensed Material:
236 |
237 | - a. for the avoidance of doubt, Section 2(a)(1) grants You the
238 | right to extract, reuse, reproduce, and Share all or a
239 | substantial portion of the contents of the database for
240 | NonCommercial purposes only;
241 | - b. if You include all or a substantial portion of the database
242 | contents in a database in which You have Sui Generis Database
243 | Rights, then the database in which You have Sui Generis Database
244 | Rights (but not its individual contents) is Adapted Material;
245 | and
246 | - c. You must comply with the conditions in Section 3(a) if You
247 | Share all or a substantial portion of the contents of the
248 | database.
249 |
250 | For the avoidance of doubt, this Section 4 supplements and does not
251 | replace Your obligations under this Public License where the
252 | Licensed Rights include other Copyright and Similar Rights.
253 |
254 | - Section 5 – Disclaimer of Warranties and Limitation of Liability.
255 |
256 | - a. Unless otherwise separately undertaken by the Licensor, to
257 | the extent possible, the Licensor offers the Licensed Material
258 | as-is and as-available, and makes no representations or
259 | warranties of any kind concerning the Licensed Material, whether
260 | express, implied, statutory, or other. This includes, without
261 | limitation, warranties of title, merchantability, fitness for a
262 | particular purpose, non-infringement, absence of latent or other
263 | defects, accuracy, or the presence or absence of errors, whether
264 | or not known or discoverable. Where disclaimers of warranties
265 | are not allowed in full or in part, this disclaimer may not
266 | apply to You.
267 | - b. To the extent possible, in no event will the Licensor be
268 | liable to You on any legal theory (including, without
269 | limitation, negligence) or otherwise for any direct, special,
270 | indirect, incidental, consequential, punitive, exemplary, or
271 | other losses, costs, expenses, or damages arising out of this
272 | Public License or use of the Licensed Material, even if the
273 | Licensor has been advised of the possibility of such losses,
274 | costs, expenses, or damages. Where a limitation of liability is
275 | not allowed in full or in part, this limitation may not apply to
276 | You.
277 | - c. The disclaimer of warranties and limitation of liability
278 | provided above shall be interpreted in a manner that, to the
279 | extent possible, most closely approximates an absolute
280 | disclaimer and waiver of all liability.
281 |
282 | - Section 6 – Term and Termination.
283 |
284 | - a. This Public License applies for the term of the Copyright and
285 | Similar Rights licensed here. However, if You fail to comply
286 | with this Public License, then Your rights under this Public
287 | License terminate automatically.
288 | - b. Where Your right to use the Licensed Material has terminated
289 | under Section 6(a), it reinstates:
290 |
291 | - 1. automatically as of the date the violation is cured,
292 | provided it is cured within 30 days of Your discovery of the
293 | violation; or
294 | - 2. upon express reinstatement by the Licensor.
295 |
296 | For the avoidance of doubt, this Section 6(b) does not affect
297 | any right the Licensor may have to seek remedies for Your
298 | violations of this Public License.
299 |
300 | - c. For the avoidance of doubt, the Licensor may also offer the
301 | Licensed Material under separate terms or conditions or stop
302 | distributing the Licensed Material at any time; however, doing
303 | so will not terminate this Public License.
304 | - d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
305 | License.
306 |
307 | - Section 7 – Other Terms and Conditions.
308 |
309 | - a. The Licensor shall not be bound by any additional or
310 | different terms or conditions communicated by You unless
311 | expressly agreed.
312 | - b. Any arrangements, understandings, or agreements regarding the
313 | Licensed Material not stated herein are separate from and
314 | independent of the terms and conditions of this Public License.
315 |
316 | - Section 8 – Interpretation.
317 |
318 | - a. For the avoidance of doubt, this Public License does not, and
319 | shall not be interpreted to, reduce, limit, restrict, or impose
320 | conditions on any use of the Licensed Material that could
321 | lawfully be made without permission under this Public License.
322 | - b. To the extent possible, if any provision of this Public
323 | License is deemed unenforceable, it shall be automatically
324 | reformed to the minimum extent necessary to make it enforceable.
325 | If the provision cannot be reformed, it shall be severed from
326 | this Public License without affecting the enforceability of the
327 | remaining terms and conditions.
328 | - c. No term or condition of this Public License will be waived
329 | and no failure to comply consented to unless expressly agreed to
330 | by the Licensor.
331 | - d. Nothing in this Public License constitutes or may be
332 | interpreted as a limitation upon, or waiver of, any privileges
333 | and immunities that apply to the Licensor or You, including from
334 | the legal processes of any jurisdiction or authority.
335 |
336 | Creative Commons is not a party to its public licenses. Notwithstanding,
337 | Creative Commons may elect to apply one of its public licenses to
338 | material it publishes and in those instances will be considered the
339 | "Licensor." The text of the Creative Commons public licenses is
340 | dedicated to the public domain under the CC0 Public Domain Dedication.
341 | Except for the limited purpose of indicating that material is shared
342 | under a Creative Commons public license or as otherwise permitted by the
343 | Creative Commons policies published at creativecommons.org/policies,
344 | Creative Commons does not authorize the use of the trademark "Creative
345 | Commons" or any other trademark or logo of Creative Commons without its
346 | prior written consent including, without limitation, in connection with
347 | any unauthorized modifications to any of its public licenses or any
348 | other arrangements, understandings, or agreements concerning use of
349 | licensed material. For the avoidance of doubt, this paragraph does not
350 | form part of the public licenses.
351 |
352 | Creative Commons may be contacted at creativecommons.org.
353 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Easy and Precise Segmentation-Guided Diffusion Models
2 |
3 | #### By [Nicholas Konz](https://nickk124.github.io/), [Yuwen Chen](https://scholar.google.com/citations?user=61s49p0AAAAJ&hl=en), [Haoyu Dong](https://scholar.google.com/citations?user=eZVEUCIAAAAJ&hl=en) and [Maciej Mazurowski](https://sites.duke.edu/mazurowski/).
4 |
5 | arXiv paper link: [](https://arxiv.org/abs/2402.05210)
6 |
7 | ## NEWS: our paper was accepted to MICCAI 2024!
8 |
9 |
10 |
11 |
12 |
13 | This is the code for our paper [**Anatomically-Controllable Medical Image Generation with Segmentation-Guided Diffusion Models**](https://arxiv.org/abs/2402.05210), where we introduce a simple yet powerful training procedure for conditioning image-generating diffusion models on (possibly incomplete) multiclass segmentation masks.
14 |
15 | ### Why use our model?
16 |
17 | Our method outperforms existing segmentation-guided image generative models (like [SPADE](https://github.com/NVlabs/SPADE) and [ControlNet](https://github.com/lllyasviel/ControlNet)) in terms of the faithfulness of generated images to input masks, on multiple, multi-modality medical image datasets with a broad range of objects of interest, and is on par for anatomical realism. Our method is also simple to use and train, and its precise pixel-wise obedience to input segmentation masks is due to it always operating in the native image space (it's not a latent diffusion model), which is especially helpful when conditioning on complex and detailed anatomical structures.
18 |
19 | Additionally, our optional *mask-ablated training* algorithm allows our model to be conditioned on segmentation masks with missing classes, which is useful for medical images where segmentation masks may be incomplete or noisy. This allows not just for more flexible image generation, but as we show in our paper, adjustable anatomical similarity of images to some real image by taking advantage of the latent space structure of diffusion models. We also used this feature to generate a synthetic paired breast MRI dataset, [shown below](https://github.com/mazurowski-lab/segmentation-guided-diffusion?tab=readme-ov-file#synthetic-paired-breast-mri-dataset-release).
20 |
21 | **Using this code, you can:**
22 | 1. Train a segmentation-guided (or standard unconditional) diffusion model on your own dataset, with a wide range of options, including mask-ablated training.
23 | 2. Generate images from these models (or using our provided pre-trained models).
24 |
25 | Please follow the steps outlined below to do these.
26 |
27 | Also, check out our accompanying [**Synthetic Paired Breast MRI Dataset Release**](https://github.com/mazurowski-lab/segmentation-guided-diffusion?tab=readme-ov-file#synthetic-paired-breast-mri-dataset-release) below!
28 |
29 | Thank you to Hugging Face's awesome [Diffusers](https://github.com/huggingface/diffusers) library for providing a helpful backbone for our code!
30 |
31 | ## Citation
32 |
33 | Please cite our paper if you use our code or reference our work:
34 | ```bib
35 | @inproceedings{konz2024segguideddiffusion,
36 | title={Anatomically-Controllable Medical Image Generation with Segmentation-Guided Diffusion Models},
37 | author={Nicholas Konz and Yuwen Chen and Haoyu Dong and Maciej A. Mazurowski},
38 | booktitle={International Conference on Medical Image Computing and Computer-Assisted Intervention},
39 | year={2024}
40 | }
41 | ```
42 |
43 | ## 1) Package Installation
44 | This codebase was created with Python 3.11. First, install PyTorch for your computer's CUDA version (check it by running `nvidia-smi` if you're not sure) according to the provided command at https://pytorch.org/get-started/locally/; this codebase was made with `torch==2.1.2` and `torchvision==0.16.2` on CUDA 12.2. Next, run `pip3 install -r requirements.txt` to install the required packages.
45 |
46 | ## 2a) (optional) Use Pre-Trained Models
47 |
48 | We provide pre-trained model checkpoints (`.safetensor` files) and config (`.json`) files from our paper for the [Duke Breast MRI](https://www.cancerimagingarchive.net/collection/duke-breast-cancer-mri/) and [CT Organ](https://wiki.cancerimagingarchive.net/pages/viewpage.action?pageId=61080890) datasets, [here](https://drive.google.com/drive/folders/1OaOGBLfpUFe_tmpvZGEe2Mv2gow32Y8u). These include:
49 |
50 | 1. Segmentation-Conditional Models, trained without mask ablation.
51 | 2. Segmentation-Conditional Models, trained with mask ablation.
52 | 3. Unconditional (standard) Models.
53 |
54 | Once you've downloaded the checkpoint and config file for your model of choice, please:
55 | 1. Put both files in a directory called `{NAME}/unet`, where `NAME` is the model checkpoint's filename without the `.safetensors` ending, to use it with our evaluation code.
56 | 2. Rename the checkpoint file to `diffusion_pytorch_model.safetensors` and the config file to `config.json`.
57 |
58 | Next, you can proceed to the [**Evaluation/Sampling**](https://github.com/mazurowski-lab/segmentation-guided-diffusion#3-evaluationsampling) section below to generate images from these models.
59 |
60 | ## 2b) Train Your Own Models
61 |
62 | ### Data Preparation
63 |
64 | Please put your training images in some dataset directory `DATA_FOLDER` (with any filenames), organized into train, validation and test split subdirectories. The images should be in a format that PIL can read (e.g. `.png`, `.jpg`, etc.) if they are standard 1- or 3-channel images, and for images with other channel counts use `.npy` NumPy array files. For example:
65 |
66 | ```
67 | DATA_FOLDER
68 | ├── train
69 | │ ├── tr_1.png
70 | │ ├── tr_2.png
71 | │ └── ...
72 | ├── val
73 | │ ├── val_1.png
74 | │ ├── val_2.png
75 | │ └── ...
76 | └── test
77 | ├── ts_1.png
78 | ├── ts_2.png
79 | └── ...
80 | ```
81 |
82 | If you are using a segmentation-guided model, please put your segmentation masks within a similar directory structure in a separate folder `MASK_FOLDER`, with a subdirectory `all` that contains the split subfolders, as shown below. **Each segmentation mask should have the same filename as its corresponding image in `DATA_FOLDER`, and should be saved with integer values starting at zero for each object class, i.e., 0, 1, 2,...**. If you don't want to train a segmentation-guided model, you can skip this step.
83 |
84 | ```
85 | MASK_FOLDER
86 | ├── all
87 | │ ├── train
88 | │ │ ├── tr_1.png
89 | │ │ ├── tr_2.png
90 | │ │ └── ...
91 | │ ├── val
92 | │ │ ├── val_1.png
93 | │ │ ├── val_2.png
94 | │ │ └── ...
95 | │ └── test
96 | │ ├── ts_1.png
97 | │ ├── ts_2.png
98 | │ └── ...
99 | ```
100 |
101 | ### Training
102 |
103 | The basic command for training a standard unconditional diffusion model is
104 | ```bash
105 | CUDA_VISIBLE_DEVICES={DEVICES} python3 main.py \
106 | --mode train \
107 | --model_type DDIM \
108 | --img_size {IMAGE_SIZE} \
109 | --num_img_channels {NUM_IMAGE_CHANNELS} \
110 | --dataset {DATASET_NAME} \
111 | --img_dir {DATA_FOLDER} \
112 | --train_batch_size 16 \
113 | --eval_batch_size 8 \
114 | --num_epochs 400
115 | ```
116 |
117 | where:
118 | - `DEVICES` is a comma-separated list of GPU device indices to use (e.g. `0,1,2,3`).
119 | - `IMAGE_SIZE` and `NUM_IMAGE_CHANNELS` respectively specify the size of the images to train on (e.g. `256`) and the number of channels (`1` for greyscale, `3` for RGB).
120 | - `model_type` specifies the type of diffusion model sampling algorithm to evaluate the model with, and can be `DDIM` or `DDPM`.
121 | - `DATASET_NAME` is some name for your dataset (e.g. `breast_mri`).
122 | - `DATA_FOLDER` is the path to your dataset directory, as outlined in the previous section.
123 | - `--train_batch_size` and `--eval_batch_size` specify the batch sizes for training and evaluation, respectively. We use a train batch size of 16 for one 48 GB A6000 GPU for an image size of 256.
124 | - `--num_epochs` specifies the number of epochs to train for (our default is 400).
125 |
126 | #### Adding segmentation guidance, mask-ablated training, and other options
127 |
128 | To train your model with mask guidance, simply add the options:
129 | ```bash
130 | --seg_dir {MASK_FOLDER} \
131 | --segmentation_guided \
132 | --num_segmentation_classes {N_SEGMENTATION_CLASSES} \
133 | ```
134 |
135 | where:
136 | - `MASK_FOLDER` is the path to your segmentation mask directory, as outlined in the previous section.
137 | - `N_SEGMENTATION_CLASSES` is the number of classes in your segmentation masks, **including the background (0) class**.
138 |
139 | To also train your model with mask ablation (randomly removing classes from the masks to each the model to condition on masks with missing classes; see our paper for details), simply also add the option `--use_ablated_segmentations`.
140 |
141 | ## 3) Evaluation/Sampling
142 |
143 | Sampling images with a trained model is run similarly to training. For example, 100 samples from an unconditional model can be generated with the command:
144 | ```bash
145 | CUDA_VISIBLE_DEVICES={DEVICES} python3 main.py \
146 | --mode eval_many \
147 | --model_type DDIM \
148 | --img_size 256 \
149 | --num_img_channels {NUM_IMAGE_CHANNELS} \
150 | --dataset {DATASET_NAME} \
151 | --eval_batch_size 8 \
152 | --eval_sample_size 100
153 | ```
154 |
155 | Note that the code will automatically use the checkpoint from the training run, and will save the generated images to a directory called `samples` in the model's output directory. To sample from a model with segmentation guidance, simply add the options:
156 | ```bash
157 | --seg_dir {MASK_FOLDER} \
158 | --segmentation_guided \
159 | --num_segmentation_classes {N_SEGMENTATION_CLASSES} \
160 | ```
161 | This will generate images conditioned on the segmentation masks in `MASK_FOLDER/all/test`. Segmentation masks should be saved as image files (e.g., `.png`) with integer values starting at zero for each object class, i.e., 0, 1, 2.
162 |
163 | ## Additional Options/Config (learning rate, etc.)
164 | Our code has further options for training and evaluation; run `python3 main.py --help` for more information. Further settings still can be changed under `class TrainingConfig:` in `training.py` (some of which are exposed as command-line options for `main.py`, and some of which are not).
165 |
166 | ## Troubleshooting/Bugfixing
167 | - **Noisy generated images**: Sometimes your model may be generating images which have some noise; see https://github.com/mazurowski-lab/segmentation-guided-diffusion/issues/12 for example. Our suggested fix would be to either reduce the learning rate (e.g., to `2e-5`), at: https://github.com/mazurowski-lab/segmentation-guided-diffusion/blob/b1ef8b137eaaefab0210e52b4c49f34ff6067fa6/training.py#L29 or simply try training for more epochs.
168 | - Some users have reported a [bug](https://github.com/mazurowski-lab/segmentation-guided-diffusion/issues/11) when the model attempts to save during training and they receive an error of `module 'safetensors' has no attribute 'torch'`. This appears to be an issue with the `diffusers` library itself in some environments, and may be remedied by [this proposed solution](https://github.com/mazurowski-lab/segmentation-guided-diffusion/issues/11#issuecomment-2251890600).
169 |
170 | ## Synthetic Paired Breast MRI Dataset Release
171 |
172 |
173 |
174 | We also release synthetic 2D breast MRI slice images that are paired/"pre-registered" in terms of blood vessels and fibroglandular tissue to real image counterparts, generated by our segmentation-guided model. These were created by applying our mask-ablated-trained segmentation-guided model to the existing segmentation masks of the held-out training set and test sets of our paper (30 patients total or about ~5000 2D slice images; see the paper for more information), but with the breast mask removed. Because of this, each of these synthetic images have blood vessels and fibroglandular tissues that are registered to/spatially match those of a real image counterpart, but have different surrounding breast tissue and shape. This paired data enables potential applications such as training some breast MRI registration model, self-supervised learning, etc.
175 |
176 | The data can be downloaded [here](https://drive.google.com/file/d/1yaLLdzMhAjWUEzdkTa5FjbX3d7fKvQxM/view), and is organized as follows.
177 |
178 | ### Filename convention/dataset organization
179 | The generated images are stored in `synthetic_data` in the downloadable `.zip` file above, with the filename convention `condon_Breast_MRI_{PATIENTNUMBER}_slice_{SLICENUMBER}.png`, where `PATIENTNUMBER` is the original dataset's patient number that the image comes from, and `SLICE_NUMBER` is the $z$-axis index of the original 3D MRI image that the 2D slice image was taken from. The corresponding real slice image, found in `real_data`, is then named `Breast_MRI_{PATIENTNUMBER}_slice_{SLICENUMBER}.png`. Finally, the corresponding segmentation mask from which the blood vessel and fibrogladular tissue masks were used to generate the images are in `segmentations` (the breast masks are also present, but were not used to generate the images). For example:
180 | ```
181 | synthetic_data
182 | │ ├── condon_Breast_MRI_002_slice_0.png
183 | │ ├── condon_Breast_MRI_002_slice_1.png
184 | │ └── ...
185 | real_data
186 | │ ├── Breast_MRI_002_slice_0.png
187 | │ ├── Breast_MRI_002_slice_1.png
188 | │ └── ...
189 | segmentations
190 | │ ├── Breast_MRI_002_slice_0.png
191 | │ ├── Breast_MRI_002_slice_1.png
192 | │ └── ...
193 | ```
194 |
195 | ### Dataset Citation and License
196 | If you use this data, please cite both our paper (see **Citation** above) and the original breast MRI dataset (below), and follow the [CC BY-NC 4.0](https://creativecommons.org/licenses/by-nc/4.0/) license.
197 | ```bib
198 | @misc{sahadata,
199 | author={Saha, Ashirbani and Harowicz, Michael R and Grimm, Lars J and Kim, Connie E and Ghate, Sujata V and Walsh, Ruth and Mazurowski, Maciej A},
200 | title = {Dynamic contrast-enhanced magnetic resonance images of breast cancer patients with tumor locations [Data set]},
201 | year = {2021},
202 | publisher = {The Cancer Imaging Archive},
203 | howpublished = {\url{https://doi.org/10.7937/TCIA.e3sv-re93}},
204 | }
205 | ```
206 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | """
2 | model evaluation/sampling
3 | """
4 | import math
5 | import os
6 | import torch
7 | from typing import List, Optional, Tuple, Union
8 | from tqdm import tqdm
9 | from copy import deepcopy
10 | import numpy as np
11 |
12 | import diffusers
13 | from diffusers import DiffusionPipeline, ImagePipelineOutput, DDIMScheduler
14 | from diffusers.utils.torch_utils import randn_tensor
15 |
16 | from utils import make_grid
17 | from torchvision.utils import save_image
18 | import matplotlib.pyplot as plt
19 | from matplotlib.colors import ListedColormap
20 |
21 | ####################
22 | # segmentation-guided DDPM
23 | ####################
24 |
25 | def evaluate_sample_many(
26 | sample_size,
27 | config,
28 | model,
29 | noise_scheduler,
30 | eval_dataloader,
31 | device='cuda'
32 | ):
33 |
34 | # for loading segs to condition on:
35 | # setup for sampling
36 | if config.model_type == "DDPM":
37 | if config.segmentation_guided:
38 | pipeline = SegGuidedDDPMPipeline(
39 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config
40 | )
41 | else:
42 | pipeline = diffusers.DDPMPipeline(unet=model.module, scheduler=noise_scheduler)
43 | elif config.model_type == "DDIM":
44 | if config.segmentation_guided:
45 | pipeline = SegGuidedDDIMPipeline(
46 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config
47 | )
48 | else:
49 | pipeline = diffusers.DDIMPipeline(unet=model.module, scheduler=noise_scheduler)
50 |
51 |
52 | sample_dir = test_dir = os.path.join(config.output_dir, "samples_many_{}".format(sample_size))
53 | if not os.path.exists(sample_dir):
54 | os.makedirs(sample_dir)
55 |
56 | num_sampled = 0
57 | # keep sampling images until we have enough
58 | for bidx, seg_batch in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)):
59 | if num_sampled < sample_size:
60 | if config.segmentation_guided:
61 | current_batch_size = [v for k, v in seg_batch.items() if k.startswith("seg_")][0].shape[0]
62 | else:
63 | current_batch_size = config.eval_batch_size
64 |
65 | if config.segmentation_guided:
66 | images = pipeline(
67 | batch_size = current_batch_size,
68 | seg_batch=seg_batch,
69 | ).images
70 | else:
71 | images = pipeline(
72 | batch_size = current_batch_size,
73 | ).images
74 |
75 | # save each image in the list separately
76 | for i, img in enumerate(images):
77 | if config.segmentation_guided:
78 | # name base on input mask fname
79 | img_fname = "{}/condon_{}".format(sample_dir, seg_batch["image_filenames"][i])
80 | else:
81 | img_fname = f"{sample_dir}/{num_sampled + i:04d}.png"
82 | img.save(img_fname)
83 |
84 | num_sampled += len(images)
85 | print("sampled {}/{}.".format(num_sampled, sample_size))
86 |
87 |
88 |
89 | def evaluate_generation(
90 | config,
91 | model,
92 | noise_scheduler,
93 | eval_dataloader,
94 | class_label_cfg=None,
95 | translate=False,
96 | eval_mask_removal=False,
97 | eval_blank_mask=False,
98 | device='cuda'
99 | ):
100 | """
101 | general function to evaluate (possibly mask-guided) trained image generation model in useful ways.
102 | also has option to use CFG for class-conditioned sampling (otherwise, class-conditional models will be evaluated using naive class conditioning and sampling from both classes).
103 |
104 | can also evaluate for image translation.
105 | """
106 |
107 | # for loading segs to condition on:
108 | eval_dataloader = iter(eval_dataloader)
109 |
110 | if config.segmentation_guided:
111 | seg_batch = next(eval_dataloader)
112 | if eval_blank_mask:
113 | # use blank masks
114 | for k, v in seg_batch.items():
115 | if k.startswith("seg_"):
116 | seg_batch[k] = torch.zeros_like(v)
117 |
118 | # setup for sampling
119 | # After each epoch you optionally sample some demo images with evaluate() and save the model
120 | if config.model_type == "DDPM":
121 | if config.segmentation_guided:
122 | pipeline = SegGuidedDDPMPipeline(
123 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config
124 | )
125 | else:
126 | pipeline = diffusers.DDPMPipeline(unet=model.module, scheduler=noise_scheduler)
127 | elif config.model_type == "DDIM":
128 | if config.segmentation_guided:
129 | pipeline = SegGuidedDDIMPipeline(
130 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config
131 | )
132 | else:
133 | pipeline = diffusers.DDIMPipeline(unet=model.module, scheduler=noise_scheduler)
134 |
135 | # sample some images
136 | if config.segmentation_guided:
137 | evaluate(config, -1, pipeline, seg_batch, class_label_cfg, translate)
138 | else:
139 | if config.class_conditional:
140 | raise NotImplementedError("TODO: implement CFG and naive conditioning sampling for non-seg-guided pipelines, including for image translation")
141 | evaluate(config, -1, pipeline)
142 |
143 | # seg-guided specific visualizations
144 | if config.segmentation_guided and eval_mask_removal:
145 | plot_result_masks_multiclass = True
146 | if plot_result_masks_multiclass:
147 | pipeoutput_type = 'np'
148 | else:
149 | pipeoutput_type = 'pil'
150 |
151 | # visualize segmentation-guided sampling by seeing what happens
152 | # when segs removed
153 | num_viz = config.eval_batch_size
154 |
155 | # choose one seg to sample from; duplicate it
156 | eval_same_image = False
157 | if eval_same_image:
158 | seg_batch = {k: torch.cat(num_viz*[v[:1]]) for k, v in seg_batch.items()}
159 |
160 | result_masks = torch.Tensor()
161 | multiclass_masks = []
162 | result_imgs = []
163 | multiclass_masks_shape = (config.eval_batch_size, 1, config.image_size, config.image_size)
164 |
165 | # will plot segs + sampled images
166 | for seg_type in seg_batch.keys():
167 | if seg_type.startswith("seg_"):
168 | #convert from tensor to PIL
169 | seg_batch_plt = seg_batch[seg_type].cpu()
170 | result_masks = torch.cat((result_masks, seg_batch_plt))
171 |
172 | # sample given all segs
173 | multiclass_masks.append(convert_segbatch_to_multiclass(multiclass_masks_shape, seg_batch, config, device))
174 | full_seg_imgs = pipeline(
175 | batch_size = num_viz,
176 | seg_batch=seg_batch,
177 | class_label_cfg=class_label_cfg,
178 | translate=translate,
179 | output_type=pipeoutput_type
180 | ).images
181 | if plot_result_masks_multiclass:
182 | result_imgs.append(full_seg_imgs)
183 | else:
184 | result_imgs += full_seg_imgs
185 |
186 | # only sample from masks with chosen classes removed
187 | chosen_class_combinations = None
188 | #chosen_class_combinations = [ #for example:
189 | # {"seg_all": [1, 2]}
190 | #]
191 | if chosen_class_combinations is not None:
192 | for allseg_classes in chosen_class_combinations:
193 | # remove all chosen classes
194 | seg_batch_removed = deepcopy(seg_batch)
195 | for seg_type in seg_batch_removed.keys():
196 | # some datasets have multiple tissue segs stored in multiple masks
197 | if seg_type.startswith("seg_"):
198 | classes = allseg_classes[seg_type]
199 | for mask_val in classes:
200 | if mask_val != 0:
201 | remove_mask = (seg_batch_removed[seg_type]*255).int() == mask_val
202 | seg_batch_removed[seg_type][remove_mask] = 0
203 |
204 | seg_batch_removed_plt = torch.cat([seg_batch_removed[seg_type].cpu() for seg_type in seg_batch_removed.keys() if seg_type.startswith("seg_")])
205 | result_masks = torch.cat((result_masks, seg_batch_removed_plt))
206 |
207 | multiclass_masks.append(convert_segbatch_to_multiclass(
208 | multiclass_masks_shape,
209 | seg_batch_removed, config, device))
210 | # add images conditioned on some segs but not all
211 | removed_seg_imgs = pipeline(
212 | batch_size = config.eval_batch_size,
213 | seg_batch=seg_batch_removed,
214 | class_label_cfg=class_label_cfg,
215 | translate=translate,
216 | output_type=pipeoutput_type
217 | ).images
218 |
219 | if plot_result_masks_multiclass:
220 | result_imgs.append(removed_seg_imgs)
221 | else:
222 | result_imgs += removed_seg_imgs
223 |
224 |
225 | else:
226 | for seg_type in seg_batch.keys():
227 | # some datasets have multiple tissue segs stored in multiple masks
228 | if seg_type.startswith("seg_"):
229 | seg_batch_removed = seg_batch
230 | for mask_val in seg_batch[seg_type].unique():
231 | if mask_val != 0:
232 | remove_mask = seg_batch[seg_type] == mask_val
233 | seg_batch_removed[seg_type][remove_mask] = 0
234 |
235 | seg_batch_removed_plt = torch.cat([seg_batch_removed[seg_type].cpu() for seg_type in seg_batch.keys() if seg_type.startswith("seg_")])
236 | result_masks = torch.cat((result_masks, seg_batch_removed_plt))
237 |
238 | multiclass_masks.append(convert_segbatch_to_multiclass(
239 | multiclass_masks_shape,
240 | seg_batch_removed, config, device))
241 | # add images conditioned on some segs but not all
242 | removed_seg_imgs = pipeline(
243 | batch_size = config.eval_batch_size,
244 | seg_batch=seg_batch_removed,
245 | class_label_cfg=class_label_cfg,
246 | translate=translate,
247 | output_type=pipeoutput_type
248 | ).images
249 |
250 | if plot_result_masks_multiclass:
251 | result_imgs.append(removed_seg_imgs)
252 | else:
253 | result_imgs += removed_seg_imgs
254 |
255 | if plot_result_masks_multiclass:
256 | multiclass_masks = np.squeeze(torch.cat(multiclass_masks).cpu().numpy())
257 | multiclass_masks = (multiclass_masks*255).astype(np.uint8)
258 | result_imgs = np.squeeze(np.concatenate(np.array(result_imgs), axis=0))
259 |
260 | # reverse interleave
261 | plot_imgs = np.zeros_like(result_imgs)
262 | plot_imgs[0:len(plot_imgs)//2] = result_imgs[0::2]
263 | plot_imgs[len(plot_imgs)//2:] = result_imgs[1::2]
264 |
265 | plot_masks = np.zeros_like(multiclass_masks)
266 | plot_masks[0:len(plot_masks)//2] = multiclass_masks[0::2]
267 | plot_masks[len(plot_masks)//2:] = multiclass_masks[1::2]
268 |
269 | fig, axs = plt.subplots(
270 | 2, len(plot_masks),
271 | figsize=(len(plot_masks), 2),
272 | dpi=600
273 | )
274 |
275 | for i, img in enumerate(plot_imgs):
276 | if config.dataset == 'breast_mri':
277 | colors = ['black', 'white', 'red', 'blue']
278 | elif config.dataset == 'ct_organ_large':
279 | colors = ['black', 'blue', 'green', 'red', 'yellow', 'magenta']
280 | else:
281 | raise ValueError('Unknown dataset')
282 |
283 | cmap = ListedColormap(colors)
284 | axs[0,i].imshow(plot_masks[i], cmap=cmap, vmin=0, vmax=len(colors)-1)
285 | axs[0,i].axis('off')
286 | axs[1,i].imshow(img, cmap='gray')
287 | axs[1,i].axis('off')
288 |
289 | plt.subplots_adjust(wspace=0, hspace=0)
290 | plt.savefig('ablated_samples_{}.pdf'.format(config.dataset), bbox_inches='tight')
291 | plt.show()
292 |
293 |
294 |
295 | else:
296 | # Make a grid out of the images
297 | cols = num_viz
298 | rows = math.ceil(len(result_imgs) / cols)
299 | image_grid = make_grid(result_imgs, rows=rows, cols=cols)
300 |
301 | # Save the images
302 | test_dir = os.path.join(config.output_dir, "samples")
303 | os.makedirs(test_dir, exist_ok=True)
304 | image_grid.save(f"{test_dir}/mask_removal_imgs.png")
305 |
306 | save_image(result_masks, f"{test_dir}/mask_removal_masks.png", normalize=True,
307 | nrow=cols*len(seg_batch.keys()) - 2)
308 |
309 | def convert_segbatch_to_multiclass(shape, segmentations_batch, config, device):
310 | # NOTE: this generic function assumes that segs don't overlap
311 | # put all segs on same channel
312 | segs = torch.zeros(shape).to(device)
313 | for k, seg in segmentations_batch.items():
314 | if k.startswith("seg_"):
315 | seg = seg.to(device)
316 | segs[segs == 0] = seg[segs == 0]
317 |
318 | if config.use_ablated_segmentations:
319 | # randomly remove class labels from segs with some probability
320 | segs = ablate_masks(segs, config)
321 |
322 | return segs
323 |
324 | def ablate_masks(segs, config, method="equal_weighted"):
325 | # randomly remove class label(s) from segs with some probability
326 | if method == "equal_weighted":
327 | """
328 | # give equal probability to each possible combination of removing non-background classes
329 | # NOTE: requires that each class has a value in ({0, 1, 2, ...} / 255)
330 | # which is by default if the mask file was saved as {0, 1, 2 ,...} and then normalized by default to [0, 1] by transforms.ToTensor()
331 | # num_segmentation_classes
332 | """
333 | class_removals = (torch.rand(config.num_segmentation_classes - 1) < 0.5).int().bool().tolist()
334 | for class_idx, remove_class in enumerate(class_removals):
335 | if remove_class:
336 | segs[(255 * segs).int() == class_idx + 1] = 0
337 |
338 | elif method == "by_class":
339 | class_ablation_prob = 0.3
340 | for seg_value in segs.unique():
341 | if seg_value != 0:
342 | # remove seg with some probability
343 | if torch.rand(1).item() < class_ablation_prob:
344 | segs[segs == seg_value] = 0
345 |
346 | else:
347 | raise NotImplementedError
348 | return segs
349 |
350 | def add_segmentations_to_noise(noisy_images, segmentations_batch, config, device):
351 | """
352 | concat segmentations to noisy image
353 | """
354 |
355 | if config.segmentation_channel_mode == "single":
356 | multiclass_masks_shape = (noisy_images.shape[0], 1, noisy_images.shape[2], noisy_images.shape[3])
357 | segs = convert_segbatch_to_multiclass(multiclass_masks_shape, segmentations_batch, config, device)
358 | # concat segs to noise
359 | noisy_images = torch.cat((noisy_images, segs), dim=1)
360 |
361 | elif config.segmentation_channel_mode == "multi":
362 | raise NotImplementedError
363 |
364 | return noisy_images
365 |
366 | ####################
367 | # general DDPM
368 | ####################
369 | def evaluate(config, epoch, pipeline, seg_batch=None, class_label_cfg=None, translate=False):
370 | # Either generate or translate images,
371 | # possibly mask guided and/or class conditioned.
372 | # The default pipeline output type is `List[PIL.Image]`
373 |
374 | if config.segmentation_guided:
375 | images = pipeline(
376 | batch_size = config.eval_batch_size,
377 | seg_batch=seg_batch,
378 | class_label_cfg=class_label_cfg,
379 | translate=translate
380 | ).images
381 | else:
382 | images = pipeline(
383 | batch_size = config.eval_batch_size,
384 | # TODO: implement CFG and naive conditioning sampling for non-seg-guided pipelines (also needed for translation)
385 | ).images
386 |
387 | # Make a grid out of the images
388 | cols = 4
389 | rows = math.ceil(len(images) / cols)
390 | image_grid = make_grid(images, rows=rows, cols=cols)
391 |
392 | # Save the images
393 | test_dir = os.path.join(config.output_dir, "samples")
394 | os.makedirs(test_dir, exist_ok=True)
395 | image_grid.save(f"{test_dir}/{epoch:04d}.png")
396 |
397 | # save segmentations we conditioned the samples on
398 | if config.segmentation_guided:
399 | for seg_type in seg_batch.keys():
400 | if seg_type.startswith("seg_"):
401 | save_image(seg_batch[seg_type], f"{test_dir}/{epoch:04d}_cond_{seg_type}.png", normalize=True, nrow=cols)
402 |
403 | # as well as original images that the segs belong to
404 | img_og = seg_batch['images']
405 | save_image(img_og, f"{test_dir}/{epoch:04d}_orig.png", normalize=True, nrow=cols)
406 |
407 |
408 | # custom diffusers pipelines for sampling from segmentation-guided models
409 | class SegGuidedDDPMPipeline(DiffusionPipeline):
410 | r"""
411 | Pipeline for segmentation-guided image generation, modified from DDPMPipeline.
412 | generates both-class conditioned and unconditional images if using class-conditional model without CFG, or just generates
413 | conditional images guided by CFG.
414 |
415 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
416 | implemented for all pipelines (downloading, saving, running on a particular device, etc.).
417 |
418 | Parameters:
419 | unet ([`UNet2DModel`]):
420 | A `UNet2DModel` to denoise the encoded image latents.
421 | scheduler ([`SchedulerMixin`]):
422 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
423 | [`DDPMScheduler`], or [`DDIMScheduler`].
424 | eval_dataloader ([`torch.utils.data.DataLoader`]):
425 | Dataloader to load the evaluation dataset of images and their segmentations. Here only uses the segmentations to generate images.
426 | """
427 | model_cpu_offload_seq = "unet"
428 |
429 | def __init__(self, unet, scheduler, eval_dataloader, external_config):
430 | super().__init__()
431 | self.register_modules(unet=unet, scheduler=scheduler)
432 | self.eval_dataloader = eval_dataloader
433 | self.external_config = external_config # config is already a thing
434 |
435 | @torch.no_grad()
436 | def __call__(
437 | self,
438 | batch_size: int = 1,
439 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
440 | num_inference_steps: int = 1000,
441 | output_type: Optional[str] = "pil",
442 | return_dict: bool = True,
443 | seg_batch: Optional[torch.Tensor] = None,
444 | class_label_cfg: Optional[int] = None,
445 | translate = False,
446 | ) -> Union[ImagePipelineOutput, Tuple]:
447 | r"""
448 | The call function to the pipeline for generation.
449 |
450 | Args:
451 | batch_size (`int`, *optional*, defaults to 1):
452 | The number of images to generate.
453 | generator (`torch.Generator`, *optional*):
454 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
455 | generation deterministic.
456 | num_inference_steps (`int`, *optional*, defaults to 1000):
457 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
458 | expense of slower inference.
459 | output_type (`str`, *optional*, defaults to `"pil"`):
460 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
461 | return_dict (`bool`, *optional*, defaults to `True`):
462 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
463 | seg_batch (`torch.Tensor`, *optional*, defaults to None):
464 | batch of segmentations to condition generation on
465 | class_label_cfg (`int`, *optional*, defaults to `None`):
466 | class label to condition generation on using CFG, if using class-conditional model
467 |
468 | OPTIONS FOR IMAGE TRANSLATION:
469 | translate (`bool`, *optional*, defaults to False):
470 | whether to translate images from the source domain to the target domain
471 |
472 | Returns:
473 | [`~pipelines.ImagePipelineOutput`] or `tuple`:
474 | If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
475 | returned where the first element is a list with the generated images
476 | """
477 | # Sample gaussian noise to begin loop
478 | if self.external_config.segmentation_channel_mode == "single":
479 | img_channel_ct = self.unet.config.in_channels - 1
480 | elif self.external_config.segmentation_channel_mode == "multi":
481 | img_channel_ct = self.unet.config.in_channels - len([k for k in seg_batch.keys() if k.startswith("seg_")])
482 |
483 | if isinstance(self.unet.config.sample_size, int):
484 | image_shape = (
485 | batch_size,
486 | img_channel_ct,
487 | self.unet.config.sample_size,
488 | self.unet.config.sample_size,
489 | )
490 | else:
491 | if self.external_config.segmentation_channel_mode == "single":
492 | image_shape = (batch_size, self.unet.config.in_channels - 1, *self.unet.config.sample_size)
493 | elif self.external_config.segmentation_channel_mode == "multi":
494 | image_shape = (batch_size, self.unet.config.in_channels - len([k for k in seg_batch.keys() if k.startswith("seg_")]), *self.unet.config.sample_size)
495 |
496 |
497 | # initiate latent variable to sample from
498 | if not translate:
499 | # normal sampling; start from noise
500 | if self.device.type == "mps":
501 | # randn does not work reproducibly on mps
502 | image = randn_tensor(image_shape, generator=generator)
503 | image = image.to(self.device)
504 | else:
505 | image = randn_tensor(image_shape, generator=generator, device=self.device)
506 | else:
507 | # image translation sampling; start from source domain images, add noise up to certain step, then being there for denoising
508 | trans_start_t = int(self.external_config.trans_noise_level * self.scheduler.config.num_train_timesteps)
509 |
510 | trans_start_images = seg_batch["images"]
511 |
512 | # Sample noise to add to the images
513 | noise = torch.randn(trans_start_images.shape).to(trans_start_images.device)
514 | timesteps = torch.full(
515 | (trans_start_images.size(0),),
516 | trans_start_t,
517 | device=trans_start_images.device
518 | ).long()
519 | image = self.scheduler.add_noise(trans_start_images, noise, timesteps)
520 |
521 | # set step values
522 | self.scheduler.set_timesteps(num_inference_steps)
523 |
524 | for t in self.progress_bar(self.scheduler.timesteps):
525 | if translate:
526 | # if doing translation, start at chosen time step given partially-noised image
527 | # skip all earlier time steps (with higher t)
528 | if t >= trans_start_t:
529 | continue
530 |
531 | # 1. predict noise model_output
532 | # first, concat segmentations to noise
533 | image = add_segmentations_to_noise(image, seg_batch, self.external_config, self.device)
534 |
535 | if self.external_config.class_conditional:
536 | if class_label_cfg is not None:
537 | class_labels = torch.full([image.size(0)], class_label_cfg).long().to(self.device)
538 | model_output_cond = self.unet(image, t, class_labels=class_labels).sample
539 | if self.external_config.use_cfg_for_eval_conditioning:
540 | # use classifier-free guidance for sampling from the given class
541 |
542 | if self.external_config.cfg_maskguidance_condmodel_only:
543 | image_emptymask = torch.cat((image[:, :img_channel_ct, :, :], torch.zeros_like(image[:, img_channel_ct:, :, :])), dim=1)
544 | model_output_uncond = self.unet(image_emptymask, t,
545 | class_labels=torch.zeros_like(class_labels).long()).sample
546 | else:
547 | model_output_uncond = self.unet(image, t,
548 | class_labels=torch.zeros_like(class_labels).long()).sample
549 |
550 | # use cfg equation
551 | model_output = (1. + self.external_config.cfg_weight) * model_output_cond - self.external_config.cfg_weight * model_output_uncond
552 | else:
553 | # just use normal conditioning
554 | model_output = model_output_cond
555 |
556 | else:
557 | # or, just use basic network conditioning to sample from both classes
558 | if self.external_config.class_conditional:
559 | # if training conditionally, evaluate source domain samples
560 | class_labels = torch.ones(image.size(0)).long().to(self.device)
561 | model_output = self.unet(image, t, class_labels=class_labels).sample
562 | else:
563 | model_output = self.unet(image, t).sample
564 | # output is slightly denoised image
565 |
566 | # 2. compute previous image: x_t -> x_t-1
567 | # but first, we're only adding denoising the image channel (not seg channel),
568 | # so remove segs
569 | image = image[:, :img_channel_ct, :, :]
570 | image = self.scheduler.step(model_output, t, image, generator=generator).prev_sample
571 |
572 | # if training conditionally, also evaluate for target domain images
573 | # if not using chosen class for CFG
574 | if self.external_config.class_conditional and class_label_cfg is None:
575 | image_target_domain = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)
576 |
577 | # set step values
578 | self.scheduler.set_timesteps(num_inference_steps)
579 |
580 | for t in self.progress_bar(self.scheduler.timesteps):
581 | # 1. predict noise model_output
582 | # first, concat segmentations to noise
583 | # no masks in target domain so just use blank masks
584 | image_target_domain = torch.cat((image_target_domain, torch.zeros_like(image_target_domain)), dim=1)
585 |
586 | if self.external_config.class_conditional:
587 | # if training conditionally, also evaluate unconditional model and target domain (no masks)
588 | class_labels = torch.cat([torch.full([image_target_domain.size(0) // 2], 2), torch.zeros(image_target_domain.size(0)) // 2]).long().to(self.device)
589 | model_output = self.unet(image_target_domain, t, class_labels=class_labels).sample
590 | else:
591 | model_output = self.unet(image_target_domain, t).sample
592 |
593 | # 2. predict previous mean of image x_t-1 and add variance depending on eta
594 | # eta corresponds to η in paper and should be between [0, 1]
595 | # do x_t -> x_t-1
596 | # but first, we're only adding denoising the image channel (not seg channel),
597 | # so remove segs
598 | image_target_domain = image_target_domain[:, :img_channel_ct, :, :]
599 | image_target_domain = self.scheduler.step(
600 | model_output, t, image_target_domain, generator=generator
601 | ).prev_sample
602 |
603 | image = torch.cat((image, image_target_domain), dim=0)
604 | # will output source domain images first, then target domain images
605 |
606 | image = (image / 2 + 0.5).clamp(0, 1)
607 | image = image.cpu().permute(0, 2, 3, 1).numpy()
608 | if output_type == "pil":
609 | image = self.numpy_to_pil(image)
610 |
611 | if not return_dict:
612 | return (image,)
613 |
614 | return ImagePipelineOutput(images=image)
615 |
616 | class SegGuidedDDIMPipeline(DiffusionPipeline):
617 | r"""
618 | Pipeline for image generation, modified for seg-guided image gen.
619 | modified from diffusers.DDIMPipeline.
620 | generates both-class conditioned and unconditional images if using class-conditional model without CFG, or just generates
621 | conditional images guided by CFG.
622 |
623 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
624 | implemented for all pipelines (downloading, saving, running on a particular device, etc.).
625 |
626 | Parameters:
627 | unet ([`UNet2DModel`]):
628 | A `UNet2DModel` to denoise the encoded image latents.
629 | scheduler ([`SchedulerMixin`]):
630 | A scheduler to be used in combination with `unet` to denoise the encoded image. Can be one of
631 | [`DDPMScheduler`], or [`DDIMScheduler`].
632 | eval_dataloader ([`torch.utils.data.DataLoader`]):
633 | Dataloader to load the evaluation dataset of images and their segmentations. Here only uses the segmentations to generate images.
634 |
635 | """
636 | model_cpu_offload_seq = "unet"
637 |
638 | def __init__(self, unet, scheduler, eval_dataloader, external_config):
639 | super().__init__()
640 | self.register_modules(unet=unet, scheduler=scheduler, eval_dataloader=eval_dataloader, external_config=external_config)
641 | # ^ some reason necessary for DDIM but not DDPM.
642 |
643 | self.eval_dataloader = eval_dataloader
644 | self.external_config = external_config # config is already a thing
645 |
646 | # make sure scheduler can always be converted to DDIM
647 | scheduler = DDIMScheduler.from_config(scheduler.config)
648 |
649 |
650 | @torch.no_grad()
651 | def __call__(
652 | self,
653 | batch_size: int = 1,
654 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
655 | eta: float = 0.0,
656 | num_inference_steps: int = 50,
657 | use_clipped_model_output: Optional[bool] = None,
658 | output_type: Optional[str] = "pil",
659 | return_dict: bool = True,
660 | seg_batch: Optional[torch.Tensor] = None,
661 | class_label_cfg: Optional[int] = None,
662 | translate = False,
663 | ) -> Union[ImagePipelineOutput, Tuple]:
664 | r"""
665 | The call function to the pipeline for generation.
666 |
667 | Args:
668 | batch_size (`int`, *optional*, defaults to 1):
669 | The number of images to generate.
670 | generator (`torch.Generator`, *optional*):
671 | A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
672 | generation deterministic.
673 | eta (`float`, *optional*, defaults to 0.0):
674 | Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
675 | to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. A value of `0` corresponds to
676 | DDIM and `1` corresponds to DDPM.
677 | num_inference_steps (`int`, *optional*, defaults to 50):
678 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the
679 | expense of slower inference.
680 | use_clipped_model_output (`bool`, *optional*, defaults to `None`):
681 | If `True` or `False`, see documentation for [`DDIMScheduler.step`]. If `None`, nothing is passed
682 | downstream to the scheduler (use `None` for schedulers which don't support this argument).
683 | output_type (`str`, *optional*, defaults to `"pil"`):
684 | The output format of the generated image. Choose between `PIL.Image` or `np.array`.
685 | return_dict (`bool`, *optional*, defaults to `True`):
686 | Whether or not to return a [`~pipelines.ImagePipelineOutput`] instead of a plain tuple.
687 | seg_batch (`torch.Tensor`, *optional*):
688 | batch of segmentations to condition generation on
689 | class_label_cfg (`int`, *optional*, defaults to `None`):
690 | class label to condition generation on using CFG, if using class-conditional model
691 |
692 | OPTIONS FOR IMAGE TRANSLATION:
693 | translate (`bool`, *optional*, defaults to False):
694 | whether to translate images from the source domain to the target domain
695 |
696 | Example:
697 |
698 | ```py
699 |
700 | Returns:
701 | [`~pipelines.ImagePipelineOutput`] or `tuple`:
702 | If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is
703 | returned where the first element is a list with the generated images
704 | """
705 |
706 | # Sample gaussian noise to begin loop
707 | if self.external_config.segmentation_channel_mode == "single":
708 | img_channel_ct = self.unet.config.in_channels - 1
709 | elif self.external_config.segmentation_channel_mode == "multi":
710 | img_channel_ct = self.unet.config.in_channels - len([k for k in seg_batch.keys() if k.startswith("seg_")])
711 |
712 | if isinstance(self.unet.config.sample_size, int):
713 | if self.external_config.segmentation_channel_mode == "single":
714 | image_shape = (
715 | batch_size,
716 | self.unet.config.in_channels - 1,
717 | self.unet.config.sample_size,
718 | self.unet.config.sample_size,
719 | )
720 | elif self.external_config.segmentation_channel_mode == "multi":
721 | image_shape = (
722 | batch_size,
723 | self.unet.config.in_channels - len([k for k in seg_batch.keys() if k.startswith("seg_")]),
724 | self.unet.config.sample_size,
725 | self.unet.config.sample_size,
726 | )
727 | else:
728 | if self.external_config.segmentation_channel_mode == "single":
729 | image_shape = (batch_size, self.unet.config.in_channels - 1, *self.unet.config.sample_size)
730 | elif self.external_config.segmentation_channel_mode == "multi":
731 | image_shape = (batch_size, self.unet.config.in_channels - len([k for k in seg_batch.keys() if k.startswith("seg_")]), *self.unet.config.sample_size)
732 |
733 | if isinstance(generator, list) and len(generator) != batch_size:
734 | raise ValueError(
735 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
736 | f" size of {batch_size}. Make sure the batch size matches the length of the generators."
737 | )
738 |
739 | # initiate latent variable to sample from
740 | if not translate:
741 | # normal sampling; start from noise
742 | image = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)
743 | else:
744 | # image translation sampling; start from source domain images, add noise up to certain step, then being there for denoising
745 | trans_start_t = int(self.external_config.trans_noise_level * self.scheduler.config.num_train_timesteps)
746 |
747 | trans_start_images = seg_batch["images"].to(self._execution_device)
748 |
749 | # Sample noise to add to the images
750 | noise = torch.randn(trans_start_images.shape).to(trans_start_images.device)
751 | timesteps = torch.full(
752 | (trans_start_images.size(0),),
753 | trans_start_t,
754 | device=trans_start_images.device
755 | ).long()
756 | image = self.scheduler.add_noise(trans_start_images, noise, timesteps)
757 |
758 | # set step values
759 | self.scheduler.set_timesteps(num_inference_steps)
760 |
761 | for t in self.progress_bar(self.scheduler.timesteps):
762 | if translate:
763 | # if doing translation, start at chosen time step given partially-noised image
764 | # skip all earlier time steps (with higher t)
765 | if t >= trans_start_t:
766 | continue
767 |
768 | # 1. predict noise model_output
769 | # first, concat segmentations to noise
770 | image = add_segmentations_to_noise(image, seg_batch, self.external_config, self.device)
771 |
772 | if self.external_config.class_conditional:
773 | if class_label_cfg is not None:
774 | class_labels = torch.full([image.size(0)], class_label_cfg).long().to(self.device)
775 | model_output_cond = self.unet(image, t, class_labels=class_labels).sample
776 | if self.external_config.use_cfg_for_eval_conditioning:
777 | # use classifier-free guidance for sampling from the given class
778 | if self.external_config.cfg_maskguidance_condmodel_only:
779 | image_emptymask = torch.cat((image[:, :img_channel_ct, :, :], torch.zeros_like(image[:, img_channel_ct:, :, :])), dim=1)
780 | model_output_uncond = self.unet(image_emptymask, t,
781 | class_labels=torch.zeros_like(class_labels).long()).sample
782 | else:
783 | model_output_uncond = self.unet(image, t,
784 | class_labels=torch.zeros_like(class_labels).long()).sample
785 |
786 | # use cfg equation
787 | model_output = (1. + self.external_config.cfg_weight) * model_output_cond - self.external_config.cfg_weight * model_output_uncond
788 | else:
789 | model_output = model_output_cond
790 |
791 | else:
792 | # or, just use basic network conditioning to sample from both classes
793 | if self.external_config.class_conditional:
794 | # if training conditionally, evaluate source domain samples
795 | class_labels = torch.ones(image.size(0)).long().to(self.device)
796 | model_output = self.unet(image, t, class_labels=class_labels).sample
797 | else:
798 | model_output = self.unet(image, t).sample
799 |
800 | # 2. predict previous mean of image x_t-1 and add variance depending on eta
801 | # eta corresponds to η in paper and should be between [0, 1]
802 | # do x_t -> x_t-1
803 | # but first, we're only adding denoising the image channel (not seg channel),
804 | # so remove segs
805 | image = image[:, :img_channel_ct, :, :]
806 | image = self.scheduler.step(
807 | model_output, t, image, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
808 | ).prev_sample
809 |
810 | # if training conditionally, also evaluate for target domain images
811 | # if not using chosen class for CFG
812 | if self.external_config.class_conditional and class_label_cfg is None:
813 | image_target_domain = randn_tensor(image_shape, generator=generator, device=self._execution_device, dtype=self.unet.dtype)
814 |
815 | # set step values
816 | self.scheduler.set_timesteps(num_inference_steps)
817 |
818 | for t in self.progress_bar(self.scheduler.timesteps):
819 | # 1. predict noise model_output
820 | # first, concat segmentations to noise
821 | # no masks in target domain so just use blank masks
822 | image_target_domain = torch.cat((image_target_domain, torch.zeros_like(image_target_domain)), dim=1)
823 |
824 | if self.external_config.class_conditional:
825 | # if training conditionally, also evaluate unconditional model and target domain (no masks)
826 | class_labels = torch.cat([torch.full([image_target_domain.size(0) // 2], 2), torch.zeros(image_target_domain.size(0) // 2)]).long().to(self.device)
827 | model_output = self.unet(image_target_domain, t, class_labels=class_labels).sample
828 | else:
829 | model_output = self.unet(image_target_domain, t).sample
830 |
831 | # 2. predict previous mean of image x_t-1 and add variance depending on eta
832 | # eta corresponds to η in paper and should be between [0, 1]
833 | # do x_t -> x_t-1
834 | # but first, we're only adding denoising the image channel (not seg channel),
835 | # so remove segs
836 | image_target_domain = image_target_domain[:, :img_channel_ct, :, :]
837 | image_target_domain = self.scheduler.step(
838 | model_output, t, image_target_domain, eta=eta, use_clipped_model_output=use_clipped_model_output, generator=generator
839 | ).prev_sample
840 |
841 | image = torch.cat((image, image_target_domain), dim=0)
842 | # will output source domain images first, then target domain images
843 |
844 | image = (image / 2 + 0.5).clamp(0, 1)
845 | image = image.cpu().permute(0, 2, 3, 1).numpy()
846 | if output_type == "pil":
847 | image = self.numpy_to_pil(image)
848 |
849 | if not return_dict:
850 | return (image,)
851 |
852 | return ImagePipelineOutput(images=image)
853 |
--------------------------------------------------------------------------------
/figs/teaser.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazurowski-lab/segmentation-guided-diffusion/9a31f0f4eef0d7b835f01b2a6301df854b8040e9/figs/teaser.png
--------------------------------------------------------------------------------
/figs/teaser_data.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/mazurowski-lab/segmentation-guided-diffusion/9a31f0f4eef0d7b835f01b2a6301df854b8040e9/figs/teaser_data.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentParser
3 |
4 | # torch imports
5 | import torch
6 | from torch import nn
7 | from torchvision import transforms
8 | import torch.nn.functional as F
9 | import numpy as np
10 |
11 | # HF imports
12 | import diffusers
13 | from diffusers.optimization import get_cosine_schedule_with_warmup
14 | import datasets
15 |
16 | # custom imports
17 | from training import TrainingConfig, train_loop
18 | from eval import evaluate_generation, evaluate_sample_many
19 |
20 | def main(
21 | mode,
22 | img_size,
23 | num_img_channels,
24 | dataset,
25 | img_dir,
26 | seg_dir,
27 | model_type,
28 | segmentation_guided,
29 | segmentation_channel_mode,
30 | num_segmentation_classes,
31 | train_batch_size,
32 | eval_batch_size,
33 | num_epochs,
34 | resume_epoch=None,
35 | use_ablated_segmentations=False,
36 | eval_shuffle_dataloader=True,
37 |
38 | # arguments only used in eval
39 | eval_mask_removal=False,
40 | eval_blank_mask=False,
41 | eval_sample_size=1000
42 | ):
43 | #GPUs
44 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45 | print('running on {}'.format(device))
46 |
47 | # load config
48 | output_dir = '{}-{}-{}'.format(model_type.lower(), dataset, img_size) # the model namy locally and on the HF Hub
49 | if segmentation_guided:
50 | output_dir += "-segguided"
51 | assert seg_dir is not None, "must provide segmentation directory for segmentation guided training/sampling"
52 |
53 | if use_ablated_segmentations or eval_mask_removal or eval_blank_mask:
54 | output_dir += "-ablated"
55 |
56 | print("output dir: {}".format(output_dir))
57 |
58 | if mode == "train":
59 | evalset_name = "val"
60 | assert img_dir is not None, "must provide image directory for training"
61 | elif "eval" in mode:
62 | evalset_name = "test"
63 |
64 | print("using evaluation set: {}".format(evalset_name))
65 |
66 | config = TrainingConfig(
67 | image_size = img_size,
68 | dataset = dataset,
69 | segmentation_guided = segmentation_guided,
70 | segmentation_channel_mode = segmentation_channel_mode,
71 | num_segmentation_classes = num_segmentation_classes,
72 | train_batch_size = train_batch_size,
73 | eval_batch_size = eval_batch_size,
74 | num_epochs = num_epochs,
75 | output_dir = output_dir,
76 | model_type=model_type,
77 | resume_epoch=resume_epoch,
78 | use_ablated_segmentations=use_ablated_segmentations
79 | )
80 |
81 | load_images_as_np_arrays = False
82 | if num_img_channels not in [1, 3]:
83 | load_images_as_np_arrays = True
84 | print("image channels not 1 or 3, attempting to load images as np arrays...")
85 |
86 | if config.segmentation_guided:
87 | seg_types = os.listdir(seg_dir)
88 | seg_paths_train = {}
89 | seg_paths_eval = {}
90 |
91 | # train set
92 | if img_dir is not None:
93 | # make sure the images are matched to the segmentation masks
94 | img_dir_train = os.path.join(img_dir, "train")
95 | img_paths_train = [os.path.join(img_dir_train, f) for f in os.listdir(img_dir_train)]
96 | for seg_type in seg_types:
97 | seg_paths_train[seg_type] = [os.path.join(seg_dir, seg_type, "train", f) for f in os.listdir(img_dir_train)]
98 | else:
99 | for seg_type in seg_types:
100 | seg_paths_train[seg_type] = [os.path.join(seg_dir, seg_type, "train", f) for f in os.listdir(os.path.join(seg_dir, seg_type, "train"))]
101 |
102 | # eval set
103 | if img_dir is not None:
104 | img_dir_eval = os.path.join(img_dir, evalset_name)
105 | img_paths_eval = [os.path.join(img_dir_eval, f) for f in os.listdir(img_dir_eval)]
106 | for seg_type in seg_types:
107 | seg_paths_eval[seg_type] = [os.path.join(seg_dir, seg_type, evalset_name, f) for f in os.listdir(img_dir_eval)]
108 | else:
109 | for seg_type in seg_types:
110 | seg_paths_eval[seg_type] = [os.path.join(seg_dir, seg_type, evalset_name, f) for f in os.listdir(os.path.join(seg_dir, seg_type, evalset_name))]
111 |
112 | if img_dir is not None:
113 | dset_dict_train = {
114 | **{"image": img_paths_train},
115 | **{"seg_{}".format(seg_type): seg_paths_train[seg_type] for seg_type in seg_types}
116 | }
117 |
118 | dset_dict_eval = {
119 | **{"image": img_paths_eval},
120 | **{"seg_{}".format(seg_type): seg_paths_eval[seg_type] for seg_type in seg_types}
121 | }
122 | else:
123 | dset_dict_train = {
124 | **{"seg_{}".format(seg_type): seg_paths_train[seg_type] for seg_type in seg_types}
125 | }
126 |
127 | dset_dict_eval = {
128 | **{"seg_{}".format(seg_type): seg_paths_eval[seg_type] for seg_type in seg_types}
129 | }
130 |
131 |
132 | if img_dir is not None:
133 | # add image filenames to dataset
134 | dset_dict_train["image_filename"] = [os.path.basename(f) for f in dset_dict_train["image"]]
135 | dset_dict_eval["image_filename"] = [os.path.basename(f) for f in dset_dict_eval["image"]]
136 | else:
137 | # use segmentation filenames as image filenames
138 | dset_dict_train["image_filename"] = [os.path.basename(f) for f in dset_dict_train["seg_{}".format(seg_types[0])]]
139 | dset_dict_eval["image_filename"] = [os.path.basename(f) for f in dset_dict_eval["seg_{}".format(seg_types[0])]]
140 |
141 | dataset_train = datasets.Dataset.from_dict(dset_dict_train)
142 | dataset_eval = datasets.Dataset.from_dict(dset_dict_eval)
143 |
144 | # load the images
145 | if not load_images_as_np_arrays and img_dir is not None:
146 | dataset_train = dataset_train.cast_column("image", datasets.Image())
147 | dataset_eval = dataset_eval.cast_column("image", datasets.Image())
148 |
149 | for seg_type in seg_types:
150 | dataset_train = dataset_train.cast_column("seg_{}".format(seg_type), datasets.Image())
151 |
152 | for seg_type in seg_types:
153 | dataset_eval = dataset_eval.cast_column("seg_{}".format(seg_type), datasets.Image())
154 |
155 | else:
156 | if img_dir is not None:
157 | img_dir_train = os.path.join(img_dir, "train")
158 | img_paths_train = [os.path.join(img_dir_train, f) for f in os.listdir(img_dir_train)]
159 |
160 | img_dir_eval = os.path.join(img_dir, evalset_name)
161 | img_paths_eval = [os.path.join(img_dir_eval, f) for f in os.listdir(img_dir_eval)]
162 |
163 | dset_dict_train = {
164 | **{"image": img_paths_train}
165 | }
166 |
167 | dset_dict_eval = {
168 | **{"image": img_paths_eval}
169 | }
170 |
171 | # add image filenames to dataset
172 | dset_dict_train["image_filename"] = [os.path.basename(f) for f in dset_dict_train["image"]]
173 | dset_dict_eval["image_filename"] = [os.path.basename(f) for f in dset_dict_eval["image"]]
174 |
175 | dataset_train = datasets.Dataset.from_dict(dset_dict_train)
176 | dataset_eval = datasets.Dataset.from_dict(dset_dict_eval)
177 |
178 | # load the images
179 | if not load_images_as_np_arrays:
180 | dataset_train = dataset_train.cast_column("image", datasets.Image())
181 | dataset_eval = dataset_eval.cast_column("image", datasets.Image())
182 |
183 | # training set preprocessing
184 | if not load_images_as_np_arrays:
185 | preprocess = transforms.Compose(
186 | [
187 | transforms.Resize((config.image_size, config.image_size)),
188 | # transforms.RandomHorizontalFlip(), # flipping wouldn't result in realistic images
189 | transforms.ToTensor(),
190 | transforms.Normalize(
191 | num_img_channels * [0.5],
192 | num_img_channels * [0.5]),
193 | ]
194 | )
195 | else:
196 | # resizing will be done in the transform function
197 | preprocess = transforms.Compose(
198 | [
199 | transforms.Normalize(
200 | num_img_channels * [0.5],
201 | num_img_channels * [0.5]),
202 | ]
203 | )
204 |
205 | if num_img_channels == 1:
206 | PIL_image_type = "L"
207 | elif num_img_channels == 3:
208 | PIL_image_type = "RGB"
209 | else:
210 | PIL_image_type = None
211 |
212 | if config.segmentation_guided:
213 | preprocess_segmentation = transforms.Compose(
214 | [
215 | transforms.Resize((config.image_size, config.image_size), interpolation=transforms.InterpolationMode.NEAREST),
216 | transforms.ToTensor(),
217 | ]
218 | )
219 |
220 | def transform(examples):
221 | if img_dir is not None:
222 | if not load_images_as_np_arrays:
223 | images = [preprocess(image.convert(PIL_image_type)) for image in examples["image"]]
224 | else:
225 | # load np array as torch tensor, resize, then normalize
226 | images = [
227 | preprocess(F.interpolate(torch.tensor(np.load(image)).unsqueeze(0).float(), size=(config.image_size, config.image_size)).squeeze()) for image in examples["image"]
228 | ]
229 |
230 | images_filenames = examples["image_filename"]
231 |
232 | segs = {}
233 | for seg_type in seg_types:
234 | segs["seg_{}".format(seg_type)] = [preprocess_segmentation(image.convert("L")) for image in examples["seg_{}".format(seg_type)]]
235 | # return {"images": images, "seg_breast": seg_breast, "seg_dv": seg_dv}
236 | if img_dir is not None:
237 | return {**{"images": images}, **segs, **{"image_filenames": images_filenames}}
238 | else:
239 | return {**segs, **{"image_filenames": images_filenames}}
240 |
241 | dataset_train.set_transform(transform)
242 | dataset_eval.set_transform(transform)
243 |
244 | else:
245 | if img_dir is not None:
246 | def transform(examples):
247 | if not load_images_as_np_arrays:
248 | images = [preprocess(image.convert(PIL_image_type)) for image in examples["image"]]
249 | else:
250 | images = [
251 | preprocess(F.interpolate(torch.tensor(np.load(image)).unsqueeze(0).float(), size=(config.image_size, config.image_size)).squeeze()) for image in examples["image"]
252 | ]
253 | images_filenames = examples["image_filename"]
254 | #return {"images": images, "image_filenames": images_filenames}
255 | return {"images": images, **{"image_filenames": images_filenames}}
256 |
257 | dataset_train.set_transform(transform)
258 | dataset_eval.set_transform(transform)
259 |
260 | if ((img_dir is None) and (not segmentation_guided)):
261 | train_dataloader = None
262 | # just make placeholder dataloaders to iterate through when sampling from uncond model
263 | eval_dataloader = torch.utils.data.DataLoader(
264 | torch.utils.data.TensorDataset(torch.zeros(config.eval_batch_size, num_img_channels, config.image_size, config.image_size)),
265 | batch_size=config.eval_batch_size,
266 | shuffle=eval_shuffle_dataloader
267 | )
268 | else:
269 | train_dataloader = torch.utils.data.DataLoader(
270 | dataset_train,
271 | batch_size=config.train_batch_size,
272 | shuffle=True
273 | )
274 |
275 | eval_dataloader = torch.utils.data.DataLoader(
276 | dataset_eval,
277 | batch_size=config.eval_batch_size,
278 | shuffle=eval_shuffle_dataloader
279 | )
280 |
281 | # define the model
282 | in_channels = num_img_channels
283 | if config.segmentation_guided:
284 | assert config.num_segmentation_classes is not None
285 | assert config.num_segmentation_classes > 1, "must have at least 2 segmentation classes (INCLUDING background)"
286 | if config.segmentation_channel_mode == "single":
287 | in_channels += 1
288 | elif config.segmentation_channel_mode == "multi":
289 | in_channels = len(seg_types) + in_channels
290 |
291 | model = diffusers.UNet2DModel(
292 | sample_size=config.image_size, # the target image resolution
293 | in_channels=in_channels, # the number of input channels, 3 for RGB images
294 | out_channels=num_img_channels, # the number of output channels
295 | layers_per_block=2, # how many ResNet layers to use per UNet block
296 | block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channes for each UNet block
297 | down_block_types=(
298 | "DownBlock2D", # a regular ResNet downsampling block
299 | "DownBlock2D",
300 | "DownBlock2D",
301 | "DownBlock2D",
302 | "AttnDownBlock2D", # a ResNet downsampling block with spatial self-attention
303 | "DownBlock2D",
304 | ),
305 | up_block_types=(
306 | "UpBlock2D", # a regular ResNet upsampling block
307 | "AttnUpBlock2D", # a ResNet upsampling block with spatial self-attention
308 | "UpBlock2D",
309 | "UpBlock2D",
310 | "UpBlock2D",
311 | "UpBlock2D"
312 | ),
313 | )
314 |
315 | if (mode == "train" and resume_epoch is not None) or "eval" in mode:
316 | if mode == "train":
317 | print("resuming from model at training epoch {}".format(resume_epoch))
318 | elif "eval" in mode:
319 | print("loading saved model...")
320 | model = model.from_pretrained(os.path.join(config.output_dir, 'unet'), use_safetensors=True)
321 |
322 | model = nn.DataParallel(model)
323 | model.to(device)
324 |
325 | # define noise scheduler
326 | if model_type == "DDPM":
327 | noise_scheduler = diffusers.DDPMScheduler(num_train_timesteps=1000)
328 | elif model_type == "DDIM":
329 | noise_scheduler = diffusers.DDIMScheduler(num_train_timesteps=1000)
330 |
331 | if mode == "train":
332 | # training setup
333 | optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
334 | lr_scheduler = get_cosine_schedule_with_warmup(
335 | optimizer=optimizer,
336 | num_warmup_steps=config.lr_warmup_steps,
337 | num_training_steps=(len(train_dataloader) * config.num_epochs),
338 | )
339 |
340 | # train
341 | train_loop(
342 | config,
343 | model,
344 | noise_scheduler,
345 | optimizer,
346 | train_dataloader,
347 | eval_dataloader,
348 | lr_scheduler,
349 | device=device
350 | )
351 | elif mode == "eval":
352 | """
353 | default eval behavior:
354 | evaluate image generation or translation (if for conditional model, either evaluate naive class conditioning but not CFG,
355 | or with CFG),
356 | possibly conditioned on masks.
357 |
358 | has various options.
359 | """
360 | evaluate_generation(
361 | config,
362 | model,
363 | noise_scheduler,
364 | eval_dataloader,
365 | eval_mask_removal=eval_mask_removal,
366 | eval_blank_mask=eval_blank_mask,
367 | device=device
368 | )
369 |
370 | elif mode == "eval_many":
371 | """
372 | generate many images and save them to a directory, saved individually
373 | """
374 | evaluate_sample_many(
375 | eval_sample_size,
376 | config,
377 | model,
378 | noise_scheduler,
379 | eval_dataloader,
380 | device=device
381 | )
382 |
383 | else:
384 | raise ValueError("mode \"{}\" not supported.".format(mode))
385 |
386 |
387 | if __name__ == "__main__":
388 | # parse args:
389 | parser = ArgumentParser()
390 | parser.add_argument('--mode', type=str, default='train')
391 | parser.add_argument('--img_size', type=int, default=256)
392 | parser.add_argument('--num_img_channels', type=int, default=1)
393 | parser.add_argument('--dataset', type=str, default="breast_mri")
394 | parser.add_argument('--img_dir', type=str, default=None)
395 | parser.add_argument('--seg_dir', type=str, default=None)
396 | parser.add_argument('--model_type', type=str, default="DDPM")
397 | parser.add_argument('--segmentation_guided', action='store_true', help='use segmentation guided training/sampling?')
398 | parser.add_argument('--segmentation_channel_mode', type=str, default="single", help='single == all segmentations in one channel, multi == each segmentation in its own channel')
399 | parser.add_argument('--num_segmentation_classes', type=int, default=None, help='number of segmentation classes, including background')
400 | parser.add_argument('--train_batch_size', type=int, default=32)
401 | parser.add_argument('--eval_batch_size', type=int, default=8)
402 | parser.add_argument('--num_epochs', type=int, default=200)
403 | parser.add_argument('--resume_epoch', type=int, default=None, help='resume training starting at this epoch')
404 |
405 | # novel options
406 | parser.add_argument('--use_ablated_segmentations', action='store_true', help='use mask ablated training and any evaluation? sometimes randomly remove class(es) from mask during training and sampling.')
407 |
408 | # other options
409 | parser.add_argument('--eval_noshuffle_dataloader', action='store_true', help='if true, don\'t shuffle the eval dataloader')
410 |
411 | # args only used in eval
412 | parser.add_argument('--eval_mask_removal', action='store_true', help='if true, evaluate gradually removing anatomies from mask and re-sampling')
413 | parser.add_argument('--eval_blank_mask', action='store_true', help='if true, evaluate sampling conditioned on blank (zeros) masks')
414 | parser.add_argument('--eval_sample_size', type=int, default=1000, help='number of images to sample when using eval_many mode')
415 |
416 | args = parser.parse_args()
417 |
418 | main(
419 | args.mode,
420 | args.img_size,
421 | args.num_img_channels,
422 | args.dataset,
423 | args.img_dir,
424 | args.seg_dir,
425 | args.model_type,
426 | args.segmentation_guided,
427 | args.segmentation_channel_mode,
428 | args.num_segmentation_classes,
429 | args.train_batch_size,
430 | args.eval_batch_size,
431 | args.num_epochs,
432 | args.resume_epoch,
433 | args.use_ablated_segmentations,
434 | not args.eval_noshuffle_dataloader,
435 |
436 | # args only used in eval
437 | args.eval_mask_removal,
438 | args.eval_blank_mask,
439 | args.eval_sample_size
440 | )
441 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy==1.26.1
2 | matplotlib==3.8.0
3 | Pillow==10.0.1
4 | diffusers==0.21.4
5 | datasets==2.14.5
6 | tqdm==4.66.1
7 |
--------------------------------------------------------------------------------
/training.py:
--------------------------------------------------------------------------------
1 | """
2 | training utils
3 | """
4 | from dataclasses import dataclass
5 | import math
6 | import os
7 | from pathlib import Path
8 | from tqdm.auto import tqdm
9 | import numpy as np
10 | from datetime import timedelta
11 |
12 | import torch
13 | from torch import nn
14 | import torch.nn.functional as F
15 | from torch.utils.tensorboard import SummaryWriter
16 |
17 | import diffusers
18 |
19 | from eval import evaluate, add_segmentations_to_noise, SegGuidedDDPMPipeline, SegGuidedDDIMPipeline
20 |
21 | @dataclass
22 | class TrainingConfig:
23 | model_type: str = "DDPM"
24 | image_size: int = 256 # the generated image resolution
25 | train_batch_size: int = 32
26 | eval_batch_size: int = 8 # how many images to sample during evaluation
27 | num_epochs: int = 200
28 | gradient_accumulation_steps: int = 1
29 | learning_rate: float = 1e-4
30 | lr_warmup_steps: int = 500
31 | save_image_epochs: int = 20
32 | save_model_epochs: int = 30
33 | mixed_precision: str = 'fp16' # `no` for float32, `fp16` for automatic mixed precision
34 | output_dir: str = None
35 |
36 | push_to_hub: bool = False # whether to upload the saved model to the HF Hub
37 | hub_private_repo: bool = False
38 | overwrite_output_dir: bool = True # overwrite the old model when re-running the notebook
39 | seed: int = 0
40 |
41 | # custom options
42 | segmentation_guided: bool = False
43 | segmentation_channel_mode: str = "single"
44 | num_segmentation_classes: int = None # INCLUDING background
45 | use_ablated_segmentations: bool = False
46 | dataset: str = "breast_mri"
47 | resume_epoch: int = None
48 |
49 | # EXPERIMENTAL/UNTESTED: classifier-free class guidance and image translation
50 | class_conditional: bool = False
51 | cfg_p_uncond: float = 0.2 # p_uncond in classifier-free guidance paper
52 | cfg_weight: float = 0.3 # w in the paper
53 | trans_noise_level: float = 0.5 # ratio of time step t to noise trans_start_images to total T before denoising in translation. e.g. value of 0.5 means t = 500 for default T = 1000.
54 | use_cfg_for_eval_conditioning: bool = True # whether to use classifier-free guidance for or just naive class conditioning for main sampling loop
55 | cfg_maskguidance_condmodel_only: bool = True # if using mask guidance AND cfg, only give mask to conditional network
56 | # ^ this is because giving mask to both uncond and cond model make class guidance not work
57 | # (see "Classifier-free guidance resolution weighting." in ControlNet paper)
58 |
59 |
60 | def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, eval_dataloader, lr_scheduler, device='cuda'):
61 | # Prepare everything
62 | # There is no specific order to remember, you just need to unpack the
63 | # objects in the same order you gave them to the prepare method.
64 |
65 | global_step = 0
66 |
67 | # logging
68 | run_name = '{}-{}-{}'.format(config.model_type.lower(), config.dataset, config.image_size)
69 | if config.segmentation_guided:
70 | run_name += "-segguided"
71 | writer = SummaryWriter(comment=run_name)
72 |
73 | # for loading segs to condition on:
74 | eval_dataloader = iter(eval_dataloader)
75 |
76 | # Now you train the model
77 | start_epoch = 0
78 | if config.resume_epoch is not None:
79 | start_epoch = config.resume_epoch
80 |
81 | for epoch in range(start_epoch, config.num_epochs):
82 | progress_bar = tqdm(total=len(train_dataloader))
83 | progress_bar.set_description(f"Epoch {epoch}")
84 |
85 | model.train()
86 |
87 | for step, batch in enumerate(train_dataloader):
88 | clean_images = batch['images']
89 | clean_images = clean_images.to(device)
90 |
91 | # Sample noise to add to the images
92 | noise = torch.randn(clean_images.shape).to(clean_images.device)
93 | bs = clean_images.shape[0]
94 |
95 | # Sample a random timestep for each image
96 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device).long()
97 |
98 | # Add noise to the clean images according to the noise magnitude at each timestep
99 | # (this is the forward diffusion process)
100 | noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
101 |
102 | if config.segmentation_guided:
103 | noisy_images = add_segmentations_to_noise(noisy_images, batch, config, device)
104 |
105 | # Predict the noise residual
106 | if config.class_conditional:
107 | class_labels = torch.ones(noisy_images.size(0)).long().to(device)
108 | # classifier-free guidance
109 | a = np.random.uniform()
110 | if a <= config.cfg_p_uncond:
111 | class_labels = torch.zeros_like(class_labels).long()
112 | noise_pred = model(noisy_images, timesteps, class_labels=class_labels, return_dict=False)[0]
113 | else:
114 | noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
115 | loss = F.mse_loss(noise_pred, noise)
116 | loss.backward()
117 |
118 | nn.utils.clip_grad_norm_(model.parameters(), 1.0)
119 | optimizer.step()
120 | lr_scheduler.step()
121 | optimizer.zero_grad()
122 |
123 | # also train on target domain images if conditional
124 | # (we don't have masks for this domain, so we can't do segmentation-guided; just use blank masks)
125 | if config.class_conditional:
126 | target_domain_images = batch['images_target']
127 | target_domain_images = target_domain_images.to(device)
128 |
129 | # Sample noise to add to the images
130 | noise = torch.randn(target_domain_images.shape).to(target_domain_images.device)
131 | bs = target_domain_images.shape[0]
132 |
133 | # Sample a random timestep for each image
134 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,), device=target_domain_images.device).long()
135 |
136 | # Add noise to the clean images according to the noise magnitude at each timestep
137 | # (this is the forward diffusion process)
138 | noisy_images = noise_scheduler.add_noise(target_domain_images, noise, timesteps)
139 |
140 | if config.segmentation_guided:
141 | # no masks in target domain so just use blank masks
142 | noisy_images = torch.cat((noisy_images, torch.zeros_like(noisy_images)), dim=1)
143 |
144 | # Predict the noise residual
145 | class_labels = torch.full([noisy_images.size(0)], 2).long().to(device)
146 | # classifier-free guidance
147 | a = np.random.uniform()
148 | if a <= config.cfg_p_uncond:
149 | class_labels = torch.zeros_like(class_labels).long()
150 | noise_pred = model(noisy_images, timesteps, class_labels=class_labels, return_dict=False)[0]
151 | loss_target_domain = F.mse_loss(noise_pred, noise)
152 | loss_target_domain.backward()
153 |
154 | nn.utils.clip_grad_norm_(model.parameters(), 1.0)
155 | optimizer.step()
156 | lr_scheduler.step()
157 | optimizer.zero_grad()
158 |
159 | progress_bar.update(1)
160 | if config.class_conditional:
161 | logs = {"loss": loss.detach().item(), "loss_target_domain": loss_target_domain.detach().item(),
162 | "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
163 | writer.add_scalar("loss_target_domain", loss.detach().item(), global_step)
164 | else:
165 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
166 | writer.add_scalar("loss", loss.detach().item(), global_step)
167 |
168 | progress_bar.set_postfix(**logs)
169 | global_step += 1
170 |
171 | # After each epoch you optionally sample some demo images with evaluate() and save the model
172 | if config.model_type == "DDPM":
173 | if config.segmentation_guided:
174 | pipeline = SegGuidedDDPMPipeline(
175 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config
176 | )
177 | else:
178 | if config.class_conditional:
179 | raise NotImplementedError("TODO: Conditional training not implemented for non-seg-guided DDPM")
180 | else:
181 | pipeline = diffusers.DDPMPipeline(unet=model.module, scheduler=noise_scheduler)
182 | elif config.model_type == "DDIM":
183 | if config.segmentation_guided:
184 | pipeline = SegGuidedDDIMPipeline(
185 | unet=model.module, scheduler=noise_scheduler, eval_dataloader=eval_dataloader, external_config=config
186 | )
187 | else:
188 | if config.class_conditional:
189 | raise NotImplementedError("TODO: Conditional training not implemented for non-seg-guided DDIM")
190 | else:
191 | pipeline = diffusers.DDIMPipeline(unet=model.module, scheduler=noise_scheduler)
192 |
193 | model.eval()
194 |
195 | if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:
196 | if config.segmentation_guided:
197 | seg_batch = next(eval_dataloader)
198 | evaluate(config, epoch, pipeline, seg_batch)
199 | else:
200 | evaluate(config, epoch, pipeline)
201 |
202 | if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
203 | pipeline.save_pretrained(config.output_dir)
204 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | def make_grid(images, rows, cols):
4 | w, h = images[0].size
5 | grid = Image.new('RGB', size=(cols*w, rows*h))
6 | for i, image in enumerate(images):
7 | grid.paste(image, box=(i%cols*w, i//cols*h))
8 | return grid
9 |
--------------------------------------------------------------------------------