├── .dockerignore
├── .gitignore
├── Dockerfile
├── LICENSE
├── Makefile
├── README.md
├── configs
├── docker.json
├── local.json
└── models
│ ├── vqgan_coco_f16_8192.json
│ ├── vqgan_custom.json
│ ├── vqgan_custom_docker.json
│ ├── vqgan_faceshq_f16_1024.json
│ ├── vqgan_imagenet_f16_1024.json
│ └── vqgan_imagenet_f16_16384.json
├── core
├── clip
│ ├── README.md
│ ├── __init__.py
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── clip.py
│ ├── model.py
│ └── simple_tokenizer.py
├── optimizer
│ ├── __init__.py
│ ├── adamp.py
│ ├── diffgrad.py
│ └── radam.py
├── schemas
│ ├── __init__.py
│ ├── config.py
│ └── train_config.py
├── taming
│ ├── README.md
│ ├── models
│ │ ├── __init__.py
│ │ └── vqgan.py
│ ├── modules
│ │ ├── diffusion
│ │ │ ├── __init__.py
│ │ │ ├── attn_block.py
│ │ │ ├── decoder.py
│ │ │ ├── downsample.py
│ │ │ ├── encoder.py
│ │ │ ├── resnet_block.py
│ │ │ └── upsample.py
│ │ ├── discriminator
│ │ │ ├── __init__.py
│ │ │ ├── act_norm.py
│ │ │ └── discriminator.py
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── lpips.py
│ │ │ └── vqperceptual.py
│ │ └── vqvae
│ │ │ ├── __init__.py
│ │ │ └── vector_quantizer.py
│ └── utils
│ │ ├── __init__.py
│ │ ├── diffusion_utils.py
│ │ ├── discriminator_utils.py
│ │ └── losses_utils.py
└── utils
│ ├── __init__.py
│ ├── gradients.py
│ ├── helpers.py
│ ├── loader.py
│ ├── make_cutouts.py
│ ├── noises.py
│ ├── normalize.py
│ └── prompt.py
├── data
└── .gitignore
├── docker-compose.yml
├── models
└── .gitignore
├── outputs
└── .gitignore
├── requirements.txt
├── samples
├── forest.png
├── ghost_pokemon.png
├── gundam.png
├── landscape.png
├── sailor_moon.png
└── waterfall.png
└── scripts
├── generate.py
└── train.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | ./models
2 | ./data
3 | ./samples
4 | ./outputs
5 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime
2 |
3 | WORKDIR /app
4 |
5 | COPY ./requirements.txt /requirements.txt
6 | RUN python -m pip install -r /requirements.txt
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Kevin Costa
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | COMPOSE=docker-compose -f docker-compose.yml
2 |
3 | all: build
4 |
5 | build:
6 | $(COMPOSE) build
7 |
8 | generate:
9 | $(COMPOSE) run generate
10 |
11 | generate-cpu:
12 | $(COMPOSE) run -e DEVICE='cpu' generate
13 |
14 | train:
15 | $(COMPOSE) run train
16 |
17 | train-cpu:
18 | $(COMPOSE) run -e DEVICE='cpu' train
19 |
20 |
21 | .PHONY: all build generate generate-cpu train train-cpu
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VQGAN-CLIP-Docker
2 |
3 | - [Setup](#Setup)
4 | - [Usage](#Usage)
5 | - [Inference](#Inference)
6 | - [Training](#Training)
7 | - [Acknowledgments](#Acknowledgments)
8 | - [Citations](#Citations)
9 |
10 | ## About
11 |
12 | > Zero-Shot Text-to-Image Generation VQGAN+CLIP Dockerized
13 |
14 | This is a stripped and minimal dependencies repository for running locally or in production VQGAN+CLIP.
15 |
16 | For a Google Colab notebook [see the original repository](#Acknowledgments).
17 |
18 | ## Samples
19 |
20 |
21 |

22 |

23 |

24 |

25 |

26 |

27 |
28 |
29 |
30 | # Setup
31 |
32 | Clone this repository and `cd` inside.
33 |
34 | ```sh
35 | git clone https://github.com/kcosta42/VQGAN-CLIP-Docker.git
36 | cd VQGAN-CLIP-Docker
37 | ```
38 |
39 | You can download a pretrained VQGAN model and put it in the `./models` folder.
40 |
41 |
72 |
73 |
74 | For GPU capability, make sure you have CUDA installed on your system (tested with CUDA 11.1+).
75 |
76 | - 6 GB of VRAM is required to generate 256x256 images.
77 | - 11 GB of VRAM is required to generate 512x512 images.
78 | - 24 GB of VRAM is required to generate 1024x1024 images. (Untested)
79 |
80 | ## Local
81 |
82 | Install the Python requirements
83 |
84 | ```sh
85 | python3 -m pip install -r requirements.txt
86 | ```
87 |
88 | To know if you can run this on your GPU, the following command must return `True`.
89 | ```sh
90 | python3 -c "import torch; print(torch.cuda.is_available());"
91 | ```
92 |
93 | ## Docker
94 |
95 | > Make sure you have `docker` and `docker-compose` v1.28.0+ installed. `nvidia-docker` is needed if you want to run this on your GPU through Docker.
96 |
97 | A Makefile is provided for ease of use.
98 |
99 | ```sh
100 | make build # Build the docker image
101 | ```
102 |
103 | # Usage
104 |
105 | ## Inference
106 |
107 | Two configuration files are provided `./configs/local.json` and `./configs/docker.json`. They are ready to go, but you may want to edit them to meet your need. Check the [Configuration section](#Configuration) to understand each field.
108 |
109 | By default, the resulting generations can be found in the `./outputs` folder.
110 |
111 | ### GPU
112 |
113 | To run locally:
114 |
115 | ```py
116 | python3 -m scripts.generate -c ./configs/local.json
117 | ```
118 |
119 | To run on docker:
120 |
121 | ```py
122 | make generate
123 | ```
124 |
125 | ### CPU
126 |
127 | To run locally:
128 |
129 | ```py
130 | DEVICE=cpu python3 -m scripts.generate -c ./configs/local.json
131 | ```
132 |
133 | To run on docker:
134 |
135 | ```py
136 | make generate-cpu
137 | ```
138 |
139 | ### Configuration
140 |
141 | | Argument | Type | Descriptions |
142 | |------------------------|----------------|--------------------------------------------------------------------------------|
143 | | `prompts` | List[str] | Text prompts |
144 | | `image_prompts` | List[FilePath] | Image prompts / target image path |
145 | | `max_iterations` | int | Number of iterations |
146 | | `save_freq` | int | Save image iterations |
147 | | `size` | [int, int] | Image size (width height) |
148 | | `pixelart` | [int, int] | Pixelart image size (width height) (Optional, remove option to disable) |
149 | | `init_image` | FilePath | Initial image |
150 | | `init_noise` | str | Initial noise image ["gradient","pixels","fractal"] |
151 | | `init_weight` | float | Initial weight |
152 | | `mse_decay_rate` | int | Slowly decrease the MSE Loss each specified iterations until it reach about 0 |
153 | | `output_dir` | FilePath | Path to output directory |
154 | | `models_dir` | FilePath | Path to models cache directory |
155 | | `clip_model` | FilePath | CLIP model path or name |
156 | | `vqgan_checkpoint` | FilePath | VQGAN checkpoint path |
157 | | `vqgan_config` | FilePath | VQGAN config path |
158 | | `noise_prompt_seeds` | List[int] | Noise prompt seeds |
159 | | `noise_prompt_weights` | List[float] | Noise prompt weights |
160 | | `step_size` | float | Learning rate |
161 | | `cutn` | int | Number of cuts |
162 | | `cut_pow` | float | Cut power |
163 | | `seed` | int | Seed (-1 for random seed) |
164 | | `optimizer` | str | Optimiser ["Adam","AdamW","Adagrad","Adamax","DiffGrad","AdamP","RAdam"] |
165 | | `nwarm_restarts` | int | Number of time the learning rate is reseted (-1 to disable LR decay) |
166 | | `augments` | List[str] | Enabled augments ["Ji","Sh","Gn","Pe","Ro","Af","Et","Ts","Cr","Er","Re","Hf"] |
167 |
168 | ## Training
169 |
170 | > These are instructions to train a new VQGAN model. You can also finetunes the pretrained models but you may need to tweak the training script.
171 |
172 | Two models configuration files are provided `./configs/models/vqgan_custom.json` and `./configs/models/vqgan_custom_docker.json`. They are ready to go, but you may want to edit them to meet your need. Check the [Model Configuration](#Model-Configuration) to understand each field.
173 |
174 | By default, the models are saved in the `./models/checkpoints` folder.
175 |
176 | ### Dataset
177 |
178 | Put your image in a folder inside the data directory (`./data` by default).
179 |
180 | The dataset must be structured as follow:
181 |
182 | ```sh
183 | ./data/
184 | ├── class_x/
185 | │ ├── xxx.png
186 | │ ├── xxy.jpg
187 | │ └── ...
188 | │ └── xxz.ppm
189 | └── class_y/
190 | ├── 123.bmp
191 | ├── nsdf3.tif
192 | └── ...
193 | └── asd932_.webp
194 | ```
195 |
196 | ### GPU
197 |
198 | To run locally:
199 |
200 | ```py
201 | python3 -m scripts.train -c ./configs/models/vqgan_custom.json
202 | ```
203 |
204 | To run on docker:
205 |
206 | ```py
207 | make train
208 | ```
209 |
210 | ### CPU
211 |
212 | To run locally:
213 |
214 | ```py
215 | DEVICE=cpu python3 -m scripts.train -c ./configs/models/vqgan_custom.json
216 | ```
217 |
218 | To run on docker:
219 |
220 | ```py
221 | make train-cpu
222 | ```
223 |
224 | ### Model Configuration
225 |
226 | | Argument | Type | Descriptions |
227 | |------------------------|----------------|---------------------------------------------------------------------------|
228 | | `base_learning_rate` | float | Initial Learning rate |
229 | | `batch_size` | int | Batch size (Adjust based on your GPU capability) |
230 | | `epochs` | int | Maximum number of epoch |
231 | | `output_dir` | FilePath | Path to directory where to save training images |
232 | | `models_dir` | FilePath | Path to directory where to save the model |
233 | | `data_dir` | FilePath | Path to data directory |
234 | | `seed` | int | Seed (-1 for random seed) |
235 | | `resume_checkpoint` | FilePath | Path to pretrained model |
236 |
237 | ### Infos
238 |
239 | - Let the Generator train without the Discriminator for a few epochs (~3-5 epochs for ImageNet), then enable the Discriminator.
The variable `lossconfig.params.disc_start` correspond to the number of global step (ie. batch iterations) before enabling the Discriminator.
240 | - Once enabled, the Discriminator loss will stagnate around ~1.0, __this is a normal behaviour__. The loss will decrease in later epochs. (It can take a _very_ long time).
241 | - If you've enabled the Discriminator too soon, the Generator will take a lot more time to train.
242 | - Basically there is no rules for the number of epochs. If your dataset is large enough, there is no risk of overfitting. So the more you train, the better.
243 |
244 |
245 | # Acknowledgments
246 |
247 | [VQGAN+CLIP](https://github.com/nerdyrodent/VQGAN-CLIP)
248 |
249 | [Taming Transformers](https://github.com/CompVis/taming-transformers)
250 |
251 | [CLIP](https://github.com/openai/CLIP)
252 |
253 | [DALLE-PyTorch](https://github.com/lucidrains/DALLE-pytorch)
254 |
255 | # Citations
256 |
257 | ```bibtex
258 | @misc{unpublished2021clip,
259 | title = {CLIP: Connecting Text and Images},
260 | author = {Alec Radford, Ilya Sutskever, Jong Wook Kim, Gretchen Krueger, Sandhini Agarwal},
261 | year = {2021}
262 | }
263 | ```
264 |
265 | ```bibtex
266 | @misc{esser2020taming,
267 | title={Taming Transformers for High-Resolution Image Synthesis},
268 | author={Patrick Esser and Robin Rombach and Björn Ommer},
269 | year={2020},
270 | eprint={2012.09841},
271 | archivePrefix={arXiv},
272 | primaryClass={cs.CV}
273 | }
274 | ```
275 |
276 | ```bibtex
277 | @misc{ramesh2021zeroshot,
278 | title = {Zero-Shot Text-to-Image Generation},
279 | author = {Aditya Ramesh and Mikhail Pavlov and Gabriel Goh and Scott Gray and Chelsea Voss and Alec Radford and Mark Chen and Ilya Sutskever},
280 | year = {2021},
281 | eprint = {2102.12092},
282 | archivePrefix = {arXiv},
283 | primaryClass = {cs.CV}
284 | }
285 | ```
286 |
--------------------------------------------------------------------------------
/configs/docker.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompts": ["a painting of a potato"],
3 | "image_prompts": [],
4 | "max_iterations": 250,
5 | "save_freq": 50,
6 | "size": [256, 256],
7 | "init_image": "",
8 | "init_noise": "",
9 | "init_weight": 0.0,
10 | "mse_decay_rate": 0,
11 | "output_dir": "/outputs",
12 | "models_dir": "/models",
13 | "clip_model": "ViT-B/16",
14 | "vqgan_checkpoint": "/models/vqgan_imagenet_f16_16384.ckpt",
15 | "vqgan_config": "/configs/models/vqgan_imagenet_f16_16384.json",
16 | "noise_prompt_seeds": [],
17 | "noise_prompt_weights": [],
18 | "step_size": 0.1,
19 | "cutn": 32,
20 | "cut_pow": 1.0,
21 | "seed": -1,
22 | "optimizer": "Adam",
23 | "nwarm_restarts": -1,
24 | "augments": ["Af", "Pe", "Ji", "Er"]
25 | }
26 |
--------------------------------------------------------------------------------
/configs/local.json:
--------------------------------------------------------------------------------
1 | {
2 | "prompts": ["a painting of a potato"],
3 | "image_prompts": [],
4 | "max_iterations": 250,
5 | "save_freq": 50,
6 | "size": [256, 256],
7 | "init_image": "",
8 | "init_noise": "",
9 | "init_weight": 0.0,
10 | "mse_decay_rate": 0,
11 | "output_dir": "./outputs",
12 | "models_dir": "./models",
13 | "clip_model": "ViT-B/16",
14 | "vqgan_checkpoint": "./models/vqgan_imagenet_f16_16384.ckpt",
15 | "vqgan_config": "./configs/models/vqgan_imagenet_f16_16384.json",
16 | "noise_prompt_seeds": [],
17 | "noise_prompt_weights": [],
18 | "step_size": 0.1,
19 | "cutn": 32,
20 | "cut_pow": 1.0,
21 | "seed": -1,
22 | "optimizer": "Adam",
23 | "nwarm_restarts": -1,
24 | "augments": ["Af", "Pe", "Ji", "Er"]
25 | }
26 |
--------------------------------------------------------------------------------
/configs/models/vqgan_coco_f16_8192.json:
--------------------------------------------------------------------------------
1 | {
2 | "params": {
3 | "embed_dim": 256,
4 | "n_embed": 8192,
5 | "ddconfig": {
6 | "double_z": false,
7 | "z_channels": 256,
8 | "resolution": 256,
9 | "in_channels": 3,
10 | "out_ch": 3,
11 | "ch": 128,
12 | "ch_mult": [1, 1, 2, 2, 4],
13 | "num_res_blocks": 2,
14 | "attn_resolutions": [16],
15 | "dropout": 0.0
16 | }
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/configs/models/vqgan_custom.json:
--------------------------------------------------------------------------------
1 | {
2 | "base_learning_rate": 4.5e-6,
3 | "batch_size": 4,
4 | "epochs": 1000,
5 | "output_dir": "./outputs",
6 | "models_dir": "./models",
7 | "data_dir": "./data",
8 | "seed": -1,
9 | "resume_checkpoint": "",
10 | "params": {
11 | "embed_dim": 256,
12 | "n_embed": 1024,
13 | "ddconfig": {
14 | "double_z": false,
15 | "z_channels": 256,
16 | "resolution": 256,
17 | "in_channels": 3,
18 | "out_ch": 3,
19 | "ch": 128,
20 | "ch_mult": [1, 1, 2, 2, 4],
21 | "num_res_blocks": 2,
22 | "attn_resolutions": [16],
23 | "dropout": 0.0
24 | },
25 | "lossconfig": {
26 | "params": {
27 | "disc_conditional": false,
28 | "disc_in_channels": 3,
29 | "disc_start": 25000,
30 | "disc_weight": 0.8,
31 | "codebook_weight": 1.0
32 | }
33 | }
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/configs/models/vqgan_custom_docker.json:
--------------------------------------------------------------------------------
1 | {
2 | "base_learning_rate": 4.5e-6,
3 | "batch_size": 4,
4 | "epochs": 1000,
5 | "output_dir": "/outputs",
6 | "models_dir": "/models",
7 | "data_dir": "/data",
8 | "seed": -1,
9 | "resume_checkpoint": "",
10 | "params": {
11 | "embed_dim": 256,
12 | "n_embed": 1024,
13 | "ddconfig": {
14 | "double_z": false,
15 | "z_channels": 256,
16 | "resolution": 256,
17 | "in_channels": 3,
18 | "out_ch": 3,
19 | "ch": 128,
20 | "ch_mult": [1, 1, 2, 2, 4],
21 | "num_res_blocks": 2,
22 | "attn_resolutions": [16],
23 | "dropout": 0.0
24 | },
25 | "lossconfig": {
26 | "params": {
27 | "disc_conditional": false,
28 | "disc_in_channels": 3,
29 | "disc_start": 25000,
30 | "disc_weight": 0.8,
31 | "codebook_weight": 1.0
32 | }
33 | }
34 | }
35 | }
36 |
--------------------------------------------------------------------------------
/configs/models/vqgan_faceshq_f16_1024.json:
--------------------------------------------------------------------------------
1 | {
2 | "params": {
3 | "embed_dim": 256,
4 | "n_embed": 1024,
5 | "ddconfig": {
6 | "double_z": false,
7 | "z_channels": 256,
8 | "resolution": 256,
9 | "in_channels": 3,
10 | "out_ch": 3,
11 | "ch": 128,
12 | "ch_mult": [1, 1, 2, 2, 4],
13 | "num_res_blocks": 2,
14 | "attn_resolutions": [16],
15 | "dropout": 0.0
16 | }
17 | }
18 | }
19 |
--------------------------------------------------------------------------------
/configs/models/vqgan_imagenet_f16_1024.json:
--------------------------------------------------------------------------------
1 | {
2 | "params": {
3 | "embed_dim": 256,
4 | "n_embed": 1024,
5 | "ddconfig": {
6 | "double_z": false,
7 | "z_channels": 256,
8 | "resolution": 256,
9 | "in_channels": 3,
10 | "out_ch": 3,
11 | "ch": 128,
12 | "ch_mult": [1, 1, 2, 2, 4],
13 | "num_res_blocks": 2,
14 | "attn_resolutions": [16],
15 | "dropout": 0.0
16 | },
17 | "lossconfig": {
18 | "params": {
19 | "disc_conditional": false,
20 | "disc_in_channels": 3,
21 | "disc_start": 0,
22 | "disc_weight": 0.8,
23 | "codebook_weight": 1.0
24 | }
25 | }
26 | }
27 | }
28 |
--------------------------------------------------------------------------------
/configs/models/vqgan_imagenet_f16_16384.json:
--------------------------------------------------------------------------------
1 | {
2 | "params": {
3 | "embed_dim": 256,
4 | "n_embed": 16384,
5 | "ddconfig": {
6 | "double_z": false,
7 | "z_channels": 256,
8 | "resolution": 256,
9 | "in_channels": 3,
10 | "out_ch": 3,
11 | "ch": 128,
12 | "ch_mult": [1, 1, 2, 2, 4],
13 | "num_res_blocks": 2,
14 | "attn_resolutions": [16],
15 | "dropout": 0.0
16 | },
17 | "lossconfig": {
18 | "params": {
19 | "disc_conditional": false,
20 | "disc_in_channels": 3,
21 | "disc_start": 0,
22 | "disc_weight": 0.75,
23 | "disc_num_layers": 2,
24 | "codebook_weight": 1.0
25 | }
26 | }
27 | }
28 | }
29 |
--------------------------------------------------------------------------------
/core/clip/README.md:
--------------------------------------------------------------------------------
1 | # CLIP
2 |
3 | [[Original]](https://github.com/openai/CLIP)
4 |
5 | ## About
6 |
7 | A stripped & minimalist version of the original project.
8 |
--------------------------------------------------------------------------------
/core/clip/__init__.py:
--------------------------------------------------------------------------------
1 | from core.clip import *
2 |
--------------------------------------------------------------------------------
/core/clip/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/core/clip/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/core/clip/clip.py:
--------------------------------------------------------------------------------
1 | import hashlib
2 | import os
3 | import sys
4 | import urllib
5 | import warnings
6 | from typing import Any, Union, List
7 |
8 | import torch
9 | from PIL import Image
10 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
11 | from tqdm import tqdm
12 |
13 | from core.clip.model import build_model
14 | from core.clip.simple_tokenizer import SimpleTokenizer as _Tokenizer
15 |
16 | try:
17 | from torchvision.transforms import InterpolationMode
18 | BICUBIC = InterpolationMode.BICUBIC
19 | except ImportError:
20 | BICUBIC = Image.BICUBIC
21 |
22 |
23 | if torch.__version__.split(".") < ["1", "7", "1"]:
24 | warnings.warn("PyTorch version 1.7.1 or higher is recommended")
25 |
26 |
27 | __all__ = ["available_models", "load", "tokenize"]
28 | _tokenizer = _Tokenizer()
29 |
30 | _MODELS = {
31 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
32 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
33 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
34 | "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
35 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
36 | "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
37 | }
38 |
39 |
40 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
41 | os.makedirs(root, exist_ok=True)
42 | filename = os.path.basename(url)
43 |
44 | expected_sha256 = url.split("/")[-2]
45 | download_target = os.path.join(root, filename)
46 |
47 | if os.path.exists(download_target) and not os.path.isfile(download_target):
48 | raise RuntimeError(f"{download_target} exists and is not a regular file")
49 |
50 | if os.path.isfile(download_target):
51 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
52 | return download_target
53 | else:
54 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
55 |
56 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, download_target))
57 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
58 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
59 | while True:
60 | buffer = source.read(8192)
61 | if not buffer:
62 | break
63 |
64 | output.write(buffer)
65 | loop.update(len(buffer))
66 |
67 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
68 | raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
69 |
70 | return download_target
71 |
72 |
73 | def _transform(n_px):
74 | return Compose([
75 | Resize(n_px, interpolation=BICUBIC),
76 | CenterCrop(n_px),
77 | lambda image: image.convert("RGB"),
78 | ToTensor(),
79 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
80 | ])
81 |
82 |
83 | def available_models() -> List[str]:
84 | """Returns the names of available CLIP models"""
85 | return list(_MODELS.keys())
86 |
87 |
88 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, **kwargs: Any):
89 | """Load a CLIP model
90 |
91 | Parameters
92 | ----------
93 | name : str
94 | A model name listed by `clip.available_models()`, or the path to a model checkpoint
95 | containing the state_dict
96 |
97 | device : Union[str, torch.device]
98 | The device to put the loaded model
99 |
100 | jit : bool
101 | Whether to load the optimized JIT model or more hackable non-JIT model (default).
102 |
103 | **kwargs (optional): Any
104 | The corresponding kwargs for _download function
105 |
106 | Returns
107 | -------
108 | model : torch.nn.Module
109 | The CLIP model
110 |
111 | preprocess : Callable[[PIL.Image], torch.Tensor]
112 | A torchvision transform that converts a PIL image into a tensor that the returned model can
113 | take as its input
114 | """
115 | if name in _MODELS:
116 | model_path = _download(_MODELS[name], **kwargs)
117 | elif os.path.isfile(name):
118 | model_path = name
119 | else:
120 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
121 |
122 | try:
123 | # loading JIT archive
124 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
125 | state_dict = None
126 | except RuntimeError:
127 | # loading saved state dict
128 | if jit:
129 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
130 | jit = False
131 | state_dict = torch.load(model_path, map_location="cpu")
132 |
133 | if not jit:
134 | model = build_model(state_dict or model.state_dict()).to(device)
135 | if str(device) == "cpu":
136 | model.float()
137 | return model, _transform(model.visual.input_resolution)
138 |
139 | # patch the device names
140 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
141 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
142 |
143 | def patch_device(module):
144 | try:
145 | graphs = [module.graph] if hasattr(module, "graph") else []
146 | except RuntimeError:
147 | graphs = []
148 |
149 | if hasattr(module, "forward1"):
150 | graphs.append(module.forward1.graph)
151 |
152 | for graph in graphs:
153 | for node in graph.findAllNodes("prim::Constant"):
154 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
155 | node.copyAttributes(device_node)
156 |
157 | model.apply(patch_device)
158 | patch_device(model.encode_image)
159 | patch_device(model.encode_text)
160 |
161 | # patch dtype to float32 on CPU
162 | if str(device) == "cpu":
163 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
164 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
165 | float_node = float_input.node()
166 |
167 | def patch_float(module):
168 | try:
169 | graphs = [module.graph] if hasattr(module, "graph") else []
170 | except RuntimeError:
171 | graphs = []
172 |
173 | if hasattr(module, "forward1"):
174 | graphs.append(module.forward1.graph)
175 |
176 | for graph in graphs:
177 | for node in graph.findAllNodes("aten::to"):
178 | inputs = list(node.inputs())
179 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
180 | if inputs[i].node()["value"] == 5:
181 | inputs[i].node().copyAttributes(float_node)
182 |
183 | model.apply(patch_float)
184 | patch_float(model.encode_image)
185 | patch_float(model.encode_text)
186 |
187 | model.float()
188 |
189 | return model, _transform(model.input_resolution.item())
190 |
191 |
192 | def tokenize(texts: Union[str, List[str]], context_length: int = 77, truncate: bool = False) -> torch.LongTensor:
193 | """
194 | Returns the tokenized representation of given input string(s)
195 |
196 | Parameters
197 | ----------
198 | texts : Union[str, List[str]]
199 | An input string or a list of input strings to tokenize
200 |
201 | context_length : int
202 | The context length to use; all CLIP models use 77 as the context length
203 |
204 | truncate: bool
205 | Whether to truncate the text in case its encoding is longer than the context length
206 |
207 | Returns
208 | -------
209 | A two-dimensional tensor containing the resulting tokens,
210 | shape = [number of input strings, context_length]
211 | """
212 | if isinstance(texts, str):
213 | texts = [texts]
214 |
215 | sot_token = _tokenizer.encoder["<|startoftext|>"]
216 | eot_token = _tokenizer.encoder["<|endoftext|>"]
217 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
218 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
219 |
220 | for i, tokens in enumerate(all_tokens):
221 | if len(tokens) > context_length:
222 | if truncate:
223 | tokens = tokens[:context_length]
224 | tokens[-1] = eot_token
225 | else:
226 | raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
227 | result[i, :len(tokens)] = torch.tensor(tokens)
228 |
229 | return result
230 |
--------------------------------------------------------------------------------
/core/clip/model.py:
--------------------------------------------------------------------------------
1 | from collections import OrderedDict
2 | from typing import Tuple, Union
3 |
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | from torch import nn
8 |
9 |
10 | class Bottleneck(nn.Module):
11 | expansion = 4
12 |
13 | def __init__(self, inplanes, planes, stride=1):
14 | super().__init__()
15 |
16 | # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
17 | self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
18 | self.bn1 = nn.BatchNorm2d(planes)
19 |
20 | self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
21 | self.bn2 = nn.BatchNorm2d(planes)
22 |
23 | self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
24 |
25 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
26 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
27 |
28 | self.relu = nn.ReLU(inplace=True)
29 | self.downsample = None
30 | self.stride = stride
31 |
32 | if stride > 1 or inplanes != planes * Bottleneck.expansion:
33 | # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
34 | self.downsample = nn.Sequential(OrderedDict([
35 | ("-1", nn.AvgPool2d(stride)),
36 | ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
37 | ("1", nn.BatchNorm2d(planes * self.expansion))
38 | ]))
39 |
40 | def forward(self, x: torch.Tensor):
41 | identity = x
42 |
43 | out = self.relu(self.bn1(self.conv1(x)))
44 | out = self.relu(self.bn2(self.conv2(out)))
45 | out = self.avgpool(out)
46 | out = self.bn3(self.conv3(out))
47 |
48 | if self.downsample is not None:
49 | identity = self.downsample(x)
50 |
51 | out += identity
52 | out = self.relu(out)
53 | return out
54 |
55 |
56 | class AttentionPool2d(nn.Module):
57 | def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
58 | super().__init__()
59 | self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
60 | self.k_proj = nn.Linear(embed_dim, embed_dim)
61 | self.q_proj = nn.Linear(embed_dim, embed_dim)
62 | self.v_proj = nn.Linear(embed_dim, embed_dim)
63 | self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
64 | self.num_heads = num_heads
65 |
66 | def forward(self, x):
67 | x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
68 | x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
69 | x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
70 | x, _ = F.multi_head_attention_forward(
71 | query=x, key=x, value=x,
72 | embed_dim_to_check=x.shape[-1],
73 | num_heads=self.num_heads,
74 | q_proj_weight=self.q_proj.weight,
75 | k_proj_weight=self.k_proj.weight,
76 | v_proj_weight=self.v_proj.weight,
77 | in_proj_weight=None,
78 | in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
79 | bias_k=None,
80 | bias_v=None,
81 | add_zero_attn=False,
82 | dropout_p=0,
83 | out_proj_weight=self.c_proj.weight,
84 | out_proj_bias=self.c_proj.bias,
85 | use_separate_proj_weight=True,
86 | training=self.training,
87 | need_weights=False
88 | )
89 |
90 | return x[0]
91 |
92 |
93 | class ModifiedResNet(nn.Module):
94 | """
95 | A ResNet class that is similar to torchvision's but contains the following changes:
96 | - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
97 | - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
98 | - The final pooling layer is a QKV attention instead of an average pool
99 | """
100 |
101 | def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
102 | super().__init__()
103 | self.output_dim = output_dim
104 | self.input_resolution = input_resolution
105 |
106 | # the 3-layer stem
107 | self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
108 | self.bn1 = nn.BatchNorm2d(width // 2)
109 | self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
110 | self.bn2 = nn.BatchNorm2d(width // 2)
111 | self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
112 | self.bn3 = nn.BatchNorm2d(width)
113 | self.avgpool = nn.AvgPool2d(2)
114 | self.relu = nn.ReLU(inplace=True)
115 |
116 | # residual layers
117 | self._inplanes = width # this is a *mutable* variable used during construction
118 | self.layer1 = self._make_layer(width, layers[0])
119 | self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
120 | self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
121 | self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
122 |
123 | embed_dim = width * 32 # the ResNet feature dimension
124 | self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
125 |
126 | def _make_layer(self, planes, blocks, stride=1):
127 | layers = [Bottleneck(self._inplanes, planes, stride)]
128 |
129 | self._inplanes = planes * Bottleneck.expansion
130 | for _ in range(1, blocks):
131 | layers.append(Bottleneck(self._inplanes, planes))
132 |
133 | return nn.Sequential(*layers)
134 |
135 | def forward(self, x):
136 | def stem(x):
137 | for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
138 | x = self.relu(bn(conv(x)))
139 | x = self.avgpool(x)
140 | return x
141 |
142 | x = x.type(self.conv1.weight.dtype)
143 | x = stem(x)
144 | x = self.layer1(x)
145 | x = self.layer2(x)
146 | x = self.layer3(x)
147 | x = self.layer4(x)
148 | x = self.attnpool(x)
149 |
150 | return x
151 |
152 |
153 | class LayerNorm(nn.LayerNorm):
154 | """Subclass torch's LayerNorm to handle fp16."""
155 |
156 | def forward(self, x: torch.Tensor):
157 | orig_type = x.dtype
158 | ret = super().forward(x.type(torch.float32))
159 | return ret.type(orig_type)
160 |
161 |
162 | class QuickGELU(nn.Module):
163 | def forward(self, x: torch.Tensor):
164 | return x * torch.sigmoid(1.702 * x)
165 |
166 |
167 | class ResidualAttentionBlock(nn.Module):
168 | def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
169 | super().__init__()
170 |
171 | self.attn = nn.MultiheadAttention(d_model, n_head)
172 | self.ln_1 = LayerNorm(d_model)
173 | self.mlp = nn.Sequential(OrderedDict([
174 | ("c_fc", nn.Linear(d_model, d_model * 4)),
175 | ("gelu", QuickGELU()),
176 | ("c_proj", nn.Linear(d_model * 4, d_model))
177 | ]))
178 | self.ln_2 = LayerNorm(d_model)
179 | self.attn_mask = attn_mask
180 |
181 | def attention(self, x: torch.Tensor):
182 | self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
183 | return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
184 |
185 | def forward(self, x: torch.Tensor):
186 | x = x + self.attention(self.ln_1(x))
187 | x = x + self.mlp(self.ln_2(x))
188 | return x
189 |
190 |
191 | class Transformer(nn.Module):
192 | def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
193 | super().__init__()
194 | self.width = width
195 | self.layers = layers
196 | self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
197 |
198 | def forward(self, x: torch.Tensor):
199 | return self.resblocks(x)
200 |
201 |
202 | class VisionTransformer(nn.Module):
203 | def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
204 | super().__init__()
205 | self.input_resolution = input_resolution
206 | self.output_dim = output_dim
207 | self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
208 |
209 | scale = width ** -0.5
210 | self.class_embedding = nn.Parameter(scale * torch.randn(width))
211 | self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
212 | self.ln_pre = LayerNorm(width)
213 |
214 | self.transformer = Transformer(width, layers, heads)
215 |
216 | self.ln_post = LayerNorm(width)
217 | self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
218 |
219 | def forward(self, x: torch.Tensor):
220 | x = self.conv1(x) # shape = [*, width, grid, grid]
221 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
222 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
223 | x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
224 | x = x + self.positional_embedding.to(x.dtype)
225 | x = self.ln_pre(x)
226 |
227 | x = x.permute(1, 0, 2) # NLD -> LND
228 | x = self.transformer(x)
229 | x = x.permute(1, 0, 2) # LND -> NLD
230 |
231 | x = self.ln_post(x[:, 0, :])
232 |
233 | if self.proj is not None:
234 | x = x @ self.proj
235 |
236 | return x
237 |
238 |
239 | class CLIP(nn.Module):
240 | def __init__(self,
241 | embed_dim: int,
242 | # vision
243 | image_resolution: int,
244 | vision_layers: Union[Tuple[int, int, int, int], int],
245 | vision_width: int,
246 | vision_patch_size: int,
247 | # text
248 | context_length: int,
249 | vocab_size: int,
250 | transformer_width: int,
251 | transformer_heads: int,
252 | transformer_layers: int
253 | ):
254 | super().__init__()
255 |
256 | self.context_length = context_length
257 |
258 | if isinstance(vision_layers, (tuple, list)):
259 | vision_heads = vision_width * 32 // 64
260 | self.visual = ModifiedResNet(
261 | layers=vision_layers,
262 | output_dim=embed_dim,
263 | heads=vision_heads,
264 | input_resolution=image_resolution,
265 | width=vision_width
266 | )
267 | else:
268 | vision_heads = vision_width // 64
269 | self.visual = VisionTransformer(
270 | input_resolution=image_resolution,
271 | patch_size=vision_patch_size,
272 | width=vision_width,
273 | layers=vision_layers,
274 | heads=vision_heads,
275 | output_dim=embed_dim
276 | )
277 |
278 | self.transformer = Transformer(
279 | width=transformer_width,
280 | layers=transformer_layers,
281 | heads=transformer_heads,
282 | attn_mask=self.build_attention_mask()
283 | )
284 |
285 | self.vocab_size = vocab_size
286 | self.token_embedding = nn.Embedding(vocab_size, transformer_width)
287 | self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
288 | self.ln_final = LayerNorm(transformer_width)
289 |
290 | self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
291 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
292 |
293 | self.initialize_parameters()
294 |
295 | def initialize_parameters(self):
296 | nn.init.normal_(self.token_embedding.weight, std=0.02)
297 | nn.init.normal_(self.positional_embedding, std=0.01)
298 |
299 | if isinstance(self.visual, ModifiedResNet):
300 | if self.visual.attnpool is not None:
301 | std = self.visual.attnpool.c_proj.in_features ** -0.5
302 | nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
303 | nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
304 | nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
305 | nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
306 |
307 | for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
308 | for name, param in resnet_block.named_parameters():
309 | if name.endswith("bn3.weight"):
310 | nn.init.zeros_(param)
311 |
312 | proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
313 | attn_std = self.transformer.width ** -0.5
314 | fc_std = (2 * self.transformer.width) ** -0.5
315 | for block in self.transformer.resblocks:
316 | nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
317 | nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
318 | nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
319 | nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
320 |
321 | if self.text_projection is not None:
322 | nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
323 |
324 | def build_attention_mask(self):
325 | # lazily create causal attention mask, with full attention between the vision tokens
326 | # pytorch uses additive attention mask; fill with -inf
327 | mask = torch.empty(self.context_length, self.context_length)
328 | mask.fill_(float("-inf"))
329 | mask.triu_(1) # zero out the lower diagonal
330 | return mask
331 |
332 | @property
333 | def dtype(self):
334 | return self.visual.conv1.weight.dtype
335 |
336 | def encode_image(self, image):
337 | return self.visual(image.type(self.dtype))
338 |
339 | def encode_text(self, text):
340 | x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
341 |
342 | x = x + self.positional_embedding.type(self.dtype)
343 | x = x.permute(1, 0, 2) # NLD -> LND
344 | x = self.transformer(x)
345 | x = x.permute(1, 0, 2) # LND -> NLD
346 | x = self.ln_final(x).type(self.dtype)
347 |
348 | # x.shape = [batch_size, n_ctx, transformer.width]
349 | # take features from the eot embedding (eot_token is the highest number in each sequence)
350 | x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
351 |
352 | return x
353 |
354 | def forward(self, image, text):
355 | image_features = self.encode_image(image)
356 | text_features = self.encode_text(text)
357 |
358 | # normalized features
359 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
360 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
361 |
362 | # cosine similarity as logits
363 | logit_scale = self.logit_scale.exp()
364 | logits_per_image = logit_scale * image_features @ text_features.t()
365 | logits_per_text = logit_scale * text_features @ image_features.t()
366 |
367 | # shape = [global_batch_size, global_batch_size]
368 | return logits_per_image, logits_per_text
369 |
370 |
371 | def convert_weights(model: nn.Module):
372 | """Convert applicable model parameters to fp16"""
373 |
374 | def _convert_weights_to_fp16(l):
375 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
376 | l.weight.data = l.weight.data.half()
377 | if l.bias is not None:
378 | l.bias.data = l.bias.data.half()
379 |
380 | if isinstance(l, nn.MultiheadAttention):
381 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
382 | tensor = getattr(l, attr)
383 | if tensor is not None:
384 | tensor.data = tensor.data.half()
385 |
386 | for name in ["text_projection", "proj"]:
387 | if hasattr(l, name):
388 | attr = getattr(l, name)
389 | if attr is not None:
390 | attr.data = attr.data.half()
391 |
392 | model.apply(_convert_weights_to_fp16)
393 |
394 |
395 | def build_model(state_dict: dict):
396 | vit = "visual.proj" in state_dict
397 |
398 | if vit:
399 | vision_width = state_dict["visual.conv1.weight"].shape[0]
400 | vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
401 | vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
402 | grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
403 | image_resolution = vision_patch_size * grid_size
404 | else:
405 | counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
406 | vision_layers = tuple(counts)
407 | vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
408 | output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
409 | vision_patch_size = None
410 | assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
411 | image_resolution = output_width * 32
412 |
413 | embed_dim = state_dict["text_projection"].shape[1]
414 | context_length = state_dict["positional_embedding"].shape[0]
415 | vocab_size = state_dict["token_embedding.weight"].shape[0]
416 | transformer_width = state_dict["ln_final.weight"].shape[0]
417 | transformer_heads = transformer_width // 64
418 | transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
419 |
420 | model = CLIP(
421 | embed_dim,
422 | image_resolution, vision_layers, vision_width, vision_patch_size,
423 | context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
424 | )
425 |
426 | for key in ["input_resolution", "context_length", "vocab_size"]:
427 | if key in state_dict:
428 | del state_dict[key]
429 |
430 | convert_weights(model)
431 | model.load_state_dict(state_dict)
432 | return model.eval()
433 |
--------------------------------------------------------------------------------
/core/clip/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import html
3 | import os
4 | from functools import lru_cache
5 |
6 | import ftfy
7 | import regex as re
8 |
9 |
10 | @lru_cache()
11 | def default_bpe():
12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
13 |
14 |
15 | @lru_cache()
16 | def bytes_to_unicode():
17 | """
18 | Returns list of utf-8 byte and a corresponding list of unicode strings.
19 | The reversible bpe codes work on unicode strings.
20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
22 | This is a signficant percentage of your normal, say, 32K bpe vocab.
23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
24 | And avoids mapping to whitespace/control characters the bpe code barfs on.
25 | """
26 | bs = list(range(ord("!"), ord("~") + 1)) + list(range(ord("¡"), ord("¬") + 1)) + list(range(ord("®"), ord("ÿ") + 1))
27 | cs = bs[:]
28 | n = 0
29 | for b in range(2**8):
30 | if b not in bs:
31 | bs.append(b)
32 | cs.append(2**8 + n)
33 | n += 1
34 | cs = [chr(n) for n in cs]
35 | return dict(zip(bs, cs))
36 |
37 |
38 | def get_pairs(word):
39 | """Return set of symbol pairs in a word.
40 | Word is represented as tuple of symbols (symbols being variable-length strings).
41 | """
42 | pairs = set()
43 | prev_char = word[0]
44 | for char in word[1:]:
45 | pairs.add((prev_char, char))
46 | prev_char = char
47 | return pairs
48 |
49 |
50 | def basic_clean(text):
51 | text = ftfy.fix_text(text)
52 | text = html.unescape(html.unescape(text))
53 | return text.strip()
54 |
55 |
56 | def whitespace_clean(text):
57 | text = re.sub(r'\s+', ' ', text)
58 | text = text.strip()
59 | return text
60 |
61 |
62 | class SimpleTokenizer(object):
63 | def __init__(self, bpe_path: str = default_bpe()):
64 | self.byte_encoder = bytes_to_unicode()
65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
67 | merges = merges[1: 49152 - 256 - 2 + 1]
68 | merges = [tuple(merge.split()) for merge in merges]
69 | vocab = list(bytes_to_unicode().values())
70 | vocab = vocab + [v + '' for v in vocab]
71 | for merge in merges:
72 | vocab.append(''.join(merge))
73 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
74 | self.encoder = dict(zip(vocab, range(len(vocab))))
75 | self.decoder = {v: k for k, v in self.encoder.items()}
76 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
79 |
80 | def bpe(self, token):
81 | if token in self.cache:
82 | return self.cache[token]
83 | word = tuple(token[:-1]) + (token[-1] + '',)
84 | pairs = get_pairs(word)
85 |
86 | if not pairs:
87 | return token + ''
88 |
89 | while True:
90 | bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float('inf')))
91 | if bigram not in self.bpe_ranks:
92 | break
93 | first, second = bigram
94 | new_word = []
95 | i = 0
96 | while i < len(word):
97 | try:
98 | j = word.index(first, i)
99 | new_word.extend(word[i:j])
100 | i = j
101 | except Exception:
102 | new_word.extend(word[i:])
103 | break
104 |
105 | if word[i] == first and i < len(word) - 1 and word[i + 1] == second:
106 | new_word.append(first + second)
107 | i += 2
108 | else:
109 | new_word.append(word[i])
110 | i += 1
111 | new_word = tuple(new_word)
112 | word = new_word
113 | if len(word) == 1:
114 | break
115 | else:
116 | pairs = get_pairs(word)
117 | word = ' '.join(word)
118 | self.cache[token] = word
119 | return word
120 |
121 | def encode(self, text):
122 | bpe_tokens = []
123 | text = whitespace_clean(basic_clean(text)).lower()
124 | for token in re.findall(self.pat, text):
125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
127 | return bpe_tokens
128 |
129 | def decode(self, tokens):
130 | text = ''.join([self.decoder[token] for token in tokens])
131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
132 | return text
133 |
--------------------------------------------------------------------------------
/core/optimizer/__init__.py:
--------------------------------------------------------------------------------
1 | from core.optimizer.adamp import AdamP
2 | from core.optimizer.diffgrad import DiffGrad
3 | from core.optimizer.radam import RAdam
4 |
5 | __all__ = [
6 | AdamP,
7 | DiffGrad,
8 | RAdam,
9 | ]
10 |
--------------------------------------------------------------------------------
/core/optimizer/adamp.py:
--------------------------------------------------------------------------------
1 | # https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/adamp.py
2 | import math
3 |
4 | import torch
5 | from torch.optim.optimizer import Optimizer
6 |
7 |
8 | class AdamP(Optimizer):
9 | r"""Implements AdamP algorithm.
10 |
11 | It has been proposed in `Slowing Down the Weight Norm Increase in
12 | Momentum-based Optimizers`__
13 |
14 | Arguments:
15 | params: iterable of parameters to optimize or dicts defining
16 | parameter groups
17 | lr: learning rate (default: 1e-3)
18 | betas: coefficients used for computing
19 | running averages of gradient and its square (default: (0.9, 0.999))
20 | eps: term added to the denominator to improve
21 | numerical stability (default: 1e-8)
22 | weight_decay: weight decay (L2 penalty) (default: 0)
23 | delta: threhold that determines whether a set of parameters is scale
24 | invariant or not (default: 0.1)
25 | wd_ratio: relative weight decay applied on scale-invariant parameters
26 | compared to that applied on scale-variant parameters (default: 0.1)
27 | nesterov: enables Nesterov momentum (default: False)
28 |
29 |
30 | Example:
31 | >>> import torch_optimizer as optim
32 | >>> optimizer = optim.AdamP(model.parameters(), lr=0.1)
33 | >>> optimizer.zero_grad()
34 | >>> loss_fn(model(input), target).backward()
35 | >>> optimizer.step()
36 |
37 | __ https://arxiv.org/abs/2006.08217
38 |
39 | Note:
40 | Reference code: https://github.com/clovaai/AdamP
41 | """
42 |
43 | def __init__(
44 | self,
45 | params,
46 | lr: float = 1e-3,
47 | betas=(0.9, 0.999),
48 | eps: float = 1e-8,
49 | weight_decay: float = 0,
50 | delta: float = 0.1,
51 | wd_ratio: float = 0.1,
52 | nesterov: bool = False,
53 | ) -> None:
54 | if lr <= 0.0:
55 | raise ValueError('Invalid learning rate: {}'.format(lr))
56 | if eps < 0.0:
57 | raise ValueError('Invalid epsilon value: {}'.format(eps))
58 | if not 0.0 <= betas[0] < 1.0:
59 | raise ValueError(
60 | 'Invalid beta parameter at index 0: {}'.format(betas[0])
61 | )
62 | if not 0.0 <= betas[1] < 1.0:
63 | raise ValueError(
64 | 'Invalid beta parameter at index 1: {}'.format(betas[1])
65 | )
66 | if weight_decay < 0:
67 | raise ValueError(
68 | 'Invalid weight_decay value: {}'.format(weight_decay)
69 | )
70 | if delta < 0:
71 | raise ValueError('Invalid delta value: {}'.format(delta))
72 | if wd_ratio < 0:
73 | raise ValueError('Invalid wd_ratio value: {}'.format(wd_ratio))
74 |
75 | defaults = dict(
76 | lr=lr,
77 | betas=betas,
78 | eps=eps,
79 | weight_decay=weight_decay,
80 | delta=delta,
81 | wd_ratio=wd_ratio,
82 | nesterov=nesterov,
83 | )
84 | super(AdamP, self).__init__(params, defaults)
85 |
86 | @staticmethod
87 | def _channel_view(x):
88 | return x.view(x.size(0), -1)
89 |
90 | @staticmethod
91 | def _layer_view(x):
92 | return x.view(1, -1)
93 |
94 | @staticmethod
95 | def _cosine_similarity(x, y, eps, view_func):
96 | x = view_func(x)
97 | y = view_func(y)
98 |
99 | x_norm = x.norm(dim=1).add_(eps)
100 | y_norm = y.norm(dim=1).add_(eps)
101 | dot = (x * y).sum(dim=1)
102 |
103 | return dot.abs() / x_norm / y_norm
104 |
105 | def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
106 | wd = 1
107 | expand_size = [-1] + [1] * (len(p.shape) - 1)
108 | for view_func in [self._channel_view, self._layer_view]:
109 |
110 | cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
111 |
112 | if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
113 | p_n = p.data / view_func(p.data).norm(dim=1).view(
114 | expand_size
115 | ).add_(eps)
116 | perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(
117 | expand_size
118 | )
119 | wd = wd_ratio
120 |
121 | return perturb, wd
122 |
123 | return perturb, wd
124 |
125 | def step(self, closure=None):
126 | r"""Performs a single optimization step.
127 |
128 | Arguments:
129 | closure: A closure that reevaluates the model and returns the loss.
130 | """
131 | loss = None
132 | if closure is not None:
133 | loss = closure()
134 |
135 | for group in self.param_groups:
136 | for p in group['params']:
137 | if p.grad is None:
138 | continue
139 |
140 | grad = p.grad.data
141 | beta1, beta2 = group['betas']
142 | nesterov = group['nesterov']
143 |
144 | state = self.state[p]
145 |
146 | # State initialization
147 | if len(state) == 0:
148 | state['step'] = 0
149 | state['exp_avg'] = torch.zeros_like(
150 | p.data, memory_format=torch.preserve_format
151 | )
152 | state['exp_avg_sq'] = torch.zeros_like(
153 | p.data, memory_format=torch.preserve_format
154 | )
155 |
156 | # Adam
157 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
158 |
159 | state['step'] += 1
160 | bias_correction1 = 1 - beta1 ** state['step']
161 | bias_correction2 = 1 - beta2 ** state['step']
162 |
163 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
164 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
165 |
166 | denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(
167 | group['eps']
168 | )
169 | step_size = group['lr'] / bias_correction1
170 |
171 | if nesterov:
172 | perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
173 | else:
174 | perturb = exp_avg / denom
175 |
176 | # Projection
177 | wd_ratio = 1
178 | if len(p.shape) > 1:
179 | perturb, wd_ratio = self._projection(
180 | p,
181 | grad,
182 | perturb,
183 | group['delta'],
184 | group['wd_ratio'],
185 | group['eps'],
186 | )
187 |
188 | # Weight decay
189 | if group['weight_decay'] > 0:
190 | p.data.mul_(
191 | 1 - group['lr'] * group['weight_decay'] * wd_ratio
192 | )
193 |
194 | # Step
195 | p.data.add_(perturb, alpha=-step_size)
196 |
197 | return loss
198 |
--------------------------------------------------------------------------------
/core/optimizer/diffgrad.py:
--------------------------------------------------------------------------------
1 | # https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/diffgrad.py
2 | import math
3 |
4 | import torch
5 | from torch.optim.optimizer import Optimizer
6 |
7 |
8 | class DiffGrad(Optimizer):
9 | r"""Implements DiffGrad algorithm.
10 |
11 | It has been proposed in `DiffGrad: An Optimization Method for
12 | Convolutional Neural Networks`__.
13 |
14 | Arguments:
15 | params: iterable of parameters to optimize or dicts defining
16 | parameter groups
17 | lr: learning rate (default: 1e-3)
18 | betas: coefficients used for computing
19 | running averages of gradient and its square (default: (0.9, 0.999))
20 | eps: term added to the denominator to improve
21 | numerical stability (default: 1e-8)
22 | weight_decay: weight decay (L2 penalty) (default: 0)
23 |
24 | Example:
25 | >>> import torch_optimizer as optim
26 | >>> optimizer = optim.DiffGrad(model.parameters(), lr=0.1)
27 | >>> optimizer.zero_grad()
28 | >>> loss_fn(model(input), target).backward()
29 | >>> optimizer.step()
30 |
31 | __ https://arxiv.org/abs/1909.11015
32 |
33 | Note:
34 | Reference code: https://github.com/shivram1987/diffGrad
35 | """
36 |
37 | def __init__(
38 | self,
39 | params,
40 | lr: float = 1e-3,
41 | betas=(0.9, 0.999),
42 | eps: float = 1e-8,
43 | weight_decay: float = 0.0,
44 | ) -> None:
45 | if lr <= 0.0:
46 | raise ValueError('Invalid learning rate: {}'.format(lr))
47 | if eps < 0.0:
48 | raise ValueError('Invalid epsilon value: {}'.format(eps))
49 | if not 0.0 <= betas[0] < 1.0:
50 | raise ValueError(
51 | 'Invalid beta parameter at index 0: {}'.format(betas[0])
52 | )
53 | if not 0.0 <= betas[1] < 1.0:
54 | raise ValueError(
55 | 'Invalid beta parameter at index 1: {}'.format(betas[1])
56 | )
57 | if weight_decay < 0.0:
58 | raise ValueError(
59 | 'Invalid weight_decay value: {}'.format(weight_decay)
60 | )
61 |
62 | defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
63 | super(DiffGrad, self).__init__(params, defaults)
64 |
65 | def step(self, closure=None):
66 | r"""Performs a single optimization step.
67 |
68 | Arguments:
69 | closure: A closure that reevaluates the model and returns the loss.
70 | """
71 | loss = None
72 | if closure is not None:
73 | loss = closure()
74 |
75 | for group in self.param_groups:
76 | beta1, beta2 = group['betas']
77 |
78 | for p in group['params']:
79 | if p.grad is None:
80 | continue
81 | grad = p.grad.data
82 | if grad.is_sparse:
83 | msg = (
84 | 'DiffGrad does not support sparse gradients, '
85 | 'please consider SparseAdam instead'
86 | )
87 | raise RuntimeError(msg)
88 |
89 | state = self.state[p]
90 |
91 | # State initialization
92 | if len(state) == 0:
93 | state['step'] = 0
94 | # Exponential moving average of gradient values
95 | state['exp_avg'] = torch.zeros_like(
96 | p, memory_format=torch.preserve_format
97 | )
98 | # Exponential moving average of squared gradient values
99 | state['exp_avg_sq'] = torch.zeros_like(
100 | p, memory_format=torch.preserve_format
101 | )
102 | # Previous gradient
103 | state['previous_grad'] = torch.zeros_like(
104 | p, memory_format=torch.preserve_format
105 | )
106 |
107 | exp_avg, exp_avg_sq, previous_grad = (
108 | state['exp_avg'],
109 | state['exp_avg_sq'],
110 | state['previous_grad'],
111 | )
112 |
113 | state['step'] += 1
114 |
115 | if group['weight_decay'] != 0:
116 | grad.add_(p.data, alpha=group['weight_decay'])
117 |
118 | # Decay the first and second moment running average coefficient
119 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
120 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
121 | denom = exp_avg_sq.sqrt().add_(group['eps'])
122 |
123 | bias_correction1 = 1 - beta1 ** state['step']
124 | bias_correction2 = 1 - beta2 ** state['step']
125 |
126 | # compute diffgrad coefficient (dfc)
127 | diff = torch.abs(previous_grad - grad)
128 | dfc = torch.div(1.0, (1.0 + torch.exp(-diff)))
129 | state['previous_grad'] = grad.clone()
130 |
131 | # update momentum with dfc
132 | exp_avg1 = exp_avg * dfc
133 |
134 | step_size = (
135 | group['lr']
136 | * math.sqrt(bias_correction2)
137 | / bias_correction1
138 | )
139 |
140 | p.data.addcdiv_(exp_avg1, denom, value=-step_size)
141 |
142 | return loss
143 |
--------------------------------------------------------------------------------
/core/optimizer/radam.py:
--------------------------------------------------------------------------------
1 | # https://github.com/jettify/pytorch-optimizer/blob/master/torch_optimizer/radam.py
2 | import math
3 |
4 | import torch
5 | from torch.optim.optimizer import Optimizer
6 |
7 |
8 | class RAdam(Optimizer):
9 | r"""Implements RAdam optimization algorithm.
10 |
11 | It has been proposed in `On the Variance of the Adaptive Learning
12 | Rate and Beyond`__.
13 |
14 | Arguments:
15 | params: iterable of parameters to optimize or dicts defining
16 | parameter groups
17 | lr: learning rate (default: 1e-3)
18 | betas: coefficients used for computing
19 | running averages of gradient and its square (default: (0.9, 0.999))
20 | eps: term added to the denominator to improve
21 | numerical stability (default: 1e-8)
22 | weight_decay: weight decay (L2 penalty) (default: 0)
23 |
24 | Example:
25 | >>> import torch_optimizer as optim
26 | >>> optimizer = optim.RAdam(model.parameters(), lr=0.1)
27 | >>> optimizer.zero_grad()
28 | >>> loss_fn(model(input), target).backward()
29 | >>> optimizer.step()
30 |
31 | __ https://arxiv.org/abs/1908.03265
32 |
33 | Note:
34 | Reference code: https://github.com/LiyuanLucasLiu/RAdam
35 | """
36 |
37 | def __init__(
38 | self,
39 | params,
40 | lr: float = 1e-3,
41 | betas=(0.9, 0.999),
42 | eps: float = 1e-8,
43 | weight_decay: float = 0,
44 | ) -> None:
45 | if lr <= 0.0:
46 | raise ValueError('Invalid learning rate: {}'.format(lr))
47 | if eps < 0.0:
48 | raise ValueError('Invalid epsilon value: {}'.format(eps))
49 | if not 0.0 <= betas[0] < 1.0:
50 | raise ValueError(
51 | 'Invalid beta parameter at index 0: {}'.format(betas[0])
52 | )
53 | if not 0.0 <= betas[1] < 1.0:
54 | raise ValueError(
55 | 'Invalid beta parameter at index 1: {}'.format(betas[1])
56 | )
57 | if weight_decay < 0:
58 | raise ValueError(
59 | 'Invalid weight_decay value: {}'.format(weight_decay)
60 | )
61 |
62 | if (
63 | isinstance(params, (list, tuple))
64 | and len(params) > 0
65 | and isinstance(params[0], dict)
66 | ):
67 | for param in params:
68 | if 'betas' in param and (
69 | param['betas'][0] != betas[0]
70 | or param['betas'][1] != betas[1]
71 | ):
72 | param['buffer'] = [[None, None, None] for _ in range(10)]
73 |
74 | defaults = dict(
75 | lr=lr,
76 | betas=betas,
77 | eps=eps,
78 | weight_decay=weight_decay,
79 | buffer=[[None, None, None] for _ in range(10)],
80 | )
81 | super(RAdam, self).__init__(params, defaults)
82 |
83 | def __setstate__(self, state):
84 | super(RAdam, self).__setstate__(state)
85 |
86 | def step(self, closure=None):
87 | r"""Performs a single optimization step.
88 |
89 | Arguments:
90 | closure: A closure that reevaluates the model and returns the loss.
91 | """
92 |
93 | loss = None
94 | if closure is not None:
95 | loss = closure()
96 |
97 | for group in self.param_groups:
98 | lr = group['lr']
99 | weight_decay = group['weight_decay']
100 | beta1, beta2 = group['betas']
101 | eps = group['eps']
102 |
103 | for p in group['params']:
104 | if p.grad is None:
105 | continue
106 | grad = p.grad.data.float()
107 | if grad.is_sparse:
108 | msg = (
109 | 'RAdam does not support sparse gradients, '
110 | 'please consider SparseAdam instead'
111 | )
112 | raise RuntimeError(msg)
113 |
114 | p_data_fp32 = p.data.float()
115 |
116 | state = self.state[p]
117 |
118 | if len(state) == 0:
119 | state['step'] = 0
120 | state['exp_avg'] = torch.zeros_like(
121 | p_data_fp32, memory_format=torch.preserve_format
122 | )
123 | state['exp_avg_sq'] = torch.zeros_like(
124 | p_data_fp32, memory_format=torch.preserve_format
125 | )
126 | else:
127 | state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
128 | state['exp_avg_sq'] = state['exp_avg_sq'].type_as(
129 | p_data_fp32
130 | )
131 |
132 | exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
133 |
134 | exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
135 | exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
136 |
137 | state['step'] += 1
138 | buffered = group['buffer'][int(state['step'] % 10)]
139 | if state['step'] == buffered[0]:
140 | N_sma, step_size = buffered[1], buffered[2]
141 | else:
142 | buffered[0] = state['step']
143 | beta2_t = beta2 ** state['step']
144 | N_sma_max = 2 / (1 - beta2) - 1
145 | N_sma = N_sma_max - 2 * state['step'] * beta2_t / (
146 | 1 - beta2_t
147 | )
148 | buffered[1] = N_sma
149 |
150 | # more conservative since it's an approximated value
151 | if N_sma >= 5:
152 | step_size = (
153 | lr
154 | * math.sqrt(
155 | (1 - beta2_t)
156 | * (N_sma - 4)
157 | / (N_sma_max - 4)
158 | * (N_sma - 2)
159 | / N_sma
160 | * N_sma_max
161 | / (N_sma_max - 2)
162 | )
163 | / (1 - beta1 ** state['step'])
164 | )
165 | else:
166 | step_size = lr / (1 - beta1 ** state['step'])
167 | buffered[2] = step_size
168 |
169 | if weight_decay != 0:
170 | p_data_fp32.add_(p_data_fp32, alpha=-weight_decay * lr)
171 |
172 | # more conservative since it's an approximated value
173 | if N_sma >= 5:
174 | denom = exp_avg_sq.sqrt().add_(eps)
175 | p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size)
176 | else:
177 | p_data_fp32.add_(exp_avg, alpha=-step_size)
178 |
179 | p.data.copy_(p_data_fp32)
180 |
181 | return loss
182 |
--------------------------------------------------------------------------------
/core/schemas/__init__.py:
--------------------------------------------------------------------------------
1 | from core.schemas.config import Config
2 | from core.schemas.train_config import TrainConfig
3 |
4 | __all__ = [
5 | Config,
6 | TrainConfig,
7 | ]
8 |
--------------------------------------------------------------------------------
/core/schemas/config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from core.clip.clip import available_models
4 |
5 | from typing import List
6 | from dataclasses import dataclass, field
7 |
8 |
9 | INIT_NOISES = ['', 'gradient', 'pixels', 'fractal']
10 | OPTIMIZERS = ['Adam', 'AdamW', 'Adagrad', 'Adamax', 'DiffGrad', 'AdamP', 'RAdam']
11 | AUGMENTS = ['Ji', 'Sh', 'Gn', 'Pe', 'Ro', 'Af', 'Et', 'Ts', 'Cr', 'Er', 'Re', 'Hf']
12 |
13 |
14 | @dataclass
15 | class Config:
16 | prompts: List[str] = field(default_factory=lambda: [])
17 | image_prompts: List[str] = field(default_factory=lambda: [])
18 | max_iterations: int = 500
19 | save_freq: int = 50
20 | size: List[int] = field(default_factory=lambda: [256, 256])
21 | pixelart: List[int] = None
22 | init_image: str = ""
23 | init_noise: str = "gradient"
24 | init_weight: float = 0.0
25 | mse_decay_rate: float = 0.0
26 | output_dir: str = "./outputs"
27 | models_dir: str = "./models"
28 | clip_model: str = 'ViT-B/16'
29 | vqgan_checkpoint: str = './models/vqgan_imagenet_f16_16384.ckpt'
30 | vqgan_config: str = './configs/models/vqgan_imagenet_f16_16384.json'
31 | noise_prompt_seeds: List[int] = field(default_factory=lambda: [])
32 | noise_prompt_weights: List[float] = field(default_factory=lambda: [])
33 | step_size: float = 0.1
34 | cutn: int = 32
35 | cut_pow: float = 1.0
36 | seed: int = -1
37 | optimizer: str = 'Adam'
38 | nwarm_restarts: int = -1
39 | augments: List[str] = field(default_factory=lambda: ['Af', 'Pe', 'Ji', 'Er'])
40 |
41 | def __post_init__(self):
42 | if self.init_noise not in INIT_NOISES:
43 | exit(f"ERROR: \"init_noise\": {self.init_noise}, <-- Noise algorithm not found.\n"
44 | f"Currently only the following values are supported: {INIT_NOISES}.")
45 |
46 | if self.optimizer not in OPTIMIZERS:
47 | exit(f"ERROR: \"optimizer\": {self.optimizer}, <-- Optimizer not found.\n"
48 | f"Currently only the following values are supported: {OPTIMIZERS}.")
49 |
50 | os.makedirs(self.models_dir, exist_ok=True)
51 | os.makedirs(self.output_dir, exist_ok=True)
52 | os.makedirs(f"{self.output_dir}/steps", exist_ok=True)
53 | print(f"Saving outputs in '{self.output_dir}'")
54 |
55 | models = available_models()
56 | if not os.path.exists(self.clip_model) and self.clip_model not in models:
57 | exit(f"ERROR: \"clip_model\": {self.clip_model}, <-- Model not found.\n"
58 | f"Make sure it is a valid path to a downloaded model or match one of {models}.")
59 |
60 | if not os.path.exists(self.vqgan_config):
61 | exit(f"ERROR: \"vqgan_config\": {self.vqgan_config}, <-- Configuration file not found.\n"
62 | f"Make sure the path is correct (Multiple config files are available in the `./configs/models` directory).")
63 |
64 | if not os.path.exists(self.vqgan_checkpoint):
65 | exit(f"ERROR: \"vqgan_checkpoint\": {self.vqgan_checkpoint}, <-- Model not found.\n"
66 | f"Make sure the path is correct and that you have downloaded the model (Refer to the README).")
67 |
68 | if self.pixelart:
69 | print("Enabling PixelArt mode. It is recommended to add 'pixelart' to your prompt.")
70 |
71 |
72 | def __str__(self):
73 | _str = (
74 | f"Config:\n"
75 | f" - prompts: {self.prompts}\n"
76 | f" - image_prompts: {self.image_prompts}\n"
77 | f" - max_iterations: {self.max_iterations}\n"
78 | f" - save_freq: {self.save_freq}\n"
79 | f" - size: {self.size}\n"
80 | f" - pixelart: {self.pixelart}\n"
81 | f" - init_image: {self.init_image}\n"
82 | f" - init_noise: {self.init_noise}\n"
83 | f" - init_weight: {self.init_weight}\n"
84 | f" - mse_decay_rate: {self.mse_decay_rate}\n"
85 | f" - output_dir: {self.output_dir}\n"
86 | f" - models_dir: {self.models_dir}\n"
87 | f" - clip_model: {self.clip_model}\n"
88 | f" - vqgan_checkpoint: {self.vqgan_checkpoint}\n"
89 | f" - vqgan_config: {self.vqgan_config}\n"
90 | f" - noise_prompt_seeds: {self.noise_prompt_seeds}\n"
91 | f" - noise_prompt_weights: {self.noise_prompt_weights}\n"
92 | f" - step_size: {self.step_size}\n"
93 | f" - cutn: {self.cutn}\n"
94 | f" - cut_pow: {self.cut_pow}\n"
95 | f" - seed: {self.seed}\n"
96 | f" - optimizer: {self.optimizer}\n"
97 | f" - nwarm_restarts: {self.nwarm_restarts}\n"
98 | f" - augments: {self.augments}\n"
99 | )
100 | return _str
101 |
--------------------------------------------------------------------------------
/core/schemas/train_config.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dataclasses import dataclass
4 |
5 |
6 | @dataclass
7 | class TrainConfig:
8 | base_learning_rate: float = 4.5e-6
9 | batch_size: int = 1
10 | epochs: int = 1000
11 | data_dir: str = "./data"
12 | output_dir: str = "./outputs"
13 | models_dir: str = "./models"
14 | resume_checkpoint: str = ""
15 | seed: int = -1
16 | params: dict = None
17 |
18 | def __post_init__(self):
19 | if not os.path.exists(self.data_dir):
20 | exit(f"ERROR: \"data_dir\": {self.data_dir}, <-- Data direcotry not found.\n"
21 | f"Make sure the path is correct (Follow instructions in the README).")
22 |
23 | ckpt_dir = os.path.join(self.models_dir, "checkpoints")
24 | os.makedirs(ckpt_dir, exist_ok=True)
25 | print(f"Checkpoints will be saved in {ckpt_dir}")
26 |
27 | train_dir = os.path.join(self.output_dir, "training")
28 | os.makedirs(train_dir, exist_ok=True)
29 | print(f"Training outputs will be saved in {train_dir}")
30 |
31 | if self.resume_checkpoint and not os.path.exists(self.resume_checkpoint):
32 | exit(f"ERROR: \"resume_checkpoint\": {self.resume_checkpoint}, <-- Model not found.\n"
33 | f"Make sure the path is correct (Follow instructions in the README).")
34 |
35 | def __str__(self):
36 | _str = (
37 | f"Config:\n"
38 | f" - base_learning_rate: {self.base_learning_rate}\n"
39 | f" - batch_size: {self.batch_size}\n"
40 | f" - epochs: {self.epochs}\n"
41 | f" - data_dir: {self.data_dir}\n"
42 | f" - output_dir: {self.output_dir}\n"
43 | f" - models_dir: {self.models_dir}\n"
44 | f" - resume_checkpoint: {self.resume_checkpoint}\n"
45 | f" - seed: {self.seed}\n"
46 | f" - params: {self.params}\n"
47 | )
48 | return _str
49 |
--------------------------------------------------------------------------------
/core/taming/README.md:
--------------------------------------------------------------------------------
1 | # Taming Transformers for High-Resolution Image Synthesis
2 |
3 | [[Original]](https://github.com/CompVis/taming-transformers)
4 |
5 | ## About
6 |
7 | A stripped & minimalist version of the original project.
8 |
--------------------------------------------------------------------------------
/core/taming/models/__init__.py:
--------------------------------------------------------------------------------
1 | from core.taming.models.vqgan import VQModel
2 |
3 |
4 | __all__ = [
5 | VQModel
6 | ]
7 |
--------------------------------------------------------------------------------
/core/taming/models/vqgan.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from core.taming.modules.diffusion import Encoder, Decoder
5 | from core.taming.modules.vqvae import VectorQuantizer
6 | from core.taming.modules.losses import VQLPIPSWithDiscriminator, DummyLoss
7 |
8 | from core.utils.loader import safe_load
9 |
10 |
11 | class VQModel(nn.Module):
12 | def __init__(self,
13 | ddconfig,
14 | n_embed,
15 | embed_dim,
16 | lossconfig=None,
17 | ckpt_path=None,
18 | model_dir=None,
19 | ignore_keys=[],
20 | image_key="image",
21 | colorize_nlabels=None,
22 | monitor=None,
23 | remap=None,
24 | sane_index_shape=False, # tell vector quantizer to return indices as bhw
25 | ):
26 | super().__init__()
27 | self.image_key = image_key
28 |
29 | self.encoder = Encoder(**ddconfig)
30 | self.decoder = Decoder(**ddconfig)
31 |
32 | self.loss = DummyLoss()
33 | if lossconfig is not None:
34 | self.loss = VQLPIPSWithDiscriminator(model_dir=model_dir, **lossconfig["params"])
35 |
36 | self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
37 | remap=remap, sane_index_shape=sane_index_shape)
38 | self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
39 | self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
40 |
41 | if ckpt_path is not None:
42 | self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
43 |
44 | self.image_key = image_key
45 |
46 | if colorize_nlabels is not None:
47 | assert type(colorize_nlabels) == int
48 | self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
49 | if monitor is not None:
50 | self.monitor = monitor
51 |
52 | def init_from_ckpt(self, path, ignore_keys=list()):
53 | try:
54 | sd = torch.load(path, map_location="cpu")["state_dict"]
55 | except Exception:
56 | sd = safe_load(path, map_location="cpu")["state_dict"]
57 |
58 | keys = list(sd.keys())
59 | for k in keys:
60 | for ik in ignore_keys:
61 | if k.startswith(ik):
62 | print("Deleting key {} from state_dict.".format(k))
63 | del sd[k]
64 |
65 | if "first_stage_model.encoder.conv_in.weight" in sd:
66 | stripped_state_dict = {}
67 | for key in sd:
68 | if key.startswith("first_stage_model."):
69 | stripped_state_dict[key[18:]] = sd[key]
70 | sd = stripped_state_dict
71 |
72 | self.load_state_dict(sd, strict=False)
73 | print(f"Restored from {path}")
74 |
75 | def encode(self, x):
76 | h = self.encoder(x)
77 | h = self.quant_conv(h)
78 | quant, emb_loss, info = self.quantize(h)
79 | return quant, emb_loss, info
80 |
81 | def decode(self, quant):
82 | quant = self.post_quant_conv(quant)
83 | dec = self.decoder(quant)
84 | return dec
85 |
86 | def decode_code(self, code_b):
87 | quant_b = self.quantize.embed_code(code_b)
88 | dec = self.decode(quant_b)
89 | return dec
90 |
91 | def forward(self, input):
92 | quant, diff, _ = self.encode(input)
93 | dec = self.decode(quant)
94 | return dec, diff
95 |
96 | def get_input(self, batch, device):
97 | x = batch
98 | if len(x.shape) == 3:
99 | x = x[..., None]
100 | x = x.to(device, memory_format=torch.contiguous_format)
101 | return x.float()
102 |
103 | def training_step(self, batch, batch_idx, optimizer_idx, device='cpu'):
104 | x = self.get_input(batch, device)
105 | xrec, qloss = self(x)
106 |
107 | if optimizer_idx == 0:
108 | # autoencode
109 | aeloss = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train")
110 | return aeloss
111 |
112 | if optimizer_idx == 1:
113 | # discriminator
114 | discloss = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, last_layer=self.get_last_layer(), split="train")
115 | return discloss
116 |
117 | def configure_optimizers(self):
118 | lr = self.learning_rate
119 | opt_ae = torch.optim.Adam(list(self.encoder.parameters()) +
120 | list(self.decoder.parameters()) +
121 | list(self.quantize.parameters()) +
122 | list(self.quant_conv.parameters()) +
123 | list(self.post_quant_conv.parameters()),
124 | lr=lr, betas=(0.5, 0.9))
125 | opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
126 | lr=lr, betas=(0.5, 0.9))
127 | return [opt_ae, opt_disc], []
128 |
129 | def get_last_layer(self):
130 | return self.decoder.conv_out.weight
131 |
--------------------------------------------------------------------------------
/core/taming/modules/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 | from core.taming.modules.diffusion.attn_block import AttnBlock
2 | from core.taming.modules.diffusion.resnet_block import ResnetBlock
3 |
4 | from core.taming.modules.diffusion.downsample import Downsample
5 | from core.taming.modules.diffusion.upsample import Upsample
6 |
7 | from core.taming.modules.diffusion.encoder import Encoder
8 | from core.taming.modules.diffusion.decoder import Decoder
9 |
10 |
11 | __all__ = [
12 | AttnBlock,
13 | ResnetBlock,
14 | Downsample,
15 | Upsample,
16 | Encoder,
17 | Decoder,
18 | ]
19 |
--------------------------------------------------------------------------------
/core/taming/modules/diffusion/attn_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from core.taming.utils import Normalize
5 |
6 |
7 | class AttnBlock(nn.Module):
8 | def __init__(self, in_channels):
9 | super().__init__()
10 | self.in_channels = in_channels
11 |
12 | self.norm = Normalize(in_channels)
13 | self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
14 | self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
15 | self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
16 | self.proj_out = torch.nn.Conv2d(
17 | in_channels, in_channels, kernel_size=1, stride=1, padding=0
18 | )
19 |
20 | def forward(self, x):
21 | h_ = x
22 | h_ = self.norm(h_)
23 | q = self.q(h_)
24 | k = self.k(h_)
25 | v = self.v(h_)
26 |
27 | # compute attention
28 | b, c, h, w = q.shape
29 | q = q.reshape(b, c, h * w)
30 | q = q.permute(0, 2, 1) # b, hw, c
31 | k = k.reshape(b, c, h * w) # b, c, hw
32 | w_ = torch.bmm(q, k) # b, hw, hw w[b, i, j]=sum_c q[b, i, c]k[b, c, j]
33 | w_ = w_ * (int(c)**(-0.5))
34 | w_ = torch.nn.functional.softmax(w_, dim=2)
35 |
36 | # attend to values
37 | v = v.reshape(b, c, h * w)
38 | w_ = w_.permute(0, 2, 1) # b, hw, hw (first hw of k, second of q)
39 | h_ = torch.bmm(v, w_) # b, c, hw (hw of q) h_[b, c, j] = sum_i v[b, c, i] w_[b, i, j]
40 | h_ = h_.reshape(b, c, h, w)
41 |
42 | h_ = self.proj_out(h_)
43 |
44 | return x + h_
45 |
--------------------------------------------------------------------------------
/core/taming/modules/diffusion/decoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import numpy as np
5 |
6 | from core.taming.utils import Normalize, nonlinearity
7 |
8 | from core.taming.modules.diffusion import AttnBlock, ResnetBlock, Upsample
9 |
10 |
11 | class Decoder(nn.Module):
12 | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
13 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
14 | resolution, z_channels, give_pre_end=False, **ignorekwargs):
15 | super().__init__()
16 | self.ch = ch
17 | self.temb_ch = 0
18 | self.num_resolutions = len(ch_mult)
19 | self.num_res_blocks = num_res_blocks
20 | self.resolution = resolution
21 | self.in_channels = in_channels
22 | self.give_pre_end = give_pre_end
23 |
24 | # compute in_ch_mult, block_in and curr_res at lowest res
25 | # in_ch_mult = (1,)+tuple(ch_mult)
26 | block_in = ch * ch_mult[self.num_resolutions - 1]
27 | curr_res = resolution // 2**(self.num_resolutions - 1)
28 | self.z_shape = (1, z_channels, curr_res, curr_res)
29 | print("Working with z of shape {} = {} dimensions.".format(
30 | self.z_shape, np.prod(self.z_shape)))
31 |
32 | # z to block_in
33 | self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
34 |
35 | # middle
36 | self.mid = nn.Module()
37 | self.mid.block_1 = ResnetBlock(
38 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
39 | )
40 | self.mid.attn_1 = AttnBlock(block_in)
41 | self.mid.block_2 = ResnetBlock(
42 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
43 | )
44 |
45 | # upsampling
46 | self.up = nn.ModuleList()
47 | for i_level in reversed(range(self.num_resolutions)):
48 | block = nn.ModuleList()
49 | attn = nn.ModuleList()
50 | block_out = ch * ch_mult[i_level]
51 | for i_block in range(self.num_res_blocks + 1):
52 | block.append(ResnetBlock(in_channels=block_in,
53 | out_channels=block_out,
54 | temb_channels=self.temb_ch,
55 | dropout=dropout))
56 | block_in = block_out
57 | if curr_res in attn_resolutions:
58 | attn.append(AttnBlock(block_in))
59 | up = nn.Module()
60 | up.block = block
61 | up.attn = attn
62 | if i_level != 0:
63 | up.upsample = Upsample(block_in, resamp_with_conv)
64 | curr_res = curr_res * 2
65 | self.up.insert(0, up) # prepend to get consistent order
66 |
67 | # end
68 | self.norm_out = Normalize(block_in)
69 | self.conv_out = torch.nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
70 |
71 | def forward(self, z):
72 | # assert z.shape[1:] == self.z_shape[1:]
73 | self.last_z_shape = z.shape
74 |
75 | # timestep embedding
76 | temb = None
77 |
78 | # z to block_in
79 | h = self.conv_in(z)
80 |
81 | # middle
82 | h = self.mid.block_1(h, temb)
83 | h = self.mid.attn_1(h)
84 | h = self.mid.block_2(h, temb)
85 |
86 | # upsampling
87 | for i_level in reversed(range(self.num_resolutions)):
88 | for i_block in range(self.num_res_blocks + 1):
89 | h = self.up[i_level].block[i_block](h, temb)
90 | if len(self.up[i_level].attn) > 0:
91 | h = self.up[i_level].attn[i_block](h)
92 | if i_level != 0:
93 | h = self.up[i_level].upsample(h)
94 |
95 | # end
96 | if self.give_pre_end:
97 | return h
98 |
99 | h = self.norm_out(h)
100 | h = nonlinearity(h)
101 | h = self.conv_out(h)
102 | return h
103 |
--------------------------------------------------------------------------------
/core/taming/modules/diffusion/downsample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Downsample(nn.Module):
6 | def __init__(self, in_channels, with_conv):
7 | super().__init__()
8 | self.with_conv = with_conv
9 | if self.with_conv:
10 | # no asymmetric padding in torch conv, must do it ourselves
11 | self.conv = torch.nn.Conv2d(
12 | in_channels, in_channels, kernel_size=3, stride=2, padding=0
13 | )
14 |
15 | def forward(self, x):
16 | if self.with_conv:
17 | pad = (0, 1, 0, 1)
18 | x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
19 | x = self.conv(x)
20 | else:
21 | x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
22 | return x
23 |
--------------------------------------------------------------------------------
/core/taming/modules/diffusion/encoder.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from core.taming.utils import Normalize, nonlinearity
5 |
6 | from core.taming.modules.diffusion import AttnBlock, ResnetBlock, Downsample
7 |
8 |
9 | class Encoder(nn.Module):
10 | def __init__(self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks,
11 | attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
12 | resolution, z_channels, double_z=True, **ignore_kwargs):
13 | super().__init__()
14 | self.ch = ch
15 | self.temb_ch = 0
16 | self.num_resolutions = len(ch_mult)
17 | self.num_res_blocks = num_res_blocks
18 | self.resolution = resolution
19 | self.in_channels = in_channels
20 |
21 | # downsampling
22 | self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
23 |
24 | curr_res = resolution
25 | in_ch_mult = (1,) + tuple(ch_mult)
26 | self.down = nn.ModuleList()
27 | for i_level in range(self.num_resolutions):
28 | block = nn.ModuleList()
29 | attn = nn.ModuleList()
30 | block_in = ch * in_ch_mult[i_level]
31 | block_out = ch * ch_mult[i_level]
32 | for i_block in range(self.num_res_blocks):
33 | block.append(ResnetBlock(in_channels=block_in,
34 | out_channels=block_out,
35 | temb_channels=self.temb_ch,
36 | dropout=dropout))
37 | block_in = block_out
38 | if curr_res in attn_resolutions:
39 | attn.append(AttnBlock(block_in))
40 | down = nn.Module()
41 | down.block = block
42 | down.attn = attn
43 | if i_level != self.num_resolutions - 1:
44 | down.downsample = Downsample(block_in, resamp_with_conv)
45 | curr_res = curr_res // 2
46 | self.down.append(down)
47 |
48 | # middle
49 | self.mid = nn.Module()
50 | self.mid.block_1 = ResnetBlock(
51 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
52 | )
53 | self.mid.attn_1 = AttnBlock(block_in)
54 | self.mid.block_2 = ResnetBlock(
55 | in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout
56 | )
57 |
58 | # end
59 | self.norm_out = Normalize(block_in)
60 | self.conv_out = torch.nn.Conv2d(
61 | block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1
62 | )
63 |
64 | def forward(self, x):
65 | # assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(
66 | # x.shape[2], x.shape[3], self.resolution
67 | # )
68 |
69 | # timestep embedding
70 | temb = None
71 |
72 | # downsampling
73 | hs = [self.conv_in(x)]
74 | for i_level in range(self.num_resolutions):
75 | for i_block in range(self.num_res_blocks):
76 | h = self.down[i_level].block[i_block](hs[-1], temb)
77 | if len(self.down[i_level].attn) > 0:
78 | h = self.down[i_level].attn[i_block](h)
79 | hs.append(h)
80 | if i_level != self.num_resolutions - 1:
81 | hs.append(self.down[i_level].downsample(hs[-1]))
82 |
83 | # middle
84 | h = hs[-1]
85 | h = self.mid.block_1(h, temb)
86 | h = self.mid.attn_1(h)
87 | h = self.mid.block_2(h, temb)
88 |
89 | # end
90 | h = self.norm_out(h)
91 | h = nonlinearity(h)
92 | h = self.conv_out(h)
93 | return h
94 |
--------------------------------------------------------------------------------
/core/taming/modules/diffusion/resnet_block.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from core.taming.utils import Normalize, nonlinearity
5 |
6 |
7 | class ResnetBlock(nn.Module):
8 | def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
9 | dropout, temb_channels=512):
10 | super().__init__()
11 | self.in_channels = in_channels
12 | out_channels = in_channels if out_channels is None else out_channels
13 | self.out_channels = out_channels
14 | self.use_conv_shortcut = conv_shortcut
15 |
16 | self.norm1 = Normalize(in_channels)
17 | self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
18 | if temb_channels > 0:
19 | self.temb_proj = torch.nn.Linear(temb_channels,
20 | out_channels)
21 | self.norm2 = Normalize(out_channels)
22 | self.dropout = torch.nn.Dropout(dropout)
23 | self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
24 | if self.in_channels != self.out_channels:
25 | if self.use_conv_shortcut:
26 | self.conv_shortcut = torch.nn.Conv2d(
27 | in_channels, out_channels, kernel_size=3, stride=1, padding=1
28 | )
29 | else:
30 | self.nin_shortcut = torch.nn.Conv2d(
31 | in_channels, out_channels, kernel_size=1, stride=1, padding=0
32 | )
33 |
34 | def forward(self, x, temb):
35 | h = x
36 | h = self.norm1(h)
37 | h = nonlinearity(h)
38 | h = self.conv1(h)
39 |
40 | if temb is not None:
41 | h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
42 |
43 | h = self.norm2(h)
44 | h = nonlinearity(h)
45 | h = self.dropout(h)
46 | h = self.conv2(h)
47 |
48 | if self.in_channels != self.out_channels:
49 | if self.use_conv_shortcut:
50 | x = self.conv_shortcut(x)
51 | else:
52 | x = self.nin_shortcut(x)
53 |
54 | return x + h
55 |
--------------------------------------------------------------------------------
/core/taming/modules/diffusion/upsample.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class Upsample(nn.Module):
6 | def __init__(self, in_channels, with_conv):
7 | super().__init__()
8 | self.with_conv = with_conv
9 | if self.with_conv:
10 | self.conv = torch.nn.Conv2d(
11 | in_channels, in_channels, kernel_size=3, stride=1, padding=1
12 | )
13 |
14 | def forward(self, x):
15 | x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
16 | if self.with_conv:
17 | x = self.conv(x)
18 | return x
19 |
--------------------------------------------------------------------------------
/core/taming/modules/discriminator/__init__.py:
--------------------------------------------------------------------------------
1 | from core.taming.modules.discriminator.act_norm import ActNorm
2 | from core.taming.modules.discriminator.discriminator import NLayerDiscriminator
3 |
4 | __all__ = [
5 | ActNorm,
6 | NLayerDiscriminator
7 | ]
8 |
--------------------------------------------------------------------------------
/core/taming/modules/discriminator/act_norm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ActNorm(nn.Module):
6 | def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
7 | assert affine
8 | super().__init__()
9 | self.logdet = logdet
10 | self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
11 | self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
12 | self.allow_reverse_init = allow_reverse_init
13 |
14 | self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8))
15 |
16 | def initialize(self, input):
17 | with torch.no_grad():
18 | flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
19 | mean = (
20 | flatten.mean(1)
21 | .unsqueeze(1)
22 | .unsqueeze(2)
23 | .unsqueeze(3)
24 | .permute(1, 0, 2, 3)
25 | )
26 | std = (
27 | flatten.std(1)
28 | .unsqueeze(1)
29 | .unsqueeze(2)
30 | .unsqueeze(3)
31 | .permute(1, 0, 2, 3)
32 | )
33 |
34 | self.loc.data.copy_(-mean)
35 | self.scale.data.copy_(1 / (std + 1e-6))
36 |
37 | def forward(self, input, reverse=False):
38 | if reverse:
39 | return self.reverse(input)
40 | if len(input.shape) == 2:
41 | input = input[:,:,None,None]
42 | squeeze = True
43 | else:
44 | squeeze = False
45 |
46 | _, _, height, width = input.shape
47 |
48 | if self.training and self.initialized.item() == 0:
49 | self.initialize(input)
50 | self.initialized.fill_(1)
51 |
52 | h = self.scale * (input + self.loc)
53 |
54 | if squeeze:
55 | h = h.squeeze(-1).squeeze(-1)
56 |
57 | if self.logdet:
58 | log_abs = torch.log(torch.abs(self.scale))
59 | logdet = height*width*torch.sum(log_abs)
60 | logdet = logdet * torch.ones(input.shape[0]).to(input)
61 | return h, logdet
62 |
63 | return h
64 |
65 | def reverse(self, output):
66 | if self.training and self.initialized.item() == 0:
67 | if not self.allow_reverse_init:
68 | raise RuntimeError(
69 | "Initializing ActNorm in reverse direction is "
70 | "disabled by default. Use allow_reverse_init=True to enable."
71 | )
72 | else:
73 | self.initialize(output)
74 | self.initialized.fill_(1)
75 |
76 | if len(output.shape) == 2:
77 | output = output[:,:,None,None]
78 | squeeze = True
79 | else:
80 | squeeze = False
81 |
82 | h = output / self.scale - self.loc
83 |
84 | if squeeze:
85 | h = h.squeeze(-1).squeeze(-1)
86 | return h
87 |
--------------------------------------------------------------------------------
/core/taming/modules/discriminator/discriminator.py:
--------------------------------------------------------------------------------
1 | # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
2 |
3 | import functools
4 | import torch.nn as nn
5 |
6 | from core.taming.modules.discriminator import ActNorm
7 |
8 |
9 | class NLayerDiscriminator(nn.Module):
10 | """Defines a PatchGAN discriminator as in Pix2Pix"""
11 | def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
12 | """Construct a PatchGAN discriminator
13 | Parameters:
14 | input_nc (int) -- the number of channels in input images
15 | ndf (int) -- the number of filters in the last conv layer
16 | n_layers (int) -- the number of conv layers in the discriminator
17 | norm_layer -- normalization layer
18 | """
19 | super(NLayerDiscriminator, self).__init__()
20 | if not use_actnorm:
21 | norm_layer = nn.BatchNorm2d
22 | else:
23 | norm_layer = ActNorm
24 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
25 | use_bias = norm_layer.func != nn.BatchNorm2d
26 | else:
27 | use_bias = norm_layer != nn.BatchNorm2d
28 |
29 | kw = 4
30 | padw = 1
31 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
32 | nf_mult = 1
33 | nf_mult_prev = 1
34 | for n in range(1, n_layers): # gradually increase the number of filters
35 | nf_mult_prev = nf_mult
36 | nf_mult = min(2 ** n, 8)
37 | sequence += [
38 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
39 | norm_layer(ndf * nf_mult),
40 | nn.LeakyReLU(0.2, True)
41 | ]
42 |
43 | nf_mult_prev = nf_mult
44 | nf_mult = min(2 ** n_layers, 8)
45 | sequence += [
46 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
47 | norm_layer(ndf * nf_mult),
48 | nn.LeakyReLU(0.2, True)
49 | ]
50 |
51 | sequence += [
52 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
53 | self.main = nn.Sequential(*sequence)
54 |
55 | def forward(self, input):
56 | """Standard forward."""
57 | return self.main(input)
58 |
--------------------------------------------------------------------------------
/core/taming/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from core.taming.modules.losses.lpips import LPIPS
2 | from core.taming.modules.losses.vqperceptual import VQLPIPSWithDiscriminator, DummyLoss
3 |
4 | __all__ = [
5 | LPIPS,
6 | DummyLoss,
7 | VQLPIPSWithDiscriminator
8 | ]
9 |
--------------------------------------------------------------------------------
/core/taming/modules/losses/lpips.py:
--------------------------------------------------------------------------------
1 | """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2 |
3 | import os
4 |
5 | import torch
6 | import torch.nn as nn
7 | from collections import namedtuple
8 |
9 | from core.utils.loader import download
10 | from core.taming.utils import normalize_tensor, spatial_average, load_vgg
11 |
12 |
13 | class LPIPS(nn.Module):
14 | # Learned perceptual metric
15 | def __init__(self, model_dir="/models", use_dropout=True):
16 | super().__init__()
17 | self.scaling_layer = ScalingLayer()
18 | self.chns = [64, 128, 256, 512, 512] # vg16 features
19 | self.net = VGG16(model_dir=model_dir, pretrained=True, requires_grad=False)
20 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
21 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
22 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
23 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
24 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
25 | self.load_from_pretrained(model_dir)
26 | for param in self.parameters():
27 | param.requires_grad = False
28 |
29 | def load_from_pretrained(self, model_dir="/models"):
30 | ckpt = f"{model_dir}/vgg.pth"
31 | if not os.path.exists(ckpt):
32 | download("https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1", ckpt)
33 | self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
34 | print(f"Loaded pretrained LPIPS loss from '{ckpt}'")
35 |
36 | def forward(self, input, target):
37 | in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
38 | outs0, outs1 = self.net(in0_input), self.net(in1_input)
39 | feats0, feats1, diffs = {}, {}, {}
40 | lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
41 | for kk in range(len(self.chns)):
42 | feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
43 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
44 |
45 | res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
46 | val = res[0]
47 | for l in range(1, len(self.chns)):
48 | val += res[l]
49 | return val
50 |
51 |
52 | class ScalingLayer(nn.Module):
53 | def __init__(self):
54 | super(ScalingLayer, self).__init__()
55 | self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
56 | self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
57 |
58 | def forward(self, inp):
59 | return (inp - self.shift) / self.scale
60 |
61 |
62 | class NetLinLayer(nn.Module):
63 | """ A single linear layer which does a 1x1 conv """
64 | def __init__(self, chn_in, chn_out=1, use_dropout=False):
65 | super(NetLinLayer, self).__init__()
66 | layers = [nn.Dropout(), ] if (use_dropout) else []
67 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
68 | self.model = nn.Sequential(*layers)
69 |
70 |
71 | class VGG16(torch.nn.Module):
72 | def __init__(self, model_dir="/models", requires_grad=False, pretrained=True):
73 | super(VGG16, self).__init__()
74 | vgg_pretrained_features = load_vgg(model_dir=model_dir, pretrained=pretrained).features
75 | self.slice1 = torch.nn.Sequential()
76 | self.slice2 = torch.nn.Sequential()
77 | self.slice3 = torch.nn.Sequential()
78 | self.slice4 = torch.nn.Sequential()
79 | self.slice5 = torch.nn.Sequential()
80 | self.N_slices = 5
81 | for x in range(4):
82 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
83 | for x in range(4, 9):
84 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
85 | for x in range(9, 16):
86 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
87 | for x in range(16, 23):
88 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
89 | for x in range(23, 30):
90 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
91 | if not requires_grad:
92 | for param in self.parameters():
93 | param.requires_grad = False
94 |
95 | def forward(self, X):
96 | h = self.slice1(X)
97 | h_relu1_2 = h
98 | h = self.slice2(h)
99 | h_relu2_2 = h
100 | h = self.slice3(h)
101 | h_relu3_3 = h
102 | h = self.slice4(h)
103 | h_relu4_3 = h
104 | h = self.slice5(h)
105 | h_relu5_3 = h
106 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
107 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
108 | return out
109 |
--------------------------------------------------------------------------------
/core/taming/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from core.taming.utils import hinge_d_loss, vanilla_d_loss, adopt_weight, weights_init
5 |
6 | from core.taming.modules.discriminator import NLayerDiscriminator
7 |
8 | from core.taming.modules.losses import LPIPS
9 |
10 |
11 | class DummyLoss(nn.Module):
12 | def __init__(self):
13 | super().__init__()
14 |
15 |
16 | class VQLPIPSWithDiscriminator(nn.Module):
17 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
18 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
19 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
20 | disc_ndf=64, disc_loss="hinge", model_dir=None):
21 | super().__init__()
22 | assert disc_loss in ["hinge", "vanilla"]
23 | self.codebook_weight = codebook_weight
24 | self.pixel_weight = pixelloss_weight
25 | self.perceptual_loss = LPIPS(model_dir=model_dir).eval()
26 | self.perceptual_weight = perceptual_weight
27 |
28 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
29 | n_layers=disc_num_layers,
30 | use_actnorm=use_actnorm,
31 | ndf=disc_ndf
32 | ).apply(weights_init)
33 | self.discriminator_iter_start = disc_start
34 | if disc_loss == "hinge":
35 | self.disc_loss = hinge_d_loss
36 | elif disc_loss == "vanilla":
37 | self.disc_loss = vanilla_d_loss
38 | else:
39 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
40 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
41 | self.disc_factor = disc_factor
42 | self.discriminator_weight = disc_weight
43 | self.disc_conditional = disc_conditional
44 |
45 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
46 | if last_layer is not None:
47 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
48 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
49 | else:
50 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
51 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
52 |
53 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
54 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
55 | d_weight = d_weight * self.discriminator_weight
56 | return d_weight
57 |
58 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
59 | global_step, last_layer=None, cond=None, split="train"):
60 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
61 | if self.perceptual_weight > 0:
62 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
63 | rec_loss = rec_loss + self.perceptual_weight * p_loss
64 | else:
65 | p_loss = torch.tensor([0.0])
66 |
67 | nll_loss = rec_loss
68 | nll_loss = torch.mean(nll_loss)
69 |
70 | # now the GAN part
71 | if optimizer_idx == 0:
72 | # generator update
73 | if cond is None:
74 | assert not self.disc_conditional
75 | logits_fake = self.discriminator(reconstructions.contiguous())
76 | else:
77 | assert self.disc_conditional
78 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
79 | g_loss = -torch.mean(logits_fake)
80 |
81 | try:
82 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
83 | except RuntimeError:
84 | assert not self.training
85 | d_weight = torch.tensor(0.0)
86 |
87 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
88 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
89 |
90 | return loss
91 |
92 | if optimizer_idx == 1:
93 | # second pass for discriminator update
94 | if cond is None:
95 | logits_real = self.discriminator(inputs.contiguous().detach())
96 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
97 | else:
98 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
99 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
100 |
101 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
102 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
103 |
104 | return d_loss
105 |
--------------------------------------------------------------------------------
/core/taming/modules/vqvae/__init__.py:
--------------------------------------------------------------------------------
1 | from core.taming.modules.vqvae.vector_quantizer import VectorQuantizer
2 |
3 |
4 | __all__ = [
5 | VectorQuantizer
6 | ]
7 |
--------------------------------------------------------------------------------
/core/taming/modules/vqvae/vector_quantizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import numpy as np
5 |
6 | from einops import rearrange
7 |
8 |
9 | class VectorQuantizer(nn.Module):
10 | """
11 | Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
12 | avoids costly matrix multiplications and allows for post-hoc remapping of indices.
13 | """
14 | # NOTE: due to a bug the beta term was applied to the wrong term. for
15 | # backwards compatibility we use the buggy version by default, but you can
16 | # specify legacy=False to fix it.
17 | def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
18 | sane_index_shape=False, legacy=True):
19 | super().__init__()
20 | self.n_e = n_e
21 | self.e_dim = e_dim
22 | self.beta = beta
23 | self.legacy = legacy
24 |
25 | self.embedding = nn.Embedding(self.n_e, self.e_dim)
26 | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
27 |
28 | self.remap = remap
29 | if self.remap is not None:
30 | self.register_buffer("used", torch.tensor(np.load(self.remap)))
31 | self.re_embed = self.used.shape[0]
32 | self.unknown_index = unknown_index # "random" or "extra" or integer
33 | if self.unknown_index == "extra":
34 | self.unknown_index = self.re_embed
35 | self.re_embed = self.re_embed + 1
36 | print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
37 | f"Using {self.unknown_index} for unknown indices.")
38 | else:
39 | self.re_embed = n_e
40 |
41 | self.sane_index_shape = sane_index_shape
42 |
43 | def remap_to_used(self, inds):
44 | ishape = inds.shape
45 | assert len(ishape) > 1
46 | inds = inds.reshape(ishape[0], -1)
47 | used = self.used.to(inds)
48 | match = (inds[:, :, None] == used[None, None, ...]).long()
49 | new = match.argmax(-1)
50 | unknown = match.sum(2) < 1
51 | if self.unknown_index == "random":
52 | new[unknown] = \
53 | torch.randint(0, self.re_embed, size=new[unknown].shape).to(device=new.device)
54 | else:
55 | new[unknown] = self.unknown_index
56 | return new.reshape(ishape)
57 |
58 | def unmap_to_all(self, inds):
59 | ishape = inds.shape
60 | assert len(ishape) > 1
61 | inds = inds.reshape(ishape[0], -1)
62 | used = self.used.to(inds)
63 | if self.re_embed > self.used.shape[0]: # extra token
64 | inds[inds >= self.used.shape[0]] = 0 # simply set to zero
65 | back = torch.gather(used[None, :][inds.shape[0] * [0], :], 1, inds)
66 | return back.reshape(ishape)
67 |
68 | def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
69 | assert temp is None or temp == 1.0, "Only for interface compatible with Gumbel"
70 | assert rescale_logits is False, "Only for interface compatible with Gumbel"
71 | assert return_logits is False, "Only for interface compatible with Gumbel"
72 |
73 | # reshape z -> (batch, height, width, channel) and flatten
74 | z = rearrange(z, 'b c h w -> b h w c').contiguous()
75 | z_flattened = z.view(-1, self.e_dim)
76 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
77 |
78 | d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
79 | torch.sum(self.embedding.weight**2, dim=1) - 2 * \
80 | torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
81 |
82 | min_encoding_indices = torch.argmin(d, dim=1)
83 | z_q = self.embedding(min_encoding_indices).view(z.shape)
84 | perplexity = None
85 | min_encodings = None
86 |
87 | # compute loss for embedding
88 | if not self.legacy:
89 | loss = self.beta * torch.mean((z_q.detach() - z)**2) + \
90 | torch.mean((z_q - z.detach()) ** 2)
91 | else:
92 | loss = torch.mean((z_q.detach() - z)**2) + self.beta * \
93 | torch.mean((z_q - z.detach()) ** 2)
94 |
95 | # preserve gradients
96 | z_q = z + (z_q - z).detach()
97 |
98 | # reshape back to match original input shape
99 | z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
100 |
101 | if self.remap is not None:
102 | min_encoding_indices = min_encoding_indices.reshape(z.shape[0], -1) # add batch axis
103 | min_encoding_indices = self.remap_to_used(min_encoding_indices)
104 | min_encoding_indices = min_encoding_indices.reshape(-1, 1) # flatten
105 |
106 | if self.sane_index_shape:
107 | min_encoding_indices = min_encoding_indices.reshape(
108 | z_q.shape[0], z_q.shape[2], z_q.shape[3])
109 |
110 | return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
111 |
112 | def get_codebook_entry(self, indices, shape):
113 | # shape specifying (batch, height, width, channel)
114 | if self.remap is not None:
115 | indices = indices.reshape(shape[0], -1) # add batch axis
116 | indices = self.unmap_to_all(indices)
117 | indices = indices.reshape(-1) # flatten again
118 |
119 | # get quantized latent vectors
120 | z_q = self.embedding(indices)
121 |
122 | if shape is not None:
123 | z_q = z_q.view(shape)
124 | # reshape back to match original input shape
125 | z_q = z_q.permute(0, 3, 1, 2).contiguous()
126 |
127 | return z_q
128 |
--------------------------------------------------------------------------------
/core/taming/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from core.taming.utils.diffusion_utils import Normalize, nonlinearity
2 | from core.taming.utils.discriminator_utils import weights_init
3 | from core.taming.utils.losses_utils import (
4 | adopt_weight, hinge_d_loss, vanilla_d_loss, normalize_tensor, spatial_average, load_vgg
5 | )
6 |
7 | __all__ = [
8 | Normalize,
9 | nonlinearity,
10 | weights_init,
11 | adopt_weight,
12 | hinge_d_loss,
13 | vanilla_d_loss,
14 | normalize_tensor,
15 | spatial_average,
16 | load_vgg,
17 | ]
18 |
--------------------------------------------------------------------------------
/core/taming/utils/diffusion_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def nonlinearity(x):
5 | # swish
6 | return x * torch.sigmoid(x)
7 |
8 |
9 | def Normalize(in_channels):
10 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
11 |
--------------------------------------------------------------------------------
/core/taming/utils/discriminator_utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | def weights_init(m):
5 | classname = m.__class__.__name__
6 | if classname.find('Conv') != -1:
7 | nn.init.normal_(m.weight.data, 0.0, 0.02)
8 | elif classname.find('BatchNorm') != -1:
9 | nn.init.normal_(m.weight.data, 1.0, 0.02)
10 | nn.init.constant_(m.bias.data, 0)
11 |
--------------------------------------------------------------------------------
/core/taming/utils/losses_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from torchvision.models import VGG
6 | from torchvision.models.vgg import load_state_dict_from_url
7 |
8 | from typing import List, Union, cast
9 |
10 |
11 | def adopt_weight(weight, global_step, threshold=0, value=0.):
12 | if global_step < threshold:
13 | weight = value
14 | return weight
15 |
16 |
17 | def hinge_d_loss(logits_real, logits_fake):
18 | loss_real = torch.mean(F.relu(1. - logits_real))
19 | loss_fake = torch.mean(F.relu(1. + logits_fake))
20 | d_loss = 0.5 * (loss_real + loss_fake)
21 | return d_loss
22 |
23 |
24 | def vanilla_d_loss(logits_real, logits_fake):
25 | d_loss = 0.5 * (
26 | torch.mean(torch.nn.functional.softplus(-logits_real)) +
27 | torch.mean(torch.nn.functional.softplus(logits_fake)))
28 | return d_loss
29 |
30 |
31 | def normalize_tensor(x, eps=1e-10):
32 | norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True))
33 | return x / (norm_factor + eps)
34 |
35 |
36 | def spatial_average(x, keepdim=True):
37 | return x.mean([2, 3], keepdim=keepdim)
38 |
39 |
40 | def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential:
41 | layers: List[nn.Module] = []
42 | in_channels = 3
43 | for v in cfg:
44 | if v == 'M':
45 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
46 | else:
47 | v = cast(int, v)
48 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
49 | if batch_norm:
50 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
51 | else:
52 | layers += [conv2d, nn.ReLU(inplace=True)]
53 | in_channels = v
54 | return nn.Sequential(*layers)
55 |
56 |
57 | def load_vgg(model_dir: str, pretrained: bool = False, **kwargs):
58 | if pretrained:
59 | kwargs['init_weights'] = False
60 |
61 | cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M']
62 | model = VGG(make_layers(cfg, batch_norm=False), **kwargs)
63 |
64 | if pretrained:
65 | state_dict = load_state_dict_from_url('https://download.pytorch.org/models/vgg16-397923af.pth',
66 | model_dir=model_dir,
67 | file_name="vgg16-397923af.pth",
68 | progress=True)
69 | model.load_state_dict(state_dict)
70 | print(f"Loaded pretrained VGG16 model from '{model_dir}/vgg16-397923af.pth'")
71 |
72 | return model
73 |
--------------------------------------------------------------------------------
/core/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from core.utils.make_cutouts import MakeCutouts
2 | from core.utils.normalize import Normalize
3 | from core.utils.helpers import resize_image, get_optimizer, get_scheduler, load_vqgan_model, global_seed
4 |
5 | __all__ = [
6 | MakeCutouts,
7 | Normalize,
8 | resize_image,
9 | get_optimizer,
10 | get_scheduler,
11 | load_vqgan_model,
12 | global_seed
13 | ]
14 |
--------------------------------------------------------------------------------
/core/utils/gradients.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | class ReplaceGrad(torch.autograd.Function):
6 | @staticmethod
7 | def forward(ctx, x_forward, x_backward):
8 | ctx.shape = x_backward.shape
9 | return x_forward
10 |
11 | @staticmethod
12 | def backward(ctx, grad_in):
13 | return None, grad_in.sum_to_size(ctx.shape)
14 |
15 |
16 | class ClampWithGrad(torch.autograd.Function):
17 | @staticmethod
18 | def forward(ctx, input, min, max):
19 | ctx.min = min
20 | ctx.max = max
21 | ctx.save_for_backward(input)
22 | return input.clamp(min, max)
23 |
24 | @staticmethod
25 | def backward(ctx, grad_in):
26 | input, = ctx.saved_tensors
27 | return grad_in * (grad_in * (input - input.clamp(ctx.min, ctx.max)) >= 0), None, None
28 |
29 |
30 | def vector_quantize(x, codebook):
31 | d = x.pow(2).sum(dim=-1, keepdim=True) + codebook.pow(2).sum(dim=1) - 2 * x @ codebook.T
32 | indices = d.argmin(-1)
33 | x_q = F.one_hot(indices, codebook.shape[0]).to(d.dtype) @ codebook
34 | return ReplaceGrad.apply(x_q, x)
35 |
--------------------------------------------------------------------------------
/core/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 |
4 | import numpy as np
5 |
6 | import torch
7 | import torch.optim as optim
8 | from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
9 |
10 | from PIL import Image
11 |
12 | from core.taming.models import vqgan
13 | from core.optimizer import DiffGrad, AdamP, RAdam
14 |
15 |
16 | def resize_image(image, out_size):
17 | ratio = image.size[0] / image.size[1]
18 | area = min(image.size[0] * image.size[1], out_size[0] * out_size[1])
19 | size = round((area * ratio)**0.5), round((area / ratio)**0.5)
20 | return image.resize(size, Image.LANCZOS)
21 |
22 |
23 | def get_optimizer(z, optimizer="Adam", step_size=0.1):
24 | if optimizer == "Adam":
25 | opt = optim.Adam([z], lr=step_size) # LR=0.1 (Default)
26 | elif optimizer == "AdamW":
27 | opt = optim.AdamW([z], lr=step_size) # LR=0.2
28 | elif optimizer == "Adagrad":
29 | opt = optim.Adagrad([z], lr=step_size) # LR=0.5+
30 | elif optimizer == "Adamax":
31 | opt = optim.Adamax([z], lr=step_size) # LR=0.5+?
32 | elif optimizer == "DiffGrad":
33 | opt = DiffGrad([z], lr=step_size) # LR=2+?
34 | elif optimizer == "AdamP":
35 | opt = AdamP([z], lr=step_size) # LR=2+?
36 | elif optimizer == "RAdam":
37 | opt = RAdam([z], lr=step_size) # LR=2+?
38 | return opt
39 |
40 |
41 | def get_scheduler(optimizer, max_iterations, nwarm_restarts=-1):
42 | if nwarm_restarts == -1:
43 | return None
44 |
45 | T_0 = max_iterations
46 | if nwarm_restarts > 0:
47 | T_0 = int(np.ceil(max_iterations / nwarm_restarts))
48 |
49 | return CosineAnnealingWarmRestarts(optimizer, T_0=T_0)
50 |
51 |
52 | def load_vqgan_model(config_path, checkpoint_path, model_dir=None):
53 | with open(config_path, 'r') as f:
54 | config = json.load(f)
55 |
56 | model = vqgan.VQModel(model_dir=model_dir, **config["params"])
57 | model.eval().requires_grad_(False)
58 | model.init_from_ckpt(checkpoint_path)
59 |
60 | del model.loss
61 | return model
62 |
63 |
64 | def global_seed(seed: int):
65 | seed = seed if seed != -1 else torch.seed()
66 | if seed > 2**32 - 1:
67 | seed = seed >> 32
68 |
69 | random.seed(seed)
70 | np.random.seed(seed)
71 | torch.manual_seed(seed)
72 | torch.cuda.manual_seed_all(seed)
73 | print(f"Global seed set to {seed}.")
74 |
--------------------------------------------------------------------------------
/core/utils/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import io
3 | import pickle
4 |
5 | import requests
6 |
7 | import torch
8 | from torch.serialization import (
9 | _get_restore_location, _maybe_decode_ascii, _open_file_like, _open_zipfile_reader
10 | )
11 |
12 | from tqdm import tqdm
13 |
14 |
15 | def safe_load(f, map_location=None, pickle_module=pickle, pickle_file='data.pkl', **pickle_load_args):
16 | with _open_file_like(f, 'rb') as opened_file:
17 | with _open_zipfile_reader(opened_file) as zip_file:
18 | restore_location = _get_restore_location(map_location)
19 |
20 | loaded_storages = {}
21 |
22 | def load_tensor(data_type, size, key, location):
23 | name = f'data/{key}'
24 | dtype = data_type(0).dtype
25 |
26 | storage = zip_file.get_storage_from_record(name, size, dtype).storage()
27 | loaded_storages[key] = restore_location(storage, location)
28 |
29 | def persistent_load(saved_id):
30 | assert isinstance(saved_id, tuple)
31 | typename = _maybe_decode_ascii(saved_id[0])
32 | data = saved_id[1:]
33 |
34 | assert typename == 'storage', \
35 | f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
36 | data_type, key, location, size = data
37 | if key not in loaded_storages:
38 | load_tensor(data_type, size, key, _maybe_decode_ascii(location))
39 | storage = loaded_storages[key]
40 | return storage
41 |
42 | load_module_mapping = {
43 | 'torch.tensor': 'torch._tensor'
44 | }
45 |
46 | class UnpicklerWrapper(pickle_module.Unpickler):
47 | def find_class(self, mod_name, name):
48 | try:
49 | mod_name = load_module_mapping.get(mod_name, mod_name)
50 | return super().find_class(mod_name, name)
51 | except Exception:
52 | pass
53 |
54 | # Load the data (which may in turn use `persistent_load` to load tensors)
55 | data_file = io.BytesIO(zip_file.get_record(pickle_file))
56 |
57 | unpickler = UnpicklerWrapper(data_file, **pickle_load_args)
58 | unpickler.persistent_load = persistent_load
59 | result = unpickler.load()
60 |
61 | torch._utils._validate_loaded_sparse_tensors()
62 |
63 | return result
64 |
65 |
66 | def download(url, local_path, chunk_size=1024):
67 | os.makedirs(os.path.split(local_path)[0], exist_ok=True)
68 | with requests.get(url, stream=True) as r:
69 | total_size = int(r.headers.get("content-length", 0))
70 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
71 | with open(local_path, "wb") as f:
72 | for data in r.iter_content(chunk_size=chunk_size):
73 | if data:
74 | f.write(data)
75 | pbar.update(chunk_size)
76 |
--------------------------------------------------------------------------------
/core/utils/make_cutouts.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | import kornia.augmentation as K
5 |
6 | CUTOUTS = {
7 | 'Ji': K.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1, p=0.5),
8 | 'Sh': K.RandomSharpness(sharpness=0.5, p=0.5),
9 | 'Gn': K.RandomGaussianNoise(mean=0.0, std=1.0, p=0.5),
10 | 'Pe': K.RandomPerspective(distortion_scale=0.5, p=0.5),
11 | 'Ro': K.RandomRotation(degrees=15, p=0.5),
12 | 'Af': K.RandomAffine(degrees=15, translate=0.1, shear=15, padding_mode='border', keepdim=True, p=0.5),
13 | 'Et': K.RandomElasticTransform(p=0.5),
14 | 'Hf': K.RandomHorizontalFlip(p=0.5),
15 | 'Ts': K.RandomThinPlateSpline(scale=0.2, same_on_batch=False, p=0.5),
16 | 'Er': K.RandomErasing(scale=(0.02, 0.33), ratio=(0.3, 3.3), same_on_batch=False, p=0.5),
17 | }
18 |
19 |
20 | class MakeCutouts(nn.Module):
21 | def __init__(self, augments, cut_size, cutn, cut_pow=1.):
22 | super().__init__()
23 | self.cut_size = cut_size
24 | self.cutn = cutn
25 | self.cut_pow = cut_pow
26 |
27 | augment_list = []
28 | for item in augments:
29 | if item == 'Cr':
30 | aug = K.RandomCrop(size=(self.cut_size, self.cut_size), p=0.5)
31 | elif item == 'Re':
32 | aug = K.RandomResizedCrop(size=(self.cut_size, self.cut_size), cropping_mode='resample', p=0.5)
33 | else:
34 | aug = CUTOUTS[item]
35 | augment_list.append(aug)
36 |
37 | print(f"Augmentations: {augment_list}")
38 | self.augs = nn.Sequential(*augment_list)
39 |
40 | self.noise_fac = 0.1
41 |
42 | # Pooling
43 | self.av_pool = nn.AdaptiveAvgPool2d((self.cut_size, self.cut_size))
44 | self.max_pool = nn.AdaptiveMaxPool2d((self.cut_size, self.cut_size))
45 |
46 | def forward(self, input):
47 | cutouts = []
48 |
49 | for _ in range(self.cutn):
50 | # Use Pooling
51 | cutout = (self.av_pool(input) + self.max_pool(input)) / 2
52 | cutouts.append(cutout)
53 |
54 | batch = self.augs(torch.cat(cutouts, dim=0))
55 |
56 | if self.noise_fac:
57 | facs = batch.new_empty([self.cutn, 1, 1, 1]).uniform_(0, self.noise_fac)
58 | batch = batch + facs * torch.randn_like(batch)
59 | return batch
60 |
--------------------------------------------------------------------------------
/core/utils/noises.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from PIL import Image
4 |
5 |
6 | def perlin_noise_2d(shape, res):
7 | def interpolant(t):
8 | return t*t*t*(t*(t*6 - 15) + 10)
9 |
10 | delta = (res[0] / shape[0], res[1] / shape[1])
11 | d = (shape[0] // res[0], shape[1] // res[1])
12 | grid = np.mgrid[0:res[0]:delta[0], 0:res[1]:delta[1]].transpose(1, 2, 0) % 1
13 |
14 | # Gradients
15 | angles = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1)
16 | gradients = np.dstack((np.cos(angles), np.sin(angles)))
17 | gradients = gradients.repeat(d[0], 0).repeat(d[1], 1)
18 | g00 = gradients[ :-d[0], :-d[1]]
19 | g10 = gradients[d[0]: , :-d[1]]
20 | g01 = gradients[ :-d[0],d[1]: ]
21 | g11 = gradients[d[0]: ,d[1]: ]
22 |
23 | # Ramps
24 | n00 = np.sum(np.dstack((grid[:, :, 0] , grid[:, :, 1] )) * g00, 2)
25 | n10 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1] )) * g10, 2)
26 | n01 = np.sum(np.dstack((grid[:, :, 0] , grid[:, :, 1]-1)) * g01, 2)
27 | n11 = np.sum(np.dstack((grid[:, :, 0] - 1, grid[:, :, 1]-1)) * g11, 2)
28 |
29 | # Interpolation
30 | t = interpolant(grid)
31 | n0 = n00 * (1 - t[:, :, 0]) + t[:, :, 0] * n10
32 | n1 = n01 * (1 - t[:, :, 0]) + t[:, :, 0] * n11
33 | return np.sqrt(2) * ((1 - t[:, :, 1]) * n0 + t[:, :, 1] * n1)
34 |
35 |
36 | def fractal_noise_2d(shape, res, octaves=1, persistence=0.5, lacunarity=2):
37 | noise = np.zeros(shape)
38 | frequency = 1
39 | amplitude = 1
40 |
41 | for _ in range(octaves):
42 | noise += amplitude * perlin_noise_2d(shape, (frequency * res[0], frequency * res[1]))
43 | frequency *= lacunarity
44 | amplitude *= persistence
45 | return (noise - np.min(noise)) / (np.max(noise) - np.min(noise))
46 |
47 |
48 | def random_fractal_image(width, height):
49 | _pow = int(np.ceil(np.log(max(width, height)) / np.log(2)))
50 | octaves = _pow - 4
51 | size = 2 ** _pow
52 | r = fractal_noise_2d((size, size), (32, 32), octaves=octaves)
53 | g = fractal_noise_2d((size, size), (32, 32), octaves=octaves)
54 | b = fractal_noise_2d((size, size), (32, 32), octaves=octaves)
55 |
56 | tile = np.dstack((r, g, b))[:height, :width, :]
57 | return Image.fromarray((255.9 * tile).astype('uint8'))
58 |
59 |
60 | def random_noise_image(width, height):
61 | return Image.fromarray(
62 | np.random.randint(0, 255, (width, height, 3), dtype=np.dtype('uint8'))
63 | )
64 |
65 |
66 | def gradient_2d(start, stop, width, height, is_horizontal):
67 | if is_horizontal:
68 | return np.tile(np.linspace(start, stop, width), (height, 1))
69 | else:
70 | return np.tile(np.linspace(start, stop, height), (width, 1)).T
71 |
72 |
73 | def gradient_3d(width, height, starts, stops, is_horizontal_list):
74 | result = np.zeros((height, width, len(starts)), dtype=float)
75 |
76 | for i, (start, stop, is_horizontal) in enumerate(zip(starts, stops, is_horizontal_list)):
77 | result[:, :, i] = gradient_2d(start, stop, width, height, is_horizontal)
78 |
79 | return result
80 |
81 |
82 | def random_gradient_image(width, height):
83 | array = gradient_3d(
84 | width,
85 | height,
86 | (0, 0, np.random.randint(0, 255)),
87 | (np.random.randint(1, 255), np.random.randint(2, 255), np.random.randint(3, 128)),
88 | (True, False, False)
89 | )
90 | random_image = Image.fromarray(np.uint8(array))
91 | return random_image
92 |
--------------------------------------------------------------------------------
/core/utils/normalize.py:
--------------------------------------------------------------------------------
1 | # https://github.com/pratogab/batch-transforms
2 |
3 | import torch
4 |
5 |
6 | class Normalize:
7 | """Applies the :class:`~torchvision.transforms.Normalize` transform to a batch of images.
8 |
9 | .. note::
10 | This transform acts out of place by default, i.e., it does not mutate the input tensor.
11 |
12 | Args:
13 | mean (sequence):
14 | Sequence of means for each channel.
15 | std (sequence):
16 | Sequence of standard deviations for each channel.
17 | inplace(bool,optional):
18 | Bool to make this operation in-place.
19 | dtype (torch.dtype,optional):
20 | The data type of tensors to which the transform will be applied.
21 | device (torch.device,optional):
22 | The device of tensors to which the transform will be applied.
23 | """
24 | def __init__(self, mean, std, inplace=False, dtype=torch.float, device='cpu'):
25 | self.mean = torch.as_tensor(mean, dtype=dtype, device=device)[None, :, None, None]
26 | self.std = torch.as_tensor(std, dtype=dtype, device=device)[None, :, None, None]
27 | self.inplace = inplace
28 |
29 | def __call__(self, tensor):
30 | """
31 | Args:
32 | tensor (Tensor): Tensor of size (N, C, H, W) to be normalized.
33 |
34 | Returns:
35 | Tensor: Normalized Tensor.
36 | """
37 | if not self.inplace:
38 | tensor = tensor.clone()
39 |
40 | tensor.sub_(self.mean).div_(self.std)
41 | return tensor
42 |
--------------------------------------------------------------------------------
/core/utils/prompt.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from core.utils.gradients import ReplaceGrad
6 |
7 |
8 | class Prompt(nn.Module):
9 | def __init__(self, embed, weight=1., stop=float('-inf')):
10 | super().__init__()
11 | self.register_buffer('embed', embed)
12 | self.register_buffer('weight', torch.as_tensor(weight))
13 | self.register_buffer('stop', torch.as_tensor(stop))
14 |
15 | def forward(self, input):
16 | input_normed = F.normalize(input.unsqueeze(1), dim=2)
17 | embed_normed = F.normalize(self.embed.unsqueeze(0), dim=2)
18 | dists = input_normed.sub(embed_normed).norm(dim=2).div(2).arcsin().pow(2).mul(2)
19 | dists = dists * self.weight.sign()
20 | return self.weight.abs() * ReplaceGrad.apply(dists, torch.maximum(dists, self.stop)).mean()
21 |
22 |
23 | def parse_prompt(prompt):
24 | vals = prompt.rsplit(':', 2)
25 | vals = vals + ['', '1', '-inf'][len(vals):]
26 | return vals[0], float(vals[1]), float(vals[2])
27 |
--------------------------------------------------------------------------------
/data/.gitignore:
--------------------------------------------------------------------------------
1 | *.txt
2 |
3 | *.jpg
4 | *.jpeg
5 | *.png
6 |
--------------------------------------------------------------------------------
/docker-compose.yml:
--------------------------------------------------------------------------------
1 | version: "3"
2 |
3 | services:
4 | generate:
5 | build: ./
6 | command: python -m scripts.generate -c /configs/docker.json
7 | volumes:
8 | - ./models:/models
9 | - ./configs:/configs
10 | - ./core:/app/core
11 | - ./scripts:/app/scripts
12 | - ./outputs:/outputs
13 | environment:
14 | - DEVICE=cuda
15 | deploy:
16 | resources:
17 | reservations:
18 | devices:
19 | - capabilities: [gpu]
20 |
21 | train:
22 | build: ./
23 | command: python -m scripts.train -c /configs/models/vqgan_custom_docker.json
24 | volumes:
25 | - ./models:/models
26 | - ./configs:/configs
27 | - ./core:/app/core
28 | - ./scripts:/app/scripts
29 | - ./outputs:/outputs
30 | environment:
31 | - DEVICE=cuda
32 | deploy:
33 | resources:
34 | reservations:
35 | devices:
36 | - capabilities: [gpu]
37 |
--------------------------------------------------------------------------------
/models/.gitignore:
--------------------------------------------------------------------------------
1 | *.pt
2 | *.pth
3 | *.ckpt
4 | *.bin
5 | *.pkl
6 |
--------------------------------------------------------------------------------
/outputs/.gitignore:
--------------------------------------------------------------------------------
1 | *.png
2 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.9.0
2 | torchvision==0.10.0
3 |
4 | einops==0.3.0
5 | kornia==0.5.7
6 |
7 | Pillow==8.3.2
8 | numpy==1.20.2
9 |
10 | requests==2.24.0
11 | tqdm==4.51.0
12 |
13 | regex==2021.4.4
14 | ftfy==6.0.3
15 |
--------------------------------------------------------------------------------
/samples/forest.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/forest.png
--------------------------------------------------------------------------------
/samples/ghost_pokemon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/ghost_pokemon.png
--------------------------------------------------------------------------------
/samples/gundam.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/gundam.png
--------------------------------------------------------------------------------
/samples/landscape.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/landscape.png
--------------------------------------------------------------------------------
/samples/sailor_moon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/sailor_moon.png
--------------------------------------------------------------------------------
/samples/waterfall.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/kcosta42/VQGAN-CLIP-Docker/73bdc5ed8581e9710a3a390db5389f0827ae1696/samples/waterfall.png
--------------------------------------------------------------------------------
/scripts/generate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | import torchvision.transforms.functional as TF
9 |
10 | import numpy as np
11 |
12 | from PIL import Image
13 |
14 | from tqdm import tqdm
15 |
16 | from core.schemas import Config
17 | from core.clip import clip
18 |
19 | from core.utils import MakeCutouts, Normalize, resize_image, get_optimizer, get_scheduler, load_vqgan_model, global_seed
20 | from core.utils.noises import random_noise_image, random_fractal_image, random_gradient_image
21 | from core.utils.prompt import Prompt, parse_prompt
22 | from core.utils.gradients import ClampWithGrad, vector_quantize
23 |
24 |
25 | PARAMS: Config = None
26 | DEVICE = torch.device(os.environ.get("DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu'))
27 | NORMALIZE = Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
28 | std=[0.26862954, 0.26130258, 0.27577711], device=DEVICE)
29 |
30 |
31 | def parse_args():
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument("-c", "--config", type=str, required=True, help="Path to configuration file.")
34 | return parser.parse_args()
35 |
36 |
37 | def initialize_image(model):
38 | f = 2**(model.decoder.num_resolutions - 1)
39 | toksX, toksY = PARAMS.size[0] // f, PARAMS.size[1] // f
40 | sideX, sideY = toksX * f, toksY * f
41 |
42 | def encode(img):
43 | pil_image = img.convert('RGB').resize((sideX, sideY), Image.LANCZOS)
44 | pil_tensor = TF.to_tensor(pil_image)
45 | z, *_ = model.encode(pil_tensor.to(DEVICE).unsqueeze(0) * 2 - 1)
46 | return z
47 |
48 | if PARAMS.init_image and os.path.exists(PARAMS.init_image):
49 | z = encode(Image.open(PARAMS.init_image))
50 | elif PARAMS.init_noise == 'pixels':
51 | z = encode(random_noise_image(PARAMS.size[0], PARAMS.size[1]))
52 | elif PARAMS.init_noise == 'fractal':
53 | z = encode(random_fractal_image(PARAMS.size[0], PARAMS.size[1]))
54 | elif PARAMS.init_noise == 'gradient':
55 | z = encode(random_gradient_image(PARAMS.size[0], PARAMS.size[1]))
56 | else:
57 | e_dim = model.quantize.e_dim
58 | n_toks = model.quantize.n_e
59 |
60 | one_hot = F.one_hot(torch.randint(n_toks, [toksY * toksX], device=DEVICE), n_toks).float()
61 | z = one_hot @ model.quantize.embedding.weight
62 | z = z.view([-1, toksY, toksX, e_dim]).permute(0, 3, 1, 2)
63 |
64 | return z
65 |
66 |
67 | def tokenize(model, perceptor, make_cutouts):
68 | f = 2**(model.decoder.num_resolutions - 1)
69 | toksX, toksY = PARAMS.size[0] // f, PARAMS.size[1] // f
70 | sideX, sideY = toksX * f, toksY * f
71 |
72 | prompts = []
73 | for prompt in PARAMS.prompts:
74 | txt, weight, stop = parse_prompt(prompt)
75 | embed = perceptor.encode_text(clip.tokenize(txt).to(DEVICE)).float()
76 | prompts.append(Prompt(embed, weight, stop).to(DEVICE))
77 |
78 | for prompt in PARAMS.image_prompts:
79 | path, weight, stop = parse_prompt(prompt)
80 | img = Image.open(path)
81 | pil_image = img.convert('RGB')
82 | img = resize_image(pil_image, (sideX, sideY))
83 | batch = make_cutouts(TF.to_tensor(img).unsqueeze(0).to(DEVICE))
84 | embed = perceptor.encode_image(NORMALIZE(batch)).float()
85 | prompts.append(Prompt(embed, weight, stop).to(DEVICE))
86 |
87 | for seed, weight in zip(PARAMS.noise_prompt_seeds, PARAMS.noise_prompt_weights):
88 | gen = torch.Generator().manual_seed(seed)
89 | embed = torch.empty([1, perceptor.visual.output_dim]).normal_(generator=gen)
90 | prompts.append(Prompt(embed, weight).to(DEVICE))
91 |
92 | return prompts
93 |
94 |
95 | def synth(z, *, model):
96 | z_q = vector_quantize(z.movedim(1, 3), model.quantize.embedding.weight).movedim(3, 1)
97 | z_q = ClampWithGrad.apply(model.decode(z_q).add(1).div(2), 0, 1)
98 |
99 | if PARAMS.pixelart:
100 | z_q = F.avg_pool2d(z_q, tuple(np.ceil(np.divide(PARAMS.size, PARAMS.pixelart)).astype('uint8')))
101 |
102 | return z_q
103 |
104 |
105 | @torch.no_grad()
106 | def checkin(z, losses, **kwargs):
107 | losses_str = ', '.join(f'{loss.item():g}' for loss in losses)
108 | tqdm.write(f"step: {kwargs['step']}, loss: {sum(losses).item():g}, losses: {losses_str}")
109 | out = synth(z, model=kwargs['model'])
110 |
111 | filename = "output"
112 | if len(PARAMS.prompts):
113 | filename = '_'.join(PARAMS.prompts).replace(' ', '_')
114 |
115 | path = f"{PARAMS.output_dir}/{filename}.png"
116 | TF.to_pil_image(out[0].cpu()).save(path)
117 |
118 |
119 | def ascend_txt(z, **kwargs):
120 | out = synth(z, model=kwargs['model'])
121 | cutouts = kwargs['make_cutouts'](out)
122 | iii = kwargs['perceptor'].encode_image(NORMALIZE(cutouts)).float()
123 |
124 | step = kwargs['step']
125 | result = []
126 | if PARAMS.init_weight:
127 | mse_weight = kwargs['mse_weight']
128 | result.append(F.mse_loss(z, kwargs['z_orig']) * mse_weight / 2)
129 |
130 | mse_decay = PARAMS.init_weight / (PARAMS.max_iterations / PARAMS.mse_decay_rate)
131 | with torch.no_grad():
132 | if step > 0 and step % PARAMS.mse_decay_rate == 0:
133 | kwargs['mse_weight'] = max(mse_weight - mse_decay, 0)
134 |
135 | for prompt in kwargs['prompts']:
136 | result.append(prompt(iii))
137 |
138 | TF.to_pil_image(out[0].cpu()).save(f"{PARAMS.output_dir}/steps/{step}.png")
139 | return result
140 |
141 |
142 | def train(z, **kwargs):
143 | kwargs['optimizer'].zero_grad(set_to_none=True)
144 | lossAll = ascend_txt(z, **kwargs)
145 |
146 | if kwargs['step'] % PARAMS.save_freq == 0 or kwargs['step'] == PARAMS.max_iterations:
147 | checkin(z, lossAll, **kwargs)
148 |
149 | loss = sum(lossAll)
150 | loss.backward()
151 | kwargs['optimizer'].step()
152 |
153 | if kwargs['scheduler'] is not None:
154 | kwargs['scheduler'].step()
155 |
156 | with torch.no_grad():
157 | z.copy_(z.maximum(kwargs['z_min']).minimum(kwargs['z_max']))
158 |
159 |
160 | def main():
161 | model = load_vqgan_model(PARAMS.vqgan_config, PARAMS.vqgan_checkpoint, PARAMS.models_dir).to(DEVICE)
162 | perceptor = clip.load(PARAMS.clip_model, device=DEVICE, root=PARAMS.models_dir)[0].eval().requires_grad_(False).to(DEVICE)
163 |
164 | cut_size = perceptor.visual.input_resolution
165 | make_cutouts = MakeCutouts(PARAMS.augments, cut_size, PARAMS.cutn, cut_pow=PARAMS.cut_pow)
166 |
167 | z_min = model.quantize.embedding.weight.min(dim=0).values[None, :, None, None]
168 | z_max = model.quantize.embedding.weight.max(dim=0).values[None, :, None, None]
169 | z = initialize_image(model)
170 | z_orig = torch.zeros_like(z)
171 | z.requires_grad_(True)
172 |
173 | prompts = tokenize(model, perceptor, make_cutouts)
174 | optimizer = get_optimizer(z, PARAMS.optimizer, PARAMS.step_size)
175 | scheduler = get_scheduler(optimizer, PARAMS.max_iterations, PARAMS.nwarm_restarts)
176 |
177 | kwargs = {
178 | 'model': model,
179 | 'perceptor': perceptor,
180 | 'optimizer': optimizer,
181 | 'scheduler': scheduler,
182 | 'prompts': prompts,
183 | 'make_cutouts': make_cutouts,
184 | 'z_orig': z_orig,
185 | 'z_min': z_min,
186 | 'z_max': z_max,
187 | 'mse_weight': PARAMS.init_weight,
188 | }
189 | try:
190 | for step in tqdm(range(PARAMS.max_iterations)):
191 | kwargs['step'] = step + 1
192 | train(z, **kwargs)
193 | except KeyboardInterrupt:
194 | pass
195 |
196 |
197 | if __name__ == "__main__":
198 | args = parse_args()
199 |
200 | if not os.path.exists(args.config):
201 | exit(f"ERROR: {args.config} not found.")
202 |
203 | print(f"Loading configuration from '{args.config}'")
204 | with open(args.config, 'r') as f:
205 | PARAMS = Config(**json.load(f))
206 |
207 | print(f"Running on {DEVICE}.")
208 | print(PARAMS)
209 |
210 | global_seed(PARAMS.seed)
211 |
212 | main()
213 |
--------------------------------------------------------------------------------
/scripts/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 |
5 | import torch
6 | from torch.utils.data import DataLoader
7 |
8 | import torchvision.transforms.functional as TF
9 | from torchvision import transforms as T
10 | from torchvision.datasets import ImageFolder
11 |
12 | from tqdm import tqdm
13 |
14 | from core.schemas import TrainConfig
15 | from core.utils import global_seed
16 | from core.utils.loader import safe_load
17 | from core.taming.models import vqgan
18 |
19 |
20 | PARAMS: TrainConfig = None
21 | DEVICE = torch.device(os.environ.get("DEVICE", 'cuda' if torch.cuda.is_available() else 'cpu'))
22 |
23 |
24 | def parse_args():
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument("-c", "--config", type=str, required=True, help="Path to configuration file.")
27 | return parser.parse_args()
28 |
29 |
30 | def save_model(model, optimizers, epoch, path):
31 | save_dict = {
32 | "epoch": epoch,
33 | "global_step": model.global_step,
34 | "state_dict": model.state_dict(),
35 | "optimizer_states": [
36 | optimizers[0].state_dict(),
37 | optimizers[1].state_dict(),
38 | ]
39 | }
40 | torch.save(save_dict, path)
41 | tqdm.write(f"Checkpoint saved in {path}")
42 |
43 |
44 | def main():
45 | dataset = ImageFolder(PARAMS.data_dir, T.Compose(
46 | [
47 | T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
48 | T.Resize(PARAMS.params["embed_dim"]),
49 | T.CenterCrop(PARAMS.params["embed_dim"]),
50 | T.ToTensor()
51 | ]
52 | ))
53 | loader = DataLoader(dataset, PARAMS.batch_size, shuffle=True)
54 |
55 | PARAMS.params["model_dir"] = PARAMS.models_dir
56 | model = vqgan.VQModel(**PARAMS.params).to(DEVICE)
57 | model.learning_rate = PARAMS.batch_size * PARAMS.base_learning_rate
58 | model.global_step = 0
59 |
60 | optimizers, _ = model.configure_optimizers()
61 | epoch = 0
62 |
63 | if PARAMS.resume_checkpoint:
64 | save_dict = safe_load(PARAMS.resume_checkpoint, map_location='cpu')
65 | epoch = save_dict["epoch"]
66 | model.global_step = save_dict["global_step"]
67 | model.load_state_dict(save_dict["state_dict"])
68 | optimizers[0].load_state_dict(save_dict["optimizer_states"][0])
69 | optimizers[1].load_state_dict(save_dict["optimizer_states"][1])
70 | print(f"Restored model from {PARAMS.resume_checkpoint}")
71 |
72 | while epoch < PARAMS.epochs:
73 | for i, (images, _) in tqdm(enumerate(loader), total=len(loader)):
74 | images.to(DEVICE)
75 |
76 | losses = []
77 | for j, opt in enumerate(optimizers):
78 | loss = model.training_step(images, i, j, device=DEVICE)
79 | losses.append(loss.item())
80 |
81 | opt.zero_grad()
82 | loss.backward()
83 |
84 | opt.step()
85 |
86 | tqdm.write(f"Epoch: {epoch} | Batch: {i} | losses: {losses}")
87 |
88 | if i % 1000 == 0:
89 | save_model(model, optimizers, epoch, f"{PARAMS.models_dir}/checkpoints/last.ckpt")
90 |
91 | with torch.no_grad():
92 | dec, _ = model(model.get_input(images, device=DEVICE))
93 | TF.to_pil_image(dec[0].cpu()).save(f"{PARAMS.output_dir}/training/{epoch}_{i}.png")
94 |
95 | model.global_step += 1
96 | epoch += 1
97 |
98 | save_model(model, optimizers, epoch, f"{PARAMS.models_dir}/checkpoints/final.ckpt")
99 |
100 |
101 | if __name__ == "__main__":
102 | args = parse_args()
103 |
104 | if not os.path.exists(args.config):
105 | exit(f"ERROR: {args.config} not found.")
106 |
107 | print(f"Loading configuration from '{args.config}'")
108 | with open(args.config, 'r') as f:
109 | PARAMS = TrainConfig(**json.load(f))
110 |
111 | print(f"Running on {DEVICE}.")
112 | print(PARAMS)
113 |
114 | global_seed(PARAMS.seed)
115 |
116 | main()
117 |
--------------------------------------------------------------------------------