├── .gitignore ├── LICENSE.md ├── README.md ├── demo ├── imagenet_example.py ├── safety_checker.py └── text_to_image_sdxl.py ├── dnnlib ├── __init__.py └── util.py ├── docs └── teaser.jpg ├── experiments ├── imagenet │ ├── README.md │ ├── imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch.sh │ ├── imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume.sh │ └── imagenet_lr2e-6_scratch.sh ├── sdv1.5 │ ├── README.md │ ├── laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch.sh │ └── laion6.25_sd_baseline_8node_guidance1.75_lr5e-7_seed10_dfake10_diffusion1000_gan1e-3_noode_resume_fixdata.sh └── sdxl │ ├── README.md │ ├── sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh │ ├── sdxl_cond999_8node_lr5e-5_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch_lora.sh │ ├── sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch.sh │ └── sdxl_lr1e-5_8node_ode_pretraining_10k_cond399.sh ├── figures └── README.md ├── fsdp_configs └── fsdp_1node_debug.yaml ├── main ├── coco_eval │ ├── captions_coco14_test.txt │ ├── cleanfid │ │ ├── __init__.py │ │ ├── clip_features.py │ │ ├── downloads_helper.py │ │ ├── features.py │ │ ├── fid.py │ │ ├── inception_pytorch.py │ │ ├── inception_torchscript.py │ │ ├── leaderboard.py │ │ ├── resize.py │ │ ├── utils.py │ │ └── wrappers.py │ └── coco_evaluator.py ├── data │ ├── create_imagenet_lmdb.py │ ├── create_lmdb_iterative.py │ └── lmdb_dataset.py ├── edm │ ├── edm_guidance.py │ ├── edm_network.py │ ├── edm_unified_model.py │ ├── test_folder_edm.py │ └── train_edm.py ├── sd_guidance.py ├── sd_image_dataset.py ├── sd_unet_forward.py ├── sd_unified_model.py ├── sdxl │ ├── create_sdxl_fsdp_configs.py │ ├── data_process.py │ ├── extract_lora_module.py │ ├── generate_noise_image_pairs_laion_sdxl.py │ ├── generate_vae_latents.py │ ├── sdxl_ode_dataset.py │ ├── sdxl_text_encoder.py │ └── test_folder_sdxl.py ├── test_folder_sd.py ├── train_sd.py ├── train_sd_ode.py └── utils.py ├── requirements.txt ├── scripts ├── download_hf_checkpoint.sh ├── download_imagenet.sh ├── download_sdv15.sh ├── download_sdxl.sh ├── download_sdxl_1step_ode_pairs_ckpt.sh └── download_sdxl_ode_pair_10k_lmdb.sh ├── setup.py ├── third_party └── edm │ ├── Dockerfile │ ├── LICENSE.txt │ ├── README.md │ ├── dataset_tool.py │ ├── dnnlib │ ├── __init__.py │ └── util.py │ ├── docs │ ├── afhqv2-64x64.png │ ├── cifar10-32x32.png │ ├── dataset-tool-help.txt │ ├── ffhq-64x64.png │ ├── fid-help.txt │ ├── generate-help.txt │ ├── imagenet-64x64.png │ ├── teaser-1280x640.jpg │ ├── teaser-1920x640.jpg │ ├── teaser-640x480.jpg │ └── train-help.txt │ ├── environment.yml │ ├── example.py │ ├── fid.py │ ├── generate.py │ ├── torch_utils │ ├── __init__.py │ ├── distributed.py │ ├── misc.py │ ├── persistence.py │ └── training_stats.py │ ├── train.py │ └── training │ ├── __init__.py │ ├── augment.py │ ├── dataset.py │ ├── loss.py │ ├── networks.py │ └── training_loop.py └── torch_utils ├── __init__.py ├── distributed.py ├── misc.py ├── persistence.py └── training_stats.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | model/ 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | cover/ 54 | 55 | # Translations 56 | *.mo 57 | *.pot 58 | 59 | # Django stuff: 60 | *.log 61 | local_settings.py 62 | db.sqlite3 63 | db.sqlite3-journal 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | .pybuilder/ 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | # For a library or package, you might want to ignore these files since the code is 88 | # intended to run in multiple environments; otherwise, check them in: 89 | # .python-version 90 | 91 | # pipenv 92 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 93 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 94 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 95 | # install all needed dependencies. 96 | #Pipfile.lock 97 | 98 | # poetry 99 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 100 | # This is especially recommended for binary packages to ensure reproducibility, and is more 101 | # commonly ignored for libraries. 102 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 103 | #poetry.lock 104 | 105 | # pdm 106 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 107 | #pdm.lock 108 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 109 | # in version control. 110 | # https://pdm.fming.dev/#use-with-ide 111 | .pdm.toml 112 | 113 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 114 | __pypackages__/ 115 | 116 | # Celery stuff 117 | celerybeat-schedule 118 | celerybeat.pid 119 | 120 | # SageMath parsed files 121 | *.sage.py 122 | 123 | # Environments 124 | .env 125 | .venv 126 | env/ 127 | venv/ 128 | ENV/ 129 | env.bak/ 130 | venv.bak/ 131 | 132 | # Spyder project settings 133 | .spyderproject 134 | .spyproject 135 | 136 | # Rope project settings 137 | .ropeproject 138 | 139 | # mkdocs documentation 140 | /site 141 | 142 | # mypy 143 | .mypy_cache/ 144 | .dmypy.json 145 | dmypy.json 146 | 147 | # Pyre type checker 148 | .pyre/ 149 | 150 | # pytype static type analyzer 151 | .pytype/ 152 | 153 | # Cython debug symbols 154 | cython_debug/ 155 | 156 | # PyCharm 157 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 158 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 159 | # and can be added to the global gitignore or merged into this file. For a more nuclear 160 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 161 | #.idea/ 162 | 163 | *.swp 164 | 165 | 166 | diffusers 167 | 168 | */*.png 169 | 170 | tests/data 171 | 172 | wandb 173 | 174 | .vscode 175 | *.zip 176 | *.npz 177 | *.tar.gz 178 | *.pyc 179 | *.pb 180 | *.pkl 181 | *.npz 182 | ignored_cache 183 | 184 | *.jpg 185 | *.png 186 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Improved Distribution Matching Distillation for Fast Image Synthesis [[Huggingface Repo](https://huggingface.co/tianweiy/DMD2)][[ComfyUI](https://gist.github.com/comfyanonymous/fcce4ced378f74f4c46026b134faf27a)][[Colab](https://colab.research.google.com/drive/1iGk7IW2WosophOVYpdW_KZGIfYpATOm7?usp=sharing)] 2 | 3 | Few-step Text-to-Image Generation. 4 | 5 | ![image/jpeg](docs/teaser.jpg) 6 | 7 | > [**Improved Distribution Matching Distillation for Fast Image Synthesis**](https://tianweiy.github.io/dmd2/), 8 | > Tianwei Yin, Michaël Gharbi, Taesung Park, Richard Zhang, Eli Shechtman, Frédo Durand, William T. Freeman 9 | > *NeurIPS 2024 ([arXiv 2405.14867](https://arxiv.org/abs/2405.14867))* 10 | 11 | ## Contact 12 | 13 | Feel free to contact us if you have any questions about the paper! 14 | 15 | Tianwei Yin [tianweiy@mit.edu](mailto:tianweiy@mit.edu) 16 | 17 | ## Abstract 18 | 19 | Recent approaches have shown promises distilling diffusion models into 20 | efficient one-step generators. Among them, Distribution Matching Distillation 21 | (DMD) produces one-step generators that match their teacher in distribution, 22 | without enforcing a one-to-one correspondence with the sampling trajectories of 23 | their teachers. However, to ensure stable training, DMD requires an additional 24 | regression loss computed using a large set of noise-image pairs generated by 25 | the teacher with many steps of a deterministic sampler. This is costly for 26 | large-scale text-to-image synthesis and limits the student's quality, tying it 27 | too closely to the teacher's original sampling paths. We introduce DMD2, a set 28 | of techniques that lift this limitation and improve DMD training. First, we 29 | eliminate the regression loss and the need for expensive dataset construction. 30 | We show that the resulting instability is due to the fake critic not estimating 31 | the distribution of generated samples accurately and propose a two time-scale 32 | update rule as a remedy. Second, we integrate a GAN loss into the distillation 33 | procedure, discriminating between generated samples and real images. This lets 34 | us train the student model on real data, mitigating the imperfect real score 35 | estimation from the teacher model, and enhancing quality. Lastly, we modify the 36 | training procedure to enable multi-step sampling. We identify and address the 37 | training-inference input mismatch problem in this setting, by simulating 38 | inference-time generator samples during training time. Taken together, our 39 | improvements set new benchmarks in one-step image generation, with FID scores 40 | of 1.28 on ImageNet-64x64 and 8.35 on zero-shot COCO 2014, surpassing the 41 | original teacher despite a 500X reduction in inference cost. Further, we show 42 | our approach can generate megapixel images by distilling SDXL, demonstrating 43 | exceptional visual quality among few-step methods. 44 | 45 | ## Environment Setup 46 | 47 | ```.bash 48 | # In conda env 49 | conda create -n dmd2 python=3.8 -y 50 | conda activate dmd2 51 | 52 | pip install --upgrade anyio 53 | pip install -r requirements.txt 54 | python setup.py develop 55 | ``` 56 | 57 | ## Inference Example 58 | 59 | #### ImageNet 60 | 61 | ```.bash 62 | python -m demo.imagenet_example --checkpoint_path IMAGENET_CKPT_PATH 63 | ``` 64 | 65 | #### Text-to-Image 66 | 67 | ```.bash 68 | # Note: on the demo page, click ``Use Tiny VAE for faster decoding'' to enable much faster speed and lower memory consumption using a Tiny VAE from [madebyollin](https://huggingface.co/madebyollin/taesdxl) 69 | 70 | # 4 step (much higher quality than 1 step) 71 | python -m demo.text_to_image_sdxl --checkpoint_path SDXL_CKPT_PATH --precision float16 72 | 73 | # 1 step 74 | python -m demo.text_to_image_sdxl --num_step 1 --checkpoint_path SDXL_CKPT_PATH --precision float16 --conditioning_timestep 399 75 | ``` 76 | 77 | We can also use the standard diffuser pipeline: 78 | 79 | #### 4-step UNet generation 80 | 81 | ```python 82 | import torch 83 | from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler 84 | from huggingface_hub import hf_hub_download 85 | from safetensors.torch import load_file 86 | base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" 87 | repo_name = "tianweiy/DMD2" 88 | ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin" 89 | 90 | # Load model. 91 | with torch.device("meta"): 92 | unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to(torch.float16) 93 | state_dict_path = hf_hub_download(repo_name, ckpt_name) 94 | unet.load_state_dict(torch.load(state_dict_path), assign=True) 95 | unet.to("cuda") 96 | 97 | pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda") 98 | pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) 99 | prompt="a photo of a cat" 100 | 101 | # LCMScheduler's default timesteps are different from the one we used for training 102 | image=pipe(prompt=prompt, num_inference_steps=4, guidance_scale=0, timesteps=[999, 749, 499, 249]).images[0] 103 | ``` 104 | 105 | #### 4-step LoRA generation 106 | 107 | ```python 108 | import torch 109 | from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler 110 | from huggingface_hub import hf_hub_download 111 | from safetensors.torch import load_file 112 | base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" 113 | repo_name = "tianweiy/DMD2" 114 | ckpt_name = "dmd2_sdxl_4step_lora_fp16.safetensors" 115 | # Load model. 116 | pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16, variant="fp16").to("cuda") 117 | pipe.load_lora_weights(hf_hub_download(repo_name, ckpt_name)) 118 | pipe.fuse_lora(lora_scale=1.0) # we might want to make the scale smaller for community models 119 | 120 | pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) 121 | prompt="a photo of a cat" 122 | 123 | # LCMScheduler's default timesteps are different from the one we used for training 124 | image=pipe(prompt=prompt, num_inference_steps=4, guidance_scale=0, timesteps=[999, 749, 499, 249]).images[0] 125 | ``` 126 | 127 | #### 1-step UNet generation 128 | 129 | ```python 130 | import torch 131 | from diffusers import DiffusionPipeline, UNet2DConditionModel, LCMScheduler 132 | from huggingface_hub import hf_hub_download 133 | from safetensors.torch import load_file 134 | base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" 135 | repo_name = "tianweiy/DMD2" 136 | ckpt_name = "dmd2_sdxl_1step_unet_fp16.bin" 137 | # Load model. 138 | unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16) 139 | unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda")) 140 | pipe = DiffusionPipeline.from_pretrained(base_model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda") 141 | pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) 142 | prompt="a photo of a cat" 143 | image=pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0, timesteps=[399]).images[0] 144 | ``` 145 | 146 | #### 4-step T2I Adapter 147 | 148 | ```python 149 | from diffusers import StableDiffusionXLAdapterPipeline, T2IAdapter, AutoencoderKL, UNet2DConditionModel, LCMScheduler 150 | from diffusers.utils import load_image, make_image_grid 151 | from controlnet_aux.canny import CannyDetector 152 | from huggingface_hub import hf_hub_download 153 | import torch 154 | 155 | # load adapter 156 | adapter = T2IAdapter.from_pretrained("TencentARC/t2i-adapter-canny-sdxl-1.0", torch_dtype=torch.float16, varient="fp16").to("cuda") 157 | 158 | vae=AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16) 159 | 160 | base_model_id = "stabilityai/stable-diffusion-xl-base-1.0" 161 | repo_name = "tianweiy/DMD2" 162 | ckpt_name = "dmd2_sdxl_4step_unet_fp16.bin" 163 | # Load model. 164 | unet = UNet2DConditionModel.from_config(base_model_id, subfolder="unet").to("cuda", torch.float16) 165 | unet.load_state_dict(torch.load(hf_hub_download(repo_name, ckpt_name), map_location="cuda")) 166 | 167 | pipe = StableDiffusionXLAdapterPipeline.from_pretrained( 168 | base_model_id, unet=unet, vae=vae, adapter=adapter, torch_dtype=torch.float16, variant="fp16", 169 | ).to("cuda") 170 | pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) 171 | pipe.enable_xformers_memory_efficient_attention() 172 | 173 | canny_detector = CannyDetector() 174 | 175 | url = "https://huggingface.co/Adapter/t2iadapter/resolve/main/figs_SDXLV1.0/org_canny.jpg" 176 | image = load_image(url) 177 | 178 | # Detect the canny map in low resolution to avoid high-frequency details 179 | image = canny_detector(image, detect_resolution=384, image_resolution=1024)#.resize((1024, 1024)) 180 | 181 | prompt = "Mystical fairy in real, magic, 4k picture, high quality" 182 | 183 | gen_images = pipe( 184 | prompt=prompt, 185 | image=image, 186 | num_inference_steps=4, 187 | guidance_scale=0, 188 | adapter_conditioning_scale=0.8, 189 | adapter_conditioning_factor=0.5, 190 | timesteps=[999, 749, 499, 249] 191 | ).images[0] 192 | gen_images.save('out_canny.png') 193 | ``` 194 | 195 | Pretrained models can be found in [ImageNet](experiments/imagenet/README.md) and [SDXL](experiments/sdxl/README.md). 196 | 197 | ## Training and Evaluation 198 | 199 | ### ImageNet-64x64 200 | 201 | Please refer to [ImageNet-64x64](experiments/imagenet/README.md) for details. 202 | 203 | ### SDXL 204 | 205 | Please refer to [SDXL](experiments/sdxl/README.md) for details. 206 | 207 | ### SDv1.5 208 | 209 | Please refer to [SDv1.5](experiments/sdv1.5/README.md) for details. 210 | 211 | ## License 212 | 213 | Improved Distribution Matching Distillation is released under [Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License](LICENSE.md). 214 | 215 | ## Known Issues 216 | 217 | - [ ] Current FSDP for SDXL training is really slow; help is greatly appreciated! 218 | - [ ] Current LORA training is actually slower than the full finetuning and takes the same amount of memory; help is greatly appreciated! 219 | 220 | 221 | ## Citation 222 | 223 | If you find DMD2 useful or relevant to your research, please kindly cite our papers: 224 | 225 | ```bib 226 | @inproceedings{yin2024improved, 227 | title={Improved Distribution Matching Distillation for Fast Image Synthesis}, 228 | author={Yin, Tianwei and Gharbi, Micha{\"e}l and Park, Taesung and Zhang, Richard and Shechtman, Eli and Durand, Fredo and Freeman, William T}, 229 | booktitle={NeurIPS}, 230 | year={2024} 231 | } 232 | 233 | @inproceedings{yin2024onestep, 234 | title={One-step Diffusion with Distribution Matching Distillation}, 235 | author={Yin, Tianwei and Gharbi, Micha{\"e}l and Zhang, Richard and Shechtman, Eli and Durand, Fr{\'e}do and Freeman, William T and Park, Taesung}, 236 | booktitle={CVPR}, 237 | year={2024} 238 | } 239 | ``` 240 | 241 | ## Third-part Code 242 | 243 | [EDM](https://github.com/NVlabs/edm/tree/main) for [dnnlib](dnnlib), [torch_utils](torch_utils) and [edm](third_party/edm) folders. 244 | 245 | ## Acknowledgments 246 | 247 | This work was done while Tianwei Yin was a full-time student at MIT. It was developed based on our reimplementation of the original DMD paper. This work was supported by the National Science Foundation under Cooperative Agreement PHY-2019786 (The NSF AI Institute for Artificial Intelligence and Fundamental Interactions, http://iaifi.org/), by NSF Grant 2105819, by NSF CISE award 1955864, and by funding from Google, GIST, Amazon, and Quanta Computer. 248 | 249 | -------------------------------------------------------------------------------- /demo/imagenet_example.py: -------------------------------------------------------------------------------- 1 | from third_party.edm.training.networks import EDMPrecond 2 | from main.edm.edm_network import get_imagenet_edm_config 3 | from accelerate.utils import set_seed 4 | from PIL import Image 5 | from tqdm import tqdm 6 | import numpy as np 7 | import argparse 8 | import wandb 9 | import torch 10 | import scipy 11 | import time 12 | 13 | def get_imagenet_config(): 14 | base_config = { 15 | "img_resolution": 64, 16 | "img_channels": 3, 17 | "label_dim": 1000, 18 | "use_fp16": False, 19 | "sigma_min": 0, 20 | "sigma_max": float("inf"), 21 | "sigma_data": 0.5, 22 | "model_type": "DhariwalUNet" 23 | } 24 | base_config.update(get_imagenet_edm_config()) 25 | return base_config 26 | 27 | 28 | def create_generator(checkpoint_path, base_model=None): 29 | if base_model is None: 30 | base_config = get_imagenet_config() 31 | generator = EDMPrecond(**base_config) 32 | del generator.model.map_augment 33 | generator.model.map_augment = None 34 | else: 35 | generator = base_model 36 | 37 | while True: 38 | try: 39 | state_dict = torch.load(checkpoint_path, map_location="cpu") 40 | break 41 | except: 42 | print(f"fail to load checkpoint {checkpoint_path}") 43 | time.sleep(1) 44 | 45 | print(generator.load_state_dict(state_dict, strict=True)) 46 | 47 | return generator 48 | 49 | 50 | @torch.no_grad() 51 | def sample(accelerator, current_model, args, model_index): 52 | timesteps = torch.ones(args.eval_batch_size, device=accelerator.device, dtype=torch.long) 53 | current_model.eval() 54 | all_images = [] 55 | all_images_tensor = [] 56 | 57 | current_index = 0 58 | 59 | all_labels = torch.arange(0, args.total_eval_samples*2, 60 | device=accelerator.device, dtype=torch.long) % args.label_dim 61 | 62 | set_seed(args.seed+accelerator.process_index) 63 | 64 | while len(all_images_tensor) * args.eval_batch_size * accelerator.num_processes < args.total_eval_samples: 65 | noise = torch.randn(args.eval_batch_size, 3, 66 | args.resolution, args.resolution, device=accelerator.device 67 | ) 68 | 69 | random_labels = all_labels[current_index:current_index+args.eval_batch_size] 70 | one_hot_labels = torch.eye(args.label_dim, device=accelerator.device)[ 71 | random_labels 72 | ] 73 | 74 | current_index += args.eval_batch_size 75 | 76 | eval_images = current_model(noise * args.conditioning_sigma, timesteps * args.conditioning_sigma, one_hot_labels) 77 | eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1) 78 | eval_images = eval_images.contiguous() 79 | 80 | gathered_images = accelerator.gather(eval_images) 81 | 82 | all_images.append(gathered_images.cpu().numpy()) 83 | all_images_tensor.append(gathered_images.cpu()) 84 | 85 | if accelerator.is_main_process: 86 | print("all_images len ", len(torch.cat(all_images_tensor, dim=0))) 87 | 88 | all_images = np.concatenate(all_images, axis=0)[:args.total_eval_samples] 89 | all_images_tensor = torch.cat(all_images_tensor, dim=0)[:args.total_eval_samples] 90 | 91 | if accelerator.is_main_process: 92 | # Uncomment if you need to save the images 93 | # np.savez(os.path.join(args.folder, f"eval_image_model_{model_index:06d}.npz"), all_images) 94 | # raise 95 | grid_size = int(args.test_visual_batch_size**(1/2)) 96 | eval_images_grid = all_images[:grid_size*grid_size].reshape(grid_size, grid_size, args.resolution, args.resolution, 3) 97 | eval_images_grid = np.swapaxes(eval_images_grid, 1, 2).reshape(grid_size*args.resolution, grid_size*args.resolution, 3) 98 | 99 | data_dict = { 100 | "generated_image_grid": wandb.Image(eval_images_grid) 101 | } 102 | 103 | data_dict['image_mean'] = all_images_tensor.float().mean().item() 104 | data_dict['image_std'] = all_images_tensor.float().std().item() 105 | 106 | wandb.log( 107 | data_dict, 108 | step=model_index 109 | ) 110 | 111 | accelerator.wait_for_everyone() 112 | return all_images_tensor 113 | 114 | @torch.no_grad() 115 | def calculate_inception_stats(all_images_tensor, evaluator, accelerator, evaluator_kwargs, feature_dim, max_batch_size): 116 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=accelerator.device) 117 | sigma = torch.ones([feature_dim, feature_dim], dtype=torch.float64, device=accelerator.device) 118 | num_batches = ((len(all_images_tensor) - 1) // (max_batch_size * accelerator.num_processes ) + 1) * accelerator.num_processes 119 | all_batches = torch.arange(len(all_images_tensor)).tensor_split(num_batches) 120 | rank_batches = all_batches[accelerator.process_index :: accelerator.num_processes] 121 | 122 | for i in tqdm(range(num_batches//accelerator.num_processes), unit='batch', disable=not accelerator.is_main_process): 123 | images = all_images_tensor[rank_batches[i]] 124 | features = evaluator(images.permute(0, 3, 1, 2).to(accelerator.device), **evaluator_kwargs).to(torch.float64) 125 | mu += features.sum(0) 126 | sigma += features.T @ features 127 | 128 | # Calculate grand totals. 129 | mu = accelerator.reduce(mu) 130 | sigma = accelerator.reduce(sigma) 131 | mu /= len(all_images_tensor) 132 | sigma -= mu.ger(mu) * len(all_images_tensor) 133 | sigma /= len(all_images_tensor) - 1 134 | return mu.cpu().numpy(), sigma.cpu().numpy() 135 | 136 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): 137 | m = np.square(mu - mu_ref).sum() 138 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) 139 | fid = m + np.trace(sigma + sigma_ref - s * 2) 140 | return float(np.real(fid)) 141 | 142 | @torch.no_grad() 143 | def main(): 144 | parser = argparse.ArgumentParser() 145 | parser.add_argument("--checkpoint_path", type=str, required=True) 146 | parser.add_argument("--eval_batch_size", type=int, default=64) 147 | parser.add_argument("--resolution", type=int, default=64) 148 | parser.add_argument("--label_dim", type=int, default=1000) 149 | parser.add_argument("--sigma_max", type=float, default=80.0) 150 | parser.add_argument("--sigma_min", type=float, default=0.002) 151 | parser.add_argument("--seed", type=int, default=10) 152 | parser.add_argument("--dataset_name", type=str, default="imagenet") 153 | parser.add_argument("--conditioning_sigma", type=float, default=80.0) 154 | 155 | args = parser.parse_args() 156 | 157 | device = torch.device('cuda') 158 | 159 | set_seed(args.seed) 160 | 161 | print(f"Loading models from {args.checkpoint_path}") 162 | 163 | generator = create_generator( 164 | args.checkpoint_path 165 | ).to(device) 166 | 167 | print(f"Generating {args.eval_batch_size} images") 168 | 169 | random_labels = torch.randint(0, args.label_dim, (args.eval_batch_size, ), device=device) 170 | one_hot_labels = torch.eye(args.label_dim, device=device)[ 171 | random_labels 172 | ] 173 | 174 | noise = torch.randn(args.eval_batch_size, 3, 175 | args.resolution, args.resolution, device=device 176 | ) 177 | 178 | eval_images = generator(noise * args.conditioning_sigma, torch.ones(args.eval_batch_size, device=device) * args.conditioning_sigma, one_hot_labels) 179 | eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1) 180 | 181 | print("Saving images") 182 | eval_images = eval_images.cpu().numpy() 183 | 184 | grid_size = int(args.eval_batch_size**(1/2)) 185 | eval_images_grid = eval_images[:grid_size*grid_size].reshape(grid_size, grid_size, args.resolution, args.resolution, 3) 186 | eval_images_grid = np.swapaxes(eval_images_grid, 1, 2).reshape(grid_size*args.resolution, grid_size*args.resolution, 3) 187 | 188 | Image.fromarray(eval_images_grid).save("imagenet_grid.jpg") 189 | 190 | if __name__ == "__main__": 191 | main() -------------------------------------------------------------------------------- /demo/safety_checker.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import numpy as np 16 | import torch 17 | import torch.nn as nn 18 | from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel 19 | 20 | 21 | def cosine_distance(image_embeds, text_embeds): 22 | normalized_image_embeds = nn.functional.normalize(image_embeds) 23 | normalized_text_embeds = nn.functional.normalize(text_embeds) 24 | return torch.mm(normalized_image_embeds, normalized_text_embeds.t()) 25 | 26 | 27 | class StableDiffusionSafetyChecker(PreTrainedModel): 28 | config_class = CLIPConfig 29 | 30 | _no_split_modules = ["CLIPEncoderLayer"] 31 | 32 | def __init__(self, config: CLIPConfig): 33 | super().__init__(config) 34 | 35 | self.vision_model = CLIPVisionModel(config.vision_config) 36 | self.visual_projection = nn.Linear( 37 | config.vision_config.hidden_size, config.projection_dim, bias=False 38 | ) 39 | 40 | self.concept_embeds = nn.Parameter( 41 | torch.ones(17, config.projection_dim), requires_grad=False 42 | ) 43 | self.special_care_embeds = nn.Parameter( 44 | torch.ones(3, config.projection_dim), requires_grad=False 45 | ) 46 | 47 | self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) 48 | self.special_care_embeds_weights = nn.Parameter( 49 | torch.ones(3), requires_grad=False 50 | ) 51 | 52 | @torch.no_grad() 53 | def forward(self, clip_input, images): 54 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 55 | image_embeds = self.visual_projection(pooled_output) 56 | 57 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 58 | special_cos_dist = ( 59 | cosine_distance(image_embeds, self.special_care_embeds) 60 | .cpu() 61 | .float() 62 | .numpy() 63 | ) 64 | cos_dist = ( 65 | cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy() 66 | ) 67 | 68 | result = [] 69 | batch_size = image_embeds.shape[0] 70 | for i in range(batch_size): 71 | result_img = { 72 | "special_scores": {}, 73 | "special_care": [], 74 | "concept_scores": {}, 75 | "bad_concepts": [], 76 | } 77 | 78 | # increase this value to create a stronger `nfsw` filter 79 | # at the cost of increasing the possibility of filtering benign images 80 | adjustment = 0.0 81 | 82 | for concept_idx in range(len(special_cos_dist[0])): 83 | concept_cos = special_cos_dist[i][concept_idx] 84 | concept_threshold = self.special_care_embeds_weights[concept_idx].item() 85 | result_img["special_scores"][concept_idx] = round( 86 | concept_cos - concept_threshold + adjustment, 3 87 | ) 88 | if result_img["special_scores"][concept_idx] > 0: 89 | result_img["special_care"].append( 90 | {concept_idx, result_img["special_scores"][concept_idx]} 91 | ) 92 | adjustment = 0.01 93 | 94 | for concept_idx in range(len(cos_dist[0])): 95 | concept_cos = cos_dist[i][concept_idx] 96 | concept_threshold = self.concept_embeds_weights[concept_idx].item() 97 | result_img["concept_scores"][concept_idx] = round( 98 | concept_cos - concept_threshold + adjustment, 3 99 | ) 100 | if result_img["concept_scores"][concept_idx] > 0: 101 | result_img["bad_concepts"].append(concept_idx) 102 | 103 | result.append(result_img) 104 | 105 | has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] 106 | 107 | return has_nsfw_concepts 108 | 109 | @torch.no_grad() 110 | def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor): 111 | pooled_output = self.vision_model(clip_input)[1] # pooled_output 112 | image_embeds = self.visual_projection(pooled_output) 113 | 114 | special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds) 115 | cos_dist = cosine_distance(image_embeds, self.concept_embeds) 116 | 117 | # increase this value to create a stronger `nsfw` filter 118 | # at the cost of increasing the possibility of filtering benign images 119 | adjustment = 0.0 120 | 121 | special_scores = ( 122 | special_cos_dist - self.special_care_embeds_weights + adjustment 123 | ) 124 | # special_scores = special_scores.round(decimals=3) 125 | special_care = torch.any(special_scores > 0, dim=1) 126 | special_adjustment = special_care * 0.01 127 | special_adjustment = special_adjustment.unsqueeze(1).expand( 128 | -1, cos_dist.shape[1] 129 | ) 130 | 131 | concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment 132 | # concept_scores = concept_scores.round(decimals=3) 133 | has_nsfw_concepts = torch.any(concept_scores > 0, dim=1) 134 | 135 | images[has_nsfw_concepts] = 0.0 # black image 136 | 137 | return images, has_nsfw_concepts -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /docs/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/DMD2/8d8fa55633d47cfb81bbc7a892e7248f9518763f/docs/teaser.jpg -------------------------------------------------------------------------------- /experiments/imagenet/README.md: -------------------------------------------------------------------------------- 1 | ## Getting Started with DMD2 on ImageNet-64x64 2 | 3 | We trained ImageNet using mixed-precision in BF16 format, adapting the EDM's code to accommodate BF16 training (see [LINK](../../third_party/edm/training/networks.py)). We noticed that the training diverges if we use FP16. FP16 might work with some fancy loss scaling; help is greatly appreciated. 4 | 5 | ### Model Zoo 6 | 7 | | Config Name | FID | Link | Iters | Hours | 8 | | ----------- | --- | ---- | ----- | ----- | 9 | | [imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch](./imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch.sh) | 1.51 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch_fid1.51_checkpoint_model_193500/) | 200k | 53 | 10 | | [imagenet_lr2e-6_scratch](./imagenet_lr2e-6_scratch.sh) | 2.61 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/imagenet/imagenet_lr2e-6_scratch_fid2.61_checkpoint_model_405500/) | 410k | 70 | 11 | | [imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume*](./imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume.sh) | 1.28 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume_fid1.28_checkpoint_model_548000/) | 140K | 38 | 12 | 13 | *The final model was resumed from the best checkpoint of the **imagenet_lr2e-6_scratch** run and trained for an additional 140,000 iterations. 14 | 15 | For inference with our models, you only need to download the pytorch_model.bin file from the provided link. For fine-tuning, you will need to download the entire folder. 16 | You can use the following script for that: 17 | 18 | ```bash 19 | export CHECKPOINT_NAME="imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch_fid1.51_checkpoint_model_193500" # note that the imagenet/ is necessary 20 | export OUTPUT_PATH="path/to/your/output/folder" 21 | 22 | bash scripts/download_hf_checkpoint.sh $CHECKPOINT_NAME $OUTPUT_PATH 23 | ``` 24 | 25 | ### Download Base Diffusion Models and Training Data 26 | 27 | ```.bash 28 | export CHECKPOINT_PATH="" # change this to your own checkpoint folder 29 | export WANDB_ENTITY="" # change this to your own wandb entity 30 | export WANDB_PROJECT="" # change this to your own wandb project 31 | 32 | mkdir $CHECKPOINT_PATH 33 | 34 | bash scripts/download_imagenet.sh $CHECKPOINT_PATH 35 | ``` 36 | 37 | You can also add these few export to the bashrc file so that you don't need to run them every time you open a new terminal. 38 | 39 | ### Sample Training/Testing Commands 40 | ```.bash 41 | # start a training with 7 gpu 42 | bash experiments/imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch.sh $CHECKPOINT_PATH $WANDB_ENTITY $WANDB_PROJECT 43 | 44 | # on the same node, start a testing process that continually reads from the checkpoint folder and evaluate the FID 45 | # Change TIMESTAMP_TBD to the real one 46 | python main/edm/test_folder_edm.py \ 47 | --folder $CHECKPOINT_PATH/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch/TIMESTAMP_TBD \ 48 | --wandb_name test_imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch \ 49 | --wandb_entity $WANDB_ENTITY \ 50 | --wandb_project $WANDB_PROJECT \ 51 | --resolution 64 --label_dim 1000 \ 52 | --ref_path $CHECKPOINT_PATH/imagenet_fid_refs_edm.npz \ 53 | --detector_url $CHECKPOINT_PATH/inception-2015-12-05.pkl 54 | ``` 55 | 56 | Please refer to [train_edm.py](../../main/edm/train_edm.py) for various training options. Notably, if the `--delete_ckpts` flag is set to `True`, all checkpoints except the latest one will be deleted during training. Additionally, you can use the `--cache_dir` flag to specify a location with larger storage capacity. The number of checkpoints stored in `cache_dir` is controlled by the `max_checkpoint` argument. 57 | -------------------------------------------------------------------------------- /experiments/imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch.sh: -------------------------------------------------------------------------------- 1 | export CHECKPOINT_PATH=$1 2 | export WANDB_ENTITY=$2 3 | export WANDB_PROJECT=$3 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 torchrun --nproc_per_node 7 --nnodes 1 main/edm/train_edm.py \ 6 | --generator_lr 2e-6 \ 7 | --guidance_lr 2e-6 \ 8 | --train_iters 10000000 \ 9 | --output_path $CHECKPOINT_PATH/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch \ 10 | --batch_size 40 \ 11 | --initialie_generator --log_iters 500 \ 12 | --resolution 64 \ 13 | --label_dim 1000 \ 14 | --dataset_name "imagenet" \ 15 | --seed 10 \ 16 | --model_id $CHECKPOINT_PATH/edm-imagenet-64x64-cond-adm.pkl \ 17 | --wandb_iters 100 \ 18 | --wandb_entity $WANDB_ENTITY \ 19 | --wandb_project $WANDB_PROJECT \ 20 | --use_fp16 \ 21 | --wandb_name "imagenet_gan_classifier_genloss3e-3_diffusion1000_lr2e-6_scratch" \ 22 | --real_image_path $CHECKPOINT_PATH/imagenet-64x64_lmdb \ 23 | --dfake_gen_update_ratio 5 \ 24 | --cls_loss_weight 1e-2 \ 25 | --gan_classifier \ 26 | --gen_cls_loss_weight 3e-3 \ 27 | --diffusion_gan \ 28 | --diffusion_gan_max_timestep 1000 \ 29 | --delete_ckpts \ 30 | --max_checkpoint 500 31 | 32 | -------------------------------------------------------------------------------- /experiments/imagenet/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume.sh: -------------------------------------------------------------------------------- 1 | export CHECKPOINT_PATH=$1 2 | export WANDB_ENTITY=$2 3 | export WANDB_PROJECT=$3 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 torchrun --nproc_per_node 7 --nnodes 1 main/edm/train_edm.py \ 6 | --generator_lr 5e-7 \ 7 | --guidance_lr 5e-7 \ 8 | --train_iters 10000000 \ 9 | --output_path $CHECKPOINT_PATH/imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume \ 10 | --batch_size 40 \ 11 | --initialie_generator --log_iters 500 \ 12 | --resolution 64 \ 13 | --label_dim 1000 \ 14 | --dataset_name "imagenet" \ 15 | --seed 10 \ 16 | --model_id $CHECKPOINT_PATH/edm-imagenet-64x64-cond-adm.pkl \ 17 | --wandb_iters 100 \ 18 | --wandb_entity $WANDB_ENTITY \ 19 | --wandb_project $WANDB_PROJECT \ 20 | --use_fp16 \ 21 | --wandb_name "imagenet_gan_classifier_genloss3e-3_diffusion1000_lr5e-7_resume" \ 22 | --real_image_path $CHECKPOINT_PATH/imagenet-64x64_lmdb \ 23 | --dfake_gen_update_ratio 5 \ 24 | --cls_loss_weight 1e-2 \ 25 | --gan_classifier \ 26 | --gen_cls_loss_weight 3e-3 \ 27 | --diffusion_gan \ 28 | --diffusion_gan_max_timestep 1000 \ 29 | --delete_ckpts \ 30 | --max_checkpoint 200 \ 31 | --ckpt_only_path $CHECKPOINT_PATH/imagenet_lr2e-6_scratch_fid2.61_checkpoint_model_405500/ 32 | 33 | -------------------------------------------------------------------------------- /experiments/imagenet/imagenet_lr2e-6_scratch.sh: -------------------------------------------------------------------------------- 1 | export CHECKPOINT_PATH=$1 2 | export WANDB_ENTITY=$2 3 | export WANDB_PROJECT=$3 4 | 5 | CUDA_VISIBLE_DEVICES=1,2,3,4,5,6,7 torchrun --nproc_per_node 7 --nnodes 1 main/train_edm.py \ 6 | --generator_lr 2e-6 \ 7 | --guidance_lr 2e-6 \ 8 | --train_iters 10000000 \ 9 | --output_path $CHECKPOINT_PATH/imagenet_lr2e-6_scratch \ 10 | --batch_size 40 \ 11 | --initialie_generator --log_iters 500 \ 12 | --resolution 64 \ 13 | --label_dim 1000 \ 14 | --dataset_name "imagenet" \ 15 | --seed 10 \ 16 | --model_id $CHECKPOINT_PATH/edm-imagenet-64x64-cond-adm.pkl \ 17 | --wandb_iters 100 \ 18 | --wandb_entity $WANDB_ENTITY \ 19 | --wandb_project $WANDB_PROJECT \ 20 | --use_fp16 \ 21 | --wandb_name "imagenet_lr2e-6_scratch" \ 22 | --real_image_path $CHECKPOINT_PATH/imagenet-64x64_lmdb \ 23 | --dfake_gen_update_ratio 5 \ 24 | --delete_ckpts \ 25 | --max_checkpoint 200 26 | 27 | -------------------------------------------------------------------------------- /experiments/sdv1.5/README.md: -------------------------------------------------------------------------------- 1 | ## Getting Started with DMD2 on SDv1.5 2 | 3 | ### Model Zoo 4 | 5 | | Config Name | FID | Link | Iters | Hours | 6 | | ----------- | --- | ---- | ----- | ----- | 7 | | [laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch](./laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch.sh) | 9.28 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/sdv1.5/laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch_fid9.28_checkpoint_model_039000) | 39k | 25 | 8 | | [laion6.25_sd_baseline_8node_guidance1.75_lr5e-7_seed10_dfake10_diffusion1000_gan1e-3_resume*](./laion6.25_sd_baseline_8node_guidance1.75_lr5e-7_seed10_dfake10_diffusion1000_gan1e-3_resume.sh) | 8.35 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/sdv1.5/laion6.25_sd_baseline_8node_guidance1.75_lr5e-7_seed10_dfake10_diffusion1000_gan1e-3_resume_fid8.35_checkpoint_model_041000/) | 2k | 2 | 9 | 10 | *The final model was resumed from the best checkpoint of the **laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch** run and trained for an additional 2,000 iterations. 11 | 12 | For inference with our models, you only need to download the pytorch_model.bin file from the provided link. For fine-tuning, you will need to download the entire folder. 13 | You can use the following script for that: 14 | 15 | ```bash 16 | export CHECKPOINT_NAME="sdv1.5/laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch_fid9.28_checkpoint_model_039000" # note that the sdv1.5/ is necessary 17 | export OUTPUT_PATH="path/to/your/output/folder" 18 | 19 | bash scripts/download_hf_checkpoint.sh $CHECKPOINT_NAME $OUTPUT_PATH 20 | ``` 21 | 22 | Note: We only experimented with a small guidance scale of 1.75 for our SDv1.5 experiments. While this setting generally produces diverse images with good FID scores, the image quality is low. For higher quality visual results, we recommend using our [SDXL training configurations](../sdxl/README.md) or adjusting the real_guidance_scale to a larger value. 23 | 24 | 25 | ### Download Base Diffusion Models and Training Data 26 | ```bash 27 | export CHECKPOINT_PATH="" # change this to your own checkpoint folder 28 | export WANDB_ENTITY="" # change this to your own wandb entity 29 | export WANDB_PROJECT="" # change this to your own wandb project 30 | export MASTER_IP="" # change this to your own master ip 31 | 32 | # Not sure why but we found the following line necessary to work with the accelerate package in our system. 33 | # Change YOUR_MASTER_IP/YOUR_MASTER_NODE_NAME to the correct value 34 | echo "YOUR_MASTER_IP YOUR_MASTER_NODE_NAME" | sudo tee -a /etc/hosts 35 | 36 | mkdir $CHECKPOINT_PATH 37 | 38 | bash scripts/download_sdv15.sh $CHECKPOINT_PATH 39 | ``` 40 | 41 | You can also add these few export to the bashrc file so that you don't need to run them every time you open a new terminal. 42 | 43 | ### Sample Training/Testing Commands 44 | 45 | ```bash 46 | # start a training with 64 gpu. we need to run this script on all 8 nodes. 47 | bash experiments/sdv1.5/laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch.sh $CHECKPOINT_PATH $WANDB_ENTITY $WANDB_PROJECT $MASTER_IP 48 | 49 | # on some other machine, start a testing process that continually reads from the checkpoint folder and evaluate the FID 50 | # Change TIMESTAMP_TBD to the real one 51 | python main/test_folder_sd.py --folder $CHECKPOINT_PATH/laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch/TIMESTAMP_TBD \ 52 | --wandb_name test_laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch \ 53 | --wandb_entity $WANDB_ENTITY \ 54 | --wandb_project $WANDB_PROJECT \ 55 | --image_resolution 512 \ 56 | --latent_resolution 64 \ 57 | --num_train_timesteps 1000 \ 58 | --test_visual_batch_size 64 \ 59 | --per_image_object 16 \ 60 | --seed 10 \ 61 | --anno_path $CHECKPOINT_PATH/captions_coco14_test.pkl \ 62 | --eval_res 256 \ 63 | --ref_dir $CHECKPOINT_PATH/val2014 \ 64 | --total_eval_samples 30000 \ 65 | --model_id "runwayml/stable-diffusion-v1-5" \ 66 | --pred_eps 67 | ``` 68 | 69 | Please refer to [train_sd.py](../../main/train_sd.py) for various training options. Notably, if the `--delete_ckpts` flag is set to `True`, all checkpoints except the latest one will be deleted during training. Additionally, you can use the `--cache_dir` flag to specify a location with larger storage capacity. The number of checkpoints stored in `cache_dir` is controlled by the `max_checkpoint` argument. 70 | -------------------------------------------------------------------------------- /experiments/sdv1.5/laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch.sh: -------------------------------------------------------------------------------- 1 | export CHECKPOINT_PATH=$1 2 | export WANDB_ENTITY=$2 3 | export WANDB_PROJECT=$3 4 | export MASTER_IP=$4 5 | 6 | torchrun --nnodes 8 --nproc_per_node=8 --rdzv_id=2345 \ 7 | --rdzv_backend=c10d \ 8 | --rdzv_endpoint=$MASTER_IP main/train_sd.py \ 9 | --generator_lr 1e-5 \ 10 | --guidance_lr 1e-5 \ 11 | --train_iters 100000000 \ 12 | --output_path $CHECKPOINT_PATH/laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch \ 13 | --batch_size 32 \ 14 | --grid_size 2 \ 15 | --initialie_generator --log_iters 1000 \ 16 | --resolution 512 \ 17 | --latent_resolution 64 \ 18 | --seed 10 \ 19 | --real_guidance_scale 1.75 \ 20 | --fake_guidance_scale 1.0 \ 21 | --max_grad_norm 10.0 \ 22 | --model_id "runwayml/stable-diffusion-v1-5" \ 23 | --train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \ 24 | --wandb_iters 50 \ 25 | --wandb_entity $WANDB_ENTITY \ 26 | --wandb_name "laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch" \ 27 | --wandb_project $WANDB_PROJECT \ 28 | --use_fp16 \ 29 | --log_loss \ 30 | --dfake_gen_update_ratio 10 \ 31 | --gradient_checkpointing 32 | 33 | 34 | -------------------------------------------------------------------------------- /experiments/sdv1.5/laion6.25_sd_baseline_8node_guidance1.75_lr5e-7_seed10_dfake10_diffusion1000_gan1e-3_noode_resume_fixdata.sh: -------------------------------------------------------------------------------- 1 | export CHECKPOINT_PATH=$1 2 | export WANDB_ENTITY=$2 3 | export WANDB_PROJECT=$3 4 | export MASTER_IP=$4 5 | 6 | torchrun --nnodes 8 --nproc_per_node=8 --rdzv_id=2345 \ 7 | --rdzv_backend=c10d \ 8 | --rdzv_endpoint=$MASTER_IP main/train_sd.py \ 9 | --generator_lr 5e-7 \ 10 | --guidance_lr 5e-7 \ 11 | --train_iters 100000000 \ 12 | --output_path $CHECKPOINT_PATH/laion6.25_sd_baseline_8node_guidance1.75_lr5e-7_seed10_dfake10_diffusion1000_gan1e-3_resume \ 13 | --batch_size 32 \ 14 | --grid_size 2 \ 15 | --initialie_generator --log_iters 1000 \ 16 | --resolution 512 \ 17 | --latent_resolution 64 \ 18 | --seed 10 \ 19 | --real_guidance_scale 1.75 \ 20 | --fake_guidance_scale 1.0 \ 21 | --max_grad_norm 10.0 \ 22 | --model_id "runwayml/stable-diffusion-v1-5" \ 23 | --train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \ 24 | --wandb_iters 50 \ 25 | --wandb_entity $WANDB_ENTITY \ 26 | --wandb_project $WANDB_PROJECT \ 27 | --wandb_name "laion6.25_sd_baseline_8node_guidance1.75_lr5e-7_seed10_dfake10_diffusion1000_gan1e-3_resume" \ 28 | --use_fp16 \ 29 | --log_loss \ 30 | --dfake_gen_update_ratio 10 \ 31 | --gradient_checkpointing \ 32 | --cls_on_clean_image \ 33 | --gen_cls_loss \ 34 | --gen_cls_loss_weight 1e-3 \ 35 | --guidance_cls_loss_weight 1e-2 \ 36 | --diffusion_gan \ 37 | --diffusion_gan_max_timestep 1000 \ 38 | --ckpt_only_path $CHECKPOINT_PATH/laion6.25_sd_baseline_8node_guidance1.75_lr1e-5_seed10_dfake10_from_scratch_fid9.28_checkpoint_model_039000" \ 39 | --real_image_path $CHECKPOINT_PATH/sd_vae_latents_laion_500k_lmdb/ 40 | -------------------------------------------------------------------------------- /experiments/sdxl/README.md: -------------------------------------------------------------------------------- 1 | ## Getting Started with DMD2 on SDXL 2 | 3 | ### Model Zoo 4 | 5 | | Config Name | FID | Link | Iters | Hours | 6 | | ----------- | --- | ---- | ----- | ----- | 7 | | [sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch](./sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch.sh) | 19.32 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/sdxl/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch_checkpoint_model_019000) | 19k | 57 | 8 | | [sdxl_cond999_8node_lr5e-5_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch_lora](./sdxl_cond999_8node_lr5e-5_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch_lora.sh) | 19.68 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/sdxl/sdxl_cond999_8node_lr5e-5_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch_lora_checkpoint_model_016000) | 16k | 63 | 9 | | [sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode](./sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh) | 19.01 | [link](https://huggingface.co/tianweiy/DMD2/tree/main/model/sdxl/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode_checkpoint_model_024000) | 24k | 57 | 10 | 11 | 12 | For inference with our models, you only need to download the pytorch_model.bin file from the provided link. For fine-tuning, you will need to download the entire folder. 13 | You can use the following script for that: 14 | 15 | ```bash 16 | export CHECKPOINT_NAME="sdxl/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch_checkpoint_model_019000" # note that the sdxl/ is necessary 17 | export OUTPUT_PATH="path/to/your/output/folder" 18 | 19 | bash scripts/download_hf_checkpoint.sh $CHECKPOINT_NAME $OUTPUT_PATH 20 | ``` 21 | 22 | 23 | ### Download Base Diffusion Models and Training Data 24 | ```bash 25 | export CHECKPOINT_PATH="" # change this to your own checkpoint folder (this should be a central directory shared across nodes) 26 | export WANDB_ENTITY="" # change this to your own wandb entity 27 | export WANDB_PROJECT="" # change this to your own wandb project 28 | export MASTER_IP="" # change this to your own master ip 29 | 30 | # Not sure why but we found the following line necessary to work with the accelerate package in our system. 31 | # Change YOUR_MASTER_IP/YOUR_MASTER_NODE_NAME to the correct value 32 | echo "YOUR_MASTER_IP YOUR_MASTER_NODE_NAME" | sudo tee -a /etc/hosts 33 | 34 | # create a fsdp configs for accelerate launch. change the EXP_NAME to your own experiment name 35 | python main/sdxl/create_sdxl_fsdp_configs.py --folder fsdp_configs/EXP_NAME --master_ip $MASTER_IP --num_machines 8 --sharding_strategy 4 36 | mkdir $CHECKPOINT_PATH 37 | 38 | bash scripts/download_sdxl.sh $CHECKPOINT_PATH 39 | ``` 40 | 41 | You can also add these few export to the bashrc file so that you don't need to run them every time you open a new terminal. 42 | 43 | ### 4-step Sample Training/Testing Commands 44 | 45 | ```bash 46 | # start a training with 64 gpu. we need to run this script on all 8 nodes. Please change the EXP_NAME and NODE_RANK_ID accordingly. 47 | bash experiments/sdxl/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch.sh $CHECKPOINT_PATH $WANDB_ENTITY $WANDB_PROJECT fsdp_configs/EXP_NAME NODE_RANK_ID 48 | 49 | # on some other machine, start a testing process that continually reads from the checkpoint folder and evaluate the FID 50 | # Change TIMESTAMP_TBD to the real one 51 | python main/sdxl/test_folder_sdxl.py \ 52 | --folder $CHECKPOINT_PATH/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch/TIMESTAMP_TBD/ \ 53 | --conditioning_timestep 999 --num_step 4 --wandb_entity $WANDB_ENTITY \ 54 | --wandb_project $WANDB_PROJECT --num_train_timesteps 1000 \ 55 | --seed 10 --eval_res 512 --ref_dir $CHECKPOINT_PATH/coco10k/subset \ 56 | --anno_path $CHECKPOINT_PATH/coco10k/all_prompts.pkl \ 57 | --total_eval_samples 10000 --clip_score \ 58 | --wandb_name test_sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch 59 | ``` 60 | 61 | ### 1-step Sample Training/Testing Commands [Work In Progress] 62 | 63 | For 1-step model, we need an extra regression loss pretraining. 64 | 65 | First, download the 10K noise-image pairs 66 | 67 | ```bash 68 | bash scripts/download_sdxl_ode_pair_10k_lmdb.sh $CHECKPOINT_PATH 69 | ``` 70 | 71 | These pairs can be generated using [generate_noise_image_pairs_laion_sdxl.py](../../main/sdxl/generate_noise_image_pairs_laion_sdxl.py) 72 | 73 | Second, Pretrain the model with regression loss 74 | 75 | ```bash 76 | bash experiments/sdxl/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399.sh $CHECKPOINT_PATH $WANDB_ENTITY $WANDB_PROJECT $MASTER_IP 77 | ``` 78 | 79 | Alternatively, you can skip the previous two steps and directly download the regression loss pretrained checkpoint 80 | 81 | ```bash 82 | bash scripts/download_sdxl_1step_ode_pairs_ckpt.sh $CHECKPOINT_PATH 83 | ``` 84 | 85 | Start the real training 86 | 87 | ```bash 88 | # start a training with 64 gpu. we need to run this script on all 8 nodes. Please change the EXP_NAME and NODE_RANK_ID accordingly. 89 | bash experiments/sdxl/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh $CHECKPOINT_PATH $WANDB_ENTITY $WANDB_PROJECT fsdp_configs/EXP_NAME NODE_RANK_ID 90 | 91 | # on some other machine, start a testing process that continually reads from the checkpoint folder and evaluate the FID 92 | # Change TIMESTAMP_TBD to the real one 93 | python main/sdxl/test_folder_sdxl.py \ 94 | --folder $CHECKPOINT_PATH/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode/TIMESTAMP_TBD/ \ 95 | --conditioning_timestep 399 --num_step 1 --wandb_entity $WANDB_ENTITY \ 96 | --wandb_project $WANDB_PROJECT --num_train_timesteps 1000 \ 97 | --seed 10 --eval_res 512 --ref_dir $CHECKPOINT_PATH/coco10k/subset \ 98 | --anno_path $CHECKPOINT_PATH/coco10k/all_prompts.pkl \ 99 | --total_eval_samples 10000 --clip_score \ 100 | --wandb_name test_sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode 101 | ``` 102 | 103 | Please refer to [train_sd.py](../../main/train_sd.py) for various training options. Notably, if the `--delete_ckpts` flag is set to `True`, all checkpoints except the latest one will be deleted during training. Additionally, you can use the `--cache_dir` flag to specify a location with larger storage capacity. The number of checkpoints stored in `cache_dir` is controlled by the `max_checkpoint` argument. 104 | 105 | For LORA training, add the `--generator_lora` flag to the training command. The final checkpoint can be converted to a LORA model using the [extract_lora_module.py](../../main/sdxl/extract_lora_module.py) script. -------------------------------------------------------------------------------- /experiments/sdxl/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode.sh: -------------------------------------------------------------------------------- 1 | export CHECKPOINT_PATH=$1 2 | export WANDB_ENTITY=$2 3 | export WANDB_PROJECT=$3 4 | export FSDP_DIR=$4 5 | export RANK=$5 6 | 7 | # accelerate launch --config_file fsdp_configs/fsdp_1node_debug.yaml main/train_sd.py \ 8 | accelerate launch --config_file $FSDP_DIR/config_rank$RANK.yaml main/train_sd.py \ 9 | --generator_lr 5e-7 \ 10 | --guidance_lr 5e-7 \ 11 | --train_iters 100000000 \ 12 | --output_path $CHECKPOINT_PATH/sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode \ 13 | --batch_size 2 \ 14 | --grid_size 2 \ 15 | --initialie_generator --log_iters 1000 \ 16 | --resolution 1024 \ 17 | --latent_resolution 128 \ 18 | --seed 10 \ 19 | --real_guidance_scale 8 \ 20 | --fake_guidance_scale 1.0 \ 21 | --max_grad_norm 10.0 \ 22 | --model_id "stabilityai/stable-diffusion-xl-base-1.0" \ 23 | --wandb_iters 100 \ 24 | --wandb_entity $WANDB_ENTITY \ 25 | --wandb_name "sdxl_cond399_8node_lr5e-7_1step_diffusion1000_gan5e-3_guidance8_noinit_noode" \ 26 | --log_loss \ 27 | --dfake_gen_update_ratio 5 \ 28 | --fsdp \ 29 | --sdxl \ 30 | --use_fp16 \ 31 | --max_step_percent 0.98 \ 32 | --cls_on_clean_image \ 33 | --gen_cls_loss \ 34 | --gen_cls_loss_weight 5e-3 \ 35 | --guidance_cls_loss_weight 1e-2 \ 36 | --diffusion_gan \ 37 | --diffusion_gan_max_timestep 1000 \ 38 | --conditioning_timestep 399 \ 39 | --train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \ 40 | --real_image_path $CHECKPOINT_PATH/sdxl_vae_latents_laion_500k_lmdb/ \ 41 | --generator_ckpt_path $CHECKPOINT_PATH/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399_checkpoint_model_002000.bin 42 | -------------------------------------------------------------------------------- /experiments/sdxl/sdxl_cond999_8node_lr5e-5_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch_lora.sh: -------------------------------------------------------------------------------- 1 | export CHECKPOINT_PATH=$1 2 | export WANDB_ENTITY=$2 3 | export WANDB_PROJECT=$3 4 | export FSDP_DIR=$4 5 | export RANK=$5 6 | 7 | # accelerate launch --config_file fsdp_configs/fsdp_1node_debug.yaml main/train_sd.py \ 8 | accelerate launch --config_file $FSDP_DIR/config_rank$RANK.yaml main/train_sd.py \ 9 | --generator_lr 5e-5 \ 10 | --guidance_lr 5e-5 \ 11 | --train_iters 100000000 \ 12 | --output_path $CHECKPOINT_PATH/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch \ 13 | --batch_size 2 \ 14 | --grid_size 2 \ 15 | --initialie_generator --log_iters 1000 \ 16 | --resolution 1024 \ 17 | --latent_resolution 128 \ 18 | --seed 10 \ 19 | --real_guidance_scale 8 \ 20 | --fake_guidance_scale 1.0 \ 21 | --max_grad_norm 10.0 \ 22 | --model_id "stabilityai/stable-diffusion-xl-base-1.0" \ 23 | --wandb_iters 100 \ 24 | --wandb_entity $WANDB_ENTITY \ 25 | --wandb_project $WANDB_PROJECT \ 26 | --wandb_name "sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch" \ 27 | --log_loss \ 28 | --dfake_gen_update_ratio 5 \ 29 | --fsdp \ 30 | --sdxl \ 31 | --use_fp16 \ 32 | --max_step_percent 0.98 \ 33 | --cls_on_clean_image \ 34 | --gen_cls_loss \ 35 | --gen_cls_loss_weight 5e-3 \ 36 | --guidance_cls_loss_weight 1e-2 \ 37 | --diffusion_gan \ 38 | --diffusion_gan_max_timestep 1000 \ 39 | --denoising \ 40 | --num_denoising_step 4 \ 41 | --denoising_timestep 1000 \ 42 | --backward_simulation \ 43 | --train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \ 44 | --real_image_path $CHECKPOINT_PATH/sdxl_vae_latents_laion_500k_lmdb/ \ 45 | --generator_lora -------------------------------------------------------------------------------- /experiments/sdxl/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch.sh: -------------------------------------------------------------------------------- 1 | export CHECKPOINT_PATH=$1 2 | export WANDB_ENTITY=$2 3 | export WANDB_PROJECT=$3 4 | export FSDP_DIR=$4 5 | export RANK=$5 6 | 7 | # accelerate launch --config_file fsdp_configs/fsdp_1node_debug.yaml main/train_sd.py \ 8 | accelerate launch --config_file $FSDP_DIR/config_rank$RANK.yaml main/train_sd.py \ 9 | --generator_lr 5e-7 \ 10 | --guidance_lr 5e-7 \ 11 | --train_iters 100000000 \ 12 | --output_path $CHECKPOINT_PATH/sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch \ 13 | --batch_size 2 \ 14 | --grid_size 2 \ 15 | --initialie_generator --log_iters 1000 \ 16 | --resolution 1024 \ 17 | --latent_resolution 128 \ 18 | --seed 10 \ 19 | --real_guidance_scale 8 \ 20 | --fake_guidance_scale 1.0 \ 21 | --max_grad_norm 10.0 \ 22 | --model_id "stabilityai/stable-diffusion-xl-base-1.0" \ 23 | --wandb_iters 100 \ 24 | --wandb_entity $WANDB_ENTITY \ 25 | --wandb_project $WANDB_PROJECT \ 26 | --wandb_name "sdxl_cond999_8node_lr5e-7_denoising4step_diffusion1000_gan5e-3_guidance8_noinit_noode_backsim_scratch" \ 27 | --log_loss \ 28 | --dfake_gen_update_ratio 5 \ 29 | --fsdp \ 30 | --sdxl \ 31 | --use_fp16 \ 32 | --max_step_percent 0.98 \ 33 | --cls_on_clean_image \ 34 | --gen_cls_loss \ 35 | --gen_cls_loss_weight 5e-3 \ 36 | --guidance_cls_loss_weight 1e-2 \ 37 | --diffusion_gan \ 38 | --diffusion_gan_max_timestep 1000 \ 39 | --denoising \ 40 | --num_denoising_step 4 \ 41 | --denoising_timestep 1000 \ 42 | --backward_simulation \ 43 | --train_prompt_path $CHECKPOINT_PATH/captions_laion_score6.25.pkl \ 44 | --real_image_path $CHECKPOINT_PATH/sdxl_vae_latents_laion_500k_lmdb/ -------------------------------------------------------------------------------- /experiments/sdxl/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399.sh: -------------------------------------------------------------------------------- 1 | export CHECKPOINT_PATH=$1 2 | export WANDB_ENTITY=$2 3 | export WANDB_PROJECT=$3 4 | export MASTER_IP=$4 5 | 6 | torchrun --nnodes 8 --nproc_per_node=8 --rdzv_id=2345 \ 7 | --rdzv_backend=c10d \ 8 | --rdzv_endpoint=$MASTER_IP main/train_sd_ode.py \ 9 | --generator_lr 1e-5 \ 10 | --train_iters 100000000 \ 11 | --output_path $CHECKPOINT_PATH/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399 \ 12 | --grid_size 1 \ 13 | --log_iters 1000 \ 14 | --resolution 1024 \ 15 | --seed 10 \ 16 | --max_grad_norm 10.0 \ 17 | --model_id "stabilityai/stable-diffusion-xl-base-1.0" \ 18 | --wandb_iters 250 \ 19 | --wandb_entity tyin \ 20 | --wandb_name "sdxl_lr1e-5_8node_ode_pretraining_10k_cond399" \ 21 | --sdxl \ 22 | --num_ode_pairs 10000 \ 23 | --ode_pair_path $CHECKPOINT_PATH/sdxl_ode_pair_10k_lmdb/ \ 24 | --ode_batch_size 4 \ 25 | --conditioning_timestep 399 \ 26 | --tiny_vae \ 27 | --use_fp16 28 | -------------------------------------------------------------------------------- /figures/README.md: -------------------------------------------------------------------------------- 1 | This folder will include scripts to reproduce our visual results. -------------------------------------------------------------------------------- /fsdp_configs/fsdp_1node_debug.yaml: -------------------------------------------------------------------------------- 1 | 2 | compute_environment: LOCAL_MACHINE 3 | debug: true 4 | distributed_type: FSDP 5 | downcast_bf16: 'no' 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: SIZE_BASED_WRAP 8 | fsdp_backward_prefetch_policy: BACKWARD_PRE 9 | fsdp_forward_prefetch: false 10 | fsdp_min_num_params: 50000000 11 | fsdp_offload_params: false 12 | fsdp_sharding_strategy: 1 13 | fsdp_state_dict_type: SHARDED_STATE_DICT 14 | fsdp_sync_module_states: true 15 | fsdp_use_orig_params: false 16 | machine_rank: 0 17 | main_training_function: main 18 | mixed_precision: 'no' 19 | num_machines: 1 20 | num_processes: 8 21 | rdzv_backend: static 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: false 27 | -------------------------------------------------------------------------------- /main/coco_eval/cleanfid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/DMD2/8d8fa55633d47cfb81bbc7a892e7248f9518763f/main/coco_eval/cleanfid/__init__.py -------------------------------------------------------------------------------- /main/coco_eval/cleanfid/clip_features.py: -------------------------------------------------------------------------------- 1 | # pip install git+https://github.com/openai/CLIP.git 2 | import pdb 3 | from PIL import Image 4 | import numpy as np 5 | import torch 6 | import torchvision.transforms as transforms 7 | import clip 8 | from cleanfid.fid import compute_fid 9 | 10 | 11 | def img_preprocess_clip(img_np): 12 | x = Image.fromarray(img_np.astype(np.uint8)).convert("RGB") 13 | T = transforms.Compose([ 14 | transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), 15 | transforms.CenterCrop(224), 16 | ]) 17 | return np.asarray(T(x)).clip(0, 255).astype(np.uint8) 18 | 19 | 20 | class CLIP_fx(): 21 | def __init__(self, name="ViT-B/32", device="cuda"): 22 | self.model, _ = clip.load(name, device=device) 23 | self.model.eval() 24 | self.name = "clip_"+name.lower().replace("-","_").replace("/","_") 25 | 26 | def __call__(self, img_t): 27 | img_x = img_t/255.0 28 | T_norm = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 29 | img_x = T_norm(img_x) 30 | assert torch.is_tensor(img_x) 31 | if len(img_x.shape)==3: 32 | img_x = img_x.unsqueeze(0) 33 | B,C,H,W = img_x.shape 34 | with torch.no_grad(): 35 | z = self.model.encode_image(img_x) 36 | return z 37 | -------------------------------------------------------------------------------- /main/coco_eval/cleanfid/downloads_helper.py: -------------------------------------------------------------------------------- 1 | import os 2 | import urllib.request 3 | import requests 4 | import shutil 5 | 6 | 7 | inception_url = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt" 8 | 9 | 10 | """ 11 | Download the pretrined inception weights if it does not exists 12 | ARGS: 13 | fpath - output folder path 14 | """ 15 | def check_download_inception(fpath="./"): 16 | inception_path = os.path.join(fpath, "inception-2015-12-05.pt") 17 | if not os.path.exists(inception_path): 18 | # download the file 19 | with urllib.request.urlopen(inception_url) as response, open(inception_path, 'wb') as f: 20 | shutil.copyfileobj(response, f) 21 | return inception_path 22 | 23 | 24 | """ 25 | Download any url if it does not exist 26 | ARGS: 27 | local_folder - output folder path 28 | url - the weburl to download 29 | """ 30 | def check_download_url(local_folder, url): 31 | name = os.path.basename(url) 32 | local_path = os.path.join(local_folder, name) 33 | if not os.path.exists(local_path): 34 | os.makedirs(local_folder, exist_ok=True) 35 | print(f"downloading statistics to {local_path}") 36 | with urllib.request.urlopen(url) as response, open(local_path, 'wb') as f: 37 | shutil.copyfileobj(response, f) 38 | return local_path 39 | 40 | 41 | """ 42 | Download a file from google drive 43 | ARGS: 44 | file_id - id of the google drive file 45 | out_path - output folder path 46 | """ 47 | def download_google_drive(file_id, out_path): 48 | def get_confirm_token(response): 49 | for key, value in response.cookies.items(): 50 | if key.startswith('download_warning'): 51 | return value 52 | return None 53 | 54 | URL = "https://drive.google.com/uc?export=download" 55 | session = requests.Session() 56 | response = session.get(URL, params={'id': file_id}, stream=True) 57 | token = get_confirm_token(response) 58 | 59 | if token: 60 | params = {'id': file_id, 'confirm': token} 61 | response = session.get(URL, params=params, stream=True) 62 | 63 | CHUNK_SIZE = 32768 64 | with open(out_path, "wb") as f: 65 | for chunk in response.iter_content(CHUNK_SIZE): 66 | if chunk: 67 | f.write(chunk) 68 | -------------------------------------------------------------------------------- /main/coco_eval/cleanfid/features.py: -------------------------------------------------------------------------------- 1 | """ 2 | helpers for extracting features from image 3 | """ 4 | import os 5 | import platform 6 | import numpy as np 7 | import torch 8 | import cleanfid 9 | from cleanfid.downloads_helper import check_download_url 10 | from cleanfid.inception_pytorch import InceptionV3 11 | from cleanfid.inception_torchscript import InceptionV3W 12 | 13 | 14 | """ 15 | returns a functions that takes an image in range [0,255] 16 | and outputs a feature embedding vector 17 | """ 18 | def feature_extractor(name="torchscript_inception", device=torch.device("cuda"), resize_inside=False, use_dataparallel=True): 19 | if name == "torchscript_inception": 20 | path = "./" if platform.system() == "Windows" else "/tmp" 21 | model = InceptionV3W(path, download=True, resize_inside=resize_inside).to(device) 22 | model.eval() 23 | if use_dataparallel: 24 | model = torch.nn.DataParallel(model) 25 | def model_fn(x): return model(x) 26 | elif name == "pytorch_inception": 27 | model = InceptionV3(output_blocks=[3], resize_input=False).to(device) 28 | model.eval() 29 | if use_dataparallel: 30 | model = torch.nn.DataParallel(model) 31 | def model_fn(x): return model(x/255)[0].squeeze(-1).squeeze(-1) 32 | else: 33 | raise ValueError(f"{name} feature extractor not implemented") 34 | return model_fn 35 | 36 | 37 | """ 38 | Build a feature extractor for each of the modes 39 | """ 40 | def build_feature_extractor(mode, device=torch.device("cuda"), use_dataparallel=True): 41 | if mode == "legacy_pytorch": 42 | feat_model = feature_extractor(name="pytorch_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel) 43 | elif mode == "legacy_tensorflow": 44 | feat_model = feature_extractor(name="torchscript_inception", resize_inside=True, device=device, use_dataparallel=use_dataparallel) 45 | elif mode == "clean": 46 | feat_model = feature_extractor(name="torchscript_inception", resize_inside=False, device=device, use_dataparallel=use_dataparallel) 47 | return feat_model 48 | 49 | 50 | """ 51 | Load precomputed reference statistics for commonly used datasets 52 | """ 53 | def get_reference_statistics(name, res, mode="clean", model_name="inception_v3", seed=0, split="test", metric="FID"): 54 | base_url = "https://www.cs.cmu.edu/~clean-fid/stats/" 55 | if split == "custom": 56 | res = "na" 57 | if model_name=="inception_v3": 58 | model_modifier = "" 59 | else: 60 | model_modifier = "_"+model_name 61 | if metric == "FID": 62 | rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}.npz").lower() 63 | url = f"{base_url}/{rel_path}" 64 | mod_path = os.path.dirname(cleanfid.__file__) 65 | stats_folder = os.path.join(mod_path, "stats") 66 | fpath = check_download_url(local_folder=stats_folder, url=url) 67 | stats = np.load(fpath) 68 | mu, sigma = stats["mu"], stats["sigma"] 69 | return mu, sigma 70 | elif metric == "KID": 71 | rel_path = (f"{name}_{mode}{model_modifier}_{split}_{res}_kid.npz").lower() 72 | url = f"{base_url}/{rel_path}" 73 | mod_path = os.path.dirname(cleanfid.__file__) 74 | stats_folder = os.path.join(mod_path, "stats") 75 | fpath = check_download_url(local_folder=stats_folder, url=url) 76 | stats = np.load(fpath) 77 | return stats["feats"] 78 | -------------------------------------------------------------------------------- /main/coco_eval/cleanfid/inception_torchscript.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | from cleanfid.downloads_helper import * 5 | import contextlib 6 | 7 | 8 | @contextlib.contextmanager 9 | def disable_gpu_fuser_on_pt19(): 10 | # On PyTorch 1.9 a CUDA fuser bug prevents the Inception JIT model to run. See 11 | # https://github.com/GaParmar/clean-fid/issues/5 12 | # https://github.com/pytorch/pytorch/issues/64062 13 | if torch.__version__.startswith('1.9.'): 14 | old_val = torch._C._jit_can_fuse_on_gpu() 15 | torch._C._jit_override_can_fuse_on_gpu(False) 16 | yield 17 | if torch.__version__.startswith('1.9.'): 18 | torch._C._jit_override_can_fuse_on_gpu(old_val) 19 | 20 | 21 | class InceptionV3W(nn.Module): 22 | """ 23 | Wrapper around Inception V3 torchscript model provided here 24 | https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/inception-2015-12-05.pt 25 | 26 | path: locally saved inception weights 27 | """ 28 | def __init__(self, path, download=True, resize_inside=False): 29 | super(InceptionV3W, self).__init__() 30 | # download the network if it is not present at the given directory 31 | # use the current directory by default 32 | if download: 33 | check_download_inception(fpath=path) 34 | path = os.path.join(path, "inception-2015-12-05.pt") 35 | self.base = torch.jit.load(path).eval() 36 | self.layers = self.base.layers 37 | self.resize_inside = resize_inside 38 | 39 | """ 40 | Get the inception features without resizing 41 | x: Image with values in range [0,255] 42 | """ 43 | def forward(self, x): 44 | with disable_gpu_fuser_on_pt19(): 45 | bs = x.shape[0] 46 | if self.resize_inside: 47 | features = self.base(x, return_features=True).view((bs, 2048)) 48 | else: 49 | # make sure it is resized already 50 | assert (x.shape[2] == 299) and (x.shape[3] == 299) 51 | # apply normalization 52 | x1 = x - 128 53 | x2 = x1 / 128 54 | features = self.layers.forward(x2, ).view((bs, 2048)) 55 | return features 56 | -------------------------------------------------------------------------------- /main/coco_eval/cleanfid/leaderboard.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import shutil 4 | import urllib.request 5 | 6 | 7 | def get_score(model_name=None, dataset_name=None, 8 | dataset_res=None, dataset_split=None, task_name=None): 9 | # download the csv file from server 10 | url = "https://www.cs.cmu.edu/~clean-fid/files/leaderboard.csv" 11 | local_path = "/tmp/leaderboard.csv" 12 | with urllib.request.urlopen(url) as response, open(local_path, 'wb') as f: 13 | shutil.copyfileobj(response, f) 14 | 15 | d_field2idx = {} 16 | l_matches = [] 17 | with open(local_path, 'r') as f: 18 | csvreader = csv.reader(f) 19 | l_fields = next(csvreader) 20 | for idx, val in enumerate(l_fields): 21 | d_field2idx[val.strip()] = idx 22 | # iterate through all rows 23 | for row in csvreader: 24 | # skip empty rows 25 | if len(row) == 0: 26 | continue 27 | # skip if the filter doesn't match 28 | if model_name is not None and (row[d_field2idx["model_name"]].strip() != model_name): 29 | continue 30 | if dataset_name is not None and (row[d_field2idx["dataset_name"]].strip() != dataset_name): 31 | continue 32 | if dataset_res is not None and (row[d_field2idx["dataset_res"]].strip() != dataset_res): 33 | continue 34 | if dataset_split is not None and (row[d_field2idx["dataset_split"]].strip() != dataset_split): 35 | continue 36 | if task_name is not None and (row[d_field2idx["task_name"]].strip() != task_name): 37 | continue 38 | curr = {} 39 | for f in l_fields: 40 | curr[f.strip()] = row[d_field2idx[f.strip()]].strip() 41 | l_matches.append(curr) 42 | os.remove(local_path) 43 | return l_matches 44 | -------------------------------------------------------------------------------- /main/coco_eval/cleanfid/resize.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helpers for resizing with multiple CPU cores 3 | """ 4 | import os 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | import torch.nn.functional as F 9 | 10 | 11 | def build_resizer(mode): 12 | if mode == "clean": 13 | return make_resizer("PIL", False, "bicubic", (299,299)) 14 | # if using legacy tensorflow, do not manually resize outside the network 15 | elif mode == "legacy_tensorflow": 16 | return lambda x: x 17 | elif mode == "legacy_pytorch": 18 | return make_resizer("PyTorch", False, "bilinear", (299, 299)) 19 | else: 20 | raise ValueError(f"Invalid mode {mode} specified") 21 | 22 | 23 | """ 24 | Construct a function that resizes a numpy image based on the 25 | flags passed in. 26 | """ 27 | def make_resizer(library, quantize_after, filter, output_size): 28 | if library == "PIL" and quantize_after: 29 | name_to_filter = { 30 | "bicubic": Image.BICUBIC, 31 | "bilinear": Image.BILINEAR, 32 | "nearest": Image.NEAREST, 33 | "lanczos": Image.LANCZOS, 34 | "box": Image.BOX 35 | } 36 | def func(x): 37 | x = Image.fromarray(x) 38 | x = x.resize(output_size, resample=name_to_filter[filter]) 39 | x = np.asarray(x).clip(0, 255).astype(np.uint8) 40 | return x 41 | elif library == "PIL" and not quantize_after: 42 | name_to_filter = { 43 | "bicubic": Image.BICUBIC, 44 | "bilinear": Image.BILINEAR, 45 | "nearest": Image.NEAREST, 46 | "lanczos": Image.LANCZOS, 47 | "box": Image.BOX 48 | } 49 | s1, s2 = output_size 50 | def resize_single_channel(x_np): 51 | img = Image.fromarray(x_np.astype(np.float32), mode='F') 52 | img = img.resize(output_size, resample=name_to_filter[filter]) 53 | return np.asarray(img).clip(0, 255).reshape(s2, s1, 1) 54 | def func(x): 55 | x = [resize_single_channel(x[:, :, idx]) for idx in range(3)] 56 | x = np.concatenate(x, axis=2).astype(np.float32) 57 | return x 58 | elif library == "PyTorch": 59 | import warnings 60 | # ignore the numpy warnings 61 | warnings.filterwarnings("ignore") 62 | def func(x): 63 | x = torch.Tensor(x.transpose((2, 0, 1)))[None, ...] 64 | x = F.interpolate(x, size=output_size, mode=filter, align_corners=False) 65 | x = x[0, ...].cpu().data.numpy().transpose((1, 2, 0)).clip(0, 255) 66 | if quantize_after: 67 | x = x.astype(np.uint8) 68 | return x 69 | elif library == "TensorFlow": 70 | import warnings 71 | # ignore the numpy warnings 72 | warnings.filterwarnings("ignore") 73 | import tensorflow as tf 74 | def func(x): 75 | x = tf.constant(x)[tf.newaxis, ...] 76 | x = tf.image.resize(x, output_size, method=filter) 77 | x = x[0, ...].numpy().clip(0, 255) 78 | if quantize_after: 79 | x = x.astype(np.uint8) 80 | return x 81 | elif library == "OpenCV": 82 | import cv2 83 | name_to_filter = { 84 | "bilinear": cv2.INTER_LINEAR, 85 | "bicubic": cv2.INTER_CUBIC, 86 | "lanczos": cv2.INTER_LANCZOS4, 87 | "nearest": cv2.INTER_NEAREST, 88 | "area": cv2.INTER_AREA 89 | } 90 | def func(x): 91 | x = cv2.resize(x, output_size, interpolation=name_to_filter[filter]) 92 | x = x.clip(0, 255) 93 | if quantize_after: 94 | x = x.astype(np.uint8) 95 | return x 96 | else: 97 | raise NotImplementedError('library [%s] is not include' % library) 98 | return func 99 | 100 | 101 | class FolderResizer(torch.utils.data.Dataset): 102 | def __init__(self, files, outpath, fn_resize, output_ext=".png"): 103 | self.files = files 104 | self.outpath = outpath 105 | self.output_ext = output_ext 106 | self.fn_resize = fn_resize 107 | 108 | def __len__(self): 109 | return len(self.files) 110 | 111 | def __getitem__(self, i): 112 | path = str(self.files[i]) 113 | img_np = np.asarray(Image.open(path)) 114 | img_resize_np = self.fn_resize(img_np) 115 | # swap the output extension 116 | basename = os.path.basename(path).split(".")[0] + self.output_ext 117 | outname = os.path.join(self.outpath, basename) 118 | if self.output_ext == ".npy": 119 | np.save(outname, img_resize_np) 120 | elif self.output_ext == ".png": 121 | img_resized_pil = Image.fromarray(img_resize_np) 122 | img_resized_pil.save(outname) 123 | else: 124 | raise ValueError("invalid output extension") 125 | return 0 126 | -------------------------------------------------------------------------------- /main/coco_eval/cleanfid/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torchvision 4 | from PIL import Image 5 | from main.coco_eval.cleanfid.resize import build_resizer 6 | import zipfile 7 | 8 | 9 | class ResizeDataset(torch.utils.data.Dataset): 10 | """ 11 | A placeholder Dataset that enables parallelizing the resize operation 12 | using multiple CPU cores 13 | 14 | files: list of all files in the folder 15 | fn_resize: function that takes an np_array as input [0,255] 16 | """ 17 | 18 | def __init__(self, files, mode, size=(299, 299), fdir=None): 19 | self.files = files 20 | self.fdir = fdir 21 | self.transforms = torchvision.transforms.ToTensor() 22 | self.size = size 23 | self.fn_resize = build_resizer(mode) 24 | self.custom_image_tranform = lambda x: x 25 | self._zipfile = None 26 | 27 | def _get_zipfile(self): 28 | assert self.fdir is not None and '.zip' in self.fdir 29 | if self._zipfile is None: 30 | self._zipfile = zipfile.ZipFile(self.fdir) 31 | return self._zipfile 32 | 33 | def __len__(self): 34 | return len(self.files) 35 | 36 | def __getitem__(self, i): 37 | path = str(self.files[i]) 38 | if self.fdir is not None and '.zip' in self.fdir: 39 | with self._get_zipfile().open(path, 'r') as f: 40 | img_np = np.array(Image.open(f).convert('RGB')) 41 | elif ".npy" in path: 42 | img_np = np.load(path) 43 | else: 44 | img_pil = Image.open(path).convert('RGB') 45 | img_np = np.array(img_pil) 46 | 47 | # apply a custom image transform before resizing the image to 299x299 48 | img_np = self.custom_image_tranform(img_np) 49 | # fn_resize expects a np array and returns a np array 50 | img_resized = self.fn_resize(img_np) 51 | 52 | # ToTensor() converts to [0,1] only if input in uint8 53 | if img_resized.dtype == "uint8": 54 | img_t = self.transforms(np.array(img_resized))*255 55 | elif img_resized.dtype == "float32": 56 | img_t = self.transforms(img_resized) 57 | 58 | return img_t 59 | 60 | 61 | EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 62 | 'tif', 'tiff', 'webp', 'npy', 'JPEG', 'JPG', 'PNG'} 63 | 64 | class ResizeArrayDataset(torch.utils.data.Dataset): 65 | """ 66 | A placeholder Dataset that enables parallelizing the resize operation 67 | using multiple CPU cores 68 | 69 | files: list of all files in the folder 70 | fn_resize: function that takes an np_array as input [0,255] 71 | """ 72 | 73 | def __init__(self, array, mode, size=(299, 299)): 74 | self.array = array 75 | self.transforms = torchvision.transforms.ToTensor() 76 | self.size = size 77 | self.fn_resize = build_resizer(mode) 78 | self.custom_image_tranform = lambda x: x 79 | 80 | def __len__(self): 81 | return len(self.array) 82 | 83 | def __getitem__(self, i): 84 | img_np = self.array[i] 85 | 86 | # apply a custom image transform before resizing the image to 299x299 87 | img_np = self.custom_image_tranform(img_np) 88 | # fn_resize expects a np array and returns a np array 89 | img_resized = self.fn_resize(img_np) 90 | 91 | # ToTensor() converts to [0,1] only if input in uint8 92 | if img_resized.dtype == "uint8": 93 | img_t = self.transforms(np.array(img_resized))*255 94 | elif img_resized.dtype == "float32": 95 | img_t = self.transforms(img_resized) 96 | 97 | return img_t 98 | -------------------------------------------------------------------------------- /main/coco_eval/cleanfid/wrappers.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import numpy as np 3 | import torch 4 | from cleanfid.features import build_feature_extractor, get_reference_statistics 5 | from cleanfid.fid import get_batch_features, fid_from_feats 6 | from cleanfid.resize import build_resizer 7 | 8 | 9 | """ 10 | A helper class that allowing adding the images one batch at a time. 11 | """ 12 | class CleanFID(): 13 | def __init__(self, mode="clean", model_name="inception_v3", device="cuda"): 14 | self.real_features = [] 15 | self.gen_features = [] 16 | self.mode = mode 17 | self.device = device 18 | if model_name=="inception_v3": 19 | self.feat_model = build_feature_extractor(mode, device) 20 | self.fn_resize = build_resizer(mode) 21 | elif model_name=="clip_vit_b_32": 22 | from cleanfid.clip_features import CLIP_fx, img_preprocess_clip 23 | clip_fx = CLIP_fx("ViT-B/32") 24 | self.feat_model = clip_fx 25 | self.fn_resize = img_preprocess_clip 26 | 27 | """ 28 | Funtion that takes an image (PIL.Image or np.array or torch.tensor) 29 | and returns the corresponding feature embedding vector. 30 | The image x is expected to be in range [0, 255] 31 | """ 32 | def compute_features(self, x): 33 | # if x is a PIL Image 34 | if isinstance(x, Image.Image): 35 | x_np = np.array(x) 36 | x_np_resized = self.fn_resize(x_np) 37 | x_t = torch.tensor(x_np_resized.transpose((2, 0, 1))).unsqueeze(0) 38 | x_feat = get_batch_features(x_t, self.feat_model, self.device) 39 | elif isinstance(x, np.ndarray): 40 | x_np_resized = self.fn_resize(x) 41 | x_t = torch.tensor(x_np_resized.transpose((2, 0, 1))).unsqueeze(0).to(self.device) 42 | # normalization happens inside the self.feat_model, expected image range here is [0,255] 43 | x_feat = get_batch_features(x_t, self.feat_model, self.device) 44 | elif isinstance(x, torch.Tensor): 45 | # pdb.set_trace() 46 | # add the batch dimension if x is passed in as C,H,W 47 | if len(x.shape)==3: 48 | x = x.unsqueeze(0) 49 | b,c,h,w = x.shape 50 | # convert back to np array and resize 51 | l_x_np_resized = [] 52 | for _ in range(b): 53 | x_np = x[_].cpu().numpy().transpose((1, 2, 0)) 54 | l_x_np_resized.append(self.fn_resize(x_np)[None,]) 55 | x_np_resized = np.concatenate(l_x_np_resized) 56 | x_t = torch.tensor(x_np_resized.transpose((0,3,1,2))).to(self.device) 57 | # normalization happens inside the self.feat_model, expected image range here is [0,255] 58 | x_feat = get_batch_features(x_t, self.feat_model, self.device) 59 | else: 60 | raise ValueError("image type could not be inferred") 61 | return x_feat 62 | 63 | """ 64 | Extract the faetures from x and add to the list of reference real images 65 | """ 66 | def add_real_images(self, x): 67 | x_feat = self.compute_features(x) 68 | self.real_features.append(x_feat) 69 | 70 | """ 71 | Extract the faetures from x and add to the list of generated images 72 | """ 73 | def add_gen_images(self, x): 74 | x_feat = self.compute_features(x) 75 | self.gen_features.append(x_feat) 76 | 77 | """ 78 | Compute FID between the real and generated images added so far 79 | """ 80 | def calculate_fid(self, verbose=True): 81 | feats1 = np.concatenate(self.real_features) 82 | feats2 = np.concatenate(self.gen_features) 83 | if verbose: 84 | print(f"# real images = {feats1.shape[0]}") 85 | print(f"# generated images = {feats2.shape[0]}") 86 | return fid_from_feats(feats1, feats2) 87 | 88 | """ 89 | Remove the real image features added so far 90 | """ 91 | def reset_real_features(self): 92 | self.real_features = [] 93 | 94 | """ 95 | Remove the generated image features added so far 96 | """ 97 | def reset_gen_features(self): 98 | self.gen_features = [] 99 | -------------------------------------------------------------------------------- /main/coco_eval/coco_evaluator.py: -------------------------------------------------------------------------------- 1 | # Part of this code is modified from GigaGAN: https://github.com/mingukkang/GigaGAN 2 | # The MIT License (MIT) 3 | from torchvision.transforms import InterpolationMode 4 | import torchvision.transforms as transforms 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data import Dataset 7 | from PIL import Image 8 | import numpy as np 9 | import shutil 10 | import torch 11 | import time 12 | import os 13 | 14 | resizer_collection = {"nearest": InterpolationMode.NEAREST, 15 | "box": InterpolationMode.BOX, 16 | "bilinear": InterpolationMode.BILINEAR, 17 | "hamming": InterpolationMode.HAMMING, 18 | "bicubic": InterpolationMode.BICUBIC, 19 | "lanczos": InterpolationMode.LANCZOS} 20 | 21 | 22 | class CenterCropLongEdge(object): 23 | """ 24 | this code is borrowed from https://github.com/ajbrock/BigGAN-PyTorch 25 | MIT License 26 | Copyright (c) 2019 Andy Brock 27 | """ 28 | def __call__(self, img): 29 | return transforms.functional.center_crop(img, min(img.size)) 30 | 31 | def __repr__(self): 32 | return self.__class__.__name__ 33 | 34 | 35 | @torch.no_grad() 36 | def compute_fid(fake_arr, gt_dir, device, 37 | resize_size=None, feature_extractor="inception", 38 | patch_fid=False): 39 | from main.coco_eval.cleanfid import fid 40 | center_crop_trsf = CenterCropLongEdge() 41 | def resize_and_center_crop(image_np): 42 | image_pil = Image.fromarray(image_np) 43 | if patch_fid: 44 | # if image_pil.size[0] != 1024 and image_pil.size[1] != 1024: 45 | # image_pil = image_pil.resize([1024, 1024]) 46 | 47 | # directly crop to the 299 x 299 patch expected by the inception network 48 | if image_pil.size[0] >= 299 and image_pil.size[1] >= 299: 49 | image_pil = transforms.functional.center_crop(image_pil, 299) 50 | # else: 51 | # raise ValueError("Image is too small to crop to 299 x 299") 52 | else: 53 | image_pil = center_crop_trsf(image_pil) 54 | 55 | if resize_size is not None: 56 | image_pil = image_pil.resize((resize_size, resize_size), 57 | Image.LANCZOS) 58 | return np.array(image_pil) 59 | 60 | if feature_extractor == "inception": 61 | model_name = "inception_v3" 62 | elif feature_extractor == "clip": 63 | model_name = "clip_vit_b_32" 64 | else: 65 | raise ValueError( 66 | "Unrecognized feature extractor [%s]" % feature_extractor) 67 | # fid, fake_feats, real_feats = fid.compute_fid( 68 | fid = fid.compute_fid( 69 | None, 70 | gt_dir, 71 | model_name=model_name, 72 | custom_image_tranform=resize_and_center_crop, 73 | use_dataparallel=False, 74 | device=device, 75 | pred_arr=fake_arr 76 | ) 77 | # return fid, fake_feats, real_feats 78 | return fid 79 | 80 | def evaluate_model(args, device, all_images, patch_fid=False): 81 | fid = compute_fid( 82 | fake_arr=all_images, 83 | gt_dir=args.ref_dir, 84 | device=device, 85 | resize_size=args.eval_res, 86 | feature_extractor="inception", 87 | patch_fid=patch_fid 88 | ) 89 | 90 | return fid 91 | 92 | 93 | def tensor2pil(image: torch.Tensor): 94 | ''' output image : tensor to PIL 95 | ''' 96 | if isinstance(image, list) or image.ndim == 4: 97 | return [tensor2pil(im) for im in image] 98 | 99 | assert image.ndim == 3 100 | output_image = Image.fromarray(((image + 1.0) * 127.5).clamp( 101 | 0.0, 255.0).to(torch.uint8).permute(1, 2, 0).detach().cpu().numpy()) 102 | return output_image 103 | 104 | class CLIPScoreDataset(Dataset): 105 | def __init__(self, images, captions, transform, preprocessor) -> None: 106 | super().__init__() 107 | self.images = images 108 | self.captions = captions 109 | self.transform = transform 110 | self.preprocessor = preprocessor 111 | 112 | def __len__(self): 113 | return len(self.images) 114 | 115 | def __getitem__(self, index): 116 | image = self.images[index] 117 | image_pil = self.transform(image) 118 | image_pil = self.preprocessor(image_pil) 119 | caption = self.captions[index] 120 | return image_pil, caption 121 | 122 | 123 | @torch.no_grad() 124 | def compute_clip_score( 125 | images, captions, clip_model="ViT-B/32", device="cuda", how_many=30000): 126 | print("Computing CLIP score") 127 | import clip as openai_clip 128 | if clip_model == "ViT-B/32": 129 | clip, clip_preprocessor = openai_clip.load("ViT-B/32", device=device) 130 | clip = clip.eval() 131 | elif clip_model == "ViT-G/14": 132 | import open_clip 133 | clip, _, clip_preprocessor = open_clip.create_model_and_transforms("ViT-g-14", pretrained="laion2b_s12b_b42k") 134 | clip = clip.to(device) 135 | clip = clip.eval() 136 | clip = clip.float() 137 | else: 138 | raise NotImplementedError 139 | 140 | def resize_and_center_crop(image_np, resize_size=256): 141 | image_pil = Image.fromarray(image_np) 142 | image_pil = CenterCropLongEdge()(image_pil) 143 | 144 | if resize_size is not None: 145 | image_pil = image_pil.resize((resize_size, resize_size), 146 | Image.LANCZOS) 147 | return image_pil 148 | 149 | def simple_collate(batch): 150 | images, captions = [], [] 151 | for img, cap in batch: 152 | images.append(img) 153 | captions.append(cap) 154 | return images, captions 155 | 156 | 157 | dataset = CLIPScoreDataset( 158 | images, captions, transform=resize_and_center_crop, 159 | preprocessor=clip_preprocessor 160 | ) 161 | dataloader = DataLoader( 162 | dataset, batch_size=64, 163 | shuffle=False, num_workers=8, 164 | collate_fn=simple_collate 165 | 166 | ) 167 | 168 | cos_sims = [] 169 | count = 0 170 | # for imgs, txts in zip(images, captions): 171 | for index, (imgs_pil, txts) in enumerate(dataloader): 172 | # imgs_pil = [resize_and_center_crop(imgs)] 173 | # txts = [txts] 174 | # imgs_pil = [clip_preprocessor(img) for img in imgs] 175 | imgs = torch.stack(imgs_pil, dim=0).to(device) 176 | tokens = openai_clip.tokenize(txts, truncate=True).to(device) 177 | # Prepending text prompts with "A photo depicts " 178 | # https://arxiv.org/abs/2104.08718 179 | prepend_text = "A photo depicts " 180 | prepend_text_token = openai_clip.tokenize(prepend_text)[:, 1:4].to(device) 181 | prepend_text_tokens = prepend_text_token.expand(tokens.shape[0], -1) 182 | 183 | start_tokens = tokens[:, :1] 184 | new_text_tokens = torch.cat( 185 | [start_tokens, prepend_text_tokens, tokens[:, 1:]], dim=1)[:, :77] 186 | last_cols = new_text_tokens[:, 77 - 1:77] 187 | last_cols[last_cols > 0] = 49407 # eot token 188 | new_text_tokens = torch.cat([new_text_tokens[:, :76], last_cols], dim=1) 189 | 190 | img_embs = clip.encode_image(imgs) 191 | text_embs = clip.encode_text(new_text_tokens) 192 | 193 | similarities = torch.nn.functional.cosine_similarity(img_embs, text_embs, dim=1) 194 | cos_sims.append(similarities) 195 | count += similarities.shape[0] 196 | if count >= how_many: 197 | break 198 | 199 | clip_score = torch.cat(cos_sims, dim=0)[:how_many].mean() 200 | clip_score = clip_score.detach().cpu().numpy() 201 | return clip_score 202 | 203 | @torch.no_grad() 204 | def compute_image_reward( 205 | images, captions, device 206 | ): 207 | import ImageReward as RM 208 | from tqdm import tqdm 209 | model = RM.load("ImageReward-v1.0", device=device) 210 | rewards = [] 211 | for image, prompt in tqdm(zip(images, captions)): 212 | reward = model.score(prompt, Image.fromarray(image)) 213 | rewards.append(reward) 214 | return np.mean(np.array(rewards)) 215 | 216 | @torch.no_grad() 217 | def compute_diversity_score( 218 | lpips_loss_func, images, device 219 | ): 220 | # resize all image to 512 and convert to tensor 221 | images = [Image.fromarray(image) for image in images] 222 | images = [image.resize((512, 512), Image.LANCZOS) for image in images] 223 | images = np.stack([np.array(image) for image in images], axis=0) 224 | images = torch.tensor(images).to(device).float() / 255.0 225 | images = images.permute(0, 3, 1, 2) 226 | 227 | num_images = images.shape[0] 228 | loss_list = [] 229 | 230 | for i in range(num_images): 231 | for j in range(i+1, num_images): 232 | image1 = images[i].unsqueeze(0) 233 | image2 = images[j].unsqueeze(0) 234 | loss = lpips_loss_func(image1, image2) 235 | 236 | loss_list.append(loss.item()) 237 | return np.mean(loss_list) 238 | -------------------------------------------------------------------------------- /main/data/create_imagenet_lmdb.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from PIL import Image 3 | import numpy as np 4 | import argparse 5 | import torch 6 | import lmdb 7 | import glob 8 | import json 9 | import os 10 | 11 | def store_arrays_to_lmdb(env, arrays_dict, start_index=0): 12 | """ 13 | Store rows of multiple numpy arrays in a single LMDB. 14 | Each row is stored separately with a naming convention. 15 | """ 16 | with env.begin(write=True) as txn: 17 | for array_name, array in arrays_dict.items(): 18 | for i, row in enumerate(array): 19 | # Convert row to bytes 20 | row_bytes = row.tobytes() 21 | data_key = f'{array_name}_{start_index+i}_data'.encode() 22 | txn.put(data_key, row_bytes) 23 | 24 | def get_array_shape_from_lmdb(lmdb_path, array_name): 25 | with lmdb.open(lmdb_path) as env: 26 | with env.begin() as txn: 27 | image_shape = txn.get(f"{array_name}_shape".encode()).decode() 28 | image_shape = tuple(map(int, image_shape.split())) 29 | 30 | return image_shape 31 | 32 | def load_ode_file(ode_file): 33 | ode_dict = torch.load(ode_file) 34 | 35 | ode_dict.pop('prompt_list', None) # Remove 'prompt_list' if exists 36 | ode_dict.pop('batch_index', None) # Remove 'batch_index' if exists 37 | 38 | return ode_dict 39 | 40 | # Example usage: 41 | def main(): 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--data_path", type=str, required=True, help="path to imagenet") 44 | parser.add_argument("--lmdb_path", type=str, required=True, help="path to lmdb") 45 | 46 | args = parser.parse_args() 47 | 48 | # figure out the maximum map size needed 49 | total_array_size = 1000000000000 # adapt to your need, set to 1TB by default 50 | env = lmdb.open(args.lmdb_path, map_size=total_array_size * 2) 51 | 52 | # Load all images and labels 53 | all_image_folders = sorted(glob.glob(os.path.join(args.data_path, "*"))) 54 | 55 | label_path = os.path.join(args.data_path, "dataset.json") 56 | 57 | labels = json.load(open(label_path))["labels"] 58 | labels = dict(labels) 59 | 60 | counter = 0 61 | 62 | # dump to lmdb 63 | for image_folder in tqdm(all_image_folders): 64 | image_files = sorted(glob.glob(os.path.join(image_folder, "*.png"))) 65 | 66 | if not os.path.isdir(image_folder) or len(image_files) == 0: 67 | continue 68 | 69 | image_list = [] 70 | label_list = [] 71 | for image_file in image_files: 72 | image = np.array(Image.open(image_file)) 73 | image = image.transpose(2, 0, 1) 74 | image_list.append(image) 75 | 76 | label_key = os.path.join(*image_file.split("/")[-2:]) 77 | label = labels[label_key] 78 | label_list.append(label) 79 | 80 | image_list = np.stack(image_list, axis=0) 81 | label_list = np.array(label_list) 82 | 83 | data_dict = { 84 | 'images': image_list, 85 | 'labels': label_list 86 | } 87 | 88 | store_arrays_to_lmdb(env, data_dict, start_index=counter) 89 | counter += len(image_list) 90 | 91 | # save each entry's shape to lmdb 92 | with env.begin(write=True) as txn: 93 | for key, val in data_dict.items(): 94 | print(key, val) 95 | array_shape = np.array(val.shape) 96 | array_shape[0] = counter 97 | 98 | shape_key = f"{key}_shape".encode() 99 | shape_str = " ".join(map(str, array_shape)) 100 | txn.put(shape_key, shape_str.encode()) 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /main/data/create_lmdb_iterative.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import argparse 4 | import torch 5 | import lmdb 6 | import glob 7 | import os 8 | 9 | def store_arrays_to_lmdb(env, arrays_dict, start_index=0): 10 | """ 11 | Store rows of multiple numpy arrays in a single LMDB. 12 | Each row is stored separately with a naming convention. 13 | """ 14 | with env.begin(write=True) as txn: 15 | for array_name, array in tqdm(arrays_dict.items()): 16 | for i, row in enumerate(array): 17 | # Convert row to bytes 18 | if isinstance(row, str): 19 | row_bytes = row.encode() 20 | else: 21 | row_bytes = row.tobytes() 22 | data_key = f'{array_name}_{start_index+i}_data'.encode() 23 | txn.put(data_key, row_bytes) 24 | 25 | def get_array_shape_from_lmdb(lmdb_path, array_name): 26 | with lmdb.open(lmdb_path) as env: 27 | with env.begin() as txn: 28 | image_shape = txn.get(f"{array_name}_shape".encode()).decode() 29 | image_shape = tuple(map(int, image_shape.split())) 30 | 31 | return image_shape 32 | 33 | def load_ode_file(ode_file): 34 | ode_dict = torch.load(ode_file) 35 | 36 | ode_dict.pop('prompt_list', None) # Remove 'prompt_list' if exists 37 | ode_dict.pop('batch_index', None) # Remove 'batch_index' if exists 38 | 39 | return ode_dict 40 | 41 | # Example usage: 42 | def main(): 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("--data_path", type=str, required=True, help="path to ode pairs") 45 | parser.add_argument("--lmdb_path", type=str, required=True, help="path to lmdb") 46 | 47 | args = parser.parse_args() 48 | 49 | all_files = sorted(glob.glob(os.path.join(args.data_path, "*.pt"))) 50 | 51 | # figure out the maximum map size needed 52 | total_array_size = 5000000000000 # adapt to your need, set to 5TB by default 53 | 54 | env = lmdb.open(args.lmdb_path, map_size=total_array_size * 2) 55 | 56 | counter = 0 57 | 58 | for index, file in tqdm(enumerate(all_files)): 59 | # read from disk 60 | data_dict = load_ode_file(file) 61 | 62 | # write to lmdb file 63 | store_arrays_to_lmdb(env, data_dict, start_index=counter) 64 | counter += len(data_dict['latents']) 65 | 66 | # save each entry's shape to lmdb 67 | with env.begin(write=True) as txn: 68 | for key, val in data_dict.items(): 69 | print(key, val) 70 | array_shape = np.array(val.shape) 71 | array_shape[0] = counter 72 | 73 | shape_key = f"{key}_shape".encode() 74 | shape_str = " ".join(map(str, array_shape)) 75 | txn.put(shape_key, shape_str.encode()) 76 | 77 | 78 | if __name__ == "__main__": 79 | main() 80 | -------------------------------------------------------------------------------- /main/data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | from main.utils import retrieve_row_from_lmdb, get_array_shape_from_lmdb 2 | from torch.utils.data import Dataset 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | import lmdb 7 | import glob 8 | import os 9 | 10 | 11 | 12 | class LMDBDataset(Dataset): 13 | # LMDB version of an ImageDataset. It is suitable for large datasets. 14 | def __init__(self, dataset_path): 15 | # for supporting new datasets, please adapt the data type according to the one used in "main/data/create_imagenet_lmdb.py" 16 | self.KEY_TO_TYPE = { 17 | 'labels': np.int64, 18 | 'images': np.uint8, 19 | } 20 | 21 | self.dataset_path = dataset_path 22 | 23 | self.env = lmdb.open(dataset_path, readonly=True, lock=False, readahead=False, meminit=False) 24 | 25 | self.image_shape = get_array_shape_from_lmdb(self.env, 'images') 26 | self.label_shape = get_array_shape_from_lmdb(self.env, 'labels') 27 | 28 | def __len__(self): 29 | return self.image_shape[0] 30 | 31 | def __getitem__(self, idx): 32 | # final ground truth rgb image 33 | image = retrieve_row_from_lmdb( 34 | self.env, 35 | "images", self.KEY_TO_TYPE['images'], self.image_shape[1:], idx 36 | ) 37 | image = torch.tensor(image, dtype=torch.float32) 38 | 39 | label = retrieve_row_from_lmdb( 40 | self.env, 41 | "labels", self.KEY_TO_TYPE['labels'], self.label_shape[1:], idx 42 | ) 43 | 44 | label = torch.tensor(label, dtype=torch.long) 45 | image = (image / 255.0) 46 | 47 | 48 | output_dict = { 49 | 'images': image, 50 | 'class_labels': label 51 | } 52 | 53 | return output_dict 54 | -------------------------------------------------------------------------------- /main/edm/edm_guidance.py: -------------------------------------------------------------------------------- 1 | from main.edm.edm_network import get_edm_network 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import dnnlib 5 | import pickle 6 | import torch 7 | import copy 8 | 9 | def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0): 10 | # from https://github.com/crowsonkb/k-diffusion 11 | ramp = torch.linspace(0, 1, n) 12 | min_inv_rho = sigma_min ** (1 / rho) 13 | max_inv_rho = sigma_max ** (1 / rho) 14 | sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho 15 | return sigmas 16 | 17 | class EDMGuidance(nn.Module): 18 | def __init__(self, args, accelerator): 19 | super().__init__() 20 | self.args = args 21 | self.accelerator = accelerator 22 | 23 | with dnnlib.util.open_url(args.model_id) as f: 24 | temp_edm = pickle.load(f)['ema'] 25 | 26 | # initialize the real unet 27 | self.real_unet = get_edm_network(args) 28 | self.real_unet.load_state_dict(temp_edm.state_dict(), strict=True) 29 | self.real_unet.requires_grad_(False) 30 | del self.real_unet.model.map_augment 31 | self.real_unet.model.map_augment = None 32 | 33 | # initialize the fake unet 34 | self.fake_unet = copy.deepcopy(self.real_unet) 35 | self.fake_unet.requires_grad_(True) 36 | 37 | # some training hyper-parameters 38 | self.sigma_data = args.sigma_data 39 | self.sigma_max = args.sigma_max 40 | self.sigma_min = args.sigma_min 41 | self.rho = args.rho 42 | 43 | self.gan_classifier = args.gan_classifier 44 | self.diffusion_gan = args.diffusion_gan 45 | self.diffusion_gan_max_timestep = args.diffusion_gan_max_timestep 46 | 47 | if self.gan_classifier: 48 | self.cls_pred_branch = nn.Sequential( 49 | nn.Conv2d(kernel_size=4, in_channels=768, out_channels=768, stride=2, padding=1), # 8x8 -> 4x4 50 | nn.GroupNorm(num_groups=32, num_channels=768), 51 | nn.SiLU(), 52 | nn.Conv2d(kernel_size=4, in_channels=768, out_channels=768, stride=4, padding=0), # 4x4 -> 1x1 53 | nn.GroupNorm(num_groups=32, num_channels=768), 54 | nn.SiLU(), 55 | nn.Conv2d(kernel_size=1, in_channels=768, out_channels=1, stride=1, padding=0), # 1x1 -> 1x1 56 | ) 57 | self.cls_pred_branch.requires_grad_(True) 58 | 59 | self.num_train_timesteps = args.num_train_timesteps 60 | # small sigma first, large sigma later 61 | karras_sigmas = torch.flip( 62 | get_sigmas_karras(self.num_train_timesteps, sigma_max=self.sigma_max, sigma_min=self.sigma_min, 63 | rho=self.rho 64 | ), 65 | dims=[0] 66 | ) 67 | self.register_buffer("karras_sigmas", karras_sigmas) 68 | 69 | self.min_step = int(args.min_step_percent * self.num_train_timesteps) 70 | self.max_step = int(args.max_step_percent * self.num_train_timesteps) 71 | del temp_edm 72 | 73 | def compute_distribution_matching_loss( 74 | self, 75 | latents, 76 | labels 77 | ): 78 | original_latents = latents 79 | batch_size = latents.shape[0] 80 | 81 | with torch.no_grad(): 82 | timesteps = torch.randint( 83 | self.min_step, 84 | min(self.max_step+1, self.num_train_timesteps), 85 | [batch_size, 1, 1, 1], 86 | device=latents.device, 87 | dtype=torch.long 88 | ) 89 | 90 | noise = torch.randn_like(latents) 91 | 92 | timestep_sigma = self.karras_sigmas[timesteps] 93 | 94 | noisy_latents = latents + timestep_sigma.reshape(-1, 1, 1, 1) * noise 95 | 96 | pred_real_image = self.real_unet(noisy_latents, timestep_sigma, labels) 97 | 98 | pred_fake_image = self.fake_unet( 99 | noisy_latents, timestep_sigma, labels 100 | ) 101 | 102 | p_real = (latents - pred_real_image) 103 | p_fake = (latents - pred_fake_image) 104 | 105 | weight_factor = torch.abs(p_real).mean(dim=[1, 2, 3], keepdim=True) 106 | grad = (p_real - p_fake) / weight_factor 107 | 108 | grad = torch.nan_to_num(grad) 109 | 110 | # this loss gives the grad as gradient through autodiff, following https://github.com/ashawkey/stable-dreamfusion 111 | loss = 0.5 * F.mse_loss(original_latents, (original_latents-grad).detach(), reduction="mean") 112 | 113 | loss_dict = { 114 | "loss_dm": loss 115 | } 116 | 117 | dm_log_dict = { 118 | "dmtrain_noisy_latents": noisy_latents.detach(), 119 | "dmtrain_pred_real_image": pred_real_image.detach(), 120 | "dmtrain_pred_fake_image": pred_fake_image.detach(), 121 | "dmtrain_grad": grad.detach(), 122 | "dmtrain_gradient_norm": torch.norm(grad).item(), 123 | "dmtrain_timesteps": timesteps.detach(), 124 | } 125 | return loss_dict, dm_log_dict 126 | 127 | def compute_loss_fake( 128 | self, 129 | latents, 130 | labels, 131 | ): 132 | batch_size = latents.shape[0] 133 | 134 | latents = latents.detach() # no gradient to generator 135 | 136 | noise = torch.randn_like(latents) 137 | 138 | timesteps = torch.randint( 139 | 0, 140 | self.num_train_timesteps, 141 | [batch_size, 1, 1, 1], 142 | device=latents.device, 143 | dtype=torch.long 144 | ) 145 | timestep_sigma = self.karras_sigmas[timesteps] 146 | noisy_latents = latents + timestep_sigma.reshape(-1, 1, 1, 1) * noise 147 | 148 | fake_x0_pred = self.fake_unet( 149 | noisy_latents, timestep_sigma, labels 150 | ) 151 | 152 | snrs = timestep_sigma**-2 153 | 154 | # weight_schedule karras 155 | weights = snrs + 1.0 / self.sigma_data**2 156 | 157 | target = latents 158 | 159 | loss_fake = torch.mean( 160 | weights * (fake_x0_pred - target)**2 161 | ) 162 | 163 | loss_dict = { 164 | "loss_fake_mean": loss_fake 165 | } 166 | 167 | fake_log_dict = { 168 | "faketrain_latents": latents.detach(), 169 | "faketrain_noisy_latents": noisy_latents.detach(), 170 | "faketrain_x0_pred": fake_x0_pred.detach() 171 | } 172 | return loss_dict, fake_log_dict 173 | 174 | def compute_cls_logits(self, image, label): 175 | if self.diffusion_gan: 176 | timesteps = torch.randint( 177 | 0, self.diffusion_gan_max_timestep, [image.shape[0]], device=image.device, dtype=torch.long 178 | ) 179 | timestep_sigma = self.karras_sigmas[timesteps] 180 | image = image + timestep_sigma.reshape(-1, 1, 1, 1) * torch.randn_like(image) 181 | else: 182 | timesteps = torch.zeros([image.shape[0]], dtype=torch.long, device=image.device) 183 | timestep_sigma = self.karras_sigmas[timesteps] 184 | 185 | rep = self.fake_unet( 186 | image, timestep_sigma, label, return_bottleneck=True 187 | ).float() 188 | 189 | logits = self.cls_pred_branch(rep).squeeze(dim=[2, 3]) 190 | return logits 191 | 192 | def compute_generator_clean_cls_loss(self, fake_image, fake_labels): 193 | loss_dict = {} 194 | 195 | pred_realism_on_fake_with_grad = self.compute_cls_logits( 196 | image=fake_image, 197 | label=fake_labels 198 | ) 199 | loss_dict["gen_cls_loss"] = F.softplus(-pred_realism_on_fake_with_grad).mean() 200 | return loss_dict 201 | 202 | def compute_guidance_clean_cls_loss(self, real_image, fake_image, real_label, fake_label): 203 | pred_realism_on_real = self.compute_cls_logits( 204 | real_image.detach(), real_label, 205 | ) 206 | pred_realism_on_fake = self.compute_cls_logits( 207 | fake_image.detach(), fake_label, 208 | ) 209 | classification_loss = F.softplus(pred_realism_on_fake) + F.softplus(-pred_realism_on_real) 210 | 211 | log_dict = { 212 | "pred_realism_on_real": torch.sigmoid(pred_realism_on_real).squeeze(dim=1).detach(), 213 | "pred_realism_on_fake": torch.sigmoid(pred_realism_on_fake).squeeze(dim=1).detach() 214 | } 215 | 216 | loss_dict = { 217 | "guidance_cls_loss": classification_loss.mean() 218 | } 219 | return loss_dict, log_dict 220 | 221 | def generator_forward( 222 | self, 223 | image, 224 | labels 225 | ): 226 | loss_dict = {} 227 | log_dict = {} 228 | 229 | # image.requires_grad_(True) 230 | dm_dict, dm_log_dict = self.compute_distribution_matching_loss(image, labels) 231 | 232 | loss_dict.update(dm_dict) 233 | log_dict.update(dm_log_dict) 234 | 235 | if self.gan_classifier: 236 | clean_cls_loss_dict = self.compute_generator_clean_cls_loss(image, labels) 237 | loss_dict.update(clean_cls_loss_dict) 238 | 239 | # loss_dm = loss_dict["loss_dm"] 240 | # gen_cls_loss = loss_dict["gen_cls_loss"] 241 | 242 | # grad_dm = torch.autograd.grad(loss_dm, image, retain_graph=True)[0] 243 | # grad_cls = torch.autograd.grad(gen_cls_loss, image, retain_graph=True)[0] 244 | 245 | # print(f"dm {grad_dm.abs().mean()} cls {grad_cls.abs().mean()}") 246 | 247 | return loss_dict, log_dict 248 | 249 | def guidance_forward( 250 | self, 251 | image, 252 | labels, 253 | real_train_dict=None 254 | ): 255 | fake_dict, fake_log_dict = self.compute_loss_fake( 256 | image, labels 257 | ) 258 | 259 | loss_dict = fake_dict 260 | log_dict = fake_log_dict 261 | 262 | if self.gan_classifier: 263 | clean_cls_loss_dict, clean_cls_log_dict = self.compute_guidance_clean_cls_loss( 264 | real_image=real_train_dict['real_image'], 265 | fake_image=image, 266 | real_label=real_train_dict['real_label'], 267 | fake_label=labels 268 | ) 269 | loss_dict.update(clean_cls_loss_dict) 270 | log_dict.update(clean_cls_log_dict) 271 | return loss_dict, log_dict 272 | 273 | def forward( 274 | self, 275 | generator_turn=False, 276 | guidance_turn=False, 277 | generator_data_dict=None, 278 | guidance_data_dict=None 279 | ): 280 | if generator_turn: 281 | loss_dict, log_dict = self.generator_forward( 282 | image=generator_data_dict['image'], 283 | labels=generator_data_dict['label'] 284 | ) 285 | elif guidance_turn: 286 | loss_dict, log_dict = self.guidance_forward( 287 | image=guidance_data_dict['image'], 288 | labels=guidance_data_dict['label'], 289 | real_train_dict=guidance_data_dict['real_train_dict'] 290 | ) 291 | else: 292 | raise NotImplementedError 293 | 294 | return loss_dict, log_dict -------------------------------------------------------------------------------- /main/edm/edm_network.py: -------------------------------------------------------------------------------- 1 | from third_party.edm.training.networks import EDMPrecond 2 | 3 | def get_imagenet_edm_config(): 4 | return dict( 5 | augment_dim=0, 6 | model_channels=192, 7 | channel_mult=[1, 2, 3, 4], 8 | channel_mult_emb=4, 9 | num_blocks=3, 10 | attn_resolutions=[32,16,8], 11 | dropout=0.0, 12 | label_dropout=0 13 | ) 14 | 15 | def get_edm_network(args): 16 | if args.dataset_name == "imagenet": 17 | unet = EDMPrecond( 18 | img_resolution=args.resolution, 19 | img_channels=3, 20 | label_dim=args.label_dim, 21 | use_fp16=args.use_fp16, 22 | sigma_min=0, 23 | sigma_max=float("inf"), 24 | sigma_data=args.sigma_data, 25 | model_type="DhariwalUNet", 26 | **get_imagenet_edm_config() 27 | ) 28 | else: 29 | raise NotImplementedError 30 | 31 | return unet -------------------------------------------------------------------------------- /main/edm/edm_unified_model.py: -------------------------------------------------------------------------------- 1 | # A single unified model that wraps both the generator and discriminator 2 | from main.edm.edm_guidance import EDMGuidance 3 | from torch import nn 4 | import torch 5 | import copy 6 | 7 | class EDMUniModel(nn.Module): 8 | def __init__(self, args, accelerator): 9 | super().__init__() 10 | 11 | self.guidance_model = EDMGuidance(args, accelerator) 12 | 13 | self.guidance_min_step = self.guidance_model.min_step 14 | self.guidance_max_step = self.guidance_model.max_step 15 | 16 | if args.initialie_generator: 17 | self.feedforward_model = copy.deepcopy(self.guidance_model.fake_unet) 18 | else: 19 | raise NotImplementedError("Only support initializing generator from guidance model.") 20 | 21 | self.feedforward_model.requires_grad_(True) 22 | 23 | self.accelerator = accelerator 24 | self.num_train_timesteps = args.num_train_timesteps 25 | 26 | def forward(self, scaled_noisy_image, 27 | timestep_sigma, labels, 28 | real_train_dict=None, 29 | compute_generator_gradient=False, 30 | generator_turn=False, 31 | guidance_turn=False, 32 | guidance_data_dict=None 33 | ): 34 | assert (generator_turn and not guidance_turn) or (guidance_turn and not generator_turn) 35 | 36 | if generator_turn: 37 | if not compute_generator_gradient: 38 | with torch.no_grad(): 39 | generated_image = self.feedforward_model(scaled_noisy_image, timestep_sigma, labels) 40 | else: 41 | generated_image = self.feedforward_model(scaled_noisy_image, timestep_sigma, labels) 42 | 43 | if compute_generator_gradient: 44 | generator_data_dict = { 45 | "image": generated_image, 46 | "label": labels, 47 | "real_train_dict": real_train_dict 48 | } 49 | 50 | # as we don't need to compute gradient for guidance model 51 | # we disable gradient to avoid side effects (in GAN Loss computation) 52 | self.guidance_model.requires_grad_(False) 53 | loss_dict, log_dict = self.guidance_model( 54 | generator_turn=True, 55 | guidance_turn=False, 56 | generator_data_dict=generator_data_dict 57 | ) 58 | self.guidance_model.requires_grad_(True) 59 | else: 60 | loss_dict = {} 61 | log_dict = {} 62 | 63 | log_dict['generated_image'] = generated_image.detach() 64 | 65 | log_dict['guidance_data_dict'] = { 66 | "image": generated_image.detach(), 67 | "label": labels.detach(), 68 | "real_train_dict": real_train_dict 69 | } 70 | 71 | elif guidance_turn: 72 | assert guidance_data_dict is not None 73 | loss_dict, log_dict = self.guidance_model( 74 | generator_turn=False, 75 | guidance_turn=True, 76 | guidance_data_dict=guidance_data_dict 77 | ) 78 | 79 | return loss_dict, log_dict 80 | 81 | 82 | 83 | 84 | -------------------------------------------------------------------------------- /main/edm/test_folder_edm.py: -------------------------------------------------------------------------------- 1 | from third_party.edm.training.networks import EDMPrecond 2 | from main.edm.edm_network import get_imagenet_edm_config 3 | from accelerate.utils import ProjectConfiguration 4 | from accelerate.utils import set_seed 5 | from accelerate import Accelerator 6 | from tqdm import tqdm 7 | import numpy as np 8 | import argparse 9 | import dnnlib 10 | import pickle 11 | import wandb 12 | import torch 13 | import scipy 14 | import glob 15 | import json 16 | import time 17 | import os 18 | 19 | def get_imagenet_config(): 20 | base_config = { 21 | "img_resolution": 64, 22 | "img_channels": 3, 23 | "label_dim": 1000, 24 | "use_fp16": False, 25 | "sigma_min": 0, 26 | "sigma_max": float("inf"), 27 | "sigma_data": 0.5, 28 | "model_type": "DhariwalUNet" 29 | } 30 | base_config.update(get_imagenet_edm_config()) 31 | return base_config 32 | 33 | 34 | def create_generator(checkpoint_path, base_model=None): 35 | if base_model is None: 36 | base_config = get_imagenet_config() 37 | generator = EDMPrecond(**base_config) 38 | del generator.model.map_augment 39 | generator.model.map_augment = None 40 | else: 41 | generator = base_model 42 | 43 | while True: 44 | try: 45 | state_dict = torch.load(checkpoint_path, map_location="cpu") 46 | break 47 | except: 48 | print(f"fail to load checkpoint {checkpoint_path}") 49 | time.sleep(1) 50 | 51 | print(generator.load_state_dict(state_dict, strict=True)) 52 | 53 | return generator 54 | 55 | def create_evaluator(detector_url): 56 | detector_kwargs = dict(return_features=True) 57 | feature_dim = 2048 58 | with dnnlib.util.open_url(detector_url, verbose=False) as f: 59 | detector_net = pickle.load(f) 60 | 61 | detector_net.eval() 62 | return detector_net, detector_kwargs, feature_dim 63 | 64 | @torch.no_grad() 65 | def sample(accelerator, current_model, args, model_index): 66 | timesteps = torch.ones(args.eval_batch_size, device=accelerator.device, dtype=torch.long) 67 | current_model.eval() 68 | all_images = [] 69 | all_images_tensor = [] 70 | 71 | current_index = 0 72 | 73 | all_labels = torch.arange(0, args.total_eval_samples*2, 74 | device=accelerator.device, dtype=torch.long) % args.label_dim 75 | 76 | set_seed(args.seed+accelerator.process_index) 77 | 78 | while len(all_images_tensor) * args.eval_batch_size * accelerator.num_processes < args.total_eval_samples: 79 | noise = torch.randn(args.eval_batch_size, 3, 80 | args.resolution, args.resolution, device=accelerator.device 81 | ) 82 | 83 | random_labels = all_labels[current_index:current_index+args.eval_batch_size] 84 | one_hot_labels = torch.eye(args.label_dim, device=accelerator.device)[ 85 | random_labels 86 | ] 87 | 88 | current_index += args.eval_batch_size 89 | 90 | eval_images = current_model(noise * args.conditioning_sigma, timesteps * args.conditioning_sigma, one_hot_labels) 91 | eval_images = ((eval_images + 1.0) * 127.5).clamp(0, 255).to(torch.uint8).permute(0, 2, 3, 1) 92 | eval_images = eval_images.contiguous() 93 | 94 | gathered_images = accelerator.gather(eval_images) 95 | 96 | all_images.append(gathered_images.cpu().numpy()) 97 | all_images_tensor.append(gathered_images.cpu()) 98 | 99 | if accelerator.is_main_process: 100 | print("all_images len ", len(torch.cat(all_images_tensor, dim=0))) 101 | 102 | all_images = np.concatenate(all_images, axis=0)[:args.total_eval_samples] 103 | all_images_tensor = torch.cat(all_images_tensor, dim=0)[:args.total_eval_samples] 104 | 105 | if accelerator.is_main_process: 106 | # Uncomment if you need to save the images 107 | # np.savez(os.path.join(args.folder, f"eval_image_model_{model_index:06d}.npz"), all_images) 108 | # raise 109 | grid_size = int(args.test_visual_batch_size**(1/2)) 110 | eval_images_grid = all_images[:grid_size*grid_size].reshape(grid_size, grid_size, args.resolution, args.resolution, 3) 111 | eval_images_grid = np.swapaxes(eval_images_grid, 1, 2).reshape(grid_size*args.resolution, grid_size*args.resolution, 3) 112 | 113 | data_dict = { 114 | "generated_image_grid": wandb.Image(eval_images_grid) 115 | } 116 | 117 | data_dict['image_mean'] = all_images_tensor.float().mean().item() 118 | data_dict['image_std'] = all_images_tensor.float().std().item() 119 | 120 | wandb.log( 121 | data_dict, 122 | step=model_index 123 | ) 124 | 125 | accelerator.wait_for_everyone() 126 | return all_images_tensor 127 | 128 | @torch.no_grad() 129 | def calculate_inception_stats(all_images_tensor, evaluator, accelerator, evaluator_kwargs, feature_dim, max_batch_size): 130 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=accelerator.device) 131 | sigma = torch.ones([feature_dim, feature_dim], dtype=torch.float64, device=accelerator.device) 132 | num_batches = ((len(all_images_tensor) - 1) // (max_batch_size * accelerator.num_processes ) + 1) * accelerator.num_processes 133 | all_batches = torch.arange(len(all_images_tensor)).tensor_split(num_batches) 134 | rank_batches = all_batches[accelerator.process_index :: accelerator.num_processes] 135 | 136 | for i in tqdm(range(num_batches//accelerator.num_processes), unit='batch', disable=not accelerator.is_main_process): 137 | images = all_images_tensor[rank_batches[i]] 138 | features = evaluator(images.permute(0, 3, 1, 2).to(accelerator.device), **evaluator_kwargs).to(torch.float64) 139 | mu += features.sum(0) 140 | sigma += features.T @ features 141 | 142 | # Calculate grand totals. 143 | mu = accelerator.reduce(mu) 144 | sigma = accelerator.reduce(sigma) 145 | mu /= len(all_images_tensor) 146 | sigma -= mu.ger(mu) * len(all_images_tensor) 147 | sigma /= len(all_images_tensor) - 1 148 | return mu.cpu().numpy(), sigma.cpu().numpy() 149 | 150 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): 151 | m = np.square(mu - mu_ref).sum() 152 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) 153 | fid = m + np.trace(sigma + sigma_ref - s * 2) 154 | return float(np.real(fid)) 155 | 156 | @torch.no_grad() 157 | def evaluate(): 158 | parser = argparse.ArgumentParser() 159 | parser.add_argument("--folder", type=str, required=True, help="pass to folder") 160 | parser.add_argument("--wandb_entity", type=str) 161 | parser.add_argument("--wandb_project", type=str) 162 | parser.add_argument("--wandb_name", type=str) 163 | parser.add_argument("--eval_batch_size", type=int, default=128) 164 | parser.add_argument("--resolution", type=int, default=32) 165 | parser.add_argument("--total_eval_samples", type=int, default=50000) 166 | parser.add_argument("--label_dim", type=int, default=10) 167 | parser.add_argument("--sigma_max", type=float, default=80.0) 168 | parser.add_argument("--sigma_min", type=float, default=0.002) 169 | parser.add_argument("--test_visual_batch_size", type=int, default=100) 170 | parser.add_argument("--max_batch_size", type=int, default=128) 171 | parser.add_argument("--ref_path", type=str, help="reference fid statistics") 172 | parser.add_argument("--detector_url", type=str) 173 | parser.add_argument("--seed", type=int, default=10) 174 | parser.add_argument("--dataset_name", type=str, default="imagenet") 175 | parser.add_argument("--no_resume", action="store_true") 176 | parser.add_argument("--conditioning_sigma", type=float, default=80.0) 177 | 178 | args = parser.parse_args() 179 | 180 | folder = args.folder 181 | evaluated_checkpoints = set() 182 | overall_stats = {} 183 | 184 | # initialize accelerator 185 | accelerator_project_config = ProjectConfiguration(logging_dir=args.folder) 186 | accelerator = Accelerator( 187 | gradient_accumulation_steps=1, 188 | mixed_precision="no", 189 | log_with="wandb", 190 | project_config=accelerator_project_config 191 | ) 192 | print(accelerator.state) 193 | 194 | # assert accelerator.num_processes == 1, "currently multi-gpu inference generates images with biased class distribution and leads to much worse FID" 195 | 196 | # load previous stats 197 | info_path = os.path.join(folder, "stats.json") 198 | if os.path.isfile(info_path) and not args.no_resume: 199 | with open(info_path, "r") as f: 200 | overall_stats = json.load(f) 201 | evaluated_checkpoints = set(overall_stats.keys()) 202 | if accelerator.is_main_process: 203 | print(f"folder to evaluate: {folder}") 204 | 205 | # initialize wandb 206 | if accelerator.is_main_process: 207 | run = wandb.init(config=args, dir=args.folder, **{"mode": "online", "entity": args.wandb_entity, "project": args.wandb_project}) 208 | wandb.run.name = args.wandb_name 209 | print(f"wandb run dir: {run.dir}") 210 | 211 | # initialie evaluator 212 | evaluator, evaluator_kwargs, feature_dim = create_evaluator(args.detector_url) 213 | evaluator = accelerator.prepare(evaluator) 214 | generator = None 215 | 216 | # initialize reference statistics 217 | with dnnlib.util.open_url(args.ref_path) as f: 218 | ref_dict = dict(np.load(f)) 219 | 220 | while True: 221 | new_checkpoints = sorted(glob.glob(os.path.join(folder, "*checkpoint_*"))) 222 | new_checkpoints = set(new_checkpoints) - evaluated_checkpoints 223 | new_checkpoints = sorted(list(new_checkpoints)) 224 | 225 | if len(new_checkpoints) == 0: 226 | continue 227 | 228 | for checkpoint in new_checkpoints: 229 | if accelerator.is_main_process: 230 | print(f"Evaluating {folder} {checkpoint}") 231 | model_index = int(checkpoint.replace("/", "").split("_")[-1]) 232 | 233 | generator = create_generator( 234 | os.path.join(checkpoint, "pytorch_model.bin"), 235 | base_model=generator 236 | ) 237 | generator = generator.to(accelerator.device) 238 | 239 | all_images_tensor = sample( 240 | accelerator, 241 | generator, 242 | args, 243 | model_index 244 | ) 245 | 246 | stats = {} 247 | 248 | pred_mu, pred_sigma = calculate_inception_stats(all_images_tensor, evaluator, 249 | accelerator, evaluator_kwargs, feature_dim, args.max_batch_size, 250 | ) 251 | 252 | if accelerator.is_main_process: 253 | fid = calculate_fid_from_inception_stats( 254 | pred_mu, pred_sigma, ref_dict['mu'], ref_dict['sigma'] 255 | ) 256 | stats["fid"] = fid 257 | 258 | print(f"checkpoint {checkpoint} fid {fid}") 259 | overall_stats[checkpoint] = stats 260 | 261 | wandb.log( 262 | stats, 263 | step=model_index 264 | ) 265 | 266 | torch.cuda.empty_cache() 267 | 268 | evaluated_checkpoints.update(new_checkpoints) 269 | 270 | if accelerator.is_main_process: 271 | # dump stats to folder 272 | with open(os.path.join(folder, "stats.json"), "w") as f: 273 | json.dump(overall_stats, f, indent=2) 274 | 275 | 276 | if __name__ == "__main__": 277 | evaluate() -------------------------------------------------------------------------------- /main/sd_image_dataset.py: -------------------------------------------------------------------------------- 1 | from main.utils import retrieve_row_from_lmdb, get_array_shape_from_lmdb 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import torch 5 | import lmdb 6 | 7 | 8 | class SDImageDatasetLMDB(Dataset): 9 | def __init__(self, dataset_path, tokenizer_one, is_sdxl=False, tokenizer_two=None): 10 | self.KEY_TO_TYPE = { 11 | 'latents': np.float16 12 | } 13 | self.is_sdxl = is_sdxl # sdxl uses two tokenizers 14 | self.dataset_path = dataset_path 15 | self.tokenizer_one = tokenizer_one 16 | self.tokenizer_two = tokenizer_two 17 | 18 | self.env = lmdb.open(dataset_path, readonly=True, lock=False, readahead=False, meminit=False) 19 | self.latent_shape = get_array_shape_from_lmdb(self.env, "latents") 20 | 21 | self.length = self.latent_shape[0] 22 | 23 | print(f"Dataset length: {self.length}") 24 | 25 | def __len__(self): 26 | return self.length 27 | 28 | def __getitem__(self, idx): 29 | image = retrieve_row_from_lmdb( 30 | self.env, 31 | "latents", self.KEY_TO_TYPE['latents'], self.latent_shape[1:], idx 32 | ) 33 | image = torch.tensor(image, dtype=torch.float32) 34 | 35 | with self.env.begin() as txn: 36 | prompt = txn.get(f'prompts_{idx}_data'.encode()).decode() 37 | 38 | text_input_ids_one = self.tokenizer_one( 39 | [prompt], 40 | padding="max_length", 41 | max_length=self.tokenizer_one.model_max_length, 42 | truncation=True, 43 | return_tensors="pt", 44 | ).input_ids 45 | 46 | output_dict = { 47 | 'images': image, 48 | 'text_input_ids_one': text_input_ids_one, 49 | } 50 | 51 | if self.is_sdxl: 52 | text_input_ids_two = self.tokenizer_two( 53 | [prompt], 54 | padding="max_length", 55 | max_length=self.tokenizer_two.model_max_length, 56 | truncation=True, 57 | return_tensors="pt", 58 | ).input_ids 59 | output_dict['text_input_ids_two'] = text_input_ids_two 60 | 61 | return output_dict -------------------------------------------------------------------------------- /main/sdxl/create_sdxl_fsdp_configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import yaml 3 | import os 4 | 5 | def create_yaml_config(filename, rank, master_ip, args): 6 | # Define the configuration data 7 | config_data = { 8 | "compute_environment": "LOCAL_MACHINE", 9 | "debug": False, 10 | "distributed_type": "FSDP", 11 | "downcast_bf16": "no", 12 | "fsdp_config": { 13 | "fsdp_auto_wrap_policy": "SIZE_BASED_WRAP", 14 | "fsdp_backward_prefetch_policy": "BACKWARD_PRE", 15 | "fsdp_forward_prefetch": False, 16 | "fsdp_min_num_params": 50000000, 17 | "fsdp_offload_params": False, 18 | "fsdp_sharding_strategy": args.sharding_strategy, 19 | "fsdp_state_dict_type": "SHARDED_STATE_DICT", 20 | "fsdp_sync_module_states": True, 21 | "fsdp_use_orig_params": False 22 | }, 23 | "machine_rank": rank, 24 | "main_process_ip": master_ip, 25 | "main_process_port": 2345, 26 | "main_training_function": "main", 27 | "mixed_precision": "no", 28 | "num_machines": args.num_machines, 29 | "num_processes": 8*args.num_machines, 30 | "rdzv_backend": "static", 31 | "same_network": True, 32 | "tpu_env": [], 33 | "tpu_use_cluster": False, 34 | "tpu_use_sudo": False, 35 | "use_cpu": False 36 | } 37 | 38 | # Write the configuration data to a YAML file 39 | with open(filename, 'w') as file: 40 | yaml.dump(config_data, file, default_flow_style=False) 41 | 42 | def main(): 43 | parser = argparse.ArgumentParser(description="Create a YAML configuration file") 44 | parser.add_argument("--folder", type=str, help="The name of the YAML configuration file to create") 45 | parser.add_argument("--master_ip", type=str) 46 | parser.add_argument("--num_machines", type=int, default=8) 47 | parser.add_argument("--sharding_strategy", type=str, help="sharding strategy. 1-5 FULL_SHARD / SHARD_GRAD_OP / NO_SHARD / HYBRID_SHARD / HYBRID_SHARD_ZERO2") 48 | args = parser.parse_args() 49 | 50 | os.makedirs(args.folder, exist_ok=True) 51 | 52 | for i in range(args.num_machines): 53 | filename = os.path.join(args.folder, f"config_rank{i}.yaml") 54 | create_yaml_config(filename, i, args.master_ip, args) 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /main/sdxl/data_process.py: -------------------------------------------------------------------------------- 1 | # # bash commands to download the data 2 | # !wget http://images.cocodataset.org/zips/val2014.zip 3 | # !wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip 4 | # !unzip annotations_trainval2014.zip -d coco/ 5 | # !unzip val2014.zip -d coco/ 6 | 7 | 8 | 9 | # load data 10 | import json 11 | import numpy as np 12 | import pickle 13 | 14 | 15 | dir_path = './coco/' 16 | data_file = dir_path + 'annotations/captions_val2014.json' 17 | data = json.load(open(data_file)) 18 | 19 | np.random.seed(123) 20 | 21 | # merge images and annotations 22 | import pandas as pd 23 | images = data['images'] 24 | annotations = data['annotations'] 25 | df = pd.DataFrame(images) 26 | df_annotations = pd.DataFrame(annotations) 27 | df = df.merge(pd.DataFrame(annotations), how='left', left_on='id', right_on='image_id') 28 | 29 | 30 | # keep only the relevant columns 31 | df = df[['file_name', 'caption']] 32 | print(df) 33 | print("length:", len(df['file_name'])) 34 | # shuffle the dataset 35 | df = df.sample(frac=1) 36 | 37 | 38 | # remove duplicate images 39 | df = df.drop_duplicates(subset='file_name') 40 | 41 | # create a random subset of n_samples 42 | n_samples = 10000 43 | df_sample = df.sample(n_samples) 44 | # print(df_sample) 45 | 46 | all_prompts = list(df_sample['caption']) 47 | 48 | with open(dir_path + 'all_prompts.pkl', 'wb') as f: 49 | pickle.dump(all_prompts, f) 50 | 51 | # save the sample to a parquet file 52 | df_sample.to_csv(dir_path + 'subset.csv') 53 | 54 | # copy the images to reference folder 55 | from pathlib import Path 56 | import shutil 57 | subset_path = Path(dir_path + 'subset') 58 | subset_path.mkdir(exist_ok=True) 59 | counter = 0 60 | 61 | for i, row in df_sample.iterrows(): 62 | path = dir_path + 'val2014/' + row['file_name'] 63 | shutil.copy(path, dir_path + 'subset/') 64 | 65 | counter += 1 66 | print(counter, path, dir_path + 'subset/') 67 | -------------------------------------------------------------------------------- /main/sdxl/extract_lora_module.py: -------------------------------------------------------------------------------- 1 | from peft import LoraConfig, get_peft_model_state_dict 2 | from diffusers import UNet2DConditionModel 3 | from safetensors.torch import save_file 4 | import argparse 5 | import torch 6 | 7 | def main(): 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--original_model_path", type=str, required=True) 10 | parser.add_argument("--output_model_path", type=str, required=True) 11 | parser.add_argument("--lora_rank", type=int, default=64) 12 | parser.add_argument("--lora_alpha", type=float, default=8) 13 | parser.add_argument("--lora_dropout", type=float, default=0.0) 14 | parser.add_argument("--fp16", action="store_true") 15 | 16 | args = parser.parse_args() 17 | 18 | generator = UNet2DConditionModel.from_pretrained( 19 | "stabilityai/stable-diffusion-xl-base-1.0", 20 | subfolder="unet" 21 | ).float() 22 | 23 | lora_target_modules = [ 24 | "to_q", 25 | "to_k", 26 | "to_v", 27 | "to_out.0", 28 | "proj_in", 29 | "proj_out", 30 | "ff.net.0.proj", 31 | "ff.net.2", 32 | "conv1", 33 | "conv2", 34 | "conv_shortcut", 35 | "downsamplers.0.conv", 36 | "upsamplers.0.conv", 37 | "time_emb_proj", 38 | ] 39 | lora_config = LoraConfig( 40 | r=args.lora_rank, 41 | target_modules=lora_target_modules, 42 | lora_alpha=args.lora_alpha, 43 | lora_dropout=args.lora_dropout 44 | ) 45 | generator.add_adapter(lora_config) 46 | 47 | generator.load_state_dict(torch.load(args.original_model_path)) 48 | 49 | if args.fp16: 50 | generator = generator.half() 51 | 52 | unet_lora_state_dict = get_peft_model_state_dict(generator) 53 | 54 | new_state_dict = {} 55 | 56 | for k, v in unet_lora_state_dict.items(): 57 | 58 | if "lora_A" in k: 59 | k = k.replace("lora_A", "lora_down") 60 | elif "lora_B" in k: 61 | k = k.replace("lora_B", "lora_up") 62 | 63 | k = "lora_unet_" + "_".join(k.split(".")[:-2]) + "." + ".".join(k.split(".")[-2:]) 64 | 65 | new_state_dict[k] = v 66 | 67 | alpha_key = k[:k.find(".")]+".alpha" 68 | 69 | new_state_dict[alpha_key] = torch.tensor(args.lora_alpha, dtype=torch.float16 if args.fp16 else torch.float32) 70 | 71 | save_file(new_state_dict, args.output_model_path) 72 | 73 | if __name__ == "__main__": 74 | main() -------------------------------------------------------------------------------- /main/sdxl/generate_noise_image_pairs_laion_sdxl.py: -------------------------------------------------------------------------------- 1 | from diffusers import StableDiffusionXLPipeline, EulerDiscreteScheduler 2 | from main.sdxl.sdxl_text_encoder import SDXLTextEncoder 3 | from accelerate.utils import ProjectConfiguration 4 | from main.utils import SDTextDataset 5 | from transformers import AutoTokenizer 6 | from accelerate.utils import set_seed 7 | from accelerate import Accelerator 8 | from tqdm import tqdm 9 | import numpy as np 10 | import argparse 11 | import torch 12 | import os 13 | 14 | @torch.no_grad() 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--seed", type=int, default=10) 18 | parser.add_argument("--folder", type=str, required=True, help="path to folder") 19 | parser.add_argument("--batch_size", type=int, default=8) 20 | parser.add_argument("--num_batches", type=int, default=1250) 21 | parser.add_argument("--guidance_scale", type=float, default=8) 22 | parser.add_argument("--prompt_path", type=str) 23 | parser.add_argument("--model_id", type=str, default="stabilityai/stable-diffusion-xl-base-1.0") 24 | parser.add_argument("--latent_resolution", type=int, default=128) 25 | parser.add_argument("--latent_channel", type=int, default=4) 26 | parser.add_argument("--revision", type=str) 27 | 28 | args = parser.parse_args() 29 | 30 | os.makedirs(args.folder, exist_ok=True) 31 | 32 | # initialize accelerator 33 | accelerator_project_config = ProjectConfiguration(logging_dir=args.folder) 34 | accelerator = Accelerator( 35 | gradient_accumulation_steps=1, 36 | mixed_precision="no", 37 | log_with="wandb", 38 | project_config=accelerator_project_config 39 | ) 40 | 41 | # make sure that different processes don't have the same seed, otherwise they will generate the same images 42 | set_seed(args.seed + accelerator.process_index) 43 | print(accelerator.state) 44 | 45 | # use TF32 for faster training on Ampere GPUs 46 | # disable for older GPUs 47 | torch.backends.cuda.matmul.allow_tf32 = True 48 | torch.backends.cudnn.allow_tf32 = True 49 | 50 | pipeline = StableDiffusionXLPipeline.from_pretrained( 51 | args.model_id, torch_dtype=torch.float16, variant="fp16", use_safetensors=True 52 | ) 53 | pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config, timestep_spacing="trailing") 54 | 55 | pipeline.to(accelerator.device) 56 | pipeline.set_progress_bar_config(disable=True) 57 | pipeline.safety_checker = None 58 | 59 | text_encoder = SDXLTextEncoder(args, accelerator).to(accelerator.device) 60 | 61 | tokenizer_one = AutoTokenizer.from_pretrained( 62 | args.model_id, subfolder="tokenizer", revision=args.revision, use_fast=False 63 | ) 64 | 65 | tokenizer_two = AutoTokenizer.from_pretrained( 66 | args.model_id, subfolder="tokenizer_2", revision=args.revision, use_fast=False 67 | ) 68 | 69 | caption_dataset = SDTextDataset( 70 | args.prompt_path, 71 | tokenizer_one=tokenizer_one, 72 | tokenizer_two=tokenizer_two, 73 | is_sdxl=True 74 | ) 75 | 76 | # split the dataset across gpus 77 | # NOTE: current code doesn't handle node failures 78 | subset_start_index = accelerator.process_index / accelerator.num_processes * len(caption_dataset) 79 | subset_end_index = (accelerator.process_index + 1) / accelerator.num_processes * len(caption_dataset) 80 | 81 | print(f"Process {accelerator.process_index} has indices {subset_start_index} to {subset_end_index}") 82 | 83 | caption_dataset = torch.utils.data.Subset( 84 | caption_dataset, 85 | list(range(int(subset_start_index), int(subset_end_index))) 86 | ) 87 | caption_dataloader = torch.utils.data.DataLoader( 88 | caption_dataset, 89 | batch_size=args.batch_size, 90 | shuffle=True, 91 | drop_last=True 92 | ) # we do shuffle in case we need a randomized subset of the data 93 | 94 | latents_list, images_list, prompt_embeds_list, pooled_prompt_embeds_list = [], [], [], [] 95 | 96 | for batch_index, data in tqdm(enumerate(caption_dataloader), disable=not accelerator.is_main_process, total=args.num_batches): 97 | prompt_embed, pooled_prompt_embed = text_encoder(data) 98 | uncond_prompt_embed, uncond_pooled_prompt_embed = ( 99 | torch.zeros_like(prompt_embed), torch.zeros_like(pooled_prompt_embed) 100 | ) 101 | 102 | input_latents = torch.randn( 103 | len(prompt_embed), 104 | args.latent_channel, 105 | args.latent_resolution, 106 | args.latent_resolution, 107 | device=accelerator.device, 108 | dtype=torch.float32 109 | ).half() 110 | 111 | output_images = pipeline( 112 | prompt_embeds=prompt_embed, 113 | pooled_prompt_embeds=pooled_prompt_embed, 114 | negative_prompt_embeds=uncond_prompt_embed, 115 | negative_pooled_prompt_embeds=uncond_pooled_prompt_embed, 116 | latents=input_latents, 117 | output_type="latent", 118 | guidance_scale=args.guidance_scale 119 | ) 120 | 121 | # save as fp16 to save space 122 | input_latents = input_latents.cpu().half().numpy() 123 | output_images = output_images.cpu().half().numpy() 124 | 125 | prompt_embeds = prompt_embed.cpu().half().numpy() 126 | pooled_prompt_embeds = pooled_prompt_embed.cpu().half().numpy() 127 | 128 | latents_list.append(input_latents) 129 | images_list.append(output_images) 130 | prompt_embeds_list.append(prompt_embeds) 131 | pooled_prompt_embeds_list.append(pooled_prompt_embeds) 132 | 133 | if batch_index >= args.num_batches: # early stop 134 | break 135 | 136 | if batch_index % 250 == 0: 137 | data_dict = { 138 | "latents": np.concatenate(latents_list, axis=0), 139 | "images": np.concatenate(images_list, axis=0), 140 | "prompt_embeds_list": np.concatenate(prompt_embeds_list, axis=0), 141 | "pooled_prompt_embeds": np.concatenate(pooled_prompt_embeds_list, axis=0) 142 | } 143 | output_path = os.path.join(args.folder, f"BATCH_{batch_index}_noise_image_pairs_{accelerator.process_index:03d}.pt") 144 | torch.save( 145 | data_dict, output_path, pickle_protocol=5 146 | ) 147 | 148 | if os.path.exists( 149 | os.path.join(args.folder, f"BATCH_{batch_index-250}_noise_image_pairs_{accelerator.process_index:03d}.pt") 150 | ): 151 | os.remove( 152 | os.path.join(args.folder, f"BATCH_{batch_index-250}_noise_image_pairs_{accelerator.process_index:03d}.pt") 153 | ) 154 | 155 | data_dict = { 156 | "latents": np.concatenate(latents_list, axis=0), 157 | "images": np.concatenate(images_list, axis=0), 158 | "prompt_embeds_list": np.concatenate(prompt_embeds_list, axis=0), 159 | "pooled_prompt_embeds": np.concatenate(pooled_prompt_embeds_list, axis=0) 160 | } 161 | output_path = os.path.join(args.folder, f"noise_image_pairs_{accelerator.process_index:03d}.pt") 162 | torch.save( 163 | data_dict, output_path, pickle_protocol=5 164 | ) 165 | accelerator.wait_for_everyone() 166 | 167 | if __name__ == "__main__": 168 | main() -------------------------------------------------------------------------------- /main/sdxl/generate_vae_latents.py: -------------------------------------------------------------------------------- 1 | from diffusers import AutoencoderKL 2 | from PIL import Image 3 | from tqdm import tqdm 4 | import numpy as np 5 | import accelerate 6 | import argparse 7 | import torch 8 | import time 9 | import glob 10 | import os 11 | 12 | torch.set_grad_enabled(False) 13 | 14 | class TempDataset(torch.utils.data.Dataset): 15 | def __init__(self, image_array, prompt_array, image_size, resize=False): 16 | self.image_array = image_array 17 | self.prompt_array = prompt_array 18 | self.image_size = image_size 19 | self.resize = resize 20 | 21 | def __len__(self): 22 | return len(self.image_array) 23 | 24 | def __getitem__(self, idx): 25 | prompt = self.prompt_array[idx] 26 | image = self.image_array[idx % len(self.image_array)].permute(1, 2, 0) 27 | image = image.numpy() 28 | if self.resize: 29 | image = Image.fromarray(image).resize((self.image_size, self.image_size), Image.LANCZOS) 30 | return { 31 | "images": torch.tensor(np.array(image), dtype=torch.float32).permute(2, 0, 1), 32 | "prompts": prompt 33 | } 34 | 35 | 36 | def main(): 37 | parser = argparse.ArgumentParser() 38 | parser.add_argument("--folder", type=str, required=True, help="path to folder") 39 | parser.add_argument("--output_folder", type=str) 40 | parser.add_argument("--model_name", type=str, choices=['sdxl', 'sd'], default='sdxl') 41 | parser.add_argument("--image_size", type=int, default=0) 42 | parser.add_argument("--resize", action="store_true") 43 | 44 | args = parser.parse_args() 45 | 46 | os.makedirs(args.output_folder, exist_ok=True) 47 | 48 | IS_SDXL = args.model_name == 'sdxl' 49 | SDXL_MODEL_NAME = "stabilityai/stable-diffusion-xl-base-1.0" 50 | SD_MODEL_NAME = "runwayml/stable-diffusion-v1-5" 51 | 52 | accelerator = accelerate.Accelerator() 53 | torch.backends.cuda.matmul.allow_tf32 = True 54 | torch.backends.cudnn.allow_tf32 = True 55 | 56 | vae = AutoencoderKL.from_pretrained( 57 | SDXL_MODEL_NAME if IS_SDXL else SD_MODEL_NAME, 58 | subfolder="vae" 59 | ).to(accelerator.device).float() 60 | 61 | image_file = sorted(glob.glob(os.path.join(args.folder, "*.pt"))) 62 | 63 | print(f"process {accelerator.process_index}, file {image_file}") 64 | 65 | if accelerator.process_index >= len(image_file): 66 | time.sleep(100000) 67 | 68 | image_file = image_file[accelerator.process_index] 69 | 70 | print(f"process {accelerator.process_index} start loading data...") 71 | data = torch.load(image_file) 72 | 73 | print( f"process {accelerator.process_index}done loading data...") 74 | 75 | image_dataset = TempDataset(data['images'], data["prompts"], args.image_size, resize=args.resize) 76 | image_dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=16, num_workers=8) 77 | 78 | latent_list = [] 79 | prompt_list = [] 80 | 81 | for i, data in tqdm(enumerate(image_dataloader), disable=accelerator.local_process_index != 0): 82 | batch_images = ((data["images"] / 255.0) * 2.0 - 1.0).to(accelerator.device) 83 | batch_prompts = data["prompts"] 84 | 85 | with torch.no_grad(): 86 | latents = vae.encode(batch_images).latent_dist.sample() * vae.config.scaling_factor 87 | 88 | latent_list.append(latents.half().cpu().numpy()) 89 | prompt_list.extend(batch_prompts) 90 | 91 | data_dict = { 92 | "latents": np.concatenate(latent_list, axis=0), 93 | "prompts": np.array(prompt_list) 94 | } 95 | output_path = os.path.join(args.output_folder, f"vae_latents_{accelerator.process_index:03d}.pt") 96 | torch.save( 97 | data_dict, output_path, pickle_protocol=5 98 | ) 99 | 100 | print(f"process {accelerator.process_index} done!") 101 | accelerator.wait_for_everyone() 102 | 103 | 104 | if __name__ == "__main__": 105 | main() -------------------------------------------------------------------------------- /main/sdxl/sdxl_ode_dataset.py: -------------------------------------------------------------------------------- 1 | from main.utils import retrieve_row_from_lmdb, get_array_shape_from_lmdb 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import torch 5 | import lmdb 6 | 7 | 8 | class SDXLODEDatasetLMDB(Dataset): 9 | def __init__(self, ode_pair_path, num_ode_pairs=0, return_first=True): 10 | self.KEY_TO_TYPE = { 11 | 'latents': np.float16, 12 | 'images': np.float16, 13 | 'prompt_embeds_list': np.float16, 14 | 'pooled_prompt_embeds': np.float16 15 | } 16 | 17 | self.ode_pair_path = ode_pair_path 18 | 19 | self.env = lmdb.open(ode_pair_path, readonly=True, lock=False, readahead=False, meminit=False) 20 | self.image_shape = get_array_shape_from_lmdb(self.env, "images") 21 | self.latent_shape = get_array_shape_from_lmdb(self.env, "latents") 22 | self.prompt_embeds_list_shape = get_array_shape_from_lmdb(self.env, "prompt_embeds_list") 23 | self.pooled_prompt_embeds_shape = get_array_shape_from_lmdb(self.env, "pooled_prompt_embeds") 24 | 25 | self.length = self.image_shape[0] 26 | 27 | if num_ode_pairs > 0: 28 | self.length = min(num_ode_pairs, self.length) 29 | 30 | print(f"Dataset length: {self.length}") 31 | 32 | # if we store the whole trajectory, by default we only output the starting point (initial noise) 33 | self.return_first = return_first 34 | 35 | def __len__(self): 36 | return self.length 37 | 38 | def __getitem__(self, idx): 39 | image = retrieve_row_from_lmdb( 40 | self.env, 41 | "images", self.KEY_TO_TYPE['images'], self.image_shape[1:], idx 42 | ) 43 | 44 | noise = retrieve_row_from_lmdb( 45 | self.env, 46 | "latents", self.KEY_TO_TYPE['latents'], self.latent_shape[1:], idx 47 | ) 48 | 49 | if noise.ndim == 5 and self.return_first: 50 | # select the starting point (initial noise) 51 | noise = noise[:, 0] 52 | 53 | prompt_embed = retrieve_row_from_lmdb( 54 | self.env, 55 | "prompt_embeds_list", self.KEY_TO_TYPE['prompt_embeds_list'], self.prompt_embeds_list_shape[1:], idx 56 | ) 57 | 58 | pooled_prompt_embed = retrieve_row_from_lmdb( 59 | self.env, 60 | "pooled_prompt_embeds", self.KEY_TO_TYPE['pooled_prompt_embeds'], self.pooled_prompt_embeds_shape[1:], idx 61 | ) 62 | 63 | embed_dict = { 64 | "prompt_embed": torch.tensor(prompt_embed, dtype=torch.float32), 65 | "pooled_prompt_embed": torch.tensor(pooled_prompt_embed, dtype=torch.float32) 66 | } 67 | 68 | image = torch.tensor(image, dtype=torch.float32) 69 | noise = torch.tensor(noise, dtype=torch.float32) 70 | 71 | output_dict = { 72 | 'images': image, 73 | 'latents': noise, 74 | 'embed_dict': embed_dict, 75 | } 76 | return output_dict -------------------------------------------------------------------------------- /main/sdxl/sdxl_text_encoder.py: -------------------------------------------------------------------------------- 1 | from transformers import CLIPTextModel, CLIPTextModelWithProjection 2 | import torch 3 | 4 | class SDXLTextEncoder(torch.nn.Module): 5 | def __init__(self, args, accelerator, dtype=torch.float32) -> None: 6 | super().__init__() 7 | 8 | self.text_encoder_one = CLIPTextModel.from_pretrained( 9 | args.model_id, subfolder="text_encoder", revision=args.revision 10 | ).to(accelerator.device).to(dtype=dtype) 11 | 12 | self.text_encoder_two = CLIPTextModelWithProjection.from_pretrained( 13 | args.model_id, subfolder="text_encoder_2", revision=args.revision 14 | ).to(accelerator.device).to(dtype=dtype) 15 | 16 | self.accelerator = accelerator 17 | 18 | def forward(self, batch): 19 | text_input_ids_one = batch['text_input_ids_one'].to(self.accelerator.device).squeeze(1) 20 | text_input_ids_two = batch['text_input_ids_two'].to(self.accelerator.device).squeeze(1) 21 | prompt_embeds_list = [] 22 | 23 | for text_input_ids, text_encoder in zip([text_input_ids_one, text_input_ids_two], [self.text_encoder_one, self.text_encoder_two]): 24 | prompt_embeds = text_encoder( 25 | text_input_ids.to(text_encoder.device), 26 | output_hidden_states=True, 27 | ) 28 | 29 | # We are only ALWAYS interested in the pooled output of the final text encoder 30 | pooled_prompt_embeds = prompt_embeds[0] 31 | 32 | prompt_embeds = prompt_embeds.hidden_states[-2] 33 | bs_embed, seq_len, _ = prompt_embeds.shape 34 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) 35 | prompt_embeds_list.append(prompt_embeds) 36 | 37 | prompt_embeds = torch.cat(prompt_embeds_list, dim=-1) 38 | pooled_prompt_embeds = pooled_prompt_embeds.view(len(text_input_ids_one), -1) # use the second text encoder's pooled prompt embeds (overwrite in for loop) 39 | 40 | return prompt_embeds, pooled_prompt_embeds 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anyio 2 | torch==2.0.1 3 | torchvision==0.15.2 4 | git+https://github.com/openai/CLIP.git 5 | open_clip_torch 6 | image-reward 7 | diffusers 8 | peft 9 | wandb 10 | lmdb 11 | transformers 12 | accelerate==0.23.0 13 | lmdb 14 | datasets 15 | evaluate 16 | scipy 17 | opencv-python 18 | matplotlib 19 | imageio 20 | piq==0.7.0 21 | safetensors 22 | gradio 23 | huggingface-hub==0.22.0 24 | clean-fid -------------------------------------------------------------------------------- /scripts/download_hf_checkpoint.sh: -------------------------------------------------------------------------------- 1 | CHECKPOINT_NAME=$1 2 | OUTPUT_PATH=$2 3 | 4 | mkdir $OUTPUT_PATH 5 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/model/$CHECKPOINT_NAME/optimizer.bin?download=true -O $OUTPUT_PATH/optimizer.bin 6 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/model/$CHECKPOINT_NAME/optimizer_1.bin?download=true -O $OUTPUT_PATH/optimizer_1.bin 7 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/model/$CHECKPOINT_NAME/pytorch_model.bin?download=true -O $OUTPUT_PATH/pytorch_model.bin 8 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/model/$CHECKPOINT_NAME/pytorch_model_1.bin?download=true -O $OUTPUT_PATH/pytorch_model_1.bin 9 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/model/$CHECKPOINT_NAME/scheduler.bin?download=true -O $OUTPUT_PATH/scheduler.bin 10 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/model/$CHECKPOINT_NAME/scheduler_1.bin?download=true -O $OUTPUT_PATH/scheduler_1.bin 11 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/model/$CHECKPOINT_NAME/random_states_0.pkl?download=true -O $OUTPUT_PATH/random_states_0.pkl 12 | -------------------------------------------------------------------------------- /scripts/download_imagenet.sh: -------------------------------------------------------------------------------- 1 | CHECKPOINT_PATH=$1 2 | 3 | wget https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-imagenet-64x64-cond-adm.pkl -O $CHECKPOINT_PATH/edm-imagenet-64x64-cond-adm.pkl 4 | 5 | wget https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl -O $CHECKPOINT_PATH/inception-2015-12-05.pkl 6 | 7 | wget https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/imagenet-64x64.npz -O $CHECKPOINT_PATH/imagenet_fid_refs_edm.npz 8 | 9 | # download the imagenet-64x64 lmdb 10 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/data/imagenet/imagenet-64x64_lmdb.zip?download=true -O $CHECKPOINT_PATH/imagenet-64x64_lmdb.zip 11 | unzip $CHECKPOINT_PATH/imagenet-64x64_lmdb.zip -d $CHECKPOINT_PATH 12 | -------------------------------------------------------------------------------- /scripts/download_sdv15.sh: -------------------------------------------------------------------------------- 1 | CHECKPOINT_PATH=$1 2 | 3 | # training prompts 4 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/data/laion/captions_laion_score6.25.pkl?download=true -O $CHECKPOINT_PATH/captions_laion_score6.25.pkl 5 | 6 | # evaluation prompts 7 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/data/coco/captions_coco14_test.pkl?download=true -O $CHECKPOINT_PATH/captions_coco14_test.pkl 8 | 9 | # real dataset 10 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/data/laion_vae_latents/sd_vae_latents_laion_500k_lmdb.zip?download=true -O $CHECKPOINT_PATH/sd_vae_latents_laion_500k_lmdb.zip 11 | unzip $CHECKPOINT_PATH/sd_vae_latents_laion_500k_lmdb.zip -d $CHECKPOINT_PATH 12 | 13 | # evaluation images 14 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/data/coco/val2014.zip?download=true -O $CHECKPOINT_PATH/val2014.zip 15 | unzip $CHECKPOINT_PATH/val2014.zip -d $CHECKPOINT_PATH -------------------------------------------------------------------------------- /scripts/download_sdxl.sh: -------------------------------------------------------------------------------- 1 | CHECKPOINT_PATH=$1 2 | 3 | # training prompts 4 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/data/laion/captions_laion_score6.25.pkl?download=true -O $CHECKPOINT_PATH/captions_laion_score6.25.pkl 5 | 6 | # evaluation prompts 7 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/data/coco/captions_coco14_test.pkl?download=true -O $CHECKPOINT_PATH/captions_coco14_test.pkl 8 | 9 | 10 | 11 | mkdir $CHECKPOINT_PATH/sdxl_vae_latents_laion_500k 12 | # real dataset 13 | for INDEX in {0..59} 14 | do 15 | # Format the index to be zero-padded to three digits 16 | INDEX_PADDED=$(printf "%03d" $INDEX) 17 | 18 | # Download the file 19 | wget "https://huggingface.co/tianweiy/DMD2/resolve/main/data/laion_vae_latents/sdxl_vae_latents_laion_500k/vae_latents_${INDEX_PADDED}.pt?download=true" -O "${CHECKPOINT_PATH}/sdxl_vae_latents_laion_500k/vae_latents_${INDEX_PADDED}.pt" 20 | done 21 | 22 | # generate the lmdb database from the downloaded files 23 | python main/data/create_lmdb_iterative.py --data_path $CHECKPOINT_PATH/sdxl_vae_latents_laion_500k/ --lmdb_path $CHECKPOINT_PATH/sdxl_vae_latents_laion_500k_lmdb 24 | 25 | # evaluation images 26 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/data/coco/coco10k.zip?download=true -O $CHECKPOINT_PATH/coco10k.zip 27 | unzip $CHECKPOINT_PATH/coco10k.zip -d $CHECKPOINT_PATH -------------------------------------------------------------------------------- /scripts/download_sdxl_1step_ode_pairs_ckpt.sh: -------------------------------------------------------------------------------- 1 | CHECKPOINT_PATH=$1 2 | 3 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/model/sdxl/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399_checkpoint_model_002000.bin?download=true -O $CHECKPOINT_PATH/sdxl_lr1e-5_8node_ode_pretraining_10k_cond399_checkpoint_model_002000.bin 4 | -------------------------------------------------------------------------------- /scripts/download_sdxl_ode_pair_10k_lmdb.sh: -------------------------------------------------------------------------------- 1 | CHECKPOINT_PATH=$1 2 | 3 | wget https://huggingface.co/tianweiy/DMD2/resolve/main/data/laion/sdxl_ode_pair_10k_lmdb.zip?download=true -O $CHECKPOINT_PATH/sdxl_ode_pair_10k_lmdb.zip 4 | 5 | unzip $CHECKPOINT_PATH/sdxl_ode_pair_10k_lmdb.zip -d $CHECKPOINT_PATH -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | setup( 3 | name="DMD2", 4 | version="0.0.1", 5 | packages=find_packages(), 6 | ) -------------------------------------------------------------------------------- /third_party/edm/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | FROM nvcr.io/nvidia/pytorch:22.10-py3 9 | 10 | ENV PYTHONDONTWRITEBYTECODE 1 11 | ENV PYTHONUNBUFFERED 1 12 | 13 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0 14 | 15 | WORKDIR /workspace 16 | 17 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 18 | ENTRYPOINT ["/entry.sh"] 19 | -------------------------------------------------------------------------------- /third_party/edm/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | from .util import EasyDict, make_cache_dir_path 9 | -------------------------------------------------------------------------------- /third_party/edm/docs/afhqv2-64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/DMD2/8d8fa55633d47cfb81bbc7a892e7248f9518763f/third_party/edm/docs/afhqv2-64x64.png -------------------------------------------------------------------------------- /third_party/edm/docs/cifar10-32x32.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/DMD2/8d8fa55633d47cfb81bbc7a892e7248f9518763f/third_party/edm/docs/cifar10-32x32.png -------------------------------------------------------------------------------- /third_party/edm/docs/dataset-tool-help.txt: -------------------------------------------------------------------------------- 1 | Usage: dataset_tool.py [OPTIONS] 2 | 3 | Convert an image dataset into a dataset archive usable with StyleGAN2 ADA 4 | PyTorch. 5 | 6 | The input dataset format is guessed from the --source argument: 7 | 8 | --source *_lmdb/ Load LSUN dataset 9 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 10 | --source train-images-idx3-ubyte.gz Load MNIST dataset 11 | --source path/ Recursively load all images from path/ 12 | --source dataset.zip Recursively load all images from dataset.zip 13 | 14 | Specifying the output format and path: 15 | 16 | --dest /path/to/dir Save output files under /path/to/dir 17 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 18 | 19 | The output dataset format can be either an image folder or an uncompressed 20 | zip archive. Zip archives makes it easier to move datasets around file 21 | servers and clusters, and may offer better training performance on network 22 | file systems. 23 | 24 | Images within the dataset archive will be stored as uncompressed PNG. 25 | Uncompresed PNGs can be efficiently decoded in the training loop. 26 | 27 | Class labels are stored in a file called 'dataset.json' that is stored at 28 | the dataset root folder. This file has the following structure: 29 | 30 | { 31 | "labels": [ 32 | ["00000/img00000000.png",6], 33 | ["00000/img00000001.png",9], 34 | ... repeated for every image in the datase 35 | ["00049/img00049999.png",1] 36 | ] 37 | } 38 | 39 | If the 'dataset.json' file cannot be found, class labels are determined from 40 | top-level directory names. 41 | 42 | Image scale/crop and resolution requirements: 43 | 44 | Output images must be square-shaped and they must all have the same power- 45 | of-two dimensions. 46 | 47 | To scale arbitrary input image size to a specific width and height, use the 48 | --resolution option. Output resolution will be either the original input 49 | resolution (if resolution was not specified) or the one specified with 50 | --resolution option. 51 | 52 | Use the --transform=center-crop or --transform=center-crop-wide options to 53 | apply a center crop transform on the input image. These options should be 54 | used with the --resolution option. For example: 55 | 56 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \ 57 | --transform=center-crop-wide --resolution=512x384 58 | 59 | Options: 60 | --source PATH Input directory or archive name [required] 61 | --dest PATH Output directory or archive name [required] 62 | --max-images INT Maximum number of images to output 63 | --transform MODE Input crop/resize mode 64 | --resolution WxH Output resolution (e.g., 512x512) 65 | --help Show this message and exit. 66 | -------------------------------------------------------------------------------- /third_party/edm/docs/ffhq-64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/DMD2/8d8fa55633d47cfb81bbc7a892e7248f9518763f/third_party/edm/docs/ffhq-64x64.png -------------------------------------------------------------------------------- /third_party/edm/docs/fid-help.txt: -------------------------------------------------------------------------------- 1 | Usage: fid.py [OPTIONS] COMMAND [ARGS]... 2 | 3 | Calculate Frechet Inception Distance (FID). 4 | 5 | Examples: 6 | 7 | # Generate 50000 images and save them as fid-tmp/*/*.png 8 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \ 9 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 10 | 11 | # Calculate FID 12 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \ 13 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 14 | 15 | # Compute dataset reference statistics 16 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 17 | 18 | Options: 19 | --help Show this message and exit. 20 | 21 | Commands: 22 | calc Calculate FID for a given set of images. 23 | ref Calculate dataset reference statistics needed by 'calc'. 24 | 25 | 26 | Usage: fid.py calc [OPTIONS] 27 | 28 | Calculate FID for a given set of images. 29 | 30 | Options: 31 | --images PATH|ZIP Path to the images [required] 32 | --ref NPZ|URL Dataset reference statistics [required] 33 | --num INT Number of images to use [default: 50000; x>=2] 34 | --seed INT Random seed for selecting the images [default: 0] 35 | --batch INT Maximum batch size [default: 64; x>=1] 36 | --help Show this message and exit. 37 | 38 | 39 | Usage: fid.py ref [OPTIONS] 40 | 41 | Calculate dataset reference statistics needed by 'calc'. 42 | 43 | Options: 44 | --data PATH|ZIP Path to the dataset [required] 45 | --dest NPZ Destination .npz file [required] 46 | --batch INT Maximum batch size [default: 64; x>=1] 47 | --help Show this message and exit. 48 | -------------------------------------------------------------------------------- /third_party/edm/docs/generate-help.txt: -------------------------------------------------------------------------------- 1 | Usage: generate.py [OPTIONS] 2 | 3 | Generate random images using the techniques described in the paper 4 | "Elucidating the Design Space of Diffusion-Based Generative Models". 5 | 6 | Examples: 7 | 8 | # Generate 64 images and save them as out/*.png 9 | python generate.py --outdir=out --seeds=0-63 --batch=64 \ 10 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 11 | 12 | # Generate 1024 images using 2 GPUs 13 | torchrun --standalone --nproc_per_node=2 generate.py --outdir=out --seeds=0-999 --batch=64 \ 14 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 15 | 16 | Options: 17 | --network PATH|URL Network pickle filename [required] 18 | --outdir DIR Where to save the output images [required] 19 | --seeds LIST Random seeds (e.g. 1,2,5-10) [default: 0-63] 20 | --subdirs Create subdirectory for every 1000 seeds 21 | --class INT Class label [default: random] [x>=0] 22 | --batch INT Maximum batch size [default: 64; x>=1] 23 | --steps INT Number of sampling steps [default: 18; x>=1] 24 | --sigma_min FLOAT Lowest noise level [default: varies] [x>0] 25 | --sigma_max FLOAT Highest noise level [default: varies] [x>0] 26 | --rho FLOAT Time step exponent [default: 7; x>0] 27 | --S_churn FLOAT Stochasticity strength [default: 0; x>=0] 28 | --S_min FLOAT Stoch. min noise level [default: 0; x>=0] 29 | --S_max FLOAT Stoch. max noise level [default: inf; x>=0] 30 | --S_noise FLOAT Stoch. noise inflation [default: 1] 31 | --solver euler|heun Ablate ODE solver 32 | --disc vp|ve|iddpm|edm Ablate time step discretization {t_i} 33 | --schedule vp|ve|linear Ablate noise schedule sigma(t) 34 | --scaling vp|none Ablate signal scaling s(t) 35 | --help Show this message and exit. 36 | -------------------------------------------------------------------------------- /third_party/edm/docs/imagenet-64x64.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/DMD2/8d8fa55633d47cfb81bbc7a892e7248f9518763f/third_party/edm/docs/imagenet-64x64.png -------------------------------------------------------------------------------- /third_party/edm/docs/teaser-1280x640.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/DMD2/8d8fa55633d47cfb81bbc7a892e7248f9518763f/third_party/edm/docs/teaser-1280x640.jpg -------------------------------------------------------------------------------- /third_party/edm/docs/teaser-1920x640.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/DMD2/8d8fa55633d47cfb81bbc7a892e7248f9518763f/third_party/edm/docs/teaser-1920x640.jpg -------------------------------------------------------------------------------- /third_party/edm/docs/teaser-640x480.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tianweiy/DMD2/8d8fa55633d47cfb81bbc7a892e7248f9518763f/third_party/edm/docs/teaser-640x480.jpg -------------------------------------------------------------------------------- /third_party/edm/docs/train-help.txt: -------------------------------------------------------------------------------- 1 | Usage: train.py [OPTIONS] 2 | 3 | Train diffusion-based generative model using the techniques described in the 4 | paper "Elucidating the Design Space of Diffusion-Based Generative Models". 5 | 6 | Examples: 7 | 8 | # Train DDPM++ model for class-conditional CIFAR-10 using 8 GPUs 9 | torchrun --standalone --nproc_per_node=8 train.py --outdir=training-runs \ 10 | --data=datasets/cifar10-32x32.zip --cond=1 --arch=ddpmpp 11 | 12 | Options: 13 | --outdir DIR Where to save the results [required] 14 | --data ZIP|DIR Path to the dataset [required] 15 | --cond BOOL Train class-conditional model [default: False] 16 | --arch ddpmpp|ncsnpp|adm Network architecture [default: ddpmpp] 17 | --precond vp|ve|edm Preconditioning & loss function [default: edm] 18 | --duration MIMG Training duration [default: 200; x>0] 19 | --batch INT Total batch size [default: 512; x>=1] 20 | --batch-gpu INT Limit batch size per GPU [x>=1] 21 | --cbase INT Channel multiplier [default: varies] 22 | --cres LIST Channels per resolution [default: varies] 23 | --lr FLOAT Learning rate [default: 0.001; x>0] 24 | --ema MIMG EMA half-life [default: 0.5; x>=0] 25 | --dropout FLOAT Dropout probability [default: 0.13; 0<=x<=1] 26 | --augment FLOAT Augment probability [default: 0.12; 0<=x<=1] 27 | --xflip BOOL Enable dataset x-flips [default: False] 28 | --fp16 BOOL Enable mixed-precision training [default: False] 29 | --ls FLOAT Loss scaling [default: 1; x>0] 30 | --bench BOOL Enable cuDNN benchmarking [default: True] 31 | --cache BOOL Cache dataset in CPU memory [default: True] 32 | --workers INT DataLoader worker processes [default: 1; x>=1] 33 | --desc STR String to include in result dir name 34 | --nosubdir Do not create a subdirectory for results 35 | --tick KIMG How often to print progress [default: 50; x>=1] 36 | --snap TICKS How often to save snapshots [default: 50; x>=1] 37 | --dump TICKS How often to dump state [default: 500; x>=1] 38 | --seed INT Random seed [default: random] 39 | --transfer PKL|URL Transfer learning from network pickle 40 | --resume PT Resume from previous training state 41 | -n, --dry-run Print training options and exit 42 | --help Show this message and exit. 43 | -------------------------------------------------------------------------------- /third_party/edm/environment.yml: -------------------------------------------------------------------------------- 1 | name: edm 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python>=3.8, < 3.10 # package build failures on 3.10 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow>=8.3.1 11 | - scipy>=1.7.1 12 | - pytorch=1.12.1 13 | - psutil 14 | - requests 15 | - tqdm 16 | - imageio 17 | - pip: 18 | - imageio-ffmpeg>=0.4.3 19 | - pyspng 20 | -------------------------------------------------------------------------------- /third_party/edm/example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Minimal standalone example to reproduce the main results from the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import tqdm 12 | import pickle 13 | import numpy as np 14 | import torch 15 | import PIL.Image 16 | import dnnlib 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def generate_image_grid( 21 | network_pkl, dest_path, 22 | seed=0, gridw=8, gridh=8, device=torch.device('cuda'), 23 | num_steps=18, sigma_min=0.002, sigma_max=80, rho=7, 24 | S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, 25 | ): 26 | batch_size = gridw * gridh 27 | torch.manual_seed(seed) 28 | 29 | # Load network. 30 | print(f'Loading network from "{network_pkl}"...') 31 | with dnnlib.util.open_url(network_pkl) as f: 32 | net = pickle.load(f)['ema'].to(device) 33 | 34 | # Pick latents and labels. 35 | print(f'Generating {batch_size} images...') 36 | latents = torch.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution], device=device) 37 | class_labels = None 38 | if net.label_dim: 39 | class_labels = torch.eye(net.label_dim, device=device)[torch.randint(net.label_dim, size=[batch_size], device=device)] 40 | 41 | # Adjust noise levels based on what's supported by the network. 42 | sigma_min = max(sigma_min, net.sigma_min) 43 | sigma_max = min(sigma_max, net.sigma_max) 44 | 45 | # Time step discretization. 46 | step_indices = torch.arange(num_steps, dtype=torch.float64, device=device) 47 | t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho 48 | t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])]) # t_N = 0 49 | 50 | # Main sampling loop. 51 | x_next = latents.to(torch.float64) * t_steps[0] 52 | for i, (t_cur, t_next) in tqdm.tqdm(list(enumerate(zip(t_steps[:-1], t_steps[1:]))), unit='step'): # 0, ..., N-1 53 | x_cur = x_next 54 | 55 | # Increase noise temporarily. 56 | gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 57 | t_hat = net.round_sigma(t_cur + gamma * t_cur) 58 | x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) 59 | 60 | # Euler step. 61 | denoised = net(x_hat, t_hat, class_labels).to(torch.float64) 62 | d_cur = (x_hat - denoised) / t_hat 63 | x_next = x_hat + (t_next - t_hat) * d_cur 64 | 65 | # Apply 2nd order correction. 66 | if i < num_steps - 1: 67 | denoised = net(x_next, t_next, class_labels).to(torch.float64) 68 | d_prime = (x_next - denoised) / t_next 69 | x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) 70 | 71 | # Save image grid. 72 | print(f'Saving image grid to "{dest_path}"...') 73 | image = (x_next * 127.5 + 128).clip(0, 255).to(torch.uint8) 74 | image = image.reshape(gridh, gridw, *image.shape[1:]).permute(0, 3, 1, 4, 2) 75 | image = image.reshape(gridh * net.img_resolution, gridw * net.img_resolution, net.img_channels) 76 | image = image.cpu().numpy() 77 | PIL.Image.fromarray(image, 'RGB').save(dest_path) 78 | print('Done.') 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def main(): 83 | model_root = 'https://nvlabs-fi-cdn.nvidia.com/edm/pretrained' 84 | generate_image_grid(f'{model_root}/edm-cifar10-32x32-cond-vp.pkl', 'cifar10-32x32.png', num_steps=18) # FID = 1.79, NFE = 35 85 | generate_image_grid(f'{model_root}/edm-ffhq-64x64-uncond-vp.pkl', 'ffhq-64x64.png', num_steps=40) # FID = 2.39, NFE = 79 86 | generate_image_grid(f'{model_root}/edm-afhqv2-64x64-uncond-vp.pkl', 'afhqv2-64x64.png', num_steps=40) # FID = 1.96, NFE = 79 87 | generate_image_grid(f'{model_root}/edm-imagenet-64x64-cond-adm.pkl', 'imagenet-64x64.png', num_steps=256, S_churn=40, S_min=0.05, S_max=50, S_noise=1.003) # FID = 1.36, NFE = 511 88 | 89 | #---------------------------------------------------------------------------- 90 | 91 | if __name__ == "__main__": 92 | main() 93 | 94 | #---------------------------------------------------------------------------- 95 | -------------------------------------------------------------------------------- /third_party/edm/fid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Script for calculating Frechet Inception Distance (FID).""" 9 | 10 | import os 11 | import click 12 | import tqdm 13 | import pickle 14 | import numpy as np 15 | import scipy.linalg 16 | import torch 17 | import dnnlib 18 | from torch_utils import distributed as dist 19 | from training import dataset 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | def calculate_inception_stats( 24 | image_path, num_expected=None, seed=0, max_batch_size=64, 25 | num_workers=3, prefetch_factor=2, device=torch.device('cuda'), 26 | ): 27 | # Rank 0 goes first. 28 | if dist.get_rank() != 0: 29 | torch.distributed.barrier() 30 | 31 | # Load Inception-v3 model. 32 | # This is a direct PyTorch translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 33 | dist.print0('Loading Inception-v3 model...') 34 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 35 | detector_kwargs = dict(return_features=True) 36 | feature_dim = 2048 37 | with dnnlib.util.open_url(detector_url, verbose=(dist.get_rank() == 0)) as f: 38 | detector_net = pickle.load(f).to(device) 39 | 40 | # List images. 41 | dist.print0(f'Loading images from "{image_path}"...') 42 | dataset_obj = dataset.ImageFolderDataset(path=image_path, max_size=num_expected, random_seed=seed) 43 | if num_expected is not None and len(dataset_obj) < num_expected: 44 | raise click.ClickException(f'Found {len(dataset_obj)} images, but expected at least {num_expected}') 45 | if len(dataset_obj) < 2: 46 | raise click.ClickException(f'Found {len(dataset_obj)} images, but need at least 2 to compute statistics') 47 | 48 | # Other ranks follow. 49 | if dist.get_rank() == 0: 50 | torch.distributed.barrier() 51 | 52 | # Divide images into batches. 53 | num_batches = ((len(dataset_obj) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size() 54 | all_batches = torch.arange(len(dataset_obj)).tensor_split(num_batches) 55 | rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()] 56 | data_loader = torch.utils.data.DataLoader(dataset_obj, batch_sampler=rank_batches, num_workers=num_workers, prefetch_factor=prefetch_factor) 57 | 58 | # Accumulate statistics. 59 | dist.print0(f'Calculating statistics for {len(dataset_obj)} images...') 60 | mu = torch.zeros([feature_dim], dtype=torch.float64, device=device) 61 | sigma = torch.zeros([feature_dim, feature_dim], dtype=torch.float64, device=device) 62 | for images, _labels in tqdm.tqdm(data_loader, unit='batch', disable=(dist.get_rank() != 0)): 63 | torch.distributed.barrier() 64 | if images.shape[0] == 0: 65 | continue 66 | if images.shape[1] == 1: 67 | images = images.repeat([1, 3, 1, 1]) 68 | features = detector_net(images.to(device), **detector_kwargs).to(torch.float64) 69 | mu += features.sum(0) 70 | sigma += features.T @ features 71 | 72 | # Calculate grand totals. 73 | torch.distributed.all_reduce(mu) 74 | torch.distributed.all_reduce(sigma) 75 | mu /= len(dataset_obj) 76 | sigma -= mu.ger(mu) * len(dataset_obj) 77 | sigma /= len(dataset_obj) - 1 78 | return mu.cpu().numpy(), sigma.cpu().numpy() 79 | 80 | #---------------------------------------------------------------------------- 81 | 82 | def calculate_fid_from_inception_stats(mu, sigma, mu_ref, sigma_ref): 83 | m = np.square(mu - mu_ref).sum() 84 | s, _ = scipy.linalg.sqrtm(np.dot(sigma, sigma_ref), disp=False) 85 | fid = m + np.trace(sigma + sigma_ref - s * 2) 86 | return float(np.real(fid)) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @click.group() 91 | def main(): 92 | """Calculate Frechet Inception Distance (FID). 93 | 94 | Examples: 95 | 96 | \b 97 | # Generate 50000 images and save them as fid-tmp/*/*.png 98 | torchrun --standalone --nproc_per_node=1 generate.py --outdir=fid-tmp --seeds=0-49999 --subdirs \\ 99 | --network=https://nvlabs-fi-cdn.nvidia.com/edm/pretrained/edm-cifar10-32x32-cond-vp.pkl 100 | 101 | \b 102 | # Calculate FID 103 | torchrun --standalone --nproc_per_node=1 fid.py calc --images=fid-tmp \\ 104 | --ref=https://nvlabs-fi-cdn.nvidia.com/edm/fid-refs/cifar10-32x32.npz 105 | 106 | \b 107 | # Compute dataset reference statistics 108 | python fid.py ref --data=datasets/my-dataset.zip --dest=fid-refs/my-dataset.npz 109 | """ 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | @main.command() 114 | @click.option('--images', 'image_path', help='Path to the images', metavar='PATH|ZIP', type=str, required=True) 115 | @click.option('--ref', 'ref_path', help='Dataset reference statistics ', metavar='NPZ|URL', type=str, required=True) 116 | @click.option('--num', 'num_expected', help='Number of images to use', metavar='INT', type=click.IntRange(min=2), default=50000, show_default=True) 117 | @click.option('--seed', help='Random seed for selecting the images', metavar='INT', type=int, default=0, show_default=True) 118 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 119 | 120 | def calc(image_path, ref_path, num_expected, seed, batch): 121 | """Calculate FID for a given set of images.""" 122 | torch.multiprocessing.set_start_method('spawn') 123 | dist.init() 124 | 125 | dist.print0(f'Loading dataset reference statistics from "{ref_path}"...') 126 | ref = None 127 | if dist.get_rank() == 0: 128 | with dnnlib.util.open_url(ref_path) as f: 129 | ref = dict(np.load(f)) 130 | 131 | mu, sigma = calculate_inception_stats(image_path=image_path, num_expected=num_expected, seed=seed, max_batch_size=batch) 132 | dist.print0('Calculating FID...') 133 | if dist.get_rank() == 0: 134 | fid = calculate_fid_from_inception_stats(mu, sigma, ref['mu'], ref['sigma']) 135 | print(f'{fid:g}') 136 | torch.distributed.barrier() 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | @main.command() 141 | @click.option('--data', 'dataset_path', help='Path to the dataset', metavar='PATH|ZIP', type=str, required=True) 142 | @click.option('--dest', 'dest_path', help='Destination .npz file', metavar='NPZ', type=str, required=True) 143 | @click.option('--batch', help='Maximum batch size', metavar='INT', type=click.IntRange(min=1), default=64, show_default=True) 144 | 145 | def ref(dataset_path, dest_path, batch): 146 | """Calculate dataset reference statistics needed by 'calc'.""" 147 | torch.multiprocessing.set_start_method('spawn') 148 | dist.init() 149 | 150 | mu, sigma = calculate_inception_stats(image_path=dataset_path, max_batch_size=batch) 151 | dist.print0(f'Saving dataset reference statistics to "{dest_path}"...') 152 | if dist.get_rank() == 0: 153 | if os.path.dirname(dest_path): 154 | os.makedirs(os.path.dirname(dest_path), exist_ok=True) 155 | np.savez(dest_path, mu=mu, sigma=sigma) 156 | 157 | torch.distributed.barrier() 158 | dist.print0('Done.') 159 | 160 | #---------------------------------------------------------------------------- 161 | 162 | if __name__ == "__main__": 163 | main() 164 | 165 | #---------------------------------------------------------------------------- 166 | -------------------------------------------------------------------------------- /third_party/edm/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /third_party/edm/torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /third_party/edm/torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for pickling Python code alongside other data. 9 | 10 | The pickled code is automatically imported into a separate Python module 11 | during unpickling. This way, any previously exported pickles will remain 12 | usable even if the original code is no longer available, or if the current 13 | version of the code is not consistent with what was originally pickled.""" 14 | 15 | import sys 16 | import pickle 17 | import io 18 | import inspect 19 | import copy 20 | import uuid 21 | import types 22 | import dnnlib 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | _version = 6 # internal version number 27 | _decorators = set() # {decorator_class, ...} 28 | _import_hooks = [] # [hook_function, ...] 29 | _module_to_src_dict = dict() # {module: src, ...} 30 | _src_to_module_dict = dict() # {src: module, ...} 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def persistent_class(orig_class): 35 | r"""Class decorator that extends a given class to save its source code 36 | when pickled. 37 | 38 | Example: 39 | 40 | from torch_utils import persistence 41 | 42 | @persistence.persistent_class 43 | class MyNetwork(torch.nn.Module): 44 | def __init__(self, num_inputs, num_outputs): 45 | super().__init__() 46 | self.fc = MyLayer(num_inputs, num_outputs) 47 | ... 48 | 49 | @persistence.persistent_class 50 | class MyLayer(torch.nn.Module): 51 | ... 52 | 53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 54 | source code alongside other internal state (e.g., parameters, buffers, 55 | and submodules). This way, any previously exported pickle will remain 56 | usable even if the class definitions have been modified or are no 57 | longer available. 58 | 59 | The decorator saves the source code of the entire Python module 60 | containing the decorated class. It does *not* save the source code of 61 | any imported modules. Thus, the imported modules must be available 62 | during unpickling, also including `torch_utils.persistence` itself. 63 | 64 | It is ok to call functions defined in the same module from the 65 | decorated class. However, if the decorated class depends on other 66 | classes defined in the same module, they must be decorated as well. 67 | This is illustrated in the above example in the case of `MyLayer`. 68 | 69 | It is also possible to employ the decorator just-in-time before 70 | calling the constructor. For example: 71 | 72 | cls = MyLayer 73 | if want_to_make_it_persistent: 74 | cls = persistence.persistent_class(cls) 75 | layer = cls(num_inputs, num_outputs) 76 | 77 | As an additional feature, the decorator also keeps track of the 78 | arguments that were used to construct each instance of the decorated 79 | class. The arguments can be queried via `obj.init_args` and 80 | `obj.init_kwargs`, and they are automatically pickled alongside other 81 | object state. This feature can be disabled on a per-instance basis 82 | by setting `self._record_init_args = False` in the constructor. 83 | 84 | A typical use case is to first unpickle a previous instance of a 85 | persistent class, and then upgrade it to use the latest version of 86 | the source code: 87 | 88 | with open('old_pickle.pkl', 'rb') as f: 89 | old_net = pickle.load(f) 90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 92 | """ 93 | assert isinstance(orig_class, type) 94 | if is_persistent(orig_class): 95 | return orig_class 96 | 97 | assert orig_class.__module__ in sys.modules 98 | orig_module = sys.modules[orig_class.__module__] 99 | orig_module_src = _module_to_src(orig_module) 100 | 101 | class Decorator(orig_class): 102 | _orig_module_src = orig_module_src 103 | _orig_class_name = orig_class.__name__ 104 | 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | record_init_args = getattr(self, '_record_init_args', True) 108 | self._init_args = copy.deepcopy(args) if record_init_args else None 109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 110 | assert orig_class.__name__ in orig_module.__dict__ 111 | _check_pickleable(self.__reduce__()) 112 | 113 | @property 114 | def init_args(self): 115 | assert self._init_args is not None 116 | return copy.deepcopy(self._init_args) 117 | 118 | @property 119 | def init_kwargs(self): 120 | assert self._init_kwargs is not None 121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 122 | 123 | def __reduce__(self): 124 | fields = list(super().__reduce__()) 125 | fields += [None] * max(3 - len(fields), 0) 126 | if fields[0] is not _reconstruct_persistent_obj: 127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 128 | fields[0] = _reconstruct_persistent_obj # reconstruct func 129 | fields[1] = (meta,) # reconstruct args 130 | fields[2] = None # state dict 131 | return tuple(fields) 132 | 133 | Decorator.__name__ = orig_class.__name__ 134 | Decorator.__module__ = orig_class.__module__ 135 | _decorators.add(Decorator) 136 | return Decorator 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | def is_persistent(obj): 141 | r"""Test whether the given object or class is persistent, i.e., 142 | whether it will save its source code when pickled. 143 | """ 144 | try: 145 | if obj in _decorators: 146 | return True 147 | except TypeError: 148 | pass 149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def import_hook(hook): 154 | r"""Register an import hook that is called whenever a persistent object 155 | is being unpickled. A typical use case is to patch the pickled source 156 | code to avoid errors and inconsistencies when the API of some imported 157 | module has changed. 158 | 159 | The hook should have the following signature: 160 | 161 | hook(meta) -> modified meta 162 | 163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 164 | 165 | type: Type of the persistent object, e.g. `'class'`. 166 | version: Internal version number of `torch_utils.persistence`. 167 | module_src Original source code of the Python module. 168 | class_name: Class name in the original Python module. 169 | state: Internal state of the object. 170 | 171 | Example: 172 | 173 | @persistence.import_hook 174 | def wreck_my_network(meta): 175 | if meta.class_name == 'MyNetwork': 176 | print('MyNetwork is being imported. I will wreck it!') 177 | meta.module_src = meta.module_src.replace("True", "False") 178 | return meta 179 | """ 180 | assert callable(hook) 181 | _import_hooks.append(hook) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | def _reconstruct_persistent_obj(meta): 186 | r"""Hook that is called internally by the `pickle` module to unpickle 187 | a persistent object. 188 | """ 189 | meta = dnnlib.EasyDict(meta) 190 | meta.state = dnnlib.EasyDict(meta.state) 191 | for hook in _import_hooks: 192 | meta = hook(meta) 193 | assert meta is not None 194 | 195 | assert meta.version == _version 196 | module = _src_to_module(meta.module_src) 197 | 198 | assert meta.type == 'class' 199 | orig_class = module.__dict__[meta.class_name] 200 | decorator_class = persistent_class(orig_class) 201 | obj = decorator_class.__new__(decorator_class) 202 | 203 | setstate = getattr(obj, '__setstate__', None) 204 | if callable(setstate): 205 | setstate(meta.state) # pylint: disable=not-callable 206 | else: 207 | obj.__dict__.update(meta.state) 208 | return obj 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | def _module_to_src(module): 213 | r"""Query the source code of a given Python module. 214 | """ 215 | src = _module_to_src_dict.get(module, None) 216 | if src is None: 217 | src = inspect.getsource(module) 218 | _module_to_src_dict[module] = src 219 | _src_to_module_dict[src] = module 220 | return src 221 | 222 | def _src_to_module(src): 223 | r"""Get or create a Python module for the given source code. 224 | """ 225 | module = _src_to_module_dict.get(src, None) 226 | if module is None: 227 | module_name = "_imported_module_" + uuid.uuid4().hex 228 | module = types.ModuleType(module_name) 229 | sys.modules[module_name] = module 230 | _module_to_src_dict[module] = src 231 | _src_to_module_dict[src] = module 232 | exec(src, module.__dict__) # pylint: disable=exec-used 233 | return module 234 | 235 | #---------------------------------------------------------------------------- 236 | 237 | def _check_pickleable(obj): 238 | r"""Check that the given object is pickleable, raising an exception if 239 | it is not. This function is expected to be considerably more efficient 240 | than actually pickling the object. 241 | """ 242 | def recurse(obj): 243 | if isinstance(obj, (list, tuple, set)): 244 | return [recurse(x) for x in obj] 245 | if isinstance(obj, dict): 246 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 248 | return None # Python primitive types are pickleable. 249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 250 | return None # NumPy arrays and PyTorch tensors are pickleable. 251 | if is_persistent(obj): 252 | return None # Persistent objects are pickleable, by virtue of the constructor check. 253 | return obj 254 | with io.BytesIO() as f: 255 | pickle.dump(recurse(obj), f) 256 | 257 | #---------------------------------------------------------------------------- 258 | -------------------------------------------------------------------------------- /third_party/edm/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /third_party/edm/training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Streaming images and labels from datasets created with dataset_tool.py.""" 9 | 10 | import os 11 | import numpy as np 12 | import zipfile 13 | import PIL.Image 14 | import json 15 | import torch 16 | import dnnlib 17 | 18 | try: 19 | import pyspng 20 | except ImportError: 21 | pyspng = None 22 | 23 | #---------------------------------------------------------------------------- 24 | # Abstract base class for datasets. 25 | 26 | class Dataset(torch.utils.data.Dataset): 27 | def __init__(self, 28 | name, # Name of the dataset. 29 | raw_shape, # Shape of the raw image data (NCHW). 30 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 31 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 32 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 33 | random_seed = 0, # Random seed to use when applying max_size. 34 | cache = False, # Cache images in CPU memory? 35 | ): 36 | self._name = name 37 | self._raw_shape = list(raw_shape) 38 | self._use_labels = use_labels 39 | self._cache = cache 40 | self._cached_images = dict() # {raw_idx: np.ndarray, ...} 41 | self._raw_labels = None 42 | self._label_shape = None 43 | 44 | # Apply max_size. 45 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 46 | if (max_size is not None) and (self._raw_idx.size > max_size): 47 | np.random.RandomState(random_seed % (1 << 31)).shuffle(self._raw_idx) 48 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 49 | 50 | # Apply xflip. 51 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 52 | if xflip: 53 | self._raw_idx = np.tile(self._raw_idx, 2) 54 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 55 | 56 | def _get_raw_labels(self): 57 | if self._raw_labels is None: 58 | self._raw_labels = self._load_raw_labels() if self._use_labels else None 59 | if self._raw_labels is None: 60 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 61 | assert isinstance(self._raw_labels, np.ndarray) 62 | assert self._raw_labels.shape[0] == self._raw_shape[0] 63 | assert self._raw_labels.dtype in [np.float32, np.int64] 64 | if self._raw_labels.dtype == np.int64: 65 | assert self._raw_labels.ndim == 1 66 | assert np.all(self._raw_labels >= 0) 67 | return self._raw_labels 68 | 69 | def close(self): # to be overridden by subclass 70 | pass 71 | 72 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 73 | raise NotImplementedError 74 | 75 | def _load_raw_labels(self): # to be overridden by subclass 76 | raise NotImplementedError 77 | 78 | def __getstate__(self): 79 | return dict(self.__dict__, _raw_labels=None) 80 | 81 | def __del__(self): 82 | try: 83 | self.close() 84 | except: 85 | pass 86 | 87 | def __len__(self): 88 | return self._raw_idx.size 89 | 90 | def __getitem__(self, idx): 91 | raw_idx = self._raw_idx[idx] 92 | image = self._cached_images.get(raw_idx, None) 93 | if image is None: 94 | image = self._load_raw_image(raw_idx) 95 | if self._cache: 96 | self._cached_images[raw_idx] = image 97 | assert isinstance(image, np.ndarray) 98 | assert list(image.shape) == self.image_shape 99 | assert image.dtype == np.uint8 100 | if self._xflip[idx]: 101 | assert image.ndim == 3 # CHW 102 | image = image[:, :, ::-1] 103 | return image.copy(), self.get_label(idx) 104 | 105 | def get_label(self, idx): 106 | label = self._get_raw_labels()[self._raw_idx[idx]] 107 | if label.dtype == np.int64: 108 | onehot = np.zeros(self.label_shape, dtype=np.float32) 109 | onehot[label] = 1 110 | label = onehot 111 | return label.copy() 112 | 113 | def get_details(self, idx): 114 | d = dnnlib.EasyDict() 115 | d.raw_idx = int(self._raw_idx[idx]) 116 | d.xflip = (int(self._xflip[idx]) != 0) 117 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 118 | return d 119 | 120 | @property 121 | def name(self): 122 | return self._name 123 | 124 | @property 125 | def image_shape(self): 126 | return list(self._raw_shape[1:]) 127 | 128 | @property 129 | def num_channels(self): 130 | assert len(self.image_shape) == 3 # CHW 131 | return self.image_shape[0] 132 | 133 | @property 134 | def resolution(self): 135 | assert len(self.image_shape) == 3 # CHW 136 | assert self.image_shape[1] == self.image_shape[2] 137 | return self.image_shape[1] 138 | 139 | @property 140 | def label_shape(self): 141 | if self._label_shape is None: 142 | raw_labels = self._get_raw_labels() 143 | if raw_labels.dtype == np.int64: 144 | self._label_shape = [int(np.max(raw_labels)) + 1] 145 | else: 146 | self._label_shape = raw_labels.shape[1:] 147 | return list(self._label_shape) 148 | 149 | @property 150 | def label_dim(self): 151 | assert len(self.label_shape) == 1 152 | return self.label_shape[0] 153 | 154 | @property 155 | def has_labels(self): 156 | return any(x != 0 for x in self.label_shape) 157 | 158 | @property 159 | def has_onehot_labels(self): 160 | return self._get_raw_labels().dtype == np.int64 161 | 162 | #---------------------------------------------------------------------------- 163 | # Dataset subclass that loads images recursively from the specified directory 164 | # or ZIP file. 165 | 166 | class ImageFolderDataset(Dataset): 167 | def __init__(self, 168 | path, # Path to directory or zip. 169 | resolution = None, # Ensure specific resolution, None = highest available. 170 | use_pyspng = True, # Use pyspng if available? 171 | **super_kwargs, # Additional arguments for the Dataset base class. 172 | ): 173 | self._path = path 174 | self._use_pyspng = use_pyspng 175 | self._zipfile = None 176 | 177 | if os.path.isdir(self._path): 178 | self._type = 'dir' 179 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 180 | elif self._file_ext(self._path) == '.zip': 181 | self._type = 'zip' 182 | self._all_fnames = set(self._get_zipfile().namelist()) 183 | else: 184 | raise IOError('Path must point to a directory or zip') 185 | 186 | PIL.Image.init() 187 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 188 | if len(self._image_fnames) == 0: 189 | raise IOError('No image files found in the specified path') 190 | 191 | name = os.path.splitext(os.path.basename(self._path))[0] 192 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 193 | if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 194 | raise IOError('Image files do not match the specified resolution') 195 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 196 | 197 | @staticmethod 198 | def _file_ext(fname): 199 | return os.path.splitext(fname)[1].lower() 200 | 201 | def _get_zipfile(self): 202 | assert self._type == 'zip' 203 | if self._zipfile is None: 204 | self._zipfile = zipfile.ZipFile(self._path) 205 | return self._zipfile 206 | 207 | def _open_file(self, fname): 208 | if self._type == 'dir': 209 | return open(os.path.join(self._path, fname), 'rb') 210 | if self._type == 'zip': 211 | return self._get_zipfile().open(fname, 'r') 212 | return None 213 | 214 | def close(self): 215 | try: 216 | if self._zipfile is not None: 217 | self._zipfile.close() 218 | finally: 219 | self._zipfile = None 220 | 221 | def __getstate__(self): 222 | return dict(super().__getstate__(), _zipfile=None) 223 | 224 | def _load_raw_image(self, raw_idx): 225 | fname = self._image_fnames[raw_idx] 226 | with self._open_file(fname) as f: 227 | if self._use_pyspng and pyspng is not None and self._file_ext(fname) == '.png': 228 | image = pyspng.load(f.read()) 229 | else: 230 | image = np.array(PIL.Image.open(f)) 231 | if image.ndim == 2: 232 | image = image[:, :, np.newaxis] # HW => HWC 233 | image = image.transpose(2, 0, 1) # HWC => CHW 234 | return image 235 | 236 | def _load_raw_labels(self): 237 | fname = 'dataset.json' 238 | if fname not in self._all_fnames: 239 | return None 240 | with self._open_file(fname) as f: 241 | labels = json.load(f)['labels'] 242 | if labels is None: 243 | return None 244 | labels = dict(labels) 245 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 246 | labels = np.array(labels) 247 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 248 | return labels 249 | 250 | #---------------------------------------------------------------------------- 251 | -------------------------------------------------------------------------------- /third_party/edm/training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Loss functions used in the paper 9 | "Elucidating the Design Space of Diffusion-Based Generative Models".""" 10 | 11 | import torch 12 | from torch_utils import persistence 13 | 14 | #---------------------------------------------------------------------------- 15 | # Loss function corresponding to the variance preserving (VP) formulation 16 | # from the paper "Score-Based Generative Modeling through Stochastic 17 | # Differential Equations". 18 | 19 | @persistence.persistent_class 20 | class VPLoss: 21 | def __init__(self, beta_d=19.9, beta_min=0.1, epsilon_t=1e-5): 22 | self.beta_d = beta_d 23 | self.beta_min = beta_min 24 | self.epsilon_t = epsilon_t 25 | 26 | def __call__(self, net, images, labels, augment_pipe=None): 27 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 28 | sigma = self.sigma(1 + rnd_uniform * (self.epsilon_t - 1)) 29 | weight = 1 / sigma ** 2 30 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 31 | n = torch.randn_like(y) * sigma 32 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 33 | loss = weight * ((D_yn - y) ** 2) 34 | return loss 35 | 36 | def sigma(self, t): 37 | t = torch.as_tensor(t) 38 | return ((0.5 * self.beta_d * (t ** 2) + self.beta_min * t).exp() - 1).sqrt() 39 | 40 | #---------------------------------------------------------------------------- 41 | # Loss function corresponding to the variance exploding (VE) formulation 42 | # from the paper "Score-Based Generative Modeling through Stochastic 43 | # Differential Equations". 44 | 45 | @persistence.persistent_class 46 | class VELoss: 47 | def __init__(self, sigma_min=0.02, sigma_max=100): 48 | self.sigma_min = sigma_min 49 | self.sigma_max = sigma_max 50 | 51 | def __call__(self, net, images, labels, augment_pipe=None): 52 | rnd_uniform = torch.rand([images.shape[0], 1, 1, 1], device=images.device) 53 | sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rnd_uniform) 54 | weight = 1 / sigma ** 2 55 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 56 | n = torch.randn_like(y) * sigma 57 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 58 | loss = weight * ((D_yn - y) ** 2) 59 | return loss 60 | 61 | #---------------------------------------------------------------------------- 62 | # Improved loss function proposed in the paper "Elucidating the Design Space 63 | # of Diffusion-Based Generative Models" (EDM). 64 | 65 | @persistence.persistent_class 66 | class EDMLoss: 67 | def __init__(self, P_mean=-1.2, P_std=1.2, sigma_data=0.5): 68 | self.P_mean = P_mean 69 | self.P_std = P_std 70 | self.sigma_data = sigma_data 71 | 72 | def __call__(self, net, images, labels=None, augment_pipe=None): 73 | rnd_normal = torch.randn([images.shape[0], 1, 1, 1], device=images.device) 74 | sigma = (rnd_normal * self.P_std + self.P_mean).exp() 75 | weight = (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 76 | y, augment_labels = augment_pipe(images) if augment_pipe is not None else (images, None) 77 | n = torch.randn_like(y) * sigma 78 | D_yn = net(y + n, sigma, labels, augment_labels=augment_labels) 79 | loss = weight * ((D_yn - y) ** 2) 80 | return loss 81 | 82 | #---------------------------------------------------------------------------- 83 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | # empty 9 | -------------------------------------------------------------------------------- /torch_utils/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | import os 9 | import torch 10 | from . import training_stats 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def init(): 15 | if 'MASTER_ADDR' not in os.environ: 16 | os.environ['MASTER_ADDR'] = 'localhost' 17 | if 'MASTER_PORT' not in os.environ: 18 | os.environ['MASTER_PORT'] = '29500' 19 | if 'RANK' not in os.environ: 20 | os.environ['RANK'] = '0' 21 | if 'LOCAL_RANK' not in os.environ: 22 | os.environ['LOCAL_RANK'] = '0' 23 | if 'WORLD_SIZE' not in os.environ: 24 | os.environ['WORLD_SIZE'] = '1' 25 | 26 | backend = 'gloo' if os.name == 'nt' else 'nccl' 27 | torch.distributed.init_process_group(backend=backend, init_method='env://') 28 | torch.cuda.set_device(int(os.environ.get('LOCAL_RANK', '0'))) 29 | 30 | sync_device = torch.device('cuda') if get_world_size() > 1 else None 31 | training_stats.init_multiprocessing(rank=get_rank(), sync_device=sync_device) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def get_rank(): 36 | return torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | def get_world_size(): 41 | return torch.distributed.get_world_size() if torch.distributed.is_initialized() else 1 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | def should_stop(): 46 | return False 47 | 48 | #---------------------------------------------------------------------------- 49 | 50 | def update_progress(cur, total): 51 | _ = cur, total 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def print0(*args, **kwargs): 56 | if get_rank() == 0: 57 | print(*args, **kwargs) 58 | 59 | #---------------------------------------------------------------------------- 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # This work is licensed under a Creative Commons 4 | # Attribution-NonCommercial-ShareAlike 4.0 International License. 5 | # You should have received a copy of the license along with this 6 | # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/ 7 | 8 | """Facilities for pickling Python code alongside other data. 9 | 10 | The pickled code is automatically imported into a separate Python module 11 | during unpickling. This way, any previously exported pickles will remain 12 | usable even if the original code is no longer available, or if the current 13 | version of the code is not consistent with what was originally pickled.""" 14 | 15 | import sys 16 | import pickle 17 | import io 18 | import inspect 19 | import copy 20 | import uuid 21 | import types 22 | import dnnlib 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | _version = 6 # internal version number 27 | _decorators = set() # {decorator_class, ...} 28 | _import_hooks = [] # [hook_function, ...] 29 | _module_to_src_dict = dict() # {module: src, ...} 30 | _src_to_module_dict = dict() # {src: module, ...} 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def persistent_class(orig_class): 35 | r"""Class decorator that extends a given class to save its source code 36 | when pickled. 37 | 38 | Example: 39 | 40 | from torch_utils import persistence 41 | 42 | @persistence.persistent_class 43 | class MyNetwork(torch.nn.Module): 44 | def __init__(self, num_inputs, num_outputs): 45 | super().__init__() 46 | self.fc = MyLayer(num_inputs, num_outputs) 47 | ... 48 | 49 | @persistence.persistent_class 50 | class MyLayer(torch.nn.Module): 51 | ... 52 | 53 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 54 | source code alongside other internal state (e.g., parameters, buffers, 55 | and submodules). This way, any previously exported pickle will remain 56 | usable even if the class definitions have been modified or are no 57 | longer available. 58 | 59 | The decorator saves the source code of the entire Python module 60 | containing the decorated class. It does *not* save the source code of 61 | any imported modules. Thus, the imported modules must be available 62 | during unpickling, also including `torch_utils.persistence` itself. 63 | 64 | It is ok to call functions defined in the same module from the 65 | decorated class. However, if the decorated class depends on other 66 | classes defined in the same module, they must be decorated as well. 67 | This is illustrated in the above example in the case of `MyLayer`. 68 | 69 | It is also possible to employ the decorator just-in-time before 70 | calling the constructor. For example: 71 | 72 | cls = MyLayer 73 | if want_to_make_it_persistent: 74 | cls = persistence.persistent_class(cls) 75 | layer = cls(num_inputs, num_outputs) 76 | 77 | As an additional feature, the decorator also keeps track of the 78 | arguments that were used to construct each instance of the decorated 79 | class. The arguments can be queried via `obj.init_args` and 80 | `obj.init_kwargs`, and they are automatically pickled alongside other 81 | object state. This feature can be disabled on a per-instance basis 82 | by setting `self._record_init_args = False` in the constructor. 83 | 84 | A typical use case is to first unpickle a previous instance of a 85 | persistent class, and then upgrade it to use the latest version of 86 | the source code: 87 | 88 | with open('old_pickle.pkl', 'rb') as f: 89 | old_net = pickle.load(f) 90 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 91 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 92 | """ 93 | assert isinstance(orig_class, type) 94 | if is_persistent(orig_class): 95 | return orig_class 96 | 97 | assert orig_class.__module__ in sys.modules 98 | orig_module = sys.modules[orig_class.__module__] 99 | orig_module_src = _module_to_src(orig_module) 100 | 101 | class Decorator(orig_class): 102 | _orig_module_src = orig_module_src 103 | _orig_class_name = orig_class.__name__ 104 | 105 | def __init__(self, *args, **kwargs): 106 | super().__init__(*args, **kwargs) 107 | record_init_args = getattr(self, '_record_init_args', True) 108 | self._init_args = copy.deepcopy(args) if record_init_args else None 109 | self._init_kwargs = copy.deepcopy(kwargs) if record_init_args else None 110 | assert orig_class.__name__ in orig_module.__dict__ 111 | _check_pickleable(self.__reduce__()) 112 | 113 | @property 114 | def init_args(self): 115 | assert self._init_args is not None 116 | return copy.deepcopy(self._init_args) 117 | 118 | @property 119 | def init_kwargs(self): 120 | assert self._init_kwargs is not None 121 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 122 | 123 | def __reduce__(self): 124 | fields = list(super().__reduce__()) 125 | fields += [None] * max(3 - len(fields), 0) 126 | if fields[0] is not _reconstruct_persistent_obj: 127 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 128 | fields[0] = _reconstruct_persistent_obj # reconstruct func 129 | fields[1] = (meta,) # reconstruct args 130 | fields[2] = None # state dict 131 | return tuple(fields) 132 | 133 | Decorator.__name__ = orig_class.__name__ 134 | Decorator.__module__ = orig_class.__module__ 135 | _decorators.add(Decorator) 136 | return Decorator 137 | 138 | #---------------------------------------------------------------------------- 139 | 140 | def is_persistent(obj): 141 | r"""Test whether the given object or class is persistent, i.e., 142 | whether it will save its source code when pickled. 143 | """ 144 | try: 145 | if obj in _decorators: 146 | return True 147 | except TypeError: 148 | pass 149 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 150 | 151 | #---------------------------------------------------------------------------- 152 | 153 | def import_hook(hook): 154 | r"""Register an import hook that is called whenever a persistent object 155 | is being unpickled. A typical use case is to patch the pickled source 156 | code to avoid errors and inconsistencies when the API of some imported 157 | module has changed. 158 | 159 | The hook should have the following signature: 160 | 161 | hook(meta) -> modified meta 162 | 163 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 164 | 165 | type: Type of the persistent object, e.g. `'class'`. 166 | version: Internal version number of `torch_utils.persistence`. 167 | module_src Original source code of the Python module. 168 | class_name: Class name in the original Python module. 169 | state: Internal state of the object. 170 | 171 | Example: 172 | 173 | @persistence.import_hook 174 | def wreck_my_network(meta): 175 | if meta.class_name == 'MyNetwork': 176 | print('MyNetwork is being imported. I will wreck it!') 177 | meta.module_src = meta.module_src.replace("True", "False") 178 | return meta 179 | """ 180 | assert callable(hook) 181 | _import_hooks.append(hook) 182 | 183 | #---------------------------------------------------------------------------- 184 | 185 | def _reconstruct_persistent_obj(meta): 186 | r"""Hook that is called internally by the `pickle` module to unpickle 187 | a persistent object. 188 | """ 189 | meta = dnnlib.EasyDict(meta) 190 | meta.state = dnnlib.EasyDict(meta.state) 191 | for hook in _import_hooks: 192 | meta = hook(meta) 193 | assert meta is not None 194 | 195 | assert meta.version == _version 196 | module = _src_to_module(meta.module_src) 197 | 198 | assert meta.type == 'class' 199 | orig_class = module.__dict__[meta.class_name] 200 | decorator_class = persistent_class(orig_class) 201 | obj = decorator_class.__new__(decorator_class) 202 | 203 | setstate = getattr(obj, '__setstate__', None) 204 | if callable(setstate): 205 | setstate(meta.state) # pylint: disable=not-callable 206 | else: 207 | obj.__dict__.update(meta.state) 208 | return obj 209 | 210 | #---------------------------------------------------------------------------- 211 | 212 | def _module_to_src(module): 213 | r"""Query the source code of a given Python module. 214 | """ 215 | src = _module_to_src_dict.get(module, None) 216 | if src is None: 217 | src = inspect.getsource(module) 218 | _module_to_src_dict[module] = src 219 | _src_to_module_dict[src] = module 220 | return src 221 | 222 | def _src_to_module(src): 223 | r"""Get or create a Python module for the given source code. 224 | """ 225 | module = _src_to_module_dict.get(src, None) 226 | if module is None: 227 | module_name = "_imported_module_" + uuid.uuid4().hex 228 | module = types.ModuleType(module_name) 229 | sys.modules[module_name] = module 230 | _module_to_src_dict[module] = src 231 | _src_to_module_dict[src] = module 232 | exec(src, module.__dict__) # pylint: disable=exec-used 233 | return module 234 | 235 | #---------------------------------------------------------------------------- 236 | 237 | def _check_pickleable(obj): 238 | r"""Check that the given object is pickleable, raising an exception if 239 | it is not. This function is expected to be considerably more efficient 240 | than actually pickling the object. 241 | """ 242 | def recurse(obj): 243 | if isinstance(obj, (list, tuple, set)): 244 | return [recurse(x) for x in obj] 245 | if isinstance(obj, dict): 246 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 247 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 248 | return None # Python primitive types are pickleable. 249 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 250 | return None # NumPy arrays and PyTorch tensors are pickleable. 251 | if is_persistent(obj): 252 | return None # Persistent objects are pickleable, by virtue of the constructor check. 253 | return obj 254 | with io.BytesIO() as f: 255 | pickle.dump(recurse(obj), f) 256 | 257 | #---------------------------------------------------------------------------- 258 | --------------------------------------------------------------------------------