├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── Makefile ├── README.md ├── configs ├── docker.json ├── local.json └── models │ ├── vqgan_coco_f16_8192.json │ ├── vqgan_custom.json │ ├── vqgan_custom_docker.json │ ├── vqgan_faceshq_f16_1024.json │ ├── vqgan_imagenet_f16_1024.json │ └── vqgan_imagenet_f16_16384.json ├── core ├── clip │ ├── README.md │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── clip.py │ ├── model.py │ └── simple_tokenizer.py ├── optimizer │ ├── __init__.py │ ├── adamp.py │ ├── diffgrad.py │ └── radam.py ├── schemas │ ├── __init__.py │ ├── config.py │ └── train_config.py ├── taming │ ├── README.md │ ├── models │ │ ├── __init__.py │ │ └── vqgan.py │ ├── modules │ │ ├── diffusion │ │ │ ├── __init__.py │ │ │ ├── attn_block.py │ │ │ ├── decoder.py │ │ │ ├── downsample.py │ │ │ ├── encoder.py │ │ │ ├── resnet_block.py │ │ │ └── upsample.py │ │ ├── discriminator │ │ │ ├── __init__.py │ │ │ ├── act_norm.py │ │ │ └── discriminator.py │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── lpips.py │ │ │ └── vqperceptual.py │ │ └── vqvae │ │ │ ├── __init__.py │ │ │ └── vector_quantizer.py │ └── utils │ │ ├── __init__.py │ │ ├── diffusion_utils.py │ │ ├── discriminator_utils.py │ │ └── losses_utils.py └── utils │ ├── __init__.py │ ├── gradients.py │ ├── helpers.py │ ├── loader.py │ ├── make_cutouts.py │ ├── noises.py │ ├── normalize.py │ └── prompt.py ├── data └── .gitignore ├── docker-compose.yml ├── models └── .gitignore ├── outputs └── .gitignore ├── requirements.txt ├── samples ├── forest.png ├── ghost_pokemon.png ├── gundam.png ├── landscape.png ├── sailor_moon.png └── waterfall.png └── scripts ├── generate.py └── train.py /.dockerignore: -------------------------------------------------------------------------------- 1 | ./models 2 | ./data 3 | ./samples 4 | ./outputs 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime 2 | 3 | WORKDIR /app 4 | 5 | COPY ./requirements.txt /requirements.txt 6 | RUN python -m pip install -r /requirements.txt 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Kevin Costa 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | COMPOSE=docker-compose -f docker-compose.yml 2 | 3 | all: build 4 | 5 | build: 6 | $(COMPOSE) build 7 | 8 | generate: 9 | $(COMPOSE) run generate 10 | 11 | generate-cpu: 12 | $(COMPOSE) run -e DEVICE='cpu' generate 13 | 14 | train: 15 | $(COMPOSE) run train 16 | 17 | train-cpu: 18 | $(COMPOSE) run -e DEVICE='cpu' train 19 | 20 | 21 | .PHONY: all build generate generate-cpu train train-cpu 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VQGAN-CLIP-Docker 2 | 3 | - [Setup](#Setup) 4 | - [Usage](#Usage) 5 | - [Inference](#Inference) 6 | - [Training](#Training) 7 | - [Acknowledgments](#Acknowledgments) 8 | - [Citations](#Citations) 9 | 10 | ## About 11 | 12 | > Zero-Shot Text-to-Image Generation VQGAN+CLIP Dockerized 13 | 14 | This is a stripped and minimal dependencies repository for running locally or in production VQGAN+CLIP. 15 | 16 | For a Google Colab notebook [see the original repository](#Acknowledgments). 17 | 18 | ## Samples 19 | 20 |
21 | 22 | 23 | 24 | 25 | 26 | 27 |
28 | 29 | 30 | # Setup 31 | 32 | Clone this repository and `cd` inside. 33 | 34 | ```sh 35 | git clone https://github.com/kcosta42/VQGAN-CLIP-Docker.git 36 | cd VQGAN-CLIP-Docker 37 | ``` 38 | 39 | You can download a pretrained VQGAN model and put it in the `./models` folder. 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 |
DatasetLinkConfig
ImageNet (f=16), 16384vqgan_imagenet_f16_16384.ckpt./configs/models/vqgan_imagenet_f16_16384.json
ImageNet (f=16), 1024vqgan_imagenet_f16_1024.ckpt./configs/models/vqgan_imagenet_f16_1024.json
FacesHQ (f=16)vqgan_faceshq_f16_1024.ckpt./configs/models/vqgan_faceshq_f16_1024.json
COCO-Stuff (f=16)vqgan_coco_f16_8192.ckpt./configs/models/vqgan_coco_f16_8192.json
72 | 73 | 74 | For GPU capability, make sure you have CUDA installed on your system (tested with CUDA 11.1+). 75 | 76 | - 6 GB of VRAM is required to generate 256x256 images. 77 | - 11 GB of VRAM is required to generate 512x512 images. 78 | - 24 GB of VRAM is required to generate 1024x1024 images. (Untested) 79 | 80 | ## Local 81 | 82 | Install the Python requirements 83 | 84 | ```sh 85 | python3 -m pip install -r requirements.txt 86 | ``` 87 | 88 | To know if you can run this on your GPU, the following command must return `True`. 89 | ```sh 90 | python3 -c "import torch; print(torch.cuda.is_available());" 91 | ``` 92 | 93 | ## Docker 94 | 95 | > Make sure you have `docker` and `docker-compose` v1.28.0+ installed. `nvidia-docker` is needed if you want to run this on your GPU through Docker. 96 | 97 | A Makefile is provided for ease of use. 98 | 99 | ```sh 100 | make build # Build the docker image 101 | ``` 102 | 103 | # Usage 104 | 105 | ## Inference 106 | 107 | Two configuration files are provided `./configs/local.json` and `./configs/docker.json`. They are ready to go, but you may want to edit them to meet your need. Check the [Configuration section](#Configuration) to understand each field. 108 | 109 | By default, the resulting generations can be found in the `./outputs` folder. 110 | 111 | ### GPU 112 | 113 | To run locally: 114 | 115 | ```py 116 | python3 -m scripts.generate -c ./configs/local.json 117 | ``` 118 | 119 | To run on docker: 120 | 121 | ```py 122 | make generate 123 | ``` 124 | 125 | ### CPU 126 | 127 | To run locally: 128 | 129 | ```py 130 | DEVICE=cpu python3 -m scripts.generate -c ./configs/local.json 131 | ``` 132 | 133 | To run on docker: 134 | 135 | ```py 136 | make generate-cpu 137 | ``` 138 | 139 | ### Configuration 140 | 141 | | Argument | Type | Descriptions | 142 | |------------------------|----------------|--------------------------------------------------------------------------------| 143 | | `prompts` | List[str] | Text prompts | 144 | | `image_prompts` | List[FilePath] | Image prompts / target image path | 145 | | `max_iterations` | int | Number of iterations | 146 | | `save_freq` | int | Save image iterations | 147 | | `size` | [int, int] | Image size (width height) | 148 | | `pixelart` | [int, int] | Pixelart image size (width height) (Optional, remove option to disable) | 149 | | `init_image` | FilePath | Initial image | 150 | | `init_noise` | str | Initial noise image ["gradient","pixels","fractal"] | 151 | | `init_weight` | float | Initial weight | 152 | | `mse_decay_rate` | int | Slowly decrease the MSE Loss each specified iterations until it reach about 0 | 153 | | `output_dir` | FilePath | Path to output directory | 154 | | `models_dir` | FilePath | Path to models cache directory | 155 | | `clip_model` | FilePath | CLIP model path or name | 156 | | `vqgan_checkpoint` | FilePath | VQGAN checkpoint path | 157 | | `vqgan_config` | FilePath | VQGAN config path | 158 | | `noise_prompt_seeds` | List[int] | Noise prompt seeds | 159 | | `noise_prompt_weights` | List[float] | Noise prompt weights | 160 | | `step_size` | float | Learning rate | 161 | | `cutn` | int | Number of cuts | 162 | | `cut_pow` | float | Cut power | 163 | | `seed` | int | Seed (-1 for random seed) | 164 | | `optimizer` | str | Optimiser ["Adam","AdamW","Adagrad","Adamax","DiffGrad","AdamP","RAdam"] | 165 | | `nwarm_restarts` | int | Number of time the learning rate is reseted (-1 to disable LR decay) | 166 | | `augments` | List[str] | Enabled augments ["Ji","Sh","Gn","Pe","Ro","Af","Et","Ts","Cr","Er","Re","Hf"] | 167 | 168 | ## Training 169 | 170 | > These are instructions to train a new VQGAN model. You can also finetunes the pretrained models but you may need to tweak the training script. 171 | 172 | Two models configuration files are provided `./configs/models/vqgan_custom.json` and `./configs/models/vqgan_custom_docker.json`. They are ready to go, but you may want to edit them to meet your need. Check the [Model Configuration](#Model-Configuration) to understand each field. 173 | 174 | By default, the models are saved in the `./models/checkpoints` folder. 175 | 176 | ### Dataset 177 | 178 | Put your image in a folder inside the data directory (`./data` by default). 179 | 180 | The dataset must be structured as follow: 181 | 182 | ```sh 183 | ./data/ 184 | ├── class_x/ 185 | │ ├── xxx.png 186 | │ ├── xxy.jpg 187 | │ └── ... 188 | │ └── xxz.ppm 189 | └── class_y/ 190 | ├── 123.bmp 191 | ├── nsdf3.tif 192 | └── ... 193 | └── asd932_.webp 194 | ``` 195 | 196 | ### GPU 197 | 198 | To run locally: 199 | 200 | ```py 201 | python3 -m scripts.train -c ./configs/models/vqgan_custom.json 202 | ``` 203 | 204 | To run on docker: 205 | 206 | ```py 207 | make train 208 | ``` 209 | 210 | ### CPU 211 | 212 | To run locally: 213 | 214 | ```py 215 | DEVICE=cpu python3 -m scripts.train -c ./configs/models/vqgan_custom.json 216 | ``` 217 | 218 | To run on docker: 219 | 220 | ```py 221 | make train-cpu 222 | ``` 223 | 224 | ### Model Configuration 225 | 226 | | Argument | Type | Descriptions | 227 | |------------------------|----------------|---------------------------------------------------------------------------| 228 | | `base_learning_rate` | float | Initial Learning rate | 229 | | `batch_size` | int | Batch size (Adjust based on your GPU capability) | 230 | | `epochs` | int | Maximum number of epoch | 231 | | `output_dir` | FilePath | Path to directory where to save training images | 232 | | `models_dir` | FilePath | Path to directory where to save the model | 233 | | `data_dir` | FilePath | Path to data directory | 234 | | `seed` | int | Seed (-1 for random seed) | 235 | | `resume_checkpoint` | FilePath | Path to pretrained model | 236 | 237 | ### Infos 238 | 239 | - Let the Generator train without the Discriminator for a few epochs (~3-5 epochs for ImageNet), then enable the Discriminator.
The variable `lossconfig.params.disc_start` correspond to the number of global step (ie. batch iterations) before enabling the Discriminator. 240 | - Once enabled, the Discriminator loss will stagnate around ~1.0, __this is a normal behaviour__. The loss will decrease in later epochs. (It can take a _very_ long time). 241 | - If you've enabled the Discriminator too soon, the Generator will take a lot more time to train. 242 | - Basically there is no rules for the number of epochs. If your dataset is large enough, there is no risk of overfitting. So the more you train, the better. 243 | 244 | 245 | # Acknowledgments 246 | 247 | [VQGAN+CLIP](https://github.com/nerdyrodent/VQGAN-CLIP) 248 | 249 | [Taming Transformers](https://github.com/CompVis/taming-transformers) 250 | 251 | [CLIP](https://github.com/openai/CLIP) 252 | 253 | [DALLE-PyTorch](https://github.com/lucidrains/DALLE-pytorch) 254 | 255 | # Citations 256 | 257 | ```bibtex 258 | @misc{unpublished2021clip, 259 | title = {CLIP: Connecting Text and Images}, 260 | author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal}, 261 | year = {2021} 262 | } 263 | ``` 264 | 265 | ```bibtex 266 | @misc{esser2020taming, 267 | title={Taming Transformers for High-Resolution Image Synthesis}, 268 | author={Patrick Esser and Robin Rombach and Björn Ommer}, 269 | year={2020}, 270 | eprint={2012.09841}, 271 | archivePrefix={arXiv}, 272 | primaryClass={cs.CV} 273 | } 274 | ``` 275 | 276 | ```bibtex 277 | @misc{ramesh2021zeroshot, 278 | title = {Zero-Shot Text-to-Image Generation}, 279 | author = {Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever}, 280 | year = {2021}, 281 | eprint = {2102.12092}, 282 | archivePrefix = {arXiv}, 283 | primaryClass = {cs.CV} 284 | } 285 | ``` 286 | -------------------------------------------------------------------------------- /configs/docker.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompts": ["a painting of a potato"], 3 | "image_prompts": [], 4 | "max_iterations": 250, 5 | "save_freq": 50, 6 | "size": [256, 256], 7 | "init_image": "", 8 | "init_noise": "", 9 | "init_weight": 0.0, 10 | "mse_decay_rate": 0, 11 | "output_dir": "/outputs", 12 | "models_dir": "/models", 13 | "clip_model": "ViT-B/16", 14 | "vqgan_checkpoint": "/models/vqgan_imagenet_f16_16384.ckpt", 15 | "vqgan_config": "/configs/models/vqgan_imagenet_f16_16384.json", 16 | "noise_prompt_seeds": [], 17 | "noise_prompt_weights": [], 18 | "step_size": 0.1, 19 | "cutn": 32, 20 | "cut_pow": 1.0, 21 | "seed": -1, 22 | "optimizer": "Adam", 23 | "nwarm_restarts": -1, 24 | "augments": ["Af", "Pe", "Ji", "Er"] 25 | } 26 | -------------------------------------------------------------------------------- /configs/local.json: -------------------------------------------------------------------------------- 1 | { 2 | "prompts": ["a painting of a potato"], 3 | "image_prompts": [], 4 | "max_iterations": 250, 5 | "save_freq": 50, 6 | "size": [256, 256], 7 | "init_image": "", 8 | "init_noise": "", 9 | "init_weight": 0.0, 10 | "mse_decay_rate": 0, 11 | "output_dir": "./outputs", 12 | "models_dir": "./models", 13 | "clip_model": "ViT-B/16", 14 | "vqgan_checkpoint": "./models/vqgan_imagenet_f16_16384.ckpt", 15 | "vqgan_config": "./configs/models/vqgan_imagenet_f16_16384.json", 16 | "noise_prompt_seeds": [], 17 | "noise_prompt_weights": [], 18 | "step_size": 0.1, 19 | "cutn": 32, 20 | "cut_pow": 1.0, 21 | "seed": -1, 22 | "optimizer": "Adam", 23 | "nwarm_restarts": -1, 24 | "augments": ["Af", "Pe", "Ji", "Er"] 25 | } 26 | -------------------------------------------------------------------------------- /configs/models/vqgan_coco_f16_8192.json: -------------------------------------------------------------------------------- 1 | { 2 | "params": { 3 | "embed_dim": 256, 4 | "n_embed": 8192, 5 | "ddconfig": { 6 | "double_z": false, 7 | "z_channels": 256, 8 | "resolution": 256, 9 | "in_channels": 3, 10 | "out_ch": 3, 11 | "ch": 128, 12 | "ch_mult": [1, 1, 2, 2, 4], 13 | "num_res_blocks": 2, 14 | "attn_resolutions": [16], 15 | "dropout": 0.0 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /configs/models/vqgan_custom.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_learning_rate": 4.5e-6, 3 | "batch_size": 4, 4 | "epochs": 1000, 5 | "output_dir": "./outputs", 6 | "models_dir": "./models", 7 | "data_dir": "./data", 8 | "seed": -1, 9 | "resume_checkpoint": "", 10 | "params": { 11 | "embed_dim": 256, 12 | "n_embed": 1024, 13 | "ddconfig": { 14 | "double_z": false, 15 | "z_channels": 256, 16 | "resolution": 256, 17 | "in_channels": 3, 18 | "out_ch": 3, 19 | "ch": 128, 20 | "ch_mult": [1, 1, 2, 2, 4], 21 | "num_res_blocks": 2, 22 | "attn_resolutions": [16], 23 | "dropout": 0.0 24 | }, 25 | "lossconfig": { 26 | "params": { 27 | "disc_conditional": false, 28 | "disc_in_channels": 3, 29 | "disc_start": 25000, 30 | "disc_weight": 0.8, 31 | "codebook_weight": 1.0 32 | } 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /configs/models/vqgan_custom_docker.json: -------------------------------------------------------------------------------- 1 | { 2 | "base_learning_rate": 4.5e-6, 3 | "batch_size": 4, 4 | "epochs": 1000, 5 | "output_dir": "/outputs", 6 | "models_dir": "/models", 7 | "data_dir": "/data", 8 | "seed": -1, 9 | "resume_checkpoint": "", 10 | "params": { 11 | "embed_dim": 256, 12 | "n_embed": 1024, 13 | "ddconfig": { 14 | "double_z": false, 15 | "z_channels": 256, 16 | "resolution": 256, 17 | "in_channels": 3, 18 | "out_ch": 3, 19 | "ch": 128, 20 | "ch_mult": [1, 1, 2, 2, 4], 21 | "num_res_blocks": 2, 22 | "attn_resolutions": [16], 23 | "dropout": 0.0 24 | }, 25 | "lossconfig": { 26 | "params": { 27 | "disc_conditional": false, 28 | "disc_in_channels": 3, 29 | "disc_start": 25000, 30 | "disc_weight": 0.8, 31 | "codebook_weight": 1.0 32 | } 33 | } 34 | } 35 | } 36 | -------------------------------------------------------------------------------- /configs/models/vqgan_faceshq_f16_1024.json: -------------------------------------------------------------------------------- 1 | { 2 | "params": { 3 | "embed_dim": 256, 4 | "n_embed": 1024, 5 | "ddconfig": { 6 | "double_z": false, 7 | "z_channels": 256, 8 | "resolution": 256, 9 | "in_channels": 3, 10 | "out_ch": 3, 11 | "ch": 128, 12 | "ch_mult": [1, 1, 2, 2, 4], 13 | "num_res_blocks": 2, 14 | "attn_resolutions": [16], 15 | "dropout": 0.0 16 | } 17 | } 18 | } 19 | -------------------------------------------------------------------------------- /configs/models/vqgan_imagenet_f16_1024.json: -------------------------------------------------------------------------------- 1 | { 2 | "params": { 3 | "embed_dim": 256, 4 | "n_embed": 1024, 5 | "ddconfig": { 6 | "double_z": false, 7 | "z_channels": 256, 8 | "resolution": 256, 9 | "in_channels": 3, 10 | "out_ch": 3, 11 | "ch": 128, 12 | "ch_mult": [1, 1, 2, 2, 4], 13 | "num_res_blocks": 2, 14 | "attn_resolutions": [16], 15 | "dropout": 0.0 16 | }, 17 | "lossconfig": { 18 | "params": { 19 | "disc_conditional": false, 20 | "disc_in_channels": 3, 21 | "disc_start": 0, 22 | "disc_weight": 0.8, 23 | "codebook_weight": 1.0 24 | } 25 | } 26 | } 27 | } 28 | -------------------------------------------------------------------------------- /configs/models/vqgan_imagenet_f16_16384.json: -------------------------------------------------------------------------------- 1 | { 2 | "params": { 3 | "embed_dim": 256, 4 | "n_embed": 16384, 5 | "ddconfig": { 6 | "double_z": false, 7 | "z_channels": 256, 8 | "resolution": 256, 9 | "in_channels": 3, 10 | "out_ch": 3, 11 | "ch": 128, 12 | "ch_mult": [1, 1, 2, 2, 4], 13 | "num_res_blocks": 2, 14 | "attn_resolutions": [16], 15 | "dropout": 0.0 16 | }, 17 | "lossconfig": { 18 | "params": { 19 | "disc_conditional": false, 20 | "disc_in_channels": 3, 21 | "disc_start": 0, 22 | "disc_weight": 0.75, 23 | "disc_num_layers": 2, 24 | "codebook_weight": 1.0 25 | } 26 | } 27 | } 28 | } 29 | -------------------------------------------------------------------------------- /core/clip/README.md: -------------------------------------------------------------------------------- 1 | # CLIP 2 | 3 | [[Original]](https://github.com/openai/CLIP) 4 | 5 | ## About 6 | 7 | A stripped & minimalist version of the original project. 8 | -------------------------------------------------------------------------------- /core/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from core.clip import * 2 | -------------------------------------------------------------------------------- /core/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/core/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /core/clip/clip.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import sys 4 | import urllib 5 | import warnings 6 | from typing import Any, Union, List 7 | 8 | import torch 9 | from PIL import Image 10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 11 | from tqdm import tqdm 12 | 13 | from core.clip.model import build_model 14 | from core.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer 15 | 16 | try: 17 | from torchvision.transforms import InterpolationMode 18 | BICUBIC = InterpolationMode.BICUBIC 19 | except ImportError: 20 | BICUBIC = Image.BICUBIC 21 | 22 | 23 | if torch.__version__.split(".") < ["1", "7", "1"]: 24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended") 25 | 26 | 27 | __all__ = ["available_models", "load", "tokenize"] 28 | _tokenizer = _Tokenizer() 29 | 30 | _MODELS = { 31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", 35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", 37 | } 38 | 39 | 40 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 41 | os.makedirs(root, exist_ok=True) 42 | filename = os.path.basename(url) 43 | 44 | expected_sha256 = url.split("/")[-2] 45 | download_target = os.path.join(root, filename) 46 | 47 | if os.path.exists(download_target) and not os.path.isfile(download_target): 48 | raise RuntimeError(f"{download_target} exists and is not a regular file") 49 | 50 | if os.path.isfile(download_target): 51 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 52 | return download_target 53 | else: 54 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 55 | 56 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, download_target)) 57 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 58 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: 59 | while True: 60 | buffer = source.read(8192) 61 | if not buffer: 62 | break 63 | 64 | output.write(buffer) 65 | loop.update(len(buffer)) 66 | 67 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 68 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match") 69 | 70 | return download_target 71 | 72 | 73 | def _transform(n_px): 74 | return Compose([ 75 | Resize(n_px, interpolation=BICUBIC), 76 | CenterCrop(n_px), 77 | lambda image: image.convert("RGB"), 78 | ToTensor(), 79 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 80 | ]) 81 | 82 | 83 | def available_models() -> List[str]: 84 | """Returns the names of available CLIP models""" 85 | return list(_MODELS.keys()) 86 | 87 | 88 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, **kwargs: Any): 89 | """Load a CLIP model 90 | 91 | Parameters 92 | ---------- 93 | name : str 94 | A model name listed by `clip.available_models()`, or the path to a model checkpoint 95 | containing the state_dict 96 | 97 | device : Union[str, torch.device] 98 | The device to put the loaded model 99 | 100 | jit : bool 101 | Whether to load the optimized JIT model or more hackable non-JIT model (default). 102 | 103 | **kwargs (optional): Any 104 | The corresponding kwargs for _download function 105 | 106 | Returns 107 | ------- 108 | model : torch.nn.Module 109 | The CLIP model 110 | 111 | preprocess : Callable[[PIL.Image], torch.Tensor] 112 | A torchvision transform that converts a PIL image into a tensor that the returned model can 113 | take as its input 114 | """ 115 | if name in _MODELS: 116 | model_path = _download(_MODELS[name], **kwargs) 117 | elif os.path.isfile(name): 118 | model_path = name 119 | else: 120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 121 | 122 | try: 123 | # loading JIT archive 124 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 125 | state_dict = None 126 | except RuntimeError: 127 | # loading saved state dict 128 | if jit: 129 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 130 | jit = False 131 | state_dict = torch.load(model_path, map_location="cpu") 132 | 133 | if not jit: 134 | model = build_model(state_dict or model.state_dict()).to(device) 135 | if str(device) == "cpu": 136 | model.float() 137 | return model, _transform(model.visual.input_resolution) 138 | 139 | # patch the device names 140 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 141 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 142 | 143 | def patch_device(module): 144 | try: 145 | graphs = [module.graph] if hasattr(module, "graph") else [] 146 | except RuntimeError: 147 | graphs = [] 148 | 149 | if hasattr(module, "forward1"): 150 | graphs.append(module.forward1.graph) 151 | 152 | for graph in graphs: 153 | for node in graph.findAllNodes("prim::Constant"): 154 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 155 | node.copyAttributes(device_node) 156 | 157 | model.apply(patch_device) 158 | patch_device(model.encode_image) 159 | patch_device(model.encode_text) 160 | 161 | # patch dtype to float32 on CPU 162 | if str(device) == "cpu": 163 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 164 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 165 | float_node = float_input.node() 166 | 167 | def patch_float(module): 168 | try: 169 | graphs = [module.graph] if hasattr(module, "graph") else [] 170 | except RuntimeError: 171 | graphs = [] 172 | 173 | if hasattr(module, "forward1"): 174 | graphs.append(module.forward1.graph) 175 | 176 | for graph in graphs: 177 | for node in graph.findAllNodes("aten::to"): 178 | inputs = list(node.inputs()) 179 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 180 | if inputs[i].node()["value"] == 5: 181 | inputs[i].node().copyAttributes(float_node) 182 | 183 | model.apply(patch_float) 184 | patch_float(model.encode_image) 185 | patch_float(model.encode_text) 186 | 187 | model.float() 188 | 189 | return model, _transform(model.input_resolution.item()) 190 | 191 | 192 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor: 193 | """ 194 | Returns the tokenized representation of given input string(s) 195 | 196 | Parameters 197 | ---------- 198 | texts : Union[str, List[str]] 199 | An input string or a list of input strings to tokenize 200 | 201 | context_length : int 202 | The context length to use; all CLIP models use 77 as the context length 203 | 204 | truncate: bool 205 | Whether to truncate the text in case its encoding is longer than the context length 206 | 207 | Returns 208 | ------- 209 | A two-dimensional tensor containing the resulting tokens, 210 | shape = [number of input strings, context_length] 211 | """ 212 | if isinstance(texts, str): 213 | texts = [texts] 214 | 215 | sot_token = _tokenizer.encoder["<|startoftext|>"] 216 | eot_token = _tokenizer.encoder["<|endoftext|>"] 217 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 218 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 219 | 220 | for i, tokens in enumerate(all_tokens): 221 | if len(tokens) > context_length: 222 | if truncate: 223 | tokens = tokens[:context_length] 224 | tokens[-1] = eot_token 225 | else: 226 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 227 | result[i, :len(tokens)] = torch.tensor(tokens) 228 | 229 | return result 230 | -------------------------------------------------------------------------------- /core/clip/model.py: -------------------------------------------------------------------------------- 1 | from collections import OrderedDict 2 | from typing import Tuple, Union 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | 9 | 10 | class Bottleneck(nn.Module): 11 | expansion = 4 12 | 13 | def __init__(self, inplanes, planes, stride=1): 14 | super().__init__() 15 | 16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) 18 | self.bn1 = nn.BatchNorm2d(planes) 19 | 20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) 21 | self.bn2 = nn.BatchNorm2d(planes) 22 | 23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() 24 | 25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) 26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 27 | 28 | self.relu = nn.ReLU(inplace=True) 29 | self.downsample = None 30 | self.stride = stride 31 | 32 | if stride > 1 or inplanes != planes * Bottleneck.expansion: 33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 34 | self.downsample = nn.Sequential(OrderedDict([ 35 | ("-1", nn.AvgPool2d(stride)), 36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), 37 | ("1", nn.BatchNorm2d(planes * self.expansion)) 38 | ])) 39 | 40 | def forward(self, x: torch.Tensor): 41 | identity = x 42 | 43 | out = self.relu(self.bn1(self.conv1(x))) 44 | out = self.relu(self.bn2(self.conv2(out))) 45 | out = self.avgpool(out) 46 | out = self.bn3(self.conv3(out)) 47 | 48 | if self.downsample is not None: 49 | identity = self.downsample(x) 50 | 51 | out += identity 52 | out = self.relu(out) 53 | return out 54 | 55 | 56 | class AttentionPool2d(nn.Module): 57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): 58 | super().__init__() 59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) 60 | self.k_proj = nn.Linear(embed_dim, embed_dim) 61 | self.q_proj = nn.Linear(embed_dim, embed_dim) 62 | self.v_proj = nn.Linear(embed_dim, embed_dim) 63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) 64 | self.num_heads = num_heads 65 | 66 | def forward(self, x): 67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC 68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC 69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC 70 | x, _ = F.multi_head_attention_forward( 71 | query=x, key=x, value=x, 72 | embed_dim_to_check=x.shape[-1], 73 | num_heads=self.num_heads, 74 | q_proj_weight=self.q_proj.weight, 75 | k_proj_weight=self.k_proj.weight, 76 | v_proj_weight=self.v_proj.weight, 77 | in_proj_weight=None, 78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), 79 | bias_k=None, 80 | bias_v=None, 81 | add_zero_attn=False, 82 | dropout_p=0, 83 | out_proj_weight=self.c_proj.weight, 84 | out_proj_bias=self.c_proj.bias, 85 | use_separate_proj_weight=True, 86 | training=self.training, 87 | need_weights=False 88 | ) 89 | 90 | return x[0] 91 | 92 | 93 | class ModifiedResNet(nn.Module): 94 | """ 95 | A ResNet class that is similar to torchvision's but contains the following changes: 96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. 97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 98 | - The final pooling layer is a QKV attention instead of an average pool 99 | """ 100 | 101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): 102 | super().__init__() 103 | self.output_dim = output_dim 104 | self.input_resolution = input_resolution 105 | 106 | # the 3-layer stem 107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) 108 | self.bn1 = nn.BatchNorm2d(width // 2) 109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) 110 | self.bn2 = nn.BatchNorm2d(width // 2) 111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) 112 | self.bn3 = nn.BatchNorm2d(width) 113 | self.avgpool = nn.AvgPool2d(2) 114 | self.relu = nn.ReLU(inplace=True) 115 | 116 | # residual layers 117 | self._inplanes = width # this is a *mutable* variable used during construction 118 | self.layer1 = self._make_layer(width, layers[0]) 119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2) 120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2) 121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2) 122 | 123 | embed_dim = width * 32 # the ResNet feature dimension 124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) 125 | 126 | def _make_layer(self, planes, blocks, stride=1): 127 | layers = [Bottleneck(self._inplanes, planes, stride)] 128 | 129 | self._inplanes = planes * Bottleneck.expansion 130 | for _ in range(1, blocks): 131 | layers.append(Bottleneck(self._inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | def stem(x): 137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: 138 | x = self.relu(bn(conv(x))) 139 | x = self.avgpool(x) 140 | return x 141 | 142 | x = x.type(self.conv1.weight.dtype) 143 | x = stem(x) 144 | x = self.layer1(x) 145 | x = self.layer2(x) 146 | x = self.layer3(x) 147 | x = self.layer4(x) 148 | x = self.attnpool(x) 149 | 150 | return x 151 | 152 | 153 | class LayerNorm(nn.LayerNorm): 154 | """Subclass torch's LayerNorm to handle fp16.""" 155 | 156 | def forward(self, x: torch.Tensor): 157 | orig_type = x.dtype 158 | ret = super().forward(x.type(torch.float32)) 159 | return ret.type(orig_type) 160 | 161 | 162 | class QuickGELU(nn.Module): 163 | def forward(self, x: torch.Tensor): 164 | return x * torch.sigmoid(1.702 * x) 165 | 166 | 167 | class ResidualAttentionBlock(nn.Module): 168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): 169 | super().__init__() 170 | 171 | self.attn = nn.MultiheadAttention(d_model, n_head) 172 | self.ln_1 = LayerNorm(d_model) 173 | self.mlp = nn.Sequential(OrderedDict([ 174 | ("c_fc", nn.Linear(d_model, d_model * 4)), 175 | ("gelu", QuickGELU()), 176 | ("c_proj", nn.Linear(d_model * 4, d_model)) 177 | ])) 178 | self.ln_2 = LayerNorm(d_model) 179 | self.attn_mask = attn_mask 180 | 181 | def attention(self, x: torch.Tensor): 182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None 183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] 184 | 185 | def forward(self, x: torch.Tensor): 186 | x = x + self.attention(self.ln_1(x)) 187 | x = x + self.mlp(self.ln_2(x)) 188 | return x 189 | 190 | 191 | class Transformer(nn.Module): 192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): 193 | super().__init__() 194 | self.width = width 195 | self.layers = layers 196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) 197 | 198 | def forward(self, x: torch.Tensor): 199 | return self.resblocks(x) 200 | 201 | 202 | class VisionTransformer(nn.Module): 203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): 204 | super().__init__() 205 | self.input_resolution = input_resolution 206 | self.output_dim = output_dim 207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) 208 | 209 | scale = width ** -0.5 210 | self.class_embedding = nn.Parameter(scale * torch.randn(width)) 211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) 212 | self.ln_pre = LayerNorm(width) 213 | 214 | self.transformer = Transformer(width, layers, heads) 215 | 216 | self.ln_post = LayerNorm(width) 217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) 218 | 219 | def forward(self, x: torch.Tensor): 220 | x = self.conv1(x) # shape = [*, width, grid, grid] 221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] 224 | x = x + self.positional_embedding.to(x.dtype) 225 | x = self.ln_pre(x) 226 | 227 | x = x.permute(1, 0, 2) # NLD -> LND 228 | x = self.transformer(x) 229 | x = x.permute(1, 0, 2) # LND -> NLD 230 | 231 | x = self.ln_post(x[:, 0, :]) 232 | 233 | if self.proj is not None: 234 | x = x @ self.proj 235 | 236 | return x 237 | 238 | 239 | class CLIP(nn.Module): 240 | def __init__(self, 241 | embed_dim: int, 242 | # vision 243 | image_resolution: int, 244 | vision_layers: Union[Tuple[int, int, int, int], int], 245 | vision_width: int, 246 | vision_patch_size: int, 247 | # text 248 | context_length: int, 249 | vocab_size: int, 250 | transformer_width: int, 251 | transformer_heads: int, 252 | transformer_layers: int 253 | ): 254 | super().__init__() 255 | 256 | self.context_length = context_length 257 | 258 | if isinstance(vision_layers, (tuple, list)): 259 | vision_heads = vision_width * 32 // 64 260 | self.visual = ModifiedResNet( 261 | layers=vision_layers, 262 | output_dim=embed_dim, 263 | heads=vision_heads, 264 | input_resolution=image_resolution, 265 | width=vision_width 266 | ) 267 | else: 268 | vision_heads = vision_width // 64 269 | self.visual = VisionTransformer( 270 | input_resolution=image_resolution, 271 | patch_size=vision_patch_size, 272 | width=vision_width, 273 | layers=vision_layers, 274 | heads=vision_heads, 275 | output_dim=embed_dim 276 | ) 277 | 278 | self.transformer = Transformer( 279 | width=transformer_width, 280 | layers=transformer_layers, 281 | heads=transformer_heads, 282 | attn_mask=self.build_attention_mask() 283 | ) 284 | 285 | self.vocab_size = vocab_size 286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width) 287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) 288 | self.ln_final = LayerNorm(transformer_width) 289 | 290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) 291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 292 | 293 | self.initialize_parameters() 294 | 295 | def initialize_parameters(self): 296 | nn.init.normal_(self.token_embedding.weight, std=0.02) 297 | nn.init.normal_(self.positional_embedding, std=0.01) 298 | 299 | if isinstance(self.visual, ModifiedResNet): 300 | if self.visual.attnpool is not None: 301 | std = self.visual.attnpool.c_proj.in_features ** -0.5 302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) 303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) 304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) 305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) 306 | 307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: 308 | for name, param in resnet_block.named_parameters(): 309 | if name.endswith("bn3.weight"): 310 | nn.init.zeros_(param) 311 | 312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) 313 | attn_std = self.transformer.width ** -0.5 314 | fc_std = (2 * self.transformer.width) ** -0.5 315 | for block in self.transformer.resblocks: 316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std) 317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std) 318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) 319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) 320 | 321 | if self.text_projection is not None: 322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) 323 | 324 | def build_attention_mask(self): 325 | # lazily create causal attention mask, with full attention between the vision tokens 326 | # pytorch uses additive attention mask; fill with -inf 327 | mask = torch.empty(self.context_length, self.context_length) 328 | mask.fill_(float("-inf")) 329 | mask.triu_(1) # zero out the lower diagonal 330 | return mask 331 | 332 | @property 333 | def dtype(self): 334 | return self.visual.conv1.weight.dtype 335 | 336 | def encode_image(self, image): 337 | return self.visual(image.type(self.dtype)) 338 | 339 | def encode_text(self, text): 340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] 341 | 342 | x = x + self.positional_embedding.type(self.dtype) 343 | x = x.permute(1, 0, 2) # NLD -> LND 344 | x = self.transformer(x) 345 | x = x.permute(1, 0, 2) # LND -> NLD 346 | x = self.ln_final(x).type(self.dtype) 347 | 348 | # x.shape = [batch_size, n_ctx, transformer.width] 349 | # take features from the eot embedding (eot_token is the highest number in each sequence) 350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection 351 | 352 | return x 353 | 354 | def forward(self, image, text): 355 | image_features = self.encode_image(image) 356 | text_features = self.encode_text(text) 357 | 358 | # normalized features 359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 361 | 362 | # cosine similarity as logits 363 | logit_scale = self.logit_scale.exp() 364 | logits_per_image = logit_scale * image_features @ text_features.t() 365 | logits_per_text = logit_scale * text_features @ image_features.t() 366 | 367 | # shape = [global_batch_size, global_batch_size] 368 | return logits_per_image, logits_per_text 369 | 370 | 371 | def convert_weights(model: nn.Module): 372 | """Convert applicable model parameters to fp16""" 373 | 374 | def _convert_weights_to_fp16(l): 375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 376 | l.weight.data = l.weight.data.half() 377 | if l.bias is not None: 378 | l.bias.data = l.bias.data.half() 379 | 380 | if isinstance(l, nn.MultiheadAttention): 381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 382 | tensor = getattr(l, attr) 383 | if tensor is not None: 384 | tensor.data = tensor.data.half() 385 | 386 | for name in ["text_projection", "proj"]: 387 | if hasattr(l, name): 388 | attr = getattr(l, name) 389 | if attr is not None: 390 | attr.data = attr.data.half() 391 | 392 | model.apply(_convert_weights_to_fp16) 393 | 394 | 395 | def build_model(state_dict: dict): 396 | vit = "visual.proj" in state_dict 397 | 398 | if vit: 399 | vision_width = state_dict["visual.conv1.weight"].shape[0] 400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) 401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] 402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) 403 | image_resolution = vision_patch_size * grid_size 404 | else: 405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] 406 | vision_layers = tuple(counts) 407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] 408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) 409 | vision_patch_size = None 410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] 411 | image_resolution = output_width * 32 412 | 413 | embed_dim = state_dict["text_projection"].shape[1] 414 | context_length = state_dict["positional_embedding"].shape[0] 415 | vocab_size = state_dict["token_embedding.weight"].shape[0] 416 | transformer_width = state_dict["ln_final.weight"].shape[0] 417 | transformer_heads = transformer_width // 64 418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks"))) 419 | 420 | model = CLIP( 421 | embed_dim, 422 | image_resolution, vision_layers, vision_width, vision_patch_size, 423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers 424 | ) 425 | 426 | for key in ["input_resolution", "context_length", "vocab_size"]: 427 | if key in state_dict: 428 | del state_dict[key] 429 | 430 | convert_weights(model) 431 | model.load_state_dict(state_dict) 432 | return model.eval() 433 | -------------------------------------------------------------------------------- /core/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8 + n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1: 49152 - 256 - 2 + 1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v + '' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + (token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token + '' 88 | 89 | while True: 90 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except Exception: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second: 106 | new_word.append(first + second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /core/optimizer/__init__.py: -------------------------------------------------------------------------------- 1 | from core.optimizer.adamp import AdamP 2 | from core.optimizer.diffgrad import DiffGrad 3 | from core.optimizer.radam import RAdam 4 | 5 | __all__ = [ 6 | AdamP, 7 | DiffGrad, 8 | RAdam, 9 | ] 10 | -------------------------------------------------------------------------------- /core/optimizer/adamp.py: -------------------------------------------------------------------------------- 1 | # https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adamp.py 2 | import math 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | 8 | class AdamP(Optimizer): 9 | r"""Implements AdamP algorithm. 10 | 11 | It has been proposed in `Slowing Down the Weight Norm Increase in 12 | Momentum-based Optimizers`__ 13 | 14 | Arguments: 15 | params: iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr: learning rate (default: 1e-3) 18 | betas: coefficients used for computing 19 | running averages of gradient and its square (default: (0.9, 0.999)) 20 | eps: term added to the denominator to improve 21 | numerical stability (default: 1e-8) 22 | weight_decay: weight decay (L2 penalty) (default: 0) 23 | delta: threhold that determines whether a set of parameters is scale 24 | invariant or not (default: 0.1) 25 | wd_ratio: relative weight decay applied on scale-invariant parameters 26 | compared to that applied on scale-variant parameters (default: 0.1) 27 | nesterov: enables Nesterov momentum (default: False) 28 | 29 | 30 | Example: 31 | >>> import torch_optimizer as optim 32 | >>> optimizer = optim.AdamP(model.parameters(), lr=0.1) 33 | >>> optimizer.zero_grad() 34 | >>> loss_fn(model(input), target).backward() 35 | >>> optimizer.step() 36 | 37 | __ https://arxiv.org/abs/2006.08217 38 | 39 | Note: 40 | Reference code: https://github.com/clovaai/AdamP 41 | """ 42 | 43 | def __init__( 44 | self, 45 | params, 46 | lr: float = 1e-3, 47 | betas=(0.9, 0.999), 48 | eps: float = 1e-8, 49 | weight_decay: float = 0, 50 | delta: float = 0.1, 51 | wd_ratio: float = 0.1, 52 | nesterov: bool = False, 53 | ) -> None: 54 | if lr <= 0.0: 55 | raise ValueError('Invalid learning rate: {}'.format(lr)) 56 | if eps < 0.0: 57 | raise ValueError('Invalid epsilon value: {}'.format(eps)) 58 | if not 0.0 <= betas[0] < 1.0: 59 | raise ValueError( 60 | 'Invalid beta parameter at index 0: {}'.format(betas[0]) 61 | ) 62 | if not 0.0 <= betas[1] < 1.0: 63 | raise ValueError( 64 | 'Invalid beta parameter at index 1: {}'.format(betas[1]) 65 | ) 66 | if weight_decay < 0: 67 | raise ValueError( 68 | 'Invalid weight_decay value: {}'.format(weight_decay) 69 | ) 70 | if delta < 0: 71 | raise ValueError('Invalid delta value: {}'.format(delta)) 72 | if wd_ratio < 0: 73 | raise ValueError('Invalid wd_ratio value: {}'.format(wd_ratio)) 74 | 75 | defaults = dict( 76 | lr=lr, 77 | betas=betas, 78 | eps=eps, 79 | weight_decay=weight_decay, 80 | delta=delta, 81 | wd_ratio=wd_ratio, 82 | nesterov=nesterov, 83 | ) 84 | super(AdamP, self).__init__(params, defaults) 85 | 86 | @staticmethod 87 | def _channel_view(x): 88 | return x.view(x.size(0), -1) 89 | 90 | @staticmethod 91 | def _layer_view(x): 92 | return x.view(1, -1) 93 | 94 | @staticmethod 95 | def _cosine_similarity(x, y, eps, view_func): 96 | x = view_func(x) 97 | y = view_func(y) 98 | 99 | x_norm = x.norm(dim=1).add_(eps) 100 | y_norm = y.norm(dim=1).add_(eps) 101 | dot = (x * y).sum(dim=1) 102 | 103 | return dot.abs() / x_norm / y_norm 104 | 105 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps): 106 | wd = 1 107 | expand_size = [-1] + [1] * (len(p.shape) - 1) 108 | for view_func in [self._channel_view, self._layer_view]: 109 | 110 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func) 111 | 112 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)): 113 | p_n = p.data / view_func(p.data).norm(dim=1).view( 114 | expand_size 115 | ).add_(eps) 116 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view( 117 | expand_size 118 | ) 119 | wd = wd_ratio 120 | 121 | return perturb, wd 122 | 123 | return perturb, wd 124 | 125 | def step(self, closure=None): 126 | r"""Performs a single optimization step. 127 | 128 | Arguments: 129 | closure: A closure that reevaluates the model and returns the loss. 130 | """ 131 | loss = None 132 | if closure is not None: 133 | loss = closure() 134 | 135 | for group in self.param_groups: 136 | for p in group['params']: 137 | if p.grad is None: 138 | continue 139 | 140 | grad = p.grad.data 141 | beta1, beta2 = group['betas'] 142 | nesterov = group['nesterov'] 143 | 144 | state = self.state[p] 145 | 146 | # State initialization 147 | if len(state) == 0: 148 | state['step'] = 0 149 | state['exp_avg'] = torch.zeros_like( 150 | p.data, memory_format=torch.preserve_format 151 | ) 152 | state['exp_avg_sq'] = torch.zeros_like( 153 | p.data, memory_format=torch.preserve_format 154 | ) 155 | 156 | # Adam 157 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 158 | 159 | state['step'] += 1 160 | bias_correction1 = 1 - beta1 ** state['step'] 161 | bias_correction2 = 1 - beta2 ** state['step'] 162 | 163 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 164 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 165 | 166 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( 167 | group['eps'] 168 | ) 169 | step_size = group['lr'] / bias_correction1 170 | 171 | if nesterov: 172 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom 173 | else: 174 | perturb = exp_avg / denom 175 | 176 | # Projection 177 | wd_ratio = 1 178 | if len(p.shape) > 1: 179 | perturb, wd_ratio = self._projection( 180 | p, 181 | grad, 182 | perturb, 183 | group['delta'], 184 | group['wd_ratio'], 185 | group['eps'], 186 | ) 187 | 188 | # Weight decay 189 | if group['weight_decay'] > 0: 190 | p.data.mul_( 191 | 1 - group['lr'] * group['weight_decay'] * wd_ratio 192 | ) 193 | 194 | # Step 195 | p.data.add_(perturb, alpha=-step_size) 196 | 197 | return loss 198 | -------------------------------------------------------------------------------- /core/optimizer/diffgrad.py: -------------------------------------------------------------------------------- 1 | # https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/diffgrad.py 2 | import math 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | 8 | class DiffGrad(Optimizer): 9 | r"""Implements DiffGrad algorithm. 10 | 11 | It has been proposed in `DiffGrad: An Optimization Method for 12 | Convolutional Neural Networks`__. 13 | 14 | Arguments: 15 | params: iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr: learning rate (default: 1e-3) 18 | betas: coefficients used for computing 19 | running averages of gradient and its square (default: (0.9, 0.999)) 20 | eps: term added to the denominator to improve 21 | numerical stability (default: 1e-8) 22 | weight_decay: weight decay (L2 penalty) (default: 0) 23 | 24 | Example: 25 | >>> import torch_optimizer as optim 26 | >>> optimizer = optim.DiffGrad(model.parameters(), lr=0.1) 27 | >>> optimizer.zero_grad() 28 | >>> loss_fn(model(input), target).backward() 29 | >>> optimizer.step() 30 | 31 | __ https://arxiv.org/abs/1909.11015 32 | 33 | Note: 34 | Reference code: https://github.com/shivram1987/diffGrad 35 | """ 36 | 37 | def __init__( 38 | self, 39 | params, 40 | lr: float = 1e-3, 41 | betas=(0.9, 0.999), 42 | eps: float = 1e-8, 43 | weight_decay: float = 0.0, 44 | ) -> None: 45 | if lr <= 0.0: 46 | raise ValueError('Invalid learning rate: {}'.format(lr)) 47 | if eps < 0.0: 48 | raise ValueError('Invalid epsilon value: {}'.format(eps)) 49 | if not 0.0 <= betas[0] < 1.0: 50 | raise ValueError( 51 | 'Invalid beta parameter at index 0: {}'.format(betas[0]) 52 | ) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError( 55 | 'Invalid beta parameter at index 1: {}'.format(betas[1]) 56 | ) 57 | if weight_decay < 0.0: 58 | raise ValueError( 59 | 'Invalid weight_decay value: {}'.format(weight_decay) 60 | ) 61 | 62 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) 63 | super(DiffGrad, self).__init__(params, defaults) 64 | 65 | def step(self, closure=None): 66 | r"""Performs a single optimization step. 67 | 68 | Arguments: 69 | closure: A closure that reevaluates the model and returns the loss. 70 | """ 71 | loss = None 72 | if closure is not None: 73 | loss = closure() 74 | 75 | for group in self.param_groups: 76 | beta1, beta2 = group['betas'] 77 | 78 | for p in group['params']: 79 | if p.grad is None: 80 | continue 81 | grad = p.grad.data 82 | if grad.is_sparse: 83 | msg = ( 84 | 'DiffGrad does not support sparse gradients, ' 85 | 'please consider SparseAdam instead' 86 | ) 87 | raise RuntimeError(msg) 88 | 89 | state = self.state[p] 90 | 91 | # State initialization 92 | if len(state) == 0: 93 | state['step'] = 0 94 | # Exponential moving average of gradient values 95 | state['exp_avg'] = torch.zeros_like( 96 | p, memory_format=torch.preserve_format 97 | ) 98 | # Exponential moving average of squared gradient values 99 | state['exp_avg_sq'] = torch.zeros_like( 100 | p, memory_format=torch.preserve_format 101 | ) 102 | # Previous gradient 103 | state['previous_grad'] = torch.zeros_like( 104 | p, memory_format=torch.preserve_format 105 | ) 106 | 107 | exp_avg, exp_avg_sq, previous_grad = ( 108 | state['exp_avg'], 109 | state['exp_avg_sq'], 110 | state['previous_grad'], 111 | ) 112 | 113 | state['step'] += 1 114 | 115 | if group['weight_decay'] != 0: 116 | grad.add_(p.data, alpha=group['weight_decay']) 117 | 118 | # Decay the first and second moment running average coefficient 119 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 120 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 121 | denom = exp_avg_sq.sqrt().add_(group['eps']) 122 | 123 | bias_correction1 = 1 - beta1 ** state['step'] 124 | bias_correction2 = 1 - beta2 ** state['step'] 125 | 126 | # compute diffgrad coefficient (dfc) 127 | diff = torch.abs(previous_grad - grad) 128 | dfc = torch.div(1.0, (1.0 + torch.exp(-diff))) 129 | state['previous_grad'] = grad.clone() 130 | 131 | # update momentum with dfc 132 | exp_avg1 = exp_avg * dfc 133 | 134 | step_size = ( 135 | group['lr'] 136 | * math.sqrt(bias_correction2) 137 | / bias_correction1 138 | ) 139 | 140 | p.data.addcdiv_(exp_avg1, denom, value=-step_size) 141 | 142 | return loss 143 | -------------------------------------------------------------------------------- /core/optimizer/radam.py: -------------------------------------------------------------------------------- 1 | # https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/radam.py 2 | import math 3 | 4 | import torch 5 | from torch.optim.optimizer import Optimizer 6 | 7 | 8 | class RAdam(Optimizer): 9 | r"""Implements RAdam optimization algorithm. 10 | 11 | It has been proposed in `On the Variance of the Adaptive Learning 12 | Rate and Beyond`__. 13 | 14 | Arguments: 15 | params: iterable of parameters to optimize or dicts defining 16 | parameter groups 17 | lr: learning rate (default: 1e-3) 18 | betas: coefficients used for computing 19 | running averages of gradient and its square (default: (0.9, 0.999)) 20 | eps: term added to the denominator to improve 21 | numerical stability (default: 1e-8) 22 | weight_decay: weight decay (L2 penalty) (default: 0) 23 | 24 | Example: 25 | >>> import torch_optimizer as optim 26 | >>> optimizer = optim.RAdam(model.parameters(), lr=0.1) 27 | >>> optimizer.zero_grad() 28 | >>> loss_fn(model(input), target).backward() 29 | >>> optimizer.step() 30 | 31 | __ https://arxiv.org/abs/1908.03265 32 | 33 | Note: 34 | Reference code: https://github.com/LiyuanLucasLiu/RAdam 35 | """ 36 | 37 | def __init__( 38 | self, 39 | params, 40 | lr: float = 1e-3, 41 | betas=(0.9, 0.999), 42 | eps: float = 1e-8, 43 | weight_decay: float = 0, 44 | ) -> None: 45 | if lr <= 0.0: 46 | raise ValueError('Invalid learning rate: {}'.format(lr)) 47 | if eps < 0.0: 48 | raise ValueError('Invalid epsilon value: {}'.format(eps)) 49 | if not 0.0 <= betas[0] < 1.0: 50 | raise ValueError( 51 | 'Invalid beta parameter at index 0: {}'.format(betas[0]) 52 | ) 53 | if not 0.0 <= betas[1] < 1.0: 54 | raise ValueError( 55 | 'Invalid beta parameter at index 1: {}'.format(betas[1]) 56 | ) 57 | if weight_decay < 0: 58 | raise ValueError( 59 | 'Invalid weight_decay value: {}'.format(weight_decay) 60 | ) 61 | 62 | if ( 63 | isinstance(params, (list, tuple)) 64 | and len(params) > 0 65 | and isinstance(params[0], dict) 66 | ): 67 | for param in params: 68 | if 'betas' in param and ( 69 | param['betas'][0] != betas[0] 70 | or param['betas'][1] != betas[1] 71 | ): 72 | param['buffer'] = [[None, None, None] for _ in range(10)] 73 | 74 | defaults = dict( 75 | lr=lr, 76 | betas=betas, 77 | eps=eps, 78 | weight_decay=weight_decay, 79 | buffer=[[None, None, None] for _ in range(10)], 80 | ) 81 | super(RAdam, self).__init__(params, defaults) 82 | 83 | def __setstate__(self, state): 84 | super(RAdam, self).__setstate__(state) 85 | 86 | def step(self, closure=None): 87 | r"""Performs a single optimization step. 88 | 89 | Arguments: 90 | closure: A closure that reevaluates the model and returns the loss. 91 | """ 92 | 93 | loss = None 94 | if closure is not None: 95 | loss = closure() 96 | 97 | for group in self.param_groups: 98 | lr = group['lr'] 99 | weight_decay = group['weight_decay'] 100 | beta1, beta2 = group['betas'] 101 | eps = group['eps'] 102 | 103 | for p in group['params']: 104 | if p.grad is None: 105 | continue 106 | grad = p.grad.data.float() 107 | if grad.is_sparse: 108 | msg = ( 109 | 'RAdam does not support sparse gradients, ' 110 | 'please consider SparseAdam instead' 111 | ) 112 | raise RuntimeError(msg) 113 | 114 | p_data_fp32 = p.data.float() 115 | 116 | state = self.state[p] 117 | 118 | if len(state) == 0: 119 | state['step'] = 0 120 | state['exp_avg'] = torch.zeros_like( 121 | p_data_fp32, memory_format=torch.preserve_format 122 | ) 123 | state['exp_avg_sq'] = torch.zeros_like( 124 | p_data_fp32, memory_format=torch.preserve_format 125 | ) 126 | else: 127 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) 128 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as( 129 | p_data_fp32 130 | ) 131 | 132 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] 133 | 134 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) 135 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) 136 | 137 | state['step'] += 1 138 | buffered = group['buffer'][int(state['step'] % 10)] 139 | if state['step'] == buffered[0]: 140 | N_sma, step_size = buffered[1], buffered[2] 141 | else: 142 | buffered[0] = state['step'] 143 | beta2_t = beta2 ** state['step'] 144 | N_sma_max = 2 / (1 - beta2) - 1 145 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / ( 146 | 1 - beta2_t 147 | ) 148 | buffered[1] = N_sma 149 | 150 | # more conservative since it's an approximated value 151 | if N_sma >= 5: 152 | step_size = ( 153 | lr 154 | * math.sqrt( 155 | (1 - beta2_t) 156 | * (N_sma - 4) 157 | / (N_sma_max - 4) 158 | * (N_sma - 2) 159 | / N_sma 160 | * N_sma_max 161 | / (N_sma_max - 2) 162 | ) 163 | / (1 - beta1 ** state['step']) 164 | ) 165 | else: 166 | step_size = lr / (1 - beta1 ** state['step']) 167 | buffered[2] = step_size 168 | 169 | if weight_decay != 0: 170 | p_data_fp32.add_(p_data_fp32, alpha=-weight_decay * lr) 171 | 172 | # more conservative since it's an approximated value 173 | if N_sma >= 5: 174 | denom = exp_avg_sq.sqrt().add_(eps) 175 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) 176 | else: 177 | p_data_fp32.add_(exp_avg, alpha=-step_size) 178 | 179 | p.data.copy_(p_data_fp32) 180 | 181 | return loss 182 | -------------------------------------------------------------------------------- /core/schemas/__init__.py: -------------------------------------------------------------------------------- 1 | from core.schemas.config import Config 2 | from core.schemas.train_config import TrainConfig 3 | 4 | __all__ = [ 5 | Config, 6 | TrainConfig, 7 | ] 8 | -------------------------------------------------------------------------------- /core/schemas/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from core.clip.clip import available_models 4 | 5 | from typing import List 6 | from dataclasses import dataclass, field 7 | 8 | 9 | INIT_NOISES = ['', 'gradient', 'pixels', 'fractal'] 10 | OPTIMIZERS = ['Adam', 'AdamW', 'Adagrad', 'Adamax', 'DiffGrad', 'AdamP', 'RAdam'] 11 | AUGMENTS = ['Ji', 'Sh', 'Gn', 'Pe', 'Ro', 'Af', 'Et', 'Ts', 'Cr', 'Er', 'Re', 'Hf'] 12 | 13 | 14 | @dataclass 15 | class Config: 16 | prompts: List[str] = field(default_factory=lambda: []) 17 | image_prompts: List[str] = field(default_factory=lambda: []) 18 | max_iterations: int = 500 19 | save_freq: int = 50 20 | size: List[int] = field(default_factory=lambda: [256, 256]) 21 | pixelart: List[int] = None 22 | init_image: str = "" 23 | init_noise: str = "gradient" 24 | init_weight: float = 0.0 25 | mse_decay_rate: float = 0.0 26 | output_dir: str = "./outputs" 27 | models_dir: str = "./models" 28 | clip_model: str = 'ViT-B/16' 29 | vqgan_checkpoint: str = './models/vqgan_imagenet_f16_16384.ckpt' 30 | vqgan_config: str = './configs/models/vqgan_imagenet_f16_16384.json' 31 | noise_prompt_seeds: List[int] = field(default_factory=lambda: []) 32 | noise_prompt_weights: List[float] = field(default_factory=lambda: []) 33 | step_size: float = 0.1 34 | cutn: int = 32 35 | cut_pow: float = 1.0 36 | seed: int = -1 37 | optimizer: str = 'Adam' 38 | nwarm_restarts: int = -1 39 | augments: List[str] = field(default_factory=lambda: ['Af', 'Pe', 'Ji', 'Er']) 40 | 41 | def __post_init__(self): 42 | if self.init_noise not in INIT_NOISES: 43 | exit(f"ERROR: \"init_noise\": {self.init_noise}, <-- Noise algorithm not found.\n" 44 | f"Currently only the following values are supported: {INIT_NOISES}.") 45 | 46 | if self.optimizer not in OPTIMIZERS: 47 | exit(f"ERROR: \"optimizer\": {self.optimizer}, <-- Optimizer not found.\n" 48 | f"Currently only the following values are supported: {OPTIMIZERS}.") 49 | 50 | os.makedirs(self.models_dir, exist_ok=True) 51 | os.makedirs(self.output_dir, exist_ok=True) 52 | os.makedirs(f"{self.output_dir}/steps", exist_ok=True) 53 | print(f"Saving outputs in '{self.output_dir}'") 54 | 55 | models = available_models() 56 | if not os.path.exists(self.clip_model) and self.clip_model not in models: 57 | exit(f"ERROR: \"clip_model\": {self.clip_model}, <-- Model not found.\n" 58 | f"Make sure it is a valid path to a downloaded model or match one of {models}.") 59 | 60 | if not os.path.exists(self.vqgan_config): 61 | exit(f"ERROR: \"vqgan_config\": {self.vqgan_config}, <-- Configuration file not found.\n" 62 | f"Make sure the path is correct (Multiple config files are available in the `./configs/models` directory).") 63 | 64 | if not os.path.exists(self.vqgan_checkpoint): 65 | exit(f"ERROR: \"vqgan_checkpoint\": {self.vqgan_checkpoint}, <-- Model not found.\n" 66 | f"Make sure the path is correct and that you have downloaded the model (Refer to the README).") 67 | 68 | if self.pixelart: 69 | print("Enabling PixelArt mode. It is recommended to add 'pixelart' to your prompt.") 70 | 71 | 72 | def __str__(self): 73 | _str = ( 74 | f"Config:\n" 75 | f" - prompts: {self.prompts}\n" 76 | f" - image_prompts: {self.image_prompts}\n" 77 | f" - max_iterations: {self.max_iterations}\n" 78 | f" - save_freq: {self.save_freq}\n" 79 | f" - size: {self.size}\n" 80 | f" - pixelart: {self.pixelart}\n" 81 | f" - init_image: {self.init_image}\n" 82 | f" - init_noise: {self.init_noise}\n" 83 | f" - init_weight: {self.init_weight}\n" 84 | f" - mse_decay_rate: {self.mse_decay_rate}\n" 85 | f" - output_dir: {self.output_dir}\n" 86 | f" - models_dir: {self.models_dir}\n" 87 | f" - clip_model: {self.clip_model}\n" 88 | f" - vqgan_checkpoint: {self.vqgan_checkpoint}\n" 89 | f" - vqgan_config: {self.vqgan_config}\n" 90 | f" - noise_prompt_seeds: {self.noise_prompt_seeds}\n" 91 | f" - noise_prompt_weights: {self.noise_prompt_weights}\n" 92 | f" - step_size: {self.step_size}\n" 93 | f" - cutn: {self.cutn}\n" 94 | f" - cut_pow: {self.cut_pow}\n" 95 | f" - seed: {self.seed}\n" 96 | f" - optimizer: {self.optimizer}\n" 97 | f" - nwarm_restarts: {self.nwarm_restarts}\n" 98 | f" - augments: {self.augments}\n" 99 | ) 100 | return _str 101 | -------------------------------------------------------------------------------- /core/schemas/train_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dataclasses import dataclass 4 | 5 | 6 | @dataclass 7 | class TrainConfig: 8 | base_learning_rate: float = 4.5e-6 9 | batch_size: int = 1 10 | epochs: int = 1000 11 | data_dir: str = "./data" 12 | output_dir: str = "./outputs" 13 | models_dir: str = "./models" 14 | resume_checkpoint: str = "" 15 | seed: int = -1 16 | params: dict = None 17 | 18 | def __post_init__(self): 19 | if not os.path.exists(self.data_dir): 20 | exit(f"ERROR: \"data_dir\": {self.data_dir}, <-- Data direcotry not found.\n" 21 | f"Make sure the path is correct (Follow instructions in the README).") 22 | 23 | ckpt_dir = os.path.join(self.models_dir, "checkpoints") 24 | os.makedirs(ckpt_dir, exist_ok=True) 25 | print(f"Checkpoints will be saved in {ckpt_dir}") 26 | 27 | train_dir = os.path.join(self.output_dir, "training") 28 | os.makedirs(train_dir, exist_ok=True) 29 | print(f"Training outputs will be saved in {train_dir}") 30 | 31 | if self.resume_checkpoint and not os.path.exists(self.resume_checkpoint): 32 | exit(f"ERROR: \"resume_checkpoint\": {self.resume_checkpoint}, <-- Model not found.\n" 33 | f"Make sure the path is correct (Follow instructions in the README).") 34 | 35 | def __str__(self): 36 | _str = ( 37 | f"Config:\n" 38 | f" - base_learning_rate: {self.base_learning_rate}\n" 39 | f" - batch_size: {self.batch_size}\n" 40 | f" - epochs: {self.epochs}\n" 41 | f" - data_dir: {self.data_dir}\n" 42 | f" - output_dir: {self.output_dir}\n" 43 | f" - models_dir: {self.models_dir}\n" 44 | f" - resume_checkpoint: {self.resume_checkpoint}\n" 45 | f" - seed: {self.seed}\n" 46 | f" - params: {self.params}\n" 47 | ) 48 | return _str 49 | -------------------------------------------------------------------------------- /core/taming/README.md: -------------------------------------------------------------------------------- 1 | # Taming Transformers for High-Resolution Image Synthesis 2 | 3 | [[Original]](https://github.com/CompVis/taming-transformers) 4 | 5 | ## About 6 | 7 | A stripped & minimalist version of the original project. 8 | -------------------------------------------------------------------------------- /core/taming/models/__init__.py: -------------------------------------------------------------------------------- 1 | from core.taming.models.vqgan import VQModel 2 | 3 | 4 | __all__ = [ 5 | VQModel 6 | ] 7 | -------------------------------------------------------------------------------- /core/taming/models/vqgan.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from core.taming.modules.diffusion import Encoder, Decoder 5 | from core.taming.modules.vqvae import VectorQuantizer 6 | from core.taming.modules.losses import VQLPIPSWithDiscriminator, DummyLoss 7 | 8 | from core.utils.loader import safe_load 9 | 10 | 11 | class VQModel(nn.Module): 12 | def __init__(self, 13 | ddconfig, 14 | n_embed, 15 | embed_dim, 16 | lossconfig=None, 17 | ckpt_path=None, 18 | model_dir=None, 19 | ignore_keys=[], 20 | image_key="image", 21 | colorize_nlabels=None, 22 | monitor=None, 23 | remap=None, 24 | sane_index_shape=False, # tell vector quantizer to return indices as bhw 25 | ): 26 | super().__init__() 27 | self.image_key = image_key 28 | 29 | self.encoder = Encoder(**ddconfig) 30 | self.decoder = Decoder(**ddconfig) 31 | 32 | self.loss = DummyLoss() 33 | if lossconfig is not None: 34 | self.loss = VQLPIPSWithDiscriminator(model_dir=model_dir, **lossconfig["params"]) 35 | 36 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, 37 | remap=remap, sane_index_shape=sane_index_shape) 38 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) 39 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) 40 | 41 | if ckpt_path is not None: 42 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) 43 | 44 | self.image_key = image_key 45 | 46 | if colorize_nlabels is not None: 47 | assert type(colorize_nlabels) == int 48 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) 49 | if monitor is not None: 50 | self.monitor = monitor 51 | 52 | def init_from_ckpt(self, path, ignore_keys=list()): 53 | try: 54 | sd = torch.load(path, map_location="cpu")["state_dict"] 55 | except Exception: 56 | sd = safe_load(path, map_location="cpu")["state_dict"] 57 | 58 | keys = list(sd.keys()) 59 | for k in keys: 60 | for ik in ignore_keys: 61 | if k.startswith(ik): 62 | print("Deleting key {} from state_dict.".format(k)) 63 | del sd[k] 64 | 65 | if "first_stage_model.encoder.conv_in.weight" in sd: 66 | stripped_state_dict = {} 67 | for key in sd: 68 | if key.startswith("first_stage_model."): 69 | stripped_state_dict[key[18:]] = sd[key] 70 | sd = stripped_state_dict 71 | 72 | self.load_state_dict(sd, strict=False) 73 | print(f"Restored from {path}") 74 | 75 | def encode(self, x): 76 | h = self.encoder(x) 77 | h = self.quant_conv(h) 78 | quant, emb_loss, info = self.quantize(h) 79 | return quant, emb_loss, info 80 | 81 | def decode(self, quant): 82 | quant = self.post_quant_conv(quant) 83 | dec = self.decoder(quant) 84 | return dec 85 | 86 | def decode_code(self, code_b): 87 | quant_b = self.quantize.embed_code(code_b) 88 | dec = self.decode(quant_b) 89 | return dec 90 | 91 | def forward(self, input): 92 | quant, diff, _ = self.encode(input) 93 | dec = self.decode(quant) 94 | return dec, diff 95 | 96 | def get_input(self, batch, device): 97 | x = batch 98 | if len(x.shape) == 3: 99 | x = x[..., None] 100 | x = x.to(device, memory_format=torch.contiguous_format) 101 | return x.float() 102 | 103 | def training_step(self, batch, batch_idx, optimizer_idx, device='cpu'): 104 | x = self.get_input(batch, device) 105 | xrec, qloss = self(x) 106 | 107 | if optimizer_idx == 0: 108 | # autoencode 109 | aeloss = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") 110 | return aeloss 111 | 112 | if optimizer_idx == 1: 113 | # discriminator 114 | discloss = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train") 115 | return discloss 116 | 117 | def configure_optimizers(self): 118 | lr = self.learning_rate 119 | opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + 120 | list(self.decoder.parameters()) + 121 | list(self.quantize.parameters()) + 122 | list(self.quant_conv.parameters()) + 123 | list(self.post_quant_conv.parameters()), 124 | lr=lr, betas=(0.5, 0.9)) 125 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), 126 | lr=lr, betas=(0.5, 0.9)) 127 | return [opt_ae, opt_disc], [] 128 | 129 | def get_last_layer(self): 130 | return self.decoder.conv_out.weight 131 | -------------------------------------------------------------------------------- /core/taming/modules/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | from core.taming.modules.diffusion.attn_block import AttnBlock 2 | from core.taming.modules.diffusion.resnet_block import ResnetBlock 3 | 4 | from core.taming.modules.diffusion.downsample import Downsample 5 | from core.taming.modules.diffusion.upsample import Upsample 6 | 7 | from core.taming.modules.diffusion.encoder import Encoder 8 | from core.taming.modules.diffusion.decoder import Decoder 9 | 10 | 11 | __all__ = [ 12 | AttnBlock, 13 | ResnetBlock, 14 | Downsample, 15 | Upsample, 16 | Encoder, 17 | Decoder, 18 | ] 19 | -------------------------------------------------------------------------------- /core/taming/modules/diffusion/attn_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from core.taming.utils import Normalize 5 | 6 | 7 | class AttnBlock(nn.Module): 8 | def __init__(self, in_channels): 9 | super().__init__() 10 | self.in_channels = in_channels 11 | 12 | self.norm = Normalize(in_channels) 13 | self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 14 | self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 15 | self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 16 | self.proj_out = torch.nn.Conv2d( 17 | in_channels, in_channels, kernel_size=1, stride=1, padding=0 18 | ) 19 | 20 | def forward(self, x): 21 | h_ = x 22 | h_ = self.norm(h_) 23 | q = self.q(h_) 24 | k = self.k(h_) 25 | v = self.v(h_) 26 | 27 | # compute attention 28 | b, c, h, w = q.shape 29 | q = q.reshape(b, c, h * w) 30 | q = q.permute(0, 2, 1) # b, hw, c 31 | k = k.reshape(b, c, h * w) # b, c, hw 32 | w_ = torch.bmm(q, k) # b, hw, hw w[b, i, j]=sum_c q[b, i, c]k[b, c, j] 33 | w_ = w_ * (int(c)**(-0.5)) 34 | w_ = torch.nn.functional.softmax(w_, dim=2) 35 | 36 | # attend to values 37 | v = v.reshape(b, c, h * w) 38 | w_ = w_.permute(0, 2, 1) # b, hw, hw (first hw of k, second of q) 39 | h_ = torch.bmm(v, w_) # b, c, hw (hw of q) h_[b, c, j] = sum_i v[b, c, i] w_[b, i, j] 40 | h_ = h_.reshape(b, c, h, w) 41 | 42 | h_ = self.proj_out(h_) 43 | 44 | return x + h_ 45 | -------------------------------------------------------------------------------- /core/taming/modules/diffusion/decoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | from core.taming.utils import Normalize, nonlinearity 7 | 8 | from core.taming.modules.diffusion import AttnBlock, ResnetBlock, Upsample 9 | 10 | 11 | class Decoder(nn.Module): 12 | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, 13 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 14 | resolution, z_channels, give_pre_end=False, **ignorekwargs): 15 | super().__init__() 16 | self.ch = ch 17 | self.temb_ch = 0 18 | self.num_resolutions = len(ch_mult) 19 | self.num_res_blocks = num_res_blocks 20 | self.resolution = resolution 21 | self.in_channels = in_channels 22 | self.give_pre_end = give_pre_end 23 | 24 | # compute in_ch_mult, block_in and curr_res at lowest res 25 | # in_ch_mult = (1,)+tuple(ch_mult) 26 | block_in = ch * ch_mult[self.num_resolutions - 1] 27 | curr_res = resolution // 2**(self.num_resolutions - 1) 28 | self.z_shape = (1, z_channels, curr_res, curr_res) 29 | print("Working with z of shape {} = {} dimensions.".format( 30 | self.z_shape, np.prod(self.z_shape))) 31 | 32 | # z to block_in 33 | self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1) 34 | 35 | # middle 36 | self.mid = nn.Module() 37 | self.mid.block_1 = ResnetBlock( 38 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout 39 | ) 40 | self.mid.attn_1 = AttnBlock(block_in) 41 | self.mid.block_2 = ResnetBlock( 42 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout 43 | ) 44 | 45 | # upsampling 46 | self.up = nn.ModuleList() 47 | for i_level in reversed(range(self.num_resolutions)): 48 | block = nn.ModuleList() 49 | attn = nn.ModuleList() 50 | block_out = ch * ch_mult[i_level] 51 | for i_block in range(self.num_res_blocks + 1): 52 | block.append(ResnetBlock(in_channels=block_in, 53 | out_channels=block_out, 54 | temb_channels=self.temb_ch, 55 | dropout=dropout)) 56 | block_in = block_out 57 | if curr_res in attn_resolutions: 58 | attn.append(AttnBlock(block_in)) 59 | up = nn.Module() 60 | up.block = block 61 | up.attn = attn 62 | if i_level != 0: 63 | up.upsample = Upsample(block_in, resamp_with_conv) 64 | curr_res = curr_res * 2 65 | self.up.insert(0, up) # prepend to get consistent order 66 | 67 | # end 68 | self.norm_out = Normalize(block_in) 69 | self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) 70 | 71 | def forward(self, z): 72 | # assert z.shape[1:] == self.z_shape[1:] 73 | self.last_z_shape = z.shape 74 | 75 | # timestep embedding 76 | temb = None 77 | 78 | # z to block_in 79 | h = self.conv_in(z) 80 | 81 | # middle 82 | h = self.mid.block_1(h, temb) 83 | h = self.mid.attn_1(h) 84 | h = self.mid.block_2(h, temb) 85 | 86 | # upsampling 87 | for i_level in reversed(range(self.num_resolutions)): 88 | for i_block in range(self.num_res_blocks + 1): 89 | h = self.up[i_level].block[i_block](h, temb) 90 | if len(self.up[i_level].attn) > 0: 91 | h = self.up[i_level].attn[i_block](h) 92 | if i_level != 0: 93 | h = self.up[i_level].upsample(h) 94 | 95 | # end 96 | if self.give_pre_end: 97 | return h 98 | 99 | h = self.norm_out(h) 100 | h = nonlinearity(h) 101 | h = self.conv_out(h) 102 | return h 103 | -------------------------------------------------------------------------------- /core/taming/modules/diffusion/downsample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Downsample(nn.Module): 6 | def __init__(self, in_channels, with_conv): 7 | super().__init__() 8 | self.with_conv = with_conv 9 | if self.with_conv: 10 | # no asymmetric padding in torch conv, must do it ourselves 11 | self.conv = torch.nn.Conv2d( 12 | in_channels, in_channels, kernel_size=3, stride=2, padding=0 13 | ) 14 | 15 | def forward(self, x): 16 | if self.with_conv: 17 | pad = (0, 1, 0, 1) 18 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0) 19 | x = self.conv(x) 20 | else: 21 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) 22 | return x 23 | -------------------------------------------------------------------------------- /core/taming/modules/diffusion/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from core.taming.utils import Normalize, nonlinearity 5 | 6 | from core.taming.modules.diffusion import AttnBlock, ResnetBlock, Downsample 7 | 8 | 9 | class Encoder(nn.Module): 10 | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, 11 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, 12 | resolution, z_channels, double_z=True, **ignore_kwargs): 13 | super().__init__() 14 | self.ch = ch 15 | self.temb_ch = 0 16 | self.num_resolutions = len(ch_mult) 17 | self.num_res_blocks = num_res_blocks 18 | self.resolution = resolution 19 | self.in_channels = in_channels 20 | 21 | # downsampling 22 | self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1) 23 | 24 | curr_res = resolution 25 | in_ch_mult = (1,) + tuple(ch_mult) 26 | self.down = nn.ModuleList() 27 | for i_level in range(self.num_resolutions): 28 | block = nn.ModuleList() 29 | attn = nn.ModuleList() 30 | block_in = ch * in_ch_mult[i_level] 31 | block_out = ch * ch_mult[i_level] 32 | for i_block in range(self.num_res_blocks): 33 | block.append(ResnetBlock(in_channels=block_in, 34 | out_channels=block_out, 35 | temb_channels=self.temb_ch, 36 | dropout=dropout)) 37 | block_in = block_out 38 | if curr_res in attn_resolutions: 39 | attn.append(AttnBlock(block_in)) 40 | down = nn.Module() 41 | down.block = block 42 | down.attn = attn 43 | if i_level != self.num_resolutions - 1: 44 | down.downsample = Downsample(block_in, resamp_with_conv) 45 | curr_res = curr_res // 2 46 | self.down.append(down) 47 | 48 | # middle 49 | self.mid = nn.Module() 50 | self.mid.block_1 = ResnetBlock( 51 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout 52 | ) 53 | self.mid.attn_1 = AttnBlock(block_in) 54 | self.mid.block_2 = ResnetBlock( 55 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout 56 | ) 57 | 58 | # end 59 | self.norm_out = Normalize(block_in) 60 | self.conv_out = torch.nn.Conv2d( 61 | block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1 62 | ) 63 | 64 | def forward(self, x): 65 | # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format( 66 | # x.shape[2], x.shape[3], self.resolution 67 | # ) 68 | 69 | # timestep embedding 70 | temb = None 71 | 72 | # downsampling 73 | hs = [self.conv_in(x)] 74 | for i_level in range(self.num_resolutions): 75 | for i_block in range(self.num_res_blocks): 76 | h = self.down[i_level].block[i_block](hs[-1], temb) 77 | if len(self.down[i_level].attn) > 0: 78 | h = self.down[i_level].attn[i_block](h) 79 | hs.append(h) 80 | if i_level != self.num_resolutions - 1: 81 | hs.append(self.down[i_level].downsample(hs[-1])) 82 | 83 | # middle 84 | h = hs[-1] 85 | h = self.mid.block_1(h, temb) 86 | h = self.mid.attn_1(h) 87 | h = self.mid.block_2(h, temb) 88 | 89 | # end 90 | h = self.norm_out(h) 91 | h = nonlinearity(h) 92 | h = self.conv_out(h) 93 | return h 94 | -------------------------------------------------------------------------------- /core/taming/modules/diffusion/resnet_block.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from core.taming.utils import Normalize, nonlinearity 5 | 6 | 7 | class ResnetBlock(nn.Module): 8 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, 9 | dropout, temb_channels=512): 10 | super().__init__() 11 | self.in_channels = in_channels 12 | out_channels = in_channels if out_channels is None else out_channels 13 | self.out_channels = out_channels 14 | self.use_conv_shortcut = conv_shortcut 15 | 16 | self.norm1 = Normalize(in_channels) 17 | self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) 18 | if temb_channels > 0: 19 | self.temb_proj = torch.nn.Linear(temb_channels, 20 | out_channels) 21 | self.norm2 = Normalize(out_channels) 22 | self.dropout = torch.nn.Dropout(dropout) 23 | self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) 24 | if self.in_channels != self.out_channels: 25 | if self.use_conv_shortcut: 26 | self.conv_shortcut = torch.nn.Conv2d( 27 | in_channels, out_channels, kernel_size=3, stride=1, padding=1 28 | ) 29 | else: 30 | self.nin_shortcut = torch.nn.Conv2d( 31 | in_channels, out_channels, kernel_size=1, stride=1, padding=0 32 | ) 33 | 34 | def forward(self, x, temb): 35 | h = x 36 | h = self.norm1(h) 37 | h = nonlinearity(h) 38 | h = self.conv1(h) 39 | 40 | if temb is not None: 41 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] 42 | 43 | h = self.norm2(h) 44 | h = nonlinearity(h) 45 | h = self.dropout(h) 46 | h = self.conv2(h) 47 | 48 | if self.in_channels != self.out_channels: 49 | if self.use_conv_shortcut: 50 | x = self.conv_shortcut(x) 51 | else: 52 | x = self.nin_shortcut(x) 53 | 54 | return x + h 55 | -------------------------------------------------------------------------------- /core/taming/modules/diffusion/upsample.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class Upsample(nn.Module): 6 | def __init__(self, in_channels, with_conv): 7 | super().__init__() 8 | self.with_conv = with_conv 9 | if self.with_conv: 10 | self.conv = torch.nn.Conv2d( 11 | in_channels, in_channels, kernel_size=3, stride=1, padding=1 12 | ) 13 | 14 | def forward(self, x): 15 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") 16 | if self.with_conv: 17 | x = self.conv(x) 18 | return x 19 | -------------------------------------------------------------------------------- /core/taming/modules/discriminator/__init__.py: -------------------------------------------------------------------------------- 1 | from core.taming.modules.discriminator.act_norm import ActNorm 2 | from core.taming.modules.discriminator.discriminator import NLayerDiscriminator 3 | 4 | __all__ = [ 5 | ActNorm, 6 | NLayerDiscriminator 7 | ] 8 | -------------------------------------------------------------------------------- /core/taming/modules/discriminator/act_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ActNorm(nn.Module): 6 | def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False): 7 | assert affine 8 | super().__init__() 9 | self.logdet = logdet 10 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) 11 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) 12 | self.allow_reverse_init = allow_reverse_init 13 | 14 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) 15 | 16 | def initialize(self, input): 17 | with torch.no_grad(): 18 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) 19 | mean = ( 20 | flatten.mean(1) 21 | .unsqueeze(1) 22 | .unsqueeze(2) 23 | .unsqueeze(3) 24 | .permute(1, 0, 2, 3) 25 | ) 26 | std = ( 27 | flatten.std(1) 28 | .unsqueeze(1) 29 | .unsqueeze(2) 30 | .unsqueeze(3) 31 | .permute(1, 0, 2, 3) 32 | ) 33 | 34 | self.loc.data.copy_(-mean) 35 | self.scale.data.copy_(1 / (std + 1e-6)) 36 | 37 | def forward(self, input, reverse=False): 38 | if reverse: 39 | return self.reverse(input) 40 | if len(input.shape) == 2: 41 | input = input[:,:,None,None] 42 | squeeze = True 43 | else: 44 | squeeze = False 45 | 46 | _, _, height, width = input.shape 47 | 48 | if self.training and self.initialized.item() == 0: 49 | self.initialize(input) 50 | self.initialized.fill_(1) 51 | 52 | h = self.scale * (input + self.loc) 53 | 54 | if squeeze: 55 | h = h.squeeze(-1).squeeze(-1) 56 | 57 | if self.logdet: 58 | log_abs = torch.log(torch.abs(self.scale)) 59 | logdet = height*width*torch.sum(log_abs) 60 | logdet = logdet * torch.ones(input.shape[0]).to(input) 61 | return h, logdet 62 | 63 | return h 64 | 65 | def reverse(self, output): 66 | if self.training and self.initialized.item() == 0: 67 | if not self.allow_reverse_init: 68 | raise RuntimeError( 69 | "Initializing ActNorm in reverse direction is " 70 | "disabled by default. Use allow_reverse_init=True to enable." 71 | ) 72 | else: 73 | self.initialize(output) 74 | self.initialized.fill_(1) 75 | 76 | if len(output.shape) == 2: 77 | output = output[:,:,None,None] 78 | squeeze = True 79 | else: 80 | squeeze = False 81 | 82 | h = output / self.scale - self.loc 83 | 84 | if squeeze: 85 | h = h.squeeze(-1).squeeze(-1) 86 | return h 87 | -------------------------------------------------------------------------------- /core/taming/modules/discriminator/discriminator.py: -------------------------------------------------------------------------------- 1 | # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py 2 | 3 | import functools 4 | import torch.nn as nn 5 | 6 | from core.taming.modules.discriminator import ActNorm 7 | 8 | 9 | class NLayerDiscriminator(nn.Module): 10 | """Defines a PatchGAN discriminator as in Pix2Pix""" 11 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): 12 | """Construct a PatchGAN discriminator 13 | Parameters: 14 | input_nc (int) -- the number of channels in input images 15 | ndf (int) -- the number of filters in the last conv layer 16 | n_layers (int) -- the number of conv layers in the discriminator 17 | norm_layer -- normalization layer 18 | """ 19 | super(NLayerDiscriminator, self).__init__() 20 | if not use_actnorm: 21 | norm_layer = nn.BatchNorm2d 22 | else: 23 | norm_layer = ActNorm 24 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 25 | use_bias = norm_layer.func != nn.BatchNorm2d 26 | else: 27 | use_bias = norm_layer != nn.BatchNorm2d 28 | 29 | kw = 4 30 | padw = 1 31 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 32 | nf_mult = 1 33 | nf_mult_prev = 1 34 | for n in range(1, n_layers): # gradually increase the number of filters 35 | nf_mult_prev = nf_mult 36 | nf_mult = min(2 ** n, 8) 37 | sequence += [ 38 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 39 | norm_layer(ndf * nf_mult), 40 | nn.LeakyReLU(0.2, True) 41 | ] 42 | 43 | nf_mult_prev = nf_mult 44 | nf_mult = min(2 ** n_layers, 8) 45 | sequence += [ 46 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 47 | norm_layer(ndf * nf_mult), 48 | nn.LeakyReLU(0.2, True) 49 | ] 50 | 51 | sequence += [ 52 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 53 | self.main = nn.Sequential(*sequence) 54 | 55 | def forward(self, input): 56 | """Standard forward.""" 57 | return self.main(input) 58 | -------------------------------------------------------------------------------- /core/taming/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from core.taming.modules.losses.lpips import LPIPS 2 | from core.taming.modules.losses.vqperceptual import VQLPIPSWithDiscriminator, DummyLoss 3 | 4 | __all__ = [ 5 | LPIPS, 6 | DummyLoss, 7 | VQLPIPSWithDiscriminator 8 | ] 9 | -------------------------------------------------------------------------------- /core/taming/modules/losses/lpips.py: -------------------------------------------------------------------------------- 1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" 2 | 3 | import os 4 | 5 | import torch 6 | import torch.nn as nn 7 | from collections import namedtuple 8 | 9 | from core.utils.loader import download 10 | from core.taming.utils import normalize_tensor, spatial_average, load_vgg 11 | 12 | 13 | class LPIPS(nn.Module): 14 | # Learned perceptual metric 15 | def __init__(self, model_dir="/models", use_dropout=True): 16 | super().__init__() 17 | self.scaling_layer = ScalingLayer() 18 | self.chns = [64, 128, 256, 512, 512] # vg16 features 19 | self.net = VGG16(model_dir=model_dir, pretrained=True, requires_grad=False) 20 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 21 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 22 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 23 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 24 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 25 | self.load_from_pretrained(model_dir) 26 | for param in self.parameters(): 27 | param.requires_grad = False 28 | 29 | def load_from_pretrained(self, model_dir="/models"): 30 | ckpt = f"{model_dir}/vgg.pth" 31 | if not os.path.exists(ckpt): 32 | download("https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1", ckpt) 33 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) 34 | print(f"Loaded pretrained LPIPS loss from '{ckpt}'") 35 | 36 | def forward(self, input, target): 37 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) 38 | outs0, outs1 = self.net(in0_input), self.net(in1_input) 39 | feats0, feats1, diffs = {}, {}, {} 40 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 41 | for kk in range(len(self.chns)): 42 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) 43 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 44 | 45 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] 46 | val = res[0] 47 | for l in range(1, len(self.chns)): 48 | val += res[l] 49 | return val 50 | 51 | 52 | class ScalingLayer(nn.Module): 53 | def __init__(self): 54 | super(ScalingLayer, self).__init__() 55 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) 56 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) 57 | 58 | def forward(self, inp): 59 | return (inp - self.shift) / self.scale 60 | 61 | 62 | class NetLinLayer(nn.Module): 63 | """ A single linear layer which does a 1x1 conv """ 64 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 65 | super(NetLinLayer, self).__init__() 66 | layers = [nn.Dropout(), ] if (use_dropout) else [] 67 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] 68 | self.model = nn.Sequential(*layers) 69 | 70 | 71 | class VGG16(torch.nn.Module): 72 | def __init__(self, model_dir="/models", requires_grad=False, pretrained=True): 73 | super(VGG16, self).__init__() 74 | vgg_pretrained_features = load_vgg(model_dir=model_dir, pretrained=pretrained).features 75 | self.slice1 = torch.nn.Sequential() 76 | self.slice2 = torch.nn.Sequential() 77 | self.slice3 = torch.nn.Sequential() 78 | self.slice4 = torch.nn.Sequential() 79 | self.slice5 = torch.nn.Sequential() 80 | self.N_slices = 5 81 | for x in range(4): 82 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 83 | for x in range(4, 9): 84 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 85 | for x in range(9, 16): 86 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 87 | for x in range(16, 23): 88 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 89 | for x in range(23, 30): 90 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 91 | if not requires_grad: 92 | for param in self.parameters(): 93 | param.requires_grad = False 94 | 95 | def forward(self, X): 96 | h = self.slice1(X) 97 | h_relu1_2 = h 98 | h = self.slice2(h) 99 | h_relu2_2 = h 100 | h = self.slice3(h) 101 | h_relu3_3 = h 102 | h = self.slice4(h) 103 | h_relu4_3 = h 104 | h = self.slice5(h) 105 | h_relu5_3 = h 106 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 107 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 108 | return out 109 | -------------------------------------------------------------------------------- /core/taming/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from core.taming.utils import hinge_d_loss, vanilla_d_loss, adopt_weight, weights_init 5 | 6 | from core.taming.modules.discriminator import NLayerDiscriminator 7 | 8 | from core.taming.modules.losses import LPIPS 9 | 10 | 11 | class DummyLoss(nn.Module): 12 | def __init__(self): 13 | super().__init__() 14 | 15 | 16 | class VQLPIPSWithDiscriminator(nn.Module): 17 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 18 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 19 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 20 | disc_ndf=64, disc_loss="hinge", model_dir=None): 21 | super().__init__() 22 | assert disc_loss in ["hinge", "vanilla"] 23 | self.codebook_weight = codebook_weight 24 | self.pixel_weight = pixelloss_weight 25 | self.perceptual_loss = LPIPS(model_dir=model_dir).eval() 26 | self.perceptual_weight = perceptual_weight 27 | 28 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 29 | n_layers=disc_num_layers, 30 | use_actnorm=use_actnorm, 31 | ndf=disc_ndf 32 | ).apply(weights_init) 33 | self.discriminator_iter_start = disc_start 34 | if disc_loss == "hinge": 35 | self.disc_loss = hinge_d_loss 36 | elif disc_loss == "vanilla": 37 | self.disc_loss = vanilla_d_loss 38 | else: 39 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 40 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 41 | self.disc_factor = disc_factor 42 | self.discriminator_weight = disc_weight 43 | self.disc_conditional = disc_conditional 44 | 45 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 46 | if last_layer is not None: 47 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 48 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 49 | else: 50 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 51 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 52 | 53 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 54 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 55 | d_weight = d_weight * self.discriminator_weight 56 | return d_weight 57 | 58 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 59 | global_step, last_layer=None, cond=None, split="train"): 60 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 61 | if self.perceptual_weight > 0: 62 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 63 | rec_loss = rec_loss + self.perceptual_weight * p_loss 64 | else: 65 | p_loss = torch.tensor([0.0]) 66 | 67 | nll_loss = rec_loss 68 | nll_loss = torch.mean(nll_loss) 69 | 70 | # now the GAN part 71 | if optimizer_idx == 0: 72 | # generator update 73 | if cond is None: 74 | assert not self.disc_conditional 75 | logits_fake = self.discriminator(reconstructions.contiguous()) 76 | else: 77 | assert self.disc_conditional 78 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 79 | g_loss = -torch.mean(logits_fake) 80 | 81 | try: 82 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 83 | except RuntimeError: 84 | assert not self.training 85 | d_weight = torch.tensor(0.0) 86 | 87 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 88 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 89 | 90 | return loss 91 | 92 | if optimizer_idx == 1: 93 | # second pass for discriminator update 94 | if cond is None: 95 | logits_real = self.discriminator(inputs.contiguous().detach()) 96 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 97 | else: 98 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 99 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 100 | 101 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 102 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 103 | 104 | return d_loss 105 | -------------------------------------------------------------------------------- /core/taming/modules/vqvae/__init__.py: -------------------------------------------------------------------------------- 1 | from core.taming.modules.vqvae.vector_quantizer import VectorQuantizer 2 | 3 | 4 | __all__ = [ 5 | VectorQuantizer 6 | ] 7 | -------------------------------------------------------------------------------- /core/taming/modules/vqvae/vector_quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import numpy as np 5 | 6 | from einops import rearrange 7 | 8 | 9 | class VectorQuantizer(nn.Module): 10 | """ 11 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly 12 | avoids costly matrix multiplications and allows for post-hoc remapping of indices. 13 | """ 14 | # NOTE: due to a bug the beta term was applied to the wrong term. for 15 | # backwards compatibility we use the buggy version by default, but you can 16 | # specify legacy=False to fix it. 17 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", 18 | sane_index_shape=False, legacy=True): 19 | super().__init__() 20 | self.n_e = n_e 21 | self.e_dim = e_dim 22 | self.beta = beta 23 | self.legacy = legacy 24 | 25 | self.embedding = nn.Embedding(self.n_e, self.e_dim) 26 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) 27 | 28 | self.remap = remap 29 | if self.remap is not None: 30 | self.register_buffer("used", torch.tensor(np.load(self.remap))) 31 | self.re_embed = self.used.shape[0] 32 | self.unknown_index = unknown_index # "random" or "extra" or integer 33 | if self.unknown_index == "extra": 34 | self.unknown_index = self.re_embed 35 | self.re_embed = self.re_embed + 1 36 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " 37 | f"Using {self.unknown_index} for unknown indices.") 38 | else: 39 | self.re_embed = n_e 40 | 41 | self.sane_index_shape = sane_index_shape 42 | 43 | def remap_to_used(self, inds): 44 | ishape = inds.shape 45 | assert len(ishape) > 1 46 | inds = inds.reshape(ishape[0], -1) 47 | used = self.used.to(inds) 48 | match = (inds[:, :, None] == used[None, None, ...]).long() 49 | new = match.argmax(-1) 50 | unknown = match.sum(2) < 1 51 | if self.unknown_index == "random": 52 | new[unknown] = \ 53 | torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device) 54 | else: 55 | new[unknown] = self.unknown_index 56 | return new.reshape(ishape) 57 | 58 | def unmap_to_all(self, inds): 59 | ishape = inds.shape 60 | assert len(ishape) > 1 61 | inds = inds.reshape(ishape[0], -1) 62 | used = self.used.to(inds) 63 | if self.re_embed > self.used.shape[0]: # extra token 64 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero 65 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds) 66 | return back.reshape(ishape) 67 | 68 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False): 69 | assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel" 70 | assert rescale_logits is False, "Only for interface compatible with Gumbel" 71 | assert return_logits is False, "Only for interface compatible with Gumbel" 72 | 73 | # reshape z -> (batch, height, width, channel) and flatten 74 | z = rearrange(z, 'b c h w -> b h w c').contiguous() 75 | z_flattened = z.view(-1, self.e_dim) 76 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 77 | 78 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ 79 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \ 80 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) 81 | 82 | min_encoding_indices = torch.argmin(d, dim=1) 83 | z_q = self.embedding(min_encoding_indices).view(z.shape) 84 | perplexity = None 85 | min_encodings = None 86 | 87 | # compute loss for embedding 88 | if not self.legacy: 89 | loss = self.beta * torch.mean((z_q.detach() - z)**2) + \ 90 | torch.mean((z_q - z.detach()) ** 2) 91 | else: 92 | loss = torch.mean((z_q.detach() - z)**2) + self.beta * \ 93 | torch.mean((z_q - z.detach()) ** 2) 94 | 95 | # preserve gradients 96 | z_q = z + (z_q - z).detach() 97 | 98 | # reshape back to match original input shape 99 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() 100 | 101 | if self.remap is not None: 102 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis 103 | min_encoding_indices = self.remap_to_used(min_encoding_indices) 104 | min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten 105 | 106 | if self.sane_index_shape: 107 | min_encoding_indices = min_encoding_indices.reshape( 108 | z_q.shape[0], z_q.shape[2], z_q.shape[3]) 109 | 110 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices) 111 | 112 | def get_codebook_entry(self, indices, shape): 113 | # shape specifying (batch, height, width, channel) 114 | if self.remap is not None: 115 | indices = indices.reshape(shape[0], -1) # add batch axis 116 | indices = self.unmap_to_all(indices) 117 | indices = indices.reshape(-1) # flatten again 118 | 119 | # get quantized latent vectors 120 | z_q = self.embedding(indices) 121 | 122 | if shape is not None: 123 | z_q = z_q.view(shape) 124 | # reshape back to match original input shape 125 | z_q = z_q.permute(0, 3, 1, 2).contiguous() 126 | 127 | return z_q 128 | -------------------------------------------------------------------------------- /core/taming/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from core.taming.utils.diffusion_utils import Normalize, nonlinearity 2 | from core.taming.utils.discriminator_utils import weights_init 3 | from core.taming.utils.losses_utils import ( 4 | adopt_weight, hinge_d_loss, vanilla_d_loss, normalize_tensor, spatial_average, load_vgg 5 | ) 6 | 7 | __all__ = [ 8 | Normalize, 9 | nonlinearity, 10 | weights_init, 11 | adopt_weight, 12 | hinge_d_loss, 13 | vanilla_d_loss, 14 | normalize_tensor, 15 | spatial_average, 16 | load_vgg, 17 | ] 18 | -------------------------------------------------------------------------------- /core/taming/utils/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def nonlinearity(x): 5 | # swish 6 | return x * torch.sigmoid(x) 7 | 8 | 9 | def Normalize(in_channels): 10 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 11 | -------------------------------------------------------------------------------- /core/taming/utils/discriminator_utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | def weights_init(m): 5 | classname = m.__class__.__name__ 6 | if classname.find('Conv') != -1: 7 | nn.init.normal_(m.weight.data, 0.0, 0.02) 8 | elif classname.find('BatchNorm') != -1: 9 | nn.init.normal_(m.weight.data, 1.0, 0.02) 10 | nn.init.constant_(m.bias.data, 0) 11 | -------------------------------------------------------------------------------- /core/taming/utils/losses_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from torchvision.models import VGG 6 | from torchvision.models.vgg import load_state_dict_from_url 7 | 8 | from typing import List, Union, cast 9 | 10 | 11 | def adopt_weight(weight, global_step, threshold=0, value=0.): 12 | if global_step < threshold: 13 | weight = value 14 | return weight 15 | 16 | 17 | def hinge_d_loss(logits_real, logits_fake): 18 | loss_real = torch.mean(F.relu(1. - logits_real)) 19 | loss_fake = torch.mean(F.relu(1. + logits_fake)) 20 | d_loss = 0.5 * (loss_real + loss_fake) 21 | return d_loss 22 | 23 | 24 | def vanilla_d_loss(logits_real, logits_fake): 25 | d_loss = 0.5 * ( 26 | torch.mean(torch.nn.functional.softplus(-logits_real)) + 27 | torch.mean(torch.nn.functional.softplus(logits_fake))) 28 | return d_loss 29 | 30 | 31 | def normalize_tensor(x, eps=1e-10): 32 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) 33 | return x / (norm_factor + eps) 34 | 35 | 36 | def spatial_average(x, keepdim=True): 37 | return x.mean([2, 3], keepdim=keepdim) 38 | 39 | 40 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: 41 | layers: List[nn.Module] = [] 42 | in_channels = 3 43 | for v in cfg: 44 | if v == 'M': 45 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 46 | else: 47 | v = cast(int, v) 48 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 49 | if batch_norm: 50 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] 51 | else: 52 | layers += [conv2d, nn.ReLU(inplace=True)] 53 | in_channels = v 54 | return nn.Sequential(*layers) 55 | 56 | 57 | def load_vgg(model_dir: str, pretrained: bool = False, **kwargs): 58 | if pretrained: 59 | kwargs['init_weights'] = False 60 | 61 | cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'] 62 | model = VGG(make_layers(cfg, batch_norm=False), **kwargs) 63 | 64 | if pretrained: 65 | state_dict = load_state_dict_from_url('https://download.pytorch.org/models/vgg16-397923af.pth', 66 | model_dir=model_dir, 67 | file_name="vgg16-397923af.pth", 68 | progress=True) 69 | model.load_state_dict(state_dict) 70 | print(f"Loaded pretrained VGG16 model from '{model_dir}/vgg16-397923af.pth'") 71 | 72 | return model 73 | -------------------------------------------------------------------------------- /core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from core.utils.make_cutouts import MakeCutouts 2 | from core.utils.normalize import Normalize 3 | from core.utils.helpers import resize_image, get_optimizer, get_scheduler, load_vqgan_model, global_seed 4 | 5 | __all__ = [ 6 | MakeCutouts, 7 | Normalize, 8 | resize_image, 9 | get_optimizer, 10 | get_scheduler, 11 | load_vqgan_model, 12 | global_seed 13 | ] 14 | -------------------------------------------------------------------------------- /core/utils/gradients.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | class ReplaceGrad(torch.autograd.Function): 6 | @staticmethod 7 | def forward(ctx, x_forward, x_backward): 8 | ctx.shape = x_backward.shape 9 | return x_forward 10 | 11 | @staticmethod 12 | def backward(ctx, grad_in): 13 | return None, grad_in.sum_to_size(ctx.shape) 14 | 15 | 16 | class ClampWithGrad(torch.autograd.Function): 17 | @staticmethod 18 | def forward(ctx, input, min, max): 19 | ctx.min = min 20 | ctx.max = max 21 | ctx.save_for_backward(input) 22 | return input.clamp(min, max) 23 | 24 | @staticmethod 25 | def backward(ctx, grad_in): 26 | input, = ctx.saved_tensors 27 | return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None 28 | 29 | 30 | def vector_quantize(x, codebook): 31 | d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T 32 | indices = d.argmin(-1) 33 | x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook 34 | return ReplaceGrad.apply(x_q, x) 35 | -------------------------------------------------------------------------------- /core/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | 4 | import numpy as np 5 | 6 | import torch 7 | import torch.optim as optim 8 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts 9 | 10 | from PIL import Image 11 | 12 | from core.taming.models import vqgan 13 | from core.optimizer import DiffGrad, AdamP, RAdam 14 | 15 | 16 | def resize_image(image, out_size): 17 | ratio = image.size[0] / image.size[1] 18 | area = min(image.size[0] * image.size[1], out_size[0] * out_size[1]) 19 | size = round((area * ratio)**0.5), round((area / ratio)**0.5) 20 | return image.resize(size, Image.LANCZOS) 21 | 22 | 23 | def get_optimizer(z, optimizer="Adam", step_size=0.1): 24 | if optimizer == "Adam": 25 | opt = optim.Adam([z], lr=step_size) # LR=0.1 (Default) 26 | elif optimizer == "AdamW": 27 | opt = optim.AdamW([z], lr=step_size) # LR=0.2 28 | elif optimizer == "Adagrad": 29 | opt = optim.Adagrad([z], lr=step_size) # LR=0.5+ 30 | elif optimizer == "Adamax": 31 | opt = optim.Adamax([z], lr=step_size) # LR=0.5+? 32 | elif optimizer == "DiffGrad": 33 | opt = DiffGrad([z], lr=step_size) # LR=2+? 34 | elif optimizer == "AdamP": 35 | opt = AdamP([z], lr=step_size) # LR=2+? 36 | elif optimizer == "RAdam": 37 | opt = RAdam([z], lr=step_size) # LR=2+? 38 | return opt 39 | 40 | 41 | def get_scheduler(optimizer, max_iterations, nwarm_restarts=-1): 42 | if nwarm_restarts == -1: 43 | return None 44 | 45 | T_0 = max_iterations 46 | if nwarm_restarts > 0: 47 | T_0 = int(np.ceil(max_iterations / nwarm_restarts)) 48 | 49 | return CosineAnnealingWarmRestarts(optimizer, T_0=T_0) 50 | 51 | 52 | def load_vqgan_model(config_path, checkpoint_path, model_dir=None): 53 | with open(config_path, 'r') as f: 54 | config = json.load(f) 55 | 56 | model = vqgan.VQModel(model_dir=model_dir, **config["params"]) 57 | model.eval().requires_grad_(False) 58 | model.init_from_ckpt(checkpoint_path) 59 | 60 | del model.loss 61 | return model 62 | 63 | 64 | def global_seed(seed: int): 65 | seed = seed if seed != -1 else torch.seed() 66 | if seed > 2**32 - 1: 67 | seed = seed >> 32 68 | 69 | random.seed(seed) 70 | np.random.seed(seed) 71 | torch.manual_seed(seed) 72 | torch.cuda.manual_seed_all(seed) 73 | print(f"Global seed set to {seed}.") 74 | -------------------------------------------------------------------------------- /core/utils/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import io 3 | import pickle 4 | 5 | import requests 6 | 7 | import torch 8 | from torch.serialization import ( 9 | _get_restore_location, _maybe_decode_ascii, _open_file_like, _open_zipfile_reader 10 | ) 11 | 12 | from tqdm import tqdm 13 | 14 | 15 | def safe_load(f, map_location=None, pickle_module=pickle, pickle_file='data.pkl', **pickle_load_args): 16 | with _open_file_like(f, 'rb') as opened_file: 17 | with _open_zipfile_reader(opened_file) as zip_file: 18 | restore_location = _get_restore_location(map_location) 19 | 20 | loaded_storages = {} 21 | 22 | def load_tensor(data_type, size, key, location): 23 | name = f'data/{key}' 24 | dtype = data_type(0).dtype 25 | 26 | storage = zip_file.get_storage_from_record(name, size, dtype).storage() 27 | loaded_storages[key] = restore_location(storage, location) 28 | 29 | def persistent_load(saved_id): 30 | assert isinstance(saved_id, tuple) 31 | typename = _maybe_decode_ascii(saved_id[0]) 32 | data = saved_id[1:] 33 | 34 | assert typename == 'storage', \ 35 | f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'" 36 | data_type, key, location, size = data 37 | if key not in loaded_storages: 38 | load_tensor(data_type, size, key, _maybe_decode_ascii(location)) 39 | storage = loaded_storages[key] 40 | return storage 41 | 42 | load_module_mapping = { 43 | 'torch.tensor': 'torch._tensor' 44 | } 45 | 46 | class UnpicklerWrapper(pickle_module.Unpickler): 47 | def find_class(self, mod_name, name): 48 | try: 49 | mod_name = load_module_mapping.get(mod_name, mod_name) 50 | return super().find_class(mod_name, name) 51 | except Exception: 52 | pass 53 | 54 | # Load the data (which may in turn use `persistent_load` to load tensors) 55 | data_file = io.BytesIO(zip_file.get_record(pickle_file)) 56 | 57 | unpickler = UnpicklerWrapper(data_file, **pickle_load_args) 58 | unpickler.persistent_load = persistent_load 59 | result = unpickler.load() 60 | 61 | torch._utils._validate_loaded_sparse_tensors() 62 | 63 | return result 64 | 65 | 66 | def download(url, local_path, chunk_size=1024): 67 | os.makedirs(os.path.split(local_path)[0], exist_ok=True) 68 | with requests.get(url, stream=True) as r: 69 | total_size = int(r.headers.get("content-length", 0)) 70 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 71 | with open(local_path, "wb") as f: 72 | for data in r.iter_content(chunk_size=chunk_size): 73 | if data: 74 | f.write(data) 75 | pbar.update(chunk_size) 76 | -------------------------------------------------------------------------------- /core/utils/make_cutouts.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | import kornia.augmentation as K 5 | 6 | CUTOUTS = { 7 | 'Ji': K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.5), 8 | 'Sh': K.RandomSharpness(sharpness=0.5, p=0.5), 9 | 'Gn': K.RandomGaussianNoise(mean=0.0, std=1.0, p=0.5), 10 | 'Pe': K.RandomPerspective(distortion_scale=0.5, p=0.5), 11 | 'Ro': K.RandomRotation(degrees=15, p=0.5), 12 | 'Af': K.RandomAffine(degrees=15, translate=0.1, shear=15, padding_mode='border', keepdim=True, p=0.5), 13 | 'Et': K.RandomElasticTransform(p=0.5), 14 | 'Hf': K.RandomHorizontalFlip(p=0.5), 15 | 'Ts': K.RandomThinPlateSpline(scale=0.2, same_on_batch=False, p=0.5), 16 | 'Er': K.RandomErasing(scale=(0.02, 0.33), ratio=(0.3, 3.3), same_on_batch=False, p=0.5), 17 | } 18 | 19 | 20 | class MakeCutouts(nn.Module): 21 | def __init__(self, augments, cut_size, cutn, cut_pow=1.): 22 | super().__init__() 23 | self.cut_size = cut_size 24 | self.cutn = cutn 25 | self.cut_pow = cut_pow 26 | 27 | augment_list = [] 28 | for item in augments: 29 | if item == 'Cr': 30 | aug = K.RandomCrop(size=(self.cut_size, self.cut_size), p=0.5) 31 | elif item == 'Re': 32 | aug = K.RandomResizedCrop(size=(self.cut_size, self.cut_size), cropping_mode='resample', p=0.5) 33 | else: 34 | aug = CUTOUTS[item] 35 | augment_list.append(aug) 36 | 37 | print(f"Augmentations: {augment_list}") 38 | self.augs = nn.Sequential(*augment_list) 39 | 40 | self.noise_fac = 0.1 41 | 42 | # Pooling 43 | self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size)) 44 | self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size)) 45 | 46 | def forward(self, input): 47 | cutouts = [] 48 | 49 | for _ in range(self.cutn): 50 | # Use Pooling 51 | cutout = (self.av_pool(input) + self.max_pool(input)) / 2 52 | cutouts.append(cutout) 53 | 54 | batch = self.augs(torch.cat(cutouts, dim=0)) 55 | 56 | if self.noise_fac: 57 | facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac) 58 | batch = batch + facs * torch.randn_like(batch) 59 | return batch 60 | -------------------------------------------------------------------------------- /core/utils/noises.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from PIL import Image 4 | 5 | 6 | def perlin_noise_2d(shape, res): 7 | def interpolant(t): 8 | return t*t*t*(t*(t*6 - 15) + 10) 9 | 10 | delta = (res[0] / shape[0], res[1] / shape[1]) 11 | d = (shape[0] // res[0], shape[1] // res[1]) 12 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1 13 | 14 | # Gradients 15 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1) 16 | gradients = np.dstack((np.cos(angles), np.sin(angles))) 17 | gradients = gradients.repeat(d[0], 0).repeat(d[1], 1) 18 | g00 = gradients[ :-d[0], :-d[1]] 19 | g10 = gradients[d[0]: , :-d[1]] 20 | g01 = gradients[ :-d[0],d[1]: ] 21 | g11 = gradients[d[0]: ,d[1]: ] 22 | 23 | # Ramps 24 | n00 = np.sum(np.dstack((grid[:, :, 0] , grid[:, :, 1] )) * g00, 2) 25 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] )) * g10, 2) 26 | n01 = np.sum(np.dstack((grid[:, :, 0] , grid[:, :, 1]-1)) * g01, 2) 27 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1]-1)) * g11, 2) 28 | 29 | # Interpolation 30 | t = interpolant(grid) 31 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10 32 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11 33 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1) 34 | 35 | 36 | def fractal_noise_2d(shape, res, octaves=1, persistence=0.5, lacunarity=2): 37 | noise = np.zeros(shape) 38 | frequency = 1 39 | amplitude = 1 40 | 41 | for _ in range(octaves): 42 | noise += amplitude * perlin_noise_2d(shape, (frequency * res[0], frequency * res[1])) 43 | frequency *= lacunarity 44 | amplitude *= persistence 45 | return (noise - np.min(noise)) / (np.max(noise) - np.min(noise)) 46 | 47 | 48 | def random_fractal_image(width, height): 49 | _pow = int(np.ceil(np.log(max(width, height)) / np.log(2))) 50 | octaves = _pow - 4 51 | size = 2 ** _pow 52 | r = fractal_noise_2d((size, size), (32, 32), octaves=octaves) 53 | g = fractal_noise_2d((size, size), (32, 32), octaves=octaves) 54 | b = fractal_noise_2d((size, size), (32, 32), octaves=octaves) 55 | 56 | tile = np.dstack((r, g, b))[:height, :width, :] 57 | return Image.fromarray((255.9 * tile).astype('uint8')) 58 | 59 | 60 | def random_noise_image(width, height): 61 | return Image.fromarray( 62 | np.random.randint(0, 255, (width, height, 3), dtype=np.dtype('uint8')) 63 | ) 64 | 65 | 66 | def gradient_2d(start, stop, width, height, is_horizontal): 67 | if is_horizontal: 68 | return np.tile(np.linspace(start, stop, width), (height, 1)) 69 | else: 70 | return np.tile(np.linspace(start, stop, height), (width, 1)).T 71 | 72 | 73 | def gradient_3d(width, height, starts, stops, is_horizontal_list): 74 | result = np.zeros((height, width, len(starts)), dtype=float) 75 | 76 | for i, (start, stop, is_horizontal) in enumerate(zip(starts, stops, is_horizontal_list)): 77 | result[:, :, i] = gradient_2d(start, stop, width, height, is_horizontal) 78 | 79 | return result 80 | 81 | 82 | def random_gradient_image(width, height): 83 | array = gradient_3d( 84 | width, 85 | height, 86 | (0, 0, np.random.randint(0, 255)), 87 | (np.random.randint(1, 255), np.random.randint(2, 255), np.random.randint(3, 128)), 88 | (True, False, False) 89 | ) 90 | random_image = Image.fromarray(np.uint8(array)) 91 | return random_image 92 | -------------------------------------------------------------------------------- /core/utils/normalize.py: -------------------------------------------------------------------------------- 1 | # https://github.com/pratogab/batch-transforms 2 | 3 | import torch 4 | 5 | 6 | class Normalize: 7 | """Applies the :class:`~torchvision.transforms.Normalize` transform to a batch of images. 8 | 9 | .. note:: 10 | This transform acts out of place by default, i.e., it does not mutate the input tensor. 11 | 12 | Args: 13 | mean (sequence): 14 | Sequence of means for each channel. 15 | std (sequence): 16 | Sequence of standard deviations for each channel. 17 | inplace(bool,optional): 18 | Bool to make this operation in-place. 19 | dtype (torch.dtype,optional): 20 | The data type of tensors to which the transform will be applied. 21 | device (torch.device,optional): 22 | The device of tensors to which the transform will be applied. 23 | """ 24 | def __init__(self, mean, std, inplace=False, dtype=torch.float, device='cpu'): 25 | self.mean = torch.as_tensor(mean, dtype=dtype, device=device)[None, :, None, None] 26 | self.std = torch.as_tensor(std, dtype=dtype, device=device)[None, :, None, None] 27 | self.inplace = inplace 28 | 29 | def __call__(self, tensor): 30 | """ 31 | Args: 32 | tensor (Tensor): Tensor of size (N, C, H, W) to be normalized. 33 | 34 | Returns: 35 | Tensor: Normalized Tensor. 36 | """ 37 | if not self.inplace: 38 | tensor = tensor.clone() 39 | 40 | tensor.sub_(self.mean).div_(self.std) 41 | return tensor 42 | -------------------------------------------------------------------------------- /core/utils/prompt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from core.utils.gradients import ReplaceGrad 6 | 7 | 8 | class Prompt(nn.Module): 9 | def __init__(self, embed, weight=1., stop=float('-inf')): 10 | super().__init__() 11 | self.register_buffer('embed', embed) 12 | self.register_buffer('weight', torch.as_tensor(weight)) 13 | self.register_buffer('stop', torch.as_tensor(stop)) 14 | 15 | def forward(self, input): 16 | input_normed = F.normalize(input.unsqueeze(1), dim=2) 17 | embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2) 18 | dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2) 19 | dists = dists * self.weight.sign() 20 | return self.weight.abs() * ReplaceGrad.apply(dists, torch.maximum(dists, self.stop)).mean() 21 | 22 | 23 | def parse_prompt(prompt): 24 | vals = prompt.rsplit(':', 2) 25 | vals = vals + ['', '1', '-inf'][len(vals):] 26 | return vals[0], float(vals[1]), float(vals[2]) 27 | -------------------------------------------------------------------------------- /data/.gitignore: -------------------------------------------------------------------------------- 1 | *.txt 2 | 3 | *.jpg 4 | *.jpeg 5 | *.png 6 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3" 2 | 3 | services: 4 | generate: 5 | build: ./ 6 | command: python -m scripts.generate -c /configs/docker.json 7 | volumes: 8 | - ./models:/models 9 | - ./configs:/configs 10 | - ./core:/app/core 11 | - ./scripts:/app/scripts 12 | - ./outputs:/outputs 13 | environment: 14 | - DEVICE=cuda 15 | deploy: 16 | resources: 17 | reservations: 18 | devices: 19 | - capabilities: [gpu] 20 | 21 | train: 22 | build: ./ 23 | command: python -m scripts.train -c /configs/models/vqgan_custom_docker.json 24 | volumes: 25 | - ./models:/models 26 | - ./configs:/configs 27 | - ./core:/app/core 28 | - ./scripts:/app/scripts 29 | - ./outputs:/outputs 30 | environment: 31 | - DEVICE=cuda 32 | deploy: 33 | resources: 34 | reservations: 35 | devices: 36 | - capabilities: [gpu] 37 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | *.pt 2 | *.pth 3 | *.ckpt 4 | *.bin 5 | *.pkl 6 | -------------------------------------------------------------------------------- /outputs/.gitignore: -------------------------------------------------------------------------------- 1 | *.png 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.9.0 2 | torchvision==0.10.0 3 | 4 | einops==0.3.0 5 | kornia==0.5.7 6 | 7 | Pillow==8.3.2 8 | numpy==1.20.2 9 | 10 | requests==2.24.0 11 | tqdm==4.51.0 12 | 13 | regex==2021.4.4 14 | ftfy==6.0.3 15 | -------------------------------------------------------------------------------- /samples/forest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/forest.png -------------------------------------------------------------------------------- /samples/ghost_pokemon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/ghost_pokemon.png -------------------------------------------------------------------------------- /samples/gundam.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/gundam.png -------------------------------------------------------------------------------- /samples/landscape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/landscape.png -------------------------------------------------------------------------------- /samples/sailor_moon.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/sailor_moon.png -------------------------------------------------------------------------------- /samples/waterfall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/waterfall.png -------------------------------------------------------------------------------- /scripts/generate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | import torchvision.transforms.functional as TF 9 | 10 | import numpy as np 11 | 12 | from PIL import Image 13 | 14 | from tqdm import tqdm 15 | 16 | from core.schemas import Config 17 | from core.clip import clip 18 | 19 | from core.utils import MakeCutouts, Normalize, resize_image, get_optimizer, get_scheduler, load_vqgan_model, global_seed 20 | from core.utils.noises import random_noise_image, random_fractal_image, random_gradient_image 21 | from core.utils.prompt import Prompt, parse_prompt 22 | from core.utils.gradients import ClampWithGrad, vector_quantize 23 | 24 | 25 | PARAMS: Config = None 26 | DEVICE = torch.device(os.environ.get("DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')) 27 | NORMALIZE = Normalize(mean=[0.48145466, 0.4578275, 0.40821073], 28 | std=[0.26862954, 0.26130258, 0.27577711], device=DEVICE) 29 | 30 | 31 | def parse_args(): 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument("-c", "--config", type=str, required=True, help="Path to configuration file.") 34 | return parser.parse_args() 35 | 36 | 37 | def initialize_image(model): 38 | f = 2**(model.decoder.num_resolutions - 1) 39 | toksX, toksY = PARAMS.size[0] // f, PARAMS.size[1] // f 40 | sideX, sideY = toksX * f, toksY * f 41 | 42 | def encode(img): 43 | pil_image = img.convert('RGB').resize((sideX, sideY), Image.LANCZOS) 44 | pil_tensor = TF.to_tensor(pil_image) 45 | z, *_ = model.encode(pil_tensor.to(DEVICE).unsqueeze(0) * 2 - 1) 46 | return z 47 | 48 | if PARAMS.init_image and os.path.exists(PARAMS.init_image): 49 | z = encode(Image.open(PARAMS.init_image)) 50 | elif PARAMS.init_noise == 'pixels': 51 | z = encode(random_noise_image(PARAMS.size[0], PARAMS.size[1])) 52 | elif PARAMS.init_noise == 'fractal': 53 | z = encode(random_fractal_image(PARAMS.size[0], PARAMS.size[1])) 54 | elif PARAMS.init_noise == 'gradient': 55 | z = encode(random_gradient_image(PARAMS.size[0], PARAMS.size[1])) 56 | else: 57 | e_dim = model.quantize.e_dim 58 | n_toks = model.quantize.n_e 59 | 60 | one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=DEVICE), n_toks).float() 61 | z = one_hot @ model.quantize.embedding.weight 62 | z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2) 63 | 64 | return z 65 | 66 | 67 | def tokenize(model, perceptor, make_cutouts): 68 | f = 2**(model.decoder.num_resolutions - 1) 69 | toksX, toksY = PARAMS.size[0] // f, PARAMS.size[1] // f 70 | sideX, sideY = toksX * f, toksY * f 71 | 72 | prompts = [] 73 | for prompt in PARAMS.prompts: 74 | txt, weight, stop = parse_prompt(prompt) 75 | embed = perceptor.encode_text(clip.tokenize(txt).to(DEVICE)).float() 76 | prompts.append(Prompt(embed, weight, stop).to(DEVICE)) 77 | 78 | for prompt in PARAMS.image_prompts: 79 | path, weight, stop = parse_prompt(prompt) 80 | img = Image.open(path) 81 | pil_image = img.convert('RGB') 82 | img = resize_image(pil_image, (sideX, sideY)) 83 | batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(DEVICE)) 84 | embed = perceptor.encode_image(NORMALIZE(batch)).float() 85 | prompts.append(Prompt(embed, weight, stop).to(DEVICE)) 86 | 87 | for seed, weight in zip(PARAMS.noise_prompt_seeds, PARAMS.noise_prompt_weights): 88 | gen = torch.Generator().manual_seed(seed) 89 | embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen) 90 | prompts.append(Prompt(embed, weight).to(DEVICE)) 91 | 92 | return prompts 93 | 94 | 95 | def synth(z, *, model): 96 | z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1) 97 | z_q = ClampWithGrad.apply(model.decode(z_q).add(1).div(2), 0, 1) 98 | 99 | if PARAMS.pixelart: 100 | z_q = F.avg_pool2d(z_q, tuple(np.ceil(np.divide(PARAMS.size, PARAMS.pixelart)).astype('uint8'))) 101 | 102 | return z_q 103 | 104 | 105 | @torch.no_grad() 106 | def checkin(z, losses, **kwargs): 107 | losses_str = ', '.join(f'{loss.item():g}' for loss in losses) 108 | tqdm.write(f"step: {kwargs['step']}, loss: {sum(losses).item():g}, losses: {losses_str}") 109 | out = synth(z, model=kwargs['model']) 110 | 111 | filename = "output" 112 | if len(PARAMS.prompts): 113 | filename = '_'.join(PARAMS.prompts).replace(' ', '_') 114 | 115 | path = f"{PARAMS.output_dir}/{filename}.png" 116 | TF.to_pil_image(out[0].cpu()).save(path) 117 | 118 | 119 | def ascend_txt(z, **kwargs): 120 | out = synth(z, model=kwargs['model']) 121 | cutouts = kwargs['make_cutouts'](out) 122 | iii = kwargs['perceptor'].encode_image(NORMALIZE(cutouts)).float() 123 | 124 | step = kwargs['step'] 125 | result = [] 126 | if PARAMS.init_weight: 127 | mse_weight = kwargs['mse_weight'] 128 | result.append(F.mse_loss(z, kwargs['z_orig']) * mse_weight / 2) 129 | 130 | mse_decay = PARAMS.init_weight / (PARAMS.max_iterations / PARAMS.mse_decay_rate) 131 | with torch.no_grad(): 132 | if step > 0 and step % PARAMS.mse_decay_rate == 0: 133 | kwargs['mse_weight'] = max(mse_weight - mse_decay, 0) 134 | 135 | for prompt in kwargs['prompts']: 136 | result.append(prompt(iii)) 137 | 138 | TF.to_pil_image(out[0].cpu()).save(f"{PARAMS.output_dir}/steps/{step}.png") 139 | return result 140 | 141 | 142 | def train(z, **kwargs): 143 | kwargs['optimizer'].zero_grad(set_to_none=True) 144 | lossAll = ascend_txt(z, **kwargs) 145 | 146 | if kwargs['step'] % PARAMS.save_freq == 0 or kwargs['step'] == PARAMS.max_iterations: 147 | checkin(z, lossAll, **kwargs) 148 | 149 | loss = sum(lossAll) 150 | loss.backward() 151 | kwargs['optimizer'].step() 152 | 153 | if kwargs['scheduler'] is not None: 154 | kwargs['scheduler'].step() 155 | 156 | with torch.no_grad(): 157 | z.copy_(z.maximum(kwargs['z_min']).minimum(kwargs['z_max'])) 158 | 159 | 160 | def main(): 161 | model = load_vqgan_model(PARAMS.vqgan_config, PARAMS.vqgan_checkpoint, PARAMS.models_dir).to(DEVICE) 162 | perceptor = clip.load(PARAMS.clip_model, device=DEVICE, root=PARAMS.models_dir)[0].eval().requires_grad_(False).to(DEVICE) 163 | 164 | cut_size = perceptor.visual.input_resolution 165 | make_cutouts = MakeCutouts(PARAMS.augments, cut_size, PARAMS.cutn, cut_pow=PARAMS.cut_pow) 166 | 167 | z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None] 168 | z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None] 169 | z = initialize_image(model) 170 | z_orig = torch.zeros_like(z) 171 | z.requires_grad_(True) 172 | 173 | prompts = tokenize(model, perceptor, make_cutouts) 174 | optimizer = get_optimizer(z, PARAMS.optimizer, PARAMS.step_size) 175 | scheduler = get_scheduler(optimizer, PARAMS.max_iterations, PARAMS.nwarm_restarts) 176 | 177 | kwargs = { 178 | 'model': model, 179 | 'perceptor': perceptor, 180 | 'optimizer': optimizer, 181 | 'scheduler': scheduler, 182 | 'prompts': prompts, 183 | 'make_cutouts': make_cutouts, 184 | 'z_orig': z_orig, 185 | 'z_min': z_min, 186 | 'z_max': z_max, 187 | 'mse_weight': PARAMS.init_weight, 188 | } 189 | try: 190 | for step in tqdm(range(PARAMS.max_iterations)): 191 | kwargs['step'] = step + 1 192 | train(z, **kwargs) 193 | except KeyboardInterrupt: 194 | pass 195 | 196 | 197 | if __name__ == "__main__": 198 | args = parse_args() 199 | 200 | if not os.path.exists(args.config): 201 | exit(f"ERROR: {args.config} not found.") 202 | 203 | print(f"Loading configuration from '{args.config}'") 204 | with open(args.config, 'r') as f: 205 | PARAMS = Config(**json.load(f)) 206 | 207 | print(f"Running on {DEVICE}.") 208 | print(PARAMS) 209 | 210 | global_seed(PARAMS.seed) 211 | 212 | main() 213 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | 5 | import torch 6 | from torch.utils.data import DataLoader 7 | 8 | import torchvision.transforms.functional as TF 9 | from torchvision import transforms as T 10 | from torchvision.datasets import ImageFolder 11 | 12 | from tqdm import tqdm 13 | 14 | from core.schemas import TrainConfig 15 | from core.utils import global_seed 16 | from core.utils.loader import safe_load 17 | from core.taming.models import vqgan 18 | 19 | 20 | PARAMS: TrainConfig = None 21 | DEVICE = torch.device(os.environ.get("DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu')) 22 | 23 | 24 | def parse_args(): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument("-c", "--config", type=str, required=True, help="Path to configuration file.") 27 | return parser.parse_args() 28 | 29 | 30 | def save_model(model, optimizers, epoch, path): 31 | save_dict = { 32 | "epoch": epoch, 33 | "global_step": model.global_step, 34 | "state_dict": model.state_dict(), 35 | "optimizer_states": [ 36 | optimizers[0].state_dict(), 37 | optimizers[1].state_dict(), 38 | ] 39 | } 40 | torch.save(save_dict, path) 41 | tqdm.write(f"Checkpoint saved in {path}") 42 | 43 | 44 | def main(): 45 | dataset = ImageFolder(PARAMS.data_dir, T.Compose( 46 | [ 47 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), 48 | T.Resize(PARAMS.params["embed_dim"]), 49 | T.CenterCrop(PARAMS.params["embed_dim"]), 50 | T.ToTensor() 51 | ] 52 | )) 53 | loader = DataLoader(dataset, PARAMS.batch_size, shuffle=True) 54 | 55 | PARAMS.params["model_dir"] = PARAMS.models_dir 56 | model = vqgan.VQModel(**PARAMS.params).to(DEVICE) 57 | model.learning_rate = PARAMS.batch_size * PARAMS.base_learning_rate 58 | model.global_step = 0 59 | 60 | optimizers, _ = model.configure_optimizers() 61 | epoch = 0 62 | 63 | if PARAMS.resume_checkpoint: 64 | save_dict = safe_load(PARAMS.resume_checkpoint, map_location='cpu') 65 | epoch = save_dict["epoch"] 66 | model.global_step = save_dict["global_step"] 67 | model.load_state_dict(save_dict["state_dict"]) 68 | optimizers[0].load_state_dict(save_dict["optimizer_states"][0]) 69 | optimizers[1].load_state_dict(save_dict["optimizer_states"][1]) 70 | print(f"Restored model from {PARAMS.resume_checkpoint}") 71 | 72 | while epoch < PARAMS.epochs: 73 | for i, (images, _) in tqdm(enumerate(loader), total=len(loader)): 74 | images.to(DEVICE) 75 | 76 | losses = [] 77 | for j, opt in enumerate(optimizers): 78 | loss = model.training_step(images, i, j, device=DEVICE) 79 | losses.append(loss.item()) 80 | 81 | opt.zero_grad() 82 | loss.backward() 83 | 84 | opt.step() 85 | 86 | tqdm.write(f"Epoch: {epoch} | Batch: {i} | losses: {losses}") 87 | 88 | if i % 1000 == 0: 89 | save_model(model, optimizers, epoch, f"{PARAMS.models_dir}/checkpoints/last.ckpt") 90 | 91 | with torch.no_grad(): 92 | dec, _ = model(model.get_input(images, device=DEVICE)) 93 | TF.to_pil_image(dec[0].cpu()).save(f"{PARAMS.output_dir}/training/{epoch}_{i}.png") 94 | 95 | model.global_step += 1 96 | epoch += 1 97 | 98 | save_model(model, optimizers, epoch, f"{PARAMS.models_dir}/checkpoints/final.ckpt") 99 | 100 | 101 | if __name__ == "__main__": 102 | args = parse_args() 103 | 104 | if not os.path.exists(args.config): 105 | exit(f"ERROR: {args.config} not found.") 106 | 107 | print(f"Loading configuration from '{args.config}'") 108 | with open(args.config, 'r') as f: 109 | PARAMS = TrainConfig(**json.load(f)) 110 | 111 | print(f"Running on {DEVICE}.") 112 | print(PARAMS) 113 | 114 | global_seed(PARAMS.seed) 115 | 116 | main() 117 | --------------------------------------------------------------------------------