├── .gitignore ├── GETTING_STARTED.md ├── LICENSE ├── README.md ├── app.py ├── assets └── teaser.jpg ├── autoregressive ├── models │ ├── generate.py │ ├── gpt.py │ └── gpt_hf.py ├── sample │ ├── sample_c2i.py │ ├── sample_c2i_ddp.py │ ├── sample_t2i.py │ └── sample_t2i_ddp.py ├── serve │ ├── README.md │ ├── fake_json │ │ ├── GPT-3B.json │ │ ├── GPT-B.json │ │ ├── GPT-L.json │ │ ├── GPT-XL.json │ │ └── GPT-XXL.json │ ├── gpt_model.py │ ├── gpu_executor.py │ ├── llm.py │ ├── llm_engine.py │ ├── model_runner.py │ ├── sample_c2i.py │ ├── sampler.py │ └── worker.py └── train │ ├── extract_codes_c2i.py │ ├── extract_codes_t2i.py │ ├── train_c2i.py │ ├── train_c2i_fsdp.py │ └── train_t2i.py ├── dataset ├── augmentation.py ├── build.py ├── coco.py ├── imagenet.py ├── openimage.py ├── pexels.py └── t2i.py ├── evaluations ├── c2i │ ├── README.md │ └── evaluator.py └── t2i │ ├── PartiPrompts.tsv │ ├── README.md │ ├── coco_captions.csv │ └── evaluation.py ├── language ├── README.md ├── extract_t5_feature.py └── t5.py ├── requirements.txt ├── scripts ├── autoregressive │ ├── extract_codes_c2i.sh │ ├── sample_c2i.sh │ ├── sample_t2i_coco.sh │ ├── sample_t2i_parti.sh │ ├── train_c2i.sh │ ├── train_c2i_fsdp.sh │ ├── train_t2i_stage1.sh │ └── train_t2i_stage2.sh ├── language │ ├── extract_flan_t5_feat_laion_coco_stage1.sh │ ├── extract_flan_t5_feat_stage2.sh │ └── extract_flan_t5_feat_trunc_stage2.sh └── tokenizer │ ├── reconstruction_consistency_decoder.sh │ ├── reconstruction_vae.sh │ ├── reconstruction_vq.sh │ ├── reconstruction_vqgan.sh │ ├── train_vq.sh │ ├── train_vq_finetune.sh │ ├── train_vq_finetune_continue.sh │ └── val.sh ├── tokenizer ├── consistencydecoder │ ├── README.md │ ├── cd_demo.py │ └── reconstruction_cd_ddp.py ├── tokenizer_image │ ├── cache │ │ └── vgg.pth │ ├── discriminator.py │ ├── discriminator_patchgan.py │ ├── discriminator_stylegan.py │ ├── lpips.py │ ├── reconstruction_vq_ddp.py │ ├── vq_demo.py │ ├── vq_loss.py │ ├── vq_model.py │ ├── vq_model_hf.py │ └── vq_train.py ├── vae │ ├── README.md │ ├── reconstruction_vae_ddp.py │ └── sd_vae_demo.py ├── validation │ └── val_ddp.py └── vqgan │ ├── README.md │ ├── configs │ ├── vqgan_imagenet_f16_1024.yaml │ ├── vqgan_imagenet_f16_16384.yaml │ ├── vqgan_openimage_f8_16384.yaml │ └── vqgan_openimage_f8_256.yaml │ ├── layer.py │ ├── model.py │ ├── quantize.py │ ├── reconstruction_vqgan_ddp.py │ └── taming_vqgan_demo.py ├── tools ├── check_image_codes.py ├── convert_pytorch_lightning_to_torch.py ├── draw_figure.py ├── imagenet_en_cn.py ├── openimage_json.py ├── push_gpt_to_hf.py └── push_vae_to_hf.py └── utils ├── data.py ├── deepspeed.py ├── distributed.py ├── drop_path.py ├── ema.py ├── logger.py └── video.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | -------------------------------------------------------------------------------- /GETTING_STARTED.md: -------------------------------------------------------------------------------- 1 | ## Getting Started 2 | ### Requirements 3 | - Linux with Python ≥ 3.7 4 | - PyTorch ≥ 2.1 5 | - A100 GPUs 6 | 7 | ### Train VQVAE models 8 | ``` 9 | bash scripts/tokenizer/train_vq.sh --cloud-save-path /path/to/cloud_disk --data-path /path/to/imagenet/train --image-size 256 --vq-model VQ-16 10 | ``` 11 | 12 | 13 | ### Pre-extract discrete codes of training images 14 | ``` 15 | bash scripts/autoregressive/extract_codes_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --data-path /path/to/imagenet/train --code-path /path/to/imagenet_code_c2i_flip_ten_crop --ten-crop --crop-range 1.1 --image-size 384 16 | ``` 17 | and/or 18 | ``` 19 | bash scripts/autoregressive/extract_codes_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --data-path /path/to/imagenet/train --code-path /path/to/imagenet_code_c2i_flip_ten_crop_105 --ten-crop --crop-range 1.05 --image-size 384 20 | ``` 21 | 22 | 23 | ### Train AR models with DDP 24 | Before running, please change `nnodes, nproc_per_node, node_rank, master_addr, master_port` in `.sh` 25 | ``` 26 | bash scripts/autoregressive/train_c2i.sh --cloud-save-path /path/to/cloud_disk --code-path /path/to/imagenet_code_c2i_flip_ten_crop --image-size 384 --gpt-model GPT-B 27 | 28 | bash scripts/autoregressive/train_c2i.sh --cloud-save-path /path/to/cloud_disk --code-path /path/to/imagenet_code_c2i_flip_ten_crop --image-size 384 --gpt-model GPT-L 29 | 30 | bash scripts/autoregressive/train_c2i.sh --cloud-save-path /path/to/cloud_disk --code-path /path/to/imagenet_code_c2i_flip_ten_crop --image-size 384 --gpt-model GPT-XL 31 | ``` 32 | 33 | 34 | ### Train AR models with FSDP 35 | Before running, please change `nnodes, nproc_per_node, node_rank, master_addr, master_port` in `.sh` 36 | ``` 37 | bash scripts/autoregressive/train_c2i_fsdp.sh --cloud-save-path /path/to/cloud_disk --code-path /path/to/imagenet_code_c2i_flip_ten_crop --image-size 384 --gpt-model GPT-XXL 38 | 39 | bash scripts/autoregressive/train_c2i_fsdp.sh --cloud-save-path /path/to/cloud_disk --code-path /path/to/imagenet_code_c2i_flip_ten_crop --image-size 384 --gpt-model GPT-3B 40 | ``` 41 | 42 | 43 | ### Sampling 44 | ``` 45 | bash scripts/autoregressive/sample_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_B.pt --gpt-model GPT-B --image-size 384 --image-size-eval 256 --cfg-scale 2.0 46 | 47 | bash scripts/autoregressive/sample_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_L.pt --gpt-model GPT-L --image-size 384 --image-size-eval 256 --cfg-scale 2.0 48 | 49 | bash scripts/autoregressive/sample_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XL.pt --gpt-model GPT-XL --image-size 384 --image-size-eval 256 --cfg-scale 1.75 50 | 51 | bash scripts/autoregressive/sample_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL.pt --gpt-model GPT-XXL --from-fsdp --image-size 384 --image-size-eval 256 --cfg-scale 1.75 52 | 53 | bash scripts/autoregressive/sample_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_3B.pt --gpt-model GPT-3B --from-fsdp --image-size 384 --image-size-eval 256 --cfg-scale 1.65 54 | ``` 55 | 56 | 57 | ### Evaluation 58 | Before evaluation, please refer [evaluation readme](evaluations/c2i/README.md) to install required packages. 59 | ``` 60 | python3 evaluations/c2i/evaluator.py VIRTUAL_imagenet256_labeled.npz samples/GPT-B-c2i_B-size-384-size-256-VQ-16-topk-0-topp-1.0-temperature-1.0-cfg-2.0-seed-0.npz 61 | ``` -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 FoundationVision 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Autoregressive Model Beats Diffusion: 🦙 Llama for Scalable Image Generation 2 | 3 | 4 |
5 | 6 | [![demo](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Online_Demo-blue)](https://huggingface.co/spaces/FoundationVision/LlamaGen)  7 | [![arXiv](https://img.shields.io/badge/arXiv%20paper-2406.06525-b31b1b.svg)](https://arxiv.org/abs/2406.06525)  8 | [![project page](https://img.shields.io/badge/Project_page-More_visualizations-green)](https://peizesun.github.io/llamagen/)  9 | 10 |
11 | 12 | 13 |

14 | 15 |

16 | 17 | 18 | 19 | This repo contains pre-trained model weights and training/sampling PyTorch(torch>=2.1.0) codes used in 20 | 21 | > [**Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation**](https://arxiv.org/abs/2406.06525)
22 | > [Peize Sun](https://peizesun.github.io/), [Yi Jiang](https://enjoyyi.github.io/), [Shoufa Chen](https://www.shoufachen.com/), [Shilong Zhang](https://jshilong.github.io/), [Bingyue Peng](), [Ping Luo](http://luoping.me/), [Zehuan Yuan](https://shallowyuan.github.io/) 23 | >
HKU, ByteDance
24 | 25 | You can find more visualizations on [![project page](https://img.shields.io/badge/Project_page-More_visualizations-green)](https://peizesun.github.io/llamagen/) 26 | 27 | ## 🔥 Update 28 | - [2024.06.28] Image tokenizers and AR models for text-conditional image generation are released ! Try it ! 29 | - [2024.06.15] All models ranging from 100M to 3B parameters are supported by vLLM ! 30 | - [2024.06.11] Image tokenizers and AR models for class-conditional image generation are released ! 31 | - [2024.06.11] Code and Demo are released ! 32 | 33 | ## 🌿 Introduction 34 | We introduce LlamaGen, a new family of image generation models that apply original ``next-token prediction`` paradigm of large language models to visual generation domain. It is an affirmative answer to whether vanilla autoregressive models, e.g., Llama, ``without inductive biases`` on visual signals can achieve state-of-the-art image generation performance if scaling properly. We reexamine design spaces of image tokenizers, scalability properties of image generation models, and their training data quality. 35 | 36 | In this repo, we release: 37 | * Two image tokenizers of downsample ratio 16 and 8. 38 | * Seven class-conditional generation models ranging from 100M to 3B parameters. 39 | * Two text-conditional generation models of 700M parameters. 40 | * Online demos in [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/FoundationVision/LlamaGen) for running pre-trained models. 41 | * Supported vLLM serving framework to enable 300% - 400% speedup. 42 | 43 | ## 🦄 Class-conditional image generation on ImageNet 44 | ### VQ-VAE models 45 | Method | params | tokens | rFID (256x256) | weight 46 | --- |:---:|:---:|:---:|:---: 47 | vq_ds16_c2i | 72M | 16x16 | 2.19 | [vq_ds16_c2i.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt) 48 | vq_ds16_c2i | 72M | 24x24 | 0.94 | above 49 | vq_ds16_c2i | 72M | 32x32 | 0.70 | above 50 | vq_ds8_c2i | 70M | 32x32 | 0.59 | [vq_ds8_c2i.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds8_c2i.pt) 51 | 52 | ### AR models 53 | Method | params | training | tokens | FID (256x256) | weight 54 | --- |:---:|:---:|:---:|:---:|:---:| 55 | LlamaGen-B | 111M | DDP | 16x16 | 5.46 | [c2i_B_256.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_B_256.pt) 56 | LlamaGen-B | 111M | DDP | 24x24 | 6.09 | [c2i_B_384.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_B_384.pt) 57 | LlamaGen-L | 343M | DDP | 16x16 | 3.80 | [c2i_L_256.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_L_256.pt) 58 | LlamaGen-L | 343M | DDP | 24x24 | 3.07 | [c2i_L_384.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_L_384.pt) 59 | LlamaGen-XL | 775M | DDP | 24x24 | 2.62 | [c2i_X_384L.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_XL_384.pt) 60 | LlamaGen-XXL | 1.4B | FSDP | 24x24 | 2.34 | [c2i_XXL_384.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_XXL_384.pt) 61 | LlamaGen-3B | 3.1B | FSDP | 24x24 | 2.18 | [c2i_3B_384.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_3B_384.pt) 62 | 63 | 64 | ### Demo 65 | Please download models, put them in the folder `./pretrained_models`, and run 66 | ``` 67 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_L_384.pt --gpt-model GPT-L --image-size 384 68 | # or 69 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL_384.pt --gpt-model GPT-XXL --from-fsdp --image-size 384 70 | ``` 71 | The generated images will be saved to `sample_c2i.png`. 72 | 73 | ### Gradio Demo 74 | 75 | You can use our online gradio demo [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/FoundationVision/LlamaGen) or run gradio locally: 76 | ```bash 77 | python app.py 78 | ``` 79 | 80 | 81 | ## 🚀 Text-conditional image generation 82 | ### VQ-VAE models 83 | Method | params | tokens | data | weight 84 | --- |:---:|:---:|:---:|:---: 85 | vq_ds16_t2i | 72M | 16x16 | LAION COCO (50M) + internal data (10M) | [vq_ds16_t2i.pt](https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.pt) 86 | 87 | ### AR models 88 | Method | params | tokens | data | weight 89 | --- |:---:|:---:|:---:|:---: 90 | LlamaGen-XL | 775M | 16x16 | LAION COCO (50M) | [t2i_XL_stage1_256.pt](https://huggingface.co/peizesun/llamagen_t2i/resolve/main/t2i_XL_stage1_256.pt) 91 | LlamaGen-XL | 775M | 32x32 | internal data (10M) | [t2i_XL_stage2_512.pt](https://huggingface.co/peizesun/llamagen_t2i/resolve/main/t2i_XL_stage2_512.pt) 92 | 93 | ### Demo 94 | Before running demo, please refer to [language readme](language/README.md) to install the required packages and language models. 95 | 96 | Please download models, put them in the folder `./pretrained_models`, and run 97 | ``` 98 | python3 autoregressive/sample/sample_t2i.py --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt --gpt-ckpt ./pretrained_models/t2i_XL_stage1_256.pt --gpt-model GPT-XL --image-size 256 99 | # or 100 | python3 autoregressive/sample/sample_t2i.py --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt --gpt-ckpt ./pretrained_models/t2i_XL_stage2_512.pt --gpt-model GPT-XL --image-size 512 101 | ``` 102 | The generated images will be saved to `sample_t2i.png`. 103 | 104 | ### Local Gradio Demo 105 | 106 | 107 | 108 | ## ⚡ Serving 109 | We use serving framework [vLLM](https://github.com/vllm-project/vllm) to enable higher throughput. Please refer to [serving readme](autoregressive/serve/README.md) to install the required packages. 110 | ``` 111 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL_384.pt --gpt-model GPT-XXL --from-fsdp --image-size 384 112 | ``` 113 | The generated images will be saved to `sample_c2i_vllm.png`. 114 | 115 | 116 | ## Getting Started 117 | See [Getting Started](GETTING_STARTED.md) for installation, training and evaluation. 118 | 119 | 120 | ## License 121 | The majority of this project is licensed under MIT License. Portions of the project are available under separate license of referred projects, detailed in corresponding files. 122 | 123 | 124 | ## BibTeX 125 | ```bibtex 126 | @article{sun2024autoregressive, 127 | title={Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation}, 128 | author={Sun, Peize and Jiang, Yi and Chen, Shoufa and Zhang, Shilong and Peng, Bingyue and Luo, Ping and Yuan, Zehuan}, 129 | journal={arXiv preprint arXiv:2406.06525}, 130 | year={2024} 131 | } 132 | ``` 133 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import gradio as gr 3 | from tools.imagenet_en_cn import IMAGENET_1K_CLASSES 4 | from huggingface_hub import hf_hub_download 5 | import torch 6 | torch.backends.cuda.matmul.allow_tf32 = True 7 | torch.backends.cudnn.allow_tf32 = True 8 | torch.set_float32_matmul_precision('high') 9 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) 10 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) 11 | from vllm import SamplingParams 12 | import time 13 | import argparse 14 | from tokenizer.tokenizer_image.vq_model import VQ_models 15 | from autoregressive.serve.llm import LLM 16 | from autoregressive.serve.sampler import Sampler 17 | 18 | device = "cuda" 19 | 20 | model2ckpt = { 21 | "GPT-XL": ("vq_ds16_c2i.pt", "c2i_XL_384.pt", 384), 22 | "GPT-B": ("vq_ds16_c2i.pt", "c2i_B_256.pt", 256), 23 | } 24 | 25 | def load_model(args): 26 | ckpt_folder = "./" 27 | vq_ckpt, gpt_ckpt, image_size = model2ckpt[args.gpt_model] 28 | hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder) 29 | hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder) 30 | # create and load model 31 | vq_model = VQ_models[args.vq_model]( 32 | codebook_size=args.codebook_size, 33 | codebook_embed_dim=args.codebook_embed_dim) 34 | vq_model.to(device) 35 | vq_model.eval() 36 | checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu") 37 | vq_model.load_state_dict(checkpoint["model"]) 38 | del checkpoint 39 | print(f"image tokenizer is loaded") 40 | 41 | # Create an LLM. 42 | args.image_size = image_size 43 | args.gpt_ckpt = f"{ckpt_folder}{gpt_ckpt}" 44 | llm = LLM( 45 | args=args, 46 | model='serve/fake_json/{}.json'.format(args.gpt_model), 47 | gpu_memory_utilization=0.6, 48 | skip_tokenizer_init=True) 49 | print(f"gpt model is loaded") 50 | return vq_model, llm, image_size 51 | 52 | 53 | def infer(cfg_scale, top_k, top_p, temperature, class_label, seed): 54 | llm.llm_engine.model_executor.driver_worker.model_runner.model.sampler = Sampler(cfg_scale) 55 | args.cfg_scale = cfg_scale 56 | n = 4 57 | latent_size = image_size // args.downsample_size 58 | # Labels to condition the model with (feel free to change): 59 | class_labels = [class_label for _ in range(n)] 60 | qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size] 61 | 62 | prompt_token_ids = [[cind] for cind in class_labels] 63 | if cfg_scale > 1.0: 64 | prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))]) 65 | 66 | # Create a sampling params object. 67 | sampling_params = SamplingParams( 68 | temperature=temperature, top_p=top_p, top_k=top_k, 69 | max_tokens=latent_size ** 2) 70 | 71 | t1 = time.time() 72 | torch.manual_seed(seed) 73 | outputs = llm.generate( 74 | prompt_token_ids=prompt_token_ids, 75 | sampling_params=sampling_params, 76 | use_tqdm=False) 77 | sampling_time = time.time() - t1 78 | print(f"gpt sampling takes about {sampling_time:.2f} seconds.") 79 | 80 | index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device) 81 | if cfg_scale > 1.0: 82 | index_sample = index_sample[:len(class_labels)] 83 | t2 = time.time() 84 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] 85 | decoder_time = time.time() - t2 86 | print(f"decoder takes about {decoder_time:.2f} seconds.") 87 | # Convert to PIL.Image format: 88 | samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy() 89 | samples = [Image.fromarray(sample) for sample in samples] 90 | return samples 91 | 92 | 93 | parser = argparse.ArgumentParser() 94 | parser.add_argument("--gpt-model", type=str, default="GPT-XL") 95 | parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional") 96 | parser.add_argument("--from-fsdp", action='store_true') 97 | parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input") 98 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 99 | parser.add_argument("--compile", action='store_true', default=False) 100 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") 101 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 102 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 103 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) 104 | parser.add_argument("--num-classes", type=int, default=1000) 105 | parser.add_argument("--cfg-scale", type=float, default=4.0) 106 | parser.add_argument("--cfg-interval", type=float, default=-1) 107 | parser.add_argument("--seed", type=int, default=0) 108 | parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with") 109 | parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") 110 | parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") 111 | args = parser.parse_args() 112 | 113 | vq_model, llm, image_size = load_model(args) 114 | 115 | with gr.Blocks() as demo: 116 | gr.Markdown("

Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation

") 117 | 118 | with gr.Tabs(): 119 | with gr.TabItem('Generate'): 120 | with gr.Row(): 121 | with gr.Column(): 122 | with gr.Row(): 123 | i1k_class = gr.Dropdown( 124 | list(IMAGENET_1K_CLASSES.values()), 125 | value='llama [羊驼]', 126 | type="index", label='ImageNet-1K Class' 127 | ) 128 | cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale') 129 | top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=4000, label='Top-K') 130 | top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P") 131 | temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature') 132 | seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed') 133 | # seed = gr.Number(value=0, label='Seed') 134 | button = gr.Button("Generate", variant="primary") 135 | with gr.Column(): 136 | output = gr.Gallery(label='Generated Images', height=700) 137 | button.click(infer, inputs=[cfg_scale, top_k, top_p, temperature, i1k_class, seed], outputs=[output]) 138 | demo.queue() 139 | demo.launch(debug=True) 140 | -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/LlamaGen/ce98ec41803a74a90ce68c40ababa9eaeffeb4ec/assets/teaser.jpg -------------------------------------------------------------------------------- /autoregressive/models/generate.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py 3 | # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn import functional as F 7 | import torch._dynamo.config 8 | import torch._inductor.config 9 | import copy 10 | # torch._inductor.config.coordinate_descent_tuning = True 11 | # torch._inductor.config.triton.unique_kernel_names = True 12 | # torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future 13 | 14 | 15 | ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html 16 | def top_k_top_p_filtering( 17 | logits, 18 | top_k: int = 0, 19 | top_p: float = 1.0, 20 | filter_value: float = -float("Inf"), 21 | min_tokens_to_keep: int = 1, 22 | ): 23 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering 24 | Args: 25 | logits: logits distribution shape (batch size, vocabulary size) 26 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering). 27 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). 28 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) 29 | Make sure we keep at least min_tokens_to_keep per batch example in the output 30 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 31 | """ 32 | if top_k > 0: 33 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check 34 | # Remove all tokens with a probability less than the last token of the top-k 35 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] 36 | logits[indices_to_remove] = filter_value 37 | 38 | if top_p < 1.0: 39 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 40 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 41 | 42 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept) 43 | sorted_indices_to_remove = cumulative_probs > top_p 44 | if min_tokens_to_keep > 1: 45 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below) 46 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0 47 | # Shift the indices to the right to keep also the first token above the threshold 48 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() 49 | sorted_indices_to_remove[..., 0] = 0 50 | 51 | # scatter sorted tensors to original indexing 52 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 53 | logits[indices_to_remove] = filter_value 54 | return logits 55 | 56 | 57 | def sample(logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True): 58 | logits = logits[:, -1, :] / max(temperature, 1e-5) 59 | if top_k > 0 or top_p < 1.0: 60 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 61 | probs = F.softmax(logits, dim=-1) 62 | if sample_logits: 63 | idx = torch.multinomial(probs, num_samples=1) 64 | else: 65 | _, idx = torch.topk(probs, k=1, dim=-1) 66 | return idx, probs 67 | 68 | 69 | def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs): 70 | logits = logits / max(temperature, 1e-5) 71 | if top_k > 0 or top_p < 1.0: 72 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) 73 | probs = torch.nn.functional.softmax(logits, dim=-1) 74 | return probs 75 | 76 | 77 | def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, **sampling_kwargs): 78 | if cfg_scale > 1.0: 79 | logits, _ = model(None, cond_idx, input_pos) 80 | logits_combined = logits 81 | cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0) 82 | logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale 83 | else: 84 | logits, _ = model(None, cond_idx, input_pos) 85 | 86 | return sample(logits, **sampling_kwargs)[0] 87 | 88 | 89 | def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, **sampling_kwargs): 90 | assert input_pos.shape[-1] == 1 91 | if cfg_scale > 1.0: 92 | x_combined = torch.cat([x, x]) 93 | logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos) 94 | logits_combined = logits 95 | cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0) 96 | if cfg_flag: 97 | logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale 98 | else: 99 | logits = cond_logits 100 | else: 101 | logits, _ = model(x, cond_idx=None, input_pos=input_pos) 102 | return sample(logits, **sampling_kwargs) 103 | 104 | 105 | def decode_n_tokens( 106 | model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, 107 | cfg_scale: float, cfg_interval: int, 108 | **sampling_kwargs): 109 | new_tokens, new_probs = [], [] 110 | cfg_flag = True 111 | for i in range(num_new_tokens): 112 | with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here 113 | if cfg_interval > -1 and i > cfg_interval: 114 | cfg_flag = False 115 | next_token, next_prob = decode_one_token( 116 | model, cur_token, input_pos, cfg_scale, cfg_flag, **sampling_kwargs 117 | ) 118 | input_pos += 1 119 | new_tokens.append(next_token.clone()) 120 | new_probs.append(next_prob.clone()) 121 | cur_token = next_token.view(-1, 1) 122 | 123 | return new_tokens, new_probs 124 | 125 | 126 | @torch.no_grad() 127 | def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, **sampling_kwargs): 128 | if model.model_type == 'c2i': 129 | if cfg_scale > 1.0: 130 | cond_null = torch.ones_like(cond) * model.num_classes 131 | cond_combined = torch.cat([cond, cond_null]) 132 | else: 133 | cond_combined = cond 134 | T = 1 135 | elif model.model_type == 't2i': 136 | if cfg_scale > 1.0: 137 | cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding 138 | cond_combined = torch.cat([cond, cond_null]) 139 | else: 140 | cond_combined = cond 141 | T = cond.shape[1] 142 | else: 143 | raise Exception("please check model type") 144 | 145 | T_new = T + max_new_tokens 146 | max_seq_length = T_new 147 | max_batch_size = cond.shape[0] 148 | 149 | device = cond.device 150 | with torch.device(device): 151 | max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size 152 | model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype) 153 | 154 | if emb_masks is not None: 155 | assert emb_masks.shape[0] == max_batch_size 156 | assert emb_masks.shape[-1] == T 157 | if cfg_scale > 1.0: 158 | model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1) 159 | else: 160 | model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1) 161 | 162 | eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device) 163 | model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix 164 | 165 | # create an empty tensor of the expected final shape and fill in the current tokens 166 | seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device) 167 | 168 | input_pos = torch.arange(0, T, device=device) 169 | next_token = prefill(model, cond_combined, input_pos, cfg_scale, **sampling_kwargs) 170 | seq[:, T:T+1] = next_token 171 | 172 | input_pos = torch.tensor([T], device=device, dtype=torch.int) 173 | generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, **sampling_kwargs) 174 | seq[:, T+1:] = torch.cat(generated_tokens, dim=1) 175 | 176 | return seq[:, T:] 177 | -------------------------------------------------------------------------------- /autoregressive/models/gpt_hf.py: -------------------------------------------------------------------------------- 1 | from autoregressive.models.gpt import ModelArgs, Transformer 2 | from huggingface_hub import PyTorchModelHubMixin 3 | 4 | 5 | class TransformerHF(Transformer, PyTorchModelHubMixin, repo_url="https://github.com/FoundationVision/LlamaGen", license="mit", tags=["llamagen", "text-to-image"]): 6 | pass 7 | 8 | 9 | ################################################################################# 10 | # GPT Configs # 11 | ################################################################################# 12 | ### text-conditional 13 | def GPT_7B(**kwargs): 14 | return TransformerHF(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B 15 | 16 | def GPT_3B(**kwargs): 17 | return TransformerHF(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B 18 | 19 | def GPT_1B(**kwargs): 20 | return TransformerHF(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B 21 | 22 | ### class-conditional 23 | def GPT_XXXL(**kwargs): 24 | return TransformerHF(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B 25 | 26 | def GPT_XXL(**kwargs): 27 | return TransformerHF(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B 28 | 29 | def GPT_XL(**kwargs): 30 | return TransformerHF(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M 31 | 32 | def GPT_L(**kwargs): 33 | return TransformerHF(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M 34 | 35 | def GPT_B(**kwargs): 36 | return TransformerHF(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M 37 | 38 | 39 | GPT_models_HF = { 40 | 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL, 41 | 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B, 42 | } 43 | -------------------------------------------------------------------------------- /autoregressive/sample/sample_c2i.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # DiT: https://github.com/facebookresearch/DiT/blob/main/sample.py 3 | import torch 4 | torch.backends.cuda.matmul.allow_tf32 = True 5 | torch.backends.cudnn.allow_tf32 = True 6 | torch.set_float32_matmul_precision('high') 7 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) 8 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) 9 | from torchvision.utils import save_image 10 | 11 | import time 12 | import argparse 13 | from tokenizer.tokenizer_image.vq_model import VQ_models 14 | from autoregressive.models.gpt import GPT_models 15 | from autoregressive.models.generate import generate 16 | 17 | 18 | def main(args): 19 | # Setup PyTorch: 20 | torch.manual_seed(args.seed) 21 | torch.backends.cudnn.deterministic = True 22 | torch.backends.cudnn.benchmark = False 23 | torch.set_grad_enabled(False) 24 | device = "cuda" if torch.cuda.is_available() else "cpu" 25 | 26 | # create and load model 27 | vq_model = VQ_models[args.vq_model]( 28 | codebook_size=args.codebook_size, 29 | codebook_embed_dim=args.codebook_embed_dim) 30 | vq_model.to(device) 31 | vq_model.eval() 32 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") 33 | vq_model.load_state_dict(checkpoint["model"]) 34 | del checkpoint 35 | print(f"image tokenizer is loaded") 36 | 37 | # create and load gpt model 38 | precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] 39 | latent_size = args.image_size // args.downsample_size 40 | gpt_model = GPT_models[args.gpt_model]( 41 | vocab_size=args.codebook_size, 42 | block_size=latent_size ** 2, 43 | num_classes=args.num_classes, 44 | cls_token_num=args.cls_token_num, 45 | model_type=args.gpt_type, 46 | ).to(device=device, dtype=precision) 47 | 48 | checkpoint = torch.load(args.gpt_ckpt, map_location="cpu") 49 | if args.from_fsdp: # fspd 50 | model_weight = checkpoint 51 | elif "model" in checkpoint: # ddp 52 | model_weight = checkpoint["model"] 53 | elif "module" in checkpoint: # deepspeed 54 | model_weight = checkpoint["module"] 55 | elif "state_dict" in checkpoint: 56 | model_weight = checkpoint["state_dict"] 57 | else: 58 | raise Exception("please check model weight, maybe add --from-fsdp to run command") 59 | # if 'freqs_cis' in model_weight: 60 | # model_weight.pop('freqs_cis') 61 | gpt_model.load_state_dict(model_weight, strict=False) 62 | gpt_model.eval() 63 | del checkpoint 64 | print(f"gpt model is loaded") 65 | 66 | if args.compile: 67 | print(f"compiling the model...") 68 | gpt_model = torch.compile( 69 | gpt_model, 70 | mode="reduce-overhead", 71 | fullgraph=True 72 | ) # requires PyTorch 2.0 (optional) 73 | else: 74 | print(f"no need to compile model in demo") 75 | 76 | # Labels to condition the model with (feel free to change): 77 | class_labels = [207, 360, 387, 974, 88, 979, 417, 279] 78 | c_indices = torch.tensor(class_labels, device=device) 79 | qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size] 80 | 81 | t1 = time.time() 82 | index_sample = generate( 83 | gpt_model, c_indices, latent_size ** 2, 84 | cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval, 85 | temperature=args.temperature, top_k=args.top_k, 86 | top_p=args.top_p, sample_logits=True, 87 | ) 88 | sampling_time = time.time() - t1 89 | print(f"gpt sampling takes about {sampling_time:.2f} seconds.") 90 | 91 | t2 = time.time() 92 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] 93 | decoder_time = time.time() - t2 94 | print(f"decoder takes about {decoder_time:.2f} seconds.") 95 | 96 | # Save and display images: 97 | save_image(samples, "sample_{}.png".format(args.gpt_type), nrow=4, normalize=True, value_range=(-1, 1)) 98 | print(f"image is saved to sample_{args.gpt_type}.png") 99 | 100 | 101 | if __name__ == "__main__": 102 | parser = argparse.ArgumentParser() 103 | parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B") 104 | parser.add_argument("--gpt-ckpt", type=str, default=None) 105 | parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional") 106 | parser.add_argument("--from-fsdp", action='store_true') 107 | parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input") 108 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 109 | parser.add_argument("--compile", action='store_true', default=False) 110 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") 111 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") 112 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 113 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 114 | parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384) 115 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) 116 | parser.add_argument("--num-classes", type=int, default=1000) 117 | parser.add_argument("--cfg-scale", type=float, default=4.0) 118 | parser.add_argument("--cfg-interval", type=float, default=-1) 119 | parser.add_argument("--seed", type=int, default=0) 120 | parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with") 121 | parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") 122 | parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") 123 | args = parser.parse_args() 124 | main(args) -------------------------------------------------------------------------------- /autoregressive/sample/sample_t2i.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | torch.backends.cudnn.allow_tf32 = True 4 | torch.set_float32_matmul_precision('high') 5 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed 6 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed 7 | from torchvision.utils import save_image 8 | 9 | import os 10 | import time 11 | import argparse 12 | from tokenizer.tokenizer_image.vq_model import VQ_models 13 | from language.t5 import T5Embedder 14 | from autoregressive.models.gpt import GPT_models 15 | from autoregressive.models.generate import generate 16 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 17 | 18 | 19 | 20 | def main(args): 21 | # Setup PyTorch: 22 | torch.manual_seed(args.seed) 23 | torch.backends.cudnn.deterministic = True 24 | torch.backends.cudnn.benchmark = False 25 | torch.set_grad_enabled(False) 26 | device = "cuda" if torch.cuda.is_available() else "cpu" 27 | 28 | # create and load model 29 | vq_model = VQ_models[args.vq_model]( 30 | codebook_size=args.codebook_size, 31 | codebook_embed_dim=args.codebook_embed_dim) 32 | vq_model.to(device) 33 | vq_model.eval() 34 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") 35 | vq_model.load_state_dict(checkpoint["model"]) 36 | del checkpoint 37 | print(f"image tokenizer is loaded") 38 | 39 | # create and load gpt model 40 | precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] 41 | latent_size = args.image_size // args.downsample_size 42 | gpt_model = GPT_models[args.gpt_model]( 43 | block_size=latent_size ** 2, 44 | cls_token_num=args.cls_token_num, 45 | model_type=args.gpt_type, 46 | ).to(device=device, dtype=precision) 47 | 48 | checkpoint = torch.load(args.gpt_ckpt, map_location="cpu") 49 | 50 | if "model" in checkpoint: # ddp 51 | model_weight = checkpoint["model"] 52 | elif "module" in checkpoint: # deepspeed 53 | model_weight = checkpoint["module"] 54 | elif "state_dict" in checkpoint: 55 | model_weight = checkpoint["state_dict"] 56 | else: 57 | raise Exception("please check model weight") 58 | gpt_model.load_state_dict(model_weight, strict=False) 59 | gpt_model.eval() 60 | del checkpoint 61 | print(f"gpt model is loaded") 62 | 63 | if args.compile: 64 | print(f"compiling the model...") 65 | gpt_model = torch.compile( 66 | gpt_model, 67 | mode="reduce-overhead", 68 | fullgraph=True 69 | ) # requires PyTorch 2.0 (optional) 70 | else: 71 | print(f"no need to compile model in demo") 72 | 73 | assert os.path.exists(args.t5_path) 74 | t5_model = T5Embedder( 75 | device=device, 76 | local_cache=True, 77 | cache_dir=args.t5_path, 78 | dir_or_name=args.t5_model_type, 79 | torch_dtype=precision, 80 | model_max_length=args.t5_feature_max_len, 81 | ) 82 | prompts = [ 83 | "A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grassin front of the Sydney Opera House holding a sign on the chest that says Welcome Friends!", 84 | "A blue Porsche 356 parked in front of a yellow brick wall.", 85 | "A photo of an astronaut riding a horse in the forest. There is a river in front of them with water lilies.", 86 | "A map of the United States made out of sushi. It is on a table next to a glass of red wine." 87 | ] 88 | 89 | caption_embs, emb_masks = t5_model.get_text_embeddings(prompts) 90 | 91 | 92 | if not args.no_left_padding: 93 | print(f"processing left-padding...") 94 | # a naive way to implement left-padding 95 | new_emb_masks = torch.flip(emb_masks, dims=[-1]) 96 | new_caption_embs = [] 97 | for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)): 98 | valid_num = int(emb_mask.sum().item()) 99 | print(f' prompt {idx} token len: {valid_num}') 100 | new_caption_emb = torch.cat([caption_emb[valid_num:], caption_emb[:valid_num]]) 101 | new_caption_embs.append(new_caption_emb) 102 | new_caption_embs = torch.stack(new_caption_embs) 103 | else: 104 | new_caption_embs, new_emb_masks = caption_embs, emb_masks 105 | c_indices = new_caption_embs * new_emb_masks[:,:, None] 106 | c_emb_masks = new_emb_masks 107 | 108 | qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size] 109 | t1 = time.time() 110 | index_sample = generate( 111 | gpt_model, c_indices, latent_size ** 2, 112 | c_emb_masks, 113 | cfg_scale=args.cfg_scale, 114 | temperature=args.temperature, top_k=args.top_k, 115 | top_p=args.top_p, sample_logits=True, 116 | ) 117 | sampling_time = time.time() - t1 118 | print(f"Full sampling takes about {sampling_time:.2f} seconds.") 119 | 120 | t2 = time.time() 121 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] 122 | decoder_time = time.time() - t2 123 | print(f"decoder takes about {decoder_time:.2f} seconds.") 124 | 125 | save_image(samples, "sample_{}.png".format(args.gpt_type), nrow=4, normalize=True, value_range=(-1, 1)) 126 | print(f"image is saved to sample_{args.gpt_type}.png") 127 | 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument("--t5-path", type=str, default='pretrained_models/t5-ckpt') 133 | parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl') 134 | parser.add_argument("--t5-feature-max-len", type=int, default=120) 135 | parser.add_argument("--t5-feature-dim", type=int, default=2048) 136 | parser.add_argument("--no-left-padding", action='store_true', default=False) 137 | parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL") 138 | parser.add_argument("--gpt-ckpt", type=str, default=None) 139 | parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image") 140 | parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input") 141 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 142 | parser.add_argument("--compile", action='store_true', default=False) 143 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") 144 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") 145 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 146 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 147 | parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=512) 148 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) 149 | parser.add_argument("--num-classes", type=int, default=1000) 150 | parser.add_argument("--cfg-scale", type=float, default=7.5) 151 | parser.add_argument("--seed", type=int, default=0) 152 | parser.add_argument("--top-k", type=int, default=1000, help="top-k value to sample with") 153 | parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") 154 | parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") 155 | args = parser.parse_args() 156 | main(args) 157 | -------------------------------------------------------------------------------- /autoregressive/serve/README.md: -------------------------------------------------------------------------------- 1 | ## serving by vLLM 2 | 3 | ### Install 4 | ``` 5 | pip install vllm==0.4.1 6 | ``` 7 | 8 | ### Comparison (A100) 9 | 10 | Method | params | baseline(s) | vllm(s) | speed-up ratio 11 | --- |:---:|:---:|:---:|:---: 12 | [GPT-B](./fake_json/GPT-B.json) | 111M | 7.80 | 2.39 | 326 % 13 | [GPT-L](./fake_json/GPT-L.json) | 343M | 13.72 | 3.48 | 380 % 14 | [GPT-XL](./fake_json/GPT-XL.json) | 775M | 19.76 | 4.84 | 408 % 15 | [GPT-XXL](./fake_json/GPT-XXL.json)| 1.4B | 26.38 | 6.36 | 414 % 16 | [GPT-3B](./fake_json/GPT-3B.json) | 3.1B | 14.73 | 6.26 | 235 % 17 | 18 | ``` 19 | ### GPT-B 20 | # 7.80 seconds 21 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_B_384.pt --image-size 384 22 | 23 | # 2.39 seconds 24 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_B_384.pt --image-size 384 25 | 26 | 27 | ### GPT-L 28 | # 13.72 seconds 29 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_L_384.pt --gpt-model GPT-L --image-size 384 30 | 31 | # 3.48 seconds 32 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_L_384.pt --gpt-model GPT-L --image-size 384 33 | 34 | 35 | ### GPT-XL 36 | # 19.76 seconds 37 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XL_384.pt --gpt-model GPT-XL --image-size 384 38 | 39 | # 4.84 seconds 40 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XL_384.pt --gpt-model GPT-XL --image-size 384 41 | 42 | 43 | ### GPT-XXL 44 | # 26.38 seconds 45 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL_384.pt --from-fsdp --gpt-model GPT-XXL --image-size 384 46 | 47 | # 6.36 seconds 48 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL_384.pt --from-fsdp --gpt-model GPT-XXL --image-size 384 49 | 50 | 51 | ### GPT-3B 52 | # 14.73 seconds 53 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_3B_384.pt --from-fsdp --gpt-model GPT-3B --image-size 384 54 | 55 | # 6.26 seconds 56 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_3B_384.pt --from-fsdp --gpt-model GPT-3B --image-size 384 57 | 58 | ``` 59 | -------------------------------------------------------------------------------- /autoregressive/serve/fake_json/GPT-3B.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "facebook/opt-125m", 3 | "activation_dropout": 0.0, 4 | "activation_function": "relu", 5 | "architectures": [ 6 | "OPTForCausalLM" 7 | ], 8 | "attention_dropout": 0.0, 9 | "bos_token_id": 2, 10 | "do_layer_norm_before": true, 11 | "dropout": 0.1, 12 | "eos_token_id": 2, 13 | "ffn_dim": 3072, 14 | "hidden_size": 3584, 15 | "init_std": 0.02, 16 | "layerdrop": 0.0, 17 | "max_position_embeddings": 2048, 18 | "model_type": "opt", 19 | "num_attention_heads": 32, 20 | "num_hidden_layers": 24, 21 | "pad_token_id": 1, 22 | "prefix": "", 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.21.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 16384, 27 | "word_embed_proj_dim": 768 28 | } 29 | -------------------------------------------------------------------------------- /autoregressive/serve/fake_json/GPT-B.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "facebook/opt-125m", 3 | "activation_dropout": 0.0, 4 | "activation_function": "relu", 5 | "architectures": [ 6 | "OPTForCausalLM" 7 | ], 8 | "attention_dropout": 0.0, 9 | "bos_token_id": 2, 10 | "do_layer_norm_before": true, 11 | "dropout": 0.1, 12 | "eos_token_id": 2, 13 | "ffn_dim": 3072, 14 | "hidden_size": 768, 15 | "init_std": 0.02, 16 | "layerdrop": 0.0, 17 | "max_position_embeddings": 2048, 18 | "model_type": "opt", 19 | "num_attention_heads": 12, 20 | "num_hidden_layers": 12, 21 | "pad_token_id": 1, 22 | "prefix": "", 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.21.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 16384, 27 | "word_embed_proj_dim": 768 28 | } 29 | -------------------------------------------------------------------------------- /autoregressive/serve/fake_json/GPT-L.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "facebook/opt-125m", 3 | "activation_dropout": 0.0, 4 | "activation_function": "relu", 5 | "architectures": [ 6 | "OPTForCausalLM" 7 | ], 8 | "attention_dropout": 0.0, 9 | "bos_token_id": 2, 10 | "do_layer_norm_before": true, 11 | "dropout": 0.1, 12 | "eos_token_id": 2, 13 | "ffn_dim": 3072, 14 | "hidden_size": 1024, 15 | "init_std": 0.02, 16 | "layerdrop": 0.0, 17 | "max_position_embeddings": 2048, 18 | "model_type": "opt", 19 | "num_attention_heads": 16, 20 | "num_hidden_layers": 24, 21 | "pad_token_id": 1, 22 | "prefix": "", 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.21.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 16384, 27 | "word_embed_proj_dim": 768 28 | } 29 | -------------------------------------------------------------------------------- /autoregressive/serve/fake_json/GPT-XL.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "facebook/opt-125m", 3 | "activation_dropout": 0.0, 4 | "activation_function": "relu", 5 | "architectures": [ 6 | "OPTForCausalLM" 7 | ], 8 | "attention_dropout": 0.0, 9 | "bos_token_id": 2, 10 | "do_layer_norm_before": true, 11 | "dropout": 0.1, 12 | "eos_token_id": 2, 13 | "ffn_dim": 3072, 14 | "hidden_size": 1280, 15 | "init_std": 0.02, 16 | "layerdrop": 0.0, 17 | "max_position_embeddings": 2048, 18 | "model_type": "opt", 19 | "num_attention_heads": 20, 20 | "num_hidden_layers": 36, 21 | "pad_token_id": 1, 22 | "prefix": "", 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.21.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 16384, 27 | "word_embed_proj_dim": 768 28 | } 29 | -------------------------------------------------------------------------------- /autoregressive/serve/fake_json/GPT-XXL.json: -------------------------------------------------------------------------------- 1 | { 2 | "_name_or_path": "facebook/opt-125m", 3 | "activation_dropout": 0.0, 4 | "activation_function": "relu", 5 | "architectures": [ 6 | "OPTForCausalLM" 7 | ], 8 | "attention_dropout": 0.0, 9 | "bos_token_id": 2, 10 | "do_layer_norm_before": true, 11 | "dropout": 0.1, 12 | "eos_token_id": 2, 13 | "ffn_dim": 3072, 14 | "hidden_size": 1536, 15 | "init_std": 0.02, 16 | "layerdrop": 0.0, 17 | "max_position_embeddings": 2048, 18 | "model_type": "opt", 19 | "num_attention_heads": 24, 20 | "num_hidden_layers": 48, 21 | "pad_token_id": 1, 22 | "prefix": "", 23 | "torch_dtype": "bfloat16", 24 | "transformers_version": "4.21.0.dev0", 25 | "use_cache": true, 26 | "vocab_size": 16384, 27 | "word_embed_proj_dim": 768 28 | } 29 | -------------------------------------------------------------------------------- /autoregressive/serve/gpu_executor.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Set, Tuple, Optional, Set 2 | import argparse 3 | 4 | from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, 5 | ModelConfig, ParallelConfig, SchedulerConfig, 6 | SpeculativeConfig, VisionLanguageConfig) 7 | from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase 8 | from vllm.logger import init_logger 9 | from vllm.lora.request import LoRARequest 10 | from vllm.sequence import SamplerOutput, SequenceGroupMetadata 11 | from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, 12 | make_async) 13 | 14 | logger = init_logger(__name__) 15 | 16 | 17 | class GPUExecutor(ExecutorBase): 18 | def __init__( 19 | self, 20 | args: argparse.ArgumentParser, 21 | model_config: ModelConfig, 22 | cache_config: CacheConfig, 23 | parallel_config: ParallelConfig, 24 | scheduler_config: SchedulerConfig, 25 | device_config: DeviceConfig, 26 | load_config: LoadConfig, 27 | lora_config: Optional[LoRAConfig], 28 | vision_language_config: Optional[VisionLanguageConfig], 29 | speculative_config: Optional[SpeculativeConfig], 30 | ) -> None: 31 | self.args = args 32 | self.model_config = model_config 33 | self.cache_config = cache_config 34 | self.lora_config = lora_config 35 | self.load_config = load_config 36 | self.parallel_config = parallel_config 37 | self.scheduler_config = scheduler_config 38 | self.device_config = device_config 39 | self.vision_language_config = vision_language_config 40 | self.speculative_config = speculative_config 41 | 42 | self._init_executor() 43 | 44 | def _init_executor(self) -> None: 45 | """Initialize the worker and load the model. 46 | 47 | If speculative decoding is enabled, we instead create the speculative 48 | worker. 49 | """ 50 | if self.speculative_config is None: 51 | self._init_non_spec_worker() 52 | else: 53 | self._init_spec_worker() 54 | 55 | def _init_non_spec_worker(self): 56 | # Lazy import the Worker to avoid importing torch.cuda/xformers 57 | # before CUDA_VISIBLE_DEVICES is set in the Worker 58 | # from vllm.worker.worker import Worker 59 | from autoregressive.serve.worker import Worker 60 | 61 | assert self.parallel_config.world_size == 1, ( 62 | "GPUExecutor only supports single GPU.") 63 | 64 | distributed_init_method = get_distributed_init_method( 65 | get_ip(), get_open_port()) 66 | self.driver_worker = Worker( 67 | model_config=self.model_config, 68 | parallel_config=self.parallel_config, 69 | scheduler_config=self.scheduler_config, 70 | device_config=self.device_config, 71 | cache_config=self.cache_config, 72 | load_config=self.load_config, 73 | local_rank=0, 74 | rank=0, 75 | distributed_init_method=distributed_init_method, 76 | lora_config=self.lora_config, 77 | vision_language_config=self.vision_language_config, 78 | is_driver_worker=True, 79 | ) 80 | self.driver_worker.init_device() 81 | self.driver_worker.load_model(self.args) 82 | 83 | def _init_spec_worker(self): 84 | """Initialize a SpecDecodeWorker, using a draft model for proposals. 85 | """ 86 | assert self.speculative_config is not None 87 | 88 | from vllm.spec_decode.multi_step_worker import MultiStepWorker 89 | from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker 90 | from vllm.worker.worker import Worker 91 | 92 | distributed_init_method = get_distributed_init_method( 93 | get_ip(), get_open_port()) 94 | 95 | target_worker = Worker( 96 | model_config=self.model_config, 97 | parallel_config=self.parallel_config, 98 | scheduler_config=self.scheduler_config, 99 | device_config=self.device_config, 100 | cache_config=self.cache_config, 101 | load_config=self.load_config, 102 | local_rank=0, 103 | rank=0, 104 | distributed_init_method=distributed_init_method, 105 | lora_config=self.lora_config, 106 | vision_language_config=self.vision_language_config, 107 | is_driver_worker=True, 108 | ) 109 | 110 | draft_worker = MultiStepWorker( 111 | model_config=self.speculative_config.draft_model_config, 112 | parallel_config=self.speculative_config.draft_parallel_config, 113 | scheduler_config=self.scheduler_config, 114 | device_config=self.device_config, 115 | cache_config=self.cache_config, 116 | load_config=self.load_config, 117 | local_rank=0, 118 | rank=0, 119 | distributed_init_method=distributed_init_method, 120 | lora_config=self.lora_config, 121 | vision_language_config=self.vision_language_config, 122 | is_driver_worker=True, 123 | ) 124 | 125 | spec_decode_worker = SpecDecodeWorker.from_workers( 126 | proposer_worker=draft_worker, scorer_worker=target_worker) 127 | 128 | assert self.parallel_config.world_size == 1, ( 129 | "GPUExecutor only supports single GPU.") 130 | 131 | self.driver_worker = spec_decode_worker 132 | 133 | # Load model handled in spec decode worker. 134 | self.driver_worker.init_device() 135 | 136 | def determine_num_available_blocks(self) -> Tuple[int, int]: 137 | """Determine the number of available KV blocks by invoking the 138 | underlying worker. 139 | """ 140 | return self.driver_worker.determine_num_available_blocks() 141 | 142 | def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None: 143 | """Initialize the KV cache by invoking the underlying worker. 144 | """ 145 | # NOTE: This is logged in the executor because there can be >1 worker 146 | # with other executors. We could log in the engine level, but work 147 | # remains to abstract away the device for non-GPU configurations. 148 | logger.info(f"# GPU blocks: {num_gpu_blocks}, " 149 | f"# CPU blocks: {num_cpu_blocks}") 150 | 151 | self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) 152 | 153 | def execute_model( 154 | self, 155 | seq_group_metadata_list: List[SequenceGroupMetadata], 156 | blocks_to_swap_in: Dict[int, int], 157 | blocks_to_swap_out: Dict[int, int], 158 | blocks_to_copy: Dict[int, List[int]], 159 | num_lookahead_slots: int, 160 | ) -> List[SamplerOutput]: 161 | output = self.driver_worker.execute_model( 162 | seq_group_metadata_list=seq_group_metadata_list, 163 | blocks_to_swap_in=blocks_to_swap_in, 164 | blocks_to_swap_out=blocks_to_swap_out, 165 | blocks_to_copy=blocks_to_copy, 166 | num_lookahead_slots=num_lookahead_slots, 167 | ) 168 | return output 169 | 170 | def add_lora(self, lora_request: LoRARequest) -> bool: 171 | assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." 172 | return self.driver_worker.add_lora(lora_request) 173 | 174 | def remove_lora(self, lora_id: int) -> bool: 175 | assert lora_id > 0, "lora_id must be greater than 0." 176 | return self.driver_worker.remove_lora(lora_id) 177 | 178 | def list_loras(self) -> Set[int]: 179 | return self.driver_worker.list_loras() 180 | 181 | def check_health(self) -> None: 182 | # GPUExecutor will always be healthy as long as 183 | # it's running. 184 | return 185 | 186 | 187 | class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase): 188 | 189 | async def execute_model_async( 190 | self, 191 | seq_group_metadata_list: List[SequenceGroupMetadata], 192 | blocks_to_swap_in: Dict[int, int], 193 | blocks_to_swap_out: Dict[int, int], 194 | blocks_to_copy: Dict[int, List[int]], 195 | ) -> SamplerOutput: 196 | output = await make_async(self.driver_worker.execute_model)( 197 | seq_group_metadata_list=seq_group_metadata_list, 198 | blocks_to_swap_in=blocks_to_swap_in, 199 | blocks_to_swap_out=blocks_to_swap_out, 200 | blocks_to_copy=blocks_to_copy) 201 | return output -------------------------------------------------------------------------------- /autoregressive/serve/sample_c2i.py: -------------------------------------------------------------------------------- 1 | import time 2 | import argparse 3 | import torch 4 | from torchvision.utils import save_image 5 | 6 | from tokenizer.tokenizer_image.vq_model import VQ_models 7 | from autoregressive.serve.gpt_model import GPT_models 8 | from autoregressive.serve.llm import LLM 9 | from vllm import SamplingParams 10 | 11 | 12 | def main(args): 13 | # Setup PyTorch: 14 | torch.manual_seed(args.seed) 15 | torch.backends.cudnn.deterministic = True 16 | torch.backends.cudnn.benchmark = False 17 | torch.set_grad_enabled(False) 18 | device = "cuda" if torch.cuda.is_available() else "cpu" 19 | 20 | # create and load model 21 | vq_model = VQ_models[args.vq_model]( 22 | codebook_size=args.codebook_size, 23 | codebook_embed_dim=args.codebook_embed_dim) 24 | vq_model.to(device) 25 | vq_model.eval() 26 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") 27 | vq_model.load_state_dict(checkpoint["model"]) 28 | del checkpoint 29 | print(f"image tokenizer is loaded") 30 | 31 | # Labels to condition the model with (feel free to change): 32 | class_labels = [207, 360, 387, 974, 88, 979, 417, 279] 33 | latent_size = args.image_size // args.downsample_size 34 | qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size] 35 | prompt_token_ids = [[cind] for cind in class_labels] 36 | if args.cfg_scale > 1.0: 37 | prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))]) 38 | # Create an LLM. 39 | llm = LLM( 40 | args=args, 41 | model='autoregressive/serve/fake_json/{}.json'.format(args.gpt_model), 42 | gpu_memory_utilization=0.9, 43 | skip_tokenizer_init=True) 44 | print(f"gpt model is loaded") 45 | 46 | # Create a sampling params object. 47 | sampling_params = SamplingParams( 48 | temperature=args.temperature, top_p=args.top_p, top_k=args.top_k, 49 | max_tokens=latent_size ** 2) 50 | 51 | # Generate texts from the prompts. The output is a list of RequestOutput objects 52 | # that contain the prompt, generated text, and other information. 53 | t1 = time.time() 54 | outputs = llm.generate( 55 | prompt_token_ids=prompt_token_ids, 56 | sampling_params=sampling_params, 57 | use_tqdm=False) 58 | sampling_time = time.time() - t1 59 | print(f"gpt sampling takes about {sampling_time:.2f} seconds.") 60 | 61 | # decode to image 62 | index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device) 63 | if args.cfg_scale > 1.0: 64 | index_sample = index_sample[:len(class_labels)] 65 | t2 = time.time() 66 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] 67 | decoder_time = time.time() - t2 68 | print(f"decoder takes about {decoder_time:.2f} seconds.") 69 | 70 | # Save and display images: 71 | save_image(samples, "sample_{}_vllm.png".format(args.gpt_type), nrow=4, normalize=True, value_range=(-1, 1)) 72 | print(f"image is saved to sample_{args.gpt_type}_vllm.png") 73 | 74 | 75 | if __name__ == '__main__': 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B") 78 | parser.add_argument("--gpt-ckpt", type=str, required=True, help="ckpt path for gpt model") 79 | parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional") 80 | parser.add_argument("--from-fsdp", action='store_true') 81 | parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input") 82 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 83 | parser.add_argument("--compile", action='store_true', default=False) 84 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") 85 | parser.add_argument("--vq-ckpt", type=str, required=True, help="ckpt path for vq model") 86 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 87 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 88 | parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384) 89 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) 90 | parser.add_argument("--num-classes", type=int, default=1000) 91 | parser.add_argument("--cfg-scale", type=float, default=4.0) 92 | parser.add_argument("--seed", type=int, default=0) 93 | parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with") 94 | parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with") 95 | parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with") 96 | args = parser.parse_args() 97 | main(args) 98 | -------------------------------------------------------------------------------- /autoregressive/train/extract_codes_c2i.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/extract_features.py 3 | import torch 4 | torch.backends.cuda.matmul.allow_tf32 = True 5 | torch.backends.cudnn.allow_tf32 = True 6 | import torch.distributed as dist 7 | from torch.utils.data import DataLoader 8 | from torch.utils.data.distributed import DistributedSampler 9 | from torchvision import transforms 10 | import numpy as np 11 | import argparse 12 | import os 13 | 14 | from utils.distributed import init_distributed_mode 15 | from dataset.augmentation import center_crop_arr 16 | from dataset.build import build_dataset 17 | from tokenizer.tokenizer_image.vq_model import VQ_models 18 | 19 | 20 | ################################################################################# 21 | # Training Loop # 22 | ################################################################################# 23 | def main(args): 24 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 25 | # Setup DDP: 26 | if not args.debug: 27 | init_distributed_mode(args) 28 | rank = dist.get_rank() 29 | device = rank % torch.cuda.device_count() 30 | seed = args.global_seed * dist.get_world_size() + rank 31 | torch.manual_seed(seed) 32 | torch.cuda.set_device(device) 33 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 34 | else: 35 | device = 'cuda' 36 | rank = 0 37 | 38 | # Setup a feature folder: 39 | if args.debug or rank == 0: 40 | os.makedirs(args.code_path, exist_ok=True) 41 | os.makedirs(os.path.join(args.code_path, f'{args.dataset}{args.image_size}_codes'), exist_ok=True) 42 | os.makedirs(os.path.join(args.code_path, f'{args.dataset}{args.image_size}_labels'), exist_ok=True) 43 | 44 | # create and load model 45 | vq_model = VQ_models[args.vq_model]( 46 | codebook_size=args.codebook_size, 47 | codebook_embed_dim=args.codebook_embed_dim) 48 | vq_model.to(device) 49 | vq_model.eval() 50 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") 51 | vq_model.load_state_dict(checkpoint["model"]) 52 | del checkpoint 53 | 54 | # Setup data: 55 | if args.ten_crop: 56 | crop_size = int(args.image_size * args.crop_range) 57 | transform = transforms.Compose([ 58 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, crop_size)), 59 | transforms.TenCrop(args.image_size), # this is a tuple of PIL Images 60 | transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), # returns a 4D tensor 61 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 62 | ]) 63 | else: 64 | crop_size = args.image_size 65 | transform = transforms.Compose([ 66 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, crop_size)), 67 | transforms.ToTensor(), 68 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 69 | ]) 70 | dataset = build_dataset(args, transform=transform) 71 | if not args.debug: 72 | sampler = DistributedSampler( 73 | dataset, 74 | num_replicas=dist.get_world_size(), 75 | rank=rank, 76 | shuffle=False, 77 | seed=args.global_seed 78 | ) 79 | else: 80 | sampler = None 81 | loader = DataLoader( 82 | dataset, 83 | batch_size=1, # important! 84 | shuffle=False, 85 | sampler=sampler, 86 | num_workers=args.num_workers, 87 | pin_memory=True, 88 | drop_last=False 89 | ) 90 | 91 | total = 0 92 | for x, y in loader: 93 | x = x.to(device) 94 | if args.ten_crop: 95 | x_all = x.flatten(0, 1) 96 | num_aug = 10 97 | else: 98 | x_flip = torch.flip(x, dims=[-1]) 99 | x_all = torch.cat([x, x_flip]) 100 | num_aug = 2 101 | y = y.to(device) 102 | with torch.no_grad(): 103 | _, _, [_, _, indices] = vq_model.encode(x_all) 104 | codes = indices.reshape(x.shape[0], num_aug, -1) 105 | 106 | x = codes.detach().cpu().numpy() # (1, num_aug, args.image_size//16 * args.image_size//16) 107 | train_steps = rank + total 108 | np.save(f'{args.code_path}/{args.dataset}{args.image_size}_codes/{train_steps}.npy', x) 109 | 110 | y = y.detach().cpu().numpy() # (1,) 111 | np.save(f'{args.code_path}/{args.dataset}{args.image_size}_labels/{train_steps}.npy', y) 112 | if not args.debug: 113 | total += dist.get_world_size() 114 | else: 115 | total += 1 116 | print(total) 117 | 118 | dist.destroy_process_group() 119 | 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser() 123 | parser.add_argument("--data-path", type=str, required=True) 124 | parser.add_argument("--code-path", type=str, required=True) 125 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") 126 | parser.add_argument("--vq-ckpt", type=str, required=True, help="ckpt path for vq model") 127 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 128 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 129 | parser.add_argument("--dataset", type=str, default='imagenet') 130 | parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512], default=256) 131 | parser.add_argument("--ten-crop", action='store_true', help="whether using random crop") 132 | parser.add_argument("--crop-range", type=float, default=1.1, help="expanding range of center crop") 133 | parser.add_argument("--global-seed", type=int, default=0) 134 | parser.add_argument("--num-workers", type=int, default=24) 135 | parser.add_argument("--debug", action='store_true') 136 | args = parser.parse_args() 137 | main(args) 138 | -------------------------------------------------------------------------------- /autoregressive/train/extract_codes_t2i.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/extract_features.py 3 | import torch 4 | torch.backends.cuda.matmul.allow_tf32 = True 5 | torch.backends.cudnn.allow_tf32 = True 6 | import torch.distributed as dist 7 | from torch.utils.data import Dataset, DataLoader 8 | from torch.utils.data.distributed import DistributedSampler 9 | from torchvision import transforms 10 | import numpy as np 11 | from PIL import Image 12 | import glob 13 | import argparse 14 | import os 15 | import json 16 | 17 | from utils.distributed import init_distributed_mode 18 | from dataset.augmentation import center_crop_arr 19 | from tokenizer.tokenizer_image.vq_model import VQ_models 20 | 21 | 22 | ################################################################################# 23 | # Training Helper Functions # 24 | ################################################################################# 25 | class CustomDataset(Dataset): 26 | def __init__(self, lst_dir, start, end, transform): 27 | img_path_list = [] 28 | for lst_name in sorted(os.listdir(lst_dir))[start: end+1]: 29 | if not lst_name.endswith('.jsonl'): 30 | continue 31 | file_path = os.path.join(lst_dir, lst_name) 32 | with open(file_path, 'r') as file: 33 | for line_idx, line in enumerate(file): 34 | data = json.loads(line) 35 | img_path = data['image_path'] 36 | code_dir = file_path.split('/')[-1].split('.')[0] 37 | img_path_list.append((img_path, code_dir, line_idx)) 38 | self.img_path_list = img_path_list 39 | self.transform = transform 40 | 41 | def __len__(self): 42 | return len(self.img_path_list) 43 | 44 | def __getitem__(self, index): 45 | img_path, code_dir, code_name = self.img_path_list[index] 46 | img = Image.open(img_path).convert("RGB") 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | return img, code_dir, code_name 50 | 51 | 52 | 53 | ################################################################################# 54 | # Training Loop # 55 | ################################################################################# 56 | def main(args): 57 | """ 58 | Trains a new DiT model. 59 | """ 60 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 61 | 62 | # Setup DDP: 63 | # dist.init_process_group("nccl") 64 | init_distributed_mode(args) 65 | rank = dist.get_rank() 66 | device = rank % torch.cuda.device_count() 67 | seed = args.global_seed * dist.get_world_size() + rank 68 | torch.manual_seed(seed) 69 | torch.cuda.set_device(device) 70 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 71 | 72 | # Setup a feature folder: 73 | if rank == 0: 74 | os.makedirs(args.code_path, exist_ok=True) 75 | 76 | 77 | # create and load model 78 | vq_model = VQ_models[args.vq_model]( 79 | codebook_size=args.codebook_size, 80 | codebook_embed_dim=args.codebook_embed_dim) 81 | vq_model.to(device) 82 | vq_model.eval() 83 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") 84 | vq_model.load_state_dict(checkpoint["model"]) 85 | del checkpoint 86 | 87 | 88 | # Setup data: 89 | transform = transforms.Compose([ 90 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), 91 | transforms.ToTensor(), 92 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 93 | ]) 94 | print(f"Dataset is preparing...") 95 | dataset = CustomDataset(args.data_path, args.data_start, args.data_end, transform=transform) 96 | sampler = DistributedSampler( 97 | dataset, 98 | num_replicas=dist.get_world_size(), 99 | rank=rank, 100 | shuffle=False, 101 | seed=args.global_seed 102 | ) 103 | loader = DataLoader( 104 | dataset, 105 | batch_size=1, # important! 106 | shuffle=False, 107 | sampler=sampler, 108 | num_workers=args.num_workers, 109 | pin_memory=True, 110 | drop_last=False 111 | ) 112 | print(f"Dataset contains {len(dataset):,} images") 113 | 114 | # total = 0 115 | for img, code_dir, code_name in loader: 116 | img = img.to(device) 117 | with torch.no_grad(): 118 | _, _, [_, _, indices] = vq_model.encode(img) 119 | codes = indices.reshape(img.shape[0], -1) 120 | x = codes.detach().cpu().numpy() # (1, args.image_size//16 * args.image_size//16) 121 | os.makedirs(os.path.join(args.code_path, code_dir[0]), exist_ok=True) 122 | np.save(os.path.join(args.code_path, code_dir[0], '{}.npy'.format(code_name.item())), x) 123 | 124 | # total += dist.get_world_size() 125 | print(code_name.item()) 126 | 127 | dist.destroy_process_group() 128 | 129 | 130 | if __name__ == "__main__": 131 | parser = argparse.ArgumentParser() 132 | parser.add_argument("--data-path", type=str, required=True) 133 | parser.add_argument("--code-path", type=str, required=True) 134 | parser.add_argument("--data-start", type=int, required=True) 135 | parser.add_argument("--data-end", type=int, required=True) 136 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") 137 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") 138 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 139 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 140 | parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512], default=512) 141 | parser.add_argument("--global-seed", type=int, default=0) 142 | parser.add_argument("--num-workers", type=int, default=24) 143 | args = parser.parse_args() 144 | main(args) 145 | -------------------------------------------------------------------------------- /dataset/augmentation.py: -------------------------------------------------------------------------------- 1 | # from https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py 2 | import math 3 | import random 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def center_crop_arr(pil_image, image_size): 9 | """ 10 | Center cropping implementation from ADM. 11 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 12 | """ 13 | while min(*pil_image.size) >= 2 * image_size: 14 | pil_image = pil_image.resize( 15 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 16 | ) 17 | 18 | scale = image_size / min(*pil_image.size) 19 | pil_image = pil_image.resize( 20 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 21 | ) 22 | 23 | arr = np.array(pil_image) 24 | crop_y = (arr.shape[0] - image_size) // 2 25 | crop_x = (arr.shape[1] - image_size) // 2 26 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 27 | 28 | 29 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0): 30 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac) 31 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac) 32 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1) 33 | 34 | # We are not on a new enough PIL to support the `reducing_gap` 35 | # argument, which uses BOX downsampling at powers of two first. 36 | # Thus, we do it by hand to improve downsample quality. 37 | while min(*pil_image.size) >= 2 * smaller_dim_size: 38 | pil_image = pil_image.resize( 39 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 40 | ) 41 | 42 | scale = smaller_dim_size / min(*pil_image.size) 43 | pil_image = pil_image.resize( 44 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 45 | ) 46 | 47 | arr = np.array(pil_image) 48 | crop_y = random.randrange(arr.shape[0] - image_size + 1) 49 | crop_x = random.randrange(arr.shape[1] - image_size + 1) 50 | return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size]) 51 | 52 | -------------------------------------------------------------------------------- /dataset/build.py: -------------------------------------------------------------------------------- 1 | from dataset.imagenet import build_imagenet, build_imagenet_code 2 | from dataset.coco import build_coco 3 | from dataset.openimage import build_openimage 4 | from dataset.pexels import build_pexels 5 | from dataset.t2i import build_t2i, build_t2i_code, build_t2i_image 6 | 7 | 8 | def build_dataset(args, **kwargs): 9 | # images 10 | if args.dataset == 'imagenet': 11 | return build_imagenet(args, **kwargs) 12 | if args.dataset == 'imagenet_code': 13 | return build_imagenet_code(args, **kwargs) 14 | if args.dataset == 'coco': 15 | return build_coco(args, **kwargs) 16 | if args.dataset == 'openimage': 17 | return build_openimage(args, **kwargs) 18 | if args.dataset == 'pexels': 19 | return build_pexels(args, **kwargs) 20 | if args.dataset == 't2i_image': 21 | return build_t2i_image(args, **kwargs) 22 | if args.dataset == 't2i': 23 | return build_t2i(args, **kwargs) 24 | if args.dataset == 't2i_code': 25 | return build_t2i_code(args, **kwargs) 26 | 27 | raise ValueError(f'dataset {args.dataset} is not supported') -------------------------------------------------------------------------------- /dataset/coco.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.utils.data import Dataset 4 | from PIL import Image 5 | 6 | 7 | class SingleFolderDataset(Dataset): 8 | def __init__(self, directory, transform=None): 9 | super().__init__() 10 | self.directory = directory 11 | self.transform = transform 12 | self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) 13 | if os.path.isfile(os.path.join(directory, file_name))] 14 | 15 | def __len__(self): 16 | return len(self.image_paths) 17 | 18 | def __getitem__(self, idx): 19 | image_path = self.image_paths[idx] 20 | image = Image.open(image_path).convert('RGB') 21 | if self.transform: 22 | image = self.transform(image) 23 | return image, torch.tensor(0) 24 | 25 | 26 | def build_coco(args, transform): 27 | return SingleFolderDataset(args.data_path, transform=transform) -------------------------------------------------------------------------------- /dataset/imagenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | from torch.utils.data import Dataset 5 | from torchvision.datasets import ImageFolder 6 | 7 | 8 | class CustomDataset(Dataset): 9 | def __init__(self, feature_dir, label_dir): 10 | self.feature_dir = feature_dir 11 | self.label_dir = label_dir 12 | self.flip = 'flip' in self.feature_dir 13 | 14 | aug_feature_dir = feature_dir.replace('ten_crop/', 'ten_crop_105/') 15 | aug_label_dir = label_dir.replace('ten_crop/', 'ten_crop_105/') 16 | if os.path.exists(aug_feature_dir) and os.path.exists(aug_label_dir): 17 | self.aug_feature_dir = aug_feature_dir 18 | self.aug_label_dir = aug_label_dir 19 | else: 20 | self.aug_feature_dir = None 21 | self.aug_label_dir = None 22 | 23 | # self.feature_files = sorted(os.listdir(feature_dir)) 24 | # self.label_files = sorted(os.listdir(label_dir)) 25 | # TODO: make it configurable 26 | self.feature_files = [f"{i}.npy" for i in range(1281167)] 27 | self.label_files = [f"{i}.npy" for i in range(1281167)] 28 | 29 | def __len__(self): 30 | assert len(self.feature_files) == len(self.label_files), \ 31 | "Number of feature files and label files should be same" 32 | return len(self.feature_files) 33 | 34 | def __getitem__(self, idx): 35 | if self.aug_feature_dir is not None and torch.rand(1) < 0.5: 36 | feature_dir = self.aug_feature_dir 37 | label_dir = self.aug_label_dir 38 | else: 39 | feature_dir = self.feature_dir 40 | label_dir = self.label_dir 41 | 42 | feature_file = self.feature_files[idx] 43 | label_file = self.label_files[idx] 44 | 45 | features = np.load(os.path.join(feature_dir, feature_file)) 46 | if self.flip: 47 | aug_idx = torch.randint(low=0, high=features.shape[1], size=(1,)).item() 48 | features = features[:, aug_idx] 49 | labels = np.load(os.path.join(label_dir, label_file)) 50 | return torch.from_numpy(features), torch.from_numpy(labels) 51 | 52 | 53 | def build_imagenet(args, transform): 54 | return ImageFolder(args.data_path, transform=transform) 55 | 56 | def build_imagenet_code(args): 57 | feature_dir = f"{args.code_path}/imagenet{args.image_size}_codes" 58 | label_dir = f"{args.code_path}/imagenet{args.image_size}_labels" 59 | assert os.path.exists(feature_dir) and os.path.exists(label_dir), \ 60 | f"please first run: bash scripts/autoregressive/extract_codes_c2i.sh ..." 61 | return CustomDataset(feature_dir, label_dir) -------------------------------------------------------------------------------- /dataset/openimage.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | from PIL import Image 5 | 6 | import torch 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class DatasetJson(Dataset): 11 | def __init__(self, data_path, transform=None): 12 | super().__init__() 13 | self.data_path = data_path 14 | self.transform = transform 15 | json_path = os.path.join(data_path, 'image_paths.json') 16 | assert os.path.exists(json_path), f"please first run: python3 tools/openimage_json.py" 17 | with open(json_path, 'r') as f: 18 | self.image_paths = json.load(f) 19 | 20 | def __len__(self): 21 | return len(self.image_paths) 22 | 23 | def __getitem__(self, idx): 24 | for _ in range(20): 25 | try: 26 | return self.getdata(idx) 27 | except Exception as e: 28 | print(f"Error details: {str(e)}") 29 | idx = np.random.randint(len(self)) 30 | raise RuntimeError('Too many bad data.') 31 | 32 | def getdata(self, idx): 33 | image_path = self.image_paths[idx] 34 | image_path_full = os.path.join(self.data_path, image_path) 35 | image = Image.open(image_path_full).convert('RGB') 36 | if self.transform: 37 | image = self.transform(image) 38 | return image, torch.tensor(0) 39 | 40 | 41 | def build_openimage(args, transform): 42 | return DatasetJson(args.data_path, transform=transform) 43 | -------------------------------------------------------------------------------- /dataset/pexels.py: -------------------------------------------------------------------------------- 1 | from torchvision.datasets import ImageFolder 2 | 3 | def build_pexels(args, transform): 4 | return ImageFolder(args.data_path, transform=transform) -------------------------------------------------------------------------------- /dataset/t2i.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | 5 | import torch 6 | from torch.utils.data import Dataset 7 | from PIL import Image 8 | 9 | 10 | class Text2ImgDatasetImg(Dataset): 11 | def __init__(self, lst_dir, face_lst_dir, transform): 12 | img_path_list = [] 13 | valid_file_path = [] 14 | # collect valid jsonl 15 | for lst_name in sorted(os.listdir(lst_dir)): 16 | if not lst_name.endswith('.jsonl'): 17 | continue 18 | file_path = os.path.join(lst_dir, lst_name) 19 | valid_file_path.append(file_path) 20 | 21 | # collect valid jsonl for face 22 | if face_lst_dir is not None: 23 | for lst_name in sorted(os.listdir(face_lst_dir)): 24 | if not lst_name.endswith('_face.jsonl'): 25 | continue 26 | file_path = os.path.join(face_lst_dir, lst_name) 27 | valid_file_path.append(file_path) 28 | 29 | for file_path in valid_file_path: 30 | with open(file_path, 'r') as file: 31 | for line_idx, line in enumerate(file): 32 | data = json.loads(line) 33 | img_path = data['image_path'] 34 | code_dir = file_path.split('/')[-1].split('.')[0] 35 | img_path_list.append((img_path, code_dir, line_idx)) 36 | self.img_path_list = img_path_list 37 | self.transform = transform 38 | 39 | def __len__(self): 40 | return len(self.img_path_list) 41 | 42 | def __getitem__(self, index): 43 | img_path, code_dir, code_name = self.img_path_list[index] 44 | img = Image.open(img_path).convert("RGB") 45 | if self.transform is not None: 46 | img = self.transform(img) 47 | return img, code_name 48 | 49 | 50 | class Text2ImgDataset(Dataset): 51 | def __init__(self, args, transform): 52 | img_path_list = [] 53 | valid_file_path = [] 54 | # collect valid jsonl file path 55 | for lst_name in sorted(os.listdir(args.data_path)): 56 | if not lst_name.endswith('.jsonl'): 57 | continue 58 | file_path = os.path.join(args.data_path, lst_name) 59 | valid_file_path.append(file_path) 60 | 61 | for file_path in valid_file_path: 62 | with open(file_path, 'r') as file: 63 | for line_idx, line in enumerate(file): 64 | data = json.loads(line) 65 | img_path = data['image_path'] 66 | code_dir = file_path.split('/')[-1].split('.')[0] 67 | img_path_list.append((img_path, code_dir, line_idx)) 68 | self.img_path_list = img_path_list 69 | self.transform = transform 70 | 71 | self.t5_feat_path = args.t5_feat_path 72 | self.short_t5_feat_path = args.short_t5_feat_path 73 | self.t5_feat_path_base = self.t5_feat_path.split('/')[-1] 74 | if self.short_t5_feat_path is not None: 75 | self.short_t5_feat_path_base = self.short_t5_feat_path.split('/')[-1] 76 | else: 77 | self.short_t5_feat_path_base = self.t5_feat_path_base 78 | self.image_size = args.image_size 79 | latent_size = args.image_size // args.downsample_size 80 | self.code_len = latent_size ** 2 81 | self.t5_feature_max_len = 120 82 | self.t5_feature_dim = 2048 83 | self.max_seq_length = self.t5_feature_max_len + self.code_len 84 | 85 | def __len__(self): 86 | return len(self.img_path_list) 87 | 88 | def dummy_data(self): 89 | img = torch.zeros((3, self.image_size, self.image_size), dtype=torch.float32) 90 | t5_feat_padding = torch.zeros((1, self.t5_feature_max_len, self.t5_feature_dim)) 91 | attn_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).unsqueeze(0) 92 | valid = 0 93 | return img, t5_feat_padding, attn_mask, valid 94 | 95 | def __getitem__(self, index): 96 | img_path, code_dir, code_name = self.img_path_list[index] 97 | try: 98 | img = Image.open(img_path).convert("RGB") 99 | except: 100 | img, t5_feat_padding, attn_mask, valid = self.dummy_data() 101 | return img, t5_feat_padding, attn_mask, torch.tensor(valid) 102 | 103 | if min(img.size) < self.image_size: 104 | img, t5_feat_padding, attn_mask, valid = self.dummy_data() 105 | return img, t5_feat_padding, attn_mask, torch.tensor(valid) 106 | 107 | if self.transform is not None: 108 | img = self.transform(img) 109 | 110 | t5_file = os.path.join(self.t5_feat_path, code_dir, f"{code_name}.npy") 111 | if torch.rand(1) < 0.3: 112 | t5_file = t5_file.replace(self.t5_feat_path_base, self.short_t5_feat_path_base) 113 | 114 | t5_feat_padding = torch.zeros((1, self.t5_feature_max_len, self.t5_feature_dim)) 115 | if os.path.isfile(t5_file): 116 | try: 117 | t5_feat = torch.from_numpy(np.load(t5_file)) 118 | t5_feat_len = t5_feat.shape[1] 119 | feat_len = min(self.t5_feature_max_len, t5_feat_len) 120 | t5_feat_padding[:, -feat_len:] = t5_feat[:, :feat_len] 121 | emb_mask = torch.zeros((self.t5_feature_max_len,)) 122 | emb_mask[-feat_len:] = 1 123 | attn_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length)) 124 | T = self.t5_feature_max_len 125 | attn_mask[:, :T] = attn_mask[:, :T] * emb_mask.unsqueeze(0) 126 | eye_matrix = torch.eye(self.max_seq_length, self.max_seq_length) 127 | attn_mask = attn_mask * (1 - eye_matrix) + eye_matrix 128 | attn_mask = attn_mask.unsqueeze(0).to(torch.bool) 129 | valid = 1 130 | except: 131 | img, t5_feat_padding, attn_mask, valid = self.dummy_data() 132 | else: 133 | img, t5_feat_padding, attn_mask, valid = self.dummy_data() 134 | 135 | return img, t5_feat_padding, attn_mask, torch.tensor(valid) 136 | 137 | 138 | class Text2ImgDatasetCode(Dataset): 139 | def __init__(self, args): 140 | pass 141 | 142 | 143 | 144 | 145 | def build_t2i_image(args, transform): 146 | return Text2ImgDatasetImg(args.data_path, args.data_face_path, transform) 147 | 148 | def build_t2i(args, transform): 149 | return Text2ImgDataset(args, transform) 150 | 151 | def build_t2i_code(args): 152 | return Text2ImgDatasetCode(args) -------------------------------------------------------------------------------- /evaluations/c2i/README.md: -------------------------------------------------------------------------------- 1 | # Evaluations from [OpenAI](https://github.com/openai/guided-diffusion/tree/main/evaluations) 2 | 3 | To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files. 4 | 5 | # Installation 6 | ### cuda version 11.7 7 | ``` 8 | pip install tensorflow-gpu==2.5.0 9 | pip install numpy==1.22.0 10 | pip install scipy 11 | pip install pydantic 12 | ``` 13 | There will happen error like `tensorflow.python.framework.errors_impl.NotFoundError: /usr/local/lib/python3.9/dist-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so: undefined symbol: _ZN10tensorflow...`, deleting `/usr/local/lib/python3.9/dist-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so` will fix this error. 14 | 15 | ### cuda version 12.1 16 | ``` 17 | pip install tensorflow 18 | pip install numpy==1.23.5 19 | pip install scipy 20 | ``` 21 | 22 | ### H100, cuda version 12.2 23 | ``` 24 | pip install tensorflow 25 | pip install numpy==1.26.2 26 | pip install scipy 27 | ``` 28 | 29 | # Download batches 30 | 31 | We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format. 32 | 33 | Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall. 34 | 35 | Here are links to download all of the sample and reference batches: 36 | 37 | * LSUN 38 | * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz) 39 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz) 40 | * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz) 41 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz) 42 | * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz) 43 | * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz) 44 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz) 45 | * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz) 46 | * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz) 47 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz) 48 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz) 49 | 50 | * ImageNet 51 | * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz) 52 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz) 53 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz) 54 | * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz) 55 | * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz) 56 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz) 57 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz) 58 | * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz) 59 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz) 60 | * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz) 61 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz) 62 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz) 63 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz) 64 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz) 65 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz) 66 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz) 67 | * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz) 68 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz) 69 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz) 70 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz) 71 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz) 72 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz) 73 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz) 74 | 75 | # Run evaluations 76 | 77 | First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`. 78 | 79 | Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB. 80 | 81 | The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging: 82 | 83 | ``` 84 | $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz 85 | ... 86 | computing reference batch activations... 87 | computing/reading reference batch statistics... 88 | computing sample batch activations... 89 | computing/reading sample batch statistics... 90 | Computing evaluations... 91 | Inception Score: 215.8370361328125 92 | FID: 3.9425574129223264 93 | sFID: 6.140433703346162 94 | Precision: 0.8265 95 | Recall: 0.5309 96 | ``` 97 | -------------------------------------------------------------------------------- /evaluations/t2i/README.md: -------------------------------------------------------------------------------- 1 | # Evaluations from [GigaGAN](https://github.com/mingukkang/GigaGAN/tree/main/evaluation) 2 | 3 | ``` 4 | pip install git+https://github.com/openai/CLIP.git 5 | pip install open_clip_torch 6 | pip install clean_fid 7 | ``` 8 | 9 | ``` 10 | python3 evaluations/t2i/evaluation.py \ 11 | --eval_res 256 \ 12 | --batch_size 256 \ 13 | --how_many 30000 \ 14 | --ref_data "coco2014" \ 15 | --ref_type "val2014" \ 16 | --eval_res 256 \ 17 | --batch_size 256 \ 18 | --ref_dir "/path/to/coco" \ 19 | --fake_dir "/path/to/generation" \ 20 | $@ 21 | ``` 22 | -------------------------------------------------------------------------------- /language/README.md: -------------------------------------------------------------------------------- 1 | ## Language models for text-conditional image generation 2 | 3 | ### Requirements 4 | ``` 5 | pip install ftfy 6 | pip install transformers 7 | pip install accelerate 8 | pip install sentencepiece 9 | pip install pandas 10 | pip install bs4 11 | ``` 12 | 13 | ### Language Models 14 | Download flan-t5-xl models from [flan-t5-xl](https://huggingface.co/google/flan-t5-xl) and put into the folder of `./pretrained_models/t5-ckpt/` 15 | -------------------------------------------------------------------------------- /language/extract_t5_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | torch.backends.cudnn.allow_tf32 = True 4 | import torch.distributed as dist 5 | from torch.utils.data import Dataset, DataLoader 6 | from torch.utils.data.distributed import DistributedSampler 7 | import numpy as np 8 | import argparse 9 | import os 10 | import json 11 | 12 | from utils.distributed import init_distributed_mode 13 | from language.t5 import T5Embedder 14 | 15 | CAPTION_KEY = { 16 | 'blip': 0, 17 | 'llava': 1, 18 | 'llava_first': 2, 19 | } 20 | ################################################################################# 21 | # Training Helper Functions # 22 | ################################################################################# 23 | class CustomDataset(Dataset): 24 | def __init__(self, lst_dir, start, end, caption_key, trunc_caption=False): 25 | img_path_list = [] 26 | for lst_name in sorted(os.listdir(lst_dir))[start: end+1]: 27 | if not lst_name.endswith('.jsonl'): 28 | continue 29 | file_path = os.path.join(lst_dir, lst_name) 30 | with open(file_path, 'r') as file: 31 | for line_idx, line in enumerate(file): 32 | data = json.loads(line) 33 | # caption = data[caption_key] 34 | caption = data['text'][CAPTION_KEY[caption_key]] 35 | code_dir = file_path.split('/')[-1].split('.')[0] 36 | if trunc_caption: 37 | caption = caption.split('.')[0] 38 | img_path_list.append((caption, code_dir, line_idx)) 39 | self.img_path_list = img_path_list 40 | 41 | def __len__(self): 42 | return len(self.img_path_list) 43 | 44 | def __getitem__(self, index): 45 | caption, code_dir, code_name = self.img_path_list[index] 46 | return caption, code_dir, code_name 47 | 48 | 49 | 50 | ################################################################################# 51 | # Training Loop # 52 | ################################################################################# 53 | def main(args): 54 | """ 55 | Trains a new DiT model. 56 | """ 57 | assert torch.cuda.is_available(), "Training currently requires at least one GPU." 58 | 59 | # Setup DDP: 60 | # dist.init_process_group("nccl") 61 | init_distributed_mode(args) 62 | rank = dist.get_rank() 63 | device = rank % torch.cuda.device_count() 64 | seed = args.global_seed * dist.get_world_size() + rank 65 | torch.manual_seed(seed) 66 | torch.cuda.set_device(device) 67 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 68 | 69 | # Setup a feature folder: 70 | if rank == 0: 71 | os.makedirs(args.t5_path, exist_ok=True) 72 | 73 | # Setup data: 74 | print(f"Dataset is preparing...") 75 | dataset = CustomDataset(args.data_path, args.data_start, args.data_end, args.caption_key, args.trunc_caption) 76 | sampler = DistributedSampler( 77 | dataset, 78 | num_replicas=dist.get_world_size(), 79 | rank=rank, 80 | shuffle=False, 81 | seed=args.global_seed 82 | ) 83 | loader = DataLoader( 84 | dataset, 85 | batch_size=1, # important! 86 | shuffle=False, 87 | sampler=sampler, 88 | num_workers=args.num_workers, 89 | pin_memory=True, 90 | drop_last=False 91 | ) 92 | print(f"Dataset contains {len(dataset):,} images") 93 | 94 | precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] 95 | assert os.path.exists(args.t5_model_path) 96 | t5_xxl = T5Embedder( 97 | device=device, 98 | local_cache=True, 99 | cache_dir=args.t5_model_path, 100 | dir_or_name=args.t5_model_type, 101 | torch_dtype=precision 102 | ) 103 | 104 | for caption, code_dir, code_name in loader: 105 | caption_embs, emb_masks = t5_xxl.get_text_embeddings(caption) 106 | valid_caption_embs = caption_embs[:, :emb_masks.sum()] 107 | x = valid_caption_embs.to(torch.float32).detach().cpu().numpy() 108 | os.makedirs(os.path.join(args.t5_path, code_dir[0]), exist_ok=True) 109 | np.save(os.path.join(args.t5_path, code_dir[0], '{}.npy'.format(code_name.item())), x) 110 | print(code_name.item()) 111 | 112 | dist.destroy_process_group() 113 | 114 | 115 | if __name__ == "__main__": 116 | parser = argparse.ArgumentParser() 117 | parser.add_argument("--data-path", type=str, required=True) 118 | parser.add_argument("--t5-path", type=str, required=True) 119 | parser.add_argument("--data-start", type=int, required=True) 120 | parser.add_argument("--data-end", type=int, required=True) 121 | parser.add_argument("--caption-key", type=str, default='blip', choices=list(CAPTION_KEY.keys())) 122 | parser.add_argument("--trunc-caption", action='store_true', default=False) 123 | parser.add_argument("--t5-model-path", type=str, default='./pretrained_models/t5-ckpt') 124 | parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl') 125 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 126 | parser.add_argument("--global-seed", type=int, default=0) 127 | parser.add_argument("--num-workers", type=int, default=24) 128 | args = parser.parse_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=2.1.0 2 | -------------------------------------------------------------------------------- /scripts/autoregressive/extract_codes_c2i.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12335 \ 7 | autoregressive/train/extract_codes_c2i.py "$@" 8 | -------------------------------------------------------------------------------- /scripts/autoregressive/sample_c2i.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12345 \ 7 | autoregressive/sample/sample_c2i_ddp.py \ 8 | --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt \ 9 | "$@" 10 | -------------------------------------------------------------------------------- /scripts/autoregressive/sample_t2i_coco.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12346 \ 7 | autoregressive/sample/sample_t2i_ddp.py \ 8 | --prompt-csv evaluations/t2i/coco_captions.csv \ 9 | --sample-dir samples_coco \ 10 | --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt \ 11 | "$@" 12 | -------------------------------------------------------------------------------- /scripts/autoregressive/sample_t2i_parti.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12347 \ 7 | autoregressive/sample/sample_t2i_ddp.py \ 8 | --prompt-csv evaluations/t2i/PartiPrompts.tsv \ 9 | --sample-dir samples_parti \ 10 | --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt \ 11 | "$@" 12 | -------------------------------------------------------------------------------- /scripts/autoregressive/train_c2i.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \ 6 | --master_addr=$master_addr --master_port=$master_port \ 7 | autoregressive/train/train_c2i.py "$@" 8 | -------------------------------------------------------------------------------- /scripts/autoregressive/train_c2i_fsdp.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \ 6 | --master_addr=$master_addr --master_port=$master_port \ 7 | autoregressive/train/train_c2i_fsdp.py "$@" 8 | -------------------------------------------------------------------------------- /scripts/autoregressive/train_t2i_stage1.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \ 6 | --master_addr=$master_addr --master_port=$master_port \ 7 | autoregressive/train/train_t2i.py \ 8 | --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt \ 9 | --data-path /path/to/laion_coco50M \ 10 | --t5-feat-path /path/to/laion_coco50M_flan_t5_xl \ 11 | --dataset t2i \ 12 | --image-size 256 \ 13 | "$@" 14 | -------------------------------------------------------------------------------- /scripts/autoregressive/train_t2i_stage2.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \ 6 | --master_addr=$master_addr --master_port=$master_port \ 7 | autoregressive/train/train_t2i.py \ 8 | --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt \ 9 | --data-path /path/to/high_aesthetic_10M \ 10 | --t5-feat-path /path/to/high_aesthetic_10M_flan_t5_xl \ 11 | --short-t5-feat-path /path/to/high_aesthetic_10M_trunc_flan_t5_xl \ 12 | --dataset t2i \ 13 | --image-size 512 \ 14 | "$@" 15 | -------------------------------------------------------------------------------- /scripts/language/extract_flan_t5_feat_laion_coco_stage1.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12337 \ 7 | language/extract_t5_feature.py \ 8 | --data-path /path/to/laion_coco50M \ 9 | --t5-path /path/to/laion_coco50M_flan_t5_xl \ 10 | --caption-key blip \ 11 | "$@" 12 | -------------------------------------------------------------------------------- /scripts/language/extract_flan_t5_feat_stage2.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12337 \ 7 | language/extract_t5_feature.py \ 8 | --data-path /path/to/high_aesthetic_10M \ 9 | --t5-path /path/to/high_aesthetic_10M_flan_t5_xl \ 10 | "$@" 11 | -------------------------------------------------------------------------------- /scripts/language/extract_flan_t5_feat_trunc_stage2.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12337 \ 7 | language/extract_t5_feature.py \ 8 | --data-path /path/to/high_aesthetic_10M \ 9 | --t5-path /path/to/high_aesthetic_10M_trunc_flan_t5_xl \ 10 | --trunc-caption \ 11 | "$@" 12 | -------------------------------------------------------------------------------- /scripts/tokenizer/reconstruction_consistency_decoder.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12344 \ 7 | tokenizer/consistencydecoder/reconstruction_cd_ddp.py \ 8 | "$@" -------------------------------------------------------------------------------- /scripts/tokenizer/reconstruction_vae.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12344 \ 7 | tokenizer/vae/reconstruction_vae_ddp.py \ 8 | "$@" -------------------------------------------------------------------------------- /scripts/tokenizer/reconstruction_vq.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12344 \ 7 | tokenizer/tokenizer_image/reconstruction_vq_ddp.py \ 8 | "$@" -------------------------------------------------------------------------------- /scripts/tokenizer/reconstruction_vqgan.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12344 \ 7 | tokenizer/vqgan/reconstruction_vqgan_ddp.py \ 8 | "$@" -------------------------------------------------------------------------------- /scripts/tokenizer/train_vq.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \ 6 | --master_addr=$master_addr --master_port=$master_port \ 7 | tokenizer/tokenizer_image/vq_train.py "$@" -------------------------------------------------------------------------------- /scripts/tokenizer/train_vq_finetune.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \ 6 | --master_addr=$master_addr --master_port=$master_port \ 7 | tokenizer/tokenizer_image/vq_train.py \ 8 | --finetune \ 9 | --disc-start 0 \ 10 | --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt \ 11 | --dataset t2i_image \ 12 | --data-path /path/to/high_aesthetic_10M \ 13 | --data-face-path /path/to/face_2M \ 14 | --cloud-save-path /path/to/cloud_disk \ 15 | "$@" -------------------------------------------------------------------------------- /scripts/tokenizer/train_vq_finetune_continue.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \ 6 | --master_addr=$master_addr --master_port=$master_port \ 7 | tokenizer/tokenizer_image/vq_train.py \ 8 | --disc-start 0 \ 9 | --dataset t2i_image \ 10 | --data-path /path/to/high_aesthetic_10M \ 11 | --data-face-path /path/to/face_2M \ 12 | --cloud-save-path /path/to/cloud_disk \ 13 | "$@" 14 | 15 | # --vq-ckpt xxx.pt -------------------------------------------------------------------------------- /scripts/tokenizer/val.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -x 3 | 4 | torchrun \ 5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \ 6 | --master_port=12343 \ 7 | tokenizer/validation/val_ddp.py \ 8 | "$@" -------------------------------------------------------------------------------- /tokenizer/consistencydecoder/README.md: -------------------------------------------------------------------------------- 1 | ## Consistency Decoder from OpenAI 2 | 3 | ### install 4 | ``` 5 | pip install diffusers 6 | pip install accelerate 7 | ``` 8 | 9 | ### demo 10 | ``` 11 | cd ${THIS_REPO_ROOT} 12 | python3 tokenizer/consistencydecoder/cd_demo.py 13 | ``` 14 | 15 | -------------------------------------------------------------------------------- /tokenizer/consistencydecoder/cd_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from PIL import Image 6 | from diffusers import ConsistencyDecoderVAE 7 | 8 | 9 | def main(args): 10 | # Setup PyTorch: 11 | torch.manual_seed(args.seed) 12 | torch.set_grad_enabled(False) 13 | device = "cuda" if torch.cuda.is_available() else "cpu" 14 | 15 | # create and load model 16 | vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to(device) 17 | 18 | # load image 19 | img_path = args.image_path 20 | out_path = args.image_path.replace('.jpg', '_cd.jpg').replace('.jpeg', '_cd.jpeg').replace('.png', '_cd.png') 21 | input_size = args.image_size 22 | img = Image.open(img_path).convert("RGB") 23 | 24 | # preprocess 25 | size_org = img.size 26 | img = img.resize((input_size, input_size)) 27 | img = np.array(img) / 255. 28 | x = 2.0 * img - 1.0 # x value is between [-1, 1] 29 | x = torch.tensor(x) 30 | x = x.unsqueeze(dim=0) 31 | x = torch.einsum('nhwc->nchw', x) 32 | x_input = x.half().to(device) 33 | 34 | # inference 35 | with torch.no_grad(): 36 | # Map input images to latent space + normalize latents: 37 | latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215) 38 | # reconstruct: 39 | output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1] 40 | 41 | # postprocess 42 | output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0] 43 | sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() 44 | 45 | # save 46 | Image.fromarray(sample).save(out_path) 47 | print("Reconstructed image is saved to {}".format(out_path)) 48 | 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | parser.add_argument("--image-path", type=str, default="assets/example.jpg") 54 | parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512) 55 | parser.add_argument("--seed", type=int, default=0) 56 | args = parser.parse_args() 57 | main(args) 58 | -------------------------------------------------------------------------------- /tokenizer/consistencydecoder/reconstruction_cd_ddp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | torch.backends.cudnn.allow_tf32 = True 4 | import torch.distributed as dist 5 | from torch.utils.data import Dataset, DataLoader 6 | from torch.utils.data.distributed import DistributedSampler 7 | from torchvision.datasets import ImageFolder 8 | from torchvision import transforms 9 | from tqdm import tqdm 10 | import os 11 | import itertools 12 | from PIL import Image 13 | import numpy as np 14 | import argparse 15 | import random 16 | 17 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 18 | from skimage.metrics import structural_similarity as ssim_loss 19 | from diffusers.models import ConsistencyDecoderVAE 20 | 21 | 22 | class SingleFolderDataset(Dataset): 23 | def __init__(self, directory, transform=None): 24 | super().__init__() 25 | self.directory = directory 26 | self.transform = transform 27 | self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) 28 | if os.path.isfile(os.path.join(directory, file_name))] 29 | 30 | def __len__(self): 31 | return len(self.image_paths) 32 | 33 | def __getitem__(self, idx): 34 | image_path = self.image_paths[idx] 35 | image = Image.open(image_path).convert('RGB') 36 | if self.transform: 37 | image = self.transform(image) 38 | return image, torch.tensor(0) 39 | 40 | 41 | def create_npz_from_sample_folder(sample_dir, num=50_000): 42 | """ 43 | Builds a single .npz file from a folder of .png samples. 44 | """ 45 | samples = [] 46 | for i in tqdm(range(num), desc="Building .npz file from samples"): 47 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 48 | sample_np = np.asarray(sample_pil).astype(np.uint8) 49 | samples.append(sample_np) 50 | 51 | random.shuffle(samples) # This is very important for IS(Inception Score) !!! 52 | samples = np.stack(samples) 53 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 54 | npz_path = f"{sample_dir}.npz" 55 | np.savez(npz_path, arr_0=samples) 56 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 57 | return npz_path 58 | 59 | 60 | def center_crop_arr(pil_image, image_size): 61 | """ 62 | Center cropping implementation from ADM. 63 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 64 | """ 65 | while min(*pil_image.size) >= 2 * image_size: 66 | pil_image = pil_image.resize( 67 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 68 | ) 69 | 70 | scale = image_size / min(*pil_image.size) 71 | pil_image = pil_image.resize( 72 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 73 | ) 74 | 75 | arr = np.array(pil_image) 76 | crop_y = (arr.shape[0] - image_size) // 2 77 | crop_x = (arr.shape[1] - image_size) // 2 78 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 79 | 80 | 81 | def main(args): 82 | # Setup PyTorch: 83 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 84 | torch.set_grad_enabled(False) 85 | 86 | # Setup env 87 | dist.init_process_group("nccl") 88 | rank = dist.get_rank() 89 | device = rank % torch.cuda.device_count() 90 | seed = args.global_seed * dist.get_world_size() + rank 91 | torch.manual_seed(seed) 92 | torch.cuda.set_device(device) 93 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 94 | 95 | # create and load model 96 | vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to("cuda:{}".format(device)) 97 | 98 | # Create folder to save samples: 99 | folder_name = f"openai-consistencydecoder-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}" 100 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 101 | if rank == 0: 102 | os.makedirs(sample_folder_dir, exist_ok=True) 103 | print(f"Saving .png samples at {sample_folder_dir}") 104 | dist.barrier() 105 | 106 | # Setup data: 107 | transform = transforms.Compose([ 108 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), 109 | transforms.ToTensor(), 110 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 111 | ]) 112 | if args.dataset == 'imagenet': 113 | dataset = ImageFolder(args.data_path, transform=transform) 114 | num_fid_samples = 50000 115 | elif args.dataset == 'coco': 116 | dataset = SingleFolderDataset(args.data_path, transform=transform) 117 | num_fid_samples = 5000 118 | else: 119 | raise Exception("please check dataset") 120 | sampler = DistributedSampler( 121 | dataset, 122 | num_replicas=dist.get_world_size(), 123 | rank=rank, 124 | shuffle=False, 125 | seed=args.global_seed 126 | ) 127 | loader = DataLoader( 128 | dataset, 129 | batch_size=args.per_proc_batch_size, 130 | shuffle=False, 131 | sampler=sampler, 132 | num_workers=args.num_workers, 133 | pin_memory=True, 134 | drop_last=False 135 | ) 136 | 137 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 138 | n = args.per_proc_batch_size 139 | global_batch_size = n * dist.get_world_size() 140 | psnr_val_rgb = [] 141 | ssim_val_rgb = [] 142 | 143 | loader = tqdm(loader) if rank == 0 else loader 144 | total = 0 145 | for x, _ in loader: 146 | rgb_gts = x 147 | rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1] 148 | x = x.half().to("cuda:{}".format(device)) 149 | with torch.no_grad(): 150 | # Map input images to latent space + normalize latents: 151 | latent = vae.encode(x).latent_dist.sample().mul_(0.18215) 152 | # reconstruct: 153 | samples = vae.decode(latent / 0.18215).sample # output value is between [-1, 1] 154 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 155 | 156 | # Save samples to disk as individual .png files 157 | for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)): 158 | index = i * dist.get_world_size() + rank + total 159 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 160 | # metric 161 | rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1] 162 | psnr = psnr_loss(rgb_restored, rgb_gt) 163 | ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1) 164 | psnr_val_rgb.append(psnr) 165 | ssim_val_rgb.append(ssim) 166 | total += global_batch_size 167 | 168 | # ------------------------------------ 169 | # Summary 170 | # ------------------------------------ 171 | # Make sure all processes have finished saving their samples 172 | dist.barrier() 173 | world_size = dist.get_world_size() 174 | gather_psnr_val = [None for _ in range(world_size)] 175 | gather_ssim_val = [None for _ in range(world_size)] 176 | dist.all_gather_object(gather_psnr_val, psnr_val_rgb) 177 | dist.all_gather_object(gather_ssim_val, ssim_val_rgb) 178 | 179 | if rank == 0: 180 | gather_psnr_val = list(itertools.chain(*gather_psnr_val)) 181 | gather_ssim_val = list(itertools.chain(*gather_ssim_val)) 182 | psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val) 183 | ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val) 184 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) 185 | 186 | result_file = f"{sample_folder_dir}_results.txt" 187 | print("writing results to {}".format(result_file)) 188 | with open(result_file, 'w') as f: 189 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f) 190 | 191 | create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) 192 | print("Done.") 193 | 194 | dist.barrier() 195 | dist.destroy_process_group() 196 | 197 | 198 | if __name__ == "__main__": 199 | parser = argparse.ArgumentParser() 200 | parser.add_argument("--data-path", type=str, required=True) 201 | parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet') 202 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 203 | parser.add_argument("--sample-dir", type=str, default="reconstructions") 204 | parser.add_argument("--per-proc-batch-size", type=int, default=32) 205 | parser.add_argument("--global-seed", type=int, default=0) 206 | parser.add_argument("--num-workers", type=int, default=4) 207 | args = parser.parse_args() 208 | main(args) -------------------------------------------------------------------------------- /tokenizer/tokenizer_image/cache/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FoundationVision/LlamaGen/ce98ec41803a74a90ce68c40ababa9eaeffeb4ec/tokenizer/tokenizer_image/cache/vgg.pth -------------------------------------------------------------------------------- /tokenizer/tokenizer_image/discriminator_patchgan.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # taming-transformers: https://github.com/CompVis/taming-transformers 3 | import functools 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | class NLayerDiscriminator(nn.Module): 9 | """Defines a PatchGAN discriminator as in Pix2Pix 10 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 11 | """ 12 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 13 | """Construct a PatchGAN discriminator 14 | Parameters: 15 | input_nc (int) -- the number of channels in input images 16 | ndf (int) -- the number of filters in the last conv layer 17 | n_layers (int) -- the number of conv layers in the discriminator 18 | norm_layer -- normalization layer 19 | """ 20 | super(NLayerDiscriminator, self).__init__() 21 | if not use_actnorm: 22 | norm_layer = nn.BatchNorm2d 23 | else: 24 | norm_layer = ActNorm 25 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 26 | use_bias = norm_layer.func != nn.BatchNorm2d 27 | else: 28 | use_bias = norm_layer != nn.BatchNorm2d 29 | 30 | kw = 4 31 | padw = 1 32 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 33 | nf_mult = 1 34 | nf_mult_prev = 1 35 | for n in range(1, n_layers): # gradually increase the number of filters 36 | nf_mult_prev = nf_mult 37 | nf_mult = min(2 ** n, 8) 38 | sequence += [ 39 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 40 | norm_layer(ndf * nf_mult), 41 | nn.LeakyReLU(0.2, True) 42 | ] 43 | 44 | nf_mult_prev = nf_mult 45 | nf_mult = min(2 ** n_layers, 8) 46 | sequence += [ 47 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 48 | norm_layer(ndf * nf_mult), 49 | nn.LeakyReLU(0.2, True) 50 | ] 51 | 52 | sequence += [ 53 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 54 | self.main = nn.Sequential(*sequence) 55 | 56 | self.apply(self._init_weights) 57 | 58 | def _init_weights(self, module): 59 | if isinstance(module, nn.Conv2d): 60 | nn.init.normal_(module.weight.data, 0.0, 0.02) 61 | elif isinstance(module, nn.BatchNorm2d): 62 | nn.init.normal_(module.weight.data, 1.0, 0.02) 63 | nn.init.constant_(module.bias.data, 0) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.main(input) 68 | 69 | 70 | class ActNorm(nn.Module): 71 | def __init__(self, num_features, logdet=False, affine=True, 72 | allow_reverse_init=False): 73 | assert affine 74 | super().__init__() 75 | self.logdet = logdet 76 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 77 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 78 | self.allow_reverse_init = allow_reverse_init 79 | 80 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 81 | 82 | def initialize(self, input): 83 | with torch.no_grad(): 84 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 85 | mean = ( 86 | flatten.mean(1) 87 | .unsqueeze(1) 88 | .unsqueeze(2) 89 | .unsqueeze(3) 90 | .permute(1, 0, 2, 3) 91 | ) 92 | std = ( 93 | flatten.std(1) 94 | .unsqueeze(1) 95 | .unsqueeze(2) 96 | .unsqueeze(3) 97 | .permute(1, 0, 2, 3) 98 | ) 99 | 100 | self.loc.data.copy_(-mean) 101 | self.scale.data.copy_(1 / (std + 1e-6)) 102 | 103 | def forward(self, input, reverse=False): 104 | if reverse: 105 | return self.reverse(input) 106 | if len(input.shape) == 2: 107 | input = input[:,:,None,None] 108 | squeeze = True 109 | else: 110 | squeeze = False 111 | 112 | _, _, height, width = input.shape 113 | 114 | if self.training and self.initialized.item() == 0: 115 | self.initialize(input) 116 | self.initialized.fill_(1) 117 | 118 | h = self.scale * (input + self.loc) 119 | 120 | if squeeze: 121 | h = h.squeeze(-1).squeeze(-1) 122 | 123 | if self.logdet: 124 | log_abs = torch.log(torch.abs(self.scale)) 125 | logdet = height*width*torch.sum(log_abs) 126 | logdet = logdet * torch.ones(input.shape[0]).to(input) 127 | return h, logdet 128 | 129 | return h 130 | 131 | def reverse(self, output): 132 | if self.training and self.initialized.item() == 0: 133 | if not self.allow_reverse_init: 134 | raise RuntimeError( 135 | "Initializing ActNorm in reverse direction is " 136 | "disabled by default. Use allow_reverse_init=True to enable." 137 | ) 138 | else: 139 | self.initialize(output) 140 | self.initialized.fill_(1) 141 | 142 | if len(output.shape) == 2: 143 | output = output[:,:,None,None] 144 | squeeze = True 145 | else: 146 | squeeze = False 147 | 148 | h = output / self.scale - self.loc 149 | 150 | if squeeze: 151 | h = h.squeeze(-1).squeeze(-1) 152 | return h -------------------------------------------------------------------------------- /tokenizer/tokenizer_image/discriminator_stylegan.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py 3 | # stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py 4 | # maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py 5 | import math 6 | import torch 7 | import torch.nn as nn 8 | try: 9 | from kornia.filters import filter2d 10 | except: 11 | pass 12 | 13 | class Discriminator(nn.Module): 14 | def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256): 15 | super().__init__() 16 | channels = { 17 | 4: 512, 18 | 8: 512, 19 | 16: 512, 20 | 32: 512, 21 | 64: 256 * channel_multiplier, 22 | 128: 128 * channel_multiplier, 23 | 256: 64 * channel_multiplier, 24 | 512: 32 * channel_multiplier, 25 | 1024: 16 * channel_multiplier, 26 | } 27 | 28 | log_size = int(math.log(image_size, 2)) 29 | in_channel = channels[image_size] 30 | 31 | blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()] 32 | for i in range(log_size, 2, -1): 33 | out_channel = channels[2 ** (i - 1)] 34 | blocks.append(DiscriminatorBlock(in_channel, out_channel)) 35 | in_channel = out_channel 36 | self.blocks = nn.ModuleList(blocks) 37 | 38 | self.final_conv = nn.Sequential( 39 | nn.Conv2d(in_channel, channels[4], 3, padding=1), 40 | leaky_relu(), 41 | ) 42 | self.final_linear = nn.Sequential( 43 | nn.Linear(channels[4] * 4 * 4, channels[4]), 44 | leaky_relu(), 45 | nn.Linear(channels[4], 1) 46 | ) 47 | 48 | def forward(self, x): 49 | for block in self.blocks: 50 | x = block(x) 51 | x = self.final_conv(x) 52 | x = x.view(x.shape[0], -1) 53 | x = self.final_linear(x) 54 | return x 55 | 56 | 57 | class DiscriminatorBlock(nn.Module): 58 | def __init__(self, input_channels, filters, downsample=True): 59 | super().__init__() 60 | self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1)) 61 | 62 | self.net = nn.Sequential( 63 | nn.Conv2d(input_channels, filters, 3, padding=1), 64 | leaky_relu(), 65 | nn.Conv2d(filters, filters, 3, padding=1), 66 | leaky_relu() 67 | ) 68 | 69 | self.downsample = nn.Sequential( 70 | Blur(), 71 | nn.Conv2d(filters, filters, 3, padding = 1, stride = 2) 72 | ) if downsample else None 73 | 74 | def forward(self, x): 75 | res = self.conv_res(x) 76 | x = self.net(x) 77 | if exists(self.downsample): 78 | x = self.downsample(x) 79 | x = (x + res) * (1 / math.sqrt(2)) 80 | return x 81 | 82 | 83 | 84 | class Blur(nn.Module): 85 | def __init__(self): 86 | super().__init__() 87 | f = torch.Tensor([1, 2, 1]) 88 | self.register_buffer('f', f) 89 | 90 | def forward(self, x): 91 | f = self.f 92 | f = f[None, None, :] * f [None, :, None] 93 | return filter2d(x, f, normalized=True) 94 | 95 | 96 | def leaky_relu(p=0.2): 97 | return nn.LeakyReLU(p, inplace=True) 98 | 99 | 100 | def exists(val): 101 | return val is not None 102 | -------------------------------------------------------------------------------- /tokenizer/tokenizer_image/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import os, hashlib 4 | import requests 5 | from tqdm import tqdm 6 | 7 | import torch 8 | import torch.nn as nn 9 | from torchvision import models 10 | from collections import namedtuple 11 | 12 | URL_MAP = { 13 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" 14 | } 15 | 16 | CKPT_MAP = { 17 | "vgg_lpips": "vgg.pth" 18 | } 19 | 20 | MD5_MAP = { 21 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" 22 | } 23 | 24 | def download(url, local_path, chunk_size=1024): 25 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 26 | with requests.get(url, stream=True) as r: 27 | total_size = int(r.headers.get("content-length", 0)) 28 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 29 | with open(local_path, "wb") as f: 30 | for data in r.iter_content(chunk_size=chunk_size): 31 | if data: 32 | f.write(data) 33 | pbar.update(chunk_size) 34 | 35 | 36 | def md5_hash(path): 37 | with open(path, "rb") as f: 38 | content = f.read() 39 | return hashlib.md5(content).hexdigest() 40 | 41 | 42 | def get_ckpt_path(name, root, check=False): 43 | assert name in URL_MAP 44 | path = os.path.join(root, CKPT_MAP[name]) 45 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): 46 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) 47 | download(URL_MAP[name], path) 48 | md5 = md5_hash(path) 49 | assert md5 == MD5_MAP[name], md5 50 | return path 51 | 52 | 53 | class LPIPS(nn.Module): 54 | # Learned perceptual metric 55 | def __init__(self, use_dropout=True): 56 | super().__init__() 57 | self.scaling_layer = ScalingLayer() 58 | self.chns = [64, 128, 256, 512, 512] # vg16 features 59 | self.net = vgg16(pretrained=True, requires_grad=False) 60 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 61 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 62 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 63 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 64 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 65 | self.load_from_pretrained() 66 | for param in self.parameters(): 67 | param.requires_grad = False 68 | 69 | def load_from_pretrained(self, name="vgg_lpips"): 70 | ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache")) 71 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 72 | print("loaded pretrained LPIPS loss from {}".format(ckpt)) 73 | 74 | @classmethod 75 | def from_pretrained(cls, name="vgg_lpips"): 76 | if name != "vgg_lpips": 77 | raise NotImplementedError 78 | model = cls() 79 | ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache")) 80 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 81 | return model 82 | 83 | def forward(self, input, target): 84 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 85 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 86 | feats0, feats1, diffs = {}, {}, {} 87 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 88 | for kk in range(len(self.chns)): 89 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 90 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 91 | 92 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 93 | val = res[0] 94 | for l in range(1, len(self.chns)): 95 | val += res[l] 96 | return val 97 | 98 | 99 | class ScalingLayer(nn.Module): 100 | def __init__(self): 101 | super(ScalingLayer, self).__init__() 102 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 103 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 104 | 105 | def forward(self, inp): 106 | return (inp - self.shift) / self.scale 107 | 108 | 109 | class NetLinLayer(nn.Module): 110 | """ A single linear layer which does a 1x1 conv """ 111 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 112 | super(NetLinLayer, self).__init__() 113 | layers = [nn.Dropout(), ] if (use_dropout) else [] 114 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 115 | self.model = nn.Sequential(*layers) 116 | 117 | 118 | class vgg16(torch.nn.Module): 119 | def __init__(self, requires_grad=False, pretrained=True): 120 | super(vgg16, self).__init__() 121 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features 122 | self.slice1 = torch.nn.Sequential() 123 | self.slice2 = torch.nn.Sequential() 124 | self.slice3 = torch.nn.Sequential() 125 | self.slice4 = torch.nn.Sequential() 126 | self.slice5 = torch.nn.Sequential() 127 | self.N_slices = 5 128 | for x in range(4): 129 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 130 | for x in range(4, 9): 131 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 132 | for x in range(9, 16): 133 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 134 | for x in range(16, 23): 135 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 136 | for x in range(23, 30): 137 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 138 | if not requires_grad: 139 | for param in self.parameters(): 140 | param.requires_grad = False 141 | 142 | def forward(self, X): 143 | h = self.slice1(X) 144 | h_relu1_2 = h 145 | h = self.slice2(h) 146 | h_relu2_2 = h 147 | h = self.slice3(h) 148 | h_relu3_3 = h 149 | h = self.slice4(h) 150 | h_relu4_3 = h 151 | h = self.slice5(h) 152 | h_relu5_3 = h 153 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 154 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 155 | return out 156 | 157 | 158 | def normalize_tensor(x,eps=1e-10): 159 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) 160 | return x/(norm_factor+eps) 161 | 162 | 163 | def spatial_average(x, keepdim=True): 164 | return x.mean([2,3],keepdim=keepdim) -------------------------------------------------------------------------------- /tokenizer/tokenizer_image/reconstruction_vq_ddp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | torch.backends.cudnn.allow_tf32 = True 4 | import torch.nn.functional as F 5 | import torch.distributed as dist 6 | from torch.utils.data import DataLoader 7 | from torch.utils.data.distributed import DistributedSampler 8 | from torchvision import transforms 9 | from tqdm import tqdm 10 | import os 11 | from PIL import Image 12 | import numpy as np 13 | import argparse 14 | import itertools 15 | 16 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 17 | from skimage.metrics import structural_similarity as ssim_loss 18 | 19 | from dataset.augmentation import center_crop_arr 20 | from dataset.build import build_dataset 21 | from tokenizer.tokenizer_image.vq_model import VQ_models 22 | 23 | 24 | 25 | def create_npz_from_sample_folder(sample_dir, num=50000): 26 | """ 27 | Builds a single .npz file from a folder of .png samples. 28 | """ 29 | samples = [] 30 | for i in tqdm(range(num), desc="Building .npz file from samples"): 31 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 32 | sample_np = np.asarray(sample_pil).astype(np.uint8) 33 | samples.append(sample_np) 34 | samples = np.stack(samples) 35 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 36 | npz_path = f"{sample_dir}.npz" 37 | np.savez(npz_path, arr_0=samples) 38 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 39 | return npz_path 40 | 41 | 42 | 43 | def main(args): 44 | # Setup PyTorch: 45 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 46 | torch.set_grad_enabled(False) 47 | 48 | # Setup DDP: 49 | dist.init_process_group("nccl") 50 | rank = dist.get_rank() 51 | device = rank % torch.cuda.device_count() 52 | seed = args.global_seed * dist.get_world_size() + rank 53 | torch.manual_seed(seed) 54 | torch.cuda.set_device(device) 55 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 56 | 57 | # create and load model 58 | vq_model = VQ_models[args.vq_model]( 59 | codebook_size=args.codebook_size, 60 | codebook_embed_dim=args.codebook_embed_dim) 61 | vq_model.to(device) 62 | vq_model.eval() 63 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") 64 | if "ema" in checkpoint: # ema 65 | model_weight = checkpoint["ema"] 66 | elif "model" in checkpoint: # ddp 67 | model_weight = checkpoint["model"] 68 | elif "state_dict" in checkpoint: 69 | model_weight = checkpoint["state_dict"] 70 | else: 71 | raise Exception("please check model weight") 72 | vq_model.load_state_dict(model_weight) 73 | del checkpoint 74 | 75 | # Create folder to save samples: 76 | folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}" 77 | f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}") 78 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 79 | if rank == 0: 80 | os.makedirs(sample_folder_dir, exist_ok=True) 81 | print(f"Saving .png samples at {sample_folder_dir}") 82 | dist.barrier() 83 | 84 | # Setup data: 85 | transform = transforms.Compose([ 86 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), 87 | transforms.ToTensor(), 88 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 89 | ]) 90 | 91 | if args.dataset == 'imagenet': 92 | dataset = build_dataset(args, transform=transform) 93 | num_fid_samples = 50000 94 | elif args.dataset == 'coco': 95 | dataset = build_dataset(args, transform=transform) 96 | num_fid_samples = 5000 97 | else: 98 | raise Exception("please check dataset") 99 | 100 | sampler = DistributedSampler( 101 | dataset, 102 | num_replicas=dist.get_world_size(), 103 | rank=rank, 104 | shuffle=False, 105 | seed=args.global_seed 106 | ) 107 | loader = DataLoader( 108 | dataset, 109 | batch_size=args.per_proc_batch_size, 110 | shuffle=False, 111 | sampler=sampler, 112 | num_workers=args.num_workers, 113 | pin_memory=True, 114 | drop_last=False 115 | ) 116 | 117 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 118 | n = args.per_proc_batch_size 119 | global_batch_size = n * dist.get_world_size() 120 | 121 | psnr_val_rgb = [] 122 | ssim_val_rgb = [] 123 | loader = tqdm(loader) if rank == 0 else loader 124 | total = 0 125 | for x, _ in loader: 126 | if args.image_size_eval != args.image_size: 127 | rgb_gts = F.interpolate(x, size=(args.image_size_eval, args.image_size_eval), mode='bicubic') 128 | else: 129 | rgb_gts = x 130 | rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1] 131 | x = x.to(device, non_blocking=True) 132 | with torch.no_grad(): 133 | latent, _, [_, _, indices] = vq_model.encode(x) 134 | samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1] 135 | if args.image_size_eval != args.image_size: 136 | samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic') 137 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 138 | 139 | # Save samples to disk as individual .png files 140 | for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)): 141 | index = i * dist.get_world_size() + rank + total 142 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 143 | # metric 144 | rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1] 145 | psnr = psnr_loss(rgb_restored, rgb_gt) 146 | ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1) 147 | psnr_val_rgb.append(psnr) 148 | ssim_val_rgb.append(ssim) 149 | 150 | total += global_batch_size 151 | 152 | # ------------------------------------ 153 | # Summary 154 | # ------------------------------------ 155 | # Make sure all processes have finished saving their samples 156 | dist.barrier() 157 | world_size = dist.get_world_size() 158 | gather_psnr_val = [None for _ in range(world_size)] 159 | gather_ssim_val = [None for _ in range(world_size)] 160 | dist.all_gather_object(gather_psnr_val, psnr_val_rgb) 161 | dist.all_gather_object(gather_ssim_val, ssim_val_rgb) 162 | 163 | if rank == 0: 164 | gather_psnr_val = list(itertools.chain(*gather_psnr_val)) 165 | gather_ssim_val = list(itertools.chain(*gather_ssim_val)) 166 | psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val) 167 | ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val) 168 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb)) 169 | 170 | result_file = f"{sample_folder_dir}_results.txt" 171 | print("writing results to {}".format(result_file)) 172 | with open(result_file, 'w') as f: 173 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f) 174 | 175 | create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) 176 | print("Done.") 177 | 178 | dist.barrier() 179 | dist.destroy_process_group() 180 | 181 | 182 | if __name__ == "__main__": 183 | parser = argparse.ArgumentParser() 184 | parser.add_argument("--data-path", type=str, required=True) 185 | parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet') 186 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") 187 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") 188 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 189 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 190 | parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256) 191 | parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256) 192 | parser.add_argument("--sample-dir", type=str, default="reconstructions") 193 | parser.add_argument("--per-proc-batch-size", type=int, default=32) 194 | parser.add_argument("--global-seed", type=int, default=0) 195 | parser.add_argument("--num-workers", type=int, default=4) 196 | args = parser.parse_args() 197 | main(args) -------------------------------------------------------------------------------- /tokenizer/tokenizer_image/vq_demo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import os 5 | import argparse 6 | import numpy as np 7 | from PIL import Image 8 | 9 | from tokenizer.tokenizer_image.vq_model import VQ_models 10 | from dataset.augmentation import center_crop_arr 11 | 12 | 13 | def main(args): 14 | # Setup PyTorch: 15 | torch.manual_seed(args.seed) 16 | torch.set_grad_enabled(False) 17 | device = "cuda" if torch.cuda.is_available() else "cpu" 18 | 19 | # create and load model 20 | model = VQ_models[args.vq_model]( 21 | codebook_size=args.codebook_size, 22 | codebook_embed_dim=args.codebook_embed_dim) 23 | model.to(device) 24 | model.eval() 25 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") 26 | if "ema" in checkpoint: # ema 27 | model_weight = checkpoint["ema"] 28 | elif "model" in checkpoint: # ddp 29 | model_weight = checkpoint["model"] 30 | elif "state_dict" in checkpoint: 31 | model_weight = checkpoint["state_dict"] 32 | else: 33 | raise Exception("please check model weight") 34 | model.load_state_dict(model_weight) 35 | del checkpoint 36 | 37 | # output dir 38 | os.makedirs(args.output_dir, exist_ok=True) 39 | out_path = args.image_path.replace('.jpg', '_{}.jpg'.format(args.suffix)) 40 | out_path = out_path.replace('.jpeg', '_{}.jpeg'.format(args.suffix)) 41 | out_path = out_path.replace('.png', '_{}.png'.format(args.suffix)) 42 | out_filename = out_path.split('/')[-1] 43 | out_path = os.path.join(args.output_dir, out_filename) 44 | 45 | # load image 46 | pil_image = Image.open(args.image_path).convert("RGB") 47 | img = center_crop_arr(pil_image, args.image_size) 48 | # # preprocess 49 | # size_org = img.size 50 | # img = img.resize((input_size, input_size)) 51 | img = np.array(img) / 255. 52 | x = 2.0 * img - 1.0 # x value is between [-1, 1] 53 | x = torch.tensor(x) 54 | x = x.unsqueeze(dim=0) 55 | x = torch.einsum('nhwc->nchw', x) 56 | x_input = x.float().to("cuda") 57 | 58 | # inference 59 | with torch.no_grad(): 60 | latent, _, [_, _, indices] = model.encode(x_input) 61 | output = model.decode_code(indices, latent.shape) # output value is between [-1, 1] 62 | 63 | # postprocess 64 | output = F.interpolate(output, size=[args.image_size, args.image_size], mode='bicubic').permute(0, 2, 3, 1)[0] 65 | sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() 66 | 67 | # save 68 | Image.fromarray(sample).save(out_path) 69 | print("Reconstructed image is saved to {}".format(out_path)) 70 | 71 | 72 | if __name__ == "__main__": 73 | parser = argparse.ArgumentParser() 74 | parser.add_argument("--image-path", type=str, default="assets/example.jpg") 75 | parser.add_argument("--output-dir", type=str, default="output_vq_demo") 76 | parser.add_argument("--suffix", type=str, default="tokenizer_image") 77 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") 78 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") 79 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 80 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 81 | parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512, 1024], default=512) 82 | parser.add_argument("--seed", type=int, default=0) 83 | args = parser.parse_args() 84 | main(args) -------------------------------------------------------------------------------- /tokenizer/tokenizer_image/vq_loss.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # taming-transformers: https://github.com/CompVis/taming-transformers 3 | # muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from tokenizer.tokenizer_image.lpips import LPIPS 9 | from tokenizer.tokenizer_image.discriminator_patchgan import NLayerDiscriminator as PatchGANDiscriminator 10 | from tokenizer.tokenizer_image.discriminator_stylegan import Discriminator as StyleGANDiscriminator 11 | 12 | 13 | 14 | def hinge_d_loss(logits_real, logits_fake): 15 | loss_real = torch.mean(F.relu(1. - logits_real)) 16 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | 21 | def vanilla_d_loss(logits_real, logits_fake): 22 | loss_real = torch.mean(F.softplus(-logits_real)) 23 | loss_fake = torch.mean(F.softplus(logits_fake)) 24 | d_loss = 0.5 * (loss_real + loss_fake) 25 | return d_loss 26 | 27 | 28 | def non_saturating_d_loss(logits_real, logits_fake): 29 | loss_real = torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real)) 30 | loss_fake = torch.mean(F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake)) 31 | d_loss = 0.5 * (loss_real + loss_fake) 32 | return d_loss 33 | 34 | 35 | def hinge_gen_loss(logit_fake): 36 | return -torch.mean(logit_fake) 37 | 38 | 39 | def non_saturating_gen_loss(logit_fake): 40 | return torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake)) 41 | 42 | 43 | def adopt_weight(weight, global_step, threshold=0, value=0.): 44 | if global_step < threshold: 45 | weight = value 46 | return weight 47 | 48 | 49 | class VQLoss(nn.Module): 50 | def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256, 51 | disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight = False, 52 | gen_adv_loss='hinge', reconstruction_loss='l2', reconstruction_weight=1.0, 53 | codebook_weight=1.0, perceptual_weight=1.0, 54 | ): 55 | super().__init__() 56 | # discriminator loss 57 | assert disc_type in ["patchgan", "stylegan"] 58 | assert disc_loss in ["hinge", "vanilla", "non-saturating"] 59 | if disc_type == "patchgan": 60 | self.discriminator = PatchGANDiscriminator( 61 | input_nc=disc_in_channels, 62 | n_layers=disc_num_layers, 63 | ndf=disc_dim, 64 | ) 65 | elif disc_type == "stylegan": 66 | self.discriminator = StyleGANDiscriminator( 67 | input_nc=disc_in_channels, 68 | image_size=image_size, 69 | ) 70 | else: 71 | raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.") 72 | if disc_loss == "hinge": 73 | self.disc_loss = hinge_d_loss 74 | elif disc_loss == "vanilla": 75 | self.disc_loss = vanilla_d_loss 76 | elif disc_loss == "non-saturating": 77 | self.disc_loss = non_saturating_d_loss 78 | else: 79 | raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.") 80 | self.discriminator_iter_start = disc_start 81 | self.disc_weight = disc_weight 82 | self.disc_adaptive_weight = disc_adaptive_weight 83 | 84 | assert gen_adv_loss in ["hinge", "non-saturating"] 85 | # gen_adv_loss 86 | if gen_adv_loss == "hinge": 87 | self.gen_adv_loss = hinge_gen_loss 88 | elif gen_adv_loss == "non-saturating": 89 | self.gen_adv_loss = non_saturating_gen_loss 90 | else: 91 | raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.") 92 | 93 | # perceptual loss 94 | self.perceptual_loss = LPIPS().eval() 95 | self.perceptual_weight = perceptual_weight 96 | 97 | # reconstruction loss 98 | if reconstruction_loss == "l1": 99 | self.rec_loss = F.l1_loss 100 | elif reconstruction_loss == "l2": 101 | self.rec_loss = F.mse_loss 102 | else: 103 | raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.") 104 | self.rec_weight = reconstruction_weight 105 | 106 | # codebook loss 107 | self.codebook_weight = codebook_weight 108 | 109 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer): 110 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 111 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 112 | 113 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 114 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 115 | return d_weight.detach() 116 | 117 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None, 118 | logger=None, log_every=100): 119 | # generator update 120 | if optimizer_idx == 0: 121 | # reconstruction loss 122 | rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous()) 123 | 124 | # perceptual loss 125 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 126 | p_loss = torch.mean(p_loss) 127 | 128 | # discriminator loss 129 | logits_fake = self.discriminator(reconstructions.contiguous()) 130 | generator_adv_loss = self.gen_adv_loss(logits_fake) 131 | 132 | if self.disc_adaptive_weight: 133 | null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss 134 | disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss, last_layer=last_layer) 135 | else: 136 | disc_adaptive_weight = 1 137 | disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start) 138 | 139 | loss = self.rec_weight * rec_loss + \ 140 | self.perceptual_weight * p_loss + \ 141 | disc_adaptive_weight * disc_weight * generator_adv_loss + \ 142 | codebook_loss[0] + codebook_loss[1] + codebook_loss[2] 143 | 144 | if global_step % log_every == 0: 145 | rec_loss = self.rec_weight * rec_loss 146 | p_loss = self.perceptual_weight * p_loss 147 | generator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_loss 148 | logger.info(f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, " 149 | f"vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, " 150 | f"codebook_usage: {codebook_loss[3]:.4f}, generator_adv_loss: {generator_adv_loss:.4f}, " 151 | f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}") 152 | return loss 153 | 154 | # discriminator update 155 | if optimizer_idx == 1: 156 | logits_real = self.discriminator(inputs.contiguous().detach()) 157 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 158 | 159 | disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start) 160 | d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake) 161 | 162 | if global_step % log_every == 0: 163 | logits_real = logits_real.detach().mean() 164 | logits_fake = logits_fake.detach().mean() 165 | logger.info(f"(Discriminator) " 166 | f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, " 167 | f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}") 168 | return d_adversarial_loss -------------------------------------------------------------------------------- /tokenizer/tokenizer_image/vq_model_hf.py: -------------------------------------------------------------------------------- 1 | from huggingface_hub import PyTorchModelHubMixin 2 | 3 | from tokenizer.tokenizer_image.vq_model import ModelArgs, VQModel 4 | 5 | class VQModelHF(VQModel, PyTorchModelHubMixin, repo_url="https://github.com/FoundationVision/LlamaGen", license="mit", tags=["llamagen", "text-to-image"]): 6 | pass 7 | 8 | ################################################################################# 9 | # VQ Model Configs # 10 | ################################################################################# 11 | def VQ_8(**kwargs): 12 | return VQModelHF(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs)) 13 | 14 | def VQ_16(**kwargs): 15 | return VQModelHF(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs)) 16 | 17 | VQ_models_HF = {'VQ-16': VQ_16, 'VQ-8': VQ_8} 18 | -------------------------------------------------------------------------------- /tokenizer/vae/README.md: -------------------------------------------------------------------------------- 1 | ## VAE Models from Stable Diffusion 2 | 3 | ### install 4 | ``` 5 | pip install diffusers 6 | pip install accelerate 7 | ``` 8 | 9 | ### demo 10 | ``` 11 | cd ${THIS_REPO_ROOT} 12 | python3 tokenizer/vae/sd_vae_demo.py 13 | ``` 14 | 15 | -------------------------------------------------------------------------------- /tokenizer/vae/sd_vae_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from PIL import Image 6 | from diffusers.models import AutoencoderKL 7 | 8 | 9 | def main(args): 10 | # Setup PyTorch: 11 | torch.manual_seed(args.seed) 12 | torch.set_grad_enabled(False) 13 | device = "cuda" if torch.cuda.is_available() else "cpu" 14 | 15 | # create and load model 16 | vae = AutoencoderKL.from_pretrained(f"stabilityai/{args.vae}").to(device) 17 | 18 | # load image 19 | img_path = args.image_path 20 | out_path = args.image_path.replace('.jpg', '_vae.jpg').replace('.jpeg', '_vae.jpeg').replace('.png', '_vae.png') 21 | input_size = args.image_size 22 | img = Image.open(img_path).convert("RGB") 23 | 24 | # preprocess 25 | size_org = img.size 26 | img = img.resize((input_size, input_size)) 27 | img = np.array(img) / 255. 28 | x = 2.0 * img - 1.0 # x value is between [-1, 1] 29 | x = torch.tensor(x) 30 | x = x.unsqueeze(dim=0) 31 | x = torch.einsum('nhwc->nchw', x) 32 | x_input = x.float().to("cuda") 33 | 34 | # inference 35 | with torch.no_grad(): 36 | # Map input images to latent space + normalize latents: 37 | latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215) 38 | # reconstruct: 39 | output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1] 40 | 41 | # postprocess 42 | output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0] 43 | sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() 44 | 45 | # save 46 | Image.fromarray(sample).save(out_path) 47 | print("Reconstructed image is saved to {}".format(out_path)) 48 | 49 | 50 | if __name__ == "__main__": 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument("--image-path", type=str, default="assets/example.jpg") 53 | parser.add_argument("--vae", type=str, choices=["sdxl-vae", "sd-vae-ft-mse"], default="sd-vae-ft-mse") 54 | parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512) 55 | parser.add_argument("--seed", type=int, default=0) 56 | args = parser.parse_args() 57 | main(args) -------------------------------------------------------------------------------- /tokenizer/validation/val_ddp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | torch.backends.cuda.matmul.allow_tf32 = True 3 | torch.backends.cudnn.allow_tf32 = True 4 | import torch.distributed as dist 5 | from torch.utils.data import Dataset, DataLoader 6 | from torch.utils.data.distributed import DistributedSampler 7 | from torchvision.datasets import ImageFolder 8 | from torchvision import transforms 9 | from tqdm import tqdm 10 | import os 11 | from PIL import Image 12 | import numpy as np 13 | import argparse 14 | import random 15 | 16 | 17 | class SingleFolderDataset(Dataset): 18 | def __init__(self, directory, transform=None): 19 | super().__init__() 20 | self.directory = directory 21 | self.transform = transform 22 | self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory) 23 | if os.path.isfile(os.path.join(directory, file_name))] 24 | 25 | def __len__(self): 26 | return len(self.image_paths) 27 | 28 | def __getitem__(self, idx): 29 | image_path = self.image_paths[idx] 30 | image = Image.open(image_path).convert('RGB') 31 | if self.transform: 32 | image = self.transform(image) 33 | return image, torch.tensor(0) 34 | 35 | 36 | def create_npz_from_sample_folder(sample_dir, num=50_000): 37 | """ 38 | Builds a single .npz file from a folder of .png samples. 39 | """ 40 | samples = [] 41 | for i in tqdm(range(num), desc="Building .npz file from samples"): 42 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png") 43 | sample_np = np.asarray(sample_pil).astype(np.uint8) 44 | samples.append(sample_np) 45 | 46 | random.shuffle(samples) # This is very important for IS(Inception Score) !!! 47 | samples = np.stack(samples) 48 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3) 49 | npz_path = f"{sample_dir}.npz" 50 | np.savez(npz_path, arr_0=samples) 51 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].") 52 | return npz_path 53 | 54 | 55 | def center_crop_arr(pil_image, image_size): 56 | """ 57 | Center cropping implementation from ADM. 58 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 59 | """ 60 | while min(*pil_image.size) >= 2 * image_size: 61 | pil_image = pil_image.resize( 62 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 63 | ) 64 | 65 | scale = image_size / min(*pil_image.size) 66 | pil_image = pil_image.resize( 67 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 68 | ) 69 | 70 | arr = np.array(pil_image) 71 | crop_y = (arr.shape[0] - image_size) // 2 72 | crop_x = (arr.shape[1] - image_size) // 2 73 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) 74 | 75 | 76 | def main(args): 77 | # Setup PyTorch: 78 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 79 | torch.set_grad_enabled(False) 80 | 81 | # Setup env 82 | dist.init_process_group("nccl") 83 | rank = dist.get_rank() 84 | device = rank % torch.cuda.device_count() 85 | seed = args.global_seed * dist.get_world_size() + rank 86 | torch.manual_seed(seed) 87 | torch.cuda.set_device(device) 88 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.") 89 | 90 | # Create folder to save samples: 91 | folder_name = f"val_{args.dataset}" 92 | sample_folder_dir = f"{args.sample_dir}/{folder_name}" 93 | if rank == 0: 94 | os.makedirs(sample_folder_dir, exist_ok=True) 95 | print(f"Saving .png samples at {sample_folder_dir}") 96 | dist.barrier() 97 | 98 | # Setup data: 99 | transform = transforms.Compose([ 100 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)), 101 | transforms.ToTensor(), 102 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) 103 | ]) 104 | 105 | if args.dataset == 'imagenet': 106 | dataset = ImageFolder(args.data_path, transform=transform) 107 | num_fid_samples = 50000 108 | elif args.dataset == 'coco': 109 | dataset = SingleFolderDataset(args.data_path, transform=transform) 110 | num_fid_samples = 5000 111 | else: 112 | raise Exception("please check dataset") 113 | 114 | sampler = DistributedSampler( 115 | dataset, 116 | num_replicas=dist.get_world_size(), 117 | rank=rank, 118 | shuffle=False, 119 | seed=args.global_seed 120 | ) 121 | loader = DataLoader( 122 | dataset, 123 | batch_size=args.per_proc_batch_size, 124 | shuffle=False, 125 | sampler=sampler, 126 | num_workers=args.num_workers, 127 | pin_memory=True, 128 | drop_last=False 129 | ) 130 | 131 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run: 132 | n = args.per_proc_batch_size 133 | global_batch_size = n * dist.get_world_size() 134 | 135 | loader = tqdm(loader) if rank == 0 else loader 136 | total = 0 137 | for x, _ in loader: 138 | samples = torch.clamp(127.5 * x + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy() 139 | # Save samples to disk as individual .png files 140 | for i, sample in enumerate(samples): 141 | index = i * dist.get_world_size() + rank + total 142 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png") 143 | 144 | total += global_batch_size 145 | 146 | # Make sure all processes have finished saving their samples before attempting to convert to .npz 147 | dist.barrier() 148 | if rank == 0: 149 | create_npz_from_sample_folder(sample_folder_dir, num_fid_samples) 150 | print("Done.") 151 | dist.barrier() 152 | dist.destroy_process_group() 153 | 154 | 155 | if __name__ == "__main__": 156 | parser = argparse.ArgumentParser() 157 | parser.add_argument("--data-path", type=str, required=True) 158 | parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet') 159 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256) 160 | parser.add_argument("--sample-dir", type=str, default="reconstructions") 161 | parser.add_argument("--per-proc-batch-size", type=int, default=32) 162 | parser.add_argument("--global-seed", type=int, default=0) 163 | parser.add_argument("--num-workers", type=int, default=4) 164 | args = parser.parse_args() 165 | main(args) -------------------------------------------------------------------------------- /tokenizer/vqgan/README.md: -------------------------------------------------------------------------------- 1 | ## Pretrained VQVAE Models 2 | 3 | ### install 4 | ``` 5 | pip install omegaconf 6 | pip install einops 7 | ``` 8 | * download all needed models from https://github.com/CompVis/taming-transformers and put in pretrained_models/ 9 | * pip install pytorch_lightning 10 | * python3 tools/convert_pytorch_lightning_to_torch.py 11 | * pip uninstall pytorch_lightning 12 | 13 | ### demo 14 | ``` 15 | cd ${THIS_REPO_ROOT} 16 | python3 tokenizer/vqgan/taming_vqgan_demo.py 17 | ``` 18 | 19 | ### acknowledge 20 | Codes in this folder are modified from from https://github.com/CompVis/taming-transformers 21 | 22 | -------------------------------------------------------------------------------- /tokenizer/vqgan/configs/vqgan_imagenet_f16_1024.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 1024 7 | ddconfig: 8 | double_z: false 9 | z_channels: 256 10 | resolution: 256 11 | in_channels: 3 12 | out_ch: 3 13 | ch: 128 14 | ch_mult: 15 | - 1 16 | - 1 17 | - 2 18 | - 2 19 | - 4 20 | num_res_blocks: 2 21 | attn_resolutions: 22 | - 16 23 | dropout: 0.0 24 | lossconfig: 25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 26 | params: 27 | disc_conditional: false 28 | disc_in_channels: 3 29 | disc_start: 0 30 | disc_weight: 0.8 31 | codebook_weight: 1.0 32 | 33 | -------------------------------------------------------------------------------- /tokenizer/vqgan/configs/vqgan_imagenet_f16_16384.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 4.5e-06 3 | target: taming.models.vqgan.VQModel 4 | params: 5 | embed_dim: 256 6 | n_embed: 16384 7 | monitor: val/rec_loss 8 | ddconfig: 9 | double_z: false 10 | z_channels: 256 11 | resolution: 256 12 | in_channels: 3 13 | out_ch: 3 14 | ch: 128 15 | ch_mult: 16 | - 1 17 | - 1 18 | - 2 19 | - 2 20 | - 4 21 | num_res_blocks: 2 22 | attn_resolutions: 23 | - 16 24 | dropout: 0.0 25 | lossconfig: 26 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator 27 | params: 28 | disc_conditional: false 29 | disc_in_channels: 3 30 | disc_start: 0 31 | disc_weight: 0.75 32 | disc_num_layers: 2 33 | codebook_weight: 1.0 34 | 35 | -------------------------------------------------------------------------------- /tokenizer/vqgan/configs/vqgan_openimage_f8_16384.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | embed_dim: 4 4 | n_embed: 16384 5 | ddconfig: 6 | double_z: false 7 | z_channels: 4 8 | resolution: 256 9 | in_channels: 3 10 | out_ch: 3 11 | ch: 128 12 | ch_mult: 13 | - 1 14 | - 2 15 | - 2 16 | - 4 17 | num_res_blocks: 2 18 | attn_resolutions: 19 | - 32 20 | dropout: 0.0 -------------------------------------------------------------------------------- /tokenizer/vqgan/configs/vqgan_openimage_f8_256.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | params: 3 | embed_dim: 4 4 | n_embed: 256 5 | ddconfig: 6 | double_z: false 7 | z_channels: 4 8 | resolution: 256 9 | in_channels: 3 10 | out_ch: 3 11 | ch: 128 12 | ch_mult: 13 | - 1 14 | - 2 15 | - 2 16 | - 4 17 | num_res_blocks: 2 18 | attn_resolutions: 19 | - 32 20 | dropout: 0.0 -------------------------------------------------------------------------------- /tokenizer/vqgan/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from tokenizer.vqgan.layer import Encoder, Decoder 6 | from tokenizer.vqgan.quantize import VectorQuantizer2 as VectorQuantizer 7 | 8 | 9 | VQGAN_FROM_TAMING = { 10 | 'vqgan_imagenet_f16_1024': ( 11 | 'tokenizer/vqgan/configs/vqgan_imagenet_f16_1024.yaml', 12 | 'pretrained_models/vqgan_imagenet_f16_1024/ckpts/last.pth'), 13 | 'vqgan_imagenet_f16_16384': ( 14 | 'tokenizer/vqgan/configs/vqgan_imagenet_f16_16384.yaml', 15 | 'pretrained_models/vqgan_imagenet_f16_16384/ckpts/last.pth'), 16 | 'vqgan_openimage_f8_256': ( 17 | 'tokenizer/vqgan/configs/vqgan_openimage_f8_256.yaml', 18 | 'pretrained_models/vq-f8-n256/model.pth'), 19 | 'vqgan_openimage_f8_16384': ( 20 | 'tokenizer/vqgan/configs/vqgan_openimage_f8_16384.yaml', 21 | 'pretrained_models/vq-f8/model.pth'), 22 | } 23 | 24 | class VQModel(nn.Module): 25 | def __init__(self, 26 | ddconfig, 27 | n_embed, 28 | embed_dim, 29 | ckpt_path=None, 30 | ignore_keys=[], 31 | image_key="image", 32 | colorize_nlabels=None, 33 | monitor=None, 34 | remap=None, 35 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 36 | **kwargs, 37 | ): 38 | super().__init__() 39 | self.image_key = image_key 40 | self.encoder = Encoder(**ddconfig) 41 | self.decoder = Decoder(**ddconfig) 42 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, 43 | remap=remap, sane_index_shape=sane_index_shape) 44 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 45 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 46 | if ckpt_path is not None: 47 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 48 | self.image_key = image_key 49 | if colorize_nlabels is not None: 50 | assert type(colorize_nlabels)==int 51 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 52 | if monitor is not None: 53 | self.monitor = monitor 54 | 55 | def init_from_ckpt(self, path, ignore_keys=list(), logging=True): 56 | model_weight = torch.load(path, map_location="cpu")["state_dict"] 57 | keys = list(model_weight.keys()) 58 | for k in keys: 59 | for ik in ignore_keys: 60 | if k.startswith(ik): 61 | print("Deleting key {} from state_dict.".format(k)) 62 | del model_weight[k] 63 | missing, unexpected = self.load_state_dict(model_weight, strict=False) 64 | if logging: 65 | print(f"Restored from {path}") 66 | print(f"Missing Keys in State Dict: {missing}") 67 | print(f"Unexpected Keys in State Dict: {unexpected}") 68 | 69 | def encode(self, x): 70 | h = self.encoder(x) 71 | h = self.quant_conv(h) 72 | quant, emb_loss, info = self.quantize(h) 73 | return quant, emb_loss, info 74 | 75 | def decode(self, quant): 76 | quant = self.post_quant_conv(quant) 77 | dec = self.decoder(quant) 78 | return dec 79 | 80 | def decode_code(self, code_b, shape, channel_first=True): 81 | quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first) 82 | dec = self.decode(quant_b) 83 | return dec 84 | 85 | def forward(self, input): 86 | quant, diff, _ = self.encode(input) 87 | dec = self.decode(quant) 88 | return dec, diff 89 | -------------------------------------------------------------------------------- /tokenizer/vqgan/taming_vqgan_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from PIL import Image 6 | from omegaconf import OmegaConf 7 | from tokenizer.vqgan.model import VQModel 8 | from tokenizer.vqgan.model import VQGAN_FROM_TAMING 9 | 10 | # before running demo, make sure to: 11 | # (1) download all needed models from https://github.com/CompVis/taming-transformers and put in pretrained_models/ 12 | # (2) pip install pytorch_lightning 13 | # (3) python3 tools/convert_pytorch_lightning_to_torch.py 14 | # (4) pip uninstall pytorch_lightning 15 | 16 | 17 | def main(args): 18 | # Setup PyTorch: 19 | torch.manual_seed(args.seed) 20 | torch.set_grad_enabled(False) 21 | device = "cuda" if torch.cuda.is_available() else "cpu" 22 | 23 | # create and load model 24 | cfg, ckpt = VQGAN_FROM_TAMING[args.vqgan] 25 | config = OmegaConf.load(cfg) 26 | model = VQModel(**config.model.get("params", dict())) 27 | model.init_from_ckpt(ckpt) 28 | model.to(device) 29 | model.eval() 30 | 31 | # load image 32 | img_path = args.image_path 33 | out_path = args.image_path.replace('.jpg', '_vqgan.jpg').replace('.jpeg', '_vqgan.jpeg').replace('.png', '_vqgan.png') 34 | input_size = args.image_size 35 | img = Image.open(img_path).convert("RGB") 36 | 37 | # preprocess 38 | size_org = img.size 39 | img = img.resize((input_size, input_size)) 40 | img = np.array(img) / 255. 41 | x = 2.0 * img - 1.0 # x value is between [-1, 1] 42 | x = torch.tensor(x) 43 | x = x.unsqueeze(dim=0) 44 | x = torch.einsum('nhwc->nchw', x) 45 | x_input = x.float().to("cuda") 46 | 47 | # inference 48 | with torch.no_grad(): 49 | latent, _, [_, _, indices] = model.encode(x_input) 50 | output = model.decode_code(indices, latent.shape) # output value is between [-1, 1] 51 | 52 | # postprocess 53 | output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0] 54 | sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() 55 | 56 | # save 57 | Image.fromarray(sample).save(out_path) 58 | print("Reconstructed image is saved to {}".format(out_path)) 59 | 60 | 61 | if __name__ == "__main__": 62 | parser = argparse.ArgumentParser() 63 | parser.add_argument("--image-path", type=str, default="assets/example.jpg") 64 | parser.add_argument("--vqgan", type=str, choices=list(VQGAN_FROM_TAMING.keys()), default="vqgan_openimage_f8_16384") 65 | parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512) 66 | parser.add_argument("--seed", type=int, default=0) 67 | args = parser.parse_args() 68 | main(args) 69 | -------------------------------------------------------------------------------- /tools/check_image_codes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | import numpy as np 4 | 5 | from tokenizer.tokenizer_image.vq_model import VQ_models 6 | from torchvision.utils import save_image 7 | 8 | 9 | def main(args): 10 | # Setup PyTorch: 11 | torch.manual_seed(args.seed) 12 | torch.set_grad_enabled(False) 13 | device = "cuda" if torch.cuda.is_available() else "cpu" 14 | 15 | # create and load model 16 | vq_model = VQ_models[args.vq_model]( 17 | codebook_size=args.codebook_size, 18 | codebook_embed_dim=args.codebook_embed_dim) 19 | vq_model.to(device) 20 | vq_model.eval() 21 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu") 22 | vq_model.load_state_dict(checkpoint["model"]) 23 | del checkpoint 24 | 25 | # load image code 26 | latent_dim = args.codebook_embed_dim 27 | latent_size = args.image_size // args.downsample_size 28 | codes = torch.from_numpy(np.load(args.code_path)).to(device) 29 | if codes.ndim == 3: # flip augmentation 30 | qzshape = (codes.shape[1], latent_dim, latent_size, latent_size) 31 | else: 32 | qzshape = (1, latent_dim, latent_size, latent_size) 33 | index_sample = codes.reshape(-1) 34 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1] 35 | 36 | # save 37 | out_path = "sample_image_code.png" 38 | nrow = max(4, int(codes.shape[1]//2)) 39 | save_image(samples, out_path, nrow=nrow, normalize=True, value_range=(-1, 1)) 40 | print("Reconstructed image is saved to {}".format(out_path)) 41 | 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument("--code-path", type=str, required=True) 47 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") 48 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") 49 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 50 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 51 | parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512], default=256) 52 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) 53 | parser.add_argument("--seed", type=int, default=0) 54 | args = parser.parse_args() 55 | main(args) -------------------------------------------------------------------------------- /tools/convert_pytorch_lightning_to_torch.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | MODEL_PATH = 'pretrained_models' 5 | pt_lightnings = [ 6 | 'vqgan_imagenet_f16_1024/ckpts/last.ckpt', 7 | 'vqgan_imagenet_f16_16384/ckpts/last.ckpt', 8 | 'vq-f8-n256/model.ckpt', 9 | 'vq-f8/model.ckpt', 10 | ] 11 | pts = [ 12 | 'vqgan_imagenet_f16_1024/ckpts/last.pth', 13 | 'vqgan_imagenet_f16_16384/ckpts/last.pth', 14 | 'vq-f8-n256/model.pth', 15 | 'vq-f8/model.pth', 16 | ] 17 | 18 | for pt_l, pt in zip(pt_lightnings, pts): 19 | pt_l_weight = torch.load(os.path.join(MODEL_PATH, pt_l), map_location='cpu') 20 | pt_weight = { 21 | 'state_dict': pt_l_weight['state_dict'] 22 | } 23 | pt_path = os.path.join(MODEL_PATH, pt) 24 | torch.save(pt_weight, pt_path) 25 | print(f'saving to {pt_path}') 26 | -------------------------------------------------------------------------------- /tools/draw_figure.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | 4 | font_size = 14 5 | 6 | def fid_scaling_law_no_cfg(): 7 | # data 8 | steps = np.array([50, 100, 200, 300,]) 9 | loss_b = np.array([41.025, 33.442, 32.105, 32.196]) 10 | loss_l = np.array([25.889, 24.654, 19.742, 19.070]) 11 | loss_xl = np.array([19.820, 18.037, 14.772, 15.549]) 12 | 13 | steps_ = np.array([50, 200, 300,]) 14 | loss_xxl = np.array([17.195, 13.997, 14.648]) 15 | loss_3b = np.array([16.431, 9.949, 9.380]) 16 | # Plot 17 | plt.figure(figsize=(6, 4)) 18 | 19 | plt.plot(steps, loss_b, 'o-', label='B', color='red') 20 | plt.plot(steps, loss_l, 'o-', label='L', color='orange') 21 | plt.plot(steps, loss_xl, 'o-', label='XL', color='green') 22 | plt.plot(steps_, loss_xxl, 'o-', label='XXL', color='blue') 23 | plt.plot(steps_, loss_3b, 'o-', label='3B', color='purple') 24 | 25 | plt.xlabel('Training Epochs', fontsize=font_size) 26 | plt.ylabel('FID', fontsize=font_size) 27 | # plt.grid(True) 28 | # plt.yscale('log') 29 | 30 | # Customize the plot to match the appearance of the provided figure 31 | plt.legend(loc='upper right', framealpha=0.5, fontsize=font_size, facecolor='white') 32 | 33 | # Customizing the x and y axis ticks (to match the example's steps) 34 | # plt.xticks(np.linspace(0, 800000, 5), ['0', '200K', '400K', '600K', '800K']) 35 | plt.yticks(np.arange(5, 50, step=5)) 36 | 37 | # Show plot 38 | plt.tight_layout() 39 | plt.savefig('fid_scaling_law_no_cfg.png', dpi=600) 40 | 41 | 42 | 43 | def fid_scaling_law_cfg(): 44 | # data 45 | steps = np.array([50, 100, 200, 300,]) 46 | loss_b_cfg = np.array([8.309, 7.256, 6.542, 6.249]) 47 | loss_l_cfg = np.array([4.240, 3.705, 3.220, 3.075]) 48 | loss_xl_cfg = np.array([3.420, 3.089, 2.617, 2.629]) 49 | 50 | steps_ = np.array([50, 200, 300,]) 51 | loss_xxl_cfg = np.array([2.893, 2.331, 2.340]) 52 | loss_3b_cfg = np.array([2.611, 2.381, 2.329]) 53 | # Plot 54 | plt.figure(figsize=(6, 4)) 55 | 56 | plt.plot(steps, loss_b_cfg, 'o-', label='B', color='red') 57 | plt.plot(steps, loss_l_cfg, 'o-', label='L', color='orange') 58 | plt.plot(steps, loss_xl_cfg, 'o-', label='XL', color='green') 59 | plt.plot(steps_, loss_xxl_cfg, 'o-', label='XXL', color='blue') 60 | plt.plot(steps_, loss_3b_cfg, 'o-', label='3B', color='purple') 61 | 62 | plt.xlabel('Training Epochs', fontsize=font_size) 63 | plt.ylabel('FID', fontsize=font_size) 64 | # plt.grid(True) 65 | # plt.yscale('log') 66 | 67 | # Customize the plot to match the appearance of the provided figure 68 | plt.legend(loc='upper right', framealpha=0.5, fontsize=font_size, facecolor='white') 69 | 70 | # Customizing the x and y axis ticks (to match the example's steps) 71 | # plt.xticks(np.linspace(0, 800000, 5), ['0', '200K', '400K', '600K', '800K']) 72 | plt.yticks(np.arange(2, 9, step=1)) 73 | 74 | # Show plot 75 | plt.tight_layout() 76 | plt.savefig('fid_scaling_law_cfg.png', dpi=600) 77 | 78 | 79 | 80 | def sample_topk(): 81 | # Data 82 | top_k = np.array([16384, 10000, 8000, 6000, 4000, 2000, 1000]) 83 | fid_values = np.array([3.075, 3.369, 3.643, 3.969, 4.635, 5.998, 7.428]) 84 | inception_scores = np.array([256.067, 265.222, 268.237, 270.159, 271.455, 267.278, 251.268]) 85 | 86 | fig, ax1 = plt.subplots() 87 | # Create first y-axis 88 | ax1.set_xlabel('top-k', fontsize=font_size) 89 | ax1.set_ylabel('FID', color='teal', fontsize=font_size) 90 | ax1.plot(top_k, fid_values, 'o-', color='teal', label="FID") 91 | ax1.tick_params(axis='y', labelcolor='teal') 92 | ax1.tick_params(axis='x') 93 | 94 | # Create second y-axis 95 | ax2 = ax1.twinx() 96 | ax2.set_ylabel('Inception Score', color='brown', fontsize=font_size) 97 | ax2.plot(top_k, inception_scores, 'o-', color='brown', label="Inception Score") 98 | ax2.tick_params(axis='y', labelcolor='brown') 99 | 100 | # Adding a legend 101 | fig.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0), bbox_transform=ax1.transAxes, fontsize=font_size) 102 | 103 | fig.tight_layout() # Adjust layout to prevent overlap 104 | plt.savefig('effect_topk.png', dpi=600) 105 | 106 | 107 | 108 | def sample_cfg(): 109 | # Data 110 | cfg = np.array([1.5, 1.75, 2.00, 2.25]) 111 | fid_values = np.array([4.743, 3.151, 3.075, 3.620]) 112 | inception_scores = np.array([165.381, 214.152, 256.067, 291.695]) 113 | 114 | plt.figure(figsize=(10, 4)) 115 | fig, ax1 = plt.subplots() 116 | # Create first y-axis 117 | ax1.set_xlabel('cfg', fontsize=font_size) 118 | ax1.set_ylabel('FID', color='teal', fontsize=font_size) 119 | ax1.plot(cfg, fid_values, 'o-', color='teal', label="FID") 120 | ax1.tick_params(axis='y', labelcolor='teal') 121 | ax1.tick_params(axis='x') 122 | 123 | # Create second y-axis 124 | ax2 = ax1.twinx() 125 | ax2.set_ylabel('Inception Score', color='brown', fontsize=font_size) 126 | ax2.plot(cfg, inception_scores, 'o-', color='brown', label="Inception Score") 127 | ax2.tick_params(axis='y', labelcolor='brown') 128 | 129 | # Adding a legend 130 | fig.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0), bbox_transform=ax1.transAxes, fontsize=font_size) 131 | 132 | fig.tight_layout() # Adjust layout to prevent overlap 133 | plt.savefig('effect_cfg.png', dpi=600) 134 | 135 | 136 | 137 | if __name__ == "__main__": 138 | fid_scaling_law_no_cfg() 139 | fid_scaling_law_cfg() 140 | sample_cfg() 141 | sample_topk() 142 | -------------------------------------------------------------------------------- /tools/openimage_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from PIL import Image 5 | import multiprocessing as mp 6 | 7 | import warnings 8 | warnings.filterwarnings('ignore') 9 | 10 | 11 | def check_image(image_path): 12 | try: 13 | Image.open(image_path) 14 | return True 15 | except Exception as e: 16 | print(f"Error details: {str(e)}") 17 | return False 18 | 19 | 20 | def check_image_path(image_info): 21 | data_path, image_path_list = image_info # Unpack the info 22 | valid_image_paths = [] 23 | for image_path in image_path_list: 24 | if check_image(os.path.join(data_path, image_path)): 25 | valid_image_paths.append(image_path) 26 | return valid_image_paths 27 | 28 | 29 | def load_image_path(image_info): 30 | folder_name, data_path, image_extensions = image_info # Unpack the info 31 | print(folder_name) 32 | 33 | folder_path = os.path.join(data_path, folder_name) 34 | local_image_paths = [] 35 | for image_path in os.listdir(folder_path): 36 | _, file_extension = os.path.splitext(image_path) 37 | if file_extension.lower() in image_extensions: 38 | image_path_full = os.path.join(folder_name, image_path) 39 | local_image_paths.append(image_path_full) 40 | return local_image_paths 41 | 42 | 43 | 44 | def main(args): 45 | data_path = args.data_path 46 | image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp'] 47 | 48 | num_processes = 47 49 | work_list = [('openimages_{:0>4}'.format(idx), data_path, image_extensions) for idx in range(1, 48)] 50 | with mp.Pool(processes=num_processes) as pool: 51 | results = pool.map(load_image_path, work_list) 52 | image_paths = [image_path for sublist in results for image_path in sublist] 53 | print('image_paths is loaded') 54 | 55 | 56 | num_processes = max(mp.cpu_count() // 2, 4) 57 | unit = len(image_paths) // num_processes 58 | work_list = [(data_path, image_paths[idx*unit:(idx+1)*unit]) for idx in range(num_processes)] 59 | with mp.Pool(processes=num_processes) as pool: 60 | results = pool.map(check_image_path, work_list) 61 | valid_image_paths = [image_path for sublist in results for image_path in sublist] 62 | print('image_paths is checked') 63 | 64 | 65 | output_json_file_path = os.path.join(data_path, 'image_paths.json') 66 | with open(output_json_file_path, 'w') as outfile: 67 | json.dump(valid_image_paths, outfile, indent=4) 68 | print(f"Image paths have been saved to {output_json_file_path}") 69 | 70 | 71 | if __name__ == "__main__": 72 | parser = argparse.ArgumentParser() 73 | parser.add_argument("--data-path", type=str, required=True) 74 | args = parser.parse_args() 75 | main(args) -------------------------------------------------------------------------------- /tools/push_gpt_to_hf.py: -------------------------------------------------------------------------------- 1 | # Modified from: 2 | # DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py 3 | import torch 4 | torch.backends.cuda.matmul.allow_tf32 = True 5 | torch.backends.cudnn.allow_tf32 = True 6 | import argparse 7 | 8 | from tokenizer.tokenizer_image.vq_model import VQ_models 9 | from autoregressive.models.gpt_hf import GPT_models_HF, TransformerHF 10 | 11 | device = "cuda" if torch.cuda_is_available() else "cpu" 12 | 13 | def main(args): 14 | # Setup PyTorch: 15 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage" 16 | torch.set_grad_enabled(False) 17 | 18 | # create and load gpt model 19 | precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision] 20 | latent_size = args.image_size // args.downsample_size 21 | gpt_model = GPT_models_HF[args.gpt_model]( 22 | vocab_size=args.codebook_size, 23 | block_size=latent_size ** 2, 24 | num_classes=args.num_classes, 25 | cls_token_num=args.cls_token_num, 26 | model_type=args.gpt_type, 27 | ).to(device=device, dtype=precision) 28 | checkpoint = torch.load(args.gpt_ckpt, map_location="cpu") 29 | if args.from_fsdp: # fsdp 30 | model_weight = checkpoint 31 | elif "model" in checkpoint: # ddp 32 | model_weight = checkpoint["model"] 33 | elif "module" in checkpoint: # deepspeed 34 | model_weight = checkpoint["module"] 35 | elif "state_dict" in checkpoint: 36 | model_weight = checkpoint["state_dict"] 37 | else: 38 | raise Exception("please check model weight, maybe add --from-fsdp to run command") 39 | 40 | # load weights 41 | gpt_model.load_state_dict(model_weight, strict=False) 42 | gpt_model.eval() 43 | del checkpoint 44 | 45 | # push to hub 46 | repo_id = f"FoundationVision/{args.gpt_model}-{args.image_size}" 47 | gpt_model.push_to_hub(repo_id) 48 | 49 | # reload 50 | model = TransformerHF.from_pretrained(repo_id) 51 | 52 | 53 | if __name__ == "__main__": 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B") 56 | parser.add_argument("--gpt-ckpt", type=str, default=None) 57 | parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional") 58 | parser.add_argument("--from-fsdp", action='store_true') 59 | parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input") 60 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"]) 61 | parser.add_argument("--compile", action='store_true', default=True) 62 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16") 63 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model") 64 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 65 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 66 | parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384) 67 | parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256) 68 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16) 69 | parser.add_argument("--num-classes", type=int, default=1000) 70 | args = parser.parse_args() 71 | main(args) -------------------------------------------------------------------------------- /tools/push_vae_to_hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to push and load custom PyTorch models to/from the Hugging Face Hub. 3 | """ 4 | 5 | import argparse 6 | import torch 7 | from tokenizer.tokenizer_image.vq_model_hf import VQ_models_HF, VQModelHF 8 | 9 | from huggingface_hub import hf_hub_download 10 | 11 | 12 | model2ckpt = { 13 | "GPT-XL": ("vq_ds16_c2i.pt", "c2i_XL_384.pt", 384), 14 | "GPT-B": ("vq_ds16_c2i.pt", "c2i_B_256.pt", 256), 15 | } 16 | 17 | def load_model(args): 18 | ckpt_folder = "./" 19 | vq_ckpt, gpt_ckpt, _ = model2ckpt[args.gpt_model] 20 | hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder) 21 | hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder) 22 | # create and load model 23 | vq_model = VQ_models_HF[args.vq_model]( 24 | codebook_size=args.codebook_size, 25 | codebook_embed_dim=args.codebook_embed_dim) 26 | vq_model.eval() 27 | checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu") 28 | vq_model.load_state_dict(checkpoint["model"]) 29 | del checkpoint 30 | print(f"image tokenizer is loaded") 31 | return vq_model 32 | 33 | 34 | parser = argparse.ArgumentParser() 35 | parser.add_argument("--gpt-model", type=str, default="GPT-XL") 36 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models_HF.keys()), default="VQ-16") 37 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization") 38 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization") 39 | args = parser.parse_args() 40 | 41 | # load weights 42 | vq_model = load_model(args) 43 | 44 | # push to hub 45 | vq_model.push_to_hub("FoundationVision/vq-ds16-c2i") 46 | 47 | # reload 48 | model = VQModelHF.from_pretrained("FoundationVision/vq-ds16-c2i") -------------------------------------------------------------------------------- /utils/data.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | 4 | def center_crop_arr(pil_image, image_size): 5 | """ 6 | Center cropping implementation from ADM. 7 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126 8 | """ 9 | while min(*pil_image.size) >= 2 * image_size: 10 | pil_image = pil_image.resize( 11 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX 12 | ) 13 | 14 | scale = image_size / min(*pil_image.size) 15 | pil_image = pil_image.resize( 16 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC 17 | ) 18 | 19 | arr = np.array(pil_image) 20 | crop_y = (arr.shape[0] - image_size) // 2 21 | crop_x = (arr.shape[1] - image_size) // 2 22 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]) -------------------------------------------------------------------------------- /utils/deepspeed.py: -------------------------------------------------------------------------------- 1 | def create_deepspeed_config(args): 2 | ds_config = { 3 | "steps_per_print": 1000, 4 | "train_batch_size": args.global_batch_size, 5 | "gradient_accumulation_steps": args.gradient_accumulation_steps, 6 | # "train_micro_batch_size_per_gpu": args.batch_size, # determined by (train_batch_size, gradient_accumulation_steps) 7 | "optimizer": { 8 | "type": "Adam", 9 | "adam_w_mode": True, 10 | "params": { 11 | "lr": args.lr, 12 | "weight_decay": args.weight_decay, 13 | "bias_correction": True, 14 | "betas": [ 15 | args.beta1, 16 | args.beta2 17 | ], 18 | } 19 | }, 20 | "fp16": { 21 | "enabled": args.mixed_precision == 'fp16', 22 | "loss_scale": 0, 23 | "initial_scale_power": 16, 24 | "loss_scale_window": 1000, 25 | "hysteresis": 2, 26 | "min_loss_scale": 1 27 | }, 28 | "bf16": { 29 | "enabled": args.mixed_precision == 'bf16', 30 | }, 31 | # "flops_profiler": { 32 | # "enabled": True, 33 | # "profile_step": -1, 34 | # "module_depth": -1, 35 | # "top_modules": 1, 36 | # "detailed": True, 37 | # }, 38 | "zero_allow_untested_optimizer": True 39 | } 40 | 41 | if args.clip_grad is not None: 42 | ds_config.update({'gradient_clipping': args.clip_grad}) 43 | 44 | if args.zero_stage == 0: 45 | ds_config.update({"zero_optimization": 46 | { 47 | "stage": args.zero_stage, 48 | "contiguous_gradients": True, 49 | "overlap_comm": True, 50 | } 51 | }) 52 | elif args.zero_stage == 1: 53 | ds_config.update({"zero_optimization": 54 | { 55 | "stage": args.zero_stage, 56 | "contiguous_gradients": True, 57 | "overlap_comm": True, 58 | "reduce_bucket_size": 5e8, 59 | } 60 | }) 61 | elif args.zero_stage == 2: 62 | ds_config.update({"zero_optimization": 63 | { 64 | "stage": args.zero_stage, 65 | "contiguous_gradients": True, 66 | "overlap_comm": True, 67 | "reduce_scatter": True, 68 | "reduce_bucket_size": 5e8, 69 | "allgather_bucket_size": 5e8, 70 | } 71 | }) 72 | elif args.zero_stage == 3: 73 | ds_config.update({"zero_optimization": 74 | { 75 | "stage": args.zero_stage, 76 | "contiguous_gradients": True, 77 | "overlap_comm": True, 78 | "reduce_bucket_size": 5e8, 79 | "stage3_prefetch_bucket_size": 5e8, 80 | "stage3_param_persistence_threshold": 1e6, 81 | "stage3_max_live_parameters": 1e9, 82 | "stage3_max_reuse_distance": 1e9, 83 | "stage3_gather_16bit_weights_on_model_save": True 84 | } 85 | }) 86 | 87 | return ds_config 88 | -------------------------------------------------------------------------------- /utils/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import subprocess 4 | 5 | 6 | def setup_for_distributed(is_master): 7 | """ 8 | This function disables printing when not in master process 9 | """ 10 | import builtins as __builtin__ 11 | builtin_print = __builtin__.print 12 | 13 | def print(*args, **kwargs): 14 | force = kwargs.pop('force', False) 15 | if is_master or force: 16 | builtin_print(*args, **kwargs) 17 | 18 | __builtin__.print = print 19 | 20 | def init_distributed_mode(args): 21 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 22 | args.rank = int(os.environ["RANK"]) 23 | args.world_size = int(os.environ['WORLD_SIZE']) 24 | args.gpu = int(os.environ['LOCAL_RANK']) 25 | args.dist_url = 'env://' 26 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count()) 27 | elif 'SLURM_PROCID' in os.environ: 28 | proc_id = int(os.environ['SLURM_PROCID']) 29 | ntasks = int(os.environ['SLURM_NTASKS']) 30 | node_list = os.environ['SLURM_NODELIST'] 31 | num_gpus = torch.cuda.device_count() 32 | addr = subprocess.getoutput( 33 | 'scontrol show hostname {} | head -n1'.format(node_list)) 34 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500') 35 | os.environ['MASTER_ADDR'] = addr 36 | os.environ['WORLD_SIZE'] = str(ntasks) 37 | os.environ['RANK'] = str(proc_id) 38 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 39 | os.environ['LOCAL_SIZE'] = str(num_gpus) 40 | args.dist_url = 'env://' 41 | args.world_size = ntasks 42 | args.rank = proc_id 43 | args.gpu = proc_id % num_gpus 44 | else: 45 | print('Not using distributed mode') 46 | args.distributed = False 47 | return 48 | 49 | args.distributed = True 50 | 51 | torch.cuda.set_device(args.gpu) 52 | args.dist_backend = 'nccl' 53 | print('| distributed init (rank {}): {}'.format( 54 | args.rank, args.dist_url), flush=True) 55 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 56 | world_size=args.world_size, rank=args.rank) 57 | torch.distributed.barrier() 58 | setup_for_distributed(args.rank == 0) 59 | -------------------------------------------------------------------------------- /utils/drop_path.py: -------------------------------------------------------------------------------- 1 | # from timm.models.layers import DropPath 2 | import torch 3 | 4 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): 5 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 6 | 7 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 8 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 9 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 10 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 11 | 'survival rate' as the argument. 12 | 13 | """ 14 | if drop_prob == 0. or not training: 15 | return x 16 | keep_prob = 1 - drop_prob 17 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets 18 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob) 19 | if keep_prob > 0.0 and scale_by_keep: 20 | random_tensor.div_(keep_prob) 21 | return x * random_tensor 22 | 23 | 24 | class DropPath(torch.nn.Module): 25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 26 | """ 27 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True): 28 | super(DropPath, self).__init__() 29 | self.drop_prob = drop_prob 30 | self.scale_by_keep = scale_by_keep 31 | 32 | def forward(self, x): 33 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) 34 | 35 | def extra_repr(self): 36 | return f'drop_prob={round(self.drop_prob,3):0.3f}' -------------------------------------------------------------------------------- /utils/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | 4 | @torch.no_grad() 5 | def update_ema(ema_model, model, decay=0.9999): 6 | """ 7 | Step the EMA model towards the current model. 8 | """ 9 | ema_params = OrderedDict(ema_model.named_parameters()) 10 | model_params = OrderedDict(model.named_parameters()) 11 | 12 | for name, param in model_params.items(): 13 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed 14 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) 15 | 16 | 17 | def requires_grad(model, flag=True): 18 | """ 19 | Set requires_grad flag for all parameters in a model. 20 | """ 21 | for p in model.parameters(): 22 | p.requires_grad = flag -------------------------------------------------------------------------------- /utils/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | def create_logger(logging_dir): 5 | """ 6 | Create a logger that writes to a log file and stdout. 7 | """ 8 | if dist.get_rank() == 0: # real logger 9 | logging.basicConfig( 10 | level=logging.INFO, 11 | format='[\033[34m%(asctime)s\033[0m] %(message)s', 12 | datefmt='%Y-%m-%d %H:%M:%S', 13 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] 14 | ) 15 | logger = logging.getLogger(__name__) 16 | else: # dummy logger (does nothing) 17 | logger = logging.getLogger(__name__) 18 | logger.addHandler(logging.NullHandler()) 19 | return logger -------------------------------------------------------------------------------- /utils/video.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import skvideo.io 4 | from PIL import Image 5 | 6 | # Shifts src_tf dim to dest dim 7 | # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) 8 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): 9 | n_dims = len(x.shape) 10 | if src_dim < 0: 11 | src_dim = n_dims + src_dim 12 | if dest_dim < 0: 13 | dest_dim = n_dims + dest_dim 14 | 15 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims 16 | 17 | dims = list(range(n_dims)) 18 | del dims[src_dim] 19 | 20 | permutation = [] 21 | ctr = 0 22 | for i in range(n_dims): 23 | if i == dest_dim: 24 | permutation.append(src_dim) 25 | else: 26 | permutation.append(dims[ctr]) 27 | ctr += 1 28 | x = x.permute(permutation) 29 | if make_contiguous: 30 | x = x.contiguous() 31 | return x 32 | 33 | # reshapes tensor start from dim i (inclusive) 34 | # to dim j (exclusive) to the desired shape 35 | # e.g. if x.shape = (b, thw, c) then 36 | # view_range(x, 1, 2, (t, h, w)) returns 37 | # x of shape (b, t, h, w, c) 38 | def view_range(x, i, j, shape): 39 | shape = tuple(shape) 40 | 41 | n_dims = len(x.shape) 42 | if i < 0: 43 | i = n_dims + i 44 | 45 | if j is None: 46 | j = n_dims 47 | elif j < 0: 48 | j = n_dims + j 49 | 50 | assert 0 <= i < j <= n_dims 51 | 52 | x_shape = x.shape 53 | target_shape = x_shape[:i] + shape + x_shape[j:] 54 | return x.view(target_shape) 55 | 56 | 57 | def tensor_slice(x, begin, size): 58 | assert all([b >= 0 for b in begin]) 59 | size = [l - b if s == -1 else s 60 | for s, b, l in zip(size, begin, x.shape)] 61 | assert all([s >= 0 for s in size]) 62 | 63 | slices = [slice(b, b + s) for b, s in zip(begin, size)] 64 | return x[slices] 65 | 66 | 67 | def save_video_grid(video, fname, nrow=None, fps=5): 68 | b, c, t, h, w = video.shape 69 | video = video.permute(0, 2, 3, 4, 1) 70 | video = (video.cpu().numpy() * 255).astype('uint8') 71 | 72 | if nrow is None: 73 | nrow = math.ceil(math.sqrt(b)) 74 | ncol = math.ceil(b / nrow) 75 | padding = 1 76 | video_grid = np.zeros((t, (padding + h) * nrow + padding, 77 | (padding + w) * ncol + padding, c), dtype='uint8') 78 | for i in range(b): 79 | r = i // ncol 80 | c = i % ncol 81 | 82 | start_r = (padding + h) * r 83 | start_c = (padding + w) * c 84 | video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] 85 | 86 | skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '{}'.format(fps)}) 87 | 88 | 89 | def save_gif_grid(video, file_name, nrow=None, fps=5): 90 | b, c, t, h, w = video.shape 91 | video = video.permute(0, 2, 3, 4, 1) 92 | video = (video.cpu().numpy() * 255).astype('uint8') 93 | 94 | if nrow is None: 95 | nrow = math.ceil(math.sqrt(b)) 96 | ncol = math.ceil(b / nrow) 97 | padding = 1 98 | video_grid = np.zeros((t, (padding + h) * nrow + padding, 99 | (padding + w) * ncol + padding, c), dtype='uint8') 100 | for i in range(b): 101 | r = i // ncol 102 | c = i % ncol 103 | 104 | start_r = (padding + h) * r 105 | start_c = (padding + w) * c 106 | video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] 107 | 108 | images = [] 109 | for frame in video_grid: 110 | images.append(Image.fromarray(frame)) 111 | 112 | # Save the first image and append the rest of the images as frames in the GIF 113 | images[0].save(file_name, save_all=True, append_images=images[1:], optimize=False, duration=int(1000/fps), loop=0) 114 | 115 | # The 'duration' parameter defines the display time for each frame in milliseconds 116 | # The 'loop' parameter defines the number of loops the GIF should make (0 for infinite loop) 117 | --------------------------------------------------------------------------------