├── .gitignore
├── GETTING_STARTED.md
├── LICENSE
├── README.md
├── app.py
├── assets
└── teaser.jpg
├── autoregressive
├── models
│ ├── generate.py
│ ├── gpt.py
│ └── gpt_hf.py
├── sample
│ ├── sample_c2i.py
│ ├── sample_c2i_ddp.py
│ ├── sample_t2i.py
│ └── sample_t2i_ddp.py
├── serve
│ ├── README.md
│ ├── fake_json
│ │ ├── GPT-3B.json
│ │ ├── GPT-B.json
│ │ ├── GPT-L.json
│ │ ├── GPT-XL.json
│ │ └── GPT-XXL.json
│ ├── gpt_model.py
│ ├── gpu_executor.py
│ ├── llm.py
│ ├── llm_engine.py
│ ├── model_runner.py
│ ├── sample_c2i.py
│ ├── sampler.py
│ └── worker.py
└── train
│ ├── extract_codes_c2i.py
│ ├── extract_codes_t2i.py
│ ├── train_c2i.py
│ ├── train_c2i_fsdp.py
│ └── train_t2i.py
├── dataset
├── augmentation.py
├── build.py
├── coco.py
├── imagenet.py
├── openimage.py
├── pexels.py
└── t2i.py
├── evaluations
├── c2i
│ ├── README.md
│ └── evaluator.py
└── t2i
│ ├── PartiPrompts.tsv
│ ├── README.md
│ ├── coco_captions.csv
│ └── evaluation.py
├── language
├── README.md
├── extract_t5_feature.py
└── t5.py
├── requirements.txt
├── scripts
├── autoregressive
│ ├── extract_codes_c2i.sh
│ ├── sample_c2i.sh
│ ├── sample_t2i_coco.sh
│ ├── sample_t2i_parti.sh
│ ├── train_c2i.sh
│ ├── train_c2i_fsdp.sh
│ ├── train_t2i_stage1.sh
│ └── train_t2i_stage2.sh
├── language
│ ├── extract_flan_t5_feat_laion_coco_stage1.sh
│ ├── extract_flan_t5_feat_stage2.sh
│ └── extract_flan_t5_feat_trunc_stage2.sh
└── tokenizer
│ ├── reconstruction_consistency_decoder.sh
│ ├── reconstruction_vae.sh
│ ├── reconstruction_vq.sh
│ ├── reconstruction_vqgan.sh
│ ├── train_vq.sh
│ ├── train_vq_finetune.sh
│ ├── train_vq_finetune_continue.sh
│ └── val.sh
├── tokenizer
├── consistencydecoder
│ ├── README.md
│ ├── cd_demo.py
│ └── reconstruction_cd_ddp.py
├── tokenizer_image
│ ├── cache
│ │ └── vgg.pth
│ ├── discriminator.py
│ ├── discriminator_patchgan.py
│ ├── discriminator_stylegan.py
│ ├── lpips.py
│ ├── reconstruction_vq_ddp.py
│ ├── vq_demo.py
│ ├── vq_loss.py
│ ├── vq_model.py
│ ├── vq_model_hf.py
│ └── vq_train.py
├── vae
│ ├── README.md
│ ├── reconstruction_vae_ddp.py
│ └── sd_vae_demo.py
├── validation
│ └── val_ddp.py
└── vqgan
│ ├── README.md
│ ├── configs
│ ├── vqgan_imagenet_f16_1024.yaml
│ ├── vqgan_imagenet_f16_16384.yaml
│ ├── vqgan_openimage_f8_16384.yaml
│ └── vqgan_openimage_f8_256.yaml
│ ├── layer.py
│ ├── model.py
│ ├── quantize.py
│ ├── reconstruction_vqgan_ddp.py
│ └── taming_vqgan_demo.py
├── tools
├── check_image_codes.py
├── convert_pytorch_lightning_to_torch.py
├── draw_figure.py
├── imagenet_en_cn.py
├── openimage_json.py
├── push_gpt_to_hf.py
└── push_vae_to_hf.py
└── utils
├── data.py
├── deepspeed.py
├── distributed.py
├── drop_path.py
├── ema.py
├── logger.py
└── video.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110 | .pdm.toml
111 | .pdm-python
112 | .pdm-build/
113 |
114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115 | __pypackages__/
116 |
117 | # Celery stuff
118 | celerybeat-schedule
119 | celerybeat.pid
120 |
121 | # SageMath parsed files
122 | *.sage.py
123 |
124 | # Environments
125 | .env
126 | .venv
127 | env/
128 | venv/
129 | ENV/
130 | env.bak/
131 | venv.bak/
132 |
133 | # Spyder project settings
134 | .spyderproject
135 | .spyproject
136 |
137 | # Rope project settings
138 | .ropeproject
139 |
140 | # mkdocs documentation
141 | /site
142 |
143 | # mypy
144 | .mypy_cache/
145 | .dmypy.json
146 | dmypy.json
147 |
148 | # Pyre type checker
149 | .pyre/
150 |
151 | # pytype static type analyzer
152 | .pytype/
153 |
154 | # Cython debug symbols
155 | cython_debug/
156 |
157 | # PyCharm
158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160 | # and can be added to the global gitignore or merged into this file. For a more nuclear
161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162 | #.idea/
163 |
--------------------------------------------------------------------------------
/GETTING_STARTED.md:
--------------------------------------------------------------------------------
1 | ## Getting Started
2 | ### Requirements
3 | - Linux with Python ≥ 3.7
4 | - PyTorch ≥ 2.1
5 | - A100 GPUs
6 |
7 | ### Train VQVAE models
8 | ```
9 | bash scripts/tokenizer/train_vq.sh --cloud-save-path /path/to/cloud_disk --data-path /path/to/imagenet/train --image-size 256 --vq-model VQ-16
10 | ```
11 |
12 |
13 | ### Pre-extract discrete codes of training images
14 | ```
15 | bash scripts/autoregressive/extract_codes_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --data-path /path/to/imagenet/train --code-path /path/to/imagenet_code_c2i_flip_ten_crop --ten-crop --crop-range 1.1 --image-size 384
16 | ```
17 | and/or
18 | ```
19 | bash scripts/autoregressive/extract_codes_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --data-path /path/to/imagenet/train --code-path /path/to/imagenet_code_c2i_flip_ten_crop_105 --ten-crop --crop-range 1.05 --image-size 384
20 | ```
21 |
22 |
23 | ### Train AR models with DDP
24 | Before running, please change `nnodes, nproc_per_node, node_rank, master_addr, master_port` in `.sh`
25 | ```
26 | bash scripts/autoregressive/train_c2i.sh --cloud-save-path /path/to/cloud_disk --code-path /path/to/imagenet_code_c2i_flip_ten_crop --image-size 384 --gpt-model GPT-B
27 |
28 | bash scripts/autoregressive/train_c2i.sh --cloud-save-path /path/to/cloud_disk --code-path /path/to/imagenet_code_c2i_flip_ten_crop --image-size 384 --gpt-model GPT-L
29 |
30 | bash scripts/autoregressive/train_c2i.sh --cloud-save-path /path/to/cloud_disk --code-path /path/to/imagenet_code_c2i_flip_ten_crop --image-size 384 --gpt-model GPT-XL
31 | ```
32 |
33 |
34 | ### Train AR models with FSDP
35 | Before running, please change `nnodes, nproc_per_node, node_rank, master_addr, master_port` in `.sh`
36 | ```
37 | bash scripts/autoregressive/train_c2i_fsdp.sh --cloud-save-path /path/to/cloud_disk --code-path /path/to/imagenet_code_c2i_flip_ten_crop --image-size 384 --gpt-model GPT-XXL
38 |
39 | bash scripts/autoregressive/train_c2i_fsdp.sh --cloud-save-path /path/to/cloud_disk --code-path /path/to/imagenet_code_c2i_flip_ten_crop --image-size 384 --gpt-model GPT-3B
40 | ```
41 |
42 |
43 | ### Sampling
44 | ```
45 | bash scripts/autoregressive/sample_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_B.pt --gpt-model GPT-B --image-size 384 --image-size-eval 256 --cfg-scale 2.0
46 |
47 | bash scripts/autoregressive/sample_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_L.pt --gpt-model GPT-L --image-size 384 --image-size-eval 256 --cfg-scale 2.0
48 |
49 | bash scripts/autoregressive/sample_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XL.pt --gpt-model GPT-XL --image-size 384 --image-size-eval 256 --cfg-scale 1.75
50 |
51 | bash scripts/autoregressive/sample_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL.pt --gpt-model GPT-XXL --from-fsdp --image-size 384 --image-size-eval 256 --cfg-scale 1.75
52 |
53 | bash scripts/autoregressive/sample_c2i.sh --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_3B.pt --gpt-model GPT-3B --from-fsdp --image-size 384 --image-size-eval 256 --cfg-scale 1.65
54 | ```
55 |
56 |
57 | ### Evaluation
58 | Before evaluation, please refer [evaluation readme](evaluations/c2i/README.md) to install required packages.
59 | ```
60 | python3 evaluations/c2i/evaluator.py VIRTUAL_imagenet256_labeled.npz samples/GPT-B-c2i_B-size-384-size-256-VQ-16-topk-0-topp-1.0-temperature-1.0-cfg-2.0-seed-0.npz
61 | ```
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 FoundationVision
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Autoregressive Model Beats Diffusion: 🦙 Llama for Scalable Image Generation
2 |
3 |
4 |
5 |
6 | [](https://huggingface.co/spaces/FoundationVision/LlamaGen)
7 | [](https://arxiv.org/abs/2406.06525)
8 | [](https://peizesun.github.io/llamagen/)
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | This repo contains pre-trained model weights and training/sampling PyTorch(torch>=2.1.0) codes used in
20 |
21 | > [**Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation**](https://arxiv.org/abs/2406.06525)
22 | > [Peize Sun](https://peizesun.github.io/), [Yi Jiang](https://enjoyyi.github.io/), [Shoufa Chen](https://www.shoufachen.com/), [Shilong Zhang](https://jshilong.github.io/), [Bingyue Peng](), [Ping Luo](http://luoping.me/), [Zehuan Yuan](https://shallowyuan.github.io/)
23 | >
HKU, ByteDance
24 |
25 | You can find more visualizations on [](https://peizesun.github.io/llamagen/)
26 |
27 | ## 🔥 Update
28 | - [2024.06.28] Image tokenizers and AR models for text-conditional image generation are released ! Try it !
29 | - [2024.06.15] All models ranging from 100M to 3B parameters are supported by vLLM !
30 | - [2024.06.11] Image tokenizers and AR models for class-conditional image generation are released !
31 | - [2024.06.11] Code and Demo are released !
32 |
33 | ## 🌿 Introduction
34 | We introduce LlamaGen, a new family of image generation models that apply original ``next-token prediction`` paradigm of large language models to visual generation domain. It is an affirmative answer to whether vanilla autoregressive models, e.g., Llama, ``without inductive biases`` on visual signals can achieve state-of-the-art image generation performance if scaling properly. We reexamine design spaces of image tokenizers, scalability properties of image generation models, and their training data quality.
35 |
36 | In this repo, we release:
37 | * Two image tokenizers of downsample ratio 16 and 8.
38 | * Seven class-conditional generation models ranging from 100M to 3B parameters.
39 | * Two text-conditional generation models of 700M parameters.
40 | * Online demos in [](https://huggingface.co/spaces/FoundationVision/LlamaGen) for running pre-trained models.
41 | * Supported vLLM serving framework to enable 300% - 400% speedup.
42 |
43 | ## 🦄 Class-conditional image generation on ImageNet
44 | ### VQ-VAE models
45 | Method | params | tokens | rFID (256x256) | weight
46 | --- |:---:|:---:|:---:|:---:
47 | vq_ds16_c2i | 72M | 16x16 | 2.19 | [vq_ds16_c2i.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds16_c2i.pt)
48 | vq_ds16_c2i | 72M | 24x24 | 0.94 | above
49 | vq_ds16_c2i | 72M | 32x32 | 0.70 | above
50 | vq_ds8_c2i | 70M | 32x32 | 0.59 | [vq_ds8_c2i.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/vq_ds8_c2i.pt)
51 |
52 | ### AR models
53 | Method | params | training | tokens | FID (256x256) | weight
54 | --- |:---:|:---:|:---:|:---:|:---:|
55 | LlamaGen-B | 111M | DDP | 16x16 | 5.46 | [c2i_B_256.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_B_256.pt)
56 | LlamaGen-B | 111M | DDP | 24x24 | 6.09 | [c2i_B_384.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_B_384.pt)
57 | LlamaGen-L | 343M | DDP | 16x16 | 3.80 | [c2i_L_256.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_L_256.pt)
58 | LlamaGen-L | 343M | DDP | 24x24 | 3.07 | [c2i_L_384.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_L_384.pt)
59 | LlamaGen-XL | 775M | DDP | 24x24 | 2.62 | [c2i_X_384L.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_XL_384.pt)
60 | LlamaGen-XXL | 1.4B | FSDP | 24x24 | 2.34 | [c2i_XXL_384.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_XXL_384.pt)
61 | LlamaGen-3B | 3.1B | FSDP | 24x24 | 2.18 | [c2i_3B_384.pt](https://huggingface.co/FoundationVision/LlamaGen/resolve/main/c2i_3B_384.pt)
62 |
63 |
64 | ### Demo
65 | Please download models, put them in the folder `./pretrained_models`, and run
66 | ```
67 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_L_384.pt --gpt-model GPT-L --image-size 384
68 | # or
69 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL_384.pt --gpt-model GPT-XXL --from-fsdp --image-size 384
70 | ```
71 | The generated images will be saved to `sample_c2i.png`.
72 |
73 | ### Gradio Demo
74 |
75 | You can use our online gradio demo [](https://huggingface.co/spaces/FoundationVision/LlamaGen) or run gradio locally:
76 | ```bash
77 | python app.py
78 | ```
79 |
80 |
81 | ## 🚀 Text-conditional image generation
82 | ### VQ-VAE models
83 | Method | params | tokens | data | weight
84 | --- |:---:|:---:|:---:|:---:
85 | vq_ds16_t2i | 72M | 16x16 | LAION COCO (50M) + internal data (10M) | [vq_ds16_t2i.pt](https://huggingface.co/peizesun/llamagen_t2i/resolve/main/vq_ds16_t2i.pt)
86 |
87 | ### AR models
88 | Method | params | tokens | data | weight
89 | --- |:---:|:---:|:---:|:---:
90 | LlamaGen-XL | 775M | 16x16 | LAION COCO (50M) | [t2i_XL_stage1_256.pt](https://huggingface.co/peizesun/llamagen_t2i/resolve/main/t2i_XL_stage1_256.pt)
91 | LlamaGen-XL | 775M | 32x32 | internal data (10M) | [t2i_XL_stage2_512.pt](https://huggingface.co/peizesun/llamagen_t2i/resolve/main/t2i_XL_stage2_512.pt)
92 |
93 | ### Demo
94 | Before running demo, please refer to [language readme](language/README.md) to install the required packages and language models.
95 |
96 | Please download models, put them in the folder `./pretrained_models`, and run
97 | ```
98 | python3 autoregressive/sample/sample_t2i.py --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt --gpt-ckpt ./pretrained_models/t2i_XL_stage1_256.pt --gpt-model GPT-XL --image-size 256
99 | # or
100 | python3 autoregressive/sample/sample_t2i.py --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt --gpt-ckpt ./pretrained_models/t2i_XL_stage2_512.pt --gpt-model GPT-XL --image-size 512
101 | ```
102 | The generated images will be saved to `sample_t2i.png`.
103 |
104 | ### Local Gradio Demo
105 |
106 |
107 |
108 | ## ⚡ Serving
109 | We use serving framework [vLLM](https://github.com/vllm-project/vllm) to enable higher throughput. Please refer to [serving readme](autoregressive/serve/README.md) to install the required packages.
110 | ```
111 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL_384.pt --gpt-model GPT-XXL --from-fsdp --image-size 384
112 | ```
113 | The generated images will be saved to `sample_c2i_vllm.png`.
114 |
115 |
116 | ## Getting Started
117 | See [Getting Started](GETTING_STARTED.md) for installation, training and evaluation.
118 |
119 |
120 | ## License
121 | The majority of this project is licensed under MIT License. Portions of the project are available under separate license of referred projects, detailed in corresponding files.
122 |
123 |
124 | ## BibTeX
125 | ```bibtex
126 | @article{sun2024autoregressive,
127 | title={Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation},
128 | author={Sun, Peize and Jiang, Yi and Chen, Shoufa and Zhang, Shilong and Peng, Bingyue and Luo, Ping and Yuan, Zehuan},
129 | journal={arXiv preprint arXiv:2406.06525},
130 | year={2024}
131 | }
132 | ```
133 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import gradio as gr
3 | from tools.imagenet_en_cn import IMAGENET_1K_CLASSES
4 | from huggingface_hub import hf_hub_download
5 | import torch
6 | torch.backends.cuda.matmul.allow_tf32 = True
7 | torch.backends.cudnn.allow_tf32 = True
8 | torch.set_float32_matmul_precision('high')
9 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
10 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
11 | from vllm import SamplingParams
12 | import time
13 | import argparse
14 | from tokenizer.tokenizer_image.vq_model import VQ_models
15 | from autoregressive.serve.llm import LLM
16 | from autoregressive.serve.sampler import Sampler
17 |
18 | device = "cuda"
19 |
20 | model2ckpt = {
21 | "GPT-XL": ("vq_ds16_c2i.pt", "c2i_XL_384.pt", 384),
22 | "GPT-B": ("vq_ds16_c2i.pt", "c2i_B_256.pt", 256),
23 | }
24 |
25 | def load_model(args):
26 | ckpt_folder = "./"
27 | vq_ckpt, gpt_ckpt, image_size = model2ckpt[args.gpt_model]
28 | hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder)
29 | hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder)
30 | # create and load model
31 | vq_model = VQ_models[args.vq_model](
32 | codebook_size=args.codebook_size,
33 | codebook_embed_dim=args.codebook_embed_dim)
34 | vq_model.to(device)
35 | vq_model.eval()
36 | checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu")
37 | vq_model.load_state_dict(checkpoint["model"])
38 | del checkpoint
39 | print(f"image tokenizer is loaded")
40 |
41 | # Create an LLM.
42 | args.image_size = image_size
43 | args.gpt_ckpt = f"{ckpt_folder}{gpt_ckpt}"
44 | llm = LLM(
45 | args=args,
46 | model='serve/fake_json/{}.json'.format(args.gpt_model),
47 | gpu_memory_utilization=0.6,
48 | skip_tokenizer_init=True)
49 | print(f"gpt model is loaded")
50 | return vq_model, llm, image_size
51 |
52 |
53 | def infer(cfg_scale, top_k, top_p, temperature, class_label, seed):
54 | llm.llm_engine.model_executor.driver_worker.model_runner.model.sampler = Sampler(cfg_scale)
55 | args.cfg_scale = cfg_scale
56 | n = 4
57 | latent_size = image_size // args.downsample_size
58 | # Labels to condition the model with (feel free to change):
59 | class_labels = [class_label for _ in range(n)]
60 | qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
61 |
62 | prompt_token_ids = [[cind] for cind in class_labels]
63 | if cfg_scale > 1.0:
64 | prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))])
65 |
66 | # Create a sampling params object.
67 | sampling_params = SamplingParams(
68 | temperature=temperature, top_p=top_p, top_k=top_k,
69 | max_tokens=latent_size ** 2)
70 |
71 | t1 = time.time()
72 | torch.manual_seed(seed)
73 | outputs = llm.generate(
74 | prompt_token_ids=prompt_token_ids,
75 | sampling_params=sampling_params,
76 | use_tqdm=False)
77 | sampling_time = time.time() - t1
78 | print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
79 |
80 | index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
81 | if cfg_scale > 1.0:
82 | index_sample = index_sample[:len(class_labels)]
83 | t2 = time.time()
84 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
85 | decoder_time = time.time() - t2
86 | print(f"decoder takes about {decoder_time:.2f} seconds.")
87 | # Convert to PIL.Image format:
88 | samples = samples.mul(127.5).add_(128.0).clamp_(0, 255).permute(0, 2, 3, 1).to("cpu", torch.uint8).numpy()
89 | samples = [Image.fromarray(sample) for sample in samples]
90 | return samples
91 |
92 |
93 | parser = argparse.ArgumentParser()
94 | parser.add_argument("--gpt-model", type=str, default="GPT-XL")
95 | parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
96 | parser.add_argument("--from-fsdp", action='store_true')
97 | parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
98 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
99 | parser.add_argument("--compile", action='store_true', default=False)
100 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
101 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
102 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
103 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
104 | parser.add_argument("--num-classes", type=int, default=1000)
105 | parser.add_argument("--cfg-scale", type=float, default=4.0)
106 | parser.add_argument("--cfg-interval", type=float, default=-1)
107 | parser.add_argument("--seed", type=int, default=0)
108 | parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
109 | parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
110 | parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
111 | args = parser.parse_args()
112 |
113 | vq_model, llm, image_size = load_model(args)
114 |
115 | with gr.Blocks() as demo:
116 | gr.Markdown("
Autoregressive Model Beats Diffusion: Llama for Scalable Image Generation
")
117 |
118 | with gr.Tabs():
119 | with gr.TabItem('Generate'):
120 | with gr.Row():
121 | with gr.Column():
122 | with gr.Row():
123 | i1k_class = gr.Dropdown(
124 | list(IMAGENET_1K_CLASSES.values()),
125 | value='llama [羊驼]',
126 | type="index", label='ImageNet-1K Class'
127 | )
128 | cfg_scale = gr.Slider(minimum=1, maximum=25, step=0.1, value=4.0, label='Classifier-free Guidance Scale')
129 | top_k = gr.Slider(minimum=1, maximum=16384, step=1, value=4000, label='Top-K')
130 | top_p = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label="Top-P")
131 | temperature = gr.Slider(minimum=0., maximum=1.0, step=0.1, value=1.0, label='Temperature')
132 | seed = gr.Slider(minimum=0, maximum=1000, step=1, value=42, label='Seed')
133 | # seed = gr.Number(value=0, label='Seed')
134 | button = gr.Button("Generate", variant="primary")
135 | with gr.Column():
136 | output = gr.Gallery(label='Generated Images', height=700)
137 | button.click(infer, inputs=[cfg_scale, top_k, top_p, temperature, i1k_class, seed], outputs=[output])
138 | demo.queue()
139 | demo.launch(debug=True)
140 |
--------------------------------------------------------------------------------
/assets/teaser.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/LlamaGen/ce98ec41803a74a90ce68c40ababa9eaeffeb4ec/assets/teaser.jpg
--------------------------------------------------------------------------------
/autoregressive/models/generate.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
3 | # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4 | import torch
5 | import torch.nn as nn
6 | from torch.nn import functional as F
7 | import torch._dynamo.config
8 | import torch._inductor.config
9 | import copy
10 | # torch._inductor.config.coordinate_descent_tuning = True
11 | # torch._inductor.config.triton.unique_kernel_names = True
12 | # torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
13 |
14 |
15 | ### from https://huggingface.co/transformers/v3.2.0/_modules/transformers/generation_utils.html
16 | def top_k_top_p_filtering(
17 | logits,
18 | top_k: int = 0,
19 | top_p: float = 1.0,
20 | filter_value: float = -float("Inf"),
21 | min_tokens_to_keep: int = 1,
22 | ):
23 | """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
24 | Args:
25 | logits: logits distribution shape (batch size, vocabulary size)
26 | if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
27 | if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
28 | Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
29 | Make sure we keep at least min_tokens_to_keep per batch example in the output
30 | From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
31 | """
32 | if top_k > 0:
33 | top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1)) # Safety check
34 | # Remove all tokens with a probability less than the last token of the top-k
35 | indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
36 | logits[indices_to_remove] = filter_value
37 |
38 | if top_p < 1.0:
39 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
40 | cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
41 |
42 | # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
43 | sorted_indices_to_remove = cumulative_probs > top_p
44 | if min_tokens_to_keep > 1:
45 | # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
46 | sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
47 | # Shift the indices to the right to keep also the first token above the threshold
48 | sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
49 | sorted_indices_to_remove[..., 0] = 0
50 |
51 | # scatter sorted tensors to original indexing
52 | indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
53 | logits[indices_to_remove] = filter_value
54 | return logits
55 |
56 |
57 | def sample(logits, temperature: float=1.0, top_k: int=0, top_p: float=1.0, sample_logits=True):
58 | logits = logits[:, -1, :] / max(temperature, 1e-5)
59 | if top_k > 0 or top_p < 1.0:
60 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
61 | probs = F.softmax(logits, dim=-1)
62 | if sample_logits:
63 | idx = torch.multinomial(probs, num_samples=1)
64 | else:
65 | _, idx = torch.topk(probs, k=1, dim=-1)
66 | return idx, probs
67 |
68 |
69 | def logits_to_probs(logits, temperature: float = 1.0, top_p: float=1.0, top_k: int = None, **kwargs):
70 | logits = logits / max(temperature, 1e-5)
71 | if top_k > 0 or top_p < 1.0:
72 | logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
73 | probs = torch.nn.functional.softmax(logits, dim=-1)
74 | return probs
75 |
76 |
77 | def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, **sampling_kwargs):
78 | if cfg_scale > 1.0:
79 | logits, _ = model(None, cond_idx, input_pos)
80 | logits_combined = logits
81 | cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
82 | logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
83 | else:
84 | logits, _ = model(None, cond_idx, input_pos)
85 |
86 | return sample(logits, **sampling_kwargs)[0]
87 |
88 |
89 | def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, cfg_flag: bool, **sampling_kwargs):
90 | assert input_pos.shape[-1] == 1
91 | if cfg_scale > 1.0:
92 | x_combined = torch.cat([x, x])
93 | logits, _ = model(x_combined, cond_idx=None, input_pos=input_pos)
94 | logits_combined = logits
95 | cond_logits, uncond_logits = torch.split(logits_combined, len(logits_combined) // 2, dim=0)
96 | if cfg_flag:
97 | logits = uncond_logits + (cond_logits - uncond_logits) * cfg_scale
98 | else:
99 | logits = cond_logits
100 | else:
101 | logits, _ = model(x, cond_idx=None, input_pos=input_pos)
102 | return sample(logits, **sampling_kwargs)
103 |
104 |
105 | def decode_n_tokens(
106 | model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
107 | cfg_scale: float, cfg_interval: int,
108 | **sampling_kwargs):
109 | new_tokens, new_probs = [], []
110 | cfg_flag = True
111 | for i in range(num_new_tokens):
112 | with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
113 | if cfg_interval > -1 and i > cfg_interval:
114 | cfg_flag = False
115 | next_token, next_prob = decode_one_token(
116 | model, cur_token, input_pos, cfg_scale, cfg_flag, **sampling_kwargs
117 | )
118 | input_pos += 1
119 | new_tokens.append(next_token.clone())
120 | new_probs.append(next_prob.clone())
121 | cur_token = next_token.view(-1, 1)
122 |
123 | return new_tokens, new_probs
124 |
125 |
126 | @torch.no_grad()
127 | def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_interval=-1, **sampling_kwargs):
128 | if model.model_type == 'c2i':
129 | if cfg_scale > 1.0:
130 | cond_null = torch.ones_like(cond) * model.num_classes
131 | cond_combined = torch.cat([cond, cond_null])
132 | else:
133 | cond_combined = cond
134 | T = 1
135 | elif model.model_type == 't2i':
136 | if cfg_scale > 1.0:
137 | cond_null = torch.zeros_like(cond) + model.cls_embedding.uncond_embedding
138 | cond_combined = torch.cat([cond, cond_null])
139 | else:
140 | cond_combined = cond
141 | T = cond.shape[1]
142 | else:
143 | raise Exception("please check model type")
144 |
145 | T_new = T + max_new_tokens
146 | max_seq_length = T_new
147 | max_batch_size = cond.shape[0]
148 |
149 | device = cond.device
150 | with torch.device(device):
151 | max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
152 | model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=model.tok_embeddings.weight.dtype)
153 |
154 | if emb_masks is not None:
155 | assert emb_masks.shape[0] == max_batch_size
156 | assert emb_masks.shape[-1] == T
157 | if cfg_scale > 1.0:
158 | model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
159 | else:
160 | model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
161 |
162 | eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
163 | model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
164 |
165 | # create an empty tensor of the expected final shape and fill in the current tokens
166 | seq = torch.empty((max_batch_size, T_new), dtype=torch.int, device=device)
167 |
168 | input_pos = torch.arange(0, T, device=device)
169 | next_token = prefill(model, cond_combined, input_pos, cfg_scale, **sampling_kwargs)
170 | seq[:, T:T+1] = next_token
171 |
172 | input_pos = torch.tensor([T], device=device, dtype=torch.int)
173 | generated_tokens, _ = decode_n_tokens(model, next_token, input_pos, max_new_tokens-1, cfg_scale, cfg_interval, **sampling_kwargs)
174 | seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
175 |
176 | return seq[:, T:]
177 |
--------------------------------------------------------------------------------
/autoregressive/models/gpt_hf.py:
--------------------------------------------------------------------------------
1 | from autoregressive.models.gpt import ModelArgs, Transformer
2 | from huggingface_hub import PyTorchModelHubMixin
3 |
4 |
5 | class TransformerHF(Transformer, PyTorchModelHubMixin, repo_url="https://github.com/FoundationVision/LlamaGen", license="mit", tags=["llamagen", "text-to-image"]):
6 | pass
7 |
8 |
9 | #################################################################################
10 | # GPT Configs #
11 | #################################################################################
12 | ### text-conditional
13 | def GPT_7B(**kwargs):
14 | return TransformerHF(ModelArgs(n_layer=32, n_head=32, dim=4096, **kwargs)) # 6.6B
15 |
16 | def GPT_3B(**kwargs):
17 | return TransformerHF(ModelArgs(n_layer=24, n_head=32, dim=3200, **kwargs)) # 3.1B
18 |
19 | def GPT_1B(**kwargs):
20 | return TransformerHF(ModelArgs(n_layer=22, n_head=32, dim=2048, **kwargs)) # 1.2B
21 |
22 | ### class-conditional
23 | def GPT_XXXL(**kwargs):
24 | return TransformerHF(ModelArgs(n_layer=48, n_head=40, dim=2560, **kwargs)) # 3.9B
25 |
26 | def GPT_XXL(**kwargs):
27 | return TransformerHF(ModelArgs(n_layer=48, n_head=24, dim=1536, **kwargs)) # 1.4B
28 |
29 | def GPT_XL(**kwargs):
30 | return TransformerHF(ModelArgs(n_layer=36, n_head=20, dim=1280, **kwargs)) # 775M
31 |
32 | def GPT_L(**kwargs):
33 | return TransformerHF(ModelArgs(n_layer=24, n_head=16, dim=1024, **kwargs)) # 343M
34 |
35 | def GPT_B(**kwargs):
36 | return TransformerHF(ModelArgs(n_layer=12, n_head=12, dim=768, **kwargs)) # 111M
37 |
38 |
39 | GPT_models_HF = {
40 | 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
41 | 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
42 | }
43 |
--------------------------------------------------------------------------------
/autoregressive/sample/sample_c2i.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # DiT: https://github.com/facebookresearch/DiT/blob/main/sample.py
3 | import torch
4 | torch.backends.cuda.matmul.allow_tf32 = True
5 | torch.backends.cudnn.allow_tf32 = True
6 | torch.set_float32_matmul_precision('high')
7 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
8 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)
9 | from torchvision.utils import save_image
10 |
11 | import time
12 | import argparse
13 | from tokenizer.tokenizer_image.vq_model import VQ_models
14 | from autoregressive.models.gpt import GPT_models
15 | from autoregressive.models.generate import generate
16 |
17 |
18 | def main(args):
19 | # Setup PyTorch:
20 | torch.manual_seed(args.seed)
21 | torch.backends.cudnn.deterministic = True
22 | torch.backends.cudnn.benchmark = False
23 | torch.set_grad_enabled(False)
24 | device = "cuda" if torch.cuda.is_available() else "cpu"
25 |
26 | # create and load model
27 | vq_model = VQ_models[args.vq_model](
28 | codebook_size=args.codebook_size,
29 | codebook_embed_dim=args.codebook_embed_dim)
30 | vq_model.to(device)
31 | vq_model.eval()
32 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
33 | vq_model.load_state_dict(checkpoint["model"])
34 | del checkpoint
35 | print(f"image tokenizer is loaded")
36 |
37 | # create and load gpt model
38 | precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
39 | latent_size = args.image_size // args.downsample_size
40 | gpt_model = GPT_models[args.gpt_model](
41 | vocab_size=args.codebook_size,
42 | block_size=latent_size ** 2,
43 | num_classes=args.num_classes,
44 | cls_token_num=args.cls_token_num,
45 | model_type=args.gpt_type,
46 | ).to(device=device, dtype=precision)
47 |
48 | checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
49 | if args.from_fsdp: # fspd
50 | model_weight = checkpoint
51 | elif "model" in checkpoint: # ddp
52 | model_weight = checkpoint["model"]
53 | elif "module" in checkpoint: # deepspeed
54 | model_weight = checkpoint["module"]
55 | elif "state_dict" in checkpoint:
56 | model_weight = checkpoint["state_dict"]
57 | else:
58 | raise Exception("please check model weight, maybe add --from-fsdp to run command")
59 | # if 'freqs_cis' in model_weight:
60 | # model_weight.pop('freqs_cis')
61 | gpt_model.load_state_dict(model_weight, strict=False)
62 | gpt_model.eval()
63 | del checkpoint
64 | print(f"gpt model is loaded")
65 |
66 | if args.compile:
67 | print(f"compiling the model...")
68 | gpt_model = torch.compile(
69 | gpt_model,
70 | mode="reduce-overhead",
71 | fullgraph=True
72 | ) # requires PyTorch 2.0 (optional)
73 | else:
74 | print(f"no need to compile model in demo")
75 |
76 | # Labels to condition the model with (feel free to change):
77 | class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
78 | c_indices = torch.tensor(class_labels, device=device)
79 | qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
80 |
81 | t1 = time.time()
82 | index_sample = generate(
83 | gpt_model, c_indices, latent_size ** 2,
84 | cfg_scale=args.cfg_scale, cfg_interval=args.cfg_interval,
85 | temperature=args.temperature, top_k=args.top_k,
86 | top_p=args.top_p, sample_logits=True,
87 | )
88 | sampling_time = time.time() - t1
89 | print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
90 |
91 | t2 = time.time()
92 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
93 | decoder_time = time.time() - t2
94 | print(f"decoder takes about {decoder_time:.2f} seconds.")
95 |
96 | # Save and display images:
97 | save_image(samples, "sample_{}.png".format(args.gpt_type), nrow=4, normalize=True, value_range=(-1, 1))
98 | print(f"image is saved to sample_{args.gpt_type}.png")
99 |
100 |
101 | if __name__ == "__main__":
102 | parser = argparse.ArgumentParser()
103 | parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
104 | parser.add_argument("--gpt-ckpt", type=str, default=None)
105 | parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
106 | parser.add_argument("--from-fsdp", action='store_true')
107 | parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
108 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
109 | parser.add_argument("--compile", action='store_true', default=False)
110 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
111 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
112 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
113 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
114 | parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
115 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
116 | parser.add_argument("--num-classes", type=int, default=1000)
117 | parser.add_argument("--cfg-scale", type=float, default=4.0)
118 | parser.add_argument("--cfg-interval", type=float, default=-1)
119 | parser.add_argument("--seed", type=int, default=0)
120 | parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
121 | parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
122 | parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
123 | args = parser.parse_args()
124 | main(args)
--------------------------------------------------------------------------------
/autoregressive/sample/sample_t2i.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.backends.cuda.matmul.allow_tf32 = True
3 | torch.backends.cudnn.allow_tf32 = True
4 | torch.set_float32_matmul_precision('high')
5 | setattr(torch.nn.Linear, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
6 | setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None) # disable default parameter init for faster speed
7 | from torchvision.utils import save_image
8 |
9 | import os
10 | import time
11 | import argparse
12 | from tokenizer.tokenizer_image.vq_model import VQ_models
13 | from language.t5 import T5Embedder
14 | from autoregressive.models.gpt import GPT_models
15 | from autoregressive.models.generate import generate
16 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
17 |
18 |
19 |
20 | def main(args):
21 | # Setup PyTorch:
22 | torch.manual_seed(args.seed)
23 | torch.backends.cudnn.deterministic = True
24 | torch.backends.cudnn.benchmark = False
25 | torch.set_grad_enabled(False)
26 | device = "cuda" if torch.cuda.is_available() else "cpu"
27 |
28 | # create and load model
29 | vq_model = VQ_models[args.vq_model](
30 | codebook_size=args.codebook_size,
31 | codebook_embed_dim=args.codebook_embed_dim)
32 | vq_model.to(device)
33 | vq_model.eval()
34 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
35 | vq_model.load_state_dict(checkpoint["model"])
36 | del checkpoint
37 | print(f"image tokenizer is loaded")
38 |
39 | # create and load gpt model
40 | precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
41 | latent_size = args.image_size // args.downsample_size
42 | gpt_model = GPT_models[args.gpt_model](
43 | block_size=latent_size ** 2,
44 | cls_token_num=args.cls_token_num,
45 | model_type=args.gpt_type,
46 | ).to(device=device, dtype=precision)
47 |
48 | checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
49 |
50 | if "model" in checkpoint: # ddp
51 | model_weight = checkpoint["model"]
52 | elif "module" in checkpoint: # deepspeed
53 | model_weight = checkpoint["module"]
54 | elif "state_dict" in checkpoint:
55 | model_weight = checkpoint["state_dict"]
56 | else:
57 | raise Exception("please check model weight")
58 | gpt_model.load_state_dict(model_weight, strict=False)
59 | gpt_model.eval()
60 | del checkpoint
61 | print(f"gpt model is loaded")
62 |
63 | if args.compile:
64 | print(f"compiling the model...")
65 | gpt_model = torch.compile(
66 | gpt_model,
67 | mode="reduce-overhead",
68 | fullgraph=True
69 | ) # requires PyTorch 2.0 (optional)
70 | else:
71 | print(f"no need to compile model in demo")
72 |
73 | assert os.path.exists(args.t5_path)
74 | t5_model = T5Embedder(
75 | device=device,
76 | local_cache=True,
77 | cache_dir=args.t5_path,
78 | dir_or_name=args.t5_model_type,
79 | torch_dtype=precision,
80 | model_max_length=args.t5_feature_max_len,
81 | )
82 | prompts = [
83 | "A portrait photo of a kangaroo wearing an orange hoodie and blue sunglasses standing on the grassin front of the Sydney Opera House holding a sign on the chest that says Welcome Friends!",
84 | "A blue Porsche 356 parked in front of a yellow brick wall.",
85 | "A photo of an astronaut riding a horse in the forest. There is a river in front of them with water lilies.",
86 | "A map of the United States made out of sushi. It is on a table next to a glass of red wine."
87 | ]
88 |
89 | caption_embs, emb_masks = t5_model.get_text_embeddings(prompts)
90 |
91 |
92 | if not args.no_left_padding:
93 | print(f"processing left-padding...")
94 | # a naive way to implement left-padding
95 | new_emb_masks = torch.flip(emb_masks, dims=[-1])
96 | new_caption_embs = []
97 | for idx, (caption_emb, emb_mask) in enumerate(zip(caption_embs, emb_masks)):
98 | valid_num = int(emb_mask.sum().item())
99 | print(f' prompt {idx} token len: {valid_num}')
100 | new_caption_emb = torch.cat([caption_emb[valid_num:], caption_emb[:valid_num]])
101 | new_caption_embs.append(new_caption_emb)
102 | new_caption_embs = torch.stack(new_caption_embs)
103 | else:
104 | new_caption_embs, new_emb_masks = caption_embs, emb_masks
105 | c_indices = new_caption_embs * new_emb_masks[:,:, None]
106 | c_emb_masks = new_emb_masks
107 |
108 | qzshape = [len(c_indices), args.codebook_embed_dim, latent_size, latent_size]
109 | t1 = time.time()
110 | index_sample = generate(
111 | gpt_model, c_indices, latent_size ** 2,
112 | c_emb_masks,
113 | cfg_scale=args.cfg_scale,
114 | temperature=args.temperature, top_k=args.top_k,
115 | top_p=args.top_p, sample_logits=True,
116 | )
117 | sampling_time = time.time() - t1
118 | print(f"Full sampling takes about {sampling_time:.2f} seconds.")
119 |
120 | t2 = time.time()
121 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
122 | decoder_time = time.time() - t2
123 | print(f"decoder takes about {decoder_time:.2f} seconds.")
124 |
125 | save_image(samples, "sample_{}.png".format(args.gpt_type), nrow=4, normalize=True, value_range=(-1, 1))
126 | print(f"image is saved to sample_{args.gpt_type}.png")
127 |
128 |
129 |
130 | if __name__ == "__main__":
131 | parser = argparse.ArgumentParser()
132 | parser.add_argument("--t5-path", type=str, default='pretrained_models/t5-ckpt')
133 | parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
134 | parser.add_argument("--t5-feature-max-len", type=int, default=120)
135 | parser.add_argument("--t5-feature-dim", type=int, default=2048)
136 | parser.add_argument("--no-left-padding", action='store_true', default=False)
137 | parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-XL")
138 | parser.add_argument("--gpt-ckpt", type=str, default=None)
139 | parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="t2i", help="class->image or text->image")
140 | parser.add_argument("--cls-token-num", type=int, default=120, help="max token number of condition input")
141 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
142 | parser.add_argument("--compile", action='store_true', default=False)
143 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
144 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
145 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
146 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
147 | parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=512)
148 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
149 | parser.add_argument("--num-classes", type=int, default=1000)
150 | parser.add_argument("--cfg-scale", type=float, default=7.5)
151 | parser.add_argument("--seed", type=int, default=0)
152 | parser.add_argument("--top-k", type=int, default=1000, help="top-k value to sample with")
153 | parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
154 | parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
155 | args = parser.parse_args()
156 | main(args)
157 |
--------------------------------------------------------------------------------
/autoregressive/serve/README.md:
--------------------------------------------------------------------------------
1 | ## serving by vLLM
2 |
3 | ### Install
4 | ```
5 | pip install vllm==0.4.1
6 | ```
7 |
8 | ### Comparison (A100)
9 |
10 | Method | params | baseline(s) | vllm(s) | speed-up ratio
11 | --- |:---:|:---:|:---:|:---:
12 | [GPT-B](./fake_json/GPT-B.json) | 111M | 7.80 | 2.39 | 326 %
13 | [GPT-L](./fake_json/GPT-L.json) | 343M | 13.72 | 3.48 | 380 %
14 | [GPT-XL](./fake_json/GPT-XL.json) | 775M | 19.76 | 4.84 | 408 %
15 | [GPT-XXL](./fake_json/GPT-XXL.json)| 1.4B | 26.38 | 6.36 | 414 %
16 | [GPT-3B](./fake_json/GPT-3B.json) | 3.1B | 14.73 | 6.26 | 235 %
17 |
18 | ```
19 | ### GPT-B
20 | # 7.80 seconds
21 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_B_384.pt --image-size 384
22 |
23 | # 2.39 seconds
24 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_B_384.pt --image-size 384
25 |
26 |
27 | ### GPT-L
28 | # 13.72 seconds
29 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_L_384.pt --gpt-model GPT-L --image-size 384
30 |
31 | # 3.48 seconds
32 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_L_384.pt --gpt-model GPT-L --image-size 384
33 |
34 |
35 | ### GPT-XL
36 | # 19.76 seconds
37 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XL_384.pt --gpt-model GPT-XL --image-size 384
38 |
39 | # 4.84 seconds
40 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XL_384.pt --gpt-model GPT-XL --image-size 384
41 |
42 |
43 | ### GPT-XXL
44 | # 26.38 seconds
45 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL_384.pt --from-fsdp --gpt-model GPT-XXL --image-size 384
46 |
47 | # 6.36 seconds
48 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_XXL_384.pt --from-fsdp --gpt-model GPT-XXL --image-size 384
49 |
50 |
51 | ### GPT-3B
52 | # 14.73 seconds
53 | python3 autoregressive/sample/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_3B_384.pt --from-fsdp --gpt-model GPT-3B --image-size 384
54 |
55 | # 6.26 seconds
56 | python3 autoregressive/serve/sample_c2i.py --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt --gpt-ckpt ./pretrained_models/c2i_3B_384.pt --from-fsdp --gpt-model GPT-3B --image-size 384
57 |
58 | ```
59 |
--------------------------------------------------------------------------------
/autoregressive/serve/fake_json/GPT-3B.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "facebook/opt-125m",
3 | "activation_dropout": 0.0,
4 | "activation_function": "relu",
5 | "architectures": [
6 | "OPTForCausalLM"
7 | ],
8 | "attention_dropout": 0.0,
9 | "bos_token_id": 2,
10 | "do_layer_norm_before": true,
11 | "dropout": 0.1,
12 | "eos_token_id": 2,
13 | "ffn_dim": 3072,
14 | "hidden_size": 3584,
15 | "init_std": 0.02,
16 | "layerdrop": 0.0,
17 | "max_position_embeddings": 2048,
18 | "model_type": "opt",
19 | "num_attention_heads": 32,
20 | "num_hidden_layers": 24,
21 | "pad_token_id": 1,
22 | "prefix": "",
23 | "torch_dtype": "bfloat16",
24 | "transformers_version": "4.21.0.dev0",
25 | "use_cache": true,
26 | "vocab_size": 16384,
27 | "word_embed_proj_dim": 768
28 | }
29 |
--------------------------------------------------------------------------------
/autoregressive/serve/fake_json/GPT-B.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "facebook/opt-125m",
3 | "activation_dropout": 0.0,
4 | "activation_function": "relu",
5 | "architectures": [
6 | "OPTForCausalLM"
7 | ],
8 | "attention_dropout": 0.0,
9 | "bos_token_id": 2,
10 | "do_layer_norm_before": true,
11 | "dropout": 0.1,
12 | "eos_token_id": 2,
13 | "ffn_dim": 3072,
14 | "hidden_size": 768,
15 | "init_std": 0.02,
16 | "layerdrop": 0.0,
17 | "max_position_embeddings": 2048,
18 | "model_type": "opt",
19 | "num_attention_heads": 12,
20 | "num_hidden_layers": 12,
21 | "pad_token_id": 1,
22 | "prefix": "",
23 | "torch_dtype": "bfloat16",
24 | "transformers_version": "4.21.0.dev0",
25 | "use_cache": true,
26 | "vocab_size": 16384,
27 | "word_embed_proj_dim": 768
28 | }
29 |
--------------------------------------------------------------------------------
/autoregressive/serve/fake_json/GPT-L.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "facebook/opt-125m",
3 | "activation_dropout": 0.0,
4 | "activation_function": "relu",
5 | "architectures": [
6 | "OPTForCausalLM"
7 | ],
8 | "attention_dropout": 0.0,
9 | "bos_token_id": 2,
10 | "do_layer_norm_before": true,
11 | "dropout": 0.1,
12 | "eos_token_id": 2,
13 | "ffn_dim": 3072,
14 | "hidden_size": 1024,
15 | "init_std": 0.02,
16 | "layerdrop": 0.0,
17 | "max_position_embeddings": 2048,
18 | "model_type": "opt",
19 | "num_attention_heads": 16,
20 | "num_hidden_layers": 24,
21 | "pad_token_id": 1,
22 | "prefix": "",
23 | "torch_dtype": "bfloat16",
24 | "transformers_version": "4.21.0.dev0",
25 | "use_cache": true,
26 | "vocab_size": 16384,
27 | "word_embed_proj_dim": 768
28 | }
29 |
--------------------------------------------------------------------------------
/autoregressive/serve/fake_json/GPT-XL.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "facebook/opt-125m",
3 | "activation_dropout": 0.0,
4 | "activation_function": "relu",
5 | "architectures": [
6 | "OPTForCausalLM"
7 | ],
8 | "attention_dropout": 0.0,
9 | "bos_token_id": 2,
10 | "do_layer_norm_before": true,
11 | "dropout": 0.1,
12 | "eos_token_id": 2,
13 | "ffn_dim": 3072,
14 | "hidden_size": 1280,
15 | "init_std": 0.02,
16 | "layerdrop": 0.0,
17 | "max_position_embeddings": 2048,
18 | "model_type": "opt",
19 | "num_attention_heads": 20,
20 | "num_hidden_layers": 36,
21 | "pad_token_id": 1,
22 | "prefix": "",
23 | "torch_dtype": "bfloat16",
24 | "transformers_version": "4.21.0.dev0",
25 | "use_cache": true,
26 | "vocab_size": 16384,
27 | "word_embed_proj_dim": 768
28 | }
29 |
--------------------------------------------------------------------------------
/autoregressive/serve/fake_json/GPT-XXL.json:
--------------------------------------------------------------------------------
1 | {
2 | "_name_or_path": "facebook/opt-125m",
3 | "activation_dropout": 0.0,
4 | "activation_function": "relu",
5 | "architectures": [
6 | "OPTForCausalLM"
7 | ],
8 | "attention_dropout": 0.0,
9 | "bos_token_id": 2,
10 | "do_layer_norm_before": true,
11 | "dropout": 0.1,
12 | "eos_token_id": 2,
13 | "ffn_dim": 3072,
14 | "hidden_size": 1536,
15 | "init_std": 0.02,
16 | "layerdrop": 0.0,
17 | "max_position_embeddings": 2048,
18 | "model_type": "opt",
19 | "num_attention_heads": 24,
20 | "num_hidden_layers": 48,
21 | "pad_token_id": 1,
22 | "prefix": "",
23 | "torch_dtype": "bfloat16",
24 | "transformers_version": "4.21.0.dev0",
25 | "use_cache": true,
26 | "vocab_size": 16384,
27 | "word_embed_proj_dim": 768
28 | }
29 |
--------------------------------------------------------------------------------
/autoregressive/serve/gpu_executor.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, List, Set, Tuple, Optional, Set
2 | import argparse
3 |
4 | from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
5 | ModelConfig, ParallelConfig, SchedulerConfig,
6 | SpeculativeConfig, VisionLanguageConfig)
7 | from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase
8 | from vllm.logger import init_logger
9 | from vllm.lora.request import LoRARequest
10 | from vllm.sequence import SamplerOutput, SequenceGroupMetadata
11 | from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
12 | make_async)
13 |
14 | logger = init_logger(__name__)
15 |
16 |
17 | class GPUExecutor(ExecutorBase):
18 | def __init__(
19 | self,
20 | args: argparse.ArgumentParser,
21 | model_config: ModelConfig,
22 | cache_config: CacheConfig,
23 | parallel_config: ParallelConfig,
24 | scheduler_config: SchedulerConfig,
25 | device_config: DeviceConfig,
26 | load_config: LoadConfig,
27 | lora_config: Optional[LoRAConfig],
28 | vision_language_config: Optional[VisionLanguageConfig],
29 | speculative_config: Optional[SpeculativeConfig],
30 | ) -> None:
31 | self.args = args
32 | self.model_config = model_config
33 | self.cache_config = cache_config
34 | self.lora_config = lora_config
35 | self.load_config = load_config
36 | self.parallel_config = parallel_config
37 | self.scheduler_config = scheduler_config
38 | self.device_config = device_config
39 | self.vision_language_config = vision_language_config
40 | self.speculative_config = speculative_config
41 |
42 | self._init_executor()
43 |
44 | def _init_executor(self) -> None:
45 | """Initialize the worker and load the model.
46 |
47 | If speculative decoding is enabled, we instead create the speculative
48 | worker.
49 | """
50 | if self.speculative_config is None:
51 | self._init_non_spec_worker()
52 | else:
53 | self._init_spec_worker()
54 |
55 | def _init_non_spec_worker(self):
56 | # Lazy import the Worker to avoid importing torch.cuda/xformers
57 | # before CUDA_VISIBLE_DEVICES is set in the Worker
58 | # from vllm.worker.worker import Worker
59 | from autoregressive.serve.worker import Worker
60 |
61 | assert self.parallel_config.world_size == 1, (
62 | "GPUExecutor only supports single GPU.")
63 |
64 | distributed_init_method = get_distributed_init_method(
65 | get_ip(), get_open_port())
66 | self.driver_worker = Worker(
67 | model_config=self.model_config,
68 | parallel_config=self.parallel_config,
69 | scheduler_config=self.scheduler_config,
70 | device_config=self.device_config,
71 | cache_config=self.cache_config,
72 | load_config=self.load_config,
73 | local_rank=0,
74 | rank=0,
75 | distributed_init_method=distributed_init_method,
76 | lora_config=self.lora_config,
77 | vision_language_config=self.vision_language_config,
78 | is_driver_worker=True,
79 | )
80 | self.driver_worker.init_device()
81 | self.driver_worker.load_model(self.args)
82 |
83 | def _init_spec_worker(self):
84 | """Initialize a SpecDecodeWorker, using a draft model for proposals.
85 | """
86 | assert self.speculative_config is not None
87 |
88 | from vllm.spec_decode.multi_step_worker import MultiStepWorker
89 | from vllm.spec_decode.spec_decode_worker import SpecDecodeWorker
90 | from vllm.worker.worker import Worker
91 |
92 | distributed_init_method = get_distributed_init_method(
93 | get_ip(), get_open_port())
94 |
95 | target_worker = Worker(
96 | model_config=self.model_config,
97 | parallel_config=self.parallel_config,
98 | scheduler_config=self.scheduler_config,
99 | device_config=self.device_config,
100 | cache_config=self.cache_config,
101 | load_config=self.load_config,
102 | local_rank=0,
103 | rank=0,
104 | distributed_init_method=distributed_init_method,
105 | lora_config=self.lora_config,
106 | vision_language_config=self.vision_language_config,
107 | is_driver_worker=True,
108 | )
109 |
110 | draft_worker = MultiStepWorker(
111 | model_config=self.speculative_config.draft_model_config,
112 | parallel_config=self.speculative_config.draft_parallel_config,
113 | scheduler_config=self.scheduler_config,
114 | device_config=self.device_config,
115 | cache_config=self.cache_config,
116 | load_config=self.load_config,
117 | local_rank=0,
118 | rank=0,
119 | distributed_init_method=distributed_init_method,
120 | lora_config=self.lora_config,
121 | vision_language_config=self.vision_language_config,
122 | is_driver_worker=True,
123 | )
124 |
125 | spec_decode_worker = SpecDecodeWorker.from_workers(
126 | proposer_worker=draft_worker, scorer_worker=target_worker)
127 |
128 | assert self.parallel_config.world_size == 1, (
129 | "GPUExecutor only supports single GPU.")
130 |
131 | self.driver_worker = spec_decode_worker
132 |
133 | # Load model handled in spec decode worker.
134 | self.driver_worker.init_device()
135 |
136 | def determine_num_available_blocks(self) -> Tuple[int, int]:
137 | """Determine the number of available KV blocks by invoking the
138 | underlying worker.
139 | """
140 | return self.driver_worker.determine_num_available_blocks()
141 |
142 | def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks) -> None:
143 | """Initialize the KV cache by invoking the underlying worker.
144 | """
145 | # NOTE: This is logged in the executor because there can be >1 worker
146 | # with other executors. We could log in the engine level, but work
147 | # remains to abstract away the device for non-GPU configurations.
148 | logger.info(f"# GPU blocks: {num_gpu_blocks}, "
149 | f"# CPU blocks: {num_cpu_blocks}")
150 |
151 | self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks)
152 |
153 | def execute_model(
154 | self,
155 | seq_group_metadata_list: List[SequenceGroupMetadata],
156 | blocks_to_swap_in: Dict[int, int],
157 | blocks_to_swap_out: Dict[int, int],
158 | blocks_to_copy: Dict[int, List[int]],
159 | num_lookahead_slots: int,
160 | ) -> List[SamplerOutput]:
161 | output = self.driver_worker.execute_model(
162 | seq_group_metadata_list=seq_group_metadata_list,
163 | blocks_to_swap_in=blocks_to_swap_in,
164 | blocks_to_swap_out=blocks_to_swap_out,
165 | blocks_to_copy=blocks_to_copy,
166 | num_lookahead_slots=num_lookahead_slots,
167 | )
168 | return output
169 |
170 | def add_lora(self, lora_request: LoRARequest) -> bool:
171 | assert lora_request.lora_int_id > 0, "lora_id must be greater than 0."
172 | return self.driver_worker.add_lora(lora_request)
173 |
174 | def remove_lora(self, lora_id: int) -> bool:
175 | assert lora_id > 0, "lora_id must be greater than 0."
176 | return self.driver_worker.remove_lora(lora_id)
177 |
178 | def list_loras(self) -> Set[int]:
179 | return self.driver_worker.list_loras()
180 |
181 | def check_health(self) -> None:
182 | # GPUExecutor will always be healthy as long as
183 | # it's running.
184 | return
185 |
186 |
187 | class GPUExecutorAsync(GPUExecutor, ExecutorAsyncBase):
188 |
189 | async def execute_model_async(
190 | self,
191 | seq_group_metadata_list: List[SequenceGroupMetadata],
192 | blocks_to_swap_in: Dict[int, int],
193 | blocks_to_swap_out: Dict[int, int],
194 | blocks_to_copy: Dict[int, List[int]],
195 | ) -> SamplerOutput:
196 | output = await make_async(self.driver_worker.execute_model)(
197 | seq_group_metadata_list=seq_group_metadata_list,
198 | blocks_to_swap_in=blocks_to_swap_in,
199 | blocks_to_swap_out=blocks_to_swap_out,
200 | blocks_to_copy=blocks_to_copy)
201 | return output
--------------------------------------------------------------------------------
/autoregressive/serve/sample_c2i.py:
--------------------------------------------------------------------------------
1 | import time
2 | import argparse
3 | import torch
4 | from torchvision.utils import save_image
5 |
6 | from tokenizer.tokenizer_image.vq_model import VQ_models
7 | from autoregressive.serve.gpt_model import GPT_models
8 | from autoregressive.serve.llm import LLM
9 | from vllm import SamplingParams
10 |
11 |
12 | def main(args):
13 | # Setup PyTorch:
14 | torch.manual_seed(args.seed)
15 | torch.backends.cudnn.deterministic = True
16 | torch.backends.cudnn.benchmark = False
17 | torch.set_grad_enabled(False)
18 | device = "cuda" if torch.cuda.is_available() else "cpu"
19 |
20 | # create and load model
21 | vq_model = VQ_models[args.vq_model](
22 | codebook_size=args.codebook_size,
23 | codebook_embed_dim=args.codebook_embed_dim)
24 | vq_model.to(device)
25 | vq_model.eval()
26 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
27 | vq_model.load_state_dict(checkpoint["model"])
28 | del checkpoint
29 | print(f"image tokenizer is loaded")
30 |
31 | # Labels to condition the model with (feel free to change):
32 | class_labels = [207, 360, 387, 974, 88, 979, 417, 279]
33 | latent_size = args.image_size // args.downsample_size
34 | qzshape = [len(class_labels), args.codebook_embed_dim, latent_size, latent_size]
35 | prompt_token_ids = [[cind] for cind in class_labels]
36 | if args.cfg_scale > 1.0:
37 | prompt_token_ids.extend([[args.num_classes] for _ in range(len(prompt_token_ids))])
38 | # Create an LLM.
39 | llm = LLM(
40 | args=args,
41 | model='autoregressive/serve/fake_json/{}.json'.format(args.gpt_model),
42 | gpu_memory_utilization=0.9,
43 | skip_tokenizer_init=True)
44 | print(f"gpt model is loaded")
45 |
46 | # Create a sampling params object.
47 | sampling_params = SamplingParams(
48 | temperature=args.temperature, top_p=args.top_p, top_k=args.top_k,
49 | max_tokens=latent_size ** 2)
50 |
51 | # Generate texts from the prompts. The output is a list of RequestOutput objects
52 | # that contain the prompt, generated text, and other information.
53 | t1 = time.time()
54 | outputs = llm.generate(
55 | prompt_token_ids=prompt_token_ids,
56 | sampling_params=sampling_params,
57 | use_tqdm=False)
58 | sampling_time = time.time() - t1
59 | print(f"gpt sampling takes about {sampling_time:.2f} seconds.")
60 |
61 | # decode to image
62 | index_sample = torch.tensor([output.outputs[0].token_ids for output in outputs], device=device)
63 | if args.cfg_scale > 1.0:
64 | index_sample = index_sample[:len(class_labels)]
65 | t2 = time.time()
66 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
67 | decoder_time = time.time() - t2
68 | print(f"decoder takes about {decoder_time:.2f} seconds.")
69 |
70 | # Save and display images:
71 | save_image(samples, "sample_{}_vllm.png".format(args.gpt_type), nrow=4, normalize=True, value_range=(-1, 1))
72 | print(f"image is saved to sample_{args.gpt_type}_vllm.png")
73 |
74 |
75 | if __name__ == '__main__':
76 | parser = argparse.ArgumentParser()
77 | parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
78 | parser.add_argument("--gpt-ckpt", type=str, required=True, help="ckpt path for gpt model")
79 | parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
80 | parser.add_argument("--from-fsdp", action='store_true')
81 | parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
82 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
83 | parser.add_argument("--compile", action='store_true', default=False)
84 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
85 | parser.add_argument("--vq-ckpt", type=str, required=True, help="ckpt path for vq model")
86 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
87 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
88 | parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
89 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
90 | parser.add_argument("--num-classes", type=int, default=1000)
91 | parser.add_argument("--cfg-scale", type=float, default=4.0)
92 | parser.add_argument("--seed", type=int, default=0)
93 | parser.add_argument("--top-k", type=int, default=2000,help="top-k value to sample with")
94 | parser.add_argument("--temperature", type=float, default=1.0, help="temperature value to sample with")
95 | parser.add_argument("--top-p", type=float, default=1.0, help="top-p value to sample with")
96 | args = parser.parse_args()
97 | main(args)
98 |
--------------------------------------------------------------------------------
/autoregressive/train/extract_codes_c2i.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/extract_features.py
3 | import torch
4 | torch.backends.cuda.matmul.allow_tf32 = True
5 | torch.backends.cudnn.allow_tf32 = True
6 | import torch.distributed as dist
7 | from torch.utils.data import DataLoader
8 | from torch.utils.data.distributed import DistributedSampler
9 | from torchvision import transforms
10 | import numpy as np
11 | import argparse
12 | import os
13 |
14 | from utils.distributed import init_distributed_mode
15 | from dataset.augmentation import center_crop_arr
16 | from dataset.build import build_dataset
17 | from tokenizer.tokenizer_image.vq_model import VQ_models
18 |
19 |
20 | #################################################################################
21 | # Training Loop #
22 | #################################################################################
23 | def main(args):
24 | assert torch.cuda.is_available(), "Training currently requires at least one GPU."
25 | # Setup DDP:
26 | if not args.debug:
27 | init_distributed_mode(args)
28 | rank = dist.get_rank()
29 | device = rank % torch.cuda.device_count()
30 | seed = args.global_seed * dist.get_world_size() + rank
31 | torch.manual_seed(seed)
32 | torch.cuda.set_device(device)
33 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
34 | else:
35 | device = 'cuda'
36 | rank = 0
37 |
38 | # Setup a feature folder:
39 | if args.debug or rank == 0:
40 | os.makedirs(args.code_path, exist_ok=True)
41 | os.makedirs(os.path.join(args.code_path, f'{args.dataset}{args.image_size}_codes'), exist_ok=True)
42 | os.makedirs(os.path.join(args.code_path, f'{args.dataset}{args.image_size}_labels'), exist_ok=True)
43 |
44 | # create and load model
45 | vq_model = VQ_models[args.vq_model](
46 | codebook_size=args.codebook_size,
47 | codebook_embed_dim=args.codebook_embed_dim)
48 | vq_model.to(device)
49 | vq_model.eval()
50 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
51 | vq_model.load_state_dict(checkpoint["model"])
52 | del checkpoint
53 |
54 | # Setup data:
55 | if args.ten_crop:
56 | crop_size = int(args.image_size * args.crop_range)
57 | transform = transforms.Compose([
58 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, crop_size)),
59 | transforms.TenCrop(args.image_size), # this is a tuple of PIL Images
60 | transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])), # returns a 4D tensor
61 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
62 | ])
63 | else:
64 | crop_size = args.image_size
65 | transform = transforms.Compose([
66 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, crop_size)),
67 | transforms.ToTensor(),
68 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
69 | ])
70 | dataset = build_dataset(args, transform=transform)
71 | if not args.debug:
72 | sampler = DistributedSampler(
73 | dataset,
74 | num_replicas=dist.get_world_size(),
75 | rank=rank,
76 | shuffle=False,
77 | seed=args.global_seed
78 | )
79 | else:
80 | sampler = None
81 | loader = DataLoader(
82 | dataset,
83 | batch_size=1, # important!
84 | shuffle=False,
85 | sampler=sampler,
86 | num_workers=args.num_workers,
87 | pin_memory=True,
88 | drop_last=False
89 | )
90 |
91 | total = 0
92 | for x, y in loader:
93 | x = x.to(device)
94 | if args.ten_crop:
95 | x_all = x.flatten(0, 1)
96 | num_aug = 10
97 | else:
98 | x_flip = torch.flip(x, dims=[-1])
99 | x_all = torch.cat([x, x_flip])
100 | num_aug = 2
101 | y = y.to(device)
102 | with torch.no_grad():
103 | _, _, [_, _, indices] = vq_model.encode(x_all)
104 | codes = indices.reshape(x.shape[0], num_aug, -1)
105 |
106 | x = codes.detach().cpu().numpy() # (1, num_aug, args.image_size//16 * args.image_size//16)
107 | train_steps = rank + total
108 | np.save(f'{args.code_path}/{args.dataset}{args.image_size}_codes/{train_steps}.npy', x)
109 |
110 | y = y.detach().cpu().numpy() # (1,)
111 | np.save(f'{args.code_path}/{args.dataset}{args.image_size}_labels/{train_steps}.npy', y)
112 | if not args.debug:
113 | total += dist.get_world_size()
114 | else:
115 | total += 1
116 | print(total)
117 |
118 | dist.destroy_process_group()
119 |
120 |
121 | if __name__ == "__main__":
122 | parser = argparse.ArgumentParser()
123 | parser.add_argument("--data-path", type=str, required=True)
124 | parser.add_argument("--code-path", type=str, required=True)
125 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
126 | parser.add_argument("--vq-ckpt", type=str, required=True, help="ckpt path for vq model")
127 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
128 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
129 | parser.add_argument("--dataset", type=str, default='imagenet')
130 | parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512], default=256)
131 | parser.add_argument("--ten-crop", action='store_true', help="whether using random crop")
132 | parser.add_argument("--crop-range", type=float, default=1.1, help="expanding range of center crop")
133 | parser.add_argument("--global-seed", type=int, default=0)
134 | parser.add_argument("--num-workers", type=int, default=24)
135 | parser.add_argument("--debug", action='store_true')
136 | args = parser.parse_args()
137 | main(args)
138 |
--------------------------------------------------------------------------------
/autoregressive/train/extract_codes_t2i.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # fast-DiT: https://github.com/chuanyangjin/fast-DiT/blob/main/extract_features.py
3 | import torch
4 | torch.backends.cuda.matmul.allow_tf32 = True
5 | torch.backends.cudnn.allow_tf32 = True
6 | import torch.distributed as dist
7 | from torch.utils.data import Dataset, DataLoader
8 | from torch.utils.data.distributed import DistributedSampler
9 | from torchvision import transforms
10 | import numpy as np
11 | from PIL import Image
12 | import glob
13 | import argparse
14 | import os
15 | import json
16 |
17 | from utils.distributed import init_distributed_mode
18 | from dataset.augmentation import center_crop_arr
19 | from tokenizer.tokenizer_image.vq_model import VQ_models
20 |
21 |
22 | #################################################################################
23 | # Training Helper Functions #
24 | #################################################################################
25 | class CustomDataset(Dataset):
26 | def __init__(self, lst_dir, start, end, transform):
27 | img_path_list = []
28 | for lst_name in sorted(os.listdir(lst_dir))[start: end+1]:
29 | if not lst_name.endswith('.jsonl'):
30 | continue
31 | file_path = os.path.join(lst_dir, lst_name)
32 | with open(file_path, 'r') as file:
33 | for line_idx, line in enumerate(file):
34 | data = json.loads(line)
35 | img_path = data['image_path']
36 | code_dir = file_path.split('/')[-1].split('.')[0]
37 | img_path_list.append((img_path, code_dir, line_idx))
38 | self.img_path_list = img_path_list
39 | self.transform = transform
40 |
41 | def __len__(self):
42 | return len(self.img_path_list)
43 |
44 | def __getitem__(self, index):
45 | img_path, code_dir, code_name = self.img_path_list[index]
46 | img = Image.open(img_path).convert("RGB")
47 | if self.transform is not None:
48 | img = self.transform(img)
49 | return img, code_dir, code_name
50 |
51 |
52 |
53 | #################################################################################
54 | # Training Loop #
55 | #################################################################################
56 | def main(args):
57 | """
58 | Trains a new DiT model.
59 | """
60 | assert torch.cuda.is_available(), "Training currently requires at least one GPU."
61 |
62 | # Setup DDP:
63 | # dist.init_process_group("nccl")
64 | init_distributed_mode(args)
65 | rank = dist.get_rank()
66 | device = rank % torch.cuda.device_count()
67 | seed = args.global_seed * dist.get_world_size() + rank
68 | torch.manual_seed(seed)
69 | torch.cuda.set_device(device)
70 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
71 |
72 | # Setup a feature folder:
73 | if rank == 0:
74 | os.makedirs(args.code_path, exist_ok=True)
75 |
76 |
77 | # create and load model
78 | vq_model = VQ_models[args.vq_model](
79 | codebook_size=args.codebook_size,
80 | codebook_embed_dim=args.codebook_embed_dim)
81 | vq_model.to(device)
82 | vq_model.eval()
83 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
84 | vq_model.load_state_dict(checkpoint["model"])
85 | del checkpoint
86 |
87 |
88 | # Setup data:
89 | transform = transforms.Compose([
90 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
91 | transforms.ToTensor(),
92 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
93 | ])
94 | print(f"Dataset is preparing...")
95 | dataset = CustomDataset(args.data_path, args.data_start, args.data_end, transform=transform)
96 | sampler = DistributedSampler(
97 | dataset,
98 | num_replicas=dist.get_world_size(),
99 | rank=rank,
100 | shuffle=False,
101 | seed=args.global_seed
102 | )
103 | loader = DataLoader(
104 | dataset,
105 | batch_size=1, # important!
106 | shuffle=False,
107 | sampler=sampler,
108 | num_workers=args.num_workers,
109 | pin_memory=True,
110 | drop_last=False
111 | )
112 | print(f"Dataset contains {len(dataset):,} images")
113 |
114 | # total = 0
115 | for img, code_dir, code_name in loader:
116 | img = img.to(device)
117 | with torch.no_grad():
118 | _, _, [_, _, indices] = vq_model.encode(img)
119 | codes = indices.reshape(img.shape[0], -1)
120 | x = codes.detach().cpu().numpy() # (1, args.image_size//16 * args.image_size//16)
121 | os.makedirs(os.path.join(args.code_path, code_dir[0]), exist_ok=True)
122 | np.save(os.path.join(args.code_path, code_dir[0], '{}.npy'.format(code_name.item())), x)
123 |
124 | # total += dist.get_world_size()
125 | print(code_name.item())
126 |
127 | dist.destroy_process_group()
128 |
129 |
130 | if __name__ == "__main__":
131 | parser = argparse.ArgumentParser()
132 | parser.add_argument("--data-path", type=str, required=True)
133 | parser.add_argument("--code-path", type=str, required=True)
134 | parser.add_argument("--data-start", type=int, required=True)
135 | parser.add_argument("--data-end", type=int, required=True)
136 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
137 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
138 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
139 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
140 | parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512], default=512)
141 | parser.add_argument("--global-seed", type=int, default=0)
142 | parser.add_argument("--num-workers", type=int, default=24)
143 | args = parser.parse_args()
144 | main(args)
145 |
--------------------------------------------------------------------------------
/dataset/augmentation.py:
--------------------------------------------------------------------------------
1 | # from https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py
2 | import math
3 | import random
4 | import numpy as np
5 | from PIL import Image
6 |
7 |
8 | def center_crop_arr(pil_image, image_size):
9 | """
10 | Center cropping implementation from ADM.
11 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
12 | """
13 | while min(*pil_image.size) >= 2 * image_size:
14 | pil_image = pil_image.resize(
15 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
16 | )
17 |
18 | scale = image_size / min(*pil_image.size)
19 | pil_image = pil_image.resize(
20 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
21 | )
22 |
23 | arr = np.array(pil_image)
24 | crop_y = (arr.shape[0] - image_size) // 2
25 | crop_x = (arr.shape[1] - image_size) // 2
26 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
27 |
28 |
29 | def random_crop_arr(pil_image, image_size, min_crop_frac=0.8, max_crop_frac=1.0):
30 | min_smaller_dim_size = math.ceil(image_size / max_crop_frac)
31 | max_smaller_dim_size = math.ceil(image_size / min_crop_frac)
32 | smaller_dim_size = random.randrange(min_smaller_dim_size, max_smaller_dim_size + 1)
33 |
34 | # We are not on a new enough PIL to support the `reducing_gap`
35 | # argument, which uses BOX downsampling at powers of two first.
36 | # Thus, we do it by hand to improve downsample quality.
37 | while min(*pil_image.size) >= 2 * smaller_dim_size:
38 | pil_image = pil_image.resize(
39 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
40 | )
41 |
42 | scale = smaller_dim_size / min(*pil_image.size)
43 | pil_image = pil_image.resize(
44 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
45 | )
46 |
47 | arr = np.array(pil_image)
48 | crop_y = random.randrange(arr.shape[0] - image_size + 1)
49 | crop_x = random.randrange(arr.shape[1] - image_size + 1)
50 | return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
51 |
52 |
--------------------------------------------------------------------------------
/dataset/build.py:
--------------------------------------------------------------------------------
1 | from dataset.imagenet import build_imagenet, build_imagenet_code
2 | from dataset.coco import build_coco
3 | from dataset.openimage import build_openimage
4 | from dataset.pexels import build_pexels
5 | from dataset.t2i import build_t2i, build_t2i_code, build_t2i_image
6 |
7 |
8 | def build_dataset(args, **kwargs):
9 | # images
10 | if args.dataset == 'imagenet':
11 | return build_imagenet(args, **kwargs)
12 | if args.dataset == 'imagenet_code':
13 | return build_imagenet_code(args, **kwargs)
14 | if args.dataset == 'coco':
15 | return build_coco(args, **kwargs)
16 | if args.dataset == 'openimage':
17 | return build_openimage(args, **kwargs)
18 | if args.dataset == 'pexels':
19 | return build_pexels(args, **kwargs)
20 | if args.dataset == 't2i_image':
21 | return build_t2i_image(args, **kwargs)
22 | if args.dataset == 't2i':
23 | return build_t2i(args, **kwargs)
24 | if args.dataset == 't2i_code':
25 | return build_t2i_code(args, **kwargs)
26 |
27 | raise ValueError(f'dataset {args.dataset} is not supported')
--------------------------------------------------------------------------------
/dataset/coco.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from torch.utils.data import Dataset
4 | from PIL import Image
5 |
6 |
7 | class SingleFolderDataset(Dataset):
8 | def __init__(self, directory, transform=None):
9 | super().__init__()
10 | self.directory = directory
11 | self.transform = transform
12 | self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory)
13 | if os.path.isfile(os.path.join(directory, file_name))]
14 |
15 | def __len__(self):
16 | return len(self.image_paths)
17 |
18 | def __getitem__(self, idx):
19 | image_path = self.image_paths[idx]
20 | image = Image.open(image_path).convert('RGB')
21 | if self.transform:
22 | image = self.transform(image)
23 | return image, torch.tensor(0)
24 |
25 |
26 | def build_coco(args, transform):
27 | return SingleFolderDataset(args.data_path, transform=transform)
--------------------------------------------------------------------------------
/dataset/imagenet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | from torch.utils.data import Dataset
5 | from torchvision.datasets import ImageFolder
6 |
7 |
8 | class CustomDataset(Dataset):
9 | def __init__(self, feature_dir, label_dir):
10 | self.feature_dir = feature_dir
11 | self.label_dir = label_dir
12 | self.flip = 'flip' in self.feature_dir
13 |
14 | aug_feature_dir = feature_dir.replace('ten_crop/', 'ten_crop_105/')
15 | aug_label_dir = label_dir.replace('ten_crop/', 'ten_crop_105/')
16 | if os.path.exists(aug_feature_dir) and os.path.exists(aug_label_dir):
17 | self.aug_feature_dir = aug_feature_dir
18 | self.aug_label_dir = aug_label_dir
19 | else:
20 | self.aug_feature_dir = None
21 | self.aug_label_dir = None
22 |
23 | # self.feature_files = sorted(os.listdir(feature_dir))
24 | # self.label_files = sorted(os.listdir(label_dir))
25 | # TODO: make it configurable
26 | self.feature_files = [f"{i}.npy" for i in range(1281167)]
27 | self.label_files = [f"{i}.npy" for i in range(1281167)]
28 |
29 | def __len__(self):
30 | assert len(self.feature_files) == len(self.label_files), \
31 | "Number of feature files and label files should be same"
32 | return len(self.feature_files)
33 |
34 | def __getitem__(self, idx):
35 | if self.aug_feature_dir is not None and torch.rand(1) < 0.5:
36 | feature_dir = self.aug_feature_dir
37 | label_dir = self.aug_label_dir
38 | else:
39 | feature_dir = self.feature_dir
40 | label_dir = self.label_dir
41 |
42 | feature_file = self.feature_files[idx]
43 | label_file = self.label_files[idx]
44 |
45 | features = np.load(os.path.join(feature_dir, feature_file))
46 | if self.flip:
47 | aug_idx = torch.randint(low=0, high=features.shape[1], size=(1,)).item()
48 | features = features[:, aug_idx]
49 | labels = np.load(os.path.join(label_dir, label_file))
50 | return torch.from_numpy(features), torch.from_numpy(labels)
51 |
52 |
53 | def build_imagenet(args, transform):
54 | return ImageFolder(args.data_path, transform=transform)
55 |
56 | def build_imagenet_code(args):
57 | feature_dir = f"{args.code_path}/imagenet{args.image_size}_codes"
58 | label_dir = f"{args.code_path}/imagenet{args.image_size}_labels"
59 | assert os.path.exists(feature_dir) and os.path.exists(label_dir), \
60 | f"please first run: bash scripts/autoregressive/extract_codes_c2i.sh ..."
61 | return CustomDataset(feature_dir, label_dir)
--------------------------------------------------------------------------------
/dataset/openimage.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 | from PIL import Image
5 |
6 | import torch
7 | from torch.utils.data import Dataset
8 |
9 |
10 | class DatasetJson(Dataset):
11 | def __init__(self, data_path, transform=None):
12 | super().__init__()
13 | self.data_path = data_path
14 | self.transform = transform
15 | json_path = os.path.join(data_path, 'image_paths.json')
16 | assert os.path.exists(json_path), f"please first run: python3 tools/openimage_json.py"
17 | with open(json_path, 'r') as f:
18 | self.image_paths = json.load(f)
19 |
20 | def __len__(self):
21 | return len(self.image_paths)
22 |
23 | def __getitem__(self, idx):
24 | for _ in range(20):
25 | try:
26 | return self.getdata(idx)
27 | except Exception as e:
28 | print(f"Error details: {str(e)}")
29 | idx = np.random.randint(len(self))
30 | raise RuntimeError('Too many bad data.')
31 |
32 | def getdata(self, idx):
33 | image_path = self.image_paths[idx]
34 | image_path_full = os.path.join(self.data_path, image_path)
35 | image = Image.open(image_path_full).convert('RGB')
36 | if self.transform:
37 | image = self.transform(image)
38 | return image, torch.tensor(0)
39 |
40 |
41 | def build_openimage(args, transform):
42 | return DatasetJson(args.data_path, transform=transform)
43 |
--------------------------------------------------------------------------------
/dataset/pexels.py:
--------------------------------------------------------------------------------
1 | from torchvision.datasets import ImageFolder
2 |
3 | def build_pexels(args, transform):
4 | return ImageFolder(args.data_path, transform=transform)
--------------------------------------------------------------------------------
/dataset/t2i.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 |
5 | import torch
6 | from torch.utils.data import Dataset
7 | from PIL import Image
8 |
9 |
10 | class Text2ImgDatasetImg(Dataset):
11 | def __init__(self, lst_dir, face_lst_dir, transform):
12 | img_path_list = []
13 | valid_file_path = []
14 | # collect valid jsonl
15 | for lst_name in sorted(os.listdir(lst_dir)):
16 | if not lst_name.endswith('.jsonl'):
17 | continue
18 | file_path = os.path.join(lst_dir, lst_name)
19 | valid_file_path.append(file_path)
20 |
21 | # collect valid jsonl for face
22 | if face_lst_dir is not None:
23 | for lst_name in sorted(os.listdir(face_lst_dir)):
24 | if not lst_name.endswith('_face.jsonl'):
25 | continue
26 | file_path = os.path.join(face_lst_dir, lst_name)
27 | valid_file_path.append(file_path)
28 |
29 | for file_path in valid_file_path:
30 | with open(file_path, 'r') as file:
31 | for line_idx, line in enumerate(file):
32 | data = json.loads(line)
33 | img_path = data['image_path']
34 | code_dir = file_path.split('/')[-1].split('.')[0]
35 | img_path_list.append((img_path, code_dir, line_idx))
36 | self.img_path_list = img_path_list
37 | self.transform = transform
38 |
39 | def __len__(self):
40 | return len(self.img_path_list)
41 |
42 | def __getitem__(self, index):
43 | img_path, code_dir, code_name = self.img_path_list[index]
44 | img = Image.open(img_path).convert("RGB")
45 | if self.transform is not None:
46 | img = self.transform(img)
47 | return img, code_name
48 |
49 |
50 | class Text2ImgDataset(Dataset):
51 | def __init__(self, args, transform):
52 | img_path_list = []
53 | valid_file_path = []
54 | # collect valid jsonl file path
55 | for lst_name in sorted(os.listdir(args.data_path)):
56 | if not lst_name.endswith('.jsonl'):
57 | continue
58 | file_path = os.path.join(args.data_path, lst_name)
59 | valid_file_path.append(file_path)
60 |
61 | for file_path in valid_file_path:
62 | with open(file_path, 'r') as file:
63 | for line_idx, line in enumerate(file):
64 | data = json.loads(line)
65 | img_path = data['image_path']
66 | code_dir = file_path.split('/')[-1].split('.')[0]
67 | img_path_list.append((img_path, code_dir, line_idx))
68 | self.img_path_list = img_path_list
69 | self.transform = transform
70 |
71 | self.t5_feat_path = args.t5_feat_path
72 | self.short_t5_feat_path = args.short_t5_feat_path
73 | self.t5_feat_path_base = self.t5_feat_path.split('/')[-1]
74 | if self.short_t5_feat_path is not None:
75 | self.short_t5_feat_path_base = self.short_t5_feat_path.split('/')[-1]
76 | else:
77 | self.short_t5_feat_path_base = self.t5_feat_path_base
78 | self.image_size = args.image_size
79 | latent_size = args.image_size // args.downsample_size
80 | self.code_len = latent_size ** 2
81 | self.t5_feature_max_len = 120
82 | self.t5_feature_dim = 2048
83 | self.max_seq_length = self.t5_feature_max_len + self.code_len
84 |
85 | def __len__(self):
86 | return len(self.img_path_list)
87 |
88 | def dummy_data(self):
89 | img = torch.zeros((3, self.image_size, self.image_size), dtype=torch.float32)
90 | t5_feat_padding = torch.zeros((1, self.t5_feature_max_len, self.t5_feature_dim))
91 | attn_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)).unsqueeze(0)
92 | valid = 0
93 | return img, t5_feat_padding, attn_mask, valid
94 |
95 | def __getitem__(self, index):
96 | img_path, code_dir, code_name = self.img_path_list[index]
97 | try:
98 | img = Image.open(img_path).convert("RGB")
99 | except:
100 | img, t5_feat_padding, attn_mask, valid = self.dummy_data()
101 | return img, t5_feat_padding, attn_mask, torch.tensor(valid)
102 |
103 | if min(img.size) < self.image_size:
104 | img, t5_feat_padding, attn_mask, valid = self.dummy_data()
105 | return img, t5_feat_padding, attn_mask, torch.tensor(valid)
106 |
107 | if self.transform is not None:
108 | img = self.transform(img)
109 |
110 | t5_file = os.path.join(self.t5_feat_path, code_dir, f"{code_name}.npy")
111 | if torch.rand(1) < 0.3:
112 | t5_file = t5_file.replace(self.t5_feat_path_base, self.short_t5_feat_path_base)
113 |
114 | t5_feat_padding = torch.zeros((1, self.t5_feature_max_len, self.t5_feature_dim))
115 | if os.path.isfile(t5_file):
116 | try:
117 | t5_feat = torch.from_numpy(np.load(t5_file))
118 | t5_feat_len = t5_feat.shape[1]
119 | feat_len = min(self.t5_feature_max_len, t5_feat_len)
120 | t5_feat_padding[:, -feat_len:] = t5_feat[:, :feat_len]
121 | emb_mask = torch.zeros((self.t5_feature_max_len,))
122 | emb_mask[-feat_len:] = 1
123 | attn_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length))
124 | T = self.t5_feature_max_len
125 | attn_mask[:, :T] = attn_mask[:, :T] * emb_mask.unsqueeze(0)
126 | eye_matrix = torch.eye(self.max_seq_length, self.max_seq_length)
127 | attn_mask = attn_mask * (1 - eye_matrix) + eye_matrix
128 | attn_mask = attn_mask.unsqueeze(0).to(torch.bool)
129 | valid = 1
130 | except:
131 | img, t5_feat_padding, attn_mask, valid = self.dummy_data()
132 | else:
133 | img, t5_feat_padding, attn_mask, valid = self.dummy_data()
134 |
135 | return img, t5_feat_padding, attn_mask, torch.tensor(valid)
136 |
137 |
138 | class Text2ImgDatasetCode(Dataset):
139 | def __init__(self, args):
140 | pass
141 |
142 |
143 |
144 |
145 | def build_t2i_image(args, transform):
146 | return Text2ImgDatasetImg(args.data_path, args.data_face_path, transform)
147 |
148 | def build_t2i(args, transform):
149 | return Text2ImgDataset(args, transform)
150 |
151 | def build_t2i_code(args):
152 | return Text2ImgDatasetCode(args)
--------------------------------------------------------------------------------
/evaluations/c2i/README.md:
--------------------------------------------------------------------------------
1 | # Evaluations from [OpenAI](https://github.com/openai/guided-diffusion/tree/main/evaluations)
2 |
3 | To compare different generative models, we use FID, sFID, Precision, Recall, and Inception Score. These metrics can all be calculated using batches of samples, which we store in `.npz` (numpy) files.
4 |
5 | # Installation
6 | ### cuda version 11.7
7 | ```
8 | pip install tensorflow-gpu==2.5.0
9 | pip install numpy==1.22.0
10 | pip install scipy
11 | pip install pydantic
12 | ```
13 | There will happen error like `tensorflow.python.framework.errors_impl.NotFoundError: /usr/local/lib/python3.9/dist-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so: undefined symbol: _ZN10tensorflow...`, deleting `/usr/local/lib/python3.9/dist-packages/tensorflow/core/kernels/libtfkernel_sobol_op.so` will fix this error.
14 |
15 | ### cuda version 12.1
16 | ```
17 | pip install tensorflow
18 | pip install numpy==1.23.5
19 | pip install scipy
20 | ```
21 |
22 | ### H100, cuda version 12.2
23 | ```
24 | pip install tensorflow
25 | pip install numpy==1.26.2
26 | pip install scipy
27 | ```
28 |
29 | # Download batches
30 |
31 | We provide pre-computed sample batches for the reference datasets, our diffusion models, and several baselines we compare against. These are all stored in `.npz` format.
32 |
33 | Reference dataset batches contain pre-computed statistics over the whole dataset, as well as 10,000 images for computing Precision and Recall. All other batches contain 50,000 images which can be used to compute statistics and Precision/Recall.
34 |
35 | Here are links to download all of the sample and reference batches:
36 |
37 | * LSUN
38 | * LSUN bedroom: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/VIRTUAL_lsun_bedroom256.npz)
39 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/admnet_dropout_lsun_bedroom.npz)
40 | * [DDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/ddpm_lsun_bedroom.npz)
41 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/iddpm_lsun_bedroom.npz)
42 | * [StyleGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/bedroom/stylegan_lsun_bedroom.npz)
43 | * LSUN cat: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/VIRTUAL_lsun_cat256.npz)
44 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/admnet_dropout_lsun_cat.npz)
45 | * [StyleGAN2](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/cat/stylegan2_lsun_cat.npz)
46 | * LSUN horse: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/VIRTUAL_lsun_horse256.npz)
47 | * [ADM (dropout)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_dropout_lsun_horse.npz)
48 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/lsun/horse/admnet_lsun_horse.npz)
49 |
50 | * ImageNet
51 | * ImageNet 64x64: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/VIRTUAL_imagenet64_labeled.npz)
52 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/admnet_imagenet64.npz)
53 | * [IDDPM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/iddpm_imagenet64.npz)
54 | * [BigGAN](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/64/biggan_deep_imagenet64.npz)
55 | * ImageNet 128x128: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/VIRTUAL_imagenet128_labeled.npz)
56 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_imagenet128.npz)
57 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_imagenet128.npz)
58 | * [ADM-G, 25 steps](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/admnet_guided_25step_imagenet128.npz)
59 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/128/biggan_deep_trunc1_imagenet128.npz)
60 | * ImageNet 256x256: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/VIRTUAL_imagenet256_labeled.npz)
61 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_imagenet256.npz)
62 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_imagenet256.npz)
63 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_25step_imagenet256.npz)
64 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_guided_upsampled_imagenet256.npz)
65 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/admnet_upsampled_imagenet256.npz)
66 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/256/biggan_deep_trunc1_imagenet256.npz)
67 | * ImageNet 512x512: [reference batch](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/VIRTUAL_imagenet512.npz)
68 | * [ADM](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_imagenet512.npz)
69 | * [ADM-G](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_imagenet512.npz)
70 | * [ADM-G, 25 step](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_25step_imagenet512.npz)
71 | * [ADM-G + ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_guided_upsampled_imagenet512.npz)
72 | * [ADM-U](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/admnet_upsampled_imagenet512.npz)
73 | * [BigGAN-deep (trunc=1.0)](https://openaipublic.blob.core.windows.net/diffusion/jul-2021/ref_batches/imagenet/512/biggan_deep_trunc1_imagenet512.npz)
74 |
75 | # Run evaluations
76 |
77 | First, generate or download a batch of samples and download the corresponding reference batch for the given dataset. For this example, we'll use ImageNet 256x256, so the refernce batch is `VIRTUAL_imagenet256_labeled.npz` and we can use the sample batch `admnet_guided_upsampled_imagenet256.npz`.
78 |
79 | Next, run the `evaluator.py` script. The requirements of this script can be found in [requirements.txt](requirements.txt). Pass two arguments to the script: the reference batch and the sample batch. The script will download the InceptionV3 model used for evaluations into the current working directory (if it is not already present). This file is roughly 100MB.
80 |
81 | The output of the script will look something like this, where the first `...` is a bunch of verbose TensorFlow logging:
82 |
83 | ```
84 | $ python evaluator.py VIRTUAL_imagenet256_labeled.npz admnet_guided_upsampled_imagenet256.npz
85 | ...
86 | computing reference batch activations...
87 | computing/reading reference batch statistics...
88 | computing sample batch activations...
89 | computing/reading sample batch statistics...
90 | Computing evaluations...
91 | Inception Score: 215.8370361328125
92 | FID: 3.9425574129223264
93 | sFID: 6.140433703346162
94 | Precision: 0.8265
95 | Recall: 0.5309
96 | ```
97 |
--------------------------------------------------------------------------------
/evaluations/t2i/README.md:
--------------------------------------------------------------------------------
1 | # Evaluations from [GigaGAN](https://github.com/mingukkang/GigaGAN/tree/main/evaluation)
2 |
3 | ```
4 | pip install git+https://github.com/openai/CLIP.git
5 | pip install open_clip_torch
6 | pip install clean_fid
7 | ```
8 |
9 | ```
10 | python3 evaluations/t2i/evaluation.py \
11 | --eval_res 256 \
12 | --batch_size 256 \
13 | --how_many 30000 \
14 | --ref_data "coco2014" \
15 | --ref_type "val2014" \
16 | --eval_res 256 \
17 | --batch_size 256 \
18 | --ref_dir "/path/to/coco" \
19 | --fake_dir "/path/to/generation" \
20 | $@
21 | ```
22 |
--------------------------------------------------------------------------------
/language/README.md:
--------------------------------------------------------------------------------
1 | ## Language models for text-conditional image generation
2 |
3 | ### Requirements
4 | ```
5 | pip install ftfy
6 | pip install transformers
7 | pip install accelerate
8 | pip install sentencepiece
9 | pip install pandas
10 | pip install bs4
11 | ```
12 |
13 | ### Language Models
14 | Download flan-t5-xl models from [flan-t5-xl](https://huggingface.co/google/flan-t5-xl) and put into the folder of `./pretrained_models/t5-ckpt/`
15 |
--------------------------------------------------------------------------------
/language/extract_t5_feature.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.backends.cuda.matmul.allow_tf32 = True
3 | torch.backends.cudnn.allow_tf32 = True
4 | import torch.distributed as dist
5 | from torch.utils.data import Dataset, DataLoader
6 | from torch.utils.data.distributed import DistributedSampler
7 | import numpy as np
8 | import argparse
9 | import os
10 | import json
11 |
12 | from utils.distributed import init_distributed_mode
13 | from language.t5 import T5Embedder
14 |
15 | CAPTION_KEY = {
16 | 'blip': 0,
17 | 'llava': 1,
18 | 'llava_first': 2,
19 | }
20 | #################################################################################
21 | # Training Helper Functions #
22 | #################################################################################
23 | class CustomDataset(Dataset):
24 | def __init__(self, lst_dir, start, end, caption_key, trunc_caption=False):
25 | img_path_list = []
26 | for lst_name in sorted(os.listdir(lst_dir))[start: end+1]:
27 | if not lst_name.endswith('.jsonl'):
28 | continue
29 | file_path = os.path.join(lst_dir, lst_name)
30 | with open(file_path, 'r') as file:
31 | for line_idx, line in enumerate(file):
32 | data = json.loads(line)
33 | # caption = data[caption_key]
34 | caption = data['text'][CAPTION_KEY[caption_key]]
35 | code_dir = file_path.split('/')[-1].split('.')[0]
36 | if trunc_caption:
37 | caption = caption.split('.')[0]
38 | img_path_list.append((caption, code_dir, line_idx))
39 | self.img_path_list = img_path_list
40 |
41 | def __len__(self):
42 | return len(self.img_path_list)
43 |
44 | def __getitem__(self, index):
45 | caption, code_dir, code_name = self.img_path_list[index]
46 | return caption, code_dir, code_name
47 |
48 |
49 |
50 | #################################################################################
51 | # Training Loop #
52 | #################################################################################
53 | def main(args):
54 | """
55 | Trains a new DiT model.
56 | """
57 | assert torch.cuda.is_available(), "Training currently requires at least one GPU."
58 |
59 | # Setup DDP:
60 | # dist.init_process_group("nccl")
61 | init_distributed_mode(args)
62 | rank = dist.get_rank()
63 | device = rank % torch.cuda.device_count()
64 | seed = args.global_seed * dist.get_world_size() + rank
65 | torch.manual_seed(seed)
66 | torch.cuda.set_device(device)
67 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
68 |
69 | # Setup a feature folder:
70 | if rank == 0:
71 | os.makedirs(args.t5_path, exist_ok=True)
72 |
73 | # Setup data:
74 | print(f"Dataset is preparing...")
75 | dataset = CustomDataset(args.data_path, args.data_start, args.data_end, args.caption_key, args.trunc_caption)
76 | sampler = DistributedSampler(
77 | dataset,
78 | num_replicas=dist.get_world_size(),
79 | rank=rank,
80 | shuffle=False,
81 | seed=args.global_seed
82 | )
83 | loader = DataLoader(
84 | dataset,
85 | batch_size=1, # important!
86 | shuffle=False,
87 | sampler=sampler,
88 | num_workers=args.num_workers,
89 | pin_memory=True,
90 | drop_last=False
91 | )
92 | print(f"Dataset contains {len(dataset):,} images")
93 |
94 | precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
95 | assert os.path.exists(args.t5_model_path)
96 | t5_xxl = T5Embedder(
97 | device=device,
98 | local_cache=True,
99 | cache_dir=args.t5_model_path,
100 | dir_or_name=args.t5_model_type,
101 | torch_dtype=precision
102 | )
103 |
104 | for caption, code_dir, code_name in loader:
105 | caption_embs, emb_masks = t5_xxl.get_text_embeddings(caption)
106 | valid_caption_embs = caption_embs[:, :emb_masks.sum()]
107 | x = valid_caption_embs.to(torch.float32).detach().cpu().numpy()
108 | os.makedirs(os.path.join(args.t5_path, code_dir[0]), exist_ok=True)
109 | np.save(os.path.join(args.t5_path, code_dir[0], '{}.npy'.format(code_name.item())), x)
110 | print(code_name.item())
111 |
112 | dist.destroy_process_group()
113 |
114 |
115 | if __name__ == "__main__":
116 | parser = argparse.ArgumentParser()
117 | parser.add_argument("--data-path", type=str, required=True)
118 | parser.add_argument("--t5-path", type=str, required=True)
119 | parser.add_argument("--data-start", type=int, required=True)
120 | parser.add_argument("--data-end", type=int, required=True)
121 | parser.add_argument("--caption-key", type=str, default='blip', choices=list(CAPTION_KEY.keys()))
122 | parser.add_argument("--trunc-caption", action='store_true', default=False)
123 | parser.add_argument("--t5-model-path", type=str, default='./pretrained_models/t5-ckpt')
124 | parser.add_argument("--t5-model-type", type=str, default='flan-t5-xl')
125 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
126 | parser.add_argument("--global-seed", type=int, default=0)
127 | parser.add_argument("--num-workers", type=int, default=24)
128 | args = parser.parse_args()
129 | main(args)
130 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=2.1.0
2 |
--------------------------------------------------------------------------------
/scripts/autoregressive/extract_codes_c2i.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12335 \
7 | autoregressive/train/extract_codes_c2i.py "$@"
8 |
--------------------------------------------------------------------------------
/scripts/autoregressive/sample_c2i.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12345 \
7 | autoregressive/sample/sample_c2i_ddp.py \
8 | --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt \
9 | "$@"
10 |
--------------------------------------------------------------------------------
/scripts/autoregressive/sample_t2i_coco.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12346 \
7 | autoregressive/sample/sample_t2i_ddp.py \
8 | --prompt-csv evaluations/t2i/coco_captions.csv \
9 | --sample-dir samples_coco \
10 | --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt \
11 | "$@"
12 |
--------------------------------------------------------------------------------
/scripts/autoregressive/sample_t2i_parti.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12347 \
7 | autoregressive/sample/sample_t2i_ddp.py \
8 | --prompt-csv evaluations/t2i/PartiPrompts.tsv \
9 | --sample-dir samples_parti \
10 | --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt \
11 | "$@"
12 |
--------------------------------------------------------------------------------
/scripts/autoregressive/train_c2i.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \
6 | --master_addr=$master_addr --master_port=$master_port \
7 | autoregressive/train/train_c2i.py "$@"
8 |
--------------------------------------------------------------------------------
/scripts/autoregressive/train_c2i_fsdp.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \
6 | --master_addr=$master_addr --master_port=$master_port \
7 | autoregressive/train/train_c2i_fsdp.py "$@"
8 |
--------------------------------------------------------------------------------
/scripts/autoregressive/train_t2i_stage1.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \
6 | --master_addr=$master_addr --master_port=$master_port \
7 | autoregressive/train/train_t2i.py \
8 | --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt \
9 | --data-path /path/to/laion_coco50M \
10 | --t5-feat-path /path/to/laion_coco50M_flan_t5_xl \
11 | --dataset t2i \
12 | --image-size 256 \
13 | "$@"
14 |
--------------------------------------------------------------------------------
/scripts/autoregressive/train_t2i_stage2.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \
6 | --master_addr=$master_addr --master_port=$master_port \
7 | autoregressive/train/train_t2i.py \
8 | --vq-ckpt ./pretrained_models/vq_ds16_t2i.pt \
9 | --data-path /path/to/high_aesthetic_10M \
10 | --t5-feat-path /path/to/high_aesthetic_10M_flan_t5_xl \
11 | --short-t5-feat-path /path/to/high_aesthetic_10M_trunc_flan_t5_xl \
12 | --dataset t2i \
13 | --image-size 512 \
14 | "$@"
15 |
--------------------------------------------------------------------------------
/scripts/language/extract_flan_t5_feat_laion_coco_stage1.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12337 \
7 | language/extract_t5_feature.py \
8 | --data-path /path/to/laion_coco50M \
9 | --t5-path /path/to/laion_coco50M_flan_t5_xl \
10 | --caption-key blip \
11 | "$@"
12 |
--------------------------------------------------------------------------------
/scripts/language/extract_flan_t5_feat_stage2.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12337 \
7 | language/extract_t5_feature.py \
8 | --data-path /path/to/high_aesthetic_10M \
9 | --t5-path /path/to/high_aesthetic_10M_flan_t5_xl \
10 | "$@"
11 |
--------------------------------------------------------------------------------
/scripts/language/extract_flan_t5_feat_trunc_stage2.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12337 \
7 | language/extract_t5_feature.py \
8 | --data-path /path/to/high_aesthetic_10M \
9 | --t5-path /path/to/high_aesthetic_10M_trunc_flan_t5_xl \
10 | --trunc-caption \
11 | "$@"
12 |
--------------------------------------------------------------------------------
/scripts/tokenizer/reconstruction_consistency_decoder.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12344 \
7 | tokenizer/consistencydecoder/reconstruction_cd_ddp.py \
8 | "$@"
--------------------------------------------------------------------------------
/scripts/tokenizer/reconstruction_vae.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12344 \
7 | tokenizer/vae/reconstruction_vae_ddp.py \
8 | "$@"
--------------------------------------------------------------------------------
/scripts/tokenizer/reconstruction_vq.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12344 \
7 | tokenizer/tokenizer_image/reconstruction_vq_ddp.py \
8 | "$@"
--------------------------------------------------------------------------------
/scripts/tokenizer/reconstruction_vqgan.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12344 \
7 | tokenizer/vqgan/reconstruction_vqgan_ddp.py \
8 | "$@"
--------------------------------------------------------------------------------
/scripts/tokenizer/train_vq.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \
6 | --master_addr=$master_addr --master_port=$master_port \
7 | tokenizer/tokenizer_image/vq_train.py "$@"
--------------------------------------------------------------------------------
/scripts/tokenizer/train_vq_finetune.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \
6 | --master_addr=$master_addr --master_port=$master_port \
7 | tokenizer/tokenizer_image/vq_train.py \
8 | --finetune \
9 | --disc-start 0 \
10 | --vq-ckpt ./pretrained_models/vq_ds16_c2i.pt \
11 | --dataset t2i_image \
12 | --data-path /path/to/high_aesthetic_10M \
13 | --data-face-path /path/to/face_2M \
14 | --cloud-save-path /path/to/cloud_disk \
15 | "$@"
--------------------------------------------------------------------------------
/scripts/tokenizer/train_vq_finetune_continue.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=$nnodes --nproc_per_node=$nproc_per_node --node_rank=$node_rank \
6 | --master_addr=$master_addr --master_port=$master_port \
7 | tokenizer/tokenizer_image/vq_train.py \
8 | --disc-start 0 \
9 | --dataset t2i_image \
10 | --data-path /path/to/high_aesthetic_10M \
11 | --data-face-path /path/to/face_2M \
12 | --cloud-save-path /path/to/cloud_disk \
13 | "$@"
14 |
15 | # --vq-ckpt xxx.pt
--------------------------------------------------------------------------------
/scripts/tokenizer/val.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -x
3 |
4 | torchrun \
5 | --nnodes=1 --nproc_per_node=8 --node_rank=0 \
6 | --master_port=12343 \
7 | tokenizer/validation/val_ddp.py \
8 | "$@"
--------------------------------------------------------------------------------
/tokenizer/consistencydecoder/README.md:
--------------------------------------------------------------------------------
1 | ## Consistency Decoder from OpenAI
2 |
3 | ### install
4 | ```
5 | pip install diffusers
6 | pip install accelerate
7 | ```
8 |
9 | ### demo
10 | ```
11 | cd ${THIS_REPO_ROOT}
12 | python3 tokenizer/consistencydecoder/cd_demo.py
13 | ```
14 |
15 |
--------------------------------------------------------------------------------
/tokenizer/consistencydecoder/cd_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from PIL import Image
6 | from diffusers import ConsistencyDecoderVAE
7 |
8 |
9 | def main(args):
10 | # Setup PyTorch:
11 | torch.manual_seed(args.seed)
12 | torch.set_grad_enabled(False)
13 | device = "cuda" if torch.cuda.is_available() else "cpu"
14 |
15 | # create and load model
16 | vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to(device)
17 |
18 | # load image
19 | img_path = args.image_path
20 | out_path = args.image_path.replace('.jpg', '_cd.jpg').replace('.jpeg', '_cd.jpeg').replace('.png', '_cd.png')
21 | input_size = args.image_size
22 | img = Image.open(img_path).convert("RGB")
23 |
24 | # preprocess
25 | size_org = img.size
26 | img = img.resize((input_size, input_size))
27 | img = np.array(img) / 255.
28 | x = 2.0 * img - 1.0 # x value is between [-1, 1]
29 | x = torch.tensor(x)
30 | x = x.unsqueeze(dim=0)
31 | x = torch.einsum('nhwc->nchw', x)
32 | x_input = x.half().to(device)
33 |
34 | # inference
35 | with torch.no_grad():
36 | # Map input images to latent space + normalize latents:
37 | latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215)
38 | # reconstruct:
39 | output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1]
40 |
41 | # postprocess
42 | output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0]
43 | sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
44 |
45 | # save
46 | Image.fromarray(sample).save(out_path)
47 | print("Reconstructed image is saved to {}".format(out_path))
48 |
49 |
50 |
51 | if __name__ == "__main__":
52 | parser = argparse.ArgumentParser()
53 | parser.add_argument("--image-path", type=str, default="assets/example.jpg")
54 | parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512)
55 | parser.add_argument("--seed", type=int, default=0)
56 | args = parser.parse_args()
57 | main(args)
58 |
--------------------------------------------------------------------------------
/tokenizer/consistencydecoder/reconstruction_cd_ddp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.backends.cuda.matmul.allow_tf32 = True
3 | torch.backends.cudnn.allow_tf32 = True
4 | import torch.distributed as dist
5 | from torch.utils.data import Dataset, DataLoader
6 | from torch.utils.data.distributed import DistributedSampler
7 | from torchvision.datasets import ImageFolder
8 | from torchvision import transforms
9 | from tqdm import tqdm
10 | import os
11 | import itertools
12 | from PIL import Image
13 | import numpy as np
14 | import argparse
15 | import random
16 |
17 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss
18 | from skimage.metrics import structural_similarity as ssim_loss
19 | from diffusers.models import ConsistencyDecoderVAE
20 |
21 |
22 | class SingleFolderDataset(Dataset):
23 | def __init__(self, directory, transform=None):
24 | super().__init__()
25 | self.directory = directory
26 | self.transform = transform
27 | self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory)
28 | if os.path.isfile(os.path.join(directory, file_name))]
29 |
30 | def __len__(self):
31 | return len(self.image_paths)
32 |
33 | def __getitem__(self, idx):
34 | image_path = self.image_paths[idx]
35 | image = Image.open(image_path).convert('RGB')
36 | if self.transform:
37 | image = self.transform(image)
38 | return image, torch.tensor(0)
39 |
40 |
41 | def create_npz_from_sample_folder(sample_dir, num=50_000):
42 | """
43 | Builds a single .npz file from a folder of .png samples.
44 | """
45 | samples = []
46 | for i in tqdm(range(num), desc="Building .npz file from samples"):
47 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
48 | sample_np = np.asarray(sample_pil).astype(np.uint8)
49 | samples.append(sample_np)
50 |
51 | random.shuffle(samples) # This is very important for IS(Inception Score) !!!
52 | samples = np.stack(samples)
53 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
54 | npz_path = f"{sample_dir}.npz"
55 | np.savez(npz_path, arr_0=samples)
56 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
57 | return npz_path
58 |
59 |
60 | def center_crop_arr(pil_image, image_size):
61 | """
62 | Center cropping implementation from ADM.
63 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
64 | """
65 | while min(*pil_image.size) >= 2 * image_size:
66 | pil_image = pil_image.resize(
67 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
68 | )
69 |
70 | scale = image_size / min(*pil_image.size)
71 | pil_image = pil_image.resize(
72 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
73 | )
74 |
75 | arr = np.array(pil_image)
76 | crop_y = (arr.shape[0] - image_size) // 2
77 | crop_x = (arr.shape[1] - image_size) // 2
78 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
79 |
80 |
81 | def main(args):
82 | # Setup PyTorch:
83 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
84 | torch.set_grad_enabled(False)
85 |
86 | # Setup env
87 | dist.init_process_group("nccl")
88 | rank = dist.get_rank()
89 | device = rank % torch.cuda.device_count()
90 | seed = args.global_seed * dist.get_world_size() + rank
91 | torch.manual_seed(seed)
92 | torch.cuda.set_device(device)
93 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
94 |
95 | # create and load model
96 | vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=torch.float16).to("cuda:{}".format(device))
97 |
98 | # Create folder to save samples:
99 | folder_name = f"openai-consistencydecoder-{args.dataset}-size-{args.image_size}-seed-{args.global_seed}"
100 | sample_folder_dir = f"{args.sample_dir}/{folder_name}"
101 | if rank == 0:
102 | os.makedirs(sample_folder_dir, exist_ok=True)
103 | print(f"Saving .png samples at {sample_folder_dir}")
104 | dist.barrier()
105 |
106 | # Setup data:
107 | transform = transforms.Compose([
108 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
109 | transforms.ToTensor(),
110 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
111 | ])
112 | if args.dataset == 'imagenet':
113 | dataset = ImageFolder(args.data_path, transform=transform)
114 | num_fid_samples = 50000
115 | elif args.dataset == 'coco':
116 | dataset = SingleFolderDataset(args.data_path, transform=transform)
117 | num_fid_samples = 5000
118 | else:
119 | raise Exception("please check dataset")
120 | sampler = DistributedSampler(
121 | dataset,
122 | num_replicas=dist.get_world_size(),
123 | rank=rank,
124 | shuffle=False,
125 | seed=args.global_seed
126 | )
127 | loader = DataLoader(
128 | dataset,
129 | batch_size=args.per_proc_batch_size,
130 | shuffle=False,
131 | sampler=sampler,
132 | num_workers=args.num_workers,
133 | pin_memory=True,
134 | drop_last=False
135 | )
136 |
137 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
138 | n = args.per_proc_batch_size
139 | global_batch_size = n * dist.get_world_size()
140 | psnr_val_rgb = []
141 | ssim_val_rgb = []
142 |
143 | loader = tqdm(loader) if rank == 0 else loader
144 | total = 0
145 | for x, _ in loader:
146 | rgb_gts = x
147 | rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1]
148 | x = x.half().to("cuda:{}".format(device))
149 | with torch.no_grad():
150 | # Map input images to latent space + normalize latents:
151 | latent = vae.encode(x).latent_dist.sample().mul_(0.18215)
152 | # reconstruct:
153 | samples = vae.decode(latent / 0.18215).sample # output value is between [-1, 1]
154 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
155 |
156 | # Save samples to disk as individual .png files
157 | for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)):
158 | index = i * dist.get_world_size() + rank + total
159 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
160 | # metric
161 | rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1]
162 | psnr = psnr_loss(rgb_restored, rgb_gt)
163 | ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1)
164 | psnr_val_rgb.append(psnr)
165 | ssim_val_rgb.append(ssim)
166 | total += global_batch_size
167 |
168 | # ------------------------------------
169 | # Summary
170 | # ------------------------------------
171 | # Make sure all processes have finished saving their samples
172 | dist.barrier()
173 | world_size = dist.get_world_size()
174 | gather_psnr_val = [None for _ in range(world_size)]
175 | gather_ssim_val = [None for _ in range(world_size)]
176 | dist.all_gather_object(gather_psnr_val, psnr_val_rgb)
177 | dist.all_gather_object(gather_ssim_val, ssim_val_rgb)
178 |
179 | if rank == 0:
180 | gather_psnr_val = list(itertools.chain(*gather_psnr_val))
181 | gather_ssim_val = list(itertools.chain(*gather_ssim_val))
182 | psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val)
183 | ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val)
184 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb))
185 |
186 | result_file = f"{sample_folder_dir}_results.txt"
187 | print("writing results to {}".format(result_file))
188 | with open(result_file, 'w') as f:
189 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f)
190 |
191 | create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
192 | print("Done.")
193 |
194 | dist.barrier()
195 | dist.destroy_process_group()
196 |
197 |
198 | if __name__ == "__main__":
199 | parser = argparse.ArgumentParser()
200 | parser.add_argument("--data-path", type=str, required=True)
201 | parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet')
202 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
203 | parser.add_argument("--sample-dir", type=str, default="reconstructions")
204 | parser.add_argument("--per-proc-batch-size", type=int, default=32)
205 | parser.add_argument("--global-seed", type=int, default=0)
206 | parser.add_argument("--num-workers", type=int, default=4)
207 | args = parser.parse_args()
208 | main(args)
--------------------------------------------------------------------------------
/tokenizer/tokenizer_image/cache/vgg.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FoundationVision/LlamaGen/ce98ec41803a74a90ce68c40ababa9eaeffeb4ec/tokenizer/tokenizer_image/cache/vgg.pth
--------------------------------------------------------------------------------
/tokenizer/tokenizer_image/discriminator_patchgan.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # taming-transformers: https://github.com/CompVis/taming-transformers
3 | import functools
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | class NLayerDiscriminator(nn.Module):
9 | """Defines a PatchGAN discriminator as in Pix2Pix
10 | --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
11 | """
12 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
13 | """Construct a PatchGAN discriminator
14 | Parameters:
15 | input_nc (int) -- the number of channels in input images
16 | ndf (int) -- the number of filters in the last conv layer
17 | n_layers (int) -- the number of conv layers in the discriminator
18 | norm_layer -- normalization layer
19 | """
20 | super(NLayerDiscriminator, self).__init__()
21 | if not use_actnorm:
22 | norm_layer = nn.BatchNorm2d
23 | else:
24 | norm_layer = ActNorm
25 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
26 | use_bias = norm_layer.func != nn.BatchNorm2d
27 | else:
28 | use_bias = norm_layer != nn.BatchNorm2d
29 |
30 | kw = 4
31 | padw = 1
32 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
33 | nf_mult = 1
34 | nf_mult_prev = 1
35 | for n in range(1, n_layers): # gradually increase the number of filters
36 | nf_mult_prev = nf_mult
37 | nf_mult = min(2 ** n, 8)
38 | sequence += [
39 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
40 | norm_layer(ndf * nf_mult),
41 | nn.LeakyReLU(0.2, True)
42 | ]
43 |
44 | nf_mult_prev = nf_mult
45 | nf_mult = min(2 ** n_layers, 8)
46 | sequence += [
47 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
48 | norm_layer(ndf * nf_mult),
49 | nn.LeakyReLU(0.2, True)
50 | ]
51 |
52 | sequence += [
53 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
54 | self.main = nn.Sequential(*sequence)
55 |
56 | self.apply(self._init_weights)
57 |
58 | def _init_weights(self, module):
59 | if isinstance(module, nn.Conv2d):
60 | nn.init.normal_(module.weight.data, 0.0, 0.02)
61 | elif isinstance(module, nn.BatchNorm2d):
62 | nn.init.normal_(module.weight.data, 1.0, 0.02)
63 | nn.init.constant_(module.bias.data, 0)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return self.main(input)
68 |
69 |
70 | class ActNorm(nn.Module):
71 | def __init__(self, num_features, logdet=False, affine=True,
72 | allow_reverse_init=False):
73 | assert affine
74 | super().__init__()
75 | self.logdet = logdet
76 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
77 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
78 | self.allow_reverse_init = allow_reverse_init
79 |
80 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
81 |
82 | def initialize(self, input):
83 | with torch.no_grad():
84 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
85 | mean = (
86 | flatten.mean(1)
87 | .unsqueeze(1)
88 | .unsqueeze(2)
89 | .unsqueeze(3)
90 | .permute(1, 0, 2, 3)
91 | )
92 | std = (
93 | flatten.std(1)
94 | .unsqueeze(1)
95 | .unsqueeze(2)
96 | .unsqueeze(3)
97 | .permute(1, 0, 2, 3)
98 | )
99 |
100 | self.loc.data.copy_(-mean)
101 | self.scale.data.copy_(1 / (std + 1e-6))
102 |
103 | def forward(self, input, reverse=False):
104 | if reverse:
105 | return self.reverse(input)
106 | if len(input.shape) == 2:
107 | input = input[:,:,None,None]
108 | squeeze = True
109 | else:
110 | squeeze = False
111 |
112 | _, _, height, width = input.shape
113 |
114 | if self.training and self.initialized.item() == 0:
115 | self.initialize(input)
116 | self.initialized.fill_(1)
117 |
118 | h = self.scale * (input + self.loc)
119 |
120 | if squeeze:
121 | h = h.squeeze(-1).squeeze(-1)
122 |
123 | if self.logdet:
124 | log_abs = torch.log(torch.abs(self.scale))
125 | logdet = height*width*torch.sum(log_abs)
126 | logdet = logdet * torch.ones(input.shape[0]).to(input)
127 | return h, logdet
128 |
129 | return h
130 |
131 | def reverse(self, output):
132 | if self.training and self.initialized.item() == 0:
133 | if not self.allow_reverse_init:
134 | raise RuntimeError(
135 | "Initializing ActNorm in reverse direction is "
136 | "disabled by default. Use allow_reverse_init=True to enable."
137 | )
138 | else:
139 | self.initialize(output)
140 | self.initialized.fill_(1)
141 |
142 | if len(output.shape) == 2:
143 | output = output[:,:,None,None]
144 | squeeze = True
145 | else:
146 | squeeze = False
147 |
148 | h = output / self.scale - self.loc
149 |
150 | if squeeze:
151 | h = h.squeeze(-1).squeeze(-1)
152 | return h
--------------------------------------------------------------------------------
/tokenizer/tokenizer_image/discriminator_stylegan.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # stylegan2-pytorch: https://github.com/lucidrains/stylegan2-pytorch/blob/master/stylegan2_pytorch/stylegan2_pytorch.py
3 | # stylegan2-pytorch: https://github.com/rosinality/stylegan2-pytorch/blob/master/model.py
4 | # maskgit: https://github.com/google-research/maskgit/blob/main/maskgit/nets/discriminator.py
5 | import math
6 | import torch
7 | import torch.nn as nn
8 | try:
9 | from kornia.filters import filter2d
10 | except:
11 | pass
12 |
13 | class Discriminator(nn.Module):
14 | def __init__(self, input_nc=3, ndf=64, n_layers=3, channel_multiplier=1, image_size=256):
15 | super().__init__()
16 | channels = {
17 | 4: 512,
18 | 8: 512,
19 | 16: 512,
20 | 32: 512,
21 | 64: 256 * channel_multiplier,
22 | 128: 128 * channel_multiplier,
23 | 256: 64 * channel_multiplier,
24 | 512: 32 * channel_multiplier,
25 | 1024: 16 * channel_multiplier,
26 | }
27 |
28 | log_size = int(math.log(image_size, 2))
29 | in_channel = channels[image_size]
30 |
31 | blocks = [nn.Conv2d(input_nc, in_channel, 3, padding=1), leaky_relu()]
32 | for i in range(log_size, 2, -1):
33 | out_channel = channels[2 ** (i - 1)]
34 | blocks.append(DiscriminatorBlock(in_channel, out_channel))
35 | in_channel = out_channel
36 | self.blocks = nn.ModuleList(blocks)
37 |
38 | self.final_conv = nn.Sequential(
39 | nn.Conv2d(in_channel, channels[4], 3, padding=1),
40 | leaky_relu(),
41 | )
42 | self.final_linear = nn.Sequential(
43 | nn.Linear(channels[4] * 4 * 4, channels[4]),
44 | leaky_relu(),
45 | nn.Linear(channels[4], 1)
46 | )
47 |
48 | def forward(self, x):
49 | for block in self.blocks:
50 | x = block(x)
51 | x = self.final_conv(x)
52 | x = x.view(x.shape[0], -1)
53 | x = self.final_linear(x)
54 | return x
55 |
56 |
57 | class DiscriminatorBlock(nn.Module):
58 | def __init__(self, input_channels, filters, downsample=True):
59 | super().__init__()
60 | self.conv_res = nn.Conv2d(input_channels, filters, 1, stride = (2 if downsample else 1))
61 |
62 | self.net = nn.Sequential(
63 | nn.Conv2d(input_channels, filters, 3, padding=1),
64 | leaky_relu(),
65 | nn.Conv2d(filters, filters, 3, padding=1),
66 | leaky_relu()
67 | )
68 |
69 | self.downsample = nn.Sequential(
70 | Blur(),
71 | nn.Conv2d(filters, filters, 3, padding = 1, stride = 2)
72 | ) if downsample else None
73 |
74 | def forward(self, x):
75 | res = self.conv_res(x)
76 | x = self.net(x)
77 | if exists(self.downsample):
78 | x = self.downsample(x)
79 | x = (x + res) * (1 / math.sqrt(2))
80 | return x
81 |
82 |
83 |
84 | class Blur(nn.Module):
85 | def __init__(self):
86 | super().__init__()
87 | f = torch.Tensor([1, 2, 1])
88 | self.register_buffer('f', f)
89 |
90 | def forward(self, x):
91 | f = self.f
92 | f = f[None, None, :] * f [None, :, None]
93 | return filter2d(x, f, normalized=True)
94 |
95 |
96 | def leaky_relu(p=0.2):
97 | return nn.LeakyReLU(p, inplace=True)
98 |
99 |
100 | def exists(val):
101 | return val is not None
102 |
--------------------------------------------------------------------------------
/tokenizer/tokenizer_image/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import os, hashlib
4 | import requests
5 | from tqdm import tqdm
6 |
7 | import torch
8 | import torch.nn as nn
9 | from torchvision import models
10 | from collections import namedtuple
11 |
12 | URL_MAP = {
13 | "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
14 | }
15 |
16 | CKPT_MAP = {
17 | "vgg_lpips": "vgg.pth"
18 | }
19 |
20 | MD5_MAP = {
21 | "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
22 | }
23 |
24 | def download(url, local_path, chunk_size=1024):
25 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
26 | with requests.get(url, stream=True) as r:
27 | total_size = int(r.headers.get("content-length", 0))
28 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
29 | with open(local_path, "wb") as f:
30 | for data in r.iter_content(chunk_size=chunk_size):
31 | if data:
32 | f.write(data)
33 | pbar.update(chunk_size)
34 |
35 |
36 | def md5_hash(path):
37 | with open(path, "rb") as f:
38 | content = f.read()
39 | return hashlib.md5(content).hexdigest()
40 |
41 |
42 | def get_ckpt_path(name, root, check=False):
43 | assert name in URL_MAP
44 | path = os.path.join(root, CKPT_MAP[name])
45 | if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
46 | print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
47 | download(URL_MAP[name], path)
48 | md5 = md5_hash(path)
49 | assert md5 == MD5_MAP[name], md5
50 | return path
51 |
52 |
53 | class LPIPS(nn.Module):
54 | # Learned perceptual metric
55 | def __init__(self, use_dropout=True):
56 | super().__init__()
57 | self.scaling_layer = ScalingLayer()
58 | self.chns = [64, 128, 256, 512, 512] # vg16 features
59 | self.net = vgg16(pretrained=True, requires_grad=False)
60 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
61 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
62 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
63 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
64 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
65 | self.load_from_pretrained()
66 | for param in self.parameters():
67 | param.requires_grad = False
68 |
69 | def load_from_pretrained(self, name="vgg_lpips"):
70 | ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
71 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
72 | print("loaded pretrained LPIPS loss from {}".format(ckpt))
73 |
74 | @classmethod
75 | def from_pretrained(cls, name="vgg_lpips"):
76 | if name != "vgg_lpips":
77 | raise NotImplementedError
78 | model = cls()
79 | ckpt = get_ckpt_path(name, os.path.join(os.path.dirname(os.path.abspath(__file__)), "cache"))
80 | model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
81 | return model
82 |
83 | def forward(self, input, target):
84 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
85 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
86 | feats0, feats1, diffs = {}, {}, {}
87 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
88 | for kk in range(len(self.chns)):
89 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
90 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
91 |
92 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
93 | val = res[0]
94 | for l in range(1, len(self.chns)):
95 | val += res[l]
96 | return val
97 |
98 |
99 | class ScalingLayer(nn.Module):
100 | def __init__(self):
101 | super(ScalingLayer, self).__init__()
102 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
103 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
104 |
105 | def forward(self, inp):
106 | return (inp - self.shift) / self.scale
107 |
108 |
109 | class NetLinLayer(nn.Module):
110 | """ A single linear layer which does a 1x1 conv """
111 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
112 | super(NetLinLayer, self).__init__()
113 | layers = [nn.Dropout(), ] if (use_dropout) else []
114 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
115 | self.model = nn.Sequential(*layers)
116 |
117 |
118 | class vgg16(torch.nn.Module):
119 | def __init__(self, requires_grad=False, pretrained=True):
120 | super(vgg16, self).__init__()
121 | vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
122 | self.slice1 = torch.nn.Sequential()
123 | self.slice2 = torch.nn.Sequential()
124 | self.slice3 = torch.nn.Sequential()
125 | self.slice4 = torch.nn.Sequential()
126 | self.slice5 = torch.nn.Sequential()
127 | self.N_slices = 5
128 | for x in range(4):
129 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
130 | for x in range(4, 9):
131 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
132 | for x in range(9, 16):
133 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
134 | for x in range(16, 23):
135 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
136 | for x in range(23, 30):
137 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
138 | if not requires_grad:
139 | for param in self.parameters():
140 | param.requires_grad = False
141 |
142 | def forward(self, X):
143 | h = self.slice1(X)
144 | h_relu1_2 = h
145 | h = self.slice2(h)
146 | h_relu2_2 = h
147 | h = self.slice3(h)
148 | h_relu3_3 = h
149 | h = self.slice4(h)
150 | h_relu4_3 = h
151 | h = self.slice5(h)
152 | h_relu5_3 = h
153 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
154 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
155 | return out
156 |
157 |
158 | def normalize_tensor(x,eps=1e-10):
159 | norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
160 | return x/(norm_factor+eps)
161 |
162 |
163 | def spatial_average(x, keepdim=True):
164 | return x.mean([2,3],keepdim=keepdim)
--------------------------------------------------------------------------------
/tokenizer/tokenizer_image/reconstruction_vq_ddp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.backends.cuda.matmul.allow_tf32 = True
3 | torch.backends.cudnn.allow_tf32 = True
4 | import torch.nn.functional as F
5 | import torch.distributed as dist
6 | from torch.utils.data import DataLoader
7 | from torch.utils.data.distributed import DistributedSampler
8 | from torchvision import transforms
9 | from tqdm import tqdm
10 | import os
11 | from PIL import Image
12 | import numpy as np
13 | import argparse
14 | import itertools
15 |
16 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss
17 | from skimage.metrics import structural_similarity as ssim_loss
18 |
19 | from dataset.augmentation import center_crop_arr
20 | from dataset.build import build_dataset
21 | from tokenizer.tokenizer_image.vq_model import VQ_models
22 |
23 |
24 |
25 | def create_npz_from_sample_folder(sample_dir, num=50000):
26 | """
27 | Builds a single .npz file from a folder of .png samples.
28 | """
29 | samples = []
30 | for i in tqdm(range(num), desc="Building .npz file from samples"):
31 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
32 | sample_np = np.asarray(sample_pil).astype(np.uint8)
33 | samples.append(sample_np)
34 | samples = np.stack(samples)
35 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
36 | npz_path = f"{sample_dir}.npz"
37 | np.savez(npz_path, arr_0=samples)
38 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
39 | return npz_path
40 |
41 |
42 |
43 | def main(args):
44 | # Setup PyTorch:
45 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
46 | torch.set_grad_enabled(False)
47 |
48 | # Setup DDP:
49 | dist.init_process_group("nccl")
50 | rank = dist.get_rank()
51 | device = rank % torch.cuda.device_count()
52 | seed = args.global_seed * dist.get_world_size() + rank
53 | torch.manual_seed(seed)
54 | torch.cuda.set_device(device)
55 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
56 |
57 | # create and load model
58 | vq_model = VQ_models[args.vq_model](
59 | codebook_size=args.codebook_size,
60 | codebook_embed_dim=args.codebook_embed_dim)
61 | vq_model.to(device)
62 | vq_model.eval()
63 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
64 | if "ema" in checkpoint: # ema
65 | model_weight = checkpoint["ema"]
66 | elif "model" in checkpoint: # ddp
67 | model_weight = checkpoint["model"]
68 | elif "state_dict" in checkpoint:
69 | model_weight = checkpoint["state_dict"]
70 | else:
71 | raise Exception("please check model weight")
72 | vq_model.load_state_dict(model_weight)
73 | del checkpoint
74 |
75 | # Create folder to save samples:
76 | folder_name = (f"{args.vq_model}-{args.dataset}-size-{args.image_size}-size-{args.image_size_eval}"
77 | f"-codebook-size-{args.codebook_size}-dim-{args.codebook_embed_dim}-seed-{args.global_seed}")
78 | sample_folder_dir = f"{args.sample_dir}/{folder_name}"
79 | if rank == 0:
80 | os.makedirs(sample_folder_dir, exist_ok=True)
81 | print(f"Saving .png samples at {sample_folder_dir}")
82 | dist.barrier()
83 |
84 | # Setup data:
85 | transform = transforms.Compose([
86 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
87 | transforms.ToTensor(),
88 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
89 | ])
90 |
91 | if args.dataset == 'imagenet':
92 | dataset = build_dataset(args, transform=transform)
93 | num_fid_samples = 50000
94 | elif args.dataset == 'coco':
95 | dataset = build_dataset(args, transform=transform)
96 | num_fid_samples = 5000
97 | else:
98 | raise Exception("please check dataset")
99 |
100 | sampler = DistributedSampler(
101 | dataset,
102 | num_replicas=dist.get_world_size(),
103 | rank=rank,
104 | shuffle=False,
105 | seed=args.global_seed
106 | )
107 | loader = DataLoader(
108 | dataset,
109 | batch_size=args.per_proc_batch_size,
110 | shuffle=False,
111 | sampler=sampler,
112 | num_workers=args.num_workers,
113 | pin_memory=True,
114 | drop_last=False
115 | )
116 |
117 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
118 | n = args.per_proc_batch_size
119 | global_batch_size = n * dist.get_world_size()
120 |
121 | psnr_val_rgb = []
122 | ssim_val_rgb = []
123 | loader = tqdm(loader) if rank == 0 else loader
124 | total = 0
125 | for x, _ in loader:
126 | if args.image_size_eval != args.image_size:
127 | rgb_gts = F.interpolate(x, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
128 | else:
129 | rgb_gts = x
130 | rgb_gts = (rgb_gts.permute(0, 2, 3, 1).to("cpu").numpy() + 1.0) / 2.0 # rgb_gt value is between [0, 1]
131 | x = x.to(device, non_blocking=True)
132 | with torch.no_grad():
133 | latent, _, [_, _, indices] = vq_model.encode(x)
134 | samples = vq_model.decode_code(indices, latent.shape) # output value is between [-1, 1]
135 | if args.image_size_eval != args.image_size:
136 | samples = F.interpolate(samples, size=(args.image_size_eval, args.image_size_eval), mode='bicubic')
137 | samples = torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
138 |
139 | # Save samples to disk as individual .png files
140 | for i, (sample, rgb_gt) in enumerate(zip(samples, rgb_gts)):
141 | index = i * dist.get_world_size() + rank + total
142 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
143 | # metric
144 | rgb_restored = sample.astype(np.float32) / 255. # rgb_restored value is between [0, 1]
145 | psnr = psnr_loss(rgb_restored, rgb_gt)
146 | ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True, data_range=2.0, channel_axis=-1)
147 | psnr_val_rgb.append(psnr)
148 | ssim_val_rgb.append(ssim)
149 |
150 | total += global_batch_size
151 |
152 | # ------------------------------------
153 | # Summary
154 | # ------------------------------------
155 | # Make sure all processes have finished saving their samples
156 | dist.barrier()
157 | world_size = dist.get_world_size()
158 | gather_psnr_val = [None for _ in range(world_size)]
159 | gather_ssim_val = [None for _ in range(world_size)]
160 | dist.all_gather_object(gather_psnr_val, psnr_val_rgb)
161 | dist.all_gather_object(gather_ssim_val, ssim_val_rgb)
162 |
163 | if rank == 0:
164 | gather_psnr_val = list(itertools.chain(*gather_psnr_val))
165 | gather_ssim_val = list(itertools.chain(*gather_ssim_val))
166 | psnr_val_rgb = sum(gather_psnr_val) / len(gather_psnr_val)
167 | ssim_val_rgb = sum(gather_ssim_val) / len(gather_ssim_val)
168 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb))
169 |
170 | result_file = f"{sample_folder_dir}_results.txt"
171 | print("writing results to {}".format(result_file))
172 | with open(result_file, 'w') as f:
173 | print("PSNR: %f, SSIM: %f " % (psnr_val_rgb, ssim_val_rgb), file=f)
174 |
175 | create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
176 | print("Done.")
177 |
178 | dist.barrier()
179 | dist.destroy_process_group()
180 |
181 |
182 | if __name__ == "__main__":
183 | parser = argparse.ArgumentParser()
184 | parser.add_argument("--data-path", type=str, required=True)
185 | parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet')
186 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
187 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
188 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
189 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
190 | parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=256)
191 | parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
192 | parser.add_argument("--sample-dir", type=str, default="reconstructions")
193 | parser.add_argument("--per-proc-batch-size", type=int, default=32)
194 | parser.add_argument("--global-seed", type=int, default=0)
195 | parser.add_argument("--num-workers", type=int, default=4)
196 | args = parser.parse_args()
197 | main(args)
--------------------------------------------------------------------------------
/tokenizer/tokenizer_image/vq_demo.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import os
5 | import argparse
6 | import numpy as np
7 | from PIL import Image
8 |
9 | from tokenizer.tokenizer_image.vq_model import VQ_models
10 | from dataset.augmentation import center_crop_arr
11 |
12 |
13 | def main(args):
14 | # Setup PyTorch:
15 | torch.manual_seed(args.seed)
16 | torch.set_grad_enabled(False)
17 | device = "cuda" if torch.cuda.is_available() else "cpu"
18 |
19 | # create and load model
20 | model = VQ_models[args.vq_model](
21 | codebook_size=args.codebook_size,
22 | codebook_embed_dim=args.codebook_embed_dim)
23 | model.to(device)
24 | model.eval()
25 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
26 | if "ema" in checkpoint: # ema
27 | model_weight = checkpoint["ema"]
28 | elif "model" in checkpoint: # ddp
29 | model_weight = checkpoint["model"]
30 | elif "state_dict" in checkpoint:
31 | model_weight = checkpoint["state_dict"]
32 | else:
33 | raise Exception("please check model weight")
34 | model.load_state_dict(model_weight)
35 | del checkpoint
36 |
37 | # output dir
38 | os.makedirs(args.output_dir, exist_ok=True)
39 | out_path = args.image_path.replace('.jpg', '_{}.jpg'.format(args.suffix))
40 | out_path = out_path.replace('.jpeg', '_{}.jpeg'.format(args.suffix))
41 | out_path = out_path.replace('.png', '_{}.png'.format(args.suffix))
42 | out_filename = out_path.split('/')[-1]
43 | out_path = os.path.join(args.output_dir, out_filename)
44 |
45 | # load image
46 | pil_image = Image.open(args.image_path).convert("RGB")
47 | img = center_crop_arr(pil_image, args.image_size)
48 | # # preprocess
49 | # size_org = img.size
50 | # img = img.resize((input_size, input_size))
51 | img = np.array(img) / 255.
52 | x = 2.0 * img - 1.0 # x value is between [-1, 1]
53 | x = torch.tensor(x)
54 | x = x.unsqueeze(dim=0)
55 | x = torch.einsum('nhwc->nchw', x)
56 | x_input = x.float().to("cuda")
57 |
58 | # inference
59 | with torch.no_grad():
60 | latent, _, [_, _, indices] = model.encode(x_input)
61 | output = model.decode_code(indices, latent.shape) # output value is between [-1, 1]
62 |
63 | # postprocess
64 | output = F.interpolate(output, size=[args.image_size, args.image_size], mode='bicubic').permute(0, 2, 3, 1)[0]
65 | sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
66 |
67 | # save
68 | Image.fromarray(sample).save(out_path)
69 | print("Reconstructed image is saved to {}".format(out_path))
70 |
71 |
72 | if __name__ == "__main__":
73 | parser = argparse.ArgumentParser()
74 | parser.add_argument("--image-path", type=str, default="assets/example.jpg")
75 | parser.add_argument("--output-dir", type=str, default="output_vq_demo")
76 | parser.add_argument("--suffix", type=str, default="tokenizer_image")
77 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
78 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
79 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
80 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
81 | parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512, 1024], default=512)
82 | parser.add_argument("--seed", type=int, default=0)
83 | args = parser.parse_args()
84 | main(args)
--------------------------------------------------------------------------------
/tokenizer/tokenizer_image/vq_loss.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # taming-transformers: https://github.com/CompVis/taming-transformers
3 | # muse-maskgit-pytorch: https://github.com/lucidrains/muse-maskgit-pytorch/blob/main/muse_maskgit_pytorch/vqgan_vae.py
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from tokenizer.tokenizer_image.lpips import LPIPS
9 | from tokenizer.tokenizer_image.discriminator_patchgan import NLayerDiscriminator as PatchGANDiscriminator
10 | from tokenizer.tokenizer_image.discriminator_stylegan import Discriminator as StyleGANDiscriminator
11 |
12 |
13 |
14 | def hinge_d_loss(logits_real, logits_fake):
15 | loss_real = torch.mean(F.relu(1. - logits_real))
16 | loss_fake = torch.mean(F.relu(1. + logits_fake))
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 |
21 | def vanilla_d_loss(logits_real, logits_fake):
22 | loss_real = torch.mean(F.softplus(-logits_real))
23 | loss_fake = torch.mean(F.softplus(logits_fake))
24 | d_loss = 0.5 * (loss_real + loss_fake)
25 | return d_loss
26 |
27 |
28 | def non_saturating_d_loss(logits_real, logits_fake):
29 | loss_real = torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logits_real), logits_real))
30 | loss_fake = torch.mean(F.binary_cross_entropy_with_logits(torch.zeros_like(logits_fake), logits_fake))
31 | d_loss = 0.5 * (loss_real + loss_fake)
32 | return d_loss
33 |
34 |
35 | def hinge_gen_loss(logit_fake):
36 | return -torch.mean(logit_fake)
37 |
38 |
39 | def non_saturating_gen_loss(logit_fake):
40 | return torch.mean(F.binary_cross_entropy_with_logits(torch.ones_like(logit_fake), logit_fake))
41 |
42 |
43 | def adopt_weight(weight, global_step, threshold=0, value=0.):
44 | if global_step < threshold:
45 | weight = value
46 | return weight
47 |
48 |
49 | class VQLoss(nn.Module):
50 | def __init__(self, disc_start, disc_loss="hinge", disc_dim=64, disc_type='patchgan', image_size=256,
51 | disc_num_layers=3, disc_in_channels=3, disc_weight=1.0, disc_adaptive_weight = False,
52 | gen_adv_loss='hinge', reconstruction_loss='l2', reconstruction_weight=1.0,
53 | codebook_weight=1.0, perceptual_weight=1.0,
54 | ):
55 | super().__init__()
56 | # discriminator loss
57 | assert disc_type in ["patchgan", "stylegan"]
58 | assert disc_loss in ["hinge", "vanilla", "non-saturating"]
59 | if disc_type == "patchgan":
60 | self.discriminator = PatchGANDiscriminator(
61 | input_nc=disc_in_channels,
62 | n_layers=disc_num_layers,
63 | ndf=disc_dim,
64 | )
65 | elif disc_type == "stylegan":
66 | self.discriminator = StyleGANDiscriminator(
67 | input_nc=disc_in_channels,
68 | image_size=image_size,
69 | )
70 | else:
71 | raise ValueError(f"Unknown GAN discriminator type '{disc_type}'.")
72 | if disc_loss == "hinge":
73 | self.disc_loss = hinge_d_loss
74 | elif disc_loss == "vanilla":
75 | self.disc_loss = vanilla_d_loss
76 | elif disc_loss == "non-saturating":
77 | self.disc_loss = non_saturating_d_loss
78 | else:
79 | raise ValueError(f"Unknown GAN discriminator loss '{disc_loss}'.")
80 | self.discriminator_iter_start = disc_start
81 | self.disc_weight = disc_weight
82 | self.disc_adaptive_weight = disc_adaptive_weight
83 |
84 | assert gen_adv_loss in ["hinge", "non-saturating"]
85 | # gen_adv_loss
86 | if gen_adv_loss == "hinge":
87 | self.gen_adv_loss = hinge_gen_loss
88 | elif gen_adv_loss == "non-saturating":
89 | self.gen_adv_loss = non_saturating_gen_loss
90 | else:
91 | raise ValueError(f"Unknown GAN generator loss '{gen_adv_loss}'.")
92 |
93 | # perceptual loss
94 | self.perceptual_loss = LPIPS().eval()
95 | self.perceptual_weight = perceptual_weight
96 |
97 | # reconstruction loss
98 | if reconstruction_loss == "l1":
99 | self.rec_loss = F.l1_loss
100 | elif reconstruction_loss == "l2":
101 | self.rec_loss = F.mse_loss
102 | else:
103 | raise ValueError(f"Unknown rec loss '{reconstruction_loss}'.")
104 | self.rec_weight = reconstruction_weight
105 |
106 | # codebook loss
107 | self.codebook_weight = codebook_weight
108 |
109 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer):
110 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
111 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
112 |
113 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
114 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
115 | return d_weight.detach()
116 |
117 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, global_step, last_layer=None,
118 | logger=None, log_every=100):
119 | # generator update
120 | if optimizer_idx == 0:
121 | # reconstruction loss
122 | rec_loss = self.rec_loss(inputs.contiguous(), reconstructions.contiguous())
123 |
124 | # perceptual loss
125 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
126 | p_loss = torch.mean(p_loss)
127 |
128 | # discriminator loss
129 | logits_fake = self.discriminator(reconstructions.contiguous())
130 | generator_adv_loss = self.gen_adv_loss(logits_fake)
131 |
132 | if self.disc_adaptive_weight:
133 | null_loss = self.rec_weight * rec_loss + self.perceptual_weight * p_loss
134 | disc_adaptive_weight = self.calculate_adaptive_weight(null_loss, generator_adv_loss, last_layer=last_layer)
135 | else:
136 | disc_adaptive_weight = 1
137 | disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
138 |
139 | loss = self.rec_weight * rec_loss + \
140 | self.perceptual_weight * p_loss + \
141 | disc_adaptive_weight * disc_weight * generator_adv_loss + \
142 | codebook_loss[0] + codebook_loss[1] + codebook_loss[2]
143 |
144 | if global_step % log_every == 0:
145 | rec_loss = self.rec_weight * rec_loss
146 | p_loss = self.perceptual_weight * p_loss
147 | generator_adv_loss = disc_adaptive_weight * disc_weight * generator_adv_loss
148 | logger.info(f"(Generator) rec_loss: {rec_loss:.4f}, perceptual_loss: {p_loss:.4f}, "
149 | f"vq_loss: {codebook_loss[0]:.4f}, commit_loss: {codebook_loss[1]:.4f}, entropy_loss: {codebook_loss[2]:.4f}, "
150 | f"codebook_usage: {codebook_loss[3]:.4f}, generator_adv_loss: {generator_adv_loss:.4f}, "
151 | f"disc_adaptive_weight: {disc_adaptive_weight:.4f}, disc_weight: {disc_weight:.4f}")
152 | return loss
153 |
154 | # discriminator update
155 | if optimizer_idx == 1:
156 | logits_real = self.discriminator(inputs.contiguous().detach())
157 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
158 |
159 | disc_weight = adopt_weight(self.disc_weight, global_step, threshold=self.discriminator_iter_start)
160 | d_adversarial_loss = disc_weight * self.disc_loss(logits_real, logits_fake)
161 |
162 | if global_step % log_every == 0:
163 | logits_real = logits_real.detach().mean()
164 | logits_fake = logits_fake.detach().mean()
165 | logger.info(f"(Discriminator) "
166 | f"discriminator_adv_loss: {d_adversarial_loss:.4f}, disc_weight: {disc_weight:.4f}, "
167 | f"logits_real: {logits_real:.4f}, logits_fake: {logits_fake:.4f}")
168 | return d_adversarial_loss
--------------------------------------------------------------------------------
/tokenizer/tokenizer_image/vq_model_hf.py:
--------------------------------------------------------------------------------
1 | from huggingface_hub import PyTorchModelHubMixin
2 |
3 | from tokenizer.tokenizer_image.vq_model import ModelArgs, VQModel
4 |
5 | class VQModelHF(VQModel, PyTorchModelHubMixin, repo_url="https://github.com/FoundationVision/LlamaGen", license="mit", tags=["llamagen", "text-to-image"]):
6 | pass
7 |
8 | #################################################################################
9 | # VQ Model Configs #
10 | #################################################################################
11 | def VQ_8(**kwargs):
12 | return VQModelHF(ModelArgs(encoder_ch_mult=[1, 2, 2, 4], decoder_ch_mult=[1, 2, 2, 4], **kwargs))
13 |
14 | def VQ_16(**kwargs):
15 | return VQModelHF(ModelArgs(encoder_ch_mult=[1, 1, 2, 2, 4], decoder_ch_mult=[1, 1, 2, 2, 4], **kwargs))
16 |
17 | VQ_models_HF = {'VQ-16': VQ_16, 'VQ-8': VQ_8}
18 |
--------------------------------------------------------------------------------
/tokenizer/vae/README.md:
--------------------------------------------------------------------------------
1 | ## VAE Models from Stable Diffusion
2 |
3 | ### install
4 | ```
5 | pip install diffusers
6 | pip install accelerate
7 | ```
8 |
9 | ### demo
10 | ```
11 | cd ${THIS_REPO_ROOT}
12 | python3 tokenizer/vae/sd_vae_demo.py
13 | ```
14 |
15 |
--------------------------------------------------------------------------------
/tokenizer/vae/sd_vae_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from PIL import Image
6 | from diffusers.models import AutoencoderKL
7 |
8 |
9 | def main(args):
10 | # Setup PyTorch:
11 | torch.manual_seed(args.seed)
12 | torch.set_grad_enabled(False)
13 | device = "cuda" if torch.cuda.is_available() else "cpu"
14 |
15 | # create and load model
16 | vae = AutoencoderKL.from_pretrained(f"stabilityai/{args.vae}").to(device)
17 |
18 | # load image
19 | img_path = args.image_path
20 | out_path = args.image_path.replace('.jpg', '_vae.jpg').replace('.jpeg', '_vae.jpeg').replace('.png', '_vae.png')
21 | input_size = args.image_size
22 | img = Image.open(img_path).convert("RGB")
23 |
24 | # preprocess
25 | size_org = img.size
26 | img = img.resize((input_size, input_size))
27 | img = np.array(img) / 255.
28 | x = 2.0 * img - 1.0 # x value is between [-1, 1]
29 | x = torch.tensor(x)
30 | x = x.unsqueeze(dim=0)
31 | x = torch.einsum('nhwc->nchw', x)
32 | x_input = x.float().to("cuda")
33 |
34 | # inference
35 | with torch.no_grad():
36 | # Map input images to latent space + normalize latents:
37 | latent = vae.encode(x_input).latent_dist.sample().mul_(0.18215)
38 | # reconstruct:
39 | output = vae.decode(latent / 0.18215).sample # output value is between [-1, 1]
40 |
41 | # postprocess
42 | output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0]
43 | sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
44 |
45 | # save
46 | Image.fromarray(sample).save(out_path)
47 | print("Reconstructed image is saved to {}".format(out_path))
48 |
49 |
50 | if __name__ == "__main__":
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument("--image-path", type=str, default="assets/example.jpg")
53 | parser.add_argument("--vae", type=str, choices=["sdxl-vae", "sd-vae-ft-mse"], default="sd-vae-ft-mse")
54 | parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512)
55 | parser.add_argument("--seed", type=int, default=0)
56 | args = parser.parse_args()
57 | main(args)
--------------------------------------------------------------------------------
/tokenizer/validation/val_ddp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | torch.backends.cuda.matmul.allow_tf32 = True
3 | torch.backends.cudnn.allow_tf32 = True
4 | import torch.distributed as dist
5 | from torch.utils.data import Dataset, DataLoader
6 | from torch.utils.data.distributed import DistributedSampler
7 | from torchvision.datasets import ImageFolder
8 | from torchvision import transforms
9 | from tqdm import tqdm
10 | import os
11 | from PIL import Image
12 | import numpy as np
13 | import argparse
14 | import random
15 |
16 |
17 | class SingleFolderDataset(Dataset):
18 | def __init__(self, directory, transform=None):
19 | super().__init__()
20 | self.directory = directory
21 | self.transform = transform
22 | self.image_paths = [os.path.join(directory, file_name) for file_name in os.listdir(directory)
23 | if os.path.isfile(os.path.join(directory, file_name))]
24 |
25 | def __len__(self):
26 | return len(self.image_paths)
27 |
28 | def __getitem__(self, idx):
29 | image_path = self.image_paths[idx]
30 | image = Image.open(image_path).convert('RGB')
31 | if self.transform:
32 | image = self.transform(image)
33 | return image, torch.tensor(0)
34 |
35 |
36 | def create_npz_from_sample_folder(sample_dir, num=50_000):
37 | """
38 | Builds a single .npz file from a folder of .png samples.
39 | """
40 | samples = []
41 | for i in tqdm(range(num), desc="Building .npz file from samples"):
42 | sample_pil = Image.open(f"{sample_dir}/{i:06d}.png")
43 | sample_np = np.asarray(sample_pil).astype(np.uint8)
44 | samples.append(sample_np)
45 |
46 | random.shuffle(samples) # This is very important for IS(Inception Score) !!!
47 | samples = np.stack(samples)
48 | assert samples.shape == (num, samples.shape[1], samples.shape[2], 3)
49 | npz_path = f"{sample_dir}.npz"
50 | np.savez(npz_path, arr_0=samples)
51 | print(f"Saved .npz file to {npz_path} [shape={samples.shape}].")
52 | return npz_path
53 |
54 |
55 | def center_crop_arr(pil_image, image_size):
56 | """
57 | Center cropping implementation from ADM.
58 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
59 | """
60 | while min(*pil_image.size) >= 2 * image_size:
61 | pil_image = pil_image.resize(
62 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
63 | )
64 |
65 | scale = image_size / min(*pil_image.size)
66 | pil_image = pil_image.resize(
67 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
68 | )
69 |
70 | arr = np.array(pil_image)
71 | crop_y = (arr.shape[0] - image_size) // 2
72 | crop_x = (arr.shape[1] - image_size) // 2
73 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
74 |
75 |
76 | def main(args):
77 | # Setup PyTorch:
78 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
79 | torch.set_grad_enabled(False)
80 |
81 | # Setup env
82 | dist.init_process_group("nccl")
83 | rank = dist.get_rank()
84 | device = rank % torch.cuda.device_count()
85 | seed = args.global_seed * dist.get_world_size() + rank
86 | torch.manual_seed(seed)
87 | torch.cuda.set_device(device)
88 | print(f"Starting rank={rank}, seed={seed}, world_size={dist.get_world_size()}.")
89 |
90 | # Create folder to save samples:
91 | folder_name = f"val_{args.dataset}"
92 | sample_folder_dir = f"{args.sample_dir}/{folder_name}"
93 | if rank == 0:
94 | os.makedirs(sample_folder_dir, exist_ok=True)
95 | print(f"Saving .png samples at {sample_folder_dir}")
96 | dist.barrier()
97 |
98 | # Setup data:
99 | transform = transforms.Compose([
100 | transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
101 | transforms.ToTensor(),
102 | transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
103 | ])
104 |
105 | if args.dataset == 'imagenet':
106 | dataset = ImageFolder(args.data_path, transform=transform)
107 | num_fid_samples = 50000
108 | elif args.dataset == 'coco':
109 | dataset = SingleFolderDataset(args.data_path, transform=transform)
110 | num_fid_samples = 5000
111 | else:
112 | raise Exception("please check dataset")
113 |
114 | sampler = DistributedSampler(
115 | dataset,
116 | num_replicas=dist.get_world_size(),
117 | rank=rank,
118 | shuffle=False,
119 | seed=args.global_seed
120 | )
121 | loader = DataLoader(
122 | dataset,
123 | batch_size=args.per_proc_batch_size,
124 | shuffle=False,
125 | sampler=sampler,
126 | num_workers=args.num_workers,
127 | pin_memory=True,
128 | drop_last=False
129 | )
130 |
131 | # Figure out how many samples we need to generate on each GPU and how many iterations we need to run:
132 | n = args.per_proc_batch_size
133 | global_batch_size = n * dist.get_world_size()
134 |
135 | loader = tqdm(loader) if rank == 0 else loader
136 | total = 0
137 | for x, _ in loader:
138 | samples = torch.clamp(127.5 * x + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()
139 | # Save samples to disk as individual .png files
140 | for i, sample in enumerate(samples):
141 | index = i * dist.get_world_size() + rank + total
142 | Image.fromarray(sample).save(f"{sample_folder_dir}/{index:06d}.png")
143 |
144 | total += global_batch_size
145 |
146 | # Make sure all processes have finished saving their samples before attempting to convert to .npz
147 | dist.barrier()
148 | if rank == 0:
149 | create_npz_from_sample_folder(sample_folder_dir, num_fid_samples)
150 | print("Done.")
151 | dist.barrier()
152 | dist.destroy_process_group()
153 |
154 |
155 | if __name__ == "__main__":
156 | parser = argparse.ArgumentParser()
157 | parser.add_argument("--data-path", type=str, required=True)
158 | parser.add_argument("--dataset", type=str, choices=['imagenet', 'coco'], default='imagenet')
159 | parser.add_argument("--image-size", type=int, choices=[256, 512], default=256)
160 | parser.add_argument("--sample-dir", type=str, default="reconstructions")
161 | parser.add_argument("--per-proc-batch-size", type=int, default=32)
162 | parser.add_argument("--global-seed", type=int, default=0)
163 | parser.add_argument("--num-workers", type=int, default=4)
164 | args = parser.parse_args()
165 | main(args)
--------------------------------------------------------------------------------
/tokenizer/vqgan/README.md:
--------------------------------------------------------------------------------
1 | ## Pretrained VQVAE Models
2 |
3 | ### install
4 | ```
5 | pip install omegaconf
6 | pip install einops
7 | ```
8 | * download all needed models from https://github.com/CompVis/taming-transformers and put in pretrained_models/
9 | * pip install pytorch_lightning
10 | * python3 tools/convert_pytorch_lightning_to_torch.py
11 | * pip uninstall pytorch_lightning
12 |
13 | ### demo
14 | ```
15 | cd ${THIS_REPO_ROOT}
16 | python3 tokenizer/vqgan/taming_vqgan_demo.py
17 | ```
18 |
19 | ### acknowledge
20 | Codes in this folder are modified from from https://github.com/CompVis/taming-transformers
21 |
22 |
--------------------------------------------------------------------------------
/tokenizer/vqgan/configs/vqgan_imagenet_f16_1024.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: taming.models.vqgan.VQModel
4 | params:
5 | embed_dim: 256
6 | n_embed: 1024
7 | ddconfig:
8 | double_z: false
9 | z_channels: 256
10 | resolution: 256
11 | in_channels: 3
12 | out_ch: 3
13 | ch: 128
14 | ch_mult:
15 | - 1
16 | - 1
17 | - 2
18 | - 2
19 | - 4
20 | num_res_blocks: 2
21 | attn_resolutions:
22 | - 16
23 | dropout: 0.0
24 | lossconfig:
25 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
26 | params:
27 | disc_conditional: false
28 | disc_in_channels: 3
29 | disc_start: 0
30 | disc_weight: 0.8
31 | codebook_weight: 1.0
32 |
33 |
--------------------------------------------------------------------------------
/tokenizer/vqgan/configs/vqgan_imagenet_f16_16384.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | base_learning_rate: 4.5e-06
3 | target: taming.models.vqgan.VQModel
4 | params:
5 | embed_dim: 256
6 | n_embed: 16384
7 | monitor: val/rec_loss
8 | ddconfig:
9 | double_z: false
10 | z_channels: 256
11 | resolution: 256
12 | in_channels: 3
13 | out_ch: 3
14 | ch: 128
15 | ch_mult:
16 | - 1
17 | - 1
18 | - 2
19 | - 2
20 | - 4
21 | num_res_blocks: 2
22 | attn_resolutions:
23 | - 16
24 | dropout: 0.0
25 | lossconfig:
26 | target: taming.modules.losses.vqperceptual.VQLPIPSWithDiscriminator
27 | params:
28 | disc_conditional: false
29 | disc_in_channels: 3
30 | disc_start: 0
31 | disc_weight: 0.75
32 | disc_num_layers: 2
33 | codebook_weight: 1.0
34 |
35 |
--------------------------------------------------------------------------------
/tokenizer/vqgan/configs/vqgan_openimage_f8_16384.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | params:
3 | embed_dim: 4
4 | n_embed: 16384
5 | ddconfig:
6 | double_z: false
7 | z_channels: 4
8 | resolution: 256
9 | in_channels: 3
10 | out_ch: 3
11 | ch: 128
12 | ch_mult:
13 | - 1
14 | - 2
15 | - 2
16 | - 4
17 | num_res_blocks: 2
18 | attn_resolutions:
19 | - 32
20 | dropout: 0.0
--------------------------------------------------------------------------------
/tokenizer/vqgan/configs/vqgan_openimage_f8_256.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | params:
3 | embed_dim: 4
4 | n_embed: 256
5 | ddconfig:
6 | double_z: false
7 | z_channels: 4
8 | resolution: 256
9 | in_channels: 3
10 | out_ch: 3
11 | ch: 128
12 | ch_mult:
13 | - 1
14 | - 2
15 | - 2
16 | - 4
17 | num_res_blocks: 2
18 | attn_resolutions:
19 | - 32
20 | dropout: 0.0
--------------------------------------------------------------------------------
/tokenizer/vqgan/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from tokenizer.vqgan.layer import Encoder, Decoder
6 | from tokenizer.vqgan.quantize import VectorQuantizer2 as VectorQuantizer
7 |
8 |
9 | VQGAN_FROM_TAMING = {
10 | 'vqgan_imagenet_f16_1024': (
11 | 'tokenizer/vqgan/configs/vqgan_imagenet_f16_1024.yaml',
12 | 'pretrained_models/vqgan_imagenet_f16_1024/ckpts/last.pth'),
13 | 'vqgan_imagenet_f16_16384': (
14 | 'tokenizer/vqgan/configs/vqgan_imagenet_f16_16384.yaml',
15 | 'pretrained_models/vqgan_imagenet_f16_16384/ckpts/last.pth'),
16 | 'vqgan_openimage_f8_256': (
17 | 'tokenizer/vqgan/configs/vqgan_openimage_f8_256.yaml',
18 | 'pretrained_models/vq-f8-n256/model.pth'),
19 | 'vqgan_openimage_f8_16384': (
20 | 'tokenizer/vqgan/configs/vqgan_openimage_f8_16384.yaml',
21 | 'pretrained_models/vq-f8/model.pth'),
22 | }
23 |
24 | class VQModel(nn.Module):
25 | def __init__(self,
26 | ddconfig,
27 | n_embed,
28 | embed_dim,
29 | ckpt_path=None,
30 | ignore_keys=[],
31 | image_key="image",
32 | colorize_nlabels=None,
33 | monitor=None,
34 | remap=None,
35 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
36 | **kwargs,
37 | ):
38 | super().__init__()
39 | self.image_key = image_key
40 | self.encoder = Encoder(**ddconfig)
41 | self.decoder = Decoder(**ddconfig)
42 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
43 | remap=remap, sane_index_shape=sane_index_shape)
44 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
45 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
46 | if ckpt_path is not None:
47 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
48 | self.image_key = image_key
49 | if colorize_nlabels is not None:
50 | assert type(colorize_nlabels)==int
51 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
52 | if monitor is not None:
53 | self.monitor = monitor
54 |
55 | def init_from_ckpt(self, path, ignore_keys=list(), logging=True):
56 | model_weight = torch.load(path, map_location="cpu")["state_dict"]
57 | keys = list(model_weight.keys())
58 | for k in keys:
59 | for ik in ignore_keys:
60 | if k.startswith(ik):
61 | print("Deleting key {} from state_dict.".format(k))
62 | del model_weight[k]
63 | missing, unexpected = self.load_state_dict(model_weight, strict=False)
64 | if logging:
65 | print(f"Restored from {path}")
66 | print(f"Missing Keys in State Dict: {missing}")
67 | print(f"Unexpected Keys in State Dict: {unexpected}")
68 |
69 | def encode(self, x):
70 | h = self.encoder(x)
71 | h = self.quant_conv(h)
72 | quant, emb_loss, info = self.quantize(h)
73 | return quant, emb_loss, info
74 |
75 | def decode(self, quant):
76 | quant = self.post_quant_conv(quant)
77 | dec = self.decoder(quant)
78 | return dec
79 |
80 | def decode_code(self, code_b, shape, channel_first=True):
81 | quant_b = self.quantize.get_codebook_entry(code_b, shape, channel_first)
82 | dec = self.decode(quant_b)
83 | return dec
84 |
85 | def forward(self, input):
86 | quant, diff, _ = self.encode(input)
87 | dec = self.decode(quant)
88 | return dec, diff
89 |
--------------------------------------------------------------------------------
/tokenizer/vqgan/taming_vqgan_demo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import torch.nn.functional as F
4 | import numpy as np
5 | from PIL import Image
6 | from omegaconf import OmegaConf
7 | from tokenizer.vqgan.model import VQModel
8 | from tokenizer.vqgan.model import VQGAN_FROM_TAMING
9 |
10 | # before running demo, make sure to:
11 | # (1) download all needed models from https://github.com/CompVis/taming-transformers and put in pretrained_models/
12 | # (2) pip install pytorch_lightning
13 | # (3) python3 tools/convert_pytorch_lightning_to_torch.py
14 | # (4) pip uninstall pytorch_lightning
15 |
16 |
17 | def main(args):
18 | # Setup PyTorch:
19 | torch.manual_seed(args.seed)
20 | torch.set_grad_enabled(False)
21 | device = "cuda" if torch.cuda.is_available() else "cpu"
22 |
23 | # create and load model
24 | cfg, ckpt = VQGAN_FROM_TAMING[args.vqgan]
25 | config = OmegaConf.load(cfg)
26 | model = VQModel(**config.model.get("params", dict()))
27 | model.init_from_ckpt(ckpt)
28 | model.to(device)
29 | model.eval()
30 |
31 | # load image
32 | img_path = args.image_path
33 | out_path = args.image_path.replace('.jpg', '_vqgan.jpg').replace('.jpeg', '_vqgan.jpeg').replace('.png', '_vqgan.png')
34 | input_size = args.image_size
35 | img = Image.open(img_path).convert("RGB")
36 |
37 | # preprocess
38 | size_org = img.size
39 | img = img.resize((input_size, input_size))
40 | img = np.array(img) / 255.
41 | x = 2.0 * img - 1.0 # x value is between [-1, 1]
42 | x = torch.tensor(x)
43 | x = x.unsqueeze(dim=0)
44 | x = torch.einsum('nhwc->nchw', x)
45 | x_input = x.float().to("cuda")
46 |
47 | # inference
48 | with torch.no_grad():
49 | latent, _, [_, _, indices] = model.encode(x_input)
50 | output = model.decode_code(indices, latent.shape) # output value is between [-1, 1]
51 |
52 | # postprocess
53 | output = F.interpolate(output, size=[size_org[1], size_org[0]], mode='bilinear').permute(0, 2, 3, 1)[0]
54 | sample = torch.clamp(127.5 * output + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy()
55 |
56 | # save
57 | Image.fromarray(sample).save(out_path)
58 | print("Reconstructed image is saved to {}".format(out_path))
59 |
60 |
61 | if __name__ == "__main__":
62 | parser = argparse.ArgumentParser()
63 | parser.add_argument("--image-path", type=str, default="assets/example.jpg")
64 | parser.add_argument("--vqgan", type=str, choices=list(VQGAN_FROM_TAMING.keys()), default="vqgan_openimage_f8_16384")
65 | parser.add_argument("--image-size", type=int, choices=[256, 512, 1024], default=512)
66 | parser.add_argument("--seed", type=int, default=0)
67 | args = parser.parse_args()
68 | main(args)
69 |
--------------------------------------------------------------------------------
/tools/check_image_codes.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import torch
3 | import numpy as np
4 |
5 | from tokenizer.tokenizer_image.vq_model import VQ_models
6 | from torchvision.utils import save_image
7 |
8 |
9 | def main(args):
10 | # Setup PyTorch:
11 | torch.manual_seed(args.seed)
12 | torch.set_grad_enabled(False)
13 | device = "cuda" if torch.cuda.is_available() else "cpu"
14 |
15 | # create and load model
16 | vq_model = VQ_models[args.vq_model](
17 | codebook_size=args.codebook_size,
18 | codebook_embed_dim=args.codebook_embed_dim)
19 | vq_model.to(device)
20 | vq_model.eval()
21 | checkpoint = torch.load(args.vq_ckpt, map_location="cpu")
22 | vq_model.load_state_dict(checkpoint["model"])
23 | del checkpoint
24 |
25 | # load image code
26 | latent_dim = args.codebook_embed_dim
27 | latent_size = args.image_size // args.downsample_size
28 | codes = torch.from_numpy(np.load(args.code_path)).to(device)
29 | if codes.ndim == 3: # flip augmentation
30 | qzshape = (codes.shape[1], latent_dim, latent_size, latent_size)
31 | else:
32 | qzshape = (1, latent_dim, latent_size, latent_size)
33 | index_sample = codes.reshape(-1)
34 | samples = vq_model.decode_code(index_sample, qzshape) # output value is between [-1, 1]
35 |
36 | # save
37 | out_path = "sample_image_code.png"
38 | nrow = max(4, int(codes.shape[1]//2))
39 | save_image(samples, out_path, nrow=nrow, normalize=True, value_range=(-1, 1))
40 | print("Reconstructed image is saved to {}".format(out_path))
41 |
42 |
43 |
44 | if __name__ == "__main__":
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument("--code-path", type=str, required=True)
47 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
48 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
49 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
50 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
51 | parser.add_argument("--image-size", type=int, choices=[256, 384, 448, 512], default=256)
52 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
53 | parser.add_argument("--seed", type=int, default=0)
54 | args = parser.parse_args()
55 | main(args)
--------------------------------------------------------------------------------
/tools/convert_pytorch_lightning_to_torch.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | MODEL_PATH = 'pretrained_models'
5 | pt_lightnings = [
6 | 'vqgan_imagenet_f16_1024/ckpts/last.ckpt',
7 | 'vqgan_imagenet_f16_16384/ckpts/last.ckpt',
8 | 'vq-f8-n256/model.ckpt',
9 | 'vq-f8/model.ckpt',
10 | ]
11 | pts = [
12 | 'vqgan_imagenet_f16_1024/ckpts/last.pth',
13 | 'vqgan_imagenet_f16_16384/ckpts/last.pth',
14 | 'vq-f8-n256/model.pth',
15 | 'vq-f8/model.pth',
16 | ]
17 |
18 | for pt_l, pt in zip(pt_lightnings, pts):
19 | pt_l_weight = torch.load(os.path.join(MODEL_PATH, pt_l), map_location='cpu')
20 | pt_weight = {
21 | 'state_dict': pt_l_weight['state_dict']
22 | }
23 | pt_path = os.path.join(MODEL_PATH, pt)
24 | torch.save(pt_weight, pt_path)
25 | print(f'saving to {pt_path}')
26 |
--------------------------------------------------------------------------------
/tools/draw_figure.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 |
4 | font_size = 14
5 |
6 | def fid_scaling_law_no_cfg():
7 | # data
8 | steps = np.array([50, 100, 200, 300,])
9 | loss_b = np.array([41.025, 33.442, 32.105, 32.196])
10 | loss_l = np.array([25.889, 24.654, 19.742, 19.070])
11 | loss_xl = np.array([19.820, 18.037, 14.772, 15.549])
12 |
13 | steps_ = np.array([50, 200, 300,])
14 | loss_xxl = np.array([17.195, 13.997, 14.648])
15 | loss_3b = np.array([16.431, 9.949, 9.380])
16 | # Plot
17 | plt.figure(figsize=(6, 4))
18 |
19 | plt.plot(steps, loss_b, 'o-', label='B', color='red')
20 | plt.plot(steps, loss_l, 'o-', label='L', color='orange')
21 | plt.plot(steps, loss_xl, 'o-', label='XL', color='green')
22 | plt.plot(steps_, loss_xxl, 'o-', label='XXL', color='blue')
23 | plt.plot(steps_, loss_3b, 'o-', label='3B', color='purple')
24 |
25 | plt.xlabel('Training Epochs', fontsize=font_size)
26 | plt.ylabel('FID', fontsize=font_size)
27 | # plt.grid(True)
28 | # plt.yscale('log')
29 |
30 | # Customize the plot to match the appearance of the provided figure
31 | plt.legend(loc='upper right', framealpha=0.5, fontsize=font_size, facecolor='white')
32 |
33 | # Customizing the x and y axis ticks (to match the example's steps)
34 | # plt.xticks(np.linspace(0, 800000, 5), ['0', '200K', '400K', '600K', '800K'])
35 | plt.yticks(np.arange(5, 50, step=5))
36 |
37 | # Show plot
38 | plt.tight_layout()
39 | plt.savefig('fid_scaling_law_no_cfg.png', dpi=600)
40 |
41 |
42 |
43 | def fid_scaling_law_cfg():
44 | # data
45 | steps = np.array([50, 100, 200, 300,])
46 | loss_b_cfg = np.array([8.309, 7.256, 6.542, 6.249])
47 | loss_l_cfg = np.array([4.240, 3.705, 3.220, 3.075])
48 | loss_xl_cfg = np.array([3.420, 3.089, 2.617, 2.629])
49 |
50 | steps_ = np.array([50, 200, 300,])
51 | loss_xxl_cfg = np.array([2.893, 2.331, 2.340])
52 | loss_3b_cfg = np.array([2.611, 2.381, 2.329])
53 | # Plot
54 | plt.figure(figsize=(6, 4))
55 |
56 | plt.plot(steps, loss_b_cfg, 'o-', label='B', color='red')
57 | plt.plot(steps, loss_l_cfg, 'o-', label='L', color='orange')
58 | plt.plot(steps, loss_xl_cfg, 'o-', label='XL', color='green')
59 | plt.plot(steps_, loss_xxl_cfg, 'o-', label='XXL', color='blue')
60 | plt.plot(steps_, loss_3b_cfg, 'o-', label='3B', color='purple')
61 |
62 | plt.xlabel('Training Epochs', fontsize=font_size)
63 | plt.ylabel('FID', fontsize=font_size)
64 | # plt.grid(True)
65 | # plt.yscale('log')
66 |
67 | # Customize the plot to match the appearance of the provided figure
68 | plt.legend(loc='upper right', framealpha=0.5, fontsize=font_size, facecolor='white')
69 |
70 | # Customizing the x and y axis ticks (to match the example's steps)
71 | # plt.xticks(np.linspace(0, 800000, 5), ['0', '200K', '400K', '600K', '800K'])
72 | plt.yticks(np.arange(2, 9, step=1))
73 |
74 | # Show plot
75 | plt.tight_layout()
76 | plt.savefig('fid_scaling_law_cfg.png', dpi=600)
77 |
78 |
79 |
80 | def sample_topk():
81 | # Data
82 | top_k = np.array([16384, 10000, 8000, 6000, 4000, 2000, 1000])
83 | fid_values = np.array([3.075, 3.369, 3.643, 3.969, 4.635, 5.998, 7.428])
84 | inception_scores = np.array([256.067, 265.222, 268.237, 270.159, 271.455, 267.278, 251.268])
85 |
86 | fig, ax1 = plt.subplots()
87 | # Create first y-axis
88 | ax1.set_xlabel('top-k', fontsize=font_size)
89 | ax1.set_ylabel('FID', color='teal', fontsize=font_size)
90 | ax1.plot(top_k, fid_values, 'o-', color='teal', label="FID")
91 | ax1.tick_params(axis='y', labelcolor='teal')
92 | ax1.tick_params(axis='x')
93 |
94 | # Create second y-axis
95 | ax2 = ax1.twinx()
96 | ax2.set_ylabel('Inception Score', color='brown', fontsize=font_size)
97 | ax2.plot(top_k, inception_scores, 'o-', color='brown', label="Inception Score")
98 | ax2.tick_params(axis='y', labelcolor='brown')
99 |
100 | # Adding a legend
101 | fig.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0), bbox_transform=ax1.transAxes, fontsize=font_size)
102 |
103 | fig.tight_layout() # Adjust layout to prevent overlap
104 | plt.savefig('effect_topk.png', dpi=600)
105 |
106 |
107 |
108 | def sample_cfg():
109 | # Data
110 | cfg = np.array([1.5, 1.75, 2.00, 2.25])
111 | fid_values = np.array([4.743, 3.151, 3.075, 3.620])
112 | inception_scores = np.array([165.381, 214.152, 256.067, 291.695])
113 |
114 | plt.figure(figsize=(10, 4))
115 | fig, ax1 = plt.subplots()
116 | # Create first y-axis
117 | ax1.set_xlabel('cfg', fontsize=font_size)
118 | ax1.set_ylabel('FID', color='teal', fontsize=font_size)
119 | ax1.plot(cfg, fid_values, 'o-', color='teal', label="FID")
120 | ax1.tick_params(axis='y', labelcolor='teal')
121 | ax1.tick_params(axis='x')
122 |
123 | # Create second y-axis
124 | ax2 = ax1.twinx()
125 | ax2.set_ylabel('Inception Score', color='brown', fontsize=font_size)
126 | ax2.plot(cfg, inception_scores, 'o-', color='brown', label="Inception Score")
127 | ax2.tick_params(axis='y', labelcolor='brown')
128 |
129 | # Adding a legend
130 | fig.legend(loc='upper right', bbox_to_anchor=(1.0, 1.0), bbox_transform=ax1.transAxes, fontsize=font_size)
131 |
132 | fig.tight_layout() # Adjust layout to prevent overlap
133 | plt.savefig('effect_cfg.png', dpi=600)
134 |
135 |
136 |
137 | if __name__ == "__main__":
138 | fid_scaling_law_no_cfg()
139 | fid_scaling_law_cfg()
140 | sample_cfg()
141 | sample_topk()
142 |
--------------------------------------------------------------------------------
/tools/openimage_json.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import json
4 | from PIL import Image
5 | import multiprocessing as mp
6 |
7 | import warnings
8 | warnings.filterwarnings('ignore')
9 |
10 |
11 | def check_image(image_path):
12 | try:
13 | Image.open(image_path)
14 | return True
15 | except Exception as e:
16 | print(f"Error details: {str(e)}")
17 | return False
18 |
19 |
20 | def check_image_path(image_info):
21 | data_path, image_path_list = image_info # Unpack the info
22 | valid_image_paths = []
23 | for image_path in image_path_list:
24 | if check_image(os.path.join(data_path, image_path)):
25 | valid_image_paths.append(image_path)
26 | return valid_image_paths
27 |
28 |
29 | def load_image_path(image_info):
30 | folder_name, data_path, image_extensions = image_info # Unpack the info
31 | print(folder_name)
32 |
33 | folder_path = os.path.join(data_path, folder_name)
34 | local_image_paths = []
35 | for image_path in os.listdir(folder_path):
36 | _, file_extension = os.path.splitext(image_path)
37 | if file_extension.lower() in image_extensions:
38 | image_path_full = os.path.join(folder_name, image_path)
39 | local_image_paths.append(image_path_full)
40 | return local_image_paths
41 |
42 |
43 |
44 | def main(args):
45 | data_path = args.data_path
46 | image_extensions = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp']
47 |
48 | num_processes = 47
49 | work_list = [('openimages_{:0>4}'.format(idx), data_path, image_extensions) for idx in range(1, 48)]
50 | with mp.Pool(processes=num_processes) as pool:
51 | results = pool.map(load_image_path, work_list)
52 | image_paths = [image_path for sublist in results for image_path in sublist]
53 | print('image_paths is loaded')
54 |
55 |
56 | num_processes = max(mp.cpu_count() // 2, 4)
57 | unit = len(image_paths) // num_processes
58 | work_list = [(data_path, image_paths[idx*unit:(idx+1)*unit]) for idx in range(num_processes)]
59 | with mp.Pool(processes=num_processes) as pool:
60 | results = pool.map(check_image_path, work_list)
61 | valid_image_paths = [image_path for sublist in results for image_path in sublist]
62 | print('image_paths is checked')
63 |
64 |
65 | output_json_file_path = os.path.join(data_path, 'image_paths.json')
66 | with open(output_json_file_path, 'w') as outfile:
67 | json.dump(valid_image_paths, outfile, indent=4)
68 | print(f"Image paths have been saved to {output_json_file_path}")
69 |
70 |
71 | if __name__ == "__main__":
72 | parser = argparse.ArgumentParser()
73 | parser.add_argument("--data-path", type=str, required=True)
74 | args = parser.parse_args()
75 | main(args)
--------------------------------------------------------------------------------
/tools/push_gpt_to_hf.py:
--------------------------------------------------------------------------------
1 | # Modified from:
2 | # DiT: https://github.com/facebookresearch/DiT/blob/main/sample_ddp.py
3 | import torch
4 | torch.backends.cuda.matmul.allow_tf32 = True
5 | torch.backends.cudnn.allow_tf32 = True
6 | import argparse
7 |
8 | from tokenizer.tokenizer_image.vq_model import VQ_models
9 | from autoregressive.models.gpt_hf import GPT_models_HF, TransformerHF
10 |
11 | device = "cuda" if torch.cuda_is_available() else "cpu"
12 |
13 | def main(args):
14 | # Setup PyTorch:
15 | assert torch.cuda.is_available(), "Sampling with DDP requires at least one GPU. sample.py supports CPU-only usage"
16 | torch.set_grad_enabled(False)
17 |
18 | # create and load gpt model
19 | precision = {'none': torch.float32, 'bf16': torch.bfloat16, 'fp16': torch.float16}[args.precision]
20 | latent_size = args.image_size // args.downsample_size
21 | gpt_model = GPT_models_HF[args.gpt_model](
22 | vocab_size=args.codebook_size,
23 | block_size=latent_size ** 2,
24 | num_classes=args.num_classes,
25 | cls_token_num=args.cls_token_num,
26 | model_type=args.gpt_type,
27 | ).to(device=device, dtype=precision)
28 | checkpoint = torch.load(args.gpt_ckpt, map_location="cpu")
29 | if args.from_fsdp: # fsdp
30 | model_weight = checkpoint
31 | elif "model" in checkpoint: # ddp
32 | model_weight = checkpoint["model"]
33 | elif "module" in checkpoint: # deepspeed
34 | model_weight = checkpoint["module"]
35 | elif "state_dict" in checkpoint:
36 | model_weight = checkpoint["state_dict"]
37 | else:
38 | raise Exception("please check model weight, maybe add --from-fsdp to run command")
39 |
40 | # load weights
41 | gpt_model.load_state_dict(model_weight, strict=False)
42 | gpt_model.eval()
43 | del checkpoint
44 |
45 | # push to hub
46 | repo_id = f"FoundationVision/{args.gpt_model}-{args.image_size}"
47 | gpt_model.push_to_hub(repo_id)
48 |
49 | # reload
50 | model = TransformerHF.from_pretrained(repo_id)
51 |
52 |
53 | if __name__ == "__main__":
54 | parser = argparse.ArgumentParser()
55 | parser.add_argument("--gpt-model", type=str, choices=list(GPT_models.keys()), default="GPT-B")
56 | parser.add_argument("--gpt-ckpt", type=str, default=None)
57 | parser.add_argument("--gpt-type", type=str, choices=['c2i', 't2i'], default="c2i", help="class-conditional or text-conditional")
58 | parser.add_argument("--from-fsdp", action='store_true')
59 | parser.add_argument("--cls-token-num", type=int, default=1, help="max token number of condition input")
60 | parser.add_argument("--precision", type=str, default='bf16', choices=["none", "fp16", "bf16"])
61 | parser.add_argument("--compile", action='store_true', default=True)
62 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models.keys()), default="VQ-16")
63 | parser.add_argument("--vq-ckpt", type=str, default=None, help="ckpt path for vq model")
64 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
65 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
66 | parser.add_argument("--image-size", type=int, choices=[256, 384, 512], default=384)
67 | parser.add_argument("--image-size-eval", type=int, choices=[256, 384, 512], default=256)
68 | parser.add_argument("--downsample-size", type=int, choices=[8, 16], default=16)
69 | parser.add_argument("--num-classes", type=int, default=1000)
70 | args = parser.parse_args()
71 | main(args)
--------------------------------------------------------------------------------
/tools/push_vae_to_hf.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to push and load custom PyTorch models to/from the Hugging Face Hub.
3 | """
4 |
5 | import argparse
6 | import torch
7 | from tokenizer.tokenizer_image.vq_model_hf import VQ_models_HF, VQModelHF
8 |
9 | from huggingface_hub import hf_hub_download
10 |
11 |
12 | model2ckpt = {
13 | "GPT-XL": ("vq_ds16_c2i.pt", "c2i_XL_384.pt", 384),
14 | "GPT-B": ("vq_ds16_c2i.pt", "c2i_B_256.pt", 256),
15 | }
16 |
17 | def load_model(args):
18 | ckpt_folder = "./"
19 | vq_ckpt, gpt_ckpt, _ = model2ckpt[args.gpt_model]
20 | hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=vq_ckpt, local_dir=ckpt_folder)
21 | hf_hub_download(repo_id="FoundationVision/LlamaGen", filename=gpt_ckpt, local_dir=ckpt_folder)
22 | # create and load model
23 | vq_model = VQ_models_HF[args.vq_model](
24 | codebook_size=args.codebook_size,
25 | codebook_embed_dim=args.codebook_embed_dim)
26 | vq_model.eval()
27 | checkpoint = torch.load(f"{ckpt_folder}{vq_ckpt}", map_location="cpu")
28 | vq_model.load_state_dict(checkpoint["model"])
29 | del checkpoint
30 | print(f"image tokenizer is loaded")
31 | return vq_model
32 |
33 |
34 | parser = argparse.ArgumentParser()
35 | parser.add_argument("--gpt-model", type=str, default="GPT-XL")
36 | parser.add_argument("--vq-model", type=str, choices=list(VQ_models_HF.keys()), default="VQ-16")
37 | parser.add_argument("--codebook-size", type=int, default=16384, help="codebook size for vector quantization")
38 | parser.add_argument("--codebook-embed-dim", type=int, default=8, help="codebook dimension for vector quantization")
39 | args = parser.parse_args()
40 |
41 | # load weights
42 | vq_model = load_model(args)
43 |
44 | # push to hub
45 | vq_model.push_to_hub("FoundationVision/vq-ds16-c2i")
46 |
47 | # reload
48 | model = VQModelHF.from_pretrained("FoundationVision/vq-ds16-c2i")
--------------------------------------------------------------------------------
/utils/data.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 |
4 | def center_crop_arr(pil_image, image_size):
5 | """
6 | Center cropping implementation from ADM.
7 | https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
8 | """
9 | while min(*pil_image.size) >= 2 * image_size:
10 | pil_image = pil_image.resize(
11 | tuple(x // 2 for x in pil_image.size), resample=Image.BOX
12 | )
13 |
14 | scale = image_size / min(*pil_image.size)
15 | pil_image = pil_image.resize(
16 | tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
17 | )
18 |
19 | arr = np.array(pil_image)
20 | crop_y = (arr.shape[0] - image_size) // 2
21 | crop_x = (arr.shape[1] - image_size) // 2
22 | return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
--------------------------------------------------------------------------------
/utils/deepspeed.py:
--------------------------------------------------------------------------------
1 | def create_deepspeed_config(args):
2 | ds_config = {
3 | "steps_per_print": 1000,
4 | "train_batch_size": args.global_batch_size,
5 | "gradient_accumulation_steps": args.gradient_accumulation_steps,
6 | # "train_micro_batch_size_per_gpu": args.batch_size, # determined by (train_batch_size, gradient_accumulation_steps)
7 | "optimizer": {
8 | "type": "Adam",
9 | "adam_w_mode": True,
10 | "params": {
11 | "lr": args.lr,
12 | "weight_decay": args.weight_decay,
13 | "bias_correction": True,
14 | "betas": [
15 | args.beta1,
16 | args.beta2
17 | ],
18 | }
19 | },
20 | "fp16": {
21 | "enabled": args.mixed_precision == 'fp16',
22 | "loss_scale": 0,
23 | "initial_scale_power": 16,
24 | "loss_scale_window": 1000,
25 | "hysteresis": 2,
26 | "min_loss_scale": 1
27 | },
28 | "bf16": {
29 | "enabled": args.mixed_precision == 'bf16',
30 | },
31 | # "flops_profiler": {
32 | # "enabled": True,
33 | # "profile_step": -1,
34 | # "module_depth": -1,
35 | # "top_modules": 1,
36 | # "detailed": True,
37 | # },
38 | "zero_allow_untested_optimizer": True
39 | }
40 |
41 | if args.clip_grad is not None:
42 | ds_config.update({'gradient_clipping': args.clip_grad})
43 |
44 | if args.zero_stage == 0:
45 | ds_config.update({"zero_optimization":
46 | {
47 | "stage": args.zero_stage,
48 | "contiguous_gradients": True,
49 | "overlap_comm": True,
50 | }
51 | })
52 | elif args.zero_stage == 1:
53 | ds_config.update({"zero_optimization":
54 | {
55 | "stage": args.zero_stage,
56 | "contiguous_gradients": True,
57 | "overlap_comm": True,
58 | "reduce_bucket_size": 5e8,
59 | }
60 | })
61 | elif args.zero_stage == 2:
62 | ds_config.update({"zero_optimization":
63 | {
64 | "stage": args.zero_stage,
65 | "contiguous_gradients": True,
66 | "overlap_comm": True,
67 | "reduce_scatter": True,
68 | "reduce_bucket_size": 5e8,
69 | "allgather_bucket_size": 5e8,
70 | }
71 | })
72 | elif args.zero_stage == 3:
73 | ds_config.update({"zero_optimization":
74 | {
75 | "stage": args.zero_stage,
76 | "contiguous_gradients": True,
77 | "overlap_comm": True,
78 | "reduce_bucket_size": 5e8,
79 | "stage3_prefetch_bucket_size": 5e8,
80 | "stage3_param_persistence_threshold": 1e6,
81 | "stage3_max_live_parameters": 1e9,
82 | "stage3_max_reuse_distance": 1e9,
83 | "stage3_gather_16bit_weights_on_model_save": True
84 | }
85 | })
86 |
87 | return ds_config
88 |
--------------------------------------------------------------------------------
/utils/distributed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import subprocess
4 |
5 |
6 | def setup_for_distributed(is_master):
7 | """
8 | This function disables printing when not in master process
9 | """
10 | import builtins as __builtin__
11 | builtin_print = __builtin__.print
12 |
13 | def print(*args, **kwargs):
14 | force = kwargs.pop('force', False)
15 | if is_master or force:
16 | builtin_print(*args, **kwargs)
17 |
18 | __builtin__.print = print
19 |
20 | def init_distributed_mode(args):
21 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
22 | args.rank = int(os.environ["RANK"])
23 | args.world_size = int(os.environ['WORLD_SIZE'])
24 | args.gpu = int(os.environ['LOCAL_RANK'])
25 | args.dist_url = 'env://'
26 | os.environ['LOCAL_SIZE'] = str(torch.cuda.device_count())
27 | elif 'SLURM_PROCID' in os.environ:
28 | proc_id = int(os.environ['SLURM_PROCID'])
29 | ntasks = int(os.environ['SLURM_NTASKS'])
30 | node_list = os.environ['SLURM_NODELIST']
31 | num_gpus = torch.cuda.device_count()
32 | addr = subprocess.getoutput(
33 | 'scontrol show hostname {} | head -n1'.format(node_list))
34 | os.environ['MASTER_PORT'] = os.environ.get('MASTER_PORT', '29500')
35 | os.environ['MASTER_ADDR'] = addr
36 | os.environ['WORLD_SIZE'] = str(ntasks)
37 | os.environ['RANK'] = str(proc_id)
38 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
39 | os.environ['LOCAL_SIZE'] = str(num_gpus)
40 | args.dist_url = 'env://'
41 | args.world_size = ntasks
42 | args.rank = proc_id
43 | args.gpu = proc_id % num_gpus
44 | else:
45 | print('Not using distributed mode')
46 | args.distributed = False
47 | return
48 |
49 | args.distributed = True
50 |
51 | torch.cuda.set_device(args.gpu)
52 | args.dist_backend = 'nccl'
53 | print('| distributed init (rank {}): {}'.format(
54 | args.rank, args.dist_url), flush=True)
55 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
56 | world_size=args.world_size, rank=args.rank)
57 | torch.distributed.barrier()
58 | setup_for_distributed(args.rank == 0)
59 |
--------------------------------------------------------------------------------
/utils/drop_path.py:
--------------------------------------------------------------------------------
1 | # from timm.models.layers import DropPath
2 | import torch
3 |
4 | def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
5 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
6 |
7 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
8 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
9 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
10 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
11 | 'survival rate' as the argument.
12 |
13 | """
14 | if drop_prob == 0. or not training:
15 | return x
16 | keep_prob = 1 - drop_prob
17 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
18 | random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
19 | if keep_prob > 0.0 and scale_by_keep:
20 | random_tensor.div_(keep_prob)
21 | return x * random_tensor
22 |
23 |
24 | class DropPath(torch.nn.Module):
25 | """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
26 | """
27 | def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
28 | super(DropPath, self).__init__()
29 | self.drop_prob = drop_prob
30 | self.scale_by_keep = scale_by_keep
31 |
32 | def forward(self, x):
33 | return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
34 |
35 | def extra_repr(self):
36 | return f'drop_prob={round(self.drop_prob,3):0.3f}'
--------------------------------------------------------------------------------
/utils/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 |
4 | @torch.no_grad()
5 | def update_ema(ema_model, model, decay=0.9999):
6 | """
7 | Step the EMA model towards the current model.
8 | """
9 | ema_params = OrderedDict(ema_model.named_parameters())
10 | model_params = OrderedDict(model.named_parameters())
11 |
12 | for name, param in model_params.items():
13 | # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
14 | ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
15 |
16 |
17 | def requires_grad(model, flag=True):
18 | """
19 | Set requires_grad flag for all parameters in a model.
20 | """
21 | for p in model.parameters():
22 | p.requires_grad = flag
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch.distributed as dist
3 |
4 | def create_logger(logging_dir):
5 | """
6 | Create a logger that writes to a log file and stdout.
7 | """
8 | if dist.get_rank() == 0: # real logger
9 | logging.basicConfig(
10 | level=logging.INFO,
11 | format='[\033[34m%(asctime)s\033[0m] %(message)s',
12 | datefmt='%Y-%m-%d %H:%M:%S',
13 | handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
14 | )
15 | logger = logging.getLogger(__name__)
16 | else: # dummy logger (does nothing)
17 | logger = logging.getLogger(__name__)
18 | logger.addHandler(logging.NullHandler())
19 | return logger
--------------------------------------------------------------------------------
/utils/video.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import skvideo.io
4 | from PIL import Image
5 |
6 | # Shifts src_tf dim to dest dim
7 | # i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c)
8 | def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
9 | n_dims = len(x.shape)
10 | if src_dim < 0:
11 | src_dim = n_dims + src_dim
12 | if dest_dim < 0:
13 | dest_dim = n_dims + dest_dim
14 |
15 | assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
16 |
17 | dims = list(range(n_dims))
18 | del dims[src_dim]
19 |
20 | permutation = []
21 | ctr = 0
22 | for i in range(n_dims):
23 | if i == dest_dim:
24 | permutation.append(src_dim)
25 | else:
26 | permutation.append(dims[ctr])
27 | ctr += 1
28 | x = x.permute(permutation)
29 | if make_contiguous:
30 | x = x.contiguous()
31 | return x
32 |
33 | # reshapes tensor start from dim i (inclusive)
34 | # to dim j (exclusive) to the desired shape
35 | # e.g. if x.shape = (b, thw, c) then
36 | # view_range(x, 1, 2, (t, h, w)) returns
37 | # x of shape (b, t, h, w, c)
38 | def view_range(x, i, j, shape):
39 | shape = tuple(shape)
40 |
41 | n_dims = len(x.shape)
42 | if i < 0:
43 | i = n_dims + i
44 |
45 | if j is None:
46 | j = n_dims
47 | elif j < 0:
48 | j = n_dims + j
49 |
50 | assert 0 <= i < j <= n_dims
51 |
52 | x_shape = x.shape
53 | target_shape = x_shape[:i] + shape + x_shape[j:]
54 | return x.view(target_shape)
55 |
56 |
57 | def tensor_slice(x, begin, size):
58 | assert all([b >= 0 for b in begin])
59 | size = [l - b if s == -1 else s
60 | for s, b, l in zip(size, begin, x.shape)]
61 | assert all([s >= 0 for s in size])
62 |
63 | slices = [slice(b, b + s) for b, s in zip(begin, size)]
64 | return x[slices]
65 |
66 |
67 | def save_video_grid(video, fname, nrow=None, fps=5):
68 | b, c, t, h, w = video.shape
69 | video = video.permute(0, 2, 3, 4, 1)
70 | video = (video.cpu().numpy() * 255).astype('uint8')
71 |
72 | if nrow is None:
73 | nrow = math.ceil(math.sqrt(b))
74 | ncol = math.ceil(b / nrow)
75 | padding = 1
76 | video_grid = np.zeros((t, (padding + h) * nrow + padding,
77 | (padding + w) * ncol + padding, c), dtype='uint8')
78 | for i in range(b):
79 | r = i // ncol
80 | c = i % ncol
81 |
82 | start_r = (padding + h) * r
83 | start_c = (padding + w) * c
84 | video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
85 |
86 | skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '{}'.format(fps)})
87 |
88 |
89 | def save_gif_grid(video, file_name, nrow=None, fps=5):
90 | b, c, t, h, w = video.shape
91 | video = video.permute(0, 2, 3, 4, 1)
92 | video = (video.cpu().numpy() * 255).astype('uint8')
93 |
94 | if nrow is None:
95 | nrow = math.ceil(math.sqrt(b))
96 | ncol = math.ceil(b / nrow)
97 | padding = 1
98 | video_grid = np.zeros((t, (padding + h) * nrow + padding,
99 | (padding + w) * ncol + padding, c), dtype='uint8')
100 | for i in range(b):
101 | r = i // ncol
102 | c = i % ncol
103 |
104 | start_r = (padding + h) * r
105 | start_c = (padding + w) * c
106 | video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
107 |
108 | images = []
109 | for frame in video_grid:
110 | images.append(Image.fromarray(frame))
111 |
112 | # Save the first image and append the rest of the images as frames in the GIF
113 | images[0].save(file_name, save_all=True, append_images=images[1:], optimize=False, duration=int(1000/fps), loop=0)
114 |
115 | # The 'duration' parameter defines the display time for each frame in milliseconds
116 | # The 'loop' parameter defines the number of loops the GIF should make (0 for infinite loop)
117 |
--------------------------------------------------------------------------------