├── .gitignore
├── .vscode
└── settings.json
├── LICENSE
├── README.md
├── README.md.backup
├── align.py
├── autoencoding.ipynb
├── choices.py
├── cog.yaml
├── config.py
├── config_base.py
├── data_resize_bedroom.py
├── data_resize_celeba.py
├── data_resize_celebahq.py
├── data_resize_ffhq.py
├── data_resize_horse.py
├── dataset.py
├── dataset_util.py
├── datasets
└── celeba_anno
│ ├── CelebAMask-HQ-attribute-anno.txt
│ ├── CelebAMask-HQ-pose-anno.txt
│ └── list_attr_celeba.txt
├── diffusion
├── __init__.py
├── base.py
├── diffusion.py
└── resample.py
├── dist_utils.py
├── evals
├── ffhq128_autoenc_130M.txt
└── ffhq128_autoenc_latent.txt
├── experiment.py
├── experiment_classifier.py
├── imgs
└── sandy.JPG
├── imgs_align
└── sandy.png
├── imgs_interpolate
├── 1_a.png
└── 1_b.png
├── imgs_manipulated
├── compare.png
├── output.png
└── sandy-wavyhair.png
├── install_requirements_for_colab.sh
├── interpolate.ipynb
├── lmdb_writer.py
├── manipulate.ipynb
├── manipulate_note.ipynb
├── metrics.py
├── model
├── __init__.py
├── blocks.py
├── latentnet.py
├── nn.py
├── unet.py
└── unet_autoenc.py
├── predict.py
├── renderer.py
├── requirement_for_colab.txt
├── requirements.txt
├── run_bedroom128.py
├── run_bedroom128_ddim.py
├── run_celeba64.py
├── run_ffhq128.py
├── run_ffhq128_cls.py
├── run_ffhq128_ddim.py
├── run_ffhq256.py
├── run_ffhq256.sh
├── run_ffhq256_cls.py
├── run_ffhq256_latent.py
├── run_horse128.py
├── run_horse128_ddim.py
├── sample.ipynb
├── ssim.py
├── templates.py
├── templates_cls.py
└── templates_latent.py
/.gitignore:
--------------------------------------------------------------------------------
1 | temp
2 | __pycache__
3 | generated
4 | latent_infer
5 | datasets/bedroom256.lmdb
6 | datasets/horse256.lmdb
7 | datasets/celebahq
8 | datasets/celebahq256.lmdb
9 | datasets/ffhq
10 | datasets/ffhq256.lmdb
11 | datasets/celeba.lmdb
12 | datasets/celeba
13 | checkpoints
14 |
15 | .ipynb_checkpoints/
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "python.formatting.provider": "yapf"
3 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 VISTEC - Vidyasirimedhi Institute of Science and Technology
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Official implementation of Diffusion Autoencoders
2 |
3 | A CVPR 2022 (ORAL) paper ([paper](https://openaccess.thecvf.com/content/CVPR2022/html/Preechakul_Diffusion_Autoencoders_Toward_a_Meaningful_and_Decodable_Representation_CVPR_2022_paper.html), [site](https://diff-ae.github.io/), [5-min video](https://youtu.be/i3rjEsiHoUU)):
4 |
5 | ```
6 | @inproceedings{preechakul2021diffusion,
7 | title={Diffusion Autoencoders: Toward a Meaningful and Decodable Representation},
8 | author={Preechakul, Konpat and Chatthee, Nattanat and Wizadwongsa, Suttisak and Suwajanakorn, Supasorn},
9 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
10 | year={2022},
11 | }
12 | ```
13 |
14 | ## Usage
15 |
16 | ⚙️ Try a Colab walkthrough: [](https://drive.google.com/file/d/1OTfwkklN-IEd4hFk4LnweOleyDtS4XTh/view?usp=sharing)
17 |
18 | 🤗 Try a web demo: [](https://replicate.com/cjwbw/diffae)
19 |
20 | Note: Since we expect a lot of changes on the codebase, please fork the repo before using.
21 |
22 | ### Prerequisites
23 |
24 | See `requirements.txt`
25 |
26 | ```
27 | pip install -r requirements.txt
28 | ```
29 |
30 | ### Quick start
31 |
32 | A jupyter notebook.
33 |
34 | For unconditional generation: `sample.ipynb`
35 |
36 | For manipulation: `manipulate.ipynb`
37 |
38 | For interpolation: `interpolate.ipynb`
39 |
40 | For autoencoding: `autoencoding.ipynb`
41 |
42 | Aligning your own images:
43 |
44 | 1. Put images into the `imgs` directory
45 | 2. Run `align.py` (need to `pip install dlib requests`)
46 | 3. Result images will be available in `imgs_align` directory
47 |
48 |
49 |
50 |
51 | Original in imgs directory
52 | |
53 |
54 | Aligned with align.py
55 | |
56 |
57 | Using manipulate.ipynb
58 | |
59 |
60 |
61 |
62 |
63 | ### Checkpoints
64 |
65 | We provide checkpoints for the following models:
66 |
67 | 1. DDIM: **FFHQ128** ([72M](https://drive.google.com/drive/folders/1-fa46UPSgy9ximKngBflgSj3u87-DLrw), [130M](https://drive.google.com/drive/folders/1-Sqes07fs1y9sAYXuYWSoDE_xxTtH4yx)), [**Bedroom128**](https://drive.google.com/drive/folders/1-_8LZd5inoAOBT-hO5f7RYivt95FbYT1), [**Horse128**](https://drive.google.com/drive/folders/10Hq3zIlJs9ZSiXDQVYuVJVf0cX4a_nDB)
68 | 2. DiffAE (autoencoding only): [**FFHQ256**](https://drive.google.com/drive/folders/1-5zfxT6Gl-GjxM7z9ZO2AHlB70tfmF6V), **FFHQ128** ([72M](https://drive.google.com/drive/folders/10bmB6WhLkgxybkhso5g3JmIFPAnmZMQO), [130M](https://drive.google.com/drive/folders/10UNtFNfxbHBPkoIh003JkSPto5s-VbeN)), [**Bedroom128**](https://drive.google.com/drive/folders/12EdjbIKnvP5RngKsR0UU-4kgpPAaYtlp), [**Horse128**](https://drive.google.com/drive/folders/12EtTRXzQc5uPHscpjIcci-Rg-OGa_N30)
69 | 3. DiffAE (with latent DPM, can sample): [**FFHQ256**](https://drive.google.com/drive/folders/1-H8WzKc65dEONN-DQ87TnXc23nTXDTYb), [**FFHQ128**](https://drive.google.com/drive/folders/11pdjMQ6NS8GFFiGOq3fziNJxzXU1Mw3l), [**Bedroom128**](https://drive.google.com/drive/folders/11mdxv2lVX5Em8TuhNJt-Wt2XKt25y8zU), [**Horse128**](https://drive.google.com/drive/folders/11k8XNDK3ENxiRnPSUdJ4rnagJYo4uKEo)
70 | 4. DiffAE's classifiers (for manipulation): [**FFHQ256's latent on CelebAHQ**](https://drive.google.com/drive/folders/117Wv7RZs_gumgrCOIhDEWgsNy6BRJorg), [**FFHQ128's latent on CelebAHQ**](https://drive.google.com/drive/folders/11EYIyuK6IX44C8MqreUyMgPCNiEnwhmI)
71 |
72 | Checkpoints ought to be put into a separate directory `checkpoints`.
73 | Download the checkpoints and put them into `checkpoints` directory. It should look like this:
74 |
75 | ```
76 | checkpoints/
77 | - bedroom128_autoenc
78 | - last.ckpt # diffae checkpoint
79 | - latent.ckpt # predicted z_sem on the dataset
80 | - bedroom128_autoenc_latent
81 | - last.ckpt # diffae + latent DPM checkpoint
82 | - bedroom128_ddpm
83 | - ...
84 | ```
85 |
86 |
87 | ### LMDB Datasets
88 |
89 | We do not own any of the following datasets. We provide the LMDB ready-to-use dataset for the sake of convenience.
90 |
91 | - [FFHQ](https://1drv.ms/f/s!Ar2O0vx8sW70uLV1Ivk2pTjam1A8VA)
92 | - [CelebAHQ](https://1drv.ms/f/s!Ar2O0vx8sW70uL4GMeWEciHkHdH6vQ)
93 |
94 | **Broken links**
95 |
96 | Note: I'm trying to recover the following links.
97 |
98 | - [CelebA](https://drive.google.com/drive/folders/1HJAhK2hLYcT_n0gWlCu5XxdZj-bPekZ0?usp=sharing)
99 | - [LSUN Bedroom](https://drive.google.com/drive/folders/1O_3aT3LtY1YDE2pOQCp6MFpCk7Pcpkhb?usp=sharing)
100 | - [LSUN Horse](https://drive.google.com/drive/folders/1ooHW7VivZUs4i5CarPaWxakCwfeqAK8l?usp=sharing)
101 |
102 | The directory tree should be:
103 |
104 | ```
105 | datasets/
106 | - bedroom256.lmdb
107 | - celebahq256.lmdb
108 | - celeba.lmdb
109 | - ffhq256.lmdb
110 | - horse256.lmdb
111 | ```
112 |
113 | You can also download from the original sources, and use our provided codes to package them as LMDB files.
114 | Original sources for each dataset is as follows:
115 |
116 | - FFHQ (https://github.com/NVlabs/ffhq-dataset)
117 | - CelebAHQ (https://github.com/switchablenorms/CelebAMask-HQ)
118 | - CelebA (https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
119 | - LSUN (https://github.com/fyu/lsun)
120 |
121 | The conversion codes are provided as:
122 |
123 | ```
124 | data_resize_bedroom.py
125 | data_resize_celebhq.py
126 | data_resize_celeba.py
127 | data_resize_ffhq.py
128 | data_resize_horse.py
129 | ```
130 |
131 | Google drive: https://drive.google.com/drive/folders/1abNP4QKGbNnymjn8607BF0cwxX2L23jh?usp=sharing
132 |
133 |
134 | ## Training
135 |
136 | We provide scripts for training & evaluate DDIM and DiffAE (including latent DPM) on the following datasets: FFHQ128, FFHQ256, Bedroom128, Horse128, Celeba64 (D2C's crop).
137 | Usually, the evaluation results (FID's) will be available in `eval` directory.
138 |
139 | Note: Most experiment requires at least 4x V100s during training the DPM models while requiring 1x 2080Ti during training the accompanying latent DPM.
140 |
141 |
142 |
143 | **FFHQ128**
144 | ```
145 | # diffae
146 | python run_ffhq128.py
147 | # ddim
148 | python run_ffhq128_ddim.py
149 | ```
150 |
151 | A classifier (for manipulation) can be trained using:
152 | ```
153 | python run_ffhq128_cls.py
154 | ```
155 |
156 | **FFHQ256**
157 |
158 | We only trained the DiffAE due to high computation cost.
159 | This requires 8x V100s.
160 | ```
161 | sbatch run_ffhq256.py
162 | ```
163 |
164 | After the task is done, you need to train the latent DPM (requiring only 1x 2080Ti)
165 | ```
166 | python run_ffhq256_latent.py
167 | ```
168 |
169 | A classifier (for manipulation) can be trained using:
170 | ```
171 | python run_ffhq256_cls.py
172 | ```
173 |
174 | **Bedroom128**
175 |
176 | ```
177 | # diffae
178 | python run_bedroom128.py
179 | # ddim
180 | python run_bedroom128_ddim.py
181 | ```
182 |
183 | **Horse128**
184 |
185 | ```
186 | # diffae
187 | python run_horse128.py
188 | # ddim
189 | python run_horse128_ddim.py
190 | ```
191 |
192 | **Celeba64**
193 |
194 | This experiment can be run on 2080Ti's.
195 |
196 | ```
197 | # diffae
198 | python run_celeba64.py
199 | ```
200 |
--------------------------------------------------------------------------------
/README.md.backup:
--------------------------------------------------------------------------------
1 | # Official implementation of Diffusion Autoencoders
2 |
3 | A CVPR 2022 paper:
4 |
5 | > Preechakul, Konpat, Nattanat Chatthee, Suttisak Wizadwongsa, and Supasorn Suwajanakorn. 2021. “Diffusion Autoencoders: Toward a Meaningful and Decodable Representation.” arXiv [cs.CV]. arXiv. http://arxiv.org/abs/2111.15640.
6 |
7 | ## Usage
8 |
9 | Note: Since we expect a lot of changes on the codebase, please fork the repo before using.
10 |
11 | ### Prerequisites
12 |
13 | See `requirements.txt`
14 |
15 | ```
16 | pip install -r requirements.txt
17 | ```
18 |
19 | ### Quick start
20 |
21 | A jupyter notebook.
22 |
23 | For unconditional generation: `sample.ipynb`
24 |
25 | For manipulation: `manipulate.ipynb`
26 |
27 | Aligning your own images:
28 |
29 | 1. Put images into the `imgs` directory
30 | 2. Run `align.py` (need to `pip install dlib requests`)
31 | 3. Result images will be available in `imgs_align` directory
32 |
33 |
34 |
39 |
40 | |  |  |  |
41 | |---|---|---|
42 |
43 |
44 | ### Checkpoints
45 |
46 | We provide checkpoints for the following models:
47 |
48 | 1. DDIM: **FFHQ128** ([72M](https://drive.google.com/drive/folders/1-J8FPNZOQxSqpfTpwRXawLi2KKGL1qlK?usp=sharing), [130M](https://drive.google.com/drive/folders/17T5YJXpYdgE6cWltN8gZFxRsJzpVxnLh?usp=sharing)), [**Bedroom128**](https://drive.google.com/drive/folders/19s-lAiK7fGD5Meo5obNV5o0L3MfqU0Sk?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1PiC5JWLcd8mZW9cghDCR0V4Hx0QCXOor?usp=sharing)
49 | 2. DiffAE (autoencoding only): [**FFHQ256**](https://drive.google.com/drive/folders/1hTP9QbYXwv_Nl5sgcZNH0yKprJx7ivC5?usp=sharing), **FFHQ128** ([72M](https://drive.google.com/drive/folders/15QHmZP1G5jEMh80R1Nbtdb4ZKb6VvfII?usp=sharing), [130M](https://drive.google.com/drive/folders/1UlwLwgv16cEqxTn7g-V2ykIyopmY_fVz?usp=sharing)), [**Bedroom128**](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
50 | 3. DiffAE (with latent DPM, can sample): [**FFHQ256**](https://drive.google.com/drive/folders/1MonJKYwVLzvCFYuVhp-l9mChq5V2XI6w?usp=sharing), [**FFHQ128**](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing), [**Bedroom128**](https://drive.google.com/drive/folders/1okhCb1RezlWmDbdEAGWMHMkUBRRXmey0?usp=sharing), [**Horse128**](https://drive.google.com/drive/folders/1Ujmv3ajeiJLOT6lF2zrQb4FimfDkMhcP?usp=sharing)
51 | 4. DiffAE's classifiers (for manipulation): [**FFHQ256's latent on CelebAHQ**](https://drive.google.com/drive/folders/1QGkTfvNhgi_TbbV8GbX1Emrp0lStsqLj?usp=sharing), [**FFHQ128's latent on CelebAHQ**](https://drive.google.com/drive/folders/1E3Ew1p9h42h7UA1DJNK7jnb2ERybg9ji?usp=sharing)
52 |
53 | Checkpoints ought to be put into a separate directory `checkpoints`.
54 | Download the checkpoints and put them into `checkpoints` directory. It should look like this:
55 |
56 | ```
57 | checkpoints/
58 | - bedroom128_autoenc
59 | - last.ckpt # diffae checkpoint
60 | - latent.ckpt # predicted z_sem on the dataset
61 | - bedroom128_autoenc_latent
62 | - last.ckpt # diffae + latent DPM checkpoint
63 | - bedroom128_ddpm
64 | - ...
65 | ```
66 |
67 |
68 | ### LMDB Datasets
69 |
70 | We do not own any of the following datasets. We provide the LMDB ready-to-use dataset for the sake of convenience.
71 |
72 | - [FFHQ](https://drive.google.com/drive/folders/1ww7itaSo53NDMa0q-wn-3HWZ3HHqK1IK?usp=sharing)
73 | - [CelebAHQ](https://drive.google.com/drive/folders/1SX3JuVHjYA8sA28EGxr_IoHJ63s4Btbl?usp=sharing)
74 | - [CelebA](https://drive.google.com/drive/folders/1HJAhK2hLYcT_n0gWlCu5XxdZj-bPekZ0?usp=sharing)
75 | - [LSUN Bedroom](https://drive.google.com/drive/folders/1O_3aT3LtY1YDE2pOQCp6MFpCk7Pcpkhb?usp=sharing)
76 | - [LSUN Horse](https://drive.google.com/drive/folders/1ooHW7VivZUs4i5CarPaWxakCwfeqAK8l?usp=sharing)
77 |
78 | The directory tree should be:
79 |
80 | ```
81 | datasets/
82 | - bedroom256.lmdb
83 | - celebahq256.lmdb
84 | - celeba.lmdb
85 | - ffhq256.lmdb
86 | - horse256.lmdb
87 | ```
88 |
89 | You can also download from the original sources, and use our provided codes to package them as LMDB files.
90 | Original sources for each dataset is as follows:
91 |
92 | - FFHQ (https://github.com/NVlabs/ffhq-dataset)
93 | - CelebAHQ (https://github.com/switchablenorms/CelebAMask-HQ)
94 | - CelebA (https://mmlab.ie.cuhk.edu.hk/projects/CelebA.html)
95 | - LSUN (https://github.com/fyu/lsun)
96 |
97 | The conversion codes are provided as:
98 |
99 | ```
100 | data_resize_bedroom.py
101 | data_resize_celebhq.py
102 | data_resize_celeba.py
103 | data_resize_ffhq.py
104 | data_resize_horse.py
105 | ```
106 |
107 | Google drive: https://drive.google.com/drive/folders/1abNP4QKGbNnymjn8607BF0cwxX2L23jh?usp=sharing
108 |
109 |
110 | ## Training
111 |
112 | We provide scripts for training & evaluate DDIM and DiffAE (including latent DPM) on the following datasets: FFHQ128, FFHQ256, Bedroom128, Horse128, Celeba64 (D2C's crop).
113 | Usually, the evaluation results (FID's) will be available in `eval` directory.
114 |
115 | Note: Most experiment requires at least 4x V100s during training the DPM models while requiring 1x 2080Ti during training the accompanying latent DPM.
116 |
117 |
118 |
119 | **FFHQ128**
120 | ```
121 | # diffae
122 | python run_ffhq128.py
123 | # ddim
124 | python run_ffhq128_ddim.py
125 | ```
126 |
127 | A classifier (for manipulation) can be trained using:
128 | ```
129 | python run_ffhq128_cls.py
130 | ```
131 |
132 | **FFHQ256**
133 |
134 | We only trained the DiffAE due to high computation cost.
135 | This requires 8x V100s.
136 | ```
137 | sbatch run_ffhq256.py
138 | ```
139 |
140 | After the task is done, you need to train the latent DPM (requiring only 1x 2080Ti)
141 | ```
142 | python run_ffhq256_latent.py
143 | ```
144 |
145 | A classifier (for manipulation) can be trained using:
146 | ```
147 | python run_ffhq256_cls.py
148 | ```
149 |
150 | **Bedroom128**
151 |
152 | ```
153 | # diffae
154 | python run_bedroom128.py
155 | # ddim
156 | python run_bedroom128_ddim.py
157 | ```
158 |
159 | **Horse128**
160 |
161 | ```
162 | # diffae
163 | python run_horse128.py
164 | # ddim
165 | python run_horse128_ddim.py
166 | ```
167 |
168 | **Celeba64**
169 |
170 | This experiment can be run on 2080Ti's.
171 |
172 | ```
173 | # diffae
174 | python run_celeba64.py
175 | ```
--------------------------------------------------------------------------------
/align.py:
--------------------------------------------------------------------------------
1 | import bz2
2 | import os
3 | import os.path as osp
4 | import sys
5 | from multiprocessing import Pool
6 |
7 | import dlib
8 | import numpy as np
9 | import PIL.Image
10 | import requests
11 | import scipy.ndimage
12 | from tqdm import tqdm
13 | from argparse import ArgumentParser
14 |
15 | LANDMARKS_MODEL_URL = 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2'
16 |
17 |
18 | def image_align(src_file,
19 | dst_file,
20 | face_landmarks,
21 | output_size=1024,
22 | transform_size=4096,
23 | enable_padding=True):
24 | # Align function from FFHQ dataset pre-processing step
25 | # https://github.com/NVlabs/ffhq-dataset/blob/master/download_ffhq.py
26 |
27 | lm = np.array(face_landmarks)
28 | lm_chin = lm[0:17] # left-right
29 | lm_eyebrow_left = lm[17:22] # left-right
30 | lm_eyebrow_right = lm[22:27] # left-right
31 | lm_nose = lm[27:31] # top-down
32 | lm_nostrils = lm[31:36] # top-down
33 | lm_eye_left = lm[36:42] # left-clockwise
34 | lm_eye_right = lm[42:48] # left-clockwise
35 | lm_mouth_outer = lm[48:60] # left-clockwise
36 | lm_mouth_inner = lm[60:68] # left-clockwise
37 |
38 | # Calculate auxiliary vectors.
39 | eye_left = np.mean(lm_eye_left, axis=0)
40 | eye_right = np.mean(lm_eye_right, axis=0)
41 | eye_avg = (eye_left + eye_right) * 0.5
42 | eye_to_eye = eye_right - eye_left
43 | mouth_left = lm_mouth_outer[0]
44 | mouth_right = lm_mouth_outer[6]
45 | mouth_avg = (mouth_left + mouth_right) * 0.5
46 | eye_to_mouth = mouth_avg - eye_avg
47 |
48 | # Choose oriented crop rectangle.
49 | x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]
50 | x /= np.hypot(*x)
51 | x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)
52 | y = np.flipud(x) * [-1, 1]
53 | c = eye_avg + eye_to_mouth * 0.1
54 | quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])
55 | qsize = np.hypot(*x) * 2
56 |
57 | # Load in-the-wild image.
58 | if not os.path.isfile(src_file):
59 | print(
60 | '\nCannot find source image. Please run "--wilds" before "--align".'
61 | )
62 | return
63 | img = PIL.Image.open(src_file)
64 | img = img.convert('RGB')
65 |
66 | # Shrink.
67 | shrink = int(np.floor(qsize / output_size * 0.5))
68 | if shrink > 1:
69 | rsize = (int(np.rint(float(img.size[0]) / shrink)),
70 | int(np.rint(float(img.size[1]) / shrink)))
71 | img = img.resize(rsize, PIL.Image.ANTIALIAS)
72 | quad /= shrink
73 | qsize /= shrink
74 |
75 | # Crop.
76 | border = max(int(np.rint(qsize * 0.1)), 3)
77 | crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
78 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
79 | crop = (max(crop[0] - border, 0), max(crop[1] - border, 0),
80 | min(crop[2] + border,
81 | img.size[0]), min(crop[3] + border, img.size[1]))
82 | if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
83 | img = img.crop(crop)
84 | quad -= crop[0:2]
85 |
86 | # Pad.
87 | pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))),
88 | int(np.ceil(max(quad[:, 0]))), int(np.ceil(max(quad[:, 1]))))
89 | pad = (max(-pad[0] + border,
90 | 0), max(-pad[1] + border,
91 | 0), max(pad[2] - img.size[0] + border,
92 | 0), max(pad[3] - img.size[1] + border, 0))
93 | if enable_padding and max(pad) > border - 4:
94 | pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
95 | img = np.pad(np.float32(img),
96 | ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
97 | h, w, _ = img.shape
98 | y, x, _ = np.ogrid[:h, :w, :1]
99 | mask = np.maximum(
100 | 1.0 -
101 | np.minimum(np.float32(x) / pad[0],
102 | np.float32(w - 1 - x) / pad[2]), 1.0 -
103 | np.minimum(np.float32(y) / pad[1],
104 | np.float32(h - 1 - y) / pad[3]))
105 | blur = qsize * 0.02
106 | img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) -
107 | img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
108 | img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
109 | img = PIL.Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)),
110 | 'RGB')
111 | quad += pad[:2]
112 |
113 | # Transform.
114 | img = img.transform((transform_size, transform_size), PIL.Image.QUAD,
115 | (quad + 0.5).flatten(), PIL.Image.BILINEAR)
116 | if output_size < transform_size:
117 | img = img.resize((output_size, output_size), PIL.Image.ANTIALIAS)
118 |
119 | # Save aligned image.
120 | img.save(dst_file, 'PNG')
121 |
122 |
123 | class LandmarksDetector:
124 | def __init__(self, predictor_model_path):
125 | """
126 | :param predictor_model_path: path to shape_predictor_68_face_landmarks.dat file
127 | """
128 | self.detector = dlib.get_frontal_face_detector(
129 | ) # cnn_face_detection_model_v1 also can be used
130 | self.shape_predictor = dlib.shape_predictor(predictor_model_path)
131 |
132 | def get_landmarks(self, image):
133 | img = dlib.load_rgb_image(image)
134 | dets = self.detector(img, 1)
135 |
136 | for detection in dets:
137 | face_landmarks = [
138 | (item.x, item.y)
139 | for item in self.shape_predictor(img, detection).parts()
140 | ]
141 | yield face_landmarks
142 |
143 |
144 | def unpack_bz2(src_path):
145 | dst_path = src_path[:-4]
146 | if os.path.exists(dst_path):
147 | print('cached')
148 | return dst_path
149 | data = bz2.BZ2File(src_path).read()
150 | with open(dst_path, 'wb') as fp:
151 | fp.write(data)
152 | return dst_path
153 |
154 |
155 | def work_landmark(raw_img_path, img_name, face_landmarks):
156 | face_img_name = '%s.png' % (os.path.splitext(img_name)[0], )
157 | aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
158 | if os.path.exists(aligned_face_path):
159 | return
160 | image_align(raw_img_path,
161 | aligned_face_path,
162 | face_landmarks,
163 | output_size=256)
164 |
165 |
166 | def get_file(src, tgt):
167 | if os.path.exists(tgt):
168 | print('cached')
169 | return tgt
170 | tgt_dir = os.path.dirname(tgt)
171 | if not os.path.exists(tgt_dir):
172 | os.makedirs(tgt_dir)
173 | file = requests.get(src)
174 | open(tgt, 'wb').write(file.content)
175 | return tgt
176 |
177 |
178 | if __name__ == "__main__":
179 | """
180 | Extracts and aligns all faces from images using DLib and a function from original FFHQ dataset preparation step
181 | python align_images.py /raw_images /aligned_images
182 | """
183 | parser = ArgumentParser()
184 | parser.add_argument("-i",
185 | "--input_imgs_path",
186 | type=str,
187 | default="imgs",
188 | help="input images directory path")
189 | parser.add_argument("-o",
190 | "--output_imgs_path",
191 | type=str,
192 | default="imgs_align",
193 | help="output images directory path")
194 |
195 | args = parser.parse_args()
196 |
197 | # takes very long time ...
198 | landmarks_model_path = unpack_bz2(
199 | get_file(
200 | 'http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2',
201 | 'temp/shape_predictor_68_face_landmarks.dat.bz2'))
202 |
203 | # RAW_IMAGES_DIR = sys.argv[1]
204 | # ALIGNED_IMAGES_DIR = sys.argv[2]
205 | RAW_IMAGES_DIR = args.input_imgs_path
206 | ALIGNED_IMAGES_DIR = args.output_imgs_path
207 |
208 | if not osp.exists(ALIGNED_IMAGES_DIR): os.makedirs(ALIGNED_IMAGES_DIR)
209 |
210 | files = os.listdir(RAW_IMAGES_DIR)
211 | print(f'total img files {len(files)}')
212 | with tqdm(total=len(files)) as progress:
213 |
214 | def cb(*args):
215 | # print('update')
216 | progress.update()
217 |
218 | def err_cb(e):
219 | print('error:', e)
220 |
221 | with Pool(8) as pool:
222 | res = []
223 | landmarks_detector = LandmarksDetector(landmarks_model_path)
224 | for img_name in files:
225 | raw_img_path = os.path.join(RAW_IMAGES_DIR, img_name)
226 | # print('img_name:', img_name)
227 | for i, face_landmarks in enumerate(
228 | landmarks_detector.get_landmarks(raw_img_path),
229 | start=1):
230 | # assert i == 1, f'{i}'
231 | # print(i, face_landmarks)
232 | # face_img_name = '%s_%02d.png' % (os.path.splitext(img_name)[0], i)
233 | # aligned_face_path = os.path.join(ALIGNED_IMAGES_DIR, face_img_name)
234 | # image_align(raw_img_path, aligned_face_path, face_landmarks, output_size=256)
235 |
236 | work_landmark(raw_img_path, img_name, face_landmarks)
237 | progress.update()
238 |
239 | # job = pool.apply_async(
240 | # work_landmark,
241 | # (raw_img_path, img_name, face_landmarks),
242 | # callback=cb,
243 | # error_callback=err_cb,
244 | # )
245 | # res.append(job)
246 |
247 | # pool.close()
248 | # pool.join()
249 | print(f"output aligned images at: {ALIGNED_IMAGES_DIR}")
250 |
--------------------------------------------------------------------------------
/choices.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 | from torch import nn
3 |
4 |
5 | class TrainMode(Enum):
6 | # manipulate mode = training the classifier
7 | manipulate = 'manipulate'
8 | # default trainin mode!
9 | diffusion = 'diffusion'
10 | # default latent training mode!
11 | # fitting the a DDPM to a given latent
12 | latent_diffusion = 'latentdiffusion'
13 |
14 | def is_manipulate(self):
15 | return self in [
16 | TrainMode.manipulate,
17 | ]
18 |
19 | def is_diffusion(self):
20 | return self in [
21 | TrainMode.diffusion,
22 | TrainMode.latent_diffusion,
23 | ]
24 |
25 | def is_autoenc(self):
26 | # the network possibly does autoencoding
27 | return self in [
28 | TrainMode.diffusion,
29 | ]
30 |
31 | def is_latent_diffusion(self):
32 | return self in [
33 | TrainMode.latent_diffusion,
34 | ]
35 |
36 | def use_latent_net(self):
37 | return self.is_latent_diffusion()
38 |
39 | def require_dataset_infer(self):
40 | """
41 | whether training in this mode requires the latent variables to be available?
42 | """
43 | # this will precalculate all the latents before hand
44 | # and the dataset will be all the predicted latents
45 | return self in [
46 | TrainMode.latent_diffusion,
47 | TrainMode.manipulate,
48 | ]
49 |
50 |
51 | class ManipulateMode(Enum):
52 | """
53 | how to train the classifier to manipulate
54 | """
55 | # train on whole celeba attr dataset
56 | celebahq_all = 'celebahq_all'
57 | # celeba with D2C's crop
58 | d2c_fewshot = 'd2cfewshot'
59 | d2c_fewshot_allneg = 'd2cfewshotallneg'
60 |
61 | def is_celeba_attr(self):
62 | return self in [
63 | ManipulateMode.d2c_fewshot,
64 | ManipulateMode.d2c_fewshot_allneg,
65 | ManipulateMode.celebahq_all,
66 | ]
67 |
68 | def is_single_class(self):
69 | return self in [
70 | ManipulateMode.d2c_fewshot,
71 | ManipulateMode.d2c_fewshot_allneg,
72 | ]
73 |
74 | def is_fewshot(self):
75 | return self in [
76 | ManipulateMode.d2c_fewshot,
77 | ManipulateMode.d2c_fewshot_allneg,
78 | ]
79 |
80 | def is_fewshot_allneg(self):
81 | return self in [
82 | ManipulateMode.d2c_fewshot_allneg,
83 | ]
84 |
85 |
86 | class ModelType(Enum):
87 | """
88 | Kinds of the backbone models
89 | """
90 |
91 | # unconditional ddpm
92 | ddpm = 'ddpm'
93 | # autoencoding ddpm cannot do unconditional generation
94 | autoencoder = 'autoencoder'
95 |
96 | def has_autoenc(self):
97 | return self in [
98 | ModelType.autoencoder,
99 | ]
100 |
101 | def can_sample(self):
102 | return self in [ModelType.ddpm]
103 |
104 |
105 | class ModelName(Enum):
106 | """
107 | List of all supported model classes
108 | """
109 |
110 | beatgans_ddpm = 'beatgans_ddpm'
111 | beatgans_autoenc = 'beatgans_autoenc'
112 |
113 |
114 | class ModelMeanType(Enum):
115 | """
116 | Which type of output the model predicts.
117 | """
118 |
119 | eps = 'eps' # the model predicts epsilon
120 |
121 |
122 | class ModelVarType(Enum):
123 | """
124 | What is used as the model's output variance.
125 |
126 | The LEARNED_RANGE option has been added to allow the model to predict
127 | values between FIXED_SMALL and FIXED_LARGE, making its job easier.
128 | """
129 |
130 | # posterior beta_t
131 | fixed_small = 'fixed_small'
132 | # beta_t
133 | fixed_large = 'fixed_large'
134 |
135 |
136 | class LossType(Enum):
137 | mse = 'mse' # use raw MSE loss (and KL when learning variances)
138 | l1 = 'l1'
139 |
140 |
141 | class GenerativeType(Enum):
142 | """
143 | How's a sample generated
144 | """
145 |
146 | ddpm = 'ddpm'
147 | ddim = 'ddim'
148 |
149 |
150 | class OptimizerType(Enum):
151 | adam = 'adam'
152 | adamw = 'adamw'
153 |
154 |
155 | class Activation(Enum):
156 | none = 'none'
157 | relu = 'relu'
158 | lrelu = 'lrelu'
159 | silu = 'silu'
160 | tanh = 'tanh'
161 |
162 | def get_act(self):
163 | if self == Activation.none:
164 | return nn.Identity()
165 | elif self == Activation.relu:
166 | return nn.ReLU()
167 | elif self == Activation.lrelu:
168 | return nn.LeakyReLU(negative_slope=0.2)
169 | elif self == Activation.silu:
170 | return nn.SiLU()
171 | elif self == Activation.tanh:
172 | return nn.Tanh()
173 | else:
174 | raise NotImplementedError()
175 |
176 |
177 | class ManipulateLossType(Enum):
178 | bce = 'bce'
179 | mse = 'mse'
--------------------------------------------------------------------------------
/cog.yaml:
--------------------------------------------------------------------------------
1 | build:
2 | cuda: "10.2"
3 | gpu: true
4 | python_version: "3.8"
5 | system_packages:
6 | - "libgl1-mesa-glx"
7 | - "libglib2.0-0"
8 | python_packages:
9 | - "numpy==1.21.5"
10 | - "cmake==3.23.3"
11 | - "ipython==7.21.0"
12 | - "opencv-python==4.5.4.58"
13 | - "pandas==1.1.5"
14 | - "lmdb==1.2.1"
15 | - "lpips==0.1.4"
16 | - "pytorch-fid==0.2.0"
17 | - "ftfy==6.1.1"
18 | - "scipy==1.5.4"
19 | - "torch==1.9.1"
20 | - "torchvision==0.10.1"
21 | - "tqdm==4.62.3"
22 | - "regex==2022.7.25"
23 | - "Pillow==9.2.0"
24 | - "pytorch_lightning==1.7.0"
25 |
26 | run:
27 | - pip install dlib
28 |
29 | predict: "predict.py:Predictor"
30 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | from model.unet import ScaleAt
2 | from model.latentnet import *
3 | from diffusion.resample import UniformSampler
4 | from diffusion.diffusion import space_timesteps
5 | from typing import Tuple
6 |
7 | from torch.utils.data import DataLoader
8 |
9 | from config_base import BaseConfig
10 | from dataset import *
11 | from diffusion import *
12 | from diffusion.base import GenerativeType, LossType, ModelMeanType, ModelVarType, get_named_beta_schedule
13 | from model import *
14 | from choices import *
15 | from multiprocessing import get_context
16 | import os
17 | from dataset_util import *
18 | from torch.utils.data.distributed import DistributedSampler
19 |
20 | data_paths = {
21 | 'ffhqlmdb256':
22 | os.path.expanduser('datasets/ffhq256.lmdb'),
23 | # used for training a classifier
24 | 'celeba':
25 | os.path.expanduser('datasets/celeba'),
26 | # used for training DPM models
27 | 'celebalmdb':
28 | os.path.expanduser('datasets/celeba.lmdb'),
29 | 'celebahq':
30 | os.path.expanduser('datasets/celebahq256.lmdb'),
31 | 'horse256':
32 | os.path.expanduser('datasets/horse256.lmdb'),
33 | 'bedroom256':
34 | os.path.expanduser('datasets/bedroom256.lmdb'),
35 | 'celeba_anno':
36 | os.path.expanduser('datasets/celeba_anno/list_attr_celeba.txt'),
37 | 'celebahq_anno':
38 | os.path.expanduser(
39 | 'datasets/celeba_anno/CelebAMask-HQ-attribute-anno.txt'),
40 | 'celeba_relight':
41 | os.path.expanduser('datasets/celeba_hq_light/celeba_light.txt'),
42 | }
43 |
44 |
45 | @dataclass
46 | class PretrainConfig(BaseConfig):
47 | name: str
48 | path: str
49 |
50 |
51 | @dataclass
52 | class TrainConfig(BaseConfig):
53 | # random seed
54 | seed: int = 0
55 | train_mode: TrainMode = TrainMode.diffusion
56 | train_cond0_prob: float = 0
57 | train_pred_xstart_detach: bool = True
58 | train_interpolate_prob: float = 0
59 | train_interpolate_img: bool = False
60 | manipulate_mode: ManipulateMode = ManipulateMode.celebahq_all
61 | manipulate_cls: str = None
62 | manipulate_shots: int = None
63 | manipulate_loss: ManipulateLossType = ManipulateLossType.bce
64 | manipulate_znormalize: bool = False
65 | manipulate_seed: int = 0
66 | accum_batches: int = 1
67 | autoenc_mid_attn: bool = True
68 | batch_size: int = 16
69 | batch_size_eval: int = None
70 | beatgans_gen_type: GenerativeType = GenerativeType.ddim
71 | beatgans_loss_type: LossType = LossType.mse
72 | beatgans_model_mean_type: ModelMeanType = ModelMeanType.eps
73 | beatgans_model_var_type: ModelVarType = ModelVarType.fixed_large
74 | beatgans_rescale_timesteps: bool = False
75 | latent_infer_path: str = None
76 | latent_znormalize: bool = False
77 | latent_gen_type: GenerativeType = GenerativeType.ddim
78 | latent_loss_type: LossType = LossType.mse
79 | latent_model_mean_type: ModelMeanType = ModelMeanType.eps
80 | latent_model_var_type: ModelVarType = ModelVarType.fixed_large
81 | latent_rescale_timesteps: bool = False
82 | latent_T_eval: int = 1_000
83 | latent_clip_sample: bool = False
84 | latent_beta_scheduler: str = 'linear'
85 | beta_scheduler: str = 'linear'
86 | data_name: str = ''
87 | data_val_name: str = None
88 | diffusion_type: str = None
89 | dropout: float = 0.1
90 | ema_decay: float = 0.9999
91 | eval_num_images: int = 5_000
92 | eval_every_samples: int = 200_000
93 | eval_ema_every_samples: int = 200_000
94 | fid_use_torch: bool = True
95 | fp16: bool = False
96 | grad_clip: float = 1
97 | img_size: int = 64
98 | lr: float = 0.0001
99 | optimizer: OptimizerType = OptimizerType.adam
100 | weight_decay: float = 0
101 | model_conf: ModelConfig = None
102 | model_name: ModelName = None
103 | model_type: ModelType = None
104 | net_attn: Tuple[int] = None
105 | net_beatgans_attn_head: int = 1
106 | # not necessarily the same as the the number of style channels
107 | net_beatgans_embed_channels: int = 512
108 | net_resblock_updown: bool = True
109 | net_enc_use_time: bool = False
110 | net_enc_pool: str = 'adaptivenonzero'
111 | net_beatgans_gradient_checkpoint: bool = False
112 | net_beatgans_resnet_two_cond: bool = False
113 | net_beatgans_resnet_use_zero_module: bool = True
114 | net_beatgans_resnet_scale_at: ScaleAt = ScaleAt.after_norm
115 | net_beatgans_resnet_cond_channels: int = None
116 | net_ch_mult: Tuple[int] = None
117 | net_ch: int = 64
118 | net_enc_attn: Tuple[int] = None
119 | net_enc_k: int = None
120 | # number of resblocks for the encoder (half-unet)
121 | net_enc_num_res_blocks: int = 2
122 | net_enc_channel_mult: Tuple[int] = None
123 | net_enc_grad_checkpoint: bool = False
124 | net_autoenc_stochastic: bool = False
125 | net_latent_activation: Activation = Activation.silu
126 | net_latent_channel_mult: Tuple[int] = (1, 2, 4)
127 | net_latent_condition_bias: float = 0
128 | net_latent_dropout: float = 0
129 | net_latent_layers: int = None
130 | net_latent_net_last_act: Activation = Activation.none
131 | net_latent_net_type: LatentNetType = LatentNetType.none
132 | net_latent_num_hid_channels: int = 1024
133 | net_latent_num_time_layers: int = 2
134 | net_latent_skip_layers: Tuple[int] = None
135 | net_latent_time_emb_channels: int = 64
136 | net_latent_use_norm: bool = False
137 | net_latent_time_last_act: bool = False
138 | net_num_res_blocks: int = 2
139 | # number of resblocks for the UNET
140 | net_num_input_res_blocks: int = None
141 | net_enc_num_cls: int = None
142 | num_workers: int = 4
143 | parallel: bool = False
144 | postfix: str = ''
145 | sample_size: int = 64
146 | sample_every_samples: int = 20_000
147 | save_every_samples: int = 100_000
148 | style_ch: int = 512
149 | T_eval: int = 1_000
150 | T_sampler: str = 'uniform'
151 | T: int = 1_000
152 | total_samples: int = 10_000_000
153 | warmup: int = 0
154 | pretrain: PretrainConfig = None
155 | continue_from: PretrainConfig = None
156 | eval_programs: Tuple[str] = None
157 | # if present load the checkpoint from this path instead
158 | eval_path: str = None
159 | base_dir: str = 'checkpoints'
160 | use_cache_dataset: bool = False
161 | data_cache_dir: str = os.path.expanduser('~/cache')
162 | work_cache_dir: str = os.path.expanduser('~/mycache')
163 | # to be overridden
164 | name: str = ''
165 |
166 | def __post_init__(self):
167 | self.batch_size_eval = self.batch_size_eval or self.batch_size
168 | self.data_val_name = self.data_val_name or self.data_name
169 |
170 | def scale_up_gpus(self, num_gpus, num_nodes=1):
171 | self.eval_ema_every_samples *= num_gpus * num_nodes
172 | self.eval_every_samples *= num_gpus * num_nodes
173 | self.sample_every_samples *= num_gpus * num_nodes
174 | self.batch_size *= num_gpus * num_nodes
175 | self.batch_size_eval *= num_gpus * num_nodes
176 | return self
177 |
178 | @property
179 | def batch_size_effective(self):
180 | return self.batch_size * self.accum_batches
181 |
182 | @property
183 | def fid_cache(self):
184 | # we try to use the local dirs to reduce the load over network drives
185 | # hopefully, this would reduce the disconnection problems with sshfs
186 | return f'{self.work_cache_dir}/eval_images/{self.data_name}_size{self.img_size}_{self.eval_num_images}'
187 |
188 | @property
189 | def data_path(self):
190 | # may use the cache dir
191 | path = data_paths[self.data_name]
192 | if self.use_cache_dataset and path is not None:
193 | path = use_cached_dataset_path(
194 | path, f'{self.data_cache_dir}/{self.data_name}')
195 | return path
196 |
197 | @property
198 | def logdir(self):
199 | return f'{self.base_dir}/{self.name}'
200 |
201 | @property
202 | def generate_dir(self):
203 | # we try to use the local dirs to reduce the load over network drives
204 | # hopefully, this would reduce the disconnection problems with sshfs
205 | return f'{self.work_cache_dir}/gen_images/{self.name}'
206 |
207 | def _make_diffusion_conf(self, T=None):
208 | if self.diffusion_type == 'beatgans':
209 | # can use T < self.T for evaluation
210 | # follows the guided-diffusion repo conventions
211 | # t's are evenly spaced
212 | if self.beatgans_gen_type == GenerativeType.ddpm:
213 | section_counts = [T]
214 | elif self.beatgans_gen_type == GenerativeType.ddim:
215 | section_counts = f'ddim{T}'
216 | else:
217 | raise NotImplementedError()
218 |
219 | return SpacedDiffusionBeatGansConfig(
220 | gen_type=self.beatgans_gen_type,
221 | model_type=self.model_type,
222 | betas=get_named_beta_schedule(self.beta_scheduler, self.T),
223 | model_mean_type=self.beatgans_model_mean_type,
224 | model_var_type=self.beatgans_model_var_type,
225 | loss_type=self.beatgans_loss_type,
226 | rescale_timesteps=self.beatgans_rescale_timesteps,
227 | use_timesteps=space_timesteps(num_timesteps=self.T,
228 | section_counts=section_counts),
229 | fp16=self.fp16,
230 | )
231 | else:
232 | raise NotImplementedError()
233 |
234 | def _make_latent_diffusion_conf(self, T=None):
235 | # can use T < self.T for evaluation
236 | # follows the guided-diffusion repo conventions
237 | # t's are evenly spaced
238 | if self.latent_gen_type == GenerativeType.ddpm:
239 | section_counts = [T]
240 | elif self.latent_gen_type == GenerativeType.ddim:
241 | section_counts = f'ddim{T}'
242 | else:
243 | raise NotImplementedError()
244 |
245 | return SpacedDiffusionBeatGansConfig(
246 | train_pred_xstart_detach=self.train_pred_xstart_detach,
247 | gen_type=self.latent_gen_type,
248 | # latent's model is always ddpm
249 | model_type=ModelType.ddpm,
250 | # latent shares the beta scheduler and full T
251 | betas=get_named_beta_schedule(self.latent_beta_scheduler, self.T),
252 | model_mean_type=self.latent_model_mean_type,
253 | model_var_type=self.latent_model_var_type,
254 | loss_type=self.latent_loss_type,
255 | rescale_timesteps=self.latent_rescale_timesteps,
256 | use_timesteps=space_timesteps(num_timesteps=self.T,
257 | section_counts=section_counts),
258 | fp16=self.fp16,
259 | )
260 |
261 | @property
262 | def model_out_channels(self):
263 | return 3
264 |
265 | def make_T_sampler(self):
266 | if self.T_sampler == 'uniform':
267 | return UniformSampler(self.T)
268 | else:
269 | raise NotImplementedError()
270 |
271 | def make_diffusion_conf(self):
272 | return self._make_diffusion_conf(self.T)
273 |
274 | def make_eval_diffusion_conf(self):
275 | return self._make_diffusion_conf(T=self.T_eval)
276 |
277 | def make_latent_diffusion_conf(self):
278 | return self._make_latent_diffusion_conf(T=self.T)
279 |
280 | def make_latent_eval_diffusion_conf(self):
281 | # latent can have different eval T
282 | return self._make_latent_diffusion_conf(T=self.latent_T_eval)
283 |
284 | def make_dataset(self, path=None, **kwargs):
285 | if self.data_name == 'ffhqlmdb256':
286 | return FFHQlmdb(path=path or self.data_path,
287 | image_size=self.img_size,
288 | **kwargs)
289 | elif self.data_name == 'horse256':
290 | return Horse_lmdb(path=path or self.data_path,
291 | image_size=self.img_size,
292 | **kwargs)
293 | elif self.data_name == 'bedroom256':
294 | return Horse_lmdb(path=path or self.data_path,
295 | image_size=self.img_size,
296 | **kwargs)
297 | elif self.data_name == 'celebalmdb':
298 | # always use d2c crop
299 | return CelebAlmdb(path=path or self.data_path,
300 | image_size=self.img_size,
301 | original_resolution=None,
302 | crop_d2c=True,
303 | **kwargs)
304 | else:
305 | raise NotImplementedError()
306 |
307 | def make_loader(self,
308 | dataset,
309 | shuffle: bool,
310 | num_worker: bool = None,
311 | drop_last: bool = True,
312 | batch_size: int = None,
313 | parallel: bool = False):
314 | if parallel and distributed.is_initialized():
315 | # drop last to make sure that there is no added special indexes
316 | sampler = DistributedSampler(dataset,
317 | shuffle=shuffle,
318 | drop_last=True)
319 | else:
320 | sampler = None
321 | return DataLoader(
322 | dataset,
323 | batch_size=batch_size or self.batch_size,
324 | sampler=sampler,
325 | # with sampler, use the sample instead of this option
326 | shuffle=False if sampler else shuffle,
327 | num_workers=num_worker or self.num_workers,
328 | pin_memory=True,
329 | drop_last=drop_last,
330 | multiprocessing_context=get_context('fork'),
331 | )
332 |
333 | def make_model_conf(self):
334 | if self.model_name == ModelName.beatgans_ddpm:
335 | self.model_type = ModelType.ddpm
336 | self.model_conf = BeatGANsUNetConfig(
337 | attention_resolutions=self.net_attn,
338 | channel_mult=self.net_ch_mult,
339 | conv_resample=True,
340 | dims=2,
341 | dropout=self.dropout,
342 | embed_channels=self.net_beatgans_embed_channels,
343 | image_size=self.img_size,
344 | in_channels=3,
345 | model_channels=self.net_ch,
346 | num_classes=None,
347 | num_head_channels=-1,
348 | num_heads_upsample=-1,
349 | num_heads=self.net_beatgans_attn_head,
350 | num_res_blocks=self.net_num_res_blocks,
351 | num_input_res_blocks=self.net_num_input_res_blocks,
352 | out_channels=self.model_out_channels,
353 | resblock_updown=self.net_resblock_updown,
354 | use_checkpoint=self.net_beatgans_gradient_checkpoint,
355 | use_new_attention_order=False,
356 | resnet_two_cond=self.net_beatgans_resnet_two_cond,
357 | resnet_use_zero_module=self.
358 | net_beatgans_resnet_use_zero_module,
359 | )
360 | elif self.model_name in [
361 | ModelName.beatgans_autoenc,
362 | ]:
363 | cls = BeatGANsAutoencConfig
364 | # supports both autoenc and vaeddpm
365 | if self.model_name == ModelName.beatgans_autoenc:
366 | self.model_type = ModelType.autoencoder
367 | else:
368 | raise NotImplementedError()
369 |
370 | if self.net_latent_net_type == LatentNetType.none:
371 | latent_net_conf = None
372 | elif self.net_latent_net_type == LatentNetType.skip:
373 | latent_net_conf = MLPSkipNetConfig(
374 | num_channels=self.style_ch,
375 | skip_layers=self.net_latent_skip_layers,
376 | num_hid_channels=self.net_latent_num_hid_channels,
377 | num_layers=self.net_latent_layers,
378 | num_time_emb_channels=self.net_latent_time_emb_channels,
379 | activation=self.net_latent_activation,
380 | use_norm=self.net_latent_use_norm,
381 | condition_bias=self.net_latent_condition_bias,
382 | dropout=self.net_latent_dropout,
383 | last_act=self.net_latent_net_last_act,
384 | num_time_layers=self.net_latent_num_time_layers,
385 | time_last_act=self.net_latent_time_last_act,
386 | )
387 | else:
388 | raise NotImplementedError()
389 |
390 | self.model_conf = cls(
391 | attention_resolutions=self.net_attn,
392 | channel_mult=self.net_ch_mult,
393 | conv_resample=True,
394 | dims=2,
395 | dropout=self.dropout,
396 | embed_channels=self.net_beatgans_embed_channels,
397 | enc_out_channels=self.style_ch,
398 | enc_pool=self.net_enc_pool,
399 | enc_num_res_block=self.net_enc_num_res_blocks,
400 | enc_channel_mult=self.net_enc_channel_mult,
401 | enc_grad_checkpoint=self.net_enc_grad_checkpoint,
402 | enc_attn_resolutions=self.net_enc_attn,
403 | image_size=self.img_size,
404 | in_channels=3,
405 | model_channels=self.net_ch,
406 | num_classes=None,
407 | num_head_channels=-1,
408 | num_heads_upsample=-1,
409 | num_heads=self.net_beatgans_attn_head,
410 | num_res_blocks=self.net_num_res_blocks,
411 | num_input_res_blocks=self.net_num_input_res_blocks,
412 | out_channels=self.model_out_channels,
413 | resblock_updown=self.net_resblock_updown,
414 | use_checkpoint=self.net_beatgans_gradient_checkpoint,
415 | use_new_attention_order=False,
416 | resnet_two_cond=self.net_beatgans_resnet_two_cond,
417 | resnet_use_zero_module=self.
418 | net_beatgans_resnet_use_zero_module,
419 | latent_net_conf=latent_net_conf,
420 | resnet_cond_channels=self.net_beatgans_resnet_cond_channels,
421 | )
422 | else:
423 | raise NotImplementedError(self.model_name)
424 |
425 | return self.model_conf
426 |
--------------------------------------------------------------------------------
/config_base.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from copy import deepcopy
4 | from dataclasses import dataclass
5 |
6 |
7 | @dataclass
8 | class BaseConfig:
9 | def clone(self):
10 | return deepcopy(self)
11 |
12 | def inherit(self, another):
13 | """inherit common keys from a given config"""
14 | common_keys = set(self.__dict__.keys()) & set(another.__dict__.keys())
15 | for k in common_keys:
16 | setattr(self, k, getattr(another, k))
17 |
18 | def propagate(self):
19 | """push down the configuration to all members"""
20 | for k, v in self.__dict__.items():
21 | if isinstance(v, BaseConfig):
22 | v.inherit(self)
23 | v.propagate()
24 |
25 | def save(self, save_path):
26 | """save config to json file"""
27 | dirname = os.path.dirname(save_path)
28 | if not os.path.exists(dirname):
29 | os.makedirs(dirname)
30 | conf = self.as_dict_jsonable()
31 | with open(save_path, 'w') as f:
32 | json.dump(conf, f)
33 |
34 | def load(self, load_path):
35 | """load json config"""
36 | with open(load_path) as f:
37 | conf = json.load(f)
38 | self.from_dict(conf)
39 |
40 | def from_dict(self, dict, strict=False):
41 | for k, v in dict.items():
42 | if not hasattr(self, k):
43 | if strict:
44 | raise ValueError(f"loading extra '{k}'")
45 | else:
46 | print(f"loading extra '{k}'")
47 | continue
48 | if isinstance(self.__dict__[k], BaseConfig):
49 | self.__dict__[k].from_dict(v)
50 | else:
51 | self.__dict__[k] = v
52 |
53 | def as_dict_jsonable(self):
54 | conf = {}
55 | for k, v in self.__dict__.items():
56 | if isinstance(v, BaseConfig):
57 | conf[k] = v.as_dict_jsonable()
58 | else:
59 | if jsonable(v):
60 | conf[k] = v
61 | else:
62 | # ignore not jsonable
63 | pass
64 | return conf
65 |
66 |
67 | def jsonable(x):
68 | try:
69 | json.dumps(x)
70 | return True
71 | except TypeError:
72 | return False
73 |
--------------------------------------------------------------------------------
/data_resize_bedroom.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing
3 | import os
4 | from os.path import join, exists
5 | from functools import partial
6 | from io import BytesIO
7 | import shutil
8 |
9 | import lmdb
10 | from PIL import Image
11 | from torchvision.datasets import LSUNClass
12 | from torchvision.transforms import functional as trans_fn
13 | from tqdm import tqdm
14 |
15 | from multiprocessing import Process, Queue
16 |
17 |
18 | def resize_and_convert(img, size, resample, quality=100):
19 | img = trans_fn.resize(img, size, resample)
20 | img = trans_fn.center_crop(img, size)
21 | buffer = BytesIO()
22 | img.save(buffer, format="webp", quality=quality)
23 | val = buffer.getvalue()
24 |
25 | return val
26 |
27 |
28 | def resize_multiple(img,
29 | sizes=(128, 256, 512, 1024),
30 | resample=Image.LANCZOS,
31 | quality=100):
32 | imgs = []
33 |
34 | for size in sizes:
35 | imgs.append(resize_and_convert(img, size, resample, quality))
36 |
37 | return imgs
38 |
39 |
40 | def resize_worker(idx, img, sizes, resample):
41 | img = img.convert("RGB")
42 | out = resize_multiple(img, sizes=sizes, resample=resample)
43 | return idx, out
44 |
45 |
46 | from torch.utils.data import Dataset, DataLoader
47 |
48 |
49 | class ConvertDataset(Dataset):
50 | def __init__(self, data) -> None:
51 | self.data = data
52 |
53 | def __len__(self):
54 | return len(self.data)
55 |
56 | def __getitem__(self, index):
57 | img, _ = self.data[index]
58 | bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90)
59 | return bytes
60 |
61 |
62 | if __name__ == "__main__":
63 | """
64 | converting lsun' original lmdb to our lmdb, which is somehow more performant.
65 | """
66 | from tqdm import tqdm
67 |
68 | # path to the original lsun's lmdb
69 | src_path = 'datasets/bedroom_train_lmdb'
70 | out_path = 'datasets/bedroom256.lmdb'
71 |
72 | dataset = LSUNClass(root=os.path.expanduser(src_path))
73 | dataset = ConvertDataset(dataset)
74 | loader = DataLoader(dataset,
75 | batch_size=50,
76 | num_workers=12,
77 | collate_fn=lambda x: x,
78 | shuffle=False)
79 |
80 | target = os.path.expanduser(out_path)
81 | if os.path.exists(target):
82 | shutil.rmtree(target)
83 |
84 | with lmdb.open(target, map_size=1024**4, readahead=False) as env:
85 | with tqdm(total=len(dataset)) as progress:
86 | i = 0
87 | for batch in loader:
88 | with env.begin(write=True) as txn:
89 | for img in batch:
90 | key = f"{256}-{str(i).zfill(7)}".encode("utf-8")
91 | # print(key)
92 | txn.put(key, img)
93 | i += 1
94 | progress.update()
95 | # if i == 1000:
96 | # break
97 | # if total == len(imgset):
98 | # break
99 |
100 | with env.begin(write=True) as txn:
101 | txn.put("length".encode("utf-8"), str(i).encode("utf-8"))
102 |
--------------------------------------------------------------------------------
/data_resize_celeba.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing
3 | import os
4 | import shutil
5 | from functools import partial
6 | from io import BytesIO
7 | from multiprocessing import Process, Queue
8 | from os.path import exists, join
9 | from pathlib import Path
10 |
11 | import lmdb
12 | from PIL import Image
13 | from torch.utils.data import DataLoader, Dataset
14 | from torchvision.datasets import LSUNClass
15 | from torchvision.transforms import functional as trans_fn
16 | from tqdm import tqdm
17 |
18 |
19 | def resize_and_convert(img, size, resample, quality=100):
20 | if size is not None:
21 | img = trans_fn.resize(img, size, resample)
22 | img = trans_fn.center_crop(img, size)
23 |
24 | buffer = BytesIO()
25 | img.save(buffer, format="webp", quality=quality)
26 | val = buffer.getvalue()
27 |
28 | return val
29 |
30 |
31 | def resize_multiple(img,
32 | sizes=(128, 256, 512, 1024),
33 | resample=Image.LANCZOS,
34 | quality=100):
35 | imgs = []
36 |
37 | for size in sizes:
38 | imgs.append(resize_and_convert(img, size, resample, quality))
39 |
40 | return imgs
41 |
42 |
43 | def resize_worker(idx, img, sizes, resample):
44 | img = img.convert("RGB")
45 | out = resize_multiple(img, sizes=sizes, resample=resample)
46 | return idx, out
47 |
48 |
49 | class ConvertDataset(Dataset):
50 | def __init__(self, data, size) -> None:
51 | self.data = data
52 | self.size = size
53 |
54 | def __len__(self):
55 | return len(self.data)
56 |
57 | def __getitem__(self, index):
58 | img = self.data[index]
59 | bytes = resize_and_convert(img, self.size, Image.LANCZOS, quality=100)
60 | return bytes
61 |
62 |
63 | class ImageFolder(Dataset):
64 | def __init__(self, folder, ext='jpg'):
65 | super().__init__()
66 | paths = sorted([p for p in Path(f'{folder}').glob(f'*.{ext}')])
67 | self.paths = paths
68 |
69 | def __len__(self):
70 | return len(self.paths)
71 |
72 | def __getitem__(self, index):
73 | path = os.path.join(self.paths[index])
74 | img = Image.open(path)
75 | return img
76 |
77 |
78 | if __name__ == "__main__":
79 | from tqdm import tqdm
80 |
81 | out_path = 'datasets/celeba.lmdb'
82 | in_path = 'datasets/celeba'
83 | ext = 'jpg'
84 | size = None
85 |
86 | dataset = ImageFolder(in_path, ext)
87 | print('len:', len(dataset))
88 | dataset = ConvertDataset(dataset, size)
89 | loader = DataLoader(dataset,
90 | batch_size=50,
91 | num_workers=12,
92 | collate_fn=lambda x: x,
93 | shuffle=False)
94 |
95 | target = os.path.expanduser(out_path)
96 | if os.path.exists(target):
97 | shutil.rmtree(target)
98 |
99 | with lmdb.open(target, map_size=1024**4, readahead=False) as env:
100 | with tqdm(total=len(dataset)) as progress:
101 | i = 0
102 | for batch in loader:
103 | with env.begin(write=True) as txn:
104 | for img in batch:
105 | key = f"{size}-{str(i).zfill(7)}".encode("utf-8")
106 | # print(key)
107 | txn.put(key, img)
108 | i += 1
109 | progress.update()
110 | # if i == 1000:
111 | # break
112 | # if total == len(imgset):
113 | # break
114 |
115 | with env.begin(write=True) as txn:
116 | txn.put("length".encode("utf-8"), str(i).encode("utf-8"))
117 |
--------------------------------------------------------------------------------
/data_resize_celebahq.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing
3 | from functools import partial
4 | from io import BytesIO
5 | from pathlib import Path
6 |
7 | import lmdb
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms import functional as trans_fn
11 | from tqdm import tqdm
12 | import os
13 |
14 |
15 | def resize_and_convert(img, size, resample, quality=100):
16 | img = trans_fn.resize(img, size, resample)
17 | img = trans_fn.center_crop(img, size)
18 | buffer = BytesIO()
19 | img.save(buffer, format="jpeg", quality=quality)
20 | val = buffer.getvalue()
21 |
22 | return val
23 |
24 |
25 | def resize_multiple(img,
26 | sizes=(128, 256, 512, 1024),
27 | resample=Image.LANCZOS,
28 | quality=100):
29 | imgs = []
30 |
31 | for size in sizes:
32 | imgs.append(resize_and_convert(img, size, resample, quality))
33 |
34 | return imgs
35 |
36 |
37 | def resize_worker(img_file, sizes, resample):
38 | i, (file, idx) = img_file
39 | img = Image.open(file)
40 | img = img.convert("RGB")
41 | out = resize_multiple(img, sizes=sizes, resample=resample)
42 |
43 | return i, idx, out
44 |
45 |
46 | def prepare(env,
47 | paths,
48 | n_worker,
49 | sizes=(128, 256, 512, 1024),
50 | resample=Image.LANCZOS):
51 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
52 |
53 | # index = filename in int
54 | indexs = []
55 | for each in paths:
56 | file = os.path.basename(each)
57 | name, ext = file.split('.')
58 | idx = int(name)
59 | indexs.append(idx)
60 |
61 | # sort by file index
62 | files = sorted(zip(paths, indexs), key=lambda x: x[1])
63 | files = list(enumerate(files))
64 | total = 0
65 |
66 | with multiprocessing.Pool(n_worker) as pool:
67 | for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
68 | for size, img in zip(sizes, imgs):
69 | key = f"{size}-{str(idx).zfill(5)}".encode("utf-8")
70 |
71 | with env.begin(write=True) as txn:
72 | txn.put(key, img)
73 |
74 | total += 1
75 |
76 | with env.begin(write=True) as txn:
77 | txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
78 |
79 |
80 | class ImageFolder(Dataset):
81 | def __init__(self, folder, exts=['jpg']):
82 | super().__init__()
83 | self.paths = [
84 | p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')
85 | ]
86 |
87 | def __len__(self):
88 | return len(self.paths)
89 |
90 | def __getitem__(self, index):
91 | path = os.path.join(self.folder, self.paths[index])
92 | img = Image.open(path)
93 | return img
94 |
95 |
96 | if __name__ == "__main__":
97 | """
98 | converting celebahq images to lmdb
99 | """
100 | num_workers = 16
101 | in_path = 'datasets/celebahq'
102 | out_path = 'datasets/celebahq256.lmdb'
103 |
104 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
105 | resample = resample_map['lanczos']
106 |
107 | sizes = [256]
108 |
109 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
110 |
111 | # imgset = datasets.ImageFolder(in_path)
112 | # imgset = ImageFolder(in_path)
113 | exts = ['jpg']
114 | paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')]
115 |
116 | with lmdb.open(out_path, map_size=1024**4, readahead=False) as env:
117 | prepare(env, paths, num_workers, sizes=sizes, resample=resample)
118 |
--------------------------------------------------------------------------------
/data_resize_ffhq.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing
3 | from functools import partial
4 | from io import BytesIO
5 | from pathlib import Path
6 |
7 | import lmdb
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 | from torchvision.transforms import functional as trans_fn
11 | from tqdm import tqdm
12 | import os
13 |
14 |
15 | def resize_and_convert(img, size, resample, quality=100):
16 | img = trans_fn.resize(img, size, resample)
17 | img = trans_fn.center_crop(img, size)
18 | buffer = BytesIO()
19 | img.save(buffer, format="jpeg", quality=quality)
20 | val = buffer.getvalue()
21 |
22 | return val
23 |
24 |
25 | def resize_multiple(img,
26 | sizes=(128, 256, 512, 1024),
27 | resample=Image.LANCZOS,
28 | quality=100):
29 | imgs = []
30 |
31 | for size in sizes:
32 | imgs.append(resize_and_convert(img, size, resample, quality))
33 |
34 | return imgs
35 |
36 |
37 | def resize_worker(img_file, sizes, resample):
38 | i, (file, idx) = img_file
39 | img = Image.open(file)
40 | img = img.convert("RGB")
41 | out = resize_multiple(img, sizes=sizes, resample=resample)
42 |
43 | return i, idx, out
44 |
45 |
46 | def prepare(env,
47 | paths,
48 | n_worker,
49 | sizes=(128, 256, 512, 1024),
50 | resample=Image.LANCZOS):
51 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample)
52 |
53 | # index = filename in int
54 | indexs = []
55 | for each in paths:
56 | file = os.path.basename(each)
57 | name, ext = file.split('.')
58 | idx = int(name)
59 | indexs.append(idx)
60 |
61 | # sort by file index
62 | files = sorted(zip(paths, indexs), key=lambda x: x[1])
63 | files = list(enumerate(files))
64 | total = 0
65 |
66 | with multiprocessing.Pool(n_worker) as pool:
67 | for i, idx, imgs in tqdm(pool.imap_unordered(resize_fn, files)):
68 | for size, img in zip(sizes, imgs):
69 | key = f"{size}-{str(idx).zfill(5)}".encode("utf-8")
70 |
71 | with env.begin(write=True) as txn:
72 | txn.put(key, img)
73 |
74 | total += 1
75 |
76 | with env.begin(write=True) as txn:
77 | txn.put("length".encode("utf-8"), str(total).encode("utf-8"))
78 |
79 |
80 | class ImageFolder(Dataset):
81 | def __init__(self, folder, exts=['jpg']):
82 | super().__init__()
83 | self.paths = [
84 | p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')
85 | ]
86 |
87 | def __len__(self):
88 | return len(self.paths)
89 |
90 | def __getitem__(self, index):
91 | path = os.path.join(self.folder, self.paths[index])
92 | img = Image.open(path)
93 | return img
94 |
95 |
96 | if __name__ == "__main__":
97 | """
98 | converting ffhq images to lmdb
99 | """
100 | num_workers = 16
101 | # original ffhq data path
102 | in_path = 'datasets/ffhq'
103 | # target output path
104 | out_path = 'datasets/ffhq.lmdb'
105 |
106 | if not os.path.exists(out_path):
107 | os.makedirs(out_path)
108 |
109 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR}
110 | resample = resample_map['lanczos']
111 |
112 | sizes = [256]
113 |
114 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes))
115 |
116 | # imgset = datasets.ImageFolder(in_path)
117 | # imgset = ImageFolder(in_path)
118 | exts = ['jpg']
119 | paths = [p for ext in exts for p in Path(f'{in_path}').glob(f'**/*.{ext}')]
120 | # print(paths[:10])
121 |
122 | with lmdb.open(out_path, map_size=1024**4, readahead=False) as env:
123 | prepare(env, paths, num_workers, sizes=sizes, resample=resample)
124 |
--------------------------------------------------------------------------------
/data_resize_horse.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing
3 | import os
4 | import shutil
5 | from functools import partial
6 | from io import BytesIO
7 | from multiprocessing import Process, Queue
8 | from os.path import exists, join
9 |
10 | import lmdb
11 | from PIL import Image
12 | from torch.utils.data import DataLoader, Dataset
13 | from torchvision.datasets import LSUNClass
14 | from torchvision.transforms import functional as trans_fn
15 | from tqdm import tqdm
16 |
17 |
18 | def resize_and_convert(img, size, resample, quality=100):
19 | img = trans_fn.resize(img, size, resample)
20 | img = trans_fn.center_crop(img, size)
21 | buffer = BytesIO()
22 | img.save(buffer, format="webp", quality=quality)
23 | val = buffer.getvalue()
24 |
25 | return val
26 |
27 |
28 | def resize_multiple(img,
29 | sizes=(128, 256, 512, 1024),
30 | resample=Image.LANCZOS,
31 | quality=100):
32 | imgs = []
33 |
34 | for size in sizes:
35 | imgs.append(resize_and_convert(img, size, resample, quality))
36 |
37 | return imgs
38 |
39 |
40 | def resize_worker(idx, img, sizes, resample):
41 | img = img.convert("RGB")
42 | out = resize_multiple(img, sizes=sizes, resample=resample)
43 | return idx, out
44 |
45 |
46 | class ConvertDataset(Dataset):
47 | def __init__(self, data) -> None:
48 | self.data = data
49 |
50 | def __len__(self):
51 | return len(self.data)
52 |
53 | def __getitem__(self, index):
54 | img, _ = self.data[index]
55 | bytes = resize_and_convert(img, 256, Image.LANCZOS, quality=90)
56 | return bytes
57 |
58 |
59 | if __name__ == "__main__":
60 | """
61 | converting lsun' original lmdb to our lmdb, which is somehow more performant.
62 | """
63 | from tqdm import tqdm
64 |
65 | # path to the original lsun's lmdb
66 | src_path = 'datasets/horse_train_lmdb'
67 | out_path = 'datasets/horse256.lmdb'
68 |
69 | dataset = LSUNClass(root=os.path.expanduser(src_path))
70 | dataset = ConvertDataset(dataset)
71 | loader = DataLoader(dataset,
72 | batch_size=50,
73 | num_workers=16,
74 | collate_fn=lambda x: x)
75 |
76 | target = os.path.expanduser(out_path)
77 | if os.path.exists(target):
78 | shutil.rmtree(target)
79 |
80 | with lmdb.open(target, map_size=1024**4, readahead=False) as env:
81 | with tqdm(total=len(dataset)) as progress:
82 | i = 0
83 | for batch in loader:
84 | with env.begin(write=True) as txn:
85 | for img in batch:
86 | key = f"{256}-{str(i).zfill(7)}".encode("utf-8")
87 | # print(key)
88 | txn.put(key, img)
89 | i += 1
90 | progress.update()
91 | # if i == 1000:
92 | # break
93 | # if total == len(imgset):
94 | # break
95 |
96 | with env.begin(write=True) as txn:
97 | txn.put("length".encode("utf-8"), str(i).encode("utf-8"))
98 |
--------------------------------------------------------------------------------
/dataset_util.py:
--------------------------------------------------------------------------------
1 | import shutil
2 | import os
3 | from dist_utils import *
4 |
5 |
6 | def use_cached_dataset_path(source_path, cache_path):
7 | if get_rank() == 0:
8 | if not os.path.exists(cache_path):
9 | # shutil.rmtree(cache_path)
10 | print(f'copying the data: {source_path} to {cache_path}')
11 | shutil.copytree(source_path, cache_path)
12 | barrier()
13 | return cache_path
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 |
3 | from .diffusion import SpacedDiffusionBeatGans, SpacedDiffusionBeatGansConfig
4 |
5 | Sampler = Union[SpacedDiffusionBeatGans]
6 | SamplerConfig = Union[SpacedDiffusionBeatGansConfig]
7 |
--------------------------------------------------------------------------------
/diffusion/diffusion.py:
--------------------------------------------------------------------------------
1 | from .base import *
2 | from dataclasses import dataclass
3 |
4 |
5 | def space_timesteps(num_timesteps, section_counts):
6 | """
7 | Create a list of timesteps to use from an original diffusion process,
8 | given the number of timesteps we want to take from equally-sized portions
9 | of the original process.
10 |
11 | For example, if there's 300 timesteps and the section counts are [10,15,20]
12 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
13 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
14 |
15 | If the stride is a string starting with "ddim", then the fixed striding
16 | from the DDIM paper is used, and only one section is allowed.
17 |
18 | :param num_timesteps: the number of diffusion steps in the original
19 | process to divide up.
20 | :param section_counts: either a list of numbers, or a string containing
21 | comma-separated numbers, indicating the step count
22 | per section. As a special case, use "ddimN" where N
23 | is a number of steps to use the striding from the
24 | DDIM paper.
25 | :return: a set of diffusion steps from the original process to use.
26 | """
27 | if isinstance(section_counts, str):
28 | if section_counts.startswith("ddim"):
29 | desired_count = int(section_counts[len("ddim"):])
30 | for i in range(1, num_timesteps):
31 | if len(range(0, num_timesteps, i)) == desired_count:
32 | return set(range(0, num_timesteps, i))
33 | raise ValueError(
34 | f"cannot create exactly {num_timesteps} steps with an integer stride"
35 | )
36 | section_counts = [int(x) for x in section_counts.split(",")]
37 | size_per = num_timesteps // len(section_counts)
38 | extra = num_timesteps % len(section_counts)
39 | start_idx = 0
40 | all_steps = []
41 | for i, section_count in enumerate(section_counts):
42 | size = size_per + (1 if i < extra else 0)
43 | if size < section_count:
44 | raise ValueError(
45 | f"cannot divide section of {size} steps into {section_count}")
46 | if section_count <= 1:
47 | frac_stride = 1
48 | else:
49 | frac_stride = (size - 1) / (section_count - 1)
50 | cur_idx = 0.0
51 | taken_steps = []
52 | for _ in range(section_count):
53 | taken_steps.append(start_idx + round(cur_idx))
54 | cur_idx += frac_stride
55 | all_steps += taken_steps
56 | start_idx += size
57 | return set(all_steps)
58 |
59 |
60 | @dataclass
61 | class SpacedDiffusionBeatGansConfig(GaussianDiffusionBeatGansConfig):
62 | use_timesteps: Tuple[int] = None
63 |
64 | def make_sampler(self):
65 | return SpacedDiffusionBeatGans(self)
66 |
67 |
68 | class SpacedDiffusionBeatGans(GaussianDiffusionBeatGans):
69 | """
70 | A diffusion process which can skip steps in a base diffusion process.
71 |
72 | :param use_timesteps: a collection (sequence or set) of timesteps from the
73 | original diffusion process to retain.
74 | :param kwargs: the kwargs to create the base diffusion process.
75 | """
76 | def __init__(self, conf: SpacedDiffusionBeatGansConfig):
77 | self.conf = conf
78 | self.use_timesteps = set(conf.use_timesteps)
79 | # how the new t's mapped to the old t's
80 | self.timestep_map = []
81 | self.original_num_steps = len(conf.betas)
82 |
83 | base_diffusion = GaussianDiffusionBeatGans(conf) # pylint: disable=missing-kwoa
84 | last_alpha_cumprod = 1.0
85 | new_betas = []
86 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
87 | if i in self.use_timesteps:
88 | # getting the new betas of the new timesteps
89 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
90 | last_alpha_cumprod = alpha_cumprod
91 | self.timestep_map.append(i)
92 | conf.betas = np.array(new_betas)
93 | super().__init__(conf)
94 |
95 | def p_mean_variance(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
96 | return super().p_mean_variance(self._wrap_model(model), *args,
97 | **kwargs)
98 |
99 | def training_losses(self, model: Model, *args, **kwargs): # pylint: disable=signature-differs
100 | return super().training_losses(self._wrap_model(model), *args,
101 | **kwargs)
102 |
103 | def condition_mean(self, cond_fn, *args, **kwargs):
104 | return super().condition_mean(self._wrap_model(cond_fn), *args,
105 | **kwargs)
106 |
107 | def condition_score(self, cond_fn, *args, **kwargs):
108 | return super().condition_score(self._wrap_model(cond_fn), *args,
109 | **kwargs)
110 |
111 | def _wrap_model(self, model: Model):
112 | if isinstance(model, _WrappedModel):
113 | return model
114 | return _WrappedModel(model, self.timestep_map, self.rescale_timesteps,
115 | self.original_num_steps)
116 |
117 | def _scale_timesteps(self, t):
118 | # Scaling is done by the wrapped model.
119 | return t
120 |
121 |
122 | class _WrappedModel:
123 | """
124 | converting the supplied t's to the old t's scales.
125 | """
126 | def __init__(self, model, timestep_map, rescale_timesteps,
127 | original_num_steps):
128 | self.model = model
129 | self.timestep_map = timestep_map
130 | self.rescale_timesteps = rescale_timesteps
131 | self.original_num_steps = original_num_steps
132 |
133 | def forward(self, x, t, t_cond=None, **kwargs):
134 | """
135 | Args:
136 | t: t's with differrent ranges (can be << T due to smaller eval T) need to be converted to the original t's
137 | t_cond: the same as t but can be of different values
138 | """
139 | map_tensor = th.tensor(self.timestep_map,
140 | device=t.device,
141 | dtype=t.dtype)
142 |
143 | def do(t):
144 | new_ts = map_tensor[t]
145 | if self.rescale_timesteps:
146 | new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
147 | return new_ts
148 |
149 | if t_cond is not None:
150 | # support t_cond
151 | t_cond = do(t_cond)
152 |
153 | return self.model(x=x, t=do(t), t_cond=t_cond, **kwargs)
154 |
155 | def __getattr__(self, name):
156 | # allow for calling the model's methods
157 | if hasattr(self.model, name):
158 | func = getattr(self.model, name)
159 | return func
160 | raise AttributeError(name)
161 |
--------------------------------------------------------------------------------
/diffusion/resample.py:
--------------------------------------------------------------------------------
1 | from abc import ABC, abstractmethod
2 |
3 | import numpy as np
4 | import torch as th
5 | import torch.distributed as dist
6 |
7 |
8 | def create_named_schedule_sampler(name, diffusion):
9 | """
10 | Create a ScheduleSampler from a library of pre-defined samplers.
11 |
12 | :param name: the name of the sampler.
13 | :param diffusion: the diffusion object to sample for.
14 | """
15 | if name == "uniform":
16 | return UniformSampler(diffusion)
17 | else:
18 | raise NotImplementedError(f"unknown schedule sampler: {name}")
19 |
20 |
21 | class ScheduleSampler(ABC):
22 | """
23 | A distribution over timesteps in the diffusion process, intended to reduce
24 | variance of the objective.
25 |
26 | By default, samplers perform unbiased importance sampling, in which the
27 | objective's mean is unchanged.
28 | However, subclasses may override sample() to change how the resampled
29 | terms are reweighted, allowing for actual changes in the objective.
30 | """
31 | @abstractmethod
32 | def weights(self):
33 | """
34 | Get a numpy array of weights, one per diffusion step.
35 |
36 | The weights needn't be normalized, but must be positive.
37 | """
38 |
39 | def sample(self, batch_size, device):
40 | """
41 | Importance-sample timesteps for a batch.
42 |
43 | :param batch_size: the number of timesteps.
44 | :param device: the torch device to save to.
45 | :return: a tuple (timesteps, weights):
46 | - timesteps: a tensor of timestep indices.
47 | - weights: a tensor of weights to scale the resulting losses.
48 | """
49 | w = self.weights()
50 | p = w / np.sum(w)
51 | indices_np = np.random.choice(len(p), size=(batch_size, ), p=p)
52 | indices = th.from_numpy(indices_np).long().to(device)
53 | weights_np = 1 / (len(p) * p[indices_np])
54 | weights = th.from_numpy(weights_np).float().to(device)
55 | return indices, weights
56 |
57 |
58 | class UniformSampler(ScheduleSampler):
59 | def __init__(self, num_timesteps):
60 | self._weights = np.ones([num_timesteps])
61 |
62 | def weights(self):
63 | return self._weights
64 |
--------------------------------------------------------------------------------
/dist_utils.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from torch import distributed
3 |
4 |
5 | def barrier():
6 | if distributed.is_initialized():
7 | distributed.barrier()
8 | else:
9 | pass
10 |
11 |
12 | def broadcast(data, src):
13 | if distributed.is_initialized():
14 | distributed.broadcast(data, src)
15 | else:
16 | pass
17 |
18 |
19 | def all_gather(data: List, src):
20 | if distributed.is_initialized():
21 | distributed.all_gather(data, src)
22 | else:
23 | data[0] = src
24 |
25 |
26 | def get_rank():
27 | if distributed.is_initialized():
28 | return distributed.get_rank()
29 | else:
30 | return 0
31 |
32 |
33 | def get_world_size():
34 | if distributed.is_initialized():
35 | return distributed.get_world_size()
36 | else:
37 | return 1
38 |
39 |
40 | def chunk_size(size, rank, world_size):
41 | extra = rank < size % world_size
42 | return size // world_size + extra
--------------------------------------------------------------------------------
/evals/ffhq128_autoenc_130M.txt:
--------------------------------------------------------------------------------
1 | {}
2 |
--------------------------------------------------------------------------------
/evals/ffhq128_autoenc_latent.txt:
--------------------------------------------------------------------------------
1 | {"fid_ema_T10_Tlatent10": 20.634624481201172}
2 |
--------------------------------------------------------------------------------
/experiment_classifier.py:
--------------------------------------------------------------------------------
1 | from config import *
2 | from dataset import *
3 | import pandas as pd
4 | import json
5 | import os
6 | import copy
7 |
8 | import numpy as np
9 | import pytorch_lightning as pl
10 | from pytorch_lightning import loggers as pl_loggers
11 | from pytorch_lightning.callbacks import *
12 | import torch
13 |
14 |
15 | class ZipLoader:
16 | def __init__(self, loaders):
17 | self.loaders = loaders
18 |
19 | def __len__(self):
20 | return len(self.loaders[0])
21 |
22 | def __iter__(self):
23 | for each in zip(*self.loaders):
24 | yield each
25 |
26 |
27 | class ClsModel(pl.LightningModule):
28 | def __init__(self, conf: TrainConfig):
29 | super().__init__()
30 | assert conf.train_mode.is_manipulate()
31 | if conf.seed is not None:
32 | pl.seed_everything(conf.seed)
33 |
34 | self.save_hyperparameters(conf.as_dict_jsonable())
35 | self.conf = conf
36 |
37 | # preparations
38 | if conf.train_mode == TrainMode.manipulate:
39 | # this is only important for training!
40 | # the latent is freshly inferred to make sure it matches the image
41 | # manipulating latents require the base model
42 | self.model = conf.make_model_conf().make_model()
43 | self.ema_model = copy.deepcopy(self.model)
44 | self.model.requires_grad_(False)
45 | self.ema_model.requires_grad_(False)
46 | self.ema_model.eval()
47 |
48 | if conf.pretrain is not None:
49 | print(f'loading pretrain ... {conf.pretrain.name}')
50 | state = torch.load(conf.pretrain.path, map_location='cpu')
51 | print('step:', state['global_step'])
52 | self.load_state_dict(state['state_dict'], strict=False)
53 |
54 | # load the latent stats
55 | if conf.manipulate_znormalize:
56 | print('loading latent stats ...')
57 | state = torch.load(conf.latent_infer_path)
58 | self.conds = state['conds']
59 | self.register_buffer('conds_mean',
60 | state['conds_mean'][None, :])
61 | self.register_buffer('conds_std', state['conds_std'][None, :])
62 | else:
63 | self.conds_mean = None
64 | self.conds_std = None
65 |
66 | if conf.manipulate_mode in [ManipulateMode.celebahq_all]:
67 | num_cls = len(CelebAttrDataset.id_to_cls)
68 | elif conf.manipulate_mode.is_single_class():
69 | num_cls = 1
70 | else:
71 | raise NotImplementedError()
72 |
73 | # classifier
74 | if conf.train_mode == TrainMode.manipulate:
75 | # latent manipluation requires only a linear classifier
76 | self.classifier = nn.Linear(conf.style_ch, num_cls)
77 | else:
78 | raise NotImplementedError()
79 |
80 | self.ema_classifier = copy.deepcopy(self.classifier)
81 |
82 | def state_dict(self, *args, **kwargs):
83 | # don't save the base model
84 | out = {}
85 | for k, v in super().state_dict(*args, **kwargs).items():
86 | if k.startswith('model.'):
87 | pass
88 | elif k.startswith('ema_model.'):
89 | pass
90 | else:
91 | out[k] = v
92 | return out
93 |
94 | def load_state_dict(self, state_dict, strict: bool = None):
95 | if self.conf.train_mode == TrainMode.manipulate:
96 | # change the default strict => False
97 | if strict is None:
98 | strict = False
99 | else:
100 | if strict is None:
101 | strict = True
102 | return super().load_state_dict(state_dict, strict=strict)
103 |
104 | def normalize(self, cond):
105 | cond = (cond - self.conds_mean.to(self.device)) / self.conds_std.to(
106 | self.device)
107 | return cond
108 |
109 | def denormalize(self, cond):
110 | cond = (cond * self.conds_std.to(self.device)) + self.conds_mean.to(
111 | self.device)
112 | return cond
113 |
114 | def load_dataset(self):
115 | if self.conf.manipulate_mode == ManipulateMode.d2c_fewshot:
116 | return CelebD2CAttrFewshotDataset(
117 | cls_name=self.conf.manipulate_cls,
118 | K=self.conf.manipulate_shots,
119 | img_folder=data_paths['celeba'],
120 | img_size=self.conf.img_size,
121 | seed=self.conf.manipulate_seed,
122 | all_neg=False,
123 | do_augment=True,
124 | )
125 | elif self.conf.manipulate_mode == ManipulateMode.d2c_fewshot_allneg:
126 | # positive-unlabeled classifier needs to keep the class ratio 1:1
127 | # we use two dataloaders, one for each class, to stabiliize the training
128 | img_folder = data_paths['celeba']
129 |
130 | return [
131 | CelebD2CAttrFewshotDataset(
132 | cls_name=self.conf.manipulate_cls,
133 | K=self.conf.manipulate_shots,
134 | img_folder=img_folder,
135 | img_size=self.conf.img_size,
136 | only_cls_name=self.conf.manipulate_cls,
137 | only_cls_value=1,
138 | seed=self.conf.manipulate_seed,
139 | all_neg=True,
140 | do_augment=True),
141 | CelebD2CAttrFewshotDataset(
142 | cls_name=self.conf.manipulate_cls,
143 | K=self.conf.manipulate_shots,
144 | img_folder=img_folder,
145 | img_size=self.conf.img_size,
146 | only_cls_name=self.conf.manipulate_cls,
147 | only_cls_value=-1,
148 | seed=self.conf.manipulate_seed,
149 | all_neg=True,
150 | do_augment=True),
151 | ]
152 | elif self.conf.manipulate_mode == ManipulateMode.celebahq_all:
153 | return CelebHQAttrDataset(data_paths['celebahq'],
154 | self.conf.img_size,
155 | data_paths['celebahq_anno'],
156 | do_augment=True)
157 | else:
158 | raise NotImplementedError()
159 |
160 | def setup(self, stage=None) -> None:
161 | ##############################################
162 | # NEED TO SET THE SEED SEPARATELY HERE
163 | if self.conf.seed is not None:
164 | seed = self.conf.seed * get_world_size() + self.global_rank
165 | np.random.seed(seed)
166 | torch.manual_seed(seed)
167 | torch.cuda.manual_seed(seed)
168 | print('local seed:', seed)
169 | ##############################################
170 |
171 | self.train_data = self.load_dataset()
172 | if self.conf.manipulate_mode.is_fewshot():
173 | # repeat the dataset to be larger (speed up the training)
174 | if isinstance(self.train_data, list):
175 | # fewshot-allneg has two datasets
176 | # we resize them to be of equal sizes
177 | a, b = self.train_data
178 | self.train_data = [
179 | Repeat(a, max(len(a), len(b))),
180 | Repeat(b, max(len(a), len(b))),
181 | ]
182 | else:
183 | self.train_data = Repeat(self.train_data, 100_000)
184 |
185 | def train_dataloader(self):
186 | # make sure to use the fraction of batch size
187 | # the batch size is global!
188 | conf = self.conf.clone()
189 | conf.batch_size = self.batch_size
190 | if isinstance(self.train_data, list):
191 | dataloader = []
192 | for each in self.train_data:
193 | dataloader.append(
194 | conf.make_loader(each, shuffle=True, drop_last=True))
195 | dataloader = ZipLoader(dataloader)
196 | else:
197 | dataloader = conf.make_loader(self.train_data,
198 | shuffle=True,
199 | drop_last=True)
200 | return dataloader
201 |
202 | @property
203 | def batch_size(self):
204 | ws = get_world_size()
205 | assert self.conf.batch_size % ws == 0
206 | return self.conf.batch_size // ws
207 |
208 | def training_step(self, batch, batch_idx):
209 | self.ema_model: BeatGANsAutoencModel
210 | if isinstance(batch, tuple):
211 | a, b = batch
212 | imgs = torch.cat([a['img'], b['img']])
213 | labels = torch.cat([a['labels'], b['labels']])
214 | else:
215 | imgs = batch['img']
216 | # print(f'({self.global_rank}) imgs:', imgs.shape)
217 | labels = batch['labels']
218 |
219 | if self.conf.train_mode == TrainMode.manipulate:
220 | self.ema_model.eval()
221 | with torch.no_grad():
222 | # (n, c)
223 | cond = self.ema_model.encoder(imgs)
224 |
225 | if self.conf.manipulate_znormalize:
226 | cond = self.normalize(cond)
227 |
228 | # (n, cls)
229 | pred = self.classifier.forward(cond)
230 | pred_ema = self.ema_classifier.forward(cond)
231 | elif self.conf.train_mode == TrainMode.manipulate_img:
232 | # (n, cls)
233 | pred = self.classifier.forward(imgs)
234 | pred_ema = None
235 | elif self.conf.train_mode == TrainMode.manipulate_imgt:
236 | t, weight = self.T_sampler.sample(len(imgs), imgs.device)
237 | imgs_t = self.sampler.q_sample(imgs, t)
238 | pred = self.classifier.forward(imgs_t, t=t)
239 | pred_ema = None
240 | print('pred:', pred.shape)
241 | else:
242 | raise NotImplementedError()
243 |
244 | if self.conf.manipulate_mode.is_celeba_attr():
245 | gt = torch.where(labels > 0,
246 | torch.ones_like(labels).float(),
247 | torch.zeros_like(labels).float())
248 | elif self.conf.manipulate_mode == ManipulateMode.relighting:
249 | gt = labels
250 | else:
251 | raise NotImplementedError()
252 |
253 | if self.conf.manipulate_loss == ManipulateLossType.bce:
254 | loss = F.binary_cross_entropy_with_logits(pred, gt)
255 | if pred_ema is not None:
256 | loss_ema = F.binary_cross_entropy_with_logits(pred_ema, gt)
257 | elif self.conf.manipulate_loss == ManipulateLossType.mse:
258 | loss = F.mse_loss(pred, gt)
259 | if pred_ema is not None:
260 | loss_ema = F.mse_loss(pred_ema, gt)
261 | else:
262 | raise NotImplementedError()
263 |
264 | self.log('loss', loss)
265 | self.log('loss_ema', loss_ema)
266 | return loss
267 |
268 | def on_train_batch_end(self, outputs, batch, batch_idx: int,
269 | dataloader_idx: int) -> None:
270 | ema(self.classifier, self.ema_classifier, self.conf.ema_decay)
271 |
272 | def configure_optimizers(self):
273 | optim = torch.optim.Adam(self.classifier.parameters(),
274 | lr=self.conf.lr,
275 | weight_decay=self.conf.weight_decay)
276 | return optim
277 |
278 |
279 | def ema(source, target, decay):
280 | source_dict = source.state_dict()
281 | target_dict = target.state_dict()
282 | for key in source_dict.keys():
283 | target_dict[key].data.copy_(target_dict[key].data * decay +
284 | source_dict[key].data * (1 - decay))
285 |
286 |
287 | def train_cls(conf: TrainConfig, gpus):
288 | print('conf:', conf.name)
289 | model = ClsModel(conf)
290 |
291 | if not os.path.exists(conf.logdir):
292 | os.makedirs(conf.logdir)
293 | checkpoint = ModelCheckpoint(
294 | dirpath=f'{conf.logdir}',
295 | save_last=True,
296 | save_top_k=1,
297 | # every_n_train_steps=conf.save_every_samples //
298 | # conf.batch_size_effective,
299 | )
300 | checkpoint_path = f'{conf.logdir}/last.ckpt'
301 | if os.path.exists(checkpoint_path):
302 | resume = checkpoint_path
303 | else:
304 | if conf.continue_from is not None:
305 | # continue from a checkpoint
306 | resume = conf.continue_from.path
307 | else:
308 | resume = None
309 |
310 | tb_logger = pl_loggers.TensorBoardLogger(save_dir=conf.logdir,
311 | name=None,
312 | version='')
313 |
314 | # from pytorch_lightning.
315 |
316 | plugins = []
317 | if len(gpus) == 1:
318 | accelerator = None
319 | else:
320 | accelerator = 'ddp'
321 | from pytorch_lightning.plugins import DDPPlugin
322 | # important for working with gradient checkpoint
323 | plugins.append(DDPPlugin(find_unused_parameters=False))
324 |
325 | trainer = pl.Trainer(
326 | max_steps=conf.total_samples // conf.batch_size_effective,
327 | resume_from_checkpoint=resume,
328 | gpus=gpus,
329 | accelerator=accelerator,
330 | precision=16 if conf.fp16 else 32,
331 | callbacks=[
332 | checkpoint,
333 | ],
334 | replace_sampler_ddp=True,
335 | logger=tb_logger,
336 | accumulate_grad_batches=conf.accum_batches,
337 | plugins=plugins,
338 | )
339 | trainer.fit(model)
340 |
--------------------------------------------------------------------------------
/imgs/sandy.JPG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs/sandy.JPG
--------------------------------------------------------------------------------
/imgs_align/sandy.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_align/sandy.png
--------------------------------------------------------------------------------
/imgs_interpolate/1_a.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_interpolate/1_a.png
--------------------------------------------------------------------------------
/imgs_interpolate/1_b.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_interpolate/1_b.png
--------------------------------------------------------------------------------
/imgs_manipulated/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_manipulated/compare.png
--------------------------------------------------------------------------------
/imgs_manipulated/output.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_manipulated/output.png
--------------------------------------------------------------------------------
/imgs_manipulated/sandy-wavyhair.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/phizaz/diffae/00c57d3f626f28bf9ed8aff58d90baab25de3af4/imgs_manipulated/sandy-wavyhair.png
--------------------------------------------------------------------------------
/install_requirements_for_colab.sh:
--------------------------------------------------------------------------------
1 | !pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 pytorch-lightning==1.2.2 torchtext==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
2 | !pip install scipy==1.5.4
3 | !pip install numpy==1.19.5
4 | !pip install tqdm
5 | !pip install pytorch-fid==0.2.0
6 | !pip install pandas==1.1.5
7 | !pip install lpips==0.1.4
8 | !pip install lmdb==1.2.1
9 | !pip install ftfy
10 | !pip install regex
11 | !pip install dlib requests
--------------------------------------------------------------------------------
/lmdb_writer.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 |
3 | import lmdb
4 | from PIL import Image
5 |
6 | import torch
7 |
8 | from contextlib import contextmanager
9 | from torch.utils.data import Dataset
10 | from multiprocessing import Process, Queue
11 | import os
12 | import shutil
13 |
14 |
15 | def convert(x, format, quality=100):
16 | # to prevent locking!
17 | torch.set_num_threads(1)
18 |
19 | buffer = BytesIO()
20 | x = x.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0)
21 | x = x.to(torch.uint8)
22 | x = x.numpy()
23 | img = Image.fromarray(x)
24 | img.save(buffer, format=format, quality=quality)
25 | val = buffer.getvalue()
26 | return val
27 |
28 |
29 | @contextmanager
30 | def nullcontext():
31 | yield
32 |
33 |
34 | class _WriterWroker(Process):
35 | def __init__(self, path, format, quality, zfill, q):
36 | super().__init__()
37 | if os.path.exists(path):
38 | shutil.rmtree(path)
39 |
40 | self.path = path
41 | self.format = format
42 | self.quality = quality
43 | self.zfill = zfill
44 | self.q = q
45 | self.i = 0
46 |
47 | def run(self):
48 | if not os.path.exists(self.path):
49 | os.makedirs(self.path)
50 |
51 | with lmdb.open(self.path, map_size=1024**4, readahead=False) as env:
52 | while True:
53 | job = self.q.get()
54 | if job is None:
55 | break
56 | with env.begin(write=True) as txn:
57 | for x in job:
58 | key = f"{str(self.i).zfill(self.zfill)}".encode(
59 | "utf-8")
60 | x = convert(x, self.format, self.quality)
61 | txn.put(key, x)
62 | self.i += 1
63 |
64 | with env.begin(write=True) as txn:
65 | txn.put("length".encode("utf-8"), str(self.i).encode("utf-8"))
66 |
67 |
68 | class LMDBImageWriter:
69 | def __init__(self, path, format='webp', quality=100, zfill=7) -> None:
70 | self.path = path
71 | self.format = format
72 | self.quality = quality
73 | self.zfill = zfill
74 | self.queue = None
75 | self.worker = None
76 |
77 | def __enter__(self):
78 | self.queue = Queue(maxsize=3)
79 | self.worker = _WriterWroker(self.path, self.format, self.quality,
80 | self.zfill, self.queue)
81 | self.worker.start()
82 |
83 | def put_images(self, tensor):
84 | """
85 | Args:
86 | tensor: (n, c, h, w) [0-1] tensor
87 | """
88 | self.queue.put(tensor.cpu())
89 | # with self.env.begin(write=True) as txn:
90 | # for x in tensor:
91 | # key = f"{str(self.i).zfill(self.zfill)}".encode("utf-8")
92 | # x = convert(x, self.format, self.quality)
93 | # txn.put(key, x)
94 | # self.i += 1
95 |
96 | def __exit__(self, *args, **kwargs):
97 | self.queue.put(None)
98 | self.queue.close()
99 | self.worker.join()
100 |
101 |
102 | class LMDBImageReader(Dataset):
103 | def __init__(self, path, zfill: int = 7):
104 | self.zfill = zfill
105 | self.env = lmdb.open(
106 | path,
107 | max_readers=32,
108 | readonly=True,
109 | lock=False,
110 | readahead=False,
111 | meminit=False,
112 | )
113 |
114 | if not self.env:
115 | raise IOError('Cannot open lmdb dataset', path)
116 |
117 | with self.env.begin(write=False) as txn:
118 | self.length = int(
119 | txn.get('length'.encode('utf-8')).decode('utf-8'))
120 |
121 | def __len__(self):
122 | return self.length
123 |
124 | def __getitem__(self, index):
125 | with self.env.begin(write=False) as txn:
126 | key = f'{str(index).zfill(self.zfill)}'.encode('utf-8')
127 | img_bytes = txn.get(key)
128 |
129 | buffer = BytesIO(img_bytes)
130 | img = Image.open(buffer)
131 | return img
132 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | import torch
5 | import torchvision
6 | from pytorch_fid import fid_score
7 | from torch import distributed
8 | from torch.utils.data import DataLoader
9 | from torch.utils.data.distributed import DistributedSampler
10 | from tqdm.autonotebook import tqdm, trange
11 |
12 | from renderer import *
13 | from config import *
14 | from diffusion import Sampler
15 | from dist_utils import *
16 | import lpips
17 | from ssim import ssim
18 |
19 |
20 | def make_subset_loader(conf: TrainConfig,
21 | dataset: Dataset,
22 | batch_size: int,
23 | shuffle: bool,
24 | parallel: bool,
25 | drop_last=True):
26 | dataset = SubsetDataset(dataset, size=conf.eval_num_images)
27 | if parallel and distributed.is_initialized():
28 | sampler = DistributedSampler(dataset, shuffle=shuffle)
29 | else:
30 | sampler = None
31 | return DataLoader(
32 | dataset,
33 | batch_size=batch_size,
34 | sampler=sampler,
35 | # with sampler, use the sample instead of this option
36 | shuffle=False if sampler else shuffle,
37 | num_workers=conf.num_workers,
38 | pin_memory=True,
39 | drop_last=drop_last,
40 | multiprocessing_context=get_context('fork'),
41 | )
42 |
43 |
44 | def evaluate_lpips(
45 | sampler: Sampler,
46 | model: Model,
47 | conf: TrainConfig,
48 | device,
49 | val_data: Dataset,
50 | latent_sampler: Sampler = None,
51 | use_inverted_noise: bool = False,
52 | ):
53 | """
54 | compare the generated images from autoencoder on validation dataset
55 |
56 | Args:
57 | use_inversed_noise: the noise is also inverted from DDIM
58 | """
59 | lpips_fn = lpips.LPIPS(net='alex').to(device)
60 | val_loader = make_subset_loader(conf,
61 | dataset=val_data,
62 | batch_size=conf.batch_size_eval,
63 | shuffle=False,
64 | parallel=True)
65 |
66 | model.eval()
67 | with torch.no_grad():
68 | scores = {
69 | 'lpips': [],
70 | 'mse': [],
71 | 'ssim': [],
72 | 'psnr': [],
73 | }
74 | for batch in tqdm(val_loader, desc='lpips'):
75 | imgs = batch['img'].to(device)
76 |
77 | if use_inverted_noise:
78 | # inverse the noise
79 | # with condition from the encoder
80 | model_kwargs = {}
81 | if conf.model_type.has_autoenc():
82 | with torch.no_grad():
83 | model_kwargs = model.encode(imgs)
84 | x_T = sampler.ddim_reverse_sample_loop(
85 | model=model,
86 | x=imgs,
87 | clip_denoised=True,
88 | model_kwargs=model_kwargs)
89 | x_T = x_T['sample']
90 | else:
91 | x_T = torch.randn((len(imgs), 3, conf.img_size, conf.img_size),
92 | device=device)
93 |
94 | if conf.model_type == ModelType.ddpm:
95 | # the case where you want to calculate the inversion capability of the DDIM model
96 | assert use_inverted_noise
97 | pred_imgs = render_uncondition(
98 | conf=conf,
99 | model=model,
100 | x_T=x_T,
101 | sampler=sampler,
102 | latent_sampler=latent_sampler,
103 | )
104 | else:
105 | pred_imgs = render_condition(conf=conf,
106 | model=model,
107 | x_T=x_T,
108 | x_start=imgs,
109 | cond=None,
110 | sampler=sampler)
111 | # # returns {'cond', 'cond2'}
112 | # conds = model.encode(imgs)
113 | # pred_imgs = sampler.sample(model=model,
114 | # noise=x_T,
115 | # model_kwargs=conds)
116 |
117 | # (n, 1, 1, 1) => (n, )
118 | scores['lpips'].append(lpips_fn.forward(imgs, pred_imgs).view(-1))
119 |
120 | # need to normalize into [0, 1]
121 | norm_imgs = (imgs + 1) / 2
122 | norm_pred_imgs = (pred_imgs + 1) / 2
123 | # (n, )
124 | scores['ssim'].append(
125 | ssim(norm_imgs, norm_pred_imgs, size_average=False))
126 | # (n, )
127 | scores['mse'].append(
128 | (norm_imgs - norm_pred_imgs).pow(2).mean(dim=[1, 2, 3]))
129 | # (n, )
130 | scores['psnr'].append(psnr(norm_imgs, norm_pred_imgs))
131 | # (N, )
132 | for key in scores.keys():
133 | scores[key] = torch.cat(scores[key]).float()
134 | model.train()
135 |
136 | barrier()
137 |
138 | # support multi-gpu
139 | outs = {
140 | key: [
141 | torch.zeros(len(scores[key]), device=device)
142 | for i in range(get_world_size())
143 | ]
144 | for key in scores.keys()
145 | }
146 | for key in scores.keys():
147 | all_gather(outs[key], scores[key])
148 |
149 | # final scores
150 | for key in scores.keys():
151 | scores[key] = torch.cat(outs[key]).mean().item()
152 |
153 | # {'lpips', 'mse', 'ssim'}
154 | return scores
155 |
156 |
157 | def psnr(img1, img2):
158 | """
159 | Args:
160 | img1: (n, c, h, w)
161 | """
162 | v_max = 1.
163 | # (n,)
164 | mse = torch.mean((img1 - img2)**2, dim=[1, 2, 3])
165 | return 20 * torch.log10(v_max / torch.sqrt(mse))
166 |
167 |
168 | def evaluate_fid(
169 | sampler: Sampler,
170 | model: Model,
171 | conf: TrainConfig,
172 | device,
173 | train_data: Dataset,
174 | val_data: Dataset,
175 | latent_sampler: Sampler = None,
176 | conds_mean=None,
177 | conds_std=None,
178 | remove_cache: bool = True,
179 | clip_latent_noise: bool = False,
180 | ):
181 | assert conf.fid_cache is not None
182 | if get_rank() == 0:
183 | # no parallel
184 | # validation data for a comparing FID
185 | val_loader = make_subset_loader(conf,
186 | dataset=val_data,
187 | batch_size=conf.batch_size_eval,
188 | shuffle=False,
189 | parallel=False)
190 |
191 | # put the val images to a directory
192 | cache_dir = f'{conf.fid_cache}_{conf.eval_num_images}'
193 | if (os.path.exists(cache_dir)
194 | and len(os.listdir(cache_dir)) < conf.eval_num_images):
195 | shutil.rmtree(cache_dir)
196 |
197 | if not os.path.exists(cache_dir):
198 | # write files to the cache
199 | # the images are normalized, hence need to denormalize first
200 | loader_to_path(val_loader, cache_dir, denormalize=True)
201 |
202 | # create the generate dir
203 | if os.path.exists(conf.generate_dir):
204 | shutil.rmtree(conf.generate_dir)
205 | os.makedirs(conf.generate_dir)
206 |
207 | barrier()
208 |
209 | world_size = get_world_size()
210 | rank = get_rank()
211 | batch_size = chunk_size(conf.batch_size_eval, rank, world_size)
212 |
213 | def filename(idx):
214 | return world_size * idx + rank
215 |
216 | model.eval()
217 | with torch.no_grad():
218 | if conf.model_type.can_sample():
219 | eval_num_images = chunk_size(conf.eval_num_images, rank,
220 | world_size)
221 | desc = "generating images"
222 | for i in trange(0, eval_num_images, batch_size, desc=desc):
223 | batch_size = min(batch_size, eval_num_images - i)
224 | x_T = torch.randn(
225 | (batch_size, 3, conf.img_size, conf.img_size),
226 | device=device)
227 | batch_images = render_uncondition(
228 | conf=conf,
229 | model=model,
230 | x_T=x_T,
231 | sampler=sampler,
232 | latent_sampler=latent_sampler,
233 | conds_mean=conds_mean,
234 | conds_std=conds_std).cpu()
235 |
236 | batch_images = (batch_images + 1) / 2
237 | # keep the generated images
238 | for j in range(len(batch_images)):
239 | img_name = filename(i + j)
240 | torchvision.utils.save_image(
241 | batch_images[j],
242 | os.path.join(conf.generate_dir, f'{img_name}.png'))
243 | elif conf.model_type == ModelType.autoencoder:
244 | if conf.train_mode.is_latent_diffusion():
245 | # evaluate autoencoder + latent diffusion (doesn't give the images)
246 | model: BeatGANsAutoencModel
247 | eval_num_images = chunk_size(conf.eval_num_images, rank,
248 | world_size)
249 | desc = "generating images"
250 | for i in trange(0, eval_num_images, batch_size, desc=desc):
251 | batch_size = min(batch_size, eval_num_images - i)
252 | x_T = torch.randn(
253 | (batch_size, 3, conf.img_size, conf.img_size),
254 | device=device)
255 | batch_images = render_uncondition(
256 | conf=conf,
257 | model=model,
258 | x_T=x_T,
259 | sampler=sampler,
260 | latent_sampler=latent_sampler,
261 | conds_mean=conds_mean,
262 | conds_std=conds_std,
263 | clip_latent_noise=clip_latent_noise,
264 | ).cpu()
265 | batch_images = (batch_images + 1) / 2
266 | # keep the generated images
267 | for j in range(len(batch_images)):
268 | img_name = filename(i + j)
269 | torchvision.utils.save_image(
270 | batch_images[j],
271 | os.path.join(conf.generate_dir, f'{img_name}.png'))
272 | else:
273 | # evaulate autoencoder (given the images)
274 | # to make the FID fair, autoencoder must not see the validation dataset
275 | # also shuffle to make it closer to unconditional generation
276 | train_loader = make_subset_loader(conf,
277 | dataset=train_data,
278 | batch_size=batch_size,
279 | shuffle=True,
280 | parallel=True)
281 |
282 | i = 0
283 | for batch in tqdm(train_loader, desc='generating images'):
284 | imgs = batch['img'].to(device)
285 | x_T = torch.randn(
286 | (len(imgs), 3, conf.img_size, conf.img_size),
287 | device=device)
288 | batch_images = render_condition(
289 | conf=conf,
290 | model=model,
291 | x_T=x_T,
292 | x_start=imgs,
293 | cond=None,
294 | sampler=sampler,
295 | latent_sampler=latent_sampler).cpu()
296 | # model: BeatGANsAutoencModel
297 | # # returns {'cond', 'cond2'}
298 | # conds = model.encode(imgs)
299 | # batch_images = sampler.sample(model=model,
300 | # noise=x_T,
301 | # model_kwargs=conds).cpu()
302 | # denormalize the images
303 | batch_images = (batch_images + 1) / 2
304 | # keep the generated images
305 | for j in range(len(batch_images)):
306 | img_name = filename(i + j)
307 | torchvision.utils.save_image(
308 | batch_images[j],
309 | os.path.join(conf.generate_dir, f'{img_name}.png'))
310 | i += len(imgs)
311 | else:
312 | raise NotImplementedError()
313 | model.train()
314 |
315 | barrier()
316 |
317 | if get_rank() == 0:
318 | fid = fid_score.calculate_fid_given_paths(
319 | [cache_dir, conf.generate_dir],
320 | batch_size,
321 | device=device,
322 | dims=2048)
323 |
324 | # remove the cache
325 | if remove_cache and os.path.exists(conf.generate_dir):
326 | shutil.rmtree(conf.generate_dir)
327 |
328 | barrier()
329 |
330 | if get_rank() == 0:
331 | # need to float it! unless the broadcasted value is wrong
332 | fid = torch.tensor(float(fid), device=device)
333 | broadcast(fid, 0)
334 | else:
335 | fid = torch.tensor(0., device=device)
336 | broadcast(fid, 0)
337 | fid = fid.item()
338 | print(f'fid ({get_rank()}):', fid)
339 |
340 | return fid
341 |
342 |
343 | def loader_to_path(loader: DataLoader, path: str, denormalize: bool):
344 | # not process safe!
345 |
346 | if not os.path.exists(path):
347 | os.makedirs(path)
348 |
349 | # write the loader to files
350 | i = 0
351 | for batch in tqdm(loader, desc='copy images'):
352 | imgs = batch['img']
353 | if denormalize:
354 | imgs = (imgs + 1) / 2
355 | for j in range(len(imgs)):
356 | torchvision.utils.save_image(imgs[j],
357 | os.path.join(path, f'{i+j}.png'))
358 | i += len(imgs)
359 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | from typing import Union
2 | from .unet import BeatGANsUNetModel, BeatGANsUNetConfig
3 | from .unet_autoenc import BeatGANsAutoencConfig, BeatGANsAutoencModel
4 |
5 | Model = Union[BeatGANsUNetModel, BeatGANsAutoencModel]
6 | ModelConfig = Union[BeatGANsUNetConfig, BeatGANsAutoencConfig]
7 |
--------------------------------------------------------------------------------
/model/blocks.py:
--------------------------------------------------------------------------------
1 | import math
2 | from abc import abstractmethod
3 | from dataclasses import dataclass
4 | from numbers import Number
5 |
6 | import torch as th
7 | import torch.nn.functional as F
8 | from choices import *
9 | from config_base import BaseConfig
10 | from torch import nn
11 |
12 | from .nn import (avg_pool_nd, conv_nd, linear, normalization,
13 | timestep_embedding, torch_checkpoint, zero_module)
14 |
15 |
16 | class ScaleAt(Enum):
17 | after_norm = 'afternorm'
18 |
19 |
20 | class TimestepBlock(nn.Module):
21 | """
22 | Any module where forward() takes timestep embeddings as a second argument.
23 | """
24 | @abstractmethod
25 | def forward(self, x, emb=None, cond=None, lateral=None):
26 | """
27 | Apply the module to `x` given `emb` timestep embeddings.
28 | """
29 |
30 |
31 | class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
32 | """
33 | A sequential module that passes timestep embeddings to the children that
34 | support it as an extra input.
35 | """
36 | def forward(self, x, emb=None, cond=None, lateral=None):
37 | for layer in self:
38 | if isinstance(layer, TimestepBlock):
39 | x = layer(x, emb=emb, cond=cond, lateral=lateral)
40 | else:
41 | x = layer(x)
42 | return x
43 |
44 |
45 | @dataclass
46 | class ResBlockConfig(BaseConfig):
47 | channels: int
48 | emb_channels: int
49 | dropout: float
50 | out_channels: int = None
51 | # condition the resblock with time (and encoder's output)
52 | use_condition: bool = True
53 | # whether to use 3x3 conv for skip path when the channels aren't matched
54 | use_conv: bool = False
55 | # dimension of conv (always 2 = 2d)
56 | dims: int = 2
57 | # gradient checkpoint
58 | use_checkpoint: bool = False
59 | up: bool = False
60 | down: bool = False
61 | # whether to condition with both time & encoder's output
62 | two_cond: bool = False
63 | # number of encoders' output channels
64 | cond_emb_channels: int = None
65 | # suggest: False
66 | has_lateral: bool = False
67 | lateral_channels: int = None
68 | # whether to init the convolution with zero weights
69 | # this is default from BeatGANs and seems to help learning
70 | use_zero_module: bool = True
71 |
72 | def __post_init__(self):
73 | self.out_channels = self.out_channels or self.channels
74 | self.cond_emb_channels = self.cond_emb_channels or self.emb_channels
75 |
76 | def make_model(self):
77 | return ResBlock(self)
78 |
79 |
80 | class ResBlock(TimestepBlock):
81 | """
82 | A residual block that can optionally change the number of channels.
83 |
84 | total layers:
85 | in_layers
86 | - norm
87 | - act
88 | - conv
89 | out_layers
90 | - norm
91 | - (modulation)
92 | - act
93 | - conv
94 | """
95 | def __init__(self, conf: ResBlockConfig):
96 | super().__init__()
97 | self.conf = conf
98 |
99 | #############################
100 | # IN LAYERS
101 | #############################
102 | assert conf.lateral_channels is None
103 | layers = [
104 | normalization(conf.channels),
105 | nn.SiLU(),
106 | conv_nd(conf.dims, conf.channels, conf.out_channels, 3, padding=1)
107 | ]
108 | self.in_layers = nn.Sequential(*layers)
109 |
110 | self.updown = conf.up or conf.down
111 |
112 | if conf.up:
113 | self.h_upd = Upsample(conf.channels, False, conf.dims)
114 | self.x_upd = Upsample(conf.channels, False, conf.dims)
115 | elif conf.down:
116 | self.h_upd = Downsample(conf.channels, False, conf.dims)
117 | self.x_upd = Downsample(conf.channels, False, conf.dims)
118 | else:
119 | self.h_upd = self.x_upd = nn.Identity()
120 |
121 | #############################
122 | # OUT LAYERS CONDITIONS
123 | #############################
124 | if conf.use_condition:
125 | # condition layers for the out_layers
126 | self.emb_layers = nn.Sequential(
127 | nn.SiLU(),
128 | linear(conf.emb_channels, 2 * conf.out_channels),
129 | )
130 |
131 | if conf.two_cond:
132 | self.cond_emb_layers = nn.Sequential(
133 | nn.SiLU(),
134 | linear(conf.cond_emb_channels, conf.out_channels),
135 | )
136 | #############################
137 | # OUT LAYERS (ignored when there is no condition)
138 | #############################
139 | # original version
140 | conv = conv_nd(conf.dims,
141 | conf.out_channels,
142 | conf.out_channels,
143 | 3,
144 | padding=1)
145 | if conf.use_zero_module:
146 | # zere out the weights
147 | # it seems to help training
148 | conv = zero_module(conv)
149 |
150 | # construct the layers
151 | # - norm
152 | # - (modulation)
153 | # - act
154 | # - dropout
155 | # - conv
156 | layers = []
157 | layers += [
158 | normalization(conf.out_channels),
159 | nn.SiLU(),
160 | nn.Dropout(p=conf.dropout),
161 | conv,
162 | ]
163 | self.out_layers = nn.Sequential(*layers)
164 |
165 | #############################
166 | # SKIP LAYERS
167 | #############################
168 | if conf.out_channels == conf.channels:
169 | # cannot be used with gatedconv, also gatedconv is alsways used as the first block
170 | self.skip_connection = nn.Identity()
171 | else:
172 | if conf.use_conv:
173 | kernel_size = 3
174 | padding = 1
175 | else:
176 | kernel_size = 1
177 | padding = 0
178 |
179 | self.skip_connection = conv_nd(conf.dims,
180 | conf.channels,
181 | conf.out_channels,
182 | kernel_size,
183 | padding=padding)
184 |
185 | def forward(self, x, emb=None, cond=None, lateral=None):
186 | """
187 | Apply the block to a Tensor, conditioned on a timestep embedding.
188 |
189 | Args:
190 | x: input
191 | lateral: lateral connection from the encoder
192 | """
193 | return torch_checkpoint(self._forward, (x, emb, cond, lateral),
194 | self.conf.use_checkpoint)
195 |
196 | def _forward(
197 | self,
198 | x,
199 | emb=None,
200 | cond=None,
201 | lateral=None,
202 | ):
203 | """
204 | Args:
205 | lateral: required if "has_lateral" and non-gated, with gated, it can be supplied optionally
206 | """
207 | if self.conf.has_lateral:
208 | # lateral may be supplied even if it doesn't require
209 | # the model will take the lateral only if "has_lateral"
210 | assert lateral is not None
211 | x = th.cat([x, lateral], dim=1)
212 |
213 | if self.updown:
214 | in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
215 | h = in_rest(x)
216 | h = self.h_upd(h)
217 | x = self.x_upd(x)
218 | h = in_conv(h)
219 | else:
220 | h = self.in_layers(x)
221 |
222 | if self.conf.use_condition:
223 | # it's possible that the network may not receieve the time emb
224 | # this happens with autoenc and setting the time_at
225 | if emb is not None:
226 | emb_out = self.emb_layers(emb).type(h.dtype)
227 | else:
228 | emb_out = None
229 |
230 | if self.conf.two_cond:
231 | # it's possible that the network is two_cond
232 | # but it doesn't get the second condition
233 | # in which case, we ignore the second condition
234 | # and treat as if the network has one condition
235 | if cond is None:
236 | cond_out = None
237 | else:
238 | cond_out = self.cond_emb_layers(cond).type(h.dtype)
239 |
240 | if cond_out is not None:
241 | while len(cond_out.shape) < len(h.shape):
242 | cond_out = cond_out[..., None]
243 | else:
244 | cond_out = None
245 |
246 | # this is the new refactored code
247 | h = apply_conditions(
248 | h=h,
249 | emb=emb_out,
250 | cond=cond_out,
251 | layers=self.out_layers,
252 | scale_bias=1,
253 | in_channels=self.conf.out_channels,
254 | up_down_layer=None,
255 | )
256 |
257 | return self.skip_connection(x) + h
258 |
259 |
260 | def apply_conditions(
261 | h,
262 | emb=None,
263 | cond=None,
264 | layers: nn.Sequential = None,
265 | scale_bias: float = 1,
266 | in_channels: int = 512,
267 | up_down_layer: nn.Module = None,
268 | ):
269 | """
270 | apply conditions on the feature maps
271 |
272 | Args:
273 | emb: time conditional (ready to scale + shift)
274 | cond: encoder's conditional (read to scale + shift)
275 | """
276 | two_cond = emb is not None and cond is not None
277 |
278 | if emb is not None:
279 | # adjusting shapes
280 | while len(emb.shape) < len(h.shape):
281 | emb = emb[..., None]
282 |
283 | if two_cond:
284 | # adjusting shapes
285 | while len(cond.shape) < len(h.shape):
286 | cond = cond[..., None]
287 | # time first
288 | scale_shifts = [emb, cond]
289 | else:
290 | # "cond" is not used with single cond mode
291 | scale_shifts = [emb]
292 |
293 | # support scale, shift or shift only
294 | for i, each in enumerate(scale_shifts):
295 | if each is None:
296 | # special case: the condition is not provided
297 | a = None
298 | b = None
299 | else:
300 | if each.shape[1] == in_channels * 2:
301 | a, b = th.chunk(each, 2, dim=1)
302 | else:
303 | a = each
304 | b = None
305 | scale_shifts[i] = (a, b)
306 |
307 | # condition scale bias could be a list
308 | if isinstance(scale_bias, Number):
309 | biases = [scale_bias] * len(scale_shifts)
310 | else:
311 | # a list
312 | biases = scale_bias
313 |
314 | # default, the scale & shift are applied after the group norm but BEFORE SiLU
315 | pre_layers, post_layers = layers[0], layers[1:]
316 |
317 | # spilt the post layer to be able to scale up or down before conv
318 | # post layers will contain only the conv
319 | mid_layers, post_layers = post_layers[:-2], post_layers[-2:]
320 |
321 | h = pre_layers(h)
322 | # scale and shift for each condition
323 | for i, (scale, shift) in enumerate(scale_shifts):
324 | # if scale is None, it indicates that the condition is not provided
325 | if scale is not None:
326 | h = h * (biases[i] + scale)
327 | if shift is not None:
328 | h = h + shift
329 | h = mid_layers(h)
330 |
331 | # upscale or downscale if any just before the last conv
332 | if up_down_layer is not None:
333 | h = up_down_layer(h)
334 | h = post_layers(h)
335 | return h
336 |
337 |
338 | class Upsample(nn.Module):
339 | """
340 | An upsampling layer with an optional convolution.
341 |
342 | :param channels: channels in the inputs and outputs.
343 | :param use_conv: a bool determining if a convolution is applied.
344 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
345 | upsampling occurs in the inner-two dimensions.
346 | """
347 | def __init__(self, channels, use_conv, dims=2, out_channels=None):
348 | super().__init__()
349 | self.channels = channels
350 | self.out_channels = out_channels or channels
351 | self.use_conv = use_conv
352 | self.dims = dims
353 | if use_conv:
354 | self.conv = conv_nd(dims,
355 | self.channels,
356 | self.out_channels,
357 | 3,
358 | padding=1)
359 |
360 | def forward(self, x):
361 | assert x.shape[1] == self.channels
362 | if self.dims == 3:
363 | x = F.interpolate(x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
364 | mode="nearest")
365 | else:
366 | x = F.interpolate(x, scale_factor=2, mode="nearest")
367 | if self.use_conv:
368 | x = self.conv(x)
369 | return x
370 |
371 |
372 | class Downsample(nn.Module):
373 | """
374 | A downsampling layer with an optional convolution.
375 |
376 | :param channels: channels in the inputs and outputs.
377 | :param use_conv: a bool determining if a convolution is applied.
378 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
379 | downsampling occurs in the inner-two dimensions.
380 | """
381 | def __init__(self, channels, use_conv, dims=2, out_channels=None):
382 | super().__init__()
383 | self.channels = channels
384 | self.out_channels = out_channels or channels
385 | self.use_conv = use_conv
386 | self.dims = dims
387 | stride = 2 if dims != 3 else (1, 2, 2)
388 | if use_conv:
389 | self.op = conv_nd(dims,
390 | self.channels,
391 | self.out_channels,
392 | 3,
393 | stride=stride,
394 | padding=1)
395 | else:
396 | assert self.channels == self.out_channels
397 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
398 |
399 | def forward(self, x):
400 | assert x.shape[1] == self.channels
401 | return self.op(x)
402 |
403 |
404 | class AttentionBlock(nn.Module):
405 | """
406 | An attention block that allows spatial positions to attend to each other.
407 |
408 | Originally ported from here, but adapted to the N-d case.
409 | https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
410 | """
411 | def __init__(
412 | self,
413 | channels,
414 | num_heads=1,
415 | num_head_channels=-1,
416 | use_checkpoint=False,
417 | use_new_attention_order=False,
418 | ):
419 | super().__init__()
420 | self.channels = channels
421 | if num_head_channels == -1:
422 | self.num_heads = num_heads
423 | else:
424 | assert (
425 | channels % num_head_channels == 0
426 | ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
427 | self.num_heads = channels // num_head_channels
428 | self.use_checkpoint = use_checkpoint
429 | self.norm = normalization(channels)
430 | self.qkv = conv_nd(1, channels, channels * 3, 1)
431 | if use_new_attention_order:
432 | # split qkv before split heads
433 | self.attention = QKVAttention(self.num_heads)
434 | else:
435 | # split heads before split qkv
436 | self.attention = QKVAttentionLegacy(self.num_heads)
437 |
438 | self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
439 |
440 | def forward(self, x):
441 | return torch_checkpoint(self._forward, (x, ), self.use_checkpoint)
442 |
443 | def _forward(self, x):
444 | b, c, *spatial = x.shape
445 | x = x.reshape(b, c, -1)
446 | qkv = self.qkv(self.norm(x))
447 | h = self.attention(qkv)
448 | h = self.proj_out(h)
449 | return (x + h).reshape(b, c, *spatial)
450 |
451 |
452 | def count_flops_attn(model, _x, y):
453 | """
454 | A counter for the `thop` package to count the operations in an
455 | attention operation.
456 | Meant to be used like:
457 | macs, params = thop.profile(
458 | model,
459 | inputs=(inputs, timestamps),
460 | custom_ops={QKVAttention: QKVAttention.count_flops},
461 | )
462 | """
463 | b, c, *spatial = y[0].shape
464 | num_spatial = int(np.prod(spatial))
465 | # We perform two matmuls with the same number of ops.
466 | # The first computes the weight matrix, the second computes
467 | # the combination of the value vectors.
468 | matmul_ops = 2 * b * (num_spatial**2) * c
469 | model.total_ops += th.DoubleTensor([matmul_ops])
470 |
471 |
472 | class QKVAttentionLegacy(nn.Module):
473 | """
474 | A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
475 | """
476 | def __init__(self, n_heads):
477 | super().__init__()
478 | self.n_heads = n_heads
479 |
480 | def forward(self, qkv):
481 | """
482 | Apply QKV attention.
483 |
484 | :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
485 | :return: an [N x (H * C) x T] tensor after attention.
486 | """
487 | bs, width, length = qkv.shape
488 | assert width % (3 * self.n_heads) == 0
489 | ch = width // (3 * self.n_heads)
490 | q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch,
491 | dim=1)
492 | scale = 1 / math.sqrt(math.sqrt(ch))
493 | weight = th.einsum(
494 | "bct,bcs->bts", q * scale,
495 | k * scale) # More stable with f16 than dividing afterwards
496 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
497 | a = th.einsum("bts,bcs->bct", weight, v)
498 | return a.reshape(bs, -1, length)
499 |
500 | @staticmethod
501 | def count_flops(model, _x, y):
502 | return count_flops_attn(model, _x, y)
503 |
504 |
505 | class QKVAttention(nn.Module):
506 | """
507 | A module which performs QKV attention and splits in a different order.
508 | """
509 | def __init__(self, n_heads):
510 | super().__init__()
511 | self.n_heads = n_heads
512 |
513 | def forward(self, qkv):
514 | """
515 | Apply QKV attention.
516 |
517 | :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
518 | :return: an [N x (H * C) x T] tensor after attention.
519 | """
520 | bs, width, length = qkv.shape
521 | assert width % (3 * self.n_heads) == 0
522 | ch = width // (3 * self.n_heads)
523 | q, k, v = qkv.chunk(3, dim=1)
524 | scale = 1 / math.sqrt(math.sqrt(ch))
525 | weight = th.einsum(
526 | "bct,bcs->bts",
527 | (q * scale).view(bs * self.n_heads, ch, length),
528 | (k * scale).view(bs * self.n_heads, ch, length),
529 | ) # More stable with f16 than dividing afterwards
530 | weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
531 | a = th.einsum("bts,bcs->bct", weight,
532 | v.reshape(bs * self.n_heads, ch, length))
533 | return a.reshape(bs, -1, length)
534 |
535 | @staticmethod
536 | def count_flops(model, _x, y):
537 | return count_flops_attn(model, _x, y)
538 |
539 |
540 | class AttentionPool2d(nn.Module):
541 | """
542 | Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
543 | """
544 | def __init__(
545 | self,
546 | spacial_dim: int,
547 | embed_dim: int,
548 | num_heads_channels: int,
549 | output_dim: int = None,
550 | ):
551 | super().__init__()
552 | self.positional_embedding = nn.Parameter(
553 | th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
554 | self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
555 | self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
556 | self.num_heads = embed_dim // num_heads_channels
557 | self.attention = QKVAttention(self.num_heads)
558 |
559 | def forward(self, x):
560 | b, c, *_spatial = x.shape
561 | x = x.reshape(b, c, -1) # NC(HW)
562 | x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
563 | x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
564 | x = self.qkv_proj(x)
565 | x = self.attention(x)
566 | x = self.c_proj(x)
567 | return x[:, :, 0]
568 |
--------------------------------------------------------------------------------
/model/latentnet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from enum import Enum
4 | from typing import NamedTuple, Tuple
5 |
6 | import torch
7 | from choices import *
8 | from config_base import BaseConfig
9 | from torch import nn
10 | from torch.nn import init
11 |
12 | from .blocks import *
13 | from .nn import timestep_embedding
14 | from .unet import *
15 |
16 |
17 | class LatentNetType(Enum):
18 | none = 'none'
19 | # injecting inputs into the hidden layers
20 | skip = 'skip'
21 |
22 |
23 | class LatentNetReturn(NamedTuple):
24 | pred: torch.Tensor = None
25 |
26 |
27 | @dataclass
28 | class MLPSkipNetConfig(BaseConfig):
29 | """
30 | default MLP for the latent DPM in the paper!
31 | """
32 | num_channels: int
33 | skip_layers: Tuple[int]
34 | num_hid_channels: int
35 | num_layers: int
36 | num_time_emb_channels: int = 64
37 | activation: Activation = Activation.silu
38 | use_norm: bool = True
39 | condition_bias: float = 1
40 | dropout: float = 0
41 | last_act: Activation = Activation.none
42 | num_time_layers: int = 2
43 | time_last_act: bool = False
44 |
45 | def make_model(self):
46 | return MLPSkipNet(self)
47 |
48 |
49 | class MLPSkipNet(nn.Module):
50 | """
51 | concat x to hidden layers
52 |
53 | default MLP for the latent DPM in the paper!
54 | """
55 | def __init__(self, conf: MLPSkipNetConfig):
56 | super().__init__()
57 | self.conf = conf
58 |
59 | layers = []
60 | for i in range(conf.num_time_layers):
61 | if i == 0:
62 | a = conf.num_time_emb_channels
63 | b = conf.num_channels
64 | else:
65 | a = conf.num_channels
66 | b = conf.num_channels
67 | layers.append(nn.Linear(a, b))
68 | if i < conf.num_time_layers - 1 or conf.time_last_act:
69 | layers.append(conf.activation.get_act())
70 | self.time_embed = nn.Sequential(*layers)
71 |
72 | self.layers = nn.ModuleList([])
73 | for i in range(conf.num_layers):
74 | if i == 0:
75 | act = conf.activation
76 | norm = conf.use_norm
77 | cond = True
78 | a, b = conf.num_channels, conf.num_hid_channels
79 | dropout = conf.dropout
80 | elif i == conf.num_layers - 1:
81 | act = Activation.none
82 | norm = False
83 | cond = False
84 | a, b = conf.num_hid_channels, conf.num_channels
85 | dropout = 0
86 | else:
87 | act = conf.activation
88 | norm = conf.use_norm
89 | cond = True
90 | a, b = conf.num_hid_channels, conf.num_hid_channels
91 | dropout = conf.dropout
92 |
93 | if i in conf.skip_layers:
94 | a += conf.num_channels
95 |
96 | self.layers.append(
97 | MLPLNAct(
98 | a,
99 | b,
100 | norm=norm,
101 | activation=act,
102 | cond_channels=conf.num_channels,
103 | use_cond=cond,
104 | condition_bias=conf.condition_bias,
105 | dropout=dropout,
106 | ))
107 | self.last_act = conf.last_act.get_act()
108 |
109 | def forward(self, x, t, **kwargs):
110 | t = timestep_embedding(t, self.conf.num_time_emb_channels)
111 | cond = self.time_embed(t)
112 | h = x
113 | for i in range(len(self.layers)):
114 | if i in self.conf.skip_layers:
115 | # injecting input into the hidden layers
116 | h = torch.cat([h, x], dim=1)
117 | h = self.layers[i].forward(x=h, cond=cond)
118 | h = self.last_act(h)
119 | return LatentNetReturn(h)
120 |
121 |
122 | class MLPLNAct(nn.Module):
123 | def __init__(
124 | self,
125 | in_channels: int,
126 | out_channels: int,
127 | norm: bool,
128 | use_cond: bool,
129 | activation: Activation,
130 | cond_channels: int,
131 | condition_bias: float = 0,
132 | dropout: float = 0,
133 | ):
134 | super().__init__()
135 | self.activation = activation
136 | self.condition_bias = condition_bias
137 | self.use_cond = use_cond
138 |
139 | self.linear = nn.Linear(in_channels, out_channels)
140 | self.act = activation.get_act()
141 | if self.use_cond:
142 | self.linear_emb = nn.Linear(cond_channels, out_channels)
143 | self.cond_layers = nn.Sequential(self.act, self.linear_emb)
144 | if norm:
145 | self.norm = nn.LayerNorm(out_channels)
146 | else:
147 | self.norm = nn.Identity()
148 |
149 | if dropout > 0:
150 | self.dropout = nn.Dropout(p=dropout)
151 | else:
152 | self.dropout = nn.Identity()
153 |
154 | self.init_weights()
155 |
156 | def init_weights(self):
157 | for module in self.modules():
158 | if isinstance(module, nn.Linear):
159 | if self.activation == Activation.relu:
160 | init.kaiming_normal_(module.weight,
161 | a=0,
162 | nonlinearity='relu')
163 | elif self.activation == Activation.lrelu:
164 | init.kaiming_normal_(module.weight,
165 | a=0.2,
166 | nonlinearity='leaky_relu')
167 | elif self.activation == Activation.silu:
168 | init.kaiming_normal_(module.weight,
169 | a=0,
170 | nonlinearity='relu')
171 | else:
172 | # leave it as default
173 | pass
174 |
175 | def forward(self, x, cond=None):
176 | x = self.linear(x)
177 | if self.use_cond:
178 | # (n, c) or (n, c * 2)
179 | cond = self.cond_layers(cond)
180 | cond = (cond, None)
181 |
182 | # scale shift first
183 | x = x * (self.condition_bias + cond[0])
184 | if cond[1] is not None:
185 | x = x + cond[1]
186 | # then norm
187 | x = self.norm(x)
188 | else:
189 | # no condition
190 | x = self.norm(x)
191 | x = self.act(x)
192 | x = self.dropout(x)
193 | return x
--------------------------------------------------------------------------------
/model/nn.py:
--------------------------------------------------------------------------------
1 | """
2 | Various utilities for neural networks.
3 | """
4 |
5 | from enum import Enum
6 | import math
7 | from typing import Optional
8 |
9 | import torch as th
10 | import torch.nn as nn
11 | import torch.utils.checkpoint
12 |
13 | import torch.nn.functional as F
14 |
15 |
16 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
17 | class SiLU(nn.Module):
18 | # @th.jit.script
19 | def forward(self, x):
20 | return x * th.sigmoid(x)
21 |
22 |
23 | class GroupNorm32(nn.GroupNorm):
24 | def forward(self, x):
25 | return super().forward(x.float()).type(x.dtype)
26 |
27 |
28 | def conv_nd(dims, *args, **kwargs):
29 | """
30 | Create a 1D, 2D, or 3D convolution module.
31 | """
32 | if dims == 1:
33 | return nn.Conv1d(*args, **kwargs)
34 | elif dims == 2:
35 | return nn.Conv2d(*args, **kwargs)
36 | elif dims == 3:
37 | return nn.Conv3d(*args, **kwargs)
38 | raise ValueError(f"unsupported dimensions: {dims}")
39 |
40 |
41 | def linear(*args, **kwargs):
42 | """
43 | Create a linear module.
44 | """
45 | return nn.Linear(*args, **kwargs)
46 |
47 |
48 | def avg_pool_nd(dims, *args, **kwargs):
49 | """
50 | Create a 1D, 2D, or 3D average pooling module.
51 | """
52 | if dims == 1:
53 | return nn.AvgPool1d(*args, **kwargs)
54 | elif dims == 2:
55 | return nn.AvgPool2d(*args, **kwargs)
56 | elif dims == 3:
57 | return nn.AvgPool3d(*args, **kwargs)
58 | raise ValueError(f"unsupported dimensions: {dims}")
59 |
60 |
61 | def update_ema(target_params, source_params, rate=0.99):
62 | """
63 | Update target parameters to be closer to those of source parameters using
64 | an exponential moving average.
65 |
66 | :param target_params: the target parameter sequence.
67 | :param source_params: the source parameter sequence.
68 | :param rate: the EMA rate (closer to 1 means slower).
69 | """
70 | for targ, src in zip(target_params, source_params):
71 | targ.detach().mul_(rate).add_(src, alpha=1 - rate)
72 |
73 |
74 | def zero_module(module):
75 | """
76 | Zero out the parameters of a module and return it.
77 | """
78 | for p in module.parameters():
79 | p.detach().zero_()
80 | return module
81 |
82 |
83 | def scale_module(module, scale):
84 | """
85 | Scale the parameters of a module and return it.
86 | """
87 | for p in module.parameters():
88 | p.detach().mul_(scale)
89 | return module
90 |
91 |
92 | def mean_flat(tensor):
93 | """
94 | Take the mean over all non-batch dimensions.
95 | """
96 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
97 |
98 |
99 | def normalization(channels):
100 | """
101 | Make a standard normalization layer.
102 |
103 | :param channels: number of input channels.
104 | :return: an nn.Module for normalization.
105 | """
106 | return GroupNorm32(min(32, channels), channels)
107 |
108 |
109 | def timestep_embedding(timesteps, dim, max_period=10000):
110 | """
111 | Create sinusoidal timestep embeddings.
112 |
113 | :param timesteps: a 1-D Tensor of N indices, one per batch element.
114 | These may be fractional.
115 | :param dim: the dimension of the output.
116 | :param max_period: controls the minimum frequency of the embeddings.
117 | :return: an [N x dim] Tensor of positional embeddings.
118 | """
119 | half = dim // 2
120 | freqs = th.exp(-math.log(max_period) *
121 | th.arange(start=0, end=half, dtype=th.float32) /
122 | half).to(device=timesteps.device)
123 | args = timesteps[:, None].float() * freqs[None]
124 | embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
125 | if dim % 2:
126 | embedding = th.cat(
127 | [embedding, th.zeros_like(embedding[:, :1])], dim=-1)
128 | return embedding
129 |
130 |
131 | def torch_checkpoint(func, args, flag, preserve_rng_state=False):
132 | # torch's gradient checkpoint works with automatic mixed precision, given torch >= 1.8
133 | if flag:
134 | return torch.utils.checkpoint.checkpoint(
135 | func, *args, preserve_rng_state=preserve_rng_state)
136 | else:
137 | return func(*args)
138 |
--------------------------------------------------------------------------------
/model/unet.py:
--------------------------------------------------------------------------------
1 | import math
2 | from dataclasses import dataclass
3 | from numbers import Number
4 | from typing import NamedTuple, Tuple, Union
5 |
6 | import numpy as np
7 | import torch as th
8 | from torch import nn
9 | import torch.nn.functional as F
10 | from choices import *
11 | from config_base import BaseConfig
12 | from .blocks import *
13 |
14 | from .nn import (conv_nd, linear, normalization, timestep_embedding,
15 | torch_checkpoint, zero_module)
16 |
17 |
18 | @dataclass
19 | class BeatGANsUNetConfig(BaseConfig):
20 | image_size: int = 64
21 | in_channels: int = 3
22 | # base channels, will be multiplied
23 | model_channels: int = 64
24 | # output of the unet
25 | # suggest: 3
26 | # you only need 6 if you also model the variance of the noise prediction (usually we use an analytical variance hence 3)
27 | out_channels: int = 3
28 | # how many repeating resblocks per resolution
29 | # the decoding side would have "one more" resblock
30 | # default: 2
31 | num_res_blocks: int = 2
32 | # you can also set the number of resblocks specifically for the input blocks
33 | # default: None = above
34 | num_input_res_blocks: int = None
35 | # number of time embed channels and style channels
36 | embed_channels: int = 512
37 | # at what resolutions you want to do self-attention of the feature maps
38 | # attentions generally improve performance
39 | # default: [16]
40 | # beatgans: [32, 16, 8]
41 | attention_resolutions: Tuple[int] = (16, )
42 | # number of time embed channels
43 | time_embed_channels: int = None
44 | # dropout applies to the resblocks (on feature maps)
45 | dropout: float = 0.1
46 | channel_mult: Tuple[int] = (1, 2, 4, 8)
47 | input_channel_mult: Tuple[int] = None
48 | conv_resample: bool = True
49 | # always 2 = 2d conv
50 | dims: int = 2
51 | # don't use this, legacy from BeatGANs
52 | num_classes: int = None
53 | use_checkpoint: bool = False
54 | # number of attention heads
55 | num_heads: int = 1
56 | # or specify the number of channels per attention head
57 | num_head_channels: int = -1
58 | # what's this?
59 | num_heads_upsample: int = -1
60 | # use resblock for upscale/downscale blocks (expensive)
61 | # default: True (BeatGANs)
62 | resblock_updown: bool = True
63 | # never tried
64 | use_new_attention_order: bool = False
65 | resnet_two_cond: bool = False
66 | resnet_cond_channels: int = None
67 | # init the decoding conv layers with zero weights, this speeds up training
68 | # default: True (BeattGANs)
69 | resnet_use_zero_module: bool = True
70 | # gradient checkpoint the attention operation
71 | attn_checkpoint: bool = False
72 |
73 | def make_model(self):
74 | return BeatGANsUNetModel(self)
75 |
76 |
77 | class BeatGANsUNetModel(nn.Module):
78 | def __init__(self, conf: BeatGANsUNetConfig):
79 | super().__init__()
80 | self.conf = conf
81 |
82 | if conf.num_heads_upsample == -1:
83 | self.num_heads_upsample = conf.num_heads
84 |
85 | self.dtype = th.float32
86 |
87 | self.time_emb_channels = conf.time_embed_channels or conf.model_channels
88 | self.time_embed = nn.Sequential(
89 | linear(self.time_emb_channels, conf.embed_channels),
90 | nn.SiLU(),
91 | linear(conf.embed_channels, conf.embed_channels),
92 | )
93 |
94 | if conf.num_classes is not None:
95 | self.label_emb = nn.Embedding(conf.num_classes,
96 | conf.embed_channels)
97 |
98 | ch = input_ch = int(conf.channel_mult[0] * conf.model_channels)
99 | self.input_blocks = nn.ModuleList([
100 | TimestepEmbedSequential(
101 | conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
102 | ])
103 |
104 | kwargs = dict(
105 | use_condition=True,
106 | two_cond=conf.resnet_two_cond,
107 | use_zero_module=conf.resnet_use_zero_module,
108 | # style channels for the resnet block
109 | cond_emb_channels=conf.resnet_cond_channels,
110 | )
111 |
112 | self._feature_size = ch
113 |
114 | # input_block_chans = [ch]
115 | input_block_chans = [[] for _ in range(len(conf.channel_mult))]
116 | input_block_chans[0].append(ch)
117 |
118 | # number of blocks at each resolution
119 | self.input_num_blocks = [0 for _ in range(len(conf.channel_mult))]
120 | self.input_num_blocks[0] = 1
121 | self.output_num_blocks = [0 for _ in range(len(conf.channel_mult))]
122 |
123 | ds = 1
124 | resolution = conf.image_size
125 | for level, mult in enumerate(conf.input_channel_mult
126 | or conf.channel_mult):
127 | for _ in range(conf.num_input_res_blocks or conf.num_res_blocks):
128 | layers = [
129 | ResBlockConfig(
130 | ch,
131 | conf.embed_channels,
132 | conf.dropout,
133 | out_channels=int(mult * conf.model_channels),
134 | dims=conf.dims,
135 | use_checkpoint=conf.use_checkpoint,
136 | **kwargs,
137 | ).make_model()
138 | ]
139 | ch = int(mult * conf.model_channels)
140 | if resolution in conf.attention_resolutions:
141 | layers.append(
142 | AttentionBlock(
143 | ch,
144 | use_checkpoint=conf.use_checkpoint
145 | or conf.attn_checkpoint,
146 | num_heads=conf.num_heads,
147 | num_head_channels=conf.num_head_channels,
148 | use_new_attention_order=conf.
149 | use_new_attention_order,
150 | ))
151 | self.input_blocks.append(TimestepEmbedSequential(*layers))
152 | self._feature_size += ch
153 | # input_block_chans.append(ch)
154 | input_block_chans[level].append(ch)
155 | self.input_num_blocks[level] += 1
156 | # print(input_block_chans)
157 | if level != len(conf.channel_mult) - 1:
158 | resolution //= 2
159 | out_ch = ch
160 | self.input_blocks.append(
161 | TimestepEmbedSequential(
162 | ResBlockConfig(
163 | ch,
164 | conf.embed_channels,
165 | conf.dropout,
166 | out_channels=out_ch,
167 | dims=conf.dims,
168 | use_checkpoint=conf.use_checkpoint,
169 | down=True,
170 | **kwargs,
171 | ).make_model() if conf.
172 | resblock_updown else Downsample(ch,
173 | conf.conv_resample,
174 | dims=conf.dims,
175 | out_channels=out_ch)))
176 | ch = out_ch
177 | # input_block_chans.append(ch)
178 | input_block_chans[level + 1].append(ch)
179 | self.input_num_blocks[level + 1] += 1
180 | ds *= 2
181 | self._feature_size += ch
182 |
183 | self.middle_block = TimestepEmbedSequential(
184 | ResBlockConfig(
185 | ch,
186 | conf.embed_channels,
187 | conf.dropout,
188 | dims=conf.dims,
189 | use_checkpoint=conf.use_checkpoint,
190 | **kwargs,
191 | ).make_model(),
192 | AttentionBlock(
193 | ch,
194 | use_checkpoint=conf.use_checkpoint or conf.attn_checkpoint,
195 | num_heads=conf.num_heads,
196 | num_head_channels=conf.num_head_channels,
197 | use_new_attention_order=conf.use_new_attention_order,
198 | ),
199 | ResBlockConfig(
200 | ch,
201 | conf.embed_channels,
202 | conf.dropout,
203 | dims=conf.dims,
204 | use_checkpoint=conf.use_checkpoint,
205 | **kwargs,
206 | ).make_model(),
207 | )
208 | self._feature_size += ch
209 |
210 | self.output_blocks = nn.ModuleList([])
211 | for level, mult in list(enumerate(conf.channel_mult))[::-1]:
212 | for i in range(conf.num_res_blocks + 1):
213 | # print(input_block_chans)
214 | # ich = input_block_chans.pop()
215 | try:
216 | ich = input_block_chans[level].pop()
217 | except IndexError:
218 | # this happens only when num_res_block > num_enc_res_block
219 | # we will not have enough lateral (skip) connecions for all decoder blocks
220 | ich = 0
221 | # print('pop:', ich)
222 | layers = [
223 | ResBlockConfig(
224 | # only direct channels when gated
225 | channels=ch + ich,
226 | emb_channels=conf.embed_channels,
227 | dropout=conf.dropout,
228 | out_channels=int(conf.model_channels * mult),
229 | dims=conf.dims,
230 | use_checkpoint=conf.use_checkpoint,
231 | # lateral channels are described here when gated
232 | has_lateral=True if ich > 0 else False,
233 | lateral_channels=None,
234 | **kwargs,
235 | ).make_model()
236 | ]
237 | ch = int(conf.model_channels * mult)
238 | if resolution in conf.attention_resolutions:
239 | layers.append(
240 | AttentionBlock(
241 | ch,
242 | use_checkpoint=conf.use_checkpoint
243 | or conf.attn_checkpoint,
244 | num_heads=self.num_heads_upsample,
245 | num_head_channels=conf.num_head_channels,
246 | use_new_attention_order=conf.
247 | use_new_attention_order,
248 | ))
249 | if level and i == conf.num_res_blocks:
250 | resolution *= 2
251 | out_ch = ch
252 | layers.append(
253 | ResBlockConfig(
254 | ch,
255 | conf.embed_channels,
256 | conf.dropout,
257 | out_channels=out_ch,
258 | dims=conf.dims,
259 | use_checkpoint=conf.use_checkpoint,
260 | up=True,
261 | **kwargs,
262 | ).make_model() if (
263 | conf.resblock_updown
264 | ) else Upsample(ch,
265 | conf.conv_resample,
266 | dims=conf.dims,
267 | out_channels=out_ch))
268 | ds //= 2
269 | self.output_blocks.append(TimestepEmbedSequential(*layers))
270 | self.output_num_blocks[level] += 1
271 | self._feature_size += ch
272 |
273 | # print(input_block_chans)
274 | # print('inputs:', self.input_num_blocks)
275 | # print('outputs:', self.output_num_blocks)
276 |
277 | if conf.resnet_use_zero_module:
278 | self.out = nn.Sequential(
279 | normalization(ch),
280 | nn.SiLU(),
281 | zero_module(
282 | conv_nd(conf.dims,
283 | input_ch,
284 | conf.out_channels,
285 | 3,
286 | padding=1)),
287 | )
288 | else:
289 | self.out = nn.Sequential(
290 | normalization(ch),
291 | nn.SiLU(),
292 | conv_nd(conf.dims, input_ch, conf.out_channels, 3, padding=1),
293 | )
294 |
295 | def forward(self, x, t, y=None, **kwargs):
296 | """
297 | Apply the model to an input batch.
298 |
299 | :param x: an [N x C x ...] Tensor of inputs.
300 | :param timesteps: a 1-D batch of timesteps.
301 | :param y: an [N] Tensor of labels, if class-conditional.
302 | :return: an [N x C x ...] Tensor of outputs.
303 | """
304 | assert (y is not None) == (
305 | self.conf.num_classes is not None
306 | ), "must specify y if and only if the model is class-conditional"
307 |
308 | # hs = []
309 | hs = [[] for _ in range(len(self.conf.channel_mult))]
310 | emb = self.time_embed(timestep_embedding(t, self.time_emb_channels))
311 |
312 | if self.conf.num_classes is not None:
313 | raise NotImplementedError()
314 | # assert y.shape == (x.shape[0], )
315 | # emb = emb + self.label_emb(y)
316 |
317 | # new code supports input_num_blocks != output_num_blocks
318 | h = x.type(self.dtype)
319 | k = 0
320 | for i in range(len(self.input_num_blocks)):
321 | for j in range(self.input_num_blocks[i]):
322 | h = self.input_blocks[k](h, emb=emb)
323 | # print(i, j, h.shape)
324 | hs[i].append(h)
325 | k += 1
326 | assert k == len(self.input_blocks)
327 |
328 | h = self.middle_block(h, emb=emb)
329 | k = 0
330 | for i in range(len(self.output_num_blocks)):
331 | for j in range(self.output_num_blocks[i]):
332 | # take the lateral connection from the same layer (in reserve)
333 | # until there is no more, use None
334 | try:
335 | lateral = hs[-i - 1].pop()
336 | # print(i, j, lateral.shape)
337 | except IndexError:
338 | lateral = None
339 | # print(i, j, lateral)
340 | h = self.output_blocks[k](h, emb=emb, lateral=lateral)
341 | k += 1
342 |
343 | h = h.type(x.dtype)
344 | pred = self.out(h)
345 | return Return(pred=pred)
346 |
347 |
348 | class Return(NamedTuple):
349 | pred: th.Tensor
350 |
351 |
352 | @dataclass
353 | class BeatGANsEncoderConfig(BaseConfig):
354 | image_size: int
355 | in_channels: int
356 | model_channels: int
357 | out_hid_channels: int
358 | out_channels: int
359 | num_res_blocks: int
360 | attention_resolutions: Tuple[int]
361 | dropout: float = 0
362 | channel_mult: Tuple[int] = (1, 2, 4, 8)
363 | use_time_condition: bool = True
364 | conv_resample: bool = True
365 | dims: int = 2
366 | use_checkpoint: bool = False
367 | num_heads: int = 1
368 | num_head_channels: int = -1
369 | resblock_updown: bool = False
370 | use_new_attention_order: bool = False
371 | pool: str = 'adaptivenonzero'
372 |
373 | def make_model(self):
374 | return BeatGANsEncoderModel(self)
375 |
376 |
377 | class BeatGANsEncoderModel(nn.Module):
378 | """
379 | The half UNet model with attention and timestep embedding.
380 |
381 | For usage, see UNet.
382 | """
383 | def __init__(self, conf: BeatGANsEncoderConfig):
384 | super().__init__()
385 | self.conf = conf
386 | self.dtype = th.float32
387 |
388 | if conf.use_time_condition:
389 | time_embed_dim = conf.model_channels * 4
390 | self.time_embed = nn.Sequential(
391 | linear(conf.model_channels, time_embed_dim),
392 | nn.SiLU(),
393 | linear(time_embed_dim, time_embed_dim),
394 | )
395 | else:
396 | time_embed_dim = None
397 |
398 | ch = int(conf.channel_mult[0] * conf.model_channels)
399 | self.input_blocks = nn.ModuleList([
400 | TimestepEmbedSequential(
401 | conv_nd(conf.dims, conf.in_channels, ch, 3, padding=1))
402 | ])
403 | self._feature_size = ch
404 | input_block_chans = [ch]
405 | ds = 1
406 | resolution = conf.image_size
407 | for level, mult in enumerate(conf.channel_mult):
408 | for _ in range(conf.num_res_blocks):
409 | layers = [
410 | ResBlockConfig(
411 | ch,
412 | time_embed_dim,
413 | conf.dropout,
414 | out_channels=int(mult * conf.model_channels),
415 | dims=conf.dims,
416 | use_condition=conf.use_time_condition,
417 | use_checkpoint=conf.use_checkpoint,
418 | ).make_model()
419 | ]
420 | ch = int(mult * conf.model_channels)
421 | if resolution in conf.attention_resolutions:
422 | layers.append(
423 | AttentionBlock(
424 | ch,
425 | use_checkpoint=conf.use_checkpoint,
426 | num_heads=conf.num_heads,
427 | num_head_channels=conf.num_head_channels,
428 | use_new_attention_order=conf.
429 | use_new_attention_order,
430 | ))
431 | self.input_blocks.append(TimestepEmbedSequential(*layers))
432 | self._feature_size += ch
433 | input_block_chans.append(ch)
434 | if level != len(conf.channel_mult) - 1:
435 | resolution //= 2
436 | out_ch = ch
437 | self.input_blocks.append(
438 | TimestepEmbedSequential(
439 | ResBlockConfig(
440 | ch,
441 | time_embed_dim,
442 | conf.dropout,
443 | out_channels=out_ch,
444 | dims=conf.dims,
445 | use_condition=conf.use_time_condition,
446 | use_checkpoint=conf.use_checkpoint,
447 | down=True,
448 | ).make_model() if (
449 | conf.resblock_updown
450 | ) else Downsample(ch,
451 | conf.conv_resample,
452 | dims=conf.dims,
453 | out_channels=out_ch)))
454 | ch = out_ch
455 | input_block_chans.append(ch)
456 | ds *= 2
457 | self._feature_size += ch
458 |
459 | self.middle_block = TimestepEmbedSequential(
460 | ResBlockConfig(
461 | ch,
462 | time_embed_dim,
463 | conf.dropout,
464 | dims=conf.dims,
465 | use_condition=conf.use_time_condition,
466 | use_checkpoint=conf.use_checkpoint,
467 | ).make_model(),
468 | AttentionBlock(
469 | ch,
470 | use_checkpoint=conf.use_checkpoint,
471 | num_heads=conf.num_heads,
472 | num_head_channels=conf.num_head_channels,
473 | use_new_attention_order=conf.use_new_attention_order,
474 | ),
475 | ResBlockConfig(
476 | ch,
477 | time_embed_dim,
478 | conf.dropout,
479 | dims=conf.dims,
480 | use_condition=conf.use_time_condition,
481 | use_checkpoint=conf.use_checkpoint,
482 | ).make_model(),
483 | )
484 | self._feature_size += ch
485 | if conf.pool == "adaptivenonzero":
486 | self.out = nn.Sequential(
487 | normalization(ch),
488 | nn.SiLU(),
489 | nn.AdaptiveAvgPool2d((1, 1)),
490 | conv_nd(conf.dims, ch, conf.out_channels, 1),
491 | nn.Flatten(),
492 | )
493 | else:
494 | raise NotImplementedError(f"Unexpected {conf.pool} pooling")
495 |
496 | def forward(self, x, t=None, return_2d_feature=False):
497 | """
498 | Apply the model to an input batch.
499 |
500 | :param x: an [N x C x ...] Tensor of inputs.
501 | :param timesteps: a 1-D batch of timesteps.
502 | :return: an [N x K] Tensor of outputs.
503 | """
504 | if self.conf.use_time_condition:
505 | emb = self.time_embed(timestep_embedding(t, self.model_channels))
506 | else:
507 | emb = None
508 |
509 | results = []
510 | h = x.type(self.dtype)
511 | for module in self.input_blocks:
512 | h = module(h, emb=emb)
513 | if self.conf.pool.startswith("spatial"):
514 | results.append(h.type(x.dtype).mean(dim=(2, 3)))
515 | h = self.middle_block(h, emb=emb)
516 | if self.conf.pool.startswith("spatial"):
517 | results.append(h.type(x.dtype).mean(dim=(2, 3)))
518 | h = th.cat(results, axis=-1)
519 | else:
520 | h = h.type(x.dtype)
521 |
522 | h_2d = h
523 | h = self.out(h)
524 |
525 | if return_2d_feature:
526 | return h, h_2d
527 | else:
528 | return h
529 |
530 | def forward_flatten(self, x):
531 | """
532 | transform the last 2d feature into a flatten vector
533 | """
534 | h = self.out(x)
535 | return h
536 |
537 |
538 | class SuperResModel(BeatGANsUNetModel):
539 | """
540 | A UNetModel that performs super-resolution.
541 |
542 | Expects an extra kwarg `low_res` to condition on a low-resolution image.
543 | """
544 | def __init__(self, image_size, in_channels, *args, **kwargs):
545 | super().__init__(image_size, in_channels * 2, *args, **kwargs)
546 |
547 | def forward(self, x, timesteps, low_res=None, **kwargs):
548 | _, _, new_height, new_width = x.shape
549 | upsampled = F.interpolate(low_res, (new_height, new_width),
550 | mode="bilinear")
551 | x = th.cat([x, upsampled], dim=1)
552 | return super().forward(x, timesteps, **kwargs)
553 |
--------------------------------------------------------------------------------
/model/unet_autoenc.py:
--------------------------------------------------------------------------------
1 | from enum import Enum
2 |
3 | import torch
4 | from torch import Tensor
5 | from torch.nn.functional import silu
6 |
7 | from .latentnet import *
8 | from .unet import *
9 | from choices import *
10 |
11 |
12 | @dataclass
13 | class BeatGANsAutoencConfig(BeatGANsUNetConfig):
14 | # number of style channels
15 | enc_out_channels: int = 512
16 | enc_attn_resolutions: Tuple[int] = None
17 | enc_pool: str = 'depthconv'
18 | enc_num_res_block: int = 2
19 | enc_channel_mult: Tuple[int] = None
20 | enc_grad_checkpoint: bool = False
21 | latent_net_conf: MLPSkipNetConfig = None
22 |
23 | def make_model(self):
24 | return BeatGANsAutoencModel(self)
25 |
26 |
27 | class BeatGANsAutoencModel(BeatGANsUNetModel):
28 | def __init__(self, conf: BeatGANsAutoencConfig):
29 | super().__init__(conf)
30 | self.conf = conf
31 |
32 | # having only time, cond
33 | self.time_embed = TimeStyleSeperateEmbed(
34 | time_channels=conf.model_channels,
35 | time_out_channels=conf.embed_channels,
36 | )
37 |
38 | self.encoder = BeatGANsEncoderConfig(
39 | image_size=conf.image_size,
40 | in_channels=conf.in_channels,
41 | model_channels=conf.model_channels,
42 | out_hid_channels=conf.enc_out_channels,
43 | out_channels=conf.enc_out_channels,
44 | num_res_blocks=conf.enc_num_res_block,
45 | attention_resolutions=(conf.enc_attn_resolutions
46 | or conf.attention_resolutions),
47 | dropout=conf.dropout,
48 | channel_mult=conf.enc_channel_mult or conf.channel_mult,
49 | use_time_condition=False,
50 | conv_resample=conf.conv_resample,
51 | dims=conf.dims,
52 | use_checkpoint=conf.use_checkpoint or conf.enc_grad_checkpoint,
53 | num_heads=conf.num_heads,
54 | num_head_channels=conf.num_head_channels,
55 | resblock_updown=conf.resblock_updown,
56 | use_new_attention_order=conf.use_new_attention_order,
57 | pool=conf.enc_pool,
58 | ).make_model()
59 |
60 | if conf.latent_net_conf is not None:
61 | self.latent_net = conf.latent_net_conf.make_model()
62 |
63 | def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
64 | """
65 | Reparameterization trick to sample from N(mu, var) from
66 | N(0,1).
67 | :param mu: (Tensor) Mean of the latent Gaussian [B x D]
68 | :param logvar: (Tensor) Standard deviation of the latent Gaussian [B x D]
69 | :return: (Tensor) [B x D]
70 | """
71 | assert self.conf.is_stochastic
72 | std = torch.exp(0.5 * logvar)
73 | eps = torch.randn_like(std)
74 | return eps * std + mu
75 |
76 | def sample_z(self, n: int, device):
77 | assert self.conf.is_stochastic
78 | return torch.randn(n, self.conf.enc_out_channels, device=device)
79 |
80 | def noise_to_cond(self, noise: Tensor):
81 | raise NotImplementedError()
82 | assert self.conf.noise_net_conf is not None
83 | return self.noise_net.forward(noise)
84 |
85 | def encode(self, x):
86 | cond = self.encoder.forward(x)
87 | return {'cond': cond}
88 |
89 | @property
90 | def stylespace_sizes(self):
91 | modules = list(self.input_blocks.modules()) + list(
92 | self.middle_block.modules()) + list(self.output_blocks.modules())
93 | sizes = []
94 | for module in modules:
95 | if isinstance(module, ResBlock):
96 | linear = module.cond_emb_layers[-1]
97 | sizes.append(linear.weight.shape[0])
98 | return sizes
99 |
100 | def encode_stylespace(self, x, return_vector: bool = True):
101 | """
102 | encode to style space
103 | """
104 | modules = list(self.input_blocks.modules()) + list(
105 | self.middle_block.modules()) + list(self.output_blocks.modules())
106 | # (n, c)
107 | cond = self.encoder.forward(x)
108 | S = []
109 | for module in modules:
110 | if isinstance(module, ResBlock):
111 | # (n, c')
112 | s = module.cond_emb_layers.forward(cond)
113 | S.append(s)
114 |
115 | if return_vector:
116 | # (n, sum_c)
117 | return torch.cat(S, dim=1)
118 | else:
119 | return S
120 |
121 | def forward(self,
122 | x,
123 | t,
124 | y=None,
125 | x_start=None,
126 | cond=None,
127 | style=None,
128 | noise=None,
129 | t_cond=None,
130 | **kwargs):
131 | """
132 | Apply the model to an input batch.
133 |
134 | Args:
135 | x_start: the original image to encode
136 | cond: output of the encoder
137 | noise: random noise (to predict the cond)
138 | """
139 |
140 | if t_cond is None:
141 | t_cond = t
142 |
143 | if noise is not None:
144 | # if the noise is given, we predict the cond from noise
145 | cond = self.noise_to_cond(noise)
146 |
147 | if cond is None:
148 | if x is not None:
149 | assert len(x) == len(x_start), f'{len(x)} != {len(x_start)}'
150 |
151 | tmp = self.encode(x_start)
152 | cond = tmp['cond']
153 |
154 | if t is not None:
155 | _t_emb = timestep_embedding(t, self.conf.model_channels)
156 | _t_cond_emb = timestep_embedding(t_cond, self.conf.model_channels)
157 | else:
158 | # this happens when training only autoenc
159 | _t_emb = None
160 | _t_cond_emb = None
161 |
162 | if self.conf.resnet_two_cond:
163 | res = self.time_embed.forward(
164 | time_emb=_t_emb,
165 | cond=cond,
166 | time_cond_emb=_t_cond_emb,
167 | )
168 | else:
169 | raise NotImplementedError()
170 |
171 | if self.conf.resnet_two_cond:
172 | # two cond: first = time emb, second = cond_emb
173 | emb = res.time_emb
174 | cond_emb = res.emb
175 | else:
176 | # one cond = combined of both time and cond
177 | emb = res.emb
178 | cond_emb = None
179 |
180 | # override the style if given
181 | style = style or res.style
182 |
183 | assert (y is not None) == (
184 | self.conf.num_classes is not None
185 | ), "must specify y if and only if the model is class-conditional"
186 |
187 | if self.conf.num_classes is not None:
188 | raise NotImplementedError()
189 | # assert y.shape == (x.shape[0], )
190 | # emb = emb + self.label_emb(y)
191 |
192 | # where in the model to supply time conditions
193 | enc_time_emb = emb
194 | mid_time_emb = emb
195 | dec_time_emb = emb
196 | # where in the model to supply style conditions
197 | enc_cond_emb = cond_emb
198 | mid_cond_emb = cond_emb
199 | dec_cond_emb = cond_emb
200 |
201 | # hs = []
202 | hs = [[] for _ in range(len(self.conf.channel_mult))]
203 |
204 | if x is not None:
205 | h = x.type(self.dtype)
206 |
207 | # input blocks
208 | k = 0
209 | for i in range(len(self.input_num_blocks)):
210 | for j in range(self.input_num_blocks[i]):
211 | h = self.input_blocks[k](h,
212 | emb=enc_time_emb,
213 | cond=enc_cond_emb)
214 |
215 | # print(i, j, h.shape)
216 | hs[i].append(h)
217 | k += 1
218 | assert k == len(self.input_blocks)
219 |
220 | # middle blocks
221 | h = self.middle_block(h, emb=mid_time_emb, cond=mid_cond_emb)
222 | else:
223 | # no lateral connections
224 | # happens when training only the autonecoder
225 | h = None
226 | hs = [[] for _ in range(len(self.conf.channel_mult))]
227 |
228 | # output blocks
229 | k = 0
230 | for i in range(len(self.output_num_blocks)):
231 | for j in range(self.output_num_blocks[i]):
232 | # take the lateral connection from the same layer (in reserve)
233 | # until there is no more, use None
234 | try:
235 | lateral = hs[-i - 1].pop()
236 | # print(i, j, lateral.shape)
237 | except IndexError:
238 | lateral = None
239 | # print(i, j, lateral)
240 |
241 | h = self.output_blocks[k](h,
242 | emb=dec_time_emb,
243 | cond=dec_cond_emb,
244 | lateral=lateral)
245 | k += 1
246 |
247 | pred = self.out(h)
248 | return AutoencReturn(pred=pred, cond=cond)
249 |
250 |
251 | class AutoencReturn(NamedTuple):
252 | pred: Tensor
253 | cond: Tensor = None
254 |
255 |
256 | class EmbedReturn(NamedTuple):
257 | # style and time
258 | emb: Tensor = None
259 | # time only
260 | time_emb: Tensor = None
261 | # style only (but could depend on time)
262 | style: Tensor = None
263 |
264 |
265 | class TimeStyleSeperateEmbed(nn.Module):
266 | # embed only style
267 | def __init__(self, time_channels, time_out_channels):
268 | super().__init__()
269 | self.time_embed = nn.Sequential(
270 | linear(time_channels, time_out_channels),
271 | nn.SiLU(),
272 | linear(time_out_channels, time_out_channels),
273 | )
274 | self.style = nn.Identity()
275 |
276 | def forward(self, time_emb=None, cond=None, **kwargs):
277 | if time_emb is None:
278 | # happens with autoenc training mode
279 | time_emb = None
280 | else:
281 | time_emb = self.time_embed(time_emb)
282 | style = self.style(cond)
283 | return EmbedReturn(emb=style, time_emb=time_emb, style=style)
284 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # pre-download the weights for 256 resolution model to checkpoints/ffhq256_autoenc and checkpoints/ffhq256_autoenc_cls
2 | # wget http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
3 | # bunzip2 shape_predictor_68_face_landmarks.dat.bz2
4 |
5 | import os
6 | import torch
7 | from torchvision.utils import save_image
8 | import tempfile
9 | from templates import *
10 | from templates_cls import *
11 | from experiment_classifier import ClsModel
12 | from align import LandmarksDetector, image_align
13 | from cog import BasePredictor, Path, Input, BaseModel
14 |
15 |
16 | class ModelOutput(BaseModel):
17 | image: Path
18 |
19 |
20 | class Predictor(BasePredictor):
21 | def setup(self):
22 | self.aligned_dir = "aligned"
23 | os.makedirs(self.aligned_dir, exist_ok=True)
24 | self.device = "cuda:0"
25 |
26 | # Model Initialization
27 | model_config = ffhq256_autoenc()
28 | self.model = LitModel(model_config)
29 | state = torch.load("checkpoints/ffhq256_autoenc/last.ckpt", map_location="cpu")
30 | self.model.load_state_dict(state["state_dict"], strict=False)
31 | self.model.ema_model.eval()
32 | self.model.ema_model.to(self.device)
33 |
34 | # Classifier Initialization
35 | classifier_config = ffhq256_autoenc_cls()
36 | classifier_config.pretrain = None # a bit faster
37 | self.classifier = ClsModel(classifier_config)
38 | state_class = torch.load(
39 | "checkpoints/ffhq256_autoenc_cls/last.ckpt", map_location="cpu"
40 | )
41 | print("latent step:", state_class["global_step"])
42 | self.classifier.load_state_dict(state_class["state_dict"], strict=False)
43 | self.classifier.to(self.device)
44 |
45 | self.landmarks_detector = LandmarksDetector(
46 | "shape_predictor_68_face_landmarks.dat"
47 | )
48 |
49 | def predict(
50 | self,
51 | image: Path = Input(
52 | description="Input image for face manipulation. Image will be aligned and cropped, "
53 | "output aligned and manipulated images.",
54 | ),
55 | target_class: str = Input(
56 | default="Bangs",
57 | choices=[
58 | "5_o_Clock_Shadow",
59 | "Arched_Eyebrows",
60 | "Attractive",
61 | "Bags_Under_Eyes",
62 | "Bald",
63 | "Bangs",
64 | "Big_Lips",
65 | "Big_Nose",
66 | "Black_Hair",
67 | "Blond_Hair",
68 | "Blurry",
69 | "Brown_Hair",
70 | "Bushy_Eyebrows",
71 | "Chubby",
72 | "Double_Chin",
73 | "Eyeglasses",
74 | "Goatee",
75 | "Gray_Hair",
76 | "Heavy_Makeup",
77 | "High_Cheekbones",
78 | "Male",
79 | "Mouth_Slightly_Open",
80 | "Mustache",
81 | "Narrow_Eyes",
82 | "Beard",
83 | "Oval_Face",
84 | "Pale_Skin",
85 | "Pointy_Nose",
86 | "Receding_Hairline",
87 | "Rosy_Cheeks",
88 | "Sideburns",
89 | "Smiling",
90 | "Straight_Hair",
91 | "Wavy_Hair",
92 | "Wearing_Earrings",
93 | "Wearing_Hat",
94 | "Wearing_Lipstick",
95 | "Wearing_Necklace",
96 | "Wearing_Necktie",
97 | "Young",
98 | ],
99 | description="Choose manipulation direction.",
100 | ),
101 | manipulation_amplitude: float = Input(
102 | default=0.3,
103 | ge=-0.5,
104 | le=0.5,
105 | description="When set too strong it would result in artifact as it could dominate the original image information.",
106 | ),
107 | T_step: int = Input(
108 | default=100,
109 | choices=[50, 100, 125, 200, 250, 500],
110 | description="Number of step for generation.",
111 | ),
112 | T_inv: int = Input(default=200, choices=[50, 100, 125, 200, 250, 500]),
113 | ) -> List[ModelOutput]:
114 |
115 | img_size = 256
116 | print("Aligning image...")
117 | for i, face_landmarks in enumerate(
118 | self.landmarks_detector.get_landmarks(str(image)), start=1
119 | ):
120 | image_align(str(image), f"{self.aligned_dir}/aligned.png", face_landmarks)
121 |
122 | data = ImageDataset(
123 | self.aligned_dir,
124 | image_size=img_size,
125 | exts=["jpg", "jpeg", "JPG", "png"],
126 | do_augment=False,
127 | )
128 |
129 | print("Encoding and Manipulating the aligned image...")
130 | cls_manipulation_amplitude = manipulation_amplitude
131 | interpreted_target_class = target_class
132 | if (
133 | target_class not in CelebAttrDataset.id_to_cls
134 | and f"No_{target_class}" in CelebAttrDataset.id_to_cls
135 | ):
136 | cls_manipulation_amplitude = -manipulation_amplitude
137 | interpreted_target_class = f"No_{target_class}"
138 |
139 | batch = data[0]["img"][None]
140 |
141 | semantic_latent = self.model.encode(batch.to(self.device))
142 | stochastic_latent = self.model.encode_stochastic(
143 | batch.to(self.device), semantic_latent, T=T_inv
144 | )
145 |
146 | cls_id = CelebAttrDataset.cls_to_id[interpreted_target_class]
147 | class_direction = self.classifier.classifier.weight[cls_id]
148 | normalized_class_direction = F.normalize(class_direction[None, :], dim=1)
149 |
150 | normalized_semantic_latent = self.classifier.normalize(semantic_latent)
151 | normalized_manipulation_amp = cls_manipulation_amplitude * math.sqrt(512)
152 | normalized_manipulated_semantic_latent = (
153 | normalized_semantic_latent
154 | + normalized_manipulation_amp * normalized_class_direction
155 | )
156 |
157 | manipulated_semantic_latent = self.classifier.denormalize(
158 | normalized_manipulated_semantic_latent
159 | )
160 |
161 | # Render Manipulated image
162 | manipulated_img = self.model.render(
163 | stochastic_latent, manipulated_semantic_latent, T=T_step
164 | )[0]
165 | original_img = data[0]["img"]
166 |
167 | model_output = []
168 | out_path = Path(tempfile.mkdtemp()) / "original_aligned.png"
169 | save_image(convert2rgb(original_img), str(out_path))
170 | model_output.append(ModelOutput(image=out_path))
171 |
172 | out_path = Path(tempfile.mkdtemp()) / "manipulated_img.png"
173 | save_image(convert2rgb(manipulated_img, adjust_scale=False), str(out_path))
174 | model_output.append(ModelOutput(image=out_path))
175 | return model_output
176 |
177 |
178 | def convert2rgb(img, adjust_scale=True):
179 | convert_img = torch.tensor(img)
180 | if adjust_scale:
181 | convert_img = (convert_img + 1) / 2
182 | return convert_img.cpu()
183 |
--------------------------------------------------------------------------------
/renderer.py:
--------------------------------------------------------------------------------
1 | from config import *
2 |
3 | from torch.cuda import amp
4 |
5 |
6 | def render_uncondition(conf: TrainConfig,
7 | model: BeatGANsAutoencModel,
8 | x_T,
9 | sampler: Sampler,
10 | latent_sampler: Sampler,
11 | conds_mean=None,
12 | conds_std=None,
13 | clip_latent_noise: bool = False):
14 | device = x_T.device
15 | if conf.train_mode == TrainMode.diffusion:
16 | assert conf.model_type.can_sample()
17 | return sampler.sample(model=model, noise=x_T)
18 | elif conf.train_mode.is_latent_diffusion():
19 | model: BeatGANsAutoencModel
20 | if conf.train_mode == TrainMode.latent_diffusion:
21 | latent_noise = torch.randn(len(x_T), conf.style_ch, device=device)
22 | else:
23 | raise NotImplementedError()
24 |
25 | if clip_latent_noise:
26 | latent_noise = latent_noise.clip(-1, 1)
27 |
28 | cond = latent_sampler.sample(
29 | model=model.latent_net,
30 | noise=latent_noise,
31 | clip_denoised=conf.latent_clip_sample,
32 | )
33 |
34 | if conf.latent_znormalize:
35 | cond = cond * conds_std.to(device) + conds_mean.to(device)
36 |
37 | # the diffusion on the model
38 | return sampler.sample(model=model, noise=x_T, cond=cond)
39 | else:
40 | raise NotImplementedError()
41 |
42 |
43 | def render_condition(
44 | conf: TrainConfig,
45 | model: BeatGANsAutoencModel,
46 | x_T,
47 | sampler: Sampler,
48 | x_start=None,
49 | cond=None,
50 | ):
51 | if conf.train_mode == TrainMode.diffusion:
52 | assert conf.model_type.has_autoenc()
53 | # returns {'cond', 'cond2'}
54 | if cond is None:
55 | cond = model.encode(x_start)
56 | return sampler.sample(model=model,
57 | noise=x_T,
58 | model_kwargs={'cond': cond})
59 | else:
60 | raise NotImplementedError()
61 |
--------------------------------------------------------------------------------
/requirement_for_colab.txt:
--------------------------------------------------------------------------------
1 | torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 pytorch-lightning==1.2.2 torchtext==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
2 | scipy==1.5.4
3 | numpy==1.19.5
4 | tqdm
5 | pytorch-fid==0.2.0
6 | pandas==1.1.5
7 | lpips==0.1.4
8 | lmdb==1.2.1
9 | ftfy
10 | regex
11 | dlib requests
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.4.5
2 | torchmetrics==0.5.0
3 | torch==1.8.1
4 | torchvision
5 | scipy==1.5.4
6 | numpy==1.19.5
7 | tqdm
8 | pytorch-fid==0.2.0
9 | pandas==1.1.5
10 | lpips==0.1.4
11 | lmdb==1.2.1
12 | ftfy
13 | regex
--------------------------------------------------------------------------------
/run_bedroom128.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | # train the autoenc moodel
6 | # this requires V100s.
7 | gpus = [0, 1, 2, 3]
8 | conf = bedroom128_autoenc()
9 | train(conf, gpus=gpus)
10 |
11 | # infer the latents for training the latent DPM
12 | # NOTE: not gpu heavy, but more gpus can be of use!
13 | gpus = [0, 1, 2, 3]
14 | conf.eval_programs = ['infer']
15 | train(conf, gpus=gpus, mode='eval')
16 |
17 | # train the latent DPM
18 | # NOTE: only need a single gpu
19 | gpus = [0]
20 | conf = bedroom128_autoenc_latent()
21 | train(conf, gpus=gpus)
22 |
23 | # unconditional sampling score
24 | # NOTE: a lot of gpus can speed up this process
25 | gpus = [0, 1, 2, 3]
26 | conf.eval_programs = ['fid(10,10)']
27 | train(conf, gpus=gpus, mode='eval')
--------------------------------------------------------------------------------
/run_bedroom128_ddim.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | gpus = [0, 1, 2, 3]
6 | conf = bedroom128_ddpm()
7 | train(conf, gpus=gpus)
8 |
9 | gpus = [0, 1, 2, 3]
10 | conf.eval_programs = ['fid10']
11 | train(conf, gpus=gpus, mode='eval')
--------------------------------------------------------------------------------
/run_celeba64.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | # train the autoenc moodel
6 | # this can be run on 2080Ti's.
7 | gpus = [0, 1, 2, 3]
8 | conf = celeba64d2c_autoenc()
9 | train(conf, gpus=gpus)
10 |
11 | # infer the latents for training the latent DPM
12 | # NOTE: not gpu heavy, but more gpus can be of use!
13 | gpus = [0, 1, 2, 3]
14 | conf.eval_programs = ['infer']
15 | train(conf, gpus=gpus, mode='eval')
16 |
17 | # train the latent DPM
18 | # NOTE: only need a single gpu
19 | gpus = [0]
20 | conf = celeba64d2c_autoenc_latent()
21 | train(conf, gpus=gpus)
22 |
23 | # unconditional sampling score
24 | # NOTE: a lot of gpus can speed up this process
25 | gpus = [0, 1, 2, 3]
26 | conf.eval_programs = ['fid(10,10)']
27 | train(conf, gpus=gpus, mode='eval')
--------------------------------------------------------------------------------
/run_ffhq128.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | # train the autoenc moodel
6 | # this requires V100s.
7 | gpus = [0, 1, 2, 3]
8 | conf = ffhq128_autoenc_130M()
9 | train(conf, gpus=gpus)
10 |
11 | # infer the latents for training the latent DPM
12 | # NOTE: not gpu heavy, but more gpus can be of use!
13 | gpus = [0, 1, 2, 3]
14 | conf.eval_programs = ['infer']
15 | train(conf, gpus=gpus, mode='eval')
16 |
17 | # train the latent DPM
18 | # NOTE: only need a single gpu
19 | gpus = [0]
20 | conf = ffhq128_autoenc_latent()
21 | train(conf, gpus=gpus)
22 |
23 | # unconditional sampling score
24 | # NOTE: a lot of gpus can speed up this process
25 | gpus = [0, 1, 2, 3]
26 | conf.eval_programs = ['fid(10,10)']
27 | train(conf, gpus=gpus, mode='eval')
--------------------------------------------------------------------------------
/run_ffhq128_cls.py:
--------------------------------------------------------------------------------
1 | from templates_cls import *
2 | from experiment_classifier import *
3 |
4 | if __name__ == '__main__':
5 | # need to first train the diffae autoencoding model & infer the latents
6 | # this requires only a single GPU.
7 | gpus = [0]
8 | conf = ffhq128_autoenc_cls()
9 | train_cls(conf, gpus=gpus)
10 |
11 | # after this you can do the manipulation!
12 |
--------------------------------------------------------------------------------
/run_ffhq128_ddim.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | gpus = [0, 1, 2, 3]
6 | conf = ffhq128_ddpm_130M()
7 | train(conf, gpus=gpus)
8 |
9 | gpus = [0, 1, 2, 3]
10 | conf.eval_programs = ['fid10']
11 | train(conf, gpus=gpus, mode='eval')
--------------------------------------------------------------------------------
/run_ffhq256.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | # 256 requires 8x v100s, in our case, on two nodes.
6 | # do not run this directly, use `sbatch run_ffhq256.sh` to spawn the srun properly.
7 | gpus = [0, 1, 2, 3]
8 | nodes = 2
9 | conf = ffhq256_autoenc()
10 | train(conf, gpus=gpus, nodes=nodes)
--------------------------------------------------------------------------------
/run_ffhq256.sh:
--------------------------------------------------------------------------------
1 | #!/bin/sh
2 | #SBATCH --gres=gpu:4
3 | #SBATCH --cpus-per-gpu=8
4 | #SBATCH --mem-per-gpu=32GB
5 | #SBATCH --nodes=2
6 | #SBATCH --ntasks=8
7 | #SBATCH --partition=gpu-cluster
8 | #SBATCH --time=72:00:00
9 |
10 | export NCCL_DEBUG=INFO
11 | export PYTHONFAULTHANDLER=1
12 |
13 | srun python run_ffhq256.py
--------------------------------------------------------------------------------
/run_ffhq256_cls.py:
--------------------------------------------------------------------------------
1 | from templates_cls import *
2 | from experiment_classifier import *
3 |
4 | if __name__ == '__main__':
5 | # need to first train the diffae autoencoding model & infer the latents
6 | # this requires only a single GPU.
7 | gpus = [0]
8 | conf = ffhq256_autoenc_cls()
9 | train_cls(conf, gpus=gpus)
10 |
11 | # after this you can do the manipulation!
12 |
--------------------------------------------------------------------------------
/run_ffhq256_latent.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | # do run the run_ffhq256 before using the file to train the latent DPM
6 |
7 | # infer the latents for training the latent DPM
8 | # NOTE: not gpu heavy, but more gpus can be of use!
9 | gpus = [0, 1, 2, 3]
10 | conf = ffhq256_autoenc()
11 | conf.eval_programs = ['infer']
12 | train(conf, gpus=gpus, mode='eval')
13 |
14 | # train the latent DPM
15 | # NOTE: only need a single gpu
16 | gpus = [0]
17 | conf = ffhq256_autoenc_latent()
18 | train(conf, gpus=gpus)
19 |
--------------------------------------------------------------------------------
/run_horse128.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | # train the autoenc moodel
6 | # this requires V100s.
7 | gpus = [0, 1, 2, 3]
8 | conf = horse128_autoenc()
9 | train(conf, gpus=gpus)
10 |
11 | # infer the latents for training the latent DPM
12 | # NOTE: not gpu heavy, but more gpus can be of use!
13 | gpus = [0, 1, 2, 3]
14 | conf.eval_programs = ['infer']
15 | train(conf, gpus=gpus, mode='eval')
16 |
17 | # train the latent DPM
18 | # NOTE: only need a single gpu
19 | gpus = [0]
20 | conf = horse128_autoenc_latent()
21 | train(conf, gpus=gpus)
22 |
23 | # unconditional sampling score
24 | # NOTE: a lot of gpus can speed up this process
25 | gpus = [0, 1, 2, 3]
26 | conf.eval_programs = ['fid(10,10)']
27 | train(conf, gpus=gpus, mode='eval')
--------------------------------------------------------------------------------
/run_horse128_ddim.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 | from templates_latent import *
3 |
4 | if __name__ == '__main__':
5 | gpus = [0, 1, 2, 3]
6 | conf = horse128_ddpm()
7 | train(conf, gpus=gpus)
8 |
9 | gpus = [0, 1, 2, 3]
10 | conf.eval_programs = ['fid10']
11 | train(conf, gpus=gpus, mode='eval')
--------------------------------------------------------------------------------
/ssim.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | from math import exp
6 |
7 |
8 | def gaussian(window_size, sigma):
9 | gauss = torch.Tensor([
10 | exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
11 | for x in range(window_size)
12 | ])
13 | return gauss / gauss.sum()
14 |
15 |
16 | def create_window(window_size, channel):
17 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
18 | _2D_window = _1D_window.mm(
19 | _1D_window.t()).float().unsqueeze(0).unsqueeze(0)
20 | window = Variable(
21 | _2D_window.expand(channel, 1, window_size, window_size).contiguous())
22 | return window
23 |
24 |
25 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
26 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
27 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
28 |
29 | mu1_sq = mu1.pow(2)
30 | mu2_sq = mu2.pow(2)
31 | mu1_mu2 = mu1 * mu2
32 |
33 | sigma1_sq = F.conv2d(
34 | img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
35 | sigma2_sq = F.conv2d(
36 | img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
37 | sigma12 = F.conv2d(
38 | img1 * img2, window, padding=window_size // 2,
39 | groups=channel) - mu1_mu2
40 |
41 | C1 = 0.01**2
42 | C2 = 0.03**2
43 |
44 | ssim_map = ((2 * mu1_mu2 + C1) *
45 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
46 | (sigma1_sq + sigma2_sq + C2))
47 |
48 | if size_average:
49 | return ssim_map.mean()
50 | else:
51 | return ssim_map.mean(1).mean(1).mean(1)
52 |
53 |
54 | class SSIM(torch.nn.Module):
55 | def __init__(self, window_size=11, size_average=True):
56 | super(SSIM, self).__init__()
57 | self.window_size = window_size
58 | self.size_average = size_average
59 | self.channel = 1
60 | self.window = create_window(window_size, self.channel)
61 |
62 | def forward(self, img1, img2):
63 | (_, channel, _, _) = img1.size()
64 |
65 | if channel == self.channel and self.window.data.type(
66 | ) == img1.data.type():
67 | window = self.window
68 | else:
69 | window = create_window(self.window_size, channel)
70 |
71 | if img1.is_cuda:
72 | window = window.cuda(img1.get_device())
73 | window = window.type_as(img1)
74 |
75 | self.window = window
76 | self.channel = channel
77 |
78 | return _ssim(img1, img2, window, self.window_size, channel,
79 | self.size_average)
80 |
81 |
82 | def ssim(img1, img2, window_size=11, size_average=True):
83 | (_, channel, _, _) = img1.size()
84 | window = create_window(window_size, channel)
85 |
86 | if img1.is_cuda:
87 | window = window.cuda(img1.get_device())
88 | window = window.type_as(img1)
89 |
90 | return _ssim(img1, img2, window, window_size, channel, size_average)
--------------------------------------------------------------------------------
/templates.py:
--------------------------------------------------------------------------------
1 | from experiment import *
2 |
3 |
4 | def ddpm():
5 | """
6 | base configuration for all DDIM-based models.
7 | """
8 | conf = TrainConfig()
9 | conf.batch_size = 32
10 | conf.beatgans_gen_type = GenerativeType.ddim
11 | conf.beta_scheduler = 'linear'
12 | conf.data_name = 'ffhq'
13 | conf.diffusion_type = 'beatgans'
14 | conf.eval_ema_every_samples = 200_000
15 | conf.eval_every_samples = 200_000
16 | conf.fp16 = True
17 | conf.lr = 1e-4
18 | conf.model_name = ModelName.beatgans_ddpm
19 | conf.net_attn = (16, )
20 | conf.net_beatgans_attn_head = 1
21 | conf.net_beatgans_embed_channels = 512
22 | conf.net_ch_mult = (1, 2, 4, 8)
23 | conf.net_ch = 64
24 | conf.sample_size = 32
25 | conf.T_eval = 20
26 | conf.T = 1000
27 | conf.make_model_conf()
28 | return conf
29 |
30 |
31 | def autoenc_base():
32 | """
33 | base configuration for all Diff-AE models.
34 | """
35 | conf = TrainConfig()
36 | conf.batch_size = 32
37 | conf.beatgans_gen_type = GenerativeType.ddim
38 | conf.beta_scheduler = 'linear'
39 | conf.data_name = 'ffhq'
40 | conf.diffusion_type = 'beatgans'
41 | conf.eval_ema_every_samples = 200_000
42 | conf.eval_every_samples = 200_000
43 | conf.fp16 = True
44 | conf.lr = 1e-4
45 | conf.model_name = ModelName.beatgans_autoenc
46 | conf.net_attn = (16, )
47 | conf.net_beatgans_attn_head = 1
48 | conf.net_beatgans_embed_channels = 512
49 | conf.net_beatgans_resnet_two_cond = True
50 | conf.net_ch_mult = (1, 2, 4, 8)
51 | conf.net_ch = 64
52 | conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
53 | conf.net_enc_pool = 'adaptivenonzero'
54 | conf.sample_size = 32
55 | conf.T_eval = 20
56 | conf.T = 1000
57 | conf.make_model_conf()
58 | return conf
59 |
60 |
61 | def ffhq64_ddpm():
62 | conf = ddpm()
63 | conf.data_name = 'ffhqlmdb256'
64 | conf.warmup = 0
65 | conf.total_samples = 72_000_000
66 | conf.scale_up_gpus(4)
67 | return conf
68 |
69 |
70 | def ffhq64_autoenc():
71 | conf = autoenc_base()
72 | conf.data_name = 'ffhqlmdb256'
73 | conf.warmup = 0
74 | conf.total_samples = 72_000_000
75 | conf.net_ch_mult = (1, 2, 4, 8)
76 | conf.net_enc_channel_mult = (1, 2, 4, 8, 8)
77 | conf.eval_every_samples = 1_000_000
78 | conf.eval_ema_every_samples = 1_000_000
79 | conf.scale_up_gpus(4)
80 | conf.make_model_conf()
81 | return conf
82 |
83 |
84 | def celeba64d2c_ddpm():
85 | conf = ffhq128_ddpm()
86 | conf.data_name = 'celebalmdb'
87 | conf.eval_every_samples = 10_000_000
88 | conf.eval_ema_every_samples = 10_000_000
89 | conf.total_samples = 72_000_000
90 | conf.name = 'celeba64d2c_ddpm'
91 | return conf
92 |
93 |
94 | def celeba64d2c_autoenc():
95 | conf = ffhq64_autoenc()
96 | conf.data_name = 'celebalmdb'
97 | conf.eval_every_samples = 10_000_000
98 | conf.eval_ema_every_samples = 10_000_000
99 | conf.total_samples = 72_000_000
100 | conf.name = 'celeba64d2c_autoenc'
101 | return conf
102 |
103 |
104 | def ffhq128_ddpm():
105 | conf = ddpm()
106 | conf.data_name = 'ffhqlmdb256'
107 | conf.warmup = 0
108 | conf.total_samples = 48_000_000
109 | conf.img_size = 128
110 | conf.net_ch = 128
111 | # channels:
112 | # 3 => 128 * 1 => 128 * 1 => 128 * 2 => 128 * 3 => 128 * 4
113 | # sizes:
114 | # 128 => 128 => 64 => 32 => 16 => 8
115 | conf.net_ch_mult = (1, 1, 2, 3, 4)
116 | conf.eval_every_samples = 1_000_000
117 | conf.eval_ema_every_samples = 1_000_000
118 | conf.scale_up_gpus(4)
119 | conf.eval_ema_every_samples = 10_000_000
120 | conf.eval_every_samples = 10_000_000
121 | conf.make_model_conf()
122 | return conf
123 |
124 |
125 | def ffhq128_autoenc_base():
126 | conf = autoenc_base()
127 | conf.data_name = 'ffhqlmdb256'
128 | conf.scale_up_gpus(4)
129 | conf.img_size = 128
130 | conf.net_ch = 128
131 | # final resolution = 8x8
132 | conf.net_ch_mult = (1, 1, 2, 3, 4)
133 | # final resolution = 4x4
134 | conf.net_enc_channel_mult = (1, 1, 2, 3, 4, 4)
135 | conf.eval_ema_every_samples = 10_000_000
136 | conf.eval_every_samples = 10_000_000
137 | conf.make_model_conf()
138 | return conf
139 |
140 |
141 | def ffhq256_autoenc():
142 | conf = ffhq128_autoenc_base()
143 | conf.img_size = 256
144 | conf.net_ch = 128
145 | conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
146 | conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
147 | conf.eval_every_samples = 10_000_000
148 | conf.eval_ema_every_samples = 10_000_000
149 | conf.total_samples = 200_000_000
150 | conf.batch_size = 64
151 | conf.make_model_conf()
152 | conf.name = 'ffhq256_autoenc'
153 | return conf
154 |
155 |
156 | def ffhq256_autoenc_eco():
157 | conf = ffhq128_autoenc_base()
158 | conf.img_size = 256
159 | conf.net_ch = 128
160 | conf.net_ch_mult = (1, 1, 2, 2, 4, 4)
161 | conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4)
162 | conf.eval_every_samples = 10_000_000
163 | conf.eval_ema_every_samples = 10_000_000
164 | conf.total_samples = 200_000_000
165 | conf.batch_size = 64
166 | conf.make_model_conf()
167 | conf.name = 'ffhq256_autoenc_eco'
168 | return conf
169 |
170 |
171 | def ffhq128_ddpm_72M():
172 | conf = ffhq128_ddpm()
173 | conf.total_samples = 72_000_000
174 | conf.name = 'ffhq128_ddpm_72M'
175 | return conf
176 |
177 |
178 | def ffhq128_autoenc_72M():
179 | conf = ffhq128_autoenc_base()
180 | conf.total_samples = 72_000_000
181 | conf.name = 'ffhq128_autoenc_72M'
182 | return conf
183 |
184 |
185 | def ffhq128_ddpm_130M():
186 | conf = ffhq128_ddpm()
187 | conf.total_samples = 130_000_000
188 | conf.eval_ema_every_samples = 10_000_000
189 | conf.eval_every_samples = 10_000_000
190 | conf.name = 'ffhq128_ddpm_130M'
191 | return conf
192 |
193 |
194 | def ffhq128_autoenc_130M():
195 | conf = ffhq128_autoenc_base()
196 | conf.total_samples = 130_000_000
197 | conf.eval_ema_every_samples = 10_000_000
198 | conf.eval_every_samples = 10_000_000
199 | conf.name = 'ffhq128_autoenc_130M'
200 | return conf
201 |
202 |
203 | def horse128_ddpm():
204 | conf = ffhq128_ddpm()
205 | conf.data_name = 'horse256'
206 | conf.total_samples = 130_000_000
207 | conf.eval_ema_every_samples = 10_000_000
208 | conf.eval_every_samples = 10_000_000
209 | conf.name = 'horse128_ddpm'
210 | return conf
211 |
212 |
213 | def horse128_autoenc():
214 | conf = ffhq128_autoenc_base()
215 | conf.data_name = 'horse256'
216 | conf.total_samples = 130_000_000
217 | conf.eval_ema_every_samples = 10_000_000
218 | conf.eval_every_samples = 10_000_000
219 | conf.name = 'horse128_autoenc'
220 | return conf
221 |
222 |
223 | def bedroom128_ddpm():
224 | conf = ffhq128_ddpm()
225 | conf.data_name = 'bedroom256'
226 | conf.eval_ema_every_samples = 10_000_000
227 | conf.eval_every_samples = 10_000_000
228 | conf.total_samples = 120_000_000
229 | conf.name = 'bedroom128_ddpm'
230 | return conf
231 |
232 |
233 | def bedroom128_autoenc():
234 | conf = ffhq128_autoenc_base()
235 | conf.data_name = 'bedroom256'
236 | conf.eval_ema_every_samples = 10_000_000
237 | conf.eval_every_samples = 10_000_000
238 | conf.total_samples = 120_000_000
239 | conf.name = 'bedroom128_autoenc'
240 | return conf
241 |
242 |
243 | def pretrain_celeba64d2c_72M():
244 | conf = celeba64d2c_autoenc()
245 | conf.pretrain = PretrainConfig(
246 | name='72M',
247 | path=f'checkpoints/{celeba64d2c_autoenc().name}/last.ckpt',
248 | )
249 | conf.latent_infer_path = f'checkpoints/{celeba64d2c_autoenc().name}/latent.pkl'
250 | return conf
251 |
252 |
253 | def pretrain_ffhq128_autoenc72M():
254 | conf = ffhq128_autoenc_base()
255 | conf.postfix = ''
256 | conf.pretrain = PretrainConfig(
257 | name='72M',
258 | path=f'checkpoints/{ffhq128_autoenc_72M().name}/last.ckpt',
259 | )
260 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_72M().name}/latent.pkl'
261 | return conf
262 |
263 |
264 | def pretrain_ffhq128_autoenc130M():
265 | conf = ffhq128_autoenc_base()
266 | conf.pretrain = PretrainConfig(
267 | name='130M',
268 | path=f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt',
269 | )
270 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl'
271 | return conf
272 |
273 |
274 | def pretrain_ffhq256_autoenc():
275 | conf = ffhq256_autoenc()
276 | conf.pretrain = PretrainConfig(
277 | name='90M',
278 | path=f'checkpoints/{ffhq256_autoenc().name}/last.ckpt',
279 | )
280 | conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl'
281 | return conf
282 |
283 |
284 | def pretrain_horse128():
285 | conf = horse128_autoenc()
286 | conf.pretrain = PretrainConfig(
287 | name='82M',
288 | path=f'checkpoints/{horse128_autoenc().name}/last.ckpt',
289 | )
290 | conf.latent_infer_path = f'checkpoints/{horse128_autoenc().name}/latent.pkl'
291 | return conf
292 |
293 |
294 | def pretrain_bedroom128():
295 | conf = bedroom128_autoenc()
296 | conf.pretrain = PretrainConfig(
297 | name='120M',
298 | path=f'checkpoints/{bedroom128_autoenc().name}/last.ckpt',
299 | )
300 | conf.latent_infer_path = f'checkpoints/{bedroom128_autoenc().name}/latent.pkl'
301 | return conf
302 |
--------------------------------------------------------------------------------
/templates_cls.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 |
3 |
4 | def ffhq128_autoenc_cls():
5 | conf = ffhq128_autoenc_130M()
6 | conf.train_mode = TrainMode.manipulate
7 | conf.manipulate_mode = ManipulateMode.celebahq_all
8 | conf.manipulate_znormalize = True
9 | conf.latent_infer_path = f'checkpoints/{ffhq128_autoenc_130M().name}/latent.pkl'
10 | conf.batch_size = 32
11 | conf.lr = 1e-3
12 | conf.total_samples = 300_000
13 | # use the pretraining trick instead of contiuning trick
14 | conf.pretrain = PretrainConfig(
15 | '130M',
16 | f'checkpoints/{ffhq128_autoenc_130M().name}/last.ckpt',
17 | )
18 | conf.name = 'ffhq128_autoenc_cls'
19 | return conf
20 |
21 |
22 | def ffhq256_autoenc_cls():
23 | '''We first train the encoder on FFHQ dataset then use it as a pretrained to train a linear classifer on CelebA dataset with attribute labels'''
24 | conf = ffhq256_autoenc()
25 | conf.train_mode = TrainMode.manipulate
26 | conf.manipulate_mode = ManipulateMode.celebahq_all
27 | conf.manipulate_znormalize = True
28 | conf.latent_infer_path = f'checkpoints/{ffhq256_autoenc().name}/latent.pkl' # we train on Celeb dataset, not FFHQ
29 | conf.batch_size = 32
30 | conf.lr = 1e-3
31 | conf.total_samples = 300_000
32 | # use the pretraining trick instead of contiuning trick
33 | conf.pretrain = PretrainConfig(
34 | '130M',
35 | f'checkpoints/{ffhq256_autoenc().name}/last.ckpt',
36 | )
37 | conf.name = 'ffhq256_autoenc_cls'
38 | return conf
39 |
--------------------------------------------------------------------------------
/templates_latent.py:
--------------------------------------------------------------------------------
1 | from templates import *
2 |
3 |
4 | def latent_diffusion_config(conf: TrainConfig):
5 | conf.batch_size = 128
6 | conf.train_mode = TrainMode.latent_diffusion
7 | conf.latent_gen_type = GenerativeType.ddim
8 | conf.latent_loss_type = LossType.mse
9 | conf.latent_model_mean_type = ModelMeanType.eps
10 | conf.latent_model_var_type = ModelVarType.fixed_large
11 | conf.latent_rescale_timesteps = False
12 | conf.latent_clip_sample = False
13 | conf.latent_T_eval = 20
14 | conf.latent_znormalize = True
15 | conf.total_samples = 96_000_000
16 | conf.sample_every_samples = 400_000
17 | conf.eval_every_samples = 20_000_000
18 | conf.eval_ema_every_samples = 20_000_000
19 | conf.save_every_samples = 2_000_000
20 | return conf
21 |
22 |
23 | def latent_diffusion128_config(conf: TrainConfig):
24 | conf = latent_diffusion_config(conf)
25 | conf.batch_size_eval = 32
26 | return conf
27 |
28 |
29 | def latent_mlp_2048_norm_10layers(conf: TrainConfig):
30 | conf.net_latent_net_type = LatentNetType.skip
31 | conf.net_latent_layers = 10
32 | conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers))
33 | conf.net_latent_activation = Activation.silu
34 | conf.net_latent_num_hid_channels = 2048
35 | conf.net_latent_use_norm = True
36 | conf.net_latent_condition_bias = 1
37 | return conf
38 |
39 |
40 | def latent_mlp_2048_norm_20layers(conf: TrainConfig):
41 | conf = latent_mlp_2048_norm_10layers(conf)
42 | conf.net_latent_layers = 20
43 | conf.net_latent_skip_layers = list(range(1, conf.net_latent_layers))
44 | return conf
45 |
46 |
47 | def latent_256_batch_size(conf: TrainConfig):
48 | conf.batch_size = 256
49 | conf.eval_ema_every_samples = 100_000_000
50 | conf.eval_every_samples = 100_000_000
51 | conf.sample_every_samples = 1_000_000
52 | conf.save_every_samples = 2_000_000
53 | conf.total_samples = 301_000_000
54 | return conf
55 |
56 |
57 | def latent_512_batch_size(conf: TrainConfig):
58 | conf.batch_size = 512
59 | conf.eval_ema_every_samples = 100_000_000
60 | conf.eval_every_samples = 100_000_000
61 | conf.sample_every_samples = 1_000_000
62 | conf.save_every_samples = 5_000_000
63 | conf.total_samples = 501_000_000
64 | return conf
65 |
66 |
67 | def latent_2048_batch_size(conf: TrainConfig):
68 | conf.batch_size = 2048
69 | conf.eval_ema_every_samples = 200_000_000
70 | conf.eval_every_samples = 200_000_000
71 | conf.sample_every_samples = 4_000_000
72 | conf.save_every_samples = 20_000_000
73 | conf.total_samples = 1_501_000_000
74 | return conf
75 |
76 |
77 | def adamw_weight_decay(conf: TrainConfig):
78 | conf.optimizer = OptimizerType.adamw
79 | conf.weight_decay = 0.01
80 | return conf
81 |
82 |
83 | def ffhq128_autoenc_latent():
84 | conf = pretrain_ffhq128_autoenc130M()
85 | conf = latent_diffusion128_config(conf)
86 | conf = latent_mlp_2048_norm_10layers(conf)
87 | conf = latent_256_batch_size(conf)
88 | conf = adamw_weight_decay(conf)
89 | conf.total_samples = 101_000_000
90 | conf.latent_loss_type = LossType.l1
91 | conf.latent_beta_scheduler = 'const0.008'
92 | conf.name = 'ffhq128_autoenc_latent'
93 | return conf
94 |
95 |
96 | def ffhq256_autoenc_latent():
97 | conf = pretrain_ffhq256_autoenc()
98 | conf = latent_diffusion128_config(conf)
99 | conf = latent_mlp_2048_norm_10layers(conf)
100 | conf = latent_256_batch_size(conf)
101 | conf = adamw_weight_decay(conf)
102 | conf.total_samples = 101_000_000
103 | conf.latent_loss_type = LossType.l1
104 | conf.latent_beta_scheduler = 'const0.008'
105 | conf.eval_ema_every_samples = 200_000_000
106 | conf.eval_every_samples = 200_000_000
107 | conf.sample_every_samples = 4_000_000
108 | conf.name = 'ffhq256_autoenc_latent'
109 | return conf
110 |
111 |
112 | def horse128_autoenc_latent():
113 | conf = pretrain_horse128()
114 | conf = latent_diffusion128_config(conf)
115 | conf = latent_2048_batch_size(conf)
116 | conf = latent_mlp_2048_norm_20layers(conf)
117 | conf.total_samples = 2_001_000_000
118 | conf.latent_beta_scheduler = 'const0.008'
119 | conf.latent_loss_type = LossType.l1
120 | conf.name = 'horse128_autoenc_latent'
121 | return conf
122 |
123 |
124 | def bedroom128_autoenc_latent():
125 | conf = pretrain_bedroom128()
126 | conf = latent_diffusion128_config(conf)
127 | conf = latent_2048_batch_size(conf)
128 | conf = latent_mlp_2048_norm_20layers(conf)
129 | conf.total_samples = 2_001_000_000
130 | conf.latent_beta_scheduler = 'const0.008'
131 | conf.latent_loss_type = LossType.l1
132 | conf.name = 'bedroom128_autoenc_latent'
133 | return conf
134 |
135 |
136 | def celeba64d2c_autoenc_latent():
137 | conf = pretrain_celeba64d2c_72M()
138 | conf = latent_diffusion_config(conf)
139 | conf = latent_512_batch_size(conf)
140 | conf = latent_mlp_2048_norm_10layers(conf)
141 | conf = adamw_weight_decay(conf)
142 | # just for the name
143 | conf.continue_from = PretrainConfig('200M',
144 | f'log-latent/{conf.name}/last.ckpt')
145 | conf.postfix = '_300M'
146 | conf.total_samples = 301_000_000
147 | conf.latent_beta_scheduler = 'const0.008'
148 | conf.latent_loss_type = LossType.l1
149 | conf.name = 'celeba64d2c_autoenc_latent'
150 | return conf
151 |
--------------------------------------------------------------------------------