├── .github ├── FUNDING.yml └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── imagen.png ├── imagen_pytorch ├── __init__.py ├── cli.py ├── configs.py ├── data.py ├── elucidated_imagen.py ├── imagen_pytorch.py ├── imagen_video │ ├── __init__.py │ └── imagen_video.py ├── t5.py ├── trainer.py ├── utils.py └── version.py └── setup.py /.github/FUNDING.yml: -------------------------------------------------------------------------------- 1 | # These are supported funding model platforms 2 | 3 | github: [lucidrains] 4 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Imagen - Pytorch 4 | 5 | Implementation of Imagen, Google's Text-to-Image Neural Network that beats DALL-E2, in Pytorch. It is the new SOTA for text-to-image synthesis. 6 | 7 | Architecturally, it is actually much simpler than DALL-E2. It consists of a cascading DDPM conditioned on text embeddings from a large pretrained T5 model (attention network). It also contains dynamic clipping for improved classifier free guidance, noise level conditioning, and a memory efficient unet design. 8 | 9 | It appears neither CLIP nor prior network is needed after all. And so research continues. 10 | 11 | AI Coffee Break with Letitia | Assembly AI | Yannic Kilcher 12 | 13 | Please join Join us on Discord if you are interested in helping out with the replication with the LAION community 14 | 15 | ## Shoutouts 16 | 17 | - StabilityAI for the generous sponsorship, as well as my other sponsors out there 18 | 19 | - 🤗 Huggingface for their amazing transformers library. The text encoder portion is pretty much taken care of because of them 20 | 21 | - Sylvain and Zachary for the Accelerate library, which this repository uses for distributed training 22 | 23 | - Alex for einops, indispensable tool for tensor manipulation 24 | 25 | - Jorge Gomes for helping out with the T5 loading code and advice on the correct T5 version 26 | 27 | - Katherine Crowson, for her beautiful code, which helped me understand the continuous time version of gaussian diffusion 28 | 29 | - Marunine and Netruk44, for reviewing code, sharing experimental results, and help with debugging 30 | 31 | - Marunine for providing a potential solution for a color shifting issue in the memory efficient u-nets. Thanks to Jacob for sharing experimental comparisons between the base and memory-efficient unets 32 | 33 | - Marunine for finding numerous bugs, resolving an issue with resize right, and for sharing his experimental configurations and results 34 | 35 | - MalumaDev for proposing the use of pixel shuffle upsampler to fix checkboard artifacts 36 | 37 | - Valentin for pointing out insufficient skip connections in the unet, as well as the specific method of attention conditioning in the base-unet in the appendix 38 | 39 | - BIGJUN for catching a big bug with continuous time gaussian diffusion noise level conditioning at inference time 40 | 41 | - Bingbing for identifying a bug with sampling and order of normalizing and noising with low resolution conditioning image 42 | 43 | ## Install 44 | 45 | ```bash 46 | $ pip install imagen-pytorch 47 | ``` 48 | 49 | ## Usage 50 | 51 | ```python 52 | import torch 53 | from imagen_pytorch import Unet, Imagen 54 | 55 | # unet for imagen 56 | 57 | unet1 = Unet( 58 | dim = 32, 59 | cond_dim = 512, 60 | dim_mults = (1, 2, 4, 8), 61 | num_resnet_blocks = 3, 62 | layer_attns = (False, True, True, True), 63 | layer_cross_attns = (False, True, True, True) 64 | ) 65 | 66 | unet2 = Unet( 67 | dim = 32, 68 | cond_dim = 512, 69 | dim_mults = (1, 2, 4, 8), 70 | num_resnet_blocks = (2, 4, 8, 8), 71 | layer_attns = (False, False, False, True), 72 | layer_cross_attns = (False, False, False, True) 73 | ) 74 | 75 | # imagen, which contains the unets above (base unet and super resoluting ones) 76 | 77 | imagen = Imagen( 78 | unets = (unet1, unet2), 79 | image_sizes = (64, 256), 80 | timesteps = 1000, 81 | cond_drop_prob = 0.1 82 | ).cuda() 83 | 84 | # mock images (get a lot of this) and text encodings from large T5 85 | 86 | text_embeds = torch.randn(4, 256, 768).cuda() 87 | images = torch.randn(4, 3, 256, 256).cuda() 88 | 89 | # feed images into imagen, training each unet in the cascade 90 | 91 | for i in (1, 2): 92 | loss = imagen(images, text_embeds = text_embeds, unet_number = i) 93 | loss.backward() 94 | 95 | # do the above for many many many many steps 96 | # now you can sample an image based on the text embeddings from the cascading ddpm 97 | 98 | images = imagen.sample(texts = [ 99 | 'a whale breaching from afar', 100 | 'young girl blowing out candles on her birthday cake', 101 | 'fireworks with blue and green sparkles' 102 | ], cond_scale = 3.) 103 | 104 | images.shape # (3, 3, 256, 256) 105 | ``` 106 | 107 | For simpler training, you can directly supply text strings instead of precomputing text encodings. (Although for scaling purposes, you will definitely want to precompute the textual embeddings + mask) 108 | 109 | The number of textual captions must match the batch size of the images if you go this route. 110 | 111 | ```python 112 | # mock images and text (get a lot of this) 113 | 114 | texts = [ 115 | 'a child screaming at finding a worm within a half-eaten apple', 116 | 'lizard running across the desert on two feet', 117 | 'waking up to a psychedelic landscape', 118 | 'seashells sparkling in the shallow waters' 119 | ] 120 | 121 | images = torch.randn(4, 3, 256, 256).cuda() 122 | 123 | # feed images into imagen, training each unet in the cascade 124 | 125 | for i in (1, 2): 126 | loss = imagen(images, texts = texts, unet_number = i) 127 | loss.backward() 128 | ``` 129 | 130 | With the `ImagenTrainer` wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling `update` 131 | 132 | ```python 133 | import torch 134 | from imagen_pytorch import Unet, Imagen, ImagenTrainer 135 | 136 | # unet for imagen 137 | 138 | unet1 = Unet( 139 | dim = 32, 140 | cond_dim = 512, 141 | dim_mults = (1, 2, 4, 8), 142 | num_resnet_blocks = 3, 143 | layer_attns = (False, True, True, True), 144 | ) 145 | 146 | unet2 = Unet( 147 | dim = 32, 148 | cond_dim = 512, 149 | dim_mults = (1, 2, 4, 8), 150 | num_resnet_blocks = (2, 4, 8, 8), 151 | layer_attns = (False, False, False, True), 152 | layer_cross_attns = (False, False, False, True) 153 | ) 154 | 155 | # imagen, which contains the unets above (base unet and super resoluting ones) 156 | 157 | imagen = Imagen( 158 | unets = (unet1, unet2), 159 | text_encoder_name = 't5-large', 160 | image_sizes = (64, 256), 161 | timesteps = 1000, 162 | cond_drop_prob = 0.1 163 | ).cuda() 164 | 165 | # wrap imagen with the trainer class 166 | 167 | trainer = ImagenTrainer(imagen) 168 | 169 | # mock images (get a lot of this) and text encodings from large T5 170 | 171 | text_embeds = torch.randn(64, 256, 1024).cuda() 172 | images = torch.randn(64, 3, 256, 256).cuda() 173 | 174 | # feed images into imagen, training each unet in the cascade 175 | 176 | loss = trainer( 177 | images, 178 | text_embeds = text_embeds, 179 | unet_number = 1, # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2 180 | max_batch_size = 4 # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory 181 | ) 182 | 183 | trainer.update(unet_number = 1) 184 | 185 | # do the above for many many many many steps 186 | # now you can sample an image based on the text embeddings from the cascading ddpm 187 | 188 | images = trainer.sample(texts = [ 189 | 'a puppy looking anxiously at a giant donut on the table', 190 | 'the milky way galaxy in the style of monet' 191 | ], cond_scale = 3.) 192 | 193 | images.shape # (2, 3, 256, 256) 194 | ``` 195 | 196 | You can also train Imagen without text (unconditional image generation) as follows 197 | 198 | ```python 199 | import torch 200 | from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer 201 | 202 | # unets for unconditional imagen 203 | 204 | unet1 = Unet( 205 | dim = 32, 206 | dim_mults = (1, 2, 4), 207 | num_resnet_blocks = 3, 208 | layer_attns = (False, True, True), 209 | layer_cross_attns = False, 210 | use_linear_attn = True 211 | ) 212 | 213 | unet2 = SRUnet256( 214 | dim = 32, 215 | dim_mults = (1, 2, 4), 216 | num_resnet_blocks = (2, 4, 8), 217 | layer_attns = (False, False, True), 218 | layer_cross_attns = False 219 | ) 220 | 221 | # imagen, which contains the unets above (base unet and super resoluting ones) 222 | 223 | imagen = Imagen( 224 | condition_on_text = False, # this must be set to False for unconditional Imagen 225 | unets = (unet1, unet2), 226 | image_sizes = (64, 128), 227 | timesteps = 1000 228 | ) 229 | 230 | trainer = ImagenTrainer(imagen).cuda() 231 | 232 | # now get a ton of images and feed it through the Imagen trainer 233 | 234 | training_images = torch.randn(4, 3, 256, 256).cuda() 235 | 236 | # train each unet separately 237 | # in this example, only training on unet number 1 238 | 239 | loss = trainer(training_images, unet_number = 1) 240 | trainer.update(unet_number = 1) 241 | 242 | # do the above for many many many many steps 243 | # now you can sample images unconditionally from the cascading unet(s) 244 | 245 | images = trainer.sample(batch_size = 16) # (16, 3, 128, 128) 246 | ``` 247 | 248 | Or train only super-resoluting unets 249 | 250 | ```python 251 | import torch 252 | from imagen_pytorch import Unet, NullUnet, Imagen 253 | 254 | # unet for imagen 255 | 256 | unet1 = NullUnet() # add a placeholder "null" unet for the base unet 257 | 258 | unet2 = Unet( 259 | dim = 32, 260 | cond_dim = 512, 261 | dim_mults = (1, 2, 4, 8), 262 | num_resnet_blocks = (2, 4, 8, 8), 263 | layer_attns = (False, False, False, True), 264 | layer_cross_attns = (False, False, False, True) 265 | ) 266 | 267 | # imagen, which contains the unets above (base unet and super resoluting ones) 268 | 269 | imagen = Imagen( 270 | unets = (unet1, unet2), 271 | image_sizes = (64, 256), 272 | timesteps = 250, 273 | cond_drop_prob = 0.1 274 | ).cuda() 275 | 276 | # mock images (get a lot of this) and text encodings from large T5 277 | 278 | text_embeds = torch.randn(4, 256, 768).cuda() 279 | images = torch.randn(4, 3, 256, 256).cuda() 280 | 281 | # feed images into imagen, training each unet in the cascade 282 | 283 | loss = imagen(images, text_embeds = text_embeds, unet_number = 2) 284 | loss.backward() 285 | 286 | # do the above for many many many many steps 287 | # now you can sample an image based on the text embeddings as well as low resolution images 288 | 289 | lowres_images = torch.randn(3, 3, 64, 64).cuda() # starting un-resoluted images 290 | 291 | images = imagen.sample( 292 | texts = [ 293 | 'a whale breaching from afar', 294 | 'young girl blowing out candles on her birthday cake', 295 | 'fireworks with blue and green sparkles' 296 | ], 297 | start_at_unet_number = 2, # start at unet number 2 298 | start_image_or_video = lowres_images, # pass in low resolution images to be resoluted 299 | cond_scale = 3.) 300 | 301 | images.shape # (3, 3, 256, 256) 302 | ``` 303 | 304 | At any time you can save and load the trainer and all associated states with the `save` and `load` methods. It is recommended you use these methods instead of manually saving with a `state_dict` call, as there are some device memory management being done underneath the hood within the trainer. 305 | 306 | ex. 307 | 308 | ```python 309 | trainer.save('./path/to/checkpoint.pt') 310 | 311 | trainer.load('./path/to/checkpoint.pt') 312 | 313 | trainer.steps # (2,) step number for each of the unets, in this case 2 314 | ``` 315 | 316 | ## Dataloader 317 | 318 | You can also rely on the `ImagenTrainer` to automatically train off `DataLoader` instances. You simply have to craft your `DataLoader` to return either `images` (for unconditional case), or of `('images', 'text_embeds')` for text-guided generation. 319 | 320 | ex. unconditional training 321 | 322 | ```python 323 | from imagen_pytorch import Unet, Imagen, ImagenTrainer 324 | from imagen_pytorch.data import Dataset 325 | 326 | # unets for unconditional imagen 327 | 328 | unet = Unet( 329 | dim = 32, 330 | dim_mults = (1, 2, 4, 8), 331 | num_resnet_blocks = 1, 332 | layer_attns = (False, False, False, True), 333 | layer_cross_attns = False 334 | ) 335 | 336 | # imagen, which contains the unet above 337 | 338 | imagen = Imagen( 339 | condition_on_text = False, # this must be set to False for unconditional Imagen 340 | unets = unet, 341 | image_sizes = 128, 342 | timesteps = 1000 343 | ) 344 | 345 | trainer = ImagenTrainer( 346 | imagen = imagen, 347 | split_valid_from_train = True # whether to split the validation dataset from the training 348 | ).cuda() 349 | 350 | # instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training 351 | 352 | dataset = Dataset('/path/to/training/images', image_size = 128) 353 | 354 | trainer.add_train_dataset(dataset, batch_size = 16) 355 | 356 | # working training loop 357 | 358 | for i in range(200000): 359 | loss = trainer.train_step(unet_number = 1, max_batch_size = 4) 360 | print(f'loss: {loss}') 361 | 362 | if not (i % 50): 363 | valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4) 364 | print(f'valid loss: {valid_loss}') 365 | 366 | if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed 367 | images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image] 368 | images[0].save(f'./sample-{i // 100}.png') 369 | 370 | ``` 371 | 372 | ## Multi GPU 373 | 374 | Thanks to 🤗 Accelerate, you can do multi GPU training easily with two steps. 375 | 376 | First you need to invoke `accelerate config` in the same directory as your training script (say it is named `train.py`) 377 | 378 | ```bash 379 | $ accelerate config 380 | ``` 381 | 382 | Next, instead of calling `python train.py` as you would for single GPU, you would use the accelerate CLI as so 383 | 384 | ```bash 385 | $ accelerate launch train.py 386 | ``` 387 | 388 | That's it! 389 | 390 | ## Command-line 391 | 392 | To further democratize the use of this machine imagination, I have built in the ability to generate an image with any text prompt using one command line as so 393 | 394 | ex. 395 | 396 | ```bash 397 | $ imagen --model ./path/to/model/checkpoint.pt "a squirrel raiding the birdfeeder" 398 | # image is saved to ./a_squirrel_raiding_the_birdfeeder.png 399 | ``` 400 | 401 | In order to save checkpoints that can make use of this feature, you must instantiate your Imagen instance using the config classes, `ImagenConfig` and `ElucidatedImagenConfig` 402 | 403 | For proper training, you'll likely want to setup config-driven training anyways. 404 | 405 | ex. 406 | 407 | ```python 408 | import torch 409 | from imagen_pytorch import ImagenConfig, ElucidatedImagenConfig, ImagenTrainer 410 | 411 | # in this example, using elucidated imagen 412 | 413 | imagen = ElucidatedImagenConfig( 414 | unets = [ 415 | dict(dim = 32, dim_mults = (1, 2, 4, 8)), 416 | dict(dim = 32, dim_mults = (1, 2, 4, 8)) 417 | ], 418 | image_sizes = (64, 128), 419 | cond_drop_prob = 0.5, 420 | num_sample_steps = 32 421 | ).create() 422 | 423 | trainer = ImagenTrainer(imagen) 424 | 425 | # do your training ... 426 | 427 | # then save it 428 | 429 | trainer.save('./checkpoint.pt') 430 | 431 | # you should see a message informing you that ./checkpoint.pt is commandable from the terminal 432 | ``` 433 | 434 | It really should be as simple as that 435 | 436 | You can also pass this checkpoint file around, and anyone can continue finetune on their own data 437 | 438 | ```python 439 | from imagen_pytorch import load_imagen_from_checkpoint, ImagenTrainer 440 | 441 | imagen = load_imagen_from_checkpoint('./checkpoint.pt') 442 | 443 | trainer = ImagenTrainer(imagen) 444 | 445 | # continue training / fine-tuning 446 | ``` 447 | 448 | ## Inpainting 449 | 450 | Inpainting follows the formulation laid out by the recent Repaint paper. Simply pass in `inpaint_images` and `inpaint_masks` to the `sample` function on either `Imagen` or `ElucidatedImagen` 451 | 452 | ```python 453 | 454 | inpaint_images = torch.randn(4, 3, 512, 512).cuda() # (batch, channels, height, width) 455 | inpaint_masks = torch.ones((4, 512, 512)).bool().cuda() # (batch, height, width) 456 | 457 | inpainted_images = trainer.sample(texts = [ 458 | 'a whale breaching from afar', 459 | 'young girl blowing out candles on her birthday cake', 460 | 'fireworks with blue and green sparkles', 461 | 'dust motes swirling in the morning sunshine on the windowsill' 462 | ], inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, cond_scale = 5.) 463 | 464 | inpainted_images # (4, 3, 512, 512) 465 | ``` 466 | 467 | ## Experimental 468 | 469 | Tero Karras of StyleGAN fame has written a new paper with results that have been corroborated by a number of independent researchers as well as on my own machine. I have decided to create a version of `Imagen`, the `ElucidatedImagen`, so that one can use the new elucidated DDPM for text-guided cascading generation. 470 | 471 | Simply import `ElucidatedImagen`, and then instantiate the instance as you did before. The hyperparameters are different than the usual ones for discrete and continuous time gaussian diffusion, and can be individualized for each unet in the cascade. 472 | 473 | Ex. 474 | 475 | ```python 476 | from imagen_pytorch import ElucidatedImagen 477 | 478 | # instantiate your unets ... 479 | 480 | imagen = ElucidatedImagen( 481 | unets = (unet1, unet2), 482 | image_sizes = (64, 128), 483 | cond_drop_prob = 0.1, 484 | num_sample_steps = (64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are) 485 | sigma_min = 0.002, # min noise level 486 | sigma_max = (80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler 487 | sigma_data = 0.5, # standard deviation of data distribution 488 | rho = 7, # controls the sampling schedule 489 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 490 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 491 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 492 | S_tmin = 0.05, 493 | S_tmax = 50, 494 | S_noise = 1.003, 495 | ).cuda() 496 | 497 | # rest is the same as above 498 | 499 | ``` 500 | 501 | ## Text to Video (ongoing research) 502 | 503 | This repository will also start accumulating new research around text guided video synthesis. For starters it will adopt the 3d unet architecture described by Jonathan Ho in Video Diffusion Models 504 | 505 | Ex. 506 | 507 | ```python 508 | import torch 509 | from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer 510 | 511 | unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda() 512 | 513 | unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda() 514 | 515 | # elucidated imagen, which contains the unets above (base unet and super resoluting ones) 516 | 517 | imagen = ElucidatedImagen( 518 | unets = (unet1, unet2), 519 | image_sizes = (16, 32), 520 | random_crop_sizes = (None, 16), 521 | num_sample_steps = 10, 522 | cond_drop_prob = 0.1, 523 | sigma_min = 0.002, # min noise level 524 | sigma_max = (80, 160), # max noise level, double the max noise level for upsampler 525 | sigma_data = 0.5, # standard deviation of data distribution 526 | rho = 7, # controls the sampling schedule 527 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 528 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 529 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 530 | S_tmin = 0.05, 531 | S_tmax = 50, 532 | S_noise = 1.003, 533 | ).cuda() 534 | 535 | # mock videos (get a lot of this) and text encodings from large T5 536 | 537 | texts = [ 538 | 'a whale breaching from afar', 539 | 'young girl blowing out candles on her birthday cake', 540 | 'fireworks with blue and green sparkles', 541 | 'dust motes swirling in the morning sunshine on the windowsill' 542 | ] 543 | 544 | videos = torch.randn(4, 3, 10, 32, 32).cuda() # (batch, channels, time / video frames, height, width) 545 | 546 | # feed images into imagen, training each unet in the cascade 547 | # for this example, only training unet 1 548 | 549 | trainer = ImagenTrainer(imagen) 550 | trainer(videos, texts = texts, unet_number = 1) 551 | trainer.update(unet_number = 1) 552 | 553 | videos = trainer.sample(texts = texts, video_frames = 20) # extrapolating to 20 frames from training on 10 frames 554 | 555 | videos.shape # (4, 3, 20, 32, 32) 556 | 557 | ``` 558 | 559 | ## FAQ 560 | 561 | - Why are my generated images not aligning well with the text? 562 | 563 | Imagen uses an algorithm called Classifier Free Guidance. When sampling, you apply a scale to the conditioning (text in this case) of greater than `1.0`. 564 | 565 | Researcher Netruk44 have reported `5-10` to be optimal, but anything greater than `10` to break. 566 | 567 | ```python 568 | trainer.sample(texts = [ 569 | 'a cloud in the shape of a roman gladiator' 570 | ], cond_scale = 5.) # <-- cond_scale is the conditioning scale, needs to be greater than 1.0 to be better than average 571 | ``` 572 | 573 | - Are there any pretrained models yet? 574 | 575 | Not at the moment but one will likely be trained and open sourced within the year, if not sooner. If you would like to participate, you can join the community of artificial neural network trainers at Laion (discord link is in the Readme above) and start collaborating. 576 | 577 | - Will this technology take my job? 578 | 579 | More the reason why you should start training your own model, starting today! The last thing we need is this technology being in the hands of an elite few. Hopefully this repository reduces the work to just finding the necessary compute, and augmenting with your own curated dataset. 580 | 581 | - What am I allowed to do with this repository? 582 | 583 | Anything! It is MIT licensed. In other words, you can freely copy / paste for your own research, remixed for whatever modality you can think of. Go train amazing models for profit, for science, or simply to satiate your own personal pleasure at witnessing something divine unravel in front of you. 584 | 585 | ## Related Works 586 | 587 | - Audio diffusion from Flavio Schneider 588 | 589 | - Mini Imagen from Ryan O. | AssemblyAI writeup 590 | 591 | ## Todo 592 | 593 | - [x] use huggingface transformers for T5-small text embeddings 594 | - [x] add dynamic thresholding 595 | - [x] add dynamic thresholding DALLE2 and video-diffusion repository as well 596 | - [x] allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer) 597 | - [x] add the lowres noise level with the pseudocode in appendix, and figure out what is this sweep they do at inference time 598 | - [x] port over some training code from DALLE2 599 | - [x] need to be able to use a different noise schedule per unet (cosine was used for base, but linear for SR) 600 | - [x] just make one master-configurable unet 601 | - [x] complete resnet block (biggan inspired? but with groupnorm) - complete self attention 602 | - [x] complete conditioning embedding block (and make it completely configurable, whether it be attention, film etc) 603 | - [x] consider using perceiver-resampler from https://github.com/lucidrains/flamingo-pytorch in place of attention pooling 604 | - [x] add attention pooling option, in addition to cross attention and film 605 | - [x] add optional cosine decay schedule with warmup, for each unet, to trainer 606 | - [x] switch to continuous timesteps instead of discretized, as it seems that is what they used for all stages - first figure out the linear noise schedule case from the variational ddpm paper https://openreview.net/forum?id=2LdBqxc1Yv 607 | - [x] figure out log(snr) for alpha cosine noise schedule. 608 | - [x] suppress the transformers warning because only T5encoder is used 609 | - [x] allow setting for using linear attention on layers where full attention cannot be used 610 | - [x] force unets in continuous time case to use non-fouriered conditions (just pass the log(snr) through an MLP with optional layernorms), as that is what i have working locally 611 | - [x] removed learned variance 612 | - [x] add p2 loss weighting for continuous time 613 | - [x] make sure cascading ddpm can be trained without text condition, and make sure both continuous and discrete time gaussian diffusion works 614 | - [x] use primer's depthwise convs on the qkv projections in linear attention (or use token shifting before projections) - also use new dropout proposed by bayesformer, as it seems to work well with linear attention 615 | - [x] explore skip layer excitation in unet decoder 616 | - [x] accelerate integration 617 | - [x] build out CLI tool and one-line generation of image 618 | - [x] knock out any issues that arised from accelerate 619 | - [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865 620 | - [x] build a simple checkpointing system, backed by a folder 621 | - [x] add skip connection from outputs of all upsample blocks, used in unet squared paper and some previous unet works 622 | - [x] add fsspec, recommended by Romain @rom1504, for cloud / local file system agnostic persistence of checkpoints 623 | - [x] test out persistence in gcs with https://github.com/fsspec/gcsfs 624 | - [x] extend to video generation, using axial time attention as in Ho's video ddpm paper 625 | - [x] allow elucidated imagen to generalize to any shape 626 | - [x] allow for imagen to generalize to any shape 627 | - [x] add dynamic positional bias for the best type of length extrapolation across video time 628 | - [x] move video frames to sample function, as we will be attempting time extrapolation 629 | - [x] attention bias to null key / values should be a learned scalar of head dimension 630 | - [x] add self-conditioning from bit diffusion paper, already coded up at ddpm-pytorch 631 | - [ ] reread cogvideo and figure out how frame rate conditioning could be used 632 | - [ ] bring in attention expertise for self attention layers in unet3d 633 | - [ ] consider bringing in NUWA's 3d convolutional attention 634 | - [ ] consider transformer-xl memories in the temporal attention blocks 635 | - [ ] consider perceiver-ar approach to attending to past time 636 | - [ ] frame dropouts during attention for achieving both regularizing effect as well as shortened training time 637 | - [ ] investigate frank wood's claims https://github.com/lucidrains/flexible-diffusion-modeling-videos-pytorch and either add the hierarchical sampling technique, or let people know about its deficiencies 638 | - [ ] make sure inpainting works with video 639 | - [ ] offer challenging moving mnist (with distractor objects) as a one-line trainable baseline for researchers to branch off of for text to video 640 | - [ ] build out CLI tool for training, resuming training off config file 641 | - [ ] preencoding of text to memmapped embeddings 642 | - [ ] be able to create dataloader iterators based on the old epoch style, also configure shuffling etc 643 | - [ ] be able to also pass in arguments (instead of requiring forward to be all keyword args on model) 644 | - [ ] bring in reversible blocks from revnets for 3d unet, to lessen memory burden 645 | - [ ] add ability to only train super-resolution network 646 | - [ ] read dpm-solver see if it is applicable to continuous time gaussian diffusion 647 | - [ ] allow for conditioning video frames with arbitrary absolute times (calculate RPE during temporal attention) 648 | - [ ] accommodate dream booth fine tuning 649 | - [ ] add textual inversion 650 | - [ ] cleanup self conditioning to be extracted at imagen instantiation 651 | - [ ] incorporate all learnings from make-a-video (https://makeavideo.studio/) 652 | - [ ] add v-parameterization (https://arxiv.org/abs/2202.00512) from imagen video paper, the only thing new 653 | 654 | ## Citations 655 | 656 | ```bibtex 657 | @inproceedings{Saharia2022PhotorealisticTD, 658 | title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding}, 659 | author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily L. Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and Seyedeh Sara Mahdavi and Raphael Gontijo Lopes and Tim Salimans and Jonathan Ho and David Fleet and Mohammad Norouzi}, 660 | year = {2022} 661 | } 662 | ``` 663 | 664 | ```bibtex 665 | @article{Alayrac2022Flamingo, 666 | title = {Flamingo: a Visual Language Model for Few-Shot Learning}, 667 | author = {Jean-Baptiste Alayrac et al}, 668 | year = {2022} 669 | } 670 | ``` 671 | 672 | ```bibtex 673 | @article{Choi2022PerceptionPT, 674 | title = {Perception Prioritized Training of Diffusion Models}, 675 | author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon}, 676 | journal = {ArXiv}, 677 | year = {2022}, 678 | volume = {abs/2204.00227} 679 | } 680 | ``` 681 | 682 | ```bibtex 683 | @inproceedings{Sankararaman2022BayesFormerTW, 684 | title = {BayesFormer: Transformer with Uncertainty Estimation}, 685 | author = {Karthik Abinav Sankararaman and Sinong Wang and Han Fang}, 686 | year = {2022} 687 | } 688 | ``` 689 | 690 | ```bibtex 691 | @article{So2021PrimerSF, 692 | title = {Primer: Searching for Efficient Transformers for Language Modeling}, 693 | author = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le}, 694 | journal = {ArXiv}, 695 | year = {2021}, 696 | volume = {abs/2109.08668} 697 | } 698 | ``` 699 | 700 | ```bibtex 701 | @misc{cao2020global, 702 | title = {Global Context Networks}, 703 | author = {Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu}, 704 | year = {2020}, 705 | eprint = {2012.13375}, 706 | archivePrefix = {arXiv}, 707 | primaryClass = {cs.CV} 708 | } 709 | ``` 710 | 711 | ```bibtex 712 | @article{Karras2022ElucidatingTD, 713 | title = {Elucidating the Design Space of Diffusion-Based Generative Models}, 714 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine}, 715 | journal = {ArXiv}, 716 | year = {2022}, 717 | volume = {abs/2206.00364} 718 | } 719 | ``` 720 | 721 | ```bibtex 722 | @inproceedings{NEURIPS2020_4c5bcfec, 723 | author = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter}, 724 | booktitle = {Advances in Neural Information Processing Systems}, 725 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin}, 726 | pages = {6840--6851}, 727 | publisher = {Curran Associates, Inc.}, 728 | title = {Denoising Diffusion Probabilistic Models}, 729 | url = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf}, 730 | volume = {33}, 731 | year = {2020} 732 | } 733 | ``` 734 | 735 | ```bibtex 736 | @article{Lugmayr2022RePaintIU, 737 | title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models}, 738 | author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool}, 739 | journal = {ArXiv}, 740 | year = {2022}, 741 | volume = {abs/2201.09865} 742 | } 743 | ``` 744 | 745 | ```bibtex 746 | @misc{ho2022video, 747 | title = {Video Diffusion Models}, 748 | author = {Jonathan Ho and Tim Salimans and Alexey Gritsenko and William Chan and Mohammad Norouzi and David J. Fleet}, 749 | year = {2022}, 750 | eprint = {2204.03458}, 751 | archivePrefix = {arXiv}, 752 | primaryClass = {cs.CV} 753 | } 754 | ``` 755 | 756 | ```bibtex 757 | @inproceedings{rogozhnikov2022einops, 758 | title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation}, 759 | author = {Alex Rogozhnikov}, 760 | booktitle = {International Conference on Learning Representations}, 761 | year = {2022}, 762 | url = {https://openreview.net/forum?id=oapKSVM2bcj} 763 | } 764 | ``` 765 | 766 | ```bibtex 767 | @misc{chen2022analog, 768 | title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning}, 769 | author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton}, 770 | year = {2022}, 771 | eprint = {2208.04202}, 772 | archivePrefix = {arXiv}, 773 | primaryClass = {cs.CV} 774 | } 775 | ``` 776 | 777 | ```bibtex 778 | @article{Sunkara2022NoMS, 779 | title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects}, 780 | author = {Raja Sunkara and Tie Luo}, 781 | journal = {ArXiv}, 782 | year = {2022}, 783 | volume = {abs/2208.03641} 784 | } 785 | ``` 786 | -------------------------------------------------------------------------------- /imagen.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/abishakdevarasan/imagen-pytorch/fa29d249a8bdd7b97ae0da8da02261fc69292b72/imagen.png -------------------------------------------------------------------------------- /imagen_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from imagen_pytorch.imagen_pytorch import Imagen, Unet 2 | from imagen_pytorch.imagen_pytorch import NullUnet 3 | from imagen_pytorch.imagen_pytorch import BaseUnet64, SRUnet256, SRUnet1024 4 | from imagen_pytorch.trainer import ImagenTrainer 5 | from imagen_pytorch.version import __version__ 6 | 7 | # imagen using the elucidated ddpm from Tero Karras' new paper 8 | 9 | from imagen_pytorch.elucidated_imagen import ElucidatedImagen 10 | 11 | # config driven creation of imagen instances 12 | 13 | from imagen_pytorch.configs import UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig 14 | 15 | # utils 16 | 17 | from imagen_pytorch.utils import load_imagen_from_checkpoint 18 | 19 | # video 20 | 21 | from imagen_pytorch.imagen_video import Unet3D 22 | -------------------------------------------------------------------------------- /imagen_pytorch/cli.py: -------------------------------------------------------------------------------- 1 | import click 2 | import torch 3 | from pathlib import Path 4 | 5 | from imagen_pytorch import load_imagen_from_checkpoint 6 | from imagen_pytorch.version import __version__ 7 | from imagen_pytorch.utils import safeget 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def simple_slugify(text, max_length = 255): 13 | return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length] 14 | 15 | def main(): 16 | pass 17 | 18 | @click.command() 19 | @click.option('--model', default = './imagen.pt', help = 'path to trained Imagen model') 20 | @click.option('--cond_scale', default = 5, help = 'conditioning scale (classifier free guidance) in decoder') 21 | @click.option('--load_ema', default = True, help = 'load EMA version of unets if available') 22 | @click.argument('text') 23 | def imagen( 24 | model, 25 | cond_scale, 26 | load_ema, 27 | text 28 | ): 29 | model_path = Path(model) 30 | full_model_path = str(model_path.resolve()) 31 | assert model_path.exists(), f'model not found at {full_model_path}' 32 | loaded = torch.load(str(model_path)) 33 | 34 | # get version 35 | 36 | version = safeget(loaded, 'version') 37 | print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}') 38 | 39 | # get imagen parameters and type 40 | 41 | imagen = load_imagen_from_checkpoint(str(model_path), load_ema_if_available = load_ema) 42 | imagen.cuda() 43 | 44 | # generate image 45 | 46 | pil_image = imagen.sample(text, cond_scale = cond_scale, return_pil_images = True) 47 | 48 | image_path = f'./{simple_slugify(text)}.png' 49 | pil_image[0].save(image_path) 50 | 51 | print(f'image saved to {str(image_path)}') 52 | return 53 | -------------------------------------------------------------------------------- /imagen_pytorch/configs.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pydantic import BaseModel, validator, root_validator 3 | from typing import List, Iterable, Optional, Union, Tuple, Dict, Any 4 | from enum import Enum 5 | 6 | from imagen_pytorch.imagen_pytorch import Imagen, Unet, Unet3D, NullUnet 7 | from imagen_pytorch.trainer import ImagenTrainer 8 | from imagen_pytorch.elucidated_imagen import ElucidatedImagen 9 | from imagen_pytorch.t5 import DEFAULT_T5_NAME, get_encoded_dim 10 | 11 | # helper functions 12 | 13 | def exists(val): 14 | return val is not None 15 | 16 | def default(val, d): 17 | return val if exists(val) else d 18 | 19 | def ListOrTuple(inner_type): 20 | return Union[List[inner_type], Tuple[inner_type]] 21 | 22 | def SingleOrList(inner_type): 23 | return Union[inner_type, ListOrTuple(inner_type)] 24 | 25 | # noise schedule 26 | 27 | class NoiseSchedule(Enum): 28 | cosine = 'cosine' 29 | linear = 'linear' 30 | 31 | class AllowExtraBaseModel(BaseModel): 32 | class Config: 33 | extra = "allow" 34 | use_enum_values = True 35 | 36 | # imagen pydantic classes 37 | 38 | class NullUnetConfig(BaseModel): 39 | is_null: bool 40 | 41 | def create(self): 42 | return NullUnet() 43 | 44 | class UnetConfig(AllowExtraBaseModel): 45 | dim: int 46 | dim_mults: ListOrTuple(int) 47 | text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME) 48 | cond_dim: int = None 49 | channels: int = 3 50 | attn_dim_head: int = 32 51 | attn_heads: int = 16 52 | 53 | def create(self): 54 | return Unet(**self.dict()) 55 | 56 | class Unet3DConfig(AllowExtraBaseModel): 57 | dim: int 58 | dim_mults: ListOrTuple(int) 59 | text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME) 60 | cond_dim: int = None 61 | channels: int = 3 62 | attn_dim_head: int = 32 63 | attn_heads: int = 16 64 | 65 | def create(self): 66 | return Unet3D(**self.dict()) 67 | 68 | class ImagenConfig(AllowExtraBaseModel): 69 | unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig]) 70 | image_sizes: ListOrTuple(int) 71 | video: bool = False 72 | timesteps: SingleOrList(int) = 1000 73 | noise_schedules: SingleOrList(NoiseSchedule) = 'cosine' 74 | text_encoder_name: str = DEFAULT_T5_NAME 75 | channels: int = 3 76 | loss_type: str = 'l2' 77 | cond_drop_prob: float = 0.5 78 | 79 | @validator('image_sizes') 80 | def check_image_sizes(cls, image_sizes, values): 81 | unets = values.get('unets') 82 | if len(image_sizes) != len(unets): 83 | raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}') 84 | return image_sizes 85 | 86 | def create(self): 87 | decoder_kwargs = self.dict() 88 | unets_kwargs = decoder_kwargs.pop('unets') 89 | is_video = decoder_kwargs.pop('video', False) 90 | 91 | unets = [] 92 | 93 | for unet, unet_kwargs in zip(self.unets, unets_kwargs): 94 | if isinstance(unet, NullUnetConfig): 95 | unet_klass = NullUnet 96 | elif is_video: 97 | unet_klass = Unet3D 98 | else: 99 | unet_klass = Unet 100 | 101 | unets.append(unet_klass(**unet_kwargs)) 102 | 103 | imagen = Imagen(unets, **decoder_kwargs) 104 | 105 | imagen._config = self.dict().copy() 106 | return imagen 107 | 108 | class ElucidatedImagenConfig(AllowExtraBaseModel): 109 | unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig]) 110 | image_sizes: ListOrTuple(int) 111 | video: bool = False 112 | text_encoder_name: str = DEFAULT_T5_NAME 113 | channels: int = 3 114 | cond_drop_prob: float = 0.5 115 | num_sample_steps: SingleOrList(int) = 32 116 | sigma_min: SingleOrList(float) = 0.002 117 | sigma_max: SingleOrList(int) = 80 118 | sigma_data: SingleOrList(float) = 0.5 119 | rho: SingleOrList(int) = 7 120 | P_mean: SingleOrList(float) = -1.2 121 | P_std: SingleOrList(float) = 1.2 122 | S_churn: SingleOrList(int) = 80 123 | S_tmin: SingleOrList(float) = 0.05 124 | S_tmax: SingleOrList(int) = 50 125 | S_noise: SingleOrList(float) = 1.003 126 | 127 | @validator('image_sizes') 128 | def check_image_sizes(cls, image_sizes, values): 129 | unets = values.get('unets') 130 | if len(image_sizes) != len(unets): 131 | raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}') 132 | return image_sizes 133 | 134 | def create(self): 135 | decoder_kwargs = self.dict() 136 | unets_kwargs = decoder_kwargs.pop('unets') 137 | is_video = decoder_kwargs.pop('video', False) 138 | 139 | unet_klass = Unet3D if is_video else Unet 140 | 141 | unets = [] 142 | 143 | for unet, unet_kwargs in zip(self.unets, unets_kwargs): 144 | if isinstance(unet, NullUnetConfig): 145 | unet_klass = NullUnet 146 | elif is_video: 147 | unet_klass = Unet3D 148 | else: 149 | unet_klass = Unet 150 | 151 | unets.append(unet_klass(**unet_kwargs)) 152 | 153 | imagen = ElucidatedImagen(unets, **decoder_kwargs) 154 | 155 | imagen._config = self.dict().copy() 156 | return imagen 157 | 158 | class ImagenTrainerConfig(AllowExtraBaseModel): 159 | imagen: dict 160 | elucidated: bool = False 161 | video: bool = False 162 | use_ema: bool = True 163 | lr: SingleOrList(float) = 1e-4 164 | eps: SingleOrList(float) = 1e-8 165 | beta1: float = 0.9 166 | beta2: float = 0.99 167 | max_grad_norm: Optional[float] = None 168 | group_wd_params: bool = True 169 | warmup_steps: SingleOrList(Optional[int]) = None 170 | cosine_decay_max_steps: SingleOrList(Optional[int]) = None 171 | 172 | def create(self): 173 | trainer_kwargs = self.dict() 174 | 175 | imagen_config = trainer_kwargs.pop('imagen') 176 | elucidated = trainer_kwargs.pop('elucidated') 177 | 178 | imagen_config_klass = ElucidatedImagenConfig if elucidated else ImagenConfig 179 | imagen = imagen_config_klass(**{**imagen_config, 'video': video}).create() 180 | 181 | return ImagenTrainer(imagen, **trainer_kwargs) 182 | -------------------------------------------------------------------------------- /imagen_pytorch/data.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from functools import partial 3 | 4 | import torch 5 | from torch import nn 6 | from torch.utils.data import Dataset, DataLoader 7 | from torchvision import transforms as T, utils 8 | 9 | from PIL import Image 10 | 11 | # helpers functions 12 | 13 | def exists(val): 14 | return val is not None 15 | 16 | def cycle(dl): 17 | while True: 18 | for data in dl: 19 | yield data 20 | 21 | def convert_image_to(img_type, image): 22 | if image.mode != img_type: 23 | return image.convert(img_type) 24 | return image 25 | 26 | # dataset and dataloader 27 | 28 | class Dataset(Dataset): 29 | def __init__( 30 | self, 31 | folder, 32 | image_size, 33 | exts = ['jpg', 'jpeg', 'png', 'tiff'], 34 | convert_image_to_type = None 35 | ): 36 | super().__init__() 37 | self.folder = folder 38 | self.image_size = image_size 39 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')] 40 | 41 | convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity() 42 | 43 | self.transform = T.Compose([ 44 | T.Lambda(convert_fn), 45 | T.Resize(image_size), 46 | T.RandomHorizontalFlip(), 47 | T.CenterCrop(image_size), 48 | T.ToTensor() 49 | ]) 50 | 51 | def __len__(self): 52 | return len(self.paths) 53 | 54 | def __getitem__(self, index): 55 | path = self.paths[index] 56 | img = Image.open(path) 57 | return self.transform(img) 58 | 59 | def get_images_dataloader( 60 | folder, 61 | *, 62 | batch_size, 63 | image_size, 64 | shuffle = True, 65 | cycle_dl = False, 66 | pin_memory = True 67 | ): 68 | ds = Dataset(folder, image_size) 69 | dl = DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory) 70 | 71 | if cycle_dl: 72 | dl = cycle(dl) 73 | return dl 74 | -------------------------------------------------------------------------------- /imagen_pytorch/elucidated_imagen.py: -------------------------------------------------------------------------------- 1 | from math import sqrt 2 | from random import random 3 | from functools import partial 4 | from contextlib import contextmanager, nullcontext 5 | from typing import List, Union 6 | from collections import namedtuple 7 | from tqdm.auto import tqdm 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn, einsum 12 | from torch.cuda.amp import autocast 13 | from torch.nn.parallel import DistributedDataParallel 14 | import torchvision.transforms as T 15 | 16 | import kornia.augmentation as K 17 | 18 | from einops import rearrange, repeat, reduce 19 | from einops_exts import rearrange_many 20 | 21 | from imagen_pytorch.imagen_pytorch import ( 22 | GaussianDiffusionContinuousTimes, 23 | Unet, 24 | NullUnet, 25 | first, 26 | exists, 27 | identity, 28 | maybe, 29 | default, 30 | cast_tuple, 31 | cast_uint8_images_to_float, 32 | is_float_dtype, 33 | eval_decorator, 34 | check_shape, 35 | pad_tuple_to_length, 36 | resize_image_to, 37 | right_pad_dims_to, 38 | module_device, 39 | normalize_neg_one_to_one, 40 | unnormalize_zero_to_one, 41 | ) 42 | 43 | from imagen_pytorch.imagen_video.imagen_video import ( 44 | Unet3D, 45 | resize_video_to 46 | ) 47 | 48 | from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME 49 | 50 | # constants 51 | 52 | Hparams_fields = [ 53 | 'num_sample_steps', 54 | 'sigma_min', 55 | 'sigma_max', 56 | 'sigma_data', 57 | 'rho', 58 | 'P_mean', 59 | 'P_std', 60 | 'S_churn', 61 | 'S_tmin', 62 | 'S_tmax', 63 | 'S_noise' 64 | ] 65 | 66 | Hparams = namedtuple('Hparams', Hparams_fields) 67 | 68 | # helper functions 69 | 70 | def log(t, eps = 1e-20): 71 | return torch.log(t.clamp(min = eps)) 72 | 73 | # main class 74 | 75 | class ElucidatedImagen(nn.Module): 76 | def __init__( 77 | self, 78 | unets, 79 | *, 80 | image_sizes, # for cascading ddpm, image size at each stage 81 | text_encoder_name = DEFAULT_T5_NAME, 82 | text_embed_dim = None, 83 | channels = 3, 84 | cond_drop_prob = 0.1, 85 | random_crop_sizes = None, 86 | lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level 87 | per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find 88 | condition_on_text = True, 89 | auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader 90 | dynamic_thresholding = True, 91 | dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper 92 | only_train_unet_number = None, 93 | lowres_noise_schedule = 'linear', 94 | num_sample_steps = 32, # number of sampling steps 95 | sigma_min = 0.002, # min noise level 96 | sigma_max = 80, # max noise level 97 | sigma_data = 0.5, # standard deviation of data distribution 98 | rho = 7, # controls the sampling schedule 99 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training 100 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training 101 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper 102 | S_tmin = 0.05, 103 | S_tmax = 50, 104 | S_noise = 1.003, 105 | ): 106 | super().__init__() 107 | 108 | self.only_train_unet_number = only_train_unet_number 109 | 110 | # conditioning hparams 111 | 112 | self.condition_on_text = condition_on_text 113 | self.unconditional = not condition_on_text 114 | 115 | # channels 116 | 117 | self.channels = channels 118 | 119 | # automatically take care of ensuring that first unet is unconditional 120 | # while the rest of the unets are conditioned on the low resolution image produced by previous unet 121 | 122 | unets = cast_tuple(unets) 123 | num_unets = len(unets) 124 | 125 | # randomly cropping for upsampler training 126 | 127 | self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets) 128 | assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example' 129 | 130 | # lowres augmentation noise schedule 131 | 132 | self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule) 133 | 134 | # get text encoder 135 | 136 | self.text_encoder_name = text_encoder_name 137 | self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name)) 138 | 139 | self.encode_text = partial(t5_encode_text, name = text_encoder_name) 140 | 141 | # construct unets 142 | 143 | self.unets = nn.ModuleList([]) 144 | self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment 145 | 146 | for ind, one_unet in enumerate(unets): 147 | assert isinstance(one_unet, (Unet, Unet3D, NullUnet)) 148 | is_first = ind == 0 149 | 150 | one_unet = one_unet.cast_model_parameters( 151 | lowres_cond = not is_first, 152 | cond_on_text = self.condition_on_text, 153 | text_embed_dim = self.text_embed_dim if self.condition_on_text else None, 154 | channels = self.channels, 155 | channels_out = self.channels 156 | ) 157 | 158 | self.unets.append(one_unet) 159 | 160 | # determine whether we are training on images or video 161 | 162 | is_video = any([isinstance(unet, Unet3D) for unet in self.unets]) 163 | self.is_video = is_video 164 | 165 | self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1')) 166 | self.resize_to = resize_video_to if is_video else resize_image_to 167 | 168 | # unet image sizes 169 | 170 | self.image_sizes = cast_tuple(image_sizes) 171 | assert num_unets == len(self.image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}' 172 | 173 | self.sample_channels = cast_tuple(self.channels, num_unets) 174 | 175 | # cascading ddpm related stuff 176 | 177 | lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) 178 | assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True' 179 | 180 | self.lowres_sample_noise_level = lowres_sample_noise_level 181 | self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level 182 | 183 | # classifier free guidance 184 | 185 | self.cond_drop_prob = cond_drop_prob 186 | self.can_classifier_guidance = cond_drop_prob > 0. 187 | 188 | # normalize and unnormalize image functions 189 | 190 | self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity 191 | self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity 192 | self.input_image_range = (0. if auto_normalize_img else -1., 1.) 193 | 194 | # dynamic thresholding 195 | 196 | self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets) 197 | self.dynamic_thresholding_percentile = dynamic_thresholding_percentile 198 | 199 | # elucidating parameters 200 | 201 | hparams = [ 202 | num_sample_steps, 203 | sigma_min, 204 | sigma_max, 205 | sigma_data, 206 | rho, 207 | P_mean, 208 | P_std, 209 | S_churn, 210 | S_tmin, 211 | S_tmax, 212 | S_noise, 213 | ] 214 | 215 | hparams = [cast_tuple(hp, num_unets) for hp in hparams] 216 | self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)] 217 | 218 | # one temp parameter for keeping track of device 219 | 220 | self.register_buffer('_temp', torch.tensor([0.]), persistent = False) 221 | 222 | # default to device of unets passed in 223 | 224 | self.to(next(self.unets.parameters()).device) 225 | 226 | def force_unconditional_(self): 227 | self.condition_on_text = False 228 | self.unconditional = True 229 | 230 | for unet in self.unets: 231 | unet.cond_on_text = False 232 | 233 | @property 234 | def device(self): 235 | return self._temp.device 236 | 237 | def get_unet(self, unet_number): 238 | assert 0 < unet_number <= len(self.unets) 239 | index = unet_number - 1 240 | 241 | if isinstance(self.unets, nn.ModuleList): 242 | unets_list = [unet for unet in self.unets] 243 | delattr(self, 'unets') 244 | self.unets = unets_list 245 | 246 | if index != self.unet_being_trained_index: 247 | for unet_index, unet in enumerate(self.unets): 248 | unet.to(self.device if unet_index == index else 'cpu') 249 | 250 | self.unet_being_trained_index = index 251 | return self.unets[index] 252 | 253 | def reset_unets_all_one_device(self, device = None): 254 | device = default(device, self.device) 255 | self.unets = nn.ModuleList([*self.unets]) 256 | self.unets.to(device) 257 | 258 | self.unet_being_trained_index = -1 259 | 260 | @contextmanager 261 | def one_unet_in_gpu(self, unet_number = None, unet = None): 262 | assert exists(unet_number) ^ exists(unet) 263 | 264 | if exists(unet_number): 265 | unet = self.unets[unet_number - 1] 266 | 267 | devices = [module_device(unet) for unet in self.unets] 268 | self.unets.cpu() 269 | unet.to(self.device) 270 | 271 | yield 272 | 273 | for unet, device in zip(self.unets, devices): 274 | unet.to(device) 275 | 276 | # overriding state dict functions 277 | 278 | def state_dict(self, *args, **kwargs): 279 | self.reset_unets_all_one_device() 280 | return super().state_dict(*args, **kwargs) 281 | 282 | def load_state_dict(self, *args, **kwargs): 283 | self.reset_unets_all_one_device() 284 | return super().load_state_dict(*args, **kwargs) 285 | 286 | # dynamic thresholding 287 | 288 | def threshold_x_start(self, x_start, dynamic_threshold = True): 289 | if not dynamic_threshold: 290 | return x_start.clamp(-1., 1.) 291 | 292 | s = torch.quantile( 293 | rearrange(x_start, 'b ... -> b (...)').abs(), 294 | self.dynamic_thresholding_percentile, 295 | dim = -1 296 | ) 297 | 298 | s.clamp_(min = 1.) 299 | s = right_pad_dims_to(x_start, s) 300 | return x_start.clamp(-s, s) / s 301 | 302 | # derived preconditioning params - Table 1 303 | 304 | def c_skip(self, sigma_data, sigma): 305 | return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2) 306 | 307 | def c_out(self, sigma_data, sigma): 308 | return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5 309 | 310 | def c_in(self, sigma_data, sigma): 311 | return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5 312 | 313 | def c_noise(self, sigma): 314 | return log(sigma) * 0.25 315 | 316 | # preconditioned network output 317 | # equation (7) in the paper 318 | 319 | def preconditioned_network_forward( 320 | self, 321 | unet_forward, 322 | noised_images, 323 | sigma, 324 | *, 325 | sigma_data, 326 | clamp = False, 327 | dynamic_threshold = True, 328 | **kwargs 329 | ): 330 | batch, device = noised_images.shape[0], noised_images.device 331 | 332 | if isinstance(sigma, float): 333 | sigma = torch.full((batch,), sigma, device = device) 334 | 335 | padded_sigma = self.right_pad_dims_to_datatype(sigma) 336 | 337 | net_out = unet_forward( 338 | self.c_in(sigma_data, padded_sigma) * noised_images, 339 | self.c_noise(sigma), 340 | **kwargs 341 | ) 342 | 343 | out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out 344 | 345 | if not clamp: 346 | return out 347 | 348 | return self.threshold_x_start(out, dynamic_threshold) 349 | 350 | # sampling 351 | 352 | # sample schedule 353 | # equation (5) in the paper 354 | 355 | def sample_schedule( 356 | self, 357 | num_sample_steps, 358 | rho, 359 | sigma_min, 360 | sigma_max 361 | ): 362 | N = num_sample_steps 363 | inv_rho = 1 / rho 364 | 365 | steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32) 366 | sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho 367 | 368 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0. 369 | return sigmas 370 | 371 | @torch.no_grad() 372 | def one_unet_sample( 373 | self, 374 | unet, 375 | shape, 376 | *, 377 | unet_number, 378 | clamp = True, 379 | dynamic_threshold = True, 380 | cond_scale = 1., 381 | use_tqdm = True, 382 | inpaint_images = None, 383 | inpaint_masks = None, 384 | inpaint_resample_times = 5, 385 | init_images = None, 386 | skip_steps = None, 387 | sigma_min = None, 388 | sigma_max = None, 389 | **kwargs 390 | ): 391 | # get specific sampling hyperparameters for unet 392 | 393 | hp = self.hparams[unet_number - 1] 394 | 395 | sigma_min = default(sigma_min, hp.sigma_min) 396 | sigma_max = default(sigma_max, hp.sigma_max) 397 | 398 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma 399 | 400 | sigmas = self.sample_schedule(hp.num_sample_steps, hp.rho, sigma_min, sigma_max) 401 | 402 | gammas = torch.where( 403 | (sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax), 404 | min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1), 405 | 0. 406 | ) 407 | 408 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1])) 409 | 410 | # images is noise at the beginning 411 | 412 | init_sigma = sigmas[0] 413 | 414 | images = init_sigma * torch.randn(shape, device = self.device) 415 | 416 | # initializing with an image 417 | 418 | if exists(init_images): 419 | images += init_images 420 | 421 | # keeping track of x0, for self conditioning if needed 422 | 423 | x_start = None 424 | 425 | # prepare inpainting images and mask 426 | 427 | has_inpainting = exists(inpaint_images) and exists(inpaint_masks) 428 | resample_times = inpaint_resample_times if has_inpainting else 1 429 | 430 | if has_inpainting: 431 | inpaint_images = self.normalize_img(inpaint_images) 432 | inpaint_images = self.resize_to(inpaint_images, shape[-1]) 433 | inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool() 434 | 435 | # unet kwargs 436 | 437 | unet_kwargs = dict( 438 | sigma_data = hp.sigma_data, 439 | clamp = clamp, 440 | dynamic_threshold = dynamic_threshold, 441 | cond_scale = cond_scale, 442 | **kwargs 443 | ) 444 | 445 | # gradually denoise 446 | 447 | initial_step = default(skip_steps, 0) 448 | sigmas_and_gammas = sigmas_and_gammas[initial_step:] 449 | 450 | total_steps = len(sigmas_and_gammas) 451 | 452 | for ind, (sigma, sigma_next, gamma) in tqdm(enumerate(sigmas_and_gammas), total = total_steps, desc = 'sampling time step', disable = not use_tqdm): 453 | is_last_timestep = ind == (total_steps - 1) 454 | 455 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma)) 456 | 457 | for r in reversed(range(resample_times)): 458 | is_last_resample_step = r == 0 459 | 460 | eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling 461 | 462 | sigma_hat = sigma + gamma * sigma 463 | added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps 464 | 465 | images_hat = images + added_noise 466 | 467 | self_cond = x_start if unet.self_cond else None 468 | 469 | if has_inpainting: 470 | images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks 471 | 472 | model_output = self.preconditioned_network_forward( 473 | unet.forward_with_cond_scale, 474 | images_hat, 475 | sigma_hat, 476 | self_cond = self_cond, 477 | **unet_kwargs 478 | ) 479 | 480 | denoised_over_sigma = (images_hat - model_output) / sigma_hat 481 | 482 | images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma 483 | 484 | # second order correction, if not the last timestep 485 | 486 | if sigma_next != 0: 487 | self_cond = model_output if unet.self_cond else None 488 | 489 | model_output_next = self.preconditioned_network_forward( 490 | unet.forward_with_cond_scale, 491 | images_next, 492 | sigma_next, 493 | self_cond = self_cond, 494 | **unet_kwargs 495 | ) 496 | 497 | denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next 498 | images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma) 499 | 500 | images = images_next 501 | 502 | if has_inpainting and not (is_last_resample_step or is_last_timestep): 503 | # renoise in repaint and then resample 504 | repaint_noise = torch.randn(shape, device = self.device) 505 | images = images + (sigma - sigma_next) * repaint_noise 506 | 507 | x_start = model_output # save model output for self conditioning 508 | 509 | images = images.clamp(-1., 1.) 510 | 511 | if has_inpainting: 512 | images = images * ~inpaint_masks + inpaint_images * inpaint_masks 513 | 514 | return self.unnormalize_img(images) 515 | 516 | @torch.no_grad() 517 | @eval_decorator 518 | def sample( 519 | self, 520 | texts: List[str] = None, 521 | text_masks = None, 522 | text_embeds = None, 523 | cond_images = None, 524 | inpaint_images = None, 525 | inpaint_masks = None, 526 | inpaint_resample_times = 5, 527 | init_images = None, 528 | skip_steps = None, 529 | sigma_min = None, 530 | sigma_max = None, 531 | video_frames = None, 532 | batch_size = 1, 533 | cond_scale = 1., 534 | lowres_sample_noise_level = None, 535 | start_at_unet_number = 1, 536 | start_image_or_video = None, 537 | stop_at_unet_number = None, 538 | return_all_unet_outputs = False, 539 | return_pil_images = False, 540 | use_tqdm = True, 541 | device = None, 542 | ): 543 | device = default(device, self.device) 544 | self.reset_unets_all_one_device(device = device) 545 | 546 | cond_images = maybe(cast_uint8_images_to_float)(cond_images) 547 | 548 | if exists(texts) and not exists(text_embeds) and not self.unconditional: 549 | assert all([*map(len, texts)]), 'text cannot be empty' 550 | 551 | with autocast(enabled = False): 552 | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) 553 | 554 | text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks)) 555 | 556 | if not self.unconditional: 557 | assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training' 558 | 559 | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) 560 | batch_size = text_embeds.shape[0] 561 | 562 | if exists(inpaint_images): 563 | if self.unconditional: 564 | if batch_size == 1: # assume researcher wants to broadcast along inpainted images 565 | batch_size = inpaint_images.shape[0] 566 | 567 | assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=)``' 568 | assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on' 569 | 570 | assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified' 571 | assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented' 572 | assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' 573 | 574 | assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting' 575 | 576 | outputs = [] 577 | 578 | is_cuda = next(self.parameters()).is_cuda 579 | device = next(self.parameters()).device 580 | 581 | lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level) 582 | 583 | num_unets = len(self.unets) 584 | cond_scale = cast_tuple(cond_scale, num_unets) 585 | 586 | # handle video and frame dimension 587 | 588 | assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video' 589 | 590 | frame_dims = (video_frames,) if self.is_video else tuple() 591 | 592 | # initializing with an image or video 593 | 594 | init_images = cast_tuple(init_images, num_unets) 595 | init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images] 596 | 597 | skip_steps = cast_tuple(skip_steps, num_unets) 598 | 599 | sigma_min = cast_tuple(sigma_min, num_unets) 600 | sigma_max = cast_tuple(sigma_max, num_unets) 601 | 602 | # handle starting at a unet greater than 1, for training only-upscaler training 603 | 604 | if start_at_unet_number > 1: 605 | assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets' 606 | assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number 607 | assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling' 608 | 609 | prev_image_size = self.image_sizes[start_at_unet_number - 2] 610 | img = self.resize_to(start_image_or_video, prev_image_size) 611 | 612 | # go through each unet in cascade 613 | 614 | for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm): 615 | if unet_number < start_at_unet_number: 616 | continue 617 | 618 | assert not isinstance(unet, NullUnet), 'cannot sample from null unet' 619 | 620 | context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext() 621 | 622 | with context: 623 | lowres_cond_img = lowres_noise_times = None 624 | 625 | shape = (batch_size, channel, *frame_dims, image_size, image_size) 626 | 627 | if unet.lowres_cond: 628 | lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device) 629 | 630 | lowres_cond_img = self.resize_to(img, image_size) 631 | lowres_cond_img = self.normalize_img(lowres_cond_img) 632 | 633 | lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img)) 634 | 635 | if exists(unet_init_images): 636 | unet_init_images = self.resize_to(unet_init_images, image_size) 637 | 638 | shape = (batch_size, self.channels, *frame_dims, image_size, image_size) 639 | 640 | img = self.one_unet_sample( 641 | unet, 642 | shape, 643 | unet_number = unet_number, 644 | text_embeds = text_embeds, 645 | text_mask = text_masks, 646 | cond_images = cond_images, 647 | inpaint_images = inpaint_images, 648 | inpaint_masks = inpaint_masks, 649 | inpaint_resample_times = inpaint_resample_times, 650 | init_images = unet_init_images, 651 | skip_steps = unet_skip_steps, 652 | sigma_min = unet_sigma_min, 653 | sigma_max = unet_sigma_max, 654 | cond_scale = unet_cond_scale, 655 | lowres_cond_img = lowres_cond_img, 656 | lowres_noise_times = lowres_noise_times, 657 | dynamic_threshold = dynamic_threshold, 658 | use_tqdm = use_tqdm 659 | ) 660 | 661 | outputs.append(img) 662 | 663 | if exists(stop_at_unet_number) and stop_at_unet_number == unet_number: 664 | break 665 | 666 | output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs 667 | 668 | if not return_pil_images: 669 | return outputs[output_index] 670 | 671 | if not return_all_unet_outputs: 672 | outputs = outputs[-1:] 673 | 674 | assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet' 675 | 676 | pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs)) 677 | 678 | return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png) 679 | 680 | # training 681 | 682 | def loss_weight(self, sigma_data, sigma): 683 | return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2 684 | 685 | def noise_distribution(self, P_mean, P_std, batch_size): 686 | return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp() 687 | 688 | def forward( 689 | self, 690 | images, 691 | unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None, 692 | texts: List[str] = None, 693 | text_embeds = None, 694 | text_masks = None, 695 | unet_number = None, 696 | cond_images = None 697 | ): 698 | assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}' 699 | assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)' 700 | unet_number = default(unet_number, 1) 701 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}' 702 | 703 | images = cast_uint8_images_to_float(images) 704 | cond_images = maybe(cast_uint8_images_to_float)(cond_images) 705 | 706 | assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead' 707 | 708 | unet_index = unet_number - 1 709 | 710 | unet = default(unet, lambda: self.get_unet(unet_number)) 711 | 712 | assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained' 713 | 714 | target_image_size = self.image_sizes[unet_index] 715 | random_crop_size = self.random_crop_sizes[unet_index] 716 | prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None 717 | hp = self.hparams[unet_index] 718 | 719 | batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5) 720 | 721 | frames = images.shape[2] if is_video else None 722 | 723 | check_shape(images, 'b c ...', c = self.channels) 724 | 725 | assert h >= target_image_size and w >= target_image_size 726 | 727 | if exists(texts) and not exists(text_embeds) and not self.unconditional: 728 | assert all([*map(len, texts)]), 'text cannot be empty' 729 | assert len(texts) == len(images), 'number of text captions does not match up with the number of images given' 730 | 731 | with autocast(enabled = False): 732 | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True) 733 | 734 | text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks)) 735 | 736 | if not self.unconditional: 737 | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1)) 738 | 739 | assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified' 740 | assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented' 741 | 742 | assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})' 743 | 744 | lowres_cond_img = lowres_aug_times = None 745 | if exists(prev_image_size): 746 | lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range) 747 | lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range) 748 | 749 | if self.per_sample_random_aug_noise_level: 750 | lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device) 751 | else: 752 | lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device) 753 | lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size) 754 | 755 | images = self.resize_to(images, target_image_size) 756 | 757 | # normalize to [-1, 1] 758 | 759 | images = self.normalize_img(images) 760 | lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img) 761 | 762 | # random cropping during training 763 | # for upsamplers 764 | 765 | if exists(random_crop_size): 766 | aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.) 767 | 768 | if is_video: 769 | images, lowres_cond_img = rearrange_many((images, lowres_cond_img), 'b c f h w -> (b f) c h w') 770 | 771 | # make sure low res conditioner and image both get augmented the same way 772 | # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop 773 | images = aug(images) 774 | lowres_cond_img = aug(lowres_cond_img, params = aug._params) 775 | 776 | if is_video: 777 | images, lowres_cond_img = rearrange_many((images, lowres_cond_img), '(b f) c h w -> b c f h w', f = frames) 778 | 779 | # noise the lowres conditioning image 780 | # at sample time, they then fix the noise level of 0.1 - 0.3 781 | 782 | lowres_cond_img_noisy = None 783 | if exists(lowres_cond_img): 784 | lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img)) 785 | 786 | # get the sigmas 787 | 788 | sigmas = self.noise_distribution(hp.P_mean, hp.P_std, batch_size) 789 | padded_sigmas = self.right_pad_dims_to_datatype(sigmas) 790 | 791 | # noise 792 | 793 | noise = torch.randn_like(images) 794 | noised_images = images + padded_sigmas * noise # alphas are 1. in the paper 795 | 796 | # unet kwargs 797 | 798 | unet_kwargs = dict( 799 | sigma_data = hp.sigma_data, 800 | text_embeds = text_embeds, 801 | text_mask = text_masks, 802 | cond_images = cond_images, 803 | lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times), 804 | lowres_cond_img = lowres_cond_img_noisy, 805 | cond_drop_prob = self.cond_drop_prob, 806 | ) 807 | 808 | # self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower 809 | 810 | # Because 'unet' can be an instance of DistributedDataParallel coming from the 811 | # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to 812 | # access the member 'module' of the wrapped unet instance. 813 | self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet 814 | 815 | if self_cond and random() < 0.5: 816 | with torch.no_grad(): 817 | pred_x0 = self.preconditioned_network_forward( 818 | unet.forward, 819 | noised_images, 820 | sigmas, 821 | **unet_kwargs 822 | ).detach() 823 | 824 | unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0} 825 | 826 | # get prediction 827 | 828 | denoised_images = self.preconditioned_network_forward( 829 | unet.forward, 830 | noised_images, 831 | sigmas, 832 | **unet_kwargs 833 | ) 834 | 835 | # losses 836 | 837 | losses = F.mse_loss(denoised_images, images, reduction = 'none') 838 | losses = reduce(losses, 'b ... -> b', 'mean') 839 | 840 | # loss weighting 841 | 842 | losses = losses * self.loss_weight(hp.sigma_data, sigmas) 843 | 844 | # return average loss 845 | 846 | return losses.mean() 847 | -------------------------------------------------------------------------------- /imagen_pytorch/imagen_video/__init__.py: -------------------------------------------------------------------------------- 1 | from imagen_pytorch.imagen_video.imagen_video import Unet3D 2 | -------------------------------------------------------------------------------- /imagen_pytorch/imagen_video/imagen_video.py: -------------------------------------------------------------------------------- 1 | import math 2 | import copy 3 | from typing import List 4 | from tqdm.auto import tqdm 5 | from functools import partial, wraps 6 | from contextlib import contextmanager, nullcontext 7 | from collections import namedtuple 8 | from pathlib import Path 9 | 10 | import torch 11 | import torch.nn.functional as F 12 | from torch import nn, einsum 13 | 14 | from einops import rearrange, repeat, reduce 15 | from einops.layers.torch import Rearrange, Reduce 16 | from einops_exts import rearrange_many, repeat_many, check_shape 17 | from einops_exts.torch import EinopsToAndFrom 18 | 19 | from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME 20 | 21 | # helper functions 22 | 23 | def exists(val): 24 | return val is not None 25 | 26 | def identity(t, *args, **kwargs): 27 | return t 28 | 29 | def first(arr, d = None): 30 | if len(arr) == 0: 31 | return d 32 | return arr[0] 33 | 34 | def maybe(fn): 35 | @wraps(fn) 36 | def inner(x): 37 | if not exists(x): 38 | return x 39 | return fn(x) 40 | return inner 41 | 42 | def once(fn): 43 | called = False 44 | @wraps(fn) 45 | def inner(x): 46 | nonlocal called 47 | if called: 48 | return 49 | called = True 50 | return fn(x) 51 | return inner 52 | 53 | print_once = once(print) 54 | 55 | def default(val, d): 56 | if exists(val): 57 | return val 58 | return d() if callable(d) else d 59 | 60 | def cast_tuple(val, length = None): 61 | if isinstance(val, list): 62 | val = tuple(val) 63 | 64 | output = val if isinstance(val, tuple) else ((val,) * default(length, 1)) 65 | 66 | if exists(length): 67 | assert len(output) == length 68 | 69 | return output 70 | 71 | def cast_uint8_images_to_float(images): 72 | if not images.dtype == torch.uint8: 73 | return images 74 | return images / 255 75 | 76 | def module_device(module): 77 | return next(module.parameters()).device 78 | 79 | def zero_init_(m): 80 | nn.init.zeros_(m.weight) 81 | if exists(m.bias): 82 | nn.init.zeros_(m.bias) 83 | 84 | def eval_decorator(fn): 85 | def inner(model, *args, **kwargs): 86 | was_training = model.training 87 | model.eval() 88 | out = fn(model, *args, **kwargs) 89 | model.train(was_training) 90 | return out 91 | return inner 92 | 93 | def pad_tuple_to_length(t, length, fillvalue = None): 94 | remain_length = length - len(t) 95 | if remain_length <= 0: 96 | return t 97 | return (*t, *((fillvalue,) * remain_length)) 98 | 99 | # helper classes 100 | 101 | class Identity(nn.Module): 102 | def __init__(self, *args, **kwargs): 103 | super().__init__() 104 | 105 | def forward(self, x, *args, **kwargs): 106 | return x 107 | 108 | # tensor helpers 109 | 110 | def log(t, eps: float = 1e-12): 111 | return torch.log(t.clamp(min = eps)) 112 | 113 | def l2norm(t): 114 | return F.normalize(t, dim = -1) 115 | 116 | def right_pad_dims_to(x, t): 117 | padding_dims = x.ndim - t.ndim 118 | if padding_dims <= 0: 119 | return t 120 | return t.view(*t.shape, *((1,) * padding_dims)) 121 | 122 | def masked_mean(t, *, dim, mask = None): 123 | if not exists(mask): 124 | return t.mean(dim = dim) 125 | 126 | denom = mask.sum(dim = dim, keepdim = True) 127 | mask = rearrange(mask, 'b n -> b n 1') 128 | masked_t = t.masked_fill(~mask, 0.) 129 | 130 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5) 131 | 132 | def resize_video_to( 133 | video, 134 | target_image_size, 135 | clamp_range = None 136 | ): 137 | orig_video_size = video.shape[-1] 138 | 139 | if orig_video_size == target_image_size: 140 | return video 141 | 142 | 143 | frames = video.shape[2] 144 | video = rearrange(video, 'b c f h w -> (b f) c h w') 145 | 146 | out = F.interpolate(video, target_image_size, mode = 'nearest') 147 | 148 | if exists(clamp_range): 149 | out = out.clamp(*clamp_range) 150 | 151 | out = rearrange(out, '(b f) c h w -> b c f h w', f = frames) 152 | 153 | return out 154 | 155 | # classifier free guidance functions 156 | 157 | def prob_mask_like(shape, prob, device): 158 | if prob == 1: 159 | return torch.ones(shape, device = device, dtype = torch.bool) 160 | elif prob == 0: 161 | return torch.zeros(shape, device = device, dtype = torch.bool) 162 | else: 163 | return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob 164 | 165 | # norms and residuals 166 | 167 | class LayerNorm(nn.Module): 168 | def __init__(self, dim, stable = False): 169 | super().__init__() 170 | self.stable = stable 171 | self.g = nn.Parameter(torch.ones(dim)) 172 | 173 | def forward(self, x): 174 | if self.stable: 175 | x = x / x.amax(dim = -1, keepdim = True).detach() 176 | 177 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 178 | var = torch.var(x, dim = -1, unbiased = False, keepdim = True) 179 | mean = torch.mean(x, dim = -1, keepdim = True) 180 | return (x - mean) * (var + eps).rsqrt() * self.g 181 | 182 | class ChanLayerNorm(nn.Module): 183 | def __init__(self, dim, stable = False): 184 | super().__init__() 185 | self.stable = stable 186 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) 187 | 188 | def forward(self, x): 189 | if self.stable: 190 | x = x / x.amax(dim = 1, keepdim = True).detach() 191 | 192 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3 193 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True) 194 | mean = torch.mean(x, dim = 1, keepdim = True) 195 | return (x - mean) * (var + eps).rsqrt() * self.g 196 | 197 | class Always(): 198 | def __init__(self, val): 199 | self.val = val 200 | 201 | def __call__(self, *args, **kwargs): 202 | return self.val 203 | 204 | class Residual(nn.Module): 205 | def __init__(self, fn): 206 | super().__init__() 207 | self.fn = fn 208 | 209 | def forward(self, x, **kwargs): 210 | return self.fn(x, **kwargs) + x 211 | 212 | class Parallel(nn.Module): 213 | def __init__(self, *fns): 214 | super().__init__() 215 | self.fns = nn.ModuleList(fns) 216 | 217 | def forward(self, x): 218 | outputs = [fn(x) for fn in self.fns] 219 | return sum(outputs) 220 | 221 | # attention pooling 222 | 223 | class PerceiverAttention(nn.Module): 224 | def __init__( 225 | self, 226 | *, 227 | dim, 228 | dim_head = 64, 229 | heads = 8, 230 | cosine_sim_attn = False 231 | ): 232 | super().__init__() 233 | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1 234 | self.cosine_sim_attn = cosine_sim_attn 235 | self.cosine_sim_scale = 16 if cosine_sim_attn else 1 236 | 237 | self.heads = heads 238 | inner_dim = dim_head * heads 239 | 240 | self.norm = nn.LayerNorm(dim) 241 | self.norm_latents = nn.LayerNorm(dim) 242 | 243 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 244 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False) 245 | 246 | self.to_out = nn.Sequential( 247 | nn.Linear(inner_dim, dim, bias = False), 248 | nn.LayerNorm(dim) 249 | ) 250 | 251 | def forward(self, x, latents, mask = None): 252 | x = self.norm(x) 253 | latents = self.norm_latents(latents) 254 | 255 | b, h = x.shape[0], self.heads 256 | 257 | q = self.to_q(latents) 258 | 259 | # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to 260 | kv_input = torch.cat((x, latents), dim = -2) 261 | k, v = self.to_kv(kv_input).chunk(2, dim = -1) 262 | 263 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h) 264 | 265 | q = q * self.scale 266 | 267 | # cosine sim attention 268 | 269 | if self.cosine_sim_attn: 270 | q, k = map(l2norm, (q, k)) 271 | 272 | # similarities and masking 273 | 274 | sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale 275 | 276 | if exists(mask): 277 | max_neg_value = -torch.finfo(sim.dtype).max 278 | mask = F.pad(mask, (0, latents.shape[-2]), value = True) 279 | mask = rearrange(mask, 'b j -> b 1 1 j') 280 | sim = sim.masked_fill(~mask, max_neg_value) 281 | 282 | # attention 283 | 284 | attn = sim.softmax(dim = -1) 285 | 286 | out = einsum('... i j, ... j d -> ... i d', attn, v) 287 | out = rearrange(out, 'b h n d -> b n (h d)', h = h) 288 | return self.to_out(out) 289 | 290 | class PerceiverResampler(nn.Module): 291 | def __init__( 292 | self, 293 | *, 294 | dim, 295 | depth, 296 | dim_head = 64, 297 | heads = 8, 298 | num_latents = 64, 299 | num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence 300 | max_seq_len = 512, 301 | ff_mult = 4, 302 | cosine_sim_attn = False 303 | ): 304 | super().__init__() 305 | self.pos_emb = nn.Embedding(max_seq_len, dim) 306 | 307 | self.latents = nn.Parameter(torch.randn(num_latents, dim)) 308 | 309 | self.to_latents_from_mean_pooled_seq = None 310 | 311 | if num_latents_mean_pooled > 0: 312 | self.to_latents_from_mean_pooled_seq = nn.Sequential( 313 | LayerNorm(dim), 314 | nn.Linear(dim, dim * num_latents_mean_pooled), 315 | Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled) 316 | ) 317 | 318 | self.layers = nn.ModuleList([]) 319 | for _ in range(depth): 320 | self.layers.append(nn.ModuleList([ 321 | PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn), 322 | FeedForward(dim = dim, mult = ff_mult) 323 | ])) 324 | 325 | def forward(self, x, mask = None): 326 | n, device = x.shape[1], x.device 327 | pos_emb = self.pos_emb(torch.arange(n, device = device)) 328 | 329 | x_with_pos = x + pos_emb 330 | 331 | latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0]) 332 | 333 | if exists(self.to_latents_from_mean_pooled_seq): 334 | meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool)) 335 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq) 336 | latents = torch.cat((meanpooled_latents, latents), dim = -2) 337 | 338 | for attn, ff in self.layers: 339 | latents = attn(x_with_pos, latents, mask = mask) + latents 340 | latents = ff(latents) + latents 341 | 342 | return latents 343 | 344 | # attention 345 | 346 | class Attention(nn.Module): 347 | def __init__( 348 | self, 349 | dim, 350 | *, 351 | dim_head = 64, 352 | heads = 8, 353 | causal = False, 354 | context_dim = None, 355 | cosine_sim_attn = False 356 | ): 357 | super().__init__() 358 | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. 359 | self.causal = causal 360 | 361 | self.cosine_sim_attn = cosine_sim_attn 362 | self.cosine_sim_scale = 16 if cosine_sim_attn else 1 363 | 364 | self.heads = heads 365 | inner_dim = dim_head * heads 366 | 367 | self.norm = LayerNorm(dim) 368 | 369 | self.null_attn_bias = nn.Parameter(torch.randn(heads)) 370 | 371 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 372 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 373 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False) 374 | 375 | self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None 376 | 377 | self.to_out = nn.Sequential( 378 | nn.Linear(inner_dim, dim, bias = False), 379 | LayerNorm(dim) 380 | ) 381 | 382 | def forward(self, x, context = None, mask = None, attn_bias = None): 383 | b, n, device = *x.shape[:2], x.device 384 | 385 | x = self.norm(x) 386 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1)) 387 | 388 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads) 389 | q = q * self.scale 390 | 391 | # add null key / value for classifier free guidance in prior net 392 | 393 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b) 394 | k = torch.cat((nk, k), dim = -2) 395 | v = torch.cat((nv, v), dim = -2) 396 | 397 | # add text conditioning, if present 398 | 399 | if exists(context): 400 | assert exists(self.to_context) 401 | ck, cv = self.to_context(context).chunk(2, dim = -1) 402 | k = torch.cat((ck, k), dim = -2) 403 | v = torch.cat((cv, v), dim = -2) 404 | 405 | # cosine sim attention 406 | 407 | if self.cosine_sim_attn: 408 | q, k = map(l2norm, (q, k)) 409 | 410 | # calculate query / key similarities 411 | 412 | sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale 413 | 414 | # relative positional encoding (T5 style) 415 | 416 | if exists(attn_bias): 417 | null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n) 418 | attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1) 419 | sim = sim + attn_bias 420 | 421 | # masking 422 | 423 | max_neg_value = -torch.finfo(sim.dtype).max 424 | 425 | if self.causal: 426 | i, j = sim.shape[-2:] 427 | causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1) 428 | sim = sim.masked_fill(causal_mask, max_neg_value) 429 | 430 | if exists(mask): 431 | mask = F.pad(mask, (1, 0), value = True) 432 | mask = rearrange(mask, 'b j -> b 1 1 j') 433 | sim = sim.masked_fill(~mask, max_neg_value) 434 | 435 | # attention 436 | 437 | attn = sim.softmax(dim = -1) 438 | 439 | # aggregate values 440 | 441 | out = einsum('b h i j, b j d -> b h i d', attn, v) 442 | 443 | out = rearrange(out, 'b h n d -> b n (h d)') 444 | return self.to_out(out) 445 | 446 | # pseudo conv2d that uses conv3d but with kernel size of 1 across frames dimension 447 | 448 | def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs): 449 | kernel = cast_tuple(kernel, 2) 450 | stride = cast_tuple(stride, 2) 451 | padding = cast_tuple(padding, 2) 452 | 453 | if len(kernel) == 2: 454 | kernel = (1, *kernel) 455 | 456 | if len(stride) == 2: 457 | stride = (1, *stride) 458 | 459 | if len(padding) == 2: 460 | padding = (0, *padding) 461 | 462 | return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs) 463 | 464 | class Pad(nn.Module): 465 | def __init__(self, padding, value = 0.): 466 | super().__init__() 467 | self.padding = padding 468 | self.value = value 469 | 470 | def forward(self, x): 471 | return F.pad(x, self.padding, value = self.value) 472 | 473 | # decoder 474 | 475 | def Upsample(dim, dim_out = None): 476 | dim_out = default(dim_out, dim) 477 | 478 | return nn.Sequential( 479 | nn.Upsample(scale_factor = 2, mode = 'nearest'), 480 | Conv2d(dim, dim_out, 3, padding = 1) 481 | ) 482 | 483 | class PixelShuffleUpsample(nn.Module): 484 | def __init__(self, dim, dim_out = None): 485 | super().__init__() 486 | dim_out = default(dim_out, dim) 487 | conv = Conv2d(dim, dim_out * 4, 1) 488 | 489 | self.net = nn.Sequential( 490 | conv, 491 | nn.SiLU() 492 | ) 493 | 494 | self.pixel_shuffle = nn.PixelShuffle(2) 495 | 496 | self.init_conv_(conv) 497 | 498 | def init_conv_(self, conv): 499 | o, i, f, h, w = conv.weight.shape 500 | conv_weight = torch.empty(o // 4, i, f, h, w) 501 | nn.init.kaiming_uniform_(conv_weight) 502 | conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...') 503 | 504 | conv.weight.data.copy_(conv_weight) 505 | nn.init.zeros_(conv.bias.data) 506 | 507 | def forward(self, x): 508 | out = self.net(x) 509 | frames = x.shape[2] 510 | out = rearrange(out, 'b c f h w -> (b f) c h w') 511 | out = self.pixel_shuffle(out) 512 | return rearrange(out, '(b f) c h w -> b c f h w', f = frames) 513 | 514 | def Downsample(dim, dim_out = None): 515 | dim_out = default(dim_out, dim) 516 | return Conv2d(dim, dim_out, 4, 2, 1) 517 | 518 | class SinusoidalPosEmb(nn.Module): 519 | def __init__(self, dim): 520 | super().__init__() 521 | self.dim = dim 522 | 523 | def forward(self, x): 524 | half_dim = self.dim // 2 525 | emb = math.log(10000) / (half_dim - 1) 526 | emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb) 527 | emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j') 528 | return torch.cat((emb.sin(), emb.cos()), dim = -1) 529 | 530 | class LearnedSinusoidalPosEmb(nn.Module): 531 | def __init__(self, dim): 532 | super().__init__() 533 | assert (dim % 2) == 0 534 | half_dim = dim // 2 535 | self.weights = nn.Parameter(torch.randn(half_dim)) 536 | 537 | def forward(self, x): 538 | x = rearrange(x, 'b -> b 1') 539 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi 540 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) 541 | fouriered = torch.cat((x, fouriered), dim = -1) 542 | return fouriered 543 | 544 | class Block(nn.Module): 545 | def __init__( 546 | self, 547 | dim, 548 | dim_out, 549 | groups = 8, 550 | norm = True 551 | ): 552 | super().__init__() 553 | self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity() 554 | self.activation = nn.SiLU() 555 | self.project = Conv2d(dim, dim_out, 3, padding = 1) 556 | 557 | def forward(self, x, scale_shift = None): 558 | x = self.groupnorm(x) 559 | 560 | if exists(scale_shift): 561 | scale, shift = scale_shift 562 | x = x * (scale + 1) + shift 563 | 564 | x = self.activation(x) 565 | return self.project(x) 566 | 567 | class ResnetBlock(nn.Module): 568 | def __init__( 569 | self, 570 | dim, 571 | dim_out, 572 | *, 573 | cond_dim = None, 574 | time_cond_dim = None, 575 | groups = 8, 576 | linear_attn = False, 577 | use_gca = False, 578 | squeeze_excite = False, 579 | **attn_kwargs 580 | ): 581 | super().__init__() 582 | 583 | self.time_mlp = None 584 | 585 | if exists(time_cond_dim): 586 | self.time_mlp = nn.Sequential( 587 | nn.SiLU(), 588 | nn.Linear(time_cond_dim, dim_out * 2) 589 | ) 590 | 591 | self.cross_attn = None 592 | 593 | if exists(cond_dim): 594 | attn_klass = CrossAttention if not linear_attn else LinearCrossAttention 595 | 596 | self.cross_attn = EinopsToAndFrom( 597 | 'b c f h w', 598 | 'b (f h w) c', 599 | attn_klass( 600 | dim = dim_out, 601 | context_dim = cond_dim, 602 | **attn_kwargs 603 | ) 604 | ) 605 | 606 | self.block1 = Block(dim, dim_out, groups = groups) 607 | self.block2 = Block(dim_out, dim_out, groups = groups) 608 | 609 | self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1) 610 | 611 | self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity() 612 | 613 | 614 | def forward(self, x, time_emb = None, cond = None): 615 | 616 | scale_shift = None 617 | if exists(self.time_mlp) and exists(time_emb): 618 | time_emb = self.time_mlp(time_emb) 619 | time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') 620 | scale_shift = time_emb.chunk(2, dim = 1) 621 | 622 | h = self.block1(x) 623 | 624 | if exists(self.cross_attn): 625 | assert exists(cond) 626 | h = self.cross_attn(h, context = cond) + h 627 | 628 | h = self.block2(h, scale_shift = scale_shift) 629 | 630 | h = h * self.gca(h) 631 | 632 | return h + self.res_conv(x) 633 | 634 | class CrossAttention(nn.Module): 635 | def __init__( 636 | self, 637 | dim, 638 | *, 639 | context_dim = None, 640 | dim_head = 64, 641 | heads = 8, 642 | norm_context = False, 643 | cosine_sim_attn = False 644 | ): 645 | super().__init__() 646 | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1. 647 | self.cosine_sim_attn = cosine_sim_attn 648 | self.cosine_sim_scale = 16 if cosine_sim_attn else 1 649 | 650 | self.heads = heads 651 | inner_dim = dim_head * heads 652 | 653 | context_dim = default(context_dim, dim) 654 | 655 | self.norm = LayerNorm(dim) 656 | self.norm_context = LayerNorm(context_dim) if norm_context else Identity() 657 | 658 | self.null_kv = nn.Parameter(torch.randn(2, dim_head)) 659 | self.to_q = nn.Linear(dim, inner_dim, bias = False) 660 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False) 661 | 662 | self.to_out = nn.Sequential( 663 | nn.Linear(inner_dim, dim, bias = False), 664 | LayerNorm(dim) 665 | ) 666 | 667 | def forward(self, x, context, mask = None): 668 | b, n, device = *x.shape[:2], x.device 669 | 670 | x = self.norm(x) 671 | context = self.norm_context(context) 672 | 673 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 674 | 675 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads) 676 | 677 | # add null key / value for classifier free guidance in prior net 678 | 679 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b) 680 | 681 | k = torch.cat((nk, k), dim = -2) 682 | v = torch.cat((nv, v), dim = -2) 683 | 684 | q = q * self.scale 685 | 686 | # cosine sim attention 687 | 688 | if self.cosine_sim_attn: 689 | q, k = map(l2norm, (q, k)) 690 | 691 | # similarities 692 | 693 | sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale 694 | 695 | # masking 696 | 697 | max_neg_value = -torch.finfo(sim.dtype).max 698 | 699 | if exists(mask): 700 | mask = F.pad(mask, (1, 0), value = True) 701 | mask = rearrange(mask, 'b j -> b 1 1 j') 702 | sim = sim.masked_fill(~mask, max_neg_value) 703 | 704 | attn = sim.softmax(dim = -1, dtype = torch.float32) 705 | 706 | out = einsum('b h i j, b h j d -> b h i d', attn, v) 707 | out = rearrange(out, 'b h n d -> b n (h d)') 708 | return self.to_out(out) 709 | 710 | class LinearCrossAttention(CrossAttention): 711 | def forward(self, x, context, mask = None): 712 | b, n, device = *x.shape[:2], x.device 713 | 714 | x = self.norm(x) 715 | context = self.norm_context(context) 716 | 717 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1)) 718 | 719 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads) 720 | 721 | # add null key / value for classifier free guidance in prior net 722 | 723 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b) 724 | 725 | k = torch.cat((nk, k), dim = -2) 726 | v = torch.cat((nv, v), dim = -2) 727 | 728 | # masking 729 | 730 | max_neg_value = -torch.finfo(x.dtype).max 731 | 732 | if exists(mask): 733 | mask = F.pad(mask, (1, 0), value = True) 734 | mask = rearrange(mask, 'b n -> b n 1') 735 | k = k.masked_fill(~mask, max_neg_value) 736 | v = v.masked_fill(~mask, 0.) 737 | 738 | # linear attention 739 | 740 | q = q.softmax(dim = -1) 741 | k = k.softmax(dim = -2) 742 | 743 | q = q * self.scale 744 | 745 | context = einsum('b n d, b n e -> b d e', k, v) 746 | out = einsum('b n d, b d e -> b n e', q, context) 747 | out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads) 748 | return self.to_out(out) 749 | 750 | class LinearAttention(nn.Module): 751 | def __init__( 752 | self, 753 | dim, 754 | dim_head = 32, 755 | heads = 8, 756 | dropout = 0.05, 757 | context_dim = None, 758 | **kwargs 759 | ): 760 | super().__init__() 761 | self.scale = dim_head ** -0.5 762 | self.heads = heads 763 | inner_dim = dim_head * heads 764 | self.norm = ChanLayerNorm(dim) 765 | 766 | self.nonlin = nn.SiLU() 767 | 768 | self.to_q = nn.Sequential( 769 | nn.Dropout(dropout), 770 | Conv2d(dim, inner_dim, 1, bias = False), 771 | Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) 772 | ) 773 | 774 | self.to_k = nn.Sequential( 775 | nn.Dropout(dropout), 776 | Conv2d(dim, inner_dim, 1, bias = False), 777 | Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) 778 | ) 779 | 780 | self.to_v = nn.Sequential( 781 | nn.Dropout(dropout), 782 | Conv2d(dim, inner_dim, 1, bias = False), 783 | Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim) 784 | ) 785 | 786 | self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None 787 | 788 | self.to_out = nn.Sequential( 789 | Conv2d(inner_dim, dim, 1, bias = False), 790 | ChanLayerNorm(dim) 791 | ) 792 | 793 | def forward(self, fmap, context = None): 794 | h, x, y = self.heads, *fmap.shape[-2:] 795 | 796 | fmap = self.norm(fmap) 797 | q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v)) 798 | q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h) 799 | 800 | if exists(context): 801 | assert exists(self.to_context) 802 | ck, cv = self.to_context(context).chunk(2, dim = -1) 803 | ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h) 804 | k = torch.cat((k, ck), dim = -2) 805 | v = torch.cat((v, cv), dim = -2) 806 | 807 | q = q.softmax(dim = -1) 808 | k = k.softmax(dim = -2) 809 | 810 | q = q * self.scale 811 | 812 | context = einsum('b n d, b n e -> b d e', k, v) 813 | out = einsum('b n d, b d e -> b n e', q, context) 814 | out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y) 815 | 816 | out = self.nonlin(out) 817 | return self.to_out(out) 818 | 819 | class GlobalContext(nn.Module): 820 | """ basically a superior form of squeeze-excitation that is attention-esque """ 821 | 822 | def __init__( 823 | self, 824 | *, 825 | dim_in, 826 | dim_out 827 | ): 828 | super().__init__() 829 | self.to_k = Conv2d(dim_in, 1, 1) 830 | hidden_dim = max(3, dim_out // 2) 831 | 832 | self.net = nn.Sequential( 833 | Conv2d(dim_in, hidden_dim, 1), 834 | nn.SiLU(), 835 | Conv2d(hidden_dim, dim_out, 1), 836 | nn.Sigmoid() 837 | ) 838 | 839 | def forward(self, x): 840 | context = self.to_k(x) 841 | x, context = rearrange_many((x, context), 'b n ... -> b n (...)') 842 | out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x) 843 | out = rearrange(out, '... -> ... 1 1') 844 | return self.net(out) 845 | 846 | def FeedForward(dim, mult = 2): 847 | hidden_dim = int(dim * mult) 848 | return nn.Sequential( 849 | LayerNorm(dim), 850 | nn.Linear(dim, hidden_dim, bias = False), 851 | nn.GELU(), 852 | LayerNorm(hidden_dim), 853 | nn.Linear(hidden_dim, dim, bias = False) 854 | ) 855 | 856 | def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width 857 | hidden_dim = int(dim * mult) 858 | return nn.Sequential( 859 | ChanLayerNorm(dim), 860 | Conv2d(dim, hidden_dim, 1, bias = False), 861 | nn.GELU(), 862 | ChanLayerNorm(hidden_dim), 863 | Conv2d(hidden_dim, dim, 1, bias = False) 864 | ) 865 | 866 | class TransformerBlock(nn.Module): 867 | def __init__( 868 | self, 869 | dim, 870 | *, 871 | depth = 1, 872 | heads = 8, 873 | dim_head = 32, 874 | ff_mult = 2, 875 | context_dim = None, 876 | cosine_sim_attn = False 877 | ): 878 | super().__init__() 879 | self.layers = nn.ModuleList([]) 880 | 881 | for _ in range(depth): 882 | self.layers.append(nn.ModuleList([ 883 | EinopsToAndFrom('b c f h w', 'b (f h w) c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)), 884 | ChanFeedForward(dim = dim, mult = ff_mult) 885 | ])) 886 | 887 | def forward(self, x, context = None): 888 | for attn, ff in self.layers: 889 | x = attn(x, context = context) + x 890 | x = ff(x) + x 891 | return x 892 | 893 | class LinearAttentionTransformerBlock(nn.Module): 894 | def __init__( 895 | self, 896 | dim, 897 | *, 898 | depth = 1, 899 | heads = 8, 900 | dim_head = 32, 901 | ff_mult = 2, 902 | context_dim = None, 903 | **kwargs 904 | ): 905 | super().__init__() 906 | self.layers = nn.ModuleList([]) 907 | 908 | for _ in range(depth): 909 | self.layers.append(nn.ModuleList([ 910 | LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim), 911 | ChanFeedForward(dim = dim, mult = ff_mult) 912 | ])) 913 | 914 | def forward(self, x, context = None): 915 | for attn, ff in self.layers: 916 | x = attn(x, context = context) + x 917 | x = ff(x) + x 918 | return x 919 | 920 | class CrossEmbedLayer(nn.Module): 921 | def __init__( 922 | self, 923 | dim_in, 924 | kernel_sizes, 925 | dim_out = None, 926 | stride = 2 927 | ): 928 | super().__init__() 929 | assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)]) 930 | dim_out = default(dim_out, dim_in) 931 | 932 | kernel_sizes = sorted(kernel_sizes) 933 | num_scales = len(kernel_sizes) 934 | 935 | # calculate the dimension at each scale 936 | dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)] 937 | dim_scales = [*dim_scales, dim_out - sum(dim_scales)] 938 | 939 | self.convs = nn.ModuleList([]) 940 | for kernel, dim_scale in zip(kernel_sizes, dim_scales): 941 | self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2)) 942 | 943 | def forward(self, x): 944 | fmaps = tuple(map(lambda conv: conv(x), self.convs)) 945 | return torch.cat(fmaps, dim = 1) 946 | 947 | class UpsampleCombiner(nn.Module): 948 | def __init__( 949 | self, 950 | dim, 951 | *, 952 | enabled = False, 953 | dim_ins = tuple(), 954 | dim_outs = tuple() 955 | ): 956 | super().__init__() 957 | dim_outs = cast_tuple(dim_outs, len(dim_ins)) 958 | assert len(dim_ins) == len(dim_outs) 959 | 960 | self.enabled = enabled 961 | 962 | if not self.enabled: 963 | self.dim_out = dim 964 | return 965 | 966 | self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)]) 967 | self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0) 968 | 969 | def forward(self, x, fmaps = None): 970 | target_size = x.shape[-1] 971 | 972 | fmaps = default(fmaps, tuple()) 973 | 974 | if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0: 975 | return x 976 | 977 | fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps] 978 | outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)] 979 | return torch.cat((x, *outs), dim = 1) 980 | 981 | class DynamicPositionBias(nn.Module): 982 | def __init__( 983 | self, 984 | dim, 985 | *, 986 | heads, 987 | depth 988 | ): 989 | super().__init__() 990 | self.mlp = nn.ModuleList([]) 991 | 992 | self.mlp.append(nn.Sequential( 993 | nn.Linear(1, dim), 994 | LayerNorm(dim), 995 | nn.SiLU() 996 | )) 997 | 998 | for _ in range(max(depth - 1, 0)): 999 | self.mlp.append(nn.Sequential( 1000 | nn.Linear(dim, dim), 1001 | LayerNorm(dim), 1002 | nn.SiLU() 1003 | )) 1004 | 1005 | self.mlp.append(nn.Linear(dim, heads)) 1006 | 1007 | def forward(self, n, device, dtype): 1008 | i = torch.arange(n, device = device) 1009 | j = torch.arange(n, device = device) 1010 | 1011 | indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j') 1012 | indices += (n - 1) 1013 | 1014 | pos = torch.arange(-n + 1, n, device = device, dtype = dtype) 1015 | pos = rearrange(pos, '... -> ... 1') 1016 | 1017 | for layer in self.mlp: 1018 | pos = layer(pos) 1019 | 1020 | bias = pos[indices] 1021 | bias = rearrange(bias, 'i j h -> h i j') 1022 | return bias 1023 | 1024 | class Unet3D(nn.Module): 1025 | def __init__( 1026 | self, 1027 | *, 1028 | dim, 1029 | image_embed_dim = 1024, 1030 | text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME), 1031 | num_resnet_blocks = 1, 1032 | cond_dim = None, 1033 | num_image_tokens = 4, 1034 | num_time_tokens = 2, 1035 | learned_sinu_pos_emb_dim = 16, 1036 | out_dim = None, 1037 | dim_mults=(1, 2, 4, 8), 1038 | cond_images_channels = 0, 1039 | channels = 3, 1040 | channels_out = None, 1041 | attn_dim_head = 64, 1042 | attn_heads = 8, 1043 | ff_mult = 2., 1044 | lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/ 1045 | layer_attns = False, 1046 | layer_attns_depth = 1, 1047 | layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1 1048 | attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention) 1049 | time_rel_pos_bias_depth = 2, 1050 | time_causal_attn = True, 1051 | layer_cross_attns = True, 1052 | use_linear_attn = False, 1053 | use_linear_cross_attn = False, 1054 | cond_on_text = True, 1055 | max_text_len = 256, 1056 | init_dim = None, 1057 | resnet_groups = 8, 1058 | init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed 1059 | init_cross_embed = True, 1060 | init_cross_embed_kernel_sizes = (3, 7, 15), 1061 | cross_embed_downsample = False, 1062 | cross_embed_downsample_kernel_sizes = (2, 4), 1063 | attn_pool_text = True, 1064 | attn_pool_num_latents = 32, 1065 | dropout = 0., 1066 | memory_efficient = False, 1067 | init_conv_to_final_conv_residual = False, 1068 | use_global_context_attn = True, 1069 | scale_skip_connection = True, 1070 | final_resnet_block = True, 1071 | final_conv_kernel_size = 3, 1072 | cosine_sim_attn = False, 1073 | self_cond = False, 1074 | combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully 1075 | pixel_shuffle_upsample = True # may address checkboard artifacts 1076 | ): 1077 | super().__init__() 1078 | 1079 | # guide researchers 1080 | 1081 | assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8' 1082 | 1083 | if dim < 128: 1084 | print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/') 1085 | 1086 | # save locals to take care of some hyperparameters for cascading DDPM 1087 | 1088 | self._locals = locals() 1089 | self._locals.pop('self', None) 1090 | self._locals.pop('__class__', None) 1091 | 1092 | self.self_cond = self_cond 1093 | 1094 | # determine dimensions 1095 | 1096 | self.channels = channels 1097 | self.channels_out = default(channels_out, channels) 1098 | 1099 | # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis 1100 | # (2) in self conditioning, one appends the predict x0 (x_start) 1101 | init_channels = channels * (1 + int(lowres_cond) + int(self_cond)) 1102 | init_dim = default(init_dim, dim) 1103 | 1104 | # optional image conditioning 1105 | 1106 | self.has_cond_image = cond_images_channels > 0 1107 | self.cond_images_channels = cond_images_channels 1108 | 1109 | init_channels += cond_images_channels 1110 | 1111 | # initial convolution 1112 | 1113 | self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2) 1114 | 1115 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 1116 | in_out = list(zip(dims[:-1], dims[1:])) 1117 | 1118 | # time conditioning 1119 | 1120 | cond_dim = default(cond_dim, dim) 1121 | time_cond_dim = dim * 4 * (2 if lowres_cond else 1) 1122 | 1123 | # embedding time for log(snr) noise from continuous version 1124 | 1125 | sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim) 1126 | sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1 1127 | 1128 | self.to_time_hiddens = nn.Sequential( 1129 | sinu_pos_emb, 1130 | nn.Linear(sinu_pos_emb_input_dim, time_cond_dim), 1131 | nn.SiLU() 1132 | ) 1133 | 1134 | self.to_time_cond = nn.Sequential( 1135 | nn.Linear(time_cond_dim, time_cond_dim) 1136 | ) 1137 | 1138 | # project to time tokens as well as time hiddens 1139 | 1140 | self.to_time_tokens = nn.Sequential( 1141 | nn.Linear(time_cond_dim, cond_dim * num_time_tokens), 1142 | Rearrange('b (r d) -> b r d', r = num_time_tokens) 1143 | ) 1144 | 1145 | # low res aug noise conditioning 1146 | 1147 | self.lowres_cond = lowres_cond 1148 | 1149 | if lowres_cond: 1150 | self.to_lowres_time_hiddens = nn.Sequential( 1151 | LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim), 1152 | nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim), 1153 | nn.SiLU() 1154 | ) 1155 | 1156 | self.to_lowres_time_cond = nn.Sequential( 1157 | nn.Linear(time_cond_dim, time_cond_dim) 1158 | ) 1159 | 1160 | self.to_lowres_time_tokens = nn.Sequential( 1161 | nn.Linear(time_cond_dim, cond_dim * num_time_tokens), 1162 | Rearrange('b (r d) -> b r d', r = num_time_tokens) 1163 | ) 1164 | 1165 | # normalizations 1166 | 1167 | self.norm_cond = nn.LayerNorm(cond_dim) 1168 | 1169 | # text encoding conditioning (optional) 1170 | 1171 | self.text_to_cond = None 1172 | 1173 | if cond_on_text: 1174 | assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True' 1175 | self.text_to_cond = nn.Linear(text_embed_dim, cond_dim) 1176 | 1177 | # finer control over whether to condition on text encodings 1178 | 1179 | self.cond_on_text = cond_on_text 1180 | 1181 | # attention pooling 1182 | 1183 | self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents, cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None 1184 | 1185 | # for classifier free guidance 1186 | 1187 | self.max_text_len = max_text_len 1188 | 1189 | self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim)) 1190 | self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim)) 1191 | 1192 | # for non-attention based text conditioning at all points in the network where time is also conditioned 1193 | 1194 | self.to_text_non_attn_cond = None 1195 | 1196 | if cond_on_text: 1197 | self.to_text_non_attn_cond = nn.Sequential( 1198 | nn.LayerNorm(cond_dim), 1199 | nn.Linear(cond_dim, time_cond_dim), 1200 | nn.SiLU(), 1201 | nn.Linear(time_cond_dim, time_cond_dim) 1202 | ) 1203 | 1204 | # attention related params 1205 | 1206 | attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn) 1207 | 1208 | num_layers = len(in_out) 1209 | 1210 | # temporal attention - attention across video frames 1211 | 1212 | temporal_peg_padding = (0, 0, 0, 0, 2, 0) if time_causal_attn else (0, 0, 0, 0, 1, 1) 1213 | temporal_peg = lambda dim: Residual(nn.Sequential(Pad(temporal_peg_padding), nn.Conv3d(dim, dim, (3, 1, 1), groups = dim))) 1214 | 1215 | temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', '(b h w) f c', Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn}))) 1216 | 1217 | # temporal attention relative positional encoding 1218 | 1219 | self.time_rel_pos_bias = DynamicPositionBias(dim = dim * 2, heads = attn_heads, depth = time_rel_pos_bias_depth) 1220 | 1221 | # resnet block klass 1222 | 1223 | num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers) 1224 | resnet_groups = cast_tuple(resnet_groups, num_layers) 1225 | 1226 | resnet_klass = partial(ResnetBlock, **attn_kwargs) 1227 | 1228 | layer_attns = cast_tuple(layer_attns, num_layers) 1229 | layer_attns_depth = cast_tuple(layer_attns_depth, num_layers) 1230 | layer_cross_attns = cast_tuple(layer_cross_attns, num_layers) 1231 | 1232 | assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))]) 1233 | 1234 | # downsample klass 1235 | 1236 | downsample_klass = Downsample 1237 | 1238 | if cross_embed_downsample: 1239 | downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes) 1240 | 1241 | # initial resnet block (for memory efficient unet) 1242 | 1243 | self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None 1244 | 1245 | self.init_temporal_peg = temporal_peg(init_dim) 1246 | self.init_temporal_attn = temporal_attn(init_dim) 1247 | 1248 | # scale for resnet skip connections 1249 | 1250 | self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5) 1251 | 1252 | # layers 1253 | 1254 | self.downs = nn.ModuleList([]) 1255 | self.ups = nn.ModuleList([]) 1256 | num_resolutions = len(in_out) 1257 | 1258 | layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns] 1259 | reversed_layer_params = list(map(reversed, layer_params)) 1260 | 1261 | # downsampling layers 1262 | 1263 | skip_connect_dims = [] # keep track of skip connection dimensions 1264 | 1265 | for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(in_out, *layer_params)): 1266 | is_last = ind >= (num_resolutions - 1) 1267 | 1268 | layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn 1269 | layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None 1270 | 1271 | transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity) 1272 | 1273 | current_dim = dim_in 1274 | 1275 | # whether to pre-downsample, from memory efficient unet 1276 | 1277 | pre_downsample = None 1278 | 1279 | if memory_efficient: 1280 | pre_downsample = downsample_klass(dim_in, dim_out) 1281 | current_dim = dim_out 1282 | 1283 | skip_connect_dims.append(current_dim) 1284 | 1285 | # whether to do post-downsample, for non-memory efficient unet 1286 | 1287 | post_downsample = None 1288 | if not memory_efficient: 1289 | post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(Conv2d(dim_in, dim_out, 3, padding = 1), Conv2d(dim_in, dim_out, 1)) 1290 | 1291 | self.downs.append(nn.ModuleList([ 1292 | pre_downsample, 1293 | resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), 1294 | nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), 1295 | transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), 1296 | temporal_peg(current_dim), 1297 | temporal_attn(current_dim), 1298 | post_downsample 1299 | ])) 1300 | 1301 | # middle layers 1302 | 1303 | mid_dim = dims[-1] 1304 | 1305 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) 1306 | self.mid_attn = EinopsToAndFrom('b c f h w', 'b (f h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None 1307 | self.mid_temporal_peg = temporal_peg(mid_dim) 1308 | self.mid_temporal_attn = temporal_attn(mid_dim) 1309 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1]) 1310 | 1311 | # upsample klass 1312 | 1313 | upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample 1314 | 1315 | # upsampling layers 1316 | 1317 | upsample_fmap_dims = [] 1318 | 1319 | for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)): 1320 | is_last = ind == (len(in_out) - 1) 1321 | layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn 1322 | layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None 1323 | transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity) 1324 | 1325 | skip_connect_dim = skip_connect_dims.pop() 1326 | 1327 | upsample_fmap_dims.append(dim_out) 1328 | 1329 | self.ups.append(nn.ModuleList([ 1330 | resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups), 1331 | nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]), 1332 | transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs), 1333 | temporal_peg(dim_out), 1334 | temporal_attn(dim_out), 1335 | upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity() 1336 | ])) 1337 | 1338 | # whether to combine feature maps from all upsample blocks before final resnet block out 1339 | 1340 | self.upsample_combiner = UpsampleCombiner( 1341 | dim = dim, 1342 | enabled = combine_upsample_fmaps, 1343 | dim_ins = upsample_fmap_dims, 1344 | dim_outs = dim 1345 | ) 1346 | 1347 | # whether to do a final residual from initial conv to the final resnet block out 1348 | 1349 | self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual 1350 | final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0) 1351 | 1352 | # final optional resnet block and convolution out 1353 | 1354 | self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None 1355 | 1356 | final_conv_dim_in = dim if final_resnet_block else final_conv_dim 1357 | final_conv_dim_in += (channels if lowres_cond else 0) 1358 | 1359 | self.final_conv = Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2) 1360 | 1361 | zero_init_(self.final_conv) 1362 | 1363 | # if the current settings for the unet are not correct 1364 | # for cascading DDPM, then reinit the unet with the right settings 1365 | def cast_model_parameters( 1366 | self, 1367 | *, 1368 | lowres_cond, 1369 | text_embed_dim, 1370 | channels, 1371 | channels_out, 1372 | cond_on_text 1373 | ): 1374 | if lowres_cond == self.lowres_cond and \ 1375 | channels == self.channels and \ 1376 | cond_on_text == self.cond_on_text and \ 1377 | text_embed_dim == self._locals['text_embed_dim'] and \ 1378 | channels_out == self.channels_out: 1379 | return self 1380 | 1381 | updated_kwargs = dict( 1382 | lowres_cond = lowres_cond, 1383 | text_embed_dim = text_embed_dim, 1384 | channels = channels, 1385 | channels_out = channels_out, 1386 | cond_on_text = cond_on_text 1387 | ) 1388 | 1389 | return self.__class__(**{**self._locals, **updated_kwargs}) 1390 | 1391 | # methods for returning the full unet config as well as its parameter state 1392 | 1393 | def to_config_and_state_dict(self): 1394 | return self._locals, self.state_dict() 1395 | 1396 | # class method for rehydrating the unet from its config and state dict 1397 | 1398 | @classmethod 1399 | def from_config_and_state_dict(klass, config, state_dict): 1400 | unet = klass(**config) 1401 | unet.load_state_dict(state_dict) 1402 | return unet 1403 | 1404 | # methods for persisting unet to disk 1405 | 1406 | def persist_to_file(self, path): 1407 | path = Path(path) 1408 | path.parents[0].mkdir(exist_ok = True, parents = True) 1409 | 1410 | config, state_dict = self.to_config_and_state_dict() 1411 | pkg = dict(config = config, state_dict = state_dict) 1412 | torch.save(pkg, str(path)) 1413 | 1414 | # class method for rehydrating the unet from file saved with `persist_to_file` 1415 | 1416 | @classmethod 1417 | def hydrate_from_file(klass, path): 1418 | path = Path(path) 1419 | assert path.exists() 1420 | pkg = torch.load(str(path)) 1421 | 1422 | assert 'config' in pkg and 'state_dict' in pkg 1423 | config, state_dict = pkg['config'], pkg['state_dict'] 1424 | 1425 | return Unet.from_config_and_state_dict(config, state_dict) 1426 | 1427 | # forward with classifier free guidance 1428 | 1429 | def forward_with_cond_scale( 1430 | self, 1431 | *args, 1432 | cond_scale = 1., 1433 | **kwargs 1434 | ): 1435 | logits = self.forward(*args, **kwargs) 1436 | 1437 | if cond_scale == 1: 1438 | return logits 1439 | 1440 | null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs) 1441 | return null_logits + (logits - null_logits) * cond_scale 1442 | 1443 | def forward( 1444 | self, 1445 | x, 1446 | time, 1447 | *, 1448 | lowres_cond_img = None, 1449 | lowres_noise_times = None, 1450 | text_embeds = None, 1451 | text_mask = None, 1452 | cond_images = None, 1453 | self_cond = None, 1454 | cond_drop_prob = 0. 1455 | ): 1456 | assert x.ndim == 5, 'input to 3d unet must have 5 dimensions (batch, channels, time, height, width)' 1457 | 1458 | batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype 1459 | 1460 | # add self conditioning if needed 1461 | 1462 | if self.self_cond: 1463 | self_cond = default(self_cond, lambda: torch.zeros_like(x)) 1464 | x = torch.cat((x, self_cond), dim = 1) 1465 | 1466 | # add low resolution conditioning, if present 1467 | 1468 | assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present' 1469 | assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present' 1470 | 1471 | if exists(lowres_cond_img): 1472 | x = torch.cat((x, lowres_cond_img), dim = 1) 1473 | 1474 | # condition on input image 1475 | 1476 | assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa' 1477 | 1478 | if exists(cond_images): 1479 | assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet' 1480 | cond_images = resize_video_to(cond_images, x.shape[-1]) 1481 | x = torch.cat((cond_images, x), dim = 1) 1482 | 1483 | # get time relative positions 1484 | 1485 | time_attn_bias = self.time_rel_pos_bias(frames, device = device, dtype = dtype) 1486 | 1487 | # initial convolution 1488 | 1489 | x = self.init_conv(x) 1490 | 1491 | x = self.init_temporal_peg(x) 1492 | x = self.init_temporal_attn(x, attn_bias = time_attn_bias) 1493 | 1494 | # init conv residual 1495 | 1496 | if self.init_conv_to_final_conv_residual: 1497 | init_conv_residual = x.clone() 1498 | 1499 | # time conditioning 1500 | 1501 | time_hiddens = self.to_time_hiddens(time) 1502 | 1503 | # derive time tokens 1504 | 1505 | time_tokens = self.to_time_tokens(time_hiddens) 1506 | t = self.to_time_cond(time_hiddens) 1507 | 1508 | # add lowres time conditioning to time hiddens 1509 | # and add lowres time tokens along sequence dimension for attention 1510 | 1511 | if self.lowres_cond: 1512 | lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times) 1513 | lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens) 1514 | lowres_t = self.to_lowres_time_cond(lowres_time_hiddens) 1515 | 1516 | t = t + lowres_t 1517 | time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2) 1518 | 1519 | # text conditioning 1520 | 1521 | text_tokens = None 1522 | 1523 | if exists(text_embeds) and self.cond_on_text: 1524 | 1525 | # conditional dropout 1526 | 1527 | text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device) 1528 | 1529 | text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1') 1530 | text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1') 1531 | 1532 | # calculate text embeds 1533 | 1534 | text_tokens = self.text_to_cond(text_embeds) 1535 | 1536 | text_tokens = text_tokens[:, :self.max_text_len] 1537 | 1538 | if exists(text_mask): 1539 | text_mask = text_mask[:, :self.max_text_len] 1540 | 1541 | text_tokens_len = text_tokens.shape[1] 1542 | remainder = self.max_text_len - text_tokens_len 1543 | 1544 | if remainder > 0: 1545 | text_tokens = F.pad(text_tokens, (0, 0, 0, remainder)) 1546 | 1547 | if exists(text_mask): 1548 | if remainder > 0: 1549 | text_mask = F.pad(text_mask, (0, remainder), value = False) 1550 | 1551 | text_mask = rearrange(text_mask, 'b n -> b n 1') 1552 | text_keep_mask_embed = text_mask & text_keep_mask_embed 1553 | 1554 | null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working 1555 | 1556 | text_tokens = torch.where( 1557 | text_keep_mask_embed, 1558 | text_tokens, 1559 | null_text_embed 1560 | ) 1561 | 1562 | if exists(self.attn_pool): 1563 | text_tokens = self.attn_pool(text_tokens) 1564 | 1565 | # extra non-attention conditioning by projecting and then summing text embeddings to time 1566 | # termed as text hiddens 1567 | 1568 | mean_pooled_text_tokens = text_tokens.mean(dim = -2) 1569 | 1570 | text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens) 1571 | 1572 | null_text_hidden = self.null_text_hidden.to(t.dtype) 1573 | 1574 | text_hiddens = torch.where( 1575 | text_keep_mask_hidden, 1576 | text_hiddens, 1577 | null_text_hidden 1578 | ) 1579 | 1580 | t = t + text_hiddens 1581 | 1582 | # main conditioning tokens (c) 1583 | 1584 | c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2) 1585 | 1586 | # normalize conditioning tokens 1587 | 1588 | c = self.norm_cond(c) 1589 | 1590 | # initial resnet block (for memory efficient unet) 1591 | 1592 | if exists(self.init_resnet_block): 1593 | x = self.init_resnet_block(x, t) 1594 | 1595 | # go through the layers of the unet, down and up 1596 | 1597 | hiddens = [] 1598 | 1599 | for pre_downsample, init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, post_downsample in self.downs: 1600 | if exists(pre_downsample): 1601 | x = pre_downsample(x) 1602 | 1603 | x = init_block(x, t, c) 1604 | 1605 | for resnet_block in resnet_blocks: 1606 | x = resnet_block(x, t) 1607 | hiddens.append(x) 1608 | 1609 | x = attn_block(x, c) 1610 | x = temporal_peg(x) 1611 | x = temporal_attn(x, attn_bias = time_attn_bias) 1612 | 1613 | hiddens.append(x) 1614 | 1615 | if exists(post_downsample): 1616 | x = post_downsample(x) 1617 | 1618 | x = self.mid_block1(x, t, c) 1619 | 1620 | if exists(self.mid_attn): 1621 | x = self.mid_attn(x) 1622 | 1623 | x = self.mid_temporal_peg(x) 1624 | x = self.mid_temporal_attn(x, attn_bias = time_attn_bias) 1625 | 1626 | x = self.mid_block2(x, t, c) 1627 | 1628 | add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1) 1629 | 1630 | up_hiddens = [] 1631 | 1632 | for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, upsample in self.ups: 1633 | x = add_skip_connection(x) 1634 | x = init_block(x, t, c) 1635 | 1636 | for resnet_block in resnet_blocks: 1637 | x = add_skip_connection(x) 1638 | x = resnet_block(x, t) 1639 | 1640 | x = attn_block(x, c) 1641 | x = temporal_peg(x) 1642 | x = temporal_attn(x, attn_bias = time_attn_bias) 1643 | 1644 | up_hiddens.append(x.contiguous()) 1645 | x = upsample(x) 1646 | 1647 | # whether to combine all feature maps from upsample blocks 1648 | 1649 | x = self.upsample_combiner(x, up_hiddens) 1650 | 1651 | # final top-most residual if needed 1652 | 1653 | if self.init_conv_to_final_conv_residual: 1654 | x = torch.cat((x, init_conv_residual), dim = 1) 1655 | 1656 | if exists(self.final_res_block): 1657 | x = self.final_res_block(x, t) 1658 | 1659 | if exists(lowres_cond_img): 1660 | x = torch.cat((x, lowres_cond_img), dim = 1) 1661 | 1662 | return self.final_conv(x) 1663 | -------------------------------------------------------------------------------- /imagen_pytorch/t5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import transformers 3 | from typing import List 4 | from transformers import T5Tokenizer, T5EncoderModel, T5Config 5 | from einops import rearrange 6 | 7 | transformers.logging.set_verbosity_error() 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def default(val, d): 13 | if exists(val): 14 | return val 15 | return d() if callable(d) else d 16 | 17 | # config 18 | 19 | MAX_LENGTH = 256 20 | 21 | DEFAULT_T5_NAME = 'google/t5-v1_1-base' 22 | 23 | T5_CONFIGS = {} 24 | 25 | # singleton globals 26 | 27 | def get_tokenizer(name): 28 | tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH) 29 | return tokenizer 30 | 31 | def get_model(name): 32 | model = T5EncoderModel.from_pretrained(name) 33 | return model 34 | 35 | def get_model_and_tokenizer(name): 36 | global T5_CONFIGS 37 | 38 | if name not in T5_CONFIGS: 39 | T5_CONFIGS[name] = dict() 40 | if "model" not in T5_CONFIGS[name]: 41 | T5_CONFIGS[name]["model"] = get_model(name) 42 | if "tokenizer" not in T5_CONFIGS[name]: 43 | T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name) 44 | 45 | return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer'] 46 | 47 | def get_encoded_dim(name): 48 | if name not in T5_CONFIGS: 49 | # avoids loading the model if we only want to get the dim 50 | config = T5Config.from_pretrained(name) 51 | T5_CONFIGS[name] = dict(config=config) 52 | elif "config" in T5_CONFIGS[name]: 53 | config = T5_CONFIGS[name]["config"] 54 | elif "model" in T5_CONFIGS[name]: 55 | config = T5_CONFIGS[name]["model"].config 56 | else: 57 | assert False 58 | return config.d_model 59 | 60 | # encoding text 61 | 62 | def t5_tokenize( 63 | texts: List[str], 64 | name = DEFAULT_T5_NAME 65 | ): 66 | t5, tokenizer = get_model_and_tokenizer(name) 67 | 68 | if torch.cuda.is_available(): 69 | t5 = t5.cuda() 70 | 71 | device = next(t5.parameters()).device 72 | 73 | encoded = tokenizer.batch_encode_plus( 74 | texts, 75 | return_tensors = "pt", 76 | padding = 'longest', 77 | max_length = MAX_LENGTH, 78 | truncation = True 79 | ) 80 | 81 | input_ids = encoded.input_ids.to(device) 82 | attn_mask = encoded.attention_mask.to(device) 83 | return input_ids, attn_mask 84 | 85 | def t5_encode_tokenized_text( 86 | token_ids, 87 | attn_mask = None, 88 | pad_id = None, 89 | name = DEFAULT_T5_NAME 90 | ): 91 | assert exists(attn_mask) or exists(pad_id) 92 | t5, _ = get_model_and_tokenizer(name) 93 | 94 | attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long()) 95 | 96 | t5.eval() 97 | 98 | with torch.no_grad(): 99 | output = t5(input_ids = token_ids, attention_mask = attn_mask) 100 | encoded_text = output.last_hidden_state.detach() 101 | 102 | attn_mask = attn_mask.bool() 103 | 104 | encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) # just force all embeddings that is padding to be equal to 0. 105 | return encoded_text 106 | 107 | def t5_encode_text( 108 | texts: List[str], 109 | name = DEFAULT_T5_NAME, 110 | return_attn_mask = False 111 | ): 112 | token_ids, attn_mask = t5_tokenize(texts, name = name) 113 | encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name) 114 | 115 | if return_attn_mask: 116 | attn_mask = attn_mask.bool() 117 | return encoded_text, attn_mask 118 | 119 | return encoded_text 120 | -------------------------------------------------------------------------------- /imagen_pytorch/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import copy 4 | from pathlib import Path 5 | from math import ceil 6 | from contextlib import contextmanager, nullcontext 7 | from functools import partial, wraps 8 | from collections.abc import Iterable 9 | 10 | import torch 11 | from torch import nn 12 | import torch.nn.functional as F 13 | from torch.utils.data import random_split, DataLoader 14 | from torch.optim import Adam 15 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR 16 | from torch.cuda.amp import autocast, GradScaler 17 | 18 | import pytorch_warmup as warmup 19 | 20 | from imagen_pytorch.imagen_pytorch import Imagen, NullUnet 21 | from imagen_pytorch.elucidated_imagen import ElucidatedImagen 22 | from imagen_pytorch.data import cycle 23 | 24 | from imagen_pytorch.version import __version__ 25 | from packaging import version 26 | 27 | import numpy as np 28 | 29 | from ema_pytorch import EMA 30 | 31 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs 32 | 33 | from fsspec.core import url_to_fs 34 | from fsspec.implementations.local import LocalFileSystem 35 | 36 | # helper functions 37 | 38 | def exists(val): 39 | return val is not None 40 | 41 | def default(val, d): 42 | if exists(val): 43 | return val 44 | return d() if callable(d) else d 45 | 46 | def cast_tuple(val, length = 1): 47 | if isinstance(val, list): 48 | val = tuple(val) 49 | 50 | return val if isinstance(val, tuple) else ((val,) * length) 51 | 52 | def find_first(fn, arr): 53 | for ind, el in enumerate(arr): 54 | if fn(el): 55 | return ind 56 | return -1 57 | 58 | def pick_and_pop(keys, d): 59 | values = list(map(lambda key: d.pop(key), keys)) 60 | return dict(zip(keys, values)) 61 | 62 | def group_dict_by_key(cond, d): 63 | return_val = [dict(),dict()] 64 | for key in d.keys(): 65 | match = bool(cond(key)) 66 | ind = int(not match) 67 | return_val[ind][key] = d[key] 68 | return (*return_val,) 69 | 70 | def string_begins_with(prefix, str): 71 | return str.startswith(prefix) 72 | 73 | def group_by_key_prefix(prefix, d): 74 | return group_dict_by_key(partial(string_begins_with, prefix), d) 75 | 76 | def groupby_prefix_and_trim(prefix, d): 77 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d) 78 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items()))) 79 | return kwargs_without_prefix, kwargs 80 | 81 | def num_to_groups(num, divisor): 82 | groups = num // divisor 83 | remainder = num % divisor 84 | arr = [divisor] * groups 85 | if remainder > 0: 86 | arr.append(remainder) 87 | return arr 88 | 89 | # url to fs, bucket, path - for checkpointing to cloud 90 | 91 | def url_to_bucket(url): 92 | if '://' not in url: 93 | return url 94 | 95 | _, suffix = url.split('://') 96 | 97 | if prefix in {'gs', 's3'}: 98 | return suffix.split('/')[0] 99 | else: 100 | raise ValueError(f'storage type prefix "{prefix}" is not supported yet') 101 | 102 | # decorators 103 | 104 | def eval_decorator(fn): 105 | def inner(model, *args, **kwargs): 106 | was_training = model.training 107 | model.eval() 108 | out = fn(model, *args, **kwargs) 109 | model.train(was_training) 110 | return out 111 | return inner 112 | 113 | def cast_torch_tensor(fn, cast_fp16 = False): 114 | @wraps(fn) 115 | def inner(model, *args, **kwargs): 116 | device = kwargs.pop('_device', model.device) 117 | cast_device = kwargs.pop('_cast_device', True) 118 | 119 | should_cast_fp16 = cast_fp16 and model.cast_half_at_training 120 | 121 | kwargs_keys = kwargs.keys() 122 | all_args = (*args, *kwargs.values()) 123 | split_kwargs_index = len(all_args) - len(kwargs_keys) 124 | all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args)) 125 | 126 | if cast_device: 127 | all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args)) 128 | 129 | if should_cast_fp16: 130 | all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args)) 131 | 132 | args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:] 133 | kwargs = dict(tuple(zip(kwargs_keys, kwargs_values))) 134 | 135 | out = fn(model, *args, **kwargs) 136 | return out 137 | return inner 138 | 139 | # gradient accumulation functions 140 | 141 | def split_iterable(it, split_size): 142 | accum = [] 143 | for ind in range(ceil(len(it) / split_size)): 144 | start_index = ind * split_size 145 | accum.append(it[start_index: (start_index + split_size)]) 146 | return accum 147 | 148 | def split(t, split_size = None): 149 | if not exists(split_size): 150 | return t 151 | 152 | if isinstance(t, torch.Tensor): 153 | return t.split(split_size, dim = 0) 154 | 155 | if isinstance(t, Iterable): 156 | return split_iterable(t, split_size) 157 | 158 | return TypeError 159 | 160 | def find_first(cond, arr): 161 | for el in arr: 162 | if cond(el): 163 | return el 164 | return None 165 | 166 | def split_args_and_kwargs(*args, split_size = None, **kwargs): 167 | all_args = (*args, *kwargs.values()) 168 | len_all_args = len(all_args) 169 | first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args) 170 | assert exists(first_tensor) 171 | 172 | batch_size = len(first_tensor) 173 | split_size = default(split_size, batch_size) 174 | num_chunks = ceil(batch_size / split_size) 175 | 176 | dict_len = len(kwargs) 177 | dict_keys = kwargs.keys() 178 | split_kwargs_index = len_all_args - dict_len 179 | 180 | split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args] 181 | chunk_sizes = tuple(map(len, split_all_args[0])) 182 | 183 | for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)): 184 | chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:] 185 | chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values))) 186 | chunk_size_frac = chunk_size / batch_size 187 | yield chunk_size_frac, (chunked_args, chunked_kwargs) 188 | 189 | # imagen trainer 190 | 191 | def imagen_sample_in_chunks(fn): 192 | @wraps(fn) 193 | def inner(self, *args, max_batch_size = None, **kwargs): 194 | if not exists(max_batch_size): 195 | return fn(self, *args, **kwargs) 196 | 197 | if self.imagen.unconditional: 198 | batch_size = kwargs.get('batch_size') 199 | batch_sizes = num_to_groups(batch_size, max_batch_size) 200 | outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes] 201 | else: 202 | outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)] 203 | 204 | if isinstance(outputs[0], torch.Tensor): 205 | return torch.cat(outputs, dim = 0) 206 | 207 | return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs)))) 208 | 209 | return inner 210 | 211 | 212 | def restore_parts(state_dict_target, state_dict_from): 213 | for name, param in state_dict_from.items(): 214 | 215 | if name not in state_dict_target: 216 | continue 217 | 218 | if param.size() == state_dict_target[name].size(): 219 | state_dict_target[name].copy_(param) 220 | else: 221 | print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}") 222 | 223 | return state_dict_target 224 | 225 | 226 | class ImagenTrainer(nn.Module): 227 | locked = False 228 | 229 | def __init__( 230 | self, 231 | imagen = None, 232 | imagen_checkpoint_path = None, 233 | use_ema = True, 234 | lr = 1e-4, 235 | eps = 1e-8, 236 | beta1 = 0.9, 237 | beta2 = 0.99, 238 | max_grad_norm = None, 239 | group_wd_params = True, 240 | warmup_steps = None, 241 | cosine_decay_max_steps = None, 242 | only_train_unet_number = None, 243 | fp16 = False, 244 | precision = None, 245 | split_batches = True, 246 | dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'), 247 | verbose = True, 248 | split_valid_fraction = 0.025, 249 | split_valid_from_train = False, 250 | split_random_seed = 42, 251 | checkpoint_path = None, 252 | checkpoint_every = None, 253 | checkpoint_fs = None, 254 | fs_kwargs: dict = None, 255 | max_checkpoints_keep = 20, 256 | **kwargs 257 | ): 258 | super().__init__() 259 | assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)' 260 | assert exists(imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config' 261 | 262 | # determine filesystem, using fsspec, for saving to local filesystem or cloud 263 | 264 | self.fs = checkpoint_fs 265 | 266 | if not exists(self.fs): 267 | fs_kwargs = default(fs_kwargs, {}) 268 | self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs) 269 | 270 | assert isinstance(imagen, (Imagen, ElucidatedImagen)) 271 | ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs) 272 | 273 | # elucidated or not 274 | 275 | self.is_elucidated = isinstance(imagen, ElucidatedImagen) 276 | 277 | # create accelerator instance 278 | 279 | accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs) 280 | 281 | assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator' 282 | accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no') 283 | 284 | self.accelerator = Accelerator(**{ 285 | 'split_batches': split_batches, 286 | 'mixed_precision': accelerator_mixed_precision, 287 | 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)] 288 | , **accelerate_kwargs}) 289 | 290 | ImagenTrainer.locked = self.is_distributed 291 | 292 | # cast data to fp16 at training time if needed 293 | 294 | self.cast_half_at_training = accelerator_mixed_precision == 'fp16' 295 | 296 | # grad scaler must be managed outside of accelerator 297 | 298 | grad_scaler_enabled = fp16 299 | 300 | # imagen, unets and ema unets 301 | 302 | self.imagen = imagen 303 | self.num_unets = len(self.imagen.unets) 304 | 305 | self.use_ema = use_ema and self.is_main 306 | self.ema_unets = nn.ModuleList([]) 307 | 308 | # keep track of what unet is being trained on 309 | # only going to allow 1 unet training at a time 310 | 311 | self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on 312 | 313 | # data related functions 314 | 315 | self.train_dl_iter = None 316 | self.train_dl = None 317 | 318 | self.valid_dl_iter = None 319 | self.valid_dl = None 320 | 321 | self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names 322 | 323 | # auto splitting validation from training, if dataset is passed in 324 | 325 | self.split_valid_from_train = split_valid_from_train 326 | 327 | assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1' 328 | self.split_valid_fraction = split_valid_fraction 329 | self.split_random_seed = split_random_seed 330 | 331 | # be able to finely customize learning rate, weight decay 332 | # per unet 333 | 334 | lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps)) 335 | 336 | for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)): 337 | optimizer = Adam( 338 | unet.parameters(), 339 | lr = unet_lr, 340 | eps = unet_eps, 341 | betas = (beta1, beta2), 342 | **kwargs 343 | ) 344 | 345 | if self.use_ema: 346 | self.ema_unets.append(EMA(unet, **ema_kwargs)) 347 | 348 | scaler = GradScaler(enabled = grad_scaler_enabled) 349 | 350 | scheduler = warmup_scheduler = None 351 | 352 | if exists(unet_cosine_decay_max_steps): 353 | scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps) 354 | 355 | if exists(unet_warmup_steps): 356 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps) 357 | 358 | if not exists(scheduler): 359 | scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0) 360 | 361 | # set on object 362 | 363 | setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers 364 | setattr(self, f'scaler{ind}', scaler) 365 | setattr(self, f'scheduler{ind}', scheduler) 366 | setattr(self, f'warmup{ind}', warmup_scheduler) 367 | 368 | # gradient clipping if needed 369 | 370 | self.max_grad_norm = max_grad_norm 371 | 372 | # step tracker and misc 373 | 374 | self.register_buffer('steps', torch.tensor([0] * self.num_unets)) 375 | 376 | self.verbose = verbose 377 | 378 | # automatic set devices based on what accelerator decided 379 | 380 | self.imagen.to(self.device) 381 | self.to(self.device) 382 | 383 | # checkpointing 384 | 385 | assert not (exists(checkpoint_path) ^ exists(checkpoint_every)) 386 | self.checkpoint_path = checkpoint_path 387 | self.checkpoint_every = checkpoint_every 388 | self.max_checkpoints_keep = max_checkpoints_keep 389 | 390 | self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main 391 | 392 | if exists(checkpoint_path) and self.can_checkpoint: 393 | bucket = url_to_bucket(checkpoint_path) 394 | 395 | if not self.fs.exists(bucket): 396 | self.fs.mkdir(bucket) 397 | 398 | self.load_from_checkpoint_folder() 399 | 400 | # only allowing training for unet 401 | 402 | self.only_train_unet_number = only_train_unet_number 403 | self.validate_and_set_unet_being_trained(only_train_unet_number) 404 | 405 | # computed values 406 | 407 | @property 408 | def device(self): 409 | return self.accelerator.device 410 | 411 | @property 412 | def is_distributed(self): 413 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1) 414 | 415 | @property 416 | def is_main(self): 417 | return self.accelerator.is_main_process 418 | 419 | @property 420 | def is_local_main(self): 421 | return self.accelerator.is_local_main_process 422 | 423 | @property 424 | def unwrapped_unet(self): 425 | return self.accelerator.unwrap_model(self.unet_being_trained) 426 | 427 | # optimizer helper functions 428 | 429 | def get_lr(self, unet_number): 430 | self.validate_unet_number(unet_number) 431 | unet_index = unet_number - 1 432 | 433 | optim = getattr(self, f'optim{unet_index}') 434 | 435 | return optim.param_groups[0]['lr'] 436 | 437 | # function for allowing only one unet from being trained at a time 438 | 439 | def validate_and_set_unet_being_trained(self, unet_number = None): 440 | if exists(unet_number): 441 | self.validate_unet_number(unet_number) 442 | 443 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet' 444 | 445 | self.only_train_unet_number = unet_number 446 | self.imagen.only_train_unet_number = unet_number 447 | 448 | if not exists(unet_number): 449 | return 450 | 451 | self.wrap_unet(unet_number) 452 | 453 | def wrap_unet(self, unet_number): 454 | if hasattr(self, 'one_unet_wrapped'): 455 | return 456 | 457 | unet = self.imagen.get_unet(unet_number) 458 | self.unet_being_trained = self.accelerator.prepare(unet) 459 | unet_index = unet_number - 1 460 | 461 | optimizer = getattr(self, f'optim{unet_index}') 462 | scheduler = getattr(self, f'scheduler{unet_index}') 463 | 464 | optimizer = self.accelerator.prepare(optimizer) 465 | 466 | if exists(scheduler): 467 | scheduler = self.accelerator.prepare(scheduler) 468 | 469 | setattr(self, f'optim{unet_index}', optimizer) 470 | setattr(self, f'scheduler{unet_index}', scheduler) 471 | 472 | self.one_unet_wrapped = True 473 | 474 | # hacking accelerator due to not having separate gradscaler per optimizer 475 | 476 | def set_accelerator_scaler(self, unet_number): 477 | unet_number = self.validate_unet_number(unet_number) 478 | scaler = getattr(self, f'scaler{unet_number - 1}') 479 | 480 | self.accelerator.scaler = scaler 481 | for optimizer in self.accelerator._optimizers: 482 | optimizer.scaler = scaler 483 | 484 | # helper print 485 | 486 | def print(self, msg): 487 | if not self.is_main: 488 | return 489 | 490 | if not self.verbose: 491 | return 492 | 493 | return self.accelerator.print(msg) 494 | 495 | # validating the unet number 496 | 497 | def validate_unet_number(self, unet_number = None): 498 | if self.num_unets == 1: 499 | unet_number = default(unet_number, 1) 500 | 501 | assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}' 502 | return unet_number 503 | 504 | # number of training steps taken 505 | 506 | def num_steps_taken(self, unet_number = None): 507 | if self.num_unets == 1: 508 | unet_number = default(unet_number, 1) 509 | 510 | return self.steps[unet_number - 1].item() 511 | 512 | def print_untrained_unets(self): 513 | print_final_error = False 514 | 515 | for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)): 516 | if steps > 0 or isinstance(unet, NullUnet): 517 | continue 518 | 519 | self.print(f'unet {ind + 1} has not been trained') 520 | print_final_error = True 521 | 522 | if print_final_error: 523 | self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets') 524 | 525 | # data related functions 526 | 527 | def add_train_dataloader(self, dl = None): 528 | if not exists(dl): 529 | return 530 | 531 | assert not exists(self.train_dl), 'training dataloader was already added' 532 | self.train_dl = self.accelerator.prepare(dl) 533 | 534 | def add_valid_dataloader(self, dl): 535 | if not exists(dl): 536 | return 537 | 538 | assert not exists(self.valid_dl), 'validation dataloader was already added' 539 | self.valid_dl = self.accelerator.prepare(dl) 540 | 541 | def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs): 542 | if not exists(ds): 543 | return 544 | 545 | assert not exists(self.train_dl), 'training dataloader was already added' 546 | 547 | valid_ds = None 548 | if self.split_valid_from_train: 549 | train_size = int((1 - self.split_valid_fraction) * len(ds)) 550 | valid_size = len(ds) - train_size 551 | 552 | ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed)) 553 | self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples') 554 | 555 | dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) 556 | self.train_dl = self.accelerator.prepare(dl) 557 | 558 | if not self.split_valid_from_train: 559 | return 560 | 561 | self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs) 562 | 563 | def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs): 564 | if not exists(ds): 565 | return 566 | 567 | assert not exists(self.valid_dl), 'validation dataloader was already added' 568 | 569 | dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs) 570 | self.valid_dl = self.accelerator.prepare(dl) 571 | 572 | def create_train_iter(self): 573 | assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet' 574 | 575 | if exists(self.train_dl_iter): 576 | return 577 | 578 | self.train_dl_iter = cycle(self.train_dl) 579 | 580 | def create_valid_iter(self): 581 | assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet' 582 | 583 | if exists(self.valid_dl_iter): 584 | return 585 | 586 | self.valid_dl_iter = cycle(self.valid_dl) 587 | 588 | def train_step(self, unet_number = None, **kwargs): 589 | self.create_train_iter() 590 | loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs) 591 | self.update(unet_number = unet_number) 592 | return loss 593 | 594 | @torch.no_grad() 595 | @eval_decorator 596 | def valid_step(self, **kwargs): 597 | self.create_valid_iter() 598 | 599 | context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext 600 | 601 | with context(): 602 | loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs) 603 | return loss 604 | 605 | def step_with_dl_iter(self, dl_iter, **kwargs): 606 | dl_tuple_output = cast_tuple(next(dl_iter)) 607 | model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output))) 608 | loss = self.forward(**{**kwargs, **model_input}) 609 | return loss 610 | 611 | # checkpointing functions 612 | 613 | @property 614 | def all_checkpoints_sorted(self): 615 | glob_pattern = os.path.join(self.checkpoint_path, '*.pt') 616 | checkpoints = self.fs.glob(glob_pattern) 617 | sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True) 618 | return sorted_checkpoints 619 | 620 | def load_from_checkpoint_folder(self, last_total_steps = -1): 621 | if last_total_steps != -1: 622 | filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt') 623 | self.load(filepath) 624 | return 625 | 626 | sorted_checkpoints = self.all_checkpoints_sorted 627 | 628 | if len(sorted_checkpoints) == 0: 629 | self.print(f'no checkpoints found to load from at {self.checkpoint_path}') 630 | return 631 | 632 | last_checkpoint = sorted_checkpoints[0] 633 | self.load(last_checkpoint) 634 | 635 | def save_to_checkpoint_folder(self): 636 | self.accelerator.wait_for_everyone() 637 | 638 | if not self.can_checkpoint: 639 | return 640 | 641 | total_steps = int(self.steps.sum().item()) 642 | filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt') 643 | 644 | self.save(filepath) 645 | 646 | if self.max_checkpoints_keep <= 0: 647 | return 648 | 649 | sorted_checkpoints = self.all_checkpoints_sorted 650 | checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:] 651 | 652 | for checkpoint in checkpoints_to_discard: 653 | self.fs.rm(checkpoint) 654 | 655 | # saving and loading functions 656 | 657 | def save( 658 | self, 659 | path, 660 | overwrite = True, 661 | without_optim_and_sched = False, 662 | **kwargs 663 | ): 664 | self.accelerator.wait_for_everyone() 665 | 666 | if not self.can_checkpoint: 667 | return 668 | 669 | fs = self.fs 670 | 671 | assert not (fs.exists(path) and not overwrite) 672 | 673 | self.reset_ema_unets_all_one_device() 674 | 675 | save_obj = dict( 676 | model = self.imagen.state_dict(), 677 | version = __version__, 678 | steps = self.steps.cpu(), 679 | **kwargs 680 | ) 681 | 682 | save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple() 683 | 684 | for ind in save_optim_and_sched_iter: 685 | scaler_key = f'scaler{ind}' 686 | optimizer_key = f'optim{ind}' 687 | scheduler_key = f'scheduler{ind}' 688 | warmup_scheduler_key = f'warmup{ind}' 689 | 690 | scaler = getattr(self, scaler_key) 691 | optimizer = getattr(self, optimizer_key) 692 | scheduler = getattr(self, scheduler_key) 693 | warmup_scheduler = getattr(self, warmup_scheduler_key) 694 | 695 | if exists(scheduler): 696 | save_obj = {**save_obj, scheduler_key: scheduler.state_dict()} 697 | 698 | if exists(warmup_scheduler): 699 | save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()} 700 | 701 | save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()} 702 | 703 | if self.use_ema: 704 | save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()} 705 | 706 | # determine if imagen config is available 707 | 708 | if hasattr(self.imagen, '_config'): 709 | self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"\""') 710 | 711 | save_obj = { 712 | **save_obj, 713 | 'imagen_type': 'elucidated' if self.is_elucidated else 'original', 714 | 'imagen_params': self.imagen._config 715 | } 716 | 717 | #save to path 718 | 719 | with fs.open(path, 'wb') as f: 720 | torch.save(save_obj, f) 721 | 722 | self.print(f'checkpoint saved to {path}') 723 | 724 | def load(self, path, only_model = False, strict = True, noop_if_not_exist = False): 725 | fs = self.fs 726 | 727 | if noop_if_not_exist and not fs.exists(path): 728 | self.print(f'trainer checkpoint not found at {str(path)}') 729 | return 730 | 731 | assert fs.exists(path), f'{path} does not exist' 732 | 733 | self.reset_ema_unets_all_one_device() 734 | 735 | # to avoid extra GPU memory usage in main process when using Accelerate 736 | 737 | with fs.open(path) as f: 738 | loaded_obj = torch.load(f, map_location='cpu') 739 | 740 | if version.parse(__version__) != version.parse(loaded_obj['version']): 741 | self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}') 742 | 743 | try: 744 | self.imagen.load_state_dict(loaded_obj['model'], strict = strict) 745 | except RuntimeError: 746 | print("Failed loading state dict. Trying partial load") 747 | self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(), 748 | loaded_obj['model'])) 749 | 750 | if only_model: 751 | return loaded_obj 752 | 753 | self.steps.copy_(loaded_obj['steps']) 754 | 755 | for ind in range(0, self.num_unets): 756 | scaler_key = f'scaler{ind}' 757 | optimizer_key = f'optim{ind}' 758 | scheduler_key = f'scheduler{ind}' 759 | warmup_scheduler_key = f'warmup{ind}' 760 | 761 | scaler = getattr(self, scaler_key) 762 | optimizer = getattr(self, optimizer_key) 763 | scheduler = getattr(self, scheduler_key) 764 | warmup_scheduler = getattr(self, warmup_scheduler_key) 765 | 766 | if exists(scheduler) and scheduler_key in loaded_obj: 767 | scheduler.load_state_dict(loaded_obj[scheduler_key]) 768 | 769 | if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj: 770 | warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key]) 771 | 772 | if exists(optimizer): 773 | try: 774 | optimizer.load_state_dict(loaded_obj[optimizer_key]) 775 | scaler.load_state_dict(loaded_obj[scaler_key]) 776 | except: 777 | self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers') 778 | 779 | if self.use_ema: 780 | assert 'ema' in loaded_obj 781 | try: 782 | self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict) 783 | except RuntimeError: 784 | print("Failed loading state dict. Trying partial load") 785 | self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(), 786 | loaded_obj['ema'])) 787 | 788 | self.print(f'checkpoint loaded from {path}') 789 | return loaded_obj 790 | 791 | # managing ema unets and their devices 792 | 793 | @property 794 | def unets(self): 795 | return nn.ModuleList([ema.ema_model for ema in self.ema_unets]) 796 | 797 | def get_ema_unet(self, unet_number = None): 798 | if not self.use_ema: 799 | return 800 | 801 | unet_number = self.validate_unet_number(unet_number) 802 | index = unet_number - 1 803 | 804 | if isinstance(self.unets, nn.ModuleList): 805 | unets_list = [unet for unet in self.ema_unets] 806 | delattr(self, 'ema_unets') 807 | self.ema_unets = unets_list 808 | 809 | if index != self.ema_unet_being_trained_index: 810 | for unet_index, unet in enumerate(self.ema_unets): 811 | unet.to(self.device if unet_index == index else 'cpu') 812 | 813 | self.ema_unet_being_trained_index = index 814 | return self.ema_unets[index] 815 | 816 | def reset_ema_unets_all_one_device(self, device = None): 817 | if not self.use_ema: 818 | return 819 | 820 | device = default(device, self.device) 821 | self.ema_unets = nn.ModuleList([*self.ema_unets]) 822 | self.ema_unets.to(device) 823 | 824 | self.ema_unet_being_trained_index = -1 825 | 826 | @torch.no_grad() 827 | @contextmanager 828 | def use_ema_unets(self): 829 | if not self.use_ema: 830 | output = yield 831 | return output 832 | 833 | self.reset_ema_unets_all_one_device() 834 | self.imagen.reset_unets_all_one_device() 835 | 836 | self.unets.eval() 837 | 838 | trainable_unets = self.imagen.unets 839 | self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling 840 | 841 | output = yield 842 | 843 | self.imagen.unets = trainable_unets # restore original training unets 844 | 845 | # cast the ema_model unets back to original device 846 | for ema in self.ema_unets: 847 | ema.restore_ema_model_device() 848 | 849 | return output 850 | 851 | def print_unet_devices(self): 852 | self.print('unet devices:') 853 | for i, unet in enumerate(self.imagen.unets): 854 | device = next(unet.parameters()).device 855 | self.print(f'\tunet {i}: {device}') 856 | 857 | if not self.use_ema: 858 | return 859 | 860 | self.print('\nema unet devices:') 861 | for i, ema_unet in enumerate(self.ema_unets): 862 | device = next(ema_unet.parameters()).device 863 | self.print(f'\tema unet {i}: {device}') 864 | 865 | # overriding state dict functions 866 | 867 | def state_dict(self, *args, **kwargs): 868 | self.reset_ema_unets_all_one_device() 869 | return super().state_dict(*args, **kwargs) 870 | 871 | def load_state_dict(self, *args, **kwargs): 872 | self.reset_ema_unets_all_one_device() 873 | return super().load_state_dict(*args, **kwargs) 874 | 875 | # encoding text functions 876 | 877 | def encode_text(self, text, **kwargs): 878 | return self.imagen.encode_text(text, **kwargs) 879 | 880 | # forwarding functions and gradient step updates 881 | 882 | def update(self, unet_number = None): 883 | unet_number = self.validate_unet_number(unet_number) 884 | self.validate_and_set_unet_being_trained(unet_number) 885 | self.set_accelerator_scaler(unet_number) 886 | 887 | index = unet_number - 1 888 | unet = self.unet_being_trained 889 | 890 | optimizer = getattr(self, f'optim{index}') 891 | scaler = getattr(self, f'scaler{index}') 892 | scheduler = getattr(self, f'scheduler{index}') 893 | warmup_scheduler = getattr(self, f'warmup{index}') 894 | 895 | # set the grad scaler on the accelerator, since we are managing one per u-net 896 | 897 | if exists(self.max_grad_norm): 898 | self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm) 899 | 900 | optimizer.step() 901 | optimizer.zero_grad() 902 | 903 | if self.use_ema: 904 | ema_unet = self.get_ema_unet(unet_number) 905 | ema_unet.update() 906 | 907 | # scheduler, if needed 908 | 909 | maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening() 910 | 911 | with maybe_warmup_context: 912 | if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs 913 | scheduler.step() 914 | 915 | self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps)) 916 | 917 | if not exists(self.checkpoint_path): 918 | return 919 | 920 | total_steps = int(self.steps.sum().item()) 921 | 922 | if total_steps % self.checkpoint_every: 923 | return 924 | 925 | self.save_to_checkpoint_folder() 926 | 927 | @torch.no_grad() 928 | @cast_torch_tensor 929 | @imagen_sample_in_chunks 930 | def sample(self, *args, **kwargs): 931 | context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets 932 | 933 | self.print_untrained_unets() 934 | 935 | if not self.is_main: 936 | kwargs['use_tqdm'] = False 937 | 938 | with context(): 939 | output = self.imagen.sample(*args, device = self.device, **kwargs) 940 | 941 | return output 942 | 943 | @partial(cast_torch_tensor, cast_fp16 = True) 944 | def forward( 945 | self, 946 | *args, 947 | unet_number = None, 948 | max_batch_size = None, 949 | **kwargs 950 | ): 951 | unet_number = self.validate_unet_number(unet_number) 952 | self.validate_and_set_unet_being_trained(unet_number) 953 | self.set_accelerator_scaler(unet_number) 954 | 955 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}' 956 | 957 | total_loss = 0. 958 | 959 | for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs): 960 | with self.accelerator.autocast(): 961 | loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs) 962 | loss = loss * chunk_size_frac 963 | 964 | total_loss += loss.item() 965 | 966 | if self.training: 967 | self.accelerator.backward(loss) 968 | 969 | return total_loss 970 | -------------------------------------------------------------------------------- /imagen_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from functools import reduce 4 | from pathlib import Path 5 | 6 | from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig 7 | from ema_pytorch import EMA 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | def safeget(dictionary, keys, default = None): 13 | return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary) 14 | 15 | def load_imagen_from_checkpoint( 16 | checkpoint_path, 17 | load_weights = True, 18 | load_ema_if_available = False 19 | ): 20 | model_path = Path(checkpoint_path) 21 | full_model_path = str(model_path.resolve()) 22 | assert model_path.exists(), f'checkpoint not found at {full_model_path}' 23 | loaded = torch.load(str(model_path), map_location='cpu') 24 | 25 | imagen_params = safeget(loaded, 'imagen_params') 26 | imagen_type = safeget(loaded, 'imagen_type') 27 | 28 | if imagen_type == 'original': 29 | imagen_klass = ImagenConfig 30 | elif imagen_type == 'elucidated': 31 | imagen_klass = ElucidatedImagenConfig 32 | else: 33 | raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig') 34 | 35 | assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint' 36 | 37 | imagen = imagen_klass(**imagen_params).create() 38 | 39 | if not load_weights: 40 | return imagen 41 | 42 | has_ema = 'ema' in loaded 43 | should_load_ema = has_ema and load_ema_if_available 44 | 45 | imagen.load_state_dict(loaded['model']) 46 | 47 | if not should_load_ema: 48 | print('loading non-EMA version of unets') 49 | return imagen 50 | 51 | ema_unets = nn.ModuleList([]) 52 | for unet in imagen.unets: 53 | ema_unets.append(EMA(unet)) 54 | 55 | ema_unets.load_state_dict(loaded['ema']) 56 | 57 | for unet, ema_unet in zip(imagen.unets, ema_unets): 58 | unet.load_state_dict(ema_unet.ema_model.state_dict()) 59 | 60 | print('loaded EMA version of unets') 61 | return imagen 62 | -------------------------------------------------------------------------------- /imagen_pytorch/version.py: -------------------------------------------------------------------------------- 1 | __version__ = '1.11.15' 2 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | exec(open('imagen_pytorch/version.py').read()) 3 | 4 | setup( 5 | name = 'imagen-pytorch', 6 | packages = find_packages(exclude=[]), 7 | include_package_data = True, 8 | entry_points={ 9 | 'console_scripts': [ 10 | 'imagen_pytorch = imagen_pytorch.cli:main', 11 | 'imagen = imagen_pytorch.cli:imagen' 12 | ], 13 | }, 14 | version = __version__, 15 | license='MIT', 16 | description = 'Imagen - unprecedented photorealism × deep level of language understanding', 17 | author = 'Phil Wang', 18 | author_email = 'lucidrains@gmail.com', 19 | long_description_content_type = 'text/markdown', 20 | url = 'https://github.com/lucidrains/imagen-pytorch', 21 | keywords = [ 22 | 'artificial intelligence', 23 | 'deep learning', 24 | 'transformers', 25 | 'text-to-image', 26 | 'denoising-diffusion' 27 | ], 28 | install_requires=[ 29 | 'accelerate', 30 | 'click', 31 | 'einops>=0.4', 32 | 'einops-exts', 33 | 'ema-pytorch>=0.0.3', 34 | 'fsspec', 35 | 'kornia', 36 | 'numpy', 37 | 'packaging', 38 | 'pillow', 39 | 'pydantic', 40 | 'pytorch-lightning', 41 | 'pytorch-warmup', 42 | 'sentencepiece', 43 | 'torch>=1.6', 44 | 'torchvision', 45 | 'transformers', 46 | 'tqdm' 47 | ], 48 | classifiers=[ 49 | 'Development Status :: 4 - Beta', 50 | 'Intended Audience :: Developers', 51 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 52 | 'License :: OSI Approved :: MIT License', 53 | 'Programming Language :: Python :: 3.6', 54 | ], 55 | ) 56 | --------------------------------------------------------------------------------