├── .github
├── FUNDING.yml
└── workflows
│ ├── pre-commit.yml
│ ├── python-publish.yml
│ └── sem-version-release.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .vscode
├── extensions.json
├── launch.json
└── settings.json
├── LICENSE
├── README.md
├── infer_vae.py
├── muse.png
├── muse_maskgit_pytorch
├── __init__.py
├── attn
│ ├── __init__.py
│ ├── attn_test.py
│ ├── ein_attn.py
│ ├── sdp_attn.py
│ └── xformers_attn.py
├── dataset.py
├── distributed_utils.py
├── modules
│ ├── __init__.py
│ ├── attention.py
│ └── mlp.py
├── muse_maskgit_pytorch.py
├── t5.py
├── trainers
│ ├── __init__.py
│ ├── base_accelerated_trainer.py
│ ├── maskgit_trainer.py
│ └── vqvae_trainers.py
├── vqgan_vae.py
├── vqgan_vae_taming.py
└── vqvae
│ ├── __init__.py
│ ├── config.py
│ ├── discriminator.py
│ ├── layers.py
│ ├── quantize.py
│ └── vqvae.py
├── pyproject.toml
├── scripts
└── vqvae_test.py
├── setup.py
├── tpu-vm.env
├── train_muse_maskgit.py
└── train_muse_vae.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | github: [ZeroCool940711]
2 | patreon: zerocool94
3 | ko_fi: zerocool94
4 | open_collective: sygil_dev
5 | custom: ["https://paypal.me/zerocool94"]
6 |
--------------------------------------------------------------------------------
/.github/workflows/pre-commit.yml:
--------------------------------------------------------------------------------
1 | name: pre-commit
2 |
3 | on:
4 | pull_request:
5 | push:
6 | branches:
7 | - main
8 | - dev
9 |
10 | jobs:
11 | pre-commit:
12 | strategy:
13 | fail-fast: false
14 | matrix:
15 | os: [ubuntu-latest]
16 | python-version: ["3.8", "3.10"]
17 |
18 | runs-on: ${{ matrix.os }}
19 | steps:
20 | - name: Checkout
21 | id: checkout
22 | uses: actions/checkout@v3
23 | with:
24 | submodules: "recursive"
25 |
26 | - name: Set up Python
27 | id: setup-python
28 | uses: actions/setup-python@v3
29 | with:
30 | python-version: ${{ matrix.python-version }}
31 |
32 | - name: Install package
33 | id: install-package
34 | run: |
35 | python -m pip install --upgrade pip setuptools wheel
36 | pip install -e '.[dev]'
37 |
38 | - name: Run pre-commit
39 | uses: pre-commit/action@v3.0.0
40 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.github/workflows/sem-version-release.yml:
--------------------------------------------------------------------------------
1 | name: Bump version
2 | on:
3 | push:
4 | branches:
5 | - master
6 | jobs:
7 | build:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - uses: actions/checkout@master
11 | - name: Bump version and push tag
12 | uses: hennejg/github-tag-action@v4.1.jh1
13 | with:
14 | github_token: ${{ secrets.GITHUB_TOKEN }}
15 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.ipynb
2 | wandb
3 | results
4 | models
5 | dataset
6 | taming
7 | ~
8 | input.png
9 | output.png
10 | muse_maskgit_pytorch/wt.py
11 |
12 | # Byte-compiled / optimized / DLL files
13 | __pycache__/
14 | *.py[cod]
15 | *$py.class
16 |
17 | # C extensions
18 | *.so
19 |
20 | # Distribution / packaging
21 | .Python
22 | build/
23 | develop-eggs/
24 | dist/
25 | downloads/
26 | eggs/
27 | .eggs/
28 | lib/
29 | lib64/
30 | parts/
31 | sdist/
32 | var/
33 | wheels/
34 | pip-wheel-metadata/
35 | share/python-wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 | MANIFEST
40 |
41 | # PyInstaller
42 | # Usually these files are written by a python script from a template
43 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
44 | *.manifest
45 | *.spec
46 |
47 | # Installer logs
48 | pip-log.txt
49 | pip-delete-this-directory.txt
50 |
51 | # Unit test / coverage reports
52 | htmlcov/
53 | .tox/
54 | .nox/
55 | .coverage
56 | .coverage.*
57 | .cache
58 | nosetests.xml
59 | coverage.xml
60 | *.cover
61 | *.py,cover
62 | .hypothesis/
63 | .pytest_cache/
64 |
65 | # Translations
66 | *.mo
67 | *.pot
68 |
69 | # Django stuff:
70 | *.log
71 | local_settings.py
72 | db.sqlite3
73 | db.sqlite3-journal
74 |
75 | # Flask stuff:
76 | instance/
77 | .webassets-cache
78 |
79 | # Scrapy stuff:
80 | .scrapy
81 |
82 | # Sphinx documentation
83 | docs/_build/
84 |
85 | # PyBuilder
86 | target/
87 |
88 | # Jupyter Notebook
89 | .ipynb_checkpoints
90 |
91 | # IPython
92 | profile_default/
93 | ipython_config.py
94 |
95 | # pyenv
96 | .python-version
97 |
98 | # pipenv
99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
102 | # install all needed dependencies.
103 | #Pipfile.lock
104 |
105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106 | __pypackages__/
107 |
108 | # Celery stuff
109 | celerybeat-schedule
110 | celerybeat.pid
111 |
112 | # SageMath parsed files
113 | *.sage.py
114 |
115 | # Environments
116 | .env
117 | .venv
118 | env/
119 | venv/
120 | ENV/
121 | env.bak/
122 | venv.bak/
123 |
124 | # Spyder project settings
125 | .spyderproject
126 | .spyproject
127 |
128 | # Rope project settings
129 | .ropeproject
130 |
131 | # mkdocs documentation
132 | /site
133 |
134 | # mypy
135 | .mypy_cache/
136 | .dmypy.json
137 | dmypy.json
138 |
139 | # Pyre type checker
140 | .pyre/
141 |
142 | # setuptools-scm version file
143 | muse_maskgit_pytorch/_version.py
144 |
145 | # wandb dir
146 | /wandb/
147 |
148 | # data, output
149 | /data/
150 | /output/
151 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | ci:
4 | autofix_prs: true
5 | autoupdate_branch: 'dev'
6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate'
7 | autoupdate_schedule: weekly
8 |
9 | repos:
10 | - repo: https://github.com/pre-commit/pre-commit-hooks
11 | rev: v4.4.0
12 | hooks:
13 | - id: trailing-whitespace
14 | - id: end-of-file-fixer
15 | - id: check-yaml
16 | - id: check-added-large-files
17 |
18 | - repo: https://github.com/astral-sh/ruff-pre-commit
19 | rev: "v0.0.278"
20 | hooks:
21 | - id: ruff
22 | args: [--fix, --exit-non-zero-on-fix]
23 |
24 | - repo: https://github.com/psf/black
25 | rev: 23.7.0
26 | hooks:
27 | - id: black
28 |
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": [
3 | "ms-python.python",
4 | "charliermarsh.ruff",
5 | "redhat.vscode-yaml",
6 | "codezombiech.gitignore",
7 | "ms-python.black-formatter"
8 | ]
9 | }
10 |
--------------------------------------------------------------------------------
/.vscode/launch.json:
--------------------------------------------------------------------------------
1 | {
2 | "configurations": [
3 | {
4 | "name": "Python: Verify attn impl equivalence",
5 | "type": "python",
6 | "request": "launch",
7 | "module": "attn_test",
8 | "justMyCode": false
9 | }
10 | ]
11 | }
12 |
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "editor.formatOnSaveMode": "file",
3 |
4 | "files.associations": {
5 | ".config": "shellscript",
6 | ".gitignore": "gitignore",
7 | ".vscode/*.json": "jsonc",
8 | "*.txt": "plaintext",
9 | "requirements*.txt": "pip-requirements",
10 | "setup.cfg": "ini",
11 | },
12 |
13 | "[json]": {
14 | "editor.codeActionsOnSave": {
15 | "source.fixAll.sortJSON": false
16 | },
17 | "editor.defaultFormatter": "vscode.json-language-features",
18 | "editor.formatOnSave": true,
19 | "editor.tabSize": 4
20 | },
21 | "[jsonc]": {
22 | "editor.codeActionsOnSave": {
23 | "source.fixAll.sortJSON": false
24 | },
25 | "editor.defaultFormatter": "vscode.json-language-features",
26 | "editor.formatOnSave": true,
27 | "editor.tabSize": 4
28 | },
29 | "json.format.keepLines": true,
30 |
31 | "[python]": {
32 | "editor.formatOnSave": true,
33 | "editor.defaultFormatter": "ms-python.black-formatter",
34 | "editor.codeActionsOnSave": {
35 | "source.organizeImports": true
36 | }
37 | },
38 | "python.formatting.provider": "none",
39 | "ruff.organizeImports": true,
40 | "ruff.args": [ "--line-length=110", "--extend-ignore=F401,F841" ],
41 | "black-formatter.args": [ "--line-length", "110" ],
42 | "python.linting.flake8Enabled": false,
43 | "python.linting.mypyEnabled": false,
44 | }
45 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Muse - Pytorch
4 |
5 | ### Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch originally made by [Lucidrains](https://github.com/lucidrains/muse-maskgit-pytorch).
6 |
7 | ##
8 | We have added additional code to allow anyone to train their own model and we have optimized the code for low end hardware.
9 |
10 | ## Join us at Sygil.Dev's Discord Server [](https://discord.gg/ttM8Tm6wge)
11 |
12 | ## Install
13 | For installing the code you have two options:
14 |
15 | 1 - You can install it directly from the repo with pip:
16 | ```bash
17 | $ pip install git+https://github.com/Sygil-Dev/muse-maskgit-pytorch
18 | ```
19 | 2 - or you can clone it and then install from source:
20 | ```bash
21 | $ git clone https://github.com/Sygil-Dev/muse-maskgit-pytorch
22 | $ cd muse-maskgit-pytorch
23 | $ pip install .
24 | ```
25 |
26 | ## Usage
27 |
28 | First train your VAE - `VQGanVAE`
29 |
30 | ```python
31 | import torch
32 | from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer
33 |
34 | vae = VQGanVAE(
35 | dim = 256,
36 | vq_codebook_size = 512
37 | )
38 |
39 | # train on folder of images, as many images as possible
40 |
41 | trainer = VQGanVAETrainer(
42 | vae = vae,
43 | image_size = 128, # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it
44 | folder = '/path/to/images',
45 | batch_size = 4,
46 | grad_accum_every = 8,
47 | num_train_steps = 50000
48 | ).cuda()
49 |
50 | trainer.train()
51 | ```
52 |
53 | Then pass the trained `VQGanVAE` and a `Transformer` to `MaskGit`
54 |
55 | ```python
56 | import torch
57 | from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer
58 |
59 | # first instantiate your vae
60 |
61 | vae = VQGanVAE(
62 | dim = 256,
63 | vq_codebook_size = 512
64 | ).cuda()
65 |
66 | vae.load('/path/to/vae.pt') # you will want to load the exponentially moving averaged VAE
67 |
68 | # then you plug the vae and transformer into your MaskGit as so
69 |
70 | # (1) create your transformer / attention network
71 |
72 | transformer = MaskGitTransformer(
73 | num_tokens = 512, # must be same as codebook size above
74 | seq_len = 256, # must be equivalent to fmap_size ** 2 in vae
75 | dim = 512, # model dimension
76 | depth = 8, # depth
77 | dim_head = 64, # attention head dimension
78 | heads = 8, # attention heads,
79 | ff_mult = 4, # feedforward expansion factor
80 | t5_name = 't5-small', # name of your T5
81 | )
82 |
83 | # (2) pass your trained VAE and the base transformer to MaskGit
84 |
85 | base_maskgit = MaskGit(
86 | vae = vae, # vqgan vae
87 | transformer = transformer, # transformer
88 | image_size = 256, # image size
89 | cond_drop_prob = 0.25, # conditional dropout, for classifier free guidance
90 | ).cuda()
91 |
92 | # ready your training text and images
93 |
94 | texts = [
95 | 'a child screaming at finding a worm within a half-eaten apple',
96 | 'lizard running across the desert on two feet',
97 | 'waking up to a psychedelic landscape',
98 | 'seashells sparkling in the shallow waters'
99 | ]
100 |
101 | images = torch.randn(4, 3, 256, 256).cuda()
102 |
103 | # feed it into your maskgit instance, with return_loss set to True
104 |
105 | loss = base_maskgit(
106 | images,
107 | texts = texts
108 | )
109 |
110 | loss.backward()
111 |
112 | # do this for a long time on much data
113 | # then...
114 |
115 | images = base_maskgit.generate(texts = [
116 | 'a whale breaching from afar',
117 | 'young girl blowing out candles on her birthday cake',
118 | 'fireworks with blue and green sparkles'
119 | ], cond_scale = 3.) # conditioning scale for classifier free guidance
120 |
121 | images.shape # (3, 3, 256, 256)
122 | ```
123 |
124 | To train the super-resolution maskgit requires you to change 1 field on `MaskGit` instantiation (you will need to now pass in the `cond_image_size`, as the previous image size being conditioned on)
125 |
126 | Optionally, you can pass in a different `VAE` as `cond_vae` for the conditioning low-resolution image. By default it will use the `vae` for both tokenizing the super and low resoluted images.
127 |
128 | ```python
129 | import torch
130 | import torch.nn.functional as F
131 | from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer
132 |
133 | # first instantiate your ViT VQGan VAE
134 | # a VQGan VAE made of transformers
135 |
136 | vae = VQGanVAE(
137 | dim = 256,
138 | vq_codebook_size = 512
139 | ).cuda()
140 |
141 | vae.load('./path/to/vae.pt') # you will want to load the exponentially moving averaged VAE
142 |
143 | # then you plug the VqGan VAE into your MaskGit as so
144 |
145 | # (1) create your transformer / attention network
146 |
147 | transformer = MaskGitTransformer(
148 | num_tokens = 512, # must be same as codebook size above
149 | seq_len = 1024, # must be equivalent to fmap_size ** 2 in vae
150 | dim = 512, # model dimension
151 | depth = 2, # depth
152 | dim_head = 64, # attention head dimension
153 | heads = 8, # attention heads,
154 | ff_mult = 4, # feedforward expansion factor
155 | t5_name = 't5-small', # name of your T5
156 | )
157 |
158 | # (2) pass your trained VAE and the base transformer to MaskGit
159 |
160 | superres_maskgit = MaskGit(
161 | vae = vae,
162 | transformer = transformer,
163 | cond_drop_prob = 0.25,
164 | image_size = 512, # larger image size
165 | cond_image_size = 256, # conditioning image size <- this must be set
166 | ).cuda()
167 |
168 | # ready your training text and images
169 |
170 | texts = [
171 | 'a child screaming at finding a worm within a half-eaten apple',
172 | 'lizard running across the desert on two feet',
173 | 'waking up to a psychedelic landscape',
174 | 'seashells sparkling in the shallow waters'
175 | ]
176 |
177 | images = torch.randn(4, 3, 512, 512).cuda()
178 |
179 | # feed it into your maskgit instance, with return_loss set to True
180 |
181 | loss = superres_maskgit(
182 | images,
183 | texts = texts
184 | )
185 |
186 | loss.backward()
187 |
188 | # do this for a long time on much data
189 | # then...
190 |
191 | images = superres_maskgit.generate(
192 | texts = [
193 | 'a whale breaching from afar',
194 | 'young girl blowing out candles on her birthday cake',
195 | 'fireworks with blue and green sparkles',
196 | 'waking up to a psychedelic landscape'
197 | ],
198 | cond_images = F.interpolate(images, 256), # conditioning images must be passed in for generating from superres
199 | cond_scale = 3.
200 | )
201 |
202 | images.shape # (4, 3, 512, 512)
203 | ```
204 |
205 | All together now
206 |
207 | ```python
208 | from muse_maskgit_pytorch import Muse
209 |
210 | base_maskgit.load('./path/to/base.pt')
211 |
212 | superres_maskgit.load('./path/to/superres.pt')
213 |
214 | # pass in the trained base_maskgit and superres_maskgit from above
215 |
216 | muse = Muse(
217 | base = base_maskgit,
218 | superres = superres_maskgit
219 | )
220 |
221 | images = muse([
222 | 'a whale breaching from afar',
223 | 'young girl blowing out candles on her birthday cake',
224 | 'fireworks with blue and green sparkles',
225 | 'waking up to a psychedelic landscape'
226 | ])
227 |
228 | images # List[PIL.Image.Image]
229 | ```
230 |
231 | ## Training
232 |
233 | Training should be done in 4 stages.
234 |
235 | 1. Training base VAE(swap out the dataset_name with your huggingface dataset)
236 |
237 | ```
238 | accelerate launch train_muse_vae.py --dataset_name="Isamu136/big-animal-dataset"
239 | ```
240 | 2. Once you trained enough in the base VAE, move the checkpoint of your latest version to a new location. Then, do
241 |
242 | ```
243 | accelerate launch train_muse_maskgit.py --dataset_name="Isamu136/big-animal-dataset" --vae_path=path_to_vae_checkpoint
244 | ```
245 |
246 | Alternatively, if you want to use a pretrained autoencoder, download one from [here](https://github.com/CompVis/taming-transformers) and then extract it. In the below code, we are using vqgan_imagenet_f16_1024. Change the paths accordingly
247 |
248 | ```
249 | accelerate launch train_muse_maskgit.py --dataset_name="Isamu136/big-animal-dataset" --taming_model_path="models/image_net_f16/ckpts/last.ckpt" --taming_config_path="models/image_net_f16/configs/model.yaml" --validation_prompt="elephant"
250 | ```
251 |
252 | or if you want to train on cifar10, try
253 |
254 | ```
255 | accelerate launch train_muse_maskgit.py --dataset_name="cifar10" --taming_model_path="models/image_net_f16/ckpts/last.ckpt" --taming_config_path="models/image_net_f16/configs/model.yaml" --validation_prompt="0" --image_column="img" --caption_column="label"
256 | ```
257 |
258 | ## Checkpoints and Pretrained Models
259 | We currently do not have any usable pretrained model for Muse but we are trying to train it with whatever resources we have available, for more information check the [Sygil Muse](https://huggingface.co/Sygil/Sygil-Muse) repository on HuggingFace where we are uploading the checkpoints for different tests we have performed and where we will be uploading the final weights once we have something everyone can use.
260 |
261 | ## Appreciation
262 | - [Lucidrains](https://github.com/lucidrains/muse-maskgit-pytorch) for the original Muse-Maskgit-Pytorch implementation.
263 | - The [ShoukanLabs](https://github.com/ShoukanLabs) team for contributing so much to improving the code and adding new features.
264 | - StabilityAI for the sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence.
265 |
266 | - 🤗 Huggingface for the transformers and accelerate library, both which are wonderful
267 |
268 | ## Todo
269 |
270 | - [x] test end-to-end
271 |
272 | - [x] separate cond_images_or_ids, it is not done right
273 |
274 | - [x] add training code for vae
275 |
276 | - [x] add optional self-conditioning on embeddings
277 |
278 | - [x] combine with token critic paper, already implemented at Phenaki
279 |
280 | - [x] hook up accelerate training code for maskgit
281 |
282 | - [ ] train a base model
283 |
284 | ## Citations
285 |
286 | ```bibtex
287 | @inproceedings{Chang2023MuseTG,
288 | title = {Muse: Text-To-Image Generation via Masked Generative Transformers},
289 | author = {Huiwen Chang and Han Zhang and Jarred Barber and AJ Maschinot and Jos{\'e} Lezama and Lu Jiang and Ming-Hsuan Yang and Kevin P. Murphy and William T. Freeman and Michael Rubinstein and Yuanzhen Li and Dilip Krishnan},
290 | year = {2023}
291 | }
292 | ```
293 |
294 | ```bibtex
295 | @article{Chen2022AnalogBG,
296 | title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
297 | author = {Ting Chen and Ruixiang Zhang and Geo rey E. Hinton},
298 | journal = {ArXiv},
299 | year = {2022},
300 | volume = {abs/2208.04202}
301 | }
302 | ```
303 |
304 | ```bibtex
305 | @misc{jabri2022scalable,
306 | title = {Scalable Adaptive Computation for Iterative Generation},
307 | author = {Allan Jabri and David Fleet and Ting Chen},
308 | year = {2022},
309 | eprint = {2212.11972},
310 | archivePrefix = {arXiv},
311 | primaryClass = {cs.LG}
312 | }
313 | ```
314 |
315 | ```bibtex
316 | @article{Lezama2022ImprovedMI,
317 | title = {Improved Masked Image Generation with Token-Critic},
318 | author = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa},
319 | journal = {ArXiv},
320 | year = {2022},
321 | volume = {abs/2209.04439}
322 | }
323 | ```
324 |
325 | ```bibtex
326 | @inproceedings{Nijkamp2021SCRIPTSP,
327 | title = {SCRIPT: Self-Critic PreTraining of Transformers},
328 | author = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong},
329 | booktitle = {North American Chapter of the Association for Computational Linguistics},
330 | year = {2021}
331 | }
332 | ```
333 |
334 | ```bibtex
335 | @misc{gilmer2023intriguing
336 | title = {Intriguing Properties of Transformer Training Instabilities},
337 | author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
338 | year = {2023},
339 | status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
340 | }
341 | ```
342 |
--------------------------------------------------------------------------------
/infer_vae.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import hashlib
4 | import os
5 | import random
6 | import re
7 | from dataclasses import dataclass
8 | from datetime import datetime
9 | from typing import Optional
10 |
11 | import accelerate
12 | import PIL
13 | import torch
14 | from accelerate.utils import ProjectConfiguration
15 | from datasets import Dataset, Image, load_dataset
16 | from torchvision.utils import save_image
17 | from tqdm import tqdm
18 |
19 | from muse_maskgit_pytorch import (
20 | VQGanVAE,
21 | VQGanVAETaming,
22 | get_accelerator,
23 | )
24 | from muse_maskgit_pytorch.dataset import (
25 | ImageDataset,
26 | get_dataset_from_dataroot,
27 | )
28 | from muse_maskgit_pytorch.vqvae import VQVAE
29 |
30 | # Create the parser
31 | parser = argparse.ArgumentParser()
32 | parser.add_argument(
33 | "--no_center_crop",
34 | action="store_true",
35 | help="Don't do center crop.",
36 | )
37 | parser.add_argument(
38 | "--random_crop",
39 | action="store_true",
40 | help="Crop the images at random locations instead of cropping from the center.",
41 | )
42 | parser.add_argument(
43 | "--no_flip",
44 | action="store_true",
45 | help="Don't flip image.",
46 | )
47 | parser.add_argument(
48 | "--random_image",
49 | action="store_true",
50 | help="Get a random image from the dataset to use for the reconstruction.",
51 | )
52 | parser.add_argument(
53 | "--dataset_save_path",
54 | type=str,
55 | default="dataset",
56 | help="Path to save the dataset if you are making one from a directory",
57 | )
58 | parser.add_argument(
59 | "--seed",
60 | type=int,
61 | default=42,
62 | help="Seed for reproducibility. If set to -1 a random seed will be generated.",
63 | )
64 | parser.add_argument("--valid_frac", type=float, default=0.05, help="validation fraction.")
65 | parser.add_argument(
66 | "--image_column",
67 | type=str,
68 | default="image",
69 | help="The column of the dataset containing an image.",
70 | )
71 | parser.add_argument(
72 | "--mixed_precision",
73 | type=str,
74 | default="no",
75 | choices=["no", "fp16", "bf16"],
76 | help="Precision to train on.",
77 | )
78 | parser.add_argument(
79 | "--results_dir",
80 | type=str,
81 | default="results",
82 | help="Path to save the training samples and checkpoints",
83 | )
84 | parser.add_argument(
85 | "--logging_dir",
86 | type=str,
87 | default=None,
88 | help="Path to log the losses and LR",
89 | )
90 |
91 | # vae_trainer args
92 | parser.add_argument(
93 | "--vae_path",
94 | type=str,
95 | default=None,
96 | help="Path to the vae model. eg. 'results/vae.steps.pt'",
97 | )
98 | parser.add_argument(
99 | "--dataset_name",
100 | type=str,
101 | default=None,
102 | help="Name of the huggingface dataset used.",
103 | )
104 | parser.add_argument(
105 | "--train_data_dir",
106 | type=str,
107 | default=None,
108 | help="Dataset folder where your input images for training are.",
109 | )
110 | parser.add_argument("--dim", type=int, default=128, help="Model dimension.")
111 | parser.add_argument("--batch_size", type=int, default=512, help="Batch Size.")
112 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.")
113 | parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.")
114 | parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
115 | parser.add_argument(
116 | "--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
117 | )
118 | parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
119 | parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
120 | parser.add_argument(
121 | "--image_size",
122 | type=int,
123 | default=256,
124 | help="Image size. You may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it",
125 | )
126 | parser.add_argument(
127 | "--chunk_size",
128 | type=int,
129 | default=256,
130 | help="This is used to split big images into smaller chunks so we can still reconstruct them no matter the size.",
131 | )
132 | parser.add_argument(
133 | "--min_chunk_size",
134 | type=int,
135 | default=8,
136 | help="We use a minimum chunk size to ensure that the image is always reconstructed correctly.",
137 | )
138 | parser.add_argument(
139 | "--overlap_size",
140 | type=int,
141 | default=256,
142 | help="The overlap size used with --chunk_size to overlap the chunks and make sure the whole image is reconstructe as well as make sure we remove artifacts caused by doing the reconstrucion in chunks.",
143 | )
144 | parser.add_argument(
145 | "--min_overlap_size",
146 | type=int,
147 | default=1,
148 | help="We use a minimum overlap size to ensure that the image is always reconstructed correctly.",
149 | )
150 | parser.add_argument(
151 | "--taming_model_path",
152 | type=str,
153 | default=None,
154 | help="path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)",
155 | )
156 |
157 | parser.add_argument(
158 | "--taming_config_path",
159 | type=str,
160 | default=None,
161 | help="path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)",
162 | )
163 | parser.add_argument(
164 | "--input_image",
165 | type=str,
166 | default=None,
167 | help="Path to an image to use as input for reconstruction instead of using one from the dataset.",
168 | )
169 | parser.add_argument(
170 | "--input_folder",
171 | type=str,
172 | default=None,
173 | help="Path to a folder with images to use as input for creating a dataset for reconstructing all the imgaes in it instead of just one image.",
174 | )
175 | parser.add_argument(
176 | "--exclude_folders",
177 | type=str,
178 | default=None,
179 | help="List of folders we want to exclude when doing reconstructions from an input folder.",
180 | )
181 | parser.add_argument(
182 | "--gpu",
183 | type=int,
184 | default=0,
185 | help="GPU to use in case we want to use a specific GPU for inference.",
186 | )
187 | parser.add_argument(
188 | "--max_retries",
189 | type=int,
190 | default=30,
191 | help="Max number of times to retry in case the reconstruction fails due to OOM or any other error.",
192 | )
193 | parser.add_argument(
194 | "--latest_checkpoint",
195 | action="store_true",
196 | help="Use the latest checkpoint using the vae_path folder instead of using just a specific vae_path.",
197 | )
198 | parser.add_argument(
199 | "--use_paintmind",
200 | action="store_true",
201 | help="Use PaintMind VAE..",
202 | )
203 |
204 |
205 | @dataclass
206 | class Arguments:
207 | only_save_last_checkpoint: bool = False
208 | validation_image_scale: float = 1.0
209 | no_center_crop: bool = False
210 | no_flip: bool = False
211 | random_crop: bool = False
212 | random_image: bool = False
213 | dataset_save_path: Optional[str] = None
214 | clear_previous_experiments: bool = False
215 | max_grad_norm: Optional[float] = None
216 | discr_max_grad_norm: Optional[float] = None
217 | num_tokens: int = 256
218 | seq_len: int = 1024
219 | seed: int = 42
220 | valid_frac: float = 0.05
221 | use_ema: bool = False
222 | ema_beta: float = 0.995
223 | ema_update_after_step: int = 1
224 | ema_update_every: int = 1
225 | apply_grad_penalty_every: int = 4
226 | image_column: str = "image"
227 | caption_column: str = "caption"
228 | log_with: str = "wandb"
229 | mixed_precision: str = "no"
230 | use_8bit_adam: bool = False
231 | results_dir: str = "results"
232 | logging_dir: Optional[str] = None
233 | resume_path: Optional[str] = None
234 | dataset_name: Optional[str] = None
235 | streaming: bool = False
236 | train_data_dir: Optional[str] = None
237 | num_train_steps: int = -1
238 | num_epochs: int = 5
239 | dim: int = 128
240 | batch_size: int = 512
241 | lr: float = 1e-5
242 | gradient_accumulation_steps: int = 1
243 | save_results_every: int = 100
244 | save_model_every: int = 500
245 | vq_codebook_size: int = 256
246 | vq_codebook_dim: int = 256
247 | cond_drop_prob: float = 0.5
248 | image_size: int = 256
249 | lr_scheduler: str = "constant"
250 | scheduler_power: float = 1.0
251 | lr_warmup_steps: int = 0
252 | num_cycles: int = 1
253 | taming_model_path: Optional[str] = None
254 | taming_config_path: Optional[str] = None
255 | optimizer: str = "Lion"
256 | weight_decay: float = 0.0
257 | cache_path: Optional[str] = None
258 | no_cache: bool = False
259 | latest_checkpoint: bool = False
260 | do_not_save_config: bool = False
261 | use_l2_recon_loss: bool = False
262 | debug: bool = False
263 | config_path: Optional[str] = None
264 | generate_config: bool = False
265 |
266 |
267 | def seed_to_int(s):
268 | if type(s) is int:
269 | return s
270 | if s is None or s == "":
271 | return random.randint(0, 2**32 - 1)
272 |
273 | if "," in s:
274 | s = s.split(",")
275 |
276 | if type(s) is list:
277 | seed_list = []
278 | for seed in s:
279 | if seed is None or seed == "":
280 | seed_list.append(random.randint(0, 2**32 - 1))
281 | else:
282 | seed_list = s
283 |
284 | return seed_list
285 |
286 | n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1))
287 | while n >= 2**32:
288 | n = n >> 32
289 | return n
290 |
291 |
292 | def main():
293 | args = parser.parse_args(namespace=Arguments())
294 |
295 | project_config = ProjectConfiguration(
296 | project_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"),
297 | automatic_checkpoint_naming=True,
298 | )
299 |
300 | accelerator: accelerate.Accelerator = get_accelerator(
301 | log_with=args.log_with,
302 | gradient_accumulation_steps=args.gradient_accumulation_steps,
303 | mixed_precision=args.mixed_precision,
304 | project_config=project_config,
305 | even_batches=True,
306 | )
307 |
308 | # set pytorch seed for reproducibility
309 | torch.manual_seed(seed_to_int(args.seed))
310 |
311 | if args.train_data_dir and not args.input_image and not args.input_folder:
312 | dataset = get_dataset_from_dataroot(
313 | args.train_data_dir,
314 | image_column=args.image_column,
315 | save_path=args.dataset_save_path,
316 | )
317 | elif args.dataset_name and not args.input_image and not args.input_folder:
318 | dataset = load_dataset(args.dataset_name)["train"]
319 |
320 | elif args.input_image and not args.input_folder:
321 | # Create dataset from single input image
322 | dataset = Dataset.from_dict({"image": [args.input_image]}).cast_column("image", Image())
323 |
324 | if args.input_folder:
325 | # Create dataset from input folder
326 | extensions = ["jpg", "jpeg", "png", "webp"]
327 | exclude_folders = args.exclude_folders.split(",") if args.exclude_folders else []
328 |
329 | filepaths = []
330 | for root, dirs, files in os.walk(args.input_folder, followlinks=True):
331 | # Resolve symbolic link to actual path and exclude based on actual path
332 | resolved_root = os.path.realpath(root)
333 | for exclude_folder in exclude_folders:
334 | if exclude_folder in resolved_root:
335 | dirs[:] = []
336 | break
337 | for file in files:
338 | if file.lower().endswith(tuple(extensions)):
339 | filepaths.append(os.path.join(root, file))
340 |
341 | if not filepaths:
342 | print(f"No images with extensions {extensions} found in {args.input_folder}.")
343 | exit(1)
344 |
345 | dataset = Dataset.from_dict({"image": filepaths}).cast_column("image", Image())
346 |
347 | if args.vae_path and args.taming_model_path:
348 | raise Exception("You can't pass vae_path and taming args at the same time.")
349 |
350 | if args.vae_path and not args.use_paintmind:
351 | accelerator.print("Loading Muse VQGanVAE")
352 | vae = VQGanVAE(
353 | dim=args.dim,
354 | vq_codebook_size=args.vq_codebook_size,
355 | vq_codebook_dim=args.vq_codebook_dim,
356 | channels=args.channels,
357 | layers=args.layers,
358 | discr_layers=args.discr_layers,
359 | ).to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
360 |
361 | if args.latest_checkpoint:
362 | accelerator.print("Finding latest checkpoint...")
363 | orig_vae_path = args.vae_path
364 |
365 | if os.path.isfile(args.vae_path) or ".pt" in args.vae_path:
366 | # If args.vae_path is a file, split it into directory and filename
367 | args.vae_path, _ = os.path.split(args.vae_path)
368 |
369 | checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt"))
370 | if checkpoint_files:
371 | latest_checkpoint_file = max(
372 | checkpoint_files,
373 | key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1))
374 | if not x.endswith("ema.pt")
375 | else -1,
376 | )
377 |
378 | # Check if latest checkpoint is empty or unreadable
379 | if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
380 | latest_checkpoint_file, os.R_OK
381 | ):
382 | accelerator.print(
383 | f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable."
384 | )
385 | if len(checkpoint_files) > 1:
386 | # Use the second last checkpoint as a fallback
387 | latest_checkpoint_file = max(
388 | checkpoint_files[:-1],
389 | key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1))
390 | if not x.endswith("ema.pt")
391 | else -1,
392 | )
393 | accelerator.print("Using second last checkpoint: ", latest_checkpoint_file)
394 | else:
395 | accelerator.print("No usable checkpoint found.")
396 | elif latest_checkpoint_file != orig_vae_path:
397 | accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file)
398 | else:
399 | accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path)
400 |
401 | args.vae_path = latest_checkpoint_file
402 | else:
403 | accelerator.print("No checkpoints found in directory: ", args.vae_path)
404 | else:
405 | accelerator.print("Resuming VAE from: ", args.vae_path)
406 |
407 | vae.load(args.vae_path)
408 |
409 | if args.use_paintmind:
410 | # load VAE
411 | accelerator.print("Loading VQVAE from 'neggles/vaedump/vit-s-vqgan-f4' ...")
412 | vae: VQVAE = VQVAE.from_pretrained("neggles/vaedump", subfolder="vit-s-vqgan-f4")
413 |
414 | elif args.taming_model_path:
415 | print("Loading Taming VQGanVAE")
416 | vae = VQGanVAETaming(
417 | vqgan_model_path=args.taming_model_path,
418 | vqgan_config_path=args.taming_config_path,
419 | )
420 | args.num_tokens = vae.codebook_size
421 | args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2
422 |
423 | # move vae to device
424 | vae = vae.to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
425 |
426 | # Use the parameters() method to get an iterator over all the learnable parameters of the model
427 | total_params = sum(p.numel() for p in vae.parameters())
428 |
429 | print(f"Total number of parameters: {format(total_params, ',d')}")
430 |
431 | # then you plug the vae and transformer into your MaskGit as so
432 |
433 | dataset = ImageDataset(
434 | dataset,
435 | args.image_size,
436 | image_column=args.image_column,
437 | center_crop=True if not args.no_center_crop and not args.random_crop else False,
438 | flip=not args.no_flip,
439 | random_crop=args.random_crop if args.random_crop else False,
440 | alpha_channel=False if args.channels == 3 else True,
441 | )
442 |
443 | if args.input_image and not args.input_folder:
444 | image_id = 0 if not args.random_image else random.randint(0, len(dataset))
445 |
446 | os.makedirs(f"{args.results_dir}/outputs", exist_ok=True)
447 |
448 | save_image(
449 | dataset[image_id],
450 | f"{args.results_dir}/outputs/input.{str(args.input_image).split('.')[-1]}",
451 | format="PNG",
452 | )
453 |
454 | _, ids, _ = vae.encode(
455 | dataset[image_id][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
456 | )
457 | recon = vae.decode_from_ids(ids)
458 | save_image(recon, f"{args.results_dir}/outputs/output.{str(args.input_image).split('.')[-1]}")
459 |
460 | if not args.input_image and not args.input_folder:
461 | image_id = 0 if not args.random_image else random.randint(0, len(dataset))
462 |
463 | os.makedirs(f"{args.results_dir}/outputs", exist_ok=True)
464 |
465 | save_image(dataset[image_id], f"{args.results_dir}/outputs/input.png")
466 |
467 | _, ids, _ = vae.encode(
468 | dataset[image_id][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
469 | )
470 | recon = vae.decode_from_ids(ids)
471 | save_image(recon, f"{args.results_dir}/outputs/output.png")
472 |
473 | if args.input_folder:
474 | # Create output directory and save input images and reconstructions as grids
475 | output_dir = os.path.join(args.results_dir, "outputs", os.path.basename(args.input_folder))
476 | os.makedirs(output_dir, exist_ok=True)
477 |
478 | for i in tqdm(range(len(dataset))):
479 | retries = 0
480 | while True:
481 | try:
482 | save_image(dataset[i], f"{output_dir}/input.png")
483 |
484 | if not args.use_paintmind:
485 | # encode
486 | _, ids, _ = vae.encode(
487 | dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
488 | )
489 | # decode
490 | recon = vae.decode_from_ids(ids)
491 | # print (recon.shape) # torch.Size([1, 3, 512, 1136])
492 | save_image(recon, f"{output_dir}/output.png")
493 | else:
494 | # encode
495 | encoded, _, _ = vae.encode(
496 | dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}")
497 | )
498 |
499 | # decode
500 | recon = vae.decode(encoded).squeeze(0)
501 | recon = torch.clamp(recon, -1.0, 1.0)
502 | save_image(recon, f"{output_dir}/output.png")
503 |
504 | # Load input and output images
505 | input_image = PIL.Image.open(f"{output_dir}/input.png")
506 | output_image = PIL.Image.open(f"{output_dir}/output.png")
507 |
508 | # Create horizontal grid with input and output images
509 | grid_image = PIL.Image.new(
510 | "RGB" if args.channels == 3 else "RGBA",
511 | (input_image.width + output_image.width, input_image.height),
512 | )
513 | grid_image.paste(input_image, (0, 0))
514 | grid_image.paste(output_image, (input_image.width, 0))
515 |
516 | # Save grid
517 | now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S")
518 | hash = hashlib.sha1(input_image.tobytes()).hexdigest()
519 |
520 | filename = f"{hash}_{now}-{os.path.basename(args.vae_path)}.png"
521 | grid_image.save(f"{output_dir}/{filename}", format="PNG")
522 |
523 | # Remove input and output images after the grid was made.
524 | os.remove(f"{output_dir}/input.png")
525 | os.remove(f"{output_dir}/output.png")
526 |
527 | del _
528 | del ids
529 | del recon
530 |
531 | torch.cuda.empty_cache()
532 | torch.cuda.ipc_collect()
533 |
534 | break # Exit the retry loop if there were no errors
535 |
536 | except RuntimeError as e:
537 | if "out of memory" in str(e) and retries < args.max_retries:
538 | retries += 1
539 | # print(f"Out of Memory. Retry #{retries}")
540 | torch.cuda.empty_cache()
541 | torch.cuda.ipc_collect()
542 | continue # Retry the loop
543 |
544 | else:
545 | if "out of memory" not in str(e):
546 | print(e)
547 | else:
548 | print(f"Skipping image {i} after {retries} retries due to out of memory error")
549 | break # Exit the retry loop after too many retries
550 |
551 |
552 | if __name__ == "__main__":
553 | main()
554 |
--------------------------------------------------------------------------------
/muse.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sygil-Dev/muse-maskgit-pytorch/7d7c00c39e29af0585a1c616e4626528e13b61da/muse.png
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from .muse_maskgit_pytorch import MaskGit, MaskGitTransformer, Muse, TokenCritic, Transformer
2 | from .trainers import MaskGitTrainer, VQGanVAETrainer, get_accelerator
3 | from .vqgan_vae import VQGanVAE
4 | from .vqgan_vae_taming import VQGanVAETaming
5 |
6 | __all__ = [
7 | "VQGanVAE",
8 | "VQGanVAETaming",
9 | "Transformer",
10 | "MaskGit",
11 | "Muse",
12 | "MaskGitTransformer",
13 | "TokenCritic",
14 | "VQGanVAETrainer",
15 | "MaskGitTrainer",
16 | "get_accelerator",
17 | ]
18 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/attn/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Sygil-Dev/muse-maskgit-pytorch/7d7c00c39e29af0585a1c616e4626528e13b61da/muse_maskgit_pytorch/attn/__init__.py
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/attn/attn_test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import BoolTensor, FloatTensor, allclose, arange, manual_seed, no_grad, randn
3 | from torch.nn.functional import pad
4 |
5 | from muse_maskgit_pytorch.attn.ein_attn import Attention as EinAttn
6 | from muse_maskgit_pytorch.attn.xformers_attn import Attention as XformersAttn
7 |
8 | device = torch.device("cuda")
9 | dtype = torch.float32
10 | seed = 42
11 |
12 | # realistically this would be 320 in stable-diffusion, but I'm going smaller during testing
13 | vision_dim = 64
14 |
15 | attn_init_params = {
16 | "dim": vision_dim,
17 | "dim_head": 64,
18 | # realistically this would be at least 5
19 | "heads": 2,
20 | "cross_attend": True,
21 | "scale": 8,
22 | }
23 |
24 | with no_grad():
25 | # seed RNG before we initialize any layers, so that both will end up with same params
26 | manual_seed(seed)
27 | ein_attn = EinAttn(**attn_init_params).to(device, dtype).eval()
28 | # commented-out scaled dot product attention because it didn't support flash attn, so we'll try with xformers instead.
29 | # manual_seed(seed)
30 | # sdp_attn = SDPAttn(**attn_init_params).to(device, dtype).eval()
31 | manual_seed(seed)
32 | xfo_attn = XformersAttn(**attn_init_params).to(device, dtype).eval()
33 |
34 | batch_size = 2
35 |
36 | # realistically this would be 64**2 in stable-diffusion
37 | vision_tokens = 32**2 # 1024
38 |
39 | # generate rand on-CPU for cross-platform determinism of results
40 | x: FloatTensor = randn(batch_size, vision_tokens, vision_dim, dtype=dtype).to(device)
41 |
42 | # I've said text here simply as an example of something you could cross-attend to
43 | text_tokens = 16 # CLIP would be 77
44 | # for a *general* cross-attention Module:
45 | # kv_in_dim could differ from q_in_dim, but this attention Module requires x and context to have same dim.
46 | text_dim = vision_dim
47 | context: FloatTensor = randn(batch_size, text_tokens, text_dim, dtype=dtype).to(device)
48 |
49 | # attend to just the first two tokens in each text condition (e.g. if both were uncond, so [BOS, EOS] followed by PAD tokens)
50 | context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1).contiguous()
51 |
52 | # for xformers cutlassF kernel: masks are only supported for keys whose lengths are multiples of 8:
53 | # https://gist.github.com/Birch-san/0c36d228e1d4b881a06d1c6e5289d569
54 | # so, we add whatever we feel like to the end of the key to extend it to a multiple of 8,
55 | # and add "discard" tokens to the mask to get rid of the excess
56 | # note: muse will add an extra "null" token to our context, so we'll account for that in advance
57 | mask_length = context_mask.shape[-1] + 1
58 | extra_tokens_needed = 8 - (mask_length % 8)
59 | # 0-pad mask to multiple of 8 tokens
60 | xfo_context_mask = pad(context_mask, (0, extra_tokens_needed))
61 | # replicate-pad embedding to multiple of 8 tokens (mask will hide the extra tokens)
62 | xfo_context = pad(context, (0, 0, 0, extra_tokens_needed), "replicate")
63 |
64 | ein_result: FloatTensor = ein_attn.forward(x, context, context_mask)
65 | # sdp attn works, but only supports flash attn when context_mask is None.
66 | # with sdp_kernel(enable_math=False):
67 | # sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask)
68 | xfo_attn: FloatTensor = xfo_attn.forward(x, xfo_context, xfo_context_mask)
69 |
70 | # default rtol
71 | rtol = 1e-5
72 | # atol would normally be 1e-8
73 | atol = 5e-7
74 | # assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}"
75 | if not allclose(ein_result, xfo_attn, rtol=rtol, atol=atol):
76 | raise RuntimeError(
77 | f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}"
78 | )
79 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/attn/ein_attn.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from einops import rearrange, repeat
4 | from torch import einsum, nn
5 |
6 |
7 | # helpers
8 | def exists(val):
9 | return val is not None
10 |
11 |
12 | def l2norm(t):
13 | return F.normalize(t, dim=-1)
14 |
15 |
16 | class LayerNorm(nn.Module):
17 | def __init__(self, dim):
18 | super().__init__()
19 | self.gamma = nn.Parameter(torch.ones(dim))
20 | self.register_buffer("beta", torch.zeros(dim))
21 |
22 | def forward(self, x):
23 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
24 |
25 |
26 | class Attention(nn.Module):
27 | def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8):
28 | super().__init__()
29 | self.scale = scale
30 | self.heads = heads
31 | inner_dim = dim_head * heads
32 |
33 | self.cross_attend = cross_attend
34 | self.norm = LayerNorm(dim)
35 |
36 | self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head))
37 |
38 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
39 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
40 |
41 | self.q_scale = nn.Parameter(torch.ones(dim_head))
42 | self.k_scale = nn.Parameter(torch.ones(dim_head))
43 |
44 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
45 |
46 | def forward(self, x, context=None, context_mask=None):
47 | assert not (exists(context) ^ self.cross_attend)
48 |
49 | h = self.heads
50 | x = self.norm(x)
51 |
52 | kv_input = context if self.cross_attend else x
53 |
54 | q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1))
55 |
56 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
57 |
58 | nk, nv = self.null_kv
59 | nk, nv = map(lambda t: repeat(t, "h 1 d -> b h 1 d", b=x.shape[0]), (nk, nv))
60 |
61 | k = torch.cat((nk, k), dim=-2)
62 | v = torch.cat((nv, v), dim=-2)
63 |
64 | q, k = map(l2norm, (q, k))
65 | q = q * self.q_scale
66 | k = k * self.k_scale
67 |
68 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale
69 |
70 | if exists(context_mask):
71 | context_mask = rearrange(context_mask, "b j -> b 1 1 j")
72 | context_mask = F.pad(context_mask, (1, 0), value=True)
73 |
74 | mask_value = -torch.finfo(sim.dtype).max
75 | sim = sim.masked_fill(~context_mask, mask_value)
76 |
77 | attn = sim.softmax(dim=-1)
78 | out = einsum("b h i j, b h j d -> b h i d", attn, v)
79 |
80 | out = rearrange(out, "b h n d -> b n (h d)")
81 | return self.to_out(out)
82 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/attn/sdp_attn.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import rearrange, repeat
6 | from torch import BoolTensor, FloatTensor, nn
7 | from torch.nn.functional import scaled_dot_product_attention
8 |
9 |
10 | def l2norm(t):
11 | return F.normalize(t, dim=-1)
12 |
13 |
14 | class LayerNorm(nn.Module):
15 | def __init__(self, dim):
16 | super().__init__()
17 | self.gamma = nn.Parameter(torch.ones(dim))
18 | self.register_buffer("beta", torch.zeros(dim))
19 |
20 | def forward(self, x):
21 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
22 |
23 |
24 | class Attention(nn.Module):
25 | def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8):
26 | super().__init__()
27 | self.heads = heads
28 | inner_dim = dim_head * heads
29 |
30 | self.cross_attend = cross_attend
31 | self.norm = LayerNorm(dim)
32 |
33 | self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head))
34 |
35 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
37 |
38 | typical_scale = dim_head**-0.5
39 | scale_ratio = scale / typical_scale
40 | self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio))
41 | self.k_scale = nn.Parameter(torch.ones(dim_head))
42 |
43 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
44 |
45 | def forward(
46 | self, x: FloatTensor, context: Optional[FloatTensor] = None, context_mask: Optional[BoolTensor] = None
47 | ):
48 | assert (context is None) != self.cross_attend
49 |
50 | h = self.heads
51 | # TODO: you could fuse this layernorm with the linear that follows it, e.g. via TransformerEngine
52 | x = self.norm(x)
53 |
54 | kv_input = context if self.cross_attend else x
55 |
56 | # TODO: to_q and to_kvs could be combined into one to_qkv
57 | q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1))
58 |
59 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
60 |
61 | nk, nv = self.null_kv
62 | nk, nv = map(lambda t: repeat(t, "h 1 d -> b h 1 d", b=x.shape[0]), (nk, nv))
63 |
64 | k = torch.cat((nk, k), dim=-2)
65 | v = torch.cat((nv, v), dim=-2)
66 |
67 | q, k = map(l2norm, (q, k))
68 | q = q * self.q_scale
69 | k = k * self.k_scale
70 |
71 | if context_mask is not None:
72 | context_mask = rearrange(context_mask, "b j -> b 1 1 j")
73 | context_mask = F.pad(context_mask, (1, 0), value=True)
74 |
75 | out: FloatTensor = scaled_dot_product_attention(q, k, v, context_mask)
76 |
77 | out = rearrange(out, "b h n d -> b n (h d)")
78 | return self.to_out(out)
79 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/attn/xformers_attn.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from einops import rearrange, repeat
6 | from torch import BoolTensor, FloatTensor, nn
7 | from xformers.ops import memory_efficient_attention
8 |
9 |
10 | def l2norm(t):
11 | return F.normalize(t, dim=-1)
12 |
13 |
14 | class LayerNorm(nn.Module):
15 | def __init__(self, dim):
16 | super().__init__()
17 | self.gamma = nn.Parameter(torch.ones(dim))
18 | self.register_buffer("beta", torch.zeros(dim))
19 |
20 | def forward(self, x):
21 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)
22 |
23 |
24 | class Attention(nn.Module):
25 | def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8):
26 | super().__init__()
27 | self.heads = heads
28 | inner_dim = dim_head * heads
29 |
30 | self.cross_attend = cross_attend
31 | self.norm = LayerNorm(dim)
32 |
33 | self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head))
34 |
35 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
37 |
38 | typical_scale = dim_head**-0.5
39 | scale_ratio = scale / typical_scale
40 | self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio))
41 | self.k_scale = nn.Parameter(torch.ones(dim_head))
42 |
43 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
44 |
45 | def forward(
46 | self, x: FloatTensor, context: Optional[FloatTensor] = None, context_mask: Optional[BoolTensor] = None
47 | ):
48 | assert (context is None) != self.cross_attend
49 |
50 | h = self.heads
51 | # TODO: you could fuse this layernorm with the linear that follows it, e.g. via TransformerEngine
52 | x = self.norm(x)
53 |
54 | kv_input = context if self.cross_attend else x
55 |
56 | # TODO: to_q and to_kvs could be combined into one to_qkv
57 | q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1))
58 |
59 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q, k, v))
60 |
61 | nk, nv = self.null_kv
62 | nk, nv = map(lambda t: repeat(t, "h 1 d -> b 1 h d", b=x.shape[0]), (nk, nv))
63 |
64 | k = torch.cat((nk, k), dim=-3)
65 | v = torch.cat((nv, v), dim=-3)
66 |
67 | q, k = map(l2norm, (q, k))
68 | q = q * self.q_scale
69 | k = k * self.k_scale
70 |
71 | if context_mask is None:
72 | attn_bias = None
73 | else:
74 | context_mask = F.pad(context_mask, (1, 0), value=True)
75 | context_mask = rearrange(context_mask, "b j -> b 1 1 j")
76 | attn_bias = torch.where(context_mask is True, 0.0, -10000.0)
77 | attn_bias = attn_bias.expand(-1, h, q.size(1), -1)
78 |
79 | out: FloatTensor = memory_efficient_attention(q, k, v, attn_bias)
80 |
81 | out = rearrange(out, "b n h d -> b n (h d)")
82 | return self.to_out(out)
83 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import shutil
4 | import sys
5 | import time
6 | from pathlib import Path
7 | from threading import Thread
8 |
9 | import datasets
10 | import torch
11 | from datasets import Image, load_from_disk
12 | from PIL import (
13 | Image as pImage,
14 | ImageFile,
15 | )
16 | from torch.utils.data import DataLoader, Dataset, random_split
17 | from torchvision import transforms as T
18 |
19 | try:
20 | import torch_xla
21 | import torch_xla.core.xla_model as xm
22 | from tqdm_loggable.auto import tqdm
23 | except ImportError:
24 | from tqdm import tqdm
25 |
26 | from io import BytesIO
27 |
28 | import requests
29 | from transformers import T5Tokenizer
30 |
31 | from muse_maskgit_pytorch.t5 import MAX_LENGTH
32 |
33 | ImageFile.LOAD_TRUNCATED_IMAGES = True
34 | pImage.MAX_IMAGE_PIXELS = None
35 |
36 |
37 | class ImageDataset(Dataset):
38 | def __init__(
39 | self,
40 | dataset,
41 | image_size,
42 | image_column="image",
43 | flip=True,
44 | center_crop=True,
45 | stream=False,
46 | using_taming=False,
47 | random_crop=False,
48 | alpha_channel=True,
49 | ):
50 | super().__init__()
51 | self.dataset = dataset
52 | self.image_column = image_column
53 | self.stream = stream
54 | transform_list = [
55 | T.Resize(image_size),
56 | ]
57 |
58 | if flip:
59 | transform_list.append(T.RandomHorizontalFlip())
60 | if center_crop and not random_crop:
61 | transform_list.append(T.CenterCrop(image_size))
62 | if random_crop:
63 | transform_list.append(T.RandomCrop(image_size, pad_if_needed=True))
64 | if alpha_channel:
65 | transform_list.append(T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" else img))
66 | else:
67 | transform_list.append(T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img))
68 |
69 | transform_list.append(T.ToTensor())
70 | self.transform = T.Compose(transform_list)
71 | self.using_taming = using_taming
72 |
73 | def __len__(self):
74 | if not self.stream:
75 | return len(self.dataset)
76 | else:
77 | raise AssertionError("Streaming doesnt support grabbing dataset length")
78 |
79 | def __getitem__(self, index):
80 | image = self.dataset[index][self.image_column]
81 | if self.using_taming:
82 | return self.transform(image) - 0.5
83 | else:
84 | return self.transform(image)
85 |
86 |
87 | class ImageTextDataset(ImageDataset):
88 | def __init__(
89 | self,
90 | dataset,
91 | image_size: int,
92 | tokenizer: T5Tokenizer,
93 | image_column="image",
94 | caption_column="caption",
95 | flip=True,
96 | center_crop=True,
97 | stream=False,
98 | using_taming=False,
99 | random_crop=False,
100 | ):
101 | super().__init__(
102 | dataset,
103 | image_size=image_size,
104 | image_column=image_column,
105 | flip=flip,
106 | stream=stream,
107 | center_crop=center_crop,
108 | using_taming=using_taming,
109 | random_crop=random_crop,
110 | )
111 | self.caption_column: str = caption_column
112 | self.tokenizer: T5Tokenizer = tokenizer
113 |
114 | def __getitem__(self, index):
115 | image = self.dataset[index][self.image_column]
116 | descriptions = self.dataset[index][self.caption_column]
117 | if self.caption_column is None or descriptions is None:
118 | text = ""
119 | elif isinstance(descriptions, list):
120 | if len(descriptions) == 0:
121 | text = ""
122 | else:
123 | text = random.choice(descriptions)
124 | else:
125 | text = descriptions
126 | # max length from the paper
127 | encoded = self.tokenizer.batch_encode_plus(
128 | [str(text)],
129 | return_tensors="pt",
130 | padding="max_length",
131 | max_length=MAX_LENGTH,
132 | truncation=True,
133 | )
134 |
135 | input_ids = encoded.input_ids
136 | attn_mask = encoded.attention_mask
137 |
138 | if self.using_taming:
139 | return self.transform(image) - 0.5, input_ids[0], attn_mask[0]
140 | else:
141 | return self.transform(image), input_ids[0], attn_mask[0]
142 |
143 |
144 | class URLTextDataset(ImageDataset):
145 | def __init__(
146 | self,
147 | dataset,
148 | image_size: int,
149 | tokenizer: T5Tokenizer,
150 | image_column="image",
151 | caption_column="caption",
152 | flip=True,
153 | center_crop=True,
154 | using_taming=True,
155 | ):
156 | super().__init__(
157 | dataset,
158 | image_size=image_size,
159 | image_column=image_column,
160 | flip=flip,
161 | center_crop=center_crop,
162 | using_taming=using_taming,
163 | )
164 | self.caption_column: str = caption_column
165 | self.tokenizer: T5Tokenizer = tokenizer
166 |
167 | def __getitem__(self, index):
168 | try:
169 | image = pImage.open(BytesIO(requests.get(self.dataset[index][self.image_column]).content))
170 | except ConnectionError:
171 | try:
172 | print("Image request failure, attempting next image")
173 | index += 1
174 |
175 | image = pImage.open(BytesIO(requests.get(self.dataset[index][self.image_column]).content))
176 | except ConnectionError:
177 | raise ConnectionError("Unable to request image from the Dataset")
178 |
179 | descriptions = self.dataset[index][self.caption_column]
180 | if self.caption_column is None or descriptions is None:
181 | text = ""
182 | elif isinstance(descriptions, list):
183 | if len(descriptions) == 0:
184 | text = ""
185 | else:
186 | text = random.choice(descriptions)
187 | else:
188 | text = descriptions
189 | # max length from the paper
190 | encoded = self.tokenizer.batch_encode_plus(
191 | [str(text)],
192 | return_tensors="pt",
193 | padding="max_length",
194 | max_length=MAX_LENGTH,
195 | truncation=True,
196 | )
197 |
198 | input_ids = encoded.input_ids
199 | attn_mask = encoded.attention_mask
200 | if self.using_taming:
201 | return self.transform(image) - 0.5, input_ids[0], attn_mask[0]
202 | else:
203 | return self.transform(image), input_ids[0], attn_mask[0]
204 |
205 |
206 | class LocalTextImageDataset(Dataset):
207 | def __init__(
208 | self,
209 | path,
210 | image_size,
211 | tokenizer,
212 | flip=True,
213 | center_crop=True,
214 | using_taming=False,
215 | random_crop=False,
216 | alpha_channel=False,
217 | ):
218 | super().__init__()
219 | self.tokenizer = tokenizer
220 | self.using_taming = using_taming
221 |
222 | print("Building dataset...")
223 |
224 | extensions = ["jpg", "jpeg", "png", "webp"]
225 | self.image_paths = []
226 | self.caption_pair = []
227 | self.images = []
228 |
229 | for ext in extensions:
230 | self.image_paths.extend(list(Path(path).rglob(f"*.{ext}")))
231 |
232 | random.shuffle(self.image_paths)
233 | for image_path in tqdm(self.image_paths):
234 | # check image size and ignore images with 0 byte.
235 | if os.path.getsize(image_path) == 0:
236 | continue
237 | caption_path = image_path.with_suffix(".txt")
238 | if os.path.exists(str(caption_path)):
239 | captions = str(caption_path)
240 | else:
241 | captions = ""
242 | self.images.append(image_path)
243 | self.caption_pair.append(captions)
244 |
245 | transform_list = [
246 | T.Resize(image_size),
247 | ]
248 | if flip:
249 | transform_list.append(T.RandomHorizontalFlip())
250 | if center_crop and not random_crop:
251 | transform_list.append(T.CenterCrop(image_size))
252 | if random_crop:
253 | transform_list.append(T.RandomCrop(image_size, pad_if_needed=True))
254 | if alpha_channel:
255 | transform_list.append(T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" else img))
256 | else:
257 | transform_list.append(T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img))
258 | transform_list.append(T.ToTensor())
259 | self.transform = T.Compose(transform_list)
260 |
261 | def __len__(self):
262 | return len(self.caption_pair)
263 |
264 | def __getitem__(self, index):
265 | image = self.images[index]
266 | image = pImage.open(image)
267 | descriptions = self.caption_pair[index]
268 | if descriptions is None or descriptions == "":
269 | text = ""
270 | else:
271 | text = Path(descriptions).read_text(encoding="utf-8").split("\n")
272 |
273 | # max length from the paper
274 | encoded = self.tokenizer.batch_encode_plus(
275 | [str(text)],
276 | return_tensors="pt",
277 | padding="max_length",
278 | max_length=MAX_LENGTH,
279 | truncation=True,
280 | )
281 |
282 | input_ids = encoded.input_ids
283 | attn_mask = encoded.attention_mask
284 | if self.using_taming:
285 | return self.transform(image) - 0.5, input_ids[0], attn_mask[0]
286 | else:
287 | return self.transform(image), input_ids[0], attn_mask[0]
288 |
289 |
290 | def get_directory_size(path):
291 | total_size = 0
292 | for dirpath, dirnames, filenames in os.walk(path):
293 | for f in filenames:
294 | fp = os.path.join(dirpath, f)
295 | total_size += os.path.getsize(fp)
296 | return total_size
297 |
298 |
299 | def save_dataset_with_progress(dataset, save_path):
300 | # Estimate the total size of the dataset in bytes
301 | total_size = sys.getsizeof(dataset)
302 |
303 | # Start saving the dataset in a separate thread
304 | save_thread = Thread(target=dataset.save_to_disk, args=(save_path,))
305 | save_thread.start()
306 |
307 | # Create a tqdm progress bar and update it periodically
308 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
309 | while save_thread.is_alive():
310 | if os.path.exists(save_path):
311 | size = get_directory_size(save_path)
312 | # Update the progress bar based on the current size of the saved file
313 | pbar.update(size - pbar.n) # Update by the difference between current and previous size
314 | time.sleep(1)
315 |
316 |
317 | def get_dataset_from_dataroot(
318 | data_root,
319 | image_column="image",
320 | caption_column="caption",
321 | save_path="dataset",
322 | save=True,
323 | ):
324 | # Check if data_root is a symlink and resolve it to its target location if it is
325 | if os.path.islink(data_root):
326 | data_root = os.path.realpath(data_root)
327 |
328 | if os.path.exists(save_path):
329 | # Get the modified time of save_path
330 | save_path_mtime = os.stat(save_path).st_mtime
331 |
332 | if save:
333 | # Traverse the directory tree of data_root and get the modified time of all files and subdirectories
334 | print("Checking modified date of all the files and subdirectories in the dataset folder.")
335 | data_root_mtime = max(
336 | os.stat(os.path.join(root, f)).st_mtime
337 | for root, dirs, files in os.walk(data_root)
338 | for f in files + dirs
339 | )
340 |
341 | # Check if data_root is newer than save_path
342 | if data_root_mtime > save_path_mtime:
343 | print(
344 | "The data_root folder has being updated recently. Removing previously saved dataset and updating it."
345 | )
346 | shutil.rmtree(save_path, ignore_errors=True)
347 | else:
348 | print("The dataset is up-to-date. Loading...")
349 | # Load the dataset from save_path if it is up-to-date
350 | return load_from_disk(save_path)
351 |
352 | extensions = ["jpg", "jpeg", "png", "webp"]
353 | image_paths = []
354 |
355 | for ext in extensions:
356 | image_paths.extend(list(Path(data_root).rglob(f"*.{ext}")))
357 |
358 | random.shuffle(image_paths)
359 | data_dict = {image_column: [], caption_column: []}
360 | for image_path in tqdm(image_paths):
361 | # check image size and ignore images with 0 byte.
362 | if os.path.getsize(image_path) == 0:
363 | continue
364 | caption_path = image_path.with_suffix(".txt")
365 | if os.path.exists(str(caption_path)):
366 | captions = caption_path.read_text(encoding="utf-8").split("\n")
367 | captions = list(filter(lambda t: len(t) > 0, captions))
368 | else:
369 | captions = []
370 | image_path = str(image_path)
371 | data_dict[image_column].append(image_path)
372 | data_dict[caption_column].append(captions)
373 | dataset = datasets.Dataset.from_dict(data_dict)
374 | dataset = dataset.cast_column(image_column, Image())
375 |
376 | if save:
377 | save_dataset_with_progress(dataset, save_path)
378 |
379 | return dataset
380 |
381 |
382 | def split_dataset_into_dataloaders(dataset, valid_frac=0.05, seed=42, batch_size=1):
383 | print(f"Dataset length: {len(dataset)} samples")
384 | if valid_frac > 0:
385 | train_size = int((1 - valid_frac) * len(dataset))
386 | valid_size = len(dataset) - train_size
387 | print(f"Splitting dataset into {train_size} training samples and {valid_size} validation samples")
388 | split_generator = torch.Generator().manual_seed(seed)
389 | train_dataset, validation_dataset = random_split(
390 | dataset,
391 | [train_size, valid_size],
392 | generator=split_generator,
393 | )
394 | else:
395 | print("Using shared dataset for training and validation")
396 | train_dataset = dataset
397 | validation_dataset = dataset
398 |
399 | dataloader = DataLoader(
400 | train_dataset,
401 | batch_size=batch_size,
402 | shuffle=True,
403 | )
404 |
405 | validation_dataloader = DataLoader(
406 | validation_dataset,
407 | batch_size=batch_size,
408 | shuffle=False,
409 | )
410 | return dataloader, validation_dataloader
411 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/distributed_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions for optional distributed execution.
3 |
4 | To use,
5 | 1. set the `BACKENDS` to the ones you want to make available,
6 | 2. in the script, wrap the argument parser with `wrap_arg_parser`,
7 | 3. in the script, set and use the backend by calling
8 | `set_backend_from_args`.
9 |
10 | You can check whether a backend is in use with the `using_backend`
11 | function.
12 | """
13 |
14 |
15 | is_distributed = None
16 | """Whether we are distributed."""
17 | backend = None
18 | """Backend in usage."""
19 |
20 |
21 | def require_set_backend():
22 | """Raise an `AssertionError` when the backend has not been set."""
23 | assert backend is not None, (
24 | "distributed backend is not set. Please call "
25 | "`distributed_utils.set_backend_from_args` at the start of your script"
26 | )
27 |
28 |
29 | def using_backend(test_backend):
30 | """Return whether the backend is set to `test_backend`.
31 |
32 | `test_backend` may be a string of the name of the backend or
33 | its class.
34 | """
35 | require_set_backend()
36 | if isinstance(test_backend, str):
37 | return backend.BACKEND_NAME == test_backend
38 | return isinstance(backend, test_backend)
39 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .attention import CrossAttention, MemoryEfficientCrossAttention
2 | from .mlp import SwiGLU, SwiGLUFFN, SwiGLUFFNFused
3 |
4 | __all__ = [
5 | "SwiGLU",
6 | "SwiGLUFFN",
7 | "SwiGLUFFNFused",
8 | "CrossAttention",
9 | "MemoryEfficientCrossAttention",
10 | ]
11 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | from typing import Any, Callable, Optional
3 |
4 | from einops import rearrange
5 | from torch import nn
6 |
7 | try:
8 | from xformers.ops import memory_efficient_attention
9 | except ImportError:
10 | memory_efficient_attention = None
11 |
12 |
13 | def exists(x):
14 | return x is not None
15 |
16 |
17 | def default(val, d):
18 | if exists(val):
19 | return val
20 | return d() if isfunction(d) else d
21 |
22 |
23 | class CrossAttention(nn.Module):
24 | def __init__(
25 | self,
26 | query_dim,
27 | context_dim=None,
28 | heads=8,
29 | dim_head=64,
30 | dropout=0.0,
31 | ):
32 | super().__init__()
33 | inner_dim = dim_head * heads
34 | context_dim = default(context_dim, query_dim)
35 |
36 | self.scale = dim_head**-0.5
37 | self.heads = heads
38 |
39 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
40 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
41 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
42 |
43 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
44 |
45 | def forward(self, x, context=None):
46 | h = self.heads
47 |
48 | q = self.to_q(x)
49 | context = default(context, x)
50 | k = self.to_k(context)
51 | v = self.to_v(context)
52 |
53 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
54 | q = q * self.scale
55 |
56 | sim = q @ k.transpose(-2, -1)
57 | sim = sim.softmax(dim=-1)
58 |
59 | out = sim @ v
60 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
61 | return self.to_out(out)
62 |
63 |
64 | class MemoryEfficientCrossAttention(nn.Module):
65 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
66 | def __init__(
67 | self,
68 | query_dim,
69 | context_dim=None,
70 | heads=8,
71 | dim_head=64,
72 | dropout=0.0,
73 | ):
74 | super().__init__()
75 | inner_dim = dim_head * heads
76 | context_dim = default(context_dim, query_dim)
77 |
78 | self.heads = heads
79 | self.dim_head = dim_head
80 |
81 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
82 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
83 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
84 |
85 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
86 | self.attention_op: Optional[Callable] = None
87 |
88 | def forward(self, x, context=None):
89 | q = self.to_q(x)
90 | context = default(context, x)
91 | k = self.to_k(context)
92 | v = self.to_v(context)
93 |
94 | b, _, _ = q.shape
95 | q, k, v = map(
96 | lambda t: t.unsqueeze(3)
97 | .reshape(b, t.shape[1], self.heads, self.dim_head)
98 | .permute(0, 2, 1, 3)
99 | .reshape(b * self.heads, t.shape[1], self.dim_head)
100 | .contiguous(),
101 | (q, k, v),
102 | )
103 |
104 | out = memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op)
105 |
106 | out = (
107 | out.unsqueeze(0)
108 | .reshape(b, self.heads, out.shape[1], self.dim_head)
109 | .permute(0, 2, 1, 3)
110 | .reshape(b, out.shape[1], self.heads * self.dim_head)
111 | )
112 | return self.to_out(out)
113 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/modules/mlp.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from typing import Optional
8 |
9 | import torch.nn.functional as F
10 | from torch import Tensor, nn
11 |
12 |
13 | class SwiGLUFFN(nn.Module):
14 | def __init__(
15 | self,
16 | in_features: int,
17 | hidden_features: Optional[int] = None,
18 | out_features: Optional[int] = None,
19 | bias: bool = True,
20 | ) -> None:
21 | super().__init__()
22 | out_features = out_features or in_features
23 | hidden_features = hidden_features or in_features
24 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
25 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
26 |
27 | def forward(self, x: Tensor) -> Tensor:
28 | x12 = self.w12(x)
29 | x1, x2 = x12.chunk(2, dim=-1)
30 | hidden = F.silu(x1) * x2
31 | return self.w3(hidden)
32 |
33 |
34 | try:
35 | from xformers.ops import SwiGLU
36 | except ImportError:
37 | SwiGLU = SwiGLUFFN
38 |
39 |
40 | class SwiGLUFFNFused(SwiGLU):
41 | def __init__(
42 | self,
43 | in_features: int,
44 | hidden_features: Optional[int] = None,
45 | out_features: Optional[int] = None,
46 | bias: bool = True,
47 | ) -> None:
48 | out_features = out_features or in_features
49 | hidden_features = hidden_features or in_features
50 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
51 | super().__init__(
52 | in_features=in_features,
53 | hidden_features=hidden_features,
54 | out_features=out_features,
55 | bias=bias,
56 | )
57 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/t5.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from dataclasses import dataclass, field
3 | from functools import cached_property
4 | from os import PathLike
5 | from typing import Dict, List, Optional, Tuple, Union
6 |
7 | import torch
8 | from beartype import beartype
9 | from torch import Tensor
10 | from transformers import T5Config, T5EncoderModel, T5Tokenizer
11 |
12 | # disable t5 warnings and a few others to keep the console clean and nice.
13 | warnings.filterwarnings("ignore")
14 |
15 |
16 | # dataclass for T5 model info
17 | @dataclass
18 | class T5ModelInfo:
19 | name: str
20 | cache_dir: Optional[PathLike] = None
21 | dtype: Optional[torch.dtype] = torch.float32
22 | config: T5Config = field(init=False)
23 |
24 | def __post_init__(self):
25 | self.config = T5Config.from_pretrained(self.name, cache_dir=self.cache_dir)
26 | self._model = None
27 | self._tokenizer = None
28 |
29 | # Using cached_property to avoid loading the model/tokenizer until needed
30 | @cached_property
31 | def model(self) -> T5EncoderModel:
32 | if not self._model:
33 | self._model = T5EncoderModel.from_pretrained(
34 | self.name, cache_dir=self.cache_dir, torch_dtype=self.dtype
35 | )
36 | return self._model
37 |
38 | @cached_property
39 | def tokenizer(self) -> T5Tokenizer:
40 | if not self._tokenizer:
41 | self._tokenizer = T5Tokenizer.from_pretrained(
42 | self.name, cache_dir=self.cache_dir, torch_dtype=self.dtype
43 | )
44 | return self._tokenizer
45 |
46 |
47 | # config
48 | MAX_LENGTH = 512
49 | DEFAULT_T5_NAME = "google/t5-v1_1-base"
50 | T5_OBJECTS: Dict[str, T5ModelInfo] = {}
51 |
52 |
53 | def get_model_and_tokenizer(
54 | name: str, cache_path: Optional[PathLike] = None, dtype: torch.dtype = torch.float32
55 | ) -> Tuple[T5EncoderModel, T5Tokenizer]:
56 | global T5_OBJECTS
57 | if name not in T5_OBJECTS.keys():
58 | T5_OBJECTS[name] = T5ModelInfo(name=name, cache_dir=cache_path, dtype=dtype)
59 | return T5_OBJECTS[name].model, T5_OBJECTS[name].tokenizer
60 |
61 |
62 | def get_encoded_dim(
63 | name: str, cache_path: Optional[PathLike] = None, dtype: torch.dtype = torch.float32
64 | ) -> int:
65 | global T5_OBJECTS
66 | if name not in T5_OBJECTS.keys():
67 | T5_OBJECTS[name] = T5ModelInfo(name=name, cache_dir=cache_path, dtype=dtype)
68 | return T5_OBJECTS[name].config.d_model
69 |
70 |
71 | # encoding text
72 | @beartype
73 | def t5_encode_text_from_encoded(
74 | input_ids: Tensor,
75 | attn_mask: Tensor,
76 | t5: T5EncoderModel,
77 | output_device: Optional[Union[torch.device, str]] = None,
78 | ) -> Tensor:
79 | device = t5.device
80 | input_ids, attn_mask = input_ids.to(device), attn_mask.to(device)
81 | with torch.no_grad():
82 | output = t5(input_ids=input_ids, attention_mask=attn_mask)
83 | encoded_text = output.last_hidden_state.detach()
84 |
85 | attn_mask = attn_mask.bool()
86 | encoded_text: Tensor = encoded_text.masked_fill(attn_mask[..., None], 0.0)
87 | return encoded_text if output_device is None else encoded_text.to(output_device)
88 |
89 |
90 | @beartype
91 | def t5_encode_text(
92 | texts: Union[str, List[str]],
93 | tokenizer: T5Tokenizer,
94 | t5: T5EncoderModel,
95 | output_device: Optional[Union[torch.device, str]] = None,
96 | ) -> Tensor:
97 | if isinstance(texts, str):
98 | texts = [texts]
99 |
100 | encoded = tokenizer.batch_encode_plus(
101 | texts,
102 | return_tensors="pt",
103 | padding="max_length",
104 | max_length=MAX_LENGTH,
105 | truncation=True,
106 | )
107 | return t5_encode_text_from_encoded(encoded["input_ids"], encoded["attention_mask"], t5, output_device)
108 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | from .base_accelerated_trainer import get_accelerator
2 | from .maskgit_trainer import MaskGitTrainer
3 | from .vqvae_trainers import VQGanVAETrainer
4 |
5 | __all__ = [
6 | "VQGanVAETrainer",
7 | "MaskGitTrainer",
8 | "get_accelerator",
9 | ]
10 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/trainers/base_accelerated_trainer.py:
--------------------------------------------------------------------------------
1 | from os import PathLike
2 | from pathlib import Path
3 | from shutil import rmtree
4 | from typing import Optional, Union
5 |
6 | import accelerate
7 | import numpy as np
8 | import torch
9 | from accelerate import Accelerator, DistributedDataParallelKwargs, DistributedType
10 | from beartype import beartype
11 | from datasets import Dataset
12 | from lion_pytorch import Lion
13 | from PIL import Image
14 | from torch import nn
15 | from torch.optim import Adam, AdamW, Optimizer
16 | from torch.utils.data import DataLoader, random_split
17 | from torch_optimizer import (
18 | PID,
19 | QHM,
20 | SGDP,
21 | SGDW,
22 | SWATS,
23 | AccSGD,
24 | AdaBound,
25 | AdaMod,
26 | AdamP,
27 | AggMo,
28 | DiffGrad,
29 | Lamb,
30 | NovoGrad,
31 | QHAdam,
32 | RAdam,
33 | Shampoo,
34 | Yogi,
35 | )
36 | from transformers.optimization import Adafactor
37 |
38 | try:
39 | from accelerate.data_loader import MpDeviceLoaderWrapper
40 | except ImportError:
41 | MpDeviceLoaderWrapper = DataLoader
42 | pass
43 |
44 | try:
45 | from bitsandbytes.optim import Adam8bit, AdamW8bit, Lion8bit
46 | except ImportError:
47 | Adam8bit = AdamW8bit = Lion8bit = None
48 |
49 | try:
50 | import wandb
51 | except ImportError:
52 | wandb = None
53 |
54 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
55 |
56 |
57 | def noop(*args, **kwargs):
58 | pass
59 |
60 |
61 | # helper functions
62 |
63 |
64 | def identity(t, *args, **kwargs):
65 | return t
66 |
67 |
68 | def cast_tuple(t):
69 | return t if isinstance(t, (tuple, list)) else (t,)
70 |
71 |
72 | def yes_or_no(question):
73 | answer = input(f"{question} (y/n) ")
74 | return answer.lower() in ("yes", "y")
75 |
76 |
77 | def pair(val):
78 | return val if isinstance(val, tuple) else (val, val)
79 |
80 |
81 | def convert_image_to_fn(img_type, image):
82 | if image.mode != img_type:
83 | return image.convert(img_type)
84 | return image
85 |
86 |
87 | # image related helpers fnuctions and dataset
88 |
89 |
90 | def get_accelerator(*args, **kwargs):
91 | kwargs_handlers = kwargs.get("kwargs_handlers", [])
92 | if ddp_kwargs not in kwargs_handlers:
93 | kwargs_handlers.append(ddp_kwargs)
94 | kwargs.update(kwargs_handlers=kwargs_handlers)
95 | accelerator = Accelerator(*args, **kwargs)
96 | return accelerator
97 |
98 |
99 | def split_dataset(dataset: Dataset, valid_frac: float, accelerator: Accelerator, seed: int = 42):
100 | if valid_frac > 0:
101 | train_size = int((1 - valid_frac) * len(dataset))
102 | valid_size = len(dataset) - train_size
103 | ds, valid_ds = random_split(
104 | dataset,
105 | [train_size, valid_size],
106 | generator=torch.Generator().manual_seed(seed),
107 | )
108 | accelerator.print(
109 | f"training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples"
110 | )
111 | else:
112 | valid_ds = ds
113 | accelerator.print(f"training with shared training and valid dataset of {len(ds)} samples")
114 | return ds, valid_ds
115 |
116 |
117 | # main trainer class
118 |
119 |
120 | def get_optimizer(
121 | use_8bit_adam: bool,
122 | optimizer: str,
123 | parameters: dict,
124 | lr: float,
125 | weight_decay: float,
126 | optimizer_kwargs: dict = {},
127 | ):
128 | if use_8bit_adam is True and Adam8bit is None:
129 | print(
130 | "Please install bitsandbytes to use 8-bit optimizers. You can do so by running `pip install "
131 | "bitsandbytes` | Defaulting to non 8-bit equivalent..."
132 | )
133 |
134 | bnb_supported_optims = ["Adam", "AdamW", "Lion"]
135 | if use_8bit_adam and optimizer not in bnb_supported_optims:
136 | print(f"8bit is not supported by the {optimizer} optimizer, Using standard {optimizer} instead.")
137 |
138 | # optimizers
139 | if optimizer == "Adam":
140 | return (
141 | Adam8bit(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs)
142 | if use_8bit_adam and Adam8bit is not None
143 | else Adam(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs)
144 | )
145 | elif optimizer == "AdamW":
146 | return (
147 | AdamW8bit(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs)
148 | if use_8bit_adam and AdamW8bit is not None
149 | else AdamW(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs)
150 | )
151 | elif optimizer == "Lion":
152 | # Reckless reuse of the use_8bit_adam flag
153 | return (
154 | Lion8bit(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs)
155 | if use_8bit_adam and Lion8bit is not None
156 | else Lion(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs)
157 | )
158 | elif optimizer == "Adafactor":
159 | return Adafactor(
160 | parameters,
161 | lr=lr,
162 | weight_decay=weight_decay,
163 | relative_step=False,
164 | scale_parameter=False,
165 | **optimizer_kwargs,
166 | )
167 | elif optimizer == "AccSGD":
168 | return AccSGD(parameters, lr=lr, weight_decay=weight_decay)
169 | elif optimizer == "AdaBound":
170 | return AdaBound(parameters, lr=lr, weight_decay=weight_decay)
171 | elif optimizer == "AdaMod":
172 | return AdaMod(parameters, lr=lr, weight_decay=weight_decay)
173 | elif optimizer == "AdamP":
174 | return AdamP(parameters, lr=lr, weight_decay=weight_decay)
175 | elif optimizer == "AggMo":
176 | return AggMo(parameters, lr=lr, weight_decay=weight_decay)
177 | elif optimizer == "DiffGrad":
178 | return DiffGrad(parameters, lr=lr, weight_decay=weight_decay)
179 | elif optimizer == "Lamb":
180 | return Lamb(parameters, lr=lr, weight_decay=weight_decay)
181 | elif optimizer == "NovoGrad":
182 | return NovoGrad(parameters, lr=lr, weight_decay=weight_decay)
183 | elif optimizer == "PID":
184 | return PID(parameters, lr=lr, weight_decay=weight_decay)
185 | elif optimizer == "QHAdam":
186 | return QHAdam(parameters, lr=lr, weight_decay=weight_decay)
187 | elif optimizer == "QHM":
188 | return QHM(parameters, lr=lr, weight_decay=weight_decay)
189 | elif optimizer == "RAdam":
190 | return RAdam(parameters, lr=lr, weight_decay=weight_decay)
191 | elif optimizer == "SGDP":
192 | return SGDP(parameters, lr=lr, weight_decay=weight_decay)
193 | elif optimizer == "SGDW":
194 | return SGDW(parameters, lr=lr, weight_decay=weight_decay)
195 | elif optimizer == "Shampoo":
196 | return Shampoo(parameters, lr=lr, weight_decay=weight_decay)
197 | elif optimizer == "SWATS":
198 | return SWATS(parameters, lr=lr, weight_decay=weight_decay)
199 | elif optimizer == "Yogi":
200 | return Yogi(parameters, lr=lr, weight_decay=weight_decay)
201 | else:
202 | raise NotImplementedError(f"{optimizer} optimizer not supported yet.")
203 |
204 |
205 | @beartype
206 | class BaseAcceleratedTrainer(nn.Module):
207 | def __init__(
208 | self,
209 | dataloader: Union[DataLoader, MpDeviceLoaderWrapper],
210 | valid_dataloader: Union[DataLoader, MpDeviceLoaderWrapper],
211 | accelerator: Accelerator,
212 | *,
213 | current_step: int,
214 | num_train_steps: int,
215 | num_epochs: int = 5,
216 | max_grad_norm: Optional[int] = None,
217 | save_results_every: int = 100,
218 | save_model_every: int = 1000,
219 | results_dir: Union[str, PathLike] = Path.cwd().joinpath("results"),
220 | logging_dir: Union[str, PathLike] = Path.cwd().joinpath("results/logs"),
221 | apply_grad_penalty_every: int = 4,
222 | gradient_accumulation_steps: int = 1,
223 | clear_previous_experiments: bool = False,
224 | validation_image_scale: Union[int, float] = 1.0,
225 | only_save_last_checkpoint: bool = False,
226 | ):
227 | super().__init__()
228 | self.model: nn.Module = None
229 | # instantiate accelerator
230 | self.gradient_accumulation_steps: int = gradient_accumulation_steps
231 | self.accelerator: Accelerator = accelerator
232 | self.logging_dir: Path = Path(logging_dir) if not isinstance(logging_dir, Path) else logging_dir
233 | self.results_dir: Path = Path(results_dir) if not isinstance(results_dir, Path) else results_dir
234 |
235 | # training params
236 | self.only_save_last_checkpoint: bool = only_save_last_checkpoint
237 | self.validation_image_scale: Union[int, float] = validation_image_scale
238 | self.register_buffer("steps", torch.Tensor([current_step]))
239 | self.num_train_steps: int = num_train_steps
240 | self.num_epochs = num_epochs
241 | self.max_grad_norm: Optional[Union[int, float]] = max_grad_norm
242 |
243 | self.dl = dataloader
244 | self.valid_dl = valid_dataloader
245 | self.dl_iter = iter(self.dl)
246 | self.valid_dl_iter = iter(self.valid_dl)
247 |
248 | self.save_model_every: int = save_model_every
249 | self.save_results_every: int = save_results_every
250 | self.apply_grad_penalty_every: int = apply_grad_penalty_every
251 |
252 | # Clear previous experiment data if requested
253 | if clear_previous_experiments is True and self.accelerator.is_local_main_process:
254 | if self.results_dir.exists():
255 | rmtree(self.results_dir, ignore_errors=True)
256 | # Make sure logging and results directories exist
257 | self.logging_dir.mkdir(parents=True, exist_ok=True)
258 | self.results_dir.mkdir(parents=True, exist_ok=True)
259 |
260 | self.optim: Optimizer = None
261 |
262 | self.print = self.accelerator.print
263 | self.log = self.accelerator.log
264 |
265 | self.on_tpu = self.accelerator.distributed_type == accelerate.DistributedType.TPU
266 |
267 | def save(self, path):
268 | if not self.accelerator.is_main_process:
269 | return
270 |
271 | pkg = dict(
272 | model=self.accelerator.get_state_dict(self.model),
273 | optim=self.optim.state_dict(),
274 | )
275 | self.accelerator.save(pkg, path)
276 |
277 | def load(self, path: Union[str, PathLike]):
278 | if not isinstance(path, Path):
279 | path = Path(path)
280 |
281 | if not path.exists():
282 | raise FileNotFoundError(f"Checkpoint file {path} does not exist.")
283 |
284 | pkg = torch.load(path, map_location="cpu")
285 | model = self.accelerator.unwrap_model(self.model)
286 | model.load_state_dict(pkg["model"])
287 |
288 | self.optim.load_state_dict(pkg["optim"])
289 | return pkg
290 |
291 | def log_validation_images(self, images, step, prompts=None):
292 | if self.validation_image_scale != 1:
293 | # Calculate the new height based on the scale factor
294 | new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale)
295 |
296 | # Calculate the aspect ratio of the original image
297 | aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0]
298 |
299 | # Calculate the new width based on the new height and aspect ratio
300 | new_width = int(new_height * aspect_ratio)
301 |
302 | # Resize the images using the new width and height
303 | output_size = (new_width, new_height)
304 | images_pil = [Image.fromarray(np.array(image)) for image in images]
305 | images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil]
306 | images = [np.array(image_pil) for image_pil in images_pil_resized]
307 |
308 | for tracker in self.accelerator.trackers:
309 | if tracker.name == "tensorboard":
310 | np_images = np.stack([np.asarray(img) for img in images])
311 | tracker.writer.add_images("validation", np_images, step, dataformats="NHWC")
312 | if tracker.name == "wandb":
313 | tracker.log(
314 | {
315 | "validation": [
316 | wandb.Image(image, caption="" if not prompts else prompts[i])
317 | for i, image in enumerate(images)
318 | ]
319 | }
320 | )
321 |
322 | @property
323 | def device(self):
324 | return self.accelerator.device
325 |
326 | @property
327 | def is_distributed(self):
328 | return (
329 | False
330 | if self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1
331 | else True
332 | )
333 |
334 | @property
335 | def is_main_process(self):
336 | return self.accelerator.is_main_process
337 |
338 | @property
339 | def is_local_main_process(self):
340 | return self.accelerator.is_local_main_process
341 |
342 | def train_step(self):
343 | raise NotImplementedError("You are calling train_step on the base trainer with no models")
344 |
345 | def train(self, log_fn=noop):
346 | self.model.train()
347 | while self.steps < self.num_train_steps:
348 | with self.accelerator.autocast():
349 | logs = self.train_step()
350 | log_fn(logs)
351 | self.print("training complete")
352 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/trainers/maskgit_trainer.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 |
3 | import torch # noqa: F401
4 | import torch.nn.functional as F
5 | from accelerate import Accelerator
6 | from diffusers.optimization import SchedulerType
7 | from ema_pytorch import EMA
8 | from omegaconf import OmegaConf
9 | from PIL import Image
10 | from torch.optim import Optimizer
11 | from torch.utils.data import DataLoader
12 | from torchvision.utils import save_image
13 |
14 | from muse_maskgit_pytorch.muse_maskgit_pytorch import MaskGit
15 | from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded
16 | from muse_maskgit_pytorch.trainers.base_accelerated_trainer import BaseAcceleratedTrainer
17 |
18 | try:
19 | import torch_xla
20 | import torch_xla.core.xla_model as xm
21 | import torch_xla.debug.metrics as met
22 | except ImportError:
23 | torch_xla = None
24 | xm = None
25 | met = None
26 |
27 | from tqdm import tqdm
28 |
29 |
30 | class MaskGitTrainer(BaseAcceleratedTrainer):
31 | def __init__(
32 | self,
33 | maskgit: MaskGit,
34 | dataloader: DataLoader,
35 | valid_dataloader: DataLoader,
36 | accelerator: Accelerator,
37 | optimizer: Optimizer,
38 | scheduler: SchedulerType,
39 | *,
40 | current_step: int,
41 | num_train_steps: int,
42 | num_epochs: int = 5,
43 | batch_size: int,
44 | gradient_accumulation_steps: int = 1,
45 | max_grad_norm: float = None,
46 | save_results_every: int = 100,
47 | save_model_every: int = 1000,
48 | log_metrics_every: int = 10,
49 | results_dir="./results",
50 | logging_dir="./results/logs",
51 | apply_grad_penalty_every=4,
52 | use_ema=True,
53 | ema_update_after_step=0,
54 | ema_update_every=1,
55 | validation_prompts=["a photo of a dog"],
56 | timesteps=18,
57 | clear_previous_experiments=False,
58 | validation_image_scale: float = 1.0,
59 | only_save_last_checkpoint=False,
60 | args=None,
61 | ):
62 | super().__init__(
63 | dataloader=dataloader,
64 | valid_dataloader=valid_dataloader,
65 | accelerator=accelerator,
66 | current_step=current_step,
67 | num_train_steps=num_train_steps,
68 | num_epochs=num_epochs,
69 | gradient_accumulation_steps=gradient_accumulation_steps,
70 | max_grad_norm=max_grad_norm,
71 | save_results_every=save_results_every,
72 | save_model_every=save_model_every,
73 | results_dir=results_dir,
74 | logging_dir=logging_dir,
75 | apply_grad_penalty_every=apply_grad_penalty_every,
76 | clear_previous_experiments=clear_previous_experiments,
77 | validation_image_scale=validation_image_scale,
78 | only_save_last_checkpoint=only_save_last_checkpoint,
79 | )
80 | self.save_results_every = save_results_every
81 | self.log_metrics_every = log_metrics_every
82 | self.batch_size = batch_size
83 | self.current_step = current_step
84 | self.timesteps = timesteps
85 |
86 | # arguments used for the training script,
87 | # we are going to use them later to save them to a config file.
88 | self.args = args
89 |
90 | # maskgit
91 | maskgit.vae.requires_grad_(False)
92 | maskgit.transformer.t5.requires_grad_(False)
93 | self.model: MaskGit = maskgit
94 |
95 | self.optim: Optimizer = optimizer
96 | self.lr_scheduler: SchedulerType = scheduler
97 |
98 | self.use_ema = use_ema
99 | self.validation_prompts: List[str] = validation_prompts
100 | if use_ema:
101 | ema_model = EMA(
102 | self.model,
103 | update_after_step=ema_update_after_step,
104 | update_every=ema_update_every,
105 | )
106 | self.ema_model = ema_model
107 | else:
108 | self.ema_model = None
109 |
110 | if not self.on_tpu:
111 | if self.num_train_steps <= 0:
112 | self.training_bar = tqdm(initial=int(self.steps.item()), total=len(self.dl) * self.num_epochs)
113 | else:
114 | self.training_bar = tqdm(initial=int(self.steps.item()), total=self.num_train_steps)
115 |
116 | self.info_bar = tqdm(total=0, bar_format="{desc}")
117 |
118 | def save_validation_images(
119 | self, validation_prompts, step: int, cond_image=None, cond_scale=3, temperature=1, timesteps=18
120 | ):
121 | # moved the print to the top of the function so it shows before the progress bar for reability.
122 | if validation_prompts:
123 | self.accelerator.print(
124 | f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}"
125 | )
126 |
127 | images = self.model.generate(
128 | validation_prompts,
129 | cond_images=cond_image,
130 | cond_scale=cond_scale,
131 | temperature=temperature,
132 | timesteps=timesteps,
133 | ).to(self.accelerator.device)
134 |
135 | save_dir = self.results_dir.joinpath("MaskGit")
136 | save_dir.mkdir(exist_ok=True, parents=True)
137 | save_file = save_dir.joinpath(f"maskgit_{step}.png")
138 |
139 | if self.accelerator.is_main_process:
140 | save_image(images, save_file, "png")
141 | self.log_validation_images([Image.open(save_file)], step, ["|".join(validation_prompts)])
142 | return save_file
143 |
144 | def train(self):
145 | self.steps = self.steps + 1
146 | self.model.train()
147 |
148 | if self.accelerator.is_main_process:
149 | proc_label = f"[P{self.accelerator.process_index}][Master]"
150 | else:
151 | proc_label = f"[P{self.accelerator.process_index}][Worker]"
152 |
153 | # logs
154 | for epoch in range(self.current_step // len(self.dl), self.num_epochs):
155 | for imgs, input_ids, attn_mask in iter(self.dl):
156 | train_loss = 0.0
157 | steps = int(self.steps.item())
158 |
159 | with torch.no_grad():
160 | text_embeds = t5_encode_text_from_encoded(
161 | input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device
162 | )
163 |
164 | with self.accelerator.accumulate(self.model), self.accelerator.autocast():
165 | loss = self.model(imgs, text_embeds=text_embeds)
166 | self.accelerator.backward(loss)
167 | if self.max_grad_norm is not None and self.accelerator.sync_gradients:
168 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
169 | self.optim.step()
170 | self.lr_scheduler.step()
171 | self.optim.zero_grad()
172 |
173 | if self.use_ema:
174 | self.ema_model.update()
175 |
176 | gathered_loss = self.accelerator.gather_for_metrics(loss)
177 | train_loss = gathered_loss.mean() / self.gradient_accumulation_steps
178 |
179 | logs = {"loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0]}
180 |
181 | if self.on_tpu:
182 | self.accelerator.print(
183 | f"\n[E{epoch + 1}][{steps}]{proc_label}: "
184 | f"maskgit loss: {logs['loss']} - lr: {logs['lr']}"
185 | )
186 | else:
187 | self.training_bar.update()
188 | self.info_bar.set_description_str(
189 | f"[E{epoch + 1}]{proc_label}: " f"maskgit loss: {logs['loss']} - lr: {logs['lr']}"
190 | )
191 |
192 | self.accelerator.log(logs, step=steps)
193 |
194 | if not (steps % self.save_model_every):
195 | self.accelerator.print(
196 | f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"saving model to {self.results_dir}"
197 | )
198 |
199 | state_dict = self.accelerator.unwrap_model(self.model).state_dict()
200 | maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit"
201 | file_name = (
202 | f"{maskgit_save_name}.{steps}.pt"
203 | if not self.only_save_last_checkpoint
204 | else f"{maskgit_save_name}.pt"
205 | )
206 |
207 | model_path = self.results_dir.joinpath(file_name)
208 | self.accelerator.wait_for_everyone()
209 | self.accelerator.save(state_dict, model_path)
210 |
211 | if self.args and not self.args.do_not_save_config:
212 | # save config file next to the model file.
213 | conf = OmegaConf.create(vars(self.args))
214 | OmegaConf.save(conf, f"{model_path}.yaml")
215 |
216 | if self.use_ema:
217 | self.accelerator.print(
218 | f"\n[E{epoch + 1}][{steps}]{proc_label}: "
219 | f"saving EMA model to {self.results_dir}"
220 | )
221 |
222 | ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict()
223 | file_name = (
224 | f"{maskgit_save_name}.{steps}.ema.pt"
225 | if not self.only_save_last_checkpoint
226 | else f"{maskgit_save_name}.ema.pt"
227 | )
228 | model_path = str(self.results_dir / file_name)
229 | self.accelerator.wait_for_everyone()
230 | self.accelerator.save(ema_state_dict, model_path)
231 |
232 | if self.args and not self.args.do_not_save_config:
233 | # save config file next to the model file.
234 | conf = OmegaConf.create(vars(self.args))
235 | OmegaConf.save(conf, f"{model_path}.yaml")
236 |
237 | if not (steps % self.save_results_every):
238 | cond_image = None
239 | if self.model.cond_image_size:
240 | cond_image = F.interpolate(imgs, self.model.cond_image_size, mode="nearest")
241 | self.validation_prompts = [""] * self.batch_size
242 |
243 | if self.on_tpu:
244 | self.accelerator.print(f"\n[E{epoch + 1}]{proc_label}: " f"Logging validation images")
245 | else:
246 | self.info_bar.set_description_str(
247 | f"[E{epoch + 1}]{proc_label}: " f"Logging validation images"
248 | )
249 |
250 | saved_image = self.save_validation_images(
251 | self.validation_prompts,
252 | steps,
253 | cond_image=cond_image,
254 | timesteps=self.timesteps,
255 | )
256 | if self.on_tpu:
257 | self.accelerator.print(
258 | f"\n[E{epoch + 1}][{steps}]{proc_label}: saved to {saved_image}"
259 | )
260 | else:
261 | self.info_bar.set_description_str(
262 | f"[E{epoch + 1}]{proc_label}: " f"saved to {saved_image}"
263 | )
264 |
265 | if met is not None and not (steps % self.log_metrics_every):
266 | if self.on_tpu:
267 | self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: metrics:")
268 | else:
269 | self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: metrics:")
270 |
271 | self.steps += 1
272 |
273 | # if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps:
274 | # if self.on_tpu:
275 | # self.accelerator.print(
276 | # f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}"
277 | # f"[STOP EARLY]: Stopping training early..."
278 | # )
279 | # else:
280 | # self.info_bar.set_description_str(
281 | # f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..."
282 | # )
283 | # break
284 |
285 | # loop complete, save final model
286 | self.accelerator.print(
287 | f"\n[E{epoch + 1}][{steps}]{proc_label}[FINAL]: saving model to {self.results_dir}"
288 | )
289 | state_dict = self.accelerator.unwrap_model(self.model).state_dict()
290 | maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit"
291 | file_name = (
292 | f"{maskgit_save_name}.{steps}.pt"
293 | if not self.only_save_last_checkpoint
294 | else f"{maskgit_save_name}.pt"
295 | )
296 |
297 | model_path = self.results_dir.joinpath(file_name)
298 | self.accelerator.wait_for_everyone()
299 | self.accelerator.save(state_dict, model_path)
300 |
301 | if self.args and not self.args.do_not_save_config:
302 | # save config file next to the model file.
303 | conf = OmegaConf.create(vars(self.args))
304 | OmegaConf.save(conf, f"{model_path}.yaml")
305 |
306 | if self.use_ema:
307 | self.accelerator.print(f"\n[{steps}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}")
308 | ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict()
309 | file_name = (
310 | f"{maskgit_save_name}.{steps}.ema.pt"
311 | if not self.only_save_last_checkpoint
312 | else f"{maskgit_save_name}.ema.pt"
313 | )
314 | model_path = str(self.results_dir / file_name)
315 | self.accelerator.wait_for_everyone()
316 | self.accelerator.save(ema_state_dict, model_path)
317 |
318 | if self.args and not self.args.do_not_save_config:
319 | # save config file next to the model file.
320 | conf = OmegaConf.create(vars(self.args))
321 | OmegaConf.save(conf, f"{model_path}.yaml")
322 |
323 | cond_image = None
324 | if self.model.cond_image_size:
325 | self.accelerator.print(
326 | "With conditional image training, we recommend keeping the validation prompts to empty strings"
327 | )
328 | cond_image = F.interpolate(imgs, self.model.cond_image_size, mode="nearest")
329 |
330 | steps = int(self.steps.item()) + 1 # get the final step count, plus one
331 | self.accelerator.print(f"\n[{steps}]{proc_label}: Logging validation images")
332 | saved_image = self.save_validation_images(self.validation_prompts, steps, cond_image=cond_image)
333 | self.accelerator.print(f"\n[{steps}]{proc_label}: saved to {saved_image}")
334 |
335 | if met is not None and not (steps % self.log_metrics_every):
336 | self.accelerator.print(f"\n[{steps}]{proc_label}: metrics:")
337 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/trainers/vqvae_trainers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from accelerate import Accelerator
3 | from diffusers.optimization import get_scheduler
4 | from einops import rearrange
5 | from ema_pytorch import EMA
6 | from omegaconf import OmegaConf
7 | from PIL import Image
8 | from torch.optim.lr_scheduler import LRScheduler
9 | from torch.utils.data import DataLoader
10 | from torchvision.utils import make_grid, save_image
11 | from tqdm import tqdm
12 |
13 | from muse_maskgit_pytorch.trainers.base_accelerated_trainer import (
14 | BaseAcceleratedTrainer,
15 | get_optimizer,
16 | )
17 | from muse_maskgit_pytorch.vqgan_vae import VQGanVAE
18 |
19 |
20 | def noop(*args, **kwargs):
21 | pass
22 |
23 |
24 | def accum_log(log, new_logs):
25 | for key, new_value in new_logs.items():
26 | old_value = log.get(key, 0.0)
27 | log[key] = old_value + new_value
28 | return log
29 |
30 |
31 | def exists(val):
32 | return val is not None
33 |
34 |
35 | class VQGanVAETrainer(BaseAcceleratedTrainer):
36 | def __init__(
37 | self,
38 | vae: VQGanVAE,
39 | dataloader: DataLoader,
40 | valid_dataloader: DataLoader,
41 | accelerator: Accelerator,
42 | *,
43 | current_step,
44 | num_train_steps,
45 | num_epochs: int = 5,
46 | gradient_accumulation_steps=1,
47 | max_grad_norm=None,
48 | save_results_every=100,
49 | save_model_every=1000,
50 | results_dir="./results",
51 | logging_dir="./results/logs",
52 | apply_grad_penalty_every=4,
53 | lr=3e-4,
54 | lr_scheduler_type="constant",
55 | lr_warmup_steps=500,
56 | discr_max_grad_norm=None,
57 | use_ema=True,
58 | ema_beta=0.995,
59 | ema_update_after_step=0,
60 | ema_update_every=1,
61 | clear_previous_experiments=False,
62 | validation_image_scale: float = 1.0,
63 | only_save_last_checkpoint=False,
64 | optimizer="Adam",
65 | weight_decay=0.0,
66 | use_8bit_adam=False,
67 | num_cycles=1,
68 | scheduler_power=1.0,
69 | args=None,
70 | ):
71 | super().__init__(
72 | dataloader,
73 | valid_dataloader,
74 | accelerator,
75 | current_step=current_step,
76 | num_train_steps=num_train_steps,
77 | num_epochs=num_epochs,
78 | gradient_accumulation_steps=gradient_accumulation_steps,
79 | max_grad_norm=max_grad_norm,
80 | save_results_every=save_results_every,
81 | save_model_every=save_model_every,
82 | results_dir=results_dir,
83 | logging_dir=logging_dir,
84 | apply_grad_penalty_every=apply_grad_penalty_every,
85 | clear_previous_experiments=clear_previous_experiments,
86 | validation_image_scale=validation_image_scale,
87 | only_save_last_checkpoint=only_save_last_checkpoint,
88 | )
89 |
90 | # arguments used for the training script,
91 | # we are going to use them later to save them to a config file.
92 | self.args = args
93 |
94 | self.current_step = current_step
95 |
96 | # vae
97 | self.model = vae
98 |
99 | all_parameters = set(vae.parameters())
100 | discr_parameters = set(vae.discr.parameters())
101 | vae_parameters = all_parameters - discr_parameters
102 |
103 | # optimizers
104 | self.optim = get_optimizer(use_8bit_adam, optimizer, vae_parameters, lr, weight_decay)
105 | self.discr_optim = get_optimizer(use_8bit_adam, optimizer, discr_parameters, lr, weight_decay)
106 |
107 | if self.num_train_steps > 0:
108 | self.num_lr_steps = self.num_train_steps * self.gradient_accumulation_steps
109 | else:
110 | self.num_lr_steps = self.num_epochs * len(self.dl)
111 |
112 | self.lr_scheduler: LRScheduler = get_scheduler(
113 | lr_scheduler_type,
114 | optimizer=self.optim,
115 | num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps,
116 | num_training_steps=self.num_lr_steps,
117 | num_cycles=num_cycles,
118 | power=scheduler_power,
119 | )
120 |
121 | self.lr_scheduler_discr: LRScheduler = get_scheduler(
122 | lr_scheduler_type,
123 | optimizer=self.discr_optim,
124 | num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps,
125 | num_training_steps=self.num_lr_steps,
126 | num_cycles=num_cycles,
127 | power=scheduler_power,
128 | )
129 |
130 | self.discr_max_grad_norm = discr_max_grad_norm
131 |
132 | # prepare with accelerator
133 |
134 | (
135 | self.model,
136 | self.optim,
137 | self.discr_optim,
138 | self.dl,
139 | self.valid_dl,
140 | self.lr_scheduler,
141 | self.lr_scheduler_discr,
142 | ) = accelerator.prepare(
143 | self.model,
144 | self.optim,
145 | self.discr_optim,
146 | self.dl,
147 | self.valid_dl,
148 | self.lr_scheduler,
149 | self.lr_scheduler_discr,
150 | )
151 | self.model.train()
152 |
153 | self.use_ema = use_ema
154 |
155 | if use_ema:
156 | self.ema_model = EMA(
157 | vae,
158 | update_after_step=ema_update_after_step,
159 | update_every=ema_update_every,
160 | )
161 | self.ema_model = accelerator.prepare(self.ema_model)
162 |
163 | if not self.on_tpu:
164 | if self.num_train_steps <= 0:
165 | self.training_bar = tqdm(initial=int(self.steps.item()), total=len(self.dl) * self.num_epochs)
166 | else:
167 | self.training_bar = tqdm(initial=int(self.steps.item()), total=self.num_train_steps)
168 |
169 | self.info_bar = tqdm(total=0, bar_format="{desc}")
170 |
171 | def load(self, path):
172 | pkg = super().load(path)
173 | self.discr_optim.load_state_dict(pkg["discr_optim"])
174 |
175 | def save(self, path):
176 | if not self.is_local_main_process:
177 | return
178 |
179 | pkg = dict(
180 | model=self.get_state_dict(self.model),
181 | optim=self.optim.state_dict(),
182 | discr_optim=self.discr_optim.state_dict(),
183 | )
184 | self.accelerator.save(pkg, path)
185 |
186 | def log_validation_images(self, logs, steps):
187 | log_imgs = []
188 | self.model.eval()
189 |
190 | try:
191 | valid_data = next(self.valid_dl_iter)
192 | except StopIteration:
193 | self.valid_dl_iter = iter(self.valid_dl)
194 | valid_data = next(self.valid_dl_iter)
195 |
196 | valid_data = valid_data.to(self.device)
197 |
198 | recons = self.model(valid_data, return_recons=True)
199 |
200 | # else save a grid of images
201 |
202 | imgs_and_recons = torch.stack((valid_data, recons), dim=0)
203 | imgs_and_recons = rearrange(imgs_and_recons, "r b ... -> (b r) ...")
204 |
205 | imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0.0, 1.0)
206 | grid = make_grid(imgs_and_recons, nrow=2, normalize=True, value_range=(0, 1))
207 |
208 | logs["reconstructions"] = grid
209 | save_file = str(self.results_dir / f"{steps}.png")
210 | save_image(grid, save_file)
211 | log_imgs.append(Image.open(save_file))
212 | super().log_validation_images(log_imgs, steps, prompts=["vae"])
213 | self.model.train()
214 |
215 | def train(self):
216 | self.steps = self.steps + 1
217 | device = self.device
218 | self.model.train()
219 |
220 | if self.accelerator.is_main_process:
221 | proc_label = f"[P{self.accelerator.process_index:03d}][Master]"
222 | else:
223 | proc_label = f"[P{self.accelerator.process_index:03d}][Worker]"
224 |
225 | for epoch in range(self.current_step // len(self.dl), self.num_epochs):
226 | for img in self.dl:
227 | loss = 0.0
228 | steps = int(self.steps.item())
229 |
230 | apply_grad_penalty = (steps % self.apply_grad_penalty_every) == 0
231 |
232 | discr = self.model.module.discr if self.is_distributed else self.model.discr
233 | if self.use_ema:
234 | ema_model = self.ema_model.module if self.is_distributed else self.ema_model
235 |
236 | # logs
237 |
238 | logs = {}
239 |
240 | # update vae (generator)
241 |
242 | img = img.to(device)
243 |
244 | with self.accelerator.autocast():
245 | loss = self.model(img, add_gradient_penalty=apply_grad_penalty, return_loss=True)
246 |
247 | self.accelerator.backward(loss / self.gradient_accumulation_steps)
248 | if self.max_grad_norm is not None and self.accelerator.sync_gradients:
249 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
250 |
251 | accum_log(logs, {"Train/vae_loss": loss.item() / self.gradient_accumulation_steps})
252 |
253 | self.lr_scheduler.step()
254 | self.lr_scheduler_discr.step()
255 | self.optim.step()
256 | self.optim.zero_grad()
257 |
258 | loss = 0.0
259 |
260 | # update discriminator
261 |
262 | if exists(discr):
263 | self.discr_optim.zero_grad()
264 |
265 | with torch.cuda.amp.autocast():
266 | loss = self.model(img, return_discr_loss=True)
267 |
268 | self.accelerator.backward(loss / self.gradient_accumulation_steps)
269 | if self.discr_max_grad_norm is not None and self.accelerator.sync_gradients:
270 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
271 |
272 | accum_log(
273 | logs,
274 | {"Train/discr_loss": loss.item() / self.gradient_accumulation_steps},
275 | )
276 |
277 | self.discr_optim.step()
278 |
279 | # log
280 | if self.on_tpu:
281 | self.accelerator.print(
282 | f"[E{epoch + 1}][{steps:05d}]{proc_label}: "
283 | f"vae loss: {logs['Train/vae_loss']} - "
284 | f"discr loss: {logs['Train/discr_loss']} - "
285 | f"lr: {self.lr_scheduler.get_last_lr()[0]}"
286 | )
287 | else:
288 | self.training_bar.update()
289 | # Note: we had to remove {proc_label} from the description
290 | # to short it so it doenst go beyond one line on each step.
291 | self.info_bar.set_description_str(
292 | f"[E{epoch + 1}][{steps:05d}]: "
293 | f"vae loss: {logs['Train/vae_loss']} - "
294 | f"discr loss: {logs['Train/discr_loss']} - "
295 | f"lr: {self.lr_scheduler.get_last_lr()[0]}"
296 | )
297 |
298 | logs["lr"] = self.lr_scheduler.get_last_lr()[0]
299 | self.accelerator.log(logs, step=steps)
300 |
301 | # update exponential moving averaged generator
302 |
303 | if self.use_ema:
304 | ema_model.update()
305 |
306 | # sample results every so often
307 |
308 | if (steps % self.save_results_every) == 0:
309 | self.accelerator.print(
310 | f"\n[E{epoch + 1}][{steps}] | Logging validation images to {str(self.results_dir)}"
311 | )
312 |
313 | self.log_validation_images(logs, steps)
314 |
315 | # save model every so often
316 | self.accelerator.wait_for_everyone()
317 | if self.is_main_process and (steps % self.save_model_every) == 0:
318 | self.accelerator.print(f"\nStep: {steps} | Saving model to {str(self.results_dir)}")
319 |
320 | state_dict = self.accelerator.unwrap_model(self.model).state_dict()
321 | file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt"
322 | model_path = str(self.results_dir / file_name)
323 | self.accelerator.save(state_dict, model_path)
324 |
325 | if self.args and not self.args.do_not_save_config:
326 | # save config file next to the model file.
327 | conf = OmegaConf.create(vars(self.args))
328 | OmegaConf.save(conf, f"{model_path}.yaml")
329 |
330 | if self.use_ema:
331 | ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict()
332 | file_name = (
333 | f"vae.{steps}.ema.pt" if not self.only_save_last_checkpoint else "vae.ema.pt"
334 | )
335 | model_path = str(self.results_dir / file_name)
336 | self.accelerator.save(ema_state_dict, model_path)
337 |
338 | if self.args and not self.args.do_not_save_config:
339 | # save config file next to the model file.
340 | conf = OmegaConf.create(vars(self.args))
341 | OmegaConf.save(conf, f"{model_path}.yaml")
342 |
343 | self.steps += 1
344 |
345 | # if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps:
346 | # self.accelerator.print(
347 | # f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..."
348 | # )
349 | # break
350 |
351 | # Loop finished, save model
352 | self.accelerator.wait_for_everyone()
353 | if self.is_main_process:
354 | self.accelerator.print(
355 | f"[E{self.num_epochs}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}"
356 | )
357 |
358 | state_dict = self.accelerator.unwrap_model(self.model).state_dict()
359 | file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt"
360 | model_path = str(self.results_dir / file_name)
361 | self.accelerator.save(state_dict, model_path)
362 |
363 | if self.args and not self.args.do_not_save_config:
364 | # save config file next to the model file.
365 | conf = OmegaConf.create(vars(self.args))
366 | OmegaConf.save(conf, f"{model_path}.yaml")
367 |
368 | if self.use_ema:
369 | ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict()
370 | file_name = f"vae.{steps}.ema.pt" if not self.only_save_last_checkpoint else "vae.ema.pt"
371 | model_path = str(self.results_dir / file_name)
372 | self.accelerator.save(ema_state_dict, model_path)
373 |
374 | if self.args and not self.args.do_not_save_config:
375 | # save config file next to the model file.
376 | conf = OmegaConf.create(vars(self.args))
377 | OmegaConf.save(conf, f"{model_path}.yaml")
378 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/vqgan_vae.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from functools import partial, wraps
3 | from pathlib import Path
4 |
5 | import timm
6 | import torch
7 | import torch.nn.functional as F
8 | import torchvision
9 | from accelerate import Accelerator
10 | from beartype import beartype
11 | from einops import rearrange, repeat
12 | from torch import nn
13 | from torch.autograd import grad as torch_grad
14 | from vector_quantize_pytorch import VectorQuantize as VQ
15 |
16 | # constants
17 |
18 | MList = nn.ModuleList
19 |
20 |
21 | # helper functions
22 | def exists(val):
23 | return val is not None
24 |
25 |
26 | def default(val, d):
27 | return val if val is not None else d
28 |
29 |
30 | # decorators
31 | def eval_decorator(fn):
32 | def inner(model, *args, **kwargs):
33 | was_training = model.training
34 | model.eval()
35 | out = fn(model, *args, **kwargs)
36 | model.train(was_training)
37 | return out
38 |
39 | return inner
40 |
41 |
42 | def remove_vgg(fn):
43 | @wraps(fn)
44 | def inner(self, *args, **kwargs):
45 | has_vgg = hasattr(self, "_vgg")
46 | if has_vgg:
47 | vgg = self._vgg
48 | delattr(self, "_vgg")
49 |
50 | out = fn(self, *args, **kwargs)
51 |
52 | if has_vgg:
53 | self._vgg = vgg
54 |
55 | return out
56 |
57 | return inner
58 |
59 |
60 | # keyword argument helpers
61 | def pick_and_pop(keys, d):
62 | values = list(map(lambda key: d.pop(key), keys))
63 | return dict(zip(keys, values))
64 |
65 |
66 | def group_dict_by_key(cond, d):
67 | return_val = [dict(), dict()]
68 | for key in d.keys():
69 | match = bool(cond(key))
70 | ind = int(not match)
71 | return_val[ind][key] = d[key]
72 | return (*return_val,)
73 |
74 |
75 | def string_begins_with(prefix, string_input):
76 | return string_input.startswith(prefix)
77 |
78 |
79 | def group_by_key_prefix(prefix, d):
80 | return group_dict_by_key(partial(string_begins_with, prefix), d)
81 |
82 |
83 | def groupby_prefix_and_trim(prefix, d):
84 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
85 | kwargs_without_prefix = dict(
86 | map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items()))
87 | )
88 | return kwargs_without_prefix, kwargs
89 |
90 |
91 | # tensor helper functions
92 | def log(t, eps=1e-10):
93 | return torch.log(t + eps)
94 |
95 |
96 | def gradient_penalty(images, output, weight=10):
97 | gradients = torch_grad(
98 | outputs=output,
99 | inputs=images,
100 | grad_outputs=torch.ones(output.size(), device=images.device),
101 | create_graph=True,
102 | retain_graph=True,
103 | only_inputs=True,
104 | )[0]
105 |
106 | gradients = rearrange(gradients, "b ... -> b (...)")
107 | return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean()
108 |
109 |
110 | def leaky_relu(p: float = 0.1):
111 | return nn.LeakyReLU(p)
112 |
113 |
114 | def safe_div(numer, denom, eps=1e-8):
115 | return numer / denom.clamp(min=eps)
116 |
117 |
118 | # gan losses
119 | def hinge_discr_loss(fake, real):
120 | return (F.relu(1 + fake) + F.relu(1 - real)).mean()
121 |
122 |
123 | def hinge_gen_loss(fake):
124 | return -fake.mean()
125 |
126 |
127 | def bce_discr_loss(fake, real):
128 | return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean()
129 |
130 |
131 | def bce_gen_loss(fake):
132 | return -log(torch.sigmoid(fake)).mean()
133 |
134 |
135 | def grad_layer_wrt_loss(loss, layer):
136 | return torch_grad(
137 | outputs=loss,
138 | inputs=layer,
139 | grad_outputs=torch.ones_like(loss),
140 | retain_graph=True,
141 | )[0].detach()
142 |
143 |
144 | # vqgan vae
145 | class LayerNormChan(nn.Module):
146 | def __init__(self, dim, eps=1e-5):
147 | super().__init__()
148 | self.eps = eps
149 | self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1))
150 |
151 | def forward(self, x):
152 | var = torch.var(x, dim=1, unbiased=False, keepdim=True)
153 | mean = torch.mean(x, dim=1, keepdim=True)
154 | return (x - mean) * var.clamp(min=self.eps).rsqrt() * self.gamma
155 |
156 |
157 | # discriminator
158 | class Discriminator(nn.Module):
159 | def __init__(self, dims, channels=4, groups=16, init_kernel_size=5):
160 | super().__init__()
161 | dim_pairs = zip(dims[:-1], dims[1:])
162 |
163 | self.layers = MList(
164 | [
165 | nn.Sequential(
166 | nn.Conv2d(channels, dims[0], init_kernel_size, padding=init_kernel_size // 2),
167 | leaky_relu(),
168 | )
169 | ]
170 | )
171 |
172 | for dim_in, dim_out in dim_pairs:
173 | self.layers.append(
174 | nn.Sequential(
175 | nn.Conv2d(dim_in, dim_out, 4, stride=2, padding=1),
176 | nn.GroupNorm(groups, dim_out),
177 | leaky_relu(),
178 | )
179 | )
180 |
181 | dim = dims[-1]
182 | # return 5 x 5, for PatchGAN-esque training
183 | self.to_logits = nn.Sequential(nn.Conv2d(dim, dim, 1), leaky_relu(), nn.Conv2d(dim, 1, 4))
184 |
185 | def forward(self, x):
186 | for net in self.layers:
187 | x = net(x)
188 | return self.to_logits(x)
189 |
190 |
191 | # resnet encoder / decoder
192 | class ResnetEncDec(nn.Module):
193 | def __init__(
194 | self,
195 | dim: int,
196 | *,
197 | channels=4,
198 | layers=4,
199 | layer_mults=None,
200 | num_resnet_blocks=1,
201 | resnet_groups=16,
202 | first_conv_kernel_size=5,
203 | ):
204 | super().__init__()
205 | assert (
206 | dim % resnet_groups == 0
207 | ), f"dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)"
208 |
209 | self.layers = layers
210 | self.encoders = MList([])
211 | self.decoders = MList([])
212 |
213 | layer_mults = default(layer_mults, [2**x for x in range(layers)])
214 | if len(layer_mults) != layers:
215 | raise ValueError("layer multipliers must be equal to designated number of layers")
216 |
217 | layer_dims = [dim * mult for mult in layer_mults]
218 | dims = (dim, *layer_dims)
219 |
220 | self.encoded_dim = dims[-1]
221 | dim_pairs = zip(dims[:-1], dims[1:])
222 |
223 | if not isinstance(num_resnet_blocks, tuple):
224 | num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks)
225 | if len(num_resnet_blocks) != layers:
226 | raise ValueError("number of resnet blocks must be equal to number of layers")
227 |
228 | for _, (dim_in, dim_out), layer_num_resnet_blocks in zip(range(layers), dim_pairs, num_resnet_blocks):
229 | self.encoders.append(
230 | nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride=2, padding=1), leaky_relu())
231 | )
232 | self.decoders.insert(0, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu()))
233 | for _ in range(layer_num_resnet_blocks):
234 | self.encoders.append(ResBlock(dim_out, groups=resnet_groups))
235 | self.decoders.insert(0, GLUResBlock(dim_out, groups=resnet_groups))
236 |
237 | self.encoders.insert(
238 | 0, nn.Conv2d(channels, dim, first_conv_kernel_size, padding=first_conv_kernel_size // 2)
239 | )
240 | self.decoders.append(nn.Conv2d(dim, channels, 1))
241 |
242 | def get_encoded_fmap_size(self, image_size: int):
243 | return image_size // (2**self.layers)
244 |
245 | @property
246 | def last_dec_layer(self):
247 | return self.decoders[-1].weight
248 |
249 | def encode(self, x):
250 | for enc in self.encoders:
251 | x = enc(x)
252 | return x
253 |
254 | def decode(self, x):
255 | for dec in self.decoders:
256 | x = dec(x)
257 | return x
258 |
259 |
260 | class GLUResBlock(nn.Module):
261 | def __init__(self, chan, groups=16):
262 | super().__init__()
263 | self.net = nn.Sequential(
264 | nn.Conv2d(chan, chan * 2, 3, padding=1),
265 | nn.GLU(dim=1),
266 | nn.GroupNorm(groups, chan),
267 | nn.Conv2d(chan, chan * 2, 3, padding=1),
268 | nn.GLU(dim=1),
269 | nn.GroupNorm(groups, chan),
270 | nn.Conv2d(chan, chan, 1),
271 | )
272 |
273 | def forward(self, x):
274 | return self.net(x) + x
275 |
276 |
277 | class ResBlock(nn.Module):
278 | def __init__(self, chan, groups=16):
279 | super().__init__()
280 | self.net = nn.Sequential(
281 | nn.Conv2d(chan, chan, 3, padding=1),
282 | nn.GroupNorm(groups, chan),
283 | leaky_relu(),
284 | nn.Conv2d(chan, chan, 3, padding=1),
285 | nn.GroupNorm(groups, chan),
286 | leaky_relu(),
287 | nn.Conv2d(chan, chan, 1),
288 | )
289 |
290 | def forward(self, x):
291 | return self.net(x) + x
292 |
293 |
294 | class TimmFeatureEncDec(nn.Module):
295 | def __init__(self, backbone="convnext_base"):
296 | self.timm_model = timm.create_model(
297 | backbone,
298 | pretrained=True,
299 | features_only=True,
300 | exportable=True,
301 | out_indices=self.idx,
302 | )
303 | return
304 |
305 | def encode(self, x):
306 | for enc in self.encoders:
307 | x = enc(x)
308 | return x
309 |
310 | def decode(self, x):
311 | for dec in self.decoders:
312 | x = dec(x)
313 | return x
314 |
315 |
316 | class HuggingfaceEncDec(nn.Module):
317 | def __init__(self):
318 | return
319 |
320 | def forward(self):
321 | return
322 |
323 |
324 | class WaveletTransformerEncDec(nn.Module):
325 | def __init__(self):
326 | return
327 |
328 | def forward(self):
329 | return
330 |
331 |
332 | # main vqgan-vae classes
333 | @beartype
334 | class VQGanVAE(nn.Module):
335 | def __init__(
336 | self,
337 | *,
338 | dim: int,
339 | accelerator: Accelerator = None,
340 | channels=4,
341 | layers=4,
342 | l2_recon_loss=False,
343 | use_hinge_loss=True,
344 | vgg=None,
345 | vq_codebook_dim=256,
346 | vq_codebook_size=512,
347 | vq_decay=0.8,
348 | vq_commitment_weight=1.0,
349 | vq_kmeans_init=True,
350 | vq_use_cosine_sim=True,
351 | use_vgg_and_gan=True,
352 | discr_layers=4,
353 | **kwargs,
354 | ):
355 | super().__init__()
356 | vq_kwargs, kwargs = groupby_prefix_and_trim("vq_", kwargs)
357 | encdec_kwargs, kwargs = groupby_prefix_and_trim("encdec_", kwargs)
358 |
359 | self.accelerator = accelerator
360 | self.channels = channels
361 | self.codebook_size = vq_codebook_size
362 | self.dim_divisor = 2**layers
363 |
364 | self.enc_dec = ResnetEncDec(dim=dim, channels=channels, layers=layers, **encdec_kwargs)
365 |
366 | self.vq = VQ(
367 | dim=self.enc_dec.encoded_dim,
368 | codebook_dim=vq_codebook_dim,
369 | codebook_size=vq_codebook_size,
370 | decay=vq_decay,
371 | commitment_weight=vq_commitment_weight,
372 | accept_image_fmap=True,
373 | kmeans_init=vq_kmeans_init,
374 | use_cosine_sim=vq_use_cosine_sim,
375 | **vq_kwargs,
376 | )
377 |
378 | # reconstruction loss
379 | self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss
380 |
381 | # turn off GAN and perceptual loss if grayscale
382 | self._vgg = None
383 | self.discr = None
384 | self.use_vgg_and_gan = use_vgg_and_gan
385 | if not use_vgg_and_gan:
386 | return
387 |
388 | # preceptual loss
389 | if exists(vgg):
390 | self._vgg = vgg
391 |
392 | # gan related losses
393 | layer_mults = list(map(lambda t: 2**t, range(discr_layers)))
394 | layer_dims = [dim * mult for mult in layer_mults]
395 | dims = (dim, *layer_dims)
396 |
397 | self.discr = Discriminator(dims=dims, channels=channels)
398 | self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
399 | self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
400 |
401 | @property
402 | def device(self):
403 | return self.accelerator.device if self.accelerator else next(self.parameters()).device
404 |
405 | @property
406 | def vgg(self):
407 | if exists(self._vgg):
408 | return self._vgg
409 |
410 | vgg = torchvision.models.vgg16(pretrained=True)
411 | vgg.features[0] = nn.Conv2d(self.channels, 64, kernel_size=3, stride=1, padding=1)
412 | vgg.classifier = nn.Sequential(*vgg.classifier[:-2])
413 | self._vgg = vgg.to(self.device)
414 | return self._vgg
415 |
416 | @property
417 | def encoded_dim(self):
418 | return self.enc_dec.encoded_dim
419 |
420 | def get_encoded_fmap_size(self, image_size):
421 | return self.enc_dec.get_encoded_fmap_size(image_size)
422 |
423 | def copy_for_eval(self):
424 | device = next(self.parameters()).device
425 | vae_copy = copy.deepcopy(self.cpu())
426 |
427 | if vae_copy.use_vgg_and_gan:
428 | del vae_copy.discr
429 | del vae_copy._vgg
430 |
431 | vae_copy.eval()
432 | return vae_copy.to(device)
433 |
434 | @remove_vgg
435 | def state_dict(self, *args, **kwargs):
436 | return super().state_dict(*args, **kwargs)
437 |
438 | @remove_vgg
439 | def load_state_dict(self, *args, **kwargs):
440 | return super().load_state_dict(*args, **kwargs)
441 |
442 | def save(self, path):
443 | if self.accelerator is not None:
444 | self.accelerator.save(self.state_dict(), path)
445 | else:
446 | torch.save(self.state_dict(), path)
447 |
448 | def load(self, path, map=None):
449 | path = Path(path)
450 | assert path.exists()
451 | state_dict = torch.load(str(path), map_location=map)
452 | self.load_state_dict(state_dict)
453 |
454 | @property
455 | def codebook(self):
456 | return self.vq.codebook
457 |
458 | def encode(self, fmap):
459 | fmap = self.enc_dec.encode(fmap)
460 | fmap, indices, commit_loss = self.vq(fmap)
461 | return fmap, indices, commit_loss
462 |
463 | def decode_from_ids(self, ids):
464 | codes = self.codebook[ids]
465 | fmap = self.vq.project_out(codes)
466 | fmap = rearrange(fmap, "b h w c -> b c h w")
467 | return self.decode(fmap)
468 |
469 | def decode(self, fmap):
470 | return self.enc_dec.decode(fmap)
471 |
472 | def forward(
473 | self,
474 | img,
475 | return_loss=False,
476 | return_discr_loss=False,
477 | return_recons=False,
478 | add_gradient_penalty=True,
479 | relu_loss=True,
480 | ):
481 | batch, channels, height, width, device = *img.shape, img.device
482 |
483 | for dim_name, size in (("height", height), ("width", width)):
484 | assert (size % self.dim_divisor) == 0, f"{dim_name} must be divisible by {self.dim_divisor}"
485 |
486 | assert (
487 | channels == self.channels
488 | ), "number of channels on image or sketch is not equal to the channels set on this VQGanVAE"
489 |
490 | fmap, indices, commit_loss = self.encode(img)
491 |
492 | fmap = self.decode(fmap)
493 |
494 | if not return_loss and not return_discr_loss:
495 | return fmap
496 |
497 | assert (
498 | return_loss ^ return_discr_loss
499 | ), "you should either return autoencoder loss or discriminator loss, but not both"
500 |
501 | # whether to return discriminator loss
502 |
503 | if return_discr_loss:
504 | assert exists(self.discr), "discriminator must exist to train it"
505 |
506 | fmap.detach_()
507 | img.requires_grad_()
508 |
509 | fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
510 |
511 | discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits)
512 |
513 | if add_gradient_penalty:
514 | gp = gradient_penalty(img, img_discr_logits)
515 | loss = discr_loss + gp
516 |
517 | if return_recons:
518 | if relu_loss:
519 | return F.relu(loss), fmap
520 | else:
521 | return loss, fmap
522 |
523 | if relu_loss:
524 | return F.relu(loss)
525 | else:
526 | return loss
527 |
528 | # reconstruction loss
529 |
530 | recon_loss = self.recon_loss_fn(fmap, img)
531 |
532 | # early return if training on grayscale
533 |
534 | if not self.use_vgg_and_gan:
535 | if return_recons:
536 | if relu_loss:
537 | return F.relu(recon_loss), fmap
538 | else:
539 | return recon_loss, fmap
540 |
541 | if relu_loss:
542 | return F.relu(recon_loss)
543 | else:
544 | return recon_loss
545 |
546 | # perceptual loss
547 |
548 | img_vgg_input = img
549 | fmap_vgg_input = fmap
550 |
551 | if img.shape[1] == 1:
552 | # handle grayscale for vgg
553 | img_vgg_input, fmap_vgg_input = map(
554 | lambda t: repeat(t, "b 1 ... -> b c ...", c=3),
555 | (img_vgg_input, fmap_vgg_input),
556 | )
557 |
558 | img_vgg_feats = self.vgg(img_vgg_input)
559 | recon_vgg_feats = self.vgg(fmap_vgg_input)
560 | perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats)
561 | if relu_loss:
562 | perceptual_loss = F.relu(perceptual_loss)
563 |
564 | # generator loss
565 |
566 | gen_loss = self.gen_loss(self.discr(fmap))
567 | if relu_loss:
568 | gen_loss = F.relu(gen_loss)
569 |
570 | # calculate adaptive weight
571 |
572 | last_dec_layer = self.enc_dec.last_dec_layer
573 |
574 | norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2)
575 | norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2)
576 |
577 | adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss)
578 | adaptive_weight.clamp_(max=1e4)
579 |
580 | # combine losses
581 | # recon loss is reconstruction loss mse
582 | # perceptual loss is loss in vgg features mse
583 | # commit loss is loss in quanitizing in vq mse
584 | # gan loss is
585 | if relu_loss:
586 | loss = (
587 | F.relu(recon_loss)
588 | + F.relu(perceptual_loss)
589 | + F.relu(commit_loss)
590 | + F.relu(adaptive_weight) * F.relu(gen_loss)
591 | )
592 | else:
593 | loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
594 |
595 | if return_recons:
596 | if relu_loss:
597 | return F.relu(loss), fmap
598 | else:
599 | return loss, fmap
600 |
601 | if relu_loss:
602 | return F.relu(loss)
603 | else:
604 | return loss
605 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/vqgan_vae_taming.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import importlib
3 | from math import log, sqrt
4 | from pathlib import Path
5 | from urllib.parse import urlparse
6 |
7 | import requests
8 | import torch
9 | import torch.nn.functional as F
10 | from accelerate import Accelerator
11 | from einops import rearrange
12 | from omegaconf import DictConfig, OmegaConf
13 | from taming.models.vqgan import VQModel
14 | from torch import nn
15 | from tqdm_loggable.auto import tqdm
16 |
17 | # constants
18 | CACHE_PATH = Path.home().joinpath(".cache/taming")
19 |
20 | VQGAN_VAE_PATH = "https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1"
21 | VQGAN_VAE_CONFIG_PATH = "https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1"
22 |
23 | # helpers methods
24 |
25 |
26 | def exists(val):
27 | return val is not None
28 |
29 |
30 | def default(val, d):
31 | return val if exists(val) else d
32 |
33 |
34 | def download(url, filename=None, root=CACHE_PATH, chunk_size=1024):
35 | filename = default(filename, urlparse(url).path.split("/")[-1])
36 | root_dir = Path(root)
37 |
38 | target_path = root_dir.joinpath(filename)
39 | if target_path.exists():
40 | if target_path.isfile():
41 | return str(target_path)
42 | raise RuntimeError(f"{target_path} exists and is not a regular file")
43 |
44 | target_tmp = target_path.with_name(f".{target_path.name}.tmp")
45 | resp = requests.get(url, stream=True)
46 | resp.raise_for_status()
47 |
48 | filesize = int(resp.headers.get("content-length", 0))
49 | with target_tmp.open("wb") as f:
50 | for data in tqdm(
51 | resp.iter_content(chunk_size=chunk_size),
52 | desc=filename,
53 | total=filesize,
54 | unit="iB",
55 | unit_scale=True,
56 | unit_divisor=1024,
57 | ):
58 | f.write(data)
59 | target_tmp.rename(target_path)
60 | return target_path
61 |
62 |
63 | # VQGAN from Taming Transformers paper
64 | # https://arxiv.org/abs/2012.09841
65 |
66 |
67 | def get_obj_from_str(string, reload=False):
68 | module, cls = string.rsplit(".", 1)
69 | if reload:
70 | module_imp = importlib.import_module(module)
71 | importlib.reload(module_imp)
72 | return getattr(importlib.import_module(module, package=None), cls)
73 |
74 |
75 | def instantiate_from_config(config):
76 | if "target" not in config:
77 | raise KeyError("Expected key `target` to instantiate.")
78 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
79 |
80 |
81 | class VQGanVAETaming(nn.Module):
82 | def __init__(self, vqgan_model_path=None, vqgan_config_path=None, accelerator: Accelerator = None):
83 | super().__init__()
84 | if accelerator is None:
85 | accelerator = Accelerator()
86 |
87 | # Download model if needed
88 | if vqgan_model_path is None:
89 | CACHE_PATH.mkdir(parents=True, exist_ok=True)
90 | model_filename = "vqgan.1024.model.ckpt"
91 | config_filename = "vqgan.1024.config.yml"
92 | with accelerator.local_main_process_first():
93 | config_path = download(VQGAN_VAE_CONFIG_PATH, config_filename)
94 | model_path = download(VQGAN_VAE_PATH, model_filename)
95 | else:
96 | config_path = Path(vqgan_config_path)
97 | model_path = Path(vqgan_model_path)
98 |
99 | with accelerator.local_main_process_first():
100 | config: DictConfig = OmegaConf.load(config_path)
101 | model: VQModel = instantiate_from_config(config["model"])
102 | state = torch.load(model_path, map_location="cpu")["state_dict"]
103 | model.load_state_dict(state, strict=False)
104 |
105 | print(f"Loaded VQGAN from {model_path} and {config_path}")
106 | self.model: VQModel = model
107 |
108 | # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models
109 | f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0]
110 | self.num_layers = int(log(f) / log(2))
111 | self.channels = 3
112 | self.image_size = 256
113 | self.num_tokens = config.model.params.n_embed
114 | self.is_gumbel = False # isinstance(self.model, GumbelVQ)
115 | self.codebook_size = config["model"]["params"]["n_embed"]
116 |
117 | @torch.no_grad()
118 | def get_codebook_indices(self, img):
119 | b = img.shape[0]
120 | img = (2 * img) - 1
121 | _, _, [_, _, indices] = self.model.encode(img)
122 | if self.is_gumbel:
123 | return rearrange(indices, "b h w -> b (h w)", b=b)
124 | return rearrange(indices, "(b n) -> b n", b=b)
125 |
126 | def get_encoded_fmap_size(self, image_size):
127 | return image_size // (2**self.num_layers)
128 |
129 | def decode_from_ids(self, img_seq):
130 | img_seq = rearrange(img_seq, "b h w -> b (h w)")
131 | b, n = img_seq.shape
132 | one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).float()
133 | z = (
134 | one_hot_indices @ self.model.quantize.embed.weight
135 | if self.is_gumbel
136 | else (one_hot_indices @ self.model.quantize.embedding.weight)
137 | )
138 |
139 | z = rearrange(z, "b (h w) c -> b c h w", h=int(sqrt(n)))
140 | img = self.model.decode(z)
141 |
142 | img = (img.clamp(-1.0, 1.0) + 1) * 0.5
143 | return img
144 |
145 | def encode(self, im_seq):
146 | # encode output
147 | # fmap, loss, (perplexity, min_encodings, min_encodings_indices) = self.model.encode(im_seq)
148 | fmap, loss, (_, _, min_encodings_indices) = self.model.encode(im_seq)
149 |
150 | b, _, h, w = fmap.shape
151 | min_encodings_indices = rearrange(min_encodings_indices, "(b h w) -> b h w", h=h, w=w, b=b)
152 | return fmap, min_encodings_indices, loss
153 |
154 | def decode_ids(self, ids):
155 | return self.model.decode_code(ids)
156 |
157 | def copy_for_eval(self):
158 | device = next(self.parameters()).device
159 | vae_copy = copy.deepcopy(self.cpu())
160 |
161 | vae_copy.eval()
162 | return vae_copy.to(device)
163 |
164 | def forward(self, img):
165 | raise NotImplementedError("Forward not implemented for Taming VAE")
166 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/vqvae/__init__.py:
--------------------------------------------------------------------------------
1 | from .config import VQVAEConfig
2 | from .vqvae import VQVAE
3 |
4 | __all__ = [
5 | "VQVAE",
6 | "VQVAEConfig",
7 | ]
8 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/vqvae/config.py:
--------------------------------------------------------------------------------
1 | from pydantic import BaseModel, Field
2 |
3 |
4 | class EncoderConfig(BaseModel):
5 | image_size: int = Field(...)
6 | patch_size: int = Field(...)
7 | dim: int = Field(...)
8 | depth: int = Field(...)
9 | num_head: int = Field(...)
10 | mlp_dim: int = Field(...)
11 | in_channels: int = Field(...)
12 | dim_head: int = Field(...)
13 | dropout: float = Field(...)
14 |
15 |
16 | class DecoderConfig(BaseModel):
17 | image_size: int = Field(...)
18 | patch_size: int = Field(...)
19 | dim: int = Field(...)
20 | depth: int = Field(...)
21 | num_head: int = Field(...)
22 | mlp_dim: int = Field(...)
23 | out_channels: int = Field(...)
24 | dim_head: int = Field(...)
25 | dropout: float = Field(...)
26 |
27 |
28 | class VQVAEConfig(BaseModel):
29 | n_embed: int = Field(...)
30 | embed_dim: int = Field(...)
31 | beta: float = Field(...)
32 | enc: EncoderConfig = Field(...)
33 | dec: DecoderConfig = Field(...)
34 |
35 |
36 | VIT_S_CONFIG = VQVAEConfig(
37 | n_embed=8192,
38 | embed_dim=32,
39 | beta=0.25,
40 | enc=EncoderConfig(
41 | image_size=256,
42 | patch_size=8,
43 | dim=512,
44 | depth=8,
45 | num_head=8,
46 | mlp_dim=2048,
47 | in_channels=3,
48 | dim_head=64,
49 | dropout=0.0,
50 | ),
51 | dec=DecoderConfig(
52 | image_size=256,
53 | patch_size=8,
54 | dim=512,
55 | depth=8,
56 | num_head=8,
57 | mlp_dim=2048,
58 | out_channels=3,
59 | dim_head=64,
60 | dropout=0.0,
61 | ),
62 | )
63 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/vqvae/discriminator.py:
--------------------------------------------------------------------------------
1 | import functools
2 |
3 | import torch.nn as nn
4 |
5 |
6 | def weights_init(m):
7 | classname = m.__class__.__name__
8 | if classname.find("Conv") != -1:
9 | nn.init.normal_(m.weight.data, 0.0, 0.02)
10 | elif classname.find("BatchNorm") != -1:
11 | nn.init.normal_(m.weight.data, 1.0, 0.02)
12 | nn.init.constant_(m.bias.data, 0)
13 |
14 |
15 | class NLayerDiscriminator(nn.Module):
16 | """Defines a PatchGAN discriminator"""
17 |
18 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
19 | """Construct a PatchGAN discriminator
20 | Parameters:
21 | input_nc (int) -- the number of channels in input images
22 | ndf (int) -- the number of filters in the last conv layer
23 | n_layers (int) -- the number of conv layers in the discriminator
24 | norm_layer -- normalization layer
25 | """
26 | super().__init__()
27 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
28 | use_bias = norm_layer.func == nn.InstanceNorm2d
29 | else:
30 | use_bias = norm_layer == nn.InstanceNorm2d
31 |
32 | kw = 4
33 | padw = 1
34 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
35 | nf_mult = 1
36 | nf_mult_prev = 1
37 | for n in range(1, n_layers): # gradually increase the number of filters
38 | nf_mult_prev = nf_mult
39 | nf_mult = min(2**n, 8)
40 | sequence += [
41 | nn.Conv2d(
42 | ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias
43 | ),
44 | norm_layer(ndf * nf_mult),
45 | nn.LeakyReLU(0.2, True),
46 | ]
47 |
48 | nf_mult_prev = nf_mult
49 | nf_mult = min(2**n_layers, 8)
50 | sequence += [
51 | nn.Conv2d(
52 | ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias
53 | ),
54 | norm_layer(ndf * nf_mult),
55 | nn.LeakyReLU(0.2, True),
56 | ]
57 |
58 | sequence += [
59 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
60 | ] # output 1 channel prediction map
61 | self.model = nn.Sequential(*sequence)
62 |
63 | self.apply(self.init_func)
64 |
65 | def forward(self, input):
66 | """Standard forward."""
67 | return self.model(input)
68 |
69 | def init_func(self, m): # define the initialization function
70 | init_gain = 0.02
71 | classname = m.__class__.__name__
72 | if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1):
73 | nn.init.normal_(m.weight.data, 0.0, init_gain)
74 | if hasattr(m, "bias") and m.bias is not None:
75 | nn.init.constant_(m.bias.data, 0.0)
76 | elif (
77 | classname.find("BatchNorm2d") != -1
78 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies.
79 | nn.init.normal_(m.weight.data, 1.0, init_gain)
80 | nn.init.constant_(m.bias.data, 0.0)
81 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/vqvae/layers.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from diffusers.utils import is_xformers_available
3 | from einops import rearrange
4 | from einops.layers.torch import Rearrange
5 | from torch import nn
6 |
7 | from muse_maskgit_pytorch.modules import CrossAttention, MemoryEfficientCrossAttention, SwiGLUFFNFused
8 |
9 |
10 | def pair(t):
11 | return t if isinstance(t, tuple) else (t, t)
12 |
13 |
14 | class FeedForward(nn.Module):
15 | def __init__(self, dim, mlp_dim, dropout=0.0):
16 | super().__init__()
17 | self.w_1 = nn.Linear(dim, mlp_dim)
18 | self.act = nn.GELU()
19 | self.dropout = nn.Dropout(p=dropout)
20 | self.w_2 = nn.Linear(mlp_dim, dim)
21 |
22 | def forward(self, x):
23 | x = self.w_1(x)
24 | x = self.act(x)
25 | x = self.dropout(x)
26 | x = self.w_2(x)
27 |
28 | return x
29 |
30 |
31 | class LayerScale(nn.Module):
32 | def __init__(self, dim, init_values=1e-5, inplace=False):
33 | super().__init__()
34 | self.inplace = inplace
35 | self.gamma = nn.Parameter(init_values * torch.ones(dim))
36 |
37 | def forward(self, x):
38 | return x.mul_(self.gamma) if self.inplace else x * self.gamma
39 |
40 |
41 | class Layer(nn.Module):
42 | ATTENTION_MODES = {"vanilla": CrossAttention, "xformer": MemoryEfficientCrossAttention}
43 |
44 | def __init__(self, dim, dim_head, mlp_dim, num_head=8, dropout=0.0):
45 | super().__init__()
46 | attn_mode = "xformer" if is_xformers_available() else "vanilla"
47 | attn_cls = self.ATTENTION_MODES[attn_mode]
48 | self.norm1 = nn.LayerNorm(dim)
49 | self.attn1 = attn_cls(query_dim=dim, heads=num_head, dim_head=dim_head, dropout=dropout)
50 | self.norm2 = nn.LayerNorm(dim)
51 | self.ffnet = SwiGLUFFNFused(in_features=dim, hidden_features=mlp_dim)
52 |
53 | def forward(self, x):
54 | x = self.attn1(self.norm1(x)) + x
55 | x = self.ffnet(self.norm2(x)) + x
56 |
57 | return x
58 |
59 |
60 | class Transformer(nn.Module):
61 | def __init__(self, dim, depth, num_head, dim_head, mlp_dim, dropout=0.0):
62 | super().__init__()
63 | self.layers = nn.Sequential(*[Layer(dim, dim_head, mlp_dim, num_head, dropout) for i in range(depth)])
64 |
65 | def forward(self, x):
66 | x = self.layers(x)
67 |
68 | return x
69 |
70 |
71 | class Encoder(nn.Module):
72 | def __init__(
73 | self,
74 | image_size,
75 | patch_size,
76 | dim,
77 | depth,
78 | num_head,
79 | mlp_dim,
80 | in_channels=3,
81 | dim_head=64,
82 | dropout=0.0,
83 | ):
84 | super().__init__()
85 |
86 | self.image_size = image_size
87 | self.patch_size = patch_size
88 |
89 | assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
90 |
91 | self.to_patch_embedding = nn.Sequential(
92 | nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=False),
93 | Rearrange("b c h w -> b (h w) c"),
94 | )
95 |
96 | scale = dim**-0.5
97 | num_patches = (image_size // patch_size) ** 2
98 | self.position_embedding = nn.Parameter(torch.randn(1, num_patches, dim) * scale)
99 | self.norm_pre = nn.LayerNorm(dim)
100 | self.transformer = Transformer(dim, depth, num_head, dim_head, mlp_dim, dropout)
101 |
102 | self.initialize_weights()
103 |
104 | def initialize_weights(self):
105 | self.apply(self._init_weights)
106 |
107 | def _init_weights(self, m):
108 | if isinstance(m, nn.Linear):
109 | torch.nn.init.xavier_uniform_(m.weight)
110 | if isinstance(m, nn.Linear) and m.bias is not None:
111 | nn.init.constant_(m.bias, 0)
112 | elif isinstance(m, nn.LayerNorm):
113 | nn.init.constant_(m.bias, 0)
114 | nn.init.constant_(m.weight, 1.0)
115 |
116 | def forward(self, x):
117 | x = self.to_patch_embedding(x)
118 | x = x + self.position_embedding
119 | x = self.norm_pre(x)
120 | x = self.transformer(x)
121 |
122 | return x
123 |
124 |
125 | class Decoder(nn.Module):
126 | def __init__(
127 | self,
128 | image_size,
129 | patch_size,
130 | dim,
131 | depth,
132 | num_head,
133 | mlp_dim,
134 | out_channels=3,
135 | dim_head=64,
136 | dropout=0.0,
137 | ):
138 | super().__init__()
139 |
140 | self.image_size = image_size
141 | self.patch_size = patch_size
142 |
143 | assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size."
144 |
145 | scale = dim**-0.5
146 | num_patches = (image_size // patch_size) ** 2
147 | self.position_embedding = nn.Parameter(torch.randn(1, num_patches, dim) * scale)
148 | self.transformer = Transformer(dim, depth, num_head, dim_head, mlp_dim, dropout)
149 | self.norm = nn.LayerNorm(dim)
150 | self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True)
151 |
152 | self.initialize_weights()
153 |
154 | def initialize_weights(self):
155 | self.apply(self._init_weights)
156 |
157 | def _init_weights(self, m):
158 | if isinstance(m, nn.Linear):
159 | torch.nn.init.xavier_uniform_(m.weight)
160 | if isinstance(m, nn.Linear) and m.bias is not None:
161 | nn.init.constant_(m.bias, 0)
162 | elif isinstance(m, nn.LayerNorm):
163 | nn.init.constant_(m.bias, 0)
164 | nn.init.constant_(m.weight, 1.0)
165 |
166 | def forward(self, x):
167 | x = x + self.position_embedding
168 | x = self.transformer(x)
169 | x = self.norm(x)
170 | x = self.proj(x)
171 | x = rearrange(
172 | x,
173 | "b (h w) (p1 p2 c) -> b c (h p1) (w p2)",
174 | h=self.image_size // self.patch_size,
175 | p1=self.patch_size,
176 | p2=self.patch_size,
177 | )
178 |
179 | return x
180 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/vqvae/quantize.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class VectorQuantize(nn.Module):
7 | def __init__(self, n_e, vq_embed_dim, beta=0.25):
8 | super().__init__()
9 | self.n_e = n_e
10 | self.vq_embed_dim = vq_embed_dim
11 | self.beta = beta
12 |
13 | self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
14 | self.embedding.weight.data.normal_()
15 |
16 | def forward(self, z):
17 | z = F.normalize(z, p=2, dim=-1)
18 | z_flattened = z.view(-1, self.vq_embed_dim)
19 | embed_norm = F.normalize(self.embedding.weight, p=2, dim=-1)
20 |
21 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
22 | d = (
23 | torch.sum(z_flattened**2, dim=1, keepdim=True)
24 | + torch.sum(embed_norm**2, dim=1)
25 | - 2 * torch.einsum("bd,nd->bn", z_flattened, embed_norm)
26 | )
27 |
28 | encoding_indices = torch.argmin(d, dim=1).view(*z.shape[:-1])
29 | z_q = self.embedding(encoding_indices).view(z.shape)
30 | z_q = F.normalize(z_q, p=2, dim=-1)
31 |
32 | # compute loss for embedding
33 | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2)
34 |
35 | # preserve gradients
36 | z_q = z + (z_q - z).detach()
37 |
38 | return z_q, loss, encoding_indices
39 |
40 | def decode_ids(self, indices):
41 | z_q = self.embedding(indices)
42 | z_q = F.normalize(z_q, p=2, dim=-1)
43 |
44 | return z_q
45 |
--------------------------------------------------------------------------------
/muse_maskgit_pytorch/vqvae/vqvae.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | import torch
4 | import torch.nn as nn
5 | from diffusers import ConfigMixin, ModelMixin
6 | from diffusers.configuration_utils import register_to_config
7 |
8 | from .layers import Decoder, Encoder
9 | from .quantize import VectorQuantize
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class VQVAE(ModelMixin, ConfigMixin):
15 | @register_to_config
16 | def __init__(self, n_embed, embed_dim, beta, enc, dec, **kwargs):
17 | super().__init__()
18 | self.encoder = Encoder(**enc)
19 | self.decoder = Decoder(**dec)
20 |
21 | self.prev_quant = nn.Linear(enc["dim"], embed_dim)
22 | self.quantize = VectorQuantize(n_embed, embed_dim, beta)
23 | self.post_quant = nn.Linear(embed_dim, dec["dim"])
24 |
25 | def freeze(self):
26 | self.eval()
27 | self.requires_grad_(False)
28 |
29 | def encode(self, x):
30 | x = self.encoder(x)
31 | x = self.prev_quant(x)
32 | x, loss, indices = self.quantize(x)
33 | return x, loss, indices
34 |
35 | def decode(self, x):
36 | x = self.post_quant(x)
37 | x = self.decoder(x)
38 | return x.clamp(-1.0, 1.0)
39 |
40 | def forward(self, inputs: torch.FloatTensor):
41 | z, loss, _ = self.encode(inputs)
42 | rec = self.decode(z)
43 | return rec, loss
44 |
45 | def encode_to_ids(self, inputs):
46 | _, _, indices = self.encode(inputs)
47 | return indices
48 |
49 | def decode_from_ids(self, indice):
50 | z_q = self.quantize.decode_ids(indice)
51 | img = self.decode(z_q)
52 | return img
53 |
54 | def __call__(self, inputs: torch.FloatTensor):
55 | return self.forward(inputs)
56 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | build-backend = "setuptools.build_meta"
3 | requires = ["setuptools>=61.0.0", "wheel", "setuptools_scm[toml]>=6.2"]
4 |
5 | [tool.setuptools_scm]
6 | write_to = "muse_maskgit_pytorch/_version.py"
7 |
8 | [tool.black]
9 | line-length = 110
10 | target-version = ['py38', 'py39', 'py310']
11 |
12 | [tool.ruff]
13 | line-length = 110
14 | target-version = 'py38'
15 | format = "grouped"
16 | ignore-init-module-imports = true
17 | select = ["E", "F", "I"]
18 | ignore = ['F841', 'F401', 'E501']
19 |
20 | [tool.ruff.isort]
21 | combine-as-imports = true
22 | force-wrap-aliases = true
23 | known-local-folder = ["muse_maskgit_pytorch"]
24 | known-first-party = ["muse_maskgit_pytorch"]
25 |
--------------------------------------------------------------------------------
/scripts/vqvae_test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from pathlib import Path
3 |
4 | import torch
5 | from huggingface_hub import hf_hub_download
6 | from PIL import Image
7 | from torchvision import transforms as T
8 | from torchvision.utils import save_image
9 |
10 | from muse_maskgit_pytorch.vqvae import VQVAE
11 |
12 | logging.basicConfig(level=logging.INFO)
13 | logger = logging.getLogger(__name__)
14 |
15 | # where to find the model and the test images
16 | model_repo = "neggles/vaedump"
17 | model_subdir = "vit-s-vqgan-f4"
18 | test_images = ["testimg_1.png", "testimg_2.png"]
19 |
20 | # where to save the preprocessed and reconstructed images
21 | image_dir = Path.cwd().joinpath("temp")
22 | image_dir.mkdir(exist_ok=True, parents=True)
23 |
24 | # image transforms for the VQVAE
25 | transform_enc = T.Compose([T.Resize(512), T.RandomCrop(256), T.ToTensor()])
26 | transform_dec = T.Compose([T.ConvertImageDtype(torch.uint8), T.ToPILImage()])
27 |
28 |
29 | def get_save_path(path: Path, append: str) -> Path:
30 | # append a string to the filename before the extension
31 | # n.b. only keeps the final suffix, e.g. "foo.xyz.png" -> "foo-prepro.png"
32 | return path.with_name(f"{path.stem}-{append}{path.suffix}")
33 |
34 |
35 | def main():
36 | torch_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
37 | dtype = torch.float32
38 |
39 | # load VAE
40 | logger.info(f"Loading VQVAE from {model_repo}/{model_subdir}...")
41 | vae: VQVAE = VQVAE.from_pretrained(model_repo, subfolder=model_subdir, torch_dtype=dtype)
42 | vae = vae.to(torch_device)
43 | logger.info(f"Loaded VQVAE from {model_repo} to {vae.device} with dtype {vae.dtype}")
44 |
45 | # download and process images
46 | for image in test_images:
47 | image_path = hf_hub_download(model_repo, subfolder="images", filename=image, local_dir=image_dir)
48 | image_path = Path(image_path)
49 | logger.info(f"Downloaded {image_path}, size {image_path.stat().st_size} bytes")
50 |
51 | # preprocess
52 | image_obj = Image.open(image_path).convert("RGB")
53 | image_tensor: torch.Tensor = transform_enc(image_obj)
54 | save_path = get_save_path(image_path, "prepro")
55 | save_image(image_tensor, save_path, normalize=True, range=(-1.0, 1.0))
56 | logger.info(f"Saved preprocessed image to {save_path}")
57 |
58 | # encode
59 | encoded, _, _ = vae.encode(image_tensor.unsqueeze(0).to(vae.device))
60 |
61 | # decode
62 | reconstructed = vae.decode(encoded).squeeze(0)
63 | reconstructed = torch.clamp(reconstructed, -1.0, 1.0)
64 |
65 | # save
66 | save_path = get_save_path(image_path, "recon")
67 | save_image(reconstructed, save_path, normalize=True, range=(-1.0, 1.0))
68 | logger.info(f"Saved reconstructed image to {save_path}")
69 |
70 | # compare
71 | image_prepro = transform_dec(image_tensor)
72 | image_recon = transform_dec(reconstructed)
73 | canvas = Image.new("RGB", (512 + 12, 256 + 8), (248, 248, 242))
74 | canvas.paste(image_prepro, (4, 4))
75 | canvas.paste(image_recon, (256 + 8, 4))
76 | save_path = get_save_path(image_path, "compare")
77 | canvas.save(save_path)
78 | logger.info(f"Saved comparison image to {save_path}")
79 |
80 | logger.info("Done!")
81 |
82 |
83 | if __name__ == "__main__":
84 | main()
85 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import find_packages, setup
2 |
3 | setup(
4 | name="muse-maskgit-pytorch",
5 | packages=find_packages(exclude=[]),
6 | version="0.1.6",
7 | license="MIT",
8 | description="MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch",
9 | author="Phil Wang",
10 | author_email="lucidrains@gmail.com",
11 | long_description_content_type="text/markdown",
12 | url="https://github.com/lucidrains/muse-maskgit-pytorch",
13 | keywords=[
14 | "artificial intelligence",
15 | "deep learning",
16 | "transformers",
17 | "attention mechanism",
18 | "text-to-image",
19 | ],
20 | extras_require={
21 | "dev": [
22 | "pre-commit>=3.3.2",
23 | "black>=23.3.0",
24 | "ruff>=0.0.272",
25 | ]
26 | },
27 | install_requires=[
28 | "accelerate",
29 | "diffusers",
30 | "datasets",
31 | "beartype",
32 | "einops>=0.6",
33 | "ema-pytorch",
34 | "omegaconf>=2.3.0",
35 | "pillow",
36 | "sentencepiece",
37 | "torch>=2.0",
38 | "torchmetrics<0.8.0",
39 | "pytorch-lightning>=2.0.0",
40 | "taming-transformers @ git+https://github.com/neggles/taming-transformers.git@v0.0.2",
41 | "transformers",
42 | "torchvision",
43 | "torch_optimizer",
44 | "tqdm",
45 | "timm",
46 | "tqdm-loggable",
47 | "vector-quantize-pytorch>=0.10.14",
48 | "lion-pytorch",
49 | "omegaconf",
50 | "xformers>=0.0.20",
51 | ],
52 | classifiers=[
53 | "Development Status :: 4 - Beta",
54 | "Intended Audience :: Developers",
55 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
56 | "License :: OSI Approved :: MIT License",
57 | "Programming Language :: Python :: 3.6",
58 | ],
59 | )
60 |
--------------------------------------------------------------------------------
/tpu-vm.env:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | # dot-source this, ok?
3 |
4 | export PYTHONUNBUFFERED='1'
5 |
6 | ## General log level opts for Accelerate/Transformers
7 | #export ACCELERATE_LOG_LEVEL='INFO'
8 | #export TRANSFORMERS_LOG_LEVEL='INFO'
9 |
10 | # tcmalloc breaks things and google enable it by default, so that's gotta go
11 | unset LD_PRELOAD
12 |
13 | # add the dir where `libtpu-nightly` puts the library to LD_LIBRARY_PATH
14 | export LD_LIBRARY_PATH="/usr/local/lib/python3.8/dist-packages/libtpu/:${LD_LIBRARY_PATH}"
15 |
16 | # PJRT doesn't work with Accelerate yet so we deconfigure it and go back to old XRT
17 | unset PJRT_DEVICE
18 | export XRT_TPU_CONFIG='localservice;0;localhost:51011'
19 | export MASTER_ADDR='localhost'
20 | export MASTER_PORT='12355'
21 |
22 | ## see https://github.com/pytorch/xla/issues/4914
23 | export XLA_IR_SHAPE_CACHE_SIZE=12288
24 |
25 | ## useful options for debug
26 | #export PT_XLA_DEBUG=1
27 | # Enables the Python stack trace to be captured where creating IR nodes, hence allowing to understand which PyTorch operation was responsible for generating the IR.
28 | #export XLA_IR_DEBUG=1
29 | # Path to save the IR graphs generated during execution.
30 | #export XLA_SAVE_TENSORS_FILE=''
31 | # File type for above. can be text, dot (GraphViz), or hlo (native)
32 | #export XLA_SAVE_TENSORS_FMT='text'
33 | # Path to save metrics after every op
34 | #export XLA_METRICS_FILE=
35 | # In case of compilation/execution error, the offending HLO graph will be saved here.
36 | #export XLA_SAVE_HLO_FILE=
37 |
38 | # Enable OpByOp dispatch for "get tensors"
39 | #export XLA_GET_TENSORS_OPBYOP=1
40 | # Enable OpByOp dispatch for "sync tensors"
41 | #export XLA_SYNC_TENSORS_OPBYOP=1
42 | # Force XLA tensor sync before moving to next step
43 | #export XLA_SYNC_WAIT=1
44 |
45 | # Force downcasting of fp32 to bf16
46 | #export XLA_USE_BF16=1
47 | # Force downcasting of fp32 to fp16
48 | #export XLA_USE_F16=1
49 | # Force downcasting of fp64 to fp32
50 | #export XLA_USE_32BIT_LONG=1
51 |
52 | ## TPU runtime / compilation debug logging
53 | # All XLA log messages are INFO level so this is required
54 | #export TF_CPP_MIN_LOG_LEVEL=0
55 | # Print the thread ID in log messages
56 | #export TF_CPP_LOG_THREAD_ID=1
57 | # What modules to print from at what level
58 | #export TF_CPP_VMODULE='tensor=4,computation_client=5,xrt_computation_client=5,aten_xla_type=5'
59 |
60 | ## Limit to single TPU chip/core, can be useful for testing
61 | # export TPU_PROCESS_BOUNDS='1,1,1'
62 | # export TPU_VISIBLE_CHIPS=0
63 |
--------------------------------------------------------------------------------
/train_muse_vae.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import glob
3 | import os
4 | import re
5 | from dataclasses import dataclass
6 | from typing import Optional, Union
7 |
8 | import wandb
9 | from accelerate.utils import ProjectConfiguration
10 | from datasets import load_dataset
11 | from omegaconf import OmegaConf
12 |
13 | from muse_maskgit_pytorch import (
14 | VQGanVAE,
15 | VQGanVAETaming,
16 | VQGanVAETrainer,
17 | get_accelerator,
18 | )
19 | from muse_maskgit_pytorch.dataset import (
20 | ImageDataset,
21 | get_dataset_from_dataroot,
22 | split_dataset_into_dataloaders,
23 | )
24 |
25 | # disable bitsandbytes welcome message.
26 | os.environ["BITSANDBYTES_NOWELCOME"] = "1"
27 |
28 | parser = argparse.ArgumentParser()
29 | parser.add_argument("--webdataset", type=str, default=None, help="Path to webdataset if using one.")
30 | parser.add_argument(
31 | "--only_save_last_checkpoint",
32 | action="store_true",
33 | help="Only save last checkpoint.",
34 | )
35 | parser.add_argument(
36 | "--validation_image_scale",
37 | default=1,
38 | type=float,
39 | help="Factor by which to scale the validation images.",
40 | )
41 | parser.add_argument(
42 | "--no_center_crop",
43 | action="store_true",
44 | help="Don't do center crop.",
45 | )
46 | parser.add_argument(
47 | "--no_flip",
48 | action="store_true",
49 | help="Don't flip image.",
50 | )
51 | parser.add_argument(
52 | "--random_crop",
53 | action="store_true",
54 | help="Crop the images at random locations instead of cropping from the center.",
55 | )
56 | parser.add_argument(
57 | "--dataset_save_path",
58 | type=str,
59 | default="dataset",
60 | help="Path to save the dataset if you are making one from a directory",
61 | )
62 | parser.add_argument(
63 | "--clear_previous_experiments",
64 | action="store_true",
65 | help="Whether to clear previous experiments.",
66 | )
67 | parser.add_argument("--max_grad_norm", type=float, default=None, help="Max gradient norm.")
68 | parser.add_argument(
69 | "--discr_max_grad_norm",
70 | type=float,
71 | default=None,
72 | help="Max gradient norm for discriminator.",
73 | )
74 | parser.add_argument("--seed", type=int, default=42, help="Seed.")
75 | parser.add_argument("--valid_frac", type=float, default=0.05, help="validation fraction.")
76 | parser.add_argument("--use_ema", action="store_true", help="Whether to use ema.")
77 | parser.add_argument("--ema_beta", type=float, default=0.995, help="Ema beta.")
78 | parser.add_argument("--ema_update_after_step", type=int, default=1, help="Ema update after step.")
79 | parser.add_argument(
80 | "--ema_update_every",
81 | type=int,
82 | default=1,
83 | help="Ema update every this number of steps.",
84 | )
85 | parser.add_argument(
86 | "--apply_grad_penalty_every",
87 | type=int,
88 | default=4,
89 | help="Apply gradient penalty every this number of steps.",
90 | )
91 | parser.add_argument(
92 | "--image_column",
93 | type=str,
94 | default="image",
95 | help="The column of the dataset containing an image.",
96 | )
97 | parser.add_argument(
98 | "--caption_column",
99 | type=str,
100 | default="caption",
101 | help="The column of the dataset containing a caption or a list of captions.",
102 | )
103 | parser.add_argument(
104 | "--log_with",
105 | type=str,
106 | default="wandb",
107 | help=(
108 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
109 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
110 | ),
111 | )
112 | parser.add_argument(
113 | "--project_name",
114 | type=str,
115 | default="muse_vae",
116 | help=("Name to use for the project to identify it when saved to a tracker such as wandb or tensorboard."),
117 | )
118 | parser.add_argument(
119 | "--run_name",
120 | type=str,
121 | default=None,
122 | help=(
123 | "Name to use for the run to identify it when saved to a tracker such"
124 | " as wandb or tensorboard. If not specified a random one will be generated."
125 | ),
126 | )
127 | parser.add_argument(
128 | "--wandb_user",
129 | type=str,
130 | default=None,
131 | help=(
132 | "Specify the name for the user or the organization in which the project will be saved when using wand."
133 | ),
134 | )
135 | parser.add_argument(
136 | "--mixed_precision",
137 | type=str,
138 | default="no",
139 | choices=["no", "fp8", "fp16", "bf16"],
140 | help="Precision to train on.",
141 | )
142 | parser.add_argument(
143 | "--use_8bit_adam",
144 | action="store_true",
145 | help="Whether to use the 8bit adam optimiser",
146 | )
147 | parser.add_argument(
148 | "--results_dir",
149 | type=str,
150 | default="results",
151 | help="Path to save the training samples and checkpoints",
152 | )
153 | parser.add_argument(
154 | "--logging_dir",
155 | type=str,
156 | default=None,
157 | help="Path to log the losses and LR",
158 | )
159 |
160 | # vae_trainer args
161 | parser.add_argument(
162 | "--dataset_name",
163 | type=str,
164 | default=None,
165 | help="Name of the huggingface dataset used.",
166 | )
167 | parser.add_argument(
168 | "--hf_split_name",
169 | type=str,
170 | default="train",
171 | help="Subset or split to use from the dataset when using a dataset form HuggingFace.",
172 | )
173 | parser.add_argument(
174 | "--streaming",
175 | action="store_true",
176 | help="Whether to stream the huggingface dataset",
177 | )
178 | parser.add_argument(
179 | "--train_data_dir",
180 | type=str,
181 | default=None,
182 | help="Dataset folder where your input images for training are.",
183 | )
184 | parser.add_argument(
185 | "--num_train_steps",
186 | type=int,
187 | default=-1,
188 | help="Total number of steps to train for. eg. 50000. | Use only if you want to stop training early",
189 | )
190 | parser.add_argument(
191 | "--num_epochs",
192 | type=int,
193 | default=5,
194 | help="Total number of epochs to train for. eg. 5.",
195 | )
196 | parser.add_argument("--dim", type=int, default=128, help="Model dimension.")
197 | parser.add_argument("--batch_size", type=int, default=1, help="Batch Size.")
198 | parser.add_argument("--lr", type=float, default=1e-5, help="Learning Rate.")
199 | parser.add_argument(
200 | "--gradient_accumulation_steps",
201 | type=int,
202 | default=1,
203 | help="Gradient Accumulation.",
204 | )
205 | parser.add_argument(
206 | "--save_results_every",
207 | type=int,
208 | default=100,
209 | help="Save results every this number of steps.",
210 | )
211 | parser.add_argument(
212 | "--save_model_every",
213 | type=int,
214 | default=500,
215 | help="Save the model every this number of steps.",
216 | )
217 | parser.add_argument(
218 | "--checkpoint_limit",
219 | type=int,
220 | default=None,
221 | help="Keep only X number of checkpoints and delete the older ones.",
222 | )
223 | parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.")
224 | parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.")
225 | parser.add_argument(
226 | "--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA."
227 | )
228 | parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.")
229 | parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.")
230 | parser.add_argument(
231 | "--image_size",
232 | type=int,
233 | default=256,
234 | help="Image size. You may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it",
235 | )
236 | parser.add_argument(
237 | "--lr_scheduler",
238 | type=str,
239 | default="constant",
240 | help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]',
241 | )
242 | parser.add_argument(
243 | "--scheduler_power",
244 | type=float,
245 | default=1.0,
246 | help="Controls the power of the polynomial decay schedule used by the CosineScheduleWithWarmup scheduler. "
247 | "It determines the rate at which the learning rate decreases during the schedule.",
248 | )
249 | parser.add_argument(
250 | "--lr_warmup_steps",
251 | type=int,
252 | default=0,
253 | help="Number of steps for the warmup in the lr scheduler.",
254 | )
255 | parser.add_argument(
256 | "--num_cycles",
257 | type=int,
258 | default=1,
259 | help="Number of cycles for the lr scheduler.",
260 | )
261 | parser.add_argument(
262 | "--resume_path",
263 | type=str,
264 | default=None,
265 | help="Path to the last saved checkpoint. 'results/vae.steps.pt'",
266 | )
267 | parser.add_argument(
268 | "--weight_decay",
269 | type=float,
270 | default=0.0,
271 | help="Optimizer weight_decay to use. Default: 0.0",
272 | )
273 | parser.add_argument(
274 | "--taming_model_path",
275 | type=str,
276 | default=None,
277 | help="path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)",
278 | )
279 | parser.add_argument(
280 | "--taming_config_path",
281 | type=str,
282 | default=None,
283 | help="path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)",
284 | )
285 | parser.add_argument(
286 | "--optimizer",
287 | type=str,
288 | default="Lion",
289 | help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Lion",
290 | )
291 | parser.add_argument(
292 | "--cache_path",
293 | type=str,
294 | default=None,
295 | help="The path to cache huggingface models",
296 | )
297 | parser.add_argument(
298 | "--no_cache",
299 | action="store_true",
300 | help="Do not save the dataset pyarrow cache/files to disk to save disk space and reduce the time it takes to launch the training.",
301 | )
302 | parser.add_argument(
303 | "--latest_checkpoint",
304 | action="store_true",
305 | help="Whether to use the latest checkpoint",
306 | )
307 | parser.add_argument(
308 | "--do_not_save_config",
309 | action="store_true",
310 | default=False,
311 | help="Generate example YAML configuration file",
312 | )
313 | parser.add_argument(
314 | "--use_l2_recon_loss",
315 | action="store_true",
316 | help="Use F.mse_loss instead of F.l1_loss.",
317 | )
318 |
319 |
320 | @dataclass
321 | class Arguments:
322 | total_params: Optional[int] = None
323 | only_save_last_checkpoint: bool = False
324 | validation_image_scale: float = 1.0
325 | no_center_crop: bool = False
326 | no_flip: bool = False
327 | random_crop: bool = False
328 | dataset_save_path: Optional[str] = None
329 | clear_previous_experiments: bool = False
330 | max_grad_norm: Optional[float] = None
331 | discr_max_grad_norm: Optional[float] = None
332 | num_tokens: int = 256
333 | seq_len: int = 1024
334 | seed: int = 42
335 | valid_frac: float = 0.05
336 | use_ema: bool = False
337 | ema_beta: float = 0.995
338 | ema_update_after_step: int = 1
339 | ema_update_every: int = 1
340 | apply_grad_penalty_every: int = 4
341 | image_column: str = "image"
342 | caption_column: str = "caption"
343 | log_with: str = "wandb"
344 | mixed_precision: str = "no"
345 | use_8bit_adam: bool = False
346 | results_dir: str = "results"
347 | logging_dir: Optional[str] = None
348 | resume_path: Optional[str] = None
349 | dataset_name: Optional[str] = None
350 | streaming: bool = False
351 | train_data_dir: Optional[str] = None
352 | num_train_steps: int = -1
353 | num_epochs: int = 5
354 | dim: int = 128
355 | batch_size: int = 512
356 | lr: float = 1e-5
357 | gradient_accumulation_steps: int = 1
358 | save_results_every: int = 100
359 | save_model_every: int = 500
360 | checkpoint_limit: Union[int, str] = None
361 | vq_codebook_size: int = 256
362 | vq_codebook_dim: int = 256
363 | cond_drop_prob: float = 0.5
364 | image_size: int = 256
365 | lr_scheduler: str = "constant"
366 | scheduler_power: float = 1.0
367 | lr_warmup_steps: int = 0
368 | num_cycles: int = 1
369 | taming_model_path: Optional[str] = None
370 | taming_config_path: Optional[str] = None
371 | optimizer: str = "Lion"
372 | weight_decay: float = 0.0
373 | cache_path: Optional[str] = None
374 | no_cache: bool = False
375 | latest_checkpoint: bool = False
376 | do_not_save_config: bool = False
377 | use_l2_recon_loss: bool = False
378 | debug: bool = False
379 | config_path: Optional[str] = None
380 |
381 |
382 | def preprocess_webdataset(args, image):
383 | return {args.image_column: image}
384 |
385 |
386 | def main():
387 | args = parser.parse_args(namespace=Arguments())
388 |
389 | if args.config_path:
390 | print("Using config file and ignoring CLI args")
391 |
392 | try:
393 | conf = OmegaConf.load(args.config_path)
394 | conf_keys = conf.keys()
395 | args_to_convert = vars(args)
396 |
397 | for key in conf_keys:
398 | try:
399 | args_to_convert[key] = conf[key]
400 | except KeyError:
401 | print(f"Error parsing config - {key}: {conf[key]} | Using default or parsed")
402 |
403 | except FileNotFoundError:
404 | print("Could not find config, using default and parsed values...")
405 |
406 | project_config = ProjectConfiguration(
407 | project_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"),
408 | total_limit=args.checkpoint_limit,
409 | automatic_checkpoint_naming=True,
410 | )
411 |
412 | accelerator = get_accelerator(
413 | log_with=args.log_with,
414 | gradient_accumulation_steps=args.gradient_accumulation_steps,
415 | mixed_precision=args.mixed_precision,
416 | project_config=project_config,
417 | even_batches=True,
418 | )
419 | if accelerator.is_main_process:
420 | accelerator.init_trackers(
421 | args.project_name,
422 | config=vars(args),
423 | init_kwargs={
424 | "wandb": {
425 | "entity": f"{args.wandb_user or wandb.api.default_entity}",
426 | "name": args.run_name,
427 | },
428 | },
429 | )
430 |
431 | if args.webdataset is not None:
432 | import webdataset as wds
433 |
434 | dataset = wds.WebDataset(args.webdataset).shuffle(1000).decode("rgb").to_tuple("png")
435 | dataset = dataset.map(lambda image: preprocess_webdataset(args, image))
436 | elif args.train_data_dir:
437 | dataset = get_dataset_from_dataroot(
438 | args.train_data_dir,
439 | image_column=args.image_column,
440 | caption_column=args.caption_column,
441 | save_path=args.dataset_save_path,
442 | save=not args.no_cache,
443 | )
444 | elif args.dataset_name:
445 | if args.cache_path:
446 | dataset = load_dataset(args.dataset_name, streaming=args.streaming, cache_dir=args.cache_path)[
447 | "train"
448 | ]
449 | else:
450 | dataset = load_dataset(args.dataset_name, streaming=args.streaming, cache_dir=args.cache_path)[
451 | "train"
452 | ]
453 | if args.streaming:
454 | if dataset.info.dataset_size is None:
455 | print("Dataset doesn't support streaming, disabling streaming")
456 | args.streaming = False
457 | if args.cache_path:
458 | dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name]
459 | else:
460 | dataset = load_dataset(args.dataset_name)[args.hf_split_name]
461 |
462 | if args.resume_path is not None and len(args.resume_path) > 1:
463 | load = True
464 | accelerator.print(f"Using Muse VQGanVAE, loading from {args.resume_path}")
465 | vae = VQGanVAE(
466 | dim=args.dim,
467 | vq_codebook_dim=args.vq_codebook_dim,
468 | vq_codebook_size=args.vq_codebook_size,
469 | l2_recon_loss=args.use_l2_recon_loss,
470 | channels=args.channels,
471 | layers=args.layers,
472 | discr_layers=args.discr_layers,
473 | accelerator=accelerator,
474 | )
475 |
476 | if args.latest_checkpoint:
477 | accelerator.print("Finding latest checkpoint...")
478 | orig_vae_path = args.resume_path
479 |
480 | if os.path.isfile(args.resume_path) or ".pt" in args.resume_path:
481 | # If args.vae_path is a file, split it into directory and filename
482 | args.resume_path, _ = os.path.split(args.resume_path)
483 |
484 | checkpoint_files = glob.glob(os.path.join(args.resume_path, "vae.*.pt"))
485 | if checkpoint_files:
486 | latest_checkpoint_file = max(
487 | checkpoint_files, key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1))
488 | )
489 |
490 | # Check if latest checkpoint is empty or unreadable
491 | if os.path.getsize(latest_checkpoint_file) == 0 or not os.access(
492 | latest_checkpoint_file, os.R_OK
493 | ):
494 | accelerator.print(
495 | f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable."
496 | )
497 | if len(checkpoint_files) > 1:
498 | # Use the second last checkpoint as a fallback
499 | latest_checkpoint_file = max(
500 | checkpoint_files[:-1], key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1))
501 | )
502 | accelerator.print("Using second last checkpoint: ", latest_checkpoint_file)
503 | else:
504 | accelerator.print("No usable checkpoint found.")
505 | load = False
506 | elif latest_checkpoint_file != orig_vae_path:
507 | accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file)
508 | else:
509 | accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path)
510 |
511 | args.resume_path = latest_checkpoint_file
512 | else:
513 | accelerator.print("No checkpoints found in directory: ", args.resume_path)
514 | load = False
515 | else:
516 | accelerator.print("Resuming VAE from: ", args.resume_path)
517 |
518 | if load:
519 | vae.load(args.resume_path, map="cpu")
520 |
521 | resume_from_parts = args.resume_path.split(".")
522 | for i in range(len(resume_from_parts) - 1, -1, -1):
523 | if resume_from_parts[i].isdigit():
524 | current_step = int(resume_from_parts[i])
525 | accelerator.print(f"Found step {current_step} for the VAE model.")
526 | break
527 | if current_step == 0:
528 | accelerator.print("No step found for the VAE model.")
529 | else:
530 | accelerator.print("No step found for the VAE model.")
531 | current_step = 0
532 |
533 | elif args.taming_model_path is not None and args.taming_config_path is not None:
534 | print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}")
535 | vae = VQGanVAETaming(
536 | vqgan_model_path=args.taming_model_path,
537 | vqgan_config_path=args.taming_config_path,
538 | accelerator=accelerator,
539 | )
540 | args.num_tokens = vae.codebook_size
541 | args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2
542 | else:
543 | print("Initialising empty VAE")
544 | vae = VQGanVAE(
545 | dim=args.dim,
546 | vq_codebook_dim=args.vq_codebook_dim,
547 | vq_codebook_size=args.vq_codebook_size,
548 | channels=args.channels,
549 | layers=args.layers,
550 | discr_layers=args.discr_layers,
551 | accelerator=accelerator,
552 | )
553 |
554 | current_step = 0
555 |
556 | # Use the parameters() method to get an iterator over all the learnable parameters of the model
557 | total_params = sum(p.numel() for p in vae.parameters())
558 | args.total_params = total_params
559 |
560 | print(f"Total number of parameters: {format(total_params, ',d')}")
561 |
562 | dataset = ImageDataset(
563 | dataset,
564 | args.image_size,
565 | image_column=args.image_column,
566 | center_crop=not args.no_center_crop,
567 | flip=not args.no_flip,
568 | stream=args.streaming,
569 | random_crop=args.random_crop,
570 | alpha_channel=False if args.channels == 3 else True,
571 | )
572 | # dataloader
573 |
574 | dataloader, validation_dataloader = split_dataset_into_dataloaders(
575 | dataset, args.valid_frac, args.seed, args.batch_size
576 | )
577 | trainer = VQGanVAETrainer(
578 | vae,
579 | dataloader,
580 | validation_dataloader,
581 | accelerator,
582 | current_step=current_step + 1 if current_step != 0 else current_step,
583 | num_train_steps=args.num_train_steps,
584 | lr=args.lr,
585 | lr_scheduler_type=args.lr_scheduler,
586 | lr_warmup_steps=args.lr_warmup_steps,
587 | max_grad_norm=args.max_grad_norm,
588 | discr_max_grad_norm=args.discr_max_grad_norm,
589 | save_results_every=args.save_results_every,
590 | save_model_every=args.save_model_every,
591 | results_dir=args.results_dir,
592 | logging_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"),
593 | use_ema=args.use_ema,
594 | ema_beta=args.ema_beta,
595 | ema_update_after_step=args.ema_update_after_step,
596 | ema_update_every=args.ema_update_every,
597 | apply_grad_penalty_every=args.apply_grad_penalty_every,
598 | gradient_accumulation_steps=args.gradient_accumulation_steps,
599 | clear_previous_experiments=args.clear_previous_experiments,
600 | validation_image_scale=args.validation_image_scale,
601 | only_save_last_checkpoint=args.only_save_last_checkpoint,
602 | optimizer=args.optimizer,
603 | use_8bit_adam=args.use_8bit_adam,
604 | num_cycles=args.num_cycles,
605 | scheduler_power=args.scheduler_power,
606 | num_epochs=args.num_epochs,
607 | args=args,
608 | )
609 |
610 | trainer.train()
611 |
612 |
613 | if __name__ == "__main__":
614 | main()
615 |
--------------------------------------------------------------------------------