├── .gitignore ├── .vscode └── settings.json ├── LICENSE ├── README.md ├── assets ├── demo.png └── method.png ├── demo.ipynb ├── diffusion ├── __init__.py ├── diffusion_utils.py ├── gaussian_diffusion.py └── respace.py ├── main_diffusion.py ├── main_reconstruction.py ├── models ├── __init__.py ├── ardiff.py ├── autoencoder.py ├── detok.py ├── diffloss.py ├── dit.py ├── ema.py ├── layers.py ├── lightningdit.py ├── mar.py ├── model_utils.py └── sit.py ├── pyrightconfig.json ├── requirements.txt ├── transport ├── __init__.py ├── integrators.py ├── path.py └── transport.py └── utils ├── builders.py ├── distributed.py ├── download.py ├── loader.py ├── logger.py ├── losses.py ├── misc.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | data 3 | work_dirs/ 4 | *.mp4 5 | *draft* 6 | pretrained_models/ 7 | code.txt 8 | 9 | 10 | # media 11 | *.mp4 12 | *.png 13 | *.jpg 14 | 15 | 16 | # wandb 17 | wandb/ 18 | 19 | # work in progress 20 | *wip* 21 | 22 | *results* 23 | *debug* 24 | *tmp* 25 | 26 | # caches 27 | *.pyc 28 | *.swp 29 | 30 | 31 | __pycache__ 32 | 33 | # checkpoints 34 | *.pt 35 | *.pth 36 | *.npz 37 | src/ 38 | 39 | /guided-diffusion/ 40 | 41 | 42 | .clangd 43 | compile_commands.json 44 | 45 | # Visual Studio Code configs. 46 | .vscode/ 47 | 48 | # Byte-compiled / optimized / DLL files 49 | *.py[cod] 50 | *$py.class 51 | 52 | # C extensions 53 | *.so 54 | 55 | # Distribution / packaging 56 | .Python 57 | build/ 58 | develop-eggs/ 59 | dist/ 60 | downloads/ 61 | eggs/ 62 | .eggs/ 63 | # lib/ 64 | lib64/ 65 | parts/ 66 | sdist/ 67 | var/ 68 | wheels/ 69 | *.egg-info/ 70 | .installed.cfg 71 | *.egg 72 | MANIFEST 73 | 74 | # PyInstaller 75 | # Usually these files are written by a python script from a template 76 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 77 | *.manifest 78 | *.spec 79 | 80 | # Installer logs 81 | pip-log.txt 82 | pip-delete-this-directory.txt 83 | 84 | # Unit test / coverage reports 85 | htmlcov/ 86 | .tox/ 87 | .coverage 88 | .coverage.* 89 | .cache 90 | nosetests.xml 91 | coverage.xml 92 | *.cover 93 | .hypothesis/ 94 | .pytest_cache/ 95 | 96 | # Translations 97 | *.mo 98 | *.pot 99 | 100 | # Django stuff: 101 | *.log 102 | local_settings.py 103 | db.sqlite3 104 | 105 | # Flask stuff: 106 | instance/ 107 | .webassets-cache 108 | 109 | # Scrapy stuff: 110 | .scrapy 111 | 112 | # Sphinx documentation 113 | docs/_build/ 114 | 115 | # PyBuilder 116 | target/ 117 | 118 | # Jupyter Notebook 119 | .ipynb_checkpoints 120 | 121 | # pyenv 122 | .python-version 123 | 124 | # celery beat schedule file 125 | celerybeat-schedule 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | 152 | .DS_Store 153 | 154 | # Direnv config. 155 | .envrc 156 | 157 | # line_profiler 158 | *.lprof 159 | 160 | *build 161 | compile_commands.json 162 | *.dump 163 | 164 | # modelling/ 165 | legacy/ 166 | released_model/ 167 | scripts/ -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | // Automatically format using Black on save. 3 | "editor.formatOnSave": true, 4 | // Draw a ruler at Black's column width. 5 | "editor.rulers": [ 6 | 110, 7 | ], 8 | // Hide non-code files. 9 | "files.exclude": { 10 | "**/.git": true, 11 | "**/.svn": true, 12 | "**/.hg": true, 13 | "**/CVS": true, 14 | "**/.DS_Store": true, 15 | "**/Thumbs.db": true, 16 | "**/__pycache__": true, 17 | "**/venv": true 18 | }, 19 | "[python]": { 20 | "editor.defaultFormatter": "ms-python.black-formatter", 21 | }, 22 | "debug.focusWindowOnBreak": false, 23 | "files.watcherExclude": { 24 | "**/.git/**": true, 25 | "**/checkpoints/**": true, 26 | "**/data/**": true, 27 | "**/work_dirs/**": true, 28 | "**/lightning_logs/**": true, 29 | "**/outputs/**": true, 30 | "**/dataset_cache/**": true, 31 | "**/.ruff_cache/**": true, 32 | "**/venv/**": true, 33 | "**/data": true 34 | }, 35 | "editor.mouseWheelZoom": true, 36 | "terminal.integrated.mouseWheelZoom": true 37 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Jiawei Yang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeTok: Latent Denoising Makes Good Visual Tokenizers
Official PyTorch Implementation 2 | 3 | [![arXiv](https://img.shields.io/badge/arXiv-2507.15856-b31b1b.svg)](https://arxiv.org/abs/2507.15856)  4 | [![huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20HuggingFace-DeTok-yellow)](https://huggingface.co/jjiaweiyang/l-DeTok)  5 | 6 |

7 | 8 |

9 | 10 | This is a PyTorch/GPU implementation of the paper **Latent Denoising Makes Good Visual Tokenizers**: 11 | 12 | ``` 13 | @article{yang2025detok, 14 | title={Latent Denoising Makes Good Visual Tokenizers}, 15 | author={Jiawei Yang and Tianhong Li and Lijie Fan and Yonglong Tian and Yue Wang}, 16 | journal={arXiv preprint arXiv:2507.15856}, 17 | year={2025} 18 | } 19 | ``` 20 | 21 | This repo contains: 22 | 23 | * 🪐 A simple PyTorch implementation of [l-DeTok](models/detok.py) tokenizer and various generative models ([MAR](models/mar.py), [RandARDiff](models/ardiff.py), [RasterARDiff](models/ardiff.py), [DiT](models/dit.py), [SiT](models/sit.py), and [LightningDiT](models/lightningdit.py)) 24 | * ⚡️ Pre-trained DeTok tokenizers and MAR models trained on ImageNet 256x256 25 | * 🛸 Training and evaluation scripts for tokenizer and generative models 26 | * 🎉 Hugging Face for easy access to pre-trained models 27 | 28 | ## Preparation 29 | 30 | 31 | ### Installation 32 | 33 | Download the code: 34 | ```bash 35 | git clone https://github.com/Jiawei-Yang/detok.git 36 | cd detok 37 | ``` 38 | 39 | Create and activate conda environment: 40 | ```bash 41 | conda create -n detok python=3.10 -y && conda activate detok 42 | pip install -r requirements.txt 43 | ``` 44 | 45 | ### Dataset 46 | Create `data/` folder by `mkdir data/` then download [ImageNet](http://image-net.org/download) dataset. You can either: 47 | 1. Download it directly to `data/imagenet/` 48 | 2. Create a symbolic link: `ln -s /path/to/your/imagenet data/imagenet` 49 | 50 | 51 | ### Download Required Files 52 | 53 | Create data directory and download required files: 54 | ```bash 55 | mkdir data/ 56 | # Download everything from huggingface 57 | huggingface-cli download jjiaweiyang/l-DeTok --local-dir released_model 58 | mv released_model/train.txt data/ 59 | mv released_model/val.txt data/ 60 | mv released_model/fid_stats data/ 61 | mv released_model/imagenet-val-prc.zip ./data/ 62 | 63 | # Unzip: imagenet-val-prc.zip for precision & recall evaluation 64 | python -m zipfile -e ./data/imagenet-val-prc.zip ./data/ 65 | ``` 66 | 67 | ### Data Organization 68 | 69 | Your data directory should be organized as follows: 70 | ``` 71 | data/ 72 | ├── fid_stats/ # FID statistics 73 | │ ├── adm_in256_stats.npz 74 | │ └── val_fid_statistics_file.npz 75 | ├── imagenet/ # ImageNet dataset (or symlink) 76 | │ ├── train/ 77 | │ └── val/ 78 | ├── imagenet-val-prc/ # Precision-recall data 79 | ├── train.txt # Training file list 80 | └── val.txt # Validation file list 81 | ``` 82 | 83 | ## Models 84 | 85 | For convenience, our pre-trained models will be available on Hugging Face: 86 | 87 | | Model | Type | Params | Hugging Face | 88 | |---------------------|-----------|-------|-------------| 89 | | DeTok-BB | Tokenizer | 172M | [🤗 detok-bb](https://huggingface.co/jjiaweiyang/l-DeTok/resolve/main/detok-BB-gamm3.0-m0.7.pth) | 90 | | DeTok-BB-decoder_ft | Tokenizer | 172M | [🤗 detok-bb-decoder_ft](https://huggingface.co/jjiaweiyang/l-DeTok/resolve/main/detok-BB-gamm3.0-m0.7-decoder_tuned.pth) | 91 | | MAR-Base | Generator | 208M | [🤗 mar-base](https://huggingface.co/jjiaweiyang/l-DeTok/resolve/main/mar_base.pth) | 92 | | MAR-Large | Generator | 479M | [🤗 mar-large](https://huggingface.co/jjiaweiyang/l-DeTok/resolve/main/mar_large.pth) | 93 | 94 | FID-50k with CFG: 95 | |cfg| MAR Model | FID-50K | Inception Score | 96 | |---|-----------------------------------|---------|-----------------| 97 | |3.9| MAR-Base + DeTok-BB | 1.61 | 289.7 | 98 | |3.9| MAR-Base + DeTok-BB-decoder_ft | 1.55 | 291.0 | 99 | |3.4| MAR-Large + DeTok-BB | 1.43 | 303.5 | 100 | |3.4| MAR-Large + DeTok-BB-decoder_ft | 1.32 | 304.1 | 101 | 102 | ## Usage 103 | 104 | ### Demo 105 | Run our demo using notebook at [demo.ipynb](demo.ipynb) 106 | 107 | 108 | ## Training 109 | 110 | ### 1. Tokenizer Training 111 | 112 | Train DeTok tokenizer with denoising: 113 | ```bash 114 | project=tokenizer_training 115 | exp_name=detokBB-g3.0-m0.7-200ep 116 | batch_size=32 # global batch size = batch_size x num_nodes x 8 = 1024 117 | num_nodes=4 # adjust for your multi-node setup 118 | YOUR_WANDB_ENTITY="" # change to your wandb entity 119 | 120 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ 121 | main_reconstruction.py \ 122 | --project $project --exp_name $exp_name --auto_resume \ 123 | --batch_size $batch_size --model detok_BB \ 124 | --gamma 3.0 --mask_ratio 0.7 \ 125 | --online_eval \ 126 | --epochs 200 --discriminator_start_epoch 100 \ 127 | --data_path ./data/imagenet/train \ 128 | --entity $YOUR_WANDB_ENTITY --enable_wandb 129 | ``` 130 | 131 | Decoder fine-tuning: 132 | ```bash 133 | project=tokenizer_training 134 | exp_name=detokBB-g3.0-m0.7-200ep-decoder_ft-100ep 135 | batch_size=32 136 | num_nodes=4 137 | pretrained_tok=work_dirs/tokenizer_training/detokBB-g3.0-m0.7-200ep/checkpoints/latest.pth 138 | YOUR_WANDB_ENTITY="" 139 | 140 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ 141 | main_reconstruction.py \ 142 | --project $project --exp_name $exp_name --auto_resume \ 143 | --batch_size $batch_size --model detok_BB \ 144 | --load_from $pretrained_tok \ 145 | --online_eval --train_decoder_only \ 146 | --perceptual_weight 0.1 \ 147 | --gamma 0.0 --mask_ratio 0.0 \ 148 | --blr 5e-5 --warmup_rate 0.05 \ 149 | --epochs 100 --discriminator_start_epoch 0 \ 150 | --data_path ./data/imagenet/train \ 151 | --entity $YOUR_WANDB_ENTITY --enable_wandb 152 | ``` 153 | 154 | ### 2. Generative Model Training 155 | 156 | Train MAR-base (100 epochs): 157 | ```bash 158 | tokenizer_project=tokenizer_training 159 | tokenizer_exp_name=detokBB-g3.0-m0.7-200ep-decoder_ft-100ep 160 | project=gen_model_training 161 | exp_name=mar_base-${tokenizer_exp_name} 162 | batch_size=32 # global batch size = batch_size x num_nodes x 8 = 1024 163 | num_nodes=4 164 | epochs=100 165 | 166 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ 167 | main_diffusion.py \ 168 | --project $project --exp_name $exp_name --auto_resume \ 169 | --batch_size $batch_size --epochs $epochs --use_aligned_schedule \ 170 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \ 171 | --stats_key $tokenizer_exp_name --stats_cache_path work_dirs/stats.pkl \ 172 | --load_tokenizer_from work_dirs/$tokenizer_project/$tokenizer_exp_name/checkpoints/latest.pth \ 173 | --model MAR_base --no_dropout_in_mlp \ 174 | --diffloss_d 6 --diffloss_w 1024 \ 175 | --num_sampling_steps 100 --cfg 4.0 \ 176 | --cfg_list 3.0 3.5 3.7 3.8 3.9 4.0 4.1 4.3 4.5 \ 177 | --vis_freq 50 --eval_bsz 256 \ 178 | --data_path ./data/imagenet/train \ 179 | --entity $YOUR_WANDB_ENTITY --enable_wandb 180 | ``` 181 | 182 | Train SiT-base (100 epochs): 183 | ```bash 184 | tokenizer_project=tokenizer_training 185 | tokenizer_exp_name=detokBB-g3.0-m0.7-200ep-decoder_ft-100ep 186 | project=gen_model_training 187 | exp_name=sit_base-${tokenizer_exp_name} 188 | batch_size=32 189 | num_nodes=4 190 | epochs=100 191 | 192 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ 193 | main_diffusion.py \ 194 | --project $project --exp_name $exp_name --auto_resume \ 195 | --batch_size $batch_size --epochs $epochs --use_aligned_schedule \ 196 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \ 197 | --stats_key $tokenizer_exp_name --stats_cache_path work_dirs/stats.pkl \ 198 | --load_tokenizer_from work_dirs/$tokenizer_project/$tokenizer_exp_name/checkpoints/latest.pth \ 199 | --model SiT_base \ 200 | --num_sampling_steps 250 --cfg 1.6 \ 201 | --cfg_list 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2.0 \ 202 | --vis_freq 50 --eval_bsz 256 \ 203 | --data_path ./data/imagenet/train \ 204 | --entity $YOUR_WANDB_ENTITY --enable_wandb 205 | ``` 206 | 207 | ### 3. Training with the released DeTok 208 | 209 | Train MAR-base for 800 epochs using released tokenizer: 210 | ```bash 211 | project=gen_model_training 212 | exp_name=mar_base_800ep-detok-BB-gamm3.0-m0.7-decoder_tuned 213 | batch_size=16 # global batch size = batch_size x num_nodes x 8 = 1024 214 | num_nodes=8 215 | epochs=800 216 | 217 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ 218 | main_diffusion.py \ 219 | --project $project --exp_name $exp_name --auto_resume \ 220 | --batch_size $batch_size --epochs $epochs --use_aligned_schedule \ 221 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \ 222 | --stats_key detok-BB-gamm3.0-m0.7 --stats_cache_path released_model/stats.pkl \ 223 | --load_tokenizer_from released_model/detok-BB-gamm3.0-m0.7-decoder_tuned.pth \ 224 | --model MAR_base --no_dropout_in_mlp \ 225 | --diffloss_d 6 --diffloss_w 1024 \ 226 | --num_sampling_steps 100 --cfg 3.9 \ 227 | --cfg_list 3.0 3.5 3.7 3.8 3.9 4.0 4.1 4.3 4.5 \ 228 | --online_eval --vis_freq 80 --eval_bsz 256 \ 229 | --data_path ./data/imagenet/train \ 230 | --entity $YOUR_WANDB_ENTITY --enable_wandb 231 | ``` 232 | 233 | Train MAR-large for 800 epochs: 234 | ```bash 235 | project=gen_model_training 236 | exp_name=mar_large_800ep-detok-BB-gamm3.0-m0.7-decoder_tuned 237 | batch_size=16 238 | num_nodes=8 239 | epochs=800 240 | 241 | torchrun --nproc_per_node=8 --nnodes=$num_nodes --node_rank=${NODE_RANK} --master_addr=${MASTER_ADDR} --master_port=${MASTER_PORT} \ 242 | main_diffusion.py \ 243 | --project $project --exp_name $exp_name --auto_resume \ 244 | --batch_size $batch_size --epochs $epochs --use_aligned_schedule \ 245 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \ 246 | --stats_key detok-BB-gamm3.0-m0.7 --stats_cache_path released_model/stats.pkl \ 247 | --load_tokenizer_from released_model/detok-BB-gamm3.0-m0.7-decoder_tuned.pth \ 248 | --model MAR_large --no_dropout_in_mlp \ 249 | --diffloss_d 8 --diffloss_w 1280 \ 250 | --num_sampling_steps 100 --cfg 3.4 \ 251 | --cfg_list 3.0 3.2 3.3 3.4 3.5 3.6 3.7 3.8 3.9 \ 252 | --online_eval --vis_freq 80 --eval_bsz 256 \ 253 | --data_path ./data/imagenet/train \ 254 | --entity $YOUR_WANDB_ENTITY --enable_wandb 255 | ``` 256 | 257 | ## Evaluation (ImageNet 256x256) 258 | 259 | ### Evaluate pretrained MAR-Base 260 | ```bash 261 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 262 | main_diffusion.py \ 263 | --project mar_eval --exp_name mar_base_ep800 --auto_resume \ 264 | --batch_size 64 --epochs 800 --use_aligned_schedule \ 265 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \ 266 | --stats_key detok-BB-gamm3.0-m0.7 --stats_cache_path released_model/stats.pkl \ 267 | --load_tokenizer_from released_model/detok-BB-gamm3.0-m0.7-decoder_tuned.pth \ 268 | --model MAR_base --no_dropout_in_mlp \ 269 | --diffloss_d 6 --diffloss_w 1024 \ 270 | --load_from released_model/mar_base.pth \ 271 | --num_sampling_steps 100 --eval_bsz 256 --num_images 50000 --num_iter 256 --evaluate \ 272 | --cfg 3.9 \ 273 | --data_path ./data/imagenet/train 274 | ``` 275 | 276 | ### Evaluate pretrained MAR-Large 277 | ```bash 278 | torchrun --nproc_per_node=8 --nnodes=1 --node_rank=0 \ 279 | main_diffusion.py \ 280 | --project mar_eval --exp_name mar_large_ep800 --auto_resume \ 281 | --batch_size 64 --epochs 800 --use_aligned_schedule \ 282 | --tokenizer detok_BB --use_ema_tokenizer --collect_tokenizer_stats \ 283 | --stats_key detok-BB-gamm3.0-m0.7 --stats_cache_path released_model/stats.pkl \ 284 | --load_tokenizer_from released_model/detok-BB-gamm3.0-m0.7-decoder_tuned.pth \ 285 | --model MAR_large --no_dropout_in_mlp \ 286 | --diffloss_d 8 --diffloss_w 1280 \ 287 | --load_from released_model/mar_large.pth \ 288 | --num_sampling_steps 100 --eval_bsz 256 --num_images 50000 --num_iter 256 --evaluate \ 289 | --cfg 3.4 \ 290 | --data_path ./data/imagenet/train 291 | ``` 292 | 293 | FID-50k with CFG: 294 | |cfg| MAR Model | FID-50K | Inception Score | #params | 295 | |---|------------------------------|---------|-----------------|---------| 296 | |3.9| MAR-B + l-DeTok | 1.61 | 289.7 | 208M | 297 | |3.9| MAR-B + l-DeTok (decoder_ft) | 1.55 | 291.0 | 208M | 298 | |3.4| MAR-L + l-DeTok | 1.43 | 303.5 | 479M | 299 | |3.4| MAR-L + l-DeTok (decoder_ft) | 1.32 | 304.1 | 479M | 300 | 301 | 302 | **Note:** 303 | - Set `--cfg 1.0 --temperature 0.95` to evaluate without CFG for MAR-base and `--cfg 1.0 --temperature 0.97` for MAR-large 304 | - Generation speed can be significantly increased by reducing the number of autoregressive iterations (e.g., `--num_iter 64`) 305 | 306 | ## Acknowledgements 307 | We thank the authors of [MAE](https://github.com/facebookresearch/mae), [MAGE](https://github.com/LTH14/mage), [DiT](https://github.com/facebookresearch/DiT), [LightningDiT](https://github.com/hustvl/LightningDiT), [MAETok](https://github.com/Hhhhhhao/continuous_tokenizer) and [MAR](https://github.com/LTH14/mar) for their foundational work. 308 | 309 | Our codebase builds upon several excellent open-source projects, including [MAR](https://github.com/LTH14/mar) and [1d-tokenizer](https://github.com/bytedance/1d-tokenizer). We are grateful to the communities behind them. 310 | 311 | ## Contact 312 | This codebase has been cleaned up but has not undergone extensive testing. If you encounter any issues or have questions, please open a GitHub issue. We appreciate your feedback! -------------------------------------------------------------------------------- /assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiawei-Yang/DeTok/be99bae6f3539c34d8711addd16f2926295500d6/assets/demo.png -------------------------------------------------------------------------------- /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiawei-Yang/DeTok/be99bae6f3539c34d8711addd16f2926295500d6/assets/method.png -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Adopted from DiT, which is modified from OpenAI's diffusion repos 2 | # DiT: https://github.com/facebookresearch/DiT/diffusion 3 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 4 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 5 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 6 | 7 | import logging 8 | 9 | from . import gaussian_diffusion as gd 10 | from .respace import SpacedDiffusion, space_timesteps 11 | 12 | logger = logging.getLogger("DeTok") 13 | 14 | 15 | def create_diffusion( 16 | timestep_respacing, 17 | noise_schedule="linear", 18 | use_kl=False, 19 | sigma_small=False, 20 | predict_xstart=False, 21 | learn_sigma=True, 22 | rescale_learned_sigmas=False, 23 | diffusion_steps=1000, 24 | channel_last=False, 25 | ) -> SpacedDiffusion: 26 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 27 | if use_kl: 28 | loss_type = gd.LossType.RESCALED_KL 29 | elif rescale_learned_sigmas: 30 | loss_type = gd.LossType.RESCALED_MSE 31 | else: 32 | loss_type = gd.LossType.MSE 33 | if timestep_respacing is None or timestep_respacing == "": 34 | timestep_respacing = [diffusion_steps] 35 | if predict_xstart: 36 | model_mean_type = gd.ModelMeanType.START_X 37 | else: 38 | model_mean_type = gd.ModelMeanType.EPSILON 39 | 40 | if learn_sigma: 41 | model_var_type = gd.ModelVarType.LEARNED_RANGE 42 | else: 43 | if sigma_small: 44 | model_var_type = gd.ModelVarType.FIXED_SMALL 45 | else: 46 | model_var_type = gd.ModelVarType.FIXED_LARGE 47 | 48 | diffusion = SpacedDiffusion( 49 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 50 | betas=betas, 51 | model_mean_type=model_mean_type, 52 | model_var_type=model_var_type, 53 | loss_type=loss_type, 54 | channel_last=channel_last, 55 | ) 56 | logger.info(f"Created diffusion with timestep respacing {timestep_respacing}") 57 | return diffusion 58 | -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2) 27 | ] 28 | 29 | return 0.5 * ( 30 | -1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 31 | ) 32 | 33 | 34 | def approx_standard_normal_cdf(x): 35 | """ 36 | A fast approximation of the cumulative distribution function of the 37 | standard normal. 38 | """ 39 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 40 | 41 | 42 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 43 | """ 44 | Compute the log-likelihood of a Gaussian distribution discretizing to a 45 | given image. 46 | :param x: the target images. It is assumed that this was uint8 values, 47 | rescaled to the range [-1, 1]. 48 | :param means: the Gaussian mean Tensor. 49 | :param log_scales: the Gaussian log stddev Tensor. 50 | :return: a tensor like x of log probabilities (in nats). 51 | """ 52 | assert x.shape == means.shape == log_scales.shape 53 | centered_x = x - means 54 | inv_stdv = th.exp(-log_scales) 55 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 56 | cdf_plus = approx_standard_normal_cdf(plus_in) 57 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 58 | cdf_min = approx_standard_normal_cdf(min_in) 59 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 60 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 61 | cdf_delta = cdf_plus - cdf_min 62 | log_probs = th.where( 63 | x < -0.999, 64 | log_cdf_plus, 65 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 66 | ) 67 | assert log_probs.shape == x.shape 68 | return log_probs 69 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, int): 32 | section_counts = str(section_counts) 33 | if isinstance(section_counts, str): 34 | if section_counts.startswith("ddim"): 35 | desired_count = int(section_counts[len("ddim") :]) 36 | for i in range(1, num_timesteps): 37 | if len(range(0, num_timesteps, i)) == desired_count: 38 | return set(range(0, num_timesteps, i)) 39 | raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride") 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError(f"cannot divide section of {size} steps into {section_count}") 49 | if section_count <= 1: 50 | frac_stride = 1 51 | else: 52 | frac_stride = (size - 1) / (section_count - 1) 53 | cur_idx = 0.0 54 | taken_steps = [] 55 | for _ in range(section_count): 56 | taken_steps.append(start_idx + round(cur_idx)) 57 | cur_idx += frac_stride 58 | all_steps += taken_steps 59 | start_idx += size 60 | return set(all_steps) 61 | 62 | 63 | class SpacedDiffusion(GaussianDiffusion): 64 | """ 65 | A diffusion process which can skip steps in a base diffusion process. 66 | :param use_timesteps: a collection (sequence or set) of timesteps from the 67 | original diffusion process to retain. 68 | :param kwargs: the kwargs to create the base diffusion process. 69 | """ 70 | 71 | def __init__(self, use_timesteps, **kwargs): 72 | self.use_timesteps = set(use_timesteps) 73 | self.timestep_map = [] 74 | self.original_num_steps = len(kwargs["betas"]) 75 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 76 | last_alpha_cumprod = 1.0 77 | new_betas = [] 78 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 79 | if i in self.use_timesteps: 80 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 81 | last_alpha_cumprod = alpha_cumprod 82 | self.timestep_map.append(i) 83 | kwargs["betas"] = np.array(new_betas) 84 | super().__init__(**kwargs) 85 | 86 | def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs 87 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 88 | 89 | def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs 90 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 91 | 92 | def condition_mean(self, cond_fn, *args, **kwargs): 93 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 94 | 95 | def condition_score(self, cond_fn, *args, **kwargs): 96 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 97 | 98 | def _wrap_model(self, model): 99 | if isinstance(model, _WrappedModel): 100 | return model 101 | return _WrappedModel(model, self.timestep_map, self.original_num_steps) 102 | 103 | def _scale_timesteps(self, t): 104 | # Scaling is done by the wrapped model. 105 | return t 106 | 107 | 108 | class _WrappedModel: 109 | def __init__(self, model, timestep_map, original_num_steps): 110 | self.model = model 111 | self.timestep_map = timestep_map 112 | # self.rescale_timesteps = rescale_timesteps 113 | self.original_num_steps = original_num_steps 114 | 115 | def __call__(self, x, ts, **kwargs): 116 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 117 | new_ts = map_tensor[ts] 118 | # if self.rescale_timesteps: 119 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 120 | return self.model(x, new_ts, **kwargs) 121 | -------------------------------------------------------------------------------- /main_diffusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeTok: Generation model training script. 3 | """ 4 | 5 | import argparse 6 | import datetime 7 | import logging 8 | import sys 9 | import time 10 | 11 | import torch 12 | import torch.distributed 13 | 14 | import models 15 | import utils.distributed as distributed 16 | from utils.builders import create_generation_model, create_optimizer_and_scaler, create_train_dataloader 17 | from utils.misc import ckpt_resume, save_checkpoint 18 | from utils.train_utils import ( 19 | collect_tokenizer_stats, 20 | evaluate_generator, 21 | setup, 22 | train_one_epoch_generator, 23 | visualize_generator, 24 | visualize_tokenizer, 25 | ) 26 | 27 | # performance optimizations 28 | torch.backends.cuda.matmul.allow_tf32 = True 29 | torch.backends.cudnn.allow_tf32 = True 30 | torch.backends.cudnn.benchmark = True 31 | torch.backends.cudnn.deterministic = False 32 | 33 | logger = logging.getLogger("DeTok") 34 | 35 | 36 | def main(args: argparse.Namespace) -> int: 37 | global logger 38 | wandb_logger = setup(args) 39 | data_loader_train = create_train_dataloader(args) 40 | 41 | # initialize models 42 | model, tokenizer, ema_model = create_generation_model(args) 43 | optimizer, loss_scaler = create_optimizer_and_scaler(args, model) 44 | model_wo_ddp = model 45 | 46 | # handle token caching or tokenizer statistics collection 47 | if args.collect_tokenizer_stats: 48 | tmp_data_loader = create_train_dataloader( 49 | args, should_flip=False, batch_size=args.tokenizer_bsz, 50 | return_path=True, drop_last=False 51 | ) 52 | # (B, C, H, W) for chan_dim=2 or (B, seq_len, C) for chan_dim=1 53 | chan_dim = 2 if args.tokenizer in models.DeTok_models else 1 54 | 55 | # collect stats 56 | result_dict = collect_tokenizer_stats( 57 | tokenizer, tmp_data_loader, chan_dim=chan_dim, 58 | stats_dict_key=args.stats_key, 59 | stats_dict_path=args.stats_cache_path, 60 | overwrite_stats=args.overwrite_stats, 61 | ) 62 | # update tokenizer with computed statistics 63 | mean, std = result_dict["channel"] 64 | if mean.ndim > 0: 65 | n_chans = len(mean) // 2 66 | mean, std = mean[:n_chans], std[:n_chans] 67 | tokenizer.reset_stats(mean, std) 68 | 69 | del tmp_data_loader 70 | data_dict = next(iter(data_loader_train)) 71 | visualize_tokenizer(args, tokenizer, ema_model=None, data_dict=data_dict) 72 | 73 | # setup distributed training 74 | if distributed.is_enabled(): 75 | model = torch.nn.parallel.DistributedDataParallel(model) 76 | model_wo_ddp = model.module 77 | 78 | # resume from checkpoint if needed 79 | logger.info("Auto-resume enabled") 80 | ckpt_resume(args, model_wo_ddp, optimizer, loss_scaler, ema_model) 81 | 82 | # evaluation-only mode 83 | if args.evaluate: 84 | torch.cuda.empty_cache() 85 | cfg_list = args.cfg_list if args.cfg_list is not None else [args.cfg] 86 | for cfg in cfg_list: 87 | evaluate_generator( 88 | args, 89 | model_wo_ddp, 90 | ema_model, 91 | tokenizer, 92 | epoch=args.start_epoch, 93 | wandb_logger=wandb_logger, 94 | cfg=cfg, 95 | use_ema=True, # always use ema model for evaluation 96 | num_images=args.num_images, 97 | ) 98 | return 0 99 | 100 | # training loop 101 | logger.info(f"Start training from {args.start_epoch} to {args.epochs}") 102 | start_time = time.time() 103 | 104 | for epoch in range(args.start_epoch, args.epochs): 105 | train_one_epoch_generator( 106 | args, model, data_loader_train, optimizer, loss_scaler, wandb_logger, 107 | epoch, ema_model, tokenizer 108 | ) 109 | 110 | # progress logging 111 | elapsed_t = time.time() - start_time + args.last_elapsed_time 112 | eta = elapsed_t / (epoch + 1) * (args.epochs - epoch - 1) 113 | logger.info( 114 | f"[{epoch}/{args.epochs}] " 115 | f"Accumulated elapsed time: {str(datetime.timedelta(seconds=int(elapsed_t)))}, " 116 | f"ETA: {str(datetime.timedelta(seconds=int(eta)))}" 117 | ) 118 | 119 | # checkpointing 120 | should_save = ( 121 | (epoch + 1) % args.save_freq == 0 # save every n epochs 122 | or (epoch + 1) == args.epochs # save at the end of training 123 | ) 124 | 125 | if should_save: 126 | save_checkpoint(args, epoch, model_wo_ddp, optimizer, loss_scaler, ema_model, elapsed_t) 127 | torch.distributed.barrier() 128 | 129 | # periodic visualization 130 | if (epoch + 1) % args.vis_freq == 0: 131 | visualize_generator(args, model_wo_ddp, ema_model, tokenizer, epoch + 1) 132 | 133 | # online evaluation 134 | if args.online_eval and (epoch + 1) % args.eval_freq == 0: 135 | torch.cuda.empty_cache() 136 | evaluate_generator( 137 | args, model_wo_ddp, ema_model, tokenizer, epoch + 1, wandb_logger, 138 | use_ema=True, num_images=args.num_images_for_eval_and_search, cfg=args.cfg 139 | ) 140 | 141 | # final evaluation 142 | total_time = int(time.time() - start_time + args.last_elapsed_time) 143 | logger.info(f"Training time {str(datetime.timedelta(seconds=total_time))}") 144 | 145 | 146 | # determine cfg values for evaluation 147 | cfg_list = args.cfg_list or [args.cfg] # use the cfg from the args if not provided 148 | best_cfg = cfg_list[0] 149 | 150 | if len(cfg_list) > 1: 151 | # search the best cfg value using 10k images 152 | fid_dict = {} 153 | for cfg in cfg_list: 154 | fid_dict[cfg] = evaluate_generator( 155 | args, model_wo_ddp, ema_model, tokenizer, args.epochs + 1, wandb_logger, 156 | use_ema=True, cfg=cfg, num_images=args.num_images_for_eval_and_search 157 | ) 158 | # find best cfg value and broadcast to all ranks 159 | if distributed.is_main_process(): 160 | best_fid = 100000 161 | for cfg in cfg_list: 162 | if fid_dict[cfg]["fid"] < best_fid: 163 | best_fid = fid_dict[cfg]["fid"] 164 | best_cfg = cfg 165 | logger.info(f"Best FID: {best_fid}, Best cfg: {best_cfg}") 166 | 167 | # broadcast best_cfg from rank 0 to all ranks 168 | if distributed.is_enabled(): 169 | best_cfg_tensor = torch.tensor([best_cfg], dtype=torch.float32, device="cuda") 170 | torch.distributed.broadcast(best_cfg_tensor, src=0) 171 | best_cfg = best_cfg_tensor.item() 172 | torch.distributed.barrier() 173 | 174 | # final comprehensive evaluation with best cfg 175 | args.num_iter = 128 if args.tokenizer == "maetok-b-128" else 256 176 | evaluate_generator( 177 | args, model_wo_ddp, ema_model, tokenizer, args.epochs + 1, wandb_logger, 178 | use_ema=True, cfg=best_cfg, num_images=args.num_images 179 | ) 180 | 181 | # additional evaluation with cfg=1.0 182 | evaluate_generator( 183 | args, model_wo_ddp, ema_model, tokenizer, args.epochs + 1, wandb_logger, 184 | use_ema=True, cfg=1.0, num_images=args.num_images 185 | ) 186 | 187 | return 0 188 | 189 | 190 | def get_args_parser(): 191 | parser = argparse.ArgumentParser("Generation model training", add_help=False) 192 | 193 | # basic training parameters 194 | parser.add_argument("--start_epoch", default=0, type=int) 195 | parser.add_argument("--epochs", default=400, type=int) 196 | parser.add_argument("--batch_size", default=64, type=int, help="Batch size per GPU for training") 197 | 198 | # model parameters 199 | parser.add_argument("--model", default="MAR_base", type=str) 200 | parser.add_argument("--order", default="raster", type=str) 201 | parser.add_argument("--patch_size", default=1, type=int) 202 | parser.add_argument("--no_dropout_in_mlp", action="store_true") 203 | parser.add_argument("--qk_norm", action="store_true") 204 | parser.add_argument("--force_one_d_seq", type=int, default=0, help="1d tokens, e.g., 128 for MAETok") 205 | parser.add_argument("--legacy_mode", action="store_true") 206 | 207 | # tokenizer parameters 208 | parser.add_argument("--img_size", default=256, type=int) 209 | parser.add_argument("--tokenizer", default=None, type=str) 210 | parser.add_argument("--token_channels", default=16, type=int) 211 | parser.add_argument("--tokenizer_patch_size", default=16, type=int) 212 | parser.add_argument("--use_ema_tokenizer", action="store_true") 213 | 214 | # tokenizer cache parameters 215 | parser.add_argument("--collect_tokenizer_stats", action="store_true") 216 | parser.add_argument("--tokenizer_bsz", default=256, type=int) 217 | parser.add_argument("--cached_path", type=str, default="data/imagenet_tokens/") 218 | parser.add_argument("--stats_key", type=str, default=None) 219 | parser.add_argument("--overwrite_stats", action="store_true") 220 | parser.add_argument("--stats_cache_path", type=str, default="work_dirs/stats.pkl") 221 | 222 | # logging parameters 223 | parser.add_argument("--output_dir", default="./work_dirs") 224 | parser.add_argument("--print_freq", type=int, default=100) 225 | parser.add_argument("--eval_freq", type=int, default=40) 226 | parser.add_argument("--vis_freq", type=int, default=10) 227 | parser.add_argument("--save_freq", type=int, default=1) 228 | parser.add_argument("--last_elapsed_time", type=float, default=0.0) 229 | 230 | # checkpoint parameters 231 | parser.add_argument("--auto_resume", action="store_true") 232 | parser.add_argument("--resume_from", default=None, help="resume model weights and optimizer state") 233 | parser.add_argument("--load_from", type=str, default=None, help="load from pretrained model") 234 | parser.add_argument("--load_tokenizer_from", type=str, default=None, help="load from pretrained tokenizer") 235 | parser.add_argument("--keep_n_ckpts", default=1, type=int, help="keep the last n checkpoints") 236 | parser.add_argument("--milestone_interval", default=100, type=int, help="keep checkpoints every n epochs") 237 | 238 | # evaluation parameters 239 | parser.add_argument("--num_images_for_eval_and_search", default=10000, type=int) 240 | parser.add_argument("--num_images", default=50000, type=int) 241 | parser.add_argument("--online_eval", action="store_true") 242 | parser.add_argument("--fid_stats_path", type=str, default="data/fid_stats/adm_in256_stats.npz") 243 | parser.add_argument("--keep_eval_folder", action="store_true") 244 | parser.add_argument("--evaluate", action="store_true") 245 | parser.add_argument("--eval_bsz", type=int, default=256) 246 | 247 | # optimization parameters 248 | parser.add_argument("--lr", type=float, default=None) 249 | parser.add_argument("--blr", type=float, default=1e-4) 250 | parser.add_argument("--min_lr", type=float, default=1e-6) 251 | parser.add_argument("--lr_sched", type=str, default="constant", choices=["constant", "cosine"]) 252 | parser.add_argument("--warmup_rate", type=float, default=0.25, help="warmup_ep = warmup_rate * total_ep") 253 | parser.add_argument("--ema_rate", default=0.9999, type=float) 254 | parser.add_argument("--weight_decay", type=float, default=0.02) 255 | parser.add_argument("--grad_clip", type=float, default=3.0) 256 | parser.add_argument("--grad_checkpointing", action="store_true") 257 | parser.add_argument("--beta1", type=float, default=0.9) 258 | parser.add_argument("--beta2", type=float, default=0.95) 259 | parser.add_argument("--use_aligned_schedule", action="store_true") 260 | 261 | # generation parameters 262 | parser.add_argument("--num_iter", default=64, type=int, help="number of autoregressive steps for MAR") 263 | parser.add_argument("--noise_schedule", type=str, default="cosine", help="noise schedule for diffusion") 264 | parser.add_argument("--cfg", default=4.0, type=float, help="cfg value for diffusion") 265 | parser.add_argument("--cfg_schedule", default="linear", type=str, help="cfg schedule for diffusion") 266 | parser.add_argument("--cfg_list", default=None, type=float, nargs="+", help="cfg list for search") 267 | 268 | # mar parameters 269 | parser.add_argument("--label_drop_prob", default=0.1, type=float) 270 | parser.add_argument("--mask_ratio_min", type=float, default=0.7) 271 | parser.add_argument("--attn_dropout", type=float, default=0.1) 272 | parser.add_argument("--proj_dropout", type=float, default=0.1) 273 | parser.add_argument("--buffer_size", type=int, default=64) 274 | 275 | # diffusion loss parameters 276 | parser.add_argument("--diffloss_d", type=int, default=3) 277 | parser.add_argument("--diffloss_w", type=int, default=1024) 278 | parser.add_argument("--num_sampling_steps", type=str, default="100") 279 | parser.add_argument("--diffusion_batch_mul", type=int, default=4) 280 | parser.add_argument("--temperature", default=1.0, type=float) 281 | 282 | # dataset parameters 283 | parser.add_argument("--use_cached_tokens", action="store_true") 284 | parser.add_argument("--data_path", default="./data/imagenet/train", type=str) 285 | parser.add_argument("--num_classes", default=1000, type=int) 286 | parser.add_argument("--class_of_interest", default=[207, 360, 387, 974, 88, 979, 417, 279], type=int, nargs="+") 287 | parser.add_argument("--force_class_of_interest", action="store_true", 288 | help="generate images of only the class of interest for args.num_images images") 289 | parser.add_argument("--num_workers", default=10, type=int) 290 | parser.add_argument("--pin_mem", action="store_true") 291 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem") 292 | parser.set_defaults(pin_mem=True) 293 | 294 | # system parameters 295 | parser.add_argument("--seed", default=1, type=int) 296 | 297 | # wandb parameters 298 | parser.add_argument("--project", default="lDeTok", type=str) 299 | parser.add_argument("--entity", default="YOUR_WANDB_ENTITY", type=str) 300 | parser.add_argument("--exp_name", default=None, type=str) 301 | parser.add_argument("--enable_wandb", action="store_true") 302 | 303 | return parser 304 | 305 | 306 | if __name__ == "__main__": 307 | args = get_args_parser().parse_args() 308 | exit_code = main(args) 309 | sys.exit(exit_code) 310 | -------------------------------------------------------------------------------- /main_reconstruction.py: -------------------------------------------------------------------------------- 1 | """ 2 | DeTok: Reconstruction model training script. 3 | """ 4 | 5 | import argparse 6 | import datetime 7 | import logging 8 | import sys 9 | import time 10 | 11 | import torch 12 | import torch.distributed 13 | 14 | import utils.distributed as distributed 15 | from utils.builders import ( 16 | create_loss_module, 17 | create_optimizer_and_scaler, 18 | create_reconstruction_model, 19 | create_train_dataloader, 20 | create_val_dataloader, 21 | create_vis_dataloader, 22 | ) 23 | from utils.misc import ckpt_resume, save_checkpoint 24 | from utils.train_utils import evaluate_tokenizer, setup, train_one_epoch_tokenizer, visualize_tokenizer 25 | 26 | # performance optimizations 27 | torch.backends.cuda.matmul.allow_tf32 = True 28 | torch.backends.cudnn.allow_tf32 = True 29 | torch.backends.cudnn.benchmark = True 30 | torch.backends.cudnn.deterministic = False 31 | 32 | logger = logging.getLogger("DeTok") 33 | 34 | 35 | def main(args: argparse.Namespace) -> int: 36 | global logger 37 | wandb_logger = setup(args) 38 | 39 | # initialize data loaders 40 | data_loader_train = create_train_dataloader(args) 41 | data_loader_val = create_val_dataloader(args) 42 | data_loader_vis = create_vis_dataloader(args) 43 | vis_iterator = iter(data_loader_vis) 44 | 45 | # initialize models and optimizers 46 | model, ema_model = create_reconstruction_model(args) 47 | if args.train_decoder_only and hasattr(model, "freeze_everything_but_decoder"): 48 | model.freeze_everything_but_decoder() 49 | 50 | optimizer, loss_scaler = create_optimizer_and_scaler(args, model) 51 | loss_fn = create_loss_module(args) 52 | discriminator_optimizer, discriminator_loss_scaler = create_optimizer_and_scaler(args, loss_fn) 53 | 54 | # setup distributed training 55 | if distributed.is_enabled(): 56 | model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True) 57 | loss_fn = torch.nn.parallel.DistributedDataParallel(loss_fn, find_unused_parameters=True) 58 | 59 | # get models without DDP wrapper 60 | model_wo_ddp = model.module if hasattr(model, "module") else model 61 | loss_module_wo_ddp = loss_fn.module if hasattr(loss_fn, "module") else loss_fn 62 | 63 | # resume from checkpoint if needed 64 | ckpt_resume( 65 | args, model_wo_ddp, optimizer, loss_scaler, ema_model, 66 | loss_module_wo_ddp, discriminator_optimizer, discriminator_loss_scaler 67 | ) 68 | 69 | # initial visualization 70 | visualize_tokenizer(args, model_wo_ddp, ema_model, next(vis_iterator), args.start_epoch) 71 | 72 | if args.vis_only: 73 | return 0 74 | 75 | # evaluation-only mode 76 | if args.evaluate: 77 | torch.cuda.empty_cache() 78 | for use_ema in [False, True]: 79 | evaluate_tokenizer( 80 | args, model_wo_ddp, ema_model, data_loader_val, args.start_epoch, wandb_logger, use_ema 81 | ) 82 | return 0 83 | 84 | # training loop 85 | logger.info(f"Start training from {args.start_epoch} to {args.epochs}") 86 | start_time = time.time() 87 | 88 | for epoch in range(args.start_epoch, args.epochs): 89 | train_one_epoch_tokenizer( 90 | args, model, data_loader_train, optimizer, loss_scaler, wandb_logger, epoch, 91 | ema_model, loss_fn, discriminator_optimizer, discriminator_loss_scaler 92 | ) 93 | 94 | # progress logging 95 | elapsed_t = time.time() - start_time + args.last_elapsed_time 96 | eta = elapsed_t / (epoch + 1) * (args.epochs - epoch - 1) 97 | logger.info( 98 | f"[{epoch}/{args.epochs}] " 99 | f"Accumulated elapsed time: {str(datetime.timedelta(seconds=int(elapsed_t)))}, " 100 | f"ETA: {str(datetime.timedelta(seconds=int(eta)))}" 101 | ) 102 | 103 | # checkpointing 104 | should_save = ( 105 | (epoch + 1) % args.save_freq == 0 # save every n epochs 106 | or (epoch + 1) == args.epochs # save at the end of training 107 | ) 108 | 109 | if should_save: 110 | save_checkpoint( 111 | args, epoch, model_wo_ddp, optimizer, loss_scaler, ema_model, elapsed_t, 112 | loss_module_wo_ddp, discriminator_optimizer, discriminator_loss_scaler 113 | ) 114 | torch.distributed.barrier() 115 | 116 | # periodic visualization 117 | if (epoch + 1) % args.vis_freq == 0: 118 | visualize_tokenizer(args, model_wo_ddp, ema_model, next(vis_iterator), epoch) 119 | 120 | # online evaluation 121 | if (args.online_eval and (epoch + 1) % args.eval_freq == 0 and (epoch + 1) != args.epochs): 122 | torch.cuda.empty_cache() 123 | for use_ema in [False, True]: 124 | evaluate_tokenizer( 125 | args, model_wo_ddp, ema_model, data_loader_val, epoch + 1, wandb_logger, use_ema 126 | ) 127 | 128 | # final evaluation 129 | total_time = int(time.time() - start_time + args.last_elapsed_time) 130 | logger.info(f"Training time {str(datetime.timedelta(seconds=total_time))}") 131 | 132 | for use_ema in [False, True]: 133 | evaluate_tokenizer(args, model_wo_ddp, ema_model, data_loader_val, args.epochs, wandb_logger, use_ema) 134 | 135 | return 0 136 | 137 | 138 | def get_args_parser(): 139 | parser = argparse.ArgumentParser("Reconstruction model training", add_help=False) 140 | 141 | # basic training parameters 142 | parser.add_argument("--start_epoch", default=0, type=int) 143 | parser.add_argument("--epochs", default=200, type=int) 144 | parser.add_argument("--batch_size", default=64, type=int, help="Batch size per GPU for training") 145 | 146 | # model parameters 147 | parser.add_argument("--model", default="detok_BB", type=str) 148 | parser.add_argument("--token_channels", default=16, type=int) 149 | parser.add_argument("--img_size", default=256, type=int) 150 | parser.add_argument("--patch_size", default=16, type=int) 151 | 152 | parser.add_argument("--mask_ratio", default=0.0, type=float) 153 | parser.add_argument("--gamma", default=0.0, type=float, help="noise standard deviation for training") 154 | parser.add_argument("--use_additive_noise", action="store_true") 155 | 156 | parser.add_argument("--no_load_ckpt", action="store_true") 157 | parser.add_argument("--train_decoder_only", action="store_true") 158 | parser.add_argument("--vis_only", action="store_true") 159 | 160 | # loss parameters 161 | parser.add_argument("--perceptual_loss", type=str, default="lpips-convnext_s-1.0-0.1") 162 | parser.add_argument("--perceptual_weight", default=1.0, type=float) 163 | parser.add_argument("--discriminator_start_epoch", default=20, type=int) 164 | parser.add_argument("--discriminator_weight", default=0.5, type=float) 165 | parser.add_argument("--kl_loss_weight", default=1e-6, type=float) 166 | 167 | # logging parameters 168 | parser.add_argument("--output_dir", default="./work_dirs") 169 | parser.add_argument("--print_freq", type=int, default=100) 170 | parser.add_argument("--eval_freq", type=int, default=10) 171 | parser.add_argument("--vis_freq", type=int, default=5) 172 | parser.add_argument("--save_freq", type=int, default=1) 173 | parser.add_argument("--last_elapsed_time", type=float, default=0.0) 174 | 175 | # checkpoint parameters 176 | parser.add_argument("--auto_resume", action="store_true") 177 | parser.add_argument("--resume_from", default=None, help="resume model weights and optimizer state") 178 | parser.add_argument("--load_from", type=str, default=None, help="load from pretrained model") 179 | parser.add_argument("--keep_n_ckpts", default=1, type=int, help="keep the last n checkpoints") 180 | parser.add_argument("--milestone_interval", default=100, type=int, help="keep checkpoints every n epochs") 181 | 182 | 183 | # evaluation parameters 184 | parser.add_argument("--num_images", default=50000, type=int, help="Number of images to evaluate on") 185 | parser.add_argument("--online_eval", action="store_true") 186 | parser.add_argument("--fid_stats_path", type=str, default="data/fid_stats/val_fid_statistics_file.npz") 187 | parser.add_argument("--keep_eval_folder", action="store_true") 188 | parser.add_argument("--evaluate", action="store_true") 189 | parser.add_argument("--eval_bsz", type=int, default=256) 190 | 191 | # optimization parameters 192 | parser.add_argument("--lr", type=float, default=None) 193 | parser.add_argument("--blr", type=float, default=1e-4) 194 | parser.add_argument("--min_lr", type=float, default=0.0) 195 | parser.add_argument("--lr_sched", type=str, default="cosine", choices=["constant", "cosine"]) 196 | parser.add_argument("--warmup_rate", type=float, default=0.25) 197 | parser.add_argument("--ema_rate", default=0.999, type=float) 198 | parser.add_argument("--weight_decay", type=float, default=1e-4) 199 | parser.add_argument("--grad_clip", type=float, default=3.0) 200 | parser.add_argument("--grad_checkpointing", action="store_true", help="Use gradient checkpointing") 201 | parser.add_argument("--beta1", type=float, default=0.9, help="Beta1 for AdamW optimizer") 202 | parser.add_argument("--beta2", type=float, default=0.95, help="Beta2 for AdamW optimizer") 203 | 204 | # dataset parameters 205 | parser.add_argument("--use_cached_tokens", action="store_true") 206 | parser.add_argument("--data_path", default="./data/imagenet/train", type=str) 207 | parser.add_argument("--num_classes", default=1000, type=int) 208 | parser.add_argument("--class_of_interest", default=[207, 360, 387, 974, 88, 979, 417, 279], type=int, nargs="+") 209 | parser.add_argument("--num_workers", default=10, type=int) 210 | parser.add_argument("--pin_mem", action="store_true") 211 | parser.add_argument("--no_pin_mem", action="store_false", dest="pin_mem") 212 | parser.set_defaults(pin_mem=True) 213 | 214 | # system parameters 215 | parser.add_argument("--seed", default=1, type=int) 216 | 217 | # wandb parameters 218 | parser.add_argument("--project", default="lDeTok", type=str) 219 | parser.add_argument("--entity", default="YOUR_WANDB_ENTITY", type=str) 220 | parser.add_argument("--exp_name", default=None, type=str) 221 | parser.add_argument("--enable_wandb", action="store_true") 222 | 223 | return parser 224 | 225 | 226 | if __name__ == "__main__": 227 | args = get_args_parser().parse_args() 228 | exit_code = main(args) 229 | sys.exit(exit_code) 230 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from .ardiff import ARDiff_models 2 | from .autoencoder import AutoencoderKL, VAE_models 3 | from .detok import DeTok_models 4 | from .dit import DiT_models 5 | from .ema import SimpleEMAModel 6 | from .lightningdit import LightningDiT_models 7 | from .mar import MAR_models 8 | from .sit import SiT_models 9 | -------------------------------------------------------------------------------- /models/ardiff.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.checkpoint import checkpoint 7 | from tqdm import tqdm 8 | 9 | from .diffloss import DiffLoss 10 | from .layers import Block, Transformer, modulate 11 | from .model_utils import SIZE_DICT 12 | 13 | logger = logging.getLogger("DeTok") 14 | 15 | 16 | class FinalLayer(nn.Module): 17 | """final layer with adaptive layer normalization.""" 18 | 19 | def __init__(self, in_features) -> None: 20 | super().__init__() 21 | self.norm = nn.LayerNorm(in_features, elementwise_affine=False, eps=1e-6) 22 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(in_features, 2 * in_features)) 23 | 24 | def forward(self, x, condition): 25 | shift, scale = self.adaLN_modulation(condition).chunk(2, dim=-1) 26 | x = modulate(self.norm(x), shift, scale) 27 | return x 28 | 29 | 30 | class ARDiff(nn.Module): 31 | """decoder-only autoregressive diffusion model.""" 32 | 33 | def __init__( 34 | self, 35 | img_size=256, 36 | patch_size=1, 37 | model_size="base", 38 | tokenizer_patch_size=16, 39 | token_channels=16, 40 | label_drop_prob=0.1, 41 | num_classes=1000, 42 | # diffloss parameters 43 | noise_schedule="cosine", 44 | diffloss_d=3, 45 | diffloss_w=1024, 46 | diffusion_batch_mul=4, 47 | # sampling parameters 48 | num_sampling_steps=100, 49 | grad_checkpointing=False, 50 | force_one_d_seq=False, 51 | order="raster", 52 | ): 53 | super().__init__() 54 | 55 | # -------------------------------------------------------------------------- 56 | # basic configuration 57 | self.img_size = img_size 58 | self.patch_size = patch_size 59 | self.token_channels = token_channels 60 | self.num_classes = num_classes 61 | self.label_drop_prob = label_drop_prob 62 | self.grad_checkpointing = grad_checkpointing 63 | self.force_one_d_seq = force_one_d_seq 64 | self.order = order 65 | self.diffusion_batch_mul = diffusion_batch_mul 66 | 67 | # sequence dimensions 68 | self.seq_h = self.seq_w = img_size // tokenizer_patch_size // patch_size 69 | self.seq_len = self.seq_h * self.seq_w + 1 # +1 for BOS token 70 | self.token_embed_dim = token_channels * patch_size**2 71 | 72 | if force_one_d_seq: 73 | self.seq_len = force_one_d_seq + 1 74 | 75 | # model architecture configuration 76 | size_dict = SIZE_DICT[model_size] 77 | num_layers, num_heads, width = size_dict["layers"], size_dict["heads"], size_dict["width"] 78 | 79 | self.label_drop_prob = label_drop_prob 80 | 81 | scale = width**-0.5 82 | 83 | # class and null token embeddings 84 | self.class_emb = nn.Embedding(self.num_classes, width) 85 | self.fake_latent = nn.Parameter(scale * torch.randn(1, width)) 86 | self.bos_token = nn.Parameter(torch.zeros(1, 1, width)) 87 | 88 | # input and positional embeddings 89 | self.x_embedder = nn.Linear(self.token_embed_dim, width) 90 | self.pos_embed = nn.Parameter(scale * torch.randn((1, self.seq_len, width))) 91 | self.target_pos_embed = nn.Parameter(scale * torch.randn((1, self.seq_len - 1, width))) 92 | self.timesteps_embeddings = nn.Parameter(scale * torch.randn((1, self.seq_len, width))) 93 | 94 | # training mask for causal attention 95 | self.train_mask = torch.tril(torch.ones(self.seq_len, self.seq_len, dtype=torch.bool)).cuda() 96 | 97 | # -------------------------------------------------------------------------- 98 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 99 | 100 | self.ln_pre = norm_layer(width) 101 | self.transformer = Transformer( 102 | width, 103 | num_layers, 104 | num_heads, 105 | block_fn=partial(Block, use_modulation=True), 106 | norm_layer=norm_layer, 107 | force_causal=True, 108 | grad_checkpointing=self.grad_checkpointing, 109 | ) 110 | self.final_layer = FinalLayer(width) 111 | self.initialize_weights() 112 | 113 | # -------------------------------------------------------------------------- 114 | # Diffusion Loss 115 | self.diffloss = DiffLoss( 116 | target_channels=self.token_embed_dim, 117 | z_channels=width, 118 | width=diffloss_w, 119 | depth=diffloss_d, 120 | num_sampling_steps=num_sampling_steps, 121 | grad_checkpointing=grad_checkpointing, 122 | noise_schedule=noise_schedule, 123 | ) 124 | self.diffusion_batch_mul = diffusion_batch_mul 125 | params_M = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 126 | logger.info(f"[ARDiff] params: {params_M:.2f}M, {model_size}-{num_layers}-{width}") 127 | 128 | def initialize_weights(self): 129 | """initialize model weights.""" 130 | # parameter initialization 131 | torch.nn.init.normal_(self.pos_embed, std=0.02) 132 | torch.nn.init.normal_(self.bos_token, std=0.02) 133 | torch.nn.init.normal_(self.target_pos_embed, std=0.02) 134 | torch.nn.init.normal_(self.timesteps_embeddings, std=0.02) 135 | torch.nn.init.normal_(self.class_emb.weight, std=0.02) 136 | torch.nn.init.normal_(self.fake_latent, std=0.02) 137 | 138 | # apply standard initialization 139 | self.apply(self._init_weights) 140 | 141 | def _init_weights(self, m): 142 | """standard weight initialization for layers.""" 143 | if isinstance(m, nn.Linear): 144 | torch.nn.init.xavier_uniform_(m.weight) 145 | if m.bias is not None: 146 | nn.init.constant_(m.bias, 0) 147 | elif isinstance(m, nn.LayerNorm): 148 | if m.bias is not None: 149 | nn.init.constant_(m.bias, 0) 150 | if m.weight is not None: 151 | nn.init.constant_(m.weight, 1.0) 152 | 153 | # zero-out adaptive modulation layers 154 | for block in self.transformer.blocks: 155 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 156 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 157 | 158 | # zero-out final layer modulation 159 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 160 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 161 | 162 | def patchify(self, x): 163 | """convert image tensor to patch tokens.""" 164 | bsz, c, h, w = x.shape 165 | p = self.patch_size 166 | h_, w_ = h // p, w // p 167 | 168 | x = x.reshape(bsz, c, h_, p, w_, p) 169 | x = torch.einsum("nchpwq->nhwcpq", x) 170 | x = x.reshape(bsz, h_ * w_, c * p**2) 171 | return x # [batch, seq_len, token_dim] 172 | 173 | def unpatchify(self, x): 174 | """convert patch tokens back to image tensor.""" 175 | bsz = x.shape[0] 176 | p = self.patch_size 177 | c = self.token_channels 178 | h_, w_ = self.seq_h, self.seq_w 179 | 180 | x = x.reshape(bsz, h_, w_, c, p, p) 181 | x = torch.einsum("nhwcpq->nchpwq", x) 182 | x = x.reshape(bsz, c, h_ * p, w_ * p) 183 | return x # [batch, channels, height, width] 184 | 185 | def enable_kv_cache(self): 186 | for block in self.transformer.blocks: 187 | block.attn.kv_cache = True 188 | block.attn.reset_kv_cache() 189 | logger.info("Enable kv_cache for Transformer blocks") 190 | 191 | def disable_kv_cache(self): 192 | for block in self.transformer.blocks: 193 | block.attn.kv_cache = False 194 | block.attn.reset_kv_cache() 195 | logger.info("Disable kv_cache for Transformer blocks") 196 | 197 | def get_random_orders(self, x): 198 | """generate random token ordering.""" 199 | batch_size = x.shape[0] 200 | random_noise = torch.randn(batch_size, self.seq_len - 1, device=x.device) 201 | shuffled_orders = torch.argsort(random_noise, dim=1) 202 | return shuffled_orders 203 | 204 | def get_raster_orders(self, x): 205 | """generate raster (sequential) token ordering.""" 206 | batch_size = x.shape[0] 207 | raster_orders = torch.arange(self.seq_len - 1, device=x.device) 208 | shuffled_orders = torch.stack([raster_orders for _ in range(batch_size)]) 209 | return shuffled_orders 210 | 211 | def shuffle(self, x, orders): 212 | """shuffle tokens according to given orders.""" 213 | batch_size, seq_len = x.shape[:2] 214 | batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, seq_len) 215 | shuffled_x = x[batch_indices, orders] 216 | return shuffled_x 217 | 218 | def unshuffle(self, shuffled_x, orders): 219 | """unshuffle tokens to restore original ordering.""" 220 | batch_size, seq_len = shuffled_x.shape[:2] 221 | batch_indices = torch.arange(batch_size).unsqueeze(1).expand(-1, seq_len) 222 | unshuffled_x = torch.zeros_like(shuffled_x) 223 | unshuffled_x[batch_indices, orders] = shuffled_x 224 | return unshuffled_x 225 | 226 | def forward_transformer(self, x, class_embedding, orders=None): 227 | """forward pass through the transformer.""" 228 | x = self.x_embedder(x) 229 | bsz = x.shape[0] 230 | 231 | # add BOS token 232 | bos_token = self.bos_token.expand(bsz, 1, -1) 233 | x = torch.cat([bos_token, x], dim=1) 234 | current_seq_len = x.shape[1] 235 | 236 | # add positional embeddings 237 | pos_embed = self.pos_embed.expand(bsz, -1, -1) 238 | if orders is not None: 239 | pos_embed = torch.cat([pos_embed[:, :1], self.shuffle(pos_embed[:, 1:], orders)], dim=1) 240 | x = x + pos_embed[:, :current_seq_len] 241 | 242 | # add target positional embeddings 243 | target_pos_embed = self.target_pos_embed.expand(bsz, -1, -1) 244 | embed_dim = target_pos_embed.shape[-1] 245 | if orders is not None: 246 | target_pos_embed = self.shuffle(target_pos_embed, orders) 247 | target_pos_embed = torch.cat([target_pos_embed, torch.zeros(bsz, 1, embed_dim).to(x.device)], dim=1) 248 | x = x + target_pos_embed[:, :current_seq_len] 249 | 250 | x = self.ln_pre(x) 251 | 252 | # prepare condition tokens 253 | condition_token = class_embedding.repeat(1, current_seq_len, 1) 254 | timestep_embed = self.timesteps_embeddings.expand(bsz, -1, -1) 255 | condition_token = condition_token + timestep_embed[:, :current_seq_len] 256 | 257 | # handle kv cache for inference 258 | if self.transformer.blocks[0].attn.kv_cache: 259 | x = x[:, -1:] 260 | condition_token = condition_token[:, -1:] 261 | 262 | # transformer forward pass 263 | for block in self.transformer.blocks: 264 | if self.grad_checkpointing and self.training: 265 | x = checkpoint(block, x, None, None, condition_token) 266 | else: 267 | x = block(x, condition=condition_token) 268 | 269 | x = self.final_layer(x, condition=class_embedding) 270 | return x 271 | 272 | def forward_loss(self, z, target): 273 | """compute diffusion loss.""" 274 | bsz, seq_len, _ = target.shape 275 | target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) 276 | z = z.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) 277 | return self.diffloss(z=z, target=target) 278 | 279 | def forward(self, x, labels): 280 | """forward pass for training.""" 281 | # get token ordering 282 | if self.order == "raster": 283 | orders = self.get_raster_orders(x) 284 | elif self.order == "random": 285 | orders = self.get_random_orders(x) 286 | else: 287 | raise NotImplementedError(f"Order '{self.order}' not implemented") 288 | 289 | # prepare class embeddings 290 | class_embedding = self.class_emb(labels) 291 | if self.training: 292 | # randomly drop class embedding during training 293 | drop_mask = torch.rand(x.shape[0]) < self.label_drop_prob 294 | drop_mask = drop_mask.unsqueeze(-1).cuda().to(x.dtype) 295 | class_embedding = drop_mask * self.fake_latent + (1 - drop_mask) * class_embedding 296 | class_embedding = class_embedding.unsqueeze(1) 297 | 298 | # prepare input tokens 299 | x = self.patchify(x) if not self.force_one_d_seq else x 300 | x = self.shuffle(x, orders) 301 | gt_latents = x.clone().detach() 302 | 303 | # forward pass and loss computation 304 | z = self.forward_transformer(x[:, :-1], class_embedding, orders=orders) 305 | return self.forward_loss(z=z, target=gt_latents) 306 | 307 | def sample_tokens( 308 | self, 309 | bsz, 310 | cfg=1.0, 311 | cfg_schedule="linear", 312 | labels=None, 313 | temperature=1.0, 314 | progress=False, 315 | kv_cache=False, 316 | ): 317 | """sample tokens autoregressively.""" 318 | tokens = torch.zeros(bsz, 0, self.token_embed_dim).cuda() 319 | indices = list(range(self.seq_len - 1)) 320 | 321 | # setup kv cache if requested 322 | if kv_cache: 323 | self.enable_kv_cache() 324 | 325 | if progress: 326 | indices = tqdm(indices) 327 | 328 | # get token ordering 329 | if self.order == "raster": 330 | orders = self.get_raster_orders(torch.zeros(bsz, self.seq_len - 1, self.token_embed_dim).cuda()) 331 | elif self.order == "random": 332 | orders = self.get_random_orders(torch.zeros(bsz, self.seq_len - 1, self.token_embed_dim).cuda()) 333 | else: 334 | raise NotImplementedError(f"Order '{self.order}' not implemented") 335 | 336 | # prepare for classifier-free guidance 337 | if cfg != 1.0: 338 | orders = torch.cat([orders, orders], dim=0) 339 | 340 | # generate tokens step by step 341 | for step in indices: 342 | cur_tokens = tokens.clone() 343 | 344 | # prepare class embeddings and CFG 345 | cls_embd = self.fake_latent.repeat(bsz, 1) if labels is None else self.class_emb(labels) 346 | 347 | if cfg != 1.0: 348 | tokens = torch.cat([tokens, tokens], dim=0) 349 | cls_embd = torch.cat([cls_embd, self.fake_latent.repeat(bsz, 1)], dim=0) 350 | cls_embd = cls_embd.unsqueeze(1) 351 | z = self.forward_transformer(tokens, cls_embd, orders=orders)[:, -1] 352 | 353 | # apply CFG schedule 354 | if cfg_schedule == "linear": 355 | cfg_iter = 1 + (cfg - 1) * step / len(indices) 356 | elif cfg_schedule == "constant": 357 | cfg_iter = cfg 358 | else: 359 | raise NotImplementedError(f"CFG schedule '{cfg_schedule}' not implemented") 360 | 361 | # sample next token 362 | sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter) 363 | 364 | if cfg != 1.0: 365 | sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) 366 | 367 | cur_tokens = torch.cat([cur_tokens, sampled_token_latent.unsqueeze(1)], dim=1) 368 | tokens = cur_tokens.clone() 369 | 370 | # cleanup 371 | if kv_cache: 372 | self.disable_kv_cache() 373 | 374 | if cfg != 1.0: 375 | orders, _ = orders.chunk(2, dim=0) 376 | 377 | # restore original ordering and convert back to image format 378 | tokens = self.unshuffle(tokens, orders) 379 | if not self.force_one_d_seq: 380 | tokens = self.unpatchify(tokens) 381 | 382 | return tokens 383 | 384 | def generate(self, n_samples, cfg, labels, args): 385 | """generate samples using the model.""" 386 | return self.sample_tokens( 387 | n_samples, 388 | cfg=cfg, 389 | labels=labels, 390 | cfg_schedule=args.cfg_schedule, 391 | temperature=args.temperature, 392 | progress=True, 393 | kv_cache=False, 394 | ) 395 | 396 | 397 | # model size variants 398 | def ARDiff_base(**kwargs): 399 | return ARDiff(model_size="base", **kwargs) 400 | 401 | 402 | def ARDiff_large(**kwargs): 403 | return ARDiff(model_size="large", **kwargs) 404 | 405 | 406 | def ARDiff_xl(**kwargs): 407 | return ARDiff(model_size="xl", **kwargs) 408 | 409 | 410 | def ARDiff_huge(**kwargs): 411 | return ARDiff(model_size="huge", **kwargs) 412 | 413 | 414 | ARDiff_models = { 415 | "ARDiff_base": ARDiff_base, 416 | "ARDiff_large": ARDiff_large, 417 | "ARDiff_huge": ARDiff_huge, 418 | "ARDiff_xl": ARDiff_xl, 419 | } 420 | -------------------------------------------------------------------------------- /models/diffloss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/LTH14/mar/blob/main/models/diffloss.py 3 | """ 4 | import torch 5 | import torch.nn as nn 6 | from torch.utils.checkpoint import checkpoint 7 | 8 | from diffusion import create_diffusion 9 | from transport import Sampler, create_transport 10 | 11 | from .layers import ModulatedLinear, TimestepEmbedder, modulate 12 | 13 | 14 | class DiffLoss(nn.Module): 15 | """diffusion loss module for training.""" 16 | 17 | def __init__( 18 | self, 19 | target_channels, 20 | z_channels, 21 | depth, 22 | width, 23 | num_sampling_steps, 24 | grad_checkpointing=False, 25 | noise_schedule="cosine", 26 | use_transport=False, 27 | timestep_shift=0.3, 28 | learn_sigma=True, 29 | sampling_method="euler", 30 | ): 31 | super(DiffLoss, self).__init__() 32 | 33 | # -------------------------------------------------------------------------- 34 | # basic configuration 35 | self.in_channels = target_channels 36 | self.noise_schedule = noise_schedule 37 | self.use_transport = use_transport 38 | 39 | # -------------------------------------------------------------------------- 40 | # network architecture 41 | self.net = SimpleMLPAdaLN( 42 | in_channels=target_channels, 43 | model_channels=width, 44 | out_channels=target_channels * 2 if learn_sigma else target_channels, 45 | z_channels=z_channels, 46 | num_res_blocks=depth, 47 | grad_checkpointing=grad_checkpointing, 48 | use_transport=use_transport, 49 | ) 50 | 51 | # -------------------------------------------------------------------------- 52 | # diffusion/transport setup 53 | if self.use_transport: 54 | self.transport = create_transport(use_cosine_loss=True, use_lognorm=True) 55 | self.sampler = Sampler(self.transport) 56 | self.sample_fn = self.sampler.sample_ode( 57 | sampling_method=sampling_method, 58 | num_steps=int(num_sampling_steps), 59 | timestep_shift=timestep_shift, 60 | ) 61 | else: 62 | self.train_diffusion = create_diffusion("", noise_schedule=noise_schedule) 63 | self.gen_diffusion = create_diffusion(num_sampling_steps, noise_schedule=noise_schedule) 64 | 65 | def forward(self, target, z, mask=None): 66 | """forward pass for training.""" 67 | if self.use_transport: 68 | model_kwargs = dict(c=z) 69 | loss_dict = self.transport.training_losses(self.net, target, model_kwargs) 70 | else: 71 | t = torch.randint( 72 | 0, 73 | self.train_diffusion.num_timesteps, 74 | (target.shape[0],), 75 | device=target.device, 76 | ) 77 | model_kwargs = dict(c=z) 78 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs) 79 | 80 | loss = loss_dict["loss"] 81 | if mask is not None: 82 | loss = (loss * mask).sum() / mask.sum() 83 | return loss.mean() 84 | 85 | def sample(self, z, temperature=1.0, cfg=1.0): 86 | """sample from the diffusion model.""" 87 | if cfg != 1.0: 88 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda() 89 | noise = torch.cat([noise, noise], dim=0) 90 | if self.use_transport: 91 | model_kwargs = dict(c=z, cfg_scale=cfg, cfg_interval=True, cfg_interval_start=0.10) 92 | else: 93 | model_kwargs = dict(c=z, cfg_scale=cfg) 94 | sample_fn = self.net.forward_with_cfg 95 | else: 96 | noise = torch.randn(z.shape[0], self.in_channels).cuda() 97 | model_kwargs = dict(c=z) 98 | sample_fn = self.net.forward 99 | 100 | if self.use_transport: 101 | sampled_token_latent = self.sample_fn(noise, sample_fn, **model_kwargs)[-1] 102 | else: 103 | sampled_token_latent = self.gen_diffusion.p_sample_loop( 104 | sample_fn, 105 | noise.shape, 106 | noise, 107 | clip_denoised=False, 108 | model_kwargs=model_kwargs, 109 | progress=False, 110 | temperature=temperature, 111 | ) 112 | return sampled_token_latent 113 | 114 | 115 | class ResBlock(nn.Module): 116 | """residual block with adaptive layer normalization.""" 117 | 118 | def __init__(self, channels): 119 | super().__init__() 120 | self.channels = channels 121 | self.in_ln = nn.LayerNorm(channels, eps=1e-6) 122 | self.mlp = nn.Sequential( 123 | nn.Linear(channels, channels, bias=True), 124 | nn.SiLU(), 125 | nn.Linear(channels, channels, bias=True), 126 | ) 127 | self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)) 128 | 129 | def forward(self, x, y): 130 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) 131 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp) 132 | h = self.mlp(h) 133 | return x + gate_mlp * h 134 | 135 | 136 | class SimpleMLPAdaLN(nn.Module): 137 | """simple MLP with adaptive layer normalization for diffusion loss.""" 138 | 139 | def __init__( 140 | self, 141 | in_channels, 142 | model_channels, 143 | out_channels, 144 | z_channels, 145 | num_res_blocks, 146 | grad_checkpointing=False, 147 | use_transport=False, 148 | ): 149 | super().__init__() 150 | 151 | # -------------------------------------------------------------------------- 152 | # basic configuration 153 | self.in_channels = in_channels 154 | self.model_channels = model_channels 155 | self.out_channels = out_channels 156 | self.num_res_blocks = num_res_blocks 157 | self.grad_checkpointing = grad_checkpointing 158 | self.use_transport = use_transport 159 | 160 | # -------------------------------------------------------------------------- 161 | # network layers 162 | self.time_embed = TimestepEmbedder(model_channels) 163 | self.cond_embed = nn.Linear(z_channels, model_channels) 164 | self.input_proj = nn.Linear(in_channels, model_channels) 165 | self.res_blocks = nn.ModuleList([ResBlock(model_channels) for _ in range(num_res_blocks)]) 166 | self.final_layer = ModulatedLinear(model_channels, out_channels) 167 | 168 | self.initialize_weights() 169 | 170 | def initialize_weights(self): 171 | """initialize model weights.""" 172 | 173 | def _basic_init(module): 174 | if isinstance(module, nn.Linear): 175 | torch.nn.init.xavier_uniform_(module.weight) 176 | if module.bias is not None: 177 | nn.init.constant_(module.bias, 0) 178 | 179 | self.apply(_basic_init) 180 | 181 | # initialize timestep embedding MLP 182 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) 183 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) 184 | 185 | # zero-out adaLN modulation layers 186 | for block in self.res_blocks: 187 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 188 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 189 | 190 | # zero-out output layers 191 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 192 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 193 | nn.init.constant_(self.final_layer.linear.weight, 0) 194 | nn.init.constant_(self.final_layer.linear.bias, 0) 195 | 196 | def forward(self, x, t=None, c=None): 197 | """apply the model to an input batch.""" 198 | x = self.input_proj(x) 199 | t = self.time_embed(t) 200 | c = self.cond_embed(c) 201 | y = t + c 202 | 203 | for block in self.res_blocks: 204 | if self.grad_checkpointing and self.training: 205 | x = checkpoint(block, x, y) 206 | else: 207 | x = block(x, y) 208 | return self.final_layer(x, y) 209 | 210 | def forward_with_cfg(self, x, t, c, cfg_scale, cfg_interval=None, cfg_interval_start=None): 211 | """forward pass with classifier-free guidance.""" 212 | half = x[: len(x) // 2] 213 | combined = torch.cat([half, half], dim=0) 214 | model_out = self.forward(combined, t, c) 215 | eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :] 216 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 217 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 218 | if cfg_interval is True: 219 | timestep = t[0] 220 | if timestep < cfg_interval_start: 221 | half_eps = cond_eps 222 | 223 | eps = torch.cat([half_eps, half_eps], dim=0) 224 | return torch.cat([eps, rest], dim=1) 225 | -------------------------------------------------------------------------------- /models/dit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/facebookresearch/DiT/blob/main/models.py 3 | - add support for 1D sequence 4 | - include samplers inside the model 5 | """ 6 | 7 | import logging 8 | from functools import partial 9 | 10 | import torch 11 | import torch.nn as nn 12 | 13 | from diffusion import create_diffusion 14 | 15 | from .layers import ( 16 | Block, 17 | LabelEmbedder, 18 | ModulatedLinear, 19 | PatchEmbed, 20 | TimestepEmbedder, 21 | Transformer, 22 | get_2d_sincos_pos_embed, 23 | ) 24 | from .model_utils import SIZE_DICT 25 | 26 | logger = logging.getLogger("DeTok") 27 | 28 | 29 | class DiT(nn.Module): 30 | """diffusion model with a transformer backbone.""" 31 | 32 | def __init__( 33 | self, 34 | img_size=256, 35 | patch_size=1, 36 | model_size="base", 37 | tokenizer_patch_size=16, 38 | token_channels=16, 39 | label_drop_prob=0.1, 40 | num_classes=1000, 41 | learn_sigma=True, 42 | noise_schedule="linear", 43 | num_sampling_steps=250, 44 | grad_checkpointing=False, 45 | force_one_d_seq=0, 46 | legacy_mode=False, 47 | ): 48 | super().__init__() 49 | 50 | # -------------------------------------------------------------------------- 51 | # basic configuration 52 | self.learn_sigma = learn_sigma 53 | self.token_channels = token_channels 54 | self.out_channels = token_channels * 2 if learn_sigma else token_channels 55 | self.input_size = img_size // tokenizer_patch_size 56 | self.patch_size = patch_size 57 | self.num_classes = num_classes 58 | self.force_one_d_seq = force_one_d_seq 59 | self.grad_checkpointing = grad_checkpointing 60 | self.legacy_mode = legacy_mode 61 | 62 | # model architecture configuration 63 | size_dict = SIZE_DICT[model_size] 64 | num_layers, num_heads, width = size_dict["layers"], size_dict["heads"], size_dict["width"] 65 | 66 | # -------------------------------------------------------------------------- 67 | # embedding layers 68 | if self.force_one_d_seq: 69 | self.x_embedder = nn.Linear(token_channels, width) 70 | self.pos_embed = nn.Parameter(torch.randn(1, self.force_one_d_seq, width) * 0.02) 71 | else: 72 | self.x_embedder = PatchEmbed(self.input_size, patch_size, token_channels, width) 73 | num_patches = self.x_embedder.num_patches 74 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, width), requires_grad=False) 75 | 76 | self.t_embedder = TimestepEmbedder(width) 77 | self.y_embedder = LabelEmbedder(num_classes, width, label_drop_prob) 78 | 79 | # -------------------------------------------------------------------------- 80 | # transformer architecture 81 | self.transformer = Transformer( 82 | width, 83 | num_layers, 84 | num_heads, 85 | block_fn=partial(Block, use_modulation=True), 86 | norm_layer=partial(nn.LayerNorm, elementwise_affine=False, eps=1e-6), 87 | grad_checkpointing=grad_checkpointing, 88 | ) 89 | self.final_layer = ModulatedLinear(width, patch_size * patch_size * self.out_channels) 90 | 91 | # -------------------------------------------------------------------------- 92 | # diffusion setup 93 | self.train_diffusion = create_diffusion("", noise_schedule=noise_schedule) 94 | self.gen_diffusion = create_diffusion(num_sampling_steps, noise_schedule=noise_schedule) 95 | self.initialize_weights() 96 | 97 | # log model info 98 | num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 99 | logger.info( 100 | f"[DiT] params: {num_trainable_params:.2f}M size: {model_size}, num_layers: {num_layers}, width: {width}" 101 | ) 102 | 103 | def initialize_weights(self): 104 | """initialize model weights.""" 105 | 106 | def _basic_init(module): 107 | if isinstance(module, nn.Linear): 108 | torch.nn.init.xavier_uniform_(module.weight) 109 | if hasattr(module, "bias"): 110 | nn.init.constant_(module.bias, 0) 111 | 112 | self.apply(_basic_init) 113 | """initialize (and freeze) pos_embed by sin-cos embedding""" 114 | if not self.force_one_d_seq: 115 | pos_embed = get_2d_sincos_pos_embed( 116 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5) 117 | ) 118 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 119 | 120 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 121 | w = self.x_embedder.proj.weight.data 122 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 123 | nn.init.constant_(self.x_embedder.proj.bias, 0) 124 | 125 | # initialize label embedding table 126 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 127 | 128 | # initialize timestep embedding MLP 129 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 130 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 131 | 132 | # zero-out adaLN modulation layers in DiT blocks 133 | for block in self.transformer.blocks: 134 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 135 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 136 | 137 | # zero-out output layers 138 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 139 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 140 | nn.init.constant_(self.final_layer.linear.weight, 0) 141 | nn.init.constant_(self.final_layer.linear.bias, 0) 142 | 143 | def unpatchify(self, x): 144 | """convert patch tokens back to image tensor.""" 145 | c, p = self.out_channels, self.patch_size 146 | h = w = int(x.shape[1] ** 0.5) 147 | assert h * w == x.shape[1] 148 | 149 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 150 | x = torch.einsum("nhwpqc->nchpwq", x) 151 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 152 | return imgs 153 | 154 | def net(self, x, t, y): 155 | """core network forward pass.""" 156 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 157 | c = self.t_embedder(t) + self.y_embedder(y, self.training) # (N, D) 158 | x = self.transformer(x, condition=c) # (N, T, D) 159 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 160 | if not self.force_one_d_seq: 161 | x = self.unpatchify(x) 162 | return x 163 | 164 | def forward_with_cfg(self, x, t, y, cfg_scale): 165 | """forward pass with classifier-free guidance.""" 166 | half = x[: len(x) // 2] 167 | combined = torch.cat([half, half], dim=0) 168 | model_out = self.net(combined, t, y) 169 | 170 | if self.legacy_mode: 171 | eps, rest = model_out[:, :3], model_out[:, 3:] 172 | else: 173 | eps, rest = model_out[:, : self.token_channels], model_out[:, self.token_channels :] 174 | 175 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 176 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 177 | eps = torch.cat([half_eps, half_eps], dim=0) 178 | return torch.cat([eps, rest], dim=1) 179 | 180 | def forward(self, x, y): 181 | """forward pass for training.""" 182 | t = torch.randint(0, self.train_diffusion.num_timesteps, (x.shape[0],), device=x.device) 183 | loss_dict = self.train_diffusion.training_losses(self.net, x, t, dict(y=y)) 184 | return loss_dict["loss"].mean() 185 | 186 | @torch.inference_mode() 187 | def generate(self, n_samples, labels, cfg=1.0, args=None): 188 | """generate samples using the model.""" 189 | device = labels.device 190 | 191 | # prepare noise tensor 192 | if self.force_one_d_seq: 193 | z = torch.randn(n_samples, self.force_one_d_seq, self.token_channels) 194 | else: 195 | z = torch.randn(n_samples, self.token_channels, self.input_size, self.input_size) 196 | z = z.to(device) 197 | 198 | # setup classifier-free guidance 199 | if cfg > 1.0: 200 | z = torch.cat([z, z], 0) 201 | labels = torch.cat([labels, torch.full_like(labels, self.num_classes)], 0) 202 | model_kwargs = dict(y=labels, cfg_scale=cfg) 203 | sample_fn = self.forward_with_cfg 204 | else: 205 | model_kwargs = dict(y=labels) 206 | sample_fn = self.net 207 | 208 | # generate samples 209 | samples = self.gen_diffusion.p_sample_loop( 210 | sample_fn, 211 | z.shape, 212 | z, 213 | clip_denoised=False, 214 | model_kwargs=model_kwargs, 215 | progress=True, 216 | device=device, 217 | ) 218 | 219 | if cfg > 1.0: 220 | samples, _ = samples.chunk(2, dim=0) # remove null class samples 221 | return samples 222 | 223 | 224 | # model size variants 225 | def DiT_base(**kwargs): 226 | return DiT(model_size="base", **kwargs) 227 | 228 | 229 | def DiT_large(**kwargs): 230 | return DiT(model_size="large", **kwargs) 231 | 232 | 233 | def DiT_xl(**kwargs): 234 | return DiT(model_size="xl", **kwargs) 235 | 236 | 237 | def DiT_huge(**kwargs): 238 | return DiT(model_size="huge", **kwargs) 239 | 240 | 241 | DiT_models = {"DiT_base": DiT_base, "DiT_large": DiT_large, "DiT_xl": DiT_xl, "DiT_huge": DiT_huge} 242 | -------------------------------------------------------------------------------- /models/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SimpleEMAModel: 5 | """simple exponential moving average model""" 6 | 7 | def __init__(self, model: torch.nn.Module, decay: float = 0.9999): 8 | self.ema_params = {} 9 | self.temp_stored_params = {} 10 | self.decay = decay 11 | 12 | # initialize EMA parameters 13 | for name, param in model.named_parameters(): 14 | if param.requires_grad: 15 | self.ema_params[name] = param.clone().detach() 16 | else: 17 | self.ema_params[name] = param 18 | 19 | @torch.inference_mode() 20 | def step(self, model: torch.nn.Module): 21 | """update EMA parameters with current model parameters.""" 22 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 23 | model = model.module 24 | 25 | for name, param in model.named_parameters(): 26 | if param.requires_grad: 27 | self.ema_params[name].mul_(self.decay).add_(param, alpha=1 - self.decay) 28 | else: 29 | self.ema_params[name].copy_(param) 30 | 31 | def copy_to(self, model: torch.nn.Module) -> None: 32 | """copy current averaged parameters into given model.""" 33 | for name, param in model.named_parameters(): 34 | param.data.copy_(self.ema_params[name].to(param.device).data) 35 | 36 | def to(self, device=None, dtype=None) -> None: 37 | """move internal buffers to specified device.""" 38 | # .to() on the tensors handles None correctly 39 | for name, param in self.ema_params.items(): 40 | self.ema_params[name] = ( 41 | self.ema_params[name].to(device=device, dtype=dtype) 42 | if self.ema_params[name].is_floating_point() 43 | else self.ema_params[name].to(device=device) 44 | ) 45 | 46 | def store(self, model: torch.nn.Module) -> None: 47 | """store current model parameters temporarily.""" 48 | for name, param in model.named_parameters(): 49 | self.temp_stored_params[name] = param.detach().cpu().clone() 50 | 51 | def restore(self, model: torch.nn.Module) -> None: 52 | """restore parameters stored with the store method.""" 53 | if self.temp_stored_params is None: 54 | raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`") 55 | 56 | for name, param in model.named_parameters(): 57 | assert name in self.temp_stored_params, f"{name} not found in temp_stored_params" 58 | param.data.copy_(self.temp_stored_params[name].data) 59 | self.temp_stored_params = {} 60 | 61 | def load_state_dict(self, state_dict: dict | list) -> None: 62 | """load EMA state from state dict.""" 63 | if isinstance(state_dict, dict): 64 | for name, param in self.ema_params.items(): 65 | param.data.copy_(state_dict[name].to(param.device).data) 66 | elif isinstance(state_dict, list): 67 | i = 0 68 | for name, param in self.ema_params.items(): 69 | param.data.copy_(state_dict[i].to(param.device).data) 70 | i += 1 71 | else: 72 | raise ValueError("state_dict must be a dict or list") 73 | 74 | def state_dict(self) -> dict: 75 | """return EMA parameters as state dict.""" 76 | return self.ema_params 77 | -------------------------------------------------------------------------------- /models/lightningdit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/hustvl/LightningDiT/blob/main/models/lightningdit.py 3 | 4 | - add support for 1D sequence 5 | - include samplers inside the model 6 | - slightly different cfg conditioning (conditioned on **all channels**) 7 | """ 8 | 9 | import logging 10 | from functools import partial 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | from transport import Sampler, create_transport 16 | 17 | from .layers import ( 18 | Block, 19 | LabelEmbedder, 20 | ModulatedLinear, 21 | PatchEmbed, 22 | TimestepEmbedder, 23 | Transformer, 24 | VisionRotaryEmbeddingFast, 25 | get_2d_sincos_pos_embed, 26 | ) 27 | from .model_utils import SIZE_DICT 28 | 29 | logger = logging.getLogger("DeTok") 30 | 31 | 32 | class LightningDiT(nn.Module): 33 | """lightning diffusion transformer model.""" 34 | 35 | def __init__( 36 | self, 37 | img_size=256, 38 | patch_size=1, 39 | model_size="base", 40 | tokenizer_patch_size=16, 41 | token_channels=16, 42 | label_drop_prob=0.1, 43 | num_classes=1000, 44 | num_sampling_steps=250, 45 | sampling_method="euler", 46 | grad_checkpointing=False, 47 | force_one_d_seq=0, 48 | learn_sigma=False, # no learn_sigma in SiT 49 | legacy_mode=False, 50 | qk_norm=False, 51 | ): 52 | super().__init__() 53 | 54 | # -------------------------------------------------------------------------- 55 | # basic configuration 56 | self.token_channels = self.out_channels = token_channels 57 | self.input_size = img_size // tokenizer_patch_size 58 | self.patch_size = patch_size 59 | self.num_classes = num_classes 60 | self.force_one_d_seq = force_one_d_seq 61 | self.grad_checkpointing = grad_checkpointing 62 | self.learn_sigma = learn_sigma 63 | self.legacy_mode = legacy_mode 64 | 65 | # model architecture configuration 66 | size_dict = SIZE_DICT[model_size] 67 | num_layers, num_heads, width = size_dict["layers"], size_dict["heads"], size_dict["width"] 68 | 69 | # -------------------------------------------------------------------------- 70 | # embedding layers 71 | if self.force_one_d_seq > 0: 72 | self.x_embedder = nn.Linear(token_channels, width) 73 | # we use learnable positional embeddings for 1D sequence without rope 74 | self.pos_embed = nn.Parameter(torch.randn(1, self.force_one_d_seq, width) * 0.02) 75 | self.seq_len = self.force_one_d_seq 76 | else: 77 | self.x_embedder = PatchEmbed(self.input_size, patch_size, token_channels, width) 78 | # use rotary position encoding + abe, borrow from EVA 79 | num_patches = self.x_embedder.num_patches 80 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, width)) 81 | self.rope = VisionRotaryEmbeddingFast(width // num_heads // 2, self.input_size // patch_size) 82 | self.seq_len = num_patches 83 | 84 | self.t_embedder = TimestepEmbedder(width) 85 | self.y_embedder = LabelEmbedder(num_classes, width, label_drop_prob) 86 | 87 | # -------------------------------------------------------------------------- 88 | # transformer architecture 89 | self.transformer = Transformer( 90 | width, 91 | num_layers, 92 | num_heads, 93 | block_fn=partial(Block, use_modulation=True), 94 | norm_layer=nn.RMSNorm, 95 | grad_checkpointing=grad_checkpointing, 96 | use_swiglu=True, 97 | qk_norm=qk_norm, 98 | ) 99 | self.final_layer = ModulatedLinear(width, patch_size**2 * token_channels, use_rmsnorm=True) 100 | 101 | # -------------------------------------------------------------------------- 102 | # transport and sampling setup 103 | self.transport = create_transport(use_cosine_loss=True, use_lognorm=True) 104 | self.sampler = Sampler(self.transport) 105 | self.sample_fn = self.sampler.sample_ode( 106 | sampling_method=sampling_method, 107 | num_steps=int(num_sampling_steps), 108 | timestep_shift=0.3, 109 | ) 110 | 111 | self.initialize_weights() 112 | 113 | # log model info 114 | num_trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 115 | logger.info( 116 | f"[LightningDiT] params: {num_trainable_params:.2f}M size: {model_size}, num_layers: {num_layers}, width: {width}" 117 | ) 118 | 119 | def initialize_weights(self): 120 | """initialize model weights.""" 121 | 122 | def _basic_init(module): 123 | if isinstance(module, nn.Linear): 124 | torch.nn.init.xavier_uniform_(module.weight) 125 | if module.bias is not None: 126 | nn.init.constant_(module.bias, 0) 127 | 128 | self.apply(_basic_init) 129 | 130 | # initialize (and freeze) pos_embed by sin-cos embedding 131 | if not self.force_one_d_seq: 132 | pos_embed = get_2d_sincos_pos_embed( 133 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5) 134 | ) 135 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 136 | 137 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 138 | w = self.x_embedder.proj.weight.data 139 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 140 | nn.init.constant_(self.x_embedder.proj.bias, 0) 141 | 142 | # initialize label embedding table 143 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 144 | 145 | # initialize timestep embedding MLP 146 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 147 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 148 | 149 | # zero-out adaLN modulation layers in LightningDiT blocks 150 | for block in self.transformer.blocks: 151 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 152 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 153 | 154 | # zero-out output layers 155 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 156 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 157 | nn.init.constant_(self.final_layer.linear.weight, 0) 158 | nn.init.constant_(self.final_layer.linear.bias, 0) 159 | 160 | def unpatchify(self, x): 161 | """convert patch tokens back to image tensor.""" 162 | c, p = self.out_channels, self.patch_size 163 | h = w = int(x.shape[1] ** 0.5) 164 | assert h * w == x.shape[1] 165 | 166 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 167 | x = torch.einsum("nhwpqc->nchpwq", x) 168 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 169 | return imgs 170 | 171 | def net(self, x, t=None, y=None): 172 | """core network forward pass.""" 173 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 174 | c = self.t_embedder(t) + self.y_embedder(y, self.training) # (N, D) 175 | 176 | # check if self.pos_embed requires grad 177 | if not self.force_one_d_seq: 178 | x = self.transformer(x, condition=c, rope=self.rope) # (N, T, D) 179 | else: 180 | x = self.transformer(x, condition=c) 181 | 182 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 183 | if not self.force_one_d_seq: 184 | x = self.unpatchify(x) 185 | return x 186 | 187 | def forward_with_cfg(self, x, t, y, cfg_scale, cfg_interval=None, cfg_interval_start=None): 188 | """forward pass with classifier-free guidance.""" 189 | half = x[: len(x) // 2] 190 | combined = torch.cat([half, half], dim=0) 191 | model_out = self.net(combined, t, y) 192 | 193 | if self.legacy_mode: 194 | eps, rest = model_out[:, :3], model_out[:, 3:] 195 | else: 196 | eps, rest = model_out[:, : self.token_channels], model_out[:, self.token_channels :] 197 | 198 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 199 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 200 | 201 | if cfg_interval is True: 202 | timestep = t[0] 203 | if timestep < cfg_interval_start: 204 | half_eps = cond_eps 205 | 206 | eps = torch.cat([half_eps, half_eps], dim=0) 207 | return torch.cat([eps, rest], dim=1) 208 | 209 | def forward(self, x, y): 210 | """forward pass for training.""" 211 | loss_dict = self.transport.training_losses(self.net, x, dict(y=y)) 212 | return loss_dict["loss"].mean() 213 | 214 | @torch.inference_mode() 215 | def generate(self, n_samples, labels, cfg=1.0, args=None): 216 | """generate samples using the model.""" 217 | device = labels.device 218 | 219 | # prepare noise tensor 220 | if self.force_one_d_seq: 221 | z = torch.randn(n_samples, self.force_one_d_seq, self.token_channels) 222 | else: 223 | z = torch.randn(n_samples, self.token_channels, self.input_size, self.input_size) 224 | z = z.to(device) 225 | 226 | # setup classifier-free guidance 227 | if cfg > 1.0: 228 | z = torch.cat([z, z], 0) 229 | y_null = torch.tensor([self.num_classes] * n_samples, device=device) 230 | labels = torch.cat([labels, y_null], 0) 231 | model_kwargs = dict(y=labels, cfg_scale=cfg, cfg_interval=True, cfg_interval_start=0.10) 232 | model_fn = self.forward_with_cfg 233 | else: 234 | model_kwargs = dict(y=labels) 235 | model_fn = self.net 236 | 237 | # generate samples 238 | samples = self.sample_fn(z, model_fn, **model_kwargs)[-1] 239 | if cfg > 1.0: 240 | samples, _ = samples.chunk(2, dim=0) # remove null class samples 241 | return samples 242 | 243 | 244 | # model size variants 245 | def LightningDiT_base(**kwargs) -> LightningDiT: 246 | return LightningDiT(model_size="base", **kwargs) 247 | 248 | 249 | def LightningDiT_large(**kwargs) -> LightningDiT: 250 | return LightningDiT(model_size="large", **kwargs) 251 | 252 | 253 | def LightningDiT_xl(**kwargs) -> LightningDiT: 254 | return LightningDiT(model_size="xl", **kwargs) 255 | 256 | 257 | def LightningDiT_huge(**kwargs) -> LightningDiT: 258 | return LightningDiT(model_size="huge", **kwargs) 259 | 260 | 261 | LightningDiT_models = { 262 | "LightningDiT_base": LightningDiT_base, 263 | "LightningDiT_large": LightningDiT_large, 264 | "LightningDiT_xl": LightningDiT_xl, 265 | "LightningDiT_huge": LightningDiT_huge, 266 | } 267 | -------------------------------------------------------------------------------- /models/mar.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/LTH14/mar/blob/main/models/mar.py 3 | - add support for 1D sequence 4 | - include samplers inside the model 5 | - add support for removing dropout in MLPs 6 | """ 7 | 8 | import logging 9 | import math 10 | from functools import partial 11 | 12 | import numpy as np 13 | import scipy.stats as stats 14 | import torch 15 | import torch.nn as nn 16 | from torch.utils.checkpoint import checkpoint 17 | from tqdm import tqdm 18 | 19 | from .diffloss import DiffLoss 20 | from .layers import Block 21 | 22 | logger = logging.getLogger("DeTok") 23 | 24 | MAR_SIZE_DICT = { 25 | "base": {"width": 768, "layers": 12, "heads": 12}, 26 | "large": {"width": 1024, "layers": 16, "heads": 16}, 27 | "huge": {"width": 1280, "layers": 20, "heads": 16}, 28 | } 29 | 30 | 31 | def mask_by_order(mask_len, order, bsz, seq_len): 32 | """create masking tensor based on given order and length.""" 33 | masking = torch.zeros(bsz, seq_len).cuda() 34 | masking = torch.scatter( 35 | masking, 36 | dim=-1, 37 | index=order[:, : mask_len.long()], 38 | src=torch.ones(bsz, seq_len).cuda(), 39 | ).bool() 40 | return masking 41 | 42 | 43 | class MAR(nn.Module): 44 | def __init__( 45 | self, 46 | img_size=256, 47 | patch_size=1, 48 | model_size="base", 49 | tokenizer_patch_size=16, 50 | token_channels=16, 51 | mask_ratio_min=0.7, 52 | label_drop_prob=0.1, 53 | num_classes=1000, 54 | attn_dropout=0.1, 55 | proj_dropout=0.1, 56 | buffer_size=64, 57 | diffloss_d=3, 58 | diffloss_w=1024, 59 | num_sampling_steps="100", 60 | noise_schedule="cosine", 61 | diffusion_batch_mul=4, 62 | force_one_d_seq=0, 63 | grad_checkpointing=False, 64 | no_dropout_in_mlp=False, 65 | ): 66 | super().__init__() 67 | 68 | # -------------------------------------------------------------------------- 69 | # VAE and patchify specifics 70 | self.token_channels = token_channels 71 | self.img_size = img_size 72 | self.patch_size = patch_size 73 | self.seq_h = self.seq_w = img_size // tokenizer_patch_size // patch_size 74 | self.seq_len = self.seq_h * self.seq_w 75 | self.token_embed_dim = token_channels * patch_size**2 76 | self.grad_checkpointing = grad_checkpointing 77 | self.model_size = model_size 78 | self.force_one_d_seq = force_one_d_seq 79 | if force_one_d_seq: 80 | self.seq_len = force_one_d_seq 81 | 82 | size_dict = MAR_SIZE_DICT[self.model_size] 83 | num_layers, num_heads, width = size_dict["layers"], size_dict["heads"], size_dict["width"] 84 | 85 | # -------------------------------------------------------------------------- 86 | # Class Embedding 87 | self.num_classes = num_classes 88 | self.class_emb = nn.Embedding(num_classes, width) 89 | self.label_drop_prob = label_drop_prob 90 | # Fake class embedding for CFG's unconditional generation 91 | self.fake_latent = nn.Parameter(torch.zeros(1, width)) 92 | 93 | # -------------------------------------------------------------------------- 94 | # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25 95 | self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25) 96 | 97 | # -------------------------------------------------------------------------- 98 | # MAR encoder specifics 99 | self.z_proj = nn.Linear(self.token_embed_dim, width, bias=True) 100 | self.z_proj_ln = nn.LayerNorm(width, eps=1e-6) 101 | self.buffer_size = buffer_size 102 | self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, width)) 103 | 104 | norm_layer = partial(nn.LayerNorm, eps=1e-6) 105 | self.encoder_blocks = nn.ModuleList( 106 | [ 107 | Block( 108 | width, 109 | num_heads, 110 | norm_layer=norm_layer, 111 | qkv_bias=True, 112 | proj_drop=proj_dropout, 113 | attn_drop=attn_dropout, 114 | no_dropout_in_mlp=no_dropout_in_mlp, 115 | ) 116 | for _ in range(num_layers) 117 | ] 118 | ) 119 | self.encoder_norm = norm_layer(width) 120 | 121 | # -------------------------------------------------------------------------- 122 | # MAR decoder specifics 123 | self.decoder_embed = nn.Linear(width, width, bias=True) 124 | self.mask_token = nn.Parameter(torch.zeros(1, 1, width)) 125 | self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, width)) 126 | 127 | self.decoder_blocks = nn.ModuleList( 128 | [ 129 | Block( 130 | width, 131 | num_heads, 132 | qkv_bias=True, 133 | norm_layer=norm_layer, 134 | proj_drop=proj_dropout, 135 | attn_drop=attn_dropout, 136 | no_dropout_in_mlp=no_dropout_in_mlp, 137 | ) 138 | for _ in range(num_layers) 139 | ] 140 | ) 141 | 142 | self.decoder_norm = norm_layer(width) 143 | self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, width)) 144 | 145 | self.initialize_weights() 146 | 147 | # -------------------------------------------------------------------------- 148 | # Diffusion Loss 149 | self.diffloss = DiffLoss( 150 | target_channels=self.token_embed_dim, 151 | z_channels=width, 152 | width=diffloss_w, 153 | depth=diffloss_d, 154 | num_sampling_steps=num_sampling_steps, 155 | noise_schedule=noise_schedule, 156 | grad_checkpointing=grad_checkpointing, 157 | ) 158 | self.diffusion_batch_mul = diffusion_batch_mul 159 | 160 | params_M = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 161 | logger.info(f"[MAR] params: {params_M:.2f}M, {model_size}-{num_layers}-{width}") 162 | logger.info(f"[MAR] seq_len: {self.seq_len}, buffer_size: {self.buffer_size}") 163 | 164 | def initialize_weights(self): 165 | # parameters 166 | torch.nn.init.normal_(self.class_emb.weight, std=0.02) 167 | torch.nn.init.normal_(self.fake_latent, std=0.02) 168 | torch.nn.init.normal_(self.mask_token, std=0.02) 169 | torch.nn.init.normal_(self.encoder_pos_embed_learned, std=0.02) 170 | torch.nn.init.normal_(self.decoder_pos_embed_learned, std=0.02) 171 | torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=0.02) 172 | 173 | # initialize nn.Linear and nn.LayerNorm 174 | self.apply(self._init_weights) 175 | 176 | def _init_weights(self, m): 177 | if isinstance(m, nn.Linear): 178 | # we use xavier_uniform following official JAX ViT: 179 | torch.nn.init.xavier_uniform_(m.weight) 180 | if m.bias is not None: 181 | nn.init.constant_(m.bias, 0) 182 | elif isinstance(m, nn.LayerNorm): 183 | if m.bias is not None: 184 | nn.init.constant_(m.bias, 0) 185 | if m.weight is not None: 186 | nn.init.constant_(m.weight, 1.0) 187 | 188 | def patchify(self, x): 189 | bsz, c, h, w = x.shape 190 | p = self.patch_size 191 | h_, w_ = h // p, w // p 192 | 193 | x = x.reshape(bsz, c, h_, p, w_, p) 194 | x = torch.einsum("nchpwq->nhwcpq", x) 195 | x = x.reshape(bsz, h_ * w_, c * p**2) 196 | return x # [n, l, d] 197 | 198 | def unpatchify(self, x): 199 | bsz = x.shape[0] 200 | p = self.patch_size 201 | c = self.token_channels 202 | h_, w_ = self.seq_h, self.seq_w 203 | 204 | x = x.reshape(bsz, h_, w_, c, p, p) 205 | x = torch.einsum("nhwcpq->nchpwq", x) 206 | x = x.reshape(bsz, c, h_ * p, w_ * p) 207 | return x # [n, c, h, w] 208 | 209 | def sample_orders(self, bsz): 210 | # generate a batch of random generation orders 211 | orders = [] 212 | for _ in range(bsz): 213 | order = np.array(list(range(self.seq_len))) 214 | np.random.shuffle(order) 215 | orders.append(order) 216 | orders = torch.Tensor(np.array(orders)).cuda().long() 217 | return orders 218 | 219 | def random_masking(self, x, orders): 220 | # generate token mask 221 | bsz, seq_len, _ = x.shape 222 | mask_rate = self.mask_ratio_generator.rvs(1)[0] 223 | num_masked_tokens = int(np.ceil(seq_len * mask_rate)) 224 | mask = torch.zeros(bsz, seq_len, device=x.device) 225 | mask = torch.scatter( 226 | mask, 227 | dim=-1, 228 | index=orders[:, :num_masked_tokens], 229 | src=torch.ones(bsz, seq_len, device=x.device), 230 | ) 231 | return mask 232 | 233 | def forward_mae_encoder(self, x, mask, class_embedding): 234 | x = self.z_proj(x) 235 | bsz, _, embed_dim = x.shape 236 | 237 | # concat buffer 238 | x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1) 239 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) 240 | 241 | # random drop class embedding during training 242 | if self.training: 243 | drop_latent_mask = torch.rand(bsz) < self.label_drop_prob 244 | drop_latent_mask = drop_latent_mask.unsqueeze(-1).cuda().to(x.dtype) 245 | class_embedding = drop_latent_mask * self.fake_latent + (1 - drop_latent_mask) * class_embedding 246 | 247 | x[:, : self.buffer_size] = class_embedding.unsqueeze(1) 248 | 249 | # encoder position embedding 250 | x = x + self.encoder_pos_embed_learned 251 | x = self.z_proj_ln(x) 252 | 253 | # dropping 254 | x = x[(1 - mask_with_buffer).nonzero(as_tuple=True)].reshape(bsz, -1, embed_dim) 255 | 256 | # apply Transformer blocks 257 | if self.grad_checkpointing and self.training: 258 | for i, block in enumerate(self.encoder_blocks): 259 | x = checkpoint(block, x) 260 | else: 261 | for block in self.encoder_blocks: 262 | x = block(x) 263 | x = self.encoder_norm(x) 264 | return x 265 | 266 | def forward_mae_decoder(self, x, mask): 267 | 268 | x = self.decoder_embed(x) 269 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) 270 | 271 | # pad mask tokens 272 | mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to( 273 | x.dtype 274 | ) 275 | x_after_pad = mask_tokens.clone() 276 | x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape( 277 | x.shape[0] * x.shape[1], x.shape[2] 278 | ) 279 | 280 | # decoder position embedding 281 | x = x_after_pad + self.decoder_pos_embed_learned 282 | 283 | # apply Transformer blocks 284 | if self.grad_checkpointing and self.training: 285 | for block in self.decoder_blocks: 286 | x = checkpoint(block, x) 287 | else: 288 | for block in self.decoder_blocks: 289 | x = block(x) 290 | x = self.decoder_norm(x) 291 | 292 | x = x[:, self.buffer_size :] 293 | x = x + self.diffusion_pos_embed_learned 294 | return x 295 | 296 | def forward_loss(self, z, target, mask): 297 | bsz, seq_len, _ = target.shape 298 | target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) 299 | z = z.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) 300 | mask = mask.reshape(bsz * seq_len).repeat(self.diffusion_batch_mul) 301 | loss = self.diffloss(z=z, target=target, mask=mask) 302 | return loss 303 | 304 | def forward(self, imgs, labels): 305 | 306 | # class embed 307 | class_embedding = self.class_emb(labels) 308 | 309 | # patchify and mask (drop) tokens 310 | x = self.patchify(imgs) if not self.force_one_d_seq else imgs 311 | gt_latents = x.clone().detach() 312 | orders = self.sample_orders(bsz=x.size(0)) 313 | mask = self.random_masking(x, orders) 314 | 315 | x = self.forward_mae_encoder(x, mask, class_embedding) 316 | z = self.forward_mae_decoder(x, mask) 317 | loss = self.forward_loss(z=z, target=gt_latents, mask=mask) 318 | return loss 319 | 320 | def sample_tokens( 321 | self, 322 | bsz, 323 | num_iter=64, 324 | cfg=1.0, 325 | cfg_schedule="linear", 326 | labels=None, 327 | temperature=1.0, 328 | progress=False, 329 | ): 330 | 331 | # init and sample generation orders 332 | mask = torch.ones(bsz, self.seq_len).cuda() 333 | tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda() 334 | orders = self.sample_orders(bsz) 335 | 336 | indices = list(range(num_iter)) 337 | if progress: 338 | indices = tqdm(indices) 339 | # generate latents 340 | for step in indices: 341 | cur_tokens = tokens.clone() 342 | 343 | # class embedding and CFG 344 | if labels is not None: 345 | class_embedding = self.class_emb(labels) 346 | else: 347 | class_embedding = self.fake_latent.repeat(bsz, 1) 348 | if cfg != 1.0: 349 | tokens = torch.cat([tokens, tokens], dim=0) 350 | class_embedding = torch.cat([class_embedding, self.fake_latent.repeat(bsz, 1)], dim=0) 351 | mask = torch.cat([mask, mask], dim=0) 352 | 353 | # mae encoder 354 | x = self.forward_mae_encoder(tokens, mask, class_embedding) 355 | 356 | # mae decoder 357 | z = self.forward_mae_decoder(x, mask) 358 | 359 | # mask ratio for the next round, following MaskGIT and MAGE. 360 | mask_ratio = np.cos(math.pi / 2.0 * (step + 1) / num_iter) 361 | mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda() 362 | 363 | # masks out at least one for the next iteration 364 | mask_len = torch.maximum( 365 | torch.Tensor([1]).cuda(), 366 | torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len), 367 | ) 368 | 369 | # get masking for next iteration and locations to be predicted in this iteration 370 | mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len) 371 | if step >= num_iter - 1: 372 | mask_to_pred = mask[:bsz].bool() 373 | else: 374 | mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool()) 375 | mask = mask_next 376 | if cfg != 1.0: 377 | mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0) 378 | 379 | # sample token latents for this step 380 | z = z[mask_to_pred.nonzero(as_tuple=True)] 381 | # cfg schedule follow Muse 382 | if cfg_schedule == "linear": 383 | cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len 384 | elif cfg_schedule == "constant": 385 | cfg_iter = cfg 386 | else: 387 | raise NotImplementedError 388 | sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter) 389 | if cfg != 1.0: 390 | sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples 391 | mask_to_pred, _ = mask_to_pred.chunk(2, dim=0) 392 | 393 | cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent 394 | tokens = cur_tokens.clone() 395 | 396 | # unpatchify 397 | if not self.force_one_d_seq: 398 | tokens = self.unpatchify(tokens) 399 | return tokens 400 | 401 | @torch.inference_mode() 402 | def generate(self, n_samples, cfg, labels, args): 403 | return self.sample_tokens( 404 | n_samples, 405 | num_iter=args.num_iter, 406 | cfg=cfg, 407 | labels=labels, 408 | cfg_schedule=args.cfg_schedule, 409 | temperature=args.temperature, 410 | progress=True, 411 | ) 412 | 413 | 414 | def mar_base(**kwargs) -> MAR: 415 | return MAR(model_size="base", **kwargs) 416 | 417 | 418 | def mar_large(**kwargs): 419 | return MAR(model_size="large", **kwargs) 420 | 421 | 422 | def mar_huge(**kwargs): 423 | return MAR(model_size="huge", **kwargs) 424 | 425 | 426 | MAR_models = { 427 | "MAR_base": mar_base, 428 | "MAR_large": mar_large, 429 | "MAR_huge": mar_huge, 430 | } 431 | -------------------------------------------------------------------------------- /models/model_utils.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------- 2 | # model size configurations 3 | SIZE_DICT = { 4 | "small": {"width": 512, "layers": 8, "heads": 8}, 5 | "base": {"width": 768, "layers": 12, "heads": 12}, 6 | "large": {"width": 1024, "layers": 24, "heads": 16}, 7 | "xl": {"width": 1152, "layers": 28, "heads": 16}, 8 | "huge": {"width": 1280, "layers": 32, "heads": 16}, 9 | } 10 | -------------------------------------------------------------------------------- /models/sit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from https://github.com/willisma/SiT/blob/main/models.py 3 | 4 | - add support for 1D sequence 5 | - include samplers inside the model 6 | - slightly different cfg conditioning (conditioned on **all channels**) 7 | """ 8 | 9 | import logging 10 | from functools import partial 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils.checkpoint import checkpoint 15 | 16 | from transport import Sampler, create_transport 17 | 18 | from .layers import ( 19 | Block, 20 | LabelEmbedder, 21 | ModulatedLinear, 22 | PatchEmbed, 23 | TimestepEmbedder, 24 | Transformer, 25 | get_2d_sincos_pos_embed, 26 | ) 27 | from .model_utils import SIZE_DICT 28 | 29 | logger = logging.getLogger("DeTok") 30 | 31 | 32 | class SiT(nn.Module): 33 | """scalable interpolant transformer model.""" 34 | 35 | def __init__( 36 | self, 37 | img_size=256, 38 | patch_size=1, 39 | model_size="base", 40 | tokenizer_patch_size=16, 41 | token_channels=16, 42 | label_drop_prob=0.1, 43 | num_classes=1000, 44 | num_sampling_steps=250, 45 | sampling_method="dopri5", 46 | grad_checkpointing=False, 47 | force_one_d_seq=0, 48 | learn_sigma=False, # no learn_sigma in SiT 49 | legacy_mode=False, 50 | qk_norm=False, 51 | ): 52 | super().__init__() 53 | 54 | # -------------------------------------------------------------------------- 55 | # basic configuration 56 | self.token_channels = token_channels 57 | self.out_channels = token_channels * 2 if learn_sigma else token_channels 58 | self.input_size = img_size // tokenizer_patch_size 59 | self.patch_size = patch_size 60 | self.num_classes = num_classes 61 | self.force_one_d_seq = force_one_d_seq 62 | self.grad_checkpointing = grad_checkpointing 63 | self.learn_sigma = learn_sigma 64 | self.legacy_mode = legacy_mode 65 | 66 | # model architecture configuration 67 | size_dict = SIZE_DICT[model_size] 68 | num_layers, num_heads, width = size_dict["layers"], size_dict["heads"], size_dict["width"] 69 | 70 | # -------------------------------------------------------------------------- 71 | # embedding layers 72 | if self.force_one_d_seq: 73 | self.x_embedder = nn.Linear(token_channels, width) 74 | self.pos_embed = nn.Parameter(torch.randn(1, self.force_one_d_seq, width) * 0.02) 75 | self.seq_len = self.force_one_d_seq 76 | else: 77 | self.x_embedder = PatchEmbed(self.input_size, patch_size, token_channels, width) 78 | num_patches = self.x_embedder.num_patches 79 | self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, width), requires_grad=False) 80 | self.seq_len = num_patches 81 | 82 | self.t_embedder = TimestepEmbedder(width) 83 | self.y_embedder = LabelEmbedder(num_classes, width, label_drop_prob) 84 | 85 | # -------------------------------------------------------------------------- 86 | # transformer architecture 87 | self.transformer = Transformer( 88 | width, 89 | num_layers, 90 | num_heads, 91 | block_fn=partial(Block, use_modulation=True), 92 | norm_layer=partial(nn.LayerNorm, elementwise_affine=False, eps=1e-6), 93 | qk_norm=qk_norm, 94 | grad_checkpointing=grad_checkpointing, 95 | ) 96 | self.final_layer = ModulatedLinear(width, patch_size * patch_size * self.out_channels) 97 | 98 | # -------------------------------------------------------------------------- 99 | # transport and sampling setup 100 | self.transport = create_transport() 101 | self.sampler = Sampler(self.transport) 102 | self.sample_fn = self.sampler.sample_ode( 103 | sampling_method=sampling_method, 104 | num_steps=int(num_sampling_steps), 105 | ) 106 | 107 | self.initialize_weights() 108 | 109 | # log model info 110 | params_M = sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 111 | logger.info(f"[SiT] params: {params_M:.2f}M, {model_size}-{num_layers}-{width}") 112 | logger.info(f"[SiT] seq_len: {self.seq_len}") 113 | 114 | def initialize_weights(self): 115 | """initialize model weights.""" 116 | 117 | def _basic_init(module): 118 | if isinstance(module, nn.Linear): 119 | torch.nn.init.xavier_uniform_(module.weight) 120 | if module.bias is not None: 121 | nn.init.constant_(module.bias, 0) 122 | 123 | self.apply(_basic_init) 124 | 125 | # initialize (and freeze) pos_embed by sin-cos embedding 126 | if not self.force_one_d_seq: 127 | pos_embed = get_2d_sincos_pos_embed( 128 | self.pos_embed.shape[-1], int(self.x_embedder.num_patches**0.5) 129 | ) 130 | self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) 131 | 132 | # initialize patch_embed like nn.Linear (instead of nn.Conv2d) 133 | w = self.x_embedder.proj.weight.data 134 | nn.init.xavier_uniform_(w.view([w.shape[0], -1])) 135 | nn.init.constant_(self.x_embedder.proj.bias, 0) 136 | 137 | # initialize label embedding table 138 | nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) 139 | 140 | # initialize timestep embedding MLP 141 | nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) 142 | nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) 143 | 144 | # zero-out adaLN modulation layers in SiT blocks 145 | for block in self.transformer.blocks: 146 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 147 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 148 | 149 | # zero-out output layers 150 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 151 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 152 | nn.init.constant_(self.final_layer.linear.weight, 0) 153 | nn.init.constant_(self.final_layer.linear.bias, 0) 154 | 155 | def unpatchify(self, x): 156 | """convert patch tokens back to image tensor.""" 157 | c, p = self.out_channels, self.patch_size 158 | h = w = int(x.shape[1] ** 0.5) 159 | assert h * w == x.shape[1] 160 | 161 | x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) 162 | x = torch.einsum("nhwpqc->nchpwq", x) 163 | imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p)) 164 | return imgs 165 | 166 | def net(self, x, t, y): 167 | """core network forward pass.""" 168 | x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2 169 | c = self.t_embedder(t) + self.y_embedder(y, self.training) # (N, D) 170 | 171 | # transformer forward pass 172 | for block in self.transformer.blocks: 173 | if self.grad_checkpointing and self.training: 174 | x = checkpoint(block, x, condition=c) 175 | else: 176 | x = block(x, condition=c) 177 | 178 | x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) 179 | 180 | if not self.force_one_d_seq: 181 | x = self.unpatchify(x) # (N, out_channels, H, W) 182 | if self.learn_sigma: 183 | x, _ = x.chunk(2, dim=1) 184 | return x 185 | 186 | def forward_with_cfg(self, x, t, y, cfg_scale): 187 | """forward pass with classifier-free guidance.""" 188 | half = x[: len(x) // 2] 189 | combined = torch.cat([half, half], dim=0) 190 | model_out = self.net(combined, t, y) 191 | 192 | if self.legacy_mode: 193 | eps, rest = model_out[:, :3], model_out[:, 3:] 194 | else: 195 | eps, rest = model_out[:, : self.token_channels], model_out[:, self.token_channels :] 196 | 197 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 198 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 199 | eps = torch.cat([half_eps, half_eps], dim=0) 200 | return torch.cat([eps, rest], dim=1) 201 | 202 | def forward(self, x, y): 203 | """forward pass for training.""" 204 | loss_dict = self.transport.training_losses(self.net, x, dict(y=y)) 205 | return loss_dict["loss"].mean() 206 | 207 | @torch.inference_mode() 208 | def generate(self, n_samples, labels, cfg=1.0, args=None): 209 | """generate samples using the model.""" 210 | device = labels.device 211 | 212 | # prepare noise tensor 213 | if self.force_one_d_seq: 214 | z = torch.randn(n_samples, self.force_one_d_seq, self.token_channels) 215 | else: 216 | z = torch.randn(n_samples, self.token_channels, self.input_size, self.input_size) 217 | z = z.to(device) 218 | 219 | # setup classifier-free guidance 220 | if cfg > 1.0: 221 | z = torch.cat([z, z], 0) 222 | y_null = torch.tensor([self.num_classes] * n_samples, device=device) 223 | labels = torch.cat([labels, y_null], 0) 224 | model_kwargs = dict(y=labels, cfg_scale=cfg) 225 | model_fn = self.forward_with_cfg 226 | else: 227 | model_kwargs = dict(y=labels) 228 | model_fn = self.net 229 | 230 | # generate samples 231 | samples = self.sample_fn(z, model_fn, **model_kwargs)[-1] 232 | if cfg > 1.0: 233 | samples, _ = samples.chunk(2, dim=0) # remove null class samples 234 | return samples 235 | 236 | 237 | # model size variants 238 | def SiT_XL(**kwargs) -> SiT: 239 | return SiT(model_size="xl", **kwargs) 240 | 241 | 242 | def SiT_L(**kwargs) -> SiT: 243 | return SiT(model_size="large", **kwargs) 244 | 245 | 246 | def SiT_B(**kwargs) -> SiT: 247 | return SiT(model_size="base", **kwargs) 248 | 249 | 250 | def SiT_S(**kwargs) -> SiT: 251 | return SiT(model_size="small", **kwargs) 252 | 253 | 254 | SiT_models = {"SiT_base": SiT_B, "SiT_large": SiT_L, "SiT_xl": SiT_XL, "SiT_small": SiT_S} 255 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | "**/node_modules", 4 | "**/__pycache__", 5 | "**/.git", 6 | "data", 7 | "legacy_code", 8 | "work_dirs" 9 | ], 10 | "reportUnknownMemberType": false, 11 | "reportUnannotatedClassAttribute": false, 12 | "reportUnusedCallResult": false, 13 | "reportUnknownParameterType": false, 14 | "reportUnknownVariableType": false, 15 | "reportUnknownArgumentType": false, 16 | "reportAny": false, 17 | "reportExplicitAny": false, 18 | "reportMissingParameterType": false, 19 | "reportArgumentType": false, 20 | "reportMissingTypeStubs": false, 21 | "reportOptionalMemberAccess": false, 22 | "reportImplicitOverride": false, 23 | "reportUntypedFunctionDecorator": false, 24 | "reportOperatorIssue": false, 25 | "reportIndexIssue": false, 26 | "reportUntypedNamedTuple": false, 27 | "reportImplicitStringConcatenation": false, 28 | "reportUnknownLambdaType": false, 29 | "reportPrivateLocalImportUsage": false, 30 | "reportOptionalOperand": false, 31 | "reportReturnType": false, 32 | "reportMissingTypeArgument": false, 33 | "reportCallIssue": false, 34 | "reportCallInDefaultInitializer": false, 35 | "reportIncompatibleMethodOverride": false, 36 | "reportAttributeAccessIssue": false, 37 | "reportPossiblyUnboundVariable": false, 38 | "reportMissingSuperCall": false, 39 | "reportOptionalSubscript": false, 40 | "reportImportCycles": false 41 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu124 2 | torch==2.6.0+cu124 3 | timm 4 | scipy 5 | einops 6 | wandb 7 | matplotlib 8 | diffusers 9 | torchtnt 10 | rich 11 | git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity 12 | huggingface_hub 13 | 14 | black 15 | pylint 16 | flake8 17 | isort 18 | jaxtyping 19 | torchdiffeq 20 | gdown 21 | ipykernel -------------------------------------------------------------------------------- /transport/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/hustvl/LightningDiT/blob/main/transport/__init__.py 3 | """ 4 | 5 | from .transport import ModelType, PathType, Transport, WeightType, Sampler 6 | 7 | 8 | 9 | def create_transport( 10 | path_type="Linear", 11 | prediction="velocity", 12 | loss_weight=None, 13 | train_eps=None, 14 | sample_eps=None, 15 | use_cosine_loss=False, 16 | use_lognorm=None, 17 | partitial_train=None, 18 | partial_ratio=1.0, 19 | shift_lg=False, 20 | ): 21 | """function for creating Transport object 22 | **Note**: model prediction defaults to velocity 23 | Args: 24 | - path_type: type of path to use; default to linear 25 | - learn_score: set model prediction to score 26 | - learn_noise: set model prediction to noise 27 | - velocity_weighted: weight loss by velocity weight 28 | - likelihood_weighted: weight loss by likelihood weight 29 | - train_eps: small epsilon for avoiding instability during training 30 | - sample_eps: small epsilon for avoiding instability during sampling 31 | """ 32 | if prediction == "noise": 33 | model_type = ModelType.NOISE 34 | elif prediction == "score": 35 | model_type = ModelType.SCORE 36 | else: 37 | model_type = ModelType.VELOCITY 38 | 39 | if loss_weight == "velocity": 40 | loss_type = WeightType.VELOCITY 41 | elif loss_weight == "likelihood": 42 | loss_type = WeightType.LIKELIHOOD 43 | else: 44 | loss_type = WeightType.NONE 45 | 46 | path_choice = { 47 | "Linear": PathType.LINEAR, 48 | "GVP": PathType.GVP, 49 | "VP": PathType.VP, 50 | } 51 | 52 | path_type = path_choice[path_type] 53 | 54 | if path_type in [PathType.VP]: 55 | train_eps = 1e-5 if train_eps is None else train_eps 56 | sample_eps = 1e-3 if train_eps is None else sample_eps 57 | elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY: 58 | train_eps = 1e-3 if train_eps is None else train_eps 59 | sample_eps = 1e-3 if train_eps is None else sample_eps 60 | else: # velocity & [GVP, LINEAR] is stable everywhere 61 | train_eps = 0 62 | sample_eps = 0 63 | 64 | # create flow state 65 | state = Transport( 66 | model_type=model_type, 67 | path_type=path_type, 68 | loss_type=loss_type, 69 | train_eps=train_eps, 70 | sample_eps=sample_eps, 71 | use_cosine_loss=use_cosine_loss, 72 | use_lognorm=use_lognorm, 73 | partitial_train=partitial_train, 74 | partial_ratio=partial_ratio, 75 | shift_lg=shift_lg, 76 | ) 77 | 78 | return state 79 | -------------------------------------------------------------------------------- /transport/integrators.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/hustvl/LightningDiT/blob/main/transport/integrators.py 3 | """ 4 | 5 | import torch as th 6 | from torchdiffeq import odeint 7 | 8 | 9 | class sde: 10 | """SDE solver class""" 11 | 12 | def __init__( 13 | self, 14 | drift, 15 | diffusion, 16 | *, 17 | t0, 18 | t1, 19 | num_steps, 20 | sampler_type, 21 | ): 22 | assert t0 < t1, "SDE sampler has to be in forward time" 23 | 24 | self.num_timesteps = num_steps 25 | self.t = th.linspace(t0, t1, num_steps) 26 | self.dt = self.t[1] - self.t[0] 27 | self.drift = drift 28 | self.diffusion = diffusion 29 | self.sampler_type = sampler_type 30 | 31 | def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs): 32 | w_cur = th.randn(x.size()).to(x) 33 | t = th.ones(x.size(0)).to(x) * t 34 | dw = w_cur * th.sqrt(self.dt) 35 | drift = self.drift(x, t, model, **model_kwargs) 36 | diffusion = self.diffusion(x, t) 37 | mean_x = x + drift * self.dt 38 | x = mean_x + th.sqrt(2 * diffusion) * dw 39 | return x, mean_x 40 | 41 | def __Heun_step(self, x, _, t, model, **model_kwargs): 42 | w_cur = th.randn(x.size()).to(x) 43 | dw = w_cur * th.sqrt(self.dt) 44 | t_cur = th.ones(x.size(0)).to(x) * t 45 | diffusion = self.diffusion(x, t_cur) 46 | xhat = x + th.sqrt(2 * diffusion) * dw 47 | K1 = self.drift(xhat, t_cur, model, **model_kwargs) 48 | xp = xhat + self.dt * K1 49 | K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs) 50 | return ( 51 | xhat + 0.5 * self.dt * (K1 + K2), 52 | xhat, 53 | ) # at last time point we do not perform the heun step 54 | 55 | def __forward_fn(self): 56 | """TODO: generalize here by adding all private functions ending with steps to it""" 57 | sampler_dict = { 58 | "Euler": self.__Euler_Maruyama_step, 59 | "Heun": self.__Heun_step, 60 | } 61 | 62 | try: 63 | sampler = sampler_dict[self.sampler_type] 64 | except: 65 | raise NotImplementedError("Smapler type not implemented.") 66 | 67 | return sampler 68 | 69 | def sample(self, init, model, **model_kwargs): 70 | """forward loop of sde""" 71 | x = init 72 | mean_x = init 73 | samples = [] 74 | sampler = self.__forward_fn() 75 | for ti in self.t[:-1]: 76 | with th.no_grad(): 77 | x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs) 78 | samples.append(x) 79 | 80 | return samples 81 | 82 | 83 | class ode: 84 | """ODE solver class""" 85 | 86 | def __init__( 87 | self, 88 | drift, 89 | *, 90 | t0, 91 | t1, 92 | sampler_type, 93 | num_steps, 94 | atol, 95 | rtol, 96 | timestep_shift, 97 | ): 98 | assert t0 < t1, "ODE sampler has to be in forward time" 99 | 100 | self.drift = drift 101 | self.t = th.linspace(t0, t1, num_steps) 102 | 103 | if timestep_shift > 0: 104 | 105 | def compute_tm(t_n, timestep_shift): 106 | numerator = timestep_shift * t_n 107 | denominator = 1 + (timestep_shift - 1) * t_n 108 | return numerator / denominator 109 | 110 | self.t = th.tensor([compute_tm(t_n, timestep_shift) for t_n in self.t]) 111 | 112 | self.atol = atol 113 | self.rtol = rtol 114 | self.sampler_type = sampler_type 115 | 116 | def sample(self, x, model, **model_kwargs): 117 | 118 | device = x[0].device if isinstance(x, tuple) else x.device 119 | 120 | def _fn(t, x): 121 | t = ( 122 | th.ones(x[0].size(0)).to(device) * t 123 | if isinstance(x, tuple) 124 | else th.ones(x.size(0)).to(device) * t 125 | ) 126 | model_output = self.drift(x, t, model, **model_kwargs) 127 | return model_output 128 | 129 | t = self.t.to(device) 130 | atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol] 131 | rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol] 132 | samples = odeint(_fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol) 133 | return samples 134 | -------------------------------------------------------------------------------- /transport/path.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/hustvl/LightningDiT/blob/main/transport/path.py 3 | """ 4 | 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | 10 | def expand_t_like_x(t, x): 11 | """Function to reshape time t to broadcastable dimension of x 12 | Args: 13 | t: [batch_dim,], time vector 14 | x: [batch_dim,...], data point 15 | """ 16 | dims = [1] * (len(x.size()) - 1) 17 | t = t.view(t.size(0), *dims) 18 | return t 19 | 20 | 21 | class ICPlan: 22 | """Linear Coupling Plan""" 23 | 24 | def __init__(self, sigma=0.0): 25 | self.sigma = sigma 26 | 27 | def compute_alpha_t(self, t): 28 | """Compute the data coefficient along the path""" 29 | return t, 1 30 | 31 | def compute_sigma_t(self, t): 32 | """Compute the noise coefficient along the path""" 33 | return 1 - t, -1 34 | 35 | def compute_d_alpha_alpha_ratio_t(self, t): 36 | """Compute the ratio between d_alpha and alpha""" 37 | return 1 / t 38 | 39 | def compute_drift(self, x, t): 40 | """We always output sde according to score parametrization;""" 41 | t = expand_t_like_x(t, x) 42 | alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t) 43 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 44 | drift = alpha_ratio * x 45 | diffusion = alpha_ratio * (sigma_t**2) - sigma_t * d_sigma_t 46 | 47 | return -drift, diffusion 48 | 49 | def compute_diffusion(self, x, t, form="constant", norm=1.0): 50 | """Compute the diffusion term of the SDE 51 | Args: 52 | x: [batch_dim, ...], data point 53 | t: [batch_dim,], time vector 54 | form: str, form of the diffusion term 55 | norm: float, norm of the diffusion term 56 | """ 57 | t = expand_t_like_x(t, x) 58 | choices = { 59 | "constant": norm, 60 | "SBDM": norm * self.compute_drift(x, t)[1], 61 | "sigma": norm * self.compute_sigma_t(t)[0], 62 | "linear": norm * (1 - t), 63 | "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2, 64 | "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2, 65 | } 66 | 67 | try: 68 | diffusion = choices[form] 69 | except KeyError: 70 | raise NotImplementedError(f"Diffusion form {form} not implemented") 71 | 72 | return diffusion 73 | 74 | def get_score_from_velocity(self, velocity, x, t): 75 | """Wrapper function: transfrom velocity prediction model to score 76 | Args: 77 | velocity: [batch_dim, ...] shaped tensor; velocity model output 78 | x: [batch_dim, ...] shaped tensor; x_t data point 79 | t: [batch_dim,] time tensor 80 | """ 81 | t = expand_t_like_x(t, x) 82 | alpha_t, d_alpha_t = self.compute_alpha_t(t) 83 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 84 | mean = x 85 | reverse_alpha_ratio = alpha_t / d_alpha_t 86 | var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t 87 | score = (reverse_alpha_ratio * velocity - mean) / var 88 | return score 89 | 90 | def get_noise_from_velocity(self, velocity, x, t): 91 | """Wrapper function: transfrom velocity prediction model to denoiser 92 | Args: 93 | velocity: [batch_dim, ...] shaped tensor; velocity model output 94 | x: [batch_dim, ...] shaped tensor; x_t data point 95 | t: [batch_dim,] time tensor 96 | """ 97 | t = expand_t_like_x(t, x) 98 | alpha_t, d_alpha_t = self.compute_alpha_t(t) 99 | sigma_t, d_sigma_t = self.compute_sigma_t(t) 100 | mean = x 101 | reverse_alpha_ratio = alpha_t / d_alpha_t 102 | var = reverse_alpha_ratio * d_sigma_t - sigma_t 103 | noise = (reverse_alpha_ratio * velocity - mean) / var 104 | return noise 105 | 106 | def get_velocity_from_score(self, score, x, t): 107 | """Wrapper function: transfrom score prediction model to velocity 108 | Args: 109 | score: [batch_dim, ...] shaped tensor; score model output 110 | x: [batch_dim, ...] shaped tensor; x_t data point 111 | t: [batch_dim,] time tensor 112 | """ 113 | t = expand_t_like_x(t, x) 114 | drift, var = self.compute_drift(x, t) 115 | velocity = var * score - drift 116 | return velocity 117 | 118 | def compute_mu_t(self, t, x0, x1): 119 | """Compute the mean of time-dependent density p_t""" 120 | t = expand_t_like_x(t, x1) 121 | alpha_t, _ = self.compute_alpha_t(t) 122 | sigma_t, _ = self.compute_sigma_t(t) 123 | return alpha_t * x1 + sigma_t * x0 124 | 125 | def compute_xt(self, t, x0, x1): 126 | """Sample xt from time-dependent density p_t; rng is required""" 127 | xt = self.compute_mu_t(t, x0, x1) 128 | return xt 129 | 130 | def compute_ut(self, t, x0, x1, xt): 131 | """Compute the vector field corresponding to p_t""" 132 | t = expand_t_like_x(t, x1) 133 | _, d_alpha_t = self.compute_alpha_t(t) 134 | _, d_sigma_t = self.compute_sigma_t(t) 135 | return d_alpha_t * x1 + d_sigma_t * x0 136 | 137 | def plan(self, t, x0, x1): 138 | xt = self.compute_xt(t, x0, x1) 139 | ut = self.compute_ut(t, x0, x1, xt) 140 | return t, xt, ut 141 | 142 | 143 | class VPCPlan(ICPlan): 144 | """class for VP path flow matching""" 145 | 146 | def __init__(self, sigma_min=0.1, sigma_max=20.0): 147 | self.sigma_min = sigma_min 148 | self.sigma_max = sigma_max 149 | self.log_mean_coeff = ( 150 | lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) 151 | - 0.5 * (1 - t) * self.sigma_min 152 | ) 153 | self.d_log_mean_coeff = ( 154 | lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min 155 | ) 156 | 157 | def compute_alpha_t(self, t): 158 | """Compute coefficient of x1""" 159 | alpha_t = self.log_mean_coeff(t) 160 | alpha_t = th.exp(alpha_t) 161 | d_alpha_t = alpha_t * self.d_log_mean_coeff(t) 162 | return alpha_t, d_alpha_t 163 | 164 | def compute_sigma_t(self, t): 165 | """Compute coefficient of x0""" 166 | p_sigma_t = 2 * self.log_mean_coeff(t) 167 | sigma_t = th.sqrt(1 - th.exp(p_sigma_t)) 168 | d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t) 169 | return sigma_t, d_sigma_t 170 | 171 | def compute_d_alpha_alpha_ratio_t(self, t): 172 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" 173 | return self.d_log_mean_coeff(t) 174 | 175 | def compute_drift(self, x, t): 176 | """Compute the drift term of the SDE""" 177 | t = expand_t_like_x(t, x) 178 | beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min) 179 | return -0.5 * beta_t * x, beta_t / 2 180 | 181 | 182 | class GVPCPlan(ICPlan): 183 | def __init__(self, sigma=0.0): 184 | super().__init__(sigma) 185 | 186 | def compute_alpha_t(self, t): 187 | """Compute coefficient of x1""" 188 | alpha_t = th.sin(t * np.pi / 2) 189 | d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2) 190 | return alpha_t, d_alpha_t 191 | 192 | def compute_sigma_t(self, t): 193 | """Compute coefficient of x0""" 194 | sigma_t = th.cos(t * np.pi / 2) 195 | d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2) 196 | return sigma_t, d_sigma_t 197 | 198 | def compute_d_alpha_alpha_ratio_t(self, t): 199 | """Special purposed function for computing numerical stabled d_alpha_t / alpha_t""" 200 | return np.pi / (2 * th.tan(t * np.pi / 2)) 201 | -------------------------------------------------------------------------------- /transport/transport.py: -------------------------------------------------------------------------------- 1 | """ 2 | Reference: https://github.com/hustvl/LightningDiT/blob/main/transport/transport.py 3 | """ 4 | 5 | import enum 6 | 7 | import numpy as np 8 | import torch as th 9 | from scipy.stats import norm 10 | 11 | from . import path 12 | from .integrators import ode, sde 13 | 14 | 15 | def mean_flat(x): 16 | """ 17 | Take the mean over all non-batch dimensions. 18 | """ 19 | return th.mean(x, dim=list(range(1, len(x.size())))) 20 | 21 | 22 | class ModelType(enum.Enum): 23 | """ 24 | Which type of output the model predicts. 25 | """ 26 | 27 | NOISE = enum.auto() # the model predicts epsilon 28 | SCORE = enum.auto() # the model predicts \nabla \log p(x) 29 | VELOCITY = enum.auto() # the model predicts v(x) 30 | 31 | 32 | class PathType(enum.Enum): 33 | """ 34 | Which type of path to use. 35 | """ 36 | 37 | LINEAR = enum.auto() 38 | GVP = enum.auto() 39 | VP = enum.auto() 40 | 41 | 42 | class WeightType(enum.Enum): 43 | """ 44 | Which type of weighting to use. 45 | """ 46 | 47 | NONE = enum.auto() 48 | VELOCITY = enum.auto() 49 | LIKELIHOOD = enum.auto() 50 | 51 | 52 | class Transport: 53 | def __init__( 54 | self, 55 | *, 56 | model_type, 57 | path_type, 58 | loss_type, 59 | train_eps, 60 | sample_eps, 61 | use_cosine_loss=False, 62 | use_lognorm=False, 63 | partitial_train=None, 64 | partial_ratio=1.0, 65 | shift_lg=False, 66 | ): 67 | path_options = { 68 | PathType.LINEAR: path.ICPlan, 69 | PathType.GVP: path.GVPCPlan, 70 | PathType.VP: path.VPCPlan, 71 | } 72 | 73 | self.loss_type = loss_type 74 | self.model_type = model_type 75 | self.path_sampler = path_options[path_type]() 76 | self.train_eps = train_eps 77 | self.sample_eps = sample_eps 78 | self.use_cosine_loss = use_cosine_loss 79 | self.use_lognorm = use_lognorm 80 | self.partitial_train = partitial_train 81 | self.partial_ratio = partial_ratio 82 | self.shift_lg = shift_lg 83 | 84 | def prior_logp(self, z): 85 | """ 86 | Standard multivariate normal prior 87 | Assume z is batched 88 | """ 89 | shape = th.tensor(z.size()) 90 | N = th.prod(shape[1:]) 91 | _fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0 92 | return th.vmap(_fn)(z) 93 | 94 | def check_interval( 95 | self, 96 | train_eps, 97 | sample_eps, 98 | *, 99 | diffusion_form="SBDM", 100 | sde=False, 101 | reverse=False, 102 | eval=False, 103 | last_step_size=0.0, 104 | ): 105 | t0 = 0 106 | t1 = 1 107 | eps = train_eps if not eval else sample_eps 108 | if type(self.path_sampler) in [path.VPCPlan]: 109 | 110 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size 111 | 112 | elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and ( 113 | self.model_type != ModelType.VELOCITY or sde 114 | ): # avoid numerical issue by taking a first semi-implicit step 115 | 116 | t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0 117 | t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size 118 | 119 | if reverse: 120 | t0, t1 = 1 - t0, 1 - t1 121 | 122 | return t0, t1 123 | 124 | def sample_logit_normal(self, mu, sigma, size=1): 125 | # Generate samples from the normal distribution 126 | samples = norm.rvs(loc=mu, scale=sigma, size=size) 127 | 128 | # Transform samples to be in the range (0, 1) using the logistic function 129 | samples = 1 / (1 + np.exp(-samples)) 130 | 131 | # Numpy to Tensor 132 | samples = th.tensor(samples, dtype=th.float32) 133 | 134 | return samples 135 | 136 | def sample_in_range(self, mu, sigma, target_size, range_min=0, range_max=0.5): 137 | samples = [] 138 | while len(samples) < target_size: 139 | generated_samples = self.sample_logit_normal(mu, sigma, size=target_size) 140 | filtered_samples = generated_samples[ 141 | (generated_samples >= range_min) & (generated_samples <= range_max) 142 | ] 143 | samples.extend(filtered_samples) 144 | 145 | # If we have more than the target size, truncate the list 146 | samples = samples[:target_size] 147 | return th.tensor(samples) 148 | 149 | def sample(self, x1, sp_timesteps=None, shifted_mu=0): 150 | """Sampling x0 & t based on shape of x1 (if needed) 151 | Args: 152 | x1 - data point; [batch, *dim] 153 | """ 154 | 155 | x0 = th.randn_like(x1) 156 | t0, t1 = self.check_interval(self.train_eps, self.sample_eps) 157 | if not self.use_lognorm: 158 | if self.partitial_train is not None and th.rand(1) < self.partial_ratio: 159 | t = ( 160 | th.rand((x1.shape[0],)) * (self.partitial_train[1] - self.partitial_train[0]) 161 | + self.partitial_train[0] 162 | ) 163 | else: 164 | t = th.rand((x1.shape[0],)) * (t1 - t0) + t0 165 | else: 166 | # random < partial_ratio, then sample from the partial range 167 | if not self.shift_lg: 168 | if self.partitial_train is not None and th.rand(1) < self.partial_ratio: 169 | t = self.sample_in_range( 170 | 0, 171 | 1, 172 | x1.shape[0], 173 | range_min=self.partitial_train[0], 174 | range_max=self.partitial_train[1], 175 | ) 176 | else: 177 | t = self.sample_logit_normal(0, 1, size=x1.shape[0]) * (t1 - t0) + t0 178 | else: 179 | assert ( 180 | self.partitial_train is None 181 | ), "Shifted lognormal distribution is not compatible with partial training" 182 | t = self.sample_logit_normal(shifted_mu, 1, size=x1.shape[0]) * (t1 - t0) + t0 183 | 184 | # overwrite t if sp_timesteps is provided (for validation) 185 | if sp_timesteps is not None: 186 | # uniform sampling between self.sp_timesteps[0] and self.sp_timesteps[1] 187 | t = th.rand((x1.shape[0],)) * (sp_timesteps[1] - sp_timesteps[0]) + sp_timesteps[0] 188 | 189 | t = t.to(x1) 190 | return t, x0, x1 191 | 192 | def training_losses( 193 | self, 194 | model, 195 | x1, 196 | model_kwargs=None, 197 | sp_timesteps=None, 198 | shifted_mu=0, 199 | return_model_output=False, 200 | ): 201 | """Loss for training the score model 202 | Args: 203 | - model: backbone model; could be score, noise, or velocity 204 | - x1: datapoint 205 | - model_kwargs: additional arguments for the model 206 | """ 207 | if model_kwargs == None: 208 | model_kwargs = {} 209 | 210 | t, x0, x1 = self.sample(x1, sp_timesteps, shifted_mu) 211 | t, xt, ut = self.path_sampler.plan(t, x0, x1) 212 | raw_model_output = model(xt, t, **model_kwargs) 213 | if isinstance(raw_model_output, tuple): 214 | model_output = raw_model_output[0] 215 | else: 216 | model_output = raw_model_output 217 | B, *_, C = xt.shape 218 | # the channel dim is the one with least size 219 | channel_dim = min([i for i in range(1, len(xt.shape))], key=lambda x: xt.size(x)) 220 | assert model_output.size() == (B, *xt.size()[1:-1], C) 221 | 222 | terms = {} 223 | terms["pred"] = model_output 224 | if self.model_type == ModelType.VELOCITY: 225 | terms["loss"] = mean_flat(((model_output - ut) ** 2)) 226 | if self.use_cosine_loss: 227 | terms["cos_loss"] = mean_flat( 228 | 1 - th.nn.functional.cosine_similarity(model_output, ut, dim=channel_dim) 229 | ) 230 | else: 231 | _, drift_var = self.path_sampler.compute_drift(xt, t) 232 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt)) 233 | if self.loss_type in [WeightType.VELOCITY]: 234 | weight = (drift_var / sigma_t) ** 2 235 | elif self.loss_type in [WeightType.LIKELIHOOD]: 236 | weight = drift_var / (sigma_t**2) 237 | elif self.loss_type in [WeightType.NONE]: 238 | weight = 1 239 | else: 240 | raise NotImplementedError() 241 | 242 | if self.model_type == ModelType.NOISE: 243 | terms["loss"] = mean_flat(weight * ((model_output - x0) ** 2)) 244 | else: 245 | terms["loss"] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2)) 246 | if return_model_output: 247 | return terms, raw_model_output 248 | return terms 249 | 250 | def get_drift(self): 251 | """member function for obtaining the drift of the probability flow ODE""" 252 | 253 | def score_ode(x, t, model, **model_kwargs): 254 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t) 255 | model_output = model(x, t, **model_kwargs) 256 | return -drift_mean + drift_var * model_output # by change of variable 257 | 258 | def noise_ode(x, t, model, **model_kwargs): 259 | drift_mean, drift_var = self.path_sampler.compute_drift(x, t) 260 | sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x)) 261 | model_output = model(x, t, **model_kwargs) 262 | score = model_output / -sigma_t 263 | return -drift_mean + drift_var * score 264 | 265 | def velocity_ode(x, t, model, **model_kwargs): 266 | model_output = model(x, t, **model_kwargs) 267 | return model_output 268 | 269 | if self.model_type == ModelType.NOISE: 270 | drift_fn = noise_ode 271 | elif self.model_type == ModelType.SCORE: 272 | drift_fn = score_ode 273 | else: 274 | drift_fn = velocity_ode 275 | 276 | def body_fn(x, t, model, **model_kwargs): 277 | model_output = drift_fn(x, t, model, **model_kwargs) 278 | assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape" 279 | return model_output 280 | 281 | return body_fn 282 | 283 | def get_score( 284 | self, 285 | ): 286 | """member function for obtaining score of 287 | x_t = alpha_t * x + sigma_t * eps""" 288 | if self.model_type == ModelType.NOISE: 289 | score_fn = ( 290 | lambda x, t, model, **kwargs: model(x, t, **kwargs) 291 | / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0] 292 | ) 293 | elif self.model_type == ModelType.SCORE: 294 | score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs) 295 | elif self.model_type == ModelType.VELOCITY: 296 | score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity( 297 | model(x, t, **kwargs), x, t 298 | ) 299 | else: 300 | raise NotImplementedError() 301 | 302 | return score_fn 303 | 304 | 305 | class Sampler: 306 | """Sampler class for the transport model""" 307 | 308 | def __init__( 309 | self, 310 | transport, 311 | ): 312 | """Constructor for a general sampler; supporting different sampling methods 313 | Args: 314 | - transport: an tranport object specify model prediction & interpolant type 315 | """ 316 | 317 | self.transport = transport 318 | self.drift = self.transport.get_drift() 319 | self.score = self.transport.get_score() 320 | 321 | def __get_sde_diffusion_and_drift( 322 | self, 323 | *, 324 | diffusion_form="SBDM", 325 | diffusion_norm=1.0, 326 | ): 327 | def diffusion_fn(x, t): 328 | diffusion = self.transport.path_sampler.compute_diffusion( 329 | x, t, form=diffusion_form, norm=diffusion_norm 330 | ) 331 | return diffusion 332 | 333 | sde_drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs) + diffusion_fn( 334 | x, t 335 | ) * self.score(x, t, model, **kwargs) 336 | 337 | sde_diffusion = diffusion_fn 338 | 339 | return sde_drift, sde_diffusion 340 | 341 | def __get_last_step( 342 | self, 343 | sde_drift, 344 | *, 345 | last_step, 346 | last_step_size, 347 | ): 348 | """Get the last step function of the SDE solver""" 349 | 350 | if last_step is None: 351 | last_step_fn = lambda x, t, model, **model_kwargs: x 352 | elif last_step == "Mean": 353 | last_step_fn = ( 354 | lambda x, t, model, **model_kwargs: x 355 | + sde_drift(x, t, model, **model_kwargs) * last_step_size 356 | ) 357 | elif last_step == "Tweedie": 358 | alpha = ( 359 | self.transport.path_sampler.compute_alpha_t 360 | ) # simple aliasing; the original name was too long 361 | sigma = self.transport.path_sampler.compute_sigma_t 362 | last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + ( 363 | sigma(t)[0][0] ** 2 364 | ) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs) 365 | elif last_step == "Euler": 366 | last_step_fn = ( 367 | lambda x, t, model, **model_kwargs: x 368 | + self.drift(x, t, model, **model_kwargs) * last_step_size 369 | ) 370 | else: 371 | raise NotImplementedError() 372 | 373 | return last_step_fn 374 | 375 | def sample_sde( 376 | self, 377 | *, 378 | sampling_method="Euler", 379 | diffusion_form="SBDM", 380 | diffusion_norm=1.0, 381 | last_step="Mean", 382 | last_step_size=0.04, 383 | num_steps=250, 384 | ): 385 | """returns a sampling function with given SDE settings 386 | Args: 387 | - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama 388 | - diffusion_form: function form of diffusion coefficient; default to be matching SBDM 389 | - diffusion_norm: function magnitude of diffusion coefficient; default to 1 390 | - last_step: type of the last step; default to identity 391 | - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1] 392 | - num_steps: total integration step of SDE 393 | """ 394 | 395 | if last_step is None: 396 | last_step_size = 0.0 397 | 398 | sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift( 399 | diffusion_form=diffusion_form, 400 | diffusion_norm=diffusion_norm, 401 | ) 402 | 403 | t0, t1 = self.transport.check_interval( 404 | self.transport.train_eps, 405 | self.transport.sample_eps, 406 | diffusion_form=diffusion_form, 407 | sde=True, 408 | eval=True, 409 | reverse=False, 410 | last_step_size=last_step_size, 411 | ) 412 | 413 | _sde = sde( 414 | sde_drift, 415 | sde_diffusion, 416 | t0=t0, 417 | t1=t1, 418 | num_steps=num_steps, 419 | sampler_type=sampling_method, 420 | ) 421 | 422 | last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size) 423 | 424 | def _sample(init, model, **model_kwargs): 425 | xs = _sde.sample(init, model, **model_kwargs) 426 | ts = th.ones(init.size(0), device=init.device) * t1 427 | x = last_step_fn(xs[-1], ts, model, **model_kwargs) 428 | xs.append(x) 429 | 430 | assert len(xs) == num_steps, "Samples does not match the number of steps" 431 | 432 | return xs 433 | 434 | return _sample 435 | 436 | def sample_ode( 437 | self, 438 | *, 439 | sampling_method="dopri5", 440 | num_steps=50, 441 | atol=1e-6, 442 | rtol=1e-3, 443 | reverse=False, 444 | timestep_shift=0.0, 445 | ): 446 | """returns a sampling function with given ODE settings 447 | Args: 448 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 449 | - num_steps: 450 | - fixed solver (Euler, Heun): the actual number of integration steps performed 451 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation 452 | - atol: absolute error tolerance for the solver 453 | - rtol: relative error tolerance for the solver 454 | - reverse: whether solving the ODE in reverse (data to noise); default to False 455 | """ 456 | if reverse: 457 | drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs) 458 | else: 459 | drift = self.drift 460 | 461 | t0, t1 = self.transport.check_interval( 462 | self.transport.train_eps, 463 | self.transport.sample_eps, 464 | sde=False, 465 | eval=True, 466 | reverse=reverse, 467 | last_step_size=0.0, 468 | ) 469 | 470 | _ode = ode( 471 | drift=drift, 472 | t0=t0, 473 | t1=t1, 474 | sampler_type=sampling_method, 475 | num_steps=num_steps, 476 | atol=atol, 477 | rtol=rtol, 478 | timestep_shift=timestep_shift, 479 | ) 480 | 481 | return _ode.sample 482 | 483 | def sample_ode_likelihood( 484 | self, 485 | *, 486 | sampling_method="dopri5", 487 | num_steps=50, 488 | atol=1e-6, 489 | rtol=1e-3, 490 | ): 491 | """returns a sampling function for calculating likelihood with given ODE settings 492 | Args: 493 | - sampling_method: type of sampler used in solving the ODE; default to be Dopri5 494 | - num_steps: 495 | - fixed solver (Euler, Heun): the actual number of integration steps performed 496 | - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation 497 | - atol: absolute error tolerance for the solver 498 | - rtol: relative error tolerance for the solver 499 | """ 500 | 501 | def _likelihood_drift(x, t, model, **model_kwargs): 502 | x, _ = x 503 | eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1 504 | t = th.ones_like(t) * (1 - t) 505 | with th.enable_grad(): 506 | x.requires_grad = True 507 | grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0] 508 | logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size())))) 509 | drift = self.drift(x, t, model, **model_kwargs) 510 | return (-drift, logp_grad) 511 | 512 | t0, t1 = self.transport.check_interval( 513 | self.transport.train_eps, 514 | self.transport.sample_eps, 515 | sde=False, 516 | eval=True, 517 | reverse=False, 518 | last_step_size=0.0, 519 | ) 520 | 521 | _ode = ode( 522 | drift=_likelihood_drift, 523 | t0=t0, 524 | t1=t1, 525 | sampler_type=sampling_method, 526 | num_steps=num_steps, 527 | atol=atol, 528 | rtol=rtol, 529 | ) 530 | 531 | def _sample_fn(x, model, **model_kwargs): 532 | init_logp = th.zeros(x.size(0)).to(x) 533 | input = (x, init_logp) 534 | drift, delta_logp = _ode.sample(input, model, **model_kwargs) 535 | drift, delta_logp = drift[-1], delta_logp[-1] 536 | prior_logp = self.transport.prior_logp(drift) 537 | logp = prior_logp - delta_logp 538 | return logp, drift 539 | 540 | return _sample_fn 541 | -------------------------------------------------------------------------------- /utils/builders.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch.utils.data 4 | import torchvision.transforms as transforms 5 | 6 | import models 7 | import utils.distributed as distributed 8 | import utils.losses as losses 9 | from utils.loader import ListDataset, center_crop_arr 10 | from utils.misc import NativeScalerWithGradNormCount 11 | 12 | logger = logging.getLogger("DeTok") 13 | 14 | 15 | def create_train_dataloader(args, should_flip=True, batch_size=-1, return_path=False, drop_last=True): 16 | transform_train = transforms.Compose( 17 | [ 18 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)), 19 | transforms.ToTensor(), 20 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 21 | ] 22 | ) 23 | input_transform = transform_train if not args.use_cached_tokens else None 24 | dataset_train = ListDataset( 25 | args.data_path, 26 | data_list="data/train.txt", 27 | transform=input_transform, 28 | loader_name="img_loader" if not args.use_cached_tokens else "npz_loader", 29 | return_label=True, 30 | return_path=return_path, 31 | should_flip=should_flip, 32 | ) 33 | logger.info(f"Train dataset size: {len(dataset_train)}") 34 | 35 | sampler_train = torch.utils.data.DistributedSampler( 36 | dataset_train, 37 | num_replicas=distributed.get_world_size(), 38 | rank=distributed.get_global_rank(), 39 | shuffle=True, 40 | ) 41 | data_loader_train = torch.utils.data.DataLoader( 42 | dataset_train, 43 | sampler=sampler_train, 44 | batch_size=args.batch_size if batch_size < 0 else batch_size, 45 | num_workers=args.num_workers, 46 | pin_memory=args.pin_mem, 47 | drop_last=drop_last, 48 | ) 49 | return data_loader_train 50 | 51 | 52 | def create_val_dataloader(args): 53 | transform_val = transforms.Compose( 54 | [ 55 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)), 56 | transforms.ToTensor(), 57 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 58 | ] 59 | ) 60 | dataset_val = ListDataset( 61 | args.data_path.replace("train", "val"), 62 | data_list="data/val.txt", 63 | transform=transform_val, 64 | loader_name="img_loader", 65 | return_label=False, 66 | return_index=True, 67 | should_flip=False, 68 | ) 69 | sampler_val = torch.utils.data.DistributedSampler( 70 | dataset_val, 71 | num_replicas=distributed.get_world_size(), 72 | rank=distributed.get_global_rank(), 73 | shuffle=False, 74 | ) 75 | 76 | logger.info(f"Val dataset size: {len(dataset_val)}") 77 | 78 | data_loader_val = torch.utils.data.DataLoader( 79 | dataset_val, 80 | sampler=sampler_val, 81 | batch_size=args.eval_bsz, 82 | num_workers=args.num_workers, 83 | pin_memory=args.pin_mem, 84 | drop_last=False, 85 | ) 86 | return data_loader_val 87 | 88 | 89 | def create_vis_dataloader(args): 90 | transform_val = transforms.Compose( 91 | [ 92 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.img_size)), 93 | transforms.ToTensor(), 94 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 95 | ] 96 | ) 97 | dataset_vis = ListDataset( 98 | args.data_path, 99 | data_list="data/train.txt", 100 | transform=transform_val, 101 | loader_name="img_loader", 102 | return_label=False, 103 | return_index=True, 104 | class_of_interest=args.class_of_interest, 105 | ) 106 | sampler_vis = torch.utils.data.DistributedSampler( 107 | dataset_vis, 108 | num_replicas=distributed.get_world_size(), 109 | rank=distributed.get_global_rank(), 110 | shuffle=True, 111 | ) 112 | 113 | logger.info(f"Vis dataset size: {len(dataset_vis)}") 114 | 115 | data_loader_vis = torch.utils.data.DataLoader( 116 | dataset_vis, 117 | sampler=sampler_vis, 118 | batch_size=8, 119 | num_workers=args.num_workers, 120 | pin_memory=args.pin_mem, 121 | drop_last=False, 122 | ) 123 | return data_loader_vis 124 | 125 | 126 | def create_generation_model(args): 127 | logger.info("Creating generation models.") 128 | if args.tokenizer is not None: 129 | if args.tokenizer in models.VAE_models: 130 | tokenizer = models.VAE_models[args.tokenizer]() 131 | elif args.tokenizer in models.DeTok_models: 132 | tokenizer = models.DeTok_models[args.tokenizer]( 133 | img_size=args.img_size, 134 | patch_size=args.tokenizer_patch_size, 135 | token_channels=args.token_channels, 136 | mask_ratio=0.0, 137 | ) 138 | else: 139 | raise ValueError(f"Unsupported tokenizer {args.tokenizer}") 140 | if args.load_tokenizer_from is not None: 141 | logger.info(f"[Tokenizer] Loading tokenizer from: {args.load_tokenizer_from}") 142 | weights = torch.load(args.load_tokenizer_from, weights_only=False, map_location="cpu") 143 | if args.use_ema_tokenizer and "model_ema" in weights: 144 | weights = weights["model_ema"] 145 | msg = tokenizer.load_state_dict(weights, strict=False) 146 | logger.info(f"[Tokenizer] Missing keys: {msg.missing_keys}") 147 | logger.info(f"[Tokenizer] Unexpected keys: {msg.unexpected_keys}") 148 | logger.info("[Tokenizer] Loaded EMA tokenizer.") 149 | else: 150 | if args.use_ema_tokenizer: 151 | logger.warning("EMA tokenizer is not in the checkpoint, using the model weights") 152 | weights = weights["model"] if "model" in weights else weights 153 | msg = tokenizer.load_state_dict(weights, strict=True) 154 | logger.info(f"[Tokenizer] Missing keys: {msg.missing_keys}") 155 | logger.info(f"[Tokenizer] Unexpected keys: {msg.unexpected_keys}") 156 | tokenizer.cuda().eval().requires_grad_(False) 157 | logger.info("====Tokenizer=====") 158 | logger.info(tokenizer) 159 | else: 160 | tokenizer = None 161 | 162 | if args.model in models.DiT_models: 163 | model = models.DiT_models[args.model]( 164 | img_size=args.img_size, 165 | patch_size=args.patch_size, 166 | tokenizer_patch_size=args.tokenizer_patch_size, 167 | token_channels=args.token_channels, 168 | label_drop_prob=args.label_drop_prob, 169 | num_classes=args.num_classes, 170 | num_sampling_steps=args.num_sampling_steps, 171 | force_one_d_seq=args.force_one_d_seq, 172 | grad_checkpointing=args.grad_checkpointing, 173 | legacy_mode=args.legacy_mode, # legacy mode: cfg on the first three channels only 174 | ) 175 | elif args.model in models.SiT_models: 176 | model = models.SiT_models[args.model]( 177 | img_size=args.img_size, 178 | patch_size=args.patch_size, 179 | tokenizer_patch_size=args.tokenizer_patch_size, 180 | token_channels=args.token_channels, 181 | label_drop_prob=args.label_drop_prob, 182 | num_classes=args.num_classes, 183 | num_sampling_steps=args.num_sampling_steps, 184 | grad_checkpointing=args.grad_checkpointing, 185 | force_one_d_seq=args.force_one_d_seq, 186 | legacy_mode=args.legacy_mode, # legacy mode: cfg on the first three channels only 187 | qk_norm=args.qk_norm, 188 | ) 189 | elif args.model in models.LightningDiT_models: 190 | model = models.LightningDiT_models[args.model]( 191 | img_size=args.img_size, 192 | patch_size=args.patch_size, 193 | tokenizer_patch_size=args.tokenizer_patch_size, 194 | token_channels=args.token_channels, 195 | label_drop_prob=args.label_drop_prob, 196 | num_classes=args.num_classes, 197 | num_sampling_steps=args.num_sampling_steps, 198 | force_one_d_seq=args.force_one_d_seq, 199 | grad_checkpointing=args.grad_checkpointing, 200 | legacy_mode=args.legacy_mode, # legacy mode: cfg on the first three channels only 201 | qk_norm=args.qk_norm, 202 | ) 203 | elif args.model in models.ARDiff_models: 204 | model = models.ARDiff_models[args.model]( 205 | img_size=args.img_size, 206 | patch_size=args.patch_size, 207 | tokenizer_patch_size=args.tokenizer_patch_size, 208 | token_channels=args.token_channels, 209 | label_drop_prob=args.label_drop_prob, 210 | num_classes=args.num_classes, 211 | num_sampling_steps=args.num_sampling_steps, 212 | diffloss_d=args.diffloss_d, 213 | diffloss_w=args.diffloss_w, 214 | diffusion_batch_mul=args.diffusion_batch_mul, 215 | noise_schedule=args.noise_schedule, 216 | force_one_d_seq=args.force_one_d_seq, 217 | grad_checkpointing=args.grad_checkpointing, 218 | order=args.order, 219 | ) 220 | elif args.model in models.MAR_models: 221 | model = models.MAR_models[args.model]( 222 | img_size=args.img_size, 223 | patch_size=args.patch_size, 224 | tokenizer_patch_size=args.tokenizer_patch_size, 225 | token_channels=args.token_channels, 226 | label_drop_prob=args.label_drop_prob, 227 | num_classes=args.num_classes, 228 | num_sampling_steps=args.num_sampling_steps, 229 | diffloss_d=args.diffloss_d, 230 | diffloss_w=args.diffloss_w, 231 | diffusion_batch_mul=args.diffusion_batch_mul, 232 | noise_schedule=args.noise_schedule, 233 | attn_dropout=args.attn_dropout, 234 | proj_dropout=args.proj_dropout, 235 | buffer_size=args.buffer_size, 236 | mask_ratio_min=args.mask_ratio_min, 237 | grad_checkpointing=args.grad_checkpointing, 238 | force_one_d_seq=args.force_one_d_seq, 239 | no_dropout_in_mlp=args.no_dropout_in_mlp, 240 | ) 241 | else: 242 | raise ValueError(f"Unsupported model {args.model}") 243 | 244 | model.cuda() 245 | logger.info("====Model=====") 246 | logger.info(model) 247 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 248 | logger.info(f"{args.model} Parameters: {n_params / 1e6:.2f}M ({n_params:,})") 249 | 250 | # ema model 251 | ema = models.SimpleEMAModel(model, decay=args.ema_rate) 252 | return model, tokenizer, ema 253 | 254 | 255 | def create_reconstruction_model(args): 256 | logger.info("Creating reconstruction models.") 257 | if args.model in models.VAE_models: 258 | model = models.VAE_models[args.model]( 259 | load_ckpt=not getattr(args, "no_load_ckpt", False), 260 | gamma=args.gamma, 261 | ) 262 | elif args.model in models.DeTok_models: 263 | model = models.DeTok_models[args.model]( 264 | img_size=args.img_size, 265 | patch_size=args.patch_size, 266 | token_channels=args.token_channels, 267 | mask_ratio=args.mask_ratio, 268 | gamma=args.gamma, 269 | ) 270 | else: 271 | raise ValueError(f"Unsupported model {args.model}") 272 | 273 | model.cuda() 274 | logger.info("====Model=====") 275 | logger.info(model) 276 | n_params = sum(p.numel() for p in model.parameters() if p.requires_grad) 277 | logger.info(f"{args.model} Trainable Parameters: {n_params / 1e6:.2f}M ({n_params:,})") 278 | ema = models.SimpleEMAModel(model, decay=args.ema_rate) 279 | return model, ema 280 | 281 | 282 | def create_optimizer_and_scaler(args, model, print_trainable_params=False): 283 | logger.info("creating optimizers") 284 | 285 | # exclude parameters from weight decay 286 | exclude = lambda name, p: ( 287 | p.ndim < 2 or any(keyword in name for keyword in 288 | ["ln", "bias", "embedding", "norm", "gamma", "embed", "token", "diffloss"]) 289 | ) 290 | 291 | named_parameters = list(model.named_parameters()) 292 | no_decay_list = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] 293 | rest_params = [p for n, p in named_parameters if not exclude(n, p) and p.requires_grad] 294 | eff_batch_size = args.batch_size * args.world_size 295 | 296 | if args.lr is None: 297 | args.lr = args.blr * eff_batch_size / 256 298 | 299 | logger.info(f"base lr: {args.lr * 256 / eff_batch_size:.6e}") 300 | logger.info(f"actual lr: {args.lr:.6e}") 301 | logger.info(f"effective batch size: {eff_batch_size}") 302 | logger.info(f"training with {args.world_size} gpus") 303 | logger.info(f"weight_decay: {args.weight_decay} on {len(rest_params)} weight tensors") 304 | logger.info(f"no_decay: {len(no_decay_list)} weight tensors") 305 | 306 | optimizer = torch.optim.AdamW( 307 | [ 308 | {"params": no_decay_list, "weight_decay": 0.0}, 309 | {"params": rest_params, "weight_decay": args.weight_decay}, 310 | ], 311 | lr=args.lr, 312 | betas=(args.beta1, args.beta2), 313 | ) 314 | logger.info(f"Optimizer = {str(optimizer)}") 315 | if print_trainable_params: 316 | logger.info("trainable parameters:") 317 | for name, param in model.named_parameters(): 318 | if param.requires_grad: 319 | logger.info(f"\t{name}") 320 | 321 | loss_scaler = NativeScalerWithGradNormCount() 322 | logger.info(f"Loss Scaler = {str(loss_scaler)}") 323 | return optimizer, loss_scaler 324 | 325 | 326 | def create_loss_module(args): 327 | loss_module = losses.ReconstructionLoss( 328 | discriminator_start_epoch=getattr(args, "discriminator_start_epoch", 20), 329 | perceptual_loss=getattr(args, "perceptual_loss", "lpips-convnext_s-1.0-0.1"), 330 | perceptual_weight=getattr(args, "perceptual_weight", 1.1), 331 | kl_weight=args.kl_loss_weight, 332 | ) 333 | loss_module.cuda() 334 | logger.info("====Loss Module=====") 335 | # logger.info(loss_module) 336 | return loss_module 337 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | import random 4 | import re 5 | import socket 6 | import sys 7 | 8 | import torch 9 | import torch.distributed as dist 10 | import torch.distributed.nn as dist_nn 11 | 12 | _LOCAL_RANK = -1 13 | _LOCAL_WORLD_SIZE = -1 14 | 15 | _TORCH_DISTRIBUTED_ENV_VARS = ( 16 | "MASTER_ADDR", 17 | "MASTER_PORT", 18 | "RANK", 19 | "WORLD_SIZE", 20 | "LOCAL_RANK", 21 | "LOCAL_WORLD_SIZE", 22 | ) 23 | 24 | 25 | def all_reduce_mean(x): 26 | world_size = get_world_size() 27 | if world_size > 1: 28 | if isinstance(x, torch.Tensor): 29 | x_reduce = x.clone().detach().cuda() 30 | else: 31 | x_reduce = torch.tensor(x).cuda() 32 | dist.all_reduce(x_reduce) 33 | x_reduce = x_reduce.float() / world_size 34 | return x_reduce.item() 35 | return x 36 | 37 | 38 | def concat_all_gather(tensor, gather_dim=0) -> torch.Tensor: 39 | if dist.get_world_size() == 1: 40 | return tensor 41 | output = dist_nn.functional.all_gather(tensor) 42 | return torch.cat(output, dim=gather_dim) 43 | 44 | 45 | def is_enabled() -> bool: 46 | return dist.is_available() and dist.is_initialized() 47 | 48 | 49 | def get_global_rank() -> int: 50 | return dist.get_rank() if is_enabled() else 0 51 | 52 | 53 | def get_world_size(): 54 | return dist.get_world_size() if is_enabled() else 1 55 | 56 | 57 | def is_main_process() -> bool: 58 | return get_global_rank() == 0 59 | 60 | 61 | def _is_slurm_job_process() -> bool: 62 | return "SLURM_JOB_ID" in os.environ and not os.isatty(sys.stdout.fileno()) 63 | 64 | 65 | def _parse_slurm_node_list(s: str) -> list[str]: 66 | nodes = [] 67 | p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") 68 | for m in p.finditer(s): 69 | prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] 70 | for suffix in suffixes.split(","): 71 | span = suffix.split("-") 72 | if len(span) == 1: 73 | nodes.append(prefix + suffix) 74 | else: 75 | width = len(span[0]) 76 | start, end = int(span[0]), int(span[1]) + 1 77 | nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) 78 | return nodes 79 | 80 | 81 | def _get_available_port() -> int: 82 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: 83 | s.bind(("", 0)) 84 | return s.getsockname()[1] 85 | 86 | 87 | @functools.lru_cache 88 | def enable_distributed(): 89 | if _is_slurm_job_process(): 90 | os.environ["MASTER_ADDR"] = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"])[0] 91 | os.environ["MASTER_PORT"] = str(random.Random(os.environ["SLURM_JOB_ID"]).randint(20_000, 60_000)) 92 | os.environ["RANK"] = os.environ["SLURM_PROCID"] 93 | os.environ["WORLD_SIZE"] = os.environ["SLURM_NTASKS"] 94 | os.environ["LOCAL_RANK"] = os.environ["SLURM_LOCALID"] 95 | os.environ["LOCAL_WORLD_SIZE"] = str( 96 | int(os.environ["WORLD_SIZE"]) // int(os.environ["SLURM_JOB_NUM_NODES"]) 97 | ) 98 | elif "MASTER_ADDR" not in os.environ: 99 | os.environ["MASTER_ADDR"] = "127.0.0.1" 100 | os.environ["MASTER_PORT"] = str(_get_available_port()) 101 | os.environ["RANK"] = "0" 102 | os.environ["WORLD_SIZE"] = "1" 103 | os.environ["LOCAL_RANK"] = "0" 104 | os.environ["LOCAL_WORLD_SIZE"] = "1" 105 | torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) 106 | dist.init_process_group(backend="nccl") 107 | dist.barrier(device_ids=[int(os.environ["LOCAL_RANK"])]) 108 | -------------------------------------------------------------------------------- /utils/download.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import requests 4 | from tqdm import tqdm 5 | 6 | 7 | def download_pretrained_vae(overwrite=False): 8 | download_path = "pretrained_models/vae/kl16.ckpt" 9 | if not os.path.exists(download_path) or overwrite: 10 | headers = {"user-agent": "Wget/1.16 (linux-gnu)"} 11 | os.makedirs("pretrained_models/vae", exist_ok=True) 12 | r = requests.get( 13 | "https://www.dropbox.com/scl/fi/hhmuvaiacrarfg28qxhwz/kl16.ckpt?rlkey=l44xipsezc8atcffdp4q7mwmh&dl=0", 14 | stream=True, 15 | headers=headers, 16 | ) 17 | print("Downloading KL-16 VAE...") 18 | with open(download_path, "wb") as f: 19 | for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024), unit="MB", total=254): 20 | if chunk: 21 | f.write(chunk) 22 | 23 | 24 | def download_pretrained_marb(overwrite=False): 25 | download_path = "pretrained_models/mar/mar_base/checkpoint-last.pth" 26 | if not os.path.exists(download_path) or overwrite: 27 | headers = {"user-agent": "Wget/1.16 (linux-gnu)"} 28 | os.makedirs("pretrained_models/mar/mar_base", exist_ok=True) 29 | r = requests.get( 30 | "https://www.dropbox.com/scl/fi/f6dpuyjb7fudzxcyhvrhk/checkpoint-last.pth?rlkey=a6i4bo71vhfo4anp33n9ukujb&dl=0", 31 | stream=True, 32 | headers=headers, 33 | ) 34 | print("Downloading MAR-B...") 35 | with open(download_path, "wb") as f: 36 | for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024), unit="MB", total=1587): 37 | if chunk: 38 | f.write(chunk) 39 | 40 | 41 | def download_pretrained_marl(overwrite=False): 42 | download_path = "pretrained_models/mar/mar_large/checkpoint-last.pth" 43 | if not os.path.exists(download_path) or overwrite: 44 | headers = {"user-agent": "Wget/1.16 (linux-gnu)"} 45 | os.makedirs("pretrained_models/mar/mar_large", exist_ok=True) 46 | r = requests.get( 47 | "https://www.dropbox.com/scl/fi/pxacc5b2mrt3ifw4cah6k/checkpoint-last.pth?rlkey=m48ovo6g7ivcbosrbdaz0ehqt&dl=0", 48 | stream=True, 49 | headers=headers, 50 | ) 51 | print("Downloading MAR-L...") 52 | with open(download_path, "wb") as f: 53 | for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024), unit="MB", total=3650): 54 | if chunk: 55 | f.write(chunk) 56 | 57 | 58 | def download_pretrained_marh(overwrite=False): 59 | download_path = "pretrained_models/mar/mar_huge/checkpoint-last.pth" 60 | if not os.path.exists(download_path) or overwrite: 61 | headers = {"user-agent": "Wget/1.16 (linux-gnu)"} 62 | os.makedirs("pretrained_models/mar/mar_huge", exist_ok=True) 63 | r = requests.get( 64 | "https://www.dropbox.com/scl/fi/1qmfx6fpy3k7j9vcjjs3s/checkpoint-last.pth?rlkey=4lae281yzxb406atp32vzc83o&dl=0", 65 | stream=True, 66 | headers=headers, 67 | ) 68 | print("Downloading MAR-H...") 69 | with open(download_path, "wb") as f: 70 | for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024), unit="MB", total=7191): 71 | if chunk: 72 | f.write(chunk) 73 | 74 | 75 | if __name__ == "__main__": 76 | download_pretrained_vae() 77 | download_pretrained_marb() 78 | download_pretrained_marl() 79 | download_pretrained_marh() 80 | -------------------------------------------------------------------------------- /utils/loader.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from functools import partial 4 | from typing import Any, Callable 5 | 6 | import numpy as np 7 | from PIL import Image 8 | from torchvision.datasets.folder import default_loader 9 | 10 | logger = logging.getLogger("DeTok") 11 | 12 | CONSTANTS = { 13 | # vavae latent statistics from https://huggingface.co/hustvl/vavae-imagenet256-f16d32-dinov2/blob/main/latents_stats.pt 14 | "vavae_mean": np.array([ 15 | 0.5984623, -0.49917176, 0.6440029, -0.0970839, -1.190963, -1.4331622, 16 | 0.46853292, 0.6259252, 0.63195026, -0.4896733, -0.74451625, 1.1595623, 17 | 0.8456217, 0.5008238, 0.22926894, 0.47535565, -0.43787342, 0.8316961, 18 | -0.0750857, 0.30632293, 0.46645293, -0.09140775, -0.82710165, 0.07807512, 19 | 1.4150785, 1.3792385, 0.2695843, -0.7573224, 0.28129938, -0.30919993, 20 | 0.07785388, 0.34966648, 21 | ]), 22 | "vavae_std": np.array([ 23 | 3.846138, 4.2699146, 3.5768437, 3.5911105, 3.6230576, 3.481018, 24 | 3.3074617, 3.5092657, 3.5540583, 3.6067245, 3.70579, 3.6314075, 25 | 3.6295316, 3.620502, 3.2590282, 3.186753, 3.8258142, 3.599939, 26 | 3.2966352, 3.226129, 3.2191944, 3.1054573, 3.580496, 4.356914, 27 | 3.308541, 3.2075875, 4.515047, 3.4869924, 3.0415804, 3.4868848, 28 | 4.4310327, 4.0881157, 29 | ]), 30 | } 31 | 32 | 33 | def center_crop_arr(pil_image: Image.Image, image_size: int) -> Image.Image: 34 | """center cropping implementation from adm. 35 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 36 | """ 37 | while min(*pil_image.size) >= 2 * image_size: 38 | pil_image = pil_image.resize( 39 | tuple(x // 2 for x in pil_image.size), 40 | resample=Image.Resampling.BOX 41 | ) 42 | 43 | scale = image_size / min(*pil_image.size) 44 | pil_image = pil_image.resize( 45 | tuple(round(x * scale) for x in pil_image.size), 46 | resample=Image.Resampling.BICUBIC 47 | ) 48 | 49 | arr = np.array(pil_image) 50 | crop_y = (arr.shape[0] - image_size) // 2 51 | crop_x = (arr.shape[1] - image_size) // 2 52 | return Image.fromarray(arr[crop_y:crop_y + image_size, crop_x:crop_x + image_size]) 53 | 54 | 55 | def default_np_loader(path: str) -> np.ndarray[Any, np.dtype[Any]]: 56 | return np.load(path, allow_pickle=True) 57 | 58 | 59 | class ListDataset: 60 | def __init__( 61 | self, 62 | data_root: str, 63 | data_list: str, 64 | transform: Callable[[Any], Any] | None = None, 65 | loader_name: str = "npz_loader", 66 | return_path: bool = False, 67 | return_label: bool = True, 68 | return_index: bool = False, 69 | should_flip: bool = True, 70 | class_of_interest: list[int] | None = None, 71 | ): 72 | self.data_root = data_root 73 | self.transform = transform 74 | self.return_path = return_path 75 | self.return_label = return_label 76 | self.return_index = return_index 77 | self.should_flip = should_flip 78 | self.class_of_interest = class_of_interest 79 | 80 | # loader function mapping 81 | loader_functions = { 82 | "img_loader": default_loader, 83 | "npz_loader": partial(np.load, allow_pickle=True), 84 | } 85 | 86 | if loader_name not in loader_functions: 87 | raise ValueError(f"Loader '{loader_name}' not supported") 88 | 89 | self.loader = loader_functions[loader_name] 90 | self.load_vae_latents = loader_name == "npz_loader" 91 | self.samples = self._load_samples(data_list, loader_name) 92 | self.targets = [label for _, label in self.samples] 93 | 94 | def _load_samples(self, data_list: str, loader_name: str) -> list[tuple[str, int | None]]: 95 | samples = [] 96 | with open(data_list, "r") as f: 97 | for line in f: 98 | splits = line.strip().split(" ") 99 | if len(splits) == 2: 100 | file_path, label = splits 101 | label = int(label) 102 | else: 103 | file_path = line.strip() 104 | label = None 105 | 106 | if self.class_of_interest and label not in self.class_of_interest: 107 | continue 108 | 109 | # adjust file extensions based on loader 110 | if loader_name == "npz_loader": 111 | file_path = file_path.replace(".JPEG", ".JPEG.npz") 112 | 113 | samples.append((file_path, label)) 114 | return samples 115 | 116 | def __getitem__(self, index: int) -> dict[str, Any]: 117 | return self._get_item_with_retry(index, 0) 118 | 119 | def _get_item_with_retry(self, index: int, retry_count: int) -> dict[str, Any]: 120 | if retry_count >= 100: 121 | raise RuntimeError(f"Failed to load data after 100 retries, last index: {index}") 122 | 123 | img_pth, label = self.samples[index] 124 | img_path_full = os.path.join(self.data_root, img_pth) 125 | should_flip = np.random.rand() < 0.5 if self.should_flip else False 126 | to_return = {} 127 | 128 | try: 129 | img = self.loader(img_path_full) 130 | if self.load_vae_latents: 131 | img_data = img # type: ignore 132 | img = img_data["moments_flip"] if should_flip else img_data["moments"] 133 | to_return = {"token": img} 134 | except Exception as e: 135 | logger.error(f"Error loading '{img_pth}': {e}") 136 | return self._get_item_with_retry((index + 1) % len(self.samples), retry_count + 1) 137 | 138 | if self.transform is not None: 139 | if "token" in to_return: 140 | # load original image when we have vae latents 141 | img_path_relative = img_path_full.split("/")[3:] 142 | img_path_relative = os.path.join(*img_path_relative) 143 | img_path_relative = img_path_relative.replace(".npz", "") 144 | img_path_full = os.path.join(self.data_root, img_path_relative) 145 | img = default_loader(img_path_full) 146 | 147 | img = self.transform(img) 148 | if should_flip: 149 | img = img.flip(dims=[2]) 150 | 151 | if len(to_return) > 0: 152 | to_return["img"] = img 153 | else: 154 | to_return = {"img": img} 155 | 156 | if self.return_index: 157 | to_return["index"] = index 158 | if self.return_label: 159 | to_return["label"] = label 160 | if self.return_path: 161 | to_return["img_pth"] = img_pth 162 | 163 | return to_return 164 | 165 | def __len__(self) -> int: 166 | return len(self.samples) 167 | -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import datetime 3 | import json 4 | import logging 5 | import os 6 | import sys 7 | import time 8 | from collections import defaultdict, deque 9 | from typing import Any 10 | 11 | import torch 12 | import torch.distributed 13 | import wandb 14 | from rich.logging import RichHandler 15 | from typing_extensions import override 16 | 17 | from .distributed import get_global_rank, is_enabled, is_main_process 18 | 19 | logger = logging.getLogger("DeTok") 20 | 21 | 22 | def move_to_device(obj: Any, device: torch.device) -> Any: 23 | """recursively moves tensors in obj to the specified device.""" 24 | if isinstance(obj, torch.Tensor): 25 | return obj.to(device, non_blocking=True) 26 | elif isinstance(obj, dict): 27 | return {key: move_to_device(value, device) for key, value in obj.items()} 28 | elif isinstance(obj, (list, tuple)): 29 | return type(obj)(move_to_device(o, device) for o in obj) 30 | else: 31 | return obj 32 | 33 | 34 | class SmoothedValue: 35 | """track a series of values and provide access to smoothed values over a window or the global series average.""" 36 | 37 | def __init__(self, window_size: int = 20, fmt: str | None = None): 38 | if fmt is None: 39 | fmt = "{median:.4f} ({global_avg:.4f})" 40 | self.deque = deque(maxlen=window_size) 41 | self.total = 0.0 42 | self.count = 0 43 | self.fmt = fmt 44 | 45 | def update(self, value: float, num: int = 1) -> None: 46 | self.deque.append(value) 47 | self.count += num 48 | self.total += value * num 49 | 50 | def synchronize_between_processes(self) -> None: 51 | """distributed synchronization of the metric. warning: does not synchronize the deque!""" 52 | if not is_enabled(): 53 | return 54 | t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") 55 | torch.distributed.barrier() 56 | torch.distributed.all_reduce(t) 57 | t = t.tolist() 58 | self.count = int(t[0]) 59 | self.total = t[1] 60 | 61 | @property 62 | def median(self) -> float: 63 | d = torch.tensor(list(self.deque)) 64 | return d.median().item() 65 | 66 | @property 67 | def avg(self) -> float: 68 | d = torch.tensor(list(self.deque), dtype=torch.float32) 69 | return d.mean().item() 70 | 71 | @property 72 | def global_avg(self) -> float: 73 | return self.total / self.count 74 | 75 | @property 76 | def max(self) -> float: 77 | return max(self.deque) 78 | 79 | @property 80 | def value(self) -> float: 81 | return self.deque[-1] 82 | 83 | @override 84 | def __str__(self) -> str: 85 | return self.fmt.format( 86 | median=self.median, 87 | avg=self.avg, 88 | global_avg=self.global_avg, 89 | max=self.max, 90 | value=self.value, 91 | ) 92 | 93 | 94 | class MetricLogger: 95 | def __init__(self, delimiter: str = "\t", output_file: str | None = None, prefetch: bool = False): 96 | self.meters = defaultdict(SmoothedValue) 97 | self.delimiter = delimiter 98 | self.output_file = output_file 99 | self.prefetch = prefetch 100 | logger.info(f"MetricLogger: output_file={output_file}, prefetch={prefetch}") 101 | 102 | def update(self, **kwargs) -> None: 103 | for k, v in kwargs.items(): 104 | if v is None: 105 | continue 106 | if isinstance(v, torch.Tensor): 107 | v = v.item() 108 | assert isinstance(v, (float, int)) 109 | self.meters[k].update(v) 110 | 111 | def __getattr__(self, attr: str): 112 | if attr in self.meters: 113 | return self.meters[attr] 114 | if attr in self.__dict__: 115 | return self.__dict__[attr] 116 | raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'") 117 | 118 | @override 119 | def __str__(self) -> str: 120 | loss_str = [] 121 | for name, meter in self.meters.items(): 122 | loss_str.append(f"{name}: {str(meter)}") 123 | return self.delimiter.join(loss_str) 124 | 125 | def synchronize_between_processes(self) -> None: 126 | for meter in self.meters.values(): 127 | meter.synchronize_between_processes() 128 | 129 | def add_meter(self, name: str, meter: SmoothedValue) -> None: 130 | self.meters[name] = meter 131 | 132 | def dump_in_output_file(self, iteration: int, iter_time: float, data_time: float) -> None: 133 | if self.output_file is None or not is_main_process(): 134 | return 135 | dict_to_dump = dict( 136 | iteration=iteration, 137 | iter_time=iter_time, 138 | data_time=data_time, 139 | ) 140 | dict_to_dump.update({k: v.median for k, v in self.meters.items()}) 141 | with open(self.output_file, "a") as f: 142 | f.write(json.dumps(dict_to_dump) + "\n") 143 | 144 | def log_every( 145 | self, 146 | iterable, 147 | print_freq: int, 148 | header: str | None = None, 149 | n_iterations: int | None = None, 150 | start_iteration: int = 0, 151 | ): 152 | i = start_iteration 153 | if not header: 154 | header = "" 155 | start_time = time.time() 156 | end = time.time() 157 | iter_time = SmoothedValue(fmt="{avg:.4f}") 158 | data_time = SmoothedValue(fmt="{avg:.4f}") 159 | 160 | if n_iterations is None: 161 | try: 162 | n_iterations = len(iterable) 163 | except TypeError: 164 | # iterable doesn't have len, use a default or require user to provide 165 | raise ValueError("n_iterations must be provided for iterables without __len__") 166 | 167 | space_fmt = ":" + str(len(str(n_iterations))) + "d" 168 | 169 | log_list = [ 170 | header, 171 | "[{0" + space_fmt + "}/{1}]", 172 | "eta: {eta}", 173 | "elapsed: {elapsed_time_str}", 174 | "{meters}", 175 | "time: {time}", 176 | "data: {data}", 177 | ] 178 | if torch.cuda.is_available(): 179 | log_list += ["max mem: {memory:.0f}"] 180 | 181 | log_msg = self.delimiter.join(log_list) 182 | MB = 1024.0 * 1024.0 183 | for obj in iterable: 184 | if self.prefetch: 185 | obj = move_to_device(obj, torch.device("cuda")) 186 | data_time.update(time.time() - end) 187 | yield obj 188 | iter_time.update(time.time() - end) 189 | if i % print_freq == 0 or i == n_iterations - 1: 190 | self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg) 191 | eta_seconds = iter_time.global_avg * (n_iterations - i) 192 | eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) 193 | elapsed_time = time.time() - start_time 194 | elapsed_time_str = str(datetime.timedelta(seconds=int(elapsed_time))) 195 | 196 | if torch.cuda.is_available(): 197 | logger.info( 198 | log_msg.format( 199 | i, 200 | n_iterations, 201 | eta=eta_string, 202 | elapsed_time_str=elapsed_time_str, 203 | meters=str(self), 204 | time=str(iter_time), 205 | data=str(data_time), 206 | memory=torch.cuda.max_memory_allocated() / MB, 207 | ) 208 | ) 209 | else: 210 | logger.info( 211 | log_msg.format( 212 | i, 213 | n_iterations, 214 | eta=eta_string, 215 | meters=str(self), 216 | time=str(iter_time), 217 | data=str(data_time), 218 | ) 219 | ) 220 | i += 1 221 | end = time.time() 222 | if i >= n_iterations: 223 | break 224 | total_time = time.time() - start_time 225 | total_time_str = str(datetime.timedelta(seconds=int(total_time))) 226 | logger.info(f"{header} Total time: {total_time_str} ({total_time / n_iterations:.6f} s / it)") 227 | 228 | 229 | class WandbLogger: 230 | def __init__( 231 | self, 232 | config, 233 | entity: str, 234 | project: str, 235 | name: str, 236 | log_dir: str, 237 | run_id: str | None = None, 238 | ): 239 | self.run = wandb.init( 240 | config=config, 241 | entity=entity, 242 | project=project, 243 | name=name, 244 | dir=log_dir, 245 | resume="allow", 246 | id=run_id, 247 | ) 248 | self.run_id = self.run.id 249 | self.step = 0 250 | self.run.log_code(".") 251 | 252 | def update(self, metrics, step: int | None = None) -> None: 253 | log_dict = { 254 | k: v.item() if isinstance(v, torch.Tensor) else v for k, v in metrics.items() if v is not None 255 | } 256 | try: 257 | wandb.log(log_dict, step=step or self.step) 258 | except Exception as e: 259 | logger.error(f"wandb logging failed: {e}") 260 | if step is not None: 261 | self.step = step 262 | 263 | def finish(self) -> None: 264 | try: 265 | wandb.finish() 266 | except Exception as e: 267 | logger.error(f"wandb failed to finish: {e}") 268 | 269 | 270 | def setup_logging(output: str, name: str = "DeTok", rank0_log_only: bool = True) -> None: 271 | """setup logging.""" 272 | logging.captureWarnings(True) 273 | 274 | logger = logging.getLogger(name) 275 | logger.setLevel(logging.INFO) 276 | logger.propagate = False 277 | 278 | # google glog format: [IWEF]yyyymmdd hh:mm:ss logger filename:line] msg 279 | fmt_prefix = "%(levelname).1s%(asctime)s %(name)s %(filename)s:%(lineno)s] " 280 | fmt_message = "%(message)s" 281 | fmt = fmt_prefix + fmt_message 282 | datefmt = "%Y%m%d %H:%M:%S" 283 | formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) 284 | 285 | # stdout logging for main worker only 286 | if is_main_process(): 287 | if sys.stdout.isatty(): 288 | handler = RichHandler(markup=True, show_time=False, show_level=False) 289 | else: 290 | handler = logging.StreamHandler(stream=sys.stdout) 291 | handler.setLevel(logging.DEBUG) 292 | handler.setFormatter(formatter) 293 | logger.addHandler(handler) 294 | 295 | # file logging 296 | if output: 297 | if os.path.splitext(output)[-1] in (".txt", ".log"): 298 | # if output is a file, use it directly 299 | filename = output 300 | else: 301 | # if output is a directory, use the directory/log.txt 302 | filename = os.path.join(output, "log.txt") 303 | 304 | # if not rank 0 but rank0_log_only=false, append the rank id 305 | if not is_main_process() and not rank0_log_only: 306 | global_rank = get_global_rank() 307 | if filename.endswith(".txt"): 308 | filename = filename.replace(".txt", f".rank{global_rank}.txt") 309 | else: 310 | filename = f"{filename}.rank{global_rank}" 311 | 312 | os.makedirs(os.path.dirname(filename), exist_ok=True) 313 | 314 | handler = logging.StreamHandler(open(filename, "a")) 315 | handler.setLevel(logging.DEBUG) 316 | handler.setFormatter(formatter) 317 | logger.addHandler(handler) 318 | 319 | 320 | def setup_wandb(args: argparse.Namespace, entity: str, project: str, name: str, log_dir: str) -> WandbLogger: 321 | """Setup Weights & Biases logging with resume capability.""" 322 | run_id_path = os.path.join(log_dir, "wandb_run_id.txt") 323 | run_id = None 324 | 325 | # resume from wandb run id if it exists 326 | if os.path.exists(run_id_path): 327 | with open(run_id_path, "r") as f: 328 | run_id = f.readlines()[-1].strip() 329 | 330 | wandb_logger = WandbLogger( 331 | config=args, 332 | entity=entity, 333 | project=project, 334 | name=name, 335 | log_dir=log_dir, 336 | run_id=run_id, 337 | ) 338 | 339 | # if no run id, save the run id to the log directory 340 | if run_id is None: 341 | with open(run_id_path, "a") as f: 342 | f.write(wandb_logger.run.id + "\n") 343 | 344 | return wandb_logger 345 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import argparse 4 | import os 5 | import copy 6 | import datetime 7 | from glob import glob 8 | import logging 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.utils 13 | from torch import inf 14 | 15 | import utils.distributed as dist 16 | 17 | logger = logging.getLogger("DeTok") 18 | 19 | 20 | def fix_random_seeds(seed=31): 21 | torch.manual_seed(seed) 22 | torch.cuda.manual_seed_all(seed) 23 | np.random.seed(seed) 24 | random.seed(seed) 25 | 26 | 27 | def adjust_learning_rate(optimizer, epoch, args): 28 | """Decay the learning rate with half-cycle cosine after warmup""" 29 | if epoch < args.warmup_epochs: 30 | lr = args.lr * epoch / args.warmup_epochs 31 | else: 32 | if args.lr_sched == "constant": 33 | lr = args.lr 34 | elif args.lr_sched == "cosine": 35 | progress = (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs) 36 | lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * (1.0 + math.cos(math.pi * progress)) 37 | else: 38 | raise NotImplementedError 39 | for param_group in optimizer.param_groups: 40 | if "lr_scale" in param_group: 41 | param_group["lr"] = lr * param_group["lr_scale"] 42 | else: 43 | param_group["lr"] = lr 44 | return lr 45 | 46 | 47 | def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor: 48 | if isinstance(parameters, torch.Tensor): 49 | parameters = [parameters] 50 | parameters = [p for p in parameters if p.grad is not None] 51 | norm_type = float(norm_type) 52 | if len(parameters) == 0: 53 | return torch.tensor(0.0) 54 | device = parameters[0].grad.device 55 | if norm_type == inf: 56 | total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters) 57 | else: 58 | total_norm = torch.norm( 59 | torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), 60 | norm_type, 61 | ) 62 | return total_norm 63 | 64 | 65 | class NativeScalerWithGradNormCount: 66 | state_dict_key = "amp_scaler" 67 | 68 | def __init__(self, enabled: bool = True): 69 | self._scaler = torch.GradScaler(device="cuda", enabled=enabled) 70 | 71 | def __call__( 72 | self, 73 | loss, 74 | optimizer, 75 | clip_grad=None, 76 | parameters=None, 77 | create_graph=False, 78 | update_grad=True, 79 | ): 80 | self._scaler.scale(loss).backward(create_graph=create_graph) 81 | if update_grad: 82 | if clip_grad is not None and clip_grad > 0.0: 83 | assert parameters is not None 84 | self._scaler.unscale_(optimizer) 85 | norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad) 86 | else: 87 | self._scaler.unscale_(optimizer) 88 | norm = get_grad_norm_(parameters) 89 | self._scaler.step(optimizer) 90 | self._scaler.update() 91 | else: 92 | norm = None 93 | return norm 94 | 95 | def state_dict(self): 96 | return self._scaler.state_dict() 97 | 98 | def load_state_dict(self, state_dict): 99 | self._scaler.load_state_dict(state_dict) 100 | 101 | 102 | def ckpt_resume( 103 | args: argparse.Namespace, 104 | model: torch.nn.Module, 105 | optimizer: torch.optim.Optimizer | None = None, 106 | loss_scaler: NativeScalerWithGradNormCount | None = None, 107 | model_ema: torch.nn.Module | None = None, 108 | loss_module: torch.nn.Module | None = None, 109 | discriminator_optimizer: torch.optim.Optimizer | None = None, 110 | discriminator_loss_scaler: NativeScalerWithGradNormCount | None = None, 111 | ): 112 | if args.resume_from or args.auto_resume: 113 | if args.resume_from is None: 114 | # find the latest checkpoint 115 | checkpoints = [ckpt for ckpt in glob(f"{args.ckpt_dir}/*.pth") if "latest" not in ckpt] 116 | checkpoints = sorted(checkpoints, key=os.path.getmtime) 117 | if len(checkpoints) > 0: 118 | args.resume_from = checkpoints[-1] 119 | 120 | if args.resume_from and os.path.exists(args.resume_from): 121 | # load the checkpoint 122 | logger.info(f"[Model-resume] Resuming from: {args.resume_from}") 123 | checkpoint = torch.load(args.resume_from, map_location="cpu", weights_only=False) 124 | msg = model.load_state_dict(checkpoint["model"]) 125 | logger.info(f"[Model-resume] Loaded model: {msg}") 126 | 127 | if "model_ema" in checkpoint: 128 | # load the EMA state dict if it exists 129 | ema_state_dict = checkpoint["model_ema"] 130 | logger.info(f"[Model-resume] Loaded EMA") 131 | else: 132 | # if no EMA state dict, use the model state dict to initialize the EMA state dict 133 | model_state_dict = model.state_dict() 134 | param_keys = [k for k, _ in model.named_parameters()] 135 | ema_state_dict = {k: model_state_dict[k] for k in param_keys} 136 | logger.info(f"[Model-resume] Loaded EMA with model state dict") 137 | 138 | # load the EMA state dict if it exists 139 | if model_ema is not None: 140 | model_ema.load_state_dict(ema_state_dict) 141 | model_ema.to("cuda") # move the EMA model to the GPU 142 | 143 | # load the optimizer state dict if it exists 144 | if "optimizer" in checkpoint and "epoch" in checkpoint and optimizer is not None: 145 | optimizer.load_state_dict(checkpoint["optimizer"]) 146 | args.start_epoch = checkpoint["epoch"] + 1 147 | # load the loss scaler state dict if it exists 148 | if "loss_scaler" in checkpoint and loss_scaler is not None: 149 | loss_scaler.load_state_dict(checkpoint["loss_scaler"]) 150 | 151 | # load the last elapsed time if it exists 152 | if "last_elapsed_time" in checkpoint: 153 | args.last_elapsed_time = float(checkpoint["last_elapsed_time"]) 154 | elapsed_time_str = str(datetime.timedelta(seconds=int(args.last_elapsed_time))) 155 | logger.info(f"Loaded elapsed_time: {elapsed_time_str}") 156 | 157 | # load the loss module state dict if it exists 158 | if "loss_module" in checkpoint and loss_module is not None: 159 | msg = loss_module.load_state_dict(checkpoint["loss_module"]) 160 | logger.info(f"[Model-resume] Loaded loss_module: {msg}") 161 | 162 | if "discriminator_optimizer" in checkpoint and discriminator_optimizer is not None: 163 | msg = discriminator_optimizer.load_state_dict(checkpoint["discriminator_optimizer"]) 164 | logger.info(f"[Model-resume] Loaded discriminator_optimizer: {msg}") 165 | 166 | if "discriminator_loss_scaler" in checkpoint and discriminator_loss_scaler is not None: 167 | msg = discriminator_loss_scaler.load_state_dict(checkpoint["discriminator_loss_scaler"]) 168 | logger.info(f"[Model-resume] Loaded discriminator_loss_scaler: {msg}") 169 | 170 | # delete the checkpoint to save memory 171 | del checkpoint 172 | else: 173 | logger.info(f"[Model-resume] Could not find checkpoint at {args.resume_from}.") 174 | else: 175 | logger.info(f"[Model-resume] Could not find checkpoint at {args.resume_from}.") 176 | 177 | if args.load_from and not args.resume_from: 178 | # if no checkpoint is provided, load the checkpoint from the load_from path instead 179 | if os.path.exists(args.load_from): 180 | logger.info(f"[Model-load] Loading checkpoint from: {args.load_from}") 181 | checkpoint = torch.load(args.load_from, map_location="cpu", weights_only=False) 182 | # load the model state dict if it exists 183 | state_dict = checkpoint["model"] if "model" in checkpoint else checkpoint 184 | msg = model.load_state_dict(state_dict, strict=False) 185 | # assert unexpected keys can only start with "loss." 186 | for key in msg.unexpected_keys: 187 | assert key.startswith("loss."), f"unexpected key {key} doesn't start with 'loss.'" 188 | logger.info(f"[Model-load] Loaded model: {msg}") 189 | if "model_ema" in checkpoint: 190 | logger.info(f"[Model-load] Loaded EMA") 191 | ema_state_dict = checkpoint["model_ema"] 192 | else: 193 | logger.info(f"[Model-load] Loaded EMA with model state dict") 194 | ema_state_dict = copy.deepcopy(model.state_dict()) 195 | if model_ema is not None: 196 | model_ema.load_state_dict(ema_state_dict) 197 | model_ema.to(device="cuda") # move the EMA model to the GPU 198 | del checkpoint # delete the checkpoint to save memory 199 | else: 200 | raise FileNotFoundError(f"Could not find checkpoint at {args.load_from}") 201 | 202 | 203 | def cleanup_checkpoints(ckpt_dir: str, keep_num: int = 5, milestone_interval: int = 5): 204 | """ 205 | Clean up older checkpoint files in `ckpt_dir` while keeping the latest `keep_num` checkpoints by epoch number. 206 | 207 | Parameters 208 | ---------- 209 | ckpt_dir : str 210 | The directory where checkpoint .pth files are stored. 211 | keep_num : int, optional 212 | The number of most recent checkpoints to keep (default=5). 213 | milestone_interval : int, optional 214 | The interval used to decide if a checkpoint is a "milestone." 215 | If (epoch_num + 1) % milestone_interval == 0, it is kept (default=50). 216 | """ 217 | ckpts = glob(os.path.join(ckpt_dir, "*.pth")) 218 | ckpts = [ckpt for ckpt in ckpts if "latest" not in ckpt and "best" not in ckpt] 219 | 220 | def get_ckpt_num(path): 221 | """Extract the epoch number from a checkpoint filename.""" 222 | filename = os.path.basename(path) 223 | # expecting something like 'epoch_049.pth' 224 | # we'll parse out the part after the last underscore and before '.pth' 225 | try: 226 | return int(filename.rsplit("_", 1)[-1].split(".")[0]) 227 | except ValueError: 228 | return None 229 | 230 | # sort checkpoints by epoch number 231 | ckpts.sort(key=lambda x: (get_ckpt_num(x) is None, get_ckpt_num(x))) 232 | 233 | # filter out any that failed to parse an integer epoch (get_ckpt_num == None) 234 | ckpts = [ckpt for ckpt in ckpts if get_ckpt_num(ckpt) is not None] 235 | 236 | if not ckpts: 237 | # if no checkpoints remain, nothing to do 238 | return 239 | 240 | # determine which checkpoints to keep: 241 | # 1. the newest `keep_num` by epoch number. 242 | # 2. any milestone checkpoints. 243 | # (epoch_num + 1) % milestone_interval == 0 244 | newest_keep = set(ckpts[-keep_num:]) # handle if keep_num > number of ckpts 245 | milestone_keep = set(ckpt for ckpt in ckpts if ((get_ckpt_num(ckpt) + 1) % milestone_interval == 0)) 246 | 247 | # union of both sets 248 | keep_set = newest_keep.union(milestone_keep) 249 | 250 | # remove anything not in keep_set 251 | for ckpt in ckpts: 252 | if ckpt not in keep_set: 253 | os.remove(ckpt) 254 | logger.info(f"Removed checkpoint: {ckpt}") 255 | 256 | # recreate the 'latest.pth' symlink to the newest checkpoint 257 | if keep_set: 258 | # we need the absolute newest based on epoch number 259 | # sort again from keep_set only 260 | remaining_ckpts_sorted = sorted(keep_set, key=lambda x: (get_ckpt_num(x) is None, get_ckpt_num(x))) 261 | newest_ckpt = os.path.abspath(remaining_ckpts_sorted[-1]) 262 | latest_symlink = os.path.join(ckpt_dir, "latest.pth") 263 | 264 | # remove the old symlink if it exists 265 | try: 266 | os.remove(latest_symlink) 267 | logger.info(f"Removed old symlink: {latest_symlink}") 268 | except FileNotFoundError: 269 | pass 270 | 271 | # create a new symlink 272 | os.symlink(newest_ckpt, latest_symlink) 273 | logger.info(f"Created symlink: {latest_symlink} -> {newest_ckpt}") 274 | 275 | 276 | def save_checkpoint( 277 | args, 278 | epoch, 279 | model, 280 | optimizer, 281 | loss_scaler, 282 | model_ema, 283 | elapsed_time=0.0, 284 | loss_module=None, 285 | discriminator_optimizer=None, 286 | discriminator_loss_scaler=None, 287 | ): 288 | if not dist.is_main_process(): 289 | return 290 | checkpoint = { 291 | "model": model.state_dict(), 292 | "model_ema": model_ema.state_dict() if model_ema is not None else None, 293 | "optimizer": optimizer.state_dict(), 294 | "loss_scaler": loss_scaler.state_dict(), 295 | "epoch": epoch, 296 | "last_elapsed_time": elapsed_time, 297 | } 298 | if loss_module is not None and isinstance(loss_module, torch.nn.Module): 299 | checkpoint["loss_module"] = loss_module.state_dict() 300 | if discriminator_optimizer is not None: 301 | checkpoint["discriminator_optimizer"] = discriminator_optimizer.state_dict() 302 | if discriminator_loss_scaler is not None: 303 | checkpoint["discriminator_loss_scaler"] = discriminator_loss_scaler.state_dict() 304 | checkpoint_path = os.path.join(args.ckpt_dir, f"epoch_{epoch:04d}.pth") 305 | torch.save(checkpoint, checkpoint_path) 306 | logger.info(f"Saved checkpoint: {checkpoint_path}") 307 | cleanup_checkpoints(args.ckpt_dir, args.keep_n_ckpts, args.milestone_interval) 308 | --------------------------------------------------------------------------------