├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── README.md.backup ├── align.py ├── autoencoding.ipynb ├── choices.py ├── cog.yaml ├── config.py ├── config_base.py ├── data_resize_bedroom.py ├── data_resize_celeba.py ├── data_resize_celebahq.py ├── data_resize_ffhq.py ├── data_resize_horse.py ├── dataset.py ├── dataset_util.py ├── datasets └── celeba_anno │ ├── CelebAMask-HQ-attribute-anno.txt │ ├── CelebAMask-HQ-pose-anno.txt │ └── list_attr_celeba.txt ├── diffusion ├── __init__.py ├── base.py ├── diffusion.py └── resample.py ├── dist_utils.py ├── evals ├── ffhq128_autoenc_130M.txt └── ffhq128_autoenc_latent.txt ├── experiment.py ├── experiment_classifier.py ├── imgs └── sandy.JPG ├── imgs_align └── sandy.png ├── imgs_interpolate ├── 1_a.png └── 1_b.png ├── imgs_manipulated ├── compare.png ├── output.png └── sandy-wavyhair.png ├── install_requirements_for_colab.sh ├── interpolate.ipynb ├── lmdb_writer.py ├── manipulate.ipynb ├── manipulate_note.ipynb ├── metrics.py ├── model ├── __init__.py ├── blocks.py ├── latentnet.py ├── nn.py ├── unet.py └── unet_autoenc.py ├── predict.py ├── renderer.py ├── requirement_for_colab.txt ├── requirements.txt ├── run_bedroom128.py ├── run_bedroom128_ddim.py ├── run_celeba64.py ├── run_ffhq128.py ├── run_ffhq128_cls.py ├── run_ffhq128_ddim.py ├── run_ffhq256.py ├── run_ffhq256.sh ├── run_ffhq256_cls.py ├── run_ffhq256_latent.py ├── run_horse128.py ├── run_horse128_ddim.py ├── sample.ipynb ├── ssim.py ├── templates.py ├── templates_cls.py └── templates_latent.py /.gitignore: -------------------------------------------------------------------------------- 1 | temp 2 | __pycache__ 3 | generated 4 | latent_infer 5 | datasets/bedroom256.lmdb 6 | datasets/horse256.lmdb 7 | datasets/celebahq 8 | datasets/celebahq256.lmdb 9 | datasets/ffhq 10 | datasets/ffhq256.lmdb 11 | datasets/celeba.lmdb 12 | datasets/celeba 13 | checkpoints 14 | 15 | .ipynb_checkpoints/ -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.formatting.provider": "yapf" 3 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 VISTEC - Vidyasirimedhi Institute of Science and Technology 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Official implementation of Diffusion Autoencoders 2 | 3 | A CVPR 2022 (ORAL) paper ([paper](https://openaccess.thecvf.com/content/CVPR2022/html/Preechakul_Diffusion_Autoencoders_Toward_a_Meaningful_and_Decodable_Representation_CVPR_2022_paper.html), [site](https://diff-ae.github.io/), [5-min video](https://youtu.be/i3rjEsiHoUU)): 4 | 5 | ``` 6 | @inproceedings{preechakul2021diffusion, 7 | title={Diffusion Autoencoders: Toward a Meaningful and Decodable Representation}, 8 | author={Preechakul, Konpat and Chatthee, Nattanat and Wizadwongsa, Suttisak and Suwajanakorn, Supasorn}, 9 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 10 | year={2022}, 11 | } 12 | ``` 13 | 14 | ## Usage 15 | 16 | ⚙️ Try a Colab walkthrough: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1OTfwkklN-IEd4hFk4LnweOleyDtS4XTh/view?usp=sharing) 17 | 18 | 🤗 Try a web demo: [![Replicate](https://replicate.com/cjwbw/diffae/badge)](https://replicate.com/cjwbw/diffae) 19 | 20 | Note: Since we expect a lot of changes on the codebase, please fork the repo before using. 21 | 22 | ### Prerequisites 23 | 24 | See `requirements.txt` 25 | 26 | ``` 27 | pip install -r requirements.txt 28 | ``` 29 | 30 | ### Quick start 31 | 32 | A jupyter notebook. 33 | 34 | For unconditional generation: `sample.ipynb` 35 | 36 | For manipulation: `manipulate.ipynb` 37 | 38 | For interpolation: `interpolate.ipynb` 39 | 40 | For autoencoding: `autoencoding.ipynb` 41 | 42 | Aligning your own images: 43 | 44 | 1. Put images into the `imgs` directory 45 | 2. Run `align.py` (need to `pip install dlib requests`) 46 | 3. Result images will be available in `imgs_align` directory 47 | 48 | 49 | 50 | 53 | 56 | 59 | 60 |
51 | Original in imgs directory
52 |
54 | Aligned with align.py
55 |
57 | Using manipulate.ipynb
58 |
61 | 62 | 63 | ### Checkpoints 64 | 65 | We provide checkpoints for the following models: 66 | 67 | 1. DDIM: **FFHQ128** ([72M](https://drive.google.com/drive/folders/1-fa46UPSgy9ximKngBflgSj3u87-DLrw), [130M](https://drive.google.com/drive/folders/1-Sqes07fs1y9sAYXuYWSoDE_xxTtH4yx)), [**Bedroom128**](https://drive.google.com/drive/folders/1-_8LZd5inoAOBT-hO5f7RYivt95FbYT1), [**Horse128**](https://drive.google.com/drive/folders/10Hq3zIlJs9ZSiXDQVYuVJVf0cX4a_nDB) 68 | 2. DiffAE (autoencoding only): [**FFHQ256**](https://drive.google.com/drive/folders/1-5zfxT6Gl-GjxM7z9ZO2AHlB70tfmF6V), **FFHQ128** ([72M](https://drive.google.com/drive/folders/10bmB6WhLkgxybkhso5g3JmIFPAnmZMQO), [130M](https://drive.google.com/drive/folders/10UNtFNfxbHBPkoIh003JkSPto5s-VbeN)), [**Bedroom128**](https://drive.google.com/drive/folders/12EdjbIKnvP5RngKsR0UU-4kgpPAaYtlp), [**Horse128**](https://drive.google.com/drive/folders/12EtTRXzQc5uPHscpjIcci-Rg-OGa_N30) 69 | 3. DiffAE (with latent DPM, can sample): [**FFHQ256**](https://drive.google.com/drive/folders/1-H8WzKc65dEONN-DQ87TnXc23nTXDTYb), [**FFHQ128**](https://drive.google.com/drive/folders/11pdjMQ6NS8GFFiGOq3fziNJxzXU1Mw3l), [**Bedroom128**](https://drive.google.com/drive/folders/11mdxv2lVX5Em8TuhNJt-Wt2XKt25y8zU), [**Horse128**](https://drive.google.com/drive/folders/11k8XNDK3ENxiRnPSUdJ4rnagJYo4uKEo) 70 | 4. DiffAE's classifiers (for manipulation): [**FFHQ256's latent on CelebAHQ**](https://drive.google.com/drive/folders/117Wv7RZs_gumgrCOIhDEWgsNy6BRJorg), [**FFHQ128's latent on CelebAHQ**](https://drive.google.com/drive/folders/11EYIyuK6IX44C8MqreUyMgPCNiEnwhmI) 71 | 72 | Checkpoints ought to be put into a separate directory `checkpoints`. 73 | Download the checkpoints and put them into `checkpoints` directory. It should look like this: 74 | 75 | ``` 76 | checkpoints/ 77 | - bedroom128_autoenc 78 | - last.ckpt # diffae checkpoint 79 | - latent.ckpt # predicted z_sem on the dataset 80 | - bedroom128_autoenc_latent 81 | - last.ckpt # diffae + latent DPM checkpoint 82 | - bedroom128_ddpm 83 | - ... 84 | ``` 85 | 86 | 87 | ### LMDB Datasets 88 | 89 | We do not own any of the following datasets. We provide the LMDB ready-to-use dataset for the sake of convenience. 90 | 91 | - [FFHQ](https://1drv.ms/f/s!Ar2O0vx8sW70uLV1Ivk2pTjam1A8VA) 92 | - [CelebAHQ](https://1drv.ms/f/s!Ar2O0vx8sW70uL4GMeWEciHkHdH6vQ) 93 | 94 | **Broken links** 95 | 96 | Note: I'm trying to recover the following links. 97 | 98 | - [CelebA](https://drive.google.com/drive/folders/1HJAhK2hLYcT_n0gWlCu5XxdZj-bPekZ0?usp=sharing) 99 | - [LSUN Bedroom](https://drive.google.com/drive/folders/1O_3aT3LtY1YDE2pOQCp6MFpCk7Pcpkhb?usp=sharing) 100 | - [LSUN Horse](https://drive.google.com/drive/folders/1ooHW7VivZUs4i5CarPaWxakCwfeqAK8l?usp=sharing) 101 | 102 | The directory tree should be: 103 | 104 | ``` 105 | datasets/ 106 | - bedroom256.lmdb 107 | - celebahq256.lmdb 108 | - celeba.lmdb 109 | - ffhq256.lmdb 110 | - horse256.lmdb 111 | ``` 112 | 113 | You can also download from the original sources, and use our provided codes to package them as LMDB files. 114 | Original sources for each dataset is as follows: 115 | 116 | - FFHQ (https://github.com/NVlabs/ffhq-dataset) 117 | - CelebAHQ (https://github.com/switchablenorms/CelebAMask-HQ) 118 | - CelebA (https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) 119 | - LSUN (https://github.com/fyu/lsun) 120 | 121 | The conversion codes are provided as: 122 | 123 | ``` 124 | data_resize_bedroom.py 125 | data_resize_celebhq.py 126 | data_resize_celeba.py 127 | data_resize_ffhq.py 128 | data_resize_horse.py 129 | ``` 130 | 131 | Google drive: https://drive.google.com/drive/folders/1abNP4QKGbNnymjn8607BF0cwxX2L23jh?usp=sharing 132 | 133 | 134 | ## Training 135 | 136 | We provide scripts for training & evaluate DDIM and DiffAE (including latent DPM) on the following datasets: FFHQ128, FFHQ256, Bedroom128, Horse128, Celeba64 (D2C's crop). 137 | Usually, the evaluation results (FID's) will be available in `eval` directory. 138 | 139 | Note: Most experiment requires at least 4x V100s during training the DPM models while requiring 1x 2080Ti during training the accompanying latent DPM. 140 | 141 | 142 | 143 | **FFHQ128** 144 | ``` 145 | # diffae 146 | python run_ffhq128.py 147 | # ddim 148 | python run_ffhq128_ddim.py 149 | ``` 150 | 151 | A classifier (for manipulation) can be trained using: 152 | ``` 153 | python run_ffhq128_cls.py 154 | ``` 155 | 156 | **FFHQ256** 157 | 158 | We only trained the DiffAE due to high computation cost. 159 | This requires 8x V100s. 160 | ``` 161 | sbatch run_ffhq256.py 162 | ``` 163 | 164 | After the task is done, you need to train the latent DPM (requiring only 1x 2080Ti) 165 | ``` 166 | python run_ffhq256_latent.py 167 | ``` 168 | 169 | A classifier (for manipulation) can be trained using: 170 | ``` 171 | python run_ffhq256_cls.py 172 | ``` 173 | 174 | **Bedroom128** 175 | 176 | ``` 177 | # diffae 178 | python run_bedroom128.py 179 | # ddim 180 | python run_bedroom128_ddim.py 181 | ``` 182 | 183 | **Horse128** 184 | 185 | ``` 186 | # diffae 187 | python run_horse128.py 188 | # ddim 189 | python run_horse128_ddim.py 190 | ``` 191 | 192 | **Celeba64** 193 | 194 | This experiment can be run on 2080Ti's. 195 | 196 | ``` 197 | # diffae 198 | python run_celeba64.py 199 | ``` 200 | -------------------------------------------------------------------------------- /README.md.backup: -------------------------------------------------------------------------------- 1 | # Official implementation of Diffusion Autoencoders 2 | 3 | A CVPR 2022 paper: 4 | 5 | > Preechakul, Konpat, Nattanat Chatthee, Suttisak Wizadwongsa, and Supasorn Suwajanakorn. 2021. “Diffusion Autoencoders: Toward a Meaningful and Decodable Representation.” arXiv [cs.CV]. arXiv. http://arxiv.org/abs/2111.15640. 6 | 7 | ## Usage 8 | 9 | Note: Since we expect a lot of changes on the codebase, please fork the repo before using. 10 | 11 | ### Prerequisites 12 | 13 | See `requirements.txt` 14 | 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ### Quick start 20 | 21 | A jupyter notebook. 22 | 23 | For unconditional generation: `sample.ipynb` 24 | 25 | For manipulation: `manipulate.ipynb` 26 | 27 | Aligning your own images: 28 | 29 | 1. Put images into the `imgs` directory 30 | 2. Run `align.py` (need to `pip install dlib requests`) 31 | 3. Result images will be available in `imgs_align` directory 32 | 33 | 34 | 39 | 40 | | ![](imgs/sandy.JPG) | ![](imgs_align/sandy.png) | ![](imgs_manipulated/sandy-wavyhair.png) | 41 | |---|---|---| 42 | 43 | 44 | ### Checkpoints 45 | 46 | We provide checkpoints for the following models: 47 | 48 | 1. DDIM: **FFHQ128** ([72M](https://drive.google.com/drive/folders/1-J8FPNZOQxSqpfTpwRXawLi2KKGL1qlK?usp=sharing), [130M](https://drive.google.com/drive/folders/17T5YJXpYdgE6cWltN8gZFxRsJzpVxnLh?usp=sharing)), [**Bedroom128**](https://drive.google.com/drive/folders/19s-lAiK7fGD5Meo5obNV5o0L3MfqU0Sk?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1PiC5JWLcd8mZW9cghDCR0V4Hx0QCXOor?usp=sharing) 49 | 2. DiffAE (autoencoding only): [**FFHQ256**](https://drive.google.com/drive/folders/1hTP9QbYXwv_Nl5sgcZNH0yKprJx7ivC5?usp=sharing), **FFHQ128** ([72M](https://drive.google.com/drive/folders/15QHmZP1G5jEMh80R1Nbtdb4ZKb6VvfII?usp=sharing), [130M](https://drive.google.com/drive/folders/1UlwLwgv16cEqxTn7g-V2ykIyopmY_fVz?usp=sharing)), [**Bedroom128**](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing) 50 | 3. DiffAE (with latent DPM, can sample): [**FFHQ256**](https://drive.google.com/drive/folders/1MonJKYwVLzvCFYuVhp-l9mChq5V2XI6w?usp=sharing), [**FFHQ128**](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing), [**Bedroom128**](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing) 51 | 4. DiffAE's classifiers (for manipulation): [**FFHQ256's latent on CelebAHQ**](https://drive.google.com/drive/folders/1QGkTfvNhgi_TbbV8GbX1Emrp0lStsqLj?usp=sharing), [**FFHQ128's latent on CelebAHQ**](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing) 52 | 53 | Checkpoints ought to be put into a separate directory `checkpoints`. 54 | Download the checkpoints and put them into `checkpoints` directory. It should look like this: 55 | 56 | ``` 57 | checkpoints/ 58 | - bedroom128_autoenc 59 | - last.ckpt # diffae checkpoint 60 | - latent.ckpt # predicted z_sem on the dataset 61 | - bedroom128_autoenc_latent 62 | - last.ckpt # diffae + latent DPM checkpoint 63 | - bedroom128_ddpm 64 | - ... 65 | ``` 66 | 67 | 68 | ### LMDB Datasets 69 | 70 | We do not own any of the following datasets. We provide the LMDB ready-to-use dataset for the sake of convenience. 71 | 72 | - [FFHQ](https://drive.google.com/drive/folders/1ww7itaSo53NDMa0q-wn-3HWZ3HHqK1IK?usp=sharing) 73 | - [CelebAHQ](https://drive.google.com/drive/folders/1SX3JuVHjYA8sA28EGxr_IoHJ63s4Btbl?usp=sharing) 74 | - [CelebA](https://drive.google.com/drive/folders/1HJAhK2hLYcT_n0gWlCu5XxdZj-bPekZ0?usp=sharing) 75 | - [LSUN Bedroom](https://drive.google.com/drive/folders/1O_3aT3LtY1YDE2pOQCp6MFpCk7Pcpkhb?usp=sharing) 76 | - [LSUN Horse](https://drive.google.com/drive/folders/1ooHW7VivZUs4i5CarPaWxakCwfeqAK8l?usp=sharing) 77 | 78 | The directory tree should be: 79 | 80 | ``` 81 | datasets/ 82 | - bedroom256.lmdb 83 | - celebahq256.lmdb 84 | - celeba.lmdb 85 | - ffhq256.lmdb 86 | - horse256.lmdb 87 | ``` 88 | 89 | You can also download from the original sources, and use our provided codes to package them as LMDB files. 90 | Original sources for each dataset is as follows: 91 | 92 | - FFHQ (https://github.com/NVlabs/ffhq-dataset) 93 | - CelebAHQ (https://github.com/switchablenorms/CelebAMask-HQ) 94 | - CelebA (https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) 95 | - LSUN (https://github.com/fyu/lsun) 96 | 97 | The conversion codes are provided as: 98 | 99 | ``` 100 | data_resize_bedroom.py 101 | data_resize_celebhq.py 102 | data_resize_celeba.py 103 | data_resize_ffhq.py 104 | data_resize_horse.py 105 | ``` 106 | 107 | Google drive: https://drive.google.com/drive/folders/1abNP4QKGbNnymjn8607BF0cwxX2L23jh?usp=sharing 108 | 109 | 110 | ## Training 111 | 112 | We provide scripts for training & evaluate DDIM and DiffAE (including latent DPM) on the following datasets: FFHQ128, FFHQ256, Bedroom128, Horse128, Celeba64 (D2C's crop). 113 | Usually, the evaluation results (FID's) will be available in `eval` directory. 114 | 115 | Note: Most experiment requires at least 4x V100s during training the DPM models while requiring 1x 2080Ti during training the accompanying latent DPM. 116 | 117 | 118 | 119 | **FFHQ128** 120 | ``` 121 | # diffae 122 | python run_ffhq128.py 123 | # ddim 124 | python run_ffhq128_ddim.py 125 | ``` 126 | 127 | A classifier (for manipulation) can be trained using: 128 | ``` 129 | python run_ffhq128_cls.py 130 | ``` 131 | 132 | **FFHQ256** 133 | 134 | We only trained the DiffAE due to high computation cost. 135 | This requires 8x V100s. 136 | ``` 137 | sbatch run_ffhq256.py 138 | ``` 139 | 140 | After the task is done, you need to train the latent DPM (requiring only 1x 2080Ti) 141 | ``` 142 | python run_ffhq256_latent.py 143 | ``` 144 | 145 | A classifier (for manipulation) can be trained using: 146 | ``` 147 | python run_ffhq256_cls.py 148 | ``` 149 | 150 | **Bedroom128** 151 | 152 | ``` 153 | # diffae 154 | python run_bedroom128.py 155 | # ddim 156 | python run_bedroom128_ddim.py 157 | ``` 158 | 159 | **Horse128** 160 | 161 | ``` 162 | # diffae 163 | python run_horse128.py 164 | # ddim 165 | python run_horse128_ddim.py 166 | ``` 167 | 168 | **Celeba64** 169 | 170 | This experiment can be run on 2080Ti's. 171 | 172 | ``` 173 | # diffae 174 | python run_celeba64.py 175 | ``` -------------------------------------------------------------------------------- /align.py: -------------------------------------------------------------------------------- 1 | import bz2 2 | import os 3 | import os.path as osp 4 | import sys 5 | from multiprocessing import Pool 6 | 7 | import dlib 8 | import numpy as np 9 | import PIL.Image 10 | import requests 11 | import scipy.ndimage 12 | from tqdm import tqdm 13 | from argparse import ArgumentParser 14 | 15 | LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2' 16 | 17 | 18 | def image_align(src_file, 19 | dst_file, 20 | face_landmarks, 21 | output_size=1024, 22 | transform_size=4096, 23 | enable_padding=True): 24 | # Align function from FFHQ dataset pre-processing step 25 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py 26 | 27 | lm = np.array(face_landmarks) 28 | lm_chin = lm[0:17] # left-right 29 | lm_eyebrow_left = lm[17:22] # left-right 30 | lm_eyebrow_right = lm[22:27] # left-right 31 | lm_nose = lm[27:31] # top-down 32 | lm_nostrils = lm[31:36] # top-down 33 | lm_eye_left = lm[36:42] # left-clockwise 34 | lm_eye_right = lm[42:48] # left-clockwise 35 | lm_mouth_outer = lm[48:60] # left-clockwise 36 | lm_mouth_inner = lm[60:68] # left-clockwise 37 | 38 | # Calculate auxiliary vectors. 39 | eye_left = np.mean(lm_eye_left, axis=0) 40 | eye_right = np.mean(lm_eye_right, axis=0) 41 | eye_avg = (eye_left + eye_right) * 0.5 42 | eye_to_eye = eye_right - eye_left 43 | mouth_left = lm_mouth_outer[0] 44 | mouth_right = lm_mouth_outer[6] 45 | mouth_avg = (mouth_left + mouth_right) * 0.5 46 | eye_to_mouth = mouth_avg - eye_avg 47 | 48 | # Choose oriented crop rectangle. 49 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1] 50 | x /= np.hypot(*x) 51 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8) 52 | y = np.flipud(x) * [-1, 1] 53 | c = eye_avg + eye_to_mouth * 0.1 54 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y]) 55 | qsize = np.hypot(*x) * 2 56 | 57 | # Load in-the-wild image. 58 | if not os.path.isfile(src_file): 59 | print( 60 | '\nCannot find source image. Please run "--wilds" before "--align".' 61 | ) 62 | return 63 | img = PIL.Image.open(src_file) 64 | img = img.convert('RGB') 65 | 66 | # Shrink. 67 | shrink = int(np.floor(qsize / output_size * 0.5)) 68 | if shrink > 1: 69 | rsize = (int(np.rint(float(img.size[0]) / shrink)), 70 | int(np.rint(float(img.size[1]) / shrink))) 71 | img = img.resize(rsize, PIL.Image.ANTIALIAS) 72 | quad /= shrink 73 | qsize /= shrink 74 | 75 | # Crop. 76 | border = max(int(np.rint(qsize * 0.1)), 3) 77 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), 78 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) 79 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), 80 | min(crop[2] + border, 81 | img.size[0]), min(crop[3] + border, img.size[1])) 82 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]: 83 | img = img.crop(crop) 84 | quad -= crop[0:2] 85 | 86 | # Pad. 87 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), 88 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1])))) 89 | pad = (max(-pad[0] + border, 90 | 0), max(-pad[1] + border, 91 | 0), max(pad[2] - img.size[0] + border, 92 | 0), max(pad[3] - img.size[1] + border, 0)) 93 | if enable_padding and max(pad) > border - 4: 94 | pad = np.maximum(pad, int(np.rint(qsize * 0.3))) 95 | img = np.pad(np.float32(img), 96 | ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect') 97 | h, w, _ = img.shape 98 | y, x, _ = np.ogrid[:h, :w, :1] 99 | mask = np.maximum( 100 | 1.0 - 101 | np.minimum(np.float32(x) / pad[0], 102 | np.float32(w - 1 - x) / pad[2]), 1.0 - 103 | np.minimum(np.float32(y) / pad[1], 104 | np.float32(h - 1 - y) / pad[3])) 105 | blur = qsize * 0.02 106 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - 107 | img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0) 108 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0) 109 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 110 | 'RGB') 111 | quad += pad[:2] 112 | 113 | # Transform. 114 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD, 115 | (quad + 0.5).flatten(), PIL.Image.BILINEAR) 116 | if output_size < transform_size: 117 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS) 118 | 119 | # Save aligned image. 120 | img.save(dst_file, 'PNG') 121 | 122 | 123 | class LandmarksDetector: 124 | def __init__(self, predictor_model_path): 125 | """ 126 | :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file 127 | """ 128 | self.detector = dlib.get_frontal_face_detector( 129 | ) # cnn_face_detection_model_v1 also can be used 130 | self.shape_predictor = dlib.shape_predictor(predictor_model_path) 131 | 132 | def get_landmarks(self, image): 133 | img = dlib.load_rgb_image(image) 134 | dets = self.detector(img, 1) 135 | 136 | for detection in dets: 137 | face_landmarks = [ 138 | (item.x, item.y) 139 | for item in self.shape_predictor(img, detection).parts() 140 | ] 141 | yield face_landmarks 142 | 143 | 144 | def unpack_bz2(src_path): 145 | dst_path = src_path[:-4] 146 | if os.path.exists(dst_path): 147 | print('cached') 148 | return dst_path 149 | data = bz2.BZ2File(src_path).read() 150 | with open(dst_path, 'wb') as fp: 151 | fp.write(data) 152 | return dst_path 153 | 154 | 155 | def work_landmark(raw_img_path, img_name, face_landmarks): 156 | face_img_name = '%s.png' % (os.path.splitext(img_name)[0], ) 157 | aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name) 158 | if os.path.exists(aligned_face_path): 159 | return 160 | image_align(raw_img_path, 161 | aligned_face_path, 162 | face_landmarks, 163 | output_size=256) 164 | 165 | 166 | def get_file(src, tgt): 167 | if os.path.exists(tgt): 168 | print('cached') 169 | return tgt 170 | tgt_dir = os.path.dirname(tgt) 171 | if not os.path.exists(tgt_dir): 172 | os.makedirs(tgt_dir) 173 | file = requests.get(src) 174 | open(tgt, 'wb').write(file.content) 175 | return tgt 176 | 177 | 178 | if __name__ == "__main__": 179 | """ 180 | Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step 181 | python align_images.py /raw_images /aligned_images 182 | """ 183 | parser = ArgumentParser() 184 | parser.add_argument("-i", 185 | "--input_imgs_path", 186 | type=str, 187 | default="imgs", 188 | help="input images directory path") 189 | parser.add_argument("-o", 190 | "--output_imgs_path", 191 | type=str, 192 | default="imgs_align", 193 | help="output images directory path") 194 | 195 | args = parser.parse_args() 196 | 197 | # takes very long time ... 198 | landmarks_model_path = unpack_bz2( 199 | get_file( 200 | 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2', 201 | 'temp/shape_predictor_68_face_landmarks.dat.bz2')) 202 | 203 | # RAW_IMAGES_DIR = sys.argv[1] 204 | # ALIGNED_IMAGES_DIR = sys.argv[2] 205 | RAW_IMAGES_DIR = args.input_imgs_path 206 | ALIGNED_IMAGES_DIR = args.output_imgs_path 207 | 208 | if not osp.exists(ALIGNED_IMAGES_DIR): os.makedirs(ALIGNED_IMAGES_DIR) 209 | 210 | files = os.listdir(RAW_IMAGES_DIR) 211 | print(f'total img files {len(files)}') 212 | with tqdm(total=len(files)) as progress: 213 | 214 | def cb(*args): 215 | # print('update') 216 | progress.update() 217 | 218 | def err_cb(e): 219 | print('error:', e) 220 | 221 | with Pool(8) as pool: 222 | res = [] 223 | landmarks_detector = LandmarksDetector(landmarks_model_path) 224 | for img_name in files: 225 | raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name) 226 | # print('img_name:', img_name) 227 | for i, face_landmarks in enumerate( 228 | landmarks_detector.get_landmarks(raw_img_path), 229 | start=1): 230 | # assert i == 1, f'{i}' 231 | # print(i, face_landmarks) 232 | # face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i) 233 | # aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name) 234 | # image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=256) 235 | 236 | work_landmark(raw_img_path, img_name, face_landmarks) 237 | progress.update() 238 | 239 | # job = pool.apply_async( 240 | # work_landmark, 241 | # (raw_img_path, img_name, face_landmarks), 242 | # callback=cb, 243 | # error_callback=err_cb, 244 | # ) 245 | # res.append(job) 246 | 247 | # pool.close() 248 | # pool.join() 249 | print(f"output aligned images at: {ALIGNED_IMAGES_DIR}") 250 | -------------------------------------------------------------------------------- /choices.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from torch import nn 3 | 4 | 5 | class TrainMode(Enum): 6 | # manipulate mode = training the classifier 7 | manipulate = 'manipulate' 8 | # default trainin mode! 9 | diffusion = 'diffusion' 10 | # default latent training mode! 11 | # fitting the a DDPM to a given latent 12 | latent_diffusion = 'latentdiffusion' 13 | 14 | def is_manipulate(self): 15 | return self in [ 16 | TrainMode.manipulate, 17 | ] 18 | 19 | def is_diffusion(self): 20 | return self in [ 21 | TrainMode.diffusion, 22 | TrainMode.latent_diffusion, 23 | ] 24 | 25 | def is_autoenc(self): 26 | # the network possibly does autoencoding 27 | return self in [ 28 | TrainMode.diffusion, 29 | ] 30 | 31 | def is_latent_diffusion(self): 32 | return self in [ 33 | TrainMode.latent_diffusion, 34 | ] 35 | 36 | def use_latent_net(self): 37 | return self.is_latent_diffusion() 38 | 39 | def require_dataset_infer(self): 40 | """ 41 | whether training in this mode requires the latent variables to be available? 42 | """ 43 | # this will precalculate all the latents before hand 44 | # and the dataset will be all the predicted latents 45 | return self in [ 46 | TrainMode.latent_diffusion, 47 | TrainMode.manipulate, 48 | ] 49 | 50 | 51 | class ManipulateMode(Enum): 52 | """ 53 | how to train the classifier to manipulate 54 | """ 55 | # train on whole celeba attr dataset 56 | celebahq_all = 'celebahq_all' 57 | # celeba with D2C's crop 58 | d2c_fewshot = 'd2cfewshot' 59 | d2c_fewshot_allneg = 'd2cfewshotallneg' 60 | 61 | def is_celeba_attr(self): 62 | return self in [ 63 | ManipulateMode.d2c_fewshot, 64 | ManipulateMode.d2c_fewshot_allneg, 65 | ManipulateMode.celebahq_all, 66 | ] 67 | 68 | def is_single_class(self): 69 | return self in [ 70 | ManipulateMode.d2c_fewshot, 71 | ManipulateMode.d2c_fewshot_allneg, 72 | ] 73 | 74 | def is_fewshot(self): 75 | return self in [ 76 | ManipulateMode.d2c_fewshot, 77 | ManipulateMode.d2c_fewshot_allneg, 78 | ] 79 | 80 | def is_fewshot_allneg(self): 81 | return self in [ 82 | ManipulateMode.d2c_fewshot_allneg, 83 | ] 84 | 85 | 86 | class ModelType(Enum): 87 | """ 88 | Kinds of the backbone models 89 | """ 90 | 91 | # unconditional ddpm 92 | ddpm = 'ddpm' 93 | # autoencoding ddpm cannot do unconditional generation 94 | autoencoder = 'autoencoder' 95 | 96 | def has_autoenc(self): 97 | return self in [ 98 | ModelType.autoencoder, 99 | ] 100 | 101 | def can_sample(self): 102 | return self in [ModelType.ddpm] 103 | 104 | 105 | class ModelName(Enum): 106 | """ 107 | List of all supported model classes 108 | """ 109 | 110 | beatgans_ddpm = 'beatgans_ddpm' 111 | beatgans_autoenc = 'beatgans_autoenc' 112 | 113 | 114 | class ModelMeanType(Enum): 115 | """ 116 | Which type of output the model predicts. 117 | """ 118 | 119 | eps = 'eps' # the model predicts epsilon 120 | 121 | 122 | class ModelVarType(Enum): 123 | """ 124 | What is used as the model's output variance. 125 | 126 | The LEARNED_RANGE option has been added to allow the model to predict 127 | values between FIXED_SMALL and FIXED_LARGE, making its job easier. 128 | """ 129 | 130 | # posterior beta_t 131 | fixed_small = 'fixed_small' 132 | # beta_t 133 | fixed_large = 'fixed_large' 134 | 135 | 136 | class LossType(Enum): 137 | mse = 'mse' # use raw MSE loss (and KL when learning variances) 138 | l1 = 'l1' 139 | 140 | 141 | class GenerativeType(Enum): 142 | """ 143 | How's a sample generated 144 | """ 145 | 146 | ddpm = 'ddpm' 147 | ddim = 'ddim' 148 | 149 | 150 | class OptimizerType(Enum): 151 | adam = 'adam' 152 | adamw = 'adamw' 153 | 154 | 155 | class Activation(Enum): 156 | none = 'none' 157 | relu = 'relu' 158 | lrelu = 'lrelu' 159 | silu = 'silu' 160 | tanh = 'tanh' 161 | 162 | def get_act(self): 163 | if self == Activation.none: 164 | return nn.Identity() 165 | elif self == Activation.relu: 166 | return nn.ReLU() 167 | elif self == Activation.lrelu: 168 | return nn.LeakyReLU(negative_slope=0.2) 169 | elif self == Activation.silu: 170 | return nn.SiLU() 171 | elif self == Activation.tanh: 172 | return nn.Tanh() 173 | else: 174 | raise NotImplementedError() 175 | 176 | 177 | class ManipulateLossType(Enum): 178 | bce = 'bce' 179 | mse = 'mse' -------------------------------------------------------------------------------- /cog.yaml: -------------------------------------------------------------------------------- 1 | build: 2 | cuda: "10.2" 3 | gpu: true 4 | python_version: "3.8" 5 | system_packages: 6 | - "libgl1-mesa-glx" 7 | - "libglib2.0-0" 8 | python_packages: 9 | - "numpy==1.21.5" 10 | - "cmake==3.23.3" 11 | - "ipython==7.21.0" 12 | - "opencv-python==4.5.4.58" 13 | - "pandas==1.1.5" 14 | - "lmdb==1.2.1" 15 | - "lpips==0.1.4" 16 | - "pytorch-fid==0.2.0" 17 | - "ftfy==6.1.1" 18 | - "scipy==1.5.4" 19 | - "torch==1.9.1" 20 | - "torchvision==0.10.1" 21 | - "tqdm==4.62.3" 22 | - "regex==2022.7.25" 23 | - "Pillow==9.2.0" 24 | - "pytorch_lightning==1.7.0" 25 | 26 | run: 27 | - pip install dlib 28 | 29 | predict: "predict.py:Predictor" 30 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from model.unet import ScaleAt 2 | from model.latentnet import * 3 | from diffusion.resample import UniformSampler 4 | from diffusion.diffusion import space_timesteps 5 | from typing import Tuple 6 | 7 | from torch.utils.data import DataLoader 8 | 9 | from config_base import BaseConfig 10 | from dataset import * 11 | from diffusion import * 12 | from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule 13 | from model import * 14 | from choices import * 15 | from multiprocessing import get_context 16 | import os 17 | from dataset_util import * 18 | from torch.utils.data.distributed import DistributedSampler 19 | 20 | data_paths = { 21 | 'ffhqlmdb256': 22 | os.path.expanduser('datasets/ffhq256.lmdb'), 23 | # used for training a classifier 24 | 'celeba': 25 | os.path.expanduser('datasets/celeba'), 26 | # used for training DPM models 27 | 'celebalmdb': 28 | os.path.expanduser('datasets/celeba.lmdb'), 29 | 'celebahq': 30 | os.path.expanduser('datasets/celebahq256.lmdb'), 31 | 'horse256': 32 | os.path.expanduser('datasets/horse256.lmdb'), 33 | 'bedroom256': 34 | os.path.expanduser('datasets/bedroom256.lmdb'), 35 | 'celeba_anno': 36 | os.path.expanduser('datasets/celeba_anno/list_attr_celeba.txt'), 37 | 'celebahq_anno': 38 | os.path.expanduser( 39 | 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'), 40 | 'celeba_relight': 41 | os.path.expanduser('datasets/celeba_hq_light/celeba_light.txt'), 42 | } 43 | 44 | 45 | @dataclass 46 | class PretrainConfig(BaseConfig): 47 | name: str 48 | path: str 49 | 50 | 51 | @dataclass 52 | class TrainConfig(BaseConfig): 53 | # random seed 54 | seed: int = 0 55 | train_mode: TrainMode = TrainMode.diffusion 56 | train_cond0_prob: float = 0 57 | train_pred_xstart_detach: bool = True 58 | train_interpolate_prob: float = 0 59 | train_interpolate_img: bool = False 60 | manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all 61 | manipulate_cls: str = None 62 | manipulate_shots: int = None 63 | manipulate_loss: ManipulateLossType = ManipulateLossType.bce 64 | manipulate_znormalize: bool = False 65 | manipulate_seed: int = 0 66 | accum_batches: int = 1 67 | autoenc_mid_attn: bool = True 68 | batch_size: int = 16 69 | batch_size_eval: int = None 70 | beatgans_gen_type: GenerativeType = GenerativeType.ddim 71 | beatgans_loss_type: LossType = LossType.mse 72 | beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps 73 | beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large 74 | beatgans_rescale_timesteps: bool = False 75 | latent_infer_path: str = None 76 | latent_znormalize: bool = False 77 | latent_gen_type: GenerativeType = GenerativeType.ddim 78 | latent_loss_type: LossType = LossType.mse 79 | latent_model_mean_type: ModelMeanType = ModelMeanType.eps 80 | latent_model_var_type: ModelVarType = ModelVarType.fixed_large 81 | latent_rescale_timesteps: bool = False 82 | latent_T_eval: int = 1_000 83 | latent_clip_sample: bool = False 84 | latent_beta_scheduler: str = 'linear' 85 | beta_scheduler: str = 'linear' 86 | data_name: str = '' 87 | data_val_name: str = None 88 | diffusion_type: str = None 89 | dropout: float = 0.1 90 | ema_decay: float = 0.9999 91 | eval_num_images: int = 5_000 92 | eval_every_samples: int = 200_000 93 | eval_ema_every_samples: int = 200_000 94 | fid_use_torch: bool = True 95 | fp16: bool = False 96 | grad_clip: float = 1 97 | img_size: int = 64 98 | lr: float = 0.0001 99 | optimizer: OptimizerType = OptimizerType.adam 100 | weight_decay: float = 0 101 | model_conf: ModelConfig = None 102 | model_name: ModelName = None 103 | model_type: ModelType = None 104 | net_attn: Tuple[int] = None 105 | net_beatgans_attn_head: int = 1 106 | # not necessarily the same as the the number of style channels 107 | net_beatgans_embed_channels: int = 512 108 | net_resblock_updown: bool = True 109 | net_enc_use_time: bool = False 110 | net_enc_pool: str = 'adaptivenonzero' 111 | net_beatgans_gradient_checkpoint: bool = False 112 | net_beatgans_resnet_two_cond: bool = False 113 | net_beatgans_resnet_use_zero_module: bool = True 114 | net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm 115 | net_beatgans_resnet_cond_channels: int = None 116 | net_ch_mult: Tuple[int] = None 117 | net_ch: int = 64 118 | net_enc_attn: Tuple[int] = None 119 | net_enc_k: int = None 120 | # number of resblocks for the encoder (half-unet) 121 | net_enc_num_res_blocks: int = 2 122 | net_enc_channel_mult: Tuple[int] = None 123 | net_enc_grad_checkpoint: bool = False 124 | net_autoenc_stochastic: bool = False 125 | net_latent_activation: Activation = Activation.silu 126 | net_latent_channel_mult: Tuple[int] = (1, 2, 4) 127 | net_latent_condition_bias: float = 0 128 | net_latent_dropout: float = 0 129 | net_latent_layers: int = None 130 | net_latent_net_last_act: Activation = Activation.none 131 | net_latent_net_type: LatentNetType = LatentNetType.none 132 | net_latent_num_hid_channels: int = 1024 133 | net_latent_num_time_layers: int = 2 134 | net_latent_skip_layers: Tuple[int] = None 135 | net_latent_time_emb_channels: int = 64 136 | net_latent_use_norm: bool = False 137 | net_latent_time_last_act: bool = False 138 | net_num_res_blocks: int = 2 139 | # number of resblocks for the UNET 140 | net_num_input_res_blocks: int = None 141 | net_enc_num_cls: int = None 142 | num_workers: int = 4 143 | parallel: bool = False 144 | postfix: str = '' 145 | sample_size: int = 64 146 | sample_every_samples: int = 20_000 147 | save_every_samples: int = 100_000 148 | style_ch: int = 512 149 | T_eval: int = 1_000 150 | T_sampler: str = 'uniform' 151 | T: int = 1_000 152 | total_samples: int = 10_000_000 153 | warmup: int = 0 154 | pretrain: PretrainConfig = None 155 | continue_from: PretrainConfig = None 156 | eval_programs: Tuple[str] = None 157 | # if present load the checkpoint from this path instead 158 | eval_path: str = None 159 | base_dir: str = 'checkpoints' 160 | use_cache_dataset: bool = False 161 | data_cache_dir: str = os.path.expanduser('~/cache') 162 | work_cache_dir: str = os.path.expanduser('~/mycache') 163 | # to be overridden 164 | name: str = '' 165 | 166 | def __post_init__(self): 167 | self.batch_size_eval = self.batch_size_eval or self.batch_size 168 | self.data_val_name = self.data_val_name or self.data_name 169 | 170 | def scale_up_gpus(self, num_gpus, num_nodes=1): 171 | self.eval_ema_every_samples *= num_gpus * num_nodes 172 | self.eval_every_samples *= num_gpus * num_nodes 173 | self.sample_every_samples *= num_gpus * num_nodes 174 | self.batch_size *= num_gpus * num_nodes 175 | self.batch_size_eval *= num_gpus * num_nodes 176 | return self 177 | 178 | @property 179 | def batch_size_effective(self): 180 | return self.batch_size * self.accum_batches 181 | 182 | @property 183 | def fid_cache(self): 184 | # we try to use the local dirs to reduce the load over network drives 185 | # hopefully, this would reduce the disconnection problems with sshfs 186 | return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}' 187 | 188 | @property 189 | def data_path(self): 190 | # may use the cache dir 191 | path = data_paths[self.data_name] 192 | if self.use_cache_dataset and path is not None: 193 | path = use_cached_dataset_path( 194 | path, f'{self.data_cache_dir}/{self.data_name}') 195 | return path 196 | 197 | @property 198 | def logdir(self): 199 | return f'{self.base_dir}/{self.name}' 200 | 201 | @property 202 | def generate_dir(self): 203 | # we try to use the local dirs to reduce the load over network drives 204 | # hopefully, this would reduce the disconnection problems with sshfs 205 | return f'{self.work_cache_dir}/gen_images/{self.name}' 206 | 207 | def _make_diffusion_conf(self, T=None): 208 | if self.diffusion_type == 'beatgans': 209 | # can use T < self.T for evaluation 210 | # follows the guided-diffusion repo conventions 211 | # t's are evenly spaced 212 | if self.beatgans_gen_type == GenerativeType.ddpm: 213 | section_counts = [T] 214 | elif self.beatgans_gen_type == GenerativeType.ddim: 215 | section_counts = f'ddim{T}' 216 | else: 217 | raise NotImplementedError() 218 | 219 | return SpacedDiffusionBeatGansConfig( 220 | gen_type=self.beatgans_gen_type, 221 | model_type=self.model_type, 222 | betas=get_named_beta_schedule(self.beta_scheduler, self.T), 223 | model_mean_type=self.beatgans_model_mean_type, 224 | model_var_type=self.beatgans_model_var_type, 225 | loss_type=self.beatgans_loss_type, 226 | rescale_timesteps=self.beatgans_rescale_timesteps, 227 | use_timesteps=space_timesteps(num_timesteps=self.T, 228 | section_counts=section_counts), 229 | fp16=self.fp16, 230 | ) 231 | else: 232 | raise NotImplementedError() 233 | 234 | def _make_latent_diffusion_conf(self, T=None): 235 | # can use T < self.T for evaluation 236 | # follows the guided-diffusion repo conventions 237 | # t's are evenly spaced 238 | if self.latent_gen_type == GenerativeType.ddpm: 239 | section_counts = [T] 240 | elif self.latent_gen_type == GenerativeType.ddim: 241 | section_counts = f'ddim{T}' 242 | else: 243 | raise NotImplementedError() 244 | 245 | return SpacedDiffusionBeatGansConfig( 246 | train_pred_xstart_detach=self.train_pred_xstart_detach, 247 | gen_type=self.latent_gen_type, 248 | # latent's model is always ddpm 249 | model_type=ModelType.ddpm, 250 | # latent shares the beta scheduler and full T 251 | betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T), 252 | model_mean_type=self.latent_model_mean_type, 253 | model_var_type=self.latent_model_var_type, 254 | loss_type=self.latent_loss_type, 255 | rescale_timesteps=self.latent_rescale_timesteps, 256 | use_timesteps=space_timesteps(num_timesteps=self.T, 257 | section_counts=section_counts), 258 | fp16=self.fp16, 259 | ) 260 | 261 | @property 262 | def model_out_channels(self): 263 | return 3 264 | 265 | def make_T_sampler(self): 266 | if self.T_sampler == 'uniform': 267 | return UniformSampler(self.T) 268 | else: 269 | raise NotImplementedError() 270 | 271 | def make_diffusion_conf(self): 272 | return self._make_diffusion_conf(self.T) 273 | 274 | def make_eval_diffusion_conf(self): 275 | return self._make_diffusion_conf(T=self.T_eval) 276 | 277 | def make_latent_diffusion_conf(self): 278 | return self._make_latent_diffusion_conf(T=self.T) 279 | 280 | def make_latent_eval_diffusion_conf(self): 281 | # latent can have different eval T 282 | return self._make_latent_diffusion_conf(T=self.latent_T_eval) 283 | 284 | def make_dataset(self, path=None, **kwargs): 285 | if self.data_name == 'ffhqlmdb256': 286 | return FFHQlmdb(path=path or self.data_path, 287 | image_size=self.img_size, 288 | **kwargs) 289 | elif self.data_name == 'horse256': 290 | return Horse_lmdb(path=path or self.data_path, 291 | image_size=self.img_size, 292 | **kwargs) 293 | elif self.data_name == 'bedroom256': 294 | return Horse_lmdb(path=path or self.data_path, 295 | image_size=self.img_size, 296 | **kwargs) 297 | elif self.data_name == 'celebalmdb': 298 | # always use d2c crop 299 | return CelebAlmdb(path=path or self.data_path, 300 | image_size=self.img_size, 301 | original_resolution=None, 302 | crop_d2c=True, 303 | **kwargs) 304 | else: 305 | raise NotImplementedError() 306 | 307 | def make_loader(self, 308 | dataset, 309 | shuffle: bool, 310 | num_worker: bool = None, 311 | drop_last: bool = True, 312 | batch_size: int = None, 313 | parallel: bool = False): 314 | if parallel and distributed.is_initialized(): 315 | # drop last to make sure that there is no added special indexes 316 | sampler = DistributedSampler(dataset, 317 | shuffle=shuffle, 318 | drop_last=True) 319 | else: 320 | sampler = None 321 | return DataLoader( 322 | dataset, 323 | batch_size=batch_size or self.batch_size, 324 | sampler=sampler, 325 | # with sampler, use the sample instead of this option 326 | shuffle=False if sampler else shuffle, 327 | num_workers=num_worker or self.num_workers, 328 | pin_memory=True, 329 | drop_last=drop_last, 330 | multiprocessing_context=get_context('fork'), 331 | ) 332 | 333 | def make_model_conf(self): 334 | if self.model_name == ModelName.beatgans_ddpm: 335 | self.model_type = ModelType.ddpm 336 | self.model_conf = BeatGANsUNetConfig( 337 | attention_resolutions=self.net_attn, 338 | channel_mult=self.net_ch_mult, 339 | conv_resample=True, 340 | dims=2, 341 | dropout=self.dropout, 342 | embed_channels=self.net_beatgans_embed_channels, 343 | image_size=self.img_size, 344 | in_channels=3, 345 | model_channels=self.net_ch, 346 | num_classes=None, 347 | num_head_channels=-1, 348 | num_heads_upsample=-1, 349 | num_heads=self.net_beatgans_attn_head, 350 | num_res_blocks=self.net_num_res_blocks, 351 | num_input_res_blocks=self.net_num_input_res_blocks, 352 | out_channels=self.model_out_channels, 353 | resblock_updown=self.net_resblock_updown, 354 | use_checkpoint=self.net_beatgans_gradient_checkpoint, 355 | use_new_attention_order=False, 356 | resnet_two_cond=self.net_beatgans_resnet_two_cond, 357 | resnet_use_zero_module=self. 358 | net_beatgans_resnet_use_zero_module, 359 | ) 360 | elif self.model_name in [ 361 | ModelName.beatgans_autoenc, 362 | ]: 363 | cls = BeatGANsAutoencConfig 364 | # supports both autoenc and vaeddpm 365 | if self.model_name == ModelName.beatgans_autoenc: 366 | self.model_type = ModelType.autoencoder 367 | else: 368 | raise NotImplementedError() 369 | 370 | if self.net_latent_net_type == LatentNetType.none: 371 | latent_net_conf = None 372 | elif self.net_latent_net_type == LatentNetType.skip: 373 | latent_net_conf = MLPSkipNetConfig( 374 | num_channels=self.style_ch, 375 | skip_layers=self.net_latent_skip_layers, 376 | num_hid_channels=self.net_latent_num_hid_channels, 377 | num_layers=self.net_latent_layers, 378 | num_time_emb_channels=self.net_latent_time_emb_channels, 379 | activation=self.net_latent_activation, 380 | use_norm=self.net_latent_use_norm, 381 | condition_bias=self.net_latent_condition_bias, 382 | dropout=self.net_latent_dropout, 383 | last_act=self.net_latent_net_last_act, 384 | num_time_layers=self.net_latent_num_time_layers, 385 | time_last_act=self.net_latent_time_last_act, 386 | ) 387 | else: 388 | raise NotImplementedError() 389 | 390 | self.model_conf = cls( 391 | attention_resolutions=self.net_attn, 392 | channel_mult=self.net_ch_mult, 393 | conv_resample=True, 394 | dims=2, 395 | dropout=self.dropout, 396 | embed_channels=self.net_beatgans_embed_channels, 397 | enc_out_channels=self.style_ch, 398 | enc_pool=self.net_enc_pool, 399 | enc_num_res_block=self.net_enc_num_res_blocks, 400 | enc_channel_mult=self.net_enc_channel_mult, 401 | enc_grad_checkpoint=self.net_enc_grad_checkpoint, 402 | enc_attn_resolutions=self.net_enc_attn, 403 | image_size=self.img_size, 404 | in_channels=3, 405 | model_channels=self.net_ch, 406 | num_classes=None, 407 | num_head_channels=-1, 408 | num_heads_upsample=-1, 409 | num_heads=self.net_beatgans_attn_head, 410 | num_res_blocks=self.net_num_res_blocks, 411 | num_input_res_blocks=self.net_num_input_res_blocks, 412 | out_channels=self.model_out_channels, 413 | resblock_updown=self.net_resblock_updown, 414 | use_checkpoint=self.net_beatgans_gradient_checkpoint, 415 | use_new_attention_order=False, 416 | resnet_two_cond=self.net_beatgans_resnet_two_cond, 417 | resnet_use_zero_module=self. 418 | net_beatgans_resnet_use_zero_module, 419 | latent_net_conf=latent_net_conf, 420 | resnet_cond_channels=self.net_beatgans_resnet_cond_channels, 421 | ) 422 | else: 423 | raise NotImplementedError(self.model_name) 424 | 425 | return self.model_conf 426 | -------------------------------------------------------------------------------- /config_base.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from copy import deepcopy 4 | from dataclasses import dataclass 5 | 6 | 7 | @dataclass 8 | class BaseConfig: 9 | def clone(self): 10 | return deepcopy(self) 11 | 12 | def inherit(self, another): 13 | """inherit common keys from a given config""" 14 | common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys()) 15 | for k in common_keys: 16 | setattr(self, k, getattr(another, k)) 17 | 18 | def propagate(self): 19 | """push down the configuration to all members""" 20 | for k, v in self.__dict__.items(): 21 | if isinstance(v, BaseConfig): 22 | v.inherit(self) 23 | v.propagate() 24 | 25 | def save(self, save_path): 26 | """save config to json file""" 27 | dirname = os.path.dirname(save_path) 28 | if not os.path.exists(dirname): 29 | os.makedirs(dirname) 30 | conf = self.as_dict_jsonable() 31 | with open(save_path, 'w') as f: 32 | json.dump(conf, f) 33 | 34 | def load(self, load_path): 35 | """load json config""" 36 | with open(load_path) as f: 37 | conf = json.load(f) 38 | self.from_dict(conf) 39 | 40 | def from_dict(self, dict, strict=False): 41 | for k, v in dict.items(): 42 | if not hasattr(self, k): 43 | if strict: 44 | raise ValueError(f"loading extra '{k}'") 45 | else: 46 | print(f"loading extra '{k}'") 47 | continue 48 | if isinstance(self.__dict__[k], BaseConfig): 49 | self.__dict__[k].from_dict(v) 50 | else: 51 | self.__dict__[k] = v 52 | 53 | def as_dict_jsonable(self): 54 | conf = {} 55 | for k, v in self.__dict__.items(): 56 | if isinstance(v, BaseConfig): 57 | conf[k] = v.as_dict_jsonable() 58 | else: 59 | if jsonable(v): 60 | conf[k] = v 61 | else: 62 | # ignore not jsonable 63 | pass 64 | return conf 65 | 66 | 67 | def jsonable(x): 68 | try: 69 | json.dumps(x) 70 | return True 71 | except TypeError: 72 | return False 73 | -------------------------------------------------------------------------------- /data_resize_bedroom.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import os 4 | from os.path import join, exists 5 | from functools import partial 6 | from io import BytesIO 7 | import shutil 8 | 9 | import lmdb 10 | from PIL import Image 11 | from torchvision.datasets import LSUNClass 12 | from torchvision.transforms import functional as trans_fn 13 | from tqdm import tqdm 14 | 15 | from multiprocessing import Process, Queue 16 | 17 | 18 | def resize_and_convert(img, size, resample, quality=100): 19 | img = trans_fn.resize(img, size, resample) 20 | img = trans_fn.center_crop(img, size) 21 | buffer = BytesIO() 22 | img.save(buffer, format="webp", quality=quality) 23 | val = buffer.getvalue() 24 | 25 | return val 26 | 27 | 28 | def resize_multiple(img, 29 | sizes=(128, 256, 512, 1024), 30 | resample=Image.LANCZOS, 31 | quality=100): 32 | imgs = [] 33 | 34 | for size in sizes: 35 | imgs.append(resize_and_convert(img, size, resample, quality)) 36 | 37 | return imgs 38 | 39 | 40 | def resize_worker(idx, img, sizes, resample): 41 | img = img.convert("RGB") 42 | out = resize_multiple(img, sizes=sizes, resample=resample) 43 | return idx, out 44 | 45 | 46 | from torch.utils.data import Dataset, DataLoader 47 | 48 | 49 | class ConvertDataset(Dataset): 50 | def __init__(self, data) -> None: 51 | self.data = data 52 | 53 | def __len__(self): 54 | return len(self.data) 55 | 56 | def __getitem__(self, index): 57 | img, _ = self.data[index] 58 | bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90) 59 | return bytes 60 | 61 | 62 | if __name__ == "__main__": 63 | """ 64 | converting lsun' original lmdb to our lmdb, which is somehow more performant. 65 | """ 66 | from tqdm import tqdm 67 | 68 | # path to the original lsun's lmdb 69 | src_path = 'datasets/bedroom_train_lmdb' 70 | out_path = 'datasets/bedroom256.lmdb' 71 | 72 | dataset = LSUNClass(root=os.path.expanduser(src_path)) 73 | dataset = ConvertDataset(dataset) 74 | loader = DataLoader(dataset, 75 | batch_size=50, 76 | num_workers=12, 77 | collate_fn=lambda x: x, 78 | shuffle=False) 79 | 80 | target = os.path.expanduser(out_path) 81 | if os.path.exists(target): 82 | shutil.rmtree(target) 83 | 84 | with lmdb.open(target, map_size=1024**4, readahead=False) as env: 85 | with tqdm(total=len(dataset)) as progress: 86 | i = 0 87 | for batch in loader: 88 | with env.begin(write=True) as txn: 89 | for img in batch: 90 | key = f"{256}-{str(i).zfill(7)}".encode("utf-8") 91 | # print(key) 92 | txn.put(key, img) 93 | i += 1 94 | progress.update() 95 | # if i == 1000: 96 | # break 97 | # if total == len(imgset): 98 | # break 99 | 100 | with env.begin(write=True) as txn: 101 | txn.put("length".encode("utf-8"), str(i).encode("utf-8")) 102 | -------------------------------------------------------------------------------- /data_resize_celeba.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import os 4 | import shutil 5 | from functools import partial 6 | from io import BytesIO 7 | from multiprocessing import Process, Queue 8 | from os.path import exists, join 9 | from pathlib import Path 10 | 11 | import lmdb 12 | from PIL import Image 13 | from torch.utils.data import DataLoader, Dataset 14 | from torchvision.datasets import LSUNClass 15 | from torchvision.transforms import functional as trans_fn 16 | from tqdm import tqdm 17 | 18 | 19 | def resize_and_convert(img, size, resample, quality=100): 20 | if size is not None: 21 | img = trans_fn.resize(img, size, resample) 22 | img = trans_fn.center_crop(img, size) 23 | 24 | buffer = BytesIO() 25 | img.save(buffer, format="webp", quality=quality) 26 | val = buffer.getvalue() 27 | 28 | return val 29 | 30 | 31 | def resize_multiple(img, 32 | sizes=(128, 256, 512, 1024), 33 | resample=Image.LANCZOS, 34 | quality=100): 35 | imgs = [] 36 | 37 | for size in sizes: 38 | imgs.append(resize_and_convert(img, size, resample, quality)) 39 | 40 | return imgs 41 | 42 | 43 | def resize_worker(idx, img, sizes, resample): 44 | img = img.convert("RGB") 45 | out = resize_multiple(img, sizes=sizes, resample=resample) 46 | return idx, out 47 | 48 | 49 | class ConvertDataset(Dataset): 50 | def __init__(self, data, size) -> None: 51 | self.data = data 52 | self.size = size 53 | 54 | def __len__(self): 55 | return len(self.data) 56 | 57 | def __getitem__(self, index): 58 | img = self.data[index] 59 | bytes = resize_and_convert(img, self.size, Image.LANCZOS, quality=100) 60 | return bytes 61 | 62 | 63 | class ImageFolder(Dataset): 64 | def __init__(self, folder, ext='jpg'): 65 | super().__init__() 66 | paths = sorted([p for p in Path(f'{folder}').glob(f'*.{ext}')]) 67 | self.paths = paths 68 | 69 | def __len__(self): 70 | return len(self.paths) 71 | 72 | def __getitem__(self, index): 73 | path = os.path.join(self.paths[index]) 74 | img = Image.open(path) 75 | return img 76 | 77 | 78 | if __name__ == "__main__": 79 | from tqdm import tqdm 80 | 81 | out_path = 'datasets/celeba.lmdb' 82 | in_path = 'datasets/celeba' 83 | ext = 'jpg' 84 | size = None 85 | 86 | dataset = ImageFolder(in_path, ext) 87 | print('len:', len(dataset)) 88 | dataset = ConvertDataset(dataset, size) 89 | loader = DataLoader(dataset, 90 | batch_size=50, 91 | num_workers=12, 92 | collate_fn=lambda x: x, 93 | shuffle=False) 94 | 95 | target = os.path.expanduser(out_path) 96 | if os.path.exists(target): 97 | shutil.rmtree(target) 98 | 99 | with lmdb.open(target, map_size=1024**4, readahead=False) as env: 100 | with tqdm(total=len(dataset)) as progress: 101 | i = 0 102 | for batch in loader: 103 | with env.begin(write=True) as txn: 104 | for img in batch: 105 | key = f"{size}-{str(i).zfill(7)}".encode("utf-8") 106 | # print(key) 107 | txn.put(key, img) 108 | i += 1 109 | progress.update() 110 | # if i == 1000: 111 | # break 112 | # if total == len(imgset): 113 | # break 114 | 115 | with env.begin(write=True) as txn: 116 | txn.put("length".encode("utf-8"), str(i).encode("utf-8")) 117 | -------------------------------------------------------------------------------- /data_resize_celebahq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | from functools import partial 4 | from io import BytesIO 5 | from pathlib import Path 6 | 7 | import lmdb 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import functional as trans_fn 11 | from tqdm import tqdm 12 | import os 13 | 14 | 15 | def resize_and_convert(img, size, resample, quality=100): 16 | img = trans_fn.resize(img, size, resample) 17 | img = trans_fn.center_crop(img, size) 18 | buffer = BytesIO() 19 | img.save(buffer, format="jpeg", quality=quality) 20 | val = buffer.getvalue() 21 | 22 | return val 23 | 24 | 25 | def resize_multiple(img, 26 | sizes=(128, 256, 512, 1024), 27 | resample=Image.LANCZOS, 28 | quality=100): 29 | imgs = [] 30 | 31 | for size in sizes: 32 | imgs.append(resize_and_convert(img, size, resample, quality)) 33 | 34 | return imgs 35 | 36 | 37 | def resize_worker(img_file, sizes, resample): 38 | i, (file, idx) = img_file 39 | img = Image.open(file) 40 | img = img.convert("RGB") 41 | out = resize_multiple(img, sizes=sizes, resample=resample) 42 | 43 | return i, idx, out 44 | 45 | 46 | def prepare(env, 47 | paths, 48 | n_worker, 49 | sizes=(128, 256, 512, 1024), 50 | resample=Image.LANCZOS): 51 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 52 | 53 | # index = filename in int 54 | indexs = [] 55 | for each in paths: 56 | file = os.path.basename(each) 57 | name, ext = file.split('.') 58 | idx = int(name) 59 | indexs.append(idx) 60 | 61 | # sort by file index 62 | files = sorted(zip(paths, indexs), key=lambda x: x[1]) 63 | files = list(enumerate(files)) 64 | total = 0 65 | 66 | with multiprocessing.Pool(n_worker) as pool: 67 | for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 68 | for size, img in zip(sizes, imgs): 69 | key = f"{size}-{str(idx).zfill(5)}".encode("utf-8") 70 | 71 | with env.begin(write=True) as txn: 72 | txn.put(key, img) 73 | 74 | total += 1 75 | 76 | with env.begin(write=True) as txn: 77 | txn.put("length".encode("utf-8"), str(total).encode("utf-8")) 78 | 79 | 80 | class ImageFolder(Dataset): 81 | def __init__(self, folder, exts=['jpg']): 82 | super().__init__() 83 | self.paths = [ 84 | p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}') 85 | ] 86 | 87 | def __len__(self): 88 | return len(self.paths) 89 | 90 | def __getitem__(self, index): 91 | path = os.path.join(self.folder, self.paths[index]) 92 | img = Image.open(path) 93 | return img 94 | 95 | 96 | if __name__ == "__main__": 97 | """ 98 | converting celebahq images to lmdb 99 | """ 100 | num_workers = 16 101 | in_path = 'datasets/celebahq' 102 | out_path = 'datasets/celebahq256.lmdb' 103 | 104 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} 105 | resample = resample_map['lanczos'] 106 | 107 | sizes = [256] 108 | 109 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) 110 | 111 | # imgset = datasets.ImageFolder(in_path) 112 | # imgset = ImageFolder(in_path) 113 | exts = ['jpg'] 114 | paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')] 115 | 116 | with lmdb.open(out_path, map_size=1024**4, readahead=False) as env: 117 | prepare(env, paths, num_workers, sizes=sizes, resample=resample) 118 | -------------------------------------------------------------------------------- /data_resize_ffhq.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | from functools import partial 4 | from io import BytesIO 5 | from pathlib import Path 6 | 7 | import lmdb 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | from torchvision.transforms import functional as trans_fn 11 | from tqdm import tqdm 12 | import os 13 | 14 | 15 | def resize_and_convert(img, size, resample, quality=100): 16 | img = trans_fn.resize(img, size, resample) 17 | img = trans_fn.center_crop(img, size) 18 | buffer = BytesIO() 19 | img.save(buffer, format="jpeg", quality=quality) 20 | val = buffer.getvalue() 21 | 22 | return val 23 | 24 | 25 | def resize_multiple(img, 26 | sizes=(128, 256, 512, 1024), 27 | resample=Image.LANCZOS, 28 | quality=100): 29 | imgs = [] 30 | 31 | for size in sizes: 32 | imgs.append(resize_and_convert(img, size, resample, quality)) 33 | 34 | return imgs 35 | 36 | 37 | def resize_worker(img_file, sizes, resample): 38 | i, (file, idx) = img_file 39 | img = Image.open(file) 40 | img = img.convert("RGB") 41 | out = resize_multiple(img, sizes=sizes, resample=resample) 42 | 43 | return i, idx, out 44 | 45 | 46 | def prepare(env, 47 | paths, 48 | n_worker, 49 | sizes=(128, 256, 512, 1024), 50 | resample=Image.LANCZOS): 51 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 52 | 53 | # index = filename in int 54 | indexs = [] 55 | for each in paths: 56 | file = os.path.basename(each) 57 | name, ext = file.split('.') 58 | idx = int(name) 59 | indexs.append(idx) 60 | 61 | # sort by file index 62 | files = sorted(zip(paths, indexs), key=lambda x: x[1]) 63 | files = list(enumerate(files)) 64 | total = 0 65 | 66 | with multiprocessing.Pool(n_worker) as pool: 67 | for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 68 | for size, img in zip(sizes, imgs): 69 | key = f"{size}-{str(idx).zfill(5)}".encode("utf-8") 70 | 71 | with env.begin(write=True) as txn: 72 | txn.put(key, img) 73 | 74 | total += 1 75 | 76 | with env.begin(write=True) as txn: 77 | txn.put("length".encode("utf-8"), str(total).encode("utf-8")) 78 | 79 | 80 | class ImageFolder(Dataset): 81 | def __init__(self, folder, exts=['jpg']): 82 | super().__init__() 83 | self.paths = [ 84 | p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}') 85 | ] 86 | 87 | def __len__(self): 88 | return len(self.paths) 89 | 90 | def __getitem__(self, index): 91 | path = os.path.join(self.folder, self.paths[index]) 92 | img = Image.open(path) 93 | return img 94 | 95 | 96 | if __name__ == "__main__": 97 | """ 98 | converting ffhq images to lmdb 99 | """ 100 | num_workers = 16 101 | # original ffhq data path 102 | in_path = 'datasets/ffhq' 103 | # target output path 104 | out_path = 'datasets/ffhq.lmdb' 105 | 106 | if not os.path.exists(out_path): 107 | os.makedirs(out_path) 108 | 109 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} 110 | resample = resample_map['lanczos'] 111 | 112 | sizes = [256] 113 | 114 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) 115 | 116 | # imgset = datasets.ImageFolder(in_path) 117 | # imgset = ImageFolder(in_path) 118 | exts = ['jpg'] 119 | paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')] 120 | # print(paths[:10]) 121 | 122 | with lmdb.open(out_path, map_size=1024**4, readahead=False) as env: 123 | prepare(env, paths, num_workers, sizes=sizes, resample=resample) 124 | -------------------------------------------------------------------------------- /data_resize_horse.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | import os 4 | import shutil 5 | from functools import partial 6 | from io import BytesIO 7 | from multiprocessing import Process, Queue 8 | from os.path import exists, join 9 | 10 | import lmdb 11 | from PIL import Image 12 | from torch.utils.data import DataLoader, Dataset 13 | from torchvision.datasets import LSUNClass 14 | from torchvision.transforms import functional as trans_fn 15 | from tqdm import tqdm 16 | 17 | 18 | def resize_and_convert(img, size, resample, quality=100): 19 | img = trans_fn.resize(img, size, resample) 20 | img = trans_fn.center_crop(img, size) 21 | buffer = BytesIO() 22 | img.save(buffer, format="webp", quality=quality) 23 | val = buffer.getvalue() 24 | 25 | return val 26 | 27 | 28 | def resize_multiple(img, 29 | sizes=(128, 256, 512, 1024), 30 | resample=Image.LANCZOS, 31 | quality=100): 32 | imgs = [] 33 | 34 | for size in sizes: 35 | imgs.append(resize_and_convert(img, size, resample, quality)) 36 | 37 | return imgs 38 | 39 | 40 | def resize_worker(idx, img, sizes, resample): 41 | img = img.convert("RGB") 42 | out = resize_multiple(img, sizes=sizes, resample=resample) 43 | return idx, out 44 | 45 | 46 | class ConvertDataset(Dataset): 47 | def __init__(self, data) -> None: 48 | self.data = data 49 | 50 | def __len__(self): 51 | return len(self.data) 52 | 53 | def __getitem__(self, index): 54 | img, _ = self.data[index] 55 | bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90) 56 | return bytes 57 | 58 | 59 | if __name__ == "__main__": 60 | """ 61 | converting lsun' original lmdb to our lmdb, which is somehow more performant. 62 | """ 63 | from tqdm import tqdm 64 | 65 | # path to the original lsun's lmdb 66 | src_path = 'datasets/horse_train_lmdb' 67 | out_path = 'datasets/horse256.lmdb' 68 | 69 | dataset = LSUNClass(root=os.path.expanduser(src_path)) 70 | dataset = ConvertDataset(dataset) 71 | loader = DataLoader(dataset, 72 | batch_size=50, 73 | num_workers=16, 74 | collate_fn=lambda x: x) 75 | 76 | target = os.path.expanduser(out_path) 77 | if os.path.exists(target): 78 | shutil.rmtree(target) 79 | 80 | with lmdb.open(target, map_size=1024**4, readahead=False) as env: 81 | with tqdm(total=len(dataset)) as progress: 82 | i = 0 83 | for batch in loader: 84 | with env.begin(write=True) as txn: 85 | for img in batch: 86 | key = f"{256}-{str(i).zfill(7)}".encode("utf-8") 87 | # print(key) 88 | txn.put(key, img) 89 | i += 1 90 | progress.update() 91 | # if i == 1000: 92 | # break 93 | # if total == len(imgset): 94 | # break 95 | 96 | with env.begin(write=True) as txn: 97 | txn.put("length".encode("utf-8"), str(i).encode("utf-8")) 98 | -------------------------------------------------------------------------------- /dataset_util.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | from dist_utils import * 4 | 5 | 6 | def use_cached_dataset_path(source_path, cache_path): 7 | if get_rank() == 0: 8 | if not os.path.exists(cache_path): 9 | # shutil.rmtree(cache_path) 10 | print(f'copying the data: {source_path} to {cache_path}') 11 | shutil.copytree(source_path, cache_path) 12 | barrier() 13 | return cache_path -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig 4 | 5 | Sampler = Union[SpacedDiffusionBeatGans] 6 | SamplerConfig = Union[SpacedDiffusionBeatGansConfig] 7 | -------------------------------------------------------------------------------- /diffusion/diffusion.py: -------------------------------------------------------------------------------- 1 | from .base import * 2 | from dataclasses import dataclass 3 | 4 | 5 | def space_timesteps(num_timesteps, section_counts): 6 | """ 7 | Create a list of timesteps to use from an original diffusion process, 8 | given the number of timesteps we want to take from equally-sized portions 9 | of the original process. 10 | 11 | For example, if there's 300 timesteps and the section counts are [10,15,20] 12 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 13 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 14 | 15 | If the stride is a string starting with "ddim", then the fixed striding 16 | from the DDIM paper is used, and only one section is allowed. 17 | 18 | :param num_timesteps: the number of diffusion steps in the original 19 | process to divide up. 20 | :param section_counts: either a list of numbers, or a string containing 21 | comma-separated numbers, indicating the step count 22 | per section. As a special case, use "ddimN" where N 23 | is a number of steps to use the striding from the 24 | DDIM paper. 25 | :return: a set of diffusion steps from the original process to use. 26 | """ 27 | if isinstance(section_counts, str): 28 | if section_counts.startswith("ddim"): 29 | desired_count = int(section_counts[len("ddim"):]) 30 | for i in range(1, num_timesteps): 31 | if len(range(0, num_timesteps, i)) == desired_count: 32 | return set(range(0, num_timesteps, i)) 33 | raise ValueError( 34 | f"cannot create exactly {num_timesteps} steps with an integer stride" 35 | ) 36 | section_counts = [int(x) for x in section_counts.split(",")] 37 | size_per = num_timesteps // len(section_counts) 38 | extra = num_timesteps % len(section_counts) 39 | start_idx = 0 40 | all_steps = [] 41 | for i, section_count in enumerate(section_counts): 42 | size = size_per + (1 if i < extra else 0) 43 | if size < section_count: 44 | raise ValueError( 45 | f"cannot divide section of {size} steps into {section_count}") 46 | if section_count <= 1: 47 | frac_stride = 1 48 | else: 49 | frac_stride = (size - 1) / (section_count - 1) 50 | cur_idx = 0.0 51 | taken_steps = [] 52 | for _ in range(section_count): 53 | taken_steps.append(start_idx + round(cur_idx)) 54 | cur_idx += frac_stride 55 | all_steps += taken_steps 56 | start_idx += size 57 | return set(all_steps) 58 | 59 | 60 | @dataclass 61 | class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig): 62 | use_timesteps: Tuple[int] = None 63 | 64 | def make_sampler(self): 65 | return SpacedDiffusionBeatGans(self) 66 | 67 | 68 | class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans): 69 | """ 70 | A diffusion process which can skip steps in a base diffusion process. 71 | 72 | :param use_timesteps: a collection (sequence or set) of timesteps from the 73 | original diffusion process to retain. 74 | :param kwargs: the kwargs to create the base diffusion process. 75 | """ 76 | def __init__(self, conf: SpacedDiffusionBeatGansConfig): 77 | self.conf = conf 78 | self.use_timesteps = set(conf.use_timesteps) 79 | # how the new t's mapped to the old t's 80 | self.timestep_map = [] 81 | self.original_num_steps = len(conf.betas) 82 | 83 | base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa 84 | last_alpha_cumprod = 1.0 85 | new_betas = [] 86 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 87 | if i in self.use_timesteps: 88 | # getting the new betas of the new timesteps 89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 90 | last_alpha_cumprod = alpha_cumprod 91 | self.timestep_map.append(i) 92 | conf.betas = np.array(new_betas) 93 | super().__init__(conf) 94 | 95 | def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs 96 | return super().p_mean_variance(self._wrap_model(model), *args, 97 | **kwargs) 98 | 99 | def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs 100 | return super().training_losses(self._wrap_model(model), *args, 101 | **kwargs) 102 | 103 | def condition_mean(self, cond_fn, *args, **kwargs): 104 | return super().condition_mean(self._wrap_model(cond_fn), *args, 105 | **kwargs) 106 | 107 | def condition_score(self, cond_fn, *args, **kwargs): 108 | return super().condition_score(self._wrap_model(cond_fn), *args, 109 | **kwargs) 110 | 111 | def _wrap_model(self, model: Model): 112 | if isinstance(model, _WrappedModel): 113 | return model 114 | return _WrappedModel(model, self.timestep_map, self.rescale_timesteps, 115 | self.original_num_steps) 116 | 117 | def _scale_timesteps(self, t): 118 | # Scaling is done by the wrapped model. 119 | return t 120 | 121 | 122 | class _WrappedModel: 123 | """ 124 | converting the supplied t's to the old t's scales. 125 | """ 126 | def __init__(self, model, timestep_map, rescale_timesteps, 127 | original_num_steps): 128 | self.model = model 129 | self.timestep_map = timestep_map 130 | self.rescale_timesteps = rescale_timesteps 131 | self.original_num_steps = original_num_steps 132 | 133 | def forward(self, x, t, t_cond=None, **kwargs): 134 | """ 135 | Args: 136 | t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's 137 | t_cond: the same as t but can be of different values 138 | """ 139 | map_tensor = th.tensor(self.timestep_map, 140 | device=t.device, 141 | dtype=t.dtype) 142 | 143 | def do(t): 144 | new_ts = map_tensor[t] 145 | if self.rescale_timesteps: 146 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 147 | return new_ts 148 | 149 | if t_cond is not None: 150 | # support t_cond 151 | t_cond = do(t_cond) 152 | 153 | return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs) 154 | 155 | def __getattr__(self, name): 156 | # allow for calling the model's methods 157 | if hasattr(self.model, name): 158 | func = getattr(self.model, name) 159 | return func 160 | raise AttributeError(name) 161 | -------------------------------------------------------------------------------- /diffusion/resample.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import numpy as np 4 | import torch as th 5 | import torch.distributed as dist 6 | 7 | 8 | def create_named_schedule_sampler(name, diffusion): 9 | """ 10 | Create a ScheduleSampler from a library of pre-defined samplers. 11 | 12 | :param name: the name of the sampler. 13 | :param diffusion: the diffusion object to sample for. 14 | """ 15 | if name == "uniform": 16 | return UniformSampler(diffusion) 17 | else: 18 | raise NotImplementedError(f"unknown schedule sampler: {name}") 19 | 20 | 21 | class ScheduleSampler(ABC): 22 | """ 23 | A distribution over timesteps in the diffusion process, intended to reduce 24 | variance of the objective. 25 | 26 | By default, samplers perform unbiased importance sampling, in which the 27 | objective's mean is unchanged. 28 | However, subclasses may override sample() to change how the resampled 29 | terms are reweighted, allowing for actual changes in the objective. 30 | """ 31 | @abstractmethod 32 | def weights(self): 33 | """ 34 | Get a numpy array of weights, one per diffusion step. 35 | 36 | The weights needn't be normalized, but must be positive. 37 | """ 38 | 39 | def sample(self, batch_size, device): 40 | """ 41 | Importance-sample timesteps for a batch. 42 | 43 | :param batch_size: the number of timesteps. 44 | :param device: the torch device to save to. 45 | :return: a tuple (timesteps, weights): 46 | - timesteps: a tensor of timestep indices. 47 | - weights: a tensor of weights to scale the resulting losses. 48 | """ 49 | w = self.weights() 50 | p = w / np.sum(w) 51 | indices_np = np.random.choice(len(p), size=(batch_size, ), p=p) 52 | indices = th.from_numpy(indices_np).long().to(device) 53 | weights_np = 1 / (len(p) * p[indices_np]) 54 | weights = th.from_numpy(weights_np).float().to(device) 55 | return indices, weights 56 | 57 | 58 | class UniformSampler(ScheduleSampler): 59 | def __init__(self, num_timesteps): 60 | self._weights = np.ones([num_timesteps]) 61 | 62 | def weights(self): 63 | return self._weights 64 | -------------------------------------------------------------------------------- /dist_utils.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from torch import distributed 3 | 4 | 5 | def barrier(): 6 | if distributed.is_initialized(): 7 | distributed.barrier() 8 | else: 9 | pass 10 | 11 | 12 | def broadcast(data, src): 13 | if distributed.is_initialized(): 14 | distributed.broadcast(data, src) 15 | else: 16 | pass 17 | 18 | 19 | def all_gather(data: List, src): 20 | if distributed.is_initialized(): 21 | distributed.all_gather(data, src) 22 | else: 23 | data[0] = src 24 | 25 | 26 | def get_rank(): 27 | if distributed.is_initialized(): 28 | return distributed.get_rank() 29 | else: 30 | return 0 31 | 32 | 33 | def get_world_size(): 34 | if distributed.is_initialized(): 35 | return distributed.get_world_size() 36 | else: 37 | return 1 38 | 39 | 40 | def chunk_size(size, rank, world_size): 41 | extra = rank < size % world_size 42 | return size // world_size + extra -------------------------------------------------------------------------------- /evals/ffhq128_autoenc_130M.txt: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /evals/ffhq128_autoenc_latent.txt: -------------------------------------------------------------------------------- 1 | {"fid_ema_T10_Tlatent10": 20.634624481201172} 2 | -------------------------------------------------------------------------------- /experiment_classifier.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | from dataset import * 3 | import pandas as pd 4 | import json 5 | import os 6 | import copy 7 | 8 | import numpy as np 9 | import pytorch_lightning as pl 10 | from pytorch_lightning import loggers as pl_loggers 11 | from pytorch_lightning.callbacks import * 12 | import torch 13 | 14 | 15 | class ZipLoader: 16 | def __init__(self, loaders): 17 | self.loaders = loaders 18 | 19 | def __len__(self): 20 | return len(self.loaders[0]) 21 | 22 | def __iter__(self): 23 | for each in zip(*self.loaders): 24 | yield each 25 | 26 | 27 | class ClsModel(pl.LightningModule): 28 | def __init__(self, conf: TrainConfig): 29 | super().__init__() 30 | assert conf.train_mode.is_manipulate() 31 | if conf.seed is not None: 32 | pl.seed_everything(conf.seed) 33 | 34 | self.save_hyperparameters(conf.as_dict_jsonable()) 35 | self.conf = conf 36 | 37 | # preparations 38 | if conf.train_mode == TrainMode.manipulate: 39 | # this is only important for training! 40 | # the latent is freshly inferred to make sure it matches the image 41 | # manipulating latents require the base model 42 | self.model = conf.make_model_conf().make_model() 43 | self.ema_model = copy.deepcopy(self.model) 44 | self.model.requires_grad_(False) 45 | self.ema_model.requires_grad_(False) 46 | self.ema_model.eval() 47 | 48 | if conf.pretrain is not None: 49 | print(f'loading pretrain ... {conf.pretrain.name}') 50 | state = torch.load(conf.pretrain.path, map_location='cpu') 51 | print('step:', state['global_step']) 52 | self.load_state_dict(state['state_dict'], strict=False) 53 | 54 | # load the latent stats 55 | if conf.manipulate_znormalize: 56 | print('loading latent stats ...') 57 | state = torch.load(conf.latent_infer_path) 58 | self.conds = state['conds'] 59 | self.register_buffer('conds_mean', 60 | state['conds_mean'][None, :]) 61 | self.register_buffer('conds_std', state['conds_std'][None, :]) 62 | else: 63 | self.conds_mean = None 64 | self.conds_std = None 65 | 66 | if conf.manipulate_mode in [ManipulateMode.celebahq_all]: 67 | num_cls = len(CelebAttrDataset.id_to_cls) 68 | elif conf.manipulate_mode.is_single_class(): 69 | num_cls = 1 70 | else: 71 | raise NotImplementedError() 72 | 73 | # classifier 74 | if conf.train_mode == TrainMode.manipulate: 75 | # latent manipluation requires only a linear classifier 76 | self.classifier = nn.Linear(conf.style_ch, num_cls) 77 | else: 78 | raise NotImplementedError() 79 | 80 | self.ema_classifier = copy.deepcopy(self.classifier) 81 | 82 | def state_dict(self, *args, **kwargs): 83 | # don't save the base model 84 | out = {} 85 | for k, v in super().state_dict(*args, **kwargs).items(): 86 | if k.startswith('model.'): 87 | pass 88 | elif k.startswith('ema_model.'): 89 | pass 90 | else: 91 | out[k] = v 92 | return out 93 | 94 | def load_state_dict(self, state_dict, strict: bool = None): 95 | if self.conf.train_mode == TrainMode.manipulate: 96 | # change the default strict => False 97 | if strict is None: 98 | strict = False 99 | else: 100 | if strict is None: 101 | strict = True 102 | return super().load_state_dict(state_dict, strict=strict) 103 | 104 | def normalize(self, cond): 105 | cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to( 106 | self.device) 107 | return cond 108 | 109 | def denormalize(self, cond): 110 | cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to( 111 | self.device) 112 | return cond 113 | 114 | def load_dataset(self): 115 | if self.conf.manipulate_mode == ManipulateMode.d2c_fewshot: 116 | return CelebD2CAttrFewshotDataset( 117 | cls_name=self.conf.manipulate_cls, 118 | K=self.conf.manipulate_shots, 119 | img_folder=data_paths['celeba'], 120 | img_size=self.conf.img_size, 121 | seed=self.conf.manipulate_seed, 122 | all_neg=False, 123 | do_augment=True, 124 | ) 125 | elif self.conf.manipulate_mode == ManipulateMode.d2c_fewshot_allneg: 126 | # positive-unlabeled classifier needs to keep the class ratio 1:1 127 | # we use two dataloaders, one for each class, to stabiliize the training 128 | img_folder = data_paths['celeba'] 129 | 130 | return [ 131 | CelebD2CAttrFewshotDataset( 132 | cls_name=self.conf.manipulate_cls, 133 | K=self.conf.manipulate_shots, 134 | img_folder=img_folder, 135 | img_size=self.conf.img_size, 136 | only_cls_name=self.conf.manipulate_cls, 137 | only_cls_value=1, 138 | seed=self.conf.manipulate_seed, 139 | all_neg=True, 140 | do_augment=True), 141 | CelebD2CAttrFewshotDataset( 142 | cls_name=self.conf.manipulate_cls, 143 | K=self.conf.manipulate_shots, 144 | img_folder=img_folder, 145 | img_size=self.conf.img_size, 146 | only_cls_name=self.conf.manipulate_cls, 147 | only_cls_value=-1, 148 | seed=self.conf.manipulate_seed, 149 | all_neg=True, 150 | do_augment=True), 151 | ] 152 | elif self.conf.manipulate_mode == ManipulateMode.celebahq_all: 153 | return CelebHQAttrDataset(data_paths['celebahq'], 154 | self.conf.img_size, 155 | data_paths['celebahq_anno'], 156 | do_augment=True) 157 | else: 158 | raise NotImplementedError() 159 | 160 | def setup(self, stage=None) -> None: 161 | ############################################## 162 | # NEED TO SET THE SEED SEPARATELY HERE 163 | if self.conf.seed is not None: 164 | seed = self.conf.seed * get_world_size() + self.global_rank 165 | np.random.seed(seed) 166 | torch.manual_seed(seed) 167 | torch.cuda.manual_seed(seed) 168 | print('local seed:', seed) 169 | ############################################## 170 | 171 | self.train_data = self.load_dataset() 172 | if self.conf.manipulate_mode.is_fewshot(): 173 | # repeat the dataset to be larger (speed up the training) 174 | if isinstance(self.train_data, list): 175 | # fewshot-allneg has two datasets 176 | # we resize them to be of equal sizes 177 | a, b = self.train_data 178 | self.train_data = [ 179 | Repeat(a, max(len(a), len(b))), 180 | Repeat(b, max(len(a), len(b))), 181 | ] 182 | else: 183 | self.train_data = Repeat(self.train_data, 100_000) 184 | 185 | def train_dataloader(self): 186 | # make sure to use the fraction of batch size 187 | # the batch size is global! 188 | conf = self.conf.clone() 189 | conf.batch_size = self.batch_size 190 | if isinstance(self.train_data, list): 191 | dataloader = [] 192 | for each in self.train_data: 193 | dataloader.append( 194 | conf.make_loader(each, shuffle=True, drop_last=True)) 195 | dataloader = ZipLoader(dataloader) 196 | else: 197 | dataloader = conf.make_loader(self.train_data, 198 | shuffle=True, 199 | drop_last=True) 200 | return dataloader 201 | 202 | @property 203 | def batch_size(self): 204 | ws = get_world_size() 205 | assert self.conf.batch_size % ws == 0 206 | return self.conf.batch_size // ws 207 | 208 | def training_step(self, batch, batch_idx): 209 | self.ema_model: BeatGANsAutoencModel 210 | if isinstance(batch, tuple): 211 | a, b = batch 212 | imgs = torch.cat([a['img'], b['img']]) 213 | labels = torch.cat([a['labels'], b['labels']]) 214 | else: 215 | imgs = batch['img'] 216 | # print(f'({self.global_rank}) imgs:', imgs.shape) 217 | labels = batch['labels'] 218 | 219 | if self.conf.train_mode == TrainMode.manipulate: 220 | self.ema_model.eval() 221 | with torch.no_grad(): 222 | # (n, c) 223 | cond = self.ema_model.encoder(imgs) 224 | 225 | if self.conf.manipulate_znormalize: 226 | cond = self.normalize(cond) 227 | 228 | # (n, cls) 229 | pred = self.classifier.forward(cond) 230 | pred_ema = self.ema_classifier.forward(cond) 231 | elif self.conf.train_mode == TrainMode.manipulate_img: 232 | # (n, cls) 233 | pred = self.classifier.forward(imgs) 234 | pred_ema = None 235 | elif self.conf.train_mode == TrainMode.manipulate_imgt: 236 | t, weight = self.T_sampler.sample(len(imgs), imgs.device) 237 | imgs_t = self.sampler.q_sample(imgs, t) 238 | pred = self.classifier.forward(imgs_t, t=t) 239 | pred_ema = None 240 | print('pred:', pred.shape) 241 | else: 242 | raise NotImplementedError() 243 | 244 | if self.conf.manipulate_mode.is_celeba_attr(): 245 | gt = torch.where(labels > 0, 246 | torch.ones_like(labels).float(), 247 | torch.zeros_like(labels).float()) 248 | elif self.conf.manipulate_mode == ManipulateMode.relighting: 249 | gt = labels 250 | else: 251 | raise NotImplementedError() 252 | 253 | if self.conf.manipulate_loss == ManipulateLossType.bce: 254 | loss = F.binary_cross_entropy_with_logits(pred, gt) 255 | if pred_ema is not None: 256 | loss_ema = F.binary_cross_entropy_with_logits(pred_ema, gt) 257 | elif self.conf.manipulate_loss == ManipulateLossType.mse: 258 | loss = F.mse_loss(pred, gt) 259 | if pred_ema is not None: 260 | loss_ema = F.mse_loss(pred_ema, gt) 261 | else: 262 | raise NotImplementedError() 263 | 264 | self.log('loss', loss) 265 | self.log('loss_ema', loss_ema) 266 | return loss 267 | 268 | def on_train_batch_end(self, outputs, batch, batch_idx: int, 269 | dataloader_idx: int) -> None: 270 | ema(self.classifier, self.ema_classifier, self.conf.ema_decay) 271 | 272 | def configure_optimizers(self): 273 | optim = torch.optim.Adam(self.classifier.parameters(), 274 | lr=self.conf.lr, 275 | weight_decay=self.conf.weight_decay) 276 | return optim 277 | 278 | 279 | def ema(source, target, decay): 280 | source_dict = source.state_dict() 281 | target_dict = target.state_dict() 282 | for key in source_dict.keys(): 283 | target_dict[key].data.copy_(target_dict[key].data * decay + 284 | source_dict[key].data * (1 - decay)) 285 | 286 | 287 | def train_cls(conf: TrainConfig, gpus): 288 | print('conf:', conf.name) 289 | model = ClsModel(conf) 290 | 291 | if not os.path.exists(conf.logdir): 292 | os.makedirs(conf.logdir) 293 | checkpoint = ModelCheckpoint( 294 | dirpath=f'{conf.logdir}', 295 | save_last=True, 296 | save_top_k=1, 297 | # every_n_train_steps=conf.save_every_samples // 298 | # conf.batch_size_effective, 299 | ) 300 | checkpoint_path = f'{conf.logdir}/last.ckpt' 301 | if os.path.exists(checkpoint_path): 302 | resume = checkpoint_path 303 | else: 304 | if conf.continue_from is not None: 305 | # continue from a checkpoint 306 | resume = conf.continue_from.path 307 | else: 308 | resume = None 309 | 310 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir, 311 | name=None, 312 | version='') 313 | 314 | # from pytorch_lightning. 315 | 316 | plugins = [] 317 | if len(gpus) == 1: 318 | accelerator = None 319 | else: 320 | accelerator = 'ddp' 321 | from pytorch_lightning.plugins import DDPPlugin 322 | # important for working with gradient checkpoint 323 | plugins.append(DDPPlugin(find_unused_parameters=False)) 324 | 325 | trainer = pl.Trainer( 326 | max_steps=conf.total_samples // conf.batch_size_effective, 327 | resume_from_checkpoint=resume, 328 | gpus=gpus, 329 | accelerator=accelerator, 330 | precision=16 if conf.fp16 else 32, 331 | callbacks=[ 332 | checkpoint, 333 | ], 334 | replace_sampler_ddp=True, 335 | logger=tb_logger, 336 | accumulate_grad_batches=conf.accum_batches, 337 | plugins=plugins, 338 | ) 339 | trainer.fit(model) 340 | -------------------------------------------------------------------------------- /imgs/sandy.JPG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs/sandy.JPG -------------------------------------------------------------------------------- /imgs_align/sandy.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_align/sandy.png -------------------------------------------------------------------------------- /imgs_interpolate/1_a.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_interpolate/1_a.png -------------------------------------------------------------------------------- /imgs_interpolate/1_b.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_interpolate/1_b.png -------------------------------------------------------------------------------- /imgs_manipulated/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_manipulated/compare.png -------------------------------------------------------------------------------- /imgs_manipulated/output.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_manipulated/output.png -------------------------------------------------------------------------------- /imgs_manipulated/sandy-wavyhair.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_manipulated/sandy-wavyhair.png -------------------------------------------------------------------------------- /install_requirements_for_colab.sh: -------------------------------------------------------------------------------- 1 | !pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 pytorch-lightning==1.2.2 torchtext==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 2 | !pip install scipy==1.5.4 3 | !pip install numpy==1.19.5 4 | !pip install tqdm 5 | !pip install pytorch-fid==0.2.0 6 | !pip install pandas==1.1.5 7 | !pip install lpips==0.1.4 8 | !pip install lmdb==1.2.1 9 | !pip install ftfy 10 | !pip install regex 11 | !pip install dlib requests -------------------------------------------------------------------------------- /lmdb_writer.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | from PIL import Image 5 | 6 | import torch 7 | 8 | from contextlib import contextmanager 9 | from torch.utils.data import Dataset 10 | from multiprocessing import Process, Queue 11 | import os 12 | import shutil 13 | 14 | 15 | def convert(x, format, quality=100): 16 | # to prevent locking! 17 | torch.set_num_threads(1) 18 | 19 | buffer = BytesIO() 20 | x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0) 21 | x = x.to(torch.uint8) 22 | x = x.numpy() 23 | img = Image.fromarray(x) 24 | img.save(buffer, format=format, quality=quality) 25 | val = buffer.getvalue() 26 | return val 27 | 28 | 29 | @contextmanager 30 | def nullcontext(): 31 | yield 32 | 33 | 34 | class _WriterWroker(Process): 35 | def __init__(self, path, format, quality, zfill, q): 36 | super().__init__() 37 | if os.path.exists(path): 38 | shutil.rmtree(path) 39 | 40 | self.path = path 41 | self.format = format 42 | self.quality = quality 43 | self.zfill = zfill 44 | self.q = q 45 | self.i = 0 46 | 47 | def run(self): 48 | if not os.path.exists(self.path): 49 | os.makedirs(self.path) 50 | 51 | with lmdb.open(self.path, map_size=1024**4, readahead=False) as env: 52 | while True: 53 | job = self.q.get() 54 | if job is None: 55 | break 56 | with env.begin(write=True) as txn: 57 | for x in job: 58 | key = f"{str(self.i).zfill(self.zfill)}".encode( 59 | "utf-8") 60 | x = convert(x, self.format, self.quality) 61 | txn.put(key, x) 62 | self.i += 1 63 | 64 | with env.begin(write=True) as txn: 65 | txn.put("length".encode("utf-8"), str(self.i).encode("utf-8")) 66 | 67 | 68 | class LMDBImageWriter: 69 | def __init__(self, path, format='webp', quality=100, zfill=7) -> None: 70 | self.path = path 71 | self.format = format 72 | self.quality = quality 73 | self.zfill = zfill 74 | self.queue = None 75 | self.worker = None 76 | 77 | def __enter__(self): 78 | self.queue = Queue(maxsize=3) 79 | self.worker = _WriterWroker(self.path, self.format, self.quality, 80 | self.zfill, self.queue) 81 | self.worker.start() 82 | 83 | def put_images(self, tensor): 84 | """ 85 | Args: 86 | tensor: (n, c, h, w) [0-1] tensor 87 | """ 88 | self.queue.put(tensor.cpu()) 89 | # with self.env.begin(write=True) as txn: 90 | # for x in tensor: 91 | # key = f"{str(self.i).zfill(self.zfill)}".encode("utf-8") 92 | # x = convert(x, self.format, self.quality) 93 | # txn.put(key, x) 94 | # self.i += 1 95 | 96 | def __exit__(self, *args, **kwargs): 97 | self.queue.put(None) 98 | self.queue.close() 99 | self.worker.join() 100 | 101 | 102 | class LMDBImageReader(Dataset): 103 | def __init__(self, path, zfill: int = 7): 104 | self.zfill = zfill 105 | self.env = lmdb.open( 106 | path, 107 | max_readers=32, 108 | readonly=True, 109 | lock=False, 110 | readahead=False, 111 | meminit=False, 112 | ) 113 | 114 | if not self.env: 115 | raise IOError('Cannot open lmdb dataset', path) 116 | 117 | with self.env.begin(write=False) as txn: 118 | self.length = int( 119 | txn.get('length'.encode('utf-8')).decode('utf-8')) 120 | 121 | def __len__(self): 122 | return self.length 123 | 124 | def __getitem__(self, index): 125 | with self.env.begin(write=False) as txn: 126 | key = f'{str(index).zfill(self.zfill)}'.encode('utf-8') 127 | img_bytes = txn.get(key) 128 | 129 | buffer = BytesIO(img_bytes) 130 | img = Image.open(buffer) 131 | return img 132 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | import torch 5 | import torchvision 6 | from pytorch_fid import fid_score 7 | from torch import distributed 8 | from torch.utils.data import DataLoader 9 | from torch.utils.data.distributed import DistributedSampler 10 | from tqdm.autonotebook import tqdm, trange 11 | 12 | from renderer import * 13 | from config import * 14 | from diffusion import Sampler 15 | from dist_utils import * 16 | import lpips 17 | from ssim import ssim 18 | 19 | 20 | def make_subset_loader(conf: TrainConfig, 21 | dataset: Dataset, 22 | batch_size: int, 23 | shuffle: bool, 24 | parallel: bool, 25 | drop_last=True): 26 | dataset = SubsetDataset(dataset, size=conf.eval_num_images) 27 | if parallel and distributed.is_initialized(): 28 | sampler = DistributedSampler(dataset, shuffle=shuffle) 29 | else: 30 | sampler = None 31 | return DataLoader( 32 | dataset, 33 | batch_size=batch_size, 34 | sampler=sampler, 35 | # with sampler, use the sample instead of this option 36 | shuffle=False if sampler else shuffle, 37 | num_workers=conf.num_workers, 38 | pin_memory=True, 39 | drop_last=drop_last, 40 | multiprocessing_context=get_context('fork'), 41 | ) 42 | 43 | 44 | def evaluate_lpips( 45 | sampler: Sampler, 46 | model: Model, 47 | conf: TrainConfig, 48 | device, 49 | val_data: Dataset, 50 | latent_sampler: Sampler = None, 51 | use_inverted_noise: bool = False, 52 | ): 53 | """ 54 | compare the generated images from autoencoder on validation dataset 55 | 56 | Args: 57 | use_inversed_noise: the noise is also inverted from DDIM 58 | """ 59 | lpips_fn = lpips.LPIPS(net='alex').to(device) 60 | val_loader = make_subset_loader(conf, 61 | dataset=val_data, 62 | batch_size=conf.batch_size_eval, 63 | shuffle=False, 64 | parallel=True) 65 | 66 | model.eval() 67 | with torch.no_grad(): 68 | scores = { 69 | 'lpips': [], 70 | 'mse': [], 71 | 'ssim': [], 72 | 'psnr': [], 73 | } 74 | for batch in tqdm(val_loader, desc='lpips'): 75 | imgs = batch['img'].to(device) 76 | 77 | if use_inverted_noise: 78 | # inverse the noise 79 | # with condition from the encoder 80 | model_kwargs = {} 81 | if conf.model_type.has_autoenc(): 82 | with torch.no_grad(): 83 | model_kwargs = model.encode(imgs) 84 | x_T = sampler.ddim_reverse_sample_loop( 85 | model=model, 86 | x=imgs, 87 | clip_denoised=True, 88 | model_kwargs=model_kwargs) 89 | x_T = x_T['sample'] 90 | else: 91 | x_T = torch.randn((len(imgs), 3, conf.img_size, conf.img_size), 92 | device=device) 93 | 94 | if conf.model_type == ModelType.ddpm: 95 | # the case where you want to calculate the inversion capability of the DDIM model 96 | assert use_inverted_noise 97 | pred_imgs = render_uncondition( 98 | conf=conf, 99 | model=model, 100 | x_T=x_T, 101 | sampler=sampler, 102 | latent_sampler=latent_sampler, 103 | ) 104 | else: 105 | pred_imgs = render_condition(conf=conf, 106 | model=model, 107 | x_T=x_T, 108 | x_start=imgs, 109 | cond=None, 110 | sampler=sampler) 111 | # # returns {'cond', 'cond2'} 112 | # conds = model.encode(imgs) 113 | # pred_imgs = sampler.sample(model=model, 114 | # noise=x_T, 115 | # model_kwargs=conds) 116 | 117 | # (n, 1, 1, 1) => (n, ) 118 | scores['lpips'].append(lpips_fn.forward(imgs, pred_imgs).view(-1)) 119 | 120 | # need to normalize into [0, 1] 121 | norm_imgs = (imgs + 1) / 2 122 | norm_pred_imgs = (pred_imgs + 1) / 2 123 | # (n, ) 124 | scores['ssim'].append( 125 | ssim(norm_imgs, norm_pred_imgs, size_average=False)) 126 | # (n, ) 127 | scores['mse'].append( 128 | (norm_imgs - norm_pred_imgs).pow(2).mean(dim=[1, 2, 3])) 129 | # (n, ) 130 | scores['psnr'].append(psnr(norm_imgs, norm_pred_imgs)) 131 | # (N, ) 132 | for key in scores.keys(): 133 | scores[key] = torch.cat(scores[key]).float() 134 | model.train() 135 | 136 | barrier() 137 | 138 | # support multi-gpu 139 | outs = { 140 | key: [ 141 | torch.zeros(len(scores[key]), device=device) 142 | for i in range(get_world_size()) 143 | ] 144 | for key in scores.keys() 145 | } 146 | for key in scores.keys(): 147 | all_gather(outs[key], scores[key]) 148 | 149 | # final scores 150 | for key in scores.keys(): 151 | scores[key] = torch.cat(outs[key]).mean().item() 152 | 153 | # {'lpips', 'mse', 'ssim'} 154 | return scores 155 | 156 | 157 | def psnr(img1, img2): 158 | """ 159 | Args: 160 | img1: (n, c, h, w) 161 | """ 162 | v_max = 1. 163 | # (n,) 164 | mse = torch.mean((img1 - img2)**2, dim=[1, 2, 3]) 165 | return 20 * torch.log10(v_max / torch.sqrt(mse)) 166 | 167 | 168 | def evaluate_fid( 169 | sampler: Sampler, 170 | model: Model, 171 | conf: TrainConfig, 172 | device, 173 | train_data: Dataset, 174 | val_data: Dataset, 175 | latent_sampler: Sampler = None, 176 | conds_mean=None, 177 | conds_std=None, 178 | remove_cache: bool = True, 179 | clip_latent_noise: bool = False, 180 | ): 181 | assert conf.fid_cache is not None 182 | if get_rank() == 0: 183 | # no parallel 184 | # validation data for a comparing FID 185 | val_loader = make_subset_loader(conf, 186 | dataset=val_data, 187 | batch_size=conf.batch_size_eval, 188 | shuffle=False, 189 | parallel=False) 190 | 191 | # put the val images to a directory 192 | cache_dir = f'{conf.fid_cache}_{conf.eval_num_images}' 193 | if (os.path.exists(cache_dir) 194 | and len(os.listdir(cache_dir)) < conf.eval_num_images): 195 | shutil.rmtree(cache_dir) 196 | 197 | if not os.path.exists(cache_dir): 198 | # write files to the cache 199 | # the images are normalized, hence need to denormalize first 200 | loader_to_path(val_loader, cache_dir, denormalize=True) 201 | 202 | # create the generate dir 203 | if os.path.exists(conf.generate_dir): 204 | shutil.rmtree(conf.generate_dir) 205 | os.makedirs(conf.generate_dir) 206 | 207 | barrier() 208 | 209 | world_size = get_world_size() 210 | rank = get_rank() 211 | batch_size = chunk_size(conf.batch_size_eval, rank, world_size) 212 | 213 | def filename(idx): 214 | return world_size * idx + rank 215 | 216 | model.eval() 217 | with torch.no_grad(): 218 | if conf.model_type.can_sample(): 219 | eval_num_images = chunk_size(conf.eval_num_images, rank, 220 | world_size) 221 | desc = "generating images" 222 | for i in trange(0, eval_num_images, batch_size, desc=desc): 223 | batch_size = min(batch_size, eval_num_images - i) 224 | x_T = torch.randn( 225 | (batch_size, 3, conf.img_size, conf.img_size), 226 | device=device) 227 | batch_images = render_uncondition( 228 | conf=conf, 229 | model=model, 230 | x_T=x_T, 231 | sampler=sampler, 232 | latent_sampler=latent_sampler, 233 | conds_mean=conds_mean, 234 | conds_std=conds_std).cpu() 235 | 236 | batch_images = (batch_images + 1) / 2 237 | # keep the generated images 238 | for j in range(len(batch_images)): 239 | img_name = filename(i + j) 240 | torchvision.utils.save_image( 241 | batch_images[j], 242 | os.path.join(conf.generate_dir, f'{img_name}.png')) 243 | elif conf.model_type == ModelType.autoencoder: 244 | if conf.train_mode.is_latent_diffusion(): 245 | # evaluate autoencoder + latent diffusion (doesn't give the images) 246 | model: BeatGANsAutoencModel 247 | eval_num_images = chunk_size(conf.eval_num_images, rank, 248 | world_size) 249 | desc = "generating images" 250 | for i in trange(0, eval_num_images, batch_size, desc=desc): 251 | batch_size = min(batch_size, eval_num_images - i) 252 | x_T = torch.randn( 253 | (batch_size, 3, conf.img_size, conf.img_size), 254 | device=device) 255 | batch_images = render_uncondition( 256 | conf=conf, 257 | model=model, 258 | x_T=x_T, 259 | sampler=sampler, 260 | latent_sampler=latent_sampler, 261 | conds_mean=conds_mean, 262 | conds_std=conds_std, 263 | clip_latent_noise=clip_latent_noise, 264 | ).cpu() 265 | batch_images = (batch_images + 1) / 2 266 | # keep the generated images 267 | for j in range(len(batch_images)): 268 | img_name = filename(i + j) 269 | torchvision.utils.save_image( 270 | batch_images[j], 271 | os.path.join(conf.generate_dir, f'{img_name}.png')) 272 | else: 273 | # evaulate autoencoder (given the images) 274 | # to make the FID fair, autoencoder must not see the validation dataset 275 | # also shuffle to make it closer to unconditional generation 276 | train_loader = make_subset_loader(conf, 277 | dataset=train_data, 278 | batch_size=batch_size, 279 | shuffle=True, 280 | parallel=True) 281 | 282 | i = 0 283 | for batch in tqdm(train_loader, desc='generating images'): 284 | imgs = batch['img'].to(device) 285 | x_T = torch.randn( 286 | (len(imgs), 3, conf.img_size, conf.img_size), 287 | device=device) 288 | batch_images = render_condition( 289 | conf=conf, 290 | model=model, 291 | x_T=x_T, 292 | x_start=imgs, 293 | cond=None, 294 | sampler=sampler, 295 | latent_sampler=latent_sampler).cpu() 296 | # model: BeatGANsAutoencModel 297 | # # returns {'cond', 'cond2'} 298 | # conds = model.encode(imgs) 299 | # batch_images = sampler.sample(model=model, 300 | # noise=x_T, 301 | # model_kwargs=conds).cpu() 302 | # denormalize the images 303 | batch_images = (batch_images + 1) / 2 304 | # keep the generated images 305 | for j in range(len(batch_images)): 306 | img_name = filename(i + j) 307 | torchvision.utils.save_image( 308 | batch_images[j], 309 | os.path.join(conf.generate_dir, f'{img_name}.png')) 310 | i += len(imgs) 311 | else: 312 | raise NotImplementedError() 313 | model.train() 314 | 315 | barrier() 316 | 317 | if get_rank() == 0: 318 | fid = fid_score.calculate_fid_given_paths( 319 | [cache_dir, conf.generate_dir], 320 | batch_size, 321 | device=device, 322 | dims=2048) 323 | 324 | # remove the cache 325 | if remove_cache and os.path.exists(conf.generate_dir): 326 | shutil.rmtree(conf.generate_dir) 327 | 328 | barrier() 329 | 330 | if get_rank() == 0: 331 | # need to float it! unless the broadcasted value is wrong 332 | fid = torch.tensor(float(fid), device=device) 333 | broadcast(fid, 0) 334 | else: 335 | fid = torch.tensor(0., device=device) 336 | broadcast(fid, 0) 337 | fid = fid.item() 338 | print(f'fid ({get_rank()}):', fid) 339 | 340 | return fid 341 | 342 | 343 | def loader_to_path(loader: DataLoader, path: str, denormalize: bool): 344 | # not process safe! 345 | 346 | if not os.path.exists(path): 347 | os.makedirs(path) 348 | 349 | # write the loader to files 350 | i = 0 351 | for batch in tqdm(loader, desc='copy images'): 352 | imgs = batch['img'] 353 | if denormalize: 354 | imgs = (imgs + 1) / 2 355 | for j in range(len(imgs)): 356 | torchvision.utils.save_image(imgs[j], 357 | os.path.join(path, f'{i+j}.png')) 358 | i += len(imgs) 359 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | from .unet import BeatGANsUNetModel, BeatGANsUNetConfig 3 | from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel 4 | 5 | Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel] 6 | ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig] 7 | -------------------------------------------------------------------------------- /model/blocks.py: -------------------------------------------------------------------------------- 1 | import math 2 | from abc import abstractmethod 3 | from dataclasses import dataclass 4 | from numbers import Number 5 | 6 | import torch as th 7 | import torch.nn.functional as F 8 | from choices import * 9 | from config_base import BaseConfig 10 | from torch import nn 11 | 12 | from .nn import (avg_pool_nd, conv_nd, linear, normalization, 13 | timestep_embedding, torch_checkpoint, zero_module) 14 | 15 | 16 | class ScaleAt(Enum): 17 | after_norm = 'afternorm' 18 | 19 | 20 | class TimestepBlock(nn.Module): 21 | """ 22 | Any module where forward() takes timestep embeddings as a second argument. 23 | """ 24 | @abstractmethod 25 | def forward(self, x, emb=None, cond=None, lateral=None): 26 | """ 27 | Apply the module to `x` given `emb` timestep embeddings. 28 | """ 29 | 30 | 31 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock): 32 | """ 33 | A sequential module that passes timestep embeddings to the children that 34 | support it as an extra input. 35 | """ 36 | def forward(self, x, emb=None, cond=None, lateral=None): 37 | for layer in self: 38 | if isinstance(layer, TimestepBlock): 39 | x = layer(x, emb=emb, cond=cond, lateral=lateral) 40 | else: 41 | x = layer(x) 42 | return x 43 | 44 | 45 | @dataclass 46 | class ResBlockConfig(BaseConfig): 47 | channels: int 48 | emb_channels: int 49 | dropout: float 50 | out_channels: int = None 51 | # condition the resblock with time (and encoder's output) 52 | use_condition: bool = True 53 | # whether to use 3x3 conv for skip path when the channels aren't matched 54 | use_conv: bool = False 55 | # dimension of conv (always 2 = 2d) 56 | dims: int = 2 57 | # gradient checkpoint 58 | use_checkpoint: bool = False 59 | up: bool = False 60 | down: bool = False 61 | # whether to condition with both time & encoder's output 62 | two_cond: bool = False 63 | # number of encoders' output channels 64 | cond_emb_channels: int = None 65 | # suggest: False 66 | has_lateral: bool = False 67 | lateral_channels: int = None 68 | # whether to init the convolution with zero weights 69 | # this is default from BeatGANs and seems to help learning 70 | use_zero_module: bool = True 71 | 72 | def __post_init__(self): 73 | self.out_channels = self.out_channels or self.channels 74 | self.cond_emb_channels = self.cond_emb_channels or self.emb_channels 75 | 76 | def make_model(self): 77 | return ResBlock(self) 78 | 79 | 80 | class ResBlock(TimestepBlock): 81 | """ 82 | A residual block that can optionally change the number of channels. 83 | 84 | total layers: 85 | in_layers 86 | - norm 87 | - act 88 | - conv 89 | out_layers 90 | - norm 91 | - (modulation) 92 | - act 93 | - conv 94 | """ 95 | def __init__(self, conf: ResBlockConfig): 96 | super().__init__() 97 | self.conf = conf 98 | 99 | ############################# 100 | # IN LAYERS 101 | ############################# 102 | assert conf.lateral_channels is None 103 | layers = [ 104 | normalization(conf.channels), 105 | nn.SiLU(), 106 | conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1) 107 | ] 108 | self.in_layers = nn.Sequential(*layers) 109 | 110 | self.updown = conf.up or conf.down 111 | 112 | if conf.up: 113 | self.h_upd = Upsample(conf.channels, False, conf.dims) 114 | self.x_upd = Upsample(conf.channels, False, conf.dims) 115 | elif conf.down: 116 | self.h_upd = Downsample(conf.channels, False, conf.dims) 117 | self.x_upd = Downsample(conf.channels, False, conf.dims) 118 | else: 119 | self.h_upd = self.x_upd = nn.Identity() 120 | 121 | ############################# 122 | # OUT LAYERS CONDITIONS 123 | ############################# 124 | if conf.use_condition: 125 | # condition layers for the out_layers 126 | self.emb_layers = nn.Sequential( 127 | nn.SiLU(), 128 | linear(conf.emb_channels, 2 * conf.out_channels), 129 | ) 130 | 131 | if conf.two_cond: 132 | self.cond_emb_layers = nn.Sequential( 133 | nn.SiLU(), 134 | linear(conf.cond_emb_channels, conf.out_channels), 135 | ) 136 | ############################# 137 | # OUT LAYERS (ignored when there is no condition) 138 | ############################# 139 | # original version 140 | conv = conv_nd(conf.dims, 141 | conf.out_channels, 142 | conf.out_channels, 143 | 3, 144 | padding=1) 145 | if conf.use_zero_module: 146 | # zere out the weights 147 | # it seems to help training 148 | conv = zero_module(conv) 149 | 150 | # construct the layers 151 | # - norm 152 | # - (modulation) 153 | # - act 154 | # - dropout 155 | # - conv 156 | layers = [] 157 | layers += [ 158 | normalization(conf.out_channels), 159 | nn.SiLU(), 160 | nn.Dropout(p=conf.dropout), 161 | conv, 162 | ] 163 | self.out_layers = nn.Sequential(*layers) 164 | 165 | ############################# 166 | # SKIP LAYERS 167 | ############################# 168 | if conf.out_channels == conf.channels: 169 | # cannot be used with gatedconv, also gatedconv is alsways used as the first block 170 | self.skip_connection = nn.Identity() 171 | else: 172 | if conf.use_conv: 173 | kernel_size = 3 174 | padding = 1 175 | else: 176 | kernel_size = 1 177 | padding = 0 178 | 179 | self.skip_connection = conv_nd(conf.dims, 180 | conf.channels, 181 | conf.out_channels, 182 | kernel_size, 183 | padding=padding) 184 | 185 | def forward(self, x, emb=None, cond=None, lateral=None): 186 | """ 187 | Apply the block to a Tensor, conditioned on a timestep embedding. 188 | 189 | Args: 190 | x: input 191 | lateral: lateral connection from the encoder 192 | """ 193 | return torch_checkpoint(self._forward, (x, emb, cond, lateral), 194 | self.conf.use_checkpoint) 195 | 196 | def _forward( 197 | self, 198 | x, 199 | emb=None, 200 | cond=None, 201 | lateral=None, 202 | ): 203 | """ 204 | Args: 205 | lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally 206 | """ 207 | if self.conf.has_lateral: 208 | # lateral may be supplied even if it doesn't require 209 | # the model will take the lateral only if "has_lateral" 210 | assert lateral is not None 211 | x = th.cat([x, lateral], dim=1) 212 | 213 | if self.updown: 214 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] 215 | h = in_rest(x) 216 | h = self.h_upd(h) 217 | x = self.x_upd(x) 218 | h = in_conv(h) 219 | else: 220 | h = self.in_layers(x) 221 | 222 | if self.conf.use_condition: 223 | # it's possible that the network may not receieve the time emb 224 | # this happens with autoenc and setting the time_at 225 | if emb is not None: 226 | emb_out = self.emb_layers(emb).type(h.dtype) 227 | else: 228 | emb_out = None 229 | 230 | if self.conf.two_cond: 231 | # it's possible that the network is two_cond 232 | # but it doesn't get the second condition 233 | # in which case, we ignore the second condition 234 | # and treat as if the network has one condition 235 | if cond is None: 236 | cond_out = None 237 | else: 238 | cond_out = self.cond_emb_layers(cond).type(h.dtype) 239 | 240 | if cond_out is not None: 241 | while len(cond_out.shape) < len(h.shape): 242 | cond_out = cond_out[..., None] 243 | else: 244 | cond_out = None 245 | 246 | # this is the new refactored code 247 | h = apply_conditions( 248 | h=h, 249 | emb=emb_out, 250 | cond=cond_out, 251 | layers=self.out_layers, 252 | scale_bias=1, 253 | in_channels=self.conf.out_channels, 254 | up_down_layer=None, 255 | ) 256 | 257 | return self.skip_connection(x) + h 258 | 259 | 260 | def apply_conditions( 261 | h, 262 | emb=None, 263 | cond=None, 264 | layers: nn.Sequential = None, 265 | scale_bias: float = 1, 266 | in_channels: int = 512, 267 | up_down_layer: nn.Module = None, 268 | ): 269 | """ 270 | apply conditions on the feature maps 271 | 272 | Args: 273 | emb: time conditional (ready to scale + shift) 274 | cond: encoder's conditional (read to scale + shift) 275 | """ 276 | two_cond = emb is not None and cond is not None 277 | 278 | if emb is not None: 279 | # adjusting shapes 280 | while len(emb.shape) < len(h.shape): 281 | emb = emb[..., None] 282 | 283 | if two_cond: 284 | # adjusting shapes 285 | while len(cond.shape) < len(h.shape): 286 | cond = cond[..., None] 287 | # time first 288 | scale_shifts = [emb, cond] 289 | else: 290 | # "cond" is not used with single cond mode 291 | scale_shifts = [emb] 292 | 293 | # support scale, shift or shift only 294 | for i, each in enumerate(scale_shifts): 295 | if each is None: 296 | # special case: the condition is not provided 297 | a = None 298 | b = None 299 | else: 300 | if each.shape[1] == in_channels * 2: 301 | a, b = th.chunk(each, 2, dim=1) 302 | else: 303 | a = each 304 | b = None 305 | scale_shifts[i] = (a, b) 306 | 307 | # condition scale bias could be a list 308 | if isinstance(scale_bias, Number): 309 | biases = [scale_bias] * len(scale_shifts) 310 | else: 311 | # a list 312 | biases = scale_bias 313 | 314 | # default, the scale & shift are applied after the group norm but BEFORE SiLU 315 | pre_layers, post_layers = layers[0], layers[1:] 316 | 317 | # spilt the post layer to be able to scale up or down before conv 318 | # post layers will contain only the conv 319 | mid_layers, post_layers = post_layers[:-2], post_layers[-2:] 320 | 321 | h = pre_layers(h) 322 | # scale and shift for each condition 323 | for i, (scale, shift) in enumerate(scale_shifts): 324 | # if scale is None, it indicates that the condition is not provided 325 | if scale is not None: 326 | h = h * (biases[i] + scale) 327 | if shift is not None: 328 | h = h + shift 329 | h = mid_layers(h) 330 | 331 | # upscale or downscale if any just before the last conv 332 | if up_down_layer is not None: 333 | h = up_down_layer(h) 334 | h = post_layers(h) 335 | return h 336 | 337 | 338 | class Upsample(nn.Module): 339 | """ 340 | An upsampling layer with an optional convolution. 341 | 342 | :param channels: channels in the inputs and outputs. 343 | :param use_conv: a bool determining if a convolution is applied. 344 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 345 | upsampling occurs in the inner-two dimensions. 346 | """ 347 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 348 | super().__init__() 349 | self.channels = channels 350 | self.out_channels = out_channels or channels 351 | self.use_conv = use_conv 352 | self.dims = dims 353 | if use_conv: 354 | self.conv = conv_nd(dims, 355 | self.channels, 356 | self.out_channels, 357 | 3, 358 | padding=1) 359 | 360 | def forward(self, x): 361 | assert x.shape[1] == self.channels 362 | if self.dims == 3: 363 | x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), 364 | mode="nearest") 365 | else: 366 | x = F.interpolate(x, scale_factor=2, mode="nearest") 367 | if self.use_conv: 368 | x = self.conv(x) 369 | return x 370 | 371 | 372 | class Downsample(nn.Module): 373 | """ 374 | A downsampling layer with an optional convolution. 375 | 376 | :param channels: channels in the inputs and outputs. 377 | :param use_conv: a bool determining if a convolution is applied. 378 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 379 | downsampling occurs in the inner-two dimensions. 380 | """ 381 | def __init__(self, channels, use_conv, dims=2, out_channels=None): 382 | super().__init__() 383 | self.channels = channels 384 | self.out_channels = out_channels or channels 385 | self.use_conv = use_conv 386 | self.dims = dims 387 | stride = 2 if dims != 3 else (1, 2, 2) 388 | if use_conv: 389 | self.op = conv_nd(dims, 390 | self.channels, 391 | self.out_channels, 392 | 3, 393 | stride=stride, 394 | padding=1) 395 | else: 396 | assert self.channels == self.out_channels 397 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 398 | 399 | def forward(self, x): 400 | assert x.shape[1] == self.channels 401 | return self.op(x) 402 | 403 | 404 | class AttentionBlock(nn.Module): 405 | """ 406 | An attention block that allows spatial positions to attend to each other. 407 | 408 | Originally ported from here, but adapted to the N-d case. 409 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. 410 | """ 411 | def __init__( 412 | self, 413 | channels, 414 | num_heads=1, 415 | num_head_channels=-1, 416 | use_checkpoint=False, 417 | use_new_attention_order=False, 418 | ): 419 | super().__init__() 420 | self.channels = channels 421 | if num_head_channels == -1: 422 | self.num_heads = num_heads 423 | else: 424 | assert ( 425 | channels % num_head_channels == 0 426 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" 427 | self.num_heads = channels // num_head_channels 428 | self.use_checkpoint = use_checkpoint 429 | self.norm = normalization(channels) 430 | self.qkv = conv_nd(1, channels, channels * 3, 1) 431 | if use_new_attention_order: 432 | # split qkv before split heads 433 | self.attention = QKVAttention(self.num_heads) 434 | else: 435 | # split heads before split qkv 436 | self.attention = QKVAttentionLegacy(self.num_heads) 437 | 438 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) 439 | 440 | def forward(self, x): 441 | return torch_checkpoint(self._forward, (x, ), self.use_checkpoint) 442 | 443 | def _forward(self, x): 444 | b, c, *spatial = x.shape 445 | x = x.reshape(b, c, -1) 446 | qkv = self.qkv(self.norm(x)) 447 | h = self.attention(qkv) 448 | h = self.proj_out(h) 449 | return (x + h).reshape(b, c, *spatial) 450 | 451 | 452 | def count_flops_attn(model, _x, y): 453 | """ 454 | A counter for the `thop` package to count the operations in an 455 | attention operation. 456 | Meant to be used like: 457 | macs, params = thop.profile( 458 | model, 459 | inputs=(inputs, timestamps), 460 | custom_ops={QKVAttention: QKVAttention.count_flops}, 461 | ) 462 | """ 463 | b, c, *spatial = y[0].shape 464 | num_spatial = int(np.prod(spatial)) 465 | # We perform two matmuls with the same number of ops. 466 | # The first computes the weight matrix, the second computes 467 | # the combination of the value vectors. 468 | matmul_ops = 2 * b * (num_spatial**2) * c 469 | model.total_ops += th.DoubleTensor([matmul_ops]) 470 | 471 | 472 | class QKVAttentionLegacy(nn.Module): 473 | """ 474 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping 475 | """ 476 | def __init__(self, n_heads): 477 | super().__init__() 478 | self.n_heads = n_heads 479 | 480 | def forward(self, qkv): 481 | """ 482 | Apply QKV attention. 483 | 484 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. 485 | :return: an [N x (H * C) x T] tensor after attention. 486 | """ 487 | bs, width, length = qkv.shape 488 | assert width % (3 * self.n_heads) == 0 489 | ch = width // (3 * self.n_heads) 490 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, 491 | dim=1) 492 | scale = 1 / math.sqrt(math.sqrt(ch)) 493 | weight = th.einsum( 494 | "bct,bcs->bts", q * scale, 495 | k * scale) # More stable with f16 than dividing afterwards 496 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 497 | a = th.einsum("bts,bcs->bct", weight, v) 498 | return a.reshape(bs, -1, length) 499 | 500 | @staticmethod 501 | def count_flops(model, _x, y): 502 | return count_flops_attn(model, _x, y) 503 | 504 | 505 | class QKVAttention(nn.Module): 506 | """ 507 | A module which performs QKV attention and splits in a different order. 508 | """ 509 | def __init__(self, n_heads): 510 | super().__init__() 511 | self.n_heads = n_heads 512 | 513 | def forward(self, qkv): 514 | """ 515 | Apply QKV attention. 516 | 517 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. 518 | :return: an [N x (H * C) x T] tensor after attention. 519 | """ 520 | bs, width, length = qkv.shape 521 | assert width % (3 * self.n_heads) == 0 522 | ch = width // (3 * self.n_heads) 523 | q, k, v = qkv.chunk(3, dim=1) 524 | scale = 1 / math.sqrt(math.sqrt(ch)) 525 | weight = th.einsum( 526 | "bct,bcs->bts", 527 | (q * scale).view(bs * self.n_heads, ch, length), 528 | (k * scale).view(bs * self.n_heads, ch, length), 529 | ) # More stable with f16 than dividing afterwards 530 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) 531 | a = th.einsum("bts,bcs->bct", weight, 532 | v.reshape(bs * self.n_heads, ch, length)) 533 | return a.reshape(bs, -1, length) 534 | 535 | @staticmethod 536 | def count_flops(model, _x, y): 537 | return count_flops_attn(model, _x, y) 538 | 539 | 540 | class AttentionPool2d(nn.Module): 541 | """ 542 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py 543 | """ 544 | def __init__( 545 | self, 546 | spacial_dim: int, 547 | embed_dim: int, 548 | num_heads_channels: int, 549 | output_dim: int = None, 550 | ): 551 | super().__init__() 552 | self.positional_embedding = nn.Parameter( 553 | th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5) 554 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) 555 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) 556 | self.num_heads = embed_dim // num_heads_channels 557 | self.attention = QKVAttention(self.num_heads) 558 | 559 | def forward(self, x): 560 | b, c, *_spatial = x.shape 561 | x = x.reshape(b, c, -1) # NC(HW) 562 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) 563 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) 564 | x = self.qkv_proj(x) 565 | x = self.attention(x) 566 | x = self.c_proj(x) 567 | return x[:, :, 0] 568 | -------------------------------------------------------------------------------- /model/latentnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from enum import Enum 4 | from typing import NamedTuple, Tuple 5 | 6 | import torch 7 | from choices import * 8 | from config_base import BaseConfig 9 | from torch import nn 10 | from torch.nn import init 11 | 12 | from .blocks import * 13 | from .nn import timestep_embedding 14 | from .unet import * 15 | 16 | 17 | class LatentNetType(Enum): 18 | none = 'none' 19 | # injecting inputs into the hidden layers 20 | skip = 'skip' 21 | 22 | 23 | class LatentNetReturn(NamedTuple): 24 | pred: torch.Tensor = None 25 | 26 | 27 | @dataclass 28 | class MLPSkipNetConfig(BaseConfig): 29 | """ 30 | default MLP for the latent DPM in the paper! 31 | """ 32 | num_channels: int 33 | skip_layers: Tuple[int] 34 | num_hid_channels: int 35 | num_layers: int 36 | num_time_emb_channels: int = 64 37 | activation: Activation = Activation.silu 38 | use_norm: bool = True 39 | condition_bias: float = 1 40 | dropout: float = 0 41 | last_act: Activation = Activation.none 42 | num_time_layers: int = 2 43 | time_last_act: bool = False 44 | 45 | def make_model(self): 46 | return MLPSkipNet(self) 47 | 48 | 49 | class MLPSkipNet(nn.Module): 50 | """ 51 | concat x to hidden layers 52 | 53 | default MLP for the latent DPM in the paper! 54 | """ 55 | def __init__(self, conf: MLPSkipNetConfig): 56 | super().__init__() 57 | self.conf = conf 58 | 59 | layers = [] 60 | for i in range(conf.num_time_layers): 61 | if i == 0: 62 | a = conf.num_time_emb_channels 63 | b = conf.num_channels 64 | else: 65 | a = conf.num_channels 66 | b = conf.num_channels 67 | layers.append(nn.Linear(a, b)) 68 | if i < conf.num_time_layers - 1 or conf.time_last_act: 69 | layers.append(conf.activation.get_act()) 70 | self.time_embed = nn.Sequential(*layers) 71 | 72 | self.layers = nn.ModuleList([]) 73 | for i in range(conf.num_layers): 74 | if i == 0: 75 | act = conf.activation 76 | norm = conf.use_norm 77 | cond = True 78 | a, b = conf.num_channels, conf.num_hid_channels 79 | dropout = conf.dropout 80 | elif i == conf.num_layers - 1: 81 | act = Activation.none 82 | norm = False 83 | cond = False 84 | a, b = conf.num_hid_channels, conf.num_channels 85 | dropout = 0 86 | else: 87 | act = conf.activation 88 | norm = conf.use_norm 89 | cond = True 90 | a, b = conf.num_hid_channels, conf.num_hid_channels 91 | dropout = conf.dropout 92 | 93 | if i in conf.skip_layers: 94 | a += conf.num_channels 95 | 96 | self.layers.append( 97 | MLPLNAct( 98 | a, 99 | b, 100 | norm=norm, 101 | activation=act, 102 | cond_channels=conf.num_channels, 103 | use_cond=cond, 104 | condition_bias=conf.condition_bias, 105 | dropout=dropout, 106 | )) 107 | self.last_act = conf.last_act.get_act() 108 | 109 | def forward(self, x, t, **kwargs): 110 | t = timestep_embedding(t, self.conf.num_time_emb_channels) 111 | cond = self.time_embed(t) 112 | h = x 113 | for i in range(len(self.layers)): 114 | if i in self.conf.skip_layers: 115 | # injecting input into the hidden layers 116 | h = torch.cat([h, x], dim=1) 117 | h = self.layers[i].forward(x=h, cond=cond) 118 | h = self.last_act(h) 119 | return LatentNetReturn(h) 120 | 121 | 122 | class MLPLNAct(nn.Module): 123 | def __init__( 124 | self, 125 | in_channels: int, 126 | out_channels: int, 127 | norm: bool, 128 | use_cond: bool, 129 | activation: Activation, 130 | cond_channels: int, 131 | condition_bias: float = 0, 132 | dropout: float = 0, 133 | ): 134 | super().__init__() 135 | self.activation = activation 136 | self.condition_bias = condition_bias 137 | self.use_cond = use_cond 138 | 139 | self.linear = nn.Linear(in_channels, out_channels) 140 | self.act = activation.get_act() 141 | if self.use_cond: 142 | self.linear_emb = nn.Linear(cond_channels, out_channels) 143 | self.cond_layers = nn.Sequential(self.act, self.linear_emb) 144 | if norm: 145 | self.norm = nn.LayerNorm(out_channels) 146 | else: 147 | self.norm = nn.Identity() 148 | 149 | if dropout > 0: 150 | self.dropout = nn.Dropout(p=dropout) 151 | else: 152 | self.dropout = nn.Identity() 153 | 154 | self.init_weights() 155 | 156 | def init_weights(self): 157 | for module in self.modules(): 158 | if isinstance(module, nn.Linear): 159 | if self.activation == Activation.relu: 160 | init.kaiming_normal_(module.weight, 161 | a=0, 162 | nonlinearity='relu') 163 | elif self.activation == Activation.lrelu: 164 | init.kaiming_normal_(module.weight, 165 | a=0.2, 166 | nonlinearity='leaky_relu') 167 | elif self.activation == Activation.silu: 168 | init.kaiming_normal_(module.weight, 169 | a=0, 170 | nonlinearity='relu') 171 | else: 172 | # leave it as default 173 | pass 174 | 175 | def forward(self, x, cond=None): 176 | x = self.linear(x) 177 | if self.use_cond: 178 | # (n, c) or (n, c * 2) 179 | cond = self.cond_layers(cond) 180 | cond = (cond, None) 181 | 182 | # scale shift first 183 | x = x * (self.condition_bias + cond[0]) 184 | if cond[1] is not None: 185 | x = x + cond[1] 186 | # then norm 187 | x = self.norm(x) 188 | else: 189 | # no condition 190 | x = self.norm(x) 191 | x = self.act(x) 192 | x = self.dropout(x) 193 | return x -------------------------------------------------------------------------------- /model/nn.py: -------------------------------------------------------------------------------- 1 | """ 2 | Various utilities for neural networks. 3 | """ 4 | 5 | from enum import Enum 6 | import math 7 | from typing import Optional 8 | 9 | import torch as th 10 | import torch.nn as nn 11 | import torch.utils.checkpoint 12 | 13 | import torch.nn.functional as F 14 | 15 | 16 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 17 | class SiLU(nn.Module): 18 | # @th.jit.script 19 | def forward(self, x): 20 | return x * th.sigmoid(x) 21 | 22 | 23 | class GroupNorm32(nn.GroupNorm): 24 | def forward(self, x): 25 | return super().forward(x.float()).type(x.dtype) 26 | 27 | 28 | def conv_nd(dims, *args, **kwargs): 29 | """ 30 | Create a 1D, 2D, or 3D convolution module. 31 | """ 32 | if dims == 1: 33 | return nn.Conv1d(*args, **kwargs) 34 | elif dims == 2: 35 | return nn.Conv2d(*args, **kwargs) 36 | elif dims == 3: 37 | return nn.Conv3d(*args, **kwargs) 38 | raise ValueError(f"unsupported dimensions: {dims}") 39 | 40 | 41 | def linear(*args, **kwargs): 42 | """ 43 | Create a linear module. 44 | """ 45 | return nn.Linear(*args, **kwargs) 46 | 47 | 48 | def avg_pool_nd(dims, *args, **kwargs): 49 | """ 50 | Create a 1D, 2D, or 3D average pooling module. 51 | """ 52 | if dims == 1: 53 | return nn.AvgPool1d(*args, **kwargs) 54 | elif dims == 2: 55 | return nn.AvgPool2d(*args, **kwargs) 56 | elif dims == 3: 57 | return nn.AvgPool3d(*args, **kwargs) 58 | raise ValueError(f"unsupported dimensions: {dims}") 59 | 60 | 61 | def update_ema(target_params, source_params, rate=0.99): 62 | """ 63 | Update target parameters to be closer to those of source parameters using 64 | an exponential moving average. 65 | 66 | :param target_params: the target parameter sequence. 67 | :param source_params: the source parameter sequence. 68 | :param rate: the EMA rate (closer to 1 means slower). 69 | """ 70 | for targ, src in zip(target_params, source_params): 71 | targ.detach().mul_(rate).add_(src, alpha=1 - rate) 72 | 73 | 74 | def zero_module(module): 75 | """ 76 | Zero out the parameters of a module and return it. 77 | """ 78 | for p in module.parameters(): 79 | p.detach().zero_() 80 | return module 81 | 82 | 83 | def scale_module(module, scale): 84 | """ 85 | Scale the parameters of a module and return it. 86 | """ 87 | for p in module.parameters(): 88 | p.detach().mul_(scale) 89 | return module 90 | 91 | 92 | def mean_flat(tensor): 93 | """ 94 | Take the mean over all non-batch dimensions. 95 | """ 96 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 97 | 98 | 99 | def normalization(channels): 100 | """ 101 | Make a standard normalization layer. 102 | 103 | :param channels: number of input channels. 104 | :return: an nn.Module for normalization. 105 | """ 106 | return GroupNorm32(min(32, channels), channels) 107 | 108 | 109 | def timestep_embedding(timesteps, dim, max_period=10000): 110 | """ 111 | Create sinusoidal timestep embeddings. 112 | 113 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 114 | These may be fractional. 115 | :param dim: the dimension of the output. 116 | :param max_period: controls the minimum frequency of the embeddings. 117 | :return: an [N x dim] Tensor of positional embeddings. 118 | """ 119 | half = dim // 2 120 | freqs = th.exp(-math.log(max_period) * 121 | th.arange(start=0, end=half, dtype=th.float32) / 122 | half).to(device=timesteps.device) 123 | args = timesteps[:, None].float() * freqs[None] 124 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1) 125 | if dim % 2: 126 | embedding = th.cat( 127 | [embedding, th.zeros_like(embedding[:, :1])], dim=-1) 128 | return embedding 129 | 130 | 131 | def torch_checkpoint(func, args, flag, preserve_rng_state=False): 132 | # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8 133 | if flag: 134 | return torch.utils.checkpoint.checkpoint( 135 | func, *args, preserve_rng_state=preserve_rng_state) 136 | else: 137 | return func(*args) 138 | -------------------------------------------------------------------------------- /model/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from numbers import Number 4 | from typing import NamedTuple, Tuple, Union 5 | 6 | import numpy as np 7 | import torch as th 8 | from torch import nn 9 | import torch.nn.functional as F 10 | from choices import * 11 | from config_base import BaseConfig 12 | from .blocks import * 13 | 14 | from .nn import (conv_nd, linear, normalization, timestep_embedding, 15 | torch_checkpoint, zero_module) 16 | 17 | 18 | @dataclass 19 | class BeatGANsUNetConfig(BaseConfig): 20 | image_size: int = 64 21 | in_channels: int = 3 22 | # base channels, will be multiplied 23 | model_channels: int = 64 24 | # output of the unet 25 | # suggest: 3 26 | # you only need 6 if you also model the variance of the noise prediction (usually we use an analytical variance hence 3) 27 | out_channels: int = 3 28 | # how many repeating resblocks per resolution 29 | # the decoding side would have "one more" resblock 30 | # default: 2 31 | num_res_blocks: int = 2 32 | # you can also set the number of resblocks specifically for the input blocks 33 | # default: None = above 34 | num_input_res_blocks: int = None 35 | # number of time embed channels and style channels 36 | embed_channels: int = 512 37 | # at what resolutions you want to do self-attention of the feature maps 38 | # attentions generally improve performance 39 | # default: [16] 40 | # beatgans: [32, 16, 8] 41 | attention_resolutions: Tuple[int] = (16, ) 42 | # number of time embed channels 43 | time_embed_channels: int = None 44 | # dropout applies to the resblocks (on feature maps) 45 | dropout: float = 0.1 46 | channel_mult: Tuple[int] = (1, 2, 4, 8) 47 | input_channel_mult: Tuple[int] = None 48 | conv_resample: bool = True 49 | # always 2 = 2d conv 50 | dims: int = 2 51 | # don't use this, legacy from BeatGANs 52 | num_classes: int = None 53 | use_checkpoint: bool = False 54 | # number of attention heads 55 | num_heads: int = 1 56 | # or specify the number of channels per attention head 57 | num_head_channels: int = -1 58 | # what's this? 59 | num_heads_upsample: int = -1 60 | # use resblock for upscale/downscale blocks (expensive) 61 | # default: True (BeatGANs) 62 | resblock_updown: bool = True 63 | # never tried 64 | use_new_attention_order: bool = False 65 | resnet_two_cond: bool = False 66 | resnet_cond_channels: int = None 67 | # init the decoding conv layers with zero weights, this speeds up training 68 | # default: True (BeattGANs) 69 | resnet_use_zero_module: bool = True 70 | # gradient checkpoint the attention operation 71 | attn_checkpoint: bool = False 72 | 73 | def make_model(self): 74 | return BeatGANsUNetModel(self) 75 | 76 | 77 | class BeatGANsUNetModel(nn.Module): 78 | def __init__(self, conf: BeatGANsUNetConfig): 79 | super().__init__() 80 | self.conf = conf 81 | 82 | if conf.num_heads_upsample == -1: 83 | self.num_heads_upsample = conf.num_heads 84 | 85 | self.dtype = th.float32 86 | 87 | self.time_emb_channels = conf.time_embed_channels or conf.model_channels 88 | self.time_embed = nn.Sequential( 89 | linear(self.time_emb_channels, conf.embed_channels), 90 | nn.SiLU(), 91 | linear(conf.embed_channels, conf.embed_channels), 92 | ) 93 | 94 | if conf.num_classes is not None: 95 | self.label_emb = nn.Embedding(conf.num_classes, 96 | conf.embed_channels) 97 | 98 | ch = input_ch = int(conf.channel_mult[0] * conf.model_channels) 99 | self.input_blocks = nn.ModuleList([ 100 | TimestepEmbedSequential( 101 | conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)) 102 | ]) 103 | 104 | kwargs = dict( 105 | use_condition=True, 106 | two_cond=conf.resnet_two_cond, 107 | use_zero_module=conf.resnet_use_zero_module, 108 | # style channels for the resnet block 109 | cond_emb_channels=conf.resnet_cond_channels, 110 | ) 111 | 112 | self._feature_size = ch 113 | 114 | # input_block_chans = [ch] 115 | input_block_chans = [[] for _ in range(len(conf.channel_mult))] 116 | input_block_chans[0].append(ch) 117 | 118 | # number of blocks at each resolution 119 | self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))] 120 | self.input_num_blocks[0] = 1 121 | self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))] 122 | 123 | ds = 1 124 | resolution = conf.image_size 125 | for level, mult in enumerate(conf.input_channel_mult 126 | or conf.channel_mult): 127 | for _ in range(conf.num_input_res_blocks or conf.num_res_blocks): 128 | layers = [ 129 | ResBlockConfig( 130 | ch, 131 | conf.embed_channels, 132 | conf.dropout, 133 | out_channels=int(mult * conf.model_channels), 134 | dims=conf.dims, 135 | use_checkpoint=conf.use_checkpoint, 136 | **kwargs, 137 | ).make_model() 138 | ] 139 | ch = int(mult * conf.model_channels) 140 | if resolution in conf.attention_resolutions: 141 | layers.append( 142 | AttentionBlock( 143 | ch, 144 | use_checkpoint=conf.use_checkpoint 145 | or conf.attn_checkpoint, 146 | num_heads=conf.num_heads, 147 | num_head_channels=conf.num_head_channels, 148 | use_new_attention_order=conf. 149 | use_new_attention_order, 150 | )) 151 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 152 | self._feature_size += ch 153 | # input_block_chans.append(ch) 154 | input_block_chans[level].append(ch) 155 | self.input_num_blocks[level] += 1 156 | # print(input_block_chans) 157 | if level != len(conf.channel_mult) - 1: 158 | resolution //= 2 159 | out_ch = ch 160 | self.input_blocks.append( 161 | TimestepEmbedSequential( 162 | ResBlockConfig( 163 | ch, 164 | conf.embed_channels, 165 | conf.dropout, 166 | out_channels=out_ch, 167 | dims=conf.dims, 168 | use_checkpoint=conf.use_checkpoint, 169 | down=True, 170 | **kwargs, 171 | ).make_model() if conf. 172 | resblock_updown else Downsample(ch, 173 | conf.conv_resample, 174 | dims=conf.dims, 175 | out_channels=out_ch))) 176 | ch = out_ch 177 | # input_block_chans.append(ch) 178 | input_block_chans[level + 1].append(ch) 179 | self.input_num_blocks[level + 1] += 1 180 | ds *= 2 181 | self._feature_size += ch 182 | 183 | self.middle_block = TimestepEmbedSequential( 184 | ResBlockConfig( 185 | ch, 186 | conf.embed_channels, 187 | conf.dropout, 188 | dims=conf.dims, 189 | use_checkpoint=conf.use_checkpoint, 190 | **kwargs, 191 | ).make_model(), 192 | AttentionBlock( 193 | ch, 194 | use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint, 195 | num_heads=conf.num_heads, 196 | num_head_channels=conf.num_head_channels, 197 | use_new_attention_order=conf.use_new_attention_order, 198 | ), 199 | ResBlockConfig( 200 | ch, 201 | conf.embed_channels, 202 | conf.dropout, 203 | dims=conf.dims, 204 | use_checkpoint=conf.use_checkpoint, 205 | **kwargs, 206 | ).make_model(), 207 | ) 208 | self._feature_size += ch 209 | 210 | self.output_blocks = nn.ModuleList([]) 211 | for level, mult in list(enumerate(conf.channel_mult))[::-1]: 212 | for i in range(conf.num_res_blocks + 1): 213 | # print(input_block_chans) 214 | # ich = input_block_chans.pop() 215 | try: 216 | ich = input_block_chans[level].pop() 217 | except IndexError: 218 | # this happens only when num_res_block > num_enc_res_block 219 | # we will not have enough lateral (skip) connecions for all decoder blocks 220 | ich = 0 221 | # print('pop:', ich) 222 | layers = [ 223 | ResBlockConfig( 224 | # only direct channels when gated 225 | channels=ch + ich, 226 | emb_channels=conf.embed_channels, 227 | dropout=conf.dropout, 228 | out_channels=int(conf.model_channels * mult), 229 | dims=conf.dims, 230 | use_checkpoint=conf.use_checkpoint, 231 | # lateral channels are described here when gated 232 | has_lateral=True if ich > 0 else False, 233 | lateral_channels=None, 234 | **kwargs, 235 | ).make_model() 236 | ] 237 | ch = int(conf.model_channels * mult) 238 | if resolution in conf.attention_resolutions: 239 | layers.append( 240 | AttentionBlock( 241 | ch, 242 | use_checkpoint=conf.use_checkpoint 243 | or conf.attn_checkpoint, 244 | num_heads=self.num_heads_upsample, 245 | num_head_channels=conf.num_head_channels, 246 | use_new_attention_order=conf. 247 | use_new_attention_order, 248 | )) 249 | if level and i == conf.num_res_blocks: 250 | resolution *= 2 251 | out_ch = ch 252 | layers.append( 253 | ResBlockConfig( 254 | ch, 255 | conf.embed_channels, 256 | conf.dropout, 257 | out_channels=out_ch, 258 | dims=conf.dims, 259 | use_checkpoint=conf.use_checkpoint, 260 | up=True, 261 | **kwargs, 262 | ).make_model() if ( 263 | conf.resblock_updown 264 | ) else Upsample(ch, 265 | conf.conv_resample, 266 | dims=conf.dims, 267 | out_channels=out_ch)) 268 | ds //= 2 269 | self.output_blocks.append(TimestepEmbedSequential(*layers)) 270 | self.output_num_blocks[level] += 1 271 | self._feature_size += ch 272 | 273 | # print(input_block_chans) 274 | # print('inputs:', self.input_num_blocks) 275 | # print('outputs:', self.output_num_blocks) 276 | 277 | if conf.resnet_use_zero_module: 278 | self.out = nn.Sequential( 279 | normalization(ch), 280 | nn.SiLU(), 281 | zero_module( 282 | conv_nd(conf.dims, 283 | input_ch, 284 | conf.out_channels, 285 | 3, 286 | padding=1)), 287 | ) 288 | else: 289 | self.out = nn.Sequential( 290 | normalization(ch), 291 | nn.SiLU(), 292 | conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1), 293 | ) 294 | 295 | def forward(self, x, t, y=None, **kwargs): 296 | """ 297 | Apply the model to an input batch. 298 | 299 | :param x: an [N x C x ...] Tensor of inputs. 300 | :param timesteps: a 1-D batch of timesteps. 301 | :param y: an [N] Tensor of labels, if class-conditional. 302 | :return: an [N x C x ...] Tensor of outputs. 303 | """ 304 | assert (y is not None) == ( 305 | self.conf.num_classes is not None 306 | ), "must specify y if and only if the model is class-conditional" 307 | 308 | # hs = [] 309 | hs = [[] for _ in range(len(self.conf.channel_mult))] 310 | emb = self.time_embed(timestep_embedding(t, self.time_emb_channels)) 311 | 312 | if self.conf.num_classes is not None: 313 | raise NotImplementedError() 314 | # assert y.shape == (x.shape[0], ) 315 | # emb = emb + self.label_emb(y) 316 | 317 | # new code supports input_num_blocks != output_num_blocks 318 | h = x.type(self.dtype) 319 | k = 0 320 | for i in range(len(self.input_num_blocks)): 321 | for j in range(self.input_num_blocks[i]): 322 | h = self.input_blocks[k](h, emb=emb) 323 | # print(i, j, h.shape) 324 | hs[i].append(h) 325 | k += 1 326 | assert k == len(self.input_blocks) 327 | 328 | h = self.middle_block(h, emb=emb) 329 | k = 0 330 | for i in range(len(self.output_num_blocks)): 331 | for j in range(self.output_num_blocks[i]): 332 | # take the lateral connection from the same layer (in reserve) 333 | # until there is no more, use None 334 | try: 335 | lateral = hs[-i - 1].pop() 336 | # print(i, j, lateral.shape) 337 | except IndexError: 338 | lateral = None 339 | # print(i, j, lateral) 340 | h = self.output_blocks[k](h, emb=emb, lateral=lateral) 341 | k += 1 342 | 343 | h = h.type(x.dtype) 344 | pred = self.out(h) 345 | return Return(pred=pred) 346 | 347 | 348 | class Return(NamedTuple): 349 | pred: th.Tensor 350 | 351 | 352 | @dataclass 353 | class BeatGANsEncoderConfig(BaseConfig): 354 | image_size: int 355 | in_channels: int 356 | model_channels: int 357 | out_hid_channels: int 358 | out_channels: int 359 | num_res_blocks: int 360 | attention_resolutions: Tuple[int] 361 | dropout: float = 0 362 | channel_mult: Tuple[int] = (1, 2, 4, 8) 363 | use_time_condition: bool = True 364 | conv_resample: bool = True 365 | dims: int = 2 366 | use_checkpoint: bool = False 367 | num_heads: int = 1 368 | num_head_channels: int = -1 369 | resblock_updown: bool = False 370 | use_new_attention_order: bool = False 371 | pool: str = 'adaptivenonzero' 372 | 373 | def make_model(self): 374 | return BeatGANsEncoderModel(self) 375 | 376 | 377 | class BeatGANsEncoderModel(nn.Module): 378 | """ 379 | The half UNet model with attention and timestep embedding. 380 | 381 | For usage, see UNet. 382 | """ 383 | def __init__(self, conf: BeatGANsEncoderConfig): 384 | super().__init__() 385 | self.conf = conf 386 | self.dtype = th.float32 387 | 388 | if conf.use_time_condition: 389 | time_embed_dim = conf.model_channels * 4 390 | self.time_embed = nn.Sequential( 391 | linear(conf.model_channels, time_embed_dim), 392 | nn.SiLU(), 393 | linear(time_embed_dim, time_embed_dim), 394 | ) 395 | else: 396 | time_embed_dim = None 397 | 398 | ch = int(conf.channel_mult[0] * conf.model_channels) 399 | self.input_blocks = nn.ModuleList([ 400 | TimestepEmbedSequential( 401 | conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1)) 402 | ]) 403 | self._feature_size = ch 404 | input_block_chans = [ch] 405 | ds = 1 406 | resolution = conf.image_size 407 | for level, mult in enumerate(conf.channel_mult): 408 | for _ in range(conf.num_res_blocks): 409 | layers = [ 410 | ResBlockConfig( 411 | ch, 412 | time_embed_dim, 413 | conf.dropout, 414 | out_channels=int(mult * conf.model_channels), 415 | dims=conf.dims, 416 | use_condition=conf.use_time_condition, 417 | use_checkpoint=conf.use_checkpoint, 418 | ).make_model() 419 | ] 420 | ch = int(mult * conf.model_channels) 421 | if resolution in conf.attention_resolutions: 422 | layers.append( 423 | AttentionBlock( 424 | ch, 425 | use_checkpoint=conf.use_checkpoint, 426 | num_heads=conf.num_heads, 427 | num_head_channels=conf.num_head_channels, 428 | use_new_attention_order=conf. 429 | use_new_attention_order, 430 | )) 431 | self.input_blocks.append(TimestepEmbedSequential(*layers)) 432 | self._feature_size += ch 433 | input_block_chans.append(ch) 434 | if level != len(conf.channel_mult) - 1: 435 | resolution //= 2 436 | out_ch = ch 437 | self.input_blocks.append( 438 | TimestepEmbedSequential( 439 | ResBlockConfig( 440 | ch, 441 | time_embed_dim, 442 | conf.dropout, 443 | out_channels=out_ch, 444 | dims=conf.dims, 445 | use_condition=conf.use_time_condition, 446 | use_checkpoint=conf.use_checkpoint, 447 | down=True, 448 | ).make_model() if ( 449 | conf.resblock_updown 450 | ) else Downsample(ch, 451 | conf.conv_resample, 452 | dims=conf.dims, 453 | out_channels=out_ch))) 454 | ch = out_ch 455 | input_block_chans.append(ch) 456 | ds *= 2 457 | self._feature_size += ch 458 | 459 | self.middle_block = TimestepEmbedSequential( 460 | ResBlockConfig( 461 | ch, 462 | time_embed_dim, 463 | conf.dropout, 464 | dims=conf.dims, 465 | use_condition=conf.use_time_condition, 466 | use_checkpoint=conf.use_checkpoint, 467 | ).make_model(), 468 | AttentionBlock( 469 | ch, 470 | use_checkpoint=conf.use_checkpoint, 471 | num_heads=conf.num_heads, 472 | num_head_channels=conf.num_head_channels, 473 | use_new_attention_order=conf.use_new_attention_order, 474 | ), 475 | ResBlockConfig( 476 | ch, 477 | time_embed_dim, 478 | conf.dropout, 479 | dims=conf.dims, 480 | use_condition=conf.use_time_condition, 481 | use_checkpoint=conf.use_checkpoint, 482 | ).make_model(), 483 | ) 484 | self._feature_size += ch 485 | if conf.pool == "adaptivenonzero": 486 | self.out = nn.Sequential( 487 | normalization(ch), 488 | nn.SiLU(), 489 | nn.AdaptiveAvgPool2d((1, 1)), 490 | conv_nd(conf.dims, ch, conf.out_channels, 1), 491 | nn.Flatten(), 492 | ) 493 | else: 494 | raise NotImplementedError(f"Unexpected {conf.pool} pooling") 495 | 496 | def forward(self, x, t=None, return_2d_feature=False): 497 | """ 498 | Apply the model to an input batch. 499 | 500 | :param x: an [N x C x ...] Tensor of inputs. 501 | :param timesteps: a 1-D batch of timesteps. 502 | :return: an [N x K] Tensor of outputs. 503 | """ 504 | if self.conf.use_time_condition: 505 | emb = self.time_embed(timestep_embedding(t, self.model_channels)) 506 | else: 507 | emb = None 508 | 509 | results = [] 510 | h = x.type(self.dtype) 511 | for module in self.input_blocks: 512 | h = module(h, emb=emb) 513 | if self.conf.pool.startswith("spatial"): 514 | results.append(h.type(x.dtype).mean(dim=(2, 3))) 515 | h = self.middle_block(h, emb=emb) 516 | if self.conf.pool.startswith("spatial"): 517 | results.append(h.type(x.dtype).mean(dim=(2, 3))) 518 | h = th.cat(results, axis=-1) 519 | else: 520 | h = h.type(x.dtype) 521 | 522 | h_2d = h 523 | h = self.out(h) 524 | 525 | if return_2d_feature: 526 | return h, h_2d 527 | else: 528 | return h 529 | 530 | def forward_flatten(self, x): 531 | """ 532 | transform the last 2d feature into a flatten vector 533 | """ 534 | h = self.out(x) 535 | return h 536 | 537 | 538 | class SuperResModel(BeatGANsUNetModel): 539 | """ 540 | A UNetModel that performs super-resolution. 541 | 542 | Expects an extra kwarg `low_res` to condition on a low-resolution image. 543 | """ 544 | def __init__(self, image_size, in_channels, *args, **kwargs): 545 | super().__init__(image_size, in_channels * 2, *args, **kwargs) 546 | 547 | def forward(self, x, timesteps, low_res=None, **kwargs): 548 | _, _, new_height, new_width = x.shape 549 | upsampled = F.interpolate(low_res, (new_height, new_width), 550 | mode="bilinear") 551 | x = th.cat([x, upsampled], dim=1) 552 | return super().forward(x, timesteps, **kwargs) 553 | -------------------------------------------------------------------------------- /model/unet_autoenc.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn.functional import silu 6 | 7 | from .latentnet import * 8 | from .unet import * 9 | from choices import * 10 | 11 | 12 | @dataclass 13 | class BeatGANsAutoencConfig(BeatGANsUNetConfig): 14 | # number of style channels 15 | enc_out_channels: int = 512 16 | enc_attn_resolutions: Tuple[int] = None 17 | enc_pool: str = 'depthconv' 18 | enc_num_res_block: int = 2 19 | enc_channel_mult: Tuple[int] = None 20 | enc_grad_checkpoint: bool = False 21 | latent_net_conf: MLPSkipNetConfig = None 22 | 23 | def make_model(self): 24 | return BeatGANsAutoencModel(self) 25 | 26 | 27 | class BeatGANsAutoencModel(BeatGANsUNetModel): 28 | def __init__(self, conf: BeatGANsAutoencConfig): 29 | super().__init__(conf) 30 | self.conf = conf 31 | 32 | # having only time, cond 33 | self.time_embed = TimeStyleSeperateEmbed( 34 | time_channels=conf.model_channels, 35 | time_out_channels=conf.embed_channels, 36 | ) 37 | 38 | self.encoder = BeatGANsEncoderConfig( 39 | image_size=conf.image_size, 40 | in_channels=conf.in_channels, 41 | model_channels=conf.model_channels, 42 | out_hid_channels=conf.enc_out_channels, 43 | out_channels=conf.enc_out_channels, 44 | num_res_blocks=conf.enc_num_res_block, 45 | attention_resolutions=(conf.enc_attn_resolutions 46 | or conf.attention_resolutions), 47 | dropout=conf.dropout, 48 | channel_mult=conf.enc_channel_mult or conf.channel_mult, 49 | use_time_condition=False, 50 | conv_resample=conf.conv_resample, 51 | dims=conf.dims, 52 | use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint, 53 | num_heads=conf.num_heads, 54 | num_head_channels=conf.num_head_channels, 55 | resblock_updown=conf.resblock_updown, 56 | use_new_attention_order=conf.use_new_attention_order, 57 | pool=conf.enc_pool, 58 | ).make_model() 59 | 60 | if conf.latent_net_conf is not None: 61 | self.latent_net = conf.latent_net_conf.make_model() 62 | 63 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor: 64 | """ 65 | Reparameterization trick to sample from N(mu, var) from 66 | N(0,1). 67 | :param mu: (Tensor) Mean of the latent Gaussian [B x D] 68 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D] 69 | :return: (Tensor) [B x D] 70 | """ 71 | assert self.conf.is_stochastic 72 | std = torch.exp(0.5 * logvar) 73 | eps = torch.randn_like(std) 74 | return eps * std + mu 75 | 76 | def sample_z(self, n: int, device): 77 | assert self.conf.is_stochastic 78 | return torch.randn(n, self.conf.enc_out_channels, device=device) 79 | 80 | def noise_to_cond(self, noise: Tensor): 81 | raise NotImplementedError() 82 | assert self.conf.noise_net_conf is not None 83 | return self.noise_net.forward(noise) 84 | 85 | def encode(self, x): 86 | cond = self.encoder.forward(x) 87 | return {'cond': cond} 88 | 89 | @property 90 | def stylespace_sizes(self): 91 | modules = list(self.input_blocks.modules()) + list( 92 | self.middle_block.modules()) + list(self.output_blocks.modules()) 93 | sizes = [] 94 | for module in modules: 95 | if isinstance(module, ResBlock): 96 | linear = module.cond_emb_layers[-1] 97 | sizes.append(linear.weight.shape[0]) 98 | return sizes 99 | 100 | def encode_stylespace(self, x, return_vector: bool = True): 101 | """ 102 | encode to style space 103 | """ 104 | modules = list(self.input_blocks.modules()) + list( 105 | self.middle_block.modules()) + list(self.output_blocks.modules()) 106 | # (n, c) 107 | cond = self.encoder.forward(x) 108 | S = [] 109 | for module in modules: 110 | if isinstance(module, ResBlock): 111 | # (n, c') 112 | s = module.cond_emb_layers.forward(cond) 113 | S.append(s) 114 | 115 | if return_vector: 116 | # (n, sum_c) 117 | return torch.cat(S, dim=1) 118 | else: 119 | return S 120 | 121 | def forward(self, 122 | x, 123 | t, 124 | y=None, 125 | x_start=None, 126 | cond=None, 127 | style=None, 128 | noise=None, 129 | t_cond=None, 130 | **kwargs): 131 | """ 132 | Apply the model to an input batch. 133 | 134 | Args: 135 | x_start: the original image to encode 136 | cond: output of the encoder 137 | noise: random noise (to predict the cond) 138 | """ 139 | 140 | if t_cond is None: 141 | t_cond = t 142 | 143 | if noise is not None: 144 | # if the noise is given, we predict the cond from noise 145 | cond = self.noise_to_cond(noise) 146 | 147 | if cond is None: 148 | if x is not None: 149 | assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}' 150 | 151 | tmp = self.encode(x_start) 152 | cond = tmp['cond'] 153 | 154 | if t is not None: 155 | _t_emb = timestep_embedding(t, self.conf.model_channels) 156 | _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels) 157 | else: 158 | # this happens when training only autoenc 159 | _t_emb = None 160 | _t_cond_emb = None 161 | 162 | if self.conf.resnet_two_cond: 163 | res = self.time_embed.forward( 164 | time_emb=_t_emb, 165 | cond=cond, 166 | time_cond_emb=_t_cond_emb, 167 | ) 168 | else: 169 | raise NotImplementedError() 170 | 171 | if self.conf.resnet_two_cond: 172 | # two cond: first = time emb, second = cond_emb 173 | emb = res.time_emb 174 | cond_emb = res.emb 175 | else: 176 | # one cond = combined of both time and cond 177 | emb = res.emb 178 | cond_emb = None 179 | 180 | # override the style if given 181 | style = style or res.style 182 | 183 | assert (y is not None) == ( 184 | self.conf.num_classes is not None 185 | ), "must specify y if and only if the model is class-conditional" 186 | 187 | if self.conf.num_classes is not None: 188 | raise NotImplementedError() 189 | # assert y.shape == (x.shape[0], ) 190 | # emb = emb + self.label_emb(y) 191 | 192 | # where in the model to supply time conditions 193 | enc_time_emb = emb 194 | mid_time_emb = emb 195 | dec_time_emb = emb 196 | # where in the model to supply style conditions 197 | enc_cond_emb = cond_emb 198 | mid_cond_emb = cond_emb 199 | dec_cond_emb = cond_emb 200 | 201 | # hs = [] 202 | hs = [[] for _ in range(len(self.conf.channel_mult))] 203 | 204 | if x is not None: 205 | h = x.type(self.dtype) 206 | 207 | # input blocks 208 | k = 0 209 | for i in range(len(self.input_num_blocks)): 210 | for j in range(self.input_num_blocks[i]): 211 | h = self.input_blocks[k](h, 212 | emb=enc_time_emb, 213 | cond=enc_cond_emb) 214 | 215 | # print(i, j, h.shape) 216 | hs[i].append(h) 217 | k += 1 218 | assert k == len(self.input_blocks) 219 | 220 | # middle blocks 221 | h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb) 222 | else: 223 | # no lateral connections 224 | # happens when training only the autonecoder 225 | h = None 226 | hs = [[] for _ in range(len(self.conf.channel_mult))] 227 | 228 | # output blocks 229 | k = 0 230 | for i in range(len(self.output_num_blocks)): 231 | for j in range(self.output_num_blocks[i]): 232 | # take the lateral connection from the same layer (in reserve) 233 | # until there is no more, use None 234 | try: 235 | lateral = hs[-i - 1].pop() 236 | # print(i, j, lateral.shape) 237 | except IndexError: 238 | lateral = None 239 | # print(i, j, lateral) 240 | 241 | h = self.output_blocks[k](h, 242 | emb=dec_time_emb, 243 | cond=dec_cond_emb, 244 | lateral=lateral) 245 | k += 1 246 | 247 | pred = self.out(h) 248 | return AutoencReturn(pred=pred, cond=cond) 249 | 250 | 251 | class AutoencReturn(NamedTuple): 252 | pred: Tensor 253 | cond: Tensor = None 254 | 255 | 256 | class EmbedReturn(NamedTuple): 257 | # style and time 258 | emb: Tensor = None 259 | # time only 260 | time_emb: Tensor = None 261 | # style only (but could depend on time) 262 | style: Tensor = None 263 | 264 | 265 | class TimeStyleSeperateEmbed(nn.Module): 266 | # embed only style 267 | def __init__(self, time_channels, time_out_channels): 268 | super().__init__() 269 | self.time_embed = nn.Sequential( 270 | linear(time_channels, time_out_channels), 271 | nn.SiLU(), 272 | linear(time_out_channels, time_out_channels), 273 | ) 274 | self.style = nn.Identity() 275 | 276 | def forward(self, time_emb=None, cond=None, **kwargs): 277 | if time_emb is None: 278 | # happens with autoenc training mode 279 | time_emb = None 280 | else: 281 | time_emb = self.time_embed(time_emb) 282 | style = self.style(cond) 283 | return EmbedReturn(emb=style, time_emb=time_emb, style=style) 284 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | # pre-download the weights for 256 resolution model to checkpoints/ffhq256_autoenc and checkpoints/ffhq256_autoenc_cls 2 | # wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2 3 | # bunzip2 shape_predictor_68_face_landmarks.dat.bz2 4 | 5 | import os 6 | import torch 7 | from torchvision.utils import save_image 8 | import tempfile 9 | from templates import * 10 | from templates_cls import * 11 | from experiment_classifier import ClsModel 12 | from align import LandmarksDetector, image_align 13 | from cog import BasePredictor, Path, Input, BaseModel 14 | 15 | 16 | class ModelOutput(BaseModel): 17 | image: Path 18 | 19 | 20 | class Predictor(BasePredictor): 21 | def setup(self): 22 | self.aligned_dir = "aligned" 23 | os.makedirs(self.aligned_dir, exist_ok=True) 24 | self.device = "cuda:0" 25 | 26 | # Model Initialization 27 | model_config = ffhq256_autoenc() 28 | self.model = LitModel(model_config) 29 | state = torch.load("checkpoints/ffhq256_autoenc/last.ckpt", map_location="cpu") 30 | self.model.load_state_dict(state["state_dict"], strict=False) 31 | self.model.ema_model.eval() 32 | self.model.ema_model.to(self.device) 33 | 34 | # Classifier Initialization 35 | classifier_config = ffhq256_autoenc_cls() 36 | classifier_config.pretrain = None # a bit faster 37 | self.classifier = ClsModel(classifier_config) 38 | state_class = torch.load( 39 | "checkpoints/ffhq256_autoenc_cls/last.ckpt", map_location="cpu" 40 | ) 41 | print("latent step:", state_class["global_step"]) 42 | self.classifier.load_state_dict(state_class["state_dict"], strict=False) 43 | self.classifier.to(self.device) 44 | 45 | self.landmarks_detector = LandmarksDetector( 46 | "shape_predictor_68_face_landmarks.dat" 47 | ) 48 | 49 | def predict( 50 | self, 51 | image: Path = Input( 52 | description="Input image for face manipulation. Image will be aligned and cropped, " 53 | "output aligned and manipulated images.", 54 | ), 55 | target_class: str = Input( 56 | default="Bangs", 57 | choices=[ 58 | "5_o_Clock_Shadow", 59 | "Arched_Eyebrows", 60 | "Attractive", 61 | "Bags_Under_Eyes", 62 | "Bald", 63 | "Bangs", 64 | "Big_Lips", 65 | "Big_Nose", 66 | "Black_Hair", 67 | "Blond_Hair", 68 | "Blurry", 69 | "Brown_Hair", 70 | "Bushy_Eyebrows", 71 | "Chubby", 72 | "Double_Chin", 73 | "Eyeglasses", 74 | "Goatee", 75 | "Gray_Hair", 76 | "Heavy_Makeup", 77 | "High_Cheekbones", 78 | "Male", 79 | "Mouth_Slightly_Open", 80 | "Mustache", 81 | "Narrow_Eyes", 82 | "Beard", 83 | "Oval_Face", 84 | "Pale_Skin", 85 | "Pointy_Nose", 86 | "Receding_Hairline", 87 | "Rosy_Cheeks", 88 | "Sideburns", 89 | "Smiling", 90 | "Straight_Hair", 91 | "Wavy_Hair", 92 | "Wearing_Earrings", 93 | "Wearing_Hat", 94 | "Wearing_Lipstick", 95 | "Wearing_Necklace", 96 | "Wearing_Necktie", 97 | "Young", 98 | ], 99 | description="Choose manipulation direction.", 100 | ), 101 | manipulation_amplitude: float = Input( 102 | default=0.3, 103 | ge=-0.5, 104 | le=0.5, 105 | description="When set too strong it would result in artifact as it could dominate the original image information.", 106 | ), 107 | T_step: int = Input( 108 | default=100, 109 | choices=[50, 100, 125, 200, 250, 500], 110 | description="Number of step for generation.", 111 | ), 112 | T_inv: int = Input(default=200, choices=[50, 100, 125, 200, 250, 500]), 113 | ) -> List[ModelOutput]: 114 | 115 | img_size = 256 116 | print("Aligning image...") 117 | for i, face_landmarks in enumerate( 118 | self.landmarks_detector.get_landmarks(str(image)), start=1 119 | ): 120 | image_align(str(image), f"{self.aligned_dir}/aligned.png", face_landmarks) 121 | 122 | data = ImageDataset( 123 | self.aligned_dir, 124 | image_size=img_size, 125 | exts=["jpg", "jpeg", "JPG", "png"], 126 | do_augment=False, 127 | ) 128 | 129 | print("Encoding and Manipulating the aligned image...") 130 | cls_manipulation_amplitude = manipulation_amplitude 131 | interpreted_target_class = target_class 132 | if ( 133 | target_class not in CelebAttrDataset.id_to_cls 134 | and f"No_{target_class}" in CelebAttrDataset.id_to_cls 135 | ): 136 | cls_manipulation_amplitude = -manipulation_amplitude 137 | interpreted_target_class = f"No_{target_class}" 138 | 139 | batch = data[0]["img"][None] 140 | 141 | semantic_latent = self.model.encode(batch.to(self.device)) 142 | stochastic_latent = self.model.encode_stochastic( 143 | batch.to(self.device), semantic_latent, T=T_inv 144 | ) 145 | 146 | cls_id = CelebAttrDataset.cls_to_id[interpreted_target_class] 147 | class_direction = self.classifier.classifier.weight[cls_id] 148 | normalized_class_direction = F.normalize(class_direction[None, :], dim=1) 149 | 150 | normalized_semantic_latent = self.classifier.normalize(semantic_latent) 151 | normalized_manipulation_amp = cls_manipulation_amplitude * math.sqrt(512) 152 | normalized_manipulated_semantic_latent = ( 153 | normalized_semantic_latent 154 | + normalized_manipulation_amp * normalized_class_direction 155 | ) 156 | 157 | manipulated_semantic_latent = self.classifier.denormalize( 158 | normalized_manipulated_semantic_latent 159 | ) 160 | 161 | # Render Manipulated image 162 | manipulated_img = self.model.render( 163 | stochastic_latent, manipulated_semantic_latent, T=T_step 164 | )[0] 165 | original_img = data[0]["img"] 166 | 167 | model_output = [] 168 | out_path = Path(tempfile.mkdtemp()) / "original_aligned.png" 169 | save_image(convert2rgb(original_img), str(out_path)) 170 | model_output.append(ModelOutput(image=out_path)) 171 | 172 | out_path = Path(tempfile.mkdtemp()) / "manipulated_img.png" 173 | save_image(convert2rgb(manipulated_img, adjust_scale=False), str(out_path)) 174 | model_output.append(ModelOutput(image=out_path)) 175 | return model_output 176 | 177 | 178 | def convert2rgb(img, adjust_scale=True): 179 | convert_img = torch.tensor(img) 180 | if adjust_scale: 181 | convert_img = (convert_img + 1) / 2 182 | return convert_img.cpu() 183 | -------------------------------------------------------------------------------- /renderer.py: -------------------------------------------------------------------------------- 1 | from config import * 2 | 3 | from torch.cuda import amp 4 | 5 | 6 | def render_uncondition(conf: TrainConfig, 7 | model: BeatGANsAutoencModel, 8 | x_T, 9 | sampler: Sampler, 10 | latent_sampler: Sampler, 11 | conds_mean=None, 12 | conds_std=None, 13 | clip_latent_noise: bool = False): 14 | device = x_T.device 15 | if conf.train_mode == TrainMode.diffusion: 16 | assert conf.model_type.can_sample() 17 | return sampler.sample(model=model, noise=x_T) 18 | elif conf.train_mode.is_latent_diffusion(): 19 | model: BeatGANsAutoencModel 20 | if conf.train_mode == TrainMode.latent_diffusion: 21 | latent_noise = torch.randn(len(x_T), conf.style_ch, device=device) 22 | else: 23 | raise NotImplementedError() 24 | 25 | if clip_latent_noise: 26 | latent_noise = latent_noise.clip(-1, 1) 27 | 28 | cond = latent_sampler.sample( 29 | model=model.latent_net, 30 | noise=latent_noise, 31 | clip_denoised=conf.latent_clip_sample, 32 | ) 33 | 34 | if conf.latent_znormalize: 35 | cond = cond * conds_std.to(device) + conds_mean.to(device) 36 | 37 | # the diffusion on the model 38 | return sampler.sample(model=model, noise=x_T, cond=cond) 39 | else: 40 | raise NotImplementedError() 41 | 42 | 43 | def render_condition( 44 | conf: TrainConfig, 45 | model: BeatGANsAutoencModel, 46 | x_T, 47 | sampler: Sampler, 48 | x_start=None, 49 | cond=None, 50 | ): 51 | if conf.train_mode == TrainMode.diffusion: 52 | assert conf.model_type.has_autoenc() 53 | # returns {'cond', 'cond2'} 54 | if cond is None: 55 | cond = model.encode(x_start) 56 | return sampler.sample(model=model, 57 | noise=x_T, 58 | model_kwargs={'cond': cond}) 59 | else: 60 | raise NotImplementedError() 61 | -------------------------------------------------------------------------------- /requirement_for_colab.txt: -------------------------------------------------------------------------------- 1 | torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 pytorch-lightning==1.2.2 torchtext==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html 2 | scipy==1.5.4 3 | numpy==1.19.5 4 | tqdm 5 | pytorch-fid==0.2.0 6 | pandas==1.1.5 7 | lpips==0.1.4 8 | lmdb==1.2.1 9 | ftfy 10 | regex 11 | dlib requests -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.4.5 2 | torchmetrics==0.5.0 3 | torch==1.8.1 4 | torchvision 5 | scipy==1.5.4 6 | numpy==1.19.5 7 | tqdm 8 | pytorch-fid==0.2.0 9 | pandas==1.1.5 10 | lpips==0.1.4 11 | lmdb==1.2.1 12 | ftfy 13 | regex -------------------------------------------------------------------------------- /run_bedroom128.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | # train the autoenc moodel 6 | # this requires V100s. 7 | gpus = [0, 1, 2, 3] 8 | conf = bedroom128_autoenc() 9 | train(conf, gpus=gpus) 10 | 11 | # infer the latents for training the latent DPM 12 | # NOTE: not gpu heavy, but more gpus can be of use! 13 | gpus = [0, 1, 2, 3] 14 | conf.eval_programs = ['infer'] 15 | train(conf, gpus=gpus, mode='eval') 16 | 17 | # train the latent DPM 18 | # NOTE: only need a single gpu 19 | gpus = [0] 20 | conf = bedroom128_autoenc_latent() 21 | train(conf, gpus=gpus) 22 | 23 | # unconditional sampling score 24 | # NOTE: a lot of gpus can speed up this process 25 | gpus = [0, 1, 2, 3] 26 | conf.eval_programs = ['fid(10,10)'] 27 | train(conf, gpus=gpus, mode='eval') -------------------------------------------------------------------------------- /run_bedroom128_ddim.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | gpus = [0, 1, 2, 3] 6 | conf = bedroom128_ddpm() 7 | train(conf, gpus=gpus) 8 | 9 | gpus = [0, 1, 2, 3] 10 | conf.eval_programs = ['fid10'] 11 | train(conf, gpus=gpus, mode='eval') -------------------------------------------------------------------------------- /run_celeba64.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | # train the autoenc moodel 6 | # this can be run on 2080Ti's. 7 | gpus = [0, 1, 2, 3] 8 | conf = celeba64d2c_autoenc() 9 | train(conf, gpus=gpus) 10 | 11 | # infer the latents for training the latent DPM 12 | # NOTE: not gpu heavy, but more gpus can be of use! 13 | gpus = [0, 1, 2, 3] 14 | conf.eval_programs = ['infer'] 15 | train(conf, gpus=gpus, mode='eval') 16 | 17 | # train the latent DPM 18 | # NOTE: only need a single gpu 19 | gpus = [0] 20 | conf = celeba64d2c_autoenc_latent() 21 | train(conf, gpus=gpus) 22 | 23 | # unconditional sampling score 24 | # NOTE: a lot of gpus can speed up this process 25 | gpus = [0, 1, 2, 3] 26 | conf.eval_programs = ['fid(10,10)'] 27 | train(conf, gpus=gpus, mode='eval') -------------------------------------------------------------------------------- /run_ffhq128.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | # train the autoenc moodel 6 | # this requires V100s. 7 | gpus = [0, 1, 2, 3] 8 | conf = ffhq128_autoenc_130M() 9 | train(conf, gpus=gpus) 10 | 11 | # infer the latents for training the latent DPM 12 | # NOTE: not gpu heavy, but more gpus can be of use! 13 | gpus = [0, 1, 2, 3] 14 | conf.eval_programs = ['infer'] 15 | train(conf, gpus=gpus, mode='eval') 16 | 17 | # train the latent DPM 18 | # NOTE: only need a single gpu 19 | gpus = [0] 20 | conf = ffhq128_autoenc_latent() 21 | train(conf, gpus=gpus) 22 | 23 | # unconditional sampling score 24 | # NOTE: a lot of gpus can speed up this process 25 | gpus = [0, 1, 2, 3] 26 | conf.eval_programs = ['fid(10,10)'] 27 | train(conf, gpus=gpus, mode='eval') -------------------------------------------------------------------------------- /run_ffhq128_cls.py: -------------------------------------------------------------------------------- 1 | from templates_cls import * 2 | from experiment_classifier import * 3 | 4 | if __name__ == '__main__': 5 | # need to first train the diffae autoencoding model & infer the latents 6 | # this requires only a single GPU. 7 | gpus = [0] 8 | conf = ffhq128_autoenc_cls() 9 | train_cls(conf, gpus=gpus) 10 | 11 | # after this you can do the manipulation! 12 | -------------------------------------------------------------------------------- /run_ffhq128_ddim.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | gpus = [0, 1, 2, 3] 6 | conf = ffhq128_ddpm_130M() 7 | train(conf, gpus=gpus) 8 | 9 | gpus = [0, 1, 2, 3] 10 | conf.eval_programs = ['fid10'] 11 | train(conf, gpus=gpus, mode='eval') -------------------------------------------------------------------------------- /run_ffhq256.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | # 256 requires 8x v100s, in our case, on two nodes. 6 | # do not run this directly, use `sbatch run_ffhq256.sh` to spawn the srun properly. 7 | gpus = [0, 1, 2, 3] 8 | nodes = 2 9 | conf = ffhq256_autoenc() 10 | train(conf, gpus=gpus, nodes=nodes) -------------------------------------------------------------------------------- /run_ffhq256.sh: -------------------------------------------------------------------------------- 1 | #!/bin/sh 2 | #SBATCH --gres=gpu:4 3 | #SBATCH --cpus-per-gpu=8 4 | #SBATCH --mem-per-gpu=32GB 5 | #SBATCH --nodes=2 6 | #SBATCH --ntasks=8 7 | #SBATCH --partition=gpu-cluster 8 | #SBATCH --time=72:00:00 9 | 10 | export NCCL_DEBUG=INFO 11 | export PYTHONFAULTHANDLER=1 12 | 13 | srun python run_ffhq256.py -------------------------------------------------------------------------------- /run_ffhq256_cls.py: -------------------------------------------------------------------------------- 1 | from templates_cls import * 2 | from experiment_classifier import * 3 | 4 | if __name__ == '__main__': 5 | # need to first train the diffae autoencoding model & infer the latents 6 | # this requires only a single GPU. 7 | gpus = [0] 8 | conf = ffhq256_autoenc_cls() 9 | train_cls(conf, gpus=gpus) 10 | 11 | # after this you can do the manipulation! 12 | -------------------------------------------------------------------------------- /run_ffhq256_latent.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | # do run the run_ffhq256 before using the file to train the latent DPM 6 | 7 | # infer the latents for training the latent DPM 8 | # NOTE: not gpu heavy, but more gpus can be of use! 9 | gpus = [0, 1, 2, 3] 10 | conf = ffhq256_autoenc() 11 | conf.eval_programs = ['infer'] 12 | train(conf, gpus=gpus, mode='eval') 13 | 14 | # train the latent DPM 15 | # NOTE: only need a single gpu 16 | gpus = [0] 17 | conf = ffhq256_autoenc_latent() 18 | train(conf, gpus=gpus) 19 | -------------------------------------------------------------------------------- /run_horse128.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | # train the autoenc moodel 6 | # this requires V100s. 7 | gpus = [0, 1, 2, 3] 8 | conf = horse128_autoenc() 9 | train(conf, gpus=gpus) 10 | 11 | # infer the latents for training the latent DPM 12 | # NOTE: not gpu heavy, but more gpus can be of use! 13 | gpus = [0, 1, 2, 3] 14 | conf.eval_programs = ['infer'] 15 | train(conf, gpus=gpus, mode='eval') 16 | 17 | # train the latent DPM 18 | # NOTE: only need a single gpu 19 | gpus = [0] 20 | conf = horse128_autoenc_latent() 21 | train(conf, gpus=gpus) 22 | 23 | # unconditional sampling score 24 | # NOTE: a lot of gpus can speed up this process 25 | gpus = [0, 1, 2, 3] 26 | conf.eval_programs = ['fid(10,10)'] 27 | train(conf, gpus=gpus, mode='eval') -------------------------------------------------------------------------------- /run_horse128_ddim.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | from templates_latent import * 3 | 4 | if __name__ == '__main__': 5 | gpus = [0, 1, 2, 3] 6 | conf = horse128_ddpm() 7 | train(conf, gpus=gpus) 8 | 9 | gpus = [0, 1, 2, 3] 10 | conf.eval_programs = ['fid10'] 11 | train(conf, gpus=gpus, mode='eval') -------------------------------------------------------------------------------- /ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | from math import exp 6 | 7 | 8 | def gaussian(window_size, sigma): 9 | gauss = torch.Tensor([ 10 | exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) 11 | for x in range(window_size) 12 | ]) 13 | return gauss / gauss.sum() 14 | 15 | 16 | def create_window(window_size, channel): 17 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 18 | _2D_window = _1D_window.mm( 19 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0) 20 | window = Variable( 21 | _2D_window.expand(channel, 1, window_size, window_size).contiguous()) 22 | return window 23 | 24 | 25 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 26 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 27 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 28 | 29 | mu1_sq = mu1.pow(2) 30 | mu2_sq = mu2.pow(2) 31 | mu1_mu2 = mu1 * mu2 32 | 33 | sigma1_sq = F.conv2d( 34 | img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 35 | sigma2_sq = F.conv2d( 36 | img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 37 | sigma12 = F.conv2d( 38 | img1 * img2, window, padding=window_size // 2, 39 | groups=channel) - mu1_mu2 40 | 41 | C1 = 0.01**2 42 | C2 = 0.03**2 43 | 44 | ssim_map = ((2 * mu1_mu2 + C1) * 45 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 46 | (sigma1_sq + sigma2_sq + C2)) 47 | 48 | if size_average: 49 | return ssim_map.mean() 50 | else: 51 | return ssim_map.mean(1).mean(1).mean(1) 52 | 53 | 54 | class SSIM(torch.nn.Module): 55 | def __init__(self, window_size=11, size_average=True): 56 | super(SSIM, self).__init__() 57 | self.window_size = window_size 58 | self.size_average = size_average 59 | self.channel = 1 60 | self.window = create_window(window_size, self.channel) 61 | 62 | def forward(self, img1, img2): 63 | (_, channel, _, _) = img1.size() 64 | 65 | if channel == self.channel and self.window.data.type( 66 | ) == img1.data.type(): 67 | window = self.window 68 | else: 69 | window = create_window(self.window_size, channel) 70 | 71 | if img1.is_cuda: 72 | window = window.cuda(img1.get_device()) 73 | window = window.type_as(img1) 74 | 75 | self.window = window 76 | self.channel = channel 77 | 78 | return _ssim(img1, img2, window, self.window_size, channel, 79 | self.size_average) 80 | 81 | 82 | def ssim(img1, img2, window_size=11, size_average=True): 83 | (_, channel, _, _) = img1.size() 84 | window = create_window(window_size, channel) 85 | 86 | if img1.is_cuda: 87 | window = window.cuda(img1.get_device()) 88 | window = window.type_as(img1) 89 | 90 | return _ssim(img1, img2, window, window_size, channel, size_average) -------------------------------------------------------------------------------- /templates.py: -------------------------------------------------------------------------------- 1 | from experiment import * 2 | 3 | 4 | def ddpm(): 5 | """ 6 | base configuration for all DDIM-based models. 7 | """ 8 | conf = TrainConfig() 9 | conf.batch_size = 32 10 | conf.beatgans_gen_type = GenerativeType.ddim 11 | conf.beta_scheduler = 'linear' 12 | conf.data_name = 'ffhq' 13 | conf.diffusion_type = 'beatgans' 14 | conf.eval_ema_every_samples = 200_000 15 | conf.eval_every_samples = 200_000 16 | conf.fp16 = True 17 | conf.lr = 1e-4 18 | conf.model_name = ModelName.beatgans_ddpm 19 | conf.net_attn = (16, ) 20 | conf.net_beatgans_attn_head = 1 21 | conf.net_beatgans_embed_channels = 512 22 | conf.net_ch_mult = (1, 2, 4, 8) 23 | conf.net_ch = 64 24 | conf.sample_size = 32 25 | conf.T_eval = 20 26 | conf.T = 1000 27 | conf.make_model_conf() 28 | return conf 29 | 30 | 31 | def autoenc_base(): 32 | """ 33 | base configuration for all Diff-AE models. 34 | """ 35 | conf = TrainConfig() 36 | conf.batch_size = 32 37 | conf.beatgans_gen_type = GenerativeType.ddim 38 | conf.beta_scheduler = 'linear' 39 | conf.data_name = 'ffhq' 40 | conf.diffusion_type = 'beatgans' 41 | conf.eval_ema_every_samples = 200_000 42 | conf.eval_every_samples = 200_000 43 | conf.fp16 = True 44 | conf.lr = 1e-4 45 | conf.model_name = ModelName.beatgans_autoenc 46 | conf.net_attn = (16, ) 47 | conf.net_beatgans_attn_head = 1 48 | conf.net_beatgans_embed_channels = 512 49 | conf.net_beatgans_resnet_two_cond = True 50 | conf.net_ch_mult = (1, 2, 4, 8) 51 | conf.net_ch = 64 52 | conf.net_enc_channel_mult = (1, 2, 4, 8, 8) 53 | conf.net_enc_pool = 'adaptivenonzero' 54 | conf.sample_size = 32 55 | conf.T_eval = 20 56 | conf.T = 1000 57 | conf.make_model_conf() 58 | return conf 59 | 60 | 61 | def ffhq64_ddpm(): 62 | conf = ddpm() 63 | conf.data_name = 'ffhqlmdb256' 64 | conf.warmup = 0 65 | conf.total_samples = 72_000_000 66 | conf.scale_up_gpus(4) 67 | return conf 68 | 69 | 70 | def ffhq64_autoenc(): 71 | conf = autoenc_base() 72 | conf.data_name = 'ffhqlmdb256' 73 | conf.warmup = 0 74 | conf.total_samples = 72_000_000 75 | conf.net_ch_mult = (1, 2, 4, 8) 76 | conf.net_enc_channel_mult = (1, 2, 4, 8, 8) 77 | conf.eval_every_samples = 1_000_000 78 | conf.eval_ema_every_samples = 1_000_000 79 | conf.scale_up_gpus(4) 80 | conf.make_model_conf() 81 | return conf 82 | 83 | 84 | def celeba64d2c_ddpm(): 85 | conf = ffhq128_ddpm() 86 | conf.data_name = 'celebalmdb' 87 | conf.eval_every_samples = 10_000_000 88 | conf.eval_ema_every_samples = 10_000_000 89 | conf.total_samples = 72_000_000 90 | conf.name = 'celeba64d2c_ddpm' 91 | return conf 92 | 93 | 94 | def celeba64d2c_autoenc(): 95 | conf = ffhq64_autoenc() 96 | conf.data_name = 'celebalmdb' 97 | conf.eval_every_samples = 10_000_000 98 | conf.eval_ema_every_samples = 10_000_000 99 | conf.total_samples = 72_000_000 100 | conf.name = 'celeba64d2c_autoenc' 101 | return conf 102 | 103 | 104 | def ffhq128_ddpm(): 105 | conf = ddpm() 106 | conf.data_name = 'ffhqlmdb256' 107 | conf.warmup = 0 108 | conf.total_samples = 48_000_000 109 | conf.img_size = 128 110 | conf.net_ch = 128 111 | # channels: 112 | # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4 113 | # sizes: 114 | # 128 => 128 => 64 => 32 => 16 => 8 115 | conf.net_ch_mult = (1, 1, 2, 3, 4) 116 | conf.eval_every_samples = 1_000_000 117 | conf.eval_ema_every_samples = 1_000_000 118 | conf.scale_up_gpus(4) 119 | conf.eval_ema_every_samples = 10_000_000 120 | conf.eval_every_samples = 10_000_000 121 | conf.make_model_conf() 122 | return conf 123 | 124 | 125 | def ffhq128_autoenc_base(): 126 | conf = autoenc_base() 127 | conf.data_name = 'ffhqlmdb256' 128 | conf.scale_up_gpus(4) 129 | conf.img_size = 128 130 | conf.net_ch = 128 131 | # final resolution = 8x8 132 | conf.net_ch_mult = (1, 1, 2, 3, 4) 133 | # final resolution = 4x4 134 | conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4) 135 | conf.eval_ema_every_samples = 10_000_000 136 | conf.eval_every_samples = 10_000_000 137 | conf.make_model_conf() 138 | return conf 139 | 140 | 141 | def ffhq256_autoenc(): 142 | conf = ffhq128_autoenc_base() 143 | conf.img_size = 256 144 | conf.net_ch = 128 145 | conf.net_ch_mult = (1, 1, 2, 2, 4, 4) 146 | conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) 147 | conf.eval_every_samples = 10_000_000 148 | conf.eval_ema_every_samples = 10_000_000 149 | conf.total_samples = 200_000_000 150 | conf.batch_size = 64 151 | conf.make_model_conf() 152 | conf.name = 'ffhq256_autoenc' 153 | return conf 154 | 155 | 156 | def ffhq256_autoenc_eco(): 157 | conf = ffhq128_autoenc_base() 158 | conf.img_size = 256 159 | conf.net_ch = 128 160 | conf.net_ch_mult = (1, 1, 2, 2, 4, 4) 161 | conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) 162 | conf.eval_every_samples = 10_000_000 163 | conf.eval_ema_every_samples = 10_000_000 164 | conf.total_samples = 200_000_000 165 | conf.batch_size = 64 166 | conf.make_model_conf() 167 | conf.name = 'ffhq256_autoenc_eco' 168 | return conf 169 | 170 | 171 | def ffhq128_ddpm_72M(): 172 | conf = ffhq128_ddpm() 173 | conf.total_samples = 72_000_000 174 | conf.name = 'ffhq128_ddpm_72M' 175 | return conf 176 | 177 | 178 | def ffhq128_autoenc_72M(): 179 | conf = ffhq128_autoenc_base() 180 | conf.total_samples = 72_000_000 181 | conf.name = 'ffhq128_autoenc_72M' 182 | return conf 183 | 184 | 185 | def ffhq128_ddpm_130M(): 186 | conf = ffhq128_ddpm() 187 | conf.total_samples = 130_000_000 188 | conf.eval_ema_every_samples = 10_000_000 189 | conf.eval_every_samples = 10_000_000 190 | conf.name = 'ffhq128_ddpm_130M' 191 | return conf 192 | 193 | 194 | def ffhq128_autoenc_130M(): 195 | conf = ffhq128_autoenc_base() 196 | conf.total_samples = 130_000_000 197 | conf.eval_ema_every_samples = 10_000_000 198 | conf.eval_every_samples = 10_000_000 199 | conf.name = 'ffhq128_autoenc_130M' 200 | return conf 201 | 202 | 203 | def horse128_ddpm(): 204 | conf = ffhq128_ddpm() 205 | conf.data_name = 'horse256' 206 | conf.total_samples = 130_000_000 207 | conf.eval_ema_every_samples = 10_000_000 208 | conf.eval_every_samples = 10_000_000 209 | conf.name = 'horse128_ddpm' 210 | return conf 211 | 212 | 213 | def horse128_autoenc(): 214 | conf = ffhq128_autoenc_base() 215 | conf.data_name = 'horse256' 216 | conf.total_samples = 130_000_000 217 | conf.eval_ema_every_samples = 10_000_000 218 | conf.eval_every_samples = 10_000_000 219 | conf.name = 'horse128_autoenc' 220 | return conf 221 | 222 | 223 | def bedroom128_ddpm(): 224 | conf = ffhq128_ddpm() 225 | conf.data_name = 'bedroom256' 226 | conf.eval_ema_every_samples = 10_000_000 227 | conf.eval_every_samples = 10_000_000 228 | conf.total_samples = 120_000_000 229 | conf.name = 'bedroom128_ddpm' 230 | return conf 231 | 232 | 233 | def bedroom128_autoenc(): 234 | conf = ffhq128_autoenc_base() 235 | conf.data_name = 'bedroom256' 236 | conf.eval_ema_every_samples = 10_000_000 237 | conf.eval_every_samples = 10_000_000 238 | conf.total_samples = 120_000_000 239 | conf.name = 'bedroom128_autoenc' 240 | return conf 241 | 242 | 243 | def pretrain_celeba64d2c_72M(): 244 | conf = celeba64d2c_autoenc() 245 | conf.pretrain = PretrainConfig( 246 | name='72M', 247 | path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt', 248 | ) 249 | conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl' 250 | return conf 251 | 252 | 253 | def pretrain_ffhq128_autoenc72M(): 254 | conf = ffhq128_autoenc_base() 255 | conf.postfix = '' 256 | conf.pretrain = PretrainConfig( 257 | name='72M', 258 | path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt', 259 | ) 260 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl' 261 | return conf 262 | 263 | 264 | def pretrain_ffhq128_autoenc130M(): 265 | conf = ffhq128_autoenc_base() 266 | conf.pretrain = PretrainConfig( 267 | name='130M', 268 | path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt', 269 | ) 270 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl' 271 | return conf 272 | 273 | 274 | def pretrain_ffhq256_autoenc(): 275 | conf = ffhq256_autoenc() 276 | conf.pretrain = PretrainConfig( 277 | name='90M', 278 | path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt', 279 | ) 280 | conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' 281 | return conf 282 | 283 | 284 | def pretrain_horse128(): 285 | conf = horse128_autoenc() 286 | conf.pretrain = PretrainConfig( 287 | name='82M', 288 | path=f'checkpoints/{horse128_autoenc().name}/last.ckpt', 289 | ) 290 | conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl' 291 | return conf 292 | 293 | 294 | def pretrain_bedroom128(): 295 | conf = bedroom128_autoenc() 296 | conf.pretrain = PretrainConfig( 297 | name='120M', 298 | path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt', 299 | ) 300 | conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl' 301 | return conf 302 | -------------------------------------------------------------------------------- /templates_cls.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | 3 | 4 | def ffhq128_autoenc_cls(): 5 | conf = ffhq128_autoenc_130M() 6 | conf.train_mode = TrainMode.manipulate 7 | conf.manipulate_mode = ManipulateMode.celebahq_all 8 | conf.manipulate_znormalize = True 9 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl' 10 | conf.batch_size = 32 11 | conf.lr = 1e-3 12 | conf.total_samples = 300_000 13 | # use the pretraining trick instead of contiuning trick 14 | conf.pretrain = PretrainConfig( 15 | '130M', 16 | f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt', 17 | ) 18 | conf.name = 'ffhq128_autoenc_cls' 19 | return conf 20 | 21 | 22 | def ffhq256_autoenc_cls(): 23 | '''We first train the encoder on FFHQ dataset then use it as a pretrained to train a linear classifer on CelebA dataset with attribute labels''' 24 | conf = ffhq256_autoenc() 25 | conf.train_mode = TrainMode.manipulate 26 | conf.manipulate_mode = ManipulateMode.celebahq_all 27 | conf.manipulate_znormalize = True 28 | conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' # we train on Celeb dataset, not FFHQ 29 | conf.batch_size = 32 30 | conf.lr = 1e-3 31 | conf.total_samples = 300_000 32 | # use the pretraining trick instead of contiuning trick 33 | conf.pretrain = PretrainConfig( 34 | '130M', 35 | f'checkpoints/{ffhq256_autoenc().name}/last.ckpt', 36 | ) 37 | conf.name = 'ffhq256_autoenc_cls' 38 | return conf 39 | -------------------------------------------------------------------------------- /templates_latent.py: -------------------------------------------------------------------------------- 1 | from templates import * 2 | 3 | 4 | def latent_diffusion_config(conf: TrainConfig): 5 | conf.batch_size = 128 6 | conf.train_mode = TrainMode.latent_diffusion 7 | conf.latent_gen_type = GenerativeType.ddim 8 | conf.latent_loss_type = LossType.mse 9 | conf.latent_model_mean_type = ModelMeanType.eps 10 | conf.latent_model_var_type = ModelVarType.fixed_large 11 | conf.latent_rescale_timesteps = False 12 | conf.latent_clip_sample = False 13 | conf.latent_T_eval = 20 14 | conf.latent_znormalize = True 15 | conf.total_samples = 96_000_000 16 | conf.sample_every_samples = 400_000 17 | conf.eval_every_samples = 20_000_000 18 | conf.eval_ema_every_samples = 20_000_000 19 | conf.save_every_samples = 2_000_000 20 | return conf 21 | 22 | 23 | def latent_diffusion128_config(conf: TrainConfig): 24 | conf = latent_diffusion_config(conf) 25 | conf.batch_size_eval = 32 26 | return conf 27 | 28 | 29 | def latent_mlp_2048_norm_10layers(conf: TrainConfig): 30 | conf.net_latent_net_type = LatentNetType.skip 31 | conf.net_latent_layers = 10 32 | conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers)) 33 | conf.net_latent_activation = Activation.silu 34 | conf.net_latent_num_hid_channels = 2048 35 | conf.net_latent_use_norm = True 36 | conf.net_latent_condition_bias = 1 37 | return conf 38 | 39 | 40 | def latent_mlp_2048_norm_20layers(conf: TrainConfig): 41 | conf = latent_mlp_2048_norm_10layers(conf) 42 | conf.net_latent_layers = 20 43 | conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers)) 44 | return conf 45 | 46 | 47 | def latent_256_batch_size(conf: TrainConfig): 48 | conf.batch_size = 256 49 | conf.eval_ema_every_samples = 100_000_000 50 | conf.eval_every_samples = 100_000_000 51 | conf.sample_every_samples = 1_000_000 52 | conf.save_every_samples = 2_000_000 53 | conf.total_samples = 301_000_000 54 | return conf 55 | 56 | 57 | def latent_512_batch_size(conf: TrainConfig): 58 | conf.batch_size = 512 59 | conf.eval_ema_every_samples = 100_000_000 60 | conf.eval_every_samples = 100_000_000 61 | conf.sample_every_samples = 1_000_000 62 | conf.save_every_samples = 5_000_000 63 | conf.total_samples = 501_000_000 64 | return conf 65 | 66 | 67 | def latent_2048_batch_size(conf: TrainConfig): 68 | conf.batch_size = 2048 69 | conf.eval_ema_every_samples = 200_000_000 70 | conf.eval_every_samples = 200_000_000 71 | conf.sample_every_samples = 4_000_000 72 | conf.save_every_samples = 20_000_000 73 | conf.total_samples = 1_501_000_000 74 | return conf 75 | 76 | 77 | def adamw_weight_decay(conf: TrainConfig): 78 | conf.optimizer = OptimizerType.adamw 79 | conf.weight_decay = 0.01 80 | return conf 81 | 82 | 83 | def ffhq128_autoenc_latent(): 84 | conf = pretrain_ffhq128_autoenc130M() 85 | conf = latent_diffusion128_config(conf) 86 | conf = latent_mlp_2048_norm_10layers(conf) 87 | conf = latent_256_batch_size(conf) 88 | conf = adamw_weight_decay(conf) 89 | conf.total_samples = 101_000_000 90 | conf.latent_loss_type = LossType.l1 91 | conf.latent_beta_scheduler = 'const0.008' 92 | conf.name = 'ffhq128_autoenc_latent' 93 | return conf 94 | 95 | 96 | def ffhq256_autoenc_latent(): 97 | conf = pretrain_ffhq256_autoenc() 98 | conf = latent_diffusion128_config(conf) 99 | conf = latent_mlp_2048_norm_10layers(conf) 100 | conf = latent_256_batch_size(conf) 101 | conf = adamw_weight_decay(conf) 102 | conf.total_samples = 101_000_000 103 | conf.latent_loss_type = LossType.l1 104 | conf.latent_beta_scheduler = 'const0.008' 105 | conf.eval_ema_every_samples = 200_000_000 106 | conf.eval_every_samples = 200_000_000 107 | conf.sample_every_samples = 4_000_000 108 | conf.name = 'ffhq256_autoenc_latent' 109 | return conf 110 | 111 | 112 | def horse128_autoenc_latent(): 113 | conf = pretrain_horse128() 114 | conf = latent_diffusion128_config(conf) 115 | conf = latent_2048_batch_size(conf) 116 | conf = latent_mlp_2048_norm_20layers(conf) 117 | conf.total_samples = 2_001_000_000 118 | conf.latent_beta_scheduler = 'const0.008' 119 | conf.latent_loss_type = LossType.l1 120 | conf.name = 'horse128_autoenc_latent' 121 | return conf 122 | 123 | 124 | def bedroom128_autoenc_latent(): 125 | conf = pretrain_bedroom128() 126 | conf = latent_diffusion128_config(conf) 127 | conf = latent_2048_batch_size(conf) 128 | conf = latent_mlp_2048_norm_20layers(conf) 129 | conf.total_samples = 2_001_000_000 130 | conf.latent_beta_scheduler = 'const0.008' 131 | conf.latent_loss_type = LossType.l1 132 | conf.name = 'bedroom128_autoenc_latent' 133 | return conf 134 | 135 | 136 | def celeba64d2c_autoenc_latent(): 137 | conf = pretrain_celeba64d2c_72M() 138 | conf = latent_diffusion_config(conf) 139 | conf = latent_512_batch_size(conf) 140 | conf = latent_mlp_2048_norm_10layers(conf) 141 | conf = adamw_weight_decay(conf) 142 | # just for the name 143 | conf.continue_from = PretrainConfig('200M', 144 | f'log-latent/{conf.name}/last.ckpt') 145 | conf.postfix = '_300M' 146 | conf.total_samples = 301_000_000 147 | conf.latent_beta_scheduler = 'const0.008' 148 | conf.latent_loss_type = LossType.l1 149 | conf.name = 'celeba64d2c_autoenc_latent' 150 | return conf 151 | --------------------------------------------------------------------------------