├── .github
├── FUNDING.yml
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── imagen.png
├── imagen_pytorch
├── __init__.py
├── cli.py
├── configs.py
├── data.py
├── elucidated_imagen.py
├── imagen_pytorch.py
├── imagen_video
│ ├── __init__.py
│ └── imagen_video.py
├── t5.py
├── trainer.py
├── utils.py
└── version.py
└── setup.py
/.github/FUNDING.yml:
--------------------------------------------------------------------------------
1 | # These are supported funding model platforms
2 |
3 | github: [lucidrains]
4 |
--------------------------------------------------------------------------------
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Imagen - Pytorch
4 |
5 | Implementation of Imagen, Google's Text-to-Image Neural Network that beats DALL-E2, in Pytorch. It is the new SOTA for text-to-image synthesis.
6 |
7 | Architecturally, it is actually much simpler than DALL-E2. It consists of a cascading DDPM conditioned on text embeddings from a large pretrained T5 model (attention network). It also contains dynamic clipping for improved classifier free guidance, noise level conditioning, and a memory efficient unet design.
8 |
9 | It appears neither CLIP nor prior network is needed after all. And so research continues.
10 |
11 | AI Coffee Break with Letitia | Assembly AI | Yannic Kilcher
12 |
13 | Please join
if you are interested in helping out with the replication with the LAION community
14 |
15 | ## Shoutouts
16 |
17 | - StabilityAI for the generous sponsorship, as well as my other sponsors out there
18 |
19 | - 🤗 Huggingface for their amazing transformers library. The text encoder portion is pretty much taken care of because of them
20 |
21 | - Sylvain and Zachary for the Accelerate library, which this repository uses for distributed training
22 |
23 | - Alex for einops, indispensable tool for tensor manipulation
24 |
25 | - Jorge Gomes for helping out with the T5 loading code and advice on the correct T5 version
26 |
27 | - Katherine Crowson, for her beautiful code, which helped me understand the continuous time version of gaussian diffusion
28 |
29 | - Marunine and Netruk44, for reviewing code, sharing experimental results, and help with debugging
30 |
31 | - Marunine for providing a potential solution for a color shifting issue in the memory efficient u-nets. Thanks to Jacob for sharing experimental comparisons between the base and memory-efficient unets
32 |
33 | - Marunine for finding numerous bugs, resolving an issue with resize right, and for sharing his experimental configurations and results
34 |
35 | - MalumaDev for proposing the use of pixel shuffle upsampler to fix checkboard artifacts
36 |
37 | - Valentin for pointing out insufficient skip connections in the unet, as well as the specific method of attention conditioning in the base-unet in the appendix
38 |
39 | - BIGJUN for catching a big bug with continuous time gaussian diffusion noise level conditioning at inference time
40 |
41 | - Bingbing for identifying a bug with sampling and order of normalizing and noising with low resolution conditioning image
42 |
43 | ## Install
44 |
45 | ```bash
46 | $ pip install imagen-pytorch
47 | ```
48 |
49 | ## Usage
50 |
51 | ```python
52 | import torch
53 | from imagen_pytorch import Unet, Imagen
54 |
55 | # unet for imagen
56 |
57 | unet1 = Unet(
58 | dim = 32,
59 | cond_dim = 512,
60 | dim_mults = (1, 2, 4, 8),
61 | num_resnet_blocks = 3,
62 | layer_attns = (False, True, True, True),
63 | layer_cross_attns = (False, True, True, True)
64 | )
65 |
66 | unet2 = Unet(
67 | dim = 32,
68 | cond_dim = 512,
69 | dim_mults = (1, 2, 4, 8),
70 | num_resnet_blocks = (2, 4, 8, 8),
71 | layer_attns = (False, False, False, True),
72 | layer_cross_attns = (False, False, False, True)
73 | )
74 |
75 | # imagen, which contains the unets above (base unet and super resoluting ones)
76 |
77 | imagen = Imagen(
78 | unets = (unet1, unet2),
79 | image_sizes = (64, 256),
80 | timesteps = 1000,
81 | cond_drop_prob = 0.1
82 | ).cuda()
83 |
84 | # mock images (get a lot of this) and text encodings from large T5
85 |
86 | text_embeds = torch.randn(4, 256, 768).cuda()
87 | images = torch.randn(4, 3, 256, 256).cuda()
88 |
89 | # feed images into imagen, training each unet in the cascade
90 |
91 | for i in (1, 2):
92 | loss = imagen(images, text_embeds = text_embeds, unet_number = i)
93 | loss.backward()
94 |
95 | # do the above for many many many many steps
96 | # now you can sample an image based on the text embeddings from the cascading ddpm
97 |
98 | images = imagen.sample(texts = [
99 | 'a whale breaching from afar',
100 | 'young girl blowing out candles on her birthday cake',
101 | 'fireworks with blue and green sparkles'
102 | ], cond_scale = 3.)
103 |
104 | images.shape # (3, 3, 256, 256)
105 | ```
106 |
107 | For simpler training, you can directly supply text strings instead of precomputing text encodings. (Although for scaling purposes, you will definitely want to precompute the textual embeddings + mask)
108 |
109 | The number of textual captions must match the batch size of the images if you go this route.
110 |
111 | ```python
112 | # mock images and text (get a lot of this)
113 |
114 | texts = [
115 | 'a child screaming at finding a worm within a half-eaten apple',
116 | 'lizard running across the desert on two feet',
117 | 'waking up to a psychedelic landscape',
118 | 'seashells sparkling in the shallow waters'
119 | ]
120 |
121 | images = torch.randn(4, 3, 256, 256).cuda()
122 |
123 | # feed images into imagen, training each unet in the cascade
124 |
125 | for i in (1, 2):
126 | loss = imagen(images, texts = texts, unet_number = i)
127 | loss.backward()
128 | ```
129 |
130 | With the `ImagenTrainer` wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling `update`
131 |
132 | ```python
133 | import torch
134 | from imagen_pytorch import Unet, Imagen, ImagenTrainer
135 |
136 | # unet for imagen
137 |
138 | unet1 = Unet(
139 | dim = 32,
140 | cond_dim = 512,
141 | dim_mults = (1, 2, 4, 8),
142 | num_resnet_blocks = 3,
143 | layer_attns = (False, True, True, True),
144 | )
145 |
146 | unet2 = Unet(
147 | dim = 32,
148 | cond_dim = 512,
149 | dim_mults = (1, 2, 4, 8),
150 | num_resnet_blocks = (2, 4, 8, 8),
151 | layer_attns = (False, False, False, True),
152 | layer_cross_attns = (False, False, False, True)
153 | )
154 |
155 | # imagen, which contains the unets above (base unet and super resoluting ones)
156 |
157 | imagen = Imagen(
158 | unets = (unet1, unet2),
159 | text_encoder_name = 't5-large',
160 | image_sizes = (64, 256),
161 | timesteps = 1000,
162 | cond_drop_prob = 0.1
163 | ).cuda()
164 |
165 | # wrap imagen with the trainer class
166 |
167 | trainer = ImagenTrainer(imagen)
168 |
169 | # mock images (get a lot of this) and text encodings from large T5
170 |
171 | text_embeds = torch.randn(64, 256, 1024).cuda()
172 | images = torch.randn(64, 3, 256, 256).cuda()
173 |
174 | # feed images into imagen, training each unet in the cascade
175 |
176 | loss = trainer(
177 | images,
178 | text_embeds = text_embeds,
179 | unet_number = 1, # training on unet number 1 in this example, but you will have to also save checkpoints and then reload and continue training on unet number 2
180 | max_batch_size = 4 # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
181 | )
182 |
183 | trainer.update(unet_number = 1)
184 |
185 | # do the above for many many many many steps
186 | # now you can sample an image based on the text embeddings from the cascading ddpm
187 |
188 | images = trainer.sample(texts = [
189 | 'a puppy looking anxiously at a giant donut on the table',
190 | 'the milky way galaxy in the style of monet'
191 | ], cond_scale = 3.)
192 |
193 | images.shape # (2, 3, 256, 256)
194 | ```
195 |
196 | You can also train Imagen without text (unconditional image generation) as follows
197 |
198 | ```python
199 | import torch
200 | from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer
201 |
202 | # unets for unconditional imagen
203 |
204 | unet1 = Unet(
205 | dim = 32,
206 | dim_mults = (1, 2, 4),
207 | num_resnet_blocks = 3,
208 | layer_attns = (False, True, True),
209 | layer_cross_attns = False,
210 | use_linear_attn = True
211 | )
212 |
213 | unet2 = SRUnet256(
214 | dim = 32,
215 | dim_mults = (1, 2, 4),
216 | num_resnet_blocks = (2, 4, 8),
217 | layer_attns = (False, False, True),
218 | layer_cross_attns = False
219 | )
220 |
221 | # imagen, which contains the unets above (base unet and super resoluting ones)
222 |
223 | imagen = Imagen(
224 | condition_on_text = False, # this must be set to False for unconditional Imagen
225 | unets = (unet1, unet2),
226 | image_sizes = (64, 128),
227 | timesteps = 1000
228 | )
229 |
230 | trainer = ImagenTrainer(imagen).cuda()
231 |
232 | # now get a ton of images and feed it through the Imagen trainer
233 |
234 | training_images = torch.randn(4, 3, 256, 256).cuda()
235 |
236 | # train each unet separately
237 | # in this example, only training on unet number 1
238 |
239 | loss = trainer(training_images, unet_number = 1)
240 | trainer.update(unet_number = 1)
241 |
242 | # do the above for many many many many steps
243 | # now you can sample images unconditionally from the cascading unet(s)
244 |
245 | images = trainer.sample(batch_size = 16) # (16, 3, 128, 128)
246 | ```
247 |
248 | Or train only super-resoluting unets
249 |
250 | ```python
251 | import torch
252 | from imagen_pytorch import Unet, NullUnet, Imagen
253 |
254 | # unet for imagen
255 |
256 | unet1 = NullUnet() # add a placeholder "null" unet for the base unet
257 |
258 | unet2 = Unet(
259 | dim = 32,
260 | cond_dim = 512,
261 | dim_mults = (1, 2, 4, 8),
262 | num_resnet_blocks = (2, 4, 8, 8),
263 | layer_attns = (False, False, False, True),
264 | layer_cross_attns = (False, False, False, True)
265 | )
266 |
267 | # imagen, which contains the unets above (base unet and super resoluting ones)
268 |
269 | imagen = Imagen(
270 | unets = (unet1, unet2),
271 | image_sizes = (64, 256),
272 | timesteps = 250,
273 | cond_drop_prob = 0.1
274 | ).cuda()
275 |
276 | # mock images (get a lot of this) and text encodings from large T5
277 |
278 | text_embeds = torch.randn(4, 256, 768).cuda()
279 | images = torch.randn(4, 3, 256, 256).cuda()
280 |
281 | # feed images into imagen, training each unet in the cascade
282 |
283 | loss = imagen(images, text_embeds = text_embeds, unet_number = 2)
284 | loss.backward()
285 |
286 | # do the above for many many many many steps
287 | # now you can sample an image based on the text embeddings as well as low resolution images
288 |
289 | lowres_images = torch.randn(3, 3, 64, 64).cuda() # starting un-resoluted images
290 |
291 | images = imagen.sample(
292 | texts = [
293 | 'a whale breaching from afar',
294 | 'young girl blowing out candles on her birthday cake',
295 | 'fireworks with blue and green sparkles'
296 | ],
297 | start_at_unet_number = 2, # start at unet number 2
298 | start_image_or_video = lowres_images, # pass in low resolution images to be resoluted
299 | cond_scale = 3.)
300 |
301 | images.shape # (3, 3, 256, 256)
302 | ```
303 |
304 | At any time you can save and load the trainer and all associated states with the `save` and `load` methods. It is recommended you use these methods instead of manually saving with a `state_dict` call, as there are some device memory management being done underneath the hood within the trainer.
305 |
306 | ex.
307 |
308 | ```python
309 | trainer.save('./path/to/checkpoint.pt')
310 |
311 | trainer.load('./path/to/checkpoint.pt')
312 |
313 | trainer.steps # (2,) step number for each of the unets, in this case 2
314 | ```
315 |
316 | ## Dataloader
317 |
318 | You can also rely on the `ImagenTrainer` to automatically train off `DataLoader` instances. You simply have to craft your `DataLoader` to return either `images` (for unconditional case), or of `('images', 'text_embeds')` for text-guided generation.
319 |
320 | ex. unconditional training
321 |
322 | ```python
323 | from imagen_pytorch import Unet, Imagen, ImagenTrainer
324 | from imagen_pytorch.data import Dataset
325 |
326 | # unets for unconditional imagen
327 |
328 | unet = Unet(
329 | dim = 32,
330 | dim_mults = (1, 2, 4, 8),
331 | num_resnet_blocks = 1,
332 | layer_attns = (False, False, False, True),
333 | layer_cross_attns = False
334 | )
335 |
336 | # imagen, which contains the unet above
337 |
338 | imagen = Imagen(
339 | condition_on_text = False, # this must be set to False for unconditional Imagen
340 | unets = unet,
341 | image_sizes = 128,
342 | timesteps = 1000
343 | )
344 |
345 | trainer = ImagenTrainer(
346 | imagen = imagen,
347 | split_valid_from_train = True # whether to split the validation dataset from the training
348 | ).cuda()
349 |
350 | # instantiate your dataloader, which returns the necessary inputs to the DDPM as tuple in the order of images, text embeddings, then text masks. in this case, only images is returned as it is unconditional training
351 |
352 | dataset = Dataset('/path/to/training/images', image_size = 128)
353 |
354 | trainer.add_train_dataset(dataset, batch_size = 16)
355 |
356 | # working training loop
357 |
358 | for i in range(200000):
359 | loss = trainer.train_step(unet_number = 1, max_batch_size = 4)
360 | print(f'loss: {loss}')
361 |
362 | if not (i % 50):
363 | valid_loss = trainer.valid_step(unet_number = 1, max_batch_size = 4)
364 | print(f'valid loss: {valid_loss}')
365 |
366 | if not (i % 100) and trainer.is_main: # is_main makes sure this can run in distributed
367 | images = trainer.sample(batch_size = 1, return_pil_images = True) # returns List[Image]
368 | images[0].save(f'./sample-{i // 100}.png')
369 |
370 | ```
371 |
372 | ## Multi GPU
373 |
374 | Thanks to 🤗 Accelerate, you can do multi GPU training easily with two steps.
375 |
376 | First you need to invoke `accelerate config` in the same directory as your training script (say it is named `train.py`)
377 |
378 | ```bash
379 | $ accelerate config
380 | ```
381 |
382 | Next, instead of calling `python train.py` as you would for single GPU, you would use the accelerate CLI as so
383 |
384 | ```bash
385 | $ accelerate launch train.py
386 | ```
387 |
388 | That's it!
389 |
390 | ## Command-line
391 |
392 | To further democratize the use of this machine imagination, I have built in the ability to generate an image with any text prompt using one command line as so
393 |
394 | ex.
395 |
396 | ```bash
397 | $ imagen --model ./path/to/model/checkpoint.pt "a squirrel raiding the birdfeeder"
398 | # image is saved to ./a_squirrel_raiding_the_birdfeeder.png
399 | ```
400 |
401 | In order to save checkpoints that can make use of this feature, you must instantiate your Imagen instance using the config classes, `ImagenConfig` and `ElucidatedImagenConfig`
402 |
403 | For proper training, you'll likely want to setup config-driven training anyways.
404 |
405 | ex.
406 |
407 | ```python
408 | import torch
409 | from imagen_pytorch import ImagenConfig, ElucidatedImagenConfig, ImagenTrainer
410 |
411 | # in this example, using elucidated imagen
412 |
413 | imagen = ElucidatedImagenConfig(
414 | unets = [
415 | dict(dim = 32, dim_mults = (1, 2, 4, 8)),
416 | dict(dim = 32, dim_mults = (1, 2, 4, 8))
417 | ],
418 | image_sizes = (64, 128),
419 | cond_drop_prob = 0.5,
420 | num_sample_steps = 32
421 | ).create()
422 |
423 | trainer = ImagenTrainer(imagen)
424 |
425 | # do your training ...
426 |
427 | # then save it
428 |
429 | trainer.save('./checkpoint.pt')
430 |
431 | # you should see a message informing you that ./checkpoint.pt is commandable from the terminal
432 | ```
433 |
434 | It really should be as simple as that
435 |
436 | You can also pass this checkpoint file around, and anyone can continue finetune on their own data
437 |
438 | ```python
439 | from imagen_pytorch import load_imagen_from_checkpoint, ImagenTrainer
440 |
441 | imagen = load_imagen_from_checkpoint('./checkpoint.pt')
442 |
443 | trainer = ImagenTrainer(imagen)
444 |
445 | # continue training / fine-tuning
446 | ```
447 |
448 | ## Inpainting
449 |
450 | Inpainting follows the formulation laid out by the recent Repaint paper. Simply pass in `inpaint_images` and `inpaint_masks` to the `sample` function on either `Imagen` or `ElucidatedImagen`
451 |
452 | ```python
453 |
454 | inpaint_images = torch.randn(4, 3, 512, 512).cuda() # (batch, channels, height, width)
455 | inpaint_masks = torch.ones((4, 512, 512)).bool().cuda() # (batch, height, width)
456 |
457 | inpainted_images = trainer.sample(texts = [
458 | 'a whale breaching from afar',
459 | 'young girl blowing out candles on her birthday cake',
460 | 'fireworks with blue and green sparkles',
461 | 'dust motes swirling in the morning sunshine on the windowsill'
462 | ], inpaint_images = inpaint_images, inpaint_masks = inpaint_masks, cond_scale = 5.)
463 |
464 | inpainted_images # (4, 3, 512, 512)
465 | ```
466 |
467 | ## Experimental
468 |
469 | Tero Karras of StyleGAN fame has written a new paper with results that have been corroborated by a number of independent researchers as well as on my own machine. I have decided to create a version of `Imagen`, the `ElucidatedImagen`, so that one can use the new elucidated DDPM for text-guided cascading generation.
470 |
471 | Simply import `ElucidatedImagen`, and then instantiate the instance as you did before. The hyperparameters are different than the usual ones for discrete and continuous time gaussian diffusion, and can be individualized for each unet in the cascade.
472 |
473 | Ex.
474 |
475 | ```python
476 | from imagen_pytorch import ElucidatedImagen
477 |
478 | # instantiate your unets ...
479 |
480 | imagen = ElucidatedImagen(
481 | unets = (unet1, unet2),
482 | image_sizes = (64, 128),
483 | cond_drop_prob = 0.1,
484 | num_sample_steps = (64, 32), # number of sample steps - 64 for base unet, 32 for upsampler (just an example, have no clue what the optimal values are)
485 | sigma_min = 0.002, # min noise level
486 | sigma_max = (80, 160), # max noise level, @crowsonkb recommends double the max noise level for upsampler
487 | sigma_data = 0.5, # standard deviation of data distribution
488 | rho = 7, # controls the sampling schedule
489 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
490 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
491 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
492 | S_tmin = 0.05,
493 | S_tmax = 50,
494 | S_noise = 1.003,
495 | ).cuda()
496 |
497 | # rest is the same as above
498 |
499 | ```
500 |
501 | ## Text to Video (ongoing research)
502 |
503 | This repository will also start accumulating new research around text guided video synthesis. For starters it will adopt the 3d unet architecture described by Jonathan Ho in Video Diffusion Models
504 |
505 | Ex.
506 |
507 | ```python
508 | import torch
509 | from imagen_pytorch import Unet3D, ElucidatedImagen, ImagenTrainer
510 |
511 | unet1 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()
512 |
513 | unet2 = Unet3D(dim = 64, dim_mults = (1, 2, 4, 8)).cuda()
514 |
515 | # elucidated imagen, which contains the unets above (base unet and super resoluting ones)
516 |
517 | imagen = ElucidatedImagen(
518 | unets = (unet1, unet2),
519 | image_sizes = (16, 32),
520 | random_crop_sizes = (None, 16),
521 | num_sample_steps = 10,
522 | cond_drop_prob = 0.1,
523 | sigma_min = 0.002, # min noise level
524 | sigma_max = (80, 160), # max noise level, double the max noise level for upsampler
525 | sigma_data = 0.5, # standard deviation of data distribution
526 | rho = 7, # controls the sampling schedule
527 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
528 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
529 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
530 | S_tmin = 0.05,
531 | S_tmax = 50,
532 | S_noise = 1.003,
533 | ).cuda()
534 |
535 | # mock videos (get a lot of this) and text encodings from large T5
536 |
537 | texts = [
538 | 'a whale breaching from afar',
539 | 'young girl blowing out candles on her birthday cake',
540 | 'fireworks with blue and green sparkles',
541 | 'dust motes swirling in the morning sunshine on the windowsill'
542 | ]
543 |
544 | videos = torch.randn(4, 3, 10, 32, 32).cuda() # (batch, channels, time / video frames, height, width)
545 |
546 | # feed images into imagen, training each unet in the cascade
547 | # for this example, only training unet 1
548 |
549 | trainer = ImagenTrainer(imagen)
550 | trainer(videos, texts = texts, unet_number = 1)
551 | trainer.update(unet_number = 1)
552 |
553 | videos = trainer.sample(texts = texts, video_frames = 20) # extrapolating to 20 frames from training on 10 frames
554 |
555 | videos.shape # (4, 3, 20, 32, 32)
556 |
557 | ```
558 |
559 | ## FAQ
560 |
561 | - Why are my generated images not aligning well with the text?
562 |
563 | Imagen uses an algorithm called Classifier Free Guidance. When sampling, you apply a scale to the conditioning (text in this case) of greater than `1.0`.
564 |
565 | Researcher Netruk44 have reported `5-10` to be optimal, but anything greater than `10` to break.
566 |
567 | ```python
568 | trainer.sample(texts = [
569 | 'a cloud in the shape of a roman gladiator'
570 | ], cond_scale = 5.) # <-- cond_scale is the conditioning scale, needs to be greater than 1.0 to be better than average
571 | ```
572 |
573 | - Are there any pretrained models yet?
574 |
575 | Not at the moment but one will likely be trained and open sourced within the year, if not sooner. If you would like to participate, you can join the community of artificial neural network trainers at Laion (discord link is in the Readme above) and start collaborating.
576 |
577 | - Will this technology take my job?
578 |
579 | More the reason why you should start training your own model, starting today! The last thing we need is this technology being in the hands of an elite few. Hopefully this repository reduces the work to just finding the necessary compute, and augmenting with your own curated dataset.
580 |
581 | - What am I allowed to do with this repository?
582 |
583 | Anything! It is MIT licensed. In other words, you can freely copy / paste for your own research, remixed for whatever modality you can think of. Go train amazing models for profit, for science, or simply to satiate your own personal pleasure at witnessing something divine unravel in front of you.
584 |
585 | ## Related Works
586 |
587 | - Audio diffusion from Flavio Schneider
588 |
589 | - Mini Imagen from Ryan O. | AssemblyAI writeup
590 |
591 | ## Todo
592 |
593 | - [x] use huggingface transformers for T5-small text embeddings
594 | - [x] add dynamic thresholding
595 | - [x] add dynamic thresholding DALLE2 and video-diffusion repository as well
596 | - [x] allow for one to set T5-large (and perhaps small factory method to take in any huggingface transformer)
597 | - [x] add the lowres noise level with the pseudocode in appendix, and figure out what is this sweep they do at inference time
598 | - [x] port over some training code from DALLE2
599 | - [x] need to be able to use a different noise schedule per unet (cosine was used for base, but linear for SR)
600 | - [x] just make one master-configurable unet
601 | - [x] complete resnet block (biggan inspired? but with groupnorm) - complete self attention
602 | - [x] complete conditioning embedding block (and make it completely configurable, whether it be attention, film etc)
603 | - [x] consider using perceiver-resampler from https://github.com/lucidrains/flamingo-pytorch in place of attention pooling
604 | - [x] add attention pooling option, in addition to cross attention and film
605 | - [x] add optional cosine decay schedule with warmup, for each unet, to trainer
606 | - [x] switch to continuous timesteps instead of discretized, as it seems that is what they used for all stages - first figure out the linear noise schedule case from the variational ddpm paper https://openreview.net/forum?id=2LdBqxc1Yv
607 | - [x] figure out log(snr) for alpha cosine noise schedule.
608 | - [x] suppress the transformers warning because only T5encoder is used
609 | - [x] allow setting for using linear attention on layers where full attention cannot be used
610 | - [x] force unets in continuous time case to use non-fouriered conditions (just pass the log(snr) through an MLP with optional layernorms), as that is what i have working locally
611 | - [x] removed learned variance
612 | - [x] add p2 loss weighting for continuous time
613 | - [x] make sure cascading ddpm can be trained without text condition, and make sure both continuous and discrete time gaussian diffusion works
614 | - [x] use primer's depthwise convs on the qkv projections in linear attention (or use token shifting before projections) - also use new dropout proposed by bayesformer, as it seems to work well with linear attention
615 | - [x] explore skip layer excitation in unet decoder
616 | - [x] accelerate integration
617 | - [x] build out CLI tool and one-line generation of image
618 | - [x] knock out any issues that arised from accelerate
619 | - [x] add inpainting ability using resampler from repaint paper https://arxiv.org/abs/2201.09865
620 | - [x] build a simple checkpointing system, backed by a folder
621 | - [x] add skip connection from outputs of all upsample blocks, used in unet squared paper and some previous unet works
622 | - [x] add fsspec, recommended by Romain @rom1504, for cloud / local file system agnostic persistence of checkpoints
623 | - [x] test out persistence in gcs with https://github.com/fsspec/gcsfs
624 | - [x] extend to video generation, using axial time attention as in Ho's video ddpm paper
625 | - [x] allow elucidated imagen to generalize to any shape
626 | - [x] allow for imagen to generalize to any shape
627 | - [x] add dynamic positional bias for the best type of length extrapolation across video time
628 | - [x] move video frames to sample function, as we will be attempting time extrapolation
629 | - [x] attention bias to null key / values should be a learned scalar of head dimension
630 | - [x] add self-conditioning from bit diffusion paper, already coded up at ddpm-pytorch
631 | - [ ] reread cogvideo and figure out how frame rate conditioning could be used
632 | - [ ] bring in attention expertise for self attention layers in unet3d
633 | - [ ] consider bringing in NUWA's 3d convolutional attention
634 | - [ ] consider transformer-xl memories in the temporal attention blocks
635 | - [ ] consider perceiver-ar approach to attending to past time
636 | - [ ] frame dropouts during attention for achieving both regularizing effect as well as shortened training time
637 | - [ ] investigate frank wood's claims https://github.com/lucidrains/flexible-diffusion-modeling-videos-pytorch and either add the hierarchical sampling technique, or let people know about its deficiencies
638 | - [ ] make sure inpainting works with video
639 | - [ ] offer challenging moving mnist (with distractor objects) as a one-line trainable baseline for researchers to branch off of for text to video
640 | - [ ] build out CLI tool for training, resuming training off config file
641 | - [ ] preencoding of text to memmapped embeddings
642 | - [ ] be able to create dataloader iterators based on the old epoch style, also configure shuffling etc
643 | - [ ] be able to also pass in arguments (instead of requiring forward to be all keyword args on model)
644 | - [ ] bring in reversible blocks from revnets for 3d unet, to lessen memory burden
645 | - [ ] add ability to only train super-resolution network
646 | - [ ] read dpm-solver see if it is applicable to continuous time gaussian diffusion
647 | - [ ] allow for conditioning video frames with arbitrary absolute times (calculate RPE during temporal attention)
648 | - [ ] accommodate dream booth fine tuning
649 | - [ ] add textual inversion
650 | - [ ] cleanup self conditioning to be extracted at imagen instantiation
651 | - [ ] incorporate all learnings from make-a-video (https://makeavideo.studio/)
652 | - [ ] add v-parameterization (https://arxiv.org/abs/2202.00512) from imagen video paper, the only thing new
653 |
654 | ## Citations
655 |
656 | ```bibtex
657 | @inproceedings{Saharia2022PhotorealisticTD,
658 | title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
659 | author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily L. Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and Seyedeh Sara Mahdavi and Raphael Gontijo Lopes and Tim Salimans and Jonathan Ho and David Fleet and Mohammad Norouzi},
660 | year = {2022}
661 | }
662 | ```
663 |
664 | ```bibtex
665 | @article{Alayrac2022Flamingo,
666 | title = {Flamingo: a Visual Language Model for Few-Shot Learning},
667 | author = {Jean-Baptiste Alayrac et al},
668 | year = {2022}
669 | }
670 | ```
671 |
672 | ```bibtex
673 | @article{Choi2022PerceptionPT,
674 | title = {Perception Prioritized Training of Diffusion Models},
675 | author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
676 | journal = {ArXiv},
677 | year = {2022},
678 | volume = {abs/2204.00227}
679 | }
680 | ```
681 |
682 | ```bibtex
683 | @inproceedings{Sankararaman2022BayesFormerTW,
684 | title = {BayesFormer: Transformer with Uncertainty Estimation},
685 | author = {Karthik Abinav Sankararaman and Sinong Wang and Han Fang},
686 | year = {2022}
687 | }
688 | ```
689 |
690 | ```bibtex
691 | @article{So2021PrimerSF,
692 | title = {Primer: Searching for Efficient Transformers for Language Modeling},
693 | author = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le},
694 | journal = {ArXiv},
695 | year = {2021},
696 | volume = {abs/2109.08668}
697 | }
698 | ```
699 |
700 | ```bibtex
701 | @misc{cao2020global,
702 | title = {Global Context Networks},
703 | author = {Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
704 | year = {2020},
705 | eprint = {2012.13375},
706 | archivePrefix = {arXiv},
707 | primaryClass = {cs.CV}
708 | }
709 | ```
710 |
711 | ```bibtex
712 | @article{Karras2022ElucidatingTD,
713 | title = {Elucidating the Design Space of Diffusion-Based Generative Models},
714 | author = {Tero Karras and Miika Aittala and Timo Aila and Samuli Laine},
715 | journal = {ArXiv},
716 | year = {2022},
717 | volume = {abs/2206.00364}
718 | }
719 | ```
720 |
721 | ```bibtex
722 | @inproceedings{NEURIPS2020_4c5bcfec,
723 | author = {Ho, Jonathan and Jain, Ajay and Abbeel, Pieter},
724 | booktitle = {Advances in Neural Information Processing Systems},
725 | editor = {H. Larochelle and M. Ranzato and R. Hadsell and M.F. Balcan and H. Lin},
726 | pages = {6840--6851},
727 | publisher = {Curran Associates, Inc.},
728 | title = {Denoising Diffusion Probabilistic Models},
729 | url = {https://proceedings.neurips.cc/paper/2020/file/4c5bcfec8584af0d967f1ab10179ca4b-Paper.pdf},
730 | volume = {33},
731 | year = {2020}
732 | }
733 | ```
734 |
735 | ```bibtex
736 | @article{Lugmayr2022RePaintIU,
737 | title = {RePaint: Inpainting using Denoising Diffusion Probabilistic Models},
738 | author = {Andreas Lugmayr and Martin Danelljan and Andr{\'e}s Romero and Fisher Yu and Radu Timofte and Luc Van Gool},
739 | journal = {ArXiv},
740 | year = {2022},
741 | volume = {abs/2201.09865}
742 | }
743 | ```
744 |
745 | ```bibtex
746 | @misc{ho2022video,
747 | title = {Video Diffusion Models},
748 | author = {Jonathan Ho and Tim Salimans and Alexey Gritsenko and William Chan and Mohammad Norouzi and David J. Fleet},
749 | year = {2022},
750 | eprint = {2204.03458},
751 | archivePrefix = {arXiv},
752 | primaryClass = {cs.CV}
753 | }
754 | ```
755 |
756 | ```bibtex
757 | @inproceedings{rogozhnikov2022einops,
758 | title = {Einops: Clear and Reliable Tensor Manipulations with Einstein-like Notation},
759 | author = {Alex Rogozhnikov},
760 | booktitle = {International Conference on Learning Representations},
761 | year = {2022},
762 | url = {https://openreview.net/forum?id=oapKSVM2bcj}
763 | }
764 | ```
765 |
766 | ```bibtex
767 | @misc{chen2022analog,
768 | title = {Analog Bits: Generating Discrete Data using Diffusion Models with Self-Conditioning},
769 | author = {Ting Chen and Ruixiang Zhang and Geoffrey Hinton},
770 | year = {2022},
771 | eprint = {2208.04202},
772 | archivePrefix = {arXiv},
773 | primaryClass = {cs.CV}
774 | }
775 | ```
776 |
777 | ```bibtex
778 | @article{Sunkara2022NoMS,
779 | title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
780 | author = {Raja Sunkara and Tie Luo},
781 | journal = {ArXiv},
782 | year = {2022},
783 | volume = {abs/2208.03641}
784 | }
785 | ```
786 |
--------------------------------------------------------------------------------
/imagen.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/abishakdevarasan/imagen-pytorch/fa29d249a8bdd7b97ae0da8da02261fc69292b72/imagen.png
--------------------------------------------------------------------------------
/imagen_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from imagen_pytorch.imagen_pytorch import Imagen, Unet
2 | from imagen_pytorch.imagen_pytorch import NullUnet
3 | from imagen_pytorch.imagen_pytorch import BaseUnet64, SRUnet256, SRUnet1024
4 | from imagen_pytorch.trainer import ImagenTrainer
5 | from imagen_pytorch.version import __version__
6 |
7 | # imagen using the elucidated ddpm from Tero Karras' new paper
8 |
9 | from imagen_pytorch.elucidated_imagen import ElucidatedImagen
10 |
11 | # config driven creation of imagen instances
12 |
13 | from imagen_pytorch.configs import UnetConfig, ImagenConfig, ElucidatedImagenConfig, ImagenTrainerConfig
14 |
15 | # utils
16 |
17 | from imagen_pytorch.utils import load_imagen_from_checkpoint
18 |
19 | # video
20 |
21 | from imagen_pytorch.imagen_video import Unet3D
22 |
--------------------------------------------------------------------------------
/imagen_pytorch/cli.py:
--------------------------------------------------------------------------------
1 | import click
2 | import torch
3 | from pathlib import Path
4 |
5 | from imagen_pytorch import load_imagen_from_checkpoint
6 | from imagen_pytorch.version import __version__
7 | from imagen_pytorch.utils import safeget
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | def simple_slugify(text, max_length = 255):
13 | return text.replace("-", "_").replace(",", "").replace(" ", "_").replace("|", "--").strip('-_')[:max_length]
14 |
15 | def main():
16 | pass
17 |
18 | @click.command()
19 | @click.option('--model', default = './imagen.pt', help = 'path to trained Imagen model')
20 | @click.option('--cond_scale', default = 5, help = 'conditioning scale (classifier free guidance) in decoder')
21 | @click.option('--load_ema', default = True, help = 'load EMA version of unets if available')
22 | @click.argument('text')
23 | def imagen(
24 | model,
25 | cond_scale,
26 | load_ema,
27 | text
28 | ):
29 | model_path = Path(model)
30 | full_model_path = str(model_path.resolve())
31 | assert model_path.exists(), f'model not found at {full_model_path}'
32 | loaded = torch.load(str(model_path))
33 |
34 | # get version
35 |
36 | version = safeget(loaded, 'version')
37 | print(f'loading Imagen from {full_model_path}, saved at version {version} - current package version is {__version__}')
38 |
39 | # get imagen parameters and type
40 |
41 | imagen = load_imagen_from_checkpoint(str(model_path), load_ema_if_available = load_ema)
42 | imagen.cuda()
43 |
44 | # generate image
45 |
46 | pil_image = imagen.sample(text, cond_scale = cond_scale, return_pil_images = True)
47 |
48 | image_path = f'./{simple_slugify(text)}.png'
49 | pil_image[0].save(image_path)
50 |
51 | print(f'image saved to {str(image_path)}')
52 | return
53 |
--------------------------------------------------------------------------------
/imagen_pytorch/configs.py:
--------------------------------------------------------------------------------
1 | import json
2 | from pydantic import BaseModel, validator, root_validator
3 | from typing import List, Iterable, Optional, Union, Tuple, Dict, Any
4 | from enum import Enum
5 |
6 | from imagen_pytorch.imagen_pytorch import Imagen, Unet, Unet3D, NullUnet
7 | from imagen_pytorch.trainer import ImagenTrainer
8 | from imagen_pytorch.elucidated_imagen import ElucidatedImagen
9 | from imagen_pytorch.t5 import DEFAULT_T5_NAME, get_encoded_dim
10 |
11 | # helper functions
12 |
13 | def exists(val):
14 | return val is not None
15 |
16 | def default(val, d):
17 | return val if exists(val) else d
18 |
19 | def ListOrTuple(inner_type):
20 | return Union[List[inner_type], Tuple[inner_type]]
21 |
22 | def SingleOrList(inner_type):
23 | return Union[inner_type, ListOrTuple(inner_type)]
24 |
25 | # noise schedule
26 |
27 | class NoiseSchedule(Enum):
28 | cosine = 'cosine'
29 | linear = 'linear'
30 |
31 | class AllowExtraBaseModel(BaseModel):
32 | class Config:
33 | extra = "allow"
34 | use_enum_values = True
35 |
36 | # imagen pydantic classes
37 |
38 | class NullUnetConfig(BaseModel):
39 | is_null: bool
40 |
41 | def create(self):
42 | return NullUnet()
43 |
44 | class UnetConfig(AllowExtraBaseModel):
45 | dim: int
46 | dim_mults: ListOrTuple(int)
47 | text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
48 | cond_dim: int = None
49 | channels: int = 3
50 | attn_dim_head: int = 32
51 | attn_heads: int = 16
52 |
53 | def create(self):
54 | return Unet(**self.dict())
55 |
56 | class Unet3DConfig(AllowExtraBaseModel):
57 | dim: int
58 | dim_mults: ListOrTuple(int)
59 | text_embed_dim: int = get_encoded_dim(DEFAULT_T5_NAME)
60 | cond_dim: int = None
61 | channels: int = 3
62 | attn_dim_head: int = 32
63 | attn_heads: int = 16
64 |
65 | def create(self):
66 | return Unet3D(**self.dict())
67 |
68 | class ImagenConfig(AllowExtraBaseModel):
69 | unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
70 | image_sizes: ListOrTuple(int)
71 | video: bool = False
72 | timesteps: SingleOrList(int) = 1000
73 | noise_schedules: SingleOrList(NoiseSchedule) = 'cosine'
74 | text_encoder_name: str = DEFAULT_T5_NAME
75 | channels: int = 3
76 | loss_type: str = 'l2'
77 | cond_drop_prob: float = 0.5
78 |
79 | @validator('image_sizes')
80 | def check_image_sizes(cls, image_sizes, values):
81 | unets = values.get('unets')
82 | if len(image_sizes) != len(unets):
83 | raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
84 | return image_sizes
85 |
86 | def create(self):
87 | decoder_kwargs = self.dict()
88 | unets_kwargs = decoder_kwargs.pop('unets')
89 | is_video = decoder_kwargs.pop('video', False)
90 |
91 | unets = []
92 |
93 | for unet, unet_kwargs in zip(self.unets, unets_kwargs):
94 | if isinstance(unet, NullUnetConfig):
95 | unet_klass = NullUnet
96 | elif is_video:
97 | unet_klass = Unet3D
98 | else:
99 | unet_klass = Unet
100 |
101 | unets.append(unet_klass(**unet_kwargs))
102 |
103 | imagen = Imagen(unets, **decoder_kwargs)
104 |
105 | imagen._config = self.dict().copy()
106 | return imagen
107 |
108 | class ElucidatedImagenConfig(AllowExtraBaseModel):
109 | unets: ListOrTuple(Union[UnetConfig, Unet3DConfig, NullUnetConfig])
110 | image_sizes: ListOrTuple(int)
111 | video: bool = False
112 | text_encoder_name: str = DEFAULT_T5_NAME
113 | channels: int = 3
114 | cond_drop_prob: float = 0.5
115 | num_sample_steps: SingleOrList(int) = 32
116 | sigma_min: SingleOrList(float) = 0.002
117 | sigma_max: SingleOrList(int) = 80
118 | sigma_data: SingleOrList(float) = 0.5
119 | rho: SingleOrList(int) = 7
120 | P_mean: SingleOrList(float) = -1.2
121 | P_std: SingleOrList(float) = 1.2
122 | S_churn: SingleOrList(int) = 80
123 | S_tmin: SingleOrList(float) = 0.05
124 | S_tmax: SingleOrList(int) = 50
125 | S_noise: SingleOrList(float) = 1.003
126 |
127 | @validator('image_sizes')
128 | def check_image_sizes(cls, image_sizes, values):
129 | unets = values.get('unets')
130 | if len(image_sizes) != len(unets):
131 | raise ValueError(f'image sizes length {len(image_sizes)} must be equivalent to the number of unets {len(unets)}')
132 | return image_sizes
133 |
134 | def create(self):
135 | decoder_kwargs = self.dict()
136 | unets_kwargs = decoder_kwargs.pop('unets')
137 | is_video = decoder_kwargs.pop('video', False)
138 |
139 | unet_klass = Unet3D if is_video else Unet
140 |
141 | unets = []
142 |
143 | for unet, unet_kwargs in zip(self.unets, unets_kwargs):
144 | if isinstance(unet, NullUnetConfig):
145 | unet_klass = NullUnet
146 | elif is_video:
147 | unet_klass = Unet3D
148 | else:
149 | unet_klass = Unet
150 |
151 | unets.append(unet_klass(**unet_kwargs))
152 |
153 | imagen = ElucidatedImagen(unets, **decoder_kwargs)
154 |
155 | imagen._config = self.dict().copy()
156 | return imagen
157 |
158 | class ImagenTrainerConfig(AllowExtraBaseModel):
159 | imagen: dict
160 | elucidated: bool = False
161 | video: bool = False
162 | use_ema: bool = True
163 | lr: SingleOrList(float) = 1e-4
164 | eps: SingleOrList(float) = 1e-8
165 | beta1: float = 0.9
166 | beta2: float = 0.99
167 | max_grad_norm: Optional[float] = None
168 | group_wd_params: bool = True
169 | warmup_steps: SingleOrList(Optional[int]) = None
170 | cosine_decay_max_steps: SingleOrList(Optional[int]) = None
171 |
172 | def create(self):
173 | trainer_kwargs = self.dict()
174 |
175 | imagen_config = trainer_kwargs.pop('imagen')
176 | elucidated = trainer_kwargs.pop('elucidated')
177 |
178 | imagen_config_klass = ElucidatedImagenConfig if elucidated else ImagenConfig
179 | imagen = imagen_config_klass(**{**imagen_config, 'video': video}).create()
180 |
181 | return ImagenTrainer(imagen, **trainer_kwargs)
182 |
--------------------------------------------------------------------------------
/imagen_pytorch/data.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from functools import partial
3 |
4 | import torch
5 | from torch import nn
6 | from torch.utils.data import Dataset, DataLoader
7 | from torchvision import transforms as T, utils
8 |
9 | from PIL import Image
10 |
11 | # helpers functions
12 |
13 | def exists(val):
14 | return val is not None
15 |
16 | def cycle(dl):
17 | while True:
18 | for data in dl:
19 | yield data
20 |
21 | def convert_image_to(img_type, image):
22 | if image.mode != img_type:
23 | return image.convert(img_type)
24 | return image
25 |
26 | # dataset and dataloader
27 |
28 | class Dataset(Dataset):
29 | def __init__(
30 | self,
31 | folder,
32 | image_size,
33 | exts = ['jpg', 'jpeg', 'png', 'tiff'],
34 | convert_image_to_type = None
35 | ):
36 | super().__init__()
37 | self.folder = folder
38 | self.image_size = image_size
39 | self.paths = [p for ext in exts for p in Path(f'{folder}').glob(f'**/*.{ext}')]
40 |
41 | convert_fn = partial(convert_image_to, convert_image_to_type) if exists(convert_image_to_type) else nn.Identity()
42 |
43 | self.transform = T.Compose([
44 | T.Lambda(convert_fn),
45 | T.Resize(image_size),
46 | T.RandomHorizontalFlip(),
47 | T.CenterCrop(image_size),
48 | T.ToTensor()
49 | ])
50 |
51 | def __len__(self):
52 | return len(self.paths)
53 |
54 | def __getitem__(self, index):
55 | path = self.paths[index]
56 | img = Image.open(path)
57 | return self.transform(img)
58 |
59 | def get_images_dataloader(
60 | folder,
61 | *,
62 | batch_size,
63 | image_size,
64 | shuffle = True,
65 | cycle_dl = False,
66 | pin_memory = True
67 | ):
68 | ds = Dataset(folder, image_size)
69 | dl = DataLoader(ds, batch_size = batch_size, shuffle = shuffle, pin_memory = pin_memory)
70 |
71 | if cycle_dl:
72 | dl = cycle(dl)
73 | return dl
74 |
--------------------------------------------------------------------------------
/imagen_pytorch/elucidated_imagen.py:
--------------------------------------------------------------------------------
1 | from math import sqrt
2 | from random import random
3 | from functools import partial
4 | from contextlib import contextmanager, nullcontext
5 | from typing import List, Union
6 | from collections import namedtuple
7 | from tqdm.auto import tqdm
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn, einsum
12 | from torch.cuda.amp import autocast
13 | from torch.nn.parallel import DistributedDataParallel
14 | import torchvision.transforms as T
15 |
16 | import kornia.augmentation as K
17 |
18 | from einops import rearrange, repeat, reduce
19 | from einops_exts import rearrange_many
20 |
21 | from imagen_pytorch.imagen_pytorch import (
22 | GaussianDiffusionContinuousTimes,
23 | Unet,
24 | NullUnet,
25 | first,
26 | exists,
27 | identity,
28 | maybe,
29 | default,
30 | cast_tuple,
31 | cast_uint8_images_to_float,
32 | is_float_dtype,
33 | eval_decorator,
34 | check_shape,
35 | pad_tuple_to_length,
36 | resize_image_to,
37 | right_pad_dims_to,
38 | module_device,
39 | normalize_neg_one_to_one,
40 | unnormalize_zero_to_one,
41 | )
42 |
43 | from imagen_pytorch.imagen_video.imagen_video import (
44 | Unet3D,
45 | resize_video_to
46 | )
47 |
48 | from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
49 |
50 | # constants
51 |
52 | Hparams_fields = [
53 | 'num_sample_steps',
54 | 'sigma_min',
55 | 'sigma_max',
56 | 'sigma_data',
57 | 'rho',
58 | 'P_mean',
59 | 'P_std',
60 | 'S_churn',
61 | 'S_tmin',
62 | 'S_tmax',
63 | 'S_noise'
64 | ]
65 |
66 | Hparams = namedtuple('Hparams', Hparams_fields)
67 |
68 | # helper functions
69 |
70 | def log(t, eps = 1e-20):
71 | return torch.log(t.clamp(min = eps))
72 |
73 | # main class
74 |
75 | class ElucidatedImagen(nn.Module):
76 | def __init__(
77 | self,
78 | unets,
79 | *,
80 | image_sizes, # for cascading ddpm, image size at each stage
81 | text_encoder_name = DEFAULT_T5_NAME,
82 | text_embed_dim = None,
83 | channels = 3,
84 | cond_drop_prob = 0.1,
85 | random_crop_sizes = None,
86 | lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
87 | per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
88 | condition_on_text = True,
89 | auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
90 | dynamic_thresholding = True,
91 | dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
92 | only_train_unet_number = None,
93 | lowres_noise_schedule = 'linear',
94 | num_sample_steps = 32, # number of sampling steps
95 | sigma_min = 0.002, # min noise level
96 | sigma_max = 80, # max noise level
97 | sigma_data = 0.5, # standard deviation of data distribution
98 | rho = 7, # controls the sampling schedule
99 | P_mean = -1.2, # mean of log-normal distribution from which noise is drawn for training
100 | P_std = 1.2, # standard deviation of log-normal distribution from which noise is drawn for training
101 | S_churn = 80, # parameters for stochastic sampling - depends on dataset, Table 5 in apper
102 | S_tmin = 0.05,
103 | S_tmax = 50,
104 | S_noise = 1.003,
105 | ):
106 | super().__init__()
107 |
108 | self.only_train_unet_number = only_train_unet_number
109 |
110 | # conditioning hparams
111 |
112 | self.condition_on_text = condition_on_text
113 | self.unconditional = not condition_on_text
114 |
115 | # channels
116 |
117 | self.channels = channels
118 |
119 | # automatically take care of ensuring that first unet is unconditional
120 | # while the rest of the unets are conditioned on the low resolution image produced by previous unet
121 |
122 | unets = cast_tuple(unets)
123 | num_unets = len(unets)
124 |
125 | # randomly cropping for upsampler training
126 |
127 | self.random_crop_sizes = cast_tuple(random_crop_sizes, num_unets)
128 | assert not exists(first(self.random_crop_sizes)), 'you should not need to randomly crop image during training for base unet, only for upsamplers - so pass in `random_crop_sizes = (None, 128, 256)` as example'
129 |
130 | # lowres augmentation noise schedule
131 |
132 | self.lowres_noise_schedule = GaussianDiffusionContinuousTimes(noise_schedule = lowres_noise_schedule)
133 |
134 | # get text encoder
135 |
136 | self.text_encoder_name = text_encoder_name
137 | self.text_embed_dim = default(text_embed_dim, lambda: get_encoded_dim(text_encoder_name))
138 |
139 | self.encode_text = partial(t5_encode_text, name = text_encoder_name)
140 |
141 | # construct unets
142 |
143 | self.unets = nn.ModuleList([])
144 | self.unet_being_trained_index = -1 # keeps track of which unet is being trained at the moment
145 |
146 | for ind, one_unet in enumerate(unets):
147 | assert isinstance(one_unet, (Unet, Unet3D, NullUnet))
148 | is_first = ind == 0
149 |
150 | one_unet = one_unet.cast_model_parameters(
151 | lowres_cond = not is_first,
152 | cond_on_text = self.condition_on_text,
153 | text_embed_dim = self.text_embed_dim if self.condition_on_text else None,
154 | channels = self.channels,
155 | channels_out = self.channels
156 | )
157 |
158 | self.unets.append(one_unet)
159 |
160 | # determine whether we are training on images or video
161 |
162 | is_video = any([isinstance(unet, Unet3D) for unet in self.unets])
163 | self.is_video = is_video
164 |
165 | self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
166 | self.resize_to = resize_video_to if is_video else resize_image_to
167 |
168 | # unet image sizes
169 |
170 | self.image_sizes = cast_tuple(image_sizes)
171 | assert num_unets == len(self.image_sizes), f'you did not supply the correct number of u-nets ({len(self.unets)}) for resolutions {self.image_sizes}'
172 |
173 | self.sample_channels = cast_tuple(self.channels, num_unets)
174 |
175 | # cascading ddpm related stuff
176 |
177 | lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
178 | assert lowres_conditions == (False, *((True,) * (num_unets - 1))), 'the first unet must be unconditioned (by low resolution image), and the rest of the unets must have `lowres_cond` set to True'
179 |
180 | self.lowres_sample_noise_level = lowres_sample_noise_level
181 | self.per_sample_random_aug_noise_level = per_sample_random_aug_noise_level
182 |
183 | # classifier free guidance
184 |
185 | self.cond_drop_prob = cond_drop_prob
186 | self.can_classifier_guidance = cond_drop_prob > 0.
187 |
188 | # normalize and unnormalize image functions
189 |
190 | self.normalize_img = normalize_neg_one_to_one if auto_normalize_img else identity
191 | self.unnormalize_img = unnormalize_zero_to_one if auto_normalize_img else identity
192 | self.input_image_range = (0. if auto_normalize_img else -1., 1.)
193 |
194 | # dynamic thresholding
195 |
196 | self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
197 | self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
198 |
199 | # elucidating parameters
200 |
201 | hparams = [
202 | num_sample_steps,
203 | sigma_min,
204 | sigma_max,
205 | sigma_data,
206 | rho,
207 | P_mean,
208 | P_std,
209 | S_churn,
210 | S_tmin,
211 | S_tmax,
212 | S_noise,
213 | ]
214 |
215 | hparams = [cast_tuple(hp, num_unets) for hp in hparams]
216 | self.hparams = [Hparams(*unet_hp) for unet_hp in zip(*hparams)]
217 |
218 | # one temp parameter for keeping track of device
219 |
220 | self.register_buffer('_temp', torch.tensor([0.]), persistent = False)
221 |
222 | # default to device of unets passed in
223 |
224 | self.to(next(self.unets.parameters()).device)
225 |
226 | def force_unconditional_(self):
227 | self.condition_on_text = False
228 | self.unconditional = True
229 |
230 | for unet in self.unets:
231 | unet.cond_on_text = False
232 |
233 | @property
234 | def device(self):
235 | return self._temp.device
236 |
237 | def get_unet(self, unet_number):
238 | assert 0 < unet_number <= len(self.unets)
239 | index = unet_number - 1
240 |
241 | if isinstance(self.unets, nn.ModuleList):
242 | unets_list = [unet for unet in self.unets]
243 | delattr(self, 'unets')
244 | self.unets = unets_list
245 |
246 | if index != self.unet_being_trained_index:
247 | for unet_index, unet in enumerate(self.unets):
248 | unet.to(self.device if unet_index == index else 'cpu')
249 |
250 | self.unet_being_trained_index = index
251 | return self.unets[index]
252 |
253 | def reset_unets_all_one_device(self, device = None):
254 | device = default(device, self.device)
255 | self.unets = nn.ModuleList([*self.unets])
256 | self.unets.to(device)
257 |
258 | self.unet_being_trained_index = -1
259 |
260 | @contextmanager
261 | def one_unet_in_gpu(self, unet_number = None, unet = None):
262 | assert exists(unet_number) ^ exists(unet)
263 |
264 | if exists(unet_number):
265 | unet = self.unets[unet_number - 1]
266 |
267 | devices = [module_device(unet) for unet in self.unets]
268 | self.unets.cpu()
269 | unet.to(self.device)
270 |
271 | yield
272 |
273 | for unet, device in zip(self.unets, devices):
274 | unet.to(device)
275 |
276 | # overriding state dict functions
277 |
278 | def state_dict(self, *args, **kwargs):
279 | self.reset_unets_all_one_device()
280 | return super().state_dict(*args, **kwargs)
281 |
282 | def load_state_dict(self, *args, **kwargs):
283 | self.reset_unets_all_one_device()
284 | return super().load_state_dict(*args, **kwargs)
285 |
286 | # dynamic thresholding
287 |
288 | def threshold_x_start(self, x_start, dynamic_threshold = True):
289 | if not dynamic_threshold:
290 | return x_start.clamp(-1., 1.)
291 |
292 | s = torch.quantile(
293 | rearrange(x_start, 'b ... -> b (...)').abs(),
294 | self.dynamic_thresholding_percentile,
295 | dim = -1
296 | )
297 |
298 | s.clamp_(min = 1.)
299 | s = right_pad_dims_to(x_start, s)
300 | return x_start.clamp(-s, s) / s
301 |
302 | # derived preconditioning params - Table 1
303 |
304 | def c_skip(self, sigma_data, sigma):
305 | return (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2)
306 |
307 | def c_out(self, sigma_data, sigma):
308 | return sigma * sigma_data * (sigma_data ** 2 + sigma ** 2) ** -0.5
309 |
310 | def c_in(self, sigma_data, sigma):
311 | return 1 * (sigma ** 2 + sigma_data ** 2) ** -0.5
312 |
313 | def c_noise(self, sigma):
314 | return log(sigma) * 0.25
315 |
316 | # preconditioned network output
317 | # equation (7) in the paper
318 |
319 | def preconditioned_network_forward(
320 | self,
321 | unet_forward,
322 | noised_images,
323 | sigma,
324 | *,
325 | sigma_data,
326 | clamp = False,
327 | dynamic_threshold = True,
328 | **kwargs
329 | ):
330 | batch, device = noised_images.shape[0], noised_images.device
331 |
332 | if isinstance(sigma, float):
333 | sigma = torch.full((batch,), sigma, device = device)
334 |
335 | padded_sigma = self.right_pad_dims_to_datatype(sigma)
336 |
337 | net_out = unet_forward(
338 | self.c_in(sigma_data, padded_sigma) * noised_images,
339 | self.c_noise(sigma),
340 | **kwargs
341 | )
342 |
343 | out = self.c_skip(sigma_data, padded_sigma) * noised_images + self.c_out(sigma_data, padded_sigma) * net_out
344 |
345 | if not clamp:
346 | return out
347 |
348 | return self.threshold_x_start(out, dynamic_threshold)
349 |
350 | # sampling
351 |
352 | # sample schedule
353 | # equation (5) in the paper
354 |
355 | def sample_schedule(
356 | self,
357 | num_sample_steps,
358 | rho,
359 | sigma_min,
360 | sigma_max
361 | ):
362 | N = num_sample_steps
363 | inv_rho = 1 / rho
364 |
365 | steps = torch.arange(num_sample_steps, device = self.device, dtype = torch.float32)
366 | sigmas = (sigma_max ** inv_rho + steps / (N - 1) * (sigma_min ** inv_rho - sigma_max ** inv_rho)) ** rho
367 |
368 | sigmas = F.pad(sigmas, (0, 1), value = 0.) # last step is sigma value of 0.
369 | return sigmas
370 |
371 | @torch.no_grad()
372 | def one_unet_sample(
373 | self,
374 | unet,
375 | shape,
376 | *,
377 | unet_number,
378 | clamp = True,
379 | dynamic_threshold = True,
380 | cond_scale = 1.,
381 | use_tqdm = True,
382 | inpaint_images = None,
383 | inpaint_masks = None,
384 | inpaint_resample_times = 5,
385 | init_images = None,
386 | skip_steps = None,
387 | sigma_min = None,
388 | sigma_max = None,
389 | **kwargs
390 | ):
391 | # get specific sampling hyperparameters for unet
392 |
393 | hp = self.hparams[unet_number - 1]
394 |
395 | sigma_min = default(sigma_min, hp.sigma_min)
396 | sigma_max = default(sigma_max, hp.sigma_max)
397 |
398 | # get the schedule, which is returned as (sigma, gamma) tuple, and pair up with the next sigma and gamma
399 |
400 | sigmas = self.sample_schedule(hp.num_sample_steps, hp.rho, sigma_min, sigma_max)
401 |
402 | gammas = torch.where(
403 | (sigmas >= hp.S_tmin) & (sigmas <= hp.S_tmax),
404 | min(hp.S_churn / hp.num_sample_steps, sqrt(2) - 1),
405 | 0.
406 | )
407 |
408 | sigmas_and_gammas = list(zip(sigmas[:-1], sigmas[1:], gammas[:-1]))
409 |
410 | # images is noise at the beginning
411 |
412 | init_sigma = sigmas[0]
413 |
414 | images = init_sigma * torch.randn(shape, device = self.device)
415 |
416 | # initializing with an image
417 |
418 | if exists(init_images):
419 | images += init_images
420 |
421 | # keeping track of x0, for self conditioning if needed
422 |
423 | x_start = None
424 |
425 | # prepare inpainting images and mask
426 |
427 | has_inpainting = exists(inpaint_images) and exists(inpaint_masks)
428 | resample_times = inpaint_resample_times if has_inpainting else 1
429 |
430 | if has_inpainting:
431 | inpaint_images = self.normalize_img(inpaint_images)
432 | inpaint_images = self.resize_to(inpaint_images, shape[-1])
433 | inpaint_masks = self.resize_to(rearrange(inpaint_masks, 'b ... -> b 1 ...').float(), shape[-1]).bool()
434 |
435 | # unet kwargs
436 |
437 | unet_kwargs = dict(
438 | sigma_data = hp.sigma_data,
439 | clamp = clamp,
440 | dynamic_threshold = dynamic_threshold,
441 | cond_scale = cond_scale,
442 | **kwargs
443 | )
444 |
445 | # gradually denoise
446 |
447 | initial_step = default(skip_steps, 0)
448 | sigmas_and_gammas = sigmas_and_gammas[initial_step:]
449 |
450 | total_steps = len(sigmas_and_gammas)
451 |
452 | for ind, (sigma, sigma_next, gamma) in tqdm(enumerate(sigmas_and_gammas), total = total_steps, desc = 'sampling time step', disable = not use_tqdm):
453 | is_last_timestep = ind == (total_steps - 1)
454 |
455 | sigma, sigma_next, gamma = map(lambda t: t.item(), (sigma, sigma_next, gamma))
456 |
457 | for r in reversed(range(resample_times)):
458 | is_last_resample_step = r == 0
459 |
460 | eps = hp.S_noise * torch.randn(shape, device = self.device) # stochastic sampling
461 |
462 | sigma_hat = sigma + gamma * sigma
463 | added_noise = sqrt(sigma_hat ** 2 - sigma ** 2) * eps
464 |
465 | images_hat = images + added_noise
466 |
467 | self_cond = x_start if unet.self_cond else None
468 |
469 | if has_inpainting:
470 | images_hat = images_hat * ~inpaint_masks + (inpaint_images + added_noise) * inpaint_masks
471 |
472 | model_output = self.preconditioned_network_forward(
473 | unet.forward_with_cond_scale,
474 | images_hat,
475 | sigma_hat,
476 | self_cond = self_cond,
477 | **unet_kwargs
478 | )
479 |
480 | denoised_over_sigma = (images_hat - model_output) / sigma_hat
481 |
482 | images_next = images_hat + (sigma_next - sigma_hat) * denoised_over_sigma
483 |
484 | # second order correction, if not the last timestep
485 |
486 | if sigma_next != 0:
487 | self_cond = model_output if unet.self_cond else None
488 |
489 | model_output_next = self.preconditioned_network_forward(
490 | unet.forward_with_cond_scale,
491 | images_next,
492 | sigma_next,
493 | self_cond = self_cond,
494 | **unet_kwargs
495 | )
496 |
497 | denoised_prime_over_sigma = (images_next - model_output_next) / sigma_next
498 | images_next = images_hat + 0.5 * (sigma_next - sigma_hat) * (denoised_over_sigma + denoised_prime_over_sigma)
499 |
500 | images = images_next
501 |
502 | if has_inpainting and not (is_last_resample_step or is_last_timestep):
503 | # renoise in repaint and then resample
504 | repaint_noise = torch.randn(shape, device = self.device)
505 | images = images + (sigma - sigma_next) * repaint_noise
506 |
507 | x_start = model_output # save model output for self conditioning
508 |
509 | images = images.clamp(-1., 1.)
510 |
511 | if has_inpainting:
512 | images = images * ~inpaint_masks + inpaint_images * inpaint_masks
513 |
514 | return self.unnormalize_img(images)
515 |
516 | @torch.no_grad()
517 | @eval_decorator
518 | def sample(
519 | self,
520 | texts: List[str] = None,
521 | text_masks = None,
522 | text_embeds = None,
523 | cond_images = None,
524 | inpaint_images = None,
525 | inpaint_masks = None,
526 | inpaint_resample_times = 5,
527 | init_images = None,
528 | skip_steps = None,
529 | sigma_min = None,
530 | sigma_max = None,
531 | video_frames = None,
532 | batch_size = 1,
533 | cond_scale = 1.,
534 | lowres_sample_noise_level = None,
535 | start_at_unet_number = 1,
536 | start_image_or_video = None,
537 | stop_at_unet_number = None,
538 | return_all_unet_outputs = False,
539 | return_pil_images = False,
540 | use_tqdm = True,
541 | device = None,
542 | ):
543 | device = default(device, self.device)
544 | self.reset_unets_all_one_device(device = device)
545 |
546 | cond_images = maybe(cast_uint8_images_to_float)(cond_images)
547 |
548 | if exists(texts) and not exists(text_embeds) and not self.unconditional:
549 | assert all([*map(len, texts)]), 'text cannot be empty'
550 |
551 | with autocast(enabled = False):
552 | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
553 |
554 | text_embeds, text_masks = map(lambda t: t.to(device), (text_embeds, text_masks))
555 |
556 | if not self.unconditional:
557 | assert exists(text_embeds), 'text must be passed in if the network was not trained without text `condition_on_text` must be set to `False` when training'
558 |
559 | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
560 | batch_size = text_embeds.shape[0]
561 |
562 | if exists(inpaint_images):
563 | if self.unconditional:
564 | if batch_size == 1: # assume researcher wants to broadcast along inpainted images
565 | batch_size = inpaint_images.shape[0]
566 |
567 | assert inpaint_images.shape[0] == batch_size, 'number of inpainting images must be equal to the specified batch size on sample `sample(batch_size=)``'
568 | assert not (self.condition_on_text and inpaint_images.shape[0] != text_embeds.shape[0]), 'number of inpainting images must be equal to the number of text to be conditioned on'
569 |
570 | assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into imagen if specified'
571 | assert not (not self.condition_on_text and exists(text_embeds)), 'imagen specified not to be conditioned on text, yet it is presented'
572 | assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
573 |
574 | assert not (exists(inpaint_images) ^ exists(inpaint_masks)), 'inpaint images and masks must be both passed in to do inpainting'
575 |
576 | outputs = []
577 |
578 | is_cuda = next(self.parameters()).is_cuda
579 | device = next(self.parameters()).device
580 |
581 | lowres_sample_noise_level = default(lowres_sample_noise_level, self.lowres_sample_noise_level)
582 |
583 | num_unets = len(self.unets)
584 | cond_scale = cast_tuple(cond_scale, num_unets)
585 |
586 | # handle video and frame dimension
587 |
588 | assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
589 |
590 | frame_dims = (video_frames,) if self.is_video else tuple()
591 |
592 | # initializing with an image or video
593 |
594 | init_images = cast_tuple(init_images, num_unets)
595 | init_images = [maybe(self.normalize_img)(init_image) for init_image in init_images]
596 |
597 | skip_steps = cast_tuple(skip_steps, num_unets)
598 |
599 | sigma_min = cast_tuple(sigma_min, num_unets)
600 | sigma_max = cast_tuple(sigma_max, num_unets)
601 |
602 | # handle starting at a unet greater than 1, for training only-upscaler training
603 |
604 | if start_at_unet_number > 1:
605 | assert start_at_unet_number <= num_unets, 'must start a unet that is less than the total number of unets'
606 | assert not exists(stop_at_unet_number) or start_at_unet_number <= stop_at_unet_number
607 | assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
608 |
609 | prev_image_size = self.image_sizes[start_at_unet_number - 2]
610 | img = self.resize_to(start_image_or_video, prev_image_size)
611 |
612 | # go through each unet in cascade
613 |
614 | for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm):
615 | if unet_number < start_at_unet_number:
616 | continue
617 |
618 | assert not isinstance(unet, NullUnet), 'cannot sample from null unet'
619 |
620 | context = self.one_unet_in_gpu(unet = unet) if is_cuda else nullcontext()
621 |
622 | with context:
623 | lowres_cond_img = lowres_noise_times = None
624 |
625 | shape = (batch_size, channel, *frame_dims, image_size, image_size)
626 |
627 | if unet.lowres_cond:
628 | lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
629 |
630 | lowres_cond_img = self.resize_to(img, image_size)
631 | lowres_cond_img = self.normalize_img(lowres_cond_img)
632 |
633 | lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
634 |
635 | if exists(unet_init_images):
636 | unet_init_images = self.resize_to(unet_init_images, image_size)
637 |
638 | shape = (batch_size, self.channels, *frame_dims, image_size, image_size)
639 |
640 | img = self.one_unet_sample(
641 | unet,
642 | shape,
643 | unet_number = unet_number,
644 | text_embeds = text_embeds,
645 | text_mask = text_masks,
646 | cond_images = cond_images,
647 | inpaint_images = inpaint_images,
648 | inpaint_masks = inpaint_masks,
649 | inpaint_resample_times = inpaint_resample_times,
650 | init_images = unet_init_images,
651 | skip_steps = unet_skip_steps,
652 | sigma_min = unet_sigma_min,
653 | sigma_max = unet_sigma_max,
654 | cond_scale = unet_cond_scale,
655 | lowres_cond_img = lowres_cond_img,
656 | lowres_noise_times = lowres_noise_times,
657 | dynamic_threshold = dynamic_threshold,
658 | use_tqdm = use_tqdm
659 | )
660 |
661 | outputs.append(img)
662 |
663 | if exists(stop_at_unet_number) and stop_at_unet_number == unet_number:
664 | break
665 |
666 | output_index = -1 if not return_all_unet_outputs else slice(None) # either return last unet output or all unet outputs
667 |
668 | if not return_pil_images:
669 | return outputs[output_index]
670 |
671 | if not return_all_unet_outputs:
672 | outputs = outputs[-1:]
673 |
674 | assert not self.is_video, 'automatically converting video tensor to video file for saving is not built yet'
675 |
676 | pil_images = list(map(lambda img: list(map(T.ToPILImage(), img.unbind(dim = 0))), outputs))
677 |
678 | return pil_images[output_index] # now you have a bunch of pillow images you can just .save(/where/ever/you/want.png)
679 |
680 | # training
681 |
682 | def loss_weight(self, sigma_data, sigma):
683 | return (sigma ** 2 + sigma_data ** 2) * (sigma * sigma_data) ** -2
684 |
685 | def noise_distribution(self, P_mean, P_std, batch_size):
686 | return (P_mean + P_std * torch.randn((batch_size,), device = self.device)).exp()
687 |
688 | def forward(
689 | self,
690 | images,
691 | unet: Union[Unet, Unet3D, NullUnet, DistributedDataParallel] = None,
692 | texts: List[str] = None,
693 | text_embeds = None,
694 | text_masks = None,
695 | unet_number = None,
696 | cond_images = None
697 | ):
698 | assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
699 | assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
700 | unet_number = default(unet_number, 1)
701 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you can only train on unet #{self.only_train_unet_number}'
702 |
703 | images = cast_uint8_images_to_float(images)
704 | cond_images = maybe(cast_uint8_images_to_float)(cond_images)
705 |
706 | assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
707 |
708 | unet_index = unet_number - 1
709 |
710 | unet = default(unet, lambda: self.get_unet(unet_number))
711 |
712 | assert not isinstance(unet, NullUnet), 'null unet cannot and should not be trained'
713 |
714 | target_image_size = self.image_sizes[unet_index]
715 | random_crop_size = self.random_crop_sizes[unet_index]
716 | prev_image_size = self.image_sizes[unet_index - 1] if unet_index > 0 else None
717 | hp = self.hparams[unet_index]
718 |
719 | batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5)
720 |
721 | frames = images.shape[2] if is_video else None
722 |
723 | check_shape(images, 'b c ...', c = self.channels)
724 |
725 | assert h >= target_image_size and w >= target_image_size
726 |
727 | if exists(texts) and not exists(text_embeds) and not self.unconditional:
728 | assert all([*map(len, texts)]), 'text cannot be empty'
729 | assert len(texts) == len(images), 'number of text captions does not match up with the number of images given'
730 |
731 | with autocast(enabled = False):
732 | text_embeds, text_masks = self.encode_text(texts, return_attn_mask = True)
733 |
734 | text_embeds, text_masks = map(lambda t: t.to(images.device), (text_embeds, text_masks))
735 |
736 | if not self.unconditional:
737 | text_masks = default(text_masks, lambda: torch.any(text_embeds != 0., dim = -1))
738 |
739 | assert not (self.condition_on_text and not exists(text_embeds)), 'text or text encodings must be passed into decoder if specified'
740 | assert not (not self.condition_on_text and exists(text_embeds)), 'decoder specified not to be conditioned on text, yet it is presented'
741 |
742 | assert not (exists(text_embeds) and text_embeds.shape[-1] != self.text_embed_dim), f'invalid text embedding dimension being passed in (should be {self.text_embed_dim})'
743 |
744 | lowres_cond_img = lowres_aug_times = None
745 | if exists(prev_image_size):
746 | lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range)
747 | lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range)
748 |
749 | if self.per_sample_random_aug_noise_level:
750 | lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device)
751 | else:
752 | lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
753 | lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size)
754 |
755 | images = self.resize_to(images, target_image_size)
756 |
757 | # normalize to [-1, 1]
758 |
759 | images = self.normalize_img(images)
760 | lowres_cond_img = maybe(self.normalize_img)(lowres_cond_img)
761 |
762 | # random cropping during training
763 | # for upsamplers
764 |
765 | if exists(random_crop_size):
766 | aug = K.RandomCrop((random_crop_size, random_crop_size), p = 1.)
767 |
768 | if is_video:
769 | images, lowres_cond_img = rearrange_many((images, lowres_cond_img), 'b c f h w -> (b f) c h w')
770 |
771 | # make sure low res conditioner and image both get augmented the same way
772 | # detailed https://kornia.readthedocs.io/en/latest/augmentation.module.html?highlight=randomcrop#kornia.augmentation.RandomCrop
773 | images = aug(images)
774 | lowres_cond_img = aug(lowres_cond_img, params = aug._params)
775 |
776 | if is_video:
777 | images, lowres_cond_img = rearrange_many((images, lowres_cond_img), '(b f) c h w -> b c f h w', f = frames)
778 |
779 | # noise the lowres conditioning image
780 | # at sample time, they then fix the noise level of 0.1 - 0.3
781 |
782 | lowres_cond_img_noisy = None
783 | if exists(lowres_cond_img):
784 | lowres_cond_img_noisy, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_aug_times, noise = torch.randn_like(lowres_cond_img))
785 |
786 | # get the sigmas
787 |
788 | sigmas = self.noise_distribution(hp.P_mean, hp.P_std, batch_size)
789 | padded_sigmas = self.right_pad_dims_to_datatype(sigmas)
790 |
791 | # noise
792 |
793 | noise = torch.randn_like(images)
794 | noised_images = images + padded_sigmas * noise # alphas are 1. in the paper
795 |
796 | # unet kwargs
797 |
798 | unet_kwargs = dict(
799 | sigma_data = hp.sigma_data,
800 | text_embeds = text_embeds,
801 | text_mask = text_masks,
802 | cond_images = cond_images,
803 | lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
804 | lowres_cond_img = lowres_cond_img_noisy,
805 | cond_drop_prob = self.cond_drop_prob,
806 | )
807 |
808 | # self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower
809 |
810 | # Because 'unet' can be an instance of DistributedDataParallel coming from the
811 | # ImagenTrainer.unet_being_trained when invoking ImagenTrainer.forward(), we need to
812 | # access the member 'module' of the wrapped unet instance.
813 | self_cond = unet.module.self_cond if isinstance(unet, DistributedDataParallel) else unet
814 |
815 | if self_cond and random() < 0.5:
816 | with torch.no_grad():
817 | pred_x0 = self.preconditioned_network_forward(
818 | unet.forward,
819 | noised_images,
820 | sigmas,
821 | **unet_kwargs
822 | ).detach()
823 |
824 | unet_kwargs = {**unet_kwargs, 'self_cond': pred_x0}
825 |
826 | # get prediction
827 |
828 | denoised_images = self.preconditioned_network_forward(
829 | unet.forward,
830 | noised_images,
831 | sigmas,
832 | **unet_kwargs
833 | )
834 |
835 | # losses
836 |
837 | losses = F.mse_loss(denoised_images, images, reduction = 'none')
838 | losses = reduce(losses, 'b ... -> b', 'mean')
839 |
840 | # loss weighting
841 |
842 | losses = losses * self.loss_weight(hp.sigma_data, sigmas)
843 |
844 | # return average loss
845 |
846 | return losses.mean()
847 |
--------------------------------------------------------------------------------
/imagen_pytorch/imagen_video/__init__.py:
--------------------------------------------------------------------------------
1 | from imagen_pytorch.imagen_video.imagen_video import Unet3D
2 |
--------------------------------------------------------------------------------
/imagen_pytorch/imagen_video/imagen_video.py:
--------------------------------------------------------------------------------
1 | import math
2 | import copy
3 | from typing import List
4 | from tqdm.auto import tqdm
5 | from functools import partial, wraps
6 | from contextlib import contextmanager, nullcontext
7 | from collections import namedtuple
8 | from pathlib import Path
9 |
10 | import torch
11 | import torch.nn.functional as F
12 | from torch import nn, einsum
13 |
14 | from einops import rearrange, repeat, reduce
15 | from einops.layers.torch import Rearrange, Reduce
16 | from einops_exts import rearrange_many, repeat_many, check_shape
17 | from einops_exts.torch import EinopsToAndFrom
18 |
19 | from imagen_pytorch.t5 import t5_encode_text, get_encoded_dim, DEFAULT_T5_NAME
20 |
21 | # helper functions
22 |
23 | def exists(val):
24 | return val is not None
25 |
26 | def identity(t, *args, **kwargs):
27 | return t
28 |
29 | def first(arr, d = None):
30 | if len(arr) == 0:
31 | return d
32 | return arr[0]
33 |
34 | def maybe(fn):
35 | @wraps(fn)
36 | def inner(x):
37 | if not exists(x):
38 | return x
39 | return fn(x)
40 | return inner
41 |
42 | def once(fn):
43 | called = False
44 | @wraps(fn)
45 | def inner(x):
46 | nonlocal called
47 | if called:
48 | return
49 | called = True
50 | return fn(x)
51 | return inner
52 |
53 | print_once = once(print)
54 |
55 | def default(val, d):
56 | if exists(val):
57 | return val
58 | return d() if callable(d) else d
59 |
60 | def cast_tuple(val, length = None):
61 | if isinstance(val, list):
62 | val = tuple(val)
63 |
64 | output = val if isinstance(val, tuple) else ((val,) * default(length, 1))
65 |
66 | if exists(length):
67 | assert len(output) == length
68 |
69 | return output
70 |
71 | def cast_uint8_images_to_float(images):
72 | if not images.dtype == torch.uint8:
73 | return images
74 | return images / 255
75 |
76 | def module_device(module):
77 | return next(module.parameters()).device
78 |
79 | def zero_init_(m):
80 | nn.init.zeros_(m.weight)
81 | if exists(m.bias):
82 | nn.init.zeros_(m.bias)
83 |
84 | def eval_decorator(fn):
85 | def inner(model, *args, **kwargs):
86 | was_training = model.training
87 | model.eval()
88 | out = fn(model, *args, **kwargs)
89 | model.train(was_training)
90 | return out
91 | return inner
92 |
93 | def pad_tuple_to_length(t, length, fillvalue = None):
94 | remain_length = length - len(t)
95 | if remain_length <= 0:
96 | return t
97 | return (*t, *((fillvalue,) * remain_length))
98 |
99 | # helper classes
100 |
101 | class Identity(nn.Module):
102 | def __init__(self, *args, **kwargs):
103 | super().__init__()
104 |
105 | def forward(self, x, *args, **kwargs):
106 | return x
107 |
108 | # tensor helpers
109 |
110 | def log(t, eps: float = 1e-12):
111 | return torch.log(t.clamp(min = eps))
112 |
113 | def l2norm(t):
114 | return F.normalize(t, dim = -1)
115 |
116 | def right_pad_dims_to(x, t):
117 | padding_dims = x.ndim - t.ndim
118 | if padding_dims <= 0:
119 | return t
120 | return t.view(*t.shape, *((1,) * padding_dims))
121 |
122 | def masked_mean(t, *, dim, mask = None):
123 | if not exists(mask):
124 | return t.mean(dim = dim)
125 |
126 | denom = mask.sum(dim = dim, keepdim = True)
127 | mask = rearrange(mask, 'b n -> b n 1')
128 | masked_t = t.masked_fill(~mask, 0.)
129 |
130 | return masked_t.sum(dim = dim) / denom.clamp(min = 1e-5)
131 |
132 | def resize_video_to(
133 | video,
134 | target_image_size,
135 | clamp_range = None
136 | ):
137 | orig_video_size = video.shape[-1]
138 |
139 | if orig_video_size == target_image_size:
140 | return video
141 |
142 |
143 | frames = video.shape[2]
144 | video = rearrange(video, 'b c f h w -> (b f) c h w')
145 |
146 | out = F.interpolate(video, target_image_size, mode = 'nearest')
147 |
148 | if exists(clamp_range):
149 | out = out.clamp(*clamp_range)
150 |
151 | out = rearrange(out, '(b f) c h w -> b c f h w', f = frames)
152 |
153 | return out
154 |
155 | # classifier free guidance functions
156 |
157 | def prob_mask_like(shape, prob, device):
158 | if prob == 1:
159 | return torch.ones(shape, device = device, dtype = torch.bool)
160 | elif prob == 0:
161 | return torch.zeros(shape, device = device, dtype = torch.bool)
162 | else:
163 | return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob
164 |
165 | # norms and residuals
166 |
167 | class LayerNorm(nn.Module):
168 | def __init__(self, dim, stable = False):
169 | super().__init__()
170 | self.stable = stable
171 | self.g = nn.Parameter(torch.ones(dim))
172 |
173 | def forward(self, x):
174 | if self.stable:
175 | x = x / x.amax(dim = -1, keepdim = True).detach()
176 |
177 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
178 | var = torch.var(x, dim = -1, unbiased = False, keepdim = True)
179 | mean = torch.mean(x, dim = -1, keepdim = True)
180 | return (x - mean) * (var + eps).rsqrt() * self.g
181 |
182 | class ChanLayerNorm(nn.Module):
183 | def __init__(self, dim, stable = False):
184 | super().__init__()
185 | self.stable = stable
186 | self.g = nn.Parameter(torch.ones(1, dim, 1, 1, 1))
187 |
188 | def forward(self, x):
189 | if self.stable:
190 | x = x / x.amax(dim = 1, keepdim = True).detach()
191 |
192 | eps = 1e-5 if x.dtype == torch.float32 else 1e-3
193 | var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
194 | mean = torch.mean(x, dim = 1, keepdim = True)
195 | return (x - mean) * (var + eps).rsqrt() * self.g
196 |
197 | class Always():
198 | def __init__(self, val):
199 | self.val = val
200 |
201 | def __call__(self, *args, **kwargs):
202 | return self.val
203 |
204 | class Residual(nn.Module):
205 | def __init__(self, fn):
206 | super().__init__()
207 | self.fn = fn
208 |
209 | def forward(self, x, **kwargs):
210 | return self.fn(x, **kwargs) + x
211 |
212 | class Parallel(nn.Module):
213 | def __init__(self, *fns):
214 | super().__init__()
215 | self.fns = nn.ModuleList(fns)
216 |
217 | def forward(self, x):
218 | outputs = [fn(x) for fn in self.fns]
219 | return sum(outputs)
220 |
221 | # attention pooling
222 |
223 | class PerceiverAttention(nn.Module):
224 | def __init__(
225 | self,
226 | *,
227 | dim,
228 | dim_head = 64,
229 | heads = 8,
230 | cosine_sim_attn = False
231 | ):
232 | super().__init__()
233 | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1
234 | self.cosine_sim_attn = cosine_sim_attn
235 | self.cosine_sim_scale = 16 if cosine_sim_attn else 1
236 |
237 | self.heads = heads
238 | inner_dim = dim_head * heads
239 |
240 | self.norm = nn.LayerNorm(dim)
241 | self.norm_latents = nn.LayerNorm(dim)
242 |
243 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
244 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
245 |
246 | self.to_out = nn.Sequential(
247 | nn.Linear(inner_dim, dim, bias = False),
248 | nn.LayerNorm(dim)
249 | )
250 |
251 | def forward(self, x, latents, mask = None):
252 | x = self.norm(x)
253 | latents = self.norm_latents(latents)
254 |
255 | b, h = x.shape[0], self.heads
256 |
257 | q = self.to_q(latents)
258 |
259 | # the paper differs from Perceiver in which they also concat the key / values derived from the latents to be attended to
260 | kv_input = torch.cat((x, latents), dim = -2)
261 | k, v = self.to_kv(kv_input).chunk(2, dim = -1)
262 |
263 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = h)
264 |
265 | q = q * self.scale
266 |
267 | # cosine sim attention
268 |
269 | if self.cosine_sim_attn:
270 | q, k = map(l2norm, (q, k))
271 |
272 | # similarities and masking
273 |
274 | sim = einsum('... i d, ... j d -> ... i j', q, k) * self.cosine_sim_scale
275 |
276 | if exists(mask):
277 | max_neg_value = -torch.finfo(sim.dtype).max
278 | mask = F.pad(mask, (0, latents.shape[-2]), value = True)
279 | mask = rearrange(mask, 'b j -> b 1 1 j')
280 | sim = sim.masked_fill(~mask, max_neg_value)
281 |
282 | # attention
283 |
284 | attn = sim.softmax(dim = -1)
285 |
286 | out = einsum('... i j, ... j d -> ... i d', attn, v)
287 | out = rearrange(out, 'b h n d -> b n (h d)', h = h)
288 | return self.to_out(out)
289 |
290 | class PerceiverResampler(nn.Module):
291 | def __init__(
292 | self,
293 | *,
294 | dim,
295 | depth,
296 | dim_head = 64,
297 | heads = 8,
298 | num_latents = 64,
299 | num_latents_mean_pooled = 4, # number of latents derived from mean pooled representation of the sequence
300 | max_seq_len = 512,
301 | ff_mult = 4,
302 | cosine_sim_attn = False
303 | ):
304 | super().__init__()
305 | self.pos_emb = nn.Embedding(max_seq_len, dim)
306 |
307 | self.latents = nn.Parameter(torch.randn(num_latents, dim))
308 |
309 | self.to_latents_from_mean_pooled_seq = None
310 |
311 | if num_latents_mean_pooled > 0:
312 | self.to_latents_from_mean_pooled_seq = nn.Sequential(
313 | LayerNorm(dim),
314 | nn.Linear(dim, dim * num_latents_mean_pooled),
315 | Rearrange('b (n d) -> b n d', n = num_latents_mean_pooled)
316 | )
317 |
318 | self.layers = nn.ModuleList([])
319 | for _ in range(depth):
320 | self.layers.append(nn.ModuleList([
321 | PerceiverAttention(dim = dim, dim_head = dim_head, heads = heads, cosine_sim_attn = cosine_sim_attn),
322 | FeedForward(dim = dim, mult = ff_mult)
323 | ]))
324 |
325 | def forward(self, x, mask = None):
326 | n, device = x.shape[1], x.device
327 | pos_emb = self.pos_emb(torch.arange(n, device = device))
328 |
329 | x_with_pos = x + pos_emb
330 |
331 | latents = repeat(self.latents, 'n d -> b n d', b = x.shape[0])
332 |
333 | if exists(self.to_latents_from_mean_pooled_seq):
334 | meanpooled_seq = masked_mean(x, dim = 1, mask = torch.ones(x.shape[:2], device = x.device, dtype = torch.bool))
335 | meanpooled_latents = self.to_latents_from_mean_pooled_seq(meanpooled_seq)
336 | latents = torch.cat((meanpooled_latents, latents), dim = -2)
337 |
338 | for attn, ff in self.layers:
339 | latents = attn(x_with_pos, latents, mask = mask) + latents
340 | latents = ff(latents) + latents
341 |
342 | return latents
343 |
344 | # attention
345 |
346 | class Attention(nn.Module):
347 | def __init__(
348 | self,
349 | dim,
350 | *,
351 | dim_head = 64,
352 | heads = 8,
353 | causal = False,
354 | context_dim = None,
355 | cosine_sim_attn = False
356 | ):
357 | super().__init__()
358 | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
359 | self.causal = causal
360 |
361 | self.cosine_sim_attn = cosine_sim_attn
362 | self.cosine_sim_scale = 16 if cosine_sim_attn else 1
363 |
364 | self.heads = heads
365 | inner_dim = dim_head * heads
366 |
367 | self.norm = LayerNorm(dim)
368 |
369 | self.null_attn_bias = nn.Parameter(torch.randn(heads))
370 |
371 | self.null_kv = nn.Parameter(torch.randn(2, dim_head))
372 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
373 | self.to_kv = nn.Linear(dim, dim_head * 2, bias = False)
374 |
375 | self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, dim_head * 2)) if exists(context_dim) else None
376 |
377 | self.to_out = nn.Sequential(
378 | nn.Linear(inner_dim, dim, bias = False),
379 | LayerNorm(dim)
380 | )
381 |
382 | def forward(self, x, context = None, mask = None, attn_bias = None):
383 | b, n, device = *x.shape[:2], x.device
384 |
385 | x = self.norm(x)
386 | q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = -1))
387 |
388 | q = rearrange(q, 'b n (h d) -> b h n d', h = self.heads)
389 | q = q * self.scale
390 |
391 | # add null key / value for classifier free guidance in prior net
392 |
393 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b 1 d', b = b)
394 | k = torch.cat((nk, k), dim = -2)
395 | v = torch.cat((nv, v), dim = -2)
396 |
397 | # add text conditioning, if present
398 |
399 | if exists(context):
400 | assert exists(self.to_context)
401 | ck, cv = self.to_context(context).chunk(2, dim = -1)
402 | k = torch.cat((ck, k), dim = -2)
403 | v = torch.cat((cv, v), dim = -2)
404 |
405 | # cosine sim attention
406 |
407 | if self.cosine_sim_attn:
408 | q, k = map(l2norm, (q, k))
409 |
410 | # calculate query / key similarities
411 |
412 | sim = einsum('b h i d, b j d -> b h i j', q, k) * self.cosine_sim_scale
413 |
414 | # relative positional encoding (T5 style)
415 |
416 | if exists(attn_bias):
417 | null_attn_bias = repeat(self.null_attn_bias, 'h -> h n 1', n = n)
418 | attn_bias = torch.cat((null_attn_bias, attn_bias), dim = -1)
419 | sim = sim + attn_bias
420 |
421 | # masking
422 |
423 | max_neg_value = -torch.finfo(sim.dtype).max
424 |
425 | if self.causal:
426 | i, j = sim.shape[-2:]
427 | causal_mask = torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
428 | sim = sim.masked_fill(causal_mask, max_neg_value)
429 |
430 | if exists(mask):
431 | mask = F.pad(mask, (1, 0), value = True)
432 | mask = rearrange(mask, 'b j -> b 1 1 j')
433 | sim = sim.masked_fill(~mask, max_neg_value)
434 |
435 | # attention
436 |
437 | attn = sim.softmax(dim = -1)
438 |
439 | # aggregate values
440 |
441 | out = einsum('b h i j, b j d -> b h i d', attn, v)
442 |
443 | out = rearrange(out, 'b h n d -> b n (h d)')
444 | return self.to_out(out)
445 |
446 | # pseudo conv2d that uses conv3d but with kernel size of 1 across frames dimension
447 |
448 | def Conv2d(dim_in, dim_out, kernel, stride = 1, padding = 0, **kwargs):
449 | kernel = cast_tuple(kernel, 2)
450 | stride = cast_tuple(stride, 2)
451 | padding = cast_tuple(padding, 2)
452 |
453 | if len(kernel) == 2:
454 | kernel = (1, *kernel)
455 |
456 | if len(stride) == 2:
457 | stride = (1, *stride)
458 |
459 | if len(padding) == 2:
460 | padding = (0, *padding)
461 |
462 | return nn.Conv3d(dim_in, dim_out, kernel, stride = stride, padding = padding, **kwargs)
463 |
464 | class Pad(nn.Module):
465 | def __init__(self, padding, value = 0.):
466 | super().__init__()
467 | self.padding = padding
468 | self.value = value
469 |
470 | def forward(self, x):
471 | return F.pad(x, self.padding, value = self.value)
472 |
473 | # decoder
474 |
475 | def Upsample(dim, dim_out = None):
476 | dim_out = default(dim_out, dim)
477 |
478 | return nn.Sequential(
479 | nn.Upsample(scale_factor = 2, mode = 'nearest'),
480 | Conv2d(dim, dim_out, 3, padding = 1)
481 | )
482 |
483 | class PixelShuffleUpsample(nn.Module):
484 | def __init__(self, dim, dim_out = None):
485 | super().__init__()
486 | dim_out = default(dim_out, dim)
487 | conv = Conv2d(dim, dim_out * 4, 1)
488 |
489 | self.net = nn.Sequential(
490 | conv,
491 | nn.SiLU()
492 | )
493 |
494 | self.pixel_shuffle = nn.PixelShuffle(2)
495 |
496 | self.init_conv_(conv)
497 |
498 | def init_conv_(self, conv):
499 | o, i, f, h, w = conv.weight.shape
500 | conv_weight = torch.empty(o // 4, i, f, h, w)
501 | nn.init.kaiming_uniform_(conv_weight)
502 | conv_weight = repeat(conv_weight, 'o ... -> (o 4) ...')
503 |
504 | conv.weight.data.copy_(conv_weight)
505 | nn.init.zeros_(conv.bias.data)
506 |
507 | def forward(self, x):
508 | out = self.net(x)
509 | frames = x.shape[2]
510 | out = rearrange(out, 'b c f h w -> (b f) c h w')
511 | out = self.pixel_shuffle(out)
512 | return rearrange(out, '(b f) c h w -> b c f h w', f = frames)
513 |
514 | def Downsample(dim, dim_out = None):
515 | dim_out = default(dim_out, dim)
516 | return Conv2d(dim, dim_out, 4, 2, 1)
517 |
518 | class SinusoidalPosEmb(nn.Module):
519 | def __init__(self, dim):
520 | super().__init__()
521 | self.dim = dim
522 |
523 | def forward(self, x):
524 | half_dim = self.dim // 2
525 | emb = math.log(10000) / (half_dim - 1)
526 | emb = torch.exp(torch.arange(half_dim, device = x.device) * -emb)
527 | emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
528 | return torch.cat((emb.sin(), emb.cos()), dim = -1)
529 |
530 | class LearnedSinusoidalPosEmb(nn.Module):
531 | def __init__(self, dim):
532 | super().__init__()
533 | assert (dim % 2) == 0
534 | half_dim = dim // 2
535 | self.weights = nn.Parameter(torch.randn(half_dim))
536 |
537 | def forward(self, x):
538 | x = rearrange(x, 'b -> b 1')
539 | freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi
540 | fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1)
541 | fouriered = torch.cat((x, fouriered), dim = -1)
542 | return fouriered
543 |
544 | class Block(nn.Module):
545 | def __init__(
546 | self,
547 | dim,
548 | dim_out,
549 | groups = 8,
550 | norm = True
551 | ):
552 | super().__init__()
553 | self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
554 | self.activation = nn.SiLU()
555 | self.project = Conv2d(dim, dim_out, 3, padding = 1)
556 |
557 | def forward(self, x, scale_shift = None):
558 | x = self.groupnorm(x)
559 |
560 | if exists(scale_shift):
561 | scale, shift = scale_shift
562 | x = x * (scale + 1) + shift
563 |
564 | x = self.activation(x)
565 | return self.project(x)
566 |
567 | class ResnetBlock(nn.Module):
568 | def __init__(
569 | self,
570 | dim,
571 | dim_out,
572 | *,
573 | cond_dim = None,
574 | time_cond_dim = None,
575 | groups = 8,
576 | linear_attn = False,
577 | use_gca = False,
578 | squeeze_excite = False,
579 | **attn_kwargs
580 | ):
581 | super().__init__()
582 |
583 | self.time_mlp = None
584 |
585 | if exists(time_cond_dim):
586 | self.time_mlp = nn.Sequential(
587 | nn.SiLU(),
588 | nn.Linear(time_cond_dim, dim_out * 2)
589 | )
590 |
591 | self.cross_attn = None
592 |
593 | if exists(cond_dim):
594 | attn_klass = CrossAttention if not linear_attn else LinearCrossAttention
595 |
596 | self.cross_attn = EinopsToAndFrom(
597 | 'b c f h w',
598 | 'b (f h w) c',
599 | attn_klass(
600 | dim = dim_out,
601 | context_dim = cond_dim,
602 | **attn_kwargs
603 | )
604 | )
605 |
606 | self.block1 = Block(dim, dim_out, groups = groups)
607 | self.block2 = Block(dim_out, dim_out, groups = groups)
608 |
609 | self.gca = GlobalContext(dim_in = dim_out, dim_out = dim_out) if use_gca else Always(1)
610 |
611 | self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()
612 |
613 |
614 | def forward(self, x, time_emb = None, cond = None):
615 |
616 | scale_shift = None
617 | if exists(self.time_mlp) and exists(time_emb):
618 | time_emb = self.time_mlp(time_emb)
619 | time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
620 | scale_shift = time_emb.chunk(2, dim = 1)
621 |
622 | h = self.block1(x)
623 |
624 | if exists(self.cross_attn):
625 | assert exists(cond)
626 | h = self.cross_attn(h, context = cond) + h
627 |
628 | h = self.block2(h, scale_shift = scale_shift)
629 |
630 | h = h * self.gca(h)
631 |
632 | return h + self.res_conv(x)
633 |
634 | class CrossAttention(nn.Module):
635 | def __init__(
636 | self,
637 | dim,
638 | *,
639 | context_dim = None,
640 | dim_head = 64,
641 | heads = 8,
642 | norm_context = False,
643 | cosine_sim_attn = False
644 | ):
645 | super().__init__()
646 | self.scale = dim_head ** -0.5 if not cosine_sim_attn else 1.
647 | self.cosine_sim_attn = cosine_sim_attn
648 | self.cosine_sim_scale = 16 if cosine_sim_attn else 1
649 |
650 | self.heads = heads
651 | inner_dim = dim_head * heads
652 |
653 | context_dim = default(context_dim, dim)
654 |
655 | self.norm = LayerNorm(dim)
656 | self.norm_context = LayerNorm(context_dim) if norm_context else Identity()
657 |
658 | self.null_kv = nn.Parameter(torch.randn(2, dim_head))
659 | self.to_q = nn.Linear(dim, inner_dim, bias = False)
660 | self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias = False)
661 |
662 | self.to_out = nn.Sequential(
663 | nn.Linear(inner_dim, dim, bias = False),
664 | LayerNorm(dim)
665 | )
666 |
667 | def forward(self, x, context, mask = None):
668 | b, n, device = *x.shape[:2], x.device
669 |
670 | x = self.norm(x)
671 | context = self.norm_context(context)
672 |
673 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
674 |
675 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> b h n d', h = self.heads)
676 |
677 | # add null key / value for classifier free guidance in prior net
678 |
679 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> b h 1 d', h = self.heads, b = b)
680 |
681 | k = torch.cat((nk, k), dim = -2)
682 | v = torch.cat((nv, v), dim = -2)
683 |
684 | q = q * self.scale
685 |
686 | # cosine sim attention
687 |
688 | if self.cosine_sim_attn:
689 | q, k = map(l2norm, (q, k))
690 |
691 | # similarities
692 |
693 | sim = einsum('b h i d, b h j d -> b h i j', q, k) * self.cosine_sim_scale
694 |
695 | # masking
696 |
697 | max_neg_value = -torch.finfo(sim.dtype).max
698 |
699 | if exists(mask):
700 | mask = F.pad(mask, (1, 0), value = True)
701 | mask = rearrange(mask, 'b j -> b 1 1 j')
702 | sim = sim.masked_fill(~mask, max_neg_value)
703 |
704 | attn = sim.softmax(dim = -1, dtype = torch.float32)
705 |
706 | out = einsum('b h i j, b h j d -> b h i d', attn, v)
707 | out = rearrange(out, 'b h n d -> b n (h d)')
708 | return self.to_out(out)
709 |
710 | class LinearCrossAttention(CrossAttention):
711 | def forward(self, x, context, mask = None):
712 | b, n, device = *x.shape[:2], x.device
713 |
714 | x = self.norm(x)
715 | context = self.norm_context(context)
716 |
717 | q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
718 |
719 | q, k, v = rearrange_many((q, k, v), 'b n (h d) -> (b h) n d', h = self.heads)
720 |
721 | # add null key / value for classifier free guidance in prior net
722 |
723 | nk, nv = repeat_many(self.null_kv.unbind(dim = -2), 'd -> (b h) 1 d', h = self.heads, b = b)
724 |
725 | k = torch.cat((nk, k), dim = -2)
726 | v = torch.cat((nv, v), dim = -2)
727 |
728 | # masking
729 |
730 | max_neg_value = -torch.finfo(x.dtype).max
731 |
732 | if exists(mask):
733 | mask = F.pad(mask, (1, 0), value = True)
734 | mask = rearrange(mask, 'b n -> b n 1')
735 | k = k.masked_fill(~mask, max_neg_value)
736 | v = v.masked_fill(~mask, 0.)
737 |
738 | # linear attention
739 |
740 | q = q.softmax(dim = -1)
741 | k = k.softmax(dim = -2)
742 |
743 | q = q * self.scale
744 |
745 | context = einsum('b n d, b n e -> b d e', k, v)
746 | out = einsum('b n d, b d e -> b n e', q, context)
747 | out = rearrange(out, '(b h) n d -> b n (h d)', h = self.heads)
748 | return self.to_out(out)
749 |
750 | class LinearAttention(nn.Module):
751 | def __init__(
752 | self,
753 | dim,
754 | dim_head = 32,
755 | heads = 8,
756 | dropout = 0.05,
757 | context_dim = None,
758 | **kwargs
759 | ):
760 | super().__init__()
761 | self.scale = dim_head ** -0.5
762 | self.heads = heads
763 | inner_dim = dim_head * heads
764 | self.norm = ChanLayerNorm(dim)
765 |
766 | self.nonlin = nn.SiLU()
767 |
768 | self.to_q = nn.Sequential(
769 | nn.Dropout(dropout),
770 | Conv2d(dim, inner_dim, 1, bias = False),
771 | Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
772 | )
773 |
774 | self.to_k = nn.Sequential(
775 | nn.Dropout(dropout),
776 | Conv2d(dim, inner_dim, 1, bias = False),
777 | Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
778 | )
779 |
780 | self.to_v = nn.Sequential(
781 | nn.Dropout(dropout),
782 | Conv2d(dim, inner_dim, 1, bias = False),
783 | Conv2d(inner_dim, inner_dim, 3, bias = False, padding = 1, groups = inner_dim)
784 | )
785 |
786 | self.to_context = nn.Sequential(nn.LayerNorm(context_dim), nn.Linear(context_dim, inner_dim * 2, bias = False)) if exists(context_dim) else None
787 |
788 | self.to_out = nn.Sequential(
789 | Conv2d(inner_dim, dim, 1, bias = False),
790 | ChanLayerNorm(dim)
791 | )
792 |
793 | def forward(self, fmap, context = None):
794 | h, x, y = self.heads, *fmap.shape[-2:]
795 |
796 | fmap = self.norm(fmap)
797 | q, k, v = map(lambda fn: fn(fmap), (self.to_q, self.to_k, self.to_v))
798 | q, k, v = rearrange_many((q, k, v), 'b (h c) x y -> (b h) (x y) c', h = h)
799 |
800 | if exists(context):
801 | assert exists(self.to_context)
802 | ck, cv = self.to_context(context).chunk(2, dim = -1)
803 | ck, cv = rearrange_many((ck, cv), 'b n (h d) -> (b h) n d', h = h)
804 | k = torch.cat((k, ck), dim = -2)
805 | v = torch.cat((v, cv), dim = -2)
806 |
807 | q = q.softmax(dim = -1)
808 | k = k.softmax(dim = -2)
809 |
810 | q = q * self.scale
811 |
812 | context = einsum('b n d, b n e -> b d e', k, v)
813 | out = einsum('b n d, b d e -> b n e', q, context)
814 | out = rearrange(out, '(b h) (x y) d -> b (h d) x y', h = h, x = x, y = y)
815 |
816 | out = self.nonlin(out)
817 | return self.to_out(out)
818 |
819 | class GlobalContext(nn.Module):
820 | """ basically a superior form of squeeze-excitation that is attention-esque """
821 |
822 | def __init__(
823 | self,
824 | *,
825 | dim_in,
826 | dim_out
827 | ):
828 | super().__init__()
829 | self.to_k = Conv2d(dim_in, 1, 1)
830 | hidden_dim = max(3, dim_out // 2)
831 |
832 | self.net = nn.Sequential(
833 | Conv2d(dim_in, hidden_dim, 1),
834 | nn.SiLU(),
835 | Conv2d(hidden_dim, dim_out, 1),
836 | nn.Sigmoid()
837 | )
838 |
839 | def forward(self, x):
840 | context = self.to_k(x)
841 | x, context = rearrange_many((x, context), 'b n ... -> b n (...)')
842 | out = einsum('b i n, b c n -> b c i', context.softmax(dim = -1), x)
843 | out = rearrange(out, '... -> ... 1 1')
844 | return self.net(out)
845 |
846 | def FeedForward(dim, mult = 2):
847 | hidden_dim = int(dim * mult)
848 | return nn.Sequential(
849 | LayerNorm(dim),
850 | nn.Linear(dim, hidden_dim, bias = False),
851 | nn.GELU(),
852 | LayerNorm(hidden_dim),
853 | nn.Linear(hidden_dim, dim, bias = False)
854 | )
855 |
856 | def ChanFeedForward(dim, mult = 2): # in paper, it seems for self attention layers they did feedforwards with twice channel width
857 | hidden_dim = int(dim * mult)
858 | return nn.Sequential(
859 | ChanLayerNorm(dim),
860 | Conv2d(dim, hidden_dim, 1, bias = False),
861 | nn.GELU(),
862 | ChanLayerNorm(hidden_dim),
863 | Conv2d(hidden_dim, dim, 1, bias = False)
864 | )
865 |
866 | class TransformerBlock(nn.Module):
867 | def __init__(
868 | self,
869 | dim,
870 | *,
871 | depth = 1,
872 | heads = 8,
873 | dim_head = 32,
874 | ff_mult = 2,
875 | context_dim = None,
876 | cosine_sim_attn = False
877 | ):
878 | super().__init__()
879 | self.layers = nn.ModuleList([])
880 |
881 | for _ in range(depth):
882 | self.layers.append(nn.ModuleList([
883 | EinopsToAndFrom('b c f h w', 'b (f h w) c', Attention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim, cosine_sim_attn = cosine_sim_attn)),
884 | ChanFeedForward(dim = dim, mult = ff_mult)
885 | ]))
886 |
887 | def forward(self, x, context = None):
888 | for attn, ff in self.layers:
889 | x = attn(x, context = context) + x
890 | x = ff(x) + x
891 | return x
892 |
893 | class LinearAttentionTransformerBlock(nn.Module):
894 | def __init__(
895 | self,
896 | dim,
897 | *,
898 | depth = 1,
899 | heads = 8,
900 | dim_head = 32,
901 | ff_mult = 2,
902 | context_dim = None,
903 | **kwargs
904 | ):
905 | super().__init__()
906 | self.layers = nn.ModuleList([])
907 |
908 | for _ in range(depth):
909 | self.layers.append(nn.ModuleList([
910 | LinearAttention(dim = dim, heads = heads, dim_head = dim_head, context_dim = context_dim),
911 | ChanFeedForward(dim = dim, mult = ff_mult)
912 | ]))
913 |
914 | def forward(self, x, context = None):
915 | for attn, ff in self.layers:
916 | x = attn(x, context = context) + x
917 | x = ff(x) + x
918 | return x
919 |
920 | class CrossEmbedLayer(nn.Module):
921 | def __init__(
922 | self,
923 | dim_in,
924 | kernel_sizes,
925 | dim_out = None,
926 | stride = 2
927 | ):
928 | super().__init__()
929 | assert all([*map(lambda t: (t % 2) == (stride % 2), kernel_sizes)])
930 | dim_out = default(dim_out, dim_in)
931 |
932 | kernel_sizes = sorted(kernel_sizes)
933 | num_scales = len(kernel_sizes)
934 |
935 | # calculate the dimension at each scale
936 | dim_scales = [int(dim_out / (2 ** i)) for i in range(1, num_scales)]
937 | dim_scales = [*dim_scales, dim_out - sum(dim_scales)]
938 |
939 | self.convs = nn.ModuleList([])
940 | for kernel, dim_scale in zip(kernel_sizes, dim_scales):
941 | self.convs.append(Conv2d(dim_in, dim_scale, kernel, stride = stride, padding = (kernel - stride) // 2))
942 |
943 | def forward(self, x):
944 | fmaps = tuple(map(lambda conv: conv(x), self.convs))
945 | return torch.cat(fmaps, dim = 1)
946 |
947 | class UpsampleCombiner(nn.Module):
948 | def __init__(
949 | self,
950 | dim,
951 | *,
952 | enabled = False,
953 | dim_ins = tuple(),
954 | dim_outs = tuple()
955 | ):
956 | super().__init__()
957 | dim_outs = cast_tuple(dim_outs, len(dim_ins))
958 | assert len(dim_ins) == len(dim_outs)
959 |
960 | self.enabled = enabled
961 |
962 | if not self.enabled:
963 | self.dim_out = dim
964 | return
965 |
966 | self.fmap_convs = nn.ModuleList([Block(dim_in, dim_out) for dim_in, dim_out in zip(dim_ins, dim_outs)])
967 | self.dim_out = dim + (sum(dim_outs) if len(dim_outs) > 0 else 0)
968 |
969 | def forward(self, x, fmaps = None):
970 | target_size = x.shape[-1]
971 |
972 | fmaps = default(fmaps, tuple())
973 |
974 | if not self.enabled or len(fmaps) == 0 or len(self.fmap_convs) == 0:
975 | return x
976 |
977 | fmaps = [resize_video_to(fmap, target_size) for fmap in fmaps]
978 | outs = [conv(fmap) for fmap, conv in zip(fmaps, self.fmap_convs)]
979 | return torch.cat((x, *outs), dim = 1)
980 |
981 | class DynamicPositionBias(nn.Module):
982 | def __init__(
983 | self,
984 | dim,
985 | *,
986 | heads,
987 | depth
988 | ):
989 | super().__init__()
990 | self.mlp = nn.ModuleList([])
991 |
992 | self.mlp.append(nn.Sequential(
993 | nn.Linear(1, dim),
994 | LayerNorm(dim),
995 | nn.SiLU()
996 | ))
997 |
998 | for _ in range(max(depth - 1, 0)):
999 | self.mlp.append(nn.Sequential(
1000 | nn.Linear(dim, dim),
1001 | LayerNorm(dim),
1002 | nn.SiLU()
1003 | ))
1004 |
1005 | self.mlp.append(nn.Linear(dim, heads))
1006 |
1007 | def forward(self, n, device, dtype):
1008 | i = torch.arange(n, device = device)
1009 | j = torch.arange(n, device = device)
1010 |
1011 | indices = rearrange(i, 'i -> i 1') - rearrange(j, 'j -> 1 j')
1012 | indices += (n - 1)
1013 |
1014 | pos = torch.arange(-n + 1, n, device = device, dtype = dtype)
1015 | pos = rearrange(pos, '... -> ... 1')
1016 |
1017 | for layer in self.mlp:
1018 | pos = layer(pos)
1019 |
1020 | bias = pos[indices]
1021 | bias = rearrange(bias, 'i j h -> h i j')
1022 | return bias
1023 |
1024 | class Unet3D(nn.Module):
1025 | def __init__(
1026 | self,
1027 | *,
1028 | dim,
1029 | image_embed_dim = 1024,
1030 | text_embed_dim = get_encoded_dim(DEFAULT_T5_NAME),
1031 | num_resnet_blocks = 1,
1032 | cond_dim = None,
1033 | num_image_tokens = 4,
1034 | num_time_tokens = 2,
1035 | learned_sinu_pos_emb_dim = 16,
1036 | out_dim = None,
1037 | dim_mults=(1, 2, 4, 8),
1038 | cond_images_channels = 0,
1039 | channels = 3,
1040 | channels_out = None,
1041 | attn_dim_head = 64,
1042 | attn_heads = 8,
1043 | ff_mult = 2.,
1044 | lowres_cond = False, # for cascading diffusion - https://cascaded-diffusion.github.io/
1045 | layer_attns = False,
1046 | layer_attns_depth = 1,
1047 | layer_attns_add_text_cond = True, # whether to condition the self-attention blocks with the text embeddings, as described in Appendix D.3.1
1048 | attend_at_middle = True, # whether to have a layer of attention at the bottleneck (can turn off for higher resolution in cascading DDPM, before bringing in efficient attention)
1049 | time_rel_pos_bias_depth = 2,
1050 | time_causal_attn = True,
1051 | layer_cross_attns = True,
1052 | use_linear_attn = False,
1053 | use_linear_cross_attn = False,
1054 | cond_on_text = True,
1055 | max_text_len = 256,
1056 | init_dim = None,
1057 | resnet_groups = 8,
1058 | init_conv_kernel_size = 7, # kernel size of initial conv, if not using cross embed
1059 | init_cross_embed = True,
1060 | init_cross_embed_kernel_sizes = (3, 7, 15),
1061 | cross_embed_downsample = False,
1062 | cross_embed_downsample_kernel_sizes = (2, 4),
1063 | attn_pool_text = True,
1064 | attn_pool_num_latents = 32,
1065 | dropout = 0.,
1066 | memory_efficient = False,
1067 | init_conv_to_final_conv_residual = False,
1068 | use_global_context_attn = True,
1069 | scale_skip_connection = True,
1070 | final_resnet_block = True,
1071 | final_conv_kernel_size = 3,
1072 | cosine_sim_attn = False,
1073 | self_cond = False,
1074 | combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
1075 | pixel_shuffle_upsample = True # may address checkboard artifacts
1076 | ):
1077 | super().__init__()
1078 |
1079 | # guide researchers
1080 |
1081 | assert attn_heads > 1, 'you need to have more than 1 attention head, ideally at least 4 or 8'
1082 |
1083 | if dim < 128:
1084 | print_once('The base dimension of your u-net should ideally be no smaller than 128, as recommended by a professional DDPM trainer https://nonint.com/2022/05/04/friends-dont-let-friends-train-small-diffusion-models/')
1085 |
1086 | # save locals to take care of some hyperparameters for cascading DDPM
1087 |
1088 | self._locals = locals()
1089 | self._locals.pop('self', None)
1090 | self._locals.pop('__class__', None)
1091 |
1092 | self.self_cond = self_cond
1093 |
1094 | # determine dimensions
1095 |
1096 | self.channels = channels
1097 | self.channels_out = default(channels_out, channels)
1098 |
1099 | # (1) in cascading diffusion, one concats the low resolution image, blurred, for conditioning the higher resolution synthesis
1100 | # (2) in self conditioning, one appends the predict x0 (x_start)
1101 | init_channels = channels * (1 + int(lowres_cond) + int(self_cond))
1102 | init_dim = default(init_dim, dim)
1103 |
1104 | # optional image conditioning
1105 |
1106 | self.has_cond_image = cond_images_channels > 0
1107 | self.cond_images_channels = cond_images_channels
1108 |
1109 | init_channels += cond_images_channels
1110 |
1111 | # initial convolution
1112 |
1113 | self.init_conv = CrossEmbedLayer(init_channels, dim_out = init_dim, kernel_sizes = init_cross_embed_kernel_sizes, stride = 1) if init_cross_embed else Conv2d(init_channels, init_dim, init_conv_kernel_size, padding = init_conv_kernel_size // 2)
1114 |
1115 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
1116 | in_out = list(zip(dims[:-1], dims[1:]))
1117 |
1118 | # time conditioning
1119 |
1120 | cond_dim = default(cond_dim, dim)
1121 | time_cond_dim = dim * 4 * (2 if lowres_cond else 1)
1122 |
1123 | # embedding time for log(snr) noise from continuous version
1124 |
1125 | sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim)
1126 | sinu_pos_emb_input_dim = learned_sinu_pos_emb_dim + 1
1127 |
1128 | self.to_time_hiddens = nn.Sequential(
1129 | sinu_pos_emb,
1130 | nn.Linear(sinu_pos_emb_input_dim, time_cond_dim),
1131 | nn.SiLU()
1132 | )
1133 |
1134 | self.to_time_cond = nn.Sequential(
1135 | nn.Linear(time_cond_dim, time_cond_dim)
1136 | )
1137 |
1138 | # project to time tokens as well as time hiddens
1139 |
1140 | self.to_time_tokens = nn.Sequential(
1141 | nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
1142 | Rearrange('b (r d) -> b r d', r = num_time_tokens)
1143 | )
1144 |
1145 | # low res aug noise conditioning
1146 |
1147 | self.lowres_cond = lowres_cond
1148 |
1149 | if lowres_cond:
1150 | self.to_lowres_time_hiddens = nn.Sequential(
1151 | LearnedSinusoidalPosEmb(learned_sinu_pos_emb_dim),
1152 | nn.Linear(learned_sinu_pos_emb_dim + 1, time_cond_dim),
1153 | nn.SiLU()
1154 | )
1155 |
1156 | self.to_lowres_time_cond = nn.Sequential(
1157 | nn.Linear(time_cond_dim, time_cond_dim)
1158 | )
1159 |
1160 | self.to_lowres_time_tokens = nn.Sequential(
1161 | nn.Linear(time_cond_dim, cond_dim * num_time_tokens),
1162 | Rearrange('b (r d) -> b r d', r = num_time_tokens)
1163 | )
1164 |
1165 | # normalizations
1166 |
1167 | self.norm_cond = nn.LayerNorm(cond_dim)
1168 |
1169 | # text encoding conditioning (optional)
1170 |
1171 | self.text_to_cond = None
1172 |
1173 | if cond_on_text:
1174 | assert exists(text_embed_dim), 'text_embed_dim must be given to the unet if cond_on_text is True'
1175 | self.text_to_cond = nn.Linear(text_embed_dim, cond_dim)
1176 |
1177 | # finer control over whether to condition on text encodings
1178 |
1179 | self.cond_on_text = cond_on_text
1180 |
1181 | # attention pooling
1182 |
1183 | self.attn_pool = PerceiverResampler(dim = cond_dim, depth = 2, dim_head = attn_dim_head, heads = attn_heads, num_latents = attn_pool_num_latents, cosine_sim_attn = cosine_sim_attn) if attn_pool_text else None
1184 |
1185 | # for classifier free guidance
1186 |
1187 | self.max_text_len = max_text_len
1188 |
1189 | self.null_text_embed = nn.Parameter(torch.randn(1, max_text_len, cond_dim))
1190 | self.null_text_hidden = nn.Parameter(torch.randn(1, time_cond_dim))
1191 |
1192 | # for non-attention based text conditioning at all points in the network where time is also conditioned
1193 |
1194 | self.to_text_non_attn_cond = None
1195 |
1196 | if cond_on_text:
1197 | self.to_text_non_attn_cond = nn.Sequential(
1198 | nn.LayerNorm(cond_dim),
1199 | nn.Linear(cond_dim, time_cond_dim),
1200 | nn.SiLU(),
1201 | nn.Linear(time_cond_dim, time_cond_dim)
1202 | )
1203 |
1204 | # attention related params
1205 |
1206 | attn_kwargs = dict(heads = attn_heads, dim_head = attn_dim_head, cosine_sim_attn = cosine_sim_attn)
1207 |
1208 | num_layers = len(in_out)
1209 |
1210 | # temporal attention - attention across video frames
1211 |
1212 | temporal_peg_padding = (0, 0, 0, 0, 2, 0) if time_causal_attn else (0, 0, 0, 0, 1, 1)
1213 | temporal_peg = lambda dim: Residual(nn.Sequential(Pad(temporal_peg_padding), nn.Conv3d(dim, dim, (3, 1, 1), groups = dim)))
1214 |
1215 | temporal_attn = lambda dim: EinopsToAndFrom('b c f h w', '(b h w) f c', Residual(Attention(dim, **{**attn_kwargs, 'causal': time_causal_attn})))
1216 |
1217 | # temporal attention relative positional encoding
1218 |
1219 | self.time_rel_pos_bias = DynamicPositionBias(dim = dim * 2, heads = attn_heads, depth = time_rel_pos_bias_depth)
1220 |
1221 | # resnet block klass
1222 |
1223 | num_resnet_blocks = cast_tuple(num_resnet_blocks, num_layers)
1224 | resnet_groups = cast_tuple(resnet_groups, num_layers)
1225 |
1226 | resnet_klass = partial(ResnetBlock, **attn_kwargs)
1227 |
1228 | layer_attns = cast_tuple(layer_attns, num_layers)
1229 | layer_attns_depth = cast_tuple(layer_attns_depth, num_layers)
1230 | layer_cross_attns = cast_tuple(layer_cross_attns, num_layers)
1231 |
1232 | assert all([layers == num_layers for layers in list(map(len, (resnet_groups, layer_attns, layer_cross_attns)))])
1233 |
1234 | # downsample klass
1235 |
1236 | downsample_klass = Downsample
1237 |
1238 | if cross_embed_downsample:
1239 | downsample_klass = partial(CrossEmbedLayer, kernel_sizes = cross_embed_downsample_kernel_sizes)
1240 |
1241 | # initial resnet block (for memory efficient unet)
1242 |
1243 | self.init_resnet_block = resnet_klass(init_dim, init_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = use_global_context_attn) if memory_efficient else None
1244 |
1245 | self.init_temporal_peg = temporal_peg(init_dim)
1246 | self.init_temporal_attn = temporal_attn(init_dim)
1247 |
1248 | # scale for resnet skip connections
1249 |
1250 | self.skip_connect_scale = 1. if not scale_skip_connection else (2 ** -0.5)
1251 |
1252 | # layers
1253 |
1254 | self.downs = nn.ModuleList([])
1255 | self.ups = nn.ModuleList([])
1256 | num_resolutions = len(in_out)
1257 |
1258 | layer_params = [num_resnet_blocks, resnet_groups, layer_attns, layer_attns_depth, layer_cross_attns]
1259 | reversed_layer_params = list(map(reversed, layer_params))
1260 |
1261 | # downsampling layers
1262 |
1263 | skip_connect_dims = [] # keep track of skip connection dimensions
1264 |
1265 | for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(in_out, *layer_params)):
1266 | is_last = ind >= (num_resolutions - 1)
1267 |
1268 | layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
1269 | layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
1270 |
1271 | transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
1272 |
1273 | current_dim = dim_in
1274 |
1275 | # whether to pre-downsample, from memory efficient unet
1276 |
1277 | pre_downsample = None
1278 |
1279 | if memory_efficient:
1280 | pre_downsample = downsample_klass(dim_in, dim_out)
1281 | current_dim = dim_out
1282 |
1283 | skip_connect_dims.append(current_dim)
1284 |
1285 | # whether to do post-downsample, for non-memory efficient unet
1286 |
1287 | post_downsample = None
1288 | if not memory_efficient:
1289 | post_downsample = downsample_klass(current_dim, dim_out) if not is_last else Parallel(Conv2d(dim_in, dim_out, 3, padding = 1), Conv2d(dim_in, dim_out, 1))
1290 |
1291 | self.downs.append(nn.ModuleList([
1292 | pre_downsample,
1293 | resnet_klass(current_dim, current_dim, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
1294 | nn.ModuleList([ResnetBlock(current_dim, current_dim, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
1295 | transformer_block_klass(dim = current_dim, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
1296 | temporal_peg(current_dim),
1297 | temporal_attn(current_dim),
1298 | post_downsample
1299 | ]))
1300 |
1301 | # middle layers
1302 |
1303 | mid_dim = dims[-1]
1304 |
1305 | self.mid_block1 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
1306 | self.mid_attn = EinopsToAndFrom('b c f h w', 'b (f h w) c', Residual(Attention(mid_dim, **attn_kwargs))) if attend_at_middle else None
1307 | self.mid_temporal_peg = temporal_peg(mid_dim)
1308 | self.mid_temporal_attn = temporal_attn(mid_dim)
1309 | self.mid_block2 = ResnetBlock(mid_dim, mid_dim, cond_dim = cond_dim, time_cond_dim = time_cond_dim, groups = resnet_groups[-1])
1310 |
1311 | # upsample klass
1312 |
1313 | upsample_klass = Upsample if not pixel_shuffle_upsample else PixelShuffleUpsample
1314 |
1315 | # upsampling layers
1316 |
1317 | upsample_fmap_dims = []
1318 |
1319 | for ind, ((dim_in, dim_out), layer_num_resnet_blocks, groups, layer_attn, layer_attn_depth, layer_cross_attn) in enumerate(zip(reversed(in_out), *reversed_layer_params)):
1320 | is_last = ind == (len(in_out) - 1)
1321 | layer_use_linear_cross_attn = not layer_cross_attn and use_linear_cross_attn
1322 | layer_cond_dim = cond_dim if layer_cross_attn or layer_use_linear_cross_attn else None
1323 | transformer_block_klass = TransformerBlock if layer_attn else (LinearAttentionTransformerBlock if use_linear_attn else Identity)
1324 |
1325 | skip_connect_dim = skip_connect_dims.pop()
1326 |
1327 | upsample_fmap_dims.append(dim_out)
1328 |
1329 | self.ups.append(nn.ModuleList([
1330 | resnet_klass(dim_out + skip_connect_dim, dim_out, cond_dim = layer_cond_dim, linear_attn = layer_use_linear_cross_attn, time_cond_dim = time_cond_dim, groups = groups),
1331 | nn.ModuleList([ResnetBlock(dim_out + skip_connect_dim, dim_out, time_cond_dim = time_cond_dim, groups = groups, use_gca = use_global_context_attn) for _ in range(layer_num_resnet_blocks)]),
1332 | transformer_block_klass(dim = dim_out, depth = layer_attn_depth, ff_mult = ff_mult, context_dim = cond_dim, **attn_kwargs),
1333 | temporal_peg(dim_out),
1334 | temporal_attn(dim_out),
1335 | upsample_klass(dim_out, dim_in) if not is_last or memory_efficient else Identity()
1336 | ]))
1337 |
1338 | # whether to combine feature maps from all upsample blocks before final resnet block out
1339 |
1340 | self.upsample_combiner = UpsampleCombiner(
1341 | dim = dim,
1342 | enabled = combine_upsample_fmaps,
1343 | dim_ins = upsample_fmap_dims,
1344 | dim_outs = dim
1345 | )
1346 |
1347 | # whether to do a final residual from initial conv to the final resnet block out
1348 |
1349 | self.init_conv_to_final_conv_residual = init_conv_to_final_conv_residual
1350 | final_conv_dim = self.upsample_combiner.dim_out + (dim if init_conv_to_final_conv_residual else 0)
1351 |
1352 | # final optional resnet block and convolution out
1353 |
1354 | self.final_res_block = ResnetBlock(final_conv_dim, dim, time_cond_dim = time_cond_dim, groups = resnet_groups[0], use_gca = True) if final_resnet_block else None
1355 |
1356 | final_conv_dim_in = dim if final_resnet_block else final_conv_dim
1357 | final_conv_dim_in += (channels if lowres_cond else 0)
1358 |
1359 | self.final_conv = Conv2d(final_conv_dim_in, self.channels_out, final_conv_kernel_size, padding = final_conv_kernel_size // 2)
1360 |
1361 | zero_init_(self.final_conv)
1362 |
1363 | # if the current settings for the unet are not correct
1364 | # for cascading DDPM, then reinit the unet with the right settings
1365 | def cast_model_parameters(
1366 | self,
1367 | *,
1368 | lowres_cond,
1369 | text_embed_dim,
1370 | channels,
1371 | channels_out,
1372 | cond_on_text
1373 | ):
1374 | if lowres_cond == self.lowres_cond and \
1375 | channels == self.channels and \
1376 | cond_on_text == self.cond_on_text and \
1377 | text_embed_dim == self._locals['text_embed_dim'] and \
1378 | channels_out == self.channels_out:
1379 | return self
1380 |
1381 | updated_kwargs = dict(
1382 | lowres_cond = lowres_cond,
1383 | text_embed_dim = text_embed_dim,
1384 | channels = channels,
1385 | channels_out = channels_out,
1386 | cond_on_text = cond_on_text
1387 | )
1388 |
1389 | return self.__class__(**{**self._locals, **updated_kwargs})
1390 |
1391 | # methods for returning the full unet config as well as its parameter state
1392 |
1393 | def to_config_and_state_dict(self):
1394 | return self._locals, self.state_dict()
1395 |
1396 | # class method for rehydrating the unet from its config and state dict
1397 |
1398 | @classmethod
1399 | def from_config_and_state_dict(klass, config, state_dict):
1400 | unet = klass(**config)
1401 | unet.load_state_dict(state_dict)
1402 | return unet
1403 |
1404 | # methods for persisting unet to disk
1405 |
1406 | def persist_to_file(self, path):
1407 | path = Path(path)
1408 | path.parents[0].mkdir(exist_ok = True, parents = True)
1409 |
1410 | config, state_dict = self.to_config_and_state_dict()
1411 | pkg = dict(config = config, state_dict = state_dict)
1412 | torch.save(pkg, str(path))
1413 |
1414 | # class method for rehydrating the unet from file saved with `persist_to_file`
1415 |
1416 | @classmethod
1417 | def hydrate_from_file(klass, path):
1418 | path = Path(path)
1419 | assert path.exists()
1420 | pkg = torch.load(str(path))
1421 |
1422 | assert 'config' in pkg and 'state_dict' in pkg
1423 | config, state_dict = pkg['config'], pkg['state_dict']
1424 |
1425 | return Unet.from_config_and_state_dict(config, state_dict)
1426 |
1427 | # forward with classifier free guidance
1428 |
1429 | def forward_with_cond_scale(
1430 | self,
1431 | *args,
1432 | cond_scale = 1.,
1433 | **kwargs
1434 | ):
1435 | logits = self.forward(*args, **kwargs)
1436 |
1437 | if cond_scale == 1:
1438 | return logits
1439 |
1440 | null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
1441 | return null_logits + (logits - null_logits) * cond_scale
1442 |
1443 | def forward(
1444 | self,
1445 | x,
1446 | time,
1447 | *,
1448 | lowres_cond_img = None,
1449 | lowres_noise_times = None,
1450 | text_embeds = None,
1451 | text_mask = None,
1452 | cond_images = None,
1453 | self_cond = None,
1454 | cond_drop_prob = 0.
1455 | ):
1456 | assert x.ndim == 5, 'input to 3d unet must have 5 dimensions (batch, channels, time, height, width)'
1457 |
1458 | batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype
1459 |
1460 | # add self conditioning if needed
1461 |
1462 | if self.self_cond:
1463 | self_cond = default(self_cond, lambda: torch.zeros_like(x))
1464 | x = torch.cat((x, self_cond), dim = 1)
1465 |
1466 | # add low resolution conditioning, if present
1467 |
1468 | assert not (self.lowres_cond and not exists(lowres_cond_img)), 'low resolution conditioning image must be present'
1469 | assert not (self.lowres_cond and not exists(lowres_noise_times)), 'low resolution conditioning noise time must be present'
1470 |
1471 | if exists(lowres_cond_img):
1472 | x = torch.cat((x, lowres_cond_img), dim = 1)
1473 |
1474 | # condition on input image
1475 |
1476 | assert not (self.has_cond_image ^ exists(cond_images)), 'you either requested to condition on an image on the unet, but the conditioning image is not supplied, or vice versa'
1477 |
1478 | if exists(cond_images):
1479 | assert cond_images.shape[1] == self.cond_images_channels, 'the number of channels on the conditioning image you are passing in does not match what you specified on initialiation of the unet'
1480 | cond_images = resize_video_to(cond_images, x.shape[-1])
1481 | x = torch.cat((cond_images, x), dim = 1)
1482 |
1483 | # get time relative positions
1484 |
1485 | time_attn_bias = self.time_rel_pos_bias(frames, device = device, dtype = dtype)
1486 |
1487 | # initial convolution
1488 |
1489 | x = self.init_conv(x)
1490 |
1491 | x = self.init_temporal_peg(x)
1492 | x = self.init_temporal_attn(x, attn_bias = time_attn_bias)
1493 |
1494 | # init conv residual
1495 |
1496 | if self.init_conv_to_final_conv_residual:
1497 | init_conv_residual = x.clone()
1498 |
1499 | # time conditioning
1500 |
1501 | time_hiddens = self.to_time_hiddens(time)
1502 |
1503 | # derive time tokens
1504 |
1505 | time_tokens = self.to_time_tokens(time_hiddens)
1506 | t = self.to_time_cond(time_hiddens)
1507 |
1508 | # add lowres time conditioning to time hiddens
1509 | # and add lowres time tokens along sequence dimension for attention
1510 |
1511 | if self.lowres_cond:
1512 | lowres_time_hiddens = self.to_lowres_time_hiddens(lowres_noise_times)
1513 | lowres_time_tokens = self.to_lowres_time_tokens(lowres_time_hiddens)
1514 | lowres_t = self.to_lowres_time_cond(lowres_time_hiddens)
1515 |
1516 | t = t + lowres_t
1517 | time_tokens = torch.cat((time_tokens, lowres_time_tokens), dim = -2)
1518 |
1519 | # text conditioning
1520 |
1521 | text_tokens = None
1522 |
1523 | if exists(text_embeds) and self.cond_on_text:
1524 |
1525 | # conditional dropout
1526 |
1527 | text_keep_mask = prob_mask_like((batch_size,), 1 - cond_drop_prob, device = device)
1528 |
1529 | text_keep_mask_embed = rearrange(text_keep_mask, 'b -> b 1 1')
1530 | text_keep_mask_hidden = rearrange(text_keep_mask, 'b -> b 1')
1531 |
1532 | # calculate text embeds
1533 |
1534 | text_tokens = self.text_to_cond(text_embeds)
1535 |
1536 | text_tokens = text_tokens[:, :self.max_text_len]
1537 |
1538 | if exists(text_mask):
1539 | text_mask = text_mask[:, :self.max_text_len]
1540 |
1541 | text_tokens_len = text_tokens.shape[1]
1542 | remainder = self.max_text_len - text_tokens_len
1543 |
1544 | if remainder > 0:
1545 | text_tokens = F.pad(text_tokens, (0, 0, 0, remainder))
1546 |
1547 | if exists(text_mask):
1548 | if remainder > 0:
1549 | text_mask = F.pad(text_mask, (0, remainder), value = False)
1550 |
1551 | text_mask = rearrange(text_mask, 'b n -> b n 1')
1552 | text_keep_mask_embed = text_mask & text_keep_mask_embed
1553 |
1554 | null_text_embed = self.null_text_embed.to(text_tokens.dtype) # for some reason pytorch AMP not working
1555 |
1556 | text_tokens = torch.where(
1557 | text_keep_mask_embed,
1558 | text_tokens,
1559 | null_text_embed
1560 | )
1561 |
1562 | if exists(self.attn_pool):
1563 | text_tokens = self.attn_pool(text_tokens)
1564 |
1565 | # extra non-attention conditioning by projecting and then summing text embeddings to time
1566 | # termed as text hiddens
1567 |
1568 | mean_pooled_text_tokens = text_tokens.mean(dim = -2)
1569 |
1570 | text_hiddens = self.to_text_non_attn_cond(mean_pooled_text_tokens)
1571 |
1572 | null_text_hidden = self.null_text_hidden.to(t.dtype)
1573 |
1574 | text_hiddens = torch.where(
1575 | text_keep_mask_hidden,
1576 | text_hiddens,
1577 | null_text_hidden
1578 | )
1579 |
1580 | t = t + text_hiddens
1581 |
1582 | # main conditioning tokens (c)
1583 |
1584 | c = time_tokens if not exists(text_tokens) else torch.cat((time_tokens, text_tokens), dim = -2)
1585 |
1586 | # normalize conditioning tokens
1587 |
1588 | c = self.norm_cond(c)
1589 |
1590 | # initial resnet block (for memory efficient unet)
1591 |
1592 | if exists(self.init_resnet_block):
1593 | x = self.init_resnet_block(x, t)
1594 |
1595 | # go through the layers of the unet, down and up
1596 |
1597 | hiddens = []
1598 |
1599 | for pre_downsample, init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, post_downsample in self.downs:
1600 | if exists(pre_downsample):
1601 | x = pre_downsample(x)
1602 |
1603 | x = init_block(x, t, c)
1604 |
1605 | for resnet_block in resnet_blocks:
1606 | x = resnet_block(x, t)
1607 | hiddens.append(x)
1608 |
1609 | x = attn_block(x, c)
1610 | x = temporal_peg(x)
1611 | x = temporal_attn(x, attn_bias = time_attn_bias)
1612 |
1613 | hiddens.append(x)
1614 |
1615 | if exists(post_downsample):
1616 | x = post_downsample(x)
1617 |
1618 | x = self.mid_block1(x, t, c)
1619 |
1620 | if exists(self.mid_attn):
1621 | x = self.mid_attn(x)
1622 |
1623 | x = self.mid_temporal_peg(x)
1624 | x = self.mid_temporal_attn(x, attn_bias = time_attn_bias)
1625 |
1626 | x = self.mid_block2(x, t, c)
1627 |
1628 | add_skip_connection = lambda x: torch.cat((x, hiddens.pop() * self.skip_connect_scale), dim = 1)
1629 |
1630 | up_hiddens = []
1631 |
1632 | for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, upsample in self.ups:
1633 | x = add_skip_connection(x)
1634 | x = init_block(x, t, c)
1635 |
1636 | for resnet_block in resnet_blocks:
1637 | x = add_skip_connection(x)
1638 | x = resnet_block(x, t)
1639 |
1640 | x = attn_block(x, c)
1641 | x = temporal_peg(x)
1642 | x = temporal_attn(x, attn_bias = time_attn_bias)
1643 |
1644 | up_hiddens.append(x.contiguous())
1645 | x = upsample(x)
1646 |
1647 | # whether to combine all feature maps from upsample blocks
1648 |
1649 | x = self.upsample_combiner(x, up_hiddens)
1650 |
1651 | # final top-most residual if needed
1652 |
1653 | if self.init_conv_to_final_conv_residual:
1654 | x = torch.cat((x, init_conv_residual), dim = 1)
1655 |
1656 | if exists(self.final_res_block):
1657 | x = self.final_res_block(x, t)
1658 |
1659 | if exists(lowres_cond_img):
1660 | x = torch.cat((x, lowres_cond_img), dim = 1)
1661 |
1662 | return self.final_conv(x)
1663 |
--------------------------------------------------------------------------------
/imagen_pytorch/t5.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import transformers
3 | from typing import List
4 | from transformers import T5Tokenizer, T5EncoderModel, T5Config
5 | from einops import rearrange
6 |
7 | transformers.logging.set_verbosity_error()
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | def default(val, d):
13 | if exists(val):
14 | return val
15 | return d() if callable(d) else d
16 |
17 | # config
18 |
19 | MAX_LENGTH = 256
20 |
21 | DEFAULT_T5_NAME = 'google/t5-v1_1-base'
22 |
23 | T5_CONFIGS = {}
24 |
25 | # singleton globals
26 |
27 | def get_tokenizer(name):
28 | tokenizer = T5Tokenizer.from_pretrained(name, model_max_length=MAX_LENGTH)
29 | return tokenizer
30 |
31 | def get_model(name):
32 | model = T5EncoderModel.from_pretrained(name)
33 | return model
34 |
35 | def get_model_and_tokenizer(name):
36 | global T5_CONFIGS
37 |
38 | if name not in T5_CONFIGS:
39 | T5_CONFIGS[name] = dict()
40 | if "model" not in T5_CONFIGS[name]:
41 | T5_CONFIGS[name]["model"] = get_model(name)
42 | if "tokenizer" not in T5_CONFIGS[name]:
43 | T5_CONFIGS[name]["tokenizer"] = get_tokenizer(name)
44 |
45 | return T5_CONFIGS[name]['model'], T5_CONFIGS[name]['tokenizer']
46 |
47 | def get_encoded_dim(name):
48 | if name not in T5_CONFIGS:
49 | # avoids loading the model if we only want to get the dim
50 | config = T5Config.from_pretrained(name)
51 | T5_CONFIGS[name] = dict(config=config)
52 | elif "config" in T5_CONFIGS[name]:
53 | config = T5_CONFIGS[name]["config"]
54 | elif "model" in T5_CONFIGS[name]:
55 | config = T5_CONFIGS[name]["model"].config
56 | else:
57 | assert False
58 | return config.d_model
59 |
60 | # encoding text
61 |
62 | def t5_tokenize(
63 | texts: List[str],
64 | name = DEFAULT_T5_NAME
65 | ):
66 | t5, tokenizer = get_model_and_tokenizer(name)
67 |
68 | if torch.cuda.is_available():
69 | t5 = t5.cuda()
70 |
71 | device = next(t5.parameters()).device
72 |
73 | encoded = tokenizer.batch_encode_plus(
74 | texts,
75 | return_tensors = "pt",
76 | padding = 'longest',
77 | max_length = MAX_LENGTH,
78 | truncation = True
79 | )
80 |
81 | input_ids = encoded.input_ids.to(device)
82 | attn_mask = encoded.attention_mask.to(device)
83 | return input_ids, attn_mask
84 |
85 | def t5_encode_tokenized_text(
86 | token_ids,
87 | attn_mask = None,
88 | pad_id = None,
89 | name = DEFAULT_T5_NAME
90 | ):
91 | assert exists(attn_mask) or exists(pad_id)
92 | t5, _ = get_model_and_tokenizer(name)
93 |
94 | attn_mask = default(attn_mask, lambda: (token_ids != pad_id).long())
95 |
96 | t5.eval()
97 |
98 | with torch.no_grad():
99 | output = t5(input_ids = token_ids, attention_mask = attn_mask)
100 | encoded_text = output.last_hidden_state.detach()
101 |
102 | attn_mask = attn_mask.bool()
103 |
104 | encoded_text = encoded_text.masked_fill(~rearrange(attn_mask, '... -> ... 1'), 0.) # just force all embeddings that is padding to be equal to 0.
105 | return encoded_text
106 |
107 | def t5_encode_text(
108 | texts: List[str],
109 | name = DEFAULT_T5_NAME,
110 | return_attn_mask = False
111 | ):
112 | token_ids, attn_mask = t5_tokenize(texts, name = name)
113 | encoded_text = t5_encode_tokenized_text(token_ids, attn_mask = attn_mask, name = name)
114 |
115 | if return_attn_mask:
116 | attn_mask = attn_mask.bool()
117 | return encoded_text, attn_mask
118 |
119 | return encoded_text
120 |
--------------------------------------------------------------------------------
/imagen_pytorch/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import copy
4 | from pathlib import Path
5 | from math import ceil
6 | from contextlib import contextmanager, nullcontext
7 | from functools import partial, wraps
8 | from collections.abc import Iterable
9 |
10 | import torch
11 | from torch import nn
12 | import torch.nn.functional as F
13 | from torch.utils.data import random_split, DataLoader
14 | from torch.optim import Adam
15 | from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
16 | from torch.cuda.amp import autocast, GradScaler
17 |
18 | import pytorch_warmup as warmup
19 |
20 | from imagen_pytorch.imagen_pytorch import Imagen, NullUnet
21 | from imagen_pytorch.elucidated_imagen import ElucidatedImagen
22 | from imagen_pytorch.data import cycle
23 |
24 | from imagen_pytorch.version import __version__
25 | from packaging import version
26 |
27 | import numpy as np
28 |
29 | from ema_pytorch import EMA
30 |
31 | from accelerate import Accelerator, DistributedType, DistributedDataParallelKwargs
32 |
33 | from fsspec.core import url_to_fs
34 | from fsspec.implementations.local import LocalFileSystem
35 |
36 | # helper functions
37 |
38 | def exists(val):
39 | return val is not None
40 |
41 | def default(val, d):
42 | if exists(val):
43 | return val
44 | return d() if callable(d) else d
45 |
46 | def cast_tuple(val, length = 1):
47 | if isinstance(val, list):
48 | val = tuple(val)
49 |
50 | return val if isinstance(val, tuple) else ((val,) * length)
51 |
52 | def find_first(fn, arr):
53 | for ind, el in enumerate(arr):
54 | if fn(el):
55 | return ind
56 | return -1
57 |
58 | def pick_and_pop(keys, d):
59 | values = list(map(lambda key: d.pop(key), keys))
60 | return dict(zip(keys, values))
61 |
62 | def group_dict_by_key(cond, d):
63 | return_val = [dict(),dict()]
64 | for key in d.keys():
65 | match = bool(cond(key))
66 | ind = int(not match)
67 | return_val[ind][key] = d[key]
68 | return (*return_val,)
69 |
70 | def string_begins_with(prefix, str):
71 | return str.startswith(prefix)
72 |
73 | def group_by_key_prefix(prefix, d):
74 | return group_dict_by_key(partial(string_begins_with, prefix), d)
75 |
76 | def groupby_prefix_and_trim(prefix, d):
77 | kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
78 | kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
79 | return kwargs_without_prefix, kwargs
80 |
81 | def num_to_groups(num, divisor):
82 | groups = num // divisor
83 | remainder = num % divisor
84 | arr = [divisor] * groups
85 | if remainder > 0:
86 | arr.append(remainder)
87 | return arr
88 |
89 | # url to fs, bucket, path - for checkpointing to cloud
90 |
91 | def url_to_bucket(url):
92 | if '://' not in url:
93 | return url
94 |
95 | _, suffix = url.split('://')
96 |
97 | if prefix in {'gs', 's3'}:
98 | return suffix.split('/')[0]
99 | else:
100 | raise ValueError(f'storage type prefix "{prefix}" is not supported yet')
101 |
102 | # decorators
103 |
104 | def eval_decorator(fn):
105 | def inner(model, *args, **kwargs):
106 | was_training = model.training
107 | model.eval()
108 | out = fn(model, *args, **kwargs)
109 | model.train(was_training)
110 | return out
111 | return inner
112 |
113 | def cast_torch_tensor(fn, cast_fp16 = False):
114 | @wraps(fn)
115 | def inner(model, *args, **kwargs):
116 | device = kwargs.pop('_device', model.device)
117 | cast_device = kwargs.pop('_cast_device', True)
118 |
119 | should_cast_fp16 = cast_fp16 and model.cast_half_at_training
120 |
121 | kwargs_keys = kwargs.keys()
122 | all_args = (*args, *kwargs.values())
123 | split_kwargs_index = len(all_args) - len(kwargs_keys)
124 | all_args = tuple(map(lambda t: torch.from_numpy(t) if exists(t) and isinstance(t, np.ndarray) else t, all_args))
125 |
126 | if cast_device:
127 | all_args = tuple(map(lambda t: t.to(device) if exists(t) and isinstance(t, torch.Tensor) else t, all_args))
128 |
129 | if should_cast_fp16:
130 | all_args = tuple(map(lambda t: t.half() if exists(t) and isinstance(t, torch.Tensor) and t.dtype != torch.bool else t, all_args))
131 |
132 | args, kwargs_values = all_args[:split_kwargs_index], all_args[split_kwargs_index:]
133 | kwargs = dict(tuple(zip(kwargs_keys, kwargs_values)))
134 |
135 | out = fn(model, *args, **kwargs)
136 | return out
137 | return inner
138 |
139 | # gradient accumulation functions
140 |
141 | def split_iterable(it, split_size):
142 | accum = []
143 | for ind in range(ceil(len(it) / split_size)):
144 | start_index = ind * split_size
145 | accum.append(it[start_index: (start_index + split_size)])
146 | return accum
147 |
148 | def split(t, split_size = None):
149 | if not exists(split_size):
150 | return t
151 |
152 | if isinstance(t, torch.Tensor):
153 | return t.split(split_size, dim = 0)
154 |
155 | if isinstance(t, Iterable):
156 | return split_iterable(t, split_size)
157 |
158 | return TypeError
159 |
160 | def find_first(cond, arr):
161 | for el in arr:
162 | if cond(el):
163 | return el
164 | return None
165 |
166 | def split_args_and_kwargs(*args, split_size = None, **kwargs):
167 | all_args = (*args, *kwargs.values())
168 | len_all_args = len(all_args)
169 | first_tensor = find_first(lambda t: isinstance(t, torch.Tensor), all_args)
170 | assert exists(first_tensor)
171 |
172 | batch_size = len(first_tensor)
173 | split_size = default(split_size, batch_size)
174 | num_chunks = ceil(batch_size / split_size)
175 |
176 | dict_len = len(kwargs)
177 | dict_keys = kwargs.keys()
178 | split_kwargs_index = len_all_args - dict_len
179 |
180 | split_all_args = [split(arg, split_size = split_size) if exists(arg) and isinstance(arg, (torch.Tensor, Iterable)) else ((arg,) * num_chunks) for arg in all_args]
181 | chunk_sizes = tuple(map(len, split_all_args[0]))
182 |
183 | for (chunk_size, *chunked_all_args) in tuple(zip(chunk_sizes, *split_all_args)):
184 | chunked_args, chunked_kwargs_values = chunked_all_args[:split_kwargs_index], chunked_all_args[split_kwargs_index:]
185 | chunked_kwargs = dict(tuple(zip(dict_keys, chunked_kwargs_values)))
186 | chunk_size_frac = chunk_size / batch_size
187 | yield chunk_size_frac, (chunked_args, chunked_kwargs)
188 |
189 | # imagen trainer
190 |
191 | def imagen_sample_in_chunks(fn):
192 | @wraps(fn)
193 | def inner(self, *args, max_batch_size = None, **kwargs):
194 | if not exists(max_batch_size):
195 | return fn(self, *args, **kwargs)
196 |
197 | if self.imagen.unconditional:
198 | batch_size = kwargs.get('batch_size')
199 | batch_sizes = num_to_groups(batch_size, max_batch_size)
200 | outputs = [fn(self, *args, **{**kwargs, 'batch_size': sub_batch_size}) for sub_batch_size in batch_sizes]
201 | else:
202 | outputs = [fn(self, *chunked_args, **chunked_kwargs) for _, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs)]
203 |
204 | if isinstance(outputs[0], torch.Tensor):
205 | return torch.cat(outputs, dim = 0)
206 |
207 | return list(map(lambda t: torch.cat(t, dim = 0), list(zip(*outputs))))
208 |
209 | return inner
210 |
211 |
212 | def restore_parts(state_dict_target, state_dict_from):
213 | for name, param in state_dict_from.items():
214 |
215 | if name not in state_dict_target:
216 | continue
217 |
218 | if param.size() == state_dict_target[name].size():
219 | state_dict_target[name].copy_(param)
220 | else:
221 | print(f"layer {name}({param.size()} different than target: {state_dict_target[name].size()}")
222 |
223 | return state_dict_target
224 |
225 |
226 | class ImagenTrainer(nn.Module):
227 | locked = False
228 |
229 | def __init__(
230 | self,
231 | imagen = None,
232 | imagen_checkpoint_path = None,
233 | use_ema = True,
234 | lr = 1e-4,
235 | eps = 1e-8,
236 | beta1 = 0.9,
237 | beta2 = 0.99,
238 | max_grad_norm = None,
239 | group_wd_params = True,
240 | warmup_steps = None,
241 | cosine_decay_max_steps = None,
242 | only_train_unet_number = None,
243 | fp16 = False,
244 | precision = None,
245 | split_batches = True,
246 | dl_tuple_output_keywords_names = ('images', 'text_embeds', 'text_masks', 'cond_images'),
247 | verbose = True,
248 | split_valid_fraction = 0.025,
249 | split_valid_from_train = False,
250 | split_random_seed = 42,
251 | checkpoint_path = None,
252 | checkpoint_every = None,
253 | checkpoint_fs = None,
254 | fs_kwargs: dict = None,
255 | max_checkpoints_keep = 20,
256 | **kwargs
257 | ):
258 | super().__init__()
259 | assert not ImagenTrainer.locked, 'ImagenTrainer can only be initialized once per process - for the sake of distributed training, you will now have to create a separate script to train each unet (or a script that accepts unet number as an argument)'
260 | assert exists(imagen) ^ exists(imagen_checkpoint_path), 'either imagen instance is passed into the trainer, or a checkpoint path that contains the imagen config'
261 |
262 | # determine filesystem, using fsspec, for saving to local filesystem or cloud
263 |
264 | self.fs = checkpoint_fs
265 |
266 | if not exists(self.fs):
267 | fs_kwargs = default(fs_kwargs, {})
268 | self.fs, _ = url_to_fs(default(checkpoint_path, './'), **fs_kwargs)
269 |
270 | assert isinstance(imagen, (Imagen, ElucidatedImagen))
271 | ema_kwargs, kwargs = groupby_prefix_and_trim('ema_', kwargs)
272 |
273 | # elucidated or not
274 |
275 | self.is_elucidated = isinstance(imagen, ElucidatedImagen)
276 |
277 | # create accelerator instance
278 |
279 | accelerate_kwargs, kwargs = groupby_prefix_and_trim('accelerate_', kwargs)
280 |
281 | assert not (fp16 and exists(precision)), 'either set fp16 = True or forward the precision ("fp16", "bf16") to Accelerator'
282 | accelerator_mixed_precision = default(precision, 'fp16' if fp16 else 'no')
283 |
284 | self.accelerator = Accelerator(**{
285 | 'split_batches': split_batches,
286 | 'mixed_precision': accelerator_mixed_precision,
287 | 'kwargs_handlers': [DistributedDataParallelKwargs(find_unused_parameters = True)]
288 | , **accelerate_kwargs})
289 |
290 | ImagenTrainer.locked = self.is_distributed
291 |
292 | # cast data to fp16 at training time if needed
293 |
294 | self.cast_half_at_training = accelerator_mixed_precision == 'fp16'
295 |
296 | # grad scaler must be managed outside of accelerator
297 |
298 | grad_scaler_enabled = fp16
299 |
300 | # imagen, unets and ema unets
301 |
302 | self.imagen = imagen
303 | self.num_unets = len(self.imagen.unets)
304 |
305 | self.use_ema = use_ema and self.is_main
306 | self.ema_unets = nn.ModuleList([])
307 |
308 | # keep track of what unet is being trained on
309 | # only going to allow 1 unet training at a time
310 |
311 | self.ema_unet_being_trained_index = -1 # keeps track of which ema unet is being trained on
312 |
313 | # data related functions
314 |
315 | self.train_dl_iter = None
316 | self.train_dl = None
317 |
318 | self.valid_dl_iter = None
319 | self.valid_dl = None
320 |
321 | self.dl_tuple_output_keywords_names = dl_tuple_output_keywords_names
322 |
323 | # auto splitting validation from training, if dataset is passed in
324 |
325 | self.split_valid_from_train = split_valid_from_train
326 |
327 | assert 0 <= split_valid_fraction <= 1, 'split valid fraction must be between 0 and 1'
328 | self.split_valid_fraction = split_valid_fraction
329 | self.split_random_seed = split_random_seed
330 |
331 | # be able to finely customize learning rate, weight decay
332 | # per unet
333 |
334 | lr, eps, warmup_steps, cosine_decay_max_steps = map(partial(cast_tuple, length = self.num_unets), (lr, eps, warmup_steps, cosine_decay_max_steps))
335 |
336 | for ind, (unet, unet_lr, unet_eps, unet_warmup_steps, unet_cosine_decay_max_steps) in enumerate(zip(self.imagen.unets, lr, eps, warmup_steps, cosine_decay_max_steps)):
337 | optimizer = Adam(
338 | unet.parameters(),
339 | lr = unet_lr,
340 | eps = unet_eps,
341 | betas = (beta1, beta2),
342 | **kwargs
343 | )
344 |
345 | if self.use_ema:
346 | self.ema_unets.append(EMA(unet, **ema_kwargs))
347 |
348 | scaler = GradScaler(enabled = grad_scaler_enabled)
349 |
350 | scheduler = warmup_scheduler = None
351 |
352 | if exists(unet_cosine_decay_max_steps):
353 | scheduler = CosineAnnealingLR(optimizer, T_max = unet_cosine_decay_max_steps)
354 |
355 | if exists(unet_warmup_steps):
356 | warmup_scheduler = warmup.LinearWarmup(optimizer, warmup_period = unet_warmup_steps)
357 |
358 | if not exists(scheduler):
359 | scheduler = LambdaLR(optimizer, lr_lambda = lambda step: 1.0)
360 |
361 | # set on object
362 |
363 | setattr(self, f'optim{ind}', optimizer) # cannot use pytorch ModuleList for some reason with optimizers
364 | setattr(self, f'scaler{ind}', scaler)
365 | setattr(self, f'scheduler{ind}', scheduler)
366 | setattr(self, f'warmup{ind}', warmup_scheduler)
367 |
368 | # gradient clipping if needed
369 |
370 | self.max_grad_norm = max_grad_norm
371 |
372 | # step tracker and misc
373 |
374 | self.register_buffer('steps', torch.tensor([0] * self.num_unets))
375 |
376 | self.verbose = verbose
377 |
378 | # automatic set devices based on what accelerator decided
379 |
380 | self.imagen.to(self.device)
381 | self.to(self.device)
382 |
383 | # checkpointing
384 |
385 | assert not (exists(checkpoint_path) ^ exists(checkpoint_every))
386 | self.checkpoint_path = checkpoint_path
387 | self.checkpoint_every = checkpoint_every
388 | self.max_checkpoints_keep = max_checkpoints_keep
389 |
390 | self.can_checkpoint = self.is_local_main if isinstance(checkpoint_fs, LocalFileSystem) else self.is_main
391 |
392 | if exists(checkpoint_path) and self.can_checkpoint:
393 | bucket = url_to_bucket(checkpoint_path)
394 |
395 | if not self.fs.exists(bucket):
396 | self.fs.mkdir(bucket)
397 |
398 | self.load_from_checkpoint_folder()
399 |
400 | # only allowing training for unet
401 |
402 | self.only_train_unet_number = only_train_unet_number
403 | self.validate_and_set_unet_being_trained(only_train_unet_number)
404 |
405 | # computed values
406 |
407 | @property
408 | def device(self):
409 | return self.accelerator.device
410 |
411 | @property
412 | def is_distributed(self):
413 | return not (self.accelerator.distributed_type == DistributedType.NO and self.accelerator.num_processes == 1)
414 |
415 | @property
416 | def is_main(self):
417 | return self.accelerator.is_main_process
418 |
419 | @property
420 | def is_local_main(self):
421 | return self.accelerator.is_local_main_process
422 |
423 | @property
424 | def unwrapped_unet(self):
425 | return self.accelerator.unwrap_model(self.unet_being_trained)
426 |
427 | # optimizer helper functions
428 |
429 | def get_lr(self, unet_number):
430 | self.validate_unet_number(unet_number)
431 | unet_index = unet_number - 1
432 |
433 | optim = getattr(self, f'optim{unet_index}')
434 |
435 | return optim.param_groups[0]['lr']
436 |
437 | # function for allowing only one unet from being trained at a time
438 |
439 | def validate_and_set_unet_being_trained(self, unet_number = None):
440 | if exists(unet_number):
441 | self.validate_unet_number(unet_number)
442 |
443 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, 'you cannot only train on one unet at a time. you will need to save the trainer into a checkpoint, and resume training on a new unet'
444 |
445 | self.only_train_unet_number = unet_number
446 | self.imagen.only_train_unet_number = unet_number
447 |
448 | if not exists(unet_number):
449 | return
450 |
451 | self.wrap_unet(unet_number)
452 |
453 | def wrap_unet(self, unet_number):
454 | if hasattr(self, 'one_unet_wrapped'):
455 | return
456 |
457 | unet = self.imagen.get_unet(unet_number)
458 | self.unet_being_trained = self.accelerator.prepare(unet)
459 | unet_index = unet_number - 1
460 |
461 | optimizer = getattr(self, f'optim{unet_index}')
462 | scheduler = getattr(self, f'scheduler{unet_index}')
463 |
464 | optimizer = self.accelerator.prepare(optimizer)
465 |
466 | if exists(scheduler):
467 | scheduler = self.accelerator.prepare(scheduler)
468 |
469 | setattr(self, f'optim{unet_index}', optimizer)
470 | setattr(self, f'scheduler{unet_index}', scheduler)
471 |
472 | self.one_unet_wrapped = True
473 |
474 | # hacking accelerator due to not having separate gradscaler per optimizer
475 |
476 | def set_accelerator_scaler(self, unet_number):
477 | unet_number = self.validate_unet_number(unet_number)
478 | scaler = getattr(self, f'scaler{unet_number - 1}')
479 |
480 | self.accelerator.scaler = scaler
481 | for optimizer in self.accelerator._optimizers:
482 | optimizer.scaler = scaler
483 |
484 | # helper print
485 |
486 | def print(self, msg):
487 | if not self.is_main:
488 | return
489 |
490 | if not self.verbose:
491 | return
492 |
493 | return self.accelerator.print(msg)
494 |
495 | # validating the unet number
496 |
497 | def validate_unet_number(self, unet_number = None):
498 | if self.num_unets == 1:
499 | unet_number = default(unet_number, 1)
500 |
501 | assert 0 < unet_number <= self.num_unets, f'unet number should be in between 1 and {self.num_unets}'
502 | return unet_number
503 |
504 | # number of training steps taken
505 |
506 | def num_steps_taken(self, unet_number = None):
507 | if self.num_unets == 1:
508 | unet_number = default(unet_number, 1)
509 |
510 | return self.steps[unet_number - 1].item()
511 |
512 | def print_untrained_unets(self):
513 | print_final_error = False
514 |
515 | for ind, (steps, unet) in enumerate(zip(self.steps.tolist(), self.imagen.unets)):
516 | if steps > 0 or isinstance(unet, NullUnet):
517 | continue
518 |
519 | self.print(f'unet {ind + 1} has not been trained')
520 | print_final_error = True
521 |
522 | if print_final_error:
523 | self.print('when sampling, you can pass stop_at_unet_number to stop early in the cascade, so it does not try to generate with untrained unets')
524 |
525 | # data related functions
526 |
527 | def add_train_dataloader(self, dl = None):
528 | if not exists(dl):
529 | return
530 |
531 | assert not exists(self.train_dl), 'training dataloader was already added'
532 | self.train_dl = self.accelerator.prepare(dl)
533 |
534 | def add_valid_dataloader(self, dl):
535 | if not exists(dl):
536 | return
537 |
538 | assert not exists(self.valid_dl), 'validation dataloader was already added'
539 | self.valid_dl = self.accelerator.prepare(dl)
540 |
541 | def add_train_dataset(self, ds = None, *, batch_size, **dl_kwargs):
542 | if not exists(ds):
543 | return
544 |
545 | assert not exists(self.train_dl), 'training dataloader was already added'
546 |
547 | valid_ds = None
548 | if self.split_valid_from_train:
549 | train_size = int((1 - self.split_valid_fraction) * len(ds))
550 | valid_size = len(ds) - train_size
551 |
552 | ds, valid_ds = random_split(ds, [train_size, valid_size], generator = torch.Generator().manual_seed(self.split_random_seed))
553 | self.print(f'training with dataset of {len(ds)} samples and validating with randomly splitted {len(valid_ds)} samples')
554 |
555 | dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
556 | self.train_dl = self.accelerator.prepare(dl)
557 |
558 | if not self.split_valid_from_train:
559 | return
560 |
561 | self.add_valid_dataset(valid_ds, batch_size = batch_size, **dl_kwargs)
562 |
563 | def add_valid_dataset(self, ds, *, batch_size, **dl_kwargs):
564 | if not exists(ds):
565 | return
566 |
567 | assert not exists(self.valid_dl), 'validation dataloader was already added'
568 |
569 | dl = DataLoader(ds, batch_size = batch_size, **dl_kwargs)
570 | self.valid_dl = self.accelerator.prepare(dl)
571 |
572 | def create_train_iter(self):
573 | assert exists(self.train_dl), 'training dataloader has not been registered with the trainer yet'
574 |
575 | if exists(self.train_dl_iter):
576 | return
577 |
578 | self.train_dl_iter = cycle(self.train_dl)
579 |
580 | def create_valid_iter(self):
581 | assert exists(self.valid_dl), 'validation dataloader has not been registered with the trainer yet'
582 |
583 | if exists(self.valid_dl_iter):
584 | return
585 |
586 | self.valid_dl_iter = cycle(self.valid_dl)
587 |
588 | def train_step(self, unet_number = None, **kwargs):
589 | self.create_train_iter()
590 | loss = self.step_with_dl_iter(self.train_dl_iter, unet_number = unet_number, **kwargs)
591 | self.update(unet_number = unet_number)
592 | return loss
593 |
594 | @torch.no_grad()
595 | @eval_decorator
596 | def valid_step(self, **kwargs):
597 | self.create_valid_iter()
598 |
599 | context = self.use_ema_unets if kwargs.pop('use_ema_unets', False) else nullcontext
600 |
601 | with context():
602 | loss = self.step_with_dl_iter(self.valid_dl_iter, **kwargs)
603 | return loss
604 |
605 | def step_with_dl_iter(self, dl_iter, **kwargs):
606 | dl_tuple_output = cast_tuple(next(dl_iter))
607 | model_input = dict(list(zip(self.dl_tuple_output_keywords_names, dl_tuple_output)))
608 | loss = self.forward(**{**kwargs, **model_input})
609 | return loss
610 |
611 | # checkpointing functions
612 |
613 | @property
614 | def all_checkpoints_sorted(self):
615 | glob_pattern = os.path.join(self.checkpoint_path, '*.pt')
616 | checkpoints = self.fs.glob(glob_pattern)
617 | sorted_checkpoints = sorted(checkpoints, key = lambda x: int(str(x).split('.')[-2]), reverse = True)
618 | return sorted_checkpoints
619 |
620 | def load_from_checkpoint_folder(self, last_total_steps = -1):
621 | if last_total_steps != -1:
622 | filepath = os.path.join(self.checkpoint_path, f'checkpoint.{last_total_steps}.pt')
623 | self.load(filepath)
624 | return
625 |
626 | sorted_checkpoints = self.all_checkpoints_sorted
627 |
628 | if len(sorted_checkpoints) == 0:
629 | self.print(f'no checkpoints found to load from at {self.checkpoint_path}')
630 | return
631 |
632 | last_checkpoint = sorted_checkpoints[0]
633 | self.load(last_checkpoint)
634 |
635 | def save_to_checkpoint_folder(self):
636 | self.accelerator.wait_for_everyone()
637 |
638 | if not self.can_checkpoint:
639 | return
640 |
641 | total_steps = int(self.steps.sum().item())
642 | filepath = os.path.join(self.checkpoint_path, f'checkpoint.{total_steps}.pt')
643 |
644 | self.save(filepath)
645 |
646 | if self.max_checkpoints_keep <= 0:
647 | return
648 |
649 | sorted_checkpoints = self.all_checkpoints_sorted
650 | checkpoints_to_discard = sorted_checkpoints[self.max_checkpoints_keep:]
651 |
652 | for checkpoint in checkpoints_to_discard:
653 | self.fs.rm(checkpoint)
654 |
655 | # saving and loading functions
656 |
657 | def save(
658 | self,
659 | path,
660 | overwrite = True,
661 | without_optim_and_sched = False,
662 | **kwargs
663 | ):
664 | self.accelerator.wait_for_everyone()
665 |
666 | if not self.can_checkpoint:
667 | return
668 |
669 | fs = self.fs
670 |
671 | assert not (fs.exists(path) and not overwrite)
672 |
673 | self.reset_ema_unets_all_one_device()
674 |
675 | save_obj = dict(
676 | model = self.imagen.state_dict(),
677 | version = __version__,
678 | steps = self.steps.cpu(),
679 | **kwargs
680 | )
681 |
682 | save_optim_and_sched_iter = range(0, self.num_unets) if not without_optim_and_sched else tuple()
683 |
684 | for ind in save_optim_and_sched_iter:
685 | scaler_key = f'scaler{ind}'
686 | optimizer_key = f'optim{ind}'
687 | scheduler_key = f'scheduler{ind}'
688 | warmup_scheduler_key = f'warmup{ind}'
689 |
690 | scaler = getattr(self, scaler_key)
691 | optimizer = getattr(self, optimizer_key)
692 | scheduler = getattr(self, scheduler_key)
693 | warmup_scheduler = getattr(self, warmup_scheduler_key)
694 |
695 | if exists(scheduler):
696 | save_obj = {**save_obj, scheduler_key: scheduler.state_dict()}
697 |
698 | if exists(warmup_scheduler):
699 | save_obj = {**save_obj, warmup_scheduler_key: warmup_scheduler.state_dict()}
700 |
701 | save_obj = {**save_obj, scaler_key: scaler.state_dict(), optimizer_key: optimizer.state_dict()}
702 |
703 | if self.use_ema:
704 | save_obj = {**save_obj, 'ema': self.ema_unets.state_dict()}
705 |
706 | # determine if imagen config is available
707 |
708 | if hasattr(self.imagen, '_config'):
709 | self.print(f'this checkpoint is commandable from the CLI - "imagen --model {str(path)} \"\""')
710 |
711 | save_obj = {
712 | **save_obj,
713 | 'imagen_type': 'elucidated' if self.is_elucidated else 'original',
714 | 'imagen_params': self.imagen._config
715 | }
716 |
717 | #save to path
718 |
719 | with fs.open(path, 'wb') as f:
720 | torch.save(save_obj, f)
721 |
722 | self.print(f'checkpoint saved to {path}')
723 |
724 | def load(self, path, only_model = False, strict = True, noop_if_not_exist = False):
725 | fs = self.fs
726 |
727 | if noop_if_not_exist and not fs.exists(path):
728 | self.print(f'trainer checkpoint not found at {str(path)}')
729 | return
730 |
731 | assert fs.exists(path), f'{path} does not exist'
732 |
733 | self.reset_ema_unets_all_one_device()
734 |
735 | # to avoid extra GPU memory usage in main process when using Accelerate
736 |
737 | with fs.open(path) as f:
738 | loaded_obj = torch.load(f, map_location='cpu')
739 |
740 | if version.parse(__version__) != version.parse(loaded_obj['version']):
741 | self.print(f'loading saved imagen at version {loaded_obj["version"]}, but current package version is {__version__}')
742 |
743 | try:
744 | self.imagen.load_state_dict(loaded_obj['model'], strict = strict)
745 | except RuntimeError:
746 | print("Failed loading state dict. Trying partial load")
747 | self.imagen.load_state_dict(restore_parts(self.imagen.state_dict(),
748 | loaded_obj['model']))
749 |
750 | if only_model:
751 | return loaded_obj
752 |
753 | self.steps.copy_(loaded_obj['steps'])
754 |
755 | for ind in range(0, self.num_unets):
756 | scaler_key = f'scaler{ind}'
757 | optimizer_key = f'optim{ind}'
758 | scheduler_key = f'scheduler{ind}'
759 | warmup_scheduler_key = f'warmup{ind}'
760 |
761 | scaler = getattr(self, scaler_key)
762 | optimizer = getattr(self, optimizer_key)
763 | scheduler = getattr(self, scheduler_key)
764 | warmup_scheduler = getattr(self, warmup_scheduler_key)
765 |
766 | if exists(scheduler) and scheduler_key in loaded_obj:
767 | scheduler.load_state_dict(loaded_obj[scheduler_key])
768 |
769 | if exists(warmup_scheduler) and warmup_scheduler_key in loaded_obj:
770 | warmup_scheduler.load_state_dict(loaded_obj[warmup_scheduler_key])
771 |
772 | if exists(optimizer):
773 | try:
774 | optimizer.load_state_dict(loaded_obj[optimizer_key])
775 | scaler.load_state_dict(loaded_obj[scaler_key])
776 | except:
777 | self.print('could not load optimizer and scaler, possibly because you have turned on mixed precision training since the last run. resuming with new optimizer and scalers')
778 |
779 | if self.use_ema:
780 | assert 'ema' in loaded_obj
781 | try:
782 | self.ema_unets.load_state_dict(loaded_obj['ema'], strict = strict)
783 | except RuntimeError:
784 | print("Failed loading state dict. Trying partial load")
785 | self.ema_unets.load_state_dict(restore_parts(self.ema_unets.state_dict(),
786 | loaded_obj['ema']))
787 |
788 | self.print(f'checkpoint loaded from {path}')
789 | return loaded_obj
790 |
791 | # managing ema unets and their devices
792 |
793 | @property
794 | def unets(self):
795 | return nn.ModuleList([ema.ema_model for ema in self.ema_unets])
796 |
797 | def get_ema_unet(self, unet_number = None):
798 | if not self.use_ema:
799 | return
800 |
801 | unet_number = self.validate_unet_number(unet_number)
802 | index = unet_number - 1
803 |
804 | if isinstance(self.unets, nn.ModuleList):
805 | unets_list = [unet for unet in self.ema_unets]
806 | delattr(self, 'ema_unets')
807 | self.ema_unets = unets_list
808 |
809 | if index != self.ema_unet_being_trained_index:
810 | for unet_index, unet in enumerate(self.ema_unets):
811 | unet.to(self.device if unet_index == index else 'cpu')
812 |
813 | self.ema_unet_being_trained_index = index
814 | return self.ema_unets[index]
815 |
816 | def reset_ema_unets_all_one_device(self, device = None):
817 | if not self.use_ema:
818 | return
819 |
820 | device = default(device, self.device)
821 | self.ema_unets = nn.ModuleList([*self.ema_unets])
822 | self.ema_unets.to(device)
823 |
824 | self.ema_unet_being_trained_index = -1
825 |
826 | @torch.no_grad()
827 | @contextmanager
828 | def use_ema_unets(self):
829 | if not self.use_ema:
830 | output = yield
831 | return output
832 |
833 | self.reset_ema_unets_all_one_device()
834 | self.imagen.reset_unets_all_one_device()
835 |
836 | self.unets.eval()
837 |
838 | trainable_unets = self.imagen.unets
839 | self.imagen.unets = self.unets # swap in exponential moving averaged unets for sampling
840 |
841 | output = yield
842 |
843 | self.imagen.unets = trainable_unets # restore original training unets
844 |
845 | # cast the ema_model unets back to original device
846 | for ema in self.ema_unets:
847 | ema.restore_ema_model_device()
848 |
849 | return output
850 |
851 | def print_unet_devices(self):
852 | self.print('unet devices:')
853 | for i, unet in enumerate(self.imagen.unets):
854 | device = next(unet.parameters()).device
855 | self.print(f'\tunet {i}: {device}')
856 |
857 | if not self.use_ema:
858 | return
859 |
860 | self.print('\nema unet devices:')
861 | for i, ema_unet in enumerate(self.ema_unets):
862 | device = next(ema_unet.parameters()).device
863 | self.print(f'\tema unet {i}: {device}')
864 |
865 | # overriding state dict functions
866 |
867 | def state_dict(self, *args, **kwargs):
868 | self.reset_ema_unets_all_one_device()
869 | return super().state_dict(*args, **kwargs)
870 |
871 | def load_state_dict(self, *args, **kwargs):
872 | self.reset_ema_unets_all_one_device()
873 | return super().load_state_dict(*args, **kwargs)
874 |
875 | # encoding text functions
876 |
877 | def encode_text(self, text, **kwargs):
878 | return self.imagen.encode_text(text, **kwargs)
879 |
880 | # forwarding functions and gradient step updates
881 |
882 | def update(self, unet_number = None):
883 | unet_number = self.validate_unet_number(unet_number)
884 | self.validate_and_set_unet_being_trained(unet_number)
885 | self.set_accelerator_scaler(unet_number)
886 |
887 | index = unet_number - 1
888 | unet = self.unet_being_trained
889 |
890 | optimizer = getattr(self, f'optim{index}')
891 | scaler = getattr(self, f'scaler{index}')
892 | scheduler = getattr(self, f'scheduler{index}')
893 | warmup_scheduler = getattr(self, f'warmup{index}')
894 |
895 | # set the grad scaler on the accelerator, since we are managing one per u-net
896 |
897 | if exists(self.max_grad_norm):
898 | self.accelerator.clip_grad_norm_(unet.parameters(), self.max_grad_norm)
899 |
900 | optimizer.step()
901 | optimizer.zero_grad()
902 |
903 | if self.use_ema:
904 | ema_unet = self.get_ema_unet(unet_number)
905 | ema_unet.update()
906 |
907 | # scheduler, if needed
908 |
909 | maybe_warmup_context = nullcontext() if not exists(warmup_scheduler) else warmup_scheduler.dampening()
910 |
911 | with maybe_warmup_context:
912 | if exists(scheduler) and not self.accelerator.optimizer_step_was_skipped: # recommended in the docs
913 | scheduler.step()
914 |
915 | self.steps += F.one_hot(torch.tensor(unet_number - 1, device = self.steps.device), num_classes = len(self.steps))
916 |
917 | if not exists(self.checkpoint_path):
918 | return
919 |
920 | total_steps = int(self.steps.sum().item())
921 |
922 | if total_steps % self.checkpoint_every:
923 | return
924 |
925 | self.save_to_checkpoint_folder()
926 |
927 | @torch.no_grad()
928 | @cast_torch_tensor
929 | @imagen_sample_in_chunks
930 | def sample(self, *args, **kwargs):
931 | context = nullcontext if kwargs.pop('use_non_ema', False) else self.use_ema_unets
932 |
933 | self.print_untrained_unets()
934 |
935 | if not self.is_main:
936 | kwargs['use_tqdm'] = False
937 |
938 | with context():
939 | output = self.imagen.sample(*args, device = self.device, **kwargs)
940 |
941 | return output
942 |
943 | @partial(cast_torch_tensor, cast_fp16 = True)
944 | def forward(
945 | self,
946 | *args,
947 | unet_number = None,
948 | max_batch_size = None,
949 | **kwargs
950 | ):
951 | unet_number = self.validate_unet_number(unet_number)
952 | self.validate_and_set_unet_being_trained(unet_number)
953 | self.set_accelerator_scaler(unet_number)
954 |
955 | assert not exists(self.only_train_unet_number) or self.only_train_unet_number == unet_number, f'you can only train unet #{self.only_train_unet_number}'
956 |
957 | total_loss = 0.
958 |
959 | for chunk_size_frac, (chunked_args, chunked_kwargs) in split_args_and_kwargs(*args, split_size = max_batch_size, **kwargs):
960 | with self.accelerator.autocast():
961 | loss = self.imagen(*chunked_args, unet = self.unet_being_trained, unet_number = unet_number, **chunked_kwargs)
962 | loss = loss * chunk_size_frac
963 |
964 | total_loss += loss.item()
965 |
966 | if self.training:
967 | self.accelerator.backward(loss)
968 |
969 | return total_loss
970 |
--------------------------------------------------------------------------------
/imagen_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from functools import reduce
4 | from pathlib import Path
5 |
6 | from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig
7 | from ema_pytorch import EMA
8 |
9 | def exists(val):
10 | return val is not None
11 |
12 | def safeget(dictionary, keys, default = None):
13 | return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
14 |
15 | def load_imagen_from_checkpoint(
16 | checkpoint_path,
17 | load_weights = True,
18 | load_ema_if_available = False
19 | ):
20 | model_path = Path(checkpoint_path)
21 | full_model_path = str(model_path.resolve())
22 | assert model_path.exists(), f'checkpoint not found at {full_model_path}'
23 | loaded = torch.load(str(model_path), map_location='cpu')
24 |
25 | imagen_params = safeget(loaded, 'imagen_params')
26 | imagen_type = safeget(loaded, 'imagen_type')
27 |
28 | if imagen_type == 'original':
29 | imagen_klass = ImagenConfig
30 | elif imagen_type == 'elucidated':
31 | imagen_klass = ElucidatedImagenConfig
32 | else:
33 | raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig')
34 |
35 | assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint'
36 |
37 | imagen = imagen_klass(**imagen_params).create()
38 |
39 | if not load_weights:
40 | return imagen
41 |
42 | has_ema = 'ema' in loaded
43 | should_load_ema = has_ema and load_ema_if_available
44 |
45 | imagen.load_state_dict(loaded['model'])
46 |
47 | if not should_load_ema:
48 | print('loading non-EMA version of unets')
49 | return imagen
50 |
51 | ema_unets = nn.ModuleList([])
52 | for unet in imagen.unets:
53 | ema_unets.append(EMA(unet))
54 |
55 | ema_unets.load_state_dict(loaded['ema'])
56 |
57 | for unet, ema_unet in zip(imagen.unets, ema_unets):
58 | unet.load_state_dict(ema_unet.ema_model.state_dict())
59 |
60 | print('loaded EMA version of unets')
61 | return imagen
62 |
--------------------------------------------------------------------------------
/imagen_pytorch/version.py:
--------------------------------------------------------------------------------
1 | __version__ = '1.11.15'
2 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 | exec(open('imagen_pytorch/version.py').read())
3 |
4 | setup(
5 | name = 'imagen-pytorch',
6 | packages = find_packages(exclude=[]),
7 | include_package_data = True,
8 | entry_points={
9 | 'console_scripts': [
10 | 'imagen_pytorch = imagen_pytorch.cli:main',
11 | 'imagen = imagen_pytorch.cli:imagen'
12 | ],
13 | },
14 | version = __version__,
15 | license='MIT',
16 | description = 'Imagen - unprecedented photorealism × deep level of language understanding',
17 | author = 'Phil Wang',
18 | author_email = 'lucidrains@gmail.com',
19 | long_description_content_type = 'text/markdown',
20 | url = 'https://github.com/lucidrains/imagen-pytorch',
21 | keywords = [
22 | 'artificial intelligence',
23 | 'deep learning',
24 | 'transformers',
25 | 'text-to-image',
26 | 'denoising-diffusion'
27 | ],
28 | install_requires=[
29 | 'accelerate',
30 | 'click',
31 | 'einops>=0.4',
32 | 'einops-exts',
33 | 'ema-pytorch>=0.0.3',
34 | 'fsspec',
35 | 'kornia',
36 | 'numpy',
37 | 'packaging',
38 | 'pillow',
39 | 'pydantic',
40 | 'pytorch-lightning',
41 | 'pytorch-warmup',
42 | 'sentencepiece',
43 | 'torch>=1.6',
44 | 'torchvision',
45 | 'transformers',
46 | 'tqdm'
47 | ],
48 | classifiers=[
49 | 'Development Status :: 4 - Beta',
50 | 'Intended Audience :: Developers',
51 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
52 | 'License :: OSI Approved :: MIT License',
53 | 'Programming Language :: Python :: 3.6',
54 | ],
55 | )
56 |
--------------------------------------------------------------------------------