├── LICENSE ├── README.md ├── datasets ├── README.md └── lsun_bedroom.py ├── evaluations ├── README.md ├── evaluator.py └── requirements.txt ├── figures ├── .DS_Store └── vis.jpg ├── infer_mdt.py ├── masked_diffusion ├── __init__.py ├── diffusion_utils.py ├── dist_util.py ├── fp16_util.py ├── gaussian_diffusion.py ├── image_datasets.py ├── logger.py ├── models.py ├── nn.py ├── resample.py ├── respace.py ├── script_util.py ├── timestep_sampler.py └── train_util.py ├── run.sh ├── run_ddp_master.sh ├── run_ddp_worker.sh ├── run_sample.sh ├── scripts ├── image_sample.py └── image_train.py └── setup.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Masked Diffusion Transformer V2 2 | 3 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/masked-diffusion-transformer-is-a-strong/image-generation-on-imagenet-256x256)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?p=masked-diffusion-transformer-is-a-strong) 4 | [![HuggingFace space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/shgao/MDT) 5 | 6 | The official codebase for [Masked Diffusion Transformer is a Strong Image Synthesizer](https://arxiv.org/abs/2303.14389). 7 | 8 | ## MDTv2: Faster Convergeence & Stronger performance 9 | **MDTv2 achieves superior image synthesis performance, e.g., a new SOTA FID score of 1.58 on the ImageNet dataset, and has more than 10× faster learning speed than the previous SOTA DiT.** 10 | 11 | MDTv2 demonstrates a 5x acceleration compared to the original MDT. 12 | 13 | [MDTv1 code](https://github.com/sail-sg/MDT/tree/mdtv1) 14 | ## Introduction 15 | 16 | Despite its success in image synthesis, we observe that diffusion probabilistic models (DPMs) often lack contextual reasoning ability to learn the relations among object parts in an image, leading to a slow learning process. To solve this issue, we propose a Masked Diffusion Transformer (MDT) that introduces a mask latent modeling scheme to explicitly enhance the DPMs’ ability to contextual relation learning among object semantic parts in an image. 17 | 18 | During training, MDT operates in the latent space to mask certain tokens. Then, an asymmetric diffusion transformer is designed to predict masked tokens from unmasked ones while maintaining the diffusion generation process. Our MDT can reconstruct the full information of an image from its incomplete contextual input, thus enabling it to learn the associated relations among image tokens. We further improve MDT with a more efficient macro network structure and training strategy, named MDTv2. 19 | 20 | Experimental results show that MDTv2 achieves superior image synthesis performance, e.g., **a new SOTA FID score of 1.58 on the ImageNet dataset, and has more than 10× faster learning speed than the previous SOTA DiT**. 21 | 22 | image 23 | 24 | # Performance 25 | 26 | | Model| Dataset | Resolution | FID-50K | Inception Score | 27 | |---------|----------|-----------|---------|--------| 28 | |MDT-XL/2 | ImageNet | 256x256 | 1.79 | 283.01| 29 | |MDTv2-XL/2 | ImageNet | 256x256 | 1.58 | 314.73| 30 | 31 | [Pretrained model download](https://huggingface.co/shgao/MDT-XL2/tree/main) 32 | 33 | Model is hosted on hugglingface, you can also download it with: 34 | ``` 35 | from huggingface_hub import snapshot_download 36 | models_path = snapshot_download("shgao/MDT-XL2") 37 | ckpt_model_path = os.path.join(models_path, "mdt_xl2_v1_ckpt.pt") 38 | ``` 39 | A hugglingface demo is on [DEMO](https://huggingface.co/spaces/shgao/MDT). 40 | 41 | **NEW SOTA on FID.** 42 | # Setup 43 | 44 | Prepare the Pytorch >=2.0 version. Download and install this repo. 45 | 46 | ``` 47 | git clone https://github.com/sail-sg/MDT 48 | cd MDT 49 | pip install -e . 50 | ``` 51 | Install [Adan optimizer](https://github.com/sail-sg/Adan), Adan is a strong optimizer with faster convergence speed than AdamW. [(paper)](https://arxiv.org/abs/2208.06677) 52 | ``` 53 | python -m pip install git+https://github.com/sail-sg/Adan.git 54 | ``` 55 | 56 | **DATA** 57 | - For standard datasets like ImageNet and CIFAR, please refer to '[dataset](https://github.com/sail-sg/MDT/tree/main/datasets)' for preparation. 58 | - When using customized dataset, change the image file name to `ClassID_ImgID.jpg`, 59 | as the [ADM's dataloder](https://github.com/openai/guided-diffusion) gets the class ID from the file name. 60 | 61 | # Training 62 | 63 |
64 | Training on one node (`run.sh`). 65 | 66 | ```shell 67 | export OPENAI_LOGDIR=output_mdtv2_s2 68 | NUM_GPUS=8 69 | 70 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 6 --model MDTv2_S_2" 71 | DIFFUSION_FLAGS="--diffusion_steps 1000" 72 | TRAIN_FLAGS="--batch_size 32" 73 | DATA_PATH=/dataset/imagenet 74 | 75 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 76 | ``` 77 | 78 |
79 | 80 |
81 | Training on multiple nodes (`run_ddp_master.sh` and `run_ddp_worker.sh`). 82 | 83 | ```shell 84 | # On master: 85 | export OPENAI_LOGDIR=output_mdtv2_xl2 86 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_XL_2" 87 | DIFFUSION_FLAGS="--diffusion_steps 1000" 88 | TRAIN_FLAGS="--batch_size 4" 89 | DATA_PATH=/dataset/imagenet 90 | NUM_NODE=8 91 | GPU_PRE_NODE=8 92 | 93 | python -m torch.distributed.launch --master_addr=$(hostname) --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 94 | 95 | # On workers: 96 | export OPENAI_LOGDIR=output_mdtv2_xl2 97 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_XL_2" 98 | DIFFUSION_FLAGS="--diffusion_steps 1000" 99 | TRAIN_FLAGS="--batch_size 4" 100 | DATA_PATH=/dataset/imagenet 101 | NUM_NODE=8 102 | GPU_PRE_NODE=8 103 | 104 | python -m torch.distributed.launch --master_addr=$MASTER_ADDR --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 105 | 106 | 107 | ``` 108 | 109 |
110 | 111 | # Evaluation 112 | 113 | The evaluation code is obtained from [ADM's TensorFlow evaluation suite](https://github.com/openai/guided-diffusion/tree/main/evaluations). 114 | Please follow the instructions in the `evaluations` folder to set up the evaluation environment. 115 | 116 |
117 | Sampling and Evaluation (`run_sample.sh`): 118 | 119 | ```shell 120 | MODEL_PATH=output_mdtv2_xl2/mdt_xl2_v2_ckpt.pt 121 | export OPENAI_LOGDIR=output_mdtv2_xl2_eval 122 | NUM_GPUS=8 123 | 124 | echo 'CFG Class-conditional sampling:' 125 | MODEL_FLAGS="--image_size 256 --model MDTv2_XL_2 --decode_layer 4" 126 | DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000 --cfg_cond True" 127 | echo $MODEL_FLAGS 128 | echo $DIFFUSION_FLAGS 129 | echo $MODEL_PATH 130 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_sample.py --model_path $MODEL_PATH $MODEL_FLAGS $DIFFUSION_FLAGS 131 | echo $MODEL_FLAGS 132 | echo $DIFFUSION_FLAGS 133 | echo $MODEL_PATH 134 | python evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz $OPENAI_LOGDIR/samples_50000x256x256x3.npz 135 | 136 | echo 'Class-conditional sampling:' 137 | MODEL_FLAGS="--image_size 256 --model MDTv2_XL_2 --decode_layer 4" 138 | DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000" 139 | echo $MODEL_FLAGS 140 | echo $DIFFUSION_FLAGS 141 | echo $MODEL_PATH 142 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_sample.py --model_path $MODEL_PATH $MODEL_FLAGS $DIFFUSION_FLAGS 143 | echo $MODEL_FLAGS 144 | echo $DIFFUSION_FLAGS 145 | echo $MODEL_PATH 146 | python evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz $OPENAI_LOGDIR/samples_50000x256x256x3.npz 147 | ``` 148 | 149 |
150 | 151 | # Visualization 152 | 153 | Run the `infer_mdt.py` to generate images. 154 | 155 | # Citation 156 | 157 | ``` 158 | @misc{gao2023masked, 159 | title={Masked Diffusion Transformer is a Strong Image Synthesizer}, 160 | author={Shanghua Gao and Pan Zhou and Ming-Ming Cheng and Shuicheng Yan}, 161 | year={2023}, 162 | eprint={2303.14389}, 163 | archivePrefix={arXiv}, 164 | primaryClass={cs.CV} 165 | } 166 | ``` 167 | 168 | # Acknowledgement 169 | 170 | This codebase is built based on the [DiT](https://github.com/facebookresearch/dit) and [ADM](https://github.com/openai/guided-diffusion). Thanks! 171 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Downloading datasets 2 | 3 | This directory includes instructions and scripts for downloading ImageNet and LSUN bedrooms for use in this codebase. 4 | 5 | ## Class-conditional ImageNet 6 | 7 | For our class-conditional models, we use the official ILSVRC2012 dataset with manual center cropping and downsampling. To obtain this dataset, navigate to [this page on image-net.org](http://www.image-net.org/challenges/LSVRC/2012/downloads) and sign in (or create an account if you do not already have one). Then click on the link reading "Training images (Task 1 & 2)". This is a 138GB tar file containing 1000 sub-tar files, one per class. 8 | 9 | Once the file is downloaded, extract it and look inside. You should see 1000 `.tar` files. You need to extract each of these, which may be impractical to do by hand on your operating system. To automate the process on a Unix-based system, you can `cd` into the directory and run this short shell script: 10 | 11 | ``` 12 | for file in *.tar; do tar xf "$file"; rm "$file"; done 13 | ``` 14 | 15 | This will extract and remove each tar file in turn. 16 | 17 | Once all of the images have been extracted, the resulting directory should be usable as a data directory (the `--data_dir` argument for the training script). The filenames should all start with WNID (class ids) followed by underscores, like `n01440764_2708.JPEG`. Conveniently (but not by accident) this is how the automated data-loader expects to discover class labels. 18 | 19 | ## LSUN bedroom 20 | 21 | To download and pre-process LSUN bedroom, clone [fyu/lsun](https://github.com/fyu/lsun) on GitHub and run their download script `python3 download.py bedroom`. The result will be an "lmdb" database named like `bedroom_train_lmdb`. You can pass this to our [lsun_bedroom.py](lsun_bedroom.py) script like so: 22 | 23 | ``` 24 | python lsun_bedroom.py bedroom_train_lmdb lsun_train_output_dir 25 | ``` 26 | 27 | This creates a directory called `lsun_train_output_dir`. This directory can be passed to the training scripts via the `--data_dir` argument. 28 | -------------------------------------------------------------------------------- /datasets/lsun_bedroom.py: -------------------------------------------------------------------------------- 1 | """ 2 | Convert an LSUN lmdb database into a directory of images. 3 | """ 4 | 5 | import argparse 6 | import io 7 | import os 8 | 9 | from PIL import Image 10 | import lmdb 11 | import numpy as np 12 | 13 | 14 | def read_images(lmdb_path, image_size): 15 | env = lmdb.open(lmdb_path, map_size=1099511627776, max_readers=100, readonly=True) 16 | with env.begin(write=False) as transaction: 17 | cursor = transaction.cursor() 18 | for _, webp_data in cursor: 19 | img = Image.open(io.BytesIO(webp_data)) 20 | width, height = img.size 21 | scale = image_size / min(width, height) 22 | img = img.resize( 23 | (int(round(scale * width)), int(round(scale * height))), 24 | resample=Image.BOX, 25 | ) 26 | arr = np.array(img) 27 | h, w, _ = arr.shape 28 | h_off = (h - image_size) // 2 29 | w_off = (w - image_size) // 2 30 | arr = arr[h_off : h_off + image_size, w_off : w_off + image_size] 31 | yield arr 32 | 33 | 34 | def dump_images(out_dir, images, prefix): 35 | if not os.path.exists(out_dir): 36 | os.mkdir(out_dir) 37 | for i, img in enumerate(images): 38 | Image.fromarray(img).save(os.path.join(out_dir, f"{prefix}_{i:07d}.png")) 39 | 40 | 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--image-size", help="new image size", type=int, default=256) 44 | parser.add_argument("--prefix", help="class name", type=str, default="bedroom") 45 | parser.add_argument("lmdb_path", help="path to an LSUN lmdb database") 46 | parser.add_argument("out_dir", help="path to output directory") 47 | args = parser.parse_args() 48 | 49 | images = read_images(args.lmdb_path, args.image_size) 50 | dump_images(args.out_dir, images, args.prefix) 51 | 52 | 53 | if __name__ == "__main__": 54 | main() 55 | -------------------------------------------------------------------------------- /evaluations/README.md: -------------------------------------------------------------------------------- 1 | # Evaluations 2 | 3 | To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files. 4 | 5 | # Download batches 6 | 7 | We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format. 8 | 9 | Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall. 10 | 11 | Here are links to download all of the sample and reference batches: 12 | 13 | * LSUN 14 | * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz) 15 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz) 16 | * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz) 17 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz) 18 | * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz) 19 | * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz) 20 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz) 21 | * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz) 22 | * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz) 23 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz) 24 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz) 25 | 26 | * ImageNet 27 | * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz) 28 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz) 29 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz) 30 | * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz) 31 | * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz) 32 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz) 33 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz) 34 | * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz) 35 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz) 36 | * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) 37 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz) 38 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz) 39 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz) 40 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz) 41 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz) 42 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz) 43 | * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) 44 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz) 45 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz) 46 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz) 47 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz) 48 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz) 49 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz) 50 | 51 | # Run evaluations 52 | 53 | First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`. 54 | 55 | Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB. 56 | 57 | The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging: 58 | 59 | ``` 60 | $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz 61 | ... 62 | computing reference batch activations... 63 | computing/reading reference batch statistics... 64 | computing sample batch activations... 65 | computing/reading sample batch statistics... 66 | Computing evaluations... 67 | Inception Score: 215.8370361328125 68 | FID: 3.9425574129223264 69 | sFID: 6.140433703346162 70 | Precision: 0.8265 71 | Recall: 0.5309 72 | ``` 73 | -------------------------------------------------------------------------------- /evaluations/evaluator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import io 3 | import os 4 | import random 5 | import warnings 6 | import zipfile 7 | from abc import ABC, abstractmethod 8 | from contextlib import contextmanager 9 | from functools import partial 10 | from multiprocessing import cpu_count 11 | from multiprocessing.pool import ThreadPool 12 | from typing import Iterable, Optional, Tuple 13 | 14 | import numpy as np 15 | import requests 16 | import tensorflow.compat.v1 as tf 17 | from scipy import linalg 18 | from tqdm.auto import tqdm 19 | 20 | INCEPTION_V3_URL = "https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/classify_image_graph_def.pb" 21 | INCEPTION_V3_PATH = "classify_image_graph_def.pb" 22 | 23 | FID_POOL_NAME = "pool_3:0" 24 | FID_SPATIAL_NAME = "mixed_6/conv:0" 25 | 26 | 27 | def main(): 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("ref_batch", help="path to reference batch npz file") 30 | parser.add_argument("sample_batch", help="path to sample batch npz file") 31 | args = parser.parse_args() 32 | 33 | config = tf.ConfigProto( 34 | allow_soft_placement=True # allows DecodeJpeg to run on CPU in Inception graph 35 | ) 36 | config.gpu_options.allow_growth = True 37 | evaluator = Evaluator(tf.Session(config=config)) 38 | 39 | print("warming up TensorFlow...") 40 | # This will cause TF to print a bunch of verbose stuff now rather 41 | # than after the next print(), to help prevent confusion. 42 | evaluator.warmup() 43 | 44 | print("computing reference batch activations...") 45 | ref_acts = evaluator.read_activations(args.ref_batch) 46 | print("computing/reading reference batch statistics...") 47 | ref_stats, ref_stats_spatial = evaluator.read_statistics(args.ref_batch, ref_acts) 48 | 49 | print("computing sample batch activations...") 50 | sample_acts = evaluator.read_activations(args.sample_batch) 51 | print("computing/reading sample batch statistics...") 52 | sample_stats, sample_stats_spatial = evaluator.read_statistics(args.sample_batch, sample_acts) 53 | 54 | print("Computing evaluations...") 55 | print("Inception Score:", evaluator.compute_inception_score(sample_acts[0])) 56 | print("FID:", sample_stats.frechet_distance(ref_stats)) 57 | print("sFID:", sample_stats_spatial.frechet_distance(ref_stats_spatial)) 58 | prec, recall = evaluator.compute_prec_recall(ref_acts[0], sample_acts[0]) 59 | print("Precision:", prec) 60 | print("Recall:", recall) 61 | 62 | 63 | class InvalidFIDException(Exception): 64 | pass 65 | 66 | 67 | class FIDStatistics: 68 | def __init__(self, mu: np.ndarray, sigma: np.ndarray): 69 | self.mu = mu 70 | self.sigma = sigma 71 | 72 | def frechet_distance(self, other, eps=1e-6): 73 | """ 74 | Compute the Frechet distance between two sets of statistics. 75 | """ 76 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L132 77 | mu1, sigma1 = self.mu, self.sigma 78 | mu2, sigma2 = other.mu, other.sigma 79 | 80 | mu1 = np.atleast_1d(mu1) 81 | mu2 = np.atleast_1d(mu2) 82 | 83 | sigma1 = np.atleast_2d(sigma1) 84 | sigma2 = np.atleast_2d(sigma2) 85 | 86 | assert ( 87 | mu1.shape == mu2.shape 88 | ), f"Training and test mean vectors have different lengths: {mu1.shape}, {mu2.shape}" 89 | assert ( 90 | sigma1.shape == sigma2.shape 91 | ), f"Training and test covariances have different dimensions: {sigma1.shape}, {sigma2.shape}" 92 | 93 | diff = mu1 - mu2 94 | 95 | # product might be almost singular 96 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 97 | if not np.isfinite(covmean).all(): 98 | msg = ( 99 | "fid calculation produces singular product; adding %s to diagonal of cov estimates" 100 | % eps 101 | ) 102 | warnings.warn(msg) 103 | offset = np.eye(sigma1.shape[0]) * eps 104 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 105 | 106 | # numerical error might give slight imaginary component 107 | if np.iscomplexobj(covmean): 108 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 109 | m = np.max(np.abs(covmean.imag)) 110 | raise ValueError("Imaginary component {}".format(m)) 111 | covmean = covmean.real 112 | 113 | tr_covmean = np.trace(covmean) 114 | 115 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 116 | 117 | 118 | class Evaluator: 119 | def __init__( 120 | self, 121 | session, 122 | batch_size=64, 123 | softmax_batch_size=512, 124 | ): 125 | self.sess = session 126 | self.batch_size = batch_size 127 | self.softmax_batch_size = softmax_batch_size 128 | self.manifold_estimator = ManifoldEstimator(session) 129 | with self.sess.graph.as_default(): 130 | self.image_input = tf.placeholder(tf.float32, shape=[None, None, None, 3]) 131 | self.softmax_input = tf.placeholder(tf.float32, shape=[None, 2048]) 132 | self.pool_features, self.spatial_features = _create_feature_graph(self.image_input) 133 | self.softmax = _create_softmax_graph(self.softmax_input) 134 | 135 | def warmup(self): 136 | self.compute_activations(np.zeros([1, 8, 64, 64, 3])) 137 | 138 | def read_activations(self, npz_path: str) -> Tuple[np.ndarray, np.ndarray]: 139 | with open_npz_array(npz_path, "arr_0") as reader: 140 | return self.compute_activations(reader.read_batches(self.batch_size)) 141 | 142 | def compute_activations(self, batches: Iterable[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]: 143 | """ 144 | Compute image features for downstream evals. 145 | 146 | :param batches: a iterator over NHWC numpy arrays in [0, 255]. 147 | :return: a tuple of numpy arrays of shape [N x X], where X is a feature 148 | dimension. The tuple is (pool_3, spatial). 149 | """ 150 | preds = [] 151 | spatial_preds = [] 152 | for batch in tqdm(batches): 153 | batch = batch.astype(np.float32) 154 | pred, spatial_pred = self.sess.run( 155 | [self.pool_features, self.spatial_features], {self.image_input: batch} 156 | ) 157 | preds.append(pred.reshape([pred.shape[0], -1])) 158 | spatial_preds.append(spatial_pred.reshape([spatial_pred.shape[0], -1])) 159 | return ( 160 | np.concatenate(preds, axis=0), 161 | np.concatenate(spatial_preds, axis=0), 162 | ) 163 | 164 | def read_statistics( 165 | self, npz_path: str, activations: Tuple[np.ndarray, np.ndarray] 166 | ) -> Tuple[FIDStatistics, FIDStatistics]: 167 | obj = np.load(npz_path) 168 | if "mu" in list(obj.keys()): 169 | return FIDStatistics(obj["mu"], obj["sigma"]), FIDStatistics( 170 | obj["mu_s"], obj["sigma_s"] 171 | ) 172 | return tuple(self.compute_statistics(x) for x in activations) 173 | 174 | def compute_statistics(self, activations: np.ndarray) -> FIDStatistics: 175 | mu = np.mean(activations, axis=0) 176 | sigma = np.cov(activations, rowvar=False) 177 | return FIDStatistics(mu, sigma) 178 | 179 | def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float: 180 | softmax_out = [] 181 | for i in range(0, len(activations), self.softmax_batch_size): 182 | acts = activations[i : i + self.softmax_batch_size] 183 | softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts})) 184 | preds = np.concatenate(softmax_out, axis=0) 185 | # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46 186 | scores = [] 187 | for i in range(0, len(preds), split_size): 188 | part = preds[i : i + split_size] 189 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 190 | kl = np.mean(np.sum(kl, 1)) 191 | scores.append(np.exp(kl)) 192 | return float(np.mean(scores)) 193 | 194 | def compute_prec_recall( 195 | self, activations_ref: np.ndarray, activations_sample: np.ndarray 196 | ) -> Tuple[float, float]: 197 | radii_1 = self.manifold_estimator.manifold_radii(activations_ref) 198 | radii_2 = self.manifold_estimator.manifold_radii(activations_sample) 199 | pr = self.manifold_estimator.evaluate_pr( 200 | activations_ref, radii_1, activations_sample, radii_2 201 | ) 202 | return (float(pr[0][0]), float(pr[1][0])) 203 | 204 | 205 | class ManifoldEstimator: 206 | """ 207 | A helper for comparing manifolds of feature vectors. 208 | 209 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L57 210 | """ 211 | 212 | def __init__( 213 | self, 214 | session, 215 | row_batch_size=10000, 216 | col_batch_size=10000, 217 | nhood_sizes=(3,), 218 | clamp_to_percentile=None, 219 | eps=1e-5, 220 | ): 221 | """ 222 | Estimate the manifold of given feature vectors. 223 | 224 | :param session: the TensorFlow session. 225 | :param row_batch_size: row batch size to compute pairwise distances 226 | (parameter to trade-off between memory usage and performance). 227 | :param col_batch_size: column batch size to compute pairwise distances. 228 | :param nhood_sizes: number of neighbors used to estimate the manifold. 229 | :param clamp_to_percentile: prune hyperspheres that have radius larger than 230 | the given percentile. 231 | :param eps: small number for numerical stability. 232 | """ 233 | self.distance_block = DistanceBlock(session) 234 | self.row_batch_size = row_batch_size 235 | self.col_batch_size = col_batch_size 236 | self.nhood_sizes = nhood_sizes 237 | self.num_nhoods = len(nhood_sizes) 238 | self.clamp_to_percentile = clamp_to_percentile 239 | self.eps = eps 240 | 241 | def warmup(self): 242 | feats, radii = ( 243 | np.zeros([1, 2048], dtype=np.float32), 244 | np.zeros([1, 1], dtype=np.float32), 245 | ) 246 | self.evaluate_pr(feats, radii, feats, radii) 247 | 248 | def manifold_radii(self, features: np.ndarray) -> np.ndarray: 249 | num_images = len(features) 250 | 251 | # Estimate manifold of features by calculating distances to k-NN of each sample. 252 | radii = np.zeros([num_images, self.num_nhoods], dtype=np.float32) 253 | distance_batch = np.zeros([self.row_batch_size, num_images], dtype=np.float32) 254 | seq = np.arange(max(self.nhood_sizes) + 1, dtype=np.int32) 255 | 256 | for begin1 in range(0, num_images, self.row_batch_size): 257 | end1 = min(begin1 + self.row_batch_size, num_images) 258 | row_batch = features[begin1:end1] 259 | 260 | for begin2 in range(0, num_images, self.col_batch_size): 261 | end2 = min(begin2 + self.col_batch_size, num_images) 262 | col_batch = features[begin2:end2] 263 | 264 | # Compute distances between batches. 265 | distance_batch[ 266 | 0 : end1 - begin1, begin2:end2 267 | ] = self.distance_block.pairwise_distances(row_batch, col_batch) 268 | 269 | # Find the k-nearest neighbor from the current batch. 270 | radii[begin1:end1, :] = np.concatenate( 271 | [ 272 | x[:, self.nhood_sizes] 273 | for x in _numpy_partition(distance_batch[0 : end1 - begin1, :], seq, axis=1) 274 | ], 275 | axis=0, 276 | ) 277 | 278 | if self.clamp_to_percentile is not None: 279 | max_distances = np.percentile(radii, self.clamp_to_percentile, axis=0) 280 | radii[radii > max_distances] = 0 281 | return radii 282 | 283 | def evaluate(self, features: np.ndarray, radii: np.ndarray, eval_features: np.ndarray): 284 | """ 285 | Evaluate if new feature vectors are at the manifold. 286 | """ 287 | num_eval_images = eval_features.shape[0] 288 | num_ref_images = radii.shape[0] 289 | distance_batch = np.zeros([self.row_batch_size, num_ref_images], dtype=np.float32) 290 | batch_predictions = np.zeros([num_eval_images, self.num_nhoods], dtype=np.int32) 291 | max_realism_score = np.zeros([num_eval_images], dtype=np.float32) 292 | nearest_indices = np.zeros([num_eval_images], dtype=np.int32) 293 | 294 | for begin1 in range(0, num_eval_images, self.row_batch_size): 295 | end1 = min(begin1 + self.row_batch_size, num_eval_images) 296 | feature_batch = eval_features[begin1:end1] 297 | 298 | for begin2 in range(0, num_ref_images, self.col_batch_size): 299 | end2 = min(begin2 + self.col_batch_size, num_ref_images) 300 | ref_batch = features[begin2:end2] 301 | 302 | distance_batch[ 303 | 0 : end1 - begin1, begin2:end2 304 | ] = self.distance_block.pairwise_distances(feature_batch, ref_batch) 305 | 306 | # From the minibatch of new feature vectors, determine if they are in the estimated manifold. 307 | # If a feature vector is inside a hypersphere of some reference sample, then 308 | # the new sample lies at the estimated manifold. 309 | # The radii of the hyperspheres are determined from distances of neighborhood size k. 310 | samples_in_manifold = distance_batch[0 : end1 - begin1, :, None] <= radii 311 | batch_predictions[begin1:end1] = np.any(samples_in_manifold, axis=1).astype(np.int32) 312 | 313 | max_realism_score[begin1:end1] = np.max( 314 | radii[:, 0] / (distance_batch[0 : end1 - begin1, :] + self.eps), axis=1 315 | ) 316 | nearest_indices[begin1:end1] = np.argmin(distance_batch[0 : end1 - begin1, :], axis=1) 317 | 318 | return { 319 | "fraction": float(np.mean(batch_predictions)), 320 | "batch_predictions": batch_predictions, 321 | "max_realisim_score": max_realism_score, 322 | "nearest_indices": nearest_indices, 323 | } 324 | 325 | def evaluate_pr( 326 | self, 327 | features_1: np.ndarray, 328 | radii_1: np.ndarray, 329 | features_2: np.ndarray, 330 | radii_2: np.ndarray, 331 | ) -> Tuple[np.ndarray, np.ndarray]: 332 | """ 333 | Evaluate precision and recall efficiently. 334 | 335 | :param features_1: [N1 x D] feature vectors for reference batch. 336 | :param radii_1: [N1 x K1] radii for reference vectors. 337 | :param features_2: [N2 x D] feature vectors for the other batch. 338 | :param radii_2: [N x K2] radii for other vectors. 339 | :return: a tuple of arrays for (precision, recall): 340 | - precision: an np.ndarray of length K1 341 | - recall: an np.ndarray of length K2 342 | """ 343 | features_1_status = np.zeros([len(features_1), radii_2.shape[1]], dtype=np.bool) 344 | features_2_status = np.zeros([len(features_2), radii_1.shape[1]], dtype=np.bool) 345 | for begin_1 in range(0, len(features_1), self.row_batch_size): 346 | end_1 = begin_1 + self.row_batch_size 347 | batch_1 = features_1[begin_1:end_1] 348 | for begin_2 in range(0, len(features_2), self.col_batch_size): 349 | end_2 = begin_2 + self.col_batch_size 350 | batch_2 = features_2[begin_2:end_2] 351 | batch_1_in, batch_2_in = self.distance_block.less_thans( 352 | batch_1, radii_1[begin_1:end_1], batch_2, radii_2[begin_2:end_2] 353 | ) 354 | features_1_status[begin_1:end_1] |= batch_1_in 355 | features_2_status[begin_2:end_2] |= batch_2_in 356 | return ( 357 | np.mean(features_2_status.astype(np.float64), axis=0), 358 | np.mean(features_1_status.astype(np.float64), axis=0), 359 | ) 360 | 361 | 362 | class DistanceBlock: 363 | """ 364 | Calculate pairwise distances between vectors. 365 | 366 | Adapted from https://github.com/kynkaat/improved-precision-and-recall-metric/blob/f60f25e5ad933a79135c783fcda53de30f42c9b9/precision_recall.py#L34 367 | """ 368 | 369 | def __init__(self, session): 370 | self.session = session 371 | 372 | # Initialize TF graph to calculate pairwise distances. 373 | with session.graph.as_default(): 374 | self._features_batch1 = tf.placeholder(tf.float32, shape=[None, None]) 375 | self._features_batch2 = tf.placeholder(tf.float32, shape=[None, None]) 376 | distance_block_16 = _batch_pairwise_distances( 377 | tf.cast(self._features_batch1, tf.float16), 378 | tf.cast(self._features_batch2, tf.float16), 379 | ) 380 | self.distance_block = tf.cond( 381 | tf.reduce_all(tf.math.is_finite(distance_block_16)), 382 | lambda: tf.cast(distance_block_16, tf.float32), 383 | lambda: _batch_pairwise_distances(self._features_batch1, self._features_batch2), 384 | ) 385 | 386 | # Extra logic for less thans. 387 | self._radii1 = tf.placeholder(tf.float32, shape=[None, None]) 388 | self._radii2 = tf.placeholder(tf.float32, shape=[None, None]) 389 | dist32 = tf.cast(self.distance_block, tf.float32)[..., None] 390 | self._batch_1_in = tf.math.reduce_any(dist32 <= self._radii2, axis=1) 391 | self._batch_2_in = tf.math.reduce_any(dist32 <= self._radii1[:, None], axis=0) 392 | 393 | def pairwise_distances(self, U, V): 394 | """ 395 | Evaluate pairwise distances between two batches of feature vectors. 396 | """ 397 | return self.session.run( 398 | self.distance_block, 399 | feed_dict={self._features_batch1: U, self._features_batch2: V}, 400 | ) 401 | 402 | def less_thans(self, batch_1, radii_1, batch_2, radii_2): 403 | return self.session.run( 404 | [self._batch_1_in, self._batch_2_in], 405 | feed_dict={ 406 | self._features_batch1: batch_1, 407 | self._features_batch2: batch_2, 408 | self._radii1: radii_1, 409 | self._radii2: radii_2, 410 | }, 411 | ) 412 | 413 | 414 | def _batch_pairwise_distances(U, V): 415 | """ 416 | Compute pairwise distances between two batches of feature vectors. 417 | """ 418 | with tf.variable_scope("pairwise_dist_block"): 419 | # Squared norms of each row in U and V. 420 | norm_u = tf.reduce_sum(tf.square(U), 1) 421 | norm_v = tf.reduce_sum(tf.square(V), 1) 422 | 423 | # norm_u as a column and norm_v as a row vectors. 424 | norm_u = tf.reshape(norm_u, [-1, 1]) 425 | norm_v = tf.reshape(norm_v, [1, -1]) 426 | 427 | # Pairwise squared Euclidean distances. 428 | D = tf.maximum(norm_u - 2 * tf.matmul(U, V, False, True) + norm_v, 0.0) 429 | 430 | return D 431 | 432 | 433 | class NpzArrayReader(ABC): 434 | @abstractmethod 435 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 436 | pass 437 | 438 | @abstractmethod 439 | def remaining(self) -> int: 440 | pass 441 | 442 | def read_batches(self, batch_size: int) -> Iterable[np.ndarray]: 443 | def gen_fn(): 444 | while True: 445 | batch = self.read_batch(batch_size) 446 | if batch is None: 447 | break 448 | yield batch 449 | 450 | rem = self.remaining() 451 | num_batches = rem // batch_size + int(rem % batch_size != 0) 452 | return BatchIterator(gen_fn, num_batches) 453 | 454 | 455 | class BatchIterator: 456 | def __init__(self, gen_fn, length): 457 | self.gen_fn = gen_fn 458 | self.length = length 459 | 460 | def __len__(self): 461 | return self.length 462 | 463 | def __iter__(self): 464 | return self.gen_fn() 465 | 466 | 467 | class StreamingNpzArrayReader(NpzArrayReader): 468 | def __init__(self, arr_f, shape, dtype): 469 | self.arr_f = arr_f 470 | self.shape = shape 471 | self.dtype = dtype 472 | self.idx = 0 473 | 474 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 475 | if self.idx >= self.shape[0]: 476 | return None 477 | 478 | bs = min(batch_size, self.shape[0] - self.idx) 479 | self.idx += bs 480 | 481 | if self.dtype.itemsize == 0: 482 | return np.ndarray([bs, *self.shape[1:]], dtype=self.dtype) 483 | 484 | read_count = bs * np.prod(self.shape[1:]) 485 | read_size = int(read_count * self.dtype.itemsize) 486 | data = _read_bytes(self.arr_f, read_size, "array data") 487 | return np.frombuffer(data, dtype=self.dtype).reshape([bs, *self.shape[1:]]) 488 | 489 | def remaining(self) -> int: 490 | return max(0, self.shape[0] - self.idx) 491 | 492 | 493 | class MemoryNpzArrayReader(NpzArrayReader): 494 | def __init__(self, arr): 495 | self.arr = arr 496 | self.idx = 0 497 | 498 | @classmethod 499 | def load(cls, path: str, arr_name: str): 500 | with open(path, "rb") as f: 501 | arr = np.load(f)[arr_name] 502 | return cls(arr) 503 | 504 | def read_batch(self, batch_size: int) -> Optional[np.ndarray]: 505 | if self.idx >= self.arr.shape[0]: 506 | return None 507 | 508 | res = self.arr[self.idx : self.idx + batch_size] 509 | self.idx += batch_size 510 | return res 511 | 512 | def remaining(self) -> int: 513 | return max(0, self.arr.shape[0] - self.idx) 514 | 515 | 516 | @contextmanager 517 | def open_npz_array(path: str, arr_name: str) -> NpzArrayReader: 518 | with _open_npy_file(path, arr_name) as arr_f: 519 | version = np.lib.format.read_magic(arr_f) 520 | if version == (1, 0): 521 | header = np.lib.format.read_array_header_1_0(arr_f) 522 | elif version == (2, 0): 523 | header = np.lib.format.read_array_header_2_0(arr_f) 524 | else: 525 | yield MemoryNpzArrayReader.load(path, arr_name) 526 | return 527 | shape, fortran, dtype = header 528 | if fortran or dtype.hasobject: 529 | yield MemoryNpzArrayReader.load(path, arr_name) 530 | else: 531 | yield StreamingNpzArrayReader(arr_f, shape, dtype) 532 | 533 | 534 | def _read_bytes(fp, size, error_template="ran out of data"): 535 | """ 536 | Copied from: https://github.com/numpy/numpy/blob/fb215c76967739268de71aa4bda55dd1b062bc2e/numpy/lib/format.py#L788-L886 537 | 538 | Read from file-like object until size bytes are read. 539 | Raises ValueError if not EOF is encountered before size bytes are read. 540 | Non-blocking objects only supported if they derive from io objects. 541 | Required as e.g. ZipExtFile in python 2.6 can return less data than 542 | requested. 543 | """ 544 | data = bytes() 545 | while True: 546 | # io files (default in python3) return None or raise on 547 | # would-block, python2 file will truncate, probably nothing can be 548 | # done about that. note that regular files can't be non-blocking 549 | try: 550 | r = fp.read(size - len(data)) 551 | data += r 552 | if len(r) == 0 or len(data) == size: 553 | break 554 | except io.BlockingIOError: 555 | pass 556 | if len(data) != size: 557 | msg = "EOF: reading %s, expected %d bytes got %d" 558 | raise ValueError(msg % (error_template, size, len(data))) 559 | else: 560 | return data 561 | 562 | 563 | @contextmanager 564 | def _open_npy_file(path: str, arr_name: str): 565 | with open(path, "rb") as f: 566 | with zipfile.ZipFile(f, "r") as zip_f: 567 | if f"{arr_name}.npy" not in zip_f.namelist(): 568 | raise ValueError(f"missing {arr_name} in npz file") 569 | with zip_f.open(f"{arr_name}.npy", "r") as arr_f: 570 | yield arr_f 571 | 572 | 573 | def _download_inception_model(): 574 | if os.path.exists(INCEPTION_V3_PATH): 575 | return 576 | print("downloading InceptionV3 model...") 577 | with requests.get(INCEPTION_V3_URL, stream=True) as r: 578 | r.raise_for_status() 579 | tmp_path = INCEPTION_V3_PATH + ".tmp" 580 | with open(tmp_path, "wb") as f: 581 | for chunk in tqdm(r.iter_content(chunk_size=8192)): 582 | f.write(chunk) 583 | os.rename(tmp_path, INCEPTION_V3_PATH) 584 | 585 | 586 | def _create_feature_graph(input_batch): 587 | _download_inception_model() 588 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" 589 | with open(INCEPTION_V3_PATH, "rb") as f: 590 | graph_def = tf.GraphDef() 591 | graph_def.ParseFromString(f.read()) 592 | pool3, spatial = tf.import_graph_def( 593 | graph_def, 594 | input_map={f"ExpandDims:0": input_batch}, 595 | return_elements=[FID_POOL_NAME, FID_SPATIAL_NAME], 596 | name=prefix, 597 | ) 598 | _update_shapes(pool3) 599 | spatial = spatial[..., :7] 600 | return pool3, spatial 601 | 602 | 603 | def _create_softmax_graph(input_batch): 604 | _download_inception_model() 605 | prefix = f"{random.randrange(2**32)}_{random.randrange(2**32)}" 606 | with open(INCEPTION_V3_PATH, "rb") as f: 607 | graph_def = tf.GraphDef() 608 | graph_def.ParseFromString(f.read()) 609 | (matmul,) = tf.import_graph_def( 610 | graph_def, return_elements=[f"softmax/logits/MatMul"], name=prefix 611 | ) 612 | w = matmul.inputs[1] 613 | logits = tf.matmul(input_batch, w) 614 | return tf.nn.softmax(logits) 615 | 616 | 617 | def _update_shapes(pool3): 618 | # https://github.com/bioinf-jku/TTUR/blob/73ab375cdf952a12686d9aa7978567771084da42/fid.py#L50-L63 619 | ops = pool3.graph.get_operations() 620 | for op in ops: 621 | for o in op.outputs: 622 | shape = o.get_shape() 623 | if shape._dims is not None: # pylint: disable=protected-access 624 | # shape = [s.value for s in shape] TF 1.x 625 | shape = [s for s in shape] # TF 2.x 626 | new_shape = [] 627 | for j, s in enumerate(shape): 628 | if s == 1 and j == 0: 629 | new_shape.append(None) 630 | else: 631 | new_shape.append(s) 632 | o.__dict__["_shape_val"] = tf.TensorShape(new_shape) 633 | return pool3 634 | 635 | 636 | def _numpy_partition(arr, kth, **kwargs): 637 | num_workers = min(cpu_count(), len(arr)) 638 | chunk_size = len(arr) // num_workers 639 | extra = len(arr) % num_workers 640 | 641 | start_idx = 0 642 | batches = [] 643 | for i in range(num_workers): 644 | size = chunk_size + (1 if i < extra else 0) 645 | batches.append(arr[start_idx : start_idx + size]) 646 | start_idx += size 647 | 648 | with ThreadPool(num_workers) as pool: 649 | return list(pool.map(partial(np.partition, kth=kth, **kwargs), batches)) 650 | 651 | 652 | if __name__ == "__main__": 653 | main() 654 | -------------------------------------------------------------------------------- /evaluations/requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow-gpu>=2.0 2 | scipy 3 | requests 4 | tqdm -------------------------------------------------------------------------------- /figures/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/MDT/7d26c2162c462bd0b90f97f3a1c36cdaaac616ec/figures/.DS_Store -------------------------------------------------------------------------------- /figures/vis.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sail-sg/MDT/7d26c2162c462bd0b90f97f3a1c36cdaaac616ec/figures/vis.jpg -------------------------------------------------------------------------------- /infer_mdt.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import torch 8 | from torchvision.utils import save_image 9 | from masked_diffusion import create_diffusion 10 | from diffusers.models import AutoencoderKL 11 | from masked_diffusion.models import MDTv2_XL_2 12 | 13 | 14 | # Setup PyTorch: 15 | torch.manual_seed(1) 16 | torch.set_grad_enabled(False) 17 | device = "cuda" if torch.cuda.is_available() else "cpu" 18 | num_sampling_steps = 250 19 | cfg_scale = 4.0 20 | pow_scale = 0.01 # large pow_scale increase the diversity, small pow_scale increase the quality. 21 | model_path = 'mdt_xl2_v2_ckpt.pt' 22 | 23 | # Load model: 24 | image_size = 256 25 | assert image_size in [256], "We provide pre-trained models for 256x256 resolutions for now." 26 | latent_size = image_size // 8 27 | model = MDTv2_XL_2(input_size=latent_size, decode_layer=4).to(device) 28 | 29 | state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) 30 | model.load_state_dict(state_dict) 31 | model.eval() 32 | diffusion = create_diffusion(str(num_sampling_steps)) 33 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device) 34 | 35 | # Labels to condition the model with: 36 | class_labels = [19,23,106,108,278,282] 37 | 38 | # Create sampling noise: 39 | n = len(class_labels) 40 | z = torch.randn(n, 4, latent_size, latent_size, device=device) 41 | y = torch.tensor(class_labels, device=device) 42 | 43 | # Setup classifier-free guidance: 44 | z = torch.cat([z, z], 0) 45 | y_null = torch.tensor([1000] * n, device=device) 46 | y = torch.cat([y, y_null], 0) 47 | 48 | 49 | model_kwargs = dict(y=y, cfg_scale=cfg_scale, scale_pow=pow_scale) 50 | 51 | # Sample images: 52 | samples = diffusion.p_sample_loop( 53 | model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, progress=True, device=device 54 | ) 55 | samples, _ = samples.chunk(2, dim=0) # Remove null class samples 56 | samples = vae.decode(samples / 0.18215).sample 57 | 58 | # Save and display images: 59 | save_image(samples, "sample.jpg", nrow=3, normalize=True, value_range=(-1, 1)) 60 | -------------------------------------------------------------------------------- /masked_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from . import gaussian_diffusion as gd 7 | from .respace import SpacedDiffusion, space_timesteps 8 | 9 | 10 | def create_diffusion( 11 | timestep_respacing, 12 | noise_schedule="linear", 13 | use_kl=False, 14 | sigma_small=False, 15 | predict_xstart=False, 16 | learn_sigma=True, 17 | rescale_learned_sigmas=False, 18 | diffusion_steps=1000 19 | ): 20 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 21 | if use_kl: 22 | loss_type = gd.LossType.RESCALED_KL 23 | elif rescale_learned_sigmas: 24 | loss_type = gd.LossType.RESCALED_MSE 25 | else: 26 | loss_type = gd.LossType.MSE 27 | if timestep_respacing is None or timestep_respacing == "": 28 | timestep_respacing = [diffusion_steps] 29 | return SpacedDiffusion( 30 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 31 | betas=betas, 32 | model_mean_type=( 33 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 34 | ), 35 | model_var_type=( 36 | ( 37 | gd.ModelVarType.FIXED_LARGE 38 | if not sigma_small 39 | else gd.ModelVarType.FIXED_SMALL 40 | ) 41 | if not learn_sigma 42 | else gd.ModelVarType.LEARNED_RANGE 43 | ), 44 | loss_type=loss_type 45 | ) 46 | 47 | def diffusion_defaults(): 48 | """ 49 | Defaults for image and classifier training. 50 | """ 51 | return dict( 52 | learn_sigma=True, 53 | diffusion_steps=1000, 54 | noise_schedule="linear", 55 | timestep_respacing="", 56 | use_kl=False, 57 | predict_xstart=False, 58 | sigma_small=False, 59 | rescale_learned_sigmas=False, 60 | ) 61 | 62 | def model_and_diffusion_defaults(): 63 | """ 64 | Defaults for image training. 65 | """ 66 | res = dict( 67 | image_size=256, 68 | mask_ratio=None, 69 | decode_layer=None, 70 | class_cond=True, 71 | use_fp16=False, 72 | ) 73 | res.update(diffusion_defaults()) 74 | return res -------------------------------------------------------------------------------- /masked_diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def continuous_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a continuous Gaussian distribution. 50 | :param x: the targets 51 | :param means: the Gaussian mean Tensor. 52 | :param log_scales: the Gaussian log stddev Tensor. 53 | :return: a tensor like x of log probabilities (in nats). 54 | """ 55 | centered_x = x - means 56 | inv_stdv = th.exp(-log_scales) 57 | normalized_x = centered_x * inv_stdv 58 | log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) 59 | return log_probs 60 | 61 | 62 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 63 | """ 64 | Compute the log-likelihood of a Gaussian distribution discretizing to a 65 | given image. 66 | :param x: the target images. It is assumed that this was uint8 values, 67 | rescaled to the range [-1, 1]. 68 | :param means: the Gaussian mean Tensor. 69 | :param log_scales: the Gaussian log stddev Tensor. 70 | :return: a tensor like x of log probabilities (in nats). 71 | """ 72 | assert x.shape == means.shape == log_scales.shape 73 | centered_x = x - means 74 | inv_stdv = th.exp(-log_scales) 75 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 76 | cdf_plus = approx_standard_normal_cdf(plus_in) 77 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 78 | cdf_min = approx_standard_normal_cdf(min_in) 79 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 80 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 81 | cdf_delta = cdf_plus - cdf_min 82 | log_probs = th.where( 83 | x < -0.999, 84 | log_cdf_plus, 85 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 86 | ) 87 | assert log_probs.shape == x.shape 88 | return log_probs 89 | -------------------------------------------------------------------------------- /masked_diffusion/dist_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for distributed training. 3 | """ 4 | 5 | import io 6 | import os 7 | import socket 8 | 9 | import blobfile as bf 10 | from mpi4py import MPI 11 | import torch as th 12 | import torch.distributed as dist 13 | import builtins 14 | import datetime 15 | 16 | # Change this to reflect your cluster layout. 17 | # The GPU for a given rank is (rank % GPUS_PER_NODE). 18 | GPUS_PER_NODE = 8 19 | 20 | SETUP_RETRY_COUNT = 3 21 | def synchronize(): 22 | if not dist.is_available(): 23 | return 24 | 25 | if not dist.is_initialized(): 26 | return 27 | 28 | world_size = dist.get_world_size() 29 | 30 | if world_size == 1: 31 | return 32 | 33 | dist.barrier() 34 | 35 | def is_dist_avail_and_initialized(): 36 | if not dist.is_available(): 37 | return False 38 | if not dist.is_initialized(): 39 | return False 40 | return True 41 | def get_world_size(): 42 | if not is_dist_avail_and_initialized(): 43 | return 1 44 | return dist.get_world_size() 45 | 46 | def setup_for_distributed(is_master): 47 | """ 48 | This function disables printing when not in master process 49 | """ 50 | builtin_print = builtins.print 51 | 52 | def print(*args, **kwargs): 53 | force = kwargs.pop('force', False) 54 | force = force or (get_world_size() > 8) 55 | if is_master or force: 56 | now = datetime.datetime.now().time() 57 | builtin_print('[{}] '.format(now), end='') # print with time stamp 58 | builtin_print(*args, **kwargs) 59 | 60 | builtins.print = print 61 | 62 | def setup_dist_multinode(args): 63 | """ 64 | Setup a distributed process group. 65 | """ 66 | if not dist.is_available() or not dist.is_initialized(): 67 | th.distributed.init_process_group(backend="nccl", init_method='env://') 68 | world_size = dist.get_world_size() 69 | local_rank = int(os.getenv('LOCAL_RANK')) 70 | print("rank",local_rank) 71 | device = local_rank 72 | th.cuda.set_device(device) 73 | setup_for_distributed(device == 0) 74 | 75 | synchronize() 76 | else: 77 | print("ddp failed!") 78 | exit() 79 | 80 | def setup_dist(): 81 | """ 82 | Setup a distributed process group. 83 | """ 84 | if dist.is_initialized(): 85 | return 86 | th.cuda.set_device(int(os.environ["LOCAL_RANK"])) 87 | th.distributed.init_process_group(backend="nccl", init_method="env://") 88 | synchronize() 89 | 90 | def dev(): 91 | """ 92 | Get the device to use for torch.distributed. 93 | """ 94 | if th.cuda.is_available(): 95 | return th.device(f"cuda") 96 | return th.device("cpu") 97 | 98 | 99 | def load_state_dict(path, **kwargs): 100 | """ 101 | Load a PyTorch file without redundant fetches across MPI ranks. 102 | """ 103 | chunk_size = 2 ** 30 # MPI has a relatively small size limit 104 | if MPI.COMM_WORLD.Get_rank() == 0: 105 | with bf.BlobFile(path, "rb") as f: 106 | data = f.read() 107 | num_chunks = len(data) // chunk_size 108 | if len(data) % chunk_size: 109 | num_chunks += 1 110 | MPI.COMM_WORLD.bcast(num_chunks) 111 | for i in range(0, len(data), chunk_size): 112 | MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) 113 | else: 114 | num_chunks = MPI.COMM_WORLD.bcast(None) 115 | data = bytes() 116 | for _ in range(num_chunks): 117 | data += MPI.COMM_WORLD.bcast(None) 118 | 119 | return th.load(io.BytesIO(data), **kwargs) 120 | 121 | 122 | def sync_params(params): 123 | """ 124 | Synchronize a sequence of Tensors across ranks from rank 0. 125 | """ 126 | for p in params: 127 | with th.no_grad(): 128 | dist.broadcast(p, 0) 129 | 130 | 131 | def _find_free_port(): 132 | try: 133 | s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 134 | s.bind(("", 0)) 135 | s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) 136 | return s.getsockname()[1] 137 | finally: 138 | s.close() 139 | -------------------------------------------------------------------------------- /masked_diffusion/fp16_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers to train with 16-bit precision. 3 | """ 4 | 5 | import numpy as np 6 | import torch as th 7 | import torch.nn as nn 8 | from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors 9 | 10 | from . import logger 11 | 12 | INITIAL_LOG_LOSS_SCALE = 20.0 13 | 14 | 15 | def convert_module_to_f16(l): 16 | """ 17 | Convert primitive modules to float16. 18 | """ 19 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 20 | l.weight.data = l.weight.data.half() 21 | if l.bias is not None: 22 | l.bias.data = l.bias.data.half() 23 | 24 | 25 | def convert_module_to_f32(l): 26 | """ 27 | Convert primitive modules to float32, undoing convert_module_to_f16(). 28 | """ 29 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): 30 | l.weight.data = l.weight.data.float() 31 | if l.bias is not None: 32 | l.bias.data = l.bias.data.float() 33 | 34 | 35 | def make_master_params(param_groups_and_shapes): 36 | """ 37 | Copy model parameters into a (differently-shaped) list of full-precision 38 | parameters. 39 | """ 40 | master_params = [] 41 | for param_group, shape in param_groups_and_shapes: 42 | master_param = nn.Parameter( 43 | _flatten_dense_tensors( 44 | [param.detach().float() for (_, param) in param_group] 45 | ).view(shape) 46 | ) 47 | master_param.requires_grad = True 48 | master_params.append(master_param) 49 | return master_params 50 | 51 | 52 | def model_grads_to_master_grads(param_groups_and_shapes, master_params): 53 | """ 54 | Copy the gradients from the model parameters into the master parameters 55 | from make_master_params(). 56 | """ 57 | for master_param, (param_group, shape) in zip( 58 | master_params, param_groups_and_shapes 59 | ): 60 | master_param.grad = _flatten_dense_tensors( 61 | [param_grad_or_zeros(param) for (_, param) in param_group] 62 | ).view(shape) 63 | 64 | 65 | def master_params_to_model_params(param_groups_and_shapes, master_params): 66 | """ 67 | Copy the master parameter data back into the model parameters. 68 | """ 69 | # Without copying to a list, if a generator is passed, this will 70 | # silently not copy any parameters. 71 | for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes): 72 | for (_, param), unflat_master_param in zip( 73 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 74 | ): 75 | param.detach().copy_(unflat_master_param) 76 | 77 | 78 | def unflatten_master_params(param_group, master_param): 79 | return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group]) 80 | 81 | 82 | def get_param_groups_and_shapes(named_model_params): 83 | named_model_params = list(named_model_params) 84 | scalar_vector_named_params = ( 85 | [(n, p) for (n, p) in named_model_params if p.ndim <= 1], 86 | (-1), 87 | ) 88 | matrix_named_params = ( 89 | [(n, p) for (n, p) in named_model_params if p.ndim > 1], 90 | (1, -1), 91 | ) 92 | return [scalar_vector_named_params, matrix_named_params] 93 | 94 | 95 | def master_params_to_state_dict( 96 | model, param_groups_and_shapes, master_params, use_fp16 97 | ): 98 | if use_fp16: 99 | state_dict = model.state_dict() 100 | for master_param, (param_group, _) in zip( 101 | master_params, param_groups_and_shapes 102 | ): 103 | for (name, _), unflat_master_param in zip( 104 | param_group, unflatten_master_params(param_group, master_param.view(-1)) 105 | ): 106 | assert name in state_dict 107 | state_dict[name] = unflat_master_param 108 | else: 109 | state_dict = model.state_dict() 110 | for i, (name, _value) in enumerate(model.named_parameters()): 111 | assert name in state_dict 112 | state_dict[name] = master_params[i] 113 | return state_dict 114 | 115 | 116 | def state_dict_to_master_params(model, state_dict, use_fp16): 117 | if use_fp16: 118 | named_model_params = [ 119 | (name, state_dict[name]) for name, _ in model.named_parameters() 120 | ] 121 | param_groups_and_shapes = get_param_groups_and_shapes(named_model_params) 122 | master_params = make_master_params(param_groups_and_shapes) 123 | else: 124 | master_params = [state_dict[name] for name, _ in model.named_parameters()] 125 | return master_params 126 | 127 | 128 | def zero_master_grads(master_params): 129 | for param in master_params: 130 | param.grad = None 131 | 132 | 133 | def zero_grad(model_params): 134 | for param in model_params: 135 | # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group 136 | if param.grad is not None: 137 | param.grad.detach_() 138 | param.grad.zero_() 139 | 140 | 141 | def param_grad_or_zeros(param): 142 | if param.grad is not None: 143 | return param.grad.data.detach() 144 | else: 145 | return th.zeros_like(param) 146 | 147 | 148 | class MixedPrecisionTrainer: 149 | def __init__( 150 | self, 151 | *, 152 | model, 153 | use_fp16=False, 154 | fp16_scale_growth=1e-3, 155 | initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE, 156 | ): 157 | self.model = model 158 | self.use_fp16 = use_fp16 159 | self.fp16_scale_growth = fp16_scale_growth 160 | 161 | self.model_params = list(self.model.parameters()) 162 | self.master_params = self.model_params 163 | self.param_groups_and_shapes = None 164 | self.lg_loss_scale = initial_lg_loss_scale 165 | 166 | if self.use_fp16: 167 | self.param_groups_and_shapes = get_param_groups_and_shapes( 168 | self.model.named_parameters() 169 | ) 170 | self.master_params = make_master_params(self.param_groups_and_shapes) 171 | self.model.convert_to_fp16() 172 | 173 | def zero_grad(self): 174 | zero_grad(self.model_params) 175 | 176 | def backward(self, loss: th.Tensor): 177 | if self.use_fp16: 178 | loss_scale = 2 ** self.lg_loss_scale 179 | (loss * loss_scale).backward() 180 | else: 181 | loss.backward() 182 | 183 | def optimize(self, opt: th.optim.Optimizer): 184 | if self.use_fp16: 185 | return self._optimize_fp16(opt) 186 | else: 187 | return self._optimize_normal(opt) 188 | 189 | def _optimize_fp16(self, opt: th.optim.Optimizer): 190 | logger.logkv_mean("lg_loss_scale", self.lg_loss_scale) 191 | model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params) 192 | grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale) 193 | if check_overflow(grad_norm): 194 | self.lg_loss_scale -= 1 195 | logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}") 196 | zero_master_grads(self.master_params) 197 | return False 198 | 199 | logger.logkv_mean("grad_norm", grad_norm) 200 | logger.logkv_mean("param_norm", param_norm) 201 | 202 | for p in self.master_params: 203 | p.grad.mul_(1.0 / (2 ** self.lg_loss_scale)) 204 | opt.step() 205 | zero_master_grads(self.master_params) 206 | master_params_to_model_params(self.param_groups_and_shapes, self.master_params) 207 | self.lg_loss_scale += self.fp16_scale_growth 208 | return True 209 | 210 | def _optimize_normal(self, opt: th.optim.Optimizer): 211 | grad_norm, param_norm = self._compute_norms() 212 | logger.logkv_mean("grad_norm", grad_norm) 213 | logger.logkv_mean("param_norm", param_norm) 214 | opt.step() 215 | return True 216 | 217 | def _compute_norms(self, grad_scale=1.0): 218 | grad_norm = 0.0 219 | param_norm = 0.0 220 | for p in self.master_params: 221 | with th.no_grad(): 222 | param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2 223 | if p.grad is not None: 224 | grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2 225 | return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm) 226 | 227 | def master_params_to_state_dict(self, master_params): 228 | return master_params_to_state_dict( 229 | self.model, self.param_groups_and_shapes, master_params, self.use_fp16 230 | ) 231 | 232 | def state_dict_to_master_params(self, state_dict): 233 | return state_dict_to_master_params(self.model, state_dict, self.use_fp16) 234 | 235 | 236 | def check_overflow(value): 237 | return (value == float("inf")) or (value == -float("inf")) or (value != value) 238 | -------------------------------------------------------------------------------- /masked_diffusion/gaussian_diffusion.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | 7 | import math 8 | 9 | import numpy as np 10 | import torch as th 11 | import enum 12 | 13 | from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl 14 | 15 | 16 | def mean_flat(tensor): 17 | """ 18 | Take the mean over all non-batch dimensions. 19 | """ 20 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 21 | 22 | 23 | class ModelMeanType(enum.Enum): 24 | """ 25 | Which type of output the model predicts. 26 | """ 27 | 28 | PREVIOUS_X = enum.auto() # the model predicts x_{t-1} 29 | START_X = enum.auto() # the model predicts x_0 30 | EPSILON = enum.auto() # the model predicts epsilon 31 | VELOCITY = enum.auto() # the model predicts v 32 | 33 | 34 | class ModelVarType(enum.Enum): 35 | """ 36 | What is used as the model's output variance. 37 | The LEARNED_RANGE option has been added to allow the model to predict 38 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 39 | """ 40 | 41 | LEARNED = enum.auto() 42 | FIXED_SMALL = enum.auto() 43 | FIXED_LARGE = enum.auto() 44 | LEARNED_RANGE = enum.auto() 45 | 46 | 47 | class LossType(enum.Enum): 48 | MSE = enum.auto() # use raw MSE loss (and KL when learning variances) 49 | RESCALED_MSE = ( 50 | enum.auto() 51 | ) # use raw MSE loss (with RESCALED_KL when learning variances) 52 | KL = enum.auto() # use the variational lower-bound 53 | RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB 54 | 55 | def is_vb(self): 56 | return self == LossType.KL or self == LossType.RESCALED_KL 57 | 58 | 59 | def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): 60 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 61 | warmup_time = int(num_diffusion_timesteps * warmup_frac) 62 | betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) 63 | return betas 64 | 65 | 66 | def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): 67 | """ 68 | This is the deprecated API for creating beta schedules. 69 | See get_named_beta_schedule() for the new library of schedules. 70 | """ 71 | if beta_schedule == "quad": 72 | betas = ( 73 | np.linspace( 74 | beta_start ** 0.5, 75 | beta_end ** 0.5, 76 | num_diffusion_timesteps, 77 | dtype=np.float64, 78 | ) 79 | ** 2 80 | ) 81 | elif beta_schedule == "linear": 82 | betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) 83 | elif beta_schedule == "warmup10": 84 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) 85 | elif beta_schedule == "warmup50": 86 | betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) 87 | elif beta_schedule == "const": 88 | betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) 89 | elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 90 | betas = 1.0 / np.linspace( 91 | num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 92 | ) 93 | else: 94 | raise NotImplementedError(beta_schedule) 95 | assert betas.shape == (num_diffusion_timesteps,) 96 | return betas 97 | 98 | 99 | def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): 100 | """ 101 | Get a pre-defined beta schedule for the given name. 102 | The beta schedule library consists of beta schedules which remain similar 103 | in the limit of num_diffusion_timesteps. 104 | Beta schedules may be added, but should not be removed or changed once 105 | they are committed to maintain backwards compatibility. 106 | """ 107 | if schedule_name == "linear": 108 | # Linear schedule from Ho et al, extended to work for any number of 109 | # diffusion steps. 110 | scale = 1000 / num_diffusion_timesteps 111 | return get_beta_schedule( 112 | "linear", 113 | beta_start=scale * 0.0001, 114 | beta_end=scale * 0.02, 115 | num_diffusion_timesteps=num_diffusion_timesteps, 116 | ) 117 | elif schedule_name == "squaredcos_cap_v2": 118 | return betas_for_alpha_bar( 119 | num_diffusion_timesteps, 120 | lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, 121 | ) 122 | else: 123 | raise NotImplementedError(f"unknown beta schedule: {schedule_name}") 124 | 125 | 126 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 127 | """ 128 | Create a beta schedule that discretizes the given alpha_t_bar function, 129 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 130 | :param num_diffusion_timesteps: the number of betas to produce. 131 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 132 | produces the cumulative product of (1-beta) up to that 133 | part of the diffusion process. 134 | :param max_beta: the maximum beta to use; use values lower than 1 to 135 | prevent singularities. 136 | """ 137 | betas = [] 138 | for i in range(num_diffusion_timesteps): 139 | t1 = i / num_diffusion_timesteps 140 | t2 = (i + 1) / num_diffusion_timesteps 141 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 142 | return np.array(betas) 143 | 144 | 145 | class GaussianDiffusion: 146 | """ 147 | Utilities for training and sampling diffusion models. 148 | Original ported from this codebase: 149 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 150 | :param betas: a 1-D numpy array of betas for each diffusion timestep, 151 | starting at T and going to 1. 152 | """ 153 | 154 | def __init__( 155 | self, 156 | *, 157 | betas, 158 | model_mean_type, 159 | model_var_type, 160 | loss_type 161 | ): 162 | 163 | self.model_mean_type = model_mean_type 164 | self.model_var_type = model_var_type 165 | self.loss_type = loss_type 166 | 167 | # Use float64 for accuracy. 168 | betas = np.array(betas, dtype=np.float64) 169 | self.betas = betas 170 | assert len(betas.shape) == 1, "betas must be 1-D" 171 | assert (betas > 0).all() and (betas <= 1).all() 172 | 173 | self.num_timesteps = int(betas.shape[0]) 174 | 175 | alphas = 1.0 - betas 176 | self.alphas_cumprod = np.cumprod(alphas, axis=0) 177 | self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) 178 | self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) 179 | assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) 180 | 181 | # calculations for diffusion q(x_t | x_{t-1}) and others 182 | self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) 183 | self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) 184 | self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) 185 | self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) 186 | self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) 187 | 188 | # calculations for posterior q(x_{t-1} | x_t, x_0) 189 | self.posterior_variance = ( 190 | betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 191 | ) 192 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 193 | self.posterior_log_variance_clipped = np.log( 194 | np.append(self.posterior_variance[1], self.posterior_variance[1:]) 195 | ) if len(self.posterior_variance) > 1 else np.array([]) 196 | 197 | self.posterior_mean_coef1 = ( 198 | betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 199 | ) 200 | self.posterior_mean_coef2 = ( 201 | (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) 202 | ) 203 | 204 | def q_mean_variance(self, x_start, t): 205 | """ 206 | Get the distribution q(x_t | x_0). 207 | :param x_start: the [N x C x ...] tensor of noiseless inputs. 208 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 209 | :return: A tuple (mean, variance, log_variance), all of x_start's shape. 210 | """ 211 | mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 212 | variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) 213 | log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) 214 | return mean, variance, log_variance 215 | 216 | def q_sample(self, x_start, t, noise=None): 217 | """ 218 | Diffuse the data for a given number of diffusion steps. 219 | In other words, sample from q(x_t | x_0). 220 | :param x_start: the initial data batch. 221 | :param t: the number of diffusion steps (minus 1). Here, 0 means one step. 222 | :param noise: if specified, the split-out normal noise. 223 | :return: A noisy version of x_start. 224 | """ 225 | if noise is None: 226 | noise = th.randn_like(x_start) 227 | assert noise.shape == x_start.shape 228 | return ( 229 | _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start 230 | + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise 231 | ) 232 | 233 | def q_posterior_mean_variance(self, x_start, x_t, t): 234 | """ 235 | Compute the mean and variance of the diffusion posterior: 236 | q(x_{t-1} | x_t, x_0) 237 | """ 238 | assert x_start.shape == x_t.shape 239 | posterior_mean = ( 240 | _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start 241 | + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t 242 | ) 243 | posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) 244 | posterior_log_variance_clipped = _extract_into_tensor( 245 | self.posterior_log_variance_clipped, t, x_t.shape 246 | ) 247 | assert ( 248 | posterior_mean.shape[0] 249 | == posterior_variance.shape[0] 250 | == posterior_log_variance_clipped.shape[0] 251 | == x_start.shape[0] 252 | ) 253 | return posterior_mean, posterior_variance, posterior_log_variance_clipped 254 | 255 | def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): 256 | """ 257 | Apply the model to get p(x_{t-1} | x_t), as well as a prediction of 258 | the initial x, x_0. 259 | :param model: the model, which takes a signal and a batch of timesteps 260 | as input. 261 | :param x: the [N x C x ...] tensor at time t. 262 | :param t: a 1-D Tensor of timesteps. 263 | :param clip_denoised: if True, clip the denoised signal into [-1, 1]. 264 | :param denoised_fn: if not None, a function which applies to the 265 | x_start prediction before it is used to sample. Applies before 266 | clip_denoised. 267 | :param model_kwargs: if not None, a dict of extra keyword arguments to 268 | pass to the model. This can be used for conditioning. 269 | :return: a dict with the following keys: 270 | - 'mean': the model mean output. 271 | - 'variance': the model variance output. 272 | - 'log_variance': the log of 'variance'. 273 | - 'pred_xstart': the prediction for x_0. 274 | """ 275 | if model_kwargs is None: 276 | model_kwargs = {} 277 | 278 | B, C = x.shape[:2] 279 | assert t.shape == (B,) 280 | model_output = model(x, t, **model_kwargs) 281 | if isinstance(model_output, tuple): 282 | model_output, extra = model_output 283 | else: 284 | extra = None 285 | 286 | if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: 287 | assert model_output.shape == (B, C * 2, *x.shape[2:]) 288 | model_output, model_var_values = th.split(model_output, C, dim=1) 289 | min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) 290 | max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) 291 | # The model_var_values is [-1, 1] for [min_var, max_var]. 292 | frac = (model_var_values + 1) / 2 293 | model_log_variance = frac * max_log + (1 - frac) * min_log 294 | model_variance = th.exp(model_log_variance) 295 | else: 296 | model_variance, model_log_variance = { 297 | # for fixedlarge, we set the initial (log-)variance like so 298 | # to get a better decoder log likelihood. 299 | ModelVarType.FIXED_LARGE: ( 300 | np.append(self.posterior_variance[1], self.betas[1:]), 301 | np.log(np.append(self.posterior_variance[1], self.betas[1:])), 302 | ), 303 | ModelVarType.FIXED_SMALL: ( 304 | self.posterior_variance, 305 | self.posterior_log_variance_clipped, 306 | ), 307 | }[self.model_var_type] 308 | model_variance = _extract_into_tensor(model_variance, t, x.shape) 309 | model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) 310 | 311 | def process_xstart(x): 312 | if denoised_fn is not None: 313 | x = denoised_fn(x) 314 | if clip_denoised: 315 | return x.clamp(-1, 1) 316 | return x 317 | 318 | if self.model_mean_type == ModelMeanType.START_X: 319 | pred_xstart = process_xstart(model_output) 320 | else: 321 | pred_xstart = process_xstart( 322 | self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) 323 | ) 324 | model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) 325 | 326 | assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape 327 | return { 328 | "mean": model_mean, 329 | "variance": model_variance, 330 | "log_variance": model_log_variance, 331 | "pred_xstart": pred_xstart, 332 | "extra": extra, 333 | } 334 | 335 | def _predict_xstart_from_eps(self, x_t, t, eps): 336 | assert x_t.shape == eps.shape 337 | return ( 338 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t 339 | - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps 340 | ) 341 | 342 | def _predict_eps_from_xstart(self, x_t, t, pred_xstart): 343 | return ( 344 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart 345 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) 346 | 347 | def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 348 | """ 349 | Compute the mean for the previous step, given a function cond_fn that 350 | computes the gradient of a conditional log probability with respect to 351 | x. In particular, cond_fn computes grad(log(p(y|x))), and we want to 352 | condition on y. 353 | This uses the conditioning strategy from Sohl-Dickstein et al. (2015). 354 | """ 355 | gradient = cond_fn(x, t, **model_kwargs) 356 | new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() 357 | return new_mean 358 | 359 | def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): 360 | """ 361 | Compute what the p_mean_variance output would have been, should the 362 | model's score function be conditioned by cond_fn. 363 | See condition_mean() for details on cond_fn. 364 | Unlike condition_mean(), this instead uses the conditioning strategy 365 | from Song et al (2020). 366 | """ 367 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 368 | 369 | eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) 370 | eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) 371 | 372 | out = p_mean_var.copy() 373 | out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) 374 | out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) 375 | return out 376 | 377 | def p_sample( 378 | self, 379 | model, 380 | x, 381 | t, 382 | clip_denoised=True, 383 | denoised_fn=None, 384 | cond_fn=None, 385 | model_kwargs=None, 386 | ): 387 | """ 388 | Sample x_{t-1} from the model at the given timestep. 389 | :param model: the model to sample from. 390 | :param x: the current tensor at x_{t-1}. 391 | :param t: the value of t, starting at 0 for the first diffusion step. 392 | :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. 393 | :param denoised_fn: if not None, a function which applies to the 394 | x_start prediction before it is used to sample. 395 | :param cond_fn: if not None, this is a gradient function that acts 396 | similarly to the model. 397 | :param model_kwargs: if not None, a dict of extra keyword arguments to 398 | pass to the model. This can be used for conditioning. 399 | :return: a dict containing the following keys: 400 | - 'sample': a random sample from the model. 401 | - 'pred_xstart': a prediction of x_0. 402 | """ 403 | out = self.p_mean_variance( 404 | model, 405 | x, 406 | t, 407 | clip_denoised=clip_denoised, 408 | denoised_fn=denoised_fn, 409 | model_kwargs=model_kwargs, 410 | ) 411 | noise = th.randn_like(x) 412 | nonzero_mask = ( 413 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 414 | ) # no noise when t == 0 415 | if cond_fn is not None: 416 | out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) 417 | sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise 418 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 419 | 420 | def p_sample_loop( 421 | self, 422 | model, 423 | shape, 424 | noise=None, 425 | clip_denoised=True, 426 | denoised_fn=None, 427 | cond_fn=None, 428 | model_kwargs=None, 429 | device=None, 430 | progress=False, 431 | ): 432 | """ 433 | Generate samples from the model. 434 | :param model: the model module. 435 | :param shape: the shape of the samples, (N, C, H, W). 436 | :param noise: if specified, the noise from the encoder to sample. 437 | Should be of the same shape as `shape`. 438 | :param clip_denoised: if True, clip x_start predictions to [-1, 1]. 439 | :param denoised_fn: if not None, a function which applies to the 440 | x_start prediction before it is used to sample. 441 | :param cond_fn: if not None, this is a gradient function that acts 442 | similarly to the model. 443 | :param model_kwargs: if not None, a dict of extra keyword arguments to 444 | pass to the model. This can be used for conditioning. 445 | :param device: if specified, the device to create the samples on. 446 | If not specified, use a model parameter's device. 447 | :param progress: if True, show a tqdm progress bar. 448 | :return: a non-differentiable batch of samples. 449 | """ 450 | final = None 451 | for sample in self.p_sample_loop_progressive( 452 | model, 453 | shape, 454 | noise=noise, 455 | clip_denoised=clip_denoised, 456 | denoised_fn=denoised_fn, 457 | cond_fn=cond_fn, 458 | model_kwargs=model_kwargs, 459 | device=device, 460 | progress=progress, 461 | ): 462 | final = sample 463 | return final["sample"] 464 | 465 | def p_sample_loop_progressive( 466 | self, 467 | model, 468 | shape, 469 | noise=None, 470 | clip_denoised=True, 471 | denoised_fn=None, 472 | cond_fn=None, 473 | model_kwargs=None, 474 | device=None, 475 | progress=False, 476 | ): 477 | """ 478 | Generate samples from the model and yield intermediate samples from 479 | each timestep of diffusion. 480 | Arguments are the same as p_sample_loop(). 481 | Returns a generator over dicts, where each dict is the return value of 482 | p_sample(). 483 | """ 484 | if device is None: 485 | device = next(model.parameters()).device 486 | assert isinstance(shape, (tuple, list)) 487 | if noise is not None: 488 | img = noise 489 | else: 490 | img = th.randn(*shape, device=device) 491 | indices = list(range(self.num_timesteps))[::-1] 492 | 493 | if progress: 494 | # Lazy import so that we don't depend on tqdm. 495 | from tqdm.auto import tqdm 496 | 497 | indices = tqdm(indices) 498 | 499 | for i in indices: 500 | t = th.tensor([i] * shape[0], device=device) 501 | with th.no_grad(): 502 | out = self.p_sample( 503 | model, 504 | img, 505 | t, 506 | clip_denoised=clip_denoised, 507 | denoised_fn=denoised_fn, 508 | cond_fn=cond_fn, 509 | model_kwargs=model_kwargs, 510 | ) 511 | yield out 512 | img = out["sample"] 513 | 514 | def ddim_sample( 515 | self, 516 | model, 517 | x, 518 | t, 519 | clip_denoised=True, 520 | denoised_fn=None, 521 | cond_fn=None, 522 | model_kwargs=None, 523 | eta=0.0, 524 | ): 525 | """ 526 | Sample x_{t-1} from the model using DDIM. 527 | Same usage as p_sample(). 528 | """ 529 | out = self.p_mean_variance( 530 | model, 531 | x, 532 | t, 533 | clip_denoised=clip_denoised, 534 | denoised_fn=denoised_fn, 535 | model_kwargs=model_kwargs, 536 | ) 537 | if cond_fn is not None: 538 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 539 | 540 | # Usually our model outputs epsilon, but we re-derive it 541 | # in case we used x_start or x_prev prediction. 542 | eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) 543 | 544 | alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) 545 | alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) 546 | sigma = ( 547 | eta 548 | * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) 549 | * th.sqrt(1 - alpha_bar / alpha_bar_prev) 550 | ) 551 | # Equation 12. 552 | noise = th.randn_like(x) 553 | mean_pred = ( 554 | out["pred_xstart"] * th.sqrt(alpha_bar_prev) 555 | + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps 556 | ) 557 | nonzero_mask = ( 558 | (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) 559 | ) # no noise when t == 0 560 | sample = mean_pred + nonzero_mask * sigma * noise 561 | return {"sample": sample, "pred_xstart": out["pred_xstart"]} 562 | 563 | def ddim_reverse_sample( 564 | self, 565 | model, 566 | x, 567 | t, 568 | clip_denoised=True, 569 | denoised_fn=None, 570 | cond_fn=None, 571 | model_kwargs=None, 572 | eta=0.0, 573 | ): 574 | """ 575 | Sample x_{t+1} from the model using DDIM reverse ODE. 576 | """ 577 | assert eta == 0.0, "Reverse ODE only for deterministic path" 578 | out = self.p_mean_variance( 579 | model, 580 | x, 581 | t, 582 | clip_denoised=clip_denoised, 583 | denoised_fn=denoised_fn, 584 | model_kwargs=model_kwargs, 585 | ) 586 | if cond_fn is not None: 587 | out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) 588 | # Usually our model outputs epsilon, but we re-derive it 589 | # in case we used x_start or x_prev prediction. 590 | eps = ( 591 | _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x 592 | - out["pred_xstart"] 593 | ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) 594 | alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) 595 | 596 | # Equation 12. reversed 597 | mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps 598 | 599 | return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} 600 | 601 | def ddim_sample_loop( 602 | self, 603 | model, 604 | shape, 605 | noise=None, 606 | clip_denoised=True, 607 | denoised_fn=None, 608 | cond_fn=None, 609 | model_kwargs=None, 610 | device=None, 611 | progress=False, 612 | eta=0.0, 613 | ): 614 | """ 615 | Generate samples from the model using DDIM. 616 | Same usage as p_sample_loop(). 617 | """ 618 | final = None 619 | for sample in self.ddim_sample_loop_progressive( 620 | model, 621 | shape, 622 | noise=noise, 623 | clip_denoised=clip_denoised, 624 | denoised_fn=denoised_fn, 625 | cond_fn=cond_fn, 626 | model_kwargs=model_kwargs, 627 | device=device, 628 | progress=progress, 629 | eta=eta, 630 | ): 631 | final = sample 632 | return final["sample"] 633 | 634 | def ddim_sample_loop_progressive( 635 | self, 636 | model, 637 | shape, 638 | noise=None, 639 | clip_denoised=True, 640 | denoised_fn=None, 641 | cond_fn=None, 642 | model_kwargs=None, 643 | device=None, 644 | progress=False, 645 | eta=0.0, 646 | ): 647 | """ 648 | Use DDIM to sample from the model and yield intermediate samples from 649 | each timestep of DDIM. 650 | Same usage as p_sample_loop_progressive(). 651 | """ 652 | if device is None: 653 | device = next(model.parameters()).device 654 | assert isinstance(shape, (tuple, list)) 655 | if noise is not None: 656 | img = noise 657 | else: 658 | img = th.randn(*shape, device=device) 659 | indices = list(range(self.num_timesteps))[::-1] 660 | 661 | if progress: 662 | # Lazy import so that we don't depend on tqdm. 663 | from tqdm.auto import tqdm 664 | 665 | indices = tqdm(indices) 666 | 667 | for i in indices: 668 | t = th.tensor([i] * shape[0], device=device) 669 | with th.no_grad(): 670 | out = self.ddim_sample( 671 | model, 672 | img, 673 | t, 674 | clip_denoised=clip_denoised, 675 | denoised_fn=denoised_fn, 676 | cond_fn=cond_fn, 677 | model_kwargs=model_kwargs, 678 | eta=eta, 679 | ) 680 | yield out 681 | img = out["sample"] 682 | 683 | def _vb_terms_bpd( 684 | self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None 685 | ): 686 | """ 687 | Get a term for the variational lower-bound. 688 | The resulting units are bits (rather than nats, as one might expect). 689 | This allows for comparison to other papers. 690 | :return: a dict with the following keys: 691 | - 'output': a shape [N] tensor of NLLs or KLs. 692 | - 'pred_xstart': the x_0 predictions. 693 | """ 694 | true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( 695 | x_start=x_start, x_t=x_t, t=t 696 | ) 697 | out = self.p_mean_variance( 698 | model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs 699 | ) 700 | kl = normal_kl( 701 | true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] 702 | ) 703 | kl = mean_flat(kl) / np.log(2.0) 704 | 705 | decoder_nll = -discretized_gaussian_log_likelihood( 706 | x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] 707 | ) 708 | assert decoder_nll.shape == x_start.shape 709 | decoder_nll = mean_flat(decoder_nll) / np.log(2.0) 710 | 711 | # At the first timestep return the decoder NLL, 712 | # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) 713 | output = th.where((t == 0), decoder_nll, kl) 714 | return {"output": output, "pred_xstart": out["pred_xstart"]} 715 | 716 | def training_losses(self, model, x_start, t, model_kwargs=None, noise=None): 717 | """ 718 | Compute training losses for a single timestep. 719 | :param model: the model to evaluate loss on. 720 | :param x_start: the [N x C x ...] tensor of inputs. 721 | :param t: a batch of timestep indices. 722 | :param model_kwargs: if not None, a dict of extra keyword arguments to 723 | pass to the model. This can be used for conditioning. 724 | :param noise: if specified, the specific Gaussian noise to try to remove. 725 | :return: a dict with the key "loss" containing a tensor of shape [N]. 726 | Some mean or variance settings may also have other keys. 727 | """ 728 | if model_kwargs is None: 729 | model_kwargs = {} 730 | if noise is None: 731 | noise = th.randn_like(x_start) 732 | x_t = self.q_sample(x_start, t, noise=noise) 733 | 734 | terms = {} 735 | 736 | 737 | mse_loss_weight = None 738 | alpha = _extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape) 739 | sigma = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape) 740 | snr = (alpha / sigma) ** 2 741 | 742 | velocity = (alpha[:, None, None, None] * x_t - x_start) / sigma[:, None, None, None] 743 | 744 | # get loss weight 745 | if self.model_mean_type is not ModelMeanType.START_X: 746 | mse_loss_weight = th.ones_like(t) 747 | k = 5.0 748 | # min{snr, k} 749 | mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] / snr 750 | else: 751 | k = 5.0 752 | # min{snr, k} 753 | mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] 754 | 755 | 756 | if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: 757 | terms["loss"] = self._vb_terms_bpd( 758 | model=model, 759 | x_start=x_start, 760 | x_t=x_t, 761 | t=t, 762 | clip_denoised=False, 763 | model_kwargs=model_kwargs, 764 | )["output"] 765 | if self.loss_type == LossType.RESCALED_KL: 766 | terms["loss"] *= self.num_timesteps 767 | elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: 768 | model_output = model(x_t, t, **model_kwargs) 769 | 770 | if self.model_var_type in [ 771 | ModelVarType.LEARNED, 772 | ModelVarType.LEARNED_RANGE, 773 | ]: 774 | B, C = x_t.shape[:2] 775 | assert model_output.shape == (B, C * 2, *x_t.shape[2:]) 776 | model_output, model_var_values = th.split(model_output, C, dim=1) 777 | # Learn the variance using the variational bound, but don't let 778 | # it affect our mean prediction. 779 | frozen_out = th.cat([model_output.detach(), model_var_values], dim=1) 780 | terms["vb"] = self._vb_terms_bpd( 781 | model=lambda *args, r=frozen_out: r, 782 | x_start=x_start, 783 | x_t=x_t, 784 | t=t, 785 | clip_denoised=False, 786 | )["output"] 787 | if self.loss_type == LossType.RESCALED_MSE: 788 | # Divide by 1000 for equivalence with initial implementation. 789 | # Without a factor of 1/1000, the VB term hurts the MSE term. 790 | terms["vb"] *= self.num_timesteps / 1000.0 791 | 792 | target = { 793 | ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( 794 | x_start=x_start, x_t=x_t, t=t 795 | )[0], 796 | ModelMeanType.START_X: x_start, 797 | ModelMeanType.EPSILON: noise, 798 | ModelMeanType.VELOCITY: velocity, 799 | }[self.model_mean_type] 800 | assert model_output.shape == target.shape == x_start.shape 801 | terms["mse"] = mse_loss_weight * mean_flat((target - model_output) ** 2) 802 | if "vb" in terms: 803 | terms["loss"] = terms["mse"] + terms["vb"] 804 | else: 805 | terms["loss"] = terms["mse"] 806 | else: 807 | raise NotImplementedError(self.loss_type) 808 | 809 | return terms 810 | 811 | def _prior_bpd(self, x_start): 812 | """ 813 | Get the prior KL term for the variational lower-bound, measured in 814 | bits-per-dim. 815 | This term can't be optimized, as it only depends on the encoder. 816 | :param x_start: the [N x C x ...] tensor of inputs. 817 | :return: a batch of [N] KL values (in bits), one per batch element. 818 | """ 819 | batch_size = x_start.shape[0] 820 | t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) 821 | qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) 822 | kl_prior = normal_kl( 823 | mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 824 | ) 825 | return mean_flat(kl_prior) / np.log(2.0) 826 | 827 | def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): 828 | """ 829 | Compute the entire variational lower-bound, measured in bits-per-dim, 830 | as well as other related quantities. 831 | :param model: the model to evaluate loss on. 832 | :param x_start: the [N x C x ...] tensor of inputs. 833 | :param clip_denoised: if True, clip denoised samples. 834 | :param model_kwargs: if not None, a dict of extra keyword arguments to 835 | pass to the model. This can be used for conditioning. 836 | :return: a dict containing the following keys: 837 | - total_bpd: the total variational lower-bound, per batch element. 838 | - prior_bpd: the prior term in the lower-bound. 839 | - vb: an [N x T] tensor of terms in the lower-bound. 840 | - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. 841 | - mse: an [N x T] tensor of epsilon MSEs for each timestep. 842 | """ 843 | device = x_start.device 844 | batch_size = x_start.shape[0] 845 | 846 | vb = [] 847 | xstart_mse = [] 848 | mse = [] 849 | for t in list(range(self.num_timesteps))[::-1]: 850 | t_batch = th.tensor([t] * batch_size, device=device) 851 | noise = th.randn_like(x_start) 852 | x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) 853 | # Calculate VLB term at the current timestep 854 | with th.no_grad(): 855 | out = self._vb_terms_bpd( 856 | model, 857 | x_start=x_start, 858 | x_t=x_t, 859 | t=t_batch, 860 | clip_denoised=clip_denoised, 861 | model_kwargs=model_kwargs, 862 | ) 863 | vb.append(out["output"]) 864 | xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) 865 | eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) 866 | mse.append(mean_flat((eps - noise) ** 2)) 867 | 868 | vb = th.stack(vb, dim=1) 869 | xstart_mse = th.stack(xstart_mse, dim=1) 870 | mse = th.stack(mse, dim=1) 871 | 872 | prior_bpd = self._prior_bpd(x_start) 873 | total_bpd = vb.sum(dim=1) + prior_bpd 874 | return { 875 | "total_bpd": total_bpd, 876 | "prior_bpd": prior_bpd, 877 | "vb": vb, 878 | "xstart_mse": xstart_mse, 879 | "mse": mse, 880 | } 881 | 882 | 883 | def _extract_into_tensor(arr, timesteps, broadcast_shape): 884 | """ 885 | Extract values from a 1-D numpy array for a batch of indices. 886 | :param arr: the 1-D numpy array. 887 | :param timesteps: a tensor of indices into the array to extract. 888 | :param broadcast_shape: a larger shape of K dimensions with the batch 889 | dimension equal to the length of timesteps. 890 | :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. 891 | """ 892 | res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() 893 | while len(res.shape) < len(broadcast_shape): 894 | res = res[..., None] 895 | return res + th.zeros(broadcast_shape, device=timesteps.device) 896 | -------------------------------------------------------------------------------- /masked_diffusion/image_datasets.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | 4 | from PIL import Image 5 | import blobfile as bf 6 | from mpi4py import MPI 7 | import numpy as np 8 | from torch.utils.data import DataLoader, Dataset 9 | 10 | 11 | def load_data( 12 | *, 13 | data_dir, 14 | batch_size, 15 | image_size, 16 | class_cond=False, 17 | deterministic=False, 18 | random_crop=False, 19 | random_flip=True, 20 | ): 21 | """ 22 | For a dataset, create a generator over (images, kwargs) pairs. 23 | 24 | Each images is an NCHW float tensor, and the kwargs dict contains zero or 25 | more keys, each of which map to a batched Tensor of their own. 26 | The kwargs dict can be used for class labels, in which case the key is "y" 27 | and the values are integer tensors of class labels. 28 | 29 | :param data_dir: a dataset directory. 30 | :param batch_size: the batch size of each returned pair. 31 | :param image_size: the size to which images are resized. 32 | :param class_cond: if True, include a "y" key in returned dicts for class 33 | label. If classes are not available and this is true, an 34 | exception will be raised. 35 | :param deterministic: if True, yield results in a deterministic order. 36 | :param random_crop: if True, randomly crop the images for augmentation. 37 | :param random_flip: if True, randomly flip the images for augmentation. 38 | """ 39 | if not data_dir: 40 | raise ValueError("unspecified data directory") 41 | all_files = _list_image_files_recursively(data_dir) 42 | classes = None 43 | if class_cond: 44 | # Assume classes are the first part of the filename, 45 | # before an underscore. 46 | class_names = [bf.basename(path).split("_")[0] for path in all_files] 47 | sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))} 48 | classes = [sorted_classes[x] for x in class_names] 49 | dataset = ImageDataset( 50 | image_size, 51 | all_files, 52 | classes=classes, 53 | shard=MPI.COMM_WORLD.Get_rank(), 54 | num_shards=MPI.COMM_WORLD.Get_size(), 55 | random_crop=random_crop, 56 | random_flip=random_flip, 57 | ) 58 | if deterministic: 59 | loader = DataLoader( 60 | dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True 61 | ) 62 | else: 63 | loader = DataLoader( 64 | dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True 65 | ) 66 | while True: 67 | yield from loader 68 | 69 | 70 | def _list_image_files_recursively(data_dir): 71 | results = [] 72 | for entry in sorted(bf.listdir(data_dir)): 73 | full_path = bf.join(data_dir, entry) 74 | ext = entry.split(".")[-1] 75 | if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]: 76 | results.append(full_path) 77 | elif bf.isdir(full_path): 78 | results.extend(_list_image_files_recursively(full_path)) 79 | return results 80 | 81 | 82 | class ImageDataset(Dataset): 83 | def __init__( 84 | self, 85 | resolution, 86 | image_paths, 87 | classes=None, 88 | shard=0, 89 | num_shards=1, 90 | random_crop=False, 91 | random_flip=True, 92 | ): 93 | super().__init__() 94 | self.resolution = resolution 95 | self.local_images = image_paths[shard:][::num_shards] 96 | self.local_classes = None if classes is None else classes[shard:][::num_shards] 97 | self.random_crop = random_crop 98 | self.random_flip = random_flip 99 | 100 | def __len__(self): 101 | return len(self.local_images) 102 | 103 | def __getitem__(self, idx): 104 | path = self.local_images[idx] 105 | with bf.BlobFile(path, "rb") as f: 106 | pil_image = Image.open(f) 107 | pil_image.load() 108 | pil_image = pil_image.convert("RGB") 109 | 110 | if self.random_crop: 111 | arr = random_crop_arr(pil_image, self.resolution) 112 | else: 113 | arr = center_crop_arr(pil_image, self.resolution) 114 | 115 | if self.random_flip and random.random() < 0.5: 116 | arr = arr[:, ::-1] 117 | 118 | arr = arr.astype(np.float32) / 127.5 - 1 119 | 120 | out_dict = {} 121 | if self.local_classes is not None: 122 | out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) 123 | return np.transpose(arr, [2, 0, 1]), out_dict 124 | 125 | 126 | def center_crop_arr(pil_image, image_size): 127 | # We are not on a new enough PIL to support the `reducing_gap` 128 | # argument, which uses BOX downsampling at powers of two first. 129 | # Thus, we do it by hand to improve downsample quality. 130 | while min(*pil_image.size) >= 2 * image_size: 131 | pil_image = pil_image.resize( 132 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 133 | ) 134 | 135 | scale = image_size / min(*pil_image.size) 136 | pil_image = pil_image.resize( 137 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 138 | ) 139 | 140 | arr = np.array(pil_image) 141 | crop_y = (arr.shape[0] - image_size) // 2 142 | crop_x = (arr.shape[1] - image_size) // 2 143 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 144 | 145 | 146 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 147 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 148 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 149 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 150 | 151 | # We are not on a new enough PIL to support the `reducing_gap` 152 | # argument, which uses BOX downsampling at powers of two first. 153 | # Thus, we do it by hand to improve downsample quality. 154 | while min(*pil_image.size) >= 2 * smaller_dim_size: 155 | pil_image = pil_image.resize( 156 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 157 | ) 158 | 159 | scale = smaller_dim_size / min(*pil_image.size) 160 | pil_image = pil_image.resize( 161 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 162 | ) 163 | 164 | arr = np.array(pil_image) 165 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 166 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 167 | return arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size] 168 | -------------------------------------------------------------------------------- /masked_diffusion/logger.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logger copied from OpenAI baselines to avoid extra RL-based dependencies: 3 | https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py 4 | """ 5 | 6 | import os 7 | import sys 8 | import shutil 9 | import os.path as osp 10 | import json 11 | import time 12 | import datetime 13 | import tempfile 14 | import warnings 15 | from collections import defaultdict 16 | from contextlib import contextmanager 17 | 18 | DEBUG = 10 19 | INFO = 20 20 | WARN = 30 21 | ERROR = 40 22 | 23 | DISABLED = 50 24 | 25 | 26 | class KVWriter(object): 27 | def writekvs(self, kvs): 28 | raise NotImplementedError 29 | 30 | 31 | class SeqWriter(object): 32 | def writeseq(self, seq): 33 | raise NotImplementedError 34 | 35 | 36 | class HumanOutputFormat(KVWriter, SeqWriter): 37 | def __init__(self, filename_or_file): 38 | if isinstance(filename_or_file, str): 39 | self.file = open(filename_or_file, "wt") 40 | self.own_file = True 41 | else: 42 | assert hasattr(filename_or_file, "read"), ( 43 | "expected file or str, got %s" % filename_or_file 44 | ) 45 | self.file = filename_or_file 46 | self.own_file = False 47 | 48 | def writekvs(self, kvs): 49 | # Create strings for printing 50 | key2str = {} 51 | for (key, val) in sorted(kvs.items()): 52 | if hasattr(val, "__float__"): 53 | valstr = "%-8.3g" % val 54 | else: 55 | valstr = str(val) 56 | key2str[self._truncate(key)] = self._truncate(valstr) 57 | 58 | # Find max widths 59 | if len(key2str) == 0: 60 | print("WARNING: tried to write empty key-value dict") 61 | return 62 | else: 63 | keywidth = max(map(len, key2str.keys())) 64 | valwidth = max(map(len, key2str.values())) 65 | 66 | # Write out the data 67 | dashes = "-" * (keywidth + valwidth + 7) 68 | lines = [dashes] 69 | for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()): 70 | lines.append( 71 | "| %s%s | %s%s |" 72 | % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val))) 73 | ) 74 | lines.append(dashes) 75 | self.file.write("\n".join(lines) + "\n") 76 | 77 | # Flush the output to the file 78 | self.file.flush() 79 | 80 | def _truncate(self, s): 81 | maxlen = 30 82 | return s[: maxlen - 3] + "..." if len(s) > maxlen else s 83 | 84 | def writeseq(self, seq): 85 | seq = list(seq) 86 | for (i, elem) in enumerate(seq): 87 | self.file.write(elem) 88 | if i < len(seq) - 1: # add space unless this is the last one 89 | self.file.write(" ") 90 | self.file.write("\n") 91 | self.file.flush() 92 | 93 | def close(self): 94 | if self.own_file: 95 | self.file.close() 96 | 97 | 98 | class JSONOutputFormat(KVWriter): 99 | def __init__(self, filename): 100 | self.file = open(filename, "wt") 101 | 102 | def writekvs(self, kvs): 103 | for k, v in sorted(kvs.items()): 104 | if hasattr(v, "dtype"): 105 | kvs[k] = float(v) 106 | self.file.write(json.dumps(kvs) + "\n") 107 | self.file.flush() 108 | 109 | def close(self): 110 | self.file.close() 111 | 112 | 113 | class CSVOutputFormat(KVWriter): 114 | def __init__(self, filename): 115 | self.file = open(filename, "w+t") 116 | self.keys = [] 117 | self.sep = "," 118 | 119 | def writekvs(self, kvs): 120 | # Add our current row to the history 121 | extra_keys = list(kvs.keys() - self.keys) 122 | extra_keys.sort() 123 | if extra_keys: 124 | self.keys.extend(extra_keys) 125 | self.file.seek(0) 126 | lines = self.file.readlines() 127 | self.file.seek(0) 128 | for (i, k) in enumerate(self.keys): 129 | if i > 0: 130 | self.file.write(",") 131 | self.file.write(k) 132 | self.file.write("\n") 133 | for line in lines[1:]: 134 | self.file.write(line[:-1]) 135 | self.file.write(self.sep * len(extra_keys)) 136 | self.file.write("\n") 137 | for (i, k) in enumerate(self.keys): 138 | if i > 0: 139 | self.file.write(",") 140 | v = kvs.get(k) 141 | if v is not None: 142 | self.file.write(str(v)) 143 | self.file.write("\n") 144 | self.file.flush() 145 | 146 | def close(self): 147 | self.file.close() 148 | 149 | 150 | class TensorBoardOutputFormat(KVWriter): 151 | """ 152 | Dumps key/value pairs into TensorBoard's numeric format. 153 | """ 154 | 155 | def __init__(self, dir): 156 | os.makedirs(dir, exist_ok=True) 157 | self.dir = dir 158 | self.step = 1 159 | prefix = "events" 160 | path = osp.join(osp.abspath(dir), prefix) 161 | import tensorflow as tf 162 | from tensorflow.python import pywrap_tensorflow 163 | from tensorflow.core.util import event_pb2 164 | from tensorflow.python.util import compat 165 | 166 | self.tf = tf 167 | self.event_pb2 = event_pb2 168 | self.pywrap_tensorflow = pywrap_tensorflow 169 | self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path)) 170 | 171 | def writekvs(self, kvs): 172 | def summary_val(k, v): 173 | kwargs = {"tag": k, "simple_value": float(v)} 174 | return self.tf.Summary.Value(**kwargs) 175 | 176 | summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()]) 177 | event = self.event_pb2.Event(wall_time=time.time(), summary=summary) 178 | event.step = ( 179 | self.step 180 | ) # is there any reason why you'd want to specify the step? 181 | self.writer.WriteEvent(event) 182 | self.writer.Flush() 183 | self.step += 1 184 | 185 | def close(self): 186 | if self.writer: 187 | self.writer.Close() 188 | self.writer = None 189 | 190 | 191 | def make_output_format(format, ev_dir, log_suffix=""): 192 | os.makedirs(ev_dir, exist_ok=True) 193 | if format == "stdout": 194 | return HumanOutputFormat(sys.stdout) 195 | elif format == "log": 196 | return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix)) 197 | elif format == "json": 198 | return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix)) 199 | elif format == "csv": 200 | return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix)) 201 | elif format == "tensorboard": 202 | return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix)) 203 | else: 204 | raise ValueError("Unknown format specified: %s" % (format,)) 205 | 206 | 207 | # ================================================================ 208 | # API 209 | # ================================================================ 210 | 211 | 212 | def logkv(key, val): 213 | """ 214 | Log a value of some diagnostic 215 | Call this once for each diagnostic quantity, each iteration 216 | If called many times, last value will be used. 217 | """ 218 | get_current().logkv(key, val) 219 | 220 | 221 | def logkv_mean(key, val): 222 | """ 223 | The same as logkv(), but if called many times, values averaged. 224 | """ 225 | get_current().logkv_mean(key, val) 226 | 227 | 228 | def logkvs(d): 229 | """ 230 | Log a dictionary of key-value pairs 231 | """ 232 | for (k, v) in d.items(): 233 | logkv(k, v) 234 | 235 | 236 | def dumpkvs(): 237 | """ 238 | Write all of the diagnostics from the current iteration 239 | """ 240 | return get_current().dumpkvs() 241 | 242 | 243 | def getkvs(): 244 | return get_current().name2val 245 | 246 | 247 | def log(*args, level=INFO): 248 | """ 249 | Write the sequence of args, with no separators, to the console and output files (if you've configured an output file). 250 | """ 251 | get_current().log(*args, level=level) 252 | 253 | 254 | def debug(*args): 255 | log(*args, level=DEBUG) 256 | 257 | 258 | def info(*args): 259 | log(*args, level=INFO) 260 | 261 | 262 | def warn(*args): 263 | log(*args, level=WARN) 264 | 265 | 266 | def error(*args): 267 | log(*args, level=ERROR) 268 | 269 | 270 | def set_level(level): 271 | """ 272 | Set logging threshold on current logger. 273 | """ 274 | get_current().set_level(level) 275 | 276 | 277 | def set_comm(comm): 278 | get_current().set_comm(comm) 279 | 280 | 281 | def get_dir(): 282 | """ 283 | Get directory that log files are being written to. 284 | will be None if there is no output directory (i.e., if you didn't call start) 285 | """ 286 | return get_current().get_dir() 287 | 288 | 289 | record_tabular = logkv 290 | dump_tabular = dumpkvs 291 | 292 | 293 | @contextmanager 294 | def profile_kv(scopename): 295 | logkey = "wait_" + scopename 296 | tstart = time.time() 297 | try: 298 | yield 299 | finally: 300 | get_current().name2val[logkey] += time.time() - tstart 301 | 302 | 303 | def profile(n): 304 | """ 305 | Usage: 306 | @profile("my_func") 307 | def my_func(): code 308 | """ 309 | 310 | def decorator_with_name(func): 311 | def func_wrapper(*args, **kwargs): 312 | with profile_kv(n): 313 | return func(*args, **kwargs) 314 | 315 | return func_wrapper 316 | 317 | return decorator_with_name 318 | 319 | 320 | # ================================================================ 321 | # Backend 322 | # ================================================================ 323 | 324 | 325 | def get_current(): 326 | if Logger.CURRENT is None: 327 | _configure_default_logger() 328 | 329 | return Logger.CURRENT 330 | 331 | 332 | class Logger(object): 333 | DEFAULT = None # A logger with no output files. (See right below class definition) 334 | # So that you can still log to the terminal without setting up any output files 335 | CURRENT = None # Current logger being used by the free functions above 336 | 337 | def __init__(self, dir, output_formats, comm=None): 338 | self.name2val = defaultdict(float) # values this iteration 339 | self.name2cnt = defaultdict(int) 340 | self.level = INFO 341 | self.dir = dir 342 | self.output_formats = output_formats 343 | self.comm = comm 344 | 345 | # Logging API, forwarded 346 | # ---------------------------------------- 347 | def logkv(self, key, val): 348 | self.name2val[key] = val 349 | 350 | def logkv_mean(self, key, val): 351 | oldval, cnt = self.name2val[key], self.name2cnt[key] 352 | self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1) 353 | self.name2cnt[key] = cnt + 1 354 | 355 | def dumpkvs(self): 356 | if self.comm is None: 357 | d = self.name2val 358 | else: 359 | d = mpi_weighted_mean( 360 | self.comm, 361 | { 362 | name: (val, self.name2cnt.get(name, 1)) 363 | for (name, val) in self.name2val.items() 364 | }, 365 | ) 366 | if self.comm.rank != 0: 367 | d["dummy"] = 1 # so we don't get a warning about empty dict 368 | out = d.copy() # Return the dict for unit testing purposes 369 | for fmt in self.output_formats: 370 | if isinstance(fmt, KVWriter): 371 | fmt.writekvs(d) 372 | self.name2val.clear() 373 | self.name2cnt.clear() 374 | return out 375 | 376 | def log(self, *args, level=INFO): 377 | if self.level <= level: 378 | self._do_log(args) 379 | 380 | # Configuration 381 | # ---------------------------------------- 382 | def set_level(self, level): 383 | self.level = level 384 | 385 | def set_comm(self, comm): 386 | self.comm = comm 387 | 388 | def get_dir(self): 389 | return self.dir 390 | 391 | def close(self): 392 | for fmt in self.output_formats: 393 | fmt.close() 394 | 395 | # Misc 396 | # ---------------------------------------- 397 | def _do_log(self, args): 398 | for fmt in self.output_formats: 399 | if isinstance(fmt, SeqWriter): 400 | fmt.writeseq(map(str, args)) 401 | 402 | 403 | def get_rank_without_mpi_import(): 404 | # check environment variables here instead of importing mpi4py 405 | # to avoid calling MPI_Init() when this module is imported 406 | for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]: 407 | if varname in os.environ: 408 | return int(os.environ[varname]) 409 | return 0 410 | 411 | 412 | def mpi_weighted_mean(comm, local_name2valcount): 413 | """ 414 | Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110 415 | Perform a weighted average over dicts that are each on a different node 416 | Input: local_name2valcount: dict mapping key -> (value, count) 417 | Returns: key -> mean 418 | """ 419 | all_name2valcount = comm.gather(local_name2valcount) 420 | if comm.rank == 0: 421 | name2sum = defaultdict(float) 422 | name2count = defaultdict(float) 423 | for n2vc in all_name2valcount: 424 | for (name, (val, count)) in n2vc.items(): 425 | try: 426 | val = float(val) 427 | except ValueError: 428 | if comm.rank == 0: 429 | warnings.warn( 430 | "WARNING: tried to compute mean on non-float {}={}".format( 431 | name, val 432 | ) 433 | ) 434 | else: 435 | name2sum[name] += val * count 436 | name2count[name] += count 437 | return {name: name2sum[name] / name2count[name] for name in name2sum} 438 | else: 439 | return {} 440 | 441 | 442 | def configure(dir=None, format_strs=None, comm=None, log_suffix=""): 443 | """ 444 | If comm is provided, average all numerical stats across that comm 445 | """ 446 | if dir is None: 447 | dir = os.getenv("OPENAI_LOGDIR") 448 | if dir is None: 449 | dir = osp.join( 450 | tempfile.gettempdir(), 451 | datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"), 452 | ) 453 | assert isinstance(dir, str) 454 | dir = os.path.expanduser(dir) 455 | os.makedirs(os.path.expanduser(dir), exist_ok=True) 456 | 457 | rank = get_rank_without_mpi_import() 458 | if rank > 0: 459 | log_suffix = log_suffix + "-rank%03i" % rank 460 | 461 | if format_strs is None: 462 | if rank == 0: 463 | format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",") 464 | else: 465 | format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",") 466 | format_strs = filter(None, format_strs) 467 | output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs] 468 | 469 | Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm) 470 | if output_formats: 471 | log("Logging to %s" % dir) 472 | 473 | 474 | def _configure_default_logger(): 475 | configure() 476 | Logger.DEFAULT = Logger.CURRENT 477 | 478 | 479 | def reset(): 480 | if Logger.CURRENT is not Logger.DEFAULT: 481 | Logger.CURRENT.close() 482 | Logger.CURRENT = Logger.DEFAULT 483 | log("Reset logger") 484 | 485 | 486 | @contextmanager 487 | def scoped_configure(dir=None, format_strs=None, comm=None): 488 | prevlogger = Logger.CURRENT 489 | configure(dir=dir, format_strs=format_strs, comm=comm) 490 | try: 491 | yield 492 | finally: 493 | Logger.CURRENT.close() 494 | Logger.CURRENT = prevlogger 495 | 496 | -------------------------------------------------------------------------------- /masked_diffusion/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import math 5 | from timm.models.vision_transformer import PatchEmbed, Mlp 6 | from timm.models.layers import trunc_normal_ 7 | import math 8 | 9 | def modulate(x, shift, scale): 10 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 11 | 12 | 13 | class Attention(nn.Module): 14 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., num_patches=None): 15 | super().__init__() 16 | self.num_heads = num_heads 17 | head_dim = dim // num_heads 18 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 19 | self.scale = qk_scale or head_dim ** -0.5 20 | 21 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 22 | self.attn_drop = nn.Dropout(attn_drop) 23 | self.proj = nn.Linear(dim, dim) 24 | self.proj_drop = nn.Dropout(proj_drop) 25 | self.rel_pos_bias = RelativePositionBias( 26 | window_size=[int(num_patches**0.5), int(num_patches**0.5)], num_heads=num_heads) 27 | 28 | def get_masked_rel_bias(self, B, ids_keep): 29 | # get masked rel_pos_bias 30 | rel_pos_bias = self.rel_pos_bias() 31 | rel_pos_bias = rel_pos_bias.unsqueeze(dim=0).repeat(B, 1, 1, 1) 32 | 33 | rel_pos_bias_masked = torch.gather( 34 | rel_pos_bias, dim=2, index=ids_keep.unsqueeze(dim=1).unsqueeze(dim=-1).repeat(1, rel_pos_bias.shape[1], 1, rel_pos_bias.shape[-1])) 35 | rel_pos_bias_masked = torch.gather( 36 | rel_pos_bias_masked, dim=3, index=ids_keep.unsqueeze(dim=1).unsqueeze(dim=2).repeat(1, rel_pos_bias.shape[1], ids_keep.shape[1], 1)) 37 | return rel_pos_bias_masked 38 | 39 | def forward(self, x, ids_keep=None): 40 | B, N, C = x.shape 41 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // 42 | self.num_heads).permute(2, 0, 3, 1, 4) 43 | # make torchscript happy (cannot use tensor as tuple) 44 | q, k, v = qkv[0], qkv[1], qkv[2] 45 | 46 | attn = (q @ k.transpose(-2, -1)) * self.scale 47 | if ids_keep is not None: 48 | rp_bias = self.get_masked_rel_bias(B, ids_keep) 49 | else: 50 | rp_bias = self.rel_pos_bias() 51 | attn += rp_bias 52 | attn = attn.softmax(dim=-1) 53 | attn = self.attn_drop(attn) 54 | 55 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 56 | x = self.proj(x) 57 | x = self.proj_drop(x) 58 | return x 59 | 60 | 61 | class RelativePositionBias(nn.Module): 62 | # https://github.com/microsoft/unilm/blob/master/beit/modeling_finetune.py 63 | def __init__(self, window_size, num_heads): 64 | super().__init__() 65 | self.window_size = window_size 66 | self.num_relative_distance = ( 67 | 2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3 68 | self.relative_position_bias_table = nn.Parameter( 69 | torch.zeros(self.num_relative_distance, num_heads)) 70 | 71 | # get pair-wise relative position index for each token inside the window 72 | coords_h = torch.arange(window_size[0]) 73 | coords_w = torch.arange(window_size[1]) 74 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) 75 | coords_flatten = torch.flatten(coords, 1) 76 | relative_coords = coords_flatten[:, :, None] - \ 77 | coords_flatten[:, None, :] 78 | relative_coords = relative_coords.permute( 79 | 1, 2, 0).contiguous() 80 | relative_coords[:, :, 0] += window_size[0] - 1 81 | relative_coords[:, :, 1] += window_size[1] - 1 82 | relative_coords[:, :, 0] *= 2 * window_size[1] - 1 83 | relative_position_index = \ 84 | torch.zeros( 85 | size=(window_size[0] * window_size[1],) * 2, dtype=relative_coords.dtype) 86 | relative_position_index = relative_coords.sum(-1) 87 | 88 | self.register_buffer("relative_position_index", 89 | relative_position_index) 90 | 91 | trunc_normal_(self.relative_position_bias_table, std=.02) 92 | 93 | def forward(self): 94 | relative_position_bias = \ 95 | self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 96 | self.window_size[0] * self.window_size[1], 97 | self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH 98 | # nH, Wh*Ww, Wh*Ww 99 | return relative_position_bias.permute(2, 0, 1).contiguous() 100 | 101 | ################################################################################# 102 | # Embedding Layers for Timesteps and Class Labels # 103 | ################################################################################# 104 | 105 | 106 | class TimestepEmbedder(nn.Module): 107 | """ 108 | Embeds scalar timesteps into vector representations. 109 | """ 110 | 111 | def __init__(self, hidden_size, frequency_embedding_size=256): 112 | super().__init__() 113 | self.mlp = nn.Sequential( 114 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 115 | nn.SiLU(), 116 | nn.Linear(hidden_size, hidden_size, bias=True), 117 | ) 118 | self.frequency_embedding_size = frequency_embedding_size 119 | 120 | @staticmethod 121 | def timestep_embedding(t, dim, max_period=10000): 122 | """ 123 | Create sinusoidal timestep embeddings. 124 | :param t: a 1-D Tensor of N indices, one per batch element. 125 | These may be fractional. 126 | :param dim: the dimension of the output. 127 | :param max_period: controls the minimum frequency of the embeddings. 128 | :return: an (N, D) Tensor of positional embeddings. 129 | """ 130 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 131 | half = dim // 2 132 | freqs = torch.exp( 133 | -math.log(max_period) * torch.arange(start=0, 134 | end=half, dtype=torch.float32) / half 135 | ).to(device=t.device) 136 | args = t[:, None].float() * freqs[None] 137 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 138 | if dim % 2: 139 | embedding = torch.cat( 140 | [embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 141 | return embedding 142 | 143 | def forward(self, t): 144 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 145 | t_emb = self.mlp(t_freq) 146 | return t_emb 147 | 148 | 149 | class LabelEmbedder(nn.Module): 150 | """ 151 | Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. 152 | """ 153 | 154 | def __init__(self, num_classes, hidden_size, dropout_prob): 155 | super().__init__() 156 | use_cfg_embedding = dropout_prob > 0 157 | self.embedding_table = nn.Embedding( 158 | num_classes + use_cfg_embedding, hidden_size) 159 | self.num_classes = num_classes 160 | self.dropout_prob = dropout_prob 161 | 162 | def token_drop(self, labels, force_drop_ids=None): 163 | """ 164 | Drops labels to enable classifier-free guidance. 165 | """ 166 | if force_drop_ids is None: 167 | drop_ids = torch.rand(labels.shape[0]) < self.dropout_prob 168 | else: 169 | drop_ids = force_drop_ids == 1 170 | 171 | labels = torch.where(drop_ids.to(labels.device), 172 | self.num_classes, labels) 173 | return labels 174 | 175 | def forward(self, labels, train, force_drop_ids=None): 176 | use_dropout = self.dropout_prob > 0 177 | if (train and use_dropout) or (force_drop_ids is not None): 178 | labels = self.token_drop(labels, force_drop_ids) 179 | embeddings = self.embedding_table(labels) 180 | return embeddings 181 | 182 | 183 | ################################################################################# 184 | # Core MDT Model # 185 | ################################################################################# 186 | 187 | class MDTBlock(nn.Module): 188 | """ 189 | A MDT block with adaptive layer norm zero (adaLN-Zero) conditioning. 190 | """ 191 | 192 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, skip=False, **block_kwargs): 193 | super().__init__() 194 | self.norm1 = nn.LayerNorm( 195 | hidden_size, elementwise_affine=False, eps=1e-6) 196 | self.attn = Attention( 197 | hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) 198 | self.norm2 = nn.LayerNorm( 199 | hidden_size, elementwise_affine=False, eps=1e-6) 200 | mlp_hidden_dim = int(hidden_size * mlp_ratio) 201 | def approx_gelu(): return nn.GELU(approximate="tanh") 202 | self.mlp = Mlp(in_features=hidden_size, 203 | hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) 204 | self.adaLN_modulation = nn.Sequential( 205 | nn.SiLU(), 206 | nn.Linear(hidden_size, 6 * hidden_size, bias=True) 207 | ) 208 | self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) if skip else None 209 | 210 | def forward(self, x, c, skip=None, ids_keep=None): 211 | if self.skip_linear is not None: 212 | x = self.skip_linear(torch.cat([x, skip], dim=-1)) 213 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation( 214 | c).chunk(6, dim=1) 215 | x = x + gate_msa.unsqueeze(1) * self.attn( 216 | modulate(self.norm1(x), shift_msa, scale_msa), ids_keep=ids_keep) 217 | x = x + \ 218 | gate_mlp.unsqueeze( 219 | 1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) 220 | return x 221 | 222 | 223 | class FinalLayer(nn.Module): 224 | """ 225 | The final layer of MDT. 226 | """ 227 | 228 | def __init__(self, hidden_size, patch_size, out_channels): 229 | super().__init__() 230 | self.norm_final = nn.LayerNorm( 231 | hidden_size, elementwise_affine=False, eps=1e-6) 232 | self.linear = nn.Linear( 233 | hidden_size, patch_size * patch_size * out_channels, bias=True) 234 | self.adaLN_modulation = nn.Sequential( 235 | nn.SiLU(), 236 | nn.Linear(hidden_size, 2 * hidden_size, bias=True) 237 | ) 238 | 239 | def forward(self, x, c): 240 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) 241 | x = modulate(self.norm_final(x), shift, scale) 242 | x = self.linear(x) 243 | return x 244 | 245 | 246 | class MDTv2(nn.Module): 247 | """ 248 | Masked Diffusion Transformer v2. 249 | """ 250 | 251 | def __init__( 252 | self, 253 | input_size=32, 254 | patch_size=2, 255 | in_channels=4, 256 | hidden_size=1152, 257 | depth=28, 258 | num_heads=16, 259 | mlp_ratio=4.0, 260 | class_dropout_prob=0.1, 261 | num_classes=1000, 262 | learn_sigma=True, 263 | mask_ratio=None, 264 | decode_layer=4, 265 | ): 266 | super().__init__() 267 | self.learn_sigma = learn_sigma 268 | self.in_channels = in_channels 269 | self.out_channels = in_channels * 2 if learn_sigma else in_channels 270 | self.patch_size = patch_size 271 | self.num_heads = num_heads 272 | decode_layer = int(decode_layer) 273 | 274 | self.x_embedder = PatchEmbed( 275 | input_size, patch_size, in_channels, hidden_size, bias=True) 276 | self.t_embedder = TimestepEmbedder(hidden_size) 277 | self.y_embedder = LabelEmbedder( 278 | num_classes, hidden_size, class_dropout_prob) 279 | num_patches = self.x_embedder.num_patches 280 | # Will use learnbale sin-cos embedding: 281 | self.pos_embed = nn.Parameter(torch.zeros( 282 | 1, num_patches, hidden_size), requires_grad=True) 283 | 284 | half_depth = (depth - decode_layer)//2 285 | self.half_depth=half_depth 286 | 287 | self.en_inblocks = nn.ModuleList([ 288 | MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches) for _ in range(half_depth) 289 | ]) 290 | self.en_outblocks = nn.ModuleList([ 291 | MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches, skip=True) for _ in range(half_depth) 292 | ]) 293 | self.de_blocks = nn.ModuleList([ 294 | MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches, skip=True) for i in range(decode_layer) 295 | ]) 296 | self.sideblocks = nn.ModuleList([ 297 | MDTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, num_patches=num_patches) for _ in range(1) 298 | ]) 299 | self.final_layer = FinalLayer( 300 | hidden_size, patch_size, self.out_channels) 301 | 302 | self.decoder_pos_embed = nn.Parameter(torch.zeros( 303 | 1, num_patches, hidden_size), requires_grad=True) 304 | if mask_ratio is not None: 305 | self.mask_token = nn.Parameter(torch.zeros(1, 1, hidden_size)) 306 | self.mask_ratio = float(mask_ratio) 307 | self.decode_layer = int(decode_layer) 308 | else: 309 | self.mask_token = nn.Parameter(torch.zeros( 310 | 1, 1, hidden_size), requires_grad=False) 311 | self.mask_ratio = None 312 | self.decode_layer = int(decode_layer) 313 | print("mask ratio:", self.mask_ratio, "decode_layer:", self.decode_layer) 314 | self.initialize_weights() 315 | 316 | def initialize_weights(self): 317 | # Initialize transformer layers: 318 | def _basic_init(module): 319 | if isinstance(module, nn.Linear): 320 | torch.nn.init.xavier_uniform_(module.weight) 321 | if module.bias is not None: 322 | nn.init.constant_(module.bias, 0) 323 | self.apply(_basic_init) 324 | 325 | # Initialize pos_embed by sin-cos embedding: 326 | pos_embed = get_2d_sincos_pos_embed( 327 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 328 | self.pos_embed.data.copy_( 329 | torch.from_numpy(pos_embed).float().unsqueeze(0)) 330 | 331 | decoder_pos_embed = get_2d_sincos_pos_embed( 332 | self.decoder_pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5)) 333 | self.decoder_pos_embed.data.copy_( 334 | torch.from_numpy(decoder_pos_embed).float().unsqueeze(0)) 335 | 336 | # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): 337 | w = self.x_embedder.proj.weight.data 338 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 339 | nn.init.constant_(self.x_embedder.proj.bias, 0) 340 | 341 | # Initialize label embedding table: 342 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 343 | 344 | # Initialize timestep embedding MLP: 345 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 346 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 347 | 348 | for block in self.en_inblocks: 349 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 350 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 351 | 352 | for block in self.en_outblocks: 353 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 354 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 355 | 356 | for block in self.de_blocks: 357 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 358 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 359 | 360 | for block in self.sideblocks: 361 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 362 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 363 | 364 | # Zero-out output layers: 365 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 366 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 367 | nn.init.constant_(self.final_layer.linear.weight, 0) 368 | nn.init.constant_(self.final_layer.linear.bias, 0) 369 | 370 | if self.mask_ratio is not None: 371 | torch.nn.init.normal_(self.mask_token, std=.02) 372 | 373 | def unpatchify(self, x): 374 | """ 375 | x: (N, T, patch_size**2 * C) 376 | imgs: (N, H, W, C) 377 | """ 378 | c = self.out_channels 379 | p = self.x_embedder.patch_size[0] 380 | h = w = int(x.shape[1] ** 0.5) 381 | assert h * w == x.shape[1] 382 | 383 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 384 | x = torch.einsum('nhwpqc->nchpwq', x) 385 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 386 | return imgs 387 | 388 | def random_masking(self, x, mask_ratio): 389 | """ 390 | Perform per-sample random masking by per-sample shuffling. 391 | Per-sample shuffling is done by argsort random noise. 392 | x: [N, L, D], sequence 393 | """ 394 | N, L, D = x.shape # batch, length, dim 395 | len_keep = int(L * (1 - mask_ratio)) 396 | 397 | noise = torch.rand(N, L, device=x.device) # noise in [0, 1] 398 | 399 | # sort noise for each sample 400 | # ascend: small is keep, large is remove 401 | ids_shuffle = torch.argsort(noise, dim=1) 402 | ids_restore = torch.argsort(ids_shuffle, dim=1) 403 | 404 | # keep the first subset 405 | ids_keep = ids_shuffle[:, :len_keep] 406 | x_masked = torch.gather( 407 | x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)) 408 | 409 | # generate the binary mask: 0 is keep, 1 is remove 410 | mask = torch.ones([N, L], device=x.device) 411 | mask[:, :len_keep] = 0 412 | # unshuffle to get the binary mask 413 | mask = torch.gather(mask, dim=1, index=ids_restore) 414 | 415 | return x_masked, mask, ids_restore, ids_keep 416 | 417 | def forward_side_interpolater(self, x, c, mask, ids_restore): 418 | # append mask tokens to sequence 419 | mask_tokens = self.mask_token.repeat( 420 | x.shape[0], ids_restore.shape[1] - x.shape[1], 1) 421 | x_ = torch.cat([x, mask_tokens], dim=1) 422 | x = torch.gather( 423 | x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle 424 | 425 | # add pos embed 426 | x = x + self.decoder_pos_embed 427 | 428 | # pass to the basic block 429 | x_before = x 430 | for sideblock in self.sideblocks: 431 | x = sideblock(x, c, ids_keep=None) 432 | 433 | # masked shortcut 434 | mask = mask.unsqueeze(dim=-1) 435 | x = x*mask + (1-mask)*x_before 436 | 437 | return x 438 | 439 | def forward(self, x, t, y, enable_mask=False): 440 | """ 441 | Forward pass of MDT. 442 | x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) 443 | t: (N,) tensor of diffusion timesteps 444 | y: (N,) tensor of class labels 445 | enable_mask: Use mask latent modeling 446 | """ 447 | x = self.x_embedder( 448 | x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 449 | 450 | t = self.t_embedder(t) # (N, D) 451 | y = self.y_embedder(y, self.training) # (N, D) 452 | c = t + y # (N, D) 453 | 454 | 455 | input_skip = x 456 | 457 | masked_stage = False 458 | skips = [] 459 | # masking op for training 460 | if self.mask_ratio is not None and enable_mask: 461 | # masking: length -> length * mask_ratio 462 | rand_mask_ratio = torch.rand(1, device=x.device) # noise in [0, 1] 463 | rand_mask_ratio = rand_mask_ratio * 0.2 + self.mask_ratio # mask_ratio, mask_ratio + 0.2 464 | # print(rand_mask_ratio) 465 | x, mask, ids_restore, ids_keep = self.random_masking( 466 | x, rand_mask_ratio) 467 | masked_stage = True 468 | 469 | 470 | for block in self.en_inblocks: 471 | if masked_stage: 472 | x = block(x, c, ids_keep=ids_keep) 473 | else: 474 | x = block(x, c, ids_keep=None) 475 | skips.append(x) 476 | 477 | for block in self.en_outblocks: 478 | if masked_stage: 479 | x = block(x, c, skip=skips.pop(), ids_keep=ids_keep) 480 | else: 481 | x = block(x, c, skip=skips.pop(), ids_keep=None) 482 | 483 | if self.mask_ratio is not None and enable_mask: 484 | x = self.forward_side_interpolater(x, c, mask, ids_restore) 485 | masked_stage = False 486 | else: 487 | # add pos embed 488 | x = x + self.decoder_pos_embed 489 | 490 | for i in range(len(self.de_blocks)): 491 | block = self.de_blocks[i] 492 | this_skip = input_skip 493 | 494 | x = block(x, c, skip=this_skip, ids_keep=None) 495 | 496 | x = self.final_layer(x, c) 497 | x = self.unpatchify(x) # (N, out_channels, H, W) 498 | return x 499 | 500 | 501 | def forward_with_cfg(self, x, t, y, cfg_scale=None, diffusion_steps=1000, scale_pow=4.0): 502 | """ 503 | Forward pass of MDT, but also batches the unconditional forward pass for classifier-free guidance. 504 | """ 505 | # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb 506 | if cfg_scale is not None: 507 | half = x[: len(x) // 2] 508 | combined = torch.cat([half, half], dim=0) 509 | model_out = self.forward(combined, t, y) 510 | eps, rest = model_out[:, :3], model_out[:, 3:] 511 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 512 | 513 | scale_step = ( 514 | 1-torch.cos(((1-t/diffusion_steps)**scale_pow)*math.pi))*1/2 # power-cos scaling 515 | real_cfg_scale = (cfg_scale-1)*scale_step + 1 516 | real_cfg_scale = real_cfg_scale[: len(x) // 2].view(-1, 1, 1, 1) 517 | 518 | half_eps = uncond_eps + real_cfg_scale * (cond_eps - uncond_eps) 519 | eps = torch.cat([half_eps, half_eps], dim=0) 520 | return torch.cat([eps, rest], dim=1) 521 | else: 522 | model_out = self.forward(x, t, y) 523 | eps, rest = model_out[:, :3], model_out[:, 3:] 524 | return torch.cat([eps, rest], dim=1) 525 | 526 | 527 | ################################################################################# 528 | # Sine/Cosine Positional Embedding Functions # 529 | ################################################################################# 530 | # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py 531 | 532 | 533 | def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): 534 | """ 535 | grid_size: int of the grid height and width 536 | return: 537 | pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) 538 | """ 539 | grid_h = np.arange(grid_size, dtype=np.float32) 540 | grid_w = np.arange(grid_size, dtype=np.float32) 541 | grid = np.meshgrid(grid_w, grid_h) # here w goes first 542 | grid = np.stack(grid, axis=0) 543 | 544 | grid = grid.reshape([2, 1, grid_size, grid_size]) 545 | pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) 546 | if cls_token and extra_tokens > 0: 547 | pos_embed = np.concatenate( 548 | [np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) 549 | return pos_embed 550 | 551 | 552 | def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): 553 | assert embed_dim % 2 == 0 554 | 555 | # use half of dimensions to encode grid_h 556 | emb_h = get_1d_sincos_pos_embed_from_grid( 557 | embed_dim // 2, grid[0]) # (H*W, D/2) 558 | emb_w = get_1d_sincos_pos_embed_from_grid( 559 | embed_dim // 2, grid[1]) # (H*W, D/2) 560 | 561 | emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) 562 | return emb 563 | 564 | 565 | def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): 566 | """ 567 | embed_dim: output dimension for each position 568 | pos: a list of positions to be encoded: size (M,) 569 | out: (M, D) 570 | """ 571 | assert embed_dim % 2 == 0 572 | omega = np.arange(embed_dim // 2, dtype=np.float64) 573 | omega /= embed_dim / 2. 574 | omega = 1. / 10000**omega # (D/2,) 575 | 576 | pos = pos.reshape(-1) # (M,) 577 | out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product 578 | 579 | emb_sin = np.sin(out) # (M, D/2) 580 | emb_cos = np.cos(out) # (M, D/2) 581 | 582 | emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) 583 | return emb 584 | 585 | 586 | ################################################################################# 587 | # MDTv2 Configs # 588 | ################################################################################# 589 | 590 | def MDTv2_XL_2(**kwargs): 591 | return MDTv2(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs) 592 | 593 | def MDTv2_L_2(**kwargs): 594 | return MDTv2(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs) 595 | 596 | def MDTv2_B_2(**kwargs): 597 | return MDTv2(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs) 598 | 599 | def MDTv2_S_2(**kwargs): 600 | return MDTv2(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs) 601 | 602 | -------------------------------------------------------------------------------- /masked_diffusion/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | import math 6 | 7 | import torch as th 8 | import torch.nn as nn 9 | 10 | 11 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 12 | class SiLU(nn.Module): 13 | def forward(self, x): 14 | return x * th.sigmoid(x) 15 | 16 | 17 | class GroupNorm32(nn.GroupNorm): 18 | def forward(self, x): 19 | return super().forward(x.float()).type(x.dtype) 20 | 21 | 22 | def conv_nd(dims, *args, **kwargs): 23 | """ 24 | Create a 1D, 2D, or 3D convolution module. 25 | """ 26 | if dims == 1: 27 | return nn.Conv1d(*args, **kwargs) 28 | elif dims == 2: 29 | return nn.Conv2d(*args, **kwargs) 30 | elif dims == 3: 31 | return nn.Conv3d(*args, **kwargs) 32 | raise ValueError(f"unsupported dimensions: {dims}") 33 | 34 | 35 | def linear(*args, **kwargs): 36 | """ 37 | Create a linear module. 38 | """ 39 | return nn.Linear(*args, **kwargs) 40 | 41 | 42 | def avg_pool_nd(dims, *args, **kwargs): 43 | """ 44 | Create a 1D, 2D, or 3D average pooling module. 45 | """ 46 | if dims == 1: 47 | return nn.AvgPool1d(*args, **kwargs) 48 | elif dims == 2: 49 | return nn.AvgPool2d(*args, **kwargs) 50 | elif dims == 3: 51 | return nn.AvgPool3d(*args, **kwargs) 52 | raise ValueError(f"unsupported dimensions: {dims}") 53 | 54 | 55 | def update_ema(target_params, source_params, rate=0.99): 56 | """ 57 | Update target parameters to be closer to those of source parameters using 58 | an exponential moving average. 59 | 60 | :param target_params: the target parameter sequence. 61 | :param source_params: the source parameter sequence. 62 | :param rate: the EMA rate (closer to 1 means slower). 63 | """ 64 | for targ, src in zip(target_params, source_params): 65 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 66 | 67 | 68 | def zero_module(module): 69 | """ 70 | Zero out the parameters of a module and return it. 71 | """ 72 | for p in module.parameters(): 73 | p.detach().zero_() 74 | return module 75 | 76 | 77 | def scale_module(module, scale): 78 | """ 79 | Scale the parameters of a module and return it. 80 | """ 81 | for p in module.parameters(): 82 | p.detach().mul_(scale) 83 | return module 84 | 85 | 86 | def mean_flat(tensor): 87 | """ 88 | Take the mean over all non-batch dimensions. 89 | """ 90 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 91 | 92 | 93 | def normalization(channels): 94 | """ 95 | Make a standard normalization layer. 96 | 97 | :param channels: number of input channels. 98 | :return: an nn.Module for normalization. 99 | """ 100 | return GroupNorm32(32, channels) 101 | 102 | 103 | def timestep_embedding(timesteps, dim, max_period=10000): 104 | """ 105 | Create sinusoidal timestep embeddings. 106 | 107 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 108 | These may be fractional. 109 | :param dim: the dimension of the output. 110 | :param max_period: controls the minimum frequency of the embeddings. 111 | :return: an [N x dim] Tensor of positional embeddings. 112 | """ 113 | half = dim // 2 114 | freqs = th.exp( 115 | -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half 116 | ).to(device=timesteps.device) 117 | args = timesteps[:, None].float() * freqs[None] 118 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 119 | if dim % 2: 120 | embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1) 121 | return embedding 122 | 123 | 124 | def checkpoint(func, inputs, params, flag): 125 | """ 126 | Evaluate a function without caching intermediate activations, allowing for 127 | reduced memory at the expense of extra compute in the backward pass. 128 | 129 | :param func: the function to evaluate. 130 | :param inputs: the argument sequence to pass to `func`. 131 | :param params: a sequence of parameters `func` depends on but does not 132 | explicitly take as arguments. 133 | :param flag: if False, disable gradient checkpointing. 134 | """ 135 | if flag: 136 | args = tuple(inputs) + tuple(params) 137 | return CheckpointFunction.apply(func, len(inputs), *args) 138 | else: 139 | return func(*inputs) 140 | 141 | 142 | class CheckpointFunction(th.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, run_function, length, *args): 145 | ctx.run_function = run_function 146 | ctx.input_tensors = list(args[:length]) 147 | ctx.input_params = list(args[length:]) 148 | with th.no_grad(): 149 | output_tensors = ctx.run_function(*ctx.input_tensors) 150 | return output_tensors 151 | 152 | @staticmethod 153 | def backward(ctx, *output_grads): 154 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 155 | with th.enable_grad(): 156 | # Fixes a bug where the first op in run_function modifies the 157 | # Tensor storage in place, which is not allowed for detach()'d 158 | # Tensors. 159 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 160 | output_tensors = ctx.run_function(*shallow_copies) 161 | input_grads = th.autograd.grad( 162 | output_tensors, 163 | ctx.input_tensors + ctx.input_params, 164 | output_grads, 165 | allow_unused=True, 166 | ) 167 | del ctx.input_tensors 168 | del ctx.input_params 169 | del output_tensors 170 | return (None, None) + input_grads 171 | -------------------------------------------------------------------------------- /masked_diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | elif name == "loss-second-moment": 18 | return LossSecondMomentResampler(diffusion) 19 | else: 20 | raise NotImplementedError(f"unknown schedule sampler: {name}") 21 | 22 | 23 | class ScheduleSampler(ABC): 24 | """ 25 | A distribution over timesteps in the diffusion process, intended to reduce 26 | variance of the objective. 27 | 28 | By default, samplers perform unbiased importance sampling, in which the 29 | objective's mean is unchanged. 30 | However, subclasses may override sample() to change how the resampled 31 | terms are reweighted, allowing for actual changes in the objective. 32 | """ 33 | 34 | @abstractmethod 35 | def weights(self): 36 | """ 37 | Get a numpy array of weights, one per diffusion step. 38 | 39 | The weights needn't be normalized, but must be positive. 40 | """ 41 | 42 | def sample(self, batch_size, device): 43 | """ 44 | Importance-sample timesteps for a batch. 45 | 46 | :param batch_size: the number of timesteps. 47 | :param device: the torch device to save to. 48 | :return: a tuple (timesteps, weights): 49 | - timesteps: a tensor of timestep indices. 50 | - weights: a tensor of weights to scale the resulting losses. 51 | """ 52 | w = self.weights() 53 | p = w / np.sum(w) 54 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 55 | indices = th.from_numpy(indices_np).long().to(device) 56 | weights_np = 1 / (len(p) * p[indices_np]) 57 | weights = th.from_numpy(weights_np).float().to(device) 58 | return indices, weights 59 | 60 | 61 | class UniformSampler(ScheduleSampler): 62 | def __init__(self, diffusion): 63 | self.diffusion = diffusion 64 | self._weights = np.ones([diffusion.num_timesteps]) 65 | 66 | def weights(self): 67 | return self._weights 68 | 69 | 70 | class LossAwareSampler(ScheduleSampler): 71 | def update_with_local_losses(self, local_ts, local_losses): 72 | """ 73 | Update the reweighting using losses from a model. 74 | 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | 80 | :param local_ts: an integer Tensor of timesteps. 81 | :param local_losses: a 1D Tensor of losses. 82 | """ 83 | batch_sizes = [ 84 | th.tensor([0], dtype=th.int32, device=local_ts.device) 85 | for _ in range(dist.get_world_size()) 86 | ] 87 | dist.all_gather( 88 | batch_sizes, 89 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 90 | ) 91 | 92 | # Pad all_gather batches to be the maximum batch size. 93 | batch_sizes = [x.item() for x in batch_sizes] 94 | max_bs = max(batch_sizes) 95 | 96 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 97 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 98 | dist.all_gather(timestep_batches, local_ts) 99 | dist.all_gather(loss_batches, local_losses) 100 | timesteps = [ 101 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 102 | ] 103 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 104 | self.update_with_all_losses(timesteps, losses) 105 | 106 | @abstractmethod 107 | def update_with_all_losses(self, ts, losses): 108 | """ 109 | Update the reweighting using losses from a model. 110 | 111 | Sub-classes should override this method to update the reweighting 112 | using losses from the model. 113 | 114 | This method directly updates the reweighting without synchronizing 115 | between workers. It is called by update_with_local_losses from all 116 | ranks with identical arguments. Thus, it should have deterministic 117 | behavior to maintain state across workers. 118 | 119 | :param ts: a list of int timesteps. 120 | :param losses: a list of float losses, one per timestep. 121 | """ 122 | 123 | 124 | class LossSecondMomentResampler(LossAwareSampler): 125 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 126 | self.diffusion = diffusion 127 | self.history_per_term = history_per_term 128 | self.uniform_prob = uniform_prob 129 | self._loss_history = np.zeros( 130 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 131 | ) 132 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 133 | 134 | def weights(self): 135 | if not self._warmed_up(): 136 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 138 | weights /= np.sum(weights) 139 | weights *= 1 - self.uniform_prob 140 | weights += self.uniform_prob / len(weights) 141 | return weights 142 | 143 | def update_with_all_losses(self, ts, losses): 144 | for t, loss in zip(ts, losses): 145 | if self._loss_counts[t] == self.history_per_term: 146 | # Shift out the oldest loss term. 147 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 148 | self._loss_history[t, -1] = loss 149 | else: 150 | self._loss_history[t, self._loss_counts[t]] = loss 151 | self._loss_counts[t] += 1 152 | 153 | def _warmed_up(self): 154 | return (self._loss_counts == self.history_per_term).all() 155 | -------------------------------------------------------------------------------- /masked_diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | class _WrappedModel: 117 | def __init__(self, model, timestep_map, original_num_steps): 118 | self.model = model 119 | self.timestep_map = timestep_map 120 | # self.rescale_timesteps = rescale_timesteps 121 | self.original_num_steps = original_num_steps 122 | 123 | def __call__(self, x, ts, **kwargs): 124 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 125 | new_ts = map_tensor[ts] 126 | # if self.rescale_timesteps: 127 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 128 | return self.model(x, new_ts, **kwargs) 129 | -------------------------------------------------------------------------------- /masked_diffusion/script_util.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import inspect 3 | 4 | from . import gaussian_diffusion as gd 5 | from .respace import SpacedDiffusion, space_timesteps 6 | 7 | NUM_CLASSES = 1000 8 | 9 | def add_dict_to_argparser(parser, default_dict): 10 | for k, v in default_dict.items(): 11 | v_type = type(v) 12 | if v is None: 13 | v_type = str 14 | elif isinstance(v, bool): 15 | v_type = str2bool 16 | parser.add_argument(f"--{k}", default=v, type=v_type) 17 | 18 | 19 | def args_to_dict(args, keys): 20 | return {k: getattr(args, k) for k in keys} 21 | 22 | 23 | def str2bool(v): 24 | """ 25 | https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse 26 | """ 27 | if isinstance(v, bool): 28 | return v 29 | if v.lower() in ("yes", "true", "t", "y", "1"): 30 | return True 31 | elif v.lower() in ("no", "false", "f", "n", "0"): 32 | return False 33 | else: 34 | raise argparse.ArgumentTypeError("boolean value expected") 35 | -------------------------------------------------------------------------------- /masked_diffusion/timestep_sampler.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | from abc import ABC, abstractmethod 7 | 8 | import numpy as np 9 | import torch as th 10 | import torch.distributed as dist 11 | 12 | 13 | def create_named_schedule_sampler(name, diffusion): 14 | """ 15 | Create a ScheduleSampler from a library of pre-defined samplers. 16 | :param name: the name of the sampler. 17 | :param diffusion: the diffusion object to sample for. 18 | """ 19 | if name == "uniform": 20 | return UniformSampler(diffusion) 21 | elif name == "loss-second-moment": 22 | return LossSecondMomentResampler(diffusion) 23 | else: 24 | raise NotImplementedError(f"unknown schedule sampler: {name}") 25 | 26 | 27 | class ScheduleSampler(ABC): 28 | """ 29 | A distribution over timesteps in the diffusion process, intended to reduce 30 | variance of the objective. 31 | By default, samplers perform unbiased importance sampling, in which the 32 | objective's mean is unchanged. 33 | However, subclasses may override sample() to change how the resampled 34 | terms are reweighted, allowing for actual changes in the objective. 35 | """ 36 | 37 | @abstractmethod 38 | def weights(self): 39 | """ 40 | Get a numpy array of weights, one per diffusion step. 41 | The weights needn't be normalized, but must be positive. 42 | """ 43 | 44 | def sample(self, batch_size, device): 45 | """ 46 | Importance-sample timesteps for a batch. 47 | :param batch_size: the number of timesteps. 48 | :param device: the torch device to save to. 49 | :return: a tuple (timesteps, weights): 50 | - timesteps: a tensor of timestep indices. 51 | - weights: a tensor of weights to scale the resulting losses. 52 | """ 53 | w = self.weights() 54 | p = w / np.sum(w) 55 | indices_np = np.random.choice(len(p), size=(batch_size,), p=p) 56 | indices = th.from_numpy(indices_np).long().to(device) 57 | weights_np = 1 / (len(p) * p[indices_np]) 58 | weights = th.from_numpy(weights_np).float().to(device) 59 | return indices, weights 60 | 61 | 62 | class UniformSampler(ScheduleSampler): 63 | def __init__(self, diffusion): 64 | self.diffusion = diffusion 65 | self._weights = np.ones([diffusion.num_timesteps]) 66 | 67 | def weights(self): 68 | return self._weights 69 | 70 | 71 | class LossAwareSampler(ScheduleSampler): 72 | def update_with_local_losses(self, local_ts, local_losses): 73 | """ 74 | Update the reweighting using losses from a model. 75 | Call this method from each rank with a batch of timesteps and the 76 | corresponding losses for each of those timesteps. 77 | This method will perform synchronization to make sure all of the ranks 78 | maintain the exact same reweighting. 79 | :param local_ts: an integer Tensor of timesteps. 80 | :param local_losses: a 1D Tensor of losses. 81 | """ 82 | batch_sizes = [ 83 | th.tensor([0], dtype=th.int32, device=local_ts.device) 84 | for _ in range(dist.get_world_size()) 85 | ] 86 | dist.all_gather( 87 | batch_sizes, 88 | th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), 89 | ) 90 | 91 | # Pad all_gather batches to be the maximum batch size. 92 | batch_sizes = [x.item() for x in batch_sizes] 93 | max_bs = max(batch_sizes) 94 | 95 | timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] 96 | loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] 97 | dist.all_gather(timestep_batches, local_ts) 98 | dist.all_gather(loss_batches, local_losses) 99 | timesteps = [ 100 | x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] 101 | ] 102 | losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] 103 | self.update_with_all_losses(timesteps, losses) 104 | 105 | @abstractmethod 106 | def update_with_all_losses(self, ts, losses): 107 | """ 108 | Update the reweighting using losses from a model. 109 | Sub-classes should override this method to update the reweighting 110 | using losses from the model. 111 | This method directly updates the reweighting without synchronizing 112 | between workers. It is called by update_with_local_losses from all 113 | ranks with identical arguments. Thus, it should have deterministic 114 | behavior to maintain state across workers. 115 | :param ts: a list of int timesteps. 116 | :param losses: a list of float losses, one per timestep. 117 | """ 118 | 119 | 120 | class LossSecondMomentResampler(LossAwareSampler): 121 | def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): 122 | self.diffusion = diffusion 123 | self.history_per_term = history_per_term 124 | self.uniform_prob = uniform_prob 125 | self._loss_history = np.zeros( 126 | [diffusion.num_timesteps, history_per_term], dtype=np.float64 127 | ) 128 | self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) 129 | 130 | def weights(self): 131 | if not self._warmed_up(): 132 | return np.ones([self.diffusion.num_timesteps], dtype=np.float64) 133 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) 134 | weights /= np.sum(weights) 135 | weights *= 1 - self.uniform_prob 136 | weights += self.uniform_prob / len(weights) 137 | return weights 138 | 139 | def update_with_all_losses(self, ts, losses): 140 | for t, loss in zip(ts, losses): 141 | if self._loss_counts[t] == self.history_per_term: 142 | # Shift out the oldest loss term. 143 | self._loss_history[t, :-1] = self._loss_history[t, 1:] 144 | self._loss_history[t, -1] = loss 145 | else: 146 | self._loss_history[t, self._loss_counts[t]] = loss 147 | self._loss_counts[t] += 1 148 | 149 | def _warmed_up(self): 150 | return (self._loss_counts == self.history_per_term).all() 151 | -------------------------------------------------------------------------------- /masked_diffusion/train_util.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import functools 3 | import os 4 | 5 | import blobfile as bf 6 | import torch as th 7 | import torch.distributed as dist 8 | from torch.nn.parallel.distributed import DistributedDataParallel as DDP 9 | from torch.optim import AdamW 10 | 11 | from . import dist_util, logger 12 | from .fp16_util import MixedPrecisionTrainer 13 | from .nn import update_ema 14 | from .resample import LossAwareSampler, UniformSampler 15 | from diffusers.models import AutoencoderKL 16 | from adan import Adan 17 | from torch.distributed.optim import ZeroRedundancyOptimizer 18 | # For ImageNet experiments, this was a good default value. 19 | # We found that the lg_loss_scale quickly climbed to 20 | # 20-21 within the first ~1K steps of training. 21 | INITIAL_LOG_LOSS_SCALE = 20.0 22 | 23 | 24 | class TrainLoop: 25 | def __init__( 26 | self, 27 | *, 28 | model, 29 | diffusion, 30 | data, 31 | batch_size, 32 | microbatch, 33 | lr, 34 | ema_rate, 35 | log_interval, 36 | save_interval, 37 | resume_checkpoint, 38 | use_fp16=False, 39 | fp16_scale_growth=1e-3, 40 | schedule_sampler=None, 41 | weight_decay=0.0, 42 | lr_anneal_steps=0, 43 | scale_factor=0.18215, # scale_factor follows DiT and stable diffusion. 44 | opt_type='adan', 45 | use_zero=True, 46 | ): 47 | self.model = model 48 | self.diffusion = diffusion 49 | self.data = data 50 | self.batch_size = batch_size 51 | self.microbatch = microbatch if microbatch > 0 else batch_size 52 | self.lr = lr 53 | self.ema_rate = ( 54 | [ema_rate] 55 | if isinstance(ema_rate, float) 56 | else [float(x) for x in ema_rate.split(",")] 57 | ) 58 | self.log_interval = log_interval 59 | self.save_interval = save_interval 60 | self.resume_checkpoint = resume_checkpoint 61 | self.use_fp16 = use_fp16 62 | self.fp16_scale_growth = fp16_scale_growth 63 | self.schedule_sampler = schedule_sampler or UniformSampler(diffusion) 64 | self.weight_decay = weight_decay 65 | self.lr_anneal_steps = lr_anneal_steps 66 | self.scale_factor = scale_factor 67 | 68 | self.step = 0 69 | self.resume_step = 0 70 | self.global_batch = self.batch_size * dist.get_world_size() 71 | 72 | self.sync_cuda = th.cuda.is_available() 73 | 74 | self._load_and_sync_parameters() 75 | self.mp_trainer = MixedPrecisionTrainer( 76 | model=self.model, 77 | use_fp16=self.use_fp16, 78 | fp16_scale_growth=fp16_scale_growth, 79 | ) 80 | 81 | if opt_type=='adamw': 82 | if use_zero: 83 | self.opt = ZeroRedundancyOptimizer( 84 | self.mp_trainer.master_params, 85 | optimizer_class=Adam, 86 | lr=self.lr, 87 | weight_decay=self.weight_decay 88 | ) 89 | else: 90 | self.opt = AdamW( 91 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay 92 | ) 93 | elif opt_type=='adan': 94 | if use_zero: 95 | self.opt = ZeroRedundancyOptimizer( 96 | self.mp_trainer.master_params, 97 | optimizer_class=Adan, 98 | lr=self.lr, 99 | weight_decay=self.weight_decay, 100 | max_grad_norm=1, fused=True 101 | ) 102 | else: 103 | self.opt = Adan( 104 | self.mp_trainer.master_params, lr=self.lr, weight_decay=self.weight_decay, max_grad_norm=1, fused=True) 105 | 106 | if self.resume_step: 107 | self._load_optimizer_state() 108 | # Model was resumed, either due to a restart or a checkpoint 109 | # being specified at the command line. 110 | self.ema_params = [ 111 | self._load_ema_parameters(rate) for rate in self.ema_rate 112 | ] 113 | else: 114 | self.ema_params = [ 115 | copy.deepcopy(self.mp_trainer.master_params) 116 | for _ in range(len(self.ema_rate)) 117 | ] 118 | 119 | if th.cuda.is_available(): 120 | self.use_ddp = True 121 | self.ddp_model = DDP( 122 | self.model, 123 | device_ids=[dist_util.dev()], 124 | output_device=dist_util.dev(), 125 | broadcast_buffers=False, 126 | bucket_cap_mb=128, 127 | find_unused_parameters=False, 128 | ) 129 | else: 130 | if dist.get_world_size() > 1: 131 | logger.warn( 132 | "Distributed training requires CUDA. " 133 | "Gradients will not be synchronized properly!" 134 | ) 135 | self.use_ddp = False 136 | self.ddp_model = self.model 137 | self.instantiate_first_stage() 138 | 139 | 140 | def instantiate_first_stage(self): 141 | model = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dist_util.dev()) 142 | model = th.compile(model) 143 | self.first_stage_model = model.eval() 144 | self.first_stage_model.train = False 145 | for param in self.first_stage_model.parameters(): 146 | param.requires_grad = False 147 | 148 | # https://github.com/huggingface/diffusers/blob/29b2c93c9005c87f8f04b1f0835babbcea736204/src/diffusers/models/autoencoder_kl.py 149 | @th.no_grad() 150 | def get_first_stage_encoding(self, x): 151 | encoder_posterior = self.first_stage_model.encode(x, return_dict=True)[0] 152 | 153 | z = encoder_posterior.sample() 154 | return z.to(dist_util.dev()) * self.scale_factor 155 | 156 | def _load_and_sync_parameters(self): 157 | resume_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 158 | 159 | if resume_checkpoint: 160 | self.resume_step = parse_resume_step_from_filename(resume_checkpoint) 161 | if dist.get_rank() == 0: 162 | logger.log(f"loading model from checkpoint: {resume_checkpoint}...") 163 | self.model.load_state_dict( 164 | dist_util.load_state_dict( 165 | resume_checkpoint, map_location=dist_util.dev() 166 | ) 167 | ) 168 | 169 | dist_util.sync_params(self.model.parameters()) 170 | 171 | def _load_ema_parameters(self, rate): 172 | ema_params = copy.deepcopy(self.mp_trainer.master_params) 173 | 174 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 175 | ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate) 176 | if ema_checkpoint: 177 | if dist.get_rank() == 0: 178 | logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...") 179 | state_dict = dist_util.load_state_dict( 180 | ema_checkpoint, map_location=dist_util.dev() 181 | ) 182 | ema_params = self.mp_trainer.state_dict_to_master_params(state_dict) 183 | 184 | dist_util.sync_params(ema_params) 185 | return ema_params 186 | 187 | def _load_optimizer_state(self): 188 | main_checkpoint = find_resume_checkpoint() or self.resume_checkpoint 189 | opt_checkpoint = bf.join( 190 | bf.dirname(main_checkpoint), f"opt{self.resume_step:06}.pt" 191 | ) 192 | if bf.exists(opt_checkpoint): 193 | logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}") 194 | state_dict = dist_util.load_state_dict( 195 | opt_checkpoint, map_location=dist_util.dev() 196 | ) 197 | self.opt.load_state_dict(state_dict) 198 | 199 | def run_loop(self): 200 | while ( 201 | not self.lr_anneal_steps 202 | or self.step + self.resume_step < self.lr_anneal_steps 203 | ): 204 | batch, cond = next(self.data) 205 | self.run_step(batch, cond) 206 | if self.step % self.log_interval == 0: 207 | logger.dumpkvs() 208 | if self.step % self.save_interval == 0: 209 | self.opt.consolidate_state_dict() 210 | self.save() 211 | # Run for a finite amount of time in integration tests. 212 | if os.environ.get("DIFFUSION_TRAINING_TEST", "") and self.step > 0: 213 | return 214 | self.step += 1 215 | # Save the last checkpoint if it wasn't already saved. 216 | if (self.step - 1) % self.save_interval != 0: 217 | self.save() 218 | 219 | def run_step(self, batch, cond): 220 | self.forward_backward(batch, cond) 221 | took_step = self.mp_trainer.optimize(self.opt) 222 | if took_step: 223 | self._update_ema() 224 | self._anneal_lr() 225 | self.log_step() 226 | 227 | def forward_backward(self, batch, cond): 228 | self.mp_trainer.zero_grad() 229 | for i in range(0, batch.shape[0], self.microbatch): 230 | 231 | micro = batch[i : i + self.microbatch].to(dist_util.dev()) 232 | micro = self.get_first_stage_encoding(micro).detach() 233 | micro_cond = { 234 | k: v[i : i + self.microbatch].to(dist_util.dev()) 235 | for k, v in cond.items() 236 | } 237 | 238 | last_batch = (i + self.microbatch) >= batch.shape[0] 239 | t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev()) 240 | 241 | compute_losses = functools.partial( 242 | self.diffusion.training_losses, 243 | self.ddp_model, 244 | micro, 245 | t, 246 | model_kwargs=micro_cond, 247 | ) 248 | micro_cond_mask = micro_cond.copy() 249 | micro_cond_mask['enable_mask']=True 250 | compute_losses_mask = functools.partial( 251 | self.diffusion.training_losses, 252 | self.ddp_model, 253 | micro, 254 | t, 255 | model_kwargs=micro_cond_mask, 256 | ) 257 | 258 | if last_batch or not self.use_ddp: 259 | losses = compute_losses() 260 | losses_mask = compute_losses_mask() 261 | else: 262 | with self.ddp_model.no_sync(): 263 | losses = compute_losses() 264 | losses_mask = compute_losses_mask() 265 | 266 | if isinstance(self.schedule_sampler, LossAwareSampler): 267 | self.schedule_sampler.update_with_local_losses( 268 | t, losses["loss"].detach() + losses_mask["loss"].detach() 269 | ) 270 | 271 | loss = (losses["loss"] * weights).mean() + (losses_mask["loss"] * weights).mean() 272 | log_loss_dict( 273 | self.diffusion, t, {k: v * weights for k, v in losses.items()} 274 | ) 275 | log_loss_dict( 276 | self.diffusion, t, {'m_'+k: v * weights for k, v in losses_mask.items()} 277 | ) 278 | self.mp_trainer.backward(loss) 279 | 280 | def _update_ema(self): 281 | for rate, params in zip(self.ema_rate, self.ema_params): 282 | update_ema(params, self.mp_trainer.master_params, rate=rate) 283 | 284 | def _anneal_lr(self): 285 | if not self.lr_anneal_steps: 286 | return 287 | frac_done = (self.step + self.resume_step) / self.lr_anneal_steps 288 | lr = self.lr * (1 - frac_done) 289 | for param_group in self.opt.param_groups: 290 | param_group["lr"] = lr 291 | 292 | def log_step(self): 293 | logger.logkv("step", self.step + self.resume_step) 294 | logger.logkv("samples", (self.step + self.resume_step + 1) * self.global_batch) 295 | 296 | def save(self): 297 | def save_checkpoint(rate, params): 298 | state_dict = self.mp_trainer.master_params_to_state_dict(params) 299 | if dist.get_rank() == 0: 300 | logger.log(f"saving model {rate}...") 301 | if not rate: 302 | filename = f"model{(self.step+self.resume_step):06d}.pt" 303 | else: 304 | filename = f"ema_{rate}_{(self.step+self.resume_step):06d}.pt" 305 | with bf.BlobFile(bf.join(get_blob_logdir(), filename), "wb") as f: 306 | th.save(state_dict, f) 307 | 308 | save_checkpoint(0, self.mp_trainer.master_params) 309 | for rate, params in zip(self.ema_rate, self.ema_params): 310 | save_checkpoint(rate, params) 311 | 312 | if dist.get_rank() == 0: 313 | with bf.BlobFile( 314 | bf.join(get_blob_logdir(), f"opt{(self.step+self.resume_step):06d}.pt"), 315 | "wb", 316 | ) as f: 317 | th.save(self.opt.state_dict(), f) 318 | 319 | dist.barrier() 320 | 321 | 322 | def parse_resume_step_from_filename(filename): 323 | """ 324 | Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the 325 | checkpoint's number of steps. 326 | """ 327 | split = filename.split("model") 328 | if len(split) < 2: 329 | return 0 330 | split1 = split[-1].split(".")[0] 331 | try: 332 | return int(split1) 333 | except ValueError: 334 | return 0 335 | 336 | 337 | def get_blob_logdir(): 338 | # You can change this to be a separate path to save checkpoints to 339 | # a blobstore or some external drive. 340 | return logger.get_dir() 341 | 342 | 343 | def find_resume_checkpoint(): 344 | # On your infrastructure, you may want to override this to automatically 345 | # discover the latest checkpoint on your blob storage, etc. 346 | return None 347 | 348 | 349 | def find_ema_checkpoint(main_checkpoint, step, rate): 350 | if main_checkpoint is None: 351 | return None 352 | filename = f"ema_{rate}_{(step):06d}.pt" 353 | path = bf.join(bf.dirname(main_checkpoint), filename) 354 | if bf.exists(path): 355 | return path 356 | return None 357 | 358 | 359 | def log_loss_dict(diffusion, ts, losses): 360 | for key, values in losses.items(): 361 | logger.logkv_mean(key, values.mean().item()) 362 | # Log the quantiles (four quartiles, in particular). 363 | for sub_t, sub_loss in zip(ts.cpu().numpy(), values.detach().cpu().numpy()): 364 | quartile = int(4 * sub_t / diffusion.num_timesteps) 365 | logger.logkv_mean(f"{key}_q{quartile}", sub_loss) 366 | -------------------------------------------------------------------------------- /run.sh: -------------------------------------------------------------------------------- 1 | pip3 install torch==2.0 torchvision torchaudio 2 | pip install -e . 3 | python -m pip install git+https://github.com/sail-sg/Adan.git 4 | export OPENAI_LOGDIR=output_mdtv2_s2 5 | NUM_GPUS=8 6 | 7 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 6 --model MDTv2_S_2" 8 | DIFFUSION_FLAGS="--diffusion_steps 1000" 9 | TRAIN_FLAGS="--batch_size 32 --lr 5e-4" 10 | DATA_PATH=/dataset/imagenet-raw/train 11 | 12 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 13 | -------------------------------------------------------------------------------- /run_ddp_master.sh: -------------------------------------------------------------------------------- 1 | pip3 install torch==2.0 torchvision torchaudio 2 | python -m pip install git+https://github.com/sail-sg/Adan.git 3 | pip install -e . 4 | export OPENAI_LOGDIR=output_mdtv2_xl2 5 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_XL_2" 6 | DIFFUSION_FLAGS="--diffusion_steps 1000" 7 | TRAIN_FLAGS="--batch_size 8" 8 | DATA_PATH=/dataset/imagenet-raw/train 9 | NUM_NODE=4 10 | GPU_PRE_NODE=8 11 | 12 | python -m torch.distributed.launch --master_addr=$(hostname) --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 13 | -------------------------------------------------------------------------------- /run_ddp_worker.sh: -------------------------------------------------------------------------------- 1 | pip3 install torch==2.0 torchvision torchaudio 2 | python -m pip install git+https://github.com/sail-sg/Adan.git 3 | pip install -e . 4 | export OPENAI_LOGDIR=output_mdtv2_xl2 5 | MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_XL_2" 6 | DIFFUSION_FLAGS="--diffusion_steps 1000" 7 | TRAIN_FLAGS="--batch_size 8" 8 | DATA_PATH=/dataset/imagenet-raw/train 9 | NUM_NODE=4 10 | GPU_PRE_NODE=8 11 | 12 | python -m torch.distributed.launch --master_addr=$MASTER_ADDR --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS 13 | -------------------------------------------------------------------------------- /run_sample.sh: -------------------------------------------------------------------------------- 1 | 2 | pip3 install torch==2.0 torchvision torchaudio 3 | pip install -e . 4 | # pip uninstall -y timm 5 | # pip install mpi4py timm diffusers 6 | MODEL_PATH=output_mdt_xl2/mdt_xl2_v2_ckpt.pt 7 | export OPENAI_LOGDIR=output_mdt_xl2_eval 8 | NUM_GPUS=8 9 | 10 | echo 'CFG Class-conditional sampling:' 11 | MODEL_FLAGS="--image_size 256 --model MDT_XL_2 --decode_layer 4" 12 | DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000 --cfg_cond True" 13 | echo $MODEL_FLAGS 14 | echo $DIFFUSION_FLAGS 15 | echo $MODEL_PATH 16 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_sample.py --model_path $MODEL_PATH $MODEL_FLAGS $DIFFUSION_FLAGS 17 | echo $MODEL_FLAGS 18 | echo $DIFFUSION_FLAGS 19 | echo $MODEL_PATH 20 | python evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz $OPENAI_LOGDIR/samples_50000x256x256x3.npz 21 | 22 | echo 'Class-conditional sampling:' 23 | MODEL_FLAGS="--image_size 256 --model MDT_XL_2 --decode_layer 4" 24 | DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000" 25 | echo $MODEL_FLAGS 26 | echo $DIFFUSION_FLAGS 27 | echo $MODEL_PATH 28 | python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_sample.py --model_path $MODEL_PATH $MODEL_FLAGS $DIFFUSION_FLAGS 29 | echo $MODEL_FLAGS 30 | echo $DIFFUSION_FLAGS 31 | echo $MODEL_PATH 32 | python evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz $OPENAI_LOGDIR/samples_50000x256x256x3.npz 33 | 34 | -------------------------------------------------------------------------------- /scripts/image_sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Generate a large batch of image samples from a model and save them as a large 3 | numpy array. This can be used to produce samples for FID evaluation. 4 | """ 5 | 6 | import argparse 7 | import os 8 | 9 | import numpy as np 10 | import torch as th 11 | import torch.distributed as dist 12 | from masked_diffusion.script_util import ( 13 | NUM_CLASSES, 14 | add_dict_to_argparser, 15 | args_to_dict, 16 | ) 17 | 18 | from masked_diffusion import ( 19 | create_diffusion, 20 | model_and_diffusion_defaults, 21 | diffusion_defaults, 22 | dist_util, 23 | logger, 24 | ) 25 | 26 | import masked_diffusion.models as models_mdt 27 | from diffusers.models import AutoencoderKL 28 | 29 | def main(): 30 | th.backends.cuda.matmul.allow_tf32 = True 31 | args = create_argparser().parse_args() 32 | 33 | dist_util.setup_dist() 34 | logger.configure() 35 | 36 | logger.log("creating model and diffusion...") 37 | 38 | configs = args_to_dict(args, model_and_diffusion_defaults().keys()) 39 | print(configs) 40 | image_size = configs['image_size'] 41 | latent_size = image_size // 8 42 | model = models_mdt.__dict__[args.model](input_size=latent_size, decode_layer=args.decode_layer) 43 | msg = model.load_state_dict( 44 | dist_util.load_state_dict(args.model_path, map_location="cpu") 45 | ) 46 | print(msg) 47 | config_diffusion = args_to_dict(args, diffusion_defaults().keys()) 48 | config_diffusion['timestep_respacing']= str(args.num_sampling_steps) 49 | print(config_diffusion) 50 | diffusion = create_diffusion(**config_diffusion) 51 | model.to(dist_util.dev()) 52 | model = th.compile(model) 53 | if args.use_fp16: 54 | model.convert_to_fp16() 55 | model.eval() 56 | th.set_grad_enabled(False) 57 | 58 | vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-"+str(args.vae_decoder)).to(dist_util.dev()) 59 | 60 | logger.log("sampling...") 61 | all_images = [] 62 | all_labels = [] 63 | while len(all_images) * args.batch_size < args.num_samples: 64 | model_kwargs = {} 65 | if args.cfg_cond: 66 | classes = th.randint( 67 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 68 | ) 69 | z = th.randn(args.batch_size, 4, latent_size, latent_size, device=dist_util.dev()) 70 | # Setup classifier-free guidance: 71 | z = th.cat([z, z], 0) 72 | classes_null = th.tensor([NUM_CLASSES] * args.batch_size, device=dist_util.dev()) 73 | classes_all = th.cat([classes, classes_null], 0) 74 | model_kwargs["y"] = classes_all 75 | model_kwargs["cfg_scale"] = args.cfg_scale 76 | model_kwargs["diffusion_steps"] = config_diffusion['diffusion_steps'] 77 | model_kwargs["scale_pow"] = args.scale_pow 78 | else: 79 | z = th.randn(args.batch_size, 4, latent_size, latent_size, device=dist_util.dev()) 80 | classes = th.randint( 81 | low=0, high=NUM_CLASSES, size=(args.batch_size,), device=dist_util.dev() 82 | ) 83 | model_kwargs["y"] = classes 84 | 85 | 86 | sample_fn = ( 87 | diffusion.p_sample_loop if not args.use_ddim else diffusion.ddim_sample_loop 88 | ) 89 | sample = sample_fn( 90 | model.forward_with_cfg, 91 | z.shape, 92 | z, 93 | clip_denoised=args.clip_denoised, 94 | progress=True, 95 | model_kwargs=model_kwargs, 96 | device=dist_util.dev() 97 | ) 98 | if args.cfg_cond: 99 | sample, _ = sample.chunk(2, dim=0) # Remove null class samples 100 | # latent to image 101 | sample = vae.decode(sample / 0.18215).sample 102 | sample = ((sample + 1) * 127.5).clamp(0, 255).to(th.uint8) # clip in range -1,1 103 | sample = sample.permute(0, 2, 3, 1) 104 | sample = sample.contiguous() 105 | 106 | gathered_samples = [th.zeros_like(sample) for _ in range(dist.get_world_size())] 107 | dist.all_gather(gathered_samples, sample) # gather not supported with NCCL 108 | all_images.extend([sample.cpu().numpy() for sample in gathered_samples]) 109 | if args.class_cond: 110 | gathered_labels = [ 111 | th.zeros_like(classes) for _ in range(dist.get_world_size()) 112 | ] 113 | dist.all_gather(gathered_labels, classes) 114 | all_labels.extend([labels.cpu().numpy() for labels in gathered_labels]) 115 | # logger.log(f"created {len(all_images) * args.batch_size} samples") 116 | 117 | arr = np.concatenate(all_images, axis=0) 118 | arr = arr[: args.num_samples] 119 | if args.class_cond: 120 | label_arr = np.concatenate(all_labels, axis=0) 121 | label_arr = label_arr[: args.num_samples] 122 | if dist.get_rank() == 0: 123 | shape_str = "x".join([str(x) for x in arr.shape]) 124 | out_path = os.path.join(logger.get_dir(), f"samples_{shape_str}.npz") 125 | logger.log(f"saving to {out_path}") 126 | if args.class_cond: 127 | np.savez(out_path, arr, label_arr) 128 | else: 129 | np.savez(out_path, arr) 130 | 131 | dist.barrier() 132 | logger.log("sampling complete") 133 | 134 | 135 | def create_argparser(): 136 | defaults = dict( 137 | num_sampling_steps=250, 138 | clip_denoised=False, 139 | num_samples=5000, 140 | batch_size=16, 141 | use_ddim=False, 142 | model_path="", 143 | model="MDT_S_2", 144 | class_cond=True, 145 | cfg_scale=3.8, 146 | decode_layer=None, 147 | cfg_cond=False, 148 | ) 149 | defaults.update(model_and_diffusion_defaults()) 150 | parser = argparse.ArgumentParser() 151 | parser.add_argument('--scale_pow', default=4, type=float) 152 | parser.add_argument('--vae_decoder', type=str, default='ema') # ema or mse 153 | parser.add_argument('--world_size', default=1, type=int, 154 | help='number of distributed processes') 155 | parser.add_argument('--local_rank', default=-1, type=int) 156 | parser.add_argument('--local-rank', default=-1, type=int) 157 | parser.add_argument('--dist_on_itp', action='store_true') 158 | parser.add_argument('--dist_url', default='env://', 159 | help='url used to set up distributed training') 160 | parser.add_argument( 161 | "--rank", default=0, type=int, help="""rank for distrbuted training.""" 162 | ) 163 | add_dict_to_argparser(parser, defaults) 164 | return parser 165 | 166 | 167 | if __name__ == "__main__": 168 | main() 169 | -------------------------------------------------------------------------------- /scripts/image_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Train a diffusion model on images. 3 | """ 4 | 5 | import argparse 6 | 7 | from masked_diffusion import dist_util, logger 8 | from masked_diffusion.image_datasets import load_data 9 | from masked_diffusion.resample import create_named_schedule_sampler 10 | from masked_diffusion.script_util import ( 11 | args_to_dict, 12 | add_dict_to_argparser, 13 | ) 14 | from masked_diffusion.train_util import TrainLoop 15 | from masked_diffusion import create_diffusion, model_and_diffusion_defaults, diffusion_defaults 16 | import masked_diffusion.models as models_mdt 17 | 18 | def main(): 19 | args = create_argparser().parse_args() 20 | 21 | dist_util.setup_dist_multinode(args) 22 | logger.configure() 23 | 24 | logger.log("creating model and diffusion...") 25 | configs = args_to_dict(args, model_and_diffusion_defaults().keys()) 26 | print(configs) 27 | print(args) 28 | image_size = configs['image_size'] 29 | latent_size = image_size // 8 30 | model = models_mdt.__dict__[args.model](input_size=latent_size, mask_ratio=args.mask_ratio, decode_layer=args.decode_layer) 31 | print(model) 32 | diffusion = create_diffusion(**args_to_dict(args, diffusion_defaults().keys())) 33 | model.to(dist_util.dev()) 34 | 35 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion) 36 | 37 | logger.log("creating data loader...") 38 | data = load_data( 39 | data_dir=args.data_dir, 40 | batch_size=args.batch_size, 41 | image_size=args.image_size, 42 | class_cond=args.class_cond, 43 | ) 44 | 45 | logger.log("training...") 46 | TrainLoop( 47 | model=model, 48 | diffusion=diffusion, 49 | data=data, 50 | batch_size=args.batch_size, 51 | microbatch=args.microbatch, 52 | lr=args.lr, 53 | ema_rate=args.ema_rate, 54 | log_interval=args.log_interval, 55 | save_interval=args.save_interval, 56 | resume_checkpoint=args.resume_checkpoint, 57 | use_fp16=args.use_fp16, 58 | fp16_scale_growth=args.fp16_scale_growth, 59 | schedule_sampler=schedule_sampler, 60 | weight_decay=args.weight_decay, 61 | lr_anneal_steps=args.lr_anneal_steps, 62 | ).run_loop() 63 | 64 | 65 | def create_argparser(): 66 | defaults = dict( 67 | data_dir="", 68 | schedule_sampler="uniform", 69 | lr=3e-4, 70 | weight_decay=0.0, 71 | lr_anneal_steps=0, 72 | batch_size=1, 73 | microbatch=-1, # -1 disables microbatches 74 | ema_rate="0.9999", # comma-separated list of EMA values 75 | log_interval=500, 76 | save_interval=10000, 77 | resume_checkpoint="", 78 | use_fp16=False, 79 | fp16_scale_growth=1e-3, 80 | model="MDTv2_S_2", 81 | mask_ratio=None, 82 | decode_layer=4, 83 | ) 84 | defaults.update(model_and_diffusion_defaults()) 85 | parser = argparse.ArgumentParser() 86 | parser.add_argument('--world_size', default=1, type=int, 87 | help='number of distributed processes') 88 | parser.add_argument('--local_rank', default=-1, type=int) 89 | parser.add_argument('--local-rank', default=-1, type=int) 90 | parser.add_argument('--dist_on_itp', action='store_true') 91 | parser.add_argument('--dist_url', default='env://', 92 | help='url used to set up distributed training') 93 | parser.add_argument( 94 | "--rank", default=0, type=int, help="""rank for distrbuted training.""" 95 | ) 96 | add_dict_to_argparser(parser, defaults) 97 | return parser 98 | 99 | 100 | if __name__ == "__main__": 101 | main() 102 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name="masked-diffusion", 5 | py_modules=["masked_diffusion"], 6 | install_requires=["blobfile>=1.0.5", "torch", "tqdm"], 7 | ) 8 | --------------------------------------------------------------------------------