├── .github ├── FUNDING.yml └── workflows │ ├── pre-commit.yml │ ├── python-publish.yml │ └── sem-version-release.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── extensions.json ├── launch.json └── settings.json ├── LICENSE ├── README.md ├── infer_vae.py ├── muse.png ├── muse_maskgit_pytorch ├── __init__.py ├── attn │ ├── __init__.py │ ├── attn_test.py │ ├── ein_attn.py │ ├── sdp_attn.py │ └── xformers_attn.py ├── dataset.py ├── distributed_utils.py ├── modules │ ├── __init__.py │ ├── attention.py │ └── mlp.py ├── muse_maskgit_pytorch.py ├── t5.py ├── trainers │ ├── __init__.py │ ├── base_accelerated_trainer.py │ ├── maskgit_trainer.py │ └── vqvae_trainers.py ├── vqgan_vae.py ├── vqgan_vae_taming.py └── vqvae │ ├── __init__.py │ ├── config.py │ ├── discriminator.py │ ├── layers.py │ ├── quantize.py │ └── vqvae.py ├── pyproject.toml ├── scripts └── vqvae_test.py ├── setup.py ├── tpu-vm.env ├── train_muse_maskgit.py └── train_muse_vae.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | github: [ZeroCool940711] 2 | patreon: zerocool94 3 | ko_fi: zerocool94 4 | open_collective: sygil_dev 5 | custom: ["https://paypal.me/zerocool94"] 6 | -------------------------------------------------------------------------------- /.github/workflows/pre-commit.yml: -------------------------------------------------------------------------------- 1 | name: pre-commit 2 | 3 | on: 4 | pull_request: 5 | push: 6 | branches: 7 | - main 8 | - dev 9 | 10 | jobs: 11 | pre-commit: 12 | strategy: 13 | fail-fast: false 14 | matrix: 15 | os: [ubuntu-latest] 16 | python-version: ["3.8", "3.10"] 17 | 18 | runs-on: ${{ matrix.os }} 19 | steps: 20 | - name: Checkout 21 | id: checkout 22 | uses: actions/checkout@v3 23 | with: 24 | submodules: "recursive" 25 | 26 | - name: Set up Python 27 | id: setup-python 28 | uses: actions/setup-python@v3 29 | with: 30 | python-version: ${{ matrix.python-version }} 31 | 32 | - name: Install package 33 | id: install-package 34 | run: | 35 | python -m pip install --upgrade pip setuptools wheel 36 | pip install -e '.[dev]' 37 | 38 | - name: Run pre-commit 39 | uses: pre-commit/action@v3.0.0 40 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.github/workflows/sem-version-release.yml: -------------------------------------------------------------------------------- 1 | name: Bump version 2 | on: 3 | push: 4 | branches: 5 | - master 6 | jobs: 7 | build: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - uses: actions/checkout@master 11 | - name: Bump version and push tag 12 | uses: hennejg/github-tag-action@v4.1.jh1 13 | with: 14 | github_token: ${{ secrets.GITHUB_TOKEN }} 15 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.ipynb 2 | wandb 3 | results 4 | models 5 | dataset 6 | taming 7 | ~ 8 | input.png 9 | output.png 10 | muse_maskgit_pytorch/wt.py 11 | 12 | # Byte-compiled / optimized / DLL files 13 | __pycache__/ 14 | *.py[cod] 15 | *$py.class 16 | 17 | # C extensions 18 | *.so 19 | 20 | # Distribution / packaging 21 | .Python 22 | build/ 23 | develop-eggs/ 24 | dist/ 25 | downloads/ 26 | eggs/ 27 | .eggs/ 28 | lib/ 29 | lib64/ 30 | parts/ 31 | sdist/ 32 | var/ 33 | wheels/ 34 | pip-wheel-metadata/ 35 | share/python-wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | MANIFEST 40 | 41 | # PyInstaller 42 | # Usually these files are written by a python script from a template 43 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 44 | *.manifest 45 | *.spec 46 | 47 | # Installer logs 48 | pip-log.txt 49 | pip-delete-this-directory.txt 50 | 51 | # Unit test / coverage reports 52 | htmlcov/ 53 | .tox/ 54 | .nox/ 55 | .coverage 56 | .coverage.* 57 | .cache 58 | nosetests.xml 59 | coverage.xml 60 | *.cover 61 | *.py,cover 62 | .hypothesis/ 63 | .pytest_cache/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | db.sqlite3 73 | db.sqlite3-journal 74 | 75 | # Flask stuff: 76 | instance/ 77 | .webassets-cache 78 | 79 | # Scrapy stuff: 80 | .scrapy 81 | 82 | # Sphinx documentation 83 | docs/_build/ 84 | 85 | # PyBuilder 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | .python-version 97 | 98 | # pipenv 99 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 100 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 101 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 102 | # install all needed dependencies. 103 | #Pipfile.lock 104 | 105 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 106 | __pypackages__/ 107 | 108 | # Celery stuff 109 | celerybeat-schedule 110 | celerybeat.pid 111 | 112 | # SageMath parsed files 113 | *.sage.py 114 | 115 | # Environments 116 | .env 117 | .venv 118 | env/ 119 | venv/ 120 | ENV/ 121 | env.bak/ 122 | venv.bak/ 123 | 124 | # Spyder project settings 125 | .spyderproject 126 | .spyproject 127 | 128 | # Rope project settings 129 | .ropeproject 130 | 131 | # mkdocs documentation 132 | /site 133 | 134 | # mypy 135 | .mypy_cache/ 136 | .dmypy.json 137 | dmypy.json 138 | 139 | # Pyre type checker 140 | .pyre/ 141 | 142 | # setuptools-scm version file 143 | muse_maskgit_pytorch/_version.py 144 | 145 | # wandb dir 146 | /wandb/ 147 | 148 | # data, output 149 | /data/ 150 | /output/ 151 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | ci: 4 | autofix_prs: true 5 | autoupdate_branch: 'dev' 6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit autoupdate' 7 | autoupdate_schedule: weekly 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.4.0 12 | hooks: 13 | - id: trailing-whitespace 14 | - id: end-of-file-fixer 15 | - id: check-yaml 16 | - id: check-added-large-files 17 | 18 | - repo: https://github.com/astral-sh/ruff-pre-commit 19 | rev: "v0.0.278" 20 | hooks: 21 | - id: ruff 22 | args: [--fix, --exit-non-zero-on-fix] 23 | 24 | - repo: https://github.com/psf/black 25 | rev: 23.7.0 26 | hooks: 27 | - id: black 28 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.python", 4 | "charliermarsh.ruff", 5 | "redhat.vscode-yaml", 6 | "codezombiech.gitignore", 7 | "ms-python.black-formatter" 8 | ] 9 | } 10 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | "configurations": [ 3 | { 4 | "name": "Python: Verify attn impl equivalence", 5 | "type": "python", 6 | "request": "launch", 7 | "module": "attn_test", 8 | "justMyCode": false 9 | } 10 | ] 11 | } 12 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSaveMode": "file", 3 | 4 | "files.associations": { 5 | ".config": "shellscript", 6 | ".gitignore": "gitignore", 7 | ".vscode/*.json": "jsonc", 8 | "*.txt": "plaintext", 9 | "requirements*.txt": "pip-requirements", 10 | "setup.cfg": "ini", 11 | }, 12 | 13 | "[json]": { 14 | "editor.codeActionsOnSave": { 15 | "source.fixAll.sortJSON": false 16 | }, 17 | "editor.defaultFormatter": "vscode.json-language-features", 18 | "editor.formatOnSave": true, 19 | "editor.tabSize": 4 20 | }, 21 | "[jsonc]": { 22 | "editor.codeActionsOnSave": { 23 | "source.fixAll.sortJSON": false 24 | }, 25 | "editor.defaultFormatter": "vscode.json-language-features", 26 | "editor.formatOnSave": true, 27 | "editor.tabSize": 4 28 | }, 29 | "json.format.keepLines": true, 30 | 31 | "[python]": { 32 | "editor.formatOnSave": true, 33 | "editor.defaultFormatter": "ms-python.black-formatter", 34 | "editor.codeActionsOnSave": { 35 | "source.organizeImports": true 36 | } 37 | }, 38 | "python.formatting.provider": "none", 39 | "ruff.organizeImports": true, 40 | "ruff.args": [ "--line-length=110", "--extend-ignore=F401,F841" ], 41 | "black-formatter.args": [ "--line-length", "110" ], 42 | "python.linting.flake8Enabled": false, 43 | "python.linting.mypyEnabled": false, 44 | } 45 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Muse - Pytorch 4 | 5 | ### Implementation of Muse: Text-to-Image Generation via Masked Generative Transformers, in Pytorch originally made by [Lucidrains](https://github.com/lucidrains/muse-maskgit-pytorch). 6 | 7 | ## 8 | We have added additional code to allow anyone to train their own model and we have optimized the code for low end hardware. 9 | 10 | ## Join us at Sygil.Dev's Discord Server [![Generic badge](https://flat.badgen.net/discord/members/ttM8Tm6wge?icon=discord)](https://discord.gg/ttM8Tm6wge) 11 | 12 | ## Install 13 | For installing the code you have two options: 14 | 15 | 1 - You can install it directly from the repo with pip: 16 | ```bash 17 | $ pip install git+https://github.com/Sygil-Dev/muse-maskgit-pytorch 18 | ``` 19 | 2 - or you can clone it and then install from source: 20 | ```bash 21 | $ git clone https://github.com/Sygil-Dev/muse-maskgit-pytorch 22 | $ cd muse-maskgit-pytorch 23 | $ pip install . 24 | ``` 25 | 26 | ## Usage 27 | 28 | First train your VAE - `VQGanVAE` 29 | 30 | ```python 31 | import torch 32 | from muse_maskgit_pytorch import VQGanVAE, VQGanVAETrainer 33 | 34 | vae = VQGanVAE( 35 | dim = 256, 36 | vq_codebook_size = 512 37 | ) 38 | 39 | # train on folder of images, as many images as possible 40 | 41 | trainer = VQGanVAETrainer( 42 | vae = vae, 43 | image_size = 128, # you may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it 44 | folder = '/path/to/images', 45 | batch_size = 4, 46 | grad_accum_every = 8, 47 | num_train_steps = 50000 48 | ).cuda() 49 | 50 | trainer.train() 51 | ``` 52 | 53 | Then pass the trained `VQGanVAE` and a `Transformer` to `MaskGit` 54 | 55 | ```python 56 | import torch 57 | from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer 58 | 59 | # first instantiate your vae 60 | 61 | vae = VQGanVAE( 62 | dim = 256, 63 | vq_codebook_size = 512 64 | ).cuda() 65 | 66 | vae.load('/path/to/vae.pt') # you will want to load the exponentially moving averaged VAE 67 | 68 | # then you plug the vae and transformer into your MaskGit as so 69 | 70 | # (1) create your transformer / attention network 71 | 72 | transformer = MaskGitTransformer( 73 | num_tokens = 512, # must be same as codebook size above 74 | seq_len = 256, # must be equivalent to fmap_size ** 2 in vae 75 | dim = 512, # model dimension 76 | depth = 8, # depth 77 | dim_head = 64, # attention head dimension 78 | heads = 8, # attention heads, 79 | ff_mult = 4, # feedforward expansion factor 80 | t5_name = 't5-small', # name of your T5 81 | ) 82 | 83 | # (2) pass your trained VAE and the base transformer to MaskGit 84 | 85 | base_maskgit = MaskGit( 86 | vae = vae, # vqgan vae 87 | transformer = transformer, # transformer 88 | image_size = 256, # image size 89 | cond_drop_prob = 0.25, # conditional dropout, for classifier free guidance 90 | ).cuda() 91 | 92 | # ready your training text and images 93 | 94 | texts = [ 95 | 'a child screaming at finding a worm within a half-eaten apple', 96 | 'lizard running across the desert on two feet', 97 | 'waking up to a psychedelic landscape', 98 | 'seashells sparkling in the shallow waters' 99 | ] 100 | 101 | images = torch.randn(4, 3, 256, 256).cuda() 102 | 103 | # feed it into your maskgit instance, with return_loss set to True 104 | 105 | loss = base_maskgit( 106 | images, 107 | texts = texts 108 | ) 109 | 110 | loss.backward() 111 | 112 | # do this for a long time on much data 113 | # then... 114 | 115 | images = base_maskgit.generate(texts = [ 116 | 'a whale breaching from afar', 117 | 'young girl blowing out candles on her birthday cake', 118 | 'fireworks with blue and green sparkles' 119 | ], cond_scale = 3.) # conditioning scale for classifier free guidance 120 | 121 | images.shape # (3, 3, 256, 256) 122 | ``` 123 | 124 | To train the super-resolution maskgit requires you to change 1 field on `MaskGit` instantiation (you will need to now pass in the `cond_image_size`, as the previous image size being conditioned on) 125 | 126 | Optionally, you can pass in a different `VAE` as `cond_vae` for the conditioning low-resolution image. By default it will use the `vae` for both tokenizing the super and low resoluted images. 127 | 128 | ```python 129 | import torch 130 | import torch.nn.functional as F 131 | from muse_maskgit_pytorch import VQGanVAE, MaskGit, MaskGitTransformer 132 | 133 | # first instantiate your ViT VQGan VAE 134 | # a VQGan VAE made of transformers 135 | 136 | vae = VQGanVAE( 137 | dim = 256, 138 | vq_codebook_size = 512 139 | ).cuda() 140 | 141 | vae.load('./path/to/vae.pt') # you will want to load the exponentially moving averaged VAE 142 | 143 | # then you plug the VqGan VAE into your MaskGit as so 144 | 145 | # (1) create your transformer / attention network 146 | 147 | transformer = MaskGitTransformer( 148 | num_tokens = 512, # must be same as codebook size above 149 | seq_len = 1024, # must be equivalent to fmap_size ** 2 in vae 150 | dim = 512, # model dimension 151 | depth = 2, # depth 152 | dim_head = 64, # attention head dimension 153 | heads = 8, # attention heads, 154 | ff_mult = 4, # feedforward expansion factor 155 | t5_name = 't5-small', # name of your T5 156 | ) 157 | 158 | # (2) pass your trained VAE and the base transformer to MaskGit 159 | 160 | superres_maskgit = MaskGit( 161 | vae = vae, 162 | transformer = transformer, 163 | cond_drop_prob = 0.25, 164 | image_size = 512, # larger image size 165 | cond_image_size = 256, # conditioning image size <- this must be set 166 | ).cuda() 167 | 168 | # ready your training text and images 169 | 170 | texts = [ 171 | 'a child screaming at finding a worm within a half-eaten apple', 172 | 'lizard running across the desert on two feet', 173 | 'waking up to a psychedelic landscape', 174 | 'seashells sparkling in the shallow waters' 175 | ] 176 | 177 | images = torch.randn(4, 3, 512, 512).cuda() 178 | 179 | # feed it into your maskgit instance, with return_loss set to True 180 | 181 | loss = superres_maskgit( 182 | images, 183 | texts = texts 184 | ) 185 | 186 | loss.backward() 187 | 188 | # do this for a long time on much data 189 | # then... 190 | 191 | images = superres_maskgit.generate( 192 | texts = [ 193 | 'a whale breaching from afar', 194 | 'young girl blowing out candles on her birthday cake', 195 | 'fireworks with blue and green sparkles', 196 | 'waking up to a psychedelic landscape' 197 | ], 198 | cond_images = F.interpolate(images, 256), # conditioning images must be passed in for generating from superres 199 | cond_scale = 3. 200 | ) 201 | 202 | images.shape # (4, 3, 512, 512) 203 | ``` 204 | 205 | All together now 206 | 207 | ```python 208 | from muse_maskgit_pytorch import Muse 209 | 210 | base_maskgit.load('./path/to/base.pt') 211 | 212 | superres_maskgit.load('./path/to/superres.pt') 213 | 214 | # pass in the trained base_maskgit and superres_maskgit from above 215 | 216 | muse = Muse( 217 | base = base_maskgit, 218 | superres = superres_maskgit 219 | ) 220 | 221 | images = muse([ 222 | 'a whale breaching from afar', 223 | 'young girl blowing out candles on her birthday cake', 224 | 'fireworks with blue and green sparkles', 225 | 'waking up to a psychedelic landscape' 226 | ]) 227 | 228 | images # List[PIL.Image.Image] 229 | ``` 230 | 231 | ## Training 232 | 233 | Training should be done in 4 stages. 234 | 235 | 1. Training base VAE(swap out the dataset_name with your huggingface dataset) 236 | 237 | ``` 238 | accelerate launch train_muse_vae.py --dataset_name="Isamu136/big-animal-dataset" 239 | ``` 240 | 2. Once you trained enough in the base VAE, move the checkpoint of your latest version to a new location. Then, do 241 | 242 | ``` 243 | accelerate launch train_muse_maskgit.py --dataset_name="Isamu136/big-animal-dataset" --vae_path=path_to_vae_checkpoint 244 | ``` 245 | 246 | Alternatively, if you want to use a pretrained autoencoder, download one from [here](https://github.com/CompVis/taming-transformers) and then extract it. In the below code, we are using vqgan_imagenet_f16_1024. Change the paths accordingly 247 | 248 | ``` 249 | accelerate launch train_muse_maskgit.py --dataset_name="Isamu136/big-animal-dataset" --taming_model_path="models/image_net_f16/ckpts/last.ckpt" --taming_config_path="models/image_net_f16/configs/model.yaml" --validation_prompt="elephant" 250 | ``` 251 | 252 | or if you want to train on cifar10, try 253 | 254 | ``` 255 | accelerate launch train_muse_maskgit.py --dataset_name="cifar10" --taming_model_path="models/image_net_f16/ckpts/last.ckpt" --taming_config_path="models/image_net_f16/configs/model.yaml" --validation_prompt="0" --image_column="img" --caption_column="label" 256 | ``` 257 | 258 | ## Checkpoints and Pretrained Models 259 | We currently do not have any usable pretrained model for Muse but we are trying to train it with whatever resources we have available, for more information check the [Sygil Muse](https://huggingface.co/Sygil/Sygil-Muse) repository on HuggingFace where we are uploading the checkpoints for different tests we have performed and where we will be uploading the final weights once we have something everyone can use. 260 | 261 | ## Appreciation 262 | - [Lucidrains](https://github.com/lucidrains/muse-maskgit-pytorch) for the original Muse-Maskgit-Pytorch implementation. 263 | - The [ShoukanLabs](https://github.com/ShoukanLabs) team for contributing so much to improving the code and adding new features. 264 | - StabilityAI for the sponsorship, as well as my other sponsors, for affording me the independence to open source artificial intelligence. 265 | 266 | - 🤗 Huggingface for the transformers and accelerate library, both which are wonderful 267 | 268 | ## Todo 269 | 270 | - [x] test end-to-end 271 | 272 | - [x] separate cond_images_or_ids, it is not done right 273 | 274 | - [x] add training code for vae 275 | 276 | - [x] add optional self-conditioning on embeddings 277 | 278 | - [x] combine with token critic paper, already implemented at Phenaki 279 | 280 | - [x] hook up accelerate training code for maskgit 281 | 282 | - [ ] train a base model 283 | 284 | ## Citations 285 | 286 | ```bibtex 287 | @inproceedings{Chang2023MuseTG, 288 | title = {Muse: Text-To-Image Generation via Masked Generative Transformers}, 289 | author = {Huiwen Chang and Han Zhang and Jarred Barber and AJ Maschinot and Jos{\'e} Lezama and Lu Jiang and Ming-Hsuan Yang and Kevin P. Murphy and William T. Freeman and Michael Rubinstein and Yuanzhen Li and Dilip Krishnan}, 290 | year = {2023} 291 | } 292 | ``` 293 | 294 | ```bibtex 295 | @article{Chen2022AnalogBG, 296 | title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning}, 297 | author = {Ting Chen and Ruixiang Zhang and Geo rey E. Hinton}, 298 | journal = {ArXiv}, 299 | year = {2022}, 300 | volume = {abs/2208.04202} 301 | } 302 | ``` 303 | 304 | ```bibtex 305 | @misc{jabri2022scalable, 306 | title = {Scalable Adaptive Computation for Iterative Generation}, 307 | author = {Allan Jabri and David Fleet and Ting Chen}, 308 | year = {2022}, 309 | eprint = {2212.11972}, 310 | archivePrefix = {arXiv}, 311 | primaryClass = {cs.LG} 312 | } 313 | ``` 314 | 315 | ```bibtex 316 | @article{Lezama2022ImprovedMI, 317 | title = {Improved Masked Image Generation with Token-Critic}, 318 | author = {Jos{\'e} Lezama and Huiwen Chang and Lu Jiang and Irfan Essa}, 319 | journal = {ArXiv}, 320 | year = {2022}, 321 | volume = {abs/2209.04439} 322 | } 323 | ``` 324 | 325 | ```bibtex 326 | @inproceedings{Nijkamp2021SCRIPTSP, 327 | title = {SCRIPT: Self-Critic PreTraining of Transformers}, 328 | author = {Erik Nijkamp and Bo Pang and Ying Nian Wu and Caiming Xiong}, 329 | booktitle = {North American Chapter of the Association for Computational Linguistics}, 330 | year = {2021} 331 | } 332 | ``` 333 | 334 | ```bibtex 335 | @misc{gilmer2023intriguing 336 | title = {Intriguing Properties of Transformer Training Instabilities}, 337 | author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen}, 338 | year = {2023}, 339 | status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams} 340 | } 341 | ``` 342 | -------------------------------------------------------------------------------- /infer_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import hashlib 4 | import os 5 | import random 6 | import re 7 | from dataclasses import dataclass 8 | from datetime import datetime 9 | from typing import Optional 10 | 11 | import accelerate 12 | import PIL 13 | import torch 14 | from accelerate.utils import ProjectConfiguration 15 | from datasets import Dataset, Image, load_dataset 16 | from torchvision.utils import save_image 17 | from tqdm import tqdm 18 | 19 | from muse_maskgit_pytorch import ( 20 | VQGanVAE, 21 | VQGanVAETaming, 22 | get_accelerator, 23 | ) 24 | from muse_maskgit_pytorch.dataset import ( 25 | ImageDataset, 26 | get_dataset_from_dataroot, 27 | ) 28 | from muse_maskgit_pytorch.vqvae import VQVAE 29 | 30 | # Create the parser 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument( 33 | "--no_center_crop", 34 | action="store_true", 35 | help="Don't do center crop.", 36 | ) 37 | parser.add_argument( 38 | "--random_crop", 39 | action="store_true", 40 | help="Crop the images at random locations instead of cropping from the center.", 41 | ) 42 | parser.add_argument( 43 | "--no_flip", 44 | action="store_true", 45 | help="Don't flip image.", 46 | ) 47 | parser.add_argument( 48 | "--random_image", 49 | action="store_true", 50 | help="Get a random image from the dataset to use for the reconstruction.", 51 | ) 52 | parser.add_argument( 53 | "--dataset_save_path", 54 | type=str, 55 | default="dataset", 56 | help="Path to save the dataset if you are making one from a directory", 57 | ) 58 | parser.add_argument( 59 | "--seed", 60 | type=int, 61 | default=42, 62 | help="Seed for reproducibility. If set to -1 a random seed will be generated.", 63 | ) 64 | parser.add_argument("--valid_frac", type=float, default=0.05, help="validation fraction.") 65 | parser.add_argument( 66 | "--image_column", 67 | type=str, 68 | default="image", 69 | help="The column of the dataset containing an image.", 70 | ) 71 | parser.add_argument( 72 | "--mixed_precision", 73 | type=str, 74 | default="no", 75 | choices=["no", "fp16", "bf16"], 76 | help="Precision to train on.", 77 | ) 78 | parser.add_argument( 79 | "--results_dir", 80 | type=str, 81 | default="results", 82 | help="Path to save the training samples and checkpoints", 83 | ) 84 | parser.add_argument( 85 | "--logging_dir", 86 | type=str, 87 | default=None, 88 | help="Path to log the losses and LR", 89 | ) 90 | 91 | # vae_trainer args 92 | parser.add_argument( 93 | "--vae_path", 94 | type=str, 95 | default=None, 96 | help="Path to the vae model. eg. 'results/vae.steps.pt'", 97 | ) 98 | parser.add_argument( 99 | "--dataset_name", 100 | type=str, 101 | default=None, 102 | help="Name of the huggingface dataset used.", 103 | ) 104 | parser.add_argument( 105 | "--train_data_dir", 106 | type=str, 107 | default=None, 108 | help="Dataset folder where your input images for training are.", 109 | ) 110 | parser.add_argument("--dim", type=int, default=128, help="Model dimension.") 111 | parser.add_argument("--batch_size", type=int, default=512, help="Batch Size.") 112 | parser.add_argument("--lr", type=float, default=1e-4, help="Learning Rate.") 113 | parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") 114 | parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") 115 | parser.add_argument( 116 | "--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA." 117 | ) 118 | parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") 119 | parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") 120 | parser.add_argument( 121 | "--image_size", 122 | type=int, 123 | default=256, 124 | help="Image size. You may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it", 125 | ) 126 | parser.add_argument( 127 | "--chunk_size", 128 | type=int, 129 | default=256, 130 | help="This is used to split big images into smaller chunks so we can still reconstruct them no matter the size.", 131 | ) 132 | parser.add_argument( 133 | "--min_chunk_size", 134 | type=int, 135 | default=8, 136 | help="We use a minimum chunk size to ensure that the image is always reconstructed correctly.", 137 | ) 138 | parser.add_argument( 139 | "--overlap_size", 140 | type=int, 141 | default=256, 142 | help="The overlap size used with --chunk_size to overlap the chunks and make sure the whole image is reconstructe as well as make sure we remove artifacts caused by doing the reconstrucion in chunks.", 143 | ) 144 | parser.add_argument( 145 | "--min_overlap_size", 146 | type=int, 147 | default=1, 148 | help="We use a minimum overlap size to ensure that the image is always reconstructed correctly.", 149 | ) 150 | parser.add_argument( 151 | "--taming_model_path", 152 | type=str, 153 | default=None, 154 | help="path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)", 155 | ) 156 | 157 | parser.add_argument( 158 | "--taming_config_path", 159 | type=str, 160 | default=None, 161 | help="path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)", 162 | ) 163 | parser.add_argument( 164 | "--input_image", 165 | type=str, 166 | default=None, 167 | help="Path to an image to use as input for reconstruction instead of using one from the dataset.", 168 | ) 169 | parser.add_argument( 170 | "--input_folder", 171 | type=str, 172 | default=None, 173 | help="Path to a folder with images to use as input for creating a dataset for reconstructing all the imgaes in it instead of just one image.", 174 | ) 175 | parser.add_argument( 176 | "--exclude_folders", 177 | type=str, 178 | default=None, 179 | help="List of folders we want to exclude when doing reconstructions from an input folder.", 180 | ) 181 | parser.add_argument( 182 | "--gpu", 183 | type=int, 184 | default=0, 185 | help="GPU to use in case we want to use a specific GPU for inference.", 186 | ) 187 | parser.add_argument( 188 | "--max_retries", 189 | type=int, 190 | default=30, 191 | help="Max number of times to retry in case the reconstruction fails due to OOM or any other error.", 192 | ) 193 | parser.add_argument( 194 | "--latest_checkpoint", 195 | action="store_true", 196 | help="Use the latest checkpoint using the vae_path folder instead of using just a specific vae_path.", 197 | ) 198 | parser.add_argument( 199 | "--use_paintmind", 200 | action="store_true", 201 | help="Use PaintMind VAE..", 202 | ) 203 | 204 | 205 | @dataclass 206 | class Arguments: 207 | only_save_last_checkpoint: bool = False 208 | validation_image_scale: float = 1.0 209 | no_center_crop: bool = False 210 | no_flip: bool = False 211 | random_crop: bool = False 212 | random_image: bool = False 213 | dataset_save_path: Optional[str] = None 214 | clear_previous_experiments: bool = False 215 | max_grad_norm: Optional[float] = None 216 | discr_max_grad_norm: Optional[float] = None 217 | num_tokens: int = 256 218 | seq_len: int = 1024 219 | seed: int = 42 220 | valid_frac: float = 0.05 221 | use_ema: bool = False 222 | ema_beta: float = 0.995 223 | ema_update_after_step: int = 1 224 | ema_update_every: int = 1 225 | apply_grad_penalty_every: int = 4 226 | image_column: str = "image" 227 | caption_column: str = "caption" 228 | log_with: str = "wandb" 229 | mixed_precision: str = "no" 230 | use_8bit_adam: bool = False 231 | results_dir: str = "results" 232 | logging_dir: Optional[str] = None 233 | resume_path: Optional[str] = None 234 | dataset_name: Optional[str] = None 235 | streaming: bool = False 236 | train_data_dir: Optional[str] = None 237 | num_train_steps: int = -1 238 | num_epochs: int = 5 239 | dim: int = 128 240 | batch_size: int = 512 241 | lr: float = 1e-5 242 | gradient_accumulation_steps: int = 1 243 | save_results_every: int = 100 244 | save_model_every: int = 500 245 | vq_codebook_size: int = 256 246 | vq_codebook_dim: int = 256 247 | cond_drop_prob: float = 0.5 248 | image_size: int = 256 249 | lr_scheduler: str = "constant" 250 | scheduler_power: float = 1.0 251 | lr_warmup_steps: int = 0 252 | num_cycles: int = 1 253 | taming_model_path: Optional[str] = None 254 | taming_config_path: Optional[str] = None 255 | optimizer: str = "Lion" 256 | weight_decay: float = 0.0 257 | cache_path: Optional[str] = None 258 | no_cache: bool = False 259 | latest_checkpoint: bool = False 260 | do_not_save_config: bool = False 261 | use_l2_recon_loss: bool = False 262 | debug: bool = False 263 | config_path: Optional[str] = None 264 | generate_config: bool = False 265 | 266 | 267 | def seed_to_int(s): 268 | if type(s) is int: 269 | return s 270 | if s is None or s == "": 271 | return random.randint(0, 2**32 - 1) 272 | 273 | if "," in s: 274 | s = s.split(",") 275 | 276 | if type(s) is list: 277 | seed_list = [] 278 | for seed in s: 279 | if seed is None or seed == "": 280 | seed_list.append(random.randint(0, 2**32 - 1)) 281 | else: 282 | seed_list = s 283 | 284 | return seed_list 285 | 286 | n = abs(int(s) if s.isdigit() else random.Random(s).randint(0, 2**32 - 1)) 287 | while n >= 2**32: 288 | n = n >> 32 289 | return n 290 | 291 | 292 | def main(): 293 | args = parser.parse_args(namespace=Arguments()) 294 | 295 | project_config = ProjectConfiguration( 296 | project_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"), 297 | automatic_checkpoint_naming=True, 298 | ) 299 | 300 | accelerator: accelerate.Accelerator = get_accelerator( 301 | log_with=args.log_with, 302 | gradient_accumulation_steps=args.gradient_accumulation_steps, 303 | mixed_precision=args.mixed_precision, 304 | project_config=project_config, 305 | even_batches=True, 306 | ) 307 | 308 | # set pytorch seed for reproducibility 309 | torch.manual_seed(seed_to_int(args.seed)) 310 | 311 | if args.train_data_dir and not args.input_image and not args.input_folder: 312 | dataset = get_dataset_from_dataroot( 313 | args.train_data_dir, 314 | image_column=args.image_column, 315 | save_path=args.dataset_save_path, 316 | ) 317 | elif args.dataset_name and not args.input_image and not args.input_folder: 318 | dataset = load_dataset(args.dataset_name)["train"] 319 | 320 | elif args.input_image and not args.input_folder: 321 | # Create dataset from single input image 322 | dataset = Dataset.from_dict({"image": [args.input_image]}).cast_column("image", Image()) 323 | 324 | if args.input_folder: 325 | # Create dataset from input folder 326 | extensions = ["jpg", "jpeg", "png", "webp"] 327 | exclude_folders = args.exclude_folders.split(",") if args.exclude_folders else [] 328 | 329 | filepaths = [] 330 | for root, dirs, files in os.walk(args.input_folder, followlinks=True): 331 | # Resolve symbolic link to actual path and exclude based on actual path 332 | resolved_root = os.path.realpath(root) 333 | for exclude_folder in exclude_folders: 334 | if exclude_folder in resolved_root: 335 | dirs[:] = [] 336 | break 337 | for file in files: 338 | if file.lower().endswith(tuple(extensions)): 339 | filepaths.append(os.path.join(root, file)) 340 | 341 | if not filepaths: 342 | print(f"No images with extensions {extensions} found in {args.input_folder}.") 343 | exit(1) 344 | 345 | dataset = Dataset.from_dict({"image": filepaths}).cast_column("image", Image()) 346 | 347 | if args.vae_path and args.taming_model_path: 348 | raise Exception("You can't pass vae_path and taming args at the same time.") 349 | 350 | if args.vae_path and not args.use_paintmind: 351 | accelerator.print("Loading Muse VQGanVAE") 352 | vae = VQGanVAE( 353 | dim=args.dim, 354 | vq_codebook_size=args.vq_codebook_size, 355 | vq_codebook_dim=args.vq_codebook_dim, 356 | channels=args.channels, 357 | layers=args.layers, 358 | discr_layers=args.discr_layers, 359 | ).to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") 360 | 361 | if args.latest_checkpoint: 362 | accelerator.print("Finding latest checkpoint...") 363 | orig_vae_path = args.vae_path 364 | 365 | if os.path.isfile(args.vae_path) or ".pt" in args.vae_path: 366 | # If args.vae_path is a file, split it into directory and filename 367 | args.vae_path, _ = os.path.split(args.vae_path) 368 | 369 | checkpoint_files = glob.glob(os.path.join(args.vae_path, "vae.*.pt")) 370 | if checkpoint_files: 371 | latest_checkpoint_file = max( 372 | checkpoint_files, 373 | key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1)) 374 | if not x.endswith("ema.pt") 375 | else -1, 376 | ) 377 | 378 | # Check if latest checkpoint is empty or unreadable 379 | if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( 380 | latest_checkpoint_file, os.R_OK 381 | ): 382 | accelerator.print( 383 | f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable." 384 | ) 385 | if len(checkpoint_files) > 1: 386 | # Use the second last checkpoint as a fallback 387 | latest_checkpoint_file = max( 388 | checkpoint_files[:-1], 389 | key=lambda x: int(re.search(r"vae\.(\d+)\.pt$", x).group(1)) 390 | if not x.endswith("ema.pt") 391 | else -1, 392 | ) 393 | accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) 394 | else: 395 | accelerator.print("No usable checkpoint found.") 396 | elif latest_checkpoint_file != orig_vae_path: 397 | accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file) 398 | else: 399 | accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path) 400 | 401 | args.vae_path = latest_checkpoint_file 402 | else: 403 | accelerator.print("No checkpoints found in directory: ", args.vae_path) 404 | else: 405 | accelerator.print("Resuming VAE from: ", args.vae_path) 406 | 407 | vae.load(args.vae_path) 408 | 409 | if args.use_paintmind: 410 | # load VAE 411 | accelerator.print("Loading VQVAE from 'neggles/vaedump/vit-s-vqgan-f4' ...") 412 | vae: VQVAE = VQVAE.from_pretrained("neggles/vaedump", subfolder="vit-s-vqgan-f4") 413 | 414 | elif args.taming_model_path: 415 | print("Loading Taming VQGanVAE") 416 | vae = VQGanVAETaming( 417 | vqgan_model_path=args.taming_model_path, 418 | vqgan_config_path=args.taming_config_path, 419 | ) 420 | args.num_tokens = vae.codebook_size 421 | args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2 422 | 423 | # move vae to device 424 | vae = vae.to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") 425 | 426 | # Use the parameters() method to get an iterator over all the learnable parameters of the model 427 | total_params = sum(p.numel() for p in vae.parameters()) 428 | 429 | print(f"Total number of parameters: {format(total_params, ',d')}") 430 | 431 | # then you plug the vae and transformer into your MaskGit as so 432 | 433 | dataset = ImageDataset( 434 | dataset, 435 | args.image_size, 436 | image_column=args.image_column, 437 | center_crop=True if not args.no_center_crop and not args.random_crop else False, 438 | flip=not args.no_flip, 439 | random_crop=args.random_crop if args.random_crop else False, 440 | alpha_channel=False if args.channels == 3 else True, 441 | ) 442 | 443 | if args.input_image and not args.input_folder: 444 | image_id = 0 if not args.random_image else random.randint(0, len(dataset)) 445 | 446 | os.makedirs(f"{args.results_dir}/outputs", exist_ok=True) 447 | 448 | save_image( 449 | dataset[image_id], 450 | f"{args.results_dir}/outputs/input.{str(args.input_image).split('.')[-1]}", 451 | format="PNG", 452 | ) 453 | 454 | _, ids, _ = vae.encode( 455 | dataset[image_id][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") 456 | ) 457 | recon = vae.decode_from_ids(ids) 458 | save_image(recon, f"{args.results_dir}/outputs/output.{str(args.input_image).split('.')[-1]}") 459 | 460 | if not args.input_image and not args.input_folder: 461 | image_id = 0 if not args.random_image else random.randint(0, len(dataset)) 462 | 463 | os.makedirs(f"{args.results_dir}/outputs", exist_ok=True) 464 | 465 | save_image(dataset[image_id], f"{args.results_dir}/outputs/input.png") 466 | 467 | _, ids, _ = vae.encode( 468 | dataset[image_id][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") 469 | ) 470 | recon = vae.decode_from_ids(ids) 471 | save_image(recon, f"{args.results_dir}/outputs/output.png") 472 | 473 | if args.input_folder: 474 | # Create output directory and save input images and reconstructions as grids 475 | output_dir = os.path.join(args.results_dir, "outputs", os.path.basename(args.input_folder)) 476 | os.makedirs(output_dir, exist_ok=True) 477 | 478 | for i in tqdm(range(len(dataset))): 479 | retries = 0 480 | while True: 481 | try: 482 | save_image(dataset[i], f"{output_dir}/input.png") 483 | 484 | if not args.use_paintmind: 485 | # encode 486 | _, ids, _ = vae.encode( 487 | dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") 488 | ) 489 | # decode 490 | recon = vae.decode_from_ids(ids) 491 | # print (recon.shape) # torch.Size([1, 3, 512, 1136]) 492 | save_image(recon, f"{output_dir}/output.png") 493 | else: 494 | # encode 495 | encoded, _, _ = vae.encode( 496 | dataset[i][None].to(accelerator.device if args.gpu == 0 else f"cuda:{args.gpu}") 497 | ) 498 | 499 | # decode 500 | recon = vae.decode(encoded).squeeze(0) 501 | recon = torch.clamp(recon, -1.0, 1.0) 502 | save_image(recon, f"{output_dir}/output.png") 503 | 504 | # Load input and output images 505 | input_image = PIL.Image.open(f"{output_dir}/input.png") 506 | output_image = PIL.Image.open(f"{output_dir}/output.png") 507 | 508 | # Create horizontal grid with input and output images 509 | grid_image = PIL.Image.new( 510 | "RGB" if args.channels == 3 else "RGBA", 511 | (input_image.width + output_image.width, input_image.height), 512 | ) 513 | grid_image.paste(input_image, (0, 0)) 514 | grid_image.paste(output_image, (input_image.width, 0)) 515 | 516 | # Save grid 517 | now = datetime.now().strftime("%m-%d-%Y_%H-%M-%S") 518 | hash = hashlib.sha1(input_image.tobytes()).hexdigest() 519 | 520 | filename = f"{hash}_{now}-{os.path.basename(args.vae_path)}.png" 521 | grid_image.save(f"{output_dir}/{filename}", format="PNG") 522 | 523 | # Remove input and output images after the grid was made. 524 | os.remove(f"{output_dir}/input.png") 525 | os.remove(f"{output_dir}/output.png") 526 | 527 | del _ 528 | del ids 529 | del recon 530 | 531 | torch.cuda.empty_cache() 532 | torch.cuda.ipc_collect() 533 | 534 | break # Exit the retry loop if there were no errors 535 | 536 | except RuntimeError as e: 537 | if "out of memory" in str(e) and retries < args.max_retries: 538 | retries += 1 539 | # print(f"Out of Memory. Retry #{retries}") 540 | torch.cuda.empty_cache() 541 | torch.cuda.ipc_collect() 542 | continue # Retry the loop 543 | 544 | else: 545 | if "out of memory" not in str(e): 546 | print(e) 547 | else: 548 | print(f"Skipping image {i} after {retries} retries due to out of memory error") 549 | break # Exit the retry loop after too many retries 550 | 551 | 552 | if __name__ == "__main__": 553 | main() 554 | -------------------------------------------------------------------------------- /muse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sygil-Dev/muse-maskgit-pytorch/7d7c00c39e29af0585a1c616e4626528e13b61da/muse.png -------------------------------------------------------------------------------- /muse_maskgit_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from .muse_maskgit_pytorch import MaskGit, MaskGitTransformer, Muse, TokenCritic, Transformer 2 | from .trainers import MaskGitTrainer, VQGanVAETrainer, get_accelerator 3 | from .vqgan_vae import VQGanVAE 4 | from .vqgan_vae_taming import VQGanVAETaming 5 | 6 | __all__ = [ 7 | "VQGanVAE", 8 | "VQGanVAETaming", 9 | "Transformer", 10 | "MaskGit", 11 | "Muse", 12 | "MaskGitTransformer", 13 | "TokenCritic", 14 | "VQGanVAETrainer", 15 | "MaskGitTrainer", 16 | "get_accelerator", 17 | ] 18 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/attn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Sygil-Dev/muse-maskgit-pytorch/7d7c00c39e29af0585a1c616e4626528e13b61da/muse_maskgit_pytorch/attn/__init__.py -------------------------------------------------------------------------------- /muse_maskgit_pytorch/attn/attn_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import BoolTensor, FloatTensor, allclose, arange, manual_seed, no_grad, randn 3 | from torch.nn.functional import pad 4 | 5 | from muse_maskgit_pytorch.attn.ein_attn import Attention as EinAttn 6 | from muse_maskgit_pytorch.attn.xformers_attn import Attention as XformersAttn 7 | 8 | device = torch.device("cuda") 9 | dtype = torch.float32 10 | seed = 42 11 | 12 | # realistically this would be 320 in stable-diffusion, but I'm going smaller during testing 13 | vision_dim = 64 14 | 15 | attn_init_params = { 16 | "dim": vision_dim, 17 | "dim_head": 64, 18 | # realistically this would be at least 5 19 | "heads": 2, 20 | "cross_attend": True, 21 | "scale": 8, 22 | } 23 | 24 | with no_grad(): 25 | # seed RNG before we initialize any layers, so that both will end up with same params 26 | manual_seed(seed) 27 | ein_attn = EinAttn(**attn_init_params).to(device, dtype).eval() 28 | # commented-out scaled dot product attention because it didn't support flash attn, so we'll try with xformers instead. 29 | # manual_seed(seed) 30 | # sdp_attn = SDPAttn(**attn_init_params).to(device, dtype).eval() 31 | manual_seed(seed) 32 | xfo_attn = XformersAttn(**attn_init_params).to(device, dtype).eval() 33 | 34 | batch_size = 2 35 | 36 | # realistically this would be 64**2 in stable-diffusion 37 | vision_tokens = 32**2 # 1024 38 | 39 | # generate rand on-CPU for cross-platform determinism of results 40 | x: FloatTensor = randn(batch_size, vision_tokens, vision_dim, dtype=dtype).to(device) 41 | 42 | # I've said text here simply as an example of something you could cross-attend to 43 | text_tokens = 16 # CLIP would be 77 44 | # for a *general* cross-attention Module: 45 | # kv_in_dim could differ from q_in_dim, but this attention Module requires x and context to have same dim. 46 | text_dim = vision_dim 47 | context: FloatTensor = randn(batch_size, text_tokens, text_dim, dtype=dtype).to(device) 48 | 49 | # attend to just the first two tokens in each text condition (e.g. if both were uncond, so [BOS, EOS] followed by PAD tokens) 50 | context_mask: BoolTensor = (arange(text_tokens, device=device) < 2).expand(batch_size, -1).contiguous() 51 | 52 | # for xformers cutlassF kernel: masks are only supported for keys whose lengths are multiples of 8: 53 | # https://gist.github.com/Birch-san/0c36d228e1d4b881a06d1c6e5289d569 54 | # so, we add whatever we feel like to the end of the key to extend it to a multiple of 8, 55 | # and add "discard" tokens to the mask to get rid of the excess 56 | # note: muse will add an extra "null" token to our context, so we'll account for that in advance 57 | mask_length = context_mask.shape[-1] + 1 58 | extra_tokens_needed = 8 - (mask_length % 8) 59 | # 0-pad mask to multiple of 8 tokens 60 | xfo_context_mask = pad(context_mask, (0, extra_tokens_needed)) 61 | # replicate-pad embedding to multiple of 8 tokens (mask will hide the extra tokens) 62 | xfo_context = pad(context, (0, 0, 0, extra_tokens_needed), "replicate") 63 | 64 | ein_result: FloatTensor = ein_attn.forward(x, context, context_mask) 65 | # sdp attn works, but only supports flash attn when context_mask is None. 66 | # with sdp_kernel(enable_math=False): 67 | # sdp_result: FloatTensor = sdp_attn.forward(x, context, context_mask) 68 | xfo_attn: FloatTensor = xfo_attn.forward(x, xfo_context, xfo_context_mask) 69 | 70 | # default rtol 71 | rtol = 1e-5 72 | # atol would normally be 1e-8 73 | atol = 5e-7 74 | # assert allclose(ein_result, sdp_result, rtol=rtol, atol=atol), f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" 75 | if not allclose(ein_result, xfo_attn, rtol=rtol, atol=atol): 76 | raise RuntimeError( 77 | f"looks like attention implementations weren't equivalent, to tolerance rtol={rtol}, atol={atol}" 78 | ) 79 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/attn/ein_attn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from einops import rearrange, repeat 4 | from torch import einsum, nn 5 | 6 | 7 | # helpers 8 | def exists(val): 9 | return val is not None 10 | 11 | 12 | def l2norm(t): 13 | return F.normalize(t, dim=-1) 14 | 15 | 16 | class LayerNorm(nn.Module): 17 | def __init__(self, dim): 18 | super().__init__() 19 | self.gamma = nn.Parameter(torch.ones(dim)) 20 | self.register_buffer("beta", torch.zeros(dim)) 21 | 22 | def forward(self, x): 23 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 24 | 25 | 26 | class Attention(nn.Module): 27 | def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): 28 | super().__init__() 29 | self.scale = scale 30 | self.heads = heads 31 | inner_dim = dim_head * heads 32 | 33 | self.cross_attend = cross_attend 34 | self.norm = LayerNorm(dim) 35 | 36 | self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head)) 37 | 38 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 39 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 40 | 41 | self.q_scale = nn.Parameter(torch.ones(dim_head)) 42 | self.k_scale = nn.Parameter(torch.ones(dim_head)) 43 | 44 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 45 | 46 | def forward(self, x, context=None, context_mask=None): 47 | assert not (exists(context) ^ self.cross_attend) 48 | 49 | h = self.heads 50 | x = self.norm(x) 51 | 52 | kv_input = context if self.cross_attend else x 53 | 54 | q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) 55 | 56 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 57 | 58 | nk, nv = self.null_kv 59 | nk, nv = map(lambda t: repeat(t, "h 1 d -> b h 1 d", b=x.shape[0]), (nk, nv)) 60 | 61 | k = torch.cat((nk, k), dim=-2) 62 | v = torch.cat((nv, v), dim=-2) 63 | 64 | q, k = map(l2norm, (q, k)) 65 | q = q * self.q_scale 66 | k = k * self.k_scale 67 | 68 | sim = einsum("b h i d, b h j d -> b h i j", q, k) * self.scale 69 | 70 | if exists(context_mask): 71 | context_mask = rearrange(context_mask, "b j -> b 1 1 j") 72 | context_mask = F.pad(context_mask, (1, 0), value=True) 73 | 74 | mask_value = -torch.finfo(sim.dtype).max 75 | sim = sim.masked_fill(~context_mask, mask_value) 76 | 77 | attn = sim.softmax(dim=-1) 78 | out = einsum("b h i j, b h j d -> b h i d", attn, v) 79 | 80 | out = rearrange(out, "b h n d -> b n (h d)") 81 | return self.to_out(out) 82 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/attn/sdp_attn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat 6 | from torch import BoolTensor, FloatTensor, nn 7 | from torch.nn.functional import scaled_dot_product_attention 8 | 9 | 10 | def l2norm(t): 11 | return F.normalize(t, dim=-1) 12 | 13 | 14 | class LayerNorm(nn.Module): 15 | def __init__(self, dim): 16 | super().__init__() 17 | self.gamma = nn.Parameter(torch.ones(dim)) 18 | self.register_buffer("beta", torch.zeros(dim)) 19 | 20 | def forward(self, x): 21 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 22 | 23 | 24 | class Attention(nn.Module): 25 | def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): 26 | super().__init__() 27 | self.heads = heads 28 | inner_dim = dim_head * heads 29 | 30 | self.cross_attend = cross_attend 31 | self.norm = LayerNorm(dim) 32 | 33 | self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head)) 34 | 35 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 37 | 38 | typical_scale = dim_head**-0.5 39 | scale_ratio = scale / typical_scale 40 | self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio)) 41 | self.k_scale = nn.Parameter(torch.ones(dim_head)) 42 | 43 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 44 | 45 | def forward( 46 | self, x: FloatTensor, context: Optional[FloatTensor] = None, context_mask: Optional[BoolTensor] = None 47 | ): 48 | assert (context is None) != self.cross_attend 49 | 50 | h = self.heads 51 | # TODO: you could fuse this layernorm with the linear that follows it, e.g. via TransformerEngine 52 | x = self.norm(x) 53 | 54 | kv_input = context if self.cross_attend else x 55 | 56 | # TODO: to_q and to_kvs could be combined into one to_qkv 57 | q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) 58 | 59 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) 60 | 61 | nk, nv = self.null_kv 62 | nk, nv = map(lambda t: repeat(t, "h 1 d -> b h 1 d", b=x.shape[0]), (nk, nv)) 63 | 64 | k = torch.cat((nk, k), dim=-2) 65 | v = torch.cat((nv, v), dim=-2) 66 | 67 | q, k = map(l2norm, (q, k)) 68 | q = q * self.q_scale 69 | k = k * self.k_scale 70 | 71 | if context_mask is not None: 72 | context_mask = rearrange(context_mask, "b j -> b 1 1 j") 73 | context_mask = F.pad(context_mask, (1, 0), value=True) 74 | 75 | out: FloatTensor = scaled_dot_product_attention(q, k, v, context_mask) 76 | 77 | out = rearrange(out, "b h n d -> b n (h d)") 78 | return self.to_out(out) 79 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/attn/xformers_attn.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from einops import rearrange, repeat 6 | from torch import BoolTensor, FloatTensor, nn 7 | from xformers.ops import memory_efficient_attention 8 | 9 | 10 | def l2norm(t): 11 | return F.normalize(t, dim=-1) 12 | 13 | 14 | class LayerNorm(nn.Module): 15 | def __init__(self, dim): 16 | super().__init__() 17 | self.gamma = nn.Parameter(torch.ones(dim)) 18 | self.register_buffer("beta", torch.zeros(dim)) 19 | 20 | def forward(self, x): 21 | return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta) 22 | 23 | 24 | class Attention(nn.Module): 25 | def __init__(self, dim, dim_head=64, heads=8, cross_attend=False, scale=8): 26 | super().__init__() 27 | self.heads = heads 28 | inner_dim = dim_head * heads 29 | 30 | self.cross_attend = cross_attend 31 | self.norm = LayerNorm(dim) 32 | 33 | self.null_kv = nn.Parameter(torch.randn(2, heads, 1, dim_head)) 34 | 35 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 37 | 38 | typical_scale = dim_head**-0.5 39 | scale_ratio = scale / typical_scale 40 | self.q_scale = nn.Parameter(torch.full((dim_head,), scale_ratio)) 41 | self.k_scale = nn.Parameter(torch.ones(dim_head)) 42 | 43 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 44 | 45 | def forward( 46 | self, x: FloatTensor, context: Optional[FloatTensor] = None, context_mask: Optional[BoolTensor] = None 47 | ): 48 | assert (context is None) != self.cross_attend 49 | 50 | h = self.heads 51 | # TODO: you could fuse this layernorm with the linear that follows it, e.g. via TransformerEngine 52 | x = self.norm(x) 53 | 54 | kv_input = context if self.cross_attend else x 55 | 56 | # TODO: to_q and to_kvs could be combined into one to_qkv 57 | q, k, v = (self.to_q(x), *self.to_kv(kv_input).chunk(2, dim=-1)) 58 | 59 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q, k, v)) 60 | 61 | nk, nv = self.null_kv 62 | nk, nv = map(lambda t: repeat(t, "h 1 d -> b 1 h d", b=x.shape[0]), (nk, nv)) 63 | 64 | k = torch.cat((nk, k), dim=-3) 65 | v = torch.cat((nv, v), dim=-3) 66 | 67 | q, k = map(l2norm, (q, k)) 68 | q = q * self.q_scale 69 | k = k * self.k_scale 70 | 71 | if context_mask is None: 72 | attn_bias = None 73 | else: 74 | context_mask = F.pad(context_mask, (1, 0), value=True) 75 | context_mask = rearrange(context_mask, "b j -> b 1 1 j") 76 | attn_bias = torch.where(context_mask is True, 0.0, -10000.0) 77 | attn_bias = attn_bias.expand(-1, h, q.size(1), -1) 78 | 79 | out: FloatTensor = memory_efficient_attention(q, k, v, attn_bias) 80 | 81 | out = rearrange(out, "b n h d -> b n (h d)") 82 | return self.to_out(out) 83 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | import sys 5 | import time 6 | from pathlib import Path 7 | from threading import Thread 8 | 9 | import datasets 10 | import torch 11 | from datasets import Image, load_from_disk 12 | from PIL import ( 13 | Image as pImage, 14 | ImageFile, 15 | ) 16 | from torch.utils.data import DataLoader, Dataset, random_split 17 | from torchvision import transforms as T 18 | 19 | try: 20 | import torch_xla 21 | import torch_xla.core.xla_model as xm 22 | from tqdm_loggable.auto import tqdm 23 | except ImportError: 24 | from tqdm import tqdm 25 | 26 | from io import BytesIO 27 | 28 | import requests 29 | from transformers import T5Tokenizer 30 | 31 | from muse_maskgit_pytorch.t5 import MAX_LENGTH 32 | 33 | ImageFile.LOAD_TRUNCATED_IMAGES = True 34 | pImage.MAX_IMAGE_PIXELS = None 35 | 36 | 37 | class ImageDataset(Dataset): 38 | def __init__( 39 | self, 40 | dataset, 41 | image_size, 42 | image_column="image", 43 | flip=True, 44 | center_crop=True, 45 | stream=False, 46 | using_taming=False, 47 | random_crop=False, 48 | alpha_channel=True, 49 | ): 50 | super().__init__() 51 | self.dataset = dataset 52 | self.image_column = image_column 53 | self.stream = stream 54 | transform_list = [ 55 | T.Resize(image_size), 56 | ] 57 | 58 | if flip: 59 | transform_list.append(T.RandomHorizontalFlip()) 60 | if center_crop and not random_crop: 61 | transform_list.append(T.CenterCrop(image_size)) 62 | if random_crop: 63 | transform_list.append(T.RandomCrop(image_size, pad_if_needed=True)) 64 | if alpha_channel: 65 | transform_list.append(T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" else img)) 66 | else: 67 | transform_list.append(T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img)) 68 | 69 | transform_list.append(T.ToTensor()) 70 | self.transform = T.Compose(transform_list) 71 | self.using_taming = using_taming 72 | 73 | def __len__(self): 74 | if not self.stream: 75 | return len(self.dataset) 76 | else: 77 | raise AssertionError("Streaming doesnt support grabbing dataset length") 78 | 79 | def __getitem__(self, index): 80 | image = self.dataset[index][self.image_column] 81 | if self.using_taming: 82 | return self.transform(image) - 0.5 83 | else: 84 | return self.transform(image) 85 | 86 | 87 | class ImageTextDataset(ImageDataset): 88 | def __init__( 89 | self, 90 | dataset, 91 | image_size: int, 92 | tokenizer: T5Tokenizer, 93 | image_column="image", 94 | caption_column="caption", 95 | flip=True, 96 | center_crop=True, 97 | stream=False, 98 | using_taming=False, 99 | random_crop=False, 100 | ): 101 | super().__init__( 102 | dataset, 103 | image_size=image_size, 104 | image_column=image_column, 105 | flip=flip, 106 | stream=stream, 107 | center_crop=center_crop, 108 | using_taming=using_taming, 109 | random_crop=random_crop, 110 | ) 111 | self.caption_column: str = caption_column 112 | self.tokenizer: T5Tokenizer = tokenizer 113 | 114 | def __getitem__(self, index): 115 | image = self.dataset[index][self.image_column] 116 | descriptions = self.dataset[index][self.caption_column] 117 | if self.caption_column is None or descriptions is None: 118 | text = "" 119 | elif isinstance(descriptions, list): 120 | if len(descriptions) == 0: 121 | text = "" 122 | else: 123 | text = random.choice(descriptions) 124 | else: 125 | text = descriptions 126 | # max length from the paper 127 | encoded = self.tokenizer.batch_encode_plus( 128 | [str(text)], 129 | return_tensors="pt", 130 | padding="max_length", 131 | max_length=MAX_LENGTH, 132 | truncation=True, 133 | ) 134 | 135 | input_ids = encoded.input_ids 136 | attn_mask = encoded.attention_mask 137 | 138 | if self.using_taming: 139 | return self.transform(image) - 0.5, input_ids[0], attn_mask[0] 140 | else: 141 | return self.transform(image), input_ids[0], attn_mask[0] 142 | 143 | 144 | class URLTextDataset(ImageDataset): 145 | def __init__( 146 | self, 147 | dataset, 148 | image_size: int, 149 | tokenizer: T5Tokenizer, 150 | image_column="image", 151 | caption_column="caption", 152 | flip=True, 153 | center_crop=True, 154 | using_taming=True, 155 | ): 156 | super().__init__( 157 | dataset, 158 | image_size=image_size, 159 | image_column=image_column, 160 | flip=flip, 161 | center_crop=center_crop, 162 | using_taming=using_taming, 163 | ) 164 | self.caption_column: str = caption_column 165 | self.tokenizer: T5Tokenizer = tokenizer 166 | 167 | def __getitem__(self, index): 168 | try: 169 | image = pImage.open(BytesIO(requests.get(self.dataset[index][self.image_column]).content)) 170 | except ConnectionError: 171 | try: 172 | print("Image request failure, attempting next image") 173 | index += 1 174 | 175 | image = pImage.open(BytesIO(requests.get(self.dataset[index][self.image_column]).content)) 176 | except ConnectionError: 177 | raise ConnectionError("Unable to request image from the Dataset") 178 | 179 | descriptions = self.dataset[index][self.caption_column] 180 | if self.caption_column is None or descriptions is None: 181 | text = "" 182 | elif isinstance(descriptions, list): 183 | if len(descriptions) == 0: 184 | text = "" 185 | else: 186 | text = random.choice(descriptions) 187 | else: 188 | text = descriptions 189 | # max length from the paper 190 | encoded = self.tokenizer.batch_encode_plus( 191 | [str(text)], 192 | return_tensors="pt", 193 | padding="max_length", 194 | max_length=MAX_LENGTH, 195 | truncation=True, 196 | ) 197 | 198 | input_ids = encoded.input_ids 199 | attn_mask = encoded.attention_mask 200 | if self.using_taming: 201 | return self.transform(image) - 0.5, input_ids[0], attn_mask[0] 202 | else: 203 | return self.transform(image), input_ids[0], attn_mask[0] 204 | 205 | 206 | class LocalTextImageDataset(Dataset): 207 | def __init__( 208 | self, 209 | path, 210 | image_size, 211 | tokenizer, 212 | flip=True, 213 | center_crop=True, 214 | using_taming=False, 215 | random_crop=False, 216 | alpha_channel=False, 217 | ): 218 | super().__init__() 219 | self.tokenizer = tokenizer 220 | self.using_taming = using_taming 221 | 222 | print("Building dataset...") 223 | 224 | extensions = ["jpg", "jpeg", "png", "webp"] 225 | self.image_paths = [] 226 | self.caption_pair = [] 227 | self.images = [] 228 | 229 | for ext in extensions: 230 | self.image_paths.extend(list(Path(path).rglob(f"*.{ext}"))) 231 | 232 | random.shuffle(self.image_paths) 233 | for image_path in tqdm(self.image_paths): 234 | # check image size and ignore images with 0 byte. 235 | if os.path.getsize(image_path) == 0: 236 | continue 237 | caption_path = image_path.with_suffix(".txt") 238 | if os.path.exists(str(caption_path)): 239 | captions = str(caption_path) 240 | else: 241 | captions = "" 242 | self.images.append(image_path) 243 | self.caption_pair.append(captions) 244 | 245 | transform_list = [ 246 | T.Resize(image_size), 247 | ] 248 | if flip: 249 | transform_list.append(T.RandomHorizontalFlip()) 250 | if center_crop and not random_crop: 251 | transform_list.append(T.CenterCrop(image_size)) 252 | if random_crop: 253 | transform_list.append(T.RandomCrop(image_size, pad_if_needed=True)) 254 | if alpha_channel: 255 | transform_list.append(T.Lambda(lambda img: img.convert("RGBA") if img.mode != "RGBA" else img)) 256 | else: 257 | transform_list.append(T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img)) 258 | transform_list.append(T.ToTensor()) 259 | self.transform = T.Compose(transform_list) 260 | 261 | def __len__(self): 262 | return len(self.caption_pair) 263 | 264 | def __getitem__(self, index): 265 | image = self.images[index] 266 | image = pImage.open(image) 267 | descriptions = self.caption_pair[index] 268 | if descriptions is None or descriptions == "": 269 | text = "" 270 | else: 271 | text = Path(descriptions).read_text(encoding="utf-8").split("\n") 272 | 273 | # max length from the paper 274 | encoded = self.tokenizer.batch_encode_plus( 275 | [str(text)], 276 | return_tensors="pt", 277 | padding="max_length", 278 | max_length=MAX_LENGTH, 279 | truncation=True, 280 | ) 281 | 282 | input_ids = encoded.input_ids 283 | attn_mask = encoded.attention_mask 284 | if self.using_taming: 285 | return self.transform(image) - 0.5, input_ids[0], attn_mask[0] 286 | else: 287 | return self.transform(image), input_ids[0], attn_mask[0] 288 | 289 | 290 | def get_directory_size(path): 291 | total_size = 0 292 | for dirpath, dirnames, filenames in os.walk(path): 293 | for f in filenames: 294 | fp = os.path.join(dirpath, f) 295 | total_size += os.path.getsize(fp) 296 | return total_size 297 | 298 | 299 | def save_dataset_with_progress(dataset, save_path): 300 | # Estimate the total size of the dataset in bytes 301 | total_size = sys.getsizeof(dataset) 302 | 303 | # Start saving the dataset in a separate thread 304 | save_thread = Thread(target=dataset.save_to_disk, args=(save_path,)) 305 | save_thread.start() 306 | 307 | # Create a tqdm progress bar and update it periodically 308 | with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: 309 | while save_thread.is_alive(): 310 | if os.path.exists(save_path): 311 | size = get_directory_size(save_path) 312 | # Update the progress bar based on the current size of the saved file 313 | pbar.update(size - pbar.n) # Update by the difference between current and previous size 314 | time.sleep(1) 315 | 316 | 317 | def get_dataset_from_dataroot( 318 | data_root, 319 | image_column="image", 320 | caption_column="caption", 321 | save_path="dataset", 322 | save=True, 323 | ): 324 | # Check if data_root is a symlink and resolve it to its target location if it is 325 | if os.path.islink(data_root): 326 | data_root = os.path.realpath(data_root) 327 | 328 | if os.path.exists(save_path): 329 | # Get the modified time of save_path 330 | save_path_mtime = os.stat(save_path).st_mtime 331 | 332 | if save: 333 | # Traverse the directory tree of data_root and get the modified time of all files and subdirectories 334 | print("Checking modified date of all the files and subdirectories in the dataset folder.") 335 | data_root_mtime = max( 336 | os.stat(os.path.join(root, f)).st_mtime 337 | for root, dirs, files in os.walk(data_root) 338 | for f in files + dirs 339 | ) 340 | 341 | # Check if data_root is newer than save_path 342 | if data_root_mtime > save_path_mtime: 343 | print( 344 | "The data_root folder has being updated recently. Removing previously saved dataset and updating it." 345 | ) 346 | shutil.rmtree(save_path, ignore_errors=True) 347 | else: 348 | print("The dataset is up-to-date. Loading...") 349 | # Load the dataset from save_path if it is up-to-date 350 | return load_from_disk(save_path) 351 | 352 | extensions = ["jpg", "jpeg", "png", "webp"] 353 | image_paths = [] 354 | 355 | for ext in extensions: 356 | image_paths.extend(list(Path(data_root).rglob(f"*.{ext}"))) 357 | 358 | random.shuffle(image_paths) 359 | data_dict = {image_column: [], caption_column: []} 360 | for image_path in tqdm(image_paths): 361 | # check image size and ignore images with 0 byte. 362 | if os.path.getsize(image_path) == 0: 363 | continue 364 | caption_path = image_path.with_suffix(".txt") 365 | if os.path.exists(str(caption_path)): 366 | captions = caption_path.read_text(encoding="utf-8").split("\n") 367 | captions = list(filter(lambda t: len(t) > 0, captions)) 368 | else: 369 | captions = [] 370 | image_path = str(image_path) 371 | data_dict[image_column].append(image_path) 372 | data_dict[caption_column].append(captions) 373 | dataset = datasets.Dataset.from_dict(data_dict) 374 | dataset = dataset.cast_column(image_column, Image()) 375 | 376 | if save: 377 | save_dataset_with_progress(dataset, save_path) 378 | 379 | return dataset 380 | 381 | 382 | def split_dataset_into_dataloaders(dataset, valid_frac=0.05, seed=42, batch_size=1): 383 | print(f"Dataset length: {len(dataset)} samples") 384 | if valid_frac > 0: 385 | train_size = int((1 - valid_frac) * len(dataset)) 386 | valid_size = len(dataset) - train_size 387 | print(f"Splitting dataset into {train_size} training samples and {valid_size} validation samples") 388 | split_generator = torch.Generator().manual_seed(seed) 389 | train_dataset, validation_dataset = random_split( 390 | dataset, 391 | [train_size, valid_size], 392 | generator=split_generator, 393 | ) 394 | else: 395 | print("Using shared dataset for training and validation") 396 | train_dataset = dataset 397 | validation_dataset = dataset 398 | 399 | dataloader = DataLoader( 400 | train_dataset, 401 | batch_size=batch_size, 402 | shuffle=True, 403 | ) 404 | 405 | validation_dataloader = DataLoader( 406 | validation_dataset, 407 | batch_size=batch_size, 408 | shuffle=False, 409 | ) 410 | return dataloader, validation_dataloader 411 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/distributed_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions for optional distributed execution. 3 | 4 | To use, 5 | 1. set the `BACKENDS` to the ones you want to make available, 6 | 2. in the script, wrap the argument parser with `wrap_arg_parser`, 7 | 3. in the script, set and use the backend by calling 8 | `set_backend_from_args`. 9 | 10 | You can check whether a backend is in use with the `using_backend` 11 | function. 12 | """ 13 | 14 | 15 | is_distributed = None 16 | """Whether we are distributed.""" 17 | backend = None 18 | """Backend in usage.""" 19 | 20 | 21 | def require_set_backend(): 22 | """Raise an `AssertionError` when the backend has not been set.""" 23 | assert backend is not None, ( 24 | "distributed backend is not set. Please call " 25 | "`distributed_utils.set_backend_from_args` at the start of your script" 26 | ) 27 | 28 | 29 | def using_backend(test_backend): 30 | """Return whether the backend is set to `test_backend`. 31 | 32 | `test_backend` may be a string of the name of the backend or 33 | its class. 34 | """ 35 | require_set_backend() 36 | if isinstance(test_backend, str): 37 | return backend.BACKEND_NAME == test_backend 38 | return isinstance(backend, test_backend) 39 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .attention import CrossAttention, MemoryEfficientCrossAttention 2 | from .mlp import SwiGLU, SwiGLUFFN, SwiGLUFFNFused 3 | 4 | __all__ = [ 5 | "SwiGLU", 6 | "SwiGLUFFN", 7 | "SwiGLUFFNFused", 8 | "CrossAttention", 9 | "MemoryEfficientCrossAttention", 10 | ] 11 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | from typing import Any, Callable, Optional 3 | 4 | from einops import rearrange 5 | from torch import nn 6 | 7 | try: 8 | from xformers.ops import memory_efficient_attention 9 | except ImportError: 10 | memory_efficient_attention = None 11 | 12 | 13 | def exists(x): 14 | return x is not None 15 | 16 | 17 | def default(val, d): 18 | if exists(val): 19 | return val 20 | return d() if isfunction(d) else d 21 | 22 | 23 | class CrossAttention(nn.Module): 24 | def __init__( 25 | self, 26 | query_dim, 27 | context_dim=None, 28 | heads=8, 29 | dim_head=64, 30 | dropout=0.0, 31 | ): 32 | super().__init__() 33 | inner_dim = dim_head * heads 34 | context_dim = default(context_dim, query_dim) 35 | 36 | self.scale = dim_head**-0.5 37 | self.heads = heads 38 | 39 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 40 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 41 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 42 | 43 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 44 | 45 | def forward(self, x, context=None): 46 | h = self.heads 47 | 48 | q = self.to_q(x) 49 | context = default(context, x) 50 | k = self.to_k(context) 51 | v = self.to_v(context) 52 | 53 | q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) 54 | q = q * self.scale 55 | 56 | sim = q @ k.transpose(-2, -1) 57 | sim = sim.softmax(dim=-1) 58 | 59 | out = sim @ v 60 | out = rearrange(out, "(b h) n d -> b n (h d)", h=h) 61 | return self.to_out(out) 62 | 63 | 64 | class MemoryEfficientCrossAttention(nn.Module): 65 | # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 66 | def __init__( 67 | self, 68 | query_dim, 69 | context_dim=None, 70 | heads=8, 71 | dim_head=64, 72 | dropout=0.0, 73 | ): 74 | super().__init__() 75 | inner_dim = dim_head * heads 76 | context_dim = default(context_dim, query_dim) 77 | 78 | self.heads = heads 79 | self.dim_head = dim_head 80 | 81 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 82 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 83 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 84 | 85 | self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) 86 | self.attention_op: Optional[Callable] = None 87 | 88 | def forward(self, x, context=None): 89 | q = self.to_q(x) 90 | context = default(context, x) 91 | k = self.to_k(context) 92 | v = self.to_v(context) 93 | 94 | b, _, _ = q.shape 95 | q, k, v = map( 96 | lambda t: t.unsqueeze(3) 97 | .reshape(b, t.shape[1], self.heads, self.dim_head) 98 | .permute(0, 2, 1, 3) 99 | .reshape(b * self.heads, t.shape[1], self.dim_head) 100 | .contiguous(), 101 | (q, k, v), 102 | ) 103 | 104 | out = memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) 105 | 106 | out = ( 107 | out.unsqueeze(0) 108 | .reshape(b, self.heads, out.shape[1], self.dim_head) 109 | .permute(0, 2, 1, 3) 110 | .reshape(b, out.shape[1], self.heads * self.dim_head) 111 | ) 112 | return self.to_out(out) 113 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/modules/mlp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from typing import Optional 8 | 9 | import torch.nn.functional as F 10 | from torch import Tensor, nn 11 | 12 | 13 | class SwiGLUFFN(nn.Module): 14 | def __init__( 15 | self, 16 | in_features: int, 17 | hidden_features: Optional[int] = None, 18 | out_features: Optional[int] = None, 19 | bias: bool = True, 20 | ) -> None: 21 | super().__init__() 22 | out_features = out_features or in_features 23 | hidden_features = hidden_features or in_features 24 | self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) 25 | self.w3 = nn.Linear(hidden_features, out_features, bias=bias) 26 | 27 | def forward(self, x: Tensor) -> Tensor: 28 | x12 = self.w12(x) 29 | x1, x2 = x12.chunk(2, dim=-1) 30 | hidden = F.silu(x1) * x2 31 | return self.w3(hidden) 32 | 33 | 34 | try: 35 | from xformers.ops import SwiGLU 36 | except ImportError: 37 | SwiGLU = SwiGLUFFN 38 | 39 | 40 | class SwiGLUFFNFused(SwiGLU): 41 | def __init__( 42 | self, 43 | in_features: int, 44 | hidden_features: Optional[int] = None, 45 | out_features: Optional[int] = None, 46 | bias: bool = True, 47 | ) -> None: 48 | out_features = out_features or in_features 49 | hidden_features = hidden_features or in_features 50 | hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 51 | super().__init__( 52 | in_features=in_features, 53 | hidden_features=hidden_features, 54 | out_features=out_features, 55 | bias=bias, 56 | ) 57 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/t5.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from dataclasses import dataclass, field 3 | from functools import cached_property 4 | from os import PathLike 5 | from typing import Dict, List, Optional, Tuple, Union 6 | 7 | import torch 8 | from beartype import beartype 9 | from torch import Tensor 10 | from transformers import T5Config, T5EncoderModel, T5Tokenizer 11 | 12 | # disable t5 warnings and a few others to keep the console clean and nice. 13 | warnings.filterwarnings("ignore") 14 | 15 | 16 | # dataclass for T5 model info 17 | @dataclass 18 | class T5ModelInfo: 19 | name: str 20 | cache_dir: Optional[PathLike] = None 21 | dtype: Optional[torch.dtype] = torch.float32 22 | config: T5Config = field(init=False) 23 | 24 | def __post_init__(self): 25 | self.config = T5Config.from_pretrained(self.name, cache_dir=self.cache_dir) 26 | self._model = None 27 | self._tokenizer = None 28 | 29 | # Using cached_property to avoid loading the model/tokenizer until needed 30 | @cached_property 31 | def model(self) -> T5EncoderModel: 32 | if not self._model: 33 | self._model = T5EncoderModel.from_pretrained( 34 | self.name, cache_dir=self.cache_dir, torch_dtype=self.dtype 35 | ) 36 | return self._model 37 | 38 | @cached_property 39 | def tokenizer(self) -> T5Tokenizer: 40 | if not self._tokenizer: 41 | self._tokenizer = T5Tokenizer.from_pretrained( 42 | self.name, cache_dir=self.cache_dir, torch_dtype=self.dtype 43 | ) 44 | return self._tokenizer 45 | 46 | 47 | # config 48 | MAX_LENGTH = 512 49 | DEFAULT_T5_NAME = "google/t5-v1_1-base" 50 | T5_OBJECTS: Dict[str, T5ModelInfo] = {} 51 | 52 | 53 | def get_model_and_tokenizer( 54 | name: str, cache_path: Optional[PathLike] = None, dtype: torch.dtype = torch.float32 55 | ) -> Tuple[T5EncoderModel, T5Tokenizer]: 56 | global T5_OBJECTS 57 | if name not in T5_OBJECTS.keys(): 58 | T5_OBJECTS[name] = T5ModelInfo(name=name, cache_dir=cache_path, dtype=dtype) 59 | return T5_OBJECTS[name].model, T5_OBJECTS[name].tokenizer 60 | 61 | 62 | def get_encoded_dim( 63 | name: str, cache_path: Optional[PathLike] = None, dtype: torch.dtype = torch.float32 64 | ) -> int: 65 | global T5_OBJECTS 66 | if name not in T5_OBJECTS.keys(): 67 | T5_OBJECTS[name] = T5ModelInfo(name=name, cache_dir=cache_path, dtype=dtype) 68 | return T5_OBJECTS[name].config.d_model 69 | 70 | 71 | # encoding text 72 | @beartype 73 | def t5_encode_text_from_encoded( 74 | input_ids: Tensor, 75 | attn_mask: Tensor, 76 | t5: T5EncoderModel, 77 | output_device: Optional[Union[torch.device, str]] = None, 78 | ) -> Tensor: 79 | device = t5.device 80 | input_ids, attn_mask = input_ids.to(device), attn_mask.to(device) 81 | with torch.no_grad(): 82 | output = t5(input_ids=input_ids, attention_mask=attn_mask) 83 | encoded_text = output.last_hidden_state.detach() 84 | 85 | attn_mask = attn_mask.bool() 86 | encoded_text: Tensor = encoded_text.masked_fill(attn_mask[..., None], 0.0) 87 | return encoded_text if output_device is None else encoded_text.to(output_device) 88 | 89 | 90 | @beartype 91 | def t5_encode_text( 92 | texts: Union[str, List[str]], 93 | tokenizer: T5Tokenizer, 94 | t5: T5EncoderModel, 95 | output_device: Optional[Union[torch.device, str]] = None, 96 | ) -> Tensor: 97 | if isinstance(texts, str): 98 | texts = [texts] 99 | 100 | encoded = tokenizer.batch_encode_plus( 101 | texts, 102 | return_tensors="pt", 103 | padding="max_length", 104 | max_length=MAX_LENGTH, 105 | truncation=True, 106 | ) 107 | return t5_encode_text_from_encoded(encoded["input_ids"], encoded["attention_mask"], t5, output_device) 108 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_accelerated_trainer import get_accelerator 2 | from .maskgit_trainer import MaskGitTrainer 3 | from .vqvae_trainers import VQGanVAETrainer 4 | 5 | __all__ = [ 6 | "VQGanVAETrainer", 7 | "MaskGitTrainer", 8 | "get_accelerator", 9 | ] 10 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/trainers/base_accelerated_trainer.py: -------------------------------------------------------------------------------- 1 | from os import PathLike 2 | from pathlib import Path 3 | from shutil import rmtree 4 | from typing import Optional, Union 5 | 6 | import accelerate 7 | import numpy as np 8 | import torch 9 | from accelerate import Accelerator, DistributedDataParallelKwargs, DistributedType 10 | from beartype import beartype 11 | from datasets import Dataset 12 | from lion_pytorch import Lion 13 | from PIL import Image 14 | from torch import nn 15 | from torch.optim import Adam, AdamW, Optimizer 16 | from torch.utils.data import DataLoader, random_split 17 | from torch_optimizer import ( 18 | PID, 19 | QHM, 20 | SGDP, 21 | SGDW, 22 | SWATS, 23 | AccSGD, 24 | AdaBound, 25 | AdaMod, 26 | AdamP, 27 | AggMo, 28 | DiffGrad, 29 | Lamb, 30 | NovoGrad, 31 | QHAdam, 32 | RAdam, 33 | Shampoo, 34 | Yogi, 35 | ) 36 | from transformers.optimization import Adafactor 37 | 38 | try: 39 | from accelerate.data_loader import MpDeviceLoaderWrapper 40 | except ImportError: 41 | MpDeviceLoaderWrapper = DataLoader 42 | pass 43 | 44 | try: 45 | from bitsandbytes.optim import Adam8bit, AdamW8bit, Lion8bit 46 | except ImportError: 47 | Adam8bit = AdamW8bit = Lion8bit = None 48 | 49 | try: 50 | import wandb 51 | except ImportError: 52 | wandb = None 53 | 54 | ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) 55 | 56 | 57 | def noop(*args, **kwargs): 58 | pass 59 | 60 | 61 | # helper functions 62 | 63 | 64 | def identity(t, *args, **kwargs): 65 | return t 66 | 67 | 68 | def cast_tuple(t): 69 | return t if isinstance(t, (tuple, list)) else (t,) 70 | 71 | 72 | def yes_or_no(question): 73 | answer = input(f"{question} (y/n) ") 74 | return answer.lower() in ("yes", "y") 75 | 76 | 77 | def pair(val): 78 | return val if isinstance(val, tuple) else (val, val) 79 | 80 | 81 | def convert_image_to_fn(img_type, image): 82 | if image.mode != img_type: 83 | return image.convert(img_type) 84 | return image 85 | 86 | 87 | # image related helpers fnuctions and dataset 88 | 89 | 90 | def get_accelerator(*args, **kwargs): 91 | kwargs_handlers = kwargs.get("kwargs_handlers", []) 92 | if ddp_kwargs not in kwargs_handlers: 93 | kwargs_handlers.append(ddp_kwargs) 94 | kwargs.update(kwargs_handlers=kwargs_handlers) 95 | accelerator = Accelerator(*args, **kwargs) 96 | return accelerator 97 | 98 | 99 | def split_dataset(dataset: Dataset, valid_frac: float, accelerator: Accelerator, seed: int = 42): 100 | if valid_frac > 0: 101 | train_size = int((1 - valid_frac) * len(dataset)) 102 | valid_size = len(dataset) - train_size 103 | ds, valid_ds = random_split( 104 | dataset, 105 | [train_size, valid_size], 106 | generator=torch.Generator().manual_seed(seed), 107 | ) 108 | accelerator.print( 109 | f"training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples" 110 | ) 111 | else: 112 | valid_ds = ds 113 | accelerator.print(f"training with shared training and valid dataset of {len(ds)} samples") 114 | return ds, valid_ds 115 | 116 | 117 | # main trainer class 118 | 119 | 120 | def get_optimizer( 121 | use_8bit_adam: bool, 122 | optimizer: str, 123 | parameters: dict, 124 | lr: float, 125 | weight_decay: float, 126 | optimizer_kwargs: dict = {}, 127 | ): 128 | if use_8bit_adam is True and Adam8bit is None: 129 | print( 130 | "Please install bitsandbytes to use 8-bit optimizers. You can do so by running `pip install " 131 | "bitsandbytes` | Defaulting to non 8-bit equivalent..." 132 | ) 133 | 134 | bnb_supported_optims = ["Adam", "AdamW", "Lion"] 135 | if use_8bit_adam and optimizer not in bnb_supported_optims: 136 | print(f"8bit is not supported by the {optimizer} optimizer, Using standard {optimizer} instead.") 137 | 138 | # optimizers 139 | if optimizer == "Adam": 140 | return ( 141 | Adam8bit(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs) 142 | if use_8bit_adam and Adam8bit is not None 143 | else Adam(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs) 144 | ) 145 | elif optimizer == "AdamW": 146 | return ( 147 | AdamW8bit(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs) 148 | if use_8bit_adam and AdamW8bit is not None 149 | else AdamW(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs) 150 | ) 151 | elif optimizer == "Lion": 152 | # Reckless reuse of the use_8bit_adam flag 153 | return ( 154 | Lion8bit(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs) 155 | if use_8bit_adam and Lion8bit is not None 156 | else Lion(parameters, lr=lr, weight_decay=weight_decay, **optimizer_kwargs) 157 | ) 158 | elif optimizer == "Adafactor": 159 | return Adafactor( 160 | parameters, 161 | lr=lr, 162 | weight_decay=weight_decay, 163 | relative_step=False, 164 | scale_parameter=False, 165 | **optimizer_kwargs, 166 | ) 167 | elif optimizer == "AccSGD": 168 | return AccSGD(parameters, lr=lr, weight_decay=weight_decay) 169 | elif optimizer == "AdaBound": 170 | return AdaBound(parameters, lr=lr, weight_decay=weight_decay) 171 | elif optimizer == "AdaMod": 172 | return AdaMod(parameters, lr=lr, weight_decay=weight_decay) 173 | elif optimizer == "AdamP": 174 | return AdamP(parameters, lr=lr, weight_decay=weight_decay) 175 | elif optimizer == "AggMo": 176 | return AggMo(parameters, lr=lr, weight_decay=weight_decay) 177 | elif optimizer == "DiffGrad": 178 | return DiffGrad(parameters, lr=lr, weight_decay=weight_decay) 179 | elif optimizer == "Lamb": 180 | return Lamb(parameters, lr=lr, weight_decay=weight_decay) 181 | elif optimizer == "NovoGrad": 182 | return NovoGrad(parameters, lr=lr, weight_decay=weight_decay) 183 | elif optimizer == "PID": 184 | return PID(parameters, lr=lr, weight_decay=weight_decay) 185 | elif optimizer == "QHAdam": 186 | return QHAdam(parameters, lr=lr, weight_decay=weight_decay) 187 | elif optimizer == "QHM": 188 | return QHM(parameters, lr=lr, weight_decay=weight_decay) 189 | elif optimizer == "RAdam": 190 | return RAdam(parameters, lr=lr, weight_decay=weight_decay) 191 | elif optimizer == "SGDP": 192 | return SGDP(parameters, lr=lr, weight_decay=weight_decay) 193 | elif optimizer == "SGDW": 194 | return SGDW(parameters, lr=lr, weight_decay=weight_decay) 195 | elif optimizer == "Shampoo": 196 | return Shampoo(parameters, lr=lr, weight_decay=weight_decay) 197 | elif optimizer == "SWATS": 198 | return SWATS(parameters, lr=lr, weight_decay=weight_decay) 199 | elif optimizer == "Yogi": 200 | return Yogi(parameters, lr=lr, weight_decay=weight_decay) 201 | else: 202 | raise NotImplementedError(f"{optimizer} optimizer not supported yet.") 203 | 204 | 205 | @beartype 206 | class BaseAcceleratedTrainer(nn.Module): 207 | def __init__( 208 | self, 209 | dataloader: Union[DataLoader, MpDeviceLoaderWrapper], 210 | valid_dataloader: Union[DataLoader, MpDeviceLoaderWrapper], 211 | accelerator: Accelerator, 212 | *, 213 | current_step: int, 214 | num_train_steps: int, 215 | num_epochs: int = 5, 216 | max_grad_norm: Optional[int] = None, 217 | save_results_every: int = 100, 218 | save_model_every: int = 1000, 219 | results_dir: Union[str, PathLike] = Path.cwd().joinpath("results"), 220 | logging_dir: Union[str, PathLike] = Path.cwd().joinpath("results/logs"), 221 | apply_grad_penalty_every: int = 4, 222 | gradient_accumulation_steps: int = 1, 223 | clear_previous_experiments: bool = False, 224 | validation_image_scale: Union[int, float] = 1.0, 225 | only_save_last_checkpoint: bool = False, 226 | ): 227 | super().__init__() 228 | self.model: nn.Module = None 229 | # instantiate accelerator 230 | self.gradient_accumulation_steps: int = gradient_accumulation_steps 231 | self.accelerator: Accelerator = accelerator 232 | self.logging_dir: Path = Path(logging_dir) if not isinstance(logging_dir, Path) else logging_dir 233 | self.results_dir: Path = Path(results_dir) if not isinstance(results_dir, Path) else results_dir 234 | 235 | # training params 236 | self.only_save_last_checkpoint: bool = only_save_last_checkpoint 237 | self.validation_image_scale: Union[int, float] = validation_image_scale 238 | self.register_buffer("steps", torch.Tensor([current_step])) 239 | self.num_train_steps: int = num_train_steps 240 | self.num_epochs = num_epochs 241 | self.max_grad_norm: Optional[Union[int, float]] = max_grad_norm 242 | 243 | self.dl = dataloader 244 | self.valid_dl = valid_dataloader 245 | self.dl_iter = iter(self.dl) 246 | self.valid_dl_iter = iter(self.valid_dl) 247 | 248 | self.save_model_every: int = save_model_every 249 | self.save_results_every: int = save_results_every 250 | self.apply_grad_penalty_every: int = apply_grad_penalty_every 251 | 252 | # Clear previous experiment data if requested 253 | if clear_previous_experiments is True and self.accelerator.is_local_main_process: 254 | if self.results_dir.exists(): 255 | rmtree(self.results_dir, ignore_errors=True) 256 | # Make sure logging and results directories exist 257 | self.logging_dir.mkdir(parents=True, exist_ok=True) 258 | self.results_dir.mkdir(parents=True, exist_ok=True) 259 | 260 | self.optim: Optimizer = None 261 | 262 | self.print = self.accelerator.print 263 | self.log = self.accelerator.log 264 | 265 | self.on_tpu = self.accelerator.distributed_type == accelerate.DistributedType.TPU 266 | 267 | def save(self, path): 268 | if not self.accelerator.is_main_process: 269 | return 270 | 271 | pkg = dict( 272 | model=self.accelerator.get_state_dict(self.model), 273 | optim=self.optim.state_dict(), 274 | ) 275 | self.accelerator.save(pkg, path) 276 | 277 | def load(self, path: Union[str, PathLike]): 278 | if not isinstance(path, Path): 279 | path = Path(path) 280 | 281 | if not path.exists(): 282 | raise FileNotFoundError(f"Checkpoint file {path} does not exist.") 283 | 284 | pkg = torch.load(path, map_location="cpu") 285 | model = self.accelerator.unwrap_model(self.model) 286 | model.load_state_dict(pkg["model"]) 287 | 288 | self.optim.load_state_dict(pkg["optim"]) 289 | return pkg 290 | 291 | def log_validation_images(self, images, step, prompts=None): 292 | if self.validation_image_scale != 1: 293 | # Calculate the new height based on the scale factor 294 | new_height = int(np.array(images[0]).shape[0] * self.validation_image_scale) 295 | 296 | # Calculate the aspect ratio of the original image 297 | aspect_ratio = np.array(images[0]).shape[1] / np.array(images[0]).shape[0] 298 | 299 | # Calculate the new width based on the new height and aspect ratio 300 | new_width = int(new_height * aspect_ratio) 301 | 302 | # Resize the images using the new width and height 303 | output_size = (new_width, new_height) 304 | images_pil = [Image.fromarray(np.array(image)) for image in images] 305 | images_pil_resized = [image_pil.resize(output_size) for image_pil in images_pil] 306 | images = [np.array(image_pil) for image_pil in images_pil_resized] 307 | 308 | for tracker in self.accelerator.trackers: 309 | if tracker.name == "tensorboard": 310 | np_images = np.stack([np.asarray(img) for img in images]) 311 | tracker.writer.add_images("validation", np_images, step, dataformats="NHWC") 312 | if tracker.name == "wandb": 313 | tracker.log( 314 | { 315 | "validation": [ 316 | wandb.Image(image, caption="" if not prompts else prompts[i]) 317 | for i, image in enumerate(images) 318 | ] 319 | } 320 | ) 321 | 322 | @property 323 | def device(self): 324 | return self.accelerator.device 325 | 326 | @property 327 | def is_distributed(self): 328 | return ( 329 | False 330 | if self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1 331 | else True 332 | ) 333 | 334 | @property 335 | def is_main_process(self): 336 | return self.accelerator.is_main_process 337 | 338 | @property 339 | def is_local_main_process(self): 340 | return self.accelerator.is_local_main_process 341 | 342 | def train_step(self): 343 | raise NotImplementedError("You are calling train_step on the base trainer with no models") 344 | 345 | def train(self, log_fn=noop): 346 | self.model.train() 347 | while self.steps < self.num_train_steps: 348 | with self.accelerator.autocast(): 349 | logs = self.train_step() 350 | log_fn(logs) 351 | self.print("training complete") 352 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/trainers/maskgit_trainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch # noqa: F401 4 | import torch.nn.functional as F 5 | from accelerate import Accelerator 6 | from diffusers.optimization import SchedulerType 7 | from ema_pytorch import EMA 8 | from omegaconf import OmegaConf 9 | from PIL import Image 10 | from torch.optim import Optimizer 11 | from torch.utils.data import DataLoader 12 | from torchvision.utils import save_image 13 | 14 | from muse_maskgit_pytorch.muse_maskgit_pytorch import MaskGit 15 | from muse_maskgit_pytorch.t5 import t5_encode_text_from_encoded 16 | from muse_maskgit_pytorch.trainers.base_accelerated_trainer import BaseAcceleratedTrainer 17 | 18 | try: 19 | import torch_xla 20 | import torch_xla.core.xla_model as xm 21 | import torch_xla.debug.metrics as met 22 | except ImportError: 23 | torch_xla = None 24 | xm = None 25 | met = None 26 | 27 | from tqdm import tqdm 28 | 29 | 30 | class MaskGitTrainer(BaseAcceleratedTrainer): 31 | def __init__( 32 | self, 33 | maskgit: MaskGit, 34 | dataloader: DataLoader, 35 | valid_dataloader: DataLoader, 36 | accelerator: Accelerator, 37 | optimizer: Optimizer, 38 | scheduler: SchedulerType, 39 | *, 40 | current_step: int, 41 | num_train_steps: int, 42 | num_epochs: int = 5, 43 | batch_size: int, 44 | gradient_accumulation_steps: int = 1, 45 | max_grad_norm: float = None, 46 | save_results_every: int = 100, 47 | save_model_every: int = 1000, 48 | log_metrics_every: int = 10, 49 | results_dir="./results", 50 | logging_dir="./results/logs", 51 | apply_grad_penalty_every=4, 52 | use_ema=True, 53 | ema_update_after_step=0, 54 | ema_update_every=1, 55 | validation_prompts=["a photo of a dog"], 56 | timesteps=18, 57 | clear_previous_experiments=False, 58 | validation_image_scale: float = 1.0, 59 | only_save_last_checkpoint=False, 60 | args=None, 61 | ): 62 | super().__init__( 63 | dataloader=dataloader, 64 | valid_dataloader=valid_dataloader, 65 | accelerator=accelerator, 66 | current_step=current_step, 67 | num_train_steps=num_train_steps, 68 | num_epochs=num_epochs, 69 | gradient_accumulation_steps=gradient_accumulation_steps, 70 | max_grad_norm=max_grad_norm, 71 | save_results_every=save_results_every, 72 | save_model_every=save_model_every, 73 | results_dir=results_dir, 74 | logging_dir=logging_dir, 75 | apply_grad_penalty_every=apply_grad_penalty_every, 76 | clear_previous_experiments=clear_previous_experiments, 77 | validation_image_scale=validation_image_scale, 78 | only_save_last_checkpoint=only_save_last_checkpoint, 79 | ) 80 | self.save_results_every = save_results_every 81 | self.log_metrics_every = log_metrics_every 82 | self.batch_size = batch_size 83 | self.current_step = current_step 84 | self.timesteps = timesteps 85 | 86 | # arguments used for the training script, 87 | # we are going to use them later to save them to a config file. 88 | self.args = args 89 | 90 | # maskgit 91 | maskgit.vae.requires_grad_(False) 92 | maskgit.transformer.t5.requires_grad_(False) 93 | self.model: MaskGit = maskgit 94 | 95 | self.optim: Optimizer = optimizer 96 | self.lr_scheduler: SchedulerType = scheduler 97 | 98 | self.use_ema = use_ema 99 | self.validation_prompts: List[str] = validation_prompts 100 | if use_ema: 101 | ema_model = EMA( 102 | self.model, 103 | update_after_step=ema_update_after_step, 104 | update_every=ema_update_every, 105 | ) 106 | self.ema_model = ema_model 107 | else: 108 | self.ema_model = None 109 | 110 | if not self.on_tpu: 111 | if self.num_train_steps <= 0: 112 | self.training_bar = tqdm(initial=int(self.steps.item()), total=len(self.dl) * self.num_epochs) 113 | else: 114 | self.training_bar = tqdm(initial=int(self.steps.item()), total=self.num_train_steps) 115 | 116 | self.info_bar = tqdm(total=0, bar_format="{desc}") 117 | 118 | def save_validation_images( 119 | self, validation_prompts, step: int, cond_image=None, cond_scale=3, temperature=1, timesteps=18 120 | ): 121 | # moved the print to the top of the function so it shows before the progress bar for reability. 122 | if validation_prompts: 123 | self.accelerator.print( 124 | f"\nStep: {step} | Logging with prompts: {[' | '.join(validation_prompts)]}" 125 | ) 126 | 127 | images = self.model.generate( 128 | validation_prompts, 129 | cond_images=cond_image, 130 | cond_scale=cond_scale, 131 | temperature=temperature, 132 | timesteps=timesteps, 133 | ).to(self.accelerator.device) 134 | 135 | save_dir = self.results_dir.joinpath("MaskGit") 136 | save_dir.mkdir(exist_ok=True, parents=True) 137 | save_file = save_dir.joinpath(f"maskgit_{step}.png") 138 | 139 | if self.accelerator.is_main_process: 140 | save_image(images, save_file, "png") 141 | self.log_validation_images([Image.open(save_file)], step, ["|".join(validation_prompts)]) 142 | return save_file 143 | 144 | def train(self): 145 | self.steps = self.steps + 1 146 | self.model.train() 147 | 148 | if self.accelerator.is_main_process: 149 | proc_label = f"[P{self.accelerator.process_index}][Master]" 150 | else: 151 | proc_label = f"[P{self.accelerator.process_index}][Worker]" 152 | 153 | # logs 154 | for epoch in range(self.current_step // len(self.dl), self.num_epochs): 155 | for imgs, input_ids, attn_mask in iter(self.dl): 156 | train_loss = 0.0 157 | steps = int(self.steps.item()) 158 | 159 | with torch.no_grad(): 160 | text_embeds = t5_encode_text_from_encoded( 161 | input_ids, attn_mask, self.model.transformer.t5, self.accelerator.device 162 | ) 163 | 164 | with self.accelerator.accumulate(self.model), self.accelerator.autocast(): 165 | loss = self.model(imgs, text_embeds=text_embeds) 166 | self.accelerator.backward(loss) 167 | if self.max_grad_norm is not None and self.accelerator.sync_gradients: 168 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 169 | self.optim.step() 170 | self.lr_scheduler.step() 171 | self.optim.zero_grad() 172 | 173 | if self.use_ema: 174 | self.ema_model.update() 175 | 176 | gathered_loss = self.accelerator.gather_for_metrics(loss) 177 | train_loss = gathered_loss.mean() / self.gradient_accumulation_steps 178 | 179 | logs = {"loss": train_loss, "lr": self.lr_scheduler.get_last_lr()[0]} 180 | 181 | if self.on_tpu: 182 | self.accelerator.print( 183 | f"\n[E{epoch + 1}][{steps}]{proc_label}: " 184 | f"maskgit loss: {logs['loss']} - lr: {logs['lr']}" 185 | ) 186 | else: 187 | self.training_bar.update() 188 | self.info_bar.set_description_str( 189 | f"[E{epoch + 1}]{proc_label}: " f"maskgit loss: {logs['loss']} - lr: {logs['lr']}" 190 | ) 191 | 192 | self.accelerator.log(logs, step=steps) 193 | 194 | if not (steps % self.save_model_every): 195 | self.accelerator.print( 196 | f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"saving model to {self.results_dir}" 197 | ) 198 | 199 | state_dict = self.accelerator.unwrap_model(self.model).state_dict() 200 | maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit" 201 | file_name = ( 202 | f"{maskgit_save_name}.{steps}.pt" 203 | if not self.only_save_last_checkpoint 204 | else f"{maskgit_save_name}.pt" 205 | ) 206 | 207 | model_path = self.results_dir.joinpath(file_name) 208 | self.accelerator.wait_for_everyone() 209 | self.accelerator.save(state_dict, model_path) 210 | 211 | if self.args and not self.args.do_not_save_config: 212 | # save config file next to the model file. 213 | conf = OmegaConf.create(vars(self.args)) 214 | OmegaConf.save(conf, f"{model_path}.yaml") 215 | 216 | if self.use_ema: 217 | self.accelerator.print( 218 | f"\n[E{epoch + 1}][{steps}]{proc_label}: " 219 | f"saving EMA model to {self.results_dir}" 220 | ) 221 | 222 | ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() 223 | file_name = ( 224 | f"{maskgit_save_name}.{steps}.ema.pt" 225 | if not self.only_save_last_checkpoint 226 | else f"{maskgit_save_name}.ema.pt" 227 | ) 228 | model_path = str(self.results_dir / file_name) 229 | self.accelerator.wait_for_everyone() 230 | self.accelerator.save(ema_state_dict, model_path) 231 | 232 | if self.args and not self.args.do_not_save_config: 233 | # save config file next to the model file. 234 | conf = OmegaConf.create(vars(self.args)) 235 | OmegaConf.save(conf, f"{model_path}.yaml") 236 | 237 | if not (steps % self.save_results_every): 238 | cond_image = None 239 | if self.model.cond_image_size: 240 | cond_image = F.interpolate(imgs, self.model.cond_image_size, mode="nearest") 241 | self.validation_prompts = [""] * self.batch_size 242 | 243 | if self.on_tpu: 244 | self.accelerator.print(f"\n[E{epoch + 1}]{proc_label}: " f"Logging validation images") 245 | else: 246 | self.info_bar.set_description_str( 247 | f"[E{epoch + 1}]{proc_label}: " f"Logging validation images" 248 | ) 249 | 250 | saved_image = self.save_validation_images( 251 | self.validation_prompts, 252 | steps, 253 | cond_image=cond_image, 254 | timesteps=self.timesteps, 255 | ) 256 | if self.on_tpu: 257 | self.accelerator.print( 258 | f"\n[E{epoch + 1}][{steps}]{proc_label}: saved to {saved_image}" 259 | ) 260 | else: 261 | self.info_bar.set_description_str( 262 | f"[E{epoch + 1}]{proc_label}: " f"saved to {saved_image}" 263 | ) 264 | 265 | if met is not None and not (steps % self.log_metrics_every): 266 | if self.on_tpu: 267 | self.accelerator.print(f"\n[E{epoch + 1}][{steps}]{proc_label}: metrics:") 268 | else: 269 | self.info_bar.set_description_str(f"[E{epoch + 1}]{proc_label}: metrics:") 270 | 271 | self.steps += 1 272 | 273 | # if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: 274 | # if self.on_tpu: 275 | # self.accelerator.print( 276 | # f"\n[E{epoch + 1}][{int(self.steps.item())}]{proc_label}" 277 | # f"[STOP EARLY]: Stopping training early..." 278 | # ) 279 | # else: 280 | # self.info_bar.set_description_str( 281 | # f"[E{epoch + 1}]{proc_label}" f"[STOP EARLY]: Stopping training early..." 282 | # ) 283 | # break 284 | 285 | # loop complete, save final model 286 | self.accelerator.print( 287 | f"\n[E{epoch + 1}][{steps}]{proc_label}[FINAL]: saving model to {self.results_dir}" 288 | ) 289 | state_dict = self.accelerator.unwrap_model(self.model).state_dict() 290 | maskgit_save_name = "maskgit_superres" if self.model.cond_image_size else "maskgit" 291 | file_name = ( 292 | f"{maskgit_save_name}.{steps}.pt" 293 | if not self.only_save_last_checkpoint 294 | else f"{maskgit_save_name}.pt" 295 | ) 296 | 297 | model_path = self.results_dir.joinpath(file_name) 298 | self.accelerator.wait_for_everyone() 299 | self.accelerator.save(state_dict, model_path) 300 | 301 | if self.args and not self.args.do_not_save_config: 302 | # save config file next to the model file. 303 | conf = OmegaConf.create(vars(self.args)) 304 | OmegaConf.save(conf, f"{model_path}.yaml") 305 | 306 | if self.use_ema: 307 | self.accelerator.print(f"\n[{steps}]{proc_label}[FINAL]: saving EMA model to {self.results_dir}") 308 | ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() 309 | file_name = ( 310 | f"{maskgit_save_name}.{steps}.ema.pt" 311 | if not self.only_save_last_checkpoint 312 | else f"{maskgit_save_name}.ema.pt" 313 | ) 314 | model_path = str(self.results_dir / file_name) 315 | self.accelerator.wait_for_everyone() 316 | self.accelerator.save(ema_state_dict, model_path) 317 | 318 | if self.args and not self.args.do_not_save_config: 319 | # save config file next to the model file. 320 | conf = OmegaConf.create(vars(self.args)) 321 | OmegaConf.save(conf, f"{model_path}.yaml") 322 | 323 | cond_image = None 324 | if self.model.cond_image_size: 325 | self.accelerator.print( 326 | "With conditional image training, we recommend keeping the validation prompts to empty strings" 327 | ) 328 | cond_image = F.interpolate(imgs, self.model.cond_image_size, mode="nearest") 329 | 330 | steps = int(self.steps.item()) + 1 # get the final step count, plus one 331 | self.accelerator.print(f"\n[{steps}]{proc_label}: Logging validation images") 332 | saved_image = self.save_validation_images(self.validation_prompts, steps, cond_image=cond_image) 333 | self.accelerator.print(f"\n[{steps}]{proc_label}: saved to {saved_image}") 334 | 335 | if met is not None and not (steps % self.log_metrics_every): 336 | self.accelerator.print(f"\n[{steps}]{proc_label}: metrics:") 337 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/trainers/vqvae_trainers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from accelerate import Accelerator 3 | from diffusers.optimization import get_scheduler 4 | from einops import rearrange 5 | from ema_pytorch import EMA 6 | from omegaconf import OmegaConf 7 | from PIL import Image 8 | from torch.optim.lr_scheduler import LRScheduler 9 | from torch.utils.data import DataLoader 10 | from torchvision.utils import make_grid, save_image 11 | from tqdm import tqdm 12 | 13 | from muse_maskgit_pytorch.trainers.base_accelerated_trainer import ( 14 | BaseAcceleratedTrainer, 15 | get_optimizer, 16 | ) 17 | from muse_maskgit_pytorch.vqgan_vae import VQGanVAE 18 | 19 | 20 | def noop(*args, **kwargs): 21 | pass 22 | 23 | 24 | def accum_log(log, new_logs): 25 | for key, new_value in new_logs.items(): 26 | old_value = log.get(key, 0.0) 27 | log[key] = old_value + new_value 28 | return log 29 | 30 | 31 | def exists(val): 32 | return val is not None 33 | 34 | 35 | class VQGanVAETrainer(BaseAcceleratedTrainer): 36 | def __init__( 37 | self, 38 | vae: VQGanVAE, 39 | dataloader: DataLoader, 40 | valid_dataloader: DataLoader, 41 | accelerator: Accelerator, 42 | *, 43 | current_step, 44 | num_train_steps, 45 | num_epochs: int = 5, 46 | gradient_accumulation_steps=1, 47 | max_grad_norm=None, 48 | save_results_every=100, 49 | save_model_every=1000, 50 | results_dir="./results", 51 | logging_dir="./results/logs", 52 | apply_grad_penalty_every=4, 53 | lr=3e-4, 54 | lr_scheduler_type="constant", 55 | lr_warmup_steps=500, 56 | discr_max_grad_norm=None, 57 | use_ema=True, 58 | ema_beta=0.995, 59 | ema_update_after_step=0, 60 | ema_update_every=1, 61 | clear_previous_experiments=False, 62 | validation_image_scale: float = 1.0, 63 | only_save_last_checkpoint=False, 64 | optimizer="Adam", 65 | weight_decay=0.0, 66 | use_8bit_adam=False, 67 | num_cycles=1, 68 | scheduler_power=1.0, 69 | args=None, 70 | ): 71 | super().__init__( 72 | dataloader, 73 | valid_dataloader, 74 | accelerator, 75 | current_step=current_step, 76 | num_train_steps=num_train_steps, 77 | num_epochs=num_epochs, 78 | gradient_accumulation_steps=gradient_accumulation_steps, 79 | max_grad_norm=max_grad_norm, 80 | save_results_every=save_results_every, 81 | save_model_every=save_model_every, 82 | results_dir=results_dir, 83 | logging_dir=logging_dir, 84 | apply_grad_penalty_every=apply_grad_penalty_every, 85 | clear_previous_experiments=clear_previous_experiments, 86 | validation_image_scale=validation_image_scale, 87 | only_save_last_checkpoint=only_save_last_checkpoint, 88 | ) 89 | 90 | # arguments used for the training script, 91 | # we are going to use them later to save them to a config file. 92 | self.args = args 93 | 94 | self.current_step = current_step 95 | 96 | # vae 97 | self.model = vae 98 | 99 | all_parameters = set(vae.parameters()) 100 | discr_parameters = set(vae.discr.parameters()) 101 | vae_parameters = all_parameters - discr_parameters 102 | 103 | # optimizers 104 | self.optim = get_optimizer(use_8bit_adam, optimizer, vae_parameters, lr, weight_decay) 105 | self.discr_optim = get_optimizer(use_8bit_adam, optimizer, discr_parameters, lr, weight_decay) 106 | 107 | if self.num_train_steps > 0: 108 | self.num_lr_steps = self.num_train_steps * self.gradient_accumulation_steps 109 | else: 110 | self.num_lr_steps = self.num_epochs * len(self.dl) 111 | 112 | self.lr_scheduler: LRScheduler = get_scheduler( 113 | lr_scheduler_type, 114 | optimizer=self.optim, 115 | num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, 116 | num_training_steps=self.num_lr_steps, 117 | num_cycles=num_cycles, 118 | power=scheduler_power, 119 | ) 120 | 121 | self.lr_scheduler_discr: LRScheduler = get_scheduler( 122 | lr_scheduler_type, 123 | optimizer=self.discr_optim, 124 | num_warmup_steps=lr_warmup_steps * self.gradient_accumulation_steps, 125 | num_training_steps=self.num_lr_steps, 126 | num_cycles=num_cycles, 127 | power=scheduler_power, 128 | ) 129 | 130 | self.discr_max_grad_norm = discr_max_grad_norm 131 | 132 | # prepare with accelerator 133 | 134 | ( 135 | self.model, 136 | self.optim, 137 | self.discr_optim, 138 | self.dl, 139 | self.valid_dl, 140 | self.lr_scheduler, 141 | self.lr_scheduler_discr, 142 | ) = accelerator.prepare( 143 | self.model, 144 | self.optim, 145 | self.discr_optim, 146 | self.dl, 147 | self.valid_dl, 148 | self.lr_scheduler, 149 | self.lr_scheduler_discr, 150 | ) 151 | self.model.train() 152 | 153 | self.use_ema = use_ema 154 | 155 | if use_ema: 156 | self.ema_model = EMA( 157 | vae, 158 | update_after_step=ema_update_after_step, 159 | update_every=ema_update_every, 160 | ) 161 | self.ema_model = accelerator.prepare(self.ema_model) 162 | 163 | if not self.on_tpu: 164 | if self.num_train_steps <= 0: 165 | self.training_bar = tqdm(initial=int(self.steps.item()), total=len(self.dl) * self.num_epochs) 166 | else: 167 | self.training_bar = tqdm(initial=int(self.steps.item()), total=self.num_train_steps) 168 | 169 | self.info_bar = tqdm(total=0, bar_format="{desc}") 170 | 171 | def load(self, path): 172 | pkg = super().load(path) 173 | self.discr_optim.load_state_dict(pkg["discr_optim"]) 174 | 175 | def save(self, path): 176 | if not self.is_local_main_process: 177 | return 178 | 179 | pkg = dict( 180 | model=self.get_state_dict(self.model), 181 | optim=self.optim.state_dict(), 182 | discr_optim=self.discr_optim.state_dict(), 183 | ) 184 | self.accelerator.save(pkg, path) 185 | 186 | def log_validation_images(self, logs, steps): 187 | log_imgs = [] 188 | self.model.eval() 189 | 190 | try: 191 | valid_data = next(self.valid_dl_iter) 192 | except StopIteration: 193 | self.valid_dl_iter = iter(self.valid_dl) 194 | valid_data = next(self.valid_dl_iter) 195 | 196 | valid_data = valid_data.to(self.device) 197 | 198 | recons = self.model(valid_data, return_recons=True) 199 | 200 | # else save a grid of images 201 | 202 | imgs_and_recons = torch.stack((valid_data, recons), dim=0) 203 | imgs_and_recons = rearrange(imgs_and_recons, "r b ... -> (b r) ...") 204 | 205 | imgs_and_recons = imgs_and_recons.detach().cpu().float().clamp(0.0, 1.0) 206 | grid = make_grid(imgs_and_recons, nrow=2, normalize=True, value_range=(0, 1)) 207 | 208 | logs["reconstructions"] = grid 209 | save_file = str(self.results_dir / f"{steps}.png") 210 | save_image(grid, save_file) 211 | log_imgs.append(Image.open(save_file)) 212 | super().log_validation_images(log_imgs, steps, prompts=["vae"]) 213 | self.model.train() 214 | 215 | def train(self): 216 | self.steps = self.steps + 1 217 | device = self.device 218 | self.model.train() 219 | 220 | if self.accelerator.is_main_process: 221 | proc_label = f"[P{self.accelerator.process_index:03d}][Master]" 222 | else: 223 | proc_label = f"[P{self.accelerator.process_index:03d}][Worker]" 224 | 225 | for epoch in range(self.current_step // len(self.dl), self.num_epochs): 226 | for img in self.dl: 227 | loss = 0.0 228 | steps = int(self.steps.item()) 229 | 230 | apply_grad_penalty = (steps % self.apply_grad_penalty_every) == 0 231 | 232 | discr = self.model.module.discr if self.is_distributed else self.model.discr 233 | if self.use_ema: 234 | ema_model = self.ema_model.module if self.is_distributed else self.ema_model 235 | 236 | # logs 237 | 238 | logs = {} 239 | 240 | # update vae (generator) 241 | 242 | img = img.to(device) 243 | 244 | with self.accelerator.autocast(): 245 | loss = self.model(img, add_gradient_penalty=apply_grad_penalty, return_loss=True) 246 | 247 | self.accelerator.backward(loss / self.gradient_accumulation_steps) 248 | if self.max_grad_norm is not None and self.accelerator.sync_gradients: 249 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 250 | 251 | accum_log(logs, {"Train/vae_loss": loss.item() / self.gradient_accumulation_steps}) 252 | 253 | self.lr_scheduler.step() 254 | self.lr_scheduler_discr.step() 255 | self.optim.step() 256 | self.optim.zero_grad() 257 | 258 | loss = 0.0 259 | 260 | # update discriminator 261 | 262 | if exists(discr): 263 | self.discr_optim.zero_grad() 264 | 265 | with torch.cuda.amp.autocast(): 266 | loss = self.model(img, return_discr_loss=True) 267 | 268 | self.accelerator.backward(loss / self.gradient_accumulation_steps) 269 | if self.discr_max_grad_norm is not None and self.accelerator.sync_gradients: 270 | self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) 271 | 272 | accum_log( 273 | logs, 274 | {"Train/discr_loss": loss.item() / self.gradient_accumulation_steps}, 275 | ) 276 | 277 | self.discr_optim.step() 278 | 279 | # log 280 | if self.on_tpu: 281 | self.accelerator.print( 282 | f"[E{epoch + 1}][{steps:05d}]{proc_label}: " 283 | f"vae loss: {logs['Train/vae_loss']} - " 284 | f"discr loss: {logs['Train/discr_loss']} - " 285 | f"lr: {self.lr_scheduler.get_last_lr()[0]}" 286 | ) 287 | else: 288 | self.training_bar.update() 289 | # Note: we had to remove {proc_label} from the description 290 | # to short it so it doenst go beyond one line on each step. 291 | self.info_bar.set_description_str( 292 | f"[E{epoch + 1}][{steps:05d}]: " 293 | f"vae loss: {logs['Train/vae_loss']} - " 294 | f"discr loss: {logs['Train/discr_loss']} - " 295 | f"lr: {self.lr_scheduler.get_last_lr()[0]}" 296 | ) 297 | 298 | logs["lr"] = self.lr_scheduler.get_last_lr()[0] 299 | self.accelerator.log(logs, step=steps) 300 | 301 | # update exponential moving averaged generator 302 | 303 | if self.use_ema: 304 | ema_model.update() 305 | 306 | # sample results every so often 307 | 308 | if (steps % self.save_results_every) == 0: 309 | self.accelerator.print( 310 | f"\n[E{epoch + 1}][{steps}] | Logging validation images to {str(self.results_dir)}" 311 | ) 312 | 313 | self.log_validation_images(logs, steps) 314 | 315 | # save model every so often 316 | self.accelerator.wait_for_everyone() 317 | if self.is_main_process and (steps % self.save_model_every) == 0: 318 | self.accelerator.print(f"\nStep: {steps} | Saving model to {str(self.results_dir)}") 319 | 320 | state_dict = self.accelerator.unwrap_model(self.model).state_dict() 321 | file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt" 322 | model_path = str(self.results_dir / file_name) 323 | self.accelerator.save(state_dict, model_path) 324 | 325 | if self.args and not self.args.do_not_save_config: 326 | # save config file next to the model file. 327 | conf = OmegaConf.create(vars(self.args)) 328 | OmegaConf.save(conf, f"{model_path}.yaml") 329 | 330 | if self.use_ema: 331 | ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() 332 | file_name = ( 333 | f"vae.{steps}.ema.pt" if not self.only_save_last_checkpoint else "vae.ema.pt" 334 | ) 335 | model_path = str(self.results_dir / file_name) 336 | self.accelerator.save(ema_state_dict, model_path) 337 | 338 | if self.args and not self.args.do_not_save_config: 339 | # save config file next to the model file. 340 | conf = OmegaConf.create(vars(self.args)) 341 | OmegaConf.save(conf, f"{model_path}.yaml") 342 | 343 | self.steps += 1 344 | 345 | # if self.num_train_steps > 0 and int(self.steps.item()) >= self.num_train_steps: 346 | # self.accelerator.print( 347 | # f"\n[E{epoch + 1}][{steps}]{proc_label}: " f"[STOP EARLY]: Stopping training early..." 348 | # ) 349 | # break 350 | 351 | # Loop finished, save model 352 | self.accelerator.wait_for_everyone() 353 | if self.is_main_process: 354 | self.accelerator.print( 355 | f"[E{self.num_epochs}][{steps:05d}]{proc_label}: saving model to {str(self.results_dir)}" 356 | ) 357 | 358 | state_dict = self.accelerator.unwrap_model(self.model).state_dict() 359 | file_name = f"vae.{steps}.pt" if not self.only_save_last_checkpoint else "vae.pt" 360 | model_path = str(self.results_dir / file_name) 361 | self.accelerator.save(state_dict, model_path) 362 | 363 | if self.args and not self.args.do_not_save_config: 364 | # save config file next to the model file. 365 | conf = OmegaConf.create(vars(self.args)) 366 | OmegaConf.save(conf, f"{model_path}.yaml") 367 | 368 | if self.use_ema: 369 | ema_state_dict = self.accelerator.unwrap_model(self.ema_model).state_dict() 370 | file_name = f"vae.{steps}.ema.pt" if not self.only_save_last_checkpoint else "vae.ema.pt" 371 | model_path = str(self.results_dir / file_name) 372 | self.accelerator.save(ema_state_dict, model_path) 373 | 374 | if self.args and not self.args.do_not_save_config: 375 | # save config file next to the model file. 376 | conf = OmegaConf.create(vars(self.args)) 377 | OmegaConf.save(conf, f"{model_path}.yaml") 378 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/vqgan_vae.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from functools import partial, wraps 3 | from pathlib import Path 4 | 5 | import timm 6 | import torch 7 | import torch.nn.functional as F 8 | import torchvision 9 | from accelerate import Accelerator 10 | from beartype import beartype 11 | from einops import rearrange, repeat 12 | from torch import nn 13 | from torch.autograd import grad as torch_grad 14 | from vector_quantize_pytorch import VectorQuantize as VQ 15 | 16 | # constants 17 | 18 | MList = nn.ModuleList 19 | 20 | 21 | # helper functions 22 | def exists(val): 23 | return val is not None 24 | 25 | 26 | def default(val, d): 27 | return val if val is not None else d 28 | 29 | 30 | # decorators 31 | def eval_decorator(fn): 32 | def inner(model, *args, **kwargs): 33 | was_training = model.training 34 | model.eval() 35 | out = fn(model, *args, **kwargs) 36 | model.train(was_training) 37 | return out 38 | 39 | return inner 40 | 41 | 42 | def remove_vgg(fn): 43 | @wraps(fn) 44 | def inner(self, *args, **kwargs): 45 | has_vgg = hasattr(self, "_vgg") 46 | if has_vgg: 47 | vgg = self._vgg 48 | delattr(self, "_vgg") 49 | 50 | out = fn(self, *args, **kwargs) 51 | 52 | if has_vgg: 53 | self._vgg = vgg 54 | 55 | return out 56 | 57 | return inner 58 | 59 | 60 | # keyword argument helpers 61 | def pick_and_pop(keys, d): 62 | values = list(map(lambda key: d.pop(key), keys)) 63 | return dict(zip(keys, values)) 64 | 65 | 66 | def group_dict_by_key(cond, d): 67 | return_val = [dict(), dict()] 68 | for key in d.keys(): 69 | match = bool(cond(key)) 70 | ind = int(not match) 71 | return_val[ind][key] = d[key] 72 | return (*return_val,) 73 | 74 | 75 | def string_begins_with(prefix, string_input): 76 | return string_input.startswith(prefix) 77 | 78 | 79 | def group_by_key_prefix(prefix, d): 80 | return group_dict_by_key(partial(string_begins_with, prefix), d) 81 | 82 | 83 | def groupby_prefix_and_trim(prefix, d): 84 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 85 | kwargs_without_prefix = dict( 86 | map(lambda x: (x[0][len(prefix) :], x[1]), tuple(kwargs_with_prefix.items())) 87 | ) 88 | return kwargs_without_prefix, kwargs 89 | 90 | 91 | # tensor helper functions 92 | def log(t, eps=1e-10): 93 | return torch.log(t + eps) 94 | 95 | 96 | def gradient_penalty(images, output, weight=10): 97 | gradients = torch_grad( 98 | outputs=output, 99 | inputs=images, 100 | grad_outputs=torch.ones(output.size(), device=images.device), 101 | create_graph=True, 102 | retain_graph=True, 103 | only_inputs=True, 104 | )[0] 105 | 106 | gradients = rearrange(gradients, "b ... -> b (...)") 107 | return weight * ((gradients.norm(2, dim=1) - 1) ** 2).mean() 108 | 109 | 110 | def leaky_relu(p: float = 0.1): 111 | return nn.LeakyReLU(p) 112 | 113 | 114 | def safe_div(numer, denom, eps=1e-8): 115 | return numer / denom.clamp(min=eps) 116 | 117 | 118 | # gan losses 119 | def hinge_discr_loss(fake, real): 120 | return (F.relu(1 + fake) + F.relu(1 - real)).mean() 121 | 122 | 123 | def hinge_gen_loss(fake): 124 | return -fake.mean() 125 | 126 | 127 | def bce_discr_loss(fake, real): 128 | return (-log(1 - torch.sigmoid(fake)) - log(torch.sigmoid(real))).mean() 129 | 130 | 131 | def bce_gen_loss(fake): 132 | return -log(torch.sigmoid(fake)).mean() 133 | 134 | 135 | def grad_layer_wrt_loss(loss, layer): 136 | return torch_grad( 137 | outputs=loss, 138 | inputs=layer, 139 | grad_outputs=torch.ones_like(loss), 140 | retain_graph=True, 141 | )[0].detach() 142 | 143 | 144 | # vqgan vae 145 | class LayerNormChan(nn.Module): 146 | def __init__(self, dim, eps=1e-5): 147 | super().__init__() 148 | self.eps = eps 149 | self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1)) 150 | 151 | def forward(self, x): 152 | var = torch.var(x, dim=1, unbiased=False, keepdim=True) 153 | mean = torch.mean(x, dim=1, keepdim=True) 154 | return (x - mean) * var.clamp(min=self.eps).rsqrt() * self.gamma 155 | 156 | 157 | # discriminator 158 | class Discriminator(nn.Module): 159 | def __init__(self, dims, channels=4, groups=16, init_kernel_size=5): 160 | super().__init__() 161 | dim_pairs = zip(dims[:-1], dims[1:]) 162 | 163 | self.layers = MList( 164 | [ 165 | nn.Sequential( 166 | nn.Conv2d(channels, dims[0], init_kernel_size, padding=init_kernel_size // 2), 167 | leaky_relu(), 168 | ) 169 | ] 170 | ) 171 | 172 | for dim_in, dim_out in dim_pairs: 173 | self.layers.append( 174 | nn.Sequential( 175 | nn.Conv2d(dim_in, dim_out, 4, stride=2, padding=1), 176 | nn.GroupNorm(groups, dim_out), 177 | leaky_relu(), 178 | ) 179 | ) 180 | 181 | dim = dims[-1] 182 | # return 5 x 5, for PatchGAN-esque training 183 | self.to_logits = nn.Sequential(nn.Conv2d(dim, dim, 1), leaky_relu(), nn.Conv2d(dim, 1, 4)) 184 | 185 | def forward(self, x): 186 | for net in self.layers: 187 | x = net(x) 188 | return self.to_logits(x) 189 | 190 | 191 | # resnet encoder / decoder 192 | class ResnetEncDec(nn.Module): 193 | def __init__( 194 | self, 195 | dim: int, 196 | *, 197 | channels=4, 198 | layers=4, 199 | layer_mults=None, 200 | num_resnet_blocks=1, 201 | resnet_groups=16, 202 | first_conv_kernel_size=5, 203 | ): 204 | super().__init__() 205 | assert ( 206 | dim % resnet_groups == 0 207 | ), f"dimension {dim} must be divisible by {resnet_groups} (groups for the groupnorm)" 208 | 209 | self.layers = layers 210 | self.encoders = MList([]) 211 | self.decoders = MList([]) 212 | 213 | layer_mults = default(layer_mults, [2**x for x in range(layers)]) 214 | if len(layer_mults) != layers: 215 | raise ValueError("layer multipliers must be equal to designated number of layers") 216 | 217 | layer_dims = [dim * mult for mult in layer_mults] 218 | dims = (dim, *layer_dims) 219 | 220 | self.encoded_dim = dims[-1] 221 | dim_pairs = zip(dims[:-1], dims[1:]) 222 | 223 | if not isinstance(num_resnet_blocks, tuple): 224 | num_resnet_blocks = (*((0,) * (layers - 1)), num_resnet_blocks) 225 | if len(num_resnet_blocks) != layers: 226 | raise ValueError("number of resnet blocks must be equal to number of layers") 227 | 228 | for _, (dim_in, dim_out), layer_num_resnet_blocks in zip(range(layers), dim_pairs, num_resnet_blocks): 229 | self.encoders.append( 230 | nn.Sequential(nn.Conv2d(dim_in, dim_out, 4, stride=2, padding=1), leaky_relu()) 231 | ) 232 | self.decoders.insert(0, nn.Sequential(nn.ConvTranspose2d(dim_out, dim_in, 4, 2, 1), leaky_relu())) 233 | for _ in range(layer_num_resnet_blocks): 234 | self.encoders.append(ResBlock(dim_out, groups=resnet_groups)) 235 | self.decoders.insert(0, GLUResBlock(dim_out, groups=resnet_groups)) 236 | 237 | self.encoders.insert( 238 | 0, nn.Conv2d(channels, dim, first_conv_kernel_size, padding=first_conv_kernel_size // 2) 239 | ) 240 | self.decoders.append(nn.Conv2d(dim, channels, 1)) 241 | 242 | def get_encoded_fmap_size(self, image_size: int): 243 | return image_size // (2**self.layers) 244 | 245 | @property 246 | def last_dec_layer(self): 247 | return self.decoders[-1].weight 248 | 249 | def encode(self, x): 250 | for enc in self.encoders: 251 | x = enc(x) 252 | return x 253 | 254 | def decode(self, x): 255 | for dec in self.decoders: 256 | x = dec(x) 257 | return x 258 | 259 | 260 | class GLUResBlock(nn.Module): 261 | def __init__(self, chan, groups=16): 262 | super().__init__() 263 | self.net = nn.Sequential( 264 | nn.Conv2d(chan, chan * 2, 3, padding=1), 265 | nn.GLU(dim=1), 266 | nn.GroupNorm(groups, chan), 267 | nn.Conv2d(chan, chan * 2, 3, padding=1), 268 | nn.GLU(dim=1), 269 | nn.GroupNorm(groups, chan), 270 | nn.Conv2d(chan, chan, 1), 271 | ) 272 | 273 | def forward(self, x): 274 | return self.net(x) + x 275 | 276 | 277 | class ResBlock(nn.Module): 278 | def __init__(self, chan, groups=16): 279 | super().__init__() 280 | self.net = nn.Sequential( 281 | nn.Conv2d(chan, chan, 3, padding=1), 282 | nn.GroupNorm(groups, chan), 283 | leaky_relu(), 284 | nn.Conv2d(chan, chan, 3, padding=1), 285 | nn.GroupNorm(groups, chan), 286 | leaky_relu(), 287 | nn.Conv2d(chan, chan, 1), 288 | ) 289 | 290 | def forward(self, x): 291 | return self.net(x) + x 292 | 293 | 294 | class TimmFeatureEncDec(nn.Module): 295 | def __init__(self, backbone="convnext_base"): 296 | self.timm_model = timm.create_model( 297 | backbone, 298 | pretrained=True, 299 | features_only=True, 300 | exportable=True, 301 | out_indices=self.idx, 302 | ) 303 | return 304 | 305 | def encode(self, x): 306 | for enc in self.encoders: 307 | x = enc(x) 308 | return x 309 | 310 | def decode(self, x): 311 | for dec in self.decoders: 312 | x = dec(x) 313 | return x 314 | 315 | 316 | class HuggingfaceEncDec(nn.Module): 317 | def __init__(self): 318 | return 319 | 320 | def forward(self): 321 | return 322 | 323 | 324 | class WaveletTransformerEncDec(nn.Module): 325 | def __init__(self): 326 | return 327 | 328 | def forward(self): 329 | return 330 | 331 | 332 | # main vqgan-vae classes 333 | @beartype 334 | class VQGanVAE(nn.Module): 335 | def __init__( 336 | self, 337 | *, 338 | dim: int, 339 | accelerator: Accelerator = None, 340 | channels=4, 341 | layers=4, 342 | l2_recon_loss=False, 343 | use_hinge_loss=True, 344 | vgg=None, 345 | vq_codebook_dim=256, 346 | vq_codebook_size=512, 347 | vq_decay=0.8, 348 | vq_commitment_weight=1.0, 349 | vq_kmeans_init=True, 350 | vq_use_cosine_sim=True, 351 | use_vgg_and_gan=True, 352 | discr_layers=4, 353 | **kwargs, 354 | ): 355 | super().__init__() 356 | vq_kwargs, kwargs = groupby_prefix_and_trim("vq_", kwargs) 357 | encdec_kwargs, kwargs = groupby_prefix_and_trim("encdec_", kwargs) 358 | 359 | self.accelerator = accelerator 360 | self.channels = channels 361 | self.codebook_size = vq_codebook_size 362 | self.dim_divisor = 2**layers 363 | 364 | self.enc_dec = ResnetEncDec(dim=dim, channels=channels, layers=layers, **encdec_kwargs) 365 | 366 | self.vq = VQ( 367 | dim=self.enc_dec.encoded_dim, 368 | codebook_dim=vq_codebook_dim, 369 | codebook_size=vq_codebook_size, 370 | decay=vq_decay, 371 | commitment_weight=vq_commitment_weight, 372 | accept_image_fmap=True, 373 | kmeans_init=vq_kmeans_init, 374 | use_cosine_sim=vq_use_cosine_sim, 375 | **vq_kwargs, 376 | ) 377 | 378 | # reconstruction loss 379 | self.recon_loss_fn = F.mse_loss if l2_recon_loss else F.l1_loss 380 | 381 | # turn off GAN and perceptual loss if grayscale 382 | self._vgg = None 383 | self.discr = None 384 | self.use_vgg_and_gan = use_vgg_and_gan 385 | if not use_vgg_and_gan: 386 | return 387 | 388 | # preceptual loss 389 | if exists(vgg): 390 | self._vgg = vgg 391 | 392 | # gan related losses 393 | layer_mults = list(map(lambda t: 2**t, range(discr_layers))) 394 | layer_dims = [dim * mult for mult in layer_mults] 395 | dims = (dim, *layer_dims) 396 | 397 | self.discr = Discriminator(dims=dims, channels=channels) 398 | self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss 399 | self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss 400 | 401 | @property 402 | def device(self): 403 | return self.accelerator.device if self.accelerator else next(self.parameters()).device 404 | 405 | @property 406 | def vgg(self): 407 | if exists(self._vgg): 408 | return self._vgg 409 | 410 | vgg = torchvision.models.vgg16(pretrained=True) 411 | vgg.features[0] = nn.Conv2d(self.channels, 64, kernel_size=3, stride=1, padding=1) 412 | vgg.classifier = nn.Sequential(*vgg.classifier[:-2]) 413 | self._vgg = vgg.to(self.device) 414 | return self._vgg 415 | 416 | @property 417 | def encoded_dim(self): 418 | return self.enc_dec.encoded_dim 419 | 420 | def get_encoded_fmap_size(self, image_size): 421 | return self.enc_dec.get_encoded_fmap_size(image_size) 422 | 423 | def copy_for_eval(self): 424 | device = next(self.parameters()).device 425 | vae_copy = copy.deepcopy(self.cpu()) 426 | 427 | if vae_copy.use_vgg_and_gan: 428 | del vae_copy.discr 429 | del vae_copy._vgg 430 | 431 | vae_copy.eval() 432 | return vae_copy.to(device) 433 | 434 | @remove_vgg 435 | def state_dict(self, *args, **kwargs): 436 | return super().state_dict(*args, **kwargs) 437 | 438 | @remove_vgg 439 | def load_state_dict(self, *args, **kwargs): 440 | return super().load_state_dict(*args, **kwargs) 441 | 442 | def save(self, path): 443 | if self.accelerator is not None: 444 | self.accelerator.save(self.state_dict(), path) 445 | else: 446 | torch.save(self.state_dict(), path) 447 | 448 | def load(self, path, map=None): 449 | path = Path(path) 450 | assert path.exists() 451 | state_dict = torch.load(str(path), map_location=map) 452 | self.load_state_dict(state_dict) 453 | 454 | @property 455 | def codebook(self): 456 | return self.vq.codebook 457 | 458 | def encode(self, fmap): 459 | fmap = self.enc_dec.encode(fmap) 460 | fmap, indices, commit_loss = self.vq(fmap) 461 | return fmap, indices, commit_loss 462 | 463 | def decode_from_ids(self, ids): 464 | codes = self.codebook[ids] 465 | fmap = self.vq.project_out(codes) 466 | fmap = rearrange(fmap, "b h w c -> b c h w") 467 | return self.decode(fmap) 468 | 469 | def decode(self, fmap): 470 | return self.enc_dec.decode(fmap) 471 | 472 | def forward( 473 | self, 474 | img, 475 | return_loss=False, 476 | return_discr_loss=False, 477 | return_recons=False, 478 | add_gradient_penalty=True, 479 | relu_loss=True, 480 | ): 481 | batch, channels, height, width, device = *img.shape, img.device 482 | 483 | for dim_name, size in (("height", height), ("width", width)): 484 | assert (size % self.dim_divisor) == 0, f"{dim_name} must be divisible by {self.dim_divisor}" 485 | 486 | assert ( 487 | channels == self.channels 488 | ), "number of channels on image or sketch is not equal to the channels set on this VQGanVAE" 489 | 490 | fmap, indices, commit_loss = self.encode(img) 491 | 492 | fmap = self.decode(fmap) 493 | 494 | if not return_loss and not return_discr_loss: 495 | return fmap 496 | 497 | assert ( 498 | return_loss ^ return_discr_loss 499 | ), "you should either return autoencoder loss or discriminator loss, but not both" 500 | 501 | # whether to return discriminator loss 502 | 503 | if return_discr_loss: 504 | assert exists(self.discr), "discriminator must exist to train it" 505 | 506 | fmap.detach_() 507 | img.requires_grad_() 508 | 509 | fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img)) 510 | 511 | discr_loss = self.discr_loss(fmap_discr_logits, img_discr_logits) 512 | 513 | if add_gradient_penalty: 514 | gp = gradient_penalty(img, img_discr_logits) 515 | loss = discr_loss + gp 516 | 517 | if return_recons: 518 | if relu_loss: 519 | return F.relu(loss), fmap 520 | else: 521 | return loss, fmap 522 | 523 | if relu_loss: 524 | return F.relu(loss) 525 | else: 526 | return loss 527 | 528 | # reconstruction loss 529 | 530 | recon_loss = self.recon_loss_fn(fmap, img) 531 | 532 | # early return if training on grayscale 533 | 534 | if not self.use_vgg_and_gan: 535 | if return_recons: 536 | if relu_loss: 537 | return F.relu(recon_loss), fmap 538 | else: 539 | return recon_loss, fmap 540 | 541 | if relu_loss: 542 | return F.relu(recon_loss) 543 | else: 544 | return recon_loss 545 | 546 | # perceptual loss 547 | 548 | img_vgg_input = img 549 | fmap_vgg_input = fmap 550 | 551 | if img.shape[1] == 1: 552 | # handle grayscale for vgg 553 | img_vgg_input, fmap_vgg_input = map( 554 | lambda t: repeat(t, "b 1 ... -> b c ...", c=3), 555 | (img_vgg_input, fmap_vgg_input), 556 | ) 557 | 558 | img_vgg_feats = self.vgg(img_vgg_input) 559 | recon_vgg_feats = self.vgg(fmap_vgg_input) 560 | perceptual_loss = F.mse_loss(img_vgg_feats, recon_vgg_feats) 561 | if relu_loss: 562 | perceptual_loss = F.relu(perceptual_loss) 563 | 564 | # generator loss 565 | 566 | gen_loss = self.gen_loss(self.discr(fmap)) 567 | if relu_loss: 568 | gen_loss = F.relu(gen_loss) 569 | 570 | # calculate adaptive weight 571 | 572 | last_dec_layer = self.enc_dec.last_dec_layer 573 | 574 | norm_grad_wrt_gen_loss = grad_layer_wrt_loss(gen_loss, last_dec_layer).norm(p=2) 575 | norm_grad_wrt_perceptual_loss = grad_layer_wrt_loss(perceptual_loss, last_dec_layer).norm(p=2) 576 | 577 | adaptive_weight = safe_div(norm_grad_wrt_perceptual_loss, norm_grad_wrt_gen_loss) 578 | adaptive_weight.clamp_(max=1e4) 579 | 580 | # combine losses 581 | # recon loss is reconstruction loss mse 582 | # perceptual loss is loss in vgg features mse 583 | # commit loss is loss in quanitizing in vq mse 584 | # gan loss is 585 | if relu_loss: 586 | loss = ( 587 | F.relu(recon_loss) 588 | + F.relu(perceptual_loss) 589 | + F.relu(commit_loss) 590 | + F.relu(adaptive_weight) * F.relu(gen_loss) 591 | ) 592 | else: 593 | loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss 594 | 595 | if return_recons: 596 | if relu_loss: 597 | return F.relu(loss), fmap 598 | else: 599 | return loss, fmap 600 | 601 | if relu_loss: 602 | return F.relu(loss) 603 | else: 604 | return loss 605 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/vqgan_vae_taming.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import importlib 3 | from math import log, sqrt 4 | from pathlib import Path 5 | from urllib.parse import urlparse 6 | 7 | import requests 8 | import torch 9 | import torch.nn.functional as F 10 | from accelerate import Accelerator 11 | from einops import rearrange 12 | from omegaconf import DictConfig, OmegaConf 13 | from taming.models.vqgan import VQModel 14 | from torch import nn 15 | from tqdm_loggable.auto import tqdm 16 | 17 | # constants 18 | CACHE_PATH = Path.home().joinpath(".cache/taming") 19 | 20 | VQGAN_VAE_PATH = "https://heibox.uni-heidelberg.de/f/140747ba53464f49b476/?dl=1" 21 | VQGAN_VAE_CONFIG_PATH = "https://heibox.uni-heidelberg.de/f/6ecf2af6c658432c8298/?dl=1" 22 | 23 | # helpers methods 24 | 25 | 26 | def exists(val): 27 | return val is not None 28 | 29 | 30 | def default(val, d): 31 | return val if exists(val) else d 32 | 33 | 34 | def download(url, filename=None, root=CACHE_PATH, chunk_size=1024): 35 | filename = default(filename, urlparse(url).path.split("/")[-1]) 36 | root_dir = Path(root) 37 | 38 | target_path = root_dir.joinpath(filename) 39 | if target_path.exists(): 40 | if target_path.isfile(): 41 | return str(target_path) 42 | raise RuntimeError(f"{target_path} exists and is not a regular file") 43 | 44 | target_tmp = target_path.with_name(f".{target_path.name}.tmp") 45 | resp = requests.get(url, stream=True) 46 | resp.raise_for_status() 47 | 48 | filesize = int(resp.headers.get("content-length", 0)) 49 | with target_tmp.open("wb") as f: 50 | for data in tqdm( 51 | resp.iter_content(chunk_size=chunk_size), 52 | desc=filename, 53 | total=filesize, 54 | unit="iB", 55 | unit_scale=True, 56 | unit_divisor=1024, 57 | ): 58 | f.write(data) 59 | target_tmp.rename(target_path) 60 | return target_path 61 | 62 | 63 | # VQGAN from Taming Transformers paper 64 | # https://arxiv.org/abs/2012.09841 65 | 66 | 67 | def get_obj_from_str(string, reload=False): 68 | module, cls = string.rsplit(".", 1) 69 | if reload: 70 | module_imp = importlib.import_module(module) 71 | importlib.reload(module_imp) 72 | return getattr(importlib.import_module(module, package=None), cls) 73 | 74 | 75 | def instantiate_from_config(config): 76 | if "target" not in config: 77 | raise KeyError("Expected key `target` to instantiate.") 78 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 79 | 80 | 81 | class VQGanVAETaming(nn.Module): 82 | def __init__(self, vqgan_model_path=None, vqgan_config_path=None, accelerator: Accelerator = None): 83 | super().__init__() 84 | if accelerator is None: 85 | accelerator = Accelerator() 86 | 87 | # Download model if needed 88 | if vqgan_model_path is None: 89 | CACHE_PATH.mkdir(parents=True, exist_ok=True) 90 | model_filename = "vqgan.1024.model.ckpt" 91 | config_filename = "vqgan.1024.config.yml" 92 | with accelerator.local_main_process_first(): 93 | config_path = download(VQGAN_VAE_CONFIG_PATH, config_filename) 94 | model_path = download(VQGAN_VAE_PATH, model_filename) 95 | else: 96 | config_path = Path(vqgan_config_path) 97 | model_path = Path(vqgan_model_path) 98 | 99 | with accelerator.local_main_process_first(): 100 | config: DictConfig = OmegaConf.load(config_path) 101 | model: VQModel = instantiate_from_config(config["model"]) 102 | state = torch.load(model_path, map_location="cpu")["state_dict"] 103 | model.load_state_dict(state, strict=False) 104 | 105 | print(f"Loaded VQGAN from {model_path} and {config_path}") 106 | self.model: VQModel = model 107 | 108 | # f as used in https://github.com/CompVis/taming-transformers#overview-of-pretrained-models 109 | f = config.model.params.ddconfig.resolution / config.model.params.ddconfig.attn_resolutions[0] 110 | self.num_layers = int(log(f) / log(2)) 111 | self.channels = 3 112 | self.image_size = 256 113 | self.num_tokens = config.model.params.n_embed 114 | self.is_gumbel = False # isinstance(self.model, GumbelVQ) 115 | self.codebook_size = config["model"]["params"]["n_embed"] 116 | 117 | @torch.no_grad() 118 | def get_codebook_indices(self, img): 119 | b = img.shape[0] 120 | img = (2 * img) - 1 121 | _, _, [_, _, indices] = self.model.encode(img) 122 | if self.is_gumbel: 123 | return rearrange(indices, "b h w -> b (h w)", b=b) 124 | return rearrange(indices, "(b n) -> b n", b=b) 125 | 126 | def get_encoded_fmap_size(self, image_size): 127 | return image_size // (2**self.num_layers) 128 | 129 | def decode_from_ids(self, img_seq): 130 | img_seq = rearrange(img_seq, "b h w -> b (h w)") 131 | b, n = img_seq.shape 132 | one_hot_indices = F.one_hot(img_seq, num_classes=self.num_tokens).float() 133 | z = ( 134 | one_hot_indices @ self.model.quantize.embed.weight 135 | if self.is_gumbel 136 | else (one_hot_indices @ self.model.quantize.embedding.weight) 137 | ) 138 | 139 | z = rearrange(z, "b (h w) c -> b c h w", h=int(sqrt(n))) 140 | img = self.model.decode(z) 141 | 142 | img = (img.clamp(-1.0, 1.0) + 1) * 0.5 143 | return img 144 | 145 | def encode(self, im_seq): 146 | # encode output 147 | # fmap, loss, (perplexity, min_encodings, min_encodings_indices) = self.model.encode(im_seq) 148 | fmap, loss, (_, _, min_encodings_indices) = self.model.encode(im_seq) 149 | 150 | b, _, h, w = fmap.shape 151 | min_encodings_indices = rearrange(min_encodings_indices, "(b h w) -> b h w", h=h, w=w, b=b) 152 | return fmap, min_encodings_indices, loss 153 | 154 | def decode_ids(self, ids): 155 | return self.model.decode_code(ids) 156 | 157 | def copy_for_eval(self): 158 | device = next(self.parameters()).device 159 | vae_copy = copy.deepcopy(self.cpu()) 160 | 161 | vae_copy.eval() 162 | return vae_copy.to(device) 163 | 164 | def forward(self, img): 165 | raise NotImplementedError("Forward not implemented for Taming VAE") 166 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/vqvae/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import VQVAEConfig 2 | from .vqvae import VQVAE 3 | 4 | __all__ = [ 5 | "VQVAE", 6 | "VQVAEConfig", 7 | ] 8 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/vqvae/config.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | 4 | class EncoderConfig(BaseModel): 5 | image_size: int = Field(...) 6 | patch_size: int = Field(...) 7 | dim: int = Field(...) 8 | depth: int = Field(...) 9 | num_head: int = Field(...) 10 | mlp_dim: int = Field(...) 11 | in_channels: int = Field(...) 12 | dim_head: int = Field(...) 13 | dropout: float = Field(...) 14 | 15 | 16 | class DecoderConfig(BaseModel): 17 | image_size: int = Field(...) 18 | patch_size: int = Field(...) 19 | dim: int = Field(...) 20 | depth: int = Field(...) 21 | num_head: int = Field(...) 22 | mlp_dim: int = Field(...) 23 | out_channels: int = Field(...) 24 | dim_head: int = Field(...) 25 | dropout: float = Field(...) 26 | 27 | 28 | class VQVAEConfig(BaseModel): 29 | n_embed: int = Field(...) 30 | embed_dim: int = Field(...) 31 | beta: float = Field(...) 32 | enc: EncoderConfig = Field(...) 33 | dec: DecoderConfig = Field(...) 34 | 35 | 36 | VIT_S_CONFIG = VQVAEConfig( 37 | n_embed=8192, 38 | embed_dim=32, 39 | beta=0.25, 40 | enc=EncoderConfig( 41 | image_size=256, 42 | patch_size=8, 43 | dim=512, 44 | depth=8, 45 | num_head=8, 46 | mlp_dim=2048, 47 | in_channels=3, 48 | dim_head=64, 49 | dropout=0.0, 50 | ), 51 | dec=DecoderConfig( 52 | image_size=256, 53 | patch_size=8, 54 | dim=512, 55 | depth=8, 56 | num_head=8, 57 | mlp_dim=2048, 58 | out_channels=3, 59 | dim_head=64, 60 | dropout=0.0, 61 | ), 62 | ) 63 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/vqvae/discriminator.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch.nn as nn 4 | 5 | 6 | def weights_init(m): 7 | classname = m.__class__.__name__ 8 | if classname.find("Conv") != -1: 9 | nn.init.normal_(m.weight.data, 0.0, 0.02) 10 | elif classname.find("BatchNorm") != -1: 11 | nn.init.normal_(m.weight.data, 1.0, 0.02) 12 | nn.init.constant_(m.bias.data, 0) 13 | 14 | 15 | class NLayerDiscriminator(nn.Module): 16 | """Defines a PatchGAN discriminator""" 17 | 18 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 19 | """Construct a PatchGAN discriminator 20 | Parameters: 21 | input_nc (int) -- the number of channels in input images 22 | ndf (int) -- the number of filters in the last conv layer 23 | n_layers (int) -- the number of conv layers in the discriminator 24 | norm_layer -- normalization layer 25 | """ 26 | super().__init__() 27 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 28 | use_bias = norm_layer.func == nn.InstanceNorm2d 29 | else: 30 | use_bias = norm_layer == nn.InstanceNorm2d 31 | 32 | kw = 4 33 | padw = 1 34 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 35 | nf_mult = 1 36 | nf_mult_prev = 1 37 | for n in range(1, n_layers): # gradually increase the number of filters 38 | nf_mult_prev = nf_mult 39 | nf_mult = min(2**n, 8) 40 | sequence += [ 41 | nn.Conv2d( 42 | ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias 43 | ), 44 | norm_layer(ndf * nf_mult), 45 | nn.LeakyReLU(0.2, True), 46 | ] 47 | 48 | nf_mult_prev = nf_mult 49 | nf_mult = min(2**n_layers, 8) 50 | sequence += [ 51 | nn.Conv2d( 52 | ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias 53 | ), 54 | norm_layer(ndf * nf_mult), 55 | nn.LeakyReLU(0.2, True), 56 | ] 57 | 58 | sequence += [ 59 | nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw) 60 | ] # output 1 channel prediction map 61 | self.model = nn.Sequential(*sequence) 62 | 63 | self.apply(self.init_func) 64 | 65 | def forward(self, input): 66 | """Standard forward.""" 67 | return self.model(input) 68 | 69 | def init_func(self, m): # define the initialization function 70 | init_gain = 0.02 71 | classname = m.__class__.__name__ 72 | if hasattr(m, "weight") and (classname.find("Conv") != -1 or classname.find("Linear") != -1): 73 | nn.init.normal_(m.weight.data, 0.0, init_gain) 74 | if hasattr(m, "bias") and m.bias is not None: 75 | nn.init.constant_(m.bias.data, 0.0) 76 | elif ( 77 | classname.find("BatchNorm2d") != -1 78 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 79 | nn.init.normal_(m.weight.data, 1.0, init_gain) 80 | nn.init.constant_(m.bias.data, 0.0) 81 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/vqvae/layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from diffusers.utils import is_xformers_available 3 | from einops import rearrange 4 | from einops.layers.torch import Rearrange 5 | from torch import nn 6 | 7 | from muse_maskgit_pytorch.modules import CrossAttention, MemoryEfficientCrossAttention, SwiGLUFFNFused 8 | 9 | 10 | def pair(t): 11 | return t if isinstance(t, tuple) else (t, t) 12 | 13 | 14 | class FeedForward(nn.Module): 15 | def __init__(self, dim, mlp_dim, dropout=0.0): 16 | super().__init__() 17 | self.w_1 = nn.Linear(dim, mlp_dim) 18 | self.act = nn.GELU() 19 | self.dropout = nn.Dropout(p=dropout) 20 | self.w_2 = nn.Linear(mlp_dim, dim) 21 | 22 | def forward(self, x): 23 | x = self.w_1(x) 24 | x = self.act(x) 25 | x = self.dropout(x) 26 | x = self.w_2(x) 27 | 28 | return x 29 | 30 | 31 | class LayerScale(nn.Module): 32 | def __init__(self, dim, init_values=1e-5, inplace=False): 33 | super().__init__() 34 | self.inplace = inplace 35 | self.gamma = nn.Parameter(init_values * torch.ones(dim)) 36 | 37 | def forward(self, x): 38 | return x.mul_(self.gamma) if self.inplace else x * self.gamma 39 | 40 | 41 | class Layer(nn.Module): 42 | ATTENTION_MODES = {"vanilla": CrossAttention, "xformer": MemoryEfficientCrossAttention} 43 | 44 | def __init__(self, dim, dim_head, mlp_dim, num_head=8, dropout=0.0): 45 | super().__init__() 46 | attn_mode = "xformer" if is_xformers_available() else "vanilla" 47 | attn_cls = self.ATTENTION_MODES[attn_mode] 48 | self.norm1 = nn.LayerNorm(dim) 49 | self.attn1 = attn_cls(query_dim=dim, heads=num_head, dim_head=dim_head, dropout=dropout) 50 | self.norm2 = nn.LayerNorm(dim) 51 | self.ffnet = SwiGLUFFNFused(in_features=dim, hidden_features=mlp_dim) 52 | 53 | def forward(self, x): 54 | x = self.attn1(self.norm1(x)) + x 55 | x = self.ffnet(self.norm2(x)) + x 56 | 57 | return x 58 | 59 | 60 | class Transformer(nn.Module): 61 | def __init__(self, dim, depth, num_head, dim_head, mlp_dim, dropout=0.0): 62 | super().__init__() 63 | self.layers = nn.Sequential(*[Layer(dim, dim_head, mlp_dim, num_head, dropout) for i in range(depth)]) 64 | 65 | def forward(self, x): 66 | x = self.layers(x) 67 | 68 | return x 69 | 70 | 71 | class Encoder(nn.Module): 72 | def __init__( 73 | self, 74 | image_size, 75 | patch_size, 76 | dim, 77 | depth, 78 | num_head, 79 | mlp_dim, 80 | in_channels=3, 81 | dim_head=64, 82 | dropout=0.0, 83 | ): 84 | super().__init__() 85 | 86 | self.image_size = image_size 87 | self.patch_size = patch_size 88 | 89 | assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size." 90 | 91 | self.to_patch_embedding = nn.Sequential( 92 | nn.Conv2d(in_channels, dim, kernel_size=patch_size, stride=patch_size, bias=False), 93 | Rearrange("b c h w -> b (h w) c"), 94 | ) 95 | 96 | scale = dim**-0.5 97 | num_patches = (image_size // patch_size) ** 2 98 | self.position_embedding = nn.Parameter(torch.randn(1, num_patches, dim) * scale) 99 | self.norm_pre = nn.LayerNorm(dim) 100 | self.transformer = Transformer(dim, depth, num_head, dim_head, mlp_dim, dropout) 101 | 102 | self.initialize_weights() 103 | 104 | def initialize_weights(self): 105 | self.apply(self._init_weights) 106 | 107 | def _init_weights(self, m): 108 | if isinstance(m, nn.Linear): 109 | torch.nn.init.xavier_uniform_(m.weight) 110 | if isinstance(m, nn.Linear) and m.bias is not None: 111 | nn.init.constant_(m.bias, 0) 112 | elif isinstance(m, nn.LayerNorm): 113 | nn.init.constant_(m.bias, 0) 114 | nn.init.constant_(m.weight, 1.0) 115 | 116 | def forward(self, x): 117 | x = self.to_patch_embedding(x) 118 | x = x + self.position_embedding 119 | x = self.norm_pre(x) 120 | x = self.transformer(x) 121 | 122 | return x 123 | 124 | 125 | class Decoder(nn.Module): 126 | def __init__( 127 | self, 128 | image_size, 129 | patch_size, 130 | dim, 131 | depth, 132 | num_head, 133 | mlp_dim, 134 | out_channels=3, 135 | dim_head=64, 136 | dropout=0.0, 137 | ): 138 | super().__init__() 139 | 140 | self.image_size = image_size 141 | self.patch_size = patch_size 142 | 143 | assert image_size % patch_size == 0, "Image dimensions must be divisible by the patch size." 144 | 145 | scale = dim**-0.5 146 | num_patches = (image_size // patch_size) ** 2 147 | self.position_embedding = nn.Parameter(torch.randn(1, num_patches, dim) * scale) 148 | self.transformer = Transformer(dim, depth, num_head, dim_head, mlp_dim, dropout) 149 | self.norm = nn.LayerNorm(dim) 150 | self.proj = nn.Linear(dim, out_channels * patch_size * patch_size, bias=True) 151 | 152 | self.initialize_weights() 153 | 154 | def initialize_weights(self): 155 | self.apply(self._init_weights) 156 | 157 | def _init_weights(self, m): 158 | if isinstance(m, nn.Linear): 159 | torch.nn.init.xavier_uniform_(m.weight) 160 | if isinstance(m, nn.Linear) and m.bias is not None: 161 | nn.init.constant_(m.bias, 0) 162 | elif isinstance(m, nn.LayerNorm): 163 | nn.init.constant_(m.bias, 0) 164 | nn.init.constant_(m.weight, 1.0) 165 | 166 | def forward(self, x): 167 | x = x + self.position_embedding 168 | x = self.transformer(x) 169 | x = self.norm(x) 170 | x = self.proj(x) 171 | x = rearrange( 172 | x, 173 | "b (h w) (p1 p2 c) -> b c (h p1) (w p2)", 174 | h=self.image_size // self.patch_size, 175 | p1=self.patch_size, 176 | p2=self.patch_size, 177 | ) 178 | 179 | return x 180 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/vqvae/quantize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class VectorQuantize(nn.Module): 7 | def __init__(self, n_e, vq_embed_dim, beta=0.25): 8 | super().__init__() 9 | self.n_e = n_e 10 | self.vq_embed_dim = vq_embed_dim 11 | self.beta = beta 12 | 13 | self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim) 14 | self.embedding.weight.data.normal_() 15 | 16 | def forward(self, z): 17 | z = F.normalize(z, p=2, dim=-1) 18 | z_flattened = z.view(-1, self.vq_embed_dim) 19 | embed_norm = F.normalize(self.embedding.weight, p=2, dim=-1) 20 | 21 | # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z 22 | d = ( 23 | torch.sum(z_flattened**2, dim=1, keepdim=True) 24 | + torch.sum(embed_norm**2, dim=1) 25 | - 2 * torch.einsum("bd,nd->bn", z_flattened, embed_norm) 26 | ) 27 | 28 | encoding_indices = torch.argmin(d, dim=1).view(*z.shape[:-1]) 29 | z_q = self.embedding(encoding_indices).view(z.shape) 30 | z_q = F.normalize(z_q, p=2, dim=-1) 31 | 32 | # compute loss for embedding 33 | loss = self.beta * torch.mean((z_q.detach() - z) ** 2) + torch.mean((z_q - z.detach()) ** 2) 34 | 35 | # preserve gradients 36 | z_q = z + (z_q - z).detach() 37 | 38 | return z_q, loss, encoding_indices 39 | 40 | def decode_ids(self, indices): 41 | z_q = self.embedding(indices) 42 | z_q = F.normalize(z_q, p=2, dim=-1) 43 | 44 | return z_q 45 | -------------------------------------------------------------------------------- /muse_maskgit_pytorch/vqvae/vqvae.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from diffusers import ConfigMixin, ModelMixin 6 | from diffusers.configuration_utils import register_to_config 7 | 8 | from .layers import Decoder, Encoder 9 | from .quantize import VectorQuantize 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class VQVAE(ModelMixin, ConfigMixin): 15 | @register_to_config 16 | def __init__(self, n_embed, embed_dim, beta, enc, dec, **kwargs): 17 | super().__init__() 18 | self.encoder = Encoder(**enc) 19 | self.decoder = Decoder(**dec) 20 | 21 | self.prev_quant = nn.Linear(enc["dim"], embed_dim) 22 | self.quantize = VectorQuantize(n_embed, embed_dim, beta) 23 | self.post_quant = nn.Linear(embed_dim, dec["dim"]) 24 | 25 | def freeze(self): 26 | self.eval() 27 | self.requires_grad_(False) 28 | 29 | def encode(self, x): 30 | x = self.encoder(x) 31 | x = self.prev_quant(x) 32 | x, loss, indices = self.quantize(x) 33 | return x, loss, indices 34 | 35 | def decode(self, x): 36 | x = self.post_quant(x) 37 | x = self.decoder(x) 38 | return x.clamp(-1.0, 1.0) 39 | 40 | def forward(self, inputs: torch.FloatTensor): 41 | z, loss, _ = self.encode(inputs) 42 | rec = self.decode(z) 43 | return rec, loss 44 | 45 | def encode_to_ids(self, inputs): 46 | _, _, indices = self.encode(inputs) 47 | return indices 48 | 49 | def decode_from_ids(self, indice): 50 | z_q = self.quantize.decode_ids(indice) 51 | img = self.decode(z_q) 52 | return img 53 | 54 | def __call__(self, inputs: torch.FloatTensor): 55 | return self.forward(inputs) 56 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | build-backend = "setuptools.build_meta" 3 | requires = ["setuptools>=61.0.0", "wheel", "setuptools_scm[toml]>=6.2"] 4 | 5 | [tool.setuptools_scm] 6 | write_to = "muse_maskgit_pytorch/_version.py" 7 | 8 | [tool.black] 9 | line-length = 110 10 | target-version = ['py38', 'py39', 'py310'] 11 | 12 | [tool.ruff] 13 | line-length = 110 14 | target-version = 'py38' 15 | format = "grouped" 16 | ignore-init-module-imports = true 17 | select = ["E", "F", "I"] 18 | ignore = ['F841', 'F401', 'E501'] 19 | 20 | [tool.ruff.isort] 21 | combine-as-imports = true 22 | force-wrap-aliases = true 23 | known-local-folder = ["muse_maskgit_pytorch"] 24 | known-first-party = ["muse_maskgit_pytorch"] 25 | -------------------------------------------------------------------------------- /scripts/vqvae_test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from pathlib import Path 3 | 4 | import torch 5 | from huggingface_hub import hf_hub_download 6 | from PIL import Image 7 | from torchvision import transforms as T 8 | from torchvision.utils import save_image 9 | 10 | from muse_maskgit_pytorch.vqvae import VQVAE 11 | 12 | logging.basicConfig(level=logging.INFO) 13 | logger = logging.getLogger(__name__) 14 | 15 | # where to find the model and the test images 16 | model_repo = "neggles/vaedump" 17 | model_subdir = "vit-s-vqgan-f4" 18 | test_images = ["testimg_1.png", "testimg_2.png"] 19 | 20 | # where to save the preprocessed and reconstructed images 21 | image_dir = Path.cwd().joinpath("temp") 22 | image_dir.mkdir(exist_ok=True, parents=True) 23 | 24 | # image transforms for the VQVAE 25 | transform_enc = T.Compose([T.Resize(512), T.RandomCrop(256), T.ToTensor()]) 26 | transform_dec = T.Compose([T.ConvertImageDtype(torch.uint8), T.ToPILImage()]) 27 | 28 | 29 | def get_save_path(path: Path, append: str) -> Path: 30 | # append a string to the filename before the extension 31 | # n.b. only keeps the final suffix, e.g. "foo.xyz.png" -> "foo-prepro.png" 32 | return path.with_name(f"{path.stem}-{append}{path.suffix}") 33 | 34 | 35 | def main(): 36 | torch_device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 37 | dtype = torch.float32 38 | 39 | # load VAE 40 | logger.info(f"Loading VQVAE from {model_repo}/{model_subdir}...") 41 | vae: VQVAE = VQVAE.from_pretrained(model_repo, subfolder=model_subdir, torch_dtype=dtype) 42 | vae = vae.to(torch_device) 43 | logger.info(f"Loaded VQVAE from {model_repo} to {vae.device} with dtype {vae.dtype}") 44 | 45 | # download and process images 46 | for image in test_images: 47 | image_path = hf_hub_download(model_repo, subfolder="images", filename=image, local_dir=image_dir) 48 | image_path = Path(image_path) 49 | logger.info(f"Downloaded {image_path}, size {image_path.stat().st_size} bytes") 50 | 51 | # preprocess 52 | image_obj = Image.open(image_path).convert("RGB") 53 | image_tensor: torch.Tensor = transform_enc(image_obj) 54 | save_path = get_save_path(image_path, "prepro") 55 | save_image(image_tensor, save_path, normalize=True, range=(-1.0, 1.0)) 56 | logger.info(f"Saved preprocessed image to {save_path}") 57 | 58 | # encode 59 | encoded, _, _ = vae.encode(image_tensor.unsqueeze(0).to(vae.device)) 60 | 61 | # decode 62 | reconstructed = vae.decode(encoded).squeeze(0) 63 | reconstructed = torch.clamp(reconstructed, -1.0, 1.0) 64 | 65 | # save 66 | save_path = get_save_path(image_path, "recon") 67 | save_image(reconstructed, save_path, normalize=True, range=(-1.0, 1.0)) 68 | logger.info(f"Saved reconstructed image to {save_path}") 69 | 70 | # compare 71 | image_prepro = transform_dec(image_tensor) 72 | image_recon = transform_dec(reconstructed) 73 | canvas = Image.new("RGB", (512 + 12, 256 + 8), (248, 248, 242)) 74 | canvas.paste(image_prepro, (4, 4)) 75 | canvas.paste(image_recon, (256 + 8, 4)) 76 | save_path = get_save_path(image_path, "compare") 77 | canvas.save(save_path) 78 | logger.info(f"Saved comparison image to {save_path}") 79 | 80 | logger.info("Done!") 81 | 82 | 83 | if __name__ == "__main__": 84 | main() 85 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="muse-maskgit-pytorch", 5 | packages=find_packages(exclude=[]), 6 | version="0.1.6", 7 | license="MIT", 8 | description="MUSE - Text-to-Image Generation via Masked Generative Transformers, in Pytorch", 9 | author="Phil Wang", 10 | author_email="lucidrains@gmail.com", 11 | long_description_content_type="text/markdown", 12 | url="https://github.com/lucidrains/muse-maskgit-pytorch", 13 | keywords=[ 14 | "artificial intelligence", 15 | "deep learning", 16 | "transformers", 17 | "attention mechanism", 18 | "text-to-image", 19 | ], 20 | extras_require={ 21 | "dev": [ 22 | "pre-commit>=3.3.2", 23 | "black>=23.3.0", 24 | "ruff>=0.0.272", 25 | ] 26 | }, 27 | install_requires=[ 28 | "accelerate", 29 | "diffusers", 30 | "datasets", 31 | "beartype", 32 | "einops>=0.6", 33 | "ema-pytorch", 34 | "omegaconf>=2.3.0", 35 | "pillow", 36 | "sentencepiece", 37 | "torch>=2.0", 38 | "torchmetrics<0.8.0", 39 | "pytorch-lightning>=2.0.0", 40 | "taming-transformers @ git+https://github.com/neggles/taming-transformers.git@v0.0.2", 41 | "transformers", 42 | "torchvision", 43 | "torch_optimizer", 44 | "tqdm", 45 | "timm", 46 | "tqdm-loggable", 47 | "vector-quantize-pytorch>=0.10.14", 48 | "lion-pytorch", 49 | "omegaconf", 50 | "xformers>=0.0.20", 51 | ], 52 | classifiers=[ 53 | "Development Status :: 4 - Beta", 54 | "Intended Audience :: Developers", 55 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 56 | "License :: OSI Approved :: MIT License", 57 | "Programming Language :: Python :: 3.6", 58 | ], 59 | ) 60 | -------------------------------------------------------------------------------- /tpu-vm.env: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | # dot-source this, ok? 3 | 4 | export PYTHONUNBUFFERED='1' 5 | 6 | ## General log level opts for Accelerate/Transformers 7 | #export ACCELERATE_LOG_LEVEL='INFO' 8 | #export TRANSFORMERS_LOG_LEVEL='INFO' 9 | 10 | # tcmalloc breaks things and google enable it by default, so that's gotta go 11 | unset LD_PRELOAD 12 | 13 | # add the dir where `libtpu-nightly` puts the library to LD_LIBRARY_PATH 14 | export LD_LIBRARY_PATH="/usr/local/lib/python3.8/dist-packages/libtpu/:${LD_LIBRARY_PATH}" 15 | 16 | # PJRT doesn't work with Accelerate yet so we deconfigure it and go back to old XRT 17 | unset PJRT_DEVICE 18 | export XRT_TPU_CONFIG='localservice;0;localhost:51011' 19 | export MASTER_ADDR='localhost' 20 | export MASTER_PORT='12355' 21 | 22 | ## see https://github.com/pytorch/xla/issues/4914 23 | export XLA_IR_SHAPE_CACHE_SIZE=12288 24 | 25 | ## useful options for debug 26 | #export PT_XLA_DEBUG=1 27 | # Enables the Python stack trace to be captured where creating IR nodes, hence allowing to understand which PyTorch operation was responsible for generating the IR. 28 | #export XLA_IR_DEBUG=1 29 | # Path to save the IR graphs generated during execution. 30 | #export XLA_SAVE_TENSORS_FILE='' 31 | # File type for above. can be text, dot (GraphViz), or hlo (native) 32 | #export XLA_SAVE_TENSORS_FMT='text' 33 | # Path to save metrics after every op 34 | #export XLA_METRICS_FILE= 35 | # In case of compilation/execution error, the offending HLO graph will be saved here. 36 | #export XLA_SAVE_HLO_FILE= 37 | 38 | # Enable OpByOp dispatch for "get tensors" 39 | #export XLA_GET_TENSORS_OPBYOP=1 40 | # Enable OpByOp dispatch for "sync tensors" 41 | #export XLA_SYNC_TENSORS_OPBYOP=1 42 | # Force XLA tensor sync before moving to next step 43 | #export XLA_SYNC_WAIT=1 44 | 45 | # Force downcasting of fp32 to bf16 46 | #export XLA_USE_BF16=1 47 | # Force downcasting of fp32 to fp16 48 | #export XLA_USE_F16=1 49 | # Force downcasting of fp64 to fp32 50 | #export XLA_USE_32BIT_LONG=1 51 | 52 | ## TPU runtime / compilation debug logging 53 | # All XLA log messages are INFO level so this is required 54 | #export TF_CPP_MIN_LOG_LEVEL=0 55 | # Print the thread ID in log messages 56 | #export TF_CPP_LOG_THREAD_ID=1 57 | # What modules to print from at what level 58 | #export TF_CPP_VMODULE='tensor=4,computation_client=5,xrt_computation_client=5,aten_xla_type=5' 59 | 60 | ## Limit to single TPU chip/core, can be useful for testing 61 | # export TPU_PROCESS_BOUNDS='1,1,1' 62 | # export TPU_VISIBLE_CHIPS=0 63 | -------------------------------------------------------------------------------- /train_muse_vae.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import glob 3 | import os 4 | import re 5 | from dataclasses import dataclass 6 | from typing import Optional, Union 7 | 8 | import wandb 9 | from accelerate.utils import ProjectConfiguration 10 | from datasets import load_dataset 11 | from omegaconf import OmegaConf 12 | 13 | from muse_maskgit_pytorch import ( 14 | VQGanVAE, 15 | VQGanVAETaming, 16 | VQGanVAETrainer, 17 | get_accelerator, 18 | ) 19 | from muse_maskgit_pytorch.dataset import ( 20 | ImageDataset, 21 | get_dataset_from_dataroot, 22 | split_dataset_into_dataloaders, 23 | ) 24 | 25 | # disable bitsandbytes welcome message. 26 | os.environ["BITSANDBYTES_NOWELCOME"] = "1" 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument("--webdataset", type=str, default=None, help="Path to webdataset if using one.") 30 | parser.add_argument( 31 | "--only_save_last_checkpoint", 32 | action="store_true", 33 | help="Only save last checkpoint.", 34 | ) 35 | parser.add_argument( 36 | "--validation_image_scale", 37 | default=1, 38 | type=float, 39 | help="Factor by which to scale the validation images.", 40 | ) 41 | parser.add_argument( 42 | "--no_center_crop", 43 | action="store_true", 44 | help="Don't do center crop.", 45 | ) 46 | parser.add_argument( 47 | "--no_flip", 48 | action="store_true", 49 | help="Don't flip image.", 50 | ) 51 | parser.add_argument( 52 | "--random_crop", 53 | action="store_true", 54 | help="Crop the images at random locations instead of cropping from the center.", 55 | ) 56 | parser.add_argument( 57 | "--dataset_save_path", 58 | type=str, 59 | default="dataset", 60 | help="Path to save the dataset if you are making one from a directory", 61 | ) 62 | parser.add_argument( 63 | "--clear_previous_experiments", 64 | action="store_true", 65 | help="Whether to clear previous experiments.", 66 | ) 67 | parser.add_argument("--max_grad_norm", type=float, default=None, help="Max gradient norm.") 68 | parser.add_argument( 69 | "--discr_max_grad_norm", 70 | type=float, 71 | default=None, 72 | help="Max gradient norm for discriminator.", 73 | ) 74 | parser.add_argument("--seed", type=int, default=42, help="Seed.") 75 | parser.add_argument("--valid_frac", type=float, default=0.05, help="validation fraction.") 76 | parser.add_argument("--use_ema", action="store_true", help="Whether to use ema.") 77 | parser.add_argument("--ema_beta", type=float, default=0.995, help="Ema beta.") 78 | parser.add_argument("--ema_update_after_step", type=int, default=1, help="Ema update after step.") 79 | parser.add_argument( 80 | "--ema_update_every", 81 | type=int, 82 | default=1, 83 | help="Ema update every this number of steps.", 84 | ) 85 | parser.add_argument( 86 | "--apply_grad_penalty_every", 87 | type=int, 88 | default=4, 89 | help="Apply gradient penalty every this number of steps.", 90 | ) 91 | parser.add_argument( 92 | "--image_column", 93 | type=str, 94 | default="image", 95 | help="The column of the dataset containing an image.", 96 | ) 97 | parser.add_argument( 98 | "--caption_column", 99 | type=str, 100 | default="caption", 101 | help="The column of the dataset containing a caption or a list of captions.", 102 | ) 103 | parser.add_argument( 104 | "--log_with", 105 | type=str, 106 | default="wandb", 107 | help=( 108 | 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' 109 | ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' 110 | ), 111 | ) 112 | parser.add_argument( 113 | "--project_name", 114 | type=str, 115 | default="muse_vae", 116 | help=("Name to use for the project to identify it when saved to a tracker such as wandb or tensorboard."), 117 | ) 118 | parser.add_argument( 119 | "--run_name", 120 | type=str, 121 | default=None, 122 | help=( 123 | "Name to use for the run to identify it when saved to a tracker such" 124 | " as wandb or tensorboard. If not specified a random one will be generated." 125 | ), 126 | ) 127 | parser.add_argument( 128 | "--wandb_user", 129 | type=str, 130 | default=None, 131 | help=( 132 | "Specify the name for the user or the organization in which the project will be saved when using wand." 133 | ), 134 | ) 135 | parser.add_argument( 136 | "--mixed_precision", 137 | type=str, 138 | default="no", 139 | choices=["no", "fp8", "fp16", "bf16"], 140 | help="Precision to train on.", 141 | ) 142 | parser.add_argument( 143 | "--use_8bit_adam", 144 | action="store_true", 145 | help="Whether to use the 8bit adam optimiser", 146 | ) 147 | parser.add_argument( 148 | "--results_dir", 149 | type=str, 150 | default="results", 151 | help="Path to save the training samples and checkpoints", 152 | ) 153 | parser.add_argument( 154 | "--logging_dir", 155 | type=str, 156 | default=None, 157 | help="Path to log the losses and LR", 158 | ) 159 | 160 | # vae_trainer args 161 | parser.add_argument( 162 | "--dataset_name", 163 | type=str, 164 | default=None, 165 | help="Name of the huggingface dataset used.", 166 | ) 167 | parser.add_argument( 168 | "--hf_split_name", 169 | type=str, 170 | default="train", 171 | help="Subset or split to use from the dataset when using a dataset form HuggingFace.", 172 | ) 173 | parser.add_argument( 174 | "--streaming", 175 | action="store_true", 176 | help="Whether to stream the huggingface dataset", 177 | ) 178 | parser.add_argument( 179 | "--train_data_dir", 180 | type=str, 181 | default=None, 182 | help="Dataset folder where your input images for training are.", 183 | ) 184 | parser.add_argument( 185 | "--num_train_steps", 186 | type=int, 187 | default=-1, 188 | help="Total number of steps to train for. eg. 50000. | Use only if you want to stop training early", 189 | ) 190 | parser.add_argument( 191 | "--num_epochs", 192 | type=int, 193 | default=5, 194 | help="Total number of epochs to train for. eg. 5.", 195 | ) 196 | parser.add_argument("--dim", type=int, default=128, help="Model dimension.") 197 | parser.add_argument("--batch_size", type=int, default=1, help="Batch Size.") 198 | parser.add_argument("--lr", type=float, default=1e-5, help="Learning Rate.") 199 | parser.add_argument( 200 | "--gradient_accumulation_steps", 201 | type=int, 202 | default=1, 203 | help="Gradient Accumulation.", 204 | ) 205 | parser.add_argument( 206 | "--save_results_every", 207 | type=int, 208 | default=100, 209 | help="Save results every this number of steps.", 210 | ) 211 | parser.add_argument( 212 | "--save_model_every", 213 | type=int, 214 | default=500, 215 | help="Save the model every this number of steps.", 216 | ) 217 | parser.add_argument( 218 | "--checkpoint_limit", 219 | type=int, 220 | default=None, 221 | help="Keep only X number of checkpoints and delete the older ones.", 222 | ) 223 | parser.add_argument("--vq_codebook_size", type=int, default=256, help="Image Size.") 224 | parser.add_argument("--vq_codebook_dim", type=int, default=256, help="VQ Codebook dimensions.") 225 | parser.add_argument( 226 | "--channels", type=int, default=3, help="Number of channels for the VAE. Use 3 for RGB or 4 for RGBA." 227 | ) 228 | parser.add_argument("--layers", type=int, default=4, help="Number of layers for the VAE.") 229 | parser.add_argument("--discr_layers", type=int, default=4, help="Number of layers for the VAE discriminator.") 230 | parser.add_argument( 231 | "--image_size", 232 | type=int, 233 | default=256, 234 | help="Image size. You may want to start with small images, and then curriculum learn to larger ones, but because the vae is all convolution, it should generalize to 512 (as in paper) without training on it", 235 | ) 236 | parser.add_argument( 237 | "--lr_scheduler", 238 | type=str, 239 | default="constant", 240 | help='The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"]', 241 | ) 242 | parser.add_argument( 243 | "--scheduler_power", 244 | type=float, 245 | default=1.0, 246 | help="Controls the power of the polynomial decay schedule used by the CosineScheduleWithWarmup scheduler. " 247 | "It determines the rate at which the learning rate decreases during the schedule.", 248 | ) 249 | parser.add_argument( 250 | "--lr_warmup_steps", 251 | type=int, 252 | default=0, 253 | help="Number of steps for the warmup in the lr scheduler.", 254 | ) 255 | parser.add_argument( 256 | "--num_cycles", 257 | type=int, 258 | default=1, 259 | help="Number of cycles for the lr scheduler.", 260 | ) 261 | parser.add_argument( 262 | "--resume_path", 263 | type=str, 264 | default=None, 265 | help="Path to the last saved checkpoint. 'results/vae.steps.pt'", 266 | ) 267 | parser.add_argument( 268 | "--weight_decay", 269 | type=float, 270 | default=0.0, 271 | help="Optimizer weight_decay to use. Default: 0.0", 272 | ) 273 | parser.add_argument( 274 | "--taming_model_path", 275 | type=str, 276 | default=None, 277 | help="path to your trained VQGAN weights. This should be a .ckpt file. (only valid when taming option is enabled)", 278 | ) 279 | parser.add_argument( 280 | "--taming_config_path", 281 | type=str, 282 | default=None, 283 | help="path to your trained VQGAN config. This should be a .yaml file. (only valid when taming option is enabled)", 284 | ) 285 | parser.add_argument( 286 | "--optimizer", 287 | type=str, 288 | default="Lion", 289 | help="Optimizer to use. Choose between: ['Adam', 'AdamW','Lion']. Default: Lion", 290 | ) 291 | parser.add_argument( 292 | "--cache_path", 293 | type=str, 294 | default=None, 295 | help="The path to cache huggingface models", 296 | ) 297 | parser.add_argument( 298 | "--no_cache", 299 | action="store_true", 300 | help="Do not save the dataset pyarrow cache/files to disk to save disk space and reduce the time it takes to launch the training.", 301 | ) 302 | parser.add_argument( 303 | "--latest_checkpoint", 304 | action="store_true", 305 | help="Whether to use the latest checkpoint", 306 | ) 307 | parser.add_argument( 308 | "--do_not_save_config", 309 | action="store_true", 310 | default=False, 311 | help="Generate example YAML configuration file", 312 | ) 313 | parser.add_argument( 314 | "--use_l2_recon_loss", 315 | action="store_true", 316 | help="Use F.mse_loss instead of F.l1_loss.", 317 | ) 318 | 319 | 320 | @dataclass 321 | class Arguments: 322 | total_params: Optional[int] = None 323 | only_save_last_checkpoint: bool = False 324 | validation_image_scale: float = 1.0 325 | no_center_crop: bool = False 326 | no_flip: bool = False 327 | random_crop: bool = False 328 | dataset_save_path: Optional[str] = None 329 | clear_previous_experiments: bool = False 330 | max_grad_norm: Optional[float] = None 331 | discr_max_grad_norm: Optional[float] = None 332 | num_tokens: int = 256 333 | seq_len: int = 1024 334 | seed: int = 42 335 | valid_frac: float = 0.05 336 | use_ema: bool = False 337 | ema_beta: float = 0.995 338 | ema_update_after_step: int = 1 339 | ema_update_every: int = 1 340 | apply_grad_penalty_every: int = 4 341 | image_column: str = "image" 342 | caption_column: str = "caption" 343 | log_with: str = "wandb" 344 | mixed_precision: str = "no" 345 | use_8bit_adam: bool = False 346 | results_dir: str = "results" 347 | logging_dir: Optional[str] = None 348 | resume_path: Optional[str] = None 349 | dataset_name: Optional[str] = None 350 | streaming: bool = False 351 | train_data_dir: Optional[str] = None 352 | num_train_steps: int = -1 353 | num_epochs: int = 5 354 | dim: int = 128 355 | batch_size: int = 512 356 | lr: float = 1e-5 357 | gradient_accumulation_steps: int = 1 358 | save_results_every: int = 100 359 | save_model_every: int = 500 360 | checkpoint_limit: Union[int, str] = None 361 | vq_codebook_size: int = 256 362 | vq_codebook_dim: int = 256 363 | cond_drop_prob: float = 0.5 364 | image_size: int = 256 365 | lr_scheduler: str = "constant" 366 | scheduler_power: float = 1.0 367 | lr_warmup_steps: int = 0 368 | num_cycles: int = 1 369 | taming_model_path: Optional[str] = None 370 | taming_config_path: Optional[str] = None 371 | optimizer: str = "Lion" 372 | weight_decay: float = 0.0 373 | cache_path: Optional[str] = None 374 | no_cache: bool = False 375 | latest_checkpoint: bool = False 376 | do_not_save_config: bool = False 377 | use_l2_recon_loss: bool = False 378 | debug: bool = False 379 | config_path: Optional[str] = None 380 | 381 | 382 | def preprocess_webdataset(args, image): 383 | return {args.image_column: image} 384 | 385 | 386 | def main(): 387 | args = parser.parse_args(namespace=Arguments()) 388 | 389 | if args.config_path: 390 | print("Using config file and ignoring CLI args") 391 | 392 | try: 393 | conf = OmegaConf.load(args.config_path) 394 | conf_keys = conf.keys() 395 | args_to_convert = vars(args) 396 | 397 | for key in conf_keys: 398 | try: 399 | args_to_convert[key] = conf[key] 400 | except KeyError: 401 | print(f"Error parsing config - {key}: {conf[key]} | Using default or parsed") 402 | 403 | except FileNotFoundError: 404 | print("Could not find config, using default and parsed values...") 405 | 406 | project_config = ProjectConfiguration( 407 | project_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"), 408 | total_limit=args.checkpoint_limit, 409 | automatic_checkpoint_naming=True, 410 | ) 411 | 412 | accelerator = get_accelerator( 413 | log_with=args.log_with, 414 | gradient_accumulation_steps=args.gradient_accumulation_steps, 415 | mixed_precision=args.mixed_precision, 416 | project_config=project_config, 417 | even_batches=True, 418 | ) 419 | if accelerator.is_main_process: 420 | accelerator.init_trackers( 421 | args.project_name, 422 | config=vars(args), 423 | init_kwargs={ 424 | "wandb": { 425 | "entity": f"{args.wandb_user or wandb.api.default_entity}", 426 | "name": args.run_name, 427 | }, 428 | }, 429 | ) 430 | 431 | if args.webdataset is not None: 432 | import webdataset as wds 433 | 434 | dataset = wds.WebDataset(args.webdataset).shuffle(1000).decode("rgb").to_tuple("png") 435 | dataset = dataset.map(lambda image: preprocess_webdataset(args, image)) 436 | elif args.train_data_dir: 437 | dataset = get_dataset_from_dataroot( 438 | args.train_data_dir, 439 | image_column=args.image_column, 440 | caption_column=args.caption_column, 441 | save_path=args.dataset_save_path, 442 | save=not args.no_cache, 443 | ) 444 | elif args.dataset_name: 445 | if args.cache_path: 446 | dataset = load_dataset(args.dataset_name, streaming=args.streaming, cache_dir=args.cache_path)[ 447 | "train" 448 | ] 449 | else: 450 | dataset = load_dataset(args.dataset_name, streaming=args.streaming, cache_dir=args.cache_path)[ 451 | "train" 452 | ] 453 | if args.streaming: 454 | if dataset.info.dataset_size is None: 455 | print("Dataset doesn't support streaming, disabling streaming") 456 | args.streaming = False 457 | if args.cache_path: 458 | dataset = load_dataset(args.dataset_name, cache_dir=args.cache_path)[args.hf_split_name] 459 | else: 460 | dataset = load_dataset(args.dataset_name)[args.hf_split_name] 461 | 462 | if args.resume_path is not None and len(args.resume_path) > 1: 463 | load = True 464 | accelerator.print(f"Using Muse VQGanVAE, loading from {args.resume_path}") 465 | vae = VQGanVAE( 466 | dim=args.dim, 467 | vq_codebook_dim=args.vq_codebook_dim, 468 | vq_codebook_size=args.vq_codebook_size, 469 | l2_recon_loss=args.use_l2_recon_loss, 470 | channels=args.channels, 471 | layers=args.layers, 472 | discr_layers=args.discr_layers, 473 | accelerator=accelerator, 474 | ) 475 | 476 | if args.latest_checkpoint: 477 | accelerator.print("Finding latest checkpoint...") 478 | orig_vae_path = args.resume_path 479 | 480 | if os.path.isfile(args.resume_path) or ".pt" in args.resume_path: 481 | # If args.vae_path is a file, split it into directory and filename 482 | args.resume_path, _ = os.path.split(args.resume_path) 483 | 484 | checkpoint_files = glob.glob(os.path.join(args.resume_path, "vae.*.pt")) 485 | if checkpoint_files: 486 | latest_checkpoint_file = max( 487 | checkpoint_files, key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)) 488 | ) 489 | 490 | # Check if latest checkpoint is empty or unreadable 491 | if os.path.getsize(latest_checkpoint_file) == 0 or not os.access( 492 | latest_checkpoint_file, os.R_OK 493 | ): 494 | accelerator.print( 495 | f"Warning: latest checkpoint {latest_checkpoint_file} is empty or unreadable." 496 | ) 497 | if len(checkpoint_files) > 1: 498 | # Use the second last checkpoint as a fallback 499 | latest_checkpoint_file = max( 500 | checkpoint_files[:-1], key=lambda x: int(re.search(r"vae\.(\d+)\.pt", x).group(1)) 501 | ) 502 | accelerator.print("Using second last checkpoint: ", latest_checkpoint_file) 503 | else: 504 | accelerator.print("No usable checkpoint found.") 505 | load = False 506 | elif latest_checkpoint_file != orig_vae_path: 507 | accelerator.print("Resuming VAE from latest checkpoint: ", latest_checkpoint_file) 508 | else: 509 | accelerator.print("Using checkpoint specified in vae_path: ", orig_vae_path) 510 | 511 | args.resume_path = latest_checkpoint_file 512 | else: 513 | accelerator.print("No checkpoints found in directory: ", args.resume_path) 514 | load = False 515 | else: 516 | accelerator.print("Resuming VAE from: ", args.resume_path) 517 | 518 | if load: 519 | vae.load(args.resume_path, map="cpu") 520 | 521 | resume_from_parts = args.resume_path.split(".") 522 | for i in range(len(resume_from_parts) - 1, -1, -1): 523 | if resume_from_parts[i].isdigit(): 524 | current_step = int(resume_from_parts[i]) 525 | accelerator.print(f"Found step {current_step} for the VAE model.") 526 | break 527 | if current_step == 0: 528 | accelerator.print("No step found for the VAE model.") 529 | else: 530 | accelerator.print("No step found for the VAE model.") 531 | current_step = 0 532 | 533 | elif args.taming_model_path is not None and args.taming_config_path is not None: 534 | print(f"Using Taming VQGanVAE, loading from {args.taming_model_path}") 535 | vae = VQGanVAETaming( 536 | vqgan_model_path=args.taming_model_path, 537 | vqgan_config_path=args.taming_config_path, 538 | accelerator=accelerator, 539 | ) 540 | args.num_tokens = vae.codebook_size 541 | args.seq_len = vae.get_encoded_fmap_size(args.image_size) ** 2 542 | else: 543 | print("Initialising empty VAE") 544 | vae = VQGanVAE( 545 | dim=args.dim, 546 | vq_codebook_dim=args.vq_codebook_dim, 547 | vq_codebook_size=args.vq_codebook_size, 548 | channels=args.channels, 549 | layers=args.layers, 550 | discr_layers=args.discr_layers, 551 | accelerator=accelerator, 552 | ) 553 | 554 | current_step = 0 555 | 556 | # Use the parameters() method to get an iterator over all the learnable parameters of the model 557 | total_params = sum(p.numel() for p in vae.parameters()) 558 | args.total_params = total_params 559 | 560 | print(f"Total number of parameters: {format(total_params, ',d')}") 561 | 562 | dataset = ImageDataset( 563 | dataset, 564 | args.image_size, 565 | image_column=args.image_column, 566 | center_crop=not args.no_center_crop, 567 | flip=not args.no_flip, 568 | stream=args.streaming, 569 | random_crop=args.random_crop, 570 | alpha_channel=False if args.channels == 3 else True, 571 | ) 572 | # dataloader 573 | 574 | dataloader, validation_dataloader = split_dataset_into_dataloaders( 575 | dataset, args.valid_frac, args.seed, args.batch_size 576 | ) 577 | trainer = VQGanVAETrainer( 578 | vae, 579 | dataloader, 580 | validation_dataloader, 581 | accelerator, 582 | current_step=current_step + 1 if current_step != 0 else current_step, 583 | num_train_steps=args.num_train_steps, 584 | lr=args.lr, 585 | lr_scheduler_type=args.lr_scheduler, 586 | lr_warmup_steps=args.lr_warmup_steps, 587 | max_grad_norm=args.max_grad_norm, 588 | discr_max_grad_norm=args.discr_max_grad_norm, 589 | save_results_every=args.save_results_every, 590 | save_model_every=args.save_model_every, 591 | results_dir=args.results_dir, 592 | logging_dir=args.logging_dir if args.logging_dir else os.path.join(args.results_dir, "logs"), 593 | use_ema=args.use_ema, 594 | ema_beta=args.ema_beta, 595 | ema_update_after_step=args.ema_update_after_step, 596 | ema_update_every=args.ema_update_every, 597 | apply_grad_penalty_every=args.apply_grad_penalty_every, 598 | gradient_accumulation_steps=args.gradient_accumulation_steps, 599 | clear_previous_experiments=args.clear_previous_experiments, 600 | validation_image_scale=args.validation_image_scale, 601 | only_save_last_checkpoint=args.only_save_last_checkpoint, 602 | optimizer=args.optimizer, 603 | use_8bit_adam=args.use_8bit_adam, 604 | num_cycles=args.num_cycles, 605 | scheduler_power=args.scheduler_power, 606 | num_epochs=args.num_epochs, 607 | args=args, 608 | ) 609 | 610 | trainer.train() 611 | 612 | 613 | if __name__ == "__main__": 614 | main() 615 | --------------------------------------------------------------------------------