├── .gitignore
├── .vscode
└── settings.json
├── LICENSE
├── README.md
├── assets
├── demo.png
└── method.png
├── demo.ipynb
├── diffusion
├── __init__.py
├── diffusion_utils.py
├── gaussian_diffusion.py
└── respace.py
├── main_diffusion.py
├── main_reconstruction.py
├── models
├── __init__.py
├── ardiff.py
├── autoencoder.py
├── detok.py
├── diffloss.py
├── dit.py
├── ema.py
├── layers.py
├── lightningdit.py
├── mar.py
├── model_utils.py
└── sit.py
├── pyrightconfig.json
├── requirements.txt
├── transport
├── __init__.py
├── integrators.py
├── path.py
└── transport.py
└── utils
├── builders.py
├── distributed.py
├── download.py
├── loader.py
├── logger.py
├── losses.py
├── misc.py
└── train_utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | data
3 | work_dirs/
4 | *.mp4
5 | *draft*
6 | pretrained_models/
7 | code.txt
8 |
9 |
10 | # media
11 | *.mp4
12 | *.png
13 | *.jpg
14 |
15 |
16 | # wandb
17 | wandb/
18 |
19 | # work in progress
20 | *wip*
21 |
22 | *results*
23 | *debug*
24 | *tmp*
25 |
26 | # caches
27 | *.pyc
28 | *.swp
29 |
30 |
31 | __pycache__
32 |
33 | # checkpoints
34 | *.pt
35 | *.pth
36 | *.npz
37 | src/
38 |
39 | /guided-diffusion/
40 |
41 |
42 | .clangd
43 | compile_commands.json
44 |
45 | # Visual Studio Code configs.
46 | .vscode/
47 |
48 | # Byte-compiled / optimized / DLL files
49 | *.py[cod]
50 | *$py.class
51 |
52 | # C extensions
53 | *.so
54 |
55 | # Distribution / packaging
56 | .Python
57 | build/
58 | develop-eggs/
59 | dist/
60 | downloads/
61 | eggs/
62 | .eggs/
63 | # lib/
64 | lib64/
65 | parts/
66 | sdist/
67 | var/
68 | wheels/
69 | *.egg-info/
70 | .installed.cfg
71 | *.egg
72 | MANIFEST
73 |
74 | # PyInstaller
75 | # Usually these files are written by a python script from a template
76 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
77 | *.manifest
78 | *.spec
79 |
80 | # Installer logs
81 | pip-log.txt
82 | pip-delete-this-directory.txt
83 |
84 | # Unit test / coverage reports
85 | htmlcov/
86 | .tox/
87 | .coverage
88 | .coverage.*
89 | .cache
90 | nosetests.xml
91 | coverage.xml
92 | *.cover
93 | .hypothesis/
94 | .pytest_cache/
95 |
96 | # Translations
97 | *.mo
98 | *.pot
99 |
100 | # Django stuff:
101 | *.log
102 | local_settings.py
103 | db.sqlite3
104 |
105 | # Flask stuff:
106 | instance/
107 | .webassets-cache
108 |
109 | # Scrapy stuff:
110 | .scrapy
111 |
112 | # Sphinx documentation
113 | docs/_build/
114 |
115 | # PyBuilder
116 | target/
117 |
118 | # Jupyter Notebook
119 | .ipynb_checkpoints
120 |
121 | # pyenv
122 | .python-version
123 |
124 | # celery beat schedule file
125 | celerybeat-schedule
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 |
152 | .DS_Store
153 |
154 | # Direnv config.
155 | .envrc
156 |
157 | # line_profiler
158 | *.lprof
159 |
160 | *build
161 | compile_commands.json
162 | *.dump
163 |
164 | # modelling/
165 | legacy/
166 | released_model/
167 | scripts/
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | // Automatically format using Black on save.
3 | "editor.formatOnSave": true,
4 | // Draw a ruler at Black's column width.
5 | "editor.rulers": [
6 | 110,
7 | ],
8 | // Hide non-code files.
9 | "files.exclude": {
10 | "**/.git": true,
11 | "**/.svn": true,
12 | "**/.hg": true,
13 | "**/CVS": true,
14 | "**/.DS_Store": true,
15 | "**/Thumbs.db": true,
16 | "**/__pycache__": true,
17 | "**/venv": true
18 | },
19 | "[python]": {
20 | "editor.defaultFormatter": "ms-python.black-formatter",
21 | },
22 | "debug.focusWindowOnBreak": false,
23 | "files.watcherExclude": {
24 | "**/.git/**": true,
25 | "**/checkpoints/**": true,
26 | "**/data/**": true,
27 | "**/work_dirs/**": true,
28 | "**/lightning_logs/**": true,
29 | "**/outputs/**": true,
30 | "**/dataset_cache/**": true,
31 | "**/.ruff_cache/**": true,
32 | "**/venv/**": true,
33 | "**/data": true
34 | },
35 | "editor.mouseWheelZoom": true,
36 | "terminal.integrated.mouseWheelZoom": true
37 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Jiawei Yang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeTok: Latent Denoising Makes Good Visual Tokenizers
Official PyTorch Implementation
2 |
3 | [](https://arxiv.org/abs/2507.15856)
4 | [](https://huggingface.co/jjiaweiyang/l-DeTok)
5 |
6 |
7 |
8 |
9 |
10 | This is a PyTorch/GPU implementation of the paper **Latent Denoising Makes Good Visual Tokenizers**:
11 |
12 | ```
13 | @article{yang2025detok,
14 | title={Latent Denoising Makes Good Visual Tokenizers},
15 | author={Jiawei Yang and Tianhong Li and Lijie Fan and Yonglong Tian and Yue Wang},
16 | journal={arXiv preprint arXiv:2507.15856},
17 | year={2025}
18 | }
19 | ```
20 |
21 | This repo contains:
22 |
23 | * 🪐 A simple PyTorch implementation of [l-DeTok](models/detok.py) tokenizer and various generative models ([MAR](models/mar.py), [RandARDiff](models/ardiff.py), [RasterARDiff](models/ardiff.py), [DiT](models/dit.py), [SiT](models/sit.py), and [LightningDiT](models/lightningdit.py))
24 | * ⚡️ Pre-trained DeTok tokenizers and MAR models trained on ImageNet 256x256
25 | * 🛸 Training and evaluation scripts for tokenizer and generative models
26 | * 🎉 Hugging Face for easy access to pre-trained models
27 |
28 | ## Preparation
29 |
30 |
31 | ### Installation
32 |
33 | Download the code:
34 | ```bash
35 | git clone https://github.com/Jiawei-Yang/detok.git
36 | cd detok
37 | ```
38 |
39 | Create and activate conda environment:
40 | ```bash
41 | conda create -n detok python=3.10 -y && conda activate detok
42 | pip install -r requirements.txt
43 | ```
44 |
45 | ### Dataset
46 | Create `data/` folder by `mkdir data/` then download [ImageNet](http://image-net.org/download) dataset. You can either:
47 | 1. Download it directly to `data/imagenet/`
48 | 2. Create a symbolic link: `ln -s /path/to/your/imagenet data/imagenet`
49 |
50 |
51 | ### Download Required Files
52 |
53 | Create data directory and download required files:
54 | ```bash
55 | mkdir data/
56 | # Download everything from huggingface
57 | huggingface-cli download jjiaweiyang/l-DeTok --local-dir released_model
58 | mv released_model/train.txt data/
59 | mv released_model/val.txt data/
60 | mv released_model/fid_stats data/
61 | mv released_model/imagenet-val-prc.zip ./data/
62 |
63 | # Unzip: imagenet-val-prc.zip for precision & recall evaluation
64 | python -m zipfile -e ./data/imagenet-val-prc.zip ./data/
65 | ```
66 |
67 | ### Data Organization
68 |
69 | Your data directory should be organized as follows:
70 | ```
71 | data/
72 | ├── fid_stats/ # FID statistics
73 | │ ├── adm_in256_stats.npz
74 | │ └── val_fid_statistics_file.npz
75 | ├── imagenet/ # ImageNet dataset (or symlink)
76 | │ ├── train/
77 | │ └── val/
78 | ├── imagenet-val-prc/ # Precision-recall data
79 | ├── train.txt # Training file list
80 | └── val.txt # Validation file list
81 | ```
82 |
83 | ## Models
84 |
85 | For convenience, our pre-trained models will be available on Hugging Face:
86 |
87 | | Model | Type | Params | Hugging Face |
88 | |---------------------|-----------|-------|-------------|
89 | | DeTok-BB | Tokenizer | 172M | [🤗 detok-bb](https://huggingface.co/jjiaweiyang/l-DeTok/resolve/main/detok-BB-gamm3.0-m0.7.pth) |
90 | | DeTok-BB-decoder_ft | Tokenizer | 172M | [🤗 detok-bb-decoder_ft](https://huggingface.co/jjiaweiyang/l-DeTok/resolve/main/detok-BB-gamm3.0-m0.7-decoder_tuned.pth) |
91 | | MAR-Base | Generator | 208M | [🤗 mar-base](https://huggingface.co/jjiaweiyang/l-DeTok/resolve/main/mar_base.pth) |
92 | | MAR-Large | Generator | 479M | [🤗 mar-large](https://huggingface.co/jjiaweiyang/l-DeTok/resolve/main/mar_large.pth) |
93 |
94 | FID-50k with CFG:
95 | |cfg| MAR Model | FID-50K | Inception Score |
96 | |---|-----------------------------------|---------|-----------------|
97 | |3.9| MAR-Base + DeTok-BB | 1.61 | 289.7 |
98 | |3.9| MAR-Base + DeTok-BB-decoder_ft | 1.55 | 291.0 |
99 | |3.4| MAR-Large + DeTok-BB | 1.43 | 303.5 |
100 | |3.4| MAR-Large + DeTok-BB-decoder_ft | 1.32 | 304.1 |
101 |
102 | ## Usage
103 |
104 | ### Demo
105 | Run our demo using notebook at [demo.ipynb](demo.ipynb)
106 |
107 |
108 | ## Training
109 |
110 | ### 1. Tokenizer Training
111 |
112 | Train DeTok tokenizer with denoising:
113 | ```bash
114 | project=tokenizer_training
115 | exp_name=detokBB-g3.0-m0.7-200ep
116 | batch_size=32 # global batch size = batch_size x num_nodes x 8 = 1024
117 | num_nodes=4 # adjust for your multi-node setup
118 | YOUR_WANDB_ENTITY="" # change to your wandb entity
119 |
120 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
121 | main_reconstruction.py \
122 | --project $project --exp_name $exp_name --auto_resume \
123 | --batch_size $batch_size --model detok_BB \
124 | --gamma 3.0 --mask_ratio 0.7 \
125 | --online_eval \
126 | --epochs 200 --discriminator_start_epoch 100 \
127 | --data_path ./data/imagenet/train \
128 | --entity $YOUR_WANDB_ENTITY --enable_wandb
129 | ```
130 |
131 | Decoder fine-tuning:
132 | ```bash
133 | project=tokenizer_training
134 | exp_name=detokBB-g3.0-m0.7-200ep-decoder_ft-100ep
135 | batch_size=32
136 | num_nodes=4
137 | pretrained_tok=work_dirs/tokenizer_training/detokBB-g3.0-m0.7-200ep/checkpoints/latest.pth
138 | YOUR_WANDB_ENTITY=""
139 |
140 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
141 | main_reconstruction.py \
142 | --project $project --exp_name $exp_name --auto_resume \
143 | --batch_size $batch_size --model detok_BB \
144 | --load_from $pretrained_tok \
145 | --online_eval --train_decoder_only \
146 | --perceptual_weight 0.1 \
147 | --gamma 0.0 --mask_ratio 0.0 \
148 | --blr 5e-5 --warmup_rate 0.05 \
149 | --epochs 100 --discriminator_start_epoch 0 \
150 | --data_path ./data/imagenet/train \
151 | --entity $YOUR_WANDB_ENTITY --enable_wandb
152 | ```
153 |
154 | ### 2. Generative Model Training
155 |
156 | Train MAR-base (100 epochs):
157 | ```bash
158 | tokenizer_project=tokenizer_training
159 | tokenizer_exp_name=detokBB-g3.0-m0.7-200ep-decoder_ft-100ep
160 | project=gen_model_training
161 | exp_name=mar_base-${tokenizer_exp_name}
162 | batch_size=32 # global batch size = batch_size x num_nodes x 8 = 1024
163 | num_nodes=4
164 | epochs=100
165 |
166 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
167 | main_diffusion.py \
168 | --project $project --exp_name $exp_name --auto_resume \
169 | --batch_size $batch_size --epochs $epochs --use_aligned_schedule \
170 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \
171 | --stats_key $tokenizer_exp_name --stats_cache_path work_dirs/stats.pkl \
172 | --load_tokenizer_from work_dirs/$tokenizer_project/$tokenizer_exp_name/checkpoints/latest.pth \
173 | --model MAR_base --no_dropout_in_mlp \
174 | --diffloss_d 6 --diffloss_w 1024 \
175 | --num_sampling_steps 100 --cfg 4.0 \
176 | --cfg_list 3.0 3.5 3.7 3.8 3.9 4.0 4.1 4.3 4.5 \
177 | --vis_freq 50 --eval_bsz 256 \
178 | --data_path ./data/imagenet/train \
179 | --entity $YOUR_WANDB_ENTITY --enable_wandb
180 | ```
181 |
182 | Train SiT-base (100 epochs):
183 | ```bash
184 | tokenizer_project=tokenizer_training
185 | tokenizer_exp_name=detokBB-g3.0-m0.7-200ep-decoder_ft-100ep
186 | project=gen_model_training
187 | exp_name=sit_base-${tokenizer_exp_name}
188 | batch_size=32
189 | num_nodes=4
190 | epochs=100
191 |
192 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
193 | main_diffusion.py \
194 | --project $project --exp_name $exp_name --auto_resume \
195 | --batch_size $batch_size --epochs $epochs --use_aligned_schedule \
196 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \
197 | --stats_key $tokenizer_exp_name --stats_cache_path work_dirs/stats.pkl \
198 | --load_tokenizer_from work_dirs/$tokenizer_project/$tokenizer_exp_name/checkpoints/latest.pth \
199 | --model SiT_base \
200 | --num_sampling_steps 250 --cfg 1.6 \
201 | --cfg_list 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2.0 \
202 | --vis_freq 50 --eval_bsz 256 \
203 | --data_path ./data/imagenet/train \
204 | --entity $YOUR_WANDB_ENTITY --enable_wandb
205 | ```
206 |
207 | ### 3. Training with the released DeTok
208 |
209 | Train MAR-base for 800 epochs using released tokenizer:
210 | ```bash
211 | project=gen_model_training
212 | exp_name=mar_base_800ep-detok-BB-gamm3.0-m0.7-decoder_tuned
213 | batch_size=16 # global batch size = batch_size x num_nodes x 8 = 1024
214 | num_nodes=8
215 | epochs=800
216 |
217 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
218 | main_diffusion.py \
219 | --project $project --exp_name $exp_name --auto_resume \
220 | --batch_size $batch_size --epochs $epochs --use_aligned_schedule \
221 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \
222 | --stats_key detok-BB-gamm3.0-m0.7 --stats_cache_path released_model/stats.pkl \
223 | --load_tokenizer_from released_model/detok-BB-gamm3.0-m0.7-decoder_tuned.pth \
224 | --model MAR_base --no_dropout_in_mlp \
225 | --diffloss_d 6 --diffloss_w 1024 \
226 | --num_sampling_steps 100 --cfg 3.9 \
227 | --cfg_list 3.0 3.5 3.7 3.8 3.9 4.0 4.1 4.3 4.5 \
228 | --online_eval --vis_freq 80 --eval_bsz 256 \
229 | --data_path ./data/imagenet/train \
230 | --entity $YOUR_WANDB_ENTITY --enable_wandb
231 | ```
232 |
233 | Train MAR-large for 800 epochs:
234 | ```bash
235 | project=gen_model_training
236 | exp_name=mar_large_800ep-detok-BB-gamm3.0-m0.7-decoder_tuned
237 | batch_size=16
238 | num_nodes=8
239 | epochs=800
240 |
241 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \
242 | main_diffusion.py \
243 | --project $project --exp_name $exp_name --auto_resume \
244 | --batch_size $batch_size --epochs $epochs --use_aligned_schedule \
245 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \
246 | --stats_key detok-BB-gamm3.0-m0.7 --stats_cache_path released_model/stats.pkl \
247 | --load_tokenizer_from released_model/detok-BB-gamm3.0-m0.7-decoder_tuned.pth \
248 | --model MAR_large --no_dropout_in_mlp \
249 | --diffloss_d 8 --diffloss_w 1280 \
250 | --num_sampling_steps 100 --cfg 3.4 \
251 | --cfg_list 3.0 3.2 3.3 3.4 3.5 3.6 3.7 3.8 3.9 \
252 | --online_eval --vis_freq 80 --eval_bsz 256 \
253 | --data_path ./data/imagenet/train \
254 | --entity $YOUR_WANDB_ENTITY --enable_wandb
255 | ```
256 |
257 | ## Evaluation (ImageNet 256x256)
258 |
259 | ### Evaluate pretrained MAR-Base
260 | ```bash
261 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
262 | main_diffusion.py \
263 | --project mar_eval --exp_name mar_base_ep800 --auto_resume \
264 | --batch_size 64 --epochs 800 --use_aligned_schedule \
265 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \
266 | --stats_key detok-BB-gamm3.0-m0.7 --stats_cache_path released_model/stats.pkl \
267 | --load_tokenizer_from released_model/detok-BB-gamm3.0-m0.7-decoder_tuned.pth \
268 | --model MAR_base --no_dropout_in_mlp \
269 | --diffloss_d 6 --diffloss_w 1024 \
270 | --load_from released_model/mar_base.pth \
271 | --num_sampling_steps 100 --eval_bsz 256 --num_images 50000 --num_iter 256 --evaluate \
272 | --cfg 3.9 \
273 | --data_path ./data/imagenet/train
274 | ```
275 |
276 | ### Evaluate pretrained MAR-Large
277 | ```bash
278 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \
279 | main_diffusion.py \
280 | --project mar_eval --exp_name mar_large_ep800 --auto_resume \
281 | --batch_size 64 --epochs 800 --use_aligned_schedule \
282 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \
283 | --stats_key detok-BB-gamm3.0-m0.7 --stats_cache_path released_model/stats.pkl \
284 | --load_tokenizer_from released_model/detok-BB-gamm3.0-m0.7-decoder_tuned.pth \
285 | --model MAR_large --no_dropout_in_mlp \
286 | --diffloss_d 8 --diffloss_w 1280 \
287 | --load_from released_model/mar_large.pth \
288 | --num_sampling_steps 100 --eval_bsz 256 --num_images 50000 --num_iter 256 --evaluate \
289 | --cfg 3.4 \
290 | --data_path ./data/imagenet/train
291 | ```
292 |
293 | FID-50k with CFG:
294 | |cfg| MAR Model | FID-50K | Inception Score | #params |
295 | |---|------------------------------|---------|-----------------|---------|
296 | |3.9| MAR-B + l-DeTok | 1.61 | 289.7 | 208M |
297 | |3.9| MAR-B + l-DeTok (decoder_ft) | 1.55 | 291.0 | 208M |
298 | |3.4| MAR-L + l-DeTok | 1.43 | 303.5 | 479M |
299 | |3.4| MAR-L + l-DeTok (decoder_ft) | 1.32 | 304.1 | 479M |
300 |
301 |
302 | **Note:**
303 | - Set `--cfg 1.0 --temperature 0.95` to evaluate without CFG for MAR-base and `--cfg 1.0 --temperature 0.97` for MAR-large
304 | - Generation speed can be significantly increased by reducing the number of autoregressive iterations (e.g., `--num_iter 64`)
305 |
306 | ## Acknowledgements
307 | We thank the authors of [MAE](https://github.com/facebookresearch/mae), [MAGE](https://github.com/LTH14/mage), [DiT](https://github.com/facebookresearch/DiT), [LightningDiT](https://github.com/hustvl/LightningDiT), [MAETok](https://github.com/Hhhhhhao/continuous_tokenizer) and [MAR](https://github.com/LTH14/mar) for their foundational work.
308 |
309 | Our codebase builds upon several excellent open-source projects, including [MAR](https://github.com/LTH14/mar) and [1d-tokenizer](https://github.com/bytedance/1d-tokenizer). We are grateful to the communities behind them.
310 |
311 | ## Contact
312 | This codebase has been cleaned up but has not undergone extensive testing. If you encounter any issues or have questions, please open a GitHub issue. We appreciate your feedback!
--------------------------------------------------------------------------------
/assets/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiawei-Yang/DeTok/be99bae6f3539c34d8711addd16f2926295500d6/assets/demo.png
--------------------------------------------------------------------------------
/assets/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jiawei-Yang/DeTok/be99bae6f3539c34d8711addd16f2926295500d6/assets/method.png
--------------------------------------------------------------------------------
/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | # Adopted from DiT, which is modified from OpenAI's diffusion repos
2 | # DiT: https://github.com/facebookresearch/DiT/diffusion
3 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
4 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
5 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
6 |
7 | import logging
8 |
9 | from . import gaussian_diffusion as gd
10 | from .respace import SpacedDiffusion, space_timesteps
11 |
12 | logger = logging.getLogger("DeTok")
13 |
14 |
15 | def create_diffusion(
16 | timestep_respacing,
17 | noise_schedule="linear",
18 | use_kl=False,
19 | sigma_small=False,
20 | predict_xstart=False,
21 | learn_sigma=True,
22 | rescale_learned_sigmas=False,
23 | diffusion_steps=1000,
24 | channel_last=False,
25 | ) -> SpacedDiffusion:
26 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
27 | if use_kl:
28 | loss_type = gd.LossType.RESCALED_KL
29 | elif rescale_learned_sigmas:
30 | loss_type = gd.LossType.RESCALED_MSE
31 | else:
32 | loss_type = gd.LossType.MSE
33 | if timestep_respacing is None or timestep_respacing == "":
34 | timestep_respacing = [diffusion_steps]
35 | if predict_xstart:
36 | model_mean_type = gd.ModelMeanType.START_X
37 | else:
38 | model_mean_type = gd.ModelMeanType.EPSILON
39 |
40 | if learn_sigma:
41 | model_var_type = gd.ModelVarType.LEARNED_RANGE
42 | else:
43 | if sigma_small:
44 | model_var_type = gd.ModelVarType.FIXED_SMALL
45 | else:
46 | model_var_type = gd.ModelVarType.FIXED_LARGE
47 |
48 | diffusion = SpacedDiffusion(
49 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
50 | betas=betas,
51 | model_mean_type=model_mean_type,
52 | model_var_type=model_var_type,
53 | loss_type=loss_type,
54 | channel_last=channel_last,
55 | )
56 | logger.info(f"Created diffusion with timestep respacing {timestep_respacing}")
57 | return diffusion
58 |
--------------------------------------------------------------------------------
/diffusion/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 |
10 | def normal_kl(mean1, logvar1, mean2, logvar2):
11 | """
12 | Compute the KL divergence between two gaussians.
13 | Shapes are automatically broadcasted, so batches can be compared to
14 | scalars, among other use cases.
15 | """
16 | tensor = None
17 | for obj in (mean1, logvar1, mean2, logvar2):
18 | if isinstance(obj, th.Tensor):
19 | tensor = obj
20 | break
21 | assert tensor is not None, "at least one argument must be a Tensor"
22 |
23 | # Force variances to be Tensors. Broadcasting helps convert scalars to
24 | # Tensors, but it does not work for th.exp().
25 | logvar1, logvar2 = [
26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)
27 | ]
28 |
29 | return 0.5 * (
30 | -1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
31 | )
32 |
33 |
34 | def approx_standard_normal_cdf(x):
35 | """
36 | A fast approximation of the cumulative distribution function of the
37 | standard normal.
38 | """
39 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
40 |
41 |
42 | def discretized_gaussian_log_likelihood(x, *, means, log_scales):
43 | """
44 | Compute the log-likelihood of a Gaussian distribution discretizing to a
45 | given image.
46 | :param x: the target images. It is assumed that this was uint8 values,
47 | rescaled to the range [-1, 1].
48 | :param means: the Gaussian mean Tensor.
49 | :param log_scales: the Gaussian log stddev Tensor.
50 | :return: a tensor like x of log probabilities (in nats).
51 | """
52 | assert x.shape == means.shape == log_scales.shape
53 | centered_x = x - means
54 | inv_stdv = th.exp(-log_scales)
55 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
56 | cdf_plus = approx_standard_normal_cdf(plus_in)
57 | min_in = inv_stdv * (centered_x - 1.0 / 255.0)
58 | cdf_min = approx_standard_normal_cdf(min_in)
59 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
60 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
61 | cdf_delta = cdf_plus - cdf_min
62 | log_probs = th.where(
63 | x < -0.999,
64 | log_cdf_plus,
65 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
66 | )
67 | assert log_probs.shape == x.shape
68 | return log_probs
69 |
--------------------------------------------------------------------------------
/diffusion/respace.py:
--------------------------------------------------------------------------------
1 | # Modified from OpenAI's diffusion repos
2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 | from .gaussian_diffusion import GaussianDiffusion
10 |
11 |
12 | def space_timesteps(num_timesteps, section_counts):
13 | """
14 | Create a list of timesteps to use from an original diffusion process,
15 | given the number of timesteps we want to take from equally-sized portions
16 | of the original process.
17 | For example, if there's 300 timesteps and the section counts are [10,15,20]
18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100
19 | are strided to be 15 timesteps, and the final 100 are strided to be 20.
20 | If the stride is a string starting with "ddim", then the fixed striding
21 | from the DDIM paper is used, and only one section is allowed.
22 | :param num_timesteps: the number of diffusion steps in the original
23 | process to divide up.
24 | :param section_counts: either a list of numbers, or a string containing
25 | comma-separated numbers, indicating the step count
26 | per section. As a special case, use "ddimN" where N
27 | is a number of steps to use the striding from the
28 | DDIM paper.
29 | :return: a set of diffusion steps from the original process to use.
30 | """
31 | if isinstance(section_counts, int):
32 | section_counts = str(section_counts)
33 | if isinstance(section_counts, str):
34 | if section_counts.startswith("ddim"):
35 | desired_count = int(section_counts[len("ddim") :])
36 | for i in range(1, num_timesteps):
37 | if len(range(0, num_timesteps, i)) == desired_count:
38 | return set(range(0, num_timesteps, i))
39 | raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
40 | section_counts = [int(x) for x in section_counts.split(",")]
41 | size_per = num_timesteps // len(section_counts)
42 | extra = num_timesteps % len(section_counts)
43 | start_idx = 0
44 | all_steps = []
45 | for i, section_count in enumerate(section_counts):
46 | size = size_per + (1 if i < extra else 0)
47 | if size < section_count:
48 | raise ValueError(f"cannot divide section of {size} steps into {section_count}")
49 | if section_count <= 1:
50 | frac_stride = 1
51 | else:
52 | frac_stride = (size - 1) / (section_count - 1)
53 | cur_idx = 0.0
54 | taken_steps = []
55 | for _ in range(section_count):
56 | taken_steps.append(start_idx + round(cur_idx))
57 | cur_idx += frac_stride
58 | all_steps += taken_steps
59 | start_idx += size
60 | return set(all_steps)
61 |
62 |
63 | class SpacedDiffusion(GaussianDiffusion):
64 | """
65 | A diffusion process which can skip steps in a base diffusion process.
66 | :param use_timesteps: a collection (sequence or set) of timesteps from the
67 | original diffusion process to retain.
68 | :param kwargs: the kwargs to create the base diffusion process.
69 | """
70 |
71 | def __init__(self, use_timesteps, **kwargs):
72 | self.use_timesteps = set(use_timesteps)
73 | self.timestep_map = []
74 | self.original_num_steps = len(kwargs["betas"])
75 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
76 | last_alpha_cumprod = 1.0
77 | new_betas = []
78 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
79 | if i in self.use_timesteps:
80 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
81 | last_alpha_cumprod = alpha_cumprod
82 | self.timestep_map.append(i)
83 | kwargs["betas"] = np.array(new_betas)
84 | super().__init__(**kwargs)
85 |
86 | def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
87 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
88 |
89 | def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
90 | return super().training_losses(self._wrap_model(model), *args, **kwargs)
91 |
92 | def condition_mean(self, cond_fn, *args, **kwargs):
93 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
94 |
95 | def condition_score(self, cond_fn, *args, **kwargs):
96 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
97 |
98 | def _wrap_model(self, model):
99 | if isinstance(model, _WrappedModel):
100 | return model
101 | return _WrappedModel(model, self.timestep_map, self.original_num_steps)
102 |
103 | def _scale_timesteps(self, t):
104 | # Scaling is done by the wrapped model.
105 | return t
106 |
107 |
108 | class _WrappedModel:
109 | def __init__(self, model, timestep_map, original_num_steps):
110 | self.model = model
111 | self.timestep_map = timestep_map
112 | # self.rescale_timesteps = rescale_timesteps
113 | self.original_num_steps = original_num_steps
114 |
115 | def __call__(self, x, ts, **kwargs):
116 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
117 | new_ts = map_tensor[ts]
118 | # if self.rescale_timesteps:
119 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
120 | return self.model(x, new_ts, **kwargs)
121 |
--------------------------------------------------------------------------------
/main_diffusion.py:
--------------------------------------------------------------------------------
1 | """
2 | DeTok: Generation model training script.
3 | """
4 |
5 | import argparse
6 | import datetime
7 | import logging
8 | import sys
9 | import time
10 |
11 | import torch
12 | import torch.distributed
13 |
14 | import models
15 | import utils.distributed as distributed
16 | from utils.builders import create_generation_model, create_optimizer_and_scaler, create_train_dataloader
17 | from utils.misc import ckpt_resume, save_checkpoint
18 | from utils.train_utils import (
19 | collect_tokenizer_stats,
20 | evaluate_generator,
21 | setup,
22 | train_one_epoch_generator,
23 | visualize_generator,
24 | visualize_tokenizer,
25 | )
26 |
27 | # performance optimizations
28 | torch.backends.cuda.matmul.allow_tf32 = True
29 | torch.backends.cudnn.allow_tf32 = True
30 | torch.backends.cudnn.benchmark = True
31 | torch.backends.cudnn.deterministic = False
32 |
33 | logger = logging.getLogger("DeTok")
34 |
35 |
36 | def main(args: argparse.Namespace) -> int:
37 | global logger
38 | wandb_logger = setup(args)
39 | data_loader_train = create_train_dataloader(args)
40 |
41 | # initialize models
42 | model, tokenizer, ema_model = create_generation_model(args)
43 | optimizer, loss_scaler = create_optimizer_and_scaler(args, model)
44 | model_wo_ddp = model
45 |
46 | # handle token caching or tokenizer statistics collection
47 | if args.collect_tokenizer_stats:
48 | tmp_data_loader = create_train_dataloader(
49 | args, should_flip=False, batch_size=args.tokenizer_bsz,
50 | return_path=True, drop_last=False
51 | )
52 | # (B, C, H, W) for chan_dim=2 or (B, seq_len, C) for chan_dim=1
53 | chan_dim = 2 if args.tokenizer in models.DeTok_models else 1
54 |
55 | # collect stats
56 | result_dict = collect_tokenizer_stats(
57 | tokenizer, tmp_data_loader, chan_dim=chan_dim,
58 | stats_dict_key=args.stats_key,
59 | stats_dict_path=args.stats_cache_path,
60 | overwrite_stats=args.overwrite_stats,
61 | )
62 | # update tokenizer with computed statistics
63 | mean, std = result_dict["channel"]
64 | if mean.ndim > 0:
65 | n_chans = len(mean) // 2
66 | mean, std = mean[:n_chans], std[:n_chans]
67 | tokenizer.reset_stats(mean, std)
68 |
69 | del tmp_data_loader
70 | data_dict = next(iter(data_loader_train))
71 | visualize_tokenizer(args, tokenizer, ema_model=None, data_dict=data_dict)
72 |
73 | # setup distributed training
74 | if distributed.is_enabled():
75 | model = torch.nn.parallel.DistributedDataParallel(model)
76 | model_wo_ddp = model.module
77 |
78 | # resume from checkpoint if needed
79 | logger.info("Auto-resume enabled")
80 | ckpt_resume(args, model_wo_ddp, optimizer, loss_scaler, ema_model)
81 |
82 | # evaluation-only mode
83 | if args.evaluate:
84 | torch.cuda.empty_cache()
85 | cfg_list = args.cfg_list if args.cfg_list is not None else [args.cfg]
86 | for cfg in cfg_list:
87 | evaluate_generator(
88 | args,
89 | model_wo_ddp,
90 | ema_model,
91 | tokenizer,
92 | epoch=args.start_epoch,
93 | wandb_logger=wandb_logger,
94 | cfg=cfg,
95 | use_ema=True, # always use ema model for evaluation
96 | num_images=args.num_images,
97 | )
98 | return 0
99 |
100 | # training loop
101 | logger.info(f"Start training from {args.start_epoch} to {args.epochs}")
102 | start_time = time.time()
103 |
104 | for epoch in range(args.start_epoch, args.epochs):
105 | train_one_epoch_generator(
106 | args, model, data_loader_train, optimizer, loss_scaler, wandb_logger,
107 | epoch, ema_model, tokenizer
108 | )
109 |
110 | # progress logging
111 | elapsed_t = time.time() - start_time + args.last_elapsed_time
112 | eta = elapsed_t / (epoch + 1) * (args.epochs - epoch - 1)
113 | logger.info(
114 | f"[{epoch}/{args.epochs}] "
115 | f"Accumulated elapsed time: {str(datetime.timedelta(seconds=int(elapsed_t)))}, "
116 | f"ETA: {str(datetime.timedelta(seconds=int(eta)))}"
117 | )
118 |
119 | # checkpointing
120 | should_save = (
121 | (epoch + 1) % args.save_freq == 0 # save every n epochs
122 | or (epoch + 1) == args.epochs # save at the end of training
123 | )
124 |
125 | if should_save:
126 | save_checkpoint(args, epoch, model_wo_ddp, optimizer, loss_scaler, ema_model, elapsed_t)
127 | torch.distributed.barrier()
128 |
129 | # periodic visualization
130 | if (epoch + 1) % args.vis_freq == 0:
131 | visualize_generator(args, model_wo_ddp, ema_model, tokenizer, epoch + 1)
132 |
133 | # online evaluation
134 | if args.online_eval and (epoch + 1) % args.eval_freq == 0:
135 | torch.cuda.empty_cache()
136 | evaluate_generator(
137 | args, model_wo_ddp, ema_model, tokenizer, epoch + 1, wandb_logger,
138 | use_ema=True, num_images=args.num_images_for_eval_and_search, cfg=args.cfg
139 | )
140 |
141 | # final evaluation
142 | total_time = int(time.time() - start_time + args.last_elapsed_time)
143 | logger.info(f"Training time {str(datetime.timedelta(seconds=total_time))}")
144 |
145 |
146 | # determine cfg values for evaluation
147 | cfg_list = args.cfg_list or [args.cfg] # use the cfg from the args if not provided
148 | best_cfg = cfg_list[0]
149 |
150 | if len(cfg_list) > 1:
151 | # search the best cfg value using 10k images
152 | fid_dict = {}
153 | for cfg in cfg_list:
154 | fid_dict[cfg] = evaluate_generator(
155 | args, model_wo_ddp, ema_model, tokenizer, args.epochs + 1, wandb_logger,
156 | use_ema=True, cfg=cfg, num_images=args.num_images_for_eval_and_search
157 | )
158 | # find best cfg value and broadcast to all ranks
159 | if distributed.is_main_process():
160 | best_fid = 100000
161 | for cfg in cfg_list:
162 | if fid_dict[cfg]["fid"] < best_fid:
163 | best_fid = fid_dict[cfg]["fid"]
164 | best_cfg = cfg
165 | logger.info(f"Best FID: {best_fid}, Best cfg: {best_cfg}")
166 |
167 | # broadcast best_cfg from rank 0 to all ranks
168 | if distributed.is_enabled():
169 | best_cfg_tensor = torch.tensor([best_cfg], dtype=torch.float32, device="cuda")
170 | torch.distributed.broadcast(best_cfg_tensor, src=0)
171 | best_cfg = best_cfg_tensor.item()
172 | torch.distributed.barrier()
173 |
174 | # final comprehensive evaluation with best cfg
175 | args.num_iter = 128 if args.tokenizer == "maetok-b-128" else 256
176 | evaluate_generator(
177 | args, model_wo_ddp, ema_model, tokenizer, args.epochs + 1, wandb_logger,
178 | use_ema=True, cfg=best_cfg, num_images=args.num_images
179 | )
180 |
181 | # additional evaluation with cfg=1.0
182 | evaluate_generator(
183 | args, model_wo_ddp, ema_model, tokenizer, args.epochs + 1, wandb_logger,
184 | use_ema=True, cfg=1.0, num_images=args.num_images
185 | )
186 |
187 | return 0
188 |
189 |
190 | def get_args_parser():
191 | parser = argparse.ArgumentParser("Generation model training", add_help=False)
192 |
193 | # basic training parameters
194 | parser.add_argument("--start_epoch", default=0, type=int)
195 | parser.add_argument("--epochs", default=400, type=int)
196 | parser.add_argument("--batch_size", default=64, type=int, help="Batch size per GPU for training")
197 |
198 | # model parameters
199 | parser.add_argument("--model", default="MAR_base", type=str)
200 | parser.add_argument("--order", default="raster", type=str)
201 | parser.add_argument("--patch_size", default=1, type=int)
202 | parser.add_argument("--no_dropout_in_mlp", action="store_true")
203 | parser.add_argument("--qk_norm", action="store_true")
204 | parser.add_argument("--force_one_d_seq", type=int, default=0, help="1d tokens, e.g., 128 for MAETok")
205 | parser.add_argument("--legacy_mode", action="store_true")
206 |
207 | # tokenizer parameters
208 | parser.add_argument("--img_size", default=256, type=int)
209 | parser.add_argument("--tokenizer", default=None, type=str)
210 | parser.add_argument("--token_channels", default=16, type=int)
211 | parser.add_argument("--tokenizer_patch_size", default=16, type=int)
212 | parser.add_argument("--use_ema_tokenizer", action="store_true")
213 |
214 | # tokenizer cache parameters
215 | parser.add_argument("--collect_tokenizer_stats", action="store_true")
216 | parser.add_argument("--tokenizer_bsz", default=256, type=int)
217 | parser.add_argument("--cached_path", type=str, default="data/imagenet_tokens/")
218 | parser.add_argument("--stats_key", type=str, default=None)
219 | parser.add_argument("--overwrite_stats", action="store_true")
220 | parser.add_argument("--stats_cache_path", type=str, default="work_dirs/stats.pkl")
221 |
222 | # logging parameters
223 | parser.add_argument("--output_dir", default="./work_dirs")
224 | parser.add_argument("--print_freq", type=int, default=100)
225 | parser.add_argument("--eval_freq", type=int, default=40)
226 | parser.add_argument("--vis_freq", type=int, default=10)
227 | parser.add_argument("--save_freq", type=int, default=1)
228 | parser.add_argument("--last_elapsed_time", type=float, default=0.0)
229 |
230 | # checkpoint parameters
231 | parser.add_argument("--auto_resume", action="store_true")
232 | parser.add_argument("--resume_from", default=None, help="resume model weights and optimizer state")
233 | parser.add_argument("--load_from", type=str, default=None, help="load from pretrained model")
234 | parser.add_argument("--load_tokenizer_from", type=str, default=None, help="load from pretrained tokenizer")
235 | parser.add_argument("--keep_n_ckpts", default=1, type=int, help="keep the last n checkpoints")
236 | parser.add_argument("--milestone_interval", default=100, type=int, help="keep checkpoints every n epochs")
237 |
238 | # evaluation parameters
239 | parser.add_argument("--num_images_for_eval_and_search", default=10000, type=int)
240 | parser.add_argument("--num_images", default=50000, type=int)
241 | parser.add_argument("--online_eval", action="store_true")
242 | parser.add_argument("--fid_stats_path", type=str, default="data/fid_stats/adm_in256_stats.npz")
243 | parser.add_argument("--keep_eval_folder", action="store_true")
244 | parser.add_argument("--evaluate", action="store_true")
245 | parser.add_argument("--eval_bsz", type=int, default=256)
246 |
247 | # optimization parameters
248 | parser.add_argument("--lr", type=float, default=None)
249 | parser.add_argument("--blr", type=float, default=1e-4)
250 | parser.add_argument("--min_lr", type=float, default=1e-6)
251 | parser.add_argument("--lr_sched", type=str, default="constant", choices=["constant", "cosine"])
252 | parser.add_argument("--warmup_rate", type=float, default=0.25, help="warmup_ep = warmup_rate * total_ep")
253 | parser.add_argument("--ema_rate", default=0.9999, type=float)
254 | parser.add_argument("--weight_decay", type=float, default=0.02)
255 | parser.add_argument("--grad_clip", type=float, default=3.0)
256 | parser.add_argument("--grad_checkpointing", action="store_true")
257 | parser.add_argument("--beta1", type=float, default=0.9)
258 | parser.add_argument("--beta2", type=float, default=0.95)
259 | parser.add_argument("--use_aligned_schedule", action="store_true")
260 |
261 | # generation parameters
262 | parser.add_argument("--num_iter", default=64, type=int, help="number of autoregressive steps for MAR")
263 | parser.add_argument("--noise_schedule", type=str, default="cosine", help="noise schedule for diffusion")
264 | parser.add_argument("--cfg", default=4.0, type=float, help="cfg value for diffusion")
265 | parser.add_argument("--cfg_schedule", default="linear", type=str, help="cfg schedule for diffusion")
266 | parser.add_argument("--cfg_list", default=None, type=float, nargs="+", help="cfg list for search")
267 |
268 | # mar parameters
269 | parser.add_argument("--label_drop_prob", default=0.1, type=float)
270 | parser.add_argument("--mask_ratio_min", type=float, default=0.7)
271 | parser.add_argument("--attn_dropout", type=float, default=0.1)
272 | parser.add_argument("--proj_dropout", type=float, default=0.1)
273 | parser.add_argument("--buffer_size", type=int, default=64)
274 |
275 | # diffusion loss parameters
276 | parser.add_argument("--diffloss_d", type=int, default=3)
277 | parser.add_argument("--diffloss_w", type=int, default=1024)
278 | parser.add_argument("--num_sampling_steps", type=str, default="100")
279 | parser.add_argument("--diffusion_batch_mul", type=int, default=4)
280 | parser.add_argument("--temperature", default=1.0, type=float)
281 |
282 | # dataset parameters
283 | parser.add_argument("--use_cached_tokens", action="store_true")
284 | parser.add_argument("--data_path", default="./data/imagenet/train", type=str)
285 | parser.add_argument("--num_classes", default=1000, type=int)
286 | parser.add_argument("--class_of_interest", default=[207, 360, 387, 974, 88, 979, 417, 279], type=int, nargs="+")
287 | parser.add_argument("--force_class_of_interest", action="store_true",
288 | help="generate images of only the class of interest for args.num_images images")
289 | parser.add_argument("--num_workers", default=10, type=int)
290 | parser.add_argument("--pin_mem", action="store_true")
291 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
292 | parser.set_defaults(pin_mem=True)
293 |
294 | # system parameters
295 | parser.add_argument("--seed", default=1, type=int)
296 |
297 | # wandb parameters
298 | parser.add_argument("--project", default="lDeTok", type=str)
299 | parser.add_argument("--entity", default="YOUR_WANDB_ENTITY", type=str)
300 | parser.add_argument("--exp_name", default=None, type=str)
301 | parser.add_argument("--enable_wandb", action="store_true")
302 |
303 | return parser
304 |
305 |
306 | if __name__ == "__main__":
307 | args = get_args_parser().parse_args()
308 | exit_code = main(args)
309 | sys.exit(exit_code)
310 |
--------------------------------------------------------------------------------
/main_reconstruction.py:
--------------------------------------------------------------------------------
1 | """
2 | DeTok: Reconstruction model training script.
3 | """
4 |
5 | import argparse
6 | import datetime
7 | import logging
8 | import sys
9 | import time
10 |
11 | import torch
12 | import torch.distributed
13 |
14 | import utils.distributed as distributed
15 | from utils.builders import (
16 | create_loss_module,
17 | create_optimizer_and_scaler,
18 | create_reconstruction_model,
19 | create_train_dataloader,
20 | create_val_dataloader,
21 | create_vis_dataloader,
22 | )
23 | from utils.misc import ckpt_resume, save_checkpoint
24 | from utils.train_utils import evaluate_tokenizer, setup, train_one_epoch_tokenizer, visualize_tokenizer
25 |
26 | # performance optimizations
27 | torch.backends.cuda.matmul.allow_tf32 = True
28 | torch.backends.cudnn.allow_tf32 = True
29 | torch.backends.cudnn.benchmark = True
30 | torch.backends.cudnn.deterministic = False
31 |
32 | logger = logging.getLogger("DeTok")
33 |
34 |
35 | def main(args: argparse.Namespace) -> int:
36 | global logger
37 | wandb_logger = setup(args)
38 |
39 | # initialize data loaders
40 | data_loader_train = create_train_dataloader(args)
41 | data_loader_val = create_val_dataloader(args)
42 | data_loader_vis = create_vis_dataloader(args)
43 | vis_iterator = iter(data_loader_vis)
44 |
45 | # initialize models and optimizers
46 | model, ema_model = create_reconstruction_model(args)
47 | if args.train_decoder_only and hasattr(model, "freeze_everything_but_decoder"):
48 | model.freeze_everything_but_decoder()
49 |
50 | optimizer, loss_scaler = create_optimizer_and_scaler(args, model)
51 | loss_fn = create_loss_module(args)
52 | discriminator_optimizer, discriminator_loss_scaler = create_optimizer_and_scaler(args, loss_fn)
53 |
54 | # setup distributed training
55 | if distributed.is_enabled():
56 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
57 | loss_fn = torch.nn.parallel.DistributedDataParallel(loss_fn, find_unused_parameters=True)
58 |
59 | # get models without DDP wrapper
60 | model_wo_ddp = model.module if hasattr(model, "module") else model
61 | loss_module_wo_ddp = loss_fn.module if hasattr(loss_fn, "module") else loss_fn
62 |
63 | # resume from checkpoint if needed
64 | ckpt_resume(
65 | args, model_wo_ddp, optimizer, loss_scaler, ema_model,
66 | loss_module_wo_ddp, discriminator_optimizer, discriminator_loss_scaler
67 | )
68 |
69 | # initial visualization
70 | visualize_tokenizer(args, model_wo_ddp, ema_model, next(vis_iterator), args.start_epoch)
71 |
72 | if args.vis_only:
73 | return 0
74 |
75 | # evaluation-only mode
76 | if args.evaluate:
77 | torch.cuda.empty_cache()
78 | for use_ema in [False, True]:
79 | evaluate_tokenizer(
80 | args, model_wo_ddp, ema_model, data_loader_val, args.start_epoch, wandb_logger, use_ema
81 | )
82 | return 0
83 |
84 | # training loop
85 | logger.info(f"Start training from {args.start_epoch} to {args.epochs}")
86 | start_time = time.time()
87 |
88 | for epoch in range(args.start_epoch, args.epochs):
89 | train_one_epoch_tokenizer(
90 | args, model, data_loader_train, optimizer, loss_scaler, wandb_logger, epoch,
91 | ema_model, loss_fn, discriminator_optimizer, discriminator_loss_scaler
92 | )
93 |
94 | # progress logging
95 | elapsed_t = time.time() - start_time + args.last_elapsed_time
96 | eta = elapsed_t / (epoch + 1) * (args.epochs - epoch - 1)
97 | logger.info(
98 | f"[{epoch}/{args.epochs}] "
99 | f"Accumulated elapsed time: {str(datetime.timedelta(seconds=int(elapsed_t)))}, "
100 | f"ETA: {str(datetime.timedelta(seconds=int(eta)))}"
101 | )
102 |
103 | # checkpointing
104 | should_save = (
105 | (epoch + 1) % args.save_freq == 0 # save every n epochs
106 | or (epoch + 1) == args.epochs # save at the end of training
107 | )
108 |
109 | if should_save:
110 | save_checkpoint(
111 | args, epoch, model_wo_ddp, optimizer, loss_scaler, ema_model, elapsed_t,
112 | loss_module_wo_ddp, discriminator_optimizer, discriminator_loss_scaler
113 | )
114 | torch.distributed.barrier()
115 |
116 | # periodic visualization
117 | if (epoch + 1) % args.vis_freq == 0:
118 | visualize_tokenizer(args, model_wo_ddp, ema_model, next(vis_iterator), epoch)
119 |
120 | # online evaluation
121 | if (args.online_eval and (epoch + 1) % args.eval_freq == 0 and (epoch + 1) != args.epochs):
122 | torch.cuda.empty_cache()
123 | for use_ema in [False, True]:
124 | evaluate_tokenizer(
125 | args, model_wo_ddp, ema_model, data_loader_val, epoch + 1, wandb_logger, use_ema
126 | )
127 |
128 | # final evaluation
129 | total_time = int(time.time() - start_time + args.last_elapsed_time)
130 | logger.info(f"Training time {str(datetime.timedelta(seconds=total_time))}")
131 |
132 | for use_ema in [False, True]:
133 | evaluate_tokenizer(args, model_wo_ddp, ema_model, data_loader_val, args.epochs, wandb_logger, use_ema)
134 |
135 | return 0
136 |
137 |
138 | def get_args_parser():
139 | parser = argparse.ArgumentParser("Reconstruction model training", add_help=False)
140 |
141 | # basic training parameters
142 | parser.add_argument("--start_epoch", default=0, type=int)
143 | parser.add_argument("--epochs", default=200, type=int)
144 | parser.add_argument("--batch_size", default=64, type=int, help="Batch size per GPU for training")
145 |
146 | # model parameters
147 | parser.add_argument("--model", default="detok_BB", type=str)
148 | parser.add_argument("--token_channels", default=16, type=int)
149 | parser.add_argument("--img_size", default=256, type=int)
150 | parser.add_argument("--patch_size", default=16, type=int)
151 |
152 | parser.add_argument("--mask_ratio", default=0.0, type=float)
153 | parser.add_argument("--gamma", default=0.0, type=float, help="noise standard deviation for training")
154 | parser.add_argument("--use_additive_noise", action="store_true")
155 |
156 | parser.add_argument("--no_load_ckpt", action="store_true")
157 | parser.add_argument("--train_decoder_only", action="store_true")
158 | parser.add_argument("--vis_only", action="store_true")
159 |
160 | # loss parameters
161 | parser.add_argument("--perceptual_loss", type=str, default="lpips-convnext_s-1.0-0.1")
162 | parser.add_argument("--perceptual_weight", default=1.0, type=float)
163 | parser.add_argument("--discriminator_start_epoch", default=20, type=int)
164 | parser.add_argument("--discriminator_weight", default=0.5, type=float)
165 | parser.add_argument("--kl_loss_weight", default=1e-6, type=float)
166 |
167 | # logging parameters
168 | parser.add_argument("--output_dir", default="./work_dirs")
169 | parser.add_argument("--print_freq", type=int, default=100)
170 | parser.add_argument("--eval_freq", type=int, default=10)
171 | parser.add_argument("--vis_freq", type=int, default=5)
172 | parser.add_argument("--save_freq", type=int, default=1)
173 | parser.add_argument("--last_elapsed_time", type=float, default=0.0)
174 |
175 | # checkpoint parameters
176 | parser.add_argument("--auto_resume", action="store_true")
177 | parser.add_argument("--resume_from", default=None, help="resume model weights and optimizer state")
178 | parser.add_argument("--load_from", type=str, default=None, help="load from pretrained model")
179 | parser.add_argument("--keep_n_ckpts", default=1, type=int, help="keep the last n checkpoints")
180 | parser.add_argument("--milestone_interval", default=100, type=int, help="keep checkpoints every n epochs")
181 |
182 |
183 | # evaluation parameters
184 | parser.add_argument("--num_images", default=50000, type=int, help="Number of images to evaluate on")
185 | parser.add_argument("--online_eval", action="store_true")
186 | parser.add_argument("--fid_stats_path", type=str, default="data/fid_stats/val_fid_statistics_file.npz")
187 | parser.add_argument("--keep_eval_folder", action="store_true")
188 | parser.add_argument("--evaluate", action="store_true")
189 | parser.add_argument("--eval_bsz", type=int, default=256)
190 |
191 | # optimization parameters
192 | parser.add_argument("--lr", type=float, default=None)
193 | parser.add_argument("--blr", type=float, default=1e-4)
194 | parser.add_argument("--min_lr", type=float, default=0.0)
195 | parser.add_argument("--lr_sched", type=str, default="cosine", choices=["constant", "cosine"])
196 | parser.add_argument("--warmup_rate", type=float, default=0.25)
197 | parser.add_argument("--ema_rate", default=0.999, type=float)
198 | parser.add_argument("--weight_decay", type=float, default=1e-4)
199 | parser.add_argument("--grad_clip", type=float, default=3.0)
200 | parser.add_argument("--grad_checkpointing", action="store_true", help="Use gradient checkpointing")
201 | parser.add_argument("--beta1", type=float, default=0.9, help="Beta1 for AdamW optimizer")
202 | parser.add_argument("--beta2", type=float, default=0.95, help="Beta2 for AdamW optimizer")
203 |
204 | # dataset parameters
205 | parser.add_argument("--use_cached_tokens", action="store_true")
206 | parser.add_argument("--data_path", default="./data/imagenet/train", type=str)
207 | parser.add_argument("--num_classes", default=1000, type=int)
208 | parser.add_argument("--class_of_interest", default=[207, 360, 387, 974, 88, 979, 417, 279], type=int, nargs="+")
209 | parser.add_argument("--num_workers", default=10, type=int)
210 | parser.add_argument("--pin_mem", action="store_true")
211 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem")
212 | parser.set_defaults(pin_mem=True)
213 |
214 | # system parameters
215 | parser.add_argument("--seed", default=1, type=int)
216 |
217 | # wandb parameters
218 | parser.add_argument("--project", default="lDeTok", type=str)
219 | parser.add_argument("--entity", default="YOUR_WANDB_ENTITY", type=str)
220 | parser.add_argument("--exp_name", default=None, type=str)
221 | parser.add_argument("--enable_wandb", action="store_true")
222 |
223 | return parser
224 |
225 |
226 | if __name__ == "__main__":
227 | args = get_args_parser().parse_args()
228 | exit_code = main(args)
229 | sys.exit(exit_code)
230 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .ardiff import ARDiff_models
2 | from .autoencoder import AutoencoderKL, VAE_models
3 | from .detok import DeTok_models
4 | from .dit import DiT_models
5 | from .ema import SimpleEMAModel
6 | from .lightningdit import LightningDiT_models
7 | from .mar import MAR_models
8 | from .sit import SiT_models
9 |
--------------------------------------------------------------------------------
/models/ardiff.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from functools import partial
3 |
4 | import torch
5 | import torch.nn as nn
6 | from torch.utils.checkpoint import checkpoint
7 | from tqdm import tqdm
8 |
9 | from .diffloss import DiffLoss
10 | from .layers import Block, Transformer, modulate
11 | from .model_utils import SIZE_DICT
12 |
13 | logger = logging.getLogger("DeTok")
14 |
15 |
16 | class FinalLayer(nn.Module):
17 | """final layer with adaptive layer normalization."""
18 |
19 | def __init__(self, in_features) -> None:
20 | super().__init__()
21 | self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6)
22 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(in_features, 2 * in_features))
23 |
24 | def forward(self, x, condition):
25 | shift, scale = self.adaLN_modulation(condition).chunk(2, dim=-1)
26 | x = modulate(self.norm(x), shift, scale)
27 | return x
28 |
29 |
30 | class ARDiff(nn.Module):
31 | """decoder-only autoregressive diffusion model."""
32 |
33 | def __init__(
34 | self,
35 | img_size=256,
36 | patch_size=1,
37 | model_size="base",
38 | tokenizer_patch_size=16,
39 | token_channels=16,
40 | label_drop_prob=0.1,
41 | num_classes=1000,
42 | # diffloss parameters
43 | noise_schedule="cosine",
44 | diffloss_d=3,
45 | diffloss_w=1024,
46 | diffusion_batch_mul=4,
47 | # sampling parameters
48 | num_sampling_steps=100,
49 | grad_checkpointing=False,
50 | force_one_d_seq=False,
51 | order="raster",
52 | ):
53 | super().__init__()
54 |
55 | # --------------------------------------------------------------------------
56 | # basic configuration
57 | self.img_size = img_size
58 | self.patch_size = patch_size
59 | self.token_channels = token_channels
60 | self.num_classes = num_classes
61 | self.label_drop_prob = label_drop_prob
62 | self.grad_checkpointing = grad_checkpointing
63 | self.force_one_d_seq = force_one_d_seq
64 | self.order = order
65 | self.diffusion_batch_mul = diffusion_batch_mul
66 |
67 | # sequence dimensions
68 | self.seq_h = self.seq_w = img_size // tokenizer_patch_size // patch_size
69 | self.seq_len = self.seq_h * self.seq_w + 1 # +1 for BOS token
70 | self.token_embed_dim = token_channels * patch_size**2
71 |
72 | if force_one_d_seq:
73 | self.seq_len = force_one_d_seq + 1
74 |
75 | # model architecture configuration
76 | size_dict = SIZE_DICT[model_size]
77 | num_layers, num_heads, width = size_dict["layers"], size_dict["heads"], size_dict["width"]
78 |
79 | self.label_drop_prob = label_drop_prob
80 |
81 | scale = width**-0.5
82 |
83 | # class and null token embeddings
84 | self.class_emb = nn.Embedding(self.num_classes, width)
85 | self.fake_latent = nn.Parameter(scale * torch.randn(1, width))
86 | self.bos_token = nn.Parameter(torch.zeros(1, 1, width))
87 |
88 | # input and positional embeddings
89 | self.x_embedder = nn.Linear(self.token_embed_dim, width)
90 | self.pos_embed = nn.Parameter(scale * torch.randn((1, self.seq_len, width)))
91 | self.target_pos_embed = nn.Parameter(scale * torch.randn((1, self.seq_len - 1, width)))
92 | self.timesteps_embeddings = nn.Parameter(scale * torch.randn((1, self.seq_len, width)))
93 |
94 | # training mask for causal attention
95 | self.train_mask = torch.tril(torch.ones(self.seq_len, self.seq_len, dtype=torch.bool)).cuda()
96 |
97 | # --------------------------------------------------------------------------
98 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
99 |
100 | self.ln_pre = norm_layer(width)
101 | self.transformer = Transformer(
102 | width,
103 | num_layers,
104 | num_heads,
105 | block_fn=partial(Block, use_modulation=True),
106 | norm_layer=norm_layer,
107 | force_causal=True,
108 | grad_checkpointing=self.grad_checkpointing,
109 | )
110 | self.final_layer = FinalLayer(width)
111 | self.initialize_weights()
112 |
113 | # --------------------------------------------------------------------------
114 | # Diffusion Loss
115 | self.diffloss = DiffLoss(
116 | target_channels=self.token_embed_dim,
117 | z_channels=width,
118 | width=diffloss_w,
119 | depth=diffloss_d,
120 | num_sampling_steps=num_sampling_steps,
121 | grad_checkpointing=grad_checkpointing,
122 | noise_schedule=noise_schedule,
123 | )
124 | self.diffusion_batch_mul = diffusion_batch_mul
125 | params_M = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
126 | logger.info(f"[ARDiff] params: {params_M:.2f}M, {model_size}-{num_layers}-{width}")
127 |
128 | def initialize_weights(self):
129 | """initialize model weights."""
130 | # parameter initialization
131 | torch.nn.init.normal_(self.pos_embed, std=0.02)
132 | torch.nn.init.normal_(self.bos_token, std=0.02)
133 | torch.nn.init.normal_(self.target_pos_embed, std=0.02)
134 | torch.nn.init.normal_(self.timesteps_embeddings, std=0.02)
135 | torch.nn.init.normal_(self.class_emb.weight, std=0.02)
136 | torch.nn.init.normal_(self.fake_latent, std=0.02)
137 |
138 | # apply standard initialization
139 | self.apply(self._init_weights)
140 |
141 | def _init_weights(self, m):
142 | """standard weight initialization for layers."""
143 | if isinstance(m, nn.Linear):
144 | torch.nn.init.xavier_uniform_(m.weight)
145 | if m.bias is not None:
146 | nn.init.constant_(m.bias, 0)
147 | elif isinstance(m, nn.LayerNorm):
148 | if m.bias is not None:
149 | nn.init.constant_(m.bias, 0)
150 | if m.weight is not None:
151 | nn.init.constant_(m.weight, 1.0)
152 |
153 | # zero-out adaptive modulation layers
154 | for block in self.transformer.blocks:
155 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
156 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
157 |
158 | # zero-out final layer modulation
159 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
160 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
161 |
162 | def patchify(self, x):
163 | """convert image tensor to patch tokens."""
164 | bsz, c, h, w = x.shape
165 | p = self.patch_size
166 | h_, w_ = h // p, w // p
167 |
168 | x = x.reshape(bsz, c, h_, p, w_, p)
169 | x = torch.einsum("nchpwq->nhwcpq", x)
170 | x = x.reshape(bsz, h_ * w_, c * p**2)
171 | return x # [batch, seq_len, token_dim]
172 |
173 | def unpatchify(self, x):
174 | """convert patch tokens back to image tensor."""
175 | bsz = x.shape[0]
176 | p = self.patch_size
177 | c = self.token_channels
178 | h_, w_ = self.seq_h, self.seq_w
179 |
180 | x = x.reshape(bsz, h_, w_, c, p, p)
181 | x = torch.einsum("nhwcpq->nchpwq", x)
182 | x = x.reshape(bsz, c, h_ * p, w_ * p)
183 | return x # [batch, channels, height, width]
184 |
185 | def enable_kv_cache(self):
186 | for block in self.transformer.blocks:
187 | block.attn.kv_cache = True
188 | block.attn.reset_kv_cache()
189 | logger.info("Enable kv_cache for Transformer blocks")
190 |
191 | def disable_kv_cache(self):
192 | for block in self.transformer.blocks:
193 | block.attn.kv_cache = False
194 | block.attn.reset_kv_cache()
195 | logger.info("Disable kv_cache for Transformer blocks")
196 |
197 | def get_random_orders(self, x):
198 | """generate random token ordering."""
199 | batch_size = x.shape[0]
200 | random_noise = torch.randn(batch_size, self.seq_len - 1, device=x.device)
201 | shuffled_orders = torch.argsort(random_noise, dim=1)
202 | return shuffled_orders
203 |
204 | def get_raster_orders(self, x):
205 | """generate raster (sequential) token ordering."""
206 | batch_size = x.shape[0]
207 | raster_orders = torch.arange(self.seq_len - 1, device=x.device)
208 | shuffled_orders = torch.stack([raster_orders for _ in range(batch_size)])
209 | return shuffled_orders
210 |
211 | def shuffle(self, x, orders):
212 | """shuffle tokens according to given orders."""
213 | batch_size, seq_len = x.shape[:2]
214 | batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, seq_len)
215 | shuffled_x = x[batch_indices, orders]
216 | return shuffled_x
217 |
218 | def unshuffle(self, shuffled_x, orders):
219 | """unshuffle tokens to restore original ordering."""
220 | batch_size, seq_len = shuffled_x.shape[:2]
221 | batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, seq_len)
222 | unshuffled_x = torch.zeros_like(shuffled_x)
223 | unshuffled_x[batch_indices, orders] = shuffled_x
224 | return unshuffled_x
225 |
226 | def forward_transformer(self, x, class_embedding, orders=None):
227 | """forward pass through the transformer."""
228 | x = self.x_embedder(x)
229 | bsz = x.shape[0]
230 |
231 | # add BOS token
232 | bos_token = self.bos_token.expand(bsz, 1, -1)
233 | x = torch.cat([bos_token, x], dim=1)
234 | current_seq_len = x.shape[1]
235 |
236 | # add positional embeddings
237 | pos_embed = self.pos_embed.expand(bsz, -1, -1)
238 | if orders is not None:
239 | pos_embed = torch.cat([pos_embed[:, :1], self.shuffle(pos_embed[:, 1:], orders)], dim=1)
240 | x = x + pos_embed[:, :current_seq_len]
241 |
242 | # add target positional embeddings
243 | target_pos_embed = self.target_pos_embed.expand(bsz, -1, -1)
244 | embed_dim = target_pos_embed.shape[-1]
245 | if orders is not None:
246 | target_pos_embed = self.shuffle(target_pos_embed, orders)
247 | target_pos_embed = torch.cat([target_pos_embed, torch.zeros(bsz, 1, embed_dim).to(x.device)], dim=1)
248 | x = x + target_pos_embed[:, :current_seq_len]
249 |
250 | x = self.ln_pre(x)
251 |
252 | # prepare condition tokens
253 | condition_token = class_embedding.repeat(1, current_seq_len, 1)
254 | timestep_embed = self.timesteps_embeddings.expand(bsz, -1, -1)
255 | condition_token = condition_token + timestep_embed[:, :current_seq_len]
256 |
257 | # handle kv cache for inference
258 | if self.transformer.blocks[0].attn.kv_cache:
259 | x = x[:, -1:]
260 | condition_token = condition_token[:, -1:]
261 |
262 | # transformer forward pass
263 | for block in self.transformer.blocks:
264 | if self.grad_checkpointing and self.training:
265 | x = checkpoint(block, x, None, None, condition_token)
266 | else:
267 | x = block(x, condition=condition_token)
268 |
269 | x = self.final_layer(x, condition=class_embedding)
270 | return x
271 |
272 | def forward_loss(self, z, target):
273 | """compute diffusion loss."""
274 | bsz, seq_len, _ = target.shape
275 | target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
276 | z = z.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
277 | return self.diffloss(z=z, target=target)
278 |
279 | def forward(self, x, labels):
280 | """forward pass for training."""
281 | # get token ordering
282 | if self.order == "raster":
283 | orders = self.get_raster_orders(x)
284 | elif self.order == "random":
285 | orders = self.get_random_orders(x)
286 | else:
287 | raise NotImplementedError(f"Order '{self.order}' not implemented")
288 |
289 | # prepare class embeddings
290 | class_embedding = self.class_emb(labels)
291 | if self.training:
292 | # randomly drop class embedding during training
293 | drop_mask = torch.rand(x.shape[0]) < self.label_drop_prob
294 | drop_mask = drop_mask.unsqueeze(-1).cuda().to(x.dtype)
295 | class_embedding = drop_mask * self.fake_latent + (1 - drop_mask) * class_embedding
296 | class_embedding = class_embedding.unsqueeze(1)
297 |
298 | # prepare input tokens
299 | x = self.patchify(x) if not self.force_one_d_seq else x
300 | x = self.shuffle(x, orders)
301 | gt_latents = x.clone().detach()
302 |
303 | # forward pass and loss computation
304 | z = self.forward_transformer(x[:, :-1], class_embedding, orders=orders)
305 | return self.forward_loss(z=z, target=gt_latents)
306 |
307 | def sample_tokens(
308 | self,
309 | bsz,
310 | cfg=1.0,
311 | cfg_schedule="linear",
312 | labels=None,
313 | temperature=1.0,
314 | progress=False,
315 | kv_cache=False,
316 | ):
317 | """sample tokens autoregressively."""
318 | tokens = torch.zeros(bsz, 0, self.token_embed_dim).cuda()
319 | indices = list(range(self.seq_len - 1))
320 |
321 | # setup kv cache if requested
322 | if kv_cache:
323 | self.enable_kv_cache()
324 |
325 | if progress:
326 | indices = tqdm(indices)
327 |
328 | # get token ordering
329 | if self.order == "raster":
330 | orders = self.get_raster_orders(torch.zeros(bsz, self.seq_len - 1, self.token_embed_dim).cuda())
331 | elif self.order == "random":
332 | orders = self.get_random_orders(torch.zeros(bsz, self.seq_len - 1, self.token_embed_dim).cuda())
333 | else:
334 | raise NotImplementedError(f"Order '{self.order}' not implemented")
335 |
336 | # prepare for classifier-free guidance
337 | if cfg != 1.0:
338 | orders = torch.cat([orders, orders], dim=0)
339 |
340 | # generate tokens step by step
341 | for step in indices:
342 | cur_tokens = tokens.clone()
343 |
344 | # prepare class embeddings and CFG
345 | cls_embd = self.fake_latent.repeat(bsz, 1) if labels is None else self.class_emb(labels)
346 |
347 | if cfg != 1.0:
348 | tokens = torch.cat([tokens, tokens], dim=0)
349 | cls_embd = torch.cat([cls_embd, self.fake_latent.repeat(bsz, 1)], dim=0)
350 | cls_embd = cls_embd.unsqueeze(1)
351 | z = self.forward_transformer(tokens, cls_embd, orders=orders)[:, -1]
352 |
353 | # apply CFG schedule
354 | if cfg_schedule == "linear":
355 | cfg_iter = 1 + (cfg - 1) * step / len(indices)
356 | elif cfg_schedule == "constant":
357 | cfg_iter = cfg
358 | else:
359 | raise NotImplementedError(f"CFG schedule '{cfg_schedule}' not implemented")
360 |
361 | # sample next token
362 | sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
363 |
364 | if cfg != 1.0:
365 | sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
366 |
367 | cur_tokens = torch.cat([cur_tokens, sampled_token_latent.unsqueeze(1)], dim=1)
368 | tokens = cur_tokens.clone()
369 |
370 | # cleanup
371 | if kv_cache:
372 | self.disable_kv_cache()
373 |
374 | if cfg != 1.0:
375 | orders, _ = orders.chunk(2, dim=0)
376 |
377 | # restore original ordering and convert back to image format
378 | tokens = self.unshuffle(tokens, orders)
379 | if not self.force_one_d_seq:
380 | tokens = self.unpatchify(tokens)
381 |
382 | return tokens
383 |
384 | def generate(self, n_samples, cfg, labels, args):
385 | """generate samples using the model."""
386 | return self.sample_tokens(
387 | n_samples,
388 | cfg=cfg,
389 | labels=labels,
390 | cfg_schedule=args.cfg_schedule,
391 | temperature=args.temperature,
392 | progress=True,
393 | kv_cache=False,
394 | )
395 |
396 |
397 | # model size variants
398 | def ARDiff_base(**kwargs):
399 | return ARDiff(model_size="base", **kwargs)
400 |
401 |
402 | def ARDiff_large(**kwargs):
403 | return ARDiff(model_size="large", **kwargs)
404 |
405 |
406 | def ARDiff_xl(**kwargs):
407 | return ARDiff(model_size="xl", **kwargs)
408 |
409 |
410 | def ARDiff_huge(**kwargs):
411 | return ARDiff(model_size="huge", **kwargs)
412 |
413 |
414 | ARDiff_models = {
415 | "ARDiff_base": ARDiff_base,
416 | "ARDiff_large": ARDiff_large,
417 | "ARDiff_huge": ARDiff_huge,
418 | "ARDiff_xl": ARDiff_xl,
419 | }
420 |
--------------------------------------------------------------------------------
/models/diffloss.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/LTH14/mar/blob/main/models/diffloss.py
3 | """
4 | import torch
5 | import torch.nn as nn
6 | from torch.utils.checkpoint import checkpoint
7 |
8 | from diffusion import create_diffusion
9 | from transport import Sampler, create_transport
10 |
11 | from .layers import ModulatedLinear, TimestepEmbedder, modulate
12 |
13 |
14 | class DiffLoss(nn.Module):
15 | """diffusion loss module for training."""
16 |
17 | def __init__(
18 | self,
19 | target_channels,
20 | z_channels,
21 | depth,
22 | width,
23 | num_sampling_steps,
24 | grad_checkpointing=False,
25 | noise_schedule="cosine",
26 | use_transport=False,
27 | timestep_shift=0.3,
28 | learn_sigma=True,
29 | sampling_method="euler",
30 | ):
31 | super(DiffLoss, self).__init__()
32 |
33 | # --------------------------------------------------------------------------
34 | # basic configuration
35 | self.in_channels = target_channels
36 | self.noise_schedule = noise_schedule
37 | self.use_transport = use_transport
38 |
39 | # --------------------------------------------------------------------------
40 | # network architecture
41 | self.net = SimpleMLPAdaLN(
42 | in_channels=target_channels,
43 | model_channels=width,
44 | out_channels=target_channels * 2 if learn_sigma else target_channels,
45 | z_channels=z_channels,
46 | num_res_blocks=depth,
47 | grad_checkpointing=grad_checkpointing,
48 | use_transport=use_transport,
49 | )
50 |
51 | # --------------------------------------------------------------------------
52 | # diffusion/transport setup
53 | if self.use_transport:
54 | self.transport = create_transport(use_cosine_loss=True, use_lognorm=True)
55 | self.sampler = Sampler(self.transport)
56 | self.sample_fn = self.sampler.sample_ode(
57 | sampling_method=sampling_method,
58 | num_steps=int(num_sampling_steps),
59 | timestep_shift=timestep_shift,
60 | )
61 | else:
62 | self.train_diffusion = create_diffusion("", noise_schedule=noise_schedule)
63 | self.gen_diffusion = create_diffusion(num_sampling_steps, noise_schedule=noise_schedule)
64 |
65 | def forward(self, target, z, mask=None):
66 | """forward pass for training."""
67 | if self.use_transport:
68 | model_kwargs = dict(c=z)
69 | loss_dict = self.transport.training_losses(self.net, target, model_kwargs)
70 | else:
71 | t = torch.randint(
72 | 0,
73 | self.train_diffusion.num_timesteps,
74 | (target.shape[0],),
75 | device=target.device,
76 | )
77 | model_kwargs = dict(c=z)
78 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
79 |
80 | loss = loss_dict["loss"]
81 | if mask is not None:
82 | loss = (loss * mask).sum() / mask.sum()
83 | return loss.mean()
84 |
85 | def sample(self, z, temperature=1.0, cfg=1.0):
86 | """sample from the diffusion model."""
87 | if cfg != 1.0:
88 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda()
89 | noise = torch.cat([noise, noise], dim=0)
90 | if self.use_transport:
91 | model_kwargs = dict(c=z, cfg_scale=cfg, cfg_interval=True, cfg_interval_start=0.10)
92 | else:
93 | model_kwargs = dict(c=z, cfg_scale=cfg)
94 | sample_fn = self.net.forward_with_cfg
95 | else:
96 | noise = torch.randn(z.shape[0], self.in_channels).cuda()
97 | model_kwargs = dict(c=z)
98 | sample_fn = self.net.forward
99 |
100 | if self.use_transport:
101 | sampled_token_latent = self.sample_fn(noise, sample_fn, **model_kwargs)[-1]
102 | else:
103 | sampled_token_latent = self.gen_diffusion.p_sample_loop(
104 | sample_fn,
105 | noise.shape,
106 | noise,
107 | clip_denoised=False,
108 | model_kwargs=model_kwargs,
109 | progress=False,
110 | temperature=temperature,
111 | )
112 | return sampled_token_latent
113 |
114 |
115 | class ResBlock(nn.Module):
116 | """residual block with adaptive layer normalization."""
117 |
118 | def __init__(self, channels):
119 | super().__init__()
120 | self.channels = channels
121 | self.in_ln = nn.LayerNorm(channels, eps=1e-6)
122 | self.mlp = nn.Sequential(
123 | nn.Linear(channels, channels, bias=True),
124 | nn.SiLU(),
125 | nn.Linear(channels, channels, bias=True),
126 | )
127 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True))
128 |
129 | def forward(self, x, y):
130 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
131 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
132 | h = self.mlp(h)
133 | return x + gate_mlp * h
134 |
135 |
136 | class SimpleMLPAdaLN(nn.Module):
137 | """simple MLP with adaptive layer normalization for diffusion loss."""
138 |
139 | def __init__(
140 | self,
141 | in_channels,
142 | model_channels,
143 | out_channels,
144 | z_channels,
145 | num_res_blocks,
146 | grad_checkpointing=False,
147 | use_transport=False,
148 | ):
149 | super().__init__()
150 |
151 | # --------------------------------------------------------------------------
152 | # basic configuration
153 | self.in_channels = in_channels
154 | self.model_channels = model_channels
155 | self.out_channels = out_channels
156 | self.num_res_blocks = num_res_blocks
157 | self.grad_checkpointing = grad_checkpointing
158 | self.use_transport = use_transport
159 |
160 | # --------------------------------------------------------------------------
161 | # network layers
162 | self.time_embed = TimestepEmbedder(model_channels)
163 | self.cond_embed = nn.Linear(z_channels, model_channels)
164 | self.input_proj = nn.Linear(in_channels, model_channels)
165 | self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)])
166 | self.final_layer = ModulatedLinear(model_channels, out_channels)
167 |
168 | self.initialize_weights()
169 |
170 | def initialize_weights(self):
171 | """initialize model weights."""
172 |
173 | def _basic_init(module):
174 | if isinstance(module, nn.Linear):
175 | torch.nn.init.xavier_uniform_(module.weight)
176 | if module.bias is not None:
177 | nn.init.constant_(module.bias, 0)
178 |
179 | self.apply(_basic_init)
180 |
181 | # initialize timestep embedding MLP
182 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
183 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
184 |
185 | # zero-out adaLN modulation layers
186 | for block in self.res_blocks:
187 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
188 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
189 |
190 | # zero-out output layers
191 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
192 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
193 | nn.init.constant_(self.final_layer.linear.weight, 0)
194 | nn.init.constant_(self.final_layer.linear.bias, 0)
195 |
196 | def forward(self, x, t=None, c=None):
197 | """apply the model to an input batch."""
198 | x = self.input_proj(x)
199 | t = self.time_embed(t)
200 | c = self.cond_embed(c)
201 | y = t + c
202 |
203 | for block in self.res_blocks:
204 | if self.grad_checkpointing and self.training:
205 | x = checkpoint(block, x, y)
206 | else:
207 | x = block(x, y)
208 | return self.final_layer(x, y)
209 |
210 | def forward_with_cfg(self, x, t, c, cfg_scale, cfg_interval=None, cfg_interval_start=None):
211 | """forward pass with classifier-free guidance."""
212 | half = x[: len(x) // 2]
213 | combined = torch.cat([half, half], dim=0)
214 | model_out = self.forward(combined, t, c)
215 | eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
216 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
217 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
218 | if cfg_interval is True:
219 | timestep = t[0]
220 | if timestep < cfg_interval_start:
221 | half_eps = cond_eps
222 |
223 | eps = torch.cat([half_eps, half_eps], dim=0)
224 | return torch.cat([eps, rest], dim=1)
225 |
--------------------------------------------------------------------------------
/models/dit.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/facebookresearch/DiT/blob/main/models.py
3 | - add support for 1D sequence
4 | - include samplers inside the model
5 | """
6 |
7 | import logging
8 | from functools import partial
9 |
10 | import torch
11 | import torch.nn as nn
12 |
13 | from diffusion import create_diffusion
14 |
15 | from .layers import (
16 | Block,
17 | LabelEmbedder,
18 | ModulatedLinear,
19 | PatchEmbed,
20 | TimestepEmbedder,
21 | Transformer,
22 | get_2d_sincos_pos_embed,
23 | )
24 | from .model_utils import SIZE_DICT
25 |
26 | logger = logging.getLogger("DeTok")
27 |
28 |
29 | class DiT(nn.Module):
30 | """diffusion model with a transformer backbone."""
31 |
32 | def __init__(
33 | self,
34 | img_size=256,
35 | patch_size=1,
36 | model_size="base",
37 | tokenizer_patch_size=16,
38 | token_channels=16,
39 | label_drop_prob=0.1,
40 | num_classes=1000,
41 | learn_sigma=True,
42 | noise_schedule="linear",
43 | num_sampling_steps=250,
44 | grad_checkpointing=False,
45 | force_one_d_seq=0,
46 | legacy_mode=False,
47 | ):
48 | super().__init__()
49 |
50 | # --------------------------------------------------------------------------
51 | # basic configuration
52 | self.learn_sigma = learn_sigma
53 | self.token_channels = token_channels
54 | self.out_channels = token_channels * 2 if learn_sigma else token_channels
55 | self.input_size = img_size // tokenizer_patch_size
56 | self.patch_size = patch_size
57 | self.num_classes = num_classes
58 | self.force_one_d_seq = force_one_d_seq
59 | self.grad_checkpointing = grad_checkpointing
60 | self.legacy_mode = legacy_mode
61 |
62 | # model architecture configuration
63 | size_dict = SIZE_DICT[model_size]
64 | num_layers, num_heads, width = size_dict["layers"], size_dict["heads"], size_dict["width"]
65 |
66 | # --------------------------------------------------------------------------
67 | # embedding layers
68 | if self.force_one_d_seq:
69 | self.x_embedder = nn.Linear(token_channels, width)
70 | self.pos_embed = nn.Parameter(torch.randn(1, self.force_one_d_seq, width) * 0.02)
71 | else:
72 | self.x_embedder = PatchEmbed(self.input_size, patch_size, token_channels, width)
73 | num_patches = self.x_embedder.num_patches
74 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, width), requires_grad=False)
75 |
76 | self.t_embedder = TimestepEmbedder(width)
77 | self.y_embedder = LabelEmbedder(num_classes, width, label_drop_prob)
78 |
79 | # --------------------------------------------------------------------------
80 | # transformer architecture
81 | self.transformer = Transformer(
82 | width,
83 | num_layers,
84 | num_heads,
85 | block_fn=partial(Block, use_modulation=True),
86 | norm_layer=partial(nn.LayerNorm, elementwise_affine=False, eps=1e-6),
87 | grad_checkpointing=grad_checkpointing,
88 | )
89 | self.final_layer = ModulatedLinear(width, patch_size * patch_size * self.out_channels)
90 |
91 | # --------------------------------------------------------------------------
92 | # diffusion setup
93 | self.train_diffusion = create_diffusion("", noise_schedule=noise_schedule)
94 | self.gen_diffusion = create_diffusion(num_sampling_steps, noise_schedule=noise_schedule)
95 | self.initialize_weights()
96 |
97 | # log model info
98 | num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
99 | logger.info(
100 | f"[DiT] params: {num_trainable_params:.2f}M size: {model_size}, num_layers: {num_layers}, width: {width}"
101 | )
102 |
103 | def initialize_weights(self):
104 | """initialize model weights."""
105 |
106 | def _basic_init(module):
107 | if isinstance(module, nn.Linear):
108 | torch.nn.init.xavier_uniform_(module.weight)
109 | if hasattr(module, "bias"):
110 | nn.init.constant_(module.bias, 0)
111 |
112 | self.apply(_basic_init)
113 | """initialize (and freeze) pos_embed by sin-cos embedding"""
114 | if not self.force_one_d_seq:
115 | pos_embed = get_2d_sincos_pos_embed(
116 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)
117 | )
118 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
119 |
120 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
121 | w = self.x_embedder.proj.weight.data
122 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
123 | nn.init.constant_(self.x_embedder.proj.bias, 0)
124 |
125 | # initialize label embedding table
126 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
127 |
128 | # initialize timestep embedding MLP
129 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
130 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
131 |
132 | # zero-out adaLN modulation layers in DiT blocks
133 | for block in self.transformer.blocks:
134 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
135 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
136 |
137 | # zero-out output layers
138 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
139 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
140 | nn.init.constant_(self.final_layer.linear.weight, 0)
141 | nn.init.constant_(self.final_layer.linear.bias, 0)
142 |
143 | def unpatchify(self, x):
144 | """convert patch tokens back to image tensor."""
145 | c, p = self.out_channels, self.patch_size
146 | h = w = int(x.shape[1] ** 0.5)
147 | assert h * w == x.shape[1]
148 |
149 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
150 | x = torch.einsum("nhwpqc->nchpwq", x)
151 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
152 | return imgs
153 |
154 | def net(self, x, t, y):
155 | """core network forward pass."""
156 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
157 | c = self.t_embedder(t) + self.y_embedder(y, self.training) # (N, D)
158 | x = self.transformer(x, condition=c) # (N, T, D)
159 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
160 | if not self.force_one_d_seq:
161 | x = self.unpatchify(x)
162 | return x
163 |
164 | def forward_with_cfg(self, x, t, y, cfg_scale):
165 | """forward pass with classifier-free guidance."""
166 | half = x[: len(x) // 2]
167 | combined = torch.cat([half, half], dim=0)
168 | model_out = self.net(combined, t, y)
169 |
170 | if self.legacy_mode:
171 | eps, rest = model_out[:, :3], model_out[:, 3:]
172 | else:
173 | eps, rest = model_out[:, : self.token_channels], model_out[:, self.token_channels :]
174 |
175 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
176 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
177 | eps = torch.cat([half_eps, half_eps], dim=0)
178 | return torch.cat([eps, rest], dim=1)
179 |
180 | def forward(self, x, y):
181 | """forward pass for training."""
182 | t = torch.randint(0, self.train_diffusion.num_timesteps, (x.shape[0],), device=x.device)
183 | loss_dict = self.train_diffusion.training_losses(self.net, x, t, dict(y=y))
184 | return loss_dict["loss"].mean()
185 |
186 | @torch.inference_mode()
187 | def generate(self, n_samples, labels, cfg=1.0, args=None):
188 | """generate samples using the model."""
189 | device = labels.device
190 |
191 | # prepare noise tensor
192 | if self.force_one_d_seq:
193 | z = torch.randn(n_samples, self.force_one_d_seq, self.token_channels)
194 | else:
195 | z = torch.randn(n_samples, self.token_channels, self.input_size, self.input_size)
196 | z = z.to(device)
197 |
198 | # setup classifier-free guidance
199 | if cfg > 1.0:
200 | z = torch.cat([z, z], 0)
201 | labels = torch.cat([labels, torch.full_like(labels, self.num_classes)], 0)
202 | model_kwargs = dict(y=labels, cfg_scale=cfg)
203 | sample_fn = self.forward_with_cfg
204 | else:
205 | model_kwargs = dict(y=labels)
206 | sample_fn = self.net
207 |
208 | # generate samples
209 | samples = self.gen_diffusion.p_sample_loop(
210 | sample_fn,
211 | z.shape,
212 | z,
213 | clip_denoised=False,
214 | model_kwargs=model_kwargs,
215 | progress=True,
216 | device=device,
217 | )
218 |
219 | if cfg > 1.0:
220 | samples, _ = samples.chunk(2, dim=0) # remove null class samples
221 | return samples
222 |
223 |
224 | # model size variants
225 | def DiT_base(**kwargs):
226 | return DiT(model_size="base", **kwargs)
227 |
228 |
229 | def DiT_large(**kwargs):
230 | return DiT(model_size="large", **kwargs)
231 |
232 |
233 | def DiT_xl(**kwargs):
234 | return DiT(model_size="xl", **kwargs)
235 |
236 |
237 | def DiT_huge(**kwargs):
238 | return DiT(model_size="huge", **kwargs)
239 |
240 |
241 | DiT_models = {"DiT_base": DiT_base, "DiT_large": DiT_large, "DiT_xl": DiT_xl, "DiT_huge": DiT_huge}
242 |
--------------------------------------------------------------------------------
/models/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class SimpleEMAModel:
5 | """simple exponential moving average model"""
6 |
7 | def __init__(self, model: torch.nn.Module, decay: float = 0.9999):
8 | self.ema_params = {}
9 | self.temp_stored_params = {}
10 | self.decay = decay
11 |
12 | # initialize EMA parameters
13 | for name, param in model.named_parameters():
14 | if param.requires_grad:
15 | self.ema_params[name] = param.clone().detach()
16 | else:
17 | self.ema_params[name] = param
18 |
19 | @torch.inference_mode()
20 | def step(self, model: torch.nn.Module):
21 | """update EMA parameters with current model parameters."""
22 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
23 | model = model.module
24 |
25 | for name, param in model.named_parameters():
26 | if param.requires_grad:
27 | self.ema_params[name].mul_(self.decay).add_(param, alpha=1 - self.decay)
28 | else:
29 | self.ema_params[name].copy_(param)
30 |
31 | def copy_to(self, model: torch.nn.Module) -> None:
32 | """copy current averaged parameters into given model."""
33 | for name, param in model.named_parameters():
34 | param.data.copy_(self.ema_params[name].to(param.device).data)
35 |
36 | def to(self, device=None, dtype=None) -> None:
37 | """move internal buffers to specified device."""
38 | # .to() on the tensors handles None correctly
39 | for name, param in self.ema_params.items():
40 | self.ema_params[name] = (
41 | self.ema_params[name].to(device=device, dtype=dtype)
42 | if self.ema_params[name].is_floating_point()
43 | else self.ema_params[name].to(device=device)
44 | )
45 |
46 | def store(self, model: torch.nn.Module) -> None:
47 | """store current model parameters temporarily."""
48 | for name, param in model.named_parameters():
49 | self.temp_stored_params[name] = param.detach().cpu().clone()
50 |
51 | def restore(self, model: torch.nn.Module) -> None:
52 | """restore parameters stored with the store method."""
53 | if self.temp_stored_params is None:
54 | raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
55 |
56 | for name, param in model.named_parameters():
57 | assert name in self.temp_stored_params, f"{name} not found in temp_stored_params"
58 | param.data.copy_(self.temp_stored_params[name].data)
59 | self.temp_stored_params = {}
60 |
61 | def load_state_dict(self, state_dict: dict | list) -> None:
62 | """load EMA state from state dict."""
63 | if isinstance(state_dict, dict):
64 | for name, param in self.ema_params.items():
65 | param.data.copy_(state_dict[name].to(param.device).data)
66 | elif isinstance(state_dict, list):
67 | i = 0
68 | for name, param in self.ema_params.items():
69 | param.data.copy_(state_dict[i].to(param.device).data)
70 | i += 1
71 | else:
72 | raise ValueError("state_dict must be a dict or list")
73 |
74 | def state_dict(self) -> dict:
75 | """return EMA parameters as state dict."""
76 | return self.ema_params
77 |
--------------------------------------------------------------------------------
/models/lightningdit.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/hustvl/LightningDiT/blob/main/models/lightningdit.py
3 |
4 | - add support for 1D sequence
5 | - include samplers inside the model
6 | - slightly different cfg conditioning (conditioned on **all channels**)
7 | """
8 |
9 | import logging
10 | from functools import partial
11 |
12 | import torch
13 | import torch.nn as nn
14 |
15 | from transport import Sampler, create_transport
16 |
17 | from .layers import (
18 | Block,
19 | LabelEmbedder,
20 | ModulatedLinear,
21 | PatchEmbed,
22 | TimestepEmbedder,
23 | Transformer,
24 | VisionRotaryEmbeddingFast,
25 | get_2d_sincos_pos_embed,
26 | )
27 | from .model_utils import SIZE_DICT
28 |
29 | logger = logging.getLogger("DeTok")
30 |
31 |
32 | class LightningDiT(nn.Module):
33 | """lightning diffusion transformer model."""
34 |
35 | def __init__(
36 | self,
37 | img_size=256,
38 | patch_size=1,
39 | model_size="base",
40 | tokenizer_patch_size=16,
41 | token_channels=16,
42 | label_drop_prob=0.1,
43 | num_classes=1000,
44 | num_sampling_steps=250,
45 | sampling_method="euler",
46 | grad_checkpointing=False,
47 | force_one_d_seq=0,
48 | learn_sigma=False, # no learn_sigma in SiT
49 | legacy_mode=False,
50 | qk_norm=False,
51 | ):
52 | super().__init__()
53 |
54 | # --------------------------------------------------------------------------
55 | # basic configuration
56 | self.token_channels = self.out_channels = token_channels
57 | self.input_size = img_size // tokenizer_patch_size
58 | self.patch_size = patch_size
59 | self.num_classes = num_classes
60 | self.force_one_d_seq = force_one_d_seq
61 | self.grad_checkpointing = grad_checkpointing
62 | self.learn_sigma = learn_sigma
63 | self.legacy_mode = legacy_mode
64 |
65 | # model architecture configuration
66 | size_dict = SIZE_DICT[model_size]
67 | num_layers, num_heads, width = size_dict["layers"], size_dict["heads"], size_dict["width"]
68 |
69 | # --------------------------------------------------------------------------
70 | # embedding layers
71 | if self.force_one_d_seq > 0:
72 | self.x_embedder = nn.Linear(token_channels, width)
73 | # we use learnable positional embeddings for 1D sequence without rope
74 | self.pos_embed = nn.Parameter(torch.randn(1, self.force_one_d_seq, width) * 0.02)
75 | self.seq_len = self.force_one_d_seq
76 | else:
77 | self.x_embedder = PatchEmbed(self.input_size, patch_size, token_channels, width)
78 | # use rotary position encoding + abe, borrow from EVA
79 | num_patches = self.x_embedder.num_patches
80 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, width))
81 | self.rope = VisionRotaryEmbeddingFast(width // num_heads // 2, self.input_size // patch_size)
82 | self.seq_len = num_patches
83 |
84 | self.t_embedder = TimestepEmbedder(width)
85 | self.y_embedder = LabelEmbedder(num_classes, width, label_drop_prob)
86 |
87 | # --------------------------------------------------------------------------
88 | # transformer architecture
89 | self.transformer = Transformer(
90 | width,
91 | num_layers,
92 | num_heads,
93 | block_fn=partial(Block, use_modulation=True),
94 | norm_layer=nn.RMSNorm,
95 | grad_checkpointing=grad_checkpointing,
96 | use_swiglu=True,
97 | qk_norm=qk_norm,
98 | )
99 | self.final_layer = ModulatedLinear(width, patch_size**2 * token_channels, use_rmsnorm=True)
100 |
101 | # --------------------------------------------------------------------------
102 | # transport and sampling setup
103 | self.transport = create_transport(use_cosine_loss=True, use_lognorm=True)
104 | self.sampler = Sampler(self.transport)
105 | self.sample_fn = self.sampler.sample_ode(
106 | sampling_method=sampling_method,
107 | num_steps=int(num_sampling_steps),
108 | timestep_shift=0.3,
109 | )
110 |
111 | self.initialize_weights()
112 |
113 | # log model info
114 | num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
115 | logger.info(
116 | f"[LightningDiT] params: {num_trainable_params:.2f}M size: {model_size}, num_layers: {num_layers}, width: {width}"
117 | )
118 |
119 | def initialize_weights(self):
120 | """initialize model weights."""
121 |
122 | def _basic_init(module):
123 | if isinstance(module, nn.Linear):
124 | torch.nn.init.xavier_uniform_(module.weight)
125 | if module.bias is not None:
126 | nn.init.constant_(module.bias, 0)
127 |
128 | self.apply(_basic_init)
129 |
130 | # initialize (and freeze) pos_embed by sin-cos embedding
131 | if not self.force_one_d_seq:
132 | pos_embed = get_2d_sincos_pos_embed(
133 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)
134 | )
135 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
136 |
137 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
138 | w = self.x_embedder.proj.weight.data
139 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
140 | nn.init.constant_(self.x_embedder.proj.bias, 0)
141 |
142 | # initialize label embedding table
143 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
144 |
145 | # initialize timestep embedding MLP
146 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
147 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
148 |
149 | # zero-out adaLN modulation layers in LightningDiT blocks
150 | for block in self.transformer.blocks:
151 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
152 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
153 |
154 | # zero-out output layers
155 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
156 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
157 | nn.init.constant_(self.final_layer.linear.weight, 0)
158 | nn.init.constant_(self.final_layer.linear.bias, 0)
159 |
160 | def unpatchify(self, x):
161 | """convert patch tokens back to image tensor."""
162 | c, p = self.out_channels, self.patch_size
163 | h = w = int(x.shape[1] ** 0.5)
164 | assert h * w == x.shape[1]
165 |
166 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
167 | x = torch.einsum("nhwpqc->nchpwq", x)
168 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
169 | return imgs
170 |
171 | def net(self, x, t=None, y=None):
172 | """core network forward pass."""
173 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
174 | c = self.t_embedder(t) + self.y_embedder(y, self.training) # (N, D)
175 |
176 | # check if self.pos_embed requires grad
177 | if not self.force_one_d_seq:
178 | x = self.transformer(x, condition=c, rope=self.rope) # (N, T, D)
179 | else:
180 | x = self.transformer(x, condition=c)
181 |
182 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
183 | if not self.force_one_d_seq:
184 | x = self.unpatchify(x)
185 | return x
186 |
187 | def forward_with_cfg(self, x, t, y, cfg_scale, cfg_interval=None, cfg_interval_start=None):
188 | """forward pass with classifier-free guidance."""
189 | half = x[: len(x) // 2]
190 | combined = torch.cat([half, half], dim=0)
191 | model_out = self.net(combined, t, y)
192 |
193 | if self.legacy_mode:
194 | eps, rest = model_out[:, :3], model_out[:, 3:]
195 | else:
196 | eps, rest = model_out[:, : self.token_channels], model_out[:, self.token_channels :]
197 |
198 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
199 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
200 |
201 | if cfg_interval is True:
202 | timestep = t[0]
203 | if timestep < cfg_interval_start:
204 | half_eps = cond_eps
205 |
206 | eps = torch.cat([half_eps, half_eps], dim=0)
207 | return torch.cat([eps, rest], dim=1)
208 |
209 | def forward(self, x, y):
210 | """forward pass for training."""
211 | loss_dict = self.transport.training_losses(self.net, x, dict(y=y))
212 | return loss_dict["loss"].mean()
213 |
214 | @torch.inference_mode()
215 | def generate(self, n_samples, labels, cfg=1.0, args=None):
216 | """generate samples using the model."""
217 | device = labels.device
218 |
219 | # prepare noise tensor
220 | if self.force_one_d_seq:
221 | z = torch.randn(n_samples, self.force_one_d_seq, self.token_channels)
222 | else:
223 | z = torch.randn(n_samples, self.token_channels, self.input_size, self.input_size)
224 | z = z.to(device)
225 |
226 | # setup classifier-free guidance
227 | if cfg > 1.0:
228 | z = torch.cat([z, z], 0)
229 | y_null = torch.tensor([self.num_classes] * n_samples, device=device)
230 | labels = torch.cat([labels, y_null], 0)
231 | model_kwargs = dict(y=labels, cfg_scale=cfg, cfg_interval=True, cfg_interval_start=0.10)
232 | model_fn = self.forward_with_cfg
233 | else:
234 | model_kwargs = dict(y=labels)
235 | model_fn = self.net
236 |
237 | # generate samples
238 | samples = self.sample_fn(z, model_fn, **model_kwargs)[-1]
239 | if cfg > 1.0:
240 | samples, _ = samples.chunk(2, dim=0) # remove null class samples
241 | return samples
242 |
243 |
244 | # model size variants
245 | def LightningDiT_base(**kwargs) -> LightningDiT:
246 | return LightningDiT(model_size="base", **kwargs)
247 |
248 |
249 | def LightningDiT_large(**kwargs) -> LightningDiT:
250 | return LightningDiT(model_size="large", **kwargs)
251 |
252 |
253 | def LightningDiT_xl(**kwargs) -> LightningDiT:
254 | return LightningDiT(model_size="xl", **kwargs)
255 |
256 |
257 | def LightningDiT_huge(**kwargs) -> LightningDiT:
258 | return LightningDiT(model_size="huge", **kwargs)
259 |
260 |
261 | LightningDiT_models = {
262 | "LightningDiT_base": LightningDiT_base,
263 | "LightningDiT_large": LightningDiT_large,
264 | "LightningDiT_xl": LightningDiT_xl,
265 | "LightningDiT_huge": LightningDiT_huge,
266 | }
267 |
--------------------------------------------------------------------------------
/models/mar.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/LTH14/mar/blob/main/models/mar.py
3 | - add support for 1D sequence
4 | - include samplers inside the model
5 | - add support for removing dropout in MLPs
6 | """
7 |
8 | import logging
9 | import math
10 | from functools import partial
11 |
12 | import numpy as np
13 | import scipy.stats as stats
14 | import torch
15 | import torch.nn as nn
16 | from torch.utils.checkpoint import checkpoint
17 | from tqdm import tqdm
18 |
19 | from .diffloss import DiffLoss
20 | from .layers import Block
21 |
22 | logger = logging.getLogger("DeTok")
23 |
24 | MAR_SIZE_DICT = {
25 | "base": {"width": 768, "layers": 12, "heads": 12},
26 | "large": {"width": 1024, "layers": 16, "heads": 16},
27 | "huge": {"width": 1280, "layers": 20, "heads": 16},
28 | }
29 |
30 |
31 | def mask_by_order(mask_len, order, bsz, seq_len):
32 | """create masking tensor based on given order and length."""
33 | masking = torch.zeros(bsz, seq_len).cuda()
34 | masking = torch.scatter(
35 | masking,
36 | dim=-1,
37 | index=order[:, : mask_len.long()],
38 | src=torch.ones(bsz, seq_len).cuda(),
39 | ).bool()
40 | return masking
41 |
42 |
43 | class MAR(nn.Module):
44 | def __init__(
45 | self,
46 | img_size=256,
47 | patch_size=1,
48 | model_size="base",
49 | tokenizer_patch_size=16,
50 | token_channels=16,
51 | mask_ratio_min=0.7,
52 | label_drop_prob=0.1,
53 | num_classes=1000,
54 | attn_dropout=0.1,
55 | proj_dropout=0.1,
56 | buffer_size=64,
57 | diffloss_d=3,
58 | diffloss_w=1024,
59 | num_sampling_steps="100",
60 | noise_schedule="cosine",
61 | diffusion_batch_mul=4,
62 | force_one_d_seq=0,
63 | grad_checkpointing=False,
64 | no_dropout_in_mlp=False,
65 | ):
66 | super().__init__()
67 |
68 | # --------------------------------------------------------------------------
69 | # VAE and patchify specifics
70 | self.token_channels = token_channels
71 | self.img_size = img_size
72 | self.patch_size = patch_size
73 | self.seq_h = self.seq_w = img_size // tokenizer_patch_size // patch_size
74 | self.seq_len = self.seq_h * self.seq_w
75 | self.token_embed_dim = token_channels * patch_size**2
76 | self.grad_checkpointing = grad_checkpointing
77 | self.model_size = model_size
78 | self.force_one_d_seq = force_one_d_seq
79 | if force_one_d_seq:
80 | self.seq_len = force_one_d_seq
81 |
82 | size_dict = MAR_SIZE_DICT[self.model_size]
83 | num_layers, num_heads, width = size_dict["layers"], size_dict["heads"], size_dict["width"]
84 |
85 | # --------------------------------------------------------------------------
86 | # Class Embedding
87 | self.num_classes = num_classes
88 | self.class_emb = nn.Embedding(num_classes, width)
89 | self.label_drop_prob = label_drop_prob
90 | # Fake class embedding for CFG's unconditional generation
91 | self.fake_latent = nn.Parameter(torch.zeros(1, width))
92 |
93 | # --------------------------------------------------------------------------
94 | # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25
95 | self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25)
96 |
97 | # --------------------------------------------------------------------------
98 | # MAR encoder specifics
99 | self.z_proj = nn.Linear(self.token_embed_dim, width, bias=True)
100 | self.z_proj_ln = nn.LayerNorm(width, eps=1e-6)
101 | self.buffer_size = buffer_size
102 | self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, width))
103 |
104 | norm_layer = partial(nn.LayerNorm, eps=1e-6)
105 | self.encoder_blocks = nn.ModuleList(
106 | [
107 | Block(
108 | width,
109 | num_heads,
110 | norm_layer=norm_layer,
111 | qkv_bias=True,
112 | proj_drop=proj_dropout,
113 | attn_drop=attn_dropout,
114 | no_dropout_in_mlp=no_dropout_in_mlp,
115 | )
116 | for _ in range(num_layers)
117 | ]
118 | )
119 | self.encoder_norm = norm_layer(width)
120 |
121 | # --------------------------------------------------------------------------
122 | # MAR decoder specifics
123 | self.decoder_embed = nn.Linear(width, width, bias=True)
124 | self.mask_token = nn.Parameter(torch.zeros(1, 1, width))
125 | self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, width))
126 |
127 | self.decoder_blocks = nn.ModuleList(
128 | [
129 | Block(
130 | width,
131 | num_heads,
132 | qkv_bias=True,
133 | norm_layer=norm_layer,
134 | proj_drop=proj_dropout,
135 | attn_drop=attn_dropout,
136 | no_dropout_in_mlp=no_dropout_in_mlp,
137 | )
138 | for _ in range(num_layers)
139 | ]
140 | )
141 |
142 | self.decoder_norm = norm_layer(width)
143 | self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, width))
144 |
145 | self.initialize_weights()
146 |
147 | # --------------------------------------------------------------------------
148 | # Diffusion Loss
149 | self.diffloss = DiffLoss(
150 | target_channels=self.token_embed_dim,
151 | z_channels=width,
152 | width=diffloss_w,
153 | depth=diffloss_d,
154 | num_sampling_steps=num_sampling_steps,
155 | noise_schedule=noise_schedule,
156 | grad_checkpointing=grad_checkpointing,
157 | )
158 | self.diffusion_batch_mul = diffusion_batch_mul
159 |
160 | params_M = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
161 | logger.info(f"[MAR] params: {params_M:.2f}M, {model_size}-{num_layers}-{width}")
162 | logger.info(f"[MAR] seq_len: {self.seq_len}, buffer_size: {self.buffer_size}")
163 |
164 | def initialize_weights(self):
165 | # parameters
166 | torch.nn.init.normal_(self.class_emb.weight, std=0.02)
167 | torch.nn.init.normal_(self.fake_latent, std=0.02)
168 | torch.nn.init.normal_(self.mask_token, std=0.02)
169 | torch.nn.init.normal_(self.encoder_pos_embed_learned, std=0.02)
170 | torch.nn.init.normal_(self.decoder_pos_embed_learned, std=0.02)
171 | torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=0.02)
172 |
173 | # initialize nn.Linear and nn.LayerNorm
174 | self.apply(self._init_weights)
175 |
176 | def _init_weights(self, m):
177 | if isinstance(m, nn.Linear):
178 | # we use xavier_uniform following official JAX ViT:
179 | torch.nn.init.xavier_uniform_(m.weight)
180 | if m.bias is not None:
181 | nn.init.constant_(m.bias, 0)
182 | elif isinstance(m, nn.LayerNorm):
183 | if m.bias is not None:
184 | nn.init.constant_(m.bias, 0)
185 | if m.weight is not None:
186 | nn.init.constant_(m.weight, 1.0)
187 |
188 | def patchify(self, x):
189 | bsz, c, h, w = x.shape
190 | p = self.patch_size
191 | h_, w_ = h // p, w // p
192 |
193 | x = x.reshape(bsz, c, h_, p, w_, p)
194 | x = torch.einsum("nchpwq->nhwcpq", x)
195 | x = x.reshape(bsz, h_ * w_, c * p**2)
196 | return x # [n, l, d]
197 |
198 | def unpatchify(self, x):
199 | bsz = x.shape[0]
200 | p = self.patch_size
201 | c = self.token_channels
202 | h_, w_ = self.seq_h, self.seq_w
203 |
204 | x = x.reshape(bsz, h_, w_, c, p, p)
205 | x = torch.einsum("nhwcpq->nchpwq", x)
206 | x = x.reshape(bsz, c, h_ * p, w_ * p)
207 | return x # [n, c, h, w]
208 |
209 | def sample_orders(self, bsz):
210 | # generate a batch of random generation orders
211 | orders = []
212 | for _ in range(bsz):
213 | order = np.array(list(range(self.seq_len)))
214 | np.random.shuffle(order)
215 | orders.append(order)
216 | orders = torch.Tensor(np.array(orders)).cuda().long()
217 | return orders
218 |
219 | def random_masking(self, x, orders):
220 | # generate token mask
221 | bsz, seq_len, _ = x.shape
222 | mask_rate = self.mask_ratio_generator.rvs(1)[0]
223 | num_masked_tokens = int(np.ceil(seq_len * mask_rate))
224 | mask = torch.zeros(bsz, seq_len, device=x.device)
225 | mask = torch.scatter(
226 | mask,
227 | dim=-1,
228 | index=orders[:, :num_masked_tokens],
229 | src=torch.ones(bsz, seq_len, device=x.device),
230 | )
231 | return mask
232 |
233 | def forward_mae_encoder(self, x, mask, class_embedding):
234 | x = self.z_proj(x)
235 | bsz, _, embed_dim = x.shape
236 |
237 | # concat buffer
238 | x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1)
239 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
240 |
241 | # random drop class embedding during training
242 | if self.training:
243 | drop_latent_mask = torch.rand(bsz) < self.label_drop_prob
244 | drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype)
245 | class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding
246 |
247 | x[:, : self.buffer_size] = class_embedding.unsqueeze(1)
248 |
249 | # encoder position embedding
250 | x = x + self.encoder_pos_embed_learned
251 | x = self.z_proj_ln(x)
252 |
253 | # dropping
254 | x = x[(1 - mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim)
255 |
256 | # apply Transformer blocks
257 | if self.grad_checkpointing and self.training:
258 | for i, block in enumerate(self.encoder_blocks):
259 | x = checkpoint(block, x)
260 | else:
261 | for block in self.encoder_blocks:
262 | x = block(x)
263 | x = self.encoder_norm(x)
264 | return x
265 |
266 | def forward_mae_decoder(self, x, mask):
267 |
268 | x = self.decoder_embed(x)
269 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1)
270 |
271 | # pad mask tokens
272 | mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(
273 | x.dtype
274 | )
275 | x_after_pad = mask_tokens.clone()
276 | x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(
277 | x.shape[0] * x.shape[1], x.shape[2]
278 | )
279 |
280 | # decoder position embedding
281 | x = x_after_pad + self.decoder_pos_embed_learned
282 |
283 | # apply Transformer blocks
284 | if self.grad_checkpointing and self.training:
285 | for block in self.decoder_blocks:
286 | x = checkpoint(block, x)
287 | else:
288 | for block in self.decoder_blocks:
289 | x = block(x)
290 | x = self.decoder_norm(x)
291 |
292 | x = x[:, self.buffer_size :]
293 | x = x + self.diffusion_pos_embed_learned
294 | return x
295 |
296 | def forward_loss(self, z, target, mask):
297 | bsz, seq_len, _ = target.shape
298 | target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
299 | z = z.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
300 | mask = mask.reshape(bsz * seq_len).repeat(self.diffusion_batch_mul)
301 | loss = self.diffloss(z=z, target=target, mask=mask)
302 | return loss
303 |
304 | def forward(self, imgs, labels):
305 |
306 | # class embed
307 | class_embedding = self.class_emb(labels)
308 |
309 | # patchify and mask (drop) tokens
310 | x = self.patchify(imgs) if not self.force_one_d_seq else imgs
311 | gt_latents = x.clone().detach()
312 | orders = self.sample_orders(bsz=x.size(0))
313 | mask = self.random_masking(x, orders)
314 |
315 | x = self.forward_mae_encoder(x, mask, class_embedding)
316 | z = self.forward_mae_decoder(x, mask)
317 | loss = self.forward_loss(z=z, target=gt_latents, mask=mask)
318 | return loss
319 |
320 | def sample_tokens(
321 | self,
322 | bsz,
323 | num_iter=64,
324 | cfg=1.0,
325 | cfg_schedule="linear",
326 | labels=None,
327 | temperature=1.0,
328 | progress=False,
329 | ):
330 |
331 | # init and sample generation orders
332 | mask = torch.ones(bsz, self.seq_len).cuda()
333 | tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda()
334 | orders = self.sample_orders(bsz)
335 |
336 | indices = list(range(num_iter))
337 | if progress:
338 | indices = tqdm(indices)
339 | # generate latents
340 | for step in indices:
341 | cur_tokens = tokens.clone()
342 |
343 | # class embedding and CFG
344 | if labels is not None:
345 | class_embedding = self.class_emb(labels)
346 | else:
347 | class_embedding = self.fake_latent.repeat(bsz, 1)
348 | if cfg != 1.0:
349 | tokens = torch.cat([tokens, tokens], dim=0)
350 | class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0)
351 | mask = torch.cat([mask, mask], dim=0)
352 |
353 | # mae encoder
354 | x = self.forward_mae_encoder(tokens, mask, class_embedding)
355 |
356 | # mae decoder
357 | z = self.forward_mae_decoder(x, mask)
358 |
359 | # mask ratio for the next round, following MaskGIT and MAGE.
360 | mask_ratio = np.cos(math.pi / 2.0 * (step + 1) / num_iter)
361 | mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda()
362 |
363 | # masks out at least one for the next iteration
364 | mask_len = torch.maximum(
365 | torch.Tensor([1]).cuda(),
366 | torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len),
367 | )
368 |
369 | # get masking for next iteration and locations to be predicted in this iteration
370 | mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len)
371 | if step >= num_iter - 1:
372 | mask_to_pred = mask[:bsz].bool()
373 | else:
374 | mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool())
375 | mask = mask_next
376 | if cfg != 1.0:
377 | mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0)
378 |
379 | # sample token latents for this step
380 | z = z[mask_to_pred.nonzero(as_tuple=True)]
381 | # cfg schedule follow Muse
382 | if cfg_schedule == "linear":
383 | cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len
384 | elif cfg_schedule == "constant":
385 | cfg_iter = cfg
386 | else:
387 | raise NotImplementedError
388 | sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter)
389 | if cfg != 1.0:
390 | sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples
391 | mask_to_pred, _ = mask_to_pred.chunk(2, dim=0)
392 |
393 | cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent
394 | tokens = cur_tokens.clone()
395 |
396 | # unpatchify
397 | if not self.force_one_d_seq:
398 | tokens = self.unpatchify(tokens)
399 | return tokens
400 |
401 | @torch.inference_mode()
402 | def generate(self, n_samples, cfg, labels, args):
403 | return self.sample_tokens(
404 | n_samples,
405 | num_iter=args.num_iter,
406 | cfg=cfg,
407 | labels=labels,
408 | cfg_schedule=args.cfg_schedule,
409 | temperature=args.temperature,
410 | progress=True,
411 | )
412 |
413 |
414 | def mar_base(**kwargs) -> MAR:
415 | return MAR(model_size="base", **kwargs)
416 |
417 |
418 | def mar_large(**kwargs):
419 | return MAR(model_size="large", **kwargs)
420 |
421 |
422 | def mar_huge(**kwargs):
423 | return MAR(model_size="huge", **kwargs)
424 |
425 |
426 | MAR_models = {
427 | "MAR_base": mar_base,
428 | "MAR_large": mar_large,
429 | "MAR_huge": mar_huge,
430 | }
431 |
--------------------------------------------------------------------------------
/models/model_utils.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------------------------
2 | # model size configurations
3 | SIZE_DICT = {
4 | "small": {"width": 512, "layers": 8, "heads": 8},
5 | "base": {"width": 768, "layers": 12, "heads": 12},
6 | "large": {"width": 1024, "layers": 24, "heads": 16},
7 | "xl": {"width": 1152, "layers": 28, "heads": 16},
8 | "huge": {"width": 1280, "layers": 32, "heads": 16},
9 | }
10 |
--------------------------------------------------------------------------------
/models/sit.py:
--------------------------------------------------------------------------------
1 | """
2 | Modified from https://github.com/willisma/SiT/blob/main/models.py
3 |
4 | - add support for 1D sequence
5 | - include samplers inside the model
6 | - slightly different cfg conditioning (conditioned on **all channels**)
7 | """
8 |
9 | import logging
10 | from functools import partial
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.utils.checkpoint import checkpoint
15 |
16 | from transport import Sampler, create_transport
17 |
18 | from .layers import (
19 | Block,
20 | LabelEmbedder,
21 | ModulatedLinear,
22 | PatchEmbed,
23 | TimestepEmbedder,
24 | Transformer,
25 | get_2d_sincos_pos_embed,
26 | )
27 | from .model_utils import SIZE_DICT
28 |
29 | logger = logging.getLogger("DeTok")
30 |
31 |
32 | class SiT(nn.Module):
33 | """scalable interpolant transformer model."""
34 |
35 | def __init__(
36 | self,
37 | img_size=256,
38 | patch_size=1,
39 | model_size="base",
40 | tokenizer_patch_size=16,
41 | token_channels=16,
42 | label_drop_prob=0.1,
43 | num_classes=1000,
44 | num_sampling_steps=250,
45 | sampling_method="dopri5",
46 | grad_checkpointing=False,
47 | force_one_d_seq=0,
48 | learn_sigma=False, # no learn_sigma in SiT
49 | legacy_mode=False,
50 | qk_norm=False,
51 | ):
52 | super().__init__()
53 |
54 | # --------------------------------------------------------------------------
55 | # basic configuration
56 | self.token_channels = token_channels
57 | self.out_channels = token_channels * 2 if learn_sigma else token_channels
58 | self.input_size = img_size // tokenizer_patch_size
59 | self.patch_size = patch_size
60 | self.num_classes = num_classes
61 | self.force_one_d_seq = force_one_d_seq
62 | self.grad_checkpointing = grad_checkpointing
63 | self.learn_sigma = learn_sigma
64 | self.legacy_mode = legacy_mode
65 |
66 | # model architecture configuration
67 | size_dict = SIZE_DICT[model_size]
68 | num_layers, num_heads, width = size_dict["layers"], size_dict["heads"], size_dict["width"]
69 |
70 | # --------------------------------------------------------------------------
71 | # embedding layers
72 | if self.force_one_d_seq:
73 | self.x_embedder = nn.Linear(token_channels, width)
74 | self.pos_embed = nn.Parameter(torch.randn(1, self.force_one_d_seq, width) * 0.02)
75 | self.seq_len = self.force_one_d_seq
76 | else:
77 | self.x_embedder = PatchEmbed(self.input_size, patch_size, token_channels, width)
78 | num_patches = self.x_embedder.num_patches
79 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, width), requires_grad=False)
80 | self.seq_len = num_patches
81 |
82 | self.t_embedder = TimestepEmbedder(width)
83 | self.y_embedder = LabelEmbedder(num_classes, width, label_drop_prob)
84 |
85 | # --------------------------------------------------------------------------
86 | # transformer architecture
87 | self.transformer = Transformer(
88 | width,
89 | num_layers,
90 | num_heads,
91 | block_fn=partial(Block, use_modulation=True),
92 | norm_layer=partial(nn.LayerNorm, elementwise_affine=False, eps=1e-6),
93 | qk_norm=qk_norm,
94 | grad_checkpointing=grad_checkpointing,
95 | )
96 | self.final_layer = ModulatedLinear(width, patch_size * patch_size * self.out_channels)
97 |
98 | # --------------------------------------------------------------------------
99 | # transport and sampling setup
100 | self.transport = create_transport()
101 | self.sampler = Sampler(self.transport)
102 | self.sample_fn = self.sampler.sample_ode(
103 | sampling_method=sampling_method,
104 | num_steps=int(num_sampling_steps),
105 | )
106 |
107 | self.initialize_weights()
108 |
109 | # log model info
110 | params_M = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
111 | logger.info(f"[SiT] params: {params_M:.2f}M, {model_size}-{num_layers}-{width}")
112 | logger.info(f"[SiT] seq_len: {self.seq_len}")
113 |
114 | def initialize_weights(self):
115 | """initialize model weights."""
116 |
117 | def _basic_init(module):
118 | if isinstance(module, nn.Linear):
119 | torch.nn.init.xavier_uniform_(module.weight)
120 | if module.bias is not None:
121 | nn.init.constant_(module.bias, 0)
122 |
123 | self.apply(_basic_init)
124 |
125 | # initialize (and freeze) pos_embed by sin-cos embedding
126 | if not self.force_one_d_seq:
127 | pos_embed = get_2d_sincos_pos_embed(
128 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5)
129 | )
130 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
131 |
132 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d)
133 | w = self.x_embedder.proj.weight.data
134 | nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
135 | nn.init.constant_(self.x_embedder.proj.bias, 0)
136 |
137 | # initialize label embedding table
138 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
139 |
140 | # initialize timestep embedding MLP
141 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
142 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
143 |
144 | # zero-out adaLN modulation layers in SiT blocks
145 | for block in self.transformer.blocks:
146 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
147 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
148 |
149 | # zero-out output layers
150 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
151 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
152 | nn.init.constant_(self.final_layer.linear.weight, 0)
153 | nn.init.constant_(self.final_layer.linear.bias, 0)
154 |
155 | def unpatchify(self, x):
156 | """convert patch tokens back to image tensor."""
157 | c, p = self.out_channels, self.patch_size
158 | h = w = int(x.shape[1] ** 0.5)
159 | assert h * w == x.shape[1]
160 |
161 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
162 | x = torch.einsum("nhwpqc->nchpwq", x)
163 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
164 | return imgs
165 |
166 | def net(self, x, t, y):
167 | """core network forward pass."""
168 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
169 | c = self.t_embedder(t) + self.y_embedder(y, self.training) # (N, D)
170 |
171 | # transformer forward pass
172 | for block in self.transformer.blocks:
173 | if self.grad_checkpointing and self.training:
174 | x = checkpoint(block, x, condition=c)
175 | else:
176 | x = block(x, condition=c)
177 |
178 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
179 |
180 | if not self.force_one_d_seq:
181 | x = self.unpatchify(x) # (N, out_channels, H, W)
182 | if self.learn_sigma:
183 | x, _ = x.chunk(2, dim=1)
184 | return x
185 |
186 | def forward_with_cfg(self, x, t, y, cfg_scale):
187 | """forward pass with classifier-free guidance."""
188 | half = x[: len(x) // 2]
189 | combined = torch.cat([half, half], dim=0)
190 | model_out = self.net(combined, t, y)
191 |
192 | if self.legacy_mode:
193 | eps, rest = model_out[:, :3], model_out[:, 3:]
194 | else:
195 | eps, rest = model_out[:, : self.token_channels], model_out[:, self.token_channels :]
196 |
197 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
198 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
199 | eps = torch.cat([half_eps, half_eps], dim=0)
200 | return torch.cat([eps, rest], dim=1)
201 |
202 | def forward(self, x, y):
203 | """forward pass for training."""
204 | loss_dict = self.transport.training_losses(self.net, x, dict(y=y))
205 | return loss_dict["loss"].mean()
206 |
207 | @torch.inference_mode()
208 | def generate(self, n_samples, labels, cfg=1.0, args=None):
209 | """generate samples using the model."""
210 | device = labels.device
211 |
212 | # prepare noise tensor
213 | if self.force_one_d_seq:
214 | z = torch.randn(n_samples, self.force_one_d_seq, self.token_channels)
215 | else:
216 | z = torch.randn(n_samples, self.token_channels, self.input_size, self.input_size)
217 | z = z.to(device)
218 |
219 | # setup classifier-free guidance
220 | if cfg > 1.0:
221 | z = torch.cat([z, z], 0)
222 | y_null = torch.tensor([self.num_classes] * n_samples, device=device)
223 | labels = torch.cat([labels, y_null], 0)
224 | model_kwargs = dict(y=labels, cfg_scale=cfg)
225 | model_fn = self.forward_with_cfg
226 | else:
227 | model_kwargs = dict(y=labels)
228 | model_fn = self.net
229 |
230 | # generate samples
231 | samples = self.sample_fn(z, model_fn, **model_kwargs)[-1]
232 | if cfg > 1.0:
233 | samples, _ = samples.chunk(2, dim=0) # remove null class samples
234 | return samples
235 |
236 |
237 | # model size variants
238 | def SiT_XL(**kwargs) -> SiT:
239 | return SiT(model_size="xl", **kwargs)
240 |
241 |
242 | def SiT_L(**kwargs) -> SiT:
243 | return SiT(model_size="large", **kwargs)
244 |
245 |
246 | def SiT_B(**kwargs) -> SiT:
247 | return SiT(model_size="base", **kwargs)
248 |
249 |
250 | def SiT_S(**kwargs) -> SiT:
251 | return SiT(model_size="small", **kwargs)
252 |
253 |
254 | SiT_models = {"SiT_base": SiT_B, "SiT_large": SiT_L, "SiT_xl": SiT_XL, "SiT_small": SiT_S}
255 |
--------------------------------------------------------------------------------
/pyrightconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "exclude": [
3 | "**/node_modules",
4 | "**/__pycache__",
5 | "**/.git",
6 | "data",
7 | "legacy_code",
8 | "work_dirs"
9 | ],
10 | "reportUnknownMemberType": false,
11 | "reportUnannotatedClassAttribute": false,
12 | "reportUnusedCallResult": false,
13 | "reportUnknownParameterType": false,
14 | "reportUnknownVariableType": false,
15 | "reportUnknownArgumentType": false,
16 | "reportAny": false,
17 | "reportExplicitAny": false,
18 | "reportMissingParameterType": false,
19 | "reportArgumentType": false,
20 | "reportMissingTypeStubs": false,
21 | "reportOptionalMemberAccess": false,
22 | "reportImplicitOverride": false,
23 | "reportUntypedFunctionDecorator": false,
24 | "reportOperatorIssue": false,
25 | "reportIndexIssue": false,
26 | "reportUntypedNamedTuple": false,
27 | "reportImplicitStringConcatenation": false,
28 | "reportUnknownLambdaType": false,
29 | "reportPrivateLocalImportUsage": false,
30 | "reportOptionalOperand": false,
31 | "reportReturnType": false,
32 | "reportMissingTypeArgument": false,
33 | "reportCallIssue": false,
34 | "reportCallInDefaultInitializer": false,
35 | "reportIncompatibleMethodOverride": false,
36 | "reportAttributeAccessIssue": false,
37 | "reportPossiblyUnboundVariable": false,
38 | "reportMissingSuperCall": false,
39 | "reportOptionalSubscript": false,
40 | "reportImportCycles": false
41 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --extra-index-url https://download.pytorch.org/whl/cu124
2 | torch==2.6.0+cu124
3 | timm
4 | scipy
5 | einops
6 | wandb
7 | matplotlib
8 | diffusers
9 | torchtnt
10 | rich
11 | git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity
12 | huggingface_hub
13 |
14 | black
15 | pylint
16 | flake8
17 | isort
18 | jaxtyping
19 | torchdiffeq
20 | gdown
21 | ipykernel
--------------------------------------------------------------------------------
/transport/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/hustvl/LightningDiT/blob/main/transport/__init__.py
3 | """
4 |
5 | from .transport import ModelType, PathType, Transport, WeightType, Sampler
6 |
7 |
8 |
9 | def create_transport(
10 | path_type="Linear",
11 | prediction="velocity",
12 | loss_weight=None,
13 | train_eps=None,
14 | sample_eps=None,
15 | use_cosine_loss=False,
16 | use_lognorm=None,
17 | partitial_train=None,
18 | partial_ratio=1.0,
19 | shift_lg=False,
20 | ):
21 | """function for creating Transport object
22 | **Note**: model prediction defaults to velocity
23 | Args:
24 | - path_type: type of path to use; default to linear
25 | - learn_score: set model prediction to score
26 | - learn_noise: set model prediction to noise
27 | - velocity_weighted: weight loss by velocity weight
28 | - likelihood_weighted: weight loss by likelihood weight
29 | - train_eps: small epsilon for avoiding instability during training
30 | - sample_eps: small epsilon for avoiding instability during sampling
31 | """
32 | if prediction == "noise":
33 | model_type = ModelType.NOISE
34 | elif prediction == "score":
35 | model_type = ModelType.SCORE
36 | else:
37 | model_type = ModelType.VELOCITY
38 |
39 | if loss_weight == "velocity":
40 | loss_type = WeightType.VELOCITY
41 | elif loss_weight == "likelihood":
42 | loss_type = WeightType.LIKELIHOOD
43 | else:
44 | loss_type = WeightType.NONE
45 |
46 | path_choice = {
47 | "Linear": PathType.LINEAR,
48 | "GVP": PathType.GVP,
49 | "VP": PathType.VP,
50 | }
51 |
52 | path_type = path_choice[path_type]
53 |
54 | if path_type in [PathType.VP]:
55 | train_eps = 1e-5 if train_eps is None else train_eps
56 | sample_eps = 1e-3 if train_eps is None else sample_eps
57 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
58 | train_eps = 1e-3 if train_eps is None else train_eps
59 | sample_eps = 1e-3 if train_eps is None else sample_eps
60 | else: # velocity & [GVP, LINEAR] is stable everywhere
61 | train_eps = 0
62 | sample_eps = 0
63 |
64 | # create flow state
65 | state = Transport(
66 | model_type=model_type,
67 | path_type=path_type,
68 | loss_type=loss_type,
69 | train_eps=train_eps,
70 | sample_eps=sample_eps,
71 | use_cosine_loss=use_cosine_loss,
72 | use_lognorm=use_lognorm,
73 | partitial_train=partitial_train,
74 | partial_ratio=partial_ratio,
75 | shift_lg=shift_lg,
76 | )
77 |
78 | return state
79 |
--------------------------------------------------------------------------------
/transport/integrators.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/hustvl/LightningDiT/blob/main/transport/integrators.py
3 | """
4 |
5 | import torch as th
6 | from torchdiffeq import odeint
7 |
8 |
9 | class sde:
10 | """SDE solver class"""
11 |
12 | def __init__(
13 | self,
14 | drift,
15 | diffusion,
16 | *,
17 | t0,
18 | t1,
19 | num_steps,
20 | sampler_type,
21 | ):
22 | assert t0 < t1, "SDE sampler has to be in forward time"
23 |
24 | self.num_timesteps = num_steps
25 | self.t = th.linspace(t0, t1, num_steps)
26 | self.dt = self.t[1] - self.t[0]
27 | self.drift = drift
28 | self.diffusion = diffusion
29 | self.sampler_type = sampler_type
30 |
31 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
32 | w_cur = th.randn(x.size()).to(x)
33 | t = th.ones(x.size(0)).to(x) * t
34 | dw = w_cur * th.sqrt(self.dt)
35 | drift = self.drift(x, t, model, **model_kwargs)
36 | diffusion = self.diffusion(x, t)
37 | mean_x = x + drift * self.dt
38 | x = mean_x + th.sqrt(2 * diffusion) * dw
39 | return x, mean_x
40 |
41 | def __Heun_step(self, x, _, t, model, **model_kwargs):
42 | w_cur = th.randn(x.size()).to(x)
43 | dw = w_cur * th.sqrt(self.dt)
44 | t_cur = th.ones(x.size(0)).to(x) * t
45 | diffusion = self.diffusion(x, t_cur)
46 | xhat = x + th.sqrt(2 * diffusion) * dw
47 | K1 = self.drift(xhat, t_cur, model, **model_kwargs)
48 | xp = xhat + self.dt * K1
49 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
50 | return (
51 | xhat + 0.5 * self.dt * (K1 + K2),
52 | xhat,
53 | ) # at last time point we do not perform the heun step
54 |
55 | def __forward_fn(self):
56 | """TODO: generalize here by adding all private functions ending with steps to it"""
57 | sampler_dict = {
58 | "Euler": self.__Euler_Maruyama_step,
59 | "Heun": self.__Heun_step,
60 | }
61 |
62 | try:
63 | sampler = sampler_dict[self.sampler_type]
64 | except:
65 | raise NotImplementedError("Smapler type not implemented.")
66 |
67 | return sampler
68 |
69 | def sample(self, init, model, **model_kwargs):
70 | """forward loop of sde"""
71 | x = init
72 | mean_x = init
73 | samples = []
74 | sampler = self.__forward_fn()
75 | for ti in self.t[:-1]:
76 | with th.no_grad():
77 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
78 | samples.append(x)
79 |
80 | return samples
81 |
82 |
83 | class ode:
84 | """ODE solver class"""
85 |
86 | def __init__(
87 | self,
88 | drift,
89 | *,
90 | t0,
91 | t1,
92 | sampler_type,
93 | num_steps,
94 | atol,
95 | rtol,
96 | timestep_shift,
97 | ):
98 | assert t0 < t1, "ODE sampler has to be in forward time"
99 |
100 | self.drift = drift
101 | self.t = th.linspace(t0, t1, num_steps)
102 |
103 | if timestep_shift > 0:
104 |
105 | def compute_tm(t_n, timestep_shift):
106 | numerator = timestep_shift * t_n
107 | denominator = 1 + (timestep_shift - 1) * t_n
108 | return numerator / denominator
109 |
110 | self.t = th.tensor([compute_tm(t_n, timestep_shift) for t_n in self.t])
111 |
112 | self.atol = atol
113 | self.rtol = rtol
114 | self.sampler_type = sampler_type
115 |
116 | def sample(self, x, model, **model_kwargs):
117 |
118 | device = x[0].device if isinstance(x, tuple) else x.device
119 |
120 | def _fn(t, x):
121 | t = (
122 | th.ones(x[0].size(0)).to(device) * t
123 | if isinstance(x, tuple)
124 | else th.ones(x.size(0)).to(device) * t
125 | )
126 | model_output = self.drift(x, t, model, **model_kwargs)
127 | return model_output
128 |
129 | t = self.t.to(device)
130 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
131 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
132 | samples = odeint(_fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol)
133 | return samples
134 |
--------------------------------------------------------------------------------
/transport/path.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/hustvl/LightningDiT/blob/main/transport/path.py
3 | """
4 |
5 |
6 | import numpy as np
7 | import torch as th
8 |
9 |
10 | def expand_t_like_x(t, x):
11 | """Function to reshape time t to broadcastable dimension of x
12 | Args:
13 | t: [batch_dim,], time vector
14 | x: [batch_dim,...], data point
15 | """
16 | dims = [1] * (len(x.size()) - 1)
17 | t = t.view(t.size(0), *dims)
18 | return t
19 |
20 |
21 | class ICPlan:
22 | """Linear Coupling Plan"""
23 |
24 | def __init__(self, sigma=0.0):
25 | self.sigma = sigma
26 |
27 | def compute_alpha_t(self, t):
28 | """Compute the data coefficient along the path"""
29 | return t, 1
30 |
31 | def compute_sigma_t(self, t):
32 | """Compute the noise coefficient along the path"""
33 | return 1 - t, -1
34 |
35 | def compute_d_alpha_alpha_ratio_t(self, t):
36 | """Compute the ratio between d_alpha and alpha"""
37 | return 1 / t
38 |
39 | def compute_drift(self, x, t):
40 | """We always output sde according to score parametrization;"""
41 | t = expand_t_like_x(t, x)
42 | alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
43 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
44 | drift = alpha_ratio * x
45 | diffusion = alpha_ratio * (sigma_t**2) - sigma_t * d_sigma_t
46 |
47 | return -drift, diffusion
48 |
49 | def compute_diffusion(self, x, t, form="constant", norm=1.0):
50 | """Compute the diffusion term of the SDE
51 | Args:
52 | x: [batch_dim, ...], data point
53 | t: [batch_dim,], time vector
54 | form: str, form of the diffusion term
55 | norm: float, norm of the diffusion term
56 | """
57 | t = expand_t_like_x(t, x)
58 | choices = {
59 | "constant": norm,
60 | "SBDM": norm * self.compute_drift(x, t)[1],
61 | "sigma": norm * self.compute_sigma_t(t)[0],
62 | "linear": norm * (1 - t),
63 | "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
64 | "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
65 | }
66 |
67 | try:
68 | diffusion = choices[form]
69 | except KeyError:
70 | raise NotImplementedError(f"Diffusion form {form} not implemented")
71 |
72 | return diffusion
73 |
74 | def get_score_from_velocity(self, velocity, x, t):
75 | """Wrapper function: transfrom velocity prediction model to score
76 | Args:
77 | velocity: [batch_dim, ...] shaped tensor; velocity model output
78 | x: [batch_dim, ...] shaped tensor; x_t data point
79 | t: [batch_dim,] time tensor
80 | """
81 | t = expand_t_like_x(t, x)
82 | alpha_t, d_alpha_t = self.compute_alpha_t(t)
83 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
84 | mean = x
85 | reverse_alpha_ratio = alpha_t / d_alpha_t
86 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
87 | score = (reverse_alpha_ratio * velocity - mean) / var
88 | return score
89 |
90 | def get_noise_from_velocity(self, velocity, x, t):
91 | """Wrapper function: transfrom velocity prediction model to denoiser
92 | Args:
93 | velocity: [batch_dim, ...] shaped tensor; velocity model output
94 | x: [batch_dim, ...] shaped tensor; x_t data point
95 | t: [batch_dim,] time tensor
96 | """
97 | t = expand_t_like_x(t, x)
98 | alpha_t, d_alpha_t = self.compute_alpha_t(t)
99 | sigma_t, d_sigma_t = self.compute_sigma_t(t)
100 | mean = x
101 | reverse_alpha_ratio = alpha_t / d_alpha_t
102 | var = reverse_alpha_ratio * d_sigma_t - sigma_t
103 | noise = (reverse_alpha_ratio * velocity - mean) / var
104 | return noise
105 |
106 | def get_velocity_from_score(self, score, x, t):
107 | """Wrapper function: transfrom score prediction model to velocity
108 | Args:
109 | score: [batch_dim, ...] shaped tensor; score model output
110 | x: [batch_dim, ...] shaped tensor; x_t data point
111 | t: [batch_dim,] time tensor
112 | """
113 | t = expand_t_like_x(t, x)
114 | drift, var = self.compute_drift(x, t)
115 | velocity = var * score - drift
116 | return velocity
117 |
118 | def compute_mu_t(self, t, x0, x1):
119 | """Compute the mean of time-dependent density p_t"""
120 | t = expand_t_like_x(t, x1)
121 | alpha_t, _ = self.compute_alpha_t(t)
122 | sigma_t, _ = self.compute_sigma_t(t)
123 | return alpha_t * x1 + sigma_t * x0
124 |
125 | def compute_xt(self, t, x0, x1):
126 | """Sample xt from time-dependent density p_t; rng is required"""
127 | xt = self.compute_mu_t(t, x0, x1)
128 | return xt
129 |
130 | def compute_ut(self, t, x0, x1, xt):
131 | """Compute the vector field corresponding to p_t"""
132 | t = expand_t_like_x(t, x1)
133 | _, d_alpha_t = self.compute_alpha_t(t)
134 | _, d_sigma_t = self.compute_sigma_t(t)
135 | return d_alpha_t * x1 + d_sigma_t * x0
136 |
137 | def plan(self, t, x0, x1):
138 | xt = self.compute_xt(t, x0, x1)
139 | ut = self.compute_ut(t, x0, x1, xt)
140 | return t, xt, ut
141 |
142 |
143 | class VPCPlan(ICPlan):
144 | """class for VP path flow matching"""
145 |
146 | def __init__(self, sigma_min=0.1, sigma_max=20.0):
147 | self.sigma_min = sigma_min
148 | self.sigma_max = sigma_max
149 | self.log_mean_coeff = (
150 | lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min)
151 | - 0.5 * (1 - t) * self.sigma_min
152 | )
153 | self.d_log_mean_coeff = (
154 | lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
155 | )
156 |
157 | def compute_alpha_t(self, t):
158 | """Compute coefficient of x1"""
159 | alpha_t = self.log_mean_coeff(t)
160 | alpha_t = th.exp(alpha_t)
161 | d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
162 | return alpha_t, d_alpha_t
163 |
164 | def compute_sigma_t(self, t):
165 | """Compute coefficient of x0"""
166 | p_sigma_t = 2 * self.log_mean_coeff(t)
167 | sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
168 | d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
169 | return sigma_t, d_sigma_t
170 |
171 | def compute_d_alpha_alpha_ratio_t(self, t):
172 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
173 | return self.d_log_mean_coeff(t)
174 |
175 | def compute_drift(self, x, t):
176 | """Compute the drift term of the SDE"""
177 | t = expand_t_like_x(t, x)
178 | beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
179 | return -0.5 * beta_t * x, beta_t / 2
180 |
181 |
182 | class GVPCPlan(ICPlan):
183 | def __init__(self, sigma=0.0):
184 | super().__init__(sigma)
185 |
186 | def compute_alpha_t(self, t):
187 | """Compute coefficient of x1"""
188 | alpha_t = th.sin(t * np.pi / 2)
189 | d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
190 | return alpha_t, d_alpha_t
191 |
192 | def compute_sigma_t(self, t):
193 | """Compute coefficient of x0"""
194 | sigma_t = th.cos(t * np.pi / 2)
195 | d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
196 | return sigma_t, d_sigma_t
197 |
198 | def compute_d_alpha_alpha_ratio_t(self, t):
199 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
200 | return np.pi / (2 * th.tan(t * np.pi / 2))
201 |
--------------------------------------------------------------------------------
/transport/transport.py:
--------------------------------------------------------------------------------
1 | """
2 | Reference: https://github.com/hustvl/LightningDiT/blob/main/transport/transport.py
3 | """
4 |
5 | import enum
6 |
7 | import numpy as np
8 | import torch as th
9 | from scipy.stats import norm
10 |
11 | from . import path
12 | from .integrators import ode, sde
13 |
14 |
15 | def mean_flat(x):
16 | """
17 | Take the mean over all non-batch dimensions.
18 | """
19 | return th.mean(x, dim=list(range(1, len(x.size()))))
20 |
21 |
22 | class ModelType(enum.Enum):
23 | """
24 | Which type of output the model predicts.
25 | """
26 |
27 | NOISE = enum.auto() # the model predicts epsilon
28 | SCORE = enum.auto() # the model predicts \nabla \log p(x)
29 | VELOCITY = enum.auto() # the model predicts v(x)
30 |
31 |
32 | class PathType(enum.Enum):
33 | """
34 | Which type of path to use.
35 | """
36 |
37 | LINEAR = enum.auto()
38 | GVP = enum.auto()
39 | VP = enum.auto()
40 |
41 |
42 | class WeightType(enum.Enum):
43 | """
44 | Which type of weighting to use.
45 | """
46 |
47 | NONE = enum.auto()
48 | VELOCITY = enum.auto()
49 | LIKELIHOOD = enum.auto()
50 |
51 |
52 | class Transport:
53 | def __init__(
54 | self,
55 | *,
56 | model_type,
57 | path_type,
58 | loss_type,
59 | train_eps,
60 | sample_eps,
61 | use_cosine_loss=False,
62 | use_lognorm=False,
63 | partitial_train=None,
64 | partial_ratio=1.0,
65 | shift_lg=False,
66 | ):
67 | path_options = {
68 | PathType.LINEAR: path.ICPlan,
69 | PathType.GVP: path.GVPCPlan,
70 | PathType.VP: path.VPCPlan,
71 | }
72 |
73 | self.loss_type = loss_type
74 | self.model_type = model_type
75 | self.path_sampler = path_options[path_type]()
76 | self.train_eps = train_eps
77 | self.sample_eps = sample_eps
78 | self.use_cosine_loss = use_cosine_loss
79 | self.use_lognorm = use_lognorm
80 | self.partitial_train = partitial_train
81 | self.partial_ratio = partial_ratio
82 | self.shift_lg = shift_lg
83 |
84 | def prior_logp(self, z):
85 | """
86 | Standard multivariate normal prior
87 | Assume z is batched
88 | """
89 | shape = th.tensor(z.size())
90 | N = th.prod(shape[1:])
91 | _fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
92 | return th.vmap(_fn)(z)
93 |
94 | def check_interval(
95 | self,
96 | train_eps,
97 | sample_eps,
98 | *,
99 | diffusion_form="SBDM",
100 | sde=False,
101 | reverse=False,
102 | eval=False,
103 | last_step_size=0.0,
104 | ):
105 | t0 = 0
106 | t1 = 1
107 | eps = train_eps if not eval else sample_eps
108 | if type(self.path_sampler) in [path.VPCPlan]:
109 |
110 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
111 |
112 | elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and (
113 | self.model_type != ModelType.VELOCITY or sde
114 | ): # avoid numerical issue by taking a first semi-implicit step
115 |
116 | t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
117 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
118 |
119 | if reverse:
120 | t0, t1 = 1 - t0, 1 - t1
121 |
122 | return t0, t1
123 |
124 | def sample_logit_normal(self, mu, sigma, size=1):
125 | # Generate samples from the normal distribution
126 | samples = norm.rvs(loc=mu, scale=sigma, size=size)
127 |
128 | # Transform samples to be in the range (0, 1) using the logistic function
129 | samples = 1 / (1 + np.exp(-samples))
130 |
131 | # Numpy to Tensor
132 | samples = th.tensor(samples, dtype=th.float32)
133 |
134 | return samples
135 |
136 | def sample_in_range(self, mu, sigma, target_size, range_min=0, range_max=0.5):
137 | samples = []
138 | while len(samples) < target_size:
139 | generated_samples = self.sample_logit_normal(mu, sigma, size=target_size)
140 | filtered_samples = generated_samples[
141 | (generated_samples >= range_min) & (generated_samples <= range_max)
142 | ]
143 | samples.extend(filtered_samples)
144 |
145 | # If we have more than the target size, truncate the list
146 | samples = samples[:target_size]
147 | return th.tensor(samples)
148 |
149 | def sample(self, x1, sp_timesteps=None, shifted_mu=0):
150 | """Sampling x0 & t based on shape of x1 (if needed)
151 | Args:
152 | x1 - data point; [batch, *dim]
153 | """
154 |
155 | x0 = th.randn_like(x1)
156 | t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
157 | if not self.use_lognorm:
158 | if self.partitial_train is not None and th.rand(1) < self.partial_ratio:
159 | t = (
160 | th.rand((x1.shape[0],)) * (self.partitial_train[1] - self.partitial_train[0])
161 | + self.partitial_train[0]
162 | )
163 | else:
164 | t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
165 | else:
166 | # random < partial_ratio, then sample from the partial range
167 | if not self.shift_lg:
168 | if self.partitial_train is not None and th.rand(1) < self.partial_ratio:
169 | t = self.sample_in_range(
170 | 0,
171 | 1,
172 | x1.shape[0],
173 | range_min=self.partitial_train[0],
174 | range_max=self.partitial_train[1],
175 | )
176 | else:
177 | t = self.sample_logit_normal(0, 1, size=x1.shape[0]) * (t1 - t0) + t0
178 | else:
179 | assert (
180 | self.partitial_train is None
181 | ), "Shifted lognormal distribution is not compatible with partial training"
182 | t = self.sample_logit_normal(shifted_mu, 1, size=x1.shape[0]) * (t1 - t0) + t0
183 |
184 | # overwrite t if sp_timesteps is provided (for validation)
185 | if sp_timesteps is not None:
186 | # uniform sampling between self.sp_timesteps[0] and self.sp_timesteps[1]
187 | t = th.rand((x1.shape[0],)) * (sp_timesteps[1] - sp_timesteps[0]) + sp_timesteps[0]
188 |
189 | t = t.to(x1)
190 | return t, x0, x1
191 |
192 | def training_losses(
193 | self,
194 | model,
195 | x1,
196 | model_kwargs=None,
197 | sp_timesteps=None,
198 | shifted_mu=0,
199 | return_model_output=False,
200 | ):
201 | """Loss for training the score model
202 | Args:
203 | - model: backbone model; could be score, noise, or velocity
204 | - x1: datapoint
205 | - model_kwargs: additional arguments for the model
206 | """
207 | if model_kwargs == None:
208 | model_kwargs = {}
209 |
210 | t, x0, x1 = self.sample(x1, sp_timesteps, shifted_mu)
211 | t, xt, ut = self.path_sampler.plan(t, x0, x1)
212 | raw_model_output = model(xt, t, **model_kwargs)
213 | if isinstance(raw_model_output, tuple):
214 | model_output = raw_model_output[0]
215 | else:
216 | model_output = raw_model_output
217 | B, *_, C = xt.shape
218 | # the channel dim is the one with least size
219 | channel_dim = min([i for i in range(1, len(xt.shape))], key=lambda x: xt.size(x))
220 | assert model_output.size() == (B, *xt.size()[1:-1], C)
221 |
222 | terms = {}
223 | terms["pred"] = model_output
224 | if self.model_type == ModelType.VELOCITY:
225 | terms["loss"] = mean_flat(((model_output - ut) ** 2))
226 | if self.use_cosine_loss:
227 | terms["cos_loss"] = mean_flat(
228 | 1 - th.nn.functional.cosine_similarity(model_output, ut, dim=channel_dim)
229 | )
230 | else:
231 | _, drift_var = self.path_sampler.compute_drift(xt, t)
232 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
233 | if self.loss_type in [WeightType.VELOCITY]:
234 | weight = (drift_var / sigma_t) ** 2
235 | elif self.loss_type in [WeightType.LIKELIHOOD]:
236 | weight = drift_var / (sigma_t**2)
237 | elif self.loss_type in [WeightType.NONE]:
238 | weight = 1
239 | else:
240 | raise NotImplementedError()
241 |
242 | if self.model_type == ModelType.NOISE:
243 | terms["loss"] = mean_flat(weight * ((model_output - x0) ** 2))
244 | else:
245 | terms["loss"] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
246 | if return_model_output:
247 | return terms, raw_model_output
248 | return terms
249 |
250 | def get_drift(self):
251 | """member function for obtaining the drift of the probability flow ODE"""
252 |
253 | def score_ode(x, t, model, **model_kwargs):
254 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
255 | model_output = model(x, t, **model_kwargs)
256 | return -drift_mean + drift_var * model_output # by change of variable
257 |
258 | def noise_ode(x, t, model, **model_kwargs):
259 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
260 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
261 | model_output = model(x, t, **model_kwargs)
262 | score = model_output / -sigma_t
263 | return -drift_mean + drift_var * score
264 |
265 | def velocity_ode(x, t, model, **model_kwargs):
266 | model_output = model(x, t, **model_kwargs)
267 | return model_output
268 |
269 | if self.model_type == ModelType.NOISE:
270 | drift_fn = noise_ode
271 | elif self.model_type == ModelType.SCORE:
272 | drift_fn = score_ode
273 | else:
274 | drift_fn = velocity_ode
275 |
276 | def body_fn(x, t, model, **model_kwargs):
277 | model_output = drift_fn(x, t, model, **model_kwargs)
278 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
279 | return model_output
280 |
281 | return body_fn
282 |
283 | def get_score(
284 | self,
285 | ):
286 | """member function for obtaining score of
287 | x_t = alpha_t * x + sigma_t * eps"""
288 | if self.model_type == ModelType.NOISE:
289 | score_fn = (
290 | lambda x, t, model, **kwargs: model(x, t, **kwargs)
291 | / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
292 | )
293 | elif self.model_type == ModelType.SCORE:
294 | score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
295 | elif self.model_type == ModelType.VELOCITY:
296 | score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(
297 | model(x, t, **kwargs), x, t
298 | )
299 | else:
300 | raise NotImplementedError()
301 |
302 | return score_fn
303 |
304 |
305 | class Sampler:
306 | """Sampler class for the transport model"""
307 |
308 | def __init__(
309 | self,
310 | transport,
311 | ):
312 | """Constructor for a general sampler; supporting different sampling methods
313 | Args:
314 | - transport: an tranport object specify model prediction & interpolant type
315 | """
316 |
317 | self.transport = transport
318 | self.drift = self.transport.get_drift()
319 | self.score = self.transport.get_score()
320 |
321 | def __get_sde_diffusion_and_drift(
322 | self,
323 | *,
324 | diffusion_form="SBDM",
325 | diffusion_norm=1.0,
326 | ):
327 | def diffusion_fn(x, t):
328 | diffusion = self.transport.path_sampler.compute_diffusion(
329 | x, t, form=diffusion_form, norm=diffusion_norm
330 | )
331 | return diffusion
332 |
333 | sde_drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs) + diffusion_fn(
334 | x, t
335 | ) * self.score(x, t, model, **kwargs)
336 |
337 | sde_diffusion = diffusion_fn
338 |
339 | return sde_drift, sde_diffusion
340 |
341 | def __get_last_step(
342 | self,
343 | sde_drift,
344 | *,
345 | last_step,
346 | last_step_size,
347 | ):
348 | """Get the last step function of the SDE solver"""
349 |
350 | if last_step is None:
351 | last_step_fn = lambda x, t, model, **model_kwargs: x
352 | elif last_step == "Mean":
353 | last_step_fn = (
354 | lambda x, t, model, **model_kwargs: x
355 | + sde_drift(x, t, model, **model_kwargs) * last_step_size
356 | )
357 | elif last_step == "Tweedie":
358 | alpha = (
359 | self.transport.path_sampler.compute_alpha_t
360 | ) # simple aliasing; the original name was too long
361 | sigma = self.transport.path_sampler.compute_sigma_t
362 | last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + (
363 | sigma(t)[0][0] ** 2
364 | ) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
365 | elif last_step == "Euler":
366 | last_step_fn = (
367 | lambda x, t, model, **model_kwargs: x
368 | + self.drift(x, t, model, **model_kwargs) * last_step_size
369 | )
370 | else:
371 | raise NotImplementedError()
372 |
373 | return last_step_fn
374 |
375 | def sample_sde(
376 | self,
377 | *,
378 | sampling_method="Euler",
379 | diffusion_form="SBDM",
380 | diffusion_norm=1.0,
381 | last_step="Mean",
382 | last_step_size=0.04,
383 | num_steps=250,
384 | ):
385 | """returns a sampling function with given SDE settings
386 | Args:
387 | - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
388 | - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
389 | - diffusion_norm: function magnitude of diffusion coefficient; default to 1
390 | - last_step: type of the last step; default to identity
391 | - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
392 | - num_steps: total integration step of SDE
393 | """
394 |
395 | if last_step is None:
396 | last_step_size = 0.0
397 |
398 | sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
399 | diffusion_form=diffusion_form,
400 | diffusion_norm=diffusion_norm,
401 | )
402 |
403 | t0, t1 = self.transport.check_interval(
404 | self.transport.train_eps,
405 | self.transport.sample_eps,
406 | diffusion_form=diffusion_form,
407 | sde=True,
408 | eval=True,
409 | reverse=False,
410 | last_step_size=last_step_size,
411 | )
412 |
413 | _sde = sde(
414 | sde_drift,
415 | sde_diffusion,
416 | t0=t0,
417 | t1=t1,
418 | num_steps=num_steps,
419 | sampler_type=sampling_method,
420 | )
421 |
422 | last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
423 |
424 | def _sample(init, model, **model_kwargs):
425 | xs = _sde.sample(init, model, **model_kwargs)
426 | ts = th.ones(init.size(0), device=init.device) * t1
427 | x = last_step_fn(xs[-1], ts, model, **model_kwargs)
428 | xs.append(x)
429 |
430 | assert len(xs) == num_steps, "Samples does not match the number of steps"
431 |
432 | return xs
433 |
434 | return _sample
435 |
436 | def sample_ode(
437 | self,
438 | *,
439 | sampling_method="dopri5",
440 | num_steps=50,
441 | atol=1e-6,
442 | rtol=1e-3,
443 | reverse=False,
444 | timestep_shift=0.0,
445 | ):
446 | """returns a sampling function with given ODE settings
447 | Args:
448 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
449 | - num_steps:
450 | - fixed solver (Euler, Heun): the actual number of integration steps performed
451 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
452 | - atol: absolute error tolerance for the solver
453 | - rtol: relative error tolerance for the solver
454 | - reverse: whether solving the ODE in reverse (data to noise); default to False
455 | """
456 | if reverse:
457 | drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
458 | else:
459 | drift = self.drift
460 |
461 | t0, t1 = self.transport.check_interval(
462 | self.transport.train_eps,
463 | self.transport.sample_eps,
464 | sde=False,
465 | eval=True,
466 | reverse=reverse,
467 | last_step_size=0.0,
468 | )
469 |
470 | _ode = ode(
471 | drift=drift,
472 | t0=t0,
473 | t1=t1,
474 | sampler_type=sampling_method,
475 | num_steps=num_steps,
476 | atol=atol,
477 | rtol=rtol,
478 | timestep_shift=timestep_shift,
479 | )
480 |
481 | return _ode.sample
482 |
483 | def sample_ode_likelihood(
484 | self,
485 | *,
486 | sampling_method="dopri5",
487 | num_steps=50,
488 | atol=1e-6,
489 | rtol=1e-3,
490 | ):
491 | """returns a sampling function for calculating likelihood with given ODE settings
492 | Args:
493 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
494 | - num_steps:
495 | - fixed solver (Euler, Heun): the actual number of integration steps performed
496 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
497 | - atol: absolute error tolerance for the solver
498 | - rtol: relative error tolerance for the solver
499 | """
500 |
501 | def _likelihood_drift(x, t, model, **model_kwargs):
502 | x, _ = x
503 | eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
504 | t = th.ones_like(t) * (1 - t)
505 | with th.enable_grad():
506 | x.requires_grad = True
507 | grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
508 | logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
509 | drift = self.drift(x, t, model, **model_kwargs)
510 | return (-drift, logp_grad)
511 |
512 | t0, t1 = self.transport.check_interval(
513 | self.transport.train_eps,
514 | self.transport.sample_eps,
515 | sde=False,
516 | eval=True,
517 | reverse=False,
518 | last_step_size=0.0,
519 | )
520 |
521 | _ode = ode(
522 | drift=_likelihood_drift,
523 | t0=t0,
524 | t1=t1,
525 | sampler_type=sampling_method,
526 | num_steps=num_steps,
527 | atol=atol,
528 | rtol=rtol,
529 | )
530 |
531 | def _sample_fn(x, model, **model_kwargs):
532 | init_logp = th.zeros(x.size(0)).to(x)
533 | input = (x, init_logp)
534 | drift, delta_logp = _ode.sample(input, model, **model_kwargs)
535 | drift, delta_logp = drift[-1], delta_logp[-1]
536 | prior_logp = self.transport.prior_logp(drift)
537 | logp = prior_logp - delta_logp
538 | return logp, drift
539 |
540 | return _sample_fn
541 |
--------------------------------------------------------------------------------
/utils/builders.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch.utils.data
4 | import torchvision.transforms as transforms
5 |
6 | import models
7 | import utils.distributed as distributed
8 | import utils.losses as losses
9 | from utils.loader import ListDataset, center_crop_arr
10 | from utils.misc import NativeScalerWithGradNormCount
11 |
12 | logger = logging.getLogger("DeTok")
13 |
14 |
15 | def create_train_dataloader(args, should_flip=True, batch_size=-1, return_path=False, drop_last=True):
16 | transform_train = transforms.Compose(
17 | [
18 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
19 | transforms.ToTensor(),
20 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
21 | ]
22 | )
23 | input_transform = transform_train if not args.use_cached_tokens else None
24 | dataset_train = ListDataset(
25 | args.data_path,
26 | data_list="data/train.txt",
27 | transform=input_transform,
28 | loader_name="img_loader" if not args.use_cached_tokens else "npz_loader",
29 | return_label=True,
30 | return_path=return_path,
31 | should_flip=should_flip,
32 | )
33 | logger.info(f"Train dataset size: {len(dataset_train)}")
34 |
35 | sampler_train = torch.utils.data.DistributedSampler(
36 | dataset_train,
37 | num_replicas=distributed.get_world_size(),
38 | rank=distributed.get_global_rank(),
39 | shuffle=True,
40 | )
41 | data_loader_train = torch.utils.data.DataLoader(
42 | dataset_train,
43 | sampler=sampler_train,
44 | batch_size=args.batch_size if batch_size < 0 else batch_size,
45 | num_workers=args.num_workers,
46 | pin_memory=args.pin_mem,
47 | drop_last=drop_last,
48 | )
49 | return data_loader_train
50 |
51 |
52 | def create_val_dataloader(args):
53 | transform_val = transforms.Compose(
54 | [
55 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
56 | transforms.ToTensor(),
57 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
58 | ]
59 | )
60 | dataset_val = ListDataset(
61 | args.data_path.replace("train", "val"),
62 | data_list="data/val.txt",
63 | transform=transform_val,
64 | loader_name="img_loader",
65 | return_label=False,
66 | return_index=True,
67 | should_flip=False,
68 | )
69 | sampler_val = torch.utils.data.DistributedSampler(
70 | dataset_val,
71 | num_replicas=distributed.get_world_size(),
72 | rank=distributed.get_global_rank(),
73 | shuffle=False,
74 | )
75 |
76 | logger.info(f"Val dataset size: {len(dataset_val)}")
77 |
78 | data_loader_val = torch.utils.data.DataLoader(
79 | dataset_val,
80 | sampler=sampler_val,
81 | batch_size=args.eval_bsz,
82 | num_workers=args.num_workers,
83 | pin_memory=args.pin_mem,
84 | drop_last=False,
85 | )
86 | return data_loader_val
87 |
88 |
89 | def create_vis_dataloader(args):
90 | transform_val = transforms.Compose(
91 | [
92 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)),
93 | transforms.ToTensor(),
94 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
95 | ]
96 | )
97 | dataset_vis = ListDataset(
98 | args.data_path,
99 | data_list="data/train.txt",
100 | transform=transform_val,
101 | loader_name="img_loader",
102 | return_label=False,
103 | return_index=True,
104 | class_of_interest=args.class_of_interest,
105 | )
106 | sampler_vis = torch.utils.data.DistributedSampler(
107 | dataset_vis,
108 | num_replicas=distributed.get_world_size(),
109 | rank=distributed.get_global_rank(),
110 | shuffle=True,
111 | )
112 |
113 | logger.info(f"Vis dataset size: {len(dataset_vis)}")
114 |
115 | data_loader_vis = torch.utils.data.DataLoader(
116 | dataset_vis,
117 | sampler=sampler_vis,
118 | batch_size=8,
119 | num_workers=args.num_workers,
120 | pin_memory=args.pin_mem,
121 | drop_last=False,
122 | )
123 | return data_loader_vis
124 |
125 |
126 | def create_generation_model(args):
127 | logger.info("Creating generation models.")
128 | if args.tokenizer is not None:
129 | if args.tokenizer in models.VAE_models:
130 | tokenizer = models.VAE_models[args.tokenizer]()
131 | elif args.tokenizer in models.DeTok_models:
132 | tokenizer = models.DeTok_models[args.tokenizer](
133 | img_size=args.img_size,
134 | patch_size=args.tokenizer_patch_size,
135 | token_channels=args.token_channels,
136 | mask_ratio=0.0,
137 | )
138 | else:
139 | raise ValueError(f"Unsupported tokenizer {args.tokenizer}")
140 | if args.load_tokenizer_from is not None:
141 | logger.info(f"[Tokenizer] Loading tokenizer from: {args.load_tokenizer_from}")
142 | weights = torch.load(args.load_tokenizer_from, weights_only=False, map_location="cpu")
143 | if args.use_ema_tokenizer and "model_ema" in weights:
144 | weights = weights["model_ema"]
145 | msg = tokenizer.load_state_dict(weights, strict=False)
146 | logger.info(f"[Tokenizer] Missing keys: {msg.missing_keys}")
147 | logger.info(f"[Tokenizer] Unexpected keys: {msg.unexpected_keys}")
148 | logger.info("[Tokenizer] Loaded EMA tokenizer.")
149 | else:
150 | if args.use_ema_tokenizer:
151 | logger.warning("EMA tokenizer is not in the checkpoint, using the model weights")
152 | weights = weights["model"] if "model" in weights else weights
153 | msg = tokenizer.load_state_dict(weights, strict=True)
154 | logger.info(f"[Tokenizer] Missing keys: {msg.missing_keys}")
155 | logger.info(f"[Tokenizer] Unexpected keys: {msg.unexpected_keys}")
156 | tokenizer.cuda().eval().requires_grad_(False)
157 | logger.info("====Tokenizer=====")
158 | logger.info(tokenizer)
159 | else:
160 | tokenizer = None
161 |
162 | if args.model in models.DiT_models:
163 | model = models.DiT_models[args.model](
164 | img_size=args.img_size,
165 | patch_size=args.patch_size,
166 | tokenizer_patch_size=args.tokenizer_patch_size,
167 | token_channels=args.token_channels,
168 | label_drop_prob=args.label_drop_prob,
169 | num_classes=args.num_classes,
170 | num_sampling_steps=args.num_sampling_steps,
171 | force_one_d_seq=args.force_one_d_seq,
172 | grad_checkpointing=args.grad_checkpointing,
173 | legacy_mode=args.legacy_mode, # legacy mode: cfg on the first three channels only
174 | )
175 | elif args.model in models.SiT_models:
176 | model = models.SiT_models[args.model](
177 | img_size=args.img_size,
178 | patch_size=args.patch_size,
179 | tokenizer_patch_size=args.tokenizer_patch_size,
180 | token_channels=args.token_channels,
181 | label_drop_prob=args.label_drop_prob,
182 | num_classes=args.num_classes,
183 | num_sampling_steps=args.num_sampling_steps,
184 | grad_checkpointing=args.grad_checkpointing,
185 | force_one_d_seq=args.force_one_d_seq,
186 | legacy_mode=args.legacy_mode, # legacy mode: cfg on the first three channels only
187 | qk_norm=args.qk_norm,
188 | )
189 | elif args.model in models.LightningDiT_models:
190 | model = models.LightningDiT_models[args.model](
191 | img_size=args.img_size,
192 | patch_size=args.patch_size,
193 | tokenizer_patch_size=args.tokenizer_patch_size,
194 | token_channels=args.token_channels,
195 | label_drop_prob=args.label_drop_prob,
196 | num_classes=args.num_classes,
197 | num_sampling_steps=args.num_sampling_steps,
198 | force_one_d_seq=args.force_one_d_seq,
199 | grad_checkpointing=args.grad_checkpointing,
200 | legacy_mode=args.legacy_mode, # legacy mode: cfg on the first three channels only
201 | qk_norm=args.qk_norm,
202 | )
203 | elif args.model in models.ARDiff_models:
204 | model = models.ARDiff_models[args.model](
205 | img_size=args.img_size,
206 | patch_size=args.patch_size,
207 | tokenizer_patch_size=args.tokenizer_patch_size,
208 | token_channels=args.token_channels,
209 | label_drop_prob=args.label_drop_prob,
210 | num_classes=args.num_classes,
211 | num_sampling_steps=args.num_sampling_steps,
212 | diffloss_d=args.diffloss_d,
213 | diffloss_w=args.diffloss_w,
214 | diffusion_batch_mul=args.diffusion_batch_mul,
215 | noise_schedule=args.noise_schedule,
216 | force_one_d_seq=args.force_one_d_seq,
217 | grad_checkpointing=args.grad_checkpointing,
218 | order=args.order,
219 | )
220 | elif args.model in models.MAR_models:
221 | model = models.MAR_models[args.model](
222 | img_size=args.img_size,
223 | patch_size=args.patch_size,
224 | tokenizer_patch_size=args.tokenizer_patch_size,
225 | token_channels=args.token_channels,
226 | label_drop_prob=args.label_drop_prob,
227 | num_classes=args.num_classes,
228 | num_sampling_steps=args.num_sampling_steps,
229 | diffloss_d=args.diffloss_d,
230 | diffloss_w=args.diffloss_w,
231 | diffusion_batch_mul=args.diffusion_batch_mul,
232 | noise_schedule=args.noise_schedule,
233 | attn_dropout=args.attn_dropout,
234 | proj_dropout=args.proj_dropout,
235 | buffer_size=args.buffer_size,
236 | mask_ratio_min=args.mask_ratio_min,
237 | grad_checkpointing=args.grad_checkpointing,
238 | force_one_d_seq=args.force_one_d_seq,
239 | no_dropout_in_mlp=args.no_dropout_in_mlp,
240 | )
241 | else:
242 | raise ValueError(f"Unsupported model {args.model}")
243 |
244 | model.cuda()
245 | logger.info("====Model=====")
246 | logger.info(model)
247 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
248 | logger.info(f"{args.model} Parameters: {n_params / 1e6:.2f}M ({n_params:,})")
249 |
250 | # ema model
251 | ema = models.SimpleEMAModel(model, decay=args.ema_rate)
252 | return model, tokenizer, ema
253 |
254 |
255 | def create_reconstruction_model(args):
256 | logger.info("Creating reconstruction models.")
257 | if args.model in models.VAE_models:
258 | model = models.VAE_models[args.model](
259 | load_ckpt=not getattr(args, "no_load_ckpt", False),
260 | gamma=args.gamma,
261 | )
262 | elif args.model in models.DeTok_models:
263 | model = models.DeTok_models[args.model](
264 | img_size=args.img_size,
265 | patch_size=args.patch_size,
266 | token_channels=args.token_channels,
267 | mask_ratio=args.mask_ratio,
268 | gamma=args.gamma,
269 | )
270 | else:
271 | raise ValueError(f"Unsupported model {args.model}")
272 |
273 | model.cuda()
274 | logger.info("====Model=====")
275 | logger.info(model)
276 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
277 | logger.info(f"{args.model} Trainable Parameters: {n_params / 1e6:.2f}M ({n_params:,})")
278 | ema = models.SimpleEMAModel(model, decay=args.ema_rate)
279 | return model, ema
280 |
281 |
282 | def create_optimizer_and_scaler(args, model, print_trainable_params=False):
283 | logger.info("creating optimizers")
284 |
285 | # exclude parameters from weight decay
286 | exclude = lambda name, p: (
287 | p.ndim < 2 or any(keyword in name for keyword in
288 | ["ln", "bias", "embedding", "norm", "gamma", "embed", "token", "diffloss"])
289 | )
290 |
291 | named_parameters = list(model.named_parameters())
292 | no_decay_list = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad]
293 | rest_params = [p for n, p in named_parameters if not exclude(n, p) and p.requires_grad]
294 | eff_batch_size = args.batch_size * args.world_size
295 |
296 | if args.lr is None:
297 | args.lr = args.blr * eff_batch_size / 256
298 |
299 | logger.info(f"base lr: {args.lr * 256 / eff_batch_size:.6e}")
300 | logger.info(f"actual lr: {args.lr:.6e}")
301 | logger.info(f"effective batch size: {eff_batch_size}")
302 | logger.info(f"training with {args.world_size} gpus")
303 | logger.info(f"weight_decay: {args.weight_decay} on {len(rest_params)} weight tensors")
304 | logger.info(f"no_decay: {len(no_decay_list)} weight tensors")
305 |
306 | optimizer = torch.optim.AdamW(
307 | [
308 | {"params": no_decay_list, "weight_decay": 0.0},
309 | {"params": rest_params, "weight_decay": args.weight_decay},
310 | ],
311 | lr=args.lr,
312 | betas=(args.beta1, args.beta2),
313 | )
314 | logger.info(f"Optimizer = {str(optimizer)}")
315 | if print_trainable_params:
316 | logger.info("trainable parameters:")
317 | for name, param in model.named_parameters():
318 | if param.requires_grad:
319 | logger.info(f"\t{name}")
320 |
321 | loss_scaler = NativeScalerWithGradNormCount()
322 | logger.info(f"Loss Scaler = {str(loss_scaler)}")
323 | return optimizer, loss_scaler
324 |
325 |
326 | def create_loss_module(args):
327 | loss_module = losses.ReconstructionLoss(
328 | discriminator_start_epoch=getattr(args, "discriminator_start_epoch", 20),
329 | perceptual_loss=getattr(args, "perceptual_loss", "lpips-convnext_s-1.0-0.1"),
330 | perceptual_weight=getattr(args, "perceptual_weight", 1.1),
331 | kl_weight=args.kl_loss_weight,
332 | )
333 | loss_module.cuda()
334 | logger.info("====Loss Module=====")
335 | # logger.info(loss_module)
336 | return loss_module
337 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import os
3 | import random
4 | import re
5 | import socket
6 | import sys
7 |
8 | import torch
9 | import torch.distributed as dist
10 | import torch.distributed.nn as dist_nn
11 |
12 | _LOCAL_RANK = -1
13 | _LOCAL_WORLD_SIZE = -1
14 |
15 | _TORCH_DISTRIBUTED_ENV_VARS = (
16 | "MASTER_ADDR",
17 | "MASTER_PORT",
18 | "RANK",
19 | "WORLD_SIZE",
20 | "LOCAL_RANK",
21 | "LOCAL_WORLD_SIZE",
22 | )
23 |
24 |
25 | def all_reduce_mean(x):
26 | world_size = get_world_size()
27 | if world_size > 1:
28 | if isinstance(x, torch.Tensor):
29 | x_reduce = x.clone().detach().cuda()
30 | else:
31 | x_reduce = torch.tensor(x).cuda()
32 | dist.all_reduce(x_reduce)
33 | x_reduce = x_reduce.float() / world_size
34 | return x_reduce.item()
35 | return x
36 |
37 |
38 | def concat_all_gather(tensor, gather_dim=0) -> torch.Tensor:
39 | if dist.get_world_size() == 1:
40 | return tensor
41 | output = dist_nn.functional.all_gather(tensor)
42 | return torch.cat(output, dim=gather_dim)
43 |
44 |
45 | def is_enabled() -> bool:
46 | return dist.is_available() and dist.is_initialized()
47 |
48 |
49 | def get_global_rank() -> int:
50 | return dist.get_rank() if is_enabled() else 0
51 |
52 |
53 | def get_world_size():
54 | return dist.get_world_size() if is_enabled() else 1
55 |
56 |
57 | def is_main_process() -> bool:
58 | return get_global_rank() == 0
59 |
60 |
61 | def _is_slurm_job_process() -> bool:
62 | return "SLURM_JOB_ID" in os.environ and not os.isatty(sys.stdout.fileno())
63 |
64 |
65 | def _parse_slurm_node_list(s: str) -> list[str]:
66 | nodes = []
67 | p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?")
68 | for m in p.finditer(s):
69 | prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)]
70 | for suffix in suffixes.split(","):
71 | span = suffix.split("-")
72 | if len(span) == 1:
73 | nodes.append(prefix + suffix)
74 | else:
75 | width = len(span[0])
76 | start, end = int(span[0]), int(span[1]) + 1
77 | nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)])
78 | return nodes
79 |
80 |
81 | def _get_available_port() -> int:
82 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
83 | s.bind(("", 0))
84 | return s.getsockname()[1]
85 |
86 |
87 | @functools.lru_cache
88 | def enable_distributed():
89 | if _is_slurm_job_process():
90 | os.environ["MASTER_ADDR"] = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])[0]
91 | os.environ["MASTER_PORT"] = str(random.Random(os.environ["SLURM_JOB_ID"]).randint(20_000, 60_000))
92 | os.environ["RANK"] = os.environ["SLURM_PROCID"]
93 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"]
94 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"]
95 | os.environ["LOCAL_WORLD_SIZE"] = str(
96 | int(os.environ["WORLD_SIZE"]) // int(os.environ["SLURM_JOB_NUM_NODES"])
97 | )
98 | elif "MASTER_ADDR" not in os.environ:
99 | os.environ["MASTER_ADDR"] = "127.0.0.1"
100 | os.environ["MASTER_PORT"] = str(_get_available_port())
101 | os.environ["RANK"] = "0"
102 | os.environ["WORLD_SIZE"] = "1"
103 | os.environ["LOCAL_RANK"] = "0"
104 | os.environ["LOCAL_WORLD_SIZE"] = "1"
105 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
106 | dist.init_process_group(backend="nccl")
107 | dist.barrier(device_ids=[int(os.environ["LOCAL_RANK"])])
108 |
--------------------------------------------------------------------------------
/utils/download.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import requests
4 | from tqdm import tqdm
5 |
6 |
7 | def download_pretrained_vae(overwrite=False):
8 | download_path = "pretrained_models/vae/kl16.ckpt"
9 | if not os.path.exists(download_path) or overwrite:
10 | headers = {"user-agent": "Wget/1.16 (linux-gnu)"}
11 | os.makedirs("pretrained_models/vae", exist_ok=True)
12 | r = requests.get(
13 | "https://www.dropbox.com/scl/fi/hhmuvaiacrarfg28qxhwz/kl16.ckpt?rlkey=l44xipsezc8atcffdp4q7mwmh&dl=0",
14 | stream=True,
15 | headers=headers,
16 | )
17 | print("Downloading KL-16 VAE...")
18 | with open(download_path, "wb") as f:
19 | for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024), unit="MB", total=254):
20 | if chunk:
21 | f.write(chunk)
22 |
23 |
24 | def download_pretrained_marb(overwrite=False):
25 | download_path = "pretrained_models/mar/mar_base/checkpoint-last.pth"
26 | if not os.path.exists(download_path) or overwrite:
27 | headers = {"user-agent": "Wget/1.16 (linux-gnu)"}
28 | os.makedirs("pretrained_models/mar/mar_base", exist_ok=True)
29 | r = requests.get(
30 | "https://www.dropbox.com/scl/fi/f6dpuyjb7fudzxcyhvrhk/checkpoint-last.pth?rlkey=a6i4bo71vhfo4anp33n9ukujb&dl=0",
31 | stream=True,
32 | headers=headers,
33 | )
34 | print("Downloading MAR-B...")
35 | with open(download_path, "wb") as f:
36 | for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024), unit="MB", total=1587):
37 | if chunk:
38 | f.write(chunk)
39 |
40 |
41 | def download_pretrained_marl(overwrite=False):
42 | download_path = "pretrained_models/mar/mar_large/checkpoint-last.pth"
43 | if not os.path.exists(download_path) or overwrite:
44 | headers = {"user-agent": "Wget/1.16 (linux-gnu)"}
45 | os.makedirs("pretrained_models/mar/mar_large", exist_ok=True)
46 | r = requests.get(
47 | "https://www.dropbox.com/scl/fi/pxacc5b2mrt3ifw4cah6k/checkpoint-last.pth?rlkey=m48ovo6g7ivcbosrbdaz0ehqt&dl=0",
48 | stream=True,
49 | headers=headers,
50 | )
51 | print("Downloading MAR-L...")
52 | with open(download_path, "wb") as f:
53 | for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024), unit="MB", total=3650):
54 | if chunk:
55 | f.write(chunk)
56 |
57 |
58 | def download_pretrained_marh(overwrite=False):
59 | download_path = "pretrained_models/mar/mar_huge/checkpoint-last.pth"
60 | if not os.path.exists(download_path) or overwrite:
61 | headers = {"user-agent": "Wget/1.16 (linux-gnu)"}
62 | os.makedirs("pretrained_models/mar/mar_huge", exist_ok=True)
63 | r = requests.get(
64 | "https://www.dropbox.com/scl/fi/1qmfx6fpy3k7j9vcjjs3s/checkpoint-last.pth?rlkey=4lae281yzxb406atp32vzc83o&dl=0",
65 | stream=True,
66 | headers=headers,
67 | )
68 | print("Downloading MAR-H...")
69 | with open(download_path, "wb") as f:
70 | for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024), unit="MB", total=7191):
71 | if chunk:
72 | f.write(chunk)
73 |
74 |
75 | if __name__ == "__main__":
76 | download_pretrained_vae()
77 | download_pretrained_marb()
78 | download_pretrained_marl()
79 | download_pretrained_marh()
80 |
--------------------------------------------------------------------------------
/utils/loader.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from functools import partial
4 | from typing import Any, Callable
5 |
6 | import numpy as np
7 | from PIL import Image
8 | from torchvision.datasets.folder import default_loader
9 |
10 | logger = logging.getLogger("DeTok")
11 |
12 | CONSTANTS = {
13 | # vavae latent statistics from https://huggingface.co/hustvl/vavae-imagenet256-f16d32-dinov2/blob/main/latents_stats.pt
14 | "vavae_mean": np.array([
15 | 0.5984623, -0.49917176, 0.6440029, -0.0970839, -1.190963, -1.4331622,
16 | 0.46853292, 0.6259252, 0.63195026, -0.4896733, -0.74451625, 1.1595623,
17 | 0.8456217, 0.5008238, 0.22926894, 0.47535565, -0.43787342, 0.8316961,
18 | -0.0750857, 0.30632293, 0.46645293, -0.09140775, -0.82710165, 0.07807512,
19 | 1.4150785, 1.3792385, 0.2695843, -0.7573224, 0.28129938, -0.30919993,
20 | 0.07785388, 0.34966648,
21 | ]),
22 | "vavae_std": np.array([
23 | 3.846138, 4.2699146, 3.5768437, 3.5911105, 3.6230576, 3.481018,
24 | 3.3074617, 3.5092657, 3.5540583, 3.6067245, 3.70579, 3.6314075,
25 | 3.6295316, 3.620502, 3.2590282, 3.186753, 3.8258142, 3.599939,
26 | 3.2966352, 3.226129, 3.2191944, 3.1054573, 3.580496, 4.356914,
27 | 3.308541, 3.2075875, 4.515047, 3.4869924, 3.0415804, 3.4868848,
28 | 4.4310327, 4.0881157,
29 | ]),
30 | }
31 |
32 |
33 | def center_crop_arr(pil_image: Image.Image, image_size: int) -> Image.Image:
34 | """center cropping implementation from adm.
35 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
36 | """
37 | while min(*pil_image.size) >= 2 * image_size:
38 | pil_image = pil_image.resize(
39 | tuple(x // 2 for x in pil_image.size),
40 | resample=Image.Resampling.BOX
41 | )
42 |
43 | scale = image_size / min(*pil_image.size)
44 | pil_image = pil_image.resize(
45 | tuple(round(x * scale) for x in pil_image.size),
46 | resample=Image.Resampling.BICUBIC
47 | )
48 |
49 | arr = np.array(pil_image)
50 | crop_y = (arr.shape[0] - image_size) // 2
51 | crop_x = (arr.shape[1] - image_size) // 2
52 | return Image.fromarray(arr[crop_y:crop_y + image_size, crop_x:crop_x + image_size])
53 |
54 |
55 | def default_np_loader(path: str) -> np.ndarray[Any, np.dtype[Any]]:
56 | return np.load(path, allow_pickle=True)
57 |
58 |
59 | class ListDataset:
60 | def __init__(
61 | self,
62 | data_root: str,
63 | data_list: str,
64 | transform: Callable[[Any], Any] | None = None,
65 | loader_name: str = "npz_loader",
66 | return_path: bool = False,
67 | return_label: bool = True,
68 | return_index: bool = False,
69 | should_flip: bool = True,
70 | class_of_interest: list[int] | None = None,
71 | ):
72 | self.data_root = data_root
73 | self.transform = transform
74 | self.return_path = return_path
75 | self.return_label = return_label
76 | self.return_index = return_index
77 | self.should_flip = should_flip
78 | self.class_of_interest = class_of_interest
79 |
80 | # loader function mapping
81 | loader_functions = {
82 | "img_loader": default_loader,
83 | "npz_loader": partial(np.load, allow_pickle=True),
84 | }
85 |
86 | if loader_name not in loader_functions:
87 | raise ValueError(f"Loader '{loader_name}' not supported")
88 |
89 | self.loader = loader_functions[loader_name]
90 | self.load_vae_latents = loader_name == "npz_loader"
91 | self.samples = self._load_samples(data_list, loader_name)
92 | self.targets = [label for _, label in self.samples]
93 |
94 | def _load_samples(self, data_list: str, loader_name: str) -> list[tuple[str, int | None]]:
95 | samples = []
96 | with open(data_list, "r") as f:
97 | for line in f:
98 | splits = line.strip().split(" ")
99 | if len(splits) == 2:
100 | file_path, label = splits
101 | label = int(label)
102 | else:
103 | file_path = line.strip()
104 | label = None
105 |
106 | if self.class_of_interest and label not in self.class_of_interest:
107 | continue
108 |
109 | # adjust file extensions based on loader
110 | if loader_name == "npz_loader":
111 | file_path = file_path.replace(".JPEG", ".JPEG.npz")
112 |
113 | samples.append((file_path, label))
114 | return samples
115 |
116 | def __getitem__(self, index: int) -> dict[str, Any]:
117 | return self._get_item_with_retry(index, 0)
118 |
119 | def _get_item_with_retry(self, index: int, retry_count: int) -> dict[str, Any]:
120 | if retry_count >= 100:
121 | raise RuntimeError(f"Failed to load data after 100 retries, last index: {index}")
122 |
123 | img_pth, label = self.samples[index]
124 | img_path_full = os.path.join(self.data_root, img_pth)
125 | should_flip = np.random.rand() < 0.5 if self.should_flip else False
126 | to_return = {}
127 |
128 | try:
129 | img = self.loader(img_path_full)
130 | if self.load_vae_latents:
131 | img_data = img # type: ignore
132 | img = img_data["moments_flip"] if should_flip else img_data["moments"]
133 | to_return = {"token": img}
134 | except Exception as e:
135 | logger.error(f"Error loading '{img_pth}': {e}")
136 | return self._get_item_with_retry((index + 1) % len(self.samples), retry_count + 1)
137 |
138 | if self.transform is not None:
139 | if "token" in to_return:
140 | # load original image when we have vae latents
141 | img_path_relative = img_path_full.split("/")[3:]
142 | img_path_relative = os.path.join(*img_path_relative)
143 | img_path_relative = img_path_relative.replace(".npz", "")
144 | img_path_full = os.path.join(self.data_root, img_path_relative)
145 | img = default_loader(img_path_full)
146 |
147 | img = self.transform(img)
148 | if should_flip:
149 | img = img.flip(dims=[2])
150 |
151 | if len(to_return) > 0:
152 | to_return["img"] = img
153 | else:
154 | to_return = {"img": img}
155 |
156 | if self.return_index:
157 | to_return["index"] = index
158 | if self.return_label:
159 | to_return["label"] = label
160 | if self.return_path:
161 | to_return["img_pth"] = img_pth
162 |
163 | return to_return
164 |
165 | def __len__(self) -> int:
166 | return len(self.samples)
167 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import datetime
3 | import json
4 | import logging
5 | import os
6 | import sys
7 | import time
8 | from collections import defaultdict, deque
9 | from typing import Any
10 |
11 | import torch
12 | import torch.distributed
13 | import wandb
14 | from rich.logging import RichHandler
15 | from typing_extensions import override
16 |
17 | from .distributed import get_global_rank, is_enabled, is_main_process
18 |
19 | logger = logging.getLogger("DeTok")
20 |
21 |
22 | def move_to_device(obj: Any, device: torch.device) -> Any:
23 | """recursively moves tensors in obj to the specified device."""
24 | if isinstance(obj, torch.Tensor):
25 | return obj.to(device, non_blocking=True)
26 | elif isinstance(obj, dict):
27 | return {key: move_to_device(value, device) for key, value in obj.items()}
28 | elif isinstance(obj, (list, tuple)):
29 | return type(obj)(move_to_device(o, device) for o in obj)
30 | else:
31 | return obj
32 |
33 |
34 | class SmoothedValue:
35 | """track a series of values and provide access to smoothed values over a window or the global series average."""
36 |
37 | def __init__(self, window_size: int = 20, fmt: str | None = None):
38 | if fmt is None:
39 | fmt = "{median:.4f} ({global_avg:.4f})"
40 | self.deque = deque(maxlen=window_size)
41 | self.total = 0.0
42 | self.count = 0
43 | self.fmt = fmt
44 |
45 | def update(self, value: float, num: int = 1) -> None:
46 | self.deque.append(value)
47 | self.count += num
48 | self.total += value * num
49 |
50 | def synchronize_between_processes(self) -> None:
51 | """distributed synchronization of the metric. warning: does not synchronize the deque!"""
52 | if not is_enabled():
53 | return
54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
55 | torch.distributed.barrier()
56 | torch.distributed.all_reduce(t)
57 | t = t.tolist()
58 | self.count = int(t[0])
59 | self.total = t[1]
60 |
61 | @property
62 | def median(self) -> float:
63 | d = torch.tensor(list(self.deque))
64 | return d.median().item()
65 |
66 | @property
67 | def avg(self) -> float:
68 | d = torch.tensor(list(self.deque), dtype=torch.float32)
69 | return d.mean().item()
70 |
71 | @property
72 | def global_avg(self) -> float:
73 | return self.total / self.count
74 |
75 | @property
76 | def max(self) -> float:
77 | return max(self.deque)
78 |
79 | @property
80 | def value(self) -> float:
81 | return self.deque[-1]
82 |
83 | @override
84 | def __str__(self) -> str:
85 | return self.fmt.format(
86 | median=self.median,
87 | avg=self.avg,
88 | global_avg=self.global_avg,
89 | max=self.max,
90 | value=self.value,
91 | )
92 |
93 |
94 | class MetricLogger:
95 | def __init__(self, delimiter: str = "\t", output_file: str | None = None, prefetch: bool = False):
96 | self.meters = defaultdict(SmoothedValue)
97 | self.delimiter = delimiter
98 | self.output_file = output_file
99 | self.prefetch = prefetch
100 | logger.info(f"MetricLogger: output_file={output_file}, prefetch={prefetch}")
101 |
102 | def update(self, **kwargs) -> None:
103 | for k, v in kwargs.items():
104 | if v is None:
105 | continue
106 | if isinstance(v, torch.Tensor):
107 | v = v.item()
108 | assert isinstance(v, (float, int))
109 | self.meters[k].update(v)
110 |
111 | def __getattr__(self, attr: str):
112 | if attr in self.meters:
113 | return self.meters[attr]
114 | if attr in self.__dict__:
115 | return self.__dict__[attr]
116 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
117 |
118 | @override
119 | def __str__(self) -> str:
120 | loss_str = []
121 | for name, meter in self.meters.items():
122 | loss_str.append(f"{name}: {str(meter)}")
123 | return self.delimiter.join(loss_str)
124 |
125 | def synchronize_between_processes(self) -> None:
126 | for meter in self.meters.values():
127 | meter.synchronize_between_processes()
128 |
129 | def add_meter(self, name: str, meter: SmoothedValue) -> None:
130 | self.meters[name] = meter
131 |
132 | def dump_in_output_file(self, iteration: int, iter_time: float, data_time: float) -> None:
133 | if self.output_file is None or not is_main_process():
134 | return
135 | dict_to_dump = dict(
136 | iteration=iteration,
137 | iter_time=iter_time,
138 | data_time=data_time,
139 | )
140 | dict_to_dump.update({k: v.median for k, v in self.meters.items()})
141 | with open(self.output_file, "a") as f:
142 | f.write(json.dumps(dict_to_dump) + "\n")
143 |
144 | def log_every(
145 | self,
146 | iterable,
147 | print_freq: int,
148 | header: str | None = None,
149 | n_iterations: int | None = None,
150 | start_iteration: int = 0,
151 | ):
152 | i = start_iteration
153 | if not header:
154 | header = ""
155 | start_time = time.time()
156 | end = time.time()
157 | iter_time = SmoothedValue(fmt="{avg:.4f}")
158 | data_time = SmoothedValue(fmt="{avg:.4f}")
159 |
160 | if n_iterations is None:
161 | try:
162 | n_iterations = len(iterable)
163 | except TypeError:
164 | # iterable doesn't have len, use a default or require user to provide
165 | raise ValueError("n_iterations must be provided for iterables without __len__")
166 |
167 | space_fmt = ":" + str(len(str(n_iterations))) + "d"
168 |
169 | log_list = [
170 | header,
171 | "[{0" + space_fmt + "}/{1}]",
172 | "eta: {eta}",
173 | "elapsed: {elapsed_time_str}",
174 | "{meters}",
175 | "time: {time}",
176 | "data: {data}",
177 | ]
178 | if torch.cuda.is_available():
179 | log_list += ["max mem: {memory:.0f}"]
180 |
181 | log_msg = self.delimiter.join(log_list)
182 | MB = 1024.0 * 1024.0
183 | for obj in iterable:
184 | if self.prefetch:
185 | obj = move_to_device(obj, torch.device("cuda"))
186 | data_time.update(time.time() - end)
187 | yield obj
188 | iter_time.update(time.time() - end)
189 | if i % print_freq == 0 or i == n_iterations - 1:
190 | self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg)
191 | eta_seconds = iter_time.global_avg * (n_iterations - i)
192 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
193 | elapsed_time = time.time() - start_time
194 | elapsed_time_str = str(datetime.timedelta(seconds=int(elapsed_time)))
195 |
196 | if torch.cuda.is_available():
197 | logger.info(
198 | log_msg.format(
199 | i,
200 | n_iterations,
201 | eta=eta_string,
202 | elapsed_time_str=elapsed_time_str,
203 | meters=str(self),
204 | time=str(iter_time),
205 | data=str(data_time),
206 | memory=torch.cuda.max_memory_allocated() / MB,
207 | )
208 | )
209 | else:
210 | logger.info(
211 | log_msg.format(
212 | i,
213 | n_iterations,
214 | eta=eta_string,
215 | meters=str(self),
216 | time=str(iter_time),
217 | data=str(data_time),
218 | )
219 | )
220 | i += 1
221 | end = time.time()
222 | if i >= n_iterations:
223 | break
224 | total_time = time.time() - start_time
225 | total_time_str = str(datetime.timedelta(seconds=int(total_time)))
226 | logger.info(f"{header} Total time: {total_time_str} ({total_time / n_iterations:.6f} s / it)")
227 |
228 |
229 | class WandbLogger:
230 | def __init__(
231 | self,
232 | config,
233 | entity: str,
234 | project: str,
235 | name: str,
236 | log_dir: str,
237 | run_id: str | None = None,
238 | ):
239 | self.run = wandb.init(
240 | config=config,
241 | entity=entity,
242 | project=project,
243 | name=name,
244 | dir=log_dir,
245 | resume="allow",
246 | id=run_id,
247 | )
248 | self.run_id = self.run.id
249 | self.step = 0
250 | self.run.log_code(".")
251 |
252 | def update(self, metrics, step: int | None = None) -> None:
253 | log_dict = {
254 | k: v.item() if isinstance(v, torch.Tensor) else v for k, v in metrics.items() if v is not None
255 | }
256 | try:
257 | wandb.log(log_dict, step=step or self.step)
258 | except Exception as e:
259 | logger.error(f"wandb logging failed: {e}")
260 | if step is not None:
261 | self.step = step
262 |
263 | def finish(self) -> None:
264 | try:
265 | wandb.finish()
266 | except Exception as e:
267 | logger.error(f"wandb failed to finish: {e}")
268 |
269 |
270 | def setup_logging(output: str, name: str = "DeTok", rank0_log_only: bool = True) -> None:
271 | """setup logging."""
272 | logging.captureWarnings(True)
273 |
274 | logger = logging.getLogger(name)
275 | logger.setLevel(logging.INFO)
276 | logger.propagate = False
277 |
278 | # google glog format: [IWEF]yyyymmdd hh:mm:ss logger filename:line] msg
279 | fmt_prefix = "%(levelname).1s%(asctime)s %(name)s %(filename)s:%(lineno)s] "
280 | fmt_message = "%(message)s"
281 | fmt = fmt_prefix + fmt_message
282 | datefmt = "%Y%m%d %H:%M:%S"
283 | formatter = logging.Formatter(fmt=fmt, datefmt=datefmt)
284 |
285 | # stdout logging for main worker only
286 | if is_main_process():
287 | if sys.stdout.isatty():
288 | handler = RichHandler(markup=True, show_time=False, show_level=False)
289 | else:
290 | handler = logging.StreamHandler(stream=sys.stdout)
291 | handler.setLevel(logging.DEBUG)
292 | handler.setFormatter(formatter)
293 | logger.addHandler(handler)
294 |
295 | # file logging
296 | if output:
297 | if os.path.splitext(output)[-1] in (".txt", ".log"):
298 | # if output is a file, use it directly
299 | filename = output
300 | else:
301 | # if output is a directory, use the directory/log.txt
302 | filename = os.path.join(output, "log.txt")
303 |
304 | # if not rank 0 but rank0_log_only=false, append the rank id
305 | if not is_main_process() and not rank0_log_only:
306 | global_rank = get_global_rank()
307 | if filename.endswith(".txt"):
308 | filename = filename.replace(".txt", f".rank{global_rank}.txt")
309 | else:
310 | filename = f"{filename}.rank{global_rank}"
311 |
312 | os.makedirs(os.path.dirname(filename), exist_ok=True)
313 |
314 | handler = logging.StreamHandler(open(filename, "a"))
315 | handler.setLevel(logging.DEBUG)
316 | handler.setFormatter(formatter)
317 | logger.addHandler(handler)
318 |
319 |
320 | def setup_wandb(args: argparse.Namespace, entity: str, project: str, name: str, log_dir: str) -> WandbLogger:
321 | """Setup Weights & Biases logging with resume capability."""
322 | run_id_path = os.path.join(log_dir, "wandb_run_id.txt")
323 | run_id = None
324 |
325 | # resume from wandb run id if it exists
326 | if os.path.exists(run_id_path):
327 | with open(run_id_path, "r") as f:
328 | run_id = f.readlines()[-1].strip()
329 |
330 | wandb_logger = WandbLogger(
331 | config=args,
332 | entity=entity,
333 | project=project,
334 | name=name,
335 | log_dir=log_dir,
336 | run_id=run_id,
337 | )
338 |
339 | # if no run id, save the run id to the log directory
340 | if run_id is None:
341 | with open(run_id_path, "a") as f:
342 | f.write(wandb_logger.run.id + "\n")
343 |
344 | return wandb_logger
345 |
--------------------------------------------------------------------------------
/utils/misc.py:
--------------------------------------------------------------------------------
1 | import math
2 | import random
3 | import argparse
4 | import os
5 | import copy
6 | import datetime
7 | from glob import glob
8 | import logging
9 |
10 | import numpy as np
11 | import torch
12 | import torch.nn.utils
13 | from torch import inf
14 |
15 | import utils.distributed as dist
16 |
17 | logger = logging.getLogger("DeTok")
18 |
19 |
20 | def fix_random_seeds(seed=31):
21 | torch.manual_seed(seed)
22 | torch.cuda.manual_seed_all(seed)
23 | np.random.seed(seed)
24 | random.seed(seed)
25 |
26 |
27 | def adjust_learning_rate(optimizer, epoch, args):
28 | """Decay the learning rate with half-cycle cosine after warmup"""
29 | if epoch < args.warmup_epochs:
30 | lr = args.lr * epoch / args.warmup_epochs
31 | else:
32 | if args.lr_sched == "constant":
33 | lr = args.lr
34 | elif args.lr_sched == "cosine":
35 | progress = (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)
36 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress))
37 | else:
38 | raise NotImplementedError
39 | for param_group in optimizer.param_groups:
40 | if "lr_scale" in param_group:
41 | param_group["lr"] = lr * param_group["lr_scale"]
42 | else:
43 | param_group["lr"] = lr
44 | return lr
45 |
46 |
47 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
48 | if isinstance(parameters, torch.Tensor):
49 | parameters = [parameters]
50 | parameters = [p for p in parameters if p.grad is not None]
51 | norm_type = float(norm_type)
52 | if len(parameters) == 0:
53 | return torch.tensor(0.0)
54 | device = parameters[0].grad.device
55 | if norm_type == inf:
56 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
57 | else:
58 | total_norm = torch.norm(
59 | torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
60 | norm_type,
61 | )
62 | return total_norm
63 |
64 |
65 | class NativeScalerWithGradNormCount:
66 | state_dict_key = "amp_scaler"
67 |
68 | def __init__(self, enabled: bool = True):
69 | self._scaler = torch.GradScaler(device="cuda", enabled=enabled)
70 |
71 | def __call__(
72 | self,
73 | loss,
74 | optimizer,
75 | clip_grad=None,
76 | parameters=None,
77 | create_graph=False,
78 | update_grad=True,
79 | ):
80 | self._scaler.scale(loss).backward(create_graph=create_graph)
81 | if update_grad:
82 | if clip_grad is not None and clip_grad > 0.0:
83 | assert parameters is not None
84 | self._scaler.unscale_(optimizer)
85 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
86 | else:
87 | self._scaler.unscale_(optimizer)
88 | norm = get_grad_norm_(parameters)
89 | self._scaler.step(optimizer)
90 | self._scaler.update()
91 | else:
92 | norm = None
93 | return norm
94 |
95 | def state_dict(self):
96 | return self._scaler.state_dict()
97 |
98 | def load_state_dict(self, state_dict):
99 | self._scaler.load_state_dict(state_dict)
100 |
101 |
102 | def ckpt_resume(
103 | args: argparse.Namespace,
104 | model: torch.nn.Module,
105 | optimizer: torch.optim.Optimizer | None = None,
106 | loss_scaler: NativeScalerWithGradNormCount | None = None,
107 | model_ema: torch.nn.Module | None = None,
108 | loss_module: torch.nn.Module | None = None,
109 | discriminator_optimizer: torch.optim.Optimizer | None = None,
110 | discriminator_loss_scaler: NativeScalerWithGradNormCount | None = None,
111 | ):
112 | if args.resume_from or args.auto_resume:
113 | if args.resume_from is None:
114 | # find the latest checkpoint
115 | checkpoints = [ckpt for ckpt in glob(f"{args.ckpt_dir}/*.pth") if "latest" not in ckpt]
116 | checkpoints = sorted(checkpoints, key=os.path.getmtime)
117 | if len(checkpoints) > 0:
118 | args.resume_from = checkpoints[-1]
119 |
120 | if args.resume_from and os.path.exists(args.resume_from):
121 | # load the checkpoint
122 | logger.info(f"[Model-resume] Resuming from: {args.resume_from}")
123 | checkpoint = torch.load(args.resume_from, map_location="cpu", weights_only=False)
124 | msg = model.load_state_dict(checkpoint["model"])
125 | logger.info(f"[Model-resume] Loaded model: {msg}")
126 |
127 | if "model_ema" in checkpoint:
128 | # load the EMA state dict if it exists
129 | ema_state_dict = checkpoint["model_ema"]
130 | logger.info(f"[Model-resume] Loaded EMA")
131 | else:
132 | # if no EMA state dict, use the model state dict to initialize the EMA state dict
133 | model_state_dict = model.state_dict()
134 | param_keys = [k for k, _ in model.named_parameters()]
135 | ema_state_dict = {k: model_state_dict[k] for k in param_keys}
136 | logger.info(f"[Model-resume] Loaded EMA with model state dict")
137 |
138 | # load the EMA state dict if it exists
139 | if model_ema is not None:
140 | model_ema.load_state_dict(ema_state_dict)
141 | model_ema.to("cuda") # move the EMA model to the GPU
142 |
143 | # load the optimizer state dict if it exists
144 | if "optimizer" in checkpoint and "epoch" in checkpoint and optimizer is not None:
145 | optimizer.load_state_dict(checkpoint["optimizer"])
146 | args.start_epoch = checkpoint["epoch"] + 1
147 | # load the loss scaler state dict if it exists
148 | if "loss_scaler" in checkpoint and loss_scaler is not None:
149 | loss_scaler.load_state_dict(checkpoint["loss_scaler"])
150 |
151 | # load the last elapsed time if it exists
152 | if "last_elapsed_time" in checkpoint:
153 | args.last_elapsed_time = float(checkpoint["last_elapsed_time"])
154 | elapsed_time_str = str(datetime.timedelta(seconds=int(args.last_elapsed_time)))
155 | logger.info(f"Loaded elapsed_time: {elapsed_time_str}")
156 |
157 | # load the loss module state dict if it exists
158 | if "loss_module" in checkpoint and loss_module is not None:
159 | msg = loss_module.load_state_dict(checkpoint["loss_module"])
160 | logger.info(f"[Model-resume] Loaded loss_module: {msg}")
161 |
162 | if "discriminator_optimizer" in checkpoint and discriminator_optimizer is not None:
163 | msg = discriminator_optimizer.load_state_dict(checkpoint["discriminator_optimizer"])
164 | logger.info(f"[Model-resume] Loaded discriminator_optimizer: {msg}")
165 |
166 | if "discriminator_loss_scaler" in checkpoint and discriminator_loss_scaler is not None:
167 | msg = discriminator_loss_scaler.load_state_dict(checkpoint["discriminator_loss_scaler"])
168 | logger.info(f"[Model-resume] Loaded discriminator_loss_scaler: {msg}")
169 |
170 | # delete the checkpoint to save memory
171 | del checkpoint
172 | else:
173 | logger.info(f"[Model-resume] Could not find checkpoint at {args.resume_from}.")
174 | else:
175 | logger.info(f"[Model-resume] Could not find checkpoint at {args.resume_from}.")
176 |
177 | if args.load_from and not args.resume_from:
178 | # if no checkpoint is provided, load the checkpoint from the load_from path instead
179 | if os.path.exists(args.load_from):
180 | logger.info(f"[Model-load] Loading checkpoint from: {args.load_from}")
181 | checkpoint = torch.load(args.load_from, map_location="cpu", weights_only=False)
182 | # load the model state dict if it exists
183 | state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint
184 | msg = model.load_state_dict(state_dict, strict=False)
185 | # assert unexpected keys can only start with "loss."
186 | for key in msg.unexpected_keys:
187 | assert key.startswith("loss."), f"unexpected key {key} doesn't start with 'loss.'"
188 | logger.info(f"[Model-load] Loaded model: {msg}")
189 | if "model_ema" in checkpoint:
190 | logger.info(f"[Model-load] Loaded EMA")
191 | ema_state_dict = checkpoint["model_ema"]
192 | else:
193 | logger.info(f"[Model-load] Loaded EMA with model state dict")
194 | ema_state_dict = copy.deepcopy(model.state_dict())
195 | if model_ema is not None:
196 | model_ema.load_state_dict(ema_state_dict)
197 | model_ema.to(device="cuda") # move the EMA model to the GPU
198 | del checkpoint # delete the checkpoint to save memory
199 | else:
200 | raise FileNotFoundError(f"Could not find checkpoint at {args.load_from}")
201 |
202 |
203 | def cleanup_checkpoints(ckpt_dir: str, keep_num: int = 5, milestone_interval: int = 5):
204 | """
205 | Clean up older checkpoint files in `ckpt_dir` while keeping the latest `keep_num` checkpoints by epoch number.
206 |
207 | Parameters
208 | ----------
209 | ckpt_dir : str
210 | The directory where checkpoint .pth files are stored.
211 | keep_num : int, optional
212 | The number of most recent checkpoints to keep (default=5).
213 | milestone_interval : int, optional
214 | The interval used to decide if a checkpoint is a "milestone."
215 | If (epoch_num + 1) % milestone_interval == 0, it is kept (default=50).
216 | """
217 | ckpts = glob(os.path.join(ckpt_dir, "*.pth"))
218 | ckpts = [ckpt for ckpt in ckpts if "latest" not in ckpt and "best" not in ckpt]
219 |
220 | def get_ckpt_num(path):
221 | """Extract the epoch number from a checkpoint filename."""
222 | filename = os.path.basename(path)
223 | # expecting something like 'epoch_049.pth'
224 | # we'll parse out the part after the last underscore and before '.pth'
225 | try:
226 | return int(filename.rsplit("_", 1)[-1].split(".")[0])
227 | except ValueError:
228 | return None
229 |
230 | # sort checkpoints by epoch number
231 | ckpts.sort(key=lambda x: (get_ckpt_num(x) is None, get_ckpt_num(x)))
232 |
233 | # filter out any that failed to parse an integer epoch (get_ckpt_num == None)
234 | ckpts = [ckpt for ckpt in ckpts if get_ckpt_num(ckpt) is not None]
235 |
236 | if not ckpts:
237 | # if no checkpoints remain, nothing to do
238 | return
239 |
240 | # determine which checkpoints to keep:
241 | # 1. the newest `keep_num` by epoch number.
242 | # 2. any milestone checkpoints.
243 | # (epoch_num + 1) % milestone_interval == 0
244 | newest_keep = set(ckpts[-keep_num:]) # handle if keep_num > number of ckpts
245 | milestone_keep = set(ckpt for ckpt in ckpts if ((get_ckpt_num(ckpt) + 1) % milestone_interval == 0))
246 |
247 | # union of both sets
248 | keep_set = newest_keep.union(milestone_keep)
249 |
250 | # remove anything not in keep_set
251 | for ckpt in ckpts:
252 | if ckpt not in keep_set:
253 | os.remove(ckpt)
254 | logger.info(f"Removed checkpoint: {ckpt}")
255 |
256 | # recreate the 'latest.pth' symlink to the newest checkpoint
257 | if keep_set:
258 | # we need the absolute newest based on epoch number
259 | # sort again from keep_set only
260 | remaining_ckpts_sorted = sorted(keep_set, key=lambda x: (get_ckpt_num(x) is None, get_ckpt_num(x)))
261 | newest_ckpt = os.path.abspath(remaining_ckpts_sorted[-1])
262 | latest_symlink = os.path.join(ckpt_dir, "latest.pth")
263 |
264 | # remove the old symlink if it exists
265 | try:
266 | os.remove(latest_symlink)
267 | logger.info(f"Removed old symlink: {latest_symlink}")
268 | except FileNotFoundError:
269 | pass
270 |
271 | # create a new symlink
272 | os.symlink(newest_ckpt, latest_symlink)
273 | logger.info(f"Created symlink: {latest_symlink} -> {newest_ckpt}")
274 |
275 |
276 | def save_checkpoint(
277 | args,
278 | epoch,
279 | model,
280 | optimizer,
281 | loss_scaler,
282 | model_ema,
283 | elapsed_time=0.0,
284 | loss_module=None,
285 | discriminator_optimizer=None,
286 | discriminator_loss_scaler=None,
287 | ):
288 | if not dist.is_main_process():
289 | return
290 | checkpoint = {
291 | "model": model.state_dict(),
292 | "model_ema": model_ema.state_dict() if model_ema is not None else None,
293 | "optimizer": optimizer.state_dict(),
294 | "loss_scaler": loss_scaler.state_dict(),
295 | "epoch": epoch,
296 | "last_elapsed_time": elapsed_time,
297 | }
298 | if loss_module is not None and isinstance(loss_module, torch.nn.Module):
299 | checkpoint["loss_module"] = loss_module.state_dict()
300 | if discriminator_optimizer is not None:
301 | checkpoint["discriminator_optimizer"] = discriminator_optimizer.state_dict()
302 | if discriminator_loss_scaler is not None:
303 | checkpoint["discriminator_loss_scaler"] = discriminator_loss_scaler.state_dict()
304 | checkpoint_path = os.path.join(args.ckpt_dir, f"epoch_{epoch:04d}.pth")
305 | torch.save(checkpoint, checkpoint_path)
306 | logger.info(f"Saved checkpoint: {checkpoint_path}")
307 | cleanup_checkpoints(args.ckpt_dir, args.keep_n_ckpts, args.milestone_interval)
308 |
--------------------------------------------------------------------------------