├── LICENSE
├── README.md
├── datasets
├── README.md
└── lsun_bedroom.py
├── evaluations
├── README.md
├── evaluator.py
└── requirements.txt
├── figures
├── .DS_Store
└── vis.jpg
├── infer_mdt.py
├── masked_diffusion
├── __init__.py
├── diffusion_utils.py
├── dist_util.py
├── fp16_util.py
├── gaussian_diffusion.py
├── image_datasets.py
├── logger.py
├── models.py
├── nn.py
├── resample.py
├── respace.py
├── script_util.py
├── timestep_sampler.py
└── train_util.py
├── run.sh
├── run_ddp_master.sh
├── run_ddp_worker.sh
├── run_sample.sh
├── scripts
├── image_sample.py
└── image_train.py
└── setup.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Masked Diffusion Transformer V2
2 |
3 | [](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?p=masked-diffusion-transformer-is-a-strong)
4 | [](https://huggingface.co/spaces/shgao/MDT)
5 |
6 | The official codebase for [Masked Diffusion Transformer is a Strong Image Synthesizer](https://arxiv.org/abs/2303.14389).
7 |
8 | ## MDTv2: Faster Convergeence & Stronger performance
9 | **MDTv2 achieves superior image synthesis performance, e.g., a new SOTA FID score of 1.58 on the ImageNet dataset, and has more than 10× faster learning speed than the previous SOTA DiT.**
10 |
11 | MDTv2 demonstrates a 5x acceleration compared to the original MDT.
12 |
13 | [MDTv1 code](https://github.com/sail-sg/MDT/tree/mdtv1)
14 | ## Introduction
15 |
16 | Despite its success in image synthesis, we observe that diffusion probabilistic models (DPMs) often lack contextual reasoning ability to learn the relations among object parts in an image, leading to a slow learning process. To solve this issue, we propose a Masked Diffusion Transformer (MDT) that introduces a mask latent modeling scheme to explicitly enhance the DPMs’ ability to contextual relation learning among object semantic parts in an image.
17 |
18 | During training, MDT operates in the latent space to mask certain tokens. Then, an asymmetric diffusion transformer is designed to predict masked tokens from unmasked ones while maintaining the diffusion generation process. Our MDT can reconstruct the full information of an image from its incomplete contextual input, thus enabling it to learn the associated relations among image tokens. We further improve MDT with a more efficient macro network structure and training strategy, named MDTv2.
19 |
20 | Experimental results show that MDTv2 achieves superior image synthesis performance, e.g., **a new SOTA FID score of 1.58 on the ImageNet dataset, and has more than 10× faster learning speed than the previous SOTA DiT**.
21 |
22 |
23 |
24 | # Performance
25 |
26 | | Model| Dataset | Resolution | FID-50K | Inception Score |
27 | |---------|----------|-----------|---------|--------|
28 | |MDT-XL/2 | ImageNet | 256x256 | 1.79 | 283.01|
29 | |MDTv2-XL/2 | ImageNet | 256x256 | 1.58 | 314.73|
30 |
31 | [Pretrained model download](https://huggingface.co/shgao/MDT-XL2/tree/main)
32 |
33 | Model is hosted on hugglingface, you can also download it with:
34 | ```
35 | from huggingface_hub import snapshot_download
36 | models_path = snapshot_download("shgao/MDT-XL2")
37 | ckpt_model_path = os.path.join(models_path, "mdt_xl2_v1_ckpt.pt")
38 | ```
39 | A hugglingface demo is on [DEMO](https://huggingface.co/spaces/shgao/MDT).
40 |
41 | **NEW SOTA on FID.**
42 | # Setup
43 |
44 | Prepare the Pytorch >=2.0 version. Download and install this repo.
45 |
46 | ```
47 | git clone https://github.com/sail-sg/MDT
48 | cd MDT
49 | pip install -e .
50 | ```
51 | Install [Adan optimizer](https://github.com/sail-sg/Adan), Adan is a strong optimizer with faster convergence speed than AdamW. [(paper)](https://arxiv.org/abs/2208.06677)
52 | ```
53 | python -m pip install git+https://github.com/sail-sg/Adan.git
54 | ```
55 |
56 | **DATA**
57 | - For standard datasets like ImageNet and CIFAR, please refer to '[dataset](https://github.com/sail-sg/MDT/tree/main/datasets)' for preparation.
58 | - When using customized dataset, change the image file name to `ClassID_ImgID.jpg`,
59 | as the [ADM's dataloder](https://github.com/openai/guided-diffusion) gets the class ID from the file name.
60 |
61 | # Training
62 |
63 |
64 | Training on one node (`run.sh`).
65 |
66 | ```shell
67 | export OPENAI_LOGDIR=output_mdtv2_s2
68 | NUM_GPUS=8
69 |
70 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 6 --model MDTv2_S_2"
71 | DIFFUSION_FLAGS="--diffusion_steps 1000"
72 | TRAIN_FLAGS="--batch_size 32"
73 | DATA_PATH=/dataset/imagenet
74 |
75 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
76 | ```
77 |
78 |
79 |
80 |
81 | Training on multiple nodes (`run_ddp_master.sh` and `run_ddp_worker.sh`).
82 |
83 | ```shell
84 | # On master:
85 | export OPENAI_LOGDIR=output_mdtv2_xl2
86 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_XL_2"
87 | DIFFUSION_FLAGS="--diffusion_steps 1000"
88 | TRAIN_FLAGS="--batch_size 4"
89 | DATA_PATH=/dataset/imagenet
90 | NUM_NODE=8
91 | GPU_PRE_NODE=8
92 |
93 | python -m torch.distributed.launch --master_addr=$(hostname) --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
94 |
95 | # On workers:
96 | export OPENAI_LOGDIR=output_mdtv2_xl2
97 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_XL_2"
98 | DIFFUSION_FLAGS="--diffusion_steps 1000"
99 | TRAIN_FLAGS="--batch_size 4"
100 | DATA_PATH=/dataset/imagenet
101 | NUM_NODE=8
102 | GPU_PRE_NODE=8
103 |
104 | python -m torch.distributed.launch --master_addr=$MASTER_ADDR --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
105 |
106 |
107 | ```
108 |
109 |
110 |
111 | # Evaluation
112 |
113 | The evaluation code is obtained from [ADM's TensorFlow evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations).
114 | Please follow the instructions in the `evaluations` folder to set up the evaluation environment.
115 |
116 |
117 | Sampling and Evaluation (`run_sample.sh`):
118 |
119 | ```shell
120 | MODEL_PATH=output_mdtv2_xl2/mdt_xl2_v2_ckpt.pt
121 | export OPENAI_LOGDIR=output_mdtv2_xl2_eval
122 | NUM_GPUS=8
123 |
124 | echo 'CFG Class-conditional sampling:'
125 | MODEL_FLAGS="--image_size 256 --model MDTv2_XL_2 --decode_layer 4"
126 | DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000 --cfg_cond True"
127 | echo $MODEL_FLAGS
128 | echo $DIFFUSION_FLAGS
129 | echo $MODEL_PATH
130 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_sample.py --model_path $MODEL_PATH $MODEL_FLAGS $DIFFUSION_FLAGS
131 | echo $MODEL_FLAGS
132 | echo $DIFFUSION_FLAGS
133 | echo $MODEL_PATH
134 | python evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz $OPENAI_LOGDIR/samples_50000x256x256x3.npz
135 |
136 | echo 'Class-conditional sampling:'
137 | MODEL_FLAGS="--image_size 256 --model MDTv2_XL_2 --decode_layer 4"
138 | DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000"
139 | echo $MODEL_FLAGS
140 | echo $DIFFUSION_FLAGS
141 | echo $MODEL_PATH
142 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_sample.py --model_path $MODEL_PATH $MODEL_FLAGS $DIFFUSION_FLAGS
143 | echo $MODEL_FLAGS
144 | echo $DIFFUSION_FLAGS
145 | echo $MODEL_PATH
146 | python evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz $OPENAI_LOGDIR/samples_50000x256x256x3.npz
147 | ```
148 |
149 |
150 |
151 | # Visualization
152 |
153 | Run the `infer_mdt.py` to generate images.
154 |
155 | # Citation
156 |
157 | ```
158 | @misc{gao2023masked,
159 | title={Masked Diffusion Transformer is a Strong Image Synthesizer},
160 | author={Shanghua Gao and Pan Zhou and Ming-Ming Cheng and Shuicheng Yan},
161 | year={2023},
162 | eprint={2303.14389},
163 | archivePrefix={arXiv},
164 | primaryClass={cs.CV}
165 | }
166 | ```
167 |
168 | # Acknowledgement
169 |
170 | This codebase is built based on the [DiT](https://github.com/facebookresearch/dit) and [ADM](https://github.com/openai/guided-diffusion). Thanks!
171 |
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | # Downloading datasets
2 |
3 | This directory includes instructions and scripts for downloading ImageNet and LSUN bedrooms for use in this codebase.
4 |
5 | ## Class-conditional ImageNet
6 |
7 | For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class.
8 |
9 | Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script:
10 |
11 | ```
12 | for file in *.tar; do tar xf "$file"; rm "$file"; done
13 | ```
14 |
15 | This will extract and remove each tar file in turn.
16 |
17 | Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels.
18 |
19 | ## LSUN bedroom
20 |
21 | To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so:
22 |
23 | ```
24 | python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir
25 | ```
26 |
27 | This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument.
28 |
--------------------------------------------------------------------------------
/datasets/lsun_bedroom.py:
--------------------------------------------------------------------------------
1 | """
2 | Convert an LSUN lmdb database into a directory of images.
3 | """
4 |
5 | import argparse
6 | import io
7 | import os
8 |
9 | from PIL import Image
10 | import lmdb
11 | import numpy as np
12 |
13 |
14 | def read_images(lmdb_path, image_size):
15 | env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True)
16 | with env.begin(write=False) as transaction:
17 | cursor = transaction.cursor()
18 | for _, webp_data in cursor:
19 | img = Image.open(io.BytesIO(webp_data))
20 | width, height = img.size
21 | scale = image_size / min(width, height)
22 | img = img.resize(
23 | (int(round(scale * width)), int(round(scale * height))),
24 | resample=Image.BOX,
25 | )
26 | arr = np.array(img)
27 | h, w, _ = arr.shape
28 | h_off = (h - image_size) // 2
29 | w_off = (w - image_size) // 2
30 | arr = arr[h_off : h_off + image_size, w_off : w_off + image_size]
31 | yield arr
32 |
33 |
34 | def dump_images(out_dir, images, prefix):
35 | if not os.path.exists(out_dir):
36 | os.mkdir(out_dir)
37 | for i, img in enumerate(images):
38 | Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png"))
39 |
40 |
41 | def main():
42 | parser = argparse.ArgumentParser()
43 | parser.add_argument("--image-size", help="new image size", type=int, default=256)
44 | parser.add_argument("--prefix", help="class name", type=str, default="bedroom")
45 | parser.add_argument("lmdb_path", help="path to an LSUN lmdb database")
46 | parser.add_argument("out_dir", help="path to output directory")
47 | args = parser.parse_args()
48 |
49 | images = read_images(args.lmdb_path, args.image_size)
50 | dump_images(args.out_dir, images, args.prefix)
51 |
52 |
53 | if __name__ == "__main__":
54 | main()
55 |
--------------------------------------------------------------------------------
/evaluations/README.md:
--------------------------------------------------------------------------------
1 | # Evaluations
2 |
3 | To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files.
4 |
5 | # Download batches
6 |
7 | We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format.
8 |
9 | Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall.
10 |
11 | Here are links to download all of the sample and reference batches:
12 |
13 | * LSUN
14 | * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz)
15 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz)
16 | * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz)
17 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz)
18 | * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz)
19 | * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz)
20 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz)
21 | * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz)
22 | * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz)
23 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz)
24 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz)
25 |
26 | * ImageNet
27 | * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz)
28 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz)
29 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz)
30 | * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz)
31 | * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz)
32 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz)
33 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz)
34 | * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz)
35 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz)
36 | * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz)
37 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz)
38 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz)
39 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz)
40 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz)
41 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz)
42 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz)
43 | * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz)
44 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz)
45 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz)
46 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz)
47 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz)
48 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz)
49 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz)
50 |
51 | # Run evaluations
52 |
53 | First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`.
54 |
55 | Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB.
56 |
57 | The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging:
58 |
59 | ```
60 | $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz
61 | ...
62 | computing reference batch activations...
63 | computing/reading reference batch statistics...
64 | computing sample batch activations...
65 | computing/reading sample batch statistics...
66 | Computing evaluations...
67 | Inception Score: 215.8370361328125
68 | FID: 3.9425574129223264
69 | sFID: 6.140433703346162
70 | Precision: 0.8265
71 | Recall: 0.5309
72 | ```
73 |
--------------------------------------------------------------------------------
/evaluations/evaluator.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import io
3 | import os
4 | import random
5 | import warnings
6 | import zipfile
7 | from abc import ABC, abstractmethod
8 | from contextlib import contextmanager
9 | from functools import partial
10 | from multiprocessing import cpu_count
11 | from multiprocessing.pool import ThreadPool
12 | from typing import Iterable, Optional, Tuple
13 |
14 | import numpy as np
15 | import requests
16 | import tensorflow.compat.v1 as tf
17 | from scipy import linalg
18 | from tqdm.auto import tqdm
19 |
20 | INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb"
21 | INCEPTION_V3_PATH = "classify_image_graph_def.pb"
22 |
23 | FID_POOL_NAME = "pool_3:0"
24 | FID_SPATIAL_NAME = "mixed_6/conv:0"
25 |
26 |
27 | def main():
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("ref_batch", help="path to reference batch npz file")
30 | parser.add_argument("sample_batch", help="path to sample batch npz file")
31 | args = parser.parse_args()
32 |
33 | config = tf.ConfigProto(
34 | allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph
35 | )
36 | config.gpu_options.allow_growth = True
37 | evaluator = Evaluator(tf.Session(config=config))
38 |
39 | print("warming up TensorFlow...")
40 | # This will cause TF to print a bunch of verbose stuff now rather
41 | # than after the next print(), to help prevent confusion.
42 | evaluator.warmup()
43 |
44 | print("computing reference batch activations...")
45 | ref_acts = evaluator.read_activations(args.ref_batch)
46 | print("computing/reading reference batch statistics...")
47 | ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts)
48 |
49 | print("computing sample batch activations...")
50 | sample_acts = evaluator.read_activations(args.sample_batch)
51 | print("computing/reading sample batch statistics...")
52 | sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts)
53 |
54 | print("Computing evaluations...")
55 | print("Inception Score:", evaluator.compute_inception_score(sample_acts[0]))
56 | print("FID:", sample_stats.frechet_distance(ref_stats))
57 | print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial))
58 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0])
59 | print("Precision:", prec)
60 | print("Recall:", recall)
61 |
62 |
63 | class InvalidFIDException(Exception):
64 | pass
65 |
66 |
67 | class FIDStatistics:
68 | def __init__(self, mu: np.ndarray, sigma: np.ndarray):
69 | self.mu = mu
70 | self.sigma = sigma
71 |
72 | def frechet_distance(self, other, eps=1e-6):
73 | """
74 | Compute the Frechet distance between two sets of statistics.
75 | """
76 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132
77 | mu1, sigma1 = self.mu, self.sigma
78 | mu2, sigma2 = other.mu, other.sigma
79 |
80 | mu1 = np.atleast_1d(mu1)
81 | mu2 = np.atleast_1d(mu2)
82 |
83 | sigma1 = np.atleast_2d(sigma1)
84 | sigma2 = np.atleast_2d(sigma2)
85 |
86 | assert (
87 | mu1.shape == mu2.shape
88 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}"
89 | assert (
90 | sigma1.shape == sigma2.shape
91 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}"
92 |
93 | diff = mu1 - mu2
94 |
95 | # product might be almost singular
96 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
97 | if not np.isfinite(covmean).all():
98 | msg = (
99 | "fid calculation produces singular product; adding %s to diagonal of cov estimates"
100 | % eps
101 | )
102 | warnings.warn(msg)
103 | offset = np.eye(sigma1.shape[0]) * eps
104 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
105 |
106 | # numerical error might give slight imaginary component
107 | if np.iscomplexobj(covmean):
108 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
109 | m = np.max(np.abs(covmean.imag))
110 | raise ValueError("Imaginary component {}".format(m))
111 | covmean = covmean.real
112 |
113 | tr_covmean = np.trace(covmean)
114 |
115 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
116 |
117 |
118 | class Evaluator:
119 | def __init__(
120 | self,
121 | session,
122 | batch_size=64,
123 | softmax_batch_size=512,
124 | ):
125 | self.sess = session
126 | self.batch_size = batch_size
127 | self.softmax_batch_size = softmax_batch_size
128 | self.manifold_estimator = ManifoldEstimator(session)
129 | with self.sess.graph.as_default():
130 | self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3])
131 | self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048])
132 | self.pool_features, self.spatial_features = _create_feature_graph(self.image_input)
133 | self.softmax = _create_softmax_graph(self.softmax_input)
134 |
135 | def warmup(self):
136 | self.compute_activations(np.zeros([1, 8, 64, 64, 3]))
137 |
138 | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]:
139 | with open_npz_array(npz_path, "arr_0") as reader:
140 | return self.compute_activations(reader.read_batches(self.batch_size))
141 |
142 | def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
143 | """
144 | Compute image features for downstream evals.
145 |
146 | :param batches: a iterator over NHWC numpy arrays in [0, 255].
147 | :return: a tuple of numpy arrays of shape [N x X], where X is a feature
148 | dimension. The tuple is (pool_3, spatial).
149 | """
150 | preds = []
151 | spatial_preds = []
152 | for batch in tqdm(batches):
153 | batch = batch.astype(np.float32)
154 | pred, spatial_pred = self.sess.run(
155 | [self.pool_features, self.spatial_features], {self.image_input: batch}
156 | )
157 | preds.append(pred.reshape([pred.shape[0], -1]))
158 | spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1]))
159 | return (
160 | np.concatenate(preds, axis=0),
161 | np.concatenate(spatial_preds, axis=0),
162 | )
163 |
164 | def read_statistics(
165 | self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray]
166 | ) -> Tuple[FIDStatistics, FIDStatistics]:
167 | obj = np.load(npz_path)
168 | if "mu" in list(obj.keys()):
169 | return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics(
170 | obj["mu_s"], obj["sigma_s"]
171 | )
172 | return tuple(self.compute_statistics(x) for x in activations)
173 |
174 | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics:
175 | mu = np.mean(activations, axis=0)
176 | sigma = np.cov(activations, rowvar=False)
177 | return FIDStatistics(mu, sigma)
178 |
179 | def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
180 | softmax_out = []
181 | for i in range(0, len(activations), self.softmax_batch_size):
182 | acts = activations[i : i + self.softmax_batch_size]
183 | softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
184 | preds = np.concatenate(softmax_out, axis=0)
185 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
186 | scores = []
187 | for i in range(0, len(preds), split_size):
188 | part = preds[i : i + split_size]
189 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
190 | kl = np.mean(np.sum(kl, 1))
191 | scores.append(np.exp(kl))
192 | return float(np.mean(scores))
193 |
194 | def compute_prec_recall(
195 | self, activations_ref: np.ndarray, activations_sample: np.ndarray
196 | ) -> Tuple[float, float]:
197 | radii_1 = self.manifold_estimator.manifold_radii(activations_ref)
198 | radii_2 = self.manifold_estimator.manifold_radii(activations_sample)
199 | pr = self.manifold_estimator.evaluate_pr(
200 | activations_ref, radii_1, activations_sample, radii_2
201 | )
202 | return (float(pr[0][0]), float(pr[1][0]))
203 |
204 |
205 | class ManifoldEstimator:
206 | """
207 | A helper for comparing manifolds of feature vectors.
208 |
209 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57
210 | """
211 |
212 | def __init__(
213 | self,
214 | session,
215 | row_batch_size=10000,
216 | col_batch_size=10000,
217 | nhood_sizes=(3,),
218 | clamp_to_percentile=None,
219 | eps=1e-5,
220 | ):
221 | """
222 | Estimate the manifold of given feature vectors.
223 |
224 | :param session: the TensorFlow session.
225 | :param row_batch_size: row batch size to compute pairwise distances
226 | (parameter to trade-off between memory usage and performance).
227 | :param col_batch_size: column batch size to compute pairwise distances.
228 | :param nhood_sizes: number of neighbors used to estimate the manifold.
229 | :param clamp_to_percentile: prune hyperspheres that have radius larger than
230 | the given percentile.
231 | :param eps: small number for numerical stability.
232 | """
233 | self.distance_block = DistanceBlock(session)
234 | self.row_batch_size = row_batch_size
235 | self.col_batch_size = col_batch_size
236 | self.nhood_sizes = nhood_sizes
237 | self.num_nhoods = len(nhood_sizes)
238 | self.clamp_to_percentile = clamp_to_percentile
239 | self.eps = eps
240 |
241 | def warmup(self):
242 | feats, radii = (
243 | np.zeros([1, 2048], dtype=np.float32),
244 | np.zeros([1, 1], dtype=np.float32),
245 | )
246 | self.evaluate_pr(feats, radii, feats, radii)
247 |
248 | def manifold_radii(self, features: np.ndarray) -> np.ndarray:
249 | num_images = len(features)
250 |
251 | # Estimate manifold of features by calculating distances to k-NN of each sample.
252 | radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32)
253 | distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32)
254 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32)
255 |
256 | for begin1 in range(0, num_images, self.row_batch_size):
257 | end1 = min(begin1 + self.row_batch_size, num_images)
258 | row_batch = features[begin1:end1]
259 |
260 | for begin2 in range(0, num_images, self.col_batch_size):
261 | end2 = min(begin2 + self.col_batch_size, num_images)
262 | col_batch = features[begin2:end2]
263 |
264 | # Compute distances between batches.
265 | distance_batch[
266 | 0 : end1 - begin1, begin2:end2
267 | ] = self.distance_block.pairwise_distances(row_batch, col_batch)
268 |
269 | # Find the k-nearest neighbor from the current batch.
270 | radii[begin1:end1, :] = np.concatenate(
271 | [
272 | x[:, self.nhood_sizes]
273 | for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1)
274 | ],
275 | axis=0,
276 | )
277 |
278 | if self.clamp_to_percentile is not None:
279 | max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0)
280 | radii[radii > max_distances] = 0
281 | return radii
282 |
283 | def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray):
284 | """
285 | Evaluate if new feature vectors are at the manifold.
286 | """
287 | num_eval_images = eval_features.shape[0]
288 | num_ref_images = radii.shape[0]
289 | distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32)
290 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32)
291 | max_realism_score = np.zeros([num_eval_images], dtype=np.float32)
292 | nearest_indices = np.zeros([num_eval_images], dtype=np.int32)
293 |
294 | for begin1 in range(0, num_eval_images, self.row_batch_size):
295 | end1 = min(begin1 + self.row_batch_size, num_eval_images)
296 | feature_batch = eval_features[begin1:end1]
297 |
298 | for begin2 in range(0, num_ref_images, self.col_batch_size):
299 | end2 = min(begin2 + self.col_batch_size, num_ref_images)
300 | ref_batch = features[begin2:end2]
301 |
302 | distance_batch[
303 | 0 : end1 - begin1, begin2:end2
304 | ] = self.distance_block.pairwise_distances(feature_batch, ref_batch)
305 |
306 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold.
307 | # If a feature vector is inside a hypersphere of some reference sample, then
308 | # the new sample lies at the estimated manifold.
309 | # The radii of the hyperspheres are determined from distances of neighborhood size k.
310 | samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii
311 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32)
312 |
313 | max_realism_score[begin1:end1] = np.max(
314 | radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1
315 | )
316 | nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1)
317 |
318 | return {
319 | "fraction": float(np.mean(batch_predictions)),
320 | "batch_predictions": batch_predictions,
321 | "max_realisim_score": max_realism_score,
322 | "nearest_indices": nearest_indices,
323 | }
324 |
325 | def evaluate_pr(
326 | self,
327 | features_1: np.ndarray,
328 | radii_1: np.ndarray,
329 | features_2: np.ndarray,
330 | radii_2: np.ndarray,
331 | ) -> Tuple[np.ndarray, np.ndarray]:
332 | """
333 | Evaluate precision and recall efficiently.
334 |
335 | :param features_1: [N1 x D] feature vectors for reference batch.
336 | :param radii_1: [N1 x K1] radii for reference vectors.
337 | :param features_2: [N2 x D] feature vectors for the other batch.
338 | :param radii_2: [N x K2] radii for other vectors.
339 | :return: a tuple of arrays for (precision, recall):
340 | - precision: an np.ndarray of length K1
341 | - recall: an np.ndarray of length K2
342 | """
343 | features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool)
344 | features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool)
345 | for begin_1 in range(0, len(features_1), self.row_batch_size):
346 | end_1 = begin_1 + self.row_batch_size
347 | batch_1 = features_1[begin_1:end_1]
348 | for begin_2 in range(0, len(features_2), self.col_batch_size):
349 | end_2 = begin_2 + self.col_batch_size
350 | batch_2 = features_2[begin_2:end_2]
351 | batch_1_in, batch_2_in = self.distance_block.less_thans(
352 | batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2]
353 | )
354 | features_1_status[begin_1:end_1] |= batch_1_in
355 | features_2_status[begin_2:end_2] |= batch_2_in
356 | return (
357 | np.mean(features_2_status.astype(np.float64), axis=0),
358 | np.mean(features_1_status.astype(np.float64), axis=0),
359 | )
360 |
361 |
362 | class DistanceBlock:
363 | """
364 | Calculate pairwise distances between vectors.
365 |
366 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34
367 | """
368 |
369 | def __init__(self, session):
370 | self.session = session
371 |
372 | # Initialize TF graph to calculate pairwise distances.
373 | with session.graph.as_default():
374 | self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None])
375 | self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None])
376 | distance_block_16 = _batch_pairwise_distances(
377 | tf.cast(self._features_batch1, tf.float16),
378 | tf.cast(self._features_batch2, tf.float16),
379 | )
380 | self.distance_block = tf.cond(
381 | tf.reduce_all(tf.math.is_finite(distance_block_16)),
382 | lambda: tf.cast(distance_block_16, tf.float32),
383 | lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2),
384 | )
385 |
386 | # Extra logic for less thans.
387 | self._radii1 = tf.placeholder(tf.float32, shape=[None, None])
388 | self._radii2 = tf.placeholder(tf.float32, shape=[None, None])
389 | dist32 = tf.cast(self.distance_block, tf.float32)[..., None]
390 | self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1)
391 | self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0)
392 |
393 | def pairwise_distances(self, U, V):
394 | """
395 | Evaluate pairwise distances between two batches of feature vectors.
396 | """
397 | return self.session.run(
398 | self.distance_block,
399 | feed_dict={self._features_batch1: U, self._features_batch2: V},
400 | )
401 |
402 | def less_thans(self, batch_1, radii_1, batch_2, radii_2):
403 | return self.session.run(
404 | [self._batch_1_in, self._batch_2_in],
405 | feed_dict={
406 | self._features_batch1: batch_1,
407 | self._features_batch2: batch_2,
408 | self._radii1: radii_1,
409 | self._radii2: radii_2,
410 | },
411 | )
412 |
413 |
414 | def _batch_pairwise_distances(U, V):
415 | """
416 | Compute pairwise distances between two batches of feature vectors.
417 | """
418 | with tf.variable_scope("pairwise_dist_block"):
419 | # Squared norms of each row in U and V.
420 | norm_u = tf.reduce_sum(tf.square(U), 1)
421 | norm_v = tf.reduce_sum(tf.square(V), 1)
422 |
423 | # norm_u as a column and norm_v as a row vectors.
424 | norm_u = tf.reshape(norm_u, [-1, 1])
425 | norm_v = tf.reshape(norm_v, [1, -1])
426 |
427 | # Pairwise squared Euclidean distances.
428 | D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0)
429 |
430 | return D
431 |
432 |
433 | class NpzArrayReader(ABC):
434 | @abstractmethod
435 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
436 | pass
437 |
438 | @abstractmethod
439 | def remaining(self) -> int:
440 | pass
441 |
442 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]:
443 | def gen_fn():
444 | while True:
445 | batch = self.read_batch(batch_size)
446 | if batch is None:
447 | break
448 | yield batch
449 |
450 | rem = self.remaining()
451 | num_batches = rem // batch_size + int(rem % batch_size != 0)
452 | return BatchIterator(gen_fn, num_batches)
453 |
454 |
455 | class BatchIterator:
456 | def __init__(self, gen_fn, length):
457 | self.gen_fn = gen_fn
458 | self.length = length
459 |
460 | def __len__(self):
461 | return self.length
462 |
463 | def __iter__(self):
464 | return self.gen_fn()
465 |
466 |
467 | class StreamingNpzArrayReader(NpzArrayReader):
468 | def __init__(self, arr_f, shape, dtype):
469 | self.arr_f = arr_f
470 | self.shape = shape
471 | self.dtype = dtype
472 | self.idx = 0
473 |
474 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
475 | if self.idx >= self.shape[0]:
476 | return None
477 |
478 | bs = min(batch_size, self.shape[0] - self.idx)
479 | self.idx += bs
480 |
481 | if self.dtype.itemsize == 0:
482 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype)
483 |
484 | read_count = bs * np.prod(self.shape[1:])
485 | read_size = int(read_count * self.dtype.itemsize)
486 | data = _read_bytes(self.arr_f, read_size, "array data")
487 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]])
488 |
489 | def remaining(self) -> int:
490 | return max(0, self.shape[0] - self.idx)
491 |
492 |
493 | class MemoryNpzArrayReader(NpzArrayReader):
494 | def __init__(self, arr):
495 | self.arr = arr
496 | self.idx = 0
497 |
498 | @classmethod
499 | def load(cls, path: str, arr_name: str):
500 | with open(path, "rb") as f:
501 | arr = np.load(f)[arr_name]
502 | return cls(arr)
503 |
504 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]:
505 | if self.idx >= self.arr.shape[0]:
506 | return None
507 |
508 | res = self.arr[self.idx : self.idx + batch_size]
509 | self.idx += batch_size
510 | return res
511 |
512 | def remaining(self) -> int:
513 | return max(0, self.arr.shape[0] - self.idx)
514 |
515 |
516 | @contextmanager
517 | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader:
518 | with _open_npy_file(path, arr_name) as arr_f:
519 | version = np.lib.format.read_magic(arr_f)
520 | if version == (1, 0):
521 | header = np.lib.format.read_array_header_1_0(arr_f)
522 | elif version == (2, 0):
523 | header = np.lib.format.read_array_header_2_0(arr_f)
524 | else:
525 | yield MemoryNpzArrayReader.load(path, arr_name)
526 | return
527 | shape, fortran, dtype = header
528 | if fortran or dtype.hasobject:
529 | yield MemoryNpzArrayReader.load(path, arr_name)
530 | else:
531 | yield StreamingNpzArrayReader(arr_f, shape, dtype)
532 |
533 |
534 | def _read_bytes(fp, size, error_template="ran out of data"):
535 | """
536 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886
537 |
538 | Read from file-like object until size bytes are read.
539 | Raises ValueError if not EOF is encountered before size bytes are read.
540 | Non-blocking objects only supported if they derive from io objects.
541 | Required as e.g. ZipExtFile in python 2.6 can return less data than
542 | requested.
543 | """
544 | data = bytes()
545 | while True:
546 | # io files (default in python3) return None or raise on
547 | # would-block, python2 file will truncate, probably nothing can be
548 | # done about that. note that regular files can't be non-blocking
549 | try:
550 | r = fp.read(size - len(data))
551 | data += r
552 | if len(r) == 0 or len(data) == size:
553 | break
554 | except io.BlockingIOError:
555 | pass
556 | if len(data) != size:
557 | msg = "EOF: reading %s, expected %d bytes got %d"
558 | raise ValueError(msg % (error_template, size, len(data)))
559 | else:
560 | return data
561 |
562 |
563 | @contextmanager
564 | def _open_npy_file(path: str, arr_name: str):
565 | with open(path, "rb") as f:
566 | with zipfile.ZipFile(f, "r") as zip_f:
567 | if f"{arr_name}.npy" not in zip_f.namelist():
568 | raise ValueError(f"missing {arr_name} in npz file")
569 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f:
570 | yield arr_f
571 |
572 |
573 | def _download_inception_model():
574 | if os.path.exists(INCEPTION_V3_PATH):
575 | return
576 | print("downloading InceptionV3 model...")
577 | with requests.get(INCEPTION_V3_URL, stream=True) as r:
578 | r.raise_for_status()
579 | tmp_path = INCEPTION_V3_PATH + ".tmp"
580 | with open(tmp_path, "wb") as f:
581 | for chunk in tqdm(r.iter_content(chunk_size=8192)):
582 | f.write(chunk)
583 | os.rename(tmp_path, INCEPTION_V3_PATH)
584 |
585 |
586 | def _create_feature_graph(input_batch):
587 | _download_inception_model()
588 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
589 | with open(INCEPTION_V3_PATH, "rb") as f:
590 | graph_def = tf.GraphDef()
591 | graph_def.ParseFromString(f.read())
592 | pool3, spatial = tf.import_graph_def(
593 | graph_def,
594 | input_map={f"ExpandDims:0": input_batch},
595 | return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME],
596 | name=prefix,
597 | )
598 | _update_shapes(pool3)
599 | spatial = spatial[..., :7]
600 | return pool3, spatial
601 |
602 |
603 | def _create_softmax_graph(input_batch):
604 | _download_inception_model()
605 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}"
606 | with open(INCEPTION_V3_PATH, "rb") as f:
607 | graph_def = tf.GraphDef()
608 | graph_def.ParseFromString(f.read())
609 | (matmul,) = tf.import_graph_def(
610 | graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix
611 | )
612 | w = matmul.inputs[1]
613 | logits = tf.matmul(input_batch, w)
614 | return tf.nn.softmax(logits)
615 |
616 |
617 | def _update_shapes(pool3):
618 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63
619 | ops = pool3.graph.get_operations()
620 | for op in ops:
621 | for o in op.outputs:
622 | shape = o.get_shape()
623 | if shape._dims is not None: # pylint: disable=protected-access
624 | # shape = [s.value for s in shape] TF 1.x
625 | shape = [s for s in shape] # TF 2.x
626 | new_shape = []
627 | for j, s in enumerate(shape):
628 | if s == 1 and j == 0:
629 | new_shape.append(None)
630 | else:
631 | new_shape.append(s)
632 | o.__dict__["_shape_val"] = tf.TensorShape(new_shape)
633 | return pool3
634 |
635 |
636 | def _numpy_partition(arr, kth, **kwargs):
637 | num_workers = min(cpu_count(), len(arr))
638 | chunk_size = len(arr) // num_workers
639 | extra = len(arr) % num_workers
640 |
641 | start_idx = 0
642 | batches = []
643 | for i in range(num_workers):
644 | size = chunk_size + (1 if i < extra else 0)
645 | batches.append(arr[start_idx : start_idx + size])
646 | start_idx += size
647 |
648 | with ThreadPool(num_workers) as pool:
649 | return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches))
650 |
651 |
652 | if __name__ == "__main__":
653 | main()
654 |
--------------------------------------------------------------------------------
/evaluations/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow-gpu>=2.0
2 | scipy
3 | requests
4 | tqdm
--------------------------------------------------------------------------------
/figures/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/MDT/7d26c2162c462bd0b90f97f3a1c36cdaaac616ec/figures/.DS_Store
--------------------------------------------------------------------------------
/figures/vis.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sail-sg/MDT/7d26c2162c462bd0b90f97f3a1c36cdaaac616ec/figures/vis.jpg
--------------------------------------------------------------------------------
/infer_mdt.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 |
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import torch
8 | from torchvision.utils import save_image
9 | from masked_diffusion import create_diffusion
10 | from diffusers.models import AutoencoderKL
11 | from masked_diffusion.models import MDTv2_XL_2
12 |
13 |
14 | # Setup PyTorch:
15 | torch.manual_seed(1)
16 | torch.set_grad_enabled(False)
17 | device = "cuda" if torch.cuda.is_available() else "cpu"
18 | num_sampling_steps = 250
19 | cfg_scale = 4.0
20 | pow_scale = 0.01 # large pow_scale increase the diversity, small pow_scale increase the quality.
21 | model_path = 'mdt_xl2_v2_ckpt.pt'
22 |
23 | # Load model:
24 | image_size = 256
25 | assert image_size in [256], "We provide pre-trained models for 256x256 resolutions for now."
26 | latent_size = image_size // 8
27 | model = MDTv2_XL_2(input_size=latent_size, decode_layer=4).to(device)
28 |
29 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
30 | model.load_state_dict(state_dict)
31 | model.eval()
32 | diffusion = create_diffusion(str(num_sampling_steps))
33 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
34 |
35 | # Labels to condition the model with:
36 | class_labels = [19,23,106,108,278,282]
37 |
38 | # Create sampling noise:
39 | n = len(class_labels)
40 | z = torch.randn(n, 4, latent_size, latent_size, device=device)
41 | y = torch.tensor(class_labels, device=device)
42 |
43 | # Setup classifier-free guidance:
44 | z = torch.cat([z, z], 0)
45 | y_null = torch.tensor([1000] * n, device=device)
46 | y = torch.cat([y, y_null], 0)
47 |
48 |
49 | model_kwargs = dict(y=y, cfg_scale=cfg_scale, scale_pow=pow_scale)
50 |
51 | # Sample images:
52 | samples = diffusion.p_sample_loop(
53 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device
54 | )
55 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples
56 | samples = vae.decode(samples / 0.18215).sample
57 |
58 | # Save and display images:
59 | save_image(samples, "sample.jpg", nrow=3, normalize=True, value_range=(-1, 1))
60 |
--------------------------------------------------------------------------------
/masked_diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from . import gaussian_diffusion as gd
7 | from .respace import SpacedDiffusion, space_timesteps
8 |
9 |
10 | def create_diffusion(
11 | timestep_respacing,
12 | noise_schedule="linear",
13 | use_kl=False,
14 | sigma_small=False,
15 | predict_xstart=False,
16 | learn_sigma=True,
17 | rescale_learned_sigmas=False,
18 | diffusion_steps=1000
19 | ):
20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21 | if use_kl:
22 | loss_type = gd.LossType.RESCALED_KL
23 | elif rescale_learned_sigmas:
24 | loss_type = gd.LossType.RESCALED_MSE
25 | else:
26 | loss_type = gd.LossType.MSE
27 | if timestep_respacing is None or timestep_respacing == "":
28 | timestep_respacing = [diffusion_steps]
29 | return SpacedDiffusion(
30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31 | betas=betas,
32 | model_mean_type=(
33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34 | ),
35 | model_var_type=(
36 | (
37 | gd.ModelVarType.FIXED_LARGE
38 | if not sigma_small
39 | else gd.ModelVarType.FIXED_SMALL
40 | )
41 | if not learn_sigma
42 | else gd.ModelVarType.LEARNED_RANGE
43 | ),
44 | loss_type=loss_type
45 | )
46 |
47 | def diffusion_defaults():
48 | """
49 | Defaults for image and classifier training.
50 | """
51 | return dict(
52 | learn_sigma=True,
53 | diffusion_steps=1000,
54 | noise_schedule="linear",
55 | timestep_respacing="",
56 | use_kl=False,
57 | predict_xstart=False,
58 | sigma_small=False,
59 | rescale_learned_sigmas=False,
60 | )
61 |
62 | def model_and_diffusion_defaults():
63 | """
64 | Defaults for image training.
65 | """
66 | res = dict(
67 | image_size=256,
68 | mask_ratio=None,
69 | decode_layer=None,
70 | class_cond=True,
71 | use_fp16=False,
72 | )
73 | res.update(diffusion_defaults())
74 | return res
--------------------------------------------------------------------------------
/masked_diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import torch as th
7 | import numpy as np
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27 | for x in (logvar1, logvar2)
28 | ]
29 |
30 | return 0.5 * (
31 | -1.0
32 | + logvar2
33 | - logvar1
34 | + th.exp(logvar1 - logvar2)
35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36 | )
37 |
38 |
39 | def approx_standard_normal_cdf(x):
40 | """
41 | A fast approximation of the cumulative distribution function of the
42 | standard normal.
43 | """
44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45 |
46 |
47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48 | """
49 | Compute the log-likelihood of a continuous Gaussian distribution.
50 | :param x: the targets
51 | :param means: the Gaussian mean Tensor.
52 | :param log_scales: the Gaussian log stddev Tensor.
53 | :return: a tensor like x of log probabilities (in nats).
54 | """
55 | centered_x = x - means
56 | inv_stdv = th.exp(-log_scales)
57 | normalized_x = centered_x * inv_stdv
58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59 | return log_probs
60 |
61 |
62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63 | """
64 | Compute the log-likelihood of a Gaussian distribution discretizing to a
65 | given image.
66 | :param x: the target images. It is assumed that this was uint8 values,
67 | rescaled to the range [-1, 1].
68 | :param means: the Gaussian mean Tensor.
69 | :param log_scales: the Gaussian log stddev Tensor.
70 | :return: a tensor like x of log probabilities (in nats).
71 | """
72 | assert x.shape == means.shape == log_scales.shape
73 | centered_x = x - means
74 | inv_stdv = th.exp(-log_scales)
75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76 | cdf_plus = approx_standard_normal_cdf(plus_in)
77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78 | cdf_min = approx_standard_normal_cdf(min_in)
79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81 | cdf_delta = cdf_plus - cdf_min
82 | log_probs = th.where(
83 | x < -0.999,
84 | log_cdf_plus,
85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86 | )
87 | assert log_probs.shape == x.shape
88 | return log_probs
89 |
--------------------------------------------------------------------------------
/masked_diffusion/dist_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers for distributed training.
3 | """
4 |
5 | import io
6 | import os
7 | import socket
8 |
9 | import blobfile as bf
10 | from mpi4py import MPI
11 | import torch as th
12 | import torch.distributed as dist
13 | import builtins
14 | import datetime
15 |
16 | # Change this to reflect your cluster layout.
17 | # The GPU for a given rank is (rank % GPUS_PER_NODE).
18 | GPUS_PER_NODE = 8
19 |
20 | SETUP_RETRY_COUNT = 3
21 | def synchronize():
22 | if not dist.is_available():
23 | return
24 |
25 | if not dist.is_initialized():
26 | return
27 |
28 | world_size = dist.get_world_size()
29 |
30 | if world_size == 1:
31 | return
32 |
33 | dist.barrier()
34 |
35 | def is_dist_avail_and_initialized():
36 | if not dist.is_available():
37 | return False
38 | if not dist.is_initialized():
39 | return False
40 | return True
41 | def get_world_size():
42 | if not is_dist_avail_and_initialized():
43 | return 1
44 | return dist.get_world_size()
45 |
46 | def setup_for_distributed(is_master):
47 | """
48 | This function disables printing when not in master process
49 | """
50 | builtin_print = builtins.print
51 |
52 | def print(*args, **kwargs):
53 | force = kwargs.pop('force', False)
54 | force = force or (get_world_size() > 8)
55 | if is_master or force:
56 | now = datetime.datetime.now().time()
57 | builtin_print('[{}] '.format(now), end='') # print with time stamp
58 | builtin_print(*args, **kwargs)
59 |
60 | builtins.print = print
61 |
62 | def setup_dist_multinode(args):
63 | """
64 | Setup a distributed process group.
65 | """
66 | if not dist.is_available() or not dist.is_initialized():
67 | th.distributed.init_process_group(backend="nccl", init_method='env://')
68 | world_size = dist.get_world_size()
69 | local_rank = int(os.getenv('LOCAL_RANK'))
70 | print("rank",local_rank)
71 | device = local_rank
72 | th.cuda.set_device(device)
73 | setup_for_distributed(device == 0)
74 |
75 | synchronize()
76 | else:
77 | print("ddp failed!")
78 | exit()
79 |
80 | def setup_dist():
81 | """
82 | Setup a distributed process group.
83 | """
84 | if dist.is_initialized():
85 | return
86 | th.cuda.set_device(int(os.environ["LOCAL_RANK"]))
87 | th.distributed.init_process_group(backend="nccl", init_method="env://")
88 | synchronize()
89 |
90 | def dev():
91 | """
92 | Get the device to use for torch.distributed.
93 | """
94 | if th.cuda.is_available():
95 | return th.device(f"cuda")
96 | return th.device("cpu")
97 |
98 |
99 | def load_state_dict(path, **kwargs):
100 | """
101 | Load a PyTorch file without redundant fetches across MPI ranks.
102 | """
103 | chunk_size = 2 ** 30 # MPI has a relatively small size limit
104 | if MPI.COMM_WORLD.Get_rank() == 0:
105 | with bf.BlobFile(path, "rb") as f:
106 | data = f.read()
107 | num_chunks = len(data) // chunk_size
108 | if len(data) % chunk_size:
109 | num_chunks += 1
110 | MPI.COMM_WORLD.bcast(num_chunks)
111 | for i in range(0, len(data), chunk_size):
112 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size])
113 | else:
114 | num_chunks = MPI.COMM_WORLD.bcast(None)
115 | data = bytes()
116 | for _ in range(num_chunks):
117 | data += MPI.COMM_WORLD.bcast(None)
118 |
119 | return th.load(io.BytesIO(data), **kwargs)
120 |
121 |
122 | def sync_params(params):
123 | """
124 | Synchronize a sequence of Tensors across ranks from rank 0.
125 | """
126 | for p in params:
127 | with th.no_grad():
128 | dist.broadcast(p, 0)
129 |
130 |
131 | def _find_free_port():
132 | try:
133 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
134 | s.bind(("", 0))
135 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
136 | return s.getsockname()[1]
137 | finally:
138 | s.close()
139 |
--------------------------------------------------------------------------------
/masked_diffusion/fp16_util.py:
--------------------------------------------------------------------------------
1 | """
2 | Helpers to train with 16-bit precision.
3 | """
4 |
5 | import numpy as np
6 | import torch as th
7 | import torch.nn as nn
8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9 |
10 | from . import logger
11 |
12 | INITIAL_LOG_LOSS_SCALE = 20.0
13 |
14 |
15 | def convert_module_to_f16(l):
16 | """
17 | Convert primitive modules to float16.
18 | """
19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20 | l.weight.data = l.weight.data.half()
21 | if l.bias is not None:
22 | l.bias.data = l.bias.data.half()
23 |
24 |
25 | def convert_module_to_f32(l):
26 | """
27 | Convert primitive modules to float32, undoing convert_module_to_f16().
28 | """
29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30 | l.weight.data = l.weight.data.float()
31 | if l.bias is not None:
32 | l.bias.data = l.bias.data.float()
33 |
34 |
35 | def make_master_params(param_groups_and_shapes):
36 | """
37 | Copy model parameters into a (differently-shaped) list of full-precision
38 | parameters.
39 | """
40 | master_params = []
41 | for param_group, shape in param_groups_and_shapes:
42 | master_param = nn.Parameter(
43 | _flatten_dense_tensors(
44 | [param.detach().float() for (_, param) in param_group]
45 | ).view(shape)
46 | )
47 | master_param.requires_grad = True
48 | master_params.append(master_param)
49 | return master_params
50 |
51 |
52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53 | """
54 | Copy the gradients from the model parameters into the master parameters
55 | from make_master_params().
56 | """
57 | for master_param, (param_group, shape) in zip(
58 | master_params, param_groups_and_shapes
59 | ):
60 | master_param.grad = _flatten_dense_tensors(
61 | [param_grad_or_zeros(param) for (_, param) in param_group]
62 | ).view(shape)
63 |
64 |
65 | def master_params_to_model_params(param_groups_and_shapes, master_params):
66 | """
67 | Copy the master parameter data back into the model parameters.
68 | """
69 | # Without copying to a list, if a generator is passed, this will
70 | # silently not copy any parameters.
71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72 | for (_, param), unflat_master_param in zip(
73 | param_group, unflatten_master_params(param_group, master_param.view(-1))
74 | ):
75 | param.detach().copy_(unflat_master_param)
76 |
77 |
78 | def unflatten_master_params(param_group, master_param):
79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80 |
81 |
82 | def get_param_groups_and_shapes(named_model_params):
83 | named_model_params = list(named_model_params)
84 | scalar_vector_named_params = (
85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86 | (-1),
87 | )
88 | matrix_named_params = (
89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90 | (1, -1),
91 | )
92 | return [scalar_vector_named_params, matrix_named_params]
93 |
94 |
95 | def master_params_to_state_dict(
96 | model, param_groups_and_shapes, master_params, use_fp16
97 | ):
98 | if use_fp16:
99 | state_dict = model.state_dict()
100 | for master_param, (param_group, _) in zip(
101 | master_params, param_groups_and_shapes
102 | ):
103 | for (name, _), unflat_master_param in zip(
104 | param_group, unflatten_master_params(param_group, master_param.view(-1))
105 | ):
106 | assert name in state_dict
107 | state_dict[name] = unflat_master_param
108 | else:
109 | state_dict = model.state_dict()
110 | for i, (name, _value) in enumerate(model.named_parameters()):
111 | assert name in state_dict
112 | state_dict[name] = master_params[i]
113 | return state_dict
114 |
115 |
116 | def state_dict_to_master_params(model, state_dict, use_fp16):
117 | if use_fp16:
118 | named_model_params = [
119 | (name, state_dict[name]) for name, _ in model.named_parameters()
120 | ]
121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122 | master_params = make_master_params(param_groups_and_shapes)
123 | else:
124 | master_params = [state_dict[name] for name, _ in model.named_parameters()]
125 | return master_params
126 |
127 |
128 | def zero_master_grads(master_params):
129 | for param in master_params:
130 | param.grad = None
131 |
132 |
133 | def zero_grad(model_params):
134 | for param in model_params:
135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136 | if param.grad is not None:
137 | param.grad.detach_()
138 | param.grad.zero_()
139 |
140 |
141 | def param_grad_or_zeros(param):
142 | if param.grad is not None:
143 | return param.grad.data.detach()
144 | else:
145 | return th.zeros_like(param)
146 |
147 |
148 | class MixedPrecisionTrainer:
149 | def __init__(
150 | self,
151 | *,
152 | model,
153 | use_fp16=False,
154 | fp16_scale_growth=1e-3,
155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156 | ):
157 | self.model = model
158 | self.use_fp16 = use_fp16
159 | self.fp16_scale_growth = fp16_scale_growth
160 |
161 | self.model_params = list(self.model.parameters())
162 | self.master_params = self.model_params
163 | self.param_groups_and_shapes = None
164 | self.lg_loss_scale = initial_lg_loss_scale
165 |
166 | if self.use_fp16:
167 | self.param_groups_and_shapes = get_param_groups_and_shapes(
168 | self.model.named_parameters()
169 | )
170 | self.master_params = make_master_params(self.param_groups_and_shapes)
171 | self.model.convert_to_fp16()
172 |
173 | def zero_grad(self):
174 | zero_grad(self.model_params)
175 |
176 | def backward(self, loss: th.Tensor):
177 | if self.use_fp16:
178 | loss_scale = 2 ** self.lg_loss_scale
179 | (loss * loss_scale).backward()
180 | else:
181 | loss.backward()
182 |
183 | def optimize(self, opt: th.optim.Optimizer):
184 | if self.use_fp16:
185 | return self._optimize_fp16(opt)
186 | else:
187 | return self._optimize_normal(opt)
188 |
189 | def _optimize_fp16(self, opt: th.optim.Optimizer):
190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193 | if check_overflow(grad_norm):
194 | self.lg_loss_scale -= 1
195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196 | zero_master_grads(self.master_params)
197 | return False
198 |
199 | logger.logkv_mean("grad_norm", grad_norm)
200 | logger.logkv_mean("param_norm", param_norm)
201 |
202 | for p in self.master_params:
203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale))
204 | opt.step()
205 | zero_master_grads(self.master_params)
206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
207 | self.lg_loss_scale += self.fp16_scale_growth
208 | return True
209 |
210 | def _optimize_normal(self, opt: th.optim.Optimizer):
211 | grad_norm, param_norm = self._compute_norms()
212 | logger.logkv_mean("grad_norm", grad_norm)
213 | logger.logkv_mean("param_norm", param_norm)
214 | opt.step()
215 | return True
216 |
217 | def _compute_norms(self, grad_scale=1.0):
218 | grad_norm = 0.0
219 | param_norm = 0.0
220 | for p in self.master_params:
221 | with th.no_grad():
222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
223 | if p.grad is not None:
224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
226 |
227 | def master_params_to_state_dict(self, master_params):
228 | return master_params_to_state_dict(
229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16
230 | )
231 |
232 | def state_dict_to_master_params(self, state_dict):
233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
234 |
235 |
236 | def check_overflow(value):
237 | return (value == float("inf")) or (value == -float("inf")) or (value != value)
238 |
--------------------------------------------------------------------------------
/masked_diffusion/gaussian_diffusion.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 |
7 | import math
8 |
9 | import numpy as np
10 | import torch as th
11 | import enum
12 |
13 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14 |
15 |
16 | def mean_flat(tensor):
17 | """
18 | Take the mean over all non-batch dimensions.
19 | """
20 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
21 |
22 |
23 | class ModelMeanType(enum.Enum):
24 | """
25 | Which type of output the model predicts.
26 | """
27 |
28 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29 | START_X = enum.auto() # the model predicts x_0
30 | EPSILON = enum.auto() # the model predicts epsilon
31 | VELOCITY = enum.auto() # the model predicts v
32 |
33 |
34 | class ModelVarType(enum.Enum):
35 | """
36 | What is used as the model's output variance.
37 | The LEARNED_RANGE option has been added to allow the model to predict
38 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
39 | """
40 |
41 | LEARNED = enum.auto()
42 | FIXED_SMALL = enum.auto()
43 | FIXED_LARGE = enum.auto()
44 | LEARNED_RANGE = enum.auto()
45 |
46 |
47 | class LossType(enum.Enum):
48 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
49 | RESCALED_MSE = (
50 | enum.auto()
51 | ) # use raw MSE loss (with RESCALED_KL when learning variances)
52 | KL = enum.auto() # use the variational lower-bound
53 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
54 |
55 | def is_vb(self):
56 | return self == LossType.KL or self == LossType.RESCALED_KL
57 |
58 |
59 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
60 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
61 | warmup_time = int(num_diffusion_timesteps * warmup_frac)
62 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
63 | return betas
64 |
65 |
66 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
67 | """
68 | This is the deprecated API for creating beta schedules.
69 | See get_named_beta_schedule() for the new library of schedules.
70 | """
71 | if beta_schedule == "quad":
72 | betas = (
73 | np.linspace(
74 | beta_start ** 0.5,
75 | beta_end ** 0.5,
76 | num_diffusion_timesteps,
77 | dtype=np.float64,
78 | )
79 | ** 2
80 | )
81 | elif beta_schedule == "linear":
82 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
83 | elif beta_schedule == "warmup10":
84 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
85 | elif beta_schedule == "warmup50":
86 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
87 | elif beta_schedule == "const":
88 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
89 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
90 | betas = 1.0 / np.linspace(
91 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
92 | )
93 | else:
94 | raise NotImplementedError(beta_schedule)
95 | assert betas.shape == (num_diffusion_timesteps,)
96 | return betas
97 |
98 |
99 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
100 | """
101 | Get a pre-defined beta schedule for the given name.
102 | The beta schedule library consists of beta schedules which remain similar
103 | in the limit of num_diffusion_timesteps.
104 | Beta schedules may be added, but should not be removed or changed once
105 | they are committed to maintain backwards compatibility.
106 | """
107 | if schedule_name == "linear":
108 | # Linear schedule from Ho et al, extended to work for any number of
109 | # diffusion steps.
110 | scale = 1000 / num_diffusion_timesteps
111 | return get_beta_schedule(
112 | "linear",
113 | beta_start=scale * 0.0001,
114 | beta_end=scale * 0.02,
115 | num_diffusion_timesteps=num_diffusion_timesteps,
116 | )
117 | elif schedule_name == "squaredcos_cap_v2":
118 | return betas_for_alpha_bar(
119 | num_diffusion_timesteps,
120 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
121 | )
122 | else:
123 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
124 |
125 |
126 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
127 | """
128 | Create a beta schedule that discretizes the given alpha_t_bar function,
129 | which defines the cumulative product of (1-beta) over time from t = [0,1].
130 | :param num_diffusion_timesteps: the number of betas to produce.
131 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
132 | produces the cumulative product of (1-beta) up to that
133 | part of the diffusion process.
134 | :param max_beta: the maximum beta to use; use values lower than 1 to
135 | prevent singularities.
136 | """
137 | betas = []
138 | for i in range(num_diffusion_timesteps):
139 | t1 = i / num_diffusion_timesteps
140 | t2 = (i + 1) / num_diffusion_timesteps
141 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
142 | return np.array(betas)
143 |
144 |
145 | class GaussianDiffusion:
146 | """
147 | Utilities for training and sampling diffusion models.
148 | Original ported from this codebase:
149 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
150 | :param betas: a 1-D numpy array of betas for each diffusion timestep,
151 | starting at T and going to 1.
152 | """
153 |
154 | def __init__(
155 | self,
156 | *,
157 | betas,
158 | model_mean_type,
159 | model_var_type,
160 | loss_type
161 | ):
162 |
163 | self.model_mean_type = model_mean_type
164 | self.model_var_type = model_var_type
165 | self.loss_type = loss_type
166 |
167 | # Use float64 for accuracy.
168 | betas = np.array(betas, dtype=np.float64)
169 | self.betas = betas
170 | assert len(betas.shape) == 1, "betas must be 1-D"
171 | assert (betas > 0).all() and (betas <= 1).all()
172 |
173 | self.num_timesteps = int(betas.shape[0])
174 |
175 | alphas = 1.0 - betas
176 | self.alphas_cumprod = np.cumprod(alphas, axis=0)
177 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
178 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
179 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
180 |
181 | # calculations for diffusion q(x_t | x_{t-1}) and others
182 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
183 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
184 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
185 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
186 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
187 |
188 | # calculations for posterior q(x_{t-1} | x_t, x_0)
189 | self.posterior_variance = (
190 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
191 | )
192 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
193 | self.posterior_log_variance_clipped = np.log(
194 | np.append(self.posterior_variance[1], self.posterior_variance[1:])
195 | ) if len(self.posterior_variance) > 1 else np.array([])
196 |
197 | self.posterior_mean_coef1 = (
198 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
199 | )
200 | self.posterior_mean_coef2 = (
201 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
202 | )
203 |
204 | def q_mean_variance(self, x_start, t):
205 | """
206 | Get the distribution q(x_t | x_0).
207 | :param x_start: the [N x C x ...] tensor of noiseless inputs.
208 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
209 | :return: A tuple (mean, variance, log_variance), all of x_start's shape.
210 | """
211 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
212 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
213 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
214 | return mean, variance, log_variance
215 |
216 | def q_sample(self, x_start, t, noise=None):
217 | """
218 | Diffuse the data for a given number of diffusion steps.
219 | In other words, sample from q(x_t | x_0).
220 | :param x_start: the initial data batch.
221 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
222 | :param noise: if specified, the split-out normal noise.
223 | :return: A noisy version of x_start.
224 | """
225 | if noise is None:
226 | noise = th.randn_like(x_start)
227 | assert noise.shape == x_start.shape
228 | return (
229 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
230 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
231 | )
232 |
233 | def q_posterior_mean_variance(self, x_start, x_t, t):
234 | """
235 | Compute the mean and variance of the diffusion posterior:
236 | q(x_{t-1} | x_t, x_0)
237 | """
238 | assert x_start.shape == x_t.shape
239 | posterior_mean = (
240 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
241 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
242 | )
243 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
244 | posterior_log_variance_clipped = _extract_into_tensor(
245 | self.posterior_log_variance_clipped, t, x_t.shape
246 | )
247 | assert (
248 | posterior_mean.shape[0]
249 | == posterior_variance.shape[0]
250 | == posterior_log_variance_clipped.shape[0]
251 | == x_start.shape[0]
252 | )
253 | return posterior_mean, posterior_variance, posterior_log_variance_clipped
254 |
255 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
256 | """
257 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
258 | the initial x, x_0.
259 | :param model: the model, which takes a signal and a batch of timesteps
260 | as input.
261 | :param x: the [N x C x ...] tensor at time t.
262 | :param t: a 1-D Tensor of timesteps.
263 | :param clip_denoised: if True, clip the denoised signal into [-1, 1].
264 | :param denoised_fn: if not None, a function which applies to the
265 | x_start prediction before it is used to sample. Applies before
266 | clip_denoised.
267 | :param model_kwargs: if not None, a dict of extra keyword arguments to
268 | pass to the model. This can be used for conditioning.
269 | :return: a dict with the following keys:
270 | - 'mean': the model mean output.
271 | - 'variance': the model variance output.
272 | - 'log_variance': the log of 'variance'.
273 | - 'pred_xstart': the prediction for x_0.
274 | """
275 | if model_kwargs is None:
276 | model_kwargs = {}
277 |
278 | B, C = x.shape[:2]
279 | assert t.shape == (B,)
280 | model_output = model(x, t, **model_kwargs)
281 | if isinstance(model_output, tuple):
282 | model_output, extra = model_output
283 | else:
284 | extra = None
285 |
286 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
287 | assert model_output.shape == (B, C * 2, *x.shape[2:])
288 | model_output, model_var_values = th.split(model_output, C, dim=1)
289 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
290 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
291 | # The model_var_values is [-1, 1] for [min_var, max_var].
292 | frac = (model_var_values + 1) / 2
293 | model_log_variance = frac * max_log + (1 - frac) * min_log
294 | model_variance = th.exp(model_log_variance)
295 | else:
296 | model_variance, model_log_variance = {
297 | # for fixedlarge, we set the initial (log-)variance like so
298 | # to get a better decoder log likelihood.
299 | ModelVarType.FIXED_LARGE: (
300 | np.append(self.posterior_variance[1], self.betas[1:]),
301 | np.log(np.append(self.posterior_variance[1], self.betas[1:])),
302 | ),
303 | ModelVarType.FIXED_SMALL: (
304 | self.posterior_variance,
305 | self.posterior_log_variance_clipped,
306 | ),
307 | }[self.model_var_type]
308 | model_variance = _extract_into_tensor(model_variance, t, x.shape)
309 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
310 |
311 | def process_xstart(x):
312 | if denoised_fn is not None:
313 | x = denoised_fn(x)
314 | if clip_denoised:
315 | return x.clamp(-1, 1)
316 | return x
317 |
318 | if self.model_mean_type == ModelMeanType.START_X:
319 | pred_xstart = process_xstart(model_output)
320 | else:
321 | pred_xstart = process_xstart(
322 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
323 | )
324 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
325 |
326 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
327 | return {
328 | "mean": model_mean,
329 | "variance": model_variance,
330 | "log_variance": model_log_variance,
331 | "pred_xstart": pred_xstart,
332 | "extra": extra,
333 | }
334 |
335 | def _predict_xstart_from_eps(self, x_t, t, eps):
336 | assert x_t.shape == eps.shape
337 | return (
338 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
339 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
340 | )
341 |
342 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
343 | return (
344 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
345 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
346 |
347 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
348 | """
349 | Compute the mean for the previous step, given a function cond_fn that
350 | computes the gradient of a conditional log probability with respect to
351 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
352 | condition on y.
353 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
354 | """
355 | gradient = cond_fn(x, t, **model_kwargs)
356 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
357 | return new_mean
358 |
359 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
360 | """
361 | Compute what the p_mean_variance output would have been, should the
362 | model's score function be conditioned by cond_fn.
363 | See condition_mean() for details on cond_fn.
364 | Unlike condition_mean(), this instead uses the conditioning strategy
365 | from Song et al (2020).
366 | """
367 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
368 |
369 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
370 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
371 |
372 | out = p_mean_var.copy()
373 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
374 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
375 | return out
376 |
377 | def p_sample(
378 | self,
379 | model,
380 | x,
381 | t,
382 | clip_denoised=True,
383 | denoised_fn=None,
384 | cond_fn=None,
385 | model_kwargs=None,
386 | ):
387 | """
388 | Sample x_{t-1} from the model at the given timestep.
389 | :param model: the model to sample from.
390 | :param x: the current tensor at x_{t-1}.
391 | :param t: the value of t, starting at 0 for the first diffusion step.
392 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
393 | :param denoised_fn: if not None, a function which applies to the
394 | x_start prediction before it is used to sample.
395 | :param cond_fn: if not None, this is a gradient function that acts
396 | similarly to the model.
397 | :param model_kwargs: if not None, a dict of extra keyword arguments to
398 | pass to the model. This can be used for conditioning.
399 | :return: a dict containing the following keys:
400 | - 'sample': a random sample from the model.
401 | - 'pred_xstart': a prediction of x_0.
402 | """
403 | out = self.p_mean_variance(
404 | model,
405 | x,
406 | t,
407 | clip_denoised=clip_denoised,
408 | denoised_fn=denoised_fn,
409 | model_kwargs=model_kwargs,
410 | )
411 | noise = th.randn_like(x)
412 | nonzero_mask = (
413 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
414 | ) # no noise when t == 0
415 | if cond_fn is not None:
416 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
417 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
418 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
419 |
420 | def p_sample_loop(
421 | self,
422 | model,
423 | shape,
424 | noise=None,
425 | clip_denoised=True,
426 | denoised_fn=None,
427 | cond_fn=None,
428 | model_kwargs=None,
429 | device=None,
430 | progress=False,
431 | ):
432 | """
433 | Generate samples from the model.
434 | :param model: the model module.
435 | :param shape: the shape of the samples, (N, C, H, W).
436 | :param noise: if specified, the noise from the encoder to sample.
437 | Should be of the same shape as `shape`.
438 | :param clip_denoised: if True, clip x_start predictions to [-1, 1].
439 | :param denoised_fn: if not None, a function which applies to the
440 | x_start prediction before it is used to sample.
441 | :param cond_fn: if not None, this is a gradient function that acts
442 | similarly to the model.
443 | :param model_kwargs: if not None, a dict of extra keyword arguments to
444 | pass to the model. This can be used for conditioning.
445 | :param device: if specified, the device to create the samples on.
446 | If not specified, use a model parameter's device.
447 | :param progress: if True, show a tqdm progress bar.
448 | :return: a non-differentiable batch of samples.
449 | """
450 | final = None
451 | for sample in self.p_sample_loop_progressive(
452 | model,
453 | shape,
454 | noise=noise,
455 | clip_denoised=clip_denoised,
456 | denoised_fn=denoised_fn,
457 | cond_fn=cond_fn,
458 | model_kwargs=model_kwargs,
459 | device=device,
460 | progress=progress,
461 | ):
462 | final = sample
463 | return final["sample"]
464 |
465 | def p_sample_loop_progressive(
466 | self,
467 | model,
468 | shape,
469 | noise=None,
470 | clip_denoised=True,
471 | denoised_fn=None,
472 | cond_fn=None,
473 | model_kwargs=None,
474 | device=None,
475 | progress=False,
476 | ):
477 | """
478 | Generate samples from the model and yield intermediate samples from
479 | each timestep of diffusion.
480 | Arguments are the same as p_sample_loop().
481 | Returns a generator over dicts, where each dict is the return value of
482 | p_sample().
483 | """
484 | if device is None:
485 | device = next(model.parameters()).device
486 | assert isinstance(shape, (tuple, list))
487 | if noise is not None:
488 | img = noise
489 | else:
490 | img = th.randn(*shape, device=device)
491 | indices = list(range(self.num_timesteps))[::-1]
492 |
493 | if progress:
494 | # Lazy import so that we don't depend on tqdm.
495 | from tqdm.auto import tqdm
496 |
497 | indices = tqdm(indices)
498 |
499 | for i in indices:
500 | t = th.tensor([i] * shape[0], device=device)
501 | with th.no_grad():
502 | out = self.p_sample(
503 | model,
504 | img,
505 | t,
506 | clip_denoised=clip_denoised,
507 | denoised_fn=denoised_fn,
508 | cond_fn=cond_fn,
509 | model_kwargs=model_kwargs,
510 | )
511 | yield out
512 | img = out["sample"]
513 |
514 | def ddim_sample(
515 | self,
516 | model,
517 | x,
518 | t,
519 | clip_denoised=True,
520 | denoised_fn=None,
521 | cond_fn=None,
522 | model_kwargs=None,
523 | eta=0.0,
524 | ):
525 | """
526 | Sample x_{t-1} from the model using DDIM.
527 | Same usage as p_sample().
528 | """
529 | out = self.p_mean_variance(
530 | model,
531 | x,
532 | t,
533 | clip_denoised=clip_denoised,
534 | denoised_fn=denoised_fn,
535 | model_kwargs=model_kwargs,
536 | )
537 | if cond_fn is not None:
538 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
539 |
540 | # Usually our model outputs epsilon, but we re-derive it
541 | # in case we used x_start or x_prev prediction.
542 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
543 |
544 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
545 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
546 | sigma = (
547 | eta
548 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
549 | * th.sqrt(1 - alpha_bar / alpha_bar_prev)
550 | )
551 | # Equation 12.
552 | noise = th.randn_like(x)
553 | mean_pred = (
554 | out["pred_xstart"] * th.sqrt(alpha_bar_prev)
555 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
556 | )
557 | nonzero_mask = (
558 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
559 | ) # no noise when t == 0
560 | sample = mean_pred + nonzero_mask * sigma * noise
561 | return {"sample": sample, "pred_xstart": out["pred_xstart"]}
562 |
563 | def ddim_reverse_sample(
564 | self,
565 | model,
566 | x,
567 | t,
568 | clip_denoised=True,
569 | denoised_fn=None,
570 | cond_fn=None,
571 | model_kwargs=None,
572 | eta=0.0,
573 | ):
574 | """
575 | Sample x_{t+1} from the model using DDIM reverse ODE.
576 | """
577 | assert eta == 0.0, "Reverse ODE only for deterministic path"
578 | out = self.p_mean_variance(
579 | model,
580 | x,
581 | t,
582 | clip_denoised=clip_denoised,
583 | denoised_fn=denoised_fn,
584 | model_kwargs=model_kwargs,
585 | )
586 | if cond_fn is not None:
587 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
588 | # Usually our model outputs epsilon, but we re-derive it
589 | # in case we used x_start or x_prev prediction.
590 | eps = (
591 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
592 | - out["pred_xstart"]
593 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
594 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
595 |
596 | # Equation 12. reversed
597 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
598 |
599 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
600 |
601 | def ddim_sample_loop(
602 | self,
603 | model,
604 | shape,
605 | noise=None,
606 | clip_denoised=True,
607 | denoised_fn=None,
608 | cond_fn=None,
609 | model_kwargs=None,
610 | device=None,
611 | progress=False,
612 | eta=0.0,
613 | ):
614 | """
615 | Generate samples from the model using DDIM.
616 | Same usage as p_sample_loop().
617 | """
618 | final = None
619 | for sample in self.ddim_sample_loop_progressive(
620 | model,
621 | shape,
622 | noise=noise,
623 | clip_denoised=clip_denoised,
624 | denoised_fn=denoised_fn,
625 | cond_fn=cond_fn,
626 | model_kwargs=model_kwargs,
627 | device=device,
628 | progress=progress,
629 | eta=eta,
630 | ):
631 | final = sample
632 | return final["sample"]
633 |
634 | def ddim_sample_loop_progressive(
635 | self,
636 | model,
637 | shape,
638 | noise=None,
639 | clip_denoised=True,
640 | denoised_fn=None,
641 | cond_fn=None,
642 | model_kwargs=None,
643 | device=None,
644 | progress=False,
645 | eta=0.0,
646 | ):
647 | """
648 | Use DDIM to sample from the model and yield intermediate samples from
649 | each timestep of DDIM.
650 | Same usage as p_sample_loop_progressive().
651 | """
652 | if device is None:
653 | device = next(model.parameters()).device
654 | assert isinstance(shape, (tuple, list))
655 | if noise is not None:
656 | img = noise
657 | else:
658 | img = th.randn(*shape, device=device)
659 | indices = list(range(self.num_timesteps))[::-1]
660 |
661 | if progress:
662 | # Lazy import so that we don't depend on tqdm.
663 | from tqdm.auto import tqdm
664 |
665 | indices = tqdm(indices)
666 |
667 | for i in indices:
668 | t = th.tensor([i] * shape[0], device=device)
669 | with th.no_grad():
670 | out = self.ddim_sample(
671 | model,
672 | img,
673 | t,
674 | clip_denoised=clip_denoised,
675 | denoised_fn=denoised_fn,
676 | cond_fn=cond_fn,
677 | model_kwargs=model_kwargs,
678 | eta=eta,
679 | )
680 | yield out
681 | img = out["sample"]
682 |
683 | def _vb_terms_bpd(
684 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
685 | ):
686 | """
687 | Get a term for the variational lower-bound.
688 | The resulting units are bits (rather than nats, as one might expect).
689 | This allows for comparison to other papers.
690 | :return: a dict with the following keys:
691 | - 'output': a shape [N] tensor of NLLs or KLs.
692 | - 'pred_xstart': the x_0 predictions.
693 | """
694 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
695 | x_start=x_start, x_t=x_t, t=t
696 | )
697 | out = self.p_mean_variance(
698 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
699 | )
700 | kl = normal_kl(
701 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
702 | )
703 | kl = mean_flat(kl) / np.log(2.0)
704 |
705 | decoder_nll = -discretized_gaussian_log_likelihood(
706 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
707 | )
708 | assert decoder_nll.shape == x_start.shape
709 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
710 |
711 | # At the first timestep return the decoder NLL,
712 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
713 | output = th.where((t == 0), decoder_nll, kl)
714 | return {"output": output, "pred_xstart": out["pred_xstart"]}
715 |
716 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
717 | """
718 | Compute training losses for a single timestep.
719 | :param model: the model to evaluate loss on.
720 | :param x_start: the [N x C x ...] tensor of inputs.
721 | :param t: a batch of timestep indices.
722 | :param model_kwargs: if not None, a dict of extra keyword arguments to
723 | pass to the model. This can be used for conditioning.
724 | :param noise: if specified, the specific Gaussian noise to try to remove.
725 | :return: a dict with the key "loss" containing a tensor of shape [N].
726 | Some mean or variance settings may also have other keys.
727 | """
728 | if model_kwargs is None:
729 | model_kwargs = {}
730 | if noise is None:
731 | noise = th.randn_like(x_start)
732 | x_t = self.q_sample(x_start, t, noise=noise)
733 |
734 | terms = {}
735 |
736 |
737 | mse_loss_weight = None
738 | alpha = _extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape)
739 | sigma = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape)
740 | snr = (alpha / sigma) ** 2
741 |
742 | velocity = (alpha[:, None, None, None] * x_t - x_start) / sigma[:, None, None, None]
743 |
744 | # get loss weight
745 | if self.model_mean_type is not ModelMeanType.START_X:
746 | mse_loss_weight = th.ones_like(t)
747 | k = 5.0
748 | # min{snr, k}
749 | mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] / snr
750 | else:
751 | k = 5.0
752 | # min{snr, k}
753 | mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0]
754 |
755 |
756 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
757 | terms["loss"] = self._vb_terms_bpd(
758 | model=model,
759 | x_start=x_start,
760 | x_t=x_t,
761 | t=t,
762 | clip_denoised=False,
763 | model_kwargs=model_kwargs,
764 | )["output"]
765 | if self.loss_type == LossType.RESCALED_KL:
766 | terms["loss"] *= self.num_timesteps
767 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
768 | model_output = model(x_t, t, **model_kwargs)
769 |
770 | if self.model_var_type in [
771 | ModelVarType.LEARNED,
772 | ModelVarType.LEARNED_RANGE,
773 | ]:
774 | B, C = x_t.shape[:2]
775 | assert model_output.shape == (B, C * 2, *x_t.shape[2:])
776 | model_output, model_var_values = th.split(model_output, C, dim=1)
777 | # Learn the variance using the variational bound, but don't let
778 | # it affect our mean prediction.
779 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
780 | terms["vb"] = self._vb_terms_bpd(
781 | model=lambda *args, r=frozen_out: r,
782 | x_start=x_start,
783 | x_t=x_t,
784 | t=t,
785 | clip_denoised=False,
786 | )["output"]
787 | if self.loss_type == LossType.RESCALED_MSE:
788 | # Divide by 1000 for equivalence with initial implementation.
789 | # Without a factor of 1/1000, the VB term hurts the MSE term.
790 | terms["vb"] *= self.num_timesteps / 1000.0
791 |
792 | target = {
793 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
794 | x_start=x_start, x_t=x_t, t=t
795 | )[0],
796 | ModelMeanType.START_X: x_start,
797 | ModelMeanType.EPSILON: noise,
798 | ModelMeanType.VELOCITY: velocity,
799 | }[self.model_mean_type]
800 | assert model_output.shape == target.shape == x_start.shape
801 | terms["mse"] = mse_loss_weight * mean_flat((target - model_output) ** 2)
802 | if "vb" in terms:
803 | terms["loss"] = terms["mse"] + terms["vb"]
804 | else:
805 | terms["loss"] = terms["mse"]
806 | else:
807 | raise NotImplementedError(self.loss_type)
808 |
809 | return terms
810 |
811 | def _prior_bpd(self, x_start):
812 | """
813 | Get the prior KL term for the variational lower-bound, measured in
814 | bits-per-dim.
815 | This term can't be optimized, as it only depends on the encoder.
816 | :param x_start: the [N x C x ...] tensor of inputs.
817 | :return: a batch of [N] KL values (in bits), one per batch element.
818 | """
819 | batch_size = x_start.shape[0]
820 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
821 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
822 | kl_prior = normal_kl(
823 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
824 | )
825 | return mean_flat(kl_prior) / np.log(2.0)
826 |
827 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
828 | """
829 | Compute the entire variational lower-bound, measured in bits-per-dim,
830 | as well as other related quantities.
831 | :param model: the model to evaluate loss on.
832 | :param x_start: the [N x C x ...] tensor of inputs.
833 | :param clip_denoised: if True, clip denoised samples.
834 | :param model_kwargs: if not None, a dict of extra keyword arguments to
835 | pass to the model. This can be used for conditioning.
836 | :return: a dict containing the following keys:
837 | - total_bpd: the total variational lower-bound, per batch element.
838 | - prior_bpd: the prior term in the lower-bound.
839 | - vb: an [N x T] tensor of terms in the lower-bound.
840 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
841 | - mse: an [N x T] tensor of epsilon MSEs for each timestep.
842 | """
843 | device = x_start.device
844 | batch_size = x_start.shape[0]
845 |
846 | vb = []
847 | xstart_mse = []
848 | mse = []
849 | for t in list(range(self.num_timesteps))[::-1]:
850 | t_batch = th.tensor([t] * batch_size, device=device)
851 | noise = th.randn_like(x_start)
852 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
853 | # Calculate VLB term at the current timestep
854 | with th.no_grad():
855 | out = self._vb_terms_bpd(
856 | model,
857 | x_start=x_start,
858 | x_t=x_t,
859 | t=t_batch,
860 | clip_denoised=clip_denoised,
861 | model_kwargs=model_kwargs,
862 | )
863 | vb.append(out["output"])
864 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
865 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
866 | mse.append(mean_flat((eps - noise) ** 2))
867 |
868 | vb = th.stack(vb, dim=1)
869 | xstart_mse = th.stack(xstart_mse, dim=1)
870 | mse = th.stack(mse, dim=1)
871 |
872 | prior_bpd = self._prior_bpd(x_start)
873 | total_bpd = vb.sum(dim=1) + prior_bpd
874 | return {
875 | "total_bpd": total_bpd,
876 | "prior_bpd": prior_bpd,
877 | "vb": vb,
878 | "xstart_mse": xstart_mse,
879 | "mse": mse,
880 | }
881 |
882 |
883 | def _extract_into_tensor(arr, timesteps, broadcast_shape):
884 | """
885 | Extract values from a 1-D numpy array for a batch of indices.
886 | :param arr: the 1-D numpy array.
887 | :param timesteps: a tensor of indices into the array to extract.
888 | :param broadcast_shape: a larger shape of K dimensions with the batch
889 | dimension equal to the length of timesteps.
890 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
891 | """
892 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
893 | while len(res.shape) < len(broadcast_shape):
894 | res = res[..., None]
895 | return res + th.zeros(broadcast_shape, device=timesteps.device)
896 |
--------------------------------------------------------------------------------
/masked_diffusion/image_datasets.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 |
4 | from PIL import Image
5 | import blobfile as bf
6 | from mpi4py import MPI
7 | import numpy as np
8 | from torch.utils.data import DataLoader, Dataset
9 |
10 |
11 | def load_data(
12 | *,
13 | data_dir,
14 | batch_size,
15 | image_size,
16 | class_cond=False,
17 | deterministic=False,
18 | random_crop=False,
19 | random_flip=True,
20 | ):
21 | """
22 | For a dataset, create a generator over (images, kwargs) pairs.
23 |
24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or
25 | more keys, each of which map to a batched Tensor of their own.
26 | The kwargs dict can be used for class labels, in which case the key is "y"
27 | and the values are integer tensors of class labels.
28 |
29 | :param data_dir: a dataset directory.
30 | :param batch_size: the batch size of each returned pair.
31 | :param image_size: the size to which images are resized.
32 | :param class_cond: if True, include a "y" key in returned dicts for class
33 | label. If classes are not available and this is true, an
34 | exception will be raised.
35 | :param deterministic: if True, yield results in a deterministic order.
36 | :param random_crop: if True, randomly crop the images for augmentation.
37 | :param random_flip: if True, randomly flip the images for augmentation.
38 | """
39 | if not data_dir:
40 | raise ValueError("unspecified data directory")
41 | all_files = _list_image_files_recursively(data_dir)
42 | classes = None
43 | if class_cond:
44 | # Assume classes are the first part of the filename,
45 | # before an underscore.
46 | class_names = [bf.basename(path).split("_")[0] for path in all_files]
47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
48 | classes = [sorted_classes[x] for x in class_names]
49 | dataset = ImageDataset(
50 | image_size,
51 | all_files,
52 | classes=classes,
53 | shard=MPI.COMM_WORLD.Get_rank(),
54 | num_shards=MPI.COMM_WORLD.Get_size(),
55 | random_crop=random_crop,
56 | random_flip=random_flip,
57 | )
58 | if deterministic:
59 | loader = DataLoader(
60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
61 | )
62 | else:
63 | loader = DataLoader(
64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
65 | )
66 | while True:
67 | yield from loader
68 |
69 |
70 | def _list_image_files_recursively(data_dir):
71 | results = []
72 | for entry in sorted(bf.listdir(data_dir)):
73 | full_path = bf.join(data_dir, entry)
74 | ext = entry.split(".")[-1]
75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
76 | results.append(full_path)
77 | elif bf.isdir(full_path):
78 | results.extend(_list_image_files_recursively(full_path))
79 | return results
80 |
81 |
82 | class ImageDataset(Dataset):
83 | def __init__(
84 | self,
85 | resolution,
86 | image_paths,
87 | classes=None,
88 | shard=0,
89 | num_shards=1,
90 | random_crop=False,
91 | random_flip=True,
92 | ):
93 | super().__init__()
94 | self.resolution = resolution
95 | self.local_images = image_paths[shard:][::num_shards]
96 | self.local_classes = None if classes is None else classes[shard:][::num_shards]
97 | self.random_crop = random_crop
98 | self.random_flip = random_flip
99 |
100 | def __len__(self):
101 | return len(self.local_images)
102 |
103 | def __getitem__(self, idx):
104 | path = self.local_images[idx]
105 | with bf.BlobFile(path, "rb") as f:
106 | pil_image = Image.open(f)
107 | pil_image.load()
108 | pil_image = pil_image.convert("RGB")
109 |
110 | if self.random_crop:
111 | arr = random_crop_arr(pil_image, self.resolution)
112 | else:
113 | arr = center_crop_arr(pil_image, self.resolution)
114 |
115 | if self.random_flip and random.random() < 0.5:
116 | arr = arr[:, ::-1]
117 |
118 | arr = arr.astype(np.float32) / 127.5 - 1
119 |
120 | out_dict = {}
121 | if self.local_classes is not None:
122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
123 | return np.transpose(arr, [2, 0, 1]), out_dict
124 |
125 |
126 | def center_crop_arr(pil_image, image_size):
127 | # We are not on a new enough PIL to support the `reducing_gap`
128 | # argument, which uses BOX downsampling at powers of two first.
129 | # Thus, we do it by hand to improve downsample quality.
130 | while min(*pil_image.size) >= 2 * image_size:
131 | pil_image = pil_image.resize(
132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
133 | )
134 |
135 | scale = image_size / min(*pil_image.size)
136 | pil_image = pil_image.resize(
137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
138 | )
139 |
140 | arr = np.array(pil_image)
141 | crop_y = (arr.shape[0] - image_size) // 2
142 | crop_x = (arr.shape[1] - image_size) // 2
143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
144 |
145 |
146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
150 |
151 | # We are not on a new enough PIL to support the `reducing_gap`
152 | # argument, which uses BOX downsampling at powers of two first.
153 | # Thus, we do it by hand to improve downsample quality.
154 | while min(*pil_image.size) >= 2 * smaller_dim_size:
155 | pil_image = pil_image.resize(
156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
157 | )
158 |
159 | scale = smaller_dim_size / min(*pil_image.size)
160 | pil_image = pil_image.resize(
161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
162 | )
163 |
164 | arr = np.array(pil_image)
165 | crop_y = random.randrange(arr.shape[0] - image_size + 1)
166 | crop_x = random.randrange(arr.shape[1] - image_size + 1)
167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]
168 |
--------------------------------------------------------------------------------
/masked_diffusion/logger.py:
--------------------------------------------------------------------------------
1 | """
2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4 | """
5 |
6 | import os
7 | import sys
8 | import shutil
9 | import os.path as osp
10 | import json
11 | import time
12 | import datetime
13 | import tempfile
14 | import warnings
15 | from collections import defaultdict
16 | from contextlib import contextmanager
17 |
18 | DEBUG = 10
19 | INFO = 20
20 | WARN = 30
21 | ERROR = 40
22 |
23 | DISABLED = 50
24 |
25 |
26 | class KVWriter(object):
27 | def writekvs(self, kvs):
28 | raise NotImplementedError
29 |
30 |
31 | class SeqWriter(object):
32 | def writeseq(self, seq):
33 | raise NotImplementedError
34 |
35 |
36 | class HumanOutputFormat(KVWriter, SeqWriter):
37 | def __init__(self, filename_or_file):
38 | if isinstance(filename_or_file, str):
39 | self.file = open(filename_or_file, "wt")
40 | self.own_file = True
41 | else:
42 | assert hasattr(filename_or_file, "read"), (
43 | "expected file or str, got %s" % filename_or_file
44 | )
45 | self.file = filename_or_file
46 | self.own_file = False
47 |
48 | def writekvs(self, kvs):
49 | # Create strings for printing
50 | key2str = {}
51 | for (key, val) in sorted(kvs.items()):
52 | if hasattr(val, "__float__"):
53 | valstr = "%-8.3g" % val
54 | else:
55 | valstr = str(val)
56 | key2str[self._truncate(key)] = self._truncate(valstr)
57 |
58 | # Find max widths
59 | if len(key2str) == 0:
60 | print("WARNING: tried to write empty key-value dict")
61 | return
62 | else:
63 | keywidth = max(map(len, key2str.keys()))
64 | valwidth = max(map(len, key2str.values()))
65 |
66 | # Write out the data
67 | dashes = "-" * (keywidth + valwidth + 7)
68 | lines = [dashes]
69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70 | lines.append(
71 | "| %s%s | %s%s |"
72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73 | )
74 | lines.append(dashes)
75 | self.file.write("\n".join(lines) + "\n")
76 |
77 | # Flush the output to the file
78 | self.file.flush()
79 |
80 | def _truncate(self, s):
81 | maxlen = 30
82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83 |
84 | def writeseq(self, seq):
85 | seq = list(seq)
86 | for (i, elem) in enumerate(seq):
87 | self.file.write(elem)
88 | if i < len(seq) - 1: # add space unless this is the last one
89 | self.file.write(" ")
90 | self.file.write("\n")
91 | self.file.flush()
92 |
93 | def close(self):
94 | if self.own_file:
95 | self.file.close()
96 |
97 |
98 | class JSONOutputFormat(KVWriter):
99 | def __init__(self, filename):
100 | self.file = open(filename, "wt")
101 |
102 | def writekvs(self, kvs):
103 | for k, v in sorted(kvs.items()):
104 | if hasattr(v, "dtype"):
105 | kvs[k] = float(v)
106 | self.file.write(json.dumps(kvs) + "\n")
107 | self.file.flush()
108 |
109 | def close(self):
110 | self.file.close()
111 |
112 |
113 | class CSVOutputFormat(KVWriter):
114 | def __init__(self, filename):
115 | self.file = open(filename, "w+t")
116 | self.keys = []
117 | self.sep = ","
118 |
119 | def writekvs(self, kvs):
120 | # Add our current row to the history
121 | extra_keys = list(kvs.keys() - self.keys)
122 | extra_keys.sort()
123 | if extra_keys:
124 | self.keys.extend(extra_keys)
125 | self.file.seek(0)
126 | lines = self.file.readlines()
127 | self.file.seek(0)
128 | for (i, k) in enumerate(self.keys):
129 | if i > 0:
130 | self.file.write(",")
131 | self.file.write(k)
132 | self.file.write("\n")
133 | for line in lines[1:]:
134 | self.file.write(line[:-1])
135 | self.file.write(self.sep * len(extra_keys))
136 | self.file.write("\n")
137 | for (i, k) in enumerate(self.keys):
138 | if i > 0:
139 | self.file.write(",")
140 | v = kvs.get(k)
141 | if v is not None:
142 | self.file.write(str(v))
143 | self.file.write("\n")
144 | self.file.flush()
145 |
146 | def close(self):
147 | self.file.close()
148 |
149 |
150 | class TensorBoardOutputFormat(KVWriter):
151 | """
152 | Dumps key/value pairs into TensorBoard's numeric format.
153 | """
154 |
155 | def __init__(self, dir):
156 | os.makedirs(dir, exist_ok=True)
157 | self.dir = dir
158 | self.step = 1
159 | prefix = "events"
160 | path = osp.join(osp.abspath(dir), prefix)
161 | import tensorflow as tf
162 | from tensorflow.python import pywrap_tensorflow
163 | from tensorflow.core.util import event_pb2
164 | from tensorflow.python.util import compat
165 |
166 | self.tf = tf
167 | self.event_pb2 = event_pb2
168 | self.pywrap_tensorflow = pywrap_tensorflow
169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170 |
171 | def writekvs(self, kvs):
172 | def summary_val(k, v):
173 | kwargs = {"tag": k, "simple_value": float(v)}
174 | return self.tf.Summary.Value(**kwargs)
175 |
176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178 | event.step = (
179 | self.step
180 | ) # is there any reason why you'd want to specify the step?
181 | self.writer.WriteEvent(event)
182 | self.writer.Flush()
183 | self.step += 1
184 |
185 | def close(self):
186 | if self.writer:
187 | self.writer.Close()
188 | self.writer = None
189 |
190 |
191 | def make_output_format(format, ev_dir, log_suffix=""):
192 | os.makedirs(ev_dir, exist_ok=True)
193 | if format == "stdout":
194 | return HumanOutputFormat(sys.stdout)
195 | elif format == "log":
196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197 | elif format == "json":
198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199 | elif format == "csv":
200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201 | elif format == "tensorboard":
202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203 | else:
204 | raise ValueError("Unknown format specified: %s" % (format,))
205 |
206 |
207 | # ================================================================
208 | # API
209 | # ================================================================
210 |
211 |
212 | def logkv(key, val):
213 | """
214 | Log a value of some diagnostic
215 | Call this once for each diagnostic quantity, each iteration
216 | If called many times, last value will be used.
217 | """
218 | get_current().logkv(key, val)
219 |
220 |
221 | def logkv_mean(key, val):
222 | """
223 | The same as logkv(), but if called many times, values averaged.
224 | """
225 | get_current().logkv_mean(key, val)
226 |
227 |
228 | def logkvs(d):
229 | """
230 | Log a dictionary of key-value pairs
231 | """
232 | for (k, v) in d.items():
233 | logkv(k, v)
234 |
235 |
236 | def dumpkvs():
237 | """
238 | Write all of the diagnostics from the current iteration
239 | """
240 | return get_current().dumpkvs()
241 |
242 |
243 | def getkvs():
244 | return get_current().name2val
245 |
246 |
247 | def log(*args, level=INFO):
248 | """
249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
250 | """
251 | get_current().log(*args, level=level)
252 |
253 |
254 | def debug(*args):
255 | log(*args, level=DEBUG)
256 |
257 |
258 | def info(*args):
259 | log(*args, level=INFO)
260 |
261 |
262 | def warn(*args):
263 | log(*args, level=WARN)
264 |
265 |
266 | def error(*args):
267 | log(*args, level=ERROR)
268 |
269 |
270 | def set_level(level):
271 | """
272 | Set logging threshold on current logger.
273 | """
274 | get_current().set_level(level)
275 |
276 |
277 | def set_comm(comm):
278 | get_current().set_comm(comm)
279 |
280 |
281 | def get_dir():
282 | """
283 | Get directory that log files are being written to.
284 | will be None if there is no output directory (i.e., if you didn't call start)
285 | """
286 | return get_current().get_dir()
287 |
288 |
289 | record_tabular = logkv
290 | dump_tabular = dumpkvs
291 |
292 |
293 | @contextmanager
294 | def profile_kv(scopename):
295 | logkey = "wait_" + scopename
296 | tstart = time.time()
297 | try:
298 | yield
299 | finally:
300 | get_current().name2val[logkey] += time.time() - tstart
301 |
302 |
303 | def profile(n):
304 | """
305 | Usage:
306 | @profile("my_func")
307 | def my_func(): code
308 | """
309 |
310 | def decorator_with_name(func):
311 | def func_wrapper(*args, **kwargs):
312 | with profile_kv(n):
313 | return func(*args, **kwargs)
314 |
315 | return func_wrapper
316 |
317 | return decorator_with_name
318 |
319 |
320 | # ================================================================
321 | # Backend
322 | # ================================================================
323 |
324 |
325 | def get_current():
326 | if Logger.CURRENT is None:
327 | _configure_default_logger()
328 |
329 | return Logger.CURRENT
330 |
331 |
332 | class Logger(object):
333 | DEFAULT = None # A logger with no output files. (See right below class definition)
334 | # So that you can still log to the terminal without setting up any output files
335 | CURRENT = None # Current logger being used by the free functions above
336 |
337 | def __init__(self, dir, output_formats, comm=None):
338 | self.name2val = defaultdict(float) # values this iteration
339 | self.name2cnt = defaultdict(int)
340 | self.level = INFO
341 | self.dir = dir
342 | self.output_formats = output_formats
343 | self.comm = comm
344 |
345 | # Logging API, forwarded
346 | # ----------------------------------------
347 | def logkv(self, key, val):
348 | self.name2val[key] = val
349 |
350 | def logkv_mean(self, key, val):
351 | oldval, cnt = self.name2val[key], self.name2cnt[key]
352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
353 | self.name2cnt[key] = cnt + 1
354 |
355 | def dumpkvs(self):
356 | if self.comm is None:
357 | d = self.name2val
358 | else:
359 | d = mpi_weighted_mean(
360 | self.comm,
361 | {
362 | name: (val, self.name2cnt.get(name, 1))
363 | for (name, val) in self.name2val.items()
364 | },
365 | )
366 | if self.comm.rank != 0:
367 | d["dummy"] = 1 # so we don't get a warning about empty dict
368 | out = d.copy() # Return the dict for unit testing purposes
369 | for fmt in self.output_formats:
370 | if isinstance(fmt, KVWriter):
371 | fmt.writekvs(d)
372 | self.name2val.clear()
373 | self.name2cnt.clear()
374 | return out
375 |
376 | def log(self, *args, level=INFO):
377 | if self.level <= level:
378 | self._do_log(args)
379 |
380 | # Configuration
381 | # ----------------------------------------
382 | def set_level(self, level):
383 | self.level = level
384 |
385 | def set_comm(self, comm):
386 | self.comm = comm
387 |
388 | def get_dir(self):
389 | return self.dir
390 |
391 | def close(self):
392 | for fmt in self.output_formats:
393 | fmt.close()
394 |
395 | # Misc
396 | # ----------------------------------------
397 | def _do_log(self, args):
398 | for fmt in self.output_formats:
399 | if isinstance(fmt, SeqWriter):
400 | fmt.writeseq(map(str, args))
401 |
402 |
403 | def get_rank_without_mpi_import():
404 | # check environment variables here instead of importing mpi4py
405 | # to avoid calling MPI_Init() when this module is imported
406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
407 | if varname in os.environ:
408 | return int(os.environ[varname])
409 | return 0
410 |
411 |
412 | def mpi_weighted_mean(comm, local_name2valcount):
413 | """
414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
415 | Perform a weighted average over dicts that are each on a different node
416 | Input: local_name2valcount: dict mapping key -> (value, count)
417 | Returns: key -> mean
418 | """
419 | all_name2valcount = comm.gather(local_name2valcount)
420 | if comm.rank == 0:
421 | name2sum = defaultdict(float)
422 | name2count = defaultdict(float)
423 | for n2vc in all_name2valcount:
424 | for (name, (val, count)) in n2vc.items():
425 | try:
426 | val = float(val)
427 | except ValueError:
428 | if comm.rank == 0:
429 | warnings.warn(
430 | "WARNING: tried to compute mean on non-float {}={}".format(
431 | name, val
432 | )
433 | )
434 | else:
435 | name2sum[name] += val * count
436 | name2count[name] += count
437 | return {name: name2sum[name] / name2count[name] for name in name2sum}
438 | else:
439 | return {}
440 |
441 |
442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
443 | """
444 | If comm is provided, average all numerical stats across that comm
445 | """
446 | if dir is None:
447 | dir = os.getenv("OPENAI_LOGDIR")
448 | if dir is None:
449 | dir = osp.join(
450 | tempfile.gettempdir(),
451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
452 | )
453 | assert isinstance(dir, str)
454 | dir = os.path.expanduser(dir)
455 | os.makedirs(os.path.expanduser(dir), exist_ok=True)
456 |
457 | rank = get_rank_without_mpi_import()
458 | if rank > 0:
459 | log_suffix = log_suffix + "-rank%03i" % rank
460 |
461 | if format_strs is None:
462 | if rank == 0:
463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
464 | else:
465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
466 | format_strs = filter(None, format_strs)
467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
468 |
469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
470 | if output_formats:
471 | log("Logging to %s" % dir)
472 |
473 |
474 | def _configure_default_logger():
475 | configure()
476 | Logger.DEFAULT = Logger.CURRENT
477 |
478 |
479 | def reset():
480 | if Logger.CURRENT is not Logger.DEFAULT:
481 | Logger.CURRENT.close()
482 | Logger.CURRENT = Logger.DEFAULT
483 | log("Reset logger")
484 |
485 |
486 | @contextmanager
487 | def scoped_configure(dir=None, format_strs=None, comm=None):
488 | prevlogger = Logger.CURRENT
489 | configure(dir=dir, format_strs=format_strs, comm=comm)
490 | try:
491 | yield
492 | finally:
493 | Logger.CURRENT.close()
494 | Logger.CURRENT = prevlogger
495 |
496 |
--------------------------------------------------------------------------------
/masked_diffusion/models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | import math
5 | from timm.models.vision_transformer import PatchEmbed, Mlp
6 | from timm.models.layers import trunc_normal_
7 | import math
8 |
9 | def modulate(x, shift, scale):
10 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
11 |
12 |
13 | class Attention(nn.Module):
14 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., num_patches=None):
15 | super().__init__()
16 | self.num_heads = num_heads
17 | head_dim = dim // num_heads
18 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
19 | self.scale = qk_scale or head_dim ** -0.5
20 |
21 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
22 | self.attn_drop = nn.Dropout(attn_drop)
23 | self.proj = nn.Linear(dim, dim)
24 | self.proj_drop = nn.Dropout(proj_drop)
25 | self.rel_pos_bias = RelativePositionBias(
26 | window_size=[int(num_patches**0.5), int(num_patches**0.5)], num_heads=num_heads)
27 |
28 | def get_masked_rel_bias(self, B, ids_keep):
29 | # get masked rel_pos_bias
30 | rel_pos_bias = self.rel_pos_bias()
31 | rel_pos_bias = rel_pos_bias.unsqueeze(dim=0).repeat(B, 1, 1, 1)
32 |
33 | rel_pos_bias_masked = torch.gather(
34 | rel_pos_bias, dim=2, index=ids_keep.unsqueeze(dim=1).unsqueeze(dim=-1).repeat(1, rel_pos_bias.shape[1], 1, rel_pos_bias.shape[-1]))
35 | rel_pos_bias_masked = torch.gather(
36 | rel_pos_bias_masked, dim=3, index=ids_keep.unsqueeze(dim=1).unsqueeze(dim=2).repeat(1, rel_pos_bias.shape[1], ids_keep.shape[1], 1))
37 | return rel_pos_bias_masked
38 |
39 | def forward(self, x, ids_keep=None):
40 | B, N, C = x.shape
41 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
42 | self.num_heads).permute(2, 0, 3, 1, 4)
43 | # make torchscript happy (cannot use tensor as tuple)
44 | q, k, v = qkv[0], qkv[1], qkv[2]
45 |
46 | attn = (q @ k.transpose(-2, -1)) * self.scale
47 | if ids_keep is not None:
48 | rp_bias = self.get_masked_rel_bias(B, ids_keep)
49 | else:
50 | rp_bias = self.rel_pos_bias()
51 | attn += rp_bias
52 | attn = attn.softmax(dim=-1)
53 | attn = self.attn_drop(attn)
54 |
55 | x = (attn @ v).transpose(1, 2).reshape(B, N, C)
56 | x = self.proj(x)
57 | x = self.proj_drop(x)
58 | return x
59 |
60 |
61 | class RelativePositionBias(nn.Module):
62 | # https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py
63 | def __init__(self, window_size, num_heads):
64 | super().__init__()
65 | self.window_size = window_size
66 | self.num_relative_distance = (
67 | 2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
68 | self.relative_position_bias_table = nn.Parameter(
69 | torch.zeros(self.num_relative_distance, num_heads))
70 |
71 | # get pair-wise relative position index for each token inside the window
72 | coords_h = torch.arange(window_size[0])
73 | coords_w = torch.arange(window_size[1])
74 | coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
75 | coords_flatten = torch.flatten(coords, 1)
76 | relative_coords = coords_flatten[:, :, None] - \
77 | coords_flatten[:, None, :]
78 | relative_coords = relative_coords.permute(
79 | 1, 2, 0).contiguous()
80 | relative_coords[:, :, 0] += window_size[0] - 1
81 | relative_coords[:, :, 1] += window_size[1] - 1
82 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1
83 | relative_position_index = \
84 | torch.zeros(
85 | size=(window_size[0] * window_size[1],) * 2, dtype=relative_coords.dtype)
86 | relative_position_index = relative_coords.sum(-1)
87 |
88 | self.register_buffer("relative_position_index",
89 | relative_position_index)
90 |
91 | trunc_normal_(self.relative_position_bias_table, std=.02)
92 |
93 | def forward(self):
94 | relative_position_bias = \
95 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
96 | self.window_size[0] * self.window_size[1],
97 | self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
98 | # nH, Wh*Ww, Wh*Ww
99 | return relative_position_bias.permute(2, 0, 1).contiguous()
100 |
101 | #################################################################################
102 | # Embedding Layers for Timesteps and Class Labels #
103 | #################################################################################
104 |
105 |
106 | class TimestepEmbedder(nn.Module):
107 | """
108 | Embeds scalar timesteps into vector representations.
109 | """
110 |
111 | def __init__(self, hidden_size, frequency_embedding_size=256):
112 | super().__init__()
113 | self.mlp = nn.Sequential(
114 | nn.Linear(frequency_embedding_size, hidden_size, bias=True),
115 | nn.SiLU(),
116 | nn.Linear(hidden_size, hidden_size, bias=True),
117 | )
118 | self.frequency_embedding_size = frequency_embedding_size
119 |
120 | @staticmethod
121 | def timestep_embedding(t, dim, max_period=10000):
122 | """
123 | Create sinusoidal timestep embeddings.
124 | :param t: a 1-D Tensor of N indices, one per batch element.
125 | These may be fractional.
126 | :param dim: the dimension of the output.
127 | :param max_period: controls the minimum frequency of the embeddings.
128 | :return: an (N, D) Tensor of positional embeddings.
129 | """
130 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
131 | half = dim // 2
132 | freqs = torch.exp(
133 | -math.log(max_period) * torch.arange(start=0,
134 | end=half, dtype=torch.float32) / half
135 | ).to(device=t.device)
136 | args = t[:, None].float() * freqs[None]
137 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
138 | if dim % 2:
139 | embedding = torch.cat(
140 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
141 | return embedding
142 |
143 | def forward(self, t):
144 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
145 | t_emb = self.mlp(t_freq)
146 | return t_emb
147 |
148 |
149 | class LabelEmbedder(nn.Module):
150 | """
151 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
152 | """
153 |
154 | def __init__(self, num_classes, hidden_size, dropout_prob):
155 | super().__init__()
156 | use_cfg_embedding = dropout_prob > 0
157 | self.embedding_table = nn.Embedding(
158 | num_classes + use_cfg_embedding, hidden_size)
159 | self.num_classes = num_classes
160 | self.dropout_prob = dropout_prob
161 |
162 | def token_drop(self, labels, force_drop_ids=None):
163 | """
164 | Drops labels to enable classifier-free guidance.
165 | """
166 | if force_drop_ids is None:
167 | drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob
168 | else:
169 | drop_ids = force_drop_ids == 1
170 |
171 | labels = torch.where(drop_ids.to(labels.device),
172 | self.num_classes, labels)
173 | return labels
174 |
175 | def forward(self, labels, train, force_drop_ids=None):
176 | use_dropout = self.dropout_prob > 0
177 | if (train and use_dropout) or (force_drop_ids is not None):
178 | labels = self.token_drop(labels, force_drop_ids)
179 | embeddings = self.embedding_table(labels)
180 | return embeddings
181 |
182 |
183 | #################################################################################
184 | # Core MDT Model #
185 | #################################################################################
186 |
187 | class MDTBlock(nn.Module):
188 | """
189 | A MDT block with adaptive layer norm zero (adaLN-Zero) conditioning.
190 | """
191 |
192 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, skip=False, **block_kwargs):
193 | super().__init__()
194 | self.norm1 = nn.LayerNorm(
195 | hidden_size, elementwise_affine=False, eps=1e-6)
196 | self.attn = Attention(
197 | hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
198 | self.norm2 = nn.LayerNorm(
199 | hidden_size, elementwise_affine=False, eps=1e-6)
200 | mlp_hidden_dim = int(hidden_size * mlp_ratio)
201 | def approx_gelu(): return nn.GELU(approximate="tanh")
202 | self.mlp = Mlp(in_features=hidden_size,
203 | hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
204 | self.adaLN_modulation = nn.Sequential(
205 | nn.SiLU(),
206 | nn.Linear(hidden_size, 6 * hidden_size, bias=True)
207 | )
208 | self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None
209 |
210 | def forward(self, x, c, skip=None, ids_keep=None):
211 | if self.skip_linear is not None:
212 | x = self.skip_linear(torch.cat([x, skip], dim=-1))
213 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(
214 | c).chunk(6, dim=1)
215 | x = x + gate_msa.unsqueeze(1) * self.attn(
216 | modulate(self.norm1(x), shift_msa, scale_msa), ids_keep=ids_keep)
217 | x = x + \
218 | gate_mlp.unsqueeze(
219 | 1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
220 | return x
221 |
222 |
223 | class FinalLayer(nn.Module):
224 | """
225 | The final layer of MDT.
226 | """
227 |
228 | def __init__(self, hidden_size, patch_size, out_channels):
229 | super().__init__()
230 | self.norm_final = nn.LayerNorm(
231 | hidden_size, elementwise_affine=False, eps=1e-6)
232 | self.linear = nn.Linear(
233 | hidden_size, patch_size * patch_size * out_channels, bias=True)
234 | self.adaLN_modulation = nn.Sequential(
235 | nn.SiLU(),
236 | nn.Linear(hidden_size, 2 * hidden_size, bias=True)
237 | )
238 |
239 | def forward(self, x, c):
240 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
241 | x = modulate(self.norm_final(x), shift, scale)
242 | x = self.linear(x)
243 | return x
244 |
245 |
246 | class MDTv2(nn.Module):
247 | """
248 | Masked Diffusion Transformer v2.
249 | """
250 |
251 | def __init__(
252 | self,
253 | input_size=32,
254 | patch_size=2,
255 | in_channels=4,
256 | hidden_size=1152,
257 | depth=28,
258 | num_heads=16,
259 | mlp_ratio=4.0,
260 | class_dropout_prob=0.1,
261 | num_classes=1000,
262 | learn_sigma=True,
263 | mask_ratio=None,
264 | decode_layer=4,
265 | ):
266 | super().__init__()
267 | self.learn_sigma = learn_sigma
268 | self.in_channels = in_channels
269 | self.out_channels = in_channels * 2 if learn_sigma else in_channels
270 | self.patch_size = patch_size
271 | self.num_heads = num_heads
272 | decode_layer = int(decode_layer)
273 |
274 | self.x_embedder = PatchEmbed(
275 | input_size, patch_size, in_channels, hidden_size, bias=True)
276 | self.t_embedder = TimestepEmbedder(hidden_size)
277 | self.y_embedder = LabelEmbedder(
278 | num_classes, hidden_size, class_dropout_prob)
279 | num_patches = self.x_embedder.num_patches
280 | # Will use learnbale sin-cos embedding:
281 | self.pos_embed = nn.Parameter(torch.zeros(
282 | 1, num_patches, hidden_size), requires_grad=True)
283 |
284 | half_depth = (depth - decode_layer)//2
285 | self.half_depth=half_depth
286 |
287 | self.en_inblocks = nn.ModuleList([
288 | MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches) for _ in range(half_depth)
289 | ])
290 | self.en_outblocks = nn.ModuleList([
291 | MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches, skip=True) for _ in range(half_depth)
292 | ])
293 | self.de_blocks = nn.ModuleList([
294 | MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches, skip=True) for i in range(decode_layer)
295 | ])
296 | self.sideblocks = nn.ModuleList([
297 | MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches) for _ in range(1)
298 | ])
299 | self.final_layer = FinalLayer(
300 | hidden_size, patch_size, self.out_channels)
301 |
302 | self.decoder_pos_embed = nn.Parameter(torch.zeros(
303 | 1, num_patches, hidden_size), requires_grad=True)
304 | if mask_ratio is not None:
305 | self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
306 | self.mask_ratio = float(mask_ratio)
307 | self.decode_layer = int(decode_layer)
308 | else:
309 | self.mask_token = nn.Parameter(torch.zeros(
310 | 1, 1, hidden_size), requires_grad=False)
311 | self.mask_ratio = None
312 | self.decode_layer = int(decode_layer)
313 | print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer)
314 | self.initialize_weights()
315 |
316 | def initialize_weights(self):
317 | # Initialize transformer layers:
318 | def _basic_init(module):
319 | if isinstance(module, nn.Linear):
320 | torch.nn.init.xavier_uniform_(module.weight)
321 | if module.bias is not None:
322 | nn.init.constant_(module.bias, 0)
323 | self.apply(_basic_init)
324 |
325 | # Initialize pos_embed by sin-cos embedding:
326 | pos_embed = get_2d_sincos_pos_embed(
327 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
328 | self.pos_embed.data.copy_(
329 | torch.from_numpy(pos_embed).float().unsqueeze(0))
330 |
331 | decoder_pos_embed = get_2d_sincos_pos_embed(
332 | self.decoder_pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
333 | self.decoder_pos_embed.data.copy_(
334 | torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
335 |
336 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
337 | w = self.x_embedder.proj.weight.data
338 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
339 | nn.init.constant_(self.x_embedder.proj.bias, 0)
340 |
341 | # Initialize label embedding table:
342 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
343 |
344 | # Initialize timestep embedding MLP:
345 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
346 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
347 |
348 | for block in self.en_inblocks:
349 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
350 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
351 |
352 | for block in self.en_outblocks:
353 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
354 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
355 |
356 | for block in self.de_blocks:
357 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
358 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
359 |
360 | for block in self.sideblocks:
361 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
362 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
363 |
364 | # Zero-out output layers:
365 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
366 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
367 | nn.init.constant_(self.final_layer.linear.weight, 0)
368 | nn.init.constant_(self.final_layer.linear.bias, 0)
369 |
370 | if self.mask_ratio is not None:
371 | torch.nn.init.normal_(self.mask_token, std=.02)
372 |
373 | def unpatchify(self, x):
374 | """
375 | x: (N, T, patch_size**2 * C)
376 | imgs: (N, H, W, C)
377 | """
378 | c = self.out_channels
379 | p = self.x_embedder.patch_size[0]
380 | h = w = int(x.shape[1] ** 0.5)
381 | assert h * w == x.shape[1]
382 |
383 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
384 | x = torch.einsum('nhwpqc->nchpwq', x)
385 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
386 | return imgs
387 |
388 | def random_masking(self, x, mask_ratio):
389 | """
390 | Perform per-sample random masking by per-sample shuffling.
391 | Per-sample shuffling is done by argsort random noise.
392 | x: [N, L, D], sequence
393 | """
394 | N, L, D = x.shape # batch, length, dim
395 | len_keep = int(L * (1 - mask_ratio))
396 |
397 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
398 |
399 | # sort noise for each sample
400 | # ascend: small is keep, large is remove
401 | ids_shuffle = torch.argsort(noise, dim=1)
402 | ids_restore = torch.argsort(ids_shuffle, dim=1)
403 |
404 | # keep the first subset
405 | ids_keep = ids_shuffle[:, :len_keep]
406 | x_masked = torch.gather(
407 | x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
408 |
409 | # generate the binary mask: 0 is keep, 1 is remove
410 | mask = torch.ones([N, L], device=x.device)
411 | mask[:, :len_keep] = 0
412 | # unshuffle to get the binary mask
413 | mask = torch.gather(mask, dim=1, index=ids_restore)
414 |
415 | return x_masked, mask, ids_restore, ids_keep
416 |
417 | def forward_side_interpolater(self, x, c, mask, ids_restore):
418 | # append mask tokens to sequence
419 | mask_tokens = self.mask_token.repeat(
420 | x.shape[0], ids_restore.shape[1] - x.shape[1], 1)
421 | x_ = torch.cat([x, mask_tokens], dim=1)
422 | x = torch.gather(
423 | x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
424 |
425 | # add pos embed
426 | x = x + self.decoder_pos_embed
427 |
428 | # pass to the basic block
429 | x_before = x
430 | for sideblock in self.sideblocks:
431 | x = sideblock(x, c, ids_keep=None)
432 |
433 | # masked shortcut
434 | mask = mask.unsqueeze(dim=-1)
435 | x = x*mask + (1-mask)*x_before
436 |
437 | return x
438 |
439 | def forward(self, x, t, y, enable_mask=False):
440 | """
441 | Forward pass of MDT.
442 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
443 | t: (N,) tensor of diffusion timesteps
444 | y: (N,) tensor of class labels
445 | enable_mask: Use mask latent modeling
446 | """
447 | x = self.x_embedder(
448 | x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
449 |
450 | t = self.t_embedder(t) # (N, D)
451 | y = self.y_embedder(y, self.training) # (N, D)
452 | c = t + y # (N, D)
453 |
454 |
455 | input_skip = x
456 |
457 | masked_stage = False
458 | skips = []
459 | # masking op for training
460 | if self.mask_ratio is not None and enable_mask:
461 | # masking: length -> length * mask_ratio
462 | rand_mask_ratio = torch.rand(1, device=x.device) # noise in [0, 1]
463 | rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2
464 | # print(rand_mask_ratio)
465 | x, mask, ids_restore, ids_keep = self.random_masking(
466 | x, rand_mask_ratio)
467 | masked_stage = True
468 |
469 |
470 | for block in self.en_inblocks:
471 | if masked_stage:
472 | x = block(x, c, ids_keep=ids_keep)
473 | else:
474 | x = block(x, c, ids_keep=None)
475 | skips.append(x)
476 |
477 | for block in self.en_outblocks:
478 | if masked_stage:
479 | x = block(x, c, skip=skips.pop(), ids_keep=ids_keep)
480 | else:
481 | x = block(x, c, skip=skips.pop(), ids_keep=None)
482 |
483 | if self.mask_ratio is not None and enable_mask:
484 | x = self.forward_side_interpolater(x, c, mask, ids_restore)
485 | masked_stage = False
486 | else:
487 | # add pos embed
488 | x = x + self.decoder_pos_embed
489 |
490 | for i in range(len(self.de_blocks)):
491 | block = self.de_blocks[i]
492 | this_skip = input_skip
493 |
494 | x = block(x, c, skip=this_skip, ids_keep=None)
495 |
496 | x = self.final_layer(x, c)
497 | x = self.unpatchify(x) # (N, out_channels, H, W)
498 | return x
499 |
500 |
501 | def forward_with_cfg(self, x, t, y, cfg_scale=None, diffusion_steps=1000, scale_pow=4.0):
502 | """
503 | Forward pass of MDT, but also batches the unconditional forward pass for classifier-free guidance.
504 | """
505 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
506 | if cfg_scale is not None:
507 | half = x[: len(x) // 2]
508 | combined = torch.cat([half, half], dim=0)
509 | model_out = self.forward(combined, t, y)
510 | eps, rest = model_out[:, :3], model_out[:, 3:]
511 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
512 |
513 | scale_step = (
514 | 1-torch.cos(((1-t/diffusion_steps)**scale_pow)*math.pi))*1/2 # power-cos scaling
515 | real_cfg_scale = (cfg_scale-1)*scale_step + 1
516 | real_cfg_scale = real_cfg_scale[: len(x) // 2].view(-1, 1, 1, 1)
517 |
518 | half_eps = uncond_eps + real_cfg_scale * (cond_eps - uncond_eps)
519 | eps = torch.cat([half_eps, half_eps], dim=0)
520 | return torch.cat([eps, rest], dim=1)
521 | else:
522 | model_out = self.forward(x, t, y)
523 | eps, rest = model_out[:, :3], model_out[:, 3:]
524 | return torch.cat([eps, rest], dim=1)
525 |
526 |
527 | #################################################################################
528 | # Sine/Cosine Positional Embedding Functions #
529 | #################################################################################
530 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
531 |
532 |
533 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
534 | """
535 | grid_size: int of the grid height and width
536 | return:
537 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
538 | """
539 | grid_h = np.arange(grid_size, dtype=np.float32)
540 | grid_w = np.arange(grid_size, dtype=np.float32)
541 | grid = np.meshgrid(grid_w, grid_h) # here w goes first
542 | grid = np.stack(grid, axis=0)
543 |
544 | grid = grid.reshape([2, 1, grid_size, grid_size])
545 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
546 | if cls_token and extra_tokens > 0:
547 | pos_embed = np.concatenate(
548 | [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
549 | return pos_embed
550 |
551 |
552 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
553 | assert embed_dim % 2 == 0
554 |
555 | # use half of dimensions to encode grid_h
556 | emb_h = get_1d_sincos_pos_embed_from_grid(
557 | embed_dim // 2, grid[0]) # (H*W, D/2)
558 | emb_w = get_1d_sincos_pos_embed_from_grid(
559 | embed_dim // 2, grid[1]) # (H*W, D/2)
560 |
561 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
562 | return emb
563 |
564 |
565 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
566 | """
567 | embed_dim: output dimension for each position
568 | pos: a list of positions to be encoded: size (M,)
569 | out: (M, D)
570 | """
571 | assert embed_dim % 2 == 0
572 | omega = np.arange(embed_dim // 2, dtype=np.float64)
573 | omega /= embed_dim / 2.
574 | omega = 1. / 10000**omega # (D/2,)
575 |
576 | pos = pos.reshape(-1) # (M,)
577 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
578 |
579 | emb_sin = np.sin(out) # (M, D/2)
580 | emb_cos = np.cos(out) # (M, D/2)
581 |
582 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
583 | return emb
584 |
585 |
586 | #################################################################################
587 | # MDTv2 Configs #
588 | #################################################################################
589 |
590 | def MDTv2_XL_2(**kwargs):
591 | return MDTv2(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
592 |
593 | def MDTv2_L_2(**kwargs):
594 | return MDTv2(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
595 |
596 | def MDTv2_B_2(**kwargs):
597 | return MDTv2(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
598 |
599 | def MDTv2_S_2(**kwargs):
600 | return MDTv2(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
601 |
602 |
--------------------------------------------------------------------------------
/masked_diffusion/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | import math
6 |
7 | import torch as th
8 | import torch.nn as nn
9 |
10 |
11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
12 | class SiLU(nn.Module):
13 | def forward(self, x):
14 | return x * th.sigmoid(x)
15 |
16 |
17 | class GroupNorm32(nn.GroupNorm):
18 | def forward(self, x):
19 | return super().forward(x.float()).type(x.dtype)
20 |
21 |
22 | def conv_nd(dims, *args, **kwargs):
23 | """
24 | Create a 1D, 2D, or 3D convolution module.
25 | """
26 | if dims == 1:
27 | return nn.Conv1d(*args, **kwargs)
28 | elif dims == 2:
29 | return nn.Conv2d(*args, **kwargs)
30 | elif dims == 3:
31 | return nn.Conv3d(*args, **kwargs)
32 | raise ValueError(f"unsupported dimensions: {dims}")
33 |
34 |
35 | def linear(*args, **kwargs):
36 | """
37 | Create a linear module.
38 | """
39 | return nn.Linear(*args, **kwargs)
40 |
41 |
42 | def avg_pool_nd(dims, *args, **kwargs):
43 | """
44 | Create a 1D, 2D, or 3D average pooling module.
45 | """
46 | if dims == 1:
47 | return nn.AvgPool1d(*args, **kwargs)
48 | elif dims == 2:
49 | return nn.AvgPool2d(*args, **kwargs)
50 | elif dims == 3:
51 | return nn.AvgPool3d(*args, **kwargs)
52 | raise ValueError(f"unsupported dimensions: {dims}")
53 |
54 |
55 | def update_ema(target_params, source_params, rate=0.99):
56 | """
57 | Update target parameters to be closer to those of source parameters using
58 | an exponential moving average.
59 |
60 | :param target_params: the target parameter sequence.
61 | :param source_params: the source parameter sequence.
62 | :param rate: the EMA rate (closer to 1 means slower).
63 | """
64 | for targ, src in zip(target_params, source_params):
65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
66 |
67 |
68 | def zero_module(module):
69 | """
70 | Zero out the parameters of a module and return it.
71 | """
72 | for p in module.parameters():
73 | p.detach().zero_()
74 | return module
75 |
76 |
77 | def scale_module(module, scale):
78 | """
79 | Scale the parameters of a module and return it.
80 | """
81 | for p in module.parameters():
82 | p.detach().mul_(scale)
83 | return module
84 |
85 |
86 | def mean_flat(tensor):
87 | """
88 | Take the mean over all non-batch dimensions.
89 | """
90 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
91 |
92 |
93 | def normalization(channels):
94 | """
95 | Make a standard normalization layer.
96 |
97 | :param channels: number of input channels.
98 | :return: an nn.Module for normalization.
99 | """
100 | return GroupNorm32(32, channels)
101 |
102 |
103 | def timestep_embedding(timesteps, dim, max_period=10000):
104 | """
105 | Create sinusoidal timestep embeddings.
106 |
107 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
108 | These may be fractional.
109 | :param dim: the dimension of the output.
110 | :param max_period: controls the minimum frequency of the embeddings.
111 | :return: an [N x dim] Tensor of positional embeddings.
112 | """
113 | half = dim // 2
114 | freqs = th.exp(
115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
116 | ).to(device=timesteps.device)
117 | args = timesteps[:, None].float() * freqs[None]
118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
119 | if dim % 2:
120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
121 | return embedding
122 |
123 |
124 | def checkpoint(func, inputs, params, flag):
125 | """
126 | Evaluate a function without caching intermediate activations, allowing for
127 | reduced memory at the expense of extra compute in the backward pass.
128 |
129 | :param func: the function to evaluate.
130 | :param inputs: the argument sequence to pass to `func`.
131 | :param params: a sequence of parameters `func` depends on but does not
132 | explicitly take as arguments.
133 | :param flag: if False, disable gradient checkpointing.
134 | """
135 | if flag:
136 | args = tuple(inputs) + tuple(params)
137 | return CheckpointFunction.apply(func, len(inputs), *args)
138 | else:
139 | return func(*inputs)
140 |
141 |
142 | class CheckpointFunction(th.autograd.Function):
143 | @staticmethod
144 | def forward(ctx, run_function, length, *args):
145 | ctx.run_function = run_function
146 | ctx.input_tensors = list(args[:length])
147 | ctx.input_params = list(args[length:])
148 | with th.no_grad():
149 | output_tensors = ctx.run_function(*ctx.input_tensors)
150 | return output_tensors
151 |
152 | @staticmethod
153 | def backward(ctx, *output_grads):
154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
155 | with th.enable_grad():
156 | # Fixes a bug where the first op in run_function modifies the
157 | # Tensor storage in place, which is not allowed for detach()'d
158 | # Tensors.
159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
160 | output_tensors = ctx.run_function(*shallow_copies)
161 | input_grads = th.autograd.grad(
162 | output_tensors,
163 | ctx.input_tensors + ctx.input_params,
164 | output_grads,
165 | allow_unused=True,
166 | )
167 | del ctx.input_tensors
168 | del ctx.input_params
169 | del output_tensors
170 | return (None, None) + input_grads
171 |
--------------------------------------------------------------------------------
/masked_diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | elif name == "loss-second-moment":
18 | return LossSecondMomentResampler(diffusion)
19 | else:
20 | raise NotImplementedError(f"unknown schedule sampler: {name}")
21 |
22 |
23 | class ScheduleSampler(ABC):
24 | """
25 | A distribution over timesteps in the diffusion process, intended to reduce
26 | variance of the objective.
27 |
28 | By default, samplers perform unbiased importance sampling, in which the
29 | objective's mean is unchanged.
30 | However, subclasses may override sample() to change how the resampled
31 | terms are reweighted, allowing for actual changes in the objective.
32 | """
33 |
34 | @abstractmethod
35 | def weights(self):
36 | """
37 | Get a numpy array of weights, one per diffusion step.
38 |
39 | The weights needn't be normalized, but must be positive.
40 | """
41 |
42 | def sample(self, batch_size, device):
43 | """
44 | Importance-sample timesteps for a batch.
45 |
46 | :param batch_size: the number of timesteps.
47 | :param device: the torch device to save to.
48 | :return: a tuple (timesteps, weights):
49 | - timesteps: a tensor of timestep indices.
50 | - weights: a tensor of weights to scale the resulting losses.
51 | """
52 | w = self.weights()
53 | p = w / np.sum(w)
54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55 | indices = th.from_numpy(indices_np).long().to(device)
56 | weights_np = 1 / (len(p) * p[indices_np])
57 | weights = th.from_numpy(weights_np).float().to(device)
58 | return indices, weights
59 |
60 |
61 | class UniformSampler(ScheduleSampler):
62 | def __init__(self, diffusion):
63 | self.diffusion = diffusion
64 | self._weights = np.ones([diffusion.num_timesteps])
65 |
66 | def weights(self):
67 | return self._weights
68 |
69 |
70 | class LossAwareSampler(ScheduleSampler):
71 | def update_with_local_losses(self, local_ts, local_losses):
72 | """
73 | Update the reweighting using losses from a model.
74 |
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 |
80 | :param local_ts: an integer Tensor of timesteps.
81 | :param local_losses: a 1D Tensor of losses.
82 | """
83 | batch_sizes = [
84 | th.tensor([0], dtype=th.int32, device=local_ts.device)
85 | for _ in range(dist.get_world_size())
86 | ]
87 | dist.all_gather(
88 | batch_sizes,
89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90 | )
91 |
92 | # Pad all_gather batches to be the maximum batch size.
93 | batch_sizes = [x.item() for x in batch_sizes]
94 | max_bs = max(batch_sizes)
95 |
96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98 | dist.all_gather(timestep_batches, local_ts)
99 | dist.all_gather(loss_batches, local_losses)
100 | timesteps = [
101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102 | ]
103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104 | self.update_with_all_losses(timesteps, losses)
105 |
106 | @abstractmethod
107 | def update_with_all_losses(self, ts, losses):
108 | """
109 | Update the reweighting using losses from a model.
110 |
111 | Sub-classes should override this method to update the reweighting
112 | using losses from the model.
113 |
114 | This method directly updates the reweighting without synchronizing
115 | between workers. It is called by update_with_local_losses from all
116 | ranks with identical arguments. Thus, it should have deterministic
117 | behavior to maintain state across workers.
118 |
119 | :param ts: a list of int timesteps.
120 | :param losses: a list of float losses, one per timestep.
121 | """
122 |
123 |
124 | class LossSecondMomentResampler(LossAwareSampler):
125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126 | self.diffusion = diffusion
127 | self.history_per_term = history_per_term
128 | self.uniform_prob = uniform_prob
129 | self._loss_history = np.zeros(
130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
131 | )
132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133 |
134 | def weights(self):
135 | if not self._warmed_up():
136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138 | weights /= np.sum(weights)
139 | weights *= 1 - self.uniform_prob
140 | weights += self.uniform_prob / len(weights)
141 | return weights
142 |
143 | def update_with_all_losses(self, ts, losses):
144 | for t, loss in zip(ts, losses):
145 | if self._loss_counts[t] == self.history_per_term:
146 | # Shift out the oldest loss term.
147 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
148 | self._loss_history[t, -1] = loss
149 | else:
150 | self._loss_history[t, self._loss_counts[t]] = loss
151 | self._loss_counts[t] += 1
152 |
153 | def _warmed_up(self):
154 | return (self._loss_counts == self.history_per_term).all()
155 |
--------------------------------------------------------------------------------
/masked_diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, str):
32 | if section_counts.startswith("ddim"):
33 | desired_count = int(section_counts[len("ddim") :])
34 | for i in range(1, num_timesteps):
35 | if len(range(0, num_timesteps, i)) == desired_count:
36 | return set(range(0, num_timesteps, i))
37 | raise ValueError(
38 | f"cannot create exactly {num_timesteps} steps with an integer stride"
39 | )
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(
49 | f"cannot divide section of {size} steps into {section_count}"
50 | )
51 | if section_count <= 1:
52 | frac_stride = 1
53 | else:
54 | frac_stride = (size - 1) / (section_count - 1)
55 | cur_idx = 0.0
56 | taken_steps = []
57 | for _ in range(section_count):
58 | taken_steps.append(start_idx + round(cur_idx))
59 | cur_idx += frac_stride
60 | all_steps += taken_steps
61 | start_idx += size
62 | return set(all_steps)
63 |
64 |
65 | class SpacedDiffusion(GaussianDiffusion):
66 | """
67 | A diffusion process which can skip steps in a base diffusion process.
68 | :param use_timesteps: a collection (sequence or set) of timesteps from the
69 | original diffusion process to retain.
70 | :param kwargs: the kwargs to create the base diffusion process.
71 | """
72 |
73 | def __init__(self, use_timesteps, **kwargs):
74 | self.use_timesteps = set(use_timesteps)
75 | self.timestep_map = []
76 | self.original_num_steps = len(kwargs["betas"])
77 |
78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79 | last_alpha_cumprod = 1.0
80 | new_betas = []
81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82 | if i in self.use_timesteps:
83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84 | last_alpha_cumprod = alpha_cumprod
85 | self.timestep_map.append(i)
86 | kwargs["betas"] = np.array(new_betas)
87 | super().__init__(**kwargs)
88 |
89 | def p_mean_variance(
90 | self, model, *args, **kwargs
91 | ): # pylint: disable=signature-differs
92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93 |
94 | def training_losses(
95 | self, model, *args, **kwargs
96 | ): # pylint: disable=signature-differs
97 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
98 |
99 | def condition_mean(self, cond_fn, *args, **kwargs):
100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101 |
102 | def condition_score(self, cond_fn, *args, **kwargs):
103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104 |
105 | def _wrap_model(self, model):
106 | if isinstance(model, _WrappedModel):
107 | return model
108 | return _WrappedModel(
109 | model, self.timestep_map, self.original_num_steps
110 | )
111 |
112 | def _scale_timesteps(self, t):
113 | # Scaling is done by the wrapped model.
114 | return t
115 |
116 | class _WrappedModel:
117 | def __init__(self, model, timestep_map, original_num_steps):
118 | self.model = model
119 | self.timestep_map = timestep_map
120 | # self.rescale_timesteps = rescale_timesteps
121 | self.original_num_steps = original_num_steps
122 |
123 | def __call__(self, x, ts, **kwargs):
124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
125 | new_ts = map_tensor[ts]
126 | # if self.rescale_timesteps:
127 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
128 | return self.model(x, new_ts, **kwargs)
129 |
--------------------------------------------------------------------------------
/masked_diffusion/script_util.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import inspect
3 |
4 | from . import gaussian_diffusion as gd
5 | from .respace import SpacedDiffusion, space_timesteps
6 |
7 | NUM_CLASSES = 1000
8 |
9 | def add_dict_to_argparser(parser, default_dict):
10 | for k, v in default_dict.items():
11 | v_type = type(v)
12 | if v is None:
13 | v_type = str
14 | elif isinstance(v, bool):
15 | v_type = str2bool
16 | parser.add_argument(f"--{k}", default=v, type=v_type)
17 |
18 |
19 | def args_to_dict(args, keys):
20 | return {k: getattr(args, k) for k in keys}
21 |
22 |
23 | def str2bool(v):
24 | """
25 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
26 | """
27 | if isinstance(v, bool):
28 | return v
29 | if v.lower() in ("yes", "true", "t", "y", "1"):
30 | return True
31 | elif v.lower() in ("no", "false", "f", "n", "0"):
32 | return False
33 | else:
34 | raise argparse.ArgumentTypeError("boolean value expected")
35 |
--------------------------------------------------------------------------------
/masked_diffusion/timestep_sampler.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | from abc import ABC, abstractmethod
7 |
8 | import numpy as np
9 | import torch as th
10 | import torch.distributed as dist
11 |
12 |
13 | def create_named_schedule_sampler(name, diffusion):
14 | """
15 | Create a ScheduleSampler from a library of pre-defined samplers.
16 | :param name: the name of the sampler.
17 | :param diffusion: the diffusion object to sample for.
18 | """
19 | if name == "uniform":
20 | return UniformSampler(diffusion)
21 | elif name == "loss-second-moment":
22 | return LossSecondMomentResampler(diffusion)
23 | else:
24 | raise NotImplementedError(f"unknown schedule sampler: {name}")
25 |
26 |
27 | class ScheduleSampler(ABC):
28 | """
29 | A distribution over timesteps in the diffusion process, intended to reduce
30 | variance of the objective.
31 | By default, samplers perform unbiased importance sampling, in which the
32 | objective's mean is unchanged.
33 | However, subclasses may override sample() to change how the resampled
34 | terms are reweighted, allowing for actual changes in the objective.
35 | """
36 |
37 | @abstractmethod
38 | def weights(self):
39 | """
40 | Get a numpy array of weights, one per diffusion step.
41 | The weights needn't be normalized, but must be positive.
42 | """
43 |
44 | def sample(self, batch_size, device):
45 | """
46 | Importance-sample timesteps for a batch.
47 | :param batch_size: the number of timesteps.
48 | :param device: the torch device to save to.
49 | :return: a tuple (timesteps, weights):
50 | - timesteps: a tensor of timestep indices.
51 | - weights: a tensor of weights to scale the resulting losses.
52 | """
53 | w = self.weights()
54 | p = w / np.sum(w)
55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56 | indices = th.from_numpy(indices_np).long().to(device)
57 | weights_np = 1 / (len(p) * p[indices_np])
58 | weights = th.from_numpy(weights_np).float().to(device)
59 | return indices, weights
60 |
61 |
62 | class UniformSampler(ScheduleSampler):
63 | def __init__(self, diffusion):
64 | self.diffusion = diffusion
65 | self._weights = np.ones([diffusion.num_timesteps])
66 |
67 | def weights(self):
68 | return self._weights
69 |
70 |
71 | class LossAwareSampler(ScheduleSampler):
72 | def update_with_local_losses(self, local_ts, local_losses):
73 | """
74 | Update the reweighting using losses from a model.
75 | Call this method from each rank with a batch of timesteps and the
76 | corresponding losses for each of those timesteps.
77 | This method will perform synchronization to make sure all of the ranks
78 | maintain the exact same reweighting.
79 | :param local_ts: an integer Tensor of timesteps.
80 | :param local_losses: a 1D Tensor of losses.
81 | """
82 | batch_sizes = [
83 | th.tensor([0], dtype=th.int32, device=local_ts.device)
84 | for _ in range(dist.get_world_size())
85 | ]
86 | dist.all_gather(
87 | batch_sizes,
88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89 | )
90 |
91 | # Pad all_gather batches to be the maximum batch size.
92 | batch_sizes = [x.item() for x in batch_sizes]
93 | max_bs = max(batch_sizes)
94 |
95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97 | dist.all_gather(timestep_batches, local_ts)
98 | dist.all_gather(loss_batches, local_losses)
99 | timesteps = [
100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101 | ]
102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103 | self.update_with_all_losses(timesteps, losses)
104 |
105 | @abstractmethod
106 | def update_with_all_losses(self, ts, losses):
107 | """
108 | Update the reweighting using losses from a model.
109 | Sub-classes should override this method to update the reweighting
110 | using losses from the model.
111 | This method directly updates the reweighting without synchronizing
112 | between workers. It is called by update_with_local_losses from all
113 | ranks with identical arguments. Thus, it should have deterministic
114 | behavior to maintain state across workers.
115 | :param ts: a list of int timesteps.
116 | :param losses: a list of float losses, one per timestep.
117 | """
118 |
119 |
120 | class LossSecondMomentResampler(LossAwareSampler):
121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122 | self.diffusion = diffusion
123 | self.history_per_term = history_per_term
124 | self.uniform_prob = uniform_prob
125 | self._loss_history = np.zeros(
126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64
127 | )
128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129 |
130 | def weights(self):
131 | if not self._warmed_up():
132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134 | weights /= np.sum(weights)
135 | weights *= 1 - self.uniform_prob
136 | weights += self.uniform_prob / len(weights)
137 | return weights
138 |
139 | def update_with_all_losses(self, ts, losses):
140 | for t, loss in zip(ts, losses):
141 | if self._loss_counts[t] == self.history_per_term:
142 | # Shift out the oldest loss term.
143 | self._loss_history[t, :-1] = self._loss_history[t, 1:]
144 | self._loss_history[t, -1] = loss
145 | else:
146 | self._loss_history[t, self._loss_counts[t]] = loss
147 | self._loss_counts[t] += 1
148 |
149 | def _warmed_up(self):
150 | return (self._loss_counts == self.history_per_term).all()
151 |
--------------------------------------------------------------------------------
/masked_diffusion/train_util.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import functools
3 | import os
4 |
5 | import blobfile as bf
6 | import torch as th
7 | import torch.distributed as dist
8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP
9 | from torch.optim import AdamW
10 |
11 | from . import dist_util, logger
12 | from .fp16_util import MixedPrecisionTrainer
13 | from .nn import update_ema
14 | from .resample import LossAwareSampler, UniformSampler
15 | from diffusers.models import AutoencoderKL
16 | from adan import Adan
17 | from torch.distributed.optim import ZeroRedundancyOptimizer
18 | # For ImageNet experiments, this was a good default value.
19 | # We found that the lg_loss_scale quickly climbed to
20 | # 20-21 within the first ~1K steps of training.
21 | INITIAL_LOG_LOSS_SCALE = 20.0
22 |
23 |
24 | class TrainLoop:
25 | def __init__(
26 | self,
27 | *,
28 | model,
29 | diffusion,
30 | data,
31 | batch_size,
32 | microbatch,
33 | lr,
34 | ema_rate,
35 | log_interval,
36 | save_interval,
37 | resume_checkpoint,
38 | use_fp16=False,
39 | fp16_scale_growth=1e-3,
40 | schedule_sampler=None,
41 | weight_decay=0.0,
42 | lr_anneal_steps=0,
43 | scale_factor=0.18215, # scale_factor follows DiT and stable diffusion.
44 | opt_type='adan',
45 | use_zero=True,
46 | ):
47 | self.model = model
48 | self.diffusion = diffusion
49 | self.data = data
50 | self.batch_size = batch_size
51 | self.microbatch = microbatch if microbatch > 0 else batch_size
52 | self.lr = lr
53 | self.ema_rate = (
54 | [ema_rate]
55 | if isinstance(ema_rate, float)
56 | else [float(x) for x in ema_rate.split(",")]
57 | )
58 | self.log_interval = log_interval
59 | self.save_interval = save_interval
60 | self.resume_checkpoint = resume_checkpoint
61 | self.use_fp16 = use_fp16
62 | self.fp16_scale_growth = fp16_scale_growth
63 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)
64 | self.weight_decay = weight_decay
65 | self.lr_anneal_steps = lr_anneal_steps
66 | self.scale_factor = scale_factor
67 |
68 | self.step = 0
69 | self.resume_step = 0
70 | self.global_batch = self.batch_size * dist.get_world_size()
71 |
72 | self.sync_cuda = th.cuda.is_available()
73 |
74 | self._load_and_sync_parameters()
75 | self.mp_trainer = MixedPrecisionTrainer(
76 | model=self.model,
77 | use_fp16=self.use_fp16,
78 | fp16_scale_growth=fp16_scale_growth,
79 | )
80 |
81 | if opt_type=='adamw':
82 | if use_zero:
83 | self.opt = ZeroRedundancyOptimizer(
84 | self.mp_trainer.master_params,
85 | optimizer_class=Adam,
86 | lr=self.lr,
87 | weight_decay=self.weight_decay
88 | )
89 | else:
90 | self.opt = AdamW(
91 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay
92 | )
93 | elif opt_type=='adan':
94 | if use_zero:
95 | self.opt = ZeroRedundancyOptimizer(
96 | self.mp_trainer.master_params,
97 | optimizer_class=Adan,
98 | lr=self.lr,
99 | weight_decay=self.weight_decay,
100 | max_grad_norm=1, fused=True
101 | )
102 | else:
103 | self.opt = Adan(
104 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay, max_grad_norm=1, fused=True)
105 |
106 | if self.resume_step:
107 | self._load_optimizer_state()
108 | # Model was resumed, either due to a restart or a checkpoint
109 | # being specified at the command line.
110 | self.ema_params = [
111 | self._load_ema_parameters(rate) for rate in self.ema_rate
112 | ]
113 | else:
114 | self.ema_params = [
115 | copy.deepcopy(self.mp_trainer.master_params)
116 | for _ in range(len(self.ema_rate))
117 | ]
118 |
119 | if th.cuda.is_available():
120 | self.use_ddp = True
121 | self.ddp_model = DDP(
122 | self.model,
123 | device_ids=[dist_util.dev()],
124 | output_device=dist_util.dev(),
125 | broadcast_buffers=False,
126 | bucket_cap_mb=128,
127 | find_unused_parameters=False,
128 | )
129 | else:
130 | if dist.get_world_size() > 1:
131 | logger.warn(
132 | "Distributed training requires CUDA. "
133 | "Gradients will not be synchronized properly!"
134 | )
135 | self.use_ddp = False
136 | self.ddp_model = self.model
137 | self.instantiate_first_stage()
138 |
139 |
140 | def instantiate_first_stage(self):
141 | model = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dist_util.dev())
142 | model = th.compile(model)
143 | self.first_stage_model = model.eval()
144 | self.first_stage_model.train = False
145 | for param in self.first_stage_model.parameters():
146 | param.requires_grad = False
147 |
148 | # https://github.com/huggingface/diffusers/blob/29b2c93c9005c87f8f04b1f0835babbcea736204/src/diffusers/models/autoencoder_kl.py
149 | @th.no_grad()
150 | def get_first_stage_encoding(self, x):
151 | encoder_posterior = self.first_stage_model.encode(x, return_dict=True)[0]
152 |
153 | z = encoder_posterior.sample()
154 | return z.to(dist_util.dev()) * self.scale_factor
155 |
156 | def _load_and_sync_parameters(self):
157 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
158 |
159 | if resume_checkpoint:
160 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint)
161 | if dist.get_rank() == 0:
162 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
163 | self.model.load_state_dict(
164 | dist_util.load_state_dict(
165 | resume_checkpoint, map_location=dist_util.dev()
166 | )
167 | )
168 |
169 | dist_util.sync_params(self.model.parameters())
170 |
171 | def _load_ema_parameters(self, rate):
172 | ema_params = copy.deepcopy(self.mp_trainer.master_params)
173 |
174 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
175 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)
176 | if ema_checkpoint:
177 | if dist.get_rank() == 0:
178 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
179 | state_dict = dist_util.load_state_dict(
180 | ema_checkpoint, map_location=dist_util.dev()
181 | )
182 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict)
183 |
184 | dist_util.sync_params(ema_params)
185 | return ema_params
186 |
187 | def _load_optimizer_state(self):
188 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint
189 | opt_checkpoint = bf.join(
190 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt"
191 | )
192 | if bf.exists(opt_checkpoint):
193 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
194 | state_dict = dist_util.load_state_dict(
195 | opt_checkpoint, map_location=dist_util.dev()
196 | )
197 | self.opt.load_state_dict(state_dict)
198 |
199 | def run_loop(self):
200 | while (
201 | not self.lr_anneal_steps
202 | or self.step + self.resume_step < self.lr_anneal_steps
203 | ):
204 | batch, cond = next(self.data)
205 | self.run_step(batch, cond)
206 | if self.step % self.log_interval == 0:
207 | logger.dumpkvs()
208 | if self.step % self.save_interval == 0:
209 | self.opt.consolidate_state_dict()
210 | self.save()
211 | # Run for a finite amount of time in integration tests.
212 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0:
213 | return
214 | self.step += 1
215 | # Save the last checkpoint if it wasn't already saved.
216 | if (self.step - 1) % self.save_interval != 0:
217 | self.save()
218 |
219 | def run_step(self, batch, cond):
220 | self.forward_backward(batch, cond)
221 | took_step = self.mp_trainer.optimize(self.opt)
222 | if took_step:
223 | self._update_ema()
224 | self._anneal_lr()
225 | self.log_step()
226 |
227 | def forward_backward(self, batch, cond):
228 | self.mp_trainer.zero_grad()
229 | for i in range(0, batch.shape[0], self.microbatch):
230 |
231 | micro = batch[i : i + self.microbatch].to(dist_util.dev())
232 | micro = self.get_first_stage_encoding(micro).detach()
233 | micro_cond = {
234 | k: v[i : i + self.microbatch].to(dist_util.dev())
235 | for k, v in cond.items()
236 | }
237 |
238 | last_batch = (i + self.microbatch) >= batch.shape[0]
239 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())
240 |
241 | compute_losses = functools.partial(
242 | self.diffusion.training_losses,
243 | self.ddp_model,
244 | micro,
245 | t,
246 | model_kwargs=micro_cond,
247 | )
248 | micro_cond_mask = micro_cond.copy()
249 | micro_cond_mask['enable_mask']=True
250 | compute_losses_mask = functools.partial(
251 | self.diffusion.training_losses,
252 | self.ddp_model,
253 | micro,
254 | t,
255 | model_kwargs=micro_cond_mask,
256 | )
257 |
258 | if last_batch or not self.use_ddp:
259 | losses = compute_losses()
260 | losses_mask = compute_losses_mask()
261 | else:
262 | with self.ddp_model.no_sync():
263 | losses = compute_losses()
264 | losses_mask = compute_losses_mask()
265 |
266 | if isinstance(self.schedule_sampler, LossAwareSampler):
267 | self.schedule_sampler.update_with_local_losses(
268 | t, losses["loss"].detach() + losses_mask["loss"].detach()
269 | )
270 |
271 | loss = (losses["loss"] * weights).mean() + (losses_mask["loss"] * weights).mean()
272 | log_loss_dict(
273 | self.diffusion, t, {k: v * weights for k, v in losses.items()}
274 | )
275 | log_loss_dict(
276 | self.diffusion, t, {'m_'+k: v * weights for k, v in losses_mask.items()}
277 | )
278 | self.mp_trainer.backward(loss)
279 |
280 | def _update_ema(self):
281 | for rate, params in zip(self.ema_rate, self.ema_params):
282 | update_ema(params, self.mp_trainer.master_params, rate=rate)
283 |
284 | def _anneal_lr(self):
285 | if not self.lr_anneal_steps:
286 | return
287 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps
288 | lr = self.lr * (1 - frac_done)
289 | for param_group in self.opt.param_groups:
290 | param_group["lr"] = lr
291 |
292 | def log_step(self):
293 | logger.logkv("step", self.step + self.resume_step)
294 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch)
295 |
296 | def save(self):
297 | def save_checkpoint(rate, params):
298 | state_dict = self.mp_trainer.master_params_to_state_dict(params)
299 | if dist.get_rank() == 0:
300 | logger.log(f"saving model {rate}...")
301 | if not rate:
302 | filename = f"model{(self.step+self.resume_step):06d}.pt"
303 | else:
304 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt"
305 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f:
306 | th.save(state_dict, f)
307 |
308 | save_checkpoint(0, self.mp_trainer.master_params)
309 | for rate, params in zip(self.ema_rate, self.ema_params):
310 | save_checkpoint(rate, params)
311 |
312 | if dist.get_rank() == 0:
313 | with bf.BlobFile(
314 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"),
315 | "wb",
316 | ) as f:
317 | th.save(self.opt.state_dict(), f)
318 |
319 | dist.barrier()
320 |
321 |
322 | def parse_resume_step_from_filename(filename):
323 | """
324 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
325 | checkpoint's number of steps.
326 | """
327 | split = filename.split("model")
328 | if len(split) < 2:
329 | return 0
330 | split1 = split[-1].split(".")[0]
331 | try:
332 | return int(split1)
333 | except ValueError:
334 | return 0
335 |
336 |
337 | def get_blob_logdir():
338 | # You can change this to be a separate path to save checkpoints to
339 | # a blobstore or some external drive.
340 | return logger.get_dir()
341 |
342 |
343 | def find_resume_checkpoint():
344 | # On your infrastructure, you may want to override this to automatically
345 | # discover the latest checkpoint on your blob storage, etc.
346 | return None
347 |
348 |
349 | def find_ema_checkpoint(main_checkpoint, step, rate):
350 | if main_checkpoint is None:
351 | return None
352 | filename = f"ema_{rate}_{(step):06d}.pt"
353 | path = bf.join(bf.dirname(main_checkpoint), filename)
354 | if bf.exists(path):
355 | return path
356 | return None
357 |
358 |
359 | def log_loss_dict(diffusion, ts, losses):
360 | for key, values in losses.items():
361 | logger.logkv_mean(key, values.mean().item())
362 | # Log the quantiles (four quartiles, in particular).
363 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()):
364 | quartile = int(4 * sub_t / diffusion.num_timesteps)
365 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss)
366 |
--------------------------------------------------------------------------------
/run.sh:
--------------------------------------------------------------------------------
1 | pip3 install torch==2.0 torchvision torchaudio
2 | pip install -e .
3 | python -m pip install git+https://github.com/sail-sg/Adan.git
4 | export OPENAI_LOGDIR=output_mdtv2_s2
5 | NUM_GPUS=8
6 |
7 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 6 --model MDTv2_S_2"
8 | DIFFUSION_FLAGS="--diffusion_steps 1000"
9 | TRAIN_FLAGS="--batch_size 32 --lr 5e-4"
10 | DATA_PATH=/dataset/imagenet-raw/train
11 |
12 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
13 |
--------------------------------------------------------------------------------
/run_ddp_master.sh:
--------------------------------------------------------------------------------
1 | pip3 install torch==2.0 torchvision torchaudio
2 | python -m pip install git+https://github.com/sail-sg/Adan.git
3 | pip install -e .
4 | export OPENAI_LOGDIR=output_mdtv2_xl2
5 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_XL_2"
6 | DIFFUSION_FLAGS="--diffusion_steps 1000"
7 | TRAIN_FLAGS="--batch_size 8"
8 | DATA_PATH=/dataset/imagenet-raw/train
9 | NUM_NODE=4
10 | GPU_PRE_NODE=8
11 |
12 | python -m torch.distributed.launch --master_addr=$(hostname) --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
13 |
--------------------------------------------------------------------------------
/run_ddp_worker.sh:
--------------------------------------------------------------------------------
1 | pip3 install torch==2.0 torchvision torchaudio
2 | python -m pip install git+https://github.com/sail-sg/Adan.git
3 | pip install -e .
4 | export OPENAI_LOGDIR=output_mdtv2_xl2
5 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_XL_2"
6 | DIFFUSION_FLAGS="--diffusion_steps 1000"
7 | TRAIN_FLAGS="--batch_size 8"
8 | DATA_PATH=/dataset/imagenet-raw/train
9 | NUM_NODE=4
10 | GPU_PRE_NODE=8
11 |
12 | python -m torch.distributed.launch --master_addr=$MASTER_ADDR --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
13 |
--------------------------------------------------------------------------------
/run_sample.sh:
--------------------------------------------------------------------------------
1 |
2 | pip3 install torch==2.0 torchvision torchaudio
3 | pip install -e .
4 | # pip uninstall -y timm
5 | # pip install mpi4py timm diffusers
6 | MODEL_PATH=output_mdt_xl2/mdt_xl2_v2_ckpt.pt
7 | export OPENAI_LOGDIR=output_mdt_xl2_eval
8 | NUM_GPUS=8
9 |
10 | echo 'CFG Class-conditional sampling:'
11 | MODEL_FLAGS="--image_size 256 --model MDT_XL_2 --decode_layer 4"
12 | DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000 --cfg_cond True"
13 | echo $MODEL_FLAGS
14 | echo $DIFFUSION_FLAGS
15 | echo $MODEL_PATH
16 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_sample.py --model_path $MODEL_PATH $MODEL_FLAGS $DIFFUSION_FLAGS
17 | echo $MODEL_FLAGS
18 | echo $DIFFUSION_FLAGS
19 | echo $MODEL_PATH
20 | python evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz $OPENAI_LOGDIR/samples_50000x256x256x3.npz
21 |
22 | echo 'Class-conditional sampling:'
23 | MODEL_FLAGS="--image_size 256 --model MDT_XL_2 --decode_layer 4"
24 | DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000"
25 | echo $MODEL_FLAGS
26 | echo $DIFFUSION_FLAGS
27 | echo $MODEL_PATH
28 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_sample.py --model_path $MODEL_PATH $MODEL_FLAGS $DIFFUSION_FLAGS
29 | echo $MODEL_FLAGS
30 | echo $DIFFUSION_FLAGS
31 | echo $MODEL_PATH
32 | python evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz $OPENAI_LOGDIR/samples_50000x256x256x3.npz
33 |
34 |
--------------------------------------------------------------------------------
/scripts/image_sample.py:
--------------------------------------------------------------------------------
1 | """
2 | Generate a large batch of image samples from a model and save them as a large
3 | numpy array. This can be used to produce samples for FID evaluation.
4 | """
5 |
6 | import argparse
7 | import os
8 |
9 | import numpy as np
10 | import torch as th
11 | import torch.distributed as dist
12 | from masked_diffusion.script_util import (
13 | NUM_CLASSES,
14 | add_dict_to_argparser,
15 | args_to_dict,
16 | )
17 |
18 | from masked_diffusion import (
19 | create_diffusion,
20 | model_and_diffusion_defaults,
21 | diffusion_defaults,
22 | dist_util,
23 | logger,
24 | )
25 |
26 | import masked_diffusion.models as models_mdt
27 | from diffusers.models import AutoencoderKL
28 |
29 | def main():
30 | th.backends.cuda.matmul.allow_tf32 = True
31 | args = create_argparser().parse_args()
32 |
33 | dist_util.setup_dist()
34 | logger.configure()
35 |
36 | logger.log("creating model and diffusion...")
37 |
38 | configs = args_to_dict(args, model_and_diffusion_defaults().keys())
39 | print(configs)
40 | image_size = configs['image_size']
41 | latent_size = image_size // 8
42 | model = models_mdt.__dict__[args.model](input_size=latent_size, decode_layer=args.decode_layer)
43 | msg = model.load_state_dict(
44 | dist_util.load_state_dict(args.model_path, map_location="cpu")
45 | )
46 | print(msg)
47 | config_diffusion = args_to_dict(args, diffusion_defaults().keys())
48 | config_diffusion['timestep_respacing']= str(args.num_sampling_steps)
49 | print(config_diffusion)
50 | diffusion = create_diffusion(**config_diffusion)
51 | model.to(dist_util.dev())
52 | model = th.compile(model)
53 | if args.use_fp16:
54 | model.convert_to_fp16()
55 | model.eval()
56 | th.set_grad_enabled(False)
57 |
58 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-"+str(args.vae_decoder)).to(dist_util.dev())
59 |
60 | logger.log("sampling...")
61 | all_images = []
62 | all_labels = []
63 | while len(all_images) * args.batch_size < args.num_samples:
64 | model_kwargs = {}
65 | if args.cfg_cond:
66 | classes = th.randint(
67 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
68 | )
69 | z = th.randn(args.batch_size, 4, latent_size, latent_size, device=dist_util.dev())
70 | # Setup classifier-free guidance:
71 | z = th.cat([z, z], 0)
72 | classes_null = th.tensor([NUM_CLASSES] * args.batch_size, device=dist_util.dev())
73 | classes_all = th.cat([classes, classes_null], 0)
74 | model_kwargs["y"] = classes_all
75 | model_kwargs["cfg_scale"] = args.cfg_scale
76 | model_kwargs["diffusion_steps"] = config_diffusion['diffusion_steps']
77 | model_kwargs["scale_pow"] = args.scale_pow
78 | else:
79 | z = th.randn(args.batch_size, 4, latent_size, latent_size, device=dist_util.dev())
80 | classes = th.randint(
81 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev()
82 | )
83 | model_kwargs["y"] = classes
84 |
85 |
86 | sample_fn = (
87 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop
88 | )
89 | sample = sample_fn(
90 | model.forward_with_cfg,
91 | z.shape,
92 | z,
93 | clip_denoised=args.clip_denoised,
94 | progress=True,
95 | model_kwargs=model_kwargs,
96 | device=dist_util.dev()
97 | )
98 | if args.cfg_cond:
99 | sample, _ = sample.chunk(2, dim=0) # Remove null class samples
100 | # latent to image
101 | sample = vae.decode(sample / 0.18215).sample
102 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) # clip in range -1,1
103 | sample = sample.permute(0, 2, 3, 1)
104 | sample = sample.contiguous()
105 |
106 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())]
107 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL
108 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples])
109 | if args.class_cond:
110 | gathered_labels = [
111 | th.zeros_like(classes) for _ in range(dist.get_world_size())
112 | ]
113 | dist.all_gather(gathered_labels, classes)
114 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels])
115 | # logger.log(f"created {len(all_images) * args.batch_size} samples")
116 |
117 | arr = np.concatenate(all_images, axis=0)
118 | arr = arr[: args.num_samples]
119 | if args.class_cond:
120 | label_arr = np.concatenate(all_labels, axis=0)
121 | label_arr = label_arr[: args.num_samples]
122 | if dist.get_rank() == 0:
123 | shape_str = "x".join([str(x) for x in arr.shape])
124 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz")
125 | logger.log(f"saving to {out_path}")
126 | if args.class_cond:
127 | np.savez(out_path, arr, label_arr)
128 | else:
129 | np.savez(out_path, arr)
130 |
131 | dist.barrier()
132 | logger.log("sampling complete")
133 |
134 |
135 | def create_argparser():
136 | defaults = dict(
137 | num_sampling_steps=250,
138 | clip_denoised=False,
139 | num_samples=5000,
140 | batch_size=16,
141 | use_ddim=False,
142 | model_path="",
143 | model="MDT_S_2",
144 | class_cond=True,
145 | cfg_scale=3.8,
146 | decode_layer=None,
147 | cfg_cond=False,
148 | )
149 | defaults.update(model_and_diffusion_defaults())
150 | parser = argparse.ArgumentParser()
151 | parser.add_argument('--scale_pow', default=4, type=float)
152 | parser.add_argument('--vae_decoder', type=str, default='ema') # ema or mse
153 | parser.add_argument('--world_size', default=1, type=int,
154 | help='number of distributed processes')
155 | parser.add_argument('--local_rank', default=-1, type=int)
156 | parser.add_argument('--local-rank', default=-1, type=int)
157 | parser.add_argument('--dist_on_itp', action='store_true')
158 | parser.add_argument('--dist_url', default='env://',
159 | help='url used to set up distributed training')
160 | parser.add_argument(
161 | "--rank", default=0, type=int, help="""rank for distrbuted training."""
162 | )
163 | add_dict_to_argparser(parser, defaults)
164 | return parser
165 |
166 |
167 | if __name__ == "__main__":
168 | main()
169 |
--------------------------------------------------------------------------------
/scripts/image_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Train a diffusion model on images.
3 | """
4 |
5 | import argparse
6 |
7 | from masked_diffusion import dist_util, logger
8 | from masked_diffusion.image_datasets import load_data
9 | from masked_diffusion.resample import create_named_schedule_sampler
10 | from masked_diffusion.script_util import (
11 | args_to_dict,
12 | add_dict_to_argparser,
13 | )
14 | from masked_diffusion.train_util import TrainLoop
15 | from masked_diffusion import create_diffusion, model_and_diffusion_defaults, diffusion_defaults
16 | import masked_diffusion.models as models_mdt
17 |
18 | def main():
19 | args = create_argparser().parse_args()
20 |
21 | dist_util.setup_dist_multinode(args)
22 | logger.configure()
23 |
24 | logger.log("creating model and diffusion...")
25 | configs = args_to_dict(args, model_and_diffusion_defaults().keys())
26 | print(configs)
27 | print(args)
28 | image_size = configs['image_size']
29 | latent_size = image_size // 8
30 | model = models_mdt.__dict__[args.model](input_size=latent_size, mask_ratio=args.mask_ratio, decode_layer=args.decode_layer)
31 | print(model)
32 | diffusion = create_diffusion(**args_to_dict(args, diffusion_defaults().keys()))
33 | model.to(dist_util.dev())
34 |
35 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)
36 |
37 | logger.log("creating data loader...")
38 | data = load_data(
39 | data_dir=args.data_dir,
40 | batch_size=args.batch_size,
41 | image_size=args.image_size,
42 | class_cond=args.class_cond,
43 | )
44 |
45 | logger.log("training...")
46 | TrainLoop(
47 | model=model,
48 | diffusion=diffusion,
49 | data=data,
50 | batch_size=args.batch_size,
51 | microbatch=args.microbatch,
52 | lr=args.lr,
53 | ema_rate=args.ema_rate,
54 | log_interval=args.log_interval,
55 | save_interval=args.save_interval,
56 | resume_checkpoint=args.resume_checkpoint,
57 | use_fp16=args.use_fp16,
58 | fp16_scale_growth=args.fp16_scale_growth,
59 | schedule_sampler=schedule_sampler,
60 | weight_decay=args.weight_decay,
61 | lr_anneal_steps=args.lr_anneal_steps,
62 | ).run_loop()
63 |
64 |
65 | def create_argparser():
66 | defaults = dict(
67 | data_dir="",
68 | schedule_sampler="uniform",
69 | lr=3e-4,
70 | weight_decay=0.0,
71 | lr_anneal_steps=0,
72 | batch_size=1,
73 | microbatch=-1, # -1 disables microbatches
74 | ema_rate="0.9999", # comma-separated list of EMA values
75 | log_interval=500,
76 | save_interval=10000,
77 | resume_checkpoint="",
78 | use_fp16=False,
79 | fp16_scale_growth=1e-3,
80 | model="MDTv2_S_2",
81 | mask_ratio=None,
82 | decode_layer=4,
83 | )
84 | defaults.update(model_and_diffusion_defaults())
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument('--world_size', default=1, type=int,
87 | help='number of distributed processes')
88 | parser.add_argument('--local_rank', default=-1, type=int)
89 | parser.add_argument('--local-rank', default=-1, type=int)
90 | parser.add_argument('--dist_on_itp', action='store_true')
91 | parser.add_argument('--dist_url', default='env://',
92 | help='url used to set up distributed training')
93 | parser.add_argument(
94 | "--rank", default=0, type=int, help="""rank for distrbuted training."""
95 | )
96 | add_dict_to_argparser(parser, defaults)
97 | return parser
98 |
99 |
100 | if __name__ == "__main__":
101 | main()
102 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 |
3 | setup(
4 | name="masked-diffusion",
5 | py_modules=["masked_diffusion"],
6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"],
7 | )
8 |
--------------------------------------------------------------------------------